num_dual/lib.rs
1//! Generalized, recursive, scalar and vector (hyper) dual numbers for the automatic and exact calculation of (partial) derivatives.
2//!
3//! # Example
4//! This example defines a generic scalar and a generic vector function that can be called using any (hyper-) dual number and automatically calculates derivatives.
5//! ```
6//! use num_dual::*;
7//! use nalgebra::SVector;
8//!
9//! fn foo<D: DualNum<f64>>(x: D) -> D {
10//! x.powi(3)
11//! }
12//!
13//! fn bar<D: DualNum<f64>, const N: usize>(x: SVector<D, N>) -> D {
14//! x.dot(&x).sqrt()
15//! }
16//!
17//! fn main() {
18//! // Calculate a simple derivative
19//! let (f, df) = first_derivative(foo, 5.0);
20//! assert_eq!(f, 125.0);
21//! assert_eq!(df, 75.0);
22//!
23//! // Manually construct the dual number
24//! let x = Dual64::new(5.0, 1.0);
25//! println!("{}", foo(x)); // 125 + 75ε
26//!
27//! // Calculate a gradient
28//! let (f, g) = gradient(bar, &SVector::from([4.0, 3.0]));
29//! assert_eq!(f, 5.0);
30//! assert_eq!(g[0], 0.8);
31//!
32//! // Calculate a Hessian
33//! let (f, g, h) = hessian(bar, &SVector::from([4.0, 3.0]));
34//! println!("{h}"); // [[0.072, -0.096], [-0.096, 0.128]]
35//!
36//! // for x=cos(t) calculate the third derivative of foo w.r.t. t
37//! let (f0, f1, f2, f3) = third_derivative(|t| foo(t.cos()), 1.0);
38//! println!("{f3}"); // 1.5836632930100278
39//! }
40//! ```
41//!
42//! # Usage
43//! There are two ways to use the data structures and functions provided in this crate:
44//! 1. (recommended) Using the provided functions for explicit ([`first_derivative`], [`gradient`], ...) and
45//! implicit ([`implicit_derivative`], [`implicit_derivative_binary`], [`implicit_derivative_vec`]) functions.
46//! 2. (for experienced users) Using the different dual number types ([`Dual`], [`HyperDual`], [`DualVec`], ...) directly.
47//!
48//! The following examples and explanations focus on the first way.
49//!
50//! # Derivatives of explicit functions
51//! To be able to calculate the derivative of a function, it needs to be generic over the type of dual number used.
52//! Most commonly this would look like this:
53//! ```compile_fail
54//! fn foo<D: DualNum<f64> + Copy>(x: X) -> O {...}
55//! ```
56//! Of course, the function could also use single precision ([`f32`]) or be generic over the precision (`F:` [`DualNumFloat`]).
57//! For now, [`Copy`] is not a supertrait of [`DualNum`] to enable the calculation of derivatives with respect
58//! to a dynamic number of variables. However, in practice, using the [`Copy`] trait bound leads to an
59//! implementation that is more similar to one not using AD and there could be severe performance ramifications
60//! when using dynamically allocated dual numbers.
61//!
62//! The type `X` above is `D` for univariate functions, [`&OVector`](nalgebra::OVector) for multivariate
63//! functions, and `(D, D)` or `(&OVector, &OVector)` for partial derivatives. In the simplest case, the output
64//! `O` is a scalar `D`. However, it is generalized using the [`Mappable`] trait to also include types like
65//! [`Option<D>`] or [`Result<D, E>`], collections like [`Vec<D>`] or [`HashMap<K, D>`], or custom structs that
66//! implement the [`Mappable`] trait. Therefore, it is, e.g., possible to calculate the derivative of a fallible
67//! function:
68//!
69//! ```no_run
70//! # use num_dual::{DualNum, first_derivative};
71//! # type E = ();
72//! fn foo<D: DualNum<f64> + Copy>(x: D) -> Result<D, E> { todo!() }
73//!
74//! fn main() -> Result<(), E> {
75//! let (val, deriv) = first_derivative(foo, 2.0)?;
76//! // ...
77//! Ok(())
78//! }
79//! ```
80//! All dual number types can contain other dual numbers as inner types. Therefore, it is also possible to
81//! use the different derivative functions inside of each other.
82//!
83//! ## Extra arguments
84//! The [`partial`] and [`partial2`] functions are used to pass additional arguments to the function, e.g.:
85//! ```no_run
86//! # use num_dual::{DualNum, first_derivative, partial};
87//! fn foo<D: DualNum<f64> + Copy>(x: D, args: &(D, D)) -> D { todo!() }
88//!
89//! fn main() {
90//! let (val, deriv) = first_derivative(partial(foo, &(3.0, 4.0)), 5.0);
91//! }
92//! ```
93//! All types that implement the [`DualStruct`] trait can be used as additional function arguments. The
94//! only difference between using the [`partial`] and [`partial2`] functions compared to passing the extra
95//! arguments via a closure, is that the type of the extra arguments is automatically adjusted to the correct
96//! dual number type used for the automatic differentiation. Note that the following code would not compile:
97//! ```compile_fail
98//! # use num_dual::{DualNum, first_derivative};
99//! # fn foo<D: DualNum<f64> + Copy>(x: D, args: &(D, D)) -> D { todo!() }
100//! fn main() {
101//! let (val, deriv) = first_derivative(|x| foo(x, &(3.0, 4.0)), 5.0);
102//! }
103//! ```
104//! The code created by [`partial`] essentially translates to:
105//! ```no_run
106//! # use num_dual::{DualNum, first_derivative, Dual, DualStruct};
107//! # fn foo<D: DualNum<f64> + Copy>(x: D, args: &(D, D)) -> D { todo!() }
108//! fn main() {
109//! let (val, deriv) = first_derivative(|x| foo(x, &(Dual::from_re(3.0), Dual::from_re(4.0))), 5.0);
110//! }
111//! ```
112//!
113//! ## The [`Gradients`] trait
114//! The functions [`gradient`], [`hessian`], [`partial_hessian`] and [`jacobian`] are generic over the dimensionality
115//! of the variable vector. However, to use the functions in a generic context requires not using the [`Copy`] trait
116//! bound on the dual number type, because the dynamically sized dual numbers can by construction not implement
117//! [`Copy`]. Also, due to frequent heap allocations, the performance of the automatic differentiation could
118//! suffer significantly for dynamically sized dual numbers compared to statically sized dual numbers. The
119//! [`Gradients`] trait is introduced to overcome these limitations.
120//! ```
121//! # use num_dual::{DualNum, Gradients};
122//! # use nalgebra::{OVector, DefaultAllocator, allocator::Allocator, vector, dvector};
123//! # use approx::assert_relative_eq;
124//! fn foo<D: DualNum<f64> + Copy, N: Gradients>(x: OVector<D, N>, n: &D) -> D where DefaultAllocator: Allocator<N> {
125//! x.dot(&x).sqrt() - n
126//! }
127//!
128//! fn main() {
129//! let x = vector![1.0, 5.0, 5.0, 7.0];
130//! let (f, grad) = Gradients::gradient(foo, &x, &10.0);
131//! assert_eq!(f, 0.0);
132//! assert_relative_eq!(grad, vector![0.1, 0.5, 0.5, 0.7]);
133//!
134//! let x = dvector![1.0, 5.0, 5.0, 7.0];
135//! let (f, grad) = Gradients::gradient(foo, &x, &10.0);
136//! assert_eq!(f, 0.0);
137//! assert_relative_eq!(grad, dvector![0.1, 0.5, 0.5, 0.7]);
138//! }
139//! ```
140//! For dynamically sized input arrays, the [`Gradients`] trait evaluates gradients or higher-order derivatives
141//! by iteratively evaluating scalar derivatives. For functions that do not rely on the [`Copy`] trait bound,
142//! only benchmarking can reveal Whether the increased performance through the avoidance of heap allocations
143//! can overcome the overhead of repeated function evaluations, i.e., if [`Gradients`] outperforms directly
144//! calling [`gradient`], [`hessian`], [`partial_hessian`] or [`jacobian`].
145//!
146//! # Derivatives of implicit functions
147//! Implicit differentiation is used to determine the derivative `dy/dx` where the output `y` is only related
148//! implicitly to the input `x` via the equation `f(x,y)=0`. Automatic implicit differentiation generalizes the
149//! idea to determining the output `y` with full derivative information. Note that the first step in calculating
150//! an implicit derivative is always determining the "real" part (i.e., neglecting all derivatives) of the equation
151//! `f(x,y)=0`. The `num-dual` library is focused on automatic differentiation and not nonlinear equation
152//! solving. Therefore, this first step needs to be done with your own custom solutions, or Rust crates for
153//! nonlinear equation solving and optimization like, e.g., [argmin](https://argmin-rs.org/).
154//!
155//! The following example implements a square root for generic dual numbers using implicit differentiation. Of
156//! course, the derivatives of the square root can also be determined explicitly using the chain rule, so the
157//! example serves mostly as illustration. `x.re()` provides the "real" part of the dual number which is a [`f64`]
158//! and therefore, we can use all the functionalities from the std library (including the square root).
159//! ```
160//! # use num_dual::{DualNum, implicit_derivative, first_derivative};
161//! fn implicit_sqrt<D: DualNum<f64> + Copy>(x: D) -> D {
162//! implicit_derivative(|s, x| s * s - x, x.re().sqrt(), &x)
163//! }
164//!
165//! fn main() {
166//! // sanity check, not actually calculating any derivative
167//! assert_eq!(implicit_sqrt(25.0), 5.0);
168//!
169//! let (sq, deriv) = first_derivative(implicit_sqrt, 25.0);
170//! assert_eq!(sq, 5.0);
171//! // The derivative of sqrt(x) is 1/(2*sqrt(x)) which should evaluate to 0.1
172//! assert_eq!(deriv, 0.1);
173//! }
174//! ```
175//! The `implicit_sqrt` or any likewise defined function is generic over the dual type `D`
176//! and can, therefore, be used anywhere as a part of an arbitrary complex computation. The functions
177//! [`implicit_derivative_binary`] and [`implicit_derivative_vec`] can be used for implicit functions
178//! with more than one variable.
179//!
180//! For implicit functions that contain complex models and a large number of parameters, the [`ImplicitDerivative`]
181//! interface might come in handy. The idea is to define the implicit function using the [`ImplicitFunction`] trait
182//! and feeding it into the [`ImplicitDerivative`] struct, which internally stores the parameters as dual numbers
183//! and their real parts. The [`ImplicitDerivative`] then provides methods for the evaluation of the real part
184//! of the residual (which can be passed to a nonlinear solver) and the implicit derivative which can be called
185//! after solving for the real part of the solution to reconstruct all the derivatives.
186//! ```
187//! # use num_dual::{ImplicitFunction, DualNum, Dual, ImplicitDerivative};
188//! struct ImplicitSqrt;
189//! impl ImplicitFunction<f64> for ImplicitSqrt {
190//! type Parameters<D> = D;
191//! type Variable<D> = D;
192//! fn residual<D: DualNum<f64> + Copy>(x: D, square: &D) -> D {
193//! *square - x * x
194//! }
195//! }
196//!
197//! fn main() {
198//! let x = Dual::from_re(25.0).derivative();
199//! let func = ImplicitDerivative::new(ImplicitSqrt, x);
200//! assert_eq!(func.residual(5.0), 0.0);
201//! assert_eq!(x.sqrt(), func.implicit_derivative(5.0));
202//! }
203//! ```
204//!
205//! ## Combination with nonlinear solver libraries
206//! As mentioned previously, this crate does not contain any algorithms for nonlinear optimization or root finding.
207//! However, combining the capabilities of automatic differentiation with nonlinear solving can be very fruitful.
208//! Most importantly, the calculation of Jacobians or Hessians can be completely automated, if the model can be
209//! expressed within the functionalities of the [`DualNum`] trait. On top of that implicit derivatives can be of
210//! interest, if derivatives of the result of the optimization itself are relevant (e.g., in a bilevel
211//! optimization). The synergy is exploited in the [`ipopt-ad`](https://github.com/prehner/ipopt-ad) crate that
212//! turns the NLP solver [IPOPT](https://github.com/coin-or/Ipopt) into a black-box optimization algorithm (i.e.,
213//! it only requires a function that returns the values of the optimization variable and constraints), without
214//! any repercussions regarding the robustness or speed of convergence of the solver.
215//!
216//! If you are developing nonlinear optimization algorithms in Rust, feel free to reach out to us. We are happy to
217//! discuss how to enhance your algorithms with the automatic differentiation capabilities of this crate.
218
219#![warn(clippy::all)]
220#![warn(clippy::allow_attributes)]
221
222use nalgebra::allocator::Allocator;
223use nalgebra::{DefaultAllocator, Dim, OMatrix, Scalar};
224#[cfg(feature = "ndarray")]
225use ndarray::ScalarOperand;
226use num_traits::{Float, FloatConst, FromPrimitive, Inv, NumAssignOps, NumOps, Signed};
227use std::collections::HashMap;
228use std::fmt;
229use std::hash::Hash;
230use std::iter::{Product, Sum};
231
232#[macro_use]
233mod macros;
234#[macro_use]
235mod nalgebra_macros;
236#[macro_use]
237mod impl_derivatives;
238
239mod bessel;
240mod datatypes;
241mod explicit;
242mod implicit;
243pub use bessel::BesselDual;
244pub use datatypes::derivative::Derivative;
245pub use datatypes::dual::{Dual, Dual32, Dual64};
246pub use datatypes::dual_vec::{
247 DualDVec32, DualDVec64, DualSVec, DualSVec32, DualSVec64, DualVec, DualVec32, DualVec64,
248};
249pub use datatypes::dual2::{Dual2, Dual2_32, Dual2_64};
250pub use datatypes::dual2_vec::{
251 Dual2DVec, Dual2DVec32, Dual2DVec64, Dual2SVec, Dual2SVec32, Dual2SVec64, Dual2Vec, Dual2Vec32,
252 Dual2Vec64,
253};
254pub use datatypes::dual3::{Dual3, Dual3_32, Dual3_64};
255pub use datatypes::hyperdual::{HyperDual, HyperDual32, HyperDual64};
256pub use datatypes::hyperdual_vec::{
257 HyperDualDVec32, HyperDualDVec64, HyperDualSVec32, HyperDualSVec64, HyperDualVec,
258 HyperDualVec32, HyperDualVec64,
259};
260pub use datatypes::hyperhyperdual::{HyperHyperDual, HyperHyperDual32, HyperHyperDual64};
261pub use datatypes::real::Real;
262pub use explicit::{
263 Gradients, first_derivative, gradient, hessian, jacobian, partial, partial_hessian, partial2,
264 partial3, second_derivative, second_partial_derivative, third_derivative,
265 third_partial_derivative, third_partial_derivative_vec, zeroth_derivative,
266};
267pub use implicit::{
268 ImplicitDerivative, ImplicitFunction, implicit_derivative, implicit_derivative_binary,
269 implicit_derivative_sp, implicit_derivative_vec,
270};
271
272pub mod linalg;
273
274#[cfg(feature = "python")]
275pub mod python;
276
277#[cfg(feature = "python_macro")]
278mod python_macro;
279
280/// A generalized (hyper) dual number.
281#[cfg(feature = "ndarray")]
282pub trait DualNum<F>:
283 NumOps
284 + for<'r> NumOps<&'r Self>
285 + Signed
286 + NumOps<F>
287 + NumAssignOps
288 + NumAssignOps<F>
289 + Clone
290 + Inv<Output = Self>
291 + Sum
292 + Product
293 + FromPrimitive
294 + From<F>
295 + DualStruct<F, Real = F>
296 + Mappable<Self>
297 + fmt::Display
298 + PartialOrd
299 + PartialOrd<F>
300 + fmt::Debug
301 + ScalarOperand
302 + 'static
303{
304 /// Highest derivative that can be calculated with this struct
305 const NDERIV: usize;
306
307 /// The type of the individual elements of this dual number
308 type InnerDual: DualNum<F>;
309
310 /// Build a dual number from its real part, setting all other values to 0
311 fn from_re(re: Self::InnerDual) -> Self;
312
313 /// Reciprocal (inverse) of a number `1/x`
314 fn recip(&self) -> Self;
315
316 /// Power with integer exponent `x^n`
317 fn powi(&self, n: i32) -> Self;
318
319 /// Power with real exponent `x^n`
320 fn powf(&self, n: F) -> Self;
321
322 /// Square root
323 fn sqrt(&self) -> Self;
324
325 /// Cubic root
326 fn cbrt(&self) -> Self;
327
328 /// Exponential `e^x`
329 fn exp(&self) -> Self;
330
331 /// Exponential with base 2 `2^x`
332 fn exp2(&self) -> Self;
333
334 /// Exponential minus 1 `e^x-1`
335 fn exp_m1(&self) -> Self;
336
337 /// Natural logarithm
338 fn ln(&self) -> Self;
339
340 /// Logarithm with arbitrary base
341 fn log(&self, base: F) -> Self;
342
343 /// Logarithm with base 2
344 fn log2(&self) -> Self;
345
346 /// Logarithm with base 10
347 fn log10(&self) -> Self;
348
349 /// Logarithm on x plus one `ln(1+x)`
350 fn ln_1p(&self) -> Self;
351
352 /// Sine
353 fn sin(&self) -> Self;
354
355 /// Cosine
356 fn cos(&self) -> Self;
357
358 /// Tangent
359 fn tan(&self) -> Self;
360
361 /// Calculate sine and cosine simultaneously
362 fn sin_cos(&self) -> (Self, Self);
363
364 /// Arcsine
365 fn asin(&self) -> Self;
366
367 /// Arccosine
368 fn acos(&self) -> Self;
369
370 /// Arctangent
371 fn atan(&self) -> Self;
372
373 /// Arctangent
374 fn atan2(&self, other: Self) -> Self;
375
376 /// Hyperbolic sine
377 fn sinh(&self) -> Self;
378
379 /// Hyperbolic cosine
380 fn cosh(&self) -> Self;
381
382 /// Hyperbolic tangent
383 fn tanh(&self) -> Self;
384
385 /// Area hyperbolic sine
386 fn asinh(&self) -> Self;
387
388 /// Area hyperbolic cosine
389 fn acosh(&self) -> Self;
390
391 /// Area hyperbolic tangent
392 fn atanh(&self) -> Self;
393
394 /// 0th order spherical Bessel function of the first kind
395 fn sph_j0(&self) -> Self;
396
397 /// 1st order spherical Bessel function of the first kind
398 fn sph_j1(&self) -> Self;
399
400 /// 2nd order spherical Bessel function of the first kind
401 fn sph_j2(&self) -> Self;
402
403 /// Fused multiply-add
404 #[inline]
405 fn mul_add(&self, a: Self, b: Self) -> Self {
406 self.clone() * a + b
407 }
408
409 /// Power with dual exponent `x^n`
410 #[inline]
411 fn powd(&self, exp: Self) -> Self {
412 (self.ln() * exp).exp()
413 }
414}
415
416/// A generalized (hyper) dual number.
417#[cfg(not(feature = "ndarray"))]
418pub trait DualNum<F>:
419 NumOps
420 + for<'r> NumOps<&'r Self>
421 + Signed
422 + NumOps<F>
423 + NumAssignOps
424 + NumAssignOps<F>
425 + Clone
426 + Inv<Output = Self>
427 + Sum
428 + Product
429 + FromPrimitive
430 + From<F>
431 + DualStruct<F, Real = F>
432 + Mappable<Self>
433 + fmt::Display
434 + PartialOrd
435 + PartialOrd<F>
436 + fmt::Debug
437 + 'static
438{
439 /// Highest derivative that can be calculated with this struct
440 const NDERIV: usize;
441
442 /// The type of the individual elements of this dual number
443 type InnerDual: DualNum<F>;
444
445 /// Build a dual number from its real part, setting all other values to 0
446 fn from_re(re: Self::InnerDual) -> Self;
447
448 /// Reciprocal (inverse) of a number `1/x`
449 fn recip(&self) -> Self;
450
451 /// Power with integer exponent `x^n`
452 fn powi(&self, n: i32) -> Self;
453
454 /// Power with real exponent `x^n`
455 fn powf(&self, n: F) -> Self;
456
457 /// Square root
458 fn sqrt(&self) -> Self;
459
460 /// Cubic root
461 fn cbrt(&self) -> Self;
462
463 /// Exponential `e^x`
464 fn exp(&self) -> Self;
465
466 /// Exponential with base 2 `2^x`
467 fn exp2(&self) -> Self;
468
469 /// Exponential minus 1 `e^x-1`
470 fn exp_m1(&self) -> Self;
471
472 /// Natural logarithm
473 fn ln(&self) -> Self;
474
475 /// Logarithm with arbitrary base
476 fn log(&self, base: F) -> Self;
477
478 /// Logarithm with base 2
479 fn log2(&self) -> Self;
480
481 /// Logarithm with base 10
482 fn log10(&self) -> Self;
483
484 /// Logarithm on x plus one `ln(1+x)`
485 fn ln_1p(&self) -> Self;
486
487 /// Sine
488 fn sin(&self) -> Self;
489
490 /// Cosine
491 fn cos(&self) -> Self;
492
493 /// Tangent
494 fn tan(&self) -> Self;
495
496 /// Calculate sine and cosine simultaneously
497 fn sin_cos(&self) -> (Self, Self);
498
499 /// Arcsine
500 fn asin(&self) -> Self;
501
502 /// Arccosine
503 fn acos(&self) -> Self;
504
505 /// Arctangent
506 fn atan(&self) -> Self;
507
508 /// Arctangent
509 fn atan2(&self, other: Self) -> Self;
510
511 /// Hyperbolic sine
512 fn sinh(&self) -> Self;
513
514 /// Hyperbolic cosine
515 fn cosh(&self) -> Self;
516
517 /// Hyperbolic tangent
518 fn tanh(&self) -> Self;
519
520 /// Area hyperbolic sine
521 fn asinh(&self) -> Self;
522
523 /// Area hyperbolic cosine
524 fn acosh(&self) -> Self;
525
526 /// Area hyperbolic tangent
527 fn atanh(&self) -> Self;
528
529 /// 0th order spherical Bessel function of the first kind
530 fn sph_j0(&self) -> Self;
531
532 /// 1st order spherical Bessel function of the first kind
533 fn sph_j1(&self) -> Self;
534
535 /// 2nd order spherical Bessel function of the first kind
536 fn sph_j2(&self) -> Self;
537
538 /// Fused multiply-add
539 #[inline]
540 fn mul_add(&self, a: Self, b: Self) -> Self {
541 self.clone() * a + b
542 }
543
544 /// Power with dual exponent `x^n`
545 #[inline]
546 fn powd(&self, exp: Self) -> Self {
547 (self.ln() * exp).exp()
548 }
549}
550
551/// A generalized (hyper) dual number that has a static size.
552pub trait DualNumCopy<F>: DualNum<F> + Copy + Send + Sync {}
553impl<T: DualNum<F> + Copy + Send + Sync, F> DualNumCopy<F> for T {}
554
555/// The underlying data type of individual derivatives. Usually f32 or f64.
556pub trait DualNumFloat:
557 Float + FloatConst + FromPrimitive + Signed + fmt::Display + fmt::Debug + Sync + Send + 'static
558{
559}
560impl<T> DualNumFloat for T where
561 T: Float
562 + FloatConst
563 + FromPrimitive
564 + Signed
565 + fmt::Display
566 + fmt::Debug
567 + Sync
568 + Send
569 + 'static
570{
571}
572
573macro_rules! impl_dual_num_float {
574 ($float:ty) => {
575 impl DualNum<$float> for $float {
576 const NDERIV: usize = 0;
577
578 type InnerDual = $float;
579 fn from_re(re: $float) -> Self {
580 re
581 }
582
583 fn mul_add(&self, a: Self, b: Self) -> Self {
584 <$float>::mul_add(*self, a, b)
585 }
586 fn recip(&self) -> Self {
587 <$float>::recip(*self)
588 }
589 fn powi(&self, n: i32) -> Self {
590 <$float>::powi(*self, n)
591 }
592 fn powf(&self, n: Self) -> Self {
593 <$float>::powf(*self, n)
594 }
595 fn powd(&self, n: Self) -> Self {
596 <$float>::powf(*self, n)
597 }
598 fn sqrt(&self) -> Self {
599 <$float>::sqrt(*self)
600 }
601 fn exp(&self) -> Self {
602 <$float>::exp(*self)
603 }
604 fn exp2(&self) -> Self {
605 <$float>::exp2(*self)
606 }
607 fn ln(&self) -> Self {
608 <$float>::ln(*self)
609 }
610 fn log(&self, base: Self) -> Self {
611 <$float>::log(*self, base)
612 }
613 fn log2(&self) -> Self {
614 <$float>::log2(*self)
615 }
616 fn log10(&self) -> Self {
617 <$float>::log10(*self)
618 }
619 fn cbrt(&self) -> Self {
620 <$float>::cbrt(*self)
621 }
622 fn sin(&self) -> Self {
623 <$float>::sin(*self)
624 }
625 fn cos(&self) -> Self {
626 <$float>::cos(*self)
627 }
628 fn tan(&self) -> Self {
629 <$float>::tan(*self)
630 }
631 fn asin(&self) -> Self {
632 <$float>::asin(*self)
633 }
634 fn acos(&self) -> Self {
635 <$float>::acos(*self)
636 }
637 fn atan(&self) -> Self {
638 <$float>::atan(*self)
639 }
640 fn atan2(&self, other: $float) -> Self {
641 <$float>::atan2(*self, other)
642 }
643 fn sin_cos(&self) -> (Self, Self) {
644 <$float>::sin_cos(*self)
645 }
646 fn exp_m1(&self) -> Self {
647 <$float>::exp_m1(*self)
648 }
649 fn ln_1p(&self) -> Self {
650 <$float>::ln_1p(*self)
651 }
652 fn sinh(&self) -> Self {
653 <$float>::sinh(*self)
654 }
655 fn cosh(&self) -> Self {
656 <$float>::cosh(*self)
657 }
658 fn tanh(&self) -> Self {
659 <$float>::tanh(*self)
660 }
661 fn asinh(&self) -> Self {
662 <$float>::asinh(*self)
663 }
664 fn acosh(&self) -> Self {
665 <$float>::acosh(*self)
666 }
667 fn atanh(&self) -> Self {
668 <$float>::atanh(*self)
669 }
670 fn sph_j0(&self) -> Self {
671 if self.abs() < <$float>::EPSILON {
672 1.0 - self * self / 6.0
673 } else {
674 self.sin() / self
675 }
676 }
677 fn sph_j1(&self) -> Self {
678 if self.abs() < <$float>::EPSILON {
679 self / 3.0
680 } else {
681 let sc = self.sin_cos();
682 let rec = self.recip();
683 (sc.0 * rec - sc.1) * rec
684 }
685 }
686 fn sph_j2(&self) -> Self {
687 if self.abs() < <$float>::EPSILON {
688 self * self / 15.0
689 } else {
690 let sc = self.sin_cos();
691 let s2 = self * self;
692 ((3.0 - s2) * sc.0 - 3.0 * self * sc.1) / (self * s2)
693 }
694 }
695 }
696 };
697}
698
699impl_dual_num_float!(f32);
700impl_dual_num_float!(f64);
701
702/// A struct that contains dual numbers. Needed for arbitrary arguments in [ImplicitFunction].
703///
704/// The trait is implemented for all dual types themselves, and common data types (tuple, vec,
705/// array, ...) and can be implemented for custom data types to achieve full flexibility.
706pub trait DualStruct<F> {
707 type Real;
708 type Inner: DualStruct<F>;
709 fn re(&self) -> Self::Real;
710 fn from_inner(inner: &Self::Inner) -> Self;
711}
712
713/// Trait for structs used as an output of functions for which derivatives are calculated.
714///
715/// The main intention is to generalize the calculation of derivatives to fallible functions, but
716/// other use cases might also appear in the future.
717pub trait Mappable<D> {
718 type Output<O>;
719 fn map_dual<M: Fn(D) -> O, O>(self, f: M) -> Self::Output<O>;
720}
721
722impl<F> DualStruct<F> for () {
723 type Real = ();
724 type Inner = ();
725 fn re(&self) {}
726 fn from_inner(_: &Self::Inner) -> Self {}
727}
728
729impl<D> Mappable<D> for () {
730 type Output<O> = ();
731 fn map_dual<M: FnOnce(D) -> O, O>(self, _: M) {}
732}
733
734impl DualStruct<f32> for f32 {
735 type Real = f32;
736 type Inner = f32;
737 fn re(&self) -> f32 {
738 *self
739 }
740 fn from_inner(inner: &Self::Inner) -> Self {
741 *inner
742 }
743}
744
745impl Mappable<f32> for f32 {
746 type Output<O> = O;
747 fn map_dual<M: FnOnce(f32) -> O, O>(self, f: M) -> Self::Output<O> {
748 f(self)
749 }
750}
751
752impl DualStruct<f64> for f64 {
753 type Real = f64;
754 type Inner = f64;
755 fn re(&self) -> f64 {
756 *self
757 }
758 fn from_inner(inner: &Self::Inner) -> Self {
759 *inner
760 }
761}
762
763impl Mappable<f64> for f64 {
764 type Output<O> = O;
765 fn map_dual<M: FnOnce(f64) -> O, O>(self, f: M) -> Self::Output<O> {
766 f(self)
767 }
768}
769
770impl<T1: DualStruct<F>, T2: DualStruct<F>, F> DualStruct<F> for (T1, T2) {
771 type Real = (T1::Real, T2::Real);
772 type Inner = (T1::Inner, T2::Inner);
773 fn re(&self) -> Self::Real {
774 let (s1, s2) = self;
775 (s1.re(), s2.re())
776 }
777 fn from_inner(re: &Self::Inner) -> Self {
778 let (r1, r2) = re;
779 (T1::from_inner(r1), T2::from_inner(r2))
780 }
781}
782
783impl<D, T1: Mappable<D>, T2: Mappable<D>> Mappable<D> for (T1, T2) {
784 type Output<O> = (T1::Output<O>, T2::Output<O>);
785 fn map_dual<M: Fn(D) -> O, O>(self, f: M) -> Self::Output<O> {
786 let (s1, s2) = self;
787 (s1.map_dual(&f), s2.map_dual(&f))
788 }
789}
790
791impl<F, T1: DualStruct<F>, T2: DualStruct<F>, T3: DualStruct<F>> DualStruct<F> for (T1, T2, T3) {
792 type Real = (T1::Real, T2::Real, T3::Real);
793 type Inner = (T1::Inner, T2::Inner, T3::Inner);
794 fn re(&self) -> Self::Real {
795 let (s1, s2, s3) = self;
796 (s1.re(), s2.re(), s3.re())
797 }
798 fn from_inner(inner: &Self::Inner) -> Self {
799 let (r1, r2, r3) = inner;
800 (T1::from_inner(r1), T2::from_inner(r2), T3::from_inner(r3))
801 }
802}
803
804impl<D, T1: Mappable<D>, T2: Mappable<D>, T3: Mappable<D>> Mappable<D> for (T1, T2, T3) {
805 type Output<O> = (T1::Output<O>, T2::Output<O>, T3::Output<O>);
806 fn map_dual<M: Fn(D) -> O, O>(self, f: M) -> Self::Output<O> {
807 let (s1, s2, s3) = self;
808 (s1.map_dual(&f), s2.map_dual(&f), s3.map_dual(&f))
809 }
810}
811
812impl<F, T1: DualStruct<F>, T2: DualStruct<F>, T3: DualStruct<F>, T4: DualStruct<F>> DualStruct<F>
813 for (T1, T2, T3, T4)
814{
815 type Real = (T1::Real, T2::Real, T3::Real, T4::Real);
816 type Inner = (T1::Inner, T2::Inner, T3::Inner, T4::Inner);
817 fn re(&self) -> Self::Real {
818 let (s1, s2, s3, s4) = self;
819 (s1.re(), s2.re(), s3.re(), s4.re())
820 }
821 fn from_inner(inner: &Self::Inner) -> Self {
822 let (r1, r2, r3, r4) = inner;
823 (
824 T1::from_inner(r1),
825 T2::from_inner(r2),
826 T3::from_inner(r3),
827 T4::from_inner(r4),
828 )
829 }
830}
831
832impl<D, T1: Mappable<D>, T2: Mappable<D>, T3: Mappable<D>, T4: Mappable<D>> Mappable<D>
833 for (T1, T2, T3, T4)
834{
835 type Output<O> = (T1::Output<O>, T2::Output<O>, T3::Output<O>, T4::Output<O>);
836 fn map_dual<M: Fn(D) -> O, O>(self, f: M) -> Self::Output<O> {
837 let (s1, s2, s3, s4) = self;
838 (
839 s1.map_dual(&f),
840 s2.map_dual(&f),
841 s3.map_dual(&f),
842 s4.map_dual(&f),
843 )
844 }
845}
846
847impl<
848 F,
849 T1: DualStruct<F>,
850 T2: DualStruct<F>,
851 T3: DualStruct<F>,
852 T4: DualStruct<F>,
853 T5: DualStruct<F>,
854> DualStruct<F> for (T1, T2, T3, T4, T5)
855{
856 type Real = (T1::Real, T2::Real, T3::Real, T4::Real, T5::Real);
857 type Inner = (T1::Inner, T2::Inner, T3::Inner, T4::Inner, T5::Inner);
858 fn re(&self) -> Self::Real {
859 let (s1, s2, s3, s4, s5) = self;
860 (s1.re(), s2.re(), s3.re(), s4.re(), s5.re())
861 }
862 fn from_inner(inner: &Self::Inner) -> Self {
863 let (r1, r2, r3, r4, r5) = inner;
864 (
865 T1::from_inner(r1),
866 T2::from_inner(r2),
867 T3::from_inner(r3),
868 T4::from_inner(r4),
869 T5::from_inner(r5),
870 )
871 }
872}
873
874impl<D, T1: Mappable<D>, T2: Mappable<D>, T3: Mappable<D>, T4: Mappable<D>, T5: Mappable<D>>
875 Mappable<D> for (T1, T2, T3, T4, T5)
876{
877 type Output<O> = (
878 T1::Output<O>,
879 T2::Output<O>,
880 T3::Output<O>,
881 T4::Output<O>,
882 T5::Output<O>,
883 );
884 fn map_dual<M: Fn(D) -> O, O>(self, f: M) -> Self::Output<O> {
885 let (s1, s2, s3, s4, s5) = self;
886 (
887 s1.map_dual(&f),
888 s2.map_dual(&f),
889 s3.map_dual(&f),
890 s4.map_dual(&f),
891 s5.map_dual(&f),
892 )
893 }
894}
895
896impl<F, T: DualStruct<F>, const N: usize> DualStruct<F> for [T; N] {
897 type Real = [T::Real; N];
898 type Inner = [T::Inner; N];
899 fn re(&self) -> Self::Real {
900 self.each_ref().map(|x| x.re())
901 }
902 fn from_inner(re: &Self::Inner) -> Self {
903 re.each_ref().map(T::from_inner)
904 }
905}
906
907impl<D, T: Mappable<D>, const N: usize> Mappable<D> for [T; N] {
908 type Output<O> = [T::Output<O>; N];
909 fn map_dual<M: Fn(D) -> O, O>(self, f: M) -> Self::Output<O> {
910 self.map(|x| x.map_dual(&f))
911 }
912}
913
914impl<F, T: DualStruct<F>> DualStruct<F> for Option<T> {
915 type Real = Option<T::Real>;
916 type Inner = Option<T::Inner>;
917 fn re(&self) -> Self::Real {
918 self.as_ref().map(|x| x.re())
919 }
920 fn from_inner(inner: &Self::Inner) -> Self {
921 inner.as_ref().map(|x| T::from_inner(x))
922 }
923}
924
925impl<D, T: Mappable<D>> Mappable<D> for Option<T> {
926 type Output<O> = Option<T::Output<O>>;
927 fn map_dual<M: Fn(D) -> O, O>(self, f: M) -> Self::Output<O> {
928 self.map(|x| x.map_dual(f))
929 }
930}
931
932impl<D, T: Mappable<D>, E> Mappable<D> for Result<T, E> {
933 type Output<O> = Result<T::Output<O>, E>;
934 fn map_dual<M: Fn(D) -> O, O>(self, f: M) -> Self::Output<O> {
935 self.map(|x| x.map_dual(f))
936 }
937}
938
939impl<F, T: DualStruct<F>> DualStruct<F> for Vec<T> {
940 type Real = Vec<T::Real>;
941 type Inner = Vec<T::Inner>;
942 fn re(&self) -> Self::Real {
943 self.iter().map(|x| x.re()).collect()
944 }
945 fn from_inner(inner: &Self::Inner) -> Self {
946 inner.iter().map(|x| T::from_inner(x)).collect()
947 }
948}
949
950impl<D, T: Mappable<D>> Mappable<D> for Vec<T> {
951 type Output<O> = Vec<T::Output<O>>;
952 fn map_dual<M: Fn(D) -> O, O>(self, f: M) -> Self::Output<O> {
953 self.into_iter().map(|x| x.map_dual(&f)).collect()
954 }
955}
956
957impl<F, T: DualStruct<F>, K: Clone + Eq + Hash> DualStruct<F> for HashMap<K, T> {
958 type Real = HashMap<K, T::Real>;
959 type Inner = HashMap<K, T::Inner>;
960 fn re(&self) -> Self::Real {
961 self.iter().map(|(k, x)| (k.clone(), x.re())).collect()
962 }
963 fn from_inner(inner: &Self::Inner) -> Self {
964 inner
965 .iter()
966 .map(|(k, x)| (k.clone(), T::from_inner(x)))
967 .collect()
968 }
969}
970
971impl<D, T: Mappable<D>, K: Eq + Hash> Mappable<D> for HashMap<K, T> {
972 type Output<O> = HashMap<K, T::Output<O>>;
973 fn map_dual<M: Fn(D) -> O, O>(self, f: M) -> Self::Output<O> {
974 self.into_iter().map(|(k, x)| (k, x.map_dual(&f))).collect()
975 }
976}
977
978impl<F: DualNumFloat, D: DualNum<F>, R: Dim, C: Dim> DualStruct<F> for OMatrix<D, R, C>
979where
980 DefaultAllocator: Allocator<R, C>,
981{
982 type Real = OMatrix<D::Real, R, C>;
983 type Inner = OMatrix<D::InnerDual, R, C>;
984 fn re(&self) -> Self::Real {
985 self.map(|x| x.re())
986 }
987 fn from_inner(inner: &Self::Inner) -> Self {
988 inner.map(|x| DualNum::from_re(x))
989 }
990}
991
992impl<D: Scalar, R: Dim, C: Dim> Mappable<Self> for OMatrix<D, R, C>
993where
994 DefaultAllocator: Allocator<R, C>,
995{
996 type Output<O> = O;
997 fn map_dual<M: Fn(Self) -> O, O>(self, f: M) -> O {
998 f(self)
999 }
1000}