1use crate::{DualNum, DualNumFloat, DualStruct};
2use approx::{AbsDiffEq, RelativeEq, UlpsEq};
3use nalgebra::*;
4use num_traits::{Float, FloatConst, FromPrimitive, Inv, Num, One, Signed, Zero};
5#[cfg(feature = "serde")]
6use serde::{Deserialize, Serialize};
7use std::fmt;
8use std::iter::{Product, Sum};
9use std::marker::PhantomData;
10use std::ops::{
11 Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign,
12};
13
14#[derive(Copy, Clone, Debug)]
16#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
17pub struct Dual2<T: DualNum<F>, F> {
18 pub re: T,
20 pub v1: T,
22 pub v2: T,
24 #[cfg_attr(feature = "serde", serde(skip))]
25 f: PhantomData<F>,
26}
27
28#[cfg(feature = "ndarray")]
29impl<T: DualNum<F>, F: DualNumFloat> ndarray::ScalarOperand for Dual2<T, F> {}
30
31pub type Dual2_32 = Dual2<f32, f32>;
32pub type Dual2_64 = Dual2<f64, f64>;
33
34impl<T: DualNum<F>, F> Dual2<T, F> {
35 #[inline]
37 pub fn new(re: T, v1: T, v2: T) -> Self {
38 Self {
39 re,
40 v1,
41 v2,
42 f: PhantomData,
43 }
44 }
45}
46
47impl<T: DualNum<F>, F> Dual2<T, F> {
48 #[inline]
70 pub fn derivative(mut self) -> Self {
71 self.v1 = T::one();
72 self
73 }
74}
75
76impl<T: DualNum<F>, F> Dual2<T, F> {
77 #[inline]
79 pub fn from_re(re: T) -> Self {
80 Self::new(re, T::zero(), T::zero())
81 }
82}
83
84impl<T: DualNum<F>, F: Float> Dual2<T, F> {
86 #[inline]
87 fn chain_rule(&self, f0: T, f1: T, f2: T) -> Self {
88 Self::new(
89 f0,
90 self.v1.clone() * f1.clone(),
91 self.v2.clone() * f1 + self.v1.clone() * self.v1.clone() * f2,
92 )
93 }
94}
95
96impl<T: DualNum<F>, F: Float> Mul<&Dual2<T, F>> for &Dual2<T, F> {
98 type Output = Dual2<T, F>;
99 #[inline]
100 fn mul(self, other: &Dual2<T, F>) -> Dual2<T, F> {
101 Dual2::new(
102 self.re.clone() * other.re.clone(),
103 other.v1.clone() * self.re.clone() + self.v1.clone() * other.re.clone(),
104 other.v2.clone() * self.re.clone()
105 + self.v1.clone() * other.v1.clone()
106 + other.v1.clone() * self.v1.clone()
107 + self.v2.clone() * other.re.clone(),
108 )
109 }
110}
111
112impl<T: DualNum<F>, F: Float> Div<&Dual2<T, F>> for &Dual2<T, F> {
114 type Output = Dual2<T, F>;
115 #[inline]
116 fn div(self, other: &Dual2<T, F>) -> Dual2<T, F> {
117 let inv = other.re.recip();
118 let inv2 = inv.clone() * inv.clone();
119 Dual2::new(
120 self.re.clone() * inv.clone(),
121 (self.v1.clone() * other.re.clone() - other.v1.clone() * self.re.clone())
122 * inv2.clone(),
123 self.v2.clone() * inv.clone()
124 - (other.v2.clone() * self.re.clone()
125 + self.v1.clone() * other.v1.clone()
126 + other.v1.clone() * self.v1.clone())
127 * inv2.clone()
128 + other.v1.clone()
129 * other.v1.clone()
130 * ((T::one() + T::one()) * self.re.clone() * inv2 * inv),
131 )
132 }
133}
134
135impl<T: DualNum<F>, F: fmt::Display> fmt::Display for Dual2<T, F> {
137 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
138 write!(f, "{} + {}ε1 + {}ε1²", self.re, self.v1, self.v2)
139 }
140}
141
142impl_second_derivatives!(Dual2, [v1, v2]);
143impl_dual!(Dual2, [v1, v2]);
144
145impl<T> nalgebra::SimdValue for Dual2<T, T::Element>
164where
165 T: DualNum<T::Element> + SimdValue + Scalar,
166 T::Element: DualNum<T::Element> + Scalar,
167{
168 type Element = Dual2<T::Element, T::Element>;
177 type SimdBool = T::SimdBool;
178
179 const LANES: usize = T::LANES;
180
181 #[inline]
182 fn splat(val: Self::Element) -> Self {
183 let re = T::splat(val.re);
187 let v1 = T::splat(val.v1);
188 let v2 = T::splat(val.v2);
189 Self::new(re, v1, v2)
190 }
191
192 #[inline]
193 fn extract(&self, i: usize) -> Self::Element {
194 let re = self.re.extract(i);
195 let v1 = self.v1.extract(i);
196 let v2 = self.v2.extract(i);
197 Self::Element {
198 re,
199 v1,
200 v2,
201 f: PhantomData,
202 }
203 }
204
205 #[inline]
206 unsafe fn extract_unchecked(&self, i: usize) -> Self::Element {
207 let re = unsafe { self.re.extract_unchecked(i) };
208 let v1 = unsafe { self.v1.extract_unchecked(i) };
209 let v2 = unsafe { self.v2.extract_unchecked(i) };
210 Self::Element {
211 re,
212 v1,
213 v2,
214 f: PhantomData,
215 }
216 }
217
218 #[inline]
219 fn replace(&mut self, i: usize, val: Self::Element) {
220 self.re.replace(i, val.re);
221 self.v1.replace(i, val.v1);
222 self.v2.replace(i, val.v2);
223 }
224
225 #[inline]
226 unsafe fn replace_unchecked(&mut self, i: usize, val: Self::Element) {
227 unsafe { self.re.replace_unchecked(i, val.re) };
228 unsafe { self.v1.replace_unchecked(i, val.v1) };
229 unsafe { self.v2.replace_unchecked(i, val.v2) };
230 }
231
232 #[inline]
233 fn select(self, cond: Self::SimdBool, other: Self) -> Self {
234 let re = self.re.select(cond, other.re);
235 let v1 = self.v1.select(cond, other.v1);
236 let v2 = self.v2.select(cond, other.v2);
237 Self::new(re, v1, v2)
238 }
239}
240
241impl<T: DualNum<F> + PartialEq, F: Float> PartialEq for Dual2<T, F> {
244 #[inline]
245 fn eq(&self, other: &Self) -> bool {
246 self.re.eq(&other.re)
247 }
248}
249impl<T: DualNum<F> + PartialOrd, F: Float> PartialOrd for Dual2<T, F> {
252 #[inline]
253 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
254 self.re.partial_cmp(&other.re)
255 }
256}
257impl<T: DualNum<F> + approx::AbsDiffEq<Epsilon = T>, F: Float> approx::AbsDiffEq for Dual2<T, F> {
260 type Epsilon = Self;
261 #[inline]
262 fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
263 self.re.abs_diff_eq(&other.re, epsilon.re)
264 }
265
266 #[inline]
267 fn default_epsilon() -> Self::Epsilon {
268 Self::from_re(T::default_epsilon())
269 }
270}
271impl<T: DualNum<F> + approx::RelativeEq<Epsilon = T>, F: Float> approx::RelativeEq for Dual2<T, F> {
274 #[inline]
275 fn default_max_relative() -> Self::Epsilon {
276 Self::from_re(T::default_max_relative())
277 }
278
279 #[inline]
280 fn relative_eq(
281 &self,
282 other: &Self,
283 epsilon: Self::Epsilon,
284 max_relative: Self::Epsilon,
285 ) -> bool {
286 self.re.relative_eq(&other.re, epsilon.re, max_relative.re)
287 }
288}
289impl<T: DualNum<F> + UlpsEq<Epsilon = T>, F: Float> UlpsEq for Dual2<T, F> {
290 #[inline]
291 fn default_max_ulps() -> u32 {
292 T::default_max_ulps()
293 }
294
295 #[inline]
296 fn ulps_eq(&self, other: &Self, epsilon: Self::Epsilon, max_ulps: u32) -> bool {
297 T::ulps_eq(&self.re, &other.re, epsilon.re, max_ulps)
298 }
299}
300
301impl<T> nalgebra::Field for Dual2<T, T::Element>
302where
303 T: DualNum<T::Element> + SimdValue,
304 T::Element: DualNum<T::Element> + Scalar + Float,
305{
306}
307
308use simba::scalar::{SubsetOf, SupersetOf};
309
310impl<TSuper, FSuper, T, F> SubsetOf<Dual2<TSuper, FSuper>> for Dual2<T, F>
311where
312 TSuper: DualNum<FSuper> + SupersetOf<T>,
313 T: DualNum<F>,
314{
315 #[inline(always)]
316 fn to_superset(&self) -> Dual2<TSuper, FSuper> {
317 let re = TSuper::from_subset(&self.re);
318 let v1 = TSuper::from_subset(&self.v1);
319 let v2 = TSuper::from_subset(&self.v2);
320 Dual2 {
321 re,
322 v1,
323 v2,
324 f: PhantomData,
325 }
326 }
327 #[inline(always)]
328 fn from_superset(element: &Dual2<TSuper, FSuper>) -> Option<Self> {
329 let re = TSuper::to_subset(&element.re)?;
330 let v1 = TSuper::to_subset(&element.v1)?;
331 let v2 = TSuper::to_subset(&element.v2)?;
332 Some(Self::new(re, v1, v2))
333 }
334 #[inline(always)]
335 fn from_superset_unchecked(element: &Dual2<TSuper, FSuper>) -> Self {
336 let re = TSuper::to_subset_unchecked(&element.re);
337 let v1 = TSuper::to_subset_unchecked(&element.v1);
338 let v2 = TSuper::to_subset_unchecked(&element.v2);
339 Self::new(re, v1, v2)
340 }
341 #[inline(always)]
342 fn is_in_subset(element: &Dual2<TSuper, FSuper>) -> bool {
343 TSuper::is_in_subset(&element.re)
344 && TSuper::is_in_subset(&element.v1)
345 && TSuper::is_in_subset(&element.v2)
346 }
347}
348
349impl<TSuper, FSuper> SupersetOf<f32> for Dual2<TSuper, FSuper>
350where
351 TSuper: DualNum<FSuper> + SupersetOf<f32>,
352{
353 #[inline(always)]
354 fn is_in_subset(&self) -> bool {
355 self.re.is_in_subset()
356 }
357
358 #[inline(always)]
359 fn to_subset_unchecked(&self) -> f32 {
360 self.re.to_subset_unchecked()
361 }
362
363 #[inline(always)]
364 fn from_subset(element: &f32) -> Self {
365 let re = TSuper::from_subset(element);
367 let v1 = TSuper::zero();
368 let v2 = TSuper::zero();
369 Self::new(re, v1, v2)
370 }
371}
372
373impl<TSuper, FSuper> SupersetOf<f64> for Dual2<TSuper, FSuper>
374where
375 TSuper: DualNum<FSuper> + SupersetOf<f64>,
376{
377 #[inline(always)]
378 fn is_in_subset(&self) -> bool {
379 self.re.is_in_subset()
380 }
381
382 #[inline(always)]
383 fn to_subset_unchecked(&self) -> f64 {
384 self.re.to_subset_unchecked()
385 }
386
387 #[inline(always)]
388 fn from_subset(element: &f64) -> Self {
389 let re = TSuper::from_subset(element);
391 let v1 = TSuper::zero();
392 let v2 = TSuper::zero();
393 Self::new(re, v1, v2)
394 }
395}
396
397use nalgebra::{ComplexField, RealField};
405impl<T> ComplexField for Dual2<T, T::Element>
407where
408 T: DualNum<T::Element> + SupersetOf<T> + AbsDiffEq<Epsilon = T> + Sync + Send,
409 T::Element: DualNum<T::Element> + Scalar + DualNumFloat + Sync + Send,
410 T: SupersetOf<T::Element>,
411 T: SupersetOf<f32>,
412 T: SupersetOf<f64>,
413 T: SimdPartialOrd + PartialOrd,
414 T: SimdValue<Element = T, SimdBool = bool>,
415 T: RelativeEq + UlpsEq + AbsDiffEq,
416{
417 type RealField = Self;
418
419 #[inline]
420 fn from_real(re: Self::RealField) -> Self {
421 re
422 }
423
424 #[inline]
425 fn real(self) -> Self::RealField {
426 self
427 }
428
429 #[inline]
430 fn imaginary(self) -> Self::RealField {
431 Self::zero()
432 }
433
434 #[inline]
435 fn modulus(self) -> Self::RealField {
436 self.abs()
437 }
438
439 #[inline]
440 fn modulus_squared(self) -> Self::RealField {
441 self * self
442 }
443
444 #[inline]
445 fn argument(self) -> Self::RealField {
446 Self::zero()
447 }
448
449 #[inline]
450 fn norm1(self) -> Self::RealField {
451 self.abs()
452 }
453
454 #[inline]
455 fn scale(self, factor: Self::RealField) -> Self {
456 self * factor
457 }
458
459 #[inline]
460 fn unscale(self, factor: Self::RealField) -> Self {
461 self / factor
462 }
463
464 #[inline]
465 fn floor(self) -> Self {
466 panic!("called floor() on a dual number")
467 }
468
469 #[inline]
470 fn ceil(self) -> Self {
471 panic!("called ceil() on a dual number")
472 }
473
474 #[inline]
475 fn round(self) -> Self {
476 panic!("called round() on a dual number")
477 }
478
479 #[inline]
480 fn trunc(self) -> Self {
481 panic!("called trunc() on a dual number")
482 }
483
484 #[inline]
485 fn fract(self) -> Self {
486 panic!("called fract() on a dual number")
487 }
488
489 #[inline]
490 fn mul_add(self, a: Self, b: Self) -> Self {
491 DualNum::mul_add(&self, a, b)
492 }
493
494 #[inline]
495 fn abs(self) -> Self::RealField {
496 Signed::abs(&self)
497 }
498
499 #[inline]
500 fn hypot(self, other: Self) -> Self::RealField {
501 let sum_sq = self.powi(2) + other.powi(2);
502 DualNum::sqrt(&sum_sq)
503 }
504
505 #[inline]
506 fn recip(self) -> Self {
507 DualNum::recip(&self)
508 }
509
510 #[inline]
511 fn conjugate(self) -> Self {
512 self
513 }
514
515 #[inline]
516 fn sin(self) -> Self {
517 DualNum::sin(&self)
518 }
519
520 #[inline]
521 fn cos(self) -> Self {
522 DualNum::cos(&self)
523 }
524
525 #[inline]
526 fn sin_cos(self) -> (Self, Self) {
527 DualNum::sin_cos(&self)
528 }
529
530 #[inline]
531 fn tan(self) -> Self {
532 DualNum::tan(&self)
533 }
534
535 #[inline]
536 fn asin(self) -> Self {
537 DualNum::asin(&self)
538 }
539
540 #[inline]
541 fn acos(self) -> Self {
542 DualNum::acos(&self)
543 }
544
545 #[inline]
546 fn atan(self) -> Self {
547 DualNum::atan(&self)
548 }
549
550 #[inline]
551 fn sinh(self) -> Self {
552 DualNum::sinh(&self)
553 }
554
555 #[inline]
556 fn cosh(self) -> Self {
557 DualNum::cosh(&self)
558 }
559
560 #[inline]
561 fn tanh(self) -> Self {
562 DualNum::tanh(&self)
563 }
564
565 #[inline]
566 fn asinh(self) -> Self {
567 DualNum::asinh(&self)
568 }
569
570 #[inline]
571 fn acosh(self) -> Self {
572 DualNum::acosh(&self)
573 }
574
575 #[inline]
576 fn atanh(self) -> Self {
577 DualNum::atanh(&self)
578 }
579
580 #[inline]
581 fn log(self, base: Self::RealField) -> Self {
582 DualNum::ln(&self) / DualNum::ln(&base)
583 }
584
585 #[inline]
586 fn log2(self) -> Self {
587 DualNum::log2(&self)
588 }
589
590 #[inline]
591 fn log10(self) -> Self {
592 DualNum::log10(&self)
593 }
594
595 #[inline]
596 fn ln(self) -> Self {
597 DualNum::ln(&self)
598 }
599
600 #[inline]
601 fn ln_1p(self) -> Self {
602 DualNum::ln_1p(&self)
603 }
604
605 #[inline]
606 fn sqrt(self) -> Self {
607 DualNum::sqrt(&self)
608 }
609
610 #[inline]
611 fn exp(self) -> Self {
612 DualNum::exp(&self)
613 }
614
615 #[inline]
616 fn exp2(self) -> Self {
617 DualNum::exp2(&self)
618 }
619
620 #[inline]
621 fn exp_m1(self) -> Self {
622 DualNum::exp_m1(&self)
623 }
624
625 #[inline]
626 fn powi(self, n: i32) -> Self {
627 DualNum::powi(&self, n)
628 }
629
630 #[inline]
631 fn powf(self, n: Self::RealField) -> Self {
632 DualNum::powd(&self, n)
634 }
635
636 #[inline]
637 fn powc(self, n: Self) -> Self {
638 self.powf(n)
640 }
641
642 #[inline]
643 fn cbrt(self) -> Self {
644 DualNum::cbrt(&self)
645 }
646
647 #[inline]
648 fn is_finite(&self) -> bool {
649 self.re.is_finite()
650 }
651
652 #[inline]
653 fn try_sqrt(self) -> Option<Self> {
654 if self > Self::zero() {
655 Some(DualNum::sqrt(&self))
656 } else {
657 None
658 }
659 }
660}
661
662impl<T> RealField for Dual2<T, T::Element>
663where
664 T: DualNum<T::Element> + SupersetOf<T> + Sync + Send,
665 T::Element: DualNum<T::Element> + Scalar + DualNumFloat,
666 T: SupersetOf<T::Element>,
667 T: SupersetOf<f32>,
668 T: SupersetOf<f64>,
669 T: SimdPartialOrd + PartialOrd,
670 T: RelativeEq + AbsDiffEq<Epsilon = T>,
671 T: SimdValue<Element = T, SimdBool = bool>,
672 T: UlpsEq,
673 T: AbsDiffEq,
674{
675 #[inline]
676 fn copysign(self, sign: Self) -> Self {
677 if sign.re.is_sign_positive() {
678 self.simd_abs()
679 } else {
680 -self.simd_abs()
681 }
682 }
683
684 #[inline]
685 fn atan2(self, other: Self) -> Self {
686 DualNum::atan2(&self, other)
687 }
688
689 #[inline]
690 fn pi() -> Self {
691 Self::from_re(<T as FloatConst>::PI())
692 }
693
694 #[inline]
695 fn two_pi() -> Self {
696 Self::from_re(<T as FloatConst>::TAU())
697 }
698
699 #[inline]
700 fn frac_pi_2() -> Self {
701 Self::from_re(<T as FloatConst>::FRAC_PI_4())
702 }
703
704 #[inline]
705 fn frac_pi_3() -> Self {
706 Self::from_re(<T as FloatConst>::FRAC_PI_3())
707 }
708
709 #[inline]
710 fn frac_pi_4() -> Self {
711 Self::from_re(<T as FloatConst>::FRAC_PI_4())
712 }
713
714 #[inline]
715 fn frac_pi_6() -> Self {
716 Self::from_re(<T as FloatConst>::FRAC_PI_6())
717 }
718
719 #[inline]
720 fn frac_pi_8() -> Self {
721 Self::from_re(<T as FloatConst>::FRAC_PI_8())
722 }
723
724 #[inline]
725 fn frac_1_pi() -> Self {
726 Self::from_re(<T as FloatConst>::FRAC_1_PI())
727 }
728
729 #[inline]
730 fn frac_2_pi() -> Self {
731 Self::from_re(<T as FloatConst>::FRAC_2_PI())
732 }
733
734 #[inline]
735 fn frac_2_sqrt_pi() -> Self {
736 Self::from_re(<T as FloatConst>::FRAC_2_SQRT_PI())
737 }
738
739 #[inline]
740 fn e() -> Self {
741 Self::from_re(<T as FloatConst>::E())
742 }
743
744 #[inline]
745 fn log2_e() -> Self {
746 Self::from_re(<T as FloatConst>::LOG2_E())
747 }
748
749 #[inline]
750 fn log10_e() -> Self {
751 Self::from_re(<T as FloatConst>::LOG10_E())
752 }
753
754 #[inline]
755 fn ln_2() -> Self {
756 Self::from_re(<T as FloatConst>::LN_2())
757 }
758
759 #[inline]
760 fn ln_10() -> Self {
761 Self::from_re(<T as FloatConst>::LN_10())
762 }
763
764 #[inline]
765 fn is_sign_positive(&self) -> bool {
766 self.re.is_sign_positive()
767 }
768
769 #[inline]
770 fn is_sign_negative(&self) -> bool {
771 self.re.is_sign_negative()
772 }
773
774 #[inline]
776 fn max(self, other: Self) -> Self {
777 if other > self { other } else { self }
778 }
779
780 #[inline]
782 fn min(self, other: Self) -> Self {
783 if other < self { other } else { self }
784 }
785
786 #[inline]
788 fn clamp(self, min: Self, max: Self) -> Self {
789 if self < min {
790 min
791 } else if self > max {
792 max
793 } else {
794 self
795 }
796 }
797
798 #[inline]
799 fn min_value() -> Option<Self> {
800 Some(Self::from_re(T::min_value()))
801 }
802
803 #[inline]
804 fn max_value() -> Option<Self> {
805 Some(Self::from_re(T::max_value()))
806 }
807}