Skip to main content

schedulers/
main.rs

1//! LR scheduler composition — warmup + cosine + plateau fallback.
2//!
3//! Demonstrates composable learning rate scheduling: linear warmup into
4//! cosine annealing, with a plateau scheduler as fallback when the
5//! primary schedule finishes.
6//!
7//! Run: `cargo run --example schedulers`
8
9use flodl::*;
10
11fn main() -> Result<()> {
12    let opts = TensorOptions::default();
13
14    // Build a small model.
15    let model = FlowBuilder::from(Linear::new(4, 16)?)
16        .through(GELU)
17        .through(LayerNorm::new(16)?)
18        .also(Linear::new(16, 16)?)
19        .through(Linear::new(16, 4)?)
20        .build()?;
21
22    let params = model.parameters();
23    let mut optimizer = Adam::new(&params, 0.001);
24    model.train();
25
26    let num_epochs = 100usize;
27
28    // --- Compose schedulers ---
29
30    // Warmup for 10 epochs, then cosine anneal for the remaining 90.
31    let cosine = CosineScheduler::new(0.001, 1e-5, num_epochs - 10);
32    let scheduler = WarmupScheduler::new(cosine, 0.001, 10);
33
34    // Plateau scheduler as a secondary control: if loss stalls,
35    // reduce LR further regardless of the primary schedule.
36    let mut plateau = PlateauScheduler::new(0.001, 10, 0.5, 1e-6);
37
38    println!("{:>5}  {:>10}  {:>10}  {:>10}", "epoch", "loss", "sched_lr", "eff_lr");
39    println!("{}", "-".repeat(40));
40
41    for epoch in 0..num_epochs {
42        let x = Tensor::randn(&[64, 4], opts)?;
43        let y = Tensor::randn(&[64, 4], opts)?;
44        let input = Variable::new(x, true);
45        let target = Variable::new(y, false);
46
47        optimizer.zero_grad();
48        let pred = model.forward(&input)?;
49        let loss = mse_loss(&pred, &target)?;
50        loss.backward()?;
51        clip_grad_norm(&params, 1.0)?;
52        optimizer.step()?;
53
54        let loss_val = loss.item()?;
55
56        // Primary schedule (warmup + cosine).
57        let sched_lr = scheduler.lr(epoch);
58
59        // Plateau feedback — takes the minimum of primary and plateau LR.
60        let plateau_lr = plateau.observe(loss_val);
61        let effective_lr = sched_lr.min(plateau_lr);
62
63        optimizer.set_lr(effective_lr);
64
65        if epoch % 10 == 0 || epoch == num_epochs - 1 {
66            println!(
67                "{:>5}  {:>10.6}  {:>10.6}  {:>10.6}",
68                epoch, loss_val, sched_lr, effective_lr
69            );
70        }
71    }
72
73    println!("\nFinal LR: {:.6}", optimizer.lr());
74    Ok(())
75}