1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
use crate::{DualNum, DualNumFloat};
use num_traits::{Float, FloatConst, FromPrimitive, Inv, Num, One, Signed, Zero};
use std::convert::Infallible;
use std::fmt;
use std::iter::{Product, Sum};
use std::marker::PhantomData;
use std::ops::*;

/// A scalar third order dual number for the calculation of third derivatives.
#[derive(PartialEq, Eq, Copy, Clone, Debug)]
pub struct Dual3<T, F = T> {
    /// Real part of the third order dual number
    pub re: T,
    /// First derivative part of the third order dual number
    pub v1: T,
    /// Second derivative part of the third order dual number
    pub v2: T,
    /// Third derivative part of the third order dual number
    pub v3: T,
    f: PhantomData<F>,
}

pub type Dual3_32 = Dual3<f32>;
pub type Dual3_64 = Dual3<f64>;

impl<T, F> Dual3<T, F> {
    /// Create a new third order dual number from its fields.
    #[inline]
    pub fn new(re: T, v1: T, v2: T, v3: T) -> Self {
        Self {
            re,
            v1,
            v2,
            v3,
            f: PhantomData,
        }
    }
}

impl<T: DualNum<F>, F> Dual3<T, F> {
    /// Create a new third order dual number from the real part.
    #[inline]
    pub fn from_re(re: T) -> Self {
        Self::new(re, T::zero(), T::zero(), T::zero())
    }

    /// Set the first derivative part to 1.
    /// ```
    /// # use num_dual::{Dual3, DualNum};
    /// let x = Dual3::from_re(5.0).derivative().powi(3);
    /// assert_eq!(x.re, 125.0);
    /// assert_eq!(x.v1, 75.0);
    /// assert_eq!(x.v2, 30.0);
    /// assert_eq!(x.v3, 6.0);
    /// ```
    #[inline]
    pub fn derivative(mut self) -> Self {
        self.v1 = T::one();
        self
    }
}

/// Calculate the third derivative of a univariate function.
/// ```
/// # use num_dual::{third_derivative, DualNum};
/// let (f, df, d2f, d3f) = third_derivative(|x| x.powi(3), 5.0);
/// assert_eq!(f, 125.0);      // x³
/// assert_eq!(df, 75.0);      // 3x²
/// assert_eq!(d2f, 30.0);     // 6x
/// assert_eq!(d3f, 6.0);      // 6
/// ```
pub fn third_derivative<G, T: DualNum<F>, F>(g: G, x: T) -> (T, T, T, T)
where
    G: FnOnce(Dual3<T, F>) -> Dual3<T, F>,
{
    try_third_derivative(|x| Ok::<_, Infallible>(g(x)), x).unwrap()
}

/// Variant of [third_derivative] for fallible functions.
pub fn try_third_derivative<G, T: DualNum<F>, F, E>(g: G, x: T) -> Result<(T, T, T, T), E>
where
    G: FnOnce(Dual3<T, F>) -> Result<Dual3<T, F>, E>,
{
    let mut x = Dual3::from_re(x);
    x.v1 = T::one();
    g(x).map(|r| (r.re, r.v1, r.v2, r.v3))
}

impl<T: DualNum<F>, F: Float> Dual3<T, F> {
    #[inline]
    fn chain_rule(&self, f0: T, f1: T, f2: T, f3: T) -> Self {
        let three = T::one() + T::one() + T::one();
        Self::new(
            f0,
            f1.clone() * &self.v1,
            f2.clone() * &self.v1 * &self.v1 + f1.clone() * &self.v2,
            f3 * &self.v1 * &self.v1 * &self.v1 + three * f2 * &self.v1 * &self.v2 + f1 * &self.v3,
        )
    }
}

impl<'a, 'b, T: DualNum<F>, F: Float> Mul<&'a Dual3<T, F>> for &'b Dual3<T, F> {
    type Output = Dual3<T, F>;
    #[inline]
    fn mul(self, rhs: &Dual3<T, F>) -> Dual3<T, F> {
        let two = T::one() + T::one();
        let three = T::one() + &two;
        Dual3::new(
            self.re.clone() * &rhs.re,
            self.v1.clone() * &rhs.re + self.re.clone() * &rhs.v1,
            self.v2.clone() * &rhs.re + two * &self.v1 * &rhs.v1 + self.re.clone() * &rhs.v2,
            self.v3.clone() * &rhs.re
                + three * (self.v2.clone() * &rhs.v1 + self.v1.clone() * &rhs.v2)
                + self.re.clone() * &rhs.v3,
        )
    }
}

impl<'a, 'b, T: DualNum<F>, F: Float> Div<&'a Dual3<T, F>> for &'b Dual3<T, F> {
    type Output = Dual3<T, F>;
    #[inline]
    fn div(self, rhs: &Dual3<T, F>) -> Dual3<T, F> {
        let rec = T::one() / &rhs.re;
        let f0 = rec.clone();
        let f1 = -f0.clone() * &rec;
        let f2 = f1.clone() * &rec * F::from(-2.0).unwrap();
        let f3 = f2.clone() * rec * F::from(-3.0).unwrap();
        self * rhs.chain_rule(f0, f1, f2, f3)
    }
}

/* string conversions */
impl<T: fmt::Display, F> fmt::Display for Dual3<T, F> {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        write!(
            f,
            "{} + {}v1 + {}v2 + {}v3",
            self.re, self.v1, self.v2, self.v3
        )
    }
}

impl_third_derivatives!(Dual3, [v1, v2, v3]);
impl_dual!(Dual3, [v1, v2, v3]);