Skip to main content

MambaBackbone

Struct MambaBackbone 

Source
pub struct MambaBackbone { /* private fields */ }
Expand description

Complete Mamba backbone: input_proj -> N layers -> norm_f.

Owns all weights. Provides both single-step recurrent inference and access to raw weights for training integration.

use mamba_rs::module::MambaBackbone;
use mamba_rs::MambaConfig;

let cfg = MambaConfig::default();
let backbone = MambaBackbone::init(cfg, 128, 42);

let mut state = backbone.alloc_state();
let mut scratch = backbone.alloc_scratch();
let mut output = vec![0.0f32; backbone.config().d_model];

let input = vec![0.1f32; 128];
backbone.forward_step(&input, &mut output, &mut state, &mut scratch);

Implementations§

Source§

impl MambaBackbone

Source

pub fn init(cfg: MambaConfig, input_dim: usize, seed: u64) -> Self

Create a backbone with Mamba-specific weight initialization.

Uses Kaiming uniform for projections, log-space init for A, inverse-softplus init for dt_proj bias (Gu & Dao, Section 3.5).

Examples found in repository?
examples/inference.rs (line 8)
3fn main() {
4    let cfg = MambaConfig::default();
5    let input_dim = cfg.d_model;
6
7    // Initialize backbone with paper-default weights
8    let backbone = MambaBackbone::init(cfg, input_dim, 42);
9    println!(
10        "Mamba: {} layers, d_model={}, d_inner={}, {} params",
11        backbone.n_layers(),
12        backbone.config().d_model,
13        backbone.config().d_inner(),
14        backbone.param_count(),
15    );
16
17    // Allocate recurrent state + scratch (once, reuse across steps)
18    let mut state = backbone.alloc_state();
19    let mut scratch = backbone.alloc_scratch();
20    let mut output = vec![0.0f32; backbone.config().d_model];
21
22    // Run 10 inference steps
23    for step in 0..10 {
24        let input = vec![0.1 * step as f32; input_dim];
25        backbone.forward_step(&input, &mut output, &mut state, &mut scratch);
26
27        let norm: f32 = output.iter().map(|x| x * x).sum::<f32>().sqrt();
28        println!("step {step}: output L2 norm = {norm:.6}");
29    }
30
31    // Reset state for new sequence
32    state.reset();
33    println!("state reset");
34}
More examples
Hide additional examples
examples/train_and_infer.rs (line 32)
18fn main() {
19    let cfg = MambaConfig::default(); // d_model=128, 3 layers, 366K params
20    let input_dim = cfg.d_model;
21    let seq_len = 16;
22    let lr = 1e-3_f32;
23    let steps = 50;
24
25    println!("mamba-rs: train -> save -> load -> inference");
26    println!("=============================================");
27    println!(
28        "Config: d_model={}, layers={}, d_inner={}, params={}",
29        cfg.d_model,
30        cfg.n_layers,
31        cfg.d_inner(),
32        MambaBackbone::init(cfg, input_dim, 42).param_count()
33    );
34    println!();
35
36    // =====================================================================
37    // Step 1: Initialize weights
38    // =====================================================================
39    let inf_weights = MambaWeights::init(&cfg, input_dim, 42);
40    let mut tw = train_weights_from_inference(&inf_weights);
41
42    let dims = MambaDims::from_config(&cfg, seq_len, input_dim);
43    let di = cfg.d_inner();
44    let ds = cfg.d_state;
45    let dc = cfg.d_conv;
46    let dm = cfg.d_model;
47
48    // Pre-allocate scratch (reused every step)
49    let mut acts = MambaBackboneFlat::zeros(dims);
50    let mut fwd_scratch = PhaseScratch::zeros(&dims);
51    let mut bwd_scratch = BackwardPhaseScratch::zeros(&dims);
52
53    // Synthetic training data: random input, target = zeros (regression)
54    let input: Vec<f32> = (0..seq_len * input_dim)
55        .map(|i| ((i * 7 + 13) % 100) as f32 * 0.01 - 0.5)
56        .collect();
57
58    // =====================================================================
59    // Step 2: Training loop (SGD, MSE loss)
60    // =====================================================================
61    println!("Training ({steps} steps, lr={lr}, seq_len={seq_len}):");
62
63    for step in 0..steps {
64        // Precompute a_neg = -exp(a_log)
65        let mut a_neg = vec![0.0f32; cfg.n_layers * di * ds];
66        for (l, lw) in tw.layers.iter().enumerate() {
67            for i in 0..di * ds {
68                a_neg[l * di * ds + i] = -lw.a_log[i].exp();
69            }
70        }
71
72        // Forward pass
73        let mut conv = vec![0.0f32; cfg.n_layers * di * dc];
74        let mut ssm = vec![0.0f32; cfg.n_layers * di * ds];
75        let mut state = MambaRecurrentState {
76            conv: &mut conv,
77            ssm: &mut ssm,
78            a_neg: &a_neg,
79        };
80        let mut temporal = vec![0.0f32; seq_len * dm];
81        forward_mamba_backbone_batched(
82            &mut temporal,
83            &mut acts,
84            &tw,
85            &input,
86            &mut state,
87            &mut fwd_scratch,
88            &dims,
89        );
90
91        // MSE loss
92        let loss: f32 = temporal.iter().map(|v| v * v).sum::<f32>() / (seq_len * dm) as f32;
93
94        // Backward pass
95        let scale = 2.0 / (seq_len * dm) as f32;
96        let mut d_temporal: Vec<f32> = temporal.iter().map(|v| v * scale).collect();
97        let mut grads = TrainMambaWeights::zeros_from_dims(&dims);
98        backward_mamba_backbone_batched(
99            &mut d_temporal,
100            &mut grads,
101            &acts,
102            &tw,
103            &a_neg,
104            &mut bwd_scratch,
105            &dims,
106        );
107
108        // SGD weight update
109        sgd_update_all(&mut tw, &grads, lr);
110
111        if step % 10 == 0 || step == steps - 1 {
112            println!("  step {step:3}: loss = {loss:.6}");
113        }
114    }
115
116    // =====================================================================
117    // Step 3: Save trained weights
118    // =====================================================================
119    let save_path = std::env::temp_dir().join("mamba_rs_trained.safetensors");
120    let trained_inf_weights = inference_weights_from_train(&tw, &inf_weights);
121    mamba_rs::serialize::save(&save_path, &trained_inf_weights, &cfg, input_dim)
122        .expect("save failed");
123    println!("\nSaved to: {}", save_path.display());
124
125    // =====================================================================
126    // Step 4: Load and run inference
127    // =====================================================================
128    let (loaded_w, loaded_cfg, loaded_input_dim) =
129        mamba_rs::serialize::load(Path::new(&save_path)).expect("load failed");
130    let bb = MambaBackbone::from_weights(loaded_cfg, loaded_w).expect("from_weights failed");
131    assert_eq!(bb.input_dim(), loaded_input_dim);
132    println!(
133        "Loaded: d_model={}, layers={}, params={}",
134        loaded_cfg.d_model,
135        loaded_cfg.n_layers,
136        bb.param_count()
137    );
138
139    // Run inference step-by-step
140    let mut inf_state = bb.alloc_state();
141    let mut scratch = bb.alloc_scratch();
142    let mut output = vec![0.0f32; dm];
143
144    println!("\nInference (10 steps):");
145    for step in 0..10 {
146        let inp: Vec<f32> = (0..input_dim)
147            .map(|i| ((step * input_dim + i) * 7 + 13) as f32 % 100.0 * 0.01 - 0.5)
148            .collect();
149        bb.forward_step(&inp, &mut output, &mut inf_state, &mut scratch);
150
151        let out_norm: f32 = output.iter().map(|v| v * v).sum::<f32>().sqrt();
152        println!("  step {step}: output L2 norm = {out_norm:.6}");
153    }
154
155    // Cleanup
156    std::fs::remove_file(&save_path).ok();
157    println!("\nDone.");
158}
Source

