pub fn ssm_discretize(a_diag: &[f32], b: &[f32], delta: f32, a_bar: &mut [f32], b_bar: &mut [f32]) {
assert_eq!(a_diag.len(), b.len(), "A and B dimension mismatch");
assert!(delta > 0.0, "Delta must be positive");
let n = a_diag.len();
for i in 0..n {
a_bar[i] = (delta * a_diag[i]).exp();
b_bar[i] = delta * b[i];
}
}
pub fn ssm_scan(x: &[f32], a_bar: &[f32], b_bar: &[f32], c: &[f32], output: &mut [f32]) {
let seq_len = x.len();
let state_dim = a_bar.len();
assert_eq!(state_dim, b_bar.len());
assert_eq!(state_dim, c.len());
assert_eq!(seq_len, output.len());
let mut h = vec![0.0f32; state_dim];
for t in 0..seq_len {
for i in 0..state_dim {
h[i] = a_bar[i] * h[i] + b_bar[i] * x[t];
}
let mut y = 0.0f32;
for i in 0..state_dim {
y += c[i] * h[i];
}
output[t] = y;
}
}
pub fn selective_gate(
x: &[f32],
w_delta: &[f32],
b_delta: f32,
w_b: &[f32],
w_c: &[f32],
state_dim: usize,
delta_out: &mut f32,
b_out: &mut [f32],
c_out: &mut [f32],
) {
let input_dim = x.len();
assert_eq!(w_delta.len(), input_dim);
assert_eq!(w_b.len(), input_dim * state_dim);
assert_eq!(w_c.len(), input_dim * state_dim);
assert_eq!(b_out.len(), state_dim);
assert_eq!(c_out.len(), state_dim);
let z: f32 = w_delta
.iter()
.zip(x.iter())
.map(|(w, xi)| w * xi)
.sum::<f32>()
+ b_delta;
*delta_out = softplus(z);
for i in 0..state_dim {
let mut sum = 0.0f32;
for j in 0..input_dim {
sum += w_b[i * input_dim + j] * x[j];
}
b_out[i] = sum;
}
for i in 0..state_dim {
let mut sum = 0.0f32;
for j in 0..input_dim {
sum += w_c[i * input_dim + j] * x[j];
}
c_out[i] = sum;
}
}
fn softplus(x: f32) -> f32 {
if x > 20.0 {
x } else if x < -20.0 {
0.0 } else {
(1.0 + x.exp()).ln()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ssm_discretize_basic() {
let a = [-0.5, -1.0];
let b = [1.0, 0.5];
let delta = 0.1;
let mut a_bar = [0.0; 2];
let mut b_bar = [0.0; 2];
ssm_discretize(&a, &b, delta, &mut a_bar, &mut b_bar);
assert!((a_bar[0] - (-0.05f32).exp()).abs() < 1e-5);
assert!((b_bar[0] - 0.1).abs() < 1e-5);
assert!((b_bar[1] - 0.05).abs() < 1e-5);
}
#[test]
fn test_ssm_scan_causality() {
let a_bar = [0.9];
let b_bar = [0.1];
let c = [1.0];
let x = [1.0, 2.0, 3.0, 4.0];
let mut y = [0.0; 4];
ssm_scan(&x, &a_bar, &b_bar, &c, &mut y);
assert!(y[0] > 0.0);
let mut y2 = [0.0; 4];
let x2 = [1.0, 2.0, 3.0, 99.0];
ssm_scan(&x2, &a_bar, &b_bar, &c, &mut y2);
assert!((y[0] - y2[0]).abs() < 1e-10);
assert!((y[1] - y2[1]).abs() < 1e-10);
assert!((y[2] - y2[2]).abs() < 1e-10);
}
#[test]
fn test_selective_gate_positivity() {
let x = [1.0, -1.0, 0.5];
let w_delta = [0.3, 0.2, 0.1];
let w_b = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]; let w_c = [0.6, 0.5, 0.4, 0.3, 0.2, 0.1];
let mut delta = 0.0;
let mut b_out = [0.0; 2];
let mut c_out = [0.0; 2];
selective_gate(
&x, &w_delta, 0.0, &w_b, &w_c, 2, &mut delta, &mut b_out, &mut c_out,
);
assert!(delta > 0.0, "Delta must be positive, got {delta}");
}
#[test]
fn test_softplus_properties() {
assert!(softplus(0.0) > 0.0);
assert!(softplus(-100.0) >= 0.0);
assert!((softplus(0.0) - 0.6931).abs() < 0.01); assert!((softplus(100.0) - 100.0).abs() < 0.01); }
}