Skip to main content

numra_autodiff/
reverse.rs

1//! Reverse-mode automatic differentiation via tape-based computation graph.
2//!
3//! This module provides [`Var`], a tracked variable type that records operations
4//! on a shared [`Tape`]. After building the computation graph in the forward pass,
5//! call [`Tape::gradient`] to compute derivatives in a single backward pass.
6//!
7//! # Comparison with forward-mode
8//!
9//! - **Forward-mode** ([`Dual`](crate::Dual)): One directional derivative per pass.
10//!   Cost: O(n) passes for n inputs. Best for few inputs, many outputs.
11//! - **Reverse-mode** ([`Var`]): All gradients in one backward pass.
12//!   Cost: O(m) passes for m outputs. Best for many inputs, few outputs (optimization).
13//!
14//! # Example
15//!
16//! ```rust
17//! use numra_autodiff::reverse::{grad, hessian};
18//!
19//! // Gradient of Rosenbrock: f(x,y) = (1-x)^2 + 100*(y-x^2)^2
20//! let g = grad(
21//!     |x| {
22//!         let a = x[0].cst(1.0) - x[0].clone();
23//!         let b = x[1].clone() - x[0].clone() * x[0].clone();
24//!         a.clone() * a + x[0].cst(100.0) * b.clone() * b
25//!     },
26//!     &[1.0, 1.0],
27//! );
28//! // At (1,1), gradient should be (0, 0)
29//! assert!(g[0].abs() < 1e-10);
30//! assert!(g[1].abs() < 1e-10);
31//! ```
32//!
33//! Author: Moussa Leblouba
34//! Date: 9 February 2026
35//! Modified: 2 May 2026
36
37use crate::tape::{Tape, TapeRef};
38use std::ops::{Add, Div, Mul, Neg, Sub};
39use std::rc::Rc;
40
41/// A reverse-mode AD variable tracked on a computation tape.
42///
43/// Arithmetic operations on `Var` automatically record themselves on the
44/// shared tape. After computation, use [`Tape::gradient`] to differentiate.
45#[derive(Clone, Debug)]
46pub struct Var {
47    /// Index of this variable's node in the tape.
48    pub(crate) index: usize,
49    /// Primal (forward) value.
50    pub value: f64,
51    /// Reference to the shared tape.
52    pub(crate) tape: TapeRef,
53}
54
55impl Var {
56    /// Create a constant (not differentiated) on the same tape as `self`.
57    pub fn cst(&self, value: f64) -> Var {
58        let (index, value) = {
59            let mut t = self.tape.borrow_mut();
60            let idx = t.nodes.len();
61            t.nodes.push(crate::tape::Node {
62                value,
63                parent1: None,
64                parent2: None,
65            });
66            (idx, value)
67        };
68        Var {
69            index,
70            value,
71            tape: Rc::clone(&self.tape),
72        }
73    }
74
75    /// sin(self)
76    pub fn sin(&self) -> Var {
77        let val = self.value.sin();
78        let deriv = self.value.cos(); // d sin(x)/dx = cos(x)
79        let (index, value) = Tape::push_unary(&self.tape, val, self.index, deriv);
80        Var {
81            index,
82            value,
83            tape: Rc::clone(&self.tape),
84        }
85    }
86
87    /// cos(self)
88    pub fn cos(&self) -> Var {
89        let val = self.value.cos();
90        let deriv = -self.value.sin(); // d cos(x)/dx = -sin(x)
91        let (index, value) = Tape::push_unary(&self.tape, val, self.index, deriv);
92        Var {
93            index,
94            value,
95            tape: Rc::clone(&self.tape),
96        }
97    }
98
99    /// tan(self)
100    pub fn tan(&self) -> Var {
101        let val = self.value.tan();
102        let c = self.value.cos();
103        let deriv = 1.0 / (c * c); // d tan(x)/dx = sec^2(x)
104        let (index, value) = Tape::push_unary(&self.tape, val, self.index, deriv);
105        Var {
106            index,
107            value,
108            tape: Rc::clone(&self.tape),
109        }
110    }
111
112    /// exp(self)
113    pub fn exp(&self) -> Var {
114        let val = self.value.exp();
115        let deriv = val; // d exp(x)/dx = exp(x)
116        let (index, value) = Tape::push_unary(&self.tape, val, self.index, deriv);
117        Var {
118            index,
119            value,
120            tape: Rc::clone(&self.tape),
121        }
122    }
123
124    /// ln(self) (natural logarithm)
125    pub fn ln(&self) -> Var {
126        let val = self.value.ln();
127        let deriv = 1.0 / self.value; // d ln(x)/dx = 1/x
128        let (index, value) = Tape::push_unary(&self.tape, val, self.index, deriv);
129        Var {
130            index,
131            value,
132            tape: Rc::clone(&self.tape),
133        }
134    }
135
136    /// sqrt(self)
137    pub fn sqrt(&self) -> Var {
138        let val = self.value.sqrt();
139        let deriv = 0.5 / val; // d sqrt(x)/dx = 1/(2*sqrt(x))
140        let (index, value) = Tape::push_unary(&self.tape, val, self.index, deriv);
141        Var {
142            index,
143            value,
144            tape: Rc::clone(&self.tape),
145        }
146    }
147
148    /// abs(self)
149    pub fn abs(&self) -> Var {
150        let val = self.value.abs();
151        let deriv = if self.value >= 0.0 { 1.0 } else { -1.0 };
152        let (index, value) = Tape::push_unary(&self.tape, val, self.index, deriv);
153        Var {
154            index,
155            value,
156            tape: Rc::clone(&self.tape),
157        }
158    }
159
160    /// tanh(self)
161    pub fn tanh(&self) -> Var {
162        let val = self.value.tanh();
163        let deriv = 1.0 - val * val; // d tanh(x)/dx = 1 - tanh^2(x)
164        let (index, value) = Tape::push_unary(&self.tape, val, self.index, deriv);
165        Var {
166            index,
167            value,
168            tape: Rc::clone(&self.tape),
169        }
170    }
171
172    /// sinh(self)
173    pub fn sinh(&self) -> Var {
174        let val = self.value.sinh();
175        let deriv = self.value.cosh();
176        let (index, value) = Tape::push_unary(&self.tape, val, self.index, deriv);
177        Var {
178            index,
179            value,
180            tape: Rc::clone(&self.tape),
181        }
182    }
183
184    /// cosh(self)
185    pub fn cosh(&self) -> Var {
186        let val = self.value.cosh();
187        let deriv = self.value.sinh();
188        let (index, value) = Tape::push_unary(&self.tape, val, self.index, deriv);
189        Var {
190            index,
191            value,
192            tape: Rc::clone(&self.tape),
193        }
194    }
195
196    /// self^n (power with another Var)
197    pub fn pow(&self, n: &Var) -> Var {
198        let val = self.value.powf(n.value);
199        // d(x^y)/dx = y * x^(y-1)
200        let d_self = n.value * self.value.powf(n.value - 1.0);
201        // d(x^y)/dy = x^y * ln(x)
202        let d_n = val * self.value.ln();
203        let (index, value) = Tape::push_binary(&self.tape, val, self.index, d_self, n.index, d_n);
204        Var {
205            index,
206            value,
207            tape: Rc::clone(&self.tape),
208        }
209    }
210
211    /// self^n (power with f64 constant)
212    pub fn powf(&self, n: f64) -> Var {
213        let val = self.value.powf(n);
214        let deriv = n * self.value.powf(n - 1.0);
215        let (index, value) = Tape::push_unary(&self.tape, val, self.index, deriv);
216        Var {
217            index,
218            value,
219            tape: Rc::clone(&self.tape),
220        }
221    }
222
223    /// self^n (integer power)
224    pub fn powi(&self, n: i32) -> Var {
225        self.powf(n as f64)
226    }
227
228    /// asin(self)
229    pub fn asin(&self) -> Var {
230        let val = self.value.asin();
231        let deriv = 1.0 / (1.0 - self.value * self.value).sqrt();
232        let (index, value) = Tape::push_unary(&self.tape, val, self.index, deriv);
233        Var {
234            index,
235            value,
236            tape: Rc::clone(&self.tape),
237        }
238    }
239
240    /// acos(self)
241    pub fn acos(&self) -> Var {
242        let val = self.value.acos();
243        let deriv = -1.0 / (1.0 - self.value * self.value).sqrt();
244        let (index, value) = Tape::push_unary(&self.tape, val, self.index, deriv);
245        Var {
246            index,
247            value,
248            tape: Rc::clone(&self.tape),
249        }
250    }
251
252    /// atan(self)
253    pub fn atan(&self) -> Var {
254        let val = self.value.atan();
255        let deriv = 1.0 / (1.0 + self.value * self.value);
256        let (index, value) = Tape::push_unary(&self.tape, val, self.index, deriv);
257        Var {
258            index,
259            value,
260            tape: Rc::clone(&self.tape),
261        }
262    }
263}
264
265// ==================== Operator overloading ====================
266
267// Var + Var
268impl Add for Var {
269    type Output = Var;
270    fn add(self, rhs: Var) -> Var {
271        let val = self.value + rhs.value;
272        let (index, value) = Tape::push_binary(&self.tape, val, self.index, 1.0, rhs.index, 1.0);
273        Var {
274            index,
275            value,
276            tape: self.tape,
277        }
278    }
279}
280
281// &Var + &Var
282impl Add for &Var {
283    type Output = Var;
284    fn add(self, rhs: &Var) -> Var {
285        let val = self.value + rhs.value;
286        let (index, value) = Tape::push_binary(&self.tape, val, self.index, 1.0, rhs.index, 1.0);
287        Var {
288            index,
289            value,
290            tape: Rc::clone(&self.tape),
291        }
292    }
293}
294
295// Var + f64
296impl Add<f64> for Var {
297    type Output = Var;
298    fn add(self, rhs: f64) -> Var {
299        let val = self.value + rhs;
300        let (index, value) = Tape::push_unary(&self.tape, val, self.index, 1.0);
301        Var {
302            index,
303            value,
304            tape: self.tape,
305        }
306    }
307}
308
309// f64 + Var
310impl Add<Var> for f64 {
311    type Output = Var;
312    fn add(self, rhs: Var) -> Var {
313        let val = self + rhs.value;
314        let (index, value) = Tape::push_unary(&rhs.tape, val, rhs.index, 1.0);
315        Var {
316            index,
317            value,
318            tape: rhs.tape,
319        }
320    }
321}
322
323// Var - Var
324impl Sub for Var {
325    type Output = Var;
326    fn sub(self, rhs: Var) -> Var {
327        let val = self.value - rhs.value;
328        let (index, value) = Tape::push_binary(&self.tape, val, self.index, 1.0, rhs.index, -1.0);
329        Var {
330            index,
331            value,
332            tape: self.tape,
333        }
334    }
335}
336
337// &Var - &Var
338impl Sub for &Var {
339    type Output = Var;
340    fn sub(self, rhs: &Var) -> Var {
341        let val = self.value - rhs.value;
342        let (index, value) = Tape::push_binary(&self.tape, val, self.index, 1.0, rhs.index, -1.0);
343        Var {
344            index,
345            value,
346            tape: Rc::clone(&self.tape),
347        }
348    }
349}
350
351// Var - f64
352impl Sub<f64> for Var {
353    type Output = Var;
354    fn sub(self, rhs: f64) -> Var {
355        let val = self.value - rhs;
356        let (index, value) = Tape::push_unary(&self.tape, val, self.index, 1.0);
357        Var {
358            index,
359            value,
360            tape: self.tape,
361        }
362    }
363}
364
365// f64 - Var
366impl Sub<Var> for f64 {
367    type Output = Var;
368    fn sub(self, rhs: Var) -> Var {
369        let val = self - rhs.value;
370        let (index, value) = Tape::push_unary(&rhs.tape, val, rhs.index, -1.0);
371        Var {
372            index,
373            value,
374            tape: rhs.tape,
375        }
376    }
377}
378
379// Var * Var
380impl Mul for Var {
381    type Output = Var;
382    fn mul(self, rhs: Var) -> Var {
383        let val = self.value * rhs.value;
384        let (index, value) = Tape::push_binary(
385            &self.tape, val, self.index, rhs.value, // d(x*y)/dx = y
386            rhs.index, self.value, // d(x*y)/dy = x
387        );
388        Var {
389            index,
390            value,
391            tape: self.tape,
392        }
393    }
394}
395
396// &Var * &Var
397impl Mul for &Var {
398    type Output = Var;
399    fn mul(self, rhs: &Var) -> Var {
400        let val = self.value * rhs.value;
401        let (index, value) = Tape::push_binary(
402            &self.tape, val, self.index, rhs.value, rhs.index, self.value,
403        );
404        Var {
405            index,
406            value,
407            tape: Rc::clone(&self.tape),
408        }
409    }
410}
411
412// Var * f64
413impl Mul<f64> for Var {
414    type Output = Var;
415    fn mul(self, rhs: f64) -> Var {
416        let val = self.value * rhs;
417        let (index, value) = Tape::push_unary(&self.tape, val, self.index, rhs);
418        Var {
419            index,
420            value,
421            tape: self.tape,
422        }
423    }
424}
425
426// f64 * Var
427impl Mul<Var> for f64 {
428    type Output = Var;
429    fn mul(self, rhs: Var) -> Var {
430        let val = self * rhs.value;
431        let (index, value) = Tape::push_unary(&rhs.tape, val, rhs.index, self);
432        Var {
433            index,
434            value,
435            tape: rhs.tape,
436        }
437    }
438}
439
440// Var * &Var
441impl Mul<&Var> for Var {
442    type Output = Var;
443    fn mul(self, rhs: &Var) -> Var {
444        let val = self.value * rhs.value;
445        let (index, value) = Tape::push_binary(
446            &self.tape, val, self.index, rhs.value, rhs.index, self.value,
447        );
448        Var {
449            index,
450            value,
451            tape: self.tape,
452        }
453    }
454}
455
456// &Var * Var
457impl Mul<Var> for &Var {
458    type Output = Var;
459    fn mul(self, rhs: Var) -> Var {
460        let val = self.value * rhs.value;
461        let (index, value) = Tape::push_binary(
462            &self.tape, val, self.index, rhs.value, rhs.index, self.value,
463        );
464        Var {
465            index,
466            value,
467            tape: rhs.tape,
468        }
469    }
470}
471
472// Var / Var
473impl Div for Var {
474    type Output = Var;
475    fn div(self, rhs: Var) -> Var {
476        let val = self.value / rhs.value;
477        let (index, value) = Tape::push_binary(
478            &self.tape,
479            val,
480            self.index,
481            1.0 / rhs.value, // d(x/y)/dx = 1/y
482            rhs.index,
483            -self.value / (rhs.value * rhs.value), // d(x/y)/dy = -x/y^2
484        );
485        Var {
486            index,
487            value,
488            tape: self.tape,
489        }
490    }
491}
492
493// &Var / &Var
494impl Div for &Var {
495    type Output = Var;
496    fn div(self, rhs: &Var) -> Var {
497        let val = self.value / rhs.value;
498        let (index, value) = Tape::push_binary(
499            &self.tape,
500            val,
501            self.index,
502            1.0 / rhs.value,
503            rhs.index,
504            -self.value / (rhs.value * rhs.value),
505        );
506        Var {
507            index,
508            value,
509            tape: Rc::clone(&self.tape),
510        }
511    }
512}
513
514// Var / f64
515impl Div<f64> for Var {
516    type Output = Var;
517    fn div(self, rhs: f64) -> Var {
518        let val = self.value / rhs;
519        let (index, value) = Tape::push_unary(&self.tape, val, self.index, 1.0 / rhs);
520        Var {
521            index,
522            value,
523            tape: self.tape,
524        }
525    }
526}
527
528// -Var
529impl Neg for Var {
530    type Output = Var;
531    fn neg(self) -> Var {
532        let val = -self.value;
533        let (index, value) = Tape::push_unary(&self.tape, val, self.index, -1.0);
534        Var {
535            index,
536            value,
537            tape: self.tape,
538        }
539    }
540}
541
542// -&Var
543impl Neg for &Var {
544    type Output = Var;
545    fn neg(self) -> Var {
546        let val = -self.value;
547        let (index, value) = Tape::push_unary(&self.tape, val, self.index, -1.0);
548        Var {
549            index,
550            value,
551            tape: Rc::clone(&self.tape),
552        }
553    }
554}
555
556// ==================== Convenience functions ====================
557
558/// Compute the gradient of a scalar function f: R^n -> R using reverse-mode AD.
559///
560/// This is more efficient than forward-mode [`gradient`](crate::gradient()) when n is large,
561/// as it requires only a single backward pass regardless of n.
562///
563/// # Example
564///
565/// ```rust
566/// use numra_autodiff::reverse::grad;
567///
568/// let g = grad(|x| x[0].clone() * x[0].clone() + x[1].clone() * x[1].clone(), &[3.0, 4.0]);
569/// assert!((g[0] - 6.0).abs() < 1e-12);
570/// assert!((g[1] - 8.0).abs() < 1e-12);
571/// ```
572pub fn grad<F>(f: F, x: &[f64]) -> Vec<f64>
573where
574    F: Fn(&[Var]) -> Var,
575{
576    let tape = Tape::new();
577    let vars: Vec<Var> = x.iter().map(|&xi| Tape::var(&tape, xi)).collect();
578    let output = f(&vars);
579    Tape::gradient(&tape, &output)
580}
581
582/// Compute the Jacobian of a vector function f: R^n -> R^m using reverse-mode AD.
583///
584/// Returns an m x n matrix (Vec of Vec), where row i is the gradient of output i.
585///
586/// # Example
587///
588/// ```rust
589/// use numra_autodiff::reverse::jacobian_reverse;
590///
591/// // f(x,y) = (x+y, x*y)
592/// let jac = jacobian_reverse(
593///     |x| vec![&x[0] + &x[1], &x[0] * &x[1]],
594///     &[2.0, 3.0],
595/// );
596/// assert!((jac[0][0] - 1.0).abs() < 1e-12); // df1/dx
597/// assert!((jac[1][0] - 3.0).abs() < 1e-12); // df2/dx = y
598/// ```
599pub fn jacobian_reverse<F>(f: F, x: &[f64]) -> Vec<Vec<f64>>
600where
601    F: Fn(&[Var]) -> Vec<Var>,
602{
603    let tape = Tape::new();
604    let vars: Vec<Var> = x.iter().map(|&xi| Tape::var(&tape, xi)).collect();
605    let outputs = f(&vars);
606    Tape::jacobian(&tape, &outputs)
607}
608
609/// Compute the Hessian of a scalar function f: R^n -> R.
610///
611/// Uses finite-difference of reverse-mode gradients. For each input variable,
612/// perturbs slightly and recomputes the gradient to get the Hessian row.
613///
614/// Returns an n x n matrix as `Vec<Vec<f64>>`.
615///
616/// # Example
617///
618/// ```rust
619/// use numra_autodiff::reverse::hessian;
620///
621/// // f(x,y) = x^2 + 2*x*y + 3*y^2
622/// let h = hessian(|x| {
623///     &x[0] * &x[0] + x[0].cst(2.0) * &x[0] * &x[1] + x[0].cst(3.0) * &x[1] * &x[1]
624/// }, &[1.0, 1.0]);
625/// assert!((h[0][0] - 2.0).abs() < 1e-6);  // d2f/dx2 = 2
626/// assert!((h[0][1] - 2.0).abs() < 1e-6);  // d2f/dxdy = 2
627/// assert!((h[1][0] - 2.0).abs() < 1e-6);  // d2f/dydx = 2
628/// assert!((h[1][1] - 6.0).abs() < 1e-6);  // d2f/dy2 = 6
629/// ```
630pub fn hessian<F>(f: F, x: &[f64]) -> Vec<Vec<f64>>
631where
632    F: Fn(&[Var]) -> Var,
633{
634    let n = x.len();
635    let eps = 1e-7;
636    let mut h = vec![vec![0.0; n]; n];
637
638    let g0 = grad(&f, x);
639
640    for j in 0..n {
641        let mut x_pert = x.to_vec();
642        x_pert[j] += eps;
643        let g_pert = grad(&f, &x_pert);
644
645        for i in 0..n {
646            h[i][j] = (g_pert[i] - g0[i]) / eps;
647        }
648    }
649
650    h
651}
652
653#[cfg(test)]
654mod tests {
655    use super::*;
656
657    #[test]
658    fn test_basic_arithmetic() {
659        let tape = Tape::new();
660        let x = Tape::var(&tape, 2.0);
661        let y = Tape::var(&tape, 3.0);
662
663        // z = x + y = 5
664        let z = x.clone() + y.clone();
665        assert!((z.value - 5.0).abs() < 1e-14);
666        let g = Tape::gradient(&tape, &z);
667        assert!((g[0] - 1.0).abs() < 1e-14);
668        assert!((g[1] - 1.0).abs() < 1e-14);
669    }
670
671    #[test]
672    fn test_multiplication() {
673        let tape = Tape::new();
674        let x = Tape::var(&tape, 2.0);
675        let y = Tape::var(&tape, 3.0);
676
677        // z = x * y = 6
678        let z = x * y;
679        let g = Tape::gradient(&tape, &z);
680        assert!((g[0] - 3.0).abs() < 1e-14); // dz/dx = y
681        assert!((g[1] - 2.0).abs() < 1e-14); // dz/dy = x
682    }
683
684    #[test]
685    fn test_chain_rule() {
686        let tape = Tape::new();
687        let x = Tape::var(&tape, 2.0);
688
689        // z = (x * x) * x = x^3, dz/dx = 3x^2 = 12
690        let z = x.clone() * x.clone() * x;
691        let g = Tape::gradient(&tape, &z);
692        assert!((g[0] - 12.0).abs() < 1e-12);
693    }
694
695    #[test]
696    fn test_subtraction() {
697        let tape = Tape::new();
698        let x = Tape::var(&tape, 5.0);
699        let y = Tape::var(&tape, 3.0);
700
701        let z = x - y;
702        assert!((z.value - 2.0).abs() < 1e-14);
703        let g = Tape::gradient(&tape, &z);
704        assert!((g[0] - 1.0).abs() < 1e-14);
705        assert!((g[1] - (-1.0)).abs() < 1e-14);
706    }
707
708    #[test]
709    fn test_division() {
710        let tape = Tape::new();
711        let x = Tape::var(&tape, 6.0);
712        let y = Tape::var(&tape, 3.0);
713
714        let z = x / y; // z = 2
715        let g = Tape::gradient(&tape, &z);
716        assert!((g[0] - 1.0 / 3.0).abs() < 1e-14); // dz/dx = 1/y
717        assert!((g[1] - (-6.0 / 9.0)).abs() < 1e-14); // dz/dy = -x/y^2
718    }
719
720    #[test]
721    fn test_negation() {
722        let tape = Tape::new();
723        let x = Tape::var(&tape, 3.0);
724        let z = -x;
725        assert!((z.value - (-3.0)).abs() < 1e-14);
726        let g = Tape::gradient(&tape, &z);
727        assert!((g[0] - (-1.0)).abs() < 1e-14);
728    }
729
730    #[test]
731    fn test_sin_cos() {
732        let tape = Tape::new();
733        let x = Tape::var(&tape, 1.0);
734
735        let s = x.sin();
736        let g = Tape::gradient(&tape, &s);
737        assert!((g[0] - 1.0_f64.cos()).abs() < 1e-14);
738
739        // New tape for cos
740        let tape2 = Tape::new();
741        let x2 = Tape::var(&tape2, 1.0);
742        let c = x2.cos();
743        let g2 = Tape::gradient(&tape2, &c);
744        assert!((g2[0] - (-1.0_f64.sin())).abs() < 1e-14);
745    }
746
747    #[test]
748    fn test_exp_ln() {
749        let tape = Tape::new();
750        let x = Tape::var(&tape, 2.0);
751        let e = x.exp();
752        let g = Tape::gradient(&tape, &e);
753        assert!((g[0] - 2.0_f64.exp()).abs() < 1e-12);
754
755        let tape2 = Tape::new();
756        let x2 = Tape::var(&tape2, 3.0);
757        let l = x2.ln();
758        let g2 = Tape::gradient(&tape2, &l);
759        assert!((g2[0] - 1.0 / 3.0).abs() < 1e-14);
760    }
761
762    #[test]
763    fn test_sqrt() {
764        let tape = Tape::new();
765        let x = Tape::var(&tape, 4.0);
766        let s = x.sqrt();
767        assert!((s.value - 2.0).abs() < 1e-14);
768        let g = Tape::gradient(&tape, &s);
769        assert!((g[0] - 0.25).abs() < 1e-14); // 1/(2*sqrt(4)) = 0.25
770    }
771
772    #[test]
773    fn test_tanh() {
774        let tape = Tape::new();
775        let x = Tape::var(&tape, 1.0);
776        let t = x.tanh();
777        let g = Tape::gradient(&tape, &t);
778        let expected = 1.0 - 1.0_f64.tanh().powi(2);
779        assert!((g[0] - expected).abs() < 1e-14);
780    }
781
782    #[test]
783    fn test_powf() {
784        let tape = Tape::new();
785        let x = Tape::var(&tape, 3.0);
786        let p = x.powf(2.0); // x^2
787        assert!((p.value - 9.0).abs() < 1e-14);
788        let g = Tape::gradient(&tape, &p);
789        assert!((g[0] - 6.0).abs() < 1e-12); // d(x^2)/dx = 2x = 6
790    }
791
792    #[test]
793    fn test_pow_var() {
794        let tape = Tape::new();
795        let x = Tape::var(&tape, 2.0);
796        let y = Tape::var(&tape, 3.0);
797        let p = x.pow(&y); // x^y = 8
798        assert!((p.value - 8.0).abs() < 1e-12);
799        let g = Tape::gradient(&tape, &p);
800        // d(x^y)/dx = y * x^(y-1) = 3 * 4 = 12
801        assert!((g[0] - 12.0).abs() < 1e-10);
802        // d(x^y)/dy = x^y * ln(x) = 8 * ln(2)
803        assert!((g[1] - 8.0 * 2.0_f64.ln()).abs() < 1e-10);
804    }
805
806    #[test]
807    fn test_scalar_ops() {
808        let tape = Tape::new();
809        let x = Tape::var(&tape, 3.0);
810
811        // x + 2.0
812        let z1 = x.clone() + 2.0;
813        assert!((z1.value - 5.0).abs() < 1e-14);
814
815        // 2.0 + x
816        let z2 = 2.0 + x.clone();
817        assert!((z2.value - 5.0).abs() < 1e-14);
818
819        // x * 3.0
820        let z3 = x.clone() * 3.0;
821        assert!((z3.value - 9.0).abs() < 1e-14);
822
823        // 3.0 * x
824        let z4 = 3.0 * x.clone();
825        assert!((z4.value - 9.0).abs() < 1e-14);
826
827        // x - 1.0
828        let z5 = x.clone() - 1.0;
829        assert!((z5.value - 2.0).abs() < 1e-14);
830
831        // 10.0 - x
832        let z6 = 10.0 - x.clone();
833        assert!((z6.value - 7.0).abs() < 1e-14);
834
835        // x / 2.0
836        let z7 = x / 2.0;
837        assert!((z7.value - 1.5).abs() < 1e-14);
838    }
839
840    #[test]
841    fn test_reference_ops() {
842        let tape = Tape::new();
843        let x = Tape::var(&tape, 2.0);
844        let y = Tape::var(&tape, 3.0);
845
846        let z = &x + &y;
847        assert!((z.value - 5.0).abs() < 1e-14);
848
849        let z2 = &x * &y;
850        assert!((z2.value - 6.0).abs() < 1e-14);
851
852        let z3 = &x - &y;
853        assert!((z3.value - (-1.0)).abs() < 1e-14);
854
855        let z4 = &x / &y;
856        assert!((z4.value - 2.0 / 3.0).abs() < 1e-14);
857    }
858
859    #[test]
860    fn test_grad_rosenbrock() {
861        // f(x,y) = (1-x)^2 + 100*(y-x^2)^2
862        let g = grad(
863            |x| {
864                let a = x[0].cst(1.0) - x[0].clone(); // 1 - x
865                let b = x[1].clone() - x[0].clone() * x[0].clone(); // y - x^2
866                a.clone() * a + x[0].cst(100.0) * b.clone() * b
867            },
868            &[1.0, 1.0],
869        );
870        // At the minimum (1,1), gradient = (0, 0)
871        assert!(g[0].abs() < 1e-10);
872        assert!(g[1].abs() < 1e-10);
873    }
874
875    #[test]
876    fn test_grad_rosenbrock_nonzero() {
877        // At (0, 0): df/dx = -2(1-x) - 400x(y-x^2) = -2, df/dy = 200(y-x^2) = 0
878        let g = grad(
879            |x| {
880                let a = x[0].cst(1.0) - x[0].clone();
881                let b = x[1].clone() - x[0].clone() * x[0].clone();
882                a.clone() * a + x[0].cst(100.0) * b.clone() * b
883            },
884            &[0.0, 0.0],
885        );
886        assert!((g[0] - (-2.0)).abs() < 1e-10);
887        assert!(g[1].abs() < 1e-10);
888    }
889
890    #[test]
891    fn test_jacobian_reverse_fn() {
892        // f(x,y) = (x+y, x*y)
893        let jac = jacobian_reverse(|x| vec![&x[0] + &x[1], &x[0] * &x[1]], &[2.0, 3.0]);
894        assert_eq!(jac.len(), 2);
895        assert!((jac[0][0] - 1.0).abs() < 1e-14);
896        assert!((jac[0][1] - 1.0).abs() < 1e-14);
897        assert!((jac[1][0] - 3.0).abs() < 1e-14);
898        assert!((jac[1][1] - 2.0).abs() < 1e-14);
899    }
900
901    #[test]
902    fn test_jacobian_rotation() {
903        // Rotation by angle theta: (x*cos(theta) - y*sin(theta), x*sin(theta) + y*cos(theta))
904        // Jacobian = [[cos, -sin], [sin, cos]]
905        let theta: f64 = 0.5;
906        let jac = jacobian_reverse(
907            |v| {
908                let x = &v[0];
909                let y = &v[1];
910                let ct = x.cst(theta.cos());
911                let st = x.cst(theta.sin());
912                vec![&(x * &ct) - &(y * &st), &(x * &st) + &(y * &ct)]
913            },
914            &[1.0, 0.0],
915        );
916        assert!((jac[0][0] - theta.cos()).abs() < 1e-12);
917        assert!((jac[0][1] - (-theta.sin())).abs() < 1e-12);
918        assert!((jac[1][0] - theta.sin()).abs() < 1e-12);
919        assert!((jac[1][1] - theta.cos()).abs() < 1e-12);
920    }
921
922    #[test]
923    fn test_hessian_quadratic() {
924        // f(x,y) = x^2 + 2*x*y + 3*y^2
925        // Hessian = [[2, 2], [2, 6]]
926        let h = hessian(
927            |x| &x[0] * &x[0] + x[0].cst(2.0) * &x[0] * &x[1] + x[0].cst(3.0) * &x[1] * &x[1],
928            &[1.0, 1.0],
929        );
930        assert!((h[0][0] - 2.0).abs() < 1e-5);
931        assert!((h[0][1] - 2.0).abs() < 1e-5);
932        assert!((h[1][0] - 2.0).abs() < 1e-5);
933        assert!((h[1][1] - 6.0).abs() < 1e-5);
934    }
935
936    #[test]
937    fn test_hessian_rosenbrock() {
938        // Rosenbrock at (1,1): Hessian = [[802, -400], [-400, 200]]
939        let h = hessian(
940            |x| {
941                let a = x[0].cst(1.0) - x[0].clone();
942                let b = x[1].clone() - x[0].clone() * x[0].clone();
943                a.clone() * a + x[0].cst(100.0) * b.clone() * b
944            },
945            &[1.0, 1.0],
946        );
947        assert!((h[0][0] - 802.0).abs() < 1e-3);
948        assert!((h[0][1] - (-400.0)).abs() < 1e-3);
949        assert!((h[1][0] - (-400.0)).abs() < 1e-3);
950        assert!((h[1][1] - 200.0).abs() < 1e-3);
951    }
952
953    #[test]
954    fn test_inverse_trig() {
955        // asin
956        let tape = Tape::new();
957        let x = Tape::var(&tape, 0.5);
958        let a = x.asin();
959        let g = Tape::gradient(&tape, &a);
960        assert!((g[0] - 1.0 / (1.0 - 0.25_f64).sqrt()).abs() < 1e-12);
961
962        // atan
963        let tape2 = Tape::new();
964        let x2 = Tape::var(&tape2, 1.0);
965        let a2 = x2.atan();
966        let g2 = Tape::gradient(&tape2, &a2);
967        assert!((g2[0] - 0.5).abs() < 1e-14); // 1/(1+1^2) = 0.5
968    }
969
970    #[test]
971    fn test_grad_matches_fd() {
972        // Verify reverse-mode gradient matches finite differences
973        let f = |x: &[Var]| x[0].sin() * x[1].exp() + x[2].clone() * x[2].clone();
974        let x0 = [1.0, 2.0, 3.0];
975        let g = grad(f, &x0);
976
977        // Finite difference
978        let eps = 1e-7;
979        let f_val = |x: &[f64]| x[0].sin() * x[1].exp() + x[2] * x[2];
980        let f0 = f_val(&x0);
981        for i in 0..3 {
982            let mut xp = x0;
983            xp[i] += eps;
984            let fd = (f_val(&xp) - f0) / eps;
985            assert!(
986                (g[i] - fd).abs() < 1e-5,
987                "component {} mismatch: {} vs {}",
988                i,
989                g[i],
990                fd
991            );
992        }
993    }
994
995    #[test]
996    fn test_constant() {
997        let tape = Tape::new();
998        let x = Tape::var(&tape, 2.0);
999        let c = x.cst(5.0);
1000        let z = x * c; // z = 2 * 5 = 10, dz/dx = 5
1001        let g = Tape::gradient(&tape, &z);
1002        assert!((g[0] - 5.0).abs() < 1e-14);
1003    }
1004
1005    #[test]
1006    fn test_complex_composition() {
1007        // f(x) = sin(x^2 + exp(x))
1008        let g = grad(
1009            |x| {
1010                let x2 = &x[0] * &x[0];
1011                let ex = x[0].exp();
1012                let inner = x2 + ex;
1013                inner.sin()
1014            },
1015            &[1.0],
1016        );
1017        // df/dx = cos(x^2+exp(x)) * (2x + exp(x))
1018        let x = 1.0_f64;
1019        let inner = x * x + x.exp();
1020        let expected = inner.cos() * (2.0 * x + x.exp());
1021        assert!((g[0] - expected).abs() < 1e-10);
1022    }
1023}