1use cjc_runtime::Tensor;
8use std::cell::RefCell;
9use std::rc::Rc;
10
11pub mod pinn;
12
13#[derive(Debug, Clone)]
17pub struct Dual {
18 pub value: f64,
19 pub deriv: f64,
20}
21
22impl Dual {
23 pub fn new(value: f64, deriv: f64) -> Self {
24 Self { value, deriv }
25 }
26
27 pub fn constant(value: f64) -> Self {
28 Self { value, deriv: 0.0 }
29 }
30
31 pub fn variable(value: f64) -> Self {
32 Self { value, deriv: 1.0 }
33 }
34
35 pub fn zero() -> Self {
36 Self {
37 value: 0.0,
38 deriv: 0.0,
39 }
40 }
41
42 pub fn one() -> Self {
43 Self {
44 value: 1.0,
45 deriv: 0.0,
46 }
47 }
48}
49
50impl std::ops::Add for Dual {
51 type Output = Dual;
52 fn add(self, rhs: Dual) -> Dual {
53 Dual {
54 value: self.value + rhs.value,
55 deriv: self.deriv + rhs.deriv,
56 }
57 }
58}
59
60impl std::ops::Sub for Dual {
61 type Output = Dual;
62 fn sub(self, rhs: Dual) -> Dual {
63 Dual {
64 value: self.value - rhs.value,
65 deriv: self.deriv - rhs.deriv,
66 }
67 }
68}
69
70impl std::ops::Mul for Dual {
71 type Output = Dual;
72 fn mul(self, rhs: Dual) -> Dual {
73 Dual {
74 value: self.value * rhs.value,
75 deriv: self.value * rhs.deriv + self.deriv * rhs.value,
76 }
77 }
78}
79
80impl std::ops::Div for Dual {
81 type Output = Dual;
82 fn div(self, rhs: Dual) -> Dual {
83 let denom = rhs.value * rhs.value;
84 Dual {
85 value: self.value / rhs.value,
86 deriv: (self.deriv * rhs.value - self.value * rhs.deriv) / denom,
87 }
88 }
89}
90
91impl std::ops::Neg for Dual {
92 type Output = Dual;
93 fn neg(self) -> Dual {
94 Dual {
95 value: -self.value,
96 deriv: -self.deriv,
97 }
98 }
99}
100
101impl Dual {
102 pub fn sin(self) -> Dual {
103 Dual {
104 value: self.value.sin(),
105 deriv: self.deriv * self.value.cos(),
106 }
107 }
108
109 pub fn cos(self) -> Dual {
110 Dual {
111 value: self.value.cos(),
112 deriv: -self.deriv * self.value.sin(),
113 }
114 }
115
116 pub fn exp(self) -> Dual {
117 let e = self.value.exp();
118 Dual {
119 value: e,
120 deriv: self.deriv * e,
121 }
122 }
123
124 pub fn ln(self) -> Dual {
125 Dual {
126 value: self.value.ln(),
127 deriv: self.deriv / self.value,
128 }
129 }
130
131 pub fn sqrt(self) -> Dual {
132 let s = self.value.sqrt();
133 Dual {
134 value: s,
135 deriv: self.deriv / (2.0 * s),
136 }
137 }
138
139 pub fn pow(self, n: f64) -> Dual {
140 Dual {
141 value: self.value.powf(n),
142 deriv: self.deriv * n * self.value.powf(n - 1.0),
143 }
144 }
145}
146
147#[derive(Debug, Clone)]
151pub enum GradOp {
152 Input,
153 Parameter,
154 Add(usize, usize),
155 Sub(usize, usize),
156 Mul(usize, usize),
157 Div(usize, usize),
158 Neg(usize),
159 MatMul(usize, usize),
160 Sum(usize),
161 Mean(usize),
162 ScalarMul(usize, f64),
163 Exp(usize),
164 Ln(usize),
165 StructField {
167 parent: usize,
168 field_index: usize,
169 total_fields: usize,
170 },
171 MapLookup {
173 map_node: usize,
174 key_index: usize,
175 total_keys: usize,
176 },
177 Sin(usize),
179 Cos(usize),
180 Sqrt(usize),
181 Pow(usize, f64),
182 Sigmoid(usize),
183 Relu(usize),
184 TanhAct(usize),
185 Abs(usize),
187 Log2(usize),
188 Softmax(usize),
189 CrossEntropy { logits: usize, targets: usize },
191 LayerNorm(usize),
193 BatchNorm(usize),
195 Clamp { input: usize, min: f64, max: f64 },
197 Where { cond: usize, on_true: usize, on_false: usize },
200 Reshape { input: usize, original_shape: Vec<usize> },
202 TransposeOp(usize),
204 CatOp { inputs: Vec<usize>, axis: usize, sizes: Vec<usize> },
206 GatherOp { input: usize, indices: Vec<usize>, axis: usize },
208}
209
210#[derive(Debug, Clone)]
212pub struct GradNode {
213 pub op: GradOp,
214 pub tensor: Tensor,
215 pub grad: Option<Tensor>,
216}
217
218pub struct GradGraph {
220 pub nodes: Vec<Rc<RefCell<GradNode>>>,
221}
222
223impl GradGraph {
224 pub fn new() -> Self {
225 Self { nodes: Vec::new() }
226 }
227
228 pub fn input(&mut self, tensor: Tensor) -> usize {
230 let idx = self.nodes.len();
231 self.nodes.push(Rc::new(RefCell::new(GradNode {
232 op: GradOp::Input,
233 tensor,
234 grad: None,
235 })));
236 idx
237 }
238
239 pub fn parameter(&mut self, tensor: Tensor) -> usize {
241 let idx = self.nodes.len();
242 let shape = tensor.shape().to_vec();
243 self.nodes.push(Rc::new(RefCell::new(GradNode {
244 op: GradOp::Parameter,
245 tensor,
246 grad: Some(Tensor::zeros(&shape)),
247 })));
248 idx
249 }
250
251 pub fn add(&mut self, a: usize, b: usize) -> usize {
253 let a_t = self.nodes[a].borrow().tensor.clone();
254 let b_t = self.nodes[b].borrow().tensor.clone();
255 let result = a_t.add_unchecked(&b_t);
256 let idx = self.nodes.len();
257 self.nodes.push(Rc::new(RefCell::new(GradNode {
258 op: GradOp::Add(a, b),
259 tensor: result,
260 grad: None,
261 })));
262 idx
263 }
264
265 pub fn sub(&mut self, a: usize, b: usize) -> usize {
267 let a_t = self.nodes[a].borrow().tensor.clone();
268 let b_t = self.nodes[b].borrow().tensor.clone();
269 let result = a_t.sub_unchecked(&b_t);
270 let idx = self.nodes.len();
271 self.nodes.push(Rc::new(RefCell::new(GradNode {
272 op: GradOp::Sub(a, b),
273 tensor: result,
274 grad: None,
275 })));
276 idx
277 }
278
279 pub fn mul(&mut self, a: usize, b: usize) -> usize {
281 let a_t = self.nodes[a].borrow().tensor.clone();
282 let b_t = self.nodes[b].borrow().tensor.clone();
283 let result = a_t.mul_elem_unchecked(&b_t);
284 let idx = self.nodes.len();
285 self.nodes.push(Rc::new(RefCell::new(GradNode {
286 op: GradOp::Mul(a, b),
287 tensor: result,
288 grad: None,
289 })));
290 idx
291 }
292
293 pub fn matmul(&mut self, a: usize, b: usize) -> usize {
295 let a_t = self.nodes[a].borrow().tensor.clone();
296 let b_t = self.nodes[b].borrow().tensor.clone();
297 let result = a_t.matmul_unchecked(&b_t);
298 let idx = self.nodes.len();
299 self.nodes.push(Rc::new(RefCell::new(GradNode {
300 op: GradOp::MatMul(a, b),
301 tensor: result,
302 grad: None,
303 })));
304 idx
305 }
306
307 pub fn sum(&mut self, a: usize) -> usize {
309 let a_t = self.nodes[a].borrow().tensor.clone();
310 let s = a_t.sum();
311 let result = Tensor::from_vec_unchecked(vec![s], &[1]);
312 let idx = self.nodes.len();
313 self.nodes.push(Rc::new(RefCell::new(GradNode {
314 op: GradOp::Sum(a),
315 tensor: result,
316 grad: None,
317 })));
318 idx
319 }
320
321 pub fn mean(&mut self, a: usize) -> usize {
323 let a_t = self.nodes[a].borrow().tensor.clone();
324 let m = a_t.mean();
325 let result = Tensor::from_vec_unchecked(vec![m], &[1]);
326 let idx = self.nodes.len();
327 self.nodes.push(Rc::new(RefCell::new(GradNode {
328 op: GradOp::Mean(a),
329 tensor: result,
330 grad: None,
331 })));
332 idx
333 }
334
335 pub fn sin(&mut self, a: usize) -> usize {
339 let a_t = self.nodes[a].borrow().tensor.clone();
340 let data = a_t.to_vec();
341 let result = Tensor::from_vec_unchecked(
342 data.iter().map(|&x| x.sin()).collect(),
343 a_t.shape(),
344 );
345 let idx = self.nodes.len();
346 self.nodes.push(Rc::new(RefCell::new(GradNode {
347 op: GradOp::Sin(a),
348 tensor: result,
349 grad: None,
350 })));
351 idx
352 }
353
354 pub fn cos(&mut self, a: usize) -> usize {
356 let a_t = self.nodes[a].borrow().tensor.clone();
357 let data = a_t.to_vec();
358 let result = Tensor::from_vec_unchecked(
359 data.iter().map(|&x| x.cos()).collect(),
360 a_t.shape(),
361 );
362 let idx = self.nodes.len();
363 self.nodes.push(Rc::new(RefCell::new(GradNode {
364 op: GradOp::Cos(a),
365 tensor: result,
366 grad: None,
367 })));
368 idx
369 }
370
371 pub fn sqrt(&mut self, a: usize) -> usize {
373 let a_t = self.nodes[a].borrow().tensor.clone();
374 let data = a_t.to_vec();
375 let result = Tensor::from_vec_unchecked(
376 data.iter().map(|&x| x.sqrt()).collect(),
377 a_t.shape(),
378 );
379 let idx = self.nodes.len();
380 self.nodes.push(Rc::new(RefCell::new(GradNode {
381 op: GradOp::Sqrt(a),
382 tensor: result,
383 grad: None,
384 })));
385 idx
386 }
387
388 pub fn pow(&mut self, a: usize, n: f64) -> usize {
390 let a_t = self.nodes[a].borrow().tensor.clone();
391 let data = a_t.to_vec();
392 let result = Tensor::from_vec_unchecked(
393 data.iter().map(|&x| x.powf(n)).collect(),
394 a_t.shape(),
395 );
396 let idx = self.nodes.len();
397 self.nodes.push(Rc::new(RefCell::new(GradNode {
398 op: GradOp::Pow(a, n),
399 tensor: result,
400 grad: None,
401 })));
402 idx
403 }
404
405 pub fn sigmoid(&mut self, a: usize) -> usize {
407 let a_t = self.nodes[a].borrow().tensor.clone();
408 let data = a_t.to_vec();
409 let result = Tensor::from_vec_unchecked(
410 data.iter().map(|&x| 1.0 / (1.0 + (-x).exp())).collect(),
411 a_t.shape(),
412 );
413 let idx = self.nodes.len();
414 self.nodes.push(Rc::new(RefCell::new(GradNode {
415 op: GradOp::Sigmoid(a),
416 tensor: result,
417 grad: None,
418 })));
419 idx
420 }
421
422 pub fn relu(&mut self, a: usize) -> usize {
424 let a_t = self.nodes[a].borrow().tensor.clone();
425 let data = a_t.to_vec();
426 let result = Tensor::from_vec_unchecked(
427 data.iter().map(|&x| if x > 0.0 { x } else { 0.0 }).collect(),
428 a_t.shape(),
429 );
430 let idx = self.nodes.len();
431 self.nodes.push(Rc::new(RefCell::new(GradNode {
432 op: GradOp::Relu(a),
433 tensor: result,
434 grad: None,
435 })));
436 idx
437 }
438
439 pub fn tanh_act(&mut self, a: usize) -> usize {
441 let a_t = self.nodes[a].borrow().tensor.clone();
442 let data = a_t.to_vec();
443 let result = Tensor::from_vec_unchecked(
444 data.iter().map(|&x| x.tanh()).collect(),
445 a_t.shape(),
446 );
447 let idx = self.nodes.len();
448 self.nodes.push(Rc::new(RefCell::new(GradNode {
449 op: GradOp::TanhAct(a),
450 tensor: result,
451 grad: None,
452 })));
453 idx
454 }
455
456 pub fn abs(&mut self, a: usize) -> usize {
460 let a_t = self.nodes[a].borrow().tensor.clone();
461 let data = a_t.to_vec();
462 let result = Tensor::from_vec_unchecked(
463 data.iter().map(|&x| x.abs()).collect(),
464 a_t.shape(),
465 );
466 let idx = self.nodes.len();
467 self.nodes.push(Rc::new(RefCell::new(GradNode {
468 op: GradOp::Abs(a),
469 tensor: result,
470 grad: None,
471 })));
472 idx
473 }
474
475 pub fn log2(&mut self, a: usize) -> usize {
477 let a_t = self.nodes[a].borrow().tensor.clone();
478 let data = a_t.to_vec();
479 let result = Tensor::from_vec_unchecked(
480 data.iter().map(|&x| x.log2()).collect(),
481 a_t.shape(),
482 );
483 let idx = self.nodes.len();
484 self.nodes.push(Rc::new(RefCell::new(GradNode {
485 op: GradOp::Log2(a),
486 tensor: result,
487 grad: None,
488 })));
489 idx
490 }
491
492 pub fn softmax(&mut self, a: usize) -> usize {
495 use cjc_repro::KahanAccumulatorF64;
496 let a_t = self.nodes[a].borrow().tensor.clone();
497 let data = a_t.to_vec();
498 let max_val = data.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
499 let exp_shifted: Vec<f64> = data.iter().map(|&x| (x - max_val).exp()).collect();
500 let mut sum_acc = KahanAccumulatorF64::new();
501 for &v in &exp_shifted {
502 sum_acc.add(v);
503 }
504 let sum_exp = sum_acc.finalize();
505 let softmax_data: Vec<f64> = exp_shifted.iter().map(|&e| e / sum_exp).collect();
506 let result = Tensor::from_vec_unchecked(softmax_data, a_t.shape());
507 let idx = self.nodes.len();
508 self.nodes.push(Rc::new(RefCell::new(GradNode {
509 op: GradOp::Softmax(a),
510 tensor: result,
511 grad: None,
512 })));
513 idx
514 }
515
516 pub fn cross_entropy(&mut self, logits: usize, targets: usize) -> usize {
520 use cjc_repro::KahanAccumulatorF64;
521 let logits_t = self.nodes[logits].borrow().tensor.clone();
522 let targets_t = self.nodes[targets].borrow().tensor.clone();
523 let logits_data = logits_t.to_vec();
524 let targets_data = targets_t.to_vec();
525 let max_val = logits_data.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
527 let shifted: Vec<f64> = logits_data.iter().map(|&x| x - max_val).collect();
528 let exp_shifted: Vec<f64> = shifted.iter().map(|&x| x.exp()).collect();
529 let mut sum_acc = KahanAccumulatorF64::new();
530 for &v in &exp_shifted {
531 sum_acc.add(v);
532 }
533 let log_sum_exp = sum_acc.finalize().ln();
534 let log_softmax: Vec<f64> = shifted.iter().map(|&x| x - log_sum_exp).collect();
535 let mut ce_acc = KahanAccumulatorF64::new();
537 for (t, ls) in targets_data.iter().zip(log_softmax.iter()) {
538 ce_acc.add(-t * ls);
539 }
540 let ce = ce_acc.finalize();
541 let result = Tensor::from_vec_unchecked(vec![ce], &[1]);
542 let idx = self.nodes.len();
543 self.nodes.push(Rc::new(RefCell::new(GradNode {
544 op: GradOp::CrossEntropy { logits, targets },
545 tensor: result,
546 grad: None,
547 })));
548 idx
549 }
550
551 pub fn layer_norm(&mut self, a: usize) -> usize {
554 use cjc_repro::KahanAccumulatorF64;
555 let a_t = self.nodes[a].borrow().tensor.clone();
556 let data = a_t.to_vec();
557 let n = data.len() as f64;
558 let mut mean_acc = KahanAccumulatorF64::new();
560 for &v in &data {
561 mean_acc.add(v);
562 }
563 let mean = mean_acc.finalize() / n;
564 let mut var_acc = KahanAccumulatorF64::new();
566 for &v in &data {
567 let d = v - mean;
568 var_acc.add(d * d);
569 }
570 let var = var_acc.finalize() / n;
571 let eps = 1e-5;
572 let std = (var + eps).sqrt();
573 let normed: Vec<f64> = data.iter().map(|&x| (x - mean) / std).collect();
574 let result = Tensor::from_vec_unchecked(normed, a_t.shape());
575 let idx = self.nodes.len();
576 self.nodes.push(Rc::new(RefCell::new(GradNode {
577 op: GradOp::LayerNorm(a),
578 tensor: result,
579 grad: None,
580 })));
581 idx
582 }
583
584 pub fn batch_norm(&mut self, a: usize) -> usize {
589 use cjc_repro::KahanAccumulatorF64;
590 let a_t = self.nodes[a].borrow().tensor.clone();
591 let data = a_t.to_vec();
592 let n = data.len() as f64;
593 let mut mean_acc = KahanAccumulatorF64::new();
594 for &v in &data {
595 mean_acc.add(v);
596 }
597 let mean = mean_acc.finalize() / n;
598 let mut var_acc = KahanAccumulatorF64::new();
599 for &v in &data {
600 let d = v - mean;
601 var_acc.add(d * d);
602 }
603 let var = var_acc.finalize() / n;
604 let eps = 1e-5;
605 let std = (var + eps).sqrt();
606 let normed: Vec<f64> = data.iter().map(|&x| (x - mean) / std).collect();
607 let result = Tensor::from_vec_unchecked(normed, a_t.shape());
608 let idx = self.nodes.len();
609 self.nodes.push(Rc::new(RefCell::new(GradNode {
610 op: GradOp::BatchNorm(a),
611 tensor: result,
612 grad: None,
613 })));
614 idx
615 }
616
617 pub fn clamp(&mut self, a: usize, min: f64, max: f64) -> usize {
619 let a_t = self.nodes[a].borrow().tensor.clone();
620 let data = a_t.to_vec();
621 let result = Tensor::from_vec_unchecked(
622 data.iter().map(|&x| x.max(min).min(max)).collect(),
623 a_t.shape(),
624 );
625 let idx = self.nodes.len();
626 self.nodes.push(Rc::new(RefCell::new(GradNode {
627 op: GradOp::Clamp { input: a, min, max },
628 tensor: result,
629 grad: None,
630 })));
631 idx
632 }
633
634 pub fn where_cond(&mut self, cond: usize, on_true: usize, on_false: usize) -> usize {
637 let cond_t = self.nodes[cond].borrow().tensor.clone();
638 let true_t = self.nodes[on_true].borrow().tensor.clone();
639 let false_t = self.nodes[on_false].borrow().tensor.clone();
640 let c = cond_t.to_vec();
641 let t = true_t.to_vec();
642 let f = false_t.to_vec();
643 let result_data: Vec<f64> = c.iter().zip(t.iter().zip(f.iter()))
644 .map(|(&ci, (&ti, &fi))| if ci != 0.0 { ti } else { fi })
645 .collect();
646 let result = Tensor::from_vec_unchecked(result_data, cond_t.shape());
647 let idx = self.nodes.len();
648 self.nodes.push(Rc::new(RefCell::new(GradNode {
649 op: GradOp::Where { cond, on_true, on_false },
650 tensor: result,
651 grad: None,
652 })));
653 idx
654 }
655
656 pub fn reshape(&mut self, a: usize, new_shape: &[usize]) -> usize {
658 let a_t = self.nodes[a].borrow().tensor.clone();
659 let original_shape = a_t.shape().to_vec();
660 let result = a_t.reshape(new_shape).expect("GradGraph::reshape: shape mismatch");
661 let idx = self.nodes.len();
662 self.nodes.push(Rc::new(RefCell::new(GradNode {
663 op: GradOp::Reshape { input: a, original_shape },
664 tensor: result,
665 grad: None,
666 })));
667 idx
668 }
669
670 pub fn transpose_op(&mut self, a: usize) -> usize {
672 let a_t = self.nodes[a].borrow().tensor.clone();
673 let result = a_t.transpose();
674 let idx = self.nodes.len();
675 self.nodes.push(Rc::new(RefCell::new(GradNode {
676 op: GradOp::TransposeOp(a),
677 tensor: result,
678 grad: None,
679 })));
680 idx
681 }
682
683 pub fn cat(&mut self, inputs: &[usize], axis: usize) -> usize {
686 let tensors: Vec<Tensor> = inputs.iter()
687 .map(|&i| self.nodes[i].borrow().tensor.clone())
688 .collect();
689 let sizes: Vec<usize> = tensors.iter()
690 .map(|t| t.shape()[axis])
691 .collect();
692 let mut all_data = Vec::new();
695 let mut total_along_axis = 0usize;
696 let ndim = tensors[0].ndim();
697 let mut result_shape = tensors[0].shape().to_vec();
698
699 if ndim == 1 {
700 for t in &tensors {
702 all_data.extend(t.to_vec());
703 total_along_axis += t.shape()[0];
704 }
705 result_shape[0] = total_along_axis;
706 } else if ndim == 2 && axis == 0 {
707 for t in &tensors {
709 all_data.extend(t.to_vec());
710 total_along_axis += t.shape()[0];
711 }
712 result_shape[0] = total_along_axis;
713 } else if ndim == 2 && axis == 1 {
714 let nrows = tensors[0].shape()[0];
716 for row in 0..nrows {
717 for t in &tensors {
718 let cols = t.shape()[1];
719 let row_data = t.to_vec();
720 let start = row * cols;
721 all_data.extend_from_slice(&row_data[start..start + cols]);
722 }
723 }
724 total_along_axis = sizes.iter().sum();
725 result_shape[1] = total_along_axis;
726 } else {
727 for t in &tensors {
729 all_data.extend(t.to_vec());
730 total_along_axis += t.shape()[axis];
731 }
732 result_shape[axis] = total_along_axis;
733 }
734
735 let result = Tensor::from_vec_unchecked(all_data, &result_shape);
736 let input_vec = inputs.to_vec();
737 let idx = self.nodes.len();
738 self.nodes.push(Rc::new(RefCell::new(GradNode {
739 op: GradOp::CatOp { inputs: input_vec, axis, sizes },
740 tensor: result,
741 grad: None,
742 })));
743 idx
744 }
745
746 pub fn gather(&mut self, a: usize, indices: &[usize], axis: usize) -> usize {
749 let a_t = self.nodes[a].borrow().tensor.clone();
750 let data = a_t.to_vec();
751 let gathered: Vec<f64> = if a_t.ndim() == 1 {
753 indices.iter().map(|&i| data[i]).collect()
754 } else if a_t.ndim() == 2 && axis == 0 {
755 let cols = a_t.shape()[1];
756 indices.iter().flat_map(|&i| {
757 let start = i * cols;
758 data[start..start + cols].to_vec()
759 }).collect()
760 } else {
761 indices.iter().map(|&i| data[i]).collect()
763 };
764 let mut result_shape = a_t.shape().to_vec();
765 if a_t.ndim() == 1 {
766 result_shape[0] = indices.len();
767 } else if axis == 0 {
768 result_shape[0] = indices.len();
769 } else {
770 result_shape[axis] = indices.len();
771 }
772 let result = Tensor::from_vec_unchecked(gathered, &result_shape);
773 let idx = self.nodes.len();
774 self.nodes.push(Rc::new(RefCell::new(GradNode {
775 op: GradOp::GatherOp { input: a, indices: indices.to_vec(), axis },
776 tensor: result,
777 grad: None,
778 })));
779 idx
780 }
781
782 pub fn div(&mut self, a: usize, b: usize) -> usize {
787 let a_tensor = self.nodes[a].borrow().tensor.clone();
788 let b_tensor = self.nodes[b].borrow().tensor.clone();
789 let result = a_tensor.div_elem_unchecked(&b_tensor);
790 let node = GradNode { op: GradOp::Div(a, b), tensor: result, grad: None };
791 self.nodes.push(Rc::new(RefCell::new(node)));
792 self.nodes.len() - 1
793 }
794
795 pub fn neg(&mut self, a: usize) -> usize {
798 let a_tensor = self.nodes[a].borrow().tensor.clone();
799 let result = a_tensor.neg();
800 let node = GradNode { op: GradOp::Neg(a), tensor: result, grad: None };
801 self.nodes.push(Rc::new(RefCell::new(node)));
802 self.nodes.len() - 1
803 }
804
805 pub fn scalar_mul(&mut self, a: usize, s: f64) -> usize {
808 let a_tensor = self.nodes[a].borrow().tensor.clone();
809 let result = a_tensor.scalar_mul(s);
810 let node = GradNode { op: GradOp::ScalarMul(a, s), tensor: result, grad: None };
811 self.nodes.push(Rc::new(RefCell::new(node)));
812 self.nodes.len() - 1
813 }
814
815 pub fn exp(&mut self, a: usize) -> usize {
818 let a_tensor = self.nodes[a].borrow().tensor.clone();
819 let result = Tensor::from_vec_unchecked(
820 a_tensor.to_vec().iter().map(|x| x.exp()).collect(),
821 a_tensor.shape(),
822 );
823 let node = GradNode { op: GradOp::Exp(a), tensor: result, grad: None };
824 self.nodes.push(Rc::new(RefCell::new(node)));
825 self.nodes.len() - 1
826 }
827
828 pub fn ln(&mut self, a: usize) -> usize {
831 let a_tensor = self.nodes[a].borrow().tensor.clone();
832 let result = Tensor::from_vec_unchecked(
833 a_tensor.to_vec().iter().map(|x| x.ln()).collect(),
834 a_tensor.shape(),
835 );
836 let node = GradNode { op: GradOp::Ln(a), tensor: result, grad: None };
837 self.nodes.push(Rc::new(RefCell::new(node)));
838 self.nodes.len() - 1
839 }
840
841 pub fn value(&self, idx: usize) -> f64 {
843 let node = self.nodes[idx].borrow();
844 let data = node.tensor.to_vec();
845 data[0]
846 }
847
848 pub fn tensor(&self, idx: usize) -> Tensor {
850 self.nodes[idx].borrow().tensor.clone()
851 }
852
853 pub fn set_tensor(&self, idx: usize, tensor: Tensor) {
855 self.nodes[idx].borrow_mut().tensor = tensor;
856 }
857
858 pub fn grad(&self, idx: usize) -> Option<Tensor> {
860 self.nodes[idx].borrow().grad.clone()
861 }
862
863 pub fn zero_grad(&self) {
865 for node in &self.nodes {
866 let mut n = node.borrow_mut();
867 if let Some(ref mut grad) = n.grad {
868 let shape = grad.shape().to_vec();
869 *grad = Tensor::zeros(&shape);
870 }
871 }
872 }
873
874 pub fn clip_grad(&self, max_norm: f64) {
877 for node in &self.nodes {
878 let mut n = node.borrow_mut();
879 if let Some(ref mut grad) = n.grad {
880 let data = grad.to_vec();
881 let clipped: Vec<f64> = data.iter()
882 .map(|&x| x.max(-max_norm).min(max_norm))
883 .collect();
884 let shape = grad.shape().to_vec();
885 *grad = Tensor::from_vec_unchecked(clipped, &shape);
886 }
887 }
888 }
889
890 pub fn clip_grad_norm(&self, max_norm: f64) -> f64 {
894 use cjc_repro::KahanAccumulatorF64;
895 let mut acc = KahanAccumulatorF64::new();
897 for node in &self.nodes {
898 let n = node.borrow();
899 if let Some(ref grad) = n.grad {
900 for &v in &grad.to_vec() {
901 acc.add(v * v);
902 }
903 }
904 }
905 let global_norm = acc.finalize().sqrt();
906
907 if global_norm > max_norm && global_norm > 0.0 {
908 let scale = max_norm / global_norm;
909 for node in &self.nodes {
910 let mut n = node.borrow_mut();
911 if let Some(ref mut grad) = n.grad {
912 let data = grad.to_vec();
913 let scaled: Vec<f64> = data.iter().map(|&x| x * scale).collect();
914 let shape = grad.shape().to_vec();
915 *grad = Tensor::from_vec_unchecked(scaled, &shape);
916 }
917 }
918 }
919
920 global_norm
921 }
922
923 pub fn backward(&self, loss_idx: usize) {
925 let n = self.nodes.len();
926
927 let mut grads: Vec<Option<Tensor>> = vec![None; n];
929
930 let loss_shape = self.nodes[loss_idx].borrow().tensor.shape().to_vec();
932 grads[loss_idx] = Some(Tensor::ones(&loss_shape));
933
934 for i in (0..=loss_idx).rev() {
936 let grad = match grads[i].take() {
937 Some(g) => g,
938 None => continue,
939 };
940
941 let (op, node_tensor) = {
943 let node = self.nodes[i].borrow();
944 (node.op.clone(), node.tensor.clone())
945 };
946
947 match op {
948 GradOp::Input => {}
949 GradOp::Parameter => {
950 let mut node_mut = self.nodes[i].borrow_mut();
951 if let Some(ref mut existing_grad) = node_mut.grad {
952 *existing_grad = existing_grad.add_unchecked(&grad);
953 } else {
954 node_mut.grad = Some(grad);
955 }
956 }
957 GradOp::Add(a, b) => {
958 accumulate_grad(&mut grads, a, &grad);
959 accumulate_grad(&mut grads, b, &grad);
960 }
961 GradOp::Sub(a, b) => {
962 accumulate_grad(&mut grads, a, &grad);
963 let neg_grad = grad.neg();
964 accumulate_grad(&mut grads, b, &neg_grad);
965 }
966 GradOp::Mul(a, b) => {
967 let a_val = self.nodes[a].borrow().tensor.clone();
968 let b_val = self.nodes[b].borrow().tensor.clone();
969
970 let grad_a = grad.mul_elem_unchecked(&b_val);
971 let grad_b = grad.mul_elem_unchecked(&a_val);
972
973 accumulate_grad(&mut grads, a, &grad_a);
974 accumulate_grad(&mut grads, b, &grad_b);
975 }
976 GradOp::Div(a, b) => {
977 let a_val = self.nodes[a].borrow().tensor.clone();
978 let b_val = self.nodes[b].borrow().tensor.clone();
979
980 let grad_a = grad.div_elem_unchecked(&b_val);
982 let b_sq = b_val.mul_elem_unchecked(&b_val);
984 let neg_a = a_val.neg();
985 let grad_b = grad.mul_elem_unchecked(&neg_a.div_elem_unchecked(&b_sq));
986
987 accumulate_grad(&mut grads, a, &grad_a);
988 accumulate_grad(&mut grads, b, &grad_b);
989 }
990 GradOp::Neg(a) => {
991 let neg_grad = grad.neg();
992 accumulate_grad(&mut grads, a, &neg_grad);
993 }
994 GradOp::MatMul(a, b) => {
995 let a_val = self.nodes[a].borrow().tensor.clone();
998 let b_val = self.nodes[b].borrow().tensor.clone();
999
1000 let b_t = b_val.transpose();
1001 let a_t = a_val.transpose();
1002
1003 let grad_a = grad.matmul_unchecked(&b_t);
1004 let grad_b = a_t.matmul_unchecked(&grad);
1005
1006 accumulate_grad(&mut grads, a, &grad_a);
1007 accumulate_grad(&mut grads, b, &grad_b);
1008 }
1009 GradOp::Sum(a) => {
1010 let a_shape = self.nodes[a].borrow().tensor.shape().to_vec();
1012 let grad_val = grad.to_vec()[0];
1013 let expanded = Tensor::from_vec_unchecked(
1014 vec![grad_val; a_shape.iter().product()],
1015 &a_shape,
1016 );
1017 accumulate_grad(&mut grads, a, &expanded);
1018 }
1019 GradOp::Mean(a) => {
1020 let a_shape = self.nodes[a].borrow().tensor.shape().to_vec();
1021 let n_elem = a_shape.iter().product::<usize>() as f64;
1022 let grad_val = grad.to_vec()[0] / n_elem;
1023 let expanded = Tensor::from_vec_unchecked(
1024 vec![grad_val; a_shape.iter().product()],
1025 &a_shape,
1026 );
1027 accumulate_grad(&mut grads, a, &expanded);
1028 }
1029 GradOp::ScalarMul(a, s) => {
1030 let scaled = grad.scalar_mul(s);
1031 accumulate_grad(&mut grads, a, &scaled);
1032 }
1033 GradOp::Exp(a) => {
1034 let grad_a = grad.mul_elem_unchecked(&node_tensor);
1035 accumulate_grad(&mut grads, a, &grad_a);
1036 }
1037 GradOp::Ln(a) => {
1038 let a_val = self.nodes[a].borrow().tensor.clone();
1039 let grad_a = grad.div_elem_unchecked(&a_val);
1040 accumulate_grad(&mut grads, a, &grad_a);
1041 }
1042 GradOp::StructField {
1044 parent,
1045 field_index,
1046 total_fields,
1047 } => {
1048 let _ = (field_index, total_fields);
1052 accumulate_grad(&mut grads, parent, &grad);
1053 }
1054 GradOp::MapLookup {
1055 map_node,
1056 key_index,
1057 total_keys,
1058 } => {
1059 let _ = (key_index, total_keys);
1062 accumulate_grad(&mut grads, map_node, &grad);
1063 }
1064 GradOp::Sin(a) => {
1066 let a_val = self.nodes[a].borrow().tensor.clone();
1067 let cos_a = Tensor::from_vec_unchecked(
1068 a_val.to_vec().iter().map(|&x| x.cos()).collect(),
1069 a_val.shape(),
1070 );
1071 let grad_a = grad.mul_elem_unchecked(&cos_a);
1072 accumulate_grad(&mut grads, a, &grad_a);
1073 }
1074 GradOp::Cos(a) => {
1075 let a_val = self.nodes[a].borrow().tensor.clone();
1076 let neg_sin_a = Tensor::from_vec_unchecked(
1077 a_val.to_vec().iter().map(|&x| -x.sin()).collect(),
1078 a_val.shape(),
1079 );
1080 let grad_a = grad.mul_elem_unchecked(&neg_sin_a);
1081 accumulate_grad(&mut grads, a, &grad_a);
1082 }
1083 GradOp::Sqrt(a) => {
1084 let inv_2sqrt = Tensor::from_vec_unchecked(
1086 node_tensor.to_vec().iter().map(|&x| 0.5 / x).collect(),
1087 node_tensor.shape(),
1088 );
1089 let grad_a = grad.mul_elem_unchecked(&inv_2sqrt);
1090 accumulate_grad(&mut grads, a, &grad_a);
1091 }
1092 GradOp::Pow(a, n) => {
1093 let a_val = self.nodes[a].borrow().tensor.clone();
1094 let coeff = Tensor::from_vec_unchecked(
1095 a_val.to_vec().iter().map(|&x| n * x.powf(n - 1.0)).collect(),
1096 a_val.shape(),
1097 );
1098 let grad_a = grad.mul_elem_unchecked(&coeff);
1099 accumulate_grad(&mut grads, a, &grad_a);
1100 }
1101 GradOp::Sigmoid(a) => {
1102 let sig = &node_tensor;
1104 let one_minus = Tensor::from_vec_unchecked(
1105 sig.to_vec().iter().map(|&s| 1.0 - s).collect(),
1106 sig.shape(),
1107 );
1108 let local = sig.mul_elem_unchecked(&one_minus);
1109 let grad_a = grad.mul_elem_unchecked(&local);
1110 accumulate_grad(&mut grads, a, &grad_a);
1111 }
1112 GradOp::Relu(a) => {
1113 let a_val = self.nodes[a].borrow().tensor.clone();
1114 let mask = Tensor::from_vec_unchecked(
1115 a_val.to_vec().iter().map(|&x| if x > 0.0 { 1.0 } else { 0.0 }).collect(),
1116 a_val.shape(),
1117 );
1118 let grad_a = grad.mul_elem_unchecked(&mask);
1119 accumulate_grad(&mut grads, a, &grad_a);
1120 }
1121 GradOp::TanhAct(a) => {
1122 let t = &node_tensor;
1124 let one_minus_sq = Tensor::from_vec_unchecked(
1125 t.to_vec().iter().map(|&x| 1.0 - x * x).collect(),
1126 t.shape(),
1127 );
1128 let grad_a = grad.mul_elem_unchecked(&one_minus_sq);
1129 accumulate_grad(&mut grads, a, &grad_a);
1130 }
1131 GradOp::Abs(a) => {
1133 let a_val = self.nodes[a].borrow().tensor.clone();
1134 let sign = Tensor::from_vec_unchecked(
1135 a_val.to_vec().iter().map(|&x| {
1136 if x > 0.0 { 1.0 } else if x < 0.0 { -1.0 } else { 0.0 }
1137 }).collect(),
1138 a_val.shape(),
1139 );
1140 let grad_a = grad.mul_elem_unchecked(&sign);
1141 accumulate_grad(&mut grads, a, &grad_a);
1142 }
1143 GradOp::Log2(a) => {
1144 let a_val = self.nodes[a].borrow().tensor.clone();
1146 let ln2 = std::f64::consts::LN_2;
1147 let local = Tensor::from_vec_unchecked(
1148 a_val.to_vec().iter().map(|&x| 1.0 / (x * ln2)).collect(),
1149 a_val.shape(),
1150 );
1151 let grad_a = grad.mul_elem_unchecked(&local);
1152 accumulate_grad(&mut grads, a, &grad_a);
1153 }
1154 GradOp::Softmax(a) => {
1155 use cjc_repro::KahanAccumulatorF64;
1157 let sm = &node_tensor;
1158 let sm_data = sm.to_vec();
1159 let grad_data = grad.to_vec();
1160 let mut dot_acc = KahanAccumulatorF64::new();
1161 for (&g, &s) in grad_data.iter().zip(sm_data.iter()) {
1162 dot_acc.add(g * s);
1163 }
1164 let dot = dot_acc.finalize();
1165 let grad_input: Vec<f64> = sm_data.iter().zip(grad_data.iter())
1166 .map(|(&s, &g)| s * (g - dot))
1167 .collect();
1168 let grad_a = Tensor::from_vec_unchecked(grad_input, sm.shape());
1169 accumulate_grad(&mut grads, a, &grad_a);
1170 }
1171 GradOp::CrossEntropy { logits, targets } => {
1172 use cjc_repro::KahanAccumulatorF64;
1174 let logits_val = self.nodes[logits].borrow().tensor.clone();
1175 let targets_val = self.nodes[targets].borrow().tensor.clone();
1176 let logits_data = logits_val.to_vec();
1177 let targets_data = targets_val.to_vec();
1178 let max_val = logits_data.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
1180 let exp_shifted: Vec<f64> = logits_data.iter().map(|&x| (x - max_val).exp()).collect();
1181 let mut sum_acc = KahanAccumulatorF64::new();
1182 for &v in &exp_shifted {
1183 sum_acc.add(v);
1184 }
1185 let sum_exp = sum_acc.finalize();
1186 let softmax: Vec<f64> = exp_shifted.iter().map(|&e| e / sum_exp).collect();
1187 let upstream = grad.to_vec()[0]; let grad_logits: Vec<f64> = softmax.iter().zip(targets_data.iter())
1190 .map(|(&s, &t)| upstream * (s - t))
1191 .collect();
1192 let gl = Tensor::from_vec_unchecked(grad_logits, logits_val.shape());
1193 accumulate_grad(&mut grads, logits, &gl);
1194 }
1196 GradOp::LayerNorm(a) => {
1197 use cjc_repro::KahanAccumulatorF64;
1201 let x_hat = &node_tensor;
1202 let x_hat_data = x_hat.to_vec();
1203 let grad_data = grad.to_vec();
1204 let n = x_hat_data.len() as f64;
1205 let a_val = self.nodes[a].borrow().tensor.clone();
1207 let a_data = a_val.to_vec();
1208 let mut mean_acc = KahanAccumulatorF64::new();
1209 for &v in &a_data {
1210 mean_acc.add(v);
1211 }
1212 let mean = mean_acc.finalize() / n;
1213 let mut var_acc = KahanAccumulatorF64::new();
1214 for &v in &a_data {
1215 let d = v - mean;
1216 var_acc.add(d * d);
1217 }
1218 let var = var_acc.finalize() / n;
1219 let eps = 1e-5;
1220 let std_val = (var + eps).sqrt();
1221 let mut mg_acc = KahanAccumulatorF64::new();
1223 for &g in &grad_data {
1224 mg_acc.add(g);
1225 }
1226 let mean_grad = mg_acc.finalize() / n;
1227 let mut mgx_acc = KahanAccumulatorF64::new();
1229 for (&g, &xh) in grad_data.iter().zip(x_hat_data.iter()) {
1230 mgx_acc.add(g * xh);
1231 }
1232 let mean_grad_xhat = mgx_acc.finalize() / n;
1233 let dx: Vec<f64> = grad_data.iter().zip(x_hat_data.iter())
1234 .map(|(&g, &xh)| (g - mean_grad - xh * mean_grad_xhat) / std_val)
1235 .collect();
1236 let grad_a = Tensor::from_vec_unchecked(dx, a_val.shape());
1237 accumulate_grad(&mut grads, a, &grad_a);
1238 }
1239 GradOp::BatchNorm(a) => {
1240 use cjc_repro::KahanAccumulatorF64;
1242 let x_hat = &node_tensor;
1243 let x_hat_data = x_hat.to_vec();
1244 let grad_data = grad.to_vec();
1245 let n = x_hat_data.len() as f64;
1246 let a_val = self.nodes[a].borrow().tensor.clone();
1247 let a_data = a_val.to_vec();
1248 let mut mean_acc = KahanAccumulatorF64::new();
1249 for &v in &a_data {
1250 mean_acc.add(v);
1251 }
1252 let mean = mean_acc.finalize() / n;
1253 let mut var_acc = KahanAccumulatorF64::new();
1254 for &v in &a_data {
1255 let d = v - mean;
1256 var_acc.add(d * d);
1257 }
1258 let var = var_acc.finalize() / n;
1259 let eps = 1e-5;
1260 let std_val = (var + eps).sqrt();
1261 let mut mg_acc = KahanAccumulatorF64::new();
1262 for &g in &grad_data {
1263 mg_acc.add(g);
1264 }
1265 let mean_grad = mg_acc.finalize() / n;
1266 let mut mgx_acc = KahanAccumulatorF64::new();
1267 for (&g, &xh) in grad_data.iter().zip(x_hat_data.iter()) {
1268 mgx_acc.add(g * xh);
1269 }
1270 let mean_grad_xhat = mgx_acc.finalize() / n;
1271 let dx: Vec<f64> = grad_data.iter().zip(x_hat_data.iter())
1272 .map(|(&g, &xh)| (g - mean_grad - xh * mean_grad_xhat) / std_val)
1273 .collect();
1274 let grad_a = Tensor::from_vec_unchecked(dx, a_val.shape());
1275 accumulate_grad(&mut grads, a, &grad_a);
1276 }
1277 GradOp::Clamp { input, min, max } => {
1278 let a_val = self.nodes[input].borrow().tensor.clone();
1280 let mask = Tensor::from_vec_unchecked(
1281 a_val.to_vec().iter().map(|&x| {
1282 if x >= min && x <= max { 1.0 } else { 0.0 }
1283 }).collect(),
1284 a_val.shape(),
1285 );
1286 let grad_a = grad.mul_elem_unchecked(&mask);
1287 accumulate_grad(&mut grads, input, &grad_a);
1288 }
1289 GradOp::Where { cond, on_true, on_false } => {
1290 let cond_data = self.nodes[cond].borrow().tensor.to_vec();
1291 let grad_data = grad.to_vec();
1292 let shape = grad.shape().to_vec();
1293 let grad_true: Vec<f64> = cond_data.iter().zip(grad_data.iter())
1294 .map(|(&c, &g)| if c != 0.0 { g } else { 0.0 })
1295 .collect();
1296 let grad_false: Vec<f64> = cond_data.iter().zip(grad_data.iter())
1297 .map(|(&c, &g)| if c != 0.0 { 0.0 } else { g })
1298 .collect();
1299 let gt = Tensor::from_vec_unchecked(grad_true, &shape);
1300 let gf = Tensor::from_vec_unchecked(grad_false, &shape);
1301 accumulate_grad(&mut grads, on_true, >);
1302 accumulate_grad(&mut grads, on_false, &gf);
1303 }
1305 GradOp::Reshape { input, ref original_shape } => {
1306 let grad_a = grad.reshape(original_shape)
1308 .expect("Reshape backward: shape mismatch");
1309 accumulate_grad(&mut grads, input, &grad_a);
1310 }
1311 GradOp::TransposeOp(a) => {
1312 let grad_a = grad.transpose();
1314 accumulate_grad(&mut grads, a, &grad_a);
1315 }
1316 GradOp::CatOp { ref inputs, axis, ref sizes } => {
1317 let grad_data = grad.to_vec();
1319 let grad_shape = grad.shape().to_vec();
1320 let ndim = grad_shape.len();
1321 if ndim == 1 {
1322 let mut offset = 0usize;
1323 for (&idx, &sz) in inputs.iter().zip(sizes.iter()) {
1324 let piece = grad_data[offset..offset + sz].to_vec();
1325 let gt = Tensor::from_vec_unchecked(piece, &[sz]);
1326 accumulate_grad(&mut grads, idx, >);
1327 offset += sz;
1328 }
1329 } else if ndim == 2 && axis == 0 {
1330 let cols = grad_shape[1];
1331 let mut row_offset = 0usize;
1332 for (&idx, &sz) in inputs.iter().zip(sizes.iter()) {
1333 let start = row_offset * cols;
1334 let end = start + sz * cols;
1335 let piece = grad_data[start..end].to_vec();
1336 let gt = Tensor::from_vec_unchecked(piece, &[sz, cols]);
1337 accumulate_grad(&mut grads, idx, >);
1338 row_offset += sz;
1339 }
1340 } else if ndim == 2 && axis == 1 {
1341 let nrows = grad_shape[0];
1342 let total_cols = grad_shape[1];
1343 for (input_idx, (&idx, &sz)) in inputs.iter().zip(sizes.iter()).enumerate() {
1344 let mut piece = Vec::with_capacity(nrows * sz);
1345 let col_offset: usize = sizes[..input_idx].iter().sum();
1346 for row in 0..nrows {
1347 let row_start = row * total_cols + col_offset;
1348 piece.extend_from_slice(&grad_data[row_start..row_start + sz]);
1349 }
1350 let gt = Tensor::from_vec_unchecked(piece, &[nrows, sz]);
1351 accumulate_grad(&mut grads, idx, >);
1352 }
1353 } else {
1354 let mut offset = 0usize;
1356 for (&idx, &sz) in inputs.iter().zip(sizes.iter()) {
1357 let piece_len = sz * grad_data.len() / grad_shape[axis];
1358 let piece = grad_data[offset..offset + piece_len].to_vec();
1359 let mut piece_shape = grad_shape.clone();
1360 piece_shape[axis] = sz;
1361 let gt = Tensor::from_vec_unchecked(piece, &piece_shape);
1362 accumulate_grad(&mut grads, idx, >);
1363 offset += piece_len;
1364 }
1365 }
1366 }
1367 GradOp::GatherOp { input, ref indices, axis } => {
1368 let input_shape = self.nodes[input].borrow().tensor.shape().to_vec();
1370 let input_len: usize = input_shape.iter().product();
1371 let mut scatter = vec![0.0_f64; input_len];
1372 let grad_data = grad.to_vec();
1373 if self.nodes[input].borrow().tensor.ndim() == 1 {
1374 for (gi, &idx) in indices.iter().enumerate() {
1375 scatter[idx] += grad_data[gi];
1376 }
1377 } else if axis == 0 && self.nodes[input].borrow().tensor.ndim() == 2 {
1378 let cols = input_shape[1];
1379 for (gi, &idx) in indices.iter().enumerate() {
1380 for c in 0..cols {
1381 scatter[idx * cols + c] += grad_data[gi * cols + c];
1382 }
1383 }
1384 } else {
1385 for (gi, &idx) in indices.iter().enumerate() {
1387 scatter[idx] += grad_data[gi];
1388 }
1389 }
1390 let grad_a = Tensor::from_vec_unchecked(scatter, &input_shape);
1391 accumulate_grad(&mut grads, input, &grad_a);
1392 }
1393 }
1394 }
1395 }
1396
1397 pub fn jacobian(&mut self, output_idx: usize, param_idx: usize) -> Tensor {
1402 let output_shape = self.nodes[output_idx].borrow().tensor.shape().to_vec();
1403 let output_dim: usize = output_shape.iter().product();
1404 let param_shape = self.nodes[param_idx].borrow().tensor.shape().to_vec();
1405 let param_dim: usize = param_shape.iter().product();
1406
1407 let mut jac_data = vec![0.0_f64; output_dim * param_dim];
1408
1409 for i in 0..output_dim {
1410 self.zero_grad();
1412
1413 let mut seed = vec![0.0_f64; output_dim];
1415 seed[i] = 1.0;
1416 let seed_tensor = Tensor::from_vec_unchecked(seed, &output_shape);
1417
1418 self.backward_with_seed(output_idx, &seed_tensor);
1420
1421 let grad = self.nodes[param_idx].borrow().grad.clone();
1423 if let Some(g) = grad {
1424 let g_vec = g.to_vec();
1425 for j in 0..param_dim {
1426 jac_data[i * param_dim + j] = g_vec[j];
1427 }
1428 }
1429 }
1430
1431 Tensor::from_vec_unchecked(jac_data, &[output_dim, param_dim])
1432 }
1433
1434 pub fn hessian_diag(&mut self, loss_idx: usize, param_idx: usize, eps: f64) -> Tensor {
1440 let param_shape = self.nodes[param_idx].borrow().tensor.shape().to_vec();
1441 let param_dim: usize = param_shape.iter().product();
1442 let original = self.nodes[param_idx].borrow().tensor.to_vec();
1443 let mut hess_diag = vec![0.0_f64; param_dim];
1444
1445 for i in 0..param_dim {
1446 let mut plus = original.clone();
1448 plus[i] += eps;
1449 self.nodes[param_idx].borrow_mut().tensor =
1450 Tensor::from_vec_unchecked(plus, ¶m_shape);
1451 self.zero_grad();
1452 self.backward(loss_idx);
1453 let grad_plus = self.nodes[param_idx]
1454 .borrow()
1455 .grad
1456 .as_ref()
1457 .map(|g| g.to_vec()[i])
1458 .unwrap_or(0.0);
1459
1460 let mut minus = original.clone();
1462 minus[i] -= eps;
1463 self.nodes[param_idx].borrow_mut().tensor =
1464 Tensor::from_vec_unchecked(minus, ¶m_shape);
1465 self.zero_grad();
1466 self.backward(loss_idx);
1467 let grad_minus = self.nodes[param_idx]
1468 .borrow()
1469 .grad
1470 .as_ref()
1471 .map(|g| g.to_vec()[i])
1472 .unwrap_or(0.0);
1473
1474 hess_diag[i] = (grad_plus - grad_minus) / (2.0 * eps);
1475 }
1476
1477 self.nodes[param_idx].borrow_mut().tensor =
1479 Tensor::from_vec_unchecked(original, ¶m_shape);
1480
1481 Tensor::from_vec_unchecked(hess_diag, ¶m_shape)
1482 }
1483
1484 pub fn hessian(&mut self, loss_idx: usize, param_idx: usize) -> Tensor {
1493 let eps = 1e-5;
1494 let param_shape = self.nodes[param_idx].borrow().tensor.shape().to_vec();
1495 let param_dim: usize = param_shape.iter().product();
1496 let original = self.nodes[param_idx].borrow().tensor.to_vec();
1497 let mut hess_data = vec![0.0_f64; param_dim * param_dim];
1498
1499 for i in 0..param_dim {
1500 let mut plus = original.clone();
1502 plus[i] += eps;
1503 self.nodes[param_idx].borrow_mut().tensor =
1504 Tensor::from_vec_unchecked(plus, ¶m_shape);
1505 self.reforward(param_idx + 1, loss_idx);
1506 self.zero_grad();
1507 self.backward(loss_idx);
1508 let grad_plus: Vec<f64> = self.nodes[param_idx]
1509 .borrow()
1510 .grad
1511 .as_ref()
1512 .map(|g| g.to_vec())
1513 .unwrap_or_else(|| vec![0.0; param_dim]);
1514
1515 let mut minus = original.clone();
1517 minus[i] -= eps;
1518 self.nodes[param_idx].borrow_mut().tensor =
1519 Tensor::from_vec_unchecked(minus, ¶m_shape);
1520 self.reforward(param_idx + 1, loss_idx);
1521 self.zero_grad();
1522 self.backward(loss_idx);
1523 let grad_minus: Vec<f64> = self.nodes[param_idx]
1524 .borrow()
1525 .grad
1526 .as_ref()
1527 .map(|g| g.to_vec())
1528 .unwrap_or_else(|| vec![0.0; param_dim]);
1529
1530 for j in 0..param_dim {
1532 hess_data[i * param_dim + j] = (grad_plus[j] - grad_minus[j]) / (2.0 * eps);
1533 }
1534 }
1535
1536 self.nodes[param_idx].borrow_mut().tensor =
1538 Tensor::from_vec_unchecked(original, ¶m_shape);
1539 self.reforward(param_idx + 1, loss_idx);
1540
1541 Tensor::from_vec_unchecked(hess_data, &[param_dim, param_dim])
1542 }
1543
1544 fn reforward(&mut self, start: usize, end: usize) {
1549 for node_i in start..=end {
1550 let new_tensor = {
1551 let node = self.nodes[node_i].borrow();
1552 match &node.op {
1553 GradOp::Input | GradOp::Parameter => continue,
1554 GradOp::Add(a, b) => {
1555 let at = self.nodes[*a].borrow().tensor.clone();
1556 let bt = self.nodes[*b].borrow().tensor.clone();
1557 at.add_unchecked(&bt)
1558 }
1559 GradOp::Sub(a, b) => {
1560 let at = self.nodes[*a].borrow().tensor.clone();
1561 let bt = self.nodes[*b].borrow().tensor.clone();
1562 at.sub_unchecked(&bt)
1563 }
1564 GradOp::Mul(a, b) => {
1565 let at = self.nodes[*a].borrow().tensor.clone();
1566 let bt = self.nodes[*b].borrow().tensor.clone();
1567 at.mul_elem_unchecked(&bt)
1568 }
1569 GradOp::Div(a, b) => {
1570 let at = self.nodes[*a].borrow().tensor.clone();
1571 let bt = self.nodes[*b].borrow().tensor.clone();
1572 at.div_elem_unchecked(&bt)
1573 }
1574 GradOp::Neg(a) => {
1575 self.nodes[*a].borrow().tensor.neg()
1576 }
1577 GradOp::ScalarMul(a, s) => {
1578 let s = *s;
1579 self.nodes[*a].borrow().tensor.scalar_mul(s)
1580 }
1581 GradOp::MatMul(a, b) => {
1582 let at = self.nodes[*a].borrow().tensor.clone();
1583 let bt = self.nodes[*b].borrow().tensor.clone();
1584 at.matmul_unchecked(&bt)
1585 }
1586 GradOp::Sum(a) => {
1587 let s = self.nodes[*a].borrow().tensor.sum();
1588 Tensor::from_vec_unchecked(vec![s], &[1])
1589 }
1590 GradOp::Mean(a) => {
1591 let m = self.nodes[*a].borrow().tensor.mean();
1592 Tensor::from_vec_unchecked(vec![m], &[1])
1593 }
1594 GradOp::Exp(a) => {
1595 let (data, shape) = {
1596 let t = self.nodes[*a].borrow();
1597 (t.tensor.to_vec(), t.tensor.shape().to_vec())
1598 };
1599 Tensor::from_vec_unchecked(data.iter().map(|x| x.exp()).collect(), &shape)
1600 }
1601 GradOp::Ln(a) => {
1602 let (data, shape) = {
1603 let t = self.nodes[*a].borrow();
1604 (t.tensor.to_vec(), t.tensor.shape().to_vec())
1605 };
1606 Tensor::from_vec_unchecked(data.iter().map(|x| x.ln()).collect(), &shape)
1607 }
1608 GradOp::Sin(a) => {
1609 let (data, shape) = {
1610 let t = self.nodes[*a].borrow();
1611 (t.tensor.to_vec(), t.tensor.shape().to_vec())
1612 };
1613 Tensor::from_vec_unchecked(data.iter().map(|x| x.sin()).collect(), &shape)
1614 }
1615 GradOp::Cos(a) => {
1616 let (data, shape) = {
1617 let t = self.nodes[*a].borrow();
1618 (t.tensor.to_vec(), t.tensor.shape().to_vec())
1619 };
1620 Tensor::from_vec_unchecked(data.iter().map(|x| x.cos()).collect(), &shape)
1621 }
1622 GradOp::Sqrt(a) => {
1623 let (data, shape) = {
1624 let t = self.nodes[*a].borrow();
1625 (t.tensor.to_vec(), t.tensor.shape().to_vec())
1626 };
1627 Tensor::from_vec_unchecked(data.iter().map(|x| x.sqrt()).collect(), &shape)
1628 }
1629 GradOp::Pow(a, n) => {
1630 let n = *n;
1631 let (data, shape) = {
1632 let t = self.nodes[*a].borrow();
1633 (t.tensor.to_vec(), t.tensor.shape().to_vec())
1634 };
1635 Tensor::from_vec_unchecked(data.iter().map(|x| x.powf(n)).collect(), &shape)
1636 }
1637 GradOp::Sigmoid(a) => {
1638 let (data, shape) = {
1639 let t = self.nodes[*a].borrow();
1640 (t.tensor.to_vec(), t.tensor.shape().to_vec())
1641 };
1642 Tensor::from_vec_unchecked(
1643 data.iter().map(|&x| 1.0 / (1.0 + (-x).exp())).collect(),
1644 &shape,
1645 )
1646 }
1647 GradOp::Relu(a) => {
1648 let (data, shape) = {
1649 let t = self.nodes[*a].borrow();
1650 (t.tensor.to_vec(), t.tensor.shape().to_vec())
1651 };
1652 Tensor::from_vec_unchecked(
1653 data.iter().map(|&x| if x > 0.0 { x } else { 0.0 }).collect(),
1654 &shape,
1655 )
1656 }
1657 GradOp::TanhAct(a) => {
1658 let (data, shape) = {
1659 let t = self.nodes[*a].borrow();
1660 (t.tensor.to_vec(), t.tensor.shape().to_vec())
1661 };
1662 Tensor::from_vec_unchecked(data.iter().map(|x| x.tanh()).collect(), &shape)
1663 }
1664 GradOp::Abs(a) => {
1665 let (data, shape) = {
1666 let t = self.nodes[*a].borrow();
1667 (t.tensor.to_vec(), t.tensor.shape().to_vec())
1668 };
1669 Tensor::from_vec_unchecked(data.iter().map(|x| x.abs()).collect(), &shape)
1670 }
1671 GradOp::Clamp { input, min, max } => {
1672 let min = *min;
1673 let max = *max;
1674 let (data, shape) = {
1675 let t = self.nodes[*input].borrow();
1676 (t.tensor.to_vec(), t.tensor.shape().to_vec())
1677 };
1678 Tensor::from_vec_unchecked(
1679 data.iter().map(|&x| x.max(min).min(max)).collect(),
1680 &shape,
1681 )
1682 }
1683 GradOp::Reshape { input, .. } => {
1684 let current_shape = node.tensor.shape().to_vec();
1685 let data = self.nodes[*input].borrow().tensor.to_vec();
1686 Tensor::from_vec_unchecked(data, ¤t_shape)
1687 }
1688 GradOp::TransposeOp(a) => {
1689 self.nodes[*a].borrow().tensor.transpose()
1690 }
1691 _ => node.tensor.clone(),
1694 }
1695 };
1696 self.nodes[node_i].borrow_mut().tensor = new_tensor;
1697 }
1698 }
1699
1700 pub fn double_backward(&mut self, loss_idx: usize, param_idx: usize) -> Tensor {
1709 let eps = 1e-5;
1710 let param_shape = self.nodes[param_idx].borrow().tensor.shape().to_vec();
1711 let param_dim: usize = param_shape.iter().product();
1712 let original = self.nodes[param_idx].borrow().tensor.to_vec();
1713 let mut diag = vec![0.0_f64; param_dim];
1714
1715 for i in 0..param_dim {
1716 let mut plus = original.clone();
1718 plus[i] += eps;
1719 self.nodes[param_idx].borrow_mut().tensor =
1720 Tensor::from_vec_unchecked(plus, ¶m_shape);
1721 self.reforward(param_idx + 1, loss_idx);
1722 self.zero_grad();
1723 self.backward(loss_idx);
1724 let grad_plus = self.nodes[param_idx]
1725 .borrow()
1726 .grad
1727 .as_ref()
1728 .map(|g| g.to_vec()[i])
1729 .unwrap_or(0.0);
1730
1731 let mut minus = original.clone();
1733 minus[i] -= eps;
1734 self.nodes[param_idx].borrow_mut().tensor =
1735 Tensor::from_vec_unchecked(minus, ¶m_shape);
1736 self.reforward(param_idx + 1, loss_idx);
1737 self.zero_grad();
1738 self.backward(loss_idx);
1739 let grad_minus = self.nodes[param_idx]
1740 .borrow()
1741 .grad
1742 .as_ref()
1743 .map(|g| g.to_vec()[i])
1744 .unwrap_or(0.0);
1745
1746 diag[i] = (grad_plus - grad_minus) / (2.0 * eps);
1747 }
1748
1749 self.nodes[param_idx].borrow_mut().tensor =
1751 Tensor::from_vec_unchecked(original, ¶m_shape);
1752 self.reforward(param_idx + 1, loss_idx);
1753
1754 Tensor::from_vec_unchecked(diag, ¶m_shape)
1755 }
1756
1757 pub fn vmap_forward(&mut self, input_idx: usize, batch_data: &[Tensor]) -> Vec<usize> {
1776 let mut result_indices = Vec::with_capacity(batch_data.len());
1777
1778 let graph_len = self.nodes.len();
1781
1782 for batch_tensor in batch_data {
1783 self.nodes[input_idx].borrow_mut().tensor = batch_tensor.clone();
1785
1786 for node_i in (input_idx + 1)..graph_len {
1788 let (op, new_tensor) = {
1789 let node = self.nodes[node_i].borrow();
1790 let op = node.op.clone();
1791 let new_tensor = match &op {
1792 GradOp::Add(a, b) => {
1793 let at = self.nodes[*a].borrow().tensor.clone();
1794 let bt = self.nodes[*b].borrow().tensor.clone();
1795 at.add_unchecked(&bt)
1796 }
1797 GradOp::Sub(a, b) => {
1798 let at = self.nodes[*a].borrow().tensor.clone();
1799 let bt = self.nodes[*b].borrow().tensor.clone();
1800 at.sub_unchecked(&bt)
1801 }
1802 GradOp::Mul(a, b) => {
1803 let at = self.nodes[*a].borrow().tensor.clone();
1804 let bt = self.nodes[*b].borrow().tensor.clone();
1805 at.mul_elem_unchecked(&bt)
1806 }
1807 GradOp::Div(a, b) => {
1808 let at = self.nodes[*a].borrow().tensor.clone();
1809 let bt = self.nodes[*b].borrow().tensor.clone();
1810 at.div_elem_unchecked(&bt)
1811 }
1812 GradOp::Neg(a) => {
1813 self.nodes[*a].borrow().tensor.neg()
1814 }
1815 GradOp::ScalarMul(a, s) => {
1816 self.nodes[*a].borrow().tensor.scalar_mul(*s)
1817 }
1818 GradOp::MatMul(a, b) => {
1819 let at = self.nodes[*a].borrow().tensor.clone();
1820 let bt = self.nodes[*b].borrow().tensor.clone();
1821 at.matmul_unchecked(&bt)
1822 }
1823 GradOp::Sum(a) => {
1824 let s = self.nodes[*a].borrow().tensor.sum();
1825 let shape = vec![1usize];
1826 Tensor::from_vec_unchecked(vec![s], &shape)
1827 }
1828 GradOp::Mean(a) => {
1829 let m = self.nodes[*a].borrow().tensor.mean();
1830 Tensor::from_vec_unchecked(vec![m], &[1])
1831 }
1832 GradOp::Exp(a) => {
1833 let data = self.nodes[*a].borrow().tensor.to_vec();
1834 let shape = self.nodes[*a].borrow().tensor.shape().to_vec();
1835 Tensor::from_vec_unchecked(
1836 data.iter().map(|x| x.exp()).collect(),
1837 &shape,
1838 )
1839 }
1840 GradOp::Ln(a) => {
1841 let data = self.nodes[*a].borrow().tensor.to_vec();
1842 let shape = self.nodes[*a].borrow().tensor.shape().to_vec();
1843 Tensor::from_vec_unchecked(
1844 data.iter().map(|x| x.ln()).collect(),
1845 &shape,
1846 )
1847 }
1848 GradOp::Sin(a) => {
1849 let data = self.nodes[*a].borrow().tensor.to_vec();
1850 let shape = self.nodes[*a].borrow().tensor.shape().to_vec();
1851 Tensor::from_vec_unchecked(
1852 data.iter().map(|x| x.sin()).collect(),
1853 &shape,
1854 )
1855 }
1856 GradOp::Cos(a) => {
1857 let data = self.nodes[*a].borrow().tensor.to_vec();
1858 let shape = self.nodes[*a].borrow().tensor.shape().to_vec();
1859 Tensor::from_vec_unchecked(
1860 data.iter().map(|x| x.cos()).collect(),
1861 &shape,
1862 )
1863 }
1864 GradOp::Sqrt(a) => {
1865 let data = self.nodes[*a].borrow().tensor.to_vec();
1866 let shape = self.nodes[*a].borrow().tensor.shape().to_vec();
1867 Tensor::from_vec_unchecked(
1868 data.iter().map(|x| x.sqrt()).collect(),
1869 &shape,
1870 )
1871 }
1872 GradOp::Pow(a, n) => {
1873 let data = self.nodes[*a].borrow().tensor.to_vec();
1874 let shape = self.nodes[*a].borrow().tensor.shape().to_vec();
1875 Tensor::from_vec_unchecked(
1876 data.iter().map(|x| x.powf(*n)).collect(),
1877 &shape,
1878 )
1879 }
1880 GradOp::Sigmoid(a) => {
1881 let data = self.nodes[*a].borrow().tensor.to_vec();
1882 let shape = self.nodes[*a].borrow().tensor.shape().to_vec();
1883 Tensor::from_vec_unchecked(
1884 data.iter().map(|&x| 1.0 / (1.0 + (-x).exp())).collect(),
1885 &shape,
1886 )
1887 }
1888 GradOp::Relu(a) => {
1889 let data = self.nodes[*a].borrow().tensor.to_vec();
1890 let shape = self.nodes[*a].borrow().tensor.shape().to_vec();
1891 Tensor::from_vec_unchecked(
1892 data.iter().map(|&x| if x > 0.0 { x } else { 0.0 }).collect(),
1893 &shape,
1894 )
1895 }
1896 GradOp::TanhAct(a) => {
1897 let data = self.nodes[*a].borrow().tensor.to_vec();
1898 let shape = self.nodes[*a].borrow().tensor.shape().to_vec();
1899 Tensor::from_vec_unchecked(
1900 data.iter().map(|x| x.tanh()).collect(),
1901 &shape,
1902 )
1903 }
1904 GradOp::Abs(a) => {
1905 let data = self.nodes[*a].borrow().tensor.to_vec();
1906 let shape = self.nodes[*a].borrow().tensor.shape().to_vec();
1907 Tensor::from_vec_unchecked(
1908 data.iter().map(|x| x.abs()).collect(),
1909 &shape,
1910 )
1911 }
1912 GradOp::Clamp { input, min, max } => {
1913 let data = self.nodes[*input].borrow().tensor.to_vec();
1914 let shape = self.nodes[*input].borrow().tensor.shape().to_vec();
1915 Tensor::from_vec_unchecked(
1916 data.iter().map(|&x| x.max(*min).min(*max)).collect(),
1917 &shape,
1918 )
1919 }
1920 GradOp::Reshape { input, .. } => {
1921 let data = self.nodes[*input].borrow().tensor.to_vec();
1923 let shape = node.tensor.shape().to_vec();
1924 Tensor::from_vec_unchecked(data, &shape)
1925 }
1926 GradOp::TransposeOp(a) => {
1927 self.nodes[*a].borrow().tensor.transpose()
1928 }
1929 _ => node.tensor.clone(),
1932 };
1933 (op, new_tensor)
1934 };
1935 let _ = op; self.nodes[node_i].borrow_mut().tensor = new_tensor;
1937 }
1938
1939 let output_tensor = self.nodes[graph_len - 1].borrow().tensor.clone();
1942 let snapshot_idx = self.nodes.len();
1943 self.nodes.push(Rc::new(RefCell::new(GradNode {
1944 op: GradOp::Input,
1945 tensor: output_tensor,
1946 grad: None,
1947 })));
1948 result_indices.push(snapshot_idx);
1949 }
1950
1951 result_indices
1952 }
1953
1954 pub fn backward_with_seed(&mut self, loss_idx: usize, seed: &Tensor) {
1956 let n = self.nodes.len();
1957 let mut grads: Vec<Option<Tensor>> = vec![None; n];
1958 grads[loss_idx] = Some(seed.clone());
1959
1960 for i in (0..n).rev() {
1961 let grad = match grads[i].take() {
1962 Some(g) => g,
1963 None => continue,
1964 };
1965
1966 let node = self.nodes[i].borrow();
1967 if let Some(ref _param_grad) = node.grad {
1968 drop(node);
1970 let new_grad = {
1971 let n = self.nodes[i].borrow();
1972 if let Some(ref existing) = n.grad {
1973 if existing.to_vec().iter().all(|&x| x == 0.0) {
1974 grad.clone()
1975 } else {
1976 existing.add_unchecked(&grad)
1977 }
1978 } else {
1979 grad.clone()
1980 }
1981 };
1982 self.nodes[i].borrow_mut().grad = Some(new_grad);
1983 } else {
1984 drop(node);
1985 }
1986
1987 let node = self.nodes[i].borrow();
1989 let node_tensor = node.tensor.clone();
1990 match &node.op {
1991 GradOp::Input | GradOp::Parameter => {}
1992 GradOp::Add(a, b) => {
1993 accumulate_grad(&mut grads, *a, &grad);
1994 accumulate_grad(&mut grads, *b, &grad);
1995 }
1996 GradOp::Sub(a, b) => {
1997 accumulate_grad(&mut grads, *a, &grad);
1998 accumulate_grad(&mut grads, *b, &grad.neg());
1999 }
2000 GradOp::Mul(a, b) => {
2001 let a_val = self.nodes[*a].borrow().tensor.clone();
2002 let b_val = self.nodes[*b].borrow().tensor.clone();
2003 accumulate_grad(&mut grads, *a, &grad.mul_elem_unchecked(&b_val));
2004 accumulate_grad(&mut grads, *b, &grad.mul_elem_unchecked(&a_val));
2005 }
2006 GradOp::Div(a, b) => {
2007 let a_val = self.nodes[*a].borrow().tensor.clone();
2008 let b_val = self.nodes[*b].borrow().tensor.clone();
2009 let grad_a = grad.div_elem_unchecked(&b_val);
2010 let neg_a_over_b2 = a_val.neg().div_elem_unchecked(
2011 &b_val.mul_elem_unchecked(&b_val),
2012 );
2013 let grad_b = grad.mul_elem_unchecked(&neg_a_over_b2);
2014 accumulate_grad(&mut grads, *a, &grad_a);
2015 accumulate_grad(&mut grads, *b, &grad_b);
2016 }
2017 GradOp::Neg(a) => {
2018 accumulate_grad(&mut grads, *a, &grad.neg());
2019 }
2020 GradOp::ScalarMul(a, s) => {
2021 accumulate_grad(&mut grads, *a, &grad.scalar_mul(*s));
2022 }
2023 GradOp::Exp(a) => {
2024 accumulate_grad(&mut grads, *a, &grad.mul_elem_unchecked(&node_tensor));
2025 }
2026 GradOp::Ln(a) => {
2027 let a_val = self.nodes[*a].borrow().tensor.clone();
2028 let inv = Tensor::from_vec_unchecked(
2029 a_val.to_vec().iter().map(|&x| 1.0 / x).collect(),
2030 a_val.shape(),
2031 );
2032 accumulate_grad(&mut grads, *a, &grad.mul_elem_unchecked(&inv));
2033 }
2034 GradOp::Sin(a) => {
2035 let a_val = self.nodes[*a].borrow().tensor.clone();
2036 let cos_a = Tensor::from_vec_unchecked(
2037 a_val.to_vec().iter().map(|&x| x.cos()).collect(),
2038 a_val.shape(),
2039 );
2040 accumulate_grad(&mut grads, *a, &grad.mul_elem_unchecked(&cos_a));
2041 }
2042 GradOp::Cos(a) => {
2043 let a_val = self.nodes[*a].borrow().tensor.clone();
2044 let neg_sin = Tensor::from_vec_unchecked(
2045 a_val.to_vec().iter().map(|&x| -x.sin()).collect(),
2046 a_val.shape(),
2047 );
2048 accumulate_grad(&mut grads, *a, &grad.mul_elem_unchecked(&neg_sin));
2049 }
2050 GradOp::Sqrt(a) => {
2051 let inv2sqrt = Tensor::from_vec_unchecked(
2052 node_tensor.to_vec().iter().map(|&x| 0.5 / x).collect(),
2053 node_tensor.shape(),
2054 );
2055 accumulate_grad(&mut grads, *a, &grad.mul_elem_unchecked(&inv2sqrt));
2056 }
2057 GradOp::Pow(a, exp) => {
2058 let a_val = self.nodes[*a].borrow().tensor.clone();
2059 let local = Tensor::from_vec_unchecked(
2060 a_val.to_vec().iter().map(|&x| exp * x.powf(exp - 1.0)).collect(),
2061 a_val.shape(),
2062 );
2063 accumulate_grad(&mut grads, *a, &grad.mul_elem_unchecked(&local));
2064 }
2065 GradOp::Sigmoid(a) => {
2066 let sig = &node_tensor;
2067 let one_minus = Tensor::from_vec_unchecked(
2068 sig.to_vec().iter().map(|&x| 1.0 - x).collect(),
2069 sig.shape(),
2070 );
2071 let local = sig.mul_elem_unchecked(&one_minus);
2072 accumulate_grad(&mut grads, *a, &grad.mul_elem_unchecked(&local));
2073 }
2074 GradOp::Relu(a) => {
2075 let a_val = self.nodes[*a].borrow().tensor.clone();
2076 let mask = Tensor::from_vec_unchecked(
2077 a_val.to_vec().iter().map(|&x| if x > 0.0 { 1.0 } else { 0.0 }).collect(),
2078 a_val.shape(),
2079 );
2080 accumulate_grad(&mut grads, *a, &grad.mul_elem_unchecked(&mask));
2081 }
2082 GradOp::TanhAct(a) => {
2083 let one_minus_sq = Tensor::from_vec_unchecked(
2084 node_tensor.to_vec().iter().map(|&x| 1.0 - x * x).collect(),
2085 node_tensor.shape(),
2086 );
2087 accumulate_grad(&mut grads, *a, &grad.mul_elem_unchecked(&one_minus_sq));
2088 }
2089 GradOp::MatMul(a, b) => {
2090 let a_val = self.nodes[*a].borrow().tensor.clone();
2091 let b_val = self.nodes[*b].borrow().tensor.clone();
2092 accumulate_grad(&mut grads, *a, &grad.matmul_unchecked(&b_val.transpose()));
2093 accumulate_grad(&mut grads, *b, &a_val.transpose().matmul_unchecked(&grad));
2094 }
2095 GradOp::Sum(a) => {
2096 let a_shape = self.nodes[*a].borrow().tensor.shape().to_vec();
2097 let grad_val = grad.to_vec()[0];
2098 let expanded = Tensor::from_vec_unchecked(
2099 vec![grad_val; a_shape.iter().product()],
2100 &a_shape,
2101 );
2102 accumulate_grad(&mut grads, *a, &expanded);
2103 }
2104 GradOp::Mean(a) => {
2105 let a_shape = self.nodes[*a].borrow().tensor.shape().to_vec();
2106 let n_elem = a_shape.iter().product::<usize>() as f64;
2107 let grad_val = grad.to_vec()[0] / n_elem;
2108 let expanded = Tensor::from_vec_unchecked(
2109 vec![grad_val; a_shape.iter().product()],
2110 &a_shape,
2111 );
2112 accumulate_grad(&mut grads, *a, &expanded);
2113 }
2114 GradOp::StructField { parent, field_index, total_fields } => {
2115 let parent_shape = self.nodes[*parent].borrow().tensor.shape().to_vec();
2116 let parent_n: usize = parent_shape.iter().product();
2117 let chunk = parent_n / total_fields;
2118 let start = field_index * chunk;
2119 let mut parent_grad = vec![0.0_f64; parent_n];
2120 let g_vec = grad.to_vec();
2121 for (j, &gv) in g_vec.iter().enumerate() {
2122 parent_grad[start + j] = gv;
2123 }
2124 let pg = Tensor::from_vec_unchecked(parent_grad, &parent_shape);
2125 accumulate_grad(&mut grads, *parent, &pg);
2126 }
2127 GradOp::MapLookup { map_node, key_index, total_keys } => {
2128 let map_shape = self.nodes[*map_node].borrow().tensor.shape().to_vec();
2129 let map_n: usize = map_shape.iter().product();
2130 let chunk = map_n / total_keys;
2131 let start = key_index * chunk;
2132 let mut map_grad = vec![0.0_f64; map_n];
2133 let g_vec = grad.to_vec();
2134 for (j, &gv) in g_vec.iter().enumerate() {
2135 map_grad[start + j] = gv;
2136 }
2137 let mg = Tensor::from_vec_unchecked(map_grad, &map_shape);
2138 accumulate_grad(&mut grads, *map_node, &mg);
2139 }
2140 GradOp::Abs(a) => {
2142 let a_val = self.nodes[*a].borrow().tensor.clone();
2143 let sign = Tensor::from_vec_unchecked(
2144 a_val.to_vec().iter().map(|&x| {
2145 if x > 0.0 { 1.0 } else if x < 0.0 { -1.0 } else { 0.0 }
2146 }).collect(),
2147 a_val.shape(),
2148 );
2149 accumulate_grad(&mut grads, *a, &grad.mul_elem_unchecked(&sign));
2150 }
2151 GradOp::Log2(a) => {
2152 let a_val = self.nodes[*a].borrow().tensor.clone();
2153 let ln2 = std::f64::consts::LN_2;
2154 let local = Tensor::from_vec_unchecked(
2155 a_val.to_vec().iter().map(|&x| 1.0 / (x * ln2)).collect(),
2156 a_val.shape(),
2157 );
2158 accumulate_grad(&mut grads, *a, &grad.mul_elem_unchecked(&local));
2159 }
2160 GradOp::Softmax(a) => {
2161 use cjc_repro::KahanAccumulatorF64;
2162 let sm = &node_tensor;
2163 let sm_data = sm.to_vec();
2164 let grad_data = grad.to_vec();
2165 let mut dot_acc = KahanAccumulatorF64::new();
2166 for (&g, &s) in grad_data.iter().zip(sm_data.iter()) {
2167 dot_acc.add(g * s);
2168 }
2169 let dot = dot_acc.finalize();
2170 let grad_input: Vec<f64> = sm_data.iter().zip(grad_data.iter())
2171 .map(|(&s, &g)| s * (g - dot))
2172 .collect();
2173 let grad_a = Tensor::from_vec_unchecked(grad_input, sm.shape());
2174 accumulate_grad(&mut grads, *a, &grad_a);
2175 }
2176 GradOp::CrossEntropy { logits, targets } => {
2177 use cjc_repro::KahanAccumulatorF64;
2178 let logits_val = self.nodes[*logits].borrow().tensor.clone();
2179 let targets_val = self.nodes[*targets].borrow().tensor.clone();
2180 let logits_data = logits_val.to_vec();
2181 let targets_data = targets_val.to_vec();
2182 let max_val = logits_data.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
2183 let exp_shifted: Vec<f64> = logits_data.iter().map(|&x| (x - max_val).exp()).collect();
2184 let mut sum_acc = KahanAccumulatorF64::new();
2185 for &v in &exp_shifted {
2186 sum_acc.add(v);
2187 }
2188 let sum_exp = sum_acc.finalize();
2189 let softmax: Vec<f64> = exp_shifted.iter().map(|&e| e / sum_exp).collect();
2190 let upstream = grad.to_vec()[0];
2191 let grad_logits: Vec<f64> = softmax.iter().zip(targets_data.iter())
2192 .map(|(&s, &t)| upstream * (s - t))
2193 .collect();
2194 let gl = Tensor::from_vec_unchecked(grad_logits, logits_val.shape());
2195 accumulate_grad(&mut grads, *logits, &gl);
2196 }
2197 GradOp::LayerNorm(a) => {
2198 use cjc_repro::KahanAccumulatorF64;
2199 let x_hat = &node_tensor;
2200 let x_hat_data = x_hat.to_vec();
2201 let grad_data = grad.to_vec();
2202 let n = x_hat_data.len() as f64;
2203 let a_val = self.nodes[*a].borrow().tensor.clone();
2204 let a_data = a_val.to_vec();
2205 let mut mean_acc = KahanAccumulatorF64::new();
2206 for &v in &a_data { mean_acc.add(v); }
2207 let mean = mean_acc.finalize() / n;
2208 let mut var_acc = KahanAccumulatorF64::new();
2209 for &v in &a_data { let d = v - mean; var_acc.add(d * d); }
2210 let var = var_acc.finalize() / n;
2211 let std_val = (var + 1e-5).sqrt();
2212 let mut mg_acc = KahanAccumulatorF64::new();
2213 for &g in &grad_data { mg_acc.add(g); }
2214 let mean_grad = mg_acc.finalize() / n;
2215 let mut mgx_acc = KahanAccumulatorF64::new();
2216 for (&g, &xh) in grad_data.iter().zip(x_hat_data.iter()) { mgx_acc.add(g * xh); }
2217 let mean_grad_xhat = mgx_acc.finalize() / n;
2218 let dx: Vec<f64> = grad_data.iter().zip(x_hat_data.iter())
2219 .map(|(&g, &xh)| (g - mean_grad - xh * mean_grad_xhat) / std_val)
2220 .collect();
2221 accumulate_grad(&mut grads, *a, &Tensor::from_vec_unchecked(dx, a_val.shape()));
2222 }
2223 GradOp::BatchNorm(a) => {
2224 use cjc_repro::KahanAccumulatorF64;
2225 let x_hat = &node_tensor;
2226 let x_hat_data = x_hat.to_vec();
2227 let grad_data = grad.to_vec();
2228 let n = x_hat_data.len() as f64;
2229 let a_val = self.nodes[*a].borrow().tensor.clone();
2230 let a_data = a_val.to_vec();
2231 let mut mean_acc = KahanAccumulatorF64::new();
2232 for &v in &a_data { mean_acc.add(v); }
2233 let mean = mean_acc.finalize() / n;
2234 let mut var_acc = KahanAccumulatorF64::new();
2235 for &v in &a_data { let d = v - mean; var_acc.add(d * d); }
2236 let var = var_acc.finalize() / n;
2237 let std_val = (var + 1e-5).sqrt();
2238 let mut mg_acc = KahanAccumulatorF64::new();
2239 for &g in &grad_data { mg_acc.add(g); }
2240 let mean_grad = mg_acc.finalize() / n;
2241 let mut mgx_acc = KahanAccumulatorF64::new();
2242 for (&g, &xh) in grad_data.iter().zip(x_hat_data.iter()) { mgx_acc.add(g * xh); }
2243 let mean_grad_xhat = mgx_acc.finalize() / n;
2244 let dx: Vec<f64> = grad_data.iter().zip(x_hat_data.iter())
2245 .map(|(&g, &xh)| (g - mean_grad - xh * mean_grad_xhat) / std_val)
2246 .collect();
2247 accumulate_grad(&mut grads, *a, &Tensor::from_vec_unchecked(dx, a_val.shape()));
2248 }
2249 GradOp::Clamp { input, min, max } => {
2250 let a_val = self.nodes[*input].borrow().tensor.clone();
2251 let mask = Tensor::from_vec_unchecked(
2252 a_val.to_vec().iter().map(|&x| {
2253 if x >= *min && x <= *max { 1.0 } else { 0.0 }
2254 }).collect(),
2255 a_val.shape(),
2256 );
2257 accumulate_grad(&mut grads, *input, &grad.mul_elem_unchecked(&mask));
2258 }
2259 GradOp::Where { cond, on_true, on_false } => {
2260 let cond_data = self.nodes[*cond].borrow().tensor.to_vec();
2261 let grad_data = grad.to_vec();
2262 let shape = grad.shape().to_vec();
2263 let grad_true: Vec<f64> = cond_data.iter().zip(grad_data.iter())
2264 .map(|(&c, &g)| if c != 0.0 { g } else { 0.0 }).collect();
2265 let grad_false: Vec<f64> = cond_data.iter().zip(grad_data.iter())
2266 .map(|(&c, &g)| if c != 0.0 { 0.0 } else { g }).collect();
2267 accumulate_grad(&mut grads, *on_true, &Tensor::from_vec_unchecked(grad_true, &shape));
2268 accumulate_grad(&mut grads, *on_false, &Tensor::from_vec_unchecked(grad_false, &shape));
2269 }
2270 GradOp::Reshape { input, ref original_shape } => {
2271 let grad_a = grad.reshape(original_shape).expect("Reshape backward: shape mismatch");
2272 accumulate_grad(&mut grads, *input, &grad_a);
2273 }
2274 GradOp::TransposeOp(a) => {
2275 accumulate_grad(&mut grads, *a, &grad.transpose());
2276 }
2277 GradOp::CatOp { ref inputs, axis, ref sizes } => {
2278 let grad_data = grad.to_vec();
2279 let grad_shape = grad.shape().to_vec();
2280 let ndim = grad_shape.len();
2281 if ndim == 1 {
2282 let mut offset = 0usize;
2283 for (idx, &sz) in inputs.iter().zip(sizes.iter()) {
2284 let piece = grad_data[offset..offset + sz].to_vec();
2285 accumulate_grad(&mut grads, *idx, &Tensor::from_vec_unchecked(piece, &[sz]));
2286 offset += sz;
2287 }
2288 } else if ndim == 2 && *axis == 0 {
2289 let cols = grad_shape[1];
2290 let mut row_offset = 0usize;
2291 for (idx, &sz) in inputs.iter().zip(sizes.iter()) {
2292 let start = row_offset * cols;
2293 let end = start + sz * cols;
2294 let piece = grad_data[start..end].to_vec();
2295 accumulate_grad(&mut grads, *idx, &Tensor::from_vec_unchecked(piece, &[sz, cols]));
2296 row_offset += sz;
2297 }
2298 } else if ndim == 2 && *axis == 1 {
2299 let nrows = grad_shape[0];
2300 let total_cols = grad_shape[1];
2301 for (input_idx, (idx, &sz)) in inputs.iter().zip(sizes.iter()).enumerate() {
2302 let mut piece = Vec::with_capacity(nrows * sz);
2303 let col_offset: usize = sizes[..input_idx].iter().sum();
2304 for row in 0..nrows {
2305 let row_start = row * total_cols + col_offset;
2306 piece.extend_from_slice(&grad_data[row_start..row_start + sz]);
2307 }
2308 accumulate_grad(&mut grads, *idx, &Tensor::from_vec_unchecked(piece, &[nrows, sz]));
2309 }
2310 } else {
2311 let mut offset = 0usize;
2312 for (idx, &sz) in inputs.iter().zip(sizes.iter()) {
2313 let piece_len = sz * grad_data.len() / grad_shape[*axis];
2314 let piece = grad_data[offset..offset + piece_len].to_vec();
2315 let mut piece_shape = grad_shape.clone();
2316 piece_shape[*axis] = sz;
2317 accumulate_grad(&mut grads, *idx, &Tensor::from_vec_unchecked(piece, &piece_shape));
2318 offset += piece_len;
2319 }
2320 }
2321 }
2322 GradOp::GatherOp { input, ref indices, axis } => {
2323 let input_shape = self.nodes[*input].borrow().tensor.shape().to_vec();
2324 let input_len: usize = input_shape.iter().product();
2325 let mut scatter = vec![0.0_f64; input_len];
2326 let grad_data = grad.to_vec();
2327 if self.nodes[*input].borrow().tensor.ndim() == 1 {
2328 for (gi, &idx) in indices.iter().enumerate() {
2329 scatter[idx] += grad_data[gi];
2330 }
2331 } else if *axis == 0 && self.nodes[*input].borrow().tensor.ndim() == 2 {
2332 let cols = input_shape[1];
2333 for (gi, &idx) in indices.iter().enumerate() {
2334 for c in 0..cols {
2335 scatter[idx * cols + c] += grad_data[gi * cols + c];
2336 }
2337 }
2338 } else {
2339 for (gi, &idx) in indices.iter().enumerate() {
2340 scatter[idx] += grad_data[gi];
2341 }
2342 }
2343 accumulate_grad(&mut grads, *input, &Tensor::from_vec_unchecked(scatter, &input_shape));
2344 }
2345 }
2346 }
2347 }
2348}
2349
2350fn accumulate_grad(grads: &mut [Option<Tensor>], idx: usize, grad: &Tensor) {
2351 if let Some(existing) = &grads[idx] {
2352 grads[idx] = Some(existing.add_unchecked(grad));
2353 } else {
2354 grads[idx] = Some(grad.clone());
2355 }
2356}
2357
2358impl Default for GradGraph {
2361 fn default() -> Self {
2362 Self::new()
2363 }
2364}
2365
2366pub fn check_grad_finite_diff<F>(
2370 f: F,
2371 x: f64,
2372 expected_grad: f64,
2373 eps: f64,
2374 tol: f64,
2375) -> bool
2376where
2377 F: Fn(f64) -> f64,
2378{
2379 let fd_grad = (f(x + eps) - f(x - eps)) / (2.0 * eps);
2380 (fd_grad - expected_grad).abs() < tol
2381}
2382
2383#[cfg(test)]
2384mod tests {
2385 use super::*;
2386
2387 #[test]
2390 fn test_dual_add() {
2391 let a = Dual::variable(3.0);
2392 let b = Dual::constant(2.0);
2393 let c = a + b;
2394 assert_eq!(c.value, 5.0);
2395 assert_eq!(c.deriv, 1.0);
2396 }
2397
2398 #[test]
2399 fn test_dual_mul() {
2400 let a = Dual::variable(3.0);
2401 let b = Dual::constant(2.0);
2402 let c = a * b;
2403 assert_eq!(c.value, 6.0);
2404 assert_eq!(c.deriv, 2.0); }
2406
2407 #[test]
2408 fn test_dual_chain_rule() {
2409 let x = Dual::variable(3.0);
2411 let result = x.clone() * x.clone() + Dual::constant(2.0) * x + Dual::one();
2412 assert_eq!(result.value, 16.0);
2413 assert_eq!(result.deriv, 8.0);
2414 }
2415
2416 #[test]
2417 fn test_dual_exp() {
2418 let x = Dual::variable(1.0);
2419 let result = x.exp();
2420 assert!((result.value - std::f64::consts::E).abs() < 1e-10);
2421 assert!((result.deriv - std::f64::consts::E).abs() < 1e-10);
2422 }
2423
2424 #[test]
2425 fn test_dual_sin_cos() {
2426 let x = Dual::variable(0.0);
2427 let sin_x = x.clone().sin();
2428 let cos_x = x.cos();
2429 assert!((sin_x.value - 0.0).abs() < 1e-10);
2430 assert!((sin_x.deriv - 1.0).abs() < 1e-10); assert!((cos_x.value - 1.0).abs() < 1e-10);
2432 assert!((cos_x.deriv - 0.0).abs() < 1e-10); }
2434
2435 #[test]
2436 fn test_dual_div() {
2437 let a = Dual::variable(6.0);
2438 let b = Dual::constant(3.0);
2439 let c = a / b;
2440 assert_eq!(c.value, 2.0);
2441 assert!((c.deriv - 1.0 / 3.0).abs() < 1e-10);
2442 }
2443
2444 #[test]
2445 fn test_finite_diff_validation() {
2446 let f = |x: f64| x * x;
2448 assert!(check_grad_finite_diff(f, 3.0, 6.0, 1e-7, 1e-5));
2449 }
2450
2451 #[test]
2454 fn test_reverse_add() {
2455 let mut g = GradGraph::new();
2456 let a = g.parameter(Tensor::from_vec_unchecked(vec![3.0], &[1]));
2457 let b = g.parameter(Tensor::from_vec_unchecked(vec![2.0], &[1]));
2458 let c = g.add(a, b);
2459
2460 g.backward(c);
2461
2462 let ga = g.grad(a).unwrap();
2463 let gb = g.grad(b).unwrap();
2464 assert!((ga.to_vec()[0] - 1.0).abs() < 1e-10);
2465 assert!((gb.to_vec()[0] - 1.0).abs() < 1e-10);
2466 }
2467
2468 #[test]
2469 fn test_reverse_mul() {
2470 let mut g = GradGraph::new();
2471 let a = g.parameter(Tensor::from_vec_unchecked(vec![3.0], &[1]));
2472 let b = g.parameter(Tensor::from_vec_unchecked(vec![2.0], &[1]));
2473 let c = g.mul(a, b);
2474
2475 g.backward(c);
2476
2477 let ga = g.grad(a).unwrap();
2478 let gb = g.grad(b).unwrap();
2479 assert!((ga.to_vec()[0] - 2.0).abs() < 1e-10); assert!((gb.to_vec()[0] - 3.0).abs() < 1e-10); }
2482
2483 #[test]
2484 fn test_reverse_matmul_gradient() {
2485 let mut g = GradGraph::new();
2486
2487 let a = g.parameter(Tensor::from_vec_unchecked(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]));
2489 let b = g.parameter(Tensor::from_vec_unchecked(vec![5.0, 6.0, 7.0, 8.0], &[2, 2]));
2490 let c = g.matmul(a, b);
2491 let loss = g.sum(c);
2492
2493 g.backward(loss);
2494
2495 let ga = g.grad(a).unwrap();
2497 let ga_data = ga.to_vec();
2498 assert!((ga_data[0] - 11.0).abs() < 1e-10);
2503 assert!((ga_data[1] - 15.0).abs() < 1e-10);
2504 }
2505
2506 #[test]
2507 fn test_reverse_mean_gradient() {
2508 let mut g = GradGraph::new();
2509 let a = g.parameter(Tensor::from_vec_unchecked(vec![2.0, 4.0, 6.0, 8.0], &[4]));
2510 let loss = g.mean(a);
2511
2512 g.backward(loss);
2513
2514 let ga = g.grad(a).unwrap();
2515 let ga_data = ga.to_vec();
2516 for &v in &ga_data {
2518 assert!((v - 0.25).abs() < 1e-10);
2519 }
2520 }
2521
2522 #[test]
2525 fn test_reverse_sin() {
2526 let mut g = GradGraph::new();
2527 let a = g.parameter(Tensor::from_vec_unchecked(vec![0.0], &[1]));
2528 let b = g.sin(a);
2529 g.backward(b);
2530 let ga = g.grad(a).unwrap();
2531 assert!((ga.to_vec()[0] - 1.0).abs() < 1e-10);
2533 }
2534
2535 #[test]
2536 fn test_reverse_cos() {
2537 let mut g = GradGraph::new();
2538 let a = g.parameter(Tensor::from_vec_unchecked(vec![0.0], &[1]));
2539 let b = g.cos(a);
2540 g.backward(b);
2541 let ga = g.grad(a).unwrap();
2542 assert!(ga.to_vec()[0].abs() < 1e-10);
2544 }
2545
2546 #[test]
2547 fn test_reverse_sqrt() {
2548 let mut g = GradGraph::new();
2549 let a = g.parameter(Tensor::from_vec_unchecked(vec![4.0], &[1]));
2550 let b = g.sqrt(a);
2551 g.backward(b);
2552 let ga = g.grad(a).unwrap();
2553 assert!((ga.to_vec()[0] - 0.25).abs() < 1e-10);
2555 }
2556
2557 #[test]
2558 fn test_reverse_pow() {
2559 let mut g = GradGraph::new();
2560 let a = g.parameter(Tensor::from_vec_unchecked(vec![2.0], &[1]));
2561 let b = g.pow(a, 3.0); g.backward(b);
2563 let ga = g.grad(a).unwrap();
2564 assert!((ga.to_vec()[0] - 12.0).abs() < 1e-10);
2566 }
2567
2568 #[test]
2569 fn test_reverse_sigmoid() {
2570 let mut g = GradGraph::new();
2571 let a = g.parameter(Tensor::from_vec_unchecked(vec![0.0], &[1]));
2572 let b = g.sigmoid(a);
2573 g.backward(b);
2574 let ga = g.grad(a).unwrap();
2575 assert!((ga.to_vec()[0] - 0.25).abs() < 1e-10);
2577 }
2578
2579 #[test]
2580 fn test_reverse_relu_positive() {
2581 let mut g = GradGraph::new();
2582 let a = g.parameter(Tensor::from_vec_unchecked(vec![3.0], &[1]));
2583 let b = g.relu(a);
2584 g.backward(b);
2585 let ga = g.grad(a).unwrap();
2586 assert!((ga.to_vec()[0] - 1.0).abs() < 1e-10);
2588 }
2589
2590 #[test]
2591 fn test_reverse_relu_negative() {
2592 let mut g = GradGraph::new();
2593 let a = g.parameter(Tensor::from_vec_unchecked(vec![-2.0], &[1]));
2594 let b = g.relu(a);
2595 g.backward(b);
2596 let ga = g.grad(a).unwrap();
2597 assert!(ga.to_vec()[0].abs() < 1e-10);
2599 }
2600
2601 #[test]
2602 fn test_reverse_tanh() {
2603 let mut g = GradGraph::new();
2604 let a = g.parameter(Tensor::from_vec_unchecked(vec![0.0], &[1]));
2605 let b = g.tanh_act(a);
2606 g.backward(b);
2607 let ga = g.grad(a).unwrap();
2608 assert!((ga.to_vec()[0] - 1.0).abs() < 1e-10);
2610 }
2611
2612 #[test]
2613 fn test_reverse_sin_cos_chain() {
2614 let mut g = GradGraph::new();
2617 let a = g.parameter(Tensor::from_vec_unchecked(vec![1.0], &[1]));
2618 let c = g.cos(a);
2619 let s = g.sin(c);
2620 g.backward(s);
2621 let ga = g.grad(a).unwrap();
2622 let expected = 1.0_f64.cos().cos() * (-1.0_f64.sin());
2623 assert!((ga.to_vec()[0] - expected).abs() < 1e-10, "got {}, expected {expected}", ga.to_vec()[0]);
2624 }
2625
2626 #[test]
2627 fn test_reverse_sigmoid_sum() {
2628 let mut g = GradGraph::new();
2630 let a = g.parameter(Tensor::from_vec_unchecked(vec![0.0, 1.0, -1.0], &[3]));
2631 let s = g.sigmoid(a);
2632 let loss = g.sum(s);
2633 g.backward(loss);
2634 let ga = g.grad(a).unwrap();
2635 let ga_data = ga.to_vec();
2636 let sig1 = 1.0 / (1.0 + (-1.0_f64).exp());
2638 let sig_neg1 = 1.0 / (1.0 + 1.0_f64.exp());
2639 assert!((ga_data[0] - 0.25).abs() < 1e-10);
2640 assert!((ga_data[1] - sig1 * (1.0 - sig1)).abs() < 1e-10);
2641 assert!((ga_data[2] - sig_neg1 * (1.0 - sig_neg1)).abs() < 1e-10);
2642 }
2643
2644 #[test]
2645 fn test_b8_determinism() {
2646 let mut g1 = GradGraph::new();
2647 let a1 = g1.parameter(Tensor::from_vec_unchecked(vec![1.5], &[1]));
2648 let s1 = g1.sin(a1);
2649 g1.backward(s1);
2650 let ga1 = g1.grad(a1).unwrap().to_vec()[0];
2651
2652 let mut g2 = GradGraph::new();
2653 let a2 = g2.parameter(Tensor::from_vec_unchecked(vec![1.5], &[1]));
2654 let s2 = g2.sin(a2);
2655 g2.backward(s2);
2656 let ga2 = g2.grad(a2).unwrap().to_vec()[0];
2657
2658 assert_eq!(ga1.to_bits(), ga2.to_bits());
2659 }
2660
2661 #[test]
2662 fn test_reverse_mse_loss() {
2663 let mut g = GradGraph::new();
2665
2666 let w = g.parameter(Tensor::from_vec_unchecked(vec![1.0, 1.0], &[2, 1]));
2667 let x = g.input(Tensor::from_vec_unchecked(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]));
2668 let target = g.input(Tensor::from_vec_unchecked(vec![3.0, 7.0], &[2, 1]));
2669
2670 let pred = g.matmul(x, w);
2671 let diff = g.sub(pred, target);
2672 let sq = g.mul(diff, diff);
2673 let loss = g.mean(sq);
2674
2675 let loss_val = g.value(loss);
2676 g.backward(loss);
2677
2678 let gw = g.grad(w).unwrap();
2679
2680 assert!(loss_val.is_finite());
2682 assert_eq!(gw.to_vec().len(), 2);
2683 for &v in &gw.to_vec() {
2684 assert!(v.is_finite());
2685 }
2686 }
2687
2688 #[test]
2691 fn test_reverse_div() {
2692 let mut g = GradGraph::new();
2694 let a = g.parameter(Tensor::from_vec_unchecked(vec![6.0], &[1]));
2695 let b = g.input(Tensor::from_vec_unchecked(vec![2.0], &[1]));
2696 let c = g.div(a, b);
2697 g.backward(c);
2698 let ga = g.grad(a).unwrap();
2699 assert!((ga.to_vec()[0] - 0.5).abs() < 1e-10);
2700 }
2701
2702 #[test]
2703 fn test_reverse_neg() {
2704 let mut g = GradGraph::new();
2706 let a = g.parameter(Tensor::from_vec_unchecked(vec![3.0], &[1]));
2707 let c = g.neg(a);
2708 g.backward(c);
2709 let ga = g.grad(a).unwrap();
2710 assert!((ga.to_vec()[0] - (-1.0)).abs() < 1e-10);
2711 }
2712
2713 #[test]
2714 fn test_reverse_scalar_mul() {
2715 let mut g = GradGraph::new();
2717 let a = g.parameter(Tensor::from_vec_unchecked(vec![2.0], &[1]));
2718 let c = g.scalar_mul(a, 3.0);
2719 g.backward(c);
2720 let ga = g.grad(a).unwrap();
2721 assert!((ga.to_vec()[0] - 3.0).abs() < 1e-10);
2722 }
2723
2724 #[test]
2725 fn test_reverse_exp() {
2726 let mut g = GradGraph::new();
2728 let a = g.parameter(Tensor::from_vec_unchecked(vec![1.0], &[1]));
2729 let c = g.exp(a);
2730 g.backward(c);
2731 let ga = g.grad(a).unwrap();
2732 assert!((ga.to_vec()[0] - std::f64::consts::E).abs() < 1e-10);
2733 }
2734
2735 #[test]
2736 fn test_reverse_ln() {
2737 let mut g = GradGraph::new();
2739 let a = g.parameter(Tensor::from_vec_unchecked(vec![2.0], &[1]));
2740 let c = g.ln(a);
2741 g.backward(c);
2742 let ga = g.grad(a).unwrap();
2743 assert!((ga.to_vec()[0] - 0.5).abs() < 1e-10);
2744 }
2745
2746 #[test]
2749 fn test_reverse_abs_positive() {
2750 let mut g = GradGraph::new();
2751 let a = g.parameter(Tensor::from_vec_unchecked(vec![3.0], &[1]));
2752 let b = g.abs(a);
2753 g.backward(b);
2754 let ga = g.grad(a).unwrap();
2755 assert!((ga.to_vec()[0] - 1.0).abs() < 1e-10);
2757 }
2758
2759 #[test]
2760 fn test_reverse_abs_negative() {
2761 let mut g = GradGraph::new();
2762 let a = g.parameter(Tensor::from_vec_unchecked(vec![-2.5], &[1]));
2763 let b = g.abs(a);
2764 g.backward(b);
2765 let ga = g.grad(a).unwrap();
2766 assert!((ga.to_vec()[0] - (-1.0)).abs() < 1e-10);
2768 }
2769
2770 #[test]
2771 fn test_reverse_abs_zero() {
2772 let mut g = GradGraph::new();
2773 let a = g.parameter(Tensor::from_vec_unchecked(vec![0.0], &[1]));
2774 let b = g.abs(a);
2775 g.backward(b);
2776 let ga = g.grad(a).unwrap();
2777 assert!(ga.to_vec()[0].abs() < 1e-10);
2779 }
2780
2781 #[test]
2782 fn test_reverse_abs_vector() {
2783 let mut g = GradGraph::new();
2784 let a = g.parameter(Tensor::from_vec_unchecked(vec![-1.0, 2.0, 0.0, -3.0], &[4]));
2785 let b = g.abs(a);
2786 let loss = g.sum(b);
2787 g.backward(loss);
2788 let ga = g.grad(a).unwrap();
2789 let expected = vec![-1.0, 1.0, 0.0, -1.0];
2790 for (i, (&got, &exp)) in ga.to_vec().iter().zip(expected.iter()).enumerate() {
2791 assert!((got - exp).abs() < 1e-10, "abs grad[{i}]: got {got}, expected {exp}");
2792 }
2793 }
2794
2795 #[test]
2796 fn test_reverse_log2() {
2797 let mut g = GradGraph::new();
2798 let a = g.parameter(Tensor::from_vec_unchecked(vec![4.0], &[1]));
2799 let b = g.log2(a);
2800 g.backward(b);
2801 let ga = g.grad(a).unwrap();
2802 let expected = 1.0 / (4.0 * std::f64::consts::LN_2);
2804 assert!((ga.to_vec()[0] - expected).abs() < 1e-10);
2805 }
2806
2807 #[test]
2808 fn test_reverse_log2_vector() {
2809 let mut g = GradGraph::new();
2810 let a = g.parameter(Tensor::from_vec_unchecked(vec![1.0, 2.0, 8.0], &[3]));
2811 let b = g.log2(a);
2812 let loss = g.sum(b);
2813 g.backward(loss);
2814 let ga = g.grad(a).unwrap();
2815 let ln2 = std::f64::consts::LN_2;
2816 let expected = vec![1.0 / (1.0 * ln2), 1.0 / (2.0 * ln2), 1.0 / (8.0 * ln2)];
2817 for (i, (&got, &exp)) in ga.to_vec().iter().zip(expected.iter()).enumerate() {
2818 assert!((got - exp).abs() < 1e-10, "log2 grad[{i}]: got {got}, expected {exp}");
2819 }
2820 }
2821
2822 #[test]
2823 fn test_softmax_forward() {
2824 let mut g = GradGraph::new();
2825 let a = g.input(Tensor::from_vec_unchecked(vec![1.0, 2.0, 3.0], &[3]));
2826 let b = g.softmax(a);
2827 let sm = g.tensor(b);
2828 let sm_data = sm.to_vec();
2829 let sum: f64 = sm_data.iter().sum();
2831 assert!((sum - 1.0).abs() < 1e-10);
2832 assert!(sm_data[2] > sm_data[1]);
2834 assert!(sm_data[1] > sm_data[0]);
2835 }
2836
2837 #[test]
2838 fn test_reverse_softmax() {
2839 let mut g = GradGraph::new();
2841 let a = g.parameter(Tensor::from_vec_unchecked(vec![1.0, 2.0, 3.0], &[3]));
2842 let b = g.softmax(a);
2843 let loss = g.sum(b);
2844 g.backward(loss);
2845 let ga = g.grad(a).unwrap();
2846 for &v in &ga.to_vec() {
2848 assert!(v.abs() < 1e-10, "softmax sum grad should be 0, got {v}");
2849 }
2850 }
2851
2852 #[test]
2853 fn test_reverse_softmax_single_element() {
2854 let mut g = GradGraph::new();
2856 let a = g.parameter(Tensor::from_vec_unchecked(vec![2.0, 1.0], &[2]));
2857 let b = g.softmax(a);
2858 let loss = g.sum(b);
2861 g.backward(loss);
2862 let ga = g.grad(a).unwrap();
2863 for &v in &ga.to_vec() {
2864 assert!(v.abs() < 1e-10);
2865 }
2866 }
2867
2868 #[test]
2869 fn test_cross_entropy_forward() {
2870 let mut g = GradGraph::new();
2871 let logits = g.input(Tensor::from_vec_unchecked(vec![2.0, 1.0, 0.1], &[3]));
2872 let targets = g.input(Tensor::from_vec_unchecked(vec![1.0, 0.0, 0.0], &[3])); let ce = g.cross_entropy(logits, targets);
2874 let loss_val = g.value(ce);
2875 assert!(loss_val > 0.0, "CE loss should be positive");
2876 assert!(loss_val.is_finite(), "CE loss should be finite");
2877 }
2878
2879 #[test]
2880 fn test_reverse_cross_entropy() {
2881 let mut g = GradGraph::new();
2882 let logits = g.parameter(Tensor::from_vec_unchecked(vec![2.0, 1.0, 0.1], &[3]));
2883 let targets = g.input(Tensor::from_vec_unchecked(vec![1.0, 0.0, 0.0], &[3]));
2884 let ce = g.cross_entropy(logits, targets);
2885 g.backward(ce);
2886 let ga = g.grad(logits).unwrap();
2887 let ga_data = ga.to_vec();
2888 assert!(ga_data[0] < 0.0, "CE grad for correct class should be negative");
2892 assert!(ga_data[1] > 0.0, "CE grad for incorrect class should be positive");
2893 assert!(ga_data[2] > 0.0, "CE grad for incorrect class should be positive");
2894 let sum: f64 = ga_data.iter().sum();
2896 assert!(sum.abs() < 1e-10, "CE grad should sum to 0, got {sum}");
2897 }
2898
2899 #[test]
2900 fn test_layer_norm_forward() {
2901 let mut g = GradGraph::new();
2902 let a = g.input(Tensor::from_vec_unchecked(vec![1.0, 2.0, 3.0, 4.0], &[4]));
2903 let b = g.layer_norm(a);
2904 let normed = g.tensor(b).to_vec();
2905 let mean: f64 = normed.iter().sum::<f64>() / normed.len() as f64;
2907 assert!(mean.abs() < 1e-5, "LayerNorm mean should be ~0, got {mean}");
2908 let var: f64 = normed.iter().map(|x| (x - mean) * (x - mean)).sum::<f64>() / normed.len() as f64;
2909 assert!((var - 1.0).abs() < 0.01, "LayerNorm variance should be ~1, got {var}");
2910 }
2911
2912 #[test]
2913 fn test_reverse_layer_norm() {
2914 let mut g = GradGraph::new();
2915 let a = g.parameter(Tensor::from_vec_unchecked(vec![1.0, 2.0, 3.0, 4.0], &[4]));
2916 let b = g.layer_norm(a);
2917 let loss = g.sum(b);
2918 g.backward(loss);
2919 let ga = g.grad(a).unwrap();
2920 for &v in &ga.to_vec() {
2922 assert!(v.is_finite(), "LayerNorm grad should be finite");
2923 }
2924 let eps = 1e-5;
2926 let input_data = vec![1.0, 2.0, 3.0, 4.0];
2927 for i in 0..4 {
2928 let mut plus = input_data.clone();
2929 plus[i] += eps;
2930 let mut g_plus = GradGraph::new();
2931 let a_plus = g_plus.input(Tensor::from_vec_unchecked(plus, &[4]));
2932 let b_plus = g_plus.layer_norm(a_plus);
2933 let loss_plus = g_plus.sum(b_plus);
2934 let val_plus = g_plus.value(loss_plus);
2935
2936 let mut minus = input_data.clone();
2937 minus[i] -= eps;
2938 let mut g_minus = GradGraph::new();
2939 let a_minus = g_minus.input(Tensor::from_vec_unchecked(minus, &[4]));
2940 let b_minus = g_minus.layer_norm(a_minus);
2941 let loss_minus = g_minus.sum(b_minus);
2942 let val_minus = g_minus.value(loss_minus);
2943
2944 let fd_grad = (val_plus - val_minus) / (2.0 * eps);
2945 let ad_grad = ga.to_vec()[i];
2946 assert!(
2947 (fd_grad - ad_grad).abs() < 1e-4,
2948 "LayerNorm FD check failed at [{i}]: fd={fd_grad}, ad={ad_grad}"
2949 );
2950 }
2951 }
2952
2953 #[test]
2954 fn test_batch_norm_forward() {
2955 let mut g = GradGraph::new();
2956 let a = g.input(Tensor::from_vec_unchecked(vec![2.0, 4.0, 6.0, 8.0], &[4]));
2957 let b = g.batch_norm(a);
2958 let normed = g.tensor(b).to_vec();
2959 let mean: f64 = normed.iter().sum::<f64>() / normed.len() as f64;
2960 assert!(mean.abs() < 1e-5, "BatchNorm mean should be ~0, got {mean}");
2961 }
2962
2963 #[test]
2964 fn test_reverse_batch_norm() {
2965 let mut g = GradGraph::new();
2966 let a = g.parameter(Tensor::from_vec_unchecked(vec![2.0, 4.0, 6.0, 8.0], &[4]));
2967 let b = g.batch_norm(a);
2968 let loss = g.sum(b);
2969 g.backward(loss);
2970 let ga = g.grad(a).unwrap();
2971 for &v in &ga.to_vec() {
2972 assert!(v.is_finite(), "BatchNorm grad should be finite");
2973 }
2974 }
2975
2976 #[test]
2977 fn test_reverse_clamp_in_range() {
2978 let mut g = GradGraph::new();
2979 let a = g.parameter(Tensor::from_vec_unchecked(vec![1.5], &[1]));
2980 let b = g.clamp(a, 0.0, 3.0);
2981 g.backward(b);
2982 let ga = g.grad(a).unwrap();
2983 assert!((ga.to_vec()[0] - 1.0).abs() < 1e-10);
2985 }
2986
2987 #[test]
2988 fn test_reverse_clamp_out_of_range() {
2989 let mut g = GradGraph::new();
2990 let a = g.parameter(Tensor::from_vec_unchecked(vec![5.0], &[1]));
2991 let b = g.clamp(a, 0.0, 3.0);
2992 g.backward(b);
2993 let ga = g.grad(a).unwrap();
2994 assert!(ga.to_vec()[0].abs() < 1e-10);
2996 }
2997
2998 #[test]
2999 fn test_reverse_clamp_vector() {
3000 let mut g = GradGraph::new();
3001 let a = g.parameter(Tensor::from_vec_unchecked(vec![-1.0, 0.5, 2.0, 4.0], &[4]));
3002 let b = g.clamp(a, 0.0, 3.0);
3003 let loss = g.sum(b);
3004 g.backward(loss);
3005 let ga = g.grad(a).unwrap();
3006 let expected = vec![0.0, 1.0, 1.0, 0.0]; for (i, (&got, &exp)) in ga.to_vec().iter().zip(expected.iter()).enumerate() {
3008 assert!((got - exp).abs() < 1e-10, "clamp grad[{i}]: got {got}, expected {exp}");
3009 }
3010 }
3011
3012 #[test]
3013 fn test_reverse_where_cond() {
3014 let mut g = GradGraph::new();
3015 let cond = g.input(Tensor::from_vec_unchecked(vec![1.0, 0.0, 1.0], &[3]));
3016 let a = g.parameter(Tensor::from_vec_unchecked(vec![10.0, 20.0, 30.0], &[3]));
3017 let b = g.parameter(Tensor::from_vec_unchecked(vec![100.0, 200.0, 300.0], &[3]));
3018 let w = g.where_cond(cond, a, b);
3019 let result = g.tensor(w).to_vec();
3021 assert!((result[0] - 10.0).abs() < 1e-10);
3022 assert!((result[1] - 200.0).abs() < 1e-10);
3023 assert!((result[2] - 30.0).abs() < 1e-10);
3024 let loss = g.sum(w);
3025 g.backward(loss);
3026 let ga = g.grad(a).unwrap().to_vec();
3027 let gb = g.grad(b).unwrap().to_vec();
3028 assert!((ga[0] - 1.0).abs() < 1e-10);
3030 assert!(ga[1].abs() < 1e-10);
3031 assert!((ga[2] - 1.0).abs() < 1e-10);
3032 assert!(gb[0].abs() < 1e-10);
3033 assert!((gb[1] - 1.0).abs() < 1e-10);
3034 assert!(gb[2].abs() < 1e-10);
3035 }
3036
3037 #[test]
3038 fn test_reverse_reshape() {
3039 let mut g = GradGraph::new();
3040 let a = g.parameter(Tensor::from_vec_unchecked(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]));
3041 let b = g.reshape(a, &[3, 2]);
3042 let loss = g.sum(b);
3043 g.backward(loss);
3044 let ga = g.grad(a).unwrap();
3045 assert_eq!(ga.shape(), &[2, 3]);
3047 for &v in &ga.to_vec() {
3048 assert!((v - 1.0).abs() < 1e-10);
3049 }
3050 }
3051
3052 #[test]
3053 fn test_reverse_transpose_op() {
3054 let mut g = GradGraph::new();
3055 let a = g.parameter(Tensor::from_vec_unchecked(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]));
3056 let b = g.transpose_op(a);
3057 assert_eq!(g.tensor(b).shape(), &[3, 2]);
3059 let loss = g.sum(b);
3060 g.backward(loss);
3061 let ga = g.grad(a).unwrap();
3062 assert_eq!(ga.shape(), &[2, 3]);
3064 for &v in &ga.to_vec() {
3065 assert!((v - 1.0).abs() < 1e-10);
3066 }
3067 }
3068
3069 #[test]
3070 fn test_reverse_cat_1d() {
3071 let mut g = GradGraph::new();
3072 let a = g.parameter(Tensor::from_vec_unchecked(vec![1.0, 2.0], &[2]));
3073 let b = g.parameter(Tensor::from_vec_unchecked(vec![3.0, 4.0, 5.0], &[3]));
3074 let c = g.cat(&[a, b], 0);
3075 let result = g.tensor(c).to_vec();
3077 assert_eq!(result, vec![1.0, 2.0, 3.0, 4.0, 5.0]);
3078 let loss = g.sum(c);
3079 g.backward(loss);
3080 let ga = g.grad(a).unwrap().to_vec();
3081 let gb = g.grad(b).unwrap().to_vec();
3082 assert_eq!(ga, vec![1.0, 1.0]);
3084 assert_eq!(gb, vec![1.0, 1.0, 1.0]);
3085 }
3086
3087 #[test]
3088 fn test_reverse_cat_2d_axis0() {
3089 let mut g = GradGraph::new();
3090 let a = g.parameter(Tensor::from_vec_unchecked(vec![1.0, 2.0], &[1, 2]));
3091 let b = g.parameter(Tensor::from_vec_unchecked(vec![3.0, 4.0, 5.0, 6.0], &[2, 2]));
3092 let c = g.cat(&[a, b], 0);
3093 assert_eq!(g.tensor(c).shape(), &[3, 2]);
3094 let loss = g.sum(c);
3095 g.backward(loss);
3096 let ga = g.grad(a).unwrap();
3097 let gb = g.grad(b).unwrap();
3098 assert_eq!(ga.shape(), &[1, 2]);
3099 assert_eq!(gb.shape(), &[2, 2]);
3100 for &v in &ga.to_vec() {
3101 assert!((v - 1.0).abs() < 1e-10);
3102 }
3103 for &v in &gb.to_vec() {
3104 assert!((v - 1.0).abs() < 1e-10);
3105 }
3106 }
3107
3108 #[test]
3109 fn test_reverse_gather_1d() {
3110 let mut g = GradGraph::new();
3111 let a = g.parameter(Tensor::from_vec_unchecked(vec![10.0, 20.0, 30.0, 40.0], &[4]));
3112 let b = g.gather(a, &[1, 3], 0);
3113 let result = g.tensor(b).to_vec();
3115 assert!((result[0] - 20.0).abs() < 1e-10);
3116 assert!((result[1] - 40.0).abs() < 1e-10);
3117 let loss = g.sum(b);
3118 g.backward(loss);
3119 let ga = g.grad(a).unwrap().to_vec();
3120 assert!((ga[0]).abs() < 1e-10);
3122 assert!((ga[1] - 1.0).abs() < 1e-10);
3123 assert!((ga[2]).abs() < 1e-10);
3124 assert!((ga[3] - 1.0).abs() < 1e-10);
3125 }
3126
3127 #[test]
3128 fn test_reverse_gather_duplicate_indices() {
3129 let mut g = GradGraph::new();
3130 let a = g.parameter(Tensor::from_vec_unchecked(vec![10.0, 20.0, 30.0], &[3]));
3131 let b = g.gather(a, &[1, 1, 2], 0);
3132 let loss = g.sum(b);
3133 g.backward(loss);
3134 let ga = g.grad(a).unwrap().to_vec();
3135 assert!((ga[0]).abs() < 1e-10);
3137 assert!((ga[1] - 2.0).abs() < 1e-10);
3138 assert!((ga[2] - 1.0).abs() < 1e-10);
3139 }
3140
3141 #[test]
3142 fn test_phase8_determinism() {
3143 for _ in 0..2 {
3145 let run = || {
3146 let mut g = GradGraph::new();
3147 let a = g.parameter(Tensor::from_vec_unchecked(vec![1.0, -2.0, 3.0, -0.5], &[4]));
3148 let b = g.abs(a);
3149 let c = g.clamp(b, 0.0, 2.5);
3150 let d = g.layer_norm(c);
3151 let loss = g.sum(d);
3152 g.backward(loss);
3153 g.grad(a).unwrap().to_vec()
3154 };
3155 let r1 = run();
3156 let r2 = run();
3157 for (i, (v1, v2)) in r1.iter().zip(r2.iter()).enumerate() {
3158 assert_eq!(v1.to_bits(), v2.to_bits(), "Determinism failed at [{i}]");
3159 }
3160 }
3161 }
3162
3163 #[test]
3164 fn test_phase8_softmax_cross_entropy_chain() {
3165 let mut g = GradGraph::new();
3167 let logits = g.parameter(Tensor::from_vec_unchecked(vec![1.0, 2.0, 3.0], &[3]));
3168 let targets = g.input(Tensor::from_vec_unchecked(vec![0.0, 0.0, 1.0], &[3]));
3169 let ce = g.cross_entropy(logits, targets);
3170 g.backward(ce);
3171 let ga = g.grad(logits).unwrap().to_vec();
3172 for &v in &ga {
3174 assert!(v.is_finite());
3175 }
3176 assert!(ga[2] < 0.0, "CE grad for correct class should be negative");
3178 }
3179
3180 #[test]
3183 fn test_double_backward_cubic() {
3184 let mut g = GradGraph::new();
3186 let x = g.parameter(Tensor::from_vec_unchecked(vec![2.0], &[1]));
3187 let x2 = g.mul(x, x);
3188 let x3 = g.mul(x2, x);
3189 let loss = g.sum(x3);
3190 let hess = g.double_backward(loss, x);
3191 assert!((hess.to_vec()[0] - 12.0).abs() < 1e-4);
3193 }
3194
3195 #[test]
3196 fn test_full_hessian_quadratic() {
3197 let mut g = GradGraph::new();
3200 let p = g.parameter(Tensor::from_vec_unchecked(vec![1.0, 1.0], &[2]));
3201 let p2 = g.mul(p, p); let s = g.sum(p2); let hess = g.hessian(s, p);
3204 let h = hess.to_vec();
3205 assert!((h[0] - 2.0).abs() < 1e-3); assert!((h[1] - 0.0).abs() < 1e-3); assert!((h[2] - 0.0).abs() < 1e-3); assert!((h[3] - 2.0).abs() < 1e-3); }
3210
3211 #[test]
3212 fn test_vmap_forward() {
3213 let mut g = GradGraph::new();
3214 let x = g.parameter(Tensor::from_vec_unchecked(vec![1.0], &[1]));
3215 let x2 = g.mul(x, x);
3216 let loss = g.sum(x2);
3217
3218 let batch = vec![
3220 Tensor::from_vec_unchecked(vec![1.0], &[1]),
3221 Tensor::from_vec_unchecked(vec![2.0], &[1]),
3222 Tensor::from_vec_unchecked(vec![3.0], &[1]),
3223 ];
3224 let results = g.vmap_forward(x, &batch);
3225 assert!((g.value(results[0]) - 1.0).abs() < 1e-10);
3227 assert!((g.value(results[1]) - 4.0).abs() < 1e-10);
3228 assert!((g.value(results[2]) - 9.0).abs() < 1e-10);
3229 }
3230
3231 #[test]
3232 fn test_hessian_determinism() {
3233 let mut g = GradGraph::new();
3234 let p = g.parameter(Tensor::from_vec_unchecked(vec![3.0, 4.0], &[2]));
3235 let p2 = g.mul(p, p);
3236 let s = g.sum(p2);
3237 let h1 = g.hessian(s, p);
3238 g.nodes[p].borrow_mut().tensor = Tensor::from_vec_unchecked(vec![3.0, 4.0], &[2]);
3240 let h2 = g.hessian(s, p);
3241 assert_eq!(h1.to_vec(), h2.to_vec(), "Hessian must be deterministic");
3242 }
3243}