Skip to main content

num_dual/datatypes/
dual2_vec.rs

1use crate::{Derivative, DualNum, DualNumFloat, DualStruct};
2use nalgebra::allocator::Allocator;
3use nalgebra::*;
4use num_traits::{Float, FloatConst, FromPrimitive, Inv, Num, One, Signed, Zero};
5use std::fmt;
6use std::iter::{Product, Sum};
7use std::marker::PhantomData;
8use std::ops::{
9    Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign,
10};
11
12/// A vector second order dual number for the calculation of Hessians.
13#[derive(Clone, Debug)]
14pub struct Dual2Vec<T: DualNum<F>, F, D: Dim>
15where
16    DefaultAllocator: Allocator<U1, D> + Allocator<D, D>,
17{
18    /// Real part of the second order dual number
19    pub re: T,
20    /// Gradient part of the second order dual number
21    pub v1: Derivative<T, F, U1, D>,
22    /// Hessian part of the second order dual number
23    pub v2: Derivative<T, F, D, D>,
24    f: PhantomData<F>,
25}
26
27impl<T: DualNum<F> + Copy, F: Copy, const N: usize> Copy for Dual2Vec<T, F, Const<N>> {}
28
29#[cfg(feature = "ndarray")]
30impl<T: DualNum<F>, F: DualNumFloat, D: Dim> ndarray::ScalarOperand for Dual2Vec<T, F, D> where
31    DefaultAllocator: Allocator<U1, D> + Allocator<D, D>
32{
33}
34
35pub type Dual2SVec<T, F, const N: usize> = Dual2Vec<T, F, Const<N>>;
36pub type Dual2DVec<T, F> = Dual2Vec<T, F, Dyn>;
37pub type Dual2Vec32<D> = Dual2Vec<f32, f32, D>;
38pub type Dual2Vec64<D> = Dual2Vec<f64, f64, D>;
39pub type Dual2SVec32<const N: usize> = Dual2Vec<f32, f32, Const<N>>;
40pub type Dual2SVec64<const N: usize> = Dual2Vec<f64, f64, Const<N>>;
41pub type Dual2DVec32 = Dual2Vec<f32, f32, Dyn>;
42pub type Dual2DVec64 = Dual2Vec<f64, f64, Dyn>;
43
44impl<T: DualNum<F>, F, D: Dim> Dual2Vec<T, F, D>
45where
46    DefaultAllocator: Allocator<U1, D> + Allocator<D, D>,
47{
48    /// Create a new second order dual number from its fields.
49    #[inline]
50    pub fn new(re: T, v1: Derivative<T, F, U1, D>, v2: Derivative<T, F, D, D>) -> Self {
51        Self {
52            re,
53            v1,
54            v2,
55            f: PhantomData,
56        }
57    }
58}
59
60impl<T: DualNum<F>, F, const N: usize> Dual2SVec<T, F, N> {
61    /// Set the derivative part of variable `index` to 1.
62    ///
63    /// For most cases, the [`hessian`](crate::hessian) function provides a convenient
64    /// interface to calculate derivatives. This function exists for the more edge cases
65    /// where more control over the variables is required.
66    /// ```
67    /// # use num_dual::Dual2SVec64;
68    /// # use nalgebra::{U1, U2, matrix};
69    /// let x: Dual2SVec64<2> = Dual2SVec64::from_re(5.0).derivative(0);
70    /// let y: Dual2SVec64<2> = Dual2SVec64::from_re(3.0).derivative(1);
71    /// let z = x * x * y;
72    /// assert_eq!(z.re, 75.0);                                                 // x²y
73    /// assert_eq!(z.v1.unwrap_generic(U1, U2), matrix![30.0, 25.0]);           // [2xy, x²]
74    /// assert_eq!(z.v2.unwrap_generic(U2, U2), matrix![6.0, 10.0; 10.0, 0.0]); // [2y, 2x; 2x, 0]
75    /// ```
76    #[inline]
77    pub fn derivative(mut self, index: usize) -> Self {
78        self.v1 = Derivative::derivative_generic(U1, Const::<N>, index);
79        self
80    }
81}
82
83impl<T: DualNum<F>, F> Dual2DVec<T, F> {
84    /// Set the derivative part of variable `index` to 1.
85    ///
86    /// For most cases, the [`hessian`](crate::hessian) function provides a convenient interface
87    /// to calculate derivatives. This function exists for the more edge cases
88    /// where more control over the variables is required.
89    /// ```
90    /// # use num_dual::Dual2DVec64;
91    /// # use nalgebra::{Dyn, U1, dmatrix};
92    /// let x: Dual2DVec64 = Dual2DVec64::from_re(5.0).derivative(2, 0);
93    /// let y: Dual2DVec64 = Dual2DVec64::from_re(3.0).derivative(2, 1);
94    /// let z = &x * &x * y;
95    /// assert_eq!(z.re, 75.0);                                                          // x²y
96    /// assert_eq!(z.v1.unwrap_generic(U1, Dyn(2)), dmatrix![30.0, 25.0]);               // [2xy, x²]
97    /// assert_eq!(z.v2.unwrap_generic(Dyn(2), Dyn(2)), dmatrix![6.0, 10.0; 10.0, 0.0]); // [2y, 2x; 2x, 0]
98    /// ```
99    #[inline]
100    pub fn derivative(mut self, variables: usize, index: usize) -> Self {
101        self.v1 = Derivative::derivative_generic(U1, Dyn(variables), index);
102        self
103    }
104}
105
106impl<T: DualNum<F>, F, D: Dim> Dual2Vec<T, F, D>
107where
108    DefaultAllocator: Allocator<U1, D> + Allocator<D, D>,
109{
110    /// Create a new second order dual number from the real part.
111    #[inline]
112    pub fn from_re(re: T) -> Self {
113        Self::new(re, Derivative::none(), Derivative::none())
114    }
115}
116
117/* chain rule */
118impl<T: DualNum<F>, F: Float, D: Dim> Dual2Vec<T, F, D>
119where
120    DefaultAllocator: Allocator<U1, D> + Allocator<D, D>,
121{
122    #[inline]
123    fn chain_rule(&self, f0: T, f1: T, f2: T) -> Self {
124        Self::new(
125            f0,
126            &self.v1 * f1.clone(),
127            &self.v2 * f1 + self.v1.tr_mul(&self.v1) * f2,
128        )
129    }
130}
131
132/* product rule */
133impl<T: DualNum<F>, F: Float, D: Dim> Mul<&Dual2Vec<T, F, D>> for &Dual2Vec<T, F, D>
134where
135    DefaultAllocator: Allocator<U1, D> + Allocator<D, D>,
136{
137    type Output = Dual2Vec<T, F, D>;
138    #[inline]
139    fn mul(self, other: &Dual2Vec<T, F, D>) -> Dual2Vec<T, F, D> {
140        Dual2Vec::new(
141            self.re.clone() * other.re.clone(),
142            &other.v1 * self.re.clone() + &self.v1 * other.re.clone(),
143            &other.v2 * self.re.clone()
144                + self.v1.tr_mul(&other.v1)
145                + other.v1.tr_mul(&self.v1)
146                + &self.v2 * other.re.clone(),
147        )
148    }
149}
150
151/* quotient rule */
152impl<T: DualNum<F>, F: Float, D: Dim> Div<&Dual2Vec<T, F, D>> for &Dual2Vec<T, F, D>
153where
154    DefaultAllocator: Allocator<U1, D> + Allocator<D, D>,
155{
156    type Output = Dual2Vec<T, F, D>;
157    #[inline]
158    fn div(self, other: &Dual2Vec<T, F, D>) -> Dual2Vec<T, F, D> {
159        let inv = other.re.recip();
160        let inv2 = inv.clone() * inv.clone();
161        Dual2Vec::new(
162            self.re.clone() * inv.clone(),
163            (&self.v1 * other.re.clone() - &other.v1 * self.re.clone()) * inv2.clone(),
164            &self.v2 * inv.clone()
165                - (&other.v2 * self.re.clone()
166                    + self.v1.tr_mul(&other.v1)
167                    + other.v1.tr_mul(&self.v1))
168                    * inv2.clone()
169                + other.v1.tr_mul(&other.v1)
170                    * ((T::one() + T::one()) * self.re.clone() * inv2 * inv),
171        )
172    }
173}
174
175/* string conversions */
176impl<T: DualNum<F>, F: fmt::Display, D: Dim> fmt::Display for Dual2Vec<T, F, D>
177where
178    DefaultAllocator: Allocator<U1, D> + Allocator<D, D>,
179{
180    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
181        write!(f, "{}", self.re)?;
182        self.v1.fmt(f, "ε1")?;
183        self.v2.fmt(f, "ε1²")
184    }
185}
186
187impl_second_derivatives!(Dual2Vec, [v1, v2], [D], [U1, D], [D, D]);
188impl_dual!(Dual2Vec, [v1, v2], [D], [U1, D], [D, D]);
189impl_nalgebra!(Dual2Vec, [v1, v2], [D], [U1, D], [D, D]);