pub fn from_weights( cfg: MambaConfig, weights: MambaWeights, ) -> Result<Self, String>

Create a backbone from pre-loaded weights.

Validates dimensions against config. Returns Err on mismatch.

Examples found in repository?
examples/train_and_infer.rs (line 130)
18fn main() {
19    let cfg = MambaConfig::default(); // d_model=128, 3 layers, 366K params
20    let input_dim = cfg.d_model;
21    let seq_len = 16;
22    let lr = 1e-3_f32;
23    let steps = 50;
24
25    println!("mamba-rs: train -> save -> load -> inference");
26    println!("=============================================");
27    println!(
28        "Config: d_model={}, layers={}, d_inner={}, params={}",
29        cfg.d_model,
30        cfg.n_layers,
31        cfg.d_inner(),
32        MambaBackbone::init(cfg, input_dim, 42).param_count()
33    );
34    println!();
35
36    // =====================================================================
37    // Step 1: Initialize weights
38    // =====================================================================
39    let inf_weights = MambaWeights::init(&cfg, input_dim, 42);
40    let mut tw = train_weights_from_inference(&inf_weights);
41
42    let dims = MambaDims::from_config(&cfg, seq_len, input_dim);
43    let di = cfg.d_inner();
44    let ds = cfg.d_state;
45    let dc = cfg.d_conv;
46    let dm = cfg.d_model;
47
48    // Pre-allocate scratch (reused every step)
49    let mut acts = MambaBackboneFlat::zeros(dims);
50    let mut fwd_scratch = PhaseScratch::zeros(&dims);
51    let mut bwd_scratch = BackwardPhaseScratch::zeros(&dims);
52
53    // Synthetic training data: random input, target = zeros (regression)
54    let input: Vec<f32> = (0..seq_len * input_dim)
55        .map(|i| ((i * 7 + 13) % 100) as f32 * 0.01 - 0.5)
56        .collect();
57
58    // =====================================================================
59    // Step 2: Training loop (SGD, MSE loss)
60    // =====================================================================
61    println!("Training ({steps} steps, lr={lr}, seq_len={seq_len}):");
62
63    for step in 0..steps {
64        // Precompute a_neg = -exp(a_log)
65        let mut a_neg = vec![0.0f32; cfg.n_layers * di * ds];
66        for (l, lw) in tw.layers.iter().enumerate() {
67            for i in 0..di * ds {
68                a_neg[l * di * ds + i] = -lw.a_log[i].exp();
69            }
70        }
71
72        // Forward pass
73        let mut conv = vec![0.0f32; cfg.n_layers * di * dc];
74        let mut ssm = vec![0.0f32; cfg.n_layers * di * ds];
75        let mut state = MambaRecurrentState {
76            conv: &mut conv,
77            ssm: &mut ssm,
78            a_neg: &a_neg,
79        };
80        let mut temporal = vec![0.0f32; seq_len * dm];
81        forward_mamba_backbone_batched(
82            &mut temporal,
83            &mut acts,
84            &tw,
85            &input,
86            &mut state,
87            &mut fwd_scratch,
88            &dims,
89        );
90
91        // MSE loss
92        let loss: f32 = temporal.iter().map(|v| v * v).sum::<f32>() / (seq_len * dm) as f32;
93
94        // Backward pass
95        let scale = 2.0 / (seq_len * dm) as f32;
96        let mut d_temporal: Vec<f32> = temporal.iter().map(|v| v * scale).collect();
97        let mut grads = TrainMambaWeights::zeros_from_dims(&dims);
98        backward_mamba_backbone_batched(
99            &mut d_temporal,
100            &mut grads,
101            &acts,
102            &tw,
103            &a_neg,
104            &mut bwd_scratch,
105            &dims,
106        );
107
108        // SGD weight update
109        sgd_update_all(&mut tw, &grads, lr);
110
111        if step % 10 == 0 || step == steps - 1 {
112            println!("  step {step:3}: loss = {loss:.6}");
113        }
114    }
115
116    // =====================================================================
117    // Step 3: Save trained weights
118    // =====================================================================
119    let save_path = std::env::temp_dir().join("mamba_rs_trained.safetensors");
120    let trained_inf_weights = inference_weights_from_train(&tw, &inf_weights);
121    mamba_rs::serialize::save(&save_path, &trained_inf_weights, &cfg, input_dim)
122        .expect("save failed");
123    println!("\nSaved to: {}", save_path.display());
124
125    // =====================================================================
126    // Step 4: Load and run inference
127    // =====================================================================
128    let (loaded_w, loaded_cfg, loaded_input_dim) =
129        mamba_rs::serialize::load(Path::new(&save_path)).expect("load failed");
130    let bb = MambaBackbone::from_weights(loaded_cfg, loaded_w).expect("from_weights failed");
131    assert_eq!(bb.input_dim(), loaded_input_dim);
132    println!(
133        "Loaded: d_model={}, layers={}, params={}",
134        loaded_cfg.d_model,
135        loaded_cfg.n_layers,
136        bb.param_count()
137    );
138
139    // Run inference step-by-step
140    let mut inf_state = bb.alloc_state();
141    let mut scratch = bb.alloc_scratch();
142    let mut output = vec![0.0f32; dm];
143
144    println!("\nInference (10 steps):");
145    for step in 0..10 {
146        let inp: Vec<f32> = (0..input_dim)
147            .map(|i| ((step * input_dim + i) * 7 + 13) as f32 % 100.0 * 0.01 - 0.5)
148            .collect();
149        bb.forward_step(&inp, &mut output, &mut inf_state, &mut scratch);
150
151        let out_norm: f32 = output.iter().map(|v| v * v).sum::<f32>().sqrt();
152        println!("  step {step}: output L2 norm = {out_norm:.6}");
153    }
154
155    // Cleanup
156    std::fs::remove_file(&save_path).ok();
157    println!("\nDone.");
158}
Source

