Skip to main content

Module

Trait Module 

Source
pub trait Module {
Show 14 methods // Required method fn forward(&self, input: &Variable) -> Result<Variable>; // Provided methods fn parameters(&self) -> Vec<Parameter> { ... } fn buffers(&self) -> Vec<Buffer> { ... } fn name(&self) -> &str { ... } fn sub_modules(&self) -> Vec<Rc<dyn Module>> { ... } fn move_to_device(&self, _device: Device) { ... } fn set_training(&self, _training: bool) { ... } fn train(&self) { ... } fn eval(&self) { ... } fn trace(&self) -> Option<Variable> { ... } fn as_named_input(&self) -> Option<&dyn NamedInputModule> { ... } fn structural_hash(&self) -> Option<String> { ... } fn reset(&self) { ... } fn detach_state(&self) { ... }
}
Expand description

The core module trait: forward pass + parameter access.

All neural network layers implement Module. Composite modules (Graph, loops, gates) implement Module too, so they compose like any other layer.

let model = Linear::new(4, 2)?;
let x = Variable::new(Tensor::randn(&[1, 4], opts)?, false);
let y = model.forward(&x)?; // [1, 4] → [1, 2]

Required Methods§

Source

fn forward(&self, input: &Variable) -> Result<Variable>

Run the forward pass on input and return the result.

Provided Methods§

Source

fn parameters(&self) -> Vec<Parameter>

Return this module’s learnable parameters. Default: recursively collects from sub_modules() with pointer dedup. Leaf modules should override to return their own parameters.

