use alloc::vec;
use alloc::vec::Vec;
use crate::math;
use crate::rng::standard_normal;
use crate::ssm::init::s4d_inv_real;
use crate::ssm::projection::{dot, mat_vec, softplus, Xorshift64};
use crate::ssm::SSMLayer;
pub struct SelectiveSSM {
log_a: Vec<f64>,
w_delta: Vec<f64>,
b_delta: f64,
w_b: Vec<f64>,
w_c: Vec<f64>,
d_skip: Vec<f64>,
h: Vec<f64>,
n_state: usize,
d_in: usize,
}
impl SelectiveSSM {
pub fn new(d_in: usize, n_state: usize, seed: u64) -> Self {
let log_a = s4d_inv_real(n_state);
let mut rng = Xorshift64(seed);
let scale = 0.1;
let w_delta: Vec<f64> = (0..d_in).map(|_| rng.next_normal() * scale).collect();
let b_delta = 0.0;
let w_b: Vec<f64> = (0..n_state * d_in)
.map(|_| rng.next_normal() * scale)
.collect();
let w_c: Vec<f64> = (0..n_state * d_in)
.map(|_| rng.next_normal() * scale)
.collect();
let d_skip = vec![1.0; d_in];
let h = vec![0.0; d_in * n_state];
Self {
log_a,
w_delta,
b_delta,
w_b,
w_c,
d_skip,
h,
n_state,
d_in,
}
}
#[inline]
pub fn d_in(&self) -> usize {
self.d_in
}
#[inline]
pub fn n_state(&self) -> usize {
self.n_state
}
pub fn reinitialize_channel(&mut self, d: usize, rng: &mut u64) {
assert!(
d < self.d_in,
"channel index {} out of range (d_in={})",
d,
self.d_in
);
let scale = 0.1;
for n in 0..self.n_state {
self.h[n * self.d_in + d] = 0.0;
}
self.w_delta[d] = standard_normal(rng) * scale;
for n in 0..self.n_state {
self.w_b[n * self.d_in + d] = standard_normal(rng) * scale;
}
for n in 0..self.n_state {
self.w_c[n * self.d_in + d] = standard_normal(rng) * scale;
}
self.d_skip[d] = 1.0;
}
fn selective_forward(&mut self, input: &[f64]) -> Vec<f64> {
let d_in = self.d_in;
let n_state = self.n_state;
let delta_raw = dot(&self.w_delta, input) + self.b_delta;
let delta = softplus(delta_raw);
let mut b_t = vec![0.0; n_state];
mat_vec(&self.w_b, input, n_state, d_in, &mut b_t);
let mut c_t = vec![0.0; n_state];
mat_vec(&self.w_c, input, n_state, d_in, &mut c_t);
let mut a_bar_vec = vec![0.0; n_state];
let mut b_bar_vec = vec![0.0; n_state];
for n in 0..n_state {
let a_n = -math::exp(self.log_a[n]); let ab = math::exp(delta * a_n); a_bar_vec[n] = ab;
b_bar_vec[n] = if math::abs(a_n) < 1e-12 {
delta * b_t[n]
} else {
(ab - 1.0) / a_n * b_t[n]
};
}
for n in 0..n_state {
let h_offset = n * d_in;
let a = a_bar_vec[n];
let b = b_bar_vec[n];
for (d, x_d) in input.iter().enumerate().take(d_in) {
self.h[h_offset + d] = a * self.h[h_offset + d] + b * x_d;
}
}
let mut output = vec![0.0; d_in];
for (n, &c_n) in c_t.iter().enumerate().take(n_state) {
let h_offset = n * d_in;
for (d, out_d) in output.iter_mut().enumerate().take(d_in) {
*out_d += c_n * self.h[h_offset + d];
}
}
for (out_d, (&skip, &x_d)) in output.iter_mut().zip(self.d_skip.iter().zip(input.iter())) {
*out_d += skip * x_d;
}
output
}
}
impl SSMLayer for SelectiveSSM {
fn forward(&mut self, input: &[f64]) -> Vec<f64> {
debug_assert_eq!(
input.len(),
self.d_in,
"input length {} must match d_in {}",
input.len(),
self.d_in
);
self.selective_forward(input)
}
fn state(&self) -> &[f64] {
&self.h
}
fn output_dim(&self) -> usize {
self.d_in
}
fn reset(&mut self) {
for h in self.h.iter_mut() {
*h = 0.0;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn new_creates_correct_dimensions() {
let ssm = SelectiveSSM::new(4, 8, 42);
assert_eq!(ssm.d_in(), 4);
assert_eq!(ssm.n_state(), 8);
assert_eq!(ssm.state().len(), 4 * 8);
assert_eq!(ssm.output_dim(), 4);
}
#[test]
fn initial_state_is_zero() {
let ssm = SelectiveSSM::new(3, 16, 42);
for &h in ssm.state() {
assert!(math::abs(h) < 1e-15, "initial state should be zero");
}
}
#[test]
fn forward_produces_correct_output_dim() {
let mut ssm = SelectiveSSM::new(5, 8, 42);
let input = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let output = ssm.forward(&input);
assert_eq!(output.len(), 5, "output dim should match d_in");
}
#[test]
fn forward_produces_finite_output() {
let mut ssm = SelectiveSSM::new(3, 8, 42);
let input = vec![1.0, -1.0, 0.5];
let output = ssm.forward(&input);
for (i, &y) in output.iter().enumerate() {
assert!(y.is_finite(), "output[{}] should be finite, got {}", i, y);
}
}
#[test]
fn forward_updates_state() {
let mut ssm = SelectiveSSM::new(3, 8, 42);
let input = vec![1.0, 2.0, 3.0];
let _ = ssm.forward(&input);
let state_norm: f64 = ssm.state().iter().map(|h| h * h).sum();
assert!(
state_norm > 0.0,
"state should be non-zero after processing non-zero input"
);
}
#[test]
fn reset_clears_state() {
let mut ssm = SelectiveSSM::new(3, 8, 42);
let _ = ssm.forward(&[1.0, 2.0, 3.0]);
ssm.reset();
for &h in ssm.state() {
assert!(math::abs(h) < 1e-15, "state should be zero after reset");
}
}
#[test]
fn state_decays_without_input() {
let mut ssm = SelectiveSSM::new(2, 4, 42);
let _ = ssm.forward(&[10.0, 10.0]);
let energy_after: f64 = ssm.state().iter().map(|h| h * h).sum();
for _ in 0..200 {
let _ = ssm.forward(&[0.0, 0.0]);
}
let energy_decayed: f64 = ssm.state().iter().map(|h| h * h).sum();
assert!(
energy_decayed < energy_after * 0.01,
"state should decay with zero input: initial={}, after={}",
energy_after,
energy_decayed
);
}
#[test]
fn deterministic_with_same_seed() {
let mut ssm1 = SelectiveSSM::new(3, 8, 42);
let mut ssm2 = SelectiveSSM::new(3, 8, 42);
let input = vec![1.0, 2.0, 3.0];
let out1 = ssm1.forward(&input);
let out2 = ssm2.forward(&input);
for (i, (&a, &b)) in out1.iter().zip(out2.iter()).enumerate() {
assert!(
math::abs(a - b) < 1e-15,
"output[{}] should be identical for same seed: {} vs {}",
i,
a,
b
);
}
}
#[test]
fn different_seeds_produce_different_outputs() {
let mut ssm1 = SelectiveSSM::new(3, 8, 42);
let mut ssm2 = SelectiveSSM::new(3, 8, 99);
let input = vec![1.0, 2.0, 3.0];
let out1 = ssm1.forward(&input);
let out2 = ssm2.forward(&input);
let diff: f64 = out1
.iter()
.zip(out2.iter())
.map(|(a, b)| (a - b) * (a - b))
.sum();
assert!(
diff > 1e-20,
"different seeds should generally produce different outputs"
);
}
#[test]
fn single_channel_works() {
let mut ssm = SelectiveSSM::new(1, 4, 42);
let output = ssm.forward(&[3.0]);
assert_eq!(output.len(), 1);
assert!(output[0].is_finite());
}
#[test]
fn single_state_dim_works() {
let mut ssm = SelectiveSSM::new(3, 1, 42);
let output = ssm.forward(&[1.0, 2.0, 3.0]);
assert_eq!(output.len(), 3);
for &y in &output {
assert!(y.is_finite());
}
}
#[test]
fn sequential_outputs_differ() {
let mut ssm = SelectiveSSM::new(2, 4, 42);
let out1 = ssm.forward(&[1.0, 0.0]);
let out2 = ssm.forward(&[1.0, 0.0]);
let diff: f64 = out1
.iter()
.zip(out2.iter())
.map(|(a, b)| (a - b) * (a - b))
.sum();
assert!(
diff > 1e-20,
"sequential calls with same input should differ due to state: out1={:?}, out2={:?}",
out1,
out2
);
}
#[test]
fn large_input_no_overflow() {
let mut ssm = SelectiveSSM::new(2, 4, 42);
let input = vec![1000.0, -1000.0];
let output = ssm.forward(&input);
for (i, &y) in output.iter().enumerate() {
assert!(
y.is_finite(),
"output[{}] should be finite for large inputs, got {}",
i,
y
);
}
}
#[test]
fn zero_input_zero_state_gives_zero_output() {
let mut ssm = SelectiveSSM::new(3, 8, 42);
let output = ssm.forward(&[0.0, 0.0, 0.0]);
for (i, &y) in output.iter().enumerate() {
assert!(
math::abs(y) < 1e-15,
"zero input with zero state should give zero output[{}], got {}",
i,
y
);
}
}
#[test]
fn reinitialize_channel_preserves_others() {
let mut ssm = SelectiveSSM::new(3, 8, 42);
for step in 0..10 {
let x = vec![
(step as f64) * 0.3,
(step as f64) * -0.2,
(step as f64) * 0.1,
];
let _ = ssm.forward(&x);
}
let state_before: Vec<f64> = ssm.state().to_vec();
let w_delta_0 = ssm.w_delta[0];
let w_delta_2 = ssm.w_delta[2];
let wb_col0: Vec<f64> = (0..ssm.n_state).map(|n| ssm.w_b[n * ssm.d_in]).collect();
let wb_col2: Vec<f64> = (0..ssm.n_state)
.map(|n| ssm.w_b[n * ssm.d_in + 2])
.collect();
let wc_col0: Vec<f64> = (0..ssm.n_state).map(|n| ssm.w_c[n * ssm.d_in]).collect();
let wc_col2: Vec<f64> = (0..ssm.n_state)
.map(|n| ssm.w_c[n * ssm.d_in + 2])
.collect();
let mut rng = 0xBEEF_u64;
ssm.reinitialize_channel(1, &mut rng);
for n in 0..ssm.n_state {
let idx = n * ssm.d_in;
assert!(
math::abs(ssm.h[idx] - state_before[idx]) < 1e-15,
"channel 0 state[{}] should be preserved after reinit of channel 1",
n
);
}
for n in 0..ssm.n_state {
let idx = n * ssm.d_in + 2;
assert!(
math::abs(ssm.h[idx] - state_before[idx]) < 1e-15,
"channel 2 state[{}] should be preserved after reinit of channel 1",
n
);
}
for n in 0..ssm.n_state {
let idx = n * ssm.d_in + 1;
assert!(
math::abs(ssm.h[idx]) < 1e-15,
"channel 1 state[{}] should be zeroed after reinit, got {}",
n,
ssm.h[idx]
);
}
assert!(
math::abs(ssm.w_delta[0] - w_delta_0) < 1e-15,
"w_delta[0] should be preserved"
);
assert!(
math::abs(ssm.w_delta[2] - w_delta_2) < 1e-15,
"w_delta[2] should be preserved"
);
for n in 0..ssm.n_state {
assert!(
math::abs(ssm.w_b[n * ssm.d_in] - wb_col0[n]) < 1e-15,
"w_b col 0 row {} should be preserved",
n
);
assert!(
math::abs(ssm.w_b[n * ssm.d_in + 2] - wb_col2[n]) < 1e-15,
"w_b col 2 row {} should be preserved",
n
);
assert!(
math::abs(ssm.w_c[n * ssm.d_in] - wc_col0[n]) < 1e-15,
"w_c col 0 row {} should be preserved",
n
);
assert!(
math::abs(ssm.w_c[n * ssm.d_in + 2] - wc_col2[n]) < 1e-15,
"w_c col 2 row {} should be preserved",
n
);
}
let mut any_wb_diff = false;
for n in 0..ssm.n_state {
if math::abs(ssm.w_b[n * ssm.d_in + 1]) > 1e-15 {
any_wb_diff = true;
}
}
assert!(
any_wb_diff,
"reinitialised channel 1 w_b should have non-zero weights"
);
assert!(
math::abs(ssm.d_skip[1] - 1.0) < 1e-15,
"d_skip[1] should be reset to 1.0 after reinit, got {}",
ssm.d_skip[1]
);
}
}