pub fn into_weights(self) -> MambaWeights

Extract owned weights (consuming self).

Source

pub fn weights(&self) -> &MambaWeights

Read-only weight access.

Source

pub fn weights_mut(&mut self) -> &mut MambaWeights

Mutable weight access (for optimizer updates).

Source

pub fn layer(&self, index: usize) -> &MambaLayerWeights

Read-only access to a specific layer’s weights.

Source

pub fn layer_mut(&mut self, index: usize) -> &mut MambaLayerWeights

Mutable access to a specific layer’s weights.

Source

pub fn n_layers(&self) -> usize

Number of layers.

Examples found in repository?
examples/inference.rs (line 11)
3fn main() {
4    let cfg = MambaConfig::default();
5    let input_dim = cfg.d_model;
6
7    // Initialize backbone with paper-default weights
8    let backbone = MambaBackbone::init(cfg, input_dim, 42);
9    println!(
10        "Mamba: {} layers, d_model={}, d_inner={}, {} params",
11        backbone.n_layers(),
12        backbone.config().d_model,
13        backbone.config().d_inner(),
14        backbone.param_count(),
15    );
16
17    // Allocate recurrent state + scratch (once, reuse across steps)
18    let mut state = backbone.alloc_state();
19    let mut scratch = backbone.alloc_scratch();
20    let mut output = vec![0.0f32; backbone.config().d_model];
21
22    // Run 10 inference steps
23    for step in 0..10 {
24        let input = vec![0.1 * step as f32; input_dim];
25        backbone.forward_step(&input, &mut output, &mut state, &mut scratch);
26
27        let norm: f32 = output.iter().map(|x| x * x).sum::<f32>().sqrt();
28        println!("step {step}: output L2 norm = {norm:.6}");
29    }
30
31    // Reset state for new sequence
32    state.reset();
33    println!("state reset");
34}
Source

pub fn param_count(&self) -> usize

Total parameter count.

Examples found in repository?
examples/inference.rs (line 14)
3fn main() {
4    let cfg = MambaConfig::default();
5    let input_dim = cfg.d_model;
6
7    // Initialize backbone with paper-default weights
8    let backbone = MambaBackbone::init(cfg, input_dim, 42);
9    println!(
10        "Mamba: {} layers, d_model={}, d_inner={}, {} params",
11        backbone.n_layers(),
12        backbone.config().d_model,
13        backbone.config().d_inner(),
14        backbone.param_count(),
15    );
16
17    // Allocate recurrent state + scratch (once, reuse across steps)
18    let mut state = backbone.alloc_state();
19    let mut scratch = backbone.alloc_scratch();
20    let mut output = vec![0.0f32; backbone.config().d_model];
21
22    // Run 10 inference steps
23    for step in 0..10 {
24        let input = vec![0.1 * step as f32; input_dim];
25        backbone.forward_step(&input, &mut output, &mut state, &mut scratch);
26
27        let norm: f32 = output.iter().map(|x| x * x).sum::<f32>().sqrt();
28        println!("step {step}: output L2 norm = {norm:.6}");
29    }
30
31    // Reset state for new sequence
32    state.reset();
33    println!("state reset");
34}
More examples
Hide additional examples
examples/train_and_infer.rs (line 32)
18fn main() {
19    let cfg = MambaConfig::default(); // d_model=128, 3 layers, 366K params
20    let input_dim = cfg.d_model;
21    let seq_len = 16;
22    let lr = 1e-3_f32;
23    let steps = 50;
24
25    println!("mamba-rs: train -> save -> load -> inference");
26    println!("=============================================");
27    println!(
28        "Config: d_model={}, layers={}, d_inner={}, params={}",
29        cfg.d_model,
30        cfg.n_layers,
31        cfg.d_inner(),
32        MambaBackbone::init(cfg, input_dim, 42).param_count()
33    );
34    println!();
35
36    // =====================================================================
37    // Step 1: Initialize weights
38    // =====================================================================
39    let inf_weights = MambaWeights::init(&cfg, input_dim, 42);
40    let mut tw = train_weights_from_inference(&inf_weights);
41
42    let dims = MambaDims::from_config(&cfg, seq_len, input_dim);
43    let di = cfg.d_inner();
44    let ds = cfg.d_state;
45    let dc = cfg.d_conv;
46    let dm = cfg.d_model;
47
48    // Pre-allocate scratch (reused every step)
49    let mut acts = MambaBackboneFlat::zeros(dims);
50    let mut fwd_scratch = PhaseScratch::zeros(&dims);
51    let mut bwd_scratch = BackwardPhaseScratch::zeros(&dims);
52
53    // Synthetic training data: random input, target = zeros (regression)
54    let input: Vec<f32> = (0..seq_len * input_dim)
55        .map(|i| ((i * 7 + 13) % 100) as f32 * 0.01 - 0.5)
56        .collect();
57
58    // =====================================================================
59    // Step 2: Training loop (SGD, MSE loss)
60    // =====================================================================
61    println!("Training ({steps} steps, lr={lr}, seq_len={seq_len}):");
62
63    for step in 0..steps {
64        // Precompute a_neg = -exp(a_log)
65        let mut a_neg = vec![0.0f32; cfg.n_layers * di * ds];
66        for (l, lw) in tw.layers.iter().enumerate() {
67            for i in 0..di * ds {
68                a_neg[l * di * ds + i] = -lw.a_log[i].exp();
69            }
70        }
71
72        // Forward pass
73        let mut conv = vec![0.0f32; cfg.n_layers * di * dc];
74        let mut ssm = vec![0.0f32; cfg.n_layers * di * ds];
75        let mut state = MambaRecurrentState {
76            conv: &mut conv,
77            ssm: &mut ssm,
78            a_neg: &a_neg,
79        };
80        let mut temporal = vec![0.0f32; seq_len * dm];
81        forward_mamba_backbone_batched(
82            &mut temporal,
83            &mut acts,
84            &tw,
85            &input,
86            &mut state,
87            &mut fwd_scratch,
88            &dims,
89        );
90
91        // MSE loss
92        let loss: f32 = temporal.iter().map(|v| v * v).sum::<f32>() / (seq_len * dm) as f32;
93
94        // Backward pass
95        let scale = 2.0 / (seq_len * dm) as f32;
96        let mut d_temporal: Vec<f32> = temporal.iter().map(|v| v * scale).collect();
97        let mut grads = TrainMambaWeights::zeros_from_dims(&dims);
98        backward_mamba_backbone_batched(
99            &mut d_temporal,
100            &mut grads,
101            &acts,
102            &tw,
103            &a_neg,
104            &mut bwd_scratch,
105            &dims,
106        );
107
108        // SGD weight update
109        sgd_update_all(&mut tw, &grads, lr);
110
111        if step % 10 == 0 || step == steps - 1 {
112            println!("  step {step:3}: loss = {loss:.6}");
113        }
114    }
115
116    // =====================================================================
117    // Step 3: Save trained weights
118    // =====================================================================
119    let save_path = std::env::temp_dir().join("mamba_rs_trained.safetensors");
120    let trained_inf_weights = inference_weights_from_train(&tw, &inf_weights);
121    mamba_rs::serialize::save(&save_path, &trained_inf_weights, &cfg, input_dim)
122        .expect("save failed");
123    println!("\nSaved to: {}", save_path.display());
124
125    // =====================================================================
126    // Step 4: Load and run inference
127    // =====================================================================
128    let (loaded_w, loaded_cfg, loaded_input_dim) =
129        mamba_rs::serialize::load(Path::new(&save_path)).expect("load failed");
130    let bb = MambaBackbone::from_weights(loaded_cfg, loaded_w).expect("from_weights failed");
131    assert_eq!(bb.input_dim(), loaded_input_dim);
132    println!(
133        "Loaded: d_model={}, layers={}, params={}",
134        loaded_cfg.d_model,
135        loaded_cfg.n_layers,
136        bb.param_count()
137    );
138
139    // Run inference step-by-step
140    let mut inf_state = bb.alloc_state();
141    let mut scratch = bb.alloc_scratch();
142    let mut output = vec![0.0f32; dm];
143
144    println!("\nInference (10 steps):");
145    for step in 0..10 {
146        let inp: Vec<f32> = (0..input_dim)
147            .map(|i| ((step * input_dim + i) * 7 + 13) as f32 % 100.0 * 0.01 - 0.5)
148            .collect();
149        bb.forward_step(&inp, &mut output, &mut inf_state, &mut scratch);
150
151        let out_norm: f32 = output.iter().map(|v| v * v).sum::<f32>().sqrt();
152        println!("  step {step}: output L2 norm = {out_norm:.6}");
153    }
154
155    // Cleanup
156    std::fs::remove_file(&save_path).ok();
157    println!("\nDone.");
158}
Source

