Skip to main content

cjc_ad/
lib.rs

1//! Automatic differentiation for CJC.
2//!
3//! Provides forward-mode differentiation via dual numbers and reverse-mode
4//! differentiation via a computation tape. Supports `grad()`, `jacobian()`,
5//! and gradient graph construction for ML training loops.
6
7use cjc_runtime::Tensor;
8use std::cell::RefCell;
9use std::rc::Rc;
10
11pub mod pinn;
12
13// ── Forward-Mode AD (Dual Numbers) ──────────────────────────────
14
15/// Dual number for forward-mode automatic differentiation.
16///
17/// Carries a primal value and its derivative (tangent) through arithmetic
18/// operations so that `f(Dual::variable(x))` yields both `f(x)` and `f'(x)`
19/// in a single forward pass.
20///
21/// # Examples
22///
23/// ```rust,ignore
24/// // Compute f(x) = x^2 and f'(x) at x = 3
25/// let x = Dual::variable(3.0);
26/// let y = x.clone() * x;
27/// assert_eq!(y.value, 9.0);
28/// assert_eq!(y.deriv, 6.0);
29/// ```
30#[derive(Debug, Clone)]
31pub struct Dual {
32    /// The primal (function) value.
33    pub value: f64,
34    /// The tangent (derivative) value.
35    pub deriv: f64,
36}
37
38impl Dual {
39    /// Create a dual number with an explicit value and derivative.
40    ///
41    /// # Arguments
42    ///
43    /// * `value` - The primal value.
44    /// * `deriv` - The tangent (derivative) seed.
45    pub fn new(value: f64, deriv: f64) -> Self {
46        Self { value, deriv }
47    }
48
49    /// Create a dual number representing a constant (derivative = 0).
50    ///
51    /// # Arguments
52    ///
53    /// * `value` - The constant value.
54    pub fn constant(value: f64) -> Self {
55        Self { value, deriv: 0.0 }
56    }
57
58    /// Create a dual number representing the independent variable (derivative = 1).
59    ///
60    /// Use this for the variable with respect to which you are differentiating.
61    ///
62    /// # Arguments
63    ///
64    /// * `value` - The point at which to evaluate.
65    pub fn variable(value: f64) -> Self {
66        Self { value, deriv: 1.0 }
67    }
68
69    /// Return the additive identity dual number (value = 0, derivative = 0).
70    pub fn zero() -> Self {
71        Self {
72            value: 0.0,
73            deriv: 0.0,
74        }
75    }
76
77    /// Return the multiplicative identity dual number (value = 1, derivative = 0).
78    pub fn one() -> Self {
79        Self {
80            value: 1.0,
81            deriv: 0.0,
82        }
83    }
84}
85
86impl std::ops::Add for Dual {
87    type Output = Dual;
88    fn add(self, rhs: Dual) -> Dual {
89        Dual {
90            value: self.value + rhs.value,
91            deriv: self.deriv + rhs.deriv,
92        }
93    }
94}
95
96impl std::ops::Sub for Dual {
97    type Output = Dual;
98    fn sub(self, rhs: Dual) -> Dual {
99        Dual {
100            value: self.value - rhs.value,
101            deriv: self.deriv - rhs.deriv,
102        }
103    }
104}
105
106impl std::ops::Mul for Dual {
107    type Output = Dual;
108    fn mul(self, rhs: Dual) -> Dual {
109        Dual {
110            value: self.value * rhs.value,
111            deriv: self.value * rhs.deriv + self.deriv * rhs.value,
112        }
113    }
114}
115
116impl std::ops::Div for Dual {
117    type Output = Dual;
118    fn div(self, rhs: Dual) -> Dual {
119        let denom = rhs.value * rhs.value;
120        Dual {
121            value: self.value / rhs.value,
122            deriv: (self.deriv * rhs.value - self.value * rhs.deriv) / denom,
123        }
124    }
125}
126
127impl std::ops::Neg for Dual {
128    type Output = Dual;
129    fn neg(self) -> Dual {
130        Dual {
131            value: -self.value,
132            deriv: -self.deriv,
133        }
134    }
135}
136
137impl Dual {
138    /// Compute the sine, propagating the derivative via the chain rule: `d/dx sin(x) = cos(x)`.
139    pub fn sin(self) -> Dual {
140        Dual {
141            value: self.value.sin(),
142            deriv: self.deriv * self.value.cos(),
143        }
144    }
145
146    /// Compute the cosine, propagating the derivative via the chain rule: `d/dx cos(x) = -sin(x)`.
147    pub fn cos(self) -> Dual {
148        Dual {
149            value: self.value.cos(),
150            deriv: -self.deriv * self.value.sin(),
151        }
152    }
153
154    /// Compute the exponential, propagating the derivative: `d/dx exp(x) = exp(x)`.
155    pub fn exp(self) -> Dual {
156        let e = self.value.exp();
157        Dual {
158            value: e,
159            deriv: self.deriv * e,
160        }
161    }
162
163    /// Compute the natural logarithm, propagating the derivative: `d/dx ln(x) = 1/x`.
164    pub fn ln(self) -> Dual {
165        Dual {
166            value: self.value.ln(),
167            deriv: self.deriv / self.value,
168        }
169    }
170
171    /// Compute the square root, propagating the derivative: `d/dx sqrt(x) = 1/(2*sqrt(x))`.
172    pub fn sqrt(self) -> Dual {
173        let s = self.value.sqrt();
174        Dual {
175            value: s,
176            deriv: self.deriv / (2.0 * s),
177        }
178    }
179
180    /// Raise to a constant power `n`, propagating the derivative: `d/dx x^n = n * x^(n-1)`.
181    ///
182    /// # Arguments
183    ///
184    /// * `n` - The exponent (constant, not differentiated).
185    pub fn pow(self, n: f64) -> Dual {
186        Dual {
187            value: self.value.powf(n),
188            deriv: self.deriv * n * self.value.powf(n - 1.0),
189        }
190    }
191}
192
193// ── Reverse-Mode AD (Computational Graph) ───────────────────────
194
195/// Operation recorded in the reverse-mode AD computation graph.
196///
197/// Each variant stores the node indices of its operands so the backward pass
198/// can look up parent tensors and propagate gradients.
199#[derive(Debug, Clone)]
200pub enum GradOp {
201    /// External input data (no gradient accumulated).
202    Input,
203    /// Trainable parameter (gradients are accumulated here during backward).
204    Parameter,
205    /// Element-wise addition of two nodes.
206    Add(usize, usize),
207    /// Element-wise subtraction of two nodes.
208    Sub(usize, usize),
209    /// Element-wise (Hadamard) multiplication of two nodes.
210    Mul(usize, usize),
211    /// Element-wise division of two nodes.
212    Div(usize, usize),
213    /// Element-wise negation.
214    Neg(usize),
215    /// Matrix multiplication of two 2-D nodes.
216    MatMul(usize, usize),
217    /// Sum all elements to a scalar `[1]` tensor.
218    Sum(usize),
219    /// Mean of all elements to a scalar `[1]` tensor.
220    Mean(usize),
221    /// Multiply every element by a constant scalar.
222    ScalarMul(usize, f64),
223    /// Element-wise exponential.
224    Exp(usize),
225    /// Element-wise natural logarithm.
226    Ln(usize),
227    /// Gradient through struct field access: parent node, field index.
228    StructField {
229        parent: usize,
230        field_index: usize,
231        total_fields: usize,
232    },
233    /// Gradient through map lookup: map node, key index in insertion order.
234    MapLookup {
235        map_node: usize,
236        key_index: usize,
237        total_keys: usize,
238    },
239    /// Element-wise sine: `d/dx sin(x) = cos(x)`.
240    Sin(usize),
241    /// Element-wise cosine: `d/dx cos(x) = -sin(x)`.
242    Cos(usize),
243    /// Element-wise square root: `d/dx sqrt(x) = 1/(2*sqrt(x))`.
244    Sqrt(usize),
245    /// Element-wise power with a constant exponent: `d/dx x^n = n * x^(n-1)`.
246    Pow(usize, f64),
247    /// Logistic sigmoid activation: `sigma(x) = 1 / (1 + exp(-x))`.
248    Sigmoid(usize),
249    /// Rectified linear unit activation: `max(0, x)`.
250    Relu(usize),
251    /// Hyperbolic tangent activation: `tanh(x)`.
252    TanhAct(usize),
253    /// Element-wise absolute value with sub-gradient `sign(x)` at zero.
254    Abs(usize),
255    /// Base-2 logarithm: `d/dx log2(x) = 1/(x * ln(2))`.
256    Log2(usize),
257    /// Softmax over the last axis, producing a probability distribution.
258    Softmax(usize),
259    /// Cross-entropy loss between predicted logits and target labels.
260    CrossEntropy {
261        /// Node index of the raw logit tensor.
262        logits: usize,
263        /// Node index of the target (one-hot or class-index) tensor.
264        targets: usize,
265    },
266    /// Layer normalization over the last axis; stores statistics for backward.
267    LayerNorm(usize),
268    /// Batch normalization over the first axis; stores statistics for backward.
269    BatchNorm(usize),
270    /// Element-wise clamping to the range `[min, max]`.
271    Clamp {
272        /// Node index of the input tensor.
273        input: usize,
274        /// Lower bound.
275        min: f64,
276        /// Upper bound.
277        max: f64,
278    },
279    /// Element-wise conditional select using a `{0.0, 1.0}` mask tensor.
280    Where {
281        /// Node index of the condition mask.
282        cond: usize,
283        /// Node index selected where condition is `1.0`.
284        on_true: usize,
285        /// Node index selected where condition is `0.0`.
286        on_false: usize,
287    },
288    /// Reshape a tensor, storing the original shape for backward reconstruction.
289    Reshape {
290        /// Node index of the input tensor.
291        input: usize,
292        /// Shape before the reshape (used during backward).
293        original_shape: Vec<usize>,
294    },
295    /// Transpose a 2-D tensor (swap rows and columns).
296    TransposeOp(usize),
297    /// Concatenate tensors along an axis, storing per-input sizes for backward splitting.
298    CatOp {
299        /// Node indices of the tensors to concatenate.
300        inputs: Vec<usize>,
301        /// Axis along which to concatenate.
302        axis: usize,
303        /// Size of each input along the concatenation axis.
304        sizes: Vec<usize>,
305    },
306    /// Gather elements along an axis by index.
307    GatherOp {
308        /// Node index of the source tensor.
309        input: usize,
310        /// Indices to gather.
311        indices: Vec<usize>,
312        /// Axis along which to gather.
313        axis: usize,
314    },
315}
316
317/// A node in the reverse-mode AD graph.
318#[derive(Debug, Clone)]
319pub struct GradNode {
320    pub op: GradOp,
321    pub tensor: Tensor,
322    pub grad: Option<Tensor>,
323}
324
325/// The reverse-mode AD tape/graph.
326pub struct GradGraph {
327    pub nodes: Vec<Rc<RefCell<GradNode>>>,
328}
329
330impl GradGraph {
331    pub fn new() -> Self {
332        Self { nodes: Vec::new() }
333    }
334
335    /// Create an input node (data, no gradient).
336    pub fn input(&mut self, tensor: Tensor) -> usize {
337        let idx = self.nodes.len();
338        self.nodes.push(Rc::new(RefCell::new(GradNode {
339            op: GradOp::Input,
340            tensor,
341            grad: None,
342        })));
343        idx
344    }
345
346    /// Create a parameter node (trainable, accumulates gradients).
347    pub fn parameter(&mut self, tensor: Tensor) -> usize {
348        let idx = self.nodes.len();
349        let shape = tensor.shape().to_vec();
350        self.nodes.push(Rc::new(RefCell::new(GradNode {
351            op: GradOp::Parameter,
352            tensor,
353            grad: Some(Tensor::zeros(&shape)),
354        })));
355        idx
356    }
357
358    /// Element-wise addition.
359    pub fn add(&mut self, a: usize, b: usize) -> usize {
360        let a_t = self.nodes[a].borrow().tensor.clone();
361        let b_t = self.nodes[b].borrow().tensor.clone();
362        let result = a_t.add_unchecked(&b_t);
363        let idx = self.nodes.len();
364        self.nodes.push(Rc::new(RefCell::new(GradNode {
365            op: GradOp::Add(a, b),
366            tensor: result,
367            grad: None,
368        })));
369        idx
370    }
371
372    /// Element-wise subtraction.
373    pub fn sub(&mut self, a: usize, b: usize) -> usize {
374        let a_t = self.nodes[a].borrow().tensor.clone();
375        let b_t = self.nodes[b].borrow().tensor.clone();
376        let result = a_t.sub_unchecked(&b_t);
377        let idx = self.nodes.len();
378        self.nodes.push(Rc::new(RefCell::new(GradNode {
379            op: GradOp::Sub(a, b),
380            tensor: result,
381            grad: None,
382        })));
383        idx
384    }
385
386    /// Element-wise multiplication.
387    pub fn mul(&mut self, a: usize, b: usize) -> usize {
388        let a_t = self.nodes[a].borrow().tensor.clone();
389        let b_t = self.nodes[b].borrow().tensor.clone();
390        let result = a_t.mul_elem_unchecked(&b_t);
391        let idx = self.nodes.len();
392        self.nodes.push(Rc::new(RefCell::new(GradNode {
393            op: GradOp::Mul(a, b),
394            tensor: result,
395            grad: None,
396        })));
397        idx
398    }
399
400    /// Matrix multiplication.
401    pub fn matmul(&mut self, a: usize, b: usize) -> usize {
402        let a_t = self.nodes[a].borrow().tensor.clone();
403        let b_t = self.nodes[b].borrow().tensor.clone();
404        let result = a_t.matmul_unchecked(&b_t);
405        let idx = self.nodes.len();
406        self.nodes.push(Rc::new(RefCell::new(GradNode {
407            op: GradOp::MatMul(a, b),
408            tensor: result,
409            grad: None,
410        })));
411        idx
412    }
413
414    /// Sum all elements.
415    pub fn sum(&mut self, a: usize) -> usize {
416        let a_t = self.nodes[a].borrow().tensor.clone();
417        let s = a_t.sum();
418        let result = Tensor::from_vec_unchecked(vec![s], &[1]);
419        let idx = self.nodes.len();
420        self.nodes.push(Rc::new(RefCell::new(GradNode {
421            op: GradOp::Sum(a),
422            tensor: result,
423            grad: None,
424        })));
425        idx
426    }
427
428    /// Mean of all elements.
429    pub fn mean(&mut self, a: usize) -> usize {
430        let a_t = self.nodes[a].borrow().tensor.clone();
431        let m = a_t.mean();
432        let result = Tensor::from_vec_unchecked(vec![m], &[1]);
433        let idx = self.nodes.len();
434        self.nodes.push(Rc::new(RefCell::new(GradNode {
435            op: GradOp::Mean(a),
436            tensor: result,
437            grad: None,
438        })));
439        idx
440    }
441
442    // ── Phase B8: Transcendental & activation forward ops ──
443
444    /// Element-wise sine.
445    pub fn sin(&mut self, a: usize) -> usize {
446        let a_t = self.nodes[a].borrow().tensor.clone();
447        let data = a_t.to_vec();
448        let result = Tensor::from_vec_unchecked(
449            data.iter().map(|&x| x.sin()).collect(),
450            a_t.shape(),
451        );
452        let idx = self.nodes.len();
453        self.nodes.push(Rc::new(RefCell::new(GradNode {
454            op: GradOp::Sin(a),
455            tensor: result,
456            grad: None,
457        })));
458        idx
459    }
460
461    /// Element-wise cosine.
462    pub fn cos(&mut self, a: usize) -> usize {
463        let a_t = self.nodes[a].borrow().tensor.clone();
464        let data = a_t.to_vec();
465        let result = Tensor::from_vec_unchecked(
466            data.iter().map(|&x| x.cos()).collect(),
467            a_t.shape(),
468        );
469        let idx = self.nodes.len();
470        self.nodes.push(Rc::new(RefCell::new(GradNode {
471            op: GradOp::Cos(a),
472            tensor: result,
473            grad: None,
474        })));
475        idx
476    }
477
478    /// Element-wise square root.
479    pub fn sqrt(&mut self, a: usize) -> usize {
480        let a_t = self.nodes[a].borrow().tensor.clone();
481        let data = a_t.to_vec();
482        let result = Tensor::from_vec_unchecked(
483            data.iter().map(|&x| x.sqrt()).collect(),
484            a_t.shape(),
485        );
486        let idx = self.nodes.len();
487        self.nodes.push(Rc::new(RefCell::new(GradNode {
488            op: GradOp::Sqrt(a),
489            tensor: result,
490            grad: None,
491        })));
492        idx
493    }
494
495    /// Element-wise power with constant exponent.
496    pub fn pow(&mut self, a: usize, n: f64) -> usize {
497        let a_t = self.nodes[a].borrow().tensor.clone();
498        let data = a_t.to_vec();
499        let result = Tensor::from_vec_unchecked(
500            data.iter().map(|&x| x.powf(n)).collect(),
501            a_t.shape(),
502        );
503        let idx = self.nodes.len();
504        self.nodes.push(Rc::new(RefCell::new(GradNode {
505            op: GradOp::Pow(a, n),
506            tensor: result,
507            grad: None,
508        })));
509        idx
510    }
511
512    /// Sigmoid activation: 1 / (1 + exp(-x)).
513    pub fn sigmoid(&mut self, a: usize) -> usize {
514        let a_t = self.nodes[a].borrow().tensor.clone();
515        let data = a_t.to_vec();
516        let result = Tensor::from_vec_unchecked(
517            data.iter().map(|&x| 1.0 / (1.0 + (-x).exp())).collect(),
518            a_t.shape(),
519        );
520        let idx = self.nodes.len();
521        self.nodes.push(Rc::new(RefCell::new(GradNode {
522            op: GradOp::Sigmoid(a),
523            tensor: result,
524            grad: None,
525        })));
526        idx
527    }
528
529    /// ReLU activation: max(0, x).
530    pub fn relu(&mut self, a: usize) -> usize {
531        let a_t = self.nodes[a].borrow().tensor.clone();
532        let data = a_t.to_vec();
533        let result = Tensor::from_vec_unchecked(
534            data.iter().map(|&x| if x > 0.0 { x } else { 0.0 }).collect(),
535            a_t.shape(),
536        );
537        let idx = self.nodes.len();
538        self.nodes.push(Rc::new(RefCell::new(GradNode {
539            op: GradOp::Relu(a),
540            tensor: result,
541            grad: None,
542        })));
543        idx
544    }
545
546    /// Tanh activation.
547    pub fn tanh_act(&mut self, a: usize) -> usize {
548        let a_t = self.nodes[a].borrow().tensor.clone();
549        let data = a_t.to_vec();
550        let result = Tensor::from_vec_unchecked(
551            data.iter().map(|&x| x.tanh()).collect(),
552            a_t.shape(),
553        );
554        let idx = self.nodes.len();
555        self.nodes.push(Rc::new(RefCell::new(GradNode {
556            op: GradOp::TanhAct(a),
557            tensor: result,
558            grad: None,
559        })));
560        idx
561    }
562
563    // ── Phase 8: Extended AD forward ops ──
564
565    /// Element-wise absolute value.
566    pub fn abs(&mut self, a: usize) -> usize {
567        let a_t = self.nodes[a].borrow().tensor.clone();
568        let data = a_t.to_vec();
569        let result = Tensor::from_vec_unchecked(
570            data.iter().map(|&x| x.abs()).collect(),
571            a_t.shape(),
572        );
573        let idx = self.nodes.len();
574        self.nodes.push(Rc::new(RefCell::new(GradNode {
575            op: GradOp::Abs(a),
576            tensor: result,
577            grad: None,
578        })));
579        idx
580    }
581
582    /// Element-wise log base 2.
583    pub fn log2(&mut self, a: usize) -> usize {
584        let a_t = self.nodes[a].borrow().tensor.clone();
585        let data = a_t.to_vec();
586        let result = Tensor::from_vec_unchecked(
587            data.iter().map(|&x| x.log2()).collect(),
588            a_t.shape(),
589        );
590        let idx = self.nodes.len();
591        self.nodes.push(Rc::new(RefCell::new(GradNode {
592            op: GradOp::Log2(a),
593            tensor: result,
594            grad: None,
595        })));
596        idx
597    }
598
599    /// Softmax along the last axis (treats tensor as a flat vector for 1-D).
600    /// Uses numerically stable log-sum-exp: softmax(x_i) = exp(x_i - max(x)) / sum(exp(x_j - max(x)))
601    pub fn softmax(&mut self, a: usize) -> usize {
602        use cjc_repro::KahanAccumulatorF64;
603        let a_t = self.nodes[a].borrow().tensor.clone();
604        let data = a_t.to_vec();
605        let max_val = data.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
606        let exp_shifted: Vec<f64> = data.iter().map(|&x| (x - max_val).exp()).collect();
607        let mut sum_acc = KahanAccumulatorF64::new();
608        for &v in &exp_shifted {
609            sum_acc.add(v);
610        }
611        let sum_exp = sum_acc.finalize();
612        let softmax_data: Vec<f64> = exp_shifted.iter().map(|&e| e / sum_exp).collect();
613        let result = Tensor::from_vec_unchecked(softmax_data, a_t.shape());
614        let idx = self.nodes.len();
615        self.nodes.push(Rc::new(RefCell::new(GradNode {
616            op: GradOp::Softmax(a),
617            tensor: result,
618            grad: None,
619        })));
620        idx
621    }
622
623    /// Cross-entropy loss: -sum(targets * log(softmax(logits)))
624    /// Uses numerically stable log-sum-exp internally.
625    /// Returns a scalar [1] tensor.
626    pub fn cross_entropy(&mut self, logits: usize, targets: usize) -> usize {
627        use cjc_repro::KahanAccumulatorF64;
628        let logits_t = self.nodes[logits].borrow().tensor.clone();
629        let targets_t = self.nodes[targets].borrow().tensor.clone();
630        let logits_data = logits_t.to_vec();
631        let targets_data = targets_t.to_vec();
632        // Numerically stable: log_softmax = x_i - max(x) - log(sum(exp(x_j - max(x))))
633        let max_val = logits_data.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
634        let shifted: Vec<f64> = logits_data.iter().map(|&x| x - max_val).collect();
635        let exp_shifted: Vec<f64> = shifted.iter().map(|&x| x.exp()).collect();
636        let mut sum_acc = KahanAccumulatorF64::new();
637        for &v in &exp_shifted {
638            sum_acc.add(v);
639        }
640        let log_sum_exp = sum_acc.finalize().ln();
641        let log_softmax: Vec<f64> = shifted.iter().map(|&x| x - log_sum_exp).collect();
642        // CE = -sum(targets * log_softmax)
643        let mut ce_acc = KahanAccumulatorF64::new();
644        for (t, ls) in targets_data.iter().zip(log_softmax.iter()) {
645            ce_acc.add(-t * ls);
646        }
647        let ce = ce_acc.finalize();
648        let result = Tensor::from_vec_unchecked(vec![ce], &[1]);
649        let idx = self.nodes.len();
650        self.nodes.push(Rc::new(RefCell::new(GradNode {
651            op: GradOp::CrossEntropy { logits, targets },
652            tensor: result,
653            grad: None,
654        })));
655        idx
656    }
657
658    /// Layer normalization: normalize input to zero mean and unit variance.
659    /// y = (x - mean(x)) / sqrt(var(x) + eps), where eps = 1e-5.
660    pub fn layer_norm(&mut self, a: usize) -> usize {
661        use cjc_repro::KahanAccumulatorF64;
662        let a_t = self.nodes[a].borrow().tensor.clone();
663        let data = a_t.to_vec();
664        let n = data.len() as f64;
665        // Mean
666        let mut mean_acc = KahanAccumulatorF64::new();
667        for &v in &data {
668            mean_acc.add(v);
669        }
670        let mean = mean_acc.finalize() / n;
671        // Variance
672        let mut var_acc = KahanAccumulatorF64::new();
673        for &v in &data {
674            let d = v - mean;
675            var_acc.add(d * d);
676        }
677        let var = var_acc.finalize() / n;
678        let eps = 1e-5;
679        let std = (var + eps).sqrt();
680        let normed: Vec<f64> = data.iter().map(|&x| (x - mean) / std).collect();
681        let result = Tensor::from_vec_unchecked(normed, a_t.shape());
682        let idx = self.nodes.len();
683        self.nodes.push(Rc::new(RefCell::new(GradNode {
684            op: GradOp::LayerNorm(a),
685            tensor: result,
686            grad: None,
687        })));
688        idx
689    }
690
691    /// Batch normalization: normalize along the first axis (batch dimension).
692    /// For a tensor of shape [batch, features], normalizes each feature across the batch.
693    /// y = (x - mean(x)) / sqrt(var(x) + eps), where eps = 1e-5.
694    /// For 1-D inputs, behaves identically to layer_norm.
695    pub fn batch_norm(&mut self, a: usize) -> usize {
696        use cjc_repro::KahanAccumulatorF64;
697        let a_t = self.nodes[a].borrow().tensor.clone();
698        let data = a_t.to_vec();
699        let n = data.len() as f64;
700        let mut mean_acc = KahanAccumulatorF64::new();
701        for &v in &data {
702            mean_acc.add(v);
703        }
704        let mean = mean_acc.finalize() / n;
705        let mut var_acc = KahanAccumulatorF64::new();
706        for &v in &data {
707            let d = v - mean;
708            var_acc.add(d * d);
709        }
710        let var = var_acc.finalize() / n;
711        let eps = 1e-5;
712        let std = (var + eps).sqrt();
713        let normed: Vec<f64> = data.iter().map(|&x| (x - mean) / std).collect();
714        let result = Tensor::from_vec_unchecked(normed, a_t.shape());
715        let idx = self.nodes.len();
716        self.nodes.push(Rc::new(RefCell::new(GradNode {
717            op: GradOp::BatchNorm(a),
718            tensor: result,
719            grad: None,
720        })));
721        idx
722    }
723
724    /// Element-wise clamp to [min, max].
725    pub fn clamp(&mut self, a: usize, min: f64, max: f64) -> usize {
726        let a_t = self.nodes[a].borrow().tensor.clone();
727        let data = a_t.to_vec();
728        let result = Tensor::from_vec_unchecked(
729            data.iter().map(|&x| x.max(min).min(max)).collect(),
730            a_t.shape(),
731        );
732        let idx = self.nodes.len();
733        self.nodes.push(Rc::new(RefCell::new(GradNode {
734            op: GradOp::Clamp { input: a, min, max },
735            tensor: result,
736            grad: None,
737        })));
738        idx
739    }
740
741    /// Conditional select: where(cond, on_true, on_false).
742    /// cond is a tensor of 0.0/1.0 values acting as a mask.
743    pub fn where_cond(&mut self, cond: usize, on_true: usize, on_false: usize) -> usize {
744        let cond_t = self.nodes[cond].borrow().tensor.clone();
745        let true_t = self.nodes[on_true].borrow().tensor.clone();
746        let false_t = self.nodes[on_false].borrow().tensor.clone();
747        let c = cond_t.to_vec();
748        let t = true_t.to_vec();
749        let f = false_t.to_vec();
750        let result_data: Vec<f64> = c.iter().zip(t.iter().zip(f.iter()))
751            .map(|(&ci, (&ti, &fi))| if ci != 0.0 { ti } else { fi })
752            .collect();
753        let result = Tensor::from_vec_unchecked(result_data, cond_t.shape());
754        let idx = self.nodes.len();
755        self.nodes.push(Rc::new(RefCell::new(GradNode {
756            op: GradOp::Where { cond, on_true, on_false },
757            tensor: result,
758            grad: None,
759        })));
760        idx
761    }
762
763    /// Reshape a tensor. Stores the original shape for backward.
764    pub fn reshape(&mut self, a: usize, new_shape: &[usize]) -> usize {
765        let a_t = self.nodes[a].borrow().tensor.clone();
766        let original_shape = a_t.shape().to_vec();
767        let result = a_t.reshape(new_shape).expect("GradGraph::reshape: shape mismatch");
768        let idx = self.nodes.len();
769        self.nodes.push(Rc::new(RefCell::new(GradNode {
770            op: GradOp::Reshape { input: a, original_shape },
771            tensor: result,
772            grad: None,
773        })));
774        idx
775    }
776
777    /// Transpose a 2-D tensor.
778    pub fn transpose_op(&mut self, a: usize) -> usize {
779        let a_t = self.nodes[a].borrow().tensor.clone();
780        let result = a_t.transpose();
781        let idx = self.nodes.len();
782        self.nodes.push(Rc::new(RefCell::new(GradNode {
783            op: GradOp::TransposeOp(a),
784            tensor: result,
785            grad: None,
786        })));
787        idx
788    }
789
790    /// Concatenate tensors along an axis.
791    /// All input tensors must have the same shape except along the concatenation axis.
792    pub fn cat(&mut self, inputs: &[usize], axis: usize) -> usize {
793        let tensors: Vec<Tensor> = inputs.iter()
794            .map(|&i| self.nodes[i].borrow().tensor.clone())
795            .collect();
796        let sizes: Vec<usize> = tensors.iter()
797            .map(|t| t.shape()[axis])
798            .collect();
799        // Build concatenated data along axis
800        // For simplicity, handle axis=0 for 1-D and 2-D tensors
801        let mut all_data = Vec::new();
802        let mut total_along_axis = 0usize;
803        let ndim = tensors[0].ndim();
804        let mut result_shape = tensors[0].shape().to_vec();
805
806        if ndim == 1 {
807            // 1-D: just concatenate flat data
808            for t in &tensors {
809                all_data.extend(t.to_vec());
810                total_along_axis += t.shape()[0];
811            }
812            result_shape[0] = total_along_axis;
813        } else if ndim == 2 && axis == 0 {
814            // Concat along rows
815            for t in &tensors {
816                all_data.extend(t.to_vec());
817                total_along_axis += t.shape()[0];
818            }
819            result_shape[0] = total_along_axis;
820        } else if ndim == 2 && axis == 1 {
821            // Concat along columns
822            let nrows = tensors[0].shape()[0];
823            for row in 0..nrows {
824                for t in &tensors {
825                    let cols = t.shape()[1];
826                    let row_data = t.to_vec();
827                    let start = row * cols;
828                    all_data.extend_from_slice(&row_data[start..start + cols]);
829                }
830            }
831            total_along_axis = sizes.iter().sum();
832            result_shape[1] = total_along_axis;
833        } else {
834            // General case: flatten, cat, reshape (deterministic but limited)
835            for t in &tensors {
836                all_data.extend(t.to_vec());
837                total_along_axis += t.shape()[axis];
838            }
839            result_shape[axis] = total_along_axis;
840        }
841
842        let result = Tensor::from_vec_unchecked(all_data, &result_shape);
843        let input_vec = inputs.to_vec();
844        let idx = self.nodes.len();
845        self.nodes.push(Rc::new(RefCell::new(GradNode {
846            op: GradOp::CatOp { inputs: input_vec, axis, sizes },
847            tensor: result,
848            grad: None,
849        })));
850        idx
851    }
852
853    /// Gather elements along an axis using indices.
854    /// For a 1-D tensor, returns tensor[indices].
855    pub fn gather(&mut self, a: usize, indices: &[usize], axis: usize) -> usize {
856        let a_t = self.nodes[a].borrow().tensor.clone();
857        let data = a_t.to_vec();
858        // For 1-D: just pick elements at indices
859        let gathered: Vec<f64> = if a_t.ndim() == 1 {
860            indices.iter().map(|&i| data[i]).collect()
861        } else if a_t.ndim() == 2 && axis == 0 {
862            let cols = a_t.shape()[1];
863            indices.iter().flat_map(|&i| {
864                let start = i * cols;
865                data[start..start + cols].to_vec()
866            }).collect()
867        } else {
868            // Fallback: gather from flat data
869            indices.iter().map(|&i| data[i]).collect()
870        };
871        let mut result_shape = a_t.shape().to_vec();
872        if a_t.ndim() == 1 {
873            result_shape[0] = indices.len();
874        } else if axis == 0 {
875            result_shape[0] = indices.len();
876        } else {
877            result_shape[axis] = indices.len();
878        }
879        let result = Tensor::from_vec_unchecked(gathered, &result_shape);
880        let idx = self.nodes.len();
881        self.nodes.push(Rc::new(RefCell::new(GradNode {
882            op: GradOp::GatherOp { input: a, indices: indices.to_vec(), axis },
883            tensor: result,
884            grad: None,
885        })));
886        idx
887    }
888
889    // ── Phase C1: Missing forward methods ──
890
891    /// Element-wise division: a / b.
892    /// GradOp::Div(a, b) already has backward implementation.
893    pub fn div(&mut self, a: usize, b: usize) -> usize {
894        let a_tensor = self.nodes[a].borrow().tensor.clone();
895        let b_tensor = self.nodes[b].borrow().tensor.clone();
896        let result = a_tensor.div_elem_unchecked(&b_tensor);
897        let node = GradNode { op: GradOp::Div(a, b), tensor: result, grad: None };
898        self.nodes.push(Rc::new(RefCell::new(node)));
899        self.nodes.len() - 1
900    }
901
902    /// Element-wise negation: -a.
903    /// GradOp::Neg(a) already has backward implementation.
904    pub fn neg(&mut self, a: usize) -> usize {
905        let a_tensor = self.nodes[a].borrow().tensor.clone();
906        let result = a_tensor.neg();
907        let node = GradNode { op: GradOp::Neg(a), tensor: result, grad: None };
908        self.nodes.push(Rc::new(RefCell::new(node)));
909        self.nodes.len() - 1
910    }
911
912    /// Scalar multiply: a * s (where s is an f64 constant).
913    /// GradOp::ScalarMul(a, s) already has backward implementation.
914    pub fn scalar_mul(&mut self, a: usize, s: f64) -> usize {
915        let a_tensor = self.nodes[a].borrow().tensor.clone();
916        let result = a_tensor.scalar_mul(s);
917        let node = GradNode { op: GradOp::ScalarMul(a, s), tensor: result, grad: None };
918        self.nodes.push(Rc::new(RefCell::new(node)));
919        self.nodes.len() - 1
920    }
921
922    /// Element-wise exponential: exp(a).
923    /// GradOp::Exp(a) already has backward implementation.
924    pub fn exp(&mut self, a: usize) -> usize {
925        let a_tensor = self.nodes[a].borrow().tensor.clone();
926        let result = Tensor::from_vec_unchecked(
927            a_tensor.to_vec().iter().map(|x| x.exp()).collect(),
928            a_tensor.shape(),
929        );
930        let node = GradNode { op: GradOp::Exp(a), tensor: result, grad: None };
931        self.nodes.push(Rc::new(RefCell::new(node)));
932        self.nodes.len() - 1
933    }
934
935    /// Element-wise natural logarithm: ln(a).
936    /// GradOp::Ln(a) already has backward implementation.
937    pub fn ln(&mut self, a: usize) -> usize {
938        let a_tensor = self.nodes[a].borrow().tensor.clone();
939        let result = Tensor::from_vec_unchecked(
940            a_tensor.to_vec().iter().map(|x| x.ln()).collect(),
941            a_tensor.shape(),
942        );
943        let node = GradNode { op: GradOp::Ln(a), tensor: result, grad: None };
944        self.nodes.push(Rc::new(RefCell::new(node)));
945        self.nodes.len() - 1
946    }
947
948    /// Get the scalar value from a 1-element tensor node.
949    pub fn value(&self, idx: usize) -> f64 {
950        let node = self.nodes[idx].borrow();
951        let data = node.tensor.to_vec();
952        data[0]
953    }
954
955    /// Get the tensor at a node.
956    pub fn tensor(&self, idx: usize) -> Tensor {
957        self.nodes[idx].borrow().tensor.clone()
958    }
959
960    /// Set the tensor at a node (for parameter updates).
961    pub fn set_tensor(&self, idx: usize, tensor: Tensor) {
962        self.nodes[idx].borrow_mut().tensor = tensor;
963    }
964
965    /// Get the gradient at a node.
966    pub fn grad(&self, idx: usize) -> Option<Tensor> {
967        self.nodes[idx].borrow().grad.clone()
968    }
969
970    /// Zero out all gradients.
971    pub fn zero_grad(&self) {
972        for node in &self.nodes {
973            let mut n = node.borrow_mut();
974            if let Some(ref mut grad) = n.grad {
975                let shape = grad.shape().to_vec();
976                *grad = Tensor::zeros(&shape);
977            }
978        }
979    }
980
981    /// Clip all gradients to `[-max_norm, max_norm]` (element-wise).
982    /// This prevents gradient explosion during backpropagation.
983    pub fn clip_grad(&self, max_norm: f64) {
984        for node in &self.nodes {
985            let mut n = node.borrow_mut();
986            if let Some(ref mut grad) = n.grad {
987                let data = grad.to_vec();
988                let clipped: Vec<f64> = data.iter()
989                    .map(|&x| x.max(-max_norm).min(max_norm))
990                    .collect();
991                let shape = grad.shape().to_vec();
992                *grad = Tensor::from_vec_unchecked(clipped, &shape);
993            }
994        }
995    }
996
997    /// Clip gradients by global norm: if ||grads||_2 > max_norm, scale all
998    /// gradients so the global norm equals max_norm. Deterministic via
999    /// sequential accumulation.
1000    pub fn clip_grad_norm(&self, max_norm: f64) -> f64 {
1001        use cjc_repro::KahanAccumulatorF64;
1002        // Compute global norm
1003        let mut acc = KahanAccumulatorF64::new();
1004        for node in &self.nodes {
1005            let n = node.borrow();
1006            if let Some(ref grad) = n.grad {
1007                for &v in &grad.to_vec() {
1008                    acc.add(v * v);
1009                }
1010            }
1011        }
1012        let global_norm = acc.finalize().sqrt();
1013
1014        if global_norm > max_norm && global_norm > 0.0 {
1015            let scale = max_norm / global_norm;
1016            for node in &self.nodes {
1017                let mut n = node.borrow_mut();
1018                if let Some(ref mut grad) = n.grad {
1019                    let data = grad.to_vec();
1020                    let scaled: Vec<f64> = data.iter().map(|&x| x * scale).collect();
1021                    let shape = grad.shape().to_vec();
1022                    *grad = Tensor::from_vec_unchecked(scaled, &shape);
1023                }
1024            }
1025        }
1026
1027        global_norm
1028    }
1029
1030    /// Run backward pass from a loss node.
1031    pub fn backward(&self, loss_idx: usize) {
1032        let n = self.nodes.len();
1033
1034        // Initialize gradients
1035        let mut grads: Vec<Option<Tensor>> = vec![None; n];
1036
1037        // Loss gradient is 1.0
1038        let loss_shape = self.nodes[loss_idx].borrow().tensor.shape().to_vec();
1039        grads[loss_idx] = Some(Tensor::ones(&loss_shape));
1040
1041        // Backward pass in reverse topological order
1042        for i in (0..=loss_idx).rev() {
1043            let grad = match grads[i].take() {
1044                Some(g) => g,
1045                None => continue,
1046            };
1047
1048            // Clone op and tensor out of the borrow so we don't hold the RefCell across match arms
1049            let (op, node_tensor) = {
1050                let node = self.nodes[i].borrow();
1051                (node.op.clone(), node.tensor.clone())
1052            };
1053
1054            match op {
1055                GradOp::Input => {}
1056                GradOp::Parameter => {
1057                    let mut node_mut = self.nodes[i].borrow_mut();
1058                    if let Some(ref mut existing_grad) = node_mut.grad {
1059                        *existing_grad = existing_grad.add_unchecked(&grad);
1060                    } else {
1061                        node_mut.grad = Some(grad);
1062                    }
1063                }
1064                GradOp::Add(a, b) => {
1065                    accumulate_grad(&mut grads, a, &grad);
1066                    accumulate_grad(&mut grads, b, &grad);
1067                }
1068                GradOp::Sub(a, b) => {
1069                    accumulate_grad(&mut grads, a, &grad);
1070                    let neg_grad = grad.neg();
1071                    accumulate_grad(&mut grads, b, &neg_grad);
1072                }
1073                GradOp::Mul(a, b) => {
1074                    let a_val = self.nodes[a].borrow().tensor.clone();
1075                    let b_val = self.nodes[b].borrow().tensor.clone();
1076
1077                    let grad_a = grad.mul_elem_unchecked(&b_val);
1078                    let grad_b = grad.mul_elem_unchecked(&a_val);
1079
1080                    accumulate_grad(&mut grads, a, &grad_a);
1081                    accumulate_grad(&mut grads, b, &grad_b);
1082                }
1083                GradOp::Div(a, b) => {
1084                    let a_val = self.nodes[a].borrow().tensor.clone();
1085                    let b_val = self.nodes[b].borrow().tensor.clone();
1086
1087                    // d/da (a/b) = 1/b
1088                    let grad_a = grad.div_elem_unchecked(&b_val);
1089                    // d/db (a/b) = -a/b^2
1090                    let b_sq = b_val.mul_elem_unchecked(&b_val);
1091                    let neg_a = a_val.neg();
1092                    let grad_b = grad.mul_elem_unchecked(&neg_a.div_elem_unchecked(&b_sq));
1093
1094                    accumulate_grad(&mut grads, a, &grad_a);
1095                    accumulate_grad(&mut grads, b, &grad_b);
1096                }
1097                GradOp::Neg(a) => {
1098                    let neg_grad = grad.neg();
1099                    accumulate_grad(&mut grads, a, &neg_grad);
1100                }
1101                GradOp::MatMul(a, b) => {
1102                    // d/da (a @ b) = grad @ b^T
1103                    // d/db (a @ b) = a^T @ grad
1104                    let a_val = self.nodes[a].borrow().tensor.clone();
1105                    let b_val = self.nodes[b].borrow().tensor.clone();
1106
1107                    let b_t = b_val.transpose();
1108                    let a_t = a_val.transpose();
1109
1110                    let grad_a = grad.matmul_unchecked(&b_t);
1111                    let grad_b = a_t.matmul_unchecked(&grad);
1112
1113                    accumulate_grad(&mut grads, a, &grad_a);
1114                    accumulate_grad(&mut grads, b, &grad_b);
1115                }
1116                GradOp::Sum(a) => {
1117                    // Gradient of sum is all ones, scaled by upstream grad
1118                    let a_shape = self.nodes[a].borrow().tensor.shape().to_vec();
1119                    let grad_val = grad.to_vec()[0];
1120                    let expanded = Tensor::from_vec_unchecked(
1121                        vec![grad_val; a_shape.iter().product()],
1122                        &a_shape,
1123                    );
1124                    accumulate_grad(&mut grads, a, &expanded);
1125                }
1126                GradOp::Mean(a) => {
1127                    let a_shape = self.nodes[a].borrow().tensor.shape().to_vec();
1128                    let n_elem = a_shape.iter().product::<usize>() as f64;
1129                    let grad_val = grad.to_vec()[0] / n_elem;
1130                    let expanded = Tensor::from_vec_unchecked(
1131                        vec![grad_val; a_shape.iter().product()],
1132                        &a_shape,
1133                    );
1134                    accumulate_grad(&mut grads, a, &expanded);
1135                }
1136                GradOp::ScalarMul(a, s) => {
1137                    let scaled = grad.scalar_mul(s);
1138                    accumulate_grad(&mut grads, a, &scaled);
1139                }
1140                GradOp::Exp(a) => {
1141                    let grad_a = grad.mul_elem_unchecked(&node_tensor);
1142                    accumulate_grad(&mut grads, a, &grad_a);
1143                }
1144                GradOp::Ln(a) => {
1145                    let a_val = self.nodes[a].borrow().tensor.clone();
1146                    let grad_a = grad.div_elem_unchecked(&a_val);
1147                    accumulate_grad(&mut grads, a, &grad_a);
1148                }
1149                // Differentiable container ops: gradient flows back to parent
1150                GradOp::StructField {
1151                    parent,
1152                    field_index,
1153                    total_fields,
1154                } => {
1155                    // Gradient for field access: create a "one-hot" gradient
1156                    // where only the accessed field gets the incoming gradient.
1157                    // For now, accumulate directly to parent.
1158                    let _ = (field_index, total_fields);
1159                    accumulate_grad(&mut grads, parent, &grad);
1160                }
1161                GradOp::MapLookup {
1162                    map_node,
1163                    key_index,
1164                    total_keys,
1165                } => {
1166                    // Gradient for map lookup: accumulate to map node.
1167                    // Deterministic: key_index determines order of accumulation.
1168                    let _ = (key_index, total_keys);
1169                    accumulate_grad(&mut grads, map_node, &grad);
1170                }
1171                // Phase B8: Transcendental & activation backward
1172                GradOp::Sin(a) => {
1173                    let a_val = self.nodes[a].borrow().tensor.clone();
1174                    let cos_a = Tensor::from_vec_unchecked(
1175                        a_val.to_vec().iter().map(|&x| x.cos()).collect(),
1176                        a_val.shape(),
1177                    );
1178                    let grad_a = grad.mul_elem_unchecked(&cos_a);
1179                    accumulate_grad(&mut grads, a, &grad_a);
1180                }
1181                GradOp::Cos(a) => {
1182                    let a_val = self.nodes[a].borrow().tensor.clone();
1183                    let neg_sin_a = Tensor::from_vec_unchecked(
1184                        a_val.to_vec().iter().map(|&x| -x.sin()).collect(),
1185                        a_val.shape(),
1186                    );
1187                    let grad_a = grad.mul_elem_unchecked(&neg_sin_a);
1188                    accumulate_grad(&mut grads, a, &grad_a);
1189                }
1190                GradOp::Sqrt(a) => {
1191                    // d/dx sqrt(x) = 0.5 / sqrt(x) = 0.5 / node_tensor
1192                    let inv_2sqrt = Tensor::from_vec_unchecked(
1193                        node_tensor.to_vec().iter().map(|&x| 0.5 / x).collect(),
1194                        node_tensor.shape(),
1195                    );
1196                    let grad_a = grad.mul_elem_unchecked(&inv_2sqrt);
1197                    accumulate_grad(&mut grads, a, &grad_a);
1198                }
1199                GradOp::Pow(a, n) => {
1200                    let a_val = self.nodes[a].borrow().tensor.clone();
1201                    let coeff = Tensor::from_vec_unchecked(
1202                        a_val.to_vec().iter().map(|&x| n * x.powf(n - 1.0)).collect(),
1203                        a_val.shape(),
1204                    );
1205                    let grad_a = grad.mul_elem_unchecked(&coeff);
1206                    accumulate_grad(&mut grads, a, &grad_a);
1207                }
1208                GradOp::Sigmoid(a) => {
1209                    // sigmoid'(x) = sigmoid(x) * (1 - sigmoid(x))
1210                    let sig = &node_tensor;
1211                    let one_minus = Tensor::from_vec_unchecked(
1212                        sig.to_vec().iter().map(|&s| 1.0 - s).collect(),
1213                        sig.shape(),
1214                    );
1215                    let local = sig.mul_elem_unchecked(&one_minus);
1216                    let grad_a = grad.mul_elem_unchecked(&local);
1217                    accumulate_grad(&mut grads, a, &grad_a);
1218                }
1219                GradOp::Relu(a) => {
1220                    let a_val = self.nodes[a].borrow().tensor.clone();
1221                    let mask = Tensor::from_vec_unchecked(
1222                        a_val.to_vec().iter().map(|&x| if x > 0.0 { 1.0 } else { 0.0 }).collect(),
1223                        a_val.shape(),
1224                    );
1225                    let grad_a = grad.mul_elem_unchecked(&mask);
1226                    accumulate_grad(&mut grads, a, &grad_a);
1227                }
1228                GradOp::TanhAct(a) => {
1229                    // tanh'(x) = 1 - tanh(x)^2
1230                    let t = &node_tensor;
1231                    let one_minus_sq = Tensor::from_vec_unchecked(
1232                        t.to_vec().iter().map(|&x| 1.0 - x * x).collect(),
1233                        t.shape(),
1234                    );
1235                    let grad_a = grad.mul_elem_unchecked(&one_minus_sq);
1236                    accumulate_grad(&mut grads, a, &grad_a);
1237                }
1238                // Phase 8: Extended AD backward
1239                GradOp::Abs(a) => {
1240                    let a_val = self.nodes[a].borrow().tensor.clone();
1241                    let sign = Tensor::from_vec_unchecked(
1242                        a_val.to_vec().iter().map(|&x| {
1243                            if x > 0.0 { 1.0 } else if x < 0.0 { -1.0 } else { 0.0 }
1244                        }).collect(),
1245                        a_val.shape(),
1246                    );
1247                    let grad_a = grad.mul_elem_unchecked(&sign);
1248                    accumulate_grad(&mut grads, a, &grad_a);
1249                }
1250                GradOp::Log2(a) => {
1251                    // d/dx log2(x) = 1 / (x * ln(2))
1252                    let a_val = self.nodes[a].borrow().tensor.clone();
1253                    let ln2 = std::f64::consts::LN_2;
1254                    let local = Tensor::from_vec_unchecked(
1255                        a_val.to_vec().iter().map(|&x| 1.0 / (x * ln2)).collect(),
1256                        a_val.shape(),
1257                    );
1258                    let grad_a = grad.mul_elem_unchecked(&local);
1259                    accumulate_grad(&mut grads, a, &grad_a);
1260                }
1261                GradOp::Softmax(a) => {
1262                    // Jacobian-vector product: grad_input = softmax * (grad - sum(grad * softmax))
1263                    use cjc_repro::KahanAccumulatorF64;
1264                    let sm = &node_tensor;
1265                    let sm_data = sm.to_vec();
1266                    let grad_data = grad.to_vec();
1267                    let mut dot_acc = KahanAccumulatorF64::new();
1268                    for (&g, &s) in grad_data.iter().zip(sm_data.iter()) {
1269                        dot_acc.add(g * s);
1270                    }
1271                    let dot = dot_acc.finalize();
1272                    let grad_input: Vec<f64> = sm_data.iter().zip(grad_data.iter())
1273                        .map(|(&s, &g)| s * (g - dot))
1274                        .collect();
1275                    let grad_a = Tensor::from_vec_unchecked(grad_input, sm.shape());
1276                    accumulate_grad(&mut grads, a, &grad_a);
1277                }
1278                GradOp::CrossEntropy { logits, targets } => {
1279                    // Combined softmax + cross-entropy gradient: grad_logits = grad * (softmax - targets)
1280                    use cjc_repro::KahanAccumulatorF64;
1281                    let logits_val = self.nodes[logits].borrow().tensor.clone();
1282                    let targets_val = self.nodes[targets].borrow().tensor.clone();
1283                    let logits_data = logits_val.to_vec();
1284                    let targets_data = targets_val.to_vec();
1285                    // Compute softmax of logits (numerically stable)
1286                    let max_val = logits_data.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
1287                    let exp_shifted: Vec<f64> = logits_data.iter().map(|&x| (x - max_val).exp()).collect();
1288                    let mut sum_acc = KahanAccumulatorF64::new();
1289                    for &v in &exp_shifted {
1290                        sum_acc.add(v);
1291                    }
1292                    let sum_exp = sum_acc.finalize();
1293                    let softmax: Vec<f64> = exp_shifted.iter().map(|&e| e / sum_exp).collect();
1294                    // grad_logits = upstream_grad * (softmax - targets)
1295                    let upstream = grad.to_vec()[0]; // CE produces scalar
1296                    let grad_logits: Vec<f64> = softmax.iter().zip(targets_data.iter())
1297                        .map(|(&s, &t)| upstream * (s - t))
1298                        .collect();
1299                    let gl = Tensor::from_vec_unchecked(grad_logits, logits_val.shape());
1300                    accumulate_grad(&mut grads, logits, &gl);
1301                    // No gradient flows to targets (they are labels)
1302                }
1303                GradOp::LayerNorm(a) => {
1304                    // Layer norm backward:
1305                    // dx = (1/std) * (grad - mean(grad) - x_hat * mean(grad * x_hat))
1306                    // where x_hat = normalized output (node_tensor)
1307                    use cjc_repro::KahanAccumulatorF64;
1308                    let x_hat = &node_tensor;
1309                    let x_hat_data = x_hat.to_vec();
1310                    let grad_data = grad.to_vec();
1311                    let n = x_hat_data.len() as f64;
1312                    // Reconstruct std from input
1313                    let a_val = self.nodes[a].borrow().tensor.clone();
1314                    let a_data = a_val.to_vec();
1315                    let mut mean_acc = KahanAccumulatorF64::new();
1316                    for &v in &a_data {
1317                        mean_acc.add(v);
1318                    }
1319                    let mean = mean_acc.finalize() / n;
1320                    let mut var_acc = KahanAccumulatorF64::new();
1321                    for &v in &a_data {
1322                        let d = v - mean;
1323                        var_acc.add(d * d);
1324                    }
1325                    let var = var_acc.finalize() / n;
1326                    let eps = 1e-5;
1327                    let std_val = (var + eps).sqrt();
1328                    // mean(grad)
1329                    let mut mg_acc = KahanAccumulatorF64::new();
1330                    for &g in &grad_data {
1331                        mg_acc.add(g);
1332                    }
1333                    let mean_grad = mg_acc.finalize() / n;
1334                    // mean(grad * x_hat)
1335                    let mut mgx_acc = KahanAccumulatorF64::new();
1336                    for (&g, &xh) in grad_data.iter().zip(x_hat_data.iter()) {
1337                        mgx_acc.add(g * xh);
1338                    }
1339                    let mean_grad_xhat = mgx_acc.finalize() / n;
1340                    let dx: Vec<f64> = grad_data.iter().zip(x_hat_data.iter())
1341                        .map(|(&g, &xh)| (g - mean_grad - xh * mean_grad_xhat) / std_val)
1342                        .collect();
1343                    let grad_a = Tensor::from_vec_unchecked(dx, a_val.shape());
1344                    accumulate_grad(&mut grads, a, &grad_a);
1345                }
1346                GradOp::BatchNorm(a) => {
1347                    // Identical to LayerNorm backward for the per-tensor case
1348                    use cjc_repro::KahanAccumulatorF64;
1349                    let x_hat = &node_tensor;
1350                    let x_hat_data = x_hat.to_vec();
1351                    let grad_data = grad.to_vec();
1352                    let n = x_hat_data.len() as f64;
1353                    let a_val = self.nodes[a].borrow().tensor.clone();
1354                    let a_data = a_val.to_vec();
1355                    let mut mean_acc = KahanAccumulatorF64::new();
1356                    for &v in &a_data {
1357                        mean_acc.add(v);
1358                    }
1359                    let mean = mean_acc.finalize() / n;
1360                    let mut var_acc = KahanAccumulatorF64::new();
1361                    for &v in &a_data {
1362                        let d = v - mean;
1363                        var_acc.add(d * d);
1364                    }
1365                    let var = var_acc.finalize() / n;
1366                    let eps = 1e-5;
1367                    let std_val = (var + eps).sqrt();
1368                    let mut mg_acc = KahanAccumulatorF64::new();
1369                    for &g in &grad_data {
1370                        mg_acc.add(g);
1371                    }
1372                    let mean_grad = mg_acc.finalize() / n;
1373                    let mut mgx_acc = KahanAccumulatorF64::new();
1374                    for (&g, &xh) in grad_data.iter().zip(x_hat_data.iter()) {
1375                        mgx_acc.add(g * xh);
1376                    }
1377                    let mean_grad_xhat = mgx_acc.finalize() / n;
1378                    let dx: Vec<f64> = grad_data.iter().zip(x_hat_data.iter())
1379                        .map(|(&g, &xh)| (g - mean_grad - xh * mean_grad_xhat) / std_val)
1380                        .collect();
1381                    let grad_a = Tensor::from_vec_unchecked(dx, a_val.shape());
1382                    accumulate_grad(&mut grads, a, &grad_a);
1383                }
1384                GradOp::Clamp { input, min, max } => {
1385                    // Gradient passes through where input is in [min, max], else 0
1386                    let a_val = self.nodes[input].borrow().tensor.clone();
1387                    let mask = Tensor::from_vec_unchecked(
1388                        a_val.to_vec().iter().map(|&x| {
1389                            if x >= min && x <= max { 1.0 } else { 0.0 }
1390                        }).collect(),
1391                        a_val.shape(),
1392                    );
1393                    let grad_a = grad.mul_elem_unchecked(&mask);
1394                    accumulate_grad(&mut grads, input, &grad_a);
1395                }
1396                GradOp::Where { cond, on_true, on_false } => {
1397                    let cond_data = self.nodes[cond].borrow().tensor.to_vec();
1398                    let grad_data = grad.to_vec();
1399                    let shape = grad.shape().to_vec();
1400                    let grad_true: Vec<f64> = cond_data.iter().zip(grad_data.iter())
1401                        .map(|(&c, &g)| if c != 0.0 { g } else { 0.0 })
1402                        .collect();
1403                    let grad_false: Vec<f64> = cond_data.iter().zip(grad_data.iter())
1404                        .map(|(&c, &g)| if c != 0.0 { 0.0 } else { g })
1405                        .collect();
1406                    let gt = Tensor::from_vec_unchecked(grad_true, &shape);
1407                    let gf = Tensor::from_vec_unchecked(grad_false, &shape);
1408                    accumulate_grad(&mut grads, on_true, &gt);
1409                    accumulate_grad(&mut grads, on_false, &gf);
1410                    // No gradient to condition
1411                }
1412                GradOp::Reshape { input, ref original_shape } => {
1413                    // Backward: reshape grad back to original shape
1414                    let grad_a = grad.reshape(original_shape)
1415                        .expect("Reshape backward: shape mismatch");
1416                    accumulate_grad(&mut grads, input, &grad_a);
1417                }
1418                GradOp::TransposeOp(a) => {
1419                    // Transpose is its own inverse for 2-D
1420                    let grad_a = grad.transpose();
1421                    accumulate_grad(&mut grads, a, &grad_a);
1422                }
1423                GradOp::CatOp { ref inputs, axis, ref sizes } => {
1424                    // Split grad along axis into pieces matching sizes
1425                    let grad_data = grad.to_vec();
1426                    let grad_shape = grad.shape().to_vec();
1427                    let ndim = grad_shape.len();
1428                    if ndim == 1 {
1429                        let mut offset = 0usize;
1430                        for (&idx, &sz) in inputs.iter().zip(sizes.iter()) {
1431                            let piece = grad_data[offset..offset + sz].to_vec();
1432                            let gt = Tensor::from_vec_unchecked(piece, &[sz]);
1433                            accumulate_grad(&mut grads, idx, &gt);
1434                            offset += sz;
1435                        }
1436                    } else if ndim == 2 && axis == 0 {
1437                        let cols = grad_shape[1];
1438                        let mut row_offset = 0usize;
1439                        for (&idx, &sz) in inputs.iter().zip(sizes.iter()) {
1440                            let start = row_offset * cols;
1441                            let end = start + sz * cols;
1442                            let piece = grad_data[start..end].to_vec();
1443                            let gt = Tensor::from_vec_unchecked(piece, &[sz, cols]);
1444                            accumulate_grad(&mut grads, idx, &gt);
1445                            row_offset += sz;
1446                        }
1447                    } else if ndim == 2 && axis == 1 {
1448                        let nrows = grad_shape[0];
1449                        let total_cols = grad_shape[1];
1450                        for (input_idx, (&idx, &sz)) in inputs.iter().zip(sizes.iter()).enumerate() {
1451                            let mut piece = Vec::with_capacity(nrows * sz);
1452                            let col_offset: usize = sizes[..input_idx].iter().sum();
1453                            for row in 0..nrows {
1454                                let row_start = row * total_cols + col_offset;
1455                                piece.extend_from_slice(&grad_data[row_start..row_start + sz]);
1456                            }
1457                            let gt = Tensor::from_vec_unchecked(piece, &[nrows, sz]);
1458                            accumulate_grad(&mut grads, idx, &gt);
1459                        }
1460                    } else {
1461                        // General fallback: split flat data proportionally
1462                        let mut offset = 0usize;
1463                        for (&idx, &sz) in inputs.iter().zip(sizes.iter()) {
1464                            let piece_len = sz * grad_data.len() / grad_shape[axis];
1465                            let piece = grad_data[offset..offset + piece_len].to_vec();
1466                            let mut piece_shape = grad_shape.clone();
1467                            piece_shape[axis] = sz;
1468                            let gt = Tensor::from_vec_unchecked(piece, &piece_shape);
1469                            accumulate_grad(&mut grads, idx, &gt);
1470                            offset += piece_len;
1471                        }
1472                    }
1473                }
1474                GradOp::GatherOp { input, ref indices, axis } => {
1475                    // Scatter-add: distribute grad back to input positions
1476                    let input_shape = self.nodes[input].borrow().tensor.shape().to_vec();
1477                    let input_len: usize = input_shape.iter().product();
1478                    let mut scatter = vec![0.0_f64; input_len];
1479                    let grad_data = grad.to_vec();
1480                    if self.nodes[input].borrow().tensor.ndim() == 1 {
1481                        for (gi, &idx) in indices.iter().enumerate() {
1482                            scatter[idx] += grad_data[gi];
1483                        }
1484                    } else if axis == 0 && self.nodes[input].borrow().tensor.ndim() == 2 {
1485                        let cols = input_shape[1];
1486                        for (gi, &idx) in indices.iter().enumerate() {
1487                            for c in 0..cols {
1488                                scatter[idx * cols + c] += grad_data[gi * cols + c];
1489                            }
1490                        }
1491                    } else {
1492                        // Fallback: treat as flat index
1493                        for (gi, &idx) in indices.iter().enumerate() {
1494                            scatter[idx] += grad_data[gi];
1495                        }
1496                    }
1497                    let grad_a = Tensor::from_vec_unchecked(scatter, &input_shape);
1498                    accumulate_grad(&mut grads, input, &grad_a);
1499                }
1500            }
1501        }
1502    }
1503
1504    /// Compute the Jacobian of a vector-valued output node with respect to
1505    /// a parameter node. Returns a 2D tensor of shape [output_dim, param_dim].
1506    ///
1507    /// Strategy: run backward once per output element with a one-hot seed.
1508    pub fn jacobian(&mut self, output_idx: usize, param_idx: usize) -> Tensor {
1509        let output_shape = self.nodes[output_idx].borrow().tensor.shape().to_vec();
1510        let output_dim: usize = output_shape.iter().product();
1511        let param_shape = self.nodes[param_idx].borrow().tensor.shape().to_vec();
1512        let param_dim: usize = param_shape.iter().product();
1513
1514        let mut jac_data = vec![0.0_f64; output_dim * param_dim];
1515
1516        for i in 0..output_dim {
1517            // Zero all gradients
1518            self.zero_grad();
1519
1520            // Create one-hot seed for output element i
1521            let mut seed = vec![0.0_f64; output_dim];
1522            seed[i] = 1.0;
1523            let seed_tensor = Tensor::from_vec_unchecked(seed, &output_shape);
1524
1525            // Run backward with this seed
1526            self.backward_with_seed(output_idx, &seed_tensor);
1527
1528            // Read gradient of param node
1529            let grad = self.nodes[param_idx].borrow().grad.clone();
1530            if let Some(g) = grad {
1531                let g_vec = g.to_vec();
1532                for j in 0..param_dim {
1533                    jac_data[i * param_dim + j] = g_vec[j];
1534                }
1535            }
1536        }
1537
1538        Tensor::from_vec_unchecked(jac_data, &[output_dim, param_dim])
1539    }
1540
1541    /// Compute the diagonal of the Hessian of a scalar loss with respect to
1542    /// a parameter node. Uses finite differences on the gradient
1543    /// (compute grad, perturb, re-compute grad).
1544    ///
1545    /// Returns a tensor of the same shape as the parameter.
1546    pub fn hessian_diag(&mut self, loss_idx: usize, param_idx: usize, eps: f64) -> Tensor {
1547        let param_shape = self.nodes[param_idx].borrow().tensor.shape().to_vec();
1548        let param_dim: usize = param_shape.iter().product();
1549        let original = self.nodes[param_idx].borrow().tensor.to_vec();
1550        let mut hess_diag = vec![0.0_f64; param_dim];
1551
1552        for i in 0..param_dim {
1553            // Perturb +eps
1554            let mut plus = original.clone();
1555            plus[i] += eps;
1556            self.nodes[param_idx].borrow_mut().tensor =
1557                Tensor::from_vec_unchecked(plus, &param_shape);
1558            self.zero_grad();
1559            self.backward(loss_idx);
1560            let grad_plus = self.nodes[param_idx]
1561                .borrow()
1562                .grad
1563                .as_ref()
1564                .map(|g| g.to_vec()[i])
1565                .unwrap_or(0.0);
1566
1567            // Perturb -eps
1568            let mut minus = original.clone();
1569            minus[i] -= eps;
1570            self.nodes[param_idx].borrow_mut().tensor =
1571                Tensor::from_vec_unchecked(minus, &param_shape);
1572            self.zero_grad();
1573            self.backward(loss_idx);
1574            let grad_minus = self.nodes[param_idx]
1575                .borrow()
1576                .grad
1577                .as_ref()
1578                .map(|g| g.to_vec()[i])
1579                .unwrap_or(0.0);
1580
1581            hess_diag[i] = (grad_plus - grad_minus) / (2.0 * eps);
1582        }
1583
1584        // Restore original parameter
1585        self.nodes[param_idx].borrow_mut().tensor =
1586            Tensor::from_vec_unchecked(original, &param_shape);
1587
1588        Tensor::from_vec_unchecked(hess_diag, &param_shape)
1589    }
1590
1591    /// Compute the full Hessian matrix of a scalar loss with respect to a parameter node.
1592    ///
1593    /// Returns a 2D tensor of shape [param_dim, param_dim] where H[i, j] = d²loss / (dp_i dp_j).
1594    ///
1595    /// Strategy: For each parameter element i, perturb param[i] by +eps and -eps, re-run
1596    /// the forward pass to update intermediate node values, then run backward() to get the
1597    /// gradient vector. The i-th row of the Hessian is (grad_plus - grad_minus) / (2 * eps).
1598    /// Uses eps = 1e-5 for accurate central differences.
1599    pub fn hessian(&mut self, loss_idx: usize, param_idx: usize) -> Tensor {
1600        let eps = 1e-5;
1601        let param_shape = self.nodes[param_idx].borrow().tensor.shape().to_vec();
1602        let param_dim: usize = param_shape.iter().product();
1603        let original = self.nodes[param_idx].borrow().tensor.to_vec();
1604        let mut hess_data = vec![0.0_f64; param_dim * param_dim];
1605
1606        for i in 0..param_dim {
1607            // Perturb +eps at index i, re-forward so intermediate nodes are up-to-date
1608            let mut plus = original.clone();
1609            plus[i] += eps;
1610            self.nodes[param_idx].borrow_mut().tensor =
1611                Tensor::from_vec_unchecked(plus, &param_shape);
1612            self.reforward(param_idx + 1, loss_idx);
1613            self.zero_grad();
1614            self.backward(loss_idx);
1615            let grad_plus: Vec<f64> = self.nodes[param_idx]
1616                .borrow()
1617                .grad
1618                .as_ref()
1619                .map(|g| g.to_vec())
1620                .unwrap_or_else(|| vec![0.0; param_dim]);
1621
1622            // Perturb -eps at index i, re-forward
1623            let mut minus = original.clone();
1624            minus[i] -= eps;
1625            self.nodes[param_idx].borrow_mut().tensor =
1626                Tensor::from_vec_unchecked(minus, &param_shape);
1627            self.reforward(param_idx + 1, loss_idx);
1628            self.zero_grad();
1629            self.backward(loss_idx);
1630            let grad_minus: Vec<f64> = self.nodes[param_idx]
1631                .borrow()
1632                .grad
1633                .as_ref()
1634                .map(|g| g.to_vec())
1635                .unwrap_or_else(|| vec![0.0; param_dim]);
1636
1637            // Row i of the Hessian: (grad_plus[j] - grad_minus[j]) / (2 * eps) for each j
1638            for j in 0..param_dim {
1639                hess_data[i * param_dim + j] = (grad_plus[j] - grad_minus[j]) / (2.0 * eps);
1640            }
1641        }
1642
1643        // Restore original parameter and re-forward to clean state
1644        self.nodes[param_idx].borrow_mut().tensor =
1645            Tensor::from_vec_unchecked(original, &param_shape);
1646        self.reforward(param_idx + 1, loss_idx);
1647
1648        Tensor::from_vec_unchecked(hess_data, &[param_dim, param_dim])
1649    }
1650
1651    /// Re-run the forward pass for all nodes from `start` up to and including `end`.
1652    ///
1653    /// This is needed before backward when a parameter has been perturbed, so that
1654    /// intermediate computation nodes hold updated tensor values.
1655    fn reforward(&mut self, start: usize, end: usize) {
1656        for node_i in start..=end {
1657            let new_tensor = {
1658                let node = self.nodes[node_i].borrow();
1659                match &node.op {
1660                    GradOp::Input | GradOp::Parameter => continue,
1661                    GradOp::Add(a, b) => {
1662                        let at = self.nodes[*a].borrow().tensor.clone();
1663                        let bt = self.nodes[*b].borrow().tensor.clone();
1664                        at.add_unchecked(&bt)
1665                    }
1666                    GradOp::Sub(a, b) => {
1667                        let at = self.nodes[*a].borrow().tensor.clone();
1668                        let bt = self.nodes[*b].borrow().tensor.clone();
1669                        at.sub_unchecked(&bt)
1670                    }
1671                    GradOp::Mul(a, b) => {
1672                        let at = self.nodes[*a].borrow().tensor.clone();
1673                        let bt = self.nodes[*b].borrow().tensor.clone();
1674                        at.mul_elem_unchecked(&bt)
1675                    }
1676                    GradOp::Div(a, b) => {
1677                        let at = self.nodes[*a].borrow().tensor.clone();
1678                        let bt = self.nodes[*b].borrow().tensor.clone();
1679                        at.div_elem_unchecked(&bt)
1680                    }
1681                    GradOp::Neg(a) => {
1682                        self.nodes[*a].borrow().tensor.neg()
1683                    }
1684                    GradOp::ScalarMul(a, s) => {
1685                        let s = *s;
1686                        self.nodes[*a].borrow().tensor.scalar_mul(s)
1687                    }
1688                    GradOp::MatMul(a, b) => {
1689                        let at = self.nodes[*a].borrow().tensor.clone();
1690                        let bt = self.nodes[*b].borrow().tensor.clone();
1691                        at.matmul_unchecked(&bt)
1692                    }
1693                    GradOp::Sum(a) => {
1694                        let s = self.nodes[*a].borrow().tensor.sum();
1695                        Tensor::from_vec_unchecked(vec![s], &[1])
1696                    }
1697                    GradOp::Mean(a) => {
1698                        let m = self.nodes[*a].borrow().tensor.mean();
1699                        Tensor::from_vec_unchecked(vec![m], &[1])
1700                    }
1701                    GradOp::Exp(a) => {
1702                        let (data, shape) = {
1703                            let t = self.nodes[*a].borrow();
1704                            (t.tensor.to_vec(), t.tensor.shape().to_vec())
1705                        };
1706                        Tensor::from_vec_unchecked(data.iter().map(|x| x.exp()).collect(), &shape)
1707                    }
1708                    GradOp::Ln(a) => {
1709                        let (data, shape) = {
1710                            let t = self.nodes[*a].borrow();
1711                            (t.tensor.to_vec(), t.tensor.shape().to_vec())
1712                        };
1713                        Tensor::from_vec_unchecked(data.iter().map(|x| x.ln()).collect(), &shape)
1714                    }
1715                    GradOp::Sin(a) => {
1716                        let (data, shape) = {
1717                            let t = self.nodes[*a].borrow();
1718                            (t.tensor.to_vec(), t.tensor.shape().to_vec())
1719                        };
1720                        Tensor::from_vec_unchecked(data.iter().map(|x| x.sin()).collect(), &shape)
1721                    }
1722                    GradOp::Cos(a) => {
1723                        let (data, shape) = {
1724                            let t = self.nodes[*a].borrow();
1725                            (t.tensor.to_vec(), t.tensor.shape().to_vec())
1726                        };
1727                        Tensor::from_vec_unchecked(data.iter().map(|x| x.cos()).collect(), &shape)
1728                    }
1729                    GradOp::Sqrt(a) => {
1730                        let (data, shape) = {
1731                            let t = self.nodes[*a].borrow();
1732                            (t.tensor.to_vec(), t.tensor.shape().to_vec())
1733                        };
1734                        Tensor::from_vec_unchecked(data.iter().map(|x| x.sqrt()).collect(), &shape)
1735                    }
1736                    GradOp::Pow(a, n) => {
1737                        let n = *n;
1738                        let (data, shape) = {
1739                            let t = self.nodes[*a].borrow();
1740                            (t.tensor.to_vec(), t.tensor.shape().to_vec())
1741                        };
1742                        Tensor::from_vec_unchecked(data.iter().map(|x| x.powf(n)).collect(), &shape)
1743                    }
1744                    GradOp::Sigmoid(a) => {
1745                        let (data, shape) = {
1746                            let t = self.nodes[*a].borrow();
1747                            (t.tensor.to_vec(), t.tensor.shape().to_vec())
1748                        };
1749                        Tensor::from_vec_unchecked(
1750                            data.iter().map(|&x| 1.0 / (1.0 + (-x).exp())).collect(),
1751                            &shape,
1752                        )
1753                    }
1754                    GradOp::Relu(a) => {
1755                        let (data, shape) = {
1756                            let t = self.nodes[*a].borrow();
1757                            (t.tensor.to_vec(), t.tensor.shape().to_vec())
1758                        };
1759                        Tensor::from_vec_unchecked(
1760                            data.iter().map(|&x| if x > 0.0 { x } else { 0.0 }).collect(),
1761                            &shape,
1762                        )
1763                    }
1764                    GradOp::TanhAct(a) => {
1765                        let (data, shape) = {
1766                            let t = self.nodes[*a].borrow();
1767                            (t.tensor.to_vec(), t.tensor.shape().to_vec())
1768                        };
1769                        Tensor::from_vec_unchecked(data.iter().map(|x| x.tanh()).collect(), &shape)
1770                    }
1771                    GradOp::Abs(a) => {
1772                        let (data, shape) = {
1773                            let t = self.nodes[*a].borrow();
1774                            (t.tensor.to_vec(), t.tensor.shape().to_vec())
1775                        };
1776                        Tensor::from_vec_unchecked(data.iter().map(|x| x.abs()).collect(), &shape)
1777                    }
1778                    GradOp::Clamp { input, min, max } => {
1779                        let min = *min;
1780                        let max = *max;
1781                        let (data, shape) = {
1782                            let t = self.nodes[*input].borrow();
1783                            (t.tensor.to_vec(), t.tensor.shape().to_vec())
1784                        };
1785                        Tensor::from_vec_unchecked(
1786                            data.iter().map(|&x| x.max(min).min(max)).collect(),
1787                            &shape,
1788                        )
1789                    }
1790                    GradOp::Reshape { input, .. } => {
1791                        let current_shape = node.tensor.shape().to_vec();
1792                        let data = self.nodes[*input].borrow().tensor.to_vec();
1793                        Tensor::from_vec_unchecked(data, &current_shape)
1794                    }
1795                    GradOp::TransposeOp(a) => {
1796                        self.nodes[*a].borrow().tensor.transpose()
1797                    }
1798                    // For complex ops (softmax, layernorm, etc.), keep existing tensor.
1799                    // These are not typically used in simple Hessian computations.
1800                    _ => node.tensor.clone(),
1801                }
1802            };
1803            self.nodes[node_i].borrow_mut().tensor = new_tensor;
1804        }
1805    }
1806
1807    /// Compute the second derivative of a scalar loss with respect to a parameter node.
1808    ///
1809    /// Implements double_backward via finite differences on the backward pass:
1810    /// perturbs the parameter by +eps/-eps, re-runs the forward and backward pass,
1811    /// and computes d(grad)/d(param) numerically. For a scalar param this gives the
1812    /// exact second derivative d²loss/dparam².
1813    ///
1814    /// Returns a tensor of the same shape as the parameter containing second derivatives.
1815    pub fn double_backward(&mut self, loss_idx: usize, param_idx: usize) -> Tensor {
1816        let eps = 1e-5;
1817        let param_shape = self.nodes[param_idx].borrow().tensor.shape().to_vec();
1818        let param_dim: usize = param_shape.iter().product();
1819        let original = self.nodes[param_idx].borrow().tensor.to_vec();
1820        let mut diag = vec![0.0_f64; param_dim];
1821
1822        for i in 0..param_dim {
1823            // Perturb +eps, re-forward, backward
1824            let mut plus = original.clone();
1825            plus[i] += eps;
1826            self.nodes[param_idx].borrow_mut().tensor =
1827                Tensor::from_vec_unchecked(plus, &param_shape);
1828            self.reforward(param_idx + 1, loss_idx);
1829            self.zero_grad();
1830            self.backward(loss_idx);
1831            let grad_plus = self.nodes[param_idx]
1832                .borrow()
1833                .grad
1834                .as_ref()
1835                .map(|g| g.to_vec()[i])
1836                .unwrap_or(0.0);
1837
1838            // Perturb -eps, re-forward, backward
1839            let mut minus = original.clone();
1840            minus[i] -= eps;
1841            self.nodes[param_idx].borrow_mut().tensor =
1842                Tensor::from_vec_unchecked(minus, &param_shape);
1843            self.reforward(param_idx + 1, loss_idx);
1844            self.zero_grad();
1845            self.backward(loss_idx);
1846            let grad_minus = self.nodes[param_idx]
1847                .borrow()
1848                .grad
1849                .as_ref()
1850                .map(|g| g.to_vec()[i])
1851                .unwrap_or(0.0);
1852
1853            diag[i] = (grad_plus - grad_minus) / (2.0 * eps);
1854        }
1855
1856        // Restore original parameter and re-forward to clean state
1857        self.nodes[param_idx].borrow_mut().tensor =
1858            Tensor::from_vec_unchecked(original, &param_shape);
1859        self.reforward(param_idx + 1, loss_idx);
1860
1861        Tensor::from_vec_unchecked(diag, &param_shape)
1862    }
1863
1864    /// Vectorized map (batched evaluation) over a batch dimension.
1865    ///
1866    /// For each tensor in `batch_data`, sets the input node `input_idx` to that tensor,
1867    /// re-evaluates all downstream nodes by re-running the forward pass (recomputing
1868    /// tensor values from the graph structure), and records the output node index
1869    /// after each evaluation.
1870    ///
1871    /// Returns a `Vec<usize>` of output node indices (one per batch element). After
1872    /// calling this, `g.value(results[k])` returns the output for batch element k.
1873    ///
1874    /// Note: This is a simple batched evaluation helper. It mutates node tensors
1875    /// in-place. After calling vmap_forward, the graph holds the values for the
1876    /// LAST batch element. Use `g.value(results[k])` to read individual results
1877    /// (stored in snapshot tensors inside each returned node).
1878    ///
1879    /// Implementation: For each batch element, set the input tensor, re-forward the
1880    /// subgraph from input_idx..=loss_idx by replaying each op, and record the final
1881    /// node's value in a fresh parameter node.
1882    pub fn vmap_forward(&mut self, input_idx: usize, batch_data: &[Tensor]) -> Vec<usize> {
1883        let mut result_indices = Vec::with_capacity(batch_data.len());
1884
1885        // Identify the topological range: nodes from input_idx onward that depend on it.
1886        // We re-evaluate all nodes from input_idx to the end of the current graph.
1887        let graph_len = self.nodes.len();
1888
1889        for batch_tensor in batch_data {
1890            // Set input node to batch element
1891            self.nodes[input_idx].borrow_mut().tensor = batch_tensor.clone();
1892
1893            // Re-run forward pass for all nodes after input_idx by replaying their ops
1894            for node_i in (input_idx + 1)..graph_len {
1895                let (op, new_tensor) = {
1896                    let node = self.nodes[node_i].borrow();
1897                    let op = node.op.clone();
1898                    let new_tensor = match &op {
1899                        GradOp::Add(a, b) => {
1900                            let at = self.nodes[*a].borrow().tensor.clone();
1901                            let bt = self.nodes[*b].borrow().tensor.clone();
1902                            at.add_unchecked(&bt)
1903                        }
1904                        GradOp::Sub(a, b) => {
1905                            let at = self.nodes[*a].borrow().tensor.clone();
1906                            let bt = self.nodes[*b].borrow().tensor.clone();
1907                            at.sub_unchecked(&bt)
1908                        }
1909                        GradOp::Mul(a, b) => {
1910                            let at = self.nodes[*a].borrow().tensor.clone();
1911                            let bt = self.nodes[*b].borrow().tensor.clone();
1912                            at.mul_elem_unchecked(&bt)
1913                        }
1914                        GradOp::Div(a, b) => {
1915                            let at = self.nodes[*a].borrow().tensor.clone();
1916                            let bt = self.nodes[*b].borrow().tensor.clone();
1917                            at.div_elem_unchecked(&bt)
1918                        }
1919                        GradOp::Neg(a) => {
1920                            self.nodes[*a].borrow().tensor.neg()
1921                        }
1922                        GradOp::ScalarMul(a, s) => {
1923                            self.nodes[*a].borrow().tensor.scalar_mul(*s)
1924                        }
1925                        GradOp::MatMul(a, b) => {
1926                            let at = self.nodes[*a].borrow().tensor.clone();
1927                            let bt = self.nodes[*b].borrow().tensor.clone();
1928                            at.matmul_unchecked(&bt)
1929                        }
1930                        GradOp::Sum(a) => {
1931                            let s = self.nodes[*a].borrow().tensor.sum();
1932                            let shape = vec![1usize];
1933                            Tensor::from_vec_unchecked(vec![s], &shape)
1934                        }
1935                        GradOp::Mean(a) => {
1936                            let m = self.nodes[*a].borrow().tensor.mean();
1937                            Tensor::from_vec_unchecked(vec![m], &[1])
1938                        }
1939                        GradOp::Exp(a) => {
1940                            let data = self.nodes[*a].borrow().tensor.to_vec();
1941                            let shape = self.nodes[*a].borrow().tensor.shape().to_vec();
1942                            Tensor::from_vec_unchecked(
1943                                data.iter().map(|x| x.exp()).collect(),
1944                                &shape,
1945                            )
1946                        }
1947                        GradOp::Ln(a) => {
1948                            let data = self.nodes[*a].borrow().tensor.to_vec();
1949                            let shape = self.nodes[*a].borrow().tensor.shape().to_vec();
1950                            Tensor::from_vec_unchecked(
1951                                data.iter().map(|x| x.ln()).collect(),
1952                                &shape,
1953                            )
1954                        }
1955                        GradOp::Sin(a) => {
1956                            let data = self.nodes[*a].borrow().tensor.to_vec();
1957                            let shape = self.nodes[*a].borrow().tensor.shape().to_vec();
1958                            Tensor::from_vec_unchecked(
1959                                data.iter().map(|x| x.sin()).collect(),
1960                                &shape,
1961                            )
1962                        }
1963                        GradOp::Cos(a) => {
1964                            let data = self.nodes[*a].borrow().tensor.to_vec();
1965                            let shape = self.nodes[*a].borrow().tensor.shape().to_vec();
1966                            Tensor::from_vec_unchecked(
1967                                data.iter().map(|x| x.cos()).collect(),
1968                                &shape,
1969                            )
1970                        }
1971                        GradOp::Sqrt(a) => {
1972                            let data = self.nodes[*a].borrow().tensor.to_vec();
1973                            let shape = self.nodes[*a].borrow().tensor.shape().to_vec();
1974                            Tensor::from_vec_unchecked(
1975                                data.iter().map(|x| x.sqrt()).collect(),
1976                                &shape,
1977                            )
1978                        }
1979                        GradOp::Pow(a, n) => {
1980                            let data = self.nodes[*a].borrow().tensor.to_vec();
1981                            let shape = self.nodes[*a].borrow().tensor.shape().to_vec();
1982                            Tensor::from_vec_unchecked(
1983                                data.iter().map(|x| x.powf(*n)).collect(),
1984                                &shape,
1985                            )
1986                        }
1987                        GradOp::Sigmoid(a) => {
1988                            let data = self.nodes[*a].borrow().tensor.to_vec();
1989                            let shape = self.nodes[*a].borrow().tensor.shape().to_vec();
1990                            Tensor::from_vec_unchecked(
1991                                data.iter().map(|&x| 1.0 / (1.0 + (-x).exp())).collect(),
1992                                &shape,
1993                            )
1994                        }
1995                        GradOp::Relu(a) => {
1996                            let data = self.nodes[*a].borrow().tensor.to_vec();
1997                            let shape = self.nodes[*a].borrow().tensor.shape().to_vec();
1998                            Tensor::from_vec_unchecked(
1999                                data.iter().map(|&x| if x > 0.0 { x } else { 0.0 }).collect(),
2000                                &shape,
2001                            )
2002                        }
2003                        GradOp::TanhAct(a) => {
2004                            let data = self.nodes[*a].borrow().tensor.to_vec();
2005                            let shape = self.nodes[*a].borrow().tensor.shape().to_vec();
2006                            Tensor::from_vec_unchecked(
2007                                data.iter().map(|x| x.tanh()).collect(),
2008                                &shape,
2009                            )
2010                        }
2011                        GradOp::Abs(a) => {
2012                            let data = self.nodes[*a].borrow().tensor.to_vec();
2013                            let shape = self.nodes[*a].borrow().tensor.shape().to_vec();
2014                            Tensor::from_vec_unchecked(
2015                                data.iter().map(|x| x.abs()).collect(),
2016                                &shape,
2017                            )
2018                        }
2019                        GradOp::Clamp { input, min, max } => {
2020                            let data = self.nodes[*input].borrow().tensor.to_vec();
2021                            let shape = self.nodes[*input].borrow().tensor.shape().to_vec();
2022                            Tensor::from_vec_unchecked(
2023                                data.iter().map(|&x| x.max(*min).min(*max)).collect(),
2024                                &shape,
2025                            )
2026                        }
2027                        GradOp::Reshape { input, .. } => {
2028                            // Keep same data, use current node's shape
2029                            let data = self.nodes[*input].borrow().tensor.to_vec();
2030                            let shape = node.tensor.shape().to_vec();
2031                            Tensor::from_vec_unchecked(data, &shape)
2032                        }
2033                        GradOp::TransposeOp(a) => {
2034                            self.nodes[*a].borrow().tensor.transpose()
2035                        }
2036                        // For complex ops and ops without direct input dependency on input_idx,
2037                        // keep the existing tensor value (no re-computation needed).
2038                        _ => node.tensor.clone(),
2039                    };
2040                    (op, new_tensor)
2041                };
2042                let _ = op; // op already moved/used above
2043                self.nodes[node_i].borrow_mut().tensor = new_tensor;
2044            }
2045
2046            // Record the output value from the last node (graph_len - 1) by creating
2047            // a snapshot input node with the current output value.
2048            let output_tensor = self.nodes[graph_len - 1].borrow().tensor.clone();
2049            let snapshot_idx = self.nodes.len();
2050            self.nodes.push(Rc::new(RefCell::new(GradNode {
2051                op: GradOp::Input,
2052                tensor: output_tensor,
2053                grad: None,
2054            })));
2055            result_indices.push(snapshot_idx);
2056        }
2057
2058        result_indices
2059    }
2060
2061    /// Backward pass with a custom gradient seed tensor (for Jacobian computation).
2062    pub fn backward_with_seed(&mut self, loss_idx: usize, seed: &Tensor) {
2063        let n = self.nodes.len();
2064        let mut grads: Vec<Option<Tensor>> = vec![None; n];
2065        grads[loss_idx] = Some(seed.clone());
2066
2067        for i in (0..n).rev() {
2068            let grad = match grads[i].take() {
2069                Some(g) => g,
2070                None => continue,
2071            };
2072
2073            let node = self.nodes[i].borrow();
2074            if let Some(ref _param_grad) = node.grad {
2075                // Accumulate into parameter grad storage
2076                drop(node);
2077                let new_grad = {
2078                    let n = self.nodes[i].borrow();
2079                    if let Some(ref existing) = n.grad {
2080                        if existing.to_vec().iter().all(|&x| x == 0.0) {
2081                            grad.clone()
2082                        } else {
2083                            existing.add_unchecked(&grad)
2084                        }
2085                    } else {
2086                        grad.clone()
2087                    }
2088                };
2089                self.nodes[i].borrow_mut().grad = Some(new_grad);
2090            } else {
2091                drop(node);
2092            }
2093
2094            // Propagate gradients using the same rules as backward()
2095            let node = self.nodes[i].borrow();
2096            let node_tensor = node.tensor.clone();
2097            match &node.op {
2098                GradOp::Input | GradOp::Parameter => {}
2099                GradOp::Add(a, b) => {
2100                    accumulate_grad(&mut grads, *a, &grad);
2101                    accumulate_grad(&mut grads, *b, &grad);
2102                }
2103                GradOp::Sub(a, b) => {
2104                    accumulate_grad(&mut grads, *a, &grad);
2105                    accumulate_grad(&mut grads, *b, &grad.neg());
2106                }
2107                GradOp::Mul(a, b) => {
2108                    let a_val = self.nodes[*a].borrow().tensor.clone();
2109                    let b_val = self.nodes[*b].borrow().tensor.clone();
2110                    accumulate_grad(&mut grads, *a, &grad.mul_elem_unchecked(&b_val));
2111                    accumulate_grad(&mut grads, *b, &grad.mul_elem_unchecked(&a_val));
2112                }
2113                GradOp::Div(a, b) => {
2114                    let a_val = self.nodes[*a].borrow().tensor.clone();
2115                    let b_val = self.nodes[*b].borrow().tensor.clone();
2116                    let grad_a = grad.div_elem_unchecked(&b_val);
2117                    let neg_a_over_b2 = a_val.neg().div_elem_unchecked(
2118                        &b_val.mul_elem_unchecked(&b_val),
2119                    );
2120                    let grad_b = grad.mul_elem_unchecked(&neg_a_over_b2);
2121                    accumulate_grad(&mut grads, *a, &grad_a);
2122                    accumulate_grad(&mut grads, *b, &grad_b);
2123                }
2124                GradOp::Neg(a) => {
2125                    accumulate_grad(&mut grads, *a, &grad.neg());
2126                }
2127                GradOp::ScalarMul(a, s) => {
2128                    accumulate_grad(&mut grads, *a, &grad.scalar_mul(*s));
2129                }
2130                GradOp::Exp(a) => {
2131                    accumulate_grad(&mut grads, *a, &grad.mul_elem_unchecked(&node_tensor));
2132                }
2133                GradOp::Ln(a) => {
2134                    let a_val = self.nodes[*a].borrow().tensor.clone();
2135                    let inv = Tensor::from_vec_unchecked(
2136                        a_val.to_vec().iter().map(|&x| 1.0 / x).collect(),
2137                        a_val.shape(),
2138                    );
2139                    accumulate_grad(&mut grads, *a, &grad.mul_elem_unchecked(&inv));
2140                }
2141                GradOp::Sin(a) => {
2142                    let a_val = self.nodes[*a].borrow().tensor.clone();
2143                    let cos_a = Tensor::from_vec_unchecked(
2144                        a_val.to_vec().iter().map(|&x| x.cos()).collect(),
2145                        a_val.shape(),
2146                    );
2147                    accumulate_grad(&mut grads, *a, &grad.mul_elem_unchecked(&cos_a));
2148                }
2149                GradOp::Cos(a) => {
2150                    let a_val = self.nodes[*a].borrow().tensor.clone();
2151                    let neg_sin = Tensor::from_vec_unchecked(
2152                        a_val.to_vec().iter().map(|&x| -x.sin()).collect(),
2153                        a_val.shape(),
2154                    );
2155                    accumulate_grad(&mut grads, *a, &grad.mul_elem_unchecked(&neg_sin));
2156                }
2157                GradOp::Sqrt(a) => {
2158                    let inv2sqrt = Tensor::from_vec_unchecked(
2159                        node_tensor.to_vec().iter().map(|&x| 0.5 / x).collect(),
2160                        node_tensor.shape(),
2161                    );
2162                    accumulate_grad(&mut grads, *a, &grad.mul_elem_unchecked(&inv2sqrt));
2163                }
2164                GradOp::Pow(a, exp) => {
2165                    let a_val = self.nodes[*a].borrow().tensor.clone();
2166                    let local = Tensor::from_vec_unchecked(
2167                        a_val.to_vec().iter().map(|&x| exp * x.powf(exp - 1.0)).collect(),
2168                        a_val.shape(),
2169                    );
2170                    accumulate_grad(&mut grads, *a, &grad.mul_elem_unchecked(&local));
2171                }
2172                GradOp::Sigmoid(a) => {
2173                    let sig = &node_tensor;
2174                    let one_minus = Tensor::from_vec_unchecked(
2175                        sig.to_vec().iter().map(|&x| 1.0 - x).collect(),
2176                        sig.shape(),
2177                    );
2178                    let local = sig.mul_elem_unchecked(&one_minus);
2179                    accumulate_grad(&mut grads, *a, &grad.mul_elem_unchecked(&local));
2180                }
2181                GradOp::Relu(a) => {
2182                    let a_val = self.nodes[*a].borrow().tensor.clone();
2183                    let mask = Tensor::from_vec_unchecked(
2184                        a_val.to_vec().iter().map(|&x| if x > 0.0 { 1.0 } else { 0.0 }).collect(),
2185                        a_val.shape(),
2186                    );
2187                    accumulate_grad(&mut grads, *a, &grad.mul_elem_unchecked(&mask));
2188                }
2189                GradOp::TanhAct(a) => {
2190                    let one_minus_sq = Tensor::from_vec_unchecked(
2191                        node_tensor.to_vec().iter().map(|&x| 1.0 - x * x).collect(),
2192                        node_tensor.shape(),
2193                    );
2194                    accumulate_grad(&mut grads, *a, &grad.mul_elem_unchecked(&one_minus_sq));
2195                }
2196                GradOp::MatMul(a, b) => {
2197                    let a_val = self.nodes[*a].borrow().tensor.clone();
2198                    let b_val = self.nodes[*b].borrow().tensor.clone();
2199                    accumulate_grad(&mut grads, *a, &grad.matmul_unchecked(&b_val.transpose()));
2200                    accumulate_grad(&mut grads, *b, &a_val.transpose().matmul_unchecked(&grad));
2201                }
2202                GradOp::Sum(a) => {
2203                    let a_shape = self.nodes[*a].borrow().tensor.shape().to_vec();
2204                    let grad_val = grad.to_vec()[0];
2205                    let expanded = Tensor::from_vec_unchecked(
2206                        vec![grad_val; a_shape.iter().product()],
2207                        &a_shape,
2208                    );
2209                    accumulate_grad(&mut grads, *a, &expanded);
2210                }
2211                GradOp::Mean(a) => {
2212                    let a_shape = self.nodes[*a].borrow().tensor.shape().to_vec();
2213                    let n_elem = a_shape.iter().product::<usize>() as f64;
2214                    let grad_val = grad.to_vec()[0] / n_elem;
2215                    let expanded = Tensor::from_vec_unchecked(
2216                        vec![grad_val; a_shape.iter().product()],
2217                        &a_shape,
2218                    );
2219                    accumulate_grad(&mut grads, *a, &expanded);
2220                }
2221                GradOp::StructField { parent, field_index, total_fields } => {
2222                    let parent_shape = self.nodes[*parent].borrow().tensor.shape().to_vec();
2223                    let parent_n: usize = parent_shape.iter().product();
2224                    let chunk = parent_n / total_fields;
2225                    let start = field_index * chunk;
2226                    let mut parent_grad = vec![0.0_f64; parent_n];
2227                    let g_vec = grad.to_vec();
2228                    for (j, &gv) in g_vec.iter().enumerate() {
2229                        parent_grad[start + j] = gv;
2230                    }
2231                    let pg = Tensor::from_vec_unchecked(parent_grad, &parent_shape);
2232                    accumulate_grad(&mut grads, *parent, &pg);
2233                }
2234                GradOp::MapLookup { map_node, key_index, total_keys } => {
2235                    let map_shape = self.nodes[*map_node].borrow().tensor.shape().to_vec();
2236                    let map_n: usize = map_shape.iter().product();
2237                    let chunk = map_n / total_keys;
2238                    let start = key_index * chunk;
2239                    let mut map_grad = vec![0.0_f64; map_n];
2240                    let g_vec = grad.to_vec();
2241                    for (j, &gv) in g_vec.iter().enumerate() {
2242                        map_grad[start + j] = gv;
2243                    }
2244                    let mg = Tensor::from_vec_unchecked(map_grad, &map_shape);
2245                    accumulate_grad(&mut grads, *map_node, &mg);
2246                }
2247                // Phase 8: Extended AD backward (backward_with_seed)
2248                GradOp::Abs(a) => {
2249                    let a_val = self.nodes[*a].borrow().tensor.clone();
2250                    let sign = Tensor::from_vec_unchecked(
2251                        a_val.to_vec().iter().map(|&x| {
2252                            if x > 0.0 { 1.0 } else if x < 0.0 { -1.0 } else { 0.0 }
2253                        }).collect(),
2254                        a_val.shape(),
2255                    );
2256                    accumulate_grad(&mut grads, *a, &grad.mul_elem_unchecked(&sign));
2257                }
2258                GradOp::Log2(a) => {
2259                    let a_val = self.nodes[*a].borrow().tensor.clone();
2260                    let ln2 = std::f64::consts::LN_2;
2261                    let local = Tensor::from_vec_unchecked(
2262                        a_val.to_vec().iter().map(|&x| 1.0 / (x * ln2)).collect(),
2263                        a_val.shape(),
2264                    );
2265                    accumulate_grad(&mut grads, *a, &grad.mul_elem_unchecked(&local));
2266                }
2267                GradOp::Softmax(a) => {
2268                    use cjc_repro::KahanAccumulatorF64;
2269                    let sm = &node_tensor;
2270                    let sm_data = sm.to_vec();
2271                    let grad_data = grad.to_vec();
2272                    let mut dot_acc = KahanAccumulatorF64::new();
2273                    for (&g, &s) in grad_data.iter().zip(sm_data.iter()) {
2274                        dot_acc.add(g * s);
2275                    }
2276                    let dot = dot_acc.finalize();
2277                    let grad_input: Vec<f64> = sm_data.iter().zip(grad_data.iter())
2278                        .map(|(&s, &g)| s * (g - dot))
2279                        .collect();
2280                    let grad_a = Tensor::from_vec_unchecked(grad_input, sm.shape());
2281                    accumulate_grad(&mut grads, *a, &grad_a);
2282                }
2283                GradOp::CrossEntropy { logits, targets } => {
2284                    use cjc_repro::KahanAccumulatorF64;
2285                    let logits_val = self.nodes[*logits].borrow().tensor.clone();
2286                    let targets_val = self.nodes[*targets].borrow().tensor.clone();
2287                    let logits_data = logits_val.to_vec();
2288                    let targets_data = targets_val.to_vec();
2289                    let max_val = logits_data.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
2290                    let exp_shifted: Vec<f64> = logits_data.iter().map(|&x| (x - max_val).exp()).collect();
2291                    let mut sum_acc = KahanAccumulatorF64::new();
2292                    for &v in &exp_shifted {
2293                        sum_acc.add(v);
2294                    }
2295                    let sum_exp = sum_acc.finalize();
2296                    let softmax: Vec<f64> = exp_shifted.iter().map(|&e| e / sum_exp).collect();
2297                    let upstream = grad.to_vec()[0];
2298                    let grad_logits: Vec<f64> = softmax.iter().zip(targets_data.iter())
2299                        .map(|(&s, &t)| upstream * (s - t))
2300                        .collect();
2301                    let gl = Tensor::from_vec_unchecked(grad_logits, logits_val.shape());
2302                    accumulate_grad(&mut grads, *logits, &gl);
2303                }
2304                GradOp::LayerNorm(a) => {
2305                    use cjc_repro::KahanAccumulatorF64;
2306                    let x_hat = &node_tensor;
2307                    let x_hat_data = x_hat.to_vec();
2308                    let grad_data = grad.to_vec();
2309                    let n = x_hat_data.len() as f64;
2310                    let a_val = self.nodes[*a].borrow().tensor.clone();
2311                    let a_data = a_val.to_vec();
2312                    let mut mean_acc = KahanAccumulatorF64::new();
2313                    for &v in &a_data { mean_acc.add(v); }
2314                    let mean = mean_acc.finalize() / n;
2315                    let mut var_acc = KahanAccumulatorF64::new();
2316                    for &v in &a_data { let d = v - mean; var_acc.add(d * d); }
2317                    let var = var_acc.finalize() / n;
2318                    let std_val = (var + 1e-5).sqrt();
2319                    let mut mg_acc = KahanAccumulatorF64::new();
2320                    for &g in &grad_data { mg_acc.add(g); }
2321                    let mean_grad = mg_acc.finalize() / n;
2322                    let mut mgx_acc = KahanAccumulatorF64::new();
2323                    for (&g, &xh) in grad_data.iter().zip(x_hat_data.iter()) { mgx_acc.add(g * xh); }
2324                    let mean_grad_xhat = mgx_acc.finalize() / n;
2325                    let dx: Vec<f64> = grad_data.iter().zip(x_hat_data.iter())
2326                        .map(|(&g, &xh)| (g - mean_grad - xh * mean_grad_xhat) / std_val)
2327                        .collect();
2328                    accumulate_grad(&mut grads, *a, &Tensor::from_vec_unchecked(dx, a_val.shape()));
2329                }
2330                GradOp::BatchNorm(a) => {
2331                    use cjc_repro::KahanAccumulatorF64;
2332                    let x_hat = &node_tensor;
2333                    let x_hat_data = x_hat.to_vec();
2334                    let grad_data = grad.to_vec();
2335                    let n = x_hat_data.len() as f64;
2336                    let a_val = self.nodes[*a].borrow().tensor.clone();
2337                    let a_data = a_val.to_vec();
2338                    let mut mean_acc = KahanAccumulatorF64::new();
2339                    for &v in &a_data { mean_acc.add(v); }
2340                    let mean = mean_acc.finalize() / n;
2341                    let mut var_acc = KahanAccumulatorF64::new();
2342                    for &v in &a_data { let d = v - mean; var_acc.add(d * d); }
2343                    let var = var_acc.finalize() / n;
2344                    let std_val = (var + 1e-5).sqrt();
2345                    let mut mg_acc = KahanAccumulatorF64::new();
2346                    for &g in &grad_data { mg_acc.add(g); }
2347                    let mean_grad = mg_acc.finalize() / n;
2348                    let mut mgx_acc = KahanAccumulatorF64::new();
2349                    for (&g, &xh) in grad_data.iter().zip(x_hat_data.iter()) { mgx_acc.add(g * xh); }
2350                    let mean_grad_xhat = mgx_acc.finalize() / n;
2351                    let dx: Vec<f64> = grad_data.iter().zip(x_hat_data.iter())
2352                        .map(|(&g, &xh)| (g - mean_grad - xh * mean_grad_xhat) / std_val)
2353                        .collect();
2354                    accumulate_grad(&mut grads, *a, &Tensor::from_vec_unchecked(dx, a_val.shape()));
2355                }
2356                GradOp::Clamp { input, min, max } => {
2357                    let a_val = self.nodes[*input].borrow().tensor.clone();
2358                    let mask = Tensor::from_vec_unchecked(
2359                        a_val.to_vec().iter().map(|&x| {
2360                            if x >= *min && x <= *max { 1.0 } else { 0.0 }
2361                        }).collect(),
2362                        a_val.shape(),
2363                    );
2364                    accumulate_grad(&mut grads, *input, &grad.mul_elem_unchecked(&mask));
2365                }
2366                GradOp::Where { cond, on_true, on_false } => {
2367                    let cond_data = self.nodes[*cond].borrow().tensor.to_vec();
2368                    let grad_data = grad.to_vec();
2369                    let shape = grad.shape().to_vec();
2370                    let grad_true: Vec<f64> = cond_data.iter().zip(grad_data.iter())
2371                        .map(|(&c, &g)| if c != 0.0 { g } else { 0.0 }).collect();
2372                    let grad_false: Vec<f64> = cond_data.iter().zip(grad_data.iter())
2373                        .map(|(&c, &g)| if c != 0.0 { 0.0 } else { g }).collect();
2374                    accumulate_grad(&mut grads, *on_true, &Tensor::from_vec_unchecked(grad_true, &shape));
2375                    accumulate_grad(&mut grads, *on_false, &Tensor::from_vec_unchecked(grad_false, &shape));
2376                }
2377                GradOp::Reshape { input, ref original_shape } => {
2378                    let grad_a = grad.reshape(original_shape).expect("Reshape backward: shape mismatch");
2379                    accumulate_grad(&mut grads, *input, &grad_a);
2380                }
2381                GradOp::TransposeOp(a) => {
2382                    accumulate_grad(&mut grads, *a, &grad.transpose());
2383                }
2384                GradOp::CatOp { ref inputs, axis, ref sizes } => {
2385                    let grad_data = grad.to_vec();
2386                    let grad_shape = grad.shape().to_vec();
2387                    let ndim = grad_shape.len();
2388                    if ndim == 1 {
2389                        let mut offset = 0usize;
2390                        for (idx, &sz) in inputs.iter().zip(sizes.iter()) {
2391                            let piece = grad_data[offset..offset + sz].to_vec();
2392                            accumulate_grad(&mut grads, *idx, &Tensor::from_vec_unchecked(piece, &[sz]));
2393                            offset += sz;
2394                        }
2395                    } else if ndim == 2 && *axis == 0 {
2396                        let cols = grad_shape[1];
2397                        let mut row_offset = 0usize;
2398                        for (idx, &sz) in inputs.iter().zip(sizes.iter()) {
2399                            let start = row_offset * cols;
2400                            let end = start + sz * cols;
2401                            let piece = grad_data[start..end].to_vec();
2402                            accumulate_grad(&mut grads, *idx, &Tensor::from_vec_unchecked(piece, &[sz, cols]));
2403                            row_offset += sz;
2404                        }
2405                    } else if ndim == 2 && *axis == 1 {
2406                        let nrows = grad_shape[0];
2407                        let total_cols = grad_shape[1];
2408                        for (input_idx, (idx, &sz)) in inputs.iter().zip(sizes.iter()).enumerate() {
2409                            let mut piece = Vec::with_capacity(nrows * sz);
2410                            let col_offset: usize = sizes[..input_idx].iter().sum();
2411                            for row in 0..nrows {
2412                                let row_start = row * total_cols + col_offset;
2413                                piece.extend_from_slice(&grad_data[row_start..row_start + sz]);
2414                            }
2415                            accumulate_grad(&mut grads, *idx, &Tensor::from_vec_unchecked(piece, &[nrows, sz]));
2416                        }
2417                    } else {
2418                        let mut offset = 0usize;
2419                        for (idx, &sz) in inputs.iter().zip(sizes.iter()) {
2420                            let piece_len = sz * grad_data.len() / grad_shape[*axis];
2421                            let piece = grad_data[offset..offset + piece_len].to_vec();
2422                            let mut piece_shape = grad_shape.clone();
2423                            piece_shape[*axis] = sz;
2424                            accumulate_grad(&mut grads, *idx, &Tensor::from_vec_unchecked(piece, &piece_shape));
2425                            offset += piece_len;
2426                        }
2427                    }
2428                }
2429                GradOp::GatherOp { input, ref indices, axis } => {
2430                    let input_shape = self.nodes[*input].borrow().tensor.shape().to_vec();
2431                    let input_len: usize = input_shape.iter().product();
2432                    let mut scatter = vec![0.0_f64; input_len];
2433                    let grad_data = grad.to_vec();
2434                    if self.nodes[*input].borrow().tensor.ndim() == 1 {
2435                        for (gi, &idx) in indices.iter().enumerate() {
2436                            scatter[idx] += grad_data[gi];
2437                        }
2438                    } else if *axis == 0 && self.nodes[*input].borrow().tensor.ndim() == 2 {
2439                        let cols = input_shape[1];
2440                        for (gi, &idx) in indices.iter().enumerate() {
2441                            for c in 0..cols {
2442                                scatter[idx * cols + c] += grad_data[gi * cols + c];
2443                            }
2444                        }
2445                    } else {
2446                        for (gi, &idx) in indices.iter().enumerate() {
2447                            scatter[idx] += grad_data[gi];
2448                        }
2449                    }
2450                    accumulate_grad(&mut grads, *input, &Tensor::from_vec_unchecked(scatter, &input_shape));
2451                }
2452            }
2453        }
2454    }
2455}
2456
2457fn accumulate_grad(grads: &mut [Option<Tensor>], idx: usize, grad: &Tensor) {
2458    if let Some(existing) = &grads[idx] {
2459        grads[idx] = Some(existing.add_unchecked(grad));
2460    } else {
2461        grads[idx] = Some(grad.clone());
2462    }
2463}
2464
2465// Close the impl block — backward_with_seed, jacobian, hessian_diag are all inside impl GradGraph
2466
2467impl Default for GradGraph {
2468    fn default() -> Self {
2469        Self::new()
2470    }
2471}
2472
2473// ── Finite Difference Validation ────────────────────────────────
2474
2475/// Validate gradient using finite differences.
2476pub fn check_grad_finite_diff<F>(
2477    f: F,
2478    x: f64,
2479    expected_grad: f64,
2480    eps: f64,
2481    tol: f64,
2482) -> bool
2483where
2484    F: Fn(f64) -> f64,
2485{
2486    let fd_grad = (f(x + eps) - f(x - eps)) / (2.0 * eps);
2487    (fd_grad - expected_grad).abs() < tol
2488}
2489
2490#[cfg(test)]
2491mod tests {
2492    use super::*;
2493
2494    // ── Forward Mode Tests ──────────────────────────────────
2495
2496    #[test]
2497    fn test_dual_add() {
2498        let a = Dual::variable(3.0);
2499        let b = Dual::constant(2.0);
2500        let c = a + b;
2501        assert_eq!(c.value, 5.0);
2502        assert_eq!(c.deriv, 1.0);
2503    }
2504
2505    #[test]
2506    fn test_dual_mul() {
2507        let a = Dual::variable(3.0);
2508        let b = Dual::constant(2.0);
2509        let c = a * b;
2510        assert_eq!(c.value, 6.0);
2511        assert_eq!(c.deriv, 2.0); // d/dx (x * 2) = 2
2512    }
2513
2514    #[test]
2515    fn test_dual_chain_rule() {
2516        // f(x) = x^2 + 2x + 1, f'(x) = 2x + 2, f'(3) = 8
2517        let x = Dual::variable(3.0);
2518        let result = x.clone() * x.clone() + Dual::constant(2.0) * x + Dual::one();
2519        assert_eq!(result.value, 16.0);
2520        assert_eq!(result.deriv, 8.0);
2521    }
2522
2523    #[test]
2524    fn test_dual_exp() {
2525        let x = Dual::variable(1.0);
2526        let result = x.exp();
2527        assert!((result.value - std::f64::consts::E).abs() < 1e-10);
2528        assert!((result.deriv - std::f64::consts::E).abs() < 1e-10);
2529    }
2530
2531    #[test]
2532    fn test_dual_sin_cos() {
2533        let x = Dual::variable(0.0);
2534        let sin_x = x.clone().sin();
2535        let cos_x = x.cos();
2536        assert!((sin_x.value - 0.0).abs() < 1e-10);
2537        assert!((sin_x.deriv - 1.0).abs() < 1e-10); // d/dx sin(x) at 0 = cos(0) = 1
2538        assert!((cos_x.value - 1.0).abs() < 1e-10);
2539        assert!((cos_x.deriv - 0.0).abs() < 1e-10); // d/dx cos(x) at 0 = -sin(0) = 0
2540    }
2541
2542    #[test]
2543    fn test_dual_div() {
2544        let a = Dual::variable(6.0);
2545        let b = Dual::constant(3.0);
2546        let c = a / b;
2547        assert_eq!(c.value, 2.0);
2548        assert!((c.deriv - 1.0 / 3.0).abs() < 1e-10);
2549    }
2550
2551    #[test]
2552    fn test_finite_diff_validation() {
2553        // f(x) = x^2, f'(3) = 6
2554        let f = |x: f64| x * x;
2555        assert!(check_grad_finite_diff(f, 3.0, 6.0, 1e-7, 1e-5));
2556    }
2557
2558    // ── Reverse Mode Tests ──────────────────────────────────
2559
2560    #[test]
2561    fn test_reverse_add() {
2562        let mut g = GradGraph::new();
2563        let a = g.parameter(Tensor::from_vec_unchecked(vec![3.0], &[1]));
2564        let b = g.parameter(Tensor::from_vec_unchecked(vec![2.0], &[1]));
2565        let c = g.add(a, b);
2566
2567        g.backward(c);
2568
2569        let ga = g.grad(a).unwrap();
2570        let gb = g.grad(b).unwrap();
2571        assert!((ga.to_vec()[0] - 1.0).abs() < 1e-10);
2572        assert!((gb.to_vec()[0] - 1.0).abs() < 1e-10);
2573    }
2574
2575    #[test]
2576    fn test_reverse_mul() {
2577        let mut g = GradGraph::new();
2578        let a = g.parameter(Tensor::from_vec_unchecked(vec![3.0], &[1]));
2579        let b = g.parameter(Tensor::from_vec_unchecked(vec![2.0], &[1]));
2580        let c = g.mul(a, b);
2581
2582        g.backward(c);
2583
2584        let ga = g.grad(a).unwrap();
2585        let gb = g.grad(b).unwrap();
2586        assert!((ga.to_vec()[0] - 2.0).abs() < 1e-10); // d/da (a*b) = b = 2
2587        assert!((gb.to_vec()[0] - 3.0).abs() < 1e-10); // d/db (a*b) = a = 3
2588    }
2589
2590    #[test]
2591    fn test_reverse_matmul_gradient() {
2592        let mut g = GradGraph::new();
2593
2594        // Simple 2x2 matmul
2595        let a = g.parameter(Tensor::from_vec_unchecked(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]));
2596        let b = g.parameter(Tensor::from_vec_unchecked(vec![5.0, 6.0, 7.0, 8.0], &[2, 2]));
2597        let c = g.matmul(a, b);
2598        let loss = g.sum(c);
2599
2600        g.backward(loss);
2601
2602        // Gradient of sum(A @ B) w.r.t. A = ones @ B^T
2603        let ga = g.grad(a).unwrap();
2604        let ga_data = ga.to_vec();
2605        // B^T = [[5,7],[6,8]], ones@B^T = [[5+7, 6+8],[5+7, 6+8]] = [[12,14],[12,14]]
2606        // Wait: grad = ones(2,2), grad @ B^T
2607        // B^T = [[5,7],[6,8]]
2608        // ones(2,2) @ B^T = [[5+6, 7+8],[5+6, 7+8]] = [[11,15],[11,15]]
2609        assert!((ga_data[0] - 11.0).abs() < 1e-10);
2610        assert!((ga_data[1] - 15.0).abs() < 1e-10);
2611    }
2612
2613    #[test]
2614    fn test_reverse_mean_gradient() {
2615        let mut g = GradGraph::new();
2616        let a = g.parameter(Tensor::from_vec_unchecked(vec![2.0, 4.0, 6.0, 8.0], &[4]));
2617        let loss = g.mean(a);
2618
2619        g.backward(loss);
2620
2621        let ga = g.grad(a).unwrap();
2622        let ga_data = ga.to_vec();
2623        // d/da mean(a) = 1/N for each element
2624        for &v in &ga_data {
2625            assert!((v - 0.25).abs() < 1e-10);
2626        }
2627    }
2628
2629    // ── Phase B8: Reverse Mode Transcendental & Activation Tests ──
2630
2631    #[test]
2632    fn test_reverse_sin() {
2633        let mut g = GradGraph::new();
2634        let a = g.parameter(Tensor::from_vec_unchecked(vec![0.0], &[1]));
2635        let b = g.sin(a);
2636        g.backward(b);
2637        let ga = g.grad(a).unwrap();
2638        // d/dx sin(x) at x=0 = cos(0) = 1.0
2639        assert!((ga.to_vec()[0] - 1.0).abs() < 1e-10);
2640    }
2641
2642    #[test]
2643    fn test_reverse_cos() {
2644        let mut g = GradGraph::new();
2645        let a = g.parameter(Tensor::from_vec_unchecked(vec![0.0], &[1]));
2646        let b = g.cos(a);
2647        g.backward(b);
2648        let ga = g.grad(a).unwrap();
2649        // d/dx cos(x) at x=0 = -sin(0) = 0.0
2650        assert!(ga.to_vec()[0].abs() < 1e-10);
2651    }
2652
2653    #[test]
2654    fn test_reverse_sqrt() {
2655        let mut g = GradGraph::new();
2656        let a = g.parameter(Tensor::from_vec_unchecked(vec![4.0], &[1]));
2657        let b = g.sqrt(a);
2658        g.backward(b);
2659        let ga = g.grad(a).unwrap();
2660        // d/dx sqrt(x) at x=4 = 1/(2*2) = 0.25
2661        assert!((ga.to_vec()[0] - 0.25).abs() < 1e-10);
2662    }
2663
2664    #[test]
2665    fn test_reverse_pow() {
2666        let mut g = GradGraph::new();
2667        let a = g.parameter(Tensor::from_vec_unchecked(vec![2.0], &[1]));
2668        let b = g.pow(a, 3.0); // x^3
2669        g.backward(b);
2670        let ga = g.grad(a).unwrap();
2671        // d/dx x^3 at x=2 = 3*4 = 12.0
2672        assert!((ga.to_vec()[0] - 12.0).abs() < 1e-10);
2673    }
2674
2675    #[test]
2676    fn test_reverse_sigmoid() {
2677        let mut g = GradGraph::new();
2678        let a = g.parameter(Tensor::from_vec_unchecked(vec![0.0], &[1]));
2679        let b = g.sigmoid(a);
2680        g.backward(b);
2681        let ga = g.grad(a).unwrap();
2682        // sigmoid(0) = 0.5, sigmoid'(0) = 0.5 * 0.5 = 0.25
2683        assert!((ga.to_vec()[0] - 0.25).abs() < 1e-10);
2684    }
2685
2686    #[test]
2687    fn test_reverse_relu_positive() {
2688        let mut g = GradGraph::new();
2689        let a = g.parameter(Tensor::from_vec_unchecked(vec![3.0], &[1]));
2690        let b = g.relu(a);
2691        g.backward(b);
2692        let ga = g.grad(a).unwrap();
2693        // relu'(3) = 1.0
2694        assert!((ga.to_vec()[0] - 1.0).abs() < 1e-10);
2695    }
2696
2697    #[test]
2698    fn test_reverse_relu_negative() {
2699        let mut g = GradGraph::new();
2700        let a = g.parameter(Tensor::from_vec_unchecked(vec![-2.0], &[1]));
2701        let b = g.relu(a);
2702        g.backward(b);
2703        let ga = g.grad(a).unwrap();
2704        // relu'(-2) = 0.0
2705        assert!(ga.to_vec()[0].abs() < 1e-10);
2706    }
2707
2708    #[test]
2709    fn test_reverse_tanh() {
2710        let mut g = GradGraph::new();
2711        let a = g.parameter(Tensor::from_vec_unchecked(vec![0.0], &[1]));
2712        let b = g.tanh_act(a);
2713        g.backward(b);
2714        let ga = g.grad(a).unwrap();
2715        // tanh'(0) = 1 - tanh(0)^2 = 1 - 0 = 1.0
2716        assert!((ga.to_vec()[0] - 1.0).abs() < 1e-10);
2717    }
2718
2719    #[test]
2720    fn test_reverse_sin_cos_chain() {
2721        // f(x) = sin(cos(x)), f'(x) = cos(cos(x)) * (-sin(x))
2722        // at x=1: f'(1) = cos(cos(1)) * (-sin(1))
2723        let mut g = GradGraph::new();
2724        let a = g.parameter(Tensor::from_vec_unchecked(vec![1.0], &[1]));
2725        let c = g.cos(a);
2726        let s = g.sin(c);
2727        g.backward(s);
2728        let ga = g.grad(a).unwrap();
2729        let expected = 1.0_f64.cos().cos() * (-1.0_f64.sin());
2730        assert!((ga.to_vec()[0] - expected).abs() < 1e-10, "got {}, expected {expected}", ga.to_vec()[0]);
2731    }
2732
2733    #[test]
2734    fn test_reverse_sigmoid_sum() {
2735        // f(x) = sum(sigmoid(x)) for vector x
2736        let mut g = GradGraph::new();
2737        let a = g.parameter(Tensor::from_vec_unchecked(vec![0.0, 1.0, -1.0], &[3]));
2738        let s = g.sigmoid(a);
2739        let loss = g.sum(s);
2740        g.backward(loss);
2741        let ga = g.grad(a).unwrap();
2742        let ga_data = ga.to_vec();
2743        // sigmoid'(0) = 0.25, sigmoid'(1) = sig(1)*(1-sig(1)), sigmoid'(-1) = sig(-1)*(1-sig(-1))
2744        let sig1 = 1.0 / (1.0 + (-1.0_f64).exp());
2745        let sig_neg1 = 1.0 / (1.0 + 1.0_f64.exp());
2746        assert!((ga_data[0] - 0.25).abs() < 1e-10);
2747        assert!((ga_data[1] - sig1 * (1.0 - sig1)).abs() < 1e-10);
2748        assert!((ga_data[2] - sig_neg1 * (1.0 - sig_neg1)).abs() < 1e-10);
2749    }
2750
2751    #[test]
2752    fn test_b8_determinism() {
2753        let mut g1 = GradGraph::new();
2754        let a1 = g1.parameter(Tensor::from_vec_unchecked(vec![1.5], &[1]));
2755        let s1 = g1.sin(a1);
2756        g1.backward(s1);
2757        let ga1 = g1.grad(a1).unwrap().to_vec()[0];
2758
2759        let mut g2 = GradGraph::new();
2760        let a2 = g2.parameter(Tensor::from_vec_unchecked(vec![1.5], &[1]));
2761        let s2 = g2.sin(a2);
2762        g2.backward(s2);
2763        let ga2 = g2.grad(a2).unwrap().to_vec()[0];
2764
2765        assert_eq!(ga1.to_bits(), ga2.to_bits());
2766    }
2767
2768    #[test]
2769    fn test_reverse_mse_loss() {
2770        // MSE = mean((pred - target)^2)
2771        let mut g = GradGraph::new();
2772
2773        let w = g.parameter(Tensor::from_vec_unchecked(vec![1.0, 1.0], &[2, 1]));
2774        let x = g.input(Tensor::from_vec_unchecked(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]));
2775        let target = g.input(Tensor::from_vec_unchecked(vec![3.0, 7.0], &[2, 1]));
2776
2777        let pred = g.matmul(x, w);
2778        let diff = g.sub(pred, target);
2779        let sq = g.mul(diff, diff);
2780        let loss = g.mean(sq);
2781
2782        let loss_val = g.value(loss);
2783        g.backward(loss);
2784
2785        let gw = g.grad(w).unwrap();
2786
2787        // Verify loss is finite and gradient exists
2788        assert!(loss_val.is_finite());
2789        assert_eq!(gw.to_vec().len(), 2);
2790        for &v in &gw.to_vec() {
2791            assert!(v.is_finite());
2792        }
2793    }
2794
2795    // ── Phase C1: Reverse Mode Tests for New Forward Methods ──
2796
2797    #[test]
2798    fn test_reverse_div() {
2799        // f(x) = x / 2, f'(x) = 0.5
2800        let mut g = GradGraph::new();
2801        let a = g.parameter(Tensor::from_vec_unchecked(vec![6.0], &[1]));
2802        let b = g.input(Tensor::from_vec_unchecked(vec![2.0], &[1]));
2803        let c = g.div(a, b);
2804        g.backward(c);
2805        let ga = g.grad(a).unwrap();
2806        assert!((ga.to_vec()[0] - 0.5).abs() < 1e-10);
2807    }
2808
2809    #[test]
2810    fn test_reverse_neg() {
2811        // f(x) = -x, f'(x) = -1
2812        let mut g = GradGraph::new();
2813        let a = g.parameter(Tensor::from_vec_unchecked(vec![3.0], &[1]));
2814        let c = g.neg(a);
2815        g.backward(c);
2816        let ga = g.grad(a).unwrap();
2817        assert!((ga.to_vec()[0] - (-1.0)).abs() < 1e-10);
2818    }
2819
2820    #[test]
2821    fn test_reverse_scalar_mul() {
2822        // f(x) = 3x, f'(x) = 3
2823        let mut g = GradGraph::new();
2824        let a = g.parameter(Tensor::from_vec_unchecked(vec![2.0], &[1]));
2825        let c = g.scalar_mul(a, 3.0);
2826        g.backward(c);
2827        let ga = g.grad(a).unwrap();
2828        assert!((ga.to_vec()[0] - 3.0).abs() < 1e-10);
2829    }
2830
2831    #[test]
2832    fn test_reverse_exp() {
2833        // f(x) = exp(x), f'(x) = exp(x)
2834        let mut g = GradGraph::new();
2835        let a = g.parameter(Tensor::from_vec_unchecked(vec![1.0], &[1]));
2836        let c = g.exp(a);
2837        g.backward(c);
2838        let ga = g.grad(a).unwrap();
2839        assert!((ga.to_vec()[0] - std::f64::consts::E).abs() < 1e-10);
2840    }
2841
2842    #[test]
2843    fn test_reverse_ln() {
2844        // f(x) = ln(x), f'(x) = 1/x, at x=2 → 0.5
2845        let mut g = GradGraph::new();
2846        let a = g.parameter(Tensor::from_vec_unchecked(vec![2.0], &[1]));
2847        let c = g.ln(a);
2848        g.backward(c);
2849        let ga = g.grad(a).unwrap();
2850        assert!((ga.to_vec()[0] - 0.5).abs() < 1e-10);
2851    }
2852
2853    // ── Phase 8: Extended AD Tests ──
2854
2855    #[test]
2856    fn test_reverse_abs_positive() {
2857        let mut g = GradGraph::new();
2858        let a = g.parameter(Tensor::from_vec_unchecked(vec![3.0], &[1]));
2859        let b = g.abs(a);
2860        g.backward(b);
2861        let ga = g.grad(a).unwrap();
2862        // abs'(3) = sign(3) = 1.0
2863        assert!((ga.to_vec()[0] - 1.0).abs() < 1e-10);
2864    }
2865
2866    #[test]
2867    fn test_reverse_abs_negative() {
2868        let mut g = GradGraph::new();
2869        let a = g.parameter(Tensor::from_vec_unchecked(vec![-2.5], &[1]));
2870        let b = g.abs(a);
2871        g.backward(b);
2872        let ga = g.grad(a).unwrap();
2873        // abs'(-2.5) = sign(-2.5) = -1.0
2874        assert!((ga.to_vec()[0] - (-1.0)).abs() < 1e-10);
2875    }
2876
2877    #[test]
2878    fn test_reverse_abs_zero() {
2879        let mut g = GradGraph::new();
2880        let a = g.parameter(Tensor::from_vec_unchecked(vec![0.0], &[1]));
2881        let b = g.abs(a);
2882        g.backward(b);
2883        let ga = g.grad(a).unwrap();
2884        // abs'(0) = 0.0 (subgradient convention)
2885        assert!(ga.to_vec()[0].abs() < 1e-10);
2886    }
2887
2888    #[test]
2889    fn test_reverse_abs_vector() {
2890        let mut g = GradGraph::new();
2891        let a = g.parameter(Tensor::from_vec_unchecked(vec![-1.0, 2.0, 0.0, -3.0], &[4]));
2892        let b = g.abs(a);
2893        let loss = g.sum(b);
2894        g.backward(loss);
2895        let ga = g.grad(a).unwrap();
2896        let expected = vec![-1.0, 1.0, 0.0, -1.0];
2897        for (i, (&got, &exp)) in ga.to_vec().iter().zip(expected.iter()).enumerate() {
2898            assert!((got - exp).abs() < 1e-10, "abs grad[{i}]: got {got}, expected {exp}");
2899        }
2900    }
2901
2902    #[test]
2903    fn test_reverse_log2() {
2904        let mut g = GradGraph::new();
2905        let a = g.parameter(Tensor::from_vec_unchecked(vec![4.0], &[1]));
2906        let b = g.log2(a);
2907        g.backward(b);
2908        let ga = g.grad(a).unwrap();
2909        // d/dx log2(x) at x=4 = 1/(4 * ln(2))
2910        let expected = 1.0 / (4.0 * std::f64::consts::LN_2);
2911        assert!((ga.to_vec()[0] - expected).abs() < 1e-10);
2912    }
2913
2914    #[test]
2915    fn test_reverse_log2_vector() {
2916        let mut g = GradGraph::new();
2917        let a = g.parameter(Tensor::from_vec_unchecked(vec![1.0, 2.0, 8.0], &[3]));
2918        let b = g.log2(a);
2919        let loss = g.sum(b);
2920        g.backward(loss);
2921        let ga = g.grad(a).unwrap();
2922        let ln2 = std::f64::consts::LN_2;
2923        let expected = vec![1.0 / (1.0 * ln2), 1.0 / (2.0 * ln2), 1.0 / (8.0 * ln2)];
2924        for (i, (&got, &exp)) in ga.to_vec().iter().zip(expected.iter()).enumerate() {
2925            assert!((got - exp).abs() < 1e-10, "log2 grad[{i}]: got {got}, expected {exp}");
2926        }
2927    }
2928
2929    #[test]
2930    fn test_softmax_forward() {
2931        let mut g = GradGraph::new();
2932        let a = g.input(Tensor::from_vec_unchecked(vec![1.0, 2.0, 3.0], &[3]));
2933        let b = g.softmax(a);
2934        let sm = g.tensor(b);
2935        let sm_data = sm.to_vec();
2936        // softmax values should sum to 1
2937        let sum: f64 = sm_data.iter().sum();
2938        assert!((sum - 1.0).abs() < 1e-10);
2939        // Verify ordering: softmax(3) > softmax(2) > softmax(1)
2940        assert!(sm_data[2] > sm_data[1]);
2941        assert!(sm_data[1] > sm_data[0]);
2942    }
2943
2944    #[test]
2945    fn test_reverse_softmax() {
2946        // Finite difference check for softmax gradient
2947        let mut g = GradGraph::new();
2948        let a = g.parameter(Tensor::from_vec_unchecked(vec![1.0, 2.0, 3.0], &[3]));
2949        let b = g.softmax(a);
2950        let loss = g.sum(b);
2951        g.backward(loss);
2952        let ga = g.grad(a).unwrap();
2953        // sum(softmax(x)) = 1 always, so d/dx sum(softmax(x)) = 0
2954        for &v in &ga.to_vec() {
2955            assert!(v.abs() < 1e-10, "softmax sum grad should be 0, got {v}");
2956        }
2957    }
2958
2959    #[test]
2960    fn test_reverse_softmax_single_element() {
2961        // With a single element, softmax gradient through sum should be 0
2962        let mut g = GradGraph::new();
2963        let a = g.parameter(Tensor::from_vec_unchecked(vec![2.0, 1.0], &[2]));
2964        let b = g.softmax(a);
2965        // Take only first element via scalar_mul trick: multiply by [1, 0]
2966        // Instead, use a direct sum which should give zero grad
2967        let loss = g.sum(b);
2968        g.backward(loss);
2969        let ga = g.grad(a).unwrap();
2970        for &v in &ga.to_vec() {
2971            assert!(v.abs() < 1e-10);
2972        }
2973    }
2974
2975    #[test]
2976    fn test_cross_entropy_forward() {
2977        let mut g = GradGraph::new();
2978        let logits = g.input(Tensor::from_vec_unchecked(vec![2.0, 1.0, 0.1], &[3]));
2979        let targets = g.input(Tensor::from_vec_unchecked(vec![1.0, 0.0, 0.0], &[3])); // one-hot
2980        let ce = g.cross_entropy(logits, targets);
2981        let loss_val = g.value(ce);
2982        assert!(loss_val > 0.0, "CE loss should be positive");
2983        assert!(loss_val.is_finite(), "CE loss should be finite");
2984    }
2985
2986    #[test]
2987    fn test_reverse_cross_entropy() {
2988        let mut g = GradGraph::new();
2989        let logits = g.parameter(Tensor::from_vec_unchecked(vec![2.0, 1.0, 0.1], &[3]));
2990        let targets = g.input(Tensor::from_vec_unchecked(vec![1.0, 0.0, 0.0], &[3]));
2991        let ce = g.cross_entropy(logits, targets);
2992        g.backward(ce);
2993        let ga = g.grad(logits).unwrap();
2994        let ga_data = ga.to_vec();
2995        // grad = softmax(logits) - targets
2996        // softmax should give something like [0.659, 0.242, 0.099]
2997        // grad[0] should be negative (softmax < 1.0 for correct class)
2998        assert!(ga_data[0] < 0.0, "CE grad for correct class should be negative");
2999        assert!(ga_data[1] > 0.0, "CE grad for incorrect class should be positive");
3000        assert!(ga_data[2] > 0.0, "CE grad for incorrect class should be positive");
3001        // Sum of gradients should be ~0 (softmax sums to 1, targets sum to 1)
3002        let sum: f64 = ga_data.iter().sum();
3003        assert!(sum.abs() < 1e-10, "CE grad should sum to 0, got {sum}");
3004    }
3005
3006    #[test]
3007    fn test_layer_norm_forward() {
3008        let mut g = GradGraph::new();
3009        let a = g.input(Tensor::from_vec_unchecked(vec![1.0, 2.0, 3.0, 4.0], &[4]));
3010        let b = g.layer_norm(a);
3011        let normed = g.tensor(b).to_vec();
3012        // After layer norm, mean should be ~0 and std ~1
3013        let mean: f64 = normed.iter().sum::<f64>() / normed.len() as f64;
3014        assert!(mean.abs() < 1e-5, "LayerNorm mean should be ~0, got {mean}");
3015        let var: f64 = normed.iter().map(|x| (x - mean) * (x - mean)).sum::<f64>() / normed.len() as f64;
3016        assert!((var - 1.0).abs() < 0.01, "LayerNorm variance should be ~1, got {var}");
3017    }
3018
3019    #[test]
3020    fn test_reverse_layer_norm() {
3021        let mut g = GradGraph::new();
3022        let a = g.parameter(Tensor::from_vec_unchecked(vec![1.0, 2.0, 3.0, 4.0], &[4]));
3023        let b = g.layer_norm(a);
3024        let loss = g.sum(b);
3025        g.backward(loss);
3026        let ga = g.grad(a).unwrap();
3027        // Gradient should be finite and non-zero
3028        for &v in &ga.to_vec() {
3029            assert!(v.is_finite(), "LayerNorm grad should be finite");
3030        }
3031        // Finite difference check
3032        let eps = 1e-5;
3033        let input_data = vec![1.0, 2.0, 3.0, 4.0];
3034        for i in 0..4 {
3035            let mut plus = input_data.clone();
3036            plus[i] += eps;
3037            let mut g_plus = GradGraph::new();
3038            let a_plus = g_plus.input(Tensor::from_vec_unchecked(plus, &[4]));
3039            let b_plus = g_plus.layer_norm(a_plus);
3040            let loss_plus = g_plus.sum(b_plus);
3041            let val_plus = g_plus.value(loss_plus);
3042
3043            let mut minus = input_data.clone();
3044            minus[i] -= eps;
3045            let mut g_minus = GradGraph::new();
3046            let a_minus = g_minus.input(Tensor::from_vec_unchecked(minus, &[4]));
3047            let b_minus = g_minus.layer_norm(a_minus);
3048            let loss_minus = g_minus.sum(b_minus);
3049            let val_minus = g_minus.value(loss_minus);
3050
3051            let fd_grad = (val_plus - val_minus) / (2.0 * eps);
3052            let ad_grad = ga.to_vec()[i];
3053            assert!(
3054                (fd_grad - ad_grad).abs() < 1e-4,
3055                "LayerNorm FD check failed at [{i}]: fd={fd_grad}, ad={ad_grad}"
3056            );
3057        }
3058    }
3059
3060    #[test]
3061    fn test_batch_norm_forward() {
3062        let mut g = GradGraph::new();
3063        let a = g.input(Tensor::from_vec_unchecked(vec![2.0, 4.0, 6.0, 8.0], &[4]));
3064        let b = g.batch_norm(a);
3065        let normed = g.tensor(b).to_vec();
3066        let mean: f64 = normed.iter().sum::<f64>() / normed.len() as f64;
3067        assert!(mean.abs() < 1e-5, "BatchNorm mean should be ~0, got {mean}");
3068    }
3069
3070    #[test]
3071    fn test_reverse_batch_norm() {
3072        let mut g = GradGraph::new();
3073        let a = g.parameter(Tensor::from_vec_unchecked(vec![2.0, 4.0, 6.0, 8.0], &[4]));
3074        let b = g.batch_norm(a);
3075        let loss = g.sum(b);
3076        g.backward(loss);
3077        let ga = g.grad(a).unwrap();
3078        for &v in &ga.to_vec() {
3079            assert!(v.is_finite(), "BatchNorm grad should be finite");
3080        }
3081    }
3082
3083    #[test]
3084    fn test_reverse_clamp_in_range() {
3085        let mut g = GradGraph::new();
3086        let a = g.parameter(Tensor::from_vec_unchecked(vec![1.5], &[1]));
3087        let b = g.clamp(a, 0.0, 3.0);
3088        g.backward(b);
3089        let ga = g.grad(a).unwrap();
3090        // 1.5 is in [0, 3], so grad passes through: 1.0
3091        assert!((ga.to_vec()[0] - 1.0).abs() < 1e-10);
3092    }
3093
3094    #[test]
3095    fn test_reverse_clamp_out_of_range() {
3096        let mut g = GradGraph::new();
3097        let a = g.parameter(Tensor::from_vec_unchecked(vec![5.0], &[1]));
3098        let b = g.clamp(a, 0.0, 3.0);
3099        g.backward(b);
3100        let ga = g.grad(a).unwrap();
3101        // 5.0 is outside [0, 3], so grad is 0
3102        assert!(ga.to_vec()[0].abs() < 1e-10);
3103    }
3104
3105    #[test]
3106    fn test_reverse_clamp_vector() {
3107        let mut g = GradGraph::new();
3108        let a = g.parameter(Tensor::from_vec_unchecked(vec![-1.0, 0.5, 2.0, 4.0], &[4]));
3109        let b = g.clamp(a, 0.0, 3.0);
3110        let loss = g.sum(b);
3111        g.backward(loss);
3112        let ga = g.grad(a).unwrap();
3113        let expected = vec![0.0, 1.0, 1.0, 0.0]; // -1 out, 0.5 in, 2 in, 4 out
3114        for (i, (&got, &exp)) in ga.to_vec().iter().zip(expected.iter()).enumerate() {
3115            assert!((got - exp).abs() < 1e-10, "clamp grad[{i}]: got {got}, expected {exp}");
3116        }
3117    }
3118
3119    #[test]
3120    fn test_reverse_where_cond() {
3121        let mut g = GradGraph::new();
3122        let cond = g.input(Tensor::from_vec_unchecked(vec![1.0, 0.0, 1.0], &[3]));
3123        let a = g.parameter(Tensor::from_vec_unchecked(vec![10.0, 20.0, 30.0], &[3]));
3124        let b = g.parameter(Tensor::from_vec_unchecked(vec![100.0, 200.0, 300.0], &[3]));
3125        let w = g.where_cond(cond, a, b);
3126        // Forward: should select [10, 200, 30]
3127        let result = g.tensor(w).to_vec();
3128        assert!((result[0] - 10.0).abs() < 1e-10);
3129        assert!((result[1] - 200.0).abs() < 1e-10);
3130        assert!((result[2] - 30.0).abs() < 1e-10);
3131        let loss = g.sum(w);
3132        g.backward(loss);
3133        let ga = g.grad(a).unwrap().to_vec();
3134        let gb = g.grad(b).unwrap().to_vec();
3135        // grad flows to a where cond=1, to b where cond=0
3136        assert!((ga[0] - 1.0).abs() < 1e-10);
3137        assert!(ga[1].abs() < 1e-10);
3138        assert!((ga[2] - 1.0).abs() < 1e-10);
3139        assert!(gb[0].abs() < 1e-10);
3140        assert!((gb[1] - 1.0).abs() < 1e-10);
3141        assert!(gb[2].abs() < 1e-10);
3142    }
3143
3144    #[test]
3145    fn test_reverse_reshape() {
3146        let mut g = GradGraph::new();
3147        let a = g.parameter(Tensor::from_vec_unchecked(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]));
3148        let b = g.reshape(a, &[3, 2]);
3149        let loss = g.sum(b);
3150        g.backward(loss);
3151        let ga = g.grad(a).unwrap();
3152        // Reshape backward: grad should have original shape [2, 3], all ones
3153        assert_eq!(ga.shape(), &[2, 3]);
3154        for &v in &ga.to_vec() {
3155            assert!((v - 1.0).abs() < 1e-10);
3156        }
3157    }
3158
3159    #[test]
3160    fn test_reverse_transpose_op() {
3161        let mut g = GradGraph::new();
3162        let a = g.parameter(Tensor::from_vec_unchecked(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]));
3163        let b = g.transpose_op(a);
3164        // Transposed shape should be [3, 2]
3165        assert_eq!(g.tensor(b).shape(), &[3, 2]);
3166        let loss = g.sum(b);
3167        g.backward(loss);
3168        let ga = g.grad(a).unwrap();
3169        // Transpose backward: grad should have original shape [2, 3], all ones
3170        assert_eq!(ga.shape(), &[2, 3]);
3171        for &v in &ga.to_vec() {
3172            assert!((v - 1.0).abs() < 1e-10);
3173        }
3174    }
3175
3176    #[test]
3177    fn test_reverse_cat_1d() {
3178        let mut g = GradGraph::new();
3179        let a = g.parameter(Tensor::from_vec_unchecked(vec![1.0, 2.0], &[2]));
3180        let b = g.parameter(Tensor::from_vec_unchecked(vec![3.0, 4.0, 5.0], &[3]));
3181        let c = g.cat(&[a, b], 0);
3182        // Forward: [1, 2, 3, 4, 5]
3183        let result = g.tensor(c).to_vec();
3184        assert_eq!(result, vec![1.0, 2.0, 3.0, 4.0, 5.0]);
3185        let loss = g.sum(c);
3186        g.backward(loss);
3187        let ga = g.grad(a).unwrap().to_vec();
3188        let gb = g.grad(b).unwrap().to_vec();
3189        // All gradients should be 1 (from sum)
3190        assert_eq!(ga, vec![1.0, 1.0]);
3191        assert_eq!(gb, vec![1.0, 1.0, 1.0]);
3192    }
3193
3194    #[test]
3195    fn test_reverse_cat_2d_axis0() {
3196        let mut g = GradGraph::new();
3197        let a = g.parameter(Tensor::from_vec_unchecked(vec![1.0, 2.0], &[1, 2]));
3198        let b = g.parameter(Tensor::from_vec_unchecked(vec![3.0, 4.0, 5.0, 6.0], &[2, 2]));
3199        let c = g.cat(&[a, b], 0);
3200        assert_eq!(g.tensor(c).shape(), &[3, 2]);
3201        let loss = g.sum(c);
3202        g.backward(loss);
3203        let ga = g.grad(a).unwrap();
3204        let gb = g.grad(b).unwrap();
3205        assert_eq!(ga.shape(), &[1, 2]);
3206        assert_eq!(gb.shape(), &[2, 2]);
3207        for &v in &ga.to_vec() {
3208            assert!((v - 1.0).abs() < 1e-10);
3209        }
3210        for &v in &gb.to_vec() {
3211            assert!((v - 1.0).abs() < 1e-10);
3212        }
3213    }
3214
3215    #[test]
3216    fn test_reverse_gather_1d() {
3217        let mut g = GradGraph::new();
3218        let a = g.parameter(Tensor::from_vec_unchecked(vec![10.0, 20.0, 30.0, 40.0], &[4]));
3219        let b = g.gather(a, &[1, 3], 0);
3220        // Forward: [20, 40]
3221        let result = g.tensor(b).to_vec();
3222        assert!((result[0] - 20.0).abs() < 1e-10);
3223        assert!((result[1] - 40.0).abs() < 1e-10);
3224        let loss = g.sum(b);
3225        g.backward(loss);
3226        let ga = g.grad(a).unwrap().to_vec();
3227        // Scatter-add: indices [1, 3] get grad 1.0 each, others get 0
3228        assert!((ga[0]).abs() < 1e-10);
3229        assert!((ga[1] - 1.0).abs() < 1e-10);
3230        assert!((ga[2]).abs() < 1e-10);
3231        assert!((ga[3] - 1.0).abs() < 1e-10);
3232    }
3233
3234    #[test]
3235    fn test_reverse_gather_duplicate_indices() {
3236        let mut g = GradGraph::new();
3237        let a = g.parameter(Tensor::from_vec_unchecked(vec![10.0, 20.0, 30.0], &[3]));
3238        let b = g.gather(a, &[1, 1, 2], 0);
3239        let loss = g.sum(b);
3240        g.backward(loss);
3241        let ga = g.grad(a).unwrap().to_vec();
3242        // Index 1 appears twice, so its grad should be 2.0
3243        assert!((ga[0]).abs() < 1e-10);
3244        assert!((ga[1] - 2.0).abs() < 1e-10);
3245        assert!((ga[2] - 1.0).abs() < 1e-10);
3246    }
3247
3248    #[test]
3249    fn test_phase8_determinism() {
3250        // Run twice and verify bit-identical gradients
3251        for _ in 0..2 {
3252            let run = || {
3253                let mut g = GradGraph::new();
3254                let a = g.parameter(Tensor::from_vec_unchecked(vec![1.0, -2.0, 3.0, -0.5], &[4]));
3255                let b = g.abs(a);
3256                let c = g.clamp(b, 0.0, 2.5);
3257                let d = g.layer_norm(c);
3258                let loss = g.sum(d);
3259                g.backward(loss);
3260                g.grad(a).unwrap().to_vec()
3261            };
3262            let r1 = run();
3263            let r2 = run();
3264            for (i, (v1, v2)) in r1.iter().zip(r2.iter()).enumerate() {
3265                assert_eq!(v1.to_bits(), v2.to_bits(), "Determinism failed at [{i}]");
3266            }
3267        }
3268    }
3269
3270    #[test]
3271    fn test_phase8_softmax_cross_entropy_chain() {
3272        // Verify that softmax + CE combined gradient works end-to-end
3273        let mut g = GradGraph::new();
3274        let logits = g.parameter(Tensor::from_vec_unchecked(vec![1.0, 2.0, 3.0], &[3]));
3275        let targets = g.input(Tensor::from_vec_unchecked(vec![0.0, 0.0, 1.0], &[3]));
3276        let ce = g.cross_entropy(logits, targets);
3277        g.backward(ce);
3278        let ga = g.grad(logits).unwrap().to_vec();
3279        // Verify all gradients are finite
3280        for &v in &ga {
3281            assert!(v.is_finite());
3282        }
3283        // CE grad for correct class (index 2) should be negative (softmax - 1)
3284        assert!(ga[2] < 0.0, "CE grad for correct class should be negative");
3285    }
3286
3287    // ── Sprint 1 SciML Hardening: hessian, double_backward, vmap_forward ──
3288
3289    #[test]
3290    fn test_double_backward_cubic() {
3291        // f(x) = x^3, f'(x) = 3x^2, f''(x) = 6x
3292        let mut g = GradGraph::new();
3293        let x = g.parameter(Tensor::from_vec_unchecked(vec![2.0], &[1]));
3294        let x2 = g.mul(x, x);
3295        let x3 = g.mul(x2, x);
3296        let loss = g.sum(x3);
3297        let hess = g.double_backward(loss, x);
3298        // f''(2) = 6*2 = 12
3299        assert!((hess.to_vec()[0] - 12.0).abs() < 1e-4);
3300    }
3301
3302    #[test]
3303    fn test_full_hessian_quadratic() {
3304        // f(x, y) = x^2 + y^2 using p = [x, y]
3305        // H = [[2, 0], [0, 2]]
3306        let mut g = GradGraph::new();
3307        let p = g.parameter(Tensor::from_vec_unchecked(vec![1.0, 1.0], &[2]));
3308        let p2 = g.mul(p, p); // [x^2, y^2]
3309        let s = g.sum(p2); // x^2 + y^2
3310        let hess = g.hessian(s, p);
3311        let h = hess.to_vec();
3312        assert!((h[0] - 2.0).abs() < 1e-3); // d2f/dx2 = 2
3313        assert!((h[1] - 0.0).abs() < 1e-3); // d2f/dxdy = 0
3314        assert!((h[2] - 0.0).abs() < 1e-3); // d2f/dydx = 0
3315        assert!((h[3] - 2.0).abs() < 1e-3); // d2f/dy2 = 2
3316    }
3317
3318    #[test]
3319    fn test_vmap_forward() {
3320        let mut g = GradGraph::new();
3321        let x = g.parameter(Tensor::from_vec_unchecked(vec![1.0], &[1]));
3322        let x2 = g.mul(x, x);
3323        let loss = g.sum(x2);
3324
3325        // vmap over batch [1.0, 2.0, 3.0]
3326        let batch = vec![
3327            Tensor::from_vec_unchecked(vec![1.0], &[1]),
3328            Tensor::from_vec_unchecked(vec![2.0], &[1]),
3329            Tensor::from_vec_unchecked(vec![3.0], &[1]),
3330        ];
3331        let results = g.vmap_forward(x, &batch);
3332        // Results should be [1.0, 4.0, 9.0]
3333        assert!((g.value(results[0]) - 1.0).abs() < 1e-10);
3334        assert!((g.value(results[1]) - 4.0).abs() < 1e-10);
3335        assert!((g.value(results[2]) - 9.0).abs() < 1e-10);
3336    }
3337
3338    #[test]
3339    fn test_hessian_determinism() {
3340        let mut g = GradGraph::new();
3341        let p = g.parameter(Tensor::from_vec_unchecked(vec![3.0, 4.0], &[2]));
3342        let p2 = g.mul(p, p);
3343        let s = g.sum(p2);
3344        let h1 = g.hessian(s, p);
3345        // Reset and redo
3346        g.nodes[p].borrow_mut().tensor = Tensor::from_vec_unchecked(vec![3.0, 4.0], &[2]);
3347        let h2 = g.hessian(s, p);
3348        assert_eq!(h1.to_vec(), h2.to_vec(), "Hessian must be deterministic");
3349    }
3350}