use crate::error::{Result, TimeSeriesError};
#[derive(Debug, Clone)]
pub struct MambaConfig {
pub model_dim: usize,
pub state_dim: usize,
pub expand_factor: usize,
pub dt_rank: usize,
pub dt_min: f64,
pub dt_max: f64,
pub conv_dim: usize,
}
impl Default for MambaConfig {
fn default() -> Self {
Self {
model_dim: 64,
state_dim: 16,
expand_factor: 2,
dt_rank: 4,
dt_min: 0.001,
dt_max: 0.1,
conv_dim: 4,
}
}
}
impl MambaConfig {
pub fn inner_dim(&self) -> usize {
self.expand_factor * self.model_dim
}
}
#[derive(Debug, Clone)]
pub struct MambaBlock {
pub in_proj_weight: Vec<Vec<f64>>,
pub conv1d_weight: Vec<f64>,
pub x_proj_weight: Vec<Vec<f64>>,
pub dt_proj_weight: Vec<Vec<f64>>,
pub dt_proj_bias: Vec<f64>,
pub a_log: Vec<Vec<f64>>,
pub d_param: Vec<f64>,
pub out_proj_weight: Vec<Vec<f64>>,
pub config: MambaConfig,
}
impl MambaBlock {
pub fn new(config: MambaConfig) -> Result<Self> {
Self::new_with_seed(config, 42)
}
pub fn new_with_seed(config: MambaConfig, seed: u64) -> Result<Self> {
let d = config.model_dim;
let n = config.state_dim;
let inner = config.inner_dim();
let dt_rank = config.dt_rank;
let conv_dim = config.conv_dim;
if d == 0 || n == 0 || inner == 0 || dt_rank == 0 || conv_dim == 0 {
return Err(TimeSeriesError::InvalidInput(
"all Mamba dimensions must be > 0".into(),
));
}
let mut rng_state = seed;
let in_scale = 1.0 / (d as f64).sqrt();
let inner_scale = 1.0 / (inner as f64).sqrt();
let dt_scale = 1.0 / (dt_rank as f64).sqrt();
let in_proj_weight = xorshift_matrix(&mut rng_state, 2 * inner, d, in_scale);
let conv1d_weight =
xorshift_vec(&mut rng_state, conv_dim, 1.0 / (conv_dim as f64).sqrt());
let x_proj_weight =
xorshift_matrix(&mut rng_state, dt_rank + 2 * n, inner, inner_scale);
let dt_proj_weight = xorshift_matrix(&mut rng_state, inner, dt_rank, dt_scale);
let dt_proj_bias = xorshift_vec(&mut rng_state, inner, 0.01);
let a_log: Vec<Vec<f64>> = (0..n)
.map(|i| vec![((i + 1) as f64).ln(); inner])
.collect();
let d_param = vec![1.0_f64; inner];
let out_proj_weight = xorshift_matrix(&mut rng_state, d, inner, inner_scale);
Ok(Self {
in_proj_weight,
conv1d_weight,
x_proj_weight,
dt_proj_weight,
dt_proj_bias,
a_log,
d_param,
out_proj_weight,
config,
})
}
pub fn forward(&self, input: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
let seq_len = input.len();
if seq_len == 0 {
return Ok(vec![]);
}
let d = self.config.model_dim;
let inner = self.config.inner_dim();
let n = self.config.state_dim;
let dt_rank = self.config.dt_rank;
for (t, row) in input.iter().enumerate() {
if row.len() != d {
return Err(TimeSeriesError::DimensionMismatch {
expected: d,
actual: row.len(),
});
}
}
let projected = linear_batch(input, &self.in_proj_weight)?;
let mut x_branch: Vec<Vec<f64>> = Vec::with_capacity(seq_len);
let mut z_branch: Vec<Vec<f64>> = Vec::with_capacity(seq_len);
for row in &projected {
x_branch.push(row[..inner].to_vec());
z_branch.push(row[inner..].to_vec());
}
let x_conv = depthwise_conv1d(&x_branch, &self.conv1d_weight)?;
let x_act = silu_batch(&x_conv);
let x_dbc = linear_batch(&x_act, &self.x_proj_weight)?;
let mut delta_raw: Vec<Vec<f64>> = Vec::with_capacity(seq_len);
let mut b_seq: Vec<Vec<f64>> = Vec::with_capacity(seq_len);
let mut c_seq: Vec<Vec<f64>> = Vec::with_capacity(seq_len);
for row in &x_dbc {
delta_raw.push(row[..dt_rank].to_vec());
b_seq.push(row[dt_rank..(dt_rank + n)].to_vec());
c_seq.push(row[(dt_rank + n)..(dt_rank + 2 * n)].to_vec());
}
let delta_proj = linear_batch_bias(&delta_raw, &self.dt_proj_weight, &self.dt_proj_bias)?;
let delta: Vec<Vec<f64>> = delta_proj
.iter()
.map(|row| {
row.iter()
.map(|&v| {
let sp = softplus(v);
sp.clamp(self.config.dt_min, self.config.dt_max)
})
.collect()
})
.collect();
let a: Vec<Vec<f64>> = self
.a_log
.iter()
.map(|row| row.iter().map(|&v| -v.exp()).collect())
.collect();
let y_ssm = selective_scan(&x_act, &delta, &a, &b_seq, &c_seq, &self.d_param)?;
let z_act = silu_batch(&z_branch);
let mut gated: Vec<Vec<f64>> = Vec::with_capacity(seq_len);
for t in 0..seq_len {
let row: Vec<f64> = y_ssm[t]
.iter()
.zip(z_act[t].iter())
.map(|(y, z)| y * z)
.collect();
gated.push(row);
}
let output = linear_batch(&gated, &self.out_proj_weight)?;
Ok(output)
}
}
fn xorshift_next(state: &mut u64) -> f64 {
*state ^= *state << 13;
*state ^= *state >> 7;
*state ^= *state << 17;
((*state as f64) / (u64::MAX as f64)) * 2.0 - 1.0
}
fn xorshift_matrix(state: &mut u64, rows: usize, cols: usize, scale: f64) -> Vec<Vec<f64>> {
(0..rows)
.map(|_| (0..cols).map(|_| xorshift_next(state) * scale).collect())
.collect()
}
fn xorshift_vec(state: &mut u64, len: usize, scale: f64) -> Vec<f64> {
(0..len).map(|_| xorshift_next(state) * scale).collect()
}
pub fn selective_scan(
u: &[Vec<f64>],
delta: &[Vec<f64>],
a: &[Vec<f64>],
b: &[Vec<f64>],
c: &[Vec<f64>],
d: &[f64],
) -> Result<Vec<Vec<f64>>> {
let seq_len = u.len();
if seq_len == 0 {
return Ok(vec![]);
}
let dim = u[0].len();
let state_dim = a.len();
if d.len() != dim {
return Err(TimeSeriesError::DimensionMismatch {
expected: dim,
actual: d.len(),
});
}
if delta.len() != seq_len || b.len() != seq_len || c.len() != seq_len {
return Err(TimeSeriesError::InvalidInput(
"delta, b, c must have same seq_len as u".into(),
));
}
let mut x = vec![vec![0.0_f64; dim]; state_dim];
let mut output = Vec::with_capacity(seq_len);
for t in 0..seq_len {
let u_t = &u[t];
let delta_t = &delta[t];
let b_t = &b[t];
let c_t = &c[t];
for i in 0..state_dim {
for j in 0..dim {
let a_bar = (a[i][j] * delta_t[j]).exp();
let b_bar = delta_t[j] * b_t[i];
x[i][j] = a_bar * x[i][j] + b_bar * u_t[j];
}
}
let mut y_t = vec![0.0_f64; dim];
for j in 0..dim {
let mut val = d[j] * u_t[j];
for i in 0..state_dim {
val += c_t[i] * x[i][j];
}
y_t[j] = val;
}
output.push(y_t);
}
Ok(output)
}
fn linear_batch(input: &[Vec<f64>], weight: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
let out_dim = weight.len();
if out_dim == 0 {
return Ok(input.iter().map(|_| vec![]).collect());
}
let in_dim = weight[0].len();
let mut output = Vec::with_capacity(input.len());
for row in input {
if row.len() != in_dim {
return Err(TimeSeriesError::DimensionMismatch {
expected: in_dim,
actual: row.len(),
});
}
let mut out = vec![0.0_f64; out_dim];
for i in 0..out_dim {
let mut s = 0.0_f64;
for j in 0..in_dim {
s += weight[i][j] * row[j];
}
out[i] = s;
}
output.push(out);
}
Ok(output)
}
fn linear_batch_bias(
input: &[Vec<f64>],
weight: &[Vec<f64>],
bias: &[f64],
) -> Result<Vec<Vec<f64>>> {
let mut output = linear_batch(input, weight)?;
for row in &mut output {
for (j, val) in row.iter_mut().enumerate() {
if j < bias.len() {
*val += bias[j];
}
}
}
Ok(output)
}
fn depthwise_conv1d(input: &[Vec<f64>], kernel: &[f64]) -> Result<Vec<Vec<f64>>> {
let seq_len = input.len();
if seq_len == 0 {
return Ok(vec![]);
}
let channels = input[0].len();
let k_size = kernel.len();
let mut output = Vec::with_capacity(seq_len);
for t in 0..seq_len {
let mut row = vec![0.0_f64; channels];
for (ki, &kv) in kernel.iter().enumerate() {
if t >= ki {
for ch in 0..channels {
row[ch] += kv * input[t - ki][ch];
}
}
}
output.push(row);
}
Ok(output)
}
#[inline]
fn silu(x: f64) -> f64 {
x / (1.0 + (-x).exp())
}
#[inline]
fn softplus(x: f64) -> f64 {
if x > 20.0 {
x
} else if x < -20.0 {
0.0
} else {
(1.0 + x.exp()).ln()
}
}
fn silu_batch(input: &[Vec<f64>]) -> Vec<Vec<f64>> {
input
.iter()
.map(|row| row.iter().map(|&v| silu(v)).collect())
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_selective_scan_output_shape() {
let seq_len = 10;
let dim = 4;
let state_dim = 3;
let u: Vec<Vec<f64>> = vec![vec![1.0; dim]; seq_len];
let delta: Vec<Vec<f64>> = vec![vec![0.01; dim]; seq_len];
let a: Vec<Vec<f64>> = vec![vec![-1.0; dim]; state_dim];
let b: Vec<Vec<f64>> = vec![vec![1.0; state_dim]; seq_len];
let c: Vec<Vec<f64>> = vec![vec![1.0; state_dim]; seq_len];
let d = vec![0.5; dim];
let output = selective_scan(&u, &delta, &a, &b, &c, &d).expect("should succeed");
assert_eq!(output.len(), seq_len);
assert_eq!(output[0].len(), dim);
}
#[test]
fn test_selective_scan_zero_delta() {
let seq_len = 5;
let dim = 3;
let state_dim = 2;
let u: Vec<Vec<f64>> = vec![vec![2.0; dim]; seq_len];
let delta: Vec<Vec<f64>> = vec![vec![0.0; dim]; seq_len];
let a: Vec<Vec<f64>> = vec![vec![-1.0; dim]; state_dim];
let b: Vec<Vec<f64>> = vec![vec![1.0; state_dim]; seq_len];
let c: Vec<Vec<f64>> = vec![vec![1.0; state_dim]; seq_len];
let d = vec![0.5; dim];
let output = selective_scan(&u, &delta, &a, &b, &c, &d).expect("should succeed");
for t in 0..seq_len {
for j in 0..dim {
let expected = d[j] * u[t][j]; assert!(
(output[t][j] - expected).abs() < 1e-12,
"t={} j={}: got {} expected {}",
t,
j,
output[t][j],
expected
);
}
}
}
#[test]
fn test_selective_scan_empty() {
let output =
selective_scan(&[], &[], &vec![vec![-1.0; 2]; 3], &[], &[], &[1.0, 1.0])
.expect("should succeed");
assert!(output.is_empty());
}
#[test]
fn test_selective_scan_dimension_mismatch() {
let u = vec![vec![1.0, 2.0]];
let delta = vec![vec![0.01, 0.01]];
let a = vec![vec![-1.0, -1.0]];
let b = vec![vec![1.0]];
let c = vec![vec![1.0]];
let d_wrong = vec![1.0, 2.0, 3.0]; assert!(selective_scan(&u, &delta, &a, &b, &c, &d_wrong).is_err());
}
#[test]
fn test_mamba_block_creation() {
let config = MambaConfig {
model_dim: 8,
state_dim: 4,
expand_factor: 2,
dt_rank: 2,
dt_min: 0.001,
dt_max: 0.1,
conv_dim: 4,
};
let block = MambaBlock::new(config).expect("should succeed");
assert_eq!(block.config.model_dim, 8);
assert_eq!(block.config.inner_dim(), 16);
}
#[test]
fn test_mamba_block_forward_dimensions() {
let config = MambaConfig {
model_dim: 8,
state_dim: 4,
expand_factor: 2,
dt_rank: 2,
dt_min: 0.001,
dt_max: 0.1,
conv_dim: 4,
};
let block = MambaBlock::new(config).expect("should succeed");
let seq_len = 6;
let input: Vec<Vec<f64>> = vec![vec![0.1; 8]; seq_len];
let output = block.forward(&input).expect("should succeed");
assert_eq!(output.len(), seq_len);
assert_eq!(output[0].len(), 8);
}
#[test]
fn test_mamba_block_empty_input() {
let block = MambaBlock::new(MambaConfig::default()).expect("should succeed");
let output = block.forward(&[]).expect("should succeed");
assert!(output.is_empty());
}
#[test]
fn test_mamba_block_wrong_input_dim() {
let config = MambaConfig {
model_dim: 8,
..MambaConfig::default()
};
let block = MambaBlock::new(config).expect("should succeed");
let input = vec![vec![1.0; 4]]; assert!(block.forward(&input).is_err());
}
#[test]
fn test_mamba_block_with_seed_deterministic() {
let config = MambaConfig {
model_dim: 8,
state_dim: 4,
expand_factor: 2,
dt_rank: 2,
dt_min: 0.001,
dt_max: 0.1,
conv_dim: 4,
};
let b1 = MambaBlock::new_with_seed(config.clone(), 123).expect("b1");
let b2 = MambaBlock::new_with_seed(config, 123).expect("b2");
assert_eq!(b1.in_proj_weight.len(), b2.in_proj_weight.len());
for (r1, r2) in b1.in_proj_weight.iter().zip(b2.in_proj_weight.iter()) {
for (v1, v2) in r1.iter().zip(r2.iter()) {
assert!((v1 - v2).abs() < 1e-15);
}
}
}
#[test]
fn test_mamba_default_config() {
let config = MambaConfig::default();
assert_eq!(config.model_dim, 64);
assert_eq!(config.state_dim, 16);
assert_eq!(config.expand_factor, 2);
assert_eq!(config.inner_dim(), 128);
}
#[test]
fn test_silu_values() {
assert!((silu(0.0) - 0.0).abs() < 1e-10);
assert!((silu(10.0) - 10.0).abs() < 0.001);
assert!(silu(1.0) > 0.0);
assert!(silu(-1.0) < 0.0);
}
#[test]
fn test_softplus_values() {
assert!((softplus(0.0) - (2.0_f64).ln()).abs() < 1e-10);
assert!((softplus(30.0) - 30.0).abs() < 1e-10);
assert!(softplus(-30.0).abs() < 1e-10);
}
#[test]
fn test_depthwise_conv1d() {
let input = vec![
vec![1.0, 2.0],
vec![3.0, 4.0],
vec![5.0, 6.0],
];
let kernel = vec![1.0, 0.5];
let output = depthwise_conv1d(&input, &kernel).expect("should succeed");
assert_eq!(output.len(), 3);
assert!((output[0][0] - 1.0).abs() < 1e-12);
assert!((output[0][1] - 2.0).abs() < 1e-12);
assert!((output[1][0] - 3.5).abs() < 1e-12);
assert!((output[1][1] - 5.0).abs() < 1e-12);
}
#[test]
fn test_mamba_forward_finite_output() {
let config = MambaConfig {
model_dim: 4,
state_dim: 2,
expand_factor: 2,
dt_rank: 1,
dt_min: 0.001,
dt_max: 0.1,
conv_dim: 2,
};
let block = MambaBlock::new_with_seed(config, 99).expect("should succeed");
let input = vec![vec![0.1, -0.2, 0.3, -0.4]; 8];
let output = block.forward(&input).expect("forward");
for (t, row) in output.iter().enumerate() {
for (j, &val) in row.iter().enumerate() {
assert!(
val.is_finite(),
"output[{}][{}] = {} is not finite",
t,
j,
val
);
}
}
}
}