pub fn config(&self) -> &MambaConfig

The config this backbone was built with.

Examples found in repository?
examples/inference.rs (line 12)
3fn main() {
4    let cfg = MambaConfig::default();
5    let input_dim = cfg.d_model;
6
7    // Initialize backbone with paper-default weights
8    let backbone = MambaBackbone::init(cfg, input_dim, 42);
9    println!(
10        "Mamba: {} layers, d_model={}, d_inner={}, {} params",
11        backbone.n_layers(),
12        backbone.config().d_model,
13        backbone.config().d_inner(),
14        backbone.param_count(),
15    );
16
17    // Allocate recurrent state + scratch (once, reuse across steps)
18    let mut state = backbone.alloc_state();
19    let mut scratch = backbone.alloc_scratch();
20    let mut output = vec![0.0f32; backbone.config().d_model];
21
22    // Run 10 inference steps
23    for step in 0..10 {
24        let input = vec![0.1 * step as f32; input_dim];
25        backbone.forward_step(&input, &mut output, &mut state, &mut scratch);
26
27        let norm: f32 = output.iter().map(|x| x * x).sum::<f32>().sqrt();
28        println!("step {step}: output L2 norm = {norm:.6}");
29    }
30
31    // Reset state for new sequence
32    state.reset();
33    println!("state reset");
34}
Source

pub fn input_dim(&self) -> usize

External input dimension.

