Skip to main content

cjc_ad/
lib.rs

1//! Automatic differentiation for CJC.
2//!
3//! Provides forward-mode differentiation via dual numbers and reverse-mode
4//! differentiation via a computation tape. Supports `grad()`, `jacobian()`,
5//! and gradient graph construction for ML training loops.
6
7use cjc_runtime::Tensor;
8use std::cell::RefCell;
9use std::rc::Rc;
10
11pub mod pinn;
12
13// ── Forward-Mode AD (Dual Numbers) ──────────────────────────────
14
15/// Dual number for forward-mode automatic differentiation.
16#[derive(Debug, Clone)]
17pub struct Dual {
18    pub value: f64,
19    pub deriv: f64,
20}
21
22impl Dual {
23    pub fn new(value: f64, deriv: f64) -> Self {
24        Self { value, deriv }
25    }
26
27    pub fn constant(value: f64) -> Self {
28        Self { value, deriv: 0.0 }
29    }
30
31    pub fn variable(value: f64) -> Self {
32        Self { value, deriv: 1.0 }
33    }
34
35    pub fn zero() -> Self {
36        Self {
37            value: 0.0,
38            deriv: 0.0,
39        }
40    }
41
42    pub fn one() -> Self {
43        Self {
44            value: 1.0,
45            deriv: 0.0,
46        }
47    }
48}
49
50impl std::ops::Add for Dual {
51    type Output = Dual;
52    fn add(self, rhs: Dual) -> Dual {
53        Dual {
54            value: self.value + rhs.value,
55            deriv: self.deriv + rhs.deriv,
56        }
57    }
58}
59
60impl std::ops::Sub for Dual {
61    type Output = Dual;
62    fn sub(self, rhs: Dual) -> Dual {
63        Dual {
64            value: self.value - rhs.value,
65            deriv: self.deriv - rhs.deriv,
66        }
67    }
68}
69
70impl std::ops::Mul for Dual {
71    type Output = Dual;
72    fn mul(self, rhs: Dual) -> Dual {
73        Dual {
74            value: self.value * rhs.value,
75            deriv: self.value * rhs.deriv + self.deriv * rhs.value,
76        }
77    }
78}
79
80impl std::ops::Div for Dual {
81    type Output = Dual;
82    fn div(self, rhs: Dual) -> Dual {
83        let denom = rhs.value * rhs.value;
84        Dual {
85            value: self.value / rhs.value,
86            deriv: (self.deriv * rhs.value - self.value * rhs.deriv) / denom,
87        }
88    }
89}
90
91impl std::ops::Neg for Dual {
92    type Output = Dual;
93    fn neg(self) -> Dual {
94        Dual {
95            value: -self.value,
96            deriv: -self.deriv,
97        }
98    }
99}
100
101impl Dual {
102    pub fn sin(self) -> Dual {
103        Dual {
104            value: self.value.sin(),
105            deriv: self.deriv * self.value.cos(),
106        }
107    }
108
109    pub fn cos(self) -> Dual {
110        Dual {
111            value: self.value.cos(),
112            deriv: -self.deriv * self.value.sin(),
113        }
114    }
115
116    pub fn exp(self) -> Dual {
117        let e = self.value.exp();
118        Dual {
119            value: e,
120            deriv: self.deriv * e,
121        }
122    }
123
124    pub fn ln(self) -> Dual {
125        Dual {
126            value: self.value.ln(),
127            deriv: self.deriv / self.value,
128        }
129    }
130
131    pub fn sqrt(self) -> Dual {
132        let s = self.value.sqrt();
133        Dual {
134            value: s,
135            deriv: self.deriv / (2.0 * s),
136        }
137    }
138
139    pub fn pow(self, n: f64) -> Dual {
140        Dual {
141            value: self.value.powf(n),
142            deriv: self.deriv * n * self.value.powf(n - 1.0),
143        }
144    }
145}
146
147// ── Reverse-Mode AD (Computational Graph) ───────────────────────
148
149/// Operation recorded in the computation graph.
150#[derive(Debug, Clone)]
151pub enum GradOp {
152    Input,
153    Parameter,
154    Add(usize, usize),
155    Sub(usize, usize),
156    Mul(usize, usize),
157    Div(usize, usize),
158    Neg(usize),
159    MatMul(usize, usize),
160    Sum(usize),
161    Mean(usize),
162    ScalarMul(usize, f64),
163    Exp(usize),
164    Ln(usize),
165    /// Gradient through struct field access: parent node, field index.
166    StructField {
167        parent: usize,
168        field_index: usize,
169        total_fields: usize,
170    },
171    /// Gradient through map lookup: map node, key index in insertion order.
172    MapLookup {
173        map_node: usize,
174        key_index: usize,
175        total_keys: usize,
176    },
177    // Phase B8: Transcendental & activation ops
178    Sin(usize),
179    Cos(usize),
180    Sqrt(usize),
181    Pow(usize, f64),
182    Sigmoid(usize),
183    Relu(usize),
184    TanhAct(usize),
185    // Phase 8: Extended AD ops
186    Abs(usize),
187    Log2(usize),
188    Softmax(usize),
189    /// Cross-entropy loss: CrossEntropy(logits, targets)
190    CrossEntropy { logits: usize, targets: usize },
191    /// Layer normalization: LayerNorm(input); stores normalized output and std for backward
192    LayerNorm(usize),
193    /// Batch normalization: BatchNorm(input); stores normalized output and std for backward
194    BatchNorm(usize),
195    /// Clamp to [min, max]
196    Clamp { input: usize, min: f64, max: f64 },
197    /// Conditional select: Where(condition, on_true, on_false)
198    /// condition is a tensor of 0.0/1.0 masks
199    Where { cond: usize, on_true: usize, on_false: usize },
200    /// Reshape with stored original shape for backward
201    Reshape { input: usize, original_shape: Vec<usize> },
202    /// Transpose (2-D)
203    TransposeOp(usize),
204    /// Concatenation along axis with sizes for splitting on backward
205    CatOp { inputs: Vec<usize>, axis: usize, sizes: Vec<usize> },
206    /// Gather along axis: GatherOp { input, indices, axis }
207    GatherOp { input: usize, indices: Vec<usize>, axis: usize },
208}
209
210/// A node in the reverse-mode AD graph.
211#[derive(Debug, Clone)]
212pub struct GradNode {
213    pub op: GradOp,
214    pub tensor: Tensor,
215    pub grad: Option<Tensor>,
216}
217
218/// The reverse-mode AD tape/graph.
219pub struct GradGraph {
220    pub nodes: Vec<Rc<RefCell<GradNode>>>,
221}
222
223impl GradGraph {
224    pub fn new() -> Self {
225        Self { nodes: Vec::new() }
226    }
227
228    /// Create an input node (data, no gradient).
229    pub fn input(&mut self, tensor: Tensor) -> usize {
230        let idx = self.nodes.len();
231        self.nodes.push(Rc::new(RefCell::new(GradNode {
232            op: GradOp::Input,
233            tensor,
234            grad: None,
235        })));
236        idx
237    }
238
239    /// Create a parameter node (trainable, accumulates gradients).
240    pub fn parameter(&mut self, tensor: Tensor) -> usize {
241        let idx = self.nodes.len();
242        let shape = tensor.shape().to_vec();
243        self.nodes.push(Rc::new(RefCell::new(GradNode {
244            op: GradOp::Parameter,
245            tensor,
246            grad: Some(Tensor::zeros(&shape)),
247        })));
248        idx
249    }
250
251    /// Element-wise addition.
252    pub fn add(&mut self, a: usize, b: usize) -> usize {
253        let a_t = self.nodes[a].borrow().tensor.clone();
254        let b_t = self.nodes[b].borrow().tensor.clone();
255        let result = a_t.add_unchecked(&b_t);
256        let idx = self.nodes.len();
257        self.nodes.push(Rc::new(RefCell::new(GradNode {
258            op: GradOp::Add(a, b),
259            tensor: result,
260            grad: None,
261        })));
262        idx
263    }
264
265    /// Element-wise subtraction.
266    pub fn sub(&mut self, a: usize, b: usize) -> usize {
267        let a_t = self.nodes[a].borrow().tensor.clone();
268        let b_t = self.nodes[b].borrow().tensor.clone();
269        let result = a_t.sub_unchecked(&b_t);
270        let idx = self.nodes.len();
271        self.nodes.push(Rc::new(RefCell::new(GradNode {
272            op: GradOp::Sub(a, b),
273            tensor: result,
274            grad: None,
275        })));
276        idx
277    }
278
279    /// Element-wise multiplication.
280    pub fn mul(&mut self, a: usize, b: usize) -> usize {
281        let a_t = self.nodes[a].borrow().tensor.clone();
282        let b_t = self.nodes[b].borrow().tensor.clone();
283        let result = a_t.mul_elem_unchecked(&b_t);
284        let idx = self.nodes.len();
285        self.nodes.push(Rc::new(RefCell::new(GradNode {
286            op: GradOp::Mul(a, b),
287            tensor: result,
288            grad: None,
289        })));
290        idx
291    }
292
293    /// Matrix multiplication.
294    pub fn matmul(&mut self, a: usize, b: usize) -> usize {
295        let a_t = self.nodes[a].borrow().tensor.clone();
296        let b_t = self.nodes[b].borrow().tensor.clone();
297        let result = a_t.matmul_unchecked(&b_t);
298        let idx = self.nodes.len();
299        self.nodes.push(Rc::new(RefCell::new(GradNode {
300            op: GradOp::MatMul(a, b),
301            tensor: result,
302            grad: None,
303        })));
304        idx
305    }
306
307    /// Sum all elements.
308    pub fn sum(&mut self, a: usize) -> usize {
309        let a_t = self.nodes[a].borrow().tensor.clone();
310        let s = a_t.sum();
311        let result = Tensor::from_vec_unchecked(vec![s], &[1]);
312        let idx = self.nodes.len();
313        self.nodes.push(Rc::new(RefCell::new(GradNode {
314            op: GradOp::Sum(a),
315            tensor: result,
316            grad: None,
317        })));
318        idx
319    }
320
321    /// Mean of all elements.
322    pub fn mean(&mut self, a: usize) -> usize {
323        let a_t = self.nodes[a].borrow().tensor.clone();
324        let m = a_t.mean();
325        let result = Tensor::from_vec_unchecked(vec![m], &[1]);
326        let idx = self.nodes.len();
327        self.nodes.push(Rc::new(RefCell::new(GradNode {
328            op: GradOp::Mean(a),
329            tensor: result,
330            grad: None,
331        })));
332        idx
333    }
334
335    // ── Phase B8: Transcendental & activation forward ops ──
336
337    /// Element-wise sine.
338    pub fn sin(&mut self, a: usize) -> usize {
339        let a_t = self.nodes[a].borrow().tensor.clone();
340        let data = a_t.to_vec();
341        let result = Tensor::from_vec_unchecked(
342            data.iter().map(|&x| x.sin()).collect(),
343            a_t.shape(),
344        );
345        let idx = self.nodes.len();
346        self.nodes.push(Rc::new(RefCell::new(GradNode {
347            op: GradOp::Sin(a),
348            tensor: result,
349            grad: None,
350        })));
351        idx
352    }
353
354    /// Element-wise cosine.
355    pub fn cos(&mut self, a: usize) -> usize {
356        let a_t = self.nodes[a].borrow().tensor.clone();
357        let data = a_t.to_vec();
358        let result = Tensor::from_vec_unchecked(
359            data.iter().map(|&x| x.cos()).collect(),
360            a_t.shape(),
361        );
362        let idx = self.nodes.len();
363        self.nodes.push(Rc::new(RefCell::new(GradNode {
364            op: GradOp::Cos(a),
365            tensor: result,
366            grad: None,
367        })));
368        idx
369    }
370
371    /// Element-wise square root.
372    pub fn sqrt(&mut self, a: usize) -> usize {
373        let a_t = self.nodes[a].borrow().tensor.clone();
374        let data = a_t.to_vec();
375        let result = Tensor::from_vec_unchecked(
376            data.iter().map(|&x| x.sqrt()).collect(),
377            a_t.shape(),
378        );
379        let idx = self.nodes.len();
380        self.nodes.push(Rc::new(RefCell::new(GradNode {
381            op: GradOp::Sqrt(a),
382            tensor: result,
383            grad: None,
384        })));
385        idx
386    }
387
388    /// Element-wise power with constant exponent.
389    pub fn pow(&mut self, a: usize, n: f64) -> usize {
390        let a_t = self.nodes[a].borrow().tensor.clone();
391        let data = a_t.to_vec();
392        let result = Tensor::from_vec_unchecked(
393            data.iter().map(|&x| x.powf(n)).collect(),
394            a_t.shape(),
395        );
396        let idx = self.nodes.len();
397        self.nodes.push(Rc::new(RefCell::new(GradNode {
398            op: GradOp::Pow(a, n),
399            tensor: result,
400            grad: None,
401        })));
402        idx
403    }
404
405    /// Sigmoid activation: 1 / (1 + exp(-x)).
406    pub fn sigmoid(&mut self, a: usize) -> usize {
407        let a_t = self.nodes[a].borrow().tensor.clone();
408        let data = a_t.to_vec();
409        let result = Tensor::from_vec_unchecked(
410            data.iter().map(|&x| 1.0 / (1.0 + (-x).exp())).collect(),
411            a_t.shape(),
412        );
413        let idx = self.nodes.len();
414        self.nodes.push(Rc::new(RefCell::new(GradNode {
415            op: GradOp::Sigmoid(a),
416            tensor: result,
417            grad: None,
418        })));
419        idx
420    }
421
422    /// ReLU activation: max(0, x).
423    pub fn relu(&mut self, a: usize) -> usize {
424        let a_t = self.nodes[a].borrow().tensor.clone();
425        let data = a_t.to_vec();
426        let result = Tensor::from_vec_unchecked(
427            data.iter().map(|&x| if x > 0.0 { x } else { 0.0 }).collect(),
428            a_t.shape(),
429        );
430        let idx = self.nodes.len();
431        self.nodes.push(Rc::new(RefCell::new(GradNode {
432            op: GradOp::Relu(a),
433            tensor: result,
434            grad: None,
435        })));
436        idx
437    }
438
439    /// Tanh activation.
440    pub fn tanh_act(&mut self, a: usize) -> usize {
441        let a_t = self.nodes[a].borrow().tensor.clone();
442        let data = a_t.to_vec();
443        let result = Tensor::from_vec_unchecked(
444            data.iter().map(|&x| x.tanh()).collect(),
445            a_t.shape(),
446        );
447        let idx = self.nodes.len();
448        self.nodes.push(Rc::new(RefCell::new(GradNode {
449            op: GradOp::TanhAct(a),
450            tensor: result,
451            grad: None,
452        })));
453        idx
454    }
455
456    // ── Phase 8: Extended AD forward ops ──
457
458    /// Element-wise absolute value.
459    pub fn abs(&mut self, a: usize) -> usize {
460        let a_t = self.nodes[a].borrow().tensor.clone();
461        let data = a_t.to_vec();
462        let result = Tensor::from_vec_unchecked(
463            data.iter().map(|&x| x.abs()).collect(),
464            a_t.shape(),
465        );
466        let idx = self.nodes.len();
467        self.nodes.push(Rc::new(RefCell::new(GradNode {
468            op: GradOp::Abs(a),
469            tensor: result,
470            grad: None,
471        })));
472        idx
473    }
474
475    /// Element-wise log base 2.
476    pub fn log2(&mut self, a: usize) -> usize {
477        let a_t = self.nodes[a].borrow().tensor.clone();
478        let data = a_t.to_vec();
479        let result = Tensor::from_vec_unchecked(
480            data.iter().map(|&x| x.log2()).collect(),
481            a_t.shape(),
482        );
483        let idx = self.nodes.len();
484        self.nodes.push(Rc::new(RefCell::new(GradNode {
485            op: GradOp::Log2(a),
486            tensor: result,
487            grad: None,
488        })));
489        idx
490    }
491
492    /// Softmax along the last axis (treats tensor as a flat vector for 1-D).
493    /// Uses numerically stable log-sum-exp: softmax(x_i) = exp(x_i - max(x)) / sum(exp(x_j - max(x)))
494    pub fn softmax(&mut self, a: usize) -> usize {
495        use cjc_repro::KahanAccumulatorF64;
496        let a_t = self.nodes[a].borrow().tensor.clone();
497        let data = a_t.to_vec();
498        let max_val = data.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
499        let exp_shifted: Vec<f64> = data.iter().map(|&x| (x - max_val).exp()).collect();
500        let mut sum_acc = KahanAccumulatorF64::new();
501        for &v in &exp_shifted {
502            sum_acc.add(v);
503        }
504        let sum_exp = sum_acc.finalize();
505        let softmax_data: Vec<f64> = exp_shifted.iter().map(|&e| e / sum_exp).collect();
506        let result = Tensor::from_vec_unchecked(softmax_data, a_t.shape());
507        let idx = self.nodes.len();
508        self.nodes.push(Rc::new(RefCell::new(GradNode {
509            op: GradOp::Softmax(a),
510            tensor: result,
511            grad: None,
512        })));
513        idx
514    }
515
516    /// Cross-entropy loss: -sum(targets * log(softmax(logits)))
517    /// Uses numerically stable log-sum-exp internally.
518    /// Returns a scalar [1] tensor.
519    pub fn cross_entropy(&mut self, logits: usize, targets: usize) -> usize {
520        use cjc_repro::KahanAccumulatorF64;
521        let logits_t = self.nodes[logits].borrow().tensor.clone();
522        let targets_t = self.nodes[targets].borrow().tensor.clone();
523        let logits_data = logits_t.to_vec();
524        let targets_data = targets_t.to_vec();
525        // Numerically stable: log_softmax = x_i - max(x) - log(sum(exp(x_j - max(x))))
526        let max_val = logits_data.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
527        let shifted: Vec<f64> = logits_data.iter().map(|&x| x - max_val).collect();
528        let exp_shifted: Vec<f64> = shifted.iter().map(|&x| x.exp()).collect();
529        let mut sum_acc = KahanAccumulatorF64::new();
530        for &v in &exp_shifted {
531            sum_acc.add(v);
532        }
533        let log_sum_exp = sum_acc.finalize().ln();
534        let log_softmax: Vec<f64> = shifted.iter().map(|&x| x - log_sum_exp).collect();
535        // CE = -sum(targets * log_softmax)
536        let mut ce_acc = KahanAccumulatorF64::new();
537        for (t, ls) in targets_data.iter().zip(log_softmax.iter()) {
538            ce_acc.add(-t * ls);
539        }
540        let ce = ce_acc.finalize();
541        let result = Tensor::from_vec_unchecked(vec![ce], &[1]);
542        let idx = self.nodes.len();
543        self.nodes.push(Rc::new(RefCell::new(GradNode {
544            op: GradOp::CrossEntropy { logits, targets },
545            tensor: result,
546            grad: None,
547        })));
548        idx
549    }
550
551    /// Layer normalization: normalize input to zero mean and unit variance.
552    /// y = (x - mean(x)) / sqrt(var(x) + eps), where eps = 1e-5.
553    pub fn layer_norm(&mut self, a: usize) -> usize {
554        use cjc_repro::KahanAccumulatorF64;
555        let a_t = self.nodes[a].borrow().tensor.clone();
556        let data = a_t.to_vec();
557        let n = data.len() as f64;
558        // Mean
559        let mut mean_acc = KahanAccumulatorF64::new();
560        for &v in &data {
561            mean_acc.add(v);
562        }
563        let mean = mean_acc.finalize() / n;
564        // Variance
565        let mut var_acc = KahanAccumulatorF64::new();
566        for &v in &data {
567            let d = v - mean;
568            var_acc.add(d * d);
569        }
570        let var = var_acc.finalize() / n;
571        let eps = 1e-5;
572        let std = (var + eps).sqrt();
573        let normed: Vec<f64> = data.iter().map(|&x| (x - mean) / std).collect();
574        let result = Tensor::from_vec_unchecked(normed, a_t.shape());
575        let idx = self.nodes.len();
576        self.nodes.push(Rc::new(RefCell::new(GradNode {
577            op: GradOp::LayerNorm(a),
578            tensor: result,
579            grad: None,
580        })));
581        idx
582    }
583
584    /// Batch normalization: normalize along the first axis (batch dimension).
585    /// For a tensor of shape [batch, features], normalizes each feature across the batch.
586    /// y = (x - mean(x)) / sqrt(var(x) + eps), where eps = 1e-5.
587    /// For 1-D inputs, behaves identically to layer_norm.
588    pub fn batch_norm(&mut self, a: usize) -> usize {
589        use cjc_repro::KahanAccumulatorF64;
590        let a_t = self.nodes[a].borrow().tensor.clone();
591        let data = a_t.to_vec();
592        let n = data.len() as f64;
593        let mut mean_acc = KahanAccumulatorF64::new();
594        for &v in &data {
595            mean_acc.add(v);
596        }
597        let mean = mean_acc.finalize() / n;
598        let mut var_acc = KahanAccumulatorF64::new();
599        for &v in &data {
600            let d = v - mean;
601            var_acc.add(d * d);
602        }
603        let var = var_acc.finalize() / n;
604        let eps = 1e-5;
605        let std = (var + eps).sqrt();
606        let normed: Vec<f64> = data.iter().map(|&x| (x - mean) / std).collect();
607        let result = Tensor::from_vec_unchecked(normed, a_t.shape());
608        let idx = self.nodes.len();
609        self.nodes.push(Rc::new(RefCell::new(GradNode {
610            op: GradOp::BatchNorm(a),
611            tensor: result,
612            grad: None,
613        })));
614        idx
615    }
616
617    /// Element-wise clamp to [min, max].
618    pub fn clamp(&mut self, a: usize, min: f64, max: f64) -> usize {
619        let a_t = self.nodes[a].borrow().tensor.clone();
620        let data = a_t.to_vec();
621        let result = Tensor::from_vec_unchecked(
622            data.iter().map(|&x| x.max(min).min(max)).collect(),
623            a_t.shape(),
624        );
625        let idx = self.nodes.len();
626        self.nodes.push(Rc::new(RefCell::new(GradNode {
627            op: GradOp::Clamp { input: a, min, max },
628            tensor: result,
629            grad: None,
630        })));
631        idx
632    }
633
634    /// Conditional select: where(cond, on_true, on_false).
635    /// cond is a tensor of 0.0/1.0 values acting as a mask.
636    pub fn where_cond(&mut self, cond: usize, on_true: usize, on_false: usize) -> usize {
637        let cond_t = self.nodes[cond].borrow().tensor.clone();
638        let true_t = self.nodes[on_true].borrow().tensor.clone();
639        let false_t = self.nodes[on_false].borrow().tensor.clone();
640        let c = cond_t.to_vec();
641        let t = true_t.to_vec();
642        let f = false_t.to_vec();
643        let result_data: Vec<f64> = c.iter().zip(t.iter().zip(f.iter()))
644            .map(|(&ci, (&ti, &fi))| if ci != 0.0 { ti } else { fi })
645            .collect();
646        let result = Tensor::from_vec_unchecked(result_data, cond_t.shape());
647        let idx = self.nodes.len();
648        self.nodes.push(Rc::new(RefCell::new(GradNode {
649            op: GradOp::Where { cond, on_true, on_false },
650            tensor: result,
651            grad: None,
652        })));
653        idx
654    }
655
656    /// Reshape a tensor. Stores the original shape for backward.
657    pub fn reshape(&mut self, a: usize, new_shape: &[usize]) -> usize {
658        let a_t = self.nodes[a].borrow().tensor.clone();
659        let original_shape = a_t.shape().to_vec();
660        let result = a_t.reshape(new_shape).expect("GradGraph::reshape: shape mismatch");
661        let idx = self.nodes.len();
662        self.nodes.push(Rc::new(RefCell::new(GradNode {
663            op: GradOp::Reshape { input: a, original_shape },
664            tensor: result,
665            grad: None,
666        })));
667        idx
668    }
669
670    /// Transpose a 2-D tensor.
671    pub fn transpose_op(&mut self, a: usize) -> usize {
672        let a_t = self.nodes[a].borrow().tensor.clone();
673        let result = a_t.transpose();
674        let idx = self.nodes.len();
675        self.nodes.push(Rc::new(RefCell::new(GradNode {
676            op: GradOp::TransposeOp(a),
677            tensor: result,
678            grad: None,
679        })));
680        idx
681    }
682
683    /// Concatenate tensors along an axis.
684    /// All input tensors must have the same shape except along the concatenation axis.
685    pub fn cat(&mut self, inputs: &[usize], axis: usize) -> usize {
686        let tensors: Vec<Tensor> = inputs.iter()
687            .map(|&i| self.nodes[i].borrow().tensor.clone())
688            .collect();
689        let sizes: Vec<usize> = tensors.iter()
690            .map(|t| t.shape()[axis])
691            .collect();
692        // Build concatenated data along axis
693        // For simplicity, handle axis=0 for 1-D and 2-D tensors
694        let mut all_data = Vec::new();
695        let mut total_along_axis = 0usize;
696        let ndim = tensors[0].ndim();
697        let mut result_shape = tensors[0].shape().to_vec();
698
699        if ndim == 1 {
700            // 1-D: just concatenate flat data
701            for t in &tensors {
702                all_data.extend(t.to_vec());
703                total_along_axis += t.shape()[0];
704            }
705            result_shape[0] = total_along_axis;
706        } else if ndim == 2 && axis == 0 {
707            // Concat along rows
708            for t in &tensors {
709                all_data.extend(t.to_vec());
710                total_along_axis += t.shape()[0];
711            }
712            result_shape[0] = total_along_axis;
713        } else if ndim == 2 && axis == 1 {
714            // Concat along columns
715            let nrows = tensors[0].shape()[0];
716            for row in 0..nrows {
717                for t in &tensors {
718                    let cols = t.shape()[1];
719                    let row_data = t.to_vec();
720                    let start = row * cols;
721                    all_data.extend_from_slice(&row_data[start..start + cols]);
722                }
723            }
724            total_along_axis = sizes.iter().sum();
725            result_shape[1] = total_along_axis;
726        } else {
727            // General case: flatten, cat, reshape (deterministic but limited)
728            for t in &tensors {
729                all_data.extend(t.to_vec());
730                total_along_axis += t.shape()[axis];
731            }
732            result_shape[axis] = total_along_axis;
733        }
734
735        let result = Tensor::from_vec_unchecked(all_data, &result_shape);
736        let input_vec = inputs.to_vec();
737        let idx = self.nodes.len();
738        self.nodes.push(Rc::new(RefCell::new(GradNode {
739            op: GradOp::CatOp { inputs: input_vec, axis, sizes },
740            tensor: result,
741            grad: None,
742        })));
743        idx
744    }
745
746    /// Gather elements along an axis using indices.
747    /// For a 1-D tensor, returns tensor[indices].
748    pub fn gather(&mut self, a: usize, indices: &[usize], axis: usize) -> usize {
749        let a_t = self.nodes[a].borrow().tensor.clone();
750        let data = a_t.to_vec();
751        // For 1-D: just pick elements at indices
752        let gathered: Vec<f64> = if a_t.ndim() == 1 {
753            indices.iter().map(|&i| data[i]).collect()
754        } else if a_t.ndim() == 2 && axis == 0 {
755            let cols = a_t.shape()[1];
756            indices.iter().flat_map(|&i| {
757                let start = i * cols;
758                data[start..start + cols].to_vec()
759            }).collect()
760        } else {
761            // Fallback: gather from flat data
762            indices.iter().map(|&i| data[i]).collect()
763        };
764        let mut result_shape = a_t.shape().to_vec();
765        if a_t.ndim() == 1 {
766            result_shape[0] = indices.len();
767        } else if axis == 0 {
768            result_shape[0] = indices.len();
769        } else {
770            result_shape[axis] = indices.len();
771        }
772        let result = Tensor::from_vec_unchecked(gathered, &result_shape);
773        let idx = self.nodes.len();
774        self.nodes.push(Rc::new(RefCell::new(GradNode {
775            op: GradOp::GatherOp { input: a, indices: indices.to_vec(), axis },
776            tensor: result,
777            grad: None,
778        })));
779        idx
780    }
781
782    // ── Phase C1: Missing forward methods ──
783
784    /// Element-wise division: a / b.
785    /// GradOp::Div(a, b) already has backward implementation.
786    pub fn div(&mut self, a: usize, b: usize) -> usize {
787        let a_tensor = self.nodes[a].borrow().tensor.clone();
788        let b_tensor = self.nodes[b].borrow().tensor.clone();
789        let result = a_tensor.div_elem_unchecked(&b_tensor);
790        let node = GradNode { op: GradOp::Div(a, b), tensor: result, grad: None };
791        self.nodes.push(Rc::new(RefCell::new(node)));
792        self.nodes.len() - 1
793    }
794
795    /// Element-wise negation: -a.
796    /// GradOp::Neg(a) already has backward implementation.
797    pub fn neg(&mut self, a: usize) -> usize {
798        let a_tensor = self.nodes[a].borrow().tensor.clone();
799        let result = a_tensor.neg();
800        let node = GradNode { op: GradOp::Neg(a), tensor: result, grad: None };
801        self.nodes.push(Rc::new(RefCell::new(node)));
802        self.nodes.len() - 1
803    }
804
805    /// Scalar multiply: a * s (where s is an f64 constant).
806    /// GradOp::ScalarMul(a, s) already has backward implementation.
807    pub fn scalar_mul(&mut self, a: usize, s: f64) -> usize {
808        let a_tensor = self.nodes[a].borrow().tensor.clone();
809        let result = a_tensor.scalar_mul(s);
810        let node = GradNode { op: GradOp::ScalarMul(a, s), tensor: result, grad: None };
811        self.nodes.push(Rc::new(RefCell::new(node)));
812        self.nodes.len() - 1
813    }
814
815    /// Element-wise exponential: exp(a).
816    /// GradOp::Exp(a) already has backward implementation.
817    pub fn exp(&mut self, a: usize) -> usize {
818        let a_tensor = self.nodes[a].borrow().tensor.clone();
819        let result = Tensor::from_vec_unchecked(
820            a_tensor.to_vec().iter().map(|x| x.exp()).collect(),
821            a_tensor.shape(),
822        );
823        let node = GradNode { op: GradOp::Exp(a), tensor: result, grad: None };
824        self.nodes.push(Rc::new(RefCell::new(node)));
825        self.nodes.len() - 1
826    }
827
828    /// Element-wise natural logarithm: ln(a).
829    /// GradOp::Ln(a) already has backward implementation.
830    pub fn ln(&mut self, a: usize) -> usize {
831        let a_tensor = self.nodes[a].borrow().tensor.clone();
832        let result = Tensor::from_vec_unchecked(
833            a_tensor.to_vec().iter().map(|x| x.ln()).collect(),
834            a_tensor.shape(),
835        );
836        let node = GradNode { op: GradOp::Ln(a), tensor: result, grad: None };
837        self.nodes.push(Rc::new(RefCell::new(node)));
838        self.nodes.len() - 1
839    }
840
841    /// Get the scalar value from a 1-element tensor node.
842    pub fn value(&self, idx: usize) -> f64 {
843        let node = self.nodes[idx].borrow();
844        let data = node.tensor.to_vec();
845        data[0]
846    }
847
848    /// Get the tensor at a node.
849    pub fn tensor(&self, idx: usize) -> Tensor {
850        self.nodes[idx].borrow().tensor.clone()
851    }
852
853    /// Set the tensor at a node (for parameter updates).
854    pub fn set_tensor(&self, idx: usize, tensor: Tensor) {
855        self.nodes[idx].borrow_mut().tensor = tensor;
856    }
857
858    /// Get the gradient at a node.
859    pub fn grad(&self, idx: usize) -> Option<Tensor> {
860        self.nodes[idx].borrow().grad.clone()
861    }
862
863    /// Zero out all gradients.
864    pub fn zero_grad(&self) {
865        for node in &self.nodes {
866            let mut n = node.borrow_mut();
867            if let Some(ref mut grad) = n.grad {
868                let shape = grad.shape().to_vec();
869                *grad = Tensor::zeros(&shape);
870            }
871        }
872    }
873
874    /// Clip all gradients to `[-max_norm, max_norm]` (element-wise).
875    /// This prevents gradient explosion during backpropagation.
876    pub fn clip_grad(&self, max_norm: f64) {
877        for node in &self.nodes {
878            let mut n = node.borrow_mut();
879            if let Some(ref mut grad) = n.grad {
880                let data = grad.to_vec();
881                let clipped: Vec<f64> = data.iter()
882                    .map(|&x| x.max(-max_norm).min(max_norm))
883                    .collect();
884                let shape = grad.shape().to_vec();
885                *grad = Tensor::from_vec_unchecked(clipped, &shape);
886            }
887        }
888    }
889
890    /// Clip gradients by global norm: if ||grads||_2 > max_norm, scale all
891    /// gradients so the global norm equals max_norm. Deterministic via
892    /// sequential accumulation.
893    pub fn clip_grad_norm(&self, max_norm: f64) -> f64 {
894        use cjc_repro::KahanAccumulatorF64;
895        // Compute global norm
896        let mut acc = KahanAccumulatorF64::new();
897        for node in &self.nodes {
898            let n = node.borrow();
899            if let Some(ref grad) = n.grad {
900                for &v in &grad.to_vec() {
901                    acc.add(v * v);
902                }
903            }
904        }
905        let global_norm = acc.finalize().sqrt();
906
907        if global_norm > max_norm && global_norm > 0.0 {
908            let scale = max_norm / global_norm;
909            for node in &self.nodes {
910                let mut n = node.borrow_mut();
911                if let Some(ref mut grad) = n.grad {
912                    let data = grad.to_vec();
913                    let scaled: Vec<f64> = data.iter().map(|&x| x * scale).collect();
914                    let shape = grad.shape().to_vec();
915                    *grad = Tensor::from_vec_unchecked(scaled, &shape);
916                }
917            }
918        }
919
920        global_norm
921    }
922
923    /// Run backward pass from a loss node.
924    pub fn backward(&self, loss_idx: usize) {
925        let n = self.nodes.len();
926
927        // Initialize gradients
928        let mut grads: Vec<Option<Tensor>> = vec![None; n];
929
930        // Loss gradient is 1.0
931        let loss_shape = self.nodes[loss_idx].borrow().tensor.shape().to_vec();
932        grads[loss_idx] = Some(Tensor::ones(&loss_shape));
933
934        // Backward pass in reverse topological order
935        for i in (0..=loss_idx).rev() {
936            let grad = match grads[i].take() {
937                Some(g) => g,
938                None => continue,
939            };
940
941            // Clone op and tensor out of the borrow so we don't hold the RefCell across match arms
942            let (op, node_tensor) = {
943                let node = self.nodes[i].borrow();
944                (node.op.clone(), node.tensor.clone())
945            };
946
947            match op {
948                GradOp::Input => {}
949                GradOp::Parameter => {
950                    let mut node_mut = self.nodes[i].borrow_mut();
951                    if let Some(ref mut existing_grad) = node_mut.grad {
952                        *existing_grad = existing_grad.add_unchecked(&grad);
953                    } else {
954                        node_mut.grad = Some(grad);
955                    }
956                }
957                GradOp::Add(a, b) => {
958                    accumulate_grad(&mut grads, a, &grad);
959                    accumulate_grad(&mut grads, b, &grad);
960                }
961                GradOp::Sub(a, b) => {
962                    accumulate_grad(&mut grads, a, &grad);
963                    let neg_grad = grad.neg();
964                    accumulate_grad(&mut grads, b, &neg_grad);
965                }
966                GradOp::Mul(a, b) => {
967                    let a_val = self.nodes[a].borrow().tensor.clone();
968                    let b_val = self.nodes[b].borrow().tensor.clone();
969
970                    let grad_a = grad.mul_elem_unchecked(&b_val);
971                    let grad_b = grad.mul_elem_unchecked(&a_val);
972
973                    accumulate_grad(&mut grads, a, &grad_a);
974                    accumulate_grad(&mut grads, b, &grad_b);
975                }
976                GradOp::Div(a, b) => {
977                    let a_val = self.nodes[a].borrow().tensor.clone();
978                    let b_val = self.nodes[b].borrow().tensor.clone();
979
980                    // d/da (a/b) = 1/b
981                    let grad_a = grad.div_elem_unchecked(&b_val);
982                    // d/db (a/b) = -a/b^2
983                    let b_sq = b_val.mul_elem_unchecked(&b_val);
984                    let neg_a = a_val.neg();
985                    let grad_b = grad.mul_elem_unchecked(&neg_a.div_elem_unchecked(&b_sq));
986
987                    accumulate_grad(&mut grads, a, &grad_a);
988                    accumulate_grad(&mut grads, b, &grad_b);
989                }
990                GradOp::Neg(a) => {
991                    let neg_grad = grad.neg();
992                    accumulate_grad(&mut grads, a, &neg_grad);
993                }
994                GradOp::MatMul(a, b) => {
995                    // d/da (a @ b) = grad @ b^T
996                    // d/db (a @ b) = a^T @ grad
997                    let a_val = self.nodes[a].borrow().tensor.clone();
998                    let b_val = self.nodes[b].borrow().tensor.clone();
999
1000                    let b_t = b_val.transpose();
1001                    let a_t = a_val.transpose();
1002
1003                    let grad_a = grad.matmul_unchecked(&b_t);
1004                    let grad_b = a_t.matmul_unchecked(&grad);
1005
1006                    accumulate_grad(&mut grads, a, &grad_a);
1007                    accumulate_grad(&mut grads, b, &grad_b);
1008                }
1009                GradOp::Sum(a) => {
1010                    // Gradient of sum is all ones, scaled by upstream grad
1011                    let a_shape = self.nodes[a].borrow().tensor.shape().to_vec();
1012                    let grad_val = grad.to_vec()[0];
1013                    let expanded = Tensor::from_vec_unchecked(
1014                        vec![grad_val; a_shape.iter().product()],
1015                        &a_shape,
1016                    );
1017                    accumulate_grad(&mut grads, a, &expanded);
1018                }
1019                GradOp::Mean(a) => {
1020                    let a_shape = self.nodes[a].borrow().tensor.shape().to_vec();
1021                    let n_elem = a_shape.iter().product::<usize>() as f64;
1022                    let grad_val = grad.to_vec()[0] / n_elem;
1023                    let expanded = Tensor::from_vec_unchecked(
1024                        vec![grad_val; a_shape.iter().product()],
1025                        &a_shape,
1026                    );
1027                    accumulate_grad(&mut grads, a, &expanded);
1028                }
1029                GradOp::ScalarMul(a, s) => {
1030                    let scaled = grad.scalar_mul(s);
1031                    accumulate_grad(&mut grads, a, &scaled);
1032                }
1033                GradOp::Exp(a) => {
1034                    let grad_a = grad.mul_elem_unchecked(&node_tensor);
1035                    accumulate_grad(&mut grads, a, &grad_a);
1036                }
1037                GradOp::Ln(a) => {
1038                    let a_val = self.nodes[a].borrow().tensor.clone();
1039                    let grad_a = grad.div_elem_unchecked(&a_val);
1040                    accumulate_grad(&mut grads, a, &grad_a);
1041                }
1042                // Differentiable container ops: gradient flows back to parent
1043                GradOp::StructField {
1044                    parent,
1045                    field_index,
1046                    total_fields,
1047                } => {
1048                    // Gradient for field access: create a "one-hot" gradient
1049                    // where only the accessed field gets the incoming gradient.
1050                    // For now, accumulate directly to parent.
1051                    let _ = (field_index, total_fields);
1052                    accumulate_grad(&mut grads, parent, &grad);
1053                }
1054                GradOp::MapLookup {
1055                    map_node,
1056                    key_index,
1057                    total_keys,
1058                } => {
1059                    // Gradient for map lookup: accumulate to map node.
1060                    // Deterministic: key_index determines order of accumulation.
1061                    let _ = (key_index, total_keys);
1062                    accumulate_grad(&mut grads, map_node, &grad);
1063                }
1064                // Phase B8: Transcendental & activation backward
1065                GradOp::Sin(a) => {
1066                    let a_val = self.nodes[a].borrow().tensor.clone();
1067                    let cos_a = Tensor::from_vec_unchecked(
1068                        a_val.to_vec().iter().map(|&x| x.cos()).collect(),
1069                        a_val.shape(),
1070                    );
1071                    let grad_a = grad.mul_elem_unchecked(&cos_a);
1072                    accumulate_grad(&mut grads, a, &grad_a);
1073                }
1074                GradOp::Cos(a) => {
1075                    let a_val = self.nodes[a].borrow().tensor.clone();
1076                    let neg_sin_a = Tensor::from_vec_unchecked(
1077                        a_val.to_vec().iter().map(|&x| -x.sin()).collect(),
1078                        a_val.shape(),
1079                    );
1080                    let grad_a = grad.mul_elem_unchecked(&neg_sin_a);
1081                    accumulate_grad(&mut grads, a, &grad_a);
1082                }
1083                GradOp::Sqrt(a) => {
1084                    // d/dx sqrt(x) = 0.5 / sqrt(x) = 0.5 / node_tensor
1085                    let inv_2sqrt = Tensor::from_vec_unchecked(
1086                        node_tensor.to_vec().iter().map(|&x| 0.5 / x).collect(),
1087                        node_tensor.shape(),
1088                    );
1089                    let grad_a = grad.mul_elem_unchecked(&inv_2sqrt);
1090                    accumulate_grad(&mut grads, a, &grad_a);
1091                }
1092                GradOp::Pow(a, n) => {
1093                    let a_val = self.nodes[a].borrow().tensor.clone();
1094                    let coeff = Tensor::from_vec_unchecked(
1095                        a_val.to_vec().iter().map(|&x| n * x.powf(n - 1.0)).collect(),
1096                        a_val.shape(),
1097                    );
1098                    let grad_a = grad.mul_elem_unchecked(&coeff);
1099                    accumulate_grad(&mut grads, a, &grad_a);
1100                }
1101                GradOp::Sigmoid(a) => {
1102                    // sigmoid'(x) = sigmoid(x) * (1 - sigmoid(x))
1103                    let sig = &node_tensor;
1104                    let one_minus = Tensor::from_vec_unchecked(
1105                        sig.to_vec().iter().map(|&s| 1.0 - s).collect(),
1106                        sig.shape(),
1107                    );
1108                    let local = sig.mul_elem_unchecked(&one_minus);
1109                    let grad_a = grad.mul_elem_unchecked(&local);
1110                    accumulate_grad(&mut grads, a, &grad_a);
1111                }
1112                GradOp::Relu(a) => {
1113                    let a_val = self.nodes[a].borrow().tensor.clone();
1114                    let mask = Tensor::from_vec_unchecked(
1115                        a_val.to_vec().iter().map(|&x| if x > 0.0 { 1.0 } else { 0.0 }).collect(),
1116                        a_val.shape(),
1117                    );
1118                    let grad_a = grad.mul_elem_unchecked(&mask);
1119                    accumulate_grad(&mut grads, a, &grad_a);
1120                }
1121                GradOp::TanhAct(a) => {
1122                    // tanh'(x) = 1 - tanh(x)^2
1123                    let t = &node_tensor;
1124                    let one_minus_sq = Tensor::from_vec_unchecked(
1125                        t.to_vec().iter().map(|&x| 1.0 - x * x).collect(),
1126                        t.shape(),
1127                    );
1128                    let grad_a = grad.mul_elem_unchecked(&one_minus_sq);
1129                    accumulate_grad(&mut grads, a, &grad_a);
1130                }
1131                // Phase 8: Extended AD backward
1132                GradOp::Abs(a) => {
1133                    let a_val = self.nodes[a].borrow().tensor.clone();
1134                    let sign = Tensor::from_vec_unchecked(
1135                        a_val.to_vec().iter().map(|&x| {
1136                            if x > 0.0 { 1.0 } else if x < 0.0 { -1.0 } else { 0.0 }
1137                        }).collect(),
1138                        a_val.shape(),
1139                    );
1140                    let grad_a = grad.mul_elem_unchecked(&sign);
1141                    accumulate_grad(&mut grads, a, &grad_a);
1142                }
1143                GradOp::Log2(a) => {
1144                    // d/dx log2(x) = 1 / (x * ln(2))
1145                    let a_val = self.nodes[a].borrow().tensor.clone();
1146                    let ln2 = std::f64::consts::LN_2;
1147                    let local = Tensor::from_vec_unchecked(
1148                        a_val.to_vec().iter().map(|&x| 1.0 / (x * ln2)).collect(),
1149                        a_val.shape(),
1150                    );
1151                    let grad_a = grad.mul_elem_unchecked(&local);
1152                    accumulate_grad(&mut grads, a, &grad_a);
1153                }
1154                GradOp::Softmax(a) => {
1155                    // Jacobian-vector product: grad_input = softmax * (grad - sum(grad * softmax))
1156                    use cjc_repro::KahanAccumulatorF64;
1157                    let sm = &node_tensor;
1158                    let sm_data = sm.to_vec();
1159                    let grad_data = grad.to_vec();
1160                    let mut dot_acc = KahanAccumulatorF64::new();
1161                    for (&g, &s) in grad_data.iter().zip(sm_data.iter()) {
1162                        dot_acc.add(g * s);
1163                    }
1164                    let dot = dot_acc.finalize();
1165                    let grad_input: Vec<f64> = sm_data.iter().zip(grad_data.iter())
1166                        .map(|(&s, &g)| s * (g - dot))
1167                        .collect();
1168                    let grad_a = Tensor::from_vec_unchecked(grad_input, sm.shape());
1169                    accumulate_grad(&mut grads, a, &grad_a);
1170                }
1171                GradOp::CrossEntropy { logits, targets } => {
1172                    // Combined softmax + cross-entropy gradient: grad_logits = grad * (softmax - targets)
1173                    use cjc_repro::KahanAccumulatorF64;
1174                    let logits_val = self.nodes[logits].borrow().tensor.clone();
1175                    let targets_val = self.nodes[targets].borrow().tensor.clone();
1176                    let logits_data = logits_val.to_vec();
1177                    let targets_data = targets_val.to_vec();
1178                    // Compute softmax of logits (numerically stable)
1179                    let max_val = logits_data.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
1180                    let exp_shifted: Vec<f64> = logits_data.iter().map(|&x| (x - max_val).exp()).collect();
1181                    let mut sum_acc = KahanAccumulatorF64::new();
1182                    for &v in &exp_shifted {
1183                        sum_acc.add(v);
1184                    }
1185                    let sum_exp = sum_acc.finalize();
1186                    let softmax: Vec<f64> = exp_shifted.iter().map(|&e| e / sum_exp).collect();
1187                    // grad_logits = upstream_grad * (softmax - targets)
1188                    let upstream = grad.to_vec()[0]; // CE produces scalar
1189                    let grad_logits: Vec<f64> = softmax.iter().zip(targets_data.iter())
1190                        .map(|(&s, &t)| upstream * (s - t))
1191                        .collect();
1192                    let gl = Tensor::from_vec_unchecked(grad_logits, logits_val.shape());
1193                    accumulate_grad(&mut grads, logits, &gl);
1194                    // No gradient flows to targets (they are labels)
1195                }
1196                GradOp::LayerNorm(a) => {
1197                    // Layer norm backward:
1198                    // dx = (1/std) * (grad - mean(grad) - x_hat * mean(grad * x_hat))
1199                    // where x_hat = normalized output (node_tensor)
1200                    use cjc_repro::KahanAccumulatorF64;
1201                    let x_hat = &node_tensor;
1202                    let x_hat_data = x_hat.to_vec();
1203                    let grad_data = grad.to_vec();
1204                    let n = x_hat_data.len() as f64;
1205                    // Reconstruct std from input
1206                    let a_val = self.nodes[a].borrow().tensor.clone();
1207                    let a_data = a_val.to_vec();
1208                    let mut mean_acc = KahanAccumulatorF64::new();
1209                    for &v in &a_data {
1210                        mean_acc.add(v);
1211                    }
1212                    let mean = mean_acc.finalize() / n;
1213                    let mut var_acc = KahanAccumulatorF64::new();
1214                    for &v in &a_data {
1215                        let d = v - mean;
1216                        var_acc.add(d * d);
1217                    }
1218                    let var = var_acc.finalize() / n;
1219                    let eps = 1e-5;
1220                    let std_val = (var + eps).sqrt();
1221                    // mean(grad)
1222                    let mut mg_acc = KahanAccumulatorF64::new();
1223                    for &g in &grad_data {
1224                        mg_acc.add(g);
1225                    }
1226                    let mean_grad = mg_acc.finalize() / n;
1227                    // mean(grad * x_hat)
1228                    let mut mgx_acc = KahanAccumulatorF64::new();
1229                    for (&g, &xh) in grad_data.iter().zip(x_hat_data.iter()) {
1230                        mgx_acc.add(g * xh);
1231                    }
1232                    let mean_grad_xhat = mgx_acc.finalize() / n;
1233                    let dx: Vec<f64> = grad_data.iter().zip(x_hat_data.iter())
1234                        .map(|(&g, &xh)| (g - mean_grad - xh * mean_grad_xhat) / std_val)
1235                        .collect();
1236                    let grad_a = Tensor::from_vec_unchecked(dx, a_val.shape());
1237                    accumulate_grad(&mut grads, a, &grad_a);
1238                }
1239                GradOp::BatchNorm(a) => {
1240                    // Identical to LayerNorm backward for the per-tensor case
1241                    use cjc_repro::KahanAccumulatorF64;
1242                    let x_hat = &node_tensor;
1243                    let x_hat_data = x_hat.to_vec();
1244                    let grad_data = grad.to_vec();
1245                    let n = x_hat_data.len() as f64;
1246                    let a_val = self.nodes[a].borrow().tensor.clone();
1247                    let a_data = a_val.to_vec();
1248                    let mut mean_acc = KahanAccumulatorF64::new();
1249                    for &v in &a_data {
1250                        mean_acc.add(v);
1251                    }
1252                    let mean = mean_acc.finalize() / n;
1253                    let mut var_acc = KahanAccumulatorF64::new();
1254                    for &v in &a_data {
1255                        let d = v - mean;
1256                        var_acc.add(d * d);
1257                    }
1258                    let var = var_acc.finalize() / n;
1259                    let eps = 1e-5;
1260                    let std_val = (var + eps).sqrt();
1261                    let mut mg_acc = KahanAccumulatorF64::new();
1262                    for &g in &grad_data {
1263                        mg_acc.add(g);
1264                    }
1265                    let mean_grad = mg_acc.finalize() / n;
1266                    let mut mgx_acc = KahanAccumulatorF64::new();
1267                    for (&g, &xh) in grad_data.iter().zip(x_hat_data.iter()) {
1268                        mgx_acc.add(g * xh);
1269                    }
1270                    let mean_grad_xhat = mgx_acc.finalize() / n;
1271                    let dx: Vec<f64> = grad_data.iter().zip(x_hat_data.iter())
1272                        .map(|(&g, &xh)| (g - mean_grad - xh * mean_grad_xhat) / std_val)
1273                        .collect();
1274                    let grad_a = Tensor::from_vec_unchecked(dx, a_val.shape());
1275                    accumulate_grad(&mut grads, a, &grad_a);
1276                }
1277                GradOp::Clamp { input, min, max } => {
1278                    // Gradient passes through where input is in [min, max], else 0
1279                    let a_val = self.nodes[input].borrow().tensor.clone();
1280                    let mask = Tensor::from_vec_unchecked(
1281                        a_val.to_vec().iter().map(|&x| {
1282                            if x >= min && x <= max { 1.0 } else { 0.0 }
1283                        }).collect(),
1284                        a_val.shape(),
1285                    );
1286                    let grad_a = grad.mul_elem_unchecked(&mask);
1287                    accumulate_grad(&mut grads, input, &grad_a);
1288                }
1289                GradOp::Where { cond, on_true, on_false } => {
1290                    let cond_data = self.nodes[cond].borrow().tensor.to_vec();
1291                    let grad_data = grad.to_vec();
1292                    let shape = grad.shape().to_vec();
1293                    let grad_true: Vec<f64> = cond_data.iter().zip(grad_data.iter())
1294                        .map(|(&c, &g)| if c != 0.0 { g } else { 0.0 })
1295                        .collect();
1296                    let grad_false: Vec<f64> = cond_data.iter().zip(grad_data.iter())
1297                        .map(|(&c, &g)| if c != 0.0 { 0.0 } else { g })
1298                        .collect();
1299                    let gt = Tensor::from_vec_unchecked(grad_true, &shape);
1300                    let gf = Tensor::from_vec_unchecked(grad_false, &shape);
1301                    accumulate_grad(&mut grads, on_true, &gt);
1302                    accumulate_grad(&mut grads, on_false, &gf);
1303                    // No gradient to condition
1304                }
1305                GradOp::Reshape { input, ref original_shape } => {
1306                    // Backward: reshape grad back to original shape
1307                    let grad_a = grad.reshape(original_shape)
1308                        .expect("Reshape backward: shape mismatch");
1309                    accumulate_grad(&mut grads, input, &grad_a);
1310                }
1311                GradOp::TransposeOp(a) => {
1312                    // Transpose is its own inverse for 2-D
1313                    let grad_a = grad.transpose();
1314                    accumulate_grad(&mut grads, a, &grad_a);
1315                }
1316                GradOp::CatOp { ref inputs, axis, ref sizes } => {
1317                    // Split grad along axis into pieces matching sizes
1318                    let grad_data = grad.to_vec();
1319                    let grad_shape = grad.shape().to_vec();
1320                    let ndim = grad_shape.len();
1321                    if ndim == 1 {
1322                        let mut offset = 0usize;
1323                        for (&idx, &sz) in inputs.iter().zip(sizes.iter()) {
1324                            let piece = grad_data[offset..offset + sz].to_vec();
1325                            let gt = Tensor::from_vec_unchecked(piece, &[sz]);
1326                            accumulate_grad(&mut grads, idx, &gt);
1327                            offset += sz;
1328                        }
1329                    } else if ndim == 2 && axis == 0 {
1330                        let cols = grad_shape[1];
1331                        let mut row_offset = 0usize;
1332                        for (&idx, &sz) in inputs.iter().zip(sizes.iter()) {
1333                            let start = row_offset * cols;
1334                            let end = start + sz * cols;
1335                            let piece = grad_data[start..end].to_vec();
1336                            let gt = Tensor::from_vec_unchecked(piece, &[sz, cols]);
1337                            accumulate_grad(&mut grads, idx, &gt);
1338                            row_offset += sz;
1339                        }
1340                    } else if ndim == 2 && axis == 1 {
1341                        let nrows = grad_shape[0];
1342                        let total_cols = grad_shape[1];
1343                        for (input_idx, (&idx, &sz)) in inputs.iter().zip(sizes.iter()).enumerate() {
1344                            let mut piece = Vec::with_capacity(nrows * sz);
1345                            let col_offset: usize = sizes[..input_idx].iter().sum();
1346                            for row in 0..nrows {
1347                                let row_start = row * total_cols + col_offset;
1348                                piece.extend_from_slice(&grad_data[row_start..row_start + sz]);
1349                            }
1350                            let gt = Tensor::from_vec_unchecked(piece, &[nrows, sz]);
1351                            accumulate_grad(&mut grads, idx, &gt);
1352                        }
1353                    } else {
1354                        // General fallback: split flat data proportionally
1355                        let mut offset = 0usize;
1356                        for (&idx, &sz) in inputs.iter().zip(sizes.iter()) {
1357                            let piece_len = sz * grad_data.len() / grad_shape[axis];
1358                            let piece = grad_data[offset..offset + piece_len].to_vec();
1359                            let mut piece_shape = grad_shape.clone();
1360                            piece_shape[axis] = sz;
1361                            let gt = Tensor::from_vec_unchecked(piece, &piece_shape);
1362                            accumulate_grad(&mut grads, idx, &gt);
1363                            offset += piece_len;
1364                        }
1365                    }
1366                }
1367                GradOp::GatherOp { input, ref indices, axis } => {
1368                    // Scatter-add: distribute grad back to input positions
1369                    let input_shape = self.nodes[input].borrow().tensor.shape().to_vec();
1370                    let input_len: usize = input_shape.iter().product();
1371                    let mut scatter = vec![0.0_f64; input_len];
1372                    let grad_data = grad.to_vec();
1373                    if self.nodes[input].borrow().tensor.ndim() == 1 {
1374                        for (gi, &idx) in indices.iter().enumerate() {
1375                            scatter[idx] += grad_data[gi];
1376                        }
1377                    } else if axis == 0 && self.nodes[input].borrow().tensor.ndim() == 2 {
1378                        let cols = input_shape[1];
1379                        for (gi, &idx) in indices.iter().enumerate() {
1380                            for c in 0..cols {
1381                                scatter[idx * cols + c] += grad_data[gi * cols + c];
1382                            }
1383                        }
1384                    } else {
1385                        // Fallback: treat as flat index
1386                        for (gi, &idx) in indices.iter().enumerate() {
1387                            scatter[idx] += grad_data[gi];
1388                        }
1389                    }
1390                    let grad_a = Tensor::from_vec_unchecked(scatter, &input_shape);
1391                    accumulate_grad(&mut grads, input, &grad_a);
1392                }
1393            }
1394        }
1395    }
1396
1397    /// Compute the Jacobian of a vector-valued output node with respect to
1398    /// a parameter node. Returns a 2D tensor of shape [output_dim, param_dim].
1399    ///
1400    /// Strategy: run backward once per output element with a one-hot seed.
1401    pub fn jacobian(&mut self, output_idx: usize, param_idx: usize) -> Tensor {
1402        let output_shape = self.nodes[output_idx].borrow().tensor.shape().to_vec();
1403        let output_dim: usize = output_shape.iter().product();
1404        let param_shape = self.nodes[param_idx].borrow().tensor.shape().to_vec();
1405        let param_dim: usize = param_shape.iter().product();
1406
1407        let mut jac_data = vec![0.0_f64; output_dim * param_dim];
1408
1409        for i in 0..output_dim {
1410            // Zero all gradients
1411            self.zero_grad();
1412
1413            // Create one-hot seed for output element i
1414            let mut seed = vec![0.0_f64; output_dim];
1415            seed[i] = 1.0;
1416            let seed_tensor = Tensor::from_vec_unchecked(seed, &output_shape);
1417
1418            // Run backward with this seed
1419            self.backward_with_seed(output_idx, &seed_tensor);
1420
1421            // Read gradient of param node
1422            let grad = self.nodes[param_idx].borrow().grad.clone();
1423            if let Some(g) = grad {
1424                let g_vec = g.to_vec();
1425                for j in 0..param_dim {
1426                    jac_data[i * param_dim + j] = g_vec[j];
1427                }
1428            }
1429        }
1430
1431        Tensor::from_vec_unchecked(jac_data, &[output_dim, param_dim])
1432    }
1433
1434    /// Compute the diagonal of the Hessian of a scalar loss with respect to
1435    /// a parameter node. Uses finite differences on the gradient
1436    /// (compute grad, perturb, re-compute grad).
1437    ///
1438    /// Returns a tensor of the same shape as the parameter.
1439    pub fn hessian_diag(&mut self, loss_idx: usize, param_idx: usize, eps: f64) -> Tensor {
1440        let param_shape = self.nodes[param_idx].borrow().tensor.shape().to_vec();
1441        let param_dim: usize = param_shape.iter().product();
1442        let original = self.nodes[param_idx].borrow().tensor.to_vec();
1443        let mut hess_diag = vec![0.0_f64; param_dim];
1444
1445        for i in 0..param_dim {
1446            // Perturb +eps
1447            let mut plus = original.clone();
1448            plus[i] += eps;
1449            self.nodes[param_idx].borrow_mut().tensor =
1450                Tensor::from_vec_unchecked(plus, &param_shape);
1451            self.zero_grad();
1452            self.backward(loss_idx);
1453            let grad_plus = self.nodes[param_idx]
1454                .borrow()
1455                .grad
1456                .as_ref()
1457                .map(|g| g.to_vec()[i])
1458                .unwrap_or(0.0);
1459
1460            // Perturb -eps
1461            let mut minus = original.clone();
1462            minus[i] -= eps;
1463            self.nodes[param_idx].borrow_mut().tensor =
1464                Tensor::from_vec_unchecked(minus, &param_shape);
1465            self.zero_grad();
1466            self.backward(loss_idx);
1467            let grad_minus = self.nodes[param_idx]
1468                .borrow()
1469                .grad
1470                .as_ref()
1471                .map(|g| g.to_vec()[i])
1472                .unwrap_or(0.0);
1473
1474            hess_diag[i] = (grad_plus - grad_minus) / (2.0 * eps);
1475        }
1476
1477        // Restore original parameter
1478        self.nodes[param_idx].borrow_mut().tensor =
1479            Tensor::from_vec_unchecked(original, &param_shape);
1480
1481        Tensor::from_vec_unchecked(hess_diag, &param_shape)
1482    }
1483
1484    /// Compute the full Hessian matrix of a scalar loss with respect to a parameter node.
1485    ///
1486    /// Returns a 2D tensor of shape [param_dim, param_dim] where H[i, j] = d²loss / (dp_i dp_j).
1487    ///
1488    /// Strategy: For each parameter element i, perturb param[i] by +eps and -eps, re-run
1489    /// the forward pass to update intermediate node values, then run backward() to get the
1490    /// gradient vector. The i-th row of the Hessian is (grad_plus - grad_minus) / (2 * eps).
1491    /// Uses eps = 1e-5 for accurate central differences.
1492    pub fn hessian(&mut self, loss_idx: usize, param_idx: usize) -> Tensor {
1493        let eps = 1e-5;
1494        let param_shape = self.nodes[param_idx].borrow().tensor.shape().to_vec();
1495        let param_dim: usize = param_shape.iter().product();
1496        let original = self.nodes[param_idx].borrow().tensor.to_vec();
1497        let mut hess_data = vec![0.0_f64; param_dim * param_dim];
1498
1499        for i in 0..param_dim {
1500            // Perturb +eps at index i, re-forward so intermediate nodes are up-to-date
1501            let mut plus = original.clone();
1502            plus[i] += eps;
1503            self.nodes[param_idx].borrow_mut().tensor =
1504                Tensor::from_vec_unchecked(plus, &param_shape);
1505            self.reforward(param_idx + 1, loss_idx);
1506            self.zero_grad();
1507            self.backward(loss_idx);
1508            let grad_plus: Vec<f64> = self.nodes[param_idx]
1509                .borrow()
1510                .grad
1511                .as_ref()
1512                .map(|g| g.to_vec())
1513                .unwrap_or_else(|| vec![0.0; param_dim]);
1514
1515            // Perturb -eps at index i, re-forward
1516            let mut minus = original.clone();
1517            minus[i] -= eps;
1518            self.nodes[param_idx].borrow_mut().tensor =
1519                Tensor::from_vec_unchecked(minus, &param_shape);
1520            self.reforward(param_idx + 1, loss_idx);
1521            self.zero_grad();
1522            self.backward(loss_idx);
1523            let grad_minus: Vec<f64> = self.nodes[param_idx]
1524                .borrow()
1525                .grad
1526                .as_ref()
1527                .map(|g| g.to_vec())
1528                .unwrap_or_else(|| vec![0.0; param_dim]);
1529
1530            // Row i of the Hessian: (grad_plus[j] - grad_minus[j]) / (2 * eps) for each j
1531            for j in 0..param_dim {
1532                hess_data[i * param_dim + j] = (grad_plus[j] - grad_minus[j]) / (2.0 * eps);
1533            }
1534        }
1535
1536        // Restore original parameter and re-forward to clean state
1537        self.nodes[param_idx].borrow_mut().tensor =
1538            Tensor::from_vec_unchecked(original, &param_shape);
1539        self.reforward(param_idx + 1, loss_idx);
1540
1541        Tensor::from_vec_unchecked(hess_data, &[param_dim, param_dim])
1542    }
1543
1544    /// Re-run the forward pass for all nodes from `start` up to and including `end`.
1545    ///
1546    /// This is needed before backward when a parameter has been perturbed, so that
1547    /// intermediate computation nodes hold updated tensor values.
1548    fn reforward(&mut self, start: usize, end: usize) {
1549        for node_i in start..=end {
1550            let new_tensor = {
1551                let node = self.nodes[node_i].borrow();
1552                match &node.op {
1553                    GradOp::Input | GradOp::Parameter => continue,
1554                    GradOp::Add(a, b) => {
1555                        let at = self.nodes[*a].borrow().tensor.clone();
1556                        let bt = self.nodes[*b].borrow().tensor.clone();
1557                        at.add_unchecked(&bt)
1558                    }
1559                    GradOp::Sub(a, b) => {
1560                        let at = self.nodes[*a].borrow().tensor.clone();
1561                        let bt = self.nodes[*b].borrow().tensor.clone();
1562                        at.sub_unchecked(&bt)
1563                    }
1564                    GradOp::Mul(a, b) => {
1565                        let at = self.nodes[*a].borrow().tensor.clone();
1566                        let bt = self.nodes[*b].borrow().tensor.clone();
1567                        at.mul_elem_unchecked(&bt)
1568                    }
1569                    GradOp::Div(a, b) => {
1570                        let at = self.nodes[*a].borrow().tensor.clone();
1571                        let bt = self.nodes[*b].borrow().tensor.clone();
1572                        at.div_elem_unchecked(&bt)
1573                    }
1574                    GradOp::Neg(a) => {
1575                        self.nodes[*a].borrow().tensor.neg()
1576                    }
1577                    GradOp::ScalarMul(a, s) => {
1578                        let s = *s;
1579                        self.nodes[*a].borrow().tensor.scalar_mul(s)
1580                    }
1581                    GradOp::MatMul(a, b) => {
1582                        let at = self.nodes[*a].borrow().tensor.clone();
1583                        let bt = self.nodes[*b].borrow().tensor.clone();
1584                        at.matmul_unchecked(&bt)
1585                    }
1586                    GradOp::Sum(a) => {
1587                        let s = self.nodes[*a].borrow().tensor.sum();
1588                        Tensor::from_vec_unchecked(vec![s], &[1])
1589                    }
1590                    GradOp::Mean(a) => {
1591                        let m = self.nodes[*a].borrow().tensor.mean();
1592                        Tensor::from_vec_unchecked(vec![m], &[1])
1593                    }
1594                    GradOp::Exp(a) => {
1595                        let (data, shape) = {
1596                            let t = self.nodes[*a].borrow();
1597                            (t.tensor.to_vec(), t.tensor.shape().to_vec())
1598                        };
1599                        Tensor::from_vec_unchecked(data.iter().map(|x| x.exp()).collect(), &shape)
1600                    }
1601                    GradOp::Ln(a) => {
1602                        let (data, shape) = {
1603                            let t = self.nodes[*a].borrow();
1604                            (t.tensor.to_vec(), t.tensor.shape().to_vec())
1605                        };
1606                        Tensor::from_vec_unchecked(data.iter().map(|x| x.ln()).collect(), &shape)
1607                    }
1608                    GradOp::Sin(a) => {
1609                        let (data, shape) = {
1610                            let t = self.nodes[*a].borrow();
1611                            (t.tensor.to_vec(), t.tensor.shape().to_vec())
1612                        };
1613                        Tensor::from_vec_unchecked(data.iter().map(|x| x.sin()).collect(), &shape)
1614                    }
1615                    GradOp::Cos(a) => {
1616                        let (data, shape) = {
1617                            let t = self.nodes[*a].borrow();
1618                            (t.tensor.to_vec(), t.tensor.shape().to_vec())
1619                        };
1620                        Tensor::from_vec_unchecked(data.iter().map(|x| x.cos()).collect(), &shape)
1621                    }
1622                    GradOp::Sqrt(a) => {
1623                        let (data, shape) = {
1624                            let t = self.nodes[*a].borrow();
1625                            (t.tensor.to_vec(), t.tensor.shape().to_vec())
1626                        };
1627                        Tensor::from_vec_unchecked(data.iter().map(|x| x.sqrt()).collect(), &shape)
1628                    }
1629                    GradOp::Pow(a, n) => {
1630                        let n = *n;
1631                        let (data, shape) = {
1632                            let t = self.nodes[*a].borrow();
1633                            (t.tensor.to_vec(), t.tensor.shape().to_vec())
1634                        };
1635                        Tensor::from_vec_unchecked(data.iter().map(|x| x.powf(n)).collect(), &shape)
1636                    }
1637                    GradOp::Sigmoid(a) => {
1638                        let (data, shape) = {
1639                            let t = self.nodes[*a].borrow();
1640                            (t.tensor.to_vec(), t.tensor.shape().to_vec())
1641                        };
1642                        Tensor::from_vec_unchecked(
1643                            data.iter().map(|&x| 1.0 / (1.0 + (-x).exp())).collect(),
1644                            &shape,
1645                        )
1646                    }
1647                    GradOp::Relu(a) => {
1648                        let (data, shape) = {
1649                            let t = self.nodes[*a].borrow();
1650                            (t.tensor.to_vec(), t.tensor.shape().to_vec())
1651                        };
1652                        Tensor::from_vec_unchecked(
1653                            data.iter().map(|&x| if x > 0.0 { x } else { 0.0 }).collect(),
1654                            &shape,
1655                        )
1656                    }
1657                    GradOp::TanhAct(a) => {
1658                        let (data, shape) = {
1659                            let t = self.nodes[*a].borrow();
1660                            (t.tensor.to_vec(), t.tensor.shape().to_vec())
1661                        };
1662                        Tensor::from_vec_unchecked(data.iter().map(|x| x.tanh()).collect(), &shape)
1663                    }
1664                    GradOp::Abs(a) => {
1665                        let (data, shape) = {
1666                            let t = self.nodes[*a].borrow();
1667                            (t.tensor.to_vec(), t.tensor.shape().to_vec())
1668                        };
1669                        Tensor::from_vec_unchecked(data.iter().map(|x| x.abs()).collect(), &shape)
1670                    }
1671                    GradOp::Clamp { input, min, max } => {
1672                        let min = *min;
1673                        let max = *max;
1674                        let (data, shape) = {
1675                            let t = self.nodes[*input].borrow();
1676                            (t.tensor.to_vec(), t.tensor.shape().to_vec())
1677                        };
1678                        Tensor::from_vec_unchecked(
1679                            data.iter().map(|&x| x.max(min).min(max)).collect(),
1680                            &shape,
1681                        )
1682                    }
1683                    GradOp::Reshape { input, .. } => {
1684                        let current_shape = node.tensor.shape().to_vec();
1685                        let data = self.nodes[*input].borrow().tensor.to_vec();
1686                        Tensor::from_vec_unchecked(data, &current_shape)
1687                    }
1688                    GradOp::TransposeOp(a) => {
1689                        self.nodes[*a].borrow().tensor.transpose()
1690                    }
1691                    // For complex ops (softmax, layernorm, etc.), keep existing tensor.
1692                    // These are not typically used in simple Hessian computations.
1693                    _ => node.tensor.clone(),
1694                }
1695            };
1696            self.nodes[node_i].borrow_mut().tensor = new_tensor;
1697        }
1698    }
1699
1700    /// Compute the second derivative of a scalar loss with respect to a parameter node.
1701    ///
1702    /// Implements double_backward via finite differences on the backward pass:
1703    /// perturbs the parameter by +eps/-eps, re-runs the forward and backward pass,
1704    /// and computes d(grad)/d(param) numerically. For a scalar param this gives the
1705    /// exact second derivative d²loss/dparam².
1706    ///
1707    /// Returns a tensor of the same shape as the parameter containing second derivatives.
1708    pub fn double_backward(&mut self, loss_idx: usize, param_idx: usize) -> Tensor {
1709        let eps = 1e-5;
1710        let param_shape = self.nodes[param_idx].borrow().tensor.shape().to_vec();
1711        let param_dim: usize = param_shape.iter().product();
1712        let original = self.nodes[param_idx].borrow().tensor.to_vec();
1713        let mut diag = vec![0.0_f64; param_dim];
1714
1715        for i in 0..param_dim {
1716            // Perturb +eps, re-forward, backward
1717            let mut plus = original.clone();
1718            plus[i] += eps;
1719            self.nodes[param_idx].borrow_mut().tensor =
1720                Tensor::from_vec_unchecked(plus, &param_shape);
1721            self.reforward(param_idx + 1, loss_idx);
1722            self.zero_grad();
1723            self.backward(loss_idx);
1724            let grad_plus = self.nodes[param_idx]
1725                .borrow()
1726                .grad
1727                .as_ref()
1728                .map(|g| g.to_vec()[i])
1729                .unwrap_or(0.0);
1730
1731            // Perturb -eps, re-forward, backward
1732            let mut minus = original.clone();
1733            minus[i] -= eps;
1734            self.nodes[param_idx].borrow_mut().tensor =
1735                Tensor::from_vec_unchecked(minus, &param_shape);
1736            self.reforward(param_idx + 1, loss_idx);
1737            self.zero_grad();
1738            self.backward(loss_idx);
1739            let grad_minus = self.nodes[param_idx]
1740                .borrow()
1741                .grad
1742                .as_ref()
1743                .map(|g| g.to_vec()[i])
1744                .unwrap_or(0.0);
1745
1746            diag[i] = (grad_plus - grad_minus) / (2.0 * eps);
1747        }
1748
1749        // Restore original parameter and re-forward to clean state
1750        self.nodes[param_idx].borrow_mut().tensor =
1751            Tensor::from_vec_unchecked(original, &param_shape);
1752        self.reforward(param_idx + 1, loss_idx);
1753
1754        Tensor::from_vec_unchecked(diag, &param_shape)
1755    }
1756
1757    /// Vectorized map (batched evaluation) over a batch dimension.
1758    ///
1759    /// For each tensor in `batch_data`, sets the input node `input_idx` to that tensor,
1760    /// re-evaluates all downstream nodes by re-running the forward pass (recomputing
1761    /// tensor values from the graph structure), and records the output node index
1762    /// after each evaluation.
1763    ///
1764    /// Returns a `Vec<usize>` of output node indices (one per batch element). After
1765    /// calling this, `g.value(results[k])` returns the output for batch element k.
1766    ///
1767    /// Note: This is a simple batched evaluation helper. It mutates node tensors
1768    /// in-place. After calling vmap_forward, the graph holds the values for the
1769    /// LAST batch element. Use `g.value(results[k])` to read individual results
1770    /// (stored in snapshot tensors inside each returned node).
1771    ///
1772    /// Implementation: For each batch element, set the input tensor, re-forward the
1773    /// subgraph from input_idx..=loss_idx by replaying each op, and record the final
1774    /// node's value in a fresh parameter node.
1775    pub fn vmap_forward(&mut self, input_idx: usize, batch_data: &[Tensor]) -> Vec<usize> {
1776        let mut result_indices = Vec::with_capacity(batch_data.len());
1777
1778        // Identify the topological range: nodes from input_idx onward that depend on it.
1779        // We re-evaluate all nodes from input_idx to the end of the current graph.
1780        let graph_len = self.nodes.len();
1781
1782        for batch_tensor in batch_data {
1783            // Set input node to batch element
1784            self.nodes[input_idx].borrow_mut().tensor = batch_tensor.clone();
1785
1786            // Re-run forward pass for all nodes after input_idx by replaying their ops
1787            for node_i in (input_idx + 1)..graph_len {
1788                let (op, new_tensor) = {
1789                    let node = self.nodes[node_i].borrow();
1790                    let op = node.op.clone();
1791                    let new_tensor = match &op {
1792                        GradOp::Add(a, b) => {
1793                            let at = self.nodes[*a].borrow().tensor.clone();
1794                            let bt = self.nodes[*b].borrow().tensor.clone();
1795                            at.add_unchecked(&bt)
1796                        }
1797                        GradOp::Sub(a, b) => {
1798                            let at = self.nodes[*a].borrow().tensor.clone();
1799                            let bt = self.nodes[*b].borrow().tensor.clone();
1800                            at.sub_unchecked(&bt)
1801                        }
1802                        GradOp::Mul(a, b) => {
1803                            let at = self.nodes[*a].borrow().tensor.clone();
1804                            let bt = self.nodes[*b].borrow().tensor.clone();
1805                            at.mul_elem_unchecked(&bt)
1806                        }
1807                        GradOp::Div(a, b) => {
1808                            let at = self.nodes[*a].borrow().tensor.clone();
1809                            let bt = self.nodes[*b].borrow().tensor.clone();
1810                            at.div_elem_unchecked(&bt)
1811                        }
1812                        GradOp::Neg(a) => {
1813                            self.nodes[*a].borrow().tensor.neg()
1814                        }
1815                        GradOp::ScalarMul(a, s) => {
1816                            self.nodes[*a].borrow().tensor.scalar_mul(*s)
1817                        }
1818                        GradOp::MatMul(a, b) => {
1819                            let at = self.nodes[*a].borrow().tensor.clone();
1820                            let bt = self.nodes[*b].borrow().tensor.clone();
1821                            at.matmul_unchecked(&bt)
1822                        }
1823                        GradOp::Sum(a) => {
1824                            let s = self.nodes[*a].borrow().tensor.sum();
1825                            let shape = vec![1usize];
1826                            Tensor::from_vec_unchecked(vec![s], &shape)
1827                        }
1828                        GradOp::Mean(a) => {
1829                            let m = self.nodes[*a].borrow().tensor.mean();
1830                            Tensor::from_vec_unchecked(vec![m], &[1])
1831                        }
1832                        GradOp::Exp(a) => {
1833                            let data = self.nodes[*a].borrow().tensor.to_vec();
1834                            let shape = self.nodes[*a].borrow().tensor.shape().to_vec();
1835                            Tensor::from_vec_unchecked(
1836                                data.iter().map(|x| x.exp()).collect(),
1837                                &shape,
1838                            )
1839                        }
1840                        GradOp::Ln(a) => {
1841                            let data = self.nodes[*a].borrow().tensor.to_vec();
1842                            let shape = self.nodes[*a].borrow().tensor.shape().to_vec();
1843                            Tensor::from_vec_unchecked(
1844                                data.iter().map(|x| x.ln()).collect(),
1845                                &shape,
1846                            )
1847                        }
1848                        GradOp::Sin(a) => {
1849                            let data = self.nodes[*a].borrow().tensor.to_vec();
1850                            let shape = self.nodes[*a].borrow().tensor.shape().to_vec();
1851                            Tensor::from_vec_unchecked(
1852                                data.iter().map(|x| x.sin()).collect(),
1853                                &shape,
1854                            )
1855                        }
1856                        GradOp::Cos(a) => {
1857                            let data = self.nodes[*a].borrow().tensor.to_vec();
1858                            let shape = self.nodes[*a].borrow().tensor.shape().to_vec();
1859                            Tensor::from_vec_unchecked(
1860                                data.iter().map(|x| x.cos()).collect(),
1861                                &shape,
1862                            )
1863                        }
1864                        GradOp::Sqrt(a) => {
1865                            let data = self.nodes[*a].borrow().tensor.to_vec();
1866                            let shape = self.nodes[*a].borrow().tensor.shape().to_vec();
1867                            Tensor::from_vec_unchecked(
1868                                data.iter().map(|x| x.sqrt()).collect(),
1869                                &shape,
1870                            )
1871                        }
1872                        GradOp::Pow(a, n) => {
1873                            let data = self.nodes[*a].borrow().tensor.to_vec();
1874                            let shape = self.nodes[*a].borrow().tensor.shape().to_vec();
1875                            Tensor::from_vec_unchecked(
1876                                data.iter().map(|x| x.powf(*n)).collect(),
1877                                &shape,
1878                            )
1879                        }
1880                        GradOp::Sigmoid(a) => {
1881                            let data = self.nodes[*a].borrow().tensor.to_vec();
1882                            let shape = self.nodes[*a].borrow().tensor.shape().to_vec();
1883                            Tensor::from_vec_unchecked(
1884                                data.iter().map(|&x| 1.0 / (1.0 + (-x).exp())).collect(),
1885                                &shape,
1886                            )
1887                        }
1888                        GradOp::Relu(a) => {
1889                            let data = self.nodes[*a].borrow().tensor.to_vec();
1890                            let shape = self.nodes[*a].borrow().tensor.shape().to_vec();
1891                            Tensor::from_vec_unchecked(
1892                                data.iter().map(|&x| if x > 0.0 { x } else { 0.0 }).collect(),
1893                                &shape,
1894                            )
1895                        }
1896                        GradOp::TanhAct(a) => {
1897                            let data = self.nodes[*a].borrow().tensor.to_vec();
1898                            let shape = self.nodes[*a].borrow().tensor.shape().to_vec();
1899                            Tensor::from_vec_unchecked(
1900                                data.iter().map(|x| x.tanh()).collect(),
1901                                &shape,
1902                            )
1903                        }
1904                        GradOp::Abs(a) => {
1905                            let data = self.nodes[*a].borrow().tensor.to_vec();
1906                            let shape = self.nodes[*a].borrow().tensor.shape().to_vec();
1907                            Tensor::from_vec_unchecked(
1908                                data.iter().map(|x| x.abs()).collect(),
1909                                &shape,
1910                            )
1911                        }
1912                        GradOp::Clamp { input, min, max } => {
1913                            let data = self.nodes[*input].borrow().tensor.to_vec();
1914                            let shape = self.nodes[*input].borrow().tensor.shape().to_vec();
1915                            Tensor::from_vec_unchecked(
1916                                data.iter().map(|&x| x.max(*min).min(*max)).collect(),
1917                                &shape,
1918                            )
1919                        }
1920                        GradOp::Reshape { input, .. } => {
1921                            // Keep same data, use current node's shape
1922                            let data = self.nodes[*input].borrow().tensor.to_vec();
1923                            let shape = node.tensor.shape().to_vec();
1924                            Tensor::from_vec_unchecked(data, &shape)
1925                        }
1926                        GradOp::TransposeOp(a) => {
1927                            self.nodes[*a].borrow().tensor.transpose()
1928                        }
1929                        // For complex ops and ops without direct input dependency on input_idx,
1930                        // keep the existing tensor value (no re-computation needed).
1931                        _ => node.tensor.clone(),
1932                    };
1933                    (op, new_tensor)
1934                };
1935                let _ = op; // op already moved/used above
1936                self.nodes[node_i].borrow_mut().tensor = new_tensor;
1937            }
1938
1939            // Record the output value from the last node (graph_len - 1) by creating
1940            // a snapshot input node with the current output value.
1941            let output_tensor = self.nodes[graph_len - 1].borrow().tensor.clone();
1942            let snapshot_idx = self.nodes.len();
1943            self.nodes.push(Rc::new(RefCell::new(GradNode {
1944                op: GradOp::Input,
1945                tensor: output_tensor,
1946                grad: None,
1947            })));
1948            result_indices.push(snapshot_idx);
1949        }
1950
1951        result_indices
1952    }
1953
1954    /// Backward pass with a custom gradient seed tensor (for Jacobian computation).
1955    pub fn backward_with_seed(&mut self, loss_idx: usize, seed: &Tensor) {
1956        let n = self.nodes.len();
1957        let mut grads: Vec<Option<Tensor>> = vec![None; n];
1958        grads[loss_idx] = Some(seed.clone());
1959
1960        for i in (0..n).rev() {
1961            let grad = match grads[i].take() {
1962                Some(g) => g,
1963                None => continue,
1964            };
1965
1966            let node = self.nodes[i].borrow();
1967            if let Some(ref _param_grad) = node.grad {
1968                // Accumulate into parameter grad storage
1969                drop(node);
1970                let new_grad = {
1971                    let n = self.nodes[i].borrow();
1972                    if let Some(ref existing) = n.grad {
1973                        if existing.to_vec().iter().all(|&x| x == 0.0) {
1974                            grad.clone()
1975                        } else {
1976                            existing.add_unchecked(&grad)
1977                        }
1978                    } else {
1979                        grad.clone()
1980                    }
1981                };
1982                self.nodes[i].borrow_mut().grad = Some(new_grad);
1983            } else {
1984                drop(node);
1985            }
1986
1987            // Propagate gradients using the same rules as backward()
1988            let node = self.nodes[i].borrow();
1989            let node_tensor = node.tensor.clone();
1990            match &node.op {
1991                GradOp::Input | GradOp::Parameter => {}
1992                GradOp::Add(a, b) => {
1993                    accumulate_grad(&mut grads, *a, &grad);
1994                    accumulate_grad(&mut grads, *b, &grad);
1995                }
1996                GradOp::Sub(a, b) => {
1997                    accumulate_grad(&mut grads, *a, &grad);
1998                    accumulate_grad(&mut grads, *b, &grad.neg());
1999                }
2000                GradOp::Mul(a, b) => {
2001                    let a_val = self.nodes[*a].borrow().tensor.clone();
2002                    let b_val = self.nodes[*b].borrow().tensor.clone();
2003                    accumulate_grad(&mut grads, *a, &grad.mul_elem_unchecked(&b_val));
2004                    accumulate_grad(&mut grads, *b, &grad.mul_elem_unchecked(&a_val));
2005                }
2006                GradOp::Div(a, b) => {
2007                    let a_val = self.nodes[*a].borrow().tensor.clone();
2008                    let b_val = self.nodes[*b].borrow().tensor.clone();
2009                    let grad_a = grad.div_elem_unchecked(&b_val);
2010                    let neg_a_over_b2 = a_val.neg().div_elem_unchecked(
2011                        &b_val.mul_elem_unchecked(&b_val),
2012                    );
2013                    let grad_b = grad.mul_elem_unchecked(&neg_a_over_b2);
2014                    accumulate_grad(&mut grads, *a, &grad_a);
2015                    accumulate_grad(&mut grads, *b, &grad_b);
2016                }
2017                GradOp::Neg(a) => {
2018                    accumulate_grad(&mut grads, *a, &grad.neg());
2019                }
2020                GradOp::ScalarMul(a, s) => {
2021                    accumulate_grad(&mut grads, *a, &grad.scalar_mul(*s));
2022                }
2023                GradOp::Exp(a) => {
2024                    accumulate_grad(&mut grads, *a, &grad.mul_elem_unchecked(&node_tensor));
2025                }
2026                GradOp::Ln(a) => {
2027                    let a_val = self.nodes[*a].borrow().tensor.clone();
2028                    let inv = Tensor::from_vec_unchecked(
2029                        a_val.to_vec().iter().map(|&x| 1.0 / x).collect(),
2030                        a_val.shape(),
2031                    );
2032                    accumulate_grad(&mut grads, *a, &grad.mul_elem_unchecked(&inv));
2033                }
2034                GradOp::Sin(a) => {
2035                    let a_val = self.nodes[*a].borrow().tensor.clone();
2036                    let cos_a = Tensor::from_vec_unchecked(
2037                        a_val.to_vec().iter().map(|&x| x.cos()).collect(),
2038                        a_val.shape(),
2039                    );
2040                    accumulate_grad(&mut grads, *a, &grad.mul_elem_unchecked(&cos_a));
2041                }
2042                GradOp::Cos(a) => {
2043                    let a_val = self.nodes[*a].borrow().tensor.clone();
2044                    let neg_sin = Tensor::from_vec_unchecked(
2045                        a_val.to_vec().iter().map(|&x| -x.sin()).collect(),
2046                        a_val.shape(),
2047                    );
2048                    accumulate_grad(&mut grads, *a, &grad.mul_elem_unchecked(&neg_sin));
2049                }
2050                GradOp::Sqrt(a) => {
2051                    let inv2sqrt = Tensor::from_vec_unchecked(
2052                        node_tensor.to_vec().iter().map(|&x| 0.5 / x).collect(),
2053                        node_tensor.shape(),
2054                    );
2055                    accumulate_grad(&mut grads, *a, &grad.mul_elem_unchecked(&inv2sqrt));
2056                }
2057                GradOp::Pow(a, exp) => {
2058                    let a_val = self.nodes[*a].borrow().tensor.clone();
2059                    let local = Tensor::from_vec_unchecked(
2060                        a_val.to_vec().iter().map(|&x| exp * x.powf(exp - 1.0)).collect(),
2061                        a_val.shape(),
2062                    );
2063                    accumulate_grad(&mut grads, *a, &grad.mul_elem_unchecked(&local));
2064                }
2065                GradOp::Sigmoid(a) => {
2066                    let sig = &node_tensor;
2067                    let one_minus = Tensor::from_vec_unchecked(
2068                        sig.to_vec().iter().map(|&x| 1.0 - x).collect(),
2069                        sig.shape(),
2070                    );
2071                    let local = sig.mul_elem_unchecked(&one_minus);
2072                    accumulate_grad(&mut grads, *a, &grad.mul_elem_unchecked(&local));
2073                }
2074                GradOp::Relu(a) => {
2075                    let a_val = self.nodes[*a].borrow().tensor.clone();
2076                    let mask = Tensor::from_vec_unchecked(
2077                        a_val.to_vec().iter().map(|&x| if x > 0.0 { 1.0 } else { 0.0 }).collect(),
2078                        a_val.shape(),
2079                    );
2080                    accumulate_grad(&mut grads, *a, &grad.mul_elem_unchecked(&mask));
2081                }
2082                GradOp::TanhAct(a) => {
2083                    let one_minus_sq = Tensor::from_vec_unchecked(
2084                        node_tensor.to_vec().iter().map(|&x| 1.0 - x * x).collect(),
2085                        node_tensor.shape(),
2086                    );
2087                    accumulate_grad(&mut grads, *a, &grad.mul_elem_unchecked(&one_minus_sq));
2088                }
2089                GradOp::MatMul(a, b) => {
2090                    let a_val = self.nodes[*a].borrow().tensor.clone();
2091                    let b_val = self.nodes[*b].borrow().tensor.clone();
2092                    accumulate_grad(&mut grads, *a, &grad.matmul_unchecked(&b_val.transpose()));
2093                    accumulate_grad(&mut grads, *b, &a_val.transpose().matmul_unchecked(&grad));
2094                }
2095                GradOp::Sum(a) => {
2096                    let a_shape = self.nodes[*a].borrow().tensor.shape().to_vec();
2097                    let grad_val = grad.to_vec()[0];
2098                    let expanded = Tensor::from_vec_unchecked(
2099                        vec![grad_val; a_shape.iter().product()],
2100                        &a_shape,
2101                    );
2102                    accumulate_grad(&mut grads, *a, &expanded);
2103                }
2104                GradOp::Mean(a) => {
2105                    let a_shape = self.nodes[*a].borrow().tensor.shape().to_vec();
2106                    let n_elem = a_shape.iter().product::<usize>() as f64;
2107                    let grad_val = grad.to_vec()[0] / n_elem;
2108                    let expanded = Tensor::from_vec_unchecked(
2109                        vec![grad_val; a_shape.iter().product()],
2110                        &a_shape,
2111                    );
2112                    accumulate_grad(&mut grads, *a, &expanded);
2113                }
2114                GradOp::StructField { parent, field_index, total_fields } => {
2115                    let parent_shape = self.nodes[*parent].borrow().tensor.shape().to_vec();
2116                    let parent_n: usize = parent_shape.iter().product();
2117                    let chunk = parent_n / total_fields;
2118                    let start = field_index * chunk;
2119                    let mut parent_grad = vec![0.0_f64; parent_n];
2120                    let g_vec = grad.to_vec();
2121                    for (j, &gv) in g_vec.iter().enumerate() {
2122                        parent_grad[start + j] = gv;
2123                    }
2124                    let pg = Tensor::from_vec_unchecked(parent_grad, &parent_shape);
2125                    accumulate_grad(&mut grads, *parent, &pg);
2126                }
2127                GradOp::MapLookup { map_node, key_index, total_keys } => {
2128                    let map_shape = self.nodes[*map_node].borrow().tensor.shape().to_vec();
2129                    let map_n: usize = map_shape.iter().product();
2130                    let chunk = map_n / total_keys;
2131                    let start = key_index * chunk;
2132                    let mut map_grad = vec![0.0_f64; map_n];
2133                    let g_vec = grad.to_vec();
2134                    for (j, &gv) in g_vec.iter().enumerate() {
2135                        map_grad[start + j] = gv;
2136                    }
2137                    let mg = Tensor::from_vec_unchecked(map_grad, &map_shape);
2138                    accumulate_grad(&mut grads, *map_node, &mg);
2139                }
2140                // Phase 8: Extended AD backward (backward_with_seed)
2141                GradOp::Abs(a) => {
2142                    let a_val = self.nodes[*a].borrow().tensor.clone();
2143                    let sign = Tensor::from_vec_unchecked(
2144                        a_val.to_vec().iter().map(|&x| {
2145                            if x > 0.0 { 1.0 } else if x < 0.0 { -1.0 } else { 0.0 }
2146                        }).collect(),
2147                        a_val.shape(),
2148                    );
2149                    accumulate_grad(&mut grads, *a, &grad.mul_elem_unchecked(&sign));
2150                }
2151                GradOp::Log2(a) => {
2152                    let a_val = self.nodes[*a].borrow().tensor.clone();
2153                    let ln2 = std::f64::consts::LN_2;
2154                    let local = Tensor::from_vec_unchecked(
2155                        a_val.to_vec().iter().map(|&x| 1.0 / (x * ln2)).collect(),
2156                        a_val.shape(),
2157                    );
2158                    accumulate_grad(&mut grads, *a, &grad.mul_elem_unchecked(&local));
2159                }
2160                GradOp::Softmax(a) => {
2161                    use cjc_repro::KahanAccumulatorF64;
2162                    let sm = &node_tensor;
2163                    let sm_data = sm.to_vec();
2164                    let grad_data = grad.to_vec();
2165                    let mut dot_acc = KahanAccumulatorF64::new();
2166                    for (&g, &s) in grad_data.iter().zip(sm_data.iter()) {
2167                        dot_acc.add(g * s);
2168                    }
2169                    let dot = dot_acc.finalize();
2170                    let grad_input: Vec<f64> = sm_data.iter().zip(grad_data.iter())
2171                        .map(|(&s, &g)| s * (g - dot))
2172                        .collect();
2173                    let grad_a = Tensor::from_vec_unchecked(grad_input, sm.shape());
2174                    accumulate_grad(&mut grads, *a, &grad_a);
2175                }
2176                GradOp::CrossEntropy { logits, targets } => {
2177                    use cjc_repro::KahanAccumulatorF64;
2178                    let logits_val = self.nodes[*logits].borrow().tensor.clone();
2179                    let targets_val = self.nodes[*targets].borrow().tensor.clone();
2180                    let logits_data = logits_val.to_vec();
2181                    let targets_data = targets_val.to_vec();
2182                    let max_val = logits_data.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
2183                    let exp_shifted: Vec<f64> = logits_data.iter().map(|&x| (x - max_val).exp()).collect();
2184                    let mut sum_acc = KahanAccumulatorF64::new();
2185                    for &v in &exp_shifted {
2186                        sum_acc.add(v);
2187                    }
2188                    let sum_exp = sum_acc.finalize();
2189                    let softmax: Vec<f64> = exp_shifted.iter().map(|&e| e / sum_exp).collect();
2190                    let upstream = grad.to_vec()[0];
2191                    let grad_logits: Vec<f64> = softmax.iter().zip(targets_data.iter())
2192                        .map(|(&s, &t)| upstream * (s - t))
2193                        .collect();
2194                    let gl = Tensor::from_vec_unchecked(grad_logits, logits_val.shape());
2195                    accumulate_grad(&mut grads, *logits, &gl);
2196                }
2197                GradOp::LayerNorm(a) => {
2198                    use cjc_repro::KahanAccumulatorF64;
2199                    let x_hat = &node_tensor;
2200                    let x_hat_data = x_hat.to_vec();
2201                    let grad_data = grad.to_vec();
2202                    let n = x_hat_data.len() as f64;
2203                    let a_val = self.nodes[*a].borrow().tensor.clone();
2204                    let a_data = a_val.to_vec();
2205                    let mut mean_acc = KahanAccumulatorF64::new();
2206                    for &v in &a_data { mean_acc.add(v); }
2207                    let mean = mean_acc.finalize() / n;
2208                    let mut var_acc = KahanAccumulatorF64::new();
2209                    for &v in &a_data { let d = v - mean; var_acc.add(d * d); }
2210                    let var = var_acc.finalize() / n;
2211                    let std_val = (var + 1e-5).sqrt();
2212                    let mut mg_acc = KahanAccumulatorF64::new();
2213                    for &g in &grad_data { mg_acc.add(g); }
2214                    let mean_grad = mg_acc.finalize() / n;
2215                    let mut mgx_acc = KahanAccumulatorF64::new();
2216                    for (&g, &xh) in grad_data.iter().zip(x_hat_data.iter()) { mgx_acc.add(g * xh); }
2217                    let mean_grad_xhat = mgx_acc.finalize() / n;
2218                    let dx: Vec<f64> = grad_data.iter().zip(x_hat_data.iter())
2219                        .map(|(&g, &xh)| (g - mean_grad - xh * mean_grad_xhat) / std_val)
2220                        .collect();
2221                    accumulate_grad(&mut grads, *a, &Tensor::from_vec_unchecked(dx, a_val.shape()));
2222                }
2223                GradOp::BatchNorm(a) => {
2224                    use cjc_repro::KahanAccumulatorF64;
2225                    let x_hat = &node_tensor;
2226                    let x_hat_data = x_hat.to_vec();
2227                    let grad_data = grad.to_vec();
2228                    let n = x_hat_data.len() as f64;
2229                    let a_val = self.nodes[*a].borrow().tensor.clone();
2230                    let a_data = a_val.to_vec();
2231                    let mut mean_acc = KahanAccumulatorF64::new();
2232                    for &v in &a_data { mean_acc.add(v); }
2233                    let mean = mean_acc.finalize() / n;
2234                    let mut var_acc = KahanAccumulatorF64::new();
2235                    for &v in &a_data { let d = v - mean; var_acc.add(d * d); }
2236                    let var = var_acc.finalize() / n;
2237                    let std_val = (var + 1e-5).sqrt();
2238                    let mut mg_acc = KahanAccumulatorF64::new();
2239                    for &g in &grad_data { mg_acc.add(g); }
2240                    let mean_grad = mg_acc.finalize() / n;
2241                    let mut mgx_acc = KahanAccumulatorF64::new();
2242                    for (&g, &xh) in grad_data.iter().zip(x_hat_data.iter()) { mgx_acc.add(g * xh); }
2243                    let mean_grad_xhat = mgx_acc.finalize() / n;
2244                    let dx: Vec<f64> = grad_data.iter().zip(x_hat_data.iter())
2245                        .map(|(&g, &xh)| (g - mean_grad - xh * mean_grad_xhat) / std_val)
2246                        .collect();
2247                    accumulate_grad(&mut grads, *a, &Tensor::from_vec_unchecked(dx, a_val.shape()));
2248                }
2249                GradOp::Clamp { input, min, max } => {
2250                    let a_val = self.nodes[*input].borrow().tensor.clone();
2251                    let mask = Tensor::from_vec_unchecked(
2252                        a_val.to_vec().iter().map(|&x| {
2253                            if x >= *min && x <= *max { 1.0 } else { 0.0 }
2254                        }).collect(),
2255                        a_val.shape(),
2256                    );
2257                    accumulate_grad(&mut grads, *input, &grad.mul_elem_unchecked(&mask));
2258                }
2259                GradOp::Where { cond, on_true, on_false } => {
2260                    let cond_data = self.nodes[*cond].borrow().tensor.to_vec();
2261                    let grad_data = grad.to_vec();
2262                    let shape = grad.shape().to_vec();
2263                    let grad_true: Vec<f64> = cond_data.iter().zip(grad_data.iter())
2264                        .map(|(&c, &g)| if c != 0.0 { g } else { 0.0 }).collect();
2265                    let grad_false: Vec<f64> = cond_data.iter().zip(grad_data.iter())
2266                        .map(|(&c, &g)| if c != 0.0 { 0.0 } else { g }).collect();
2267                    accumulate_grad(&mut grads, *on_true, &Tensor::from_vec_unchecked(grad_true, &shape));
2268                    accumulate_grad(&mut grads, *on_false, &Tensor::from_vec_unchecked(grad_false, &shape));
2269                }
2270                GradOp::Reshape { input, ref original_shape } => {
2271                    let grad_a = grad.reshape(original_shape).expect("Reshape backward: shape mismatch");
2272                    accumulate_grad(&mut grads, *input, &grad_a);
2273                }
2274                GradOp::TransposeOp(a) => {
2275                    accumulate_grad(&mut grads, *a, &grad.transpose());
2276                }
2277                GradOp::CatOp { ref inputs, axis, ref sizes } => {
2278                    let grad_data = grad.to_vec();
2279                    let grad_shape = grad.shape().to_vec();
2280                    let ndim = grad_shape.len();
2281                    if ndim == 1 {
2282                        let mut offset = 0usize;
2283                        for (idx, &sz) in inputs.iter().zip(sizes.iter()) {
2284                            let piece = grad_data[offset..offset + sz].to_vec();
2285                            accumulate_grad(&mut grads, *idx, &Tensor::from_vec_unchecked(piece, &[sz]));
2286                            offset += sz;
2287                        }
2288                    } else if ndim == 2 && *axis == 0 {
2289                        let cols = grad_shape[1];
2290                        let mut row_offset = 0usize;
2291                        for (idx, &sz) in inputs.iter().zip(sizes.iter()) {
2292                            let start = row_offset * cols;
2293                            let end = start + sz * cols;
2294                            let piece = grad_data[start..end].to_vec();
2295                            accumulate_grad(&mut grads, *idx, &Tensor::from_vec_unchecked(piece, &[sz, cols]));
2296                            row_offset += sz;
2297                        }
2298                    } else if ndim == 2 && *axis == 1 {
2299                        let nrows = grad_shape[0];
2300                        let total_cols = grad_shape[1];
2301                        for (input_idx, (idx, &sz)) in inputs.iter().zip(sizes.iter()).enumerate() {
2302                            let mut piece = Vec::with_capacity(nrows * sz);
2303                            let col_offset: usize = sizes[..input_idx].iter().sum();
2304                            for row in 0..nrows {
2305                                let row_start = row * total_cols + col_offset;
2306                                piece.extend_from_slice(&grad_data[row_start..row_start + sz]);
2307                            }
2308                            accumulate_grad(&mut grads, *idx, &Tensor::from_vec_unchecked(piece, &[nrows, sz]));
2309                        }
2310                    } else {
2311                        let mut offset = 0usize;
2312                        for (idx, &sz) in inputs.iter().zip(sizes.iter()) {
2313                            let piece_len = sz * grad_data.len() / grad_shape[*axis];
2314                            let piece = grad_data[offset..offset + piece_len].to_vec();
2315                            let mut piece_shape = grad_shape.clone();
2316                            piece_shape[*axis] = sz;
2317                            accumulate_grad(&mut grads, *idx, &Tensor::from_vec_unchecked(piece, &piece_shape));
2318                            offset += piece_len;
2319                        }
2320                    }
2321                }
2322                GradOp::GatherOp { input, ref indices, axis } => {
2323                    let input_shape = self.nodes[*input].borrow().tensor.shape().to_vec();
2324                    let input_len: usize = input_shape.iter().product();
2325                    let mut scatter = vec![0.0_f64; input_len];
2326                    let grad_data = grad.to_vec();
2327                    if self.nodes[*input].borrow().tensor.ndim() == 1 {
2328                        for (gi, &idx) in indices.iter().enumerate() {
2329                            scatter[idx] += grad_data[gi];
2330                        }
2331                    } else if *axis == 0 && self.nodes[*input].borrow().tensor.ndim() == 2 {
2332                        let cols = input_shape[1];
2333                        for (gi, &idx) in indices.iter().enumerate() {
2334                            for c in 0..cols {
2335                                scatter[idx * cols + c] += grad_data[gi * cols + c];
2336                            }
2337                        }
2338                    } else {
2339                        for (gi, &idx) in indices.iter().enumerate() {
2340                            scatter[idx] += grad_data[gi];
2341                        }
2342                    }
2343                    accumulate_grad(&mut grads, *input, &Tensor::from_vec_unchecked(scatter, &input_shape));
2344                }
2345            }
2346        }
2347    }
2348}
2349
2350fn accumulate_grad(grads: &mut [Option<Tensor>], idx: usize, grad: &Tensor) {
2351    if let Some(existing) = &grads[idx] {
2352        grads[idx] = Some(existing.add_unchecked(grad));
2353    } else {
2354        grads[idx] = Some(grad.clone());
2355    }
2356}
2357
2358// Close the impl block — backward_with_seed, jacobian, hessian_diag are all inside impl GradGraph
2359
2360impl Default for GradGraph {
2361    fn default() -> Self {
2362        Self::new()
2363    }
2364}
2365
2366// ── Finite Difference Validation ────────────────────────────────
2367
2368/// Validate gradient using finite differences.
2369pub fn check_grad_finite_diff<F>(
2370    f: F,
2371    x: f64,
2372    expected_grad: f64,
2373    eps: f64,
2374    tol: f64,
2375) -> bool
2376where
2377    F: Fn(f64) -> f64,
2378{
2379    let fd_grad = (f(x + eps) - f(x - eps)) / (2.0 * eps);
2380    (fd_grad - expected_grad).abs() < tol
2381}
2382
2383#[cfg(test)]
2384mod tests {
2385    use super::*;
2386
2387    // ── Forward Mode Tests ──────────────────────────────────
2388
2389    #[test]
2390    fn test_dual_add() {
2391        let a = Dual::variable(3.0);
2392        let b = Dual::constant(2.0);
2393        let c = a + b;
2394        assert_eq!(c.value, 5.0);
2395        assert_eq!(c.deriv, 1.0);
2396    }
2397
2398    #[test]
2399    fn test_dual_mul() {
2400        let a = Dual::variable(3.0);
2401        let b = Dual::constant(2.0);
2402        let c = a * b;
2403        assert_eq!(c.value, 6.0);
2404        assert_eq!(c.deriv, 2.0); // d/dx (x * 2) = 2
2405    }
2406
2407    #[test]
2408    fn test_dual_chain_rule() {
2409        // f(x) = x^2 + 2x + 1, f'(x) = 2x + 2, f'(3) = 8
2410        let x = Dual::variable(3.0);
2411        let result = x.clone() * x.clone() + Dual::constant(2.0) * x + Dual::one();
2412        assert_eq!(result.value, 16.0);
2413        assert_eq!(result.deriv, 8.0);
2414    }
2415
2416    #[test]
2417    fn test_dual_exp() {
2418        let x = Dual::variable(1.0);
2419        let result = x.exp();
2420        assert!((result.value - std::f64::consts::E).abs() < 1e-10);
2421        assert!((result.deriv - std::f64::consts::E).abs() < 1e-10);
2422    }
2423
2424    #[test]
2425    fn test_dual_sin_cos() {
2426        let x = Dual::variable(0.0);
2427        let sin_x = x.clone().sin();
2428        let cos_x = x.cos();
2429        assert!((sin_x.value - 0.0).abs() < 1e-10);
2430        assert!((sin_x.deriv - 1.0).abs() < 1e-10); // d/dx sin(x) at 0 = cos(0) = 1
2431        assert!((cos_x.value - 1.0).abs() < 1e-10);
2432        assert!((cos_x.deriv - 0.0).abs() < 1e-10); // d/dx cos(x) at 0 = -sin(0) = 0
2433    }
2434
2435    #[test]
2436    fn test_dual_div() {
2437        let a = Dual::variable(6.0);
2438        let b = Dual::constant(3.0);
2439        let c = a / b;
2440        assert_eq!(c.value, 2.0);
2441        assert!((c.deriv - 1.0 / 3.0).abs() < 1e-10);
2442    }
2443
2444    #[test]
2445    fn test_finite_diff_validation() {
2446        // f(x) = x^2, f'(3) = 6
2447        let f = |x: f64| x * x;
2448        assert!(check_grad_finite_diff(f, 3.0, 6.0, 1e-7, 1e-5));
2449    }
2450
2451    // ── Reverse Mode Tests ──────────────────────────────────
2452
2453    #[test]
2454    fn test_reverse_add() {
2455        let mut g = GradGraph::new();
2456        let a = g.parameter(Tensor::from_vec_unchecked(vec![3.0], &[1]));
2457        let b = g.parameter(Tensor::from_vec_unchecked(vec![2.0], &[1]));
2458        let c = g.add(a, b);
2459
2460        g.backward(c);
2461
2462        let ga = g.grad(a).unwrap();
2463        let gb = g.grad(b).unwrap();
2464        assert!((ga.to_vec()[0] - 1.0).abs() < 1e-10);
2465        assert!((gb.to_vec()[0] - 1.0).abs() < 1e-10);
2466    }
2467
2468    #[test]
2469    fn test_reverse_mul() {
2470        let mut g = GradGraph::new();
2471        let a = g.parameter(Tensor::from_vec_unchecked(vec![3.0], &[1]));
2472        let b = g.parameter(Tensor::from_vec_unchecked(vec![2.0], &[1]));
2473        let c = g.mul(a, b);
2474
2475        g.backward(c);
2476
2477        let ga = g.grad(a).unwrap();
2478        let gb = g.grad(b).unwrap();
2479        assert!((ga.to_vec()[0] - 2.0).abs() < 1e-10); // d/da (a*b) = b = 2
2480        assert!((gb.to_vec()[0] - 3.0).abs() < 1e-10); // d/db (a*b) = a = 3
2481    }
2482
2483    #[test]
2484    fn test_reverse_matmul_gradient() {
2485        let mut g = GradGraph::new();
2486
2487        // Simple 2x2 matmul
2488        let a = g.parameter(Tensor::from_vec_unchecked(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]));
2489        let b = g.parameter(Tensor::from_vec_unchecked(vec![5.0, 6.0, 7.0, 8.0], &[2, 2]));
2490        let c = g.matmul(a, b);
2491        let loss = g.sum(c);
2492
2493        g.backward(loss);
2494
2495        // Gradient of sum(A @ B) w.r.t. A = ones @ B^T
2496        let ga = g.grad(a).unwrap();
2497        let ga_data = ga.to_vec();
2498        // B^T = [[5,7],[6,8]], ones@B^T = [[5+7, 6+8],[5+7, 6+8]] = [[12,14],[12,14]]
2499        // Wait: grad = ones(2,2), grad @ B^T
2500        // B^T = [[5,7],[6,8]]
2501        // ones(2,2) @ B^T = [[5+6, 7+8],[5+6, 7+8]] = [[11,15],[11,15]]
2502        assert!((ga_data[0] - 11.0).abs() < 1e-10);
2503        assert!((ga_data[1] - 15.0).abs() < 1e-10);
2504    }
2505
2506    #[test]
2507    fn test_reverse_mean_gradient() {
2508        let mut g = GradGraph::new();
2509        let a = g.parameter(Tensor::from_vec_unchecked(vec![2.0, 4.0, 6.0, 8.0], &[4]));
2510        let loss = g.mean(a);
2511
2512        g.backward(loss);
2513
2514        let ga = g.grad(a).unwrap();
2515        let ga_data = ga.to_vec();
2516        // d/da mean(a) = 1/N for each element
2517        for &v in &ga_data {
2518            assert!((v - 0.25).abs() < 1e-10);
2519        }
2520    }
2521
2522    // ── Phase B8: Reverse Mode Transcendental & Activation Tests ──
2523
2524    #[test]
2525    fn test_reverse_sin() {
2526        let mut g = GradGraph::new();
2527        let a = g.parameter(Tensor::from_vec_unchecked(vec![0.0], &[1]));
2528        let b = g.sin(a);
2529        g.backward(b);
2530        let ga = g.grad(a).unwrap();
2531        // d/dx sin(x) at x=0 = cos(0) = 1.0
2532        assert!((ga.to_vec()[0] - 1.0).abs() < 1e-10);
2533    }
2534
2535    #[test]
2536    fn test_reverse_cos() {
2537        let mut g = GradGraph::new();
2538        let a = g.parameter(Tensor::from_vec_unchecked(vec![0.0], &[1]));
2539        let b = g.cos(a);
2540        g.backward(b);
2541        let ga = g.grad(a).unwrap();
2542        // d/dx cos(x) at x=0 = -sin(0) = 0.0
2543        assert!(ga.to_vec()[0].abs() < 1e-10);
2544    }
2545
2546    #[test]
2547    fn test_reverse_sqrt() {
2548        let mut g = GradGraph::new();
2549        let a = g.parameter(Tensor::from_vec_unchecked(vec![4.0], &[1]));
2550        let b = g.sqrt(a);
2551        g.backward(b);
2552        let ga = g.grad(a).unwrap();
2553        // d/dx sqrt(x) at x=4 = 1/(2*2) = 0.25
2554        assert!((ga.to_vec()[0] - 0.25).abs() < 1e-10);
2555    }
2556
2557    #[test]
2558    fn test_reverse_pow() {
2559        let mut g = GradGraph::new();
2560        let a = g.parameter(Tensor::from_vec_unchecked(vec![2.0], &[1]));
2561        let b = g.pow(a, 3.0); // x^3
2562        g.backward(b);
2563        let ga = g.grad(a).unwrap();
2564        // d/dx x^3 at x=2 = 3*4 = 12.0
2565        assert!((ga.to_vec()[0] - 12.0).abs() < 1e-10);
2566    }
2567
2568    #[test]
2569    fn test_reverse_sigmoid() {
2570        let mut g = GradGraph::new();
2571        let a = g.parameter(Tensor::from_vec_unchecked(vec![0.0], &[1]));
2572        let b = g.sigmoid(a);
2573        g.backward(b);
2574        let ga = g.grad(a).unwrap();
2575        // sigmoid(0) = 0.5, sigmoid'(0) = 0.5 * 0.5 = 0.25
2576        assert!((ga.to_vec()[0] - 0.25).abs() < 1e-10);
2577    }
2578
2579    #[test]
2580    fn test_reverse_relu_positive() {
2581        let mut g = GradGraph::new();
2582        let a = g.parameter(Tensor::from_vec_unchecked(vec![3.0], &[1]));
2583        let b = g.relu(a);
2584        g.backward(b);
2585        let ga = g.grad(a).unwrap();
2586        // relu'(3) = 1.0
2587        assert!((ga.to_vec()[0] - 1.0).abs() < 1e-10);
2588    }
2589
2590    #[test]
2591    fn test_reverse_relu_negative() {
2592        let mut g = GradGraph::new();
2593        let a = g.parameter(Tensor::from_vec_unchecked(vec![-2.0], &[1]));
2594        let b = g.relu(a);
2595        g.backward(b);
2596        let ga = g.grad(a).unwrap();
2597        // relu'(-2) = 0.0
2598        assert!(ga.to_vec()[0].abs() < 1e-10);
2599    }
2600
2601    #[test]
2602    fn test_reverse_tanh() {
2603        let mut g = GradGraph::new();
2604        let a = g.parameter(Tensor::from_vec_unchecked(vec![0.0], &[1]));
2605        let b = g.tanh_act(a);
2606        g.backward(b);
2607        let ga = g.grad(a).unwrap();
2608        // tanh'(0) = 1 - tanh(0)^2 = 1 - 0 = 1.0
2609        assert!((ga.to_vec()[0] - 1.0).abs() < 1e-10);
2610    }
2611
2612    #[test]
2613    fn test_reverse_sin_cos_chain() {
2614        // f(x) = sin(cos(x)), f'(x) = cos(cos(x)) * (-sin(x))
2615        // at x=1: f'(1) = cos(cos(1)) * (-sin(1))
2616        let mut g = GradGraph::new();
2617        let a = g.parameter(Tensor::from_vec_unchecked(vec![1.0], &[1]));
2618        let c = g.cos(a);
2619        let s = g.sin(c);
2620        g.backward(s);
2621        let ga = g.grad(a).unwrap();
2622        let expected = 1.0_f64.cos().cos() * (-1.0_f64.sin());
2623        assert!((ga.to_vec()[0] - expected).abs() < 1e-10, "got {}, expected {expected}", ga.to_vec()[0]);
2624    }
2625
2626    #[test]
2627    fn test_reverse_sigmoid_sum() {
2628        // f(x) = sum(sigmoid(x)) for vector x
2629        let mut g = GradGraph::new();
2630        let a = g.parameter(Tensor::from_vec_unchecked(vec![0.0, 1.0, -1.0], &[3]));
2631        let s = g.sigmoid(a);
2632        let loss = g.sum(s);
2633        g.backward(loss);
2634        let ga = g.grad(a).unwrap();
2635        let ga_data = ga.to_vec();
2636        // sigmoid'(0) = 0.25, sigmoid'(1) = sig(1)*(1-sig(1)), sigmoid'(-1) = sig(-1)*(1-sig(-1))
2637        let sig1 = 1.0 / (1.0 + (-1.0_f64).exp());
2638        let sig_neg1 = 1.0 / (1.0 + 1.0_f64.exp());
2639        assert!((ga_data[0] - 0.25).abs() < 1e-10);
2640        assert!((ga_data[1] - sig1 * (1.0 - sig1)).abs() < 1e-10);
2641        assert!((ga_data[2] - sig_neg1 * (1.0 - sig_neg1)).abs() < 1e-10);
2642    }
2643
2644    #[test]
2645    fn test_b8_determinism() {
2646        let mut g1 = GradGraph::new();
2647        let a1 = g1.parameter(Tensor::from_vec_unchecked(vec![1.5], &[1]));
2648        let s1 = g1.sin(a1);
2649        g1.backward(s1);
2650        let ga1 = g1.grad(a1).unwrap().to_vec()[0];
2651
2652        let mut g2 = GradGraph::new();
2653        let a2 = g2.parameter(Tensor::from_vec_unchecked(vec![1.5], &[1]));
2654        let s2 = g2.sin(a2);
2655        g2.backward(s2);
2656        let ga2 = g2.grad(a2).unwrap().to_vec()[0];
2657
2658        assert_eq!(ga1.to_bits(), ga2.to_bits());
2659    }
2660
2661    #[test]
2662    fn test_reverse_mse_loss() {
2663        // MSE = mean((pred - target)^2)
2664        let mut g = GradGraph::new();
2665
2666        let w = g.parameter(Tensor::from_vec_unchecked(vec![1.0, 1.0], &[2, 1]));
2667        let x = g.input(Tensor::from_vec_unchecked(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]));
2668        let target = g.input(Tensor::from_vec_unchecked(vec![3.0, 7.0], &[2, 1]));
2669
2670        let pred = g.matmul(x, w);
2671        let diff = g.sub(pred, target);
2672        let sq = g.mul(diff, diff);
2673        let loss = g.mean(sq);
2674
2675        let loss_val = g.value(loss);
2676        g.backward(loss);
2677
2678        let gw = g.grad(w).unwrap();
2679
2680        // Verify loss is finite and gradient exists
2681        assert!(loss_val.is_finite());
2682        assert_eq!(gw.to_vec().len(), 2);
2683        for &v in &gw.to_vec() {
2684            assert!(v.is_finite());
2685        }
2686    }
2687
2688    // ── Phase C1: Reverse Mode Tests for New Forward Methods ──
2689
2690    #[test]
2691    fn test_reverse_div() {
2692        // f(x) = x / 2, f'(x) = 0.5
2693        let mut g = GradGraph::new();
2694        let a = g.parameter(Tensor::from_vec_unchecked(vec![6.0], &[1]));
2695        let b = g.input(Tensor::from_vec_unchecked(vec![2.0], &[1]));
2696        let c = g.div(a, b);
2697        g.backward(c);
2698        let ga = g.grad(a).unwrap();
2699        assert!((ga.to_vec()[0] - 0.5).abs() < 1e-10);
2700    }
2701
2702    #[test]
2703    fn test_reverse_neg() {
2704        // f(x) = -x, f'(x) = -1
2705        let mut g = GradGraph::new();
2706        let a = g.parameter(Tensor::from_vec_unchecked(vec![3.0], &[1]));
2707        let c = g.neg(a);
2708        g.backward(c);
2709        let ga = g.grad(a).unwrap();
2710        assert!((ga.to_vec()[0] - (-1.0)).abs() < 1e-10);
2711    }
2712
2713    #[test]
2714    fn test_reverse_scalar_mul() {
2715        // f(x) = 3x, f'(x) = 3
2716        let mut g = GradGraph::new();
2717        let a = g.parameter(Tensor::from_vec_unchecked(vec![2.0], &[1]));
2718        let c = g.scalar_mul(a, 3.0);
2719        g.backward(c);
2720        let ga = g.grad(a).unwrap();
2721        assert!((ga.to_vec()[0] - 3.0).abs() < 1e-10);
2722    }
2723
2724    #[test]
2725    fn test_reverse_exp() {
2726        // f(x) = exp(x), f'(x) = exp(x)
2727        let mut g = GradGraph::new();
2728        let a = g.parameter(Tensor::from_vec_unchecked(vec![1.0], &[1]));
2729        let c = g.exp(a);
2730        g.backward(c);
2731        let ga = g.grad(a).unwrap();
2732        assert!((ga.to_vec()[0] - std::f64::consts::E).abs() < 1e-10);
2733    }
2734
2735    #[test]
2736    fn test_reverse_ln() {
2737        // f(x) = ln(x), f'(x) = 1/x, at x=2 → 0.5
2738        let mut g = GradGraph::new();
2739        let a = g.parameter(Tensor::from_vec_unchecked(vec![2.0], &[1]));
2740        let c = g.ln(a);
2741        g.backward(c);
2742        let ga = g.grad(a).unwrap();
2743        assert!((ga.to_vec()[0] - 0.5).abs() < 1e-10);
2744    }
2745
2746    // ── Phase 8: Extended AD Tests ──
2747
2748    #[test]
2749    fn test_reverse_abs_positive() {
2750        let mut g = GradGraph::new();
2751        let a = g.parameter(Tensor::from_vec_unchecked(vec![3.0], &[1]));
2752        let b = g.abs(a);
2753        g.backward(b);
2754        let ga = g.grad(a).unwrap();
2755        // abs'(3) = sign(3) = 1.0
2756        assert!((ga.to_vec()[0] - 1.0).abs() < 1e-10);
2757    }
2758
2759    #[test]
2760    fn test_reverse_abs_negative() {
2761        let mut g = GradGraph::new();
2762        let a = g.parameter(Tensor::from_vec_unchecked(vec![-2.5], &[1]));
2763        let b = g.abs(a);
2764        g.backward(b);
2765        let ga = g.grad(a).unwrap();
2766        // abs'(-2.5) = sign(-2.5) = -1.0
2767        assert!((ga.to_vec()[0] - (-1.0)).abs() < 1e-10);
2768    }
2769
2770    #[test]
2771    fn test_reverse_abs_zero() {
2772        let mut g = GradGraph::new();
2773        let a = g.parameter(Tensor::from_vec_unchecked(vec![0.0], &[1]));
2774        let b = g.abs(a);
2775        g.backward(b);
2776        let ga = g.grad(a).unwrap();
2777        // abs'(0) = 0.0 (subgradient convention)
2778        assert!(ga.to_vec()[0].abs() < 1e-10);
2779    }
2780
2781    #[test]
2782    fn test_reverse_abs_vector() {
2783        let mut g = GradGraph::new();
2784        let a = g.parameter(Tensor::from_vec_unchecked(vec![-1.0, 2.0, 0.0, -3.0], &[4]));
2785        let b = g.abs(a);
2786        let loss = g.sum(b);
2787        g.backward(loss);
2788        let ga = g.grad(a).unwrap();
2789        let expected = vec![-1.0, 1.0, 0.0, -1.0];
2790        for (i, (&got, &exp)) in ga.to_vec().iter().zip(expected.iter()).enumerate() {
2791            assert!((got - exp).abs() < 1e-10, "abs grad[{i}]: got {got}, expected {exp}");
2792        }
2793    }
2794
2795    #[test]
2796    fn test_reverse_log2() {
2797        let mut g = GradGraph::new();
2798        let a = g.parameter(Tensor::from_vec_unchecked(vec![4.0], &[1]));
2799        let b = g.log2(a);
2800        g.backward(b);
2801        let ga = g.grad(a).unwrap();
2802        // d/dx log2(x) at x=4 = 1/(4 * ln(2))
2803        let expected = 1.0 / (4.0 * std::f64::consts::LN_2);
2804        assert!((ga.to_vec()[0] - expected).abs() < 1e-10);
2805    }
2806
2807    #[test]
2808    fn test_reverse_log2_vector() {
2809        let mut g = GradGraph::new();
2810        let a = g.parameter(Tensor::from_vec_unchecked(vec![1.0, 2.0, 8.0], &[3]));
2811        let b = g.log2(a);
2812        let loss = g.sum(b);
2813        g.backward(loss);
2814        let ga = g.grad(a).unwrap();
2815        let ln2 = std::f64::consts::LN_2;
2816        let expected = vec![1.0 / (1.0 * ln2), 1.0 / (2.0 * ln2), 1.0 / (8.0 * ln2)];
2817        for (i, (&got, &exp)) in ga.to_vec().iter().zip(expected.iter()).enumerate() {
2818            assert!((got - exp).abs() < 1e-10, "log2 grad[{i}]: got {got}, expected {exp}");
2819        }
2820    }
2821
2822    #[test]
2823    fn test_softmax_forward() {
2824        let mut g = GradGraph::new();
2825        let a = g.input(Tensor::from_vec_unchecked(vec![1.0, 2.0, 3.0], &[3]));
2826        let b = g.softmax(a);
2827        let sm = g.tensor(b);
2828        let sm_data = sm.to_vec();
2829        // softmax values should sum to 1
2830        let sum: f64 = sm_data.iter().sum();
2831        assert!((sum - 1.0).abs() < 1e-10);
2832        // Verify ordering: softmax(3) > softmax(2) > softmax(1)
2833        assert!(sm_data[2] > sm_data[1]);
2834        assert!(sm_data[1] > sm_data[0]);
2835    }
2836
2837    #[test]
2838    fn test_reverse_softmax() {
2839        // Finite difference check for softmax gradient
2840        let mut g = GradGraph::new();
2841        let a = g.parameter(Tensor::from_vec_unchecked(vec![1.0, 2.0, 3.0], &[3]));
2842        let b = g.softmax(a);
2843        let loss = g.sum(b);
2844        g.backward(loss);
2845        let ga = g.grad(a).unwrap();
2846        // sum(softmax(x)) = 1 always, so d/dx sum(softmax(x)) = 0
2847        for &v in &ga.to_vec() {
2848            assert!(v.abs() < 1e-10, "softmax sum grad should be 0, got {v}");
2849        }
2850    }
2851
2852    #[test]
2853    fn test_reverse_softmax_single_element() {
2854        // With a single element, softmax gradient through sum should be 0
2855        let mut g = GradGraph::new();
2856        let a = g.parameter(Tensor::from_vec_unchecked(vec![2.0, 1.0], &[2]));
2857        let b = g.softmax(a);
2858        // Take only first element via scalar_mul trick: multiply by [1, 0]
2859        // Instead, use a direct sum which should give zero grad
2860        let loss = g.sum(b);
2861        g.backward(loss);
2862        let ga = g.grad(a).unwrap();
2863        for &v in &ga.to_vec() {
2864            assert!(v.abs() < 1e-10);
2865        }
2866    }
2867
2868    #[test]
2869    fn test_cross_entropy_forward() {
2870        let mut g = GradGraph::new();
2871        let logits = g.input(Tensor::from_vec_unchecked(vec![2.0, 1.0, 0.1], &[3]));
2872        let targets = g.input(Tensor::from_vec_unchecked(vec![1.0, 0.0, 0.0], &[3])); // one-hot
2873        let ce = g.cross_entropy(logits, targets);
2874        let loss_val = g.value(ce);
2875        assert!(loss_val > 0.0, "CE loss should be positive");
2876        assert!(loss_val.is_finite(), "CE loss should be finite");
2877    }
2878
2879    #[test]
2880    fn test_reverse_cross_entropy() {
2881        let mut g = GradGraph::new();
2882        let logits = g.parameter(Tensor::from_vec_unchecked(vec![2.0, 1.0, 0.1], &[3]));
2883        let targets = g.input(Tensor::from_vec_unchecked(vec![1.0, 0.0, 0.0], &[3]));
2884        let ce = g.cross_entropy(logits, targets);
2885        g.backward(ce);
2886        let ga = g.grad(logits).unwrap();
2887        let ga_data = ga.to_vec();
2888        // grad = softmax(logits) - targets
2889        // softmax should give something like [0.659, 0.242, 0.099]
2890        // grad[0] should be negative (softmax < 1.0 for correct class)
2891        assert!(ga_data[0] < 0.0, "CE grad for correct class should be negative");
2892        assert!(ga_data[1] > 0.0, "CE grad for incorrect class should be positive");
2893        assert!(ga_data[2] > 0.0, "CE grad for incorrect class should be positive");
2894        // Sum of gradients should be ~0 (softmax sums to 1, targets sum to 1)
2895        let sum: f64 = ga_data.iter().sum();
2896        assert!(sum.abs() < 1e-10, "CE grad should sum to 0, got {sum}");
2897    }
2898
2899    #[test]
2900    fn test_layer_norm_forward() {
2901        let mut g = GradGraph::new();
2902        let a = g.input(Tensor::from_vec_unchecked(vec![1.0, 2.0, 3.0, 4.0], &[4]));
2903        let b = g.layer_norm(a);
2904        let normed = g.tensor(b).to_vec();
2905        // After layer norm, mean should be ~0 and std ~1
2906        let mean: f64 = normed.iter().sum::<f64>() / normed.len() as f64;
2907        assert!(mean.abs() < 1e-5, "LayerNorm mean should be ~0, got {mean}");
2908        let var: f64 = normed.iter().map(|x| (x - mean) * (x - mean)).sum::<f64>() / normed.len() as f64;
2909        assert!((var - 1.0).abs() < 0.01, "LayerNorm variance should be ~1, got {var}");
2910    }
2911
2912    #[test]
2913    fn test_reverse_layer_norm() {
2914        let mut g = GradGraph::new();
2915        let a = g.parameter(Tensor::from_vec_unchecked(vec![1.0, 2.0, 3.0, 4.0], &[4]));
2916        let b = g.layer_norm(a);
2917        let loss = g.sum(b);
2918        g.backward(loss);
2919        let ga = g.grad(a).unwrap();
2920        // Gradient should be finite and non-zero
2921        for &v in &ga.to_vec() {
2922            assert!(v.is_finite(), "LayerNorm grad should be finite");
2923        }
2924        // Finite difference check
2925        let eps = 1e-5;
2926        let input_data = vec![1.0, 2.0, 3.0, 4.0];
2927        for i in 0..4 {
2928            let mut plus = input_data.clone();
2929            plus[i] += eps;
2930            let mut g_plus = GradGraph::new();
2931            let a_plus = g_plus.input(Tensor::from_vec_unchecked(plus, &[4]));
2932            let b_plus = g_plus.layer_norm(a_plus);
2933            let loss_plus = g_plus.sum(b_plus);
2934            let val_plus = g_plus.value(loss_plus);
2935
2936            let mut minus = input_data.clone();
2937            minus[i] -= eps;
2938            let mut g_minus = GradGraph::new();
2939            let a_minus = g_minus.input(Tensor::from_vec_unchecked(minus, &[4]));
2940            let b_minus = g_minus.layer_norm(a_minus);
2941            let loss_minus = g_minus.sum(b_minus);
2942            let val_minus = g_minus.value(loss_minus);
2943
2944            let fd_grad = (val_plus - val_minus) / (2.0 * eps);
2945            let ad_grad = ga.to_vec()[i];
2946            assert!(
2947                (fd_grad - ad_grad).abs() < 1e-4,
2948                "LayerNorm FD check failed at [{i}]: fd={fd_grad}, ad={ad_grad}"
2949            );
2950        }
2951    }
2952
2953    #[test]
2954    fn test_batch_norm_forward() {
2955        let mut g = GradGraph::new();
2956        let a = g.input(Tensor::from_vec_unchecked(vec![2.0, 4.0, 6.0, 8.0], &[4]));
2957        let b = g.batch_norm(a);
2958        let normed = g.tensor(b).to_vec();
2959        let mean: f64 = normed.iter().sum::<f64>() / normed.len() as f64;
2960        assert!(mean.abs() < 1e-5, "BatchNorm mean should be ~0, got {mean}");
2961    }
2962
2963    #[test]
2964    fn test_reverse_batch_norm() {
2965        let mut g = GradGraph::new();
2966        let a = g.parameter(Tensor::from_vec_unchecked(vec![2.0, 4.0, 6.0, 8.0], &[4]));
2967        let b = g.batch_norm(a);
2968        let loss = g.sum(b);
2969        g.backward(loss);
2970        let ga = g.grad(a).unwrap();
2971        for &v in &ga.to_vec() {
2972            assert!(v.is_finite(), "BatchNorm grad should be finite");
2973        }
2974    }
2975
2976    #[test]
2977    fn test_reverse_clamp_in_range() {
2978        let mut g = GradGraph::new();
2979        let a = g.parameter(Tensor::from_vec_unchecked(vec![1.5], &[1]));
2980        let b = g.clamp(a, 0.0, 3.0);
2981        g.backward(b);
2982        let ga = g.grad(a).unwrap();
2983        // 1.5 is in [0, 3], so grad passes through: 1.0
2984        assert!((ga.to_vec()[0] - 1.0).abs() < 1e-10);
2985    }
2986
2987    #[test]
2988    fn test_reverse_clamp_out_of_range() {
2989        let mut g = GradGraph::new();
2990        let a = g.parameter(Tensor::from_vec_unchecked(vec![5.0], &[1]));
2991        let b = g.clamp(a, 0.0, 3.0);
2992        g.backward(b);
2993        let ga = g.grad(a).unwrap();
2994        // 5.0 is outside [0, 3], so grad is 0
2995        assert!(ga.to_vec()[0].abs() < 1e-10);
2996    }
2997
2998    #[test]
2999    fn test_reverse_clamp_vector() {
3000        let mut g = GradGraph::new();
3001        let a = g.parameter(Tensor::from_vec_unchecked(vec![-1.0, 0.5, 2.0, 4.0], &[4]));
3002        let b = g.clamp(a, 0.0, 3.0);
3003        let loss = g.sum(b);
3004        g.backward(loss);
3005        let ga = g.grad(a).unwrap();
3006        let expected = vec![0.0, 1.0, 1.0, 0.0]; // -1 out, 0.5 in, 2 in, 4 out
3007        for (i, (&got, &exp)) in ga.to_vec().iter().zip(expected.iter()).enumerate() {
3008            assert!((got - exp).abs() < 1e-10, "clamp grad[{i}]: got {got}, expected {exp}");
3009        }
3010    }
3011
3012    #[test]
3013    fn test_reverse_where_cond() {
3014        let mut g = GradGraph::new();
3015        let cond = g.input(Tensor::from_vec_unchecked(vec![1.0, 0.0, 1.0], &[3]));
3016        let a = g.parameter(Tensor::from_vec_unchecked(vec![10.0, 20.0, 30.0], &[3]));
3017        let b = g.parameter(Tensor::from_vec_unchecked(vec![100.0, 200.0, 300.0], &[3]));
3018        let w = g.where_cond(cond, a, b);
3019        // Forward: should select [10, 200, 30]
3020        let result = g.tensor(w).to_vec();
3021        assert!((result[0] - 10.0).abs() < 1e-10);
3022        assert!((result[1] - 200.0).abs() < 1e-10);
3023        assert!((result[2] - 30.0).abs() < 1e-10);
3024        let loss = g.sum(w);
3025        g.backward(loss);
3026        let ga = g.grad(a).unwrap().to_vec();
3027        let gb = g.grad(b).unwrap().to_vec();
3028        // grad flows to a where cond=1, to b where cond=0
3029        assert!((ga[0] - 1.0).abs() < 1e-10);
3030        assert!(ga[1].abs() < 1e-10);
3031        assert!((ga[2] - 1.0).abs() < 1e-10);
3032        assert!(gb[0].abs() < 1e-10);
3033        assert!((gb[1] - 1.0).abs() < 1e-10);
3034        assert!(gb[2].abs() < 1e-10);
3035    }
3036
3037    #[test]
3038    fn test_reverse_reshape() {
3039        let mut g = GradGraph::new();
3040        let a = g.parameter(Tensor::from_vec_unchecked(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]));
3041        let b = g.reshape(a, &[3, 2]);
3042        let loss = g.sum(b);
3043        g.backward(loss);
3044        let ga = g.grad(a).unwrap();
3045        // Reshape backward: grad should have original shape [2, 3], all ones
3046        assert_eq!(ga.shape(), &[2, 3]);
3047        for &v in &ga.to_vec() {
3048            assert!((v - 1.0).abs() < 1e-10);
3049        }
3050    }
3051
3052    #[test]
3053    fn test_reverse_transpose_op() {
3054        let mut g = GradGraph::new();
3055        let a = g.parameter(Tensor::from_vec_unchecked(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]));
3056        let b = g.transpose_op(a);
3057        // Transposed shape should be [3, 2]
3058        assert_eq!(g.tensor(b).shape(), &[3, 2]);
3059        let loss = g.sum(b);
3060        g.backward(loss);
3061        let ga = g.grad(a).unwrap();
3062        // Transpose backward: grad should have original shape [2, 3], all ones
3063        assert_eq!(ga.shape(), &[2, 3]);
3064        for &v in &ga.to_vec() {
3065            assert!((v - 1.0).abs() < 1e-10);
3066        }
3067    }
3068
3069    #[test]
3070    fn test_reverse_cat_1d() {
3071        let mut g = GradGraph::new();
3072        let a = g.parameter(Tensor::from_vec_unchecked(vec![1.0, 2.0], &[2]));
3073        let b = g.parameter(Tensor::from_vec_unchecked(vec![3.0, 4.0, 5.0], &[3]));
3074        let c = g.cat(&[a, b], 0);
3075        // Forward: [1, 2, 3, 4, 5]
3076        let result = g.tensor(c).to_vec();
3077        assert_eq!(result, vec![1.0, 2.0, 3.0, 4.0, 5.0]);
3078        let loss = g.sum(c);
3079        g.backward(loss);
3080        let ga = g.grad(a).unwrap().to_vec();
3081        let gb = g.grad(b).unwrap().to_vec();
3082        // All gradients should be 1 (from sum)
3083        assert_eq!(ga, vec![1.0, 1.0]);
3084        assert_eq!(gb, vec![1.0, 1.0, 1.0]);
3085    }
3086
3087    #[test]
3088    fn test_reverse_cat_2d_axis0() {
3089        let mut g = GradGraph::new();
3090        let a = g.parameter(Tensor::from_vec_unchecked(vec![1.0, 2.0], &[1, 2]));
3091        let b = g.parameter(Tensor::from_vec_unchecked(vec![3.0, 4.0, 5.0, 6.0], &[2, 2]));
3092        let c = g.cat(&[a, b], 0);
3093        assert_eq!(g.tensor(c).shape(), &[3, 2]);
3094        let loss = g.sum(c);
3095        g.backward(loss);
3096        let ga = g.grad(a).unwrap();
3097        let gb = g.grad(b).unwrap();
3098        assert_eq!(ga.shape(), &[1, 2]);
3099        assert_eq!(gb.shape(), &[2, 2]);
3100        for &v in &ga.to_vec() {
3101            assert!((v - 1.0).abs() < 1e-10);
3102        }
3103        for &v in &gb.to_vec() {
3104            assert!((v - 1.0).abs() < 1e-10);
3105        }
3106    }
3107
3108    #[test]
3109    fn test_reverse_gather_1d() {
3110        let mut g = GradGraph::new();
3111        let a = g.parameter(Tensor::from_vec_unchecked(vec![10.0, 20.0, 30.0, 40.0], &[4]));
3112        let b = g.gather(a, &[1, 3], 0);
3113        // Forward: [20, 40]
3114        let result = g.tensor(b).to_vec();
3115        assert!((result[0] - 20.0).abs() < 1e-10);
3116        assert!((result[1] - 40.0).abs() < 1e-10);
3117        let loss = g.sum(b);
3118        g.backward(loss);
3119        let ga = g.grad(a).unwrap().to_vec();
3120        // Scatter-add: indices [1, 3] get grad 1.0 each, others get 0
3121        assert!((ga[0]).abs() < 1e-10);
3122        assert!((ga[1] - 1.0).abs() < 1e-10);
3123        assert!((ga[2]).abs() < 1e-10);
3124        assert!((ga[3] - 1.0).abs() < 1e-10);
3125    }
3126
3127    #[test]
3128    fn test_reverse_gather_duplicate_indices() {
3129        let mut g = GradGraph::new();
3130        let a = g.parameter(Tensor::from_vec_unchecked(vec![10.0, 20.0, 30.0], &[3]));
3131        let b = g.gather(a, &[1, 1, 2], 0);
3132        let loss = g.sum(b);
3133        g.backward(loss);
3134        let ga = g.grad(a).unwrap().to_vec();
3135        // Index 1 appears twice, so its grad should be 2.0
3136        assert!((ga[0]).abs() < 1e-10);
3137        assert!((ga[1] - 2.0).abs() < 1e-10);
3138        assert!((ga[2] - 1.0).abs() < 1e-10);
3139    }
3140
3141    #[test]
3142    fn test_phase8_determinism() {
3143        // Run twice and verify bit-identical gradients
3144        for _ in 0..2 {
3145            let run = || {
3146                let mut g = GradGraph::new();
3147                let a = g.parameter(Tensor::from_vec_unchecked(vec![1.0, -2.0, 3.0, -0.5], &[4]));
3148                let b = g.abs(a);
3149                let c = g.clamp(b, 0.0, 2.5);
3150                let d = g.layer_norm(c);
3151                let loss = g.sum(d);
3152                g.backward(loss);
3153                g.grad(a).unwrap().to_vec()
3154            };
3155            let r1 = run();
3156            let r2 = run();
3157            for (i, (v1, v2)) in r1.iter().zip(r2.iter()).enumerate() {
3158                assert_eq!(v1.to_bits(), v2.to_bits(), "Determinism failed at [{i}]");
3159            }
3160        }
3161    }
3162
3163    #[test]
3164    fn test_phase8_softmax_cross_entropy_chain() {
3165        // Verify that softmax + CE combined gradient works end-to-end
3166        let mut g = GradGraph::new();
3167        let logits = g.parameter(Tensor::from_vec_unchecked(vec![1.0, 2.0, 3.0], &[3]));
3168        let targets = g.input(Tensor::from_vec_unchecked(vec![0.0, 0.0, 1.0], &[3]));
3169        let ce = g.cross_entropy(logits, targets);
3170        g.backward(ce);
3171        let ga = g.grad(logits).unwrap().to_vec();
3172        // Verify all gradients are finite
3173        for &v in &ga {
3174            assert!(v.is_finite());
3175        }
3176        // CE grad for correct class (index 2) should be negative (softmax - 1)
3177        assert!(ga[2] < 0.0, "CE grad for correct class should be negative");
3178    }
3179
3180    // ── Sprint 1 SciML Hardening: hessian, double_backward, vmap_forward ──
3181
3182    #[test]
3183    fn test_double_backward_cubic() {
3184        // f(x) = x^3, f'(x) = 3x^2, f''(x) = 6x
3185        let mut g = GradGraph::new();
3186        let x = g.parameter(Tensor::from_vec_unchecked(vec![2.0], &[1]));
3187        let x2 = g.mul(x, x);
3188        let x3 = g.mul(x2, x);
3189        let loss = g.sum(x3);
3190        let hess = g.double_backward(loss, x);
3191        // f''(2) = 6*2 = 12
3192        assert!((hess.to_vec()[0] - 12.0).abs() < 1e-4);
3193    }
3194
3195    #[test]
3196    fn test_full_hessian_quadratic() {
3197        // f(x, y) = x^2 + y^2 using p = [x, y]
3198        // H = [[2, 0], [0, 2]]
3199        let mut g = GradGraph::new();
3200        let p = g.parameter(Tensor::from_vec_unchecked(vec![1.0, 1.0], &[2]));
3201        let p2 = g.mul(p, p); // [x^2, y^2]
3202        let s = g.sum(p2); // x^2 + y^2
3203        let hess = g.hessian(s, p);
3204        let h = hess.to_vec();
3205        assert!((h[0] - 2.0).abs() < 1e-3); // d2f/dx2 = 2
3206        assert!((h[1] - 0.0).abs() < 1e-3); // d2f/dxdy = 0
3207        assert!((h[2] - 0.0).abs() < 1e-3); // d2f/dydx = 0
3208        assert!((h[3] - 2.0).abs() < 1e-3); // d2f/dy2 = 2
3209    }
3210
3211    #[test]
3212    fn test_vmap_forward() {
3213        let mut g = GradGraph::new();
3214        let x = g.parameter(Tensor::from_vec_unchecked(vec![1.0], &[1]));
3215        let x2 = g.mul(x, x);
3216        let loss = g.sum(x2);
3217
3218        // vmap over batch [1.0, 2.0, 3.0]
3219        let batch = vec![
3220            Tensor::from_vec_unchecked(vec![1.0], &[1]),
3221            Tensor::from_vec_unchecked(vec![2.0], &[1]),
3222            Tensor::from_vec_unchecked(vec![3.0], &[1]),
3223        ];
3224        let results = g.vmap_forward(x, &batch);
3225        // Results should be [1.0, 4.0, 9.0]
3226        assert!((g.value(results[0]) - 1.0).abs() < 1e-10);
3227        assert!((g.value(results[1]) - 4.0).abs() < 1e-10);
3228        assert!((g.value(results[2]) - 9.0).abs() < 1e-10);
3229    }
3230
3231    #[test]
3232    fn test_hessian_determinism() {
3233        let mut g = GradGraph::new();
3234        let p = g.parameter(Tensor::from_vec_unchecked(vec![3.0, 4.0], &[2]));
3235        let p2 = g.mul(p, p);
3236        let s = g.sum(p2);
3237        let h1 = g.hessian(s, p);
3238        // Reset and redo
3239        g.nodes[p].borrow_mut().tensor = Tensor::from_vec_unchecked(vec![3.0, 4.0], &[2]);
3240        let h2 = g.hessian(s, p);
3241        assert_eq!(h1.to_vec(), h2.to_vec(), "Hessian must be deterministic");
3242    }
3243}