1use nalgebra::{
25 ArrayStorage, ClosedAddAssign, ClosedDivAssign, ClosedSubAssign, Const, Matrix, Point, SMatrix,
26 Scalar, Vector,
27};
28use num_traits::{AsPrimitive, NumOps, Zero};
29
30use crate::{array, point_clouds::calculate_point_cloud_center, utils::distance_squared, Sum};
31
32#[inline]
45#[cfg_attr(
46 feature = "tracing",
47 tracing::instrument("Calculate MSE", skip_all, level = "debug")
48)]
49pub(crate) fn calculate_mse<T, const N: usize>(
50 transformed_points_a: &[Point<T, N>],
51 closest_points_in_b: &[Point<T, N>],
52) -> T
53where
54 T: Copy + Default + NumOps + Scalar + Sum,
55{
56 transformed_points_a
57 .iter()
58 .zip(closest_points_in_b.iter())
59 .map(|(transformed_a, closest_point_in_b)| {
60 distance_squared(transformed_a, closest_point_in_b)
61 })
62 .sum()
63}
64
65#[inline]
78#[cfg_attr(
79 feature = "tracing",
80 tracing::instrument("Calculate Outer Product", skip_all, level = "trace")
81)]
82pub(crate) fn outer_product<T, const N: usize>(
83 point_a: &Vector<T, Const<N>, ArrayStorage<T, N, 1>>,
84 point_b: &Vector<T, Const<N>, ArrayStorage<T, N, 1>>,
85) -> SMatrix<T, N, N>
86where
87 T: NumOps + Copy,
88{
89 Matrix::from_data(ArrayStorage(array::from_fn(|a_idx| {
90 array::from_fn(|b_idx| point_a.data.0[0][a_idx] * point_b.data.0[0][b_idx])
91 })))
92}
93
94#[inline]
113#[cfg_attr(
114 feature = "tracing",
115 tracing::instrument("Estimate Transform And Means", skip_all, level = "debug")
116)]
117pub(crate) fn get_rotation_matrix_and_centroids<T, const N: usize>(
118 transformed_points_a: &[Point<T, N>],
119 closest_points: &[Point<T, N>],
120) -> (SMatrix<T, N, N>, Point<T, N>, Point<T, N>)
121where
122 T: ClosedAddAssign + ClosedDivAssign + ClosedSubAssign + Copy + NumOps + Scalar + Zero,
123 usize: AsPrimitive<T>,
124{
125 let (mean_transformed_a, mean_closest) = (
126 calculate_point_cloud_center(transformed_points_a),
127 calculate_point_cloud_center(closest_points),
128 );
129
130 let rot_mat = transformed_points_a.iter().zip(closest_points.iter()).fold(
131 Matrix::from_array_storage(ArrayStorage([[T::zero(); N]; N])),
132 |rot_mat, (transformed_point_a, closest_point)| {
133 let a_distance_from_centroid = transformed_point_a - mean_transformed_a;
134 let closest_point_distance_from_centroid = closest_point - mean_closest;
135 rot_mat
136 + outer_product(
137 &a_distance_from_centroid,
138 &closest_point_distance_from_centroid,
139 )
140 },
141 );
142
143 (rot_mat, mean_transformed_a, mean_closest)
144}
145
146#[cfg(test)]
147mod tests {
148 use super::*;
149 use nalgebra::{Point3, Vector3};
150
151 #[test]
152 fn test_calculate_mean() {
153 let points: [Point<f64, 3>; 3] = [
155 Point::from([1.0, 2.0, 3.0]),
156 Point::from([4.0, 5.0, 6.0]),
157 Point::from([7.0, 8.0, 9.0]),
158 ];
159
160 let mean = calculate_point_cloud_center(&points);
162 assert_eq!(
163 mean,
164 Point::from([4.0, 5.0, 6.0]),
165 "The mean point was not calculated correctly."
166 );
167 }
168
169 #[test]
170 fn test_calculate_mse() {
171 let transformed_points_a: [Point<f64, 3>; 3] = [
173 Point::from([1.0, 2.0, 3.0]),
174 Point::from([4.0, 4.0, 4.0]),
175 Point::from([7.0, 7.0, 7.0]),
176 ];
177
178 let points_b: [Point<f64, 3>; 3] = [
179 Point::from([1.0, 1.0, 1.0]),
180 Point::from([4.0, 5.0, 6.0]),
181 Point::from([8.0, 8.0, 8.0]),
182 ];
183
184 let mse = calculate_mse(&transformed_points_a, &points_b);
186
187 assert_eq!(
188 mse, 13.0,
189 "The calculated MSE does not match the expected value."
190 );
191 }
192
193 #[test]
194 fn test_outer_product() {
195 let point_a = Vector3::new(1.0, 2.0, 3.0);
197 let point_b = Vector3::new(4.0, 5.0, 6.0);
198
199 let result = outer_product(&point_a, &point_b);
201 assert_eq!(
202 result,
203 SMatrix::from_data(ArrayStorage([
204 [4.0, 5.0, 6.0],
205 [8.0, 10.0, 12.0],
206 [12.0, 15.0, 18.0]
207 ])),
208 "The calculated outer product does not match the expected value."
209 );
210 }
211
212 #[test]
213 fn test_get_rotation_matrix_and_centroids() {
214 let points_a: [Point<f64, 3>; 3] = [
216 Point::from([6.0, 4.0, 20.0]),
217 Point::from([100.0, 60.0, 3.0]),
218 Point::from([5.0, 20.0, 10.0]),
219 ];
220
221 let points_b: [Point<f64, 3>; 3] = [
222 Point::from([40.0, 22.0, 12.0]),
223 Point::from([10.0, 14.0, 10.0]),
224 Point::from([7.0, 30.0, 20.0]),
225 ];
226
227 let (rot_mat, mean_a, mean_b) = get_rotation_matrix_and_centroids(&points_a, &points_b);
229 assert_eq!(
230 mean_a,
231 Point3::new(37.0, 28.0, 11.0),
232 "The calculated mean of points_a does not match the expected value."
233 );
234 assert_eq!(
235 mean_b,
236 Point3::new(19.0, 22.0, 14.0),
237 "The calculated mean of points_b does not match the expected value."
238 );
239 assert_eq!(
240 rot_mat,
241 Matrix::from_data(ArrayStorage([
242 [-834.0, -760.0, -382.0],
243 [-696.0, -320.0, -128.0],
244 [273.0, 56.0, 8.0]
245 ])),
246 "The calculated rotation matrix does not match the expected value."
247 );
248 }
249}