1use flodl::*;
10use flodl::monitor::Monitor;
11
12fn main() -> Result<()> {
13 let opts = TensorOptions::default();
14
15 let model = FlowBuilder::from(Linear::new(4, 16)?)
17 .through(GELU)
18 .tag("hidden")
19 .through(LayerNorm::new(16)?)
20 .also(Linear::new(16, 16)?)
21 .tag("residual")
22 .through(Linear::new(16, 4)?)
23 .build()?;
24
25 let params = model.parameters();
26 let mut optimizer = Adam::new(¶ms, 0.01);
27 let scheduler = CosineScheduler::new(0.01, 1e-4, 200);
28 model.train();
29
30 let num_epochs = 200usize;
31 let mut monitor = Monitor::new(num_epochs);
32
33 let x_data = Tensor::randn(&[128, 4], opts)?;
35 let y_data = Tensor::randn(&[128, 4], opts)?;
36
37 let x_batches = x_data.batches(32)?;
39 let y_batches = y_data.batches(32)?;
40 let batches: Vec<_> = x_batches.into_iter().zip(y_batches).collect();
41
42 println!("{:>5} {:>10} {:>10} {:>8} {:>10}", "epoch", "loss", "slope", "stalled", "status");
43 println!("{}", "-".repeat(50));
44
45 for epoch in 0..num_epochs {
46 let t = std::time::Instant::now();
47
48 for (xb, yb) in &batches {
49 let input = Variable::new(xb.clone(), true);
50 let target = Variable::new(yb.clone(), false);
51
52 optimizer.zero_grad();
53 let pred = model.forward(&input)?;
54
55 model.collect(&["hidden", "residual"])?;
57
58 let loss = mse_loss(&pred, &target)?;
59 loss.backward()?;
60 clip_grad_norm(¶ms, 1.0)?;
61 optimizer.step()?;
62
63 model.record_scalar("loss", loss.item()?);
64 }
65
66 let lr = scheduler.lr(epoch);
68 optimizer.set_lr(lr);
69 model.record_scalar("lr", lr);
70 model.flush(&["hidden", "residual", "loss", "lr"]);
71 monitor.log(epoch, t.elapsed(), &model);
72
73 let loss_trend = model.trend("loss");
75 let stalled = loss_trend.stalled(10, 1e-5);
76 let converged = loss_trend.converged(10, 1e-5);
77
78 if epoch % 20 == 0 || stalled || converged {
79 let status = if converged {
80 "CONVERGED"
81 } else if stalled {
82 "stalled"
83 } else if loss_trend.improving(10) {
84 "improving"
85 } else {
86 ""
87 };
88
89 println!(
90 "{:>5} {:>10.6} {:>10.6} {:>8} {:>10}",
91 epoch,
92 loss_trend.latest(),
93 loss_trend.slope(10),
94 stalled,
95 status
96 );
97 }
98
99 if converged && epoch > 20 {
101 println!("\nEarly stop at epoch {} — loss converged.", epoch);
102 break;
103 }
104 }
105
106 monitor.finish();
107
108 let group = model.trends(&["hidden", "residual", "loss"]);
110 println!("\nFinal trend summary:");
111 println!(" All improving: {}", group.all_improving(10));
112 println!(" Any stalled: {}", group.any_stalled(10, 1e-5));
113 println!(" Mean slope: {:.6}", group.mean_slope(10));
114
115 Ok(())
116}