num_dual/datatypes/
dual.rs1use crate::{DualNum, DualNumFloat, DualStruct};
2use num_traits::{Float, FloatConst, FromPrimitive, Inv, Num, One, Signed, Zero};
3#[cfg(feature = "serde")]
4use serde::{Deserialize, Serialize};
5use std::fmt;
6use std::iter::{Product, Sum};
7use std::marker::PhantomData;
8use std::ops::{
9 Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign,
10};
11
12#[derive(Copy, Clone, Debug)]
14#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
15pub struct Dual<T: DualNum<F>, F> {
16 pub re: T,
18 pub eps: T,
20 #[cfg_attr(feature = "serde", serde(skip))]
21 f: PhantomData<F>,
22}
23
24#[cfg(feature = "ndarray")]
25impl<T: DualNum<F>, F: DualNumFloat> ndarray::ScalarOperand for Dual<T, F> {}
26
27pub type Dual32 = Dual<f32, f32>;
28pub type Dual64 = Dual<f64, f64>;
29
30impl<T: DualNum<F>, F> Dual<T, F> {
31 #[inline]
33 pub fn new(re: T, eps: T) -> Self {
34 Self {
35 re,
36 eps,
37 f: PhantomData,
38 }
39 }
40}
41
42impl<T: DualNum<F> + Zero, F> Dual<T, F> {
43 #[inline]
45 pub fn from_re(re: T) -> Self {
46 Self::new(re, T::zero())
47 }
48}
49
50impl<T: DualNum<F> + One, F> Dual<T, F> {
51 #[inline]
59 pub fn derivative(mut self) -> Self {
60 self.eps = T::one();
61 self
62 }
63}
64
65impl<T: DualNum<F>, F: Float> Dual<T, F> {
67 #[inline]
68 fn chain_rule(&self, f0: T, f1: T) -> Self {
69 Self::new(f0, self.eps.clone() * f1)
70 }
71}
72
73impl<T: DualNum<F>, F: Float> Mul<&Dual<T, F>> for &Dual<T, F> {
75 type Output = Dual<T, F>;
76 #[inline]
77 fn mul(self, other: &Dual<T, F>) -> Self::Output {
78 Dual::new(
79 self.re.clone() * other.re.clone(),
80 self.eps.clone() * other.re.clone() + other.eps.clone() * self.re.clone(),
81 )
82 }
83}
84
85impl<T: DualNum<F>, F: Float> Div<&Dual<T, F>> for &Dual<T, F> {
87 type Output = Dual<T, F>;
88 #[inline]
89 fn div(self, other: &Dual<T, F>) -> Dual<T, F> {
90 let inv = other.re.recip();
91 Dual::new(
92 self.re.clone() * inv.clone(),
93 (self.eps.clone() * other.re.clone() - other.eps.clone() * self.re.clone())
94 * inv.clone()
95 * inv,
96 )
97 }
98}
99
100impl<T: DualNum<F>, F> fmt::Display for Dual<T, F> {
102 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
103 write!(f, "{} + {}ε", self.re, self.eps)
104 }
105}
106
107impl_first_derivatives!(Dual, [eps]);
108impl_dual!(Dual, [eps]);
109impl_nalgebra!(Dual, [eps]);