Skip to main content

mamba_rs/module/
backbone.rs

1use crate::config::MambaConfig;
2use crate::inference::{MambaStepScratch, mamba_step};
3use crate::state::MambaState;
4use crate::weights::{MambaLayerWeights, MambaWeights};
5
6/// Complete Mamba backbone: input_proj -> N layers -> norm_f.
7///
8/// Owns all weights. Provides both single-step recurrent inference
9/// and access to raw weights for training integration.
10///
11/// ```rust
12/// use mamba_rs::module::MambaBackbone;
13/// use mamba_rs::MambaConfig;
14///
15/// let cfg = MambaConfig::default();
16/// let backbone = MambaBackbone::init(cfg, 128, 42);
17///
18/// let mut state = backbone.alloc_state();
19/// let mut scratch = backbone.alloc_scratch();
20/// let mut output = vec![0.0f32; backbone.config().d_model];
21///
22/// let input = vec![0.1f32; 128];
23/// backbone.forward_step(&input, &mut output, &mut state, &mut scratch);
24/// ```
25pub struct MambaBackbone {
26    weights: MambaWeights,
27    cfg: MambaConfig,
28    input_dim: usize,
29}
30
31impl MambaBackbone {
32    /// Create a backbone with Mamba-specific weight initialization.
33    ///
34    /// Uses Kaiming uniform for projections, log-space init for A,
35    /// inverse-softplus init for dt_proj bias (Gu & Dao, Section 3.5).
36    pub fn init(cfg: MambaConfig, input_dim: usize, seed: u64) -> Self {
37        let weights = MambaWeights::init(&cfg, input_dim, seed);
38        Self {
39            weights,
40            cfg,
41            input_dim,
42        }
43    }
44
45    /// Create a backbone from pre-loaded weights.
46    ///
47    /// Validates dimensions against config. Returns `Err` on mismatch.
48    pub fn from_weights(cfg: MambaConfig, weights: MambaWeights) -> Result<Self, String> {
49        let input_dim = weights.input_proj_w.len() / cfg.d_model;
50        weights.validate(&cfg, input_dim)?;
51        Ok(Self {
52            weights,
53            cfg,
54            input_dim,
55        })
56    }
57
58    /// Extract owned weights (consuming self).
59    pub fn into_weights(self) -> MambaWeights {
60        self.weights
61    }
62
63    /// Read-only weight access.
64    pub fn weights(&self) -> &MambaWeights {
65        &self.weights
66    }
67
68    /// Mutable weight access (for optimizer updates).
69    pub fn weights_mut(&mut self) -> &mut MambaWeights {
70        &mut self.weights
71    }
72
73    /// Read-only access to a specific layer's weights.
74    pub fn layer(&self, index: usize) -> &MambaLayerWeights {
75        &self.weights.layers[index]
76    }
77
78    /// Mutable access to a specific layer's weights.
79    pub fn layer_mut(&mut self, index: usize) -> &mut MambaLayerWeights {
80        &mut self.weights.layers[index]
81    }
82
83    /// Number of layers.
84    pub fn n_layers(&self) -> usize {
85        self.cfg.n_layers
86    }
87
88    /// Total parameter count.
89    pub fn param_count(&self) -> usize {
90        self.weights.param_count(self.input_dim, &self.cfg)
91    }
92
93    /// The config this backbone was built with.
94    pub fn config(&self) -> &MambaConfig {
95        &self.cfg
96    }
97
98    /// External input dimension.
99    pub fn input_dim(&self) -> usize {
100        self.input_dim
101    }
102
103    /// Single-step recurrent forward through the full backbone.
104    ///
105    /// `input_proj(input) -> N x layer_step -> norm_f -> output`
106    ///
107    /// Zero allocations per call. Delegates to [`mamba_step`].
108    pub fn forward_step(
109        &self,
110        input: &[f32],
111        output: &mut [f32],
112        state: &mut MambaState,
113        scratch: &mut MambaStepScratch,
114    ) {
115        mamba_step(
116            input,
117            output,
118            &self.weights,
119            &mut state.layers,
120            scratch,
121            &self.cfg,
122            self.input_dim,
123        );
124    }
125
126    /// Run T inference steps sequentially, collecting all outputs.
127    ///
128    /// `inputs`: `[T * input_dim]` — T sequential inputs.
129    /// `outputs`: `[T * d_model]` — T sequential outputs (written in-place).
130    /// State carries across all T steps (warm-up, offline eval, etc.).
131    pub fn forward_sequence(
132        &self,
133        inputs: &[f32],
134        outputs: &mut [f32],
135        state: &mut MambaState,
136        scratch: &mut MambaStepScratch,
137        seq_len: usize,
138    ) {
139        let dm = self.cfg.d_model;
140        debug_assert_eq!(inputs.len(), seq_len * self.input_dim);
141        debug_assert_eq!(outputs.len(), seq_len * dm);
142        for t in 0..seq_len {
143            let inp = &inputs[t * self.input_dim..(t + 1) * self.input_dim];
144            let out = &mut outputs[t * dm..(t + 1) * dm];
145            self.forward_step(inp, out, state, scratch);
146        }
147    }
148
149    /// Batched single-step forward through the backbone.
150    ///
151    /// Processes B independent samples with the same weights.
152    /// `inputs`: `[B * input_dim]`, `outputs`: `[B * d_model]`.
153    pub fn forward_step_batch(
154        &self,
155        inputs: &[f32],
156        outputs: &mut [f32],
157        states: &mut [MambaState],
158        scratches: &mut [MambaStepScratch],
159    ) {
160        crate::inference::mamba_step_batch(
161            inputs,
162            outputs,
163            &self.weights,
164            states,
165            scratches,
166            &self.cfg,
167            self.input_dim,
168        );
169    }
170
171    /// Allocate zeroed recurrent state matching this backbone.
172    pub fn alloc_state(&self) -> MambaState {
173        MambaState::zeros(
174            self.cfg.n_layers,
175            self.cfg.d_inner(),
176            self.cfg.d_state,
177            self.cfg.d_conv,
178        )
179    }
180
181    /// Allocate inference scratch buffers matching this backbone.
182    pub fn alloc_scratch(&self) -> MambaStepScratch {
183        MambaStepScratch::new(&self.cfg)
184    }
185}