pub struct MambaBackbone { /* private fields */ }Expand description
Complete Mamba backbone: input_proj -> N layers -> norm_f.
Owns all weights. Provides both single-step recurrent inference and access to raw weights for training integration.
use mamba_rs::module::MambaBackbone;
use mamba_rs::MambaConfig;
let cfg = MambaConfig::default();
let backbone = MambaBackbone::init(cfg, 128, 42);
let mut state = backbone.alloc_state();
let mut scratch = backbone.alloc_scratch();
let mut output = vec![0.0f32; backbone.config().d_model];
let input = vec![0.1f32; 128];
backbone.forward_step(&input, &mut output, &mut state, &mut scratch);Implementations§
Source§impl MambaBackbone
impl MambaBackbone
Sourcepub fn init(cfg: MambaConfig, input_dim: usize, seed: u64) -> Self
pub fn init(cfg: MambaConfig, input_dim: usize, seed: u64) -> Self
Create a backbone with Mamba-specific weight initialization.
Uses Kaiming uniform for projections, log-space init for A, inverse-softplus init for dt_proj bias (Gu & Dao, Section 3.5).
Examples found in repository?
3fn main() {
4 let cfg = MambaConfig::default();
5 let input_dim = cfg.d_model;
6
7 // Initialize backbone with paper-default weights
8 let backbone = MambaBackbone::init(cfg, input_dim, 42);
9 println!(
10 "Mamba: {} layers, d_model={}, d_inner={}, {} params",
11 backbone.n_layers(),
12 backbone.config().d_model,
13 backbone.config().d_inner(),
14 backbone.param_count(),
15 );
16
17 // Allocate recurrent state + scratch (once, reuse across steps)
18 let mut state = backbone.alloc_state();
19 let mut scratch = backbone.alloc_scratch();
20 let mut output = vec![0.0f32; backbone.config().d_model];
21
22 // Run 10 inference steps
23 for step in 0..10 {
24 let input = vec![0.1 * step as f32; input_dim];
25 backbone.forward_step(&input, &mut output, &mut state, &mut scratch);
26
27 let norm: f32 = output.iter().map(|x| x * x).sum::<f32>().sqrt();
28 println!("step {step}: output L2 norm = {norm:.6}");
29 }
30
31 // Reset state for new sequence
32 state.reset();
33 println!("state reset");
34}More examples
18fn main() {
19 let cfg = MambaConfig::default(); // d_model=128, 3 layers, 366K params
20 let input_dim = cfg.d_model;
21 let seq_len = 16;
22 let lr = 1e-3_f32;
23 let steps = 50;
24
25 println!("mamba-rs: train -> save -> load -> inference");
26 println!("=============================================");
27 println!(
28 "Config: d_model={}, layers={}, d_inner={}, params={}",
29 cfg.d_model,
30 cfg.n_layers,
31 cfg.d_inner(),
32 MambaBackbone::init(cfg, input_dim, 42).param_count()
33 );
34 println!();
35
36 // =====================================================================
37 // Step 1: Initialize weights
38 // =====================================================================
39 let inf_weights = MambaWeights::init(&cfg, input_dim, 42);
40 let mut tw = train_weights_from_inference(&inf_weights);
41
42 let dims = MambaDims::from_config(&cfg, seq_len, input_dim);
43 let di = cfg.d_inner();
44 let ds = cfg.d_state;
45 let dc = cfg.d_conv;
46 let dm = cfg.d_model;
47
48 // Pre-allocate scratch (reused every step)
49 let mut acts = MambaBackboneFlat::zeros(dims);
50 let mut fwd_scratch = PhaseScratch::zeros(&dims);
51 let mut bwd_scratch = BackwardPhaseScratch::zeros(&dims);
52
53 // Synthetic training data: random input, target = zeros (regression)
54 let input: Vec<f32> = (0..seq_len * input_dim)
55 .map(|i| ((i * 7 + 13) % 100) as f32 * 0.01 - 0.5)
56 .collect();
57
58 // =====================================================================
59 // Step 2: Training loop (SGD, MSE loss)
60 // =====================================================================
61 println!("Training ({steps} steps, lr={lr}, seq_len={seq_len}):");
62
63 for step in 0..steps {
64 // Precompute a_neg = -exp(a_log)
65 let mut a_neg = vec![0.0f32; cfg.n_layers * di * ds];
66 for (l, lw) in tw.layers.iter().enumerate() {
67 for i in 0..di * ds {
68 a_neg[l * di * ds + i] = -lw.a_log[i].exp();
69 }
70 }
71
72 // Forward pass
73 let mut conv = vec![0.0f32; cfg.n_layers * di * dc];
74 let mut ssm = vec![0.0f32; cfg.n_layers * di * ds];
75 let mut state = MambaRecurrentState {
76 conv: &mut conv,
77 ssm: &mut ssm,
78 a_neg: &a_neg,
79 };
80 let mut temporal = vec![0.0f32; seq_len * dm];
81 forward_mamba_backbone_batched(
82 &mut temporal,
83 &mut acts,
84 &tw,
85 &input,
86 &mut state,
87 &mut fwd_scratch,
88 &dims,
89 );
90
91 // MSE loss
92 let loss: f32 = temporal.iter().map(|v| v * v).sum::<f32>() / (seq_len * dm) as f32;
93
94 // Backward pass
95 let scale = 2.0 / (seq_len * dm) as f32;
96 let mut d_temporal: Vec<f32> = temporal.iter().map(|v| v * scale).collect();
97 let mut grads = TrainMambaWeights::zeros_from_dims(&dims);
98 backward_mamba_backbone_batched(
99 &mut d_temporal,
100 &mut grads,
101 &acts,
102 &tw,
103 &a_neg,
104 &mut bwd_scratch,
105 &dims,
106 );
107
108 // SGD weight update
109 sgd_update_all(&mut tw, &grads, lr);
110
111 if step % 10 == 0 || step == steps - 1 {
112 println!(" step {step:3}: loss = {loss:.6}");
113 }
114 }
115
116 // =====================================================================
117 // Step 3: Save trained weights
118 // =====================================================================
119 let save_path = std::env::temp_dir().join("mamba_rs_trained.safetensors");
120 let trained_inf_weights = inference_weights_from_train(&tw, &inf_weights);
121 mamba_rs::serialize::save(&save_path, &trained_inf_weights, &cfg, input_dim)
122 .expect("save failed");
123 println!("\nSaved to: {}", save_path.display());
124
125 // =====================================================================
126 // Step 4: Load and run inference
127 // =====================================================================
128 let (loaded_w, loaded_cfg, loaded_input_dim) =
129 mamba_rs::serialize::load(Path::new(&save_path)).expect("load failed");
130 let bb = MambaBackbone::from_weights(loaded_cfg, loaded_w).expect("from_weights failed");
131 assert_eq!(bb.input_dim(), loaded_input_dim);
132 println!(
133 "Loaded: d_model={}, layers={}, params={}",
134 loaded_cfg.d_model,
135 loaded_cfg.n_layers,
136 bb.param_count()
137 );
138
139 // Run inference step-by-step
140 let mut inf_state = bb.alloc_state();
141 let mut scratch = bb.alloc_scratch();
142 let mut output = vec![0.0f32; dm];
143
144 println!("\nInference (10 steps):");
145 for step in 0..10 {
146 let inp: Vec<f32> = (0..input_dim)
147 .map(|i| ((step * input_dim + i) * 7 + 13) as f32 % 100.0 * 0.01 - 0.5)
148 .collect();
149 bb.forward_step(&inp, &mut output, &mut inf_state, &mut scratch);
150
151 let out_norm: f32 = output.iter().map(|v| v * v).sum::<f32>().sqrt();
152 println!(" step {step}: output L2 norm = {out_norm:.6}");
153 }
154
155 // Cleanup
156 std::fs::remove_file(&save_path).ok();
157 println!("\nDone.");
158}Sourcepub fn from_weights(
cfg: MambaConfig,
weights: MambaWeights,
) -> Result<Self, String>
pub fn from_weights( cfg: MambaConfig, weights: MambaWeights, ) -> Result<Self, String>
Create a backbone from pre-loaded weights.
Validates dimensions against config. Returns Err on mismatch.
Examples found in repository?
18fn main() {
19 let cfg = MambaConfig::default(); // d_model=128, 3 layers, 366K params
20 let input_dim = cfg.d_model;
21 let seq_len = 16;
22 let lr = 1e-3_f32;
23 let steps = 50;
24
25 println!("mamba-rs: train -> save -> load -> inference");
26 println!("=============================================");
27 println!(
28 "Config: d_model={}, layers={}, d_inner={}, params={}",
29 cfg.d_model,
30 cfg.n_layers,
31 cfg.d_inner(),
32 MambaBackbone::init(cfg, input_dim, 42).param_count()
33 );
34 println!();
35
36 // =====================================================================
37 // Step 1: Initialize weights
38 // =====================================================================
39 let inf_weights = MambaWeights::init(&cfg, input_dim, 42);
40 let mut tw = train_weights_from_inference(&inf_weights);
41
42 let dims = MambaDims::from_config(&cfg, seq_len, input_dim);
43 let di = cfg.d_inner();
44 let ds = cfg.d_state;
45 let dc = cfg.d_conv;
46 let dm = cfg.d_model;
47
48 // Pre-allocate scratch (reused every step)
49 let mut acts = MambaBackboneFlat::zeros(dims);
50 let mut fwd_scratch = PhaseScratch::zeros(&dims);
51 let mut bwd_scratch = BackwardPhaseScratch::zeros(&dims);
52
53 // Synthetic training data: random input, target = zeros (regression)
54 let input: Vec<f32> = (0..seq_len * input_dim)
55 .map(|i| ((i * 7 + 13) % 100) as f32 * 0.01 - 0.5)
56 .collect();
57
58 // =====================================================================
59 // Step 2: Training loop (SGD, MSE loss)
60 // =====================================================================
61 println!("Training ({steps} steps, lr={lr}, seq_len={seq_len}):");
62
63 for step in 0..steps {
64 // Precompute a_neg = -exp(a_log)
65 let mut a_neg = vec![0.0f32; cfg.n_layers * di * ds];
66 for (l, lw) in tw.layers.iter().enumerate() {
67 for i in 0..di * ds {
68 a_neg[l * di * ds + i] = -lw.a_log[i].exp();
69 }
70 }
71
72 // Forward pass
73 let mut conv = vec![0.0f32; cfg.n_layers * di * dc];
74 let mut ssm = vec![0.0f32; cfg.n_layers * di * ds];
75 let mut state = MambaRecurrentState {
76 conv: &mut conv,
77 ssm: &mut ssm,
78 a_neg: &a_neg,
79 };
80 let mut temporal = vec![0.0f32; seq_len * dm];
81 forward_mamba_backbone_batched(
82 &mut temporal,
83 &mut acts,
84 &tw,
85 &input,
86 &mut state,
87 &mut fwd_scratch,
88 &dims,
89 );
90
91 // MSE loss
92 let loss: f32 = temporal.iter().map(|v| v * v).sum::<f32>() / (seq_len * dm) as f32;
93
94 // Backward pass
95 let scale = 2.0 / (seq_len * dm) as f32;
96 let mut d_temporal: Vec<f32> = temporal.iter().map(|v| v * scale).collect();
97 let mut grads = TrainMambaWeights::zeros_from_dims(&dims);
98 backward_mamba_backbone_batched(
99 &mut d_temporal,
100 &mut grads,
101 &acts,
102 &tw,
103 &a_neg,
104 &mut bwd_scratch,
105 &dims,
106 );
107
108 // SGD weight update
109 sgd_update_all(&mut tw, &grads, lr);
110
111 if step % 10 == 0 || step == steps - 1 {
112 println!(" step {step:3}: loss = {loss:.6}");
113 }
114 }
115
116 // =====================================================================
117 // Step 3: Save trained weights
118 // =====================================================================
119 let save_path = std::env::temp_dir().join("mamba_rs_trained.safetensors");
120 let trained_inf_weights = inference_weights_from_train(&tw, &inf_weights);
121 mamba_rs::serialize::save(&save_path, &trained_inf_weights, &cfg, input_dim)
122 .expect("save failed");
123 println!("\nSaved to: {}", save_path.display());
124
125 // =====================================================================
126 // Step 4: Load and run inference
127 // =====================================================================
128 let (loaded_w, loaded_cfg, loaded_input_dim) =
129 mamba_rs::serialize::load(Path::new(&save_path)).expect("load failed");
130 let bb = MambaBackbone::from_weights(loaded_cfg, loaded_w).expect("from_weights failed");
131 assert_eq!(bb.input_dim(), loaded_input_dim);
132 println!(
133 "Loaded: d_model={}, layers={}, params={}",
134 loaded_cfg.d_model,
135 loaded_cfg.n_layers,
136 bb.param_count()
137 );
138
139 // Run inference step-by-step
140 let mut inf_state = bb.alloc_state();
141 let mut scratch = bb.alloc_scratch();
142 let mut output = vec![0.0f32; dm];
143
144 println!("\nInference (10 steps):");
145 for step in 0..10 {
146 let inp: Vec<f32> = (0..input_dim)
147 .map(|i| ((step * input_dim + i) * 7 + 13) as f32 % 100.0 * 0.01 - 0.5)
148 .collect();
149 bb.forward_step(&inp, &mut output, &mut inf_state, &mut scratch);
150
151 let out_norm: f32 = output.iter().map(|v| v * v).sum::<f32>().sqrt();
152 println!(" step {step}: output L2 norm = {out_norm:.6}");
153 }
154
155 // Cleanup
156 std::fs::remove_file(&save_path).ok();
157 println!("\nDone.");
158}Sourcepub fn into_weights(self) -> MambaWeights
pub fn into_weights(self) -> MambaWeights
Extract owned weights (consuming self).
Sourcepub fn weights(&self) -> &MambaWeights
pub fn weights(&self) -> &MambaWeights
Read-only weight access.
Sourcepub fn weights_mut(&mut self) -> &mut MambaWeights
pub fn weights_mut(&mut self) -> &mut MambaWeights
Mutable weight access (for optimizer updates).
Sourcepub fn layer(&self, index: usize) -> &MambaLayerWeights
pub fn layer(&self, index: usize) -> &MambaLayerWeights
Read-only access to a specific layer’s weights.
Sourcepub fn layer_mut(&mut self, index: usize) -> &mut MambaLayerWeights
pub fn layer_mut(&mut self, index: usize) -> &mut MambaLayerWeights
Mutable access to a specific layer’s weights.
Sourcepub fn n_layers(&self) -> usize
pub fn n_layers(&self) -> usize
Number of layers.
Examples found in repository?
3fn main() {
4 let cfg = MambaConfig::default();
5 let input_dim = cfg.d_model;
6
7 // Initialize backbone with paper-default weights
8 let backbone = MambaBackbone::init(cfg, input_dim, 42);
9 println!(
10 "Mamba: {} layers, d_model={}, d_inner={}, {} params",
11 backbone.n_layers(),
12 backbone.config().d_model,
13 backbone.config().d_inner(),
14 backbone.param_count(),
15 );
16
17 // Allocate recurrent state + scratch (once, reuse across steps)
18 let mut state = backbone.alloc_state();
19 let mut scratch = backbone.alloc_scratch();
20 let mut output = vec![0.0f32; backbone.config().d_model];
21
22 // Run 10 inference steps
23 for step in 0..10 {
24 let input = vec![0.1 * step as f32; input_dim];
25 backbone.forward_step(&input, &mut output, &mut state, &mut scratch);
26
27 let norm: f32 = output.iter().map(|x| x * x).sum::<f32>().sqrt();
28 println!("step {step}: output L2 norm = {norm:.6}");
29 }
30
31 // Reset state for new sequence
32 state.reset();
33 println!("state reset");
34}Sourcepub fn param_count(&self) -> usize
pub fn param_count(&self) -> usize
Total parameter count.
Examples found in repository?
3fn main() {
4 let cfg = MambaConfig::default();
5 let input_dim = cfg.d_model;
6
7 // Initialize backbone with paper-default weights
8 let backbone = MambaBackbone::init(cfg, input_dim, 42);
9 println!(
10 "Mamba: {} layers, d_model={}, d_inner={}, {} params",
11 backbone.n_layers(),
12 backbone.config().d_model,
13 backbone.config().d_inner(),
14 backbone.param_count(),
15 );
16
17 // Allocate recurrent state + scratch (once, reuse across steps)
18 let mut state = backbone.alloc_state();
19 let mut scratch = backbone.alloc_scratch();
20 let mut output = vec![0.0f32; backbone.config().d_model];
21
22 // Run 10 inference steps
23 for step in 0..10 {
24 let input = vec![0.1 * step as f32; input_dim];
25 backbone.forward_step(&input, &mut output, &mut state, &mut scratch);
26
27 let norm: f32 = output.iter().map(|x| x * x).sum::<f32>().sqrt();
28 println!("step {step}: output L2 norm = {norm:.6}");
29 }
30
31 // Reset state for new sequence
32 state.reset();
33 println!("state reset");
34}More examples
18fn main() {
19 let cfg = MambaConfig::default(); // d_model=128, 3 layers, 366K params
20 let input_dim = cfg.d_model;
21 let seq_len = 16;
22 let lr = 1e-3_f32;
23 let steps = 50;
24
25 println!("mamba-rs: train -> save -> load -> inference");
26 println!("=============================================");
27 println!(
28 "Config: d_model={}, layers={}, d_inner={}, params={}",
29 cfg.d_model,
30 cfg.n_layers,
31 cfg.d_inner(),
32 MambaBackbone::init(cfg, input_dim, 42).param_count()
33 );
34 println!();
35
36 // =====================================================================
37 // Step 1: Initialize weights
38 // =====================================================================
39 let inf_weights = MambaWeights::init(&cfg, input_dim, 42);
40 let mut tw = train_weights_from_inference(&inf_weights);
41
42 let dims = MambaDims::from_config(&cfg, seq_len, input_dim);
43 let di = cfg.d_inner();
44 let ds = cfg.d_state;
45 let dc = cfg.d_conv;
46 let dm = cfg.d_model;
47
48 // Pre-allocate scratch (reused every step)
49 let mut acts = MambaBackboneFlat::zeros(dims);
50 let mut fwd_scratch = PhaseScratch::zeros(&dims);
51 let mut bwd_scratch = BackwardPhaseScratch::zeros(&dims);
52
53 // Synthetic training data: random input, target = zeros (regression)
54 let input: Vec<f32> = (0..seq_len * input_dim)
55 .map(|i| ((i * 7 + 13) % 100) as f32 * 0.01 - 0.5)
56 .collect();
57
58 // =====================================================================
59 // Step 2: Training loop (SGD, MSE loss)
60 // =====================================================================
61 println!("Training ({steps} steps, lr={lr}, seq_len={seq_len}):");
62
63 for step in 0..steps {
64 // Precompute a_neg = -exp(a_log)
65 let mut a_neg = vec![0.0f32; cfg.n_layers * di * ds];
66 for (l, lw) in tw.layers.iter().enumerate() {
67 for i in 0..di * ds {
68 a_neg[l * di * ds + i] = -lw.a_log[i].exp();
69 }
70 }
71
72 // Forward pass
73 let mut conv = vec![0.0f32; cfg.n_layers * di * dc];
74 let mut ssm = vec![0.0f32; cfg.n_layers * di * ds];
75 let mut state = MambaRecurrentState {
76 conv: &mut conv,
77 ssm: &mut ssm,
78 a_neg: &a_neg,
79 };
80 let mut temporal = vec![0.0f32; seq_len * dm];
81 forward_mamba_backbone_batched(
82 &mut temporal,
83 &mut acts,
84 &tw,
85 &input,
86 &mut state,
87 &mut fwd_scratch,
88 &dims,
89 );
90
91 // MSE loss
92 let loss: f32 = temporal.iter().map(|v| v * v).sum::<f32>() / (seq_len * dm) as f32;
93
94 // Backward pass
95 let scale = 2.0 / (seq_len * dm) as f32;
96 let mut d_temporal: Vec<f32> = temporal.iter().map(|v| v * scale).collect();
97 let mut grads = TrainMambaWeights::zeros_from_dims(&dims);
98 backward_mamba_backbone_batched(
99 &mut d_temporal,
100 &mut grads,
101 &acts,
102 &tw,
103 &a_neg,
104 &mut bwd_scratch,
105 &dims,
106 );
107
108 // SGD weight update
109 sgd_update_all(&mut tw, &grads, lr);
110
111 if step % 10 == 0 || step == steps - 1 {
112 println!(" step {step:3}: loss = {loss:.6}");
113 }
114 }
115
116 // =====================================================================
117 // Step 3: Save trained weights
118 // =====================================================================
119 let save_path = std::env::temp_dir().join("mamba_rs_trained.safetensors");
120 let trained_inf_weights = inference_weights_from_train(&tw, &inf_weights);
121 mamba_rs::serialize::save(&save_path, &trained_inf_weights, &cfg, input_dim)
122 .expect("save failed");
123 println!("\nSaved to: {}", save_path.display());
124
125 // =====================================================================
126 // Step 4: Load and run inference
127 // =====================================================================
128 let (loaded_w, loaded_cfg, loaded_input_dim) =
129 mamba_rs::serialize::load(Path::new(&save_path)).expect("load failed");
130 let bb = MambaBackbone::from_weights(loaded_cfg, loaded_w).expect("from_weights failed");
131 assert_eq!(bb.input_dim(), loaded_input_dim);
132 println!(
133 "Loaded: d_model={}, layers={}, params={}",
134 loaded_cfg.d_model,
135 loaded_cfg.n_layers,
136 bb.param_count()
137 );
138
139 // Run inference step-by-step
140 let mut inf_state = bb.alloc_state();
141 let mut scratch = bb.alloc_scratch();
142 let mut output = vec![0.0f32; dm];
143
144 println!("\nInference (10 steps):");
145 for step in 0..10 {
146 let inp: Vec<f32> = (0..input_dim)
147 .map(|i| ((step * input_dim + i) * 7 + 13) as f32 % 100.0 * 0.01 - 0.5)
148 .collect();
149 bb.forward_step(&inp, &mut output, &mut inf_state, &mut scratch);
150
151 let out_norm: f32 = output.iter().map(|v| v * v).sum::<f32>().sqrt();
152 println!(" step {step}: output L2 norm = {out_norm:.6}");
153 }
154
155 // Cleanup
156 std::fs::remove_file(&save_path).ok();
157 println!("\nDone.");
158}Sourcepub fn config(&self) -> &MambaConfig
pub fn config(&self) -> &MambaConfig
The config this backbone was built with.
Examples found in repository?
3fn main() {
4 let cfg = MambaConfig::default();
5 let input_dim = cfg.d_model;
6
7 // Initialize backbone with paper-default weights
8 let backbone = MambaBackbone::init(cfg, input_dim, 42);
9 println!(
10 "Mamba: {} layers, d_model={}, d_inner={}, {} params",
11 backbone.n_layers(),
12 backbone.config().d_model,
13 backbone.config().d_inner(),
14 backbone.param_count(),
15 );
16
17 // Allocate recurrent state + scratch (once, reuse across steps)
18 let mut state = backbone.alloc_state();
19 let mut scratch = backbone.alloc_scratch();
20 let mut output = vec![0.0f32; backbone.config().d_model];
21
22 // Run 10 inference steps
23 for step in 0..10 {
24 let input = vec![0.1 * step as f32; input_dim];
25 backbone.forward_step(&input, &mut output, &mut state, &mut scratch);
26
27 let norm: f32 = output.iter().map(|x| x * x).sum::<f32>().sqrt();
28 println!("step {step}: output L2 norm = {norm:.6}");
29 }
30
31 // Reset state for new sequence
32 state.reset();
33 println!("state reset");
34}Sourcepub fn input_dim(&self) -> usize
pub fn input_dim(&self) -> usize
External input dimension.
Examples found in repository?
18fn main() {
19 let cfg = MambaConfig::default(); // d_model=128, 3 layers, 366K params
20 let input_dim = cfg.d_model;
21 let seq_len = 16;
22 let lr = 1e-3_f32;
23 let steps = 50;
24
25 println!("mamba-rs: train -> save -> load -> inference");
26 println!("=============================================");
27 println!(
28 "Config: d_model={}, layers={}, d_inner={}, params={}",
29 cfg.d_model,
30 cfg.n_layers,
31 cfg.d_inner(),
32 MambaBackbone::init(cfg, input_dim, 42).param_count()
33 );
34 println!();
35
36 // =====================================================================
37 // Step 1: Initialize weights
38 // =====================================================================
39 let inf_weights = MambaWeights::init(&cfg, input_dim, 42);
40 let mut tw = train_weights_from_inference(&inf_weights);
41
42 let dims = MambaDims::from_config(&cfg, seq_len, input_dim);
43 let di = cfg.d_inner();
44 let ds = cfg.d_state;
45 let dc = cfg.d_conv;
46 let dm = cfg.d_model;
47
48 // Pre-allocate scratch (reused every step)
49 let mut acts = MambaBackboneFlat::zeros(dims);
50 let mut fwd_scratch = PhaseScratch::zeros(&dims);
51 let mut bwd_scratch = BackwardPhaseScratch::zeros(&dims);
52
53 // Synthetic training data: random input, target = zeros (regression)
54 let input: Vec<f32> = (0..seq_len * input_dim)
55 .map(|i| ((i * 7 + 13) % 100) as f32 * 0.01 - 0.5)
56 .collect();
57
58 // =====================================================================
59 // Step 2: Training loop (SGD, MSE loss)
60 // =====================================================================
61 println!("Training ({steps} steps, lr={lr}, seq_len={seq_len}):");
62
63 for step in 0..steps {
64 // Precompute a_neg = -exp(a_log)
65 let mut a_neg = vec![0.0f32; cfg.n_layers * di * ds];
66 for (l, lw) in tw.layers.iter().enumerate() {
67 for i in 0..di * ds {
68 a_neg[l * di * ds + i] = -lw.a_log[i].exp();
69 }
70 }
71
72 // Forward pass
73 let mut conv = vec![0.0f32; cfg.n_layers * di * dc];
74 let mut ssm = vec![0.0f32; cfg.n_layers * di * ds];
75 let mut state = MambaRecurrentState {
76 conv: &mut conv,
77 ssm: &mut ssm,
78 a_neg: &a_neg,
79 };
80 let mut temporal = vec![0.0f32; seq_len * dm];
81 forward_mamba_backbone_batched(
82 &mut temporal,
83 &mut acts,
84 &tw,
85 &input,
86 &mut state,
87 &mut fwd_scratch,
88 &dims,
89 );
90
91 // MSE loss
92 let loss: f32 = temporal.iter().map(|v| v * v).sum::<f32>() / (seq_len * dm) as f32;
93
94 // Backward pass
95 let scale = 2.0 / (seq_len * dm) as f32;
96 let mut d_temporal: Vec<f32> = temporal.iter().map(|v| v * scale).collect();
97 let mut grads = TrainMambaWeights::zeros_from_dims(&dims);
98 backward_mamba_backbone_batched(
99 &mut d_temporal,
100 &mut grads,
101 &acts,
102 &tw,
103 &a_neg,
104 &mut bwd_scratch,
105 &dims,
106 );
107
108 // SGD weight update
109 sgd_update_all(&mut tw, &grads, lr);
110
111 if step % 10 == 0 || step == steps - 1 {
112 println!(" step {step:3}: loss = {loss:.6}");
113 }
114 }
115
116 // =====================================================================
117 // Step 3: Save trained weights
118 // =====================================================================
119 let save_path = std::env::temp_dir().join("mamba_rs_trained.safetensors");
120 let trained_inf_weights = inference_weights_from_train(&tw, &inf_weights);
121 mamba_rs::serialize::save(&save_path, &trained_inf_weights, &cfg, input_dim)
122 .expect("save failed");
123 println!("\nSaved to: {}", save_path.display());
124
125 // =====================================================================
126 // Step 4: Load and run inference
127 // =====================================================================
128 let (loaded_w, loaded_cfg, loaded_input_dim) =
129 mamba_rs::serialize::load(Path::new(&save_path)).expect("load failed");
130 let bb = MambaBackbone::from_weights(loaded_cfg, loaded_w).expect("from_weights failed");
131 assert_eq!(bb.input_dim(), loaded_input_dim);
132 println!(
133 "Loaded: d_model={}, layers={}, params={}",
134 loaded_cfg.d_model,
135 loaded_cfg.n_layers,
136 bb.param_count()
137 );
138
139 // Run inference step-by-step
140 let mut inf_state = bb.alloc_state();
141 let mut scratch = bb.alloc_scratch();
142 let mut output = vec![0.0f32; dm];
143
144 println!("\nInference (10 steps):");
145 for step in 0..10 {
146 let inp: Vec<f32> = (0..input_dim)
147 .map(|i| ((step * input_dim + i) * 7 + 13) as f32 % 100.0 * 0.01 - 0.5)
148 .collect();
149 bb.forward_step(&inp, &mut output, &mut inf_state, &mut scratch);
150
151 let out_norm: f32 = output.iter().map(|v| v * v).sum::<f32>().sqrt();
152 println!(" step {step}: output L2 norm = {out_norm:.6}");
153 }
154
155 // Cleanup
156 std::fs::remove_file(&save_path).ok();
157 println!("\nDone.");
158}Sourcepub fn forward_step(
&self,
input: &[f32],
output: &mut [f32],
state: &mut MambaState,
scratch: &mut MambaStepScratch,
)
pub fn forward_step( &self, input: &[f32], output: &mut [f32], state: &mut MambaState, scratch: &mut MambaStepScratch, )
Single-step recurrent forward through the full backbone.
input_proj(input) -> N x layer_step -> norm_f -> output
Zero allocations per call. Delegates to mamba_step.
Examples found in repository?
3fn main() {
4 let cfg = MambaConfig::default();
5 let input_dim = cfg.d_model;
6
7 // Initialize backbone with paper-default weights
8 let backbone = MambaBackbone::init(cfg, input_dim, 42);
9 println!(
10 "Mamba: {} layers, d_model={}, d_inner={}, {} params",
11 backbone.n_layers(),
12 backbone.config().d_model,
13 backbone.config().d_inner(),
14 backbone.param_count(),
15 );
16
17 // Allocate recurrent state + scratch (once, reuse across steps)
18 let mut state = backbone.alloc_state();
19 let mut scratch = backbone.alloc_scratch();
20 let mut output = vec![0.0f32; backbone.config().d_model];
21
22 // Run 10 inference steps
23 for step in 0..10 {
24 let input = vec![0.1 * step as f32; input_dim];
25 backbone.forward_step(&input, &mut output, &mut state, &mut scratch);
26
27 let norm: f32 = output.iter().map(|x| x * x).sum::<f32>().sqrt();
28 println!("step {step}: output L2 norm = {norm:.6}");
29 }
30
31 // Reset state for new sequence
32 state.reset();
33 println!("state reset");
34}More examples
18fn main() {
19 let cfg = MambaConfig::default(); // d_model=128, 3 layers, 366K params
20 let input_dim = cfg.d_model;
21 let seq_len = 16;
22 let lr = 1e-3_f32;
23 let steps = 50;
24
25 println!("mamba-rs: train -> save -> load -> inference");
26 println!("=============================================");
27 println!(
28 "Config: d_model={}, layers={}, d_inner={}, params={}",
29 cfg.d_model,
30 cfg.n_layers,
31 cfg.d_inner(),
32 MambaBackbone::init(cfg, input_dim, 42).param_count()
33 );
34 println!();
35
36 // =====================================================================
37 // Step 1: Initialize weights
38 // =====================================================================
39 let inf_weights = MambaWeights::init(&cfg, input_dim, 42);
40 let mut tw = train_weights_from_inference(&inf_weights);
41
42 let dims = MambaDims::from_config(&cfg, seq_len, input_dim);
43 let di = cfg.d_inner();
44 let ds = cfg.d_state;
45 let dc = cfg.d_conv;
46 let dm = cfg.d_model;
47
48 // Pre-allocate scratch (reused every step)
49 let mut acts = MambaBackboneFlat::zeros(dims);
50 let mut fwd_scratch = PhaseScratch::zeros(&dims);
51 let mut bwd_scratch = BackwardPhaseScratch::zeros(&dims);
52
53 // Synthetic training data: random input, target = zeros (regression)
54 let input: Vec<f32> = (0..seq_len * input_dim)
55 .map(|i| ((i * 7 + 13) % 100) as f32 * 0.01 - 0.5)
56 .collect();
57
58 // =====================================================================
59 // Step 2: Training loop (SGD, MSE loss)
60 // =====================================================================
61 println!("Training ({steps} steps, lr={lr}, seq_len={seq_len}):");
62
63 for step in 0..steps {
64 // Precompute a_neg = -exp(a_log)
65 let mut a_neg = vec![0.0f32; cfg.n_layers * di * ds];
66 for (l, lw) in tw.layers.iter().enumerate() {
67 for i in 0..di * ds {
68 a_neg[l * di * ds + i] = -lw.a_log[i].exp();
69 }
70 }
71
72 // Forward pass
73 let mut conv = vec![0.0f32; cfg.n_layers * di * dc];
74 let mut ssm = vec![0.0f32; cfg.n_layers * di * ds];
75 let mut state = MambaRecurrentState {
76 conv: &mut conv,
77 ssm: &mut ssm,
78 a_neg: &a_neg,
79 };
80 let mut temporal = vec![0.0f32; seq_len * dm];
81 forward_mamba_backbone_batched(
82 &mut temporal,
83 &mut acts,
84 &tw,
85 &input,
86 &mut state,
87 &mut fwd_scratch,
88 &dims,
89 );
90
91 // MSE loss
92 let loss: f32 = temporal.iter().map(|v| v * v).sum::<f32>() / (seq_len * dm) as f32;
93
94 // Backward pass
95 let scale = 2.0 / (seq_len * dm) as f32;
96 let mut d_temporal: Vec<f32> = temporal.iter().map(|v| v * scale).collect();
97 let mut grads = TrainMambaWeights::zeros_from_dims(&dims);
98 backward_mamba_backbone_batched(
99 &mut d_temporal,
100 &mut grads,
101 &acts,
102 &tw,
103 &a_neg,
104 &mut bwd_scratch,
105 &dims,
106 );
107
108 // SGD weight update
109 sgd_update_all(&mut tw, &grads, lr);
110
111 if step % 10 == 0 || step == steps - 1 {
112 println!(" step {step:3}: loss = {loss:.6}");
113 }
114 }
115
116 // =====================================================================
117 // Step 3: Save trained weights
118 // =====================================================================
119 let save_path = std::env::temp_dir().join("mamba_rs_trained.safetensors");
120 let trained_inf_weights = inference_weights_from_train(&tw, &inf_weights);
121 mamba_rs::serialize::save(&save_path, &trained_inf_weights, &cfg, input_dim)
122 .expect("save failed");
123 println!("\nSaved to: {}", save_path.display());
124
125 // =====================================================================
126 // Step 4: Load and run inference
127 // =====================================================================
128 let (loaded_w, loaded_cfg, loaded_input_dim) =
129 mamba_rs::serialize::load(Path::new(&save_path)).expect("load failed");
130 let bb = MambaBackbone::from_weights(loaded_cfg, loaded_w).expect("from_weights failed");
131 assert_eq!(bb.input_dim(), loaded_input_dim);
132 println!(
133 "Loaded: d_model={}, layers={}, params={}",
134 loaded_cfg.d_model,
135 loaded_cfg.n_layers,
136 bb.param_count()
137 );
138
139 // Run inference step-by-step
140 let mut inf_state = bb.alloc_state();
141 let mut scratch = bb.alloc_scratch();
142 let mut output = vec![0.0f32; dm];
143
144 println!("\nInference (10 steps):");
145 for step in 0..10 {
146 let inp: Vec<f32> = (0..input_dim)
147 .map(|i| ((step * input_dim + i) * 7 + 13) as f32 % 100.0 * 0.01 - 0.5)
148 .collect();
149 bb.forward_step(&inp, &mut output, &mut inf_state, &mut scratch);
150
151 let out_norm: f32 = output.iter().map(|v| v * v).sum::<f32>().sqrt();
152 println!(" step {step}: output L2 norm = {out_norm:.6}");
153 }
154
155 // Cleanup
156 std::fs::remove_file(&save_path).ok();
157 println!("\nDone.");
158}Sourcepub fn forward_sequence(
&self,
inputs: &[f32],
outputs: &mut [f32],
state: &mut MambaState,
scratch: &mut MambaStepScratch,
seq_len: usize,
)
pub fn forward_sequence( &self, inputs: &[f32], outputs: &mut [f32], state: &mut MambaState, scratch: &mut MambaStepScratch, seq_len: usize, )
Run T inference steps sequentially, collecting all outputs.
inputs: [T * input_dim] — T sequential inputs.
outputs: [T * d_model] — T sequential outputs (written in-place).
State carries across all T steps (warm-up, offline eval, etc.).
Sourcepub fn forward_step_batch(
&self,
inputs: &[f32],
outputs: &mut [f32],
states: &mut [MambaState],
scratches: &mut [MambaStepScratch],
)
pub fn forward_step_batch( &self, inputs: &[f32], outputs: &mut [f32], states: &mut [MambaState], scratches: &mut [MambaStepScratch], )
Batched single-step forward through the backbone.
Processes B independent samples with the same weights.
inputs: [B * input_dim], outputs: [B * d_model].
Sourcepub fn alloc_state(&self) -> MambaState
pub fn alloc_state(&self) -> MambaState
Allocate zeroed recurrent state matching this backbone.
Examples found in repository?
3fn main() {
4 let cfg = MambaConfig::default();
5 let input_dim = cfg.d_model;
6
7 // Initialize backbone with paper-default weights
8 let backbone = MambaBackbone::init(cfg, input_dim, 42);
9 println!(
10 "Mamba: {} layers, d_model={}, d_inner={}, {} params",
11 backbone.n_layers(),
12 backbone.config().d_model,
13 backbone.config().d_inner(),
14 backbone.param_count(),
15 );
16
17 // Allocate recurrent state + scratch (once, reuse across steps)
18 let mut state = backbone.alloc_state();
19 let mut scratch = backbone.alloc_scratch();
20 let mut output = vec![0.0f32; backbone.config().d_model];
21
22 // Run 10 inference steps
23 for step in 0..10 {
24 let input = vec![0.1 * step as f32; input_dim];
25 backbone.forward_step(&input, &mut output, &mut state, &mut scratch);
26
27 let norm: f32 = output.iter().map(|x| x * x).sum::<f32>().sqrt();
28 println!("step {step}: output L2 norm = {norm:.6}");
29 }
30
31 // Reset state for new sequence
32 state.reset();
33 println!("state reset");
34}More examples
18fn main() {
19 let cfg = MambaConfig::default(); // d_model=128, 3 layers, 366K params
20 let input_dim = cfg.d_model;
21 let seq_len = 16;
22 let lr = 1e-3_f32;
23 let steps = 50;
24
25 println!("mamba-rs: train -> save -> load -> inference");
26 println!("=============================================");
27 println!(
28 "Config: d_model={}, layers={}, d_inner={}, params={}",
29 cfg.d_model,
30 cfg.n_layers,
31 cfg.d_inner(),
32 MambaBackbone::init(cfg, input_dim, 42).param_count()
33 );
34 println!();
35
36 // =====================================================================
37 // Step 1: Initialize weights
38 // =====================================================================
39 let inf_weights = MambaWeights::init(&cfg, input_dim, 42);
40 let mut tw = train_weights_from_inference(&inf_weights);
41
42 let dims = MambaDims::from_config(&cfg, seq_len, input_dim);
43 let di = cfg.d_inner();
44 let ds = cfg.d_state;
45 let dc = cfg.d_conv;
46 let dm = cfg.d_model;
47
48 // Pre-allocate scratch (reused every step)
49 let mut acts = MambaBackboneFlat::zeros(dims);
50 let mut fwd_scratch = PhaseScratch::zeros(&dims);
51 let mut bwd_scratch = BackwardPhaseScratch::zeros(&dims);
52
53 // Synthetic training data: random input, target = zeros (regression)
54 let input: Vec<f32> = (0..seq_len * input_dim)
55 .map(|i| ((i * 7 + 13) % 100) as f32 * 0.01 - 0.5)
56 .collect();
57
58 // =====================================================================
59 // Step 2: Training loop (SGD, MSE loss)
60 // =====================================================================
61 println!("Training ({steps} steps, lr={lr}, seq_len={seq_len}):");
62
63 for step in 0..steps {
64 // Precompute a_neg = -exp(a_log)
65 let mut a_neg = vec![0.0f32; cfg.n_layers * di * ds];
66 for (l, lw) in tw.layers.iter().enumerate() {
67 for i in 0..di * ds {
68 a_neg[l * di * ds + i] = -lw.a_log[i].exp();
69 }
70 }
71
72 // Forward pass
73 let mut conv = vec![0.0f32; cfg.n_layers * di * dc];
74 let mut ssm = vec![0.0f32; cfg.n_layers * di * ds];
75 let mut state = MambaRecurrentState {
76 conv: &mut conv,
77 ssm: &mut ssm,
78 a_neg: &a_neg,
79 };
80 let mut temporal = vec![0.0f32; seq_len * dm];
81 forward_mamba_backbone_batched(
82 &mut temporal,
83 &mut acts,
84 &tw,
85 &input,
86 &mut state,
87 &mut fwd_scratch,
88 &dims,
89 );
90
91 // MSE loss
92 let loss: f32 = temporal.iter().map(|v| v * v).sum::<f32>() / (seq_len * dm) as f32;
93
94 // Backward pass
95 let scale = 2.0 / (seq_len * dm) as f32;
96 let mut d_temporal: Vec<f32> = temporal.iter().map(|v| v * scale).collect();
97 let mut grads = TrainMambaWeights::zeros_from_dims(&dims);
98 backward_mamba_backbone_batched(
99 &mut d_temporal,
100 &mut grads,
101 &acts,
102 &tw,
103 &a_neg,
104 &mut bwd_scratch,
105 &dims,
106 );
107
108 // SGD weight update
109 sgd_update_all(&mut tw, &grads, lr);
110
111 if step % 10 == 0 || step == steps - 1 {
112 println!(" step {step:3}: loss = {loss:.6}");
113 }
114 }
115
116 // =====================================================================
117 // Step 3: Save trained weights
118 // =====================================================================
119 let save_path = std::env::temp_dir().join("mamba_rs_trained.safetensors");
120 let trained_inf_weights = inference_weights_from_train(&tw, &inf_weights);
121 mamba_rs::serialize::save(&save_path, &trained_inf_weights, &cfg, input_dim)
122 .expect("save failed");
123 println!("\nSaved to: {}", save_path.display());
124
125 // =====================================================================
126 // Step 4: Load and run inference
127 // =====================================================================
128 let (loaded_w, loaded_cfg, loaded_input_dim) =
129 mamba_rs::serialize::load(Path::new(&save_path)).expect("load failed");
130 let bb = MambaBackbone::from_weights(loaded_cfg, loaded_w).expect("from_weights failed");
131 assert_eq!(bb.input_dim(), loaded_input_dim);
132 println!(
133 "Loaded: d_model={}, layers={}, params={}",
134 loaded_cfg.d_model,
135 loaded_cfg.n_layers,
136 bb.param_count()
137 );
138
139 // Run inference step-by-step
140 let mut inf_state = bb.alloc_state();
141 let mut scratch = bb.alloc_scratch();
142 let mut output = vec![0.0f32; dm];
143
144 println!("\nInference (10 steps):");
145 for step in 0..10 {
146 let inp: Vec<f32> = (0..input_dim)
147 .map(|i| ((step * input_dim + i) * 7 + 13) as f32 % 100.0 * 0.01 - 0.5)
148 .collect();
149 bb.forward_step(&inp, &mut output, &mut inf_state, &mut scratch);
150
151 let out_norm: f32 = output.iter().map(|v| v * v).sum::<f32>().sqrt();
152 println!(" step {step}: output L2 norm = {out_norm:.6}");
153 }
154
155 // Cleanup
156 std::fs::remove_file(&save_path).ok();
157 println!("\nDone.");
158}Sourcepub fn alloc_scratch(&self) -> MambaStepScratch
pub fn alloc_scratch(&self) -> MambaStepScratch
Allocate inference scratch buffers matching this backbone.
Examples found in repository?
3fn main() {
4 let cfg = MambaConfig::default();
5 let input_dim = cfg.d_model;
6
7 // Initialize backbone with paper-default weights
8 let backbone = MambaBackbone::init(cfg, input_dim, 42);
9 println!(
10 "Mamba: {} layers, d_model={}, d_inner={}, {} params",
11 backbone.n_layers(),
12 backbone.config().d_model,
13 backbone.config().d_inner(),
14 backbone.param_count(),
15 );
16
17 // Allocate recurrent state + scratch (once, reuse across steps)
18 let mut state = backbone.alloc_state();
19 let mut scratch = backbone.alloc_scratch();
20 let mut output = vec![0.0f32; backbone.config().d_model];
21
22 // Run 10 inference steps
23 for step in 0..10 {
24 let input = vec![0.1 * step as f32; input_dim];
25 backbone.forward_step(&input, &mut output, &mut state, &mut scratch);
26
27 let norm: f32 = output.iter().map(|x| x * x).sum::<f32>().sqrt();
28 println!("step {step}: output L2 norm = {norm:.6}");
29 }
30
31 // Reset state for new sequence
32 state.reset();
33 println!("state reset");
34}More examples
18fn main() {
19 let cfg = MambaConfig::default(); // d_model=128, 3 layers, 366K params
20 let input_dim = cfg.d_model;
21 let seq_len = 16;
22 let lr = 1e-3_f32;
23 let steps = 50;
24
25 println!("mamba-rs: train -> save -> load -> inference");
26 println!("=============================================");
27 println!(
28 "Config: d_model={}, layers={}, d_inner={}, params={}",
29 cfg.d_model,
30 cfg.n_layers,
31 cfg.d_inner(),
32 MambaBackbone::init(cfg, input_dim, 42).param_count()
33 );
34 println!();
35
36 // =====================================================================
37 // Step 1: Initialize weights
38 // =====================================================================
39 let inf_weights = MambaWeights::init(&cfg, input_dim, 42);
40 let mut tw = train_weights_from_inference(&inf_weights);
41
42 let dims = MambaDims::from_config(&cfg, seq_len, input_dim);
43 let di = cfg.d_inner();
44 let ds = cfg.d_state;
45 let dc = cfg.d_conv;
46 let dm = cfg.d_model;
47
48 // Pre-allocate scratch (reused every step)
49 let mut acts = MambaBackboneFlat::zeros(dims);
50 let mut fwd_scratch = PhaseScratch::zeros(&dims);
51 let mut bwd_scratch = BackwardPhaseScratch::zeros(&dims);
52
53 // Synthetic training data: random input, target = zeros (regression)
54 let input: Vec<f32> = (0..seq_len * input_dim)
55 .map(|i| ((i * 7 + 13) % 100) as f32 * 0.01 - 0.5)
56 .collect();
57
58 // =====================================================================
59 // Step 2: Training loop (SGD, MSE loss)
60 // =====================================================================
61 println!("Training ({steps} steps, lr={lr}, seq_len={seq_len}):");
62
63 for step in 0..steps {
64 // Precompute a_neg = -exp(a_log)
65 let mut a_neg = vec![0.0f32; cfg.n_layers * di * ds];
66 for (l, lw) in tw.layers.iter().enumerate() {
67 for i in 0..di * ds {
68 a_neg[l * di * ds + i] = -lw.a_log[i].exp();
69 }
70 }
71
72 // Forward pass
73 let mut conv = vec![0.0f32; cfg.n_layers * di * dc];
74 let mut ssm = vec![0.0f32; cfg.n_layers * di * ds];
75 let mut state = MambaRecurrentState {
76 conv: &mut conv,
77 ssm: &mut ssm,
78 a_neg: &a_neg,
79 };
80 let mut temporal = vec![0.0f32; seq_len * dm];
81 forward_mamba_backbone_batched(
82 &mut temporal,
83 &mut acts,
84 &tw,
85 &input,
86 &mut state,
87 &mut fwd_scratch,
88 &dims,
89 );
90
91 // MSE loss
92 let loss: f32 = temporal.iter().map(|v| v * v).sum::<f32>() / (seq_len * dm) as f32;
93
94 // Backward pass
95 let scale = 2.0 / (seq_len * dm) as f32;
96 let mut d_temporal: Vec<f32> = temporal.iter().map(|v| v * scale).collect();
97 let mut grads = TrainMambaWeights::zeros_from_dims(&dims);
98 backward_mamba_backbone_batched(
99 &mut d_temporal,
100 &mut grads,
101 &acts,
102 &tw,
103 &a_neg,
104 &mut bwd_scratch,
105 &dims,
106 );
107
108 // SGD weight update
109 sgd_update_all(&mut tw, &grads, lr);
110
111 if step % 10 == 0 || step == steps - 1 {
112 println!(" step {step:3}: loss = {loss:.6}");
113 }
114 }
115
116 // =====================================================================
117 // Step 3: Save trained weights
118 // =====================================================================
119 let save_path = std::env::temp_dir().join("mamba_rs_trained.safetensors");
120 let trained_inf_weights = inference_weights_from_train(&tw, &inf_weights);
121 mamba_rs::serialize::save(&save_path, &trained_inf_weights, &cfg, input_dim)
122 .expect("save failed");
123 println!("\nSaved to: {}", save_path.display());
124
125 // =====================================================================
126 // Step 4: Load and run inference
127 // =====================================================================
128 let (loaded_w, loaded_cfg, loaded_input_dim) =
129 mamba_rs::serialize::load(Path::new(&save_path)).expect("load failed");
130 let bb = MambaBackbone::from_weights(loaded_cfg, loaded_w).expect("from_weights failed");
131 assert_eq!(bb.input_dim(), loaded_input_dim);
132 println!(
133 "Loaded: d_model={}, layers={}, params={}",
134 loaded_cfg.d_model,
135 loaded_cfg.n_layers,
136 bb.param_count()
137 );
138
139 // Run inference step-by-step
140 let mut inf_state = bb.alloc_state();
141 let mut scratch = bb.alloc_scratch();
142 let mut output = vec![0.0f32; dm];
143
144 println!("\nInference (10 steps):");
145 for step in 0..10 {
146 let inp: Vec<f32> = (0..input_dim)
147 .map(|i| ((step * input_dim + i) * 7 + 13) as f32 % 100.0 * 0.01 - 0.5)
148 .collect();
149 bb.forward_step(&inp, &mut output, &mut inf_state, &mut scratch);
150
151 let out_norm: f32 = output.iter().map(|v| v * v).sum::<f32>().sqrt();
152 println!(" step {step}: output L2 norm = {out_norm:.6}");
153 }
154
155 // Cleanup
156 std::fs::remove_file(&save_path).ok();
157 println!("\nDone.");
158}Auto Trait Implementations§
impl Freeze for MambaBackbone
impl RefUnwindSafe for MambaBackbone
impl Send for MambaBackbone
impl Sync for MambaBackbone
impl Unpin for MambaBackbone
impl UnsafeUnpin for MambaBackbone
impl UnwindSafe for MambaBackbone
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read more