1use crate::{DualNum, DualNumFloat, DualStruct};
2use approx::{AbsDiffEq, RelativeEq, UlpsEq};
3use nalgebra::*;
4use num_traits::{Float, FloatConst, FromPrimitive, Inv, Num, One, Signed, Zero};
5#[cfg(feature = "serde")]
6use serde::{Deserialize, Serialize};
7use std::fmt;
8use std::iter::{Product, Sum};
9use std::marker::PhantomData;
10use std::ops::{
11 Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign,
12};
13
14#[derive(Copy, Clone, Debug)]
16#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
17pub struct Dual<T: DualNum<F>, F> {
18 pub re: T,
20 pub eps: T,
22 #[cfg_attr(feature = "serde", serde(skip))]
23 f: PhantomData<F>,
24}
25
26#[cfg(feature = "ndarray")]
27impl<T: DualNum<F>, F: DualNumFloat> ndarray::ScalarOperand for Dual<T, F> {}
28
29pub type Dual32 = Dual<f32, f32>;
30pub type Dual64 = Dual<f64, f64>;
31
32impl<T: DualNum<F>, F> Dual<T, F> {
33 #[inline]
35 pub fn new(re: T, eps: T) -> Self {
36 Self {
37 re,
38 eps,
39 f: PhantomData,
40 }
41 }
42}
43
44impl<T: DualNum<F> + Zero, F> Dual<T, F> {
45 #[inline]
47 pub fn from_re(re: T) -> Self {
48 Self::new(re, T::zero())
49 }
50}
51
52impl<T: DualNum<F> + One, F> Dual<T, F> {
53 #[inline]
61 pub fn derivative(mut self) -> Self {
62 self.eps = T::one();
63 self
64 }
65}
66
67impl<T: DualNum<F>, F: Float> Dual<T, F> {
69 #[inline]
70 fn chain_rule(&self, f0: T, f1: T) -> Self {
71 Self::new(f0, self.eps.clone() * f1)
72 }
73}
74
75impl<T: DualNum<F>, F: Float> Mul<&Dual<T, F>> for &Dual<T, F> {
77 type Output = Dual<T, F>;
78 #[inline]
79 fn mul(self, other: &Dual<T, F>) -> Self::Output {
80 Dual::new(
81 self.re.clone() * other.re.clone(),
82 self.eps.clone() * other.re.clone() + other.eps.clone() * self.re.clone(),
83 )
84 }
85}
86
87impl<T: DualNum<F>, F: Float> Div<&Dual<T, F>> for &Dual<T, F> {
89 type Output = Dual<T, F>;
90 #[inline]
91 fn div(self, other: &Dual<T, F>) -> Dual<T, F> {
92 let inv = other.re.recip();
93 Dual::new(
94 self.re.clone() * inv.clone(),
95 (self.eps.clone() * other.re.clone() - other.eps.clone() * self.re.clone())
96 * inv.clone()
97 * inv,
98 )
99 }
100}
101
102impl<T: DualNum<F>, F> fmt::Display for Dual<T, F> {
104 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
105 write!(f, "{} + {}ε", self.re, self.eps)
106 }
107}
108
109impl_first_derivatives!(Dual, [eps]);
110impl_dual!(Dual, [eps]);
111
112impl<T> nalgebra::SimdValue for Dual<T, T::Element>
131where
132 T: DualNum<T::Element> + SimdValue + Scalar,
133 T::Element: DualNum<T::Element> + Scalar,
134{
135 type Element = Dual<T::Element, T::Element>;
144 type SimdBool = T::SimdBool;
145
146 const LANES: usize = T::LANES;
147
148 #[inline]
149 fn splat(val: Self::Element) -> Self {
150 let re = T::splat(val.re);
154 let eps = T::splat(val.eps);
155 Self::new(re, eps)
156 }
157
158 #[inline]
159 fn extract(&self, i: usize) -> Self::Element {
160 let re = self.re.extract(i);
161 let eps = self.eps.extract(i);
162 Self::Element {
163 re,
164 eps,
165 f: PhantomData,
166 }
167 }
168
169 #[inline]
170 unsafe fn extract_unchecked(&self, i: usize) -> Self::Element {
171 let re = unsafe { self.re.extract_unchecked(i) };
172 let eps = unsafe { self.eps.extract_unchecked(i) };
173 Self::Element {
174 re,
175 eps,
176 f: PhantomData,
177 }
178 }
179
180 #[inline]
181 fn replace(&mut self, i: usize, val: Self::Element) {
182 self.re.replace(i, val.re);
183 self.eps.replace(i, val.eps);
184 }
185
186 #[inline]
187 unsafe fn replace_unchecked(&mut self, i: usize, val: Self::Element) {
188 unsafe { self.re.replace_unchecked(i, val.re) };
189 unsafe { self.eps.replace_unchecked(i, val.eps) };
190 }
191
192 #[inline]
193 fn select(self, cond: Self::SimdBool, other: Self) -> Self {
194 let re = self.re.select(cond, other.re);
195 let eps = self.eps.select(cond, other.eps);
196 Self::new(re, eps)
197 }
198}
199
200impl<T: DualNum<F> + approx::AbsDiffEq<Epsilon = T>, F: Float> approx::AbsDiffEq for Dual<T, F> {
203 type Epsilon = Self;
204 #[inline]
205 fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
206 self.re.abs_diff_eq(&other.re, epsilon.re)
207 }
208
209 #[inline]
210 fn default_epsilon() -> Self::Epsilon {
211 Self::from_re(T::default_epsilon())
212 }
213}
214impl<T: DualNum<F> + approx::RelativeEq<Epsilon = T>, F: Float> approx::RelativeEq for Dual<T, F> {
217 #[inline]
218 fn default_max_relative() -> Self::Epsilon {
219 Self::from_re(T::default_max_relative())
220 }
221
222 #[inline]
223 fn relative_eq(
224 &self,
225 other: &Self,
226 epsilon: Self::Epsilon,
227 max_relative: Self::Epsilon,
228 ) -> bool {
229 self.re.relative_eq(&other.re, epsilon.re, max_relative.re)
230 }
231}
232impl<T: DualNum<F> + UlpsEq<Epsilon = T>, F: Float> UlpsEq for Dual<T, F> {
233 #[inline]
234 fn default_max_ulps() -> u32 {
235 T::default_max_ulps()
236 }
237
238 #[inline]
239 fn ulps_eq(&self, other: &Self, epsilon: Self::Epsilon, max_ulps: u32) -> bool {
240 T::ulps_eq(&self.re, &other.re, epsilon.re, max_ulps)
241 }
242}
243
244impl<T> nalgebra::Field for Dual<T, T::Element>
245where
246 T: DualNum<T::Element> + SimdValue,
247 T::Element: DualNum<T::Element> + Scalar + Float,
248{
249}
250
251use simba::scalar::{SubsetOf, SupersetOf};
252
253impl<TSuper, FSuper, T, F> SubsetOf<Dual<TSuper, FSuper>> for Dual<T, F>
254where
255 TSuper: DualNum<FSuper> + SupersetOf<T>,
256 T: DualNum<F>,
257{
258 #[inline(always)]
259 fn to_superset(&self) -> Dual<TSuper, FSuper> {
260 let re = TSuper::from_subset(&self.re);
261 let eps = TSuper::from_subset(&self.eps);
262 Dual {
263 re,
264 eps,
265 f: PhantomData,
266 }
267 }
268 #[inline(always)]
269 fn from_superset(element: &Dual<TSuper, FSuper>) -> Option<Self> {
270 let re = TSuper::to_subset(&element.re)?;
271 let eps = TSuper::to_subset(&element.eps)?;
272 Some(Self::new(re, eps))
273 }
274 #[inline(always)]
275 fn from_superset_unchecked(element: &Dual<TSuper, FSuper>) -> Self {
276 let re = TSuper::to_subset_unchecked(&element.re);
277 let eps = TSuper::to_subset_unchecked(&element.eps);
278 Self::new(re, eps)
279 }
280 #[inline(always)]
281 fn is_in_subset(element: &Dual<TSuper, FSuper>) -> bool {
282 TSuper::is_in_subset(&element.re) && TSuper::is_in_subset(&element.eps)
283 }
284}
285
286impl<TSuper, FSuper> SupersetOf<f32> for Dual<TSuper, FSuper>
287where
288 TSuper: DualNum<FSuper> + SupersetOf<f32>,
289{
290 #[inline(always)]
291 fn is_in_subset(&self) -> bool {
292 self.re.is_in_subset()
293 }
294
295 #[inline(always)]
296 fn to_subset_unchecked(&self) -> f32 {
297 self.re.to_subset_unchecked()
298 }
299
300 #[inline(always)]
301 fn from_subset(element: &f32) -> Self {
302 let re = TSuper::from_subset(element);
304 let eps = TSuper::zero();
305 Self::new(re, eps)
306 }
307}
308
309impl<TSuper, FSuper> SupersetOf<f64> for Dual<TSuper, FSuper>
310where
311 TSuper: DualNum<FSuper> + SupersetOf<f64>,
312{
313 #[inline(always)]
314 fn is_in_subset(&self) -> bool {
315 self.re.is_in_subset()
316 }
317
318 #[inline(always)]
319 fn to_subset_unchecked(&self) -> f64 {
320 self.re.to_subset_unchecked()
321 }
322
323 #[inline(always)]
324 fn from_subset(element: &f64) -> Self {
325 let re = TSuper::from_subset(element);
327 let eps = TSuper::zero();
328 Self::new(re, eps)
329 }
330}
331
332use nalgebra::{ComplexField, RealField};
340impl<T> ComplexField for Dual<T, T::Element>
342where
343 T: DualNum<T::Element> + SupersetOf<T> + AbsDiffEq<Epsilon = T> + Sync + Send,
344 T::Element: DualNum<T::Element> + Scalar + DualNumFloat + Sync + Send,
345 T: SupersetOf<T::Element>,
346 T: SupersetOf<f32>,
347 T: SupersetOf<f64>,
348 T: SimdPartialOrd + PartialOrd,
349 T: SimdValue<Element = T, SimdBool = bool>,
350 T: RelativeEq + UlpsEq + AbsDiffEq,
351{
352 type RealField = Self;
353
354 #[inline]
355 fn from_real(re: Self::RealField) -> Self {
356 re
357 }
358
359 #[inline]
360 fn real(self) -> Self::RealField {
361 self
362 }
363
364 #[inline]
365 fn imaginary(self) -> Self::RealField {
366 Self::zero()
367 }
368
369 #[inline]
370 fn modulus(self) -> Self::RealField {
371 self.abs()
372 }
373
374 #[inline]
375 fn modulus_squared(self) -> Self::RealField {
376 self * self
377 }
378
379 #[inline]
380 fn argument(self) -> Self::RealField {
381 Self::zero()
382 }
383
384 #[inline]
385 fn norm1(self) -> Self::RealField {
386 self.abs()
387 }
388
389 #[inline]
390 fn scale(self, factor: Self::RealField) -> Self {
391 self * factor
392 }
393
394 #[inline]
395 fn unscale(self, factor: Self::RealField) -> Self {
396 self / factor
397 }
398
399 #[inline]
400 fn floor(self) -> Self {
401 panic!("called floor() on a dual number")
402 }
403
404 #[inline]
405 fn ceil(self) -> Self {
406 panic!("called ceil() on a dual number")
407 }
408
409 #[inline]
410 fn round(self) -> Self {
411 panic!("called round() on a dual number")
412 }
413
414 #[inline]
415 fn trunc(self) -> Self {
416 panic!("called trunc() on a dual number")
417 }
418
419 #[inline]
420 fn fract(self) -> Self {
421 panic!("called fract() on a dual number")
422 }
423
424 #[inline]
425 fn mul_add(self, a: Self, b: Self) -> Self {
426 DualNum::mul_add(&self, a, b)
427 }
428
429 #[inline]
430 fn abs(self) -> Self::RealField {
431 Signed::abs(&self)
432 }
433
434 #[inline]
435 fn hypot(self, other: Self) -> Self::RealField {
436 let sum_sq = self.powi(2) + other.powi(2);
437 DualNum::sqrt(&sum_sq)
438 }
439
440 #[inline]
441 fn recip(self) -> Self {
442 DualNum::recip(&self)
443 }
444
445 #[inline]
446 fn conjugate(self) -> Self {
447 self
448 }
449
450 #[inline]
451 fn sin(self) -> Self {
452 DualNum::sin(&self)
453 }
454
455 #[inline]
456 fn cos(self) -> Self {
457 DualNum::cos(&self)
458 }
459
460 #[inline]
461 fn sin_cos(self) -> (Self, Self) {
462 DualNum::sin_cos(&self)
463 }
464
465 #[inline]
466 fn tan(self) -> Self {
467 DualNum::tan(&self)
468 }
469
470 #[inline]
471 fn asin(self) -> Self {
472 DualNum::asin(&self)
473 }
474
475 #[inline]
476 fn acos(self) -> Self {
477 DualNum::acos(&self)
478 }
479
480 #[inline]
481 fn atan(self) -> Self {
482 DualNum::atan(&self)
483 }
484
485 #[inline]
486 fn sinh(self) -> Self {
487 DualNum::sinh(&self)
488 }
489
490 #[inline]
491 fn cosh(self) -> Self {
492 DualNum::cosh(&self)
493 }
494
495 #[inline]
496 fn tanh(self) -> Self {
497 DualNum::tanh(&self)
498 }
499
500 #[inline]
501 fn asinh(self) -> Self {
502 DualNum::asinh(&self)
503 }
504
505 #[inline]
506 fn acosh(self) -> Self {
507 DualNum::acosh(&self)
508 }
509
510 #[inline]
511 fn atanh(self) -> Self {
512 DualNum::atanh(&self)
513 }
514
515 #[inline]
516 fn log(self, base: Self::RealField) -> Self {
517 DualNum::ln(&self) / DualNum::ln(&base)
518 }
519
520 #[inline]
521 fn log2(self) -> Self {
522 DualNum::log2(&self)
523 }
524
525 #[inline]
526 fn log10(self) -> Self {
527 DualNum::log10(&self)
528 }
529
530 #[inline]
531 fn ln(self) -> Self {
532 DualNum::ln(&self)
533 }
534
535 #[inline]
536 fn ln_1p(self) -> Self {
537 DualNum::ln_1p(&self)
538 }
539
540 #[inline]
541 fn sqrt(self) -> Self {
542 DualNum::sqrt(&self)
543 }
544
545 #[inline]
546 fn exp(self) -> Self {
547 DualNum::exp(&self)
548 }
549
550 #[inline]
551 fn exp2(self) -> Self {
552 DualNum::exp2(&self)
553 }
554
555 #[inline]
556 fn exp_m1(self) -> Self {
557 DualNum::exp_m1(&self)
558 }
559
560 #[inline]
561 fn powi(self, n: i32) -> Self {
562 DualNum::powi(&self, n)
563 }
564
565 #[inline]
566 fn powf(self, n: Self::RealField) -> Self {
567 DualNum::powd(&self, n)
569 }
570
571 #[inline]
572 fn powc(self, n: Self) -> Self {
573 self.powf(n)
575 }
576
577 #[inline]
578 fn cbrt(self) -> Self {
579 DualNum::cbrt(&self)
580 }
581
582 #[inline]
583 fn is_finite(&self) -> bool {
584 self.re.is_finite()
585 }
586
587 #[inline]
588 fn try_sqrt(self) -> Option<Self> {
589 if self > Self::zero() {
590 Some(DualNum::sqrt(&self))
591 } else {
592 None
593 }
594 }
595}
596
597impl<T> RealField for Dual<T, T::Element>
598where
599 T: DualNum<T::Element> + SupersetOf<T> + Sync + Send,
600 T::Element: DualNum<T::Element> + Scalar + DualNumFloat,
601 T: SupersetOf<T::Element>,
602 T: SupersetOf<f32>,
603 T: SupersetOf<f64>,
604 T: SimdPartialOrd + PartialOrd,
605 T: RelativeEq + AbsDiffEq<Epsilon = T>,
606 T: SimdValue<Element = T, SimdBool = bool>,
607 T: UlpsEq,
608 T: AbsDiffEq,
609{
610 #[inline]
611 fn copysign(self, sign: Self) -> Self {
612 if sign.re.is_sign_positive() {
613 self.simd_abs()
614 } else {
615 -self.simd_abs()
616 }
617 }
618
619 #[inline]
620 fn atan2(self, other: Self) -> Self {
621 DualNum::atan2(&self, other)
622 }
623
624 #[inline]
625 fn pi() -> Self {
626 Self::from_re(<T as FloatConst>::PI())
627 }
628
629 #[inline]
630 fn two_pi() -> Self {
631 Self::from_re(<T as FloatConst>::TAU())
632 }
633
634 #[inline]
635 fn frac_pi_2() -> Self {
636 Self::from_re(<T as FloatConst>::FRAC_PI_4())
637 }
638
639 #[inline]
640 fn frac_pi_3() -> Self {
641 Self::from_re(<T as FloatConst>::FRAC_PI_3())
642 }
643
644 #[inline]
645 fn frac_pi_4() -> Self {
646 Self::from_re(<T as FloatConst>::FRAC_PI_4())
647 }
648
649 #[inline]
650 fn frac_pi_6() -> Self {
651 Self::from_re(<T as FloatConst>::FRAC_PI_6())
652 }
653
654 #[inline]
655 fn frac_pi_8() -> Self {
656 Self::from_re(<T as FloatConst>::FRAC_PI_8())
657 }
658
659 #[inline]
660 fn frac_1_pi() -> Self {
661 Self::from_re(<T as FloatConst>::FRAC_1_PI())
662 }
663
664 #[inline]
665 fn frac_2_pi() -> Self {
666 Self::from_re(<T as FloatConst>::FRAC_2_PI())
667 }
668
669 #[inline]
670 fn frac_2_sqrt_pi() -> Self {
671 Self::from_re(<T as FloatConst>::FRAC_2_SQRT_PI())
672 }
673
674 #[inline]
675 fn e() -> Self {
676 Self::from_re(<T as FloatConst>::E())
677 }
678
679 #[inline]
680 fn log2_e() -> Self {
681 Self::from_re(<T as FloatConst>::LOG2_E())
682 }
683
684 #[inline]
685 fn log10_e() -> Self {
686 Self::from_re(<T as FloatConst>::LOG10_E())
687 }
688
689 #[inline]
690 fn ln_2() -> Self {
691 Self::from_re(<T as FloatConst>::LN_2())
692 }
693
694 #[inline]
695 fn ln_10() -> Self {
696 Self::from_re(<T as FloatConst>::LN_10())
697 }
698
699 #[inline]
700 fn is_sign_positive(&self) -> bool {
701 self.re.is_sign_positive()
702 }
703
704 #[inline]
705 fn is_sign_negative(&self) -> bool {
706 self.re.is_sign_negative()
707 }
708
709 #[inline]
711 fn max(self, other: Self) -> Self {
712 if other > self { other } else { self }
713 }
714
715 #[inline]
717 fn min(self, other: Self) -> Self {
718 if other < self { other } else { self }
719 }
720
721 #[inline]
723 fn clamp(self, min: Self, max: Self) -> Self {
724 if self < min {
725 min
726 } else if self > max {
727 max
728 } else {
729 self
730 }
731 }
732
733 #[inline]
734 fn min_value() -> Option<Self> {
735 Some(Self::from_re(T::min_value()))
736 }
737
738 #[inline]
739 fn max_value() -> Option<Self> {
740 Some(Self::from_re(T::max_value()))
741 }
742}