gllm_kernels/ops/
mamba.rs

1//! Mamba-2 hybrid selective state space utilities.
2
3use std::marker::PhantomData;
4
5use burn::tensor::backend::Backend;
6use burn::tensor::{Tensor, TensorData};
7
8/// Mamba block configuration for selective state space models.
9#[derive(Debug, Clone, Copy)]
10pub struct MambaConfig {
11    /// State space dimension.
12    pub state_dim: usize,
13    /// Input dimension.
14    pub input_dim: usize,
15    /// Expansion ratio (typical: 2).
16    pub expand_ratio: usize,
17    /// Enable selective gating.
18    pub selective: bool,
19}
20
21/// Mamba block for selective state space modeling.
22pub struct MambaBlock<B: Backend> {
23    /// State space dimension.
24    state_dim: usize,
25    /// Input dimension.
26    input_dim: usize,
27    /// Expansion ratio.
28    expand_ratio: usize,
29    _marker: PhantomData<B>,
30}
31
32/// Parameters for Mamba selective state space projection.
33#[derive(Debug, Clone)]
34pub struct MambaParameters<B: Backend> {
35    /// Time-step projection weights: [expanded_dim, state_dim].
36    pub dt_proj: Tensor<B, 2>,
37    /// State decay parameters (diagonal): [state_dim].
38    pub a: Tensor<B, 1>,
39    /// Input projection weights: [expanded_dim, state_dim].
40    pub b: Tensor<B, 2>,
41    /// Output projection weights: [state_dim, expanded_dim].
42    pub c: Tensor<B, 2>,
43    /// Skip connection weights (diagonal): [expanded_dim].
44    pub d: Tensor<B, 1>,
45}
46
47/// Stateful cache for Mamba recurrence.
48#[derive(Debug, Clone)]
49pub struct MambaState<B: Backend> {
50    /// State tensor: [batch, state_dim].
51    pub state: Tensor<B, 2>,
52}
53
54/// Hybrid strategy for mixing attention and Mamba outputs.
55#[derive(Debug, Clone, Copy)]
56pub enum HybridStrategy {
57    /// Alternate between attention and Mamba per layer index.
58    Alternating,
59    /// Parallel blend with a fixed Mamba weight.
60    Parallel { mamba_weight: f32 },
61    /// Adaptive blend based on content energy.
62    Adaptive { min_weight: f32, max_weight: f32 },
63}
64
65/// Hybrid layer for combining attention and Mamba outputs.
66pub struct HybridLayer<B: Backend> {
67    /// Blending strategy.
68    strategy: HybridStrategy,
69    /// Layer index for alternating strategy.
70    layer_index: usize,
71    _marker: PhantomData<B>,
72}
73
74impl MambaConfig {
75    /// Validate configuration values.
76    pub fn validate(&self) -> Result<(), &'static str> {
77        if self.state_dim == 0 {
78            return Err("state_dim must be > 0");
79        }
80        if self.input_dim == 0 {
81            return Err("input_dim must be > 0");
82        }
83        if self.expand_ratio == 0 {
84            return Err("expand_ratio must be > 0");
85        }
86        Ok(())
87    }
88}
89
90impl<B: Backend> MambaBlock<B> {
91    /// Create a new Mamba block.
92    pub fn new(state_dim: usize, input_dim: usize, expand_ratio: usize) -> Self {
93        Self {
94            state_dim,
95            input_dim,
96            expand_ratio,
97            _marker: PhantomData,
98        }
99    }
100
101    /// Create a Mamba block from configuration.
102    pub fn from_config(config: &MambaConfig) -> Self {
103        Self::new(config.state_dim, config.input_dim, config.expand_ratio)
104    }
105
106    /// State dimension configured for the block.
107    pub fn state_dim(&self) -> usize {
108        self.state_dim
109    }
110
111    /// Input dimension configured for the block.
112    pub fn input_dim(&self) -> usize {
113        self.input_dim
114    }
115
116    /// Expansion ratio configured for the block.
117    pub fn expand_ratio(&self) -> usize {
118        self.expand_ratio
119    }
120
121    /// Forward pass for selective state space modeling.
122    ///
123    /// # Shapes
124    /// * `input`: [batch, seq_len, input_dim]
125    /// * `state`: [batch, state_dim]
126    pub fn forward(
127        &self,
128        input: Tensor<B, 3>,
129        params: &MambaParameters<B>,
130        state: Option<MambaState<B>>,
131        selective: bool,
132    ) -> Result<(Tensor<B, 3>, MambaState<B>), &'static str> {
133        let [batch, seq_len, input_dim] = input.dims();
134        if input_dim != self.input_dim {
135            return Err("input dimension mismatch");
136        }
137        if batch == 0 || seq_len == 0 {
138            return Err("input batch/seq must be > 0");
139        }
140
141        let expanded_dim = match self.input_dim.checked_mul(self.expand_ratio) {
142            Some(value) => value,
143            None => return Err("expanded dimension overflow"),
144        };
145        params.validate(expanded_dim, self.state_dim)?;
146
147        let device = input.device();
148        let input_data = input
149            .into_data()
150            .into_vec::<f32>()
151            .map_err(|_| "input data conversion failed")?;
152        let dt_proj_data = params
153            .dt_proj
154            .clone()
155            .into_data()
156            .into_vec::<f32>()
157            .map_err(|_| "dt_proj conversion failed")?;
158        let a_data = params
159            .a
160            .clone()
161            .into_data()
162            .into_vec::<f32>()
163            .map_err(|_| "a conversion failed")?;
164        let b_data = params
165            .b
166            .clone()
167            .into_data()
168            .into_vec::<f32>()
169            .map_err(|_| "b conversion failed")?;
170        let c_data = params
171            .c
172            .clone()
173            .into_data()
174            .into_vec::<f32>()
175            .map_err(|_| "c conversion failed")?;
176        let d_data = params
177            .d
178            .clone()
179            .into_data()
180            .into_vec::<f32>()
181            .map_err(|_| "d conversion failed")?;
182
183        let mut state_data = match state {
184            Some(state) => state
185                .state
186                .into_data()
187                .into_vec::<f32>()
188                .map_err(|_| "state conversion failed")?,
189            None => vec![0.0f32; batch * self.state_dim],
190        };
191        if state_data.len() != batch * self.state_dim {
192            return Err("state dimension mismatch");
193        }
194
195        let a_values: Vec<f32> = a_data.iter().map(|value| -value.exp()).collect();
196        let mut output_data = vec![0.0f32; batch * seq_len * input_dim];
197
198        for batch_idx in 0..batch {
199            for time_idx in 0..seq_len {
200                let input_offset = (batch_idx * seq_len + time_idx) * input_dim;
201                let mut expanded_input = vec![0.0f32; expanded_dim];
202                for i in 0..input_dim {
203                    let value = input_data[input_offset + i];
204                    for r in 0..self.expand_ratio {
205                        expanded_input[i * self.expand_ratio + r] = value;
206                    }
207                }
208
209                for s in 0..self.state_dim {
210                    let mut dt_pre = 0.0f32;
211                    let mut input_proj = 0.0f32;
212                    for i in 0..expanded_dim {
213                        let x = expanded_input[i];
214                        dt_pre += x * dt_proj_data[i * self.state_dim + s];
215                        input_proj += x * b_data[i * self.state_dim + s];
216                    }
217                    let mut dt = softplus(dt_pre);
218                    if selective {
219                        dt *= sigmoid(dt_pre);
220                    }
221                    let decay = (a_values[s] * dt).exp();
222                    let state_idx = batch_idx * self.state_dim + s;
223                    let next = state_data[state_idx] * decay + input_proj * dt;
224                    state_data[state_idx] = next;
225                }
226
227                for j in 0..input_dim {
228                    let mut sum = 0.0f32;
229                    for r in 0..self.expand_ratio {
230                        let idx = j * self.expand_ratio + r;
231                        let mut y = 0.0f32;
232                        for s in 0..self.state_dim {
233                            y += state_data[batch_idx * self.state_dim + s]
234                                * c_data[s * expanded_dim + idx];
235                        }
236                        y += expanded_input[idx] * d_data[idx];
237                        sum += y;
238                    }
239                    output_data[(batch_idx * seq_len + time_idx) * input_dim + j] =
240                        sum / self.expand_ratio as f32;
241                }
242            }
243        }
244
245        let output =
246            Tensor::from_data(TensorData::new(output_data, [batch, seq_len, input_dim]), &device);
247        let state =
248            Tensor::from_data(TensorData::new(state_data, [batch, self.state_dim]), &device);
249        Ok((output, MambaState { state }))
250    }
251
252    /// Forward pass using configuration for selective gating.
253    pub fn forward_with_config(
254        &self,
255        input: Tensor<B, 3>,
256        params: &MambaParameters<B>,
257        state: Option<MambaState<B>>,
258        config: &MambaConfig,
259    ) -> Result<(Tensor<B, 3>, MambaState<B>), &'static str> {
260        config.validate()?;
261        if config.state_dim != self.state_dim
262            || config.input_dim != self.input_dim
263            || config.expand_ratio != self.expand_ratio
264        {
265            return Err("config mismatch for Mamba block");
266        }
267        self.forward(input, params, state, config.selective)
268    }
269}
270
271impl<B: Backend> MambaParameters<B> {
272    /// Validate parameter tensor shapes.
273    pub fn validate(&self, expanded_dim: usize, state_dim: usize) -> Result<(), &'static str> {
274        if self.dt_proj.dims() != [expanded_dim, state_dim] {
275            return Err("dt_proj shape mismatch");
276        }
277        if self.a.dims() != [state_dim] {
278            return Err("a shape mismatch");
279        }
280        if self.b.dims() != [expanded_dim, state_dim] {
281            return Err("b shape mismatch");
282        }
283        if self.c.dims() != [state_dim, expanded_dim] {
284            return Err("c shape mismatch");
285        }
286        if self.d.dims() != [expanded_dim] {
287            return Err("d shape mismatch");
288        }
289        Ok(())
290    }
291}
292
293impl<B: Backend> HybridLayer<B> {
294    /// Create a new hybrid layer.
295    pub fn new(strategy: HybridStrategy, layer_index: usize) -> Self {
296        Self {
297            strategy,
298            layer_index,
299            _marker: PhantomData,
300        }
301    }
302
303    /// Hybrid strategy configured for this layer.
304    pub fn strategy(&self) -> HybridStrategy {
305        self.strategy
306    }
307
308    /// Layer index for alternating strategy.
309    pub fn layer_index(&self) -> usize {
310        self.layer_index
311    }
312
313    /// Combine attention and Mamba outputs.
314    ///
315    /// # Shapes
316    /// * `attention`: [batch, seq_len, dim]
317    /// * `mamba`: [batch, seq_len, dim]
318    pub fn combine(
319        &self,
320        attention: Tensor<B, 3>,
321        mamba: Tensor<B, 3>,
322    ) -> Result<Tensor<B, 3>, &'static str> {
323        let attn_dims = attention.dims();
324        let mamba_dims = mamba.dims();
325        if attn_dims != mamba_dims {
326            return Err("attention/mamba dimension mismatch");
327        }
328
329        match self.strategy {
330            HybridStrategy::Alternating => {
331                if self.layer_index % 2 == 0 {
332                    Ok(attention)
333                } else {
334                    Ok(mamba)
335                }
336            }
337            HybridStrategy::Parallel { mamba_weight } => {
338                let weight = clamp_weight(mamba_weight);
339                blend_fixed(attention, mamba, weight)
340            }
341            HybridStrategy::Adaptive { min_weight, max_weight } => {
342                blend_adaptive(attention, mamba, min_weight, max_weight)
343            }
344        }
345    }
346}
347
348fn sigmoid(x: f32) -> f32 {
349    1.0 / (1.0 + (-x).exp())
350}
351
352fn softplus(x: f32) -> f32 {
353    if x > 20.0 {
354        x
355    } else {
356        (1.0 + x.exp()).ln()
357    }
358}
359
360fn clamp_weight(weight: f32) -> f32 {
361    if weight < 0.0 {
362        0.0
363    } else if weight > 1.0 {
364        1.0
365    } else {
366        weight
367    }
368}
369
370fn blend_fixed<B: Backend>(
371    attention: Tensor<B, 3>,
372    mamba: Tensor<B, 3>,
373    weight: f32,
374) -> Result<Tensor<B, 3>, &'static str> {
375    let device = attention.device();
376    let dims = attention.dims();
377    let attn_data = attention
378        .into_data()
379        .into_vec::<f32>()
380        .map_err(|_| "attention conversion failed")?;
381    let mamba_data = mamba
382        .into_data()
383        .into_vec::<f32>()
384        .map_err(|_| "mamba conversion failed")?;
385    let mut output = vec![0.0f32; attn_data.len()];
386    let inv = 1.0 - weight;
387    for (idx, value) in output.iter_mut().enumerate() {
388        *value = attn_data[idx] * inv + mamba_data[idx] * weight;
389    }
390    Ok(Tensor::from_data(TensorData::new(output, dims), &device))
391}
392
393fn blend_adaptive<B: Backend>(
394    attention: Tensor<B, 3>,
395    mamba: Tensor<B, 3>,
396    min_weight: f32,
397    max_weight: f32,
398) -> Result<Tensor<B, 3>, &'static str> {
399    let device = attention.device();
400    let [batch, seq_len, dim] = attention.dims();
401    let attn_data = attention
402        .into_data()
403        .into_vec::<f32>()
404        .map_err(|_| "attention conversion failed")?;
405    let mamba_data = mamba
406        .into_data()
407        .into_vec::<f32>()
408        .map_err(|_| "mamba conversion failed")?;
409
410    let mut output = vec![0.0f32; attn_data.len()];
411    let per_batch = seq_len * dim;
412    for b in 0..batch {
413        let base = b * per_batch;
414        let mut attn_energy = 0.0f32;
415        let mut mamba_energy = 0.0f32;
416        for i in 0..per_batch {
417            attn_energy += attn_data[base + i].abs();
418            mamba_energy += mamba_data[base + i].abs();
419        }
420        let denom = attn_energy + mamba_energy + 1e-6;
421        let mut weight = mamba_energy / denom;
422        let min_w = min_weight.min(max_weight);
423        let max_w = max_weight.max(min_weight);
424        if weight < min_w {
425            weight = min_w;
426        } else if weight > max_w {
427            weight = max_w;
428        }
429        let inv = 1.0 - weight;
430        for i in 0..per_batch {
431            output[base + i] = attn_data[base + i] * inv + mamba_data[base + i] * weight;
432        }
433    }
434
435    Ok(Tensor::from_data(
436        TensorData::new(output, [batch, seq_len, dim]),
437        &device,
438    ))
439}
440
441#[cfg(all(test, feature = "cpu"))]
442mod tests {
443    use super::*;
444    use burn_ndarray::NdArray;
445
446    #[test]
447    fn test_mamba_forward_shapes() {
448        let config = MambaConfig {
449            state_dim: 2,
450            input_dim: 3,
451            expand_ratio: 2,
452            selective: true,
453        };
454        let block = MambaBlock::<NdArray<f32>>::from_config(&config);
455        let device = <NdArray<f32> as Backend>::Device::default();
456        let input = Tensor::from_data(
457            TensorData::new(vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6], [1, 2, 3]),
458            &device,
459        );
460        let params = MambaParameters {
461            dt_proj: Tensor::from_data(TensorData::new(vec![0.05; 12], [6, 2]), &device),
462            a: Tensor::from_data(TensorData::new(vec![0.1, 0.2], [2]), &device),
463            b: Tensor::from_data(TensorData::new(vec![0.02; 12], [6, 2]), &device),
464            c: Tensor::from_data(TensorData::new(vec![0.03; 12], [2, 6]), &device),
465            d: Tensor::from_data(TensorData::new(vec![0.1; 6], [6]), &device),
466        };
467
468        let (output, state) = block
469            .forward_with_config(input, &params, None, &config)
470            .expect("forward");
471        assert_eq!(output.dims(), [1, 2, 3]);
472        assert_eq!(state.state.dims(), [1, 2]);
473    }
474
475    #[test]
476    fn test_hybrid_parallel_blend() {
477        let layer = HybridLayer::<NdArray<f32>>::new(
478            HybridStrategy::Parallel { mamba_weight: 0.25 },
479            0,
480        );
481        let device = <NdArray<f32> as Backend>::Device::default();
482        let attention =
483            Tensor::from_data(TensorData::new(vec![1.0, 3.0], [1, 1, 2]), &device);
484        let mamba = Tensor::from_data(TensorData::new(vec![5.0, 1.0], [1, 1, 2]), &device);
485
486        let output = layer.combine(attention, mamba).expect("combine");
487        let data = output
488            .into_data()
489            .into_vec::<f32>()
490            .expect("output data");
491        assert!((data[0] - 2.0).abs() < 1e-4);
492        assert!((data[1] - 2.5).abs() < 1e-4);
493    }
494}