1use crate::{DualNum, DualNumFloat, DualStruct};
2use approx::{AbsDiffEq, RelativeEq, UlpsEq};
3use nalgebra::*;
4use num_traits::{Float, FloatConst, FromPrimitive, Inv, Num, One, Signed, Zero};
5#[cfg(feature = "serde")]
6use serde::{Deserialize, Serialize};
7use std::fmt;
8use std::iter::{Product, Sum};
9use std::marker::PhantomData;
10use std::ops::{
11 Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign,
12};
13
14#[derive(Copy, Clone, Debug)]
16#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
17pub struct Dual<T: DualNum<F>, F> {
18 pub re: T,
20 pub eps: T,
22 #[cfg_attr(feature = "serde", serde(skip))]
23 f: PhantomData<F>,
24}
25
26#[cfg(feature = "ndarray")]
27impl<T: DualNum<F>, F: DualNumFloat> ndarray::ScalarOperand for Dual<T, F> {}
28
29pub type Dual32 = Dual<f32, f32>;
30pub type Dual64 = Dual<f64, f64>;
31
32impl<T: DualNum<F>, F> Dual<T, F> {
33 #[inline]
35 pub fn new(re: T, eps: T) -> Self {
36 Self {
37 re,
38 eps,
39 f: PhantomData,
40 }
41 }
42}
43
44impl<T: DualNum<F> + Zero, F> Dual<T, F> {
45 #[inline]
47 pub fn from_re(re: T) -> Self {
48 Self::new(re, T::zero())
49 }
50}
51
52impl<T: DualNum<F> + One, F> Dual<T, F> {
53 #[inline]
61 pub fn derivative(mut self) -> Self {
62 self.eps = T::one();
63 self
64 }
65}
66
67impl<T: DualNum<F>, F: Float> Dual<T, F> {
69 #[inline]
70 fn chain_rule(&self, f0: T, f1: T) -> Self {
71 Self::new(f0, self.eps.clone() * f1)
72 }
73}
74
75impl<T: DualNum<F>, F: Float> Mul<&Dual<T, F>> for &Dual<T, F> {
77 type Output = Dual<T, F>;
78 #[inline]
79 fn mul(self, other: &Dual<T, F>) -> Self::Output {
80 Dual::new(
81 self.re.clone() * other.re.clone(),
82 self.eps.clone() * other.re.clone() + other.eps.clone() * self.re.clone(),
83 )
84 }
85}
86
87impl<T: DualNum<F>, F: Float> Div<&Dual<T, F>> for &Dual<T, F> {
89 type Output = Dual<T, F>;
90 #[inline]
91 fn div(self, other: &Dual<T, F>) -> Dual<T, F> {
92 let inv = other.re.recip();
93 Dual::new(
94 self.re.clone() * inv.clone(),
95 (self.eps.clone() * other.re.clone() - other.eps.clone() * self.re.clone())
96 * inv.clone()
97 * inv,
98 )
99 }
100}
101
102impl<T: DualNum<F>, F> fmt::Display for Dual<T, F> {
104 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
105 write!(f, "{} + {}ε", self.re, self.eps)
106 }
107}
108
109impl_first_derivatives!(Dual, [eps]);
110impl_dual!(Dual, [eps]);
111
112impl<T> nalgebra::SimdValue for Dual<T, T::Element>
131where
132 T: DualNum<T::Element> + SimdValue + Scalar,
133 T::Element: DualNum<T::Element> + Scalar,
134{
135 type Element = Dual<T::Element, T::Element>;
144 type SimdBool = T::SimdBool;
145
146 const LANES: usize = T::LANES;
147
148 #[inline]
149 fn splat(val: Self::Element) -> Self {
150 let re = T::splat(val.re);
154 let eps = T::splat(val.eps);
155 Self::new(re, eps)
156 }
157
158 #[inline]
159 fn extract(&self, i: usize) -> Self::Element {
160 let re = self.re.extract(i);
161 let eps = self.eps.extract(i);
162 Self::Element {
163 re,
164 eps,
165 f: PhantomData,
166 }
167 }
168
169 #[inline]
170 unsafe fn extract_unchecked(&self, i: usize) -> Self::Element {
171 let re = unsafe { self.re.extract_unchecked(i) };
172 let eps = unsafe { self.eps.extract_unchecked(i) };
173 Self::Element {
174 re,
175 eps,
176 f: PhantomData,
177 }
178 }
179
180 #[inline]
181 fn replace(&mut self, i: usize, val: Self::Element) {
182 self.re.replace(i, val.re);
183 self.eps.replace(i, val.eps);
184 }
185
186 #[inline]
187 unsafe fn replace_unchecked(&mut self, i: usize, val: Self::Element) {
188 unsafe { self.re.replace_unchecked(i, val.re) };
189 unsafe { self.eps.replace_unchecked(i, val.eps) };
190 }
191
192 #[inline]
193 fn select(self, cond: Self::SimdBool, other: Self) -> Self {
194 let re = self.re.select(cond, other.re);
195 let eps = self.eps.select(cond, other.eps);
196 Self::new(re, eps)
197 }
198}
199
200impl<T: DualNum<F> + PartialEq, F: Float> PartialEq for Dual<T, F> {
203 #[inline]
204 fn eq(&self, other: &Self) -> bool {
205 self.re.eq(&other.re)
206 }
207}
208impl<T: DualNum<F> + PartialOrd, F: Float> PartialOrd for Dual<T, F> {
211 #[inline]
212 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
213 self.re.partial_cmp(&other.re)
214 }
215}
216impl<T: DualNum<F> + approx::AbsDiffEq<Epsilon = T>, F: Float> approx::AbsDiffEq for Dual<T, F> {
219 type Epsilon = Self;
220 #[inline]
221 fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
222 self.re.abs_diff_eq(&other.re, epsilon.re)
223 }
224
225 #[inline]
226 fn default_epsilon() -> Self::Epsilon {
227 Self::from_re(T::default_epsilon())
228 }
229}
230impl<T: DualNum<F> + approx::RelativeEq<Epsilon = T>, F: Float> approx::RelativeEq for Dual<T, F> {
233 #[inline]
234 fn default_max_relative() -> Self::Epsilon {
235 Self::from_re(T::default_max_relative())
236 }
237
238 #[inline]
239 fn relative_eq(
240 &self,
241 other: &Self,
242 epsilon: Self::Epsilon,
243 max_relative: Self::Epsilon,
244 ) -> bool {
245 self.re.relative_eq(&other.re, epsilon.re, max_relative.re)
246 }
247}
248impl<T: DualNum<F> + UlpsEq<Epsilon = T>, F: Float> UlpsEq for Dual<T, F> {
249 #[inline]
250 fn default_max_ulps() -> u32 {
251 T::default_max_ulps()
252 }
253
254 #[inline]
255 fn ulps_eq(&self, other: &Self, epsilon: Self::Epsilon, max_ulps: u32) -> bool {
256 T::ulps_eq(&self.re, &other.re, epsilon.re, max_ulps)
257 }
258}
259
260impl<T> nalgebra::Field for Dual<T, T::Element>
261where
262 T: DualNum<T::Element> + SimdValue,
263 T::Element: DualNum<T::Element> + Scalar + Float,
264{
265}
266
267use simba::scalar::{SubsetOf, SupersetOf};
268
269impl<TSuper, FSuper, T, F> SubsetOf<Dual<TSuper, FSuper>> for Dual<T, F>
270where
271 TSuper: DualNum<FSuper> + SupersetOf<T>,
272 T: DualNum<F>,
273{
274 #[inline(always)]
275 fn to_superset(&self) -> Dual<TSuper, FSuper> {
276 let re = TSuper::from_subset(&self.re);
277 let eps = TSuper::from_subset(&self.eps);
278 Dual {
279 re,
280 eps,
281 f: PhantomData,
282 }
283 }
284 #[inline(always)]
285 fn from_superset(element: &Dual<TSuper, FSuper>) -> Option<Self> {
286 let re = TSuper::to_subset(&element.re)?;
287 let eps = TSuper::to_subset(&element.eps)?;
288 Some(Self::new(re, eps))
289 }
290 #[inline(always)]
291 fn from_superset_unchecked(element: &Dual<TSuper, FSuper>) -> Self {
292 let re = TSuper::to_subset_unchecked(&element.re);
293 let eps = TSuper::to_subset_unchecked(&element.eps);
294 Self::new(re, eps)
295 }
296 #[inline(always)]
297 fn is_in_subset(element: &Dual<TSuper, FSuper>) -> bool {
298 TSuper::is_in_subset(&element.re) && TSuper::is_in_subset(&element.eps)
299 }
300}
301
302impl<TSuper, FSuper> SupersetOf<f32> for Dual<TSuper, FSuper>
303where
304 TSuper: DualNum<FSuper> + SupersetOf<f32>,
305{
306 #[inline(always)]
307 fn is_in_subset(&self) -> bool {
308 self.re.is_in_subset()
309 }
310
311 #[inline(always)]
312 fn to_subset_unchecked(&self) -> f32 {
313 self.re.to_subset_unchecked()
314 }
315
316 #[inline(always)]
317 fn from_subset(element: &f32) -> Self {
318 let re = TSuper::from_subset(element);
320 let eps = TSuper::zero();
321 Self::new(re, eps)
322 }
323}
324
325impl<TSuper, FSuper> SupersetOf<f64> for Dual<TSuper, FSuper>
326where
327 TSuper: DualNum<FSuper> + SupersetOf<f64>,
328{
329 #[inline(always)]
330 fn is_in_subset(&self) -> bool {
331 self.re.is_in_subset()
332 }
333
334 #[inline(always)]
335 fn to_subset_unchecked(&self) -> f64 {
336 self.re.to_subset_unchecked()
337 }
338
339 #[inline(always)]
340 fn from_subset(element: &f64) -> Self {
341 let re = TSuper::from_subset(element);
343 let eps = TSuper::zero();
344 Self::new(re, eps)
345 }
346}
347
348use nalgebra::{ComplexField, RealField};
356impl<T> ComplexField for Dual<T, T::Element>
358where
359 T: DualNum<T::Element> + SupersetOf<T> + AbsDiffEq<Epsilon = T> + Sync + Send,
360 T::Element: DualNum<T::Element> + Scalar + DualNumFloat + Sync + Send,
361 T: SupersetOf<T::Element>,
362 T: SupersetOf<f32>,
363 T: SupersetOf<f64>,
364 T: SimdPartialOrd + PartialOrd,
365 T: SimdValue<Element = T, SimdBool = bool>,
366 T: RelativeEq + UlpsEq + AbsDiffEq,
367{
368 type RealField = Self;
369
370 #[inline]
371 fn from_real(re: Self::RealField) -> Self {
372 re
373 }
374
375 #[inline]
376 fn real(self) -> Self::RealField {
377 self
378 }
379
380 #[inline]
381 fn imaginary(self) -> Self::RealField {
382 Self::zero()
383 }
384
385 #[inline]
386 fn modulus(self) -> Self::RealField {
387 self.abs()
388 }
389
390 #[inline]
391 fn modulus_squared(self) -> Self::RealField {
392 self * self
393 }
394
395 #[inline]
396 fn argument(self) -> Self::RealField {
397 Self::zero()
398 }
399
400 #[inline]
401 fn norm1(self) -> Self::RealField {
402 self.abs()
403 }
404
405 #[inline]
406 fn scale(self, factor: Self::RealField) -> Self {
407 self * factor
408 }
409
410 #[inline]
411 fn unscale(self, factor: Self::RealField) -> Self {
412 self / factor
413 }
414
415 #[inline]
416 fn floor(self) -> Self {
417 panic!("called floor() on a dual number")
418 }
419
420 #[inline]
421 fn ceil(self) -> Self {
422 panic!("called ceil() on a dual number")
423 }
424
425 #[inline]
426 fn round(self) -> Self {
427 panic!("called round() on a dual number")
428 }
429
430 #[inline]
431 fn trunc(self) -> Self {
432 panic!("called trunc() on a dual number")
433 }
434
435 #[inline]
436 fn fract(self) -> Self {
437 panic!("called fract() on a dual number")
438 }
439
440 #[inline]
441 fn mul_add(self, a: Self, b: Self) -> Self {
442 DualNum::mul_add(&self, a, b)
443 }
444
445 #[inline]
446 fn abs(self) -> Self::RealField {
447 Signed::abs(&self)
448 }
449
450 #[inline]
451 fn hypot(self, other: Self) -> Self::RealField {
452 let sum_sq = self.powi(2) + other.powi(2);
453 DualNum::sqrt(&sum_sq)
454 }
455
456 #[inline]
457 fn recip(self) -> Self {
458 DualNum::recip(&self)
459 }
460
461 #[inline]
462 fn conjugate(self) -> Self {
463 self
464 }
465
466 #[inline]
467 fn sin(self) -> Self {
468 DualNum::sin(&self)
469 }
470
471 #[inline]
472 fn cos(self) -> Self {
473 DualNum::cos(&self)
474 }
475
476 #[inline]
477 fn sin_cos(self) -> (Self, Self) {
478 DualNum::sin_cos(&self)
479 }
480
481 #[inline]
482 fn tan(self) -> Self {
483 DualNum::tan(&self)
484 }
485
486 #[inline]
487 fn asin(self) -> Self {
488 DualNum::asin(&self)
489 }
490
491 #[inline]
492 fn acos(self) -> Self {
493 DualNum::acos(&self)
494 }
495
496 #[inline]
497 fn atan(self) -> Self {
498 DualNum::atan(&self)
499 }
500
501 #[inline]
502 fn sinh(self) -> Self {
503 DualNum::sinh(&self)
504 }
505
506 #[inline]
507 fn cosh(self) -> Self {
508 DualNum::cosh(&self)
509 }
510
511 #[inline]
512 fn tanh(self) -> Self {
513 DualNum::tanh(&self)
514 }
515
516 #[inline]
517 fn asinh(self) -> Self {
518 DualNum::asinh(&self)
519 }
520
521 #[inline]
522 fn acosh(self) -> Self {
523 DualNum::acosh(&self)
524 }
525
526 #[inline]
527 fn atanh(self) -> Self {
528 DualNum::atanh(&self)
529 }
530
531 #[inline]
532 fn log(self, base: Self::RealField) -> Self {
533 DualNum::ln(&self) / DualNum::ln(&base)
534 }
535
536 #[inline]
537 fn log2(self) -> Self {
538 DualNum::log2(&self)
539 }
540
541 #[inline]
542 fn log10(self) -> Self {
543 DualNum::log10(&self)
544 }
545
546 #[inline]
547 fn ln(self) -> Self {
548 DualNum::ln(&self)
549 }
550
551 #[inline]
552 fn ln_1p(self) -> Self {
553 DualNum::ln_1p(&self)
554 }
555
556 #[inline]
557 fn sqrt(self) -> Self {
558 DualNum::sqrt(&self)
559 }
560
561 #[inline]
562 fn exp(self) -> Self {
563 DualNum::exp(&self)
564 }
565
566 #[inline]
567 fn exp2(self) -> Self {
568 DualNum::exp2(&self)
569 }
570
571 #[inline]
572 fn exp_m1(self) -> Self {
573 DualNum::exp_m1(&self)
574 }
575
576 #[inline]
577 fn powi(self, n: i32) -> Self {
578 DualNum::powi(&self, n)
579 }
580
581 #[inline]
582 fn powf(self, n: Self::RealField) -> Self {
583 DualNum::powd(&self, n)
585 }
586
587 #[inline]
588 fn powc(self, n: Self) -> Self {
589 self.powf(n)
591 }
592
593 #[inline]
594 fn cbrt(self) -> Self {
595 DualNum::cbrt(&self)
596 }
597
598 #[inline]
599 fn is_finite(&self) -> bool {
600 self.re.is_finite()
601 }
602
603 #[inline]
604 fn try_sqrt(self) -> Option<Self> {
605 if self > Self::zero() {
606 Some(DualNum::sqrt(&self))
607 } else {
608 None
609 }
610 }
611}
612
613impl<T> RealField for Dual<T, T::Element>
614where
615 T: DualNum<T::Element> + SupersetOf<T> + Sync + Send,
616 T::Element: DualNum<T::Element> + Scalar + DualNumFloat,
617 T: SupersetOf<T::Element>,
618 T: SupersetOf<f32>,
619 T: SupersetOf<f64>,
620 T: SimdPartialOrd + PartialOrd,
621 T: RelativeEq + AbsDiffEq<Epsilon = T>,
622 T: SimdValue<Element = T, SimdBool = bool>,
623 T: UlpsEq,
624 T: AbsDiffEq,
625{
626 #[inline]
627 fn copysign(self, sign: Self) -> Self {
628 if sign.re.is_sign_positive() {
629 self.simd_abs()
630 } else {
631 -self.simd_abs()
632 }
633 }
634
635 #[inline]
636 fn atan2(self, other: Self) -> Self {
637 DualNum::atan2(&self, other)
638 }
639
640 #[inline]
641 fn pi() -> Self {
642 Self::from_re(<T as FloatConst>::PI())
643 }
644
645 #[inline]
646 fn two_pi() -> Self {
647 Self::from_re(<T as FloatConst>::TAU())
648 }
649
650 #[inline]
651 fn frac_pi_2() -> Self {
652 Self::from_re(<T as FloatConst>::FRAC_PI_4())
653 }
654
655 #[inline]
656 fn frac_pi_3() -> Self {
657 Self::from_re(<T as FloatConst>::FRAC_PI_3())
658 }
659
660 #[inline]
661 fn frac_pi_4() -> Self {
662 Self::from_re(<T as FloatConst>::FRAC_PI_4())
663 }
664
665 #[inline]
666 fn frac_pi_6() -> Self {
667 Self::from_re(<T as FloatConst>::FRAC_PI_6())
668 }
669
670 #[inline]
671 fn frac_pi_8() -> Self {
672 Self::from_re(<T as FloatConst>::FRAC_PI_8())
673 }
674
675 #[inline]
676 fn frac_1_pi() -> Self {
677 Self::from_re(<T as FloatConst>::FRAC_1_PI())
678 }
679
680 #[inline]
681 fn frac_2_pi() -> Self {
682 Self::from_re(<T as FloatConst>::FRAC_2_PI())
683 }
684
685 #[inline]
686 fn frac_2_sqrt_pi() -> Self {
687 Self::from_re(<T as FloatConst>::FRAC_2_SQRT_PI())
688 }
689
690 #[inline]
691 fn e() -> Self {
692 Self::from_re(<T as FloatConst>::E())
693 }
694
695 #[inline]
696 fn log2_e() -> Self {
697 Self::from_re(<T as FloatConst>::LOG2_E())
698 }
699
700 #[inline]
701 fn log10_e() -> Self {
702 Self::from_re(<T as FloatConst>::LOG10_E())
703 }
704
705 #[inline]
706 fn ln_2() -> Self {
707 Self::from_re(<T as FloatConst>::LN_2())
708 }
709
710 #[inline]
711 fn ln_10() -> Self {
712 Self::from_re(<T as FloatConst>::LN_10())
713 }
714
715 #[inline]
716 fn is_sign_positive(&self) -> bool {
717 self.re.is_sign_positive()
718 }
719
720 #[inline]
721 fn is_sign_negative(&self) -> bool {
722 self.re.is_sign_negative()
723 }
724
725 #[inline]
727 fn max(self, other: Self) -> Self {
728 if other > self { other } else { self }
729 }
730
731 #[inline]
733 fn min(self, other: Self) -> Self {
734 if other < self { other } else { self }
735 }
736
737 #[inline]
739 fn clamp(self, min: Self, max: Self) -> Self {
740 if self < min {
741 min
742 } else if self > max {
743 max
744 } else {
745 self
746 }
747 }
748
749 #[inline]
750 fn min_value() -> Option<Self> {
751 Some(Self::from_re(T::min_value()))
752 }
753
754 #[inline]
755 fn max_value() -> Option<Self> {
756 Some(Self::from_re(T::max_value()))
757 }
758}