1use super::karcher::karcher_mean;
7use super::pairwise::{elastic_align_pair, elastic_self_distance_matrix};
8use super::srsf::srsf_single;
9use super::{AlignmentResult, KarcherMeanResult};
10use crate::error::FdarError;
11use crate::helpers::simpsons_weights;
12use crate::matrix::FdMatrix;
13use crate::warping::l2_norm_l2;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
21#[non_exhaustive]
22pub enum ShapeQuotient {
23 #[default]
25 Reparameterization,
26 ReparameterizationTranslation,
28 ReparameterizationTranslationScale,
30}
31
32#[derive(Debug, Clone, PartialEq)]
34#[non_exhaustive]
35pub struct OrbitRepresentative {
36 pub representative: Vec<f64>,
38 pub representative_srsf: Vec<f64>,
40 pub gamma: Vec<f64>,
42 pub translation: f64,
44 pub scale: f64,
46}
47
48#[derive(Debug, Clone, PartialEq)]
50#[non_exhaustive]
51pub struct ShapeDistanceResult {
52 pub distance: f64,
54 pub gamma: Vec<f64>,
56 pub f2_aligned: Vec<f64>,
58}
59
60#[derive(Debug, Clone, PartialEq)]
62#[non_exhaustive]
63pub struct ShapeMeanResult {
64 pub mean: Vec<f64>,
66 pub mean_srsf: Vec<f64>,
68 pub gammas: FdMatrix,
70 pub aligned_data: FdMatrix,
72 pub n_iter: usize,
74 pub converged: bool,
76}
77
78fn integral_mean(f: &[f64], argvals: &[f64]) -> f64 {
82 let w = simpsons_weights(argvals);
83 let total_w: f64 = w.iter().sum();
84 if total_w <= 0.0 {
85 return 0.0;
86 }
87 let wsum: f64 = f.iter().zip(w.iter()).map(|(&fi, &wi)| fi * wi).sum();
88 wsum / total_w
89}
90
91fn preprocess_curve(f: &[f64], argvals: &[f64], quotient: ShapeQuotient) -> (Vec<f64>, f64, f64) {
95 let mut curve = f.to_vec();
96 let mut translation = 0.0;
97 let mut scale = 1.0;
98
99 match quotient {
100 ShapeQuotient::Reparameterization => {
101 }
103 ShapeQuotient::ReparameterizationTranslation => {
104 let mean_val = integral_mean(&curve, argvals);
106 translation = mean_val;
107 for v in &mut curve {
108 *v -= mean_val;
109 }
110 }
111 ShapeQuotient::ReparameterizationTranslationScale => {
112 let mean_val = integral_mean(&curve, argvals);
114 translation = mean_val;
115 for v in &mut curve {
116 *v -= mean_val;
117 }
118
119 let q = srsf_single(&curve, argvals);
121 let m = argvals.len();
123 let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1).max(1) as f64).collect();
124 let norm = l2_norm_l2(&q, &time);
125
126 if norm > 1e-10 {
127 scale = norm;
128 for v in &mut curve {
129 *v /= norm;
130 }
131 }
132 }
133 }
134
135 (curve, translation, scale)
136}
137
138fn preprocess_data(data: &FdMatrix, argvals: &[f64], quotient: ShapeQuotient) -> FdMatrix {
140 let (n, m) = data.shape();
141 let mut result = FdMatrix::zeros(n, m);
142 for i in 0..n {
143 let row = data.row(i);
144 let (processed, _, _) = preprocess_curve(&row, argvals, quotient);
145 for j in 0..m {
146 result[(i, j)] = processed[j];
147 }
148 }
149 result
150}
151
152pub fn orbit_representative(
167 f: &[f64],
168 argvals: &[f64],
169 quotient: ShapeQuotient,
170) -> Result<OrbitRepresentative, FdarError> {
171 let m = f.len();
172 if m != argvals.len() {
173 return Err(FdarError::InvalidDimension {
174 parameter: "f",
175 expected: format!("length {}", argvals.len()),
176 actual: format!("length {m}"),
177 });
178 }
179 if m < 2 {
180 return Err(FdarError::InvalidDimension {
181 parameter: "f",
182 expected: "length >= 2".to_string(),
183 actual: format!("length {m}"),
184 });
185 }
186
187 let (representative, translation, scale) = preprocess_curve(f, argvals, quotient);
188 let representative_srsf = srsf_single(&representative, argvals);
189 let gamma = argvals.to_vec(); Ok(OrbitRepresentative {
192 representative,
193 representative_srsf,
194 gamma,
195 translation,
196 scale,
197 })
198}
199
200#[must_use = "expensive computation whose result should not be discarded"]
215pub fn shape_distance(
216 f1: &[f64],
217 f2: &[f64],
218 argvals: &[f64],
219 quotient: ShapeQuotient,
220 lambda: f64,
221) -> Result<ShapeDistanceResult, FdarError> {
222 let m = f1.len();
223 if m != f2.len() || m != argvals.len() {
224 return Err(FdarError::InvalidDimension {
225 parameter: "f1/f2",
226 expected: format!("matching lengths == argvals.len() ({})", argvals.len()),
227 actual: format!("f1.len()={}, f2.len()={}", f1.len(), f2.len()),
228 });
229 }
230 if m < 2 {
231 return Err(FdarError::InvalidDimension {
232 parameter: "f1",
233 expected: "length >= 2".to_string(),
234 actual: format!("length {m}"),
235 });
236 }
237
238 let (f1_pre, _, _) = preprocess_curve(f1, argvals, quotient);
239 let (f2_pre, _, _) = preprocess_curve(f2, argvals, quotient);
240
241 let AlignmentResult {
242 gamma,
243 f_aligned,
244 distance,
245 } = elastic_align_pair(&f1_pre, &f2_pre, argvals, lambda);
246
247 Ok(ShapeDistanceResult {
248 distance,
249 gamma,
250 f2_aligned: f_aligned,
251 })
252}
253
254#[must_use = "expensive computation whose result should not be discarded"]
268pub fn shape_self_distance_matrix(
269 data: &FdMatrix,
270 argvals: &[f64],
271 quotient: ShapeQuotient,
272 lambda: f64,
273) -> Result<FdMatrix, FdarError> {
274 let (_n, m) = data.shape();
275 if argvals.len() != m {
276 return Err(FdarError::InvalidDimension {
277 parameter: "argvals",
278 expected: format!("{m}"),
279 actual: format!("{}", argvals.len()),
280 });
281 }
282
283 let preprocessed = preprocess_data(data, argvals, quotient);
284 Ok(elastic_self_distance_matrix(&preprocessed, argvals, lambda))
285}
286
287#[must_use = "expensive computation whose result should not be discarded"]
304pub fn shape_mean(
305 data: &FdMatrix,
306 argvals: &[f64],
307 quotient: ShapeQuotient,
308 lambda: f64,
309 max_iter: usize,
310 tol: f64,
311) -> Result<ShapeMeanResult, FdarError> {
312 let (n, m) = data.shape();
313 if argvals.len() != m {
314 return Err(FdarError::InvalidDimension {
315 parameter: "argvals",
316 expected: format!("{m}"),
317 actual: format!("{}", argvals.len()),
318 });
319 }
320 if n < 1 {
321 return Err(FdarError::InvalidDimension {
322 parameter: "data",
323 expected: "at least 1 row".to_string(),
324 actual: format!("{n} rows"),
325 });
326 }
327
328 let preprocessed = preprocess_data(data, argvals, quotient);
329 let KarcherMeanResult {
330 mean,
331 mean_srsf,
332 gammas,
333 aligned_data,
334 n_iter,
335 converged,
336 ..
337 } = karcher_mean(&preprocessed, argvals, max_iter, tol, lambda);
338
339 Ok(ShapeMeanResult {
340 mean,
341 mean_srsf,
342 gammas,
343 aligned_data,
344 n_iter,
345 converged,
346 })
347}
348
349#[cfg(test)]
352mod tests {
353 use super::*;
354 use crate::simulation::{sim_fundata, EFunType, EValType};
355 use crate::test_helpers::uniform_grid;
356
357 fn make_data(n: usize, m: usize) -> (FdMatrix, Vec<f64>) {
358 let t = uniform_grid(m);
359 let data = sim_fundata(n, &t, 3, EFunType::Fourier, EValType::Exponential, Some(99));
360 (data, t)
361 }
362
363 #[test]
366 fn orbit_representative_reparam_only() {
367 let t = uniform_grid(30);
368 let f: Vec<f64> = t.iter().map(|&x| (x * 6.0).sin()).collect();
369 let rep = orbit_representative(&f, &t, ShapeQuotient::Reparameterization).unwrap();
370 assert_eq!(rep.representative.len(), 30);
372 for i in 0..30 {
373 assert!(
374 (rep.representative[i] - f[i]).abs() < 1e-12,
375 "reparameterization-only orbit should not change the curve"
376 );
377 }
378 assert!((rep.translation - 0.0).abs() < f64::EPSILON);
379 assert!((rep.scale - 1.0).abs() < f64::EPSILON);
380 assert_eq!(rep.gamma, t);
381 }
382
383 #[test]
384 fn orbit_representative_translation() {
385 let t = uniform_grid(30);
386 let offset = 5.0;
387 let f: Vec<f64> = t.iter().map(|&x| (x * 6.0).sin() + offset).collect();
388 let rep =
389 orbit_representative(&f, &t, ShapeQuotient::ReparameterizationTranslation).unwrap();
390 let mean_after = integral_mean(&rep.representative, &t);
392 assert!(
393 mean_after.abs() < 1e-10,
394 "translation quotient should center the curve, mean={mean_after}"
395 );
396 }
397
398 #[test]
399 fn orbit_representative_translation_scale() {
400 let t = uniform_grid(50);
401 let f: Vec<f64> = t.iter().map(|&x| 10.0 * (x * 4.0).sin() + 3.0).collect();
402 let rep = orbit_representative(&f, &t, ShapeQuotient::ReparameterizationTranslationScale)
403 .unwrap();
404 assert!(rep.scale > 0.0, "scale factor should be positive");
405
406 let f2: Vec<f64> = t.iter().map(|&x| 20.0 * (x * 4.0).sin() + 3.0).collect();
408 let rep2 = orbit_representative(&f2, &t, ShapeQuotient::ReparameterizationTranslationScale)
409 .unwrap();
410
411 let dot: f64 = rep
413 .representative
414 .iter()
415 .zip(rep2.representative.iter())
416 .map(|(&a, &b)| a * b)
417 .sum();
418 let n1: f64 = rep
419 .representative
420 .iter()
421 .map(|&v| v * v)
422 .sum::<f64>()
423 .sqrt();
424 let n2: f64 = rep2
425 .representative
426 .iter()
427 .map(|&v| v * v)
428 .sum::<f64>()
429 .sqrt();
430 let corr = if n1 > 1e-10 && n2 > 1e-10 {
431 dot / (n1 * n2)
432 } else {
433 1.0
434 };
435 assert!(
436 corr > 0.99,
437 "scaled curves should have nearly identical representatives, corr={corr}"
438 );
439 }
440
441 #[test]
442 fn orbit_representative_length_mismatch() {
443 let t = uniform_grid(30);
444 let f = vec![1.0; 20];
445 assert!(orbit_representative(&f, &t, ShapeQuotient::Reparameterization).is_err());
446 }
447
448 #[test]
449 fn orbit_representative_too_short() {
450 let f = vec![1.0];
451 let t = vec![0.0];
452 assert!(orbit_representative(&f, &t, ShapeQuotient::Reparameterization).is_err());
453 }
454
455 #[test]
458 fn shape_distance_identical_curves() {
459 let t = uniform_grid(30);
460 let f: Vec<f64> = t.iter().map(|&x| (x * 6.0).sin()).collect();
461 let result = shape_distance(&f, &f, &t, ShapeQuotient::Reparameterization, 0.0).unwrap();
462 assert!(
463 result.distance < 0.1,
464 "distance between identical curves should be near zero, got {}",
465 result.distance
466 );
467 assert_eq!(result.gamma.len(), 30);
468 assert_eq!(result.f2_aligned.len(), 30);
469 }
470
471 #[test]
472 fn shape_distance_translated_curves() {
473 let t = uniform_grid(30);
474 let f1: Vec<f64> = t.iter().map(|&x| (x * 6.0).sin()).collect();
475 let f2: Vec<f64> = t.iter().map(|&x| (x * 6.0).sin() + 5.0).collect();
476
477 let d_no_trans =
479 shape_distance(&f1, &f2, &t, ShapeQuotient::Reparameterization, 0.0).unwrap();
480 let d_trans = shape_distance(
482 &f1,
483 &f2,
484 &t,
485 ShapeQuotient::ReparameterizationTranslation,
486 0.0,
487 )
488 .unwrap();
489
490 assert!(
491 d_trans.distance < d_no_trans.distance + 0.01,
492 "translation quotient should not increase distance: d_trans={}, d_no_trans={}",
493 d_trans.distance,
494 d_no_trans.distance
495 );
496 }
497
498 #[test]
499 fn shape_distance_length_mismatch() {
500 let t = uniform_grid(30);
501 let f1 = vec![0.0; 30];
502 let f2 = vec![0.0; 20];
503 assert!(shape_distance(&f1, &f2, &t, ShapeQuotient::Reparameterization, 0.0).is_err());
504 }
505
506 #[test]
509 fn shape_distance_matrix_smoke() {
510 let (data, t) = make_data(5, 20);
511 let dmat =
512 shape_self_distance_matrix(&data, &t, ShapeQuotient::Reparameterization, 0.0).unwrap();
513 assert_eq!(dmat.shape(), (5, 5));
514 for i in 0..5 {
516 assert!(
517 dmat[(i, i)].abs() < 1e-10,
518 "diagonal should be zero, got {}",
519 dmat[(i, i)]
520 );
521 }
522 for i in 0..5 {
524 for j in (i + 1)..5 {
525 assert!(
526 (dmat[(i, j)] - dmat[(j, i)]).abs() < 1e-10,
527 "distance matrix should be symmetric"
528 );
529 }
530 }
531 }
532
533 #[test]
534 fn shape_distance_matrix_argvals_mismatch() {
535 let (data, _) = make_data(5, 20);
536 let bad_t = uniform_grid(15);
537 assert!(
538 shape_self_distance_matrix(&data, &bad_t, ShapeQuotient::Reparameterization, 0.0)
539 .is_err()
540 );
541 }
542
543 #[test]
546 fn shape_mean_smoke() {
547 let (data, t) = make_data(6, 25);
548 let result =
549 shape_mean(&data, &t, ShapeQuotient::Reparameterization, 0.0, 5, 1e-2).unwrap();
550 assert_eq!(result.mean.len(), 25);
551 assert_eq!(result.mean_srsf.len(), 25);
552 assert_eq!(result.gammas.shape(), (6, 25));
553 assert_eq!(result.aligned_data.shape(), (6, 25));
554 assert!(result.n_iter >= 1);
555 }
556
557 #[test]
558 fn shape_mean_translation_quotient() {
559 let (data, t) = make_data(6, 25);
560 let result = shape_mean(
561 &data,
562 &t,
563 ShapeQuotient::ReparameterizationTranslation,
564 0.0,
565 5,
566 1e-2,
567 )
568 .unwrap();
569 assert_eq!(result.mean.len(), 25);
570 }
571
572 #[test]
573 fn shape_mean_full_quotient() {
574 let (data, t) = make_data(6, 25);
575 let result = shape_mean(
576 &data,
577 &t,
578 ShapeQuotient::ReparameterizationTranslationScale,
579 0.0,
580 5,
581 1e-2,
582 )
583 .unwrap();
584 assert_eq!(result.mean.len(), 25);
585 }
586
587 #[test]
588 fn shape_mean_argvals_mismatch() {
589 let (data, _) = make_data(5, 25);
590 let bad_t = uniform_grid(15);
591 assert!(shape_mean(
592 &data,
593 &bad_t,
594 ShapeQuotient::Reparameterization,
595 0.0,
596 5,
597 1e-2
598 )
599 .is_err());
600 }
601
602 #[test]
603 fn default_quotient() {
604 assert_eq!(ShapeQuotient::default(), ShapeQuotient::Reparameterization);
605 }
606}