use alloc::vec;
use alloc::vec::Vec;
use core::mem;
use crate::math;
use crate::rng::standard_normal;
const PRE_GATE_CLAMP: f64 = 20.0;
const DENOM_EPS: f64 = 1e-6;
pub struct SLSTMCell {
w_input_f: Vec<f64>,
w_input_i: Vec<f64>,
w_input_o: Vec<f64>,
w_input_z: Vec<f64>,
r_f: Vec<f64>,
r_i: Vec<f64>,
r_o: Vec<f64>,
r_z: Vec<f64>,
b_f: Vec<f64>,
b_i: Vec<f64>,
b_o: Vec<f64>,
b_z: Vec<f64>,
h: Vec<f64>,
c: Vec<f64>,
n: Vec<f64>,
m: Vec<f64>,
scratch: Vec<f64>,
d_input: usize,
d_hidden: usize,
n_heads: usize,
d_h_per_head: usize,
forget_bias_init: Vec<f64>,
initialized: bool,
rng_state: u64,
}
impl SLSTMCell {
pub fn new(d_hidden: usize, seed: u64) -> Self {
let forget_bias_init = vec![1.0; d_hidden];
Self {
w_input_f: Vec::new(),
w_input_i: Vec::new(),
w_input_o: Vec::new(),
w_input_z: Vec::new(),
r_f: Vec::new(),
r_i: Vec::new(),
r_o: Vec::new(),
r_z: Vec::new(),
b_f: Vec::new(),
b_i: Vec::new(),
b_o: Vec::new(),
b_z: Vec::new(),
h: vec![0.0; d_hidden],
c: vec![0.0; d_hidden],
n: vec![1.0; d_hidden],
m: vec![0.0; d_hidden],
scratch: Vec::new(),
d_input: 0,
d_hidden,
n_heads: 1,
d_h_per_head: d_hidden,
forget_bias_init,
initialized: false,
rng_state: seed,
}
}
pub fn with_config(
d_hidden: usize,
n_heads: usize,
forget_bias_init: Vec<f64>,
seed: u64,
) -> Self {
assert!(n_heads > 0, "n_heads must be > 0");
assert!(
d_hidden % n_heads == 0,
"n_heads ({}) must divide d_hidden ({})",
n_heads,
d_hidden
);
assert_eq!(
forget_bias_init.len(),
d_hidden,
"forget_bias_init length ({}) must equal d_hidden ({})",
forget_bias_init.len(),
d_hidden
);
let d_h_per_head = d_hidden / n_heads;
Self {
w_input_f: Vec::new(),
w_input_i: Vec::new(),
w_input_o: Vec::new(),
w_input_z: Vec::new(),
r_f: Vec::new(),
r_i: Vec::new(),
r_o: Vec::new(),
r_z: Vec::new(),
b_f: Vec::new(),
b_i: Vec::new(),
b_o: Vec::new(),
b_z: Vec::new(),
h: vec![0.0; d_hidden],
c: vec![0.0; d_hidden],
n: vec![1.0; d_hidden],
m: vec![0.0; d_hidden],
scratch: Vec::new(),
d_input: 0,
d_hidden,
n_heads,
d_h_per_head,
forget_bias_init,
initialized: false,
rng_state: seed,
}
}
fn ensure_initialized(&mut self, d_input: usize) {
if self.initialized {
return;
}
self.d_input = d_input;
let d_total = d_input + self.d_hidden;
let scale = math::sqrt(2.0 / d_total as f64);
let n_input_weights = self.d_hidden * d_input;
self.w_input_f = (0..n_input_weights)
.map(|_| standard_normal(&mut self.rng_state) * scale)
.collect();
self.w_input_i = (0..n_input_weights)
.map(|_| standard_normal(&mut self.rng_state) * scale)
.collect();
self.w_input_o = (0..n_input_weights)
.map(|_| standard_normal(&mut self.rng_state) * scale)
.collect();
self.w_input_z = (0..n_input_weights)
.map(|_| standard_normal(&mut self.rng_state) * scale)
.collect();
let n_recurrent_weights = self.d_hidden * self.d_h_per_head;
self.r_f = (0..n_recurrent_weights)
.map(|_| standard_normal(&mut self.rng_state) * scale)
.collect();
self.r_i = (0..n_recurrent_weights)
.map(|_| standard_normal(&mut self.rng_state) * scale)
.collect();
self.r_o = (0..n_recurrent_weights)
.map(|_| standard_normal(&mut self.rng_state) * scale)
.collect();
self.r_z = (0..n_recurrent_weights)
.map(|_| standard_normal(&mut self.rng_state) * scale)
.collect();
self.b_f = self.forget_bias_init.clone();
self.b_i = vec![0.0; self.d_hidden];
self.b_o = vec![0.0; self.d_hidden];
self.b_z = vec![0.0; self.d_hidden];
self.scratch = vec![0.0; 8 * self.d_hidden];
self.initialized = true;
}
pub fn forward(&mut self, x: &[f64]) -> &[f64] {
self.ensure_initialized(x.len());
let d_h = self.d_hidden;
let mut scratch = mem::take(&mut self.scratch);
let (pre_f, rest) = scratch.split_at_mut(d_h);
let (pre_i, rest) = rest.split_at_mut(d_h);
let (pre_o, rest) = rest.split_at_mut(d_h);
let (pre_z, rest) = rest.split_at_mut(d_h);
let (o_gate, rest) = rest.split_at_mut(d_h);
let (z_gate, rest) = rest.split_at_mut(d_h);
let (f_prime, i_prime) = rest.split_at_mut(d_h);
crate::simd::simd_mat_vec(&self.w_input_f, x, d_h, self.d_input, pre_f);
crate::simd::simd_mat_vec(&self.w_input_i, x, d_h, self.d_input, pre_i);
crate::simd::simd_mat_vec(&self.w_input_o, x, d_h, self.d_input, pre_o);
crate::simd::simd_mat_vec(&self.w_input_z, x, d_h, self.d_input, pre_z);
compute_block_diagonal_recurrent(&self.r_f, &self.h, d_h, self.d_h_per_head, pre_f);
compute_block_diagonal_recurrent(&self.r_i, &self.h, d_h, self.d_h_per_head, pre_i);
compute_block_diagonal_recurrent(&self.r_o, &self.h, d_h, self.d_h_per_head, pre_o);
compute_block_diagonal_recurrent(&self.r_z, &self.h, d_h, self.d_h_per_head, pre_z);
for j in 0..d_h {
pre_f[j] += self.b_f[j];
pre_i[j] += self.b_i[j];
pre_o[j] += self.b_o[j];
pre_z[j] += self.b_z[j];
pre_f[j] = clamp(pre_f[j], -PRE_GATE_CLAMP, PRE_GATE_CLAMP);
pre_i[j] = clamp(pre_i[j], -PRE_GATE_CLAMP, PRE_GATE_CLAMP);
}
crate::simd::simd_sigmoid(pre_o, o_gate);
crate::simd::simd_tanh(pre_z, z_gate);
for j in 0..d_h {
let log_f = pre_f[j] + self.m[j];
let m_new = if log_f > pre_i[j] { log_f } else { pre_i[j] };
pre_f[j] = log_f - m_new;
pre_i[j] -= m_new;
self.m[j] = m_new;
}
crate::simd::simd_exp(pre_f, f_prime);
crate::simd::simd_exp(pre_i, i_prime);
for j in 0..d_h {
self.c[j] = f_prime[j] * self.c[j] + i_prime[j] * z_gate[j];
self.n[j] = f_prime[j] * self.n[j] + i_prime[j];
let abs_n = math::abs(self.n[j]);
let floor = DENOM_EPS * math::exp(-self.m[j]);
let denom = if abs_n > floor { abs_n } else { floor };
self.h[j] = o_gate[j] * (self.c[j] / denom);
}
self.scratch = scratch;
&self.h
}
pub fn forward_predict(&self, x: &[f64]) -> Vec<f64> {
assert!(
self.initialized,
"forward_predict called before initialization; call forward() first"
);
let d_h = self.d_hidden;
let mut c_tmp = self.c.clone();
let mut n_tmp = self.n.clone();
let mut m_tmp = self.m.clone();
let mut pre_f = vec![0.0; d_h];
let mut pre_i = vec![0.0; d_h];
let mut pre_o = vec![0.0; d_h];
let mut pre_z = vec![0.0; d_h];
let mut o_gate = vec![0.0; d_h];
let mut z_gate = vec![0.0; d_h];
let mut f_prime = vec![0.0; d_h];
let mut i_prime = vec![0.0; d_h];
crate::simd::simd_mat_vec(&self.w_input_f, x, d_h, self.d_input, &mut pre_f);
crate::simd::simd_mat_vec(&self.w_input_i, x, d_h, self.d_input, &mut pre_i);
crate::simd::simd_mat_vec(&self.w_input_o, x, d_h, self.d_input, &mut pre_o);
crate::simd::simd_mat_vec(&self.w_input_z, x, d_h, self.d_input, &mut pre_z);
compute_block_diagonal_recurrent(&self.r_f, &self.h, d_h, self.d_h_per_head, &mut pre_f);
compute_block_diagonal_recurrent(&self.r_i, &self.h, d_h, self.d_h_per_head, &mut pre_i);
compute_block_diagonal_recurrent(&self.r_o, &self.h, d_h, self.d_h_per_head, &mut pre_o);
compute_block_diagonal_recurrent(&self.r_z, &self.h, d_h, self.d_h_per_head, &mut pre_z);
for j in 0..d_h {
pre_f[j] += self.b_f[j];
pre_i[j] += self.b_i[j];
pre_o[j] += self.b_o[j];
pre_z[j] += self.b_z[j];
pre_f[j] = clamp(pre_f[j], -PRE_GATE_CLAMP, PRE_GATE_CLAMP);
pre_i[j] = clamp(pre_i[j], -PRE_GATE_CLAMP, PRE_GATE_CLAMP);
}
crate::simd::simd_sigmoid(&pre_o, &mut o_gate);
crate::simd::simd_tanh(&pre_z, &mut z_gate);
for j in 0..d_h {
let log_f = pre_f[j] + m_tmp[j];
let m_new = if log_f > pre_i[j] { log_f } else { pre_i[j] };
pre_f[j] = log_f - m_new;
pre_i[j] -= m_new;
m_tmp[j] = m_new;
}
crate::simd::simd_exp(&pre_f, &mut f_prime);
crate::simd::simd_exp(&pre_i, &mut i_prime);
let mut h_out = vec![0.0; d_h];
for j in 0..d_h {
c_tmp[j] = f_prime[j] * c_tmp[j] + i_prime[j] * z_gate[j];
n_tmp[j] = f_prime[j] * n_tmp[j] + i_prime[j];
let abs_n = math::abs(n_tmp[j]);
let floor = DENOM_EPS * math::exp(-m_tmp[j]);
let denom = if abs_n > floor { abs_n } else { floor };
h_out[j] = o_gate[j] * (c_tmp[j] / denom);
}
h_out
}
pub fn reset(&mut self) {
self.h.fill(0.0);
self.c.fill(0.0);
self.n.fill(1.0);
self.m.fill(0.0);
self.scratch.fill(0.0);
}
#[inline]
pub fn hidden_state(&self) -> &[f64] {
&self.h
}
#[inline]
pub fn d_hidden(&self) -> usize {
self.d_hidden
}
#[inline]
pub fn output_dim(&self) -> usize {
self.d_hidden
}
pub fn reinitialize_unit(&mut self, j: usize, rng: &mut u64) {
assert!(self.initialized, "cell must be initialized before reinit");
assert!(
j < self.d_hidden,
"unit index {} out of range (d_hidden={})",
j,
self.d_hidden
);
let d_total = self.d_input + self.d_hidden;
let scale = math::sqrt(2.0 / d_total as f64);
let input_row_start = j * self.d_input;
for col in 0..self.d_input {
self.w_input_f[input_row_start + col] = standard_normal(rng) * scale;
self.w_input_i[input_row_start + col] = standard_normal(rng) * scale;
self.w_input_o[input_row_start + col] = standard_normal(rng) * scale;
self.w_input_z[input_row_start + col] = standard_normal(rng) * scale;
}
let k = j / self.d_h_per_head;
let l = j % self.d_h_per_head;
let recurrent_row_start = k * self.d_h_per_head * self.d_h_per_head + l * self.d_h_per_head;
for col in 0..self.d_h_per_head {
self.r_f[recurrent_row_start + col] = standard_normal(rng) * scale;
self.r_i[recurrent_row_start + col] = standard_normal(rng) * scale;
self.r_o[recurrent_row_start + col] = standard_normal(rng) * scale;
self.r_z[recurrent_row_start + col] = standard_normal(rng) * scale;
}
self.b_f[j] = self.forget_bias_init[j];
self.b_i[j] = 0.0;
self.b_o[j] = 0.0;
self.b_z[j] = 0.0;
self.h[j] = 0.0;
self.c[j] = 0.0;
self.n[j] = 1.0;
self.m[j] = 0.0;
}
#[inline]
pub fn n_heads(&self) -> usize {
self.n_heads
}
pub fn forget_bias_linspace(start: f64, stop: f64, n: usize) -> Vec<f64> {
assert!(n > 0, "n must be > 0");
if n == 1 {
return vec![start];
}
let step = (stop - start) / (n - 1) as f64;
(0..n).map(|i| start + step * i as f64).collect()
}
}
fn compute_block_diagonal_recurrent(
r: &[f64],
h: &[f64],
d_hidden: usize,
d_h_per_head: usize,
out: &mut [f64],
) {
debug_assert_eq!(h.len(), d_hidden);
debug_assert_eq!(out.len(), d_hidden);
debug_assert_eq!(d_hidden % d_h_per_head, 0);
debug_assert_eq!(r.len(), d_hidden * d_h_per_head);
let n_heads = d_hidden / d_h_per_head;
let block_size = d_h_per_head * d_h_per_head;
for k in 0..n_heads {
let r_block_start = k * block_size;
let h_offset = k * d_h_per_head;
for i in 0..d_h_per_head {
let row_start = r_block_start + i * d_h_per_head;
let mut acc = 0.0;
for j in 0..d_h_per_head {
acc += r[row_start + j] * h[h_offset + j];
}
out[h_offset + i] += acc;
}
}
}
#[inline]
fn clamp(x: f64, lo: f64, hi: f64) -> f64 {
if x < lo {
lo
} else if x > hi {
hi
} else {
x
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn slstm_cell_new() {
let cell = SLSTMCell::new(16, 42);
assert_eq!(cell.d_hidden(), 16, "d_hidden should match constructor arg");
assert_eq!(cell.output_dim(), 16, "output_dim should equal d_hidden");
assert!(
!cell.initialized,
"cell should not be initialized before first forward"
);
assert_eq!(
cell.hidden_state().len(),
16,
"hidden state should be pre-allocated to d_hidden"
);
}
#[test]
fn slstm_cell_forward_initializes() {
let mut cell = SLSTMCell::new(8, 42);
assert!(!cell.initialized, "should start uninitialized");
let x = [0.1, -0.2, 0.3, 0.4];
let h_len = cell.forward(&x).len();
assert!(
cell.initialized,
"should be initialized after first forward"
);
assert_eq!(h_len, 8, "output length should be d_hidden");
assert_eq!(
cell.d_input, 4,
"d_input should be inferred from input length"
);
assert_eq!(
cell.w_input_f.len(),
8 * 4,
"w_input_f should have d_hidden * d_input elements"
);
assert_eq!(
cell.r_f.len(),
8 * 8,
"r_f should have d_hidden * d_hidden elements"
);
}
#[test]
fn slstm_cell_forward_finite() {
let mut cell = SLSTMCell::new(8, 123);
let x = [1.0, -0.5, 0.3, 2.0, -1.0];
let h = cell.forward(&x);
for (i, &val) in h.iter().enumerate() {
assert!(
val.is_finite(),
"h[{}] = {} should be finite after forward",
i,
val
);
}
}
#[test]
fn slstm_cell_forward_predict_no_state_change() {
let mut cell = SLSTMCell::new(4, 99);
let x = [0.5, -0.3, 0.8];
cell.forward(&x);
let h_before = cell.h.clone();
let c_before = cell.c.clone();
let n_before = cell.n.clone();
let m_before = cell.m.clone();
let x2 = [0.1, 0.2, -0.4];
let _h_predict = cell.forward_predict(&x2);
assert_eq!(
cell.h, h_before,
"hidden state should not change after forward_predict"
);
assert_eq!(
cell.c, c_before,
"cell state should not change after forward_predict"
);
assert_eq!(
cell.n, n_before,
"normalizer state should not change after forward_predict"
);
assert_eq!(
cell.m, m_before,
"stabilizer state should not change after forward_predict"
);
}
#[test]
fn slstm_cell_reset() {
let mut cell = SLSTMCell::new(4, 77);
let x = [1.0, -1.0];
for _ in 0..5 {
cell.forward(&x);
}
let w_f_before = cell.r_f.clone();
let w_i_before = cell.r_i.clone();
cell.reset();
assert!(
cell.h.iter().all(|&v| v == 0.0),
"h should be all zeros after reset"
);
assert!(
cell.c.iter().all(|&v| v == 0.0),
"c should be all zeros after reset"
);
assert!(
cell.n.iter().all(|&v| v == 1.0),
"n should be all 1.0 after reset"
);
assert!(
cell.m.iter().all(|&v| v == 0.0),
"m should be all zeros after reset"
);
assert_eq!(
cell.r_f, w_f_before,
"r_f weights should be preserved after reset"
);
assert_eq!(
cell.r_i, w_i_before,
"r_i weights should be preserved after reset"
);
}
#[test]
fn slstm_cell_exponential_gating_range() {
let mut cell = SLSTMCell::new(16, 55);
let x_large: Vec<f64> = (0..10).map(|i| (i as f64 - 5.0) * 10.0).collect();
for _ in 0..50 {
let h = cell.forward(&x_large);
for (i, &val) in h.iter().enumerate() {
assert!(
val.is_finite(),
"h[{}] = {} should be finite even with large inputs",
i,
val
);
assert!(
!val.is_nan(),
"h[{}] should not be NaN even with large inputs",
i,
);
}
}
}
#[test]
fn slstm_cell_sequence_evolves_state() {
let mut cell = SLSTMCell::new(4, 42);
let x = [0.5, -0.3, 0.8];
let h1 = cell.forward(&x).to_vec();
let h2 = cell.forward(&x).to_vec();
let h3 = cell.forward(&x).to_vec();
assert_ne!(
h1, h2,
"hidden state should evolve between step 1 and step 2"
);
assert_ne!(
h2, h3,
"hidden state should evolve between step 2 and step 3"
);
}
#[test]
fn reinitialize_unit_resets_target_only() {
let mut cell = SLSTMCell::new(4, 42);
let x = [0.5, -0.3, 0.8];
for _ in 0..10 {
cell.forward(&x);
}
let h0_before = cell.h[0];
let h2_before = cell.h[2];
let c2_before = cell.c[2];
let mut rng = 999u64;
cell.reinitialize_unit(1, &mut rng);
assert!(
math::abs(cell.h[1]) < 1e-15,
"reinit unit h should be zero, got {}",
cell.h[1]
);
assert!(
math::abs(cell.c[1]) < 1e-15,
"reinit unit c should be zero, got {}",
cell.c[1]
);
assert!(
(cell.n[1] - 1.0).abs() < 1e-15,
"reinit unit n should be 1.0, got {}",
cell.n[1]
);
assert!(
(cell.h[0] - h0_before).abs() < 1e-15,
"unit 0 h should be unchanged after reinit of unit 1"
);
assert!(
(cell.h[2] - h2_before).abs() < 1e-15,
"unit 2 h should be unchanged after reinit of unit 1"
);
assert!(
(cell.c[2] - c2_before).abs() < 1e-15,
"unit 2 c should be unchanged after reinit of unit 1"
);
}
#[test]
fn reinitialize_unit_produces_fresh_weights() {
let mut cell = SLSTMCell::new(4, 42);
cell.forward(&[0.1, 0.2, 0.3]);
let d_h = cell.d_h_per_head; let row_start = d_h;
let w_f_before: Vec<f64> = cell.r_f[row_start..row_start + d_h].to_vec();
let mut rng = 777u64;
cell.reinitialize_unit(1, &mut rng);
let w_f_after: Vec<f64> = cell.r_f[row_start..row_start + d_h].to_vec();
let diff: f64 = w_f_before
.iter()
.zip(w_f_after.iter())
.map(|(a, b)| (a - b).abs())
.sum();
assert!(
diff > 1e-10,
"reinitialized weights should differ from original"
);
assert!(
(cell.b_f[1] - 1.0).abs() < 1e-15,
"forget bias should be 1.0 after reinit, got {}",
cell.b_f[1]
);
}
#[test]
fn forget_bias_uses_linspace_3_to_6() {
let d = 8usize;
let bias = SLSTMCell::forget_bias_linspace(3.0, 6.0, d);
assert_eq!(bias.len(), d, "linspace length must equal d_hidden");
assert!(
(bias[0] - 3.0).abs() < 1e-12,
"first bias value must be 3.0, got {}",
bias[0]
);
assert!(
(bias[d - 1] - 6.0).abs() < 1e-12,
"last bias value must be 6.0, got {}",
bias[d - 1]
);
for i in 1..d {
assert!(
bias[i] > bias[i - 1],
"linspace must be strictly increasing at index {}",
i
);
}
let step = (6.0 - 3.0) / (d - 1) as f64;
for (i, &b) in bias.iter().enumerate() {
let expected = 3.0 + step * i as f64;
assert!(
(b - expected).abs() < 1e-12,
"bias[{}] expected {}, got {}",
i,
expected,
b
);
}
let mut cell = SLSTMCell::with_config(d, 1, bias.clone(), 42);
cell.forward(&[0.1, 0.2]); for (j, &expected) in bias.iter().enumerate() {
assert!(
(cell.forget_bias_init[j] - expected).abs() < 1e-12,
"forget_bias_init[{}] must equal linspace value {}, got {}",
j,
expected,
cell.forget_bias_init[j]
);
}
}
#[test]
fn denominator_is_scale_equivariant_in_low_gate_regime() {
let d = 4usize;
let mut cell_a = SLSTMCell::new(d, 7);
cell_a.forward(&[0.1, 0.2]);
for j in 0..d {
cell_a.m[j] = -10.0; cell_a.n[j] = 1e-9; cell_a.c[j] = 1.0; }
let h_equivariant: Vec<f64> = (0..d)
.map(|j| {
let abs_n = math::abs(cell_a.n[j]);
let floor = DENOM_EPS * math::exp(-cell_a.m[j]);
let denom = if abs_n > floor { abs_n } else { floor };
cell_a.c[j] / denom
})
.collect();
let h_constant_floor: Vec<f64> = (0..d)
.map(|j| {
let abs_n = math::abs(cell_a.n[j]);
let denom = if abs_n > 1.0 { abs_n } else { 1.0 };
cell_a.c[j] / denom
})
.collect();
for (j, (&he, &hc)) in h_equivariant
.iter()
.zip(h_constant_floor.iter())
.enumerate()
{
assert!(
he.abs() > hc.abs(),
"scale-equivariant h[{}]={:.6} must exceed constant-floor h[{}]={:.6} in low-gate regime",
j, he, j, hc
);
}
}
#[test]
fn n_heads_1_matches_dense_path() {
let d = 8usize;
let seed = 77u64;
let forget_bias = vec![1.0; d];
let mut cell_dense = SLSTMCell::new(d, seed);
let mut cell_config = SLSTMCell::with_config(d, 1, forget_bias, seed);
let inputs: &[&[f64]] = &[&[0.1, -0.2, 0.3], &[0.5, 0.0, -0.1], &[-0.3, 0.8, 0.2]];
for &x in inputs {
let h_dense = cell_dense.forward(x).to_vec();
let h_config = cell_config.forward(x).to_vec();
for (j, (a, b)) in h_dense.iter().zip(h_config.iter()).enumerate() {
assert!(
(a - b).abs() < 1e-14,
"n_heads=1 config path must match dense path at unit {j}: dense={a}, config={b}"
);
}
}
}
#[test]
fn slstm_multi_head_forward_finite_and_correct_n_heads() {
let d = 8usize;
let n_heads = 2usize;
let bias = SLSTMCell::forget_bias_linspace(3.0, 6.0, d);
let mut cell = SLSTMCell::with_config(d, n_heads, bias, 42);
assert_eq!(
cell.n_heads(),
n_heads,
"n_heads accessor must match constructor arg"
);
let x = [0.1f64, -0.2, 0.3, 0.4];
for _ in 0..10 {
let h = cell.forward(&x);
assert_eq!(h.len(), d, "output length must equal d_hidden");
for (j, &v) in h.iter().enumerate() {
assert!(v.is_finite(), "multi-head h[{j}]={v} must be finite");
}
}
}
}