1use crate::{Derivative, DualNum, DualNumFloat, DualStruct};
2use approx::{AbsDiffEq, RelativeEq, UlpsEq};
3use nalgebra::allocator::Allocator;
4use nalgebra::*;
5use num_traits::{Float, FloatConst, FromPrimitive, Inv, Num, One, Signed, Zero};
6use std::fmt;
7use std::iter::{Product, Sum};
8use std::marker::PhantomData;
9use std::ops::{
10 Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign,
11};
12
13#[derive(Clone, Debug)]
15pub struct DualVec<T: DualNum<F>, F, D: Dim>
16where
17 DefaultAllocator: Allocator<D>,
18{
19 pub re: T,
21 pub eps: Derivative<T, F, D, U1>,
23 f: PhantomData<F>,
24}
25
26#[cfg(feature = "ndarray")]
27impl<T: DualNum<F>, F: DualNumFloat, D: Dim> ndarray::ScalarOperand for DualVec<T, F, D> where
28 DefaultAllocator: Allocator<D>
29{
30}
31
32impl<T: DualNum<F> + Copy, F: Copy, const N: usize> Copy for DualVec<T, F, Const<N>> {}
33
34pub type DualSVec<D, F, const N: usize> = DualVec<D, F, Const<N>>;
35pub type DualDVec<D, F> = DualVec<D, F, Dyn>;
36pub type DualVec32<D> = DualVec<f32, f32, D>;
37pub type DualVec64<D> = DualVec<f64, f64, D>;
38pub type DualSVec32<const N: usize> = DualVec<f32, f32, Const<N>>;
39pub type DualSVec64<const N: usize> = DualVec<f64, f64, Const<N>>;
40pub type DualDVec32 = DualVec<f32, f32, Dyn>;
41pub type DualDVec64 = DualVec<f64, f64, Dyn>;
42
43impl<T: DualNum<F>, F, D: Dim> DualVec<T, F, D>
44where
45 DefaultAllocator: Allocator<D>,
46{
47 #[inline]
49 pub fn new(re: T, eps: Derivative<T, F, D, U1>) -> Self {
50 Self {
51 re,
52 eps,
53 f: PhantomData,
54 }
55 }
56}
57
58impl<T: DualNum<F>, F, const N: usize> DualSVec<T, F, N> {
59 #[inline]
74 pub fn derivative(mut self, index: usize) -> Self {
75 self.eps = Derivative::derivative_generic(Const::<N>, U1, index);
76 self
77 }
78}
79
80impl<T: DualNum<F>, F> DualDVec<T, F> {
81 #[inline]
96 pub fn derivative(mut self, variables: usize, index: usize) -> Self {
97 self.eps = Derivative::derivative_generic(Dyn(variables), U1, index);
98 self
99 }
100}
101
102impl<T: DualNum<F> + Zero, F, D: Dim> DualVec<T, F, D>
103where
104 DefaultAllocator: Allocator<D>,
105{
106 #[inline]
108 pub fn from_re(re: T) -> Self {
109 Self::new(re, Derivative::none())
110 }
111}
112
113impl<T: DualNum<F>, F: Float, D: Dim> DualVec<T, F, D>
115where
116 DefaultAllocator: Allocator<D>,
117{
118 #[inline]
119 fn chain_rule(&self, f0: T, f1: T) -> Self {
120 Self::new(f0, &self.eps * f1)
121 }
122}
123
124impl<T: DualNum<F>, F: Float, D: Dim> Mul<&DualVec<T, F, D>> for &DualVec<T, F, D>
126where
127 DefaultAllocator: Allocator<D>,
128{
129 type Output = DualVec<T, F, D>;
130 #[inline]
131 fn mul(self, other: &DualVec<T, F, D>) -> Self::Output {
132 DualVec::new(
133 self.re.clone() * other.re.clone(),
134 &self.eps * other.re.clone() + &other.eps * self.re.clone(),
135 )
136 }
137}
138
139impl<T: DualNum<F>, F: Float, D: Dim> Div<&DualVec<T, F, D>> for &DualVec<T, F, D>
141where
142 DefaultAllocator: Allocator<D>,
143{
144 type Output = DualVec<T, F, D>;
145 #[inline]
146 fn div(self, other: &DualVec<T, F, D>) -> DualVec<T, F, D> {
147 let inv = other.re.recip();
148 DualVec::new(
149 self.re.clone() * inv.clone(),
150 (&self.eps * other.re.clone() - &other.eps * self.re.clone()) * inv.clone() * inv,
151 )
152 }
153}
154
155impl<T: DualNum<F>, F, D: Dim> fmt::Display for DualVec<T, F, D>
157where
158 DefaultAllocator: Allocator<D>,
159{
160 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
161 write!(f, "{}", self.re)?;
162 self.eps.fmt(f, "ε")
163 }
164}
165
166impl_first_derivatives!(DualVec, [eps], [D], [D]);
167impl_dual!(DualVec, [eps], [D], [D]);
168
169impl<T, D: Dim> nalgebra::SimdValue for DualVec<T, T::Element, D>
188where
189 DefaultAllocator: Allocator<D>,
190 T: DualNum<T::Element> + SimdValue + Scalar,
191 T::Element: DualNum<T::Element> + Scalar,
192{
193 type Element = DualVec<T::Element, T::Element, D>;
202 type SimdBool = T::SimdBool;
203
204 const LANES: usize = T::LANES;
205
206 #[inline]
207 fn splat(val: Self::Element) -> Self {
208 let re = T::splat(val.re);
212 let eps = Derivative::splat(val.eps);
213 Self::new(re, eps)
214 }
215
216 #[inline]
217 fn extract(&self, i: usize) -> Self::Element {
218 let re = self.re.extract(i);
219 let eps = self.eps.extract(i);
220 Self::Element {
221 re,
222 eps,
223 f: PhantomData,
224 }
225 }
226
227 #[inline]
228 unsafe fn extract_unchecked(&self, i: usize) -> Self::Element {
229 let re = unsafe { self.re.extract_unchecked(i) };
230 let eps = unsafe { self.eps.extract_unchecked(i) };
231 Self::Element {
232 re,
233 eps,
234 f: PhantomData,
235 }
236 }
237
238 #[inline]
239 fn replace(&mut self, i: usize, val: Self::Element) {
240 self.re.replace(i, val.re);
241 self.eps.replace(i, val.eps);
242 }
243
244 #[inline]
245 unsafe fn replace_unchecked(&mut self, i: usize, val: Self::Element) {
246 unsafe { self.re.replace_unchecked(i, val.re) };
247 unsafe { self.eps.replace_unchecked(i, val.eps) };
248 }
249
250 #[inline]
251 fn select(self, cond: Self::SimdBool, other: Self) -> Self {
252 let re = self.re.select(cond, other.re);
253 let eps = self.eps.select(cond, other.eps);
254 Self::new(re, eps)
255 }
256}
257
258impl<T: DualNum<F> + PartialEq, F: Float, D: Dim> PartialEq for DualVec<T, F, D>
261where
262 DefaultAllocator: Allocator<D>,
263{
264 #[inline]
265 fn eq(&self, other: &Self) -> bool {
266 self.re.eq(&other.re)
267 }
268}
269impl<T: DualNum<F> + PartialOrd, F: Float, D: Dim> PartialOrd for DualVec<T, F, D>
272where
273 DefaultAllocator: Allocator<D>,
274{
275 #[inline]
276 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
277 self.re.partial_cmp(&other.re)
278 }
279}
280impl<T: DualNum<F> + approx::AbsDiffEq<Epsilon = T>, F: Float, D: Dim> approx::AbsDiffEq
283 for DualVec<T, F, D>
284where
285 DefaultAllocator: Allocator<D>,
286{
287 type Epsilon = Self;
288 #[inline]
289 fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
290 self.re.abs_diff_eq(&other.re, epsilon.re)
291 }
292
293 #[inline]
294 fn default_epsilon() -> Self::Epsilon {
295 Self::from_re(T::default_epsilon())
296 }
297}
298impl<T: DualNum<F> + approx::RelativeEq<Epsilon = T>, F: Float, D: Dim> approx::RelativeEq
301 for DualVec<T, F, D>
302where
303 DefaultAllocator: Allocator<D>,
304{
305 #[inline]
306 fn default_max_relative() -> Self::Epsilon {
307 Self::from_re(T::default_max_relative())
308 }
309
310 #[inline]
311 fn relative_eq(
312 &self,
313 other: &Self,
314 epsilon: Self::Epsilon,
315 max_relative: Self::Epsilon,
316 ) -> bool {
317 self.re.relative_eq(&other.re, epsilon.re, max_relative.re)
318 }
319}
320impl<T: DualNum<F> + UlpsEq<Epsilon = T>, F: Float, D: Dim> UlpsEq for DualVec<T, F, D>
321where
322 DefaultAllocator: Allocator<D>,
323{
324 #[inline]
325 fn default_max_ulps() -> u32 {
326 T::default_max_ulps()
327 }
328
329 #[inline]
330 fn ulps_eq(&self, other: &Self, epsilon: Self::Epsilon, max_ulps: u32) -> bool {
331 T::ulps_eq(&self.re, &other.re, epsilon.re, max_ulps)
332 }
333}
334
335impl<T, D: Dim> nalgebra::Field for DualVec<T, T::Element, D>
336where
337 T: DualNum<T::Element> + SimdValue,
338 T::Element: DualNum<T::Element> + Scalar + Float,
339 DefaultAllocator: Allocator<D> + Allocator<U1, D> + Allocator<D, U1> + Allocator<D, D>,
340{
341}
342
343use simba::scalar::{SubsetOf, SupersetOf};
344
345impl<TSuper, FSuper, T, F, D: Dim> SubsetOf<DualVec<TSuper, FSuper, D>> for DualVec<T, F, D>
346where
347 TSuper: DualNum<FSuper> + SupersetOf<T>,
348 T: DualNum<F>,
349 DefaultAllocator: Allocator<D> + Allocator<U1, D> + Allocator<D, U1> + Allocator<D, D>,
350{
351 #[inline(always)]
352 fn to_superset(&self) -> DualVec<TSuper, FSuper, D> {
353 let re = TSuper::from_subset(&self.re);
354 let eps = Derivative::from_subset(&self.eps);
355 DualVec {
356 re,
357 eps,
358 f: PhantomData,
359 }
360 }
361 #[inline(always)]
362 fn from_superset(element: &DualVec<TSuper, FSuper, D>) -> Option<Self> {
363 let re = TSuper::to_subset(&element.re)?;
364 let eps = Derivative::to_subset(&element.eps)?;
365 Some(Self::new(re, eps))
366 }
367 #[inline(always)]
368 fn from_superset_unchecked(element: &DualVec<TSuper, FSuper, D>) -> Self {
369 let re = TSuper::to_subset_unchecked(&element.re);
370 let eps = Derivative::to_subset_unchecked(&element.eps);
371 Self::new(re, eps)
372 }
373 #[inline(always)]
374 fn is_in_subset(element: &DualVec<TSuper, FSuper, D>) -> bool {
375 TSuper::is_in_subset(&element.re)
376 && <Derivative<_, _, _, _> as SupersetOf<Derivative<_, _, _, _>>>::is_in_subset(
377 &element.eps,
378 )
379 }
380}
381
382impl<TSuper, FSuper, D: Dim> SupersetOf<f32> for DualVec<TSuper, FSuper, D>
383where
384 TSuper: DualNum<FSuper> + SupersetOf<f32>,
385 DefaultAllocator: Allocator<D> + Allocator<U1, D> + Allocator<D, U1> + Allocator<D, D>,
386{
387 #[inline(always)]
388 fn is_in_subset(&self) -> bool {
389 self.re.is_in_subset()
390 }
391
392 #[inline(always)]
393 fn to_subset_unchecked(&self) -> f32 {
394 self.re.to_subset_unchecked()
395 }
396
397 #[inline(always)]
398 fn from_subset(element: &f32) -> Self {
399 let re = TSuper::from_subset(element);
401 let eps = Derivative::none();
402 Self::new(re, eps)
403 }
404}
405
406impl<TSuper, FSuper, D: Dim> SupersetOf<f64> for DualVec<TSuper, FSuper, D>
407where
408 TSuper: DualNum<FSuper> + SupersetOf<f64>,
409 DefaultAllocator: Allocator<D> + Allocator<U1, D> + Allocator<D, U1> + Allocator<D, D>,
410{
411 #[inline(always)]
412 fn is_in_subset(&self) -> bool {
413 self.re.is_in_subset()
414 }
415
416 #[inline(always)]
417 fn to_subset_unchecked(&self) -> f64 {
418 self.re.to_subset_unchecked()
419 }
420
421 #[inline(always)]
422 fn from_subset(element: &f64) -> Self {
423 let re = TSuper::from_subset(element);
425 let eps = Derivative::none();
426 Self::new(re, eps)
427 }
428}
429
430use nalgebra::{ComplexField, RealField};
438impl<T, D: Dim> ComplexField for DualVec<T, T::Element, D>
440where
441 T: DualNum<T::Element> + SupersetOf<T> + AbsDiffEq<Epsilon = T> + Sync + Send,
442 T::Element: DualNum<T::Element> + Scalar + DualNumFloat + Sync + Send,
443 T: SupersetOf<T::Element>,
444 T: SupersetOf<f32>,
445 T: SupersetOf<f64>,
446 T: SimdPartialOrd + PartialOrd,
447 T: SimdValue<Element = T, SimdBool = bool>,
448 T: RelativeEq + UlpsEq + AbsDiffEq,
449 DefaultAllocator: Allocator<D> + Allocator<U1, D> + Allocator<D, U1> + Allocator<D, D>,
450 <DefaultAllocator as Allocator<D>>::Buffer<T>: Sync + Send,
451{
452 type RealField = Self;
453
454 #[inline]
455 fn from_real(re: Self::RealField) -> Self {
456 re
457 }
458
459 #[inline]
460 fn real(self) -> Self::RealField {
461 self
462 }
463
464 #[inline]
465 fn imaginary(self) -> Self::RealField {
466 Self::zero()
467 }
468
469 #[inline]
470 fn modulus(self) -> Self::RealField {
471 self.abs()
472 }
473
474 #[inline]
475 fn modulus_squared(self) -> Self::RealField {
476 &self * &self
477 }
478
479 #[inline]
480 fn argument(self) -> Self::RealField {
481 Self::zero()
482 }
483
484 #[inline]
485 fn norm1(self) -> Self::RealField {
486 self.abs()
487 }
488
489 #[inline]
490 fn scale(self, factor: Self::RealField) -> Self {
491 self * factor
492 }
493
494 #[inline]
495 fn unscale(self, factor: Self::RealField) -> Self {
496 self / factor
497 }
498
499 #[inline]
500 fn floor(self) -> Self {
501 panic!("called floor() on a dual number")
502 }
503
504 #[inline]
505 fn ceil(self) -> Self {
506 panic!("called ceil() on a dual number")
507 }
508
509 #[inline]
510 fn round(self) -> Self {
511 panic!("called round() on a dual number")
512 }
513
514 #[inline]
515 fn trunc(self) -> Self {
516 panic!("called trunc() on a dual number")
517 }
518
519 #[inline]
520 fn fract(self) -> Self {
521 panic!("called fract() on a dual number")
522 }
523
524 #[inline]
525 fn mul_add(self, a: Self, b: Self) -> Self {
526 DualNum::mul_add(&self, a, b)
527 }
528
529 #[inline]
530 fn abs(self) -> Self::RealField {
531 Signed::abs(&self)
532 }
533
534 #[inline]
535 fn hypot(self, other: Self) -> Self::RealField {
536 let sum_sq = self.powi(2) + other.powi(2);
537 DualNum::sqrt(&sum_sq)
538 }
539
540 #[inline]
541 fn recip(self) -> Self {
542 DualNum::recip(&self)
543 }
544
545 #[inline]
546 fn conjugate(self) -> Self {
547 self
548 }
549
550 #[inline]
551 fn sin(self) -> Self {
552 DualNum::sin(&self)
553 }
554
555 #[inline]
556 fn cos(self) -> Self {
557 DualNum::cos(&self)
558 }
559
560 #[inline]
561 fn sin_cos(self) -> (Self, Self) {
562 DualNum::sin_cos(&self)
563 }
564
565 #[inline]
566 fn tan(self) -> Self {
567 DualNum::tan(&self)
568 }
569
570 #[inline]
571 fn asin(self) -> Self {
572 DualNum::asin(&self)
573 }
574
575 #[inline]
576 fn acos(self) -> Self {
577 DualNum::acos(&self)
578 }
579
580 #[inline]
581 fn atan(self) -> Self {
582 DualNum::atan(&self)
583 }
584
585 #[inline]
586 fn sinh(self) -> Self {
587 DualNum::sinh(&self)
588 }
589
590 #[inline]
591 fn cosh(self) -> Self {
592 DualNum::cosh(&self)
593 }
594
595 #[inline]
596 fn tanh(self) -> Self {
597 DualNum::tanh(&self)
598 }
599
600 #[inline]
601 fn asinh(self) -> Self {
602 DualNum::asinh(&self)
603 }
604
605 #[inline]
606 fn acosh(self) -> Self {
607 DualNum::acosh(&self)
608 }
609
610 #[inline]
611 fn atanh(self) -> Self {
612 DualNum::atanh(&self)
613 }
614
615 #[inline]
616 fn log(self, base: Self::RealField) -> Self {
617 DualNum::ln(&self) / DualNum::ln(&base)
618 }
619
620 #[inline]
621 fn log2(self) -> Self {
622 DualNum::log2(&self)
623 }
624
625 #[inline]
626 fn log10(self) -> Self {
627 DualNum::log10(&self)
628 }
629
630 #[inline]
631 fn ln(self) -> Self {
632 DualNum::ln(&self)
633 }
634
635 #[inline]
636 fn ln_1p(self) -> Self {
637 DualNum::ln_1p(&self)
638 }
639
640 #[inline]
641 fn sqrt(self) -> Self {
642 DualNum::sqrt(&self)
643 }
644
645 #[inline]
646 fn exp(self) -> Self {
647 DualNum::exp(&self)
648 }
649
650 #[inline]
651 fn exp2(self) -> Self {
652 DualNum::exp2(&self)
653 }
654
655 #[inline]
656 fn exp_m1(self) -> Self {
657 DualNum::exp_m1(&self)
658 }
659
660 #[inline]
661 fn powi(self, n: i32) -> Self {
662 DualNum::powi(&self, n)
663 }
664
665 #[inline]
666 fn powf(self, n: Self::RealField) -> Self {
667 DualNum::powd(&self, n)
669 }
670
671 #[inline]
672 fn powc(self, n: Self) -> Self {
673 self.powf(n)
675 }
676
677 #[inline]
678 fn cbrt(self) -> Self {
679 DualNum::cbrt(&self)
680 }
681
682 #[inline]
683 fn is_finite(&self) -> bool {
684 self.re.is_finite()
685 }
686
687 #[inline]
688 fn try_sqrt(self) -> Option<Self> {
689 if self > Self::zero() {
690 Some(DualNum::sqrt(&self))
691 } else {
692 None
693 }
694 }
695}
696
697impl<T, D: Dim> RealField for DualVec<T, T::Element, D>
698where
699 T: DualNum<T::Element> + SupersetOf<T> + Sync + Send,
700 T::Element: DualNum<T::Element> + Scalar + DualNumFloat,
701 T: SupersetOf<T::Element>,
702 T: SupersetOf<f32>,
703 T: SupersetOf<f64>,
704 T: SimdPartialOrd + PartialOrd,
705 T: RelativeEq + AbsDiffEq<Epsilon = T>,
706 T: SimdValue<Element = T, SimdBool = bool>,
707 T: UlpsEq,
708 T: AbsDiffEq,
709 DefaultAllocator: Allocator<D> + Allocator<U1, D> + Allocator<D, U1> + Allocator<D, D>,
710 <DefaultAllocator as Allocator<D>>::Buffer<T>: Sync + Send,
711{
712 #[inline]
713 fn copysign(self, sign: Self) -> Self {
714 if sign.re.is_sign_positive() {
715 self.simd_abs()
716 } else {
717 -self.simd_abs()
718 }
719 }
720
721 #[inline]
722 fn atan2(self, other: Self) -> Self {
723 DualNum::atan2(&self, other)
724 }
725
726 #[inline]
727 fn pi() -> Self {
728 Self::from_re(<T as FloatConst>::PI())
729 }
730
731 #[inline]
732 fn two_pi() -> Self {
733 Self::from_re(<T as FloatConst>::TAU())
734 }
735
736 #[inline]
737 fn frac_pi_2() -> Self {
738 Self::from_re(<T as FloatConst>::FRAC_PI_4())
739 }
740
741 #[inline]
742 fn frac_pi_3() -> Self {
743 Self::from_re(<T as FloatConst>::FRAC_PI_3())
744 }
745
746 #[inline]
747 fn frac_pi_4() -> Self {
748 Self::from_re(<T as FloatConst>::FRAC_PI_4())
749 }
750
751 #[inline]
752 fn frac_pi_6() -> Self {
753 Self::from_re(<T as FloatConst>::FRAC_PI_6())
754 }
755
756 #[inline]
757 fn frac_pi_8() -> Self {
758 Self::from_re(<T as FloatConst>::FRAC_PI_8())
759 }
760
761 #[inline]
762 fn frac_1_pi() -> Self {
763 Self::from_re(<T as FloatConst>::FRAC_1_PI())
764 }
765
766 #[inline]
767 fn frac_2_pi() -> Self {
768 Self::from_re(<T as FloatConst>::FRAC_2_PI())
769 }
770
771 #[inline]
772 fn frac_2_sqrt_pi() -> Self {
773 Self::from_re(<T as FloatConst>::FRAC_2_SQRT_PI())
774 }
775
776 #[inline]
777 fn e() -> Self {
778 Self::from_re(<T as FloatConst>::E())
779 }
780
781 #[inline]
782 fn log2_e() -> Self {
783 Self::from_re(<T as FloatConst>::LOG2_E())
784 }
785
786 #[inline]
787 fn log10_e() -> Self {
788 Self::from_re(<T as FloatConst>::LOG10_E())
789 }
790
791 #[inline]
792 fn ln_2() -> Self {
793 Self::from_re(<T as FloatConst>::LN_2())
794 }
795
796 #[inline]
797 fn ln_10() -> Self {
798 Self::from_re(<T as FloatConst>::LN_10())
799 }
800
801 #[inline]
802 fn is_sign_positive(&self) -> bool {
803 self.re.is_sign_positive()
804 }
805
806 #[inline]
807 fn is_sign_negative(&self) -> bool {
808 self.re.is_sign_negative()
809 }
810
811 #[inline]
813 fn max(self, other: Self) -> Self {
814 if other > self { other } else { self }
815 }
816
817 #[inline]
819 fn min(self, other: Self) -> Self {
820 if other < self { other } else { self }
821 }
822
823 #[inline]
825 fn clamp(self, min: Self, max: Self) -> Self {
826 if self < min {
827 min
828 } else if self > max {
829 max
830 } else {
831 self
832 }
833 }
834
835 #[inline]
836 fn min_value() -> Option<Self> {
837 Some(Self::from_re(T::min_value()))
838 }
839
840 #[inline]
841 fn max_value() -> Option<Self> {
842 Some(Self::from_re(T::max_value()))
843 }
844}