1use super::dp_alignment_core;
8use super::pairwise::elastic_align_pair;
9use super::srsf::{reparameterize_curve, srsf_single};
10use crate::error::FdarError;
11use crate::helpers::{l2_distance, simpsons_weights};
12use crate::iter_maybe_parallel;
13use crate::matrix::FdMatrix;
14#[cfg(feature = "parallel")]
15use rayon::iter::ParallelIterator;
16
17#[derive(Debug, Clone, PartialEq)]
21#[non_exhaustive]
22pub struct ClosedAlignmentResult {
23 pub gamma: Vec<f64>,
25 pub f_aligned: Vec<f64>,
27 pub distance: f64,
29 pub optimal_rotation: usize,
31}
32
33#[derive(Debug, Clone, PartialEq)]
35#[non_exhaustive]
36pub struct ClosedKarcherMeanResult {
37 pub mean: Vec<f64>,
39 pub mean_srsf: Vec<f64>,
41 pub gammas: FdMatrix,
43 pub aligned_data: FdMatrix,
45 pub rotations: Vec<usize>,
47 pub n_iter: usize,
49 pub converged: bool,
51}
52
53fn circular_shift(f: &[f64], k: usize) -> Vec<f64> {
57 let m = f.len();
58 if m == 0 || k == 0 {
59 return f.to_vec();
60 }
61 let k = k % m;
62 (0..m).map(|j| f[(j + k) % m]).collect()
63}
64
65fn find_best_rotation(f1: &[f64], f2: &[f64], argvals: &[f64], lambda: f64) -> (usize, f64) {
69 let m = f1.len();
70 if m < 2 {
71 return (0, 0.0);
72 }
73
74 let step_size = (m / 20).max(1);
76 let mut best_k = 0;
77 let mut best_dist = f64::INFINITY;
78
79 let mut k = 0;
80 while k < m {
81 let f2_rot = circular_shift(f2, k);
82 let q1 = srsf_single(f1, argvals);
83 let q2 = srsf_single(&f2_rot, argvals);
84 let gamma = dp_alignment_core(&q1, &q2, argvals, lambda);
85 let f_aligned = reparameterize_curve(&f2_rot, argvals, &gamma);
86 let q_aligned = srsf_single(&f_aligned, argvals);
87 let weights = simpsons_weights(argvals);
88 let dist = l2_distance(&q1, &q_aligned, &weights);
89
90 if dist < best_dist {
91 best_dist = dist;
92 best_k = k;
93 }
94 k += step_size;
95 }
96
97 let search_start = best_k.saturating_sub(step_size);
99 let search_end = (best_k + step_size).min(m);
100
101 for k in search_start..search_end {
102 if k % step_size == 0 {
103 continue; }
105 let f2_rot = circular_shift(f2, k);
106 let q1 = srsf_single(f1, argvals);
107 let q2 = srsf_single(&f2_rot, argvals);
108 let gamma = dp_alignment_core(&q1, &q2, argvals, lambda);
109 let f_aligned = reparameterize_curve(&f2_rot, argvals, &gamma);
110 let q_aligned = srsf_single(&f_aligned, argvals);
111 let weights = simpsons_weights(argvals);
112 let dist = l2_distance(&q1, &q_aligned, &weights);
113
114 if dist < best_dist {
115 best_dist = dist;
116 best_k = k;
117 }
118 }
119
120 (best_k, best_dist)
121}
122
123#[must_use = "expensive computation whose result should not be discarded"]
140pub fn elastic_align_pair_closed(
141 f1: &[f64],
142 f2: &[f64],
143 argvals: &[f64],
144 lambda: f64,
145) -> Result<ClosedAlignmentResult, FdarError> {
146 let m = f1.len();
147 if m != f2.len() || m != argvals.len() {
148 return Err(FdarError::InvalidDimension {
149 parameter: "f1/f2/argvals",
150 expected: format!("equal lengths, f1 has {m}"),
151 actual: format!("f2 has {}, argvals has {}", f2.len(), argvals.len()),
152 });
153 }
154 if m < 2 {
155 return Err(FdarError::InvalidDimension {
156 parameter: "f1",
157 expected: "length >= 2".to_string(),
158 actual: format!("length {m}"),
159 });
160 }
161
162 let (best_k, _) = find_best_rotation(f1, f2, argvals, lambda);
163
164 let f2_rotated = circular_shift(f2, best_k);
166 let result = elastic_align_pair(f1, &f2_rotated, argvals, lambda);
167
168 Ok(ClosedAlignmentResult {
169 gamma: result.gamma,
170 f_aligned: result.f_aligned,
171 distance: result.distance,
172 optimal_rotation: best_k,
173 })
174}
175
176#[must_use = "expensive computation whose result should not be discarded"]
189pub fn elastic_distance_closed(
190 f1: &[f64],
191 f2: &[f64],
192 argvals: &[f64],
193 lambda: f64,
194) -> Result<f64, FdarError> {
195 Ok(elastic_align_pair_closed(f1, f2, argvals, lambda)?.distance)
196}
197
198#[must_use = "expensive computation whose result should not be discarded"]
213pub fn karcher_mean_closed(
214 data: &FdMatrix,
215 argvals: &[f64],
216 max_iter: usize,
217 tol: f64,
218 lambda: f64,
219) -> Result<ClosedKarcherMeanResult, FdarError> {
220 let (n, m) = data.shape();
221 if m != argvals.len() {
222 return Err(FdarError::InvalidDimension {
223 parameter: "argvals",
224 expected: format!("length {m}"),
225 actual: format!("length {}", argvals.len()),
226 });
227 }
228 if m < 2 {
229 return Err(FdarError::InvalidDimension {
230 parameter: "data",
231 expected: "ncols >= 2".to_string(),
232 actual: format!("ncols = {m}"),
233 });
234 }
235 if n == 0 {
236 return Err(FdarError::InvalidDimension {
237 parameter: "data",
238 expected: "nrows > 0".to_string(),
239 actual: "nrows = 0".to_string(),
240 });
241 }
242
243 let mut mu: Vec<f64> = data.row(0);
245 let mut mu_q = srsf_single(&mu, argvals);
246
247 let mut gammas = FdMatrix::zeros(n, m);
248 let mut rotations = vec![0usize; n];
249 let mut converged = false;
250 let mut n_iter = 0;
251
252 for iter in 0..max_iter {
253 n_iter = iter + 1;
254
255 let align_results: Vec<(ClosedAlignmentResult, Vec<f64>)> = iter_maybe_parallel!(0..n)
257 .map(|i| {
258 let fi = data.row(i);
259 let res = elastic_align_pair_closed(&mu, &fi, argvals, lambda)
260 .expect("dimension invariant: all curves have length m");
261 let q_warped = srsf_single(&res.f_aligned, argvals);
262 (res, q_warped)
263 })
264 .collect();
265
266 let mut mu_q_new = vec![0.0; m];
268 for (i, (res, q_aligned)) in align_results.iter().enumerate() {
269 for j in 0..m {
270 gammas[(i, j)] = res.gamma[j];
271 mu_q_new[j] += q_aligned[j];
272 }
273 rotations[i] = res.optimal_rotation;
274 }
275 for j in 0..m {
276 mu_q_new[j] /= n as f64;
277 }
278
279 let diff_norm: f64 = mu_q
281 .iter()
282 .zip(mu_q_new.iter())
283 .map(|(&a, &b)| (a - b).powi(2))
284 .sum::<f64>()
285 .sqrt();
286 let old_norm: f64 = mu_q.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
287 let rel = diff_norm / old_norm;
288
289 mu_q = mu_q_new;
290
291 if rel < tol {
292 converged = true;
293 break;
294 }
295
296 mu = crate::alignment::srsf::srsf_inverse(&mu_q, argvals, mu[0]);
298 }
299
300 let mut aligned_data = FdMatrix::zeros(n, m);
302 for i in 0..n {
303 let fi = data.row(i);
304 let f_rotated = circular_shift(&fi, rotations[i]);
305 let gamma_i: Vec<f64> = (0..m).map(|j| gammas[(i, j)]).collect();
306 let f_aligned = reparameterize_curve(&f_rotated, argvals, &gamma_i);
307 for j in 0..m {
308 aligned_data[(i, j)] = f_aligned[j];
309 }
310 }
311
312 mu = crate::alignment::srsf::srsf_inverse(&mu_q, argvals, mu[0]);
314
315 Ok(ClosedKarcherMeanResult {
316 mean: mu,
317 mean_srsf: mu_q,
318 gammas,
319 aligned_data,
320 rotations,
321 n_iter,
322 converged,
323 })
324}
325
326#[cfg(test)]
327mod tests {
328 use super::*;
329 use crate::test_helpers::uniform_grid;
330
331 #[test]
332 fn closed_align_identity() {
333 let m = 30;
334 let argvals = uniform_grid(m);
335 let f: Vec<f64> = argvals
336 .iter()
337 .map(|&t| (2.0 * std::f64::consts::PI * t).sin())
338 .collect();
339
340 let result = elastic_align_pair_closed(&f, &f, &argvals, 0.0).unwrap();
341 assert!(
342 result.distance < 0.1,
343 "identical closed curves should have near-zero distance, got {}",
344 result.distance
345 );
346 assert_eq!(
347 result.optimal_rotation, 0,
348 "identical curves should need no rotation"
349 );
350 }
351
352 #[test]
353 fn closed_align_shifted() {
354 let m = 40;
356 let argvals = uniform_grid(m);
357 let f1: Vec<f64> = argvals
358 .iter()
359 .map(|&t| (2.0 * std::f64::consts::PI * t).sin() + 0.5 * t)
360 .collect();
361 let shift = 5;
363 let f2 = circular_shift(&f1, shift);
364
365 let result = elastic_align_pair_closed(&f1, &f2, &argvals, 0.0).unwrap();
366 assert!(
368 result.distance < 1.0,
369 "distance after closed alignment should be small, got {}",
370 result.distance
371 );
372 }
373
374 #[test]
375 fn closed_distance_symmetric() {
376 let m = 25;
377 let argvals = uniform_grid(m);
378 let f1: Vec<f64> = argvals
379 .iter()
380 .map(|&t| (2.0 * std::f64::consts::PI * t).sin())
381 .collect();
382 let f2: Vec<f64> = argvals
383 .iter()
384 .map(|&t| (2.0 * std::f64::consts::PI * t).cos())
385 .collect();
386
387 let d12 = elastic_distance_closed(&f1, &f2, &argvals, 0.0).unwrap();
388 let d21 = elastic_distance_closed(&f2, &f1, &argvals, 0.0).unwrap();
389
390 assert!(
392 d12 >= 0.0 && d12.is_finite(),
393 "d12 should be non-negative finite, got {d12}"
394 );
395 assert!(
396 d21 >= 0.0 && d21.is_finite(),
397 "d21 should be non-negative finite, got {d21}"
398 );
399 assert!(
402 d12.max(d21) < 2.0 * d12.min(d21) + 0.5,
403 "closed distances should be in comparable range: d12={d12:.4}, d21={d21:.4}"
404 );
405 }
406
407 #[test]
408 fn closed_karcher_mean_smoke() {
409 let n = 5;
410 let m = 25;
411 let argvals = uniform_grid(m);
412
413 let mut data_flat = vec![0.0; n * m];
415 for i in 0..n {
416 let shift = i as f64 * 0.1;
417 for j in 0..m {
418 let t = argvals[j];
419 data_flat[i + j * n] = (2.0 * std::f64::consts::PI * (t + shift)).sin();
420 }
421 }
422 let data = FdMatrix::from_column_major(data_flat, n, m).unwrap();
423
424 let result = karcher_mean_closed(&data, &argvals, 10, 1e-3, 0.0).unwrap();
425 assert_eq!(result.mean.len(), m);
426 assert_eq!(result.mean_srsf.len(), m);
427 assert_eq!(result.gammas.shape(), (n, m));
428 assert_eq!(result.aligned_data.shape(), (n, m));
429 assert_eq!(result.rotations.len(), n);
430 assert!(result.n_iter <= 10);
431 }
432}