Skip to main content

sine_wave/
main.rs

1//! Sine wave regression — the "it really works" moment.
2//!
3//! Trains a small network to learn sin(x) on [-2pi, 2pi], prints a
4//! prediction-vs-actual comparison table, and verifies checkpoint
5//! save/load round-trips.
6//!
7//! Run: `cargo run --example sine_wave`
8
9use flodl::*;
10use flodl::monitor::Monitor;
11
12fn main() -> Result<()> {
13    // --- Data: sin(x) on [-2pi, 2pi] ---
14    let opts = TensorOptions::default();
15    let n_samples = 200i64;
16    let tau = std::f64::consts::TAU;
17    let x_all = Tensor::linspace(-tau, tau, n_samples, opts)?; // [-2pi, 2pi]
18    let y_all = x_all.sin()?;
19
20    // Reshape to [n, 1] for the network.
21    let x_data = x_all.reshape(&[n_samples, 1])?;
22    let y_data = y_all.reshape(&[n_samples, 1])?;
23
24    // Split into batches of 50.
25    let x_batches = x_data.batches(50)?;
26    let y_batches = y_data.batches(50)?;
27    let batches: Vec<_> = x_batches.into_iter().zip(y_batches).collect();
28
29    // --- Model: Linear(1,32) -> GELU -> LayerNorm -> residual -> Linear(32,1) ---
30    let model = FlowBuilder::from(Linear::new(1, 32)?)
31        .through(GELU)
32        .through(LayerNorm::new(32)?)
33        .also(Linear::new(32, 32)?)   // residual connection
34        .through(Linear::new(32, 1)?)
35        .build()?;
36
37    let params = model.parameters();
38    let mut optimizer = Adam::new(&params, 0.005);
39    let scheduler = CosineScheduler::new(0.005, 1e-5, 200);
40    model.train();
41
42    // --- Training ---
43    let num_epochs = 200usize;
44    let mut monitor = Monitor::new(num_epochs);
45    // monitor.serve(3000)?;  // uncomment for live dashboard at http://localhost:3000
46    // monitor.watch(&model);
47
48    for epoch in 0..num_epochs {
49        let t = std::time::Instant::now();
50
51        for (xb, yb) in &batches {
52            let input = Variable::new(xb.clone(), true);
53            let target = Variable::new(yb.clone(), false);
54
55            optimizer.zero_grad();
56            let pred = model.forward(&input)?;
57            let loss = mse_loss(&pred, &target)?;
58            loss.backward()?;
59            clip_grad_norm(&params, 1.0)?;
60            optimizer.step()?;
61
62            model.record_scalar("loss", loss.item()?);
63        }
64
65        let lr = scheduler.lr(epoch);
66        optimizer.set_lr(lr);
67        model.record_scalar("lr", lr);
68        model.flush(&[]);
69        monitor.log(epoch, t.elapsed(), &model);
70    }
71
72    monitor.finish();
73
74    // --- Evaluation ---
75    model.eval();
76    println!("\n{:>8}  {:>10}  {:>10}  {:>8}", "x", "actual", "predicted", "error");
77    println!("{}", "-".repeat(42));
78
79    // Test on 10 evenly spaced points.
80    let test_x = Tensor::linspace(-tau, tau, 10, opts)?;
81    let test_y = test_x.sin()?;
82    let test_input = test_x.reshape(&[10, 1])?;
83
84    let pred = no_grad(|| {
85        let input = Variable::new(test_input.clone(), false);
86        model.forward(&input)
87    })?;
88
89    let pred_data = pred.data().to_f32_vec()?;
90    let actual_data = test_y.to_f32_vec()?;
91    let x_data_vec = test_x.to_f32_vec()?;
92
93    let mut max_err: f32 = 0.0;
94    for i in 0..10 {
95        let err = (pred_data[i] - actual_data[i]).abs();
96        if err > max_err {
97            max_err = err;
98        }
99        println!(
100            "{:>8.3}  {:>10.4}  {:>10.4}  {:>8.4}",
101            x_data_vec[i], actual_data[i], pred_data[i], err
102        );
103    }
104    println!("\nMax error: {:.4}", max_err);
105
106    // --- Checkpoint round-trip ---
107    let path = "sine_model.fdl";
108    let named = model.named_parameters();
109    let named_bufs = model.named_buffers();
110    save_checkpoint_file(path, &named, &named_bufs, Some(model.structural_hash()))?;
111    println!("Checkpoint saved to {}", path);
112
113    // Rebuild architecture and load weights.
114    let model2 = FlowBuilder::from(Linear::new(1, 32)?)
115        .through(GELU)
116        .through(LayerNorm::new(32)?)
117        .also(Linear::new(32, 32)?)
118        .through(Linear::new(32, 1)?)
119        .build()?;
120
121    let named2 = model2.named_parameters();
122    let named_bufs2 = model2.named_buffers();
123    load_checkpoint_file(path, &named2, &named_bufs2, Some(model2.structural_hash()))?;
124    model2.eval();
125
126    // Verify loaded model produces the same output.
127    let pred2 = no_grad(|| {
128        let input = Variable::new(test_input.clone(), false);
129        model2.forward(&input)
130    })?;
131
132    let pred2_data = pred2.data().to_f32_vec()?;
133    let mut reload_diff: f32 = 0.0;
134    for i in 0..10 {
135        let d = (pred_data[i] - pred2_data[i]).abs();
136        if d > reload_diff {
137            reload_diff = d;
138        }
139    }
140    println!("Checkpoint reload max diff: {:.6}", reload_diff);
141    assert!(
142        reload_diff < 1e-5,
143        "Checkpoint round-trip mismatch: {}",
144        reload_diff
145    );
146    println!("Checkpoint round-trip verified.");
147
148    // Clean up.
149    std::fs::remove_file(path).ok();
150    Ok(())
151}
152
153#[cfg(test)]
154mod tests {
155    use super::*;
156
157    #[test]
158    fn sine_wave_converges() -> Result<()> {
159        let opts = TensorOptions::default();
160        let n = 100i64;
161        let tau = std::f64::consts::TAU;
162        let x = Tensor::linspace(-tau, tau, n, opts)?;
163        let y = x.sin()?;
164        let x_data = x.reshape(&[n, 1])?;
165        let y_data = y.reshape(&[n, 1])?;
166
167        let model = FlowBuilder::from(Linear::new(1, 32)?)
168            .through(GELU)
169            .through(LayerNorm::new(32)?)
170            .also(Linear::new(32, 32)?)
171            .through(Linear::new(32, 1)?)
172            .build()?;
173
174        let params = model.parameters();
175        let mut opt = Adam::new(&params, 0.005);
176        model.train();
177
178        let mut last_loss = f64::MAX;
179        for _ in 0..150 {
180            let input = Variable::new(x_data.clone(), true);
181            let target = Variable::new(y_data.clone(), false);
182
183            opt.zero_grad();
184            let pred = model.forward(&input)?;
185            let loss = mse_loss(&pred, &target)?;
186            loss.backward()?;
187            clip_grad_norm(&params, 1.0)?;
188            opt.step()?;
189
190            last_loss = loss.item()?;
191        }
192
193        assert!(
194            last_loss < 0.05,
195            "sine wave loss should converge below 0.05, got {}",
196            last_loss
197        );
198        Ok(())
199    }
200}