num_dual/datatypes/
hyperdual.rs

1use crate::{DualNum, DualNumFloat, DualStruct};
2use num_traits::{Float, FloatConst, FromPrimitive, Inv, Num, One, Signed, Zero};
3#[cfg(feature = "serde")]
4use serde::{Deserialize, Serialize};
5use std::fmt;
6use std::iter::{Product, Sum};
7use std::marker::PhantomData;
8use std::ops::{
9    Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign,
10};
11
12/// A scalar hyper-dual number for the calculation of second partial derivatives.
13#[derive(PartialEq, Eq, Copy, Clone, Debug)]
14#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
15pub struct HyperDual<T: DualNum<F>, F> {
16    /// Real part of the hyper-dual number
17    pub re: T,
18    /// Partial derivative part of the hyper-dual number
19    pub eps1: T,
20    /// Partial derivative part of the hyper-dual number
21    pub eps2: T,
22    /// Second partial derivative part of the hyper-dual number
23    pub eps1eps2: T,
24    #[cfg_attr(feature = "serde", serde(skip))]
25    f: PhantomData<F>,
26}
27
28#[cfg(feature = "ndarray")]
29impl<T: DualNum<F>, F: DualNumFloat> ndarray::ScalarOperand for HyperDual<T, F> {}
30
31pub type HyperDual32 = HyperDual<f32, f32>;
32pub type HyperDual64 = HyperDual<f64, f64>;
33
34impl<T: DualNum<F>, F> HyperDual<T, F> {
35    /// Create a new hyper-dual number from its fields.
36    #[inline]
37    pub fn new(re: T, eps1: T, eps2: T, eps1eps2: T) -> Self {
38        Self {
39            re,
40            eps1,
41            eps2,
42            eps1eps2,
43            f: PhantomData,
44        }
45    }
46}
47
48impl<T: DualNum<F>, F> HyperDual<T, F> {
49    /// Set the partial derivative part w.r.t. the 1st variable to 1.
50    #[inline]
51    pub fn derivative1(mut self) -> Self {
52        self.eps1 = T::one();
53        self
54    }
55
56    /// Set the partial derivative part w.r.t. the 2nd variable to 1.
57    #[inline]
58    pub fn derivative2(mut self) -> Self {
59        self.eps2 = T::one();
60        self
61    }
62}
63
64impl<T: DualNum<F>, F> HyperDual<T, F> {
65    /// Create a new hyper-dual number from the real part.
66    #[inline]
67    pub fn from_re(re: T) -> Self {
68        Self::new(re, T::zero(), T::zero(), T::zero())
69    }
70}
71
72/* chain rule */
73impl<T: DualNum<F>, F: Float> HyperDual<T, F> {
74    #[inline]
75    fn chain_rule(&self, f0: T, f1: T, f2: T) -> Self {
76        Self::new(
77            f0,
78            self.eps1.clone() * f1.clone(),
79            self.eps2.clone() * f1.clone(),
80            self.eps1eps2.clone() * f1 + self.eps1.clone() * self.eps2.clone() * f2,
81        )
82    }
83}
84
85/* product rule */
86impl<T: DualNum<F>, F: Float> Mul<&HyperDual<T, F>> for &HyperDual<T, F> {
87    type Output = HyperDual<T, F>;
88    #[inline]
89    fn mul(self, other: &HyperDual<T, F>) -> HyperDual<T, F> {
90        HyperDual::new(
91            self.re.clone() * other.re.clone(),
92            other.eps1.clone() * self.re.clone() + self.eps1.clone() * other.re.clone(),
93            other.eps2.clone() * self.re.clone() + self.eps2.clone() * other.re.clone(),
94            other.eps1eps2.clone() * self.re.clone()
95                + self.eps1.clone() * other.eps2.clone()
96                + other.eps1.clone() * self.eps2.clone()
97                + self.eps1eps2.clone() * other.re.clone(),
98        )
99    }
100}
101
102/* quotient rule */
103impl<T: DualNum<F>, F: Float> Div<&HyperDual<T, F>> for &HyperDual<T, F> {
104    type Output = HyperDual<T, F>;
105    #[inline]
106    fn div(self, other: &HyperDual<T, F>) -> HyperDual<T, F> {
107        let inv = other.re.recip();
108        let inv2 = inv.clone() * &inv;
109        HyperDual::new(
110            self.re.clone() * &inv,
111            (self.eps1.clone() * other.re.clone() - other.eps1.clone() * self.re.clone())
112                * inv2.clone(),
113            (self.eps2.clone() * other.re.clone() - other.eps2.clone() * self.re.clone())
114                * inv2.clone(),
115            self.eps1eps2.clone() * inv.clone()
116                - (other.eps1eps2.clone() * self.re.clone()
117                    + self.eps1.clone() * other.eps2.clone()
118                    + other.eps1.clone() * self.eps2.clone())
119                    * inv2.clone()
120                + other.eps1.clone()
121                    * other.eps2.clone()
122                    * ((T::one() + T::one()) * self.re.clone() * inv2 * inv),
123        )
124    }
125}
126
127/* string conversions */
128impl<T: DualNum<F>, F: fmt::Display> fmt::Display for HyperDual<T, F> {
129    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
130        fmt::Display::fmt(&self.re, f)?;
131        write!(f, " + ")?;
132        fmt::Display::fmt(&self.eps1, f)?;
133        write!(f, "ε1 + ")?;
134        fmt::Display::fmt(&self.eps2, f)?;
135        write!(f, "ε2 + ")?;
136        fmt::Display::fmt(&self.eps1eps2, f)?;
137        write!(f, "ε1ε2")
138    }
139}
140
141impl_second_derivatives!(HyperDual, [eps1, eps2, eps1eps2]);
142impl_dual!(HyperDual, [eps1, eps2, eps1eps2]);