mamba_rs/module/backbone.rs
1use crate::config::MambaConfig;
2use crate::inference::{MambaStepScratch, mamba_step};
3use crate::state::MambaState;
4use crate::weights::{MambaLayerWeights, MambaWeights};
5
6/// Complete Mamba backbone: input_proj -> N layers -> norm_f.
7///
8/// Owns all weights. Provides both single-step recurrent inference
9/// and access to raw weights for training integration.
10///
11/// ```rust
12/// use mamba_rs::module::MambaBackbone;
13/// use mamba_rs::MambaConfig;
14///
15/// let cfg = MambaConfig::default();
16/// let backbone = MambaBackbone::init(cfg, 128, 42);
17///
18/// let mut state = backbone.alloc_state();
19/// let mut scratch = backbone.alloc_scratch();
20/// let mut output = vec![0.0f32; backbone.config().d_model];
21///
22/// let input = vec![0.1f32; 128];
23/// backbone.forward_step(&input, &mut output, &mut state, &mut scratch);
24/// ```
25pub struct MambaBackbone {
26 weights: MambaWeights,
27 cfg: MambaConfig,
28 input_dim: usize,
29}
30
31impl MambaBackbone {
32 /// Create a backbone with Mamba-specific weight initialization.
33 ///
34 /// Uses Kaiming uniform for projections, log-space init for A,
35 /// inverse-softplus init for dt_proj bias (Gu & Dao, Section 3.5).
36 pub fn init(cfg: MambaConfig, input_dim: usize, seed: u64) -> Self {
37 let weights = MambaWeights::init(&cfg, input_dim, seed);
38 Self {
39 weights,
40 cfg,
41 input_dim,
42 }
43 }
44
45 /// Create a backbone from pre-loaded weights.
46 ///
47 /// Validates dimensions against config. Returns `Err` on mismatch.
48 pub fn from_weights(cfg: MambaConfig, weights: MambaWeights) -> Result<Self, String> {
49 let input_dim = weights.input_proj_w.len() / cfg.d_model;
50 weights.validate(&cfg, input_dim)?;
51 Ok(Self {
52 weights,
53 cfg,
54 input_dim,
55 })
56 }
57
58 /// Extract owned weights (consuming self).
59 pub fn into_weights(self) -> MambaWeights {
60 self.weights
61 }
62
63 /// Read-only weight access.
64 pub fn weights(&self) -> &MambaWeights {
65 &self.weights
66 }
67
68 /// Mutable weight access (for optimizer updates).
69 pub fn weights_mut(&mut self) -> &mut MambaWeights {
70 &mut self.weights
71 }
72
73 /// Read-only access to a specific layer's weights.
74 pub fn layer(&self, index: usize) -> &MambaLayerWeights {
75 &self.weights.layers[index]
76 }
77
78 /// Mutable access to a specific layer's weights.
79 pub fn layer_mut(&mut self, index: usize) -> &mut MambaLayerWeights {
80 &mut self.weights.layers[index]
81 }
82
83 /// Number of layers.
84 pub fn n_layers(&self) -> usize {
85 self.cfg.n_layers
86 }
87
88 /// Total parameter count.
89 pub fn param_count(&self) -> usize {
90 self.weights.param_count(self.input_dim, &self.cfg)
91 }
92
93 /// The config this backbone was built with.
94 pub fn config(&self) -> &MambaConfig {
95 &self.cfg
96 }
97
98 /// External input dimension.
99 pub fn input_dim(&self) -> usize {
100 self.input_dim
101 }
102
103 /// Single-step recurrent forward through the full backbone.
104 ///
105 /// `input_proj(input) -> N x layer_step -> norm_f -> output`
106 ///
107 /// Zero allocations per call. Delegates to [`mamba_step`].
108 pub fn forward_step(
109 &self,
110 input: &[f32],
111 output: &mut [f32],
112 state: &mut MambaState,
113 scratch: &mut MambaStepScratch,
114 ) {
115 mamba_step(
116 input,
117 output,
118 &self.weights,
119 &mut state.layers,
120 scratch,
121 &self.cfg,
122 self.input_dim,
123 );
124 }
125
126 /// Run T inference steps sequentially, collecting all outputs.
127 ///
128 /// `inputs`: `[T * input_dim]` — T sequential inputs.
129 /// `outputs`: `[T * d_model]` — T sequential outputs (written in-place).
130 /// State carries across all T steps (warm-up, offline eval, etc.).
131 pub fn forward_sequence(
132 &self,
133 inputs: &[f32],
134 outputs: &mut [f32],
135 state: &mut MambaState,
136 scratch: &mut MambaStepScratch,
137 seq_len: usize,
138 ) {
139 let dm = self.cfg.d_model;
140 debug_assert_eq!(inputs.len(), seq_len * self.input_dim);
141 debug_assert_eq!(outputs.len(), seq_len * dm);
142 for t in 0..seq_len {
143 let inp = &inputs[t * self.input_dim..(t + 1) * self.input_dim];
144 let out = &mut outputs[t * dm..(t + 1) * dm];
145 self.forward_step(inp, out, state, scratch);
146 }
147 }
148
149 /// Batched single-step forward through the backbone.
150 ///
151 /// Processes B independent samples with the same weights.
152 /// `inputs`: `[B * input_dim]`, `outputs`: `[B * d_model]`.
153 pub fn forward_step_batch(
154 &self,
155 inputs: &[f32],
156 outputs: &mut [f32],
157 states: &mut [MambaState],
158 scratches: &mut [MambaStepScratch],
159 ) {
160 crate::inference::mamba_step_batch(
161 inputs,
162 outputs,
163 &self.weights,
164 states,
165 scratches,
166 &self.cfg,
167 self.input_dim,
168 );
169 }
170
171 /// Allocate zeroed recurrent state matching this backbone.
172 pub fn alloc_state(&self) -> MambaState {
173 MambaState::zeros(
174 self.cfg.n_layers,
175 self.cfg.d_inner(),
176 self.cfg.d_state,
177 self.cfg.d_conv,
178 )
179 }
180
181 /// Allocate inference scratch buffers matching this backbone.
182 pub fn alloc_scratch(&self) -> MambaStepScratch {
183 MambaStepScratch::new(&self.cfg)
184 }
185}