1use cjc_runtime::Tensor;
8use std::cell::RefCell;
9use std::rc::Rc;
10
11pub mod pinn;
12
13#[derive(Debug, Clone)]
31pub struct Dual {
32 pub value: f64,
34 pub deriv: f64,
36}
37
38impl Dual {
39 pub fn new(value: f64, deriv: f64) -> Self {
46 Self { value, deriv }
47 }
48
49 pub fn constant(value: f64) -> Self {
55 Self { value, deriv: 0.0 }
56 }
57
58 pub fn variable(value: f64) -> Self {
66 Self { value, deriv: 1.0 }
67 }
68
69 pub fn zero() -> Self {
71 Self {
72 value: 0.0,
73 deriv: 0.0,
74 }
75 }
76
77 pub fn one() -> Self {
79 Self {
80 value: 1.0,
81 deriv: 0.0,
82 }
83 }
84}
85
86impl std::ops::Add for Dual {
87 type Output = Dual;
88 fn add(self, rhs: Dual) -> Dual {
89 Dual {
90 value: self.value + rhs.value,
91 deriv: self.deriv + rhs.deriv,
92 }
93 }
94}
95
96impl std::ops::Sub for Dual {
97 type Output = Dual;
98 fn sub(self, rhs: Dual) -> Dual {
99 Dual {
100 value: self.value - rhs.value,
101 deriv: self.deriv - rhs.deriv,
102 }
103 }
104}
105
106impl std::ops::Mul for Dual {
107 type Output = Dual;
108 fn mul(self, rhs: Dual) -> Dual {
109 Dual {
110 value: self.value * rhs.value,
111 deriv: self.value * rhs.deriv + self.deriv * rhs.value,
112 }
113 }
114}
115
116impl std::ops::Div for Dual {
117 type Output = Dual;
118 fn div(self, rhs: Dual) -> Dual {
119 let denom = rhs.value * rhs.value;
120 Dual {
121 value: self.value / rhs.value,
122 deriv: (self.deriv * rhs.value - self.value * rhs.deriv) / denom,
123 }
124 }
125}
126
127impl std::ops::Neg for Dual {
128 type Output = Dual;
129 fn neg(self) -> Dual {
130 Dual {
131 value: -self.value,
132 deriv: -self.deriv,
133 }
134 }
135}
136
137impl Dual {
138 pub fn sin(self) -> Dual {
140 Dual {
141 value: self.value.sin(),
142 deriv: self.deriv * self.value.cos(),
143 }
144 }
145
146 pub fn cos(self) -> Dual {
148 Dual {
149 value: self.value.cos(),
150 deriv: -self.deriv * self.value.sin(),
151 }
152 }
153
154 pub fn exp(self) -> Dual {
156 let e = self.value.exp();
157 Dual {
158 value: e,
159 deriv: self.deriv * e,
160 }
161 }
162
163 pub fn ln(self) -> Dual {
165 Dual {
166 value: self.value.ln(),
167 deriv: self.deriv / self.value,
168 }
169 }
170
171 pub fn sqrt(self) -> Dual {
173 let s = self.value.sqrt();
174 Dual {
175 value: s,
176 deriv: self.deriv / (2.0 * s),
177 }
178 }
179
180 pub fn pow(self, n: f64) -> Dual {
186 Dual {
187 value: self.value.powf(n),
188 deriv: self.deriv * n * self.value.powf(n - 1.0),
189 }
190 }
191}
192
193#[derive(Debug, Clone)]
200pub enum GradOp {
201 Input,
203 Parameter,
205 Add(usize, usize),
207 Sub(usize, usize),
209 Mul(usize, usize),
211 Div(usize, usize),
213 Neg(usize),
215 MatMul(usize, usize),
217 Sum(usize),
219 Mean(usize),
221 ScalarMul(usize, f64),
223 Exp(usize),
225 Ln(usize),
227 StructField {
229 parent: usize,
230 field_index: usize,
231 total_fields: usize,
232 },
233 MapLookup {
235 map_node: usize,
236 key_index: usize,
237 total_keys: usize,
238 },
239 Sin(usize),
241 Cos(usize),
243 Sqrt(usize),
245 Pow(usize, f64),
247 Sigmoid(usize),
249 Relu(usize),
251 TanhAct(usize),
253 Abs(usize),
255 Log2(usize),
257 Softmax(usize),
259 CrossEntropy {
261 logits: usize,
263 targets: usize,
265 },
266 LayerNorm(usize),
268 BatchNorm(usize),
270 Clamp {
272 input: usize,
274 min: f64,
276 max: f64,
278 },
279 Where {
281 cond: usize,
283 on_true: usize,
285 on_false: usize,
287 },
288 Reshape {
290 input: usize,
292 original_shape: Vec<usize>,
294 },
295 TransposeOp(usize),
297 CatOp {
299 inputs: Vec<usize>,
301 axis: usize,
303 sizes: Vec<usize>,
305 },
306 GatherOp {
308 input: usize,
310 indices: Vec<usize>,
312 axis: usize,
314 },
315}
316
317#[derive(Debug, Clone)]
319pub struct GradNode {
320 pub op: GradOp,
321 pub tensor: Tensor,
322 pub grad: Option<Tensor>,
323}
324
325pub struct GradGraph {
327 pub nodes: Vec<Rc<RefCell<GradNode>>>,
328}
329
330impl GradGraph {
331 pub fn new() -> Self {
332 Self { nodes: Vec::new() }
333 }
334
335 pub fn input(&mut self, tensor: Tensor) -> usize {
337 let idx = self.nodes.len();
338 self.nodes.push(Rc::new(RefCell::new(GradNode {
339 op: GradOp::Input,
340 tensor,
341 grad: None,
342 })));
343 idx
344 }
345
346 pub fn parameter(&mut self, tensor: Tensor) -> usize {
348 let idx = self.nodes.len();
349 let shape = tensor.shape().to_vec();
350 self.nodes.push(Rc::new(RefCell::new(GradNode {
351 op: GradOp::Parameter,
352 tensor,
353 grad: Some(Tensor::zeros(&shape)),
354 })));
355 idx
356 }
357
358 pub fn add(&mut self, a: usize, b: usize) -> usize {
360 let a_t = self.nodes[a].borrow().tensor.clone();
361 let b_t = self.nodes[b].borrow().tensor.clone();
362 let result = a_t.add_unchecked(&b_t);
363 let idx = self.nodes.len();
364 self.nodes.push(Rc::new(RefCell::new(GradNode {
365 op: GradOp::Add(a, b),
366 tensor: result,
367 grad: None,
368 })));
369 idx
370 }
371
372 pub fn sub(&mut self, a: usize, b: usize) -> usize {
374 let a_t = self.nodes[a].borrow().tensor.clone();
375 let b_t = self.nodes[b].borrow().tensor.clone();
376 let result = a_t.sub_unchecked(&b_t);
377 let idx = self.nodes.len();
378 self.nodes.push(Rc::new(RefCell::new(GradNode {
379 op: GradOp::Sub(a, b),
380 tensor: result,
381 grad: None,
382 })));
383 idx
384 }
385
386 pub fn mul(&mut self, a: usize, b: usize) -> usize {
388 let a_t = self.nodes[a].borrow().tensor.clone();
389 let b_t = self.nodes[b].borrow().tensor.clone();
390 let result = a_t.mul_elem_unchecked(&b_t);
391 let idx = self.nodes.len();
392 self.nodes.push(Rc::new(RefCell::new(GradNode {
393 op: GradOp::Mul(a, b),
394 tensor: result,
395 grad: None,
396 })));
397 idx
398 }
399
400 pub fn matmul(&mut self, a: usize, b: usize) -> usize {
402 let a_t = self.nodes[a].borrow().tensor.clone();
403 let b_t = self.nodes[b].borrow().tensor.clone();
404 let result = a_t.matmul_unchecked(&b_t);
405 let idx = self.nodes.len();
406 self.nodes.push(Rc::new(RefCell::new(GradNode {
407 op: GradOp::MatMul(a, b),
408 tensor: result,
409 grad: None,
410 })));
411 idx
412 }
413
414 pub fn sum(&mut self, a: usize) -> usize {
416 let a_t = self.nodes[a].borrow().tensor.clone();
417 let s = a_t.sum();
418 let result = Tensor::from_vec_unchecked(vec![s], &[1]);
419 let idx = self.nodes.len();
420 self.nodes.push(Rc::new(RefCell::new(GradNode {
421 op: GradOp::Sum(a),
422 tensor: result,
423 grad: None,
424 })));
425 idx
426 }
427
428 pub fn mean(&mut self, a: usize) -> usize {
430 let a_t = self.nodes[a].borrow().tensor.clone();
431 let m = a_t.mean();
432 let result = Tensor::from_vec_unchecked(vec![m], &[1]);
433 let idx = self.nodes.len();
434 self.nodes.push(Rc::new(RefCell::new(GradNode {
435 op: GradOp::Mean(a),
436 tensor: result,
437 grad: None,
438 })));
439 idx
440 }
441
442 pub fn sin(&mut self, a: usize) -> usize {
446 let a_t = self.nodes[a].borrow().tensor.clone();
447 let data = a_t.to_vec();
448 let result = Tensor::from_vec_unchecked(
449 data.iter().map(|&x| x.sin()).collect(),
450 a_t.shape(),
451 );
452 let idx = self.nodes.len();
453 self.nodes.push(Rc::new(RefCell::new(GradNode {
454 op: GradOp::Sin(a),
455 tensor: result,
456 grad: None,
457 })));
458 idx
459 }
460
461 pub fn cos(&mut self, a: usize) -> usize {
463 let a_t = self.nodes[a].borrow().tensor.clone();
464 let data = a_t.to_vec();
465 let result = Tensor::from_vec_unchecked(
466 data.iter().map(|&x| x.cos()).collect(),
467 a_t.shape(),
468 );
469 let idx = self.nodes.len();
470 self.nodes.push(Rc::new(RefCell::new(GradNode {
471 op: GradOp::Cos(a),
472 tensor: result,
473 grad: None,
474 })));
475 idx
476 }
477
478 pub fn sqrt(&mut self, a: usize) -> usize {
480 let a_t = self.nodes[a].borrow().tensor.clone();
481 let data = a_t.to_vec();
482 let result = Tensor::from_vec_unchecked(
483 data.iter().map(|&x| x.sqrt()).collect(),
484 a_t.shape(),
485 );
486 let idx = self.nodes.len();
487 self.nodes.push(Rc::new(RefCell::new(GradNode {
488 op: GradOp::Sqrt(a),
489 tensor: result,
490 grad: None,
491 })));
492 idx
493 }
494
495 pub fn pow(&mut self, a: usize, n: f64) -> usize {
497 let a_t = self.nodes[a].borrow().tensor.clone();
498 let data = a_t.to_vec();
499 let result = Tensor::from_vec_unchecked(
500 data.iter().map(|&x| x.powf(n)).collect(),
501 a_t.shape(),
502 );
503 let idx = self.nodes.len();
504 self.nodes.push(Rc::new(RefCell::new(GradNode {
505 op: GradOp::Pow(a, n),
506 tensor: result,
507 grad: None,
508 })));
509 idx
510 }
511
512 pub fn sigmoid(&mut self, a: usize) -> usize {
514 let a_t = self.nodes[a].borrow().tensor.clone();
515 let data = a_t.to_vec();
516 let result = Tensor::from_vec_unchecked(
517 data.iter().map(|&x| 1.0 / (1.0 + (-x).exp())).collect(),
518 a_t.shape(),
519 );
520 let idx = self.nodes.len();
521 self.nodes.push(Rc::new(RefCell::new(GradNode {
522 op: GradOp::Sigmoid(a),
523 tensor: result,
524 grad: None,
525 })));
526 idx
527 }
528
529 pub fn relu(&mut self, a: usize) -> usize {
531 let a_t = self.nodes[a].borrow().tensor.clone();
532 let data = a_t.to_vec();
533 let result = Tensor::from_vec_unchecked(
534 data.iter().map(|&x| if x > 0.0 { x } else { 0.0 }).collect(),
535 a_t.shape(),
536 );
537 let idx = self.nodes.len();
538 self.nodes.push(Rc::new(RefCell::new(GradNode {
539 op: GradOp::Relu(a),
540 tensor: result,
541 grad: None,
542 })));
543 idx
544 }
545
546 pub fn tanh_act(&mut self, a: usize) -> usize {
548 let a_t = self.nodes[a].borrow().tensor.clone();
549 let data = a_t.to_vec();
550 let result = Tensor::from_vec_unchecked(
551 data.iter().map(|&x| x.tanh()).collect(),
552 a_t.shape(),
553 );
554 let idx = self.nodes.len();
555 self.nodes.push(Rc::new(RefCell::new(GradNode {
556 op: GradOp::TanhAct(a),
557 tensor: result,
558 grad: None,
559 })));
560 idx
561 }
562
563 pub fn abs(&mut self, a: usize) -> usize {
567 let a_t = self.nodes[a].borrow().tensor.clone();
568 let data = a_t.to_vec();
569 let result = Tensor::from_vec_unchecked(
570 data.iter().map(|&x| x.abs()).collect(),
571 a_t.shape(),
572 );
573 let idx = self.nodes.len();
574 self.nodes.push(Rc::new(RefCell::new(GradNode {
575 op: GradOp::Abs(a),
576 tensor: result,
577 grad: None,
578 })));
579 idx
580 }
581
582 pub fn log2(&mut self, a: usize) -> usize {
584 let a_t = self.nodes[a].borrow().tensor.clone();
585 let data = a_t.to_vec();
586 let result = Tensor::from_vec_unchecked(
587 data.iter().map(|&x| x.log2()).collect(),
588 a_t.shape(),
589 );
590 let idx = self.nodes.len();
591 self.nodes.push(Rc::new(RefCell::new(GradNode {
592 op: GradOp::Log2(a),
593 tensor: result,
594 grad: None,
595 })));
596 idx
597 }
598
599 pub fn softmax(&mut self, a: usize) -> usize {
602 use cjc_repro::KahanAccumulatorF64;
603 let a_t = self.nodes[a].borrow().tensor.clone();
604 let data = a_t.to_vec();
605 let max_val = data.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
606 let exp_shifted: Vec<f64> = data.iter().map(|&x| (x - max_val).exp()).collect();
607 let mut sum_acc = KahanAccumulatorF64::new();
608 for &v in &exp_shifted {
609 sum_acc.add(v);
610 }
611 let sum_exp = sum_acc.finalize();
612 let softmax_data: Vec<f64> = exp_shifted.iter().map(|&e| e / sum_exp).collect();
613 let result = Tensor::from_vec_unchecked(softmax_data, a_t.shape());
614 let idx = self.nodes.len();
615 self.nodes.push(Rc::new(RefCell::new(GradNode {
616 op: GradOp::Softmax(a),
617 tensor: result,
618 grad: None,
619 })));
620 idx
621 }
622
623 pub fn cross_entropy(&mut self, logits: usize, targets: usize) -> usize {
627 use cjc_repro::KahanAccumulatorF64;
628 let logits_t = self.nodes[logits].borrow().tensor.clone();
629 let targets_t = self.nodes[targets].borrow().tensor.clone();
630 let logits_data = logits_t.to_vec();
631 let targets_data = targets_t.to_vec();
632 let max_val = logits_data.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
634 let shifted: Vec<f64> = logits_data.iter().map(|&x| x - max_val).collect();
635 let exp_shifted: Vec<f64> = shifted.iter().map(|&x| x.exp()).collect();
636 let mut sum_acc = KahanAccumulatorF64::new();
637 for &v in &exp_shifted {
638 sum_acc.add(v);
639 }
640 let log_sum_exp = sum_acc.finalize().ln();
641 let log_softmax: Vec<f64> = shifted.iter().map(|&x| x - log_sum_exp).collect();
642 let mut ce_acc = KahanAccumulatorF64::new();
644 for (t, ls) in targets_data.iter().zip(log_softmax.iter()) {
645 ce_acc.add(-t * ls);
646 }
647 let ce = ce_acc.finalize();
648 let result = Tensor::from_vec_unchecked(vec![ce], &[1]);
649 let idx = self.nodes.len();
650 self.nodes.push(Rc::new(RefCell::new(GradNode {
651 op: GradOp::CrossEntropy { logits, targets },
652 tensor: result,
653 grad: None,
654 })));
655 idx
656 }
657
658 pub fn layer_norm(&mut self, a: usize) -> usize {
661 use cjc_repro::KahanAccumulatorF64;
662 let a_t = self.nodes[a].borrow().tensor.clone();
663 let data = a_t.to_vec();
664 let n = data.len() as f64;
665 let mut mean_acc = KahanAccumulatorF64::new();
667 for &v in &data {
668 mean_acc.add(v);
669 }
670 let mean = mean_acc.finalize() / n;
671 let mut var_acc = KahanAccumulatorF64::new();
673 for &v in &data {
674 let d = v - mean;
675 var_acc.add(d * d);
676 }
677 let var = var_acc.finalize() / n;
678 let eps = 1e-5;
679 let std = (var + eps).sqrt();
680 let normed: Vec<f64> = data.iter().map(|&x| (x - mean) / std).collect();
681 let result = Tensor::from_vec_unchecked(normed, a_t.shape());
682 let idx = self.nodes.len();
683 self.nodes.push(Rc::new(RefCell::new(GradNode {
684 op: GradOp::LayerNorm(a),
685 tensor: result,
686 grad: None,
687 })));
688 idx
689 }
690
691 pub fn batch_norm(&mut self, a: usize) -> usize {
696 use cjc_repro::KahanAccumulatorF64;
697 let a_t = self.nodes[a].borrow().tensor.clone();
698 let data = a_t.to_vec();
699 let n = data.len() as f64;
700 let mut mean_acc = KahanAccumulatorF64::new();
701 for &v in &data {
702 mean_acc.add(v);
703 }
704 let mean = mean_acc.finalize() / n;
705 let mut var_acc = KahanAccumulatorF64::new();
706 for &v in &data {
707 let d = v - mean;
708 var_acc.add(d * d);
709 }
710 let var = var_acc.finalize() / n;
711 let eps = 1e-5;
712 let std = (var + eps).sqrt();
713 let normed: Vec<f64> = data.iter().map(|&x| (x - mean) / std).collect();
714 let result = Tensor::from_vec_unchecked(normed, a_t.shape());
715 let idx = self.nodes.len();
716 self.nodes.push(Rc::new(RefCell::new(GradNode {
717 op: GradOp::BatchNorm(a),
718 tensor: result,
719 grad: None,
720 })));
721 idx
722 }
723
724 pub fn clamp(&mut self, a: usize, min: f64, max: f64) -> usize {
726 let a_t = self.nodes[a].borrow().tensor.clone();
727 let data = a_t.to_vec();
728 let result = Tensor::from_vec_unchecked(
729 data.iter().map(|&x| x.max(min).min(max)).collect(),
730 a_t.shape(),
731 );
732 let idx = self.nodes.len();
733 self.nodes.push(Rc::new(RefCell::new(GradNode {
734 op: GradOp::Clamp { input: a, min, max },
735 tensor: result,
736 grad: None,
737 })));
738 idx
739 }
740
741 pub fn where_cond(&mut self, cond: usize, on_true: usize, on_false: usize) -> usize {
744 let cond_t = self.nodes[cond].borrow().tensor.clone();
745 let true_t = self.nodes[on_true].borrow().tensor.clone();
746 let false_t = self.nodes[on_false].borrow().tensor.clone();
747 let c = cond_t.to_vec();
748 let t = true_t.to_vec();
749 let f = false_t.to_vec();
750 let result_data: Vec<f64> = c.iter().zip(t.iter().zip(f.iter()))
751 .map(|(&ci, (&ti, &fi))| if ci != 0.0 { ti } else { fi })
752 .collect();
753 let result = Tensor::from_vec_unchecked(result_data, cond_t.shape());
754 let idx = self.nodes.len();
755 self.nodes.push(Rc::new(RefCell::new(GradNode {
756 op: GradOp::Where { cond, on_true, on_false },
757 tensor: result,
758 grad: None,
759 })));
760 idx
761 }
762
763 pub fn reshape(&mut self, a: usize, new_shape: &[usize]) -> usize {
765 let a_t = self.nodes[a].borrow().tensor.clone();
766 let original_shape = a_t.shape().to_vec();
767 let result = a_t.reshape(new_shape).expect("GradGraph::reshape: shape mismatch");
768 let idx = self.nodes.len();
769 self.nodes.push(Rc::new(RefCell::new(GradNode {
770 op: GradOp::Reshape { input: a, original_shape },
771 tensor: result,
772 grad: None,
773 })));
774 idx
775 }
776
777 pub fn transpose_op(&mut self, a: usize) -> usize {
779 let a_t = self.nodes[a].borrow().tensor.clone();
780 let result = a_t.transpose();
781 let idx = self.nodes.len();
782 self.nodes.push(Rc::new(RefCell::new(GradNode {
783 op: GradOp::TransposeOp(a),
784 tensor: result,
785 grad: None,
786 })));
787 idx
788 }
789
790 pub fn cat(&mut self, inputs: &[usize], axis: usize) -> usize {
793 let tensors: Vec<Tensor> = inputs.iter()
794 .map(|&i| self.nodes[i].borrow().tensor.clone())
795 .collect();
796 let sizes: Vec<usize> = tensors.iter()
797 .map(|t| t.shape()[axis])
798 .collect();
799 let mut all_data = Vec::new();
802 let mut total_along_axis = 0usize;
803 let ndim = tensors[0].ndim();
804 let mut result_shape = tensors[0].shape().to_vec();
805
806 if ndim == 1 {
807 for t in &tensors {
809 all_data.extend(t.to_vec());
810 total_along_axis += t.shape()[0];
811 }
812 result_shape[0] = total_along_axis;
813 } else if ndim == 2 && axis == 0 {
814 for t in &tensors {
816 all_data.extend(t.to_vec());
817 total_along_axis += t.shape()[0];
818 }
819 result_shape[0] = total_along_axis;
820 } else if ndim == 2 && axis == 1 {
821 let nrows = tensors[0].shape()[0];
823 for row in 0..nrows {
824 for t in &tensors {
825 let cols = t.shape()[1];
826 let row_data = t.to_vec();
827 let start = row * cols;
828 all_data.extend_from_slice(&row_data[start..start + cols]);
829 }
830 }
831 total_along_axis = sizes.iter().sum();
832 result_shape[1] = total_along_axis;
833 } else {
834 for t in &tensors {
836 all_data.extend(t.to_vec());
837 total_along_axis += t.shape()[axis];
838 }
839 result_shape[axis] = total_along_axis;
840 }
841
842 let result = Tensor::from_vec_unchecked(all_data, &result_shape);
843 let input_vec = inputs.to_vec();
844 let idx = self.nodes.len();
845 self.nodes.push(Rc::new(RefCell::new(GradNode {
846 op: GradOp::CatOp { inputs: input_vec, axis, sizes },
847 tensor: result,
848 grad: None,
849 })));
850 idx
851 }
852
853 pub fn gather(&mut self, a: usize, indices: &[usize], axis: usize) -> usize {
856 let a_t = self.nodes[a].borrow().tensor.clone();
857 let data = a_t.to_vec();
858 let gathered: Vec<f64> = if a_t.ndim() == 1 {
860 indices.iter().map(|&i| data[i]).collect()
861 } else if a_t.ndim() == 2 && axis == 0 {
862 let cols = a_t.shape()[1];
863 indices.iter().flat_map(|&i| {
864 let start = i * cols;
865 data[start..start + cols].to_vec()
866 }).collect()
867 } else {
868 indices.iter().map(|&i| data[i]).collect()
870 };
871 let mut result_shape = a_t.shape().to_vec();
872 if a_t.ndim() == 1 {
873 result_shape[0] = indices.len();
874 } else if axis == 0 {
875 result_shape[0] = indices.len();
876 } else {
877 result_shape[axis] = indices.len();
878 }
879 let result = Tensor::from_vec_unchecked(gathered, &result_shape);
880 let idx = self.nodes.len();
881 self.nodes.push(Rc::new(RefCell::new(GradNode {
882 op: GradOp::GatherOp { input: a, indices: indices.to_vec(), axis },
883 tensor: result,
884 grad: None,
885 })));
886 idx
887 }
888
889 pub fn div(&mut self, a: usize, b: usize) -> usize {
894 let a_tensor = self.nodes[a].borrow().tensor.clone();
895 let b_tensor = self.nodes[b].borrow().tensor.clone();
896 let result = a_tensor.div_elem_unchecked(&b_tensor);
897 let node = GradNode { op: GradOp::Div(a, b), tensor: result, grad: None };
898 self.nodes.push(Rc::new(RefCell::new(node)));
899 self.nodes.len() - 1
900 }
901
902 pub fn neg(&mut self, a: usize) -> usize {
905 let a_tensor = self.nodes[a].borrow().tensor.clone();
906 let result = a_tensor.neg();
907 let node = GradNode { op: GradOp::Neg(a), tensor: result, grad: None };
908 self.nodes.push(Rc::new(RefCell::new(node)));
909 self.nodes.len() - 1
910 }
911
912 pub fn scalar_mul(&mut self, a: usize, s: f64) -> usize {
915 let a_tensor = self.nodes[a].borrow().tensor.clone();
916 let result = a_tensor.scalar_mul(s);
917 let node = GradNode { op: GradOp::ScalarMul(a, s), tensor: result, grad: None };
918 self.nodes.push(Rc::new(RefCell::new(node)));
919 self.nodes.len() - 1
920 }
921
922 pub fn exp(&mut self, a: usize) -> usize {
925 let a_tensor = self.nodes[a].borrow().tensor.clone();
926 let result = Tensor::from_vec_unchecked(
927 a_tensor.to_vec().iter().map(|x| x.exp()).collect(),
928 a_tensor.shape(),
929 );
930 let node = GradNode { op: GradOp::Exp(a), tensor: result, grad: None };
931 self.nodes.push(Rc::new(RefCell::new(node)));
932 self.nodes.len() - 1
933 }
934
935 pub fn ln(&mut self, a: usize) -> usize {
938 let a_tensor = self.nodes[a].borrow().tensor.clone();
939 let result = Tensor::from_vec_unchecked(
940 a_tensor.to_vec().iter().map(|x| x.ln()).collect(),
941 a_tensor.shape(),
942 );
943 let node = GradNode { op: GradOp::Ln(a), tensor: result, grad: None };
944 self.nodes.push(Rc::new(RefCell::new(node)));
945 self.nodes.len() - 1
946 }
947
948 pub fn value(&self, idx: usize) -> f64 {
950 let node = self.nodes[idx].borrow();
951 let data = node.tensor.to_vec();
952 data[0]
953 }
954
955 pub fn tensor(&self, idx: usize) -> Tensor {
957 self.nodes[idx].borrow().tensor.clone()
958 }
959
960 pub fn set_tensor(&self, idx: usize, tensor: Tensor) {
962 self.nodes[idx].borrow_mut().tensor = tensor;
963 }
964
965 pub fn grad(&self, idx: usize) -> Option<Tensor> {
967 self.nodes[idx].borrow().grad.clone()
968 }
969
970 pub fn zero_grad(&self) {
972 for node in &self.nodes {
973 let mut n = node.borrow_mut();
974 if let Some(ref mut grad) = n.grad {
975 let shape = grad.shape().to_vec();
976 *grad = Tensor::zeros(&shape);
977 }
978 }
979 }
980
981 pub fn clip_grad(&self, max_norm: f64) {
984 for node in &self.nodes {
985 let mut n = node.borrow_mut();
986 if let Some(ref mut grad) = n.grad {
987 let data = grad.to_vec();
988 let clipped: Vec<f64> = data.iter()
989 .map(|&x| x.max(-max_norm).min(max_norm))
990 .collect();
991 let shape = grad.shape().to_vec();
992 *grad = Tensor::from_vec_unchecked(clipped, &shape);
993 }
994 }
995 }
996
997 pub fn clip_grad_norm(&self, max_norm: f64) -> f64 {
1001 use cjc_repro::KahanAccumulatorF64;
1002 let mut acc = KahanAccumulatorF64::new();
1004 for node in &self.nodes {
1005 let n = node.borrow();
1006 if let Some(ref grad) = n.grad {
1007 for &v in &grad.to_vec() {
1008 acc.add(v * v);
1009 }
1010 }
1011 }
1012 let global_norm = acc.finalize().sqrt();
1013
1014 if global_norm > max_norm && global_norm > 0.0 {
1015 let scale = max_norm / global_norm;
1016 for node in &self.nodes {
1017 let mut n = node.borrow_mut();
1018 if let Some(ref mut grad) = n.grad {
1019 let data = grad.to_vec();
1020 let scaled: Vec<f64> = data.iter().map(|&x| x * scale).collect();
1021 let shape = grad.shape().to_vec();
1022 *grad = Tensor::from_vec_unchecked(scaled, &shape);
1023 }
1024 }
1025 }
1026
1027 global_norm
1028 }
1029
1030 pub fn backward(&self, loss_idx: usize) {
1032 let n = self.nodes.len();
1033
1034 let mut grads: Vec<Option<Tensor>> = vec![None; n];
1036
1037 let loss_shape = self.nodes[loss_idx].borrow().tensor.shape().to_vec();
1039 grads[loss_idx] = Some(Tensor::ones(&loss_shape));
1040
1041 for i in (0..=loss_idx).rev() {
1043 let grad = match grads[i].take() {
1044 Some(g) => g,
1045 None => continue,
1046 };
1047
1048 let (op, node_tensor) = {
1050 let node = self.nodes[i].borrow();
1051 (node.op.clone(), node.tensor.clone())
1052 };
1053
1054 match op {
1055 GradOp::Input => {}
1056 GradOp::Parameter => {
1057 let mut node_mut = self.nodes[i].borrow_mut();
1058 if let Some(ref mut existing_grad) = node_mut.grad {
1059 *existing_grad = existing_grad.add_unchecked(&grad);
1060 } else {
1061 node_mut.grad = Some(grad);
1062 }
1063 }
1064 GradOp::Add(a, b) => {
1065 accumulate_grad(&mut grads, a, &grad);
1066 accumulate_grad(&mut grads, b, &grad);
1067 }
1068 GradOp::Sub(a, b) => {
1069 accumulate_grad(&mut grads, a, &grad);
1070 let neg_grad = grad.neg();
1071 accumulate_grad(&mut grads, b, &neg_grad);
1072 }
1073 GradOp::Mul(a, b) => {
1074 let a_val = self.nodes[a].borrow().tensor.clone();
1075 let b_val = self.nodes[b].borrow().tensor.clone();
1076
1077 let grad_a = grad.mul_elem_unchecked(&b_val);
1078 let grad_b = grad.mul_elem_unchecked(&a_val);
1079
1080 accumulate_grad(&mut grads, a, &grad_a);
1081 accumulate_grad(&mut grads, b, &grad_b);
1082 }
1083 GradOp::Div(a, b) => {
1084 let a_val = self.nodes[a].borrow().tensor.clone();
1085 let b_val = self.nodes[b].borrow().tensor.clone();
1086
1087 let grad_a = grad.div_elem_unchecked(&b_val);
1089 let b_sq = b_val.mul_elem_unchecked(&b_val);
1091 let neg_a = a_val.neg();
1092 let grad_b = grad.mul_elem_unchecked(&neg_a.div_elem_unchecked(&b_sq));
1093
1094 accumulate_grad(&mut grads, a, &grad_a);
1095 accumulate_grad(&mut grads, b, &grad_b);
1096 }
1097 GradOp::Neg(a) => {
1098 let neg_grad = grad.neg();
1099 accumulate_grad(&mut grads, a, &neg_grad);
1100 }
1101 GradOp::MatMul(a, b) => {
1102 let a_val = self.nodes[a].borrow().tensor.clone();
1105 let b_val = self.nodes[b].borrow().tensor.clone();
1106
1107 let b_t = b_val.transpose();
1108 let a_t = a_val.transpose();
1109
1110 let grad_a = grad.matmul_unchecked(&b_t);
1111 let grad_b = a_t.matmul_unchecked(&grad);
1112
1113 accumulate_grad(&mut grads, a, &grad_a);
1114 accumulate_grad(&mut grads, b, &grad_b);
1115 }
1116 GradOp::Sum(a) => {
1117 let a_shape = self.nodes[a].borrow().tensor.shape().to_vec();
1119 let grad_val = grad.to_vec()[0];
1120 let expanded = Tensor::from_vec_unchecked(
1121 vec![grad_val; a_shape.iter().product()],
1122 &a_shape,
1123 );
1124 accumulate_grad(&mut grads, a, &expanded);
1125 }
1126 GradOp::Mean(a) => {
1127 let a_shape = self.nodes[a].borrow().tensor.shape().to_vec();
1128 let n_elem = a_shape.iter().product::<usize>() as f64;
1129 let grad_val = grad.to_vec()[0] / n_elem;
1130 let expanded = Tensor::from_vec_unchecked(
1131 vec![grad_val; a_shape.iter().product()],
1132 &a_shape,
1133 );
1134 accumulate_grad(&mut grads, a, &expanded);
1135 }
1136 GradOp::ScalarMul(a, s) => {
1137 let scaled = grad.scalar_mul(s);
1138 accumulate_grad(&mut grads, a, &scaled);
1139 }
1140 GradOp::Exp(a) => {
1141 let grad_a = grad.mul_elem_unchecked(&node_tensor);
1142 accumulate_grad(&mut grads, a, &grad_a);
1143 }
1144 GradOp::Ln(a) => {
1145 let a_val = self.nodes[a].borrow().tensor.clone();
1146 let grad_a = grad.div_elem_unchecked(&a_val);
1147 accumulate_grad(&mut grads, a, &grad_a);
1148 }
1149 GradOp::StructField {
1151 parent,
1152 field_index,
1153 total_fields,
1154 } => {
1155 let _ = (field_index, total_fields);
1159 accumulate_grad(&mut grads, parent, &grad);
1160 }
1161 GradOp::MapLookup {
1162 map_node,
1163 key_index,
1164 total_keys,
1165 } => {
1166 let _ = (key_index, total_keys);
1169 accumulate_grad(&mut grads, map_node, &grad);
1170 }
1171 GradOp::Sin(a) => {
1173 let a_val = self.nodes[a].borrow().tensor.clone();
1174 let cos_a = Tensor::from_vec_unchecked(
1175 a_val.to_vec().iter().map(|&x| x.cos()).collect(),
1176 a_val.shape(),
1177 );
1178 let grad_a = grad.mul_elem_unchecked(&cos_a);
1179 accumulate_grad(&mut grads, a, &grad_a);
1180 }
1181 GradOp::Cos(a) => {
1182 let a_val = self.nodes[a].borrow().tensor.clone();
1183 let neg_sin_a = Tensor::from_vec_unchecked(
1184 a_val.to_vec().iter().map(|&x| -x.sin()).collect(),
1185 a_val.shape(),
1186 );
1187 let grad_a = grad.mul_elem_unchecked(&neg_sin_a);
1188 accumulate_grad(&mut grads, a, &grad_a);
1189 }
1190 GradOp::Sqrt(a) => {
1191 let inv_2sqrt = Tensor::from_vec_unchecked(
1193 node_tensor.to_vec().iter().map(|&x| 0.5 / x).collect(),
1194 node_tensor.shape(),
1195 );
1196 let grad_a = grad.mul_elem_unchecked(&inv_2sqrt);
1197 accumulate_grad(&mut grads, a, &grad_a);
1198 }
1199 GradOp::Pow(a, n) => {
1200 let a_val = self.nodes[a].borrow().tensor.clone();
1201 let coeff = Tensor::from_vec_unchecked(
1202 a_val.to_vec().iter().map(|&x| n * x.powf(n - 1.0)).collect(),
1203 a_val.shape(),
1204 );
1205 let grad_a = grad.mul_elem_unchecked(&coeff);
1206 accumulate_grad(&mut grads, a, &grad_a);
1207 }
1208 GradOp::Sigmoid(a) => {
1209 let sig = &node_tensor;
1211 let one_minus = Tensor::from_vec_unchecked(
1212 sig.to_vec().iter().map(|&s| 1.0 - s).collect(),
1213 sig.shape(),
1214 );
1215 let local = sig.mul_elem_unchecked(&one_minus);
1216 let grad_a = grad.mul_elem_unchecked(&local);
1217 accumulate_grad(&mut grads, a, &grad_a);
1218 }
1219 GradOp::Relu(a) => {
1220 let a_val = self.nodes[a].borrow().tensor.clone();
1221 let mask = Tensor::from_vec_unchecked(
1222 a_val.to_vec().iter().map(|&x| if x > 0.0 { 1.0 } else { 0.0 }).collect(),
1223 a_val.shape(),
1224 );
1225 let grad_a = grad.mul_elem_unchecked(&mask);
1226 accumulate_grad(&mut grads, a, &grad_a);
1227 }
1228 GradOp::TanhAct(a) => {
1229 let t = &node_tensor;
1231 let one_minus_sq = Tensor::from_vec_unchecked(
1232 t.to_vec().iter().map(|&x| 1.0 - x * x).collect(),
1233 t.shape(),
1234 );
1235 let grad_a = grad.mul_elem_unchecked(&one_minus_sq);
1236 accumulate_grad(&mut grads, a, &grad_a);
1237 }
1238 GradOp::Abs(a) => {
1240 let a_val = self.nodes[a].borrow().tensor.clone();
1241 let sign = Tensor::from_vec_unchecked(
1242 a_val.to_vec().iter().map(|&x| {
1243 if x > 0.0 { 1.0 } else if x < 0.0 { -1.0 } else { 0.0 }
1244 }).collect(),
1245 a_val.shape(),
1246 );
1247 let grad_a = grad.mul_elem_unchecked(&sign);
1248 accumulate_grad(&mut grads, a, &grad_a);
1249 }
1250 GradOp::Log2(a) => {
1251 let a_val = self.nodes[a].borrow().tensor.clone();
1253 let ln2 = std::f64::consts::LN_2;
1254 let local = Tensor::from_vec_unchecked(
1255 a_val.to_vec().iter().map(|&x| 1.0 / (x * ln2)).collect(),
1256 a_val.shape(),
1257 );
1258 let grad_a = grad.mul_elem_unchecked(&local);
1259 accumulate_grad(&mut grads, a, &grad_a);
1260 }
1261 GradOp::Softmax(a) => {
1262 use cjc_repro::KahanAccumulatorF64;
1264 let sm = &node_tensor;
1265 let sm_data = sm.to_vec();
1266 let grad_data = grad.to_vec();
1267 let mut dot_acc = KahanAccumulatorF64::new();
1268 for (&g, &s) in grad_data.iter().zip(sm_data.iter()) {
1269 dot_acc.add(g * s);
1270 }
1271 let dot = dot_acc.finalize();
1272 let grad_input: Vec<f64> = sm_data.iter().zip(grad_data.iter())
1273 .map(|(&s, &g)| s * (g - dot))
1274 .collect();
1275 let grad_a = Tensor::from_vec_unchecked(grad_input, sm.shape());
1276 accumulate_grad(&mut grads, a, &grad_a);
1277 }
1278 GradOp::CrossEntropy { logits, targets } => {
1279 use cjc_repro::KahanAccumulatorF64;
1281 let logits_val = self.nodes[logits].borrow().tensor.clone();
1282 let targets_val = self.nodes[targets].borrow().tensor.clone();
1283 let logits_data = logits_val.to_vec();
1284 let targets_data = targets_val.to_vec();
1285 let max_val = logits_data.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
1287 let exp_shifted: Vec<f64> = logits_data.iter().map(|&x| (x - max_val).exp()).collect();
1288 let mut sum_acc = KahanAccumulatorF64::new();
1289 for &v in &exp_shifted {
1290 sum_acc.add(v);
1291 }
1292 let sum_exp = sum_acc.finalize();
1293 let softmax: Vec<f64> = exp_shifted.iter().map(|&e| e / sum_exp).collect();
1294 let upstream = grad.to_vec()[0]; let grad_logits: Vec<f64> = softmax.iter().zip(targets_data.iter())
1297 .map(|(&s, &t)| upstream * (s - t))
1298 .collect();
1299 let gl = Tensor::from_vec_unchecked(grad_logits, logits_val.shape());
1300 accumulate_grad(&mut grads, logits, &gl);
1301 }
1303 GradOp::LayerNorm(a) => {
1304 use cjc_repro::KahanAccumulatorF64;
1308 let x_hat = &node_tensor;
1309 let x_hat_data = x_hat.to_vec();
1310 let grad_data = grad.to_vec();
1311 let n = x_hat_data.len() as f64;
1312 let a_val = self.nodes[a].borrow().tensor.clone();
1314 let a_data = a_val.to_vec();
1315 let mut mean_acc = KahanAccumulatorF64::new();
1316 for &v in &a_data {
1317 mean_acc.add(v);
1318 }
1319 let mean = mean_acc.finalize() / n;
1320 let mut var_acc = KahanAccumulatorF64::new();
1321 for &v in &a_data {
1322 let d = v - mean;
1323 var_acc.add(d * d);
1324 }
1325 let var = var_acc.finalize() / n;
1326 let eps = 1e-5;
1327 let std_val = (var + eps).sqrt();
1328 let mut mg_acc = KahanAccumulatorF64::new();
1330 for &g in &grad_data {
1331 mg_acc.add(g);
1332 }
1333 let mean_grad = mg_acc.finalize() / n;
1334 let mut mgx_acc = KahanAccumulatorF64::new();
1336 for (&g, &xh) in grad_data.iter().zip(x_hat_data.iter()) {
1337 mgx_acc.add(g * xh);
1338 }
1339 let mean_grad_xhat = mgx_acc.finalize() / n;
1340 let dx: Vec<f64> = grad_data.iter().zip(x_hat_data.iter())
1341 .map(|(&g, &xh)| (g - mean_grad - xh * mean_grad_xhat) / std_val)
1342 .collect();
1343 let grad_a = Tensor::from_vec_unchecked(dx, a_val.shape());
1344 accumulate_grad(&mut grads, a, &grad_a);
1345 }
1346 GradOp::BatchNorm(a) => {
1347 use cjc_repro::KahanAccumulatorF64;
1349 let x_hat = &node_tensor;
1350 let x_hat_data = x_hat.to_vec();
1351 let grad_data = grad.to_vec();
1352 let n = x_hat_data.len() as f64;
1353 let a_val = self.nodes[a].borrow().tensor.clone();
1354 let a_data = a_val.to_vec();
1355 let mut mean_acc = KahanAccumulatorF64::new();
1356 for &v in &a_data {
1357 mean_acc.add(v);
1358 }
1359 let mean = mean_acc.finalize() / n;
1360 let mut var_acc = KahanAccumulatorF64::new();
1361 for &v in &a_data {
1362 let d = v - mean;
1363 var_acc.add(d * d);
1364 }
1365 let var = var_acc.finalize() / n;
1366 let eps = 1e-5;
1367 let std_val = (var + eps).sqrt();
1368 let mut mg_acc = KahanAccumulatorF64::new();
1369 for &g in &grad_data {
1370 mg_acc.add(g);
1371 }
1372 let mean_grad = mg_acc.finalize() / n;
1373 let mut mgx_acc = KahanAccumulatorF64::new();
1374 for (&g, &xh) in grad_data.iter().zip(x_hat_data.iter()) {
1375 mgx_acc.add(g * xh);
1376 }
1377 let mean_grad_xhat = mgx_acc.finalize() / n;
1378 let dx: Vec<f64> = grad_data.iter().zip(x_hat_data.iter())
1379 .map(|(&g, &xh)| (g - mean_grad - xh * mean_grad_xhat) / std_val)
1380 .collect();
1381 let grad_a = Tensor::from_vec_unchecked(dx, a_val.shape());
1382 accumulate_grad(&mut grads, a, &grad_a);
1383 }
1384 GradOp::Clamp { input, min, max } => {
1385 let a_val = self.nodes[input].borrow().tensor.clone();
1387 let mask = Tensor::from_vec_unchecked(
1388 a_val.to_vec().iter().map(|&x| {
1389 if x >= min && x <= max { 1.0 } else { 0.0 }
1390 }).collect(),
1391 a_val.shape(),
1392 );
1393 let grad_a = grad.mul_elem_unchecked(&mask);
1394 accumulate_grad(&mut grads, input, &grad_a);
1395 }
1396 GradOp::Where { cond, on_true, on_false } => {
1397 let cond_data = self.nodes[cond].borrow().tensor.to_vec();
1398 let grad_data = grad.to_vec();
1399 let shape = grad.shape().to_vec();
1400 let grad_true: Vec<f64> = cond_data.iter().zip(grad_data.iter())
1401 .map(|(&c, &g)| if c != 0.0 { g } else { 0.0 })
1402 .collect();
1403 let grad_false: Vec<f64> = cond_data.iter().zip(grad_data.iter())
1404 .map(|(&c, &g)| if c != 0.0 { 0.0 } else { g })
1405 .collect();
1406 let gt = Tensor::from_vec_unchecked(grad_true, &shape);
1407 let gf = Tensor::from_vec_unchecked(grad_false, &shape);
1408 accumulate_grad(&mut grads, on_true, >);
1409 accumulate_grad(&mut grads, on_false, &gf);
1410 }
1412 GradOp::Reshape { input, ref original_shape } => {
1413 let grad_a = grad.reshape(original_shape)
1415 .expect("Reshape backward: shape mismatch");
1416 accumulate_grad(&mut grads, input, &grad_a);
1417 }
1418 GradOp::TransposeOp(a) => {
1419 let grad_a = grad.transpose();
1421 accumulate_grad(&mut grads, a, &grad_a);
1422 }
1423 GradOp::CatOp { ref inputs, axis, ref sizes } => {
1424 let grad_data = grad.to_vec();
1426 let grad_shape = grad.shape().to_vec();
1427 let ndim = grad_shape.len();
1428 if ndim == 1 {
1429 let mut offset = 0usize;
1430 for (&idx, &sz) in inputs.iter().zip(sizes.iter()) {
1431 let piece = grad_data[offset..offset + sz].to_vec();
1432 let gt = Tensor::from_vec_unchecked(piece, &[sz]);
1433 accumulate_grad(&mut grads, idx, >);
1434 offset += sz;
1435 }
1436 } else if ndim == 2 && axis == 0 {
1437 let cols = grad_shape[1];
1438 let mut row_offset = 0usize;
1439 for (&idx, &sz) in inputs.iter().zip(sizes.iter()) {
1440 let start = row_offset * cols;
1441 let end = start + sz * cols;
1442 let piece = grad_data[start..end].to_vec();
1443 let gt = Tensor::from_vec_unchecked(piece, &[sz, cols]);
1444 accumulate_grad(&mut grads, idx, >);
1445 row_offset += sz;
1446 }
1447 } else if ndim == 2 && axis == 1 {
1448 let nrows = grad_shape[0];
1449 let total_cols = grad_shape[1];
1450 for (input_idx, (&idx, &sz)) in inputs.iter().zip(sizes.iter()).enumerate() {
1451 let mut piece = Vec::with_capacity(nrows * sz);
1452 let col_offset: usize = sizes[..input_idx].iter().sum();
1453 for row in 0..nrows {
1454 let row_start = row * total_cols + col_offset;
1455 piece.extend_from_slice(&grad_data[row_start..row_start + sz]);
1456 }
1457 let gt = Tensor::from_vec_unchecked(piece, &[nrows, sz]);
1458 accumulate_grad(&mut grads, idx, >);
1459 }
1460 } else {
1461 let mut offset = 0usize;
1463 for (&idx, &sz) in inputs.iter().zip(sizes.iter()) {
1464 let piece_len = sz * grad_data.len() / grad_shape[axis];
1465 let piece = grad_data[offset..offset + piece_len].to_vec();
1466 let mut piece_shape = grad_shape.clone();
1467 piece_shape[axis] = sz;
1468 let gt = Tensor::from_vec_unchecked(piece, &piece_shape);
1469 accumulate_grad(&mut grads, idx, >);
1470 offset += piece_len;
1471 }
1472 }
1473 }
1474 GradOp::GatherOp { input, ref indices, axis } => {
1475 let input_shape = self.nodes[input].borrow().tensor.shape().to_vec();
1477 let input_len: usize = input_shape.iter().product();
1478 let mut scatter = vec![0.0_f64; input_len];
1479 let grad_data = grad.to_vec();
1480 if self.nodes[input].borrow().tensor.ndim() == 1 {
1481 for (gi, &idx) in indices.iter().enumerate() {
1482 scatter[idx] += grad_data[gi];
1483 }
1484 } else if axis == 0 && self.nodes[input].borrow().tensor.ndim() == 2 {
1485 let cols = input_shape[1];
1486 for (gi, &idx) in indices.iter().enumerate() {
1487 for c in 0..cols {
1488 scatter[idx * cols + c] += grad_data[gi * cols + c];
1489 }
1490 }
1491 } else {
1492 for (gi, &idx) in indices.iter().enumerate() {
1494 scatter[idx] += grad_data[gi];
1495 }
1496 }
1497 let grad_a = Tensor::from_vec_unchecked(scatter, &input_shape);
1498 accumulate_grad(&mut grads, input, &grad_a);
1499 }
1500 }
1501 }
1502 }
1503
1504 pub fn jacobian(&mut self, output_idx: usize, param_idx: usize) -> Tensor {
1509 let output_shape = self.nodes[output_idx].borrow().tensor.shape().to_vec();
1510 let output_dim: usize = output_shape.iter().product();
1511 let param_shape = self.nodes[param_idx].borrow().tensor.shape().to_vec();
1512 let param_dim: usize = param_shape.iter().product();
1513
1514 let mut jac_data = vec![0.0_f64; output_dim * param_dim];
1515
1516 for i in 0..output_dim {
1517 self.zero_grad();
1519
1520 let mut seed = vec![0.0_f64; output_dim];
1522 seed[i] = 1.0;
1523 let seed_tensor = Tensor::from_vec_unchecked(seed, &output_shape);
1524
1525 self.backward_with_seed(output_idx, &seed_tensor);
1527
1528 let grad = self.nodes[param_idx].borrow().grad.clone();
1530 if let Some(g) = grad {
1531 let g_vec = g.to_vec();
1532 for j in 0..param_dim {
1533 jac_data[i * param_dim + j] = g_vec[j];
1534 }
1535 }
1536 }
1537
1538 Tensor::from_vec_unchecked(jac_data, &[output_dim, param_dim])
1539 }
1540
1541 pub fn hessian_diag(&mut self, loss_idx: usize, param_idx: usize, eps: f64) -> Tensor {
1547 let param_shape = self.nodes[param_idx].borrow().tensor.shape().to_vec();
1548 let param_dim: usize = param_shape.iter().product();
1549 let original = self.nodes[param_idx].borrow().tensor.to_vec();
1550 let mut hess_diag = vec![0.0_f64; param_dim];
1551
1552 for i in 0..param_dim {
1553 let mut plus = original.clone();
1555 plus[i] += eps;
1556 self.nodes[param_idx].borrow_mut().tensor =
1557 Tensor::from_vec_unchecked(plus, ¶m_shape);
1558 self.zero_grad();
1559 self.backward(loss_idx);
1560 let grad_plus = self.nodes[param_idx]
1561 .borrow()
1562 .grad
1563 .as_ref()
1564 .map(|g| g.to_vec()[i])
1565 .unwrap_or(0.0);
1566
1567 let mut minus = original.clone();
1569 minus[i] -= eps;
1570 self.nodes[param_idx].borrow_mut().tensor =
1571 Tensor::from_vec_unchecked(minus, ¶m_shape);
1572 self.zero_grad();
1573 self.backward(loss_idx);
1574 let grad_minus = self.nodes[param_idx]
1575 .borrow()
1576 .grad
1577 .as_ref()
1578 .map(|g| g.to_vec()[i])
1579 .unwrap_or(0.0);
1580
1581 hess_diag[i] = (grad_plus - grad_minus) / (2.0 * eps);
1582 }
1583
1584 self.nodes[param_idx].borrow_mut().tensor =
1586 Tensor::from_vec_unchecked(original, ¶m_shape);
1587
1588 Tensor::from_vec_unchecked(hess_diag, ¶m_shape)
1589 }
1590
1591 pub fn hessian(&mut self, loss_idx: usize, param_idx: usize) -> Tensor {
1600 let eps = 1e-5;
1601 let param_shape = self.nodes[param_idx].borrow().tensor.shape().to_vec();
1602 let param_dim: usize = param_shape.iter().product();
1603 let original = self.nodes[param_idx].borrow().tensor.to_vec();
1604 let mut hess_data = vec![0.0_f64; param_dim * param_dim];
1605
1606 for i in 0..param_dim {
1607 let mut plus = original.clone();
1609 plus[i] += eps;
1610 self.nodes[param_idx].borrow_mut().tensor =
1611 Tensor::from_vec_unchecked(plus, ¶m_shape);
1612 self.reforward(param_idx + 1, loss_idx);
1613 self.zero_grad();
1614 self.backward(loss_idx);
1615 let grad_plus: Vec<f64> = self.nodes[param_idx]
1616 .borrow()
1617 .grad
1618 .as_ref()
1619 .map(|g| g.to_vec())
1620 .unwrap_or_else(|| vec![0.0; param_dim]);
1621
1622 let mut minus = original.clone();
1624 minus[i] -= eps;
1625 self.nodes[param_idx].borrow_mut().tensor =
1626 Tensor::from_vec_unchecked(minus, ¶m_shape);
1627 self.reforward(param_idx + 1, loss_idx);
1628 self.zero_grad();
1629 self.backward(loss_idx);
1630 let grad_minus: Vec<f64> = self.nodes[param_idx]
1631 .borrow()
1632 .grad
1633 .as_ref()
1634 .map(|g| g.to_vec())
1635 .unwrap_or_else(|| vec![0.0; param_dim]);
1636
1637 for j in 0..param_dim {
1639 hess_data[i * param_dim + j] = (grad_plus[j] - grad_minus[j]) / (2.0 * eps);
1640 }
1641 }
1642
1643 self.nodes[param_idx].borrow_mut().tensor =
1645 Tensor::from_vec_unchecked(original, ¶m_shape);
1646 self.reforward(param_idx + 1, loss_idx);
1647
1648 Tensor::from_vec_unchecked(hess_data, &[param_dim, param_dim])
1649 }
1650
1651 fn reforward(&mut self, start: usize, end: usize) {
1656 for node_i in start..=end {
1657 let new_tensor = {
1658 let node = self.nodes[node_i].borrow();
1659 match &node.op {
1660 GradOp::Input | GradOp::Parameter => continue,
1661 GradOp::Add(a, b) => {
1662 let at = self.nodes[*a].borrow().tensor.clone();
1663 let bt = self.nodes[*b].borrow().tensor.clone();
1664 at.add_unchecked(&bt)
1665 }
1666 GradOp::Sub(a, b) => {
1667 let at = self.nodes[*a].borrow().tensor.clone();
1668 let bt = self.nodes[*b].borrow().tensor.clone();
1669 at.sub_unchecked(&bt)
1670 }
1671 GradOp::Mul(a, b) => {
1672 let at = self.nodes[*a].borrow().tensor.clone();
1673 let bt = self.nodes[*b].borrow().tensor.clone();
1674 at.mul_elem_unchecked(&bt)
1675 }
1676 GradOp::Div(a, b) => {
1677 let at = self.nodes[*a].borrow().tensor.clone();
1678 let bt = self.nodes[*b].borrow().tensor.clone();
1679 at.div_elem_unchecked(&bt)
1680 }
1681 GradOp::Neg(a) => {
1682 self.nodes[*a].borrow().tensor.neg()
1683 }
1684 GradOp::ScalarMul(a, s) => {
1685 let s = *s;
1686 self.nodes[*a].borrow().tensor.scalar_mul(s)
1687 }
1688 GradOp::MatMul(a, b) => {
1689 let at = self.nodes[*a].borrow().tensor.clone();
1690 let bt = self.nodes[*b].borrow().tensor.clone();
1691 at.matmul_unchecked(&bt)
1692 }
1693 GradOp::Sum(a) => {
1694 let s = self.nodes[*a].borrow().tensor.sum();
1695 Tensor::from_vec_unchecked(vec![s], &[1])
1696 }
1697 GradOp::Mean(a) => {
1698 let m = self.nodes[*a].borrow().tensor.mean();
1699 Tensor::from_vec_unchecked(vec![m], &[1])
1700 }
1701 GradOp::Exp(a) => {
1702 let (data, shape) = {
1703 let t = self.nodes[*a].borrow();
1704 (t.tensor.to_vec(), t.tensor.shape().to_vec())
1705 };
1706 Tensor::from_vec_unchecked(data.iter().map(|x| x.exp()).collect(), &shape)
1707 }
1708 GradOp::Ln(a) => {
1709 let (data, shape) = {
1710 let t = self.nodes[*a].borrow();
1711 (t.tensor.to_vec(), t.tensor.shape().to_vec())
1712 };
1713 Tensor::from_vec_unchecked(data.iter().map(|x| x.ln()).collect(), &shape)
1714 }
1715 GradOp::Sin(a) => {
1716 let (data, shape) = {
1717 let t = self.nodes[*a].borrow();
1718 (t.tensor.to_vec(), t.tensor.shape().to_vec())
1719 };
1720 Tensor::from_vec_unchecked(data.iter().map(|x| x.sin()).collect(), &shape)
1721 }
1722 GradOp::Cos(a) => {
1723 let (data, shape) = {
1724 let t = self.nodes[*a].borrow();
1725 (t.tensor.to_vec(), t.tensor.shape().to_vec())
1726 };
1727 Tensor::from_vec_unchecked(data.iter().map(|x| x.cos()).collect(), &shape)
1728 }
1729 GradOp::Sqrt(a) => {
1730 let (data, shape) = {
1731 let t = self.nodes[*a].borrow();
1732 (t.tensor.to_vec(), t.tensor.shape().to_vec())
1733 };
1734 Tensor::from_vec_unchecked(data.iter().map(|x| x.sqrt()).collect(), &shape)
1735 }
1736 GradOp::Pow(a, n) => {
1737 let n = *n;
1738 let (data, shape) = {
1739 let t = self.nodes[*a].borrow();
1740 (t.tensor.to_vec(), t.tensor.shape().to_vec())
1741 };
1742 Tensor::from_vec_unchecked(data.iter().map(|x| x.powf(n)).collect(), &shape)
1743 }
1744 GradOp::Sigmoid(a) => {
1745 let (data, shape) = {
1746 let t = self.nodes[*a].borrow();
1747 (t.tensor.to_vec(), t.tensor.shape().to_vec())
1748 };
1749 Tensor::from_vec_unchecked(
1750 data.iter().map(|&x| 1.0 / (1.0 + (-x).exp())).collect(),
1751 &shape,
1752 )
1753 }
1754 GradOp::Relu(a) => {
1755 let (data, shape) = {
1756 let t = self.nodes[*a].borrow();
1757 (t.tensor.to_vec(), t.tensor.shape().to_vec())
1758 };
1759 Tensor::from_vec_unchecked(
1760 data.iter().map(|&x| if x > 0.0 { x } else { 0.0 }).collect(),
1761 &shape,
1762 )
1763 }
1764 GradOp::TanhAct(a) => {
1765 let (data, shape) = {
1766 let t = self.nodes[*a].borrow();
1767 (t.tensor.to_vec(), t.tensor.shape().to_vec())
1768 };
1769 Tensor::from_vec_unchecked(data.iter().map(|x| x.tanh()).collect(), &shape)
1770 }
1771 GradOp::Abs(a) => {
1772 let (data, shape) = {
1773 let t = self.nodes[*a].borrow();
1774 (t.tensor.to_vec(), t.tensor.shape().to_vec())
1775 };
1776 Tensor::from_vec_unchecked(data.iter().map(|x| x.abs()).collect(), &shape)
1777 }
1778 GradOp::Clamp { input, min, max } => {
1779 let min = *min;
1780 let max = *max;
1781 let (data, shape) = {
1782 let t = self.nodes[*input].borrow();
1783 (t.tensor.to_vec(), t.tensor.shape().to_vec())
1784 };
1785 Tensor::from_vec_unchecked(
1786 data.iter().map(|&x| x.max(min).min(max)).collect(),
1787 &shape,
1788 )
1789 }
1790 GradOp::Reshape { input, .. } => {
1791 let current_shape = node.tensor.shape().to_vec();
1792 let data = self.nodes[*input].borrow().tensor.to_vec();
1793 Tensor::from_vec_unchecked(data, ¤t_shape)
1794 }
1795 GradOp::TransposeOp(a) => {
1796 self.nodes[*a].borrow().tensor.transpose()
1797 }
1798 _ => node.tensor.clone(),
1801 }
1802 };
1803 self.nodes[node_i].borrow_mut().tensor = new_tensor;
1804 }
1805 }
1806
1807 pub fn double_backward(&mut self, loss_idx: usize, param_idx: usize) -> Tensor {
1816 let eps = 1e-5;
1817 let param_shape = self.nodes[param_idx].borrow().tensor.shape().to_vec();
1818 let param_dim: usize = param_shape.iter().product();
1819 let original = self.nodes[param_idx].borrow().tensor.to_vec();
1820 let mut diag = vec![0.0_f64; param_dim];
1821
1822 for i in 0..param_dim {
1823 let mut plus = original.clone();
1825 plus[i] += eps;
1826 self.nodes[param_idx].borrow_mut().tensor =
1827 Tensor::from_vec_unchecked(plus, ¶m_shape);
1828 self.reforward(param_idx + 1, loss_idx);
1829 self.zero_grad();
1830 self.backward(loss_idx);
1831 let grad_plus = self.nodes[param_idx]
1832 .borrow()
1833 .grad
1834 .as_ref()
1835 .map(|g| g.to_vec()[i])
1836 .unwrap_or(0.0);
1837
1838 let mut minus = original.clone();
1840 minus[i] -= eps;
1841 self.nodes[param_idx].borrow_mut().tensor =
1842 Tensor::from_vec_unchecked(minus, ¶m_shape);
1843 self.reforward(param_idx + 1, loss_idx);
1844 self.zero_grad();
1845 self.backward(loss_idx);
1846 let grad_minus = self.nodes[param_idx]
1847 .borrow()
1848 .grad
1849 .as_ref()
1850 .map(|g| g.to_vec()[i])
1851 .unwrap_or(0.0);
1852
1853 diag[i] = (grad_plus - grad_minus) / (2.0 * eps);
1854 }
1855
1856 self.nodes[param_idx].borrow_mut().tensor =
1858 Tensor::from_vec_unchecked(original, ¶m_shape);
1859 self.reforward(param_idx + 1, loss_idx);
1860
1861 Tensor::from_vec_unchecked(diag, ¶m_shape)
1862 }
1863
1864 pub fn vmap_forward(&mut self, input_idx: usize, batch_data: &[Tensor]) -> Vec<usize> {
1883 let mut result_indices = Vec::with_capacity(batch_data.len());
1884
1885 let graph_len = self.nodes.len();
1888
1889 for batch_tensor in batch_data {
1890 self.nodes[input_idx].borrow_mut().tensor = batch_tensor.clone();
1892
1893 for node_i in (input_idx + 1)..graph_len {
1895 let (op, new_tensor) = {
1896 let node = self.nodes[node_i].borrow();
1897 let op = node.op.clone();
1898 let new_tensor = match &op {
1899 GradOp::Add(a, b) => {
1900 let at = self.nodes[*a].borrow().tensor.clone();
1901 let bt = self.nodes[*b].borrow().tensor.clone();
1902 at.add_unchecked(&bt)
1903 }
1904 GradOp::Sub(a, b) => {
1905 let at = self.nodes[*a].borrow().tensor.clone();
1906 let bt = self.nodes[*b].borrow().tensor.clone();
1907 at.sub_unchecked(&bt)
1908 }
1909 GradOp::Mul(a, b) => {
1910 let at = self.nodes[*a].borrow().tensor.clone();
1911 let bt = self.nodes[*b].borrow().tensor.clone();
1912 at.mul_elem_unchecked(&bt)
1913 }
1914 GradOp::Div(a, b) => {
1915 let at = self.nodes[*a].borrow().tensor.clone();
1916 let bt = self.nodes[*b].borrow().tensor.clone();
1917 at.div_elem_unchecked(&bt)
1918 }
1919 GradOp::Neg(a) => {
1920 self.nodes[*a].borrow().tensor.neg()
1921 }
1922 GradOp::ScalarMul(a, s) => {
1923 self.nodes[*a].borrow().tensor.scalar_mul(*s)
1924 }
1925 GradOp::MatMul(a, b) => {
1926 let at = self.nodes[*a].borrow().tensor.clone();
1927 let bt = self.nodes[*b].borrow().tensor.clone();
1928 at.matmul_unchecked(&bt)
1929 }
1930 GradOp::Sum(a) => {
1931 let s = self.nodes[*a].borrow().tensor.sum();
1932 let shape = vec![1usize];
1933 Tensor::from_vec_unchecked(vec![s], &shape)
1934 }
1935 GradOp::Mean(a) => {
1936 let m = self.nodes[*a].borrow().tensor.mean();
1937 Tensor::from_vec_unchecked(vec![m], &[1])
1938 }
1939 GradOp::Exp(a) => {
1940 let data = self.nodes[*a].borrow().tensor.to_vec();
1941 let shape = self.nodes[*a].borrow().tensor.shape().to_vec();
1942 Tensor::from_vec_unchecked(
1943 data.iter().map(|x| x.exp()).collect(),
1944 &shape,
1945 )
1946 }
1947 GradOp::Ln(a) => {
1948 let data = self.nodes[*a].borrow().tensor.to_vec();
1949 let shape = self.nodes[*a].borrow().tensor.shape().to_vec();
1950 Tensor::from_vec_unchecked(
1951 data.iter().map(|x| x.ln()).collect(),
1952 &shape,
1953 )
1954 }
1955 GradOp::Sin(a) => {
1956 let data = self.nodes[*a].borrow().tensor.to_vec();
1957 let shape = self.nodes[*a].borrow().tensor.shape().to_vec();
1958 Tensor::from_vec_unchecked(
1959 data.iter().map(|x| x.sin()).collect(),
1960 &shape,
1961 )
1962 }
1963 GradOp::Cos(a) => {
1964 let data = self.nodes[*a].borrow().tensor.to_vec();
1965 let shape = self.nodes[*a].borrow().tensor.shape().to_vec();
1966 Tensor::from_vec_unchecked(
1967 data.iter().map(|x| x.cos()).collect(),
1968 &shape,
1969 )
1970 }
1971 GradOp::Sqrt(a) => {
1972 let data = self.nodes[*a].borrow().tensor.to_vec();
1973 let shape = self.nodes[*a].borrow().tensor.shape().to_vec();
1974 Tensor::from_vec_unchecked(
1975 data.iter().map(|x| x.sqrt()).collect(),
1976 &shape,
1977 )
1978 }
1979 GradOp::Pow(a, n) => {
1980 let data = self.nodes[*a].borrow().tensor.to_vec();
1981 let shape = self.nodes[*a].borrow().tensor.shape().to_vec();
1982 Tensor::from_vec_unchecked(
1983 data.iter().map(|x| x.powf(*n)).collect(),
1984 &shape,
1985 )
1986 }
1987 GradOp::Sigmoid(a) => {
1988 let data = self.nodes[*a].borrow().tensor.to_vec();
1989 let shape = self.nodes[*a].borrow().tensor.shape().to_vec();
1990 Tensor::from_vec_unchecked(
1991 data.iter().map(|&x| 1.0 / (1.0 + (-x).exp())).collect(),
1992 &shape,
1993 )
1994 }
1995 GradOp::Relu(a) => {
1996 let data = self.nodes[*a].borrow().tensor.to_vec();
1997 let shape = self.nodes[*a].borrow().tensor.shape().to_vec();
1998 Tensor::from_vec_unchecked(
1999 data.iter().map(|&x| if x > 0.0 { x } else { 0.0 }).collect(),
2000 &shape,
2001 )
2002 }
2003 GradOp::TanhAct(a) => {
2004 let data = self.nodes[*a].borrow().tensor.to_vec();
2005 let shape = self.nodes[*a].borrow().tensor.shape().to_vec();
2006 Tensor::from_vec_unchecked(
2007 data.iter().map(|x| x.tanh()).collect(),
2008 &shape,
2009 )
2010 }
2011 GradOp::Abs(a) => {
2012 let data = self.nodes[*a].borrow().tensor.to_vec();
2013 let shape = self.nodes[*a].borrow().tensor.shape().to_vec();
2014 Tensor::from_vec_unchecked(
2015 data.iter().map(|x| x.abs()).collect(),
2016 &shape,
2017 )
2018 }
2019 GradOp::Clamp { input, min, max } => {
2020 let data = self.nodes[*input].borrow().tensor.to_vec();
2021 let shape = self.nodes[*input].borrow().tensor.shape().to_vec();
2022 Tensor::from_vec_unchecked(
2023 data.iter().map(|&x| x.max(*min).min(*max)).collect(),
2024 &shape,
2025 )
2026 }
2027 GradOp::Reshape { input, .. } => {
2028 let data = self.nodes[*input].borrow().tensor.to_vec();
2030 let shape = node.tensor.shape().to_vec();
2031 Tensor::from_vec_unchecked(data, &shape)
2032 }
2033 GradOp::TransposeOp(a) => {
2034 self.nodes[*a].borrow().tensor.transpose()
2035 }
2036 _ => node.tensor.clone(),
2039 };
2040 (op, new_tensor)
2041 };
2042 let _ = op; self.nodes[node_i].borrow_mut().tensor = new_tensor;
2044 }
2045
2046 let output_tensor = self.nodes[graph_len - 1].borrow().tensor.clone();
2049 let snapshot_idx = self.nodes.len();
2050 self.nodes.push(Rc::new(RefCell::new(GradNode {
2051 op: GradOp::Input,
2052 tensor: output_tensor,
2053 grad: None,
2054 })));
2055 result_indices.push(snapshot_idx);
2056 }
2057
2058 result_indices
2059 }
2060
2061 pub fn backward_with_seed(&mut self, loss_idx: usize, seed: &Tensor) {
2063 let n = self.nodes.len();
2064 let mut grads: Vec<Option<Tensor>> = vec![None; n];
2065 grads[loss_idx] = Some(seed.clone());
2066
2067 for i in (0..n).rev() {
2068 let grad = match grads[i].take() {
2069 Some(g) => g,
2070 None => continue,
2071 };
2072
2073 let node = self.nodes[i].borrow();
2074 if let Some(ref _param_grad) = node.grad {
2075 drop(node);
2077 let new_grad = {
2078 let n = self.nodes[i].borrow();
2079 if let Some(ref existing) = n.grad {
2080 if existing.to_vec().iter().all(|&x| x == 0.0) {
2081 grad.clone()
2082 } else {
2083 existing.add_unchecked(&grad)
2084 }
2085 } else {
2086 grad.clone()
2087 }
2088 };
2089 self.nodes[i].borrow_mut().grad = Some(new_grad);
2090 } else {
2091 drop(node);
2092 }
2093
2094 let node = self.nodes[i].borrow();
2096 let node_tensor = node.tensor.clone();
2097 match &node.op {
2098 GradOp::Input | GradOp::Parameter => {}
2099 GradOp::Add(a, b) => {
2100 accumulate_grad(&mut grads, *a, &grad);
2101 accumulate_grad(&mut grads, *b, &grad);
2102 }
2103 GradOp::Sub(a, b) => {
2104 accumulate_grad(&mut grads, *a, &grad);
2105 accumulate_grad(&mut grads, *b, &grad.neg());
2106 }
2107 GradOp::Mul(a, b) => {
2108 let a_val = self.nodes[*a].borrow().tensor.clone();
2109 let b_val = self.nodes[*b].borrow().tensor.clone();
2110 accumulate_grad(&mut grads, *a, &grad.mul_elem_unchecked(&b_val));
2111 accumulate_grad(&mut grads, *b, &grad.mul_elem_unchecked(&a_val));
2112 }
2113 GradOp::Div(a, b) => {
2114 let a_val = self.nodes[*a].borrow().tensor.clone();
2115 let b_val = self.nodes[*b].borrow().tensor.clone();
2116 let grad_a = grad.div_elem_unchecked(&b_val);
2117 let neg_a_over_b2 = a_val.neg().div_elem_unchecked(
2118 &b_val.mul_elem_unchecked(&b_val),
2119 );
2120 let grad_b = grad.mul_elem_unchecked(&neg_a_over_b2);
2121 accumulate_grad(&mut grads, *a, &grad_a);
2122 accumulate_grad(&mut grads, *b, &grad_b);
2123 }
2124 GradOp::Neg(a) => {
2125 accumulate_grad(&mut grads, *a, &grad.neg());
2126 }
2127 GradOp::ScalarMul(a, s) => {
2128 accumulate_grad(&mut grads, *a, &grad.scalar_mul(*s));
2129 }
2130 GradOp::Exp(a) => {
2131 accumulate_grad(&mut grads, *a, &grad.mul_elem_unchecked(&node_tensor));
2132 }
2133 GradOp::Ln(a) => {
2134 let a_val = self.nodes[*a].borrow().tensor.clone();
2135 let inv = Tensor::from_vec_unchecked(
2136 a_val.to_vec().iter().map(|&x| 1.0 / x).collect(),
2137 a_val.shape(),
2138 );
2139 accumulate_grad(&mut grads, *a, &grad.mul_elem_unchecked(&inv));
2140 }
2141 GradOp::Sin(a) => {
2142 let a_val = self.nodes[*a].borrow().tensor.clone();
2143 let cos_a = Tensor::from_vec_unchecked(
2144 a_val.to_vec().iter().map(|&x| x.cos()).collect(),
2145 a_val.shape(),
2146 );
2147 accumulate_grad(&mut grads, *a, &grad.mul_elem_unchecked(&cos_a));
2148 }
2149 GradOp::Cos(a) => {
2150 let a_val = self.nodes[*a].borrow().tensor.clone();
2151 let neg_sin = Tensor::from_vec_unchecked(
2152 a_val.to_vec().iter().map(|&x| -x.sin()).collect(),
2153 a_val.shape(),
2154 );
2155 accumulate_grad(&mut grads, *a, &grad.mul_elem_unchecked(&neg_sin));
2156 }
2157 GradOp::Sqrt(a) => {
2158 let inv2sqrt = Tensor::from_vec_unchecked(
2159 node_tensor.to_vec().iter().map(|&x| 0.5 / x).collect(),
2160 node_tensor.shape(),
2161 );
2162 accumulate_grad(&mut grads, *a, &grad.mul_elem_unchecked(&inv2sqrt));
2163 }
2164 GradOp::Pow(a, exp) => {
2165 let a_val = self.nodes[*a].borrow().tensor.clone();
2166 let local = Tensor::from_vec_unchecked(
2167 a_val.to_vec().iter().map(|&x| exp * x.powf(exp - 1.0)).collect(),
2168 a_val.shape(),
2169 );
2170 accumulate_grad(&mut grads, *a, &grad.mul_elem_unchecked(&local));
2171 }
2172 GradOp::Sigmoid(a) => {
2173 let sig = &node_tensor;
2174 let one_minus = Tensor::from_vec_unchecked(
2175 sig.to_vec().iter().map(|&x| 1.0 - x).collect(),
2176 sig.shape(),
2177 );
2178 let local = sig.mul_elem_unchecked(&one_minus);
2179 accumulate_grad(&mut grads, *a, &grad.mul_elem_unchecked(&local));
2180 }
2181 GradOp::Relu(a) => {
2182 let a_val = self.nodes[*a].borrow().tensor.clone();
2183 let mask = Tensor::from_vec_unchecked(
2184 a_val.to_vec().iter().map(|&x| if x > 0.0 { 1.0 } else { 0.0 }).collect(),
2185 a_val.shape(),
2186 );
2187 accumulate_grad(&mut grads, *a, &grad.mul_elem_unchecked(&mask));
2188 }
2189 GradOp::TanhAct(a) => {
2190 let one_minus_sq = Tensor::from_vec_unchecked(
2191 node_tensor.to_vec().iter().map(|&x| 1.0 - x * x).collect(),
2192 node_tensor.shape(),
2193 );
2194 accumulate_grad(&mut grads, *a, &grad.mul_elem_unchecked(&one_minus_sq));
2195 }
2196 GradOp::MatMul(a, b) => {
2197 let a_val = self.nodes[*a].borrow().tensor.clone();
2198 let b_val = self.nodes[*b].borrow().tensor.clone();
2199 accumulate_grad(&mut grads, *a, &grad.matmul_unchecked(&b_val.transpose()));
2200 accumulate_grad(&mut grads, *b, &a_val.transpose().matmul_unchecked(&grad));
2201 }
2202 GradOp::Sum(a) => {
2203 let a_shape = self.nodes[*a].borrow().tensor.shape().to_vec();
2204 let grad_val = grad.to_vec()[0];
2205 let expanded = Tensor::from_vec_unchecked(
2206 vec![grad_val; a_shape.iter().product()],
2207 &a_shape,
2208 );
2209 accumulate_grad(&mut grads, *a, &expanded);
2210 }
2211 GradOp::Mean(a) => {
2212 let a_shape = self.nodes[*a].borrow().tensor.shape().to_vec();
2213 let n_elem = a_shape.iter().product::<usize>() as f64;
2214 let grad_val = grad.to_vec()[0] / n_elem;
2215 let expanded = Tensor::from_vec_unchecked(
2216 vec![grad_val; a_shape.iter().product()],
2217 &a_shape,
2218 );
2219 accumulate_grad(&mut grads, *a, &expanded);
2220 }
2221 GradOp::StructField { parent, field_index, total_fields } => {
2222 let parent_shape = self.nodes[*parent].borrow().tensor.shape().to_vec();
2223 let parent_n: usize = parent_shape.iter().product();
2224 let chunk = parent_n / total_fields;
2225 let start = field_index * chunk;
2226 let mut parent_grad = vec![0.0_f64; parent_n];
2227 let g_vec = grad.to_vec();
2228 for (j, &gv) in g_vec.iter().enumerate() {
2229 parent_grad[start + j] = gv;
2230 }
2231 let pg = Tensor::from_vec_unchecked(parent_grad, &parent_shape);
2232 accumulate_grad(&mut grads, *parent, &pg);
2233 }
2234 GradOp::MapLookup { map_node, key_index, total_keys } => {
2235 let map_shape = self.nodes[*map_node].borrow().tensor.shape().to_vec();
2236 let map_n: usize = map_shape.iter().product();
2237 let chunk = map_n / total_keys;
2238 let start = key_index * chunk;
2239 let mut map_grad = vec![0.0_f64; map_n];
2240 let g_vec = grad.to_vec();
2241 for (j, &gv) in g_vec.iter().enumerate() {
2242 map_grad[start + j] = gv;
2243 }
2244 let mg = Tensor::from_vec_unchecked(map_grad, &map_shape);
2245 accumulate_grad(&mut grads, *map_node, &mg);
2246 }
2247 GradOp::Abs(a) => {
2249 let a_val = self.nodes[*a].borrow().tensor.clone();
2250 let sign = Tensor::from_vec_unchecked(
2251 a_val.to_vec().iter().map(|&x| {
2252 if x > 0.0 { 1.0 } else if x < 0.0 { -1.0 } else { 0.0 }
2253 }).collect(),
2254 a_val.shape(),
2255 );
2256 accumulate_grad(&mut grads, *a, &grad.mul_elem_unchecked(&sign));
2257 }
2258 GradOp::Log2(a) => {
2259 let a_val = self.nodes[*a].borrow().tensor.clone();
2260 let ln2 = std::f64::consts::LN_2;
2261 let local = Tensor::from_vec_unchecked(
2262 a_val.to_vec().iter().map(|&x| 1.0 / (x * ln2)).collect(),
2263 a_val.shape(),
2264 );
2265 accumulate_grad(&mut grads, *a, &grad.mul_elem_unchecked(&local));
2266 }
2267 GradOp::Softmax(a) => {
2268 use cjc_repro::KahanAccumulatorF64;
2269 let sm = &node_tensor;
2270 let sm_data = sm.to_vec();
2271 let grad_data = grad.to_vec();
2272 let mut dot_acc = KahanAccumulatorF64::new();
2273 for (&g, &s) in grad_data.iter().zip(sm_data.iter()) {
2274 dot_acc.add(g * s);
2275 }
2276 let dot = dot_acc.finalize();
2277 let grad_input: Vec<f64> = sm_data.iter().zip(grad_data.iter())
2278 .map(|(&s, &g)| s * (g - dot))
2279 .collect();
2280 let grad_a = Tensor::from_vec_unchecked(grad_input, sm.shape());
2281 accumulate_grad(&mut grads, *a, &grad_a);
2282 }
2283 GradOp::CrossEntropy { logits, targets } => {
2284 use cjc_repro::KahanAccumulatorF64;
2285 let logits_val = self.nodes[*logits].borrow().tensor.clone();
2286 let targets_val = self.nodes[*targets].borrow().tensor.clone();
2287 let logits_data = logits_val.to_vec();
2288 let targets_data = targets_val.to_vec();
2289 let max_val = logits_data.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
2290 let exp_shifted: Vec<f64> = logits_data.iter().map(|&x| (x - max_val).exp()).collect();
2291 let mut sum_acc = KahanAccumulatorF64::new();
2292 for &v in &exp_shifted {
2293 sum_acc.add(v);
2294 }
2295 let sum_exp = sum_acc.finalize();
2296 let softmax: Vec<f64> = exp_shifted.iter().map(|&e| e / sum_exp).collect();
2297 let upstream = grad.to_vec()[0];
2298 let grad_logits: Vec<f64> = softmax.iter().zip(targets_data.iter())
2299 .map(|(&s, &t)| upstream * (s - t))
2300 .collect();
2301 let gl = Tensor::from_vec_unchecked(grad_logits, logits_val.shape());
2302 accumulate_grad(&mut grads, *logits, &gl);
2303 }
2304 GradOp::LayerNorm(a) => {
2305 use cjc_repro::KahanAccumulatorF64;
2306 let x_hat = &node_tensor;
2307 let x_hat_data = x_hat.to_vec();
2308 let grad_data = grad.to_vec();
2309 let n = x_hat_data.len() as f64;
2310 let a_val = self.nodes[*a].borrow().tensor.clone();
2311 let a_data = a_val.to_vec();
2312 let mut mean_acc = KahanAccumulatorF64::new();
2313 for &v in &a_data { mean_acc.add(v); }
2314 let mean = mean_acc.finalize() / n;
2315 let mut var_acc = KahanAccumulatorF64::new();
2316 for &v in &a_data { let d = v - mean; var_acc.add(d * d); }
2317 let var = var_acc.finalize() / n;
2318 let std_val = (var + 1e-5).sqrt();
2319 let mut mg_acc = KahanAccumulatorF64::new();
2320 for &g in &grad_data { mg_acc.add(g); }
2321 let mean_grad = mg_acc.finalize() / n;
2322 let mut mgx_acc = KahanAccumulatorF64::new();
2323 for (&g, &xh) in grad_data.iter().zip(x_hat_data.iter()) { mgx_acc.add(g * xh); }
2324 let mean_grad_xhat = mgx_acc.finalize() / n;
2325 let dx: Vec<f64> = grad_data.iter().zip(x_hat_data.iter())
2326 .map(|(&g, &xh)| (g - mean_grad - xh * mean_grad_xhat) / std_val)
2327 .collect();
2328 accumulate_grad(&mut grads, *a, &Tensor::from_vec_unchecked(dx, a_val.shape()));
2329 }
2330 GradOp::BatchNorm(a) => {
2331 use cjc_repro::KahanAccumulatorF64;
2332 let x_hat = &node_tensor;
2333 let x_hat_data = x_hat.to_vec();
2334 let grad_data = grad.to_vec();
2335 let n = x_hat_data.len() as f64;
2336 let a_val = self.nodes[*a].borrow().tensor.clone();
2337 let a_data = a_val.to_vec();
2338 let mut mean_acc = KahanAccumulatorF64::new();
2339 for &v in &a_data { mean_acc.add(v); }
2340 let mean = mean_acc.finalize() / n;
2341 let mut var_acc = KahanAccumulatorF64::new();
2342 for &v in &a_data { let d = v - mean; var_acc.add(d * d); }
2343 let var = var_acc.finalize() / n;
2344 let std_val = (var + 1e-5).sqrt();
2345 let mut mg_acc = KahanAccumulatorF64::new();
2346 for &g in &grad_data { mg_acc.add(g); }
2347 let mean_grad = mg_acc.finalize() / n;
2348 let mut mgx_acc = KahanAccumulatorF64::new();
2349 for (&g, &xh) in grad_data.iter().zip(x_hat_data.iter()) { mgx_acc.add(g * xh); }
2350 let mean_grad_xhat = mgx_acc.finalize() / n;
2351 let dx: Vec<f64> = grad_data.iter().zip(x_hat_data.iter())
2352 .map(|(&g, &xh)| (g - mean_grad - xh * mean_grad_xhat) / std_val)
2353 .collect();
2354 accumulate_grad(&mut grads, *a, &Tensor::from_vec_unchecked(dx, a_val.shape()));
2355 }
2356 GradOp::Clamp { input, min, max } => {
2357 let a_val = self.nodes[*input].borrow().tensor.clone();
2358 let mask = Tensor::from_vec_unchecked(
2359 a_val.to_vec().iter().map(|&x| {
2360 if x >= *min && x <= *max { 1.0 } else { 0.0 }
2361 }).collect(),
2362 a_val.shape(),
2363 );
2364 accumulate_grad(&mut grads, *input, &grad.mul_elem_unchecked(&mask));
2365 }
2366 GradOp::Where { cond, on_true, on_false } => {
2367 let cond_data = self.nodes[*cond].borrow().tensor.to_vec();
2368 let grad_data = grad.to_vec();
2369 let shape = grad.shape().to_vec();
2370 let grad_true: Vec<f64> = cond_data.iter().zip(grad_data.iter())
2371 .map(|(&c, &g)| if c != 0.0 { g } else { 0.0 }).collect();
2372 let grad_false: Vec<f64> = cond_data.iter().zip(grad_data.iter())
2373 .map(|(&c, &g)| if c != 0.0 { 0.0 } else { g }).collect();
2374 accumulate_grad(&mut grads, *on_true, &Tensor::from_vec_unchecked(grad_true, &shape));
2375 accumulate_grad(&mut grads, *on_false, &Tensor::from_vec_unchecked(grad_false, &shape));
2376 }
2377 GradOp::Reshape { input, ref original_shape } => {
2378 let grad_a = grad.reshape(original_shape).expect("Reshape backward: shape mismatch");
2379 accumulate_grad(&mut grads, *input, &grad_a);
2380 }
2381 GradOp::TransposeOp(a) => {
2382 accumulate_grad(&mut grads, *a, &grad.transpose());
2383 }
2384 GradOp::CatOp { ref inputs, axis, ref sizes } => {
2385 let grad_data = grad.to_vec();
2386 let grad_shape = grad.shape().to_vec();
2387 let ndim = grad_shape.len();
2388 if ndim == 1 {
2389 let mut offset = 0usize;
2390 for (idx, &sz) in inputs.iter().zip(sizes.iter()) {
2391 let piece = grad_data[offset..offset + sz].to_vec();
2392 accumulate_grad(&mut grads, *idx, &Tensor::from_vec_unchecked(piece, &[sz]));
2393 offset += sz;
2394 }
2395 } else if ndim == 2 && *axis == 0 {
2396 let cols = grad_shape[1];
2397 let mut row_offset = 0usize;
2398 for (idx, &sz) in inputs.iter().zip(sizes.iter()) {
2399 let start = row_offset * cols;
2400 let end = start + sz * cols;
2401 let piece = grad_data[start..end].to_vec();
2402 accumulate_grad(&mut grads, *idx, &Tensor::from_vec_unchecked(piece, &[sz, cols]));
2403 row_offset += sz;
2404 }
2405 } else if ndim == 2 && *axis == 1 {
2406 let nrows = grad_shape[0];
2407 let total_cols = grad_shape[1];
2408 for (input_idx, (idx, &sz)) in inputs.iter().zip(sizes.iter()).enumerate() {
2409 let mut piece = Vec::with_capacity(nrows * sz);
2410 let col_offset: usize = sizes[..input_idx].iter().sum();
2411 for row in 0..nrows {
2412 let row_start = row * total_cols + col_offset;
2413 piece.extend_from_slice(&grad_data[row_start..row_start + sz]);
2414 }
2415 accumulate_grad(&mut grads, *idx, &Tensor::from_vec_unchecked(piece, &[nrows, sz]));
2416 }
2417 } else {
2418 let mut offset = 0usize;
2419 for (idx, &sz) in inputs.iter().zip(sizes.iter()) {
2420 let piece_len = sz * grad_data.len() / grad_shape[*axis];
2421 let piece = grad_data[offset..offset + piece_len].to_vec();
2422 let mut piece_shape = grad_shape.clone();
2423 piece_shape[*axis] = sz;
2424 accumulate_grad(&mut grads, *idx, &Tensor::from_vec_unchecked(piece, &piece_shape));
2425 offset += piece_len;
2426 }
2427 }
2428 }
2429 GradOp::GatherOp { input, ref indices, axis } => {
2430 let input_shape = self.nodes[*input].borrow().tensor.shape().to_vec();
2431 let input_len: usize = input_shape.iter().product();
2432 let mut scatter = vec![0.0_f64; input_len];
2433 let grad_data = grad.to_vec();
2434 if self.nodes[*input].borrow().tensor.ndim() == 1 {
2435 for (gi, &idx) in indices.iter().enumerate() {
2436 scatter[idx] += grad_data[gi];
2437 }
2438 } else if *axis == 0 && self.nodes[*input].borrow().tensor.ndim() == 2 {
2439 let cols = input_shape[1];
2440 for (gi, &idx) in indices.iter().enumerate() {
2441 for c in 0..cols {
2442 scatter[idx * cols + c] += grad_data[gi * cols + c];
2443 }
2444 }
2445 } else {
2446 for (gi, &idx) in indices.iter().enumerate() {
2447 scatter[idx] += grad_data[gi];
2448 }
2449 }
2450 accumulate_grad(&mut grads, *input, &Tensor::from_vec_unchecked(scatter, &input_shape));
2451 }
2452 }
2453 }
2454 }
2455}
2456
2457fn accumulate_grad(grads: &mut [Option<Tensor>], idx: usize, grad: &Tensor) {
2458 if let Some(existing) = &grads[idx] {
2459 grads[idx] = Some(existing.add_unchecked(grad));
2460 } else {
2461 grads[idx] = Some(grad.clone());
2462 }
2463}
2464
2465impl Default for GradGraph {
2468 fn default() -> Self {
2469 Self::new()
2470 }
2471}
2472
2473pub fn check_grad_finite_diff<F>(
2477 f: F,
2478 x: f64,
2479 expected_grad: f64,
2480 eps: f64,
2481 tol: f64,
2482) -> bool
2483where
2484 F: Fn(f64) -> f64,
2485{
2486 let fd_grad = (f(x + eps) - f(x - eps)) / (2.0 * eps);
2487 (fd_grad - expected_grad).abs() < tol
2488}
2489
2490#[cfg(test)]
2491mod tests {
2492 use super::*;
2493
2494 #[test]
2497 fn test_dual_add() {
2498 let a = Dual::variable(3.0);
2499 let b = Dual::constant(2.0);
2500 let c = a + b;
2501 assert_eq!(c.value, 5.0);
2502 assert_eq!(c.deriv, 1.0);
2503 }
2504
2505 #[test]
2506 fn test_dual_mul() {
2507 let a = Dual::variable(3.0);
2508 let b = Dual::constant(2.0);
2509 let c = a * b;
2510 assert_eq!(c.value, 6.0);
2511 assert_eq!(c.deriv, 2.0); }
2513
2514 #[test]
2515 fn test_dual_chain_rule() {
2516 let x = Dual::variable(3.0);
2518 let result = x.clone() * x.clone() + Dual::constant(2.0) * x + Dual::one();
2519 assert_eq!(result.value, 16.0);
2520 assert_eq!(result.deriv, 8.0);
2521 }
2522
2523 #[test]
2524 fn test_dual_exp() {
2525 let x = Dual::variable(1.0);
2526 let result = x.exp();
2527 assert!((result.value - std::f64::consts::E).abs() < 1e-10);
2528 assert!((result.deriv - std::f64::consts::E).abs() < 1e-10);
2529 }
2530
2531 #[test]
2532 fn test_dual_sin_cos() {
2533 let x = Dual::variable(0.0);
2534 let sin_x = x.clone().sin();
2535 let cos_x = x.cos();
2536 assert!((sin_x.value - 0.0).abs() < 1e-10);
2537 assert!((sin_x.deriv - 1.0).abs() < 1e-10); assert!((cos_x.value - 1.0).abs() < 1e-10);
2539 assert!((cos_x.deriv - 0.0).abs() < 1e-10); }
2541
2542 #[test]
2543 fn test_dual_div() {
2544 let a = Dual::variable(6.0);
2545 let b = Dual::constant(3.0);
2546 let c = a / b;
2547 assert_eq!(c.value, 2.0);
2548 assert!((c.deriv - 1.0 / 3.0).abs() < 1e-10);
2549 }
2550
2551 #[test]
2552 fn test_finite_diff_validation() {
2553 let f = |x: f64| x * x;
2555 assert!(check_grad_finite_diff(f, 3.0, 6.0, 1e-7, 1e-5));
2556 }
2557
2558 #[test]
2561 fn test_reverse_add() {
2562 let mut g = GradGraph::new();
2563 let a = g.parameter(Tensor::from_vec_unchecked(vec![3.0], &[1]));
2564 let b = g.parameter(Tensor::from_vec_unchecked(vec![2.0], &[1]));
2565 let c = g.add(a, b);
2566
2567 g.backward(c);
2568
2569 let ga = g.grad(a).unwrap();
2570 let gb = g.grad(b).unwrap();
2571 assert!((ga.to_vec()[0] - 1.0).abs() < 1e-10);
2572 assert!((gb.to_vec()[0] - 1.0).abs() < 1e-10);
2573 }
2574
2575 #[test]
2576 fn test_reverse_mul() {
2577 let mut g = GradGraph::new();
2578 let a = g.parameter(Tensor::from_vec_unchecked(vec![3.0], &[1]));
2579 let b = g.parameter(Tensor::from_vec_unchecked(vec![2.0], &[1]));
2580 let c = g.mul(a, b);
2581
2582 g.backward(c);
2583
2584 let ga = g.grad(a).unwrap();
2585 let gb = g.grad(b).unwrap();
2586 assert!((ga.to_vec()[0] - 2.0).abs() < 1e-10); assert!((gb.to_vec()[0] - 3.0).abs() < 1e-10); }
2589
2590 #[test]
2591 fn test_reverse_matmul_gradient() {
2592 let mut g = GradGraph::new();
2593
2594 let a = g.parameter(Tensor::from_vec_unchecked(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]));
2596 let b = g.parameter(Tensor::from_vec_unchecked(vec![5.0, 6.0, 7.0, 8.0], &[2, 2]));
2597 let c = g.matmul(a, b);
2598 let loss = g.sum(c);
2599
2600 g.backward(loss);
2601
2602 let ga = g.grad(a).unwrap();
2604 let ga_data = ga.to_vec();
2605 assert!((ga_data[0] - 11.0).abs() < 1e-10);
2610 assert!((ga_data[1] - 15.0).abs() < 1e-10);
2611 }
2612
2613 #[test]
2614 fn test_reverse_mean_gradient() {
2615 let mut g = GradGraph::new();
2616 let a = g.parameter(Tensor::from_vec_unchecked(vec![2.0, 4.0, 6.0, 8.0], &[4]));
2617 let loss = g.mean(a);
2618
2619 g.backward(loss);
2620
2621 let ga = g.grad(a).unwrap();
2622 let ga_data = ga.to_vec();
2623 for &v in &ga_data {
2625 assert!((v - 0.25).abs() < 1e-10);
2626 }
2627 }
2628
2629 #[test]
2632 fn test_reverse_sin() {
2633 let mut g = GradGraph::new();
2634 let a = g.parameter(Tensor::from_vec_unchecked(vec![0.0], &[1]));
2635 let b = g.sin(a);
2636 g.backward(b);
2637 let ga = g.grad(a).unwrap();
2638 assert!((ga.to_vec()[0] - 1.0).abs() < 1e-10);
2640 }
2641
2642 #[test]
2643 fn test_reverse_cos() {
2644 let mut g = GradGraph::new();
2645 let a = g.parameter(Tensor::from_vec_unchecked(vec![0.0], &[1]));
2646 let b = g.cos(a);
2647 g.backward(b);
2648 let ga = g.grad(a).unwrap();
2649 assert!(ga.to_vec()[0].abs() < 1e-10);
2651 }
2652
2653 #[test]
2654 fn test_reverse_sqrt() {
2655 let mut g = GradGraph::new();
2656 let a = g.parameter(Tensor::from_vec_unchecked(vec![4.0], &[1]));
2657 let b = g.sqrt(a);
2658 g.backward(b);
2659 let ga = g.grad(a).unwrap();
2660 assert!((ga.to_vec()[0] - 0.25).abs() < 1e-10);
2662 }
2663
2664 #[test]
2665 fn test_reverse_pow() {
2666 let mut g = GradGraph::new();
2667 let a = g.parameter(Tensor::from_vec_unchecked(vec![2.0], &[1]));
2668 let b = g.pow(a, 3.0); g.backward(b);
2670 let ga = g.grad(a).unwrap();
2671 assert!((ga.to_vec()[0] - 12.0).abs() < 1e-10);
2673 }
2674
2675 #[test]
2676 fn test_reverse_sigmoid() {
2677 let mut g = GradGraph::new();
2678 let a = g.parameter(Tensor::from_vec_unchecked(vec![0.0], &[1]));
2679 let b = g.sigmoid(a);
2680 g.backward(b);
2681 let ga = g.grad(a).unwrap();
2682 assert!((ga.to_vec()[0] - 0.25).abs() < 1e-10);
2684 }
2685
2686 #[test]
2687 fn test_reverse_relu_positive() {
2688 let mut g = GradGraph::new();
2689 let a = g.parameter(Tensor::from_vec_unchecked(vec![3.0], &[1]));
2690 let b = g.relu(a);
2691 g.backward(b);
2692 let ga = g.grad(a).unwrap();
2693 assert!((ga.to_vec()[0] - 1.0).abs() < 1e-10);
2695 }
2696
2697 #[test]
2698 fn test_reverse_relu_negative() {
2699 let mut g = GradGraph::new();
2700 let a = g.parameter(Tensor::from_vec_unchecked(vec![-2.0], &[1]));
2701 let b = g.relu(a);
2702 g.backward(b);
2703 let ga = g.grad(a).unwrap();
2704 assert!(ga.to_vec()[0].abs() < 1e-10);
2706 }
2707
2708 #[test]
2709 fn test_reverse_tanh() {
2710 let mut g = GradGraph::new();
2711 let a = g.parameter(Tensor::from_vec_unchecked(vec![0.0], &[1]));
2712 let b = g.tanh_act(a);
2713 g.backward(b);
2714 let ga = g.grad(a).unwrap();
2715 assert!((ga.to_vec()[0] - 1.0).abs() < 1e-10);
2717 }
2718
2719 #[test]
2720 fn test_reverse_sin_cos_chain() {
2721 let mut g = GradGraph::new();
2724 let a = g.parameter(Tensor::from_vec_unchecked(vec![1.0], &[1]));
2725 let c = g.cos(a);
2726 let s = g.sin(c);
2727 g.backward(s);
2728 let ga = g.grad(a).unwrap();
2729 let expected = 1.0_f64.cos().cos() * (-1.0_f64.sin());
2730 assert!((ga.to_vec()[0] - expected).abs() < 1e-10, "got {}, expected {expected}", ga.to_vec()[0]);
2731 }
2732
2733 #[test]
2734 fn test_reverse_sigmoid_sum() {
2735 let mut g = GradGraph::new();
2737 let a = g.parameter(Tensor::from_vec_unchecked(vec![0.0, 1.0, -1.0], &[3]));
2738 let s = g.sigmoid(a);
2739 let loss = g.sum(s);
2740 g.backward(loss);
2741 let ga = g.grad(a).unwrap();
2742 let ga_data = ga.to_vec();
2743 let sig1 = 1.0 / (1.0 + (-1.0_f64).exp());
2745 let sig_neg1 = 1.0 / (1.0 + 1.0_f64.exp());
2746 assert!((ga_data[0] - 0.25).abs() < 1e-10);
2747 assert!((ga_data[1] - sig1 * (1.0 - sig1)).abs() < 1e-10);
2748 assert!((ga_data[2] - sig_neg1 * (1.0 - sig_neg1)).abs() < 1e-10);
2749 }
2750
2751 #[test]
2752 fn test_b8_determinism() {
2753 let mut g1 = GradGraph::new();
2754 let a1 = g1.parameter(Tensor::from_vec_unchecked(vec![1.5], &[1]));
2755 let s1 = g1.sin(a1);
2756 g1.backward(s1);
2757 let ga1 = g1.grad(a1).unwrap().to_vec()[0];
2758
2759 let mut g2 = GradGraph::new();
2760 let a2 = g2.parameter(Tensor::from_vec_unchecked(vec![1.5], &[1]));
2761 let s2 = g2.sin(a2);
2762 g2.backward(s2);
2763 let ga2 = g2.grad(a2).unwrap().to_vec()[0];
2764
2765 assert_eq!(ga1.to_bits(), ga2.to_bits());
2766 }
2767
2768 #[test]
2769 fn test_reverse_mse_loss() {
2770 let mut g = GradGraph::new();
2772
2773 let w = g.parameter(Tensor::from_vec_unchecked(vec![1.0, 1.0], &[2, 1]));
2774 let x = g.input(Tensor::from_vec_unchecked(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]));
2775 let target = g.input(Tensor::from_vec_unchecked(vec![3.0, 7.0], &[2, 1]));
2776
2777 let pred = g.matmul(x, w);
2778 let diff = g.sub(pred, target);
2779 let sq = g.mul(diff, diff);
2780 let loss = g.mean(sq);
2781
2782 let loss_val = g.value(loss);
2783 g.backward(loss);
2784
2785 let gw = g.grad(w).unwrap();
2786
2787 assert!(loss_val.is_finite());
2789 assert_eq!(gw.to_vec().len(), 2);
2790 for &v in &gw.to_vec() {
2791 assert!(v.is_finite());
2792 }
2793 }
2794
2795 #[test]
2798 fn test_reverse_div() {
2799 let mut g = GradGraph::new();
2801 let a = g.parameter(Tensor::from_vec_unchecked(vec![6.0], &[1]));
2802 let b = g.input(Tensor::from_vec_unchecked(vec![2.0], &[1]));
2803 let c = g.div(a, b);
2804 g.backward(c);
2805 let ga = g.grad(a).unwrap();
2806 assert!((ga.to_vec()[0] - 0.5).abs() < 1e-10);
2807 }
2808
2809 #[test]
2810 fn test_reverse_neg() {
2811 let mut g = GradGraph::new();
2813 let a = g.parameter(Tensor::from_vec_unchecked(vec![3.0], &[1]));
2814 let c = g.neg(a);
2815 g.backward(c);
2816 let ga = g.grad(a).unwrap();
2817 assert!((ga.to_vec()[0] - (-1.0)).abs() < 1e-10);
2818 }
2819
2820 #[test]
2821 fn test_reverse_scalar_mul() {
2822 let mut g = GradGraph::new();
2824 let a = g.parameter(Tensor::from_vec_unchecked(vec![2.0], &[1]));
2825 let c = g.scalar_mul(a, 3.0);
2826 g.backward(c);
2827 let ga = g.grad(a).unwrap();
2828 assert!((ga.to_vec()[0] - 3.0).abs() < 1e-10);
2829 }
2830
2831 #[test]
2832 fn test_reverse_exp() {
2833 let mut g = GradGraph::new();
2835 let a = g.parameter(Tensor::from_vec_unchecked(vec![1.0], &[1]));
2836 let c = g.exp(a);
2837 g.backward(c);
2838 let ga = g.grad(a).unwrap();
2839 assert!((ga.to_vec()[0] - std::f64::consts::E).abs() < 1e-10);
2840 }
2841
2842 #[test]
2843 fn test_reverse_ln() {
2844 let mut g = GradGraph::new();
2846 let a = g.parameter(Tensor::from_vec_unchecked(vec![2.0], &[1]));
2847 let c = g.ln(a);
2848 g.backward(c);
2849 let ga = g.grad(a).unwrap();
2850 assert!((ga.to_vec()[0] - 0.5).abs() < 1e-10);
2851 }
2852
2853 #[test]
2856 fn test_reverse_abs_positive() {
2857 let mut g = GradGraph::new();
2858 let a = g.parameter(Tensor::from_vec_unchecked(vec![3.0], &[1]));
2859 let b = g.abs(a);
2860 g.backward(b);
2861 let ga = g.grad(a).unwrap();
2862 assert!((ga.to_vec()[0] - 1.0).abs() < 1e-10);
2864 }
2865
2866 #[test]
2867 fn test_reverse_abs_negative() {
2868 let mut g = GradGraph::new();
2869 let a = g.parameter(Tensor::from_vec_unchecked(vec![-2.5], &[1]));
2870 let b = g.abs(a);
2871 g.backward(b);
2872 let ga = g.grad(a).unwrap();
2873 assert!((ga.to_vec()[0] - (-1.0)).abs() < 1e-10);
2875 }
2876
2877 #[test]
2878 fn test_reverse_abs_zero() {
2879 let mut g = GradGraph::new();
2880 let a = g.parameter(Tensor::from_vec_unchecked(vec![0.0], &[1]));
2881 let b = g.abs(a);
2882 g.backward(b);
2883 let ga = g.grad(a).unwrap();
2884 assert!(ga.to_vec()[0].abs() < 1e-10);
2886 }
2887
2888 #[test]
2889 fn test_reverse_abs_vector() {
2890 let mut g = GradGraph::new();
2891 let a = g.parameter(Tensor::from_vec_unchecked(vec![-1.0, 2.0, 0.0, -3.0], &[4]));
2892 let b = g.abs(a);
2893 let loss = g.sum(b);
2894 g.backward(loss);
2895 let ga = g.grad(a).unwrap();
2896 let expected = vec![-1.0, 1.0, 0.0, -1.0];
2897 for (i, (&got, &exp)) in ga.to_vec().iter().zip(expected.iter()).enumerate() {
2898 assert!((got - exp).abs() < 1e-10, "abs grad[{i}]: got {got}, expected {exp}");
2899 }
2900 }
2901
2902 #[test]
2903 fn test_reverse_log2() {
2904 let mut g = GradGraph::new();
2905 let a = g.parameter(Tensor::from_vec_unchecked(vec![4.0], &[1]));
2906 let b = g.log2(a);
2907 g.backward(b);
2908 let ga = g.grad(a).unwrap();
2909 let expected = 1.0 / (4.0 * std::f64::consts::LN_2);
2911 assert!((ga.to_vec()[0] - expected).abs() < 1e-10);
2912 }
2913
2914 #[test]
2915 fn test_reverse_log2_vector() {
2916 let mut g = GradGraph::new();
2917 let a = g.parameter(Tensor::from_vec_unchecked(vec![1.0, 2.0, 8.0], &[3]));
2918 let b = g.log2(a);
2919 let loss = g.sum(b);
2920 g.backward(loss);
2921 let ga = g.grad(a).unwrap();
2922 let ln2 = std::f64::consts::LN_2;
2923 let expected = vec![1.0 / (1.0 * ln2), 1.0 / (2.0 * ln2), 1.0 / (8.0 * ln2)];
2924 for (i, (&got, &exp)) in ga.to_vec().iter().zip(expected.iter()).enumerate() {
2925 assert!((got - exp).abs() < 1e-10, "log2 grad[{i}]: got {got}, expected {exp}");
2926 }
2927 }
2928
2929 #[test]
2930 fn test_softmax_forward() {
2931 let mut g = GradGraph::new();
2932 let a = g.input(Tensor::from_vec_unchecked(vec![1.0, 2.0, 3.0], &[3]));
2933 let b = g.softmax(a);
2934 let sm = g.tensor(b);
2935 let sm_data = sm.to_vec();
2936 let sum: f64 = sm_data.iter().sum();
2938 assert!((sum - 1.0).abs() < 1e-10);
2939 assert!(sm_data[2] > sm_data[1]);
2941 assert!(sm_data[1] > sm_data[0]);
2942 }
2943
2944 #[test]
2945 fn test_reverse_softmax() {
2946 let mut g = GradGraph::new();
2948 let a = g.parameter(Tensor::from_vec_unchecked(vec![1.0, 2.0, 3.0], &[3]));
2949 let b = g.softmax(a);
2950 let loss = g.sum(b);
2951 g.backward(loss);
2952 let ga = g.grad(a).unwrap();
2953 for &v in &ga.to_vec() {
2955 assert!(v.abs() < 1e-10, "softmax sum grad should be 0, got {v}");
2956 }
2957 }
2958
2959 #[test]
2960 fn test_reverse_softmax_single_element() {
2961 let mut g = GradGraph::new();
2963 let a = g.parameter(Tensor::from_vec_unchecked(vec![2.0, 1.0], &[2]));
2964 let b = g.softmax(a);
2965 let loss = g.sum(b);
2968 g.backward(loss);
2969 let ga = g.grad(a).unwrap();
2970 for &v in &ga.to_vec() {
2971 assert!(v.abs() < 1e-10);
2972 }
2973 }
2974
2975 #[test]
2976 fn test_cross_entropy_forward() {
2977 let mut g = GradGraph::new();
2978 let logits = g.input(Tensor::from_vec_unchecked(vec![2.0, 1.0, 0.1], &[3]));
2979 let targets = g.input(Tensor::from_vec_unchecked(vec![1.0, 0.0, 0.0], &[3])); let ce = g.cross_entropy(logits, targets);
2981 let loss_val = g.value(ce);
2982 assert!(loss_val > 0.0, "CE loss should be positive");
2983 assert!(loss_val.is_finite(), "CE loss should be finite");
2984 }
2985
2986 #[test]
2987 fn test_reverse_cross_entropy() {
2988 let mut g = GradGraph::new();
2989 let logits = g.parameter(Tensor::from_vec_unchecked(vec![2.0, 1.0, 0.1], &[3]));
2990 let targets = g.input(Tensor::from_vec_unchecked(vec![1.0, 0.0, 0.0], &[3]));
2991 let ce = g.cross_entropy(logits, targets);
2992 g.backward(ce);
2993 let ga = g.grad(logits).unwrap();
2994 let ga_data = ga.to_vec();
2995 assert!(ga_data[0] < 0.0, "CE grad for correct class should be negative");
2999 assert!(ga_data[1] > 0.0, "CE grad for incorrect class should be positive");
3000 assert!(ga_data[2] > 0.0, "CE grad for incorrect class should be positive");
3001 let sum: f64 = ga_data.iter().sum();
3003 assert!(sum.abs() < 1e-10, "CE grad should sum to 0, got {sum}");
3004 }
3005
3006 #[test]
3007 fn test_layer_norm_forward() {
3008 let mut g = GradGraph::new();
3009 let a = g.input(Tensor::from_vec_unchecked(vec![1.0, 2.0, 3.0, 4.0], &[4]));
3010 let b = g.layer_norm(a);
3011 let normed = g.tensor(b).to_vec();
3012 let mean: f64 = normed.iter().sum::<f64>() / normed.len() as f64;
3014 assert!(mean.abs() < 1e-5, "LayerNorm mean should be ~0, got {mean}");
3015 let var: f64 = normed.iter().map(|x| (x - mean) * (x - mean)).sum::<f64>() / normed.len() as f64;
3016 assert!((var - 1.0).abs() < 0.01, "LayerNorm variance should be ~1, got {var}");
3017 }
3018
3019 #[test]
3020 fn test_reverse_layer_norm() {
3021 let mut g = GradGraph::new();
3022 let a = g.parameter(Tensor::from_vec_unchecked(vec![1.0, 2.0, 3.0, 4.0], &[4]));
3023 let b = g.layer_norm(a);
3024 let loss = g.sum(b);
3025 g.backward(loss);
3026 let ga = g.grad(a).unwrap();
3027 for &v in &ga.to_vec() {
3029 assert!(v.is_finite(), "LayerNorm grad should be finite");
3030 }
3031 let eps = 1e-5;
3033 let input_data = vec![1.0, 2.0, 3.0, 4.0];
3034 for i in 0..4 {
3035 let mut plus = input_data.clone();
3036 plus[i] += eps;
3037 let mut g_plus = GradGraph::new();
3038 let a_plus = g_plus.input(Tensor::from_vec_unchecked(plus, &[4]));
3039 let b_plus = g_plus.layer_norm(a_plus);
3040 let loss_plus = g_plus.sum(b_plus);
3041 let val_plus = g_plus.value(loss_plus);
3042
3043 let mut minus = input_data.clone();
3044 minus[i] -= eps;
3045 let mut g_minus = GradGraph::new();
3046 let a_minus = g_minus.input(Tensor::from_vec_unchecked(minus, &[4]));
3047 let b_minus = g_minus.layer_norm(a_minus);
3048 let loss_minus = g_minus.sum(b_minus);
3049 let val_minus = g_minus.value(loss_minus);
3050
3051 let fd_grad = (val_plus - val_minus) / (2.0 * eps);
3052 let ad_grad = ga.to_vec()[i];
3053 assert!(
3054 (fd_grad - ad_grad).abs() < 1e-4,
3055 "LayerNorm FD check failed at [{i}]: fd={fd_grad}, ad={ad_grad}"
3056 );
3057 }
3058 }
3059
3060 #[test]
3061 fn test_batch_norm_forward() {
3062 let mut g = GradGraph::new();
3063 let a = g.input(Tensor::from_vec_unchecked(vec![2.0, 4.0, 6.0, 8.0], &[4]));
3064 let b = g.batch_norm(a);
3065 let normed = g.tensor(b).to_vec();
3066 let mean: f64 = normed.iter().sum::<f64>() / normed.len() as f64;
3067 assert!(mean.abs() < 1e-5, "BatchNorm mean should be ~0, got {mean}");
3068 }
3069
3070 #[test]
3071 fn test_reverse_batch_norm() {
3072 let mut g = GradGraph::new();
3073 let a = g.parameter(Tensor::from_vec_unchecked(vec![2.0, 4.0, 6.0, 8.0], &[4]));
3074 let b = g.batch_norm(a);
3075 let loss = g.sum(b);
3076 g.backward(loss);
3077 let ga = g.grad(a).unwrap();
3078 for &v in &ga.to_vec() {
3079 assert!(v.is_finite(), "BatchNorm grad should be finite");
3080 }
3081 }
3082
3083 #[test]
3084 fn test_reverse_clamp_in_range() {
3085 let mut g = GradGraph::new();
3086 let a = g.parameter(Tensor::from_vec_unchecked(vec![1.5], &[1]));
3087 let b = g.clamp(a, 0.0, 3.0);
3088 g.backward(b);
3089 let ga = g.grad(a).unwrap();
3090 assert!((ga.to_vec()[0] - 1.0).abs() < 1e-10);
3092 }
3093
3094 #[test]
3095 fn test_reverse_clamp_out_of_range() {
3096 let mut g = GradGraph::new();
3097 let a = g.parameter(Tensor::from_vec_unchecked(vec![5.0], &[1]));
3098 let b = g.clamp(a, 0.0, 3.0);
3099 g.backward(b);
3100 let ga = g.grad(a).unwrap();
3101 assert!(ga.to_vec()[0].abs() < 1e-10);
3103 }
3104
3105 #[test]
3106 fn test_reverse_clamp_vector() {
3107 let mut g = GradGraph::new();
3108 let a = g.parameter(Tensor::from_vec_unchecked(vec![-1.0, 0.5, 2.0, 4.0], &[4]));
3109 let b = g.clamp(a, 0.0, 3.0);
3110 let loss = g.sum(b);
3111 g.backward(loss);
3112 let ga = g.grad(a).unwrap();
3113 let expected = vec![0.0, 1.0, 1.0, 0.0]; for (i, (&got, &exp)) in ga.to_vec().iter().zip(expected.iter()).enumerate() {
3115 assert!((got - exp).abs() < 1e-10, "clamp grad[{i}]: got {got}, expected {exp}");
3116 }
3117 }
3118
3119 #[test]
3120 fn test_reverse_where_cond() {
3121 let mut g = GradGraph::new();
3122 let cond = g.input(Tensor::from_vec_unchecked(vec![1.0, 0.0, 1.0], &[3]));
3123 let a = g.parameter(Tensor::from_vec_unchecked(vec![10.0, 20.0, 30.0], &[3]));
3124 let b = g.parameter(Tensor::from_vec_unchecked(vec![100.0, 200.0, 300.0], &[3]));
3125 let w = g.where_cond(cond, a, b);
3126 let result = g.tensor(w).to_vec();
3128 assert!((result[0] - 10.0).abs() < 1e-10);
3129 assert!((result[1] - 200.0).abs() < 1e-10);
3130 assert!((result[2] - 30.0).abs() < 1e-10);
3131 let loss = g.sum(w);
3132 g.backward(loss);
3133 let ga = g.grad(a).unwrap().to_vec();
3134 let gb = g.grad(b).unwrap().to_vec();
3135 assert!((ga[0] - 1.0).abs() < 1e-10);
3137 assert!(ga[1].abs() < 1e-10);
3138 assert!((ga[2] - 1.0).abs() < 1e-10);
3139 assert!(gb[0].abs() < 1e-10);
3140 assert!((gb[1] - 1.0).abs() < 1e-10);
3141 assert!(gb[2].abs() < 1e-10);
3142 }
3143
3144 #[test]
3145 fn test_reverse_reshape() {
3146 let mut g = GradGraph::new();
3147 let a = g.parameter(Tensor::from_vec_unchecked(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]));
3148 let b = g.reshape(a, &[3, 2]);
3149 let loss = g.sum(b);
3150 g.backward(loss);
3151 let ga = g.grad(a).unwrap();
3152 assert_eq!(ga.shape(), &[2, 3]);
3154 for &v in &ga.to_vec() {
3155 assert!((v - 1.0).abs() < 1e-10);
3156 }
3157 }
3158
3159 #[test]
3160 fn test_reverse_transpose_op() {
3161 let mut g = GradGraph::new();
3162 let a = g.parameter(Tensor::from_vec_unchecked(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]));
3163 let b = g.transpose_op(a);
3164 assert_eq!(g.tensor(b).shape(), &[3, 2]);
3166 let loss = g.sum(b);
3167 g.backward(loss);
3168 let ga = g.grad(a).unwrap();
3169 assert_eq!(ga.shape(), &[2, 3]);
3171 for &v in &ga.to_vec() {
3172 assert!((v - 1.0).abs() < 1e-10);
3173 }
3174 }
3175
3176 #[test]
3177 fn test_reverse_cat_1d() {
3178 let mut g = GradGraph::new();
3179 let a = g.parameter(Tensor::from_vec_unchecked(vec![1.0, 2.0], &[2]));
3180 let b = g.parameter(Tensor::from_vec_unchecked(vec![3.0, 4.0, 5.0], &[3]));
3181 let c = g.cat(&[a, b], 0);
3182 let result = g.tensor(c).to_vec();
3184 assert_eq!(result, vec![1.0, 2.0, 3.0, 4.0, 5.0]);
3185 let loss = g.sum(c);
3186 g.backward(loss);
3187 let ga = g.grad(a).unwrap().to_vec();
3188 let gb = g.grad(b).unwrap().to_vec();
3189 assert_eq!(ga, vec![1.0, 1.0]);
3191 assert_eq!(gb, vec![1.0, 1.0, 1.0]);
3192 }
3193
3194 #[test]
3195 fn test_reverse_cat_2d_axis0() {
3196 let mut g = GradGraph::new();
3197 let a = g.parameter(Tensor::from_vec_unchecked(vec![1.0, 2.0], &[1, 2]));
3198 let b = g.parameter(Tensor::from_vec_unchecked(vec![3.0, 4.0, 5.0, 6.0], &[2, 2]));
3199 let c = g.cat(&[a, b], 0);
3200 assert_eq!(g.tensor(c).shape(), &[3, 2]);
3201 let loss = g.sum(c);
3202 g.backward(loss);
3203 let ga = g.grad(a).unwrap();
3204 let gb = g.grad(b).unwrap();
3205 assert_eq!(ga.shape(), &[1, 2]);
3206 assert_eq!(gb.shape(), &[2, 2]);
3207 for &v in &ga.to_vec() {
3208 assert!((v - 1.0).abs() < 1e-10);
3209 }
3210 for &v in &gb.to_vec() {
3211 assert!((v - 1.0).abs() < 1e-10);
3212 }
3213 }
3214
3215 #[test]
3216 fn test_reverse_gather_1d() {
3217 let mut g = GradGraph::new();
3218 let a = g.parameter(Tensor::from_vec_unchecked(vec![10.0, 20.0, 30.0, 40.0], &[4]));
3219 let b = g.gather(a, &[1, 3], 0);
3220 let result = g.tensor(b).to_vec();
3222 assert!((result[0] - 20.0).abs() < 1e-10);
3223 assert!((result[1] - 40.0).abs() < 1e-10);
3224 let loss = g.sum(b);
3225 g.backward(loss);
3226 let ga = g.grad(a).unwrap().to_vec();
3227 assert!((ga[0]).abs() < 1e-10);
3229 assert!((ga[1] - 1.0).abs() < 1e-10);
3230 assert!((ga[2]).abs() < 1e-10);
3231 assert!((ga[3] - 1.0).abs() < 1e-10);
3232 }
3233
3234 #[test]
3235 fn test_reverse_gather_duplicate_indices() {
3236 let mut g = GradGraph::new();
3237 let a = g.parameter(Tensor::from_vec_unchecked(vec![10.0, 20.0, 30.0], &[3]));
3238 let b = g.gather(a, &[1, 1, 2], 0);
3239 let loss = g.sum(b);
3240 g.backward(loss);
3241 let ga = g.grad(a).unwrap().to_vec();
3242 assert!((ga[0]).abs() < 1e-10);
3244 assert!((ga[1] - 2.0).abs() < 1e-10);
3245 assert!((ga[2] - 1.0).abs() < 1e-10);
3246 }
3247
3248 #[test]
3249 fn test_phase8_determinism() {
3250 for _ in 0..2 {
3252 let run = || {
3253 let mut g = GradGraph::new();
3254 let a = g.parameter(Tensor::from_vec_unchecked(vec![1.0, -2.0, 3.0, -0.5], &[4]));
3255 let b = g.abs(a);
3256 let c = g.clamp(b, 0.0, 2.5);
3257 let d = g.layer_norm(c);
3258 let loss = g.sum(d);
3259 g.backward(loss);
3260 g.grad(a).unwrap().to_vec()
3261 };
3262 let r1 = run();
3263 let r2 = run();
3264 for (i, (v1, v2)) in r1.iter().zip(r2.iter()).enumerate() {
3265 assert_eq!(v1.to_bits(), v2.to_bits(), "Determinism failed at [{i}]");
3266 }
3267 }
3268 }
3269
3270 #[test]
3271 fn test_phase8_softmax_cross_entropy_chain() {
3272 let mut g = GradGraph::new();
3274 let logits = g.parameter(Tensor::from_vec_unchecked(vec![1.0, 2.0, 3.0], &[3]));
3275 let targets = g.input(Tensor::from_vec_unchecked(vec![0.0, 0.0, 1.0], &[3]));
3276 let ce = g.cross_entropy(logits, targets);
3277 g.backward(ce);
3278 let ga = g.grad(logits).unwrap().to_vec();
3279 for &v in &ga {
3281 assert!(v.is_finite());
3282 }
3283 assert!(ga[2] < 0.0, "CE grad for correct class should be negative");
3285 }
3286
3287 #[test]
3290 fn test_double_backward_cubic() {
3291 let mut g = GradGraph::new();
3293 let x = g.parameter(Tensor::from_vec_unchecked(vec![2.0], &[1]));
3294 let x2 = g.mul(x, x);
3295 let x3 = g.mul(x2, x);
3296 let loss = g.sum(x3);
3297 let hess = g.double_backward(loss, x);
3298 assert!((hess.to_vec()[0] - 12.0).abs() < 1e-4);
3300 }
3301
3302 #[test]
3303 fn test_full_hessian_quadratic() {
3304 let mut g = GradGraph::new();
3307 let p = g.parameter(Tensor::from_vec_unchecked(vec![1.0, 1.0], &[2]));
3308 let p2 = g.mul(p, p); let s = g.sum(p2); let hess = g.hessian(s, p);
3311 let h = hess.to_vec();
3312 assert!((h[0] - 2.0).abs() < 1e-3); assert!((h[1] - 0.0).abs() < 1e-3); assert!((h[2] - 0.0).abs() < 1e-3); assert!((h[3] - 2.0).abs() < 1e-3); }
3317
3318 #[test]
3319 fn test_vmap_forward() {
3320 let mut g = GradGraph::new();
3321 let x = g.parameter(Tensor::from_vec_unchecked(vec![1.0], &[1]));
3322 let x2 = g.mul(x, x);
3323 let loss = g.sum(x2);
3324
3325 let batch = vec![
3327 Tensor::from_vec_unchecked(vec![1.0], &[1]),
3328 Tensor::from_vec_unchecked(vec![2.0], &[1]),
3329 Tensor::from_vec_unchecked(vec![3.0], &[1]),
3330 ];
3331 let results = g.vmap_forward(x, &batch);
3332 assert!((g.value(results[0]) - 1.0).abs() < 1e-10);
3334 assert!((g.value(results[1]) - 4.0).abs() < 1e-10);
3335 assert!((g.value(results[2]) - 9.0).abs() < 1e-10);
3336 }
3337
3338 #[test]
3339 fn test_hessian_determinism() {
3340 let mut g = GradGraph::new();
3341 let p = g.parameter(Tensor::from_vec_unchecked(vec![3.0, 4.0], &[2]));
3342 let p2 = g.mul(p, p);
3343 let s = g.sum(p2);
3344 let h1 = g.hessian(s, p);
3345 g.nodes[p].borrow_mut().tensor = Tensor::from_vec_unchecked(vec![3.0, 4.0], &[2]);
3347 let h2 = g.hessian(s, p);
3348 assert_eq!(h1.to_vec(), h2.to_vec(), "Hessian must be deterministic");
3349 }
3350}