1use nalgebra::{
9 allocator::Allocator, DefaultAllocator, Dyn, Matrix, Matrix3, Matrix3x1, OMatrix, RealField,
10 VecStorage, U1, U3,
11};
12use num_traits::float::TotalOrder;
13
14use crate::{MvgError, Result};
15
16pub enum Algorithm {
22 KabschUmeyama,
28 RobustArun,
34}
35
36pub fn align_points<T>(
51 x: &OMatrix<T, U3, Dyn>,
52 y: &OMatrix<T, U3, Dyn>,
53 algorithm: Algorithm,
54) -> Result<(T, Matrix3<T>, Matrix3x1<T>)>
55where
56 T: RealField + Copy + TotalOrder,
57{
58 let n = x.ncols();
59
60 if n != y.ncols() {
61 return Err(MvgError::InvalidShape);
62 }
63 if n < 1 {
64 return Err(MvgError::InvalidShape);
65 }
66
67 let mu_x = x.column_mean();
69 let mu_y = y.column_mean();
70
71 let x_center = x - bcast(&mu_x, n);
73 let y_center = y - bcast(&mu_y, n);
74
75 let (robust_scale, cov_xy) = match algorithm {
77 Algorithm::RobustArun => {
78 let dx = x.columns(1, n - 1) - x.columns(0, n - 1);
79 let dy = y.columns(1, n - 1) - y.columns(0, n - 1);
80 let dx = sqrt(&square(&dx).row_sum());
81 let dy = sqrt(&square(&dy).row_sum());
82 let scales = dy.component_div(&dx);
83
84 let scale = median(&scales).unwrap();
85
86 let x_centered_scaled = &x_center * scale;
87
88 let cov_xy = &x_centered_scaled * y_center.transpose();
89 (Some(scale), cov_xy)
90 }
91 Algorithm::KabschUmeyama => {
92 let cov_xy = (y_center * x_center.transpose()) / nalgebra::convert::<_, T>(n as f64);
93 (None, cov_xy)
94 }
95 };
96
97 const SVD_MAX_ITERATIONS: usize = 1_000_000;
99
100 let svd = if let Some(svd) = nalgebra::linalg::SVD::try_new(
101 cov_xy,
102 true,
103 true,
104 nalgebra::convert(1e-7),
105 SVD_MAX_ITERATIONS,
106 ) {
107 svd
108 } else {
109 return Err(MvgError::SvdFailed);
110 };
111 let u = svd.u.unwrap();
112 let d = svd.singular_values;
113 let vh = svd.v_t.unwrap();
114
115 let (c, r) = if let Some(scale) = robust_scale {
117 let v = vh.transpose();
118 let ut = u.transpose();
119 (scale, v * ut)
120 } else {
121 let mut s = nalgebra::Matrix3::<T>::identity();
122
123 if u.determinant() * vh.determinant() < nalgebra::convert(0.0) {
125 s[(2, 2)] = nalgebra::convert(-1.0);
126 }
127
128 let var_x = square(&x_center).row_sum().mean();
130 let c = (nalgebra::Matrix3::from_diagonal(&d) * s).trace() / var_x;
131 (c, u * s * vh)
132 };
133
134 let t = mu_y - (r * mu_x) * c;
136
137 Ok((c, r, t))
138}
139
140fn bcast<T, R>(m: &OMatrix<T, R, U1>, n: usize) -> OMatrix<T, R, Dyn>
141where
142 T: RealField + Copy,
143 R: nalgebra::DimName,
144 DefaultAllocator: Allocator<R>,
145{
146 let mut result = OMatrix::<T, R, Dyn>::zeros(n);
148 for i in 0..R::dim() {
149 for j in 0..n {
150 result[(i, j)] = m[(i, 0)];
151 }
152 }
153 result
154}
155
156fn sqrt<T, R, C>(m: &OMatrix<T, R, C>) -> OMatrix<T, R, C>
157where
158 T: RealField + Copy,
159 R: nalgebra::Dim,
160 C: nalgebra::Dim,
161 DefaultAllocator: Allocator<R, C>,
162{
163 let mut result = m.clone();
164 sqrt_in_place(&mut result);
165 result
166}
167
168fn sqrt_in_place<T, R, C>(m: &mut OMatrix<T, R, C>)
169where
170 T: RealField + Copy,
171 R: nalgebra::Dim,
172 C: nalgebra::Dim,
173 DefaultAllocator: Allocator<R, C>,
174{
175 for el in m.iter_mut() {
176 let val: T = *el;
177 *el = val.sqrt();
178 }
179}
180
181fn square<T, R, C>(m: &OMatrix<T, R, C>) -> OMatrix<T, R, C>
182where
183 T: RealField + Copy,
184 R: nalgebra::Dim,
185 C: nalgebra::Dim,
186 DefaultAllocator: Allocator<R, C>,
187{
188 m.component_mul(m)
189}
190
191fn median<T, C>(scales: &Matrix<T, U1, C, VecStorage<T, U1, C>>) -> Option<T>
192where
193 T: RealField + Copy + TotalOrder,
194 C: nalgebra::Dim,
195 DefaultAllocator: Allocator<U1, C>,
196{
197 let mut scales = scales.data.as_slice().to_vec(); scales.as_mut_slice().sort_by(|a, b| a.total_cmp(b));
200
201 let n = scales.len();
202 if n == 0 {
203 None
204 } else if n == 1 {
205 Some(scales[0])
206 } else if n % 2 == 0 {
207 let s1 = scales[n / 2 - 1];
208 let s2 = scales[n / 2];
209 Some((s1 + s2) * nalgebra::convert(0.5))
210 } else {
211 Some(scales[n / 2])
213 }
214}
215
216#[test]
217fn test_median() {
218 let mut a = OMatrix::<f64, U1, Dyn>::zeros(3);
219 a[(0, 0)] = 1.0;
220 a[(0, 1)] = 2.0;
221 a[(0, 2)] = 3.0;
222 assert_eq!(median(&a), Some(2.0));
223
224 let mut a = OMatrix::<f64, U1, Dyn>::zeros(2);
225 a[(0, 0)] = 1.0;
226 a[(0, 1)] = 2.0;
227 assert_eq!(median(&a), Some(1.5));
228}
229
230#[test]
231fn test_square() {
232 let a = nalgebra::Matrix2::new(0., 1., 2., 3.);
233 let b = square(&a);
234 assert_eq!(b, nalgebra::Matrix2::new(0., 1., 4., 9.));
235}
236
237#[test]
238fn test_align_points() {
239 use nalgebra::{Matrix3, Vector3};
240
241 #[rustfmt::skip]
242 let x1 = nalgebra::base::Matrix3xX::from_column_slice(&[
244 3.36748406,1.61036404,3.55147255,
245 3.58702265,0.06676394,3.64695356,
246 0.28452026,-0.11188296,3.78947735,
247 0.25482713,1.57828256,3.6900808,
248 3.54938525,1.74057692,5.13329681,
249 3.6855626,0.10335229,5.26344841,
250 0.25025385,-0.06146044,5.57085135,
251 0.20742481,1.71073272,5.41823085]);
252
253 #[rustfmt::skip]
254 let x2_noisy = nalgebra::base::Matrix3xX::from_column_slice(&[
255 3.048,1.524,1.524,
256 3.048,0.0,1.524,
257 0.0,0.0,1.524,
258 0.0,1.524,1.524,
259 3.048,1.524,0.0,
260 3.048,0.0,0.0,
261 0.0,0.0,0.0,
262 0.0,1.524,0.0]);
263
264 for algorithm in [Algorithm::KabschUmeyama, Algorithm::RobustArun] {
265 let c_expected = 0.1;
267 let r_expected = *nalgebra::geometry::Rotation3::from_euler_angles(
268 std::f64::consts::FRAC_PI_4,
269 0.0,
270 0.0,
271 )
272 .matrix();
273 let t_expected = Vector3::new(-0.2, 0.3, -0.4);
274
275 let x2 = c_expected * r_expected * &x1 + bcast(&t_expected, 8);
276
277 let (c, r, t) = align_points(&x1, &x2, algorithm).unwrap();
278
279 approx::assert_abs_diff_eq!(c, c_expected);
280 approx::assert_abs_diff_eq!(r, r_expected, epsilon = 1e-10);
281 approx::assert_abs_diff_eq!(t, t_expected, epsilon = 1e-10);
282 }
283
284 let (c, r, t) = align_points(&x1, &x2_noisy, Algorithm::RobustArun).unwrap();
287
288 let c_expected = 0.920734586302497;
291 #[rustfmt::skip]
292 let r_expected = {
293 Matrix3::new(
294 0.997554805278945, 0.03689676080610408, -0.05935519780863721,
295 -0.04056669686950421, 0.9972599534887207, -0.06186217158144404,
296 -0.05691004805816868, -0.06411875084319214, -0.9963182384260189,
297 )
298 };
299 let t_expected = Vector3::new(
300 -0.0013862645696010034,
301 0.3279319869522358,
302 5.0458138154244985,
303 );
304
305 approx::assert_abs_diff_eq!(c, c_expected);
306 approx::assert_abs_diff_eq!(r, r_expected, epsilon = 1e-10);
307 approx::assert_abs_diff_eq!(t, t_expected, epsilon = 1e-10);
308
309 }