sophus_core/calculus/maps/
curves.rs

1use crate::linalg::SMat;
2use crate::prelude::*;
3use nalgebra::SVector;
4
5/// A smooth curve in ℝ.
6///
7/// This is a function which takes a scalar and returns a scalar:
8///
9///  f: ℝ -> ℝ
10pub struct ScalarValuedCurve<S: IsScalar<BATCH>, const BATCH: usize> {
11    phantom: core::marker::PhantomData<S>,
12}
13
14impl<S: IsScalar<BATCH>, const BATCH: usize> ScalarValuedCurve<S, BATCH> {
15    /// Finite difference quotient of the scalar-valued curve.
16    ///
17    /// The derivative is also a scalar.
18    pub fn sym_diff_quotient<TFn>(curve: TFn, a: S, h: f64) -> S
19    where
20        TFn: Fn(S) -> S,
21    {
22        let hh = S::from_f64(h);
23        (curve(a.clone() + hh.clone()) - curve(a - hh)) / S::from_f64(2.0 * h)
24    }
25}
26
27impl<D: IsDualScalar<BATCH>, const BATCH: usize> ScalarValuedCurve<D, BATCH> {
28    /// Auto differentiation of the scalar-valued curve.
29    pub fn fw_autodiff<TFn>(curve: TFn, a: D::RealScalar) -> D::RealScalar
30    where
31        TFn: Fn(D) -> D,
32    {
33        curve(D::new_with_dij(a))
34            .dij_val()
35            .clone()
36            .unwrap()
37            .get([0, 0])
38    }
39}
40
41/// A smooth curve in ℝʳ.
42///
43/// This is a function which takes a scalar and returns a vector:
44///
45///   f: ℝ -> ℝʳ
46pub struct VectorValuedCurve<S: IsScalar<BATCH>, const BATCH: usize> {
47    phantom: core::marker::PhantomData<S>,
48}
49
50impl<S: IsScalar<BATCH>, const BATCH: usize> VectorValuedCurve<S, BATCH> {
51    /// Finite difference quotient of the vector-valued curve.
52    ///
53    /// The derivative is also a vector.
54    pub fn sym_diff_quotient<TFn, const ROWS: usize>(curve: TFn, a: S, h: f64) -> S::Vector<ROWS>
55    where
56        TFn: Fn(S) -> S::Vector<ROWS>,
57    {
58        let hh = S::from_f64(h);
59        (curve(a.clone() + hh.clone()) - curve(a - hh)).scaled(S::from_f64(1.0 / (2.0 * h)))
60    }
61}
62
63impl<D: IsDualScalar<BATCH>, const BATCH: usize> VectorValuedCurve<D, BATCH> {
64    /// Auto differentiation of the vector-valued curve.
65    pub fn fw_autodiff<TFn, const ROWS: usize>(
66        curve: TFn,
67        a: D::RealScalar,
68    ) -> SVector<D::RealScalar, ROWS>
69    where
70        TFn: Fn(D) -> D::Vector<ROWS>,
71        D::Vector<ROWS>: IsDualVector<D, ROWS, BATCH>,
72    {
73        curve(D::new_with_dij(a)).dij_val().unwrap().get([0, 0])
74    }
75}
76
77/// A smooth curve in ℝʳ x ℝᶜ.
78///
79/// This is a function which takes a scalar and returns a matrix:
80///   f: ℝ -> ℝʳ x ℝᶜ
81pub struct MatrixValuedCurve<S: IsScalar<BATCH>, const BATCH: usize> {
82    phantom: core::marker::PhantomData<S>,
83}
84
85impl<S: IsScalar<BATCH>, const BATCH: usize> MatrixValuedCurve<S, BATCH> {
86    /// Finite difference quotient of the matrix-valued curve.
87    ///
88    /// The derivative is also a matrix.
89    pub fn sym_diff_quotient<TFn, const ROWS: usize, const COLS: usize>(
90        curve: TFn,
91        a: S,
92        h: f64,
93    ) -> S::Matrix<ROWS, COLS>
94    where
95        TFn: Fn(S) -> S::Matrix<ROWS, COLS>,
96    {
97        let hh = S::from_f64(h);
98        (curve(a.clone() + hh.clone()) - curve(a - hh)).scaled(S::from_f64(1.0 / (2.0 * h)))
99    }
100}
101
102impl<D: IsDualScalar<BATCH>, const BATCH: usize> MatrixValuedCurve<D, BATCH> {
103    /// Auto differentiation of the matrix-valued curve.
104    pub fn fw_autodiff<TFn, const ROWS: usize, const COLS: usize>(
105        curve: TFn,
106        a: D::RealScalar,
107    ) -> SMat<<D as IsScalar<BATCH>>::RealScalar, ROWS, COLS>
108    where
109        TFn: Fn(D) -> D::Matrix<ROWS, COLS>,
110        D::Matrix<ROWS, COLS>: IsDualMatrix<D, ROWS, COLS, BATCH>,
111    {
112        curve(D::new_with_dij(a)).dij_val().unwrap().get([0, 0])
113    }
114}
115
116#[test]
117fn curve_test() {
118    use crate::calculus::dual::DualScalar;
119    use crate::linalg::scalar::IsScalar;
120    use crate::linalg::EPS_F64;
121
122    #[cfg(feature = "simd")]
123    use crate::calculus::dual::DualBatchScalar;
124    #[cfg(feature = "simd")]
125    use crate::linalg::BatchScalarF64;
126
127    trait CurveTest {
128        fn run_curve_test();
129    }
130
131    macro_rules! def_curve_test_template {
132        ($batch:literal, $scalar: ty, $dual_scalar: ty
133    ) => {
134            impl CurveTest for $dual_scalar {
135                fn run_curve_test() {
136                    use crate::linalg::vector::IsVector;
137
138                    for i in 0..10 {
139                        let a = <$scalar>::from_f64(0.1 * (i as f64));
140
141                        // f(x) = x^2
142                        fn square_fn<S: IsScalar<BATCH>, const BATCH: usize>(x: S) -> S {
143                            x.clone() * x
144                        }
145                        let finite_diff = ScalarValuedCurve::<$scalar, $batch>::sym_diff_quotient(
146                            square_fn,
147                            a.clone(),
148                            EPS_F64,
149                        );
150                        let auto_grad =
151                            ScalarValuedCurve::<$dual_scalar, $batch>::fw_autodiff(square_fn, a);
152                        approx::assert_abs_diff_eq!(finite_diff, auto_grad, epsilon = 0.0001);
153                    }
154
155                    for i in 0..10 {
156                        let a = <$scalar>::from_f64(0.1 * (i as f64));
157
158                        // f(x) = [cos(x), sin(x)]
159                        fn trig_fn<S: IsScalar<BATCH>, const BATCH: usize>(x: S) -> S::Vector<2> {
160                            S::Vector::<2>::from_array([x.clone().cos(), x.sin()])
161                        }
162
163                        let finite_diff = VectorValuedCurve::<$scalar, $batch>::sym_diff_quotient(
164                            trig_fn,
165                            a.clone(),
166                            EPS_F64,
167                        );
168                        let auto_grad =
169                            VectorValuedCurve::<$dual_scalar, $batch>::fw_autodiff(trig_fn, a);
170                        approx::assert_abs_diff_eq!(finite_diff, auto_grad, epsilon = 0.0001);
171                    }
172
173                    for i in 0..10 {
174                        let a = <$scalar>::from_f64(0.1 * (i as f64));
175
176                        // f(x) = [[ cos(x), sin(x), 0],
177                        //         [-sin(x), cos(x), 0]]
178                        fn fn_x<S: IsScalar<BATCH>, const BATCH: usize>(x: S) -> S::Matrix<2, 3> {
179                            let sin = x.clone().sin();
180                            let cos = x.clone().cos();
181
182                            S::Matrix::from_array2([
183                                [cos.clone(), sin.clone(), S::from_f64(0.0)],
184                                [-sin, cos, S::from_f64(0.0)],
185                            ])
186                        }
187
188                        let finite_diff = MatrixValuedCurve::<$scalar, $batch>::sym_diff_quotient(
189                            fn_x,
190                            a.clone(),
191                            EPS_F64,
192                        );
193                        let auto_grad =
194                            MatrixValuedCurve::<$dual_scalar, $batch>::fw_autodiff(fn_x, a);
195                        approx::assert_abs_diff_eq!(finite_diff, auto_grad, epsilon = 0.0001);
196                    }
197                }
198            }
199        };
200    }
201
202    def_curve_test_template!(1, f64, DualScalar);
203    #[cfg(feature = "simd")]
204    def_curve_test_template!(2, BatchScalarF64<2>, DualBatchScalar<2>);
205    #[cfg(feature = "simd")]
206    def_curve_test_template!(4, BatchScalarF64<4>, DualBatchScalar<4>);
207    #[cfg(feature = "simd")]
208    def_curve_test_template!(8, BatchScalarF64<8>, DualBatchScalar<8>);
209
210    DualScalar::run_curve_test();
211    #[cfg(feature = "simd")]
212    DualBatchScalar::<2>::run_curve_test();
213    #[cfg(feature = "simd")]
214    DualBatchScalar::<4>::run_curve_test();
215    #[cfg(feature = "simd")]
216    DualBatchScalar::<8>::run_curve_test();
217}