1#[macro_export]
2macro_rules! impl_derivatives {
3 ($deriv:ident, $nderiv:expr, $struct:ident, [$($im:ident),*]$(, [$($dim:tt),*]$(, [$($ddim:tt),*])*)?) => {
4 impl<T: DualNum<F>, F: DualNumFloat$($(, $dim: Dim)*)?> DualNum<F> for $struct<T, F$($(, $dim)*)?>
5 where
6 $($(DefaultAllocator: Allocator<$($ddim,)*>),*)?
8 {
9 const NDERIV: usize = T::NDERIV + $nderiv;
10
11 #[inline]
12 fn recip(&self) -> Self {
13 let rec = self.re.recip();
14 let f0 = rec.clone();
15 first!($deriv, let f1 = -f0.clone() * &rec;);
16 second!($deriv, let f2 = f1.clone() * &rec * F::from(-2.0).unwrap(););
17 third!($deriv, let f3 = f2.clone() * rec * F::from(-3.0).unwrap(););
18 chain_rule!($deriv, Self::chain_rule(self, f0, f1, f2, f3))
19 }
20
21 #[inline]
22 fn powi(&self, exp: i32) -> Self {
23 match exp {
24 0 => Self::one(),
25 1 => self.clone(),
26 2 => self * self,
27 _ => {
28 let pow3 = self.re.powi(exp - 3);
29 let f0 = pow3.clone() * &self.re * &self.re * &self.re;
30 first!($deriv, let f1 = pow3.clone() * &self.re * &self.re * F::from(exp).unwrap(););
31 second!($deriv, let f2 = pow3.clone() * &self.re * F::from(exp * (exp - 1)).unwrap(););
32 third!($deriv, let f3 = pow3 * F::from(exp * (exp - 1) * (exp - 2)).unwrap(););
33 chain_rule!($deriv, Self::chain_rule(self, f0, f1, f2, f3))
34 }
35 }
36 }
37
38 #[inline]
39 fn powf(&self, n: F) -> Self {
40 if n.is_zero() {
41 Self::one()
42 } else if n.is_one() {
43 self.clone()
44 } else if (n - F::one() - F::one()).abs() < F::epsilon() {
45 self * self
46 } else {
47 let n1 = n - F::one();
48 let n2 = n1 - F::one();
49 let n3 = n2 - F::one();
50 let pow3 = self.re.powf(n3);
51 let f0 = pow3.clone() * &self.re * &self.re * &self.re;
52 first!($deriv, let f1 = pow3.clone() * &self.re * &self.re * n;);
53 second!($deriv, let f2 = pow3.clone() * &self.re * n * n1;);
54 third!($deriv, let f3 = pow3 * n * n1 * n2;);
55 chain_rule!($deriv, Self::chain_rule(self, f0, f1, f2, f3))
56 }
57 }
58
59 #[inline]
60 fn sqrt(&self) -> Self {
61 first!($deriv, let rec = self.re.recip(););
62 first!($deriv, let half = F::from(0.5).unwrap(););
63 let f0 = self.re.sqrt();
64 first!($deriv, let f1 = f0.clone() * &rec * half;);
65 second!($deriv, let f2 = -f1.clone() * &rec * half;);
66 third!($deriv, let f3 = f2.clone() * rec * (-F::one() - half););
67 chain_rule!($deriv, Self::chain_rule(self, f0, f1, f2, f3))
68 }
69
70 #[inline]
71 fn cbrt(&self) -> Self {
72 first!($deriv, let rec = self.re.recip(););
73 first!($deriv, let third = F::from(1.0 / 3.0).unwrap(););
74 let f0 = self.re.cbrt();
75 first!($deriv, let f1 = f0.clone() * &rec * third;);
76 second!($deriv, let f2 = f1.clone() * &rec * (third - F::one()););
77 third!($deriv, let f3 = f2.clone() * rec * (third - F::one() - F::one()););
78 chain_rule!($deriv, Self::chain_rule(self, f0, f1, f2, f3))
79 }
80
81
82 #[inline]
83 fn exp(&self) -> Self {
84 let f = self.re.exp();
85 chain_rule!($deriv, Self::chain_rule(self, f.clone(), f.clone(), f.clone(), f))
86 }
87
88 #[inline]
89 fn exp2(&self) -> Self {
90 first!($deriv, let ln2 = F::from(2.0).unwrap().ln(););
91 let f0 = self.re.exp2();
92 first!($deriv, let f1 = f0.clone() * ln2;);
93 second!($deriv, let f2 = f1.clone() * ln2;);
94 third!($deriv, let f3 = f2.clone() * ln2;);
95 chain_rule!($deriv, Self::chain_rule(self, f0, f1, f2, f3))
96 }
97
98 #[inline]
99 fn exp_m1(&self) -> Self {
100 let f0 = self.re.exp_m1();
101 first!($deriv, let f1 = self.re.exp(););
102 chain_rule!($deriv, Self::chain_rule(self, f0, f1.clone(), f1.clone(), f1))
103 }
104
105 #[inline]
106 fn ln(&self) -> Self {
107 first!($deriv, let rec = self.re.recip(););
108 let f0 = self.re.ln();
109 first!($deriv, let f1 = rec.clone(););
110 second!($deriv, let f2 = -f1.clone() * &rec;);
111 third!($deriv, let f3 = f2.clone() * rec * F::from(-2.0).unwrap(););
112 chain_rule!($deriv, Self::chain_rule(self, f0, f1, f2, f3))
113 }
114
115 #[inline]
116 fn log(&self, base: F) -> Self {
117 first!($deriv, let rec = self.re.recip(););
118 let f0 = self.re.log(base);
119 first!($deriv, let f1 = rec.clone() / base.ln(););
120 second!($deriv, let f2 = -f1.clone() * &rec;);
121 third!($deriv, let f3 = f2.clone() * rec * F::from(-2.0).unwrap(););
122 chain_rule!($deriv, Self::chain_rule(self, f0, f1, f2, f3))
123 }
124
125 #[inline]
126 fn log2(&self) -> Self {
127 first!($deriv, let rec = self.re.recip(););
128 let f0 = self.re.log2();
129 first!($deriv, let f1 = rec.clone() / (F::one() + F::one()).ln(););
130 second!($deriv, let f2 = -f1.clone() * &rec;);
131 third!($deriv, let f3 = f2.clone() * rec * F::from(-2.0).unwrap(););
132 chain_rule!($deriv, Self::chain_rule(self, f0, f1, f2, f3))
133 }
134
135 #[inline]
136 fn log10(&self) -> Self {
137 first!($deriv, let rec = self.re.recip(););
138 let f0 = self.re.log10();
139 first!($deriv, let f1 = rec.clone() / F::from(10.0).unwrap().ln(););
140 second!($deriv, let f2 = -f1.clone() * &rec;);
141 third!($deriv, let f3 = f2.clone() * rec * F::from(-2.0).unwrap(););
142 chain_rule!($deriv, Self::chain_rule(self, f0, f1, f2, f3))
143 }
144
145 #[inline]
146 fn ln_1p(&self) -> Self {
147 first!($deriv, let rec = (self.re.clone() + F::one()).recip(););
148 let f0 = self.re.ln_1p();
149 first!($deriv, let f1 = rec.clone(););
150 second!($deriv, let f2 = -f1.clone() * &rec;);
151 third!($deriv, let f3 = f2.clone() * rec * F::from(-2.0).unwrap(););
152 chain_rule!($deriv, Self::chain_rule(self, f0, f1, f2, f3))
153 }
154
155 #[inline]
156 fn sin(&self) -> Self {
157 zeroth!($deriv, let s = self.re.sin(););
158 first!($deriv, let (s, c) = self.re.sin_cos(););
159 chain_rule!($deriv, Self::chain_rule(self, s.clone(), c.clone(), -s, -c))
160 }
161
162 #[inline]
163 fn cos(&self) -> Self {
164 zeroth!($deriv, let c = self.re.cos(););
165 first!($deriv, let (s, c) = self.re.sin_cos(););
166 chain_rule!($deriv, Self::chain_rule(self, c.clone(), -s.clone(), -c, s))
167 }
168
169 #[inline]
170 fn sin_cos(&self) -> (Self, Self) {
171 let (s, c) = self.re.sin_cos();
172 (
173 chain_rule!($deriv, Self::chain_rule(self, s.clone(), c.clone(), -s.clone(), -c.clone())),
174 chain_rule!($deriv, Self::chain_rule(self, c.clone(), -s.clone(), -c, s)))
175 }
176
177 #[inline]
178 fn tan(&self) -> Self {
179 let (sin, cos) = self.sin_cos();
180 sin / cos
181 }
182
183 #[inline]
184 fn asin(&self) -> Self {
185 first!($deriv, let rec = (T::one() - self.re.clone() * &self.re).recip(););
186 let f0 = self.re.asin();
187 first!($deriv, let f1 = rec.sqrt(););
188 second!($deriv, let f2 = self.re.clone() * &f1 * &rec;);
189 third!($deriv, let f3 = (self.re.clone() * &self.re * (F::one() + F::one()) + F::one()) * &f1 * &rec * rec;);
190 chain_rule!($deriv, Self::chain_rule(self, f0, f1, f2, f3))
191 }
192
193 #[inline]
194 fn acos(&self) -> Self {
195 first!($deriv, let rec = (T::one() - self.re.clone() * &self.re).recip(););
196 let f0 = self.re.acos();
197 first!($deriv, let f1 = -rec.sqrt(););
198 second!($deriv, let f2 = self.re.clone() * &f1 * &rec;);
199 third!($deriv, let f3 = (self.re.clone() * &self.re * (F::one() + F::one()) + F::one()) * &f1 * &rec * rec;);
200 chain_rule!($deriv, Self::chain_rule(self, f0, f1, f2, f3))
201 }
202
203 #[inline]
204 fn atan(&self) -> Self {
205 first!($deriv, let rec = (T::one() + self.re.clone() * &self.re).recip(););
206 let f0 = self.re.atan();
207 first!($deriv, let f1 = rec.clone(););
208 second!($deriv, let two = F::one() + F::one(););
209 second!($deriv, let f2 = -self.re.clone() * &f1 * &rec * two;);
210 third!($deriv, let f3 = (self.re.clone() * &self.re * F::from(6.0).unwrap() - two) * &f1 * &rec * rec;);
211 chain_rule!($deriv, Self::chain_rule(self, f0, f1, f2, f3))
212 }
213
214 #[inline]
215 fn atan2(&self, other: Self) -> Self {
216 let mut res = (self / other.clone()).atan();
217 res.re = self.re.atan2(other.re);
218 res
219 }
220
221 #[inline]
222 fn sinh(&self) -> Self {
223 let s = self.re.sinh();
224 first!($deriv, let c = self.re.cosh(););
225 chain_rule!($deriv, Self::chain_rule(self, s.clone(), c.clone(), s, c))
226 }
227
228 #[inline]
229 fn cosh(&self) -> Self {
230 first!($deriv, let s = self.re.sinh(););
231 let c = self.re.cosh();
232 chain_rule!($deriv, Self::chain_rule(self, c.clone(), s.clone(), c, s))
233 }
234
235 #[inline]
236 fn tanh(&self) -> Self {
237 self.sinh() / self.cosh()
238 }
239
240 #[inline]
241 fn asinh(&self) -> Self {
242 first!($deriv, let rec = (T::one() + self.re.clone() * &self.re).recip(););
243 let f0 = self.re.asinh();
244 first!($deriv, let f1 = rec.sqrt(););
245 second!($deriv, let f2 = -self.re.clone() * &f1 * &rec;);
246 third!($deriv, let f3 = (self.re.clone() * &self.re * (F::one() + F::one()) - F::one()) * &f1 * &rec * rec;);
247 chain_rule!($deriv, Self::chain_rule(self, f0, f1, f2, f3))
248 }
249
250 #[inline]
251 fn acosh(&self) -> Self {
252 first!($deriv, let rec = (self.re.clone() * &self.re - F::one()).recip(););
253 let f0 = self.re.acosh();
254 first!($deriv, let f1 = rec.sqrt(););
255 second!($deriv, let f2 = -self.re.clone() * &f1 * &rec;);
256 third!($deriv, let f3 = (self.re.clone() * &self.re * (F::one() + F::one()) + F::one()) * &f1 * &rec * rec;);
257 chain_rule!($deriv, Self::chain_rule(self, f0, f1, f2, f3))
258 }
259
260 #[inline]
261 fn atanh(&self) -> Self {
262 first!($deriv, let rec = (T::one() - self.re.clone() * &self.re).recip(););
263 let f0 = self.re.atanh();
264 first!($deriv, let f1 = rec.clone(););
265 second!($deriv, let two = F::one() + F::one(););
266 second!($deriv, let f2 = self.re.clone() * &f1 * &rec * two;);
267 third!($deriv, let f3 = (self.re.clone() * &self.re * F::from(6.0).unwrap() + two) * &f1 * &rec * rec;);
268 chain_rule!($deriv, Self::chain_rule(self, f0, f1, f2, f3))
269 }
270
271 #[inline]
272 fn sph_j0(&self) -> Self {
273 if self.re() < F::epsilon() {
274 Self::one() - self * self / F::from(6.0).unwrap()
275 } else {
276 self.sin() / self
277 }
278 }
279
280 #[inline]
281 fn sph_j1(&self) -> Self {
282 if self.re() < F::epsilon() {
283 self.clone() / F::from(3.0).unwrap()
284 } else {
285 let (s, c) = self.sin_cos();
286 (s - self * c) / (self * self)
287 }
288 }
289
290 #[inline]
291 fn sph_j2(&self) -> Self {
292 if self.re() < F::epsilon() {
293 self * self / F::from(15.0).unwrap()
294 } else {
295 let (s, c) = self.sin_cos();
296 let s2 = self * self;
297 ((&s - self * c) * F::from(3.0).unwrap() - &s2 * s) / (s2 * self)
298 }
299 }
300 }
301 };
302}
303
304#[macro_export]
305macro_rules! zeroth {
306 (zeroth, $($code:tt)*) => {
307 $($code)*
308 };
309 (first, $($code:tt)*) => {};
310 (second, $($code:tt)*) => {};
311 (third, $($code:tt)*) => {};
312}
313
314#[macro_export]
315macro_rules! first {
316 (zeroth, $($code:tt)*) => {};
317 (first, $($code:tt)*) => {
318 $($code)*
319 };
320 (second, $($code:tt)*) => {
321 $($code)*
322 };
323 (third, $($code:tt)*) => {
324 $($code)*
325 };
326}
327
328#[macro_export]
329macro_rules! second {
330 (zeroth, $($code:tt)*) => {};
331 (first, $($code:tt)*) => {};
332 (second, $($code:tt)*) => {
333 $($code)*
334 };
335 (third, $($code:tt)*) => {
336 $($code)*
337 };
338}
339
340#[macro_export]
341macro_rules! third {
342 (zeroth, $($code:tt)*) => {};
343 (first, $($code:tt)*) => {};
344 (second, $($code:tt)*) => {};
345 (third, $($code:tt)*) => {
346 $($code)*
347 };
348}
349
350#[macro_export]
351macro_rules! chain_rule {
352 (zeroth, Self::chain_rule($self:ident, $f0:expr, $f1:expr, $f2:expr, $f3:expr)) => {
353 Self::chain_rule($self, $f0)
354 };
355 (first, Self::chain_rule($self:ident, $f0:expr, $f1:expr, $f2:expr, $f3:expr)) => {
356 Self::chain_rule($self, $f0, $f1)
357 };
358 (second, Self::chain_rule($self:ident, $f0:expr, $f1:expr, $f2:expr, $f3:expr)) => {
359 Self::chain_rule($self, $f0, $f1, $f2)
360 };
361 (third, Self::chain_rule($self:ident, $f0:expr, $f1:expr, $f2:expr, $f3:expr)) => {
362 Self::chain_rule($self, $f0, $f1, $f2, $f3)
363 };
364}
365
366#[macro_export]
367macro_rules! impl_zeroth_derivatives {
368 ($struct:ident, [$($im:ident),*]$(, [$($dim:tt),*]$(, [$($ddim:tt),*])*)?) => {
369 impl_derivatives!(zeroth, 0, $struct, [$($im),*]$(, [$($dim),*]$(, [$($ddim),*])*)?);
370 };
371}
372
373#[macro_export]
374macro_rules! impl_first_derivatives {
375 ($struct:ident, [$($im:ident),*]$(, [$($dim:tt),*]$(, [$($ddim:tt),*])*)?) => {
376 impl_derivatives!(first, 1, $struct, [$($im),*]$(, [$($dim),*]$(, [$($ddim),*])*)?);
377 };
378}
379
380#[macro_export]
381macro_rules! impl_second_derivatives {
382 ($struct:ident, [$($im:ident),*]$(, [$($dim:tt),*]$(, [$($ddim:tt),*])*)?) => {
383 impl_derivatives!(second, 2, $struct, [$($im),*]$(, [$($dim),*]$(, [$($ddim),*])*)?);
384 };
385}
386
387#[macro_export]
388macro_rules! impl_third_derivatives {
389 ($struct:ident, [$($im:ident),*]$(, [$($dim:tt),*]$(, [$($ddim:tt),*])*)?) => {
390 impl_derivatives!(third, 3, $struct, [$($im),*]$(, [$($dim),*]$(, [$($ddim),*])*)?);
391 };
392}