Examples found in repository?
examples/quickstart/main.rs (line 21)
11fn main() -> Result<()> {
12    // Build the model.
13    let model = FlowBuilder::from(Linear::new(2, 16)?)
14        .through(GELU)
15        .through(LayerNorm::new(16)?)
16        .also(Linear::new(16, 16)?)
17        .through(Linear::new(16, 2)?)
18        .build()?;
19
20    // Set up training.
21    let params = model.parameters();
22    let mut optimizer = Adam::new(&params, 0.01);
23    model.train();
24
25    // Generate some random data (XOR-ish pattern).
26    let opts = TensorOptions::default();
27    let batches: Vec<(Tensor, Tensor)> = (0..32)
28        .map(|_| {
29            let x = Tensor::randn(&[16, 2], opts).unwrap();
30            let y = Tensor::randn(&[16, 2], opts).unwrap();
31            (x, y)
32        })
33        .collect();
34
35    // Training loop with monitor.
36    let num_epochs = 50;
37    let mut monitor = Monitor::new(num_epochs);
38    // monitor.serve(3000)?;   // uncomment for live dashboard
39    // monitor.watch(&model);  // uncomment to show graph SVG in dashboard
40
41    for epoch in 0..num_epochs {
42        let t = std::time::Instant::now();
43
44        for (input_t, target_t) in &batches {
45            let input = Variable::new(input_t.clone(), true);
46            let target = Variable::new(target_t.clone(), false);
47
48            let pred = model.forward(&input)?;
49            let loss = mse_loss(&pred, &target)?;
50
51            optimizer.zero_grad();
52            loss.backward()?;
53            clip_grad_norm(&params, 1.0)?;
54            optimizer.step()?;
55
56            model.record_scalar("loss", loss.item()?);
57        }
58
59        model.flush(&[]);
60        monitor.log(epoch, t.elapsed(), &model);
61    }
62
63    monitor.finish();
64    Ok(())
65}
More examples
Hide additional examples
examples/schedulers/main.rs (line 22)
11fn main() -> Result<()> {
12    let opts = TensorOptions::default();
13
14    // Build a small model.
15    let model = FlowBuilder::from(Linear::new(4, 16)?)
16        .through(GELU)
17        .through(LayerNorm::new(16)?)
18        .also(Linear::new(16, 16)?)
19        .through(Linear::new(16, 4)?)
20        .build()?;
21
22    let params = model.parameters();
23    let mut optimizer = Adam::new(&params, 0.001);
24    model.train();
25
26    let num_epochs = 100usize;
27
28    // --- Compose schedulers ---
29
30    // Warmup for 10 epochs, then cosine anneal for the remaining 90.
31    let cosine = CosineScheduler::new(0.001, 1e-5, num_epochs - 10);
32    let scheduler = WarmupScheduler::new(cosine, 0.001, 10);
33
34    // Plateau scheduler as a secondary control: if loss stalls,
35    // reduce LR further regardless of the primary schedule.
36    let mut plateau = PlateauScheduler::new(0.001, 10, 0.5, 1e-6);
37
38    println!("{:>5}  {:>10}  {:>10}  {:>10}", "epoch", "loss", "sched_lr", "eff_lr");
39    println!("{}", "-".repeat(40));
40
41    for epoch in 0..num_epochs {
42        let x = Tensor::randn(&[64, 4], opts)?;
43        let y = Tensor::randn(&[64, 4], opts)?;
44        let input = Variable::new(x, true);
45        let target = Variable::new(y, false);
46
47        optimizer.zero_grad();
48        let pred = model.forward(&input)?;
49        let loss = mse_loss(&pred, &target)?;
50        loss.backward()?;
51        clip_grad_norm(&params, 1.0)?;
52        optimizer.step()?;
53
54        let loss_val = loss.item()?;
55
56        // Primary schedule (warmup + cosine).
57        let sched_lr = scheduler.lr(epoch);
58
59        // Plateau feedback — takes the minimum of primary and plateau LR.
60        let plateau_lr = plateau.observe(loss_val);
61        let effective_lr = sched_lr.min(plateau_lr);
62
63        optimizer.set_lr(effective_lr);
64
65        if epoch % 10 == 0 || epoch == num_epochs - 1 {
66            println!(
67                "{:>5}  {:>10.6}  {:>10.6}  {:>10.6}",
68                epoch, loss_val, sched_lr, effective_lr
69            );
70        }
71    }
72
73    println!("\nFinal LR: {:.6}", optimizer.lr());
74    Ok(())
75}
examples/mixed_precision/main.rs (line 32)
12fn main() -> Result<()> {
13    let opts = TensorOptions::default();
14
15    // Random regression data.
16    let batches: Vec<(Tensor, Tensor)> = (0..4)
17        .map(|_| {
18            let x = Tensor::randn(&[50, 8], opts).unwrap();
19            let y = Tensor::randn(&[50, 4], opts).unwrap();
20            (x, y)
21        })
22        .collect();
23
24    // Build the model (starts in float32).
25    let model = FlowBuilder::from(Linear::new(8, 32)?)
26        .through(GELU)
27        .through(LayerNorm::new(32)?)
28        .also(Linear::new(32, 32)?)
29        .through(Linear::new(32, 4)?)
30        .build()?;
31
32    let params = model.parameters();
33
34    // Cast parameters to float16 for reduced memory and faster matmuls.
35    cast_parameters(&params, DType::Float16);
36    println!("Parameters cast to float16");
37
38    let mut optimizer = Adam::new(&params, 0.001);
39    let mut scaler = GradScaler::new();
40    model.train();
41
42    let num_epochs = 100usize;
43    let mut monitor = Monitor::new(num_epochs);
44
45    for epoch in 0..num_epochs {
46        let t = std::time::Instant::now();
47        let mut steps_taken = 0u32;
48
49        for (xb, yb) in &batches {
50            // Cast inputs to match parameter dtype.
51            let input = Variable::new(xb.to_dtype(DType::Float16)?, true);
52            let target = Variable::new(yb.to_dtype(DType::Float16)?, false);
53
54            optimizer.zero_grad();
55            let pred = model.forward(&input)?;
56            let loss = mse_loss(&pred, &target)?;
57
58            // Scale loss before backward to prevent gradient underflow.
59            let scaled_loss = scaler.scale(&loss)?;
60            scaled_loss.backward()?;
61
62            // Unscale, check for inf/nan, step if clean.
63            let stepped = scaler.step(&params, &mut || optimizer.step())?;
64            scaler.update();
65
66            if stepped {
67                steps_taken += 1;
68            }
69
70            model.record_scalar("loss", loss.item()?);
71        }
72
73        model.record_scalar("scale", scaler.scale_factor());
74        model.flush(&[]);
75        monitor.log(epoch, t.elapsed(), &model);
76
77        if steps_taken == 0 {
78            println!("epoch {}: all steps skipped (inf grads), scale={:.0}", epoch, scaler.scale_factor());
79        }
80    }
81
82    monitor.finish();
83
84    // Cast back to float32 for inference or checkpointing.
85    cast_parameters(&params, DType::Float32);
86    println!("Parameters cast back to float32 for export");
87
88    Ok(())
89}
examples/observation/main.rs (line 25)
12fn main() -> Result<()> {
13    let opts = TensorOptions::default();
14
15    // Build a model with tagged intermediate nodes.
16    let model = FlowBuilder::from(Linear::new(4, 16)?)
17        .through(GELU)
18        .tag("hidden")
19        .through(LayerNorm::new(16)?)
20        .also(Linear::new(16, 16)?)
21        .tag("residual")
22        .through(Linear::new(16, 4)?)
23        .build()?;
24
25    let params = model.parameters();
26    let mut optimizer = Adam::new(&params, 0.01);
27    let scheduler = CosineScheduler::new(0.01, 1e-4, 200);
28    model.train();
29
30    let num_epochs = 200usize;
31    let mut monitor = Monitor::new(num_epochs);
32
33    // Generate stable data so trends are meaningful.
34    let x_data = Tensor::randn(&[128, 4], opts)?;
35    let y_data = Tensor::randn(&[128, 4], opts)?;
36
37    // Split into batches.
38    let x_batches = x_data.batches(32)?;
39    let y_batches = y_data.batches(32)?;
40    let batches: Vec<_> = x_batches.into_iter().zip(y_batches).collect();
41
42    println!("{:>5}  {:>10}  {:>10}  {:>8}  {:>10}", "epoch", "loss", "slope", "stalled", "status");
43    println!("{}", "-".repeat(50));
44
45    for epoch in 0..num_epochs {
46        let t = std::time::Instant::now();
47
48        for (xb, yb) in &batches {
49            let input = Variable::new(xb.clone(), true);
50            let target = Variable::new(yb.clone(), false);
51
52            optimizer.zero_grad();
53            let pred = model.forward(&input)?;
54
55            // Collect tagged node outputs as metrics.
56            model.collect(&["hidden", "residual"])?;
57
58            let loss = mse_loss(&pred, &target)?;
59            loss.backward()?;
60            clip_grad_norm(&params, 1.0)?;
61            optimizer.step()?;
62
63            model.record_scalar("loss", loss.item()?);
64        }
65
66        // Flush batch metrics to epoch history.
67        let lr = scheduler.lr(epoch);
68        optimizer.set_lr(lr);
69        model.record_scalar("lr", lr);
70        model.flush(&["hidden", "residual", "loss", "lr"]);
71        monitor.log(epoch, t.elapsed(), &model);
72
73        // Query trends for early stopping (window=10 epochs).
74        let loss_trend = model.trend("loss");
75        let stalled = loss_trend.stalled(10, 1e-5);
76        let converged = loss_trend.converged(10, 1e-5);
77
78        if epoch % 20 == 0 || stalled || converged {
79            let status = if converged {
80                "CONVERGED"
81            } else if stalled {
82                "stalled"
83            } else if loss_trend.improving(10) {
84                "improving"
85            } else {
86                ""
87            };
88
89            println!(
90                "{:>5}  {:>10.6}  {:>10.6}  {:>8}  {:>10}",
91                epoch,
92                loss_trend.latest(),
93                loss_trend.slope(10),
94                stalled,
95                status
96            );
97        }
98
99        // Early stop when converged.
100        if converged && epoch > 20 {
101            println!("\nEarly stop at epoch {} — loss converged.", epoch);
102            break;
103        }
104    }
105
106    monitor.finish();
107
108    // Group trend queries (window=10 for all).
109    let group = model.trends(&["hidden", "residual", "loss"]);
110    println!("\nFinal trend summary:");
111    println!("  All improving: {}", group.all_improving(10));
112    println!("  Any stalled:   {}", group.any_stalled(10, 1e-5));
113    println!("  Mean slope:    {:.6}", group.mean_slope(10));
114
115    Ok(())
116}
examples/transfer_learning/main.rs (line 24)
11fn main() -> Result<()> {
12    let opts = TensorOptions::default();
13
14    // --- Phase 1: Train an encoder on a proxy task ---
15    println!("=== Phase 1: Train encoder ===");
16
17    let encoder = FlowBuilder::from(Linear::new(4, 16)?)
18        .through(GELU)
19        .through(LayerNorm::new(16)?)
20        .tag("encoded")
21        .through(Linear::new(16, 4)?)
22        .build()?;
23
24    let params = encoder.parameters();
25    let mut opt = Adam::new(&params, 0.01);
26    encoder.train();
27
28    for epoch in 0..50 {
29        let x = Tensor::randn(&[32, 4], opts)?;
30        let y = Tensor::randn(&[32, 4], opts)?;
31        let input = Variable::new(x, true);
32        let target = Variable::new(y, false);
33
34        opt.zero_grad();
35        let pred = encoder.forward(&input)?;
36        let loss = mse_loss(&pred, &target)?;
37        loss.backward()?;
38        clip_grad_norm(&params, 1.0)?;
39        opt.step()?;
40
41        if epoch % 10 == 0 {
42            println!("  epoch {:>3}: loss={:.4}", epoch, loss.item()?);
43        }
44    }
45
46    // Save the pretrained encoder.
47    let ckpt_path = "pretrained_encoder.fdl";
48    let named = encoder.named_parameters();
49    let named_bufs = encoder.named_buffers();
50    save_checkpoint_file(ckpt_path, &named, &named_bufs, None)?;
51    println!("Encoder saved to {}", ckpt_path);
52
53    // --- Phase 2: Build a new model and load encoder weights ---
54    println!("\n=== Phase 2: Transfer to new architecture ===");
55
56    // New architecture reuses the encoder layers but adds a different head.
57    let model = FlowBuilder::from(Linear::new(4, 16)?)
58        .through(GELU)
59        .through(LayerNorm::new(16)?)
60        .tag("encoded")
61        .also(Linear::new(16, 16)?)      // new: residual connection
62        .through(Linear::new(16, 2)?)     // new: different output dim
63        .build()?;
64
65    // Partial load: matching names get loaded, new layers keep random init.
66    let named2 = model.named_parameters();
67    let named_bufs2 = model.named_buffers();
68    let report = load_checkpoint_file(ckpt_path, &named2, &named_bufs2, None)?;
69
70    println!("Loaded {} parameters, skipped {}, missing {}",
71        report.loaded.len(), report.skipped.len(), report.missing.len());
72
73    // Freeze the encoder layers (first 3 modules: Linear, GELU, LayerNorm).
74    let all_params = model.parameters();
75    for (i, p) in all_params.iter().enumerate() {
76        if i < 3 {
77            p.freeze()?;
78        }
79    }
80    println!("Encoder layers frozen");
81
82    // Set up optimizer with parameter groups:
83    // - Frozen params are excluded automatically (zero grad).
84    // - New head gets higher LR.
85    let trainable: Vec<Parameter> = all_params
86        .iter()
87        .filter(|p| !p.is_frozen())
88        .cloned()
89        .collect();
90
91    let mut opt2 = Adam::new(&trainable, 0.005);
92    model.train();
93
94    // --- Phase 3: Fine-tune the new head ---
95    println!("\n=== Phase 3: Fine-tune new head ===");
96
97    for epoch in 0..50 {
98        let x = Tensor::randn(&[32, 4], opts)?;
99        let y = Tensor::randn(&[32, 2], opts)?;
100        let input = Variable::new(x, true);
101        let target = Variable::new(y, false);
102
103        opt2.zero_grad();
104        let pred = model.forward(&input)?;
105        let loss = mse_loss(&pred, &target)?;
106        loss.backward()?;
107        clip_grad_norm(&trainable, 1.0)?;
108        opt2.step()?;
109
110        if epoch % 10 == 0 {
111            println!("  epoch {:>3}: loss={:.4}", epoch, loss.item()?);
112        }
113    }
114
115    // --- Phase 4: Unfreeze and fine-tune everything ---
116    println!("\n=== Phase 4: Full fine-tune (unfrozen) ===");
117
118    for p in &all_params {
119        p.unfreeze()?;
120    }
121
122    let mut opt3 = Adam::new(&all_params, 0.001); // lower LR for full model
123    for epoch in 0..30 {
124        let x = Tensor::randn(&[32, 4], opts)?;
125        let y = Tensor::randn(&[32, 2], opts)?;
126        let input = Variable::new(x, true);
127        let target = Variable::new(y, false);
128
129        opt3.zero_grad();
130        let pred = model.forward(&input)?;
131        let loss = mse_loss(&pred, &target)?;
132        loss.backward()?;
133        clip_grad_norm(&all_params, 1.0)?;
134        opt3.step()?;
135
136        if epoch % 10 == 0 {
137            println!("  epoch {:>3}: loss={:.4}", epoch, loss.item()?);
138        }
139    }
140
141    println!("\nTransfer learning complete.");
142
143    // Clean up.
144    std::fs::remove_file(ckpt_path).ok();
145    Ok(())
146}
examples/sine_wave/main.rs (line 37)
12fn main() -> Result<()> {
13    // --- Data: sin(x) on [-2pi, 2pi] ---
14    let opts = TensorOptions::default();
15    let n_samples = 200i64;
16    let tau = std::f64::consts::TAU;
17    let x_all = Tensor::linspace(-tau, tau, n_samples, opts)?; // [-2pi, 2pi]
18    let y_all = x_all.sin()?;
19
20    // Reshape to [n, 1] for the network.
21    let x_data = x_all.reshape(&[n_samples, 1])?;
22    let y_data = y_all.reshape(&[n_samples, 1])?;
23
24    // Split into batches of 50.
25    let x_batches = x_data.batches(50)?;
26    let y_batches = y_data.batches(50)?;
27    let batches: Vec<_> = x_batches.into_iter().zip(y_batches).collect();
28
29    // --- Model: Linear(1,32) -> GELU -> LayerNorm -> residual -> Linear(32,1) ---
30    let model = FlowBuilder::from(Linear::new(1, 32)?)
31        .through(GELU)
32        .through(LayerNorm::new(32)?)
33        .also(Linear::new(32, 32)?)   // residual connection
34        .through(Linear::new(32, 1)?)
35        .build()?;
36
37    let params = model.parameters();
38    let mut optimizer = Adam::new(&params, 0.005);
39    let scheduler = CosineScheduler::new(0.005, 1e-5, 200);
40    model.train();
41
42    // --- Training ---
43    let num_epochs = 200usize;
44    let mut monitor = Monitor::new(num_epochs);
45    // monitor.serve(3000)?;  // uncomment for live dashboard at http://localhost:3000
46    // monitor.watch(&model);
47
48    for epoch in 0..num_epochs {
49        let t = std::time::Instant::now();
50
51        for (xb, yb) in &batches {
52            let input = Variable::new(xb.clone(), true);
53            let target = Variable::new(yb.clone(), false);
54
55            optimizer.zero_grad();
56            let pred = model.forward(&input)?;
57            let loss = mse_loss(&pred, &target)?;
58            loss.backward()?;
59            clip_grad_norm(&params, 1.0)?;
60            optimizer.step()?;
61
62            model.record_scalar("loss", loss.item()?);
63        }
64
65        let lr = scheduler.lr(epoch);
66        optimizer.set_lr(lr);
67        model.record_scalar("lr", lr);
68        model.flush(&[]);
69        monitor.log(epoch, t.elapsed(), &model);
70    }
71
72    monitor.finish();
73
74    // --- Evaluation ---
75    model.eval();
76    println!("\n{:>8}  {:>10}  {:>10}  {:>8}", "x", "actual", "predicted", "error");
77    println!("{}", "-".repeat(42));
78
79    // Test on 10 evenly spaced points.
80    let test_x = Tensor::linspace(-tau, tau, 10, opts)?;
81    let test_y = test_x.sin()?;
82    let test_input = test_x.reshape(&[10, 1])?;
83
84    let pred = no_grad(|| {
85        let input = Variable::new(test_input.clone(), false);
86        model.forward(&input)
87    })?;
88
89    let pred_data = pred.data().to_f32_vec()?;
90    let actual_data = test_y.to_f32_vec()?;
91    let x_data_vec = test_x.to_f32_vec()?;
92
93    let mut max_err: f32 = 0.0;
94    for i in 0..10 {
95        let err = (pred_data[i] - actual_data[i]).abs();
96        if err > max_err {
97            max_err = err;
98        }
99        println!(
100            "{:>8.3}  {:>10.4}  {:>10.4}  {:>8.4}",
101            x_data_vec[i], actual_data[i], pred_data[i], err
102        );
103    }
104    println!("\nMax error: {:.4}", max_err);
105
106    // --- Checkpoint round-trip ---
107    let path = "sine_model.fdl";
108    let named = model.named_parameters();
109    let named_bufs = model.named_buffers();
110    save_checkpoint_file(path, &named, &named_bufs, Some(model.structural_hash()))?;
111    println!("Checkpoint saved to {}", path);
112
113    // Rebuild architecture and load weights.
114    let model2 = FlowBuilder::from(Linear::new(1, 32)?)
115        .through(GELU)
116        .through(LayerNorm::new(32)?)
117        .also(Linear::new(32, 32)?)
118        .through(Linear::new(32, 1)?)
119        .build()?;
120
121    let named2 = model2.named_parameters();
122    let named_bufs2 = model2.named_buffers();
123    load_checkpoint_file(path, &named2, &named_bufs2, Some(model2.structural_hash()))?;
124    model2.eval();
125
126    // Verify loaded model produces the same output.
127    let pred2 = no_grad(|| {
128        let input = Variable::new(test_input.clone(), false);
129        model2.forward(&input)
130    })?;
131
132    let pred2_data = pred2.data().to_f32_vec()?;
133    let mut reload_diff: f32 = 0.0;
134    for i in 0..10 {
135        let d = (pred_data[i] - pred2_data[i]).abs();
136        if d > reload_diff {
137            reload_diff = d;
138        }
139    }
140    println!("Checkpoint reload max diff: {:.6}", reload_diff);
141    assert!(
142        reload_diff < 1e-5,
143        "Checkpoint round-trip mismatch: {}",
144        reload_diff
145    );
146    println!("Checkpoint round-trip verified.");
147
148    // Clean up.
149    std::fs::remove_file(path).ok();
150    Ok(())
151}
Source

