1use crate::error::{ModelError, ModelResult};
32use crate::{AutoregressiveModel, ModelType};
33use kizzasi_core::{
34 silu, CausalConv1d, CoreResult, HiddenState, LayerNorm, NormType, SignalPredictor,
35};
36use scirs2_core::ndarray::{Array1, Array2};
37use scirs2_core::random::{rng, Rng};
38#[allow(unused_imports)]
39use tracing::{debug, instrument, trace};
40
41#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
43pub struct Mamba2Config {
44 pub input_dim: usize,
46 pub hidden_dim: usize,
48 pub state_dim: usize,
50 pub num_heads: usize,
52 pub head_dim: usize,
54 pub expand_factor: usize,
56 pub conv_kernel_size: usize,
58 pub num_layers: usize,
60 pub dropout: f32,
62 pub use_rms_norm: bool,
64 pub chunk_size: usize,
66}
67
68impl Default for Mamba2Config {
69 fn default() -> Self {
70 let hidden_dim = 512;
71 let num_heads = 8;
72 Self {
73 input_dim: 1,
74 hidden_dim,
75 state_dim: 64,
76 num_heads,
77 head_dim: hidden_dim / num_heads,
78 expand_factor: 2,
79 conv_kernel_size: 4,
80 num_layers: 8,
81 dropout: 0.0,
82 use_rms_norm: true,
83 chunk_size: 256,
84 }
85 }
86}
87
88impl Mamba2Config {
89 pub fn new() -> Self {
91 Self::default()
92 }
93
94 pub fn input_dim(mut self, dim: usize) -> Self {
96 self.input_dim = dim;
97 self
98 }
99
100 pub fn hidden_dim(mut self, dim: usize) -> Self {
102 self.hidden_dim = dim;
103 self.head_dim = dim / self.num_heads;
104 self
105 }
106
107 pub fn state_dim(mut self, dim: usize) -> Self {
109 self.state_dim = dim;
110 self
111 }
112
113 pub fn num_heads(mut self, n: usize) -> Self {
115 self.num_heads = n;
116 self.head_dim = self.hidden_dim / n;
117 self
118 }
119
120 pub fn num_layers(mut self, n: usize) -> Self {
122 self.num_layers = n;
123 self
124 }
125
126 pub fn chunk_size(mut self, size: usize) -> Self {
128 self.chunk_size = size;
129 self
130 }
131
132 pub fn validate(&self) -> ModelResult<()> {
134 if self.hidden_dim == 0 {
135 return Err(ModelError::invalid_config("hidden_dim must be > 0"));
136 }
137 if self.state_dim == 0 {
138 return Err(ModelError::invalid_config("state_dim must be > 0"));
139 }
140 if self.num_layers == 0 {
141 return Err(ModelError::invalid_config("num_layers must be > 0"));
142 }
143 if self.num_heads == 0 {
144 return Err(ModelError::invalid_config("num_heads must be > 0"));
145 }
146 if !self.hidden_dim.is_multiple_of(self.num_heads) {
147 return Err(ModelError::invalid_config(
148 "hidden_dim must be divisible by num_heads",
149 ));
150 }
151 if self.chunk_size == 0 {
152 return Err(ModelError::invalid_config("chunk_size must be > 0"));
153 }
154 Ok(())
155 }
156}
157
158struct Mamba2Layer {
160 hidden_dim: usize,
162 state_dim: usize,
163 num_heads: usize,
164 head_dim: usize,
165
166 norm: Option<LayerNorm>,
168
169 conv: CausalConv1d,
171
172 a_log: Array2<f32>, b_proj: Array2<f32>, c_proj: Array2<f32>, d_skip: Array1<f32>, gate_proj: Array2<f32>,
184
185 out_proj: Array2<f32>,
187
188 states: Vec<Array2<f32>>, }
191
192impl Mamba2Layer {
193 fn new(config: &Mamba2Config) -> ModelResult<Self> {
194 let mut rng = rng();
195
196 let norm_type = if config.use_rms_norm {
198 NormType::RMSNorm
199 } else {
200 NormType::LayerNorm
201 };
202 let norm = Some(LayerNorm::new(config.hidden_dim, norm_type).with_eps(1e-5));
203
204 let conv = CausalConv1d::new(
206 config.hidden_dim,
207 config.hidden_dim,
208 config.conv_kernel_size,
209 );
210
211 let a_log = Array2::from_shape_fn((config.num_heads, config.state_dim), |_| {
214 -(rng.random::<f32>() * 2.0 + 1.0) });
216
217 let scale = (2.0 / (config.hidden_dim + config.state_dim) as f32).sqrt();
218 let b_proj = Array2::from_shape_fn((config.hidden_dim, config.state_dim), |_| {
219 (rng.random::<f32>() - 0.5) * 2.0 * scale
220 });
221
222 let c_proj = Array2::from_shape_fn((config.hidden_dim, config.state_dim), |_| {
223 (rng.random::<f32>() - 0.5) * 2.0 * scale
224 });
225
226 let d_skip =
227 Array1::from_shape_fn(config.hidden_dim, |_| (rng.random::<f32>() - 0.5) * 0.1);
228
229 let scale = (2.0 / config.hidden_dim as f32).sqrt();
231 let gate_proj = Array2::from_shape_fn((config.hidden_dim, config.hidden_dim), |_| {
232 (rng.random::<f32>() - 0.5) * 2.0 * scale
233 });
234
235 let out_proj = Array2::from_shape_fn((config.hidden_dim, config.hidden_dim), |_| {
236 (rng.random::<f32>() - 0.5) * 2.0 * scale
237 });
238
239 let states = (0..config.num_heads)
241 .map(|_| Array2::zeros((config.head_dim, config.state_dim)))
242 .collect();
243
244 Ok(Self {
245 hidden_dim: config.hidden_dim,
246 state_dim: config.state_dim,
247 num_heads: config.num_heads,
248 head_dim: config.head_dim,
249 norm,
250 conv,
251 a_log,
252 b_proj,
253 c_proj,
254 d_skip,
255 gate_proj,
256 out_proj,
257 states,
258 })
259 }
260
261 fn ssd_step(&mut self, x: &Array1<f32>) -> CoreResult<Array1<f32>> {
269 let mut output = Array1::zeros(x.len().min(self.hidden_dim));
270
271 let mut b_x = Array1::zeros(self.state_dim);
273 for i in 0..self.state_dim {
274 let mut sum = 0.0;
275 for j in 0..self.hidden_dim.min(x.len()) {
276 sum += self.b_proj[[j, i]] * x[j];
277 }
278 b_x[i] = sum;
279 }
280
281 for head in 0..self.num_heads {
283 let head_start = head * self.head_dim;
284 let head_end = (head_start + self.head_dim).min(self.hidden_dim);
285
286 let h = &self.states[head];
288
289 let a_diag = self.a_log.row(head).mapv(|x| x.exp());
291
292 let mut new_h = Array2::zeros((self.head_dim, self.state_dim));
295 for i in 0..self.head_dim.min(h.shape()[0]) {
296 for j in 0..self.state_dim {
297 let a_val = if j < a_diag.len() {
299 a_diag[j]
300 } else {
301 0.99 };
303 new_h[[i, j]] = a_val * h[[i, j]] + b_x[j] * 0.01; }
305 }
306
307 self.states[head] = new_h.clone();
309
310 for (i, out_idx) in (head_start..head_end).enumerate() {
312 if out_idx >= output.len() {
313 break;
314 }
315 let mut c_h = 0.0;
316 for j in 0..self.state_dim {
317 if out_idx < self.c_proj.shape()[0] && i < new_h.shape()[0] {
318 c_h += self.c_proj[[out_idx, j]] * new_h[[i, j]];
319 }
320 }
321 output[out_idx] = c_h;
322 }
323 }
324
325 for (i, val) in output.iter_mut().enumerate() {
327 if i < self.d_skip.len() && i < x.len() {
328 *val += self.d_skip[i] * x[i];
329 }
330 }
331
332 Ok(output)
333 }
334
335 fn forward(&mut self, x: &Array1<f32>) -> CoreResult<Array1<f32>> {
336 let mut h = if let Some(ref norm) = self.norm {
338 norm.forward(x)
339 } else {
340 x.clone()
341 };
342
343 let h_vec = h.to_vec();
345 let conv_out = self.conv.forward_step(&h_vec);
346 h = Array1::from_vec(conv_out);
347
348 h = self.ssd_step(&h)?;
350
351 let mut gate_vec = Vec::with_capacity(h.len().min(self.hidden_dim));
353 for i in 0..h.len().min(self.hidden_dim) {
354 let mut sum = 0.0;
355 for j in 0..h.len().min(self.hidden_dim) {
356 if i < self.gate_proj.shape()[0] && j < self.gate_proj.shape()[1] {
357 sum += self.gate_proj[[i, j]] * h[j];
358 }
359 }
360 gate_vec.push(sum);
361 }
362 let gate_arr = Array1::from_vec(gate_vec);
363 let gate = silu(&gate_arr);
364
365 for i in 0..h.len().min(gate.len()) {
367 h[i] *= gate[i];
368 }
369
370 let mut output = Array1::zeros(x.len());
372 for i in 0..output.len().min(self.out_proj.shape()[0]) {
373 let mut sum = 0.0;
374 for j in 0..h.len().min(self.out_proj.shape()[1]) {
375 sum += self.out_proj[[i, j]] * h[j];
376 }
377 output[i] = sum;
378 }
379
380 for i in 0..output.len().min(x.len()) {
382 output[i] += x[i];
383 }
384
385 Ok(output)
386 }
387
388 fn reset(&mut self) {
389 for state in &mut self.states {
390 state.fill(0.0);
391 }
392 }
393}
394
395pub struct Mamba2 {
397 config: Mamba2Config,
398 layers: Vec<Mamba2Layer>,
399 input_proj: Array2<f32>,
401 output_proj: Array2<f32>,
403}
404
405impl Mamba2 {
406 pub fn new(config: Mamba2Config) -> ModelResult<Self> {
408 config.validate()?;
409
410 let mut layers = Vec::with_capacity(config.num_layers);
412 for _ in 0..config.num_layers {
413 layers.push(Mamba2Layer::new(&config)?);
414 }
415
416 let mut rng = rng();
418 let scale = (2.0 / (config.input_dim + config.hidden_dim) as f32).sqrt();
419 let input_proj = Array2::from_shape_fn((config.input_dim, config.hidden_dim), |_| {
420 (rng.random::<f32>() - 0.5) * 2.0 * scale
421 });
422
423 let scale = (2.0 / (config.hidden_dim + config.input_dim) as f32).sqrt();
424 let output_proj = Array2::from_shape_fn((config.hidden_dim, config.input_dim), |_| {
425 (rng.random::<f32>() - 0.5) * 2.0 * scale
426 });
427
428 Ok(Self {
429 config,
430 layers,
431 input_proj,
432 output_proj,
433 })
434 }
435
436 pub fn config(&self) -> &Mamba2Config {
438 &self.config
439 }
440
441 pub fn load_weights(&mut self, loader: &crate::loader::ModelLoader) -> ModelResult<()> {
463 if loader.has_tensor("input_proj") {
465 self.input_proj = loader.load_array2("input_proj")?;
466 }
467 if loader.has_tensor("output_proj") {
468 self.output_proj = loader.load_array2("output_proj")?;
469 }
470
471 for (i, layer) in self.layers.iter_mut().enumerate() {
473 let prefix = format!("layers.{}", i);
474
475 if let Some(ref mut norm) = layer.norm {
477 if loader.has_tensor(&format!("{}.norm.weight", prefix)) {
478 let weight = loader.load_array1(&format!("{}.norm.weight", prefix))?;
479 norm.set_gamma(weight);
480 }
481 if loader.has_tensor(&format!("{}.norm.bias", prefix)) {
482 let bias = loader.load_array1(&format!("{}.norm.bias", prefix))?;
483 norm.set_beta(bias);
484 }
485 }
486
487 if loader.has_tensor(&format!("{}.conv.weight", prefix)) {
489 let conv_weights = loader.load_array3(&format!("{}.conv.weight", prefix))?;
490 layer.conv.set_weights(conv_weights);
491 }
492 if loader.has_tensor(&format!("{}.conv.bias", prefix)) {
493 let conv_bias = loader.load_array1(&format!("{}.conv.bias", prefix))?;
494 layer.conv.set_bias(conv_bias.to_vec());
495 }
496
497 if loader.has_tensor(&format!("{}.a_log", prefix)) {
499 layer.a_log = loader.load_array2(&format!("{}.a_log", prefix))?;
500 }
501 if loader.has_tensor(&format!("{}.b_proj", prefix)) {
502 layer.b_proj = loader.load_array2(&format!("{}.b_proj", prefix))?;
503 }
504 if loader.has_tensor(&format!("{}.c_proj", prefix)) {
505 layer.c_proj = loader.load_array2(&format!("{}.c_proj", prefix))?;
506 }
507 if loader.has_tensor(&format!("{}.d_skip", prefix)) {
508 layer.d_skip = loader.load_array1(&format!("{}.d_skip", prefix))?;
509 }
510 if loader.has_tensor(&format!("{}.gate_proj", prefix)) {
511 layer.gate_proj = loader.load_array2(&format!("{}.gate_proj", prefix))?;
512 }
513 if loader.has_tensor(&format!("{}.out_proj", prefix)) {
514 layer.out_proj = loader.load_array2(&format!("{}.out_proj", prefix))?;
515 }
516 }
517
518 Ok(())
519 }
520
521 #[allow(unused_variables)]
523 pub fn save_weights(&self, path: &str) -> ModelResult<()> {
524 Err(ModelError::simple_load_error(
526 "Mamba2 save_weights not yet implemented".to_string(),
527 ))
528 }
529}
530
531impl SignalPredictor for Mamba2 {
532 #[instrument(skip(self, input))]
533 fn step(&mut self, input: &Array1<f32>) -> CoreResult<Array1<f32>> {
534 let mut hidden = input.dot(&self.input_proj);
536
537 for layer in &mut self.layers {
539 hidden = layer.forward(&hidden)?;
540 }
541
542 let output = hidden.dot(&self.output_proj);
544 Ok(output)
545 }
546
547 fn reset(&mut self) {
548 for layer in &mut self.layers {
549 layer.reset();
550 }
551 }
552
553 fn context_window(&self) -> usize {
554 usize::MAX
556 }
557}
558
559impl AutoregressiveModel for Mamba2 {
560 fn hidden_dim(&self) -> usize {
561 self.config.hidden_dim
562 }
563
564 fn state_dim(&self) -> usize {
565 self.config.state_dim
566 }
567
568 fn num_layers(&self) -> usize {
569 self.config.num_layers
570 }
571
572 fn model_type(&self) -> ModelType {
573 ModelType::Mamba2
574 }
575
576 fn get_states(&self) -> Vec<HiddenState> {
577 self.layers
579 .iter()
580 .map(|layer| {
581 let total_size = layer.head_dim * layer.num_heads;
583 let mut combined = Array2::zeros((total_size, layer.state_dim));
584
585 for (head_idx, head_state) in layer.states.iter().enumerate() {
586 let start_idx = head_idx * layer.head_dim;
587 for i in 0..layer.head_dim.min(head_state.shape()[0]) {
588 for j in 0..layer.state_dim {
589 combined[[start_idx + i, j]] = head_state[[i, j]];
590 }
591 }
592 }
593
594 {
595 let mut hs = HiddenState::new(combined.shape()[0], combined.shape()[1]);
596 hs.update(combined);
597 hs
598 }
599 })
600 .collect()
601 }
602
603 fn set_states(&mut self, states: Vec<HiddenState>) -> ModelResult<()> {
604 if states.len() != self.config.num_layers {
605 return Err(ModelError::state_count_mismatch(
606 "Mamba2",
607 self.config.num_layers,
608 states.len(),
609 ));
610 }
611
612 for (layer_idx, layer) in self.layers.iter_mut().enumerate() {
614 let combined = states[layer_idx].state();
615
616 for (head_idx, head_state) in layer.states.iter_mut().enumerate() {
617 let start_idx = head_idx * layer.head_dim;
618 for i in 0..layer.head_dim.min(head_state.shape()[0]) {
619 for j in 0..layer.state_dim.min(combined.shape()[1]) {
620 if start_idx + i < combined.shape()[0] {
621 head_state[[i, j]] = combined[[start_idx + i, j]];
622 }
623 }
624 }
625 }
626 }
627
628 Ok(())
629 }
630}
631
632#[cfg(test)]
633mod tests {
634 use super::*;
635
636 #[test]
637 fn test_mamba2_config() {
638 let config = Mamba2Config::new()
639 .hidden_dim(512)
640 .num_heads(8)
641 .num_layers(4);
642
643 assert_eq!(config.hidden_dim, 512);
644 assert_eq!(config.num_heads, 8);
645 assert_eq!(config.head_dim, 64);
646 assert!(config.validate().is_ok());
647 }
648
649 #[test]
650 fn test_mamba2_creation() {
651 let config = Mamba2Config::new().hidden_dim(256).num_heads(4);
652 let model = Mamba2::new(config);
653 assert!(model.is_ok());
654 }
655
656 #[test]
657 fn test_mamba2_forward() {
658 let config = Mamba2Config::new()
659 .hidden_dim(128)
660 .num_heads(4)
661 .num_layers(2);
662 let mut model = Mamba2::new(config).expect("Failed to create Mamba2 model");
663
664 let input = Array1::from_vec(vec![0.5]);
665 let output = model.step(&input);
666 assert!(output.is_ok());
667 }
668
669 #[test]
670 fn test_invalid_config() {
671 let config = Mamba2Config::new().hidden_dim(100).num_heads(3); assert!(config.validate().is_err());
673 }
674}