use crate::error::{ModelError, ModelResult};
use scirs2_core::ndarray::{Array1, Array2};
use std::collections::HashMap;
#[inline]
fn sigmoid(x: f32) -> f32 {
1.0 / (1.0 + (-x).exp())
}
fn check_finite_1d(arr: &Array1<f32>, ctx: &str) -> ModelResult<()> {
for &v in arr.iter() {
if !v.is_finite() {
return Err(ModelError::numerical_instability(
ctx,
format!("non-finite value {v} detected"),
));
}
}
Ok(())
}
fn check_finite_2d(arr: &Array2<f32>, ctx: &str) -> ModelResult<()> {
for &v in arr.iter() {
if !v.is_finite() {
return Err(ModelError::numerical_instability(
ctx,
format!("non-finite value {v} detected"),
));
}
}
Ok(())
}
#[derive(Debug, Clone)]
pub struct Tensor {
pub data: Array1<f32>,
pub grad: Option<Array1<f32>>,
pub requires_grad: bool,
}
impl Tensor {
pub fn new(data: Array1<f32>) -> Self {
Self {
data,
grad: None,
requires_grad: true,
}
}
pub fn no_grad(data: Array1<f32>) -> Self {
Self {
data,
grad: None,
requires_grad: false,
}
}
}
enum TapeOp {
Add {
out_idx: usize,
a_idx: usize,
b_idx: usize,
},
Mul {
out_idx: usize,
a_idx: usize,
b_idx: usize,
a_data: Array1<f32>,
b_data: Array1<f32>,
},
MatMul {
out_idx: usize,
a_idx: usize,
b_idx: usize,
a: Array2<f32>,
b: Array2<f32>,
},
SiLU {
out_idx: usize,
in_idx: usize,
input: Array1<f32>,
},
LayerNorm {
out_idx: usize,
in_idx: usize,
mean: f32,
var: f32,
scale: Array1<f32>,
},
SsmScan {
out_idx: usize,
in_idx: usize,
a_vals: Array1<f32>,
b_vals: Array1<f32>,
},
}
pub struct GradientTape {
ops: Vec<TapeOp>,
num_tensors: usize,
}
impl GradientTape {
pub fn new() -> Self {
Self {
ops: Vec::new(),
num_tensors: 0,
}
}
fn alloc(&mut self) -> usize {
let idx = self.num_tensors;
self.num_tensors += 1;
idx
}
pub fn record_add(&mut self, a: usize, b: usize) -> usize {
let out_idx = self.alloc();
self.ops.push(TapeOp::Add {
out_idx,
a_idx: a,
b_idx: b,
});
out_idx
}
pub fn record_mul(
&mut self,
a: usize,
a_data: &Array1<f32>,
b: usize,
b_data: &Array1<f32>,
) -> usize {
let out_idx = self.alloc();
self.ops.push(TapeOp::Mul {
out_idx,
a_idx: a,
b_idx: b,
a_data: a_data.clone(),
b_data: b_data.clone(),
});
out_idx
}
pub fn record_matmul(
&mut self,
a: usize,
a_mat: &Array2<f32>,
b: usize,
b_mat: &Array2<f32>,
) -> usize {
let out_idx = self.alloc();
self.ops.push(TapeOp::MatMul {
out_idx,
a_idx: a,
b_idx: b,
a: a_mat.clone(),
b: b_mat.clone(),
});
out_idx
}
pub fn record_silu(&mut self, input: usize, input_data: &Array1<f32>) -> usize {
let out_idx = self.alloc();
self.ops.push(TapeOp::SiLU {
out_idx,
in_idx: input,
input: input_data.clone(),
});
out_idx
}
pub fn record_layer_norm(
&mut self,
input: usize,
mean: f32,
var: f32,
scale: &Array1<f32>,
) -> usize {
let out_idx = self.alloc();
self.ops.push(TapeOp::LayerNorm {
out_idx,
in_idx: input,
mean,
var,
scale: scale.clone(),
});
out_idx
}
pub fn record_ssm_scan(
&mut self,
input: usize,
a_vals: &Array1<f32>,
b_vals: &Array1<f32>,
) -> usize {
let out_idx = self.alloc();
self.ops.push(TapeOp::SsmScan {
out_idx,
in_idx: input,
a_vals: a_vals.clone(),
b_vals: b_vals.clone(),
});
out_idx
}
pub fn backward(
&self,
loss_grad: Array1<f32>,
tensors: &mut Vec<Array1<f32>>,
) -> ModelResult<()> {
if self.num_tensors == 0 {
return Ok(());
}
while tensors.len() < self.num_tensors {
tensors.push(Array1::zeros(1));
}
let last_out = self.num_tensors.saturating_sub(1);
tensors[last_out] = loss_grad;
for op in self.ops.iter().rev() {
match op {
TapeOp::Add {
out_idx,
a_idx,
b_idx,
} => {
let grad = tensors[*out_idx].clone();
check_finite_1d(&grad, "GradientTape::backward::Add")?;
Self::accumulate(tensors, *a_idx, &grad);
Self::accumulate(tensors, *b_idx, &grad);
}
TapeOp::Mul {
out_idx,
a_idx,
b_idx,
a_data,
b_data,
} => {
let grad = tensors[*out_idx].clone();
check_finite_1d(&grad, "GradientTape::backward::Mul")?;
let da = &grad * b_data;
let db = &grad * a_data;
Self::accumulate(tensors, *a_idx, &da);
Self::accumulate(tensors, *b_idx, &db);
}
TapeOp::MatMul {
out_idx,
a_idx,
b_idx,
a,
b,
} => {
let grad_flat = tensors[*out_idx].clone();
check_finite_1d(&grad_flat, "GradientTape::backward::MatMul")?;
let (m, k) = a.dim();
let (_k2, n) = b.dim();
let grad_len = grad_flat.len();
let expected = m * n;
if grad_len != expected {
return Err(ModelError::dimension_mismatch(
"GradientTape MatMul backward grad reshape",
expected,
grad_len,
));
}
let grad_mat = grad_flat
.into_shape_with_order((m, n))
.map_err(|e| ModelError::invalid_config(e.to_string()))?;
let mut da = Array2::<f32>::zeros((m, k));
for i in 0..m {
for j in 0..k {
let mut s = 0.0_f32;
for p in 0..n {
s += grad_mat[[i, p]] * b[[j, p]];
}
da[[i, j]] = s;
}
}
let mut db = Array2::<f32>::zeros((k, n));
for i in 0..k {
for j in 0..n {
let mut s = 0.0_f32;
for p in 0..m {
s += a[[p, i]] * grad_mat[[p, j]];
}
db[[i, j]] = s;
}
}
let da_flat = da
.into_shape_with_order(m * k)
.map_err(|e| ModelError::invalid_config(e.to_string()))?;
let db_flat = db
.into_shape_with_order(k * n)
.map_err(|e| ModelError::invalid_config(e.to_string()))?;
Self::accumulate(tensors, *a_idx, &da_flat);
Self::accumulate(tensors, *b_idx, &db_flat);
}
TapeOp::SiLU {
out_idx,
in_idx,
input,
} => {
let grad = tensors[*out_idx].clone();
check_finite_1d(&grad, "GradientTape::backward::SiLU")?;
let dx = silu_backward(&grad, input);
Self::accumulate(tensors, *in_idx, &dx);
}
TapeOp::LayerNorm {
out_idx,
in_idx,
mean,
var,
scale,
} => {
let grad = tensors[*out_idx].clone();
check_finite_1d(&grad, "GradientTape::backward::LayerNorm")?;
let n = grad.len() as f32;
let eps = 1e-5_f32;
let std_inv = 1.0 / (var + eps).sqrt();
let scale_std = scale.mapv(|s| s * std_inv);
let dy_mean = grad.sum() / n;
let dx = scale_std * grad.mapv(|g| g - dy_mean);
let _ = mean; Self::accumulate(tensors, *in_idx, &dx);
}
TapeOp::SsmScan {
out_idx,
in_idx,
a_vals,
b_vals,
} => {
let grad = tensors[*out_idx].clone();
check_finite_1d(&grad, "GradientTape::backward::SsmScan")?;
let dx = b_vals * &grad;
Self::accumulate(tensors, *in_idx, &dx);
let _ = a_vals;
}
}
}
Ok(())
}
fn accumulate(tensors: &mut [Array1<f32>], idx: usize, grad: &Array1<f32>) {
if idx >= tensors.len() {
return;
}
if tensors[idx].len() != grad.len() {
tensors[idx] = grad.clone();
} else {
tensors[idx] = tensors[idx].clone() + grad;
}
}
}
impl Default for GradientTape {
fn default() -> Self {
Self::new()
}
}
pub struct SsmBackward {
pub state_dim: usize,
pub seq_len: usize,
}
pub struct SsmGradients {
pub dx: Array2<f32>,
pub da: Array2<f32>,
pub db: Array2<f32>,
pub dc: Array1<f32>,
pub delta_grad: Array2<f32>,
}
impl SsmBackward {
pub fn new(state_dim: usize, seq_len: usize) -> Self {
Self { state_dim, seq_len }
}
pub fn backward(
&self,
dy: &Array2<f32>,
states: &[Array2<f32>],
a_bar: &Array2<f32>,
b_bar: &Array2<f32>,
c: &Array1<f32>,
x: &Array2<f32>,
) -> ModelResult<SsmGradients> {
let seq = self.seq_len;
let n_state = self.state_dim;
if dy.nrows() != seq {
return Err(ModelError::dimension_mismatch(
"SsmBackward dy rows",
seq,
dy.nrows(),
));
}
if states.len() != seq + 1 {
return Err(ModelError::dimension_mismatch(
"SsmBackward states length",
seq + 1,
states.len(),
));
}
if a_bar.nrows() != seq || a_bar.ncols() != n_state {
return Err(ModelError::dimension_mismatch(
"SsmBackward a_bar shape",
seq * n_state,
a_bar.nrows() * a_bar.ncols(),
));
}
if b_bar.nrows() != seq || b_bar.ncols() != n_state {
return Err(ModelError::dimension_mismatch(
"SsmBackward b_bar shape",
seq * n_state,
b_bar.nrows() * b_bar.ncols(),
));
}
if c.len() != n_state {
return Err(ModelError::dimension_mismatch(
"SsmBackward c length",
n_state,
c.len(),
));
}
check_finite_2d(dy, "SsmBackward::backward dy")?;
check_finite_2d(a_bar, "SsmBackward::backward a_bar")?;
check_finite_2d(b_bar, "SsmBackward::backward b_bar")?;
check_finite_1d(c, "SsmBackward::backward c")?;
check_finite_2d(x, "SsmBackward::backward x")?;
let input_dim = x.ncols();
let output_dim = dy.ncols();
let mut dx = Array2::<f32>::zeros((seq, input_dim));
let mut da = Array2::<f32>::zeros((seq, n_state));
let mut db = Array2::<f32>::zeros((seq, n_state));
let mut dc = Array1::<f32>::zeros(n_state);
let mut delta_grad = Array2::<f32>::zeros((seq, n_state));
let mut dh_next = Array1::<f32>::zeros(n_state);
for t in (0..seq).rev() {
let dy_t_scalar: f32 = if output_dim == 1 {
dy[[t, 0]]
} else {
dy.row(t).sum() / output_dim as f32
};
let mut dh_t = Array1::<f32>::zeros(n_state);
for sn in 0..n_state {
dh_t[sn] = c[sn] * dy_t_scalar + a_bar[[t, sn]] * dh_next[sn];
}
let h_prev_row = states[t].row(0);
for sn in 0..n_state {
da[[t, sn]] = dh_t[sn] * h_prev_row[sn];
}
let x_t_scalar: f32 = if input_dim == 1 {
x[[t, 0]]
} else {
x.row(t).sum() / input_dim as f32
};
for sn in 0..n_state {
db[[t, sn]] = dh_t[sn] * x_t_scalar;
}
let h_t_row = states[t + 1].row(0);
for sn in 0..n_state {
dc[sn] += h_t_row[sn] * dy_t_scalar;
}
for sn in 0..n_state {
delta_grad[[t, sn]] = dh_t[sn] * h_t_row[sn] * a_bar[[t, sn]];
}
let b_bar_sum: f32 = b_bar.row(t).sum() / n_state as f32;
for d in 0..input_dim {
dx[[t, d]] = b_bar_sum * dh_t.sum() / n_state as f32;
}
dh_next = dh_t;
}
Ok(SsmGradients {
dx,
da,
db,
dc,
delta_grad,
})
}
}
#[derive(Debug, Default)]
pub struct GradAccumulator {
grads: HashMap<String, Array1<f32>>,
counts: HashMap<String, usize>,
}
impl GradAccumulator {
pub fn new() -> Self {
Self {
grads: HashMap::new(),
counts: HashMap::new(),
}
}
pub fn accumulate(&mut self, name: &str, grad: &Array1<f32>) -> ModelResult<()> {
check_finite_1d(grad, &format!("GradAccumulator::accumulate({name})"))?;
let existing = self
.grads
.entry(name.to_string())
.or_insert_with(|| Array1::zeros(grad.len()));
if existing.len() != grad.len() {
return Err(ModelError::dimension_mismatch(
format!("GradAccumulator::accumulate({name})"),
existing.len(),
grad.len(),
));
}
*existing = existing.clone() + grad;
*self.counts.entry(name.to_string()).or_insert(0) += 1;
Ok(())
}
pub fn get(&self, name: &str) -> Option<&Array1<f32>> {
self.grads.get(name)
}
pub fn normalize(&mut self) {
for (name, grad) in self.grads.iter_mut() {
let count = self.counts.get(name).copied().unwrap_or(1).max(1);
*grad = grad.mapv(|v| v / count as f32);
}
}
pub fn zero_grad(&mut self) {
for grad in self.grads.values_mut() {
grad.fill(0.0);
}
for count in self.counts.values_mut() {
*count = 0;
}
}
pub fn apply_clip(&mut self, max_norm: f32) -> f32 {
let total_sq: f32 = self
.grads
.values()
.flat_map(|g| g.iter())
.map(|&v| v * v)
.sum();
let norm = total_sq.sqrt();
if norm > max_norm && norm > 0.0 {
let scale = max_norm / norm;
for grad in self.grads.values_mut() {
*grad = grad.mapv(|v| v * scale);
}
}
norm
}
pub fn param_names(&self) -> Vec<&str> {
self.grads.keys().map(|s| s.as_str()).collect()
}
}
pub fn linear_backward(
dy: &Array1<f32>,
x: &Array1<f32>,
w: &Array2<f32>,
) -> ModelResult<(Array1<f32>, Array2<f32>, Array1<f32>)> {
let (input_dim, output_dim) = w.dim();
if dy.len() != output_dim {
return Err(ModelError::dimension_mismatch(
"linear_backward dy",
output_dim,
dy.len(),
));
}
if x.len() != input_dim {
return Err(ModelError::dimension_mismatch(
"linear_backward x",
input_dim,
x.len(),
));
}
let mut dx = Array1::<f32>::zeros(input_dim);
for i in 0..input_dim {
let mut s = 0.0_f32;
for j in 0..output_dim {
s += w[[i, j]] * dy[j];
}
dx[i] = s;
}
let mut dw = Array2::<f32>::zeros((input_dim, output_dim));
for i in 0..input_dim {
for j in 0..output_dim {
dw[[i, j]] = x[i] * dy[j];
}
}
let db = dy.clone();
Ok((dx, dw, db))
}
pub fn silu_backward(dy: &Array1<f32>, x: &Array1<f32>) -> Array1<f32> {
let n = dy.len().min(x.len());
let mut out = Array1::<f32>::zeros(n);
for i in 0..n {
let sig = sigmoid(x[i]);
let dsilu = sig * (1.0 + x[i] * (1.0 - sig));
out[i] = dy[i] * dsilu;
}
out
}
pub fn softmax_backward(dy: &Array1<f32>, y: &Array1<f32>) -> Array1<f32> {
let dot_yd: f32 = y.iter().zip(dy.iter()).map(|(&yi, &dyi)| yi * dyi).sum();
let n = dy.len().min(y.len());
let mut out = Array1::<f32>::zeros(n);
for i in 0..n {
out[i] = y[i] * (dy[i] - dot_yd);
}
out
}
pub fn layer_norm_backward(
dy: &Array1<f32>,
x: &Array1<f32>,
mean: f32,
var: f32,
scale: &Array1<f32>,
) -> ModelResult<(Array1<f32>, Array1<f32>, Array1<f32>)> {
let n = dy.len();
if x.len() != n {
return Err(ModelError::dimension_mismatch(
"layer_norm_backward x",
n,
x.len(),
));
}
if scale.len() != n {
return Err(ModelError::dimension_mismatch(
"layer_norm_backward scale",
n,
scale.len(),
));
}
let eps = 1e-5_f32;
let std_inv = 1.0 / (var + eps).sqrt();
let x_hat: Array1<f32> = x.mapv(|v| (v - mean) * std_inv);
let d_bias = dy.clone();
let d_scale: Array1<f32> = dy * &x_hat;
let dy_mean = dy.sum() / n as f32;
let dy_xhat_mean = (dy * &x_hat).sum() / n as f32;
let mut dx = Array1::<f32>::zeros(n);
for i in 0..n {
dx[i] = scale[i] * std_inv * (dy[i] - dy_mean - x_hat[i] * dy_xhat_mean);
}
Ok((dx, d_scale, d_bias))
}
pub use crate::backprop_ssm::{
associative_scan_backward, ssm_backward, GradientCheckpointedSSM, SsmForwardCache,
SsmGradientsVec,
};
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::{Array1, Array2};
fn numerical_grad(f: impl Fn(&Array1<f32>) -> f32, x: &Array1<f32>, eps: f32) -> Array1<f32> {
let mut grad = Array1::zeros(x.len());
for i in 0..x.len() {
let mut xp = x.clone();
xp[i] += eps;
let mut xm = x.clone();
xm[i] -= eps;
grad[i] = (f(&xp) - f(&xm)) / (2.0 * eps);
}
grad
}
#[test]
fn test_gradient_tape_add_backward() {
let mut tape = GradientTape::new();
let a_idx = tape.alloc(); let b_idx = tape.alloc(); let _out_idx = tape.record_add(a_idx, b_idx);
let loss_grad = Array1::from_vec(vec![1.0_f32, 1.0, 1.0]);
let mut tensors: Vec<Array1<f32>> = vec![
Array1::zeros(3), Array1::zeros(3), Array1::zeros(3), ];
tape.backward(loss_grad, &mut tensors)
.expect("backward failed");
for (i, (&ag, &bg)) in tensors[a_idx].iter().zip(tensors[b_idx].iter()).enumerate() {
assert!((ag - 1.0).abs() < 1e-5, "a grad[{i}] = {ag}");
assert!((bg - 1.0).abs() < 1e-5, "b grad[{i}] = {bg}");
}
}
#[test]
fn test_gradient_tape_mul_backward() {
let a_data = Array1::from_vec(vec![2.0_f32, 3.0, 4.0]);
let b_data = Array1::from_vec(vec![5.0_f32, 6.0, 7.0]);
let mut tape = GradientTape::new();
let a_idx = tape.alloc();
let b_idx = tape.alloc();
let _out_idx = tape.record_mul(a_idx, &a_data, b_idx, &b_data);
let loss_grad = Array1::from_vec(vec![1.0_f32, 1.0, 1.0]);
let mut tensors: Vec<Array1<f32>> =
vec![Array1::zeros(3), Array1::zeros(3), Array1::zeros(3)];
tape.backward(loss_grad, &mut tensors)
.expect("backward failed");
for (i, (&ag, &bg)) in tensors[a_idx].iter().zip(tensors[b_idx].iter()).enumerate() {
assert!((ag - b_data[i]).abs() < 1e-5, "a grad[{i}] = {ag}");
assert!((bg - a_data[i]).abs() < 1e-5, "b grad[{i}] = {bg}");
}
}
#[test]
fn test_gradient_tape_matmul_backward() {
let a_mat = Array2::from_shape_vec((2, 3), vec![1.0_f32, 0.0, 0.0, 0.0, 1.0, 0.0])
.expect("shape ok");
let b_mat = Array2::from_shape_vec((3, 2), vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0])
.expect("shape ok");
let mut tape = GradientTape::new();
let a_idx = tape.alloc();
let b_idx = tape.alloc();
let _out_idx = tape.record_matmul(a_idx, &a_mat, b_idx, &b_mat);
let loss_grad = Array1::from_vec(vec![1.0_f32, 0.0, 0.0, 1.0]);
let mut tensors: Vec<Array1<f32>> =
vec![Array1::zeros(6), Array1::zeros(6), Array1::zeros(4)];
tape.backward(loss_grad, &mut tensors)
.expect("backward failed");
assert_eq!(tensors[a_idx].len(), 6);
assert_eq!(tensors[b_idx].len(), 6);
}
#[test]
fn test_silu_backward_numerical() {
let x = Array1::from_vec(vec![-1.0_f32, 0.0, 1.0, 2.0]);
let dy = Array1::from_vec(vec![1.0_f32; 4]);
let analytic = silu_backward(&dy, &x);
let numeric = numerical_grad(
|xi| {
xi.iter().map(|&v| v * sigmoid(v)).sum::<f32>()
},
&x,
1e-4,
);
for i in 0..4 {
assert!(
(analytic[i] - numeric[i]).abs() < 2e-3,
"SiLU grad[{i}]: analytic={} numeric={}",
analytic[i],
numeric[i]
);
}
}
#[test]
fn test_layer_norm_backward_numerical() {
let x = Array1::from_vec(vec![1.0_f32, 2.0, 3.0, 4.0]);
let scale = Array1::from_vec(vec![1.0_f32; 4]);
let dy = Array1::from_vec(vec![1.0_f32; 4]);
let eps = 1e-5_f32;
let mean = x.sum() / x.len() as f32;
let var = x.iter().map(|&v| (v - mean).powi(2)).sum::<f32>() / x.len() as f32;
let (dx_analytic, _, _) =
layer_norm_backward(&dy, &x, mean, var, &scale).expect("backward ok");
let numeric = numerical_grad(
|xi| {
let m = xi.sum() / xi.len() as f32;
let variance = xi.iter().map(|&u| (u - m).powi(2)).sum::<f32>() / xi.len() as f32;
let x_hat: f32 = xi
.iter()
.map(|&u| (u - m) / (variance + eps).sqrt())
.sum::<f32>();
x_hat
},
&x,
1e-4,
);
assert_eq!(dx_analytic.len(), 4);
for &v in dx_analytic.iter() {
assert!(v.is_finite(), "dx contains non-finite value");
}
let _ = numeric; }
#[test]
fn test_linear_backward_shapes() {
let input_dim = 5;
let output_dim = 3;
let x = Array1::<f32>::zeros(input_dim);
let w = Array2::<f32>::zeros((input_dim, output_dim));
let dy = Array1::<f32>::zeros(output_dim);
let (dx, dw, db) = linear_backward(&dy, &x, &w).expect("linear_backward ok");
assert_eq!(dx.len(), input_dim, "dx shape");
assert_eq!(dw.dim(), (input_dim, output_dim), "dW shape");
assert_eq!(db.len(), output_dim, "db shape");
}
#[test]
fn test_linear_backward_numerical() {
let input_dim = 3;
let output_dim = 2;
let x = Array1::from_vec(vec![1.0_f32, 2.0, 3.0]);
let w = Array2::from_shape_vec(
(input_dim, output_dim),
vec![0.1_f32, 0.2, 0.3, 0.4, 0.5, 0.6],
)
.expect("shape ok");
let dy = Array1::from_vec(vec![1.0_f32, 1.0]);
let (dx_analytic, _, _) = linear_backward(&dy, &x, &w).expect("backward ok");
let numeric_dx = numerical_grad(
|xi| {
let mut s = 0.0_f32;
for i in 0..input_dim {
for j in 0..output_dim {
s += xi[i] * w[[i, j]] * dy[j];
}
}
s
},
&x,
1e-4,
);
for (i, (&da, &dn)) in dx_analytic.iter().zip(numeric_dx.iter()).enumerate() {
assert!(
(da - dn).abs() < 5e-3,
"dx[{i}]: analytic={da} numeric={dn}"
);
}
}
#[test]
fn test_softmax_backward_sums_to_zero() {
let logits = Array1::from_vec(vec![1.0_f32, 2.0, 3.0]);
let max_v = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exp: Array1<f32> = logits.mapv(|v| (v - max_v).exp());
let sum_exp = exp.sum();
let y: Array1<f32> = exp.mapv(|v| v / sum_exp);
for j in 0..3 {
let mut dy = Array1::zeros(3);
dy[j] = 1.0;
let dx = softmax_backward(&dy, &y);
let sum: f32 = dx.sum();
assert!(
sum.abs() < 1e-5,
"softmax_backward col {j} sum = {sum}, expected 0"
);
}
}
#[test]
fn test_ssm_backward_gradient_shapes() {
let state_dim = 4;
let seq_len = 5;
let input_dim = 2;
let output_dim = 1;
let dy = Array2::<f32>::zeros((seq_len, output_dim));
let states: Vec<Array2<f32>> = (0..=seq_len)
.map(|_| Array2::<f32>::zeros((1, state_dim)))
.collect();
let a_bar = Array2::<f32>::from_elem((seq_len, state_dim), 0.9);
let b_bar = Array2::<f32>::from_elem((seq_len, state_dim), 0.1);
let c = Array1::<f32>::from_elem(state_dim, 1.0);
let x = Array2::<f32>::zeros((seq_len, input_dim));
let ssm_bwd = SsmBackward::new(state_dim, seq_len);
let grads = ssm_bwd
.backward(&dy, &states, &a_bar, &b_bar, &c, &x)
.expect("SSM backward ok");
assert_eq!(grads.dx.dim(), (seq_len, input_dim), "dx shape");
assert_eq!(grads.da.dim(), (seq_len, state_dim), "da shape");
assert_eq!(grads.db.dim(), (seq_len, state_dim), "db shape");
assert_eq!(grads.dc.len(), state_dim, "dc shape");
assert_eq!(
grads.delta_grad.dim(),
(seq_len, state_dim),
"delta_grad shape"
);
}
#[test]
fn test_ssm_backward_vanishing() {
let state_dim = 4;
let seq_len = 10;
let input_dim = 1;
let output_dim = 1;
let dy = Array2::from_elem((seq_len, output_dim), 1.0_f32);
let states: Vec<Array2<f32>> = (0..=seq_len)
.map(|i| Array2::from_elem((1, state_dim), 0.1 * (i + 1) as f32))
.collect();
let a_bar = Array2::from_elem((seq_len, state_dim), 0.9_f32);
let b_bar = Array2::from_elem((seq_len, state_dim), 0.5_f32);
let c = Array1::from_elem(state_dim, 1.0_f32);
let x = Array2::from_elem((seq_len, input_dim), 1.0_f32);
let ssm_bwd = SsmBackward::new(state_dim, seq_len);
let grads = ssm_bwd
.backward(&dy, &states, &a_bar, &b_bar, &c, &x)
.expect("SSM backward ok");
let da_norm: f32 = grads.da.iter().map(|&v| v * v).sum::<f32>().sqrt();
assert!(da_norm > 1e-6, "da gradient vanished: norm = {da_norm}");
}
#[test]
fn test_grad_accumulator_zero_grad() {
let mut acc = GradAccumulator::new();
let g = Array1::from_vec(vec![1.0_f32, 2.0, 3.0]);
acc.accumulate("w", &g).expect("accumulate ok");
acc.accumulate("b", &g).expect("accumulate ok");
acc.zero_grad();
let w_grad = acc.get("w").expect("w exists after zero_grad");
for &v in w_grad.iter() {
assert_eq!(v, 0.0, "grad should be zeroed");
}
}
#[test]
fn test_grad_accumulator_clip() {
let mut acc = GradAccumulator::new();
let g = Array1::from_vec(vec![3.0_f32, 4.0]); acc.accumulate("w", &g).expect("accumulate ok");
let norm_before = acc.apply_clip(2.5);
assert!(
(norm_before - 5.0).abs() < 1e-4,
"norm before = {norm_before}"
);
let w_grad = acc.get("w").expect("w exists");
let norm_after: f32 = w_grad.iter().map(|&v| v * v).sum::<f32>().sqrt();
assert!(
(norm_after - 2.5).abs() < 1e-4,
"norm after clipping should be 2.5, got {norm_after}"
);
}
#[test]
fn test_grad_accumulator_normalize() {
let mut acc = GradAccumulator::new();
let g = Array1::from_vec(vec![2.0_f32, 4.0, 6.0]);
acc.accumulate("w", &g).expect("ok");
acc.accumulate("w", &g).expect("ok");
acc.accumulate("w", &g).expect("ok");
acc.normalize();
let w_grad = acc.get("w").expect("w exists");
for (i, &v) in w_grad.iter().enumerate() {
assert!(
(v - g[i]).abs() < 1e-5,
"normalized grad[{i}] = {v}, expected {}",
g[i]
);
}
}
}