1use crate::fdata::{deriv_1d, mean_1d};
15use crate::helpers::{
16 cumulative_trapz, gradient_uniform, l2_distance, linear_interp, simpsons_weights,
17};
18use crate::iter_maybe_parallel;
19use crate::matrix::FdMatrix;
20use crate::smoothing::nadaraya_watson;
21use crate::warping::{
22 exp_map_sphere, gam_to_psi, inv_exp_map_sphere, invert_gamma, l2_norm_l2, normalize_warp,
23 psi_to_gam,
24};
25#[cfg(feature = "parallel")]
26use rayon::iter::ParallelIterator;
27
28#[derive(Debug, Clone)]
32pub struct AlignmentResult {
33 pub gamma: Vec<f64>,
35 pub f_aligned: Vec<f64>,
37 pub distance: f64,
39}
40
41#[derive(Debug, Clone)]
43pub struct AlignmentSetResult {
44 pub gammas: FdMatrix,
46 pub aligned_data: FdMatrix,
48 pub distances: Vec<f64>,
50}
51
52#[derive(Debug, Clone)]
54pub struct KarcherMeanResult {
55 pub mean: Vec<f64>,
57 pub mean_srsf: Vec<f64>,
59 pub gammas: FdMatrix,
61 pub aligned_data: FdMatrix,
63 pub n_iter: usize,
65 pub converged: bool,
67}
68
69fn karcher_sphere_step(mu: &mut Vec<f64>, psis: &[Vec<f64>], time: &[f64], step_size: f64) -> bool {
82 let m = mu.len();
83 let n = psis.len();
84 let mut vbar = vec![0.0; m];
85 for psi in psis {
86 let v = inv_exp_map_sphere(mu, psi, time);
87 for j in 0..m {
88 vbar[j] += v[j];
89 }
90 }
91 for j in 0..m {
92 vbar[j] /= n as f64;
93 }
94 if l2_norm_l2(&vbar, time) <= 1e-8 {
95 return true;
96 }
97 let scaled: Vec<f64> = vbar.iter().map(|&v| v * step_size).collect();
98 *mu = exp_map_sphere(mu, &scaled, time);
99 false
100}
101
102fn sqrt_mean_inverse(gammas: &FdMatrix, argvals: &[f64]) -> Vec<f64> {
103 let (n, m) = gammas.shape();
104 let t0 = argvals[0];
105 let t1 = argvals[m - 1];
106 let domain = t1 - t0;
107
108 let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
109 let binsize = 1.0 / (m - 1) as f64;
110
111 let psis: Vec<Vec<f64>> = (0..n)
112 .map(|i| {
113 let gam_01: Vec<f64> = (0..m).map(|j| (gammas[(i, j)] - t0) / domain).collect();
114 gam_to_psi(&gam_01, binsize)
115 })
116 .collect();
117
118 let mut mu = vec![0.0; m];
119 for psi in &psis {
120 for j in 0..m {
121 mu[j] += psi[j];
122 }
123 }
124 for j in 0..m {
125 mu[j] /= n as f64;
126 }
127
128 for _ in 0..501 {
129 if karcher_sphere_step(&mut mu, &psis, &time, 0.3) {
130 break;
131 }
132 }
133
134 let gam_mu = psi_to_gam(&mu, &time);
135 let gam_inv = invert_gamma(&gam_mu, &time);
136 gam_inv.iter().map(|&g| t0 + g * domain).collect()
137}
138
139pub fn srsf_transform(data: &FdMatrix, argvals: &[f64]) -> FdMatrix {
152 let (n, m) = data.shape();
153 if n == 0 || m == 0 || argvals.len() != m {
154 return FdMatrix::zeros(n, m);
155 }
156
157 let deriv = deriv_1d(data, argvals, 1);
158
159 let mut result = FdMatrix::zeros(n, m);
160 for i in 0..n {
161 for j in 0..m {
162 let d = deriv[(i, j)];
163 result[(i, j)] = d.signum() * d.abs().sqrt();
164 }
165 }
166 result
167}
168
169pub fn srsf_inverse(q: &[f64], argvals: &[f64], f0: f64) -> Vec<f64> {
181 let m = q.len();
182 if m == 0 {
183 return Vec::new();
184 }
185
186 let integrand: Vec<f64> = q.iter().map(|&qi| qi * qi.abs()).collect();
188 let integral = cumulative_trapz(&integrand, argvals);
189
190 integral.iter().map(|&v| f0 + v).collect()
191}
192
193pub fn reparameterize_curve(f: &[f64], argvals: &[f64], gamma: &[f64]) -> Vec<f64> {
204 gamma
205 .iter()
206 .map(|&g| linear_interp(argvals, f, g))
207 .collect()
208}
209
210pub fn compose_warps(gamma1: &[f64], gamma2: &[f64], argvals: &[f64]) -> Vec<f64> {
217 gamma2
218 .iter()
219 .map(|&g| linear_interp(argvals, gamma1, g))
220 .collect()
221}
222
223#[cfg(test)]
228fn gcd(a: usize, b: usize) -> usize {
229 if b == 0 {
230 a
231 } else {
232 gcd(b, a % b)
233 }
234}
235
236#[cfg(test)]
239fn generate_coprime_nbhd(nbhd_dim: usize) -> Vec<(usize, usize)> {
240 let mut pairs = Vec::new();
241 for i in 1..=nbhd_dim {
242 for j in 1..=nbhd_dim {
243 if gcd(i, j) == 1 {
244 pairs.push((i, j));
245 }
246 }
247 }
248 pairs
249}
250
251#[rustfmt::skip]
255const COPRIME_NBHD_7: [(usize, usize); 35] = [
256 (1,1),(1,2),(1,3),(1,4),(1,5),(1,6),(1,7),
257 (2,1), (2,3), (2,5), (2,7),
258 (3,1),(3,2), (3,4),(3,5), (3,7),
259 (4,1), (4,3), (4,5), (4,7),
260 (5,1),(5,2),(5,3),(5,4), (5,6),(5,7),
261 (6,1), (6,5), (6,7),
262 (7,1),(7,2),(7,3),(7,4),(7,5),(7,6),
263];
264
265#[inline]
273fn dp_edge_weight(
274 q1: &[f64],
275 q2: &[f64],
276 argvals: &[f64],
277 sc: usize,
278 tc: usize,
279 sr: usize,
280 tr: usize,
281) -> f64 {
282 let n1 = tc - sc;
283 let n2 = tr - sr;
284 if n1 == 0 || n2 == 0 {
285 return f64::INFINITY;
286 }
287
288 let slope = (argvals[tr] - argvals[sr]) / (argvals[tc] - argvals[sc]);
289 let rslope = slope.sqrt();
290
291 let mut weight = 0.0;
293 let mut i1 = 0usize; let mut i2 = 0usize; while i1 < n1 && i2 < n2 {
297 let left1 = i1 as f64 / n1 as f64;
299 let right1 = (i1 + 1) as f64 / n1 as f64;
300 let left2 = i2 as f64 / n2 as f64;
301 let right2 = (i2 + 1) as f64 / n2 as f64;
302
303 let left = left1.max(left2);
304 let right = right1.min(right2);
305 let dt = right - left;
306
307 if dt > 0.0 {
308 let diff = q1[sc + i1] - rslope * q2[sr + i2];
309 weight += diff * diff * dt;
310 }
311
312 if right1 < right2 {
314 i1 += 1;
315 } else if right2 < right1 {
316 i2 += 1;
317 } else {
318 i1 += 1;
319 i2 += 1;
320 }
321 }
322
323 weight * (argvals[tc] - argvals[sc])
325}
326
327#[inline]
329fn dp_lambda_penalty(
330 argvals: &[f64],
331 sc: usize,
332 tc: usize,
333 sr: usize,
334 tr: usize,
335 lambda: f64,
336) -> f64 {
337 if lambda > 0.0 {
338 let dt = argvals[tc] - argvals[sc];
339 let slope = (argvals[tr] - argvals[sr]) / dt;
340 lambda * (slope - 1.0).powi(2) * dt
341 } else {
342 0.0
343 }
344}
345
346fn dp_traceback(parent: &[u32], nrows: usize, ncols: usize) -> Vec<(usize, usize)> {
350 let mut path = Vec::with_capacity(nrows + ncols);
351 let mut cur = (nrows - 1) * ncols + (ncols - 1);
352 loop {
353 path.push((cur / ncols, cur % ncols));
354 if cur == 0 || parent[cur] == u32::MAX {
355 break;
356 }
357 cur = parent[cur] as usize;
358 }
359 path.reverse();
360 path
361}
362
363#[inline]
365fn dp_relax_cell<F>(
366 e: &mut [f64],
367 parent: &mut [u32],
368 ncols: usize,
369 tr: usize,
370 tc: usize,
371 edge_cost: &F,
372) where
373 F: Fn(usize, usize, usize, usize) -> f64,
374{
375 let idx = tr * ncols + tc;
376 for &(dr, dc) in &COPRIME_NBHD_7 {
377 if dr > tr || dc > tc {
378 continue;
379 }
380 let sr = tr - dr;
381 let sc = tc - dc;
382 let src_idx = sr * ncols + sc;
383 if e[src_idx] == f64::INFINITY {
384 continue;
385 }
386 let cost = e[src_idx] + edge_cost(sr, sc, tr, tc);
387 if cost < e[idx] {
388 e[idx] = cost;
389 parent[idx] = src_idx as u32;
390 }
391 }
392}
393
394fn dp_grid_solve<F>(nrows: usize, ncols: usize, edge_cost: F) -> Vec<(usize, usize)>
400where
401 F: Fn(usize, usize, usize, usize) -> f64,
402{
403 let mut e = vec![f64::INFINITY; nrows * ncols];
404 let mut parent = vec![u32::MAX; nrows * ncols];
405 e[0] = 0.0;
406
407 for tr in 0..nrows {
408 for tc in 0..ncols {
409 if tr == 0 && tc == 0 {
410 continue;
411 }
412 dp_relax_cell(&mut e, &mut parent, ncols, tr, tc, &edge_cost);
413 }
414 }
415
416 dp_traceback(&parent, nrows, ncols)
417}
418
419fn dp_path_to_gamma(path: &[(usize, usize)], argvals: &[f64]) -> Vec<f64> {
421 let path_tc: Vec<f64> = path.iter().map(|&(_, c)| argvals[c]).collect();
422 let path_tr: Vec<f64> = path.iter().map(|&(r, _)| argvals[r]).collect();
423 let mut gamma: Vec<f64> = argvals
424 .iter()
425 .map(|&t| linear_interp(&path_tc, &path_tr, t))
426 .collect();
427 normalize_warp(&mut gamma, argvals);
428 gamma
429}
430
431fn dp_alignment_core(q1: &[f64], q2: &[f64], argvals: &[f64], lambda: f64) -> Vec<f64> {
437 let m = argvals.len();
438 if m < 2 {
439 return argvals.to_vec();
440 }
441
442 let norm1 = q1.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
443 let norm2 = q2.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
444 let q1n: Vec<f64> = q1.iter().map(|&v| v / norm1).collect();
445 let q2n: Vec<f64> = q2.iter().map(|&v| v / norm2).collect();
446
447 let path = dp_grid_solve(m, m, |sr, sc, tr, tc| {
448 dp_edge_weight(&q1n, &q2n, argvals, sc, tc, sr, tr)
449 + dp_lambda_penalty(argvals, sc, tc, sr, tr, lambda)
450 });
451
452 dp_path_to_gamma(&path, argvals)
453}
454
455pub fn elastic_align_pair(f1: &[f64], f2: &[f64], argvals: &[f64], lambda: f64) -> AlignmentResult {
471 let m = f1.len();
472
473 let f1_mat = FdMatrix::from_slice(f1, 1, m).unwrap();
475 let f2_mat = FdMatrix::from_slice(f2, 1, m).unwrap();
476
477 let q1_mat = srsf_transform(&f1_mat, argvals);
478 let q2_mat = srsf_transform(&f2_mat, argvals);
479
480 let q1: Vec<f64> = q1_mat.row(0);
481 let q2: Vec<f64> = q2_mat.row(0);
482
483 let gamma = dp_alignment_core(&q1, &q2, argvals, lambda);
485
486 let f_aligned = reparameterize_curve(f2, argvals, &gamma);
488
489 let f_aligned_mat = FdMatrix::from_slice(&f_aligned, 1, m).unwrap();
491 let q_aligned_mat = srsf_transform(&f_aligned_mat, argvals);
492 let q_aligned: Vec<f64> = q_aligned_mat.row(0);
493
494 let weights = simpsons_weights(argvals);
495 let distance = l2_distance(&q1, &q_aligned, &weights);
496
497 AlignmentResult {
498 gamma,
499 f_aligned,
500 distance,
501 }
502}
503
504pub fn elastic_distance(f1: &[f64], f2: &[f64], argvals: &[f64], lambda: f64) -> f64 {
514 elastic_align_pair(f1, f2, argvals, lambda).distance
515}
516
517pub fn align_to_target(
528 data: &FdMatrix,
529 target: &[f64],
530 argvals: &[f64],
531 lambda: f64,
532) -> AlignmentSetResult {
533 let (n, m) = data.shape();
534
535 let results: Vec<AlignmentResult> = iter_maybe_parallel!(0..n)
536 .map(|i| {
537 let fi = data.row(i);
538 elastic_align_pair(target, &fi, argvals, lambda)
539 })
540 .collect();
541
542 let mut gammas = FdMatrix::zeros(n, m);
543 let mut aligned_data = FdMatrix::zeros(n, m);
544 let mut distances = Vec::with_capacity(n);
545
546 for (i, r) in results.into_iter().enumerate() {
547 for j in 0..m {
548 gammas[(i, j)] = r.gamma[j];
549 aligned_data[(i, j)] = r.f_aligned[j];
550 }
551 distances.push(r.distance);
552 }
553
554 AlignmentSetResult {
555 gammas,
556 aligned_data,
557 distances,
558 }
559}
560
561pub fn elastic_self_distance_matrix(data: &FdMatrix, argvals: &[f64], lambda: f64) -> FdMatrix {
576 let n = data.nrows();
577
578 let upper_vals: Vec<f64> = iter_maybe_parallel!(0..n)
579 .flat_map(|i| {
580 let fi = data.row(i);
581 ((i + 1)..n)
582 .map(|j| {
583 let fj = data.row(j);
584 elastic_distance(&fi, &fj, argvals, lambda)
585 })
586 .collect::<Vec<_>>()
587 })
588 .collect();
589
590 let mut dist = FdMatrix::zeros(n, n);
591 let mut idx = 0;
592 for i in 0..n {
593 for j in (i + 1)..n {
594 let d = upper_vals[idx];
595 dist[(i, j)] = d;
596 dist[(j, i)] = d;
597 idx += 1;
598 }
599 }
600 dist
601}
602
603pub fn elastic_cross_distance_matrix(
614 data1: &FdMatrix,
615 data2: &FdMatrix,
616 argvals: &[f64],
617 lambda: f64,
618) -> FdMatrix {
619 let n1 = data1.nrows();
620 let n2 = data2.nrows();
621
622 let vals: Vec<f64> = iter_maybe_parallel!(0..n1)
623 .flat_map(|i| {
624 let fi = data1.row(i);
625 (0..n2)
626 .map(|j| {
627 let fj = data2.row(j);
628 elastic_distance(&fi, &fj, argvals, lambda)
629 })
630 .collect::<Vec<_>>()
631 })
632 .collect();
633
634 let mut dist = FdMatrix::zeros(n1, n2);
635 for i in 0..n1 {
636 for j in 0..n2 {
637 dist[(i, j)] = vals[i * n2 + j];
638 }
639 }
640 dist
641}
642
643#[derive(Debug, Clone)]
647pub struct DecompositionResult {
648 pub alignment: AlignmentResult,
650 pub d_amplitude: f64,
652 pub d_phase: f64,
654}
655
656pub fn elastic_decomposition(
666 f1: &[f64],
667 f2: &[f64],
668 argvals: &[f64],
669 lambda: f64,
670) -> DecompositionResult {
671 let alignment = elastic_align_pair(f1, f2, argvals, lambda);
672 let d_amplitude = alignment.distance;
673 let d_phase = crate::warping::phase_distance(&alignment.gamma, argvals);
674 DecompositionResult {
675 alignment,
676 d_amplitude,
677 d_phase,
678 }
679}
680
681pub fn amplitude_distance(f1: &[f64], f2: &[f64], argvals: &[f64], lambda: f64) -> f64 {
683 elastic_distance(f1, f2, argvals, lambda)
684}
685
686pub fn phase_distance_pair(f1: &[f64], f2: &[f64], argvals: &[f64], lambda: f64) -> f64 {
688 let alignment = elastic_align_pair(f1, f2, argvals, lambda);
689 crate::warping::phase_distance(&alignment.gamma, argvals)
690}
691
692pub fn phase_self_distance_matrix(data: &FdMatrix, argvals: &[f64], lambda: f64) -> FdMatrix {
694 let n = data.nrows();
695
696 let upper_vals: Vec<f64> = iter_maybe_parallel!(0..n)
697 .flat_map(|i| {
698 let fi = data.row(i);
699 ((i + 1)..n)
700 .map(|j| {
701 let fj = data.row(j);
702 phase_distance_pair(&fi, &fj, argvals, lambda)
703 })
704 .collect::<Vec<_>>()
705 })
706 .collect();
707
708 let mut dist = FdMatrix::zeros(n, n);
709 let mut idx = 0;
710 for i in 0..n {
711 for j in (i + 1)..n {
712 let d = upper_vals[idx];
713 dist[(i, j)] = d;
714 dist[(j, i)] = d;
715 idx += 1;
716 }
717 }
718 dist
719}
720
721pub fn amplitude_self_distance_matrix(data: &FdMatrix, argvals: &[f64], lambda: f64) -> FdMatrix {
723 elastic_self_distance_matrix(data, argvals, lambda)
724}
725
726fn relative_change(q_old: &[f64], q_new: &[f64]) -> f64 {
733 let diff_norm: f64 = q_old
734 .iter()
735 .zip(q_new.iter())
736 .map(|(&a, &b)| (a - b).powi(2))
737 .sum::<f64>()
738 .sqrt();
739 let old_norm: f64 = q_old.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
740 diff_norm / old_norm
741}
742
743fn srsf_single(f: &[f64], argvals: &[f64]) -> Vec<f64> {
745 let m = f.len();
746 let mat = FdMatrix::from_slice(f, 1, m).unwrap();
747 let q_mat = srsf_transform(&mat, argvals);
748 q_mat.row(0)
749}
750
751fn align_srsf_pair(q1: &[f64], q2: &[f64], argvals: &[f64], lambda: f64) -> (Vec<f64>, Vec<f64>) {
753 let gamma = dp_alignment_core(q1, q2, argvals, lambda);
754
755 let q2_warped = reparameterize_curve(q2, argvals, &gamma);
757
758 let m = gamma.len();
760 let mut gamma_dot = vec![0.0; m];
761 gamma_dot[0] = (gamma[1] - gamma[0]) / (argvals[1] - argvals[0]);
762 for j in 1..(m - 1) {
763 gamma_dot[j] = (gamma[j + 1] - gamma[j - 1]) / (argvals[j + 1] - argvals[j - 1]);
764 }
765 gamma_dot[m - 1] = (gamma[m - 1] - gamma[m - 2]) / (argvals[m - 1] - argvals[m - 2]);
766
767 let q2_aligned: Vec<f64> = q2_warped
769 .iter()
770 .zip(gamma_dot.iter())
771 .map(|(&q, &gd)| q * gd.max(0.0).sqrt())
772 .collect();
773
774 (gamma, q2_aligned)
775}
776
777fn accumulate_alignments(
806 results: &[(Vec<f64>, Vec<f64>)],
807 gammas: &mut FdMatrix,
808 m: usize,
809 n: usize,
810) -> Vec<f64> {
811 let mut mu_q_new = vec![0.0; m];
812 for (i, (gamma, q_aligned)) in results.iter().enumerate() {
813 for j in 0..m {
814 gammas[(i, j)] = gamma[j];
815 mu_q_new[j] += q_aligned[j];
816 }
817 }
818 for j in 0..m {
819 mu_q_new[j] /= n as f64;
820 }
821 mu_q_new
822}
823
824fn apply_stored_warps(data: &FdMatrix, gammas: &FdMatrix, argvals: &[f64]) -> FdMatrix {
826 let (n, m) = data.shape();
827 let mut aligned = FdMatrix::zeros(n, m);
828 for i in 0..n {
829 let fi = data.row(i);
830 let gamma: Vec<f64> = (0..m).map(|j| gammas[(i, j)]).collect();
831 let f_aligned = reparameterize_curve(&fi, argvals, &gamma);
832 for j in 0..m {
833 aligned[(i, j)] = f_aligned[j];
834 }
835 }
836 aligned
837}
838
839fn select_template(srsf_mat: &FdMatrix, data: &FdMatrix, argvals: &[f64]) -> (Vec<f64>, Vec<f64>) {
841 let (n, m) = srsf_mat.shape();
842 let mnq = mean_1d(srsf_mat);
843 let mut min_dist = f64::INFINITY;
844 let mut min_idx = 0;
845 for i in 0..n {
846 let dist_sq: f64 = (0..m).map(|j| (srsf_mat[(i, j)] - mnq[j]).powi(2)).sum();
847 if dist_sq < min_dist {
848 min_dist = dist_sq;
849 min_idx = i;
850 }
851 }
852 let _ = argvals; (srsf_mat.row(min_idx), data.row(min_idx))
854}
855
856fn pre_center_template(
858 data: &FdMatrix,
859 mu_q: &[f64],
860 mu: &[f64],
861 argvals: &[f64],
862 lambda: f64,
863) -> (Vec<f64>, Vec<f64>) {
864 let (n, m) = data.shape();
865 let align_results: Vec<(Vec<f64>, Vec<f64>)> = iter_maybe_parallel!(0..n)
866 .map(|i| {
867 let fi = data.row(i);
868 let qi = srsf_single(&fi, argvals);
869 align_srsf_pair(mu_q, &qi, argvals, lambda)
870 })
871 .collect();
872
873 let mut init_gammas = FdMatrix::zeros(n, m);
874 for (i, (gamma, _)) in align_results.iter().enumerate() {
875 for j in 0..m {
876 init_gammas[(i, j)] = gamma[j];
877 }
878 }
879
880 let gam_inv = sqrt_mean_inverse(&init_gammas, argvals);
881 let mu_new = reparameterize_curve(mu, argvals, &gam_inv);
882 let mu_q_new = srsf_single(&mu_new, argvals);
883 (mu_q_new, mu_new)
884}
885
886fn post_center_results(
888 data: &FdMatrix,
889 mu_q: &[f64],
890 final_gammas: &mut FdMatrix,
891 argvals: &[f64],
892) -> (Vec<f64>, Vec<f64>, FdMatrix) {
893 let (n, m) = data.shape();
894 let gam_inv = sqrt_mean_inverse(final_gammas, argvals);
895 let h = (argvals[m - 1] - argvals[0]) / (m - 1) as f64;
896 let gam_inv_dev = gradient_uniform(&gam_inv, h);
897
898 let mu_q_warped = reparameterize_curve(mu_q, argvals, &gam_inv);
899 let mu_q_centered: Vec<f64> = mu_q_warped
900 .iter()
901 .zip(gam_inv_dev.iter())
902 .map(|(&q, &gd)| q * gd.max(0.0).sqrt())
903 .collect();
904
905 for i in 0..n {
906 let gam_i: Vec<f64> = (0..m).map(|j| final_gammas[(i, j)]).collect();
907 let gam_centered = reparameterize_curve(&gam_i, argvals, &gam_inv);
908 for j in 0..m {
909 final_gammas[(i, j)] = gam_centered[j];
910 }
911 }
912
913 let initial_mean = mean_1d(data);
914 let mu = srsf_inverse(&mu_q_centered, argvals, initial_mean[0]);
915 let final_aligned = apply_stored_warps(data, final_gammas, argvals);
916 (mu, mu_q_centered, final_aligned)
917}
918
919pub fn karcher_mean(
920 data: &FdMatrix,
921 argvals: &[f64],
922 max_iter: usize,
923 tol: f64,
924 lambda: f64,
925) -> KarcherMeanResult {
926 let (n, m) = data.shape();
927
928 let srsf_mat = srsf_transform(data, argvals);
929 let (mut mu_q, mu) = select_template(&srsf_mat, data, argvals);
930 let (mu_q_c, mu_c) = pre_center_template(data, &mu_q, &mu, argvals, lambda);
931 mu_q = mu_q_c;
932 let mut mu = mu_c;
933
934 let mut converged = false;
935 let mut n_iter = 0;
936 let mut final_gammas = FdMatrix::zeros(n, m);
937
938 for iter in 0..max_iter {
939 n_iter = iter + 1;
940
941 let align_results: Vec<(Vec<f64>, Vec<f64>)> = iter_maybe_parallel!(0..n)
942 .map(|i| {
943 let fi = data.row(i);
944 let qi = srsf_single(&fi, argvals);
945 align_srsf_pair(&mu_q, &qi, argvals, lambda)
946 })
947 .collect();
948
949 let mu_q_new = accumulate_alignments(&align_results, &mut final_gammas, m, n);
950
951 let rel = relative_change(&mu_q, &mu_q_new);
952 if rel < tol {
953 converged = true;
954 mu_q = mu_q_new;
955 break;
956 }
957
958 mu_q = mu_q_new;
959 mu = srsf_inverse(&mu_q, argvals, mu[0]);
960 }
961
962 let (mu_final, mu_q_final, final_aligned) =
963 post_center_results(data, &mu_q, &mut final_gammas, argvals);
964
965 KarcherMeanResult {
966 mean: mu_final,
967 mean_srsf: mu_q_final,
968 gammas: final_gammas,
969 aligned_data: final_aligned,
970 n_iter,
971 converged,
972 }
973}
974
975#[derive(Debug, Clone)]
982pub struct TsrvfResult {
983 pub tangent_vectors: FdMatrix,
985 pub mean: Vec<f64>,
987 pub mean_srsf: Vec<f64>,
989 pub mean_srsf_norm: f64,
991 pub srsf_norms: Vec<f64>,
993 pub initial_values: Vec<f64>,
995 pub gammas: FdMatrix,
997 pub converged: bool,
999}
1000
1001pub fn tsrvf_transform(
1012 data: &FdMatrix,
1013 argvals: &[f64],
1014 max_iter: usize,
1015 tol: f64,
1016 lambda: f64,
1017) -> TsrvfResult {
1018 let karcher = karcher_mean(data, argvals, max_iter, tol, lambda);
1019 tsrvf_from_alignment(&karcher, argvals)
1020}
1021
1022fn smooth_aligned_srsfs(srsf: &FdMatrix, m: usize) -> FdMatrix {
1028 let n = srsf.nrows();
1029 let time: Vec<f64> = (0..m).map(|j| j as f64 / (m - 1) as f64).collect();
1030 let bandwidth = 2.0 / (m - 1) as f64;
1031
1032 let mut smoothed = FdMatrix::zeros(n, m);
1033 for i in 0..n {
1034 let qi = srsf.row(i);
1035 let qi_smooth = nadaraya_watson(&time, &qi, &time, bandwidth, "gaussian");
1036 for j in 0..m {
1037 smoothed[(i, j)] = qi_smooth[j];
1038 }
1039 }
1040 smoothed
1041}
1042
1043pub fn tsrvf_from_alignment(karcher: &KarcherMeanResult, argvals: &[f64]) -> TsrvfResult {
1055 let (n, m) = karcher.aligned_data.shape();
1056
1057 let aligned_srsf = srsf_transform(&karcher.aligned_data, argvals);
1059
1060 let aligned_srsf = smooth_aligned_srsfs(&aligned_srsf, m);
1072
1073 let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
1077 let bandwidth = 2.0 / (m - 1) as f64;
1078 let mean_srsf_smooth = nadaraya_watson(&time, &karcher.mean_srsf, &time, bandwidth, "gaussian");
1079 let mean_norm = l2_norm_l2(&mean_srsf_smooth, &time);
1080
1081 let mu_unit: Vec<f64> = if mean_norm > 1e-10 {
1082 mean_srsf_smooth.iter().map(|&q| q / mean_norm).collect()
1083 } else {
1084 vec![0.0; m]
1085 };
1086
1087 let srsf_norms: Vec<f64> = iter_maybe_parallel!(0..n)
1089 .map(|i| {
1090 let qi = aligned_srsf.row(i);
1091 l2_norm_l2(&qi, &time)
1092 })
1093 .collect();
1094
1095 let tangent_data: Vec<Vec<f64>> = iter_maybe_parallel!(0..n)
1096 .map(|i| {
1097 let qi = aligned_srsf.row(i);
1098 let qi_norm = srsf_norms[i];
1099
1100 if qi_norm < 1e-10 || mean_norm < 1e-10 {
1101 return vec![0.0; m];
1102 }
1103
1104 let qi_unit: Vec<f64> = qi.iter().map(|&q| q / qi_norm).collect();
1106
1107 inv_exp_map_sphere(&mu_unit, &qi_unit, &time)
1109 })
1110 .collect();
1111
1112 let mut tangent_vectors = FdMatrix::zeros(n, m);
1114 for i in 0..n {
1115 for j in 0..m {
1116 tangent_vectors[(i, j)] = tangent_data[i][j];
1117 }
1118 }
1119
1120 let initial_values: Vec<f64> = (0..n).map(|i| karcher.aligned_data[(i, 0)]).collect();
1123
1124 TsrvfResult {
1125 tangent_vectors,
1126 mean: karcher.mean.clone(),
1127 mean_srsf: mean_srsf_smooth,
1128 mean_srsf_norm: mean_norm,
1129 srsf_norms,
1130 initial_values,
1131 gammas: karcher.gammas.clone(),
1132 converged: karcher.converged,
1133 }
1134}
1135
1136pub fn tsrvf_inverse(tsrvf: &TsrvfResult, argvals: &[f64]) -> FdMatrix {
1148 let (n, m) = tsrvf.tangent_vectors.shape();
1149
1150 let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
1151
1152 let mu_unit: Vec<f64> = if tsrvf.mean_srsf_norm > 1e-10 {
1154 tsrvf
1155 .mean_srsf
1156 .iter()
1157 .map(|&q| q / tsrvf.mean_srsf_norm)
1158 .collect()
1159 } else {
1160 vec![0.0; m]
1161 };
1162
1163 let curves: Vec<Vec<f64>> = iter_maybe_parallel!(0..n)
1164 .map(|i| {
1165 let vi = tsrvf.tangent_vectors.row(i);
1166
1167 let qi_unit = exp_map_sphere(&mu_unit, &vi, &time);
1169
1170 let qi: Vec<f64> = qi_unit.iter().map(|&q| q * tsrvf.srsf_norms[i]).collect();
1172
1173 srsf_inverse(&qi, argvals, tsrvf.initial_values[i])
1175 })
1176 .collect();
1177
1178 let mut result = FdMatrix::zeros(n, m);
1179 for i in 0..n {
1180 for j in 0..m {
1181 result[(i, j)] = curves[i][j];
1182 }
1183 }
1184 result
1185}
1186
1187#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
1191pub enum TransportMethod {
1192 #[default]
1194 LogMap,
1195 SchildsLadder,
1197 PoleLadder,
1199}
1200
1201fn parallel_transport_schilds(v: &[f64], from: &[f64], to: &[f64], time: &[f64]) -> Vec<f64> {
1203 use crate::warping::{exp_map_sphere, inv_exp_map_sphere};
1204
1205 let v_norm = crate::warping::l2_norm_l2(v, time);
1206 if v_norm < 1e-10 {
1207 return vec![0.0; v.len()];
1208 }
1209
1210 let endpoint = exp_map_sphere(from, v, time);
1212
1213 let log_to_ep = inv_exp_map_sphere(to, &endpoint, time);
1215
1216 let half_log: Vec<f64> = log_to_ep.iter().map(|&x| 0.5 * x).collect();
1218 let midpoint = exp_map_sphere(to, &half_log, time);
1219
1220 let log_to_mid = inv_exp_map_sphere(to, &midpoint, time);
1222 log_to_mid.iter().map(|&x| 2.0 * x).collect()
1223}
1224
1225fn parallel_transport_pole(v: &[f64], from: &[f64], to: &[f64], time: &[f64]) -> Vec<f64> {
1227 use crate::warping::{exp_map_sphere, inv_exp_map_sphere};
1228
1229 let v_norm = crate::warping::l2_norm_l2(v, time);
1230 if v_norm < 1e-10 {
1231 return vec![0.0; v.len()];
1232 }
1233
1234 let neg_v: Vec<f64> = v.iter().map(|&x| -x).collect();
1236 let pole = exp_map_sphere(from, &neg_v, time);
1237
1238 let log_to_pole = inv_exp_map_sphere(to, &pole, time);
1240
1241 let half_log: Vec<f64> = log_to_pole.iter().map(|&x| 0.5 * x).collect();
1243 let midpoint = exp_map_sphere(to, &half_log, time);
1244
1245 let log_to_mid = inv_exp_map_sphere(to, &midpoint, time);
1247 log_to_mid.iter().map(|&x| -2.0 * x).collect()
1248}
1249
1250pub fn tsrvf_transform_with_method(
1254 data: &FdMatrix,
1255 argvals: &[f64],
1256 max_iter: usize,
1257 tol: f64,
1258 lambda: f64,
1259 method: TransportMethod,
1260) -> TsrvfResult {
1261 let karcher = karcher_mean(data, argvals, max_iter, tol, lambda);
1262 tsrvf_from_alignment_with_method(&karcher, argvals, method)
1263}
1264
1265pub fn tsrvf_from_alignment_with_method(
1272 karcher: &KarcherMeanResult,
1273 argvals: &[f64],
1274 method: TransportMethod,
1275) -> TsrvfResult {
1276 if method == TransportMethod::LogMap {
1277 return tsrvf_from_alignment(karcher, argvals);
1278 }
1279
1280 let (n, m) = karcher.aligned_data.shape();
1281 let aligned_srsf = srsf_transform(&karcher.aligned_data, argvals);
1282 let aligned_srsf = smooth_aligned_srsfs(&aligned_srsf, m);
1283 let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
1284 let bandwidth = 2.0 / (m - 1) as f64;
1285 let mean_srsf_smooth = nadaraya_watson(&time, &karcher.mean_srsf, &time, bandwidth, "gaussian");
1286 let mean_norm = crate::warping::l2_norm_l2(&mean_srsf_smooth, &time);
1287
1288 let mu_unit: Vec<f64> = if mean_norm > 1e-10 {
1289 mean_srsf_smooth.iter().map(|&q| q / mean_norm).collect()
1290 } else {
1291 vec![0.0; m]
1292 };
1293
1294 let srsf_norms: Vec<f64> = iter_maybe_parallel!(0..n)
1295 .map(|i| {
1296 let qi = aligned_srsf.row(i);
1297 crate::warping::l2_norm_l2(&qi, &time)
1298 })
1299 .collect();
1300
1301 let tangent_data: Vec<Vec<f64>> = iter_maybe_parallel!(0..n)
1302 .map(|i| {
1303 let qi = aligned_srsf.row(i);
1304 let qi_norm = srsf_norms[i];
1305
1306 if qi_norm < 1e-10 || mean_norm < 1e-10 {
1307 return vec![0.0; m];
1308 }
1309
1310 let qi_unit: Vec<f64> = qi.iter().map(|&q| q / qi_norm).collect();
1311
1312 let v_at_qi = inv_exp_map_sphere(&qi_unit, &mu_unit, &time);
1314 let neg_v: Vec<f64> = v_at_qi.iter().map(|&x| -x).collect();
1315
1316 match method {
1318 TransportMethod::SchildsLadder => {
1319 parallel_transport_schilds(&neg_v, &qi_unit, &mu_unit, &time)
1320 }
1321 TransportMethod::PoleLadder => {
1322 parallel_transport_pole(&neg_v, &qi_unit, &mu_unit, &time)
1323 }
1324 TransportMethod::LogMap => unreachable!(),
1325 }
1326 })
1327 .collect();
1328
1329 let mut tangent_vectors = FdMatrix::zeros(n, m);
1330 for i in 0..n {
1331 for j in 0..m {
1332 tangent_vectors[(i, j)] = tangent_data[i][j];
1333 }
1334 }
1335
1336 let initial_values: Vec<f64> = (0..n).map(|i| karcher.aligned_data[(i, 0)]).collect();
1337
1338 TsrvfResult {
1339 tangent_vectors,
1340 mean: karcher.mean.clone(),
1341 mean_srsf: mean_srsf_smooth,
1342 mean_srsf_norm: mean_norm,
1343 srsf_norms,
1344 initial_values,
1345 gammas: karcher.gammas.clone(),
1346 converged: karcher.converged,
1347 }
1348}
1349
1350#[derive(Debug, Clone)]
1354pub struct AlignmentQuality {
1355 pub warp_complexity: Vec<f64>,
1357 pub mean_warp_complexity: f64,
1359 pub warp_smoothness: Vec<f64>,
1361 pub mean_warp_smoothness: f64,
1363 pub total_variance: f64,
1365 pub amplitude_variance: f64,
1367 pub phase_variance: f64,
1369 pub phase_amplitude_ratio: f64,
1371 pub pointwise_variance_ratio: Vec<f64>,
1373 pub mean_variance_reduction: f64,
1375}
1376
1377pub fn warp_complexity(gamma: &[f64], argvals: &[f64]) -> f64 {
1381 crate::warping::phase_distance(gamma, argvals)
1382}
1383
1384pub fn warp_smoothness(gamma: &[f64], argvals: &[f64]) -> f64 {
1386 let m = gamma.len();
1387 if m < 3 {
1388 return 0.0;
1389 }
1390
1391 let h = (argvals[m - 1] - argvals[0]) / (m - 1) as f64;
1392 let gam_prime = gradient_uniform(gamma, h);
1393 let gam_pprime = gradient_uniform(&gam_prime, h);
1394
1395 let integrand: Vec<f64> = gam_pprime.iter().map(|&g| g * g).collect();
1396 crate::helpers::trapz(&integrand, argvals)
1397}
1398
1399pub fn alignment_quality(
1406 data: &FdMatrix,
1407 karcher: &KarcherMeanResult,
1408 argvals: &[f64],
1409) -> AlignmentQuality {
1410 let (n, m) = data.shape();
1411 let weights = simpsons_weights(argvals);
1412
1413 let wc: Vec<f64> = (0..n)
1415 .map(|i| {
1416 let gamma: Vec<f64> = (0..m).map(|j| karcher.gammas[(i, j)]).collect();
1417 warp_complexity(&gamma, argvals)
1418 })
1419 .collect();
1420 let ws: Vec<f64> = (0..n)
1421 .map(|i| {
1422 let gamma: Vec<f64> = (0..m).map(|j| karcher.gammas[(i, j)]).collect();
1423 warp_smoothness(&gamma, argvals)
1424 })
1425 .collect();
1426
1427 let mean_wc = wc.iter().sum::<f64>() / n as f64;
1428 let mean_ws = ws.iter().sum::<f64>() / n as f64;
1429
1430 let orig_mean = crate::fdata::mean_1d(data);
1432
1433 let total_var: f64 = (0..n)
1435 .map(|i| {
1436 let fi = data.row(i);
1437 let d = l2_distance(&fi, &orig_mean, &weights);
1438 d * d
1439 })
1440 .sum::<f64>()
1441 / n as f64;
1442
1443 let aligned_mean = crate::fdata::mean_1d(&karcher.aligned_data);
1445
1446 let amp_var: f64 = (0..n)
1448 .map(|i| {
1449 let fi = karcher.aligned_data.row(i);
1450 let d = l2_distance(&fi, &aligned_mean, &weights);
1451 d * d
1452 })
1453 .sum::<f64>()
1454 / n as f64;
1455
1456 let phase_var = (total_var - amp_var).max(0.0);
1457 let ratio = if total_var > 1e-10 {
1458 phase_var / total_var
1459 } else {
1460 0.0
1461 };
1462
1463 let mut pw_ratio = vec![0.0; m];
1465 for j in 0..m {
1466 let col_orig = data.column(j);
1467 let mean_orig = col_orig.iter().sum::<f64>() / n as f64;
1468 let var_orig: f64 = col_orig
1469 .iter()
1470 .map(|&v| (v - mean_orig).powi(2))
1471 .sum::<f64>()
1472 / n as f64;
1473
1474 let col_aligned = karcher.aligned_data.column(j);
1475 let mean_aligned = col_aligned.iter().sum::<f64>() / n as f64;
1476 let var_aligned: f64 = col_aligned
1477 .iter()
1478 .map(|&v| (v - mean_aligned).powi(2))
1479 .sum::<f64>()
1480 / n as f64;
1481
1482 pw_ratio[j] = if var_orig > 1e-15 {
1483 var_aligned / var_orig
1484 } else {
1485 1.0
1486 };
1487 }
1488
1489 let mean_vr = pw_ratio.iter().sum::<f64>() / m as f64;
1490
1491 AlignmentQuality {
1492 warp_complexity: wc,
1493 mean_warp_complexity: mean_wc,
1494 warp_smoothness: ws,
1495 mean_warp_smoothness: mean_ws,
1496 total_variance: total_var,
1497 amplitude_variance: amp_var,
1498 phase_variance: phase_var,
1499 phase_amplitude_ratio: ratio,
1500 pointwise_variance_ratio: pw_ratio,
1501 mean_variance_reduction: mean_vr,
1502 }
1503}
1504
1505fn triplet_indices(n: usize, max_triplets: usize) -> Vec<(usize, usize, usize)> {
1507 let total = n * (n - 1) * (n - 2) / 6;
1508 let cap = if max_triplets > 0 {
1509 max_triplets.min(total)
1510 } else {
1511 total
1512 };
1513 (0..n)
1514 .flat_map(|i| ((i + 1)..n).flat_map(move |j| ((j + 1)..n).map(move |k| (i, j, k))))
1515 .take(cap)
1516 .collect()
1517}
1518
1519fn triplet_warp_deviation(
1521 data: &FdMatrix,
1522 argvals: &[f64],
1523 weights: &[f64],
1524 i: usize,
1525 j: usize,
1526 k: usize,
1527 lambda: f64,
1528) -> f64 {
1529 let fi = data.row(i);
1530 let fj = data.row(j);
1531 let fk = data.row(k);
1532 let rij = elastic_align_pair(&fi, &fj, argvals, lambda);
1533 let rjk = elastic_align_pair(&fj, &fk, argvals, lambda);
1534 let rik = elastic_align_pair(&fi, &fk, argvals, lambda);
1535 let composed = compose_warps(&rij.gamma, &rjk.gamma, argvals);
1536 l2_distance(&composed, &rik.gamma, weights)
1537}
1538
1539pub fn pairwise_consistency(
1550 data: &FdMatrix,
1551 argvals: &[f64],
1552 lambda: f64,
1553 max_triplets: usize,
1554) -> f64 {
1555 let n = data.nrows();
1556 if n < 3 {
1557 return 0.0;
1558 }
1559
1560 let weights = simpsons_weights(argvals);
1561 let triplets = triplet_indices(n, max_triplets);
1562 if triplets.is_empty() {
1563 return 0.0;
1564 }
1565
1566 let total_dev: f64 = triplets
1567 .iter()
1568 .map(|&(i, j, k)| triplet_warp_deviation(data, argvals, &weights, i, j, k, lambda))
1569 .sum();
1570 total_dev / triplets.len() as f64
1571}
1572
1573#[derive(Debug, Clone)]
1577pub struct ConstrainedAlignmentResult {
1578 pub gamma: Vec<f64>,
1580 pub f_aligned: Vec<f64>,
1582 pub distance: f64,
1584 pub enforced_landmarks: Vec<(f64, f64)>,
1586}
1587
1588fn snap_to_grid(t_val: f64, argvals: &[f64]) -> usize {
1590 let mut best = 0;
1591 let mut best_dist = (t_val - argvals[0]).abs();
1592 for (i, &a) in argvals.iter().enumerate().skip(1) {
1593 let d = (t_val - a).abs();
1594 if d < best_dist {
1595 best = i;
1596 best_dist = d;
1597 }
1598 }
1599 best
1600}
1601
1602fn dp_segment(
1607 q1: &[f64],
1608 q2: &[f64],
1609 argvals: &[f64],
1610 sc: usize,
1611 ec: usize,
1612 sr: usize,
1613 er: usize,
1614 lambda: f64,
1615) -> Vec<(usize, usize)> {
1616 let nc = ec - sc + 1;
1617 let nr = er - sr + 1;
1618
1619 if nc <= 1 || nr <= 1 {
1620 return vec![(sc, sr), (ec, er)];
1621 }
1622
1623 let path = dp_grid_solve(nr, nc, |local_sr, local_sc, local_tr, local_tc| {
1624 let gsr = sr + local_sr;
1625 let gsc = sc + local_sc;
1626 let gtr = sr + local_tr;
1627 let gtc = sc + local_tc;
1628 dp_edge_weight(q1, q2, argvals, gsc, gtc, gsr, gtr)
1629 + dp_lambda_penalty(argvals, gsc, gtc, gsr, gtr, lambda)
1630 });
1631
1632 path.iter().map(|&(lr, lc)| (sc + lc, sr + lr)).collect()
1634}
1635
1636fn build_constrained_waypoints(
1652 landmark_pairs: &[(f64, f64)],
1653 argvals: &[f64],
1654 m: usize,
1655) -> Vec<(usize, usize)> {
1656 let mut waypoints: Vec<(usize, usize)> = Vec::with_capacity(landmark_pairs.len() + 2);
1657 waypoints.push((0, 0));
1658 for &(tt, st) in landmark_pairs {
1659 let tc = snap_to_grid(tt, argvals);
1660 let tr = snap_to_grid(st, argvals);
1661 if let Some(&(prev_c, prev_r)) = waypoints.last() {
1662 if tc > prev_c && tr > prev_r {
1663 waypoints.push((tc, tr));
1664 }
1665 }
1666 }
1667 let last = m - 1;
1668 if let Some(&(prev_c, prev_r)) = waypoints.last() {
1669 if prev_c != last || prev_r != last {
1670 waypoints.push((last, last));
1671 }
1672 }
1673 waypoints
1674}
1675
1676fn segmented_dp_gamma(
1678 q1n: &[f64],
1679 q2n: &[f64],
1680 argvals: &[f64],
1681 waypoints: &[(usize, usize)],
1682 lambda: f64,
1683) -> Vec<f64> {
1684 let mut full_path_tc: Vec<f64> = Vec::new();
1685 let mut full_path_tr: Vec<f64> = Vec::new();
1686
1687 for seg in 0..(waypoints.len() - 1) {
1688 let (sc, sr) = waypoints[seg];
1689 let (ec, er) = waypoints[seg + 1];
1690 let segment_path = dp_segment(q1n, q2n, argvals, sc, ec, sr, er, lambda);
1691 let start = if seg > 0 { 1 } else { 0 };
1692 for &(tc, tr) in &segment_path[start..] {
1693 full_path_tc.push(argvals[tc]);
1694 full_path_tr.push(argvals[tr]);
1695 }
1696 }
1697
1698 let mut gamma: Vec<f64> = argvals
1699 .iter()
1700 .map(|&t| linear_interp(&full_path_tc, &full_path_tr, t))
1701 .collect();
1702 normalize_warp(&mut gamma, argvals);
1703 gamma
1704}
1705
1706pub fn elastic_align_pair_constrained(
1707 f1: &[f64],
1708 f2: &[f64],
1709 argvals: &[f64],
1710 landmark_pairs: &[(f64, f64)],
1711 lambda: f64,
1712) -> ConstrainedAlignmentResult {
1713 let m = f1.len();
1714
1715 if landmark_pairs.is_empty() {
1716 let r = elastic_align_pair(f1, f2, argvals, lambda);
1717 return ConstrainedAlignmentResult {
1718 gamma: r.gamma,
1719 f_aligned: r.f_aligned,
1720 distance: r.distance,
1721 enforced_landmarks: Vec::new(),
1722 };
1723 }
1724
1725 let f1_mat = FdMatrix::from_slice(f1, 1, m).unwrap();
1727 let f2_mat = FdMatrix::from_slice(f2, 1, m).unwrap();
1728 let q1_mat = srsf_transform(&f1_mat, argvals);
1729 let q2_mat = srsf_transform(&f2_mat, argvals);
1730 let q1: Vec<f64> = q1_mat.row(0);
1731 let q2: Vec<f64> = q2_mat.row(0);
1732 let norm1 = q1.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
1733 let norm2 = q2.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
1734 let q1n: Vec<f64> = q1.iter().map(|&v| v / norm1).collect();
1735 let q2n: Vec<f64> = q2.iter().map(|&v| v / norm2).collect();
1736
1737 let waypoints = build_constrained_waypoints(landmark_pairs, argvals, m);
1738 let gamma = segmented_dp_gamma(&q1n, &q2n, argvals, &waypoints, lambda);
1739
1740 let f_aligned = reparameterize_curve(f2, argvals, &gamma);
1741 let f_aligned_mat = FdMatrix::from_slice(&f_aligned, 1, m).unwrap();
1742 let q_aligned_mat = srsf_transform(&f_aligned_mat, argvals);
1743 let q_aligned: Vec<f64> = q_aligned_mat.row(0);
1744 let weights = simpsons_weights(argvals);
1745 let distance = l2_distance(&q1, &q_aligned, &weights);
1746
1747 let enforced: Vec<(f64, f64)> = waypoints[1..waypoints.len() - 1]
1748 .iter()
1749 .map(|&(tc, tr)| (argvals[tc], argvals[tr]))
1750 .collect();
1751
1752 ConstrainedAlignmentResult {
1753 gamma,
1754 f_aligned,
1755 distance,
1756 enforced_landmarks: enforced,
1757 }
1758}
1759
1760pub fn elastic_align_pair_with_landmarks(
1774 f1: &[f64],
1775 f2: &[f64],
1776 argvals: &[f64],
1777 kind: crate::landmark::LandmarkKind,
1778 min_prominence: f64,
1779 expected_count: usize,
1780 lambda: f64,
1781) -> ConstrainedAlignmentResult {
1782 let lm1 = crate::landmark::detect_landmarks(f1, argvals, kind, min_prominence);
1783 let lm2 = crate::landmark::detect_landmarks(f2, argvals, kind, min_prominence);
1784
1785 let n_match = if expected_count > 0 {
1787 expected_count.min(lm1.len()).min(lm2.len())
1788 } else {
1789 lm1.len().min(lm2.len())
1790 };
1791
1792 let pairs: Vec<(f64, f64)> = (0..n_match)
1793 .map(|i| (lm1[i].position, lm2[i].position))
1794 .collect();
1795
1796 elastic_align_pair_constrained(f1, f2, argvals, &pairs, lambda)
1797}
1798
1799use crate::matrix::FdCurveSet;
1802
1803#[derive(Debug, Clone)]
1805pub struct AlignmentResultNd {
1806 pub gamma: Vec<f64>,
1808 pub f_aligned: Vec<Vec<f64>>,
1810 pub distance: f64,
1812}
1813
1814#[inline]
1827fn srsf_scale_point(derivs: &[FdMatrix], result_dims: &mut [FdMatrix], i: usize, j: usize) {
1828 let d = derivs.len();
1829 let norm_sq: f64 = derivs.iter().map(|dd| dd[(i, j)].powi(2)).sum();
1830 let norm = norm_sq.sqrt();
1831 if norm < 1e-15 {
1832 for k in 0..d {
1833 result_dims[k][(i, j)] = 0.0;
1834 }
1835 } else {
1836 let scale = 1.0 / norm.sqrt();
1837 for k in 0..d {
1838 result_dims[k][(i, j)] = derivs[k][(i, j)] * scale;
1839 }
1840 }
1841}
1842
1843pub fn srsf_transform_nd(data: &FdCurveSet, argvals: &[f64]) -> FdCurveSet {
1844 let d = data.ndim();
1845 let n = data.ncurves();
1846 let m = data.npoints();
1847
1848 if d == 0 || n == 0 || m == 0 || argvals.len() != m {
1849 return FdCurveSet {
1850 dims: (0..d).map(|_| FdMatrix::zeros(n, m)).collect(),
1851 };
1852 }
1853
1854 let derivs: Vec<FdMatrix> = data
1855 .dims
1856 .iter()
1857 .map(|dim_mat| crate::fdata::deriv_1d(dim_mat, argvals, 1))
1858 .collect();
1859
1860 let mut result_dims: Vec<FdMatrix> = (0..d).map(|_| FdMatrix::zeros(n, m)).collect();
1861 for i in 0..n {
1862 for j in 0..m {
1863 srsf_scale_point(&derivs, &mut result_dims, i, j);
1864 }
1865 }
1866
1867 FdCurveSet { dims: result_dims }
1868}
1869
1870pub fn srsf_inverse_nd(q: &[Vec<f64>], argvals: &[f64], f0: &[f64]) -> Vec<Vec<f64>> {
1883 let d = q.len();
1884 if d == 0 {
1885 return Vec::new();
1886 }
1887 let m = q[0].len();
1888 if m == 0 {
1889 return vec![Vec::new(); d];
1890 }
1891
1892 let norms: Vec<f64> = (0..m)
1894 .map(|j| {
1895 let norm_sq: f64 = q.iter().map(|qk| qk[j].powi(2)).sum();
1896 norm_sq.sqrt()
1897 })
1898 .collect();
1899
1900 let mut result = Vec::with_capacity(d);
1902 for k in 0..d {
1903 let integrand: Vec<f64> = (0..m).map(|j| q[k][j] * norms[j]).collect();
1904 let integral = cumulative_trapz(&integrand, argvals);
1905 let curve: Vec<f64> = integral.iter().map(|&v| f0[k] + v).collect();
1906 result.push(curve);
1907 }
1908
1909 result
1910}
1911
1912fn dp_alignment_core_nd(
1917 q1: &[Vec<f64>],
1918 q2: &[Vec<f64>],
1919 argvals: &[f64],
1920 lambda: f64,
1921) -> Vec<f64> {
1922 let d = q1.len();
1923 let m = argvals.len();
1924 if m < 2 || d == 0 {
1925 return argvals.to_vec();
1926 }
1927
1928 if d == 1 {
1930 return dp_alignment_core(&q1[0], &q2[0], argvals, lambda);
1931 }
1932
1933 let q1n: Vec<Vec<f64>> = q1
1935 .iter()
1936 .map(|qk| {
1937 let norm = qk.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
1938 qk.iter().map(|&v| v / norm).collect()
1939 })
1940 .collect();
1941 let q2n: Vec<Vec<f64>> = q2
1942 .iter()
1943 .map(|qk| {
1944 let norm = qk.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
1945 qk.iter().map(|&v| v / norm).collect()
1946 })
1947 .collect();
1948
1949 let path = dp_grid_solve(m, m, |sr, sc, tr, tc| {
1950 let w: f64 = (0..d)
1951 .map(|k| dp_edge_weight(&q1n[k], &q2n[k], argvals, sc, tc, sr, tr))
1952 .sum();
1953 w + dp_lambda_penalty(argvals, sc, tc, sr, tr, lambda)
1954 });
1955
1956 dp_path_to_gamma(&path, argvals)
1957}
1958
1959pub fn elastic_align_pair_nd(
1970 f1: &FdCurveSet,
1971 f2: &FdCurveSet,
1972 argvals: &[f64],
1973 lambda: f64,
1974) -> AlignmentResultNd {
1975 let d = f1.ndim();
1976 let m = f1.npoints();
1977
1978 let q1_set = srsf_transform_nd(f1, argvals);
1980 let q2_set = srsf_transform_nd(f2, argvals);
1981
1982 let q1: Vec<Vec<f64>> = q1_set.dims.iter().map(|dm| dm.row(0)).collect();
1984 let q2: Vec<Vec<f64>> = q2_set.dims.iter().map(|dm| dm.row(0)).collect();
1985
1986 let gamma = dp_alignment_core_nd(&q1, &q2, argvals, lambda);
1988
1989 let f_aligned: Vec<Vec<f64>> = f2
1991 .dims
1992 .iter()
1993 .map(|dm| {
1994 let row = dm.row(0);
1995 reparameterize_curve(&row, argvals, &gamma)
1996 })
1997 .collect();
1998
1999 let f_aligned_set = {
2001 let dims: Vec<FdMatrix> = f_aligned
2002 .iter()
2003 .map(|fa| FdMatrix::from_slice(fa, 1, m).unwrap())
2004 .collect();
2005 FdCurveSet { dims }
2006 };
2007 let q_aligned = srsf_transform_nd(&f_aligned_set, argvals);
2008 let weights = simpsons_weights(argvals);
2009
2010 let mut dist_sq = 0.0;
2011 for k in 0..d {
2012 let q1k = q1_set.dims[k].row(0);
2013 let qak = q_aligned.dims[k].row(0);
2014 let d_k = l2_distance(&q1k, &qak, &weights);
2015 dist_sq += d_k * d_k;
2016 }
2017
2018 AlignmentResultNd {
2019 gamma,
2020 f_aligned,
2021 distance: dist_sq.sqrt(),
2022 }
2023}
2024
2025pub fn elastic_distance_nd(f1: &FdCurveSet, f2: &FdCurveSet, argvals: &[f64], lambda: f64) -> f64 {
2029 elastic_align_pair_nd(f1, f2, argvals, lambda).distance
2030}
2031
2032#[cfg(test)]
2035mod tests {
2036 use super::*;
2037 use crate::helpers::trapz;
2038 use crate::simulation::{sim_fundata, EFunType, EValType};
2039 use crate::warping::inner_product_l2;
2040
2041 fn uniform_grid(m: usize) -> Vec<f64> {
2042 (0..m).map(|i| i as f64 / (m - 1) as f64).collect()
2043 }
2044
2045 fn make_test_data(n: usize, m: usize, seed: u64) -> FdMatrix {
2046 let t = uniform_grid(m);
2047 sim_fundata(
2048 n,
2049 &t,
2050 3,
2051 EFunType::Fourier,
2052 EValType::Exponential,
2053 Some(seed),
2054 )
2055 }
2056
2057 #[test]
2060 fn test_cumulative_trapz_constant() {
2061 let x = uniform_grid(50);
2063 let y = vec![1.0; 50];
2064 let result = cumulative_trapz(&y, &x);
2065 assert!((result[0]).abs() < 1e-15, "cumulative_trapz(0) should be 0");
2066 for j in 1..50 {
2067 assert!(
2068 (result[j] - x[j]).abs() < 1e-12,
2069 "∫₀^{:.3} 1 dt should be {:.3}, got {:.3}",
2070 x[j],
2071 x[j],
2072 result[j]
2073 );
2074 }
2075 }
2076
2077 #[test]
2078 fn test_cumulative_trapz_linear() {
2079 let m = 100;
2081 let x = uniform_grid(m);
2082 let y: Vec<f64> = x.clone();
2083 let result = cumulative_trapz(&y, &x);
2084 for j in 1..m {
2085 let expected = x[j] * x[j] / 2.0;
2086 assert!(
2087 (result[j] - expected).abs() < 1e-4,
2088 "∫₀^{:.3} s ds: expected {expected:.6}, got {:.6}",
2089 x[j],
2090 result[j]
2091 );
2092 }
2093 }
2094
2095 #[test]
2098 fn test_normalize_warp_fixes_boundaries() {
2099 let t = uniform_grid(10);
2100 let mut gamma = vec![0.1; 10]; normalize_warp(&mut gamma, &t);
2102 assert_eq!(gamma[0], t[0]);
2103 assert_eq!(gamma[9], t[9]);
2104 }
2105
2106 #[test]
2107 fn test_normalize_warp_enforces_monotonicity() {
2108 let t = uniform_grid(5);
2109 let mut gamma = vec![0.0, 0.5, 0.3, 0.8, 1.0]; normalize_warp(&mut gamma, &t);
2111 for j in 1..5 {
2112 assert!(
2113 gamma[j] >= gamma[j - 1],
2114 "gamma should be monotone after normalization at j={j}"
2115 );
2116 }
2117 }
2118
2119 #[test]
2120 fn test_normalize_warp_identity_unchanged() {
2121 let t = uniform_grid(20);
2122 let mut gamma = t.clone();
2123 normalize_warp(&mut gamma, &t);
2124 for j in 0..20 {
2125 assert!(
2126 (gamma[j] - t[j]).abs() < 1e-15,
2127 "Identity warp should be unchanged"
2128 );
2129 }
2130 }
2131
2132 #[test]
2135 fn test_linear_interp_at_nodes() {
2136 let x = vec![0.0, 1.0, 2.0, 3.0];
2137 let y = vec![0.0, 2.0, 4.0, 6.0];
2138 for i in 0..x.len() {
2139 assert!((linear_interp(&x, &y, x[i]) - y[i]).abs() < 1e-12);
2140 }
2141 }
2142
2143 #[test]
2144 fn test_linear_interp_midpoints() {
2145 let x = vec![0.0, 1.0, 2.0];
2146 let y = vec![0.0, 2.0, 4.0];
2147 assert!((linear_interp(&x, &y, 0.5) - 1.0).abs() < 1e-12);
2148 assert!((linear_interp(&x, &y, 1.5) - 3.0).abs() < 1e-12);
2149 }
2150
2151 #[test]
2152 fn test_linear_interp_clamp() {
2153 let x = vec![0.0, 1.0, 2.0];
2154 let y = vec![1.0, 3.0, 5.0];
2155 assert!((linear_interp(&x, &y, -1.0) - 1.0).abs() < 1e-12);
2156 assert!((linear_interp(&x, &y, 3.0) - 5.0).abs() < 1e-12);
2157 }
2158
2159 #[test]
2160 fn test_linear_interp_nonuniform_grid() {
2161 let x = vec![0.0, 0.1, 0.5, 1.0];
2162 let y = vec![0.0, 1.0, 5.0, 10.0];
2163 let val = linear_interp(&x, &y, 0.3);
2165 let expected = 1.0 + 10.0 * (0.3 - 0.1);
2166 assert!(
2167 (val - expected).abs() < 1e-12,
2168 "Non-uniform interp: expected {expected}, got {val}"
2169 );
2170 }
2171
2172 #[test]
2173 fn test_linear_interp_two_points() {
2174 let x = vec![0.0, 1.0];
2175 let y = vec![3.0, 7.0];
2176 assert!((linear_interp(&x, &y, 0.25) - 4.0).abs() < 1e-12);
2177 assert!((linear_interp(&x, &y, 0.75) - 6.0).abs() < 1e-12);
2178 }
2179
2180 #[test]
2183 fn test_srsf_transform_linear() {
2184 let m = 50;
2186 let t = uniform_grid(m);
2187 let f: Vec<f64> = t.iter().map(|&ti| 2.0 * ti).collect();
2188 let mat = FdMatrix::from_slice(&f, 1, m).unwrap();
2189
2190 let q_mat = srsf_transform(&mat, &t);
2191 let q: Vec<f64> = q_mat.row(0);
2192
2193 let expected = 2.0_f64.sqrt();
2194 for j in 2..(m - 2) {
2196 assert!(
2197 (q[j] - expected).abs() < 0.1,
2198 "q[{j}] = {}, expected ~{expected}",
2199 q[j]
2200 );
2201 }
2202 }
2203
2204 #[test]
2205 fn test_srsf_transform_preserves_shape() {
2206 let data = make_test_data(10, 50, 42);
2207 let t = uniform_grid(50);
2208 let q = srsf_transform(&data, &t);
2209 assert_eq!(q.shape(), data.shape());
2210 }
2211
2212 #[test]
2213 fn test_srsf_transform_constant_is_zero() {
2214 let m = 30;
2216 let t = uniform_grid(m);
2217 let f = vec![5.0; m];
2218 let mat = FdMatrix::from_slice(&f, 1, m).unwrap();
2219 let q_mat = srsf_transform(&mat, &t);
2220 let q: Vec<f64> = q_mat.row(0);
2221
2222 for j in 0..m {
2223 assert!(
2224 q[j].abs() < 1e-10,
2225 "SRSF of constant should be 0, got q[{j}] = {}",
2226 q[j]
2227 );
2228 }
2229 }
2230
2231 #[test]
2232 fn test_srsf_transform_negative_slope() {
2233 let m = 50;
2235 let t = uniform_grid(m);
2236 let f: Vec<f64> = t.iter().map(|&ti| -3.0 * ti).collect();
2237 let mat = FdMatrix::from_slice(&f, 1, m).unwrap();
2238
2239 let q_mat = srsf_transform(&mat, &t);
2240 let q: Vec<f64> = q_mat.row(0);
2241
2242 let expected = -(3.0_f64.sqrt());
2243 for j in 2..(m - 2) {
2244 assert!(
2245 (q[j] - expected).abs() < 0.15,
2246 "q[{j}] = {}, expected ~{expected}",
2247 q[j]
2248 );
2249 }
2250 }
2251
2252 #[test]
2253 fn test_srsf_transform_empty_input() {
2254 let data = FdMatrix::zeros(0, 0);
2255 let t: Vec<f64> = vec![];
2256 let q = srsf_transform(&data, &t);
2257 assert_eq!(q.shape(), (0, 0));
2258 }
2259
2260 #[test]
2261 fn test_srsf_transform_multiple_curves() {
2262 let m = 40;
2263 let t = uniform_grid(m);
2264 let data = make_test_data(5, m, 42);
2265
2266 let q = srsf_transform(&data, &t);
2267 assert_eq!(q.shape(), (5, m));
2268
2269 for i in 0..5 {
2271 for j in 0..m {
2272 assert!(q[(i, j)].is_finite(), "SRSF should be finite at ({i},{j})");
2273 }
2274 }
2275 }
2276
2277 #[test]
2280 fn test_srsf_round_trip() {
2281 let m = 100;
2282 let t = uniform_grid(m);
2283 let f: Vec<f64> = t
2285 .iter()
2286 .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin() + ti)
2287 .collect();
2288
2289 let mat = FdMatrix::from_slice(&f, 1, m).unwrap();
2290 let q_mat = srsf_transform(&mat, &t);
2291 let q: Vec<f64> = q_mat.row(0);
2292
2293 let f_recon = srsf_inverse(&q, &t, f[0]);
2294
2295 let max_err: f64 = f[5..(m - 5)]
2297 .iter()
2298 .zip(f_recon[5..(m - 5)].iter())
2299 .map(|(a, b)| (a - b).abs())
2300 .fold(0.0_f64, f64::max);
2301
2302 assert!(
2303 max_err < 0.15,
2304 "Round-trip error too large: max_err = {max_err}"
2305 );
2306 }
2307
2308 #[test]
2309 fn test_srsf_inverse_empty() {
2310 let q: Vec<f64> = vec![];
2311 let t: Vec<f64> = vec![];
2312 let result = srsf_inverse(&q, &t, 0.0);
2313 assert!(result.is_empty());
2314 }
2315
2316 #[test]
2317 fn test_srsf_inverse_preserves_initial_value() {
2318 let m = 50;
2319 let t = uniform_grid(m);
2320 let q = vec![1.0; m]; let f0 = 3.15;
2322 let f = srsf_inverse(&q, &t, f0);
2323 assert!((f[0] - f0).abs() < 1e-12, "srsf_inverse should start at f0");
2324 }
2325
2326 #[test]
2327 fn test_srsf_round_trip_multiple_curves() {
2328 let m = 80;
2329 let t = uniform_grid(m);
2330 let data = make_test_data(5, m, 99);
2331
2332 let q_mat = srsf_transform(&data, &t);
2333
2334 for i in 0..5 {
2335 let fi = data.row(i);
2336 let qi = q_mat.row(i);
2337 let f_recon = srsf_inverse(&qi, &t, fi[0]);
2338 let max_err: f64 = fi[5..(m - 5)]
2339 .iter()
2340 .zip(f_recon[5..(m - 5)].iter())
2341 .map(|(a, b)| (a - b).abs())
2342 .fold(0.0_f64, f64::max);
2343 assert!(max_err < 0.3, "Round-trip curve {i}: max_err = {max_err}");
2344 }
2345 }
2346
2347 #[test]
2350 fn test_reparameterize_identity_warp() {
2351 let m = 50;
2352 let t = uniform_grid(m);
2353 let f: Vec<f64> = t.iter().map(|&ti| ti * ti).collect();
2354
2355 let result = reparameterize_curve(&f, &t, &t);
2357 for j in 0..m {
2358 assert!(
2359 (result[j] - f[j]).abs() < 1e-12,
2360 "Identity warp should return original at j={j}"
2361 );
2362 }
2363 }
2364
2365 #[test]
2366 fn test_reparameterize_linear_warp() {
2367 let m = 50;
2368 let t = uniform_grid(m);
2369 let f: Vec<f64> = t.clone();
2371 let gamma: Vec<f64> = t.iter().map(|&ti| ti * ti).collect();
2372
2373 let result = reparameterize_curve(&f, &t, &gamma);
2374
2375 for j in 0..m {
2377 assert!(
2378 (result[j] - gamma[j]).abs() < 1e-10,
2379 "f(gamma(t)) should be gamma(t) for f(t)=t at j={j}"
2380 );
2381 }
2382 }
2383
2384 #[test]
2385 fn test_reparameterize_sine_with_quadratic_warp() {
2386 let m = 100;
2387 let t = uniform_grid(m);
2388 let f: Vec<f64> = t
2389 .iter()
2390 .map(|&ti| (std::f64::consts::PI * ti).sin())
2391 .collect();
2392 let gamma: Vec<f64> = t.iter().map(|&ti| ti * ti).collect(); let result = reparameterize_curve(&f, &t, &gamma);
2395
2396 for j in 0..m {
2398 let expected = (std::f64::consts::PI * gamma[j]).sin();
2399 assert!(
2400 (result[j] - expected).abs() < 0.05,
2401 "sin(π γ(t)) at j={j}: expected {expected:.4}, got {:.4}",
2402 result[j]
2403 );
2404 }
2405 }
2406
2407 #[test]
2408 fn test_reparameterize_preserves_length() {
2409 let m = 50;
2410 let t = uniform_grid(m);
2411 let f = vec![1.0; m];
2412 let gamma: Vec<f64> = t.iter().map(|&ti| ti.sqrt()).collect();
2413
2414 let result = reparameterize_curve(&f, &t, &gamma);
2415 assert_eq!(result.len(), m);
2416 }
2417
2418 #[test]
2421 fn test_compose_warps_identity() {
2422 let m = 50;
2423 let t = uniform_grid(m);
2424 let gamma: Vec<f64> = t.iter().map(|&ti| ti.sqrt()).collect();
2426
2427 let result = compose_warps(&t, &gamma, &t);
2429 for j in 0..m {
2430 assert!(
2431 (result[j] - gamma[j]).abs() < 1e-10,
2432 "id ∘ γ should be γ at j={j}"
2433 );
2434 }
2435
2436 let result2 = compose_warps(&gamma, &t, &t);
2438 for j in 0..m {
2439 assert!(
2440 (result2[j] - gamma[j]).abs() < 1e-10,
2441 "γ ∘ id should be γ at j={j}"
2442 );
2443 }
2444 }
2445
2446 #[test]
2447 fn test_compose_warps_associativity() {
2448 let m = 50;
2450 let t = uniform_grid(m);
2451 let g1: Vec<f64> = t.iter().map(|&ti| ti.sqrt()).collect();
2452 let g2: Vec<f64> = t.iter().map(|&ti| ti * ti).collect();
2453 let g3: Vec<f64> = t.iter().map(|&ti| 0.5 * ti + 0.5 * ti * ti).collect();
2454
2455 let g12 = compose_warps(&g1, &g2, &t);
2456 let left = compose_warps(&g12, &g3, &t); let g23 = compose_warps(&g2, &g3, &t);
2459 let right = compose_warps(&g1, &g23, &t); for j in 0..m {
2462 assert!(
2463 (left[j] - right[j]).abs() < 0.05,
2464 "Composition should be roughly associative at j={j}: left={:.4}, right={:.4}",
2465 left[j],
2466 right[j]
2467 );
2468 }
2469 }
2470
2471 #[test]
2472 fn test_compose_warps_preserves_domain() {
2473 let m = 50;
2474 let t = uniform_grid(m);
2475 let g1: Vec<f64> = t.iter().map(|&ti| ti.sqrt()).collect();
2476 let g2: Vec<f64> = t.iter().map(|&ti| ti * ti).collect();
2477
2478 let composed = compose_warps(&g1, &g2, &t);
2479 assert!(
2480 (composed[0] - t[0]).abs() < 1e-10,
2481 "Composed warp should start at domain start"
2482 );
2483 assert!(
2484 (composed[m - 1] - t[m - 1]).abs() < 1e-10,
2485 "Composed warp should end at domain end"
2486 );
2487 }
2488
2489 #[test]
2492 fn test_align_identical_curves() {
2493 let m = 50;
2494 let t = uniform_grid(m);
2495 let f: Vec<f64> = t
2496 .iter()
2497 .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
2498 .collect();
2499
2500 let result = elastic_align_pair(&f, &f, &t, 0.0);
2501
2502 assert!(
2504 result.distance < 0.1,
2505 "Distance between identical curves should be near 0, got {}",
2506 result.distance
2507 );
2508
2509 for j in 0..m {
2511 assert!(
2512 (result.gamma[j] - t[j]).abs() < 0.1,
2513 "Warp should be near identity at j={j}: gamma={}, t={}",
2514 result.gamma[j],
2515 t[j]
2516 );
2517 }
2518 }
2519
2520 #[test]
2521 fn test_align_pair_valid_output() {
2522 let data = make_test_data(2, 50, 42);
2523 let t = uniform_grid(50);
2524 let f1 = data.row(0);
2525 let f2 = data.row(1);
2526
2527 let result = elastic_align_pair(&f1, &f2, &t, 0.0);
2528
2529 assert_eq!(result.gamma.len(), 50);
2530 assert_eq!(result.f_aligned.len(), 50);
2531 assert!(result.distance >= 0.0);
2532
2533 for j in 1..50 {
2535 assert!(
2536 result.gamma[j] >= result.gamma[j - 1],
2537 "Warp should be monotone at j={j}"
2538 );
2539 }
2540 }
2541
2542 #[test]
2543 fn test_align_pair_warp_boundaries() {
2544 let data = make_test_data(2, 50, 42);
2545 let t = uniform_grid(50);
2546 let f1 = data.row(0);
2547 let f2 = data.row(1);
2548
2549 let result = elastic_align_pair(&f1, &f2, &t, 0.0);
2550 assert!(
2551 (result.gamma[0] - t[0]).abs() < 1e-12,
2552 "Warp should start at domain start"
2553 );
2554 assert!(
2555 (result.gamma[49] - t[49]).abs() < 1e-12,
2556 "Warp should end at domain end"
2557 );
2558 }
2559
2560 #[test]
2561 fn test_align_shifted_sine() {
2562 let m = 80;
2564 let t = uniform_grid(m);
2565 let f1: Vec<f64> = t
2566 .iter()
2567 .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
2568 .collect();
2569 let f2: Vec<f64> = t
2570 .iter()
2571 .map(|&ti| (2.0 * std::f64::consts::PI * (ti - 0.1)).sin())
2572 .collect();
2573
2574 let weights = simpsons_weights(&t);
2575 let l2_before = l2_distance(&f1, &f2, &weights);
2576 let result = elastic_align_pair(&f1, &f2, &t, 0.0);
2577 let l2_after = l2_distance(&f1, &result.f_aligned, &weights);
2578
2579 assert!(
2580 l2_after < l2_before + 0.01,
2581 "Alignment should not increase L2 distance: before={l2_before:.4}, after={l2_after:.4}"
2582 );
2583 }
2584
2585 #[test]
2586 fn test_align_pair_aligned_curve_is_finite() {
2587 let data = make_test_data(2, 50, 77);
2588 let t = uniform_grid(50);
2589 let f1 = data.row(0);
2590 let f2 = data.row(1);
2591
2592 let result = elastic_align_pair(&f1, &f2, &t, 0.0);
2593 for j in 0..50 {
2594 assert!(
2595 result.f_aligned[j].is_finite(),
2596 "Aligned curve should be finite at j={j}"
2597 );
2598 }
2599 }
2600
2601 #[test]
2602 fn test_align_pair_minimum_grid() {
2603 let t = vec![0.0, 1.0];
2605 let f1 = vec![0.0, 1.0];
2606 let f2 = vec![0.0, 2.0];
2607 let result = elastic_align_pair(&f1, &f2, &t, 0.0);
2608 assert_eq!(result.gamma.len(), 2);
2609 assert_eq!(result.f_aligned.len(), 2);
2610 assert!(result.distance >= 0.0);
2611 }
2612
2613 #[test]
2616 fn test_elastic_distance_symmetric() {
2617 let data = make_test_data(3, 50, 42);
2618 let t = uniform_grid(50);
2619 let f1 = data.row(0);
2620 let f2 = data.row(1);
2621
2622 let d12 = elastic_distance(&f1, &f2, &t, 0.0);
2623 let d21 = elastic_distance(&f2, &f1, &t, 0.0);
2624
2625 assert!(
2627 (d12 - d21).abs() < d12.max(d21) * 0.3 + 0.01,
2628 "Elastic distance should be roughly symmetric: d12={d12}, d21={d21}"
2629 );
2630 }
2631
2632 #[test]
2633 fn test_elastic_distance_nonneg() {
2634 let data = make_test_data(3, 50, 42);
2635 let t = uniform_grid(50);
2636
2637 for i in 0..3 {
2638 for j in 0..3 {
2639 let fi = data.row(i);
2640 let fj = data.row(j);
2641 let d = elastic_distance(&fi, &fj, &t, 0.0);
2642 assert!(d >= 0.0, "Elastic distance should be non-negative");
2643 }
2644 }
2645 }
2646
2647 #[test]
2648 fn test_elastic_distance_self_near_zero() {
2649 let data = make_test_data(3, 50, 42);
2650 let t = uniform_grid(50);
2651
2652 for i in 0..3 {
2653 let fi = data.row(i);
2654 let d = elastic_distance(&fi, &fi, &t, 0.0);
2655 assert!(
2656 d < 0.1,
2657 "Self-distance should be near zero, got {d} for curve {i}"
2658 );
2659 }
2660 }
2661
2662 #[test]
2663 fn test_elastic_distance_triangle_inequality() {
2664 let data = make_test_data(3, 50, 42);
2665 let t = uniform_grid(50);
2666 let f0 = data.row(0);
2667 let f1 = data.row(1);
2668 let f2 = data.row(2);
2669
2670 let d01 = elastic_distance(&f0, &f1, &t, 0.0);
2671 let d12 = elastic_distance(&f1, &f2, &t, 0.0);
2672 let d02 = elastic_distance(&f0, &f2, &t, 0.0);
2673
2674 let slack = 0.5;
2676 assert!(
2677 d02 <= d01 + d12 + slack,
2678 "Triangle inequality (relaxed): d02={d02:.4} > d01={d01:.4} + d12={d12:.4} + {slack}"
2679 );
2680 }
2681
2682 #[test]
2683 fn test_elastic_distance_different_shapes_nonzero() {
2684 let m = 50;
2685 let t = uniform_grid(m);
2686 let f1: Vec<f64> = t.to_vec(); let f2: Vec<f64> = t.iter().map(|&ti| ti * ti).collect(); let d = elastic_distance(&f1, &f2, &t, 0.0);
2690 assert!(
2691 d > 0.01,
2692 "Distance between different shapes should be > 0, got {d}"
2693 );
2694 }
2695
2696 #[test]
2699 fn test_self_distance_matrix_symmetric() {
2700 let data = make_test_data(5, 30, 42);
2701 let t = uniform_grid(30);
2702
2703 let dm = elastic_self_distance_matrix(&data, &t, 0.0);
2704 let n = dm.nrows();
2705
2706 assert_eq!(dm.shape(), (5, 5));
2707
2708 for i in 0..n {
2710 assert!(dm[(i, i)].abs() < 1e-12, "Diagonal should be zero");
2711 }
2712
2713 for i in 0..n {
2715 for j in (i + 1)..n {
2716 assert!(
2717 (dm[(i, j)] - dm[(j, i)]).abs() < 1e-12,
2718 "Matrix should be symmetric at ({i},{j})"
2719 );
2720 }
2721 }
2722 }
2723
2724 #[test]
2725 fn test_self_distance_matrix_nonneg() {
2726 let data = make_test_data(4, 30, 42);
2727 let t = uniform_grid(30);
2728 let dm = elastic_self_distance_matrix(&data, &t, 0.0);
2729
2730 for i in 0..4 {
2731 for j in 0..4 {
2732 assert!(
2733 dm[(i, j)] >= 0.0,
2734 "Distance matrix entries should be non-negative at ({i},{j})"
2735 );
2736 }
2737 }
2738 }
2739
2740 #[test]
2741 fn test_self_distance_matrix_single_curve() {
2742 let data = make_test_data(1, 30, 42);
2743 let t = uniform_grid(30);
2744 let dm = elastic_self_distance_matrix(&data, &t, 0.0);
2745 assert_eq!(dm.shape(), (1, 1));
2746 assert!(dm[(0, 0)].abs() < 1e-12);
2747 }
2748
2749 #[test]
2750 fn test_self_distance_matrix_consistent_with_pairwise() {
2751 let data = make_test_data(4, 30, 42);
2752 let t = uniform_grid(30);
2753
2754 let dm = elastic_self_distance_matrix(&data, &t, 0.0);
2755
2756 for i in 0..4 {
2758 for j in (i + 1)..4 {
2759 let fi = data.row(i);
2760 let fj = data.row(j);
2761 let d_direct = elastic_distance(&fi, &fj, &t, 0.0);
2762 assert!(
2763 (dm[(i, j)] - d_direct).abs() < 1e-10,
2764 "Matrix entry ({i},{j})={:.6} should match pairwise {d_direct:.6}",
2765 dm[(i, j)]
2766 );
2767 }
2768 }
2769 }
2770
2771 #[test]
2774 fn test_karcher_mean_identical_curves() {
2775 let m = 50;
2776 let t = uniform_grid(m);
2777 let f: Vec<f64> = t
2778 .iter()
2779 .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
2780 .collect();
2781
2782 let mut data = FdMatrix::zeros(5, m);
2784 for i in 0..5 {
2785 for j in 0..m {
2786 data[(i, j)] = f[j];
2787 }
2788 }
2789
2790 let result = karcher_mean(&data, &t, 10, 1e-4, 0.0);
2791
2792 assert_eq!(result.mean.len(), m);
2793 assert!(result.n_iter <= 10);
2794 }
2795
2796 #[test]
2797 fn test_karcher_mean_output_shape() {
2798 let data = make_test_data(15, 50, 42);
2799 let t = uniform_grid(50);
2800
2801 let result = karcher_mean(&data, &t, 5, 1e-3, 0.0);
2802
2803 assert_eq!(result.mean.len(), 50);
2804 assert_eq!(result.mean_srsf.len(), 50);
2805 assert_eq!(result.gammas.shape(), (15, 50));
2806 assert_eq!(result.aligned_data.shape(), (15, 50));
2807 assert!(result.n_iter <= 5);
2808 }
2809
2810 #[test]
2811 fn test_karcher_mean_warps_are_valid() {
2812 let data = make_test_data(10, 40, 42);
2813 let t = uniform_grid(40);
2814
2815 let result = karcher_mean(&data, &t, 5, 1e-3, 0.0);
2816
2817 for i in 0..10 {
2818 assert!(
2820 (result.gammas[(i, 0)] - t[0]).abs() < 1e-10,
2821 "Warp {i} should start at domain start"
2822 );
2823 assert!(
2824 (result.gammas[(i, 39)] - t[39]).abs() < 1e-10,
2825 "Warp {i} should end at domain end"
2826 );
2827 for j in 1..40 {
2829 assert!(
2830 result.gammas[(i, j)] >= result.gammas[(i, j - 1)],
2831 "Warp {i} should be monotone at j={j}"
2832 );
2833 }
2834 }
2835 }
2836
2837 #[test]
2838 fn test_karcher_mean_aligned_data_is_finite() {
2839 let data = make_test_data(8, 40, 42);
2840 let t = uniform_grid(40);
2841 let result = karcher_mean(&data, &t, 5, 1e-3, 0.0);
2842
2843 for i in 0..8 {
2844 for j in 0..40 {
2845 assert!(
2846 result.aligned_data[(i, j)].is_finite(),
2847 "Aligned data should be finite at ({i},{j})"
2848 );
2849 }
2850 }
2851 }
2852
2853 #[test]
2854 fn test_karcher_mean_srsf_is_finite() {
2855 let data = make_test_data(8, 40, 42);
2856 let t = uniform_grid(40);
2857 let result = karcher_mean(&data, &t, 5, 1e-3, 0.0);
2858
2859 for j in 0..40 {
2860 assert!(
2861 result.mean_srsf[j].is_finite(),
2862 "Mean SRSF should be finite at j={j}"
2863 );
2864 assert!(
2865 result.mean[j].is_finite(),
2866 "Mean curve should be finite at j={j}"
2867 );
2868 }
2869 }
2870
2871 #[test]
2872 fn test_karcher_mean_single_iteration() {
2873 let data = make_test_data(10, 40, 42);
2874 let t = uniform_grid(40);
2875 let result = karcher_mean(&data, &t, 1, 1e-10, 0.0);
2876
2877 assert_eq!(result.n_iter, 1);
2878 assert_eq!(result.mean.len(), 40);
2879 for j in 0..40 {
2881 assert!(result.mean[j].is_finite());
2882 }
2883 }
2884
2885 #[test]
2886 fn test_karcher_mean_convergence_not_premature() {
2887 let n = 10;
2892 let m = 50;
2893 let t = uniform_grid(m);
2894
2895 let mut col_major = vec![0.0; n * m];
2897 for i in 0..n {
2898 let shift = (i as f64 - 5.0) * 0.05;
2899 for j in 0..m {
2900 col_major[i + j * n] = (2.0 * std::f64::consts::PI * (t[j] - shift)).sin();
2901 }
2902 }
2903 let data = FdMatrix::from_column_major(col_major, n, m).unwrap();
2904
2905 let result = karcher_mean(&data, &t, 20, 1e-15, 0.0);
2908 assert!(
2909 result.n_iter > 2,
2910 "With tol=1e-15 the algorithm should iterate beyond 2, got n_iter={}",
2911 result.n_iter
2912 );
2913
2914 let result_loose = karcher_mean(&data, &t, 20, 1e-2, 0.0);
2916 assert!(
2917 result_loose.converged,
2918 "With tol=1e-2 the algorithm should converge"
2919 );
2920 }
2921
2922 #[test]
2925 fn test_align_to_target_valid() {
2926 let data = make_test_data(10, 40, 42);
2927 let t = uniform_grid(40);
2928 let target = data.row(0);
2929
2930 let result = align_to_target(&data, &target, &t, 0.0);
2931
2932 assert_eq!(result.gammas.shape(), (10, 40));
2933 assert_eq!(result.aligned_data.shape(), (10, 40));
2934 assert_eq!(result.distances.len(), 10);
2935
2936 for &d in &result.distances {
2938 assert!(d >= 0.0);
2939 }
2940 }
2941
2942 #[test]
2943 fn test_align_to_target_self_near_zero() {
2944 let data = make_test_data(5, 40, 42);
2945 let t = uniform_grid(40);
2946 let target = data.row(0);
2947
2948 let result = align_to_target(&data, &target, &t, 0.0);
2949
2950 assert!(
2952 result.distances[0] < 0.1,
2953 "Self-alignment distance should be near zero, got {}",
2954 result.distances[0]
2955 );
2956 }
2957
2958 #[test]
2959 fn test_align_to_target_warps_are_monotone() {
2960 let data = make_test_data(8, 40, 42);
2961 let t = uniform_grid(40);
2962 let target = data.row(0);
2963 let result = align_to_target(&data, &target, &t, 0.0);
2964
2965 for i in 0..8 {
2966 for j in 1..40 {
2967 assert!(
2968 result.gammas[(i, j)] >= result.gammas[(i, j - 1)],
2969 "Warp for curve {i} should be monotone at j={j}"
2970 );
2971 }
2972 }
2973 }
2974
2975 #[test]
2976 fn test_align_to_target_aligned_data_finite() {
2977 let data = make_test_data(6, 40, 42);
2978 let t = uniform_grid(40);
2979 let target = data.row(0);
2980 let result = align_to_target(&data, &target, &t, 0.0);
2981
2982 for i in 0..6 {
2983 for j in 0..40 {
2984 assert!(
2985 result.aligned_data[(i, j)].is_finite(),
2986 "Aligned data should be finite at ({i},{j})"
2987 );
2988 }
2989 }
2990 }
2991
2992 #[test]
2995 fn test_cross_distance_matrix_shape() {
2996 let data1 = make_test_data(3, 30, 42);
2997 let data2 = make_test_data(4, 30, 99);
2998 let t = uniform_grid(30);
2999
3000 let dm = elastic_cross_distance_matrix(&data1, &data2, &t, 0.0);
3001 assert_eq!(dm.shape(), (3, 4));
3002
3003 for i in 0..3 {
3005 for j in 0..4 {
3006 assert!(dm[(i, j)] >= 0.0);
3007 }
3008 }
3009 }
3010
3011 #[test]
3012 fn test_cross_distance_matrix_self_matches_self_matrix() {
3013 let data = make_test_data(4, 30, 42);
3015 let t = uniform_grid(30);
3016
3017 let cross = elastic_cross_distance_matrix(&data, &data, &t, 0.0);
3018 for i in 0..4 {
3019 assert!(
3020 cross[(i, i)] < 0.1,
3021 "Cross distance (self) diagonal should be near zero: got {}",
3022 cross[(i, i)]
3023 );
3024 }
3025 }
3026
3027 #[test]
3028 fn test_cross_distance_matrix_consistent_with_pairwise() {
3029 let data1 = make_test_data(3, 30, 42);
3030 let data2 = make_test_data(2, 30, 99);
3031 let t = uniform_grid(30);
3032
3033 let dm = elastic_cross_distance_matrix(&data1, &data2, &t, 0.0);
3034
3035 for i in 0..3 {
3036 for j in 0..2 {
3037 let fi = data1.row(i);
3038 let fj = data2.row(j);
3039 let d_direct = elastic_distance(&fi, &fj, &t, 0.0);
3040 assert!(
3041 (dm[(i, j)] - d_direct).abs() < 1e-10,
3042 "Cross matrix ({i},{j})={:.6} should match pairwise {d_direct:.6}",
3043 dm[(i, j)]
3044 );
3045 }
3046 }
3047 }
3048
3049 #[test]
3052 fn test_align_srsf_pair_identity() {
3053 let m = 50;
3054 let t = uniform_grid(m);
3055 let f: Vec<f64> = t
3056 .iter()
3057 .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
3058 .collect();
3059 let q = srsf_single(&f, &t);
3060
3061 let (gamma, q_aligned) = align_srsf_pair(&q, &q, &t, 0.0);
3062
3063 for j in 0..m {
3065 assert!(
3066 (gamma[j] - t[j]).abs() < 0.15,
3067 "Self-SRSF alignment warp should be near identity at j={j}"
3068 );
3069 }
3070
3071 let weights = simpsons_weights(&t);
3073 let dist = l2_distance(&q, &q_aligned, &weights);
3074 assert!(
3075 dist < 0.5,
3076 "Self-aligned SRSF distance should be small, got {dist}"
3077 );
3078 }
3079
3080 #[test]
3083 fn test_srsf_single_matches_matrix_version() {
3084 let m = 50;
3085 let t = uniform_grid(m);
3086 let f: Vec<f64> = t.iter().map(|&ti| ti * ti + ti).collect();
3087
3088 let q_single = srsf_single(&f, &t);
3089
3090 let mat = FdMatrix::from_slice(&f, 1, m).unwrap();
3091 let q_mat = srsf_transform(&mat, &t);
3092 let q_from_mat = q_mat.row(0);
3093
3094 for j in 0..m {
3095 assert!(
3096 (q_single[j] - q_from_mat[j]).abs() < 1e-12,
3097 "srsf_single should match srsf_transform at j={j}"
3098 );
3099 }
3100 }
3101
3102 #[test]
3105 fn test_gcd_basic() {
3106 assert_eq!(gcd(1, 1), 1);
3107 assert_eq!(gcd(6, 4), 2);
3108 assert_eq!(gcd(7, 5), 1);
3109 assert_eq!(gcd(12, 8), 4);
3110 assert_eq!(gcd(7, 0), 7);
3111 assert_eq!(gcd(0, 5), 5);
3112 }
3113
3114 #[test]
3117 fn test_coprime_nbhd_count() {
3118 assert_eq!(generate_coprime_nbhd(1).len(), 1); assert_eq!(generate_coprime_nbhd(7).len(), 35);
3120 }
3121
3122 #[test]
3123 fn test_coprime_nbhd_matches_const() {
3124 let generated = generate_coprime_nbhd(7);
3125 assert_eq!(generated.len(), COPRIME_NBHD_7.len());
3126 for (i, pair) in generated.iter().enumerate() {
3127 assert_eq!(*pair, COPRIME_NBHD_7[i], "mismatch at index {i}");
3128 }
3129 }
3130
3131 #[test]
3132 fn test_coprime_nbhd_all_coprime() {
3133 for &(i, j) in &COPRIME_NBHD_7 {
3134 assert_eq!(gcd(i, j), 1, "({i},{j}) should be coprime");
3135 assert!((1..=7).contains(&i));
3136 assert!((1..=7).contains(&j));
3137 }
3138 }
3139
3140 #[test]
3143 fn test_dp_edge_weight_diagonal() {
3144 let t = uniform_grid(10);
3146 let q1 = vec![1.0; 10];
3147 let q2 = vec![1.0; 10];
3148 let w = dp_edge_weight(&q1, &q2, &t, 0, 1, 0, 1);
3150 assert!(w.abs() < 1e-12, "identical SRSFs should have zero cost");
3151 }
3152
3153 #[test]
3154 fn test_dp_edge_weight_non_diagonal() {
3155 let t = uniform_grid(10);
3157 let q1 = vec![1.0; 10];
3158 let q2 = vec![0.0; 10];
3159 let w = dp_edge_weight(&q1, &q2, &t, 0, 2, 0, 1);
3160 let expected = 2.0 / 9.0;
3163 assert!(
3164 (w - expected).abs() < 1e-10,
3165 "dp_edge_weight (1,2): expected {expected}, got {w}"
3166 );
3167 }
3168
3169 #[test]
3170 fn test_dp_edge_weight_zero_span() {
3171 let t = uniform_grid(10);
3172 let q1 = vec![1.0; 10];
3173 let q2 = vec![1.0; 10];
3174 assert_eq!(dp_edge_weight(&q1, &q2, &t, 3, 3, 0, 1), f64::INFINITY);
3176 assert_eq!(dp_edge_weight(&q1, &q2, &t, 0, 1, 3, 3), f64::INFINITY);
3178 }
3179
3180 #[test]
3183 fn test_alignment_improves_distance() {
3184 let m = 50;
3186 let t = uniform_grid(m);
3187 let f1: Vec<f64> = t
3188 .iter()
3189 .map(|&x| (2.0 * std::f64::consts::PI * x).sin())
3190 .collect();
3191 let f2: Vec<f64> = t
3193 .iter()
3194 .map(|&x| (2.0 * std::f64::consts::PI * (x + 0.2)).sin())
3195 .collect();
3196
3197 let q1 = srsf_single(&f1, &t);
3198 let q2 = srsf_single(&f2, &t);
3199 let weights = simpsons_weights(&t);
3200 let unaligned_srsf_dist = l2_distance(&q1, &q2, &weights);
3201
3202 let result = elastic_align_pair(&f1, &f2, &t, 0.0);
3203
3204 assert!(
3205 result.distance <= unaligned_srsf_dist + 1e-6,
3206 "aligned SRSF dist ({}) should be <= unaligned SRSF dist ({})",
3207 result.distance,
3208 unaligned_srsf_dist
3209 );
3210 }
3211
3212 #[test]
3215 fn test_alignment_constant_curves() {
3216 let m = 30;
3217 let t = uniform_grid(m);
3218 let f1 = vec![5.0; m];
3219 let f2 = vec![5.0; m];
3220
3221 let result = elastic_align_pair(&f1, &f2, &t, 0.0);
3222 assert!(
3223 result.distance < 0.01,
3224 "Constant curves: distance should be ~0"
3225 );
3226 assert_eq!(result.f_aligned.len(), m);
3227 }
3228
3229 #[test]
3230 fn test_karcher_mean_constant_curves() {
3231 let m = 30;
3232 let t = uniform_grid(m);
3233 let mut data = FdMatrix::zeros(5, m);
3234 for i in 0..5 {
3235 for j in 0..m {
3236 data[(i, j)] = 3.0;
3237 }
3238 }
3239
3240 let result = karcher_mean(&data, &t, 5, 1e-4, 0.0);
3241 for j in 0..m {
3242 assert!(
3243 (result.mean[j] - 3.0).abs() < 0.5,
3244 "Mean of constant curves should be near 3.0, got {} at j={j}",
3245 result.mean[j]
3246 );
3247 }
3248 }
3249
3250 #[test]
3251 fn test_nan_srsf_no_panic() {
3252 let m = 20;
3253 let t = uniform_grid(m);
3254 let mut f = vec![1.0; m];
3255 f[5] = f64::NAN;
3256 let mat = FdMatrix::from_slice(&f, 1, m).unwrap();
3257 let q = srsf_transform(&mat, &t);
3258 assert_eq!(q.nrows(), 1);
3260 }
3261
3262 #[test]
3263 fn test_n1_karcher_mean() {
3264 let m = 30;
3265 let t = uniform_grid(m);
3266 let f: Vec<f64> = t.iter().map(|&ti| ti * ti).collect();
3267 let data = FdMatrix::from_slice(&f, 1, m).unwrap();
3268 let result = karcher_mean(&data, &t, 5, 1e-4, 0.0);
3269 assert_eq!(result.mean.len(), m);
3270 for j in 0..m {
3272 assert!(result.mean[j].is_finite());
3273 }
3274 }
3275
3276 #[test]
3277 fn test_two_point_grid() {
3278 let t = vec![0.0, 1.0];
3279 let f1 = vec![0.0, 1.0];
3280 let f2 = vec![0.0, 2.0];
3281 let d = elastic_distance(&f1, &f2, &t, 0.0);
3282 assert!(d >= 0.0);
3283 assert!(d.is_finite());
3284 }
3285
3286 #[test]
3287 fn test_non_uniform_grid_alignment() {
3288 let t = vec![0.0, 0.01, 0.05, 0.2, 0.5, 1.0];
3290 let m = t.len();
3291 let f1: Vec<f64> = t.iter().map(|&ti: &f64| ti.sin()).collect();
3292 let f2: Vec<f64> = t.iter().map(|&ti: &f64| (ti + 0.1).sin()).collect();
3293 let result = elastic_align_pair(&f1, &f2, &t, 0.0);
3294 assert_eq!(result.gamma.len(), m);
3295 assert!(result.distance >= 0.0);
3296 assert!(result.distance.is_finite());
3297 }
3298
3299 #[test]
3302 fn test_tsrvf_output_shape() {
3303 let m = 50;
3304 let n = 10;
3305 let t = uniform_grid(m);
3306 let data = make_test_data(n, m, 42);
3307 let result = tsrvf_transform(&data, &t, 5, 1e-3, 0.0);
3308 assert_eq!(
3309 result.tangent_vectors.shape(),
3310 (n, m),
3311 "Tangent vectors should be n×m"
3312 );
3313 assert_eq!(result.gammas.shape(), (n, m), "Gammas should be n×m");
3314 assert_eq!(result.srsf_norms.len(), n, "Should have n SRSF norms");
3315 assert_eq!(result.mean.len(), m, "Mean should have m points");
3316 assert_eq!(result.mean_srsf.len(), m, "Mean SRSF should have m points");
3317 }
3318
3319 #[test]
3320 fn test_tsrvf_all_finite() {
3321 let m = 50;
3322 let n = 5;
3323 let t = uniform_grid(m);
3324 let data = make_test_data(n, m, 42);
3325 let result = tsrvf_transform(&data, &t, 5, 1e-3, 0.0);
3326 for i in 0..n {
3327 for j in 0..m {
3328 assert!(
3329 result.tangent_vectors[(i, j)].is_finite(),
3330 "Tangent vector should be finite at ({i},{j})"
3331 );
3332 }
3333 assert!(
3334 result.srsf_norms[i].is_finite(),
3335 "SRSF norm should be finite for curve {i}"
3336 );
3337 }
3338 assert!(
3339 result.mean_srsf_norm.is_finite(),
3340 "Mean SRSF norm should be finite"
3341 );
3342 }
3343
3344 #[test]
3345 fn test_tsrvf_identical_curves_zero_tangent() {
3346 let m = 50;
3347 let t = uniform_grid(m);
3348 let curve: Vec<f64> = t
3350 .iter()
3351 .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
3352 .collect();
3353 let mut col_major = vec![0.0; 5 * m];
3354 for i in 0..5 {
3355 for j in 0..m {
3356 col_major[i + j * 5] = curve[j];
3357 }
3358 }
3359 let data = FdMatrix::from_column_major(col_major, 5, m).unwrap();
3360 let result = tsrvf_transform(&data, &t, 10, 1e-4, 0.0);
3361
3362 for i in 0..5 {
3364 let tv_norm_sq: f64 = (0..m).map(|j| result.tangent_vectors[(i, j)].powi(2)).sum();
3365 assert!(
3366 tv_norm_sq.sqrt() < 0.5,
3367 "Identical curves should have near-zero tangent vectors, got norm = {}",
3368 tv_norm_sq.sqrt()
3369 );
3370 }
3371 }
3372
3373 #[test]
3374 fn test_tsrvf_mean_tangent_near_zero() {
3375 let m = 50;
3376 let n = 10;
3377 let t = uniform_grid(m);
3378 let data = make_test_data(n, m, 42);
3379 let result = tsrvf_transform(&data, &t, 10, 1e-3, 0.0);
3380
3381 let mut mean_tv = vec![0.0; m];
3383 for i in 0..n {
3384 for j in 0..m {
3385 mean_tv[j] += result.tangent_vectors[(i, j)];
3386 }
3387 }
3388 for j in 0..m {
3389 mean_tv[j] /= n as f64;
3390 }
3391 let mean_norm: f64 = mean_tv.iter().map(|v| v * v).sum::<f64>().sqrt();
3392 assert!(
3393 mean_norm < 1.0,
3394 "Mean tangent vector should be near zero, got norm = {mean_norm}"
3395 );
3396 }
3397
3398 #[test]
3399 fn test_tsrvf_from_alignment() {
3400 let m = 50;
3401 let n = 5;
3402 let t = uniform_grid(m);
3403 let data = make_test_data(n, m, 42);
3404 let karcher = karcher_mean(&data, &t, 5, 1e-3, 0.0);
3405 let result = tsrvf_from_alignment(&karcher, &t);
3406 assert_eq!(result.tangent_vectors.shape(), (n, m));
3407 assert!(result.mean_srsf_norm > 0.0);
3408 }
3409
3410 #[test]
3411 fn test_tsrvf_round_trip() {
3412 let m = 50;
3413 let n = 5;
3414 let t = uniform_grid(m);
3415 let data = make_test_data(n, m, 42);
3416 let result = tsrvf_transform(&data, &t, 10, 1e-3, 0.0);
3417 let reconstructed = tsrvf_inverse(&result, &t);
3418
3419 assert_eq!(reconstructed.shape(), result.tangent_vectors.shape());
3420 for i in 0..n {
3422 for j in 0..m {
3423 assert!(
3424 reconstructed[(i, j)].is_finite(),
3425 "Reconstructed curve should be finite at ({i},{j})"
3426 );
3427 }
3428 }
3429 for i in 0..n {
3431 assert!(
3432 (reconstructed[(i, 0)] - result.initial_values[i]).abs() < 1e-6,
3433 "Curve {i} initial value: expected {}, got {}",
3434 result.initial_values[i],
3435 reconstructed[(i, 0)]
3436 );
3437 }
3438 }
3439
3440 #[test]
3441 fn test_tsrvf_initial_values_per_curve() {
3442 let m = 50;
3444 let n = 5;
3445 let t = uniform_grid(m);
3446
3447 let mut col_major = vec![0.0; n * m];
3449 for i in 0..n {
3450 let offset = (i as f64 + 1.0) * 2.0; for j in 0..m {
3452 col_major[i + j * n] = offset + (2.0 * std::f64::consts::PI * t[j]).sin();
3453 }
3454 }
3455 let data = FdMatrix::from_column_major(col_major, n, m).unwrap();
3456
3457 let result = tsrvf_transform(&data, &t, 15, 1e-4, 0.0);
3458
3459 assert_eq!(result.initial_values.len(), n);
3461 let all_same = result
3462 .initial_values
3463 .windows(2)
3464 .all(|w| (w[0] - w[1]).abs() < 1e-10);
3465 assert!(
3466 !all_same,
3467 "Initial values should differ per curve: {:?}",
3468 result.initial_values
3469 );
3470
3471 let reconstructed = tsrvf_inverse(&result, &t);
3473 for i in 0..n {
3474 assert!(
3475 (reconstructed[(i, 0)] - result.initial_values[i]).abs() < 1e-6,
3476 "Curve {i}: reconstructed f(0) = {}, expected {}",
3477 reconstructed[(i, 0)],
3478 result.initial_values[i]
3479 );
3480 }
3481
3482 let recon_initials: Vec<f64> = (0..n).map(|i| reconstructed[(i, 0)]).collect();
3485 let all_recon_same = recon_initials.windows(2).all(|w| (w[0] - w[1]).abs() < 0.1);
3486 assert!(
3487 !all_recon_same,
3488 "Reconstructed initial values must vary per curve: {:?}",
3489 recon_initials
3490 );
3491 }
3492
3493 #[test]
3494 fn test_tsrvf_single_curve() {
3495 let m = 50;
3496 let t = uniform_grid(m);
3497 let data = make_test_data(1, m, 42);
3498 let result = tsrvf_transform(&data, &t, 5, 1e-3, 0.0);
3499 assert_eq!(result.tangent_vectors.shape(), (1, m));
3500 let tv_norm: f64 = (0..m)
3502 .map(|j| result.tangent_vectors[(0, j)].powi(2))
3503 .sum::<f64>()
3504 .sqrt();
3505 assert!(
3506 tv_norm < 0.5,
3507 "Single curve tangent vector should be near zero, got {tv_norm}"
3508 );
3509 }
3510
3511 #[test]
3512 fn test_tsrvf_constant_curves() {
3513 let m = 30;
3514 let t = uniform_grid(m);
3515 let data = FdMatrix::from_column_major(vec![5.0; 3 * m], 3, m).unwrap();
3517 let result = tsrvf_transform(&data, &t, 5, 1e-3, 0.0);
3518 for i in 0..3 {
3520 for j in 0..m {
3521 let v = result.tangent_vectors[(i, j)];
3522 assert!(
3523 !v.is_nan(),
3524 "Constant curves should not produce NaN tangent vectors"
3525 );
3526 }
3527 }
3528 }
3529
3530 #[test]
3533 fn test_tsrvf_sphere_inv_exp_reference() {
3534 let m = 21;
3537 let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
3538
3539 let raw1 = vec![1.0; m];
3541 let norm1 = inner_product_l2(&raw1, &raw1, &time).max(0.0).sqrt();
3542 let psi1: Vec<f64> = raw1.iter().map(|&v| v / norm1).collect();
3543
3544 let raw2: Vec<f64> = time
3546 .iter()
3547 .map(|&t| 1.0 + 0.3 * (2.0 * std::f64::consts::PI * t).sin())
3548 .collect();
3549 let norm2 = inner_product_l2(&raw2, &raw2, &time).max(0.0).sqrt();
3550 let psi2: Vec<f64> = raw2.iter().map(|&v| v / norm2).collect();
3551
3552 let ip = inner_product_l2(&psi1, &psi2, &time).clamp(-1.0, 1.0);
3554 let theta_expected = ip.acos();
3555
3556 let v = inv_exp_map_sphere(&psi1, &psi2, &time);
3558 let v_norm = inner_product_l2(&v, &v, &time).max(0.0).sqrt();
3559
3560 assert!(
3562 (v_norm - theta_expected).abs() < 1e-10,
3563 "||v|| = {v_norm}, expected theta = {theta_expected}"
3564 );
3565
3566 assert!(
3568 theta_expected > 0.01 && theta_expected < 1.0,
3569 "theta = {theta_expected} out of expected range"
3570 );
3571 }
3572
3573 #[test]
3574 fn test_tsrvf_sphere_round_trip_reference() {
3575 let m = 21;
3577 let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
3578
3579 let raw1 = vec![1.0; m];
3580 let norm1 = inner_product_l2(&raw1, &raw1, &time).max(0.0).sqrt();
3581 let psi1: Vec<f64> = raw1.iter().map(|&v| v / norm1).collect();
3582
3583 let raw2: Vec<f64> = time
3584 .iter()
3585 .map(|&t| 1.0 + 0.3 * (2.0 * std::f64::consts::PI * t).sin())
3586 .collect();
3587 let norm2 = inner_product_l2(&raw2, &raw2, &time).max(0.0).sqrt();
3588 let psi2: Vec<f64> = raw2.iter().map(|&v| v / norm2).collect();
3589
3590 let v = inv_exp_map_sphere(&psi1, &psi2, &time);
3591 let recovered = exp_map_sphere(&psi1, &v, &time);
3592
3593 let diff: Vec<f64> = psi2
3595 .iter()
3596 .zip(recovered.iter())
3597 .map(|(&a, &b)| (a - b).powi(2))
3598 .collect();
3599 let l2_err = trapz(&diff, &time).max(0.0).sqrt();
3600 assert!(
3601 l2_err < 1e-12,
3602 "Round-trip L2 error = {l2_err:.2e}, expected < 1e-12"
3603 );
3604 }
3605
3606 #[test]
3609 fn test_penalized_alignment_lambda_zero_matches_unpenalized() {
3610 let m = 50;
3611 let t = uniform_grid(m);
3612 let data = make_test_data(2, m, 42);
3613 let f1 = data.row(0);
3614 let f2 = data.row(1);
3615
3616 let r0 = elastic_align_pair(&f1, &f2, &t, 0.0);
3617 assert!(r0.distance >= 0.0);
3619 assert_eq!(r0.gamma.len(), m);
3620 }
3621
3622 #[test]
3623 fn test_penalized_alignment_smoother_warp() {
3624 let m = 80;
3625 let t = uniform_grid(m);
3626 let f1: Vec<f64> = t
3627 .iter()
3628 .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
3629 .collect();
3630 let f2: Vec<f64> = t
3631 .iter()
3632 .map(|&ti| (2.0 * std::f64::consts::PI * (ti - 0.15)).sin())
3633 .collect();
3634
3635 let r_free = elastic_align_pair(&f1, &f2, &t, 0.0);
3636 let r_pen = elastic_align_pair(&f1, &f2, &t, 1.0);
3637
3638 let dev_free: f64 = r_free
3640 .gamma
3641 .iter()
3642 .zip(t.iter())
3643 .map(|(g, ti)| (g - ti).powi(2))
3644 .sum();
3645 let dev_pen: f64 = r_pen
3646 .gamma
3647 .iter()
3648 .zip(t.iter())
3649 .map(|(g, ti)| (g - ti).powi(2))
3650 .sum();
3651
3652 assert!(
3653 dev_pen <= dev_free + 1e-6,
3654 "Penalized warp should be closer to identity: free={dev_free:.6}, pen={dev_pen:.6}"
3655 );
3656 }
3657
3658 #[test]
3659 fn test_penalized_alignment_large_lambda_near_identity() {
3660 let m = 50;
3661 let t = uniform_grid(m);
3662 let f1: Vec<f64> = t
3663 .iter()
3664 .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
3665 .collect();
3666 let f2: Vec<f64> = t
3667 .iter()
3668 .map(|&ti| (2.0 * std::f64::consts::PI * (ti - 0.1)).sin())
3669 .collect();
3670
3671 let r = elastic_align_pair(&f1, &f2, &t, 1000.0);
3672
3673 let max_dev: f64 = r
3675 .gamma
3676 .iter()
3677 .zip(t.iter())
3678 .map(|(g, ti)| (g - ti).abs())
3679 .fold(0.0_f64, f64::max);
3680 assert!(
3681 max_dev < 0.05,
3682 "Large lambda should give near-identity warp: max deviation = {max_dev}"
3683 );
3684 }
3685
3686 #[test]
3687 fn test_penalized_karcher_mean() {
3688 let m = 40;
3689 let t = uniform_grid(m);
3690 let data = make_test_data(10, m, 42);
3691
3692 let result = karcher_mean(&data, &t, 5, 1e-3, 0.5);
3693 assert_eq!(result.mean.len(), m);
3694 for j in 0..m {
3695 assert!(result.mean[j].is_finite());
3696 }
3697 }
3698
3699 #[test]
3702 fn test_decomposition_identity_curves() {
3703 let m = 50;
3704 let t = uniform_grid(m);
3705 let f: Vec<f64> = t
3706 .iter()
3707 .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
3708 .collect();
3709
3710 let result = elastic_decomposition(&f, &f, &t, 0.0);
3711 assert!(
3712 result.d_amplitude < 0.1,
3713 "Self-decomposition amplitude should be ~0, got {}",
3714 result.d_amplitude
3715 );
3716 assert!(
3717 result.d_phase < 0.2,
3718 "Self-decomposition phase should be ~0, got {}",
3719 result.d_phase
3720 );
3721 }
3722
3723 #[test]
3724 fn test_decomposition_pythagorean() {
3725 let m = 80;
3727 let t = uniform_grid(m);
3728 let f1: Vec<f64> = t
3729 .iter()
3730 .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
3731 .collect();
3732 let f2: Vec<f64> = t
3733 .iter()
3734 .map(|&ti| 1.2 * (2.0 * std::f64::consts::PI * (ti - 0.1)).sin())
3735 .collect();
3736
3737 let result = elastic_decomposition(&f1, &f2, &t, 0.0);
3738 let da = result.d_amplitude;
3739 let dp = result.d_phase;
3740 assert!(da >= 0.0);
3742 assert!(dp >= 0.0);
3743 }
3744
3745 #[test]
3746 fn test_phase_distance_shifted_sine() {
3747 let m = 80;
3748 let t = uniform_grid(m);
3749 let f1: Vec<f64> = t
3750 .iter()
3751 .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
3752 .collect();
3753 let f2: Vec<f64> = t
3754 .iter()
3755 .map(|&ti| (2.0 * std::f64::consts::PI * (ti - 0.15)).sin())
3756 .collect();
3757
3758 let dp = phase_distance_pair(&f1, &f2, &t, 0.0);
3759 assert!(
3760 dp > 0.01,
3761 "Phase distance of shifted curves should be > 0, got {dp}"
3762 );
3763 }
3764
3765 #[test]
3766 fn test_amplitude_distance_scaled_curve() {
3767 let m = 80;
3768 let t = uniform_grid(m);
3769 let f1: Vec<f64> = t
3770 .iter()
3771 .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
3772 .collect();
3773 let f2: Vec<f64> = t
3774 .iter()
3775 .map(|&ti| 2.0 * (2.0 * std::f64::consts::PI * ti).sin())
3776 .collect();
3777
3778 let da = amplitude_distance(&f1, &f2, &t, 0.0);
3779 assert!(
3780 da > 0.01,
3781 "Amplitude distance of scaled curves should be > 0, got {da}"
3782 );
3783 }
3784
3785 #[test]
3786 fn test_phase_distance_nonneg() {
3787 let data = make_test_data(4, 40, 42);
3788 let t = uniform_grid(40);
3789 for i in 0..4 {
3790 for j in 0..4 {
3791 let fi = data.row(i);
3792 let fj = data.row(j);
3793 let dp = phase_distance_pair(&fi, &fj, &t, 0.0);
3794 assert!(dp >= 0.0, "Phase distance should be non-negative");
3795 }
3796 }
3797 }
3798
3799 #[test]
3802 fn test_schilds_ladder_zero_vector() {
3803 let m = 21;
3804 let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
3805 let raw = vec![1.0; m];
3806 let norm = crate::warping::l2_norm_l2(&raw, &time);
3807 let from: Vec<f64> = raw.iter().map(|&v| v / norm).collect();
3808 let raw2: Vec<f64> = time
3809 .iter()
3810 .map(|&t| 1.0 + 0.2 * (2.0 * std::f64::consts::PI * t).sin())
3811 .collect();
3812 let norm2 = crate::warping::l2_norm_l2(&raw2, &time);
3813 let to: Vec<f64> = raw2.iter().map(|&v| v / norm2).collect();
3814
3815 let zero = vec![0.0; m];
3816 let result = parallel_transport_schilds(&zero, &from, &to, &time);
3817 let result_norm: f64 = result.iter().map(|v| v * v).sum::<f64>().sqrt();
3818 assert!(
3819 result_norm < 1e-6,
3820 "Transporting zero should give zero, got norm {result_norm}"
3821 );
3822 }
3823
3824 #[test]
3825 fn test_pole_ladder_zero_vector() {
3826 let m = 21;
3827 let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
3828 let raw = vec![1.0; m];
3829 let norm = crate::warping::l2_norm_l2(&raw, &time);
3830 let from: Vec<f64> = raw.iter().map(|&v| v / norm).collect();
3831 let raw2: Vec<f64> = time
3832 .iter()
3833 .map(|&t| 1.0 + 0.2 * (2.0 * std::f64::consts::PI * t).sin())
3834 .collect();
3835 let norm2 = crate::warping::l2_norm_l2(&raw2, &time);
3836 let to: Vec<f64> = raw2.iter().map(|&v| v / norm2).collect();
3837
3838 let zero = vec![0.0; m];
3839 let result = parallel_transport_pole(&zero, &from, &to, &time);
3840 let result_norm: f64 = result.iter().map(|v| v * v).sum::<f64>().sqrt();
3841 assert!(
3842 result_norm < 1e-6,
3843 "Transporting zero should give zero, got norm {result_norm}"
3844 );
3845 }
3846
3847 #[test]
3848 fn test_schilds_preserves_norm() {
3849 let m = 51;
3850 let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
3851 let raw = vec![1.0; m];
3852 let norm = crate::warping::l2_norm_l2(&raw, &time);
3853 let from: Vec<f64> = raw.iter().map(|&v| v / norm).collect();
3854 let raw2: Vec<f64> = time
3855 .iter()
3856 .map(|&t| 1.0 + 0.15 * (2.0 * std::f64::consts::PI * t).sin())
3857 .collect();
3858 let norm2 = crate::warping::l2_norm_l2(&raw2, &time);
3859 let to: Vec<f64> = raw2.iter().map(|&v| v / norm2).collect();
3860
3861 let v: Vec<f64> = time
3863 .iter()
3864 .map(|&t| 0.1 * (4.0 * std::f64::consts::PI * t).cos())
3865 .collect();
3866 let v_norm = crate::warping::l2_norm_l2(&v, &time);
3867
3868 let transported = parallel_transport_schilds(&v, &from, &to, &time);
3869 let t_norm = crate::warping::l2_norm_l2(&transported, &time);
3870
3871 assert!(
3873 (t_norm - v_norm).abs() / v_norm.max(1e-10) < 1.5,
3874 "Schild's should roughly preserve norm: original={v_norm:.4}, transported={t_norm:.4}"
3875 );
3876 }
3877
3878 #[test]
3879 fn test_tsrvf_logmap_matches_original() {
3880 let m = 50;
3881 let n = 5;
3882 let t = uniform_grid(m);
3883 let data = make_test_data(n, m, 42);
3884
3885 let result_orig = tsrvf_transform(&data, &t, 5, 1e-3, 0.0);
3886 let result_logmap =
3887 tsrvf_transform_with_method(&data, &t, 5, 1e-3, 0.0, TransportMethod::LogMap);
3888
3889 for i in 0..n {
3891 for j in 0..m {
3892 assert!(
3893 (result_orig.tangent_vectors[(i, j)] - result_logmap.tangent_vectors[(i, j)])
3894 .abs()
3895 < 1e-12,
3896 "LogMap variant should match original at ({i},{j})"
3897 );
3898 }
3899 }
3900 }
3901
3902 #[test]
3903 fn test_tsrvf_with_schilds_produces_valid_result() {
3904 let m = 50;
3905 let n = 5;
3906 let t = uniform_grid(m);
3907 let data = make_test_data(n, m, 42);
3908
3909 let result =
3910 tsrvf_transform_with_method(&data, &t, 5, 1e-3, 0.0, TransportMethod::SchildsLadder);
3911
3912 assert_eq!(result.tangent_vectors.shape(), (n, m));
3913 for i in 0..n {
3914 for j in 0..m {
3915 assert!(
3916 result.tangent_vectors[(i, j)].is_finite(),
3917 "Schild's TSRVF should produce finite tangent vectors at ({i},{j})"
3918 );
3919 }
3920 }
3921 }
3922
3923 #[test]
3924 fn test_transport_methods_differ() {
3925 let m = 50;
3926 let n = 5;
3927 let t = uniform_grid(m);
3928 let data = make_test_data(n, m, 42);
3929 let karcher = karcher_mean(&data, &t, 5, 1e-3, 0.0);
3930
3931 let r_log = tsrvf_from_alignment_with_method(&karcher, &t, TransportMethod::LogMap);
3932 let r_schilds =
3933 tsrvf_from_alignment_with_method(&karcher, &t, TransportMethod::SchildsLadder);
3934
3935 let mut total_diff = 0.0;
3937 for i in 0..n {
3938 for j in 0..m {
3939 total_diff +=
3940 (r_log.tangent_vectors[(i, j)] - r_schilds.tangent_vectors[(i, j)]).abs();
3941 }
3942 }
3943
3944 assert!(total_diff.is_finite());
3947 }
3948
3949 #[test]
3952 fn test_warp_complexity_identity_is_zero() {
3953 let m = 50;
3954 let t = uniform_grid(m);
3955 let identity = t.clone();
3956 let c = warp_complexity(&identity, &t);
3957 assert!(
3958 c < 1e-10,
3959 "Identity warp should have zero complexity, got {c}"
3960 );
3961 }
3962
3963 #[test]
3964 fn test_warp_complexity_nonidentity_positive() {
3965 let m = 50;
3966 let t = uniform_grid(m);
3967 let gamma: Vec<f64> = t.iter().map(|&ti| ti * ti).collect();
3968 let c = warp_complexity(&gamma, &t);
3969 assert!(
3970 c > 0.01,
3971 "Non-identity warp should have positive complexity, got {c}"
3972 );
3973 }
3974
3975 #[test]
3976 fn test_warp_smoothness_identity_is_zero() {
3977 let m = 50;
3978 let t = uniform_grid(m);
3979 let identity = t.clone();
3980 let s = warp_smoothness(&identity, &t);
3981 assert!(
3982 s < 1e-6,
3983 "Identity warp (constant γ'=1, γ''=0) should have near-zero bending energy, got {s}"
3984 );
3985 }
3986
3987 #[test]
3988 fn test_alignment_quality_basic() {
3989 let m = 50;
3990 let n = 8;
3991 let t = uniform_grid(m);
3992 let data = make_test_data(n, m, 42);
3993 let karcher = karcher_mean(&data, &t, 10, 1e-3, 0.0);
3994 let quality = alignment_quality(&data, &karcher, &t);
3995
3996 assert_eq!(quality.warp_complexity.len(), n);
3998 assert_eq!(quality.warp_smoothness.len(), n);
3999 assert_eq!(quality.pointwise_variance_ratio.len(), m);
4000
4001 assert!(quality.total_variance >= 0.0);
4003 assert!(quality.amplitude_variance >= 0.0);
4004 assert!(quality.phase_variance >= 0.0);
4005 assert!(quality.mean_warp_complexity >= 0.0);
4006 assert!(quality.mean_warp_smoothness >= 0.0);
4007
4008 assert!(
4010 quality.amplitude_variance <= quality.total_variance + 1e-10,
4011 "Amplitude variance ({}) should be ≤ total variance ({})",
4012 quality.amplitude_variance,
4013 quality.total_variance
4014 );
4015 }
4016
4017 #[test]
4018 fn test_alignment_quality_identical_curves() {
4019 let m = 50;
4020 let t = uniform_grid(m);
4021 let curve: Vec<f64> = t
4022 .iter()
4023 .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
4024 .collect();
4025 let mut col_major = vec![0.0; 5 * m];
4026 for i in 0..5 {
4027 for j in 0..m {
4028 col_major[i + j * 5] = curve[j];
4029 }
4030 }
4031 let data = FdMatrix::from_column_major(col_major, 5, m).unwrap();
4032 let karcher = karcher_mean(&data, &t, 5, 1e-3, 0.0);
4033 let quality = alignment_quality(&data, &karcher, &t);
4034
4035 assert!(
4037 quality.total_variance < 0.01,
4038 "Identical curves should have near-zero total variance, got {}",
4039 quality.total_variance
4040 );
4041 assert!(
4042 quality.mean_warp_complexity < 0.1,
4043 "Identical curves should have near-zero warp complexity, got {}",
4044 quality.mean_warp_complexity
4045 );
4046 }
4047
4048 #[test]
4049 fn test_alignment_quality_variance_reduction() {
4050 let m = 50;
4051 let n = 10;
4052 let t = uniform_grid(m);
4053 let data = make_test_data(n, m, 42);
4054 let karcher = karcher_mean(&data, &t, 10, 1e-3, 0.0);
4055 let quality = alignment_quality(&data, &karcher, &t);
4056
4057 assert!(
4059 quality.mean_variance_reduction <= 1.5,
4060 "Mean variance reduction ratio should be ≤ ~1, got {}",
4061 quality.mean_variance_reduction
4062 );
4063 }
4064
4065 #[test]
4066 fn test_pairwise_consistency_small() {
4067 let m = 40;
4068 let n = 4;
4069 let t = uniform_grid(m);
4070 let data = make_test_data(n, m, 42);
4071
4072 let consistency = pairwise_consistency(&data, &t, 0.0, 100);
4073 assert!(
4074 consistency.is_finite() && consistency >= 0.0,
4075 "Pairwise consistency should be finite and non-negative, got {consistency}"
4076 );
4077 }
4078
4079 #[test]
4082 fn test_srsf_nd_d1_matches_existing() {
4083 let m = 50;
4084 let t = uniform_grid(m);
4085 let data = make_test_data(3, m, 42);
4086
4087 let q_1d = srsf_transform(&data, &t);
4089
4090 let data_nd = FdCurveSet::from_1d(data);
4092 let q_nd = srsf_transform_nd(&data_nd, &t);
4093
4094 assert_eq!(q_nd.ndim(), 1);
4095 for i in 0..3 {
4096 for j in 0..m {
4097 assert!(
4098 (q_1d[(i, j)] - q_nd.dims[0][(i, j)]).abs() < 1e-10,
4099 "1D nd SRSF should match existing at ({i},{j}): {} vs {}",
4100 q_1d[(i, j)],
4101 q_nd.dims[0][(i, j)]
4102 );
4103 }
4104 }
4105 }
4106
4107 #[test]
4108 fn test_srsf_nd_constant_is_zero() {
4109 let m = 30;
4110 let t = uniform_grid(m);
4111 let dim0 = FdMatrix::from_column_major(vec![3.0; m], 1, m).unwrap();
4113 let dim1 = FdMatrix::from_column_major(vec![-1.0; m], 1, m).unwrap();
4114 let data = FdCurveSet::from_dims(vec![dim0, dim1]).unwrap();
4115
4116 let q = srsf_transform_nd(&data, &t);
4117 for k in 0..2 {
4118 for j in 0..m {
4119 assert!(
4120 q.dims[k][(0, j)].abs() < 1e-10,
4121 "Constant curve SRSF should be zero, dim {k} at {j}: {}",
4122 q.dims[k][(0, j)]
4123 );
4124 }
4125 }
4126 }
4127
4128 #[test]
4129 fn test_srsf_nd_linear_r2() {
4130 let m = 51;
4131 let t = uniform_grid(m);
4132 let dim0 =
4135 FdMatrix::from_slice(&t.iter().map(|&ti| 2.0 * ti).collect::<Vec<_>>(), 1, m).unwrap();
4136 let dim1 =
4137 FdMatrix::from_slice(&t.iter().map(|&ti| 3.0 * ti).collect::<Vec<_>>(), 1, m).unwrap();
4138 let data = FdCurveSet::from_dims(vec![dim0, dim1]).unwrap();
4139
4140 let q = srsf_transform_nd(&data, &t);
4141 let expected_scale = 1.0 / 13.0_f64.powf(0.25);
4142 let mid = m / 2;
4143
4144 assert!(
4145 (q.dims[0][(0, mid)] - 2.0 * expected_scale).abs() < 0.1,
4146 "q_x at midpoint: {} vs expected {}",
4147 q.dims[0][(0, mid)],
4148 2.0 * expected_scale
4149 );
4150 assert!(
4151 (q.dims[1][(0, mid)] - 3.0 * expected_scale).abs() < 0.1,
4152 "q_y at midpoint: {} vs expected {}",
4153 q.dims[1][(0, mid)],
4154 3.0 * expected_scale
4155 );
4156 }
4157
4158 #[test]
4159 fn test_srsf_nd_round_trip() {
4160 let m = 51;
4161 let t = uniform_grid(m);
4162 let pi2 = 2.0 * std::f64::consts::PI;
4164 let vals_x: Vec<f64> = t.iter().map(|&ti| (pi2 * ti).sin()).collect();
4165 let vals_y: Vec<f64> = t.iter().map(|&ti| (pi2 * ti).cos()).collect();
4166 let dim0 = FdMatrix::from_slice(&vals_x, 1, m).unwrap();
4167 let dim1 = FdMatrix::from_slice(&vals_y, 1, m).unwrap();
4168 let data = FdCurveSet::from_dims(vec![dim0, dim1]).unwrap();
4169
4170 let q = srsf_transform_nd(&data, &t);
4171 let q_vecs: Vec<Vec<f64>> = q.dims.iter().map(|dm| dm.row(0)).collect();
4172 let f0 = vec![vals_x[0], vals_y[0]];
4173 let recon = srsf_inverse_nd(&q_vecs, &t, &f0);
4174
4175 let mut max_err = 0.0_f64;
4177 for k in 0..2 {
4178 let orig = if k == 0 { &vals_x } else { &vals_y };
4179 for j in 2..(m - 2) {
4180 let err = (recon[k][j] - orig[j]).abs();
4181 max_err = max_err.max(err);
4182 }
4183 }
4184 assert!(
4185 max_err < 0.2,
4186 "SRSF round-trip max error should be small, got {max_err}"
4187 );
4188 }
4189
4190 #[test]
4191 fn test_align_nd_identical_near_zero() {
4192 let m = 50;
4193 let t = uniform_grid(m);
4194 let pi2 = 2.0 * std::f64::consts::PI;
4195 let vals_x: Vec<f64> = t.iter().map(|&ti| (pi2 * ti).sin()).collect();
4196 let vals_y: Vec<f64> = t.iter().map(|&ti| (pi2 * ti).cos()).collect();
4197 let dim0 = FdMatrix::from_slice(&vals_x, 1, m).unwrap();
4198 let dim1 = FdMatrix::from_slice(&vals_y, 1, m).unwrap();
4199 let data = FdCurveSet::from_dims(vec![dim0, dim1]).unwrap();
4200
4201 let result = elastic_align_pair_nd(&data, &data, &t, 0.0);
4202 assert!(
4203 result.distance < 0.5,
4204 "Self-alignment distance should be ~0, got {}",
4205 result.distance
4206 );
4207 let max_dev: f64 = result
4209 .gamma
4210 .iter()
4211 .zip(t.iter())
4212 .map(|(g, ti)| (g - ti).abs())
4213 .fold(0.0_f64, f64::max);
4214 assert!(
4215 max_dev < 0.1,
4216 "Self-alignment warp should be near identity, max dev = {max_dev}"
4217 );
4218 }
4219
4220 #[test]
4221 fn test_align_nd_shifted_r2() {
4222 let m = 60;
4223 let t = uniform_grid(m);
4224 let pi2 = 2.0 * std::f64::consts::PI;
4225
4226 let f1x: Vec<f64> = t.iter().map(|&ti| (pi2 * ti).sin()).collect();
4228 let f1y: Vec<f64> = t.iter().map(|&ti| (pi2 * ti).cos()).collect();
4229 let f1 = FdCurveSet::from_dims(vec![
4230 FdMatrix::from_slice(&f1x, 1, m).unwrap(),
4231 FdMatrix::from_slice(&f1y, 1, m).unwrap(),
4232 ])
4233 .unwrap();
4234
4235 let f2x: Vec<f64> = t.iter().map(|&ti| (pi2 * (ti - 0.1)).sin()).collect();
4237 let f2y: Vec<f64> = t.iter().map(|&ti| (pi2 * (ti - 0.1)).cos()).collect();
4238 let f2 = FdCurveSet::from_dims(vec![
4239 FdMatrix::from_slice(&f2x, 1, m).unwrap(),
4240 FdMatrix::from_slice(&f2y, 1, m).unwrap(),
4241 ])
4242 .unwrap();
4243
4244 let result = elastic_align_pair_nd(&f1, &f2, &t, 0.0);
4245 assert!(
4246 result.distance.is_finite(),
4247 "Distance should be finite, got {}",
4248 result.distance
4249 );
4250 assert_eq!(result.f_aligned.len(), 2);
4251 assert_eq!(result.f_aligned[0].len(), m);
4252 let max_dev: f64 = result
4254 .gamma
4255 .iter()
4256 .zip(t.iter())
4257 .map(|(g, ti)| (g - ti).abs())
4258 .fold(0.0_f64, f64::max);
4259 assert!(
4260 max_dev > 0.01,
4261 "Shifted curves should require non-trivial warp, max dev = {max_dev}"
4262 );
4263 }
4264
4265 #[test]
4268 fn test_constrained_no_landmarks_matches_unconstrained() {
4269 let m = 50;
4270 let t = uniform_grid(m);
4271 let f1: Vec<f64> = t
4272 .iter()
4273 .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
4274 .collect();
4275 let f2: Vec<f64> = t
4276 .iter()
4277 .map(|&ti| (2.0 * std::f64::consts::PI * (ti - 0.1)).sin())
4278 .collect();
4279
4280 let r_free = elastic_align_pair(&f1, &f2, &t, 0.0);
4281 let r_const = elastic_align_pair_constrained(&f1, &f2, &t, &[], 0.0);
4282
4283 for j in 0..m {
4285 assert!(
4286 (r_free.gamma[j] - r_const.gamma[j]).abs() < 1e-10,
4287 "No-landmark constrained should match unconstrained at {j}"
4288 );
4289 }
4290 assert!(r_const.enforced_landmarks.is_empty());
4291 }
4292
4293 #[test]
4294 fn test_constrained_single_landmark_enforced() {
4295 let m = 60;
4296 let t = uniform_grid(m);
4297 let f1: Vec<f64> = t
4298 .iter()
4299 .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
4300 .collect();
4301 let f2: Vec<f64> = t
4302 .iter()
4303 .map(|&ti| (2.0 * std::f64::consts::PI * (ti - 0.1)).sin())
4304 .collect();
4305
4306 let result = elastic_align_pair_constrained(&f1, &f2, &t, &[(0.5, 0.5)], 0.0);
4308
4309 let mid_idx = snap_to_grid(0.5, &t);
4311 assert!(
4312 (result.gamma[mid_idx] - 0.5).abs() < 0.05,
4313 "Constrained gamma at midpoint should be ~0.5, got {}",
4314 result.gamma[mid_idx]
4315 );
4316 assert_eq!(result.enforced_landmarks.len(), 1);
4317 }
4318
4319 #[test]
4320 fn test_constrained_multiple_landmarks() {
4321 let m = 80;
4322 let t = uniform_grid(m);
4323 let f1: Vec<f64> = t
4324 .iter()
4325 .map(|&ti| (4.0 * std::f64::consts::PI * ti).sin())
4326 .collect();
4327 let f2: Vec<f64> = t
4328 .iter()
4329 .map(|&ti| (4.0 * std::f64::consts::PI * (ti - 0.05)).sin())
4330 .collect();
4331
4332 let landmarks = vec![(0.25, 0.25), (0.5, 0.5), (0.75, 0.75)];
4333 let result = elastic_align_pair_constrained(&f1, &f2, &t, &landmarks, 0.0);
4334
4335 for &(tt, st) in &landmarks {
4337 let idx = snap_to_grid(tt, &t);
4338 assert!(
4339 (result.gamma[idx] - st).abs() < 0.05,
4340 "Gamma at t={tt} should be ~{st}, got {}",
4341 result.gamma[idx]
4342 );
4343 }
4344 }
4345
4346 #[test]
4347 fn test_constrained_monotone_gamma() {
4348 let m = 60;
4349 let t = uniform_grid(m);
4350 let f1: Vec<f64> = t
4351 .iter()
4352 .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
4353 .collect();
4354 let f2: Vec<f64> = t
4355 .iter()
4356 .map(|&ti| (2.0 * std::f64::consts::PI * (ti - 0.1)).sin())
4357 .collect();
4358
4359 let result = elastic_align_pair_constrained(&f1, &f2, &t, &[(0.3, 0.3), (0.7, 0.7)], 0.0);
4360
4361 for j in 1..m {
4363 assert!(
4364 result.gamma[j] >= result.gamma[j - 1] - 1e-10,
4365 "Gamma should be monotone: gamma[{}]={} < gamma[{}]={}",
4366 j,
4367 result.gamma[j],
4368 j - 1,
4369 result.gamma[j - 1]
4370 );
4371 }
4372 assert!((result.gamma[0] - t[0]).abs() < 1e-10);
4374 assert!((result.gamma[m - 1] - t[m - 1]).abs() < 1e-10);
4375 }
4376
4377 #[test]
4378 fn test_constrained_distance_ge_unconstrained() {
4379 let m = 60;
4380 let t = uniform_grid(m);
4381 let f1: Vec<f64> = t
4382 .iter()
4383 .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
4384 .collect();
4385 let f2: Vec<f64> = t
4386 .iter()
4387 .map(|&ti| (2.0 * std::f64::consts::PI * (ti - 0.15)).sin())
4388 .collect();
4389
4390 let r_free = elastic_align_pair(&f1, &f2, &t, 0.0);
4391 let r_const = elastic_align_pair_constrained(&f1, &f2, &t, &[(0.5, 0.5)], 0.0);
4392
4393 assert!(
4395 r_const.distance >= r_free.distance - 1e-6,
4396 "Constrained distance ({}) should be >= unconstrained ({})",
4397 r_const.distance,
4398 r_free.distance
4399 );
4400 }
4401
4402 #[test]
4403 fn test_constrained_with_landmark_detection() {
4404 let m = 80;
4405 let t = uniform_grid(m);
4406 let f1: Vec<f64> = t
4407 .iter()
4408 .map(|&ti| (4.0 * std::f64::consts::PI * ti).sin())
4409 .collect();
4410 let f2: Vec<f64> = t
4411 .iter()
4412 .map(|&ti| (4.0 * std::f64::consts::PI * (ti - 0.05)).sin())
4413 .collect();
4414
4415 let result = elastic_align_pair_with_landmarks(
4416 &f1,
4417 &f2,
4418 &t,
4419 crate::landmark::LandmarkKind::Peak,
4420 0.1,
4421 0,
4422 0.0,
4423 );
4424
4425 assert_eq!(result.gamma.len(), m);
4426 assert_eq!(result.f_aligned.len(), m);
4427 assert!(result.distance.is_finite());
4428 for j in 1..m {
4430 assert!(
4431 result.gamma[j] >= result.gamma[j - 1] - 1e-10,
4432 "Gamma should be monotone at j={j}"
4433 );
4434 }
4435 }
4436
4437 #[test]
4440 fn test_gam_to_psi_smooth_identity() {
4441 use crate::warping::{gam_to_psi, gam_to_psi_smooth};
4444 let m = 101;
4445 let h = 1.0 / (m - 1) as f64;
4446 let gam: Vec<f64> = uniform_grid(m);
4447 let psi_raw = gam_to_psi(&gam, h);
4448 let psi_smooth = gam_to_psi_smooth(&gam, h);
4449 let skip = m / 20;
4451 for j in skip..(m - skip) {
4452 assert!(
4453 (psi_smooth[j] - 1.0).abs() < 0.05,
4454 "Smoothed psi of identity should be ~1.0, got {} at j={}",
4455 psi_smooth[j],
4456 j
4457 );
4458 assert!(
4459 (psi_smooth[j] - psi_raw[j]).abs() < 0.05,
4460 "Smoothed and raw psi should agree on smooth warp at j={}",
4461 j
4462 );
4463 }
4464 }
4465
4466 #[test]
4467 fn test_gam_to_psi_smooth_reduces_spikes() {
4468 use crate::warping::{gam_to_psi, gam_to_psi_smooth};
4471 let m = 101;
4472 let h = 1.0 / (m - 1) as f64;
4473 let argvals = uniform_grid(m);
4474 let mut gam: Vec<f64> = Vec::with_capacity(m);
4476 for j in 0..m {
4477 let t = argvals[j];
4478 let g = if t < 0.33 {
4480 t * 0.5 / 0.33
4481 } else if t < 0.67 {
4482 0.5 + (t - 0.33) * 0.5 / 0.34 * 2.0 } else {
4484 let base = 0.5 + 0.5 / 0.34 * 2.0 * 0.34; (base + (t - 0.67) * 0.5 / 0.33).min(1.0)
4486 };
4487 gam.push(g.min(1.0));
4488 }
4489 let gmax = gam[m - 1].max(1e-10);
4491 for g in &mut gam {
4492 *g /= gmax;
4493 }
4494 let psi_raw = gam_to_psi(&gam, h);
4495 let psi_smooth = gam_to_psi_smooth(&gam, h);
4496 let max_jump_raw: f64 = psi_raw
4498 .windows(2)
4499 .map(|w| (w[1] - w[0]).abs())
4500 .fold(0.0_f64, f64::max);
4501 let max_jump_smooth: f64 = psi_smooth
4502 .windows(2)
4503 .map(|w| (w[1] - w[0]).abs())
4504 .fold(0.0_f64, f64::max);
4505 assert!(
4507 max_jump_smooth < max_jump_raw + 0.01,
4508 "Smoothing should not increase max psi jump: raw={max_jump_raw:.4}, smooth={max_jump_smooth:.4}"
4509 );
4510 }
4511
4512 #[test]
4513 fn test_smooth_aligned_srsfs_preserves_shape() {
4514 use crate::smoothing::nadaraya_watson;
4516 let m = 101;
4517 let time: Vec<f64> = (0..m).map(|j| j as f64 / (m - 1) as f64).collect();
4518 let qi: Vec<f64> = time
4520 .iter()
4521 .map(|&t| (2.0 * std::f64::consts::PI * t).sin())
4522 .collect();
4523 let bandwidth = 2.0 / (m - 1) as f64;
4524 let qi_smooth = nadaraya_watson(&time, &qi, &time, bandwidth, "gaussian");
4525 let mean_orig: f64 = qi.iter().sum::<f64>() / m as f64;
4527 let mean_smooth: f64 = qi_smooth.iter().sum::<f64>() / m as f64;
4528 let mut cov = 0.0;
4529 let mut var_o = 0.0;
4530 let mut var_s = 0.0;
4531 for j in 0..m {
4532 let do_ = qi[j] - mean_orig;
4533 let ds = qi_smooth[j] - mean_smooth;
4534 cov += do_ * ds;
4535 var_o += do_ * do_;
4536 var_s += ds * ds;
4537 }
4538 let rho = cov / (var_o * var_s).sqrt().max(1e-10);
4539 assert!(
4540 rho > 0.99,
4541 "Smoothed SRSF should be highly correlated with original (rho={rho:.4})"
4542 );
4543 }
4544
4545 #[test]
4546 fn test_tsrvf_tangent_vectors_no_spikes() {
4547 let m = 101;
4551 let argvals = uniform_grid(m);
4552 let data = make_test_data(10, m, 42);
4553 let result = tsrvf_transform(&data, &argvals, 5, 1e-3, 0.0);
4554 let (n, _) = result.tangent_vectors.shape();
4555 for i in 0..n {
4556 let vi = result.tangent_vectors.row(i);
4557 let rms = (vi.iter().map(|&v| v * v).sum::<f64>() / m as f64).sqrt();
4558 if rms > 1e-10 {
4559 let max_abs = vi.iter().map(|&v| v.abs()).fold(0.0_f64, f64::max);
4560 assert!(
4561 max_abs < 10.0 * rms,
4562 "Tangent vector {} has spike: max |v| = {max_abs:.4}, rms = {rms:.4}, ratio = {:.1}",
4563 i,
4564 max_abs / rms
4565 );
4566 }
4567 }
4568 }
4569}