amari_dual/
types.rs

1//! Core dual number types for automatic differentiation
2//!
3//! Dual numbers extend real numbers with an infinitesimal unit ε where ε² = 0.
4//! This allows for exact computation of derivatives without numerical approximation
5//! or computational graphs, making it ideal for forward-mode automatic differentiation.
6//!
7//! A dual number has the form: a + b·ε where:
8//! - a is the real part (the function value)
9//! - b is the dual part (the derivative)
10
11use core::fmt;
12use core::ops::{Add, Div, Mul, Neg, Sub};
13use num_traits::{Float, One, Zero};
14
15/// A dual number for automatic differentiation
16///
17/// Dual numbers enable exact derivative computation using the algebraic property ε² = 0.
18/// For a function f(x), evaluating f(x + ε) automatically computes both f(x) and f'(x).
19///
20/// # Examples
21///
22/// ```
23/// use amari_dual::DualNumber;
24///
25/// // Create a variable x = 3.0 (with derivative dx/dx = 1.0)
26/// let x = DualNumber::variable(3.0);
27///
28/// // Compute f(x) = x² + 2x + 1
29/// let result = x * x + DualNumber::constant(2.0) * x + DualNumber::constant(1.0);
30///
31/// // result.real = f(3) = 9 + 6 + 1 = 16
32/// // result.dual = f'(3) = 2(3) + 2 = 8
33/// assert_eq!(result.real, 16.0);
34/// assert_eq!(result.dual, 8.0);
35/// ```
36#[derive(Debug, Clone, Copy, PartialEq)]
37pub struct DualNumber<T: Float> {
38    /// The real part (function value)
39    pub real: T,
40    /// The dual part (derivative)
41    pub dual: T,
42}
43
44impl<T: Float> DualNumber<T> {
45    /// Create a new dual number
46    pub fn new(real: T, dual: T) -> Self {
47        Self { real, dual }
48    }
49
50    /// Create a constant (derivative = 0)
51    ///
52    /// Constants have zero derivative since d/dx(c) = 0.
53    pub fn constant(value: T) -> Self {
54        Self {
55            real: value,
56            dual: T::zero(),
57        }
58    }
59
60    /// Create a variable (derivative = 1)
61    ///
62    /// Variables have unit derivative since d/dx(x) = 1.
63    pub fn variable(value: T) -> Self {
64        Self {
65            real: value,
66            dual: T::one(),
67        }
68    }
69
70    /// Get the value (real part)
71    pub fn value(&self) -> T {
72        self.real
73    }
74
75    /// Get the derivative (dual part)
76    pub fn derivative(&self) -> T {
77        self.dual
78    }
79
80    /// Exponential function: exp(a + b·ε) = exp(a) + b·exp(a)·ε
81    ///
82    /// Uses the chain rule: d/dx(e^f) = f'·e^f
83    pub fn exp(self) -> Self {
84        let exp_real = self.real.exp();
85        Self {
86            real: exp_real,
87            dual: self.dual * exp_real,
88        }
89    }
90
91    /// Natural logarithm: ln(a + b·ε) = ln(a) + (b/a)·ε
92    ///
93    /// Uses the chain rule: d/dx(ln f) = f'/f
94    pub fn ln(self) -> Self {
95        Self {
96            real: self.real.ln(),
97            dual: self.dual / self.real,
98        }
99    }
100
101    /// Sine function: sin(a + b·ε) = sin(a) + b·cos(a)·ε
102    ///
103    /// Uses the chain rule: d/dx(sin f) = f'·cos(f)
104    pub fn sin(self) -> Self {
105        Self {
106            real: self.real.sin(),
107            dual: self.dual * self.real.cos(),
108        }
109    }
110
111    /// Cosine function: cos(a + b·ε) = cos(a) - b·sin(a)·ε
112    ///
113    /// Uses the chain rule: d/dx(cos f) = -f'·sin(f)
114    pub fn cos(self) -> Self {
115        Self {
116            real: self.real.cos(),
117            dual: -self.dual * self.real.sin(),
118        }
119    }
120
121    /// Tangent function: tan(a + b·ε) = tan(a) + b·sec²(a)·ε
122    ///
123    /// Uses the chain rule: d/dx(tan f) = f'·sec²(f) = f'/cos²(f)
124    pub fn tan(self) -> Self {
125        let tan_real = self.real.tan();
126        let cos_real = self.real.cos();
127        Self {
128            real: tan_real,
129            dual: self.dual / (cos_real * cos_real),
130        }
131    }
132
133    /// Square root: sqrt(a + b·ε) = sqrt(a) + (b/(2·sqrt(a)))·ε
134    ///
135    /// Uses the chain rule: d/dx(√f) = f'/(2√f)
136    pub fn sqrt(self) -> Self {
137        let sqrt_real = self.real.sqrt();
138        Self {
139            real: sqrt_real,
140            dual: self.dual / (T::from(2.0).unwrap() * sqrt_real),
141        }
142    }
143
144    /// Power function: (a + b·ε)^n = a^n + n·b·a^(n-1)·ε
145    ///
146    /// Uses the power rule: d/dx(f^n) = n·f'·f^(n-1)
147    pub fn powf(self, n: T) -> Self {
148        let pow_real = self.real.powf(n);
149        Self {
150            real: pow_real,
151            dual: n * self.dual * self.real.powf(n - T::one()),
152        }
153    }
154
155    /// Integer power function: (a + b·ε)^n = a^n + n·b·a^(n-1)·ε
156    ///
157    /// Uses the power rule: d/dx(f^n) = n·f'·f^(n-1)
158    pub fn powi(self, n: i32) -> Self {
159        let n_float = T::from(n).unwrap();
160        let pow_real = self.real.powi(n);
161        Self {
162            real: pow_real,
163            dual: n_float * self.dual * self.real.powi(n - 1),
164        }
165    }
166
167    /// Absolute value: |a + b·ε| = |a| + b·sign(a)·ε
168    ///
169    /// Derivative is sign(a), undefined at a=0
170    pub fn abs(self) -> Self {
171        let sign = if self.real >= T::zero() {
172            T::one()
173        } else {
174            -T::one()
175        };
176        Self {
177            real: self.real.abs(),
178            dual: self.dual * sign,
179        }
180    }
181
182    /// Hyperbolic sine: sinh(a + b·ε) = sinh(a) + b·cosh(a)·ε
183    pub fn sinh(self) -> Self {
184        Self {
185            real: self.real.sinh(),
186            dual: self.dual * self.real.cosh(),
187        }
188    }
189
190    /// Hyperbolic cosine: cosh(a + b·ε) = cosh(a) + b·sinh(a)·ε
191    pub fn cosh(self) -> Self {
192        Self {
193            real: self.real.cosh(),
194            dual: self.dual * self.real.sinh(),
195        }
196    }
197
198    /// Hyperbolic tangent: tanh(a + b·ε) = tanh(a) + b·sech²(a)·ε
199    pub fn tanh(self) -> Self {
200        let tanh_real = self.real.tanh();
201        let cosh_real = self.real.cosh();
202        Self {
203            real: tanh_real,
204            dual: self.dual / (cosh_real * cosh_real),
205        }
206    }
207
208    /// Maximum of two dual numbers (non-differentiable at equality)
209    pub fn max(self, other: Self) -> Self {
210        if self.real >= other.real {
211            self
212        } else {
213            other
214        }
215    }
216
217    /// Minimum of two dual numbers (non-differentiable at equality)
218    pub fn min(self, other: Self) -> Self {
219        if self.real <= other.real {
220            self
221        } else {
222            other
223        }
224    }
225
226    /// Sigmoid (logistic) function: σ(x) = 1/(1 + e^(-x))
227    ///
228    /// Uses the chain rule: d/dx(σ(f)) = σ(f)·(1 - σ(f))·f'
229    pub fn sigmoid(self) -> Self {
230        let exp_neg = (-self.real).exp();
231        let sigmoid_real = T::one() / (T::one() + exp_neg);
232        let sigmoid_deriv = sigmoid_real * (T::one() - sigmoid_real);
233        Self {
234            real: sigmoid_real,
235            dual: self.dual * sigmoid_deriv,
236        }
237    }
238
239    /// Apply a function with its derivative
240    ///
241    /// This is useful for applying functions where you know both f(x) and f'(x).
242    /// The chain rule is applied automatically.
243    ///
244    /// # Arguments
245    /// * `f` - The function to apply
246    /// * `df` - The derivative of the function
247    pub fn apply_with_derivative<F, G>(self, f: F, df: G) -> Self
248    where
249        F: Fn(T) -> T,
250        G: Fn(T) -> T,
251    {
252        Self {
253            real: f(self.real),
254            dual: self.dual * df(self.real),
255        }
256    }
257}
258
259// Arithmetic operations using dual number algebra
260//
261// Dual number arithmetic rules (where ε² = 0):
262// (a + b·ε) + (c + d·ε) = (a + c) + (b + d)·ε
263// (a + b·ε) - (c + d·ε) = (a - c) + (b - d)·ε
264// (a + b·ε) * (c + d·ε) = a·c + (a·d + b·c)·ε
265// (a + b·ε) / (c + d·ε) = a/c + (b·c - a·d)/c²·ε
266
267impl<T: Float> Add for DualNumber<T> {
268    type Output = Self;
269
270    fn add(self, other: Self) -> Self {
271        Self {
272            real: self.real + other.real,
273            dual: self.dual + other.dual,
274        }
275    }
276}
277
278impl<T: Float> Sub for DualNumber<T> {
279    type Output = Self;
280
281    fn sub(self, other: Self) -> Self {
282        Self {
283            real: self.real - other.real,
284            dual: self.dual - other.dual,
285        }
286    }
287}
288
289impl<T: Float> Mul for DualNumber<T> {
290    type Output = Self;
291
292    fn mul(self, other: Self) -> Self {
293        Self {
294            real: self.real * other.real,
295            dual: self.real * other.dual + self.dual * other.real,
296        }
297    }
298}
299
300impl<T: Float> Div for DualNumber<T> {
301    type Output = Self;
302
303    fn div(self, other: Self) -> Self {
304        let real = self.real / other.real;
305        let dual = (self.dual * other.real - self.real * other.dual) / (other.real * other.real);
306        Self { real, dual }
307    }
308}
309
310impl<T: Float> Neg for DualNumber<T> {
311    type Output = Self;
312
313    fn neg(self) -> Self {
314        Self {
315            real: -self.real,
316            dual: -self.dual,
317        }
318    }
319}
320
321impl<T: Float> Zero for DualNumber<T> {
322    fn zero() -> Self {
323        Self::constant(T::zero())
324    }
325
326    fn is_zero(&self) -> bool {
327        self.real.is_zero() && self.dual.is_zero()
328    }
329}
330
331impl<T: Float> One for DualNumber<T> {
332    fn one() -> Self {
333        Self::constant(T::one())
334    }
335}
336
337impl<T: Float + fmt::Display> fmt::Display for DualNumber<T> {
338    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
339        write!(f, "{} + {}ε", self.real, self.dual)
340    }
341}
342
343// Precision type aliases for dual numbers
344/// Standard-precision dual number (f64)
345pub type StandardDual = DualNumber<f64>;
346
347/// Extended-precision dual number (uses extended precision float from amari-core)
348#[cfg(feature = "high-precision")]
349pub type ExtendedDual = DualNumber<crate::ExtendedFloat>;
350
351/// Standard-precision multi-dual number (f64)
352pub type StandardMultiDual = MultiDualNumber<f64>;
353
354/// Extended-precision multi-dual number (uses extended precision float from amari-core)
355#[cfg(feature = "high-precision")]
356pub type ExtendedMultiDual = MultiDualNumber<crate::ExtendedFloat>;
357
358/// Multi-variable dual number for computing gradients
359///
360/// A MultiDualNumber represents a scalar function of multiple variables,
361/// storing the function value and partial derivatives with respect to each variable.
362///
363/// # Examples
364///
365/// ```
366/// use amari_dual::MultiDualNumber;
367///
368/// // f(x, y) = x² + xy + y²
369/// // ∂f/∂x = 2x + y
370/// // ∂f/∂y = x + 2y
371///
372/// let x = 2.0;
373/// let y = 3.0;
374///
375/// let value = x * x + x * y + y * y; // = 4 + 6 + 9 = 19
376/// let gradient = vec![2.0 * x + y, x + 2.0 * y]; // = [7, 8]
377///
378/// let result = MultiDualNumber::new(value, gradient);
379/// ```
380#[derive(Debug, Clone, PartialEq)]
381pub struct MultiDualNumber<T: Float> {
382    /// The function value
383    pub value: T,
384    /// The gradient (partial derivatives with respect to each variable)
385    pub gradient: Vec<T>,
386}
387
388impl<T: Float> MultiDualNumber<T> {
389    /// Create a new multi-dual number
390    pub fn new(value: T, gradient: Vec<T>) -> Self {
391        Self { value, gradient }
392    }
393
394    /// Create a constant (all partial derivatives = 0)
395    pub fn constant(value: T, n_vars: usize) -> Self {
396        Self {
397            value,
398            gradient: vec![T::zero(); n_vars],
399        }
400    }
401
402    /// Create a variable (partial derivative = 1 for this variable, 0 for others)
403    pub fn variable(value: T, var_index: usize, n_vars: usize) -> Self {
404        let mut gradient = vec![T::zero(); n_vars];
405        gradient[var_index] = T::one();
406        Self { value, gradient }
407    }
408
409    /// Get the number of variables
410    pub fn n_vars(&self) -> usize {
411        self.gradient.len()
412    }
413
414    /// Get the value
415    pub fn get_value(&self) -> T {
416        self.value
417    }
418
419    /// Get the gradient
420    pub fn get_gradient(&self) -> &[T] {
421        &self.gradient
422    }
423}
424
425impl<T: Float> Add for MultiDualNumber<T> {
426    type Output = Self;
427
428    fn add(self, other: Self) -> Self {
429        assert_eq!(
430            self.gradient.len(),
431            other.gradient.len(),
432            "Gradient dimension mismatch"
433        );
434        let gradient = self
435            .gradient
436            .iter()
437            .zip(&other.gradient)
438            .map(|(&a, &b)| a + b)
439            .collect();
440        Self {
441            value: self.value + other.value,
442            gradient,
443        }
444    }
445}
446
447impl<T: Float> Sub for MultiDualNumber<T> {
448    type Output = Self;
449
450    fn sub(self, other: Self) -> Self {
451        assert_eq!(
452            self.gradient.len(),
453            other.gradient.len(),
454            "Gradient dimension mismatch"
455        );
456        let gradient = self
457            .gradient
458            .iter()
459            .zip(&other.gradient)
460            .map(|(&a, &b)| a - b)
461            .collect();
462        Self {
463            value: self.value - other.value,
464            gradient,
465        }
466    }
467}
468
469impl<T: Float> Mul for MultiDualNumber<T> {
470    type Output = Self;
471
472    fn mul(self, other: Self) -> Self {
473        assert_eq!(
474            self.gradient.len(),
475            other.gradient.len(),
476            "Gradient dimension mismatch"
477        );
478        // Product rule: (fg)' = f'g + fg'
479        let gradient = self
480            .gradient
481            .iter()
482            .zip(&other.gradient)
483            .map(|(&df, &dg)| df * other.value + self.value * dg)
484            .collect();
485        Self {
486            value: self.value * other.value,
487            gradient,
488        }
489    }
490}
491
492impl<T: Float> Div for MultiDualNumber<T> {
493    type Output = Self;
494
495    fn div(self, other: Self) -> Self {
496        assert_eq!(
497            self.gradient.len(),
498            other.gradient.len(),
499            "Gradient dimension mismatch"
500        );
501        // Quotient rule: (f/g)' = (f'g - fg')/g²
502        let g_squared = other.value * other.value;
503        let gradient = self
504            .gradient
505            .iter()
506            .zip(&other.gradient)
507            .map(|(&df, &dg)| (df * other.value - self.value * dg) / g_squared)
508            .collect();
509        Self {
510            value: self.value / other.value,
511            gradient,
512        }
513    }
514}
515
516impl<T: Float> Neg for MultiDualNumber<T> {
517    type Output = Self;
518
519    fn neg(self) -> Self {
520        let gradient = self.gradient.iter().map(|&x| -x).collect();
521        Self {
522            value: -self.value,
523            gradient,
524        }
525    }
526}
527
528#[cfg(test)]
529mod tests {
530    use super::*;
531    use approx::assert_relative_eq;
532
533    #[test]
534    fn test_dual_number_creation() {
535        let constant = DualNumber::constant(5.0);
536        assert_eq!(constant.real, 5.0);
537        assert_eq!(constant.dual, 0.0);
538
539        let variable = DualNumber::variable(3.0);
540        assert_eq!(variable.real, 3.0);
541        assert_eq!(variable.dual, 1.0);
542    }
543
544    #[test]
545    fn test_dual_number_arithmetic() {
546        let x = DualNumber::variable(2.0);
547        let c = DualNumber::constant(3.0);
548
549        // f(x) = 3x
550        let result = c * x;
551        assert_eq!(result.real, 6.0); // f(2) = 6
552        assert_eq!(result.dual, 3.0); // f'(2) = 3
553
554        // f(x) = x²
555        let result = x * x;
556        assert_eq!(result.real, 4.0); // f(2) = 4
557        assert_eq!(result.dual, 4.0); // f'(2) = 2x = 4
558    }
559
560    #[test]
561    fn test_dual_number_division() {
562        let x = DualNumber::variable(4.0);
563        let c = DualNumber::constant(2.0);
564
565        // f(x) = x/2
566        let result = x / c;
567        assert_eq!(result.real, 2.0); // f(4) = 2
568        assert_eq!(result.dual, 0.5); // f'(4) = 1/2
569    }
570
571    #[test]
572    fn test_dual_number_exp() {
573        let x = DualNumber::variable(0.0);
574
575        // f(x) = e^x
576        let result = x.exp();
577        assert_relative_eq!(result.real, 1.0, epsilon = 1e-10); // e^0 = 1
578        assert_relative_eq!(result.dual, 1.0, epsilon = 1e-10); // d/dx(e^x) = e^x = 1
579    }
580
581    #[test]
582    fn test_dual_number_ln() {
583        let x = DualNumber::variable(1.0);
584
585        // f(x) = ln(x)
586        let result = x.ln();
587        assert_relative_eq!(result.real, 0.0, epsilon = 1e-10); // ln(1) = 0
588        assert_relative_eq!(result.dual, 1.0, epsilon = 1e-10); // d/dx(ln x) = 1/x = 1
589    }
590
591    #[test]
592    fn test_dual_number_sin() {
593        let x = DualNumber::variable(0.0);
594
595        // f(x) = sin(x)
596        let result = x.sin();
597        assert_relative_eq!(result.real, 0.0, epsilon = 1e-10); // sin(0) = 0
598        assert_relative_eq!(result.dual, 1.0, epsilon = 1e-10); // d/dx(sin x) = cos(x) = 1
599    }
600
601    #[test]
602    fn test_dual_number_cos() {
603        let x = DualNumber::variable(0.0);
604
605        // f(x) = cos(x)
606        let result = x.cos();
607        assert_relative_eq!(result.real, 1.0, epsilon = 1e-10); // cos(0) = 1
608        assert_relative_eq!(result.dual, 0.0, epsilon = 1e-10); // d/dx(cos x) = -sin(x) = 0
609    }
610
611    #[test]
612    fn test_dual_number_sqrt() {
613        let x = DualNumber::variable(4.0);
614
615        // f(x) = √x
616        let result = x.sqrt();
617        assert_relative_eq!(result.real, 2.0, epsilon = 1e-10); // √4 = 2
618        assert_relative_eq!(result.dual, 0.25, epsilon = 1e-10); // d/dx(√x) = 1/(2√x) = 1/4
619    }
620
621    #[test]
622    fn test_multi_dual_number() {
623        // f(x, y) = x + y
624        let x = MultiDualNumber::variable(2.0, 0, 2);
625        let y = MultiDualNumber::variable(3.0, 1, 2);
626
627        let result = x + y;
628        assert_eq!(result.value, 5.0);
629        assert_eq!(result.gradient[0], 1.0); // ∂f/∂x = 1
630        assert_eq!(result.gradient[1], 1.0); // ∂f/∂y = 1
631    }
632
633    #[test]
634    fn test_multi_dual_number_product() {
635        // f(x, y) = x * y
636        let x = MultiDualNumber::variable(2.0, 0, 2);
637        let y = MultiDualNumber::variable(3.0, 1, 2);
638
639        let result = x * y;
640        assert_eq!(result.value, 6.0);
641        assert_eq!(result.gradient[0], 3.0); // ∂f/∂x = y = 3
642        assert_eq!(result.gradient[1], 2.0); // ∂f/∂y = x = 2
643    }
644
645    #[test]
646    fn test_chain_rule() {
647        // f(x) = sin(x²)
648        let x = DualNumber::variable(1.0);
649        let x_squared = x * x;
650        let result = x_squared.sin();
651
652        // f(1) = sin(1)
653        assert_relative_eq!(result.real, 1.0_f64.sin(), epsilon = 1e-10);
654        // f'(1) = cos(1) * 2x = cos(1) * 2
655        assert_relative_eq!(result.dual, 1.0_f64.cos() * 2.0, epsilon = 1e-10);
656    }
657}