Skip to main content

inference/
inference.rs

1use mamba_rs::{MambaBackbone, MambaConfig};
2
3fn main() {
4    let cfg = MambaConfig::default();
5    let input_dim = cfg.d_model;
6
7    // Initialize backbone with paper-default weights
8    let backbone = MambaBackbone::init(cfg, input_dim, 42);
9    println!(
10        "Mamba: {} layers, d_model={}, d_inner={}, {} params",
11        backbone.n_layers(),
12        backbone.config().d_model,
13        backbone.config().d_inner(),
14        backbone.param_count(),
15    );
16
17    // Allocate recurrent state + scratch (once, reuse across steps)
18    let mut state = backbone.alloc_state();
19    let mut scratch = backbone.alloc_scratch();
20    let mut output = vec![0.0f32; backbone.config().d_model];
21
22    // Run 10 inference steps
23    for step in 0..10 {
24        let input = vec![0.1 * step as f32; input_dim];
25        backbone.forward_step(&input, &mut output, &mut state, &mut scratch);
26
27        let norm: f32 = output.iter().map(|x| x * x).sum::<f32>().sqrt();
28        println!("step {step}: output L2 norm = {norm:.6}");
29    }
30
31    // Reset state for new sequence
32    state.reset();
33    println!("state reset");
34}