1use std::marker::PhantomData;
4
5use burn::tensor::backend::Backend;
6use burn::tensor::{Tensor, TensorData};
7
8#[derive(Debug, Clone, Copy)]
10pub struct MambaConfig {
11 pub state_dim: usize,
13 pub input_dim: usize,
15 pub expand_ratio: usize,
17 pub selective: bool,
19}
20
21pub struct MambaBlock<B: Backend> {
23 state_dim: usize,
25 input_dim: usize,
27 expand_ratio: usize,
29 _marker: PhantomData<B>,
30}
31
32#[derive(Debug, Clone)]
34pub struct MambaParameters<B: Backend> {
35 pub dt_proj: Tensor<B, 2>,
37 pub a: Tensor<B, 1>,
39 pub b: Tensor<B, 2>,
41 pub c: Tensor<B, 2>,
43 pub d: Tensor<B, 1>,
45}
46
47#[derive(Debug, Clone)]
49pub struct MambaState<B: Backend> {
50 pub state: Tensor<B, 2>,
52}
53
54#[derive(Debug, Clone, Copy)]
56pub enum HybridStrategy {
57 Alternating,
59 Parallel { mamba_weight: f32 },
61 Adaptive { min_weight: f32, max_weight: f32 },
63}
64
65pub struct HybridLayer<B: Backend> {
67 strategy: HybridStrategy,
69 layer_index: usize,
71 _marker: PhantomData<B>,
72}
73
74impl MambaConfig {
75 pub fn validate(&self) -> Result<(), &'static str> {
77 if self.state_dim == 0 {
78 return Err("state_dim must be > 0");
79 }
80 if self.input_dim == 0 {
81 return Err("input_dim must be > 0");
82 }
83 if self.expand_ratio == 0 {
84 return Err("expand_ratio must be > 0");
85 }
86 Ok(())
87 }
88}
89
90impl<B: Backend> MambaBlock<B> {
91 pub fn new(state_dim: usize, input_dim: usize, expand_ratio: usize) -> Self {
93 Self {
94 state_dim,
95 input_dim,
96 expand_ratio,
97 _marker: PhantomData,
98 }
99 }
100
101 pub fn from_config(config: &MambaConfig) -> Self {
103 Self::new(config.state_dim, config.input_dim, config.expand_ratio)
104 }
105
106 pub fn state_dim(&self) -> usize {
108 self.state_dim
109 }
110
111 pub fn input_dim(&self) -> usize {
113 self.input_dim
114 }
115
116 pub fn expand_ratio(&self) -> usize {
118 self.expand_ratio
119 }
120
121 pub fn forward(
127 &self,
128 input: Tensor<B, 3>,
129 params: &MambaParameters<B>,
130 state: Option<MambaState<B>>,
131 selective: bool,
132 ) -> Result<(Tensor<B, 3>, MambaState<B>), &'static str> {
133 let [batch, seq_len, input_dim] = input.dims();
134 if input_dim != self.input_dim {
135 return Err("input dimension mismatch");
136 }
137 if batch == 0 || seq_len == 0 {
138 return Err("input batch/seq must be > 0");
139 }
140
141 let expanded_dim = match self.input_dim.checked_mul(self.expand_ratio) {
142 Some(value) => value,
143 None => return Err("expanded dimension overflow"),
144 };
145 params.validate(expanded_dim, self.state_dim)?;
146
147 let device = input.device();
148 let input_data = input
149 .into_data()
150 .into_vec::<f32>()
151 .map_err(|_| "input data conversion failed")?;
152 let dt_proj_data = params
153 .dt_proj
154 .clone()
155 .into_data()
156 .into_vec::<f32>()
157 .map_err(|_| "dt_proj conversion failed")?;
158 let a_data = params
159 .a
160 .clone()
161 .into_data()
162 .into_vec::<f32>()
163 .map_err(|_| "a conversion failed")?;
164 let b_data = params
165 .b
166 .clone()
167 .into_data()
168 .into_vec::<f32>()
169 .map_err(|_| "b conversion failed")?;
170 let c_data = params
171 .c
172 .clone()
173 .into_data()
174 .into_vec::<f32>()
175 .map_err(|_| "c conversion failed")?;
176 let d_data = params
177 .d
178 .clone()
179 .into_data()
180 .into_vec::<f32>()
181 .map_err(|_| "d conversion failed")?;
182
183 let mut state_data = match state {
184 Some(state) => state
185 .state
186 .into_data()
187 .into_vec::<f32>()
188 .map_err(|_| "state conversion failed")?,
189 None => vec![0.0f32; batch * self.state_dim],
190 };
191 if state_data.len() != batch * self.state_dim {
192 return Err("state dimension mismatch");
193 }
194
195 let a_values: Vec<f32> = a_data.iter().map(|value| -value.exp()).collect();
196 let mut output_data = vec![0.0f32; batch * seq_len * input_dim];
197
198 for batch_idx in 0..batch {
199 for time_idx in 0..seq_len {
200 let input_offset = (batch_idx * seq_len + time_idx) * input_dim;
201 let mut expanded_input = vec![0.0f32; expanded_dim];
202 for i in 0..input_dim {
203 let value = input_data[input_offset + i];
204 for r in 0..self.expand_ratio {
205 expanded_input[i * self.expand_ratio + r] = value;
206 }
207 }
208
209 for s in 0..self.state_dim {
210 let mut dt_pre = 0.0f32;
211 let mut input_proj = 0.0f32;
212 for i in 0..expanded_dim {
213 let x = expanded_input[i];
214 dt_pre += x * dt_proj_data[i * self.state_dim + s];
215 input_proj += x * b_data[i * self.state_dim + s];
216 }
217 let mut dt = softplus(dt_pre);
218 if selective {
219 dt *= sigmoid(dt_pre);
220 }
221 let decay = (a_values[s] * dt).exp();
222 let state_idx = batch_idx * self.state_dim + s;
223 let next = state_data[state_idx] * decay + input_proj * dt;
224 state_data[state_idx] = next;
225 }
226
227 for j in 0..input_dim {
228 let mut sum = 0.0f32;
229 for r in 0..self.expand_ratio {
230 let idx = j * self.expand_ratio + r;
231 let mut y = 0.0f32;
232 for s in 0..self.state_dim {
233 y += state_data[batch_idx * self.state_dim + s]
234 * c_data[s * expanded_dim + idx];
235 }
236 y += expanded_input[idx] * d_data[idx];
237 sum += y;
238 }
239 output_data[(batch_idx * seq_len + time_idx) * input_dim + j] =
240 sum / self.expand_ratio as f32;
241 }
242 }
243 }
244
245 let output =
246 Tensor::from_data(TensorData::new(output_data, [batch, seq_len, input_dim]), &device);
247 let state =
248 Tensor::from_data(TensorData::new(state_data, [batch, self.state_dim]), &device);
249 Ok((output, MambaState { state }))
250 }
251
252 pub fn forward_with_config(
254 &self,
255 input: Tensor<B, 3>,
256 params: &MambaParameters<B>,
257 state: Option<MambaState<B>>,
258 config: &MambaConfig,
259 ) -> Result<(Tensor<B, 3>, MambaState<B>), &'static str> {
260 config.validate()?;
261 if config.state_dim != self.state_dim
262 || config.input_dim != self.input_dim
263 || config.expand_ratio != self.expand_ratio
264 {
265 return Err("config mismatch for Mamba block");
266 }
267 self.forward(input, params, state, config.selective)
268 }
269}
270
271impl<B: Backend> MambaParameters<B> {
272 pub fn validate(&self, expanded_dim: usize, state_dim: usize) -> Result<(), &'static str> {
274 if self.dt_proj.dims() != [expanded_dim, state_dim] {
275 return Err("dt_proj shape mismatch");
276 }
277 if self.a.dims() != [state_dim] {
278 return Err("a shape mismatch");
279 }
280 if self.b.dims() != [expanded_dim, state_dim] {
281 return Err("b shape mismatch");
282 }
283 if self.c.dims() != [state_dim, expanded_dim] {
284 return Err("c shape mismatch");
285 }
286 if self.d.dims() != [expanded_dim] {
287 return Err("d shape mismatch");
288 }
289 Ok(())
290 }
291}
292
293impl<B: Backend> HybridLayer<B> {
294 pub fn new(strategy: HybridStrategy, layer_index: usize) -> Self {
296 Self {
297 strategy,
298 layer_index,
299 _marker: PhantomData,
300 }
301 }
302
303 pub fn strategy(&self) -> HybridStrategy {
305 self.strategy
306 }
307
308 pub fn layer_index(&self) -> usize {
310 self.layer_index
311 }
312
313 pub fn combine(
319 &self,
320 attention: Tensor<B, 3>,
321 mamba: Tensor<B, 3>,
322 ) -> Result<Tensor<B, 3>, &'static str> {
323 let attn_dims = attention.dims();
324 let mamba_dims = mamba.dims();
325 if attn_dims != mamba_dims {
326 return Err("attention/mamba dimension mismatch");
327 }
328
329 match self.strategy {
330 HybridStrategy::Alternating => {
331 if self.layer_index % 2 == 0 {
332 Ok(attention)
333 } else {
334 Ok(mamba)
335 }
336 }
337 HybridStrategy::Parallel { mamba_weight } => {
338 let weight = clamp_weight(mamba_weight);
339 blend_fixed(attention, mamba, weight)
340 }
341 HybridStrategy::Adaptive { min_weight, max_weight } => {
342 blend_adaptive(attention, mamba, min_weight, max_weight)
343 }
344 }
345 }
346}
347
348fn sigmoid(x: f32) -> f32 {
349 1.0 / (1.0 + (-x).exp())
350}
351
352fn softplus(x: f32) -> f32 {
353 if x > 20.0 {
354 x
355 } else {
356 (1.0 + x.exp()).ln()
357 }
358}
359
360fn clamp_weight(weight: f32) -> f32 {
361 if weight < 0.0 {
362 0.0
363 } else if weight > 1.0 {
364 1.0
365 } else {
366 weight
367 }
368}
369
370fn blend_fixed<B: Backend>(
371 attention: Tensor<B, 3>,
372 mamba: Tensor<B, 3>,
373 weight: f32,
374) -> Result<Tensor<B, 3>, &'static str> {
375 let device = attention.device();
376 let dims = attention.dims();
377 let attn_data = attention
378 .into_data()
379 .into_vec::<f32>()
380 .map_err(|_| "attention conversion failed")?;
381 let mamba_data = mamba
382 .into_data()
383 .into_vec::<f32>()
384 .map_err(|_| "mamba conversion failed")?;
385 let mut output = vec![0.0f32; attn_data.len()];
386 let inv = 1.0 - weight;
387 for (idx, value) in output.iter_mut().enumerate() {
388 *value = attn_data[idx] * inv + mamba_data[idx] * weight;
389 }
390 Ok(Tensor::from_data(TensorData::new(output, dims), &device))
391}
392
393fn blend_adaptive<B: Backend>(
394 attention: Tensor<B, 3>,
395 mamba: Tensor<B, 3>,
396 min_weight: f32,
397 max_weight: f32,
398) -> Result<Tensor<B, 3>, &'static str> {
399 let device = attention.device();
400 let [batch, seq_len, dim] = attention.dims();
401 let attn_data = attention
402 .into_data()
403 .into_vec::<f32>()
404 .map_err(|_| "attention conversion failed")?;
405 let mamba_data = mamba
406 .into_data()
407 .into_vec::<f32>()
408 .map_err(|_| "mamba conversion failed")?;
409
410 let mut output = vec![0.0f32; attn_data.len()];
411 let per_batch = seq_len * dim;
412 for b in 0..batch {
413 let base = b * per_batch;
414 let mut attn_energy = 0.0f32;
415 let mut mamba_energy = 0.0f32;
416 for i in 0..per_batch {
417 attn_energy += attn_data[base + i].abs();
418 mamba_energy += mamba_data[base + i].abs();
419 }
420 let denom = attn_energy + mamba_energy + 1e-6;
421 let mut weight = mamba_energy / denom;
422 let min_w = min_weight.min(max_weight);
423 let max_w = max_weight.max(min_weight);
424 if weight < min_w {
425 weight = min_w;
426 } else if weight > max_w {
427 weight = max_w;
428 }
429 let inv = 1.0 - weight;
430 for i in 0..per_batch {
431 output[base + i] = attn_data[base + i] * inv + mamba_data[base + i] * weight;
432 }
433 }
434
435 Ok(Tensor::from_data(
436 TensorData::new(output, [batch, seq_len, dim]),
437 &device,
438 ))
439}
440
441#[cfg(all(test, feature = "cpu"))]
442mod tests {
443 use super::*;
444 use burn_ndarray::NdArray;
445
446 #[test]
447 fn test_mamba_forward_shapes() {
448 let config = MambaConfig {
449 state_dim: 2,
450 input_dim: 3,
451 expand_ratio: 2,
452 selective: true,
453 };
454 let block = MambaBlock::<NdArray<f32>>::from_config(&config);
455 let device = <NdArray<f32> as Backend>::Device::default();
456 let input = Tensor::from_data(
457 TensorData::new(vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6], [1, 2, 3]),
458 &device,
459 );
460 let params = MambaParameters {
461 dt_proj: Tensor::from_data(TensorData::new(vec![0.05; 12], [6, 2]), &device),
462 a: Tensor::from_data(TensorData::new(vec![0.1, 0.2], [2]), &device),
463 b: Tensor::from_data(TensorData::new(vec![0.02; 12], [6, 2]), &device),
464 c: Tensor::from_data(TensorData::new(vec![0.03; 12], [2, 6]), &device),
465 d: Tensor::from_data(TensorData::new(vec![0.1; 6], [6]), &device),
466 };
467
468 let (output, state) = block
469 .forward_with_config(input, ¶ms, None, &config)
470 .expect("forward");
471 assert_eq!(output.dims(), [1, 2, 3]);
472 assert_eq!(state.state.dims(), [1, 2]);
473 }
474
475 #[test]
476 fn test_hybrid_parallel_blend() {
477 let layer = HybridLayer::<NdArray<f32>>::new(
478 HybridStrategy::Parallel { mamba_weight: 0.25 },
479 0,
480 );
481 let device = <NdArray<f32> as Backend>::Device::default();
482 let attention =
483 Tensor::from_data(TensorData::new(vec![1.0, 3.0], [1, 1, 2]), &device);
484 let mamba = Tensor::from_data(TensorData::new(vec![5.0, 1.0], [1, 1, 2]), &device);
485
486 let output = layer.combine(attention, mamba).expect("combine");
487 let data = output
488 .into_data()
489 .into_vec::<f32>()
490 .expect("output data");
491 assert!((data[0] - 2.0).abs() < 1e-4);
492 assert!((data[1] - 2.5).abs() < 1e-4);
493 }
494}