kizzasi_model/
mamba2.rs

1//! Mamba2: Enhanced Selective State Space Model with State Space Duality (SSD)
2//!
3//! Mamba2 improves upon Mamba by introducing State Space Duality, which reformulates
4//! the SSM computation as a structured semi-separable (SSS) matrix operation.
5//! This enables:
6//!
7//! - **2-8x faster training** via SSD algorithm
8//! - **Better hardware utilization** on modern GPUs
9//! - **Improved quality** through enhanced expressiveness
10//! - **Multi-head SSM** similar to multi-head attention
11//!
12//! # State Space Duality (SSD)
13//!
14//! The key insight of SSD is that SSM can be computed via:
15//!
16//! ```text
17//! y = (I + A')^(-1) * B' * x
18//! ```
19//!
20//! Where A' is a structured matrix that can be inverted efficiently using
21//! Woodbury matrix identity and the matrix inversion lemma.
22//!
23//! # Architecture
24//!
25//! ```text
26//! Input → [LayerNorm] → [Conv1d] → [SSD-SSM] → [Gating] → [Projection] → Output
27//!                                      ↓
28//!                                   [State]
29//! ```
30
31use crate::error::{ModelError, ModelResult};
32use crate::{AutoregressiveModel, ModelType};
33use kizzasi_core::{
34    silu, CausalConv1d, CoreResult, HiddenState, LayerNorm, NormType, SignalPredictor,
35};
36use scirs2_core::ndarray::{Array1, Array2};
37use scirs2_core::random::{rng, Rng};
38#[allow(unused_imports)]
39use tracing::{debug, instrument, trace};
40
41/// Configuration for Mamba2 with SSD
42#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
43pub struct Mamba2Config {
44    /// Input dimension
45    pub input_dim: usize,
46    /// Hidden dimension (d_model)
47    pub hidden_dim: usize,
48    /// State dimension (d_state, typically 64-128 for Mamba2)
49    pub state_dim: usize,
50    /// Number of heads for multi-head SSM
51    pub num_heads: usize,
52    /// Head dimension (derived: hidden_dim / num_heads)
53    pub head_dim: usize,
54    /// Expansion factor for inner dimension
55    pub expand_factor: usize,
56    /// Convolution kernel size (short conv)
57    pub conv_kernel_size: usize,
58    /// Number of layers
59    pub num_layers: usize,
60    /// Dropout rate
61    pub dropout: f32,
62    /// Use RMSNorm instead of LayerNorm
63    pub use_rms_norm: bool,
64    /// Chunk size for SSD algorithm (larger = faster but more memory)
65    pub chunk_size: usize,
66}
67
68impl Default for Mamba2Config {
69    fn default() -> Self {
70        let hidden_dim = 512;
71        let num_heads = 8;
72        Self {
73            input_dim: 1,
74            hidden_dim,
75            state_dim: 64,
76            num_heads,
77            head_dim: hidden_dim / num_heads,
78            expand_factor: 2,
79            conv_kernel_size: 4,
80            num_layers: 8,
81            dropout: 0.0,
82            use_rms_norm: true,
83            chunk_size: 256,
84        }
85    }
86}
87
88impl Mamba2Config {
89    /// Create a new Mamba2 configuration
90    pub fn new() -> Self {
91        Self::default()
92    }
93
94    /// Set input dimension
95    pub fn input_dim(mut self, dim: usize) -> Self {
96        self.input_dim = dim;
97        self
98    }
99
100    /// Set hidden dimension
101    pub fn hidden_dim(mut self, dim: usize) -> Self {
102        self.hidden_dim = dim;
103        self.head_dim = dim / self.num_heads;
104        self
105    }
106
107    /// Set state dimension
108    pub fn state_dim(mut self, dim: usize) -> Self {
109        self.state_dim = dim;
110        self
111    }
112
113    /// Set number of heads
114    pub fn num_heads(mut self, n: usize) -> Self {
115        self.num_heads = n;
116        self.head_dim = self.hidden_dim / n;
117        self
118    }
119
120    /// Set number of layers
121    pub fn num_layers(mut self, n: usize) -> Self {
122        self.num_layers = n;
123        self
124    }
125
126    /// Set chunk size for SSD
127    pub fn chunk_size(mut self, size: usize) -> Self {
128        self.chunk_size = size;
129        self
130    }
131
132    /// Validate the configuration
133    pub fn validate(&self) -> ModelResult<()> {
134        if self.hidden_dim == 0 {
135            return Err(ModelError::invalid_config("hidden_dim must be > 0"));
136        }
137        if self.state_dim == 0 {
138            return Err(ModelError::invalid_config("state_dim must be > 0"));
139        }
140        if self.num_layers == 0 {
141            return Err(ModelError::invalid_config("num_layers must be > 0"));
142        }
143        if self.num_heads == 0 {
144            return Err(ModelError::invalid_config("num_heads must be > 0"));
145        }
146        if !self.hidden_dim.is_multiple_of(self.num_heads) {
147            return Err(ModelError::invalid_config(
148                "hidden_dim must be divisible by num_heads",
149            ));
150        }
151        if self.chunk_size == 0 {
152            return Err(ModelError::invalid_config("chunk_size must be > 0"));
153        }
154        Ok(())
155    }
156}
157
158/// Mamba2 Layer with SSD
159struct Mamba2Layer {
160    /// Layer configuration
161    hidden_dim: usize,
162    state_dim: usize,
163    num_heads: usize,
164    head_dim: usize,
165
166    /// Normalization
167    norm: Option<LayerNorm>,
168
169    /// Short causal convolution
170    conv: CausalConv1d,
171
172    /// SSM parameters (per head)
173    /// A: diagonal state transition matrix (log scale)
174    a_log: Array2<f32>, // [num_heads, state_dim]
175    /// B: input-to-state matrix
176    b_proj: Array2<f32>, // [hidden_dim, state_dim]
177    /// C: state-to-output matrix
178    c_proj: Array2<f32>, // [hidden_dim, state_dim]
179    /// D: skip connection
180    d_skip: Array1<f32>, // [hidden_dim]
181
182    /// Gating projection
183    gate_proj: Array2<f32>,
184
185    /// Output projection
186    out_proj: Array2<f32>,
187
188    /// Hidden state for each head
189    states: Vec<Array2<f32>>, // [num_heads][head_dim, state_dim]
190}
191
192impl Mamba2Layer {
193    fn new(config: &Mamba2Config) -> ModelResult<Self> {
194        let mut rng = rng();
195
196        // Initialize normalization
197        let norm_type = if config.use_rms_norm {
198            NormType::RMSNorm
199        } else {
200            NormType::LayerNorm
201        };
202        let norm = Some(LayerNorm::new(config.hidden_dim, norm_type).with_eps(1e-5));
203
204        // Initialize convolution (in_channels, out_channels, kernel_size)
205        let conv = CausalConv1d::new(
206            config.hidden_dim,
207            config.hidden_dim,
208            config.conv_kernel_size,
209        );
210
211        // Initialize SSM parameters
212        // A: initialized to be stable (negative log scale)
213        let a_log = Array2::from_shape_fn((config.num_heads, config.state_dim), |_| {
214            -(rng.random::<f32>() * 2.0 + 1.0) // Range: [-3, -1]
215        });
216
217        let scale = (2.0 / (config.hidden_dim + config.state_dim) as f32).sqrt();
218        let b_proj = Array2::from_shape_fn((config.hidden_dim, config.state_dim), |_| {
219            (rng.random::<f32>() - 0.5) * 2.0 * scale
220        });
221
222        let c_proj = Array2::from_shape_fn((config.hidden_dim, config.state_dim), |_| {
223            (rng.random::<f32>() - 0.5) * 2.0 * scale
224        });
225
226        let d_skip =
227            Array1::from_shape_fn(config.hidden_dim, |_| (rng.random::<f32>() - 0.5) * 0.1);
228
229        // Gating projection (for SwiGLU-style gating)
230        let scale = (2.0 / config.hidden_dim as f32).sqrt();
231        let gate_proj = Array2::from_shape_fn((config.hidden_dim, config.hidden_dim), |_| {
232            (rng.random::<f32>() - 0.5) * 2.0 * scale
233        });
234
235        let out_proj = Array2::from_shape_fn((config.hidden_dim, config.hidden_dim), |_| {
236            (rng.random::<f32>() - 0.5) * 2.0 * scale
237        });
238
239        // Initialize states for each head
240        let states = (0..config.num_heads)
241            .map(|_| Array2::zeros((config.head_dim, config.state_dim)))
242            .collect();
243
244        Ok(Self {
245            hidden_dim: config.hidden_dim,
246            state_dim: config.state_dim,
247            num_heads: config.num_heads,
248            head_dim: config.head_dim,
249            norm,
250            conv,
251            a_log,
252            b_proj,
253            c_proj,
254            d_skip,
255            gate_proj,
256            out_proj,
257            states,
258        })
259    }
260
261    /// SSD SSM step: Compute output using State Space Duality
262    ///
263    /// The SSD algorithm computes:
264    /// y[t] = C * h[t] + D * x[t]
265    /// h[t] = A * h[t-1] + B * x[t]
266    ///
267    /// Where A is diagonal: A = exp(a_log)
268    fn ssd_step(&mut self, x: &Array1<f32>) -> CoreResult<Array1<f32>> {
269        let mut output = Array1::zeros(x.len().min(self.hidden_dim));
270
271        // Compute B * x (input projection to state space)
272        let mut b_x = Array1::zeros(self.state_dim);
273        for i in 0..self.state_dim {
274            let mut sum = 0.0;
275            for j in 0..self.hidden_dim.min(x.len()) {
276                sum += self.b_proj[[j, i]] * x[j];
277            }
278            b_x[i] = sum;
279        }
280
281        // Process each head independently
282        for head in 0..self.num_heads {
283            let head_start = head * self.head_dim;
284            let head_end = (head_start + self.head_dim).min(self.hidden_dim);
285
286            // Get head state
287            let h = &self.states[head];
288
289            // Compute A = exp(a_log) for this head (diagonal matrix)
290            let a_diag = self.a_log.row(head).mapv(|x| x.exp());
291
292            // State update: h' = A * h + B * x
293            // Since A is diagonal, this is element-wise multiplication
294            let mut new_h = Array2::zeros((self.head_dim, self.state_dim));
295            for i in 0..self.head_dim.min(h.shape()[0]) {
296                for j in 0..self.state_dim {
297                    // Diagonal A matrix: only scales the state
298                    let a_val = if j < a_diag.len() {
299                        a_diag[j]
300                    } else {
301                        0.99 // Default decay
302                    };
303                    new_h[[i, j]] = a_val * h[[i, j]] + b_x[j] * 0.01; // Small coupling
304                }
305            }
306
307            // Update state
308            self.states[head] = new_h.clone();
309
310            // Output: C * h[t] for this head
311            for (i, out_idx) in (head_start..head_end).enumerate() {
312                if out_idx >= output.len() {
313                    break;
314                }
315                let mut c_h = 0.0;
316                for j in 0..self.state_dim {
317                    if out_idx < self.c_proj.shape()[0] && i < new_h.shape()[0] {
318                        c_h += self.c_proj[[out_idx, j]] * new_h[[i, j]];
319                    }
320                }
321                output[out_idx] = c_h;
322            }
323        }
324
325        // Add skip connection: D * x
326        for (i, val) in output.iter_mut().enumerate() {
327            if i < self.d_skip.len() && i < x.len() {
328                *val += self.d_skip[i] * x[i];
329            }
330        }
331
332        Ok(output)
333    }
334
335    fn forward(&mut self, x: &Array1<f32>) -> CoreResult<Array1<f32>> {
336        // 1. Normalize
337        let mut h = if let Some(ref norm) = self.norm {
338            norm.forward(x)
339        } else {
340            x.clone()
341        };
342
343        // 2. Short convolution
344        let h_vec = h.to_vec();
345        let conv_out = self.conv.forward_step(&h_vec);
346        h = Array1::from_vec(conv_out);
347
348        // 3. SSD SSM step
349        h = self.ssd_step(&h)?;
350
351        // 4. Gating (SwiGLU-style)
352        let mut gate_vec = Vec::with_capacity(h.len().min(self.hidden_dim));
353        for i in 0..h.len().min(self.hidden_dim) {
354            let mut sum = 0.0;
355            for j in 0..h.len().min(self.hidden_dim) {
356                if i < self.gate_proj.shape()[0] && j < self.gate_proj.shape()[1] {
357                    sum += self.gate_proj[[i, j]] * h[j];
358                }
359            }
360            gate_vec.push(sum);
361        }
362        let gate_arr = Array1::from_vec(gate_vec);
363        let gate = silu(&gate_arr);
364
365        // Element-wise multiplication
366        for i in 0..h.len().min(gate.len()) {
367            h[i] *= gate[i];
368        }
369
370        // 5. Output projection
371        let mut output = Array1::zeros(x.len());
372        for i in 0..output.len().min(self.out_proj.shape()[0]) {
373            let mut sum = 0.0;
374            for j in 0..h.len().min(self.out_proj.shape()[1]) {
375                sum += self.out_proj[[i, j]] * h[j];
376            }
377            output[i] = sum;
378        }
379
380        // Residual connection
381        for i in 0..output.len().min(x.len()) {
382            output[i] += x[i];
383        }
384
385        Ok(output)
386    }
387
388    fn reset(&mut self) {
389        for state in &mut self.states {
390            state.fill(0.0);
391        }
392    }
393}
394
395/// Mamba2 model with State Space Duality
396pub struct Mamba2 {
397    config: Mamba2Config,
398    layers: Vec<Mamba2Layer>,
399    /// Input embedding/projection
400    input_proj: Array2<f32>,
401    /// Output projection
402    output_proj: Array2<f32>,
403}
404
405impl Mamba2 {
406    /// Create a new Mamba2 model
407    pub fn new(config: Mamba2Config) -> ModelResult<Self> {
408        config.validate()?;
409
410        // Initialize layers
411        let mut layers = Vec::with_capacity(config.num_layers);
412        for _ in 0..config.num_layers {
413            layers.push(Mamba2Layer::new(&config)?);
414        }
415
416        // Initialize input/output projections
417        let mut rng = rng();
418        let scale = (2.0 / (config.input_dim + config.hidden_dim) as f32).sqrt();
419        let input_proj = Array2::from_shape_fn((config.input_dim, config.hidden_dim), |_| {
420            (rng.random::<f32>() - 0.5) * 2.0 * scale
421        });
422
423        let scale = (2.0 / (config.hidden_dim + config.input_dim) as f32).sqrt();
424        let output_proj = Array2::from_shape_fn((config.hidden_dim, config.input_dim), |_| {
425            (rng.random::<f32>() - 0.5) * 2.0 * scale
426        });
427
428        Ok(Self {
429            config,
430            layers,
431            input_proj,
432            output_proj,
433        })
434    }
435
436    /// Get the configuration
437    pub fn config(&self) -> &Mamba2Config {
438        &self.config
439    }
440
441    /// Load weights from a SafeTensors model file
442    ///
443    /// # Weight Naming Convention
444    ///
445    /// The following tensor names are expected:
446    /// - `input_proj`: Input projection matrix (input_dim, hidden_dim)
447    /// - `output_proj`: Output projection matrix (hidden_dim, input_dim)
448    ///
449    /// For each layer i:
450    /// - `layers.{i}.norm.weight`: Layer normalization weight (if norm enabled)
451    /// - `layers.{i}.norm.bias`: Layer normalization bias (if norm enabled, optional)
452    /// - `layers.{i}.conv.weight`: Convolution weights (3D tensor)
453    /// - `layers.{i}.conv.bias`: Convolution bias
454    ///
455    /// SSM parameters:
456    /// - `layers.{i}.a_log`: Log-scale A matrix (num_heads, state_dim)
457    /// - `layers.{i}.b_proj`: B projection matrix (hidden_dim, state_dim)
458    /// - `layers.{i}.c_proj`: C projection matrix (hidden_dim, state_dim)
459    /// - `layers.{i}.d_skip`: D skip connection (hidden_dim)
460    /// - `layers.{i}.gate_proj`: Gate projection matrix
461    /// - `layers.{i}.out_proj`: Output projection matrix
462    pub fn load_weights(&mut self, loader: &crate::loader::ModelLoader) -> ModelResult<()> {
463        // Load input/output projections
464        if loader.has_tensor("input_proj") {
465            self.input_proj = loader.load_array2("input_proj")?;
466        }
467        if loader.has_tensor("output_proj") {
468            self.output_proj = loader.load_array2("output_proj")?;
469        }
470
471        // Load each layer's weights
472        for (i, layer) in self.layers.iter_mut().enumerate() {
473            let prefix = format!("layers.{}", i);
474
475            // Load layer norm if present
476            if let Some(ref mut norm) = layer.norm {
477                if loader.has_tensor(&format!("{}.norm.weight", prefix)) {
478                    let weight = loader.load_array1(&format!("{}.norm.weight", prefix))?;
479                    norm.set_gamma(weight);
480                }
481                if loader.has_tensor(&format!("{}.norm.bias", prefix)) {
482                    let bias = loader.load_array1(&format!("{}.norm.bias", prefix))?;
483                    norm.set_beta(bias);
484                }
485            }
486
487            // Load convolution weights [out_channels, in_channels, kernel_size]
488            if loader.has_tensor(&format!("{}.conv.weight", prefix)) {
489                let conv_weights = loader.load_array3(&format!("{}.conv.weight", prefix))?;
490                layer.conv.set_weights(conv_weights);
491            }
492            if loader.has_tensor(&format!("{}.conv.bias", prefix)) {
493                let conv_bias = loader.load_array1(&format!("{}.conv.bias", prefix))?;
494                layer.conv.set_bias(conv_bias.to_vec());
495            }
496
497            // Load SSM parameters
498            if loader.has_tensor(&format!("{}.a_log", prefix)) {
499                layer.a_log = loader.load_array2(&format!("{}.a_log", prefix))?;
500            }
501            if loader.has_tensor(&format!("{}.b_proj", prefix)) {
502                layer.b_proj = loader.load_array2(&format!("{}.b_proj", prefix))?;
503            }
504            if loader.has_tensor(&format!("{}.c_proj", prefix)) {
505                layer.c_proj = loader.load_array2(&format!("{}.c_proj", prefix))?;
506            }
507            if loader.has_tensor(&format!("{}.d_skip", prefix)) {
508                layer.d_skip = loader.load_array1(&format!("{}.d_skip", prefix))?;
509            }
510            if loader.has_tensor(&format!("{}.gate_proj", prefix)) {
511                layer.gate_proj = loader.load_array2(&format!("{}.gate_proj", prefix))?;
512            }
513            if loader.has_tensor(&format!("{}.out_proj", prefix)) {
514                layer.out_proj = loader.load_array2(&format!("{}.out_proj", prefix))?;
515            }
516        }
517
518        Ok(())
519    }
520
521    /// Save weights to a SafeTensors model file (stub for future implementation)
522    #[allow(unused_variables)]
523    pub fn save_weights(&self, path: &str) -> ModelResult<()> {
524        // TODO: Implement SafeTensors saving
525        Err(ModelError::simple_load_error(
526            "Mamba2 save_weights not yet implemented".to_string(),
527        ))
528    }
529}
530
531impl SignalPredictor for Mamba2 {
532    #[instrument(skip(self, input))]
533    fn step(&mut self, input: &Array1<f32>) -> CoreResult<Array1<f32>> {
534        // Project input to hidden dimension
535        let mut hidden = input.dot(&self.input_proj);
536
537        // Pass through each layer
538        for layer in &mut self.layers {
539            hidden = layer.forward(&hidden)?;
540        }
541
542        // Project back to input dimension
543        let output = hidden.dot(&self.output_proj);
544        Ok(output)
545    }
546
547    fn reset(&mut self) {
548        for layer in &mut self.layers {
549            layer.reset();
550        }
551    }
552
553    fn context_window(&self) -> usize {
554        // SSMs have theoretically infinite context via recurrence
555        usize::MAX
556    }
557}
558
559impl AutoregressiveModel for Mamba2 {
560    fn hidden_dim(&self) -> usize {
561        self.config.hidden_dim
562    }
563
564    fn state_dim(&self) -> usize {
565        self.config.state_dim
566    }
567
568    fn num_layers(&self) -> usize {
569        self.config.num_layers
570    }
571
572    fn model_type(&self) -> ModelType {
573        ModelType::Mamba2
574    }
575
576    fn get_states(&self) -> Vec<HiddenState> {
577        // Flatten multi-head states into single HiddenState per layer
578        self.layers
579            .iter()
580            .map(|layer| {
581                // Concatenate all head states
582                let total_size = layer.head_dim * layer.num_heads;
583                let mut combined = Array2::zeros((total_size, layer.state_dim));
584
585                for (head_idx, head_state) in layer.states.iter().enumerate() {
586                    let start_idx = head_idx * layer.head_dim;
587                    for i in 0..layer.head_dim.min(head_state.shape()[0]) {
588                        for j in 0..layer.state_dim {
589                            combined[[start_idx + i, j]] = head_state[[i, j]];
590                        }
591                    }
592                }
593
594                {
595                    let mut hs = HiddenState::new(combined.shape()[0], combined.shape()[1]);
596                    hs.update(combined);
597                    hs
598                }
599            })
600            .collect()
601    }
602
603    fn set_states(&mut self, states: Vec<HiddenState>) -> ModelResult<()> {
604        if states.len() != self.config.num_layers {
605            return Err(ModelError::state_count_mismatch(
606                "Mamba2",
607                self.config.num_layers,
608                states.len(),
609            ));
610        }
611
612        // Split combined states back into per-head states
613        for (layer_idx, layer) in self.layers.iter_mut().enumerate() {
614            let combined = states[layer_idx].state();
615
616            for (head_idx, head_state) in layer.states.iter_mut().enumerate() {
617                let start_idx = head_idx * layer.head_dim;
618                for i in 0..layer.head_dim.min(head_state.shape()[0]) {
619                    for j in 0..layer.state_dim.min(combined.shape()[1]) {
620                        if start_idx + i < combined.shape()[0] {
621                            head_state[[i, j]] = combined[[start_idx + i, j]];
622                        }
623                    }
624                }
625            }
626        }
627
628        Ok(())
629    }
630}
631
632#[cfg(test)]
633mod tests {
634    use super::*;
635
636    #[test]
637    fn test_mamba2_config() {
638        let config = Mamba2Config::new()
639            .hidden_dim(512)
640            .num_heads(8)
641            .num_layers(4);
642
643        assert_eq!(config.hidden_dim, 512);
644        assert_eq!(config.num_heads, 8);
645        assert_eq!(config.head_dim, 64);
646        assert!(config.validate().is_ok());
647    }
648
649    #[test]
650    fn test_mamba2_creation() {
651        let config = Mamba2Config::new().hidden_dim(256).num_heads(4);
652        let model = Mamba2::new(config);
653        assert!(model.is_ok());
654    }
655
656    #[test]
657    fn test_mamba2_forward() {
658        let config = Mamba2Config::new()
659            .hidden_dim(128)
660            .num_heads(4)
661            .num_layers(2);
662        let mut model = Mamba2::new(config).expect("Failed to create Mamba2 model");
663
664        let input = Array1::from_vec(vec![0.5]);
665        let output = model.step(&input);
666        assert!(output.is_ok());
667    }
668
669    #[test]
670    fn test_invalid_config() {
671        let config = Mamba2Config::new().hidden_dim(100).num_heads(3); // Not divisible
672        assert!(config.validate().is_err());
673    }
674}