1use crate::{DualNum, DualNumFloat, DualStruct};
2use num_traits::{Float, FloatConst, FromPrimitive, Inv, Num, One, Signed, Zero};
3#[cfg(feature = "serde")]
4use serde::{Deserialize, Serialize};
5use std::fmt;
6use std::iter::{Product, Sum};
7use std::marker::PhantomData;
8use std::ops::*;
9
10#[derive(PartialEq, Eq, Copy, Clone, Debug)]
12#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
13pub struct Dual3<T, F = T> {
14 pub re: T,
16 pub v1: T,
18 pub v2: T,
20 pub v3: T,
22 #[cfg_attr(feature = "serde", serde(skip))]
23 f: PhantomData<F>,
24}
25
26#[cfg(feature = "ndarray")]
27impl<T: DualNum<F>, F: DualNumFloat> ndarray::ScalarOperand for Dual3<T, F> {}
28
29pub type Dual3_32 = Dual3<f32>;
30pub type Dual3_64 = Dual3<f64>;
31
32impl<T, F> Dual3<T, F> {
33 #[inline]
35 pub fn new(re: T, v1: T, v2: T, v3: T) -> Self {
36 Self {
37 re,
38 v1,
39 v2,
40 v3,
41 f: PhantomData,
42 }
43 }
44}
45
46impl<T: DualNum<F>, F> Dual3<T, F> {
47 #[inline]
49 pub fn from_re(re: T) -> Self {
50 Self::new(re, T::zero(), T::zero(), T::zero())
51 }
52
53 #[inline]
63 pub fn derivative(mut self) -> Self {
64 self.v1 = T::one();
65 self
66 }
67}
68
69impl<T: DualNum<F>, F: Float> Dual3<T, F> {
70 #[inline]
71 fn chain_rule(&self, f0: T, f1: T, f2: T, f3: T) -> Self {
72 let three = T::one() + T::one() + T::one();
73 Self::new(
74 f0,
75 f1.clone() * &self.v1,
76 f2.clone() * &self.v1 * &self.v1 + f1.clone() * &self.v2,
77 f3 * &self.v1 * &self.v1 * &self.v1 + three * f2 * &self.v1 * &self.v2 + f1 * &self.v3,
78 )
79 }
80}
81
82impl<T: DualNum<F>, F: Float> Mul<&Dual3<T, F>> for &Dual3<T, F> {
83 type Output = Dual3<T, F>;
84 #[inline]
85 fn mul(self, rhs: &Dual3<T, F>) -> Dual3<T, F> {
86 let two = T::one() + T::one();
87 let three = T::one() + &two;
88 Dual3::new(
89 self.re.clone() * &rhs.re,
90 self.v1.clone() * &rhs.re + self.re.clone() * &rhs.v1,
91 self.v2.clone() * &rhs.re + two * &self.v1 * &rhs.v1 + self.re.clone() * &rhs.v2,
92 self.v3.clone() * &rhs.re
93 + three * (self.v2.clone() * &rhs.v1 + self.v1.clone() * &rhs.v2)
94 + self.re.clone() * &rhs.v3,
95 )
96 }
97}
98
99impl<T: DualNum<F>, F: Float> Div<&Dual3<T, F>> for &Dual3<T, F> {
100 type Output = Dual3<T, F>;
101 #[inline]
102 fn div(self, rhs: &Dual3<T, F>) -> Dual3<T, F> {
103 let rec = T::one() / &rhs.re;
104 let f0 = rec.clone();
105 let f1 = -f0.clone() * &rec;
106 let f2 = f1.clone() * &rec * F::from(-2.0).unwrap();
107 let f3 = f2.clone() * rec * F::from(-3.0).unwrap();
108 self * rhs.chain_rule(f0, f1, f2, f3)
109 }
110}
111
112impl<T: fmt::Display, F> fmt::Display for Dual3<T, F> {
114 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
115 write!(
116 f,
117 "{} + {}v1 + {}v2 + {}v3",
118 self.re, self.v1, self.v2, self.v3
119 )
120 }
121}
122
123impl_third_derivatives!(Dual3, [v1, v2, v3]);
124impl_dual!(Dual3, [v1, v2, v3]);