use std::f64::consts::PI;
use scirs2_core::ndarray::Array2;
use scirs2_core::random::RngExt;
use crate::error::{Result, TimeSeriesError};
use super::config::MambaConfig;
#[inline]
fn silu(x: f64) -> f64 {
x / (1.0 + (-x).exp())
}
#[inline]
fn softplus(x: f64) -> f64 {
if x > 20.0 {
x } else {
(1.0 + x.exp()).ln()
}
}
fn rms_norm(x: &[f64], scale: &[f64]) -> Vec<f64> {
let rms = (x.iter().map(|v| v * v).sum::<f64>() / x.len() as f64 + 1e-6).sqrt();
x.iter()
.zip(scale.iter())
.map(|(v, s)| v / rms * s)
.collect()
}
fn normal_sample(rng: &mut impl scirs2_core::random::Rng, mean: f64, std: f64) -> f64 {
let u1: f64 = rng.random::<f64>().max(1e-15); let u2: f64 = rng.random::<f64>();
let z: f64 = (-2.0_f64 * u1.ln()).sqrt() * (2.0_f64 * PI * u2).cos();
mean + std * z
}
pub struct MambaBlock {
pub config: MambaConfig,
in_proj: Array2<f64>,
conv1d_weight: Array2<f64>,
conv1d_bias: Vec<f64>,
x_proj: Array2<f64>,
dt_proj: Array2<f64>,
a_log: Array2<f64>,
d_param: Vec<f64>,
out_proj: Array2<f64>,
}
impl MambaBlock {
pub fn new(config: &MambaConfig, rng: &mut impl scirs2_core::random::Rng) -> Self {
let d_model = config.d_model;
let d_inner = config.d_inner();
let d_state = config.d_state;
let d_conv = config.d_conv;
let dt_rank = config.dt_rank;
let in_proj = xavier_uniform_2d(d_model, 2 * d_inner, rng);
let conv_scale = (2.0 / (d_inner * d_conv) as f64).sqrt();
let conv_data: Vec<f64> = (0..d_inner * d_conv)
.map(|_| normal_sample(rng, 0.0, conv_scale))
.collect();
let conv1d_weight = Array2::from_shape_vec((d_inner, d_conv), conv_data)
.unwrap_or_else(|_| Array2::zeros((d_inner, d_conv)));
let conv1d_bias = vec![0.0_f64; d_inner];
let x_proj_cols = dt_rank + 2 * d_state;
let x_proj = xavier_uniform_2d(d_inner, x_proj_cols, rng);
let dt_proj_scale = (2.0 / dt_rank as f64).sqrt();
let dt_proj_data: Vec<f64> = (0..dt_rank * d_inner)
.map(|_| normal_sample(rng, 0.0, dt_proj_scale))
.collect();
let dt_proj = Array2::from_shape_vec((dt_rank, d_inner), dt_proj_data)
.unwrap_or_else(|_| Array2::zeros((dt_rank, d_inner)));
let a_log_data: Vec<f64> = (0..d_inner)
.flat_map(|_| {
(0..d_state).map(|n| {
((n + 1) as f64).ln()
})
})
.collect();
let a_log = Array2::from_shape_vec((d_inner, d_state), a_log_data)
.unwrap_or_else(|_| Array2::zeros((d_inner, d_state)));
let d_param = vec![1.0_f64; d_inner];
let out_proj = xavier_uniform_2d(d_inner, d_model, rng);
MambaBlock {
config: config.clone(),
in_proj,
conv1d_weight,
conv1d_bias,
x_proj,
dt_proj,
a_log,
d_param,
out_proj,
}
}
fn causal_conv1d(&self, x: &Array2<f64>) -> Array2<f64> {
let (seq_len, d_inner) = x.dim();
let d_conv = self.config.d_conv;
let mut out = Array2::zeros((seq_len, d_inner));
for c in 0..d_inner {
let bias = self.conv1d_bias[c];
for t in 0..seq_len {
let mut acc = bias;
for k in 0..d_conv {
if t + 1 >= k + 1 {
let src_t = t - k;
acc += self.conv1d_weight[[c, k]] * x[[src_t, c]];
}
}
out[[t, c]] = acc;
}
}
out
}
fn selective_scan(
&self,
x: &Array2<f64>,
dt: &Array2<f64>,
b: &Array2<f64>,
c: &Array2<f64>,
) -> Array2<f64> {
let (seq_len, d_inner) = x.dim();
let d_state = self.config.d_state;
let mut output = Array2::zeros((seq_len, d_inner));
for ch in 0..d_inner {
let mut h = vec![0.0_f64; d_state];
for t in 0..seq_len {
let dt_t = dt[[t, ch]];
let x_t = x[[t, ch]];
let mut y_t = 0.0_f64;
for n in 0..d_state {
let a_val = -(self.a_log[[ch, n]].exp()); let a_bar = (dt_t * a_val).exp();
let b_bar = dt_t * b[[t, n]]; h[n] = a_bar * h[n] + b_bar * x_t;
y_t += c[[t, n]] * h[n];
}
output[[t, ch]] = y_t + self.d_param[ch] * x_t;
}
}
output
}
fn matmul_seq(x: &Array2<f64>, w: &Array2<f64>) -> Array2<f64> {
let (seq_len, in_dim) = x.dim();
let (in_w, out_dim) = w.dim();
debug_assert_eq!(in_dim, in_w);
let mut out = Array2::zeros((seq_len, out_dim));
for l in 0..seq_len {
for j in 0..out_dim {
let mut acc = 0.0_f64;
for i in 0..in_dim {
acc += x[[l, i]] * w[[i, j]];
}
out[[l, j]] = acc;
}
}
out
}
pub fn forward(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
let (seq_len, d_model) = x.dim();
if d_model != self.config.d_model {
return Err(TimeSeriesError::DimensionMismatch {
expected: self.config.d_model,
actual: d_model,
});
}
let d_inner = self.config.d_inner();
let d_state = self.config.d_state;
let dt_rank = self.config.dt_rank;
let xz = Self::matmul_seq(x, &self.in_proj);
let mut x_ssm = Array2::zeros((seq_len, d_inner));
let mut z = Array2::zeros((seq_len, d_inner));
for t in 0..seq_len {
for j in 0..d_inner {
x_ssm[[t, j]] = xz[[t, j]];
z[[t, j]] = xz[[t, j + d_inner]];
}
}
let x_conv = self.causal_conv1d(&x_ssm);
let x_act = Array2::from_shape_fn((seq_len, d_inner), |(t, j)| silu(x_conv[[t, j]]));
let xbc = Self::matmul_seq(&x_act, &self.x_proj);
let mut delta_raw = Array2::zeros((seq_len, dt_rank));
let mut b_mat = Array2::zeros((seq_len, d_state));
let mut c_mat = Array2::zeros((seq_len, d_state));
for t in 0..seq_len {
for j in 0..dt_rank {
delta_raw[[t, j]] = xbc[[t, j]];
}
for j in 0..d_state {
b_mat[[t, j]] = xbc[[t, dt_rank + j]];
c_mat[[t, j]] = xbc[[t, dt_rank + d_state + j]];
}
}
let dt_linear = Self::matmul_seq(&delta_raw, &self.dt_proj); let dt = Array2::from_shape_fn((seq_len, d_inner), |(t, j)| softplus(dt_linear[[t, j]]));
let y_ssm = self.selective_scan(&x_act, &dt, &b_mat, &c_mat);
let y_gated =
Array2::from_shape_fn((seq_len, d_inner), |(t, j)| y_ssm[[t, j]] * silu(z[[t, j]]));
let output = Self::matmul_seq(&y_gated, &self.out_proj);
Ok(output)
}
pub fn a_log(&self) -> &Array2<f64> {
&self.a_log
}
pub fn d_param(&self) -> &[f64] {
&self.d_param
}
}
pub struct MambaModel {
pub config: MambaConfig,
layers: Vec<MambaBlock>,
norm_layers: Vec<Vec<f64>>,
final_norm: Vec<f64>,
output_proj: Array2<f64>,
output_dim: usize,
}
impl MambaModel {
pub fn new(
config: &MambaConfig,
output_dim: usize,
rng: &mut impl scirs2_core::random::Rng,
) -> Self {
let n_layers = config.n_layers;
let d_model = config.d_model;
let layers: Vec<MambaBlock> = (0..n_layers)
.map(|_| MambaBlock::new(config, rng))
.collect();
let norm_layers: Vec<Vec<f64>> = (0..n_layers).map(|_| vec![1.0_f64; d_model]).collect();
let final_norm = vec![1.0_f64; d_model];
let output_proj = xavier_uniform_2d(d_model, output_dim, rng);
MambaModel {
config: config.clone(),
layers,
norm_layers,
final_norm,
output_proj,
output_dim,
}
}
pub fn forward(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
let (seq_len, d_model) = x.dim();
if d_model != self.config.d_model {
return Err(TimeSeriesError::DimensionMismatch {
expected: self.config.d_model,
actual: d_model,
});
}
let mut hidden = x.clone();
for (layer_idx, layer) in self.layers.iter().enumerate() {
let scale = &self.norm_layers[layer_idx];
let normed = apply_rmsnorm_seq(&hidden, scale);
let block_out = layer.forward(&normed)?;
for t in 0..seq_len {
for j in 0..d_model {
hidden[[t, j]] += block_out[[t, j]];
}
}
}
let normed_final = apply_rmsnorm_seq(&hidden, &self.final_norm);
let output = matmul_seq_static(&normed_final, &self.output_proj);
Ok(output)
}
pub fn forecast(&self, history: &[f64], horizon: usize) -> Result<Vec<f64>> {
if history.is_empty() {
return Err(TimeSeriesError::InvalidInput(
"history must be non-empty".to_string(),
));
}
if horizon == 0 {
return Ok(Vec::new());
}
let seq_len = self.config.seq_len;
let d_model = self.config.d_model;
let window: Vec<f64> = if history.len() >= seq_len {
history[history.len() - seq_len..].to_vec()
} else {
let pad_len = seq_len - history.len();
let pad_val = history[0];
let mut w = vec![pad_val; pad_len];
w.extend_from_slice(history);
w
};
let mut predictions = Vec::with_capacity(horizon);
let mut current_window = window;
for _ in 0..horizon {
let x = embed_window(¤t_window, d_model);
let output = self.forward(&x)?;
let pred = output[[seq_len - 1, 0]];
predictions.push(pred);
current_window.remove(0);
current_window.push(pred);
}
Ok(predictions)
}
}
fn xavier_uniform_2d(
in_dim: usize,
out_dim: usize,
rng: &mut impl scirs2_core::random::Rng,
) -> Array2<f64> {
let limit = (6.0 / (in_dim + out_dim) as f64).sqrt();
let data: Vec<f64> = (0..in_dim * out_dim)
.map(|_| {
let u: f64 = rng.random::<f64>();
-limit + 2.0 * limit * u
})
.collect();
Array2::from_shape_vec((in_dim, out_dim), data)
.unwrap_or_else(|_| Array2::zeros((in_dim, out_dim)))
}
fn apply_rmsnorm_seq(x: &Array2<f64>, scale: &[f64]) -> Array2<f64> {
let (seq_len, d_model) = x.dim();
let mut out = Array2::zeros((seq_len, d_model));
for t in 0..seq_len {
let row: Vec<f64> = (0..d_model).map(|j| x[[t, j]]).collect();
let normed = rms_norm(&row, scale);
for j in 0..d_model {
out[[t, j]] = normed[j];
}
}
out
}
fn matmul_seq_static(x: &Array2<f64>, w: &Array2<f64>) -> Array2<f64> {
let (seq_len, in_dim) = x.dim();
let (in_w, out_dim) = w.dim();
debug_assert_eq!(in_dim, in_w);
let mut out = Array2::zeros((seq_len, out_dim));
for l in 0..seq_len {
for j in 0..out_dim {
let mut acc = 0.0_f64;
for i in 0..in_dim {
acc += x[[l, i]] * w[[i, j]];
}
out[[l, j]] = acc;
}
}
out
}
fn embed_window(window: &[f64], d_model: usize) -> Array2<f64> {
let seq_len = window.len();
let mut x = Array2::zeros((seq_len, d_model));
for t in 0..seq_len {
let v = window[t];
for j in 0..d_model {
let pos_angle =
(t as f64 * PI) / (seq_len as f64) * (j + 1) as f64 / (d_model + 1) as f64;
x[[t, j]] = v * pos_angle.cos();
}
}
x
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::SeedableRng;
fn make_rng() -> impl scirs2_core::random::Rng {
scirs2_core::random::rngs::StdRng::seed_from_u64(0)
}
fn small_config() -> MambaConfig {
MambaConfig {
d_model: 16,
d_state: 4,
d_conv: 3,
expand: 2,
dt_rank: 2,
seq_len: 16,
n_layers: 2,
dropout: 0.0,
}
}
#[test]
fn test_mamba_block_shape() {
let config = small_config();
let block = MambaBlock::new(&config, &mut make_rng());
let x = Array2::ones((16, 16));
let out = block.forward(&x).expect("block forward should succeed");
assert_eq!(out.dim(), (16, 16));
}
#[test]
fn test_mamba_block_shape_varied_seq() {
let config = small_config();
let block = MambaBlock::new(&config, &mut make_rng());
for seq_len in [1, 4, 16, 32] {
let x = Array2::ones((seq_len, 16));
let out = block.forward(&x).expect("block forward");
assert_eq!(out.dim(), (seq_len, 16));
}
}
#[test]
fn test_mamba_selective_scan_shape() {
let config = small_config();
let block = MambaBlock::new(&config, &mut make_rng());
let d_inner = config.d_inner();
let d_state = config.d_state;
let seq_len = 16;
let x = Array2::ones((seq_len, d_inner));
let dt = Array2::from_elem((seq_len, d_inner), 0.01);
let b = Array2::ones((seq_len, d_state));
let c = Array2::ones((seq_len, d_state));
let out = block.selective_scan(&x, &dt, &b, &c);
assert_eq!(out.dim(), (seq_len, d_inner));
}
#[test]
fn test_mamba_forward_causal() {
let config = small_config();
let block = MambaBlock::new(&config, &mut make_rng());
let mut x1 = Array2::zeros((8, 16));
let mut x2 = Array2::zeros((8, 16));
for t in 0..8 {
for j in 0..16 {
x1[[t, j]] = t as f64 * 0.1 + j as f64 * 0.01;
x2[[t, j]] = t as f64 * 0.1 + j as f64 * 0.01;
}
}
for t in 4..8 {
for j in 0..16 {
x2[[t, j]] = 999.0;
}
}
let out1 = block.forward(&x1).expect("forward x1");
let out2 = block.forward(&x2).expect("forward x2");
for t in 0..4 {
for j in 0..16 {
let diff = (out1[[t, j]] - out2[[t, j]]).abs();
assert!(
diff < 1e-10,
"t={t}, j={j}: causality violation: diff={diff:.2e}"
);
}
}
}
#[test]
fn test_mamba_model_forecast() {
let config = MambaConfig {
d_model: 16,
d_state: 4,
d_conv: 3,
expand: 2,
dt_rank: 2,
seq_len: 32,
n_layers: 2,
dropout: 0.0,
};
let mut rng = make_rng();
let model = MambaModel::new(&config, 1, &mut rng);
let history: Vec<f64> = (0..64).map(|i| (i as f64).sin()).collect();
let forecast = model.forecast(&history, 8).expect("forecast");
assert_eq!(forecast.len(), 8);
for (i, &v) in forecast.iter().enumerate() {
assert!(v.is_finite(), "prediction[{i}] is not finite: {v}");
}
}
#[test]
fn test_mamba_model_forward_shape() {
let config = small_config();
let mut rng = make_rng();
let model = MambaModel::new(&config, 1, &mut rng);
let x = Array2::zeros((16, 16));
let out = model.forward(&x).expect("model forward");
assert_eq!(out.dim(), (16, 1));
}
#[test]
fn test_mamba_longer_seq() {
let config = MambaConfig {
d_model: 16,
d_state: 4,
d_conv: 4,
expand: 2,
dt_rank: 2,
seq_len: 256,
n_layers: 2,
dropout: 0.0,
};
let block = MambaBlock::new(&config, &mut make_rng());
let x = Array2::zeros((256, 16));
let out = block.forward(&x).expect("long seq forward");
assert_eq!(out.dim(), (256, 16));
}
#[test]
fn test_mamba_d_param_skip_connection() {
let config = small_config();
let block = MambaBlock::new(&config, &mut make_rng());
let x_zero = Array2::zeros((8, 16));
let out_zero = block.forward(&x_zero).expect("forward zero");
for t in 0..8 {
for j in 0..16 {
assert!(out_zero[[t, j]].is_finite());
}
}
let d = block.d_param();
assert_eq!(d.len(), config.d_inner());
for &v in d {
assert!(v.is_finite());
}
}
#[test]
fn test_mamba_dimension_mismatch() {
let config = small_config();
let block = MambaBlock::new(&config, &mut make_rng());
let x = Array2::zeros((8, 4)); let result = block.forward(&x);
assert!(result.is_err());
}
#[test]
fn test_mamba_forecast_short_history() {
let config = MambaConfig {
d_model: 8,
d_state: 4,
d_conv: 3,
expand: 2,
dt_rank: 1,
seq_len: 16,
n_layers: 1,
dropout: 0.0,
};
let mut rng = make_rng();
let model = MambaModel::new(&config, 1, &mut rng);
let history = vec![1.0, 2.0, 3.0]; let forecast = model.forecast(&history, 4).expect("forecast");
assert_eq!(forecast.len(), 4);
}
#[test]
fn test_mamba_a_log_shape() {
let config = small_config();
let block = MambaBlock::new(&config, &mut make_rng());
let a_log = block.a_log();
assert_eq!(a_log.dim(), (config.d_inner(), config.d_state));
}
#[test]
fn test_silu_properties() {
assert!((silu(0.0)).abs() < 1e-10);
for x in [-5.0, -1.0, 0.0, 1.0, 5.0, 20.0] {
let v = silu(x);
assert!(v.is_finite(), "silu({x}) = {v}");
}
assert!(silu(1.0) > silu(0.0));
}
#[test]
fn test_softplus_properties() {
for x in [-10.0, -1.0, 0.0, 1.0, 10.0, 25.0] {
let v = softplus(x);
assert!(v > 0.0, "softplus({x}) = {v}");
assert!(v.is_finite(), "softplus({x}) = {v}");
}
assert!((softplus(25.0) - 25.0).abs() < 1e-6);
}
#[test]
fn test_mamba_default_config() {
let config = MambaConfig::default();
assert_eq!(config.d_model, 64);
assert_eq!(config.d_state, 16);
assert_eq!(config.d_conv, 4);
assert_eq!(config.expand, 2);
assert_eq!(config.d_inner(), 128);
assert_eq!(config.dt_rank, 4); }
}