num_dual/datatypes/
hyperdual_vec.rs

1use crate::{Derivative, DualNum, DualNumFloat, DualStruct};
2use nalgebra::allocator::Allocator;
3use nalgebra::{Const, DefaultAllocator, Dim, Dyn, U1};
4use num_traits::{Float, FloatConst, FromPrimitive, Inv, Num, One, Signed, Zero};
5use std::fmt;
6use std::iter::{Product, Sum};
7use std::marker::PhantomData;
8use std::ops::{
9    Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign,
10};
11
12/// A vector hyper-dual number for the calculation of partial Hessians.
13#[derive(PartialEq, Eq, Clone, Debug)]
14pub struct HyperDualVec<T: DualNum<F>, F, M: Dim, N: Dim>
15where
16    DefaultAllocator: Allocator<M> + Allocator<M, N> + Allocator<U1, N>,
17{
18    /// Real part of the hyper-dual number
19    pub re: T,
20    /// Gradient part of the hyper-dual number
21    pub eps1: Derivative<T, F, M, U1>,
22    /// Gradient part of the hyper-dual number
23    pub eps2: Derivative<T, F, U1, N>,
24    /// Partial Hessian part of the hyper-dual number
25    pub eps1eps2: Derivative<T, F, M, N>,
26    f: PhantomData<F>,
27}
28
29impl<T: DualNum<F> + Copy, F: Copy, const M: usize, const N: usize> Copy
30    for HyperDualVec<T, F, Const<M>, Const<N>>
31{
32}
33
34#[cfg(feature = "ndarray")]
35impl<T: DualNum<F>, F: DualNumFloat, M: Dim, N: Dim> ndarray::ScalarOperand
36    for HyperDualVec<T, F, M, N>
37where
38    DefaultAllocator: Allocator<M> + Allocator<M, N> + Allocator<U1, N>,
39{
40}
41
42pub type HyperDualSVec<T, F, const M: usize, const N: usize> =
43    HyperDualVec<T, F, Const<M>, Const<N>>;
44pub type HyperDualDVec<T, F> = HyperDualVec<T, F, Dyn, Dyn>;
45pub type HyperDualVec32<M, N> = HyperDualVec<f32, f32, M, N>;
46pub type HyperDualVec64<M, N> = HyperDualVec<f64, f64, M, N>;
47pub type HyperDualSVec32<const M: usize, const N: usize> =
48    HyperDualVec<f32, f32, Const<M>, Const<N>>;
49pub type HyperDualSVec64<const M: usize, const N: usize> =
50    HyperDualVec<f64, f64, Const<M>, Const<N>>;
51pub type HyperDualDVec32 = HyperDualVec<f32, f32, Dyn, Dyn>;
52pub type HyperDualDVec64 = HyperDualVec<f64, f64, Dyn, Dyn>;
53
54impl<T: DualNum<F>, F, M: Dim, N: Dim> HyperDualVec<T, F, M, N>
55where
56    DefaultAllocator: Allocator<M> + Allocator<M, N> + Allocator<U1, N>,
57{
58    /// Create a new hyper-dual number from its fields.
59    #[inline]
60    pub fn new(
61        re: T,
62        eps1: Derivative<T, F, M, U1>,
63        eps2: Derivative<T, F, U1, N>,
64        eps1eps2: Derivative<T, F, M, N>,
65    ) -> Self {
66        Self {
67            re,
68            eps1,
69            eps2,
70            eps1eps2,
71            f: PhantomData,
72        }
73    }
74}
75
76impl<T: DualNum<F>, F, M: Dim, N: Dim> HyperDualVec<T, F, M, N>
77where
78    DefaultAllocator: Allocator<M> + Allocator<M, N> + Allocator<U1, N>,
79{
80    /// Create a new hyper-dual number from the real part.
81    #[inline]
82    pub fn from_re(re: T) -> Self {
83        Self::new(
84            re,
85            Derivative::none(),
86            Derivative::none(),
87            Derivative::none(),
88        )
89    }
90}
91
92impl<T: DualNum<F>, F, const M: usize, const N: usize> HyperDualSVec<T, F, M, N> {
93    /// Set the 1st dimension derivative of variable `index` to 1.
94    ///
95    /// For most cases, the [`partial_hessian`](crate::partial_hessian) function provides a
96    /// convenient interface to calculate derivatives. This function exists for the more edge
97    /// cases where more control over the variables is required.
98    #[inline]
99    pub fn derivative1(mut self, index: usize) -> Self {
100        self.eps1 = Derivative::derivative_generic(Const::<M>, U1, index);
101        self
102    }
103
104    /// Set the 2nd dimension derivative of variable `index` to 1.
105    ///
106    /// For most cases, the [`partial_hessian`](crate::partial_hessian) function provides a
107    /// convenient interface to calculate derivatives. This function exists for the more edge
108    /// cases where more control over the variables is required.
109    #[inline]
110    pub fn derivative2(mut self, index: usize) -> Self {
111        self.eps2 = Derivative::derivative_generic(U1, Const::<N>, index);
112        self
113    }
114}
115
116impl<T: DualNum<F>, F> HyperDualDVec<T, F> {
117    /// Set the 1st dimension derivative part of variable `index` to 1.
118    ///
119    /// For most cases, the [`partial_hessian`](crate::partial_hessian) function provides a
120    /// convenient interface to calculate derivatives. This function exists for the more edge
121    /// cases where more control over the variables is required.
122    #[inline]
123    pub fn derivative1(mut self, variables: usize, index: usize) -> Self {
124        self.eps1 = Derivative::derivative_generic(Dyn(variables), U1, index);
125        self
126    }
127
128    /// Set the 2nd dimension derivative part of variable `index` to 1.
129    ///
130    /// For most cases, the [`partial_hessian`](crate::partial_hessian) function provides a
131    /// convenient interface to calculate derivatives. This function exists for the more edge
132    /// cases where more control over the variables is required.
133    #[inline]
134    pub fn derivative2(mut self, variables: usize, index: usize) -> Self {
135        self.eps2 = Derivative::derivative_generic(U1, Dyn(variables), index);
136        self
137    }
138}
139
140/* chain rule */
141impl<T: DualNum<F>, F: Float, M: Dim, N: Dim> HyperDualVec<T, F, M, N>
142where
143    DefaultAllocator: Allocator<M> + Allocator<M, N> + Allocator<U1, N>,
144{
145    #[inline]
146    fn chain_rule(&self, f0: T, f1: T, f2: T) -> Self {
147        Self::new(
148            f0,
149            &self.eps1 * f1.clone(),
150            &self.eps2 * f1.clone(),
151            &self.eps1eps2 * f1 + &self.eps1 * &self.eps2 * f2,
152        )
153    }
154}
155
156/* product rule */
157impl<T: DualNum<F>, F: Float, M: Dim, N: Dim> Mul<&HyperDualVec<T, F, M, N>>
158    for &HyperDualVec<T, F, M, N>
159where
160    DefaultAllocator: Allocator<M> + Allocator<M, N> + Allocator<U1, N>,
161{
162    type Output = HyperDualVec<T, F, M, N>;
163    #[inline]
164    fn mul(self, other: &HyperDualVec<T, F, M, N>) -> HyperDualVec<T, F, M, N> {
165        HyperDualVec::new(
166            self.re.clone() * other.re.clone(),
167            &other.eps1 * self.re.clone() + &self.eps1 * other.re.clone(),
168            &other.eps2 * self.re.clone() + &self.eps2 * other.re.clone(),
169            &other.eps1eps2 * self.re.clone()
170                + &self.eps1 * &other.eps2
171                + &other.eps1 * &self.eps2
172                + &self.eps1eps2 * other.re.clone(),
173        )
174    }
175}
176
177/* quotient rule */
178impl<T: DualNum<F>, F: Float, M: Dim, N: Dim> Div<&HyperDualVec<T, F, M, N>>
179    for &HyperDualVec<T, F, M, N>
180where
181    DefaultAllocator: Allocator<M> + Allocator<M, N> + Allocator<U1, N>,
182{
183    type Output = HyperDualVec<T, F, M, N>;
184    #[inline]
185    fn div(self, other: &HyperDualVec<T, F, M, N>) -> HyperDualVec<T, F, M, N> {
186        let inv = other.re.recip();
187        let inv2 = inv.clone() * &inv;
188        HyperDualVec::new(
189            self.re.clone() * &inv,
190            (&self.eps1 * other.re.clone() - &other.eps1 * self.re.clone()) * inv2.clone(),
191            (&self.eps2 * other.re.clone() - &other.eps2 * self.re.clone()) * inv2.clone(),
192            &self.eps1eps2 * inv.clone()
193                - (&other.eps1eps2 * self.re.clone()
194                    + &self.eps1 * &other.eps2
195                    + &other.eps1 * &self.eps2)
196                    * inv2.clone()
197                + &other.eps1
198                    * &other.eps2
199                    * ((T::one() + T::one()) * self.re.clone() * inv2 * inv),
200        )
201    }
202}
203
204/* string conversions */
205impl<T: DualNum<F>, F: fmt::Display, M: Dim, N: Dim> fmt::Display for HyperDualVec<T, F, M, N>
206where
207    DefaultAllocator: Allocator<M> + Allocator<M, N> + Allocator<U1, N>,
208{
209    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
210        write!(f, "{}", self.re)?;
211        self.eps1.fmt(f, "ε1")?;
212        self.eps2.fmt(f, "ε2")?;
213        self.eps1eps2.fmt(f, "ε1ε2")
214    }
215}
216
217impl_second_derivatives!(
218    HyperDualVec,
219    [eps1, eps2, eps1eps2],
220    [M, N],
221    [M],
222    [M, N],
223    [U1, N]
224);
225impl_dual!(
226    HyperDualVec,
227    [eps1, eps2, eps1eps2],
228    [M, N],
229    [M],
230    [M, N],
231    [U1, N]
232);