pub trait Module {
Show 14 methods
// Required method
fn forward(&self, input: &Variable) -> Result<Variable>;
// Provided methods
fn parameters(&self) -> Vec<Parameter> { ... }
fn buffers(&self) -> Vec<Buffer> { ... }
fn name(&self) -> &str { ... }
fn sub_modules(&self) -> Vec<Rc<dyn Module>> { ... }
fn move_to_device(&self, _device: Device) { ... }
fn set_training(&self, _training: bool) { ... }
fn train(&self) { ... }
fn eval(&self) { ... }
fn trace(&self) -> Option<Variable> { ... }
fn as_named_input(&self) -> Option<&dyn NamedInputModule> { ... }
fn structural_hash(&self) -> Option<String> { ... }
fn reset(&self) { ... }
fn detach_state(&self) { ... }
}Expand description
The core module trait: forward pass + parameter access.
All neural network layers implement Module. Composite modules (Graph, loops, gates) implement Module too, so they compose like any other layer.
let model = Linear::new(4, 2)?;
let x = Variable::new(Tensor::randn(&[1, 4], opts)?, false);
let y = model.forward(&x)?; // [1, 4] → [1, 2]Required Methods§
Provided Methods§
Sourcefn parameters(&self) -> Vec<Parameter>
fn parameters(&self) -> Vec<Parameter>
Return this module’s learnable parameters.
Default: recursively collects from sub_modules() with pointer dedup.
Leaf modules should override to return their own parameters.
Examples found in repository?
11fn main() -> Result<()> {
12 // Build the model.
13 let model = FlowBuilder::from(Linear::new(2, 16)?)
14 .through(GELU)
15 .through(LayerNorm::new(16)?)
16 .also(Linear::new(16, 16)?)
17 .through(Linear::new(16, 2)?)
18 .build()?;
19
20 // Set up training.
21 let params = model.parameters();
22 let mut optimizer = Adam::new(¶ms, 0.01);
23 model.train();
24
25 // Generate some random data (XOR-ish pattern).
26 let opts = TensorOptions::default();
27 let batches: Vec<(Tensor, Tensor)> = (0..32)
28 .map(|_| {
29 let x = Tensor::randn(&[16, 2], opts).unwrap();
30 let y = Tensor::randn(&[16, 2], opts).unwrap();
31 (x, y)
32 })
33 .collect();
34
35 // Training loop with monitor.
36 let num_epochs = 50;
37 let mut monitor = Monitor::new(num_epochs);
38 // monitor.serve(3000)?; // uncomment for live dashboard
39 // monitor.watch(&model); // uncomment to show graph SVG in dashboard
40
41 for epoch in 0..num_epochs {
42 let t = std::time::Instant::now();
43
44 for (input_t, target_t) in &batches {
45 let input = Variable::new(input_t.clone(), true);
46 let target = Variable::new(target_t.clone(), false);
47
48 let pred = model.forward(&input)?;
49 let loss = mse_loss(&pred, &target)?;
50
51 optimizer.zero_grad();
52 loss.backward()?;
53 clip_grad_norm(¶ms, 1.0)?;
54 optimizer.step()?;
55
56 model.record_scalar("loss", loss.item()?);
57 }
58
59 model.flush(&[]);
60 monitor.log(epoch, t.elapsed(), &model);
61 }
62
63 monitor.finish();
64 Ok(())
65}More examples
11fn main() -> Result<()> {
12 let opts = TensorOptions::default();
13
14 // Build a small model.
15 let model = FlowBuilder::from(Linear::new(4, 16)?)
16 .through(GELU)
17 .through(LayerNorm::new(16)?)
18 .also(Linear::new(16, 16)?)
19 .through(Linear::new(16, 4)?)
20 .build()?;
21
22 let params = model.parameters();
23 let mut optimizer = Adam::new(¶ms, 0.001);
24 model.train();
25
26 let num_epochs = 100usize;
27
28 // --- Compose schedulers ---
29
30 // Warmup for 10 epochs, then cosine anneal for the remaining 90.
31 let cosine = CosineScheduler::new(0.001, 1e-5, num_epochs - 10);
32 let scheduler = WarmupScheduler::new(cosine, 0.001, 10);
33
34 // Plateau scheduler as a secondary control: if loss stalls,
35 // reduce LR further regardless of the primary schedule.
36 let mut plateau = PlateauScheduler::new(0.001, 10, 0.5, 1e-6);
37
38 println!("{:>5} {:>10} {:>10} {:>10}", "epoch", "loss", "sched_lr", "eff_lr");
39 println!("{}", "-".repeat(40));
40
41 for epoch in 0..num_epochs {
42 let x = Tensor::randn(&[64, 4], opts)?;
43 let y = Tensor::randn(&[64, 4], opts)?;
44 let input = Variable::new(x, true);
45 let target = Variable::new(y, false);
46
47 optimizer.zero_grad();
48 let pred = model.forward(&input)?;
49 let loss = mse_loss(&pred, &target)?;
50 loss.backward()?;
51 clip_grad_norm(¶ms, 1.0)?;
52 optimizer.step()?;
53
54 let loss_val = loss.item()?;
55
56 // Primary schedule (warmup + cosine).
57 let sched_lr = scheduler.lr(epoch);
58
59 // Plateau feedback — takes the minimum of primary and plateau LR.
60 let plateau_lr = plateau.observe(loss_val);
61 let effective_lr = sched_lr.min(plateau_lr);
62
63 optimizer.set_lr(effective_lr);
64
65 if epoch % 10 == 0 || epoch == num_epochs - 1 {
66 println!(
67 "{:>5} {:>10.6} {:>10.6} {:>10.6}",
68 epoch, loss_val, sched_lr, effective_lr
69 );
70 }
71 }
72
73 println!("\nFinal LR: {:.6}", optimizer.lr());
74 Ok(())
75}12fn main() -> Result<()> {
13 let opts = TensorOptions::default();
14
15 // Random regression data.
16 let batches: Vec<(Tensor, Tensor)> = (0..4)
17 .map(|_| {
18 let x = Tensor::randn(&[50, 8], opts).unwrap();
19 let y = Tensor::randn(&[50, 4], opts).unwrap();
20 (x, y)
21 })
22 .collect();
23
24 // Build the model (starts in float32).
25 let model = FlowBuilder::from(Linear::new(8, 32)?)
26 .through(GELU)
27 .through(LayerNorm::new(32)?)
28 .also(Linear::new(32, 32)?)
29 .through(Linear::new(32, 4)?)
30 .build()?;
31
32 let params = model.parameters();
33
34 // Cast parameters to float16 for reduced memory and faster matmuls.
35 cast_parameters(¶ms, DType::Float16);
36 println!("Parameters cast to float16");
37
38 let mut optimizer = Adam::new(¶ms, 0.001);
39 let mut scaler = GradScaler::new();
40 model.train();
41
42 let num_epochs = 100usize;
43 let mut monitor = Monitor::new(num_epochs);
44
45 for epoch in 0..num_epochs {
46 let t = std::time::Instant::now();
47 let mut steps_taken = 0u32;
48
49 for (xb, yb) in &batches {
50 // Cast inputs to match parameter dtype.
51 let input = Variable::new(xb.to_dtype(DType::Float16)?, true);
52 let target = Variable::new(yb.to_dtype(DType::Float16)?, false);
53
54 optimizer.zero_grad();
55 let pred = model.forward(&input)?;
56 let loss = mse_loss(&pred, &target)?;
57
58 // Scale loss before backward to prevent gradient underflow.
59 let scaled_loss = scaler.scale(&loss)?;
60 scaled_loss.backward()?;
61
62 // Unscale, check for inf/nan, step if clean.
63 let stepped = scaler.step(¶ms, &mut || optimizer.step())?;
64 scaler.update();
65
66 if stepped {
67 steps_taken += 1;
68 }
69
70 model.record_scalar("loss", loss.item()?);
71 }
72
73 model.record_scalar("scale", scaler.scale_factor());
74 model.flush(&[]);
75 monitor.log(epoch, t.elapsed(), &model);
76
77 if steps_taken == 0 {
78 println!("epoch {}: all steps skipped (inf grads), scale={:.0}", epoch, scaler.scale_factor());
79 }
80 }
81
82 monitor.finish();
83
84 // Cast back to float32 for inference or checkpointing.
85 cast_parameters(¶ms, DType::Float32);
86 println!("Parameters cast back to float32 for export");
87
88 Ok(())
89}12fn main() -> Result<()> {
13 let opts = TensorOptions::default();
14
15 // Build a model with tagged intermediate nodes.
16 let model = FlowBuilder::from(Linear::new(4, 16)?)
17 .through(GELU)
18 .tag("hidden")
19 .through(LayerNorm::new(16)?)
20 .also(Linear::new(16, 16)?)
21 .tag("residual")
22 .through(Linear::new(16, 4)?)
23 .build()?;
24
25 let params = model.parameters();
26 let mut optimizer = Adam::new(¶ms, 0.01);
27 let scheduler = CosineScheduler::new(0.01, 1e-4, 200);
28 model.train();
29
30 let num_epochs = 200usize;
31 let mut monitor = Monitor::new(num_epochs);
32
33 // Generate stable data so trends are meaningful.
34 let x_data = Tensor::randn(&[128, 4], opts)?;
35 let y_data = Tensor::randn(&[128, 4], opts)?;
36
37 // Split into batches.
38 let x_batches = x_data.batches(32)?;
39 let y_batches = y_data.batches(32)?;
40 let batches: Vec<_> = x_batches.into_iter().zip(y_batches).collect();
41
42 println!("{:>5} {:>10} {:>10} {:>8} {:>10}", "epoch", "loss", "slope", "stalled", "status");
43 println!("{}", "-".repeat(50));
44
45 for epoch in 0..num_epochs {
46 let t = std::time::Instant::now();
47
48 for (xb, yb) in &batches {
49 let input = Variable::new(xb.clone(), true);
50 let target = Variable::new(yb.clone(), false);
51
52 optimizer.zero_grad();
53 let pred = model.forward(&input)?;
54
55 // Collect tagged node outputs as metrics.
56 model.collect(&["hidden", "residual"])?;
57
58 let loss = mse_loss(&pred, &target)?;
59 loss.backward()?;
60 clip_grad_norm(¶ms, 1.0)?;
61 optimizer.step()?;
62
63 model.record_scalar("loss", loss.item()?);
64 }
65
66 // Flush batch metrics to epoch history.
67 let lr = scheduler.lr(epoch);
68 optimizer.set_lr(lr);
69 model.record_scalar("lr", lr);
70 model.flush(&["hidden", "residual", "loss", "lr"]);
71 monitor.log(epoch, t.elapsed(), &model);
72
73 // Query trends for early stopping (window=10 epochs).
74 let loss_trend = model.trend("loss");
75 let stalled = loss_trend.stalled(10, 1e-5);
76 let converged = loss_trend.converged(10, 1e-5);
77
78 if epoch % 20 == 0 || stalled || converged {
79 let status = if converged {
80 "CONVERGED"
81 } else if stalled {
82 "stalled"
83 } else if loss_trend.improving(10) {
84 "improving"
85 } else {
86 ""
87 };
88
89 println!(
90 "{:>5} {:>10.6} {:>10.6} {:>8} {:>10}",
91 epoch,
92 loss_trend.latest(),
93 loss_trend.slope(10),
94 stalled,
95 status
96 );
97 }
98
99 // Early stop when converged.
100 if converged && epoch > 20 {
101 println!("\nEarly stop at epoch {} — loss converged.", epoch);
102 break;
103 }
104 }
105
106 monitor.finish();
107
108 // Group trend queries (window=10 for all).
109 let group = model.trends(&["hidden", "residual", "loss"]);
110 println!("\nFinal trend summary:");
111 println!(" All improving: {}", group.all_improving(10));
112 println!(" Any stalled: {}", group.any_stalled(10, 1e-5));
113 println!(" Mean slope: {:.6}", group.mean_slope(10));
114
115 Ok(())
116}11fn main() -> Result<()> {
12 let opts = TensorOptions::default();
13
14 // --- Phase 1: Train an encoder on a proxy task ---
15 println!("=== Phase 1: Train encoder ===");
16
17 let encoder = FlowBuilder::from(Linear::new(4, 16)?)
18 .through(GELU)
19 .through(LayerNorm::new(16)?)
20 .tag("encoded")
21 .through(Linear::new(16, 4)?)
22 .build()?;
23
24 let params = encoder.parameters();
25 let mut opt = Adam::new(¶ms, 0.01);
26 encoder.train();
27
28 for epoch in 0..50 {
29 let x = Tensor::randn(&[32, 4], opts)?;
30 let y = Tensor::randn(&[32, 4], opts)?;
31 let input = Variable::new(x, true);
32 let target = Variable::new(y, false);
33
34 opt.zero_grad();
35 let pred = encoder.forward(&input)?;
36 let loss = mse_loss(&pred, &target)?;
37 loss.backward()?;
38 clip_grad_norm(¶ms, 1.0)?;
39 opt.step()?;
40
41 if epoch % 10 == 0 {
42 println!(" epoch {:>3}: loss={:.4}", epoch, loss.item()?);
43 }
44 }
45
46 // Save the pretrained encoder.
47 let ckpt_path = "pretrained_encoder.fdl";
48 let named = encoder.named_parameters();
49 let named_bufs = encoder.named_buffers();
50 save_checkpoint_file(ckpt_path, &named, &named_bufs, None)?;
51 println!("Encoder saved to {}", ckpt_path);
52
53 // --- Phase 2: Build a new model and load encoder weights ---
54 println!("\n=== Phase 2: Transfer to new architecture ===");
55
56 // New architecture reuses the encoder layers but adds a different head.
57 let model = FlowBuilder::from(Linear::new(4, 16)?)
58 .through(GELU)
59 .through(LayerNorm::new(16)?)
60 .tag("encoded")
61 .also(Linear::new(16, 16)?) // new: residual connection
62 .through(Linear::new(16, 2)?) // new: different output dim
63 .build()?;
64
65 // Partial load: matching names get loaded, new layers keep random init.
66 let named2 = model.named_parameters();
67 let named_bufs2 = model.named_buffers();
68 let report = load_checkpoint_file(ckpt_path, &named2, &named_bufs2, None)?;
69
70 println!("Loaded {} parameters, skipped {}, missing {}",
71 report.loaded.len(), report.skipped.len(), report.missing.len());
72
73 // Freeze the encoder layers (first 3 modules: Linear, GELU, LayerNorm).
74 let all_params = model.parameters();
75 for (i, p) in all_params.iter().enumerate() {
76 if i < 3 {
77 p.freeze()?;
78 }
79 }
80 println!("Encoder layers frozen");
81
82 // Set up optimizer with parameter groups:
83 // - Frozen params are excluded automatically (zero grad).
84 // - New head gets higher LR.
85 let trainable: Vec<Parameter> = all_params
86 .iter()
87 .filter(|p| !p.is_frozen())
88 .cloned()
89 .collect();
90
91 let mut opt2 = Adam::new(&trainable, 0.005);
92 model.train();
93
94 // --- Phase 3: Fine-tune the new head ---
95 println!("\n=== Phase 3: Fine-tune new head ===");
96
97 for epoch in 0..50 {
98 let x = Tensor::randn(&[32, 4], opts)?;
99 let y = Tensor::randn(&[32, 2], opts)?;
100 let input = Variable::new(x, true);
101 let target = Variable::new(y, false);
102
103 opt2.zero_grad();
104 let pred = model.forward(&input)?;
105 let loss = mse_loss(&pred, &target)?;
106 loss.backward()?;
107 clip_grad_norm(&trainable, 1.0)?;
108 opt2.step()?;
109
110 if epoch % 10 == 0 {
111 println!(" epoch {:>3}: loss={:.4}", epoch, loss.item()?);
112 }
113 }
114
115 // --- Phase 4: Unfreeze and fine-tune everything ---
116 println!("\n=== Phase 4: Full fine-tune (unfrozen) ===");
117
118 for p in &all_params {
119 p.unfreeze()?;
120 }
121
122 let mut opt3 = Adam::new(&all_params, 0.001); // lower LR for full model
123 for epoch in 0..30 {
124 let x = Tensor::randn(&[32, 4], opts)?;
125 let y = Tensor::randn(&[32, 2], opts)?;
126 let input = Variable::new(x, true);
127 let target = Variable::new(y, false);
128
129 opt3.zero_grad();
130 let pred = model.forward(&input)?;
131 let loss = mse_loss(&pred, &target)?;
132 loss.backward()?;
133 clip_grad_norm(&all_params, 1.0)?;
134 opt3.step()?;
135
136 if epoch % 10 == 0 {
137 println!(" epoch {:>3}: loss={:.4}", epoch, loss.item()?);
138 }
139 }
140
141 println!("\nTransfer learning complete.");
142
143 // Clean up.
144 std::fs::remove_file(ckpt_path).ok();
145 Ok(())
146}12fn main() -> Result<()> {
13 // --- Data: sin(x) on [-2pi, 2pi] ---
14 let opts = TensorOptions::default();
15 let n_samples = 200i64;
16 let tau = std::f64::consts::TAU;
17 let x_all = Tensor::linspace(-tau, tau, n_samples, opts)?; // [-2pi, 2pi]
18 let y_all = x_all.sin()?;
19
20 // Reshape to [n, 1] for the network.
21 let x_data = x_all.reshape(&[n_samples, 1])?;
22 let y_data = y_all.reshape(&[n_samples, 1])?;
23
24 // Split into batches of 50.
25 let x_batches = x_data.batches(50)?;
26 let y_batches = y_data.batches(50)?;
27 let batches: Vec<_> = x_batches.into_iter().zip(y_batches).collect();
28
29 // --- Model: Linear(1,32) -> GELU -> LayerNorm -> residual -> Linear(32,1) ---
30 let model = FlowBuilder::from(Linear::new(1, 32)?)
31 .through(GELU)
32 .through(LayerNorm::new(32)?)
33 .also(Linear::new(32, 32)?) // residual connection
34 .through(Linear::new(32, 1)?)
35 .build()?;
36
37 let params = model.parameters();
38 let mut optimizer = Adam::new(¶ms, 0.005);
39 let scheduler = CosineScheduler::new(0.005, 1e-5, 200);
40 model.train();
41
42 // --- Training ---
43 let num_epochs = 200usize;
44 let mut monitor = Monitor::new(num_epochs);
45 // monitor.serve(3000)?; // uncomment for live dashboard at http://localhost:3000
46 // monitor.watch(&model);
47
48 for epoch in 0..num_epochs {
49 let t = std::time::Instant::now();
50
51 for (xb, yb) in &batches {
52 let input = Variable::new(xb.clone(), true);
53 let target = Variable::new(yb.clone(), false);
54
55 optimizer.zero_grad();
56 let pred = model.forward(&input)?;
57 let loss = mse_loss(&pred, &target)?;
58 loss.backward()?;
59 clip_grad_norm(¶ms, 1.0)?;
60 optimizer.step()?;
61
62 model.record_scalar("loss", loss.item()?);
63 }
64
65 let lr = scheduler.lr(epoch);
66 optimizer.set_lr(lr);
67 model.record_scalar("lr", lr);
68 model.flush(&[]);
69 monitor.log(epoch, t.elapsed(), &model);
70 }
71
72 monitor.finish();
73
74 // --- Evaluation ---
75 model.eval();
76 println!("\n{:>8} {:>10} {:>10} {:>8}", "x", "actual", "predicted", "error");
77 println!("{}", "-".repeat(42));
78
79 // Test on 10 evenly spaced points.
80 let test_x = Tensor::linspace(-tau, tau, 10, opts)?;
81 let test_y = test_x.sin()?;
82 let test_input = test_x.reshape(&[10, 1])?;
83
84 let pred = no_grad(|| {
85 let input = Variable::new(test_input.clone(), false);
86 model.forward(&input)
87 })?;
88
89 let pred_data = pred.data().to_f32_vec()?;
90 let actual_data = test_y.to_f32_vec()?;
91 let x_data_vec = test_x.to_f32_vec()?;
92
93 let mut max_err: f32 = 0.0;
94 for i in 0..10 {
95 let err = (pred_data[i] - actual_data[i]).abs();
96 if err > max_err {
97 max_err = err;
98 }
99 println!(
100 "{:>8.3} {:>10.4} {:>10.4} {:>8.4}",
101 x_data_vec[i], actual_data[i], pred_data[i], err
102 );
103 }
104 println!("\nMax error: {:.4}", max_err);
105
106 // --- Checkpoint round-trip ---
107 let path = "sine_model.fdl";
108 let named = model.named_parameters();
109 let named_bufs = model.named_buffers();
110 save_checkpoint_file(path, &named, &named_bufs, Some(model.structural_hash()))?;
111 println!("Checkpoint saved to {}", path);
112
113 // Rebuild architecture and load weights.
114 let model2 = FlowBuilder::from(Linear::new(1, 32)?)
115 .through(GELU)
116 .through(LayerNorm::new(32)?)
117 .also(Linear::new(32, 32)?)
118 .through(Linear::new(32, 1)?)
119 .build()?;
120
121 let named2 = model2.named_parameters();
122 let named_bufs2 = model2.named_buffers();
123 load_checkpoint_file(path, &named2, &named_bufs2, Some(model2.structural_hash()))?;
124 model2.eval();
125
126 // Verify loaded model produces the same output.
127 let pred2 = no_grad(|| {
128 let input = Variable::new(test_input.clone(), false);
129 model2.forward(&input)
130 })?;
131
132 let pred2_data = pred2.data().to_f32_vec()?;
133 let mut reload_diff: f32 = 0.0;
134 for i in 0..10 {
135 let d = (pred_data[i] - pred2_data[i]).abs();
136 if d > reload_diff {
137 reload_diff = d;
138 }
139 }
140 println!("Checkpoint reload max diff: {:.6}", reload_diff);
141 assert!(
142 reload_diff < 1e-5,
143 "Checkpoint round-trip mismatch: {}",
144 reload_diff
145 );
146 println!("Checkpoint round-trip verified.");
147
148 // Clean up.
149 std::fs::remove_file(path).ok();
150 Ok(())
151}Sourcefn buffers(&self) -> Vec<Buffer>
fn buffers(&self) -> Vec<Buffer>
Return this module’s non-learnable persistent buffers (e.g., running stats).
Default: recursively collects from sub_modules() with pointer dedup.
Leaf modules should override to return their own buffers.
Sourcefn name(&self) -> &str
fn name(&self) -> &str
Human-readable type name used as node ID prefix in graph visualization. Override to return a lowercase identifier (e.g., “linear”, “gelu”).
Sourcefn sub_modules(&self) -> Vec<Rc<dyn Module>>
fn sub_modules(&self) -> Vec<Rc<dyn Module>>
Return direct child modules for recursive tree walks. Override in composite modules (loops, switches, gates).
Sourcefn move_to_device(&self, _device: Device)
fn move_to_device(&self, _device: Device)
Move all parameters and buffers to the given device. Override in modules like BatchNorm that hold non-parameter state.
Sourcefn set_training(&self, _training: bool)
fn set_training(&self, _training: bool)
Set training/eval mode. Affects Dropout, BatchNorm, etc. Override in modules with mode-dependent behavior.
Sourcefn train(&self)
fn train(&self)
Set training mode. Shorthand for set_training(true).
Examples found in repository?
11fn main() -> Result<()> {
12 // Build the model.
13 let model = FlowBuilder::from(Linear::new(2, 16)?)
14 .through(GELU)
15 .through(LayerNorm::new(16)?)
16 .also(Linear::new(16, 16)?)
17 .through(Linear::new(16, 2)?)
18 .build()?;
19
20 // Set up training.
21 let params = model.parameters();
22 let mut optimizer = Adam::new(¶ms, 0.01);
23 model.train();
24
25 // Generate some random data (XOR-ish pattern).
26 let opts = TensorOptions::default();
27 let batches: Vec<(Tensor, Tensor)> = (0..32)
28 .map(|_| {
29 let x = Tensor::randn(&[16, 2], opts).unwrap();
30 let y = Tensor::randn(&[16, 2], opts).unwrap();
31 (x, y)
32 })
33 .collect();
34
35 // Training loop with monitor.
36 let num_epochs = 50;
37 let mut monitor = Monitor::new(num_epochs);
38 // monitor.serve(3000)?; // uncomment for live dashboard
39 // monitor.watch(&model); // uncomment to show graph SVG in dashboard
40
41 for epoch in 0..num_epochs {
42 let t = std::time::Instant::now();
43
44 for (input_t, target_t) in &batches {
45 let input = Variable::new(input_t.clone(), true);
46 let target = Variable::new(target_t.clone(), false);
47
48 let pred = model.forward(&input)?;
49 let loss = mse_loss(&pred, &target)?;
50
51 optimizer.zero_grad();
52 loss.backward()?;
53 clip_grad_norm(¶ms, 1.0)?;
54 optimizer.step()?;
55
56 model.record_scalar("loss", loss.item()?);
57 }
58
59 model.flush(&[]);
60 monitor.log(epoch, t.elapsed(), &model);
61 }
62
63 monitor.finish();
64 Ok(())
65}More examples
11fn main() -> Result<()> {
12 let opts = TensorOptions::default();
13
14 // Build a small model.
15 let model = FlowBuilder::from(Linear::new(4, 16)?)
16 .through(GELU)
17 .through(LayerNorm::new(16)?)
18 .also(Linear::new(16, 16)?)
19 .through(Linear::new(16, 4)?)
20 .build()?;
21
22 let params = model.parameters();
23 let mut optimizer = Adam::new(¶ms, 0.001);
24 model.train();
25
26 let num_epochs = 100usize;
27
28 // --- Compose schedulers ---
29
30 // Warmup for 10 epochs, then cosine anneal for the remaining 90.
31 let cosine = CosineScheduler::new(0.001, 1e-5, num_epochs - 10);
32 let scheduler = WarmupScheduler::new(cosine, 0.001, 10);
33
34 // Plateau scheduler as a secondary control: if loss stalls,
35 // reduce LR further regardless of the primary schedule.
36 let mut plateau = PlateauScheduler::new(0.001, 10, 0.5, 1e-6);
37
38 println!("{:>5} {:>10} {:>10} {:>10}", "epoch", "loss", "sched_lr", "eff_lr");
39 println!("{}", "-".repeat(40));
40
41 for epoch in 0..num_epochs {
42 let x = Tensor::randn(&[64, 4], opts)?;
43 let y = Tensor::randn(&[64, 4], opts)?;
44 let input = Variable::new(x, true);
45 let target = Variable::new(y, false);
46
47 optimizer.zero_grad();
48 let pred = model.forward(&input)?;
49 let loss = mse_loss(&pred, &target)?;
50 loss.backward()?;
51 clip_grad_norm(¶ms, 1.0)?;
52 optimizer.step()?;
53
54 let loss_val = loss.item()?;
55
56 // Primary schedule (warmup + cosine).
57 let sched_lr = scheduler.lr(epoch);
58
59 // Plateau feedback — takes the minimum of primary and plateau LR.
60 let plateau_lr = plateau.observe(loss_val);
61 let effective_lr = sched_lr.min(plateau_lr);
62
63 optimizer.set_lr(effective_lr);
64
65 if epoch % 10 == 0 || epoch == num_epochs - 1 {
66 println!(
67 "{:>5} {:>10.6} {:>10.6} {:>10.6}",
68 epoch, loss_val, sched_lr, effective_lr
69 );
70 }
71 }
72
73 println!("\nFinal LR: {:.6}", optimizer.lr());
74 Ok(())
75}12fn main() -> Result<()> {
13 let opts = TensorOptions::default();
14
15 // Random regression data.
16 let batches: Vec<(Tensor, Tensor)> = (0..4)
17 .map(|_| {
18 let x = Tensor::randn(&[50, 8], opts).unwrap();
19 let y = Tensor::randn(&[50, 4], opts).unwrap();
20 (x, y)
21 })
22 .collect();
23
24 // Build the model (starts in float32).
25 let model = FlowBuilder::from(Linear::new(8, 32)?)
26 .through(GELU)
27 .through(LayerNorm::new(32)?)
28 .also(Linear::new(32, 32)?)
29 .through(Linear::new(32, 4)?)
30 .build()?;
31
32 let params = model.parameters();
33
34 // Cast parameters to float16 for reduced memory and faster matmuls.
35 cast_parameters(¶ms, DType::Float16);
36 println!("Parameters cast to float16");
37
38 let mut optimizer = Adam::new(¶ms, 0.001);
39 let mut scaler = GradScaler::new();
40 model.train();
41
42 let num_epochs = 100usize;
43 let mut monitor = Monitor::new(num_epochs);
44
45 for epoch in 0..num_epochs {
46 let t = std::time::Instant::now();
47 let mut steps_taken = 0u32;
48
49 for (xb, yb) in &batches {
50 // Cast inputs to match parameter dtype.
51 let input = Variable::new(xb.to_dtype(DType::Float16)?, true);
52 let target = Variable::new(yb.to_dtype(DType::Float16)?, false);
53
54 optimizer.zero_grad();
55 let pred = model.forward(&input)?;
56 let loss = mse_loss(&pred, &target)?;
57
58 // Scale loss before backward to prevent gradient underflow.
59 let scaled_loss = scaler.scale(&loss)?;
60 scaled_loss.backward()?;
61
62 // Unscale, check for inf/nan, step if clean.
63 let stepped = scaler.step(¶ms, &mut || optimizer.step())?;
64 scaler.update();
65
66 if stepped {
67 steps_taken += 1;
68 }
69
70 model.record_scalar("loss", loss.item()?);
71 }
72
73 model.record_scalar("scale", scaler.scale_factor());
74 model.flush(&[]);
75 monitor.log(epoch, t.elapsed(), &model);
76
77 if steps_taken == 0 {
78 println!("epoch {}: all steps skipped (inf grads), scale={:.0}", epoch, scaler.scale_factor());
79 }
80 }
81
82 monitor.finish();
83
84 // Cast back to float32 for inference or checkpointing.
85 cast_parameters(¶ms, DType::Float32);
86 println!("Parameters cast back to float32 for export");
87
88 Ok(())
89}12fn main() -> Result<()> {
13 let opts = TensorOptions::default();
14
15 // Build a model with tagged intermediate nodes.
16 let model = FlowBuilder::from(Linear::new(4, 16)?)
17 .through(GELU)
18 .tag("hidden")
19 .through(LayerNorm::new(16)?)
20 .also(Linear::new(16, 16)?)
21 .tag("residual")
22 .through(Linear::new(16, 4)?)
23 .build()?;
24
25 let params = model.parameters();
26 let mut optimizer = Adam::new(¶ms, 0.01);
27 let scheduler = CosineScheduler::new(0.01, 1e-4, 200);
28 model.train();
29
30 let num_epochs = 200usize;
31 let mut monitor = Monitor::new(num_epochs);
32
33 // Generate stable data so trends are meaningful.
34 let x_data = Tensor::randn(&[128, 4], opts)?;
35 let y_data = Tensor::randn(&[128, 4], opts)?;
36
37 // Split into batches.
38 let x_batches = x_data.batches(32)?;
39 let y_batches = y_data.batches(32)?;
40 let batches: Vec<_> = x_batches.into_iter().zip(y_batches).collect();
41
42 println!("{:>5} {:>10} {:>10} {:>8} {:>10}", "epoch", "loss", "slope", "stalled", "status");
43 println!("{}", "-".repeat(50));
44
45 for epoch in 0..num_epochs {
46 let t = std::time::Instant::now();
47
48 for (xb, yb) in &batches {
49 let input = Variable::new(xb.clone(), true);
50 let target = Variable::new(yb.clone(), false);
51
52 optimizer.zero_grad();
53 let pred = model.forward(&input)?;
54
55 // Collect tagged node outputs as metrics.
56 model.collect(&["hidden", "residual"])?;
57
58 let loss = mse_loss(&pred, &target)?;
59 loss.backward()?;
60 clip_grad_norm(¶ms, 1.0)?;
61 optimizer.step()?;
62
63 model.record_scalar("loss", loss.item()?);
64 }
65
66 // Flush batch metrics to epoch history.
67 let lr = scheduler.lr(epoch);
68 optimizer.set_lr(lr);
69 model.record_scalar("lr", lr);
70 model.flush(&["hidden", "residual", "loss", "lr"]);
71 monitor.log(epoch, t.elapsed(), &model);
72
73 // Query trends for early stopping (window=10 epochs).
74 let loss_trend = model.trend("loss");
75 let stalled = loss_trend.stalled(10, 1e-5);
76 let converged = loss_trend.converged(10, 1e-5);
77
78 if epoch % 20 == 0 || stalled || converged {
79 let status = if converged {
80 "CONVERGED"
81 } else if stalled {
82 "stalled"
83 } else if loss_trend.improving(10) {
84 "improving"
85 } else {
86 ""
87 };
88
89 println!(
90 "{:>5} {:>10.6} {:>10.6} {:>8} {:>10}",
91 epoch,
92 loss_trend.latest(),
93 loss_trend.slope(10),
94 stalled,
95 status
96 );
97 }
98
99 // Early stop when converged.
100 if converged && epoch > 20 {
101 println!("\nEarly stop at epoch {} — loss converged.", epoch);
102 break;
103 }
104 }
105
106 monitor.finish();
107
108 // Group trend queries (window=10 for all).
109 let group = model.trends(&["hidden", "residual", "loss"]);
110 println!("\nFinal trend summary:");
111 println!(" All improving: {}", group.all_improving(10));
112 println!(" Any stalled: {}", group.any_stalled(10, 1e-5));
113 println!(" Mean slope: {:.6}", group.mean_slope(10));
114
115 Ok(())
116}11fn main() -> Result<()> {
12 let opts = TensorOptions::default();
13
14 // --- Phase 1: Train an encoder on a proxy task ---
15 println!("=== Phase 1: Train encoder ===");
16
17 let encoder = FlowBuilder::from(Linear::new(4, 16)?)
18 .through(GELU)
19 .through(LayerNorm::new(16)?)
20 .tag("encoded")
21 .through(Linear::new(16, 4)?)
22 .build()?;
23
24 let params = encoder.parameters();
25 let mut opt = Adam::new(¶ms, 0.01);
26 encoder.train();
27
28 for epoch in 0..50 {
29 let x = Tensor::randn(&[32, 4], opts)?;
30 let y = Tensor::randn(&[32, 4], opts)?;
31 let input = Variable::new(x, true);
32 let target = Variable::new(y, false);
33
34 opt.zero_grad();
35 let pred = encoder.forward(&input)?;
36 let loss = mse_loss(&pred, &target)?;
37 loss.backward()?;
38 clip_grad_norm(¶ms, 1.0)?;
39 opt.step()?;
40
41 if epoch % 10 == 0 {
42 println!(" epoch {:>3}: loss={:.4}", epoch, loss.item()?);
43 }
44 }
45
46 // Save the pretrained encoder.
47 let ckpt_path = "pretrained_encoder.fdl";
48 let named = encoder.named_parameters();
49 let named_bufs = encoder.named_buffers();
50 save_checkpoint_file(ckpt_path, &named, &named_bufs, None)?;
51 println!("Encoder saved to {}", ckpt_path);
52
53 // --- Phase 2: Build a new model and load encoder weights ---
54 println!("\n=== Phase 2: Transfer to new architecture ===");
55
56 // New architecture reuses the encoder layers but adds a different head.
57 let model = FlowBuilder::from(Linear::new(4, 16)?)
58 .through(GELU)
59 .through(LayerNorm::new(16)?)
60 .tag("encoded")
61 .also(Linear::new(16, 16)?) // new: residual connection
62 .through(Linear::new(16, 2)?) // new: different output dim
63 .build()?;
64
65 // Partial load: matching names get loaded, new layers keep random init.
66 let named2 = model.named_parameters();
67 let named_bufs2 = model.named_buffers();
68 let report = load_checkpoint_file(ckpt_path, &named2, &named_bufs2, None)?;
69
70 println!("Loaded {} parameters, skipped {}, missing {}",
71 report.loaded.len(), report.skipped.len(), report.missing.len());
72
73 // Freeze the encoder layers (first 3 modules: Linear, GELU, LayerNorm).
74 let all_params = model.parameters();
75 for (i, p) in all_params.iter().enumerate() {
76 if i < 3 {
77 p.freeze()?;
78 }
79 }
80 println!("Encoder layers frozen");
81
82 // Set up optimizer with parameter groups:
83 // - Frozen params are excluded automatically (zero grad).
84 // - New head gets higher LR.
85 let trainable: Vec<Parameter> = all_params
86 .iter()
87 .filter(|p| !p.is_frozen())
88 .cloned()
89 .collect();
90
91 let mut opt2 = Adam::new(&trainable, 0.005);
92 model.train();
93
94 // --- Phase 3: Fine-tune the new head ---
95 println!("\n=== Phase 3: Fine-tune new head ===");
96
97 for epoch in 0..50 {
98 let x = Tensor::randn(&[32, 4], opts)?;
99 let y = Tensor::randn(&[32, 2], opts)?;
100 let input = Variable::new(x, true);
101 let target = Variable::new(y, false);
102
103 opt2.zero_grad();
104 let pred = model.forward(&input)?;
105 let loss = mse_loss(&pred, &target)?;
106 loss.backward()?;
107 clip_grad_norm(&trainable, 1.0)?;
108 opt2.step()?;
109
110 if epoch % 10 == 0 {
111 println!(" epoch {:>3}: loss={:.4}", epoch, loss.item()?);
112 }
113 }
114
115 // --- Phase 4: Unfreeze and fine-tune everything ---
116 println!("\n=== Phase 4: Full fine-tune (unfrozen) ===");
117
118 for p in &all_params {
119 p.unfreeze()?;
120 }
121
122 let mut opt3 = Adam::new(&all_params, 0.001); // lower LR for full model
123 for epoch in 0..30 {
124 let x = Tensor::randn(&[32, 4], opts)?;
125 let y = Tensor::randn(&[32, 2], opts)?;
126 let input = Variable::new(x, true);
127 let target = Variable::new(y, false);
128
129 opt3.zero_grad();
130 let pred = model.forward(&input)?;
131 let loss = mse_loss(&pred, &target)?;
132 loss.backward()?;
133 clip_grad_norm(&all_params, 1.0)?;
134 opt3.step()?;
135
136 if epoch % 10 == 0 {
137 println!(" epoch {:>3}: loss={:.4}", epoch, loss.item()?);
138 }
139 }
140
141 println!("\nTransfer learning complete.");
142
143 // Clean up.
144 std::fs::remove_file(ckpt_path).ok();
145 Ok(())
146}12fn main() -> Result<()> {
13 // --- Data: sin(x) on [-2pi, 2pi] ---
14 let opts = TensorOptions::default();
15 let n_samples = 200i64;
16 let tau = std::f64::consts::TAU;
17 let x_all = Tensor::linspace(-tau, tau, n_samples, opts)?; // [-2pi, 2pi]
18 let y_all = x_all.sin()?;
19
20 // Reshape to [n, 1] for the network.
21 let x_data = x_all.reshape(&[n_samples, 1])?;
22 let y_data = y_all.reshape(&[n_samples, 1])?;
23
24 // Split into batches of 50.
25 let x_batches = x_data.batches(50)?;
26 let y_batches = y_data.batches(50)?;
27 let batches: Vec<_> = x_batches.into_iter().zip(y_batches).collect();
28
29 // --- Model: Linear(1,32) -> GELU -> LayerNorm -> residual -> Linear(32,1) ---
30 let model = FlowBuilder::from(Linear::new(1, 32)?)
31 .through(GELU)
32 .through(LayerNorm::new(32)?)
33 .also(Linear::new(32, 32)?) // residual connection
34 .through(Linear::new(32, 1)?)
35 .build()?;
36
37 let params = model.parameters();
38 let mut optimizer = Adam::new(¶ms, 0.005);
39 let scheduler = CosineScheduler::new(0.005, 1e-5, 200);
40 model.train();
41
42 // --- Training ---
43 let num_epochs = 200usize;
44 let mut monitor = Monitor::new(num_epochs);
45 // monitor.serve(3000)?; // uncomment for live dashboard at http://localhost:3000
46 // monitor.watch(&model);
47
48 for epoch in 0..num_epochs {
49 let t = std::time::Instant::now();
50
51 for (xb, yb) in &batches {
52 let input = Variable::new(xb.clone(), true);
53 let target = Variable::new(yb.clone(), false);
54
55 optimizer.zero_grad();
56 let pred = model.forward(&input)?;
57 let loss = mse_loss(&pred, &target)?;
58 loss.backward()?;
59 clip_grad_norm(¶ms, 1.0)?;
60 optimizer.step()?;
61
62 model.record_scalar("loss", loss.item()?);
63 }
64
65 let lr = scheduler.lr(epoch);
66 optimizer.set_lr(lr);
67 model.record_scalar("lr", lr);
68 model.flush(&[]);
69 monitor.log(epoch, t.elapsed(), &model);
70 }
71
72 monitor.finish();
73
74 // --- Evaluation ---
75 model.eval();
76 println!("\n{:>8} {:>10} {:>10} {:>8}", "x", "actual", "predicted", "error");
77 println!("{}", "-".repeat(42));
78
79 // Test on 10 evenly spaced points.
80 let test_x = Tensor::linspace(-tau, tau, 10, opts)?;
81 let test_y = test_x.sin()?;
82 let test_input = test_x.reshape(&[10, 1])?;
83
84 let pred = no_grad(|| {
85 let input = Variable::new(test_input.clone(), false);
86 model.forward(&input)
87 })?;
88
89 let pred_data = pred.data().to_f32_vec()?;
90 let actual_data = test_y.to_f32_vec()?;
91 let x_data_vec = test_x.to_f32_vec()?;
92
93 let mut max_err: f32 = 0.0;
94 for i in 0..10 {
95 let err = (pred_data[i] - actual_data[i]).abs();
96 if err > max_err {
97 max_err = err;
98 }
99 println!(
100 "{:>8.3} {:>10.4} {:>10.4} {:>8.4}",
101 x_data_vec[i], actual_data[i], pred_data[i], err
102 );
103 }
104 println!("\nMax error: {:.4}", max_err);
105
106 // --- Checkpoint round-trip ---
107 let path = "sine_model.fdl";
108 let named = model.named_parameters();
109 let named_bufs = model.named_buffers();
110 save_checkpoint_file(path, &named, &named_bufs, Some(model.structural_hash()))?;
111 println!("Checkpoint saved to {}", path);
112
113 // Rebuild architecture and load weights.
114 let model2 = FlowBuilder::from(Linear::new(1, 32)?)
115 .through(GELU)
116 .through(LayerNorm::new(32)?)
117 .also(Linear::new(32, 32)?)
118 .through(Linear::new(32, 1)?)
119 .build()?;
120
121 let named2 = model2.named_parameters();
122 let named_bufs2 = model2.named_buffers();
123 load_checkpoint_file(path, &named2, &named_bufs2, Some(model2.structural_hash()))?;
124 model2.eval();
125
126 // Verify loaded model produces the same output.
127 let pred2 = no_grad(|| {
128 let input = Variable::new(test_input.clone(), false);
129 model2.forward(&input)
130 })?;
131
132 let pred2_data = pred2.data().to_f32_vec()?;
133 let mut reload_diff: f32 = 0.0;
134 for i in 0..10 {
135 let d = (pred_data[i] - pred2_data[i]).abs();
136 if d > reload_diff {
137 reload_diff = d;
138 }
139 }
140 println!("Checkpoint reload max diff: {:.6}", reload_diff);
141 assert!(
142 reload_diff < 1e-5,
143 "Checkpoint round-trip mismatch: {}",
144 reload_diff
145 );
146 println!("Checkpoint round-trip verified.");
147
148 // Clean up.
149 std::fs::remove_file(path).ok();
150 Ok(())
151}Sourcefn eval(&self)
fn eval(&self)
Set eval mode. Shorthand for set_training(false).
Examples found in repository?
12fn main() -> Result<()> {
13 // --- Data: sin(x) on [-2pi, 2pi] ---
14 let opts = TensorOptions::default();
15 let n_samples = 200i64;
16 let tau = std::f64::consts::TAU;
17 let x_all = Tensor::linspace(-tau, tau, n_samples, opts)?; // [-2pi, 2pi]
18 let y_all = x_all.sin()?;
19
20 // Reshape to [n, 1] for the network.
21 let x_data = x_all.reshape(&[n_samples, 1])?;
22 let y_data = y_all.reshape(&[n_samples, 1])?;
23
24 // Split into batches of 50.
25 let x_batches = x_data.batches(50)?;
26 let y_batches = y_data.batches(50)?;
27 let batches: Vec<_> = x_batches.into_iter().zip(y_batches).collect();
28
29 // --- Model: Linear(1,32) -> GELU -> LayerNorm -> residual -> Linear(32,1) ---
30 let model = FlowBuilder::from(Linear::new(1, 32)?)
31 .through(GELU)
32 .through(LayerNorm::new(32)?)
33 .also(Linear::new(32, 32)?) // residual connection
34 .through(Linear::new(32, 1)?)
35 .build()?;
36
37 let params = model.parameters();
38 let mut optimizer = Adam::new(¶ms, 0.005);
39 let scheduler = CosineScheduler::new(0.005, 1e-5, 200);
40 model.train();
41
42 // --- Training ---
43 let num_epochs = 200usize;
44 let mut monitor = Monitor::new(num_epochs);
45 // monitor.serve(3000)?; // uncomment for live dashboard at http://localhost:3000
46 // monitor.watch(&model);
47
48 for epoch in 0..num_epochs {
49 let t = std::time::Instant::now();
50
51 for (xb, yb) in &batches {
52 let input = Variable::new(xb.clone(), true);
53 let target = Variable::new(yb.clone(), false);
54
55 optimizer.zero_grad();
56 let pred = model.forward(&input)?;
57 let loss = mse_loss(&pred, &target)?;
58 loss.backward()?;
59 clip_grad_norm(¶ms, 1.0)?;
60 optimizer.step()?;
61
62 model.record_scalar("loss", loss.item()?);
63 }
64
65 let lr = scheduler.lr(epoch);
66 optimizer.set_lr(lr);
67 model.record_scalar("lr", lr);
68 model.flush(&[]);
69 monitor.log(epoch, t.elapsed(), &model);
70 }
71
72 monitor.finish();
73
74 // --- Evaluation ---
75 model.eval();
76 println!("\n{:>8} {:>10} {:>10} {:>8}", "x", "actual", "predicted", "error");
77 println!("{}", "-".repeat(42));
78
79 // Test on 10 evenly spaced points.
80 let test_x = Tensor::linspace(-tau, tau, 10, opts)?;
81 let test_y = test_x.sin()?;
82 let test_input = test_x.reshape(&[10, 1])?;
83
84 let pred = no_grad(|| {
85 let input = Variable::new(test_input.clone(), false);
86 model.forward(&input)
87 })?;
88
89 let pred_data = pred.data().to_f32_vec()?;
90 let actual_data = test_y.to_f32_vec()?;
91 let x_data_vec = test_x.to_f32_vec()?;
92
93 let mut max_err: f32 = 0.0;
94 for i in 0..10 {
95 let err = (pred_data[i] - actual_data[i]).abs();
96 if err > max_err {
97 max_err = err;
98 }
99 println!(
100 "{:>8.3} {:>10.4} {:>10.4} {:>8.4}",
101 x_data_vec[i], actual_data[i], pred_data[i], err
102 );
103 }
104 println!("\nMax error: {:.4}", max_err);
105
106 // --- Checkpoint round-trip ---
107 let path = "sine_model.fdl";
108 let named = model.named_parameters();
109 let named_bufs = model.named_buffers();
110 save_checkpoint_file(path, &named, &named_bufs, Some(model.structural_hash()))?;
111 println!("Checkpoint saved to {}", path);
112
113 // Rebuild architecture and load weights.
114 let model2 = FlowBuilder::from(Linear::new(1, 32)?)
115 .through(GELU)
116 .through(LayerNorm::new(32)?)
117 .also(Linear::new(32, 32)?)
118 .through(Linear::new(32, 1)?)
119 .build()?;
120
121 let named2 = model2.named_parameters();
122 let named_bufs2 = model2.named_buffers();
123 load_checkpoint_file(path, &named2, &named_bufs2, Some(model2.structural_hash()))?;
124 model2.eval();
125
126 // Verify loaded model produces the same output.
127 let pred2 = no_grad(|| {
128 let input = Variable::new(test_input.clone(), false);
129 model2.forward(&input)
130 })?;
131
132 let pred2_data = pred2.data().to_f32_vec()?;
133 let mut reload_diff: f32 = 0.0;
134 for i in 0..10 {
135 let d = (pred_data[i] - pred2_data[i]).abs();
136 if d > reload_diff {
137 reload_diff = d;
138 }
139 }
140 println!("Checkpoint reload max diff: {:.6}", reload_diff);
141 assert!(
142 reload_diff < 1e-5,
143 "Checkpoint round-trip mismatch: {}",
144 reload_diff
145 );
146 println!("Checkpoint round-trip verified.");
147
148 // Clean up.
149 std::fs::remove_file(path).ok();
150 Ok(())
151}More examples
694fn main() {
695 println!("=== floDl showcase ===\n");
696
697 // -- Build --
698 println!("Building graph...");
699 let g = build_showcase().expect("build failed");
700 let n_params = g.parameters().len();
701 println!("Parameters: {}", n_params);
702
703 // -- Forward (with auxiliary input) --
704 let result = g.forward_multi(&[make_input(false), make_context()])
705 .expect("forward failed");
706 println!("Output: {:?} (shape {:?})", result.data().to_f32_vec().unwrap(), result.shape());
707
708 // -- Forward ref carries state --
709 g.reset_state();
710 let r1 = g.forward_multi(&[make_input(false), make_context()]).unwrap();
711 let v1 = r1.data().to_f32_vec().unwrap();
712 let r2 = g.forward_multi(&[make_input(false), make_context()]).unwrap();
713 let v2 = r2.data().to_f32_vec().unwrap();
714 println!("State drift: pass2 differs = {}", v1 != v2);
715
716 // -- Reset --
717 g.reset_state();
718 let r3 = g.forward_multi(&[make_input(false), make_context()]).unwrap();
719 let v3 = r3.data().to_f32_vec().unwrap();
720 println!("Reset restores: {}", v1 == v3);
721
722 // -- DOT + SVG (structural) --
723 let dot = g.dot();
724 println!("DOT: {} bytes", dot.len());
725
726 // Write structural DOT
727 let dot_path = concat!(env!("CARGO_MANIFEST_DIR"), "/examples/showcase/showcase.dot");
728 std::fs::write(dot_path, &dot).expect("write showcase.dot");
729 println!("Wrote {}", dot_path);
730
731 // Write structural SVG
732 let svg_path = concat!(env!("CARGO_MANIFEST_DIR"), "/examples/showcase/showcase.svg");
733 let svg = g.svg(Some(svg_path)).expect("write showcase.svg");
734 println!("Wrote {} ({} bytes)", svg_path, svg.len());
735
736 // -- Training loop with observation + profiling + monitor --
737 println!("\n--- Training (5 epochs x 4 steps) ---");
738 g.train();
739 g.reset_state();
740 g.enable_profiling();
741
742 let params = g.parameters();
743 let mut optimizer = Adam::new(¶ms, 0.001);
744 let num_epochs = 5;
745 let total_steps = num_epochs * 4;
746 let sched = CosineScheduler::new(0.001, 1e-5, total_steps);
747 let mut monitor = Monitor::new(num_epochs);
748
749 let mut step_idx = 0;
750 for epoch in 0..num_epochs {
751 let t = std::time::Instant::now();
752 for _ in 0..4 {
753 optimizer.zero_grad();
754 let input = make_input(true);
755 let ctx = make_context();
756 let target = make_target();
757
758 let pred = g.forward_multi(&[input, ctx]).unwrap();
759 let loss = mse_loss(&pred, &target).unwrap();
760
761 loss.backward().unwrap();
762 clip_grad_norm(¶ms, 1.0).unwrap();
763 optimizer.set_lr(sched.lr(step_idx));
764 optimizer.step().unwrap();
765 step_idx += 1;
766
767 g.record_scalar("loss", loss.item().unwrap());
768 g.record_scalar("lr", sched.lr(step_idx - 1));
769 g.end_step();
770 }
771
772 g.end_epoch();
773 monitor.log(epoch, t.elapsed(), &g);
774 }
775
776 // -- Trends --
777 let trend = g.trend("loss");
778 println!(
779 "\nLoss trend: {} epochs, slope={:.4}, improving={}",
780 trend.len(),
781 trend.slope(0),
782 trend.improving(0),
783 );
784
785 // Timing trends use node IDs — pick the first tagged one
786 let timing = g.timing_trend("input");
787 println!(
788 "Timing trend (input node): {} epochs, mean={:.1}us",
789 timing.len(),
790 timing.mean() * 1e6,
791 );
792
793 // -- Write profiling DOT + SVG --
794 let profile_dot = g.dot_with_profile();
795 let profile_dot_path = concat!(env!("CARGO_MANIFEST_DIR"), "/examples/showcase/showcase_profile.dot");
796 std::fs::write(profile_dot_path, &profile_dot).expect("write showcase_profile.dot");
797 println!("Wrote {}", profile_dot_path);
798
799 let profile_svg_path = concat!(env!("CARGO_MANIFEST_DIR"), "/examples/showcase/showcase_profile.svg");
800 let profile_svg = g.svg_with_profile(Some(profile_svg_path)).expect("write showcase_profile.svg");
801 println!("Wrote {} ({} bytes)", profile_svg_path, profile_svg.len());
802
803 // -- Write training HTML --
804 let html_path = concat!(env!("CARGO_MANIFEST_DIR"), "/examples/showcase/showcase_training.html");
805 g.plot_html(html_path, &["loss"]).expect("write showcase_training.html");
806 println!("Wrote {}", html_path);
807
808 // -- Write training log --
809 let log_path = concat!(env!("CARGO_MANIFEST_DIR"), "/examples/showcase/showcase_training.log");
810 g.write_log(log_path, 5, &["loss"]).expect("write showcase_training.log");
811 println!("Wrote {}", log_path);
812
813 // -- Checkpoint round-trip --
814 let path = "/tmp/flodl_showcase_checkpoint.fdl";
815 let named = g.named_parameters();
816 let named_bufs = g.named_buffers();
817 save_checkpoint_file(path, &named, &named_bufs, Some(g.structural_hash())).expect("save failed");
818 let report = load_checkpoint_file(path, &named, &named_bufs, Some(g.structural_hash())).expect("load failed");
819 println!("\nCheckpoint save/load: OK ({} loaded)", report.loaded.len());
820
821 // -- no_grad inference (eval mode works now — BatchNorm has running stats from training) --
822 g.eval();
823 g.reset_state();
824 let final_out = no_grad(|| g.forward_multi(&[make_input(false), make_context()])).unwrap();
825 let final_vals = final_out.data().to_f32_vec().unwrap();
826 println!("no_grad inference: {:?}", final_vals);
827 assert!(final_vals.iter().all(|v| v.is_finite()), "no_grad output should be finite");
828
829 println!("\nAll showcase checks passed.");
830}Sourcefn trace(&self) -> Option<Variable>
fn trace(&self) -> Option<Variable>
Return per-iteration side output for loop tracing.
Override in loop body modules that capture trajectory data
(e.g., attention fixation points). Returns None by default.
When Some, the loop executor collects traces accessible via
Graph::traces().
Sourcefn as_named_input(&self) -> Option<&dyn NamedInputModule>
fn as_named_input(&self) -> Option<&dyn NamedInputModule>
Upcast to NamedInputModule for multi-input graphs.
Override in types that implement NamedInputModule to enable
receiving additional named inputs via graph using().
Sourcefn structural_hash(&self) -> Option<String>
fn structural_hash(&self) -> Option<String>
SHA-256 hex hash of module architecture for checkpoint validation. Override in composite modules (Graph) that compute a deterministic hash from their topology and parameter shapes.
Sourcefn reset(&self)
fn reset(&self)
Reset internal state (e.g. recurrent hidden state) between sequences. Called by loops before iterating to clear stale tensors whose grad_fns may reference freed saved tensors. Override in stateful modules.
Sourcefn detach_state(&self)
fn detach_state(&self)
Detach internal state from the computation graph (for truncated BPTT). Called between training steps to break gradient chains on state carried across forward passes (e.g., recurrent hidden state). Override in stateful modules.