1use std::path::Path;
8
9use mamba_rs::config::MambaConfig;
10use mamba_rs::ops::dims::{MambaDims, MambaRecurrentState};
11use mamba_rs::train::backward::backward_mamba_backbone_batched;
12use mamba_rs::train::flat::MambaBackboneFlat;
13use mamba_rs::train::forward::forward_mamba_backbone_batched;
14use mamba_rs::train::scratch::{BackwardPhaseScratch, PhaseScratch};
15use mamba_rs::train::weights::{TrainMambaLayerWeights, TrainMambaWeights};
16use mamba_rs::{MambaBackbone, MambaWeights};
17
18fn main() {
19 let cfg = MambaConfig::default(); let input_dim = cfg.d_model;
21 let seq_len = 16;
22 let lr = 1e-3_f32;
23 let steps = 50;
24
25 println!("mamba-rs: train -> save -> load -> inference");
26 println!("=============================================");
27 println!(
28 "Config: d_model={}, layers={}, d_inner={}, params={}",
29 cfg.d_model,
30 cfg.n_layers,
31 cfg.d_inner(),
32 MambaBackbone::init(cfg, input_dim, 42).param_count()
33 );
34 println!();
35
36 let inf_weights = MambaWeights::init(&cfg, input_dim, 42);
40 let mut tw = train_weights_from_inference(&inf_weights);
41
42 let dims = MambaDims::from_config(&cfg, seq_len, input_dim);
43 let di = cfg.d_inner();
44 let ds = cfg.d_state;
45 let dc = cfg.d_conv;
46 let dm = cfg.d_model;
47
48 let mut acts = MambaBackboneFlat::zeros(dims);
50 let mut fwd_scratch = PhaseScratch::zeros(&dims);
51 let mut bwd_scratch = BackwardPhaseScratch::zeros(&dims);
52
53 let input: Vec<f32> = (0..seq_len * input_dim)
55 .map(|i| ((i * 7 + 13) % 100) as f32 * 0.01 - 0.5)
56 .collect();
57
58 println!("Training ({steps} steps, lr={lr}, seq_len={seq_len}):");
62
63 for step in 0..steps {
64 let mut a_neg = vec![0.0f32; cfg.n_layers * di * ds];
66 for (l, lw) in tw.layers.iter().enumerate() {
67 for i in 0..di * ds {
68 a_neg[l * di * ds + i] = -lw.a_log[i].exp();
69 }
70 }
71
72 let mut conv = vec![0.0f32; cfg.n_layers * di * dc];
74 let mut ssm = vec![0.0f32; cfg.n_layers * di * ds];
75 let mut state = MambaRecurrentState {
76 conv: &mut conv,
77 ssm: &mut ssm,
78 a_neg: &a_neg,
79 };
80 let mut temporal = vec![0.0f32; seq_len * dm];
81 forward_mamba_backbone_batched(
82 &mut temporal,
83 &mut acts,
84 &tw,
85 &input,
86 &mut state,
87 &mut fwd_scratch,
88 &dims,
89 );
90
91 let loss: f32 = temporal.iter().map(|v| v * v).sum::<f32>() / (seq_len * dm) as f32;
93
94 let scale = 2.0 / (seq_len * dm) as f32;
96 let mut d_temporal: Vec<f32> = temporal.iter().map(|v| v * scale).collect();
97 let mut grads = TrainMambaWeights::zeros_from_dims(&dims);
98 backward_mamba_backbone_batched(
99 &mut d_temporal,
100 &mut grads,
101 &acts,
102 &tw,
103 &a_neg,
104 &mut bwd_scratch,
105 &dims,
106 );
107
108 sgd_update_all(&mut tw, &grads, lr);
110
111 if step % 10 == 0 || step == steps - 1 {
112 println!(" step {step:3}: loss = {loss:.6}");
113 }
114 }
115
116 let save_path = std::env::temp_dir().join("mamba_rs_trained.safetensors");
120 let trained_inf_weights = inference_weights_from_train(&tw, &inf_weights);
121 mamba_rs::serialize::save(&save_path, &trained_inf_weights, &cfg, input_dim)
122 .expect("save failed");
123 println!("\nSaved to: {}", save_path.display());
124
125 let (loaded_w, loaded_cfg, loaded_input_dim) =
129 mamba_rs::serialize::load(Path::new(&save_path)).expect("load failed");
130 let bb = MambaBackbone::from_weights(loaded_cfg, loaded_w).expect("from_weights failed");
131 assert_eq!(bb.input_dim(), loaded_input_dim);
132 println!(
133 "Loaded: d_model={}, layers={}, params={}",
134 loaded_cfg.d_model,
135 loaded_cfg.n_layers,
136 bb.param_count()
137 );
138
139 let mut inf_state = bb.alloc_state();
141 let mut scratch = bb.alloc_scratch();
142 let mut output = vec![0.0f32; dm];
143
144 println!("\nInference (10 steps):");
145 for step in 0..10 {
146 let inp: Vec<f32> = (0..input_dim)
147 .map(|i| ((step * input_dim + i) * 7 + 13) as f32 % 100.0 * 0.01 - 0.5)
148 .collect();
149 bb.forward_step(&inp, &mut output, &mut inf_state, &mut scratch);
150
151 let out_norm: f32 = output.iter().map(|v| v * v).sum::<f32>().sqrt();
152 println!(" step {step}: output L2 norm = {out_norm:.6}");
153 }
154
155 std::fs::remove_file(&save_path).ok();
157 println!("\nDone.");
158}
159
160fn sgd_update_all(tw: &mut TrainMambaWeights, grads: &TrainMambaWeights, lr: f32) {
161 sgd_update(&mut tw.input_proj_w, &grads.input_proj_w, lr);
162 sgd_update(&mut tw.input_proj_b, &grads.input_proj_b, lr);
163 sgd_update(&mut tw.norm_f_weight, &grads.norm_f_weight, lr);
164 for (lw, gw) in tw.layers.iter_mut().zip(grads.layers.iter()) {
165 sgd_update(&mut lw.norm_weight, &gw.norm_weight, lr);
166 sgd_update(&mut lw.in_proj_w, &gw.in_proj_w, lr);
167 sgd_update(&mut lw.conv1d_weight, &gw.conv1d_weight, lr);
168 sgd_update(&mut lw.conv1d_bias, &gw.conv1d_bias, lr);
169 sgd_update(&mut lw.x_proj_w, &gw.x_proj_w, lr);
170 sgd_update(&mut lw.dt_proj_w, &gw.dt_proj_w, lr);
171 sgd_update(&mut lw.dt_proj_b, &gw.dt_proj_b, lr);
172 sgd_update(&mut lw.a_log, &gw.a_log, lr);
173 sgd_update(&mut lw.d_param, &gw.d_param, lr);
174 sgd_update(&mut lw.out_proj_w, &gw.out_proj_w, lr);
175 }
176}
177
178fn sgd_update(params: &mut [f32], grads: &[f32], lr: f32) {
179 for (p, g) in params.iter_mut().zip(grads.iter()) {
180 *p -= lr * g;
181 }
182}
183
184fn train_weights_from_inference(w: &MambaWeights) -> TrainMambaWeights {
185 TrainMambaWeights {
186 input_proj_w: w.input_proj_w.clone(),
187 input_proj_b: w.input_proj_b.clone(),
188 layers: w
189 .layers
190 .iter()
191 .map(|lw| TrainMambaLayerWeights {
192 norm_weight: lw.norm_weight.clone(),
193 in_proj_w: lw.in_proj_w.clone(),
194 conv1d_weight: lw.conv1d_weight.clone(),
195 conv1d_bias: lw.conv1d_bias.clone(),
196 x_proj_w: lw.x_proj_w.clone(),
197 dt_proj_w: lw.dt_proj_w.clone(),
198 dt_proj_b: lw.dt_proj_b.clone(),
199 a_log: lw.a_log.clone(),
200 d_param: lw.d_param.clone(),
201 out_proj_w: lw.out_proj_w.clone(),
202 })
203 .collect(),
204 norm_f_weight: w.norm_f_weight.clone(),
205 }
206}
207
208fn inference_weights_from_train(tw: &TrainMambaWeights, template: &MambaWeights) -> MambaWeights {
209 MambaWeights {
210 input_proj_w: tw.input_proj_w.clone(),
211 input_proj_b: tw.input_proj_b.clone(),
212 layers: tw
213 .layers
214 .iter()
215 .zip(template.layers.iter())
216 .map(|(tlw, _)| mamba_rs::weights::MambaLayerWeights {
217 norm_weight: tlw.norm_weight.clone(),
218 in_proj_w: tlw.in_proj_w.clone(),
219 conv1d_weight: tlw.conv1d_weight.clone(),
220 conv1d_bias: tlw.conv1d_bias.clone(),
221 x_proj_w: tlw.x_proj_w.clone(),
222 dt_proj_w: tlw.dt_proj_w.clone(),
223 dt_proj_b: tlw.dt_proj_b.clone(),
224 a_log: tlw.a_log.clone(),
225 a_neg: tlw.a_log.iter().map(|v| -v.exp()).collect(),
226 d_param: tlw.d_param.clone(),
227 out_proj_w: tlw.out_proj_w.clone(),
228 })
229 .collect(),
230 norm_f_weight: tw.norm_f_weight.clone(),
231 }
232}