Skip to main content

num_dual/datatypes/
dual2.rs

1use crate::{DualNum, DualNumFloat, DualStruct};
2use num_traits::{Float, FloatConst, FromPrimitive, Inv, Num, One, Signed, Zero};
3#[cfg(feature = "serde")]
4use serde::{Deserialize, Serialize};
5use std::fmt;
6use std::iter::{Product, Sum};
7use std::marker::PhantomData;
8use std::ops::{
9    Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign,
10};
11
12/// A scalar second order dual number for the calculation of second derivatives.
13#[derive(Copy, Clone, Debug)]
14#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
15pub struct Dual2<T: DualNum<F>, F> {
16    /// Real part of the second order dual number
17    pub re: T,
18    /// First derivative part of the second order dual number
19    pub v1: T,
20    /// Second derivative part of the second order dual number
21    pub v2: T,
22    #[cfg_attr(feature = "serde", serde(skip))]
23    f: PhantomData<F>,
24}
25
26#[cfg(feature = "ndarray")]
27impl<T: DualNum<F>, F: DualNumFloat> ndarray::ScalarOperand for Dual2<T, F> {}
28
29pub type Dual2_32 = Dual2<f32, f32>;
30pub type Dual2_64 = Dual2<f64, f64>;
31
32impl<T: DualNum<F>, F> Dual2<T, F> {
33    /// Create a new second order dual number from its fields.
34    #[inline]
35    pub fn new(re: T, v1: T, v2: T) -> Self {
36        Self {
37            re,
38            v1,
39            v2,
40            f: PhantomData,
41        }
42    }
43}
44
45impl<T: DualNum<F>, F> Dual2<T, F> {
46    /// Set the derivative part to 1.
47    /// ```
48    /// # use num_dual::{Dual2, DualNum};
49    /// let x = Dual2::from_re(5.0).derivative().powi(2);
50    /// assert_eq!(x.re, 25.0);             // x²
51    /// assert_eq!(x.v1, 10.0);    // 2x
52    /// assert_eq!(x.v2, 2.0);     // 2
53    /// ```
54    ///
55    /// Can also be used for higher order derivatives.
56    /// ```
57    /// # use num_dual::{Dual64, Dual2, DualNum};
58    /// let x = Dual2::from_re(Dual64::from_re(5.0).derivative())
59    ///     .derivative()
60    ///     .powi(2);
61    /// assert_eq!(x.re.re, 25.0);      // x²
62    /// assert_eq!(x.re.eps, 10.0);     // 2x
63    /// assert_eq!(x.v1.re, 10.0);      // 2x
64    /// assert_eq!(x.v1.eps, 2.0);      // 2
65    /// assert_eq!(x.v2.re, 2.0);       // 2
66    /// ```
67    #[inline]
68    pub fn derivative(mut self) -> Self {
69        self.v1 = T::one();
70        self
71    }
72}
73
74impl<T: DualNum<F>, F> Dual2<T, F> {
75    /// Create a new second order dual number from the real part.
76    #[inline]
77    pub fn from_re(re: T) -> Self {
78        Self::new(re, T::zero(), T::zero())
79    }
80}
81
82/* chain rule */
83impl<T: DualNum<F>, F: Float> Dual2<T, F> {
84    #[inline]
85    fn chain_rule(&self, f0: T, f1: T, f2: T) -> Self {
86        Self::new(
87            f0,
88            self.v1.clone() * f1.clone(),
89            self.v2.clone() * f1 + self.v1.clone() * self.v1.clone() * f2,
90        )
91    }
92}
93
94/* product rule */
95impl<T: DualNum<F>, F: Float> Mul<&Dual2<T, F>> for &Dual2<T, F> {
96    type Output = Dual2<T, F>;
97    #[inline]
98    fn mul(self, other: &Dual2<T, F>) -> Dual2<T, F> {
99        Dual2::new(
100            self.re.clone() * other.re.clone(),
101            other.v1.clone() * self.re.clone() + self.v1.clone() * other.re.clone(),
102            other.v2.clone() * self.re.clone()
103                + self.v1.clone() * other.v1.clone()
104                + other.v1.clone() * self.v1.clone()
105                + self.v2.clone() * other.re.clone(),
106        )
107    }
108}
109
110/* quotient rule */
111impl<T: DualNum<F>, F: Float> Div<&Dual2<T, F>> for &Dual2<T, F> {
112    type Output = Dual2<T, F>;
113    #[inline]
114    fn div(self, other: &Dual2<T, F>) -> Dual2<T, F> {
115        let inv = other.re.recip();
116        let inv2 = inv.clone() * inv.clone();
117        Dual2::new(
118            self.re.clone() * inv.clone(),
119            (self.v1.clone() * other.re.clone() - other.v1.clone() * self.re.clone())
120                * inv2.clone(),
121            self.v2.clone() * inv.clone()
122                - (other.v2.clone() * self.re.clone()
123                    + self.v1.clone() * other.v1.clone()
124                    + other.v1.clone() * self.v1.clone())
125                    * inv2.clone()
126                + other.v1.clone()
127                    * other.v1.clone()
128                    * ((T::one() + T::one()) * self.re.clone() * inv2 * inv),
129        )
130    }
131}
132
133/* string conversions */
134impl<T: DualNum<F>, F: fmt::Display> fmt::Display for Dual2<T, F> {
135    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
136        write!(f, "{} + {}ε1 + {}ε1²", self.re, self.v1, self.v2)
137    }
138}
139
140impl_second_derivatives!(Dual2, [v1, v2]);
141impl_dual!(Dual2, [v1, v2]);
142impl_nalgebra!(Dual2, [v1, v2]);