1use std::collections::HashMap;
2use std::num::NonZero;
3use std::sync::Arc;
4
5use crate::common::{compute_base_coefficients, compute_blending_matrix};
6use crate::traits::*;
7use tiny_solver::loss_functions::HuberLoss;
8use tiny_solver::manifold::{AutoDiffManifold, Manifold};
9use tiny_solver::{self, GaussNewtonOptimizer, Optimizer, manifold::so3::SO3, na};
10
11pub struct RvecManifold;
12impl<T: na::RealField> AutoDiffManifold<T> for RvecManifold {
13 fn plus(&self, x: na::DVectorView<T>, delta: na::DVectorView<T>) -> na::DVector<T> {
14 (SO3::exp(x) * SO3::exp(delta)).log()
15 }
16
17 fn minus(&self, y: na::DVectorView<T>, x: na::DVectorView<T>) -> na::DVector<T> {
18 let y_so3 = SO3::from_vec(y);
19 let x_so3_inv = SO3::from_vec(x).inverse();
20 (x_so3_inv * y_so3).log()
21 }
22}
23
24impl Manifold for RvecManifold {
25 fn tangent_size(&self) -> NonZero<usize> {
26 NonZero::new(3).unwrap()
27 }
28}
29
30#[derive(Debug, Clone)]
31pub struct SO3Bspline<const N: usize> {
32 timestamp_start_ns: u64,
33 spacing_ns: u64,
34 pub knots: Vec<[f64; 3]>,
35 blending_matrix: na::DMatrix<f64>,
36 first_derivative_bases: na::DVector<f64>,
37}
38
39pub fn so3_from_u_and_knots<T: na::RealField>(
40 u: f64,
41 knots: &[na::DVector<T>],
42 blending_matrix: &na::DMatrix<f64>,
43) -> SO3<T> {
44 let n = knots.len();
45 let uv = na::DVector::from_fn(n, |i, _| u.powi(i as i32));
46 let kv = (blending_matrix * uv).cast::<T>();
47 let mut r = knots[0].to_so3();
48 for j in 1..n {
49 let k = kv[j].clone();
50 let r_j = knots[j].to_so3();
51 let r_j_minus_inv = knots[j - 1].to_so3().inverse();
52 let rj_vec = (r_j_minus_inv * r_j).log() * k;
53 r = r * rj_vec.to_so3();
54 }
55 r
56}
57
58struct RotationCost {
59 u: f64,
60 rvec: na::Vector3<f64>,
61 blending_matrix: na::DMatrix<f64>,
62}
63impl RotationCost {
64 pub fn new(u: f64, rvec: &[f64; 3], blending_matrix: &na::DMatrix<f64>) -> Self {
65 RotationCost {
66 u,
67 rvec: na::Vector3::new(rvec[0], rvec[1], rvec[2]),
68 blending_matrix: blending_matrix.clone(),
69 }
70 }
71}
72
73impl<T: na::RealField> tiny_solver::factors::Factor<T> for RotationCost {
74 fn residual_func(&self, params: &[na::DVector<T>]) -> na::DVector<T> {
75 let r = so3_from_u_and_knots(self.u, params, &self.blending_matrix);
76 let target = self.rvec.cast::<T>().to_dvec();
77 let target = SO3::exp(target.as_view());
78 (target.inverse() * r).log()
79 }
80}
81
82impl<const N: usize> SO3Bspline<N> {
83 fn fit(&mut self, timestamps_ns: &[u64], rvecs: &[[f64; 3]]) {
84 let mut problem = tiny_solver::Problem::new();
85 let mut initial_values = HashMap::new();
86 for (&t_ns, rvec) in timestamps_ns.iter().zip(rvecs) {
87 let (u, idx) = self.get_u_and_index(t_ns);
88 let cost = RotationCost::new(u, rvec, &self.blending_matrix());
89 let mut var_list = Vec::new();
90 for i in idx..idx + N {
91 let var_name = format!("r{}", i);
92 if !initial_values.contains_key(&var_name) {
93 problem.set_variable_manifold(&var_name, Arc::new(RvecManifold));
94 let rvec = self.knots[i].to_dvec();
95 initial_values.insert(var_name.clone(), rvec);
96 }
97 var_list.push(var_name);
98 }
99 let var_list: Vec<_> = var_list.iter().map(|a| a.as_str()).collect();
100 problem.add_residual_block(
101 3,
102 &var_list,
103 Box::new(cost),
104 Some(Box::new(HuberLoss::new(0.1))),
105 );
106 }
107 let optimizer = GaussNewtonOptimizer::default();
108 let result = optimizer.optimize(&problem, &initial_values, None).unwrap();
109 self.knots = (0..self.knots.len())
110 .map(|i| {
111 let var_name = format!("r{}", i);
112 let knot = result.get(&var_name).unwrap();
113 [knot[0], knot[1], knot[2]]
114 })
115 .collect();
116 }
117 pub fn blending_matrix(&self) -> na::DMatrix<f64> {
118 self.blending_matrix.clone()
119 }
120
121 pub fn first_derivative_bases(&self) -> na::DVector<f64> {
122 self.first_derivative_bases.clone()
123 }
124
125 pub fn get_u_and_index(&self, timestamp_ns: u64) -> (f64, usize) {
126 let time_offset = timestamp_ns - self.timestamp_start_ns;
128 let u = (time_offset % self.spacing_ns) as f64 / self.spacing_ns as f64;
131 let idx = time_offset / self.spacing_ns;
133 (u, idx as usize)
134 }
135
136 pub fn get_rotation(&self, timestamp_ns: u64) -> SO3<f64> {
137 let (u, i) = self.get_u_and_index(timestamp_ns);
138 let knots: Vec<_> = (i..i + N).map(|idx| self.knots[idx].to_dvec()).collect();
139 so3_from_u_and_knots(u, &knots, &self.blending_matrix())
140 }
141
142 pub fn get_velocity(&self, timestamp_ns: u64) -> na::Vector3<f64> {
143 let (u, idx) = self.get_u_and_index(timestamp_ns);
144
145 let coeff = &self.blending_matrix * Self::base_coeffs_with_time::<0>(u);
146 let d_tn_s = self.spacing_ns as f64 / 1e9;
147 let dcoeff = 1.0 / d_tn_s * &self.blending_matrix * Self::base_coeffs_with_time::<1>(u);
148 let mut w = na::Vector3::<f64>::zeros();
149 for j in 0..(N - 1) {
150 let p0 = self.knots[idx + j].to_so3();
151 let p1 = self.knots[idx + j + 1].to_so3();
152 let r01 = p0.inverse() * p1;
153 let delta = r01.log();
154 let ww = ((-1.0 * &delta) * coeff[j + 1]).to_so3();
155 w = &ww * w.as_view();
156 w += delta * dcoeff[j + 1];
157 }
158 w
159 }
160
161 pub fn from_rotation_vectors(
162 timestamps_ns: &[u64],
163 rvecs: &[[f64; 3]],
164 spacing_ns: u64,
165 ) -> Self {
166 if timestamps_ns.len() < N {
167 panic!("timestams_ns should be larger than {}", N);
168 } else if timestamps_ns.len() != rvecs.len() {
169 panic!("timestameps_ns should have the same length as rvecs");
170 }
171 let mut prev_t = timestamps_ns[0];
172 for &t in timestamps_ns {
173 if t - prev_t > spacing_ns {
174 panic!("spacing should be larger than all time stamp steps");
175 }
176 prev_t = t;
177 }
178 let num_of_knots =
179 ((timestamps_ns.last().unwrap() - timestamps_ns[0]) / spacing_ns) as usize + N;
180 let timestamp_start_ns = timestamps_ns[0];
181
182 let mut current_idx = 0;
183 let mut knots = Vec::new();
184 for i in 0..num_of_knots {
185 let knot_time_ns = timestamp_start_ns + i as u64 * spacing_ns;
186 while current_idx < timestamps_ns.len() - 1 {
187 if timestamps_ns[current_idx + 1] < knot_time_ns {
188 current_idx += 1;
189 } else {
190 break;
191 }
192 }
193 knots.push(rvecs[current_idx]);
194 }
195
196 let mut bspline = SO3Bspline {
197 timestamp_start_ns,
198 spacing_ns,
199 knots,
200 blending_matrix: compute_blending_matrix(N, true),
201 first_derivative_bases: compute_base_coefficients(N).row(1).transpose(),
202 };
203 bspline.fit(timestamps_ns, rvecs);
204 bspline
205 }
206 fn base_coeffs_with_time<const DERIVATIVE: usize>(u: f64) -> na::DVector<f64> {
207 let mut res = na::DVector::zeros(N);
208 let base_coefficients = compute_base_coefficients(N);
209 if DERIVATIVE < N {
210 res[DERIVATIVE] = base_coefficients[(DERIVATIVE, DERIVATIVE)];
211 let mut ti = u;
212 for j in (DERIVATIVE + 1)..N {
213 res[j] = base_coefficients[(DERIVATIVE, j)] * ti;
214 ti *= u;
215 }
216 }
217 res
218 }
219}