Skip to main content

oxiphysics_core/
autodiff.rs

1// Copyright 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3
4//! Forward-mode automatic differentiation using dual numbers.
5//!
6//! A [`Dual`] number `(v, dv)` carries a value and its derivative
7//! simultaneously.  All arithmetic operations and transcendental functions
8//! apply the chain rule so that the derivative tracks through expressions
9//! automatically.
10//!
11//! # Quick start
12//!
13//! ```text
14//! use oxiphysics_core::autodiff::{dual, grad1};
15//! let f = |x: oxiphysics_core::autodiff::Dual| x * x + x;  // f(x) = x² + x
16//! let df = grad1(f, 3.0);                                   // f'(3) = 2·3 + 1 = 7
17//! assert!((df - 7.0).abs() < 1e-12);
18//! ```
19
20use std::ops::{Add, Div, Mul, Neg, Sub};
21
22// ---------------------------------------------------------------------------
23// Dual number
24// ---------------------------------------------------------------------------
25
26/// A dual number `(v, dv)` representing a value and its first derivative.
27///
28/// All arithmetic operators and elementary functions implement the standard
29/// chain-rule so that derivatives propagate automatically.
30#[derive(Debug, Clone, Copy, PartialEq)]
31pub struct Dual {
32    /// Primal (real) value.
33    pub v: f64,
34    /// Derivative (dual) component.
35    pub dv: f64,
36}
37
38/// Construct a [`Dual`] from a value and a derivative seed.
39///
40/// Set `dx = 1.0` to differentiate with respect to this variable,
41/// or `dx = 0.0` to treat it as a constant.
42pub fn dual(x: f64, dx: f64) -> Dual {
43    Dual { v: x, dv: dx }
44}
45
46impl Dual {
47    /// Create a dual number representing the constant `c` (derivative = 0).
48    #[allow(dead_code)]
49    pub fn constant(c: f64) -> Self {
50        Dual { v: c, dv: 0.0 }
51    }
52
53    /// Create a dual number representing the variable `x` with seed `dx = 1`.
54    #[allow(dead_code)]
55    pub fn variable(x: f64) -> Self {
56        Dual { v: x, dv: 1.0 }
57    }
58
59    /// Sine with chain rule: `sin(u)' = cos(u) * u'`.
60    pub fn sin(self) -> Self {
61        Dual {
62            v: self.v.sin(),
63            dv: self.v.cos() * self.dv,
64        }
65    }
66
67    /// Cosine with chain rule: `cos(u)' = -sin(u) * u'`.
68    pub fn cos(self) -> Self {
69        Dual {
70            v: self.v.cos(),
71            dv: -self.v.sin() * self.dv,
72        }
73    }
74
75    /// Natural exponential with chain rule: `exp(u)' = exp(u) * u'`.
76    pub fn exp(self) -> Self {
77        let ev = self.v.exp();
78        Dual {
79            v: ev,
80            dv: ev * self.dv,
81        }
82    }
83
84    /// Natural logarithm with chain rule: `ln(u)' = u'/u`.
85    pub fn ln(self) -> Self {
86        Dual {
87            v: self.v.ln(),
88            dv: self.dv / self.v,
89        }
90    }
91
92    /// Square root with chain rule: `sqrt(u)' = u' / (2 * sqrt(u))`.
93    pub fn sqrt(self) -> Self {
94        let sv = self.v.sqrt();
95        Dual {
96            v: sv,
97            dv: self.dv / (2.0 * sv),
98        }
99    }
100
101    /// Absolute value with (sub)derivative: `|u|' = sign(u) * u'`.
102    pub fn abs(self) -> Self {
103        Dual {
104            v: self.v.abs(),
105            dv: self.v.signum() * self.dv,
106        }
107    }
108
109    /// Integer power with chain rule: `u^n' = n * u^(n-1) * u'`.
110    pub fn powi(self, n: i32) -> Self {
111        Dual {
112            v: self.v.powi(n),
113            dv: (n as f64) * self.v.powi(n - 1) * self.dv,
114        }
115    }
116
117    /// Floating-point power: `u^p' = p * u^(p-1) * u'`.
118    pub fn powf(self, p: f64) -> Self {
119        Dual {
120            v: self.v.powf(p),
121            dv: p * self.v.powf(p - 1.0) * self.dv,
122        }
123    }
124
125    /// Tangent: `tan(u)' = u' / cos²(u)`.
126    #[allow(dead_code)]
127    pub fn tan(self) -> Self {
128        let c = self.v.cos();
129        Dual {
130            v: self.v.tan(),
131            dv: self.dv / (c * c),
132        }
133    }
134
135    /// Hyperbolic sine: `sinh(u)' = cosh(u) * u'`.
136    #[allow(dead_code)]
137    pub fn sinh(self) -> Self {
138        Dual {
139            v: self.v.sinh(),
140            dv: self.v.cosh() * self.dv,
141        }
142    }
143
144    /// Hyperbolic cosine: `cosh(u)' = sinh(u) * u'`.
145    #[allow(dead_code)]
146    pub fn cosh(self) -> Self {
147        Dual {
148            v: self.v.cosh(),
149            dv: self.v.sinh() * self.dv,
150        }
151    }
152}
153
154// ---------------------------------------------------------------------------
155// Arithmetic operators
156// ---------------------------------------------------------------------------
157
158impl Add for Dual {
159    type Output = Self;
160    fn add(self, rhs: Self) -> Self {
161        Dual {
162            v: self.v + rhs.v,
163            dv: self.dv + rhs.dv,
164        }
165    }
166}
167
168impl Add<f64> for Dual {
169    type Output = Self;
170    fn add(self, rhs: f64) -> Self {
171        Dual {
172            v: self.v + rhs,
173            dv: self.dv,
174        }
175    }
176}
177
178impl Add<Dual> for f64 {
179    type Output = Dual;
180    fn add(self, rhs: Dual) -> Dual {
181        Dual {
182            v: self + rhs.v,
183            dv: rhs.dv,
184        }
185    }
186}
187
188impl Sub for Dual {
189    type Output = Self;
190    fn sub(self, rhs: Self) -> Self {
191        Dual {
192            v: self.v - rhs.v,
193            dv: self.dv - rhs.dv,
194        }
195    }
196}
197
198impl Sub<f64> for Dual {
199    type Output = Self;
200    fn sub(self, rhs: f64) -> Self {
201        Dual {
202            v: self.v - rhs,
203            dv: self.dv,
204        }
205    }
206}
207
208impl Sub<Dual> for f64 {
209    type Output = Dual;
210    fn sub(self, rhs: Dual) -> Dual {
211        Dual {
212            v: self - rhs.v,
213            dv: -rhs.dv,
214        }
215    }
216}
217
218impl Mul for Dual {
219    type Output = Self;
220    fn mul(self, rhs: Self) -> Self {
221        Dual {
222            v: self.v * rhs.v,
223            dv: self.dv * rhs.v + self.v * rhs.dv,
224        }
225    }
226}
227
228impl Mul<f64> for Dual {
229    type Output = Self;
230    fn mul(self, rhs: f64) -> Self {
231        Dual {
232            v: self.v * rhs,
233            dv: self.dv * rhs,
234        }
235    }
236}
237
238impl Mul<Dual> for f64 {
239    type Output = Dual;
240    fn mul(self, rhs: Dual) -> Dual {
241        Dual {
242            v: self * rhs.v,
243            dv: self * rhs.dv,
244        }
245    }
246}
247
248impl Div for Dual {
249    type Output = Self;
250    fn div(self, rhs: Self) -> Self {
251        Dual {
252            v: self.v / rhs.v,
253            dv: (self.dv * rhs.v - self.v * rhs.dv) / (rhs.v * rhs.v),
254        }
255    }
256}
257
258impl Div<f64> for Dual {
259    type Output = Self;
260    fn div(self, rhs: f64) -> Self {
261        Dual {
262            v: self.v / rhs,
263            dv: self.dv / rhs,
264        }
265    }
266}
267
268impl Neg for Dual {
269    type Output = Self;
270    fn neg(self) -> Self {
271        Dual {
272            v: -self.v,
273            dv: -self.dv,
274        }
275    }
276}
277
278// ---------------------------------------------------------------------------
279// Convenience differentiation utilities
280// ---------------------------------------------------------------------------
281
282/// Compute the first derivative of `f` at `x` using forward-mode AD.
283///
284/// Sets the seed derivative to 1 and evaluates `f(dual(x, 1.0)).dv`.
285pub fn grad1(f: impl Fn(Dual) -> Dual, x: f64) -> f64 {
286    f(dual(x, 1.0)).dv
287}
288
289/// Compute the diagonal of the Hessian of a scalar function `f: Rⁿ → R`.
290///
291/// The second derivative `∂²f/∂xᵢ²` is approximated by perturbing the
292/// `i`-th coordinate with a small step `h` and differencing derivatives:
293///
294/// ```text
295/// d²f/dxᵢ² ≈ (f'(xᵢ + h) - f'(xᵢ - h)) / (2h)
296/// ```
297///
298/// where `f'(xᵢ)` is obtained via forward-mode AD holding all other
299/// coordinates constant.
300pub fn hessian_diag(f: impl Fn(Dual) -> Dual, xs: &[f64]) -> Vec<f64> {
301    let h = 1e-5;
302    xs.iter()
303        .map(|&xi| {
304            let fp = f(dual(xi + h, 1.0)).dv;
305            let fm = f(dual(xi - h, 1.0)).dv;
306            (fp - fm) / (2.0 * h)
307        })
308        .collect()
309}
310
311/// Compute one row of the Jacobian: `∂f/∂xᵢ` for all `i`.
312///
313/// `f` maps a slice of dual numbers to a single dual output.
314/// For each index `i`, we seed `xs[i]` with derivative 1 and evaluate.
315pub fn jacobian_row(f: impl Fn(&[Dual]) -> Dual, xs: &[f64]) -> Vec<f64> {
316    let n = xs.len();
317    let mut row = Vec::with_capacity(n);
318    for i in 0..n {
319        let duals: Vec<Dual> = xs
320            .iter()
321            .enumerate()
322            .map(|(j, &x)| dual(x, if j == i { 1.0 } else { 0.0 }))
323            .collect();
324        row.push(f(&duals).dv);
325    }
326    row
327}
328
329// ---------------------------------------------------------------------------
330// DualVec — vector of Dual numbers
331// ---------------------------------------------------------------------------
332
333/// A vector of [`Dual`] numbers for multivariate forward-mode differentiation.
334#[allow(dead_code)]
335#[derive(Debug, Clone)]
336pub struct DualVec {
337    /// The dual components.
338    pub components: Vec<Dual>,
339}
340
341impl DualVec {
342    /// Construct from a slice of `(value, derivative)` pairs.
343    #[allow(dead_code)]
344    pub fn from_pairs(pairs: &[(f64, f64)]) -> Self {
345        DualVec {
346            components: pairs.iter().map(|&(v, dv)| Dual { v, dv }).collect(),
347        }
348    }
349
350    /// Construct a variable vector: component `i` has seed 1, others 0.
351    #[allow(dead_code)]
352    pub fn variable(xs: &[f64], seed_idx: usize) -> Self {
353        DualVec {
354            components: xs
355                .iter()
356                .enumerate()
357                .map(|(i, &x)| dual(x, if i == seed_idx { 1.0 } else { 0.0 }))
358                .collect(),
359        }
360    }
361
362    /// Length of the vector.
363    #[allow(dead_code)]
364    pub fn len(&self) -> usize {
365        self.components.len()
366    }
367
368    /// Returns `true` if the vector is empty.
369    #[allow(dead_code)]
370    pub fn is_empty(&self) -> bool {
371        self.components.is_empty()
372    }
373}
374
375/// Compute the full gradient of `f: Rⁿ → R` at `xs`.
376///
377/// Makes `n` forward passes, one per variable.
378pub fn gradient(f: impl Fn(&[Dual]) -> Dual, xs: &[f64]) -> Vec<f64> {
379    jacobian_row(f, xs)
380}
381
382// ---------------------------------------------------------------------------
383// Newton-Raphson step via autodiff
384// ---------------------------------------------------------------------------
385
386/// Perform one Newton-Raphson step: `x_new = x - f(x) / f'(x)`.
387///
388/// Uses forward-mode autodiff to compute `f'(x)` automatically.
389///
390/// # Panics
391///
392/// Does **not** panic on zero derivative; instead returns `x` unchanged when
393/// `|f'(x)| < 1e-14`.
394pub fn newton_step(f: impl Fn(Dual) -> Dual, x: f64) -> f64 {
395    let d = f(dual(x, 1.0));
396    if d.dv.abs() < 1e-14 {
397        x
398    } else {
399        x - d.v / d.dv
400    }
401}
402
403// ---------------------------------------------------------------------------
404// TaylorExpand
405// ---------------------------------------------------------------------------
406
407/// Stores the value of a function and its first three derivatives at a point,
408/// enabling Taylor polynomial evaluation.
409#[allow(dead_code)]
410#[derive(Debug, Clone)]
411pub struct TaylorExpand {
412    /// The expansion center `x₀`.
413    pub center: f64,
414    /// `f(x₀)`.
415    pub f0: f64,
416    /// `f'(x₀)`.
417    pub f1: f64,
418    /// `f''(x₀)` (second derivative).
419    pub f2: f64,
420    /// `f'''(x₀)` (third derivative).
421    pub f3: f64,
422}
423
424impl TaylorExpand {
425    /// Build a [`TaylorExpand`] by numerically estimating derivatives via
426    /// forward-mode AD and central finite differences.
427    ///
428    /// - `f1` — exact via autodiff.
429    /// - `f2` — via central difference of `f'`.
430    /// - `f3` — via central difference of `f''`.
431    pub fn build<F>(f: F, x0: f64) -> Self
432    where
433        F: Fn(Dual) -> Dual + Copy,
434    {
435        let h = 1e-5;
436        let f0 = f(dual(x0, 0.0)).v;
437        let f1 = f(dual(x0, 1.0)).dv;
438        // Second derivative: (f'(x0+h) - f'(x0-h)) / (2h)
439        let f2 = (f(dual(x0 + h, 1.0)).dv - f(dual(x0 - h, 1.0)).dv) / (2.0 * h);
440        // Third derivative: (f''(x0+h) - f''(x0-h)) / (2h)
441        let f2p = |xv: f64| (f(dual(xv + h, 1.0)).dv - f(dual(xv - h, 1.0)).dv) / (2.0 * h);
442        let f3 = (f2p(x0 + h) - f2p(x0 - h)) / (2.0 * h);
443        TaylorExpand {
444            center: x0,
445            f0,
446            f1,
447            f2,
448            f3,
449        }
450    }
451
452    /// Evaluate the Taylor polynomial at `x`:
453    /// `f0 + f1*(x-x0) + f2/2*(x-x0)² + f3/6*(x-x0)³`
454    pub fn eval(&self, x: f64) -> f64 {
455        let dx = x - self.center;
456        self.f0 + self.f1 * dx + (self.f2 / 2.0) * dx * dx + (self.f3 / 6.0) * dx * dx * dx
457    }
458}
459
460// ---------------------------------------------------------------------------
461// Tests
462// ---------------------------------------------------------------------------
463
464#[cfg(test)]
465mod tests {
466    use super::*;
467    use std::f64::consts::{E, PI};
468
469    const EPS: f64 = 1e-9;
470    const LOOSE: f64 = 1e-5;
471
472    // --- Dual arithmetic ---
473
474    #[test]
475    fn test_dual_add() {
476        let a = dual(3.0, 1.0);
477        let b = dual(2.0, 4.0);
478        let c = a + b;
479        assert!((c.v - 5.0).abs() < EPS);
480        assert!((c.dv - 5.0).abs() < EPS);
481    }
482
483    #[test]
484    fn test_dual_sub() {
485        let a = dual(5.0, 2.0);
486        let b = dual(3.0, 1.0);
487        let c = a - b;
488        assert!((c.v - 2.0).abs() < EPS);
489        assert!((c.dv - 1.0).abs() < EPS);
490    }
491
492    #[test]
493    fn test_dual_mul_product_rule() {
494        // d/dx [x * x] at x=3 = 2x = 6
495        let x = dual(3.0, 1.0);
496        let y = x * x;
497        assert!((y.v - 9.0).abs() < EPS);
498        assert!((y.dv - 6.0).abs() < EPS);
499    }
500
501    #[test]
502    fn test_dual_div_quotient_rule() {
503        // d/dx [x / (x+1)] at x=2 = 1/(x+1)^2 = 1/9
504        let x = dual(2.0, 1.0);
505        let y = x / (x + 1.0);
506        assert!((y.v - 2.0 / 3.0).abs() < EPS);
507        assert!((y.dv - 1.0 / 9.0).abs() < EPS);
508    }
509
510    #[test]
511    fn test_dual_neg() {
512        let x = dual(4.0, 1.0);
513        let y = -x;
514        assert!((y.v + 4.0).abs() < EPS);
515        assert!((y.dv + 1.0).abs() < EPS);
516    }
517
518    #[test]
519    fn test_dual_add_f64() {
520        let x = dual(1.0, 1.0) + 5.0;
521        assert!((x.v - 6.0).abs() < EPS);
522        assert!((x.dv - 1.0).abs() < EPS);
523    }
524
525    #[test]
526    fn test_dual_mul_f64() {
527        let x = dual(3.0, 1.0) * 4.0;
528        assert!((x.v - 12.0).abs() < EPS);
529        assert!((x.dv - 4.0).abs() < EPS);
530    }
531
532    #[test]
533    fn test_dual_sub_f64_rhs() {
534        let x = 10.0_f64 - dual(3.0, 1.0);
535        assert!((x.v - 7.0).abs() < EPS);
536        assert!((x.dv + 1.0).abs() < EPS);
537    }
538
539    // --- Transcendental functions ---
540
541    #[test]
542    fn test_dual_sin_derivative() {
543        // d/dx sin(x) at x = π/4 = cos(π/4)
544        let x = dual(PI / 4.0, 1.0);
545        let y = x.sin();
546        assert!((y.v - (PI / 4.0).sin()).abs() < EPS);
547        assert!((y.dv - (PI / 4.0).cos()).abs() < EPS);
548    }
549
550    #[test]
551    fn test_dual_cos_derivative() {
552        let x = dual(PI / 3.0, 1.0);
553        let y = x.cos();
554        assert!((y.v - (PI / 3.0).cos()).abs() < EPS);
555        assert!((y.dv + (PI / 3.0).sin()).abs() < EPS);
556    }
557
558    #[test]
559    fn test_dual_exp_derivative() {
560        // d/dx exp(x) = exp(x)
561        let x = dual(2.0, 1.0);
562        let y = x.exp();
563        assert!((y.v - E * E).abs() < 1e-10);
564        assert!((y.dv - E * E).abs() < 1e-10);
565    }
566
567    #[test]
568    fn test_dual_ln_derivative() {
569        // d/dx ln(x) = 1/x
570        let x = dual(3.0, 1.0);
571        let y = x.ln();
572        assert!((y.v - 3.0_f64.ln()).abs() < EPS);
573        assert!((y.dv - 1.0 / 3.0).abs() < EPS);
574    }
575
576    #[test]
577    fn test_dual_sqrt_derivative() {
578        // d/dx sqrt(x) = 1/(2*sqrt(x))
579        let x = dual(4.0, 1.0);
580        let y = x.sqrt();
581        assert!((y.v - 2.0).abs() < EPS);
582        assert!((y.dv - 0.25).abs() < EPS);
583    }
584
585    #[test]
586    fn test_dual_abs_positive() {
587        let x = dual(3.0, 1.0);
588        let y = x.abs();
589        assert!((y.v - 3.0).abs() < EPS);
590        assert!((y.dv - 1.0).abs() < EPS);
591    }
592
593    #[test]
594    fn test_dual_abs_negative() {
595        let x = dual(-3.0, 1.0);
596        let y = x.abs();
597        assert!((y.v - 3.0).abs() < EPS);
598        assert!((y.dv + 1.0).abs() < EPS);
599    }
600
601    #[test]
602    fn test_dual_powi() {
603        // d/dx x^3 at x=2 = 3x^2 = 12
604        let x = dual(2.0, 1.0);
605        let y = x.powi(3);
606        assert!((y.v - 8.0).abs() < EPS);
607        assert!((y.dv - 12.0).abs() < EPS);
608    }
609
610    #[test]
611    fn test_dual_powf() {
612        // d/dx x^2.5 at x=4 = 2.5 * 4^1.5 = 2.5 * 8 = 20
613        let x = dual(4.0, 1.0);
614        let y = x.powf(2.5);
615        assert!((y.v - 4.0_f64.powf(2.5)).abs() < 1e-10);
616        assert!((y.dv - 2.5 * 4.0_f64.powf(1.5)).abs() < 1e-10);
617    }
618
619    // --- grad1 ---
620
621    #[test]
622    fn test_grad1_quadratic() {
623        // f(x) = x² + 3x - 5 → f'(x) = 2x + 3 → f'(4) = 11
624        let f = |x: Dual| x * x + dual(3.0, 0.0) * x - 5.0;
625        assert!((grad1(f, 4.0) - 11.0).abs() < EPS);
626    }
627
628    #[test]
629    fn test_grad1_sin() {
630        // f(x) = sin(x) → f'(π/6) = cos(π/6)
631        let f = |x: Dual| x.sin();
632        let expected = (PI / 6.0).cos();
633        assert!((grad1(f, PI / 6.0) - expected).abs() < EPS);
634    }
635
636    #[test]
637    fn test_grad1_chain_rule() {
638        // f(x) = exp(x^2) → f'(x) = 2x·exp(x²) at x=1 → 2e
639        let f = |x: Dual| (x * x).exp();
640        let expected = 2.0 * E;
641        assert!((grad1(f, 1.0) - expected).abs() < 1e-10);
642    }
643
644    #[test]
645    fn test_grad1_constant_function() {
646        let f = |_x: Dual| dual(42.0, 0.0);
647        assert!(grad1(f, 1.0).abs() < EPS);
648    }
649
650    // --- hessian_diag ---
651
652    #[test]
653    fn test_hessian_diag_quadratic() {
654        // f(x) = x² → f''(x) = 2 for all x
655        let f = |x: Dual| x * x;
656        let xs = [1.0, 2.0, -3.0];
657        let h = hessian_diag(f, &xs);
658        for hi in &h {
659            assert!((hi - 2.0).abs() < LOOSE, "expected 2, got {hi}");
660        }
661    }
662
663    #[test]
664    fn test_hessian_diag_sin() {
665        // f(x) = sin(x) → f''(x) = -sin(x)
666        let f = |x: Dual| x.sin();
667        let xs = [PI / 4.0];
668        let h = hessian_diag(f, &xs);
669        let expected = -(PI / 4.0).sin();
670        assert!((h[0] - expected).abs() < LOOSE);
671    }
672
673    // --- jacobian_row / gradient ---
674
675    #[test]
676    fn test_jacobian_row_linear() {
677        // f(x,y) = 2x + 3y → ∇f = [2, 3]
678        let f = |xs: &[Dual]| dual(2.0, 0.0) * xs[0] + dual(3.0, 0.0) * xs[1];
679        let row = jacobian_row(f, &[1.0, 1.0]);
680        assert!((row[0] - 2.0).abs() < EPS);
681        assert!((row[1] - 3.0).abs() < EPS);
682    }
683
684    #[test]
685    fn test_gradient_quadratic_surface() {
686        // f(x,y) = x² + y² → ∇f(3,4) = [6, 8]
687        let f = |xs: &[Dual]| xs[0] * xs[0] + xs[1] * xs[1];
688        let g = gradient(f, &[3.0, 4.0]);
689        assert!((g[0] - 6.0).abs() < EPS);
690        assert!((g[1] - 8.0).abs() < EPS);
691    }
692
693    #[test]
694    fn test_gradient_cross_term() {
695        // f(x,y,z) = x*y*z → ∂f/∂x = y*z, etc. at (2,3,4)
696        let f = |xs: &[Dual]| xs[0] * xs[1] * xs[2];
697        let g = gradient(f, &[2.0, 3.0, 4.0]);
698        assert!((g[0] - 12.0).abs() < EPS); // y*z = 12
699        assert!((g[1] - 8.0).abs() < EPS); // x*z = 8
700        assert!((g[2] - 6.0).abs() < EPS); // x*y = 6
701    }
702
703    // --- newton_step ---
704
705    #[test]
706    fn test_newton_step_sqrt2() {
707        // f(x) = x² - 2, root = √2 ≈ 1.4142
708        let f = |x: Dual| x * x - 2.0;
709        let mut x = 2.0_f64;
710        for _ in 0..20 {
711            x = newton_step(f, x);
712        }
713        assert!((x - 2.0_f64.sqrt()).abs() < 1e-12);
714    }
715
716    #[test]
717    fn test_newton_step_cube_root() {
718        // f(x) = x³ - 8, root = 2
719        let f = |x: Dual| x.powi(3) - 8.0;
720        let mut x = 3.0_f64;
721        for _ in 0..30 {
722            x = newton_step(f, x);
723        }
724        assert!((x - 2.0).abs() < 1e-10);
725    }
726
727    #[test]
728    fn test_newton_step_zero_derivative_safe() {
729        // At the extremum of f(x) = x², f'(0) = 0; must not blow up.
730        let f = |x: Dual| x * x;
731        let x_new = newton_step(f, 0.0);
732        assert!(x_new.is_finite());
733    }
734
735    // --- DualVec ---
736
737    #[test]
738    fn test_dualvec_len() {
739        let dv = DualVec::variable(&[1.0, 2.0, 3.0], 0);
740        assert_eq!(dv.len(), 3);
741    }
742
743    #[test]
744    fn test_dualvec_seed() {
745        let dv = DualVec::variable(&[1.0, 2.0, 3.0], 1);
746        assert!((dv.components[0].dv).abs() < EPS);
747        assert!((dv.components[1].dv - 1.0).abs() < EPS);
748        assert!((dv.components[2].dv).abs() < EPS);
749    }
750
751    #[test]
752    fn test_dualvec_from_pairs() {
753        let dv = DualVec::from_pairs(&[(1.0, 0.0), (2.0, 1.0)]);
754        assert_eq!(dv.len(), 2);
755        assert!((dv.components[1].dv - 1.0).abs() < EPS);
756    }
757
758    // --- TaylorExpand ---
759
760    #[test]
761    fn test_taylor_expand_sin_at_zero() {
762        // sin expanded at 0: f0=0, f1=1, f2=0, f3=-1
763        // T3(x) = x - x³/6
764        let t = TaylorExpand::build(|x| x.sin(), 0.0);
765        let x = 0.3;
766        let approx = t.eval(x);
767        let exact = x.sin();
768        assert!(
769            (approx - exact).abs() < 1e-4,
770            "approx={approx}, exact={exact}"
771        );
772    }
773
774    #[test]
775    fn test_taylor_expand_exp_at_zero() {
776        // exp expanded at 0: T3(x) = 1 + x + x²/2 + x³/6
777        let t = TaylorExpand::build(|x| x.exp(), 0.0);
778        let x = 0.5;
779        let approx = t.eval(x);
780        let exact = x.exp();
781        assert!(
782            (approx - exact).abs() < 1e-2,
783            "approx={approx}, exact={exact}"
784        );
785    }
786
787    #[test]
788    fn test_taylor_expand_exact_at_center() {
789        let t = TaylorExpand::build(|x| x * x + x, 2.0);
790        // At x = center, eval should equal f(center) = 6
791        let approx = t.eval(2.0);
792        assert!((approx - 6.0).abs() < 1e-9, "approx={approx}");
793    }
794
795    #[test]
796    fn test_taylor_expand_f1_matches_grad1() {
797        let f = |x: Dual| (x * x).sin();
798        let x0 = 1.0;
799        let t = TaylorExpand::build(f, x0);
800        let g = grad1(f, x0);
801        assert!((t.f1 - g).abs() < 1e-10);
802    }
803
804    // --- Dual::constant / variable helpers ---
805
806    #[test]
807    fn test_dual_constant_zero_derivative() {
808        let c = Dual::constant(7.0);
809        assert!((c.v - 7.0).abs() < EPS);
810        assert!(c.dv.abs() < EPS);
811    }
812
813    #[test]
814    fn test_dual_variable_unit_derivative() {
815        let v = Dual::variable(5.0);
816        assert!((v.v - 5.0).abs() < EPS);
817        assert!((v.dv - 1.0).abs() < EPS);
818    }
819
820    // --- tan / sinh / cosh ---
821
822    #[test]
823    fn test_dual_tan_derivative() {
824        // d/dx tan(x) at x = π/4 = 1/cos²(π/4) = 2
825        let x = dual(PI / 4.0, 1.0);
826        let y = x.tan();
827        let expected_dv = 1.0 / (PI / 4.0).cos().powi(2);
828        assert!((y.dv - expected_dv).abs() < 1e-10);
829    }
830
831    #[test]
832    fn test_dual_sinh_derivative() {
833        // d/dx sinh(x) = cosh(x)
834        let x = dual(1.0, 1.0);
835        let y = x.sinh();
836        assert!((y.dv - 1.0_f64.cosh()).abs() < EPS);
837    }
838
839    #[test]
840    fn test_dual_cosh_derivative() {
841        // d/dx cosh(x) = sinh(x)
842        let x = dual(1.0, 1.0);
843        let y = x.cosh();
844        assert!((y.dv - 1.0_f64.sinh()).abs() < EPS);
845    }
846
847    // --- operator overloads for f64 lhs ---
848
849    #[test]
850    fn test_f64_add_dual() {
851        let d = 3.0_f64 + dual(2.0, 1.0);
852        assert!((d.v - 5.0).abs() < EPS);
853        assert!((d.dv - 1.0).abs() < EPS);
854    }
855
856    #[test]
857    fn test_f64_mul_dual() {
858        let d = 5.0_f64 * dual(3.0, 1.0);
859        assert!((d.v - 15.0).abs() < EPS);
860        assert!((d.dv - 5.0).abs() < EPS);
861    }
862
863    // --- composed functions ---
864
865    #[test]
866    fn test_composed_sin_exp() {
867        // f(x) = sin(exp(x)) at x=0: f'(0) = cos(exp(0)) * exp(0) = cos(1)
868        let f = |x: Dual| x.exp().sin();
869        let expected = 1.0_f64.cos();
870        assert!((grad1(f, 0.0) - expected).abs() < EPS);
871    }
872
873    #[test]
874    fn test_composed_sqrt_ln() {
875        // f(x) = sqrt(ln(x)) at x=e: f'(e) = 1/(2*sqrt(ln(e))) * 1/e = 1/(2e)
876        let f = |x: Dual| x.ln().sqrt();
877        let expected = 1.0 / (2.0 * E);
878        assert!((grad1(f, E) - expected).abs() < 1e-10);
879    }
880
881    #[test]
882    fn test_div_f64() {
883        let d = dual(6.0, 3.0) / 2.0;
884        assert!((d.v - 3.0).abs() < EPS);
885        assert!((d.dv - 1.5).abs() < EPS);
886    }
887}