1pub use types::{ICPConfiguration, ICPConfigurationBuilder, ICPError, ICPResult, ICPSuccess};
25
26use nalgebra::{ComplexField, Isometry, Point, RealField, SimdRealField};
27use num_traits::{AsPrimitive, Bounded};
28
29use crate::{
30 kd_tree::KDTree,
31 point_clouds::find_nearest_neighbour_naive,
32 types::{AbstractIsometry, IsNan, IsometryAbstractor},
33 Sum, Vec,
34};
35
36use helpers::{calculate_mse, get_rotation_matrix_and_centroids};
37
38mod helpers;
39mod types;
40
41#[cfg_attr(
62 feature = "tracing",
63 tracing::instrument("ICP Algorithm Iteration", skip_all, level = "info")
64)]
65pub fn icp_iteration<T, const N: usize>(
66 points_a: &[Point<T, N>],
67 transformed_points: &mut [Point<T, N>],
68 points_b: &[Point<T, N>],
69 target_points_tree: Option<&KDTree<T, N>>,
70 current_transform: &mut Isometry<
71 T,
72 <IsometryAbstractor<T, N> as AbstractIsometry<T, N>>::RotType,
73 N,
74 >,
75 current_mse: &mut T,
76 config: &ICPConfiguration<T>,
77) -> Result<T, ICPError<T, N>>
78where
79 T: Bounded + Copy + Default + RealField + Sum + SimdRealField,
80 usize: AsPrimitive<T>,
81 IsometryAbstractor<T, N>: AbstractIsometry<T, N>,
82{
83 let closest_points = transformed_points.iter().try_fold(
84 Vec::with_capacity(transformed_points.len()),
85 |mut accumulator, transformed_point_a| {
86 accumulator.push(
87 target_points_tree
88 .and_then(|kd_tree| kd_tree.nearest(transformed_point_a))
89 .or_else(|| find_nearest_neighbour_naive(transformed_point_a, points_b))
90 .ok_or(ICPError::NoNearestNeighbour)?,
91 );
92
93 Ok(accumulator)
94 },
95 )?;
96
97 let (rot_mat, mean_a, mean_b) =
98 get_rotation_matrix_and_centroids(transformed_points, &closest_points);
99
100 *current_transform =
101 IsometryAbstractor::<T, N>::update_transform(current_transform, mean_a, mean_b, &rot_mat);
102
103 for (idx, point_a) in points_a.iter().enumerate() {
104 transformed_points[idx] = current_transform.transform_point(point_a);
105 }
106 let new_mse = calculate_mse(transformed_points, closest_points.as_slice());
107 log::trace!("New MSE: {new_mse}");
108
109 if config
111 .mse_absolute_threshold
112 .map(|thres| new_mse < thres)
113 .unwrap_or_default()
114 || <T as ComplexField>::abs(*current_mse - new_mse) < config.mse_interval_threshold
115 {
116 return Ok(new_mse);
117 }
118
119 *current_mse = new_mse;
120 Err(ICPError::IterationDidNotConverge((mean_a, mean_b)))
121}
122
123#[cfg_attr(
140 feature = "tracing",
141 tracing::instrument("Full ICP Algorithm", skip_all, level = "info")
142)]
143pub fn icp<T, const N: usize>(
144 points_a: &[Point<T, N>],
145 points_b: &[Point<T, N>],
146 config: ICPConfiguration<T>,
147) -> ICPResult<T, <IsometryAbstractor<T, N> as AbstractIsometry<T, N>>::RotType, N>
148where
149 T: Bounded + Copy + Default + IsNan + RealField + Sum,
150 usize: AsPrimitive<T>,
151 IsometryAbstractor<T, N>: AbstractIsometry<T, N>,
152{
153 if points_a.is_empty() {
154 return Err(ICPError::SourcePointCloudEmpty);
155 }
156
157 if points_b.is_empty() {
158 return Err(ICPError::TargetPointCloudEmpty);
159 }
160
161 if config.max_iterations == 0 {
162 return Err(ICPError::IterationNumIsZero);
163 }
164
165 if config.mse_interval_threshold <= T::default_epsilon() {
166 return Err(ICPError::MSEIntervalThreshold);
167 }
168
169 if config
170 .mse_absolute_threshold
171 .map(|thres| thres.is_nan() || thres <= T::default_epsilon())
172 .unwrap_or_default()
173 {
174 return Err(ICPError::MSEAbsoluteThreshold);
175 }
176
177 let mut points_to_transform = points_a.to_vec();
178 let target_points_tree = config.use_kd_tree.then_some(KDTree::from(points_b));
179 let mut current_transform = Isometry::identity();
180 let mut current_mse = <T as Bounded>::max_value();
181
182 for iteration_num in 0..config.max_iterations {
183 log::trace!(
184 "Running iteration number {iteration_num}/{}",
185 config.max_iterations
186 );
187 if let Ok(mse) = icp_iteration::<T, N>(
188 points_a,
189 &mut points_to_transform,
190 points_b,
191 target_points_tree.as_ref(),
192 &mut current_transform,
193 &mut current_mse,
194 &config,
195 ) {
196 log::trace!("Converged after {iteration_num} iterations with an MSE of {mse}");
197 return Ok(ICPSuccess {
198 transform: current_transform,
199 mse,
200 iteration_num,
201 });
202 }
203 }
204
205 Err(ICPError::AlrogithmDidNotConverge)
206}
207
208#[cfg(feature = "pregenerated")]
209macro_rules! impl_icp_algorithm {
210 ($precision:expr, $doc:tt, $nd:expr, $rot_type:expr) => {
211 ::paste::paste! {
212 #[doc = "A premade variant of the ICP algorithm function, in " $nd "D space and " $doc "-precision floats."]
213 pub fn [<icp_$nd d>](points_a: &[Point<$precision, $nd>],
214 points_b: &[Point<$precision, $nd>],
215 config: ICPConfiguration<$precision>) -> ICPResult<$precision, $rot_type<$precision>, $nd> {
216 super::icp(points_a, points_b, config)
217 }
218 }
219 };
220
221 ($precision:expr, doc $doc:tt) => {
222 ::paste::paste! {
223 pub(super) mod [<$doc _precision>] {
224 use nalgebra::{Point, UnitComplex, UnitQuaternion};
225 use super::{ICPConfiguration, ICPResult};
226
227 impl_icp_algorithm!($precision, $doc, 2, UnitComplex);
228 impl_icp_algorithm!($precision, $doc, 3, UnitQuaternion);
229 }
230 }
231 }
232}
233
234#[cfg(feature = "pregenerated")]
235impl_icp_algorithm!(f32, doc single);
236#[cfg(feature = "pregenerated")]
237impl_icp_algorithm!(f64, doc double);
238
239#[cfg(test)]
240mod tests {
241 use nalgebra::{Isometry2, Isometry3, UnitComplex, Vector2, Vector3};
242
243 use crate::{
244 array,
245 point_clouds::{generate_point_cloud, transform_point_cloud},
246 };
247
248 use super::*;
249
250 #[test]
251 fn test_icp_errors() {
252 let points = generate_point_cloud(10, array::from_fn(|_| -15.0..=15.0));
253 let config_builder = ICPConfiguration::builder();
254
255 let mut res: ICPResult<f32, UnitComplex<f32>, 2> =
256 icp(&[], points.as_slice(), config_builder.build());
257 assert_eq!(res.unwrap_err(), ICPError::SourcePointCloudEmpty);
258
259 res = icp(points.as_slice(), &[], config_builder.build());
260 assert_eq!(res.unwrap_err(), ICPError::TargetPointCloudEmpty);
261
262 res = icp(
263 points.as_slice(),
264 points.as_slice(),
265 config_builder.with_max_iterations(0).build(),
266 );
267 assert_eq!(res.unwrap_err(), ICPError::IterationNumIsZero);
268
269 res = icp(
270 points.as_slice(),
271 points.as_slice(),
272 config_builder.with_mse_interval_threshold(0.0).build(),
273 );
274 assert_eq!(res.unwrap_err(), ICPError::MSEIntervalThreshold);
275
276 res = icp(
277 points.as_slice(),
278 points.as_slice(),
279 config_builder
280 .with_absolute_mse_threshold(Some(0.0))
281 .build(),
282 );
283 assert_eq!(res.unwrap_err(), ICPError::MSEAbsoluteThreshold);
284 }
285
286 #[test]
287 fn test_no_convegence() {
288 let points = generate_point_cloud(1000, array::from_fn(|_| -15.0..=15.0));
289 let translation = Vector2::new(-12.5, 7.3);
290 let isom = Isometry2::new(translation, 90.0f32.to_radians());
291 let points_transformed = transform_point_cloud(&points, isom);
292
293 let res = icp(
294 points.as_slice(),
295 points_transformed.as_slice(),
296 ICPConfiguration::builder()
297 .with_max_iterations(1) .with_mse_interval_threshold(0.001)
299 .build(),
300 );
301 assert!(res.is_err());
302 assert_eq!(res.unwrap_err(), ICPError::AlrogithmDidNotConverge);
303 }
304
305 #[test]
306 fn test_icp_absolute_threshold() {
308 let points = generate_point_cloud(100, array::from_fn(|_| -15.0..=15.0));
309 let translation = Vector2::new(-0.8, 1.3);
310 let isom = Isometry2::new(translation, 0.1);
311 let points_transformed = transform_point_cloud(&points, isom);
312
313 let res = icp(
314 points.as_slice(),
315 points_transformed.as_slice(),
316 ICPConfiguration::builder()
317 .with_max_iterations(10)
318 .with_absolute_mse_threshold(Some(0.1))
319 .with_mse_interval_threshold(0.001)
320 .build(),
321 );
322 assert!(res.is_ok());
323 assert!(res.unwrap().mse < 0.1);
324 }
325
326 #[test]
327 fn test_icp_2d() {
328 let points = generate_point_cloud(100, array::from_fn(|_| -15.0..=15.0));
329 let translation = Vector2::new(-0.8, 1.3);
330 let isom = Isometry2::new(translation, 0.1);
331 let points_transformed = transform_point_cloud(&points, isom);
332
333 let res = icp(
334 points.as_slice(),
335 points_transformed.as_slice(),
336 ICPConfiguration::builder()
337 .with_max_iterations(10)
338 .with_mse_interval_threshold(0.01)
339 .build(),
340 );
341 assert!(res.is_ok());
342 assert!(res.unwrap().mse < 0.01);
343 }
344
345 #[test]
346 fn test_icp_2d_with_kd() {
347 let points = generate_point_cloud(100, array::from_fn(|_| -15.0..=15.0));
348 let isom = Isometry2::new(Vector2::new(-0.8, 1.3), 0.1);
349 let points_transformed = transform_point_cloud(&points, isom);
350
351 let res = icp(
352 points.as_slice(),
353 points_transformed.as_slice(),
354 ICPConfiguration::builder()
355 .with_kd_tree(true)
356 .with_max_iterations(50)
357 .with_mse_interval_threshold(0.01)
358 .build(),
359 );
360 assert!(res.is_ok());
361 assert!(res.unwrap().mse < 0.01);
362 }
363
364 #[test]
365 fn test_icp_3d() {
366 let points = generate_point_cloud(500, array::from_fn(|_| -15.0..=15.0));
367 let translation = Vector3::new(-0.8, 1.3, 0.2);
368 let rotation = Vector3::new(0.1, 0.2, -0.21);
369 let isom = Isometry3::new(translation, rotation);
370 let points_transformed = transform_point_cloud(&points, isom);
371
372 let res = icp(
373 points.as_slice(),
374 points_transformed.as_slice(),
375 ICPConfiguration::builder()
376 .with_max_iterations(50)
377 .with_mse_interval_threshold(0.01)
378 .build(),
379 );
380 assert!(res.is_ok());
381 assert!(res.unwrap().mse < 0.05);
382 }
383
384 #[test]
385 fn test_icp_3d_with_kd() {
386 let points = generate_point_cloud(500, array::from_fn(|_| -15.0..=15.0));
387 let translation = Vector3::new(-0.8, 1.3, 0.2);
388 let rotation = Vector3::new(0.1, 0.2, -0.21);
389 let isom = Isometry3::new(translation, rotation);
390 let points_transformed = transform_point_cloud(&points, isom);
391
392 let res = icp(
393 points.as_slice(),
394 points_transformed.as_slice(),
395 ICPConfiguration::builder()
396 .with_kd_tree(true)
397 .with_max_iterations(50)
398 .with_mse_interval_threshold(0.01)
399 .build(),
400 );
401 assert!(res.is_ok());
402 assert!(res.unwrap().mse < 0.05);
403 }
404}