fn buffers(&self) -> Vec<Buffer>

Return this module’s non-learnable persistent buffers (e.g., running stats). Default: recursively collects from sub_modules() with pointer dedup. Leaf modules should override to return their own buffers.

Source

fn name(&self) -> &str

Human-readable type name used as node ID prefix in graph visualization. Override to return a lowercase identifier (e.g., “linear”, “gelu”).

Source

fn sub_modules(&self) -> Vec<Rc<dyn Module>>

Return direct child modules for recursive tree walks. Override in composite modules (loops, switches, gates).

Source

fn move_to_device(&self, _device: Device)

Move all parameters and buffers to the given device. Override in modules like BatchNorm that hold non-parameter state.

Source

fn set_training(&self, _training: bool)

Set training/eval mode. Affects Dropout, BatchNorm, etc. Override in modules with mode-dependent behavior.

Source

fn train(&self)

Set training mode. Shorthand for set_training(true).

Examples found in repository?
examples/quickstart/main.rs (line 23)
11fn main() -> Result<()> {
12    // Build the model.
13    let model = FlowBuilder::from(Linear::new(2, 16)?)
14        .through(GELU)
15        .through(LayerNorm::new(16)?)
16        .also(Linear::new(16, 16)?)
17        .through(Linear::new(16, 2)?)
18        .build()?;
19
20    // Set up training.
21    let params = model.parameters();
22    let mut optimizer = Adam::new(&params, 0.01);
23    model.train();
24
25    // Generate some random data (XOR-ish pattern).
26    let opts = TensorOptions::default();
27    let batches: Vec<(Tensor, Tensor)> = (0..32)
28        .map(|_| {
29            let x = Tensor::randn(&[16, 2], opts).unwrap();
30            let y = Tensor::randn(&[16, 2], opts).unwrap();
31            (x, y)
32        })
33        .collect();
34
35    // Training loop with monitor.
36    let num_epochs = 50;
37    let mut monitor = Monitor::new(num_epochs);
38    // monitor.serve(3000)?;   // uncomment for live dashboard
39    // monitor.watch(&model);  // uncomment to show graph SVG in dashboard
40
41    for epoch in 0..num_epochs {
42        let t = std::time::Instant::now();
43
44        for (input_t, target_t) in &batches {
45            let input = Variable::new(input_t.clone(), true);
46            let target = Variable::new(target_t.clone(), false);
47
48            let pred = model.forward(&input)?;
49            let loss = mse_loss(&pred, &target)?;
50
51            optimizer.zero_grad();
52            loss.backward()?;
53            clip_grad_norm(&params, 1.0)?;
54            optimizer.step()?;
55
56            model.record_scalar("loss", loss.item()?);
57        }
58
59        model.flush(&[]);
60        monitor.log(epoch, t.elapsed(), &model);
61    }
62
63    monitor.finish();
64    Ok(())
65}
More examples
Hide additional examples
examples/schedulers/main.rs (line 24)
11fn main() -> Result<()> {
12    let opts = TensorOptions::default();
13
14    // Build a small model.
15    let model = FlowBuilder::from(Linear::new(4, 16)?)
16        .through(GELU)
17        .through(LayerNorm::new(16)?)
18        .also(Linear::new(16, 16)?)
19        .through(Linear::new(16, 4)?)
20        .build()?;
21
22    let params = model.parameters();
23    let mut optimizer = Adam::new(&params, 0.001);
24    model.train();
25
26    let num_epochs = 100usize;
27
28    // --- Compose schedulers ---
29
30    // Warmup for 10 epochs, then cosine anneal for the remaining 90.
31    let cosine = CosineScheduler::new(0.001, 1e-5, num_epochs - 10);
32    let scheduler = WarmupScheduler::new(cosine, 0.001, 10);
33
34    // Plateau scheduler as a secondary control: if loss stalls,
35    // reduce LR further regardless of the primary schedule.
36    let mut plateau = PlateauScheduler::new(0.001, 10, 0.5, 1e-6);
37
38    println!("{:>5}  {:>10}  {:>10}  {:>10}", "epoch", "loss", "sched_lr", "eff_lr");
39    println!("{}", "-".repeat(40));
40
41    for epoch in 0..num_epochs {
42        let x = Tensor::randn(&[64, 4], opts)?;
43        let y = Tensor::randn(&[64, 4], opts)?;
44        let input = Variable::new(x, true);
45        let target = Variable::new(y, false);
46
47        optimizer.zero_grad();
48        let pred = model.forward(&input)?;
49        let loss = mse_loss(&pred, &target)?;
50        loss.backward()?;
51        clip_grad_norm(&params, 1.0)?;
52        optimizer.step()?;
53
54        let loss_val = loss.item()?;
55
56        // Primary schedule (warmup + cosine).
57        let sched_lr = scheduler.lr(epoch);
58
59        // Plateau feedback — takes the minimum of primary and plateau LR.
60        let plateau_lr = plateau.observe(loss_val);
61        let effective_lr = sched_lr.min(plateau_lr);
62
63        optimizer.set_lr(effective_lr);
64
65        if epoch % 10 == 0 || epoch == num_epochs - 1 {
66            println!(
67                "{:>5}  {:>10.6}  {:>10.6}  {:>10.6}",
68                epoch, loss_val, sched_lr, effective_lr
69            );
70        }
71    }
72
73    println!("\nFinal LR: {:.6}", optimizer.lr());
74    Ok(())
75}
examples/mixed_precision/main.rs (line 40)
12fn main() -> Result<()> {
13    let opts = TensorOptions::default();
14
15    // Random regression data.
16    let batches: Vec<(Tensor, Tensor)> = (0..4)
17        .map(|_| {
18            let x = Tensor::randn(&[50, 8], opts).unwrap();
19            let y = Tensor::randn(&[50, 4], opts).unwrap();
20            (x, y)
21        })
22        .collect();
23
24    // Build the model (starts in float32).
25    let model = FlowBuilder::from(Linear::new(8, 32)?)
26        .through(GELU)
27        .through(LayerNorm::new(32)?)
28        .also(Linear::new(32, 32)?)
29        .through(Linear::new(32, 4)?)
30        .build()?;
31
32    let params = model.parameters();
33
34    // Cast parameters to float16 for reduced memory and faster matmuls.
35    cast_parameters(&params, DType::Float16);
36    println!("Parameters cast to float16");
37
38    let mut optimizer = Adam::new(&params, 0.001);
39    let mut scaler = GradScaler::new();
40    model.train();
41
42    let num_epochs = 100usize;
43    let mut monitor = Monitor::new(num_epochs);
44
45    for epoch in 0..num_epochs {
46        let t = std::time::Instant::now();
47        let mut steps_taken = 0u32;
48
49        for (xb, yb) in &batches {
50            // Cast inputs to match parameter dtype.
51            let input = Variable::new(xb.to_dtype(DType::Float16)?, true);
52            let target = Variable::new(yb.to_dtype(DType::Float16)?, false);
53
54            optimizer.zero_grad();
55            let pred = model.forward(&input)?;
56            let loss = mse_loss(&pred, &target)?;
57
58            // Scale loss before backward to prevent gradient underflow.
59            let scaled_loss = scaler.scale(&loss)?;
60            scaled_loss.backward()?;
61
62            // Unscale, check for inf/nan, step if clean.
63            let stepped = scaler.step(&params, &mut || optimizer.step())?;
64            scaler.update();
65
66            if stepped {
67                steps_taken += 1;
68            }
69
70            model.record_scalar("loss", loss.item()?);
71        }
72
73        model.record_scalar("scale", scaler.scale_factor());
74        model.flush(&[]);
75        monitor.log(epoch, t.elapsed(), &model);
76
77        if steps_taken == 0 {
78            println!("epoch {}: all steps skipped (inf grads), scale={:.0}", epoch, scaler.scale_factor());
79        }
80    }
81
82    monitor.finish();
83
84    // Cast back to float32 for inference or checkpointing.
85    cast_parameters(&params, DType::Float32);
86    println!("Parameters cast back to float32 for export");
87
88    Ok(())
89}
examples/observation/main.rs (line 28)
12fn main() -> Result<()> {
13    let opts = TensorOptions::default();
14
15    // Build a model with tagged intermediate nodes.
16    let model = FlowBuilder::from(Linear::new(4, 16)?)
17        .through(GELU)
18        .tag("hidden")
19        .through(LayerNorm::new(16)?)
20        .also(Linear::new(16, 16)?)
21        .tag("residual")
22        .through(Linear::new(16, 4)?)
23        .build()?;
24
25    let params = model.parameters();
26    let mut optimizer = Adam::new(&params, 0.01);
27    let scheduler = CosineScheduler::new(0.01, 1e-4, 200);
28    model.train();
29
30    let num_epochs = 200usize;
31    let mut monitor = Monitor::new(num_epochs);
32
33    // Generate stable data so trends are meaningful.
34    let x_data = Tensor::randn(&[128, 4], opts)?;
35    let y_data = Tensor::randn(&[128, 4], opts)?;
36
37    // Split into batches.
38    let x_batches = x_data.batches(32)?;
39    let y_batches = y_data.batches(32)?;
40    let batches: Vec<_> = x_batches.into_iter().zip(y_batches).collect();
41
42    println!("{:>5}  {:>10}  {:>10}  {:>8}  {:>10}", "epoch", "loss", "slope", "stalled", "status");
43    println!("{}", "-".repeat(50));
44
45    for epoch in 0..num_epochs {
46        let t = std::time::Instant::now();
47
48        for (xb, yb) in &batches {
49            let input = Variable::new(xb.clone(), true);
50            let target = Variable::new(yb.clone(), false);
51
52            optimizer.zero_grad();
53            let pred = model.forward(&input)?;
54
55            // Collect tagged node outputs as metrics.
56            model.collect(&["hidden", "residual"])?;
57
58            let loss = mse_loss(&pred, &target)?;
59            loss.backward()?;
60            clip_grad_norm(&params, 1.0)?;
61            optimizer.step()?;
62
63            model.record_scalar("loss", loss.item()?);
64        }
65
66        // Flush batch metrics to epoch history.
67        let lr = scheduler.lr(epoch);
68        optimizer.set_lr(lr);
69        model.record_scalar("lr", lr);
70        model.flush(&["hidden", "residual", "loss", "lr"]);
71        monitor.log(epoch, t.elapsed(), &model);
72
73        // Query trends for early stopping (window=10 epochs).
74        let loss_trend = model.trend("loss");
75        let stalled = loss_trend.stalled(10, 1e-5);
76        let converged = loss_trend.converged(10, 1e-5);
77
78        if epoch % 20 == 0 || stalled || converged {
79            let status = if converged {
80                "CONVERGED"
81            } else if stalled {
82                "stalled"
83            } else if loss_trend.improving(10) {
84                "improving"
85            } else {
86                ""
87            };
88
89            println!(
90                "{:>5}  {:>10.6}  {:>10.6}  {:>8}  {:>10}",
91                epoch,
92                loss_trend.latest(),
93                loss_trend.slope(10),
94                stalled,
95                status
96            );
97        }
98
99        // Early stop when converged.
100        if converged && epoch > 20 {
101            println!("\nEarly stop at epoch {} — loss converged.", epoch);
102            break;
103        }
104    }
105
106    monitor.finish();
107
108    // Group trend queries (window=10 for all).
109    let group = model.trends(&["hidden", "residual", "loss"]);
110    println!("\nFinal trend summary:");
111    println!("  All improving: {}", group.all_improving(10));
112    println!("  Any stalled:   {}", group.any_stalled(10, 1e-5));
113    println!("  Mean slope:    {:.6}", group.mean_slope(10));
114
115    Ok(())
116}
examples/transfer_learning/main.rs (line 26)
11fn main() -> Result<()> {
12    let opts = TensorOptions::default();
13
14    // --- Phase 1: Train an encoder on a proxy task ---
15    println!("=== Phase 1: Train encoder ===");
16
17    let encoder = FlowBuilder::from(Linear::new(4, 16)?)
18        .through(GELU)
19        .through(LayerNorm::new(16)?)
20        .tag("encoded")
21        .through(Linear::new(16, 4)?)
22        .build()?;
23
24    let params = encoder.parameters();
25    let mut opt = Adam::new(&params, 0.01);
26    encoder.train();
27
28    for epoch in 0..50 {
29        let x = Tensor::randn(&[32, 4], opts)?;
30        let y = Tensor::randn(&[32, 4], opts)?;
31        let input = Variable::new(x, true);
32        let target = Variable::new(y, false);
33
34        opt.zero_grad();
35        let pred = encoder.forward(&input)?;
36        let loss = mse_loss(&pred, &target)?;
37        loss.backward()?;
38        clip_grad_norm(&params, 1.0)?;
39        opt.step()?;
40
41        if epoch % 10 == 0 {
42            println!("  epoch {:>3}: loss={:.4}", epoch, loss.item()?);
43        }
44    }
45
46    // Save the pretrained encoder.
47    let ckpt_path = "pretrained_encoder.fdl";
48    let named = encoder.named_parameters();
49    let named_bufs = encoder.named_buffers();
50    save_checkpoint_file(ckpt_path, &named, &named_bufs, None)?;
51    println!("Encoder saved to {}", ckpt_path);
52
53    // --- Phase 2: Build a new model and load encoder weights ---
54    println!("\n=== Phase 2: Transfer to new architecture ===");
55
56    // New architecture reuses the encoder layers but adds a different head.
57    let model = FlowBuilder::from(Linear::new(4, 16)?)
58        .through(GELU)
59        .through(LayerNorm::new(16)?)
60        .tag("encoded")
61        .also(Linear::new(16, 16)?)      // new: residual connection
62        .through(Linear::new(16, 2)?)     // new: different output dim
63        .build()?;
64
65    // Partial load: matching names get loaded, new layers keep random init.
66    let named2 = model.named_parameters();
67    let named_bufs2 = model.named_buffers();
68    let report = load_checkpoint_file(ckpt_path, &named2, &named_bufs2, None)?;
69
70    println!("Loaded {} parameters, skipped {}, missing {}",
71        report.loaded.len(), report.skipped.len(), report.missing.len());
72
73    // Freeze the encoder layers (first 3 modules: Linear, GELU, LayerNorm).
74    let all_params = model.parameters();
75    for (i, p) in all_params.iter().enumerate() {
76        if i < 3 {
77            p.freeze()?;
78        }
79    }
80    println!("Encoder layers frozen");
81
82    // Set up optimizer with parameter groups:
83    // - Frozen params are excluded automatically (zero grad).
84    // - New head gets higher LR.
85    let trainable: Vec<Parameter> = all_params
86        .iter()
87        .filter(|p| !p.is_frozen())
88        .cloned()
89        .collect();
90
91    let mut opt2 = Adam::new(&trainable, 0.005);
92    model.train();
93
94    // --- Phase 3: Fine-tune the new head ---
95    println!("\n=== Phase 3: Fine-tune new head ===");
96
97    for epoch in 0..50 {
98        let x = Tensor::randn(&[32, 4], opts)?;
99        let y = Tensor::randn(&[32, 2], opts)?;
100        let input = Variable::new(x, true);
101        let target = Variable::new(y, false);
102
103        opt2.zero_grad();
104        let pred = model.forward(&input)?;
105        let loss = mse_loss(&pred, &target)?;
106        loss.backward()?;
107        clip_grad_norm(&trainable, 1.0)?;
108        opt2.step()?;
109
110        if epoch % 10 == 0 {
111            println!("  epoch {:>3}: loss={:.4}", epoch, loss.item()?);
112        }
113    }
114
115    // --- Phase 4: Unfreeze and fine-tune everything ---
116    println!("\n=== Phase 4: Full fine-tune (unfrozen) ===");
117
118    for p in &all_params {
119        p.unfreeze()?;
120    }
121
122    let mut opt3 = Adam::new(&all_params, 0.001); // lower LR for full model
123    for epoch in 0..30 {
124        let x = Tensor::randn(&[32, 4], opts)?;
125        let y = Tensor::randn(&[32, 2], opts)?;
126        let input = Variable::new(x, true);
127        let target = Variable::new(y, false);
128
129        opt3.zero_grad();
130        let pred = model.forward(&input)?;
131        let loss = mse_loss(&pred, &target)?;
132        loss.backward()?;
133        clip_grad_norm(&all_params, 1.0)?;
134        opt3.step()?;
135
136        if epoch % 10 == 0 {
137            println!("  epoch {:>3}: loss={:.4}", epoch, loss.item()?);
138        }
139    }
140
141    println!("\nTransfer learning complete.");
142
143    // Clean up.
144    std::fs::remove_file(ckpt_path).ok();
145    Ok(())
146}
examples/sine_wave/main.rs (line 40)
12fn main() -> Result<()> {
13    // --- Data: sin(x) on [-2pi, 2pi] ---
14    let opts = TensorOptions::default();
15    let n_samples = 200i64;
16    let tau = std::f64::consts::TAU;
17    let x_all = Tensor::linspace(-tau, tau, n_samples, opts)?; // [-2pi, 2pi]
18    let y_all = x_all.sin()?;
19
20    // Reshape to [n, 1] for the network.
21    let x_data = x_all.reshape(&[n_samples, 1])?;
22    let y_data = y_all.reshape(&[n_samples, 1])?;
23
24    // Split into batches of 50.
25    let x_batches = x_data.batches(50)?;
26    let y_batches = y_data.batches(50)?;
27    let batches: Vec<_> = x_batches.into_iter().zip(y_batches).collect();
28
29    // --- Model: Linear(1,32) -> GELU -> LayerNorm -> residual -> Linear(32,1) ---
30    let model = FlowBuilder::from(Linear::new(1, 32)?)
31        .through(GELU)
32        .through(LayerNorm::new(32)?)
33        .also(Linear::new(32, 32)?)   // residual connection
34        .through(Linear::new(32, 1)?)
35        .build()?;
36
37    let params = model.parameters();
38    let mut optimizer = Adam::new(&params, 0.005);
39    let scheduler = CosineScheduler::new(0.005, 1e-5, 200);
40    model.train();
41
42    // --- Training ---
43    let num_epochs = 200usize;
44    let mut monitor = Monitor::new(num_epochs);
45    // monitor.serve(3000)?;  // uncomment for live dashboard at http://localhost:3000
46    // monitor.watch(&model);
47
48    for epoch in 0..num_epochs {
49        let t = std::time::Instant::now();
50
51        for (xb, yb) in &batches {
52            let input = Variable::new(xb.clone(), true);
53            let target = Variable::new(yb.clone(), false);
54
55            optimizer.zero_grad();
56            let pred = model.forward(&input)?;
57            let loss = mse_loss(&pred, &target)?;
58            loss.backward()?;
59            clip_grad_norm(&params, 1.0)?;
60            optimizer.step()?;
61
62            model.record_scalar("loss", loss.item()?);
63        }
64
65        let lr = scheduler.lr(epoch);
66        optimizer.set_lr(lr);
67        model.record_scalar("lr", lr);
68        model.flush(&[]);
69        monitor.log(epoch, t.elapsed(), &model);
70    }
71
72    monitor.finish();
73
74    // --- Evaluation ---
75    model.eval();
76    println!("\n{:>8}  {:>10}  {:>10}  {:>8}", "x", "actual", "predicted", "error");
77    println!("{}", "-".repeat(42));
78
79    // Test on 10 evenly spaced points.
80    let test_x = Tensor::linspace(-tau, tau, 10, opts)?;
81    let test_y = test_x.sin()?;
82    let test_input = test_x.reshape(&[10, 1])?;
83
84    let pred = no_grad(|| {
85        let input = Variable::new(test_input.clone(), false);
86        model.forward(&input)
87    })?;
88
89    let pred_data = pred.data().to_f32_vec()?;
90    let actual_data = test_y.to_f32_vec()?;
91    let x_data_vec = test_x.to_f32_vec()?;
92
93    let mut max_err: f32 = 0.0;
94    for i in 0..10 {
95        let err = (pred_data[i] - actual_data[i]).abs();
96        if err > max_err {
97            max_err = err;
98        }
99        println!(
100            "{:>8.3}  {:>10.4}  {:>10.4}  {:>8.4}",
101            x_data_vec[i], actual_data[i], pred_data[i], err
102        );
103    }
104    println!("\nMax error: {:.4}", max_err);
105
106    // --- Checkpoint round-trip ---
107    let path = "sine_model.fdl";
108    let named = model.named_parameters();
109    let named_bufs = model.named_buffers();
110    save_checkpoint_file(path, &named, &named_bufs, Some(model.structural_hash()))?;
111    println!("Checkpoint saved to {}", path);
112
113    // Rebuild architecture and load weights.
114    let model2 = FlowBuilder::from(Linear::new(1, 32)?)
115        .through(GELU)
116        .through(LayerNorm::new(32)?)
117        .also(Linear::new(32, 32)?)
118        .through(Linear::new(32, 1)?)
119        .build()?;
120
121    let named2 = model2.named_parameters();
122    let named_bufs2 = model2.named_buffers();
123    load_checkpoint_file(path, &named2, &named_bufs2, Some(model2.structural_hash()))?;
124    model2.eval();
125
126    // Verify loaded model produces the same output.
127    let pred2 = no_grad(|| {
128        let input = Variable::new(test_input.clone(), false);
129        model2.forward(&input)
130    })?;
131
132    let pred2_data = pred2.data().to_f32_vec()?;
133    let mut reload_diff: f32 = 0.0;
134    for i in 0..10 {
135        let d = (pred_data[i] - pred2_data[i]).abs();
136        if d > reload_diff {
137            reload_diff = d;
138        }
139    }
140    println!("Checkpoint reload max diff: {:.6}", reload_diff);
141    assert!(
142        reload_diff < 1e-5,
143        "Checkpoint round-trip mismatch: {}",
144        reload_diff
145    );
146    println!("Checkpoint round-trip verified.");
147
148    // Clean up.
149    std::fs::remove_file(path).ok();
150    Ok(())
151}
Source

