use alloc::vec;
use alloc::vec::Vec;
use crate::math;
use crate::rng::standard_normal;
use crate::ssm::init::s4d_inv_real;
use crate::ssm::projection::{dot, mat_vec, softplus, Xorshift64};
use crate::ssm::SSMLayer;
pub struct SelectiveSSMBD {
a_matrices: Vec<f64>,
w_b: Vec<f64>,
w_c: Vec<f64>,
w_delta: Vec<f64>,
b_delta: f64,
d_skip: Vec<f64>,
h: Vec<f64>,
d_in: usize,
n_state: usize,
block_size: usize,
n_blocks: usize,
}
fn normalize_row_l1(a: &mut [f64], m: usize) {
for row in 0..m {
let start = row * m;
let row_sum: f64 = a[start..start + m].iter().map(|x| math::abs(*x)).sum();
if row_sum > 1.0 {
for j in 0..m {
a[start + j] /= row_sum;
}
}
}
}
impl SelectiveSSMBD {
pub fn new(d_in: usize, n_state: usize, block_size: usize, seed: u64) -> Self {
assert!(
d_in % block_size == 0,
"d_in ({}) must be evenly divisible by block_size ({})",
d_in,
block_size
);
let n_blocks = d_in / block_size;
let m = block_size;
let mut rng = Xorshift64(seed);
let scale = 0.1;
let off_diag_scale = 0.02;
let log_a = s4d_inv_real(m);
let mut a_matrices = vec![0.0; n_blocks * m * m];
for blk in 0..n_blocks {
let base = blk * m * m;
for i in 0..m {
for j in 0..m {
if i == j {
a_matrices[base + i * m + j] = -math::exp(log_a[i]);
} else {
a_matrices[base + i * m + j] = rng.next_normal() * off_diag_scale;
}
}
}
normalize_row_l1(&mut a_matrices[base..base + m * m], m);
}
let w_delta: Vec<f64> = (0..d_in).map(|_| rng.next_normal() * scale).collect();
let b_delta = 0.0;
let w_b: Vec<f64> = (0..n_state * d_in)
.map(|_| rng.next_normal() * scale)
.collect();
let w_c: Vec<f64> = (0..n_state * d_in)
.map(|_| rng.next_normal() * scale)
.collect();
let d_skip = vec![1.0; d_in];
let h = vec![0.0; n_blocks * n_state * block_size];
Self {
a_matrices,
w_b,
w_c,
w_delta,
b_delta,
d_skip,
h,
d_in,
n_state,
block_size,
n_blocks,
}
}
#[inline]
pub fn d_in(&self) -> usize {
self.d_in
}
#[inline]
pub fn n_state(&self) -> usize {
self.n_state
}
#[inline]
pub fn block_size(&self) -> usize {
self.block_size
}
#[inline]
pub fn n_blocks(&self) -> usize {
self.n_blocks
}
pub fn reinitialize_block(&mut self, b: usize, rng: &mut u64) {
assert!(
b < self.n_blocks,
"block index {} out of range (n_blocks={})",
b,
self.n_blocks
);
let m = self.block_size;
let off_diag_scale = 0.02;
let h_start = b * self.n_state * m;
let h_end = h_start + self.n_state * m;
for h in self.h[h_start..h_end].iter_mut() {
*h = 0.0;
}
let log_a = s4d_inv_real(m);
let a_base = b * m * m;
for (i, &la_i) in log_a.iter().enumerate().take(m) {
for j in 0..m {
if i == j {
self.a_matrices[a_base + i * m + j] = -math::exp(la_i);
} else {
self.a_matrices[a_base + i * m + j] = standard_normal(rng) * off_diag_scale;
}
}
}
normalize_row_l1(&mut self.a_matrices[a_base..a_base + m * m], m);
let ch_start = b * m;
for d in ch_start..ch_start + m {
self.d_skip[d] = 1.0;
}
}
fn bd_forward(&mut self, input: &[f64]) -> Vec<f64> {
let d_in = self.d_in;
let n_state = self.n_state;
let m = self.block_size;
let n_blocks = self.n_blocks;
let delta_raw = dot(&self.w_delta, input) + self.b_delta;
let delta = softplus(delta_raw).min(1.0);
let mut b_t = vec![0.0; n_state];
mat_vec(&self.w_b, input, n_state, d_in, &mut b_t);
let mut c_t = vec![0.0; n_state];
mat_vec(&self.w_c, input, n_state, d_in, &mut c_t);
let mut output = vec![0.0; d_in];
for blk in 0..n_blocks {
let a_base = blk * m * m;
let x_start = blk * m;
let h_block_base = blk * n_state * m;
for (n, &b_n) in b_t.iter().enumerate().take(n_state) {
let h_offset = h_block_base + n * m;
let db = delta * b_n;
let mut h_new = vec![0.0; m];
for i in 0..m {
let a_row = a_base + i * m;
let mut sum = 0.0;
for j in 0..m {
let a_disc = if i == j {
1.0 + delta * self.a_matrices[a_row + j]
} else {
delta * self.a_matrices[a_row + j]
};
sum += a_disc * self.h[h_offset + j];
}
h_new[i] = sum + db * input[x_start + i];
}
self.h[h_offset..h_offset + m].copy_from_slice(&h_new);
}
for (n, &c_n) in c_t.iter().enumerate().take(n_state) {
let h_offset = h_block_base + n * m;
for i in 0..m {
output[x_start + i] += c_n * self.h[h_offset + i];
}
}
}
for (out_d, (&skip, &x_d)) in output.iter_mut().zip(self.d_skip.iter().zip(input.iter())) {
*out_d += skip * x_d;
}
output
}
}
impl SSMLayer for SelectiveSSMBD {
fn forward(&mut self, input: &[f64]) -> Vec<f64> {
debug_assert_eq!(
input.len(),
self.d_in,
"input length {} must match d_in {}",
input.len(),
self.d_in
);
self.bd_forward(input)
}
fn state(&self) -> &[f64] {
&self.h
}
fn output_dim(&self) -> usize {
self.d_in
}
fn reset(&mut self) {
for h in self.h.iter_mut() {
*h = 0.0;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn bd_new_correct_dimensions() {
let ssm = SelectiveSSMBD::new(6, 8, 2, 42);
assert_eq!(ssm.d_in(), 6);
assert_eq!(ssm.n_state(), 8);
assert_eq!(ssm.block_size(), 2);
assert_eq!(ssm.n_blocks(), 3);
assert_eq!(
ssm.state().len(),
3 * 8 * 2,
"state size = n_blocks * n_state * block_size"
);
assert_eq!(ssm.output_dim(), 6);
}
#[test]
fn bd_initial_state_zero() {
let ssm = SelectiveSSMBD::new(4, 8, 2, 42);
for &h in ssm.state() {
assert!(math::abs(h) < 1e-15, "initial state should be zero");
}
}
#[test]
fn bd_forward_correct_output_dim() {
let mut ssm = SelectiveSSMBD::new(6, 8, 3, 42);
let input = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let output = ssm.forward(&input);
assert_eq!(output.len(), 6, "output dim should match d_in");
}
#[test]
fn bd_forward_finite_output() {
let mut ssm = SelectiveSSMBD::new(4, 8, 2, 42);
let input = vec![1.0, -1.0, 0.5, -0.5];
let output = ssm.forward(&input);
for (i, &y) in output.iter().enumerate() {
assert!(y.is_finite(), "output[{}] should be finite, got {}", i, y);
}
}
#[test]
fn bd_forward_updates_state() {
let mut ssm = SelectiveSSMBD::new(4, 8, 2, 42);
let input = vec![1.0, 2.0, 3.0, 4.0];
let _ = ssm.forward(&input);
let state_norm: f64 = ssm.state().iter().map(|h| h * h).sum();
assert!(
state_norm > 0.0,
"state should be non-zero after processing non-zero input"
);
}
#[test]
fn bd_reset_clears_state() {
let mut ssm = SelectiveSSMBD::new(4, 8, 2, 42);
let _ = ssm.forward(&[1.0, 2.0, 3.0, 4.0]);
ssm.reset();
for &h in ssm.state() {
assert!(math::abs(h) < 1e-15, "state should be zero after reset");
}
}
#[test]
fn bd_deterministic_same_seed() {
let mut ssm1 = SelectiveSSMBD::new(4, 8, 2, 42);
let mut ssm2 = SelectiveSSMBD::new(4, 8, 2, 42);
let input = vec![1.0, -1.0, 0.5, -0.5];
let out1 = ssm1.forward(&input);
let out2 = ssm2.forward(&input);
for (i, (&a, &b)) in out1.iter().zip(out2.iter()).enumerate() {
assert!(
math::abs(a - b) < 1e-15,
"output[{}] should be identical for same seed: {} vs {}",
i,
a,
b
);
}
}
#[test]
fn bd_different_seeds_differ() {
let mut ssm1 = SelectiveSSMBD::new(4, 8, 2, 42);
let mut ssm2 = SelectiveSSMBD::new(4, 8, 2, 99);
let input = vec![1.0, 2.0, 3.0, 4.0];
let out1 = ssm1.forward(&input);
let out2 = ssm2.forward(&input);
let diff: f64 = out1
.iter()
.zip(out2.iter())
.map(|(a, b)| (a - b) * (a - b))
.sum();
assert!(
diff > 1e-20,
"different seeds should generally produce different outputs"
);
}
#[test]
fn bd_zero_input_zero_state_zero_output() {
let mut ssm = SelectiveSSMBD::new(4, 8, 2, 42);
let output = ssm.forward(&[0.0, 0.0, 0.0, 0.0]);
for (i, &y) in output.iter().enumerate() {
assert!(
math::abs(y) < 1e-15,
"zero input with zero state should give zero output[{}], got {}",
i,
y
);
}
}
#[test]
fn bd_cross_channel_mixing() {
let d_in = 4;
let n_state = 4;
let seed = 42;
let mut ssm_blk1 = SelectiveSSMBD::new(d_in, n_state, 1, seed);
let mut ssm_blk2 = SelectiveSSMBD::new(d_in, n_state, 2, seed);
let input = vec![1.0, 2.0, 3.0, 4.0];
for _ in 0..5 {
let _ = ssm_blk1.forward(&input);
let _ = ssm_blk2.forward(&input);
}
let out1 = ssm_blk1.forward(&input);
let out2 = ssm_blk2.forward(&input);
for &y in &out1 {
assert!(y.is_finite(), "block_size=1 output should be finite");
}
for &y in &out2 {
assert!(y.is_finite(), "block_size=2 output should be finite");
}
let diff: f64 = out1
.iter()
.zip(out2.iter())
.map(|(a, b)| (a - b) * (a - b))
.sum();
assert!(
diff > 1e-20,
"block_size=1 vs block_size=2 should produce different outputs due to cross-channel mixing: diff={}",
diff
);
}
#[test]
fn bd_state_bounded_under_constant_input() {
let mut ssm = SelectiveSSMBD::new(4, 8, 2, 42);
let input = vec![1.0, -0.5, 0.3, -0.8];
for step in 0..1000 {
let output = ssm.forward(&input);
for (i, &y) in output.iter().enumerate() {
assert!(
y.is_finite(),
"output[{}] is not finite at step {}: {}",
i,
step,
y
);
}
}
for (i, &h) in ssm.state().iter().enumerate() {
assert!(
h.is_finite(),
"state[{}] is not finite after 1000 steps: {}",
i,
h
);
}
let state_norm: f64 = ssm.state().iter().map(|h| h * h).sum();
assert!(
state_norm < 1e12,
"state norm should be bounded after 1000 constant-input steps, got {}",
state_norm
);
}
#[test]
fn reinitialize_block_preserves_others() {
let mut ssm = SelectiveSSMBD::new(6, 4, 2, 42);
for step in 0..10 {
let s = step as f64;
let x = vec![s * 0.1, s * -0.2, s * 0.3, s * -0.1, s * 0.2, s * -0.3];
let _ = ssm.forward(&x);
}
let state_before: Vec<f64> = ssm.state().to_vec();
let a_before: Vec<f64> = ssm.a_matrices.clone();
let n_state = ssm.n_state();
let m = ssm.block_size();
let mut rng = 0xBEEF_u64;
ssm.reinitialize_block(1, &mut rng);
let b0_start = 0;
let b0_end = n_state * m;
for (i, &sb) in state_before.iter().enumerate().take(b0_end).skip(b0_start) {
assert!(
math::abs(ssm.h[i] - sb) < 1e-15,
"block 0 state[{}] should be preserved after reinit of block 1",
i
);
}
let b2_start = 2 * n_state * m;
let b2_end = 3 * n_state * m;
for (i, &sb) in state_before.iter().enumerate().take(b2_end).skip(b2_start) {
assert!(
math::abs(ssm.h[i] - sb) < 1e-15,
"block 2 state[{}] should be preserved after reinit of block 1",
i
);
}
let b1_start = n_state * m;
let b1_end = 2 * n_state * m;
for i in b1_start..b1_end {
assert!(
math::abs(ssm.h[i]) < 1e-15,
"block 1 state[{}] should be zero after reinit, got {}",
i,
ssm.h[i]
);
}
let a0_start = 0;
let a0_end = m * m;
for (i, &ab) in a_before.iter().enumerate().take(a0_end).skip(a0_start) {
assert!(
math::abs(ssm.a_matrices[i] - ab) < 1e-15,
"block 0 A[{}] should be preserved",
i
);
}
let a2_start = 2 * m * m;
let a2_end = 3 * m * m;
for (i, &ab) in a_before.iter().enumerate().take(a2_end).skip(a2_start) {
assert!(
math::abs(ssm.a_matrices[i] - ab) < 1e-15,
"block 2 A[{}] should be preserved",
i
);
}
let a1_start = m * m;
let a1_end = 2 * m * m;
let mut any_a_diff = false;
for (i, &ab) in a_before.iter().enumerate().take(a1_end).skip(a1_start) {
if math::abs(ssm.a_matrices[i] - ab) > 1e-15 {
any_a_diff = true;
break;
}
}
assert!(any_a_diff, "block 1 A matrix should differ after reinit");
assert!(
math::abs(ssm.d_skip[2] - 1.0) < 1e-15,
"d_skip[2] should be 1.0 after block 1 reinit"
);
assert!(
math::abs(ssm.d_skip[3] - 1.0) < 1e-15,
"d_skip[3] should be 1.0 after block 1 reinit"
);
}
#[test]
fn bd_block_sizes_produce_different_outputs() {
let d_in = 8;
let n_state = 4;
let seed = 42;
let mut ssm_bs2 = SelectiveSSMBD::new(d_in, n_state, 2, seed);
let mut ssm_bs4 = SelectiveSSMBD::new(d_in, n_state, 4, seed);
let input = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
for _ in 0..5 {
let _ = ssm_bs2.forward(&input);
let _ = ssm_bs4.forward(&input);
}
let out_bs2 = ssm_bs2.forward(&input);
let out_bs4 = ssm_bs4.forward(&input);
assert_eq!(out_bs2.len(), d_in);
assert_eq!(out_bs4.len(), d_in);
for &y in &out_bs2 {
assert!(y.is_finite(), "block_size=2 output should be finite");
}
for &y in &out_bs4 {
assert!(y.is_finite(), "block_size=4 output should be finite");
}
let diff: f64 = out_bs2
.iter()
.zip(out_bs4.iter())
.map(|(a, b)| (a - b) * (a - b))
.sum();
assert!(
diff > 1e-20,
"block_size=2 vs block_size=4 should produce different outputs: diff={}",
diff
);
}
}