use crate::{CoreError, CoreResult, HiddenState};
use scirs2_core::ndarray::{s, Array1, Array2, Axis};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Mamba2Config {
pub d_model: usize,
pub d_state: usize,
pub expand: usize,
pub n_heads: usize,
pub d_head: usize,
pub chunk_size: usize,
pub use_bias: bool,
pub dt_rank: usize,
pub dt_min: f32,
pub dt_max: f32,
pub dt_init_scale: f32,
pub layer_norm_eps: f32,
}
impl Default for Mamba2Config {
fn default() -> Self {
Self {
d_model: 256,
d_state: 64,
expand: 2,
n_heads: 8,
d_head: 32,
chunk_size: 256,
use_bias: false,
dt_rank: 32,
dt_min: 0.001,
dt_max: 0.1,
dt_init_scale: 1.0,
layer_norm_eps: 1e-5,
}
}
}
impl Mamba2Config {
pub fn new(d_model: usize, d_state: usize) -> Self {
let n_heads = (d_model / 32).max(1);
let d_head = d_model / n_heads;
Self {
d_model,
d_state,
n_heads,
d_head,
..Default::default()
}
}
pub fn with_expand(mut self, expand: usize) -> Self {
self.expand = expand;
self
}
pub fn with_heads(mut self, n_heads: usize) -> Self {
self.n_heads = n_heads;
self.d_head = self.d_model / n_heads;
self
}
pub fn with_chunk_size(mut self, chunk_size: usize) -> Self {
self.chunk_size = chunk_size;
self
}
pub fn d_inner(&self) -> usize {
self.d_model * self.expand
}
pub fn validate(&self) -> CoreResult<()> {
if self.d_model == 0 {
return Err(CoreError::InvalidConfig("d_model must be > 0".into()));
}
if self.d_state == 0 {
return Err(CoreError::InvalidConfig("d_state must be > 0".into()));
}
if !self.d_model.is_multiple_of(self.n_heads) {
return Err(CoreError::InvalidConfig(
"d_model must be divisible by n_heads".into(),
));
}
if self.dt_min >= self.dt_max {
return Err(CoreError::InvalidConfig("dt_min must be < dt_max".into()));
}
Ok(())
}
}
pub struct Mamba2Layer {
config: Mamba2Config,
in_proj_w: Array2<f32>,
out_proj_w: Array2<f32>,
conv1d_w: Array2<f32>,
conv1d_b: Array1<f32>,
a_log: Array2<f32>, d_param: Array1<f32>,
dt_proj_w: Array2<f32>,
dt_proj_b: Array1<f32>,
b_proj_w: Array2<f32>,
c_proj_w: Array2<f32>,
norm_w: Array1<f32>,
norm_b: Array1<f32>,
hidden_state: HiddenState,
conv_buffer: Array2<f32>, kernel_size: usize,
}
impl Mamba2Layer {
pub fn new(config: Mamba2Config) -> CoreResult<Self> {
config.validate()?;
let d_inner = config.d_inner();
let kernel_size = 4;
use scirs2_core::random::thread_rng;
let mut rng = thread_rng();
let init_scale = 0.02;
let in_proj_w = Array2::from_shape_fn((2 * d_inner, config.d_model), |_| {
(rng.random::<f32>() - 0.5) * 2.0 * init_scale
});
let out_proj_w = Array2::from_shape_fn((config.d_model, d_inner), |_| {
(rng.random::<f32>() - 0.5) * 2.0 * init_scale
});
let conv1d_w = Array2::from_shape_fn((d_inner, kernel_size), |_| {
(rng.random::<f32>() - 0.5) * 2.0 * init_scale
});
let conv1d_b = Array1::zeros(d_inner);
let a_log = Self::init_a_matrix(config.n_heads, config.d_state)?;
let d_param = Array1::ones(d_inner);
let dt_proj_w = Array2::from_shape_fn((d_inner, config.dt_rank), |_| {
(rng.random::<f32>() - 0.5) * 2.0 * init_scale
});
let dt_proj_b = Array1::from_elem(
d_inner,
config.dt_init_scale / (config.dt_rank as f32).sqrt(),
);
let b_proj_w = Array2::from_shape_fn((config.d_state * config.n_heads, d_inner), |_| {
(rng.random::<f32>() - 0.5) * 2.0 * init_scale
});
let c_proj_w = Array2::from_shape_fn((config.d_state * config.n_heads, d_inner), |_| {
(rng.random::<f32>() - 0.5) * 2.0 * init_scale
});
let norm_w = Array1::ones(config.d_model);
let norm_b = Array1::zeros(config.d_model);
let hidden_state = HiddenState::new(config.n_heads, config.d_state);
let conv_buffer = Array2::zeros((d_inner, kernel_size));
Ok(Self {
config,
in_proj_w,
out_proj_w,
conv1d_w,
conv1d_b,
a_log,
d_param,
dt_proj_w,
dt_proj_b,
b_proj_w,
c_proj_w,
norm_w,
norm_b,
hidden_state,
conv_buffer,
kernel_size,
})
}
fn init_a_matrix(n_heads: usize, d_state: usize) -> CoreResult<Array2<f32>> {
use scirs2_core::random::thread_rng;
let mut rng = thread_rng();
let a_log = Array2::from_shape_fn((n_heads, d_state), |_| {
-std::f32::consts::LN_2 * rng.random::<f32>()
});
Ok(a_log)
}
fn layer_norm(&self, x: &Array1<f32>) -> Array1<f32> {
let mean = x.mean().unwrap_or(0.0);
let variance = x.iter().map(|&xi| (xi - mean).powi(2)).sum::<f32>() / (x.len() as f32);
let std = (variance + self.config.layer_norm_eps).sqrt();
let normalized = x.mapv(|xi| (xi - mean) / std);
&normalized * &self.norm_w + &self.norm_b
}
fn causal_conv1d(&mut self, x: &Array1<f32>) -> CoreResult<Array1<f32>> {
if x.len() != self.conv_buffer.nrows() {
return Err(CoreError::DimensionMismatch {
expected: self.conv_buffer.nrows(),
got: x.len(),
});
}
for i in (1..self.kernel_size).rev() {
for j in 0..self.conv_buffer.nrows() {
self.conv_buffer[[j, i]] = self.conv_buffer[[j, i - 1]];
}
}
for (j, &val) in x.iter().enumerate() {
self.conv_buffer[[j, 0]] = val;
}
let mut output = Array1::zeros(x.len());
for i in 0..x.len() {
let mut sum = 0.0;
for k in 0..self.kernel_size {
sum += self.conv_buffer[[i, k]] * self.conv1d_w[[i, k]];
}
output[i] = sum + self.conv1d_b[i];
}
Ok(output)
}
fn compute_dt(&self, x: &Array1<f32>) -> CoreResult<Array1<f32>> {
let dt_rank_size = self.config.dt_rank.min(x.len());
let x_reduced = x.slice(s![..dt_rank_size]).to_owned();
let dt_proj = self.dt_proj_w.slice(s![.., ..dt_rank_size]);
let mut dt = Array1::zeros(self.config.d_inner());
for i in 0..dt.len() {
dt[i] = dt_proj.row(i).dot(&x_reduced) + self.dt_proj_b[i];
}
for dt_i in dt.iter_mut() {
let sp = (1.0 + dt_i.exp()).ln();
*dt_i = sp.clamp(self.config.dt_min, self.config.dt_max);
}
Ok(dt)
}
fn compute_bc(&self, x: &Array1<f32>) -> CoreResult<(Array2<f32>, Array2<f32>)> {
let n_heads = self.config.n_heads;
let d_state = self.config.d_state;
let mut b = Array2::zeros((n_heads, d_state));
let mut c = Array2::zeros((n_heads, d_state));
for h in 0..n_heads {
for n in 0..d_state {
let idx = h * d_state + n;
if idx < self.b_proj_w.nrows() {
let b_row = self.b_proj_w.row(idx);
let c_row = self.c_proj_w.row(idx);
let x_len = x.len().min(b_row.len());
b[[h, n]] = b_row.slice(s![..x_len]).dot(&x.slice(s![..x_len]));
c[[h, n]] = c_row.slice(s![..x_len]).dot(&x.slice(s![..x_len]));
}
}
}
Ok((b, c))
}
fn discretize(
&self,
dt: &Array1<f32>,
a: &Array2<f32>,
b: &Array2<f32>,
) -> CoreResult<(Array2<f32>, Array2<f32>)> {
let n_heads = self.config.n_heads;
let d_state = self.config.d_state;
let mut a_bar = Array2::zeros((n_heads, d_state));
let mut b_bar = Array2::zeros((n_heads, d_state));
for head_idx in 0..n_heads {
for state_idx in 0..d_state {
let dt_idx = head_idx * d_state + state_idx;
let dt_val = if dt_idx < dt.len() {
dt[dt_idx]
} else {
self.config.dt_min
};
a_bar[[head_idx, state_idx]] = (a[[head_idx, state_idx]] * dt_val).exp();
b_bar[[head_idx, state_idx]] = b[[head_idx, state_idx]] * dt_val;
}
}
Ok((a_bar, b_bar))
}
fn ssm_step(
&mut self,
x: &Array1<f32>,
a_bar: &Array2<f32>,
b_bar: &Array2<f32>,
c: &Array2<f32>,
) -> CoreResult<Array1<f32>> {
let n_heads = self.config.n_heads;
let d_state = self.config.d_state;
let d_inner = x.len();
let d_head = d_inner / n_heads;
let h_state = self.hidden_state.state();
let mut h_new = Array2::zeros((n_heads, d_state));
for head_idx in 0..n_heads {
for state_idx in 0..d_state {
let mut x_contribution = 0.0;
for i in 0..d_head {
let x_idx = head_idx * d_head + i;
if x_idx < x.len() {
x_contribution += x[x_idx];
}
}
x_contribution /= d_head as f32;
h_new[[head_idx, state_idx]] = a_bar[[head_idx, state_idx]]
* h_state[[head_idx, state_idx]]
+ b_bar[[head_idx, state_idx]] * x_contribution;
}
}
let mut y = Array1::zeros(d_inner);
for head_idx in 0..n_heads {
for i in 0..d_head {
let y_idx = head_idx * d_head + i;
if y_idx < d_inner {
let mut state_sum = 0.0;
for state_idx in 0..d_state {
state_sum += c[[head_idx, state_idx]] * h_new[[head_idx, state_idx]];
}
y[y_idx] = state_sum;
}
}
}
self.hidden_state.update(h_new);
Ok(y)
}
pub fn forward(&mut self, x: &Array1<f32>) -> CoreResult<Array1<f32>> {
if x.len() != self.config.d_model {
return Err(CoreError::DimensionMismatch {
expected: self.config.d_model,
got: x.len(),
});
}
let x_norm = self.layer_norm(x);
let xz = self.in_proj_w.dot(&x_norm);
let d_inner = self.config.d_inner();
let x_proj = xz.slice(s![..d_inner]).to_owned();
let z = xz.slice(s![d_inner..]).to_owned();
let x_conv = self.causal_conv1d(&x_proj)?;
let x_act = x_conv.mapv(|v| v / (1.0 + (-v).exp()));
let dt = self.compute_dt(&x_act)?;
let (b, c) = self.compute_bc(&x_act)?;
let a = self.a_log.mapv(|v| (-v.exp()).abs());
let (a_bar, b_bar) = self.discretize(&dt, &a, &b)?;
let y_ssm = self.ssm_step(&x_act, &a_bar, &b_bar, &c)?;
let y_skip = &y_ssm + &(&x_act * &self.d_param);
let y_gated = &y_skip * &z;
let output = self.out_proj_w.dot(&y_gated);
Ok(&output + x)
}
pub fn forward_batch(&mut self, x: &Array2<f32>) -> CoreResult<Array2<f32>> {
let batch_size = x.nrows();
let mut outputs = Array2::zeros((batch_size, self.config.d_model));
for (i, input) in x.axis_iter(Axis(0)).enumerate() {
let output = self.forward(&input.to_owned())?;
outputs.row_mut(i).assign(&output);
}
Ok(outputs)
}
pub fn reset(&mut self) {
self.hidden_state.reset();
self.conv_buffer.fill(0.0);
}
pub fn config(&self) -> &Mamba2Config {
&self.config
}
}
pub struct Mamba2Model {
layers: Vec<Mamba2Layer>,
config: Mamba2Config,
}
impl Mamba2Model {
pub fn new(config: Mamba2Config, n_layers: usize) -> CoreResult<Self> {
let mut layers = Vec::with_capacity(n_layers);
for _ in 0..n_layers {
layers.push(Mamba2Layer::new(config.clone())?);
}
Ok(Self { layers, config })
}
pub fn forward(&mut self, x: &Array1<f32>) -> CoreResult<Array1<f32>> {
let mut hidden = x.clone();
for layer in &mut self.layers {
hidden = layer.forward(&hidden)?;
}
Ok(hidden)
}
pub fn forward_batch(&mut self, x: &Array2<f32>) -> CoreResult<Array2<f32>> {
let mut hidden = x.clone();
for layer in &mut self.layers {
hidden = layer.forward_batch(&hidden)?;
}
Ok(hidden)
}
pub fn reset(&mut self) {
for layer in &mut self.layers {
layer.reset();
}
}
pub fn n_layers(&self) -> usize {
self.layers.len()
}
pub fn config(&self) -> &Mamba2Config {
&self.config
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mamba2_config() {
let config = Mamba2Config::new(256, 64);
assert_eq!(config.d_model, 256);
assert_eq!(config.d_state, 64);
assert_eq!(config.d_inner(), 512); assert!(config.validate().is_ok());
}
#[test]
fn test_mamba2_config_validation() {
let mut config = Mamba2Config::new(256, 64);
config.d_model = 0;
assert!(config.validate().is_err());
let mut config = Mamba2Config::new(257, 64);
config.n_heads = 8;
assert!(config.validate().is_err());
let mut config = Mamba2Config::new(256, 64);
config.dt_min = 0.5;
config.dt_max = 0.1;
assert!(config.validate().is_err());
}
#[test]
fn test_mamba2_layer_creation() {
let config = Mamba2Config::new(64, 16);
let result = Mamba2Layer::new(config);
assert!(result.is_ok());
}
#[test]
fn test_mamba2_forward() {
let config = Mamba2Config::new(64, 16);
let mut layer = Mamba2Layer::new(config).unwrap();
let input = Array1::from_vec(vec![0.1; 64]);
let output = layer.forward(&input);
assert!(output.is_ok());
let output = output.unwrap();
assert_eq!(output.len(), 64);
assert!(output.iter().all(|&x| x.is_finite()));
}
#[test]
fn test_mamba2_reset() {
let config = Mamba2Config::new(64, 16);
let mut layer = Mamba2Layer::new(config).unwrap();
let input = Array1::from_vec(vec![0.1; 64]);
layer.forward(&input).unwrap();
layer.forward(&input).unwrap();
layer.reset();
let h_state = layer.hidden_state.state();
assert!(h_state.iter().all(|&x| x == 0.0));
}
#[test]
fn test_mamba2_model() {
let config = Mamba2Config::new(64, 16);
let mut model = Mamba2Model::new(config, 4).unwrap();
assert_eq!(model.n_layers(), 4);
let input = Array1::from_vec(vec![0.1; 64]);
let output = model.forward(&input);
assert!(output.is_ok());
let output = output.unwrap();
assert_eq!(output.len(), 64);
assert!(output.iter().all(|&x| x.is_finite()));
}
#[test]
fn test_mamba2_batch() {
let config = Mamba2Config::new(64, 16);
let mut layer = Mamba2Layer::new(config).unwrap();
let batch = Array2::from_shape_fn((8, 64), |(i, j)| 0.1 * (i as f32 + j as f32));
let output = layer.forward_batch(&batch);
assert!(output.is_ok());
let output = output.unwrap();
assert_eq!(output.shape(), &[8, 64]);
assert!(output.iter().all(|&x| x.is_finite()));
}
#[test]
fn test_mamba2_no_nan() {
let config = Mamba2Config::new(64, 16);
let mut layer = Mamba2Layer::new(config).unwrap();
for _ in 0..10 {
let input = Array1::from_elem(64, 0.5);
let output = layer.forward(&input).unwrap();
assert!(output.iter().all(|&x| !x.is_nan()));
}
}
#[test]
fn test_layer_norm() {
let config = Mamba2Config::new(64, 16);
let layer = Mamba2Layer::new(config).unwrap();
let input = Array1::from_vec((0..64).map(|i| i as f32).collect());
let normalized = layer.layer_norm(&input);
let mean = normalized.mean().unwrap();
let std = (normalized.iter().map(|&x| x.powi(2)).sum::<f32>() / 64.0).sqrt();
assert!((mean.abs()) < 1e-5, "Mean should be close to 0");
assert!((std - 1.0).abs() < 1e-4, "Std should be close to 1");
}
}