nd_icp/
icp.rs

1// Created by Indraneel on 12/8/24
2
3use core::f32;
4
5use nalgebra::{Const, Dyn, OMatrix, U1};
6
7use crate::types::{Point, PointSet};
8
9/// N dimensional Icp
10///
11/// # Salient features
12///
13/// 1. Generic Type can be any n dimensional point
14/// 2. Implement Point Trait for your point
15/// 3. Uses SVD to find rotation and translation
16/// 3. Vectorised operations to control time complexity
17///
18/// # Examples
19///
20/// ```
21///
22/// let max_iterations = 100;
23/// let cost_change_threshold = 1e-5;
24/// let icp = Icp::new(
25///    model_point_set.clone(),
26///    max_iterations,
27///    cost_change_threshold,
28/// );
29/// let result = icp.register(&target_point_set);
30///
31///
32/// ```
33///
34/// # TODO:
35///
36/// 1. Outlier rejection of input data
37/// 2. Voxel binning to make finding correspondences faster
38pub struct Icp<T>
39where
40    T: Point + Copy,
41{
42    /// Model reference to register against
43    model_point_set: PointSet<T>,
44    /// Max number of iterations to run
45    max_iterations: i32,
46    /// Cost change threshold to terminate after
47    cost_change_threshold: f32,
48}
49
50impl<T> Icp<T>
51where
52    T: Point + Copy,
53{
54    pub fn new(
55        model_point_set: PointSet<T>,
56        max_iterations: i32,
57        cost_change_threshold: f32,
58    ) -> Self {
59        Self {
60            model_point_set,
61            max_iterations,
62            cost_change_threshold,
63        }
64    }
65
66    /// Finds closest point correspondences
67    /// between two sets of points
68    fn get_point_correspondences(
69        &self,
70        target_point_set: &OMatrix<f32, Dyn, Dyn>,
71    ) -> OMatrix<f32, Dyn, Dyn> {
72        let nrows = target_point_set.nrows();
73        let ncols = target_point_set.ncols();
74
75        let mut correspondence_matrix =
76            OMatrix::zeros_generic(nalgebra::Dyn(nrows), nalgebra::Dyn(ncols));
77
78        for (target_point_idx, target_point_mat) in target_point_set.row_iter().enumerate() {
79            let mut closest_point: Option<T> = None;
80            let mut closest_dist = f32::MAX;
81
82            let target_point: T = Point::from_matrix(&target_point_mat);
83
84            for model_point in self.model_point_set.points.iter() {
85                let point_dist = target_point.find_distance_squared(model_point);
86                if point_dist < closest_dist {
87                    closest_dist = point_dist;
88                    closest_point = Some(*model_point);
89                }
90            }
91
92            // Save point correspondence
93            correspondence_matrix
94                .row_mut(target_point_idx)
95                .copy_from_slice(&closest_point.expect("Closest point not found").to_vec());
96        }
97        correspondence_matrix
98    }
99
100    /// Converts a set of points to a Matrix with all points
101    /// stacked in rows
102    fn get_matrix_from_point_set(
103        &self,
104        point_set: &Vec<T>,
105        dimension: usize,
106    ) -> OMatrix<f32, Dyn, Dyn> {
107        let points_vec: Vec<Vec<f32>> = point_set.iter().map(|point| point.to_vec()).collect();
108        let points_vec_flattened: Vec<f32> = points_vec.into_iter().flatten().collect();
109        let target_mat: OMatrix<f32, Dyn, Dyn> = OMatrix::<f32, Dyn, Dyn>::from_row_slice(
110            point_set.len(),
111            dimension,
112            &points_vec_flattened,
113        );
114        target_mat
115    }
116
117    /// Converts a rotation and translation to its homogeneous matrix
118    /// representation
119    fn get_homogeneous_matrix(
120        &self,
121        translation: &OMatrix<f32, U1, Dyn>,
122        rotation: &OMatrix<f32, Dyn, Dyn>,
123        dimension: usize,
124    ) -> OMatrix<f32, Dyn, Dyn> {
125        // Start with an identity matrix
126        let mut homogeneous_matrix: OMatrix<f32, Dyn, Dyn> =
127            OMatrix::identity_generic(nalgebra::Dyn(dimension + 1), nalgebra::Dyn(dimension + 1));
128
129        // Assign the rotation part
130        homogeneous_matrix
131            .view_mut((0, 0), (dimension, dimension))
132            .copy_from(rotation);
133
134        // Assign the translation part
135        homogeneous_matrix
136            .view_mut((0, dimension), (dimension, 1))
137            .copy_from(&translation.transpose());
138
139        homogeneous_matrix
140    }
141
142    /// Calculates the mean squared error between two sets of points
143    fn icp_cost(
144        &self,
145        target_mat_no_mean: &OMatrix<f32, Dyn, Dyn>,
146        model_mat_no_mean: &OMatrix<f32, Dyn, Dyn>,
147        rotation: &OMatrix<f32, Dyn, Dyn>,
148    ) -> f32 {
149        // Rotate the target mat
150        let rotated_target_mat = (rotation * target_mat_no_mean.transpose()).transpose();
151
152        // calculate cost
153        let cost = model_mat_no_mean - rotated_target_mat;
154        cost.norm()
155    }
156
157    /// Subtracts the row-wise mean from a matrix and returns the resulting matrix and the mean.
158    fn center_point_cloud_about_mean(
159        &self,
160        matrix: &OMatrix<f32, Dyn, Dyn>,
161    ) -> (OMatrix<f32, Dyn, Dyn>, OMatrix<f32, Const<1>, Dyn>) {
162        let mean_row = matrix.row_mean();
163        let matrix_no_mean = OMatrix::from_rows(
164            &matrix
165                .row_iter()
166                .map(|row| row - mean_row.clone_owned())
167                .collect::<Vec<_>>(),
168        );
169        (matrix_no_mean, mean_row)
170    }
171
172    /// Applies a homogenous transformation to a matrix of points
173    fn transform_matrix(
174        &self,
175        matrix: &mut OMatrix<f32, Dyn, Dyn>,
176        homogeneous_transformation_matrix: &OMatrix<f32, Dyn, Dyn>,
177    ) {
178        let nrows = matrix.nrows();
179        let ncols = matrix.ncols();
180
181        let mut homogeneous_representation = matrix.clone_owned();
182        homogeneous_representation = homogeneous_representation.insert_column(ncols, 1.0);
183
184        // Apply the homogeneous transformation
185        let transformed_homogeneous_matrix =
186            homogeneous_transformation_matrix * homogeneous_representation.transpose();
187
188        *matrix = transformed_homogeneous_matrix
189            .transpose()
190            .view((0, 0), (nrows, ncols))
191            .into_owned();
192    }
193
194    /// ICP registration
195    ///
196    /// 1. Initialises the transformation given number of dimensions
197    /// 2. For each point
198    ///     - Find closest point in reference cloud
199    /// 3. Remove the means
200    /// 4. Find transformaton and rotation which will minimise the error
201    /// 5. Transform the target cloud
202    /// 6. Iterate until within error threshold or max iterations
203    pub fn register(&self, target_point_set: &PointSet<T>) -> OMatrix<f32, Dyn, Dyn> {
204        let dimension = target_point_set
205            .points
206            .iter()
207            .next()
208            .expect("Input set is empty")
209            .get_dimensions();
210
211        // Vectorise point sets
212        let mut target_mat = self.get_matrix_from_point_set(&target_point_set.points, dimension);
213
214        // Initialise transformation
215        let mut registration_matrix: OMatrix<f32, Dyn, Dyn> =
216            OMatrix::identity_generic(nalgebra::Dyn(dimension + 1), nalgebra::Dyn(dimension + 1));
217
218        // Begin iterations
219        let mut previous_cost = f32::MAX;
220        for iteration in 0..self.max_iterations {
221            // Find point correspondences
222            let correspondence_mat = self.get_point_correspondences(&target_mat);
223
224            // Remove the means from the point clouds for better rotation matrix calculation
225            let (correspondence_mat_no_mean, mean_correspondence_point) =
226                self.center_point_cloud_about_mean(&correspondence_mat);
227            let (target_mat_no_mean, mean_target_point) =
228                self.center_point_cloud_about_mean(&target_mat);
229
230            // Calculate cross covariance
231            let cross_covariance_mat =
232                correspondence_mat_no_mean.transpose() * target_mat_no_mean.clone();
233
234            // Find best rotation
235            let res = nalgebra::linalg::SVD::new(cross_covariance_mat, true, true);
236            let u = res.u.expect("Failed to calculate u matrix");
237            let vt = res.v_t.expect("Failed to calculate vt matrix");
238            let rotation = u * vt;
239
240            // Find translation
241            let translation = mean_correspondence_point
242                - (rotation.clone() * mean_target_point.transpose()).transpose();
243
244            let homogenous_mat = self.get_homogeneous_matrix(&translation, &rotation, dimension);
245            println!(
246                " r {} test {} homo {}",
247                rotation, translation, homogenous_mat
248            );
249
250            // Transform target cloud
251            self.transform_matrix(&mut target_mat, &homogenous_mat);
252
253            // Update registration matrix
254            registration_matrix *= homogenous_mat;
255
256            // Calculate cost
257            let icp_cost =
258                self.icp_cost(&target_mat_no_mean, &correspondence_mat_no_mean, &rotation);
259            println!(
260                "=== Finished iteration {} with cost {}",
261                iteration, icp_cost
262            );
263
264            // Check termination condition
265            if (previous_cost - icp_cost).abs() < self.cost_change_threshold {
266                println!(
267                    "Reached termination threshold of {} with {} exiting!",
268                    self.cost_change_threshold, previous_cost
269                );
270                break;
271            }
272            previous_cost = icp_cost;
273        }
274
275        registration_matrix
276    }
277}
278
279#[cfg(test)]
280mod tests {
281    use crate::types::Point3D;
282
283    use super::*;
284
285    use rstest::*;
286
287    #[fixture]
288    fn icp_fixture() -> Icp<Point3D> {
289        let max_iterations = 1;
290        let cost_change_threshold = 1e-3;
291
292        let model_point_set = PointSet {
293            points: vec![
294                Point3D {
295                    x: 1.0,
296                    y: 1.0,
297                    z: 1.0,
298                },
299                Point3D {
300                    x: 2.0,
301                    y: 2.0,
302                    z: 2.0,
303                },
304                Point3D {
305                    x: 3.0,
306                    y: 3.0,
307                    z: 3.0,
308                },
309            ],
310        };
311
312        Icp::new(model_point_set, max_iterations, cost_change_threshold)
313    }
314
315    #[rstest]
316    fn test_get_matrix_from_point_set(icp_fixture: Icp<Point3D>) {
317        let expected_matrix = OMatrix::<f32, Dyn, Dyn>::from_row_slice(
318            3,
319            3,
320            &[1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0],
321        );
322
323        let result_matrix =
324            icp_fixture.get_matrix_from_point_set(&icp_fixture.model_point_set.points, 3);
325
326        assert_eq!(result_matrix, expected_matrix);
327        assert_eq!(result_matrix.row(0), expected_matrix.row(0));
328        assert_eq!(result_matrix.row(1), expected_matrix.row(1));
329        assert_eq!(result_matrix.row(2), expected_matrix.row(2));
330    }
331
332    #[rstest]
333    fn test_get_point_correspondences(icp_fixture: Icp<Point3D>) {
334        let target_matrix = OMatrix::<f32, Dyn, Dyn>::from_row_slice(
335            3,
336            3,
337            &[1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0],
338        );
339
340        let correspondence_mat = icp_fixture.get_point_correspondences(&target_matrix);
341
342        assert_eq!(correspondence_mat, target_matrix);
343
344        // Shuffled order
345        let target_shuffled = OMatrix::<f32, Dyn, Dyn>::from_row_slice(
346            3,
347            3,
348            &[3.0, 3.0, 3.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0],
349        );
350
351        let correspondence_mat = icp_fixture.get_point_correspondences(&target_shuffled);
352
353        assert_eq!(correspondence_mat, target_shuffled);
354    }
355
356    #[rstest]
357    fn test_center_point_cloud_about_mean(icp_fixture: Icp<Point3D>) {
358        let model_matrix =
359            icp_fixture.get_matrix_from_point_set(&icp_fixture.model_point_set.points, 3);
360
361        let expected_mean = OMatrix::<f32, Dyn, Dyn>::from_row_slice(1, 3, &[2.0, 2.0, 2.0]);
362
363        let expected_matrix = OMatrix::<f32, Dyn, Dyn>::from_row_slice(
364            3,
365            3,
366            &[
367                -1.0, -1.0, -1.0, // First row - mean
368                0.0, 0.0, 0.0, // Second row - mean
369                1.0, 1.0, 1.0, // Third row - mean
370            ],
371        );
372
373        // Call the function
374        let (result_matrix, result_mean) = icp_fixture.center_point_cloud_about_mean(&model_matrix);
375
376        assert_eq!(result_matrix, expected_matrix);
377        assert_eq!(result_mean, expected_mean);
378
379        let another_matrix = OMatrix::<f32, Dyn, Dyn>::from_row_slice(
380            3,
381            3,
382            &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
383        );
384
385        let expected_mean = OMatrix::<f32, Dyn, Dyn>::from_row_slice(1, 3, &[4.0, 5.0, 6.0]);
386
387        let expected_matrix = OMatrix::<f32, Dyn, Dyn>::from_row_slice(
388            3,
389            3,
390            &[-3.0, -3.0, -3.0, 0.0, 0.0, 0.0, 3.0, 3.0, 3.0],
391        );
392
393        // Call the function
394        let (result_matrix, result_mean) =
395            icp_fixture.center_point_cloud_about_mean(&another_matrix);
396
397        assert_eq!(result_matrix, expected_matrix);
398        assert_eq!(result_mean, expected_mean);
399    }
400
401    #[rstest]
402    fn test_get_homogeneous_matrix(icp_fixture: Icp<Point3D>) {
403        // Sample rotation matrix (identity for simplicity)
404        let rotation = OMatrix::<f32, Dyn, Dyn>::identity_generic(Dyn(3), Dyn(3));
405
406        // Sample translation vector
407        let translation = OMatrix::<f32, U1, Dyn>::from_row_slice(&[1.0, 2.0, 3.0]);
408
409        // Expected homogeneous matrix
410        let expected_matrix = OMatrix::<f32, Dyn, Dyn>::from_row_slice(
411            4,
412            4,
413            &[
414                1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 2.0, 0.0, 0.0, 1.0, 3.0, 0.0, 0.0, 0.0, 1.0,
415            ],
416        );
417
418        let result_matrix = icp_fixture.get_homogeneous_matrix(&translation, &rotation, 3);
419
420        assert_eq!(result_matrix, expected_matrix);
421    }
422
423    #[rstest]
424    fn test_transform_matrix(icp_fixture: Icp<Point3D>) {
425        let mut matrix = OMatrix::<f32, Dyn, Dyn>::from_row_slice(2, 2, &[1.0, 2.0, 3.0, 4.0]);
426        let transformation = OMatrix::<f32, Dyn, Dyn>::from_row_slice(
427            3,
428            3,
429            &[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0],
430        );
431
432        let expected = OMatrix::<f32, Dyn, Dyn>::from_row_slice(2, 2, &[1.0, 2.0, 3.0, 4.0]);
433
434        icp_fixture.transform_matrix(&mut matrix, &transformation);
435        assert_eq!(matrix, expected);
436
437        // Translation
438        let transformation = OMatrix::<f32, Dyn, Dyn>::from_row_slice(
439            3,
440            3,
441            &[1.0, 0.0, 1.0, 0.0, 1.0, 2.0, 0.0, 0.0, 1.0],
442        );
443
444        let expected = OMatrix::<f32, Dyn, Dyn>::from_row_slice(2, 2, &[2.0, 4.0, 4.0, 6.0]);
445
446        icp_fixture.transform_matrix(&mut matrix, &transformation);
447        assert_eq!(matrix, expected);
448    }
449
450    #[rstest]
451    fn test_icp_registration(icp_fixture: Icp<Point3D>) {
452        let target_point_set = PointSet {
453            points: vec![
454                Point3D {
455                    x: 1.0,
456                    y: 1.0,
457                    z: 1.0,
458                },
459                Point3D {
460                    x: 2.0,
461                    y: 2.0,
462                    z: 2.0,
463                },
464                Point3D {
465                    x: 3.0,
466                    y: 3.0,
467                    z: 3.0,
468                },
469                Point3D {
470                    x: 3.0,
471                    y: 3.0,
472                    z: 3.0,
473                },
474            ],
475        };
476
477        let result = icp_fixture.register(&target_point_set);
478
479        assert!(result[(0, 3)].abs() < 1e-3);
480        assert!(result[(1, 3)].abs() < 1e-3);
481        assert!(result[(2, 3)].abs() < 1e-3);
482    }
483}