Skip to main content

observation/
main.rs

1//! Observation and early stopping — collect, flush, trend, stop.
2//!
3//! Demonstrates the full observation workflow: tag graph nodes, collect
4//! per-batch metrics, flush to epoch history, query trends for early
5//! stopping decisions, and export training curves.
6//!
7//! Run: `cargo run --example observation`
8
9use flodl::*;
10use flodl::monitor::Monitor;
11
12fn main() -> Result<()> {
13    let opts = TensorOptions::default();
14
15    // Build a model with tagged intermediate nodes.
16    let model = FlowBuilder::from(Linear::new(4, 16)?)
17        .through(GELU)
18        .tag("hidden")
19        .through(LayerNorm::new(16)?)
20        .also(Linear::new(16, 16)?)
21        .tag("residual")
22        .through(Linear::new(16, 4)?)
23        .build()?;
24
25    let params = model.parameters();
26    let mut optimizer = Adam::new(&params, 0.01);
27    let scheduler = CosineScheduler::new(0.01, 1e-4, 200);
28    model.train();
29
30    let num_epochs = 200usize;
31    let mut monitor = Monitor::new(num_epochs);
32
33    // Generate stable data so trends are meaningful.
34    let x_data = Tensor::randn(&[128, 4], opts)?;
35    let y_data = Tensor::randn(&[128, 4], opts)?;
36
37    // Split into batches.
38    let x_batches = x_data.batches(32)?;
39    let y_batches = y_data.batches(32)?;
40    let batches: Vec<_> = x_batches.into_iter().zip(y_batches).collect();
41
42    println!("{:>5}  {:>10}  {:>10}  {:>8}  {:>10}", "epoch", "loss", "slope", "stalled", "status");
43    println!("{}", "-".repeat(50));
44
45    for epoch in 0..num_epochs {
46        let t = std::time::Instant::now();
47
48        for (xb, yb) in &batches {
49            let input = Variable::new(xb.clone(), true);
50            let target = Variable::new(yb.clone(), false);
51
52            optimizer.zero_grad();
53            let pred = model.forward(&input)?;
54
55            // Collect tagged node outputs as metrics.
56            model.collect(&["hidden", "residual"])?;
57
58            let loss = mse_loss(&pred, &target)?;
59            loss.backward()?;
60            clip_grad_norm(&params, 1.0)?;
61            optimizer.step()?;
62
63            model.record_scalar("loss", loss.item()?);
64        }
65
66        // Flush batch metrics to epoch history.
67        let lr = scheduler.lr(epoch);
68        optimizer.set_lr(lr);
69        model.record_scalar("lr", lr);
70        model.flush(&["hidden", "residual", "loss", "lr"]);
71        monitor.log(epoch, t.elapsed(), &model);
72
73        // Query trends for early stopping (window=10 epochs).
74        let loss_trend = model.trend("loss");
75        let stalled = loss_trend.stalled(10, 1e-5);
76        let converged = loss_trend.converged(10, 1e-5);
77
78        if epoch % 20 == 0 || stalled || converged {
79            let status = if converged {
80                "CONVERGED"
81            } else if stalled {
82                "stalled"
83            } else if loss_trend.improving(10) {
84                "improving"
85            } else {
86                ""
87            };
88
89            println!(
90                "{:>5}  {:>10.6}  {:>10.6}  {:>8}  {:>10}",
91                epoch,
92                loss_trend.latest(),
93                loss_trend.slope(10),
94                stalled,
95                status
96            );
97        }
98
99        // Early stop when converged.
100        if converged && epoch > 20 {
101            println!("\nEarly stop at epoch {} — loss converged.", epoch);
102            break;
103        }
104    }
105
106    monitor.finish();
107
108    // Group trend queries (window=10 for all).
109    let group = model.trends(&["hidden", "residual", "loss"]);
110    println!("\nFinal trend summary:");
111    println!("  All improving: {}", group.all_improving(10));
112    println!("  Any stalled:   {}", group.any_stalled(10, 1e-5));
113    println!("  Mean slope:    {:.6}", group.mean_slope(10));
114
115    Ok(())
116}