1use crate::fdata::{deriv_1d, mean_1d};
15use crate::helpers::{l2_distance, simpsons_weights};
16use crate::iter_maybe_parallel;
17use crate::matrix::FdMatrix;
18#[cfg(feature = "parallel")]
19use rayon::iter::ParallelIterator;
20
21#[derive(Debug, Clone)]
25pub struct AlignmentResult {
26 pub gamma: Vec<f64>,
28 pub f_aligned: Vec<f64>,
30 pub distance: f64,
32}
33
34#[derive(Debug, Clone)]
36pub struct AlignmentSetResult {
37 pub gammas: FdMatrix,
39 pub aligned_data: FdMatrix,
41 pub distances: Vec<f64>,
43}
44
45#[derive(Debug, Clone)]
47pub struct KarcherMeanResult {
48 pub mean: Vec<f64>,
50 pub mean_srsf: Vec<f64>,
52 pub gammas: FdMatrix,
54 pub aligned_data: FdMatrix,
56 pub n_iter: usize,
58 pub converged: bool,
60}
61
62fn linear_interp(x: &[f64], y: &[f64], t: f64) -> f64 {
66 if t <= x[0] {
67 return y[0];
68 }
69 let last = x.len() - 1;
70 if t >= x[last] {
71 return y[last];
72 }
73
74 let idx = match x.binary_search_by(|v| v.partial_cmp(&t).unwrap()) {
76 Ok(i) => return y[i],
77 Err(i) => i,
78 };
79
80 let t0 = x[idx - 1];
81 let t1 = x[idx];
82 let y0 = y[idx - 1];
83 let y1 = y[idx];
84 y0 + (y1 - y0) * (t - t0) / (t1 - t0)
85}
86
87fn cumulative_trapz(y: &[f64], x: &[f64]) -> Vec<f64> {
89 let n = y.len();
90 let mut out = vec![0.0; n];
91 for k in 1..n {
92 out[k] = out[k - 1] + 0.5 * (y[k] + y[k - 1]) * (x[k] - x[k - 1]);
93 }
94 out
95}
96
97fn normalize_warp(gamma: &mut [f64], argvals: &[f64]) {
99 let n = gamma.len();
100 if n == 0 {
101 return;
102 }
103
104 gamma[0] = argvals[0];
106 gamma[n - 1] = argvals[n - 1];
107
108 for i in 1..n {
110 if gamma[i] < gamma[i - 1] {
111 gamma[i] = gamma[i - 1];
112 }
113 }
114}
115
116fn trapz(y: &[f64], x: &[f64]) -> f64 {
122 let mut sum = 0.0;
123 for k in 1..y.len() {
124 sum += 0.5 * (y[k] + y[k - 1]) * (x[k] - x[k - 1]);
125 }
126 sum
127}
128
129fn gradient_uniform(y: &[f64], h: f64) -> Vec<f64> {
131 let n = y.len();
132 let mut g = vec![0.0; n];
133 if n < 2 {
134 return g;
135 }
136 g[0] = (y[1] - y[0]) / h;
137 for i in 1..(n - 1) {
138 g[i] = (y[i + 1] - y[i - 1]) / (2.0 * h);
139 }
140 g[n - 1] = (y[n - 1] - y[n - 2]) / h;
141 g
142}
143
144fn gam_to_psi(gam: &[f64], h: f64) -> Vec<f64> {
146 gradient_uniform(gam, h)
147 .iter()
148 .map(|&g| g.max(0.0).sqrt())
149 .collect()
150}
151
152fn psi_to_gam(psi: &[f64], time: &[f64]) -> Vec<f64> {
154 let psi_sq: Vec<f64> = psi.iter().map(|&p| p * p).collect();
155 let gam = cumulative_trapz(&psi_sq, time);
156 let min_val = gam.iter().cloned().fold(f64::INFINITY, f64::min);
157 let max_val = gam.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
158 let range = (max_val - min_val).max(1e-10);
159 gam.iter().map(|&v| (v - min_val) / range).collect()
160}
161
162fn inner_product_l2(psi1: &[f64], psi2: &[f64], time: &[f64]) -> f64 {
164 let prod: Vec<f64> = psi1.iter().zip(psi2.iter()).map(|(&a, &b)| a * b).collect();
165 trapz(&prod, time)
166}
167
168fn l2_norm_l2(psi: &[f64], time: &[f64]) -> f64 {
170 inner_product_l2(psi, psi, time).max(0.0).sqrt()
171}
172
173fn inv_exp_map_sphere(mu: &[f64], psi: &[f64], time: &[f64]) -> Vec<f64> {
176 let ip = inner_product_l2(mu, psi, time).clamp(-1.0, 1.0);
177 let theta = ip.acos();
178 if theta < 1e-10 {
179 vec![0.0; mu.len()]
180 } else {
181 let coeff = theta / theta.sin();
182 let cos_theta = theta.cos();
183 mu.iter()
184 .zip(psi.iter())
185 .map(|(&m, &p)| coeff * (p - cos_theta * m))
186 .collect()
187 }
188}
189
190fn exp_map_sphere(psi: &[f64], v: &[f64], time: &[f64]) -> Vec<f64> {
193 let v_norm = l2_norm_l2(v, time);
194 if v_norm < 1e-10 {
195 psi.to_vec()
196 } else {
197 let cos_n = v_norm.cos();
198 let sin_n = v_norm.sin();
199 psi.iter()
200 .zip(v.iter())
201 .map(|(&p, &vi)| cos_n * p + sin_n * vi / v_norm)
202 .collect()
203 }
204}
205
206fn invert_gamma(gam: &[f64], time: &[f64]) -> Vec<f64> {
209 let n = time.len();
210 let mut gam_inv: Vec<f64> = time.iter().map(|&t| linear_interp(gam, time, t)).collect();
213 gam_inv[0] = time[0];
214 gam_inv[n - 1] = time[n - 1];
215 gam_inv
216}
217
218fn sqrt_mean_inverse(gammas: &FdMatrix, argvals: &[f64]) -> Vec<f64> {
226 let (n, m) = gammas.shape();
227 let t0 = argvals[0];
228 let t1 = argvals[m - 1];
229 let domain = t1 - t0;
230
231 let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
233 let binsize = 1.0 / (m - 1) as f64;
234
235 let mut psis: Vec<Vec<f64>> = Vec::with_capacity(n);
237 for i in 0..n {
238 let gam_01: Vec<f64> = (0..m).map(|j| (gammas[(i, j)] - t0) / domain).collect();
239 psis.push(gam_to_psi(&gam_01, binsize));
240 }
241
242 let mut mu = vec![0.0; m];
244 for psi in &psis {
245 for j in 0..m {
246 mu[j] += psi[j];
247 }
248 }
249 for j in 0..m {
250 mu[j] /= n as f64;
251 }
252
253 let step_size = 0.3;
255 let max_iter = 501;
256
257 for _ in 0..max_iter {
258 let mut vbar = vec![0.0; m];
260 for psi in &psis {
261 let v = inv_exp_map_sphere(&mu, psi, &time);
262 for j in 0..m {
263 vbar[j] += v[j];
264 }
265 }
266 for j in 0..m {
267 vbar[j] /= n as f64;
268 }
269
270 if l2_norm_l2(&vbar, &time) <= 1e-8 {
272 break;
273 }
274
275 let scaled: Vec<f64> = vbar.iter().map(|&v| v * step_size).collect();
277 mu = exp_map_sphere(&mu, &scaled, &time);
278 }
279
280 let gam_mu = psi_to_gam(&mu, &time);
282 let gam_inv = invert_gamma(&gam_mu, &time);
283
284 gam_inv.iter().map(|&g| t0 + g * domain).collect()
286}
287
288pub fn srsf_transform(data: &FdMatrix, argvals: &[f64]) -> FdMatrix {
301 let (n, m) = data.shape();
302 if n == 0 || m == 0 || argvals.len() != m {
303 return FdMatrix::zeros(n, m);
304 }
305
306 let deriv = deriv_1d(data, argvals, 1);
307
308 let mut result = FdMatrix::zeros(n, m);
309 for i in 0..n {
310 for j in 0..m {
311 let d = deriv[(i, j)];
312 result[(i, j)] = d.signum() * d.abs().sqrt();
313 }
314 }
315 result
316}
317
318pub fn srsf_inverse(q: &[f64], argvals: &[f64], f0: f64) -> Vec<f64> {
330 let m = q.len();
331 if m == 0 {
332 return Vec::new();
333 }
334
335 let integrand: Vec<f64> = q.iter().map(|&qi| qi * qi.abs()).collect();
337 let integral = cumulative_trapz(&integrand, argvals);
338
339 integral.iter().map(|&v| f0 + v).collect()
340}
341
342pub fn reparameterize_curve(f: &[f64], argvals: &[f64], gamma: &[f64]) -> Vec<f64> {
353 gamma
354 .iter()
355 .map(|&g| linear_interp(argvals, f, g))
356 .collect()
357}
358
359pub fn compose_warps(gamma1: &[f64], gamma2: &[f64], argvals: &[f64]) -> Vec<f64> {
366 gamma2
367 .iter()
368 .map(|&g| linear_interp(argvals, gamma1, g))
369 .collect()
370}
371
372#[cfg(test)]
377fn gcd(a: usize, b: usize) -> usize {
378 if b == 0 {
379 a
380 } else {
381 gcd(b, a % b)
382 }
383}
384
385#[cfg(test)]
388fn generate_coprime_nbhd(nbhd_dim: usize) -> Vec<(usize, usize)> {
389 let mut pairs = Vec::new();
390 for i in 1..=nbhd_dim {
391 for j in 1..=nbhd_dim {
392 if gcd(i, j) == 1 {
393 pairs.push((i, j));
394 }
395 }
396 }
397 pairs
398}
399
400#[rustfmt::skip]
404const COPRIME_NBHD_7: [(usize, usize); 35] = [
405 (1,1),(1,2),(1,3),(1,4),(1,5),(1,6),(1,7),
406 (2,1), (2,3), (2,5), (2,7),
407 (3,1),(3,2), (3,4),(3,5), (3,7),
408 (4,1), (4,3), (4,5), (4,7),
409 (5,1),(5,2),(5,3),(5,4), (5,6),(5,7),
410 (6,1), (6,5), (6,7),
411 (7,1),(7,2),(7,3),(7,4),(7,5),(7,6),
412];
413
414#[inline]
422fn dp_edge_weight(
423 q1: &[f64],
424 q2: &[f64],
425 argvals: &[f64],
426 sc: usize,
427 tc: usize,
428 sr: usize,
429 tr: usize,
430) -> f64 {
431 let n1 = tc - sc;
432 let n2 = tr - sr;
433 if n1 == 0 || n2 == 0 {
434 return f64::INFINITY;
435 }
436
437 let slope = (argvals[tr] - argvals[sr]) / (argvals[tc] - argvals[sc]);
438 let rslope = slope.sqrt();
439
440 let mut weight = 0.0;
442 let mut i1 = 0usize; let mut i2 = 0usize; while i1 < n1 && i2 < n2 {
446 let left1 = i1 as f64 / n1 as f64;
448 let right1 = (i1 + 1) as f64 / n1 as f64;
449 let left2 = i2 as f64 / n2 as f64;
450 let right2 = (i2 + 1) as f64 / n2 as f64;
451
452 let left = left1.max(left2);
453 let right = right1.min(right2);
454 let dt = right - left;
455
456 if dt > 0.0 {
457 let diff = q1[sc + i1] - rslope * q2[sr + i2];
458 weight += diff * diff * dt;
459 }
460
461 if right1 < right2 {
463 i1 += 1;
464 } else if right2 < right1 {
465 i2 += 1;
466 } else {
467 i1 += 1;
468 i2 += 1;
469 }
470 }
471
472 weight * (argvals[tc] - argvals[sc])
474}
475
476fn dp_alignment_core(q1: &[f64], q2: &[f64], argvals: &[f64]) -> Vec<f64> {
482 let m = argvals.len();
483 if m < 2 {
484 return argvals.to_vec();
485 }
486
487 let norm1 = q1.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
489 let norm2 = q2.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
490 let q1n: Vec<f64> = q1.iter().map(|&v| v / norm1).collect();
491 let q2n: Vec<f64> = q2.iter().map(|&v| v / norm2).collect();
492
493 let mut e = vec![f64::INFINITY; m * m];
496 let mut parent = vec![u32::MAX; m * m];
497 e[0] = 0.0;
498
499 for tr in 1..m {
500 for tc in 1..m {
501 let idx = tr * m + tc;
502 for &(dr, dc) in &COPRIME_NBHD_7 {
503 if dr > tr || dc > tc {
504 continue;
505 }
506 let sr = tr - dr;
507 let sc = tc - dc;
508 let src_idx = sr * m + sc;
509 if e[src_idx] == f64::INFINITY {
510 continue;
511 }
512 let w = dp_edge_weight(&q1n, &q2n, argvals, sc, tc, sr, tr);
513 let cost = e[src_idx] + w;
514 if cost < e[idx] {
515 e[idx] = cost;
516 parent[idx] = src_idx as u32;
517 }
518 }
519 }
520 }
521
522 let mut path_tc = Vec::with_capacity(2 * m);
524 let mut path_tr = Vec::with_capacity(2 * m);
525 let mut cur = (m - 1) * m + (m - 1);
526 loop {
527 let tr = cur / m;
528 let tc = cur % m;
529 path_tc.push(argvals[tc]);
530 path_tr.push(argvals[tr]);
531 if cur == 0 {
532 break;
533 }
534 if parent[cur] == u32::MAX {
535 break;
536 }
537 cur = parent[cur] as usize;
538 }
539
540 path_tc.reverse();
542 path_tr.reverse();
543
544 let mut gamma: Vec<f64> = argvals
547 .iter()
548 .map(|&t| linear_interp(&path_tc, &path_tr, t))
549 .collect();
550
551 normalize_warp(&mut gamma, argvals);
552 gamma
553}
554
555pub fn elastic_align_pair(f1: &[f64], f2: &[f64], argvals: &[f64]) -> AlignmentResult {
570 let m = f1.len();
571
572 let f1_mat = FdMatrix::from_slice(f1, 1, m).unwrap();
574 let f2_mat = FdMatrix::from_slice(f2, 1, m).unwrap();
575
576 let q1_mat = srsf_transform(&f1_mat, argvals);
577 let q2_mat = srsf_transform(&f2_mat, argvals);
578
579 let q1: Vec<f64> = q1_mat.row(0);
580 let q2: Vec<f64> = q2_mat.row(0);
581
582 let gamma = dp_alignment_core(&q1, &q2, argvals);
584
585 let f_aligned = reparameterize_curve(f2, argvals, &gamma);
587
588 let f_aligned_mat = FdMatrix::from_slice(&f_aligned, 1, m).unwrap();
590 let q_aligned_mat = srsf_transform(&f_aligned_mat, argvals);
591 let q_aligned: Vec<f64> = q_aligned_mat.row(0);
592
593 let weights = simpsons_weights(argvals);
594 let distance = l2_distance(&q1, &q_aligned, &weights);
595
596 AlignmentResult {
597 gamma,
598 f_aligned,
599 distance,
600 }
601}
602
603pub fn elastic_distance(f1: &[f64], f2: &[f64], argvals: &[f64]) -> f64 {
612 elastic_align_pair(f1, f2, argvals).distance
613}
614
615pub fn align_to_target(data: &FdMatrix, target: &[f64], argvals: &[f64]) -> AlignmentSetResult {
625 let (n, m) = data.shape();
626
627 let results: Vec<AlignmentResult> = iter_maybe_parallel!(0..n)
628 .map(|i| {
629 let fi = data.row(i);
630 elastic_align_pair(target, &fi, argvals)
631 })
632 .collect();
633
634 let mut gammas = FdMatrix::zeros(n, m);
635 let mut aligned_data = FdMatrix::zeros(n, m);
636 let mut distances = Vec::with_capacity(n);
637
638 for (i, r) in results.into_iter().enumerate() {
639 for j in 0..m {
640 gammas[(i, j)] = r.gamma[j];
641 aligned_data[(i, j)] = r.f_aligned[j];
642 }
643 distances.push(r.distance);
644 }
645
646 AlignmentSetResult {
647 gammas,
648 aligned_data,
649 distances,
650 }
651}
652
653pub fn elastic_self_distance_matrix(data: &FdMatrix, argvals: &[f64]) -> FdMatrix {
667 let n = data.nrows();
668
669 let upper_vals: Vec<f64> = iter_maybe_parallel!(0..n)
670 .flat_map(|i| {
671 let fi = data.row(i);
672 ((i + 1)..n)
673 .map(|j| {
674 let fj = data.row(j);
675 elastic_distance(&fi, &fj, argvals)
676 })
677 .collect::<Vec<_>>()
678 })
679 .collect();
680
681 let mut dist = FdMatrix::zeros(n, n);
682 let mut idx = 0;
683 for i in 0..n {
684 for j in (i + 1)..n {
685 let d = upper_vals[idx];
686 dist[(i, j)] = d;
687 dist[(j, i)] = d;
688 idx += 1;
689 }
690 }
691 dist
692}
693
694pub fn elastic_cross_distance_matrix(
704 data1: &FdMatrix,
705 data2: &FdMatrix,
706 argvals: &[f64],
707) -> FdMatrix {
708 let n1 = data1.nrows();
709 let n2 = data2.nrows();
710
711 let vals: Vec<f64> = iter_maybe_parallel!(0..n1)
712 .flat_map(|i| {
713 let fi = data1.row(i);
714 (0..n2)
715 .map(|j| {
716 let fj = data2.row(j);
717 elastic_distance(&fi, &fj, argvals)
718 })
719 .collect::<Vec<_>>()
720 })
721 .collect();
722
723 let mut dist = FdMatrix::zeros(n1, n2);
724 for i in 0..n1 {
725 for j in 0..n2 {
726 dist[(i, j)] = vals[i * n2 + j];
727 }
728 }
729 dist
730}
731
732fn relative_change(q_old: &[f64], q_new: &[f64]) -> f64 {
739 let diff_norm: f64 = q_old
740 .iter()
741 .zip(q_new.iter())
742 .map(|(&a, &b)| (a - b).powi(2))
743 .sum::<f64>()
744 .sqrt();
745 let old_norm: f64 = q_old.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
746 diff_norm / old_norm
747}
748
749fn srsf_single(f: &[f64], argvals: &[f64]) -> Vec<f64> {
751 let m = f.len();
752 let mat = FdMatrix::from_slice(f, 1, m).unwrap();
753 let q_mat = srsf_transform(&mat, argvals);
754 q_mat.row(0)
755}
756
757fn align_srsf_pair(q1: &[f64], q2: &[f64], argvals: &[f64]) -> (Vec<f64>, Vec<f64>) {
759 let gamma = dp_alignment_core(q1, q2, argvals);
760
761 let q2_warped = reparameterize_curve(q2, argvals, &gamma);
763
764 let m = gamma.len();
766 let mut gamma_dot = vec![0.0; m];
767 gamma_dot[0] = (gamma[1] - gamma[0]) / (argvals[1] - argvals[0]);
768 for j in 1..(m - 1) {
769 gamma_dot[j] = (gamma[j + 1] - gamma[j - 1]) / (argvals[j + 1] - argvals[j - 1]);
770 }
771 gamma_dot[m - 1] = (gamma[m - 1] - gamma[m - 2]) / (argvals[m - 1] - argvals[m - 2]);
772
773 let q2_aligned: Vec<f64> = q2_warped
775 .iter()
776 .zip(gamma_dot.iter())
777 .map(|(&q, &gd)| q * gd.max(0.0).sqrt())
778 .collect();
779
780 (gamma, q2_aligned)
781}
782
783fn accumulate_alignments(
812 results: &[(Vec<f64>, Vec<f64>)],
813 gammas: &mut FdMatrix,
814 m: usize,
815 n: usize,
816) -> Vec<f64> {
817 let mut mu_q_new = vec![0.0; m];
818 for (i, (gamma, q_aligned)) in results.iter().enumerate() {
819 for j in 0..m {
820 gammas[(i, j)] = gamma[j];
821 mu_q_new[j] += q_aligned[j];
822 }
823 }
824 for j in 0..m {
825 mu_q_new[j] /= n as f64;
826 }
827 mu_q_new
828}
829
830fn apply_stored_warps(data: &FdMatrix, gammas: &FdMatrix, argvals: &[f64]) -> FdMatrix {
832 let (n, m) = data.shape();
833 let mut aligned = FdMatrix::zeros(n, m);
834 for i in 0..n {
835 let fi = data.row(i);
836 let gamma: Vec<f64> = (0..m).map(|j| gammas[(i, j)]).collect();
837 let f_aligned = reparameterize_curve(&fi, argvals, &gamma);
838 for j in 0..m {
839 aligned[(i, j)] = f_aligned[j];
840 }
841 }
842 aligned
843}
844
845pub fn karcher_mean(
846 data: &FdMatrix,
847 argvals: &[f64],
848 max_iter: usize,
849 tol: f64,
850) -> KarcherMeanResult {
851 let (n, m) = data.shape();
852
853 let srsf_mat = srsf_transform(data, argvals);
855 let mnq = mean_1d(&srsf_mat);
856 let mut min_dist = f64::INFINITY;
857 let mut min_idx = 0;
858 for i in 0..n {
859 let dist_sq: f64 = (0..m).map(|j| (srsf_mat[(i, j)] - mnq[j]).powi(2)).sum();
860 if dist_sq < min_dist {
861 min_dist = dist_sq;
862 min_idx = i;
863 }
864 }
865 let mut mu_q = srsf_mat.row(min_idx);
866 let mut mu = data.row(min_idx);
867
868 {
871 let align_results: Vec<(Vec<f64>, Vec<f64>)> = iter_maybe_parallel!(0..n)
872 .map(|i| {
873 let fi = data.row(i);
874 let qi = srsf_single(&fi, argvals);
875 align_srsf_pair(&mu_q, &qi, argvals)
876 })
877 .collect();
878
879 let mut init_gammas = FdMatrix::zeros(n, m);
880 for (i, (gamma, _)) in align_results.iter().enumerate() {
881 for j in 0..m {
882 init_gammas[(i, j)] = gamma[j];
883 }
884 }
885
886 let gam_inv = sqrt_mean_inverse(&init_gammas, argvals);
888 mu = reparameterize_curve(&mu, argvals, &gam_inv);
889 mu_q = srsf_single(&mu, argvals);
890 }
891
892 let mut converged = false;
894 let mut n_iter = 0;
895 let mut final_gammas = FdMatrix::zeros(n, m);
896 let mut prev_rel = 0.0_f64;
897
898 for iter in 0..max_iter {
899 n_iter = iter + 1;
900
901 let align_results: Vec<(Vec<f64>, Vec<f64>)> = iter_maybe_parallel!(0..n)
902 .map(|i| {
903 let fi = data.row(i);
904 let qi = srsf_single(&fi, argvals);
905 align_srsf_pair(&mu_q, &qi, argvals)
906 })
907 .collect();
908
909 let mu_q_new = accumulate_alignments(&align_results, &mut final_gammas, m, n);
910
911 let rel = relative_change(&mu_q, &mu_q_new);
912 if rel < f64::EPSILON || (iter > 0 && rel - prev_rel <= tol * prev_rel) {
913 converged = true;
914 mu_q = mu_q_new;
915 break;
916 }
917 prev_rel = rel;
918
919 mu_q = mu_q_new;
920 mu = srsf_inverse(&mu_q, argvals, mu[0]);
921 }
922
923 let gam_inv = sqrt_mean_inverse(&final_gammas, argvals);
925 let h = (argvals[m - 1] - argvals[0]) / (m - 1) as f64;
926 let gam_inv_dev = gradient_uniform(&gam_inv, h);
927
928 let mu_q_warped = reparameterize_curve(&mu_q, argvals, &gam_inv);
930 mu_q = mu_q_warped
931 .iter()
932 .zip(gam_inv_dev.iter())
933 .map(|(&q, &gd)| q * gd.max(0.0).sqrt())
934 .collect();
935
936 for i in 0..n {
938 let gam_i: Vec<f64> = (0..m).map(|j| final_gammas[(i, j)]).collect();
939 let gam_centered = reparameterize_curve(&gam_i, argvals, &gam_inv);
940 for j in 0..m {
941 final_gammas[(i, j)] = gam_centered[j];
942 }
943 }
944
945 let initial_mean = mean_1d(data);
947 mu = srsf_inverse(&mu_q, argvals, initial_mean[0]);
948 let final_aligned = apply_stored_warps(data, &final_gammas, argvals);
949
950 KarcherMeanResult {
951 mean: mu,
952 mean_srsf: mu_q,
953 gammas: final_gammas,
954 aligned_data: final_aligned,
955 n_iter,
956 converged,
957 }
958}
959
960#[cfg(test)]
963mod tests {
964 use super::*;
965 use crate::simulation::{sim_fundata, EFunType, EValType};
966
967 fn uniform_grid(m: usize) -> Vec<f64> {
968 (0..m).map(|i| i as f64 / (m - 1) as f64).collect()
969 }
970
971 fn make_test_data(n: usize, m: usize, seed: u64) -> FdMatrix {
972 let t = uniform_grid(m);
973 sim_fundata(
974 n,
975 &t,
976 3,
977 EFunType::Fourier,
978 EValType::Exponential,
979 Some(seed),
980 )
981 }
982
983 #[test]
986 fn test_cumulative_trapz_constant() {
987 let x = uniform_grid(50);
989 let y = vec![1.0; 50];
990 let result = cumulative_trapz(&y, &x);
991 assert!((result[0]).abs() < 1e-15, "cumulative_trapz(0) should be 0");
992 for j in 1..50 {
993 assert!(
994 (result[j] - x[j]).abs() < 1e-12,
995 "∫₀^{:.3} 1 dt should be {:.3}, got {:.3}",
996 x[j],
997 x[j],
998 result[j]
999 );
1000 }
1001 }
1002
1003 #[test]
1004 fn test_cumulative_trapz_linear() {
1005 let m = 100;
1007 let x = uniform_grid(m);
1008 let y: Vec<f64> = x.clone();
1009 let result = cumulative_trapz(&y, &x);
1010 for j in 1..m {
1011 let expected = x[j] * x[j] / 2.0;
1012 assert!(
1013 (result[j] - expected).abs() < 1e-4,
1014 "∫₀^{:.3} s ds: expected {expected:.6}, got {:.6}",
1015 x[j],
1016 result[j]
1017 );
1018 }
1019 }
1020
1021 #[test]
1024 fn test_normalize_warp_fixes_boundaries() {
1025 let t = uniform_grid(10);
1026 let mut gamma = vec![0.1; 10]; normalize_warp(&mut gamma, &t);
1028 assert_eq!(gamma[0], t[0]);
1029 assert_eq!(gamma[9], t[9]);
1030 }
1031
1032 #[test]
1033 fn test_normalize_warp_enforces_monotonicity() {
1034 let t = uniform_grid(5);
1035 let mut gamma = vec![0.0, 0.5, 0.3, 0.8, 1.0]; normalize_warp(&mut gamma, &t);
1037 for j in 1..5 {
1038 assert!(
1039 gamma[j] >= gamma[j - 1],
1040 "gamma should be monotone after normalization at j={j}"
1041 );
1042 }
1043 }
1044
1045 #[test]
1046 fn test_normalize_warp_identity_unchanged() {
1047 let t = uniform_grid(20);
1048 let mut gamma = t.clone();
1049 normalize_warp(&mut gamma, &t);
1050 for j in 0..20 {
1051 assert!(
1052 (gamma[j] - t[j]).abs() < 1e-15,
1053 "Identity warp should be unchanged"
1054 );
1055 }
1056 }
1057
1058 #[test]
1061 fn test_linear_interp_at_nodes() {
1062 let x = vec![0.0, 1.0, 2.0, 3.0];
1063 let y = vec![0.0, 2.0, 4.0, 6.0];
1064 for i in 0..x.len() {
1065 assert!((linear_interp(&x, &y, x[i]) - y[i]).abs() < 1e-12);
1066 }
1067 }
1068
1069 #[test]
1070 fn test_linear_interp_midpoints() {
1071 let x = vec![0.0, 1.0, 2.0];
1072 let y = vec![0.0, 2.0, 4.0];
1073 assert!((linear_interp(&x, &y, 0.5) - 1.0).abs() < 1e-12);
1074 assert!((linear_interp(&x, &y, 1.5) - 3.0).abs() < 1e-12);
1075 }
1076
1077 #[test]
1078 fn test_linear_interp_clamp() {
1079 let x = vec![0.0, 1.0, 2.0];
1080 let y = vec![1.0, 3.0, 5.0];
1081 assert!((linear_interp(&x, &y, -1.0) - 1.0).abs() < 1e-12);
1082 assert!((linear_interp(&x, &y, 3.0) - 5.0).abs() < 1e-12);
1083 }
1084
1085 #[test]
1086 fn test_linear_interp_nonuniform_grid() {
1087 let x = vec![0.0, 0.1, 0.5, 1.0];
1088 let y = vec![0.0, 1.0, 5.0, 10.0];
1089 let val = linear_interp(&x, &y, 0.3);
1091 let expected = 1.0 + 10.0 * (0.3 - 0.1);
1092 assert!(
1093 (val - expected).abs() < 1e-12,
1094 "Non-uniform interp: expected {expected}, got {val}"
1095 );
1096 }
1097
1098 #[test]
1099 fn test_linear_interp_two_points() {
1100 let x = vec![0.0, 1.0];
1101 let y = vec![3.0, 7.0];
1102 assert!((linear_interp(&x, &y, 0.25) - 4.0).abs() < 1e-12);
1103 assert!((linear_interp(&x, &y, 0.75) - 6.0).abs() < 1e-12);
1104 }
1105
1106 #[test]
1109 fn test_srsf_transform_linear() {
1110 let m = 50;
1112 let t = uniform_grid(m);
1113 let f: Vec<f64> = t.iter().map(|&ti| 2.0 * ti).collect();
1114 let mat = FdMatrix::from_slice(&f, 1, m).unwrap();
1115
1116 let q_mat = srsf_transform(&mat, &t);
1117 let q: Vec<f64> = q_mat.row(0);
1118
1119 let expected = 2.0_f64.sqrt();
1120 for j in 2..(m - 2) {
1122 assert!(
1123 (q[j] - expected).abs() < 0.1,
1124 "q[{j}] = {}, expected ~{expected}",
1125 q[j]
1126 );
1127 }
1128 }
1129
1130 #[test]
1131 fn test_srsf_transform_preserves_shape() {
1132 let data = make_test_data(10, 50, 42);
1133 let t = uniform_grid(50);
1134 let q = srsf_transform(&data, &t);
1135 assert_eq!(q.shape(), data.shape());
1136 }
1137
1138 #[test]
1139 fn test_srsf_transform_constant_is_zero() {
1140 let m = 30;
1142 let t = uniform_grid(m);
1143 let f = vec![5.0; m];
1144 let mat = FdMatrix::from_slice(&f, 1, m).unwrap();
1145 let q_mat = srsf_transform(&mat, &t);
1146 let q: Vec<f64> = q_mat.row(0);
1147
1148 for j in 0..m {
1149 assert!(
1150 q[j].abs() < 1e-10,
1151 "SRSF of constant should be 0, got q[{j}] = {}",
1152 q[j]
1153 );
1154 }
1155 }
1156
1157 #[test]
1158 fn test_srsf_transform_negative_slope() {
1159 let m = 50;
1161 let t = uniform_grid(m);
1162 let f: Vec<f64> = t.iter().map(|&ti| -3.0 * ti).collect();
1163 let mat = FdMatrix::from_slice(&f, 1, m).unwrap();
1164
1165 let q_mat = srsf_transform(&mat, &t);
1166 let q: Vec<f64> = q_mat.row(0);
1167
1168 let expected = -(3.0_f64.sqrt());
1169 for j in 2..(m - 2) {
1170 assert!(
1171 (q[j] - expected).abs() < 0.15,
1172 "q[{j}] = {}, expected ~{expected}",
1173 q[j]
1174 );
1175 }
1176 }
1177
1178 #[test]
1179 fn test_srsf_transform_empty_input() {
1180 let data = FdMatrix::zeros(0, 0);
1181 let t: Vec<f64> = vec![];
1182 let q = srsf_transform(&data, &t);
1183 assert_eq!(q.shape(), (0, 0));
1184 }
1185
1186 #[test]
1187 fn test_srsf_transform_multiple_curves() {
1188 let m = 40;
1189 let t = uniform_grid(m);
1190 let data = make_test_data(5, m, 42);
1191
1192 let q = srsf_transform(&data, &t);
1193 assert_eq!(q.shape(), (5, m));
1194
1195 for i in 0..5 {
1197 for j in 0..m {
1198 assert!(q[(i, j)].is_finite(), "SRSF should be finite at ({i},{j})");
1199 }
1200 }
1201 }
1202
1203 #[test]
1206 fn test_srsf_round_trip() {
1207 let m = 100;
1208 let t = uniform_grid(m);
1209 let f: Vec<f64> = t
1211 .iter()
1212 .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin() + ti)
1213 .collect();
1214
1215 let mat = FdMatrix::from_slice(&f, 1, m).unwrap();
1216 let q_mat = srsf_transform(&mat, &t);
1217 let q: Vec<f64> = q_mat.row(0);
1218
1219 let f_recon = srsf_inverse(&q, &t, f[0]);
1220
1221 let max_err: f64 = f[5..(m - 5)]
1223 .iter()
1224 .zip(f_recon[5..(m - 5)].iter())
1225 .map(|(a, b)| (a - b).abs())
1226 .fold(0.0_f64, f64::max);
1227
1228 assert!(
1229 max_err < 0.15,
1230 "Round-trip error too large: max_err = {max_err}"
1231 );
1232 }
1233
1234 #[test]
1235 fn test_srsf_inverse_empty() {
1236 let q: Vec<f64> = vec![];
1237 let t: Vec<f64> = vec![];
1238 let result = srsf_inverse(&q, &t, 0.0);
1239 assert!(result.is_empty());
1240 }
1241
1242 #[test]
1243 fn test_srsf_inverse_preserves_initial_value() {
1244 let m = 50;
1245 let t = uniform_grid(m);
1246 let q = vec![1.0; m]; let f0 = 3.15;
1248 let f = srsf_inverse(&q, &t, f0);
1249 assert!((f[0] - f0).abs() < 1e-12, "srsf_inverse should start at f0");
1250 }
1251
1252 #[test]
1253 fn test_srsf_round_trip_multiple_curves() {
1254 let m = 80;
1255 let t = uniform_grid(m);
1256 let data = make_test_data(5, m, 99);
1257
1258 let q_mat = srsf_transform(&data, &t);
1259
1260 for i in 0..5 {
1261 let fi = data.row(i);
1262 let qi = q_mat.row(i);
1263 let f_recon = srsf_inverse(&qi, &t, fi[0]);
1264 let max_err: f64 = fi[5..(m - 5)]
1265 .iter()
1266 .zip(f_recon[5..(m - 5)].iter())
1267 .map(|(a, b)| (a - b).abs())
1268 .fold(0.0_f64, f64::max);
1269 assert!(max_err < 0.3, "Round-trip curve {i}: max_err = {max_err}");
1270 }
1271 }
1272
1273 #[test]
1276 fn test_reparameterize_identity_warp() {
1277 let m = 50;
1278 let t = uniform_grid(m);
1279 let f: Vec<f64> = t.iter().map(|&ti| ti * ti).collect();
1280
1281 let result = reparameterize_curve(&f, &t, &t);
1283 for j in 0..m {
1284 assert!(
1285 (result[j] - f[j]).abs() < 1e-12,
1286 "Identity warp should return original at j={j}"
1287 );
1288 }
1289 }
1290
1291 #[test]
1292 fn test_reparameterize_linear_warp() {
1293 let m = 50;
1294 let t = uniform_grid(m);
1295 let f: Vec<f64> = t.clone();
1297 let gamma: Vec<f64> = t.iter().map(|&ti| ti * ti).collect();
1298
1299 let result = reparameterize_curve(&f, &t, &gamma);
1300
1301 for j in 0..m {
1303 assert!(
1304 (result[j] - gamma[j]).abs() < 1e-10,
1305 "f(gamma(t)) should be gamma(t) for f(t)=t at j={j}"
1306 );
1307 }
1308 }
1309
1310 #[test]
1311 fn test_reparameterize_sine_with_quadratic_warp() {
1312 let m = 100;
1313 let t = uniform_grid(m);
1314 let f: Vec<f64> = t
1315 .iter()
1316 .map(|&ti| (std::f64::consts::PI * ti).sin())
1317 .collect();
1318 let gamma: Vec<f64> = t.iter().map(|&ti| ti * ti).collect(); let result = reparameterize_curve(&f, &t, &gamma);
1321
1322 for j in 0..m {
1324 let expected = (std::f64::consts::PI * gamma[j]).sin();
1325 assert!(
1326 (result[j] - expected).abs() < 0.05,
1327 "sin(π γ(t)) at j={j}: expected {expected:.4}, got {:.4}",
1328 result[j]
1329 );
1330 }
1331 }
1332
1333 #[test]
1334 fn test_reparameterize_preserves_length() {
1335 let m = 50;
1336 let t = uniform_grid(m);
1337 let f = vec![1.0; m];
1338 let gamma: Vec<f64> = t.iter().map(|&ti| ti.sqrt()).collect();
1339
1340 let result = reparameterize_curve(&f, &t, &gamma);
1341 assert_eq!(result.len(), m);
1342 }
1343
1344 #[test]
1347 fn test_compose_warps_identity() {
1348 let m = 50;
1349 let t = uniform_grid(m);
1350 let gamma: Vec<f64> = t.iter().map(|&ti| ti.sqrt()).collect();
1352
1353 let result = compose_warps(&t, &gamma, &t);
1355 for j in 0..m {
1356 assert!(
1357 (result[j] - gamma[j]).abs() < 1e-10,
1358 "id ∘ γ should be γ at j={j}"
1359 );
1360 }
1361
1362 let result2 = compose_warps(&gamma, &t, &t);
1364 for j in 0..m {
1365 assert!(
1366 (result2[j] - gamma[j]).abs() < 1e-10,
1367 "γ ∘ id should be γ at j={j}"
1368 );
1369 }
1370 }
1371
1372 #[test]
1373 fn test_compose_warps_associativity() {
1374 let m = 50;
1376 let t = uniform_grid(m);
1377 let g1: Vec<f64> = t.iter().map(|&ti| ti.sqrt()).collect();
1378 let g2: Vec<f64> = t.iter().map(|&ti| ti * ti).collect();
1379 let g3: Vec<f64> = t.iter().map(|&ti| 0.5 * ti + 0.5 * ti * ti).collect();
1380
1381 let g12 = compose_warps(&g1, &g2, &t);
1382 let left = compose_warps(&g12, &g3, &t); let g23 = compose_warps(&g2, &g3, &t);
1385 let right = compose_warps(&g1, &g23, &t); for j in 0..m {
1388 assert!(
1389 (left[j] - right[j]).abs() < 0.05,
1390 "Composition should be roughly associative at j={j}: left={:.4}, right={:.4}",
1391 left[j],
1392 right[j]
1393 );
1394 }
1395 }
1396
1397 #[test]
1398 fn test_compose_warps_preserves_domain() {
1399 let m = 50;
1400 let t = uniform_grid(m);
1401 let g1: Vec<f64> = t.iter().map(|&ti| ti.sqrt()).collect();
1402 let g2: Vec<f64> = t.iter().map(|&ti| ti * ti).collect();
1403
1404 let composed = compose_warps(&g1, &g2, &t);
1405 assert!(
1406 (composed[0] - t[0]).abs() < 1e-10,
1407 "Composed warp should start at domain start"
1408 );
1409 assert!(
1410 (composed[m - 1] - t[m - 1]).abs() < 1e-10,
1411 "Composed warp should end at domain end"
1412 );
1413 }
1414
1415 #[test]
1418 fn test_align_identical_curves() {
1419 let m = 50;
1420 let t = uniform_grid(m);
1421 let f: Vec<f64> = t
1422 .iter()
1423 .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
1424 .collect();
1425
1426 let result = elastic_align_pair(&f, &f, &t);
1427
1428 assert!(
1430 result.distance < 0.1,
1431 "Distance between identical curves should be near 0, got {}",
1432 result.distance
1433 );
1434
1435 for j in 0..m {
1437 assert!(
1438 (result.gamma[j] - t[j]).abs() < 0.1,
1439 "Warp should be near identity at j={j}: gamma={}, t={}",
1440 result.gamma[j],
1441 t[j]
1442 );
1443 }
1444 }
1445
1446 #[test]
1447 fn test_align_pair_valid_output() {
1448 let data = make_test_data(2, 50, 42);
1449 let t = uniform_grid(50);
1450 let f1 = data.row(0);
1451 let f2 = data.row(1);
1452
1453 let result = elastic_align_pair(&f1, &f2, &t);
1454
1455 assert_eq!(result.gamma.len(), 50);
1456 assert_eq!(result.f_aligned.len(), 50);
1457 assert!(result.distance >= 0.0);
1458
1459 for j in 1..50 {
1461 assert!(
1462 result.gamma[j] >= result.gamma[j - 1],
1463 "Warp should be monotone at j={j}"
1464 );
1465 }
1466 }
1467
1468 #[test]
1469 fn test_align_pair_warp_boundaries() {
1470 let data = make_test_data(2, 50, 42);
1471 let t = uniform_grid(50);
1472 let f1 = data.row(0);
1473 let f2 = data.row(1);
1474
1475 let result = elastic_align_pair(&f1, &f2, &t);
1476 assert!(
1477 (result.gamma[0] - t[0]).abs() < 1e-12,
1478 "Warp should start at domain start"
1479 );
1480 assert!(
1481 (result.gamma[49] - t[49]).abs() < 1e-12,
1482 "Warp should end at domain end"
1483 );
1484 }
1485
1486 #[test]
1487 fn test_align_shifted_sine() {
1488 let m = 80;
1490 let t = uniform_grid(m);
1491 let f1: Vec<f64> = t
1492 .iter()
1493 .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
1494 .collect();
1495 let f2: Vec<f64> = t
1496 .iter()
1497 .map(|&ti| (2.0 * std::f64::consts::PI * (ti - 0.1)).sin())
1498 .collect();
1499
1500 let weights = simpsons_weights(&t);
1501 let l2_before = l2_distance(&f1, &f2, &weights);
1502 let result = elastic_align_pair(&f1, &f2, &t);
1503 let l2_after = l2_distance(&f1, &result.f_aligned, &weights);
1504
1505 assert!(
1506 l2_after < l2_before + 0.01,
1507 "Alignment should not increase L2 distance: before={l2_before:.4}, after={l2_after:.4}"
1508 );
1509 }
1510
1511 #[test]
1512 fn test_align_pair_aligned_curve_is_finite() {
1513 let data = make_test_data(2, 50, 77);
1514 let t = uniform_grid(50);
1515 let f1 = data.row(0);
1516 let f2 = data.row(1);
1517
1518 let result = elastic_align_pair(&f1, &f2, &t);
1519 for j in 0..50 {
1520 assert!(
1521 result.f_aligned[j].is_finite(),
1522 "Aligned curve should be finite at j={j}"
1523 );
1524 }
1525 }
1526
1527 #[test]
1528 fn test_align_pair_minimum_grid() {
1529 let t = vec![0.0, 1.0];
1531 let f1 = vec![0.0, 1.0];
1532 let f2 = vec![0.0, 2.0];
1533 let result = elastic_align_pair(&f1, &f2, &t);
1534 assert_eq!(result.gamma.len(), 2);
1535 assert_eq!(result.f_aligned.len(), 2);
1536 assert!(result.distance >= 0.0);
1537 }
1538
1539 #[test]
1542 fn test_elastic_distance_symmetric() {
1543 let data = make_test_data(3, 50, 42);
1544 let t = uniform_grid(50);
1545 let f1 = data.row(0);
1546 let f2 = data.row(1);
1547
1548 let d12 = elastic_distance(&f1, &f2, &t);
1549 let d21 = elastic_distance(&f2, &f1, &t);
1550
1551 assert!(
1553 (d12 - d21).abs() < d12.max(d21) * 0.3 + 0.01,
1554 "Elastic distance should be roughly symmetric: d12={d12}, d21={d21}"
1555 );
1556 }
1557
1558 #[test]
1559 fn test_elastic_distance_nonneg() {
1560 let data = make_test_data(3, 50, 42);
1561 let t = uniform_grid(50);
1562
1563 for i in 0..3 {
1564 for j in 0..3 {
1565 let fi = data.row(i);
1566 let fj = data.row(j);
1567 let d = elastic_distance(&fi, &fj, &t);
1568 assert!(d >= 0.0, "Elastic distance should be non-negative");
1569 }
1570 }
1571 }
1572
1573 #[test]
1574 fn test_elastic_distance_self_near_zero() {
1575 let data = make_test_data(3, 50, 42);
1576 let t = uniform_grid(50);
1577
1578 for i in 0..3 {
1579 let fi = data.row(i);
1580 let d = elastic_distance(&fi, &fi, &t);
1581 assert!(
1582 d < 0.1,
1583 "Self-distance should be near zero, got {d} for curve {i}"
1584 );
1585 }
1586 }
1587
1588 #[test]
1589 fn test_elastic_distance_triangle_inequality() {
1590 let data = make_test_data(3, 50, 42);
1591 let t = uniform_grid(50);
1592 let f0 = data.row(0);
1593 let f1 = data.row(1);
1594 let f2 = data.row(2);
1595
1596 let d01 = elastic_distance(&f0, &f1, &t);
1597 let d12 = elastic_distance(&f1, &f2, &t);
1598 let d02 = elastic_distance(&f0, &f2, &t);
1599
1600 let slack = 0.5;
1602 assert!(
1603 d02 <= d01 + d12 + slack,
1604 "Triangle inequality (relaxed): d02={d02:.4} > d01={d01:.4} + d12={d12:.4} + {slack}"
1605 );
1606 }
1607
1608 #[test]
1609 fn test_elastic_distance_different_shapes_nonzero() {
1610 let m = 50;
1611 let t = uniform_grid(m);
1612 let f1: Vec<f64> = t.to_vec(); let f2: Vec<f64> = t.iter().map(|&ti| ti * ti).collect(); let d = elastic_distance(&f1, &f2, &t);
1616 assert!(
1617 d > 0.01,
1618 "Distance between different shapes should be > 0, got {d}"
1619 );
1620 }
1621
1622 #[test]
1625 fn test_self_distance_matrix_symmetric() {
1626 let data = make_test_data(5, 30, 42);
1627 let t = uniform_grid(30);
1628
1629 let dm = elastic_self_distance_matrix(&data, &t);
1630 let n = dm.nrows();
1631
1632 assert_eq!(dm.shape(), (5, 5));
1633
1634 for i in 0..n {
1636 assert!(dm[(i, i)].abs() < 1e-12, "Diagonal should be zero");
1637 }
1638
1639 for i in 0..n {
1641 for j in (i + 1)..n {
1642 assert!(
1643 (dm[(i, j)] - dm[(j, i)]).abs() < 1e-12,
1644 "Matrix should be symmetric at ({i},{j})"
1645 );
1646 }
1647 }
1648 }
1649
1650 #[test]
1651 fn test_self_distance_matrix_nonneg() {
1652 let data = make_test_data(4, 30, 42);
1653 let t = uniform_grid(30);
1654 let dm = elastic_self_distance_matrix(&data, &t);
1655
1656 for i in 0..4 {
1657 for j in 0..4 {
1658 assert!(
1659 dm[(i, j)] >= 0.0,
1660 "Distance matrix entries should be non-negative at ({i},{j})"
1661 );
1662 }
1663 }
1664 }
1665
1666 #[test]
1667 fn test_self_distance_matrix_single_curve() {
1668 let data = make_test_data(1, 30, 42);
1669 let t = uniform_grid(30);
1670 let dm = elastic_self_distance_matrix(&data, &t);
1671 assert_eq!(dm.shape(), (1, 1));
1672 assert!(dm[(0, 0)].abs() < 1e-12);
1673 }
1674
1675 #[test]
1676 fn test_self_distance_matrix_consistent_with_pairwise() {
1677 let data = make_test_data(4, 30, 42);
1678 let t = uniform_grid(30);
1679
1680 let dm = elastic_self_distance_matrix(&data, &t);
1681
1682 for i in 0..4 {
1684 for j in (i + 1)..4 {
1685 let fi = data.row(i);
1686 let fj = data.row(j);
1687 let d_direct = elastic_distance(&fi, &fj, &t);
1688 assert!(
1689 (dm[(i, j)] - d_direct).abs() < 1e-10,
1690 "Matrix entry ({i},{j})={:.6} should match pairwise {d_direct:.6}",
1691 dm[(i, j)]
1692 );
1693 }
1694 }
1695 }
1696
1697 #[test]
1700 fn test_karcher_mean_identical_curves() {
1701 let m = 50;
1702 let t = uniform_grid(m);
1703 let f: Vec<f64> = t
1704 .iter()
1705 .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
1706 .collect();
1707
1708 let mut data = FdMatrix::zeros(5, m);
1710 for i in 0..5 {
1711 for j in 0..m {
1712 data[(i, j)] = f[j];
1713 }
1714 }
1715
1716 let result = karcher_mean(&data, &t, 10, 1e-4);
1717
1718 assert_eq!(result.mean.len(), m);
1719 assert!(result.n_iter <= 10);
1720 }
1721
1722 #[test]
1723 fn test_karcher_mean_output_shape() {
1724 let data = make_test_data(15, 50, 42);
1725 let t = uniform_grid(50);
1726
1727 let result = karcher_mean(&data, &t, 5, 1e-3);
1728
1729 assert_eq!(result.mean.len(), 50);
1730 assert_eq!(result.mean_srsf.len(), 50);
1731 assert_eq!(result.gammas.shape(), (15, 50));
1732 assert_eq!(result.aligned_data.shape(), (15, 50));
1733 assert!(result.n_iter <= 5);
1734 }
1735
1736 #[test]
1737 fn test_karcher_mean_warps_are_valid() {
1738 let data = make_test_data(10, 40, 42);
1739 let t = uniform_grid(40);
1740
1741 let result = karcher_mean(&data, &t, 5, 1e-3);
1742
1743 for i in 0..10 {
1744 assert!(
1746 (result.gammas[(i, 0)] - t[0]).abs() < 1e-10,
1747 "Warp {i} should start at domain start"
1748 );
1749 assert!(
1750 (result.gammas[(i, 39)] - t[39]).abs() < 1e-10,
1751 "Warp {i} should end at domain end"
1752 );
1753 for j in 1..40 {
1755 assert!(
1756 result.gammas[(i, j)] >= result.gammas[(i, j - 1)],
1757 "Warp {i} should be monotone at j={j}"
1758 );
1759 }
1760 }
1761 }
1762
1763 #[test]
1764 fn test_karcher_mean_aligned_data_is_finite() {
1765 let data = make_test_data(8, 40, 42);
1766 let t = uniform_grid(40);
1767 let result = karcher_mean(&data, &t, 5, 1e-3);
1768
1769 for i in 0..8 {
1770 for j in 0..40 {
1771 assert!(
1772 result.aligned_data[(i, j)].is_finite(),
1773 "Aligned data should be finite at ({i},{j})"
1774 );
1775 }
1776 }
1777 }
1778
1779 #[test]
1780 fn test_karcher_mean_srsf_is_finite() {
1781 let data = make_test_data(8, 40, 42);
1782 let t = uniform_grid(40);
1783 let result = karcher_mean(&data, &t, 5, 1e-3);
1784
1785 for j in 0..40 {
1786 assert!(
1787 result.mean_srsf[j].is_finite(),
1788 "Mean SRSF should be finite at j={j}"
1789 );
1790 assert!(
1791 result.mean[j].is_finite(),
1792 "Mean curve should be finite at j={j}"
1793 );
1794 }
1795 }
1796
1797 #[test]
1798 fn test_karcher_mean_single_iteration() {
1799 let data = make_test_data(10, 40, 42);
1800 let t = uniform_grid(40);
1801 let result = karcher_mean(&data, &t, 1, 1e-10);
1802
1803 assert_eq!(result.n_iter, 1);
1804 assert_eq!(result.mean.len(), 40);
1805 for j in 0..40 {
1807 assert!(result.mean[j].is_finite());
1808 }
1809 }
1810
1811 #[test]
1814 fn test_align_to_target_valid() {
1815 let data = make_test_data(10, 40, 42);
1816 let t = uniform_grid(40);
1817 let target = data.row(0);
1818
1819 let result = align_to_target(&data, &target, &t);
1820
1821 assert_eq!(result.gammas.shape(), (10, 40));
1822 assert_eq!(result.aligned_data.shape(), (10, 40));
1823 assert_eq!(result.distances.len(), 10);
1824
1825 for &d in &result.distances {
1827 assert!(d >= 0.0);
1828 }
1829 }
1830
1831 #[test]
1832 fn test_align_to_target_self_near_zero() {
1833 let data = make_test_data(5, 40, 42);
1834 let t = uniform_grid(40);
1835 let target = data.row(0);
1836
1837 let result = align_to_target(&data, &target, &t);
1838
1839 assert!(
1841 result.distances[0] < 0.1,
1842 "Self-alignment distance should be near zero, got {}",
1843 result.distances[0]
1844 );
1845 }
1846
1847 #[test]
1848 fn test_align_to_target_warps_are_monotone() {
1849 let data = make_test_data(8, 40, 42);
1850 let t = uniform_grid(40);
1851 let target = data.row(0);
1852 let result = align_to_target(&data, &target, &t);
1853
1854 for i in 0..8 {
1855 for j in 1..40 {
1856 assert!(
1857 result.gammas[(i, j)] >= result.gammas[(i, j - 1)],
1858 "Warp for curve {i} should be monotone at j={j}"
1859 );
1860 }
1861 }
1862 }
1863
1864 #[test]
1865 fn test_align_to_target_aligned_data_finite() {
1866 let data = make_test_data(6, 40, 42);
1867 let t = uniform_grid(40);
1868 let target = data.row(0);
1869 let result = align_to_target(&data, &target, &t);
1870
1871 for i in 0..6 {
1872 for j in 0..40 {
1873 assert!(
1874 result.aligned_data[(i, j)].is_finite(),
1875 "Aligned data should be finite at ({i},{j})"
1876 );
1877 }
1878 }
1879 }
1880
1881 #[test]
1884 fn test_cross_distance_matrix_shape() {
1885 let data1 = make_test_data(3, 30, 42);
1886 let data2 = make_test_data(4, 30, 99);
1887 let t = uniform_grid(30);
1888
1889 let dm = elastic_cross_distance_matrix(&data1, &data2, &t);
1890 assert_eq!(dm.shape(), (3, 4));
1891
1892 for i in 0..3 {
1894 for j in 0..4 {
1895 assert!(dm[(i, j)] >= 0.0);
1896 }
1897 }
1898 }
1899
1900 #[test]
1901 fn test_cross_distance_matrix_self_matches_self_matrix() {
1902 let data = make_test_data(4, 30, 42);
1904 let t = uniform_grid(30);
1905
1906 let cross = elastic_cross_distance_matrix(&data, &data, &t);
1907 for i in 0..4 {
1908 assert!(
1909 cross[(i, i)] < 0.1,
1910 "Cross distance (self) diagonal should be near zero: got {}",
1911 cross[(i, i)]
1912 );
1913 }
1914 }
1915
1916 #[test]
1917 fn test_cross_distance_matrix_consistent_with_pairwise() {
1918 let data1 = make_test_data(3, 30, 42);
1919 let data2 = make_test_data(2, 30, 99);
1920 let t = uniform_grid(30);
1921
1922 let dm = elastic_cross_distance_matrix(&data1, &data2, &t);
1923
1924 for i in 0..3 {
1925 for j in 0..2 {
1926 let fi = data1.row(i);
1927 let fj = data2.row(j);
1928 let d_direct = elastic_distance(&fi, &fj, &t);
1929 assert!(
1930 (dm[(i, j)] - d_direct).abs() < 1e-10,
1931 "Cross matrix ({i},{j})={:.6} should match pairwise {d_direct:.6}",
1932 dm[(i, j)]
1933 );
1934 }
1935 }
1936 }
1937
1938 #[test]
1941 fn test_align_srsf_pair_identity() {
1942 let m = 50;
1943 let t = uniform_grid(m);
1944 let f: Vec<f64> = t
1945 .iter()
1946 .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
1947 .collect();
1948 let q = srsf_single(&f, &t);
1949
1950 let (gamma, q_aligned) = align_srsf_pair(&q, &q, &t);
1951
1952 for j in 0..m {
1954 assert!(
1955 (gamma[j] - t[j]).abs() < 0.15,
1956 "Self-SRSF alignment warp should be near identity at j={j}"
1957 );
1958 }
1959
1960 let weights = simpsons_weights(&t);
1962 let dist = l2_distance(&q, &q_aligned, &weights);
1963 assert!(
1964 dist < 0.5,
1965 "Self-aligned SRSF distance should be small, got {dist}"
1966 );
1967 }
1968
1969 #[test]
1972 fn test_srsf_single_matches_matrix_version() {
1973 let m = 50;
1974 let t = uniform_grid(m);
1975 let f: Vec<f64> = t.iter().map(|&ti| ti * ti + ti).collect();
1976
1977 let q_single = srsf_single(&f, &t);
1978
1979 let mat = FdMatrix::from_slice(&f, 1, m).unwrap();
1980 let q_mat = srsf_transform(&mat, &t);
1981 let q_from_mat = q_mat.row(0);
1982
1983 for j in 0..m {
1984 assert!(
1985 (q_single[j] - q_from_mat[j]).abs() < 1e-12,
1986 "srsf_single should match srsf_transform at j={j}"
1987 );
1988 }
1989 }
1990
1991 #[test]
1994 fn test_gcd_basic() {
1995 assert_eq!(gcd(1, 1), 1);
1996 assert_eq!(gcd(6, 4), 2);
1997 assert_eq!(gcd(7, 5), 1);
1998 assert_eq!(gcd(12, 8), 4);
1999 assert_eq!(gcd(7, 0), 7);
2000 assert_eq!(gcd(0, 5), 5);
2001 }
2002
2003 #[test]
2006 fn test_coprime_nbhd_count() {
2007 assert_eq!(generate_coprime_nbhd(1).len(), 1); assert_eq!(generate_coprime_nbhd(7).len(), 35);
2009 }
2010
2011 #[test]
2012 fn test_coprime_nbhd_matches_const() {
2013 let generated = generate_coprime_nbhd(7);
2014 assert_eq!(generated.len(), COPRIME_NBHD_7.len());
2015 for (i, pair) in generated.iter().enumerate() {
2016 assert_eq!(*pair, COPRIME_NBHD_7[i], "mismatch at index {i}");
2017 }
2018 }
2019
2020 #[test]
2021 fn test_coprime_nbhd_all_coprime() {
2022 for &(i, j) in &COPRIME_NBHD_7 {
2023 assert_eq!(gcd(i, j), 1, "({i},{j}) should be coprime");
2024 assert!((1..=7).contains(&i));
2025 assert!((1..=7).contains(&j));
2026 }
2027 }
2028
2029 #[test]
2032 fn test_dp_edge_weight_diagonal() {
2033 let t = uniform_grid(10);
2035 let q1 = vec![1.0; 10];
2036 let q2 = vec![1.0; 10];
2037 let w = dp_edge_weight(&q1, &q2, &t, 0, 1, 0, 1);
2039 assert!(w.abs() < 1e-12, "identical SRSFs should have zero cost");
2040 }
2041
2042 #[test]
2043 fn test_dp_edge_weight_non_diagonal() {
2044 let t = uniform_grid(10);
2046 let q1 = vec![1.0; 10];
2047 let q2 = vec![0.0; 10];
2048 let w = dp_edge_weight(&q1, &q2, &t, 0, 2, 0, 1);
2049 let expected = 2.0 / 9.0;
2052 assert!(
2053 (w - expected).abs() < 1e-10,
2054 "dp_edge_weight (1,2): expected {expected}, got {w}"
2055 );
2056 }
2057
2058 #[test]
2059 fn test_dp_edge_weight_zero_span() {
2060 let t = uniform_grid(10);
2061 let q1 = vec![1.0; 10];
2062 let q2 = vec![1.0; 10];
2063 assert_eq!(dp_edge_weight(&q1, &q2, &t, 3, 3, 0, 1), f64::INFINITY);
2065 assert_eq!(dp_edge_weight(&q1, &q2, &t, 0, 1, 3, 3), f64::INFINITY);
2067 }
2068
2069 #[test]
2072 fn test_alignment_improves_distance() {
2073 let m = 50;
2075 let t = uniform_grid(m);
2076 let f1: Vec<f64> = t
2077 .iter()
2078 .map(|&x| (2.0 * std::f64::consts::PI * x).sin())
2079 .collect();
2080 let f2: Vec<f64> = t
2082 .iter()
2083 .map(|&x| (2.0 * std::f64::consts::PI * (x + 0.2)).sin())
2084 .collect();
2085
2086 let q1 = srsf_single(&f1, &t);
2087 let q2 = srsf_single(&f2, &t);
2088 let weights = simpsons_weights(&t);
2089 let unaligned_srsf_dist = l2_distance(&q1, &q2, &weights);
2090
2091 let result = elastic_align_pair(&f1, &f2, &t);
2092
2093 assert!(
2094 result.distance <= unaligned_srsf_dist + 1e-6,
2095 "aligned SRSF dist ({}) should be <= unaligned SRSF dist ({})",
2096 result.distance,
2097 unaligned_srsf_dist
2098 );
2099 }
2100
2101 #[test]
2104 fn test_alignment_constant_curves() {
2105 let m = 30;
2106 let t = uniform_grid(m);
2107 let f1 = vec![5.0; m];
2108 let f2 = vec![5.0; m];
2109
2110 let result = elastic_align_pair(&f1, &f2, &t);
2111 assert!(
2112 result.distance < 0.01,
2113 "Constant curves: distance should be ~0"
2114 );
2115 assert_eq!(result.f_aligned.len(), m);
2116 }
2117
2118 #[test]
2119 fn test_karcher_mean_constant_curves() {
2120 let m = 30;
2121 let t = uniform_grid(m);
2122 let mut data = FdMatrix::zeros(5, m);
2123 for i in 0..5 {
2124 for j in 0..m {
2125 data[(i, j)] = 3.0;
2126 }
2127 }
2128
2129 let result = karcher_mean(&data, &t, 5, 1e-4);
2130 for j in 0..m {
2131 assert!(
2132 (result.mean[j] - 3.0).abs() < 0.5,
2133 "Mean of constant curves should be near 3.0, got {} at j={j}",
2134 result.mean[j]
2135 );
2136 }
2137 }
2138}