Skip to main content

num_dual/datatypes/
dual.rs

1use crate::{DualNum, DualNumFloat, DualStruct};
2use approx::{AbsDiffEq, RelativeEq, UlpsEq};
3use nalgebra::*;
4use num_traits::{Float, FloatConst, FromPrimitive, Inv, Num, One, Signed, Zero};
5#[cfg(feature = "serde")]
6use serde::{Deserialize, Serialize};
7use std::fmt;
8use std::iter::{Product, Sum};
9use std::marker::PhantomData;
10use std::ops::{
11    Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign,
12};
13
14/// A scalar dual number for the calculations of first derivatives.
15#[derive(Copy, Clone, Debug)]
16#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
17pub struct Dual<T: DualNum<F>, F> {
18    /// Real part of the dual number
19    pub re: T,
20    /// Derivative part of the dual number
21    pub eps: T,
22    #[cfg_attr(feature = "serde", serde(skip))]
23    f: PhantomData<F>,
24}
25
26#[cfg(feature = "ndarray")]
27impl<T: DualNum<F>, F: DualNumFloat> ndarray::ScalarOperand for Dual<T, F> {}
28
29pub type Dual32 = Dual<f32, f32>;
30pub type Dual64 = Dual<f64, f64>;
31
32impl<T: DualNum<F>, F> Dual<T, F> {
33    /// Create a new dual number from its fields.
34    #[inline]
35    pub fn new(re: T, eps: T) -> Self {
36        Self {
37            re,
38            eps,
39            f: PhantomData,
40        }
41    }
42}
43
44impl<T: DualNum<F> + Zero, F> Dual<T, F> {
45    /// Create a new dual number from the real part.
46    #[inline]
47    pub fn from_re(re: T) -> Self {
48        Self::new(re, T::zero())
49    }
50}
51
52impl<T: DualNum<F> + One, F> Dual<T, F> {
53    /// Set the derivative part to 1.
54    /// ```
55    /// # use num_dual::{Dual64, DualNum};
56    /// let x = Dual64::from_re(5.0).derivative().powi(2);
57    /// assert_eq!(x.re, 25.0);
58    /// assert_eq!(x.eps, 10.0);
59    /// ```
60    #[inline]
61    pub fn derivative(mut self) -> Self {
62        self.eps = T::one();
63        self
64    }
65}
66
67/* chain rule */
68impl<T: DualNum<F>, F: Float> Dual<T, F> {
69    #[inline]
70    fn chain_rule(&self, f0: T, f1: T) -> Self {
71        Self::new(f0, self.eps.clone() * f1)
72    }
73}
74
75/* product rule */
76impl<T: DualNum<F>, F: Float> Mul<&Dual<T, F>> for &Dual<T, F> {
77    type Output = Dual<T, F>;
78    #[inline]
79    fn mul(self, other: &Dual<T, F>) -> Self::Output {
80        Dual::new(
81            self.re.clone() * other.re.clone(),
82            self.eps.clone() * other.re.clone() + other.eps.clone() * self.re.clone(),
83        )
84    }
85}
86
87/* quotient rule */
88impl<T: DualNum<F>, F: Float> Div<&Dual<T, F>> for &Dual<T, F> {
89    type Output = Dual<T, F>;
90    #[inline]
91    fn div(self, other: &Dual<T, F>) -> Dual<T, F> {
92        let inv = other.re.recip();
93        Dual::new(
94            self.re.clone() * inv.clone(),
95            (self.eps.clone() * other.re.clone() - other.eps.clone() * self.re.clone())
96                * inv.clone()
97                * inv,
98        )
99    }
100}
101
102/* string conversions */
103impl<T: DualNum<F>, F> fmt::Display for Dual<T, F> {
104    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
105        write!(f, "{} + {}ε", self.re, self.eps)
106    }
107}
108
109impl_first_derivatives!(Dual, [eps]);
110impl_dual!(Dual, [eps]);
111
112/**
113 * The SimdValue trait is for rearranging data into a form more suitable for Simd,
114 * and rearranging it back into a usable form. It is not documented particularly well.
115 *
116 * The primary job of this SimdValue impl is to allow people to use `simba::simd::f32x4` etc,
117 * instead of f32/f64. Those types implement nalgebra::SimdRealField/ComplexField, so they
118 * behave like scalars. When we use them, we would have `Dual<f32x4, f32, N>` etc, with our
119 * F parameter set to `<T as SimdValue>::Element`. We will need to be able to split up that type
120 * into four of Dual in order to get out of simd-land. That's what the SimdValue trait is for.
121 *
122 * Ultimately, someone will have to to implement SimdRealField on Dual and call the
123 * simd_ functions of `<T as SimdRealField>`. That's future work for someone who finds
124 * num_dual is not fast enough.
125 *
126 * Unfortunately, doing anything with SIMD is blocked on
127 * <https://github.com/dimforge/simba/issues/44>.
128 *
129 */
130impl<T> nalgebra::SimdValue for Dual<T, T::Element>
131where
132    T: DualNum<T::Element> + SimdValue + Scalar,
133    T::Element: DualNum<T::Element> + Scalar,
134{
135    // Say T = simba::f32x4. T::Element is f32. T::SimdBool is AutoSimd<[bool; 4]>.
136    // AutoSimd<[f32; 4]> stores an actual [f32; 4], i.e. four floats in one slot.
137    // So our Dual<AutoSimd<[f32; 4], f32, N> has 4 * (1+N) floats in it, stored in blocks of
138    // four. When we want to do any math on it but ignore its f32x4 storage mode, we need to break
139    // that type into FOUR of Dual<f32, f32, N>; then we do math on it, then we bring it back
140    // together.
141    //
142    // Hence this definition of Element:
143    type Element = Dual<T::Element, T::Element>;
144    type SimdBool = T::SimdBool;
145
146    const LANES: usize = T::LANES;
147
148    #[inline]
149    fn splat(val: Self::Element) -> Self {
150        // Need to make `lanes` copies of each of:
151        // - the real part
152        // - each of the N epsilon parts
153        let re = T::splat(val.re);
154        let eps = T::splat(val.eps);
155        Self::new(re, eps)
156    }
157
158    #[inline]
159    fn extract(&self, i: usize) -> Self::Element {
160        let re = self.re.extract(i);
161        let eps = self.eps.extract(i);
162        Self::Element {
163            re,
164            eps,
165            f: PhantomData,
166        }
167    }
168
169    #[inline]
170    unsafe fn extract_unchecked(&self, i: usize) -> Self::Element {
171        let re = unsafe { self.re.extract_unchecked(i) };
172        let eps = unsafe { self.eps.extract_unchecked(i) };
173        Self::Element {
174            re,
175            eps,
176            f: PhantomData,
177        }
178    }
179
180    #[inline]
181    fn replace(&mut self, i: usize, val: Self::Element) {
182        self.re.replace(i, val.re);
183        self.eps.replace(i, val.eps);
184    }
185
186    #[inline]
187    unsafe fn replace_unchecked(&mut self, i: usize, val: Self::Element) {
188        unsafe { self.re.replace_unchecked(i, val.re) };
189        unsafe { self.eps.replace_unchecked(i, val.eps) };
190    }
191
192    #[inline]
193    fn select(self, cond: Self::SimdBool, other: Self) -> Self {
194        let re = self.re.select(cond, other.re);
195        let eps = self.eps.select(cond, other.eps);
196        Self::new(re, eps)
197    }
198}
199
200/// Like PartialEq, comparisons are only made based on the real part. This allows the code to follow the
201/// same execution path as real-valued code would.
202impl<T: DualNum<F> + approx::AbsDiffEq<Epsilon = T>, F: Float> approx::AbsDiffEq for Dual<T, F> {
203    type Epsilon = Self;
204    #[inline]
205    fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
206        self.re.abs_diff_eq(&other.re, epsilon.re)
207    }
208
209    #[inline]
210    fn default_epsilon() -> Self::Epsilon {
211        Self::from_re(T::default_epsilon())
212    }
213}
214/// Like PartialEq, comparisons are only made based on the real part. This allows the code to follow the
215/// same execution path as real-valued code would.
216impl<T: DualNum<F> + approx::RelativeEq<Epsilon = T>, F: Float> approx::RelativeEq for Dual<T, F> {
217    #[inline]
218    fn default_max_relative() -> Self::Epsilon {
219        Self::from_re(T::default_max_relative())
220    }
221
222    #[inline]
223    fn relative_eq(
224        &self,
225        other: &Self,
226        epsilon: Self::Epsilon,
227        max_relative: Self::Epsilon,
228    ) -> bool {
229        self.re.relative_eq(&other.re, epsilon.re, max_relative.re)
230    }
231}
232impl<T: DualNum<F> + UlpsEq<Epsilon = T>, F: Float> UlpsEq for Dual<T, F> {
233    #[inline]
234    fn default_max_ulps() -> u32 {
235        T::default_max_ulps()
236    }
237
238    #[inline]
239    fn ulps_eq(&self, other: &Self, epsilon: Self::Epsilon, max_ulps: u32) -> bool {
240        T::ulps_eq(&self.re, &other.re, epsilon.re, max_ulps)
241    }
242}
243
244impl<T> nalgebra::Field for Dual<T, T::Element>
245where
246    T: DualNum<T::Element> + SimdValue,
247    T::Element: DualNum<T::Element> + Scalar + Float,
248{
249}
250
251use simba::scalar::{SubsetOf, SupersetOf};
252
253impl<TSuper, FSuper, T, F> SubsetOf<Dual<TSuper, FSuper>> for Dual<T, F>
254where
255    TSuper: DualNum<FSuper> + SupersetOf<T>,
256    T: DualNum<F>,
257{
258    #[inline(always)]
259    fn to_superset(&self) -> Dual<TSuper, FSuper> {
260        let re = TSuper::from_subset(&self.re);
261        let eps = TSuper::from_subset(&self.eps);
262        Dual {
263            re,
264            eps,
265            f: PhantomData,
266        }
267    }
268    #[inline(always)]
269    fn from_superset(element: &Dual<TSuper, FSuper>) -> Option<Self> {
270        let re = TSuper::to_subset(&element.re)?;
271        let eps = TSuper::to_subset(&element.eps)?;
272        Some(Self::new(re, eps))
273    }
274    #[inline(always)]
275    fn from_superset_unchecked(element: &Dual<TSuper, FSuper>) -> Self {
276        let re = TSuper::to_subset_unchecked(&element.re);
277        let eps = TSuper::to_subset_unchecked(&element.eps);
278        Self::new(re, eps)
279    }
280    #[inline(always)]
281    fn is_in_subset(element: &Dual<TSuper, FSuper>) -> bool {
282        TSuper::is_in_subset(&element.re) && TSuper::is_in_subset(&element.eps)
283    }
284}
285
286impl<TSuper, FSuper> SupersetOf<f32> for Dual<TSuper, FSuper>
287where
288    TSuper: DualNum<FSuper> + SupersetOf<f32>,
289{
290    #[inline(always)]
291    fn is_in_subset(&self) -> bool {
292        self.re.is_in_subset()
293    }
294
295    #[inline(always)]
296    fn to_subset_unchecked(&self) -> f32 {
297        self.re.to_subset_unchecked()
298    }
299
300    #[inline(always)]
301    fn from_subset(element: &f32) -> Self {
302        // Interpret as a purely real number
303        let re = TSuper::from_subset(element);
304        let eps = TSuper::zero();
305        Self::new(re, eps)
306    }
307}
308
309impl<TSuper, FSuper> SupersetOf<f64> for Dual<TSuper, FSuper>
310where
311    TSuper: DualNum<FSuper> + SupersetOf<f64>,
312{
313    #[inline(always)]
314    fn is_in_subset(&self) -> bool {
315        self.re.is_in_subset()
316    }
317
318    #[inline(always)]
319    fn to_subset_unchecked(&self) -> f64 {
320        self.re.to_subset_unchecked()
321    }
322
323    #[inline(always)]
324    fn from_subset(element: &f64) -> Self {
325        // Interpret as a purely real number
326        let re = TSuper::from_subset(element);
327        let eps = TSuper::zero();
328        Self::new(re, eps)
329    }
330}
331
332// We can't do a simd implementation until simba lets us implement SimdPartialOrd
333// using _T_'s SimdBool. The blanket impl gets in the way. So we must constrain
334// T to SimdValue<Element = T, SimdBool = bool>, which is basically the same as
335// saying f32 or f64 only.
336//
337// Limitation of simba. See https://github.com/dimforge/simba/issues/44
338
339use nalgebra::{ComplexField, RealField};
340// This impl is modelled on `impl ComplexField for f32`. The imaginary part is nothing.
341impl<T> ComplexField for Dual<T, T::Element>
342where
343    T: DualNum<T::Element> + SupersetOf<T> + AbsDiffEq<Epsilon = T> + Sync + Send,
344    T::Element: DualNum<T::Element> + Scalar + DualNumFloat + Sync + Send,
345    T: SupersetOf<T::Element>,
346    T: SupersetOf<f32>,
347    T: SupersetOf<f64>,
348    T: SimdPartialOrd + PartialOrd,
349    T: SimdValue<Element = T, SimdBool = bool>,
350    T: RelativeEq + UlpsEq + AbsDiffEq,
351{
352    type RealField = Self;
353
354    #[inline]
355    fn from_real(re: Self::RealField) -> Self {
356        re
357    }
358
359    #[inline]
360    fn real(self) -> Self::RealField {
361        self
362    }
363
364    #[inline]
365    fn imaginary(self) -> Self::RealField {
366        Self::zero()
367    }
368
369    #[inline]
370    fn modulus(self) -> Self::RealField {
371        self.abs()
372    }
373
374    #[inline]
375    fn modulus_squared(self) -> Self::RealField {
376        self * self
377    }
378
379    #[inline]
380    fn argument(self) -> Self::RealField {
381        Self::zero()
382    }
383
384    #[inline]
385    fn norm1(self) -> Self::RealField {
386        self.abs()
387    }
388
389    #[inline]
390    fn scale(self, factor: Self::RealField) -> Self {
391        self * factor
392    }
393
394    #[inline]
395    fn unscale(self, factor: Self::RealField) -> Self {
396        self / factor
397    }
398
399    #[inline]
400    fn floor(self) -> Self {
401        panic!("called floor() on a dual number")
402    }
403
404    #[inline]
405    fn ceil(self) -> Self {
406        panic!("called ceil() on a dual number")
407    }
408
409    #[inline]
410    fn round(self) -> Self {
411        panic!("called round() on a dual number")
412    }
413
414    #[inline]
415    fn trunc(self) -> Self {
416        panic!("called trunc() on a dual number")
417    }
418
419    #[inline]
420    fn fract(self) -> Self {
421        panic!("called fract() on a dual number")
422    }
423
424    #[inline]
425    fn mul_add(self, a: Self, b: Self) -> Self {
426        DualNum::mul_add(&self, a, b)
427    }
428
429    #[inline]
430    fn abs(self) -> Self::RealField {
431        Signed::abs(&self)
432    }
433
434    #[inline]
435    fn hypot(self, other: Self) -> Self::RealField {
436        let sum_sq = self.powi(2) + other.powi(2);
437        DualNum::sqrt(&sum_sq)
438    }
439
440    #[inline]
441    fn recip(self) -> Self {
442        DualNum::recip(&self)
443    }
444
445    #[inline]
446    fn conjugate(self) -> Self {
447        self
448    }
449
450    #[inline]
451    fn sin(self) -> Self {
452        DualNum::sin(&self)
453    }
454
455    #[inline]
456    fn cos(self) -> Self {
457        DualNum::cos(&self)
458    }
459
460    #[inline]
461    fn sin_cos(self) -> (Self, Self) {
462        DualNum::sin_cos(&self)
463    }
464
465    #[inline]
466    fn tan(self) -> Self {
467        DualNum::tan(&self)
468    }
469
470    #[inline]
471    fn asin(self) -> Self {
472        DualNum::asin(&self)
473    }
474
475    #[inline]
476    fn acos(self) -> Self {
477        DualNum::acos(&self)
478    }
479
480    #[inline]
481    fn atan(self) -> Self {
482        DualNum::atan(&self)
483    }
484
485    #[inline]
486    fn sinh(self) -> Self {
487        DualNum::sinh(&self)
488    }
489
490    #[inline]
491    fn cosh(self) -> Self {
492        DualNum::cosh(&self)
493    }
494
495    #[inline]
496    fn tanh(self) -> Self {
497        DualNum::tanh(&self)
498    }
499
500    #[inline]
501    fn asinh(self) -> Self {
502        DualNum::asinh(&self)
503    }
504
505    #[inline]
506    fn acosh(self) -> Self {
507        DualNum::acosh(&self)
508    }
509
510    #[inline]
511    fn atanh(self) -> Self {
512        DualNum::atanh(&self)
513    }
514
515    #[inline]
516    fn log(self, base: Self::RealField) -> Self {
517        DualNum::ln(&self) / DualNum::ln(&base)
518    }
519
520    #[inline]
521    fn log2(self) -> Self {
522        DualNum::log2(&self)
523    }
524
525    #[inline]
526    fn log10(self) -> Self {
527        DualNum::log10(&self)
528    }
529
530    #[inline]
531    fn ln(self) -> Self {
532        DualNum::ln(&self)
533    }
534
535    #[inline]
536    fn ln_1p(self) -> Self {
537        DualNum::ln_1p(&self)
538    }
539
540    #[inline]
541    fn sqrt(self) -> Self {
542        DualNum::sqrt(&self)
543    }
544
545    #[inline]
546    fn exp(self) -> Self {
547        DualNum::exp(&self)
548    }
549
550    #[inline]
551    fn exp2(self) -> Self {
552        DualNum::exp2(&self)
553    }
554
555    #[inline]
556    fn exp_m1(self) -> Self {
557        DualNum::exp_m1(&self)
558    }
559
560    #[inline]
561    fn powi(self, n: i32) -> Self {
562        DualNum::powi(&self, n)
563    }
564
565    #[inline]
566    fn powf(self, n: Self::RealField) -> Self {
567        // n could be a dual.
568        DualNum::powd(&self, n)
569    }
570
571    #[inline]
572    fn powc(self, n: Self) -> Self {
573        // same as powf, Self isn't complex
574        self.powf(n)
575    }
576
577    #[inline]
578    fn cbrt(self) -> Self {
579        DualNum::cbrt(&self)
580    }
581
582    #[inline]
583    fn is_finite(&self) -> bool {
584        self.re.is_finite()
585    }
586
587    #[inline]
588    fn try_sqrt(self) -> Option<Self> {
589        if self > Self::zero() {
590            Some(DualNum::sqrt(&self))
591        } else {
592            None
593        }
594    }
595}
596
597impl<T> RealField for Dual<T, T::Element>
598where
599    T: DualNum<T::Element> + SupersetOf<T> + Sync + Send,
600    T::Element: DualNum<T::Element> + Scalar + DualNumFloat,
601    T: SupersetOf<T::Element>,
602    T: SupersetOf<f32>,
603    T: SupersetOf<f64>,
604    T: SimdPartialOrd + PartialOrd,
605    T: RelativeEq + AbsDiffEq<Epsilon = T>,
606    T: SimdValue<Element = T, SimdBool = bool>,
607    T: UlpsEq,
608    T: AbsDiffEq,
609{
610    #[inline]
611    fn copysign(self, sign: Self) -> Self {
612        if sign.re.is_sign_positive() {
613            self.simd_abs()
614        } else {
615            -self.simd_abs()
616        }
617    }
618
619    #[inline]
620    fn atan2(self, other: Self) -> Self {
621        DualNum::atan2(&self, other)
622    }
623
624    #[inline]
625    fn pi() -> Self {
626        Self::from_re(<T as FloatConst>::PI())
627    }
628
629    #[inline]
630    fn two_pi() -> Self {
631        Self::from_re(<T as FloatConst>::TAU())
632    }
633
634    #[inline]
635    fn frac_pi_2() -> Self {
636        Self::from_re(<T as FloatConst>::FRAC_PI_4())
637    }
638
639    #[inline]
640    fn frac_pi_3() -> Self {
641        Self::from_re(<T as FloatConst>::FRAC_PI_3())
642    }
643
644    #[inline]
645    fn frac_pi_4() -> Self {
646        Self::from_re(<T as FloatConst>::FRAC_PI_4())
647    }
648
649    #[inline]
650    fn frac_pi_6() -> Self {
651        Self::from_re(<T as FloatConst>::FRAC_PI_6())
652    }
653
654    #[inline]
655    fn frac_pi_8() -> Self {
656        Self::from_re(<T as FloatConst>::FRAC_PI_8())
657    }
658
659    #[inline]
660    fn frac_1_pi() -> Self {
661        Self::from_re(<T as FloatConst>::FRAC_1_PI())
662    }
663
664    #[inline]
665    fn frac_2_pi() -> Self {
666        Self::from_re(<T as FloatConst>::FRAC_2_PI())
667    }
668
669    #[inline]
670    fn frac_2_sqrt_pi() -> Self {
671        Self::from_re(<T as FloatConst>::FRAC_2_SQRT_PI())
672    }
673
674    #[inline]
675    fn e() -> Self {
676        Self::from_re(<T as FloatConst>::E())
677    }
678
679    #[inline]
680    fn log2_e() -> Self {
681        Self::from_re(<T as FloatConst>::LOG2_E())
682    }
683
684    #[inline]
685    fn log10_e() -> Self {
686        Self::from_re(<T as FloatConst>::LOG10_E())
687    }
688
689    #[inline]
690    fn ln_2() -> Self {
691        Self::from_re(<T as FloatConst>::LN_2())
692    }
693
694    #[inline]
695    fn ln_10() -> Self {
696        Self::from_re(<T as FloatConst>::LN_10())
697    }
698
699    #[inline]
700    fn is_sign_positive(&self) -> bool {
701        self.re.is_sign_positive()
702    }
703
704    #[inline]
705    fn is_sign_negative(&self) -> bool {
706        self.re.is_sign_negative()
707    }
708
709    /// Got to be careful using this, because it throws away the derivatives of the one not chosen
710    #[inline]
711    fn max(self, other: Self) -> Self {
712        if other > self { other } else { self }
713    }
714
715    /// Got to be careful using this, because it throws away the derivatives of the one not chosen
716    #[inline]
717    fn min(self, other: Self) -> Self {
718        if other < self { other } else { self }
719    }
720
721    /// If the min/max values are constants and the clamping has an effect, you lose your gradients.
722    #[inline]
723    fn clamp(self, min: Self, max: Self) -> Self {
724        if self < min {
725            min
726        } else if self > max {
727            max
728        } else {
729            self
730        }
731    }
732
733    #[inline]
734    fn min_value() -> Option<Self> {
735        Some(Self::from_re(T::min_value()))
736    }
737
738    #[inline]
739    fn max_value() -> Option<Self> {
740        Some(Self::from_re(T::max_value()))
741    }
742}