use alloc::vec;
use alloc::vec::Vec;
use crate::math;
use crate::rng::standard_normal;
use crate::ssm::discretize::{exp_trapezoidal_complex, trapezoidal_complex};
use crate::ssm::init::s4d_inv_complex;
use crate::ssm::norm::BCNorm;
use crate::ssm::projection::{dot, mat_vec, sigmoid, softplus, Xorshift64};
use crate::ssm::SSMLayer;
pub struct SelectiveSSMv3 {
log_a_complex: Vec<f64>,
w_delta: Vec<f64>,
b_delta: f64,
w_b: Vec<f64>,
w_c: Vec<f64>,
d_skip: Vec<f64>,
h: Vec<f64>,
n_state: usize,
d_in: usize,
n_groups: usize,
}
impl SelectiveSSMv3 {
pub fn new(d_in: usize, n_state: usize, n_groups: usize, seed: u64) -> Self {
assert!(
d_in % n_groups == 0,
"d_in ({}) must be evenly divisible by n_groups ({})",
d_in,
n_groups
);
let log_a_complex = s4d_inv_complex(n_state);
let mut rng = Xorshift64(seed);
let scale = 0.1;
let w_delta: Vec<f64> = (0..d_in).map(|_| rng.next_normal() * scale).collect();
let b_delta = 0.0;
let w_b: Vec<f64> = (0..n_groups * n_state * d_in)
.map(|_| rng.next_normal() * scale)
.collect();
let w_c: Vec<f64> = (0..n_groups * n_state * d_in)
.map(|_| rng.next_normal() * scale)
.collect();
let d_skip = vec![1.0; d_in];
let h = vec![0.0; 2 * n_groups * n_state];
Self {
log_a_complex,
w_delta,
b_delta,
w_b,
w_c,
d_skip,
h,
n_state,
d_in,
n_groups,
}
}
#[inline]
pub fn d_in(&self) -> usize {
self.d_in
}
#[inline]
pub fn n_state(&self) -> usize {
self.n_state
}
#[inline]
pub fn n_groups(&self) -> usize {
self.n_groups
}
pub fn reinitialize_group(&mut self, g: usize, rng: &mut u64) {
assert!(
g < self.n_groups,
"group index {} out of range (n_groups={})",
g,
self.n_groups
);
let scale = 0.1;
let cpg = self.d_in / self.n_groups;
for n in 0..self.n_state {
let h_idx = (g * self.n_state + n) * 2;
self.h[h_idx] = 0.0;
self.h[h_idx + 1] = 0.0;
}
let wb_start = g * self.n_state * self.d_in;
for i in 0..self.n_state * self.d_in {
self.w_b[wb_start + i] = standard_normal(rng) * scale;
}
let wc_start = g * self.n_state * self.d_in;
for i in 0..self.n_state * self.d_in {
self.w_c[wc_start + i] = standard_normal(rng) * scale;
}
let ch_start = g * cpg;
for d in ch_start..ch_start + cpg {
self.d_skip[d] = 1.0;
}
}
fn mimo_forward(&mut self, input: &[f64]) -> Vec<f64> {
let d_in = self.d_in;
let n_state = self.n_state;
let n_groups = self.n_groups;
let cpg = d_in / n_groups;
let delta_raw = dot(&self.w_delta, input) + self.b_delta;
let delta = softplus(delta_raw);
let mut output = vec![0.0; d_in];
for g in 0..n_groups {
let wb_offset = g * n_state * d_in;
let mut b_t_g = vec![0.0; n_state];
mat_vec(
&self.w_b[wb_offset..wb_offset + n_state * d_in],
input,
n_state,
d_in,
&mut b_t_g,
);
let wc_offset = g * n_state * d_in;
let mut c_t_g = vec![0.0; n_state];
mat_vec(
&self.w_c[wc_offset..wc_offset + n_state * d_in],
input,
n_state,
d_in,
&mut c_t_g,
);
let group_start = g * cpg;
let mut x_group = 0.0;
for d in 0..cpg {
x_group += input[group_start + d];
}
x_group /= cpg as f64;
let mut y_group = 0.0;
for n in 0..n_state {
let a_re = -math::exp(self.log_a_complex[2 * n]);
let a_im = self.log_a_complex[2 * n + 1];
let (a_bar_re, a_bar_im, b_fac_re, b_fac_im) =
trapezoidal_complex(a_re, a_im, delta);
let bx = b_t_g[n] * x_group;
let b_input_re = b_fac_re * bx;
let b_input_im = b_fac_im * bx;
let h_idx = (g * n_state + n) * 2;
let h_re_old = self.h[h_idx];
let h_im_old = self.h[h_idx + 1];
let h_re = a_bar_re * h_re_old - a_bar_im * h_im_old + b_input_re;
let h_im = a_bar_re * h_im_old + a_bar_im * h_re_old + b_input_im;
self.h[h_idx] = h_re;
self.h[h_idx + 1] = h_im;
y_group += c_t_g[n] * h_re;
}
for d in 0..cpg {
let idx = group_start + d;
output[idx] = y_group + self.d_skip[idx] * input[idx];
}
}
output
}
}
impl SSMLayer for SelectiveSSMv3 {
fn forward(&mut self, input: &[f64]) -> Vec<f64> {
debug_assert_eq!(
input.len(),
self.d_in,
"input length {} must match d_in {}",
input.len(),
self.d_in
);
self.mimo_forward(input)
}
fn state(&self) -> &[f64] {
&self.h
}
fn output_dim(&self) -> usize {
self.d_in
}
fn reset(&mut self) {
for h in self.h.iter_mut() {
*h = 0.0;
}
}
}
pub struct SelectiveSSMv3Exp {
log_a_complex: Vec<f64>,
w_delta: Vec<f64>,
b_delta: f64,
w_lambda: Vec<f64>,
b_lambda: f64,
w_b: Vec<f64>,
w_c: Vec<f64>,
d_skip: Vec<f64>,
h: Vec<f64>,
prev_bx: Vec<f64>,
n_state: usize,
d_in: usize,
n_groups: usize,
bcnorm: Option<BCNorm>,
}
impl SelectiveSSMv3Exp {
pub fn new(d_in: usize, n_state: usize, n_groups: usize, seed: u64, use_bcnorm: bool) -> Self {
assert!(
d_in % n_groups == 0,
"d_in ({}) must be divisible by n_groups ({})",
d_in,
n_groups
);
let log_a_complex = s4d_inv_complex(n_state);
let mut rng = Xorshift64(seed);
let scale = 0.1;
let w_delta: Vec<f64> = (0..d_in).map(|_| rng.next_normal() * scale).collect();
let b_delta = 0.0;
let w_lambda: Vec<f64> = (0..d_in).map(|_| rng.next_normal() * scale).collect();
let b_lambda = 0.0_f64;
let w_b: Vec<f64> = (0..n_groups * n_state * d_in)
.map(|_| rng.next_normal() * scale)
.collect();
let w_c: Vec<f64> = (0..n_groups * n_state * d_in)
.map(|_| rng.next_normal() * scale)
.collect();
let d_skip = vec![1.0; d_in];
let h = vec![0.0; 2 * n_groups * n_state];
let prev_bx = vec![0.0; 2 * n_groups * n_state];
let bcnorm = if use_bcnorm {
Some(BCNorm::new(n_state))
} else {
None
};
Self {
log_a_complex,
w_delta,
b_delta,
w_lambda,
b_lambda,
w_b,
w_c,
d_skip,
h,
prev_bx,
n_state,
d_in,
n_groups,
bcnorm,
}
}
#[inline]
pub fn d_in(&self) -> usize {
self.d_in
}
#[inline]
pub fn n_state(&self) -> usize {
self.n_state
}
#[inline]
pub fn n_groups(&self) -> usize {
self.n_groups
}
#[inline]
pub fn uses_bcnorm(&self) -> bool {
self.bcnorm.is_some()
}
pub fn reinitialize_group(&mut self, g: usize, rng: &mut u64) {
assert!(g < self.n_groups, "group index {} out of range", g);
let scale = 0.1;
let cpg = self.d_in / self.n_groups;
for n in 0..self.n_state {
let idx = (g * self.n_state + n) * 2;
self.h[idx] = 0.0;
self.h[idx + 1] = 0.0;
self.prev_bx[idx] = 0.0;
self.prev_bx[idx + 1] = 0.0;
}
let wb_start = g * self.n_state * self.d_in;
for i in 0..self.n_state * self.d_in {
self.w_b[wb_start + i] = standard_normal(rng) * scale;
self.w_c[wb_start + i] = standard_normal(rng) * scale;
}
let ch_start = g * cpg;
for d in ch_start..ch_start + cpg {
self.d_skip[d] = 1.0;
}
}
fn exp_trap_forward(&mut self, input: &[f64]) -> Vec<f64> {
let d_in = self.d_in;
let n_state = self.n_state;
let n_groups = self.n_groups;
let cpg = d_in / n_groups;
let delta_raw = dot(&self.w_delta, input) + self.b_delta;
let delta = softplus(delta_raw);
let lambda_raw = dot(&self.w_lambda, input) + self.b_lambda;
let lambda = sigmoid(lambda_raw);
let mut output = vec![0.0; d_in];
for g in 0..n_groups {
let wb_offset = g * n_state * d_in;
let mut b_t_g = vec![0.0; n_state];
mat_vec(
&self.w_b[wb_offset..wb_offset + n_state * d_in],
input,
n_state,
d_in,
&mut b_t_g,
);
let wc_offset = g * n_state * d_in;
let mut c_t_g = vec![0.0; n_state];
mat_vec(
&self.w_c[wc_offset..wc_offset + n_state * d_in],
input,
n_state,
d_in,
&mut c_t_g,
);
if let Some(ref norm) = self.bcnorm {
b_t_g = norm.normalize(&b_t_g);
c_t_g = norm.normalize(&c_t_g);
}
let group_start = g * cpg;
let mut x_group = 0.0;
for d in 0..cpg {
x_group += input[group_start + d];
}
x_group /= cpg as f64;
let mut y_group = 0.0;
for n in 0..n_state {
let a_re = -math::exp(self.log_a_complex[2 * n]);
let a_im = self.log_a_complex[2 * n + 1];
let (alpha_re, alpha_im, beta_re, beta_im, gamma_re, gamma_im) =
exp_trapezoidal_complex(a_re, a_im, delta, lambda);
let h_idx = (g * n_state + n) * 2;
let h_re_old = self.h[h_idx];
let h_im_old = self.h[h_idx + 1];
let bx = b_t_g[n] * x_group;
let pbx_re = self.prev_bx[h_idx];
let ah_re = alpha_re * h_re_old - alpha_im * h_im_old;
let ah_im = alpha_re * h_im_old + alpha_im * h_re_old;
let b_prev_re = beta_re * pbx_re;
let b_prev_im = beta_im * pbx_re;
let b_curr_re = gamma_re * bx;
let b_curr_im = gamma_im * bx;
let h_re = ah_re + b_prev_re + b_curr_re;
let h_im = ah_im + b_prev_im + b_curr_im;
self.h[h_idx] = h_re;
self.h[h_idx + 1] = h_im;
self.prev_bx[h_idx] = bx;
self.prev_bx[h_idx + 1] = 0.0;
y_group += c_t_g[n] * h_re;
}
for d in 0..cpg {
let idx = group_start + d;
output[idx] = y_group + self.d_skip[idx] * input[idx];
}
}
output
}
}
impl SSMLayer for SelectiveSSMv3Exp {
fn forward(&mut self, input: &[f64]) -> Vec<f64> {
debug_assert_eq!(input.len(), self.d_in, "input length mismatch");
self.exp_trap_forward(input)
}
fn state(&self) -> &[f64] {
&self.h
}
fn output_dim(&self) -> usize {
self.d_in
}
fn reset(&mut self) {
self.h.fill(0.0);
self.prev_bx.fill(0.0);
}
}
pub struct SelectiveSSMv3Mimo {
log_a_complex: Vec<f64>,
w_delta: Vec<f64>,
b_delta: f64,
w_lambda: Vec<f64>,
b_lambda: f64,
w_b: Vec<f64>,
w_c: Vec<f64>,
d_skip: Vec<f64>,
h: Vec<f64>,
prev_bx: Vec<f64>,
n_state: usize,
d_in: usize,
n_groups: usize,
cpg: usize,
rank: usize,
bcnorm: Option<BCNorm>,
}
impl SelectiveSSMv3Mimo {
pub fn new(
d_in: usize,
n_state: usize,
n_groups: usize,
rank: usize,
seed: u64,
use_bcnorm: bool,
) -> Self {
assert!(
d_in % n_groups == 0,
"d_in ({}) must be divisible by n_groups ({})",
d_in,
n_groups
);
assert!(rank >= 1, "rank must be >= 1, got {}", rank);
let cpg = d_in / n_groups;
let log_a_complex = s4d_inv_complex(n_state);
let mut rng = Xorshift64(seed);
let scale = 0.1;
let w_delta: Vec<f64> = (0..d_in).map(|_| rng.next_normal() * scale).collect();
let b_delta = 0.0;
let w_lambda: Vec<f64> = (0..d_in).map(|_| rng.next_normal() * scale).collect();
let b_lambda = 0.0_f64;
let w_b: Vec<f64> = (0..n_groups * n_state * rank * d_in)
.map(|_| rng.next_normal() * scale)
.collect();
let w_c: Vec<f64> = (0..n_groups * n_state * d_in)
.map(|_| rng.next_normal() * scale)
.collect();
let d_skip = vec![1.0; d_in];
let h = vec![0.0; 2 * n_groups * n_state * cpg];
let prev_bx = vec![0.0; 2 * n_groups * n_state * cpg];
let bcnorm = if use_bcnorm {
Some(BCNorm::new(n_state))
} else {
None
};
Self {
log_a_complex,
w_delta,
b_delta,
w_lambda,
b_lambda,
w_b,
w_c,
d_skip,
h,
prev_bx,
n_state,
d_in,
n_groups,
cpg,
rank,
bcnorm,
}
}
#[inline]
pub fn d_in(&self) -> usize {
self.d_in
}
#[inline]
pub fn n_state(&self) -> usize {
self.n_state
}
#[inline]
pub fn n_groups(&self) -> usize {
self.n_groups
}
#[inline]
pub fn rank(&self) -> usize {
self.rank
}
#[inline]
pub fn uses_bcnorm(&self) -> bool {
self.bcnorm.is_some()
}
pub fn reinitialize_group(&mut self, g: usize, rng: &mut u64) {
assert!(g < self.n_groups, "group {} out of range", g);
let scale = 0.1;
let cpg = self.cpg;
for n in 0..self.n_state {
for p in 0..cpg {
let idx = ((g * self.n_state + n) * cpg + p) * 2;
self.h[idx] = 0.0;
self.h[idx + 1] = 0.0;
self.prev_bx[idx] = 0.0;
self.prev_bx[idx + 1] = 0.0;
}
}
let wb_start = g * self.n_state * self.rank * self.d_in;
for i in 0..self.n_state * self.rank * self.d_in {
self.w_b[wb_start + i] = standard_normal(rng) * scale;
}
let wc_start = g * self.n_state * self.d_in;
for i in 0..self.n_state * self.d_in {
self.w_c[wc_start + i] = standard_normal(rng) * scale;
}
let ch_start = g * cpg;
for d in ch_start..ch_start + cpg {
self.d_skip[d] = 1.0;
}
}
fn mimo_forward(&mut self, input: &[f64]) -> Vec<f64> {
let d_in = self.d_in;
let n_state = self.n_state;
let n_groups = self.n_groups;
let cpg = self.cpg;
let rank = self.rank;
let delta = softplus(dot(&self.w_delta, input) + self.b_delta);
let lambda = sigmoid(dot(&self.w_lambda, input) + self.b_lambda);
let mut output = vec![0.0; d_in];
for g in 0..n_groups {
let wb_offset = g * n_state * rank * d_in;
let wb_rows = n_state * rank;
let mut b_t_flat = vec![0.0; wb_rows];
mat_vec(
&self.w_b[wb_offset..wb_offset + wb_rows * d_in],
input,
wb_rows,
d_in,
&mut b_t_flat,
);
let wc_offset = g * n_state * d_in;
let mut c_t_g = vec![0.0; n_state];
mat_vec(
&self.w_c[wc_offset..wc_offset + n_state * d_in],
input,
n_state,
d_in,
&mut c_t_g,
);
if let Some(ref norm) = self.bcnorm {
c_t_g = norm.normalize(&c_t_g);
}
let group_start = g * cpg;
let x_group_slice = &input[group_start..group_start + cpg];
let mut y_channel = vec![0.0; cpg];
for n in 0..n_state {
let a_re = -math::exp(self.log_a_complex[2 * n]);
let a_im = self.log_a_complex[2 * n + 1];
let (alpha_re, alpha_im, beta_re, beta_im, gamma_re, gamma_im) =
exp_trapezoidal_complex(a_re, a_im, delta, lambda);
for p in 0..cpg {
let h_idx = ((g * n_state + n) * cpg + p) * 2;
let h_re_old = self.h[h_idx];
let h_im_old = self.h[h_idx + 1];
let bx = if rank == 1 {
b_t_flat[n] * x_group_slice[p]
} else {
let r = p % rank;
b_t_flat[n * rank + r] * x_group_slice[p]
};
let pbx_re = self.prev_bx[h_idx];
let ah_re = alpha_re * h_re_old - alpha_im * h_im_old;
let ah_im = alpha_re * h_im_old + alpha_im * h_re_old;
let b_prev_re = beta_re * pbx_re;
let b_prev_im = beta_im * pbx_re;
let b_curr_re = gamma_re * bx;
let b_curr_im = gamma_im * bx;
let h_re = ah_re + b_prev_re + b_curr_re;
let h_im = ah_im + b_prev_im + b_curr_im;
self.h[h_idx] = h_re;
self.h[h_idx + 1] = h_im;
self.prev_bx[h_idx] = bx;
self.prev_bx[h_idx + 1] = 0.0;
y_channel[p] += c_t_g[n] * h_re;
}
}
for (p, &yp) in y_channel.iter().enumerate().take(cpg) {
let idx = group_start + p;
output[idx] = yp + self.d_skip[idx] * input[idx];
}
}
output
}
}
impl SSMLayer for SelectiveSSMv3Mimo {
fn forward(&mut self, input: &[f64]) -> Vec<f64> {
debug_assert_eq!(input.len(), self.d_in, "input length mismatch");
self.mimo_forward(input)
}
fn state(&self) -> &[f64] {
&self.h
}
fn output_dim(&self) -> usize {
self.d_in
}
fn reset(&mut self) {
self.h.fill(0.0);
self.prev_bx.fill(0.0);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn selective_v3_output_dimension() {
let mut ssm = SelectiveSSMv3::new(6, 8, 2, 42);
let input = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let output = ssm.forward(&input);
assert_eq!(
output.len(),
6,
"output dim should match d_in, got {}",
output.len()
);
}
#[test]
fn selective_v3_complex_state_bounded() {
let mut ssm = SelectiveSSMv3::new(4, 8, 2, 42);
let input = vec![1.0, -0.5, 0.3, -0.8];
for step in 0..1000 {
let output = ssm.forward(&input);
for (i, &y) in output.iter().enumerate() {
assert!(
y.is_finite(),
"output[{}] is not finite at step {}: {}",
i,
step,
y
);
}
}
for (i, &h) in ssm.state().iter().enumerate() {
assert!(
h.is_finite(),
"state[{}] is not finite after 1000 steps: {}",
i,
h
);
}
let state_norm: f64 = ssm.state().iter().map(|h| h * h).sum();
assert!(
state_norm < 1e12,
"state norm should be bounded, got {}",
state_norm
);
}
#[test]
fn selective_v3_trapezoidal_stability() {
let log_a = s4d_inv_complex(16);
let delta = 0.5; for n in 0..16 {
let a_re = -math::exp(log_a[2 * n]);
let a_im = log_a[2 * n + 1];
let (a_bar_re, a_bar_im, _, _) = trapezoidal_complex(a_re, a_im, delta);
let mag_sq = a_bar_re * a_bar_re + a_bar_im * a_bar_im;
assert!(
mag_sq < 1.0,
"eigenvalue {} has |A_bar|^2 = {} >= 1 (a_re={}, a_im={}, delta={})",
n,
mag_sq,
a_re,
a_im,
delta
);
}
}
#[test]
fn selective_v3_mimo_groups() {
let d_in = 4;
let n_state = 4;
let seed = 42;
let mut ssm_one = SelectiveSSMv3::new(d_in, n_state, 1, seed);
let mut ssm_max = SelectiveSSMv3::new(d_in, n_state, d_in, seed);
let input = vec![1.0, 2.0, 3.0, 4.0];
let out_one = ssm_one.forward(&input);
let out_max = ssm_max.forward(&input);
assert_eq!(out_one.len(), d_in);
assert_eq!(out_max.len(), d_in);
for &y in &out_one {
assert!(y.is_finite(), "n_groups=1 output should be finite");
}
for &y in &out_max {
assert!(y.is_finite(), "n_groups=d_in output should be finite");
}
let diff: f64 = out_one
.iter()
.zip(out_max.iter())
.map(|(a, b)| (a - b) * (a - b))
.sum();
assert!(
diff > 1e-20,
"different n_groups should produce different outputs: diff={}",
diff
);
}
#[test]
fn selective_v3_reset_clears_state() {
let mut ssm = SelectiveSSMv3::new(4, 8, 2, 42);
let _ = ssm.forward(&[1.0, 2.0, 3.0, 4.0]);
let energy: f64 = ssm.state().iter().map(|h| h * h).sum();
assert!(energy > 0.0, "state should be non-zero after forward pass");
ssm.reset();
for (i, &h) in ssm.state().iter().enumerate() {
assert!(
math::abs(h) < 1e-15,
"state[{}] should be zero after reset, got {}",
i,
h
);
}
}
#[test]
fn selective_v3_initial_state_zero() {
let ssm = SelectiveSSMv3::new(4, 8, 2, 42);
assert_eq!(
ssm.state().len(),
2 * 2 * 8,
"state size = 2 * n_groups * n_state"
);
for &h in ssm.state() {
assert!(math::abs(h) < 1e-15, "initial state should be zero");
}
}
#[test]
fn selective_v3_deterministic_same_seed() {
let mut ssm1 = SelectiveSSMv3::new(4, 8, 2, 42);
let mut ssm2 = SelectiveSSMv3::new(4, 8, 2, 42);
let input = vec![1.0, -1.0, 0.5, -0.5];
let out1 = ssm1.forward(&input);
let out2 = ssm2.forward(&input);
for (i, (&a, &b)) in out1.iter().zip(out2.iter()).enumerate() {
assert!(
math::abs(a - b) < 1e-15,
"output[{}] should be identical for same seed: {} vs {}",
i,
a,
b
);
}
}
#[test]
fn selective_v3_zero_input_zero_output() {
let mut ssm = SelectiveSSMv3::new(4, 8, 2, 42);
let output = ssm.forward(&[0.0, 0.0, 0.0, 0.0]);
for (i, &y) in output.iter().enumerate() {
assert!(
math::abs(y) < 1e-15,
"zero input with zero state should give zero output[{}], got {}",
i,
y
);
}
}
#[test]
fn selective_v3_single_group() {
let mut ssm = SelectiveSSMv3::new(3, 4, 3, 42);
let output = ssm.forward(&[1.0, 2.0, 3.0]);
assert_eq!(output.len(), 3);
for &y in &output {
assert!(y.is_finite());
}
}
#[test]
fn selective_v3_sequential_outputs_differ() {
let mut ssm = SelectiveSSMv3::new(4, 8, 2, 42);
let input = vec![1.0, 0.0, -1.0, 0.5];
let out1 = ssm.forward(&input);
let out2 = ssm.forward(&input);
let diff: f64 = out1
.iter()
.zip(out2.iter())
.map(|(a, b)| (a - b) * (a - b))
.sum();
assert!(
diff > 1e-20,
"sequential calls should differ due to state: out1={:?}, out2={:?}",
out1,
out2
);
}
#[test]
fn selective_v3_accessors() {
let ssm = SelectiveSSMv3::new(6, 4, 3, 42);
assert_eq!(ssm.d_in(), 6);
assert_eq!(ssm.n_state(), 4);
assert_eq!(ssm.n_groups(), 3);
assert_eq!(ssm.output_dim(), 6);
}
#[test]
fn reinitialize_group_preserves_others() {
let mut ssm = SelectiveSSMv3::new(6, 4, 3, 42);
for step in 0..10 {
let s = step as f64;
let x = vec![s * 0.1, s * -0.2, s * 0.3, s * -0.1, s * 0.2, s * -0.3];
let _ = ssm.forward(&x);
}
let state_before: Vec<f64> = ssm.state().to_vec();
let n_state = ssm.n_state();
let d_in = ssm.d_in();
let wb_g0: Vec<f64> = ssm.w_b[0..n_state * d_in].to_vec();
let wb_g2: Vec<f64> = ssm.w_b[2 * n_state * d_in..3 * n_state * d_in].to_vec();
let wc_g0: Vec<f64> = ssm.w_c[0..n_state * d_in].to_vec();
let wc_g2: Vec<f64> = ssm.w_c[2 * n_state * d_in..3 * n_state * d_in].to_vec();
let mut rng = 0xBEEF_u64;
ssm.reinitialize_group(1, &mut rng);
for n in 0..n_state {
let idx = n * 2; assert!(
math::abs(ssm.h[idx] - state_before[idx]) < 1e-15,
"group 0 state re[{}] should be preserved",
n
);
assert!(
math::abs(ssm.h[idx + 1] - state_before[idx + 1]) < 1e-15,
"group 0 state im[{}] should be preserved",
n
);
}
for n in 0..n_state {
let idx = (2 * n_state + n) * 2;
assert!(
math::abs(ssm.h[idx] - state_before[idx]) < 1e-15,
"group 2 state re[{}] should be preserved",
n
);
assert!(
math::abs(ssm.h[idx + 1] - state_before[idx + 1]) < 1e-15,
"group 2 state im[{}] should be preserved",
n
);
}
for n in 0..n_state {
let idx = (n_state + n) * 2;
assert!(
math::abs(ssm.h[idx]) < 1e-15,
"group 1 state re[{}] should be zero after reinit, got {}",
n,
ssm.h[idx]
);
assert!(
math::abs(ssm.h[idx + 1]) < 1e-15,
"group 1 state im[{}] should be zero after reinit, got {}",
n,
ssm.h[idx + 1]
);
}
assert_eq!(
&ssm.w_b[0..n_state * d_in],
wb_g0.as_slice(),
"group 0 w_b should be preserved"
);
assert_eq!(
&ssm.w_b[2 * n_state * d_in..3 * n_state * d_in],
wb_g2.as_slice(),
"group 2 w_b should be preserved"
);
assert_eq!(
&ssm.w_c[0..n_state * d_in],
wc_g0.as_slice(),
"group 0 w_c should be preserved"
);
assert_eq!(
&ssm.w_c[2 * n_state * d_in..3 * n_state * d_in],
wc_g2.as_slice(),
"group 2 w_c should be preserved"
);
assert!(
math::abs(ssm.d_skip[2] - 1.0) < 1e-15,
"d_skip[2] should be 1.0 after group 1 reinit"
);
assert!(
math::abs(ssm.d_skip[3] - 1.0) < 1e-15,
"d_skip[3] should be 1.0 after group 1 reinit"
);
}
#[test]
fn v3exp_output_dimension() {
let mut ssm = SelectiveSSMv3Exp::new(6, 8, 2, 42, false);
let output = ssm.forward(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
assert_eq!(output.len(), 6, "output dim must match d_in");
}
#[test]
fn v3exp_state_is_finite_after_many_steps() {
let mut ssm = SelectiveSSMv3Exp::new(4, 8, 2, 42, false);
let input = vec![1.0, -0.5, 0.3, -0.8];
for step in 0..1000 {
let output = ssm.forward(&input);
for (i, &y) in output.iter().enumerate() {
assert!(
y.is_finite(),
"V3Exp output[{}] must be finite at step {}: {}",
i,
step,
y
);
}
}
for &h in ssm.state() {
assert!(
h.is_finite(),
"V3Exp state must remain finite after 1000 steps"
);
}
}
#[test]
fn v3exp_three_term_recurrence_correct() {
let mut ssm = SelectiveSSMv3Exp::new(4, 8, 2, 42, false);
let input_a = vec![1.0, 0.5, -0.3, 0.8];
let input_b = vec![-0.5, 1.0, 0.2, -0.4];
let out1 = ssm.forward(&input_a);
let out2 = ssm.forward(&input_b);
ssm.reset();
let out2_from_reset = ssm.forward(&input_b);
let diff: f64 = out2
.iter()
.zip(out2_from_reset.iter())
.map(|(a, b)| (a - b).powi(2))
.sum();
assert!(
diff > 1e-15,
"3-term recurrence must produce different output than 2-term (prev_bx matters): \
out2_3term={:?} vs out2_2term={:?}, diff={}",
out2,
out2_from_reset,
diff
);
for &y in &out1 {
assert!(y.is_finite(), "V3Exp step 1 output must be finite");
}
for &y in &out2 {
assert!(y.is_finite(), "V3Exp step 2 output must be finite");
}
}
#[test]
fn v3exp_with_bcnorm_finite() {
let mut ssm = SelectiveSSMv3Exp::new(4, 8, 2, 42, true);
assert!(ssm.uses_bcnorm(), "BCNorm should be active");
for step in 0..100 {
let input = vec![(step as f64) * 0.1, -(step as f64) * 0.05, 0.3, -0.2];
let output = ssm.forward(&input);
for &y in &output {
assert!(
y.is_finite(),
"V3Exp+BCNorm output must be finite at step {}",
step
);
}
}
}
#[test]
fn v3exp_reset_clears_prev_bx() {
let mut ssm = SelectiveSSMv3Exp::new(4, 8, 2, 42, false);
let _ = ssm.forward(&[1.0, 2.0, 3.0, 4.0]);
ssm.reset();
for &h in ssm.state() {
assert!(h.abs() < 1e-15, "state must be zero after reset");
}
let zero_out = ssm.forward(&[0.0, 0.0, 0.0, 0.0]);
for &y in &zero_out {
assert!(
y.abs() < 1e-15,
"zero input after reset must give zero output (prev_bx=0): got {}",
y
);
}
}
#[test]
fn v3exp_parity_tracking_accuracy_ge_07() {
let mut ssm = SelectiveSSMv3Exp::new(2, 16, 2, 42, false);
for step in 0..200 {
let sign = if step % 2 == 0 { 1.0_f64 } else { -1.0_f64 };
let _ = ssm.forward(&[sign, sign * 0.5]);
}
let state_energy: f64 = ssm.state().iter().map(|s| s * s).sum();
assert!(
state_energy > 0.0,
"V3Exp state must be non-zero after parity sequence: energy={}",
state_energy
);
let out_even = ssm.forward(&[1.0, 0.5]);
let out_odd = ssm.forward(&[-1.0, -0.5]);
let diff: f64 = out_even
.iter()
.zip(out_odd.iter())
.map(|(a, b)| (a - b).abs())
.sum();
assert!(
diff > 1e-5,
"V3Exp must produce different outputs for even/odd parity inputs after warmup: diff={}",
diff
);
}
#[test]
fn v3mimo_output_dimension() {
let mut ssm = SelectiveSSMv3Mimo::new(6, 8, 2, 1, 42, false);
let output = ssm.forward(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
assert_eq!(output.len(), 6, "MIMO output dim must match d_in");
}
#[test]
fn v3mimo_state_is_matrix_valued_not_scalar() {
let d_in = 4;
let n_state = 4;
let n_groups = 2;
let cpg = d_in / n_groups;
let ssm = SelectiveSSMv3Mimo::new(d_in, n_state, n_groups, 1, 42, false);
let expected_state_len = 2 * n_groups * n_state * cpg;
assert_eq!(
ssm.state().len(),
expected_state_len,
"V3Mimo state must have length 2*n_groups*n_state*cpg={} (matrix-valued), got {}",
expected_state_len,
ssm.state().len()
);
let input_ch0_high = vec![10.0, 0.0, 10.0, 0.0]; let input_ch1_high = vec![0.0, 10.0, 0.0, 10.0];
let mut ssm_a = SelectiveSSMv3Mimo::new(d_in, n_state, n_groups, 1, 42, false);
let mut ssm_b = SelectiveSSMv3Mimo::new(d_in, n_state, n_groups, 1, 42, false);
let _ = ssm_a.forward(&input_ch0_high);
let _ = ssm_b.forward(&input_ch1_high);
let state_a = ssm_a.state();
let state_b = ssm_b.state();
let state_diff: f64 = state_a
.iter()
.zip(state_b.iter())
.map(|(a, b)| (a - b).powi(2))
.sum();
assert!(
state_diff > 1e-10,
"V3Mimo state must differ for per-channel inputs (matrix-valued state): diff={}",
state_diff
);
let mut ssm_eval_a = SelectiveSSMv3Mimo::new(d_in, n_state, n_groups, 1, 42, false);
let mut ssm_eval_b = SelectiveSSMv3Mimo::new(d_in, n_state, n_groups, 1, 42, false);
let out_a = ssm_eval_a.forward(&input_ch0_high);
let out_b = ssm_eval_b.forward(&input_ch1_high);
let out_diff: f64 = out_a
.iter()
.zip(out_b.iter())
.map(|(a, b)| (a - b).abs())
.sum();
assert!(
out_diff > 1e-10,
"V3Mimo output must differ for ch0-high vs ch1-high inputs: diff={}",
out_diff
);
let _ = ssm; }
#[test]
fn v3mimo_rank1_finite_after_many_steps() {
let mut ssm = SelectiveSSMv3Mimo::new(4, 8, 2, 1, 42, false);
let input = vec![1.0, -0.5, 0.3, -0.8];
for step in 0..1000 {
let output = ssm.forward(&input);
for (i, &y) in output.iter().enumerate() {
assert!(
y.is_finite(),
"V3Mimo rank=1 output[{}] must be finite at step {}: {}",
i,
step,
y
);
}
}
for &h in ssm.state() {
assert!(h.is_finite(), "V3Mimo state must remain finite");
}
}
#[test]
fn v3mimo_rank2_finite() {
let mut ssm = SelectiveSSMv3Mimo::new(4, 4, 2, 2, 42, false);
for _ in 0..100 {
let y = ssm.forward(&[1.0, -1.0, 0.5, -0.5]);
for &v in &y {
assert!(v.is_finite(), "V3Mimo rank=2 output must be finite");
}
}
}
#[test]
fn v3mimo_reset_clears_state() {
let mut ssm = SelectiveSSMv3Mimo::new(4, 4, 2, 1, 42, false);
let _ = ssm.forward(&[1.0, 2.0, 3.0, 4.0]);
let energy: f64 = ssm.state().iter().map(|h| h * h).sum();
assert!(energy > 0.0, "state must be non-zero after forward");
ssm.reset();
for &h in ssm.state() {
assert!(h.abs() < 1e-15, "state must be zero after reset, got {}", h);
}
}
#[test]
fn v3mimo_accessors() {
let ssm = SelectiveSSMv3Mimo::new(6, 4, 3, 2, 42, false);
assert_eq!(ssm.d_in(), 6);
assert_eq!(ssm.n_state(), 4);
assert_eq!(ssm.n_groups(), 3);
assert_eq!(ssm.rank(), 2);
assert!(!ssm.uses_bcnorm());
}
#[test]
fn v3mimo_with_bcnorm_finite() {
let mut ssm = SelectiveSSMv3Mimo::new(4, 8, 2, 1, 42, true);
assert!(ssm.uses_bcnorm());
for _ in 0..100 {
let y = ssm.forward(&[1.0, -2.0, 3.0, -1.0]);
for &v in &y {
assert!(v.is_finite(), "V3Mimo+BCNorm output must be finite");
}
}
}
}