num_dual/datatypes/
hyperdual_vec.rs

1use crate::{Derivative, DualNum, DualNumFloat, DualStruct};
2use nalgebra::allocator::Allocator;
3use nalgebra::{Const, DefaultAllocator, Dim, Dyn, U1};
4use num_traits::{Float, FloatConst, FromPrimitive, Inv, Num, One, Signed, Zero};
5use std::fmt;
6use std::iter::{Product, Sum};
7use std::marker::PhantomData;
8use std::ops::{
9    Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign,
10};
11
12/// A vector hyper-dual number for the calculation of partial Hessians.
13#[derive(PartialEq, Eq, Clone, Debug)]
14pub struct HyperDualVec<T: DualNum<F>, F, M: Dim, N: Dim>
15where
16    DefaultAllocator: Allocator<M> + Allocator<M, N> + Allocator<U1, N>,
17{
18    /// Real part of the hyper-dual number
19    pub re: T,
20    /// Gradient part of the hyper-dual number
21    pub eps1: Derivative<T, F, M, U1>,
22    /// Gradient part of the hyper-dual number
23    pub eps2: Derivative<T, F, U1, N>,
24    /// Partial Hessian part of the hyper-dual number
25    pub eps1eps2: Derivative<T, F, M, N>,
26    f: PhantomData<F>,
27}
28
29impl<T: DualNum<F> + Copy, F: Copy, const M: usize, const N: usize> Copy
30    for HyperDualVec<T, F, Const<M>, Const<N>>
31{
32}
33
34#[cfg(feature = "ndarray")]
35impl<T: DualNum<F>, F: DualNumFloat, M: Dim, N: Dim> ndarray::ScalarOperand
36    for HyperDualVec<T, F, M, N>
37where
38    DefaultAllocator: Allocator<M> + Allocator<M, N> + Allocator<U1, N>,
39{
40}
41
42pub type HyperDualVec32<M, N> = HyperDualVec<f32, f32, M, N>;
43pub type HyperDualVec64<M, N> = HyperDualVec<f64, f64, M, N>;
44pub type HyperDualSVec32<const M: usize, const N: usize> =
45    HyperDualVec<f32, f32, Const<M>, Const<N>>;
46pub type HyperDualSVec64<const M: usize, const N: usize> =
47    HyperDualVec<f64, f64, Const<M>, Const<N>>;
48pub type HyperDualDVec32 = HyperDualVec<f32, f32, Dyn, Dyn>;
49pub type HyperDualDVec64 = HyperDualVec<f64, f64, Dyn, Dyn>;
50
51impl<T: DualNum<F>, F, M: Dim, N: Dim> HyperDualVec<T, F, M, N>
52where
53    DefaultAllocator: Allocator<M> + Allocator<M, N> + Allocator<U1, N>,
54{
55    /// Create a new hyper-dual number from its fields.
56    #[inline]
57    pub fn new(
58        re: T,
59        eps1: Derivative<T, F, M, U1>,
60        eps2: Derivative<T, F, U1, N>,
61        eps1eps2: Derivative<T, F, M, N>,
62    ) -> Self {
63        Self {
64            re,
65            eps1,
66            eps2,
67            eps1eps2,
68            f: PhantomData,
69        }
70    }
71}
72
73impl<T: DualNum<F>, F, M: Dim, N: Dim> HyperDualVec<T, F, M, N>
74where
75    DefaultAllocator: Allocator<M> + Allocator<M, N> + Allocator<U1, N>,
76{
77    /// Create a new hyper-dual number from the real part.
78    #[inline]
79    pub fn from_re(re: T) -> Self {
80        Self::new(
81            re,
82            Derivative::none(),
83            Derivative::none(),
84            Derivative::none(),
85        )
86    }
87}
88
89/* chain rule */
90impl<T: DualNum<F>, F: Float, M: Dim, N: Dim> HyperDualVec<T, F, M, N>
91where
92    DefaultAllocator: Allocator<M> + Allocator<M, N> + Allocator<U1, N>,
93{
94    #[inline]
95    fn chain_rule(&self, f0: T, f1: T, f2: T) -> Self {
96        Self::new(
97            f0,
98            &self.eps1 * f1.clone(),
99            &self.eps2 * f1.clone(),
100            &self.eps1eps2 * f1 + &self.eps1 * &self.eps2 * f2,
101        )
102    }
103}
104
105/* product rule */
106impl<T: DualNum<F>, F: Float, M: Dim, N: Dim> Mul<&HyperDualVec<T, F, M, N>>
107    for &HyperDualVec<T, F, M, N>
108where
109    DefaultAllocator: Allocator<M> + Allocator<M, N> + Allocator<U1, N>,
110{
111    type Output = HyperDualVec<T, F, M, N>;
112    #[inline]
113    fn mul(self, other: &HyperDualVec<T, F, M, N>) -> HyperDualVec<T, F, M, N> {
114        HyperDualVec::new(
115            self.re.clone() * other.re.clone(),
116            &other.eps1 * self.re.clone() + &self.eps1 * other.re.clone(),
117            &other.eps2 * self.re.clone() + &self.eps2 * other.re.clone(),
118            &other.eps1eps2 * self.re.clone()
119                + &self.eps1 * &other.eps2
120                + &other.eps1 * &self.eps2
121                + &self.eps1eps2 * other.re.clone(),
122        )
123    }
124}
125
126/* quotient rule */
127impl<T: DualNum<F>, F: Float, M: Dim, N: Dim> Div<&HyperDualVec<T, F, M, N>>
128    for &HyperDualVec<T, F, M, N>
129where
130    DefaultAllocator: Allocator<M> + Allocator<M, N> + Allocator<U1, N>,
131{
132    type Output = HyperDualVec<T, F, M, N>;
133    #[inline]
134    fn div(self, other: &HyperDualVec<T, F, M, N>) -> HyperDualVec<T, F, M, N> {
135        let inv = other.re.recip();
136        let inv2 = inv.clone() * &inv;
137        HyperDualVec::new(
138            self.re.clone() * &inv,
139            (&self.eps1 * other.re.clone() - &other.eps1 * self.re.clone()) * inv2.clone(),
140            (&self.eps2 * other.re.clone() - &other.eps2 * self.re.clone()) * inv2.clone(),
141            &self.eps1eps2 * inv.clone()
142                - (&other.eps1eps2 * self.re.clone()
143                    + &self.eps1 * &other.eps2
144                    + &other.eps1 * &self.eps2)
145                    * inv2.clone()
146                + &other.eps1
147                    * &other.eps2
148                    * ((T::one() + T::one()) * self.re.clone() * inv2 * inv),
149        )
150    }
151}
152
153/* string conversions */
154impl<T: DualNum<F>, F: fmt::Display, M: Dim, N: Dim> fmt::Display for HyperDualVec<T, F, M, N>
155where
156    DefaultAllocator: Allocator<M> + Allocator<M, N> + Allocator<U1, N>,
157{
158    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
159        write!(f, "{}", self.re)?;
160        self.eps1.fmt(f, "ε1")?;
161        self.eps2.fmt(f, "ε2")?;
162        self.eps1eps2.fmt(f, "ε1ε2")
163    }
164}
165
166impl_second_derivatives!(
167    HyperDualVec,
168    [eps1, eps2, eps1eps2],
169    [M, N],
170    [M],
171    [M, N],
172    [U1, N]
173);
174impl_dual!(
175    HyperDualVec,
176    [eps1, eps2, eps1eps2],
177    [M, N],
178    [M],
179    [M, N],
180    [U1, N]
181);