braid_mvg/
align_points.rs

1// Copyright 2016-2025 Andrew D. Straw.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT
5// or http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use nalgebra::{
9    allocator::Allocator, DefaultAllocator, Dyn, Matrix, Matrix3, Matrix3x1, OMatrix, RealField,
10    VecStorage, U1, U3,
11};
12use num_traits::float::TotalOrder;
13
14use crate::{MvgError, Result};
15
16/// Algorithm selection for point cloud alignment.
17///
18/// This enum specifies which algorithm to use when aligning two sets of 3D points.
19/// Different algorithms have different robustness characteristics and computational
20/// requirements.
21pub enum Algorithm {
22    /// The Kabsch-Umeyama algorithm for point set alignment.
23    ///
24    /// This is a classic algorithm that finds the optimal similarity transformation
25    /// (scale, rotation, translation) between two point sets. It's mathematically
26    /// elegant but can be sensitive to outliers.
27    KabschUmeyama,
28    /// A robustly-scaled variant of the Arun, Huang, and Blostein algorithm.
29    ///
30    /// This algorithm provides more robust scaling estimation compared to
31    /// Kabsch-Umeyama, making it more suitable for real-world data that may
32    /// contain noise or outliers.
33    RobustArun,
34}
35
36/// Find the linear transformation that converts 3D points `x` as close as
37/// possible to points `y`.
38///
39/// The best (scale, rotation, translation) are returned.
40///
41/// The Kabsch-Umeyama implementation is based on that in
42/// <https://github.com/clementinboittiaux/umeyama-python/blob/main/umeyama.py>.
43///
44/// The robust Arun implementation is based on that in
45/// <https://github.com/strawlab/MultiCamSelfCal/blob/main/MultiCamSelfCal/CoreFunctions/estsimt.m>.
46/// That code claims to be an implementation of the Arun, Huang, and Blostein
47/// algorithm, but contains an extra bit to determine scaling which works
48/// differently, and in my experience is more robust than, the Kabsch-Umeyama
49/// algorithm.
50pub fn align_points<T>(
51    x: &OMatrix<T, U3, Dyn>,
52    y: &OMatrix<T, U3, Dyn>,
53    algorithm: Algorithm,
54) -> Result<(T, Matrix3<T>, Matrix3x1<T>)>
55where
56    T: RealField + Copy + TotalOrder,
57{
58    let n = x.ncols();
59
60    if n != y.ncols() {
61        return Err(MvgError::InvalidShape);
62    }
63    if n < 1 {
64        return Err(MvgError::InvalidShape);
65    }
66
67    // Find centroids.
68    let mu_x = x.column_mean();
69    let mu_y = y.column_mean();
70
71    // Move points to center.
72    let x_center = x - bcast(&mu_x, n);
73    let y_center = y - bcast(&mu_y, n);
74
75    // Covariance of X,Y
76    let (robust_scale, cov_xy) = match algorithm {
77        Algorithm::RobustArun => {
78            let dx = x.columns(1, n - 1) - x.columns(0, n - 1);
79            let dy = y.columns(1, n - 1) - y.columns(0, n - 1);
80            let dx = sqrt(&square(&dx).row_sum());
81            let dy = sqrt(&square(&dy).row_sum());
82            let scales = dy.component_div(&dx);
83
84            let scale = median(&scales).unwrap();
85
86            let x_centered_scaled = &x_center * scale;
87
88            let cov_xy = &x_centered_scaled * y_center.transpose();
89            (Some(scale), cov_xy)
90        }
91        Algorithm::KabschUmeyama => {
92            let cov_xy = (y_center * x_center.transpose()) / nalgebra::convert::<_, T>(n as f64);
93            (None, cov_xy)
94        }
95    };
96
97    // Decomposition of covariance matrix.
98    const SVD_MAX_ITERATIONS: usize = 1_000_000;
99
100    let svd = if let Some(svd) = nalgebra::linalg::SVD::try_new(
101        cov_xy,
102        true,
103        true,
104        nalgebra::convert(1e-7),
105        SVD_MAX_ITERATIONS,
106    ) {
107        svd
108    } else {
109        return Err(MvgError::SvdFailed);
110    };
111    let u = svd.u.unwrap();
112    let d = svd.singular_values;
113    let vh = svd.v_t.unwrap();
114
115    // Generate rotation matrix
116    let (c, r) = if let Some(scale) = robust_scale {
117        let v = vh.transpose();
118        let ut = u.transpose();
119        (scale, v * ut)
120    } else {
121        let mut s = nalgebra::Matrix3::<T>::identity();
122
123        // Are the points reflected?
124        if u.determinant() * vh.determinant() < nalgebra::convert(0.0) {
125            s[(2, 2)] = nalgebra::convert(-1.0);
126        }
127
128        // Variance of X
129        let var_x = square(&x_center).row_sum().mean();
130        let c = (nalgebra::Matrix3::from_diagonal(&d) * s).trace() / var_x;
131        (c, u * s * vh)
132    };
133
134    // Translation
135    let t = mu_y - (r * mu_x) * c;
136
137    Ok((c, r, t))
138}
139
140fn bcast<T, R>(m: &OMatrix<T, R, U1>, n: usize) -> OMatrix<T, R, Dyn>
141where
142    T: RealField + Copy,
143    R: nalgebra::DimName,
144    DefaultAllocator: Allocator<R>,
145{
146    // this is far from efficient
147    let mut result = OMatrix::<T, R, Dyn>::zeros(n);
148    for i in 0..R::dim() {
149        for j in 0..n {
150            result[(i, j)] = m[(i, 0)];
151        }
152    }
153    result
154}
155
156fn sqrt<T, R, C>(m: &OMatrix<T, R, C>) -> OMatrix<T, R, C>
157where
158    T: RealField + Copy,
159    R: nalgebra::Dim,
160    C: nalgebra::Dim,
161    DefaultAllocator: Allocator<R, C>,
162{
163    let mut result = m.clone();
164    sqrt_in_place(&mut result);
165    result
166}
167
168fn sqrt_in_place<T, R, C>(m: &mut OMatrix<T, R, C>)
169where
170    T: RealField + Copy,
171    R: nalgebra::Dim,
172    C: nalgebra::Dim,
173    DefaultAllocator: Allocator<R, C>,
174{
175    for el in m.iter_mut() {
176        let val: T = *el;
177        *el = val.sqrt();
178    }
179}
180
181fn square<T, R, C>(m: &OMatrix<T, R, C>) -> OMatrix<T, R, C>
182where
183    T: RealField + Copy,
184    R: nalgebra::Dim,
185    C: nalgebra::Dim,
186    DefaultAllocator: Allocator<R, C>,
187{
188    m.component_mul(m)
189}
190
191fn median<T, C>(scales: &Matrix<T, U1, C, VecStorage<T, U1, C>>) -> Option<T>
192where
193    T: RealField + Copy + TotalOrder,
194    C: nalgebra::Dim,
195    DefaultAllocator: Allocator<U1, C>,
196{
197    let mut scales = scales.data.as_slice().to_vec(); // clone data to vec
198
199    scales.as_mut_slice().sort_by(|a, b| a.total_cmp(b));
200
201    let n = scales.len();
202    if n == 0 {
203        None
204    } else if n == 1 {
205        Some(scales[0])
206    } else if n % 2 == 0 {
207        let s1 = scales[n / 2 - 1];
208        let s2 = scales[n / 2];
209        Some((s1 + s2) * nalgebra::convert(0.5))
210    } else {
211        // odd
212        Some(scales[n / 2])
213    }
214}
215
216#[test]
217fn test_median() {
218    let mut a = OMatrix::<f64, U1, Dyn>::zeros(3);
219    a[(0, 0)] = 1.0;
220    a[(0, 1)] = 2.0;
221    a[(0, 2)] = 3.0;
222    assert_eq!(median(&a), Some(2.0));
223
224    let mut a = OMatrix::<f64, U1, Dyn>::zeros(2);
225    a[(0, 0)] = 1.0;
226    a[(0, 1)] = 2.0;
227    assert_eq!(median(&a), Some(1.5));
228}
229
230#[test]
231fn test_square() {
232    let a = nalgebra::Matrix2::new(0., 1., 2., 3.);
233    let b = square(&a);
234    assert_eq!(b, nalgebra::Matrix2::new(0., 1., 4., 9.));
235}
236
237#[test]
238fn test_align_points() {
239    use nalgebra::{Matrix3, Vector3};
240
241    #[rustfmt::skip]
242    // This is transposed because we are using `from_column_slice()`.
243    let x1 = nalgebra::base::Matrix3xX::from_column_slice(&[
244        3.36748406,1.61036404,3.55147255,
245        3.58702265,0.06676394,3.64695356,
246        0.28452026,-0.11188296,3.78947735,
247        0.25482713,1.57828256,3.6900808,
248        3.54938525,1.74057692,5.13329681,
249        3.6855626,0.10335229,5.26344841,
250        0.25025385,-0.06146044,5.57085135,
251        0.20742481,1.71073272,5.41823085]);
252
253    #[rustfmt::skip]
254    let x2_noisy = nalgebra::base::Matrix3xX::from_column_slice(&[
255        3.048,1.524,1.524,
256        3.048,0.0,1.524,
257        0.0,0.0,1.524,
258        0.0,1.524,1.524,
259        3.048,1.524,0.0,
260        3.048,0.0,0.0,
261        0.0,0.0,0.0,
262        0.0,1.524,0.0]);
263
264    for algorithm in [Algorithm::KabschUmeyama, Algorithm::RobustArun] {
265        // Test in noise-free conditions with generated data.
266        let c_expected = 0.1;
267        let r_expected = *nalgebra::geometry::Rotation3::from_euler_angles(
268            std::f64::consts::FRAC_PI_4,
269            0.0,
270            0.0,
271        )
272        .matrix();
273        let t_expected = Vector3::new(-0.2, 0.3, -0.4);
274
275        let x2 = c_expected * r_expected * &x1 + bcast(&t_expected, 8);
276
277        let (c, r, t) = align_points(&x1, &x2, algorithm).unwrap();
278
279        approx::assert_abs_diff_eq!(c, c_expected);
280        approx::assert_abs_diff_eq!(r, r_expected, epsilon = 1e-10);
281        approx::assert_abs_diff_eq!(t, t_expected, epsilon = 1e-10);
282    }
283
284    // Test on some real data which seems problematic for Kabsch-Umeyama using
285    // the robust scale option.
286    let (c, r, t) = align_points(&x1, &x2_noisy, Algorithm::RobustArun).unwrap();
287
288    // These values were generated by running on this data using `estsimt()` in
289    // `align.py` from flydra.
290    let c_expected = 0.920734586302497;
291    #[rustfmt::skip]
292        let r_expected = {
293            Matrix3::new(
294                0.997554805278945, 0.03689676080610408, -0.05935519780863721,
295                -0.04056669686950421, 0.9972599534887207, -0.06186217158144404,
296                -0.05691004805816868, -0.06411875084319214, -0.9963182384260189,
297            )
298        };
299    let t_expected = Vector3::new(
300        -0.0013862645696010034,
301        0.3279319869522358,
302        5.0458138154244985,
303    );
304
305    approx::assert_abs_diff_eq!(c, c_expected);
306    approx::assert_abs_diff_eq!(r, r_expected, epsilon = 1e-10);
307    approx::assert_abs_diff_eq!(t, t_expected, epsilon = 1e-10);
308
309    // let xformed = c * r * x1 + bcast(&t, 8);
310    // println!("xformed{xformed}");
311    // let p = c * r;
312    // let mut pp = nalgebra::Matrix4::zeros();
313    // let mut ul = pp.fixed_view_mut::<3, 3>(0, 0);
314    // ul.set_row(0, &p.row(0));
315    // ul.set_row(1, &p.row(1));
316    // ul.set_row(2, &p.row(2));
317    // pp[(0, 3)] = t[0];
318    // pp[(1, 3)] = t[1];
319    // pp[(2, 3)] = t[2];
320
321    // println!("pp\n{}", &pp);
322}