Examples found in repository?
examples/train_and_infer.rs (line 131)
18fn main() {
19    let cfg = MambaConfig::default(); // d_model=128, 3 layers, 366K params
20    let input_dim = cfg.d_model;
21    let seq_len = 16;
22    let lr = 1e-3_f32;
23    let steps = 50;
24
25    println!("mamba-rs: train -> save -> load -> inference");
26    println!("=============================================");
27    println!(
28        "Config: d_model={}, layers={}, d_inner={}, params={}",
29        cfg.d_model,
30        cfg.n_layers,
31        cfg.d_inner(),
32        MambaBackbone::init(cfg, input_dim, 42).param_count()
33    );
34    println!();
35
36    // =====================================================================
37    // Step 1: Initialize weights
38    // =====================================================================
39    let inf_weights = MambaWeights::init(&cfg, input_dim, 42);
40    let mut tw = train_weights_from_inference(&inf_weights);
41
42    let dims = MambaDims::from_config(&cfg, seq_len, input_dim);
43    let di = cfg.d_inner();
44    let ds = cfg.d_state;
45    let dc = cfg.d_conv;
46    let dm = cfg.d_model;
47
48    // Pre-allocate scratch (reused every step)
49    let mut acts = MambaBackboneFlat::zeros(dims);
50    let mut fwd_scratch = PhaseScratch::zeros(&dims);
51    let mut bwd_scratch = BackwardPhaseScratch::zeros(&dims);
52
53    // Synthetic training data: random input, target = zeros (regression)
54    let input: Vec<f32> = (0..seq_len * input_dim)
55        .map(|i| ((i * 7 + 13) % 100) as f32 * 0.01 - 0.5)
56        .collect();
57
58    // =====================================================================
59    // Step 2: Training loop (SGD, MSE loss)
60    // =====================================================================
61    println!("Training ({steps} steps, lr={lr}, seq_len={seq_len}):");
62
63    for step in 0..steps {
64        // Precompute a_neg = -exp(a_log)
65        let mut a_neg = vec![0.0f32; cfg.n_layers * di * ds];
66        for (l, lw) in tw.layers.iter().enumerate() {
67            for i in 0..di * ds {
68                a_neg[l * di * ds + i] = -lw.a_log[i].exp();
69            }
70        }
71
72        // Forward pass
73        let mut conv = vec![0.0f32; cfg.n_layers * di * dc];
74        let mut ssm = vec![0.0f32; cfg.n_layers * di * ds];
75        let mut state = MambaRecurrentState {
76            conv: &mut conv,
77            ssm: &mut ssm,
78            a_neg: &a_neg,
79        };
80        let mut temporal = vec![0.0f32; seq_len * dm];
81        forward_mamba_backbone_batched(
82            &mut temporal,
83            &mut acts,
84            &tw,
85            &input,
86            &mut state,
87            &mut fwd_scratch,
88            &dims,
89        );
90
91        // MSE loss
92        let loss: f32 = temporal.iter().map(|v| v * v).sum::<f32>() / (seq_len * dm) as f32;
93
94        // Backward pass
95        let scale = 2.0 / (seq_len * dm) as f32;
96        let mut d_temporal: Vec<f32> = temporal.iter().map(|v| v * scale).collect();
97        let mut grads = TrainMambaWeights::zeros_from_dims(&dims);
98        backward_mamba_backbone_batched(
99            &mut d_temporal,
100            &mut grads,
101            &acts,
102            &tw,
103            &a_neg,
104            &mut bwd_scratch,
105            &dims,
106        );
107
108        // SGD weight update
109        sgd_update_all(&mut tw, &grads, lr);
110
111        if step % 10 == 0 || step == steps - 1 {
112            println!("  step {step:3}: loss = {loss:.6}");
113        }
114    }
115
116    // =====================================================================
117    // Step 3: Save trained weights
118    // =====================================================================
119    let save_path = std::env::temp_dir().join("mamba_rs_trained.safetensors");
120    let trained_inf_weights = inference_weights_from_train(&tw, &inf_weights);
121    mamba_rs::serialize::save(&save_path, &trained_inf_weights, &cfg, input_dim)
122        .expect("save failed");
123    println!("\nSaved to: {}", save_path.display());
124
125    // =====================================================================
126    // Step 4: Load and run inference
127    // =====================================================================
128    let (loaded_w, loaded_cfg, loaded_input_dim) =
129        mamba_rs::serialize::load(Path::new(&save_path)).expect("load failed");
130    let bb = MambaBackbone::from_weights(loaded_cfg, loaded_w).expect("from_weights failed");
131    assert_eq!(bb.input_dim(), loaded_input_dim);
132    println!(
133        "Loaded: d_model={}, layers={}, params={}",
134        loaded_cfg.d_model,
135        loaded_cfg.n_layers,
136        bb.param_count()
137    );
138
139    // Run inference step-by-step
140    let mut inf_state = bb.alloc_state();
141    let mut scratch = bb.alloc_scratch();
142    let mut output = vec![0.0f32; dm];
143
144    println!("\nInference (10 steps):");
145    for step in 0..10 {
146        let inp: Vec<f32> = (0..input_dim)
147            .map(|i| ((step * input_dim + i) * 7 + 13) as f32 % 100.0 * 0.01 - 0.5)
148            .collect();
149        bb.forward_step(&inp, &mut output, &mut inf_state, &mut scratch);
150
151        let out_norm: f32 = output.iter().map(|v| v * v).sum::<f32>().sqrt();
152        println!("  step {step}: output L2 norm = {out_norm:.6}");
153    }
154
155    // Cleanup
156    std::fs::remove_file(&save_path).ok();
157    println!("\nDone.");
158}
Source

pub fn forward_step( &self, input: &[f32], output: &mut [f32], state: &mut MambaState, scratch: &mut MambaStepScratch, )

Single-step recurrent forward through the full backbone.

input_proj(input) -> N x layer_step -> norm_f -> output

Zero allocations per call. Delegates to mamba_step.

