Skip to main content

train_and_infer/
train_and_infer.rs

1//! End-to-end example: initialize model, train it, save, load, and run inference.
2//!
3//! ```bash
4//! cargo run --example train_and_infer
5//! ```
6
7use std::path::Path;
8
9use mamba_rs::config::MambaConfig;
10use mamba_rs::ops::dims::{MambaDims, MambaRecurrentState};
11use mamba_rs::train::backward::backward_mamba_backbone_batched;
12use mamba_rs::train::flat::MambaBackboneFlat;
13use mamba_rs::train::forward::forward_mamba_backbone_batched;
14use mamba_rs::train::scratch::{BackwardPhaseScratch, PhaseScratch};
15use mamba_rs::train::weights::{TrainMambaLayerWeights, TrainMambaWeights};
16use mamba_rs::{MambaBackbone, MambaWeights};
17
18fn main() {
19    let cfg = MambaConfig::default(); // d_model=128, 3 layers, 366K params
20    let input_dim = cfg.d_model;
21    let seq_len = 16;
22    let lr = 1e-3_f32;
23    let steps = 50;
24
25    println!("mamba-rs: train -> save -> load -> inference");
26    println!("=============================================");
27    println!(
28        "Config: d_model={}, layers={}, d_inner={}, params={}",
29        cfg.d_model,
30        cfg.n_layers,
31        cfg.d_inner(),
32        MambaBackbone::init(cfg, input_dim, 42).param_count()
33    );
34    println!();
35
36    // =====================================================================
37    // Step 1: Initialize weights
38    // =====================================================================
39    let inf_weights = MambaWeights::init(&cfg, input_dim, 42);
40    let mut tw = train_weights_from_inference(&inf_weights);
41
42    let dims = MambaDims::from_config(&cfg, seq_len, input_dim);
43    let di = cfg.d_inner();
44    let ds = cfg.d_state;
45    let dc = cfg.d_conv;
46    let dm = cfg.d_model;
47
48    // Pre-allocate scratch (reused every step)
49    let mut acts = MambaBackboneFlat::zeros(dims);
50    let mut fwd_scratch = PhaseScratch::zeros(&dims);
51    let mut bwd_scratch = BackwardPhaseScratch::zeros(&dims);
52
53    // Synthetic training data: random input, target = zeros (regression)
54    let input: Vec<f32> = (0..seq_len * input_dim)
55        .map(|i| ((i * 7 + 13) % 100) as f32 * 0.01 - 0.5)
56        .collect();
57
58    // =====================================================================
59    // Step 2: Training loop (SGD, MSE loss)
60    // =====================================================================
61    println!("Training ({steps} steps, lr={lr}, seq_len={seq_len}):");
62
63    for step in 0..steps {
64        // Precompute a_neg = -exp(a_log)
65        let mut a_neg = vec![0.0f32; cfg.n_layers * di * ds];
66        for (l, lw) in tw.layers.iter().enumerate() {
67            for i in 0..di * ds {
68                a_neg[l * di * ds + i] = -lw.a_log[i].exp();
69            }
70        }
71
72        // Forward pass
73        let mut conv = vec![0.0f32; cfg.n_layers * di * dc];
74        let mut ssm = vec![0.0f32; cfg.n_layers * di * ds];
75        let mut state = MambaRecurrentState {
76            conv: &mut conv,
77            ssm: &mut ssm,
78            a_neg: &a_neg,
79        };
80        let mut temporal = vec![0.0f32; seq_len * dm];
81        forward_mamba_backbone_batched(
82            &mut temporal,
83            &mut acts,
84            &tw,
85            &input,
86            &mut state,
87            &mut fwd_scratch,
88            &dims,
89        );
90
91        // MSE loss
92        let loss: f32 = temporal.iter().map(|v| v * v).sum::<f32>() / (seq_len * dm) as f32;
93
94        // Backward pass
95        let scale = 2.0 / (seq_len * dm) as f32;
96        let mut d_temporal: Vec<f32> = temporal.iter().map(|v| v * scale).collect();
97        let mut grads = TrainMambaWeights::zeros_from_dims(&dims);
98        backward_mamba_backbone_batched(
99            &mut d_temporal,
100            &mut grads,
101            &acts,
102            &tw,
103            &a_neg,
104            &mut bwd_scratch,
105            &dims,
106        );
107
108        // SGD weight update
109        sgd_update_all(&mut tw, &grads, lr);
110
111        if step % 10 == 0 || step == steps - 1 {
112            println!("  step {step:3}: loss = {loss:.6}");
113        }
114    }
115
116    // =====================================================================
117    // Step 3: Save trained weights
118    // =====================================================================
119    let save_path = std::env::temp_dir().join("mamba_rs_trained.safetensors");
120    let trained_inf_weights = inference_weights_from_train(&tw, &inf_weights);
121    mamba_rs::serialize::save(&save_path, &trained_inf_weights, &cfg, input_dim)
122        .expect("save failed");
123    println!("\nSaved to: {}", save_path.display());
124
125    // =====================================================================
126    // Step 4: Load and run inference
127    // =====================================================================
128    let (loaded_w, loaded_cfg, loaded_input_dim) =
129        mamba_rs::serialize::load(Path::new(&save_path)).expect("load failed");
130    let bb = MambaBackbone::from_weights(loaded_cfg, loaded_w).expect("from_weights failed");
131    assert_eq!(bb.input_dim(), loaded_input_dim);
132    println!(
133        "Loaded: d_model={}, layers={}, params={}",
134        loaded_cfg.d_model,
135        loaded_cfg.n_layers,
136        bb.param_count()
137    );
138
139    // Run inference step-by-step
140    let mut inf_state = bb.alloc_state();
141    let mut scratch = bb.alloc_scratch();
142    let mut output = vec![0.0f32; dm];
143
144    println!("\nInference (10 steps):");
145    for step in 0..10 {
146        let inp: Vec<f32> = (0..input_dim)
147            .map(|i| ((step * input_dim + i) * 7 + 13) as f32 % 100.0 * 0.01 - 0.5)
148            .collect();
149        bb.forward_step(&inp, &mut output, &mut inf_state, &mut scratch);
150
151        let out_norm: f32 = output.iter().map(|v| v * v).sum::<f32>().sqrt();
152        println!("  step {step}: output L2 norm = {out_norm:.6}");
153    }
154
155    // Cleanup
156    std::fs::remove_file(&save_path).ok();
157    println!("\nDone.");
158}
159
160fn sgd_update_all(tw: &mut TrainMambaWeights, grads: &TrainMambaWeights, lr: f32) {
161    sgd_update(&mut tw.input_proj_w, &grads.input_proj_w, lr);
162    sgd_update(&mut tw.input_proj_b, &grads.input_proj_b, lr);
163    sgd_update(&mut tw.norm_f_weight, &grads.norm_f_weight, lr);
164    for (lw, gw) in tw.layers.iter_mut().zip(grads.layers.iter()) {
165        sgd_update(&mut lw.norm_weight, &gw.norm_weight, lr);
166        sgd_update(&mut lw.in_proj_w, &gw.in_proj_w, lr);
167        sgd_update(&mut lw.conv1d_weight, &gw.conv1d_weight, lr);
168        sgd_update(&mut lw.conv1d_bias, &gw.conv1d_bias, lr);
169        sgd_update(&mut lw.x_proj_w, &gw.x_proj_w, lr);
170        sgd_update(&mut lw.dt_proj_w, &gw.dt_proj_w, lr);
171        sgd_update(&mut lw.dt_proj_b, &gw.dt_proj_b, lr);
172        sgd_update(&mut lw.a_log, &gw.a_log, lr);
173        sgd_update(&mut lw.d_param, &gw.d_param, lr);
174        sgd_update(&mut lw.out_proj_w, &gw.out_proj_w, lr);
175    }
176}
177
178fn sgd_update(params: &mut [f32], grads: &[f32], lr: f32) {
179    for (p, g) in params.iter_mut().zip(grads.iter()) {
180        *p -= lr * g;
181    }
182}
183
184fn train_weights_from_inference(w: &MambaWeights) -> TrainMambaWeights {
185    TrainMambaWeights {
186        input_proj_w: w.input_proj_w.clone(),
187        input_proj_b: w.input_proj_b.clone(),
188        layers: w
189            .layers
190            .iter()
191            .map(|lw| TrainMambaLayerWeights {
192                norm_weight: lw.norm_weight.clone(),
193                in_proj_w: lw.in_proj_w.clone(),
194                conv1d_weight: lw.conv1d_weight.clone(),
195                conv1d_bias: lw.conv1d_bias.clone(),
196                x_proj_w: lw.x_proj_w.clone(),
197                dt_proj_w: lw.dt_proj_w.clone(),
198                dt_proj_b: lw.dt_proj_b.clone(),
199                a_log: lw.a_log.clone(),
200                d_param: lw.d_param.clone(),
201                out_proj_w: lw.out_proj_w.clone(),
202            })
203            .collect(),
204        norm_f_weight: w.norm_f_weight.clone(),
205    }
206}
207
208fn inference_weights_from_train(tw: &TrainMambaWeights, template: &MambaWeights) -> MambaWeights {
209    MambaWeights {
210        input_proj_w: tw.input_proj_w.clone(),
211        input_proj_b: tw.input_proj_b.clone(),
212        layers: tw
213            .layers
214            .iter()
215            .zip(template.layers.iter())
216            .map(|(tlw, _)| mamba_rs::weights::MambaLayerWeights {
217                norm_weight: tlw.norm_weight.clone(),
218                in_proj_w: tlw.in_proj_w.clone(),
219                conv1d_weight: tlw.conv1d_weight.clone(),
220                conv1d_bias: tlw.conv1d_bias.clone(),
221                x_proj_w: tlw.x_proj_w.clone(),
222                dt_proj_w: tlw.dt_proj_w.clone(),
223                dt_proj_b: tlw.dt_proj_b.clone(),
224                a_log: tlw.a_log.clone(),
225                a_neg: tlw.a_log.iter().map(|v| -v.exp()).collect(),
226                d_param: tlw.d_param.clone(),
227                out_proj_w: tlw.out_proj_w.clone(),
228            })
229            .collect(),
230        norm_f_weight: tw.norm_f_weight.clone(),
231    }
232}