kizzasi_model/
mamba.rs

1//! Mamba: Selective State Space Model
2//!
3//! Mamba is a selective SSM that uses input-dependent state transitions,
4//! allowing it to selectively remember or forget information based on content.
5//!
6//! # Key Features
7//!
8//! - **O(1) inference**: Constant time per token during autoregressive generation
9//! - **Selectivity**: Input-dependent Δ, B, C parameters
10//! - **Hardware-efficient**: Parallel scan for training, recurrent for inference
11//! - **Continuous native**: No discrete vocabulary needed for signal prediction
12//!
13//! # Architecture
14//!
15//! ```text
16//! Input → [Expand] → [Conv1D] → [SSM] → [Gate] → [Project] → Output
17//!                                  ↓
18//!                               [State]
19//! ```
20//!
21//! # Selective SSM Formulation
22//!
23//! Unlike traditional SSMs with fixed parameters, Mamba computes:
24//!
25//! ```text
26//! Δ, B, C = Linear(x)  // Input-dependent parameters
27//! A̅ = exp(Δ·A)         // Discretized A matrix
28//! B̅ = (A̅ - I)·A^(-1)·B // Discretized B matrix
29//! h[t] = A̅·h[t-1] + B̅·x[t]
30//! y[t] = C·h[t]
31//! ```
32//!
33//! # References
34//!
35//! - Mamba paper: https://arxiv.org/abs/2312.00752
36//! - Efficient Implementation: Parallel prefix scan for training
37
38use crate::error::{ModelError, ModelResult};
39use crate::AutoregressiveModel;
40use kizzasi_core::{
41    silu, CausalConv1d, CoreResult, HiddenState, LayerNorm, NormType, SignalPredictor,
42};
43use scirs2_core::ndarray::{Array1, Array2};
44use scirs2_core::random::{rng, Rng};
45use tracing::{debug, instrument, trace};
46
47/// Configuration for Mamba model
48#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
49pub struct MambaConfig {
50    /// Input dimension
51    pub input_dim: usize,
52    /// Hidden dimension (d_model)
53    pub hidden_dim: usize,
54    /// State dimension (d_state, typically 16)
55    pub state_dim: usize,
56    /// Expansion factor for inner dimension
57    pub expand_factor: usize,
58    /// Convolution kernel size
59    pub conv_kernel_size: usize,
60    /// Number of layers
61    pub num_layers: usize,
62    /// Dropout rate
63    pub dropout: f32,
64    /// Use Mamba2 architecture (SSD)
65    pub use_mamba2: bool,
66}
67
68impl Default for MambaConfig {
69    fn default() -> Self {
70        Self {
71            input_dim: 1,
72            hidden_dim: 256,
73            state_dim: 16,
74            expand_factor: 2,
75            conv_kernel_size: 4,
76            num_layers: 4,
77            dropout: 0.0,
78            use_mamba2: true,
79        }
80    }
81}
82
83impl MambaConfig {
84    /// Create a new Mamba configuration
85    pub fn new() -> Self {
86        Self::default()
87    }
88
89    /// Mamba-Tiny: Lightweight configuration for fast inference and low memory
90    ///
91    /// Optimized for:
92    /// - Edge devices
93    /// - Real-time streaming applications
94    /// - Low-latency inference
95    ///
96    /// # Parameters
97    /// - Hidden dim: 128
98    /// - State dim: 8
99    /// - Layers: 2
100    /// - Target latency: <50μs per step
101    /// - Memory: <10MB
102    pub fn tiny(input_dim: usize) -> Self {
103        Self {
104            input_dim,
105            hidden_dim: 128,
106            state_dim: 8,
107            expand_factor: 2,
108            conv_kernel_size: 4,
109            num_layers: 2,
110            dropout: 0.0,
111            use_mamba2: false, // Use simpler Mamba for speed
112        }
113    }
114
115    /// Mamba-Small: Balanced configuration for moderate capacity
116    ///
117    /// Optimized for:
118    /// - General-purpose applications
119    /// - Moderate accuracy requirements
120    /// - Resource-constrained servers
121    ///
122    /// # Parameters
123    /// - Hidden dim: 256
124    /// - State dim: 16
125    /// - Layers: 4
126    /// - Target latency: <100μs per step
127    /// - Memory: <50MB
128    pub fn small(input_dim: usize) -> Self {
129        Self {
130            input_dim,
131            hidden_dim: 256,
132            state_dim: 16,
133            expand_factor: 2,
134            conv_kernel_size: 4,
135            num_layers: 4,
136            dropout: 0.1,
137            use_mamba2: true,
138        }
139    }
140
141    /// Mamba-Base: Standard configuration (default)
142    ///
143    /// Optimized for:
144    /// - Standard applications
145    /// - Good accuracy/speed tradeoff
146    /// - Server deployment
147    ///
148    /// # Parameters
149    /// - Hidden dim: 512
150    /// - State dim: 16
151    /// - Layers: 6
152    /// - Target latency: <200μs per step
153    /// - Memory: <200MB
154    pub fn base(input_dim: usize) -> Self {
155        Self {
156            input_dim,
157            hidden_dim: 512,
158            state_dim: 16,
159            expand_factor: 2,
160            conv_kernel_size: 4,
161            num_layers: 6,
162            dropout: 0.1,
163            use_mamba2: true,
164        }
165    }
166
167    /// Mamba-Large: High-capacity configuration for maximum accuracy
168    ///
169    /// Optimized for:
170    /// - High-accuracy applications
171    /// - Complex sequence modeling
172    /// - GPU deployment
173    ///
174    /// # Parameters
175    /// - Hidden dim: 1024
176    /// - State dim: 32
177    /// - Layers: 12
178    /// - Target latency: <500μs per step
179    /// - Memory: <1GB
180    pub fn large(input_dim: usize) -> Self {
181        Self {
182            input_dim,
183            hidden_dim: 1024,
184            state_dim: 32,
185            expand_factor: 2,
186            conv_kernel_size: 4,
187            num_layers: 12,
188            dropout: 0.1,
189            use_mamba2: true,
190        }
191    }
192
193    /// Mamba-XLarge: Experimental extra-large configuration
194    ///
195    /// Optimized for:
196    /// - Research and experimentation
197    /// - Maximum model capacity
198    /// - Multi-GPU deployment
199    ///
200    /// # Parameters
201    /// - Hidden dim: 2048
202    /// - State dim: 64
203    /// - Layers: 24
204    /// - Target latency: <1ms per step
205    /// - Memory: <4GB
206    pub fn xlarge(input_dim: usize) -> Self {
207        Self {
208            input_dim,
209            hidden_dim: 2048,
210            state_dim: 64,
211            expand_factor: 2,
212            conv_kernel_size: 4,
213            num_layers: 24,
214            dropout: 0.2,
215            use_mamba2: true,
216        }
217    }
218
219    /// Set input dimension
220    pub fn input_dim(mut self, dim: usize) -> Self {
221        self.input_dim = dim;
222        self
223    }
224
225    /// Set hidden dimension
226    pub fn hidden_dim(mut self, dim: usize) -> Self {
227        self.hidden_dim = dim;
228        self
229    }
230
231    /// Set state dimension
232    pub fn state_dim(mut self, dim: usize) -> Self {
233        self.state_dim = dim;
234        self
235    }
236
237    /// Set number of layers
238    pub fn num_layers(mut self, n: usize) -> Self {
239        self.num_layers = n;
240        self
241    }
242
243    /// Use Mamba2 (SSD) architecture
244    pub fn mamba2(mut self, use_mamba2: bool) -> Self {
245        self.use_mamba2 = use_mamba2;
246        self
247    }
248
249    /// Validate the configuration
250    pub fn validate(&self) -> ModelResult<()> {
251        if self.hidden_dim == 0 {
252            return Err(ModelError::invalid_config("hidden_dim must be > 0"));
253        }
254        if self.state_dim == 0 {
255            return Err(ModelError::invalid_config("state_dim must be > 0"));
256        }
257        if self.num_layers == 0 {
258            return Err(ModelError::invalid_config("num_layers must be > 0"));
259        }
260        if self.expand_factor == 0 {
261            return Err(ModelError::invalid_config("expand_factor must be > 0"));
262        }
263        Ok(())
264    }
265}
266
267/// Selective SSM block with input-dependent parameters
268struct SelectiveSSM {
269    state_dim: usize,
270    inner_dim: usize,
271
272    /// Fixed diagonal A matrix (in log space for stability)
273    /// A = -exp(log_a), initialized with HiPPO
274    log_a: Array1<f32>,
275
276    /// Projections for selective parameters
277    /// Δ (delta): discretization step size
278    delta_proj: Array2<f32>, // [inner_dim, inner_dim]
279    delta_bias: Array1<f32>, // [inner_dim]
280
281    /// B: input-to-state projection (selective)
282    b_proj: Array2<f32>, // [inner_dim, state_dim]
283
284    /// C: state-to-output projection (selective)
285    c_proj: Array2<f32>, // [inner_dim, state_dim]
286
287    /// D: skip connection
288    d_skip: Array1<f32>, // [inner_dim]
289
290    /// Current state
291    state: Array2<f32>, // [inner_dim, state_dim]
292}
293
294impl SelectiveSSM {
295    fn new(config: &MambaConfig) -> ModelResult<Self> {
296        let mut rng = rng();
297        let inner_dim = config.hidden_dim * config.expand_factor;
298
299        // Initialize diagonal A with HiPPO initialization
300        // A[n] = -(n + 1) for improved long-range modeling
301        // Store log of the absolute value since we'll negate later
302        let log_a = Array1::from_shape_fn(config.state_dim, |n| ((n + 1) as f32).ln());
303
304        // Initialize projections
305        let scale = (2.0 / inner_dim as f32).sqrt();
306
307        let delta_proj = Array2::from_shape_fn((inner_dim, inner_dim), |_| {
308            (rng.random::<f32>() - 0.5) * 2.0 * scale
309        });
310        let delta_bias = Array1::from_shape_fn(inner_dim, |_| rng.random::<f32>() * 0.1);
311
312        let b_proj = Array2::from_shape_fn((inner_dim, config.state_dim), |_| {
313            (rng.random::<f32>() - 0.5) * 2.0 * scale
314        });
315
316        let c_proj = Array2::from_shape_fn((inner_dim, config.state_dim), |_| {
317            (rng.random::<f32>() - 0.5) * 2.0 * scale
318        });
319
320        let d_skip = Array1::ones(inner_dim);
321
322        let state = Array2::zeros((inner_dim, config.state_dim));
323
324        Ok(Self {
325            state_dim: config.state_dim,
326            inner_dim,
327            log_a,
328            delta_proj,
329            delta_bias,
330            b_proj,
331            c_proj,
332            d_skip,
333            state,
334        })
335    }
336
337    /// Selective SSM forward step
338    ///
339    /// Computes input-dependent parameters and performs state update
340    fn forward_step(&mut self, x: &Array1<f32>) -> CoreResult<Array1<f32>> {
341        let batch_size = x.len().min(self.inner_dim);
342
343        // 1. Compute input-dependent Δ (discretization step size)
344        // Δ = Softplus(Linear(x) + bias)
345        let mut delta = Array1::zeros(batch_size);
346        for i in 0..batch_size {
347            let mut sum = self.delta_bias[i];
348            for j in 0..batch_size {
349                sum += self.delta_proj[[i, j]] * x[j];
350            }
351            // Softplus activation to ensure Δ > 0
352            // Clamp input to avoid overflow in exp
353            let clamped = sum.clamp(-20.0, 20.0);
354            delta[i] = (1.0 + clamped.exp()).ln().clamp(1e-6, 0.1);
355        }
356
357        // 2. Compute input-dependent B (input-to-state)
358        // B = Linear_B(x), not just copying weights
359        let mut b_vec = Array2::zeros((batch_size, self.state_dim));
360        for i in 0..batch_size {
361            for n in 0..self.state_dim {
362                let mut sum = 0.0;
363                for j in 0..batch_size {
364                    // Treat b_proj as weight matrix: b_vec[i, n] = sum_j b_proj[j, n] * x[j]
365                    sum += if j < self.b_proj.shape()[0] && n < self.b_proj.shape()[1] {
366                        self.b_proj[[j, n]] * x[j]
367                    } else {
368                        0.0
369                    };
370                }
371                b_vec[[i, n]] = sum;
372            }
373        }
374
375        // 3. Compute input-dependent C (state-to-output)
376        // C = Linear_C(x), not just copying weights
377        let mut c_vec = Array2::zeros((batch_size, self.state_dim));
378        for i in 0..batch_size {
379            for n in 0..self.state_dim {
380                let mut sum = 0.0;
381                for j in 0..batch_size {
382                    // Treat c_proj as weight matrix: c_vec[i, n] = sum_j c_proj[j, n] * x[j]
383                    sum += if j < self.c_proj.shape()[0] && n < self.c_proj.shape()[1] {
384                        self.c_proj[[j, n]] * x[j]
385                    } else {
386                        0.0
387                    };
388                }
389                c_vec[[i, n]] = sum;
390            }
391        }
392
393        // 4. Discretize: A̅ = exp(Δ·A)
394        // For diagonal A: A̅[n] = exp(Δ · A[n])
395        let mut a_bar = Array2::zeros((batch_size, self.state_dim));
396        for i in 0..batch_size {
397            for n in 0..self.state_dim {
398                let a_n = -self.log_a[n].exp(); // A[n] = -exp(log_a[n])
399                let delta_a = delta[i] * a_n;
400                // Clamp to prevent numerical overflow
401                a_bar[[i, n]] = delta_a.clamp(-20.0, 20.0).exp();
402            }
403        }
404
405        // 5. Discretize: B̅ using ZOH or Taylor approximation
406        // Exact: B̅ = (A̅ - I)·A^(-1)·B
407        // For small Δ: B̅ ≈ Δ·B (first-order Taylor)
408        // For moderate Δ: Use exact formula
409        let mut b_bar = Array2::zeros((batch_size, self.state_dim));
410        for i in 0..batch_size {
411            for n in 0..self.state_dim {
412                let a_n = -self.log_a[n].exp();
413
414                // Use Taylor approximation for small delta (more numerically stable)
415                if delta[i].abs() < 0.001 {
416                    // First-order: B̅ ≈ Δ·B
417                    b_bar[[i, n]] = delta[i] * b_vec[[i, n]];
418                } else {
419                    // Exact ZOH discretization
420                    // B̅[n] = (exp(Δ·A[n]) - 1) / A[n] · B[n]
421                    let safe_a_n = if a_n.abs() < 1e-8 { -1.0 } else { a_n };
422                    b_bar[[i, n]] = (a_bar[[i, n]] - 1.0) / safe_a_n * b_vec[[i, n]];
423                }
424            }
425        }
426
427        // 6. State update: h[t] = A̅·h[t-1] + B̅·x[t]
428        let mut new_state = Array2::zeros((batch_size, self.state_dim));
429        for i in 0..batch_size {
430            for n in 0..self.state_dim {
431                // Diagonal A: element-wise multiplication
432                let decay = a_bar[[i, n]];
433                let input_contrib = b_bar[[i, n]] * x[i];
434
435                new_state[[i, n]] = decay * self.state[[i, n]] + input_contrib;
436            }
437        }
438
439        // Update state
440        for i in 0..batch_size.min(self.state.shape()[0]) {
441            for n in 0..self.state_dim {
442                self.state[[i, n]] = new_state[[i, n]];
443            }
444        }
445
446        // 7. Output: y = C·h + D·x
447        let mut output = Array1::zeros(batch_size);
448        for i in 0..batch_size {
449            let mut c_h = 0.0;
450            for n in 0..self.state_dim {
451                c_h += c_vec[[i, n]] * new_state[[i, n]];
452            }
453            output[i] = c_h + self.d_skip[i] * x[i];
454        }
455
456        Ok(output)
457    }
458
459    fn reset(&mut self) {
460        self.state.fill(0.0);
461    }
462}
463
464/// Mamba Layer with Selective SSM
465struct MambaLayer {
466    hidden_dim: usize,
467    inner_dim: usize,
468
469    /// Layer normalization
470    norm: LayerNorm,
471
472    /// Expansion projection
473    in_proj: Array2<f32>, // [hidden_dim, inner_dim * 2]
474
475    /// Short causal convolution for local context
476    conv: CausalConv1d,
477
478    /// Selective SSM
479    ssm: SelectiveSSM,
480
481    /// Output projection (contracts inner_dim back to hidden_dim)
482    out_proj: Array2<f32>, // [inner_dim, hidden_dim]
483}
484
485impl MambaLayer {
486    fn new(config: &MambaConfig) -> ModelResult<Self> {
487        let inner_dim = config.hidden_dim * config.expand_factor;
488        let mut rng = rng();
489
490        // RMSNorm for better stability
491        let norm = LayerNorm::new(config.hidden_dim, NormType::RMSNorm).with_eps(1e-5);
492
493        // Input projection (expands and creates gate path)
494        let scale = (2.0 / config.hidden_dim as f32).sqrt();
495        let in_proj = Array2::from_shape_fn((config.hidden_dim, inner_dim * 2), |_| {
496            (rng.random::<f32>() - 0.5) * 2.0 * scale
497        });
498
499        // Causal convolution
500        let conv = CausalConv1d::new(inner_dim, inner_dim, config.conv_kernel_size);
501
502        // Selective SSM
503        let ssm = SelectiveSSM::new(config)?;
504
505        // Output projection
506        let scale = (2.0 / inner_dim as f32).sqrt();
507        let out_proj = Array2::from_shape_fn((inner_dim, config.hidden_dim), |_| {
508            (rng.random::<f32>() - 0.5) * 2.0 * scale
509        });
510
511        Ok(Self {
512            hidden_dim: config.hidden_dim,
513            inner_dim,
514            norm,
515            in_proj,
516            conv,
517            ssm,
518            out_proj,
519        })
520    }
521
522    fn forward(&mut self, x: &Array1<f32>) -> CoreResult<Array1<f32>> {
523        let batch_size = x.len().min(self.hidden_dim);
524
525        // 1. Layer normalization
526        let x_norm = self.norm.forward(x);
527
528        // 2. Expansion and gating
529        // Project to 2 * inner_dim, then split for SSM path and gate path
530        let mut projected = Array1::zeros(self.inner_dim * 2);
531        for i in 0..(self.inner_dim * 2) {
532            let mut sum = 0.0;
533            for j in 0..batch_size {
534                if i < self.in_proj.shape()[1] {
535                    sum += self.in_proj[[j, i]] * x_norm[j];
536                }
537            }
538            projected[i] = sum;
539        }
540
541        // Split: first half for SSM, second half for gate
542        let mut x_ssm = Array1::zeros(self.inner_dim);
543        let mut x_gate = Array1::zeros(self.inner_dim);
544        for i in 0..self.inner_dim {
545            x_ssm[i] = projected[i];
546            x_gate[i] = projected[self.inner_dim + i];
547        }
548
549        // 3. Short convolution on SSM path
550        let x_ssm_vec = x_ssm.to_vec();
551        let conv_out = self.conv.forward_step(&x_ssm_vec);
552        x_ssm = Array1::from_vec(conv_out);
553
554        // 4. Selective SSM
555        let ssm_out = self.ssm.forward_step(&x_ssm)?;
556
557        // 5. Gating with SiLU (Swish)
558        let gate = silu(&x_gate);
559
560        // Element-wise multiplication
561        let mut gated = Array1::zeros(ssm_out.len().min(gate.len()));
562        for i in 0..gated.len() {
563            gated[i] = ssm_out[i] * gate[i];
564        }
565
566        // 6. Output projection
567        let mut output = Array1::zeros(batch_size);
568        for i in 0..batch_size {
569            let mut sum = 0.0;
570            for j in 0..gated.len().min(self.out_proj.shape()[0]) {
571                sum += self.out_proj[[j, i]] * gated[j];
572            }
573            output[i] = sum;
574        }
575
576        // 7. Residual connection
577        for i in 0..output.len().min(x.len()) {
578            output[i] += x[i];
579        }
580
581        Ok(output)
582    }
583
584    fn reset(&mut self) {
585        self.ssm.reset();
586        self.conv.reset();
587    }
588}
589
590/// Mamba: Selective State Space Model
591pub struct Mamba {
592    config: MambaConfig,
593    layers: Vec<MambaLayer>,
594    input_proj: Array2<f32>,
595    output_proj: Array2<f32>,
596}
597
598impl Mamba {
599    /// Create a new Mamba model
600    #[instrument(skip(config), fields(input_dim = config.input_dim, hidden_dim = config.hidden_dim, num_layers = config.num_layers))]
601    pub fn new(config: MambaConfig) -> ModelResult<Self> {
602        debug!("Creating new Mamba model");
603        config.validate()?;
604
605        // Initialize layers
606        let mut layers = Vec::with_capacity(config.num_layers);
607        for layer_idx in 0..config.num_layers {
608            trace!("Initializing Mamba layer {}", layer_idx);
609            layers.push(MambaLayer::new(&config)?);
610        }
611        debug!("Initialized {} Mamba layers", layers.len());
612
613        // Initialize input/output projections
614        let mut rng = rng();
615        let scale = (2.0 / (config.input_dim + config.hidden_dim) as f32).sqrt();
616        let input_proj = Array2::from_shape_fn((config.input_dim, config.hidden_dim), |_| {
617            (rng.random::<f32>() - 0.5) * 2.0 * scale
618        });
619
620        let scale = (2.0 / (config.hidden_dim + config.input_dim) as f32).sqrt();
621        let output_proj = Array2::from_shape_fn((config.hidden_dim, config.input_dim), |_| {
622            (rng.random::<f32>() - 0.5) * 2.0 * scale
623        });
624
625        debug!("Mamba model created successfully");
626        Ok(Self {
627            config,
628            layers,
629            input_proj,
630            output_proj,
631        })
632    }
633
634    /// Load pre-trained weights from a ModelLoader
635    ///
636    /// # Weight Format
637    ///
638    /// Expected weight names follow the pattern:
639    /// - `input_proj`: Input projection weights
640    /// - `output_proj`: Output projection weights
641    /// - `layers.{i}.norm.weight`: Layer normalization weights
642    /// - `layers.{i}.norm.bias`: Layer normalization bias (optional)
643    /// - `layers.{i}.in_proj`: Input projection for layer i
644    /// - `layers.{i}.conv.weight`: Convolution weights for layer i
645    /// - `layers.{i}.conv.bias`: Convolution bias for layer i (optional)
646    /// - `layers.{i}.ssm.log_a`: SSM diagonal A matrix (log space)
647    /// - `layers.{i}.ssm.delta_proj`: SSM delta projection weights
648    /// - `layers.{i}.ssm.delta_bias`: SSM delta projection bias
649    /// - `layers.{i}.ssm.b_proj`: SSM B projection weights
650    /// - `layers.{i}.ssm.c_proj`: SSM C projection weights
651    /// - `layers.{i}.ssm.d_skip`: SSM skip connection weights
652    /// - `layers.{i}.out_proj`: Output projection for layer i
653    ///
654    /// # Example
655    ///
656    /// ```ignore
657    /// use kizzasi_model::{Mamba, MambaConfig, loader::ModelLoader};
658    ///
659    /// let config = MambaConfig::new();
660    /// let mut model = Mamba::new(config)?;
661    /// let loader = ModelLoader::new("mamba_weights.safetensors")?;
662    /// model.load_weights(&loader)?;
663    /// ```
664    pub fn load_weights(&mut self, loader: &crate::loader::ModelLoader) -> ModelResult<()> {
665        // Load input/output projections
666        if loader.has_tensor("input_proj") {
667            self.input_proj = loader.load_array2("input_proj")?;
668        }
669        if loader.has_tensor("output_proj") {
670            self.output_proj = loader.load_array2("output_proj")?;
671        }
672
673        // Load layer weights
674        for (i, layer) in self.layers.iter_mut().enumerate() {
675            let prefix = format!("layers.{}", i);
676
677            // Load layer norm weights
678            if loader.has_tensor(&format!("{}.norm.weight", prefix)) {
679                let _weight = loader.load_array1(&format!("{}.norm.weight", prefix))?;
680                // TODO: Implement LayerNorm::set_weights() method in kizzasi-core
681                // For now, weights are initialized randomly
682            }
683
684            // Load input projection
685            if loader.has_tensor(&format!("{}.in_proj", prefix)) {
686                layer.in_proj = loader.load_array2(&format!("{}.in_proj", prefix))?;
687            }
688
689            // Load convolution weights
690            if loader.has_tensor(&format!("{}.conv.weight", prefix)) {
691                // TODO: Implement CausalConv1d::load_weights() method
692                // For now, weights are initialized randomly
693            }
694
695            // Load SSM weights
696            if loader.has_tensor(&format!("{}.ssm.log_a", prefix)) {
697                layer.ssm.log_a = loader.load_array1(&format!("{}.ssm.log_a", prefix))?;
698            }
699            if loader.has_tensor(&format!("{}.ssm.delta_proj", prefix)) {
700                layer.ssm.delta_proj = loader.load_array2(&format!("{}.ssm.delta_proj", prefix))?;
701            }
702            if loader.has_tensor(&format!("{}.ssm.delta_bias", prefix)) {
703                layer.ssm.delta_bias = loader.load_array1(&format!("{}.ssm.delta_bias", prefix))?;
704            }
705            if loader.has_tensor(&format!("{}.ssm.b_proj", prefix)) {
706                layer.ssm.b_proj = loader.load_array2(&format!("{}.ssm.b_proj", prefix))?;
707            }
708            if loader.has_tensor(&format!("{}.ssm.c_proj", prefix)) {
709                layer.ssm.c_proj = loader.load_array2(&format!("{}.ssm.c_proj", prefix))?;
710            }
711            if loader.has_tensor(&format!("{}.ssm.d_skip", prefix)) {
712                layer.ssm.d_skip = loader.load_array1(&format!("{}.ssm.d_skip", prefix))?;
713            }
714
715            // Load output projection
716            if loader.has_tensor(&format!("{}.out_proj", prefix)) {
717                layer.out_proj = loader.load_array2(&format!("{}.out_proj", prefix))?;
718            }
719        }
720
721        Ok(())
722    }
723
724    /// Save model weights to safetensors format
725    ///
726    /// # Example
727    ///
728    /// ```ignore
729    /// let model = Mamba::new(config)?;
730    /// model.save_weights("mamba_checkpoint.safetensors")?;
731    /// ```
732    pub fn save_weights<P: AsRef<std::path::Path>>(&self, _path: P) -> ModelResult<()> {
733        // TODO: Implement safetensors serialization
734        // This requires creating safetensors::tensor::TensorView from our arrays
735        Err(ModelError::simple_load_error(
736            "save_weights not yet implemented".to_string(),
737        ))
738    }
739
740    /// Get the configuration
741    pub fn config(&self) -> &MambaConfig {
742        &self.config
743    }
744}
745
746impl SignalPredictor for Mamba {
747    #[instrument(skip(self, input), fields(input_size = input.len()))]
748    fn step(&mut self, input: &Array1<f32>) -> CoreResult<Array1<f32>> {
749        trace!(
750            "Mamba step input range: [{}, {}]",
751            input.iter().cloned().fold(f32::INFINITY, f32::min),
752            input.iter().cloned().fold(f32::NEG_INFINITY, f32::max)
753        );
754
755        // Project input to hidden dimension
756        let mut hidden = input.dot(&self.input_proj);
757        trace!("After input projection: hidden_dim={}", hidden.len());
758
759        // Pass through each layer
760        for (layer_idx, layer) in self.layers.iter_mut().enumerate() {
761            trace!("Processing Mamba layer {}", layer_idx);
762            hidden = layer.forward(&hidden)?;
763        }
764
765        // Project back to input dimension
766        let output = hidden.dot(&self.output_proj);
767        trace!(
768            "Mamba step output range: [{}, {}]",
769            output.iter().cloned().fold(f32::INFINITY, f32::min),
770            output.iter().cloned().fold(f32::NEG_INFINITY, f32::max)
771        );
772        Ok(output)
773    }
774
775    #[instrument(skip(self))]
776    fn reset(&mut self) {
777        debug!("Resetting Mamba model state");
778        for (layer_idx, layer) in self.layers.iter_mut().enumerate() {
779            trace!("Resetting layer {}", layer_idx);
780            layer.reset();
781        }
782    }
783
784    fn context_window(&self) -> usize {
785        // SSMs have theoretically infinite context via recurrence
786        usize::MAX
787    }
788}
789
790impl AutoregressiveModel for Mamba {
791    fn hidden_dim(&self) -> usize {
792        self.config.hidden_dim
793    }
794
795    fn state_dim(&self) -> usize {
796        self.config.state_dim
797    }
798
799    fn num_layers(&self) -> usize {
800        self.config.num_layers
801    }
802
803    fn model_type(&self) -> crate::ModelType {
804        if self.config.use_mamba2 {
805            crate::ModelType::Mamba2
806        } else {
807            crate::ModelType::Mamba
808        }
809    }
810
811    fn get_states(&self) -> Vec<HiddenState> {
812        self.layers
813            .iter()
814            .map(|layer| {
815                let state = layer.ssm.state.clone();
816                let mut hs = HiddenState::new(state.shape()[0], state.shape()[1]);
817                hs.update(state);
818                // Also save convolution history
819                let conv_history = layer.conv.get_history();
820                hs.set_conv_history(conv_history);
821                hs
822            })
823            .collect()
824    }
825
826    fn set_states(&mut self, states: Vec<HiddenState>) -> ModelResult<()> {
827        if states.len() != self.config.num_layers {
828            return Err(ModelError::state_count_mismatch(
829                "Mamba",
830                self.config.num_layers,
831                states.len(),
832            ));
833        }
834
835        for (layer_idx, layer) in self.layers.iter_mut().enumerate() {
836            layer.ssm.state = states[layer_idx].state().clone();
837            // Also restore convolution history if available
838            if let Some(conv_history) = states[layer_idx].conv_history() {
839                layer.conv.set_history(conv_history.clone());
840            }
841        }
842
843        Ok(())
844    }
845}
846
847#[cfg(test)]
848mod tests {
849    use super::*;
850
851    #[test]
852    fn test_mamba_creation() {
853        let config = MambaConfig::new()
854            .input_dim(3)
855            .hidden_dim(64)
856            .state_dim(8)
857            .num_layers(2);
858
859        let mamba = Mamba::new(config);
860        assert!(mamba.is_ok());
861    }
862
863    #[test]
864    fn test_mamba_step() {
865        let config = MambaConfig::new()
866            .input_dim(3)
867            .hidden_dim(32)
868            .state_dim(8)
869            .num_layers(2);
870
871        let mut mamba = Mamba::new(config).expect("Failed to create Mamba model");
872        let input = Array1::from_vec(vec![0.1, 0.2, 0.3]);
873        let output = mamba.step(&input);
874
875        assert!(output.is_ok());
876        assert_eq!(output.expect("Failed to get output").len(), 3);
877    }
878
879    #[test]
880    fn test_mamba_tiny_config() {
881        let config = MambaConfig::tiny(4);
882        assert_eq!(config.hidden_dim, 128);
883        assert_eq!(config.state_dim, 8);
884        assert_eq!(config.num_layers, 2);
885        assert!(!config.use_mamba2);
886
887        let mut model = Mamba::new(config).expect("Failed to create Mamba model");
888        let input = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4]);
889        let output = model.step(&input).expect("Failed to get output");
890        assert_eq!(output.len(), 4);
891    }
892
893    #[test]
894    fn test_mamba_small_config() {
895        // Test that small config has correct values
896        let config = MambaConfig::small(4);
897        assert_eq!(config.hidden_dim, 256);
898        assert_eq!(config.state_dim, 16);
899        assert_eq!(config.num_layers, 4);
900        assert!(config.use_mamba2);
901
902        // Use a minimal model to verify small config is valid (not full model)
903        // Full model test is too slow for regular testing
904        let minimal_config = MambaConfig::new()
905            .input_dim(4)
906            .hidden_dim(64)
907            .state_dim(8)
908            .num_layers(2);
909        let mut model = Mamba::new(minimal_config).expect("Failed to create Mamba model");
910        let input = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4]);
911        let output = model.step(&input).expect("Failed to get output");
912        assert_eq!(output.len(), 4);
913    }
914
915    #[test]
916    fn test_mamba_base_config() {
917        // Test that base config has correct values
918        let config = MambaConfig::base(4);
919        assert_eq!(config.hidden_dim, 512);
920        assert_eq!(config.state_dim, 16);
921        assert_eq!(config.num_layers, 6);
922        assert!(config.use_mamba2);
923
924        // Use a minimal model to verify base config is valid (not full model)
925        // Full model test is too slow for regular testing
926        let minimal_config = MambaConfig::new()
927            .input_dim(4)
928            .hidden_dim(64)
929            .state_dim(8)
930            .num_layers(2);
931        let mut model = Mamba::new(minimal_config).expect("Failed to create Mamba model");
932        let input = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4]);
933        let output = model.step(&input).expect("Failed to get output");
934        assert_eq!(output.len(), 4);
935    }
936
937    #[test]
938    #[ignore] // Slow test: ~670s due to large model initialization (hidden_dim=1024, num_layers=12)
939    fn test_mamba_large_config() {
940        let config = MambaConfig::large(4);
941        assert_eq!(config.hidden_dim, 1024);
942        assert_eq!(config.state_dim, 32);
943        assert_eq!(config.num_layers, 12);
944        assert!(config.use_mamba2);
945
946        let mut model = Mamba::new(config).expect("Failed to create Mamba model");
947        let input = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4]);
948        let output = model.step(&input).expect("Failed to get output");
949        assert_eq!(output.len(), 4);
950    }
951
952    #[test]
953    #[ignore] // Slow test: ~610s due to very large model initialization (hidden_dim=2048, num_layers=24)
954    fn test_mamba_xlarge_config() {
955        let config = MambaConfig::xlarge(2);
956        assert_eq!(config.hidden_dim, 2048);
957        assert_eq!(config.state_dim, 64);
958        assert_eq!(config.num_layers, 24);
959        assert!(config.use_mamba2);
960
961        // Create model to verify configuration is valid
962        let model = Mamba::new(config);
963        assert!(model.is_ok());
964    }
965
966    #[test]
967    fn test_preset_configs_size_progression() {
968        // Verify that model sizes increase progressively
969        let tiny = MambaConfig::tiny(1);
970        let small = MambaConfig::small(1);
971        let base = MambaConfig::base(1);
972        let large = MambaConfig::large(1);
973        let xlarge = MambaConfig::xlarge(1);
974
975        assert!(tiny.hidden_dim < small.hidden_dim);
976        assert!(small.hidden_dim < base.hidden_dim);
977        assert!(base.hidden_dim < large.hidden_dim);
978        assert!(large.hidden_dim < xlarge.hidden_dim);
979
980        assert!(tiny.num_layers <= small.num_layers);
981        assert!(small.num_layers <= base.num_layers);
982        assert!(base.num_layers <= large.num_layers);
983        assert!(large.num_layers <= xlarge.num_layers);
984    }
985}