1use crate::{DualNum, DualNumFloat, DualStruct};
2use num_traits::{Float, FloatConst, FromPrimitive, Inv, Num, One, Signed, Zero};
3#[cfg(feature = "serde")]
4use serde::{Deserialize, Serialize};
5use std::fmt;
6use std::iter::{Product, Sum};
7use std::marker::PhantomData;
8use std::ops::{
9 Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign,
10};
11
12#[derive(Copy, Clone, Debug)]
14#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
15pub struct Dual2<T: DualNum<F>, F> {
16 pub re: T,
18 pub v1: T,
20 pub v2: T,
22 #[cfg_attr(feature = "serde", serde(skip))]
23 f: PhantomData<F>,
24}
25
26#[cfg(feature = "ndarray")]
27impl<T: DualNum<F>, F: DualNumFloat> ndarray::ScalarOperand for Dual2<T, F> {}
28
29pub type Dual2_32 = Dual2<f32, f32>;
30pub type Dual2_64 = Dual2<f64, f64>;
31
32impl<T: DualNum<F>, F> Dual2<T, F> {
33 #[inline]
35 pub fn new(re: T, v1: T, v2: T) -> Self {
36 Self {
37 re,
38 v1,
39 v2,
40 f: PhantomData,
41 }
42 }
43}
44
45impl<T: DualNum<F>, F> Dual2<T, F> {
46 #[inline]
68 pub fn derivative(mut self) -> Self {
69 self.v1 = T::one();
70 self
71 }
72}
73
74impl<T: DualNum<F>, F> Dual2<T, F> {
75 #[inline]
77 pub fn from_re(re: T) -> Self {
78 Self::new(re, T::zero(), T::zero())
79 }
80}
81
82impl<T: DualNum<F>, F: Float> Dual2<T, F> {
84 #[inline]
85 fn chain_rule(&self, f0: T, f1: T, f2: T) -> Self {
86 Self::new(
87 f0,
88 self.v1.clone() * f1.clone(),
89 self.v2.clone() * f1 + self.v1.clone() * self.v1.clone() * f2,
90 )
91 }
92}
93
94impl<T: DualNum<F>, F: Float> Mul<&Dual2<T, F>> for &Dual2<T, F> {
96 type Output = Dual2<T, F>;
97 #[inline]
98 fn mul(self, other: &Dual2<T, F>) -> Dual2<T, F> {
99 Dual2::new(
100 self.re.clone() * other.re.clone(),
101 other.v1.clone() * self.re.clone() + self.v1.clone() * other.re.clone(),
102 other.v2.clone() * self.re.clone()
103 + self.v1.clone() * other.v1.clone()
104 + other.v1.clone() * self.v1.clone()
105 + self.v2.clone() * other.re.clone(),
106 )
107 }
108}
109
110impl<T: DualNum<F>, F: Float> Div<&Dual2<T, F>> for &Dual2<T, F> {
112 type Output = Dual2<T, F>;
113 #[inline]
114 fn div(self, other: &Dual2<T, F>) -> Dual2<T, F> {
115 let inv = other.re.recip();
116 let inv2 = inv.clone() * inv.clone();
117 Dual2::new(
118 self.re.clone() * inv.clone(),
119 (self.v1.clone() * other.re.clone() - other.v1.clone() * self.re.clone())
120 * inv2.clone(),
121 self.v2.clone() * inv.clone()
122 - (other.v2.clone() * self.re.clone()
123 + self.v1.clone() * other.v1.clone()
124 + other.v1.clone() * self.v1.clone())
125 * inv2.clone()
126 + other.v1.clone()
127 * other.v1.clone()
128 * ((T::one() + T::one()) * self.re.clone() * inv2 * inv),
129 )
130 }
131}
132
133impl<T: DualNum<F>, F: fmt::Display> fmt::Display for Dual2<T, F> {
135 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
136 write!(f, "{} + {}ε1 + {}ε1²", self.re, self.v1, self.v2)
137 }
138}
139
140impl_second_derivatives!(Dual2, [v1, v2]);
141impl_dual!(Dual2, [v1, v2]);
142impl_nalgebra!(Dual2, [v1, v2]);