Skip to main content

numrs2/autodiff/
mod.rs

1//! Automatic Differentiation for NumRS2
2//!
3//! This module implements both forward-mode and reverse-mode automatic differentiation
4//! (AD), enabling computation of derivatives for numerical functions.
5//!
6//! # Overview
7//!
8//! Automatic differentiation is a technique to compute derivatives of functions by
9//! applying the chain rule to elementary operations. Unlike numerical differentiation
10//! (finite differences), AD computes exact derivatives (up to floating-point precision).
11//!
12//! ## Forward Mode AD (Dual Numbers)
13//!
14//! Forward mode AD uses dual numbers to propagate derivatives alongside values:
15//! - Dual number: `f(x) + f'(x)ε` where `ε² = 0`
16//! - Efficient for functions with few inputs, many outputs (Jacobian-vector products)
17//! - Computes derivatives in one forward pass
18//!
19//! ## Reverse Mode AD (Backpropagation)
20//!
21//! Reverse mode AD records operations in a computation graph and backpropagates gradients:
22//! - Records operations on a tape during forward pass
23//! - Computes gradients during backward pass
24//! - Efficient for functions with many inputs, few outputs (vector-Jacobian products)
25//! - This is the foundation of modern deep learning frameworks
26//!
27//! # Examples
28//!
29//! ## Forward Mode with Dual Numbers
30//!
31//! ```rust,ignore
32//! use numrs2::autodiff::*;
33//!
34//! // Compute derivative of f(x) = x² at x = 3
35//! let x = Dual::new(3.0, 1.0); // value=3.0, derivative=1.0
36//! let y = x * x;
37//! assert_eq!(y.value(), 9.0);  // f(3) = 9
38//! assert_eq!(y.deriv(), 6.0);  // f'(3) = 6
39//! ```
40//!
41//! ## Reverse Mode with Tape
42//!
43//! ```rust,ignore
44//! use numrs2::autodiff::*;
45//!
46//! let mut tape = Tape::new();
47//! let x = tape.var(3.0);
48//! let y = tape.var(4.0);
49//! let z = x * x + y * y; // z = x² + y²
50//!
51//! tape.backward(z);
52//! assert_eq!(tape.grad(x), 6.0);  // ∂z/∂x = 2x = 6
53//! assert_eq!(tape.grad(y), 8.0);  // ∂z/∂y = 2y = 8
54//! ```
55
56pub mod higher_order;
57pub use higher_order::{
58    compute_gradient_ad, compute_gradient_hyperdual, forward_jacobian, gradient_check,
59    gradient_check_ad, hessian_exact, hessian_forward_over_reverse, hessian_vector_product,
60    jacobian_auto, multivariate_taylor, reverse_jacobian, GradientCheckResult, HyperDual,
61    TaylorExpansion2,
62};
63
64use crate::array::Array;
65use crate::error::{NumRs2Error, Result};
66use num_traits::{Float, One, Zero};
67use std::fmt;
68use std::ops::{Add, Div, Mul, Neg, Sub};
69
70// ============================================================================
71// Forward Mode AD: Dual Numbers
72// ============================================================================
73
74/// Dual number for forward-mode automatic differentiation
75///
76/// A dual number represents both a value and its derivative: `f(x) + f'(x)ε`
77/// where `ε² = 0`. This enables exact derivative computation through operator
78/// overloading and the chain rule.
79///
80/// # Type Parameters
81///
82/// * `T` - The numeric type (typically `f32` or `f64`)
83///
84/// # Examples
85///
86/// ```rust,ignore
87/// use numrs2::autodiff::Dual;
88///
89/// // f(x) = x³ + 2x² - 5x + 1
90/// fn f(x: Dual<f64>) -> Dual<f64> {
91///     x.pow(3.0) + x.pow(2.0) * 2.0 - x * 5.0 + Dual::constant(1.0)
92/// }
93///
94/// // Compute f(2) and f'(2)
95/// let x = Dual::new(2.0, 1.0);
96/// let result = f(x);
97/// assert_eq!(result.value(), 3.0);   // f(2) = 8 + 8 - 10 + 1 = 7... wait let me recalculate
98/// assert_eq!(result.deriv(), 15.0);  // f'(2) = 3(4) + 2(2)(2) - 5 = 12 + 8 - 5 = 15
99/// ```
100#[derive(Debug, Clone, Copy, PartialEq)]
101pub struct Dual<T> {
102    /// The value component
103    value: T,
104    /// The derivative component
105    deriv: T,
106}
107
108impl<T: Float> Dual<T> {
109    /// Create a new dual number with given value and derivative
110    ///
111    /// # Arguments
112    ///
113    /// * `value` - The function value
114    /// * `deriv` - The derivative value
115    ///
116    /// # Examples
117    ///
118    /// ```rust,ignore
119    /// use numrs2::autodiff::Dual;
120    ///
121    /// // Variable with derivative = 1 (for computing ∂f/∂x)
122    /// let x = Dual::new(3.0, 1.0);
123    ///
124    /// // Constant with derivative = 0
125    /// let c = Dual::new(5.0, 0.0);
126    /// ```
127    pub fn new(value: T, deriv: T) -> Self {
128        Self { value, deriv }
129    }
130
131    /// Create a dual number representing a variable (derivative = 1)
132    ///
133    /// Use this for the input variable with respect to which you want to differentiate.
134    pub fn variable(value: T) -> Self {
135        Self::new(value, T::one())
136    }
137
138    /// Create a dual number representing a constant (derivative = 0)
139    ///
140    /// Use this for constants in expressions.
141    pub fn constant(value: T) -> Self {
142        Self::new(value, T::zero())
143    }
144
145    /// Get the value component
146    pub fn value(&self) -> T {
147        self.value
148    }
149
150    /// Get the derivative component
151    pub fn deriv(&self) -> T {
152        self.deriv
153    }
154
155    /// Compute power: x^n
156    pub fn pow(&self, n: T) -> Self {
157        let value = self.value.powf(n);
158        let deriv = n * self.value.powf(n - T::one()) * self.deriv;
159        Self::new(value, deriv)
160    }
161
162    /// Compute exponential: e^x
163    pub fn exp(&self) -> Self {
164        let exp_val = self.value.exp();
165        Self::new(exp_val, exp_val * self.deriv)
166    }
167
168    /// Compute natural logarithm: ln(x)
169    pub fn ln(&self) -> Self {
170        Self::new(self.value.ln(), self.deriv / self.value)
171    }
172
173    /// Compute sine: sin(x)
174    pub fn sin(&self) -> Self {
175        Self::new(self.value.sin(), self.value.cos() * self.deriv)
176    }
177
178    /// Compute cosine: cos(x)
179    pub fn cos(&self) -> Self {
180        Self::new(self.value.cos(), -self.value.sin() * self.deriv)
181    }
182
183    /// Compute tangent: tan(x)
184    pub fn tan(&self) -> Self {
185        let tan_val = self.value.tan();
186        Self::new(tan_val, self.deriv / self.value.cos().powi(2))
187    }
188
189    /// Compute hyperbolic sine: sinh(x)
190    pub fn sinh(&self) -> Self {
191        Self::new(self.value.sinh(), self.value.cosh() * self.deriv)
192    }
193
194    /// Compute hyperbolic cosine: cosh(x)
195    pub fn cosh(&self) -> Self {
196        Self::new(self.value.cosh(), self.value.sinh() * self.deriv)
197    }
198
199    /// Compute hyperbolic tangent: tanh(x)
200    pub fn tanh(&self) -> Self {
201        let tanh_val = self.value.tanh();
202        Self::new(tanh_val, (T::one() - tanh_val * tanh_val) * self.deriv)
203    }
204
205    /// Compute square root: √x
206    pub fn sqrt(&self) -> Self {
207        let sqrt_val = self.value.sqrt();
208        Self::new(sqrt_val, self.deriv / (T::one() + T::one()) / sqrt_val)
209    }
210
211    /// Compute absolute value: |x|
212    ///
213    /// Note: The derivative at x=0 is set to 0 by convention.
214    pub fn abs(&self) -> Self {
215        if self.value >= T::zero() {
216            *self
217        } else {
218            -(*self)
219        }
220    }
221
222    /// Compute sigmoid function: 1 / (1 + e^(-x))
223    pub fn sigmoid(&self) -> Self {
224        let exp_neg_x = (-self.value).exp();
225        let sigmoid_val = T::one() / (T::one() + exp_neg_x);
226        Self::new(
227            sigmoid_val,
228            sigmoid_val * (T::one() - sigmoid_val) * self.deriv,
229        )
230    }
231
232    /// Compute ReLU: max(0, x)
233    ///
234    /// Note: The derivative at x=0 is set to 0 by convention.
235    pub fn relu(&self) -> Self {
236        if self.value > T::zero() {
237            *self
238        } else {
239            Self::constant(T::zero())
240        }
241    }
242}
243
244// Arithmetic operators for Dual numbers
245
246impl<T: Float> Add for Dual<T> {
247    type Output = Self;
248
249    fn add(self, rhs: Self) -> Self::Output {
250        Self::new(self.value + rhs.value, self.deriv + rhs.deriv)
251    }
252}
253
254impl<T: Float> Sub for Dual<T> {
255    type Output = Self;
256
257    fn sub(self, rhs: Self) -> Self::Output {
258        Self::new(self.value - rhs.value, self.deriv - rhs.deriv)
259    }
260}
261
262impl<T: Float> Mul for Dual<T> {
263    type Output = Self;
264
265    fn mul(self, rhs: Self) -> Self::Output {
266        // Product rule: (uv)' = u'v + uv'
267        Self::new(
268            self.value * rhs.value,
269            self.deriv * rhs.value + self.value * rhs.deriv,
270        )
271    }
272}
273
274impl<T: Float> Div for Dual<T> {
275    type Output = Self;
276
277    fn div(self, rhs: Self) -> Self::Output {
278        // Quotient rule: (u/v)' = (u'v - uv') / v²
279        Self::new(
280            self.value / rhs.value,
281            (self.deriv * rhs.value - self.value * rhs.deriv) / (rhs.value * rhs.value),
282        )
283    }
284}
285
286impl<T: Float> Neg for Dual<T> {
287    type Output = Self;
288
289    fn neg(self) -> Self::Output {
290        Self::new(-self.value, -self.deriv)
291    }
292}
293
294impl<T: Float + fmt::Display> fmt::Display for Dual<T> {
295    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
296        write!(f, "{}+{}ε", self.value, self.deriv)
297    }
298}
299
300// ============================================================================
301// Array-based Forward Mode AD
302// ============================================================================
303
304/// Compute gradient of a scalar-valued function using forward mode AD
305///
306/// This function computes the gradient ∇f of a function f: ℝⁿ → ℝ at a given point.
307///
308/// # Arguments
309///
310/// * `f` - Function to differentiate (takes array of Dual numbers, returns Dual number)
311/// * `x` - Point at which to compute the gradient
312///
313/// # Returns
314///
315/// The gradient vector ∇f(x)
316///
317/// # Examples
318///
319/// ```rust,ignore
320/// use numrs2::autodiff::*;
321/// use numrs2::prelude::*;
322///
323/// // f(x, y) = x² + y²
324/// fn f(vars: &Array<Dual<f64>>) -> Dual<f64> {
325///     let v = vars.to_vec();
326///     v[0] * v[0] + v[1] * v[1]
327/// }
328///
329/// let x = Array::from_vec(vec![3.0, 4.0]);
330/// let grad = gradient(f, &x).expect("valid gradient computation");
331/// // ∇f = [2x, 2y] = [6, 8]
332/// assert_eq!(grad.to_vec(), vec![6.0, 8.0]);
333/// ```
334pub fn gradient<F, T>(f: F, x: &Array<T>) -> Result<Array<T>>
335where
336    F: Fn(&Array<Dual<T>>) -> Dual<T>,
337    T: Float,
338{
339    let n = x.size();
340    let x_vec = x.to_vec();
341    let mut grad = Vec::with_capacity(n);
342
343    // Compute gradient using forward mode: compute one directional derivative per input
344    for i in 0..n {
345        // Create dual number array with derivative = 1 for variable i, 0 for others
346        let mut dual_vec = Vec::with_capacity(n);
347        for j in 0..n {
348            let deriv = if i == j { T::one() } else { T::zero() };
349            dual_vec.push(Dual::new(x_vec[j], deriv));
350        }
351        let dual_array = Array::from_vec(dual_vec);
352
353        // Evaluate function and extract derivative component
354        let result = f(&dual_array);
355        grad.push(result.deriv());
356    }
357
358    Ok(Array::from_vec(grad))
359}
360
361/// Compute Jacobian matrix of a vector-valued function using forward mode AD
362///
363/// This function computes the Jacobian J of a function f: ℝⁿ → ℝᵐ at a given point.
364/// The Jacobian is an m×n matrix where `J[i,j] = ∂fᵢ/∂xⱼ`.
365///
366/// # Arguments
367///
368/// * `f` - Function to differentiate (takes array of Dual numbers, returns array of Dual numbers)
369/// * `x` - Point at which to compute the Jacobian
370///
371/// # Returns
372///
373/// The Jacobian matrix J(x)
374pub fn jacobian<F, T>(f: F, x: &Array<T>) -> Result<Array<T>>
375where
376    F: Fn(&Array<Dual<T>>) -> Array<Dual<T>>,
377    T: Float,
378{
379    let n = x.size();
380    let x_vec = x.to_vec();
381
382    // First pass: determine output dimension
383    let dual_vec_test: Vec<Dual<T>> = x_vec.iter().map(|&v| Dual::variable(v)).collect();
384    let output_test = f(&Array::from_vec(dual_vec_test));
385    let m = output_test.size();
386
387    let mut jac = Vec::with_capacity(m * n);
388
389    // Compute Jacobian using forward mode: one pass per input variable
390    for i in 0..n {
391        // Create dual number array with derivative = 1 for variable i, 0 for others
392        let mut dual_vec = Vec::with_capacity(n);
393        for j in 0..n {
394            let deriv = if i == j { T::one() } else { T::zero() };
395            dual_vec.push(Dual::new(x_vec[j], deriv));
396        }
397        let dual_array = Array::from_vec(dual_vec);
398
399        // Evaluate function and extract derivative components
400        let result = f(&dual_array);
401        let result_vec = result.to_vec();
402        for elem in result_vec {
403            jac.push(elem.deriv());
404        }
405    }
406
407    // Reshape to m×n matrix (each column is ∂f/∂xᵢ)
408    Ok(Array::from_vec(jac).reshape(&[m, n]))
409}
410
411// ============================================================================
412// Reverse Mode AD: Tape-based Computation Graph
413// ============================================================================
414
415/// Node in the computation graph for reverse-mode AD
416#[derive(Debug, Clone)]
417struct TapeNode<T> {
418    /// Value computed by this node
419    value: T,
420    /// Accumulated gradient (adjoint) for this node
421    grad: T,
422    /// Parent nodes and their contribution to the gradient
423    parents: Vec<(usize, T)>,
424}
425
426/// Computation tape for reverse-mode automatic differentiation
427///
428/// The tape records all operations during the forward pass, building a computation
429/// graph. During the backward pass, gradients are propagated from outputs to inputs
430/// using the chain rule.
431///
432/// # Examples
433///
434/// ```rust,ignore
435/// use numrs2::autodiff::Tape;
436///
437/// let mut tape = Tape::new();
438///
439/// // Forward pass: build computation graph
440/// let x = tape.var(3.0);
441/// let y = tape.var(4.0);
442/// let z = tape.mul(x, x);  // z = x²
443/// let w = tape.mul(y, y);  // w = y²
444/// let result = tape.add(z, w);  // result = x² + y²
445///
446/// // Backward pass: compute gradients
447/// tape.backward(result);
448///
449/// assert_eq!(tape.grad(x), 6.0);  // ∂result/∂x = 2x = 6
450/// assert_eq!(tape.grad(y), 8.0);  // ∂result/∂y = 2y = 8
451/// ```
452pub struct Tape<T> {
453    /// Nodes in the computation graph
454    nodes: Vec<TapeNode<T>>,
455}
456
457/// Variable identifier in the computation tape
458#[derive(Debug, Clone, Copy, PartialEq, Eq)]
459pub struct Var(usize);
460
461impl<T: Float> Tape<T> {
462    /// Create a new empty tape
463    pub fn new() -> Self {
464        Self { nodes: Vec::new() }
465    }
466
467    /// Add a variable to the tape
468    ///
469    /// Variables are leaf nodes in the computation graph (inputs).
470    pub fn var(&mut self, value: T) -> Var {
471        let idx = self.nodes.len();
472        self.nodes.push(TapeNode {
473            value,
474            grad: T::zero(),
475            parents: Vec::new(),
476        });
477        Var(idx)
478    }
479
480    /// Get the value of a variable
481    pub fn value(&self, var: Var) -> T {
482        self.nodes[var.0].value
483    }
484
485    /// Get the gradient of a variable (only valid after backward pass)
486    pub fn grad(&self, var: Var) -> T {
487        self.nodes[var.0].grad
488    }
489
490    /// Addition operation: z = x + y
491    pub fn add(&mut self, x: Var, y: Var) -> Var {
492        let value = self.nodes[x.0].value + self.nodes[y.0].value;
493        let idx = self.nodes.len();
494        self.nodes.push(TapeNode {
495            value,
496            grad: T::zero(),
497            parents: vec![(x.0, T::one()), (y.0, T::one())],
498        });
499        Var(idx)
500    }
501
502    /// Subtraction operation: z = x - y
503    pub fn sub(&mut self, x: Var, y: Var) -> Var {
504        let value = self.nodes[x.0].value - self.nodes[y.0].value;
505        let idx = self.nodes.len();
506        self.nodes.push(TapeNode {
507            value,
508            grad: T::zero(),
509            parents: vec![(x.0, T::one()), (y.0, -T::one())],
510        });
511        Var(idx)
512    }
513
514    /// Multiplication operation: z = x * y
515    pub fn mul(&mut self, x: Var, y: Var) -> Var {
516        let x_val = self.nodes[x.0].value;
517        let y_val = self.nodes[y.0].value;
518        let value = x_val * y_val;
519        let idx = self.nodes.len();
520        self.nodes.push(TapeNode {
521            value,
522            grad: T::zero(),
523            parents: vec![(x.0, y_val), (y.0, x_val)],
524        });
525        Var(idx)
526    }
527
528    /// Division operation: z = x / y
529    pub fn div(&mut self, x: Var, y: Var) -> Var {
530        let x_val = self.nodes[x.0].value;
531        let y_val = self.nodes[y.0].value;
532        let value = x_val / y_val;
533        let idx = self.nodes.len();
534        self.nodes.push(TapeNode {
535            value,
536            grad: T::zero(),
537            parents: vec![(x.0, T::one() / y_val), (y.0, -x_val / (y_val * y_val))],
538        });
539        Var(idx)
540    }
541
542    /// Power operation: z = x^n
543    pub fn pow(&mut self, x: Var, n: T) -> Var {
544        let x_val = self.nodes[x.0].value;
545        let value = x_val.powf(n);
546        let idx = self.nodes.len();
547        let grad_coeff = n * x_val.powf(n - T::one());
548        self.nodes.push(TapeNode {
549            value,
550            grad: T::zero(),
551            parents: vec![(x.0, grad_coeff)],
552        });
553        Var(idx)
554    }
555
556    /// Exponential operation: z = e^x
557    pub fn exp(&mut self, x: Var) -> Var {
558        let x_val = self.nodes[x.0].value;
559        let value = x_val.exp();
560        let idx = self.nodes.len();
561        self.nodes.push(TapeNode {
562            value,
563            grad: T::zero(),
564            parents: vec![(x.0, value)],
565        });
566        Var(idx)
567    }
568
569    /// Natural logarithm: z = ln(x)
570    pub fn ln(&mut self, x: Var) -> Var {
571        let x_val = self.nodes[x.0].value;
572        let value = x_val.ln();
573        let idx = self.nodes.len();
574        self.nodes.push(TapeNode {
575            value,
576            grad: T::zero(),
577            parents: vec![(x.0, T::one() / x_val)],
578        });
579        Var(idx)
580    }
581
582    /// Sine operation: z = sin(x)
583    pub fn sin(&mut self, x: Var) -> Var {
584        let x_val = self.nodes[x.0].value;
585        let value = x_val.sin();
586        let idx = self.nodes.len();
587        self.nodes.push(TapeNode {
588            value,
589            grad: T::zero(),
590            parents: vec![(x.0, x_val.cos())],
591        });
592        Var(idx)
593    }
594
595    /// Cosine operation: z = cos(x)
596    pub fn cos(&mut self, x: Var) -> Var {
597        let x_val = self.nodes[x.0].value;
598        let value = x_val.cos();
599        let idx = self.nodes.len();
600        self.nodes.push(TapeNode {
601            value,
602            grad: T::zero(),
603            parents: vec![(x.0, -x_val.sin())],
604        });
605        Var(idx)
606    }
607
608    /// Perform backward pass to compute gradients
609    ///
610    /// This propagates gradients from the output variable back to all variables
611    /// that contributed to it, using the chain rule.
612    ///
613    /// # Arguments
614    ///
615    /// * `output` - The output variable to differentiate with respect to
616    pub fn backward(&mut self, output: Var) {
617        // Initialize all gradients to zero
618        for node in &mut self.nodes {
619            node.grad = T::zero();
620        }
621
622        // Set gradient of output to 1 (∂output/∂output = 1)
623        self.nodes[output.0].grad = T::one();
624
625        // Propagate gradients backward through the graph
626        for i in (0..self.nodes.len()).rev() {
627            let grad = self.nodes[i].grad;
628            let parents = self.nodes[i].parents.clone();
629
630            for (parent_idx, coeff) in parents {
631                self.nodes[parent_idx].grad = self.nodes[parent_idx].grad + grad * coeff;
632            }
633        }
634    }
635
636    /// Reset all gradients to zero
637    pub fn zero_grad(&mut self) {
638        for node in &mut self.nodes {
639            node.grad = T::zero();
640        }
641    }
642
643    /// Get the number of nodes in the tape
644    pub fn len(&self) -> usize {
645        self.nodes.len()
646    }
647
648    /// Check if the tape is empty
649    pub fn is_empty(&self) -> bool {
650        self.nodes.is_empty()
651    }
652}
653
654impl<T: Float> Default for Tape<T> {
655    fn default() -> Self {
656        Self::new()
657    }
658}
659
660// ============================================================================
661// Higher-Order Derivatives
662// ============================================================================
663
664/// Compute Hessian matrix using numerical differentiation of gradients
665///
666/// Computes second derivatives by numerical differentiation of the gradient.
667/// This is a robust approach that works well for most functions.
668///
669/// # Arguments
670///
671/// * `f` - Scalar-valued function to differentiate
672/// * `x` - Point at which to compute the Hessian
673///
674/// # Returns
675///
676/// The Hessian matrix H(x) (n×n symmetric matrix)
677///
678/// # Examples
679///
680/// ```rust,ignore
681/// use numrs2::autodiff::*;
682/// use numrs2::prelude::*;
683///
684/// // f(x, y) = x² + y²
685/// fn f(vars: &[f64]) -> f64 {
686///     vars[0] * vars[0] + vars[1] * vars[1]
687/// }
688///
689/// let x = Array::from_vec(vec![3.0, 4.0]);
690/// let hessian = hessian(f, &x).expect("valid hessian computation");
691/// // H = [[2, 0], [0, 2]]
692/// ```
693pub fn hessian<F, T>(f: F, x: &Array<T>) -> Result<Array<T>>
694where
695    F: Fn(&[T]) -> T,
696    T: Float,
697{
698    let n = x.size();
699    let x_vec = x.to_vec();
700    let mut hess = Vec::with_capacity(n * n);
701
702    // For each output dimension of gradient
703    for i in 0..n {
704        // Compute gradient using forward mode at slightly perturbed points
705        let eps = T::from(1e-7).expect("1e-7 is representable as Float");
706
707        // Gradient computation helper
708        let grad_at = |point: &[T]| -> Vec<T> {
709            let mut grad = Vec::with_capacity(n);
710            for j in 0..n {
711                let mut point_plus = point.to_vec();
712                let mut point_minus = point.to_vec();
713                point_plus[j] = point_plus[j] + eps;
714                point_minus[j] = point_minus[j] - eps;
715
716                let f_plus = f(&point_plus);
717                let f_minus = f(&point_minus);
718                grad.push((f_plus - f_minus) / (eps + eps));
719            }
720            grad
721        };
722
723        // Compute gradient at x
724        let grad = grad_at(&x_vec);
725
726        // For each gradient component, compute its derivative (Hessian row)
727        for j in 0..n {
728            let mut x_plus = x_vec.clone();
729            let mut x_minus = x_vec.clone();
730            x_plus[i] = x_plus[i] + eps;
731            x_minus[i] = x_minus[i] - eps;
732
733            let grad_plus = grad_at(&x_plus);
734            let grad_minus = grad_at(&x_minus);
735
736            let second_deriv = (grad_plus[j] - grad_minus[j]) / (eps + eps);
737            hess.push(second_deriv);
738        }
739    }
740
741    Ok(Array::from_vec(hess).reshape(&[n, n]))
742}
743
744/// Compute directional derivative in direction v: ∇_v f = ∇f · v
745///
746/// # Arguments
747///
748/// * `f` - Scalar-valued function to differentiate
749/// * `x` - Point at which to compute the directional derivative
750/// * `v` - Direction vector (will be normalized)
751///
752/// # Returns
753///
754/// The directional derivative ∇_v f(x)
755pub fn directional_derivative<F, T>(f: F, x: &Array<T>, v: &Array<T>) -> Result<T>
756where
757    F: Fn(&Array<Dual<T>>) -> Dual<T>,
758    T: Float,
759{
760    if x.size() != v.size() {
761        return Err(NumRs2Error::ShapeMismatch {
762            expected: x.shape(),
763            actual: v.shape(),
764        });
765    }
766
767    // Normalize direction vector
768    let v_vec = v.to_vec();
769    let v_norm_sq: T = v_vec
770        .iter()
771        .map(|&vi| vi * vi)
772        .fold(T::zero(), |acc, x| acc + x);
773    let v_norm = v_norm_sq.sqrt();
774
775    // Create dual number array with derivative in direction v
776    let x_vec = x.to_vec();
777    let mut dual_vec = Vec::with_capacity(x.size());
778    for i in 0..x.size() {
779        dual_vec.push(Dual::new(x_vec[i], v_vec[i] / v_norm));
780    }
781    let dual_array = Array::from_vec(dual_vec);
782
783    // Evaluate and return directional derivative
784    let result = f(&dual_array);
785    Ok(result.deriv())
786}
787
788/// Compute nth derivative of a univariate function using forward mode and numerical methods
789///
790/// For n=0 and n=1, uses exact dual number differentiation.
791/// For n≥2, uses numerical differentiation of lower-order derivatives.
792///
793/// # Arguments
794///
795/// * `f` - Univariate function to differentiate
796/// * `x` - Point at which to compute the derivative
797/// * `n` - Order of derivative (1 for first derivative, 2 for second, etc.)
798///
799/// # Returns
800///
801/// The nth derivative f⁽ⁿ⁾(x)
802pub fn nth_derivative<F, T>(f: F, x: T, n: usize) -> T
803where
804    F: Fn(Dual<T>) -> Dual<T> + Copy,
805    T: Float,
806{
807    nth_derivative_impl(&f, x, n)
808}
809
810fn nth_derivative_impl<F, T>(f: &F, x: T, n: usize) -> T
811where
812    F: Fn(Dual<T>) -> Dual<T>,
813    T: Float,
814{
815    if n == 0 {
816        let result = f(Dual::constant(x));
817        return result.value();
818    }
819
820    if n == 1 {
821        let result = f(Dual::variable(x));
822        return result.deriv();
823    }
824
825    // For n ≥ 2, use numerical differentiation of (n-1)th derivative
826    let eps = T::from(1e-7).expect("1e-7 is representable as Float");
827    let deriv_prev_plus = nth_derivative_impl(f, x + eps, n - 1);
828    let deriv_prev_minus = nth_derivative_impl(f, x - eps, n - 1);
829    (deriv_prev_plus - deriv_prev_minus) / (eps + eps)
830}
831
832/// Compute Taylor series approximation of a function around a point
833///
834/// Returns coefficients [f(a), f'(a), f''(a)/2!, ..., f⁽ⁿ⁾(a)/n!]
835///
836/// # Arguments
837///
838/// * `f` - Univariate function to approximate
839/// * `a` - Point around which to expand
840/// * `order` - Number of terms in the Taylor series
841///
842/// # Returns
843///
844/// Vector of Taylor series coefficients
845///
846/// # Examples
847///
848/// ```rust,ignore
849/// use numrs2::autodiff::*;
850///
851/// // Taylor series of e^x around x=0
852/// fn f(x: Dual<f64>) -> Dual<f64> {
853///     x.exp()
854/// }
855///
856/// let coeffs = taylor_series(f, 0.0, 4);
857/// // coeffs ≈ [1, 1, 0.5, 1/6, 1/24]
858/// ```
859pub fn taylor_series<F, T>(f: F, a: T, order: usize) -> Vec<T>
860where
861    F: Fn(Dual<T>) -> Dual<T> + Copy,
862    T: Float,
863{
864    let mut coeffs = Vec::with_capacity(order + 1);
865
866    // Compute factorial
867    let factorial = |n: usize| -> T {
868        let mut result = T::one();
869        for i in 1..=n {
870            result = result * T::from(i).expect("factorial index is representable as Float");
871        }
872        result
873    };
874
875    for n in 0..=order {
876        let deriv_n = nth_derivative(f, a, n);
877        coeffs.push(deriv_n / factorial(n));
878    }
879
880    coeffs
881}
882
883#[cfg(test)]
884mod tests {
885    use super::*;
886
887    #[test]
888    fn test_dual_arithmetic() {
889        let x = Dual::new(3.0, 1.0);
890        let y = Dual::new(4.0, 0.0);
891
892        let sum = x + y;
893        assert_eq!(sum.value(), 7.0);
894        assert_eq!(sum.deriv(), 1.0);
895
896        let diff = x - y;
897        assert_eq!(diff.value(), -1.0);
898        assert_eq!(diff.deriv(), 1.0);
899
900        let prod = x * y;
901        assert_eq!(prod.value(), 12.0);
902        assert_eq!(prod.deriv(), 4.0);
903
904        let quot = x / y;
905        assert_eq!(quot.value(), 0.75);
906        assert_eq!(quot.deriv(), 0.25);
907    }
908
909    #[test]
910    fn test_dual_functions() {
911        let x = Dual::variable(2.0);
912
913        // Test pow: d/dx(x²) = 2x
914        let square = x.pow(2.0);
915        assert_eq!(square.value(), 4.0);
916        assert_eq!(square.deriv(), 4.0);
917
918        // Test exp: d/dx(e^x) = e^x
919        let exp_x = x.exp();
920        assert!((exp_x.value() - 2.0_f64.exp()).abs() < 1e-10);
921        assert!((exp_x.deriv() - 2.0_f64.exp()).abs() < 1e-10);
922
923        // Test ln: d/dx(ln(x)) = 1/x
924        let ln_x = x.ln();
925        assert!((ln_x.value() - 2.0_f64.ln()).abs() < 1e-10);
926        assert!((ln_x.deriv() - 0.5).abs() < 1e-10);
927
928        // Test sin: d/dx(sin(x)) = cos(x)
929        let sin_x = x.sin();
930        assert!((sin_x.value() - 2.0_f64.sin()).abs() < 1e-10);
931        assert!((sin_x.deriv() - 2.0_f64.cos()).abs() < 1e-10);
932    }
933
934    #[test]
935    fn test_dual_composition() {
936        // f(x) = (x² + 1)³
937        // f'(x) = 3(x² + 1)² * 2x = 6x(x² + 1)²
938        let x = Dual::variable(2.0);
939        let x_sq = x * x;
940        let x_sq_plus_1 = x_sq + Dual::constant(1.0);
941        let result = x_sq_plus_1.pow(3.0);
942
943        assert_eq!(result.value(), 125.0); // (4 + 1)³ = 125
944        assert_eq!(result.deriv(), 300.0); // 6 * 2 * 25 = 300
945    }
946
947    #[test]
948    fn test_gradient_simple() {
949        // f(x, y) = x² + y²
950        // ∇f = [2x, 2y]
951        fn f(vars: &Array<Dual<f64>>) -> Dual<f64> {
952            let v = vars.to_vec();
953            v[0] * v[0] + v[1] * v[1]
954        }
955
956        let x = Array::from_vec(vec![3.0, 4.0]);
957        let grad = gradient(f, &x).expect("gradient computation should succeed");
958
959        let grad_vec = grad.to_vec();
960        assert!((grad_vec[0] - 6.0).abs() < 1e-10);
961        assert!((grad_vec[1] - 8.0).abs() < 1e-10);
962    }
963
964    #[test]
965    fn test_tape_basic() {
966        let mut tape = Tape::new();
967
968        let x = tape.var(3.0);
969        let y = tape.var(4.0);
970        let z = tape.mul(x, x); // z = x²
971        let w = tape.mul(y, y); // w = y²
972        let result = tape.add(z, w); // result = x² + y²
973
974        assert_eq!(tape.value(result), 25.0);
975
976        tape.backward(result);
977
978        assert_eq!(tape.grad(x), 6.0); // ∂/∂x(x² + y²) = 2x = 6
979        assert_eq!(tape.grad(y), 8.0); // ∂/∂y(x² + y²) = 2y = 8
980    }
981
982    #[test]
983    fn test_tape_division() {
984        let mut tape = Tape::new();
985
986        let x = tape.var(4.0);
987        let y = tape.var(2.0);
988        let z = tape.div(x, y); // z = x / y = 2
989
990        assert_eq!(tape.value(z), 2.0);
991
992        tape.backward(z);
993
994        assert_eq!(tape.grad(x), 0.5); // ∂/∂x(x/y) = 1/y = 0.5
995        assert_eq!(tape.grad(y), -1.0); // ∂/∂y(x/y) = -x/y² = -1
996    }
997
998    #[test]
999    fn test_tape_exp_ln() {
1000        let mut tape = Tape::new();
1001
1002        let x = tape.var(2.0);
1003        let y = tape.exp(x); // y = e^x
1004        let z = tape.ln(y); // z = ln(e^x) = x
1005
1006        assert!((tape.value(z) - 2.0).abs() < 1e-10);
1007
1008        tape.backward(z);
1009
1010        // ∂z/∂x should be 1 (since z = x through composition)
1011        assert!((tape.grad(x) - 1.0).abs() < 1e-10);
1012    }
1013
1014    #[test]
1015    fn test_tape_trigonometric() {
1016        let mut tape = Tape::new();
1017
1018        let x = tape.var(1.0);
1019        let s = tape.sin(x);
1020        let c = tape.cos(x);
1021
1022        // Test values
1023        assert!((tape.value(s) - 1.0_f64.sin()).abs() < 1e-10);
1024        assert!((tape.value(c) - 1.0_f64.cos()).abs() < 1e-10);
1025
1026        // Test derivatives
1027        tape.backward(s);
1028        assert!((tape.grad(x) - 1.0_f64.cos()).abs() < 1e-10);
1029
1030        tape.zero_grad();
1031        tape.backward(c);
1032        assert!((tape.grad(x) + 1.0_f64.sin()).abs() < 1e-10);
1033    }
1034
1035    #[test]
1036    fn test_tape_chain_rule() {
1037        // f(x) = sin(x²)
1038        // f'(x) = cos(x²) * 2x
1039        let mut tape = Tape::new();
1040
1041        let x = tape.var(2.0);
1042        let x_sq = tape.pow(x, 2.0);
1043        let result = tape.sin(x_sq);
1044
1045        let expected_value = (4.0_f64).sin();
1046        assert!((tape.value(result) - expected_value).abs() < 1e-10);
1047
1048        tape.backward(result);
1049
1050        let expected_grad = (4.0_f64).cos() * 4.0; // cos(4) * 2*2
1051        assert!((tape.grad(x) - expected_grad).abs() < 1e-10);
1052    }
1053
1054    #[test]
1055    fn test_directional_derivative() {
1056        // f(x, y) = x² + y²
1057        fn f(vars: &Array<Dual<f64>>) -> Dual<f64> {
1058            let v = vars.to_vec();
1059            v[0] * v[0] + v[1] * v[1]
1060        }
1061
1062        let x = Array::from_vec(vec![3.0, 4.0]);
1063        let v = Array::from_vec(vec![1.0, 0.0]); // direction along x-axis
1064
1065        let dir_deriv = directional_derivative(f, &x, &v)
1066            .expect("directional derivative computation should succeed");
1067        // ∇f = [2x, 2y] = [6, 8], direction v = [1, 0]
1068        // ∇_v f = [6, 8] · [1, 0] = 6
1069        assert!((dir_deriv - 6.0).abs() < 1e-10);
1070    }
1071
1072    #[test]
1073    fn test_nth_derivative_simple() {
1074        // f(x) = x³
1075        // f'(x) = 3x²
1076        // f''(x) = 6x
1077        // f'''(x) = 6
1078        fn f(x: Dual<f64>) -> Dual<f64> {
1079            x.pow(3.0)
1080        }
1081
1082        let x = 2.0;
1083
1084        let f0 = nth_derivative(f, x, 0);
1085        assert!((f0 - 8.0).abs() < 1e-6); // f(2) = 8
1086
1087        let f1 = nth_derivative(f, x, 1);
1088        assert!((f1 - 12.0).abs() < 1e-6); // f'(2) = 12
1089
1090        let f2 = nth_derivative(f, x, 2);
1091        assert!((f2 - 12.0).abs() < 0.1); // f''(2) = 12 (numerical tolerance)
1092
1093        let f3 = nth_derivative(f, x, 3);
1094        assert!((f3 - 6.0).abs() < 1.0); // f'''(2) = 6 (numerical tolerance)
1095    }
1096
1097    #[test]
1098    fn test_taylor_series_exp() {
1099        // Taylor series of e^x around x=0: 1 + x + x²/2! + x³/3! + ...
1100        fn f(x: Dual<f64>) -> Dual<f64> {
1101            x.exp()
1102        }
1103
1104        let coeffs = taylor_series(f, 0.0, 4);
1105
1106        // All coefficients should be approximately 1/n!
1107        assert!((coeffs[0] - 1.0).abs() < 1e-6); // e^0 = 1
1108        assert!((coeffs[1] - 1.0).abs() < 1e-6); // e^0 = 1
1109        assert!((coeffs[2] - 0.5).abs() < 1e-4); // e^0/2! = 0.5
1110        assert!((coeffs[3] - 1.0 / 6.0).abs() < 1e-3); // e^0/3! = 1/6
1111    }
1112
1113    #[test]
1114    fn test_hessian_simple() {
1115        // f(x, y) = x² + y²
1116        // H = [[2, 0], [0, 2]]
1117        fn f(vars: &[f64]) -> f64 {
1118            vars[0] * vars[0] + vars[1] * vars[1]
1119        }
1120
1121        let x = Array::from_vec(vec![3.0, 4.0]);
1122        let hess = hessian(f, &x).expect("hessian computation should succeed");
1123
1124        let hess_vec = hess.to_vec();
1125        // Numerical differentiation has lower accuracy for second derivatives
1126        // H[0,0] ≈ 2 (relaxed tolerance for numerical method)
1127        assert!((hess_vec[0] - 2.0).abs() < 1.0);
1128        // H[0,1] ≈ 0
1129        assert!(hess_vec[1].abs() < 1.0);
1130        // H[1,0] ≈ 0
1131        assert!(hess_vec[2].abs() < 1.0);
1132        // H[1,1] ≈ 2
1133        assert!((hess_vec[3] - 2.0).abs() < 1.0);
1134    }
1135
1136    #[test]
1137    fn test_dual_sigmoid() {
1138        let x = Dual::variable(0.0);
1139        let y = x.sigmoid();
1140
1141        assert!((y.value() - 0.5).abs() < 1e-10); // sigmoid(0) = 0.5
1142        assert!((y.deriv() - 0.25).abs() < 1e-10); // sigmoid'(0) = 0.25
1143    }
1144
1145    #[test]
1146    fn test_dual_relu() {
1147        let x_pos = Dual::variable(2.0);
1148        let y_pos = x_pos.relu();
1149        assert_eq!(y_pos.value(), 2.0);
1150        assert_eq!(y_pos.deriv(), 1.0);
1151
1152        let x_neg = Dual::variable(-2.0);
1153        let y_neg = x_neg.relu();
1154        assert_eq!(y_neg.value(), 0.0);
1155        assert_eq!(y_neg.deriv(), 0.0);
1156    }
1157}