num_dual/datatypes/
dual.rs

1use crate::{DualNum, DualNumFloat, DualStruct};
2use approx::{AbsDiffEq, RelativeEq, UlpsEq};
3use nalgebra::*;
4use num_traits::{Float, FloatConst, FromPrimitive, Inv, Num, One, Signed, Zero};
5#[cfg(feature = "serde")]
6use serde::{Deserialize, Serialize};
7use std::fmt;
8use std::iter::{Product, Sum};
9use std::marker::PhantomData;
10use std::ops::{
11    Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign,
12};
13
14/// A scalar dual number for the calculations of first derivatives.
15#[derive(Copy, Clone, Debug)]
16#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
17pub struct Dual<T: DualNum<F>, F> {
18    /// Real part of the dual number
19    pub re: T,
20    /// Derivative part of the dual number
21    pub eps: T,
22    #[cfg_attr(feature = "serde", serde(skip))]
23    f: PhantomData<F>,
24}
25
26#[cfg(feature = "ndarray")]
27impl<T: DualNum<F>, F: DualNumFloat> ndarray::ScalarOperand for Dual<T, F> {}
28
29pub type Dual32 = Dual<f32, f32>;
30pub type Dual64 = Dual<f64, f64>;
31
32impl<T: DualNum<F>, F> Dual<T, F> {
33    /// Create a new dual number from its fields.
34    #[inline]
35    pub fn new(re: T, eps: T) -> Self {
36        Self {
37            re,
38            eps,
39            f: PhantomData,
40        }
41    }
42}
43
44impl<T: DualNum<F> + Zero, F> Dual<T, F> {
45    /// Create a new dual number from the real part.
46    #[inline]
47    pub fn from_re(re: T) -> Self {
48        Self::new(re, T::zero())
49    }
50}
51
52impl<T: DualNum<F> + One, F> Dual<T, F> {
53    /// Set the derivative part to 1.
54    /// ```
55    /// # use num_dual::{Dual64, DualNum};
56    /// let x = Dual64::from_re(5.0).derivative().powi(2);
57    /// assert_eq!(x.re, 25.0);
58    /// assert_eq!(x.eps, 10.0);
59    /// ```
60    #[inline]
61    pub fn derivative(mut self) -> Self {
62        self.eps = T::one();
63        self
64    }
65}
66
67/* chain rule */
68impl<T: DualNum<F>, F: Float> Dual<T, F> {
69    #[inline]
70    fn chain_rule(&self, f0: T, f1: T) -> Self {
71        Self::new(f0, self.eps.clone() * f1)
72    }
73}
74
75/* product rule */
76impl<T: DualNum<F>, F: Float> Mul<&Dual<T, F>> for &Dual<T, F> {
77    type Output = Dual<T, F>;
78    #[inline]
79    fn mul(self, other: &Dual<T, F>) -> Self::Output {
80        Dual::new(
81            self.re.clone() * other.re.clone(),
82            self.eps.clone() * other.re.clone() + other.eps.clone() * self.re.clone(),
83        )
84    }
85}
86
87/* quotient rule */
88impl<T: DualNum<F>, F: Float> Div<&Dual<T, F>> for &Dual<T, F> {
89    type Output = Dual<T, F>;
90    #[inline]
91    fn div(self, other: &Dual<T, F>) -> Dual<T, F> {
92        let inv = other.re.recip();
93        Dual::new(
94            self.re.clone() * inv.clone(),
95            (self.eps.clone() * other.re.clone() - other.eps.clone() * self.re.clone())
96                * inv.clone()
97                * inv,
98        )
99    }
100}
101
102/* string conversions */
103impl<T: DualNum<F>, F> fmt::Display for Dual<T, F> {
104    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
105        write!(f, "{} + {}ε", self.re, self.eps)
106    }
107}
108
109impl_first_derivatives!(Dual, [eps]);
110impl_dual!(Dual, [eps]);
111
112/**
113 * The SimdValue trait is for rearranging data into a form more suitable for Simd,
114 * and rearranging it back into a usable form. It is not documented particularly well.
115 *
116 * The primary job of this SimdValue impl is to allow people to use `simba::simd::f32x4` etc,
117 * instead of f32/f64. Those types implement nalgebra::SimdRealField/ComplexField, so they
118 * behave like scalars. When we use them, we would have `Dual<f32x4, f32, N>` etc, with our
119 * F parameter set to `<T as SimdValue>::Element`. We will need to be able to split up that type
120 * into four of Dual in order to get out of simd-land. That's what the SimdValue trait is for.
121 *
122 * Ultimately, someone will have to to implement SimdRealField on Dual and call the
123 * simd_ functions of `<T as SimdRealField>`. That's future work for someone who finds
124 * num_dual is not fast enough.
125 *
126 * Unfortunately, doing anything with SIMD is blocked on
127 * <https://github.com/dimforge/simba/issues/44>.
128 *
129 */
130impl<T> nalgebra::SimdValue for Dual<T, T::Element>
131where
132    T: DualNum<T::Element> + SimdValue + Scalar,
133    T::Element: DualNum<T::Element> + Scalar,
134{
135    // Say T = simba::f32x4. T::Element is f32. T::SimdBool is AutoSimd<[bool; 4]>.
136    // AutoSimd<[f32; 4]> stores an actual [f32; 4], i.e. four floats in one slot.
137    // So our Dual<AutoSimd<[f32; 4], f32, N> has 4 * (1+N) floats in it, stored in blocks of
138    // four. When we want to do any math on it but ignore its f32x4 storage mode, we need to break
139    // that type into FOUR of Dual<f32, f32, N>; then we do math on it, then we bring it back
140    // together.
141    //
142    // Hence this definition of Element:
143    type Element = Dual<T::Element, T::Element>;
144    type SimdBool = T::SimdBool;
145
146    const LANES: usize = T::LANES;
147
148    #[inline]
149    fn splat(val: Self::Element) -> Self {
150        // Need to make `lanes` copies of each of:
151        // - the real part
152        // - each of the N epsilon parts
153        let re = T::splat(val.re);
154        let eps = T::splat(val.eps);
155        Self::new(re, eps)
156    }
157
158    #[inline]
159    fn extract(&self, i: usize) -> Self::Element {
160        let re = self.re.extract(i);
161        let eps = self.eps.extract(i);
162        Self::Element {
163            re,
164            eps,
165            f: PhantomData,
166        }
167    }
168
169    #[inline]
170    unsafe fn extract_unchecked(&self, i: usize) -> Self::Element {
171        let re = unsafe { self.re.extract_unchecked(i) };
172        let eps = unsafe { self.eps.extract_unchecked(i) };
173        Self::Element {
174            re,
175            eps,
176            f: PhantomData,
177        }
178    }
179
180    #[inline]
181    fn replace(&mut self, i: usize, val: Self::Element) {
182        self.re.replace(i, val.re);
183        self.eps.replace(i, val.eps);
184    }
185
186    #[inline]
187    unsafe fn replace_unchecked(&mut self, i: usize, val: Self::Element) {
188        unsafe { self.re.replace_unchecked(i, val.re) };
189        unsafe { self.eps.replace_unchecked(i, val.eps) };
190    }
191
192    #[inline]
193    fn select(self, cond: Self::SimdBool, other: Self) -> Self {
194        let re = self.re.select(cond, other.re);
195        let eps = self.eps.select(cond, other.eps);
196        Self::new(re, eps)
197    }
198}
199
200/// Comparisons are only made based on the real part. This allows the code to follow the
201/// same execution path as real-valued code would.
202impl<T: DualNum<F> + PartialEq, F: Float> PartialEq for Dual<T, F> {
203    #[inline]
204    fn eq(&self, other: &Self) -> bool {
205        self.re.eq(&other.re)
206    }
207}
208/// Like PartialEq, comparisons are only made based on the real part. This allows the code to follow the
209/// same execution path as real-valued code would.
210impl<T: DualNum<F> + PartialOrd, F: Float> PartialOrd for Dual<T, F> {
211    #[inline]
212    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
213        self.re.partial_cmp(&other.re)
214    }
215}
216/// Like PartialEq, comparisons are only made based on the real part. This allows the code to follow the
217/// same execution path as real-valued code would.
218impl<T: DualNum<F> + approx::AbsDiffEq<Epsilon = T>, F: Float> approx::AbsDiffEq for Dual<T, F> {
219    type Epsilon = Self;
220    #[inline]
221    fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
222        self.re.abs_diff_eq(&other.re, epsilon.re)
223    }
224
225    #[inline]
226    fn default_epsilon() -> Self::Epsilon {
227        Self::from_re(T::default_epsilon())
228    }
229}
230/// Like PartialEq, comparisons are only made based on the real part. This allows the code to follow the
231/// same execution path as real-valued code would.
232impl<T: DualNum<F> + approx::RelativeEq<Epsilon = T>, F: Float> approx::RelativeEq for Dual<T, F> {
233    #[inline]
234    fn default_max_relative() -> Self::Epsilon {
235        Self::from_re(T::default_max_relative())
236    }
237
238    #[inline]
239    fn relative_eq(
240        &self,
241        other: &Self,
242        epsilon: Self::Epsilon,
243        max_relative: Self::Epsilon,
244    ) -> bool {
245        self.re.relative_eq(&other.re, epsilon.re, max_relative.re)
246    }
247}
248impl<T: DualNum<F> + UlpsEq<Epsilon = T>, F: Float> UlpsEq for Dual<T, F> {
249    #[inline]
250    fn default_max_ulps() -> u32 {
251        T::default_max_ulps()
252    }
253
254    #[inline]
255    fn ulps_eq(&self, other: &Self, epsilon: Self::Epsilon, max_ulps: u32) -> bool {
256        T::ulps_eq(&self.re, &other.re, epsilon.re, max_ulps)
257    }
258}
259
260impl<T> nalgebra::Field for Dual<T, T::Element>
261where
262    T: DualNum<T::Element> + SimdValue,
263    T::Element: DualNum<T::Element> + Scalar + Float,
264{
265}
266
267use simba::scalar::{SubsetOf, SupersetOf};
268
269impl<TSuper, FSuper, T, F> SubsetOf<Dual<TSuper, FSuper>> for Dual<T, F>
270where
271    TSuper: DualNum<FSuper> + SupersetOf<T>,
272    T: DualNum<F>,
273{
274    #[inline(always)]
275    fn to_superset(&self) -> Dual<TSuper, FSuper> {
276        let re = TSuper::from_subset(&self.re);
277        let eps = TSuper::from_subset(&self.eps);
278        Dual {
279            re,
280            eps,
281            f: PhantomData,
282        }
283    }
284    #[inline(always)]
285    fn from_superset(element: &Dual<TSuper, FSuper>) -> Option<Self> {
286        let re = TSuper::to_subset(&element.re)?;
287        let eps = TSuper::to_subset(&element.eps)?;
288        Some(Self::new(re, eps))
289    }
290    #[inline(always)]
291    fn from_superset_unchecked(element: &Dual<TSuper, FSuper>) -> Self {
292        let re = TSuper::to_subset_unchecked(&element.re);
293        let eps = TSuper::to_subset_unchecked(&element.eps);
294        Self::new(re, eps)
295    }
296    #[inline(always)]
297    fn is_in_subset(element: &Dual<TSuper, FSuper>) -> bool {
298        TSuper::is_in_subset(&element.re) && TSuper::is_in_subset(&element.eps)
299    }
300}
301
302impl<TSuper, FSuper> SupersetOf<f32> for Dual<TSuper, FSuper>
303where
304    TSuper: DualNum<FSuper> + SupersetOf<f32>,
305{
306    #[inline(always)]
307    fn is_in_subset(&self) -> bool {
308        self.re.is_in_subset()
309    }
310
311    #[inline(always)]
312    fn to_subset_unchecked(&self) -> f32 {
313        self.re.to_subset_unchecked()
314    }
315
316    #[inline(always)]
317    fn from_subset(element: &f32) -> Self {
318        // Interpret as a purely real number
319        let re = TSuper::from_subset(element);
320        let eps = TSuper::zero();
321        Self::new(re, eps)
322    }
323}
324
325impl<TSuper, FSuper> SupersetOf<f64> for Dual<TSuper, FSuper>
326where
327    TSuper: DualNum<FSuper> + SupersetOf<f64>,
328{
329    #[inline(always)]
330    fn is_in_subset(&self) -> bool {
331        self.re.is_in_subset()
332    }
333
334    #[inline(always)]
335    fn to_subset_unchecked(&self) -> f64 {
336        self.re.to_subset_unchecked()
337    }
338
339    #[inline(always)]
340    fn from_subset(element: &f64) -> Self {
341        // Interpret as a purely real number
342        let re = TSuper::from_subset(element);
343        let eps = TSuper::zero();
344        Self::new(re, eps)
345    }
346}
347
348// We can't do a simd implementation until simba lets us implement SimdPartialOrd
349// using _T_'s SimdBool. The blanket impl gets in the way. So we must constrain
350// T to SimdValue<Element = T, SimdBool = bool>, which is basically the same as
351// saying f32 or f64 only.
352//
353// Limitation of simba. See https://github.com/dimforge/simba/issues/44
354
355use nalgebra::{ComplexField, RealField};
356// This impl is modelled on `impl ComplexField for f32`. The imaginary part is nothing.
357impl<T> ComplexField for Dual<T, T::Element>
358where
359    T: DualNum<T::Element> + SupersetOf<T> + AbsDiffEq<Epsilon = T> + Sync + Send,
360    T::Element: DualNum<T::Element> + Scalar + DualNumFloat + Sync + Send,
361    T: SupersetOf<T::Element>,
362    T: SupersetOf<f32>,
363    T: SupersetOf<f64>,
364    T: SimdPartialOrd + PartialOrd,
365    T: SimdValue<Element = T, SimdBool = bool>,
366    T: RelativeEq + UlpsEq + AbsDiffEq,
367{
368    type RealField = Self;
369
370    #[inline]
371    fn from_real(re: Self::RealField) -> Self {
372        re
373    }
374
375    #[inline]
376    fn real(self) -> Self::RealField {
377        self
378    }
379
380    #[inline]
381    fn imaginary(self) -> Self::RealField {
382        Self::zero()
383    }
384
385    #[inline]
386    fn modulus(self) -> Self::RealField {
387        self.abs()
388    }
389
390    #[inline]
391    fn modulus_squared(self) -> Self::RealField {
392        self * self
393    }
394
395    #[inline]
396    fn argument(self) -> Self::RealField {
397        Self::zero()
398    }
399
400    #[inline]
401    fn norm1(self) -> Self::RealField {
402        self.abs()
403    }
404
405    #[inline]
406    fn scale(self, factor: Self::RealField) -> Self {
407        self * factor
408    }
409
410    #[inline]
411    fn unscale(self, factor: Self::RealField) -> Self {
412        self / factor
413    }
414
415    #[inline]
416    fn floor(self) -> Self {
417        panic!("called floor() on a dual number")
418    }
419
420    #[inline]
421    fn ceil(self) -> Self {
422        panic!("called ceil() on a dual number")
423    }
424
425    #[inline]
426    fn round(self) -> Self {
427        panic!("called round() on a dual number")
428    }
429
430    #[inline]
431    fn trunc(self) -> Self {
432        panic!("called trunc() on a dual number")
433    }
434
435    #[inline]
436    fn fract(self) -> Self {
437        panic!("called fract() on a dual number")
438    }
439
440    #[inline]
441    fn mul_add(self, a: Self, b: Self) -> Self {
442        DualNum::mul_add(&self, a, b)
443    }
444
445    #[inline]
446    fn abs(self) -> Self::RealField {
447        Signed::abs(&self)
448    }
449
450    #[inline]
451    fn hypot(self, other: Self) -> Self::RealField {
452        let sum_sq = self.powi(2) + other.powi(2);
453        DualNum::sqrt(&sum_sq)
454    }
455
456    #[inline]
457    fn recip(self) -> Self {
458        DualNum::recip(&self)
459    }
460
461    #[inline]
462    fn conjugate(self) -> Self {
463        self
464    }
465
466    #[inline]
467    fn sin(self) -> Self {
468        DualNum::sin(&self)
469    }
470
471    #[inline]
472    fn cos(self) -> Self {
473        DualNum::cos(&self)
474    }
475
476    #[inline]
477    fn sin_cos(self) -> (Self, Self) {
478        DualNum::sin_cos(&self)
479    }
480
481    #[inline]
482    fn tan(self) -> Self {
483        DualNum::tan(&self)
484    }
485
486    #[inline]
487    fn asin(self) -> Self {
488        DualNum::asin(&self)
489    }
490
491    #[inline]
492    fn acos(self) -> Self {
493        DualNum::acos(&self)
494    }
495
496    #[inline]
497    fn atan(self) -> Self {
498        DualNum::atan(&self)
499    }
500
501    #[inline]
502    fn sinh(self) -> Self {
503        DualNum::sinh(&self)
504    }
505
506    #[inline]
507    fn cosh(self) -> Self {
508        DualNum::cosh(&self)
509    }
510
511    #[inline]
512    fn tanh(self) -> Self {
513        DualNum::tanh(&self)
514    }
515
516    #[inline]
517    fn asinh(self) -> Self {
518        DualNum::asinh(&self)
519    }
520
521    #[inline]
522    fn acosh(self) -> Self {
523        DualNum::acosh(&self)
524    }
525
526    #[inline]
527    fn atanh(self) -> Self {
528        DualNum::atanh(&self)
529    }
530
531    #[inline]
532    fn log(self, base: Self::RealField) -> Self {
533        DualNum::ln(&self) / DualNum::ln(&base)
534    }
535
536    #[inline]
537    fn log2(self) -> Self {
538        DualNum::log2(&self)
539    }
540
541    #[inline]
542    fn log10(self) -> Self {
543        DualNum::log10(&self)
544    }
545
546    #[inline]
547    fn ln(self) -> Self {
548        DualNum::ln(&self)
549    }
550
551    #[inline]
552    fn ln_1p(self) -> Self {
553        DualNum::ln_1p(&self)
554    }
555
556    #[inline]
557    fn sqrt(self) -> Self {
558        DualNum::sqrt(&self)
559    }
560
561    #[inline]
562    fn exp(self) -> Self {
563        DualNum::exp(&self)
564    }
565
566    #[inline]
567    fn exp2(self) -> Self {
568        DualNum::exp2(&self)
569    }
570
571    #[inline]
572    fn exp_m1(self) -> Self {
573        DualNum::exp_m1(&self)
574    }
575
576    #[inline]
577    fn powi(self, n: i32) -> Self {
578        DualNum::powi(&self, n)
579    }
580
581    #[inline]
582    fn powf(self, n: Self::RealField) -> Self {
583        // n could be a dual.
584        DualNum::powd(&self, n)
585    }
586
587    #[inline]
588    fn powc(self, n: Self) -> Self {
589        // same as powf, Self isn't complex
590        self.powf(n)
591    }
592
593    #[inline]
594    fn cbrt(self) -> Self {
595        DualNum::cbrt(&self)
596    }
597
598    #[inline]
599    fn is_finite(&self) -> bool {
600        self.re.is_finite()
601    }
602
603    #[inline]
604    fn try_sqrt(self) -> Option<Self> {
605        if self > Self::zero() {
606            Some(DualNum::sqrt(&self))
607        } else {
608            None
609        }
610    }
611}
612
613impl<T> RealField for Dual<T, T::Element>
614where
615    T: DualNum<T::Element> + SupersetOf<T> + Sync + Send,
616    T::Element: DualNum<T::Element> + Scalar + DualNumFloat,
617    T: SupersetOf<T::Element>,
618    T: SupersetOf<f32>,
619    T: SupersetOf<f64>,
620    T: SimdPartialOrd + PartialOrd,
621    T: RelativeEq + AbsDiffEq<Epsilon = T>,
622    T: SimdValue<Element = T, SimdBool = bool>,
623    T: UlpsEq,
624    T: AbsDiffEq,
625{
626    #[inline]
627    fn copysign(self, sign: Self) -> Self {
628        if sign.re.is_sign_positive() {
629            self.simd_abs()
630        } else {
631            -self.simd_abs()
632        }
633    }
634
635    #[inline]
636    fn atan2(self, other: Self) -> Self {
637        DualNum::atan2(&self, other)
638    }
639
640    #[inline]
641    fn pi() -> Self {
642        Self::from_re(<T as FloatConst>::PI())
643    }
644
645    #[inline]
646    fn two_pi() -> Self {
647        Self::from_re(<T as FloatConst>::TAU())
648    }
649
650    #[inline]
651    fn frac_pi_2() -> Self {
652        Self::from_re(<T as FloatConst>::FRAC_PI_4())
653    }
654
655    #[inline]
656    fn frac_pi_3() -> Self {
657        Self::from_re(<T as FloatConst>::FRAC_PI_3())
658    }
659
660    #[inline]
661    fn frac_pi_4() -> Self {
662        Self::from_re(<T as FloatConst>::FRAC_PI_4())
663    }
664
665    #[inline]
666    fn frac_pi_6() -> Self {
667        Self::from_re(<T as FloatConst>::FRAC_PI_6())
668    }
669
670    #[inline]
671    fn frac_pi_8() -> Self {
672        Self::from_re(<T as FloatConst>::FRAC_PI_8())
673    }
674
675    #[inline]
676    fn frac_1_pi() -> Self {
677        Self::from_re(<T as FloatConst>::FRAC_1_PI())
678    }
679
680    #[inline]
681    fn frac_2_pi() -> Self {
682        Self::from_re(<T as FloatConst>::FRAC_2_PI())
683    }
684
685    #[inline]
686    fn frac_2_sqrt_pi() -> Self {
687        Self::from_re(<T as FloatConst>::FRAC_2_SQRT_PI())
688    }
689
690    #[inline]
691    fn e() -> Self {
692        Self::from_re(<T as FloatConst>::E())
693    }
694
695    #[inline]
696    fn log2_e() -> Self {
697        Self::from_re(<T as FloatConst>::LOG2_E())
698    }
699
700    #[inline]
701    fn log10_e() -> Self {
702        Self::from_re(<T as FloatConst>::LOG10_E())
703    }
704
705    #[inline]
706    fn ln_2() -> Self {
707        Self::from_re(<T as FloatConst>::LN_2())
708    }
709
710    #[inline]
711    fn ln_10() -> Self {
712        Self::from_re(<T as FloatConst>::LN_10())
713    }
714
715    #[inline]
716    fn is_sign_positive(&self) -> bool {
717        self.re.is_sign_positive()
718    }
719
720    #[inline]
721    fn is_sign_negative(&self) -> bool {
722        self.re.is_sign_negative()
723    }
724
725    /// Got to be careful using this, because it throws away the derivatives of the one not chosen
726    #[inline]
727    fn max(self, other: Self) -> Self {
728        if other > self { other } else { self }
729    }
730
731    /// Got to be careful using this, because it throws away the derivatives of the one not chosen
732    #[inline]
733    fn min(self, other: Self) -> Self {
734        if other < self { other } else { self }
735    }
736
737    /// If the min/max values are constants and the clamping has an effect, you lose your gradients.
738    #[inline]
739    fn clamp(self, min: Self, max: Self) -> Self {
740        if self < min {
741            min
742        } else if self > max {
743            max
744        } else {
745            self
746        }
747    }
748
749    #[inline]
750    fn min_value() -> Option<Self> {
751        Some(Self::from_re(T::min_value()))
752    }
753
754    #[inline]
755    fn max_value() -> Option<Self> {
756        Some(Self::from_re(T::max_value()))
757    }
758}