1use crate::{Derivative, DualNum, DualNumFloat, DualStruct};
2use nalgebra::allocator::Allocator;
3use nalgebra::*;
4use num_traits::{Float, FloatConst, FromPrimitive, Inv, Num, One, Signed, Zero};
5use std::fmt;
6use std::iter::{Product, Sum};
7use std::marker::PhantomData;
8use std::ops::{
9 Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign,
10};
11
12#[derive(Clone, Debug)]
14pub struct DualVec<T: DualNum<F>, F, D: Dim>
15where
16 DefaultAllocator: Allocator<D>,
17{
18 pub re: T,
20 pub eps: Derivative<T, F, D, U1>,
22 f: PhantomData<F>,
23}
24
25#[cfg(feature = "ndarray")]
26impl<T: DualNum<F>, F: DualNumFloat, D: Dim> ndarray::ScalarOperand for DualVec<T, F, D> where
27 DefaultAllocator: Allocator<D>
28{
29}
30
31impl<T: DualNum<F> + Copy, F: Copy, const N: usize> Copy for DualVec<T, F, Const<N>> {}
32
33pub type DualSVec<D, F, const N: usize> = DualVec<D, F, Const<N>>;
34pub type DualDVec<D, F> = DualVec<D, F, Dyn>;
35pub type DualVec32<D> = DualVec<f32, f32, D>;
36pub type DualVec64<D> = DualVec<f64, f64, D>;
37pub type DualSVec32<const N: usize> = DualVec<f32, f32, Const<N>>;
38pub type DualSVec64<const N: usize> = DualVec<f64, f64, Const<N>>;
39pub type DualDVec32 = DualVec<f32, f32, Dyn>;
40pub type DualDVec64 = DualVec<f64, f64, Dyn>;
41
42impl<T: DualNum<F>, F, D: Dim> DualVec<T, F, D>
43where
44 DefaultAllocator: Allocator<D>,
45{
46 #[inline]
48 pub fn new(re: T, eps: Derivative<T, F, D, U1>) -> Self {
49 Self {
50 re,
51 eps,
52 f: PhantomData,
53 }
54 }
55}
56
57impl<T: DualNum<F>, F, const N: usize> DualSVec<T, F, N> {
58 #[inline]
73 pub fn derivative(mut self, index: usize) -> Self {
74 self.eps = Derivative::derivative_generic(Const::<N>, U1, index);
75 self
76 }
77}
78
79impl<T: DualNum<F>, F> DualDVec<T, F> {
80 #[inline]
95 pub fn derivative(mut self, variables: usize, index: usize) -> Self {
96 self.eps = Derivative::derivative_generic(Dyn(variables), U1, index);
97 self
98 }
99}
100
101impl<T: DualNum<F> + Zero, F, D: Dim> DualVec<T, F, D>
102where
103 DefaultAllocator: Allocator<D>,
104{
105 #[inline]
107 pub fn from_re(re: T) -> Self {
108 Self::new(re, Derivative::none())
109 }
110}
111
112impl<T: DualNum<F>, F: Float, D: Dim> DualVec<T, F, D>
114where
115 DefaultAllocator: Allocator<D>,
116{
117 #[inline]
118 fn chain_rule(&self, f0: T, f1: T) -> Self {
119 Self::new(f0, &self.eps * f1)
120 }
121}
122
123impl<T: DualNum<F>, F: Float, D: Dim> Mul<&DualVec<T, F, D>> for &DualVec<T, F, D>
125where
126 DefaultAllocator: Allocator<D>,
127{
128 type Output = DualVec<T, F, D>;
129 #[inline]
130 fn mul(self, other: &DualVec<T, F, D>) -> Self::Output {
131 DualVec::new(
132 self.re.clone() * other.re.clone(),
133 &self.eps * other.re.clone() + &other.eps * self.re.clone(),
134 )
135 }
136}
137
138impl<T: DualNum<F>, F: Float, D: Dim> Div<&DualVec<T, F, D>> for &DualVec<T, F, D>
140where
141 DefaultAllocator: Allocator<D>,
142{
143 type Output = DualVec<T, F, D>;
144 #[inline]
145 fn div(self, other: &DualVec<T, F, D>) -> DualVec<T, F, D> {
146 let inv = other.re.recip();
147 DualVec::new(
148 self.re.clone() * inv.clone(),
149 (&self.eps * other.re.clone() - &other.eps * self.re.clone()) * inv.clone() * inv,
150 )
151 }
152}
153
154impl<T: DualNum<F>, F, D: Dim> fmt::Display for DualVec<T, F, D>
156where
157 DefaultAllocator: Allocator<D>,
158{
159 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
160 write!(f, "{}", self.re)?;
161 self.eps.fmt(f, "ε")
162 }
163}
164
165impl_first_derivatives!(DualVec, [eps], [D], [D]);
166impl_dual!(DualVec, [eps], [D], [D]);
167impl_nalgebra!(DualVec, [eps], [D], [D]);