1use core::f32;
4
5use nalgebra::{Const, Dyn, OMatrix, U1};
6
7use crate::types::{Point, PointSet};
8
9pub struct Icp<T>
39where
40 T: Point + Copy,
41{
42 model_point_set: PointSet<T>,
44 max_iterations: i32,
46 cost_change_threshold: f32,
48}
49
50impl<T> Icp<T>
51where
52 T: Point + Copy,
53{
54 pub fn new(
55 model_point_set: PointSet<T>,
56 max_iterations: i32,
57 cost_change_threshold: f32,
58 ) -> Self {
59 Self {
60 model_point_set,
61 max_iterations,
62 cost_change_threshold,
63 }
64 }
65
66 fn get_point_correspondences(
69 &self,
70 target_point_set: &OMatrix<f32, Dyn, Dyn>,
71 ) -> OMatrix<f32, Dyn, Dyn> {
72 let nrows = target_point_set.nrows();
73 let ncols = target_point_set.ncols();
74
75 let mut correspondence_matrix =
76 OMatrix::zeros_generic(nalgebra::Dyn(nrows), nalgebra::Dyn(ncols));
77
78 for (target_point_idx, target_point_mat) in target_point_set.row_iter().enumerate() {
79 let mut closest_point: Option<T> = None;
80 let mut closest_dist = f32::MAX;
81
82 let target_point: T = Point::from_matrix(&target_point_mat);
83
84 for model_point in self.model_point_set.points.iter() {
85 let point_dist = target_point.find_distance_squared(model_point);
86 if point_dist < closest_dist {
87 closest_dist = point_dist;
88 closest_point = Some(*model_point);
89 }
90 }
91
92 correspondence_matrix
94 .row_mut(target_point_idx)
95 .copy_from_slice(&closest_point.expect("Closest point not found").to_vec());
96 }
97 correspondence_matrix
98 }
99
100 fn get_matrix_from_point_set(
103 &self,
104 point_set: &Vec<T>,
105 dimension: usize,
106 ) -> OMatrix<f32, Dyn, Dyn> {
107 let points_vec: Vec<Vec<f32>> = point_set.iter().map(|point| point.to_vec()).collect();
108 let points_vec_flattened: Vec<f32> = points_vec.into_iter().flatten().collect();
109 let target_mat: OMatrix<f32, Dyn, Dyn> = OMatrix::<f32, Dyn, Dyn>::from_row_slice(
110 point_set.len(),
111 dimension,
112 &points_vec_flattened,
113 );
114 target_mat
115 }
116
117 fn get_homogeneous_matrix(
120 &self,
121 translation: &OMatrix<f32, U1, Dyn>,
122 rotation: &OMatrix<f32, Dyn, Dyn>,
123 dimension: usize,
124 ) -> OMatrix<f32, Dyn, Dyn> {
125 let mut homogeneous_matrix: OMatrix<f32, Dyn, Dyn> =
127 OMatrix::identity_generic(nalgebra::Dyn(dimension + 1), nalgebra::Dyn(dimension + 1));
128
129 homogeneous_matrix
131 .view_mut((0, 0), (dimension, dimension))
132 .copy_from(rotation);
133
134 homogeneous_matrix
136 .view_mut((0, dimension), (dimension, 1))
137 .copy_from(&translation.transpose());
138
139 homogeneous_matrix
140 }
141
142 fn icp_cost(
144 &self,
145 target_mat_no_mean: &OMatrix<f32, Dyn, Dyn>,
146 model_mat_no_mean: &OMatrix<f32, Dyn, Dyn>,
147 rotation: &OMatrix<f32, Dyn, Dyn>,
148 ) -> f32 {
149 let rotated_target_mat = (rotation * target_mat_no_mean.transpose()).transpose();
151
152 let cost = model_mat_no_mean - rotated_target_mat;
154 cost.norm()
155 }
156
157 fn center_point_cloud_about_mean(
159 &self,
160 matrix: &OMatrix<f32, Dyn, Dyn>,
161 ) -> (OMatrix<f32, Dyn, Dyn>, OMatrix<f32, Const<1>, Dyn>) {
162 let mean_row = matrix.row_mean();
163 let matrix_no_mean = OMatrix::from_rows(
164 &matrix
165 .row_iter()
166 .map(|row| row - mean_row.clone_owned())
167 .collect::<Vec<_>>(),
168 );
169 (matrix_no_mean, mean_row)
170 }
171
172 fn transform_matrix(
174 &self,
175 matrix: &mut OMatrix<f32, Dyn, Dyn>,
176 homogeneous_transformation_matrix: &OMatrix<f32, Dyn, Dyn>,
177 ) {
178 let nrows = matrix.nrows();
179 let ncols = matrix.ncols();
180
181 let mut homogeneous_representation = matrix.clone_owned();
182 homogeneous_representation = homogeneous_representation.insert_column(ncols, 1.0);
183
184 let transformed_homogeneous_matrix =
186 homogeneous_transformation_matrix * homogeneous_representation.transpose();
187
188 *matrix = transformed_homogeneous_matrix
189 .transpose()
190 .view((0, 0), (nrows, ncols))
191 .into_owned();
192 }
193
194 pub fn register(&self, target_point_set: &PointSet<T>) -> OMatrix<f32, Dyn, Dyn> {
204 let dimension = target_point_set
205 .points
206 .iter()
207 .next()
208 .expect("Input set is empty")
209 .get_dimensions();
210
211 let mut target_mat = self.get_matrix_from_point_set(&target_point_set.points, dimension);
213
214 let mut registration_matrix: OMatrix<f32, Dyn, Dyn> =
216 OMatrix::identity_generic(nalgebra::Dyn(dimension + 1), nalgebra::Dyn(dimension + 1));
217
218 let mut previous_cost = f32::MAX;
220 for iteration in 0..self.max_iterations {
221 let correspondence_mat = self.get_point_correspondences(&target_mat);
223
224 let (correspondence_mat_no_mean, mean_correspondence_point) =
226 self.center_point_cloud_about_mean(&correspondence_mat);
227 let (target_mat_no_mean, mean_target_point) =
228 self.center_point_cloud_about_mean(&target_mat);
229
230 let cross_covariance_mat =
232 correspondence_mat_no_mean.transpose() * target_mat_no_mean.clone();
233
234 let res = nalgebra::linalg::SVD::new(cross_covariance_mat, true, true);
236 let u = res.u.expect("Failed to calculate u matrix");
237 let vt = res.v_t.expect("Failed to calculate vt matrix");
238 let rotation = u * vt;
239
240 let translation = mean_correspondence_point
242 - (rotation.clone() * mean_target_point.transpose()).transpose();
243
244 let homogenous_mat = self.get_homogeneous_matrix(&translation, &rotation, dimension);
245 println!(
246 " r {} test {} homo {}",
247 rotation, translation, homogenous_mat
248 );
249
250 self.transform_matrix(&mut target_mat, &homogenous_mat);
252
253 registration_matrix *= homogenous_mat;
255
256 let icp_cost =
258 self.icp_cost(&target_mat_no_mean, &correspondence_mat_no_mean, &rotation);
259 println!(
260 "=== Finished iteration {} with cost {}",
261 iteration, icp_cost
262 );
263
264 if (previous_cost - icp_cost).abs() < self.cost_change_threshold {
266 println!(
267 "Reached termination threshold of {} with {} exiting!",
268 self.cost_change_threshold, previous_cost
269 );
270 break;
271 }
272 previous_cost = icp_cost;
273 }
274
275 registration_matrix
276 }
277}
278
279#[cfg(test)]
280mod tests {
281 use crate::types::Point3D;
282
283 use super::*;
284
285 use rstest::*;
286
287 #[fixture]
288 fn icp_fixture() -> Icp<Point3D> {
289 let max_iterations = 1;
290 let cost_change_threshold = 1e-3;
291
292 let model_point_set = PointSet {
293 points: vec![
294 Point3D {
295 x: 1.0,
296 y: 1.0,
297 z: 1.0,
298 },
299 Point3D {
300 x: 2.0,
301 y: 2.0,
302 z: 2.0,
303 },
304 Point3D {
305 x: 3.0,
306 y: 3.0,
307 z: 3.0,
308 },
309 ],
310 };
311
312 Icp::new(model_point_set, max_iterations, cost_change_threshold)
313 }
314
315 #[rstest]
316 fn test_get_matrix_from_point_set(icp_fixture: Icp<Point3D>) {
317 let expected_matrix = OMatrix::<f32, Dyn, Dyn>::from_row_slice(
318 3,
319 3,
320 &[1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0],
321 );
322
323 let result_matrix =
324 icp_fixture.get_matrix_from_point_set(&icp_fixture.model_point_set.points, 3);
325
326 assert_eq!(result_matrix, expected_matrix);
327 assert_eq!(result_matrix.row(0), expected_matrix.row(0));
328 assert_eq!(result_matrix.row(1), expected_matrix.row(1));
329 assert_eq!(result_matrix.row(2), expected_matrix.row(2));
330 }
331
332 #[rstest]
333 fn test_get_point_correspondences(icp_fixture: Icp<Point3D>) {
334 let target_matrix = OMatrix::<f32, Dyn, Dyn>::from_row_slice(
335 3,
336 3,
337 &[1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0],
338 );
339
340 let correspondence_mat = icp_fixture.get_point_correspondences(&target_matrix);
341
342 assert_eq!(correspondence_mat, target_matrix);
343
344 let target_shuffled = OMatrix::<f32, Dyn, Dyn>::from_row_slice(
346 3,
347 3,
348 &[3.0, 3.0, 3.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0],
349 );
350
351 let correspondence_mat = icp_fixture.get_point_correspondences(&target_shuffled);
352
353 assert_eq!(correspondence_mat, target_shuffled);
354 }
355
356 #[rstest]
357 fn test_center_point_cloud_about_mean(icp_fixture: Icp<Point3D>) {
358 let model_matrix =
359 icp_fixture.get_matrix_from_point_set(&icp_fixture.model_point_set.points, 3);
360
361 let expected_mean = OMatrix::<f32, Dyn, Dyn>::from_row_slice(1, 3, &[2.0, 2.0, 2.0]);
362
363 let expected_matrix = OMatrix::<f32, Dyn, Dyn>::from_row_slice(
364 3,
365 3,
366 &[
367 -1.0, -1.0, -1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, ],
371 );
372
373 let (result_matrix, result_mean) = icp_fixture.center_point_cloud_about_mean(&model_matrix);
375
376 assert_eq!(result_matrix, expected_matrix);
377 assert_eq!(result_mean, expected_mean);
378
379 let another_matrix = OMatrix::<f32, Dyn, Dyn>::from_row_slice(
380 3,
381 3,
382 &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
383 );
384
385 let expected_mean = OMatrix::<f32, Dyn, Dyn>::from_row_slice(1, 3, &[4.0, 5.0, 6.0]);
386
387 let expected_matrix = OMatrix::<f32, Dyn, Dyn>::from_row_slice(
388 3,
389 3,
390 &[-3.0, -3.0, -3.0, 0.0, 0.0, 0.0, 3.0, 3.0, 3.0],
391 );
392
393 let (result_matrix, result_mean) =
395 icp_fixture.center_point_cloud_about_mean(&another_matrix);
396
397 assert_eq!(result_matrix, expected_matrix);
398 assert_eq!(result_mean, expected_mean);
399 }
400
401 #[rstest]
402 fn test_get_homogeneous_matrix(icp_fixture: Icp<Point3D>) {
403 let rotation = OMatrix::<f32, Dyn, Dyn>::identity_generic(Dyn(3), Dyn(3));
405
406 let translation = OMatrix::<f32, U1, Dyn>::from_row_slice(&[1.0, 2.0, 3.0]);
408
409 let expected_matrix = OMatrix::<f32, Dyn, Dyn>::from_row_slice(
411 4,
412 4,
413 &[
414 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 2.0, 0.0, 0.0, 1.0, 3.0, 0.0, 0.0, 0.0, 1.0,
415 ],
416 );
417
418 let result_matrix = icp_fixture.get_homogeneous_matrix(&translation, &rotation, 3);
419
420 assert_eq!(result_matrix, expected_matrix);
421 }
422
423 #[rstest]
424 fn test_transform_matrix(icp_fixture: Icp<Point3D>) {
425 let mut matrix = OMatrix::<f32, Dyn, Dyn>::from_row_slice(2, 2, &[1.0, 2.0, 3.0, 4.0]);
426 let transformation = OMatrix::<f32, Dyn, Dyn>::from_row_slice(
427 3,
428 3,
429 &[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0],
430 );
431
432 let expected = OMatrix::<f32, Dyn, Dyn>::from_row_slice(2, 2, &[1.0, 2.0, 3.0, 4.0]);
433
434 icp_fixture.transform_matrix(&mut matrix, &transformation);
435 assert_eq!(matrix, expected);
436
437 let transformation = OMatrix::<f32, Dyn, Dyn>::from_row_slice(
439 3,
440 3,
441 &[1.0, 0.0, 1.0, 0.0, 1.0, 2.0, 0.0, 0.0, 1.0],
442 );
443
444 let expected = OMatrix::<f32, Dyn, Dyn>::from_row_slice(2, 2, &[2.0, 4.0, 4.0, 6.0]);
445
446 icp_fixture.transform_matrix(&mut matrix, &transformation);
447 assert_eq!(matrix, expected);
448 }
449
450 #[rstest]
451 fn test_icp_registration(icp_fixture: Icp<Point3D>) {
452 let target_point_set = PointSet {
453 points: vec![
454 Point3D {
455 x: 1.0,
456 y: 1.0,
457 z: 1.0,
458 },
459 Point3D {
460 x: 2.0,
461 y: 2.0,
462 z: 2.0,
463 },
464 Point3D {
465 x: 3.0,
466 y: 3.0,
467 z: 3.0,
468 },
469 Point3D {
470 x: 3.0,
471 y: 3.0,
472 z: 3.0,
473 },
474 ],
475 };
476
477 let result = icp_fixture.register(&target_point_set);
478
479 assert!(result[(0, 3)].abs() < 1e-3);
480 assert!(result[(1, 3)].abs() < 1e-3);
481 assert!(result[(2, 3)].abs() < 1e-3);
482 }
483}