1use crate::fdata::{deriv_1d, mean_1d};
15use crate::helpers::{
16 cumulative_trapz, gradient_uniform, l2_distance, linear_interp, simpsons_weights,
17};
18use crate::iter_maybe_parallel;
19use crate::matrix::FdMatrix;
20use crate::warping::{
21 exp_map_sphere, gam_to_psi, inv_exp_map_sphere, invert_gamma, l2_norm_l2, normalize_warp,
22 psi_to_gam,
23};
24#[cfg(feature = "parallel")]
25use rayon::iter::ParallelIterator;
26
27#[derive(Debug, Clone)]
31pub struct AlignmentResult {
32 pub gamma: Vec<f64>,
34 pub f_aligned: Vec<f64>,
36 pub distance: f64,
38}
39
40#[derive(Debug, Clone)]
42pub struct AlignmentSetResult {
43 pub gammas: FdMatrix,
45 pub aligned_data: FdMatrix,
47 pub distances: Vec<f64>,
49}
50
51#[derive(Debug, Clone)]
53pub struct KarcherMeanResult {
54 pub mean: Vec<f64>,
56 pub mean_srsf: Vec<f64>,
58 pub gammas: FdMatrix,
60 pub aligned_data: FdMatrix,
62 pub n_iter: usize,
64 pub converged: bool,
66}
67
68fn karcher_sphere_step(mu: &mut Vec<f64>, psis: &[Vec<f64>], time: &[f64], step_size: f64) -> bool {
81 let m = mu.len();
82 let n = psis.len();
83 let mut vbar = vec![0.0; m];
84 for psi in psis {
85 let v = inv_exp_map_sphere(mu, psi, time);
86 for j in 0..m {
87 vbar[j] += v[j];
88 }
89 }
90 for j in 0..m {
91 vbar[j] /= n as f64;
92 }
93 if l2_norm_l2(&vbar, time) <= 1e-8 {
94 return true;
95 }
96 let scaled: Vec<f64> = vbar.iter().map(|&v| v * step_size).collect();
97 *mu = exp_map_sphere(mu, &scaled, time);
98 false
99}
100
101fn sqrt_mean_inverse(gammas: &FdMatrix, argvals: &[f64]) -> Vec<f64> {
102 let (n, m) = gammas.shape();
103 let t0 = argvals[0];
104 let t1 = argvals[m - 1];
105 let domain = t1 - t0;
106
107 let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
108 let binsize = 1.0 / (m - 1) as f64;
109
110 let psis: Vec<Vec<f64>> = (0..n)
111 .map(|i| {
112 let gam_01: Vec<f64> = (0..m).map(|j| (gammas[(i, j)] - t0) / domain).collect();
113 gam_to_psi(&gam_01, binsize)
114 })
115 .collect();
116
117 let mut mu = vec![0.0; m];
118 for psi in &psis {
119 for j in 0..m {
120 mu[j] += psi[j];
121 }
122 }
123 for j in 0..m {
124 mu[j] /= n as f64;
125 }
126
127 for _ in 0..501 {
128 if karcher_sphere_step(&mut mu, &psis, &time, 0.3) {
129 break;
130 }
131 }
132
133 let gam_mu = psi_to_gam(&mu, &time);
134 let gam_inv = invert_gamma(&gam_mu, &time);
135 gam_inv.iter().map(|&g| t0 + g * domain).collect()
136}
137
138pub fn srsf_transform(data: &FdMatrix, argvals: &[f64]) -> FdMatrix {
151 let (n, m) = data.shape();
152 if n == 0 || m == 0 || argvals.len() != m {
153 return FdMatrix::zeros(n, m);
154 }
155
156 let deriv = deriv_1d(data, argvals, 1);
157
158 let mut result = FdMatrix::zeros(n, m);
159 for i in 0..n {
160 for j in 0..m {
161 let d = deriv[(i, j)];
162 result[(i, j)] = d.signum() * d.abs().sqrt();
163 }
164 }
165 result
166}
167
168pub fn srsf_inverse(q: &[f64], argvals: &[f64], f0: f64) -> Vec<f64> {
180 let m = q.len();
181 if m == 0 {
182 return Vec::new();
183 }
184
185 let integrand: Vec<f64> = q.iter().map(|&qi| qi * qi.abs()).collect();
187 let integral = cumulative_trapz(&integrand, argvals);
188
189 integral.iter().map(|&v| f0 + v).collect()
190}
191
192pub fn reparameterize_curve(f: &[f64], argvals: &[f64], gamma: &[f64]) -> Vec<f64> {
203 gamma
204 .iter()
205 .map(|&g| linear_interp(argvals, f, g))
206 .collect()
207}
208
209pub fn compose_warps(gamma1: &[f64], gamma2: &[f64], argvals: &[f64]) -> Vec<f64> {
216 gamma2
217 .iter()
218 .map(|&g| linear_interp(argvals, gamma1, g))
219 .collect()
220}
221
222#[cfg(test)]
227fn gcd(a: usize, b: usize) -> usize {
228 if b == 0 {
229 a
230 } else {
231 gcd(b, a % b)
232 }
233}
234
235#[cfg(test)]
238fn generate_coprime_nbhd(nbhd_dim: usize) -> Vec<(usize, usize)> {
239 let mut pairs = Vec::new();
240 for i in 1..=nbhd_dim {
241 for j in 1..=nbhd_dim {
242 if gcd(i, j) == 1 {
243 pairs.push((i, j));
244 }
245 }
246 }
247 pairs
248}
249
250#[rustfmt::skip]
254const COPRIME_NBHD_7: [(usize, usize); 35] = [
255 (1,1),(1,2),(1,3),(1,4),(1,5),(1,6),(1,7),
256 (2,1), (2,3), (2,5), (2,7),
257 (3,1),(3,2), (3,4),(3,5), (3,7),
258 (4,1), (4,3), (4,5), (4,7),
259 (5,1),(5,2),(5,3),(5,4), (5,6),(5,7),
260 (6,1), (6,5), (6,7),
261 (7,1),(7,2),(7,3),(7,4),(7,5),(7,6),
262];
263
264#[inline]
272fn dp_edge_weight(
273 q1: &[f64],
274 q2: &[f64],
275 argvals: &[f64],
276 sc: usize,
277 tc: usize,
278 sr: usize,
279 tr: usize,
280) -> f64 {
281 let n1 = tc - sc;
282 let n2 = tr - sr;
283 if n1 == 0 || n2 == 0 {
284 return f64::INFINITY;
285 }
286
287 let slope = (argvals[tr] - argvals[sr]) / (argvals[tc] - argvals[sc]);
288 let rslope = slope.sqrt();
289
290 let mut weight = 0.0;
292 let mut i1 = 0usize; let mut i2 = 0usize; while i1 < n1 && i2 < n2 {
296 let left1 = i1 as f64 / n1 as f64;
298 let right1 = (i1 + 1) as f64 / n1 as f64;
299 let left2 = i2 as f64 / n2 as f64;
300 let right2 = (i2 + 1) as f64 / n2 as f64;
301
302 let left = left1.max(left2);
303 let right = right1.min(right2);
304 let dt = right - left;
305
306 if dt > 0.0 {
307 let diff = q1[sc + i1] - rslope * q2[sr + i2];
308 weight += diff * diff * dt;
309 }
310
311 if right1 < right2 {
313 i1 += 1;
314 } else if right2 < right1 {
315 i2 += 1;
316 } else {
317 i1 += 1;
318 i2 += 1;
319 }
320 }
321
322 weight * (argvals[tc] - argvals[sc])
324}
325
326#[inline]
328fn dp_lambda_penalty(
329 argvals: &[f64],
330 sc: usize,
331 tc: usize,
332 sr: usize,
333 tr: usize,
334 lambda: f64,
335) -> f64 {
336 if lambda > 0.0 {
337 let dt = argvals[tc] - argvals[sc];
338 let slope = (argvals[tr] - argvals[sr]) / dt;
339 lambda * (slope - 1.0).powi(2) * dt
340 } else {
341 0.0
342 }
343}
344
345fn dp_traceback(parent: &[u32], nrows: usize, ncols: usize) -> Vec<(usize, usize)> {
349 let mut path = Vec::with_capacity(nrows + ncols);
350 let mut cur = (nrows - 1) * ncols + (ncols - 1);
351 loop {
352 path.push((cur / ncols, cur % ncols));
353 if cur == 0 || parent[cur] == u32::MAX {
354 break;
355 }
356 cur = parent[cur] as usize;
357 }
358 path.reverse();
359 path
360}
361
362#[inline]
364fn dp_relax_cell<F>(
365 e: &mut [f64],
366 parent: &mut [u32],
367 ncols: usize,
368 tr: usize,
369 tc: usize,
370 edge_cost: &F,
371) where
372 F: Fn(usize, usize, usize, usize) -> f64,
373{
374 let idx = tr * ncols + tc;
375 for &(dr, dc) in &COPRIME_NBHD_7 {
376 if dr > tr || dc > tc {
377 continue;
378 }
379 let sr = tr - dr;
380 let sc = tc - dc;
381 let src_idx = sr * ncols + sc;
382 if e[src_idx] == f64::INFINITY {
383 continue;
384 }
385 let cost = e[src_idx] + edge_cost(sr, sc, tr, tc);
386 if cost < e[idx] {
387 e[idx] = cost;
388 parent[idx] = src_idx as u32;
389 }
390 }
391}
392
393fn dp_grid_solve<F>(nrows: usize, ncols: usize, edge_cost: F) -> Vec<(usize, usize)>
399where
400 F: Fn(usize, usize, usize, usize) -> f64,
401{
402 let mut e = vec![f64::INFINITY; nrows * ncols];
403 let mut parent = vec![u32::MAX; nrows * ncols];
404 e[0] = 0.0;
405
406 for tr in 0..nrows {
407 for tc in 0..ncols {
408 if tr == 0 && tc == 0 {
409 continue;
410 }
411 dp_relax_cell(&mut e, &mut parent, ncols, tr, tc, &edge_cost);
412 }
413 }
414
415 dp_traceback(&parent, nrows, ncols)
416}
417
418fn dp_path_to_gamma(path: &[(usize, usize)], argvals: &[f64]) -> Vec<f64> {
420 let path_tc: Vec<f64> = path.iter().map(|&(_, c)| argvals[c]).collect();
421 let path_tr: Vec<f64> = path.iter().map(|&(r, _)| argvals[r]).collect();
422 let mut gamma: Vec<f64> = argvals
423 .iter()
424 .map(|&t| linear_interp(&path_tc, &path_tr, t))
425 .collect();
426 normalize_warp(&mut gamma, argvals);
427 gamma
428}
429
430fn dp_alignment_core(q1: &[f64], q2: &[f64], argvals: &[f64], lambda: f64) -> Vec<f64> {
436 let m = argvals.len();
437 if m < 2 {
438 return argvals.to_vec();
439 }
440
441 let norm1 = q1.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
442 let norm2 = q2.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
443 let q1n: Vec<f64> = q1.iter().map(|&v| v / norm1).collect();
444 let q2n: Vec<f64> = q2.iter().map(|&v| v / norm2).collect();
445
446 let path = dp_grid_solve(m, m, |sr, sc, tr, tc| {
447 dp_edge_weight(&q1n, &q2n, argvals, sc, tc, sr, tr)
448 + dp_lambda_penalty(argvals, sc, tc, sr, tr, lambda)
449 });
450
451 dp_path_to_gamma(&path, argvals)
452}
453
454pub fn elastic_align_pair(f1: &[f64], f2: &[f64], argvals: &[f64], lambda: f64) -> AlignmentResult {
470 let m = f1.len();
471
472 let f1_mat = FdMatrix::from_slice(f1, 1, m).unwrap();
474 let f2_mat = FdMatrix::from_slice(f2, 1, m).unwrap();
475
476 let q1_mat = srsf_transform(&f1_mat, argvals);
477 let q2_mat = srsf_transform(&f2_mat, argvals);
478
479 let q1: Vec<f64> = q1_mat.row(0);
480 let q2: Vec<f64> = q2_mat.row(0);
481
482 let gamma = dp_alignment_core(&q1, &q2, argvals, lambda);
484
485 let f_aligned = reparameterize_curve(f2, argvals, &gamma);
487
488 let f_aligned_mat = FdMatrix::from_slice(&f_aligned, 1, m).unwrap();
490 let q_aligned_mat = srsf_transform(&f_aligned_mat, argvals);
491 let q_aligned: Vec<f64> = q_aligned_mat.row(0);
492
493 let weights = simpsons_weights(argvals);
494 let distance = l2_distance(&q1, &q_aligned, &weights);
495
496 AlignmentResult {
497 gamma,
498 f_aligned,
499 distance,
500 }
501}
502
503pub fn elastic_distance(f1: &[f64], f2: &[f64], argvals: &[f64], lambda: f64) -> f64 {
513 elastic_align_pair(f1, f2, argvals, lambda).distance
514}
515
516pub fn align_to_target(
527 data: &FdMatrix,
528 target: &[f64],
529 argvals: &[f64],
530 lambda: f64,
531) -> AlignmentSetResult {
532 let (n, m) = data.shape();
533
534 let results: Vec<AlignmentResult> = iter_maybe_parallel!(0..n)
535 .map(|i| {
536 let fi = data.row(i);
537 elastic_align_pair(target, &fi, argvals, lambda)
538 })
539 .collect();
540
541 let mut gammas = FdMatrix::zeros(n, m);
542 let mut aligned_data = FdMatrix::zeros(n, m);
543 let mut distances = Vec::with_capacity(n);
544
545 for (i, r) in results.into_iter().enumerate() {
546 for j in 0..m {
547 gammas[(i, j)] = r.gamma[j];
548 aligned_data[(i, j)] = r.f_aligned[j];
549 }
550 distances.push(r.distance);
551 }
552
553 AlignmentSetResult {
554 gammas,
555 aligned_data,
556 distances,
557 }
558}
559
560pub fn elastic_self_distance_matrix(data: &FdMatrix, argvals: &[f64], lambda: f64) -> FdMatrix {
575 let n = data.nrows();
576
577 let upper_vals: Vec<f64> = iter_maybe_parallel!(0..n)
578 .flat_map(|i| {
579 let fi = data.row(i);
580 ((i + 1)..n)
581 .map(|j| {
582 let fj = data.row(j);
583 elastic_distance(&fi, &fj, argvals, lambda)
584 })
585 .collect::<Vec<_>>()
586 })
587 .collect();
588
589 let mut dist = FdMatrix::zeros(n, n);
590 let mut idx = 0;
591 for i in 0..n {
592 for j in (i + 1)..n {
593 let d = upper_vals[idx];
594 dist[(i, j)] = d;
595 dist[(j, i)] = d;
596 idx += 1;
597 }
598 }
599 dist
600}
601
602pub fn elastic_cross_distance_matrix(
613 data1: &FdMatrix,
614 data2: &FdMatrix,
615 argvals: &[f64],
616 lambda: f64,
617) -> FdMatrix {
618 let n1 = data1.nrows();
619 let n2 = data2.nrows();
620
621 let vals: Vec<f64> = iter_maybe_parallel!(0..n1)
622 .flat_map(|i| {
623 let fi = data1.row(i);
624 (0..n2)
625 .map(|j| {
626 let fj = data2.row(j);
627 elastic_distance(&fi, &fj, argvals, lambda)
628 })
629 .collect::<Vec<_>>()
630 })
631 .collect();
632
633 let mut dist = FdMatrix::zeros(n1, n2);
634 for i in 0..n1 {
635 for j in 0..n2 {
636 dist[(i, j)] = vals[i * n2 + j];
637 }
638 }
639 dist
640}
641
642#[derive(Debug, Clone)]
646pub struct DecompositionResult {
647 pub alignment: AlignmentResult,
649 pub d_amplitude: f64,
651 pub d_phase: f64,
653}
654
655pub fn elastic_decomposition(
665 f1: &[f64],
666 f2: &[f64],
667 argvals: &[f64],
668 lambda: f64,
669) -> DecompositionResult {
670 let alignment = elastic_align_pair(f1, f2, argvals, lambda);
671 let d_amplitude = alignment.distance;
672 let d_phase = crate::warping::phase_distance(&alignment.gamma, argvals);
673 DecompositionResult {
674 alignment,
675 d_amplitude,
676 d_phase,
677 }
678}
679
680pub fn amplitude_distance(f1: &[f64], f2: &[f64], argvals: &[f64], lambda: f64) -> f64 {
682 elastic_distance(f1, f2, argvals, lambda)
683}
684
685pub fn phase_distance_pair(f1: &[f64], f2: &[f64], argvals: &[f64], lambda: f64) -> f64 {
687 let alignment = elastic_align_pair(f1, f2, argvals, lambda);
688 crate::warping::phase_distance(&alignment.gamma, argvals)
689}
690
691pub fn phase_self_distance_matrix(data: &FdMatrix, argvals: &[f64], lambda: f64) -> FdMatrix {
693 let n = data.nrows();
694
695 let upper_vals: Vec<f64> = iter_maybe_parallel!(0..n)
696 .flat_map(|i| {
697 let fi = data.row(i);
698 ((i + 1)..n)
699 .map(|j| {
700 let fj = data.row(j);
701 phase_distance_pair(&fi, &fj, argvals, lambda)
702 })
703 .collect::<Vec<_>>()
704 })
705 .collect();
706
707 let mut dist = FdMatrix::zeros(n, n);
708 let mut idx = 0;
709 for i in 0..n {
710 for j in (i + 1)..n {
711 let d = upper_vals[idx];
712 dist[(i, j)] = d;
713 dist[(j, i)] = d;
714 idx += 1;
715 }
716 }
717 dist
718}
719
720pub fn amplitude_self_distance_matrix(data: &FdMatrix, argvals: &[f64], lambda: f64) -> FdMatrix {
722 elastic_self_distance_matrix(data, argvals, lambda)
723}
724
725fn relative_change(q_old: &[f64], q_new: &[f64]) -> f64 {
732 let diff_norm: f64 = q_old
733 .iter()
734 .zip(q_new.iter())
735 .map(|(&a, &b)| (a - b).powi(2))
736 .sum::<f64>()
737 .sqrt();
738 let old_norm: f64 = q_old.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
739 diff_norm / old_norm
740}
741
742fn srsf_single(f: &[f64], argvals: &[f64]) -> Vec<f64> {
744 let m = f.len();
745 let mat = FdMatrix::from_slice(f, 1, m).unwrap();
746 let q_mat = srsf_transform(&mat, argvals);
747 q_mat.row(0)
748}
749
750fn align_srsf_pair(q1: &[f64], q2: &[f64], argvals: &[f64], lambda: f64) -> (Vec<f64>, Vec<f64>) {
752 let gamma = dp_alignment_core(q1, q2, argvals, lambda);
753
754 let q2_warped = reparameterize_curve(q2, argvals, &gamma);
756
757 let m = gamma.len();
759 let mut gamma_dot = vec![0.0; m];
760 gamma_dot[0] = (gamma[1] - gamma[0]) / (argvals[1] - argvals[0]);
761 for j in 1..(m - 1) {
762 gamma_dot[j] = (gamma[j + 1] - gamma[j - 1]) / (argvals[j + 1] - argvals[j - 1]);
763 }
764 gamma_dot[m - 1] = (gamma[m - 1] - gamma[m - 2]) / (argvals[m - 1] - argvals[m - 2]);
765
766 let q2_aligned: Vec<f64> = q2_warped
768 .iter()
769 .zip(gamma_dot.iter())
770 .map(|(&q, &gd)| q * gd.max(0.0).sqrt())
771 .collect();
772
773 (gamma, q2_aligned)
774}
775
776fn accumulate_alignments(
805 results: &[(Vec<f64>, Vec<f64>)],
806 gammas: &mut FdMatrix,
807 m: usize,
808 n: usize,
809) -> Vec<f64> {
810 let mut mu_q_new = vec![0.0; m];
811 for (i, (gamma, q_aligned)) in results.iter().enumerate() {
812 for j in 0..m {
813 gammas[(i, j)] = gamma[j];
814 mu_q_new[j] += q_aligned[j];
815 }
816 }
817 for j in 0..m {
818 mu_q_new[j] /= n as f64;
819 }
820 mu_q_new
821}
822
823fn apply_stored_warps(data: &FdMatrix, gammas: &FdMatrix, argvals: &[f64]) -> FdMatrix {
825 let (n, m) = data.shape();
826 let mut aligned = FdMatrix::zeros(n, m);
827 for i in 0..n {
828 let fi = data.row(i);
829 let gamma: Vec<f64> = (0..m).map(|j| gammas[(i, j)]).collect();
830 let f_aligned = reparameterize_curve(&fi, argvals, &gamma);
831 for j in 0..m {
832 aligned[(i, j)] = f_aligned[j];
833 }
834 }
835 aligned
836}
837
838fn select_template(srsf_mat: &FdMatrix, data: &FdMatrix, argvals: &[f64]) -> (Vec<f64>, Vec<f64>) {
840 let (n, m) = srsf_mat.shape();
841 let mnq = mean_1d(srsf_mat);
842 let mut min_dist = f64::INFINITY;
843 let mut min_idx = 0;
844 for i in 0..n {
845 let dist_sq: f64 = (0..m).map(|j| (srsf_mat[(i, j)] - mnq[j]).powi(2)).sum();
846 if dist_sq < min_dist {
847 min_dist = dist_sq;
848 min_idx = i;
849 }
850 }
851 let _ = argvals; (srsf_mat.row(min_idx), data.row(min_idx))
853}
854
855fn pre_center_template(
857 data: &FdMatrix,
858 mu_q: &[f64],
859 mu: &[f64],
860 argvals: &[f64],
861 lambda: f64,
862) -> (Vec<f64>, Vec<f64>) {
863 let (n, m) = data.shape();
864 let align_results: Vec<(Vec<f64>, Vec<f64>)> = iter_maybe_parallel!(0..n)
865 .map(|i| {
866 let fi = data.row(i);
867 let qi = srsf_single(&fi, argvals);
868 align_srsf_pair(mu_q, &qi, argvals, lambda)
869 })
870 .collect();
871
872 let mut init_gammas = FdMatrix::zeros(n, m);
873 for (i, (gamma, _)) in align_results.iter().enumerate() {
874 for j in 0..m {
875 init_gammas[(i, j)] = gamma[j];
876 }
877 }
878
879 let gam_inv = sqrt_mean_inverse(&init_gammas, argvals);
880 let mu_new = reparameterize_curve(mu, argvals, &gam_inv);
881 let mu_q_new = srsf_single(&mu_new, argvals);
882 (mu_q_new, mu_new)
883}
884
885fn post_center_results(
887 data: &FdMatrix,
888 mu_q: &[f64],
889 final_gammas: &mut FdMatrix,
890 argvals: &[f64],
891) -> (Vec<f64>, Vec<f64>, FdMatrix) {
892 let (n, m) = data.shape();
893 let gam_inv = sqrt_mean_inverse(final_gammas, argvals);
894 let h = (argvals[m - 1] - argvals[0]) / (m - 1) as f64;
895 let gam_inv_dev = gradient_uniform(&gam_inv, h);
896
897 let mu_q_warped = reparameterize_curve(mu_q, argvals, &gam_inv);
898 let mu_q_centered: Vec<f64> = mu_q_warped
899 .iter()
900 .zip(gam_inv_dev.iter())
901 .map(|(&q, &gd)| q * gd.max(0.0).sqrt())
902 .collect();
903
904 for i in 0..n {
905 let gam_i: Vec<f64> = (0..m).map(|j| final_gammas[(i, j)]).collect();
906 let gam_centered = reparameterize_curve(&gam_i, argvals, &gam_inv);
907 for j in 0..m {
908 final_gammas[(i, j)] = gam_centered[j];
909 }
910 }
911
912 let initial_mean = mean_1d(data);
913 let mu = srsf_inverse(&mu_q_centered, argvals, initial_mean[0]);
914 let final_aligned = apply_stored_warps(data, final_gammas, argvals);
915 (mu, mu_q_centered, final_aligned)
916}
917
918pub fn karcher_mean(
919 data: &FdMatrix,
920 argvals: &[f64],
921 max_iter: usize,
922 tol: f64,
923 lambda: f64,
924) -> KarcherMeanResult {
925 let (n, m) = data.shape();
926
927 let srsf_mat = srsf_transform(data, argvals);
928 let (mut mu_q, mu) = select_template(&srsf_mat, data, argvals);
929 let (mu_q_c, mu_c) = pre_center_template(data, &mu_q, &mu, argvals, lambda);
930 mu_q = mu_q_c;
931 let mut mu = mu_c;
932
933 let mut converged = false;
934 let mut n_iter = 0;
935 let mut final_gammas = FdMatrix::zeros(n, m);
936 let mut prev_rel = 0.0_f64;
937
938 for iter in 0..max_iter {
939 n_iter = iter + 1;
940
941 let align_results: Vec<(Vec<f64>, Vec<f64>)> = iter_maybe_parallel!(0..n)
942 .map(|i| {
943 let fi = data.row(i);
944 let qi = srsf_single(&fi, argvals);
945 align_srsf_pair(&mu_q, &qi, argvals, lambda)
946 })
947 .collect();
948
949 let mu_q_new = accumulate_alignments(&align_results, &mut final_gammas, m, n);
950
951 let rel = relative_change(&mu_q, &mu_q_new);
952 if rel < f64::EPSILON || (iter > 0 && rel - prev_rel <= tol * prev_rel) {
953 converged = true;
954 mu_q = mu_q_new;
955 break;
956 }
957 prev_rel = rel;
958
959 mu_q = mu_q_new;
960 mu = srsf_inverse(&mu_q, argvals, mu[0]);
961 }
962
963 let (mu_final, mu_q_final, final_aligned) =
964 post_center_results(data, &mu_q, &mut final_gammas, argvals);
965
966 KarcherMeanResult {
967 mean: mu_final,
968 mean_srsf: mu_q_final,
969 gammas: final_gammas,
970 aligned_data: final_aligned,
971 n_iter,
972 converged,
973 }
974}
975
976#[derive(Debug, Clone)]
983pub struct TsrvfResult {
984 pub tangent_vectors: FdMatrix,
986 pub mean: Vec<f64>,
988 pub mean_srsf: Vec<f64>,
990 pub mean_srsf_norm: f64,
992 pub srsf_norms: Vec<f64>,
994 pub gammas: FdMatrix,
996 pub converged: bool,
998}
999
1000pub fn tsrvf_transform(
1011 data: &FdMatrix,
1012 argvals: &[f64],
1013 max_iter: usize,
1014 tol: f64,
1015 lambda: f64,
1016) -> TsrvfResult {
1017 let karcher = karcher_mean(data, argvals, max_iter, tol, lambda);
1018 tsrvf_from_alignment(&karcher, argvals)
1019}
1020
1021pub fn tsrvf_from_alignment(karcher: &KarcherMeanResult, argvals: &[f64]) -> TsrvfResult {
1033 let (n, m) = karcher.aligned_data.shape();
1034
1035 let aligned_srsf = srsf_transform(&karcher.aligned_data, argvals);
1037
1038 let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
1041 let mean_norm = l2_norm_l2(&karcher.mean_srsf, &time);
1042
1043 let mu_unit: Vec<f64> = if mean_norm > 1e-10 {
1044 karcher.mean_srsf.iter().map(|&q| q / mean_norm).collect()
1045 } else {
1046 vec![0.0; m]
1047 };
1048
1049 let srsf_norms: Vec<f64> = iter_maybe_parallel!(0..n)
1051 .map(|i| {
1052 let qi = aligned_srsf.row(i);
1053 l2_norm_l2(&qi, &time)
1054 })
1055 .collect();
1056
1057 let tangent_data: Vec<Vec<f64>> = iter_maybe_parallel!(0..n)
1058 .map(|i| {
1059 let qi = aligned_srsf.row(i);
1060 let qi_norm = srsf_norms[i];
1061
1062 if qi_norm < 1e-10 || mean_norm < 1e-10 {
1063 return vec![0.0; m];
1064 }
1065
1066 let qi_unit: Vec<f64> = qi.iter().map(|&q| q / qi_norm).collect();
1068
1069 inv_exp_map_sphere(&mu_unit, &qi_unit, &time)
1071 })
1072 .collect();
1073
1074 let mut tangent_vectors = FdMatrix::zeros(n, m);
1076 for i in 0..n {
1077 for j in 0..m {
1078 tangent_vectors[(i, j)] = tangent_data[i][j];
1079 }
1080 }
1081
1082 TsrvfResult {
1083 tangent_vectors,
1084 mean: karcher.mean.clone(),
1085 mean_srsf: karcher.mean_srsf.clone(),
1086 mean_srsf_norm: mean_norm,
1087 srsf_norms,
1088 gammas: karcher.gammas.clone(),
1089 converged: karcher.converged,
1090 }
1091}
1092
1093pub fn tsrvf_inverse(tsrvf: &TsrvfResult, argvals: &[f64]) -> FdMatrix {
1105 let (n, m) = tsrvf.tangent_vectors.shape();
1106
1107 let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
1108
1109 let mu_unit: Vec<f64> = if tsrvf.mean_srsf_norm > 1e-10 {
1111 tsrvf
1112 .mean_srsf
1113 .iter()
1114 .map(|&q| q / tsrvf.mean_srsf_norm)
1115 .collect()
1116 } else {
1117 vec![0.0; m]
1118 };
1119
1120 let curves: Vec<Vec<f64>> = iter_maybe_parallel!(0..n)
1121 .map(|i| {
1122 let vi = tsrvf.tangent_vectors.row(i);
1123
1124 let qi_unit = exp_map_sphere(&mu_unit, &vi, &time);
1126
1127 let qi: Vec<f64> = qi_unit.iter().map(|&q| q * tsrvf.srsf_norms[i]).collect();
1129
1130 srsf_inverse(&qi, argvals, tsrvf.mean[0])
1132 })
1133 .collect();
1134
1135 let mut result = FdMatrix::zeros(n, m);
1136 for i in 0..n {
1137 for j in 0..m {
1138 result[(i, j)] = curves[i][j];
1139 }
1140 }
1141 result
1142}
1143
1144#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
1148pub enum TransportMethod {
1149 #[default]
1151 LogMap,
1152 SchildsLadder,
1154 PoleLadder,
1156}
1157
1158fn parallel_transport_schilds(v: &[f64], from: &[f64], to: &[f64], time: &[f64]) -> Vec<f64> {
1160 use crate::warping::{exp_map_sphere, inv_exp_map_sphere};
1161
1162 let v_norm = crate::warping::l2_norm_l2(v, time);
1163 if v_norm < 1e-10 {
1164 return vec![0.0; v.len()];
1165 }
1166
1167 let endpoint = exp_map_sphere(from, v, time);
1169
1170 let log_to_ep = inv_exp_map_sphere(to, &endpoint, time);
1172
1173 let half_log: Vec<f64> = log_to_ep.iter().map(|&x| 0.5 * x).collect();
1175 let midpoint = exp_map_sphere(to, &half_log, time);
1176
1177 let log_to_mid = inv_exp_map_sphere(to, &midpoint, time);
1179 log_to_mid.iter().map(|&x| 2.0 * x).collect()
1180}
1181
1182fn parallel_transport_pole(v: &[f64], from: &[f64], to: &[f64], time: &[f64]) -> Vec<f64> {
1184 use crate::warping::{exp_map_sphere, inv_exp_map_sphere};
1185
1186 let v_norm = crate::warping::l2_norm_l2(v, time);
1187 if v_norm < 1e-10 {
1188 return vec![0.0; v.len()];
1189 }
1190
1191 let neg_v: Vec<f64> = v.iter().map(|&x| -x).collect();
1193 let pole = exp_map_sphere(from, &neg_v, time);
1194
1195 let log_to_pole = inv_exp_map_sphere(to, &pole, time);
1197
1198 let half_log: Vec<f64> = log_to_pole.iter().map(|&x| 0.5 * x).collect();
1200 let midpoint = exp_map_sphere(to, &half_log, time);
1201
1202 let log_to_mid = inv_exp_map_sphere(to, &midpoint, time);
1204 log_to_mid.iter().map(|&x| -2.0 * x).collect()
1205}
1206
1207pub fn tsrvf_transform_with_method(
1211 data: &FdMatrix,
1212 argvals: &[f64],
1213 max_iter: usize,
1214 tol: f64,
1215 lambda: f64,
1216 method: TransportMethod,
1217) -> TsrvfResult {
1218 let karcher = karcher_mean(data, argvals, max_iter, tol, lambda);
1219 tsrvf_from_alignment_with_method(&karcher, argvals, method)
1220}
1221
1222pub fn tsrvf_from_alignment_with_method(
1229 karcher: &KarcherMeanResult,
1230 argvals: &[f64],
1231 method: TransportMethod,
1232) -> TsrvfResult {
1233 if method == TransportMethod::LogMap {
1234 return tsrvf_from_alignment(karcher, argvals);
1235 }
1236
1237 let (n, m) = karcher.aligned_data.shape();
1238 let aligned_srsf = srsf_transform(&karcher.aligned_data, argvals);
1239 let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
1240 let mean_norm = crate::warping::l2_norm_l2(&karcher.mean_srsf, &time);
1241
1242 let mu_unit: Vec<f64> = if mean_norm > 1e-10 {
1243 karcher.mean_srsf.iter().map(|&q| q / mean_norm).collect()
1244 } else {
1245 vec![0.0; m]
1246 };
1247
1248 let srsf_norms: Vec<f64> = iter_maybe_parallel!(0..n)
1249 .map(|i| {
1250 let qi = aligned_srsf.row(i);
1251 crate::warping::l2_norm_l2(&qi, &time)
1252 })
1253 .collect();
1254
1255 let tangent_data: Vec<Vec<f64>> = iter_maybe_parallel!(0..n)
1256 .map(|i| {
1257 let qi = aligned_srsf.row(i);
1258 let qi_norm = srsf_norms[i];
1259
1260 if qi_norm < 1e-10 || mean_norm < 1e-10 {
1261 return vec![0.0; m];
1262 }
1263
1264 let qi_unit: Vec<f64> = qi.iter().map(|&q| q / qi_norm).collect();
1265
1266 let v_at_qi = inv_exp_map_sphere(&qi_unit, &mu_unit, &time);
1268 let neg_v: Vec<f64> = v_at_qi.iter().map(|&x| -x).collect();
1269
1270 match method {
1272 TransportMethod::SchildsLadder => {
1273 parallel_transport_schilds(&neg_v, &qi_unit, &mu_unit, &time)
1274 }
1275 TransportMethod::PoleLadder => {
1276 parallel_transport_pole(&neg_v, &qi_unit, &mu_unit, &time)
1277 }
1278 TransportMethod::LogMap => unreachable!(),
1279 }
1280 })
1281 .collect();
1282
1283 let mut tangent_vectors = FdMatrix::zeros(n, m);
1284 for i in 0..n {
1285 for j in 0..m {
1286 tangent_vectors[(i, j)] = tangent_data[i][j];
1287 }
1288 }
1289
1290 TsrvfResult {
1291 tangent_vectors,
1292 mean: karcher.mean.clone(),
1293 mean_srsf: karcher.mean_srsf.clone(),
1294 mean_srsf_norm: mean_norm,
1295 srsf_norms,
1296 gammas: karcher.gammas.clone(),
1297 converged: karcher.converged,
1298 }
1299}
1300
1301#[derive(Debug, Clone)]
1305pub struct AlignmentQuality {
1306 pub warp_complexity: Vec<f64>,
1308 pub mean_warp_complexity: f64,
1310 pub warp_smoothness: Vec<f64>,
1312 pub mean_warp_smoothness: f64,
1314 pub total_variance: f64,
1316 pub amplitude_variance: f64,
1318 pub phase_variance: f64,
1320 pub phase_amplitude_ratio: f64,
1322 pub pointwise_variance_ratio: Vec<f64>,
1324 pub mean_variance_reduction: f64,
1326}
1327
1328pub fn warp_complexity(gamma: &[f64], argvals: &[f64]) -> f64 {
1332 crate::warping::phase_distance(gamma, argvals)
1333}
1334
1335pub fn warp_smoothness(gamma: &[f64], argvals: &[f64]) -> f64 {
1337 let m = gamma.len();
1338 if m < 3 {
1339 return 0.0;
1340 }
1341
1342 let h = (argvals[m - 1] - argvals[0]) / (m - 1) as f64;
1343 let gam_prime = gradient_uniform(gamma, h);
1344 let gam_pprime = gradient_uniform(&gam_prime, h);
1345
1346 let integrand: Vec<f64> = gam_pprime.iter().map(|&g| g * g).collect();
1347 crate::helpers::trapz(&integrand, argvals)
1348}
1349
1350pub fn alignment_quality(
1357 data: &FdMatrix,
1358 karcher: &KarcherMeanResult,
1359 argvals: &[f64],
1360) -> AlignmentQuality {
1361 let (n, m) = data.shape();
1362 let weights = simpsons_weights(argvals);
1363
1364 let wc: Vec<f64> = (0..n)
1366 .map(|i| {
1367 let gamma: Vec<f64> = (0..m).map(|j| karcher.gammas[(i, j)]).collect();
1368 warp_complexity(&gamma, argvals)
1369 })
1370 .collect();
1371 let ws: Vec<f64> = (0..n)
1372 .map(|i| {
1373 let gamma: Vec<f64> = (0..m).map(|j| karcher.gammas[(i, j)]).collect();
1374 warp_smoothness(&gamma, argvals)
1375 })
1376 .collect();
1377
1378 let mean_wc = wc.iter().sum::<f64>() / n as f64;
1379 let mean_ws = ws.iter().sum::<f64>() / n as f64;
1380
1381 let orig_mean = crate::fdata::mean_1d(data);
1383
1384 let total_var: f64 = (0..n)
1386 .map(|i| {
1387 let fi = data.row(i);
1388 let d = l2_distance(&fi, &orig_mean, &weights);
1389 d * d
1390 })
1391 .sum::<f64>()
1392 / n as f64;
1393
1394 let aligned_mean = crate::fdata::mean_1d(&karcher.aligned_data);
1396
1397 let amp_var: f64 = (0..n)
1399 .map(|i| {
1400 let fi = karcher.aligned_data.row(i);
1401 let d = l2_distance(&fi, &aligned_mean, &weights);
1402 d * d
1403 })
1404 .sum::<f64>()
1405 / n as f64;
1406
1407 let phase_var = (total_var - amp_var).max(0.0);
1408 let ratio = if total_var > 1e-10 {
1409 phase_var / total_var
1410 } else {
1411 0.0
1412 };
1413
1414 let mut pw_ratio = vec![0.0; m];
1416 for j in 0..m {
1417 let col_orig = data.column(j);
1418 let mean_orig = col_orig.iter().sum::<f64>() / n as f64;
1419 let var_orig: f64 = col_orig
1420 .iter()
1421 .map(|&v| (v - mean_orig).powi(2))
1422 .sum::<f64>()
1423 / n as f64;
1424
1425 let col_aligned = karcher.aligned_data.column(j);
1426 let mean_aligned = col_aligned.iter().sum::<f64>() / n as f64;
1427 let var_aligned: f64 = col_aligned
1428 .iter()
1429 .map(|&v| (v - mean_aligned).powi(2))
1430 .sum::<f64>()
1431 / n as f64;
1432
1433 pw_ratio[j] = if var_orig > 1e-15 {
1434 var_aligned / var_orig
1435 } else {
1436 1.0
1437 };
1438 }
1439
1440 let mean_vr = pw_ratio.iter().sum::<f64>() / m as f64;
1441
1442 AlignmentQuality {
1443 warp_complexity: wc,
1444 mean_warp_complexity: mean_wc,
1445 warp_smoothness: ws,
1446 mean_warp_smoothness: mean_ws,
1447 total_variance: total_var,
1448 amplitude_variance: amp_var,
1449 phase_variance: phase_var,
1450 phase_amplitude_ratio: ratio,
1451 pointwise_variance_ratio: pw_ratio,
1452 mean_variance_reduction: mean_vr,
1453 }
1454}
1455
1456fn triplet_indices(n: usize, max_triplets: usize) -> Vec<(usize, usize, usize)> {
1458 let total = n * (n - 1) * (n - 2) / 6;
1459 let cap = if max_triplets > 0 {
1460 max_triplets.min(total)
1461 } else {
1462 total
1463 };
1464 (0..n)
1465 .flat_map(|i| ((i + 1)..n).flat_map(move |j| ((j + 1)..n).map(move |k| (i, j, k))))
1466 .take(cap)
1467 .collect()
1468}
1469
1470fn triplet_warp_deviation(
1472 data: &FdMatrix,
1473 argvals: &[f64],
1474 weights: &[f64],
1475 i: usize,
1476 j: usize,
1477 k: usize,
1478 lambda: f64,
1479) -> f64 {
1480 let fi = data.row(i);
1481 let fj = data.row(j);
1482 let fk = data.row(k);
1483 let rij = elastic_align_pair(&fi, &fj, argvals, lambda);
1484 let rjk = elastic_align_pair(&fj, &fk, argvals, lambda);
1485 let rik = elastic_align_pair(&fi, &fk, argvals, lambda);
1486 let composed = compose_warps(&rij.gamma, &rjk.gamma, argvals);
1487 l2_distance(&composed, &rik.gamma, weights)
1488}
1489
1490pub fn pairwise_consistency(
1501 data: &FdMatrix,
1502 argvals: &[f64],
1503 lambda: f64,
1504 max_triplets: usize,
1505) -> f64 {
1506 let n = data.nrows();
1507 if n < 3 {
1508 return 0.0;
1509 }
1510
1511 let weights = simpsons_weights(argvals);
1512 let triplets = triplet_indices(n, max_triplets);
1513 if triplets.is_empty() {
1514 return 0.0;
1515 }
1516
1517 let total_dev: f64 = triplets
1518 .iter()
1519 .map(|&(i, j, k)| triplet_warp_deviation(data, argvals, &weights, i, j, k, lambda))
1520 .sum();
1521 total_dev / triplets.len() as f64
1522}
1523
1524#[derive(Debug, Clone)]
1528pub struct ConstrainedAlignmentResult {
1529 pub gamma: Vec<f64>,
1531 pub f_aligned: Vec<f64>,
1533 pub distance: f64,
1535 pub enforced_landmarks: Vec<(f64, f64)>,
1537}
1538
1539fn snap_to_grid(t_val: f64, argvals: &[f64]) -> usize {
1541 let mut best = 0;
1542 let mut best_dist = (t_val - argvals[0]).abs();
1543 for (i, &a) in argvals.iter().enumerate().skip(1) {
1544 let d = (t_val - a).abs();
1545 if d < best_dist {
1546 best = i;
1547 best_dist = d;
1548 }
1549 }
1550 best
1551}
1552
1553fn dp_segment(
1558 q1: &[f64],
1559 q2: &[f64],
1560 argvals: &[f64],
1561 sc: usize,
1562 ec: usize,
1563 sr: usize,
1564 er: usize,
1565 lambda: f64,
1566) -> Vec<(usize, usize)> {
1567 let nc = ec - sc + 1;
1568 let nr = er - sr + 1;
1569
1570 if nc <= 1 || nr <= 1 {
1571 return vec![(sc, sr), (ec, er)];
1572 }
1573
1574 let path = dp_grid_solve(nr, nc, |local_sr, local_sc, local_tr, local_tc| {
1575 let gsr = sr + local_sr;
1576 let gsc = sc + local_sc;
1577 let gtr = sr + local_tr;
1578 let gtc = sc + local_tc;
1579 dp_edge_weight(q1, q2, argvals, gsc, gtc, gsr, gtr)
1580 + dp_lambda_penalty(argvals, gsc, gtc, gsr, gtr, lambda)
1581 });
1582
1583 path.iter().map(|&(lr, lc)| (sc + lc, sr + lr)).collect()
1585}
1586
1587fn build_constrained_waypoints(
1603 landmark_pairs: &[(f64, f64)],
1604 argvals: &[f64],
1605 m: usize,
1606) -> Vec<(usize, usize)> {
1607 let mut waypoints: Vec<(usize, usize)> = Vec::with_capacity(landmark_pairs.len() + 2);
1608 waypoints.push((0, 0));
1609 for &(tt, st) in landmark_pairs {
1610 let tc = snap_to_grid(tt, argvals);
1611 let tr = snap_to_grid(st, argvals);
1612 if let Some(&(prev_c, prev_r)) = waypoints.last() {
1613 if tc > prev_c && tr > prev_r {
1614 waypoints.push((tc, tr));
1615 }
1616 }
1617 }
1618 let last = m - 1;
1619 if let Some(&(prev_c, prev_r)) = waypoints.last() {
1620 if prev_c != last || prev_r != last {
1621 waypoints.push((last, last));
1622 }
1623 }
1624 waypoints
1625}
1626
1627fn segmented_dp_gamma(
1629 q1n: &[f64],
1630 q2n: &[f64],
1631 argvals: &[f64],
1632 waypoints: &[(usize, usize)],
1633 lambda: f64,
1634) -> Vec<f64> {
1635 let mut full_path_tc: Vec<f64> = Vec::new();
1636 let mut full_path_tr: Vec<f64> = Vec::new();
1637
1638 for seg in 0..(waypoints.len() - 1) {
1639 let (sc, sr) = waypoints[seg];
1640 let (ec, er) = waypoints[seg + 1];
1641 let segment_path = dp_segment(q1n, q2n, argvals, sc, ec, sr, er, lambda);
1642 let start = if seg > 0 { 1 } else { 0 };
1643 for &(tc, tr) in &segment_path[start..] {
1644 full_path_tc.push(argvals[tc]);
1645 full_path_tr.push(argvals[tr]);
1646 }
1647 }
1648
1649 let mut gamma: Vec<f64> = argvals
1650 .iter()
1651 .map(|&t| linear_interp(&full_path_tc, &full_path_tr, t))
1652 .collect();
1653 normalize_warp(&mut gamma, argvals);
1654 gamma
1655}
1656
1657pub fn elastic_align_pair_constrained(
1658 f1: &[f64],
1659 f2: &[f64],
1660 argvals: &[f64],
1661 landmark_pairs: &[(f64, f64)],
1662 lambda: f64,
1663) -> ConstrainedAlignmentResult {
1664 let m = f1.len();
1665
1666 if landmark_pairs.is_empty() {
1667 let r = elastic_align_pair(f1, f2, argvals, lambda);
1668 return ConstrainedAlignmentResult {
1669 gamma: r.gamma,
1670 f_aligned: r.f_aligned,
1671 distance: r.distance,
1672 enforced_landmarks: Vec::new(),
1673 };
1674 }
1675
1676 let f1_mat = FdMatrix::from_slice(f1, 1, m).unwrap();
1678 let f2_mat = FdMatrix::from_slice(f2, 1, m).unwrap();
1679 let q1_mat = srsf_transform(&f1_mat, argvals);
1680 let q2_mat = srsf_transform(&f2_mat, argvals);
1681 let q1: Vec<f64> = q1_mat.row(0);
1682 let q2: Vec<f64> = q2_mat.row(0);
1683 let norm1 = q1.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
1684 let norm2 = q2.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
1685 let q1n: Vec<f64> = q1.iter().map(|&v| v / norm1).collect();
1686 let q2n: Vec<f64> = q2.iter().map(|&v| v / norm2).collect();
1687
1688 let waypoints = build_constrained_waypoints(landmark_pairs, argvals, m);
1689 let gamma = segmented_dp_gamma(&q1n, &q2n, argvals, &waypoints, lambda);
1690
1691 let f_aligned = reparameterize_curve(f2, argvals, &gamma);
1692 let f_aligned_mat = FdMatrix::from_slice(&f_aligned, 1, m).unwrap();
1693 let q_aligned_mat = srsf_transform(&f_aligned_mat, argvals);
1694 let q_aligned: Vec<f64> = q_aligned_mat.row(0);
1695 let weights = simpsons_weights(argvals);
1696 let distance = l2_distance(&q1, &q_aligned, &weights);
1697
1698 let enforced: Vec<(f64, f64)> = waypoints[1..waypoints.len() - 1]
1699 .iter()
1700 .map(|&(tc, tr)| (argvals[tc], argvals[tr]))
1701 .collect();
1702
1703 ConstrainedAlignmentResult {
1704 gamma,
1705 f_aligned,
1706 distance,
1707 enforced_landmarks: enforced,
1708 }
1709}
1710
1711pub fn elastic_align_pair_with_landmarks(
1725 f1: &[f64],
1726 f2: &[f64],
1727 argvals: &[f64],
1728 kind: crate::landmark::LandmarkKind,
1729 min_prominence: f64,
1730 expected_count: usize,
1731 lambda: f64,
1732) -> ConstrainedAlignmentResult {
1733 let lm1 = crate::landmark::detect_landmarks(f1, argvals, kind, min_prominence);
1734 let lm2 = crate::landmark::detect_landmarks(f2, argvals, kind, min_prominence);
1735
1736 let n_match = if expected_count > 0 {
1738 expected_count.min(lm1.len()).min(lm2.len())
1739 } else {
1740 lm1.len().min(lm2.len())
1741 };
1742
1743 let pairs: Vec<(f64, f64)> = (0..n_match)
1744 .map(|i| (lm1[i].position, lm2[i].position))
1745 .collect();
1746
1747 elastic_align_pair_constrained(f1, f2, argvals, &pairs, lambda)
1748}
1749
1750use crate::matrix::FdCurveSet;
1753
1754#[derive(Debug, Clone)]
1756pub struct AlignmentResultNd {
1757 pub gamma: Vec<f64>,
1759 pub f_aligned: Vec<Vec<f64>>,
1761 pub distance: f64,
1763}
1764
1765#[inline]
1778fn srsf_scale_point(derivs: &[FdMatrix], result_dims: &mut [FdMatrix], i: usize, j: usize) {
1779 let d = derivs.len();
1780 let norm_sq: f64 = derivs.iter().map(|dd| dd[(i, j)].powi(2)).sum();
1781 let norm = norm_sq.sqrt();
1782 if norm < 1e-15 {
1783 for k in 0..d {
1784 result_dims[k][(i, j)] = 0.0;
1785 }
1786 } else {
1787 let scale = 1.0 / norm.sqrt();
1788 for k in 0..d {
1789 result_dims[k][(i, j)] = derivs[k][(i, j)] * scale;
1790 }
1791 }
1792}
1793
1794pub fn srsf_transform_nd(data: &FdCurveSet, argvals: &[f64]) -> FdCurveSet {
1795 let d = data.ndim();
1796 let n = data.ncurves();
1797 let m = data.npoints();
1798
1799 if d == 0 || n == 0 || m == 0 || argvals.len() != m {
1800 return FdCurveSet {
1801 dims: (0..d).map(|_| FdMatrix::zeros(n, m)).collect(),
1802 };
1803 }
1804
1805 let derivs: Vec<FdMatrix> = data
1806 .dims
1807 .iter()
1808 .map(|dim_mat| crate::fdata::deriv_1d(dim_mat, argvals, 1))
1809 .collect();
1810
1811 let mut result_dims: Vec<FdMatrix> = (0..d).map(|_| FdMatrix::zeros(n, m)).collect();
1812 for i in 0..n {
1813 for j in 0..m {
1814 srsf_scale_point(&derivs, &mut result_dims, i, j);
1815 }
1816 }
1817
1818 FdCurveSet { dims: result_dims }
1819}
1820
1821pub fn srsf_inverse_nd(q: &[Vec<f64>], argvals: &[f64], f0: &[f64]) -> Vec<Vec<f64>> {
1834 let d = q.len();
1835 if d == 0 {
1836 return Vec::new();
1837 }
1838 let m = q[0].len();
1839 if m == 0 {
1840 return vec![Vec::new(); d];
1841 }
1842
1843 let norms: Vec<f64> = (0..m)
1845 .map(|j| {
1846 let norm_sq: f64 = q.iter().map(|qk| qk[j].powi(2)).sum();
1847 norm_sq.sqrt()
1848 })
1849 .collect();
1850
1851 let mut result = Vec::with_capacity(d);
1853 for k in 0..d {
1854 let integrand: Vec<f64> = (0..m).map(|j| q[k][j] * norms[j]).collect();
1855 let integral = cumulative_trapz(&integrand, argvals);
1856 let curve: Vec<f64> = integral.iter().map(|&v| f0[k] + v).collect();
1857 result.push(curve);
1858 }
1859
1860 result
1861}
1862
1863fn dp_alignment_core_nd(
1868 q1: &[Vec<f64>],
1869 q2: &[Vec<f64>],
1870 argvals: &[f64],
1871 lambda: f64,
1872) -> Vec<f64> {
1873 let d = q1.len();
1874 let m = argvals.len();
1875 if m < 2 || d == 0 {
1876 return argvals.to_vec();
1877 }
1878
1879 if d == 1 {
1881 return dp_alignment_core(&q1[0], &q2[0], argvals, lambda);
1882 }
1883
1884 let q1n: Vec<Vec<f64>> = q1
1886 .iter()
1887 .map(|qk| {
1888 let norm = qk.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
1889 qk.iter().map(|&v| v / norm).collect()
1890 })
1891 .collect();
1892 let q2n: Vec<Vec<f64>> = q2
1893 .iter()
1894 .map(|qk| {
1895 let norm = qk.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
1896 qk.iter().map(|&v| v / norm).collect()
1897 })
1898 .collect();
1899
1900 let path = dp_grid_solve(m, m, |sr, sc, tr, tc| {
1901 let w: f64 = (0..d)
1902 .map(|k| dp_edge_weight(&q1n[k], &q2n[k], argvals, sc, tc, sr, tr))
1903 .sum();
1904 w + dp_lambda_penalty(argvals, sc, tc, sr, tr, lambda)
1905 });
1906
1907 dp_path_to_gamma(&path, argvals)
1908}
1909
1910pub fn elastic_align_pair_nd(
1921 f1: &FdCurveSet,
1922 f2: &FdCurveSet,
1923 argvals: &[f64],
1924 lambda: f64,
1925) -> AlignmentResultNd {
1926 let d = f1.ndim();
1927 let m = f1.npoints();
1928
1929 let q1_set = srsf_transform_nd(f1, argvals);
1931 let q2_set = srsf_transform_nd(f2, argvals);
1932
1933 let q1: Vec<Vec<f64>> = q1_set.dims.iter().map(|dm| dm.row(0)).collect();
1935 let q2: Vec<Vec<f64>> = q2_set.dims.iter().map(|dm| dm.row(0)).collect();
1936
1937 let gamma = dp_alignment_core_nd(&q1, &q2, argvals, lambda);
1939
1940 let f_aligned: Vec<Vec<f64>> = f2
1942 .dims
1943 .iter()
1944 .map(|dm| {
1945 let row = dm.row(0);
1946 reparameterize_curve(&row, argvals, &gamma)
1947 })
1948 .collect();
1949
1950 let f_aligned_set = {
1952 let dims: Vec<FdMatrix> = f_aligned
1953 .iter()
1954 .map(|fa| FdMatrix::from_slice(fa, 1, m).unwrap())
1955 .collect();
1956 FdCurveSet { dims }
1957 };
1958 let q_aligned = srsf_transform_nd(&f_aligned_set, argvals);
1959 let weights = simpsons_weights(argvals);
1960
1961 let mut dist_sq = 0.0;
1962 for k in 0..d {
1963 let q1k = q1_set.dims[k].row(0);
1964 let qak = q_aligned.dims[k].row(0);
1965 let d_k = l2_distance(&q1k, &qak, &weights);
1966 dist_sq += d_k * d_k;
1967 }
1968
1969 AlignmentResultNd {
1970 gamma,
1971 f_aligned,
1972 distance: dist_sq.sqrt(),
1973 }
1974}
1975
1976pub fn elastic_distance_nd(f1: &FdCurveSet, f2: &FdCurveSet, argvals: &[f64], lambda: f64) -> f64 {
1980 elastic_align_pair_nd(f1, f2, argvals, lambda).distance
1981}
1982
1983#[cfg(test)]
1986mod tests {
1987 use super::*;
1988 use crate::helpers::trapz;
1989 use crate::simulation::{sim_fundata, EFunType, EValType};
1990 use crate::warping::inner_product_l2;
1991
1992 fn uniform_grid(m: usize) -> Vec<f64> {
1993 (0..m).map(|i| i as f64 / (m - 1) as f64).collect()
1994 }
1995
1996 fn make_test_data(n: usize, m: usize, seed: u64) -> FdMatrix {
1997 let t = uniform_grid(m);
1998 sim_fundata(
1999 n,
2000 &t,
2001 3,
2002 EFunType::Fourier,
2003 EValType::Exponential,
2004 Some(seed),
2005 )
2006 }
2007
2008 #[test]
2011 fn test_cumulative_trapz_constant() {
2012 let x = uniform_grid(50);
2014 let y = vec![1.0; 50];
2015 let result = cumulative_trapz(&y, &x);
2016 assert!((result[0]).abs() < 1e-15, "cumulative_trapz(0) should be 0");
2017 for j in 1..50 {
2018 assert!(
2019 (result[j] - x[j]).abs() < 1e-12,
2020 "∫₀^{:.3} 1 dt should be {:.3}, got {:.3}",
2021 x[j],
2022 x[j],
2023 result[j]
2024 );
2025 }
2026 }
2027
2028 #[test]
2029 fn test_cumulative_trapz_linear() {
2030 let m = 100;
2032 let x = uniform_grid(m);
2033 let y: Vec<f64> = x.clone();
2034 let result = cumulative_trapz(&y, &x);
2035 for j in 1..m {
2036 let expected = x[j] * x[j] / 2.0;
2037 assert!(
2038 (result[j] - expected).abs() < 1e-4,
2039 "∫₀^{:.3} s ds: expected {expected:.6}, got {:.6}",
2040 x[j],
2041 result[j]
2042 );
2043 }
2044 }
2045
2046 #[test]
2049 fn test_normalize_warp_fixes_boundaries() {
2050 let t = uniform_grid(10);
2051 let mut gamma = vec![0.1; 10]; normalize_warp(&mut gamma, &t);
2053 assert_eq!(gamma[0], t[0]);
2054 assert_eq!(gamma[9], t[9]);
2055 }
2056
2057 #[test]
2058 fn test_normalize_warp_enforces_monotonicity() {
2059 let t = uniform_grid(5);
2060 let mut gamma = vec![0.0, 0.5, 0.3, 0.8, 1.0]; normalize_warp(&mut gamma, &t);
2062 for j in 1..5 {
2063 assert!(
2064 gamma[j] >= gamma[j - 1],
2065 "gamma should be monotone after normalization at j={j}"
2066 );
2067 }
2068 }
2069
2070 #[test]
2071 fn test_normalize_warp_identity_unchanged() {
2072 let t = uniform_grid(20);
2073 let mut gamma = t.clone();
2074 normalize_warp(&mut gamma, &t);
2075 for j in 0..20 {
2076 assert!(
2077 (gamma[j] - t[j]).abs() < 1e-15,
2078 "Identity warp should be unchanged"
2079 );
2080 }
2081 }
2082
2083 #[test]
2086 fn test_linear_interp_at_nodes() {
2087 let x = vec![0.0, 1.0, 2.0, 3.0];
2088 let y = vec![0.0, 2.0, 4.0, 6.0];
2089 for i in 0..x.len() {
2090 assert!((linear_interp(&x, &y, x[i]) - y[i]).abs() < 1e-12);
2091 }
2092 }
2093
2094 #[test]
2095 fn test_linear_interp_midpoints() {
2096 let x = vec![0.0, 1.0, 2.0];
2097 let y = vec![0.0, 2.0, 4.0];
2098 assert!((linear_interp(&x, &y, 0.5) - 1.0).abs() < 1e-12);
2099 assert!((linear_interp(&x, &y, 1.5) - 3.0).abs() < 1e-12);
2100 }
2101
2102 #[test]
2103 fn test_linear_interp_clamp() {
2104 let x = vec![0.0, 1.0, 2.0];
2105 let y = vec![1.0, 3.0, 5.0];
2106 assert!((linear_interp(&x, &y, -1.0) - 1.0).abs() < 1e-12);
2107 assert!((linear_interp(&x, &y, 3.0) - 5.0).abs() < 1e-12);
2108 }
2109
2110 #[test]
2111 fn test_linear_interp_nonuniform_grid() {
2112 let x = vec![0.0, 0.1, 0.5, 1.0];
2113 let y = vec![0.0, 1.0, 5.0, 10.0];
2114 let val = linear_interp(&x, &y, 0.3);
2116 let expected = 1.0 + 10.0 * (0.3 - 0.1);
2117 assert!(
2118 (val - expected).abs() < 1e-12,
2119 "Non-uniform interp: expected {expected}, got {val}"
2120 );
2121 }
2122
2123 #[test]
2124 fn test_linear_interp_two_points() {
2125 let x = vec![0.0, 1.0];
2126 let y = vec![3.0, 7.0];
2127 assert!((linear_interp(&x, &y, 0.25) - 4.0).abs() < 1e-12);
2128 assert!((linear_interp(&x, &y, 0.75) - 6.0).abs() < 1e-12);
2129 }
2130
2131 #[test]
2134 fn test_srsf_transform_linear() {
2135 let m = 50;
2137 let t = uniform_grid(m);
2138 let f: Vec<f64> = t.iter().map(|&ti| 2.0 * ti).collect();
2139 let mat = FdMatrix::from_slice(&f, 1, m).unwrap();
2140
2141 let q_mat = srsf_transform(&mat, &t);
2142 let q: Vec<f64> = q_mat.row(0);
2143
2144 let expected = 2.0_f64.sqrt();
2145 for j in 2..(m - 2) {
2147 assert!(
2148 (q[j] - expected).abs() < 0.1,
2149 "q[{j}] = {}, expected ~{expected}",
2150 q[j]
2151 );
2152 }
2153 }
2154
2155 #[test]
2156 fn test_srsf_transform_preserves_shape() {
2157 let data = make_test_data(10, 50, 42);
2158 let t = uniform_grid(50);
2159 let q = srsf_transform(&data, &t);
2160 assert_eq!(q.shape(), data.shape());
2161 }
2162
2163 #[test]
2164 fn test_srsf_transform_constant_is_zero() {
2165 let m = 30;
2167 let t = uniform_grid(m);
2168 let f = vec![5.0; m];
2169 let mat = FdMatrix::from_slice(&f, 1, m).unwrap();
2170 let q_mat = srsf_transform(&mat, &t);
2171 let q: Vec<f64> = q_mat.row(0);
2172
2173 for j in 0..m {
2174 assert!(
2175 q[j].abs() < 1e-10,
2176 "SRSF of constant should be 0, got q[{j}] = {}",
2177 q[j]
2178 );
2179 }
2180 }
2181
2182 #[test]
2183 fn test_srsf_transform_negative_slope() {
2184 let m = 50;
2186 let t = uniform_grid(m);
2187 let f: Vec<f64> = t.iter().map(|&ti| -3.0 * ti).collect();
2188 let mat = FdMatrix::from_slice(&f, 1, m).unwrap();
2189
2190 let q_mat = srsf_transform(&mat, &t);
2191 let q: Vec<f64> = q_mat.row(0);
2192
2193 let expected = -(3.0_f64.sqrt());
2194 for j in 2..(m - 2) {
2195 assert!(
2196 (q[j] - expected).abs() < 0.15,
2197 "q[{j}] = {}, expected ~{expected}",
2198 q[j]
2199 );
2200 }
2201 }
2202
2203 #[test]
2204 fn test_srsf_transform_empty_input() {
2205 let data = FdMatrix::zeros(0, 0);
2206 let t: Vec<f64> = vec![];
2207 let q = srsf_transform(&data, &t);
2208 assert_eq!(q.shape(), (0, 0));
2209 }
2210
2211 #[test]
2212 fn test_srsf_transform_multiple_curves() {
2213 let m = 40;
2214 let t = uniform_grid(m);
2215 let data = make_test_data(5, m, 42);
2216
2217 let q = srsf_transform(&data, &t);
2218 assert_eq!(q.shape(), (5, m));
2219
2220 for i in 0..5 {
2222 for j in 0..m {
2223 assert!(q[(i, j)].is_finite(), "SRSF should be finite at ({i},{j})");
2224 }
2225 }
2226 }
2227
2228 #[test]
2231 fn test_srsf_round_trip() {
2232 let m = 100;
2233 let t = uniform_grid(m);
2234 let f: Vec<f64> = t
2236 .iter()
2237 .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin() + ti)
2238 .collect();
2239
2240 let mat = FdMatrix::from_slice(&f, 1, m).unwrap();
2241 let q_mat = srsf_transform(&mat, &t);
2242 let q: Vec<f64> = q_mat.row(0);
2243
2244 let f_recon = srsf_inverse(&q, &t, f[0]);
2245
2246 let max_err: f64 = f[5..(m - 5)]
2248 .iter()
2249 .zip(f_recon[5..(m - 5)].iter())
2250 .map(|(a, b)| (a - b).abs())
2251 .fold(0.0_f64, f64::max);
2252
2253 assert!(
2254 max_err < 0.15,
2255 "Round-trip error too large: max_err = {max_err}"
2256 );
2257 }
2258
2259 #[test]
2260 fn test_srsf_inverse_empty() {
2261 let q: Vec<f64> = vec![];
2262 let t: Vec<f64> = vec![];
2263 let result = srsf_inverse(&q, &t, 0.0);
2264 assert!(result.is_empty());
2265 }
2266
2267 #[test]
2268 fn test_srsf_inverse_preserves_initial_value() {
2269 let m = 50;
2270 let t = uniform_grid(m);
2271 let q = vec![1.0; m]; let f0 = 3.15;
2273 let f = srsf_inverse(&q, &t, f0);
2274 assert!((f[0] - f0).abs() < 1e-12, "srsf_inverse should start at f0");
2275 }
2276
2277 #[test]
2278 fn test_srsf_round_trip_multiple_curves() {
2279 let m = 80;
2280 let t = uniform_grid(m);
2281 let data = make_test_data(5, m, 99);
2282
2283 let q_mat = srsf_transform(&data, &t);
2284
2285 for i in 0..5 {
2286 let fi = data.row(i);
2287 let qi = q_mat.row(i);
2288 let f_recon = srsf_inverse(&qi, &t, fi[0]);
2289 let max_err: f64 = fi[5..(m - 5)]
2290 .iter()
2291 .zip(f_recon[5..(m - 5)].iter())
2292 .map(|(a, b)| (a - b).abs())
2293 .fold(0.0_f64, f64::max);
2294 assert!(max_err < 0.3, "Round-trip curve {i}: max_err = {max_err}");
2295 }
2296 }
2297
2298 #[test]
2301 fn test_reparameterize_identity_warp() {
2302 let m = 50;
2303 let t = uniform_grid(m);
2304 let f: Vec<f64> = t.iter().map(|&ti| ti * ti).collect();
2305
2306 let result = reparameterize_curve(&f, &t, &t);
2308 for j in 0..m {
2309 assert!(
2310 (result[j] - f[j]).abs() < 1e-12,
2311 "Identity warp should return original at j={j}"
2312 );
2313 }
2314 }
2315
2316 #[test]
2317 fn test_reparameterize_linear_warp() {
2318 let m = 50;
2319 let t = uniform_grid(m);
2320 let f: Vec<f64> = t.clone();
2322 let gamma: Vec<f64> = t.iter().map(|&ti| ti * ti).collect();
2323
2324 let result = reparameterize_curve(&f, &t, &gamma);
2325
2326 for j in 0..m {
2328 assert!(
2329 (result[j] - gamma[j]).abs() < 1e-10,
2330 "f(gamma(t)) should be gamma(t) for f(t)=t at j={j}"
2331 );
2332 }
2333 }
2334
2335 #[test]
2336 fn test_reparameterize_sine_with_quadratic_warp() {
2337 let m = 100;
2338 let t = uniform_grid(m);
2339 let f: Vec<f64> = t
2340 .iter()
2341 .map(|&ti| (std::f64::consts::PI * ti).sin())
2342 .collect();
2343 let gamma: Vec<f64> = t.iter().map(|&ti| ti * ti).collect(); let result = reparameterize_curve(&f, &t, &gamma);
2346
2347 for j in 0..m {
2349 let expected = (std::f64::consts::PI * gamma[j]).sin();
2350 assert!(
2351 (result[j] - expected).abs() < 0.05,
2352 "sin(π γ(t)) at j={j}: expected {expected:.4}, got {:.4}",
2353 result[j]
2354 );
2355 }
2356 }
2357
2358 #[test]
2359 fn test_reparameterize_preserves_length() {
2360 let m = 50;
2361 let t = uniform_grid(m);
2362 let f = vec![1.0; m];
2363 let gamma: Vec<f64> = t.iter().map(|&ti| ti.sqrt()).collect();
2364
2365 let result = reparameterize_curve(&f, &t, &gamma);
2366 assert_eq!(result.len(), m);
2367 }
2368
2369 #[test]
2372 fn test_compose_warps_identity() {
2373 let m = 50;
2374 let t = uniform_grid(m);
2375 let gamma: Vec<f64> = t.iter().map(|&ti| ti.sqrt()).collect();
2377
2378 let result = compose_warps(&t, &gamma, &t);
2380 for j in 0..m {
2381 assert!(
2382 (result[j] - gamma[j]).abs() < 1e-10,
2383 "id ∘ γ should be γ at j={j}"
2384 );
2385 }
2386
2387 let result2 = compose_warps(&gamma, &t, &t);
2389 for j in 0..m {
2390 assert!(
2391 (result2[j] - gamma[j]).abs() < 1e-10,
2392 "γ ∘ id should be γ at j={j}"
2393 );
2394 }
2395 }
2396
2397 #[test]
2398 fn test_compose_warps_associativity() {
2399 let m = 50;
2401 let t = uniform_grid(m);
2402 let g1: Vec<f64> = t.iter().map(|&ti| ti.sqrt()).collect();
2403 let g2: Vec<f64> = t.iter().map(|&ti| ti * ti).collect();
2404 let g3: Vec<f64> = t.iter().map(|&ti| 0.5 * ti + 0.5 * ti * ti).collect();
2405
2406 let g12 = compose_warps(&g1, &g2, &t);
2407 let left = compose_warps(&g12, &g3, &t); let g23 = compose_warps(&g2, &g3, &t);
2410 let right = compose_warps(&g1, &g23, &t); for j in 0..m {
2413 assert!(
2414 (left[j] - right[j]).abs() < 0.05,
2415 "Composition should be roughly associative at j={j}: left={:.4}, right={:.4}",
2416 left[j],
2417 right[j]
2418 );
2419 }
2420 }
2421
2422 #[test]
2423 fn test_compose_warps_preserves_domain() {
2424 let m = 50;
2425 let t = uniform_grid(m);
2426 let g1: Vec<f64> = t.iter().map(|&ti| ti.sqrt()).collect();
2427 let g2: Vec<f64> = t.iter().map(|&ti| ti * ti).collect();
2428
2429 let composed = compose_warps(&g1, &g2, &t);
2430 assert!(
2431 (composed[0] - t[0]).abs() < 1e-10,
2432 "Composed warp should start at domain start"
2433 );
2434 assert!(
2435 (composed[m - 1] - t[m - 1]).abs() < 1e-10,
2436 "Composed warp should end at domain end"
2437 );
2438 }
2439
2440 #[test]
2443 fn test_align_identical_curves() {
2444 let m = 50;
2445 let t = uniform_grid(m);
2446 let f: Vec<f64> = t
2447 .iter()
2448 .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
2449 .collect();
2450
2451 let result = elastic_align_pair(&f, &f, &t, 0.0);
2452
2453 assert!(
2455 result.distance < 0.1,
2456 "Distance between identical curves should be near 0, got {}",
2457 result.distance
2458 );
2459
2460 for j in 0..m {
2462 assert!(
2463 (result.gamma[j] - t[j]).abs() < 0.1,
2464 "Warp should be near identity at j={j}: gamma={}, t={}",
2465 result.gamma[j],
2466 t[j]
2467 );
2468 }
2469 }
2470
2471 #[test]
2472 fn test_align_pair_valid_output() {
2473 let data = make_test_data(2, 50, 42);
2474 let t = uniform_grid(50);
2475 let f1 = data.row(0);
2476 let f2 = data.row(1);
2477
2478 let result = elastic_align_pair(&f1, &f2, &t, 0.0);
2479
2480 assert_eq!(result.gamma.len(), 50);
2481 assert_eq!(result.f_aligned.len(), 50);
2482 assert!(result.distance >= 0.0);
2483
2484 for j in 1..50 {
2486 assert!(
2487 result.gamma[j] >= result.gamma[j - 1],
2488 "Warp should be monotone at j={j}"
2489 );
2490 }
2491 }
2492
2493 #[test]
2494 fn test_align_pair_warp_boundaries() {
2495 let data = make_test_data(2, 50, 42);
2496 let t = uniform_grid(50);
2497 let f1 = data.row(0);
2498 let f2 = data.row(1);
2499
2500 let result = elastic_align_pair(&f1, &f2, &t, 0.0);
2501 assert!(
2502 (result.gamma[0] - t[0]).abs() < 1e-12,
2503 "Warp should start at domain start"
2504 );
2505 assert!(
2506 (result.gamma[49] - t[49]).abs() < 1e-12,
2507 "Warp should end at domain end"
2508 );
2509 }
2510
2511 #[test]
2512 fn test_align_shifted_sine() {
2513 let m = 80;
2515 let t = uniform_grid(m);
2516 let f1: Vec<f64> = t
2517 .iter()
2518 .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
2519 .collect();
2520 let f2: Vec<f64> = t
2521 .iter()
2522 .map(|&ti| (2.0 * std::f64::consts::PI * (ti - 0.1)).sin())
2523 .collect();
2524
2525 let weights = simpsons_weights(&t);
2526 let l2_before = l2_distance(&f1, &f2, &weights);
2527 let result = elastic_align_pair(&f1, &f2, &t, 0.0);
2528 let l2_after = l2_distance(&f1, &result.f_aligned, &weights);
2529
2530 assert!(
2531 l2_after < l2_before + 0.01,
2532 "Alignment should not increase L2 distance: before={l2_before:.4}, after={l2_after:.4}"
2533 );
2534 }
2535
2536 #[test]
2537 fn test_align_pair_aligned_curve_is_finite() {
2538 let data = make_test_data(2, 50, 77);
2539 let t = uniform_grid(50);
2540 let f1 = data.row(0);
2541 let f2 = data.row(1);
2542
2543 let result = elastic_align_pair(&f1, &f2, &t, 0.0);
2544 for j in 0..50 {
2545 assert!(
2546 result.f_aligned[j].is_finite(),
2547 "Aligned curve should be finite at j={j}"
2548 );
2549 }
2550 }
2551
2552 #[test]
2553 fn test_align_pair_minimum_grid() {
2554 let t = vec![0.0, 1.0];
2556 let f1 = vec![0.0, 1.0];
2557 let f2 = vec![0.0, 2.0];
2558 let result = elastic_align_pair(&f1, &f2, &t, 0.0);
2559 assert_eq!(result.gamma.len(), 2);
2560 assert_eq!(result.f_aligned.len(), 2);
2561 assert!(result.distance >= 0.0);
2562 }
2563
2564 #[test]
2567 fn test_elastic_distance_symmetric() {
2568 let data = make_test_data(3, 50, 42);
2569 let t = uniform_grid(50);
2570 let f1 = data.row(0);
2571 let f2 = data.row(1);
2572
2573 let d12 = elastic_distance(&f1, &f2, &t, 0.0);
2574 let d21 = elastic_distance(&f2, &f1, &t, 0.0);
2575
2576 assert!(
2578 (d12 - d21).abs() < d12.max(d21) * 0.3 + 0.01,
2579 "Elastic distance should be roughly symmetric: d12={d12}, d21={d21}"
2580 );
2581 }
2582
2583 #[test]
2584 fn test_elastic_distance_nonneg() {
2585 let data = make_test_data(3, 50, 42);
2586 let t = uniform_grid(50);
2587
2588 for i in 0..3 {
2589 for j in 0..3 {
2590 let fi = data.row(i);
2591 let fj = data.row(j);
2592 let d = elastic_distance(&fi, &fj, &t, 0.0);
2593 assert!(d >= 0.0, "Elastic distance should be non-negative");
2594 }
2595 }
2596 }
2597
2598 #[test]
2599 fn test_elastic_distance_self_near_zero() {
2600 let data = make_test_data(3, 50, 42);
2601 let t = uniform_grid(50);
2602
2603 for i in 0..3 {
2604 let fi = data.row(i);
2605 let d = elastic_distance(&fi, &fi, &t, 0.0);
2606 assert!(
2607 d < 0.1,
2608 "Self-distance should be near zero, got {d} for curve {i}"
2609 );
2610 }
2611 }
2612
2613 #[test]
2614 fn test_elastic_distance_triangle_inequality() {
2615 let data = make_test_data(3, 50, 42);
2616 let t = uniform_grid(50);
2617 let f0 = data.row(0);
2618 let f1 = data.row(1);
2619 let f2 = data.row(2);
2620
2621 let d01 = elastic_distance(&f0, &f1, &t, 0.0);
2622 let d12 = elastic_distance(&f1, &f2, &t, 0.0);
2623 let d02 = elastic_distance(&f0, &f2, &t, 0.0);
2624
2625 let slack = 0.5;
2627 assert!(
2628 d02 <= d01 + d12 + slack,
2629 "Triangle inequality (relaxed): d02={d02:.4} > d01={d01:.4} + d12={d12:.4} + {slack}"
2630 );
2631 }
2632
2633 #[test]
2634 fn test_elastic_distance_different_shapes_nonzero() {
2635 let m = 50;
2636 let t = uniform_grid(m);
2637 let f1: Vec<f64> = t.to_vec(); let f2: Vec<f64> = t.iter().map(|&ti| ti * ti).collect(); let d = elastic_distance(&f1, &f2, &t, 0.0);
2641 assert!(
2642 d > 0.01,
2643 "Distance between different shapes should be > 0, got {d}"
2644 );
2645 }
2646
2647 #[test]
2650 fn test_self_distance_matrix_symmetric() {
2651 let data = make_test_data(5, 30, 42);
2652 let t = uniform_grid(30);
2653
2654 let dm = elastic_self_distance_matrix(&data, &t, 0.0);
2655 let n = dm.nrows();
2656
2657 assert_eq!(dm.shape(), (5, 5));
2658
2659 for i in 0..n {
2661 assert!(dm[(i, i)].abs() < 1e-12, "Diagonal should be zero");
2662 }
2663
2664 for i in 0..n {
2666 for j in (i + 1)..n {
2667 assert!(
2668 (dm[(i, j)] - dm[(j, i)]).abs() < 1e-12,
2669 "Matrix should be symmetric at ({i},{j})"
2670 );
2671 }
2672 }
2673 }
2674
2675 #[test]
2676 fn test_self_distance_matrix_nonneg() {
2677 let data = make_test_data(4, 30, 42);
2678 let t = uniform_grid(30);
2679 let dm = elastic_self_distance_matrix(&data, &t, 0.0);
2680
2681 for i in 0..4 {
2682 for j in 0..4 {
2683 assert!(
2684 dm[(i, j)] >= 0.0,
2685 "Distance matrix entries should be non-negative at ({i},{j})"
2686 );
2687 }
2688 }
2689 }
2690
2691 #[test]
2692 fn test_self_distance_matrix_single_curve() {
2693 let data = make_test_data(1, 30, 42);
2694 let t = uniform_grid(30);
2695 let dm = elastic_self_distance_matrix(&data, &t, 0.0);
2696 assert_eq!(dm.shape(), (1, 1));
2697 assert!(dm[(0, 0)].abs() < 1e-12);
2698 }
2699
2700 #[test]
2701 fn test_self_distance_matrix_consistent_with_pairwise() {
2702 let data = make_test_data(4, 30, 42);
2703 let t = uniform_grid(30);
2704
2705 let dm = elastic_self_distance_matrix(&data, &t, 0.0);
2706
2707 for i in 0..4 {
2709 for j in (i + 1)..4 {
2710 let fi = data.row(i);
2711 let fj = data.row(j);
2712 let d_direct = elastic_distance(&fi, &fj, &t, 0.0);
2713 assert!(
2714 (dm[(i, j)] - d_direct).abs() < 1e-10,
2715 "Matrix entry ({i},{j})={:.6} should match pairwise {d_direct:.6}",
2716 dm[(i, j)]
2717 );
2718 }
2719 }
2720 }
2721
2722 #[test]
2725 fn test_karcher_mean_identical_curves() {
2726 let m = 50;
2727 let t = uniform_grid(m);
2728 let f: Vec<f64> = t
2729 .iter()
2730 .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
2731 .collect();
2732
2733 let mut data = FdMatrix::zeros(5, m);
2735 for i in 0..5 {
2736 for j in 0..m {
2737 data[(i, j)] = f[j];
2738 }
2739 }
2740
2741 let result = karcher_mean(&data, &t, 10, 1e-4, 0.0);
2742
2743 assert_eq!(result.mean.len(), m);
2744 assert!(result.n_iter <= 10);
2745 }
2746
2747 #[test]
2748 fn test_karcher_mean_output_shape() {
2749 let data = make_test_data(15, 50, 42);
2750 let t = uniform_grid(50);
2751
2752 let result = karcher_mean(&data, &t, 5, 1e-3, 0.0);
2753
2754 assert_eq!(result.mean.len(), 50);
2755 assert_eq!(result.mean_srsf.len(), 50);
2756 assert_eq!(result.gammas.shape(), (15, 50));
2757 assert_eq!(result.aligned_data.shape(), (15, 50));
2758 assert!(result.n_iter <= 5);
2759 }
2760
2761 #[test]
2762 fn test_karcher_mean_warps_are_valid() {
2763 let data = make_test_data(10, 40, 42);
2764 let t = uniform_grid(40);
2765
2766 let result = karcher_mean(&data, &t, 5, 1e-3, 0.0);
2767
2768 for i in 0..10 {
2769 assert!(
2771 (result.gammas[(i, 0)] - t[0]).abs() < 1e-10,
2772 "Warp {i} should start at domain start"
2773 );
2774 assert!(
2775 (result.gammas[(i, 39)] - t[39]).abs() < 1e-10,
2776 "Warp {i} should end at domain end"
2777 );
2778 for j in 1..40 {
2780 assert!(
2781 result.gammas[(i, j)] >= result.gammas[(i, j - 1)],
2782 "Warp {i} should be monotone at j={j}"
2783 );
2784 }
2785 }
2786 }
2787
2788 #[test]
2789 fn test_karcher_mean_aligned_data_is_finite() {
2790 let data = make_test_data(8, 40, 42);
2791 let t = uniform_grid(40);
2792 let result = karcher_mean(&data, &t, 5, 1e-3, 0.0);
2793
2794 for i in 0..8 {
2795 for j in 0..40 {
2796 assert!(
2797 result.aligned_data[(i, j)].is_finite(),
2798 "Aligned data should be finite at ({i},{j})"
2799 );
2800 }
2801 }
2802 }
2803
2804 #[test]
2805 fn test_karcher_mean_srsf_is_finite() {
2806 let data = make_test_data(8, 40, 42);
2807 let t = uniform_grid(40);
2808 let result = karcher_mean(&data, &t, 5, 1e-3, 0.0);
2809
2810 for j in 0..40 {
2811 assert!(
2812 result.mean_srsf[j].is_finite(),
2813 "Mean SRSF should be finite at j={j}"
2814 );
2815 assert!(
2816 result.mean[j].is_finite(),
2817 "Mean curve should be finite at j={j}"
2818 );
2819 }
2820 }
2821
2822 #[test]
2823 fn test_karcher_mean_single_iteration() {
2824 let data = make_test_data(10, 40, 42);
2825 let t = uniform_grid(40);
2826 let result = karcher_mean(&data, &t, 1, 1e-10, 0.0);
2827
2828 assert_eq!(result.n_iter, 1);
2829 assert_eq!(result.mean.len(), 40);
2830 for j in 0..40 {
2832 assert!(result.mean[j].is_finite());
2833 }
2834 }
2835
2836 #[test]
2839 fn test_align_to_target_valid() {
2840 let data = make_test_data(10, 40, 42);
2841 let t = uniform_grid(40);
2842 let target = data.row(0);
2843
2844 let result = align_to_target(&data, &target, &t, 0.0);
2845
2846 assert_eq!(result.gammas.shape(), (10, 40));
2847 assert_eq!(result.aligned_data.shape(), (10, 40));
2848 assert_eq!(result.distances.len(), 10);
2849
2850 for &d in &result.distances {
2852 assert!(d >= 0.0);
2853 }
2854 }
2855
2856 #[test]
2857 fn test_align_to_target_self_near_zero() {
2858 let data = make_test_data(5, 40, 42);
2859 let t = uniform_grid(40);
2860 let target = data.row(0);
2861
2862 let result = align_to_target(&data, &target, &t, 0.0);
2863
2864 assert!(
2866 result.distances[0] < 0.1,
2867 "Self-alignment distance should be near zero, got {}",
2868 result.distances[0]
2869 );
2870 }
2871
2872 #[test]
2873 fn test_align_to_target_warps_are_monotone() {
2874 let data = make_test_data(8, 40, 42);
2875 let t = uniform_grid(40);
2876 let target = data.row(0);
2877 let result = align_to_target(&data, &target, &t, 0.0);
2878
2879 for i in 0..8 {
2880 for j in 1..40 {
2881 assert!(
2882 result.gammas[(i, j)] >= result.gammas[(i, j - 1)],
2883 "Warp for curve {i} should be monotone at j={j}"
2884 );
2885 }
2886 }
2887 }
2888
2889 #[test]
2890 fn test_align_to_target_aligned_data_finite() {
2891 let data = make_test_data(6, 40, 42);
2892 let t = uniform_grid(40);
2893 let target = data.row(0);
2894 let result = align_to_target(&data, &target, &t, 0.0);
2895
2896 for i in 0..6 {
2897 for j in 0..40 {
2898 assert!(
2899 result.aligned_data[(i, j)].is_finite(),
2900 "Aligned data should be finite at ({i},{j})"
2901 );
2902 }
2903 }
2904 }
2905
2906 #[test]
2909 fn test_cross_distance_matrix_shape() {
2910 let data1 = make_test_data(3, 30, 42);
2911 let data2 = make_test_data(4, 30, 99);
2912 let t = uniform_grid(30);
2913
2914 let dm = elastic_cross_distance_matrix(&data1, &data2, &t, 0.0);
2915 assert_eq!(dm.shape(), (3, 4));
2916
2917 for i in 0..3 {
2919 for j in 0..4 {
2920 assert!(dm[(i, j)] >= 0.0);
2921 }
2922 }
2923 }
2924
2925 #[test]
2926 fn test_cross_distance_matrix_self_matches_self_matrix() {
2927 let data = make_test_data(4, 30, 42);
2929 let t = uniform_grid(30);
2930
2931 let cross = elastic_cross_distance_matrix(&data, &data, &t, 0.0);
2932 for i in 0..4 {
2933 assert!(
2934 cross[(i, i)] < 0.1,
2935 "Cross distance (self) diagonal should be near zero: got {}",
2936 cross[(i, i)]
2937 );
2938 }
2939 }
2940
2941 #[test]
2942 fn test_cross_distance_matrix_consistent_with_pairwise() {
2943 let data1 = make_test_data(3, 30, 42);
2944 let data2 = make_test_data(2, 30, 99);
2945 let t = uniform_grid(30);
2946
2947 let dm = elastic_cross_distance_matrix(&data1, &data2, &t, 0.0);
2948
2949 for i in 0..3 {
2950 for j in 0..2 {
2951 let fi = data1.row(i);
2952 let fj = data2.row(j);
2953 let d_direct = elastic_distance(&fi, &fj, &t, 0.0);
2954 assert!(
2955 (dm[(i, j)] - d_direct).abs() < 1e-10,
2956 "Cross matrix ({i},{j})={:.6} should match pairwise {d_direct:.6}",
2957 dm[(i, j)]
2958 );
2959 }
2960 }
2961 }
2962
2963 #[test]
2966 fn test_align_srsf_pair_identity() {
2967 let m = 50;
2968 let t = uniform_grid(m);
2969 let f: Vec<f64> = t
2970 .iter()
2971 .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
2972 .collect();
2973 let q = srsf_single(&f, &t);
2974
2975 let (gamma, q_aligned) = align_srsf_pair(&q, &q, &t, 0.0);
2976
2977 for j in 0..m {
2979 assert!(
2980 (gamma[j] - t[j]).abs() < 0.15,
2981 "Self-SRSF alignment warp should be near identity at j={j}"
2982 );
2983 }
2984
2985 let weights = simpsons_weights(&t);
2987 let dist = l2_distance(&q, &q_aligned, &weights);
2988 assert!(
2989 dist < 0.5,
2990 "Self-aligned SRSF distance should be small, got {dist}"
2991 );
2992 }
2993
2994 #[test]
2997 fn test_srsf_single_matches_matrix_version() {
2998 let m = 50;
2999 let t = uniform_grid(m);
3000 let f: Vec<f64> = t.iter().map(|&ti| ti * ti + ti).collect();
3001
3002 let q_single = srsf_single(&f, &t);
3003
3004 let mat = FdMatrix::from_slice(&f, 1, m).unwrap();
3005 let q_mat = srsf_transform(&mat, &t);
3006 let q_from_mat = q_mat.row(0);
3007
3008 for j in 0..m {
3009 assert!(
3010 (q_single[j] - q_from_mat[j]).abs() < 1e-12,
3011 "srsf_single should match srsf_transform at j={j}"
3012 );
3013 }
3014 }
3015
3016 #[test]
3019 fn test_gcd_basic() {
3020 assert_eq!(gcd(1, 1), 1);
3021 assert_eq!(gcd(6, 4), 2);
3022 assert_eq!(gcd(7, 5), 1);
3023 assert_eq!(gcd(12, 8), 4);
3024 assert_eq!(gcd(7, 0), 7);
3025 assert_eq!(gcd(0, 5), 5);
3026 }
3027
3028 #[test]
3031 fn test_coprime_nbhd_count() {
3032 assert_eq!(generate_coprime_nbhd(1).len(), 1); assert_eq!(generate_coprime_nbhd(7).len(), 35);
3034 }
3035
3036 #[test]
3037 fn test_coprime_nbhd_matches_const() {
3038 let generated = generate_coprime_nbhd(7);
3039 assert_eq!(generated.len(), COPRIME_NBHD_7.len());
3040 for (i, pair) in generated.iter().enumerate() {
3041 assert_eq!(*pair, COPRIME_NBHD_7[i], "mismatch at index {i}");
3042 }
3043 }
3044
3045 #[test]
3046 fn test_coprime_nbhd_all_coprime() {
3047 for &(i, j) in &COPRIME_NBHD_7 {
3048 assert_eq!(gcd(i, j), 1, "({i},{j}) should be coprime");
3049 assert!((1..=7).contains(&i));
3050 assert!((1..=7).contains(&j));
3051 }
3052 }
3053
3054 #[test]
3057 fn test_dp_edge_weight_diagonal() {
3058 let t = uniform_grid(10);
3060 let q1 = vec![1.0; 10];
3061 let q2 = vec![1.0; 10];
3062 let w = dp_edge_weight(&q1, &q2, &t, 0, 1, 0, 1);
3064 assert!(w.abs() < 1e-12, "identical SRSFs should have zero cost");
3065 }
3066
3067 #[test]
3068 fn test_dp_edge_weight_non_diagonal() {
3069 let t = uniform_grid(10);
3071 let q1 = vec![1.0; 10];
3072 let q2 = vec![0.0; 10];
3073 let w = dp_edge_weight(&q1, &q2, &t, 0, 2, 0, 1);
3074 let expected = 2.0 / 9.0;
3077 assert!(
3078 (w - expected).abs() < 1e-10,
3079 "dp_edge_weight (1,2): expected {expected}, got {w}"
3080 );
3081 }
3082
3083 #[test]
3084 fn test_dp_edge_weight_zero_span() {
3085 let t = uniform_grid(10);
3086 let q1 = vec![1.0; 10];
3087 let q2 = vec![1.0; 10];
3088 assert_eq!(dp_edge_weight(&q1, &q2, &t, 3, 3, 0, 1), f64::INFINITY);
3090 assert_eq!(dp_edge_weight(&q1, &q2, &t, 0, 1, 3, 3), f64::INFINITY);
3092 }
3093
3094 #[test]
3097 fn test_alignment_improves_distance() {
3098 let m = 50;
3100 let t = uniform_grid(m);
3101 let f1: Vec<f64> = t
3102 .iter()
3103 .map(|&x| (2.0 * std::f64::consts::PI * x).sin())
3104 .collect();
3105 let f2: Vec<f64> = t
3107 .iter()
3108 .map(|&x| (2.0 * std::f64::consts::PI * (x + 0.2)).sin())
3109 .collect();
3110
3111 let q1 = srsf_single(&f1, &t);
3112 let q2 = srsf_single(&f2, &t);
3113 let weights = simpsons_weights(&t);
3114 let unaligned_srsf_dist = l2_distance(&q1, &q2, &weights);
3115
3116 let result = elastic_align_pair(&f1, &f2, &t, 0.0);
3117
3118 assert!(
3119 result.distance <= unaligned_srsf_dist + 1e-6,
3120 "aligned SRSF dist ({}) should be <= unaligned SRSF dist ({})",
3121 result.distance,
3122 unaligned_srsf_dist
3123 );
3124 }
3125
3126 #[test]
3129 fn test_alignment_constant_curves() {
3130 let m = 30;
3131 let t = uniform_grid(m);
3132 let f1 = vec![5.0; m];
3133 let f2 = vec![5.0; m];
3134
3135 let result = elastic_align_pair(&f1, &f2, &t, 0.0);
3136 assert!(
3137 result.distance < 0.01,
3138 "Constant curves: distance should be ~0"
3139 );
3140 assert_eq!(result.f_aligned.len(), m);
3141 }
3142
3143 #[test]
3144 fn test_karcher_mean_constant_curves() {
3145 let m = 30;
3146 let t = uniform_grid(m);
3147 let mut data = FdMatrix::zeros(5, m);
3148 for i in 0..5 {
3149 for j in 0..m {
3150 data[(i, j)] = 3.0;
3151 }
3152 }
3153
3154 let result = karcher_mean(&data, &t, 5, 1e-4, 0.0);
3155 for j in 0..m {
3156 assert!(
3157 (result.mean[j] - 3.0).abs() < 0.5,
3158 "Mean of constant curves should be near 3.0, got {} at j={j}",
3159 result.mean[j]
3160 );
3161 }
3162 }
3163
3164 #[test]
3165 fn test_nan_srsf_no_panic() {
3166 let m = 20;
3167 let t = uniform_grid(m);
3168 let mut f = vec![1.0; m];
3169 f[5] = f64::NAN;
3170 let mat = FdMatrix::from_slice(&f, 1, m).unwrap();
3171 let q = srsf_transform(&mat, &t);
3172 assert_eq!(q.nrows(), 1);
3174 }
3175
3176 #[test]
3177 fn test_n1_karcher_mean() {
3178 let m = 30;
3179 let t = uniform_grid(m);
3180 let f: Vec<f64> = t.iter().map(|&ti| ti * ti).collect();
3181 let data = FdMatrix::from_slice(&f, 1, m).unwrap();
3182 let result = karcher_mean(&data, &t, 5, 1e-4, 0.0);
3183 assert_eq!(result.mean.len(), m);
3184 for j in 0..m {
3186 assert!(result.mean[j].is_finite());
3187 }
3188 }
3189
3190 #[test]
3191 fn test_two_point_grid() {
3192 let t = vec![0.0, 1.0];
3193 let f1 = vec![0.0, 1.0];
3194 let f2 = vec![0.0, 2.0];
3195 let d = elastic_distance(&f1, &f2, &t, 0.0);
3196 assert!(d >= 0.0);
3197 assert!(d.is_finite());
3198 }
3199
3200 #[test]
3201 fn test_non_uniform_grid_alignment() {
3202 let t = vec![0.0, 0.01, 0.05, 0.2, 0.5, 1.0];
3204 let m = t.len();
3205 let f1: Vec<f64> = t.iter().map(|&ti: &f64| ti.sin()).collect();
3206 let f2: Vec<f64> = t.iter().map(|&ti: &f64| (ti + 0.1).sin()).collect();
3207 let result = elastic_align_pair(&f1, &f2, &t, 0.0);
3208 assert_eq!(result.gamma.len(), m);
3209 assert!(result.distance >= 0.0);
3210 assert!(result.distance.is_finite());
3211 }
3212
3213 #[test]
3216 fn test_tsrvf_output_shape() {
3217 let m = 50;
3218 let n = 10;
3219 let t = uniform_grid(m);
3220 let data = make_test_data(n, m, 42);
3221 let result = tsrvf_transform(&data, &t, 5, 1e-3, 0.0);
3222 assert_eq!(
3223 result.tangent_vectors.shape(),
3224 (n, m),
3225 "Tangent vectors should be n×m"
3226 );
3227 assert_eq!(result.gammas.shape(), (n, m), "Gammas should be n×m");
3228 assert_eq!(result.srsf_norms.len(), n, "Should have n SRSF norms");
3229 assert_eq!(result.mean.len(), m, "Mean should have m points");
3230 assert_eq!(result.mean_srsf.len(), m, "Mean SRSF should have m points");
3231 }
3232
3233 #[test]
3234 fn test_tsrvf_all_finite() {
3235 let m = 50;
3236 let n = 5;
3237 let t = uniform_grid(m);
3238 let data = make_test_data(n, m, 42);
3239 let result = tsrvf_transform(&data, &t, 5, 1e-3, 0.0);
3240 for i in 0..n {
3241 for j in 0..m {
3242 assert!(
3243 result.tangent_vectors[(i, j)].is_finite(),
3244 "Tangent vector should be finite at ({i},{j})"
3245 );
3246 }
3247 assert!(
3248 result.srsf_norms[i].is_finite(),
3249 "SRSF norm should be finite for curve {i}"
3250 );
3251 }
3252 assert!(
3253 result.mean_srsf_norm.is_finite(),
3254 "Mean SRSF norm should be finite"
3255 );
3256 }
3257
3258 #[test]
3259 fn test_tsrvf_identical_curves_zero_tangent() {
3260 let m = 50;
3261 let t = uniform_grid(m);
3262 let curve: Vec<f64> = t
3264 .iter()
3265 .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
3266 .collect();
3267 let mut col_major = vec![0.0; 5 * m];
3268 for i in 0..5 {
3269 for j in 0..m {
3270 col_major[i + j * 5] = curve[j];
3271 }
3272 }
3273 let data = FdMatrix::from_column_major(col_major, 5, m).unwrap();
3274 let result = tsrvf_transform(&data, &t, 10, 1e-4, 0.0);
3275
3276 for i in 0..5 {
3278 let tv_norm_sq: f64 = (0..m).map(|j| result.tangent_vectors[(i, j)].powi(2)).sum();
3279 assert!(
3280 tv_norm_sq.sqrt() < 0.5,
3281 "Identical curves should have near-zero tangent vectors, got norm = {}",
3282 tv_norm_sq.sqrt()
3283 );
3284 }
3285 }
3286
3287 #[test]
3288 fn test_tsrvf_mean_tangent_near_zero() {
3289 let m = 50;
3290 let n = 10;
3291 let t = uniform_grid(m);
3292 let data = make_test_data(n, m, 42);
3293 let result = tsrvf_transform(&data, &t, 10, 1e-3, 0.0);
3294
3295 let mut mean_tv = vec![0.0; m];
3297 for i in 0..n {
3298 for j in 0..m {
3299 mean_tv[j] += result.tangent_vectors[(i, j)];
3300 }
3301 }
3302 for j in 0..m {
3303 mean_tv[j] /= n as f64;
3304 }
3305 let mean_norm: f64 = mean_tv.iter().map(|v| v * v).sum::<f64>().sqrt();
3306 assert!(
3307 mean_norm < 1.0,
3308 "Mean tangent vector should be near zero, got norm = {mean_norm}"
3309 );
3310 }
3311
3312 #[test]
3313 fn test_tsrvf_from_alignment() {
3314 let m = 50;
3315 let n = 5;
3316 let t = uniform_grid(m);
3317 let data = make_test_data(n, m, 42);
3318 let karcher = karcher_mean(&data, &t, 5, 1e-3, 0.0);
3319 let result = tsrvf_from_alignment(&karcher, &t);
3320 assert_eq!(result.tangent_vectors.shape(), (n, m));
3321 assert!(result.mean_srsf_norm > 0.0);
3322 }
3323
3324 #[test]
3325 fn test_tsrvf_round_trip() {
3326 let m = 50;
3327 let n = 5;
3328 let t = uniform_grid(m);
3329 let data = make_test_data(n, m, 42);
3330 let result = tsrvf_transform(&data, &t, 10, 1e-3, 0.0);
3331 let reconstructed = tsrvf_inverse(&result, &t);
3332
3333 assert_eq!(reconstructed.shape(), result.tangent_vectors.shape());
3334 for i in 0..n {
3336 for j in 0..m {
3337 assert!(
3338 reconstructed[(i, j)].is_finite(),
3339 "Reconstructed curve should be finite at ({i},{j})"
3340 );
3341 }
3342 }
3343 }
3344
3345 #[test]
3346 fn test_tsrvf_single_curve() {
3347 let m = 50;
3348 let t = uniform_grid(m);
3349 let data = make_test_data(1, m, 42);
3350 let result = tsrvf_transform(&data, &t, 5, 1e-3, 0.0);
3351 assert_eq!(result.tangent_vectors.shape(), (1, m));
3352 let tv_norm: f64 = (0..m)
3354 .map(|j| result.tangent_vectors[(0, j)].powi(2))
3355 .sum::<f64>()
3356 .sqrt();
3357 assert!(
3358 tv_norm < 0.5,
3359 "Single curve tangent vector should be near zero, got {tv_norm}"
3360 );
3361 }
3362
3363 #[test]
3364 fn test_tsrvf_constant_curves() {
3365 let m = 30;
3366 let t = uniform_grid(m);
3367 let data = FdMatrix::from_column_major(vec![5.0; 3 * m], 3, m).unwrap();
3369 let result = tsrvf_transform(&data, &t, 5, 1e-3, 0.0);
3370 for i in 0..3 {
3372 for j in 0..m {
3373 let v = result.tangent_vectors[(i, j)];
3374 assert!(
3375 !v.is_nan(),
3376 "Constant curves should not produce NaN tangent vectors"
3377 );
3378 }
3379 }
3380 }
3381
3382 #[test]
3385 fn test_tsrvf_sphere_inv_exp_reference() {
3386 let m = 21;
3389 let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
3390
3391 let raw1 = vec![1.0; m];
3393 let norm1 = inner_product_l2(&raw1, &raw1, &time).max(0.0).sqrt();
3394 let psi1: Vec<f64> = raw1.iter().map(|&v| v / norm1).collect();
3395
3396 let raw2: Vec<f64> = time
3398 .iter()
3399 .map(|&t| 1.0 + 0.3 * (2.0 * std::f64::consts::PI * t).sin())
3400 .collect();
3401 let norm2 = inner_product_l2(&raw2, &raw2, &time).max(0.0).sqrt();
3402 let psi2: Vec<f64> = raw2.iter().map(|&v| v / norm2).collect();
3403
3404 let ip = inner_product_l2(&psi1, &psi2, &time).clamp(-1.0, 1.0);
3406 let theta_expected = ip.acos();
3407
3408 let v = inv_exp_map_sphere(&psi1, &psi2, &time);
3410 let v_norm = inner_product_l2(&v, &v, &time).max(0.0).sqrt();
3411
3412 assert!(
3414 (v_norm - theta_expected).abs() < 1e-10,
3415 "||v|| = {v_norm}, expected theta = {theta_expected}"
3416 );
3417
3418 assert!(
3420 theta_expected > 0.01 && theta_expected < 1.0,
3421 "theta = {theta_expected} out of expected range"
3422 );
3423 }
3424
3425 #[test]
3426 fn test_tsrvf_sphere_round_trip_reference() {
3427 let m = 21;
3429 let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
3430
3431 let raw1 = vec![1.0; m];
3432 let norm1 = inner_product_l2(&raw1, &raw1, &time).max(0.0).sqrt();
3433 let psi1: Vec<f64> = raw1.iter().map(|&v| v / norm1).collect();
3434
3435 let raw2: Vec<f64> = time
3436 .iter()
3437 .map(|&t| 1.0 + 0.3 * (2.0 * std::f64::consts::PI * t).sin())
3438 .collect();
3439 let norm2 = inner_product_l2(&raw2, &raw2, &time).max(0.0).sqrt();
3440 let psi2: Vec<f64> = raw2.iter().map(|&v| v / norm2).collect();
3441
3442 let v = inv_exp_map_sphere(&psi1, &psi2, &time);
3443 let recovered = exp_map_sphere(&psi1, &v, &time);
3444
3445 let diff: Vec<f64> = psi2
3447 .iter()
3448 .zip(recovered.iter())
3449 .map(|(&a, &b)| (a - b).powi(2))
3450 .collect();
3451 let l2_err = trapz(&diff, &time).max(0.0).sqrt();
3452 assert!(
3453 l2_err < 1e-12,
3454 "Round-trip L2 error = {l2_err:.2e}, expected < 1e-12"
3455 );
3456 }
3457
3458 #[test]
3461 fn test_penalized_alignment_lambda_zero_matches_unpenalized() {
3462 let m = 50;
3463 let t = uniform_grid(m);
3464 let data = make_test_data(2, m, 42);
3465 let f1 = data.row(0);
3466 let f2 = data.row(1);
3467
3468 let r0 = elastic_align_pair(&f1, &f2, &t, 0.0);
3469 assert!(r0.distance >= 0.0);
3471 assert_eq!(r0.gamma.len(), m);
3472 }
3473
3474 #[test]
3475 fn test_penalized_alignment_smoother_warp() {
3476 let m = 80;
3477 let t = uniform_grid(m);
3478 let f1: Vec<f64> = t
3479 .iter()
3480 .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
3481 .collect();
3482 let f2: Vec<f64> = t
3483 .iter()
3484 .map(|&ti| (2.0 * std::f64::consts::PI * (ti - 0.15)).sin())
3485 .collect();
3486
3487 let r_free = elastic_align_pair(&f1, &f2, &t, 0.0);
3488 let r_pen = elastic_align_pair(&f1, &f2, &t, 1.0);
3489
3490 let dev_free: f64 = r_free
3492 .gamma
3493 .iter()
3494 .zip(t.iter())
3495 .map(|(g, ti)| (g - ti).powi(2))
3496 .sum();
3497 let dev_pen: f64 = r_pen
3498 .gamma
3499 .iter()
3500 .zip(t.iter())
3501 .map(|(g, ti)| (g - ti).powi(2))
3502 .sum();
3503
3504 assert!(
3505 dev_pen <= dev_free + 1e-6,
3506 "Penalized warp should be closer to identity: free={dev_free:.6}, pen={dev_pen:.6}"
3507 );
3508 }
3509
3510 #[test]
3511 fn test_penalized_alignment_large_lambda_near_identity() {
3512 let m = 50;
3513 let t = uniform_grid(m);
3514 let f1: Vec<f64> = t
3515 .iter()
3516 .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
3517 .collect();
3518 let f2: Vec<f64> = t
3519 .iter()
3520 .map(|&ti| (2.0 * std::f64::consts::PI * (ti - 0.1)).sin())
3521 .collect();
3522
3523 let r = elastic_align_pair(&f1, &f2, &t, 1000.0);
3524
3525 let max_dev: f64 = r
3527 .gamma
3528 .iter()
3529 .zip(t.iter())
3530 .map(|(g, ti)| (g - ti).abs())
3531 .fold(0.0_f64, f64::max);
3532 assert!(
3533 max_dev < 0.05,
3534 "Large lambda should give near-identity warp: max deviation = {max_dev}"
3535 );
3536 }
3537
3538 #[test]
3539 fn test_penalized_karcher_mean() {
3540 let m = 40;
3541 let t = uniform_grid(m);
3542 let data = make_test_data(10, m, 42);
3543
3544 let result = karcher_mean(&data, &t, 5, 1e-3, 0.5);
3545 assert_eq!(result.mean.len(), m);
3546 for j in 0..m {
3547 assert!(result.mean[j].is_finite());
3548 }
3549 }
3550
3551 #[test]
3554 fn test_decomposition_identity_curves() {
3555 let m = 50;
3556 let t = uniform_grid(m);
3557 let f: Vec<f64> = t
3558 .iter()
3559 .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
3560 .collect();
3561
3562 let result = elastic_decomposition(&f, &f, &t, 0.0);
3563 assert!(
3564 result.d_amplitude < 0.1,
3565 "Self-decomposition amplitude should be ~0, got {}",
3566 result.d_amplitude
3567 );
3568 assert!(
3569 result.d_phase < 0.2,
3570 "Self-decomposition phase should be ~0, got {}",
3571 result.d_phase
3572 );
3573 }
3574
3575 #[test]
3576 fn test_decomposition_pythagorean() {
3577 let m = 80;
3579 let t = uniform_grid(m);
3580 let f1: Vec<f64> = t
3581 .iter()
3582 .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
3583 .collect();
3584 let f2: Vec<f64> = t
3585 .iter()
3586 .map(|&ti| 1.2 * (2.0 * std::f64::consts::PI * (ti - 0.1)).sin())
3587 .collect();
3588
3589 let result = elastic_decomposition(&f1, &f2, &t, 0.0);
3590 let da = result.d_amplitude;
3591 let dp = result.d_phase;
3592 assert!(da >= 0.0);
3594 assert!(dp >= 0.0);
3595 }
3596
3597 #[test]
3598 fn test_phase_distance_shifted_sine() {
3599 let m = 80;
3600 let t = uniform_grid(m);
3601 let f1: Vec<f64> = t
3602 .iter()
3603 .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
3604 .collect();
3605 let f2: Vec<f64> = t
3606 .iter()
3607 .map(|&ti| (2.0 * std::f64::consts::PI * (ti - 0.15)).sin())
3608 .collect();
3609
3610 let dp = phase_distance_pair(&f1, &f2, &t, 0.0);
3611 assert!(
3612 dp > 0.01,
3613 "Phase distance of shifted curves should be > 0, got {dp}"
3614 );
3615 }
3616
3617 #[test]
3618 fn test_amplitude_distance_scaled_curve() {
3619 let m = 80;
3620 let t = uniform_grid(m);
3621 let f1: Vec<f64> = t
3622 .iter()
3623 .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
3624 .collect();
3625 let f2: Vec<f64> = t
3626 .iter()
3627 .map(|&ti| 2.0 * (2.0 * std::f64::consts::PI * ti).sin())
3628 .collect();
3629
3630 let da = amplitude_distance(&f1, &f2, &t, 0.0);
3631 assert!(
3632 da > 0.01,
3633 "Amplitude distance of scaled curves should be > 0, got {da}"
3634 );
3635 }
3636
3637 #[test]
3638 fn test_phase_distance_nonneg() {
3639 let data = make_test_data(4, 40, 42);
3640 let t = uniform_grid(40);
3641 for i in 0..4 {
3642 for j in 0..4 {
3643 let fi = data.row(i);
3644 let fj = data.row(j);
3645 let dp = phase_distance_pair(&fi, &fj, &t, 0.0);
3646 assert!(dp >= 0.0, "Phase distance should be non-negative");
3647 }
3648 }
3649 }
3650
3651 #[test]
3654 fn test_schilds_ladder_zero_vector() {
3655 let m = 21;
3656 let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
3657 let raw = vec![1.0; m];
3658 let norm = crate::warping::l2_norm_l2(&raw, &time);
3659 let from: Vec<f64> = raw.iter().map(|&v| v / norm).collect();
3660 let raw2: Vec<f64> = time
3661 .iter()
3662 .map(|&t| 1.0 + 0.2 * (2.0 * std::f64::consts::PI * t).sin())
3663 .collect();
3664 let norm2 = crate::warping::l2_norm_l2(&raw2, &time);
3665 let to: Vec<f64> = raw2.iter().map(|&v| v / norm2).collect();
3666
3667 let zero = vec![0.0; m];
3668 let result = parallel_transport_schilds(&zero, &from, &to, &time);
3669 let result_norm: f64 = result.iter().map(|v| v * v).sum::<f64>().sqrt();
3670 assert!(
3671 result_norm < 1e-6,
3672 "Transporting zero should give zero, got norm {result_norm}"
3673 );
3674 }
3675
3676 #[test]
3677 fn test_pole_ladder_zero_vector() {
3678 let m = 21;
3679 let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
3680 let raw = vec![1.0; m];
3681 let norm = crate::warping::l2_norm_l2(&raw, &time);
3682 let from: Vec<f64> = raw.iter().map(|&v| v / norm).collect();
3683 let raw2: Vec<f64> = time
3684 .iter()
3685 .map(|&t| 1.0 + 0.2 * (2.0 * std::f64::consts::PI * t).sin())
3686 .collect();
3687 let norm2 = crate::warping::l2_norm_l2(&raw2, &time);
3688 let to: Vec<f64> = raw2.iter().map(|&v| v / norm2).collect();
3689
3690 let zero = vec![0.0; m];
3691 let result = parallel_transport_pole(&zero, &from, &to, &time);
3692 let result_norm: f64 = result.iter().map(|v| v * v).sum::<f64>().sqrt();
3693 assert!(
3694 result_norm < 1e-6,
3695 "Transporting zero should give zero, got norm {result_norm}"
3696 );
3697 }
3698
3699 #[test]
3700 fn test_schilds_preserves_norm() {
3701 let m = 51;
3702 let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
3703 let raw = vec![1.0; m];
3704 let norm = crate::warping::l2_norm_l2(&raw, &time);
3705 let from: Vec<f64> = raw.iter().map(|&v| v / norm).collect();
3706 let raw2: Vec<f64> = time
3707 .iter()
3708 .map(|&t| 1.0 + 0.15 * (2.0 * std::f64::consts::PI * t).sin())
3709 .collect();
3710 let norm2 = crate::warping::l2_norm_l2(&raw2, &time);
3711 let to: Vec<f64> = raw2.iter().map(|&v| v / norm2).collect();
3712
3713 let v: Vec<f64> = time
3715 .iter()
3716 .map(|&t| 0.1 * (4.0 * std::f64::consts::PI * t).cos())
3717 .collect();
3718 let v_norm = crate::warping::l2_norm_l2(&v, &time);
3719
3720 let transported = parallel_transport_schilds(&v, &from, &to, &time);
3721 let t_norm = crate::warping::l2_norm_l2(&transported, &time);
3722
3723 assert!(
3725 (t_norm - v_norm).abs() / v_norm.max(1e-10) < 1.5,
3726 "Schild's should roughly preserve norm: original={v_norm:.4}, transported={t_norm:.4}"
3727 );
3728 }
3729
3730 #[test]
3731 fn test_tsrvf_logmap_matches_original() {
3732 let m = 50;
3733 let n = 5;
3734 let t = uniform_grid(m);
3735 let data = make_test_data(n, m, 42);
3736
3737 let result_orig = tsrvf_transform(&data, &t, 5, 1e-3, 0.0);
3738 let result_logmap =
3739 tsrvf_transform_with_method(&data, &t, 5, 1e-3, 0.0, TransportMethod::LogMap);
3740
3741 for i in 0..n {
3743 for j in 0..m {
3744 assert!(
3745 (result_orig.tangent_vectors[(i, j)] - result_logmap.tangent_vectors[(i, j)])
3746 .abs()
3747 < 1e-12,
3748 "LogMap variant should match original at ({i},{j})"
3749 );
3750 }
3751 }
3752 }
3753
3754 #[test]
3755 fn test_tsrvf_with_schilds_produces_valid_result() {
3756 let m = 50;
3757 let n = 5;
3758 let t = uniform_grid(m);
3759 let data = make_test_data(n, m, 42);
3760
3761 let result =
3762 tsrvf_transform_with_method(&data, &t, 5, 1e-3, 0.0, TransportMethod::SchildsLadder);
3763
3764 assert_eq!(result.tangent_vectors.shape(), (n, m));
3765 for i in 0..n {
3766 for j in 0..m {
3767 assert!(
3768 result.tangent_vectors[(i, j)].is_finite(),
3769 "Schild's TSRVF should produce finite tangent vectors at ({i},{j})"
3770 );
3771 }
3772 }
3773 }
3774
3775 #[test]
3776 fn test_transport_methods_differ() {
3777 let m = 50;
3778 let n = 5;
3779 let t = uniform_grid(m);
3780 let data = make_test_data(n, m, 42);
3781 let karcher = karcher_mean(&data, &t, 5, 1e-3, 0.0);
3782
3783 let r_log = tsrvf_from_alignment_with_method(&karcher, &t, TransportMethod::LogMap);
3784 let r_schilds =
3785 tsrvf_from_alignment_with_method(&karcher, &t, TransportMethod::SchildsLadder);
3786
3787 let mut total_diff = 0.0;
3789 for i in 0..n {
3790 for j in 0..m {
3791 total_diff +=
3792 (r_log.tangent_vectors[(i, j)] - r_schilds.tangent_vectors[(i, j)]).abs();
3793 }
3794 }
3795
3796 assert!(total_diff.is_finite());
3799 }
3800
3801 #[test]
3804 fn test_warp_complexity_identity_is_zero() {
3805 let m = 50;
3806 let t = uniform_grid(m);
3807 let identity = t.clone();
3808 let c = warp_complexity(&identity, &t);
3809 assert!(
3810 c < 1e-10,
3811 "Identity warp should have zero complexity, got {c}"
3812 );
3813 }
3814
3815 #[test]
3816 fn test_warp_complexity_nonidentity_positive() {
3817 let m = 50;
3818 let t = uniform_grid(m);
3819 let gamma: Vec<f64> = t.iter().map(|&ti| ti * ti).collect();
3820 let c = warp_complexity(&gamma, &t);
3821 assert!(
3822 c > 0.01,
3823 "Non-identity warp should have positive complexity, got {c}"
3824 );
3825 }
3826
3827 #[test]
3828 fn test_warp_smoothness_identity_is_zero() {
3829 let m = 50;
3830 let t = uniform_grid(m);
3831 let identity = t.clone();
3832 let s = warp_smoothness(&identity, &t);
3833 assert!(
3834 s < 1e-6,
3835 "Identity warp (constant γ'=1, γ''=0) should have near-zero bending energy, got {s}"
3836 );
3837 }
3838
3839 #[test]
3840 fn test_alignment_quality_basic() {
3841 let m = 50;
3842 let n = 8;
3843 let t = uniform_grid(m);
3844 let data = make_test_data(n, m, 42);
3845 let karcher = karcher_mean(&data, &t, 10, 1e-3, 0.0);
3846 let quality = alignment_quality(&data, &karcher, &t);
3847
3848 assert_eq!(quality.warp_complexity.len(), n);
3850 assert_eq!(quality.warp_smoothness.len(), n);
3851 assert_eq!(quality.pointwise_variance_ratio.len(), m);
3852
3853 assert!(quality.total_variance >= 0.0);
3855 assert!(quality.amplitude_variance >= 0.0);
3856 assert!(quality.phase_variance >= 0.0);
3857 assert!(quality.mean_warp_complexity >= 0.0);
3858 assert!(quality.mean_warp_smoothness >= 0.0);
3859
3860 assert!(
3862 quality.amplitude_variance <= quality.total_variance + 1e-10,
3863 "Amplitude variance ({}) should be ≤ total variance ({})",
3864 quality.amplitude_variance,
3865 quality.total_variance
3866 );
3867 }
3868
3869 #[test]
3870 fn test_alignment_quality_identical_curves() {
3871 let m = 50;
3872 let t = uniform_grid(m);
3873 let curve: Vec<f64> = t
3874 .iter()
3875 .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
3876 .collect();
3877 let mut col_major = vec![0.0; 5 * m];
3878 for i in 0..5 {
3879 for j in 0..m {
3880 col_major[i + j * 5] = curve[j];
3881 }
3882 }
3883 let data = FdMatrix::from_column_major(col_major, 5, m).unwrap();
3884 let karcher = karcher_mean(&data, &t, 5, 1e-3, 0.0);
3885 let quality = alignment_quality(&data, &karcher, &t);
3886
3887 assert!(
3889 quality.total_variance < 0.01,
3890 "Identical curves should have near-zero total variance, got {}",
3891 quality.total_variance
3892 );
3893 assert!(
3894 quality.mean_warp_complexity < 0.1,
3895 "Identical curves should have near-zero warp complexity, got {}",
3896 quality.mean_warp_complexity
3897 );
3898 }
3899
3900 #[test]
3901 fn test_alignment_quality_variance_reduction() {
3902 let m = 50;
3903 let n = 10;
3904 let t = uniform_grid(m);
3905 let data = make_test_data(n, m, 42);
3906 let karcher = karcher_mean(&data, &t, 10, 1e-3, 0.0);
3907 let quality = alignment_quality(&data, &karcher, &t);
3908
3909 assert!(
3911 quality.mean_variance_reduction <= 1.5,
3912 "Mean variance reduction ratio should be ≤ ~1, got {}",
3913 quality.mean_variance_reduction
3914 );
3915 }
3916
3917 #[test]
3918 fn test_pairwise_consistency_small() {
3919 let m = 40;
3920 let n = 4;
3921 let t = uniform_grid(m);
3922 let data = make_test_data(n, m, 42);
3923
3924 let consistency = pairwise_consistency(&data, &t, 0.0, 100);
3925 assert!(
3926 consistency.is_finite() && consistency >= 0.0,
3927 "Pairwise consistency should be finite and non-negative, got {consistency}"
3928 );
3929 }
3930
3931 #[test]
3934 fn test_srsf_nd_d1_matches_existing() {
3935 let m = 50;
3936 let t = uniform_grid(m);
3937 let data = make_test_data(3, m, 42);
3938
3939 let q_1d = srsf_transform(&data, &t);
3941
3942 let data_nd = FdCurveSet::from_1d(data);
3944 let q_nd = srsf_transform_nd(&data_nd, &t);
3945
3946 assert_eq!(q_nd.ndim(), 1);
3947 for i in 0..3 {
3948 for j in 0..m {
3949 assert!(
3950 (q_1d[(i, j)] - q_nd.dims[0][(i, j)]).abs() < 1e-10,
3951 "1D nd SRSF should match existing at ({i},{j}): {} vs {}",
3952 q_1d[(i, j)],
3953 q_nd.dims[0][(i, j)]
3954 );
3955 }
3956 }
3957 }
3958
3959 #[test]
3960 fn test_srsf_nd_constant_is_zero() {
3961 let m = 30;
3962 let t = uniform_grid(m);
3963 let dim0 = FdMatrix::from_column_major(vec![3.0; m], 1, m).unwrap();
3965 let dim1 = FdMatrix::from_column_major(vec![-1.0; m], 1, m).unwrap();
3966 let data = FdCurveSet::from_dims(vec![dim0, dim1]).unwrap();
3967
3968 let q = srsf_transform_nd(&data, &t);
3969 for k in 0..2 {
3970 for j in 0..m {
3971 assert!(
3972 q.dims[k][(0, j)].abs() < 1e-10,
3973 "Constant curve SRSF should be zero, dim {k} at {j}: {}",
3974 q.dims[k][(0, j)]
3975 );
3976 }
3977 }
3978 }
3979
3980 #[test]
3981 fn test_srsf_nd_linear_r2() {
3982 let m = 51;
3983 let t = uniform_grid(m);
3984 let dim0 =
3987 FdMatrix::from_slice(&t.iter().map(|&ti| 2.0 * ti).collect::<Vec<_>>(), 1, m).unwrap();
3988 let dim1 =
3989 FdMatrix::from_slice(&t.iter().map(|&ti| 3.0 * ti).collect::<Vec<_>>(), 1, m).unwrap();
3990 let data = FdCurveSet::from_dims(vec![dim0, dim1]).unwrap();
3991
3992 let q = srsf_transform_nd(&data, &t);
3993 let expected_scale = 1.0 / 13.0_f64.powf(0.25);
3994 let mid = m / 2;
3995
3996 assert!(
3997 (q.dims[0][(0, mid)] - 2.0 * expected_scale).abs() < 0.1,
3998 "q_x at midpoint: {} vs expected {}",
3999 q.dims[0][(0, mid)],
4000 2.0 * expected_scale
4001 );
4002 assert!(
4003 (q.dims[1][(0, mid)] - 3.0 * expected_scale).abs() < 0.1,
4004 "q_y at midpoint: {} vs expected {}",
4005 q.dims[1][(0, mid)],
4006 3.0 * expected_scale
4007 );
4008 }
4009
4010 #[test]
4011 fn test_srsf_nd_round_trip() {
4012 let m = 51;
4013 let t = uniform_grid(m);
4014 let pi2 = 2.0 * std::f64::consts::PI;
4016 let vals_x: Vec<f64> = t.iter().map(|&ti| (pi2 * ti).sin()).collect();
4017 let vals_y: Vec<f64> = t.iter().map(|&ti| (pi2 * ti).cos()).collect();
4018 let dim0 = FdMatrix::from_slice(&vals_x, 1, m).unwrap();
4019 let dim1 = FdMatrix::from_slice(&vals_y, 1, m).unwrap();
4020 let data = FdCurveSet::from_dims(vec![dim0, dim1]).unwrap();
4021
4022 let q = srsf_transform_nd(&data, &t);
4023 let q_vecs: Vec<Vec<f64>> = q.dims.iter().map(|dm| dm.row(0)).collect();
4024 let f0 = vec![vals_x[0], vals_y[0]];
4025 let recon = srsf_inverse_nd(&q_vecs, &t, &f0);
4026
4027 let mut max_err = 0.0_f64;
4029 for k in 0..2 {
4030 let orig = if k == 0 { &vals_x } else { &vals_y };
4031 for j in 2..(m - 2) {
4032 let err = (recon[k][j] - orig[j]).abs();
4033 max_err = max_err.max(err);
4034 }
4035 }
4036 assert!(
4037 max_err < 0.2,
4038 "SRSF round-trip max error should be small, got {max_err}"
4039 );
4040 }
4041
4042 #[test]
4043 fn test_align_nd_identical_near_zero() {
4044 let m = 50;
4045 let t = uniform_grid(m);
4046 let pi2 = 2.0 * std::f64::consts::PI;
4047 let vals_x: Vec<f64> = t.iter().map(|&ti| (pi2 * ti).sin()).collect();
4048 let vals_y: Vec<f64> = t.iter().map(|&ti| (pi2 * ti).cos()).collect();
4049 let dim0 = FdMatrix::from_slice(&vals_x, 1, m).unwrap();
4050 let dim1 = FdMatrix::from_slice(&vals_y, 1, m).unwrap();
4051 let data = FdCurveSet::from_dims(vec![dim0, dim1]).unwrap();
4052
4053 let result = elastic_align_pair_nd(&data, &data, &t, 0.0);
4054 assert!(
4055 result.distance < 0.5,
4056 "Self-alignment distance should be ~0, got {}",
4057 result.distance
4058 );
4059 let max_dev: f64 = result
4061 .gamma
4062 .iter()
4063 .zip(t.iter())
4064 .map(|(g, ti)| (g - ti).abs())
4065 .fold(0.0_f64, f64::max);
4066 assert!(
4067 max_dev < 0.1,
4068 "Self-alignment warp should be near identity, max dev = {max_dev}"
4069 );
4070 }
4071
4072 #[test]
4073 fn test_align_nd_shifted_r2() {
4074 let m = 60;
4075 let t = uniform_grid(m);
4076 let pi2 = 2.0 * std::f64::consts::PI;
4077
4078 let f1x: Vec<f64> = t.iter().map(|&ti| (pi2 * ti).sin()).collect();
4080 let f1y: Vec<f64> = t.iter().map(|&ti| (pi2 * ti).cos()).collect();
4081 let f1 = FdCurveSet::from_dims(vec![
4082 FdMatrix::from_slice(&f1x, 1, m).unwrap(),
4083 FdMatrix::from_slice(&f1y, 1, m).unwrap(),
4084 ])
4085 .unwrap();
4086
4087 let f2x: Vec<f64> = t.iter().map(|&ti| (pi2 * (ti - 0.1)).sin()).collect();
4089 let f2y: Vec<f64> = t.iter().map(|&ti| (pi2 * (ti - 0.1)).cos()).collect();
4090 let f2 = FdCurveSet::from_dims(vec![
4091 FdMatrix::from_slice(&f2x, 1, m).unwrap(),
4092 FdMatrix::from_slice(&f2y, 1, m).unwrap(),
4093 ])
4094 .unwrap();
4095
4096 let result = elastic_align_pair_nd(&f1, &f2, &t, 0.0);
4097 assert!(
4098 result.distance.is_finite(),
4099 "Distance should be finite, got {}",
4100 result.distance
4101 );
4102 assert_eq!(result.f_aligned.len(), 2);
4103 assert_eq!(result.f_aligned[0].len(), m);
4104 let max_dev: f64 = result
4106 .gamma
4107 .iter()
4108 .zip(t.iter())
4109 .map(|(g, ti)| (g - ti).abs())
4110 .fold(0.0_f64, f64::max);
4111 assert!(
4112 max_dev > 0.01,
4113 "Shifted curves should require non-trivial warp, max dev = {max_dev}"
4114 );
4115 }
4116
4117 #[test]
4120 fn test_constrained_no_landmarks_matches_unconstrained() {
4121 let m = 50;
4122 let t = uniform_grid(m);
4123 let f1: Vec<f64> = t
4124 .iter()
4125 .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
4126 .collect();
4127 let f2: Vec<f64> = t
4128 .iter()
4129 .map(|&ti| (2.0 * std::f64::consts::PI * (ti - 0.1)).sin())
4130 .collect();
4131
4132 let r_free = elastic_align_pair(&f1, &f2, &t, 0.0);
4133 let r_const = elastic_align_pair_constrained(&f1, &f2, &t, &[], 0.0);
4134
4135 for j in 0..m {
4137 assert!(
4138 (r_free.gamma[j] - r_const.gamma[j]).abs() < 1e-10,
4139 "No-landmark constrained should match unconstrained at {j}"
4140 );
4141 }
4142 assert!(r_const.enforced_landmarks.is_empty());
4143 }
4144
4145 #[test]
4146 fn test_constrained_single_landmark_enforced() {
4147 let m = 60;
4148 let t = uniform_grid(m);
4149 let f1: Vec<f64> = t
4150 .iter()
4151 .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
4152 .collect();
4153 let f2: Vec<f64> = t
4154 .iter()
4155 .map(|&ti| (2.0 * std::f64::consts::PI * (ti - 0.1)).sin())
4156 .collect();
4157
4158 let result = elastic_align_pair_constrained(&f1, &f2, &t, &[(0.5, 0.5)], 0.0);
4160
4161 let mid_idx = snap_to_grid(0.5, &t);
4163 assert!(
4164 (result.gamma[mid_idx] - 0.5).abs() < 0.05,
4165 "Constrained gamma at midpoint should be ~0.5, got {}",
4166 result.gamma[mid_idx]
4167 );
4168 assert_eq!(result.enforced_landmarks.len(), 1);
4169 }
4170
4171 #[test]
4172 fn test_constrained_multiple_landmarks() {
4173 let m = 80;
4174 let t = uniform_grid(m);
4175 let f1: Vec<f64> = t
4176 .iter()
4177 .map(|&ti| (4.0 * std::f64::consts::PI * ti).sin())
4178 .collect();
4179 let f2: Vec<f64> = t
4180 .iter()
4181 .map(|&ti| (4.0 * std::f64::consts::PI * (ti - 0.05)).sin())
4182 .collect();
4183
4184 let landmarks = vec![(0.25, 0.25), (0.5, 0.5), (0.75, 0.75)];
4185 let result = elastic_align_pair_constrained(&f1, &f2, &t, &landmarks, 0.0);
4186
4187 for &(tt, st) in &landmarks {
4189 let idx = snap_to_grid(tt, &t);
4190 assert!(
4191 (result.gamma[idx] - st).abs() < 0.05,
4192 "Gamma at t={tt} should be ~{st}, got {}",
4193 result.gamma[idx]
4194 );
4195 }
4196 }
4197
4198 #[test]
4199 fn test_constrained_monotone_gamma() {
4200 let m = 60;
4201 let t = uniform_grid(m);
4202 let f1: Vec<f64> = t
4203 .iter()
4204 .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
4205 .collect();
4206 let f2: Vec<f64> = t
4207 .iter()
4208 .map(|&ti| (2.0 * std::f64::consts::PI * (ti - 0.1)).sin())
4209 .collect();
4210
4211 let result = elastic_align_pair_constrained(&f1, &f2, &t, &[(0.3, 0.3), (0.7, 0.7)], 0.0);
4212
4213 for j in 1..m {
4215 assert!(
4216 result.gamma[j] >= result.gamma[j - 1] - 1e-10,
4217 "Gamma should be monotone: gamma[{}]={} < gamma[{}]={}",
4218 j,
4219 result.gamma[j],
4220 j - 1,
4221 result.gamma[j - 1]
4222 );
4223 }
4224 assert!((result.gamma[0] - t[0]).abs() < 1e-10);
4226 assert!((result.gamma[m - 1] - t[m - 1]).abs() < 1e-10);
4227 }
4228
4229 #[test]
4230 fn test_constrained_distance_ge_unconstrained() {
4231 let m = 60;
4232 let t = uniform_grid(m);
4233 let f1: Vec<f64> = t
4234 .iter()
4235 .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
4236 .collect();
4237 let f2: Vec<f64> = t
4238 .iter()
4239 .map(|&ti| (2.0 * std::f64::consts::PI * (ti - 0.15)).sin())
4240 .collect();
4241
4242 let r_free = elastic_align_pair(&f1, &f2, &t, 0.0);
4243 let r_const = elastic_align_pair_constrained(&f1, &f2, &t, &[(0.5, 0.5)], 0.0);
4244
4245 assert!(
4247 r_const.distance >= r_free.distance - 1e-6,
4248 "Constrained distance ({}) should be >= unconstrained ({})",
4249 r_const.distance,
4250 r_free.distance
4251 );
4252 }
4253
4254 #[test]
4255 fn test_constrained_with_landmark_detection() {
4256 let m = 80;
4257 let t = uniform_grid(m);
4258 let f1: Vec<f64> = t
4259 .iter()
4260 .map(|&ti| (4.0 * std::f64::consts::PI * ti).sin())
4261 .collect();
4262 let f2: Vec<f64> = t
4263 .iter()
4264 .map(|&ti| (4.0 * std::f64::consts::PI * (ti - 0.05)).sin())
4265 .collect();
4266
4267 let result = elastic_align_pair_with_landmarks(
4268 &f1,
4269 &f2,
4270 &t,
4271 crate::landmark::LandmarkKind::Peak,
4272 0.1,
4273 0,
4274 0.0,
4275 );
4276
4277 assert_eq!(result.gamma.len(), m);
4278 assert_eq!(result.f_aligned.len(), m);
4279 assert!(result.distance.is_finite());
4280 for j in 1..m {
4282 assert!(
4283 result.gamma[j] >= result.gamma[j - 1] - 1e-10,
4284 "Gamma should be monotone at j={j}"
4285 );
4286 }
4287 }
4288}