Skip to main content

numdiff/automatic_differentiation/
dual_vector.rs

1use crate::automatic_differentiation::dual::Dual;
2use linalg_traits::{Scalar, Vector};
3
4/// Trait to create a vector of dual numbers.
5pub trait DualVector<S, V>
6where
7    S: Scalar,
8    V: Vector<S>,
9{
10    /// Convert this vector of scalars to a vector of dual numbers.
11    ///
12    /// # Returns
13    ///
14    /// A copy of this vector with each element converted to a dual number (with dual part `0.0`).
15    fn to_dual_vector(self) -> V::VectorT<Dual>;
16}
17
18impl<S, V> DualVector<S, V> for V
19where
20    S: Scalar,
21    V: Vector<S>,
22{
23    fn to_dual_vector(self) -> V::VectorT<Dual> {
24        let mut vec_dual = V::VectorT::new_with_length(self.len());
25        for i in 0..self.len() {
26            vec_dual.vset(i, Dual::new(self.vget(i).to_f64().unwrap(), 0.0));
27        }
28        vec_dual
29    }
30}
31
32#[cfg(test)]
33mod tests {
34    use super::*;
35    use nalgebra::{dvector, SVector};
36    use ndarray::array;
37
38    #[test]
39    fn test_vec() {
40        let vec = vec![1.0, 2.0, 3.0];
41        assert_eq!(
42            vec.to_dual_vector(),
43            vec![
44                Dual::new(1.0, 0.0),
45                Dual::new(2.0, 0.0),
46                Dual::new(3.0, 0.0)
47            ]
48        );
49    }
50
51    #[test]
52    fn test_nalgebra_dvector() {
53        let vec = dvector![1.0, 2.0, 3.0];
54        assert_eq!(
55            vec.to_dual_vector(),
56            dvector![
57                Dual::new(1.0, 0.0),
58                Dual::new(2.0, 0.0),
59                Dual::new(3.0, 0.0)
60            ]
61        );
62    }
63
64    #[test]
65    fn test_nalgebra_svector() {
66        let vec = SVector::<f64, 3>::from_row_slice(&[1.0, 2.0, 3.0]);
67        assert_eq!(
68            vec.to_dual_vector(),
69            SVector::<Dual, 3>::from_row_slice(&[
70                Dual::new(1.0, 0.0),
71                Dual::new(2.0, 0.0),
72                Dual::new(3.0, 0.0)
73            ])
74        );
75    }
76
77    #[test]
78    fn test_ndarray_array1() {
79        let vec = array![1.0, 2.0, 3.0];
80        assert_eq!(
81            vec.to_dual_vector(),
82            vec![
83                Dual::new(1.0, 0.0),
84                Dual::new(2.0, 0.0),
85                Dual::new(3.0, 0.0)
86            ]
87        );
88    }
89}