1use mamba_rs::{MambaBackbone, MambaConfig};
2
3fn main() {
4 let cfg = MambaConfig::default();
5 let input_dim = cfg.d_model;
6
7 let backbone = MambaBackbone::init(cfg, input_dim, 42);
9 println!(
10 "Mamba: {} layers, d_model={}, d_inner={}, {} params",
11 backbone.n_layers(),
12 backbone.config().d_model,
13 backbone.config().d_inner(),
14 backbone.param_count(),
15 );
16
17 let mut state = backbone.alloc_state();
19 let mut scratch = backbone.alloc_scratch();
20 let mut output = vec![0.0f32; backbone.config().d_model];
21
22 for step in 0..10 {
24 let input = vec![0.1 * step as f32; input_dim];
25 backbone.forward_step(&input, &mut output, &mut state, &mut scratch);
26
27 let norm: f32 = output.iter().map(|x| x * x).sum::<f32>().sqrt();
28 println!("step {step}: output L2 norm = {norm:.6}");
29 }
30
31 state.reset();
33 println!("state reset");
34}