num_dual/
lib.rs

1//! Generalized, recursive, scalar and vector (hyper) dual numbers for the automatic and exact calculation of (partial) derivatives.
2//!
3//! # Example
4//! This example defines a generic scalar and a generic vector function that can be called using any (hyper-) dual number and automatically calculates derivatives.
5//! ```
6//! use num_dual::*;
7//! use nalgebra::SVector;
8//!
9//! fn foo<D: DualNum<f64>>(x: D) -> D {
10//!     x.powi(3)
11//! }
12//!
13//! fn bar<D: DualNum<f64>, const N: usize>(x: SVector<D, N>) -> D {
14//!     x.dot(&x).sqrt()
15//! }
16//!
17//! fn main() {
18//!     // Calculate a simple derivative
19//!     let (f, df) = first_derivative(foo, 5.0);
20//!     assert_eq!(f, 125.0);
21//!     assert_eq!(df, 75.0);
22//!
23//!     // Manually construct the dual number
24//!     let x = Dual64::new(5.0, 1.0);
25//!     println!("{}", foo(x));                     // 125 + 75ε
26//!
27//!     // Calculate a gradient
28//!     let (f, g) = gradient(bar, &SVector::from([4.0, 3.0]));
29//!     assert_eq!(f, 5.0);
30//!     assert_eq!(g[0], 0.8);
31//!
32//!     // Calculate a Hessian
33//!     let (f, g, h) = hessian(bar, &SVector::from([4.0, 3.0]));
34//!     println!("{h}");                            // [[0.072, -0.096], [-0.096, 0.128]]
35//!
36//!     // for x=cos(t) calculate the third derivative of foo w.r.t. t
37//!     let (f0, f1, f2, f3) = third_derivative(|t| foo(t.cos()), 1.0);
38//!     println!("{f3}");                           // 1.5836632930100278
39//! }
40//! ```
41//!
42//! # Usage
43//! There are two ways to use the data structures and functions provided in this crate:
44//! 1. (recommended) Using the provided functions for explicit ([`first_derivative`], [`gradient`], ...) and
45//!    implicit ([`implicit_derivative`], [`implicit_derivative_binary`], [`implicit_derivative_vec`]) functions.
46//! 2. (for experienced users) Using the different dual number types ([`Dual`], [`HyperDual`], [`DualVec`], ...) directly.
47//!
48//! The following examples and explanations focus on the first way.
49//!
50//! # Derivatives of explicit functions
51//! To be able to calculate the derivative of a function, it needs to be generic over the type of dual number used.
52//! Most commonly this would look like this:
53//! ```compile_fail
54//! fn foo<D: DualNum<f64> + Copy>(x: X) -> O {...}
55//! ```
56//! Of course, the function could also use single precision ([`f32`]) or be generic over the precision (`F:` [`DualNumFloat`]).
57//! For now, [`Copy`] is not a supertrait of [`DualNum`] to enable the calculation of derivatives with respect
58//! to a dynamic number of variables. However, in practice, using the [`Copy`] trait bound leads to an
59//! implementation that is more similar to one not using AD and there could be severe performance ramifications
60//! when using dynamically allocated dual numbers.
61//!
62//! The type `X` above is `D` for univariate functions, [`&OVector`](nalgebra::OVector) for multivariate
63//! functions, and `(D, D)` or `(&OVector, &OVector)` for partial derivatives. In the simplest case, the output
64//! `O` is a scalar `D`. However, it is generalized using the [`Mappable`] trait to also include types like
65//! [`Option<D>`] or [`Result<D, E>`], collections like [`Vec<D>`] or [`HashMap<K, D>`], or custom structs that
66//! implement the [`Mappable`] trait. Therefore, it is, e.g., possible to calculate the derivative of a fallible
67//! function:
68//!
69//! ```no_run
70//! # use num_dual::{DualNum, first_derivative};
71//! # type E = ();
72//! fn foo<D: DualNum<f64> + Copy>(x: D) -> Result<D, E> { todo!() }
73//!
74//! fn main() -> Result<(), E> {
75//!     let (val, deriv) = first_derivative(foo, 2.0)?;
76//!     // ...
77//!     Ok(())
78//! }
79//! ```
80//! All dual number types can contain other dual numbers as inner types. Therefore, it is also possible to
81//! use the different derivative functions inside of each other.
82//!
83//! ## Extra arguments
84//! The [`partial`] and [`partial2`] functions are used to pass additional arguments to the function, e.g.:
85//! ```no_run
86//! # use num_dual::{DualNum, first_derivative, partial};
87//! fn foo<D: DualNum<f64> + Copy>(x: D, args: &(D, D)) -> D { todo!() }
88//!
89//! fn main() {
90//!     let (val, deriv) = first_derivative(partial(foo, &(3.0, 4.0)), 5.0);
91//! }
92//! ```
93//! All types that implement the [`DualStruct`] trait can be used as additional function arguments. The
94//! only difference between using the [`partial`] and [`partial2`] functions compared to passing the extra
95//! arguments via a closure, is that the type of the extra arguments is automatically adjusted to the correct
96//! dual number type used for the automatic differentiation. Note that the following code would not compile:
97//! ```compile_fail
98//! # use num_dual::{DualNum, first_derivative};
99//! # fn foo<D: DualNum<f64> + Copy>(x: D, args: &(D, D)) -> D { todo!() }
100//! fn main() {
101//!     let (val, deriv) = first_derivative(|x| foo(x, &(3.0, 4.0)), 5.0);
102//! }
103//! ```
104//! The code created by [`partial`] essentially translates to:
105//! ```no_run
106//! # use num_dual::{DualNum, first_derivative, Dual, DualStruct};
107//! # fn foo<D: DualNum<f64> + Copy>(x: D, args: &(D, D)) -> D { todo!() }
108//! fn main() {
109//!     let (val, deriv) = first_derivative(|x| foo(x, &(Dual::from_inner(&3.0), Dual::from_inner(&4.0))), 5.0);
110//! }
111//! ```
112//!
113//! ## The [`Gradients`] trait
114//! The functions [`gradient`], [`hessian`], [`partial_hessian`] and [`jacobian`] are generic over the dimensionality
115//! of the variable vector. However, to use the functions in a generic context requires not using the [`Copy`] trait
116//! bound on the dual number type, because the dynamically sized dual numbers can by construction not implement
117//! [`Copy`]. Also, due to frequent heap allocations, the performance of the automatic differentiation could
118//! suffer significantly for dynamically sized dual numbers compared to statically sized dual numbers. The
119//! [`Gradients`] trait is introduced to overcome these limitations.
120//! ```
121//! # use num_dual::{DualNum, Gradients};
122//! # use nalgebra::{OVector, DefaultAllocator, allocator::Allocator, vector, dvector};
123//! # use approx::assert_relative_eq;
124//! fn foo<D: DualNum<f64> + Copy, N: Gradients>(x: OVector<D, N>, n: &D) -> D where DefaultAllocator: Allocator<N> {
125//!     x.dot(&x).sqrt() - n
126//! }
127//!
128//! fn main() {
129//!     let x = vector![1.0, 5.0, 5.0, 7.0];
130//!     let (f, grad) = Gradients::gradient(foo, &x, &10.0);
131//!     assert_eq!(f, 0.0);
132//!     assert_relative_eq!(grad, vector![0.1, 0.5, 0.5, 0.7]);
133//!
134//!     let x = dvector![1.0, 5.0, 5.0, 7.0];
135//!     let (f, grad) = Gradients::gradient(foo, &x, &10.0);
136//!     assert_eq!(f, 0.0);
137//!     assert_relative_eq!(grad, dvector![0.1, 0.5, 0.5, 0.7]);
138//! }
139//! ```
140//! For dynamically sized input arrays, the [`Gradients`] trait evaluates gradients or higher-order derivatives
141//! by iteratively evaluating scalar derivatives. For functions that do not rely on the [`Copy`] trait bound,
142//! only benchmarking can reveal Whether the increased performance through the avoidance of heap allocations
143//! can overcome the overhead of repeated function evaluations, i.e., if [`Gradients`] outperforms directly
144//! calling [`gradient`], [`hessian`], [`partial_hessian`] or [`jacobian`].
145//!
146//! # Derivatives of implicit functions
147//! Implicit differentiation is used to determine the derivative `dy/dx` where the output `y` is only related
148//! implicitly to the input `x` via the equation `f(x,y)=0`. Automatic implicit differentiation generalizes the
149//! idea to determining the output `y` with full derivative information. Note that the first step in calculating
150//! an implicit derivative is always determining the "real" part (i.e., neglecting all derivatives) of the equation
151//! `f(x,y)=0`. The `num-dual` library is focused on automatic differentiation and not nonlinear equation
152//! solving. Therefore, this first step needs to be done with your own custom solutions, or Rust crates for
153//! nonlinear equation solving and optimization like, e.g., [argmin](https://argmin-rs.org/).
154//!
155//! The following example implements a square root for generic dual numbers using implicit differentiation. Of
156//! course, the derivatives of the square root can also be determined explicitly using the chain rule, so the
157//! example serves mostly as illustration. `x.re()` provides the "real" part of the dual number which is a [`f64`]
158//! and therefore, we can use all the functionalities from the std library (including the square root).
159//! ```
160//! # use num_dual::{DualNum, implicit_derivative, first_derivative};
161//! fn implicit_sqrt<D: DualNum<f64> + Copy>(x: D) -> D {
162//!     implicit_derivative(|s, x| s * s - x, x.re().sqrt(), &x)
163//! }
164//!
165//! fn main() {
166//!     // sanity check, not actually calculating any derivative
167//!     assert_eq!(implicit_sqrt(25.0), 5.0);
168//!     
169//!     let (sq, deriv) = first_derivative(implicit_sqrt, 25.0);
170//!     assert_eq!(sq, 5.0);
171//!     // The derivative of sqrt(x) is 1/(2*sqrt(x)) which should evaluate to 0.1
172//!     assert_eq!(deriv, 0.1);
173//! }
174//! ```
175//! The `implicit_sqrt` or any likewise defined function is generic over the dual type `D`
176//! and can, therefore, be used anywhere as a part of an arbitrary complex computation. The functions
177//! [`implicit_derivative_binary`] and [`implicit_derivative_vec`] can be used for implicit functions
178//! with more than one variable.
179//!
180//! For implicit functions that contain complex models and a large number of parameters, the [`ImplicitDerivative`]
181//! interface might come in handy. The idea is to define the implicit function using the [`ImplicitFunction`] trait
182//! and feeding it into the [`ImplicitDerivative`] struct, which internally stores the parameters as dual numbers
183//! and their real parts. The [`ImplicitDerivative`] then provides methods for the evaluation of the real part
184//! of the residual (which can be passed to a nonlinear solver) and the implicit derivative which can be called
185//! after solving for the real part of the solution to reconstruct all the derivatives.
186//! ```
187//! # use num_dual::{ImplicitFunction, DualNum, Dual, ImplicitDerivative};
188//! struct ImplicitSqrt;
189//! impl ImplicitFunction<f64> for ImplicitSqrt {
190//!     type Parameters<D> = D;
191//!     type Variable<D> = D;
192//!     fn residual<D: DualNum<f64> + Copy>(x: D, square: &D) -> D {
193//!         *square - x * x
194//!     }
195//! }
196//!
197//! fn main() {
198//!     let x = Dual::from_re(25.0).derivative();
199//!     let func = ImplicitDerivative::new(ImplicitSqrt, x);
200//!     assert_eq!(func.residual(5.0), 0.0);
201//!     assert_eq!(x.sqrt(), func.implicit_derivative(5.0));
202//! }
203//! ```
204//!
205//! ## Combination with nonlinear solver libraries
206//! As mentioned previously, this crate does not contain any algorithms for nonlinear optimization or root finding.
207//! However, combining the capabilities of automatic differentiation with nonlinear solving can be very fruitful.
208//! Most importantly, the calculation of Jacobians or Hessians can be completely automated, if the model can be
209//! expressed within the functionalities of the [`DualNum`] trait. On top of that implicit derivatives can be of
210//! interest, if derivatives of the result of the optimization itself are relevant (e.g., in a bilevel
211//! optimization). The synergy is exploited in the [`ipopt-ad`](https://github.com/prehner/ipopt-ad) crate that
212//! turns the NLP solver [IPOPT](https://github.com/coin-or/Ipopt) into a black-box optimization algorithm (i.e.,
213//! it only requires a function that returns the values of the optimization variable and constraints), without
214//! any repercussions regarding the robustness or speed of convergence of the solver.
215//!
216//! If you are developing nonlinear optimization algorithms in Rust, feel free to reach out to us. We are happy to
217//! discuss how to enhance your algorithms with the automatic differentiation capabilities of this crate.
218
219#![warn(clippy::all)]
220#![warn(clippy::allow_attributes)]
221
222use nalgebra::allocator::Allocator;
223use nalgebra::{DefaultAllocator, Dim, OMatrix, Scalar};
224#[cfg(feature = "ndarray")]
225use ndarray::ScalarOperand;
226use num_traits::{Float, FloatConst, FromPrimitive, Inv, NumAssignOps, NumOps, Signed};
227use std::collections::HashMap;
228use std::fmt;
229use std::hash::Hash;
230use std::iter::{Product, Sum};
231
232#[macro_use]
233mod macros;
234#[macro_use]
235mod impl_derivatives;
236
237mod bessel;
238mod datatypes;
239mod explicit;
240mod implicit;
241pub use bessel::BesselDual;
242pub use datatypes::derivative::Derivative;
243pub use datatypes::dual::{Dual, Dual32, Dual64};
244pub use datatypes::dual2::{Dual2, Dual2_32, Dual2_64};
245pub use datatypes::dual2_vec::{
246    Dual2DVec32, Dual2DVec64, Dual2SVec32, Dual2SVec64, Dual2Vec, Dual2Vec32, Dual2Vec64,
247};
248pub use datatypes::dual3::{Dual3, Dual3_32, Dual3_64};
249pub use datatypes::dual_vec::{
250    DualDVec32, DualDVec64, DualSVec, DualSVec32, DualSVec64, DualVec, DualVec32, DualVec64,
251};
252pub use datatypes::hyperdual::{HyperDual, HyperDual32, HyperDual64};
253pub use datatypes::hyperdual_vec::{
254    HyperDualDVec32, HyperDualDVec64, HyperDualSVec32, HyperDualSVec64, HyperDualVec,
255    HyperDualVec32, HyperDualVec64,
256};
257pub use datatypes::hyperhyperdual::{HyperHyperDual, HyperHyperDual32, HyperHyperDual64};
258pub use datatypes::real::Real;
259pub use explicit::{
260    first_derivative, gradient, hessian, jacobian, partial, partial2, partial_hessian,
261    second_derivative, second_partial_derivative, third_derivative, third_partial_derivative,
262    third_partial_derivative_vec, zeroth_derivative, Gradients,
263};
264pub use implicit::{
265    implicit_derivative, implicit_derivative_binary, implicit_derivative_vec, ImplicitDerivative,
266    ImplicitFunction,
267};
268
269pub mod linalg;
270
271#[cfg(feature = "python")]
272pub mod python;
273
274#[cfg(feature = "python_macro")]
275mod python_macro;
276
277/// A generalized (hyper) dual number.
278#[cfg(feature = "ndarray")]
279pub trait DualNum<F>:
280    NumOps
281    + for<'r> NumOps<&'r Self>
282    + Signed
283    + NumOps<F>
284    + NumAssignOps
285    + NumAssignOps<F>
286    + Clone
287    + Inv<Output = Self>
288    + Sum
289    + Product
290    + FromPrimitive
291    + From<F>
292    + DualStruct<Self, F, Real = F>
293    + Mappable<Self>
294    + fmt::Display
295    + PartialEq
296    + fmt::Debug
297    + ScalarOperand
298    + 'static
299{
300    /// Highest derivative that can be calculated with this struct
301    const NDERIV: usize;
302
303    /// Reciprocal (inverse) of a number `1/x`
304    fn recip(&self) -> Self;
305
306    /// Power with integer exponent `x^n`
307    fn powi(&self, n: i32) -> Self;
308
309    /// Power with real exponent `x^n`
310    fn powf(&self, n: F) -> Self;
311
312    /// Square root
313    fn sqrt(&self) -> Self;
314
315    /// Cubic root
316    fn cbrt(&self) -> Self;
317
318    /// Exponential `e^x`
319    fn exp(&self) -> Self;
320
321    /// Exponential with base 2 `2^x`
322    fn exp2(&self) -> Self;
323
324    /// Exponential minus 1 `e^x-1`
325    fn exp_m1(&self) -> Self;
326
327    /// Natural logarithm
328    fn ln(&self) -> Self;
329
330    /// Logarithm with arbitrary base
331    fn log(&self, base: F) -> Self;
332
333    /// Logarithm with base 2
334    fn log2(&self) -> Self;
335
336    /// Logarithm with base 10
337    fn log10(&self) -> Self;
338
339    /// Logarithm on x plus one `ln(1+x)`
340    fn ln_1p(&self) -> Self;
341
342    /// Sine
343    fn sin(&self) -> Self;
344
345    /// Cosine
346    fn cos(&self) -> Self;
347
348    /// Tangent
349    fn tan(&self) -> Self;
350
351    /// Calculate sine and cosine simultaneously
352    fn sin_cos(&self) -> (Self, Self);
353
354    /// Arcsine
355    fn asin(&self) -> Self;
356
357    /// Arccosine
358    fn acos(&self) -> Self;
359
360    /// Arctangent
361    fn atan(&self) -> Self;
362
363    /// Arctangent
364    fn atan2(&self, other: Self) -> Self;
365
366    /// Hyperbolic sine
367    fn sinh(&self) -> Self;
368
369    /// Hyperbolic cosine
370    fn cosh(&self) -> Self;
371
372    /// Hyperbolic tangent
373    fn tanh(&self) -> Self;
374
375    /// Area hyperbolic sine
376    fn asinh(&self) -> Self;
377
378    /// Area hyperbolic cosine
379    fn acosh(&self) -> Self;
380
381    /// Area hyperbolic tangent
382    fn atanh(&self) -> Self;
383
384    /// 0th order spherical Bessel function of the first kind
385    fn sph_j0(&self) -> Self;
386
387    /// 1st order spherical Bessel function of the first kind
388    fn sph_j1(&self) -> Self;
389
390    /// 2nd order spherical Bessel function of the first kind
391    fn sph_j2(&self) -> Self;
392
393    /// Fused multiply-add
394    #[inline]
395    fn mul_add(&self, a: Self, b: Self) -> Self {
396        self.clone() * a + b
397    }
398
399    /// Power with dual exponent `x^n`
400    #[inline]
401    fn powd(&self, exp: Self) -> Self {
402        (self.ln() * exp).exp()
403    }
404}
405
406/// A generalized (hyper) dual number.
407#[cfg(not(feature = "ndarray"))]
408pub trait DualNum<F>:
409    NumOps
410    + for<'r> NumOps<&'r Self>
411    + Signed
412    + NumOps<F>
413    + NumAssignOps
414    + NumAssignOps<F>
415    + Clone
416    + Inv<Output = Self>
417    + Sum
418    + Product
419    + FromPrimitive
420    + From<F>
421    + DualStruct<Self, F, Real = F>
422    + Mappable<Self>
423    + fmt::Display
424    + PartialEq
425    + fmt::Debug
426    + 'static
427{
428    /// Highest derivative that can be calculated with this struct
429    const NDERIV: usize;
430
431    /// Reciprocal (inverse) of a number `1/x`
432    fn recip(&self) -> Self;
433
434    /// Power with integer exponent `x^n`
435    fn powi(&self, n: i32) -> Self;
436
437    /// Power with real exponent `x^n`
438    fn powf(&self, n: F) -> Self;
439
440    /// Square root
441    fn sqrt(&self) -> Self;
442
443    /// Cubic root
444    fn cbrt(&self) -> Self;
445
446    /// Exponential `e^x`
447    fn exp(&self) -> Self;
448
449    /// Exponential with base 2 `2^x`
450    fn exp2(&self) -> Self;
451
452    /// Exponential minus 1 `e^x-1`
453    fn exp_m1(&self) -> Self;
454
455    /// Natural logarithm
456    fn ln(&self) -> Self;
457
458    /// Logarithm with arbitrary base
459    fn log(&self, base: F) -> Self;
460
461    /// Logarithm with base 2
462    fn log2(&self) -> Self;
463
464    /// Logarithm with base 10
465    fn log10(&self) -> Self;
466
467    /// Logarithm on x plus one `ln(1+x)`
468    fn ln_1p(&self) -> Self;
469
470    /// Sine
471    fn sin(&self) -> Self;
472
473    /// Cosine
474    fn cos(&self) -> Self;
475
476    /// Tangent
477    fn tan(&self) -> Self;
478
479    /// Calculate sine and cosine simultaneously
480    fn sin_cos(&self) -> (Self, Self);
481
482    /// Arcsine
483    fn asin(&self) -> Self;
484
485    /// Arccosine
486    fn acos(&self) -> Self;
487
488    /// Arctangent
489    fn atan(&self) -> Self;
490
491    /// Arctangent
492    fn atan2(&self, other: Self) -> Self;
493
494    /// Hyperbolic sine
495    fn sinh(&self) -> Self;
496
497    /// Hyperbolic cosine
498    fn cosh(&self) -> Self;
499
500    /// Hyperbolic tangent
501    fn tanh(&self) -> Self;
502
503    /// Area hyperbolic sine
504    fn asinh(&self) -> Self;
505
506    /// Area hyperbolic cosine
507    fn acosh(&self) -> Self;
508
509    /// Area hyperbolic tangent
510    fn atanh(&self) -> Self;
511
512    /// 0th order spherical Bessel function of the first kind
513    fn sph_j0(&self) -> Self;
514
515    /// 1st order spherical Bessel function of the first kind
516    fn sph_j1(&self) -> Self;
517
518    /// 2nd order spherical Bessel function of the first kind
519    fn sph_j2(&self) -> Self;
520
521    /// Fused multiply-add
522    #[inline]
523    fn mul_add(&self, a: Self, b: Self) -> Self {
524        self.clone() * a + b
525    }
526
527    /// Power with dual exponent `x^n`
528    #[inline]
529    fn powd(&self, exp: Self) -> Self {
530        (self.ln() * exp).exp()
531    }
532}
533
534/// The underlying data type of individual derivatives. Usually f32 or f64.
535pub trait DualNumFloat:
536    Float + FloatConst + FromPrimitive + Signed + fmt::Display + fmt::Debug + Sync + Send + 'static
537{
538}
539impl<T> DualNumFloat for T where
540    T: Float
541        + FloatConst
542        + FromPrimitive
543        + Signed
544        + fmt::Display
545        + fmt::Debug
546        + Sync
547        + Send
548        + 'static
549{
550}
551
552macro_rules! impl_dual_num_float {
553    ($float:ty) => {
554        impl DualNum<$float> for $float {
555            const NDERIV: usize = 0;
556
557            fn mul_add(&self, a: Self, b: Self) -> Self {
558                <$float>::mul_add(*self, a, b)
559            }
560            fn recip(&self) -> Self {
561                <$float>::recip(*self)
562            }
563            fn powi(&self, n: i32) -> Self {
564                <$float>::powi(*self, n)
565            }
566            fn powf(&self, n: Self) -> Self {
567                <$float>::powf(*self, n)
568            }
569            fn powd(&self, n: Self) -> Self {
570                <$float>::powf(*self, n)
571            }
572            fn sqrt(&self) -> Self {
573                <$float>::sqrt(*self)
574            }
575            fn exp(&self) -> Self {
576                <$float>::exp(*self)
577            }
578            fn exp2(&self) -> Self {
579                <$float>::exp2(*self)
580            }
581            fn ln(&self) -> Self {
582                <$float>::ln(*self)
583            }
584            fn log(&self, base: Self) -> Self {
585                <$float>::log(*self, base)
586            }
587            fn log2(&self) -> Self {
588                <$float>::log2(*self)
589            }
590            fn log10(&self) -> Self {
591                <$float>::log10(*self)
592            }
593            fn cbrt(&self) -> Self {
594                <$float>::cbrt(*self)
595            }
596            fn sin(&self) -> Self {
597                <$float>::sin(*self)
598            }
599            fn cos(&self) -> Self {
600                <$float>::cos(*self)
601            }
602            fn tan(&self) -> Self {
603                <$float>::tan(*self)
604            }
605            fn asin(&self) -> Self {
606                <$float>::asin(*self)
607            }
608            fn acos(&self) -> Self {
609                <$float>::acos(*self)
610            }
611            fn atan(&self) -> Self {
612                <$float>::atan(*self)
613            }
614            fn atan2(&self, other: $float) -> Self {
615                <$float>::atan2(*self, other)
616            }
617            fn sin_cos(&self) -> (Self, Self) {
618                <$float>::sin_cos(*self)
619            }
620            fn exp_m1(&self) -> Self {
621                <$float>::exp_m1(*self)
622            }
623            fn ln_1p(&self) -> Self {
624                <$float>::ln_1p(*self)
625            }
626            fn sinh(&self) -> Self {
627                <$float>::sinh(*self)
628            }
629            fn cosh(&self) -> Self {
630                <$float>::cosh(*self)
631            }
632            fn tanh(&self) -> Self {
633                <$float>::tanh(*self)
634            }
635            fn asinh(&self) -> Self {
636                <$float>::asinh(*self)
637            }
638            fn acosh(&self) -> Self {
639                <$float>::acosh(*self)
640            }
641            fn atanh(&self) -> Self {
642                <$float>::atanh(*self)
643            }
644            fn sph_j0(&self) -> Self {
645                if self.abs() < <$float>::EPSILON {
646                    1.0 - self * self / 6.0
647                } else {
648                    self.sin() / self
649                }
650            }
651            fn sph_j1(&self) -> Self {
652                if self.abs() < <$float>::EPSILON {
653                    self / 3.0
654                } else {
655                    let sc = self.sin_cos();
656                    let rec = self.recip();
657                    (sc.0 * rec - sc.1) * rec
658                }
659            }
660            fn sph_j2(&self) -> Self {
661                if self.abs() < <$float>::EPSILON {
662                    self * self / 15.0
663                } else {
664                    let sc = self.sin_cos();
665                    let s2 = self * self;
666                    ((3.0 - s2) * sc.0 - 3.0 * self * sc.1) / (self * s2)
667                }
668            }
669        }
670    };
671}
672
673impl_dual_num_float!(f32);
674impl_dual_num_float!(f64);
675
676/// A struct that contains dual numbers. Needed for arbitrary arguments in [ImplicitFunction].
677///
678/// The trait is implemented for all dual types themselves, and common data types (tuple, vec,
679/// array, ...) and can be implemented for custom data types to achieve full flexibility.
680pub trait DualStruct<D, F> {
681    type Real;
682    type Inner;
683    fn re(&self) -> Self::Real;
684    fn from_inner(inner: &Self::Inner) -> Self;
685}
686
687/// Trait for structs used as an output of functions for which derivatives are calculated.
688///
689/// The main intention is to generalize the calculation of derivatives to fallible functions, but
690/// other use cases might also appear in the future.
691pub trait Mappable<D> {
692    type Output<O>;
693    fn map_dual<M: Fn(D) -> O, O>(self, f: M) -> Self::Output<O>;
694}
695
696impl<D, F> DualStruct<D, F> for () {
697    type Real = ();
698    type Inner = ();
699    fn re(&self) {}
700    fn from_inner(_: &Self::Inner) -> Self {}
701}
702
703impl<D> Mappable<D> for () {
704    type Output<O> = ();
705    fn map_dual<M: FnOnce(D) -> O, O>(self, _: M) {}
706}
707
708impl DualStruct<f32, f32> for f32 {
709    type Real = f32;
710    type Inner = f32;
711    fn re(&self) -> f32 {
712        *self
713    }
714    fn from_inner(inner: &Self::Inner) -> Self {
715        *inner
716    }
717}
718
719impl Mappable<f32> for f32 {
720    type Output<O> = O;
721    fn map_dual<M: FnOnce(f32) -> O, O>(self, f: M) -> Self::Output<O> {
722        f(self)
723    }
724}
725
726impl DualStruct<f64, f64> for f64 {
727    type Real = f64;
728    type Inner = f64;
729    fn re(&self) -> f64 {
730        *self
731    }
732    fn from_inner(inner: &Self::Inner) -> Self {
733        *inner
734    }
735}
736
737impl Mappable<f64> for f64 {
738    type Output<O> = O;
739    fn map_dual<M: FnOnce(f64) -> O, O>(self, f: M) -> Self::Output<O> {
740        f(self)
741    }
742}
743
744impl<D, F, T1: DualStruct<D, F>, T2: DualStruct<D, F>> DualStruct<D, F> for (T1, T2) {
745    type Real = (T1::Real, T2::Real);
746    type Inner = (T1::Inner, T2::Inner);
747    fn re(&self) -> Self::Real {
748        let (s1, s2) = self;
749        (s1.re(), s2.re())
750    }
751    fn from_inner(re: &Self::Inner) -> Self {
752        let (r1, r2) = re;
753        (T1::from_inner(r1), T2::from_inner(r2))
754    }
755}
756
757impl<D, T1: Mappable<D>, T2: Mappable<D>> Mappable<D> for (T1, T2) {
758    type Output<O> = (T1::Output<O>, T2::Output<O>);
759    fn map_dual<M: Fn(D) -> O, O>(self, f: M) -> Self::Output<O> {
760        let (s1, s2) = self;
761        (s1.map_dual(&f), s2.map_dual(&f))
762    }
763}
764
765impl<D, F, T1: DualStruct<D, F>, T2: DualStruct<D, F>, T3: DualStruct<D, F>> DualStruct<D, F>
766    for (T1, T2, T3)
767{
768    type Real = (T1::Real, T2::Real, T3::Real);
769    type Inner = (T1::Inner, T2::Inner, T3::Inner);
770    fn re(&self) -> Self::Real {
771        let (s1, s2, s3) = self;
772        (s1.re(), s2.re(), s3.re())
773    }
774    fn from_inner(inner: &Self::Inner) -> Self {
775        let (r1, r2, r3) = inner;
776        (T1::from_inner(r1), T2::from_inner(r2), T3::from_inner(r3))
777    }
778}
779
780impl<D, T1: Mappable<D>, T2: Mappable<D>, T3: Mappable<D>> Mappable<D> for (T1, T2, T3) {
781    type Output<O> = (T1::Output<O>, T2::Output<O>, T3::Output<O>);
782    fn map_dual<M: Fn(D) -> O, O>(self, f: M) -> Self::Output<O> {
783        let (s1, s2, s3) = self;
784        (s1.map_dual(&f), s2.map_dual(&f), s3.map_dual(&f))
785    }
786}
787
788impl<
789        D,
790        F,
791        T1: DualStruct<D, F>,
792        T2: DualStruct<D, F>,
793        T3: DualStruct<D, F>,
794        T4: DualStruct<D, F>,
795    > DualStruct<D, F> for (T1, T2, T3, T4)
796{
797    type Real = (T1::Real, T2::Real, T3::Real, T4::Real);
798    type Inner = (T1::Inner, T2::Inner, T3::Inner, T4::Inner);
799    fn re(&self) -> Self::Real {
800        let (s1, s2, s3, s4) = self;
801        (s1.re(), s2.re(), s3.re(), s4.re())
802    }
803    fn from_inner(inner: &Self::Inner) -> Self {
804        let (r1, r2, r3, r4) = inner;
805        (
806            T1::from_inner(r1),
807            T2::from_inner(r2),
808            T3::from_inner(r3),
809            T4::from_inner(r4),
810        )
811    }
812}
813
814impl<D, T1: Mappable<D>, T2: Mappable<D>, T3: Mappable<D>, T4: Mappable<D>> Mappable<D>
815    for (T1, T2, T3, T4)
816{
817    type Output<O> = (T1::Output<O>, T2::Output<O>, T3::Output<O>, T4::Output<O>);
818    fn map_dual<M: Fn(D) -> O, O>(self, f: M) -> Self::Output<O> {
819        let (s1, s2, s3, s4) = self;
820        (
821            s1.map_dual(&f),
822            s2.map_dual(&f),
823            s3.map_dual(&f),
824            s4.map_dual(&f),
825        )
826    }
827}
828
829impl<
830        D,
831        F,
832        T1: DualStruct<D, F>,
833        T2: DualStruct<D, F>,
834        T3: DualStruct<D, F>,
835        T4: DualStruct<D, F>,
836        T5: DualStruct<D, F>,
837    > DualStruct<D, F> for (T1, T2, T3, T4, T5)
838{
839    type Real = (T1::Real, T2::Real, T3::Real, T4::Real, T5::Real);
840    type Inner = (T1::Inner, T2::Inner, T3::Inner, T4::Inner, T5::Inner);
841    fn re(&self) -> Self::Real {
842        let (s1, s2, s3, s4, s5) = self;
843        (s1.re(), s2.re(), s3.re(), s4.re(), s5.re())
844    }
845    fn from_inner(inner: &Self::Inner) -> Self {
846        let (r1, r2, r3, r4, r5) = inner;
847        (
848            T1::from_inner(r1),
849            T2::from_inner(r2),
850            T3::from_inner(r3),
851            T4::from_inner(r4),
852            T5::from_inner(r5),
853        )
854    }
855}
856
857impl<D, T1: Mappable<D>, T2: Mappable<D>, T3: Mappable<D>, T4: Mappable<D>, T5: Mappable<D>>
858    Mappable<D> for (T1, T2, T3, T4, T5)
859{
860    type Output<O> = (
861        T1::Output<O>,
862        T2::Output<O>,
863        T3::Output<O>,
864        T4::Output<O>,
865        T5::Output<O>,
866    );
867    fn map_dual<M: Fn(D) -> O, O>(self, f: M) -> Self::Output<O> {
868        let (s1, s2, s3, s4, s5) = self;
869        (
870            s1.map_dual(&f),
871            s2.map_dual(&f),
872            s3.map_dual(&f),
873            s4.map_dual(&f),
874            s5.map_dual(&f),
875        )
876    }
877}
878
879impl<D, F, T: DualStruct<D, F>, const N: usize> DualStruct<D, F> for [T; N] {
880    type Real = [T::Real; N];
881    type Inner = [T::Inner; N];
882    fn re(&self) -> Self::Real {
883        self.each_ref().map(|x| x.re())
884    }
885    fn from_inner(re: &Self::Inner) -> Self {
886        re.each_ref().map(T::from_inner)
887    }
888}
889
890impl<D, T: Mappable<D>, const N: usize> Mappable<D> for [T; N] {
891    type Output<O> = [T::Output<O>; N];
892    fn map_dual<M: Fn(D) -> O, O>(self, f: M) -> Self::Output<O> {
893        self.map(|x| x.map_dual(&f))
894    }
895}
896
897impl<D, F, T: DualStruct<D, F>> DualStruct<D, F> for Option<T> {
898    type Real = Option<T::Real>;
899    type Inner = Option<T::Inner>;
900    fn re(&self) -> Self::Real {
901        self.as_ref().map(|x| x.re())
902    }
903    fn from_inner(inner: &Self::Inner) -> Self {
904        inner.as_ref().map(|x| T::from_inner(x))
905    }
906}
907
908impl<D, T: Mappable<D>> Mappable<D> for Option<T> {
909    type Output<O> = Option<T::Output<O>>;
910    fn map_dual<M: Fn(D) -> O, O>(self, f: M) -> Self::Output<O> {
911        self.map(|x| x.map_dual(f))
912    }
913}
914
915impl<D, T: Mappable<D>, E> Mappable<D> for Result<T, E> {
916    type Output<O> = Result<T::Output<O>, E>;
917    fn map_dual<M: Fn(D) -> O, O>(self, f: M) -> Self::Output<O> {
918        self.map(|x| x.map_dual(f))
919    }
920}
921
922impl<D, F, T: DualStruct<D, F>> DualStruct<D, F> for Vec<T> {
923    type Real = Vec<T::Real>;
924    type Inner = Vec<T::Inner>;
925    fn re(&self) -> Self::Real {
926        self.iter().map(|x| x.re()).collect()
927    }
928    fn from_inner(inner: &Self::Inner) -> Self {
929        inner.iter().map(|x| T::from_inner(x)).collect()
930    }
931}
932
933impl<D, T: Mappable<D>> Mappable<D> for Vec<T> {
934    type Output<O> = Vec<T::Output<O>>;
935    fn map_dual<M: Fn(D) -> O, O>(self, f: M) -> Self::Output<O> {
936        self.into_iter().map(|x| x.map_dual(&f)).collect()
937    }
938}
939
940impl<D, F, T: DualStruct<D, F>, K: Clone + Eq + Hash> DualStruct<D, F> for HashMap<K, T> {
941    type Real = HashMap<K, T::Real>;
942    type Inner = HashMap<K, T::Inner>;
943    fn re(&self) -> Self::Real {
944        self.iter().map(|(k, x)| (k.clone(), x.re())).collect()
945    }
946    fn from_inner(inner: &Self::Inner) -> Self {
947        inner
948            .iter()
949            .map(|(k, x)| (k.clone(), T::from_inner(x)))
950            .collect()
951    }
952}
953
954impl<D, T: Mappable<D>, K: Eq + Hash> Mappable<D> for HashMap<K, T> {
955    type Output<O> = HashMap<K, T::Output<O>>;
956    fn map_dual<M: Fn(D) -> O, O>(self, f: M) -> Self::Output<O> {
957        self.into_iter().map(|(k, x)| (k, x.map_dual(&f))).collect()
958    }
959}
960
961impl<D: DualNum<F>, F: DualNumFloat, R: Dim, C: Dim> DualStruct<D, F> for OMatrix<D, R, C>
962where
963    DefaultAllocator: Allocator<R, C>,
964    D::Inner: DualNum<F>,
965{
966    type Real = OMatrix<F, R, C>;
967    type Inner = OMatrix<D::Inner, R, C>;
968    fn re(&self) -> Self::Real {
969        self.map(|x| x.re())
970    }
971    fn from_inner(inner: &Self::Inner) -> Self {
972        inner.map(|x| D::from_inner(&x))
973    }
974}
975
976impl<D: Scalar, R: Dim, C: Dim> Mappable<Self> for OMatrix<D, R, C>
977where
978    DefaultAllocator: Allocator<R, C>,
979{
980    type Output<O> = O;
981    fn map_dual<M: Fn(Self) -> O, O>(self, f: M) -> O {
982        f(self)
983    }
984}