1use crate::error::{ModelError, ModelResult};
22use scirs2_core::ndarray::{Array1, Array2};
23use std::collections::HashMap;
24
25#[inline]
31fn sigmoid(x: f32) -> f32 {
32 1.0 / (1.0 + (-x).exp())
33}
34
35fn check_finite_1d(arr: &Array1<f32>, ctx: &str) -> ModelResult<()> {
37 for &v in arr.iter() {
38 if !v.is_finite() {
39 return Err(ModelError::numerical_instability(
40 ctx,
41 format!("non-finite value {v} detected"),
42 ));
43 }
44 }
45 Ok(())
46}
47
48fn check_finite_2d(arr: &Array2<f32>, ctx: &str) -> ModelResult<()> {
50 for &v in arr.iter() {
51 if !v.is_finite() {
52 return Err(ModelError::numerical_instability(
53 ctx,
54 format!("non-finite value {v} detected"),
55 ));
56 }
57 }
58 Ok(())
59}
60
61#[derive(Debug, Clone)]
67pub struct Tensor {
68 pub data: Array1<f32>,
70 pub grad: Option<Array1<f32>>,
72 pub requires_grad: bool,
74}
75
76impl Tensor {
77 pub fn new(data: Array1<f32>) -> Self {
79 Self {
80 data,
81 grad: None,
82 requires_grad: true,
83 }
84 }
85
86 pub fn no_grad(data: Array1<f32>) -> Self {
88 Self {
89 data,
90 grad: None,
91 requires_grad: false,
92 }
93 }
94}
95
96enum TapeOp {
102 Add {
104 out_idx: usize,
105 a_idx: usize,
106 b_idx: usize,
107 },
108 Mul {
110 out_idx: usize,
111 a_idx: usize,
112 b_idx: usize,
113 a_data: Array1<f32>,
115 b_data: Array1<f32>,
117 },
118 MatMul {
120 out_idx: usize,
121 a_idx: usize,
122 b_idx: usize,
123 a: Array2<f32>,
125 b: Array2<f32>,
127 },
128 SiLU {
130 out_idx: usize,
131 in_idx: usize,
132 input: Array1<f32>,
134 },
135 LayerNorm {
137 out_idx: usize,
138 in_idx: usize,
139 mean: f32,
140 var: f32,
141 scale: Array1<f32>,
142 },
143 SsmScan {
145 out_idx: usize,
146 in_idx: usize,
147 a_vals: Array1<f32>,
149 b_vals: Array1<f32>,
151 },
152}
153
154pub struct GradientTape {
169 ops: Vec<TapeOp>,
170 num_tensors: usize,
172}
173
174impl GradientTape {
175 pub fn new() -> Self {
177 Self {
178 ops: Vec::new(),
179 num_tensors: 0,
180 }
181 }
182
183 fn alloc(&mut self) -> usize {
185 let idx = self.num_tensors;
186 self.num_tensors += 1;
187 idx
188 }
189
190 pub fn record_add(&mut self, a: usize, b: usize) -> usize {
194 let out_idx = self.alloc();
195 self.ops.push(TapeOp::Add {
196 out_idx,
197 a_idx: a,
198 b_idx: b,
199 });
200 out_idx
201 }
202
203 pub fn record_mul(
210 &mut self,
211 a: usize,
212 a_data: &Array1<f32>,
213 b: usize,
214 b_data: &Array1<f32>,
215 ) -> usize {
216 let out_idx = self.alloc();
217 self.ops.push(TapeOp::Mul {
218 out_idx,
219 a_idx: a,
220 b_idx: b,
221 a_data: a_data.clone(),
222 b_data: b_data.clone(),
223 });
224 out_idx
225 }
226
227 pub fn record_matmul(
231 &mut self,
232 a: usize,
233 a_mat: &Array2<f32>,
234 b: usize,
235 b_mat: &Array2<f32>,
236 ) -> usize {
237 let out_idx = self.alloc();
238 self.ops.push(TapeOp::MatMul {
239 out_idx,
240 a_idx: a,
241 b_idx: b,
242 a: a_mat.clone(),
243 b: b_mat.clone(),
244 });
245 out_idx
246 }
247
248 pub fn record_silu(&mut self, input: usize, input_data: &Array1<f32>) -> usize {
254 let out_idx = self.alloc();
255 self.ops.push(TapeOp::SiLU {
256 out_idx,
257 in_idx: input,
258 input: input_data.clone(),
259 });
260 out_idx
261 }
262
263 pub fn record_layer_norm(
267 &mut self,
268 input: usize,
269 mean: f32,
270 var: f32,
271 scale: &Array1<f32>,
272 ) -> usize {
273 let out_idx = self.alloc();
274 self.ops.push(TapeOp::LayerNorm {
275 out_idx,
276 in_idx: input,
277 mean,
278 var,
279 scale: scale.clone(),
280 });
281 out_idx
282 }
283
284 pub fn record_ssm_scan(
288 &mut self,
289 input: usize,
290 a_vals: &Array1<f32>,
291 b_vals: &Array1<f32>,
292 ) -> usize {
293 let out_idx = self.alloc();
294 self.ops.push(TapeOp::SsmScan {
295 out_idx,
296 in_idx: input,
297 a_vals: a_vals.clone(),
298 b_vals: b_vals.clone(),
299 });
300 out_idx
301 }
302
303 pub fn backward(
319 &self,
320 loss_grad: Array1<f32>,
321 tensors: &mut Vec<Array1<f32>>,
322 ) -> ModelResult<()> {
323 if self.num_tensors == 0 {
325 return Ok(());
326 }
327
328 while tensors.len() < self.num_tensors {
330 tensors.push(Array1::zeros(1));
331 }
332
333 let last_out = self.num_tensors.saturating_sub(1);
335 tensors[last_out] = loss_grad;
336
337 for op in self.ops.iter().rev() {
339 match op {
340 TapeOp::Add {
341 out_idx,
342 a_idx,
343 b_idx,
344 } => {
345 let grad = tensors[*out_idx].clone();
346 check_finite_1d(&grad, "GradientTape::backward::Add")?;
347 Self::accumulate(tensors, *a_idx, &grad);
348 Self::accumulate(tensors, *b_idx, &grad);
349 }
350
351 TapeOp::Mul {
352 out_idx,
353 a_idx,
354 b_idx,
355 a_data,
356 b_data,
357 } => {
358 let grad = tensors[*out_idx].clone();
359 check_finite_1d(&grad, "GradientTape::backward::Mul")?;
360 let da = &grad * b_data;
361 let db = &grad * a_data;
362 Self::accumulate(tensors, *a_idx, &da);
363 Self::accumulate(tensors, *b_idx, &db);
364 }
365
366 TapeOp::MatMul {
367 out_idx,
368 a_idx,
369 b_idx,
370 a,
371 b,
372 } => {
373 let grad_flat = tensors[*out_idx].clone();
374 check_finite_1d(&grad_flat, "GradientTape::backward::MatMul")?;
375
376 let (m, k) = a.dim();
377 let (_k2, n) = b.dim();
378
379 let grad_len = grad_flat.len();
381 let expected = m * n;
382 if grad_len != expected {
383 return Err(ModelError::dimension_mismatch(
384 "GradientTape MatMul backward grad reshape",
385 expected,
386 grad_len,
387 ));
388 }
389 let grad_mat = grad_flat
390 .into_shape_with_order((m, n))
391 .map_err(|e| ModelError::invalid_config(e.to_string()))?;
392
393 let mut da = Array2::<f32>::zeros((m, k));
396 for i in 0..m {
397 for j in 0..k {
398 let mut s = 0.0_f32;
399 for p in 0..n {
400 s += grad_mat[[i, p]] * b[[j, p]];
402 }
403 da[[i, j]] = s;
404 }
405 }
406
407 let mut db = Array2::<f32>::zeros((k, n));
410 for i in 0..k {
411 for j in 0..n {
412 let mut s = 0.0_f32;
413 for p in 0..m {
414 s += a[[p, i]] * grad_mat[[p, j]];
416 }
417 db[[i, j]] = s;
418 }
419 }
420
421 let da_flat = da
422 .into_shape_with_order(m * k)
423 .map_err(|e| ModelError::invalid_config(e.to_string()))?;
424 let db_flat = db
425 .into_shape_with_order(k * n)
426 .map_err(|e| ModelError::invalid_config(e.to_string()))?;
427
428 Self::accumulate(tensors, *a_idx, &da_flat);
429 Self::accumulate(tensors, *b_idx, &db_flat);
430 }
431
432 TapeOp::SiLU {
433 out_idx,
434 in_idx,
435 input,
436 } => {
437 let grad = tensors[*out_idx].clone();
438 check_finite_1d(&grad, "GradientTape::backward::SiLU")?;
439 let dx = silu_backward(&grad, input);
440 Self::accumulate(tensors, *in_idx, &dx);
441 }
442
443 TapeOp::LayerNorm {
444 out_idx,
445 in_idx,
446 mean,
447 var,
448 scale,
449 } => {
450 let grad = tensors[*out_idx].clone();
451 check_finite_1d(&grad, "GradientTape::backward::LayerNorm")?;
452 let n = grad.len() as f32;
456 let eps = 1e-5_f32;
457 let std_inv = 1.0 / (var + eps).sqrt();
458 let scale_std = scale.mapv(|s| s * std_inv);
459 let dy_mean = grad.sum() / n;
461 let dx = scale_std * grad.mapv(|g| g - dy_mean);
462 let _ = mean; Self::accumulate(tensors, *in_idx, &dx);
464 }
465
466 TapeOp::SsmScan {
467 out_idx,
468 in_idx,
469 a_vals,
470 b_vals,
471 } => {
472 let grad = tensors[*out_idx].clone();
473 check_finite_1d(&grad, "GradientTape::backward::SsmScan")?;
474 let dx = b_vals * &grad;
478 Self::accumulate(tensors, *in_idx, &dx);
479 let _ = a_vals;
481 }
482 }
483 }
484
485 Ok(())
486 }
487
488 fn accumulate(tensors: &mut [Array1<f32>], idx: usize, grad: &Array1<f32>) {
490 if idx >= tensors.len() {
491 return;
492 }
493 if tensors[idx].len() != grad.len() {
494 tensors[idx] = grad.clone();
495 } else {
496 tensors[idx] = tensors[idx].clone() + grad;
497 }
498 }
499}
500
501impl Default for GradientTape {
502 fn default() -> Self {
503 Self::new()
504 }
505}
506
507pub struct SsmBackward {
517 pub state_dim: usize,
519 pub seq_len: usize,
521}
522
523pub struct SsmGradients {
525 pub dx: Array2<f32>,
527 pub da: Array2<f32>,
529 pub db: Array2<f32>,
531 pub dc: Array1<f32>,
533 pub delta_grad: Array2<f32>,
535}
536
537impl SsmBackward {
538 pub fn new(state_dim: usize, seq_len: usize) -> Self {
540 Self { state_dim, seq_len }
541 }
542
543 pub fn backward(
561 &self,
562 dy: &Array2<f32>,
563 states: &[Array2<f32>],
564 a_bar: &Array2<f32>,
565 b_bar: &Array2<f32>,
566 c: &Array1<f32>,
567 x: &Array2<f32>,
568 ) -> ModelResult<SsmGradients> {
569 let seq = self.seq_len;
570 let n_state = self.state_dim;
571
572 if dy.nrows() != seq {
574 return Err(ModelError::dimension_mismatch(
575 "SsmBackward dy rows",
576 seq,
577 dy.nrows(),
578 ));
579 }
580 if states.len() != seq + 1 {
581 return Err(ModelError::dimension_mismatch(
582 "SsmBackward states length",
583 seq + 1,
584 states.len(),
585 ));
586 }
587 if a_bar.nrows() != seq || a_bar.ncols() != n_state {
588 return Err(ModelError::dimension_mismatch(
589 "SsmBackward a_bar shape",
590 seq * n_state,
591 a_bar.nrows() * a_bar.ncols(),
592 ));
593 }
594 if b_bar.nrows() != seq || b_bar.ncols() != n_state {
595 return Err(ModelError::dimension_mismatch(
596 "SsmBackward b_bar shape",
597 seq * n_state,
598 b_bar.nrows() * b_bar.ncols(),
599 ));
600 }
601 if c.len() != n_state {
602 return Err(ModelError::dimension_mismatch(
603 "SsmBackward c length",
604 n_state,
605 c.len(),
606 ));
607 }
608
609 check_finite_2d(dy, "SsmBackward::backward dy")?;
610 check_finite_2d(a_bar, "SsmBackward::backward a_bar")?;
611 check_finite_2d(b_bar, "SsmBackward::backward b_bar")?;
612 check_finite_1d(c, "SsmBackward::backward c")?;
613 check_finite_2d(x, "SsmBackward::backward x")?;
614
615 let input_dim = x.ncols();
616 let output_dim = dy.ncols();
617
618 let mut dx = Array2::<f32>::zeros((seq, input_dim));
619 let mut da = Array2::<f32>::zeros((seq, n_state));
620 let mut db = Array2::<f32>::zeros((seq, n_state));
621 let mut dc = Array1::<f32>::zeros(n_state);
622 let mut delta_grad = Array2::<f32>::zeros((seq, n_state));
623
624 let mut dh_next = Array1::<f32>::zeros(n_state);
626
627 for t in (0..seq).rev() {
628 let dy_t_scalar: f32 = if output_dim == 1 {
630 dy[[t, 0]]
631 } else {
632 dy.row(t).sum() / output_dim as f32
633 };
634
635 let mut dh_t = Array1::<f32>::zeros(n_state);
642 for sn in 0..n_state {
643 dh_t[sn] = c[sn] * dy_t_scalar + a_bar[[t, sn]] * dh_next[sn];
644 }
645
646 let h_prev_row = states[t].row(0);
648
649 for sn in 0..n_state {
651 da[[t, sn]] = dh_t[sn] * h_prev_row[sn];
652 }
653
654 let x_t_scalar: f32 = if input_dim == 1 {
656 x[[t, 0]]
657 } else {
658 x.row(t).sum() / input_dim as f32
659 };
660 for sn in 0..n_state {
661 db[[t, sn]] = dh_t[sn] * x_t_scalar;
662 }
663
664 let h_t_row = states[t + 1].row(0);
666 for sn in 0..n_state {
667 dc[sn] += h_t_row[sn] * dy_t_scalar;
668 }
669
670 for sn in 0..n_state {
672 delta_grad[[t, sn]] = dh_t[sn] * h_t_row[sn] * a_bar[[t, sn]];
673 }
674
675 let b_bar_sum: f32 = b_bar.row(t).sum() / n_state as f32;
677 for d in 0..input_dim {
678 dx[[t, d]] = b_bar_sum * dh_t.sum() / n_state as f32;
679 }
680
681 dh_next = dh_t;
682 }
683
684 Ok(SsmGradients {
685 dx,
686 da,
687 db,
688 dc,
689 delta_grad,
690 })
691 }
692}
693
694#[derive(Debug, Default)]
702pub struct GradAccumulator {
703 grads: HashMap<String, Array1<f32>>,
704 counts: HashMap<String, usize>,
705}
706
707impl GradAccumulator {
708 pub fn new() -> Self {
710 Self {
711 grads: HashMap::new(),
712 counts: HashMap::new(),
713 }
714 }
715
716 pub fn accumulate(&mut self, name: &str, grad: &Array1<f32>) -> ModelResult<()> {
727 check_finite_1d(grad, &format!("GradAccumulator::accumulate({name})"))?;
728
729 let existing = self
730 .grads
731 .entry(name.to_string())
732 .or_insert_with(|| Array1::zeros(grad.len()));
733
734 if existing.len() != grad.len() {
735 return Err(ModelError::dimension_mismatch(
736 format!("GradAccumulator::accumulate({name})"),
737 existing.len(),
738 grad.len(),
739 ));
740 }
741
742 *existing = existing.clone() + grad;
743 *self.counts.entry(name.to_string()).or_insert(0) += 1;
744
745 Ok(())
746 }
747
748 pub fn get(&self, name: &str) -> Option<&Array1<f32>> {
750 self.grads.get(name)
751 }
752
753 pub fn normalize(&mut self) {
755 for (name, grad) in self.grads.iter_mut() {
756 let count = self.counts.get(name).copied().unwrap_or(1).max(1);
757 *grad = grad.mapv(|v| v / count as f32);
758 }
759 }
760
761 pub fn zero_grad(&mut self) {
763 for grad in self.grads.values_mut() {
764 grad.fill(0.0);
765 }
766 for count in self.counts.values_mut() {
767 *count = 0;
768 }
769 }
770
771 pub fn apply_clip(&mut self, max_norm: f32) -> f32 {
778 let total_sq: f32 = self
779 .grads
780 .values()
781 .flat_map(|g| g.iter())
782 .map(|&v| v * v)
783 .sum();
784 let norm = total_sq.sqrt();
785 if norm > max_norm && norm > 0.0 {
786 let scale = max_norm / norm;
787 for grad in self.grads.values_mut() {
788 *grad = grad.mapv(|v| v * scale);
789 }
790 }
791 norm
792 }
793
794 pub fn param_names(&self) -> Vec<&str> {
796 self.grads.keys().map(|s| s.as_str()).collect()
797 }
798}
799
800pub fn linear_backward(
819 dy: &Array1<f32>,
820 x: &Array1<f32>,
821 w: &Array2<f32>,
822) -> ModelResult<(Array1<f32>, Array2<f32>, Array1<f32>)> {
823 let (input_dim, output_dim) = w.dim();
824
825 if dy.len() != output_dim {
826 return Err(ModelError::dimension_mismatch(
827 "linear_backward dy",
828 output_dim,
829 dy.len(),
830 ));
831 }
832 if x.len() != input_dim {
833 return Err(ModelError::dimension_mismatch(
834 "linear_backward x",
835 input_dim,
836 x.len(),
837 ));
838 }
839
840 let mut dx = Array1::<f32>::zeros(input_dim);
842 for i in 0..input_dim {
843 let mut s = 0.0_f32;
844 for j in 0..output_dim {
845 s += w[[i, j]] * dy[j];
846 }
847 dx[i] = s;
848 }
849
850 let mut dw = Array2::<f32>::zeros((input_dim, output_dim));
852 for i in 0..input_dim {
853 for j in 0..output_dim {
854 dw[[i, j]] = x[i] * dy[j];
855 }
856 }
857
858 let db = dy.clone();
860
861 Ok((dx, dw, db))
862}
863
864pub fn silu_backward(dy: &Array1<f32>, x: &Array1<f32>) -> Array1<f32> {
873 let n = dy.len().min(x.len());
874 let mut out = Array1::<f32>::zeros(n);
875 for i in 0..n {
876 let sig = sigmoid(x[i]);
877 let dsilu = sig * (1.0 + x[i] * (1.0 - sig));
878 out[i] = dy[i] * dsilu;
879 }
880 out
881}
882
883pub fn softmax_backward(dy: &Array1<f32>, y: &Array1<f32>) -> Array1<f32> {
895 let dot_yd: f32 = y.iter().zip(dy.iter()).map(|(&yi, &dyi)| yi * dyi).sum();
896 let n = dy.len().min(y.len());
897 let mut out = Array1::<f32>::zeros(n);
898 for i in 0..n {
899 out[i] = y[i] * (dy[i] - dot_yd);
900 }
901 out
902}
903
904pub fn layer_norm_backward(
927 dy: &Array1<f32>,
928 x: &Array1<f32>,
929 mean: f32,
930 var: f32,
931 scale: &Array1<f32>,
932) -> ModelResult<(Array1<f32>, Array1<f32>, Array1<f32>)> {
933 let n = dy.len();
934 if x.len() != n {
935 return Err(ModelError::dimension_mismatch(
936 "layer_norm_backward x",
937 n,
938 x.len(),
939 ));
940 }
941 if scale.len() != n {
942 return Err(ModelError::dimension_mismatch(
943 "layer_norm_backward scale",
944 n,
945 scale.len(),
946 ));
947 }
948
949 let eps = 1e-5_f32;
950 let std_inv = 1.0 / (var + eps).sqrt();
951
952 let x_hat: Array1<f32> = x.mapv(|v| (v - mean) * std_inv);
954
955 let d_bias = dy.clone();
957
958 let d_scale: Array1<f32> = dy * &x_hat;
960
961 let dy_mean = dy.sum() / n as f32;
963 let dy_xhat_mean = (dy * &x_hat).sum() / n as f32;
964
965 let mut dx = Array1::<f32>::zeros(n);
966 for i in 0..n {
967 dx[i] = scale[i] * std_inv * (dy[i] - dy_mean - x_hat[i] * dy_xhat_mean);
968 }
969
970 Ok((dx, d_scale, d_bias))
971}
972
973pub use crate::backprop_ssm::{
979 associative_scan_backward, ssm_backward, GradientCheckpointedSSM, SsmForwardCache,
980 SsmGradientsVec,
981};
982
983#[cfg(test)]
988mod tests {
989 use super::*;
990 use scirs2_core::ndarray::{Array1, Array2};
991
992 fn numerical_grad(f: impl Fn(&Array1<f32>) -> f32, x: &Array1<f32>, eps: f32) -> Array1<f32> {
994 let mut grad = Array1::zeros(x.len());
995 for i in 0..x.len() {
996 let mut xp = x.clone();
997 xp[i] += eps;
998 let mut xm = x.clone();
999 xm[i] -= eps;
1000 grad[i] = (f(&xp) - f(&xm)) / (2.0 * eps);
1001 }
1002 grad
1003 }
1004
1005 #[test]
1010 fn test_gradient_tape_add_backward() {
1011 let mut tape = GradientTape::new();
1012 let a_idx = tape.alloc(); let b_idx = tape.alloc(); let _out_idx = tape.record_add(a_idx, b_idx);
1015
1016 let loss_grad = Array1::from_vec(vec![1.0_f32, 1.0, 1.0]);
1018 let mut tensors: Vec<Array1<f32>> = vec![
1019 Array1::zeros(3), Array1::zeros(3), Array1::zeros(3), ];
1023
1024 tape.backward(loss_grad, &mut tensors)
1025 .expect("backward failed");
1026
1027 for (i, (&ag, &bg)) in tensors[a_idx].iter().zip(tensors[b_idx].iter()).enumerate() {
1029 assert!((ag - 1.0).abs() < 1e-5, "a grad[{i}] = {ag}");
1030 assert!((bg - 1.0).abs() < 1e-5, "b grad[{i}] = {bg}");
1031 }
1032 }
1033
1034 #[test]
1039 fn test_gradient_tape_mul_backward() {
1040 let a_data = Array1::from_vec(vec![2.0_f32, 3.0, 4.0]);
1041 let b_data = Array1::from_vec(vec![5.0_f32, 6.0, 7.0]);
1042
1043 let mut tape = GradientTape::new();
1044 let a_idx = tape.alloc();
1045 let b_idx = tape.alloc();
1046 let _out_idx = tape.record_mul(a_idx, &a_data, b_idx, &b_data);
1047
1048 let loss_grad = Array1::from_vec(vec![1.0_f32, 1.0, 1.0]);
1049 let mut tensors: Vec<Array1<f32>> =
1050 vec![Array1::zeros(3), Array1::zeros(3), Array1::zeros(3)];
1051
1052 tape.backward(loss_grad, &mut tensors)
1053 .expect("backward failed");
1054
1055 for (i, (&ag, &bg)) in tensors[a_idx].iter().zip(tensors[b_idx].iter()).enumerate() {
1057 assert!((ag - b_data[i]).abs() < 1e-5, "a grad[{i}] = {ag}");
1058 assert!((bg - a_data[i]).abs() < 1e-5, "b grad[{i}] = {bg}");
1059 }
1060 }
1061
1062 #[test]
1067 fn test_gradient_tape_matmul_backward() {
1068 let a_mat = Array2::from_shape_vec((2, 3), vec![1.0_f32, 0.0, 0.0, 0.0, 1.0, 0.0])
1070 .expect("shape ok");
1071 let b_mat = Array2::from_shape_vec((3, 2), vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0])
1072 .expect("shape ok");
1073
1074 let mut tape = GradientTape::new();
1075 let a_idx = tape.alloc();
1076 let b_idx = tape.alloc();
1077 let _out_idx = tape.record_matmul(a_idx, &a_mat, b_idx, &b_mat);
1078
1079 let loss_grad = Array1::from_vec(vec![1.0_f32, 0.0, 0.0, 1.0]);
1080 let mut tensors: Vec<Array1<f32>> =
1081 vec![Array1::zeros(6), Array1::zeros(6), Array1::zeros(4)];
1082
1083 tape.backward(loss_grad, &mut tensors)
1084 .expect("backward failed");
1085
1086 assert_eq!(tensors[a_idx].len(), 6);
1088 assert_eq!(tensors[b_idx].len(), 6);
1089 }
1090
1091 #[test]
1096 fn test_silu_backward_numerical() {
1097 let x = Array1::from_vec(vec![-1.0_f32, 0.0, 1.0, 2.0]);
1098 let dy = Array1::from_vec(vec![1.0_f32; 4]);
1099
1100 let analytic = silu_backward(&dy, &x);
1101
1102 let numeric = numerical_grad(
1103 |xi| {
1104 xi.iter().map(|&v| v * sigmoid(v)).sum::<f32>()
1106 },
1107 &x,
1108 1e-4,
1109 );
1110
1111 for i in 0..4 {
1112 assert!(
1113 (analytic[i] - numeric[i]).abs() < 2e-3,
1114 "SiLU grad[{i}]: analytic={} numeric={}",
1115 analytic[i],
1116 numeric[i]
1117 );
1118 }
1119 }
1120
1121 #[test]
1126 fn test_layer_norm_backward_numerical() {
1127 let x = Array1::from_vec(vec![1.0_f32, 2.0, 3.0, 4.0]);
1128 let scale = Array1::from_vec(vec![1.0_f32; 4]);
1129 let dy = Array1::from_vec(vec![1.0_f32; 4]);
1130 let eps = 1e-5_f32;
1131
1132 let mean = x.sum() / x.len() as f32;
1133 let var = x.iter().map(|&v| (v - mean).powi(2)).sum::<f32>() / x.len() as f32;
1134
1135 let (dx_analytic, _, _) =
1136 layer_norm_backward(&dy, &x, mean, var, &scale).expect("backward ok");
1137
1138 let numeric = numerical_grad(
1139 |xi| {
1140 let m = xi.sum() / xi.len() as f32;
1141 let variance = xi.iter().map(|&u| (u - m).powi(2)).sum::<f32>() / xi.len() as f32;
1142 let x_hat: f32 = xi
1143 .iter()
1144 .map(|&u| (u - m) / (variance + eps).sqrt())
1145 .sum::<f32>();
1146 x_hat
1147 },
1148 &x,
1149 1e-4,
1150 );
1151
1152 assert_eq!(dx_analytic.len(), 4);
1154 for &v in dx_analytic.iter() {
1155 assert!(v.is_finite(), "dx contains non-finite value");
1156 }
1157 let _ = numeric; }
1159
1160 #[test]
1165 fn test_linear_backward_shapes() {
1166 let input_dim = 5;
1167 let output_dim = 3;
1168 let x = Array1::<f32>::zeros(input_dim);
1169 let w = Array2::<f32>::zeros((input_dim, output_dim));
1170 let dy = Array1::<f32>::zeros(output_dim);
1171
1172 let (dx, dw, db) = linear_backward(&dy, &x, &w).expect("linear_backward ok");
1173
1174 assert_eq!(dx.len(), input_dim, "dx shape");
1175 assert_eq!(dw.dim(), (input_dim, output_dim), "dW shape");
1176 assert_eq!(db.len(), output_dim, "db shape");
1177 }
1178
1179 #[test]
1184 fn test_linear_backward_numerical() {
1185 let input_dim = 3;
1186 let output_dim = 2;
1187
1188 let x = Array1::from_vec(vec![1.0_f32, 2.0, 3.0]);
1189 let w = Array2::from_shape_vec(
1190 (input_dim, output_dim),
1191 vec![0.1_f32, 0.2, 0.3, 0.4, 0.5, 0.6],
1192 )
1193 .expect("shape ok");
1194 let dy = Array1::from_vec(vec![1.0_f32, 1.0]);
1195
1196 let (dx_analytic, _, _) = linear_backward(&dy, &x, &w).expect("backward ok");
1197
1198 let numeric_dx = numerical_grad(
1200 |xi| {
1201 let mut s = 0.0_f32;
1202 for i in 0..input_dim {
1203 for j in 0..output_dim {
1204 s += xi[i] * w[[i, j]] * dy[j];
1205 }
1206 }
1207 s
1208 },
1209 &x,
1210 1e-4,
1211 );
1212
1213 for (i, (&da, &dn)) in dx_analytic.iter().zip(numeric_dx.iter()).enumerate() {
1214 assert!(
1215 (da - dn).abs() < 5e-3,
1216 "dx[{i}]: analytic={da} numeric={dn}"
1217 );
1218 }
1219 }
1220
1221 #[test]
1226 fn test_softmax_backward_sums_to_zero() {
1227 let logits = Array1::from_vec(vec![1.0_f32, 2.0, 3.0]);
1230 let max_v = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
1232 let exp: Array1<f32> = logits.mapv(|v| (v - max_v).exp());
1233 let sum_exp = exp.sum();
1234 let y: Array1<f32> = exp.mapv(|v| v / sum_exp);
1235
1236 for j in 0..3 {
1238 let mut dy = Array1::zeros(3);
1239 dy[j] = 1.0;
1240 let dx = softmax_backward(&dy, &y);
1241 let sum: f32 = dx.sum();
1242 assert!(
1243 sum.abs() < 1e-5,
1244 "softmax_backward col {j} sum = {sum}, expected 0"
1245 );
1246 }
1247 }
1248
1249 #[test]
1254 fn test_ssm_backward_gradient_shapes() {
1255 let state_dim = 4;
1256 let seq_len = 5;
1257 let input_dim = 2;
1258 let output_dim = 1;
1259
1260 let dy = Array2::<f32>::zeros((seq_len, output_dim));
1261 let states: Vec<Array2<f32>> = (0..=seq_len)
1262 .map(|_| Array2::<f32>::zeros((1, state_dim)))
1263 .collect();
1264 let a_bar = Array2::<f32>::from_elem((seq_len, state_dim), 0.9);
1265 let b_bar = Array2::<f32>::from_elem((seq_len, state_dim), 0.1);
1266 let c = Array1::<f32>::from_elem(state_dim, 1.0);
1267 let x = Array2::<f32>::zeros((seq_len, input_dim));
1268
1269 let ssm_bwd = SsmBackward::new(state_dim, seq_len);
1270 let grads = ssm_bwd
1271 .backward(&dy, &states, &a_bar, &b_bar, &c, &x)
1272 .expect("SSM backward ok");
1273
1274 assert_eq!(grads.dx.dim(), (seq_len, input_dim), "dx shape");
1275 assert_eq!(grads.da.dim(), (seq_len, state_dim), "da shape");
1276 assert_eq!(grads.db.dim(), (seq_len, state_dim), "db shape");
1277 assert_eq!(grads.dc.len(), state_dim, "dc shape");
1278 assert_eq!(
1279 grads.delta_grad.dim(),
1280 (seq_len, state_dim),
1281 "delta_grad shape"
1282 );
1283 }
1284
1285 #[test]
1290 fn test_ssm_backward_vanishing() {
1291 let state_dim = 4;
1292 let seq_len = 10;
1293 let input_dim = 1;
1294 let output_dim = 1;
1295
1296 let dy = Array2::from_elem((seq_len, output_dim), 1.0_f32);
1298
1299 let states: Vec<Array2<f32>> = (0..=seq_len)
1301 .map(|i| Array2::from_elem((1, state_dim), 0.1 * (i + 1) as f32))
1302 .collect();
1303
1304 let a_bar = Array2::from_elem((seq_len, state_dim), 0.9_f32);
1305 let b_bar = Array2::from_elem((seq_len, state_dim), 0.5_f32);
1306 let c = Array1::from_elem(state_dim, 1.0_f32);
1307 let x = Array2::from_elem((seq_len, input_dim), 1.0_f32);
1308
1309 let ssm_bwd = SsmBackward::new(state_dim, seq_len);
1310 let grads = ssm_bwd
1311 .backward(&dy, &states, &a_bar, &b_bar, &c, &x)
1312 .expect("SSM backward ok");
1313
1314 let da_norm: f32 = grads.da.iter().map(|&v| v * v).sum::<f32>().sqrt();
1315 assert!(da_norm > 1e-6, "da gradient vanished: norm = {da_norm}");
1316 }
1317
1318 #[test]
1323 fn test_grad_accumulator_zero_grad() {
1324 let mut acc = GradAccumulator::new();
1325 let g = Array1::from_vec(vec![1.0_f32, 2.0, 3.0]);
1326 acc.accumulate("w", &g).expect("accumulate ok");
1327 acc.accumulate("b", &g).expect("accumulate ok");
1328
1329 acc.zero_grad();
1330
1331 let w_grad = acc.get("w").expect("w exists after zero_grad");
1332 for &v in w_grad.iter() {
1333 assert_eq!(v, 0.0, "grad should be zeroed");
1334 }
1335 }
1336
1337 #[test]
1342 fn test_grad_accumulator_clip() {
1343 let mut acc = GradAccumulator::new();
1344 let g = Array1::from_vec(vec![3.0_f32, 4.0]); acc.accumulate("w", &g).expect("accumulate ok");
1346
1347 let norm_before = acc.apply_clip(2.5);
1348 assert!(
1349 (norm_before - 5.0).abs() < 1e-4,
1350 "norm before = {norm_before}"
1351 );
1352
1353 let w_grad = acc.get("w").expect("w exists");
1354 let norm_after: f32 = w_grad.iter().map(|&v| v * v).sum::<f32>().sqrt();
1355 assert!(
1356 (norm_after - 2.5).abs() < 1e-4,
1357 "norm after clipping should be 2.5, got {norm_after}"
1358 );
1359 }
1360
1361 #[test]
1366 fn test_grad_accumulator_normalize() {
1367 let mut acc = GradAccumulator::new();
1368 let g = Array1::from_vec(vec![2.0_f32, 4.0, 6.0]);
1369
1370 acc.accumulate("w", &g).expect("ok");
1372 acc.accumulate("w", &g).expect("ok");
1373 acc.accumulate("w", &g).expect("ok");
1374
1375 acc.normalize();
1376
1377 let w_grad = acc.get("w").expect("w exists");
1378 for (i, &v) in w_grad.iter().enumerate() {
1380 assert!(
1381 (v - g[i]).abs() < 1e-5,
1382 "normalized grad[{i}] = {v}, expected {}",
1383 g[i]
1384 );
1385 }
1386 }
1387}