Examples found in repository?
examples/inference.rs (line 25)
3fn main() {
4    let cfg = MambaConfig::default();
5    let input_dim = cfg.d_model;
6
7    // Initialize backbone with paper-default weights
8    let backbone = MambaBackbone::init(cfg, input_dim, 42);
9    println!(
10        "Mamba: {} layers, d_model={}, d_inner={}, {} params",
11        backbone.n_layers(),
12        backbone.config().d_model,
13        backbone.config().d_inner(),
14        backbone.param_count(),
15    );
16
17    // Allocate recurrent state + scratch (once, reuse across steps)
18    let mut state = backbone.alloc_state();
19    let mut scratch = backbone.alloc_scratch();
20    let mut output = vec![0.0f32; backbone.config().d_model];
21
22    // Run 10 inference steps
23    for step in 0..10 {
24        let input = vec![0.1 * step as f32; input_dim];
25        backbone.forward_step(&input, &mut output, &mut state, &mut scratch);
26
27        let norm: f32 = output.iter().map(|x| x * x).sum::<f32>().sqrt();
28        println!("step {step}: output L2 norm = {norm:.6}");
29    }
30
31    // Reset state for new sequence
32    state.reset();
33    println!("state reset");
34}
More examples
Hide additional examples
examples/train_and_infer.rs (line 149)
18fn main() {
19    let cfg = MambaConfig::default(); // d_model=128, 3 layers, 366K params
20    let input_dim = cfg.d_model;
21    let seq_len = 16;
22    let lr = 1e-3_f32;
23    let steps = 50;
24
25    println!("mamba-rs: train -> save -> load -> inference");
26    println!("=============================================");
27    println!(
28        "Config: d_model={}, layers={}, d_inner={}, params={}",
29        cfg.d_model,
30        cfg.n_layers,
31        cfg.d_inner(),
32        MambaBackbone::init(cfg, input_dim, 42).param_count()
33    );
34    println!();
35
36    // =====================================================================
37    // Step 1: Initialize weights
38    // =====================================================================
39    let inf_weights = MambaWeights::init(&cfg, input_dim, 42);
40    let mut tw = train_weights_from_inference(&inf_weights);
41
42    let dims = MambaDims::from_config(&cfg, seq_len, input_dim);
43    let di = cfg.d_inner();
44    let ds = cfg.d_state;
45    let dc = cfg.d_conv;
46    let dm = cfg.d_model;
47
48    // Pre-allocate scratch (reused every step)
49    let mut acts = MambaBackboneFlat::zeros(dims);
50    let mut fwd_scratch = PhaseScratch::zeros(&dims);
51    let mut bwd_scratch = BackwardPhaseScratch::zeros(&dims);
52
53    // Synthetic training data: random input, target = zeros (regression)
54    let input: Vec<f32> = (0..seq_len * input_dim)
55        .map(|i| ((i * 7 + 13) % 100) as f32 * 0.01 - 0.5)
56        .collect();
57
58    // =====================================================================
59    // Step 2: Training loop (SGD, MSE loss)
60    // =====================================================================
61    println!("Training ({steps} steps, lr={lr}, seq_len={seq_len}):");
62
63    for step in 0..steps {
64        // Precompute a_neg = -exp(a_log)
65        let mut a_neg = vec![0.0f32; cfg.n_layers * di * ds];
66        for (l, lw) in tw.layers.iter().enumerate() {
67            for i in 0..di * ds {
68                a_neg[l * di * ds + i] = -lw.a_log[i].exp();
69            }
70        }
71
72        // Forward pass
73        let mut conv = vec![0.0f32; cfg.n_layers * di * dc];
74        let mut ssm = vec![0.0f32; cfg.n_layers * di * ds];
75        let mut state = MambaRecurrentState {
76            conv: &mut conv,
77            ssm: &mut ssm,
78            a_neg: &a_neg,
79        };
80        let mut temporal = vec![0.0f32; seq_len * dm];
81        forward_mamba_backbone_batched(
82            &mut temporal,
83            &mut acts,
84            &tw,
85            &input,
86            &mut state,
87            &mut fwd_scratch,
88            &dims,
89        );
90
91        // MSE loss
92        let loss: f32 = temporal.iter().map(|v| v * v).sum::<f32>() / (seq_len * dm) as f32;
93
94        // Backward pass
95        let scale = 2.0 / (seq_len * dm) as f32;
96        let mut d_temporal: Vec<f32> = temporal.iter().map(|v| v * scale).collect();
97        let mut grads = TrainMambaWeights::zeros_from_dims(&dims);
98        backward_mamba_backbone_batched(
99            &mut d_temporal,
100            &mut grads,
101            &acts,
102            &tw,
103            &a_neg,
104            &mut bwd_scratch,
105            &dims,
106        );
107
108        // SGD weight update
109        sgd_update_all(&mut tw, &grads, lr);
110
111        if step % 10 == 0 || step == steps - 1 {
112            println!("  step {step:3}: loss = {loss:.6}");
113        }
114    }
115
116    // =====================================================================
117    // Step 3: Save trained weights
118    // =====================================================================
119    let save_path = std::env::temp_dir().join("mamba_rs_trained.safetensors");
120    let trained_inf_weights = inference_weights_from_train(&tw, &inf_weights);
121    mamba_rs::serialize::save(&save_path, &trained_inf_weights, &cfg, input_dim)
122        .expect("save failed");
123    println!("\nSaved to: {}", save_path.display());
124
125    // =====================================================================
126    // Step 4: Load and run inference
127    // =====================================================================
128    let (loaded_w, loaded_cfg, loaded_input_dim) =
129        mamba_rs::serialize::load(Path::new(&save_path)).expect("load failed");
130    let bb = MambaBackbone::from_weights(loaded_cfg, loaded_w).expect("from_weights failed");
131    assert_eq!(bb.input_dim(), loaded_input_dim);
132    println!(
133        "Loaded: d_model={}, layers={}, params={}",
134        loaded_cfg.d_model,
135        loaded_cfg.n_layers,
136        bb.param_count()
137    );
138
139    // Run inference step-by-step
140    let mut inf_state = bb.alloc_state();
141    let mut scratch = bb.alloc_scratch();
142    let mut output = vec![0.0f32; dm];
143
144    println!("\nInference (10 steps):");
145    for step in 0..10 {
146        let inp: Vec<f32> = (0..input_dim)
147            .map(|i| ((step * input_dim + i) * 7 + 13) as f32 % 100.0 * 0.01 - 0.5)
148            .collect();
149        bb.forward_step(&inp, &mut output, &mut inf_state, &mut scratch);
150
151        let out_norm: f32 = output.iter().map(|v| v * v).sum::<f32>().sqrt();
152        println!("  step {step}: output L2 norm = {out_norm:.6}");
153    }
154
155    // Cleanup
156    std::fs::remove_file(&save_path).ok();
157    println!("\nDone.");
158}
Source

pub fn forward_sequence( &self, inputs: &[f32], outputs: &mut [f32], state: &mut MambaState, scratch: &mut MambaStepScratch, seq_len: usize, )

Run T inference steps sequentially, collecting all outputs.

inputs: [T * input_dim] — T sequential inputs. outputs: [T * d_model] — T sequential outputs (written in-place). State carries across all T steps (warm-up, offline eval, etc.).

Source

pub fn forward_step_batch( &self, inputs: &[f32], outputs: &mut [f32], states: &mut [MambaState], scratches: &mut [MambaStepScratch], )

Batched single-step forward through the backbone.

Processes B independent samples with the same weights. inputs: [B * input_dim], outputs: [B * d_model].

Source

pub fn alloc_state(&self) -> MambaState

Allocate zeroed recurrent state matching this backbone.

