b_spline/
so3bspline.rs

1use std::collections::HashMap;
2use std::num::NonZero;
3use std::sync::Arc;
4
5use crate::common::{compute_base_coefficients, compute_blending_matrix};
6use crate::traits::*;
7use tiny_solver::loss_functions::HuberLoss;
8use tiny_solver::manifold::{AutoDiffManifold, Manifold};
9use tiny_solver::{self, GaussNewtonOptimizer, Optimizer, manifold::so3::SO3, na};
10
11pub struct RvecManifold;
12impl<T: na::RealField> AutoDiffManifold<T> for RvecManifold {
13    fn plus(&self, x: na::DVectorView<T>, delta: na::DVectorView<T>) -> na::DVector<T> {
14        (SO3::exp(x) * SO3::exp(delta)).log()
15    }
16
17    fn minus(&self, y: na::DVectorView<T>, x: na::DVectorView<T>) -> na::DVector<T> {
18        let y_so3 = SO3::from_vec(y);
19        let x_so3_inv = SO3::from_vec(x).inverse();
20        (x_so3_inv * y_so3).log()
21    }
22}
23
24impl Manifold for RvecManifold {
25    fn tangent_size(&self) -> NonZero<usize> {
26        NonZero::new(3).unwrap()
27    }
28}
29
30#[derive(Debug, Clone)]
31pub struct SO3Bspline<const N: usize> {
32    timestamp_start_ns: u64,
33    spacing_ns: u64,
34    pub knots: Vec<[f64; 3]>,
35    blending_matrix: na::DMatrix<f64>,
36    first_derivative_bases: na::DVector<f64>,
37}
38
39pub fn so3_from_u_and_knots<T: na::RealField>(
40    u: f64,
41    knots: &[na::DVector<T>],
42    blending_matrix: &na::DMatrix<f64>,
43) -> SO3<T> {
44    let n = knots.len();
45    let uv = na::DVector::from_fn(n, |i, _| u.powi(i as i32));
46    let kv = (blending_matrix * uv).cast::<T>();
47    let mut r = knots[0].to_so3();
48    for j in 1..n {
49        let k = kv[j].clone();
50        let r_j = knots[j].to_so3();
51        let r_j_minus_inv = knots[j - 1].to_so3().inverse();
52        let rj_vec = (r_j_minus_inv * r_j).log() * k;
53        r = r * rj_vec.to_so3();
54    }
55    r
56}
57
58struct RotationCost {
59    u: f64,
60    rvec: na::Vector3<f64>,
61    blending_matrix: na::DMatrix<f64>,
62}
63impl RotationCost {
64    pub fn new(u: f64, rvec: &[f64; 3], blending_matrix: &na::DMatrix<f64>) -> Self {
65        RotationCost {
66            u,
67            rvec: na::Vector3::new(rvec[0], rvec[1], rvec[2]),
68            blending_matrix: blending_matrix.clone(),
69        }
70    }
71}
72
73impl<T: na::RealField> tiny_solver::factors::Factor<T> for RotationCost {
74    fn residual_func(&self, params: &[na::DVector<T>]) -> na::DVector<T> {
75        let r = so3_from_u_and_knots(self.u, params, &self.blending_matrix);
76        let target = self.rvec.cast::<T>().to_dvec();
77        let target = SO3::exp(target.as_view());
78        (target.inverse() * r).log()
79    }
80}
81
82impl<const N: usize> SO3Bspline<N> {
83    fn fit(&mut self, timestamps_ns: &[u64], rvecs: &[[f64; 3]]) {
84        let mut problem = tiny_solver::Problem::new();
85        let mut initial_values = HashMap::new();
86        for (&t_ns, rvec) in timestamps_ns.iter().zip(rvecs) {
87            let (u, idx) = self.get_u_and_index(t_ns);
88            let cost = RotationCost::new(u, rvec, &self.blending_matrix());
89            let mut var_list = Vec::new();
90            for i in idx..idx + N {
91                let var_name = format!("r{}", i);
92                if !initial_values.contains_key(&var_name) {
93                    problem.set_variable_manifold(&var_name, Arc::new(RvecManifold));
94                    let rvec = self.knots[i].to_dvec();
95                    initial_values.insert(var_name.clone(), rvec);
96                }
97                var_list.push(var_name);
98            }
99            let var_list: Vec<_> = var_list.iter().map(|a| a.as_str()).collect();
100            problem.add_residual_block(
101                3,
102                &var_list,
103                Box::new(cost),
104                Some(Box::new(HuberLoss::new(0.1))),
105            );
106        }
107        let optimizer = GaussNewtonOptimizer::default();
108        let result = optimizer.optimize(&problem, &initial_values, None).unwrap();
109        self.knots = (0..self.knots.len())
110            .map(|i| {
111                let var_name = format!("r{}", i);
112                let knot = result.get(&var_name).unwrap();
113                [knot[0], knot[1], knot[2]]
114            })
115            .collect();
116    }
117    pub fn blending_matrix(&self) -> na::DMatrix<f64> {
118        self.blending_matrix.clone()
119    }
120
121    pub fn first_derivative_bases(&self) -> na::DVector<f64> {
122        self.first_derivative_bases.clone()
123    }
124
125    pub fn get_u_and_index(&self, timestamp_ns: u64) -> (f64, usize) {
126        // println!("req: {}", timestamp_ns);
127        let time_offset = timestamp_ns - self.timestamp_start_ns;
128        // println!("offset {}", time_offset);
129        // println!("mod {}", time_offset % self.spacing_ns);
130        let u = (time_offset % self.spacing_ns) as f64 / self.spacing_ns as f64;
131        // println!("u {}", u);
132        let idx = time_offset / self.spacing_ns;
133        (u, idx as usize)
134    }
135
136    pub fn get_rotation(&self, timestamp_ns: u64) -> SO3<f64> {
137        let (u, i) = self.get_u_and_index(timestamp_ns);
138        let knots: Vec<_> = (i..i + N).map(|idx| self.knots[idx].to_dvec()).collect();
139        so3_from_u_and_knots(u, &knots, &self.blending_matrix())
140    }
141
142    pub fn get_velocity(&self, timestamp_ns: u64) -> na::Vector3<f64> {
143        let (u, idx) = self.get_u_and_index(timestamp_ns);
144
145        let coeff = &self.blending_matrix * Self::base_coeffs_with_time::<0>(u);
146        let d_tn_s = self.spacing_ns as f64 / 1e9;
147        let dcoeff = 1.0 / d_tn_s * &self.blending_matrix * Self::base_coeffs_with_time::<1>(u);
148        let mut w = na::Vector3::<f64>::zeros();
149        for j in 0..(N - 1) {
150            let p0 = self.knots[idx + j].to_so3();
151            let p1 = self.knots[idx + j + 1].to_so3();
152            let r01 = p0.inverse() * p1;
153            let delta = r01.log();
154            let ww = ((-1.0 * &delta) * coeff[j + 1]).to_so3();
155            w = &ww * w.as_view();
156            w += delta * dcoeff[j + 1];
157        }
158        w
159    }
160
161    pub fn from_rotation_vectors(
162        timestamps_ns: &[u64],
163        rvecs: &[[f64; 3]],
164        spacing_ns: u64,
165    ) -> Self {
166        if timestamps_ns.len() < N {
167            panic!("timestams_ns should be larger than {}", N);
168        } else if timestamps_ns.len() != rvecs.len() {
169            panic!("timestameps_ns should have the same length as rvecs");
170        }
171        let mut prev_t = timestamps_ns[0];
172        for &t in timestamps_ns {
173            if t - prev_t > spacing_ns {
174                panic!("spacing should be larger than all time stamp steps");
175            }
176            prev_t = t;
177        }
178        let num_of_knots =
179            ((timestamps_ns.last().unwrap() - timestamps_ns[0]) / spacing_ns) as usize + N;
180        let timestamp_start_ns = timestamps_ns[0];
181
182        let mut current_idx = 0;
183        let mut knots = Vec::new();
184        for i in 0..num_of_knots {
185            let knot_time_ns = timestamp_start_ns + i as u64 * spacing_ns;
186            while current_idx < timestamps_ns.len() - 1 {
187                if timestamps_ns[current_idx + 1] < knot_time_ns {
188                    current_idx += 1;
189                } else {
190                    break;
191                }
192            }
193            knots.push(rvecs[current_idx]);
194        }
195
196        let mut bspline = SO3Bspline {
197            timestamp_start_ns,
198            spacing_ns,
199            knots,
200            blending_matrix: compute_blending_matrix(N, true),
201            first_derivative_bases: compute_base_coefficients(N).row(1).transpose(),
202        };
203        bspline.fit(timestamps_ns, rvecs);
204        bspline
205    }
206    fn base_coeffs_with_time<const DERIVATIVE: usize>(u: f64) -> na::DVector<f64> {
207        let mut res = na::DVector::zeros(N);
208        let base_coefficients = compute_base_coefficients(N);
209        if DERIVATIVE < N {
210            res[DERIVATIVE] = base_coefficients[(DERIVATIVE, DERIVATIVE)];
211            let mut ti = u;
212            for j in (DERIVATIVE + 1)..N {
213                res[j] = base_coefficients[(DERIVATIVE, j)] * ti;
214                ti *= u;
215            }
216        }
217        res
218    }
219}