scirs2_spatial/
procrustes.rs1use crate::error::{SpatialError, SpatialResult};
11use scirs2_core::ndarray::{Array1, Array2, ArrayView2, Axis};
12
13#[allow(dead_code)]
15fn check_array_finite(array: &ArrayView2<'_, f64>, name: &str) -> SpatialResult<()> {
16 for value in array.iter() {
17 if !value.is_finite() {
18 return Err(SpatialError::ValueError(format!(
19 "Array '{name}' contains non-finite values"
20 )));
21 }
22 }
23 Ok(())
24}
25
26#[derive(Debug, Clone)]
28pub struct ProcrustesParams {
29 pub scale: f64,
31 pub rotation: Array2<f64>,
33 pub translation: Array1<f64>,
35}
36
37impl ProcrustesParams {
38 pub fn transform(&self, points: &ArrayView2<'_, f64>) -> Array2<f64> {
48 let mut result = points.to_owned() * self.scale;
50 result = result.dot(&self.rotation.t());
51
52 for mut row in result.rows_mut() {
54 for (i, val) in row.iter_mut().enumerate() {
55 *val += self.translation[i];
56 }
57 }
58
59 result
60 }
61}
62
63#[allow(dead_code)]
86pub fn procrustes(
87 data1: &ArrayView2<'_, f64>,
88 data2: &ArrayView2<'_, f64>,
89) -> SpatialResult<(Array2<f64>, Array2<f64>, f64)> {
90 check_array_finite(data1, "data1")?;
92 check_array_finite(data2, "data2")?;
93
94 if data1.shape() != data2.shape() {
95 return Err(SpatialError::DimensionError(format!(
96 "Input arrays must have the same shape. Got {:?} and {:?}",
97 data1.shape(),
98 data2.shape()
99 )));
100 }
101
102 let (n_points, n_dims) = (data1.nrows(), data1.ncols());
103
104 if n_points == 0 || n_dims == 0 {
105 return Err(SpatialError::DimensionError(
106 "Input arrays cannot be empty".to_string(),
107 ));
108 }
109
110 let mean1 = data1.mean_axis(Axis(0)).unwrap();
112 let mean2 = data2.mean_axis(Axis(0)).unwrap();
113
114 let mut centered1 = data1.to_owned();
115 let mut centered2 = data2.to_owned();
116
117 for mut row in centered1.rows_mut() {
118 for (i, val) in row.iter_mut().enumerate() {
119 *val -= mean1[i];
120 }
121 }
122
123 for mut row in centered2.rows_mut() {
124 for (i, val) in row.iter_mut().enumerate() {
125 *val -= mean2[i];
126 }
127 }
128
129 let _h = centered1.t().dot(¢ered2);
131
132 let result = procrustes_basic_impl(¢ered1.view(), ¢ered2.view(), &mean1, &mean2)?;
135
136 Ok(result)
137}
138
139#[allow(dead_code)]
141fn procrustes_basic_impl(
142 centered1: &ArrayView2<'_, f64>,
143 centered2: &ArrayView2<'_, f64>,
144 _mean1: &Array1<f64>,
145 mean2: &Array1<f64>,
146) -> SpatialResult<(Array2<f64>, Array2<f64>, f64)> {
147 let n_points = centered1.nrows() as f64;
148
149 let norm1_sq: f64 = centered1.iter().map(|x| x * x).sum();
151 let norm2_sq: f64 = centered2.iter().map(|x| x * x).sum();
152
153 let norm1 = (norm1_sq / n_points).sqrt();
154 let norm2 = (norm2_sq / n_points).sqrt();
155
156 let scale1 = if norm1 > 1e-10 { 1.0 / norm1 } else { 1.0 };
158 let scale2 = if norm2 > 1e-10 { 1.0 / norm2 } else { 1.0 };
159
160 let scaled1 = centered1 * scale1;
161 let scaled2 = centered2 * scale2;
162
163 let mut transformed1 = scaled1.to_owned();
166 let transformed2 = scaled2.to_owned();
167
168 for mut row in transformed1.rows_mut() {
170 for (i, val) in row.iter_mut().enumerate() {
171 *val += mean2[i];
172 }
173 }
174
175 let diff = &transformed1 - &transformed2;
177 let disparity: f64 = diff.iter().map(|x| x * x).sum();
178 let normalized_disparity = disparity / n_points;
179
180 Ok((transformed1, transformed2, normalized_disparity))
181}
182
183#[allow(dead_code)]
209pub fn procrustes_extended(
210 data1: &ArrayView2<'_, f64>,
211 data2: &ArrayView2<'_, f64>,
212 scaling: bool,
213 _reflection: bool,
214 translation: bool,
215) -> SpatialResult<(Array2<f64>, ProcrustesParams, f64)> {
216 check_array_finite(data1, "data1")?;
218 check_array_finite(data2, "data2")?;
219
220 if data1.shape() != data2.shape() {
221 return Err(SpatialError::DimensionError(format!(
222 "Input arrays must have the same shape. Got {:?} and {:?}",
223 data1.shape(),
224 data2.shape()
225 )));
226 }
227
228 let (n_points, n_dims) = (data1.nrows(), data1.ncols());
229
230 if n_points == 0 || n_dims == 0 {
231 return Err(SpatialError::DimensionError(
232 "Input arrays cannot be empty".to_string(),
233 ));
234 }
235
236 let mut scale = 1.0;
238 let rotation = Array2::eye(n_dims);
239 let mut translation_vec = Array1::zeros(n_dims);
240
241 let (centered1, centered2, mean1, mean2) = if translation {
243 let mean1 = data1.mean_axis(Axis(0)).unwrap();
244 let mean2 = data2.mean_axis(Axis(0)).unwrap();
245
246 let mut centered1 = data1.to_owned();
247 let mut centered2 = data2.to_owned();
248
249 for mut row in centered1.rows_mut() {
250 for (i, val) in row.iter_mut().enumerate() {
251 *val -= mean1[i];
252 }
253 }
254
255 for mut row in centered2.rows_mut() {
256 for (i, val) in row.iter_mut().enumerate() {
257 *val -= mean2[i];
258 }
259 }
260
261 (centered1, centered2, mean1, mean2)
262 } else {
263 (
264 data1.to_owned(),
265 data2.to_owned(),
266 Array1::zeros(n_dims),
267 Array1::zeros(n_dims),
268 )
269 };
270
271 if scaling {
273 let norm1_sq: f64 = centered1.iter().map(|x| x * x).sum();
274 let norm2_sq: f64 = centered2.iter().map(|x| x * x).sum();
275
276 let norm1 = (norm1_sq / n_points as f64).sqrt();
277 let norm2 = (norm2_sq / n_points as f64).sqrt();
278
279 if norm1 > 1e-10 && norm2 > 1e-10 {
280 scale = norm2 / norm1;
281 }
282 }
283
284 if translation {
289 for i in 0..n_dims {
290 translation_vec[i] = mean2[i] - scale * mean1[i];
291 }
292 }
293
294 let mut transformed = centered1 * scale;
296 transformed = transformed.dot(&rotation);
297
298 if translation {
299 for mut row in transformed.rows_mut() {
300 for (i, val) in row.iter_mut().enumerate() {
301 *val += translation_vec[i];
302 }
303 }
304 }
305
306 let target = if translation {
308 data2.to_owned()
309 } else {
310 centered2
311 };
312
313 let diff = &transformed - ⌖
314 let disparity: f64 = diff.iter().map(|x| x * x).sum();
315 let normalized_disparity = disparity / n_points as f64;
316
317 let params = ProcrustesParams {
318 scale,
319 rotation,
320 translation: translation_vec,
321 };
322
323 Ok((transformed, params, normalized_disparity))
324}