1use crate::{DualNum, DualNumFloat, DualStruct};
2use approx::{AbsDiffEq, RelativeEq, UlpsEq};
3use nalgebra::*;
4use num_traits::{Float, FloatConst, FromPrimitive, Inv, Num, One, Signed, Zero};
5#[cfg(feature = "serde")]
6use serde::{Deserialize, Serialize};
7use std::fmt;
8use std::iter::{Product, Sum};
9use std::marker::PhantomData;
10use std::ops::{
11 Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign,
12};
13
14#[derive(Copy, Clone, Debug)]
16#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
17pub struct Dual2<T: DualNum<F>, F> {
18 pub re: T,
20 pub v1: T,
22 pub v2: T,
24 #[cfg_attr(feature = "serde", serde(skip))]
25 f: PhantomData<F>,
26}
27
28#[cfg(feature = "ndarray")]
29impl<T: DualNum<F>, F: DualNumFloat> ndarray::ScalarOperand for Dual2<T, F> {}
30
31pub type Dual2_32 = Dual2<f32, f32>;
32pub type Dual2_64 = Dual2<f64, f64>;
33
34impl<T: DualNum<F>, F> Dual2<T, F> {
35 #[inline]
37 pub fn new(re: T, v1: T, v2: T) -> Self {
38 Self {
39 re,
40 v1,
41 v2,
42 f: PhantomData,
43 }
44 }
45}
46
47impl<T: DualNum<F>, F> Dual2<T, F> {
48 #[inline]
70 pub fn derivative(mut self) -> Self {
71 self.v1 = T::one();
72 self
73 }
74}
75
76impl<T: DualNum<F>, F> Dual2<T, F> {
77 #[inline]
79 pub fn from_re(re: T) -> Self {
80 Self::new(re, T::zero(), T::zero())
81 }
82}
83
84impl<T: DualNum<F>, F: Float> Dual2<T, F> {
86 #[inline]
87 fn chain_rule(&self, f0: T, f1: T, f2: T) -> Self {
88 Self::new(
89 f0,
90 self.v1.clone() * f1.clone(),
91 self.v2.clone() * f1 + self.v1.clone() * self.v1.clone() * f2,
92 )
93 }
94}
95
96impl<T: DualNum<F>, F: Float> Mul<&Dual2<T, F>> for &Dual2<T, F> {
98 type Output = Dual2<T, F>;
99 #[inline]
100 fn mul(self, other: &Dual2<T, F>) -> Dual2<T, F> {
101 Dual2::new(
102 self.re.clone() * other.re.clone(),
103 other.v1.clone() * self.re.clone() + self.v1.clone() * other.re.clone(),
104 other.v2.clone() * self.re.clone()
105 + self.v1.clone() * other.v1.clone()
106 + other.v1.clone() * self.v1.clone()
107 + self.v2.clone() * other.re.clone(),
108 )
109 }
110}
111
112impl<T: DualNum<F>, F: Float> Div<&Dual2<T, F>> for &Dual2<T, F> {
114 type Output = Dual2<T, F>;
115 #[inline]
116 fn div(self, other: &Dual2<T, F>) -> Dual2<T, F> {
117 let inv = other.re.recip();
118 let inv2 = inv.clone() * inv.clone();
119 Dual2::new(
120 self.re.clone() * inv.clone(),
121 (self.v1.clone() * other.re.clone() - other.v1.clone() * self.re.clone())
122 * inv2.clone(),
123 self.v2.clone() * inv.clone()
124 - (other.v2.clone() * self.re.clone()
125 + self.v1.clone() * other.v1.clone()
126 + other.v1.clone() * self.v1.clone())
127 * inv2.clone()
128 + other.v1.clone()
129 * other.v1.clone()
130 * ((T::one() + T::one()) * self.re.clone() * inv2 * inv),
131 )
132 }
133}
134
135impl<T: DualNum<F>, F: fmt::Display> fmt::Display for Dual2<T, F> {
137 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
138 write!(f, "{} + {}ε1 + {}ε1²", self.re, self.v1, self.v2)
139 }
140}
141
142impl_second_derivatives!(Dual2, [v1, v2]);
143impl_dual!(Dual2, [v1, v2]);
144
145impl<T> nalgebra::SimdValue for Dual2<T, T::Element>
164where
165 T: DualNum<T::Element> + SimdValue + Scalar,
166 T::Element: DualNum<T::Element> + Scalar,
167{
168 type Element = Dual2<T::Element, T::Element>;
177 type SimdBool = T::SimdBool;
178
179 const LANES: usize = T::LANES;
180
181 #[inline]
182 fn splat(val: Self::Element) -> Self {
183 let re = T::splat(val.re);
187 let v1 = T::splat(val.v1);
188 let v2 = T::splat(val.v2);
189 Self::new(re, v1, v2)
190 }
191
192 #[inline]
193 fn extract(&self, i: usize) -> Self::Element {
194 let re = self.re.extract(i);
195 let v1 = self.v1.extract(i);
196 let v2 = self.v2.extract(i);
197 Self::Element {
198 re,
199 v1,
200 v2,
201 f: PhantomData,
202 }
203 }
204
205 #[inline]
206 unsafe fn extract_unchecked(&self, i: usize) -> Self::Element {
207 let re = unsafe { self.re.extract_unchecked(i) };
208 let v1 = unsafe { self.v1.extract_unchecked(i) };
209 let v2 = unsafe { self.v2.extract_unchecked(i) };
210 Self::Element {
211 re,
212 v1,
213 v2,
214 f: PhantomData,
215 }
216 }
217
218 #[inline]
219 fn replace(&mut self, i: usize, val: Self::Element) {
220 self.re.replace(i, val.re);
221 self.v1.replace(i, val.v1);
222 self.v2.replace(i, val.v2);
223 }
224
225 #[inline]
226 unsafe fn replace_unchecked(&mut self, i: usize, val: Self::Element) {
227 unsafe { self.re.replace_unchecked(i, val.re) };
228 unsafe { self.v1.replace_unchecked(i, val.v1) };
229 unsafe { self.v2.replace_unchecked(i, val.v2) };
230 }
231
232 #[inline]
233 fn select(self, cond: Self::SimdBool, other: Self) -> Self {
234 let re = self.re.select(cond, other.re);
235 let v1 = self.v1.select(cond, other.v1);
236 let v2 = self.v2.select(cond, other.v2);
237 Self::new(re, v1, v2)
238 }
239}
240
241impl<T: DualNum<F> + approx::AbsDiffEq<Epsilon = T>, F: Float> approx::AbsDiffEq for Dual2<T, F> {
244 type Epsilon = Self;
245 #[inline]
246 fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
247 self.re.abs_diff_eq(&other.re, epsilon.re)
248 }
249
250 #[inline]
251 fn default_epsilon() -> Self::Epsilon {
252 Self::from_re(T::default_epsilon())
253 }
254}
255impl<T: DualNum<F> + approx::RelativeEq<Epsilon = T>, F: Float> approx::RelativeEq for Dual2<T, F> {
258 #[inline]
259 fn default_max_relative() -> Self::Epsilon {
260 Self::from_re(T::default_max_relative())
261 }
262
263 #[inline]
264 fn relative_eq(
265 &self,
266 other: &Self,
267 epsilon: Self::Epsilon,
268 max_relative: Self::Epsilon,
269 ) -> bool {
270 self.re.relative_eq(&other.re, epsilon.re, max_relative.re)
271 }
272}
273impl<T: DualNum<F> + UlpsEq<Epsilon = T>, F: Float> UlpsEq for Dual2<T, F> {
274 #[inline]
275 fn default_max_ulps() -> u32 {
276 T::default_max_ulps()
277 }
278
279 #[inline]
280 fn ulps_eq(&self, other: &Self, epsilon: Self::Epsilon, max_ulps: u32) -> bool {
281 T::ulps_eq(&self.re, &other.re, epsilon.re, max_ulps)
282 }
283}
284
285impl<T> nalgebra::Field for Dual2<T, T::Element>
286where
287 T: DualNum<T::Element> + SimdValue,
288 T::Element: DualNum<T::Element> + Scalar + Float,
289{
290}
291
292use simba::scalar::{SubsetOf, SupersetOf};
293
294impl<TSuper, FSuper, T, F> SubsetOf<Dual2<TSuper, FSuper>> for Dual2<T, F>
295where
296 TSuper: DualNum<FSuper> + SupersetOf<T>,
297 T: DualNum<F>,
298{
299 #[inline(always)]
300 fn to_superset(&self) -> Dual2<TSuper, FSuper> {
301 let re = TSuper::from_subset(&self.re);
302 let v1 = TSuper::from_subset(&self.v1);
303 let v2 = TSuper::from_subset(&self.v2);
304 Dual2 {
305 re,
306 v1,
307 v2,
308 f: PhantomData,
309 }
310 }
311 #[inline(always)]
312 fn from_superset(element: &Dual2<TSuper, FSuper>) -> Option<Self> {
313 let re = TSuper::to_subset(&element.re)?;
314 let v1 = TSuper::to_subset(&element.v1)?;
315 let v2 = TSuper::to_subset(&element.v2)?;
316 Some(Self::new(re, v1, v2))
317 }
318 #[inline(always)]
319 fn from_superset_unchecked(element: &Dual2<TSuper, FSuper>) -> Self {
320 let re = TSuper::to_subset_unchecked(&element.re);
321 let v1 = TSuper::to_subset_unchecked(&element.v1);
322 let v2 = TSuper::to_subset_unchecked(&element.v2);
323 Self::new(re, v1, v2)
324 }
325 #[inline(always)]
326 fn is_in_subset(element: &Dual2<TSuper, FSuper>) -> bool {
327 TSuper::is_in_subset(&element.re)
328 && TSuper::is_in_subset(&element.v1)
329 && TSuper::is_in_subset(&element.v2)
330 }
331}
332
333impl<TSuper, FSuper> SupersetOf<f32> for Dual2<TSuper, FSuper>
334where
335 TSuper: DualNum<FSuper> + SupersetOf<f32>,
336{
337 #[inline(always)]
338 fn is_in_subset(&self) -> bool {
339 self.re.is_in_subset()
340 }
341
342 #[inline(always)]
343 fn to_subset_unchecked(&self) -> f32 {
344 self.re.to_subset_unchecked()
345 }
346
347 #[inline(always)]
348 fn from_subset(element: &f32) -> Self {
349 let re = TSuper::from_subset(element);
351 let v1 = TSuper::zero();
352 let v2 = TSuper::zero();
353 Self::new(re, v1, v2)
354 }
355}
356
357impl<TSuper, FSuper> SupersetOf<f64> for Dual2<TSuper, FSuper>
358where
359 TSuper: DualNum<FSuper> + SupersetOf<f64>,
360{
361 #[inline(always)]
362 fn is_in_subset(&self) -> bool {
363 self.re.is_in_subset()
364 }
365
366 #[inline(always)]
367 fn to_subset_unchecked(&self) -> f64 {
368 self.re.to_subset_unchecked()
369 }
370
371 #[inline(always)]
372 fn from_subset(element: &f64) -> Self {
373 let re = TSuper::from_subset(element);
375 let v1 = TSuper::zero();
376 let v2 = TSuper::zero();
377 Self::new(re, v1, v2)
378 }
379}
380
381use nalgebra::{ComplexField, RealField};
389impl<T> ComplexField for Dual2<T, T::Element>
391where
392 T: DualNum<T::Element> + SupersetOf<T> + AbsDiffEq<Epsilon = T> + Sync + Send,
393 T::Element: DualNum<T::Element> + Scalar + DualNumFloat + Sync + Send,
394 T: SupersetOf<T::Element>,
395 T: SupersetOf<f32>,
396 T: SupersetOf<f64>,
397 T: SimdPartialOrd + PartialOrd,
398 T: SimdValue<Element = T, SimdBool = bool>,
399 T: RelativeEq + UlpsEq + AbsDiffEq,
400{
401 type RealField = Self;
402
403 #[inline]
404 fn from_real(re: Self::RealField) -> Self {
405 re
406 }
407
408 #[inline]
409 fn real(self) -> Self::RealField {
410 self
411 }
412
413 #[inline]
414 fn imaginary(self) -> Self::RealField {
415 Self::zero()
416 }
417
418 #[inline]
419 fn modulus(self) -> Self::RealField {
420 self.abs()
421 }
422
423 #[inline]
424 fn modulus_squared(self) -> Self::RealField {
425 self * self
426 }
427
428 #[inline]
429 fn argument(self) -> Self::RealField {
430 Self::zero()
431 }
432
433 #[inline]
434 fn norm1(self) -> Self::RealField {
435 self.abs()
436 }
437
438 #[inline]
439 fn scale(self, factor: Self::RealField) -> Self {
440 self * factor
441 }
442
443 #[inline]
444 fn unscale(self, factor: Self::RealField) -> Self {
445 self / factor
446 }
447
448 #[inline]
449 fn floor(self) -> Self {
450 panic!("called floor() on a dual number")
451 }
452
453 #[inline]
454 fn ceil(self) -> Self {
455 panic!("called ceil() on a dual number")
456 }
457
458 #[inline]
459 fn round(self) -> Self {
460 panic!("called round() on a dual number")
461 }
462
463 #[inline]
464 fn trunc(self) -> Self {
465 panic!("called trunc() on a dual number")
466 }
467
468 #[inline]
469 fn fract(self) -> Self {
470 panic!("called fract() on a dual number")
471 }
472
473 #[inline]
474 fn mul_add(self, a: Self, b: Self) -> Self {
475 DualNum::mul_add(&self, a, b)
476 }
477
478 #[inline]
479 fn abs(self) -> Self::RealField {
480 Signed::abs(&self)
481 }
482
483 #[inline]
484 fn hypot(self, other: Self) -> Self::RealField {
485 let sum_sq = self.powi(2) + other.powi(2);
486 DualNum::sqrt(&sum_sq)
487 }
488
489 #[inline]
490 fn recip(self) -> Self {
491 DualNum::recip(&self)
492 }
493
494 #[inline]
495 fn conjugate(self) -> Self {
496 self
497 }
498
499 #[inline]
500 fn sin(self) -> Self {
501 DualNum::sin(&self)
502 }
503
504 #[inline]
505 fn cos(self) -> Self {
506 DualNum::cos(&self)
507 }
508
509 #[inline]
510 fn sin_cos(self) -> (Self, Self) {
511 DualNum::sin_cos(&self)
512 }
513
514 #[inline]
515 fn tan(self) -> Self {
516 DualNum::tan(&self)
517 }
518
519 #[inline]
520 fn asin(self) -> Self {
521 DualNum::asin(&self)
522 }
523
524 #[inline]
525 fn acos(self) -> Self {
526 DualNum::acos(&self)
527 }
528
529 #[inline]
530 fn atan(self) -> Self {
531 DualNum::atan(&self)
532 }
533
534 #[inline]
535 fn sinh(self) -> Self {
536 DualNum::sinh(&self)
537 }
538
539 #[inline]
540 fn cosh(self) -> Self {
541 DualNum::cosh(&self)
542 }
543
544 #[inline]
545 fn tanh(self) -> Self {
546 DualNum::tanh(&self)
547 }
548
549 #[inline]
550 fn asinh(self) -> Self {
551 DualNum::asinh(&self)
552 }
553
554 #[inline]
555 fn acosh(self) -> Self {
556 DualNum::acosh(&self)
557 }
558
559 #[inline]
560 fn atanh(self) -> Self {
561 DualNum::atanh(&self)
562 }
563
564 #[inline]
565 fn log(self, base: Self::RealField) -> Self {
566 DualNum::ln(&self) / DualNum::ln(&base)
567 }
568
569 #[inline]
570 fn log2(self) -> Self {
571 DualNum::log2(&self)
572 }
573
574 #[inline]
575 fn log10(self) -> Self {
576 DualNum::log10(&self)
577 }
578
579 #[inline]
580 fn ln(self) -> Self {
581 DualNum::ln(&self)
582 }
583
584 #[inline]
585 fn ln_1p(self) -> Self {
586 DualNum::ln_1p(&self)
587 }
588
589 #[inline]
590 fn sqrt(self) -> Self {
591 DualNum::sqrt(&self)
592 }
593
594 #[inline]
595 fn exp(self) -> Self {
596 DualNum::exp(&self)
597 }
598
599 #[inline]
600 fn exp2(self) -> Self {
601 DualNum::exp2(&self)
602 }
603
604 #[inline]
605 fn exp_m1(self) -> Self {
606 DualNum::exp_m1(&self)
607 }
608
609 #[inline]
610 fn powi(self, n: i32) -> Self {
611 DualNum::powi(&self, n)
612 }
613
614 #[inline]
615 fn powf(self, n: Self::RealField) -> Self {
616 DualNum::powd(&self, n)
618 }
619
620 #[inline]
621 fn powc(self, n: Self) -> Self {
622 self.powf(n)
624 }
625
626 #[inline]
627 fn cbrt(self) -> Self {
628 DualNum::cbrt(&self)
629 }
630
631 #[inline]
632 fn is_finite(&self) -> bool {
633 self.re.is_finite()
634 }
635
636 #[inline]
637 fn try_sqrt(self) -> Option<Self> {
638 if self > Self::zero() {
639 Some(DualNum::sqrt(&self))
640 } else {
641 None
642 }
643 }
644}
645
646impl<T> RealField for Dual2<T, T::Element>
647where
648 T: DualNum<T::Element> + SupersetOf<T> + Sync + Send,
649 T::Element: DualNum<T::Element> + Scalar + DualNumFloat,
650 T: SupersetOf<T::Element>,
651 T: SupersetOf<f32>,
652 T: SupersetOf<f64>,
653 T: SimdPartialOrd + PartialOrd,
654 T: RelativeEq + AbsDiffEq<Epsilon = T>,
655 T: SimdValue<Element = T, SimdBool = bool>,
656 T: UlpsEq,
657 T: AbsDiffEq,
658{
659 #[inline]
660 fn copysign(self, sign: Self) -> Self {
661 if sign.re.is_sign_positive() {
662 self.simd_abs()
663 } else {
664 -self.simd_abs()
665 }
666 }
667
668 #[inline]
669 fn atan2(self, other: Self) -> Self {
670 DualNum::atan2(&self, other)
671 }
672
673 #[inline]
674 fn pi() -> Self {
675 Self::from_re(<T as FloatConst>::PI())
676 }
677
678 #[inline]
679 fn two_pi() -> Self {
680 Self::from_re(<T as FloatConst>::TAU())
681 }
682
683 #[inline]
684 fn frac_pi_2() -> Self {
685 Self::from_re(<T as FloatConst>::FRAC_PI_4())
686 }
687
688 #[inline]
689 fn frac_pi_3() -> Self {
690 Self::from_re(<T as FloatConst>::FRAC_PI_3())
691 }
692
693 #[inline]
694 fn frac_pi_4() -> Self {
695 Self::from_re(<T as FloatConst>::FRAC_PI_4())
696 }
697
698 #[inline]
699 fn frac_pi_6() -> Self {
700 Self::from_re(<T as FloatConst>::FRAC_PI_6())
701 }
702
703 #[inline]
704 fn frac_pi_8() -> Self {
705 Self::from_re(<T as FloatConst>::FRAC_PI_8())
706 }
707
708 #[inline]
709 fn frac_1_pi() -> Self {
710 Self::from_re(<T as FloatConst>::FRAC_1_PI())
711 }
712
713 #[inline]
714 fn frac_2_pi() -> Self {
715 Self::from_re(<T as FloatConst>::FRAC_2_PI())
716 }
717
718 #[inline]
719 fn frac_2_sqrt_pi() -> Self {
720 Self::from_re(<T as FloatConst>::FRAC_2_SQRT_PI())
721 }
722
723 #[inline]
724 fn e() -> Self {
725 Self::from_re(<T as FloatConst>::E())
726 }
727
728 #[inline]
729 fn log2_e() -> Self {
730 Self::from_re(<T as FloatConst>::LOG2_E())
731 }
732
733 #[inline]
734 fn log10_e() -> Self {
735 Self::from_re(<T as FloatConst>::LOG10_E())
736 }
737
738 #[inline]
739 fn ln_2() -> Self {
740 Self::from_re(<T as FloatConst>::LN_2())
741 }
742
743 #[inline]
744 fn ln_10() -> Self {
745 Self::from_re(<T as FloatConst>::LN_10())
746 }
747
748 #[inline]
749 fn is_sign_positive(&self) -> bool {
750 self.re.is_sign_positive()
751 }
752
753 #[inline]
754 fn is_sign_negative(&self) -> bool {
755 self.re.is_sign_negative()
756 }
757
758 #[inline]
760 fn max(self, other: Self) -> Self {
761 if other > self { other } else { self }
762 }
763
764 #[inline]
766 fn min(self, other: Self) -> Self {
767 if other < self { other } else { self }
768 }
769
770 #[inline]
772 fn clamp(self, min: Self, max: Self) -> Self {
773 if self < min {
774 min
775 } else if self > max {
776 max
777 } else {
778 self
779 }
780 }
781
782 #[inline]
783 fn min_value() -> Option<Self> {
784 Some(Self::from_re(T::min_value()))
785 }
786
787 #[inline]
788 fn max_value() -> Option<Self> {
789 Some(Self::from_re(T::max_value()))
790 }
791}