1use crate::error::{ModelError, ModelResult};
39use crate::AutoregressiveModel;
40use kizzasi_core::{
41 silu, CausalConv1d, CoreResult, HiddenState, LayerNorm, NormType, SignalPredictor,
42};
43use scirs2_core::ndarray::{Array1, Array2};
44use scirs2_core::random::{rng, Rng};
45use tracing::{debug, instrument, trace};
46
47#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
49pub struct MambaConfig {
50 pub input_dim: usize,
52 pub hidden_dim: usize,
54 pub state_dim: usize,
56 pub expand_factor: usize,
58 pub conv_kernel_size: usize,
60 pub num_layers: usize,
62 pub dropout: f32,
64 pub use_mamba2: bool,
66}
67
68impl Default for MambaConfig {
69 fn default() -> Self {
70 Self {
71 input_dim: 1,
72 hidden_dim: 256,
73 state_dim: 16,
74 expand_factor: 2,
75 conv_kernel_size: 4,
76 num_layers: 4,
77 dropout: 0.0,
78 use_mamba2: true,
79 }
80 }
81}
82
83impl MambaConfig {
84 pub fn new() -> Self {
86 Self::default()
87 }
88
89 pub fn tiny(input_dim: usize) -> Self {
103 Self {
104 input_dim,
105 hidden_dim: 128,
106 state_dim: 8,
107 expand_factor: 2,
108 conv_kernel_size: 4,
109 num_layers: 2,
110 dropout: 0.0,
111 use_mamba2: false, }
113 }
114
115 pub fn small(input_dim: usize) -> Self {
129 Self {
130 input_dim,
131 hidden_dim: 256,
132 state_dim: 16,
133 expand_factor: 2,
134 conv_kernel_size: 4,
135 num_layers: 4,
136 dropout: 0.1,
137 use_mamba2: true,
138 }
139 }
140
141 pub fn base(input_dim: usize) -> Self {
155 Self {
156 input_dim,
157 hidden_dim: 512,
158 state_dim: 16,
159 expand_factor: 2,
160 conv_kernel_size: 4,
161 num_layers: 6,
162 dropout: 0.1,
163 use_mamba2: true,
164 }
165 }
166
167 pub fn large(input_dim: usize) -> Self {
181 Self {
182 input_dim,
183 hidden_dim: 1024,
184 state_dim: 32,
185 expand_factor: 2,
186 conv_kernel_size: 4,
187 num_layers: 12,
188 dropout: 0.1,
189 use_mamba2: true,
190 }
191 }
192
193 pub fn xlarge(input_dim: usize) -> Self {
207 Self {
208 input_dim,
209 hidden_dim: 2048,
210 state_dim: 64,
211 expand_factor: 2,
212 conv_kernel_size: 4,
213 num_layers: 24,
214 dropout: 0.2,
215 use_mamba2: true,
216 }
217 }
218
219 pub fn input_dim(mut self, dim: usize) -> Self {
221 self.input_dim = dim;
222 self
223 }
224
225 pub fn hidden_dim(mut self, dim: usize) -> Self {
227 self.hidden_dim = dim;
228 self
229 }
230
231 pub fn state_dim(mut self, dim: usize) -> Self {
233 self.state_dim = dim;
234 self
235 }
236
237 pub fn num_layers(mut self, n: usize) -> Self {
239 self.num_layers = n;
240 self
241 }
242
243 pub fn mamba2(mut self, use_mamba2: bool) -> Self {
245 self.use_mamba2 = use_mamba2;
246 self
247 }
248
249 pub fn validate(&self) -> ModelResult<()> {
251 if self.hidden_dim == 0 {
252 return Err(ModelError::invalid_config("hidden_dim must be > 0"));
253 }
254 if self.state_dim == 0 {
255 return Err(ModelError::invalid_config("state_dim must be > 0"));
256 }
257 if self.num_layers == 0 {
258 return Err(ModelError::invalid_config("num_layers must be > 0"));
259 }
260 if self.expand_factor == 0 {
261 return Err(ModelError::invalid_config("expand_factor must be > 0"));
262 }
263 Ok(())
264 }
265}
266
267struct SelectiveSSM {
269 state_dim: usize,
270 inner_dim: usize,
271
272 log_a: Array1<f32>,
275
276 delta_proj: Array2<f32>, delta_bias: Array1<f32>, b_proj: Array2<f32>, c_proj: Array2<f32>, d_skip: Array1<f32>, state: Array2<f32>, }
293
294impl SelectiveSSM {
295 fn new(config: &MambaConfig) -> ModelResult<Self> {
296 let mut rng = rng();
297 let inner_dim = config.hidden_dim * config.expand_factor;
298
299 let log_a = Array1::from_shape_fn(config.state_dim, |n| ((n + 1) as f32).ln());
303
304 let scale = (2.0 / inner_dim as f32).sqrt();
306
307 let delta_proj = Array2::from_shape_fn((inner_dim, inner_dim), |_| {
308 (rng.random::<f32>() - 0.5) * 2.0 * scale
309 });
310 let delta_bias = Array1::from_shape_fn(inner_dim, |_| rng.random::<f32>() * 0.1);
311
312 let b_proj = Array2::from_shape_fn((inner_dim, config.state_dim), |_| {
313 (rng.random::<f32>() - 0.5) * 2.0 * scale
314 });
315
316 let c_proj = Array2::from_shape_fn((inner_dim, config.state_dim), |_| {
317 (rng.random::<f32>() - 0.5) * 2.0 * scale
318 });
319
320 let d_skip = Array1::ones(inner_dim);
321
322 let state = Array2::zeros((inner_dim, config.state_dim));
323
324 Ok(Self {
325 state_dim: config.state_dim,
326 inner_dim,
327 log_a,
328 delta_proj,
329 delta_bias,
330 b_proj,
331 c_proj,
332 d_skip,
333 state,
334 })
335 }
336
337 fn forward_step(&mut self, x: &Array1<f32>) -> CoreResult<Array1<f32>> {
341 let batch_size = x.len().min(self.inner_dim);
342
343 let mut delta = Array1::zeros(batch_size);
346 for i in 0..batch_size {
347 let mut sum = self.delta_bias[i];
348 for j in 0..batch_size {
349 sum += self.delta_proj[[i, j]] * x[j];
350 }
351 let clamped = sum.clamp(-20.0, 20.0);
354 delta[i] = (1.0 + clamped.exp()).ln().clamp(1e-6, 0.1);
355 }
356
357 let mut b_vec = Array2::zeros((batch_size, self.state_dim));
360 for i in 0..batch_size {
361 for n in 0..self.state_dim {
362 let mut sum = 0.0;
363 for j in 0..batch_size {
364 sum += if j < self.b_proj.shape()[0] && n < self.b_proj.shape()[1] {
366 self.b_proj[[j, n]] * x[j]
367 } else {
368 0.0
369 };
370 }
371 b_vec[[i, n]] = sum;
372 }
373 }
374
375 let mut c_vec = Array2::zeros((batch_size, self.state_dim));
378 for i in 0..batch_size {
379 for n in 0..self.state_dim {
380 let mut sum = 0.0;
381 for j in 0..batch_size {
382 sum += if j < self.c_proj.shape()[0] && n < self.c_proj.shape()[1] {
384 self.c_proj[[j, n]] * x[j]
385 } else {
386 0.0
387 };
388 }
389 c_vec[[i, n]] = sum;
390 }
391 }
392
393 let mut a_bar = Array2::zeros((batch_size, self.state_dim));
396 for i in 0..batch_size {
397 for n in 0..self.state_dim {
398 let a_n = -self.log_a[n].exp(); let delta_a = delta[i] * a_n;
400 a_bar[[i, n]] = delta_a.clamp(-20.0, 20.0).exp();
402 }
403 }
404
405 let mut b_bar = Array2::zeros((batch_size, self.state_dim));
410 for i in 0..batch_size {
411 for n in 0..self.state_dim {
412 let a_n = -self.log_a[n].exp();
413
414 if delta[i].abs() < 0.001 {
416 b_bar[[i, n]] = delta[i] * b_vec[[i, n]];
418 } else {
419 let safe_a_n = if a_n.abs() < 1e-8 { -1.0 } else { a_n };
422 b_bar[[i, n]] = (a_bar[[i, n]] - 1.0) / safe_a_n * b_vec[[i, n]];
423 }
424 }
425 }
426
427 let mut new_state = Array2::zeros((batch_size, self.state_dim));
429 for i in 0..batch_size {
430 for n in 0..self.state_dim {
431 let decay = a_bar[[i, n]];
433 let input_contrib = b_bar[[i, n]] * x[i];
434
435 new_state[[i, n]] = decay * self.state[[i, n]] + input_contrib;
436 }
437 }
438
439 for i in 0..batch_size.min(self.state.shape()[0]) {
441 for n in 0..self.state_dim {
442 self.state[[i, n]] = new_state[[i, n]];
443 }
444 }
445
446 let mut output = Array1::zeros(batch_size);
448 for i in 0..batch_size {
449 let mut c_h = 0.0;
450 for n in 0..self.state_dim {
451 c_h += c_vec[[i, n]] * new_state[[i, n]];
452 }
453 output[i] = c_h + self.d_skip[i] * x[i];
454 }
455
456 Ok(output)
457 }
458
459 fn reset(&mut self) {
460 self.state.fill(0.0);
461 }
462}
463
464struct MambaLayer {
466 hidden_dim: usize,
467 inner_dim: usize,
468
469 norm: LayerNorm,
471
472 in_proj: Array2<f32>, conv: CausalConv1d,
477
478 ssm: SelectiveSSM,
480
481 out_proj: Array2<f32>, }
484
485impl MambaLayer {
486 fn new(config: &MambaConfig) -> ModelResult<Self> {
487 let inner_dim = config.hidden_dim * config.expand_factor;
488 let mut rng = rng();
489
490 let norm = LayerNorm::new(config.hidden_dim, NormType::RMSNorm).with_eps(1e-5);
492
493 let scale = (2.0 / config.hidden_dim as f32).sqrt();
495 let in_proj = Array2::from_shape_fn((config.hidden_dim, inner_dim * 2), |_| {
496 (rng.random::<f32>() - 0.5) * 2.0 * scale
497 });
498
499 let conv = CausalConv1d::new(inner_dim, inner_dim, config.conv_kernel_size);
501
502 let ssm = SelectiveSSM::new(config)?;
504
505 let scale = (2.0 / inner_dim as f32).sqrt();
507 let out_proj = Array2::from_shape_fn((inner_dim, config.hidden_dim), |_| {
508 (rng.random::<f32>() - 0.5) * 2.0 * scale
509 });
510
511 Ok(Self {
512 hidden_dim: config.hidden_dim,
513 inner_dim,
514 norm,
515 in_proj,
516 conv,
517 ssm,
518 out_proj,
519 })
520 }
521
522 fn forward(&mut self, x: &Array1<f32>) -> CoreResult<Array1<f32>> {
523 let batch_size = x.len().min(self.hidden_dim);
524
525 let x_norm = self.norm.forward(x);
527
528 let mut projected = Array1::zeros(self.inner_dim * 2);
531 for i in 0..(self.inner_dim * 2) {
532 let mut sum = 0.0;
533 for j in 0..batch_size {
534 if i < self.in_proj.shape()[1] {
535 sum += self.in_proj[[j, i]] * x_norm[j];
536 }
537 }
538 projected[i] = sum;
539 }
540
541 let mut x_ssm = Array1::zeros(self.inner_dim);
543 let mut x_gate = Array1::zeros(self.inner_dim);
544 for i in 0..self.inner_dim {
545 x_ssm[i] = projected[i];
546 x_gate[i] = projected[self.inner_dim + i];
547 }
548
549 let x_ssm_vec = x_ssm.to_vec();
551 let conv_out = self.conv.forward_step(&x_ssm_vec);
552 x_ssm = Array1::from_vec(conv_out);
553
554 let ssm_out = self.ssm.forward_step(&x_ssm)?;
556
557 let gate = silu(&x_gate);
559
560 let mut gated = Array1::zeros(ssm_out.len().min(gate.len()));
562 for i in 0..gated.len() {
563 gated[i] = ssm_out[i] * gate[i];
564 }
565
566 let mut output = Array1::zeros(batch_size);
568 for i in 0..batch_size {
569 let mut sum = 0.0;
570 for j in 0..gated.len().min(self.out_proj.shape()[0]) {
571 sum += self.out_proj[[j, i]] * gated[j];
572 }
573 output[i] = sum;
574 }
575
576 for i in 0..output.len().min(x.len()) {
578 output[i] += x[i];
579 }
580
581 Ok(output)
582 }
583
584 fn reset(&mut self) {
585 self.ssm.reset();
586 self.conv.reset();
587 }
588}
589
590pub struct Mamba {
592 config: MambaConfig,
593 layers: Vec<MambaLayer>,
594 input_proj: Array2<f32>,
595 output_proj: Array2<f32>,
596}
597
598impl Mamba {
599 #[instrument(skip(config), fields(input_dim = config.input_dim, hidden_dim = config.hidden_dim, num_layers = config.num_layers))]
601 pub fn new(config: MambaConfig) -> ModelResult<Self> {
602 debug!("Creating new Mamba model");
603 config.validate()?;
604
605 let mut layers = Vec::with_capacity(config.num_layers);
607 for layer_idx in 0..config.num_layers {
608 trace!("Initializing Mamba layer {}", layer_idx);
609 layers.push(MambaLayer::new(&config)?);
610 }
611 debug!("Initialized {} Mamba layers", layers.len());
612
613 let mut rng = rng();
615 let scale = (2.0 / (config.input_dim + config.hidden_dim) as f32).sqrt();
616 let input_proj = Array2::from_shape_fn((config.input_dim, config.hidden_dim), |_| {
617 (rng.random::<f32>() - 0.5) * 2.0 * scale
618 });
619
620 let scale = (2.0 / (config.hidden_dim + config.input_dim) as f32).sqrt();
621 let output_proj = Array2::from_shape_fn((config.hidden_dim, config.input_dim), |_| {
622 (rng.random::<f32>() - 0.5) * 2.0 * scale
623 });
624
625 debug!("Mamba model created successfully");
626 Ok(Self {
627 config,
628 layers,
629 input_proj,
630 output_proj,
631 })
632 }
633
634 pub fn load_weights(&mut self, loader: &crate::loader::ModelLoader) -> ModelResult<()> {
665 if loader.has_tensor("input_proj") {
667 self.input_proj = loader.load_array2("input_proj")?;
668 }
669 if loader.has_tensor("output_proj") {
670 self.output_proj = loader.load_array2("output_proj")?;
671 }
672
673 for (i, layer) in self.layers.iter_mut().enumerate() {
675 let prefix = format!("layers.{}", i);
676
677 if loader.has_tensor(&format!("{}.norm.weight", prefix)) {
679 let _weight = loader.load_array1(&format!("{}.norm.weight", prefix))?;
680 }
683
684 if loader.has_tensor(&format!("{}.in_proj", prefix)) {
686 layer.in_proj = loader.load_array2(&format!("{}.in_proj", prefix))?;
687 }
688
689 if loader.has_tensor(&format!("{}.conv.weight", prefix)) {
691 }
694
695 if loader.has_tensor(&format!("{}.ssm.log_a", prefix)) {
697 layer.ssm.log_a = loader.load_array1(&format!("{}.ssm.log_a", prefix))?;
698 }
699 if loader.has_tensor(&format!("{}.ssm.delta_proj", prefix)) {
700 layer.ssm.delta_proj = loader.load_array2(&format!("{}.ssm.delta_proj", prefix))?;
701 }
702 if loader.has_tensor(&format!("{}.ssm.delta_bias", prefix)) {
703 layer.ssm.delta_bias = loader.load_array1(&format!("{}.ssm.delta_bias", prefix))?;
704 }
705 if loader.has_tensor(&format!("{}.ssm.b_proj", prefix)) {
706 layer.ssm.b_proj = loader.load_array2(&format!("{}.ssm.b_proj", prefix))?;
707 }
708 if loader.has_tensor(&format!("{}.ssm.c_proj", prefix)) {
709 layer.ssm.c_proj = loader.load_array2(&format!("{}.ssm.c_proj", prefix))?;
710 }
711 if loader.has_tensor(&format!("{}.ssm.d_skip", prefix)) {
712 layer.ssm.d_skip = loader.load_array1(&format!("{}.ssm.d_skip", prefix))?;
713 }
714
715 if loader.has_tensor(&format!("{}.out_proj", prefix)) {
717 layer.out_proj = loader.load_array2(&format!("{}.out_proj", prefix))?;
718 }
719 }
720
721 Ok(())
722 }
723
724 pub fn save_weights<P: AsRef<std::path::Path>>(&self, _path: P) -> ModelResult<()> {
733 Err(ModelError::simple_load_error(
736 "save_weights not yet implemented".to_string(),
737 ))
738 }
739
740 pub fn config(&self) -> &MambaConfig {
742 &self.config
743 }
744}
745
746impl SignalPredictor for Mamba {
747 #[instrument(skip(self, input), fields(input_size = input.len()))]
748 fn step(&mut self, input: &Array1<f32>) -> CoreResult<Array1<f32>> {
749 trace!(
750 "Mamba step input range: [{}, {}]",
751 input.iter().cloned().fold(f32::INFINITY, f32::min),
752 input.iter().cloned().fold(f32::NEG_INFINITY, f32::max)
753 );
754
755 let mut hidden = input.dot(&self.input_proj);
757 trace!("After input projection: hidden_dim={}", hidden.len());
758
759 for (layer_idx, layer) in self.layers.iter_mut().enumerate() {
761 trace!("Processing Mamba layer {}", layer_idx);
762 hidden = layer.forward(&hidden)?;
763 }
764
765 let output = hidden.dot(&self.output_proj);
767 trace!(
768 "Mamba step output range: [{}, {}]",
769 output.iter().cloned().fold(f32::INFINITY, f32::min),
770 output.iter().cloned().fold(f32::NEG_INFINITY, f32::max)
771 );
772 Ok(output)
773 }
774
775 #[instrument(skip(self))]
776 fn reset(&mut self) {
777 debug!("Resetting Mamba model state");
778 for (layer_idx, layer) in self.layers.iter_mut().enumerate() {
779 trace!("Resetting layer {}", layer_idx);
780 layer.reset();
781 }
782 }
783
784 fn context_window(&self) -> usize {
785 usize::MAX
787 }
788}
789
790impl AutoregressiveModel for Mamba {
791 fn hidden_dim(&self) -> usize {
792 self.config.hidden_dim
793 }
794
795 fn state_dim(&self) -> usize {
796 self.config.state_dim
797 }
798
799 fn num_layers(&self) -> usize {
800 self.config.num_layers
801 }
802
803 fn model_type(&self) -> crate::ModelType {
804 if self.config.use_mamba2 {
805 crate::ModelType::Mamba2
806 } else {
807 crate::ModelType::Mamba
808 }
809 }
810
811 fn get_states(&self) -> Vec<HiddenState> {
812 self.layers
813 .iter()
814 .map(|layer| {
815 let state = layer.ssm.state.clone();
816 let mut hs = HiddenState::new(state.shape()[0], state.shape()[1]);
817 hs.update(state);
818 let conv_history = layer.conv.get_history();
820 hs.set_conv_history(conv_history);
821 hs
822 })
823 .collect()
824 }
825
826 fn set_states(&mut self, states: Vec<HiddenState>) -> ModelResult<()> {
827 if states.len() != self.config.num_layers {
828 return Err(ModelError::state_count_mismatch(
829 "Mamba",
830 self.config.num_layers,
831 states.len(),
832 ));
833 }
834
835 for (layer_idx, layer) in self.layers.iter_mut().enumerate() {
836 layer.ssm.state = states[layer_idx].state().clone();
837 if let Some(conv_history) = states[layer_idx].conv_history() {
839 layer.conv.set_history(conv_history.clone());
840 }
841 }
842
843 Ok(())
844 }
845}
846
847#[cfg(test)]
848mod tests {
849 use super::*;
850
851 #[test]
852 fn test_mamba_creation() {
853 let config = MambaConfig::new()
854 .input_dim(3)
855 .hidden_dim(64)
856 .state_dim(8)
857 .num_layers(2);
858
859 let mamba = Mamba::new(config);
860 assert!(mamba.is_ok());
861 }
862
863 #[test]
864 fn test_mamba_step() {
865 let config = MambaConfig::new()
866 .input_dim(3)
867 .hidden_dim(32)
868 .state_dim(8)
869 .num_layers(2);
870
871 let mut mamba = Mamba::new(config).expect("Failed to create Mamba model");
872 let input = Array1::from_vec(vec![0.1, 0.2, 0.3]);
873 let output = mamba.step(&input);
874
875 assert!(output.is_ok());
876 assert_eq!(output.expect("Failed to get output").len(), 3);
877 }
878
879 #[test]
880 fn test_mamba_tiny_config() {
881 let config = MambaConfig::tiny(4);
882 assert_eq!(config.hidden_dim, 128);
883 assert_eq!(config.state_dim, 8);
884 assert_eq!(config.num_layers, 2);
885 assert!(!config.use_mamba2);
886
887 let mut model = Mamba::new(config).expect("Failed to create Mamba model");
888 let input = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4]);
889 let output = model.step(&input).expect("Failed to get output");
890 assert_eq!(output.len(), 4);
891 }
892
893 #[test]
894 fn test_mamba_small_config() {
895 let config = MambaConfig::small(4);
897 assert_eq!(config.hidden_dim, 256);
898 assert_eq!(config.state_dim, 16);
899 assert_eq!(config.num_layers, 4);
900 assert!(config.use_mamba2);
901
902 let minimal_config = MambaConfig::new()
905 .input_dim(4)
906 .hidden_dim(64)
907 .state_dim(8)
908 .num_layers(2);
909 let mut model = Mamba::new(minimal_config).expect("Failed to create Mamba model");
910 let input = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4]);
911 let output = model.step(&input).expect("Failed to get output");
912 assert_eq!(output.len(), 4);
913 }
914
915 #[test]
916 fn test_mamba_base_config() {
917 let config = MambaConfig::base(4);
919 assert_eq!(config.hidden_dim, 512);
920 assert_eq!(config.state_dim, 16);
921 assert_eq!(config.num_layers, 6);
922 assert!(config.use_mamba2);
923
924 let minimal_config = MambaConfig::new()
927 .input_dim(4)
928 .hidden_dim(64)
929 .state_dim(8)
930 .num_layers(2);
931 let mut model = Mamba::new(minimal_config).expect("Failed to create Mamba model");
932 let input = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4]);
933 let output = model.step(&input).expect("Failed to get output");
934 assert_eq!(output.len(), 4);
935 }
936
937 #[test]
938 #[ignore] fn test_mamba_large_config() {
940 let config = MambaConfig::large(4);
941 assert_eq!(config.hidden_dim, 1024);
942 assert_eq!(config.state_dim, 32);
943 assert_eq!(config.num_layers, 12);
944 assert!(config.use_mamba2);
945
946 let mut model = Mamba::new(config).expect("Failed to create Mamba model");
947 let input = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4]);
948 let output = model.step(&input).expect("Failed to get output");
949 assert_eq!(output.len(), 4);
950 }
951
952 #[test]
953 #[ignore] fn test_mamba_xlarge_config() {
955 let config = MambaConfig::xlarge(2);
956 assert_eq!(config.hidden_dim, 2048);
957 assert_eq!(config.state_dim, 64);
958 assert_eq!(config.num_layers, 24);
959 assert!(config.use_mamba2);
960
961 let model = Mamba::new(config);
963 assert!(model.is_ok());
964 }
965
966 #[test]
967 fn test_preset_configs_size_progression() {
968 let tiny = MambaConfig::tiny(1);
970 let small = MambaConfig::small(1);
971 let base = MambaConfig::base(1);
972 let large = MambaConfig::large(1);
973 let xlarge = MambaConfig::xlarge(1);
974
975 assert!(tiny.hidden_dim < small.hidden_dim);
976 assert!(small.hidden_dim < base.hidden_dim);
977 assert!(base.hidden_dim < large.hidden_dim);
978 assert!(large.hidden_dim < xlarge.hidden_dim);
979
980 assert!(tiny.num_layers <= small.num_layers);
981 assert!(small.num_layers <= base.num_layers);
982 assert!(base.num_layers <= large.num_layers);
983 assert!(large.num_layers <= xlarge.num_layers);
984 }
985}