1use crate::{DualNum, DualNumFloat, DualStruct};
2use num_traits::{Float, FloatConst, FromPrimitive, Inv, Num, One, Signed, Zero};
3#[cfg(feature = "serde")]
4use serde::{Deserialize, Serialize};
5use std::fmt;
6use std::iter::{Product, Sum};
7use std::marker::PhantomData;
8use std::ops::{
9 Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign,
10};
11
12#[derive(PartialEq, Eq, Copy, Clone, Debug)]
14#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
15pub struct HyperDual<T: DualNum<F>, F> {
16 pub re: T,
18 pub eps1: T,
20 pub eps2: T,
22 pub eps1eps2: T,
24 #[cfg_attr(feature = "serde", serde(skip))]
25 f: PhantomData<F>,
26}
27
28#[cfg(feature = "ndarray")]
29impl<T: DualNum<F>, F: DualNumFloat> ndarray::ScalarOperand for HyperDual<T, F> {}
30
31pub type HyperDual32 = HyperDual<f32, f32>;
32pub type HyperDual64 = HyperDual<f64, f64>;
33
34impl<T: DualNum<F>, F> HyperDual<T, F> {
35 #[inline]
37 pub fn new(re: T, eps1: T, eps2: T, eps1eps2: T) -> Self {
38 Self {
39 re,
40 eps1,
41 eps2,
42 eps1eps2,
43 f: PhantomData,
44 }
45 }
46}
47
48impl<T: DualNum<F>, F> HyperDual<T, F> {
49 #[inline]
51 pub fn derivative1(mut self) -> Self {
52 self.eps1 = T::one();
53 self
54 }
55
56 #[inline]
58 pub fn derivative2(mut self) -> Self {
59 self.eps2 = T::one();
60 self
61 }
62}
63
64impl<T: DualNum<F>, F> HyperDual<T, F> {
65 #[inline]
67 pub fn from_re(re: T) -> Self {
68 Self::new(re, T::zero(), T::zero(), T::zero())
69 }
70}
71
72impl<T: DualNum<F>, F: Float> HyperDual<T, F> {
74 #[inline]
75 fn chain_rule(&self, f0: T, f1: T, f2: T) -> Self {
76 Self::new(
77 f0,
78 self.eps1.clone() * f1.clone(),
79 self.eps2.clone() * f1.clone(),
80 self.eps1eps2.clone() * f1 + self.eps1.clone() * self.eps2.clone() * f2,
81 )
82 }
83}
84
85impl<T: DualNum<F>, F: Float> Mul<&HyperDual<T, F>> for &HyperDual<T, F> {
87 type Output = HyperDual<T, F>;
88 #[inline]
89 fn mul(self, other: &HyperDual<T, F>) -> HyperDual<T, F> {
90 HyperDual::new(
91 self.re.clone() * other.re.clone(),
92 other.eps1.clone() * self.re.clone() + self.eps1.clone() * other.re.clone(),
93 other.eps2.clone() * self.re.clone() + self.eps2.clone() * other.re.clone(),
94 other.eps1eps2.clone() * self.re.clone()
95 + self.eps1.clone() * other.eps2.clone()
96 + other.eps1.clone() * self.eps2.clone()
97 + self.eps1eps2.clone() * other.re.clone(),
98 )
99 }
100}
101
102impl<T: DualNum<F>, F: Float> Div<&HyperDual<T, F>> for &HyperDual<T, F> {
104 type Output = HyperDual<T, F>;
105 #[inline]
106 fn div(self, other: &HyperDual<T, F>) -> HyperDual<T, F> {
107 let inv = other.re.recip();
108 let inv2 = inv.clone() * &inv;
109 HyperDual::new(
110 self.re.clone() * &inv,
111 (self.eps1.clone() * other.re.clone() - other.eps1.clone() * self.re.clone())
112 * inv2.clone(),
113 (self.eps2.clone() * other.re.clone() - other.eps2.clone() * self.re.clone())
114 * inv2.clone(),
115 self.eps1eps2.clone() * inv.clone()
116 - (other.eps1eps2.clone() * self.re.clone()
117 + self.eps1.clone() * other.eps2.clone()
118 + other.eps1.clone() * self.eps2.clone())
119 * inv2.clone()
120 + other.eps1.clone()
121 * other.eps2.clone()
122 * ((T::one() + T::one()) * self.re.clone() * inv2 * inv),
123 )
124 }
125}
126
127impl<T: DualNum<F>, F: fmt::Display> fmt::Display for HyperDual<T, F> {
129 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
130 fmt::Display::fmt(&self.re, f)?;
131 write!(f, " + ")?;
132 fmt::Display::fmt(&self.eps1, f)?;
133 write!(f, "ε1 + ")?;
134 fmt::Display::fmt(&self.eps2, f)?;
135 write!(f, "ε2 + ")?;
136 fmt::Display::fmt(&self.eps1eps2, f)?;
137 write!(f, "ε1ε2")
138 }
139}
140
141impl_second_derivatives!(HyperDual, [eps1, eps2, eps1eps2]);
142impl_dual!(HyperDual, [eps1, eps2, eps1eps2]);