num_dual/
impl_derivatives.rs

1#[macro_export]
2macro_rules! impl_derivatives {
3    ($deriv:ident, $nderiv:expr, $struct:ident, [$($im:ident),*]$(, [$($dim:tt),*]$(, [$($ddim:tt),*])*)?) => {
4        impl<T: DualNum<F>, F: DualNumFloat$($(, $dim: Dim)*)?> DualNum<F> for $struct<T, F$($(, $dim)*)?>
5        where
6        // $($(DefaultAllocator: Allocator<$dim> + Allocator<U1, $dim> + Allocator<$dim, $dim>,)*)?
7        $($(DefaultAllocator: Allocator<$($ddim,)*>),*)?
8        {
9            const NDERIV: usize = T::NDERIV + $nderiv;
10
11            #[inline]
12            fn recip(&self) -> Self {
13                let rec = self.re.recip();
14                let f0 = rec.clone();
15                first!($deriv, let f1 = -f0.clone() * &rec;);
16                second!($deriv, let f2 = f1.clone() * &rec * F::from(-2.0).unwrap(););
17                third!($deriv, let f3 = f2.clone() * rec * F::from(-3.0).unwrap(););
18                chain_rule!($deriv, Self::chain_rule(self, f0, f1, f2, f3))
19            }
20
21            #[inline]
22            fn powi(&self, exp: i32) -> Self {
23                match exp {
24                    0 => Self::one(),
25                    1 => self.clone(),
26                    2 => self * self,
27                    _ => {
28                        let pow3 = self.re.powi(exp - 3);
29                        let f0 = pow3.clone() * &self.re * &self.re * &self.re;
30                        first!($deriv, let f1 = pow3.clone() * &self.re * &self.re * F::from(exp).unwrap(););
31                        second!($deriv, let f2 = pow3.clone() * &self.re * F::from(exp * (exp - 1)).unwrap(););
32                        third!($deriv, let f3 = pow3 * F::from(exp * (exp - 1) * (exp - 2)).unwrap(););
33                        chain_rule!($deriv, Self::chain_rule(self, f0, f1, f2, f3))
34                    }
35                }
36            }
37
38            #[inline]
39            fn powf(&self, n: F) -> Self {
40                if n.is_zero() {
41                    Self::one()
42                } else if n.is_one() {
43                    self.clone()
44                } else if (n - F::one() - F::one()).abs() < F::epsilon() {
45                    self * self
46                } else {
47                    let n1 = n - F::one();
48                    let n2 = n1 - F::one();
49                    let n3 = n2 - F::one();
50                    let pow3 = self.re.powf(n3);
51                    let f0 = pow3.clone() * &self.re * &self.re * &self.re;
52                    first!($deriv, let f1 = pow3.clone() * &self.re * &self.re * n;);
53                    second!($deriv, let f2 = pow3.clone() * &self.re * n * n1;);
54                    third!($deriv, let f3 = pow3 * n * n1 * n2;);
55                    chain_rule!($deriv, Self::chain_rule(self, f0, f1, f2, f3))
56                }
57            }
58
59            #[inline]
60            fn sqrt(&self) -> Self {
61                first!($deriv, let rec = self.re.recip(););
62                first!($deriv, let half = F::from(0.5).unwrap(););
63                let f0 = self.re.sqrt();
64                first!($deriv, let f1 = f0.clone() * &rec * half;);
65                second!($deriv, let f2 = -f1.clone() * &rec * half;);
66                third!($deriv, let f3 = f2.clone() * rec * (-F::one() - half););
67                chain_rule!($deriv, Self::chain_rule(self, f0, f1, f2, f3))
68            }
69
70            #[inline]
71            fn cbrt(&self) -> Self {
72                first!($deriv, let rec = self.re.recip(););
73                first!($deriv, let third = F::from(1.0 / 3.0).unwrap(););
74                let f0 = self.re.cbrt();
75                first!($deriv, let f1 = f0.clone() * &rec * third;);
76                second!($deriv, let f2 = f1.clone() * &rec * (third - F::one()););
77                third!($deriv, let f3 = f2.clone() * rec * (third - F::one() - F::one()););
78                chain_rule!($deriv, Self::chain_rule(self, f0, f1, f2, f3))
79            }
80
81
82            #[inline]
83            fn exp(&self) -> Self {
84                let f = self.re.exp();
85                chain_rule!($deriv, Self::chain_rule(self, f.clone(), f.clone(), f.clone(), f))
86            }
87
88            #[inline]
89            fn exp2(&self) -> Self {
90                first!($deriv, let ln2 = F::from(2.0).unwrap().ln(););
91                let f0 = self.re.exp2();
92                first!($deriv, let f1 = f0.clone() * ln2;);
93                second!($deriv, let f2 = f1.clone() * ln2;);
94                third!($deriv, let f3 = f2.clone() * ln2;);
95                chain_rule!($deriv, Self::chain_rule(self, f0, f1, f2, f3))
96            }
97
98            #[inline]
99            fn exp_m1(&self) -> Self {
100                let f0 = self.re.exp_m1();
101                first!($deriv, let f1 = self.re.exp(););
102                chain_rule!($deriv, Self::chain_rule(self, f0, f1.clone(), f1.clone(), f1))
103            }
104
105            #[inline]
106            fn ln(&self) -> Self {
107                first!($deriv, let rec = self.re.recip(););
108                let f0 = self.re.ln();
109                first!($deriv, let f1 = rec.clone(););
110                second!($deriv, let f2 = -f1.clone() * &rec;);
111                third!($deriv, let f3 = f2.clone() * rec * F::from(-2.0).unwrap(););
112                chain_rule!($deriv, Self::chain_rule(self, f0, f1, f2, f3))
113            }
114
115            #[inline]
116            fn log(&self, base: F) -> Self {
117                first!($deriv, let rec = self.re.recip(););
118                let f0 = self.re.log(base);
119                first!($deriv, let f1 = rec.clone() / base.ln(););
120                second!($deriv, let f2 = -f1.clone() * &rec;);
121                third!($deriv, let f3 = f2.clone() * rec * F::from(-2.0).unwrap(););
122                chain_rule!($deriv, Self::chain_rule(self, f0, f1, f2, f3))
123            }
124
125            #[inline]
126            fn log2(&self) -> Self {
127                first!($deriv, let rec = self.re.recip(););
128                let f0 = self.re.log2();
129                first!($deriv, let f1 = rec.clone() / (F::one() + F::one()).ln(););
130                second!($deriv, let f2 = -f1.clone() * &rec;);
131                third!($deriv, let f3 = f2.clone() * rec * F::from(-2.0).unwrap(););
132                chain_rule!($deriv, Self::chain_rule(self, f0, f1, f2, f3))
133            }
134
135            #[inline]
136            fn log10(&self) -> Self {
137                first!($deriv, let rec = self.re.recip(););
138                let f0 = self.re.log10();
139                first!($deriv, let f1 = rec.clone() / F::from(10.0).unwrap().ln(););
140                second!($deriv, let f2 = -f1.clone() * &rec;);
141                third!($deriv, let f3 = f2.clone() * rec * F::from(-2.0).unwrap(););
142                chain_rule!($deriv, Self::chain_rule(self, f0, f1, f2, f3))
143            }
144
145            #[inline]
146            fn ln_1p(&self) -> Self {
147                first!($deriv, let rec = (self.re.clone() + F::one()).recip(););
148                let f0 = self.re.ln_1p();
149                first!($deriv, let f1 = rec.clone(););
150                second!($deriv, let f2 = -f1.clone() * &rec;);
151                third!($deriv, let f3 = f2.clone() * rec * F::from(-2.0).unwrap(););
152                chain_rule!($deriv, Self::chain_rule(self, f0, f1, f2, f3))
153            }
154
155            #[inline]
156            fn sin(&self) -> Self {
157                zeroth!($deriv, let s = self.re.sin(););
158                first!($deriv, let (s, c) = self.re.sin_cos(););
159                chain_rule!($deriv, Self::chain_rule(self, s.clone(), c.clone(), -s, -c))
160            }
161
162            #[inline]
163            fn cos(&self) -> Self {
164                zeroth!($deriv, let c = self.re.cos(););
165                first!($deriv, let (s, c) = self.re.sin_cos(););
166                chain_rule!($deriv, Self::chain_rule(self, c.clone(), -s.clone(), -c, s))
167            }
168
169            #[inline]
170            fn sin_cos(&self) -> (Self, Self) {
171                let (s, c) = self.re.sin_cos();
172                (
173                    chain_rule!($deriv, Self::chain_rule(self, s.clone(), c.clone(), -s.clone(), -c.clone())),
174                    chain_rule!($deriv, Self::chain_rule(self, c.clone(), -s.clone(), -c, s)))
175            }
176
177            #[inline]
178            fn tan(&self) -> Self {
179                let (sin, cos) = self.sin_cos();
180                sin / cos
181            }
182
183            #[inline]
184            fn asin(&self) -> Self {
185                first!($deriv, let rec = (T::one() - self.re.clone() * &self.re).recip(););
186                let f0 = self.re.asin();
187                first!($deriv, let f1 = rec.sqrt(););
188                second!($deriv, let f2 = self.re.clone() * &f1 * &rec;);
189                third!($deriv, let f3 = (self.re.clone() * &self.re * (F::one() + F::one()) + F::one()) * &f1 * &rec * rec;);
190                chain_rule!($deriv, Self::chain_rule(self, f0, f1, f2, f3))
191            }
192
193            #[inline]
194            fn acos(&self) -> Self {
195                first!($deriv, let rec = (T::one() - self.re.clone() * &self.re).recip(););
196                let f0 = self.re.acos();
197                first!($deriv, let f1 = -rec.sqrt(););
198                second!($deriv, let f2 = self.re.clone() * &f1 * &rec;);
199                third!($deriv, let f3 = (self.re.clone() * &self.re * (F::one() + F::one()) + F::one()) * &f1 * &rec * rec;);
200                chain_rule!($deriv, Self::chain_rule(self, f0, f1, f2, f3))
201            }
202
203            #[inline]
204            fn atan(&self) -> Self {
205                first!($deriv, let rec = (T::one() + self.re.clone() * &self.re).recip(););
206                let f0 = self.re.atan();
207                first!($deriv, let f1 = rec.clone(););
208                second!($deriv, let two = F::one() + F::one(););
209                second!($deriv, let f2 = -self.re.clone() * &f1 * &rec * two;);
210                third!($deriv, let f3 = (self.re.clone() * &self.re * F::from(6.0).unwrap() - two) * &f1 * &rec * rec;);
211                chain_rule!($deriv, Self::chain_rule(self, f0, f1, f2, f3))
212            }
213
214            #[inline]
215            fn atan2(&self, other: Self) -> Self {
216                let mut res = (self / other.clone()).atan();
217                res.re = self.re.atan2(other.re);
218                res
219            }
220
221            #[inline]
222            fn sinh(&self) -> Self {
223                let s = self.re.sinh();
224                first!($deriv, let c = self.re.cosh(););
225                chain_rule!($deriv, Self::chain_rule(self, s.clone(), c.clone(), s, c))
226            }
227
228            #[inline]
229            fn cosh(&self) -> Self {
230                first!($deriv, let s = self.re.sinh(););
231                let c = self.re.cosh();
232                chain_rule!($deriv, Self::chain_rule(self, c.clone(), s.clone(), c, s))
233            }
234
235            #[inline]
236            fn tanh(&self) -> Self {
237                self.sinh() / self.cosh()
238            }
239
240            #[inline]
241            fn asinh(&self) -> Self {
242                first!($deriv, let rec = (T::one() + self.re.clone() * &self.re).recip(););
243                let f0 = self.re.asinh();
244                first!($deriv, let f1 = rec.sqrt(););
245                second!($deriv, let f2 = -self.re.clone() * &f1 * &rec;);
246                third!($deriv, let f3 = (self.re.clone() * &self.re * (F::one() + F::one()) - F::one()) * &f1 * &rec * rec;);
247                chain_rule!($deriv, Self::chain_rule(self, f0, f1, f2, f3))
248            }
249
250            #[inline]
251            fn acosh(&self) -> Self {
252                first!($deriv, let rec = (self.re.clone() * &self.re - F::one()).recip(););
253                let f0 = self.re.acosh();
254                first!($deriv, let f1 = rec.sqrt(););
255                second!($deriv, let f2 = -self.re.clone() * &f1 * &rec;);
256                third!($deriv, let f3 = (self.re.clone() * &self.re * (F::one() + F::one()) + F::one()) * &f1 * &rec * rec;);
257                chain_rule!($deriv, Self::chain_rule(self, f0, f1, f2, f3))
258            }
259
260            #[inline]
261            fn atanh(&self) -> Self {
262                first!($deriv, let rec = (T::one() - self.re.clone() * &self.re).recip(););
263                let f0 = self.re.atanh();
264                first!($deriv, let f1 = rec.clone(););
265                second!($deriv, let two = F::one() + F::one(););
266                second!($deriv, let f2 = self.re.clone() * &f1 * &rec * two;);
267                third!($deriv, let f3 = (self.re.clone() * &self.re * F::from(6.0).unwrap() + two) * &f1 * &rec * rec;);
268                chain_rule!($deriv, Self::chain_rule(self, f0, f1, f2, f3))
269            }
270
271            #[inline]
272            fn sph_j0(&self) -> Self {
273                if self.re() < F::epsilon() {
274                    Self::one() - self * self / F::from(6.0).unwrap()
275                } else {
276                    self.sin() / self
277                }
278            }
279
280            #[inline]
281            fn sph_j1(&self) -> Self {
282                if self.re() < F::epsilon() {
283                    self.clone() / F::from(3.0).unwrap()
284                } else {
285                    let (s, c) = self.sin_cos();
286                    (s - self * c) / (self * self)
287                }
288            }
289
290            #[inline]
291            fn sph_j2(&self) -> Self {
292                if self.re() < F::epsilon() {
293                    self * self / F::from(15.0).unwrap()
294                } else {
295                    let (s, c) = self.sin_cos();
296                    let s2 = self * self;
297                    ((&s - self * c) * F::from(3.0).unwrap() - &s2 * s) / (s2 * self)
298                }
299            }
300        }
301    };
302}
303
304#[macro_export]
305macro_rules! zeroth {
306    (zeroth, $($code:tt)*) => {
307        $($code)*
308    };
309    (first, $($code:tt)*) => {};
310    (second, $($code:tt)*) => {};
311    (third, $($code:tt)*) => {};
312}
313
314#[macro_export]
315macro_rules! first {
316    (zeroth, $($code:tt)*) => {};
317    (first, $($code:tt)*) => {
318         $($code)*
319    };
320    (second, $($code:tt)*) => {
321        $($code)*
322    };
323    (third, $($code:tt)*) => {
324        $($code)*
325    };
326}
327
328#[macro_export]
329macro_rules! second {
330    (zeroth, $($code:tt)*) => {};
331    (first, $($code:tt)*) => {};
332    (second, $($code:tt)*) => {
333        $($code)*
334    };
335    (third, $($code:tt)*) => {
336        $($code)*
337    };
338}
339
340#[macro_export]
341macro_rules! third {
342    (zeroth, $($code:tt)*) => {};
343    (first, $($code:tt)*) => {};
344    (second, $($code:tt)*) => {};
345    (third, $($code:tt)*) => {
346        $($code)*
347    };
348}
349
350#[macro_export]
351macro_rules! chain_rule {
352    (zeroth, Self::chain_rule($self:ident, $f0:expr, $f1:expr, $f2:expr, $f3:expr)) => {
353        Self::chain_rule($self, $f0)
354    };
355    (first, Self::chain_rule($self:ident, $f0:expr, $f1:expr, $f2:expr, $f3:expr)) => {
356        Self::chain_rule($self, $f0, $f1)
357    };
358    (second, Self::chain_rule($self:ident, $f0:expr, $f1:expr, $f2:expr, $f3:expr)) => {
359        Self::chain_rule($self, $f0, $f1, $f2)
360    };
361    (third, Self::chain_rule($self:ident, $f0:expr, $f1:expr, $f2:expr, $f3:expr)) => {
362        Self::chain_rule($self, $f0, $f1, $f2, $f3)
363    };
364}
365
366#[macro_export]
367macro_rules! impl_zeroth_derivatives {
368    ($struct:ident, [$($im:ident),*]$(, [$($dim:tt),*]$(, [$($ddim:tt),*])*)?) => {
369        impl_derivatives!(zeroth, 0, $struct, [$($im),*]$(, [$($dim),*]$(, [$($ddim),*])*)?);
370    };
371}
372
373#[macro_export]
374macro_rules! impl_first_derivatives {
375    ($struct:ident, [$($im:ident),*]$(, [$($dim:tt),*]$(, [$($ddim:tt),*])*)?) => {
376        impl_derivatives!(first, 1, $struct, [$($im),*]$(, [$($dim),*]$(, [$($ddim),*])*)?);
377    };
378}
379
380#[macro_export]
381macro_rules! impl_second_derivatives {
382    ($struct:ident, [$($im:ident),*]$(, [$($dim:tt),*]$(, [$($ddim:tt),*])*)?) => {
383        impl_derivatives!(second, 2, $struct, [$($im),*]$(, [$($dim),*]$(, [$($ddim),*])*)?);
384    };
385}
386
387#[macro_export]
388macro_rules! impl_third_derivatives {
389    ($struct:ident, [$($im:ident),*]$(, [$($dim:tt),*]$(, [$($ddim:tt),*])*)?) => {
390        impl_derivatives!(third, 3, $struct, [$($im),*]$(, [$($dim),*]$(, [$($ddim),*])*)?);
391    };
392}