1use crate::{Derivative, DualNum, DualNumFloat, DualStruct};
2use approx::{AbsDiffEq, RelativeEq, UlpsEq};
3use nalgebra::allocator::Allocator;
4use nalgebra::*;
5use num_traits::{Float, FloatConst, FromPrimitive, Inv, Num, One, Signed, Zero};
6use std::fmt;
7use std::iter::{Product, Sum};
8use std::marker::PhantomData;
9use std::ops::{
10 Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign,
11};
12
13#[derive(Clone, Debug)]
15pub struct DualVec<T: DualNum<F>, F, D: Dim>
16where
17 DefaultAllocator: Allocator<D>,
18{
19 pub re: T,
21 pub eps: Derivative<T, F, D, U1>,
23 f: PhantomData<F>,
24}
25
26#[cfg(feature = "ndarray")]
27impl<T: DualNum<F>, F: DualNumFloat, D: Dim> ndarray::ScalarOperand for DualVec<T, F, D> where
28 DefaultAllocator: Allocator<D>
29{
30}
31
32impl<T: DualNum<F> + Copy, F: Copy, const N: usize> Copy for DualVec<T, F, Const<N>> {}
33
34pub type DualSVec<D, F, const N: usize> = DualVec<D, F, Const<N>>;
35pub type DualDVec<D, F> = DualVec<D, F, Dyn>;
36pub type DualVec32<D> = DualVec<f32, f32, D>;
37pub type DualVec64<D> = DualVec<f64, f64, D>;
38pub type DualSVec32<const N: usize> = DualVec<f32, f32, Const<N>>;
39pub type DualSVec64<const N: usize> = DualVec<f64, f64, Const<N>>;
40pub type DualDVec32 = DualVec<f32, f32, Dyn>;
41pub type DualDVec64 = DualVec<f64, f64, Dyn>;
42
43impl<T: DualNum<F>, F, D: Dim> DualVec<T, F, D>
44where
45 DefaultAllocator: Allocator<D>,
46{
47 #[inline]
49 pub fn new(re: T, eps: Derivative<T, F, D, U1>) -> Self {
50 Self {
51 re,
52 eps,
53 f: PhantomData,
54 }
55 }
56}
57
58impl<T: DualNum<F>, F, const N: usize> DualSVec<T, F, N> {
59 #[inline]
74 pub fn derivative(mut self, index: usize) -> Self {
75 self.eps = Derivative::derivative_generic(Const::<N>, U1, index);
76 self
77 }
78}
79
80impl<T: DualNum<F>, F> DualDVec<T, F> {
81 #[inline]
96 pub fn derivative(mut self, variables: usize, index: usize) -> Self {
97 self.eps = Derivative::derivative_generic(Dyn(variables), U1, index);
98 self
99 }
100}
101
102impl<T: DualNum<F> + Zero, F, D: Dim> DualVec<T, F, D>
103where
104 DefaultAllocator: Allocator<D>,
105{
106 #[inline]
108 pub fn from_re(re: T) -> Self {
109 Self::new(re, Derivative::none())
110 }
111}
112
113impl<T: DualNum<F>, F: Float, D: Dim> DualVec<T, F, D>
115where
116 DefaultAllocator: Allocator<D>,
117{
118 #[inline]
119 fn chain_rule(&self, f0: T, f1: T) -> Self {
120 Self::new(f0, &self.eps * f1)
121 }
122}
123
124impl<T: DualNum<F>, F: Float, D: Dim> Mul<&DualVec<T, F, D>> for &DualVec<T, F, D>
126where
127 DefaultAllocator: Allocator<D>,
128{
129 type Output = DualVec<T, F, D>;
130 #[inline]
131 fn mul(self, other: &DualVec<T, F, D>) -> Self::Output {
132 DualVec::new(
133 self.re.clone() * other.re.clone(),
134 &self.eps * other.re.clone() + &other.eps * self.re.clone(),
135 )
136 }
137}
138
139impl<T: DualNum<F>, F: Float, D: Dim> Div<&DualVec<T, F, D>> for &DualVec<T, F, D>
141where
142 DefaultAllocator: Allocator<D>,
143{
144 type Output = DualVec<T, F, D>;
145 #[inline]
146 fn div(self, other: &DualVec<T, F, D>) -> DualVec<T, F, D> {
147 let inv = other.re.recip();
148 DualVec::new(
149 self.re.clone() * inv.clone(),
150 (&self.eps * other.re.clone() - &other.eps * self.re.clone()) * inv.clone() * inv,
151 )
152 }
153}
154
155impl<T: DualNum<F>, F, D: Dim> fmt::Display for DualVec<T, F, D>
157where
158 DefaultAllocator: Allocator<D>,
159{
160 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
161 write!(f, "{}", self.re)?;
162 self.eps.fmt(f, "ε")
163 }
164}
165
166impl_first_derivatives!(DualVec, [eps], [D], [D]);
167impl_dual!(DualVec, [eps], [D], [D]);
168
169impl<T, D: Dim> nalgebra::SimdValue for DualVec<T, T::Element, D>
188where
189 DefaultAllocator: Allocator<D>,
190 T: DualNum<T::Element> + SimdValue + Scalar,
191 T::Element: DualNum<T::Element> + Scalar,
192{
193 type Element = DualVec<T::Element, T::Element, D>;
202 type SimdBool = T::SimdBool;
203
204 const LANES: usize = T::LANES;
205
206 #[inline]
207 fn splat(val: Self::Element) -> Self {
208 let re = T::splat(val.re);
212 let eps = Derivative::splat(val.eps);
213 Self::new(re, eps)
214 }
215
216 #[inline]
217 fn extract(&self, i: usize) -> Self::Element {
218 let re = self.re.extract(i);
219 let eps = self.eps.extract(i);
220 Self::Element {
221 re,
222 eps,
223 f: PhantomData,
224 }
225 }
226
227 #[inline]
228 unsafe fn extract_unchecked(&self, i: usize) -> Self::Element {
229 let re = unsafe { self.re.extract_unchecked(i) };
230 let eps = unsafe { self.eps.extract_unchecked(i) };
231 Self::Element {
232 re,
233 eps,
234 f: PhantomData,
235 }
236 }
237
238 #[inline]
239 fn replace(&mut self, i: usize, val: Self::Element) {
240 self.re.replace(i, val.re);
241 self.eps.replace(i, val.eps);
242 }
243
244 #[inline]
245 unsafe fn replace_unchecked(&mut self, i: usize, val: Self::Element) {
246 unsafe { self.re.replace_unchecked(i, val.re) };
247 unsafe { self.eps.replace_unchecked(i, val.eps) };
248 }
249
250 #[inline]
251 fn select(self, cond: Self::SimdBool, other: Self) -> Self {
252 let re = self.re.select(cond, other.re);
253 let eps = self.eps.select(cond, other.eps);
254 Self::new(re, eps)
255 }
256}
257
258impl<T: DualNum<F> + approx::AbsDiffEq<Epsilon = T>, F: Float, D: Dim> approx::AbsDiffEq
261 for DualVec<T, F, D>
262where
263 DefaultAllocator: Allocator<D>,
264{
265 type Epsilon = Self;
266 #[inline]
267 fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
268 self.re.abs_diff_eq(&other.re, epsilon.re)
269 }
270
271 #[inline]
272 fn default_epsilon() -> Self::Epsilon {
273 Self::from_re(T::default_epsilon())
274 }
275}
276impl<T: DualNum<F> + approx::RelativeEq<Epsilon = T>, F: Float, D: Dim> approx::RelativeEq
279 for DualVec<T, F, D>
280where
281 DefaultAllocator: Allocator<D>,
282{
283 #[inline]
284 fn default_max_relative() -> Self::Epsilon {
285 Self::from_re(T::default_max_relative())
286 }
287
288 #[inline]
289 fn relative_eq(
290 &self,
291 other: &Self,
292 epsilon: Self::Epsilon,
293 max_relative: Self::Epsilon,
294 ) -> bool {
295 self.re.relative_eq(&other.re, epsilon.re, max_relative.re)
296 }
297}
298impl<T: DualNum<F> + UlpsEq<Epsilon = T>, F: Float, D: Dim> UlpsEq for DualVec<T, F, D>
299where
300 DefaultAllocator: Allocator<D>,
301{
302 #[inline]
303 fn default_max_ulps() -> u32 {
304 T::default_max_ulps()
305 }
306
307 #[inline]
308 fn ulps_eq(&self, other: &Self, epsilon: Self::Epsilon, max_ulps: u32) -> bool {
309 T::ulps_eq(&self.re, &other.re, epsilon.re, max_ulps)
310 }
311}
312
313impl<T, D: Dim> nalgebra::Field for DualVec<T, T::Element, D>
314where
315 T: DualNum<T::Element> + SimdValue,
316 T::Element: DualNum<T::Element> + Scalar + Float,
317 DefaultAllocator: Allocator<D> + Allocator<U1, D> + Allocator<D, U1> + Allocator<D, D>,
318{
319}
320
321use simba::scalar::{SubsetOf, SupersetOf};
322
323impl<TSuper, FSuper, T, F, D: Dim> SubsetOf<DualVec<TSuper, FSuper, D>> for DualVec<T, F, D>
324where
325 TSuper: DualNum<FSuper> + SupersetOf<T>,
326 T: DualNum<F>,
327 DefaultAllocator: Allocator<D> + Allocator<U1, D> + Allocator<D, U1> + Allocator<D, D>,
328{
329 #[inline(always)]
330 fn to_superset(&self) -> DualVec<TSuper, FSuper, D> {
331 let re = TSuper::from_subset(&self.re);
332 let eps = Derivative::from_subset(&self.eps);
333 DualVec {
334 re,
335 eps,
336 f: PhantomData,
337 }
338 }
339 #[inline(always)]
340 fn from_superset(element: &DualVec<TSuper, FSuper, D>) -> Option<Self> {
341 let re = TSuper::to_subset(&element.re)?;
342 let eps = Derivative::to_subset(&element.eps)?;
343 Some(Self::new(re, eps))
344 }
345 #[inline(always)]
346 fn from_superset_unchecked(element: &DualVec<TSuper, FSuper, D>) -> Self {
347 let re = TSuper::to_subset_unchecked(&element.re);
348 let eps = Derivative::to_subset_unchecked(&element.eps);
349 Self::new(re, eps)
350 }
351 #[inline(always)]
352 fn is_in_subset(element: &DualVec<TSuper, FSuper, D>) -> bool {
353 TSuper::is_in_subset(&element.re)
354 && <Derivative<_, _, _, _> as SupersetOf<Derivative<_, _, _, _>>>::is_in_subset(
355 &element.eps,
356 )
357 }
358}
359
360impl<TSuper, FSuper, D: Dim> SupersetOf<f32> for DualVec<TSuper, FSuper, D>
361where
362 TSuper: DualNum<FSuper> + SupersetOf<f32>,
363 DefaultAllocator: Allocator<D> + Allocator<U1, D> + Allocator<D, U1> + Allocator<D, D>,
364{
365 #[inline(always)]
366 fn is_in_subset(&self) -> bool {
367 self.re.is_in_subset()
368 }
369
370 #[inline(always)]
371 fn to_subset_unchecked(&self) -> f32 {
372 self.re.to_subset_unchecked()
373 }
374
375 #[inline(always)]
376 fn from_subset(element: &f32) -> Self {
377 let re = TSuper::from_subset(element);
379 let eps = Derivative::none();
380 Self::new(re, eps)
381 }
382}
383
384impl<TSuper, FSuper, D: Dim> SupersetOf<f64> for DualVec<TSuper, FSuper, D>
385where
386 TSuper: DualNum<FSuper> + SupersetOf<f64>,
387 DefaultAllocator: Allocator<D> + Allocator<U1, D> + Allocator<D, U1> + Allocator<D, D>,
388{
389 #[inline(always)]
390 fn is_in_subset(&self) -> bool {
391 self.re.is_in_subset()
392 }
393
394 #[inline(always)]
395 fn to_subset_unchecked(&self) -> f64 {
396 self.re.to_subset_unchecked()
397 }
398
399 #[inline(always)]
400 fn from_subset(element: &f64) -> Self {
401 let re = TSuper::from_subset(element);
403 let eps = Derivative::none();
404 Self::new(re, eps)
405 }
406}
407
408use nalgebra::{ComplexField, RealField};
416impl<T, D: Dim> ComplexField for DualVec<T, T::Element, D>
418where
419 T: DualNum<T::Element> + SupersetOf<T> + AbsDiffEq<Epsilon = T> + Sync + Send,
420 T::Element: DualNum<T::Element> + Scalar + DualNumFloat + Sync + Send,
421 T: SupersetOf<T::Element>,
422 T: SupersetOf<f32>,
423 T: SupersetOf<f64>,
424 T: SimdPartialOrd + PartialOrd,
425 T: SimdValue<Element = T, SimdBool = bool>,
426 T: RelativeEq + UlpsEq + AbsDiffEq,
427 DefaultAllocator: Allocator<D> + Allocator<U1, D> + Allocator<D, U1> + Allocator<D, D>,
428 <DefaultAllocator as Allocator<D>>::Buffer<T>: Sync + Send,
429{
430 type RealField = Self;
431
432 #[inline]
433 fn from_real(re: Self::RealField) -> Self {
434 re
435 }
436
437 #[inline]
438 fn real(self) -> Self::RealField {
439 self
440 }
441
442 #[inline]
443 fn imaginary(self) -> Self::RealField {
444 Self::zero()
445 }
446
447 #[inline]
448 fn modulus(self) -> Self::RealField {
449 self.abs()
450 }
451
452 #[inline]
453 fn modulus_squared(self) -> Self::RealField {
454 &self * &self
455 }
456
457 #[inline]
458 fn argument(self) -> Self::RealField {
459 Self::zero()
460 }
461
462 #[inline]
463 fn norm1(self) -> Self::RealField {
464 self.abs()
465 }
466
467 #[inline]
468 fn scale(self, factor: Self::RealField) -> Self {
469 self * factor
470 }
471
472 #[inline]
473 fn unscale(self, factor: Self::RealField) -> Self {
474 self / factor
475 }
476
477 #[inline]
478 fn floor(self) -> Self {
479 panic!("called floor() on a dual number")
480 }
481
482 #[inline]
483 fn ceil(self) -> Self {
484 panic!("called ceil() on a dual number")
485 }
486
487 #[inline]
488 fn round(self) -> Self {
489 panic!("called round() on a dual number")
490 }
491
492 #[inline]
493 fn trunc(self) -> Self {
494 panic!("called trunc() on a dual number")
495 }
496
497 #[inline]
498 fn fract(self) -> Self {
499 panic!("called fract() on a dual number")
500 }
501
502 #[inline]
503 fn mul_add(self, a: Self, b: Self) -> Self {
504 DualNum::mul_add(&self, a, b)
505 }
506
507 #[inline]
508 fn abs(self) -> Self::RealField {
509 Signed::abs(&self)
510 }
511
512 #[inline]
513 fn hypot(self, other: Self) -> Self::RealField {
514 let sum_sq = self.powi(2) + other.powi(2);
515 DualNum::sqrt(&sum_sq)
516 }
517
518 #[inline]
519 fn recip(self) -> Self {
520 DualNum::recip(&self)
521 }
522
523 #[inline]
524 fn conjugate(self) -> Self {
525 self
526 }
527
528 #[inline]
529 fn sin(self) -> Self {
530 DualNum::sin(&self)
531 }
532
533 #[inline]
534 fn cos(self) -> Self {
535 DualNum::cos(&self)
536 }
537
538 #[inline]
539 fn sin_cos(self) -> (Self, Self) {
540 DualNum::sin_cos(&self)
541 }
542
543 #[inline]
544 fn tan(self) -> Self {
545 DualNum::tan(&self)
546 }
547
548 #[inline]
549 fn asin(self) -> Self {
550 DualNum::asin(&self)
551 }
552
553 #[inline]
554 fn acos(self) -> Self {
555 DualNum::acos(&self)
556 }
557
558 #[inline]
559 fn atan(self) -> Self {
560 DualNum::atan(&self)
561 }
562
563 #[inline]
564 fn sinh(self) -> Self {
565 DualNum::sinh(&self)
566 }
567
568 #[inline]
569 fn cosh(self) -> Self {
570 DualNum::cosh(&self)
571 }
572
573 #[inline]
574 fn tanh(self) -> Self {
575 DualNum::tanh(&self)
576 }
577
578 #[inline]
579 fn asinh(self) -> Self {
580 DualNum::asinh(&self)
581 }
582
583 #[inline]
584 fn acosh(self) -> Self {
585 DualNum::acosh(&self)
586 }
587
588 #[inline]
589 fn atanh(self) -> Self {
590 DualNum::atanh(&self)
591 }
592
593 #[inline]
594 fn log(self, base: Self::RealField) -> Self {
595 DualNum::ln(&self) / DualNum::ln(&base)
596 }
597
598 #[inline]
599 fn log2(self) -> Self {
600 DualNum::log2(&self)
601 }
602
603 #[inline]
604 fn log10(self) -> Self {
605 DualNum::log10(&self)
606 }
607
608 #[inline]
609 fn ln(self) -> Self {
610 DualNum::ln(&self)
611 }
612
613 #[inline]
614 fn ln_1p(self) -> Self {
615 DualNum::ln_1p(&self)
616 }
617
618 #[inline]
619 fn sqrt(self) -> Self {
620 DualNum::sqrt(&self)
621 }
622
623 #[inline]
624 fn exp(self) -> Self {
625 DualNum::exp(&self)
626 }
627
628 #[inline]
629 fn exp2(self) -> Self {
630 DualNum::exp2(&self)
631 }
632
633 #[inline]
634 fn exp_m1(self) -> Self {
635 DualNum::exp_m1(&self)
636 }
637
638 #[inline]
639 fn powi(self, n: i32) -> Self {
640 DualNum::powi(&self, n)
641 }
642
643 #[inline]
644 fn powf(self, n: Self::RealField) -> Self {
645 DualNum::powd(&self, n)
647 }
648
649 #[inline]
650 fn powc(self, n: Self) -> Self {
651 self.powf(n)
653 }
654
655 #[inline]
656 fn cbrt(self) -> Self {
657 DualNum::cbrt(&self)
658 }
659
660 #[inline]
661 fn is_finite(&self) -> bool {
662 self.re.is_finite()
663 }
664
665 #[inline]
666 fn try_sqrt(self) -> Option<Self> {
667 if self > Self::zero() {
668 Some(DualNum::sqrt(&self))
669 } else {
670 None
671 }
672 }
673}
674
675impl<T, D: Dim> RealField for DualVec<T, T::Element, D>
676where
677 T: DualNum<T::Element> + SupersetOf<T> + Sync + Send,
678 T::Element: DualNum<T::Element> + Scalar + DualNumFloat,
679 T: SupersetOf<T::Element>,
680 T: SupersetOf<f32>,
681 T: SupersetOf<f64>,
682 T: SimdPartialOrd + PartialOrd,
683 T: RelativeEq + AbsDiffEq<Epsilon = T>,
684 T: SimdValue<Element = T, SimdBool = bool>,
685 T: UlpsEq,
686 T: AbsDiffEq,
687 DefaultAllocator: Allocator<D> + Allocator<U1, D> + Allocator<D, U1> + Allocator<D, D>,
688 <DefaultAllocator as Allocator<D>>::Buffer<T>: Sync + Send,
689{
690 #[inline]
691 fn copysign(self, sign: Self) -> Self {
692 if sign.re.is_sign_positive() {
693 self.simd_abs()
694 } else {
695 -self.simd_abs()
696 }
697 }
698
699 #[inline]
700 fn atan2(self, other: Self) -> Self {
701 DualNum::atan2(&self, other)
702 }
703
704 #[inline]
705 fn pi() -> Self {
706 Self::from_re(<T as FloatConst>::PI())
707 }
708
709 #[inline]
710 fn two_pi() -> Self {
711 Self::from_re(<T as FloatConst>::TAU())
712 }
713
714 #[inline]
715 fn frac_pi_2() -> Self {
716 Self::from_re(<T as FloatConst>::FRAC_PI_4())
717 }
718
719 #[inline]
720 fn frac_pi_3() -> Self {
721 Self::from_re(<T as FloatConst>::FRAC_PI_3())
722 }
723
724 #[inline]
725 fn frac_pi_4() -> Self {
726 Self::from_re(<T as FloatConst>::FRAC_PI_4())
727 }
728
729 #[inline]
730 fn frac_pi_6() -> Self {
731 Self::from_re(<T as FloatConst>::FRAC_PI_6())
732 }
733
734 #[inline]
735 fn frac_pi_8() -> Self {
736 Self::from_re(<T as FloatConst>::FRAC_PI_8())
737 }
738
739 #[inline]
740 fn frac_1_pi() -> Self {
741 Self::from_re(<T as FloatConst>::FRAC_1_PI())
742 }
743
744 #[inline]
745 fn frac_2_pi() -> Self {
746 Self::from_re(<T as FloatConst>::FRAC_2_PI())
747 }
748
749 #[inline]
750 fn frac_2_sqrt_pi() -> Self {
751 Self::from_re(<T as FloatConst>::FRAC_2_SQRT_PI())
752 }
753
754 #[inline]
755 fn e() -> Self {
756 Self::from_re(<T as FloatConst>::E())
757 }
758
759 #[inline]
760 fn log2_e() -> Self {
761 Self::from_re(<T as FloatConst>::LOG2_E())
762 }
763
764 #[inline]
765 fn log10_e() -> Self {
766 Self::from_re(<T as FloatConst>::LOG10_E())
767 }
768
769 #[inline]
770 fn ln_2() -> Self {
771 Self::from_re(<T as FloatConst>::LN_2())
772 }
773
774 #[inline]
775 fn ln_10() -> Self {
776 Self::from_re(<T as FloatConst>::LN_10())
777 }
778
779 #[inline]
780 fn is_sign_positive(&self) -> bool {
781 self.re.is_sign_positive()
782 }
783
784 #[inline]
785 fn is_sign_negative(&self) -> bool {
786 self.re.is_sign_negative()
787 }
788
789 #[inline]
791 fn max(self, other: Self) -> Self {
792 if other > self { other } else { self }
793 }
794
795 #[inline]
797 fn min(self, other: Self) -> Self {
798 if other < self { other } else { self }
799 }
800
801 #[inline]
803 fn clamp(self, min: Self, max: Self) -> Self {
804 if self < min {
805 min
806 } else if self > max {
807 max
808 } else {
809 self
810 }
811 }
812
813 #[inline]
814 fn min_value() -> Option<Self> {
815 Some(Self::from_re(T::min_value()))
816 }
817
818 #[inline]
819 fn max_value() -> Option<Self> {
820 Some(Self::from_re(T::max_value()))
821 }
822}