fn eval(&self)

Set eval mode. Shorthand for set_training(false).

Examples found in repository?
examples/sine_wave/main.rs (line 75)
12fn main() -> Result<()> {
13    // --- Data: sin(x) on [-2pi, 2pi] ---
14    let opts = TensorOptions::default();
15    let n_samples = 200i64;
16    let tau = std::f64::consts::TAU;
17    let x_all = Tensor::linspace(-tau, tau, n_samples, opts)?; // [-2pi, 2pi]
18    let y_all = x_all.sin()?;
19
20    // Reshape to [n, 1] for the network.
21    let x_data = x_all.reshape(&[n_samples, 1])?;
22    let y_data = y_all.reshape(&[n_samples, 1])?;
23
24    // Split into batches of 50.
25    let x_batches = x_data.batches(50)?;
26    let y_batches = y_data.batches(50)?;
27    let batches: Vec<_> = x_batches.into_iter().zip(y_batches).collect();
28
29    // --- Model: Linear(1,32) -> GELU -> LayerNorm -> residual -> Linear(32,1) ---
30    let model = FlowBuilder::from(Linear::new(1, 32)?)
31        .through(GELU)
32        .through(LayerNorm::new(32)?)
33        .also(Linear::new(32, 32)?)   // residual connection
34        .through(Linear::new(32, 1)?)
35        .build()?;
36
37    let params = model.parameters();
38    let mut optimizer = Adam::new(&params, 0.005);
39    let scheduler = CosineScheduler::new(0.005, 1e-5, 200);
40    model.train();
41
42    // --- Training ---
43    let num_epochs = 200usize;
44    let mut monitor = Monitor::new(num_epochs);
45    // monitor.serve(3000)?;  // uncomment for live dashboard at http://localhost:3000
46    // monitor.watch(&model);
47
48    for epoch in 0..num_epochs {
49        let t = std::time::Instant::now();
50
51        for (xb, yb) in &batches {
52            let input = Variable::new(xb.clone(), true);
53            let target = Variable::new(yb.clone(), false);
54
55            optimizer.zero_grad();
56            let pred = model.forward(&input)?;
57            let loss = mse_loss(&pred, &target)?;
58            loss.backward()?;
59            clip_grad_norm(&params, 1.0)?;
60            optimizer.step()?;
61
62            model.record_scalar("loss", loss.item()?);
63        }
64
65        let lr = scheduler.lr(epoch);
66        optimizer.set_lr(lr);
67        model.record_scalar("lr", lr);
68        model.flush(&[]);
69        monitor.log(epoch, t.elapsed(), &model);
70    }
71
72    monitor.finish();
73
74    // --- Evaluation ---
75    model.eval();
76    println!("\n{:>8}  {:>10}  {:>10}  {:>8}", "x", "actual", "predicted", "error");
77    println!("{}", "-".repeat(42));
78
79    // Test on 10 evenly spaced points.
80    let test_x = Tensor::linspace(-tau, tau, 10, opts)?;
81    let test_y = test_x.sin()?;
82    let test_input = test_x.reshape(&[10, 1])?;
83
84    let pred = no_grad(|| {
85        let input = Variable::new(test_input.clone(), false);
86        model.forward(&input)
87    })?;
88
89    let pred_data = pred.data().to_f32_vec()?;
90    let actual_data = test_y.to_f32_vec()?;
91    let x_data_vec = test_x.to_f32_vec()?;
92
93    let mut max_err: f32 = 0.0;
94    for i in 0..10 {
95        let err = (pred_data[i] - actual_data[i]).abs();
96        if err > max_err {
97            max_err = err;
98        }
99        println!(
100            "{:>8.3}  {:>10.4}  {:>10.4}  {:>8.4}",
101            x_data_vec[i], actual_data[i], pred_data[i], err
102        );
103    }
104    println!("\nMax error: {:.4}", max_err);
105
106    // --- Checkpoint round-trip ---
107    let path = "sine_model.fdl";
108    let named = model.named_parameters();
109    let named_bufs = model.named_buffers();
110    save_checkpoint_file(path, &named, &named_bufs, Some(model.structural_hash()))?;
111    println!("Checkpoint saved to {}", path);
112
113    // Rebuild architecture and load weights.
114    let model2 = FlowBuilder::from(Linear::new(1, 32)?)
115        .through(GELU)
116        .through(LayerNorm::new(32)?)
117        .also(Linear::new(32, 32)?)
118        .through(Linear::new(32, 1)?)
119        .build()?;
120
121    let named2 = model2.named_parameters();
122    let named_bufs2 = model2.named_buffers();
123    load_checkpoint_file(path, &named2, &named_bufs2, Some(model2.structural_hash()))?;
124    model2.eval();
125
126    // Verify loaded model produces the same output.
127    let pred2 = no_grad(|| {
128        let input = Variable::new(test_input.clone(), false);
129        model2.forward(&input)
130    })?;
131
132    let pred2_data = pred2.data().to_f32_vec()?;
133    let mut reload_diff: f32 = 0.0;
134    for i in 0..10 {
135        let d = (pred_data[i] - pred2_data[i]).abs();
136        if d > reload_diff {
137            reload_diff = d;
138        }
139    }
140    println!("Checkpoint reload max diff: {:.6}", reload_diff);
141    assert!(
142        reload_diff < 1e-5,
143        "Checkpoint round-trip mismatch: {}",
144        reload_diff
145    );
146    println!("Checkpoint round-trip verified.");
147
148    // Clean up.
149    std::fs::remove_file(path).ok();
150    Ok(())
151}
More examples
Hide additional examples
examples/showcase/main.rs (line 822)
694fn main() {
695    println!("=== floDl showcase ===\n");
696
697    // -- Build --
698    println!("Building graph...");
699    let g = build_showcase().expect("build failed");
700    let n_params = g.parameters().len();
701    println!("Parameters: {}", n_params);
702
703    // -- Forward (with auxiliary input) --
704    let result = g.forward_multi(&[make_input(false), make_context()])
705        .expect("forward failed");
706    println!("Output: {:?} (shape {:?})", result.data().to_f32_vec().unwrap(), result.shape());
707
708    // -- Forward ref carries state --
709    g.reset_state();
710    let r1 = g.forward_multi(&[make_input(false), make_context()]).unwrap();
711    let v1 = r1.data().to_f32_vec().unwrap();
712    let r2 = g.forward_multi(&[make_input(false), make_context()]).unwrap();
713    let v2 = r2.data().to_f32_vec().unwrap();
714    println!("State drift: pass2 differs = {}", v1 != v2);
715
716    // -- Reset --
717    g.reset_state();
718    let r3 = g.forward_multi(&[make_input(false), make_context()]).unwrap();
719    let v3 = r3.data().to_f32_vec().unwrap();
720    println!("Reset restores: {}", v1 == v3);
721
722    // -- DOT + SVG (structural) --
723    let dot = g.dot();
724    println!("DOT: {} bytes", dot.len());
725
726    // Write structural DOT
727    let dot_path = concat!(env!("CARGO_MANIFEST_DIR"), "/examples/showcase/showcase.dot");
728    std::fs::write(dot_path, &dot).expect("write showcase.dot");
729    println!("Wrote {}", dot_path);
730
731    // Write structural SVG
732    let svg_path = concat!(env!("CARGO_MANIFEST_DIR"), "/examples/showcase/showcase.svg");
733    let svg = g.svg(Some(svg_path)).expect("write showcase.svg");
734    println!("Wrote {} ({} bytes)", svg_path, svg.len());
735
736    // -- Training loop with observation + profiling + monitor --
737    println!("\n--- Training (5 epochs x 4 steps) ---");
738    g.train();
739    g.reset_state();
740    g.enable_profiling();
741
742    let params = g.parameters();
743    let mut optimizer = Adam::new(&params, 0.001);
744    let num_epochs = 5;
745    let total_steps = num_epochs * 4;
746    let sched = CosineScheduler::new(0.001, 1e-5, total_steps);
747    let mut monitor = Monitor::new(num_epochs);
748
749    let mut step_idx = 0;
750    for epoch in 0..num_epochs {
751        let t = std::time::Instant::now();
752        for _ in 0..4 {
753            optimizer.zero_grad();
754            let input = make_input(true);
755            let ctx = make_context();
756            let target = make_target();
757
758            let pred = g.forward_multi(&[input, ctx]).unwrap();
759            let loss = mse_loss(&pred, &target).unwrap();
760
761            loss.backward().unwrap();
762            clip_grad_norm(&params, 1.0).unwrap();
763            optimizer.set_lr(sched.lr(step_idx));
764            optimizer.step().unwrap();
765            step_idx += 1;
766
767            g.record_scalar("loss", loss.item().unwrap());
768            g.record_scalar("lr", sched.lr(step_idx - 1));
769            g.end_step();
770        }
771
772        g.end_epoch();
773        monitor.log(epoch, t.elapsed(), &g);
774    }
775
776    // -- Trends --
777    let trend = g.trend("loss");
778    println!(
779        "\nLoss trend: {} epochs, slope={:.4}, improving={}",
780        trend.len(),
781        trend.slope(0),
782        trend.improving(0),
783    );
784
785    // Timing trends use node IDs — pick the first tagged one
786    let timing = g.timing_trend("input");
787    println!(
788        "Timing trend (input node): {} epochs, mean={:.1}us",
789        timing.len(),
790        timing.mean() * 1e6,
791    );
792
793    // -- Write profiling DOT + SVG --
794    let profile_dot = g.dot_with_profile();
795    let profile_dot_path = concat!(env!("CARGO_MANIFEST_DIR"), "/examples/showcase/showcase_profile.dot");
796    std::fs::write(profile_dot_path, &profile_dot).expect("write showcase_profile.dot");
797    println!("Wrote {}", profile_dot_path);
798
799    let profile_svg_path = concat!(env!("CARGO_MANIFEST_DIR"), "/examples/showcase/showcase_profile.svg");
800    let profile_svg = g.svg_with_profile(Some(profile_svg_path)).expect("write showcase_profile.svg");
801    println!("Wrote {} ({} bytes)", profile_svg_path, profile_svg.len());
802
803    // -- Write training HTML --
804    let html_path = concat!(env!("CARGO_MANIFEST_DIR"), "/examples/showcase/showcase_training.html");
805    g.plot_html(html_path, &["loss"]).expect("write showcase_training.html");
806    println!("Wrote {}", html_path);
807
808    // -- Write training log --
809    let log_path = concat!(env!("CARGO_MANIFEST_DIR"), "/examples/showcase/showcase_training.log");
810    g.write_log(log_path, 5, &["loss"]).expect("write showcase_training.log");
811    println!("Wrote {}", log_path);
812
813    // -- Checkpoint round-trip --
814    let path = "/tmp/flodl_showcase_checkpoint.fdl";
815    let named = g.named_parameters();
816    let named_bufs = g.named_buffers();
817    save_checkpoint_file(path, &named, &named_bufs, Some(g.structural_hash())).expect("save failed");
818    let report = load_checkpoint_file(path, &named, &named_bufs, Some(g.structural_hash())).expect("load failed");
819    println!("\nCheckpoint save/load: OK ({} loaded)", report.loaded.len());
820
821    // -- no_grad inference (eval mode works now — BatchNorm has running stats from training) --
822    g.eval();
823    g.reset_state();
824    let final_out = no_grad(|| g.forward_multi(&[make_input(false), make_context()])).unwrap();
825    let final_vals = final_out.data().to_f32_vec().unwrap();
826    println!("no_grad inference: {:?}", final_vals);
827    assert!(final_vals.iter().all(|v| v.is_finite()), "no_grad output should be finite");
828
829    println!("\nAll showcase checks passed.");
830}
Source

fn trace(&self) -> Option<Variable>

Return per-iteration side output for loop tracing. Override in loop body modules that capture trajectory data (e.g., attention fixation points). Returns None by default. When Some, the loop executor collects traces accessible via Graph::traces().

Source

fn as_named_input(&self) -> Option<&dyn NamedInputModule>

Upcast to NamedInputModule for multi-input graphs. Override in types that implement NamedInputModule to enable receiving additional named inputs via graph using().

Source

fn structural_hash(&self) -> Option<String>

SHA-256 hex hash of module architecture for checkpoint validation. Override in composite modules (Graph) that compute a deterministic hash from their topology and parameter shapes.

Source

fn reset(&self)

Reset internal state (e.g. recurrent hidden state) between sequences. Called by loops before iterating to clear stale tensors whose grad_fns may reference freed saved tensors. Override in stateful modules.

Source

fn detach_state(&self)

Detach internal state from the computation graph (for truncated BPTT). Called between training steps to break gradient chains on state carried across forward passes (e.g., recurrent hidden state). Override in stateful modules.

Implementors§