Skip to main content

quickstart/
main.rs

1//! Quickstart — build, train, and monitor a model with residual connections.
2//!
3//! Builds a small graph with a residual connection, trains it on random
4//! data using Adam, and logs progress with the training monitor.
5//!
6//! Run: `cargo run --example quickstart`
7
8use flodl::*;
9use flodl::monitor::Monitor;
10
11fn main() -> Result<()> {
12    // Build the model.
13    let model = FlowBuilder::from(Linear::new(2, 16)?)
14        .through(GELU)
15        .through(LayerNorm::new(16)?)
16        .also(Linear::new(16, 16)?)
17        .through(Linear::new(16, 2)?)
18        .build()?;
19
20    // Set up training.
21    let params = model.parameters();
22    let mut optimizer = Adam::new(&params, 0.01);
23    model.train();
24
25    // Generate some random data (XOR-ish pattern).
26    let opts = TensorOptions::default();
27    let batches: Vec<(Tensor, Tensor)> = (0..32)
28        .map(|_| {
29            let x = Tensor::randn(&[16, 2], opts).unwrap();
30            let y = Tensor::randn(&[16, 2], opts).unwrap();
31            (x, y)
32        })
33        .collect();
34
35    // Training loop with monitor.
36    let num_epochs = 50;
37    let mut monitor = Monitor::new(num_epochs);
38    // monitor.serve(3000)?;   // uncomment for live dashboard
39    // monitor.watch(&model);  // uncomment to show graph SVG in dashboard
40
41    for epoch in 0..num_epochs {
42        let t = std::time::Instant::now();
43
44        for (input_t, target_t) in &batches {
45            let input = Variable::new(input_t.clone(), true);
46            let target = Variable::new(target_t.clone(), false);
47
48            let pred = model.forward(&input)?;
49            let loss = mse_loss(&pred, &target)?;
50
51            optimizer.zero_grad();
52            loss.backward()?;
53            clip_grad_norm(&params, 1.0)?;
54            optimizer.step()?;
55
56            model.record_scalar("loss", loss.item()?);
57        }
58
59        model.flush(&[]);
60        monitor.log(epoch, t.elapsed(), &model);
61    }
62
63    monitor.finish();
64    Ok(())
65}