1use flodl::*;
10use flodl::monitor::Monitor;
11
12fn main() -> Result<()> {
13 let opts = TensorOptions::default();
15 let n_samples = 200i64;
16 let tau = std::f64::consts::TAU;
17 let x_all = Tensor::linspace(-tau, tau, n_samples, opts)?; let y_all = x_all.sin()?;
19
20 let x_data = x_all.reshape(&[n_samples, 1])?;
22 let y_data = y_all.reshape(&[n_samples, 1])?;
23
24 let x_batches = x_data.batches(50)?;
26 let y_batches = y_data.batches(50)?;
27 let batches: Vec<_> = x_batches.into_iter().zip(y_batches).collect();
28
29 let model = FlowBuilder::from(Linear::new(1, 32)?)
31 .through(GELU)
32 .through(LayerNorm::new(32)?)
33 .also(Linear::new(32, 32)?) .through(Linear::new(32, 1)?)
35 .build()?;
36
37 let params = model.parameters();
38 let mut optimizer = Adam::new(¶ms, 0.005);
39 let scheduler = CosineScheduler::new(0.005, 1e-5, 200);
40 model.train();
41
42 let num_epochs = 200usize;
44 let mut monitor = Monitor::new(num_epochs);
45 for epoch in 0..num_epochs {
49 let t = std::time::Instant::now();
50
51 for (xb, yb) in &batches {
52 let input = Variable::new(xb.clone(), true);
53 let target = Variable::new(yb.clone(), false);
54
55 optimizer.zero_grad();
56 let pred = model.forward(&input)?;
57 let loss = mse_loss(&pred, &target)?;
58 loss.backward()?;
59 clip_grad_norm(¶ms, 1.0)?;
60 optimizer.step()?;
61
62 model.record_scalar("loss", loss.item()?);
63 }
64
65 let lr = scheduler.lr(epoch);
66 optimizer.set_lr(lr);
67 model.record_scalar("lr", lr);
68 model.flush(&[]);
69 monitor.log(epoch, t.elapsed(), &model);
70 }
71
72 monitor.finish();
73
74 model.eval();
76 println!("\n{:>8} {:>10} {:>10} {:>8}", "x", "actual", "predicted", "error");
77 println!("{}", "-".repeat(42));
78
79 let test_x = Tensor::linspace(-tau, tau, 10, opts)?;
81 let test_y = test_x.sin()?;
82 let test_input = test_x.reshape(&[10, 1])?;
83
84 let pred = no_grad(|| {
85 let input = Variable::new(test_input.clone(), false);
86 model.forward(&input)
87 })?;
88
89 let pred_data = pred.data().to_f32_vec()?;
90 let actual_data = test_y.to_f32_vec()?;
91 let x_data_vec = test_x.to_f32_vec()?;
92
93 let mut max_err: f32 = 0.0;
94 for i in 0..10 {
95 let err = (pred_data[i] - actual_data[i]).abs();
96 if err > max_err {
97 max_err = err;
98 }
99 println!(
100 "{:>8.3} {:>10.4} {:>10.4} {:>8.4}",
101 x_data_vec[i], actual_data[i], pred_data[i], err
102 );
103 }
104 println!("\nMax error: {:.4}", max_err);
105
106 let path = "sine_model.fdl";
108 let named = model.named_parameters();
109 let named_bufs = model.named_buffers();
110 save_checkpoint_file(path, &named, &named_bufs, Some(model.structural_hash()))?;
111 println!("Checkpoint saved to {}", path);
112
113 let model2 = FlowBuilder::from(Linear::new(1, 32)?)
115 .through(GELU)
116 .through(LayerNorm::new(32)?)
117 .also(Linear::new(32, 32)?)
118 .through(Linear::new(32, 1)?)
119 .build()?;
120
121 let named2 = model2.named_parameters();
122 let named_bufs2 = model2.named_buffers();
123 load_checkpoint_file(path, &named2, &named_bufs2, Some(model2.structural_hash()))?;
124 model2.eval();
125
126 let pred2 = no_grad(|| {
128 let input = Variable::new(test_input.clone(), false);
129 model2.forward(&input)
130 })?;
131
132 let pred2_data = pred2.data().to_f32_vec()?;
133 let mut reload_diff: f32 = 0.0;
134 for i in 0..10 {
135 let d = (pred_data[i] - pred2_data[i]).abs();
136 if d > reload_diff {
137 reload_diff = d;
138 }
139 }
140 println!("Checkpoint reload max diff: {:.6}", reload_diff);
141 assert!(
142 reload_diff < 1e-5,
143 "Checkpoint round-trip mismatch: {}",
144 reload_diff
145 );
146 println!("Checkpoint round-trip verified.");
147
148 std::fs::remove_file(path).ok();
150 Ok(())
151}
152
153#[cfg(test)]
154mod tests {
155 use super::*;
156
157 #[test]
158 fn sine_wave_converges() -> Result<()> {
159 let opts = TensorOptions::default();
160 let n = 100i64;
161 let tau = std::f64::consts::TAU;
162 let x = Tensor::linspace(-tau, tau, n, opts)?;
163 let y = x.sin()?;
164 let x_data = x.reshape(&[n, 1])?;
165 let y_data = y.reshape(&[n, 1])?;
166
167 let model = FlowBuilder::from(Linear::new(1, 32)?)
168 .through(GELU)
169 .through(LayerNorm::new(32)?)
170 .also(Linear::new(32, 32)?)
171 .through(Linear::new(32, 1)?)
172 .build()?;
173
174 let params = model.parameters();
175 let mut opt = Adam::new(¶ms, 0.005);
176 model.train();
177
178 let mut last_loss = f64::MAX;
179 for _ in 0..150 {
180 let input = Variable::new(x_data.clone(), true);
181 let target = Variable::new(y_data.clone(), false);
182
183 opt.zero_grad();
184 let pred = model.forward(&input)?;
185 let loss = mse_loss(&pred, &target)?;
186 loss.backward()?;
187 clip_grad_norm(¶ms, 1.0)?;
188 opt.step()?;
189
190 last_loss = loss.item()?;
191 }
192
193 assert!(
194 last_loss < 0.05,
195 "sine wave loss should converge below 0.05, got {}",
196 last_loss
197 );
198 Ok(())
199 }
200}