Examples found in repository?
examples/inference.rs (line 18)
3fn main() {
4    let cfg = MambaConfig::default();
5    let input_dim = cfg.d_model;
6
7    // Initialize backbone with paper-default weights
8    let backbone = MambaBackbone::init(cfg, input_dim, 42);
9    println!(
10        "Mamba: {} layers, d_model={}, d_inner={}, {} params",
11        backbone.n_layers(),
12        backbone.config().d_model,
13        backbone.config().d_inner(),
14        backbone.param_count(),
15    );
16
17    // Allocate recurrent state + scratch (once, reuse across steps)
18    let mut state = backbone.alloc_state();
19    let mut scratch = backbone.alloc_scratch();
20    let mut output = vec![0.0f32; backbone.config().d_model];
21
22    // Run 10 inference steps
23    for step in 0..10 {
24        let input = vec![0.1 * step as f32; input_dim];
25        backbone.forward_step(&input, &mut output, &mut state, &mut scratch);
26
27        let norm: f32 = output.iter().map(|x| x * x).sum::<f32>().sqrt();
28        println!("step {step}: output L2 norm = {norm:.6}");
29    }
30
31    // Reset state for new sequence
32    state.reset();
33    println!("state reset");
34}
More examples
Hide additional examples
examples/train_and_infer.rs (line 140)
18fn main() {
19    let cfg = MambaConfig::default(); // d_model=128, 3 layers, 366K params
20    let input_dim = cfg.d_model;
21    let seq_len = 16;
22    let lr = 1e-3_f32;
23    let steps = 50;
24
25    println!("mamba-rs: train -> save -> load -> inference");
26    println!("=============================================");
27    println!(
28        "Config: d_model={}, layers={}, d_inner={}, params={}",
29        cfg.d_model,
30        cfg.n_layers,
31        cfg.d_inner(),
32        MambaBackbone::init(cfg, input_dim, 42).param_count()
33    );
34    println!();
35
36    // =====================================================================
37    // Step 1: Initialize weights
38    // =====================================================================
39    let inf_weights = MambaWeights::init(&cfg, input_dim, 42);
40    let mut tw = train_weights_from_inference(&inf_weights);
41
42    let dims = MambaDims::from_config(&cfg, seq_len, input_dim);
43    let di = cfg.d_inner();
44    let ds = cfg.d_state;
45    let dc = cfg.d_conv;
46    let dm = cfg.d_model;
47
48    // Pre-allocate scratch (reused every step)
49    let mut acts = MambaBackboneFlat::zeros(dims);
50    let mut fwd_scratch = PhaseScratch::zeros(&dims);
51    let mut bwd_scratch = BackwardPhaseScratch::zeros(&dims);
52
53    // Synthetic training data: random input, target = zeros (regression)
54    let input: Vec<f32> = (0..seq_len * input_dim)
55        .map(|i| ((i * 7 + 13) % 100) as f32 * 0.01 - 0.5)
56        .collect();
57
58    // =====================================================================
59    // Step 2: Training loop (SGD, MSE loss)
60    // =====================================================================
61    println!("Training ({steps} steps, lr={lr}, seq_len={seq_len}):");
62
63    for step in 0..steps {
64        // Precompute a_neg = -exp(a_log)
65        let mut a_neg = vec![0.0f32; cfg.n_layers * di * ds];
66        for (l, lw) in tw.layers.iter().enumerate() {
67            for i in 0..di * ds {
68                a_neg[l * di * ds + i] = -lw.a_log[i].exp();
69            }
70        }
71
72        // Forward pass
73        let mut conv = vec![0.0f32; cfg.n_layers * di * dc];
74        let mut ssm = vec![0.0f32; cfg.n_layers * di * ds];
75        let mut state = MambaRecurrentState {
76            conv: &mut conv,
77            ssm: &mut ssm,
78            a_neg: &a_neg,
79        };
80        let mut temporal = vec![0.0f32; seq_len * dm];
81        forward_mamba_backbone_batched(
82            &mut temporal,
83            &mut acts,
84            &tw,
85            &input,
86            &mut state,
87            &mut fwd_scratch,
88            &dims,
89        );
90
91        // MSE loss
92        let loss: f32 = temporal.iter().map(|v| v * v).sum::<f32>() / (seq_len * dm) as f32;
93
94        // Backward pass
95        let scale = 2.0 / (seq_len * dm) as f32;
96        let mut d_temporal: Vec<f32> = temporal.iter().map(|v| v * scale).collect();
97        let mut grads = TrainMambaWeights::zeros_from_dims(&dims);
98        backward_mamba_backbone_batched(
99            &mut d_temporal,
100            &mut grads,
101            &acts,
102            &tw,
103            &a_neg,
104            &mut bwd_scratch,
105            &dims,
106        );
107
108        // SGD weight update
109        sgd_update_all(&mut tw, &grads, lr);
110
111        if step % 10 == 0 || step == steps - 1 {
112            println!("  step {step:3}: loss = {loss:.6}");
113        }
114    }
115
116    // =====================================================================
117    // Step 3: Save trained weights
118    // =====================================================================
119    let save_path = std::env::temp_dir().join("mamba_rs_trained.safetensors");
120    let trained_inf_weights = inference_weights_from_train(&tw, &inf_weights);
121    mamba_rs::serialize::save(&save_path, &trained_inf_weights, &cfg, input_dim)
122        .expect("save failed");
123    println!("\nSaved to: {}", save_path.display());
124
125    // =====================================================================
126    // Step 4: Load and run inference
127    // =====================================================================
128    let (loaded_w, loaded_cfg, loaded_input_dim) =
129        mamba_rs::serialize::load(Path::new(&save_path)).expect("load failed");
130    let bb = MambaBackbone::from_weights(loaded_cfg, loaded_w).expect("from_weights failed");
131    assert_eq!(bb.input_dim(), loaded_input_dim);
132    println!(
133        "Loaded: d_model={}, layers={}, params={}",
134        loaded_cfg.d_model,
135        loaded_cfg.n_layers,
136        bb.param_count()
137    );
138
139    // Run inference step-by-step
140    let mut inf_state = bb.alloc_state();
141    let mut scratch = bb.alloc_scratch();
142    let mut output = vec![0.0f32; dm];
143
144    println!("\nInference (10 steps):");
145    for step in 0..10 {
146        let inp: Vec<f32> = (0..input_dim)
147            .map(|i| ((step * input_dim + i) * 7 + 13) as f32 % 100.0 * 0.01 - 0.5)
148            .collect();
149        bb.forward_step(&inp, &mut output, &mut inf_state, &mut scratch);
150
151        let out_norm: f32 = output.iter().map(|v| v * v).sum::<f32>().sqrt();
152        println!("  step {step}: output L2 norm = {out_norm:.6}");
153    }
154
155    // Cleanup
156    std::fs::remove_file(&save_path).ok();
157    println!("\nDone.");
158}
Source

pub fn alloc_scratch(&self) -> MambaStepScratch

Allocate inference scratch buffers matching this backbone.

Examples found in repository?
examples/inference.rs (line 19)
3fn main() {
4    let cfg = MambaConfig::default();
5    let input_dim = cfg.d_model;
6
7    // Initialize backbone with paper-default weights
8    let backbone = MambaBackbone::init(cfg, input_dim, 42);
9    println!(
10        "Mamba: {} layers, d_model={}, d_inner={}, {} params",
11        backbone.n_layers(),
12        backbone.config().d_model,
13        backbone.config().d_inner(),
14        backbone.param_count(),
15    );
16
17    // Allocate recurrent state + scratch (once, reuse across steps)
18    let mut state = backbone.alloc_state();
19    let mut scratch = backbone.alloc_scratch();
20    let mut output = vec![0.0f32; backbone.config().d_model];
21
22    // Run 10 inference steps
23    for step in 0..10 {
24        let input = vec![0.1 * step as f32; input_dim];
25        backbone.forward_step(&input, &mut output, &mut state, &mut scratch);
26
27        let norm: f32 = output.iter().map(|x| x * x).sum::<f32>().sqrt();
28        println!("step {step}: output L2 norm = {norm:.6}");
29    }
30
31    // Reset state for new sequence
32    state.reset();
33    println!("state reset");
34}
More examples
Hide additional examples
examples/train_and_infer.rs (line 141)
18fn main() {
19    let cfg = MambaConfig::default(); // d_model=128, 3 layers, 366K params
20    let input_dim = cfg.d_model;
21    let seq_len = 16;
22    let lr = 1e-3_f32;
23    let steps = 50;
24
25    println!("mamba-rs: train -> save -> load -> inference");
26    println!("=============================================");
27    println!(
28        "Config: d_model={}, layers={}, d_inner={}, params={}",
29        cfg.d_model,
30        cfg.n_layers,
31        cfg.d_inner(),
32        MambaBackbone::init(cfg, input_dim, 42).param_count()
33    );
34    println!();
35
36    // =====================================================================
37    // Step 1: Initialize weights
38    // =====================================================================
39    let inf_weights = MambaWeights::init(&cfg, input_dim, 42);
40    let mut tw = train_weights_from_inference(&inf_weights);
41
42    let dims = MambaDims::from_config(&cfg, seq_len, input_dim);
43    let di = cfg.d_inner();
44    let ds = cfg.d_state;
45    let dc = cfg.d_conv;
46    let dm = cfg.d_model;
47
48    // Pre-allocate scratch (reused every step)
49    let mut acts = MambaBackboneFlat::zeros(dims);
50    let mut fwd_scratch = PhaseScratch::zeros(&dims);
51    let mut bwd_scratch = BackwardPhaseScratch::zeros(&dims);
52
53    // Synthetic training data: random input, target = zeros (regression)
54    let input: Vec<f32> = (0..seq_len * input_dim)
55        .map(|i| ((i * 7 + 13) % 100) as f32 * 0.01 - 0.5)
56        .collect();
57
58    // =====================================================================
59    // Step 2: Training loop (SGD, MSE loss)
60    // =====================================================================
61    println!("Training ({steps} steps, lr={lr}, seq_len={seq_len}):");
62
63    for step in 0..steps {
64        // Precompute a_neg = -exp(a_log)
65        let mut a_neg = vec![0.0f32; cfg.n_layers * di * ds];
66        for (l, lw) in tw.layers.iter().enumerate() {
67            for i in 0..di * ds {
68                a_neg[l * di * ds + i] = -lw.a_log[i].exp();
69            }
70        }
71
72        // Forward pass
73        let mut conv = vec![0.0f32; cfg.n_layers * di * dc];
74        let mut ssm = vec![0.0f32; cfg.n_layers * di * ds];
75        let mut state = MambaRecurrentState {
76            conv: &mut conv,
77            ssm: &mut ssm,
78            a_neg: &a_neg,
79        };
80        let mut temporal = vec![0.0f32; seq_len * dm];
81        forward_mamba_backbone_batched(
82            &mut temporal,
83            &mut acts,
84            &tw,
85            &input,
86            &mut state,
87            &mut fwd_scratch,
88            &dims,
89        );
90
91        // MSE loss
92        let loss: f32 = temporal.iter().map(|v| v * v).sum::<f32>() / (seq_len * dm) as f32;
93
94        // Backward pass
95        let scale = 2.0 / (seq_len * dm) as f32;
96        let mut d_temporal: Vec<f32> = temporal.iter().map(|v| v * scale).collect();
97        let mut grads = TrainMambaWeights::zeros_from_dims(&dims);
98        backward_mamba_backbone_batched(
99            &mut d_temporal,
100            &mut grads,
101            &acts,
102            &tw,
103            &a_neg,
104            &mut bwd_scratch,
105            &dims,
106        );
107
108        // SGD weight update
109        sgd_update_all(&mut tw, &grads, lr);
110
111        if step % 10 == 0 || step == steps - 1 {
112            println!("  step {step:3}: loss = {loss:.6}");
113        }
114    }
115
116    // =====================================================================
117    // Step 3: Save trained weights
118    // =====================================================================
119    let save_path = std::env::temp_dir().join("mamba_rs_trained.safetensors");
120    let trained_inf_weights = inference_weights_from_train(&tw, &inf_weights);
121    mamba_rs::serialize::save(&save_path, &trained_inf_weights, &cfg, input_dim)
122        .expect("save failed");
123    println!("\nSaved to: {}", save_path.display());
124
125    // =====================================================================
126    // Step 4: Load and run inference
127    // =====================================================================
128    let (loaded_w, loaded_cfg, loaded_input_dim) =
129        mamba_rs::serialize::load(Path::new(&save_path)).expect("load failed");
130    let bb = MambaBackbone::from_weights(loaded_cfg, loaded_w).expect("from_weights failed");
131    assert_eq!(bb.input_dim(), loaded_input_dim);
132    println!(
133        "Loaded: d_model={}, layers={}, params={}",
134        loaded_cfg.d_model,
135        loaded_cfg.n_layers,
136        bb.param_count()
137    );
138
139    // Run inference step-by-step
140    let mut inf_state = bb.alloc_state();
141    let mut scratch = bb.alloc_scratch();
142    let mut output = vec![0.0f32; dm];
143
144    println!("\nInference (10 steps):");
145    for step in 0..10 {
146        let inp: Vec<f32> = (0..input_dim)
147            .map(|i| ((step * input_dim + i) * 7 + 13) as f32 % 100.0 * 0.01 - 0.5)
148            .collect();
149        bb.forward_step(&inp, &mut output, &mut inf_state, &mut scratch);
150
151        let out_norm: f32 = output.iter().map(|v| v * v).sum::<f32>().sqrt();
152        println!("  step {step}: output L2 norm = {out_norm:.6}");
153    }
154
155    // Cleanup
156    std::fs::remove_file(&save_path).ok();
157    println!("\nDone.");
158}

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> IntoEither for T

Source§

fn into_either(self, into_left: bool) -> Either<Self, Self>

Converts self into a Left variant of Either<Self, Self> if into_left is true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
where F: FnOnce(&Self) -> bool,

Converts self into a Left variant of Either<Self, Self> if into_left(&self) returns true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

impl<T> Pointable for T

Source§

const ALIGN: usize

The alignment of pointer.
Source§

type Init = T

The type for initializers.
Source§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
Source§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
Source§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
Source§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.