use alloc::vec;
use alloc::vec::Vec;
use crate::math;
use crate::ssm::discretize::zoh_discretize;
use crate::ssm::init::mamba_init;
use crate::ssm::projection::dot;
use crate::ssm::SSMLayer;
pub struct DiagonalSSM {
log_a: Vec<f64>,
b: Vec<f64>,
c: Vec<f64>,
delta: f64,
d_skip: f64,
h: Vec<f64>,
a_bar: Vec<f64>,
b_bar_factor: Vec<f64>,
}
impl DiagonalSSM {
pub fn new(n_state: usize, delta: f64) -> Self {
let log_a = mamba_init(n_state);
let b = vec![1.0; n_state];
let c = vec![1.0; n_state];
let h = vec![0.0; n_state];
let mut a_bar = Vec::with_capacity(n_state);
let mut b_bar_factor = Vec::with_capacity(n_state);
for la in &log_a {
let a_n = -math::exp(*la);
let (ab, bbf) = zoh_discretize(a_n, delta);
a_bar.push(ab);
b_bar_factor.push(bbf);
}
Self {
log_a,
b,
c,
delta,
d_skip: 0.0,
h,
a_bar,
b_bar_factor,
}
}
pub fn with_params(n_state: usize, delta: f64, b: Vec<f64>, c: Vec<f64>, d_skip: f64) -> Self {
debug_assert_eq!(b.len(), n_state);
debug_assert_eq!(c.len(), n_state);
let log_a = mamba_init(n_state);
let h = vec![0.0; n_state];
let mut a_bar = Vec::with_capacity(n_state);
let mut b_bar_factor = Vec::with_capacity(n_state);
for la in &log_a {
let a_n = -math::exp(*la);
let (ab, bbf) = zoh_discretize(a_n, delta);
a_bar.push(ab);
b_bar_factor.push(bbf);
}
Self {
log_a,
b,
c,
delta,
d_skip,
h,
a_bar,
b_bar_factor,
}
}
#[inline]
pub fn forward_scalar(&mut self, x: f64) -> f64 {
let n_state = self.h.len();
for n in 0..n_state {
self.h[n] = self.a_bar[n] * self.h[n] + self.b_bar_factor[n] * self.b[n] * x;
}
dot(&self.c, &self.h) + self.d_skip * x
}
#[inline]
pub fn n_state(&self) -> usize {
self.h.len()
}
#[inline]
pub fn log_a(&self) -> &[f64] {
&self.log_a
}
#[inline]
pub fn delta(&self) -> f64 {
self.delta
}
}
impl SSMLayer for DiagonalSSM {
fn forward(&mut self, input: &[f64]) -> Vec<f64> {
let x = if input.is_empty() { 0.0 } else { input[0] };
vec![self.forward_scalar(x)]
}
fn state(&self) -> &[f64] {
&self.h
}
fn output_dim(&self) -> usize {
1
}
fn reset(&mut self) {
for h in self.h.iter_mut() {
*h = 0.0;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn new_creates_zero_state() {
let ssm = DiagonalSSM::new(8, 0.1);
assert_eq!(ssm.n_state(), 8);
for &h in ssm.state() {
assert!(math::abs(h) < 1e-15, "initial state should be zero");
}
}
#[test]
fn forward_scalar_produces_finite_output() {
let mut ssm = DiagonalSSM::new(4, 0.1);
let y = ssm.forward_scalar(1.0);
assert!(y.is_finite(), "output should be finite, got {}", y);
}
#[test]
fn forward_updates_state() {
let mut ssm = DiagonalSSM::new(4, 0.1);
let _ = ssm.forward_scalar(1.0);
let state_norm: f64 = ssm.state().iter().map(|h| h * h).sum();
assert!(
state_norm > 0.0,
"state should be non-zero after processing input"
);
}
#[test]
fn reset_clears_state() {
let mut ssm = DiagonalSSM::new(4, 0.1);
let _ = ssm.forward_scalar(1.0);
ssm.reset();
for &h in ssm.state() {
assert!(math::abs(h) < 1e-15, "state should be zero after reset");
}
}
#[test]
fn state_decays_without_input() {
let mut ssm = DiagonalSSM::new(4, 0.1);
let _ = ssm.forward_scalar(10.0);
let energy_after_input: f64 = ssm.state().iter().map(|h| h * h).sum();
for _ in 0..100 {
let _ = ssm.forward_scalar(0.0);
}
let energy_after_decay: f64 = ssm.state().iter().map(|h| h * h).sum();
assert!(
energy_after_decay < energy_after_input * 0.01,
"state energy should decay: initial={}, after={}",
energy_after_input,
energy_after_decay
);
}
#[test]
fn ssm_layer_trait_works() {
let mut ssm = DiagonalSSM::new(4, 0.1);
let out = ssm.forward(&[1.0]);
assert_eq!(out.len(), 1, "output_dim should be 1");
assert_eq!(ssm.output_dim(), 1);
}
#[test]
fn constant_input_converges() {
let mut ssm = DiagonalSSM::new(4, 0.1);
let mut prev_y = 0.0;
let mut settled = false;
for i in 0..500 {
let y = ssm.forward_scalar(1.0);
if i > 10 && math::abs(y - prev_y) < 1e-10 {
settled = true;
break;
}
prev_y = y;
}
assert!(settled, "output should converge for constant input");
}
#[test]
fn skip_connection_passes_through() {
let b = vec![0.0; 4]; let c = vec![0.0; 4]; let mut ssm = DiagonalSSM::with_params(4, 0.1, b, c, 1.0);
let y = ssm.forward_scalar(5.0);
assert!(
math::abs(y - 5.0) < 1e-12,
"with zero B/C and d_skip=1, output should equal input: got {}",
y
);
}
#[test]
fn empty_input_treated_as_zero() {
let mut ssm = DiagonalSSM::new(4, 0.1);
let out = ssm.forward(&[]);
assert_eq!(out.len(), 1);
assert!(
math::abs(out[0]) < 1e-15,
"empty input should be treated as zero"
);
}
#[test]
fn different_delta_changes_dynamics() {
let mut ssm_fast = DiagonalSSM::new(4, 1.0);
let mut ssm_slow = DiagonalSSM::new(4, 0.001);
let _ = ssm_fast.forward_scalar(1.0);
let y_fast = ssm_fast.forward_scalar(0.0);
let _ = ssm_slow.forward_scalar(1.0);
let y_slow = ssm_slow.forward_scalar(0.0);
assert!(
math::abs(y_fast - y_slow) > 1e-6,
"different delta should produce different dynamics: fast={}, slow={}",
y_fast,
y_slow
);
}
}