1use crate::{Derivative, DualNum, DualNumFloat, DualStruct};
2use nalgebra::allocator::Allocator;
3use nalgebra::{Const, DefaultAllocator, Dim, Dyn, U1};
4use num_traits::{Float, FloatConst, FromPrimitive, Inv, Num, One, Signed, Zero};
5use std::fmt;
6use std::iter::{Product, Sum};
7use std::marker::PhantomData;
8use std::ops::{
9 Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign,
10};
11
12#[derive(PartialEq, Eq, Clone, Debug)]
14pub struct HyperDualVec<T: DualNum<F>, F, M: Dim, N: Dim>
15where
16 DefaultAllocator: Allocator<M> + Allocator<M, N> + Allocator<U1, N>,
17{
18 pub re: T,
20 pub eps1: Derivative<T, F, M, U1>,
22 pub eps2: Derivative<T, F, U1, N>,
24 pub eps1eps2: Derivative<T, F, M, N>,
26 f: PhantomData<F>,
27}
28
29impl<T: DualNum<F> + Copy, F: Copy, const M: usize, const N: usize> Copy
30 for HyperDualVec<T, F, Const<M>, Const<N>>
31{
32}
33
34#[cfg(feature = "ndarray")]
35impl<T: DualNum<F>, F: DualNumFloat, M: Dim, N: Dim> ndarray::ScalarOperand
36 for HyperDualVec<T, F, M, N>
37where
38 DefaultAllocator: Allocator<M> + Allocator<M, N> + Allocator<U1, N>,
39{
40}
41
42pub type HyperDualVec32<M, N> = HyperDualVec<f32, f32, M, N>;
43pub type HyperDualVec64<M, N> = HyperDualVec<f64, f64, M, N>;
44pub type HyperDualSVec32<const M: usize, const N: usize> =
45 HyperDualVec<f32, f32, Const<M>, Const<N>>;
46pub type HyperDualSVec64<const M: usize, const N: usize> =
47 HyperDualVec<f64, f64, Const<M>, Const<N>>;
48pub type HyperDualDVec32 = HyperDualVec<f32, f32, Dyn, Dyn>;
49pub type HyperDualDVec64 = HyperDualVec<f64, f64, Dyn, Dyn>;
50
51impl<T: DualNum<F>, F, M: Dim, N: Dim> HyperDualVec<T, F, M, N>
52where
53 DefaultAllocator: Allocator<M> + Allocator<M, N> + Allocator<U1, N>,
54{
55 #[inline]
57 pub fn new(
58 re: T,
59 eps1: Derivative<T, F, M, U1>,
60 eps2: Derivative<T, F, U1, N>,
61 eps1eps2: Derivative<T, F, M, N>,
62 ) -> Self {
63 Self {
64 re,
65 eps1,
66 eps2,
67 eps1eps2,
68 f: PhantomData,
69 }
70 }
71}
72
73impl<T: DualNum<F>, F, M: Dim, N: Dim> HyperDualVec<T, F, M, N>
74where
75 DefaultAllocator: Allocator<M> + Allocator<M, N> + Allocator<U1, N>,
76{
77 #[inline]
79 pub fn from_re(re: T) -> Self {
80 Self::new(
81 re,
82 Derivative::none(),
83 Derivative::none(),
84 Derivative::none(),
85 )
86 }
87}
88
89impl<T: DualNum<F>, F: Float, M: Dim, N: Dim> HyperDualVec<T, F, M, N>
91where
92 DefaultAllocator: Allocator<M> + Allocator<M, N> + Allocator<U1, N>,
93{
94 #[inline]
95 fn chain_rule(&self, f0: T, f1: T, f2: T) -> Self {
96 Self::new(
97 f0,
98 &self.eps1 * f1.clone(),
99 &self.eps2 * f1.clone(),
100 &self.eps1eps2 * f1 + &self.eps1 * &self.eps2 * f2,
101 )
102 }
103}
104
105impl<T: DualNum<F>, F: Float, M: Dim, N: Dim> Mul<&HyperDualVec<T, F, M, N>>
107 for &HyperDualVec<T, F, M, N>
108where
109 DefaultAllocator: Allocator<M> + Allocator<M, N> + Allocator<U1, N>,
110{
111 type Output = HyperDualVec<T, F, M, N>;
112 #[inline]
113 fn mul(self, other: &HyperDualVec<T, F, M, N>) -> HyperDualVec<T, F, M, N> {
114 HyperDualVec::new(
115 self.re.clone() * other.re.clone(),
116 &other.eps1 * self.re.clone() + &self.eps1 * other.re.clone(),
117 &other.eps2 * self.re.clone() + &self.eps2 * other.re.clone(),
118 &other.eps1eps2 * self.re.clone()
119 + &self.eps1 * &other.eps2
120 + &other.eps1 * &self.eps2
121 + &self.eps1eps2 * other.re.clone(),
122 )
123 }
124}
125
126impl<T: DualNum<F>, F: Float, M: Dim, N: Dim> Div<&HyperDualVec<T, F, M, N>>
128 for &HyperDualVec<T, F, M, N>
129where
130 DefaultAllocator: Allocator<M> + Allocator<M, N> + Allocator<U1, N>,
131{
132 type Output = HyperDualVec<T, F, M, N>;
133 #[inline]
134 fn div(self, other: &HyperDualVec<T, F, M, N>) -> HyperDualVec<T, F, M, N> {
135 let inv = other.re.recip();
136 let inv2 = inv.clone() * &inv;
137 HyperDualVec::new(
138 self.re.clone() * &inv,
139 (&self.eps1 * other.re.clone() - &other.eps1 * self.re.clone()) * inv2.clone(),
140 (&self.eps2 * other.re.clone() - &other.eps2 * self.re.clone()) * inv2.clone(),
141 &self.eps1eps2 * inv.clone()
142 - (&other.eps1eps2 * self.re.clone()
143 + &self.eps1 * &other.eps2
144 + &other.eps1 * &self.eps2)
145 * inv2.clone()
146 + &other.eps1
147 * &other.eps2
148 * ((T::one() + T::one()) * self.re.clone() * inv2 * inv),
149 )
150 }
151}
152
153impl<T: DualNum<F>, F: fmt::Display, M: Dim, N: Dim> fmt::Display for HyperDualVec<T, F, M, N>
155where
156 DefaultAllocator: Allocator<M> + Allocator<M, N> + Allocator<U1, N>,
157{
158 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
159 write!(f, "{}", self.re)?;
160 self.eps1.fmt(f, "ε1")?;
161 self.eps2.fmt(f, "ε2")?;
162 self.eps1eps2.fmt(f, "ε1ε2")
163 }
164}
165
166impl_second_derivatives!(
167 HyperDualVec,
168 [eps1, eps2, eps1eps2],
169 [M, N],
170 [M],
171 [M, N],
172 [U1, N]
173);
174impl_dual!(
175 HyperDualVec,
176 [eps1, eps2, eps1eps2],
177 [M, N],
178 [M],
179 [M, N],
180 [U1, N]
181);