1use crate::fdata::{deriv_1d, mean_1d};
15use crate::helpers::{
16 cumulative_trapz, gradient_uniform, l2_distance, linear_interp, simpsons_weights,
17};
18use crate::iter_maybe_parallel;
19use crate::matrix::FdMatrix;
20use crate::smoothing::nadaraya_watson;
21use crate::warping::{
22 exp_map_sphere, gam_to_psi, inv_exp_map_sphere, invert_gamma, l2_norm_l2, normalize_warp,
23 psi_to_gam,
24};
25#[cfg(feature = "parallel")]
26use rayon::iter::ParallelIterator;
27
28#[derive(Debug, Clone)]
32pub struct AlignmentResult {
33 pub gamma: Vec<f64>,
35 pub f_aligned: Vec<f64>,
37 pub distance: f64,
39}
40
41#[derive(Debug, Clone)]
43pub struct AlignmentSetResult {
44 pub gammas: FdMatrix,
46 pub aligned_data: FdMatrix,
48 pub distances: Vec<f64>,
50}
51
52#[derive(Debug, Clone)]
54pub struct KarcherMeanResult {
55 pub mean: Vec<f64>,
57 pub mean_srsf: Vec<f64>,
59 pub gammas: FdMatrix,
61 pub aligned_data: FdMatrix,
63 pub n_iter: usize,
65 pub converged: bool,
67 pub aligned_srsfs: Option<FdMatrix>,
70}
71
72fn karcher_sphere_step(mu: &mut Vec<f64>, psis: &[Vec<f64>], time: &[f64], step_size: f64) -> bool {
85 let m = mu.len();
86 let n = psis.len();
87 let mut vbar = vec![0.0; m];
88 for psi in psis {
89 let v = inv_exp_map_sphere(mu, psi, time);
90 for j in 0..m {
91 vbar[j] += v[j];
92 }
93 }
94 for j in 0..m {
95 vbar[j] /= n as f64;
96 }
97 if l2_norm_l2(&vbar, time) <= 1e-8 {
98 return true;
99 }
100 let scaled: Vec<f64> = vbar.iter().map(|&v| v * step_size).collect();
101 *mu = exp_map_sphere(mu, &scaled, time);
102 false
103}
104
105pub(crate) fn sqrt_mean_inverse(gammas: &FdMatrix, argvals: &[f64]) -> Vec<f64> {
106 let (n, m) = gammas.shape();
107 let t0 = argvals[0];
108 let t1 = argvals[m - 1];
109 let domain = t1 - t0;
110
111 let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
112 let binsize = 1.0 / (m - 1) as f64;
113
114 let psis: Vec<Vec<f64>> = (0..n)
115 .map(|i| {
116 let gam_01: Vec<f64> = (0..m).map(|j| (gammas[(i, j)] - t0) / domain).collect();
117 gam_to_psi(&gam_01, binsize)
118 })
119 .collect();
120
121 let mut mu = vec![0.0; m];
122 for psi in &psis {
123 for j in 0..m {
124 mu[j] += psi[j];
125 }
126 }
127 for j in 0..m {
128 mu[j] /= n as f64;
129 }
130
131 for _ in 0..501 {
132 if karcher_sphere_step(&mut mu, &psis, &time, 0.3) {
133 break;
134 }
135 }
136
137 let gam_mu = psi_to_gam(&mu, &time);
138 let gam_inv = invert_gamma(&gam_mu, &time);
139 gam_inv.iter().map(|&g| t0 + g * domain).collect()
140}
141
142pub fn srsf_transform(data: &FdMatrix, argvals: &[f64]) -> FdMatrix {
155 let (n, m) = data.shape();
156 if n == 0 || m == 0 || argvals.len() != m {
157 return FdMatrix::zeros(n, m);
158 }
159
160 let deriv = deriv_1d(data, argvals, 1);
161
162 let mut result = FdMatrix::zeros(n, m);
163 for i in 0..n {
164 for j in 0..m {
165 let d = deriv[(i, j)];
166 result[(i, j)] = d.signum() * d.abs().sqrt();
167 }
168 }
169 result
170}
171
172pub fn srsf_inverse(q: &[f64], argvals: &[f64], f0: f64) -> Vec<f64> {
184 let m = q.len();
185 if m == 0 {
186 return Vec::new();
187 }
188
189 let integrand: Vec<f64> = q.iter().map(|&qi| qi * qi.abs()).collect();
191 let integral = cumulative_trapz(&integrand, argvals);
192
193 integral.iter().map(|&v| f0 + v).collect()
194}
195
196pub fn reparameterize_curve(f: &[f64], argvals: &[f64], gamma: &[f64]) -> Vec<f64> {
207 gamma
208 .iter()
209 .map(|&g| linear_interp(argvals, f, g))
210 .collect()
211}
212
213pub fn compose_warps(gamma1: &[f64], gamma2: &[f64], argvals: &[f64]) -> Vec<f64> {
220 gamma2
221 .iter()
222 .map(|&g| linear_interp(argvals, gamma1, g))
223 .collect()
224}
225
226#[cfg(test)]
231fn gcd(a: usize, b: usize) -> usize {
232 if b == 0 {
233 a
234 } else {
235 gcd(b, a % b)
236 }
237}
238
239#[cfg(test)]
242fn generate_coprime_nbhd(nbhd_dim: usize) -> Vec<(usize, usize)> {
243 let mut pairs = Vec::new();
244 for i in 1..=nbhd_dim {
245 for j in 1..=nbhd_dim {
246 if gcd(i, j) == 1 {
247 pairs.push((i, j));
248 }
249 }
250 }
251 pairs
252}
253
254#[rustfmt::skip]
258const COPRIME_NBHD_7: [(usize, usize); 35] = [
259 (1,1),(1,2),(1,3),(1,4),(1,5),(1,6),(1,7),
260 (2,1), (2,3), (2,5), (2,7),
261 (3,1),(3,2), (3,4),(3,5), (3,7),
262 (4,1), (4,3), (4,5), (4,7),
263 (5,1),(5,2),(5,3),(5,4), (5,6),(5,7),
264 (6,1), (6,5), (6,7),
265 (7,1),(7,2),(7,3),(7,4),(7,5),(7,6),
266];
267
268#[inline]
276fn dp_edge_weight(
277 q1: &[f64],
278 q2: &[f64],
279 argvals: &[f64],
280 sc: usize,
281 tc: usize,
282 sr: usize,
283 tr: usize,
284) -> f64 {
285 let n1 = tc - sc;
286 let n2 = tr - sr;
287 if n1 == 0 || n2 == 0 {
288 return f64::INFINITY;
289 }
290
291 let slope = (argvals[tr] - argvals[sr]) / (argvals[tc] - argvals[sc]);
292 let rslope = slope.sqrt();
293
294 let mut weight = 0.0;
296 let mut i1 = 0usize; let mut i2 = 0usize; while i1 < n1 && i2 < n2 {
300 let left1 = i1 as f64 / n1 as f64;
302 let right1 = (i1 + 1) as f64 / n1 as f64;
303 let left2 = i2 as f64 / n2 as f64;
304 let right2 = (i2 + 1) as f64 / n2 as f64;
305
306 let left = left1.max(left2);
307 let right = right1.min(right2);
308 let dt = right - left;
309
310 if dt > 0.0 {
311 let diff = q1[sc + i1] - rslope * q2[sr + i2];
312 weight += diff * diff * dt;
313 }
314
315 if right1 < right2 {
317 i1 += 1;
318 } else if right2 < right1 {
319 i2 += 1;
320 } else {
321 i1 += 1;
322 i2 += 1;
323 }
324 }
325
326 weight * (argvals[tc] - argvals[sc])
328}
329
330#[inline]
332fn dp_lambda_penalty(
333 argvals: &[f64],
334 sc: usize,
335 tc: usize,
336 sr: usize,
337 tr: usize,
338 lambda: f64,
339) -> f64 {
340 if lambda > 0.0 {
341 let dt = argvals[tc] - argvals[sc];
342 let slope = (argvals[tr] - argvals[sr]) / dt;
343 lambda * (slope - 1.0).powi(2) * dt
344 } else {
345 0.0
346 }
347}
348
349fn dp_traceback(parent: &[u32], nrows: usize, ncols: usize) -> Vec<(usize, usize)> {
353 let mut path = Vec::with_capacity(nrows + ncols);
354 let mut cur = (nrows - 1) * ncols + (ncols - 1);
355 loop {
356 path.push((cur / ncols, cur % ncols));
357 if cur == 0 || parent[cur] == u32::MAX {
358 break;
359 }
360 cur = parent[cur] as usize;
361 }
362 path.reverse();
363 path
364}
365
366#[inline]
368fn dp_relax_cell<F>(
369 e: &mut [f64],
370 parent: &mut [u32],
371 ncols: usize,
372 tr: usize,
373 tc: usize,
374 edge_cost: &F,
375) where
376 F: Fn(usize, usize, usize, usize) -> f64,
377{
378 let idx = tr * ncols + tc;
379 for &(dr, dc) in &COPRIME_NBHD_7 {
380 if dr > tr || dc > tc {
381 continue;
382 }
383 let sr = tr - dr;
384 let sc = tc - dc;
385 let src_idx = sr * ncols + sc;
386 if e[src_idx] == f64::INFINITY {
387 continue;
388 }
389 let cost = e[src_idx] + edge_cost(sr, sc, tr, tc);
390 if cost < e[idx] {
391 e[idx] = cost;
392 parent[idx] = src_idx as u32;
393 }
394 }
395}
396
397fn dp_grid_solve<F>(nrows: usize, ncols: usize, edge_cost: F) -> Vec<(usize, usize)>
403where
404 F: Fn(usize, usize, usize, usize) -> f64,
405{
406 let mut e = vec![f64::INFINITY; nrows * ncols];
407 let mut parent = vec![u32::MAX; nrows * ncols];
408 e[0] = 0.0;
409
410 for tr in 0..nrows {
411 for tc in 0..ncols {
412 if tr == 0 && tc == 0 {
413 continue;
414 }
415 dp_relax_cell(&mut e, &mut parent, ncols, tr, tc, &edge_cost);
416 }
417 }
418
419 dp_traceback(&parent, nrows, ncols)
420}
421
422fn dp_path_to_gamma(path: &[(usize, usize)], argvals: &[f64]) -> Vec<f64> {
424 let path_tc: Vec<f64> = path.iter().map(|&(_, c)| argvals[c]).collect();
425 let path_tr: Vec<f64> = path.iter().map(|&(r, _)| argvals[r]).collect();
426 let mut gamma: Vec<f64> = argvals
427 .iter()
428 .map(|&t| linear_interp(&path_tc, &path_tr, t))
429 .collect();
430 normalize_warp(&mut gamma, argvals);
431 gamma
432}
433
434pub(crate) fn dp_alignment_core(q1: &[f64], q2: &[f64], argvals: &[f64], lambda: f64) -> Vec<f64> {
440 let m = argvals.len();
441 if m < 2 {
442 return argvals.to_vec();
443 }
444
445 let norm1 = q1.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
446 let norm2 = q2.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
447 let q1n: Vec<f64> = q1.iter().map(|&v| v / norm1).collect();
448 let q2n: Vec<f64> = q2.iter().map(|&v| v / norm2).collect();
449
450 let path = dp_grid_solve(m, m, |sr, sc, tr, tc| {
451 dp_edge_weight(&q1n, &q2n, argvals, sc, tc, sr, tr)
452 + dp_lambda_penalty(argvals, sc, tc, sr, tr, lambda)
453 });
454
455 dp_path_to_gamma(&path, argvals)
456}
457
458pub fn elastic_align_pair(f1: &[f64], f2: &[f64], argvals: &[f64], lambda: f64) -> AlignmentResult {
474 let m = f1.len();
475
476 let f1_mat = FdMatrix::from_slice(f1, 1, m).unwrap();
478 let f2_mat = FdMatrix::from_slice(f2, 1, m).unwrap();
479
480 let q1_mat = srsf_transform(&f1_mat, argvals);
481 let q2_mat = srsf_transform(&f2_mat, argvals);
482
483 let q1: Vec<f64> = q1_mat.row(0);
484 let q2: Vec<f64> = q2_mat.row(0);
485
486 let gamma = dp_alignment_core(&q1, &q2, argvals, lambda);
488
489 let f_aligned = reparameterize_curve(f2, argvals, &gamma);
491
492 let f_aligned_mat = FdMatrix::from_slice(&f_aligned, 1, m).unwrap();
494 let q_aligned_mat = srsf_transform(&f_aligned_mat, argvals);
495 let q_aligned: Vec<f64> = q_aligned_mat.row(0);
496
497 let weights = simpsons_weights(argvals);
498 let distance = l2_distance(&q1, &q_aligned, &weights);
499
500 AlignmentResult {
501 gamma,
502 f_aligned,
503 distance,
504 }
505}
506
507pub fn elastic_distance(f1: &[f64], f2: &[f64], argvals: &[f64], lambda: f64) -> f64 {
517 elastic_align_pair(f1, f2, argvals, lambda).distance
518}
519
520pub fn align_to_target(
531 data: &FdMatrix,
532 target: &[f64],
533 argvals: &[f64],
534 lambda: f64,
535) -> AlignmentSetResult {
536 let (n, m) = data.shape();
537
538 let results: Vec<AlignmentResult> = iter_maybe_parallel!(0..n)
539 .map(|i| {
540 let fi = data.row(i);
541 elastic_align_pair(target, &fi, argvals, lambda)
542 })
543 .collect();
544
545 let mut gammas = FdMatrix::zeros(n, m);
546 let mut aligned_data = FdMatrix::zeros(n, m);
547 let mut distances = Vec::with_capacity(n);
548
549 for (i, r) in results.into_iter().enumerate() {
550 for j in 0..m {
551 gammas[(i, j)] = r.gamma[j];
552 aligned_data[(i, j)] = r.f_aligned[j];
553 }
554 distances.push(r.distance);
555 }
556
557 AlignmentSetResult {
558 gammas,
559 aligned_data,
560 distances,
561 }
562}
563
564pub fn elastic_self_distance_matrix(data: &FdMatrix, argvals: &[f64], lambda: f64) -> FdMatrix {
579 let n = data.nrows();
580
581 let upper_vals: Vec<f64> = iter_maybe_parallel!(0..n)
582 .flat_map(|i| {
583 let fi = data.row(i);
584 ((i + 1)..n)
585 .map(|j| {
586 let fj = data.row(j);
587 elastic_distance(&fi, &fj, argvals, lambda)
588 })
589 .collect::<Vec<_>>()
590 })
591 .collect();
592
593 let mut dist = FdMatrix::zeros(n, n);
594 let mut idx = 0;
595 for i in 0..n {
596 for j in (i + 1)..n {
597 let d = upper_vals[idx];
598 dist[(i, j)] = d;
599 dist[(j, i)] = d;
600 idx += 1;
601 }
602 }
603 dist
604}
605
606pub fn elastic_cross_distance_matrix(
617 data1: &FdMatrix,
618 data2: &FdMatrix,
619 argvals: &[f64],
620 lambda: f64,
621) -> FdMatrix {
622 let n1 = data1.nrows();
623 let n2 = data2.nrows();
624
625 let vals: Vec<f64> = iter_maybe_parallel!(0..n1)
626 .flat_map(|i| {
627 let fi = data1.row(i);
628 (0..n2)
629 .map(|j| {
630 let fj = data2.row(j);
631 elastic_distance(&fi, &fj, argvals, lambda)
632 })
633 .collect::<Vec<_>>()
634 })
635 .collect();
636
637 let mut dist = FdMatrix::zeros(n1, n2);
638 for i in 0..n1 {
639 for j in 0..n2 {
640 dist[(i, j)] = vals[i * n2 + j];
641 }
642 }
643 dist
644}
645
646#[derive(Debug, Clone)]
650pub struct DecompositionResult {
651 pub alignment: AlignmentResult,
653 pub d_amplitude: f64,
655 pub d_phase: f64,
657}
658
659pub fn elastic_decomposition(
669 f1: &[f64],
670 f2: &[f64],
671 argvals: &[f64],
672 lambda: f64,
673) -> DecompositionResult {
674 let alignment = elastic_align_pair(f1, f2, argvals, lambda);
675 let d_amplitude = alignment.distance;
676 let d_phase = crate::warping::phase_distance(&alignment.gamma, argvals);
677 DecompositionResult {
678 alignment,
679 d_amplitude,
680 d_phase,
681 }
682}
683
684pub fn amplitude_distance(f1: &[f64], f2: &[f64], argvals: &[f64], lambda: f64) -> f64 {
686 elastic_distance(f1, f2, argvals, lambda)
687}
688
689pub fn phase_distance_pair(f1: &[f64], f2: &[f64], argvals: &[f64], lambda: f64) -> f64 {
691 let alignment = elastic_align_pair(f1, f2, argvals, lambda);
692 crate::warping::phase_distance(&alignment.gamma, argvals)
693}
694
695pub fn phase_self_distance_matrix(data: &FdMatrix, argvals: &[f64], lambda: f64) -> FdMatrix {
697 let n = data.nrows();
698
699 let upper_vals: Vec<f64> = iter_maybe_parallel!(0..n)
700 .flat_map(|i| {
701 let fi = data.row(i);
702 ((i + 1)..n)
703 .map(|j| {
704 let fj = data.row(j);
705 phase_distance_pair(&fi, &fj, argvals, lambda)
706 })
707 .collect::<Vec<_>>()
708 })
709 .collect();
710
711 let mut dist = FdMatrix::zeros(n, n);
712 let mut idx = 0;
713 for i in 0..n {
714 for j in (i + 1)..n {
715 let d = upper_vals[idx];
716 dist[(i, j)] = d;
717 dist[(j, i)] = d;
718 idx += 1;
719 }
720 }
721 dist
722}
723
724pub fn amplitude_self_distance_matrix(data: &FdMatrix, argvals: &[f64], lambda: f64) -> FdMatrix {
726 elastic_self_distance_matrix(data, argvals, lambda)
727}
728
729fn relative_change(q_old: &[f64], q_new: &[f64]) -> f64 {
736 let diff_norm: f64 = q_old
737 .iter()
738 .zip(q_new.iter())
739 .map(|(&a, &b)| (a - b).powi(2))
740 .sum::<f64>()
741 .sqrt();
742 let old_norm: f64 = q_old.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
743 diff_norm / old_norm
744}
745
746fn srsf_single(f: &[f64], argvals: &[f64]) -> Vec<f64> {
748 let m = f.len();
749 let mat = FdMatrix::from_slice(f, 1, m).unwrap();
750 let q_mat = srsf_transform(&mat, argvals);
751 q_mat.row(0)
752}
753
754fn align_srsf_pair(q1: &[f64], q2: &[f64], argvals: &[f64], lambda: f64) -> (Vec<f64>, Vec<f64>) {
756 let gamma = dp_alignment_core(q1, q2, argvals, lambda);
757
758 let q2_warped = reparameterize_curve(q2, argvals, &gamma);
760
761 let m = gamma.len();
763 let mut gamma_dot = vec![0.0; m];
764 gamma_dot[0] = (gamma[1] - gamma[0]) / (argvals[1] - argvals[0]);
765 for j in 1..(m - 1) {
766 gamma_dot[j] = (gamma[j + 1] - gamma[j - 1]) / (argvals[j + 1] - argvals[j - 1]);
767 }
768 gamma_dot[m - 1] = (gamma[m - 1] - gamma[m - 2]) / (argvals[m - 1] - argvals[m - 2]);
769
770 let q2_aligned: Vec<f64> = q2_warped
772 .iter()
773 .zip(gamma_dot.iter())
774 .map(|(&q, &gd)| q * gd.max(0.0).sqrt())
775 .collect();
776
777 (gamma, q2_aligned)
778}
779
780fn accumulate_alignments(
809 results: &[(Vec<f64>, Vec<f64>)],
810 gammas: &mut FdMatrix,
811 m: usize,
812 n: usize,
813) -> Vec<f64> {
814 let mut mu_q_new = vec![0.0; m];
815 for (i, (gamma, q_aligned)) in results.iter().enumerate() {
816 for j in 0..m {
817 gammas[(i, j)] = gamma[j];
818 mu_q_new[j] += q_aligned[j];
819 }
820 }
821 for j in 0..m {
822 mu_q_new[j] /= n as f64;
823 }
824 mu_q_new
825}
826
827fn apply_stored_warps(data: &FdMatrix, gammas: &FdMatrix, argvals: &[f64]) -> FdMatrix {
829 let (n, m) = data.shape();
830 let mut aligned = FdMatrix::zeros(n, m);
831 for i in 0..n {
832 let fi = data.row(i);
833 let gamma: Vec<f64> = (0..m).map(|j| gammas[(i, j)]).collect();
834 let f_aligned = reparameterize_curve(&fi, argvals, &gamma);
835 for j in 0..m {
836 aligned[(i, j)] = f_aligned[j];
837 }
838 }
839 aligned
840}
841
842fn select_template(srsf_mat: &FdMatrix, data: &FdMatrix, argvals: &[f64]) -> (Vec<f64>, Vec<f64>) {
844 let (n, m) = srsf_mat.shape();
845 let mnq = mean_1d(srsf_mat);
846 let mut min_dist = f64::INFINITY;
847 let mut min_idx = 0;
848 for i in 0..n {
849 let dist_sq: f64 = (0..m).map(|j| (srsf_mat[(i, j)] - mnq[j]).powi(2)).sum();
850 if dist_sq < min_dist {
851 min_dist = dist_sq;
852 min_idx = i;
853 }
854 }
855 let _ = argvals; (srsf_mat.row(min_idx), data.row(min_idx))
857}
858
859fn pre_center_template(
861 data: &FdMatrix,
862 mu_q: &[f64],
863 mu: &[f64],
864 argvals: &[f64],
865 lambda: f64,
866) -> (Vec<f64>, Vec<f64>) {
867 let (n, m) = data.shape();
868 let align_results: Vec<(Vec<f64>, Vec<f64>)> = iter_maybe_parallel!(0..n)
869 .map(|i| {
870 let fi = data.row(i);
871 let qi = srsf_single(&fi, argvals);
872 align_srsf_pair(mu_q, &qi, argvals, lambda)
873 })
874 .collect();
875
876 let mut init_gammas = FdMatrix::zeros(n, m);
877 for (i, (gamma, _)) in align_results.iter().enumerate() {
878 for j in 0..m {
879 init_gammas[(i, j)] = gamma[j];
880 }
881 }
882
883 let gam_inv = sqrt_mean_inverse(&init_gammas, argvals);
884 let mu_new = reparameterize_curve(mu, argvals, &gam_inv);
885 let mu_q_new = srsf_single(&mu_new, argvals);
886 (mu_q_new, mu_new)
887}
888
889fn post_center_results(
891 data: &FdMatrix,
892 mu_q: &[f64],
893 final_gammas: &mut FdMatrix,
894 argvals: &[f64],
895) -> (Vec<f64>, Vec<f64>, FdMatrix) {
896 let (n, m) = data.shape();
897 let gam_inv = sqrt_mean_inverse(final_gammas, argvals);
898 let h = (argvals[m - 1] - argvals[0]) / (m - 1) as f64;
899 let gam_inv_dev = gradient_uniform(&gam_inv, h);
900
901 let mu_q_warped = reparameterize_curve(mu_q, argvals, &gam_inv);
902 let mu_q_centered: Vec<f64> = mu_q_warped
903 .iter()
904 .zip(gam_inv_dev.iter())
905 .map(|(&q, &gd)| q * gd.max(0.0).sqrt())
906 .collect();
907
908 for i in 0..n {
909 let gam_i: Vec<f64> = (0..m).map(|j| final_gammas[(i, j)]).collect();
910 let gam_centered = reparameterize_curve(&gam_i, argvals, &gam_inv);
911 for j in 0..m {
912 final_gammas[(i, j)] = gam_centered[j];
913 }
914 }
915
916 let initial_mean = mean_1d(data);
917 let mu = srsf_inverse(&mu_q_centered, argvals, initial_mean[0]);
918 let final_aligned = apply_stored_warps(data, final_gammas, argvals);
919 (mu, mu_q_centered, final_aligned)
920}
921
922pub fn karcher_mean(
923 data: &FdMatrix,
924 argvals: &[f64],
925 max_iter: usize,
926 tol: f64,
927 lambda: f64,
928) -> KarcherMeanResult {
929 let (n, m) = data.shape();
930
931 let srsf_mat = srsf_transform(data, argvals);
932 let (mut mu_q, mu) = select_template(&srsf_mat, data, argvals);
933 let (mu_q_c, mu_c) = pre_center_template(data, &mu_q, &mu, argvals, lambda);
934 mu_q = mu_q_c;
935 let mut mu = mu_c;
936
937 let mut converged = false;
938 let mut n_iter = 0;
939 let mut final_gammas = FdMatrix::zeros(n, m);
940
941 for iter in 0..max_iter {
942 n_iter = iter + 1;
943
944 let align_results: Vec<(Vec<f64>, Vec<f64>)> = iter_maybe_parallel!(0..n)
945 .map(|i| {
946 let fi = data.row(i);
947 let qi = srsf_single(&fi, argvals);
948 align_srsf_pair(&mu_q, &qi, argvals, lambda)
949 })
950 .collect();
951
952 let mu_q_new = accumulate_alignments(&align_results, &mut final_gammas, m, n);
953
954 let rel = relative_change(&mu_q, &mu_q_new);
955 if rel < tol {
956 converged = true;
957 mu_q = mu_q_new;
958 break;
959 }
960
961 mu_q = mu_q_new;
962 mu = srsf_inverse(&mu_q, argvals, mu[0]);
963 }
964
965 let (mu_final, mu_q_final, final_aligned) =
966 post_center_results(data, &mu_q, &mut final_gammas, argvals);
967
968 KarcherMeanResult {
969 mean: mu_final,
970 mean_srsf: mu_q_final,
971 gammas: final_gammas,
972 aligned_data: final_aligned,
973 n_iter,
974 converged,
975 aligned_srsfs: None,
976 }
977}
978
979#[derive(Debug, Clone)]
986pub struct TsrvfResult {
987 pub tangent_vectors: FdMatrix,
989 pub mean: Vec<f64>,
991 pub mean_srsf: Vec<f64>,
993 pub mean_srsf_norm: f64,
995 pub srsf_norms: Vec<f64>,
997 pub initial_values: Vec<f64>,
999 pub gammas: FdMatrix,
1001 pub converged: bool,
1003}
1004
1005pub fn tsrvf_transform(
1016 data: &FdMatrix,
1017 argvals: &[f64],
1018 max_iter: usize,
1019 tol: f64,
1020 lambda: f64,
1021) -> TsrvfResult {
1022 let karcher = karcher_mean(data, argvals, max_iter, tol, lambda);
1023 tsrvf_from_alignment(&karcher, argvals)
1024}
1025
1026fn smooth_aligned_srsfs(srsf: &FdMatrix, m: usize) -> FdMatrix {
1032 let n = srsf.nrows();
1033 let time: Vec<f64> = (0..m).map(|j| j as f64 / (m - 1) as f64).collect();
1034 let bandwidth = 2.0 / (m - 1) as f64;
1035
1036 let mut smoothed = FdMatrix::zeros(n, m);
1037 for i in 0..n {
1038 let qi = srsf.row(i);
1039 let qi_smooth = nadaraya_watson(&time, &qi, &time, bandwidth, "gaussian");
1040 for j in 0..m {
1041 smoothed[(i, j)] = qi_smooth[j];
1042 }
1043 }
1044 smoothed
1045}
1046
1047pub fn tsrvf_from_alignment(karcher: &KarcherMeanResult, argvals: &[f64]) -> TsrvfResult {
1059 let (n, m) = karcher.aligned_data.shape();
1060
1061 let aligned_srsf = srsf_transform(&karcher.aligned_data, argvals);
1063
1064 let aligned_srsf = smooth_aligned_srsfs(&aligned_srsf, m);
1076
1077 let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
1081 let bandwidth = 2.0 / (m - 1) as f64;
1082 let mean_srsf_smooth = nadaraya_watson(&time, &karcher.mean_srsf, &time, bandwidth, "gaussian");
1083 let mean_norm = l2_norm_l2(&mean_srsf_smooth, &time);
1084
1085 let mu_unit: Vec<f64> = if mean_norm > 1e-10 {
1086 mean_srsf_smooth.iter().map(|&q| q / mean_norm).collect()
1087 } else {
1088 vec![0.0; m]
1089 };
1090
1091 let srsf_norms: Vec<f64> = iter_maybe_parallel!(0..n)
1093 .map(|i| {
1094 let qi = aligned_srsf.row(i);
1095 l2_norm_l2(&qi, &time)
1096 })
1097 .collect();
1098
1099 let tangent_data: Vec<Vec<f64>> = iter_maybe_parallel!(0..n)
1100 .map(|i| {
1101 let qi = aligned_srsf.row(i);
1102 let qi_norm = srsf_norms[i];
1103
1104 if qi_norm < 1e-10 || mean_norm < 1e-10 {
1105 return vec![0.0; m];
1106 }
1107
1108 let qi_unit: Vec<f64> = qi.iter().map(|&q| q / qi_norm).collect();
1110
1111 inv_exp_map_sphere(&mu_unit, &qi_unit, &time)
1113 })
1114 .collect();
1115
1116 let mut tangent_vectors = FdMatrix::zeros(n, m);
1118 for i in 0..n {
1119 for j in 0..m {
1120 tangent_vectors[(i, j)] = tangent_data[i][j];
1121 }
1122 }
1123
1124 let initial_values: Vec<f64> = (0..n).map(|i| karcher.aligned_data[(i, 0)]).collect();
1127
1128 TsrvfResult {
1129 tangent_vectors,
1130 mean: karcher.mean.clone(),
1131 mean_srsf: mean_srsf_smooth,
1132 mean_srsf_norm: mean_norm,
1133 srsf_norms,
1134 initial_values,
1135 gammas: karcher.gammas.clone(),
1136 converged: karcher.converged,
1137 }
1138}
1139
1140pub fn tsrvf_inverse(tsrvf: &TsrvfResult, argvals: &[f64]) -> FdMatrix {
1152 let (n, m) = tsrvf.tangent_vectors.shape();
1153
1154 let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
1155
1156 let mu_unit: Vec<f64> = if tsrvf.mean_srsf_norm > 1e-10 {
1158 tsrvf
1159 .mean_srsf
1160 .iter()
1161 .map(|&q| q / tsrvf.mean_srsf_norm)
1162 .collect()
1163 } else {
1164 vec![0.0; m]
1165 };
1166
1167 let curves: Vec<Vec<f64>> = iter_maybe_parallel!(0..n)
1168 .map(|i| {
1169 let vi = tsrvf.tangent_vectors.row(i);
1170
1171 let qi_unit = exp_map_sphere(&mu_unit, &vi, &time);
1173
1174 let qi: Vec<f64> = qi_unit.iter().map(|&q| q * tsrvf.srsf_norms[i]).collect();
1176
1177 srsf_inverse(&qi, argvals, tsrvf.initial_values[i])
1179 })
1180 .collect();
1181
1182 let mut result = FdMatrix::zeros(n, m);
1183 for i in 0..n {
1184 for j in 0..m {
1185 result[(i, j)] = curves[i][j];
1186 }
1187 }
1188 result
1189}
1190
1191#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
1195pub enum TransportMethod {
1196 #[default]
1198 LogMap,
1199 SchildsLadder,
1201 PoleLadder,
1203}
1204
1205fn parallel_transport_schilds(v: &[f64], from: &[f64], to: &[f64], time: &[f64]) -> Vec<f64> {
1207 use crate::warping::{exp_map_sphere, inv_exp_map_sphere};
1208
1209 let v_norm = crate::warping::l2_norm_l2(v, time);
1210 if v_norm < 1e-10 {
1211 return vec![0.0; v.len()];
1212 }
1213
1214 let endpoint = exp_map_sphere(from, v, time);
1216
1217 let log_to_ep = inv_exp_map_sphere(to, &endpoint, time);
1219
1220 let half_log: Vec<f64> = log_to_ep.iter().map(|&x| 0.5 * x).collect();
1222 let midpoint = exp_map_sphere(to, &half_log, time);
1223
1224 let log_to_mid = inv_exp_map_sphere(to, &midpoint, time);
1226 log_to_mid.iter().map(|&x| 2.0 * x).collect()
1227}
1228
1229fn parallel_transport_pole(v: &[f64], from: &[f64], to: &[f64], time: &[f64]) -> Vec<f64> {
1231 use crate::warping::{exp_map_sphere, inv_exp_map_sphere};
1232
1233 let v_norm = crate::warping::l2_norm_l2(v, time);
1234 if v_norm < 1e-10 {
1235 return vec![0.0; v.len()];
1236 }
1237
1238 let neg_v: Vec<f64> = v.iter().map(|&x| -x).collect();
1240 let pole = exp_map_sphere(from, &neg_v, time);
1241
1242 let log_to_pole = inv_exp_map_sphere(to, &pole, time);
1244
1245 let half_log: Vec<f64> = log_to_pole.iter().map(|&x| 0.5 * x).collect();
1247 let midpoint = exp_map_sphere(to, &half_log, time);
1248
1249 let log_to_mid = inv_exp_map_sphere(to, &midpoint, time);
1251 log_to_mid.iter().map(|&x| -2.0 * x).collect()
1252}
1253
1254pub fn tsrvf_transform_with_method(
1258 data: &FdMatrix,
1259 argvals: &[f64],
1260 max_iter: usize,
1261 tol: f64,
1262 lambda: f64,
1263 method: TransportMethod,
1264) -> TsrvfResult {
1265 let karcher = karcher_mean(data, argvals, max_iter, tol, lambda);
1266 tsrvf_from_alignment_with_method(&karcher, argvals, method)
1267}
1268
1269pub fn tsrvf_from_alignment_with_method(
1276 karcher: &KarcherMeanResult,
1277 argvals: &[f64],
1278 method: TransportMethod,
1279) -> TsrvfResult {
1280 if method == TransportMethod::LogMap {
1281 return tsrvf_from_alignment(karcher, argvals);
1282 }
1283
1284 let (n, m) = karcher.aligned_data.shape();
1285 let aligned_srsf = srsf_transform(&karcher.aligned_data, argvals);
1286 let aligned_srsf = smooth_aligned_srsfs(&aligned_srsf, m);
1287 let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
1288 let bandwidth = 2.0 / (m - 1) as f64;
1289 let mean_srsf_smooth = nadaraya_watson(&time, &karcher.mean_srsf, &time, bandwidth, "gaussian");
1290 let mean_norm = crate::warping::l2_norm_l2(&mean_srsf_smooth, &time);
1291
1292 let mu_unit: Vec<f64> = if mean_norm > 1e-10 {
1293 mean_srsf_smooth.iter().map(|&q| q / mean_norm).collect()
1294 } else {
1295 vec![0.0; m]
1296 };
1297
1298 let srsf_norms: Vec<f64> = iter_maybe_parallel!(0..n)
1299 .map(|i| {
1300 let qi = aligned_srsf.row(i);
1301 crate::warping::l2_norm_l2(&qi, &time)
1302 })
1303 .collect();
1304
1305 let tangent_data: Vec<Vec<f64>> = iter_maybe_parallel!(0..n)
1306 .map(|i| {
1307 let qi = aligned_srsf.row(i);
1308 let qi_norm = srsf_norms[i];
1309
1310 if qi_norm < 1e-10 || mean_norm < 1e-10 {
1311 return vec![0.0; m];
1312 }
1313
1314 let qi_unit: Vec<f64> = qi.iter().map(|&q| q / qi_norm).collect();
1315
1316 let v_at_qi = inv_exp_map_sphere(&qi_unit, &mu_unit, &time);
1318 let neg_v: Vec<f64> = v_at_qi.iter().map(|&x| -x).collect();
1319
1320 match method {
1322 TransportMethod::SchildsLadder => {
1323 parallel_transport_schilds(&neg_v, &qi_unit, &mu_unit, &time)
1324 }
1325 TransportMethod::PoleLadder => {
1326 parallel_transport_pole(&neg_v, &qi_unit, &mu_unit, &time)
1327 }
1328 TransportMethod::LogMap => unreachable!(),
1329 }
1330 })
1331 .collect();
1332
1333 let mut tangent_vectors = FdMatrix::zeros(n, m);
1334 for i in 0..n {
1335 for j in 0..m {
1336 tangent_vectors[(i, j)] = tangent_data[i][j];
1337 }
1338 }
1339
1340 let initial_values: Vec<f64> = (0..n).map(|i| karcher.aligned_data[(i, 0)]).collect();
1341
1342 TsrvfResult {
1343 tangent_vectors,
1344 mean: karcher.mean.clone(),
1345 mean_srsf: mean_srsf_smooth,
1346 mean_srsf_norm: mean_norm,
1347 srsf_norms,
1348 initial_values,
1349 gammas: karcher.gammas.clone(),
1350 converged: karcher.converged,
1351 }
1352}
1353
1354#[derive(Debug, Clone)]
1358pub struct AlignmentQuality {
1359 pub warp_complexity: Vec<f64>,
1361 pub mean_warp_complexity: f64,
1363 pub warp_smoothness: Vec<f64>,
1365 pub mean_warp_smoothness: f64,
1367 pub total_variance: f64,
1369 pub amplitude_variance: f64,
1371 pub phase_variance: f64,
1373 pub phase_amplitude_ratio: f64,
1375 pub pointwise_variance_ratio: Vec<f64>,
1377 pub mean_variance_reduction: f64,
1379}
1380
1381pub fn warp_complexity(gamma: &[f64], argvals: &[f64]) -> f64 {
1385 crate::warping::phase_distance(gamma, argvals)
1386}
1387
1388pub fn warp_smoothness(gamma: &[f64], argvals: &[f64]) -> f64 {
1390 let m = gamma.len();
1391 if m < 3 {
1392 return 0.0;
1393 }
1394
1395 let h = (argvals[m - 1] - argvals[0]) / (m - 1) as f64;
1396 let gam_prime = gradient_uniform(gamma, h);
1397 let gam_pprime = gradient_uniform(&gam_prime, h);
1398
1399 let integrand: Vec<f64> = gam_pprime.iter().map(|&g| g * g).collect();
1400 crate::helpers::trapz(&integrand, argvals)
1401}
1402
1403pub fn alignment_quality(
1410 data: &FdMatrix,
1411 karcher: &KarcherMeanResult,
1412 argvals: &[f64],
1413) -> AlignmentQuality {
1414 let (n, m) = data.shape();
1415 let weights = simpsons_weights(argvals);
1416
1417 let wc: Vec<f64> = (0..n)
1419 .map(|i| {
1420 let gamma: Vec<f64> = (0..m).map(|j| karcher.gammas[(i, j)]).collect();
1421 warp_complexity(&gamma, argvals)
1422 })
1423 .collect();
1424 let ws: Vec<f64> = (0..n)
1425 .map(|i| {
1426 let gamma: Vec<f64> = (0..m).map(|j| karcher.gammas[(i, j)]).collect();
1427 warp_smoothness(&gamma, argvals)
1428 })
1429 .collect();
1430
1431 let mean_wc = wc.iter().sum::<f64>() / n as f64;
1432 let mean_ws = ws.iter().sum::<f64>() / n as f64;
1433
1434 let orig_mean = crate::fdata::mean_1d(data);
1436
1437 let total_var: f64 = (0..n)
1439 .map(|i| {
1440 let fi = data.row(i);
1441 let d = l2_distance(&fi, &orig_mean, &weights);
1442 d * d
1443 })
1444 .sum::<f64>()
1445 / n as f64;
1446
1447 let aligned_mean = crate::fdata::mean_1d(&karcher.aligned_data);
1449
1450 let amp_var: f64 = (0..n)
1452 .map(|i| {
1453 let fi = karcher.aligned_data.row(i);
1454 let d = l2_distance(&fi, &aligned_mean, &weights);
1455 d * d
1456 })
1457 .sum::<f64>()
1458 / n as f64;
1459
1460 let phase_var = (total_var - amp_var).max(0.0);
1461 let ratio = if total_var > 1e-10 {
1462 phase_var / total_var
1463 } else {
1464 0.0
1465 };
1466
1467 let mut pw_ratio = vec![0.0; m];
1469 for j in 0..m {
1470 let col_orig = data.column(j);
1471 let mean_orig = col_orig.iter().sum::<f64>() / n as f64;
1472 let var_orig: f64 = col_orig
1473 .iter()
1474 .map(|&v| (v - mean_orig).powi(2))
1475 .sum::<f64>()
1476 / n as f64;
1477
1478 let col_aligned = karcher.aligned_data.column(j);
1479 let mean_aligned = col_aligned.iter().sum::<f64>() / n as f64;
1480 let var_aligned: f64 = col_aligned
1481 .iter()
1482 .map(|&v| (v - mean_aligned).powi(2))
1483 .sum::<f64>()
1484 / n as f64;
1485
1486 pw_ratio[j] = if var_orig > 1e-15 {
1487 var_aligned / var_orig
1488 } else {
1489 1.0
1490 };
1491 }
1492
1493 let mean_vr = pw_ratio.iter().sum::<f64>() / m as f64;
1494
1495 AlignmentQuality {
1496 warp_complexity: wc,
1497 mean_warp_complexity: mean_wc,
1498 warp_smoothness: ws,
1499 mean_warp_smoothness: mean_ws,
1500 total_variance: total_var,
1501 amplitude_variance: amp_var,
1502 phase_variance: phase_var,
1503 phase_amplitude_ratio: ratio,
1504 pointwise_variance_ratio: pw_ratio,
1505 mean_variance_reduction: mean_vr,
1506 }
1507}
1508
1509fn triplet_indices(n: usize, max_triplets: usize) -> Vec<(usize, usize, usize)> {
1511 let total = n * (n - 1) * (n - 2) / 6;
1512 let cap = if max_triplets > 0 {
1513 max_triplets.min(total)
1514 } else {
1515 total
1516 };
1517 (0..n)
1518 .flat_map(|i| ((i + 1)..n).flat_map(move |j| ((j + 1)..n).map(move |k| (i, j, k))))
1519 .take(cap)
1520 .collect()
1521}
1522
1523fn triplet_warp_deviation(
1525 data: &FdMatrix,
1526 argvals: &[f64],
1527 weights: &[f64],
1528 i: usize,
1529 j: usize,
1530 k: usize,
1531 lambda: f64,
1532) -> f64 {
1533 let fi = data.row(i);
1534 let fj = data.row(j);
1535 let fk = data.row(k);
1536 let rij = elastic_align_pair(&fi, &fj, argvals, lambda);
1537 let rjk = elastic_align_pair(&fj, &fk, argvals, lambda);
1538 let rik = elastic_align_pair(&fi, &fk, argvals, lambda);
1539 let composed = compose_warps(&rij.gamma, &rjk.gamma, argvals);
1540 l2_distance(&composed, &rik.gamma, weights)
1541}
1542
1543pub fn pairwise_consistency(
1554 data: &FdMatrix,
1555 argvals: &[f64],
1556 lambda: f64,
1557 max_triplets: usize,
1558) -> f64 {
1559 let n = data.nrows();
1560 if n < 3 {
1561 return 0.0;
1562 }
1563
1564 let weights = simpsons_weights(argvals);
1565 let triplets = triplet_indices(n, max_triplets);
1566 if triplets.is_empty() {
1567 return 0.0;
1568 }
1569
1570 let total_dev: f64 = triplets
1571 .iter()
1572 .map(|&(i, j, k)| triplet_warp_deviation(data, argvals, &weights, i, j, k, lambda))
1573 .sum();
1574 total_dev / triplets.len() as f64
1575}
1576
1577#[derive(Debug, Clone)]
1581pub struct ConstrainedAlignmentResult {
1582 pub gamma: Vec<f64>,
1584 pub f_aligned: Vec<f64>,
1586 pub distance: f64,
1588 pub enforced_landmarks: Vec<(f64, f64)>,
1590}
1591
1592fn snap_to_grid(t_val: f64, argvals: &[f64]) -> usize {
1594 let mut best = 0;
1595 let mut best_dist = (t_val - argvals[0]).abs();
1596 for (i, &a) in argvals.iter().enumerate().skip(1) {
1597 let d = (t_val - a).abs();
1598 if d < best_dist {
1599 best = i;
1600 best_dist = d;
1601 }
1602 }
1603 best
1604}
1605
1606fn dp_segment(
1611 q1: &[f64],
1612 q2: &[f64],
1613 argvals: &[f64],
1614 sc: usize,
1615 ec: usize,
1616 sr: usize,
1617 er: usize,
1618 lambda: f64,
1619) -> Vec<(usize, usize)> {
1620 let nc = ec - sc + 1;
1621 let nr = er - sr + 1;
1622
1623 if nc <= 1 || nr <= 1 {
1624 return vec![(sc, sr), (ec, er)];
1625 }
1626
1627 let path = dp_grid_solve(nr, nc, |local_sr, local_sc, local_tr, local_tc| {
1628 let gsr = sr + local_sr;
1629 let gsc = sc + local_sc;
1630 let gtr = sr + local_tr;
1631 let gtc = sc + local_tc;
1632 dp_edge_weight(q1, q2, argvals, gsc, gtc, gsr, gtr)
1633 + dp_lambda_penalty(argvals, gsc, gtc, gsr, gtr, lambda)
1634 });
1635
1636 path.iter().map(|&(lr, lc)| (sc + lc, sr + lr)).collect()
1638}
1639
1640fn build_constrained_waypoints(
1656 landmark_pairs: &[(f64, f64)],
1657 argvals: &[f64],
1658 m: usize,
1659) -> Vec<(usize, usize)> {
1660 let mut waypoints: Vec<(usize, usize)> = Vec::with_capacity(landmark_pairs.len() + 2);
1661 waypoints.push((0, 0));
1662 for &(tt, st) in landmark_pairs {
1663 let tc = snap_to_grid(tt, argvals);
1664 let tr = snap_to_grid(st, argvals);
1665 if let Some(&(prev_c, prev_r)) = waypoints.last() {
1666 if tc > prev_c && tr > prev_r {
1667 waypoints.push((tc, tr));
1668 }
1669 }
1670 }
1671 let last = m - 1;
1672 if let Some(&(prev_c, prev_r)) = waypoints.last() {
1673 if prev_c != last || prev_r != last {
1674 waypoints.push((last, last));
1675 }
1676 }
1677 waypoints
1678}
1679
1680fn segmented_dp_gamma(
1682 q1n: &[f64],
1683 q2n: &[f64],
1684 argvals: &[f64],
1685 waypoints: &[(usize, usize)],
1686 lambda: f64,
1687) -> Vec<f64> {
1688 let mut full_path_tc: Vec<f64> = Vec::new();
1689 let mut full_path_tr: Vec<f64> = Vec::new();
1690
1691 for seg in 0..(waypoints.len() - 1) {
1692 let (sc, sr) = waypoints[seg];
1693 let (ec, er) = waypoints[seg + 1];
1694 let segment_path = dp_segment(q1n, q2n, argvals, sc, ec, sr, er, lambda);
1695 let start = if seg > 0 { 1 } else { 0 };
1696 for &(tc, tr) in &segment_path[start..] {
1697 full_path_tc.push(argvals[tc]);
1698 full_path_tr.push(argvals[tr]);
1699 }
1700 }
1701
1702 let mut gamma: Vec<f64> = argvals
1703 .iter()
1704 .map(|&t| linear_interp(&full_path_tc, &full_path_tr, t))
1705 .collect();
1706 normalize_warp(&mut gamma, argvals);
1707 gamma
1708}
1709
1710pub fn elastic_align_pair_constrained(
1711 f1: &[f64],
1712 f2: &[f64],
1713 argvals: &[f64],
1714 landmark_pairs: &[(f64, f64)],
1715 lambda: f64,
1716) -> ConstrainedAlignmentResult {
1717 let m = f1.len();
1718
1719 if landmark_pairs.is_empty() {
1720 let r = elastic_align_pair(f1, f2, argvals, lambda);
1721 return ConstrainedAlignmentResult {
1722 gamma: r.gamma,
1723 f_aligned: r.f_aligned,
1724 distance: r.distance,
1725 enforced_landmarks: Vec::new(),
1726 };
1727 }
1728
1729 let f1_mat = FdMatrix::from_slice(f1, 1, m).unwrap();
1731 let f2_mat = FdMatrix::from_slice(f2, 1, m).unwrap();
1732 let q1_mat = srsf_transform(&f1_mat, argvals);
1733 let q2_mat = srsf_transform(&f2_mat, argvals);
1734 let q1: Vec<f64> = q1_mat.row(0);
1735 let q2: Vec<f64> = q2_mat.row(0);
1736 let norm1 = q1.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
1737 let norm2 = q2.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
1738 let q1n: Vec<f64> = q1.iter().map(|&v| v / norm1).collect();
1739 let q2n: Vec<f64> = q2.iter().map(|&v| v / norm2).collect();
1740
1741 let waypoints = build_constrained_waypoints(landmark_pairs, argvals, m);
1742 let gamma = segmented_dp_gamma(&q1n, &q2n, argvals, &waypoints, lambda);
1743
1744 let f_aligned = reparameterize_curve(f2, argvals, &gamma);
1745 let f_aligned_mat = FdMatrix::from_slice(&f_aligned, 1, m).unwrap();
1746 let q_aligned_mat = srsf_transform(&f_aligned_mat, argvals);
1747 let q_aligned: Vec<f64> = q_aligned_mat.row(0);
1748 let weights = simpsons_weights(argvals);
1749 let distance = l2_distance(&q1, &q_aligned, &weights);
1750
1751 let enforced: Vec<(f64, f64)> = waypoints[1..waypoints.len() - 1]
1752 .iter()
1753 .map(|&(tc, tr)| (argvals[tc], argvals[tr]))
1754 .collect();
1755
1756 ConstrainedAlignmentResult {
1757 gamma,
1758 f_aligned,
1759 distance,
1760 enforced_landmarks: enforced,
1761 }
1762}
1763
1764pub fn elastic_align_pair_with_landmarks(
1778 f1: &[f64],
1779 f2: &[f64],
1780 argvals: &[f64],
1781 kind: crate::landmark::LandmarkKind,
1782 min_prominence: f64,
1783 expected_count: usize,
1784 lambda: f64,
1785) -> ConstrainedAlignmentResult {
1786 let lm1 = crate::landmark::detect_landmarks(f1, argvals, kind, min_prominence);
1787 let lm2 = crate::landmark::detect_landmarks(f2, argvals, kind, min_prominence);
1788
1789 let n_match = if expected_count > 0 {
1791 expected_count.min(lm1.len()).min(lm2.len())
1792 } else {
1793 lm1.len().min(lm2.len())
1794 };
1795
1796 let pairs: Vec<(f64, f64)> = (0..n_match)
1797 .map(|i| (lm1[i].position, lm2[i].position))
1798 .collect();
1799
1800 elastic_align_pair_constrained(f1, f2, argvals, &pairs, lambda)
1801}
1802
1803use crate::matrix::FdCurveSet;
1806
1807#[derive(Debug, Clone)]
1809pub struct AlignmentResultNd {
1810 pub gamma: Vec<f64>,
1812 pub f_aligned: Vec<Vec<f64>>,
1814 pub distance: f64,
1816}
1817
1818#[inline]
1831fn srsf_scale_point(derivs: &[FdMatrix], result_dims: &mut [FdMatrix], i: usize, j: usize) {
1832 let d = derivs.len();
1833 let norm_sq: f64 = derivs.iter().map(|dd| dd[(i, j)].powi(2)).sum();
1834 let norm = norm_sq.sqrt();
1835 if norm < 1e-15 {
1836 for k in 0..d {
1837 result_dims[k][(i, j)] = 0.0;
1838 }
1839 } else {
1840 let scale = 1.0 / norm.sqrt();
1841 for k in 0..d {
1842 result_dims[k][(i, j)] = derivs[k][(i, j)] * scale;
1843 }
1844 }
1845}
1846
1847pub fn srsf_transform_nd(data: &FdCurveSet, argvals: &[f64]) -> FdCurveSet {
1848 let d = data.ndim();
1849 let n = data.ncurves();
1850 let m = data.npoints();
1851
1852 if d == 0 || n == 0 || m == 0 || argvals.len() != m {
1853 return FdCurveSet {
1854 dims: (0..d).map(|_| FdMatrix::zeros(n, m)).collect(),
1855 };
1856 }
1857
1858 let derivs: Vec<FdMatrix> = data
1859 .dims
1860 .iter()
1861 .map(|dim_mat| crate::fdata::deriv_1d(dim_mat, argvals, 1))
1862 .collect();
1863
1864 let mut result_dims: Vec<FdMatrix> = (0..d).map(|_| FdMatrix::zeros(n, m)).collect();
1865 for i in 0..n {
1866 for j in 0..m {
1867 srsf_scale_point(&derivs, &mut result_dims, i, j);
1868 }
1869 }
1870
1871 FdCurveSet { dims: result_dims }
1872}
1873
1874pub fn srsf_inverse_nd(q: &[Vec<f64>], argvals: &[f64], f0: &[f64]) -> Vec<Vec<f64>> {
1887 let d = q.len();
1888 if d == 0 {
1889 return Vec::new();
1890 }
1891 let m = q[0].len();
1892 if m == 0 {
1893 return vec![Vec::new(); d];
1894 }
1895
1896 let norms: Vec<f64> = (0..m)
1898 .map(|j| {
1899 let norm_sq: f64 = q.iter().map(|qk| qk[j].powi(2)).sum();
1900 norm_sq.sqrt()
1901 })
1902 .collect();
1903
1904 let mut result = Vec::with_capacity(d);
1906 for k in 0..d {
1907 let integrand: Vec<f64> = (0..m).map(|j| q[k][j] * norms[j]).collect();
1908 let integral = cumulative_trapz(&integrand, argvals);
1909 let curve: Vec<f64> = integral.iter().map(|&v| f0[k] + v).collect();
1910 result.push(curve);
1911 }
1912
1913 result
1914}
1915
1916fn dp_alignment_core_nd(
1921 q1: &[Vec<f64>],
1922 q2: &[Vec<f64>],
1923 argvals: &[f64],
1924 lambda: f64,
1925) -> Vec<f64> {
1926 let d = q1.len();
1927 let m = argvals.len();
1928 if m < 2 || d == 0 {
1929 return argvals.to_vec();
1930 }
1931
1932 if d == 1 {
1934 return dp_alignment_core(&q1[0], &q2[0], argvals, lambda);
1935 }
1936
1937 let q1n: Vec<Vec<f64>> = q1
1939 .iter()
1940 .map(|qk| {
1941 let norm = qk.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
1942 qk.iter().map(|&v| v / norm).collect()
1943 })
1944 .collect();
1945 let q2n: Vec<Vec<f64>> = q2
1946 .iter()
1947 .map(|qk| {
1948 let norm = qk.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
1949 qk.iter().map(|&v| v / norm).collect()
1950 })
1951 .collect();
1952
1953 let path = dp_grid_solve(m, m, |sr, sc, tr, tc| {
1954 let w: f64 = (0..d)
1955 .map(|k| dp_edge_weight(&q1n[k], &q2n[k], argvals, sc, tc, sr, tr))
1956 .sum();
1957 w + dp_lambda_penalty(argvals, sc, tc, sr, tr, lambda)
1958 });
1959
1960 dp_path_to_gamma(&path, argvals)
1961}
1962
1963pub fn elastic_align_pair_nd(
1974 f1: &FdCurveSet,
1975 f2: &FdCurveSet,
1976 argvals: &[f64],
1977 lambda: f64,
1978) -> AlignmentResultNd {
1979 let d = f1.ndim();
1980 let m = f1.npoints();
1981
1982 let q1_set = srsf_transform_nd(f1, argvals);
1984 let q2_set = srsf_transform_nd(f2, argvals);
1985
1986 let q1: Vec<Vec<f64>> = q1_set.dims.iter().map(|dm| dm.row(0)).collect();
1988 let q2: Vec<Vec<f64>> = q2_set.dims.iter().map(|dm| dm.row(0)).collect();
1989
1990 let gamma = dp_alignment_core_nd(&q1, &q2, argvals, lambda);
1992
1993 let f_aligned: Vec<Vec<f64>> = f2
1995 .dims
1996 .iter()
1997 .map(|dm| {
1998 let row = dm.row(0);
1999 reparameterize_curve(&row, argvals, &gamma)
2000 })
2001 .collect();
2002
2003 let f_aligned_set = {
2005 let dims: Vec<FdMatrix> = f_aligned
2006 .iter()
2007 .map(|fa| FdMatrix::from_slice(fa, 1, m).unwrap())
2008 .collect();
2009 FdCurveSet { dims }
2010 };
2011 let q_aligned = srsf_transform_nd(&f_aligned_set, argvals);
2012 let weights = simpsons_weights(argvals);
2013
2014 let mut dist_sq = 0.0;
2015 for k in 0..d {
2016 let q1k = q1_set.dims[k].row(0);
2017 let qak = q_aligned.dims[k].row(0);
2018 let d_k = l2_distance(&q1k, &qak, &weights);
2019 dist_sq += d_k * d_k;
2020 }
2021
2022 AlignmentResultNd {
2023 gamma,
2024 f_aligned,
2025 distance: dist_sq.sqrt(),
2026 }
2027}
2028
2029pub fn elastic_distance_nd(f1: &FdCurveSet, f2: &FdCurveSet, argvals: &[f64], lambda: f64) -> f64 {
2033 elastic_align_pair_nd(f1, f2, argvals, lambda).distance
2034}
2035
2036#[cfg(test)]
2039mod tests {
2040 use super::*;
2041 use crate::helpers::trapz;
2042 use crate::simulation::{sim_fundata, EFunType, EValType};
2043 use crate::warping::inner_product_l2;
2044
2045 fn uniform_grid(m: usize) -> Vec<f64> {
2046 (0..m).map(|i| i as f64 / (m - 1) as f64).collect()
2047 }
2048
2049 fn make_test_data(n: usize, m: usize, seed: u64) -> FdMatrix {
2050 let t = uniform_grid(m);
2051 sim_fundata(
2052 n,
2053 &t,
2054 3,
2055 EFunType::Fourier,
2056 EValType::Exponential,
2057 Some(seed),
2058 )
2059 }
2060
2061 #[test]
2064 fn test_cumulative_trapz_constant() {
2065 let x = uniform_grid(50);
2067 let y = vec![1.0; 50];
2068 let result = cumulative_trapz(&y, &x);
2069 assert!((result[0]).abs() < 1e-15, "cumulative_trapz(0) should be 0");
2070 for j in 1..50 {
2071 assert!(
2072 (result[j] - x[j]).abs() < 1e-12,
2073 "∫₀^{:.3} 1 dt should be {:.3}, got {:.3}",
2074 x[j],
2075 x[j],
2076 result[j]
2077 );
2078 }
2079 }
2080
2081 #[test]
2082 fn test_cumulative_trapz_linear() {
2083 let m = 100;
2085 let x = uniform_grid(m);
2086 let y: Vec<f64> = x.clone();
2087 let result = cumulative_trapz(&y, &x);
2088 for j in 1..m {
2089 let expected = x[j] * x[j] / 2.0;
2090 assert!(
2091 (result[j] - expected).abs() < 1e-4,
2092 "∫₀^{:.3} s ds: expected {expected:.6}, got {:.6}",
2093 x[j],
2094 result[j]
2095 );
2096 }
2097 }
2098
2099 #[test]
2102 fn test_normalize_warp_fixes_boundaries() {
2103 let t = uniform_grid(10);
2104 let mut gamma = vec![0.1; 10]; normalize_warp(&mut gamma, &t);
2106 assert_eq!(gamma[0], t[0]);
2107 assert_eq!(gamma[9], t[9]);
2108 }
2109
2110 #[test]
2111 fn test_normalize_warp_enforces_monotonicity() {
2112 let t = uniform_grid(5);
2113 let mut gamma = vec![0.0, 0.5, 0.3, 0.8, 1.0]; normalize_warp(&mut gamma, &t);
2115 for j in 1..5 {
2116 assert!(
2117 gamma[j] >= gamma[j - 1],
2118 "gamma should be monotone after normalization at j={j}"
2119 );
2120 }
2121 }
2122
2123 #[test]
2124 fn test_normalize_warp_identity_unchanged() {
2125 let t = uniform_grid(20);
2126 let mut gamma = t.clone();
2127 normalize_warp(&mut gamma, &t);
2128 for j in 0..20 {
2129 assert!(
2130 (gamma[j] - t[j]).abs() < 1e-15,
2131 "Identity warp should be unchanged"
2132 );
2133 }
2134 }
2135
2136 #[test]
2139 fn test_linear_interp_at_nodes() {
2140 let x = vec![0.0, 1.0, 2.0, 3.0];
2141 let y = vec![0.0, 2.0, 4.0, 6.0];
2142 for i in 0..x.len() {
2143 assert!((linear_interp(&x, &y, x[i]) - y[i]).abs() < 1e-12);
2144 }
2145 }
2146
2147 #[test]
2148 fn test_linear_interp_midpoints() {
2149 let x = vec![0.0, 1.0, 2.0];
2150 let y = vec![0.0, 2.0, 4.0];
2151 assert!((linear_interp(&x, &y, 0.5) - 1.0).abs() < 1e-12);
2152 assert!((linear_interp(&x, &y, 1.5) - 3.0).abs() < 1e-12);
2153 }
2154
2155 #[test]
2156 fn test_linear_interp_clamp() {
2157 let x = vec![0.0, 1.0, 2.0];
2158 let y = vec![1.0, 3.0, 5.0];
2159 assert!((linear_interp(&x, &y, -1.0) - 1.0).abs() < 1e-12);
2160 assert!((linear_interp(&x, &y, 3.0) - 5.0).abs() < 1e-12);
2161 }
2162
2163 #[test]
2164 fn test_linear_interp_nonuniform_grid() {
2165 let x = vec![0.0, 0.1, 0.5, 1.0];
2166 let y = vec![0.0, 1.0, 5.0, 10.0];
2167 let val = linear_interp(&x, &y, 0.3);
2169 let expected = 1.0 + 10.0 * (0.3 - 0.1);
2170 assert!(
2171 (val - expected).abs() < 1e-12,
2172 "Non-uniform interp: expected {expected}, got {val}"
2173 );
2174 }
2175
2176 #[test]
2177 fn test_linear_interp_two_points() {
2178 let x = vec![0.0, 1.0];
2179 let y = vec![3.0, 7.0];
2180 assert!((linear_interp(&x, &y, 0.25) - 4.0).abs() < 1e-12);
2181 assert!((linear_interp(&x, &y, 0.75) - 6.0).abs() < 1e-12);
2182 }
2183
2184 #[test]
2187 fn test_srsf_transform_linear() {
2188 let m = 50;
2190 let t = uniform_grid(m);
2191 let f: Vec<f64> = t.iter().map(|&ti| 2.0 * ti).collect();
2192 let mat = FdMatrix::from_slice(&f, 1, m).unwrap();
2193
2194 let q_mat = srsf_transform(&mat, &t);
2195 let q: Vec<f64> = q_mat.row(0);
2196
2197 let expected = 2.0_f64.sqrt();
2198 for j in 2..(m - 2) {
2200 assert!(
2201 (q[j] - expected).abs() < 0.1,
2202 "q[{j}] = {}, expected ~{expected}",
2203 q[j]
2204 );
2205 }
2206 }
2207
2208 #[test]
2209 fn test_srsf_transform_preserves_shape() {
2210 let data = make_test_data(10, 50, 42);
2211 let t = uniform_grid(50);
2212 let q = srsf_transform(&data, &t);
2213 assert_eq!(q.shape(), data.shape());
2214 }
2215
2216 #[test]
2217 fn test_srsf_transform_constant_is_zero() {
2218 let m = 30;
2220 let t = uniform_grid(m);
2221 let f = vec![5.0; m];
2222 let mat = FdMatrix::from_slice(&f, 1, m).unwrap();
2223 let q_mat = srsf_transform(&mat, &t);
2224 let q: Vec<f64> = q_mat.row(0);
2225
2226 for j in 0..m {
2227 assert!(
2228 q[j].abs() < 1e-10,
2229 "SRSF of constant should be 0, got q[{j}] = {}",
2230 q[j]
2231 );
2232 }
2233 }
2234
2235 #[test]
2236 fn test_srsf_transform_negative_slope() {
2237 let m = 50;
2239 let t = uniform_grid(m);
2240 let f: Vec<f64> = t.iter().map(|&ti| -3.0 * ti).collect();
2241 let mat = FdMatrix::from_slice(&f, 1, m).unwrap();
2242
2243 let q_mat = srsf_transform(&mat, &t);
2244 let q: Vec<f64> = q_mat.row(0);
2245
2246 let expected = -(3.0_f64.sqrt());
2247 for j in 2..(m - 2) {
2248 assert!(
2249 (q[j] - expected).abs() < 0.15,
2250 "q[{j}] = {}, expected ~{expected}",
2251 q[j]
2252 );
2253 }
2254 }
2255
2256 #[test]
2257 fn test_srsf_transform_empty_input() {
2258 let data = FdMatrix::zeros(0, 0);
2259 let t: Vec<f64> = vec![];
2260 let q = srsf_transform(&data, &t);
2261 assert_eq!(q.shape(), (0, 0));
2262 }
2263
2264 #[test]
2265 fn test_srsf_transform_multiple_curves() {
2266 let m = 40;
2267 let t = uniform_grid(m);
2268 let data = make_test_data(5, m, 42);
2269
2270 let q = srsf_transform(&data, &t);
2271 assert_eq!(q.shape(), (5, m));
2272
2273 for i in 0..5 {
2275 for j in 0..m {
2276 assert!(q[(i, j)].is_finite(), "SRSF should be finite at ({i},{j})");
2277 }
2278 }
2279 }
2280
2281 #[test]
2284 fn test_srsf_round_trip() {
2285 let m = 100;
2286 let t = uniform_grid(m);
2287 let f: Vec<f64> = t
2289 .iter()
2290 .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin() + ti)
2291 .collect();
2292
2293 let mat = FdMatrix::from_slice(&f, 1, m).unwrap();
2294 let q_mat = srsf_transform(&mat, &t);
2295 let q: Vec<f64> = q_mat.row(0);
2296
2297 let f_recon = srsf_inverse(&q, &t, f[0]);
2298
2299 let max_err: f64 = f[5..(m - 5)]
2301 .iter()
2302 .zip(f_recon[5..(m - 5)].iter())
2303 .map(|(a, b)| (a - b).abs())
2304 .fold(0.0_f64, f64::max);
2305
2306 assert!(
2307 max_err < 0.15,
2308 "Round-trip error too large: max_err = {max_err}"
2309 );
2310 }
2311
2312 #[test]
2313 fn test_srsf_inverse_empty() {
2314 let q: Vec<f64> = vec![];
2315 let t: Vec<f64> = vec![];
2316 let result = srsf_inverse(&q, &t, 0.0);
2317 assert!(result.is_empty());
2318 }
2319
2320 #[test]
2321 fn test_srsf_inverse_preserves_initial_value() {
2322 let m = 50;
2323 let t = uniform_grid(m);
2324 let q = vec![1.0; m]; let f0 = 3.15;
2326 let f = srsf_inverse(&q, &t, f0);
2327 assert!((f[0] - f0).abs() < 1e-12, "srsf_inverse should start at f0");
2328 }
2329
2330 #[test]
2331 fn test_srsf_round_trip_multiple_curves() {
2332 let m = 80;
2333 let t = uniform_grid(m);
2334 let data = make_test_data(5, m, 99);
2335
2336 let q_mat = srsf_transform(&data, &t);
2337
2338 for i in 0..5 {
2339 let fi = data.row(i);
2340 let qi = q_mat.row(i);
2341 let f_recon = srsf_inverse(&qi, &t, fi[0]);
2342 let max_err: f64 = fi[5..(m - 5)]
2343 .iter()
2344 .zip(f_recon[5..(m - 5)].iter())
2345 .map(|(a, b)| (a - b).abs())
2346 .fold(0.0_f64, f64::max);
2347 assert!(max_err < 0.3, "Round-trip curve {i}: max_err = {max_err}");
2348 }
2349 }
2350
2351 #[test]
2354 fn test_reparameterize_identity_warp() {
2355 let m = 50;
2356 let t = uniform_grid(m);
2357 let f: Vec<f64> = t.iter().map(|&ti| ti * ti).collect();
2358
2359 let result = reparameterize_curve(&f, &t, &t);
2361 for j in 0..m {
2362 assert!(
2363 (result[j] - f[j]).abs() < 1e-12,
2364 "Identity warp should return original at j={j}"
2365 );
2366 }
2367 }
2368
2369 #[test]
2370 fn test_reparameterize_linear_warp() {
2371 let m = 50;
2372 let t = uniform_grid(m);
2373 let f: Vec<f64> = t.clone();
2375 let gamma: Vec<f64> = t.iter().map(|&ti| ti * ti).collect();
2376
2377 let result = reparameterize_curve(&f, &t, &gamma);
2378
2379 for j in 0..m {
2381 assert!(
2382 (result[j] - gamma[j]).abs() < 1e-10,
2383 "f(gamma(t)) should be gamma(t) for f(t)=t at j={j}"
2384 );
2385 }
2386 }
2387
2388 #[test]
2389 fn test_reparameterize_sine_with_quadratic_warp() {
2390 let m = 100;
2391 let t = uniform_grid(m);
2392 let f: Vec<f64> = t
2393 .iter()
2394 .map(|&ti| (std::f64::consts::PI * ti).sin())
2395 .collect();
2396 let gamma: Vec<f64> = t.iter().map(|&ti| ti * ti).collect(); let result = reparameterize_curve(&f, &t, &gamma);
2399
2400 for j in 0..m {
2402 let expected = (std::f64::consts::PI * gamma[j]).sin();
2403 assert!(
2404 (result[j] - expected).abs() < 0.05,
2405 "sin(π γ(t)) at j={j}: expected {expected:.4}, got {:.4}",
2406 result[j]
2407 );
2408 }
2409 }
2410
2411 #[test]
2412 fn test_reparameterize_preserves_length() {
2413 let m = 50;
2414 let t = uniform_grid(m);
2415 let f = vec![1.0; m];
2416 let gamma: Vec<f64> = t.iter().map(|&ti| ti.sqrt()).collect();
2417
2418 let result = reparameterize_curve(&f, &t, &gamma);
2419 assert_eq!(result.len(), m);
2420 }
2421
2422 #[test]
2425 fn test_compose_warps_identity() {
2426 let m = 50;
2427 let t = uniform_grid(m);
2428 let gamma: Vec<f64> = t.iter().map(|&ti| ti.sqrt()).collect();
2430
2431 let result = compose_warps(&t, &gamma, &t);
2433 for j in 0..m {
2434 assert!(
2435 (result[j] - gamma[j]).abs() < 1e-10,
2436 "id ∘ γ should be γ at j={j}"
2437 );
2438 }
2439
2440 let result2 = compose_warps(&gamma, &t, &t);
2442 for j in 0..m {
2443 assert!(
2444 (result2[j] - gamma[j]).abs() < 1e-10,
2445 "γ ∘ id should be γ at j={j}"
2446 );
2447 }
2448 }
2449
2450 #[test]
2451 fn test_compose_warps_associativity() {
2452 let m = 50;
2454 let t = uniform_grid(m);
2455 let g1: Vec<f64> = t.iter().map(|&ti| ti.sqrt()).collect();
2456 let g2: Vec<f64> = t.iter().map(|&ti| ti * ti).collect();
2457 let g3: Vec<f64> = t.iter().map(|&ti| 0.5 * ti + 0.5 * ti * ti).collect();
2458
2459 let g12 = compose_warps(&g1, &g2, &t);
2460 let left = compose_warps(&g12, &g3, &t); let g23 = compose_warps(&g2, &g3, &t);
2463 let right = compose_warps(&g1, &g23, &t); for j in 0..m {
2466 assert!(
2467 (left[j] - right[j]).abs() < 0.05,
2468 "Composition should be roughly associative at j={j}: left={:.4}, right={:.4}",
2469 left[j],
2470 right[j]
2471 );
2472 }
2473 }
2474
2475 #[test]
2476 fn test_compose_warps_preserves_domain() {
2477 let m = 50;
2478 let t = uniform_grid(m);
2479 let g1: Vec<f64> = t.iter().map(|&ti| ti.sqrt()).collect();
2480 let g2: Vec<f64> = t.iter().map(|&ti| ti * ti).collect();
2481
2482 let composed = compose_warps(&g1, &g2, &t);
2483 assert!(
2484 (composed[0] - t[0]).abs() < 1e-10,
2485 "Composed warp should start at domain start"
2486 );
2487 assert!(
2488 (composed[m - 1] - t[m - 1]).abs() < 1e-10,
2489 "Composed warp should end at domain end"
2490 );
2491 }
2492
2493 #[test]
2496 fn test_align_identical_curves() {
2497 let m = 50;
2498 let t = uniform_grid(m);
2499 let f: Vec<f64> = t
2500 .iter()
2501 .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
2502 .collect();
2503
2504 let result = elastic_align_pair(&f, &f, &t, 0.0);
2505
2506 assert!(
2508 result.distance < 0.1,
2509 "Distance between identical curves should be near 0, got {}",
2510 result.distance
2511 );
2512
2513 for j in 0..m {
2515 assert!(
2516 (result.gamma[j] - t[j]).abs() < 0.1,
2517 "Warp should be near identity at j={j}: gamma={}, t={}",
2518 result.gamma[j],
2519 t[j]
2520 );
2521 }
2522 }
2523
2524 #[test]
2525 fn test_align_pair_valid_output() {
2526 let data = make_test_data(2, 50, 42);
2527 let t = uniform_grid(50);
2528 let f1 = data.row(0);
2529 let f2 = data.row(1);
2530
2531 let result = elastic_align_pair(&f1, &f2, &t, 0.0);
2532
2533 assert_eq!(result.gamma.len(), 50);
2534 assert_eq!(result.f_aligned.len(), 50);
2535 assert!(result.distance >= 0.0);
2536
2537 for j in 1..50 {
2539 assert!(
2540 result.gamma[j] >= result.gamma[j - 1],
2541 "Warp should be monotone at j={j}"
2542 );
2543 }
2544 }
2545
2546 #[test]
2547 fn test_align_pair_warp_boundaries() {
2548 let data = make_test_data(2, 50, 42);
2549 let t = uniform_grid(50);
2550 let f1 = data.row(0);
2551 let f2 = data.row(1);
2552
2553 let result = elastic_align_pair(&f1, &f2, &t, 0.0);
2554 assert!(
2555 (result.gamma[0] - t[0]).abs() < 1e-12,
2556 "Warp should start at domain start"
2557 );
2558 assert!(
2559 (result.gamma[49] - t[49]).abs() < 1e-12,
2560 "Warp should end at domain end"
2561 );
2562 }
2563
2564 #[test]
2565 fn test_align_shifted_sine() {
2566 let m = 80;
2568 let t = uniform_grid(m);
2569 let f1: Vec<f64> = t
2570 .iter()
2571 .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
2572 .collect();
2573 let f2: Vec<f64> = t
2574 .iter()
2575 .map(|&ti| (2.0 * std::f64::consts::PI * (ti - 0.1)).sin())
2576 .collect();
2577
2578 let weights = simpsons_weights(&t);
2579 let l2_before = l2_distance(&f1, &f2, &weights);
2580 let result = elastic_align_pair(&f1, &f2, &t, 0.0);
2581 let l2_after = l2_distance(&f1, &result.f_aligned, &weights);
2582
2583 assert!(
2584 l2_after < l2_before + 0.01,
2585 "Alignment should not increase L2 distance: before={l2_before:.4}, after={l2_after:.4}"
2586 );
2587 }
2588
2589 #[test]
2590 fn test_align_pair_aligned_curve_is_finite() {
2591 let data = make_test_data(2, 50, 77);
2592 let t = uniform_grid(50);
2593 let f1 = data.row(0);
2594 let f2 = data.row(1);
2595
2596 let result = elastic_align_pair(&f1, &f2, &t, 0.0);
2597 for j in 0..50 {
2598 assert!(
2599 result.f_aligned[j].is_finite(),
2600 "Aligned curve should be finite at j={j}"
2601 );
2602 }
2603 }
2604
2605 #[test]
2606 fn test_align_pair_minimum_grid() {
2607 let t = vec![0.0, 1.0];
2609 let f1 = vec![0.0, 1.0];
2610 let f2 = vec![0.0, 2.0];
2611 let result = elastic_align_pair(&f1, &f2, &t, 0.0);
2612 assert_eq!(result.gamma.len(), 2);
2613 assert_eq!(result.f_aligned.len(), 2);
2614 assert!(result.distance >= 0.0);
2615 }
2616
2617 #[test]
2620 fn test_elastic_distance_symmetric() {
2621 let data = make_test_data(3, 50, 42);
2622 let t = uniform_grid(50);
2623 let f1 = data.row(0);
2624 let f2 = data.row(1);
2625
2626 let d12 = elastic_distance(&f1, &f2, &t, 0.0);
2627 let d21 = elastic_distance(&f2, &f1, &t, 0.0);
2628
2629 assert!(
2631 (d12 - d21).abs() < d12.max(d21) * 0.3 + 0.01,
2632 "Elastic distance should be roughly symmetric: d12={d12}, d21={d21}"
2633 );
2634 }
2635
2636 #[test]
2637 fn test_elastic_distance_nonneg() {
2638 let data = make_test_data(3, 50, 42);
2639 let t = uniform_grid(50);
2640
2641 for i in 0..3 {
2642 for j in 0..3 {
2643 let fi = data.row(i);
2644 let fj = data.row(j);
2645 let d = elastic_distance(&fi, &fj, &t, 0.0);
2646 assert!(d >= 0.0, "Elastic distance should be non-negative");
2647 }
2648 }
2649 }
2650
2651 #[test]
2652 fn test_elastic_distance_self_near_zero() {
2653 let data = make_test_data(3, 50, 42);
2654 let t = uniform_grid(50);
2655
2656 for i in 0..3 {
2657 let fi = data.row(i);
2658 let d = elastic_distance(&fi, &fi, &t, 0.0);
2659 assert!(
2660 d < 0.1,
2661 "Self-distance should be near zero, got {d} for curve {i}"
2662 );
2663 }
2664 }
2665
2666 #[test]
2667 fn test_elastic_distance_triangle_inequality() {
2668 let data = make_test_data(3, 50, 42);
2669 let t = uniform_grid(50);
2670 let f0 = data.row(0);
2671 let f1 = data.row(1);
2672 let f2 = data.row(2);
2673
2674 let d01 = elastic_distance(&f0, &f1, &t, 0.0);
2675 let d12 = elastic_distance(&f1, &f2, &t, 0.0);
2676 let d02 = elastic_distance(&f0, &f2, &t, 0.0);
2677
2678 let slack = 0.5;
2680 assert!(
2681 d02 <= d01 + d12 + slack,
2682 "Triangle inequality (relaxed): d02={d02:.4} > d01={d01:.4} + d12={d12:.4} + {slack}"
2683 );
2684 }
2685
2686 #[test]
2687 fn test_elastic_distance_different_shapes_nonzero() {
2688 let m = 50;
2689 let t = uniform_grid(m);
2690 let f1: Vec<f64> = t.to_vec(); let f2: Vec<f64> = t.iter().map(|&ti| ti * ti).collect(); let d = elastic_distance(&f1, &f2, &t, 0.0);
2694 assert!(
2695 d > 0.01,
2696 "Distance between different shapes should be > 0, got {d}"
2697 );
2698 }
2699
2700 #[test]
2703 fn test_self_distance_matrix_symmetric() {
2704 let data = make_test_data(5, 30, 42);
2705 let t = uniform_grid(30);
2706
2707 let dm = elastic_self_distance_matrix(&data, &t, 0.0);
2708 let n = dm.nrows();
2709
2710 assert_eq!(dm.shape(), (5, 5));
2711
2712 for i in 0..n {
2714 assert!(dm[(i, i)].abs() < 1e-12, "Diagonal should be zero");
2715 }
2716
2717 for i in 0..n {
2719 for j in (i + 1)..n {
2720 assert!(
2721 (dm[(i, j)] - dm[(j, i)]).abs() < 1e-12,
2722 "Matrix should be symmetric at ({i},{j})"
2723 );
2724 }
2725 }
2726 }
2727
2728 #[test]
2729 fn test_self_distance_matrix_nonneg() {
2730 let data = make_test_data(4, 30, 42);
2731 let t = uniform_grid(30);
2732 let dm = elastic_self_distance_matrix(&data, &t, 0.0);
2733
2734 for i in 0..4 {
2735 for j in 0..4 {
2736 assert!(
2737 dm[(i, j)] >= 0.0,
2738 "Distance matrix entries should be non-negative at ({i},{j})"
2739 );
2740 }
2741 }
2742 }
2743
2744 #[test]
2745 fn test_self_distance_matrix_single_curve() {
2746 let data = make_test_data(1, 30, 42);
2747 let t = uniform_grid(30);
2748 let dm = elastic_self_distance_matrix(&data, &t, 0.0);
2749 assert_eq!(dm.shape(), (1, 1));
2750 assert!(dm[(0, 0)].abs() < 1e-12);
2751 }
2752
2753 #[test]
2754 fn test_self_distance_matrix_consistent_with_pairwise() {
2755 let data = make_test_data(4, 30, 42);
2756 let t = uniform_grid(30);
2757
2758 let dm = elastic_self_distance_matrix(&data, &t, 0.0);
2759
2760 for i in 0..4 {
2762 for j in (i + 1)..4 {
2763 let fi = data.row(i);
2764 let fj = data.row(j);
2765 let d_direct = elastic_distance(&fi, &fj, &t, 0.0);
2766 assert!(
2767 (dm[(i, j)] - d_direct).abs() < 1e-10,
2768 "Matrix entry ({i},{j})={:.6} should match pairwise {d_direct:.6}",
2769 dm[(i, j)]
2770 );
2771 }
2772 }
2773 }
2774
2775 #[test]
2778 fn test_karcher_mean_identical_curves() {
2779 let m = 50;
2780 let t = uniform_grid(m);
2781 let f: Vec<f64> = t
2782 .iter()
2783 .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
2784 .collect();
2785
2786 let mut data = FdMatrix::zeros(5, m);
2788 for i in 0..5 {
2789 for j in 0..m {
2790 data[(i, j)] = f[j];
2791 }
2792 }
2793
2794 let result = karcher_mean(&data, &t, 10, 1e-4, 0.0);
2795
2796 assert_eq!(result.mean.len(), m);
2797 assert!(result.n_iter <= 10);
2798 }
2799
2800 #[test]
2801 fn test_karcher_mean_output_shape() {
2802 let data = make_test_data(15, 50, 42);
2803 let t = uniform_grid(50);
2804
2805 let result = karcher_mean(&data, &t, 5, 1e-3, 0.0);
2806
2807 assert_eq!(result.mean.len(), 50);
2808 assert_eq!(result.mean_srsf.len(), 50);
2809 assert_eq!(result.gammas.shape(), (15, 50));
2810 assert_eq!(result.aligned_data.shape(), (15, 50));
2811 assert!(result.n_iter <= 5);
2812 }
2813
2814 #[test]
2815 fn test_karcher_mean_warps_are_valid() {
2816 let data = make_test_data(10, 40, 42);
2817 let t = uniform_grid(40);
2818
2819 let result = karcher_mean(&data, &t, 5, 1e-3, 0.0);
2820
2821 for i in 0..10 {
2822 assert!(
2824 (result.gammas[(i, 0)] - t[0]).abs() < 1e-10,
2825 "Warp {i} should start at domain start"
2826 );
2827 assert!(
2828 (result.gammas[(i, 39)] - t[39]).abs() < 1e-10,
2829 "Warp {i} should end at domain end"
2830 );
2831 for j in 1..40 {
2833 assert!(
2834 result.gammas[(i, j)] >= result.gammas[(i, j - 1)],
2835 "Warp {i} should be monotone at j={j}"
2836 );
2837 }
2838 }
2839 }
2840
2841 #[test]
2842 fn test_karcher_mean_aligned_data_is_finite() {
2843 let data = make_test_data(8, 40, 42);
2844 let t = uniform_grid(40);
2845 let result = karcher_mean(&data, &t, 5, 1e-3, 0.0);
2846
2847 for i in 0..8 {
2848 for j in 0..40 {
2849 assert!(
2850 result.aligned_data[(i, j)].is_finite(),
2851 "Aligned data should be finite at ({i},{j})"
2852 );
2853 }
2854 }
2855 }
2856
2857 #[test]
2858 fn test_karcher_mean_srsf_is_finite() {
2859 let data = make_test_data(8, 40, 42);
2860 let t = uniform_grid(40);
2861 let result = karcher_mean(&data, &t, 5, 1e-3, 0.0);
2862
2863 for j in 0..40 {
2864 assert!(
2865 result.mean_srsf[j].is_finite(),
2866 "Mean SRSF should be finite at j={j}"
2867 );
2868 assert!(
2869 result.mean[j].is_finite(),
2870 "Mean curve should be finite at j={j}"
2871 );
2872 }
2873 }
2874
2875 #[test]
2876 fn test_karcher_mean_single_iteration() {
2877 let data = make_test_data(10, 40, 42);
2878 let t = uniform_grid(40);
2879 let result = karcher_mean(&data, &t, 1, 1e-10, 0.0);
2880
2881 assert_eq!(result.n_iter, 1);
2882 assert_eq!(result.mean.len(), 40);
2883 for j in 0..40 {
2885 assert!(result.mean[j].is_finite());
2886 }
2887 }
2888
2889 #[test]
2890 fn test_karcher_mean_convergence_not_premature() {
2891 let n = 10;
2896 let m = 50;
2897 let t = uniform_grid(m);
2898
2899 let mut col_major = vec![0.0; n * m];
2901 for i in 0..n {
2902 let shift = (i as f64 - 5.0) * 0.05;
2903 for j in 0..m {
2904 col_major[i + j * n] = (2.0 * std::f64::consts::PI * (t[j] - shift)).sin();
2905 }
2906 }
2907 let data = FdMatrix::from_column_major(col_major, n, m).unwrap();
2908
2909 let result = karcher_mean(&data, &t, 20, 1e-15, 0.0);
2912 assert!(
2913 result.n_iter > 2,
2914 "With tol=1e-15 the algorithm should iterate beyond 2, got n_iter={}",
2915 result.n_iter
2916 );
2917
2918 let result_loose = karcher_mean(&data, &t, 20, 1e-2, 0.0);
2920 assert!(
2921 result_loose.converged,
2922 "With tol=1e-2 the algorithm should converge"
2923 );
2924 }
2925
2926 #[test]
2929 fn test_align_to_target_valid() {
2930 let data = make_test_data(10, 40, 42);
2931 let t = uniform_grid(40);
2932 let target = data.row(0);
2933
2934 let result = align_to_target(&data, &target, &t, 0.0);
2935
2936 assert_eq!(result.gammas.shape(), (10, 40));
2937 assert_eq!(result.aligned_data.shape(), (10, 40));
2938 assert_eq!(result.distances.len(), 10);
2939
2940 for &d in &result.distances {
2942 assert!(d >= 0.0);
2943 }
2944 }
2945
2946 #[test]
2947 fn test_align_to_target_self_near_zero() {
2948 let data = make_test_data(5, 40, 42);
2949 let t = uniform_grid(40);
2950 let target = data.row(0);
2951
2952 let result = align_to_target(&data, &target, &t, 0.0);
2953
2954 assert!(
2956 result.distances[0] < 0.1,
2957 "Self-alignment distance should be near zero, got {}",
2958 result.distances[0]
2959 );
2960 }
2961
2962 #[test]
2963 fn test_align_to_target_warps_are_monotone() {
2964 let data = make_test_data(8, 40, 42);
2965 let t = uniform_grid(40);
2966 let target = data.row(0);
2967 let result = align_to_target(&data, &target, &t, 0.0);
2968
2969 for i in 0..8 {
2970 for j in 1..40 {
2971 assert!(
2972 result.gammas[(i, j)] >= result.gammas[(i, j - 1)],
2973 "Warp for curve {i} should be monotone at j={j}"
2974 );
2975 }
2976 }
2977 }
2978
2979 #[test]
2980 fn test_align_to_target_aligned_data_finite() {
2981 let data = make_test_data(6, 40, 42);
2982 let t = uniform_grid(40);
2983 let target = data.row(0);
2984 let result = align_to_target(&data, &target, &t, 0.0);
2985
2986 for i in 0..6 {
2987 for j in 0..40 {
2988 assert!(
2989 result.aligned_data[(i, j)].is_finite(),
2990 "Aligned data should be finite at ({i},{j})"
2991 );
2992 }
2993 }
2994 }
2995
2996 #[test]
2999 fn test_cross_distance_matrix_shape() {
3000 let data1 = make_test_data(3, 30, 42);
3001 let data2 = make_test_data(4, 30, 99);
3002 let t = uniform_grid(30);
3003
3004 let dm = elastic_cross_distance_matrix(&data1, &data2, &t, 0.0);
3005 assert_eq!(dm.shape(), (3, 4));
3006
3007 for i in 0..3 {
3009 for j in 0..4 {
3010 assert!(dm[(i, j)] >= 0.0);
3011 }
3012 }
3013 }
3014
3015 #[test]
3016 fn test_cross_distance_matrix_self_matches_self_matrix() {
3017 let data = make_test_data(4, 30, 42);
3019 let t = uniform_grid(30);
3020
3021 let cross = elastic_cross_distance_matrix(&data, &data, &t, 0.0);
3022 for i in 0..4 {
3023 assert!(
3024 cross[(i, i)] < 0.1,
3025 "Cross distance (self) diagonal should be near zero: got {}",
3026 cross[(i, i)]
3027 );
3028 }
3029 }
3030
3031 #[test]
3032 fn test_cross_distance_matrix_consistent_with_pairwise() {
3033 let data1 = make_test_data(3, 30, 42);
3034 let data2 = make_test_data(2, 30, 99);
3035 let t = uniform_grid(30);
3036
3037 let dm = elastic_cross_distance_matrix(&data1, &data2, &t, 0.0);
3038
3039 for i in 0..3 {
3040 for j in 0..2 {
3041 let fi = data1.row(i);
3042 let fj = data2.row(j);
3043 let d_direct = elastic_distance(&fi, &fj, &t, 0.0);
3044 assert!(
3045 (dm[(i, j)] - d_direct).abs() < 1e-10,
3046 "Cross matrix ({i},{j})={:.6} should match pairwise {d_direct:.6}",
3047 dm[(i, j)]
3048 );
3049 }
3050 }
3051 }
3052
3053 #[test]
3056 fn test_align_srsf_pair_identity() {
3057 let m = 50;
3058 let t = uniform_grid(m);
3059 let f: Vec<f64> = t
3060 .iter()
3061 .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
3062 .collect();
3063 let q = srsf_single(&f, &t);
3064
3065 let (gamma, q_aligned) = align_srsf_pair(&q, &q, &t, 0.0);
3066
3067 for j in 0..m {
3069 assert!(
3070 (gamma[j] - t[j]).abs() < 0.15,
3071 "Self-SRSF alignment warp should be near identity at j={j}"
3072 );
3073 }
3074
3075 let weights = simpsons_weights(&t);
3077 let dist = l2_distance(&q, &q_aligned, &weights);
3078 assert!(
3079 dist < 0.5,
3080 "Self-aligned SRSF distance should be small, got {dist}"
3081 );
3082 }
3083
3084 #[test]
3087 fn test_srsf_single_matches_matrix_version() {
3088 let m = 50;
3089 let t = uniform_grid(m);
3090 let f: Vec<f64> = t.iter().map(|&ti| ti * ti + ti).collect();
3091
3092 let q_single = srsf_single(&f, &t);
3093
3094 let mat = FdMatrix::from_slice(&f, 1, m).unwrap();
3095 let q_mat = srsf_transform(&mat, &t);
3096 let q_from_mat = q_mat.row(0);
3097
3098 for j in 0..m {
3099 assert!(
3100 (q_single[j] - q_from_mat[j]).abs() < 1e-12,
3101 "srsf_single should match srsf_transform at j={j}"
3102 );
3103 }
3104 }
3105
3106 #[test]
3109 fn test_gcd_basic() {
3110 assert_eq!(gcd(1, 1), 1);
3111 assert_eq!(gcd(6, 4), 2);
3112 assert_eq!(gcd(7, 5), 1);
3113 assert_eq!(gcd(12, 8), 4);
3114 assert_eq!(gcd(7, 0), 7);
3115 assert_eq!(gcd(0, 5), 5);
3116 }
3117
3118 #[test]
3121 fn test_coprime_nbhd_count() {
3122 assert_eq!(generate_coprime_nbhd(1).len(), 1); assert_eq!(generate_coprime_nbhd(7).len(), 35);
3124 }
3125
3126 #[test]
3127 fn test_coprime_nbhd_matches_const() {
3128 let generated = generate_coprime_nbhd(7);
3129 assert_eq!(generated.len(), COPRIME_NBHD_7.len());
3130 for (i, pair) in generated.iter().enumerate() {
3131 assert_eq!(*pair, COPRIME_NBHD_7[i], "mismatch at index {i}");
3132 }
3133 }
3134
3135 #[test]
3136 fn test_coprime_nbhd_all_coprime() {
3137 for &(i, j) in &COPRIME_NBHD_7 {
3138 assert_eq!(gcd(i, j), 1, "({i},{j}) should be coprime");
3139 assert!((1..=7).contains(&i));
3140 assert!((1..=7).contains(&j));
3141 }
3142 }
3143
3144 #[test]
3147 fn test_dp_edge_weight_diagonal() {
3148 let t = uniform_grid(10);
3150 let q1 = vec![1.0; 10];
3151 let q2 = vec![1.0; 10];
3152 let w = dp_edge_weight(&q1, &q2, &t, 0, 1, 0, 1);
3154 assert!(w.abs() < 1e-12, "identical SRSFs should have zero cost");
3155 }
3156
3157 #[test]
3158 fn test_dp_edge_weight_non_diagonal() {
3159 let t = uniform_grid(10);
3161 let q1 = vec![1.0; 10];
3162 let q2 = vec![0.0; 10];
3163 let w = dp_edge_weight(&q1, &q2, &t, 0, 2, 0, 1);
3164 let expected = 2.0 / 9.0;
3167 assert!(
3168 (w - expected).abs() < 1e-10,
3169 "dp_edge_weight (1,2): expected {expected}, got {w}"
3170 );
3171 }
3172
3173 #[test]
3174 fn test_dp_edge_weight_zero_span() {
3175 let t = uniform_grid(10);
3176 let q1 = vec![1.0; 10];
3177 let q2 = vec![1.0; 10];
3178 assert_eq!(dp_edge_weight(&q1, &q2, &t, 3, 3, 0, 1), f64::INFINITY);
3180 assert_eq!(dp_edge_weight(&q1, &q2, &t, 0, 1, 3, 3), f64::INFINITY);
3182 }
3183
3184 #[test]
3187 fn test_alignment_improves_distance() {
3188 let m = 50;
3190 let t = uniform_grid(m);
3191 let f1: Vec<f64> = t
3192 .iter()
3193 .map(|&x| (2.0 * std::f64::consts::PI * x).sin())
3194 .collect();
3195 let f2: Vec<f64> = t
3197 .iter()
3198 .map(|&x| (2.0 * std::f64::consts::PI * (x + 0.2)).sin())
3199 .collect();
3200
3201 let q1 = srsf_single(&f1, &t);
3202 let q2 = srsf_single(&f2, &t);
3203 let weights = simpsons_weights(&t);
3204 let unaligned_srsf_dist = l2_distance(&q1, &q2, &weights);
3205
3206 let result = elastic_align_pair(&f1, &f2, &t, 0.0);
3207
3208 assert!(
3209 result.distance <= unaligned_srsf_dist + 1e-6,
3210 "aligned SRSF dist ({}) should be <= unaligned SRSF dist ({})",
3211 result.distance,
3212 unaligned_srsf_dist
3213 );
3214 }
3215
3216 #[test]
3219 fn test_alignment_constant_curves() {
3220 let m = 30;
3221 let t = uniform_grid(m);
3222 let f1 = vec![5.0; m];
3223 let f2 = vec![5.0; m];
3224
3225 let result = elastic_align_pair(&f1, &f2, &t, 0.0);
3226 assert!(
3227 result.distance < 0.01,
3228 "Constant curves: distance should be ~0"
3229 );
3230 assert_eq!(result.f_aligned.len(), m);
3231 }
3232
3233 #[test]
3234 fn test_karcher_mean_constant_curves() {
3235 let m = 30;
3236 let t = uniform_grid(m);
3237 let mut data = FdMatrix::zeros(5, m);
3238 for i in 0..5 {
3239 for j in 0..m {
3240 data[(i, j)] = 3.0;
3241 }
3242 }
3243
3244 let result = karcher_mean(&data, &t, 5, 1e-4, 0.0);
3245 for j in 0..m {
3246 assert!(
3247 (result.mean[j] - 3.0).abs() < 0.5,
3248 "Mean of constant curves should be near 3.0, got {} at j={j}",
3249 result.mean[j]
3250 );
3251 }
3252 }
3253
3254 #[test]
3255 fn test_nan_srsf_no_panic() {
3256 let m = 20;
3257 let t = uniform_grid(m);
3258 let mut f = vec![1.0; m];
3259 f[5] = f64::NAN;
3260 let mat = FdMatrix::from_slice(&f, 1, m).unwrap();
3261 let q = srsf_transform(&mat, &t);
3262 assert_eq!(q.nrows(), 1);
3264 }
3265
3266 #[test]
3267 fn test_n1_karcher_mean() {
3268 let m = 30;
3269 let t = uniform_grid(m);
3270 let f: Vec<f64> = t.iter().map(|&ti| ti * ti).collect();
3271 let data = FdMatrix::from_slice(&f, 1, m).unwrap();
3272 let result = karcher_mean(&data, &t, 5, 1e-4, 0.0);
3273 assert_eq!(result.mean.len(), m);
3274 for j in 0..m {
3276 assert!(result.mean[j].is_finite());
3277 }
3278 }
3279
3280 #[test]
3281 fn test_two_point_grid() {
3282 let t = vec![0.0, 1.0];
3283 let f1 = vec![0.0, 1.0];
3284 let f2 = vec![0.0, 2.0];
3285 let d = elastic_distance(&f1, &f2, &t, 0.0);
3286 assert!(d >= 0.0);
3287 assert!(d.is_finite());
3288 }
3289
3290 #[test]
3291 fn test_non_uniform_grid_alignment() {
3292 let t = vec![0.0, 0.01, 0.05, 0.2, 0.5, 1.0];
3294 let m = t.len();
3295 let f1: Vec<f64> = t.iter().map(|&ti: &f64| ti.sin()).collect();
3296 let f2: Vec<f64> = t.iter().map(|&ti: &f64| (ti + 0.1).sin()).collect();
3297 let result = elastic_align_pair(&f1, &f2, &t, 0.0);
3298 assert_eq!(result.gamma.len(), m);
3299 assert!(result.distance >= 0.0);
3300 assert!(result.distance.is_finite());
3301 }
3302
3303 #[test]
3306 fn test_tsrvf_output_shape() {
3307 let m = 50;
3308 let n = 10;
3309 let t = uniform_grid(m);
3310 let data = make_test_data(n, m, 42);
3311 let result = tsrvf_transform(&data, &t, 5, 1e-3, 0.0);
3312 assert_eq!(
3313 result.tangent_vectors.shape(),
3314 (n, m),
3315 "Tangent vectors should be n×m"
3316 );
3317 assert_eq!(result.gammas.shape(), (n, m), "Gammas should be n×m");
3318 assert_eq!(result.srsf_norms.len(), n, "Should have n SRSF norms");
3319 assert_eq!(result.mean.len(), m, "Mean should have m points");
3320 assert_eq!(result.mean_srsf.len(), m, "Mean SRSF should have m points");
3321 }
3322
3323 #[test]
3324 fn test_tsrvf_all_finite() {
3325 let m = 50;
3326 let n = 5;
3327 let t = uniform_grid(m);
3328 let data = make_test_data(n, m, 42);
3329 let result = tsrvf_transform(&data, &t, 5, 1e-3, 0.0);
3330 for i in 0..n {
3331 for j in 0..m {
3332 assert!(
3333 result.tangent_vectors[(i, j)].is_finite(),
3334 "Tangent vector should be finite at ({i},{j})"
3335 );
3336 }
3337 assert!(
3338 result.srsf_norms[i].is_finite(),
3339 "SRSF norm should be finite for curve {i}"
3340 );
3341 }
3342 assert!(
3343 result.mean_srsf_norm.is_finite(),
3344 "Mean SRSF norm should be finite"
3345 );
3346 }
3347
3348 #[test]
3349 fn test_tsrvf_identical_curves_zero_tangent() {
3350 let m = 50;
3351 let t = uniform_grid(m);
3352 let curve: Vec<f64> = t
3354 .iter()
3355 .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
3356 .collect();
3357 let mut col_major = vec![0.0; 5 * m];
3358 for i in 0..5 {
3359 for j in 0..m {
3360 col_major[i + j * 5] = curve[j];
3361 }
3362 }
3363 let data = FdMatrix::from_column_major(col_major, 5, m).unwrap();
3364 let result = tsrvf_transform(&data, &t, 10, 1e-4, 0.0);
3365
3366 for i in 0..5 {
3368 let tv_norm_sq: f64 = (0..m).map(|j| result.tangent_vectors[(i, j)].powi(2)).sum();
3369 assert!(
3370 tv_norm_sq.sqrt() < 0.5,
3371 "Identical curves should have near-zero tangent vectors, got norm = {}",
3372 tv_norm_sq.sqrt()
3373 );
3374 }
3375 }
3376
3377 #[test]
3378 fn test_tsrvf_mean_tangent_near_zero() {
3379 let m = 50;
3380 let n = 10;
3381 let t = uniform_grid(m);
3382 let data = make_test_data(n, m, 42);
3383 let result = tsrvf_transform(&data, &t, 10, 1e-3, 0.0);
3384
3385 let mut mean_tv = vec![0.0; m];
3387 for i in 0..n {
3388 for j in 0..m {
3389 mean_tv[j] += result.tangent_vectors[(i, j)];
3390 }
3391 }
3392 for j in 0..m {
3393 mean_tv[j] /= n as f64;
3394 }
3395 let mean_norm: f64 = mean_tv.iter().map(|v| v * v).sum::<f64>().sqrt();
3396 assert!(
3397 mean_norm < 1.0,
3398 "Mean tangent vector should be near zero, got norm = {mean_norm}"
3399 );
3400 }
3401
3402 #[test]
3403 fn test_tsrvf_from_alignment() {
3404 let m = 50;
3405 let n = 5;
3406 let t = uniform_grid(m);
3407 let data = make_test_data(n, m, 42);
3408 let karcher = karcher_mean(&data, &t, 5, 1e-3, 0.0);
3409 let result = tsrvf_from_alignment(&karcher, &t);
3410 assert_eq!(result.tangent_vectors.shape(), (n, m));
3411 assert!(result.mean_srsf_norm > 0.0);
3412 }
3413
3414 #[test]
3415 fn test_tsrvf_round_trip() {
3416 let m = 50;
3417 let n = 5;
3418 let t = uniform_grid(m);
3419 let data = make_test_data(n, m, 42);
3420 let result = tsrvf_transform(&data, &t, 10, 1e-3, 0.0);
3421 let reconstructed = tsrvf_inverse(&result, &t);
3422
3423 assert_eq!(reconstructed.shape(), result.tangent_vectors.shape());
3424 for i in 0..n {
3426 for j in 0..m {
3427 assert!(
3428 reconstructed[(i, j)].is_finite(),
3429 "Reconstructed curve should be finite at ({i},{j})"
3430 );
3431 }
3432 }
3433 for i in 0..n {
3435 assert!(
3436 (reconstructed[(i, 0)] - result.initial_values[i]).abs() < 1e-6,
3437 "Curve {i} initial value: expected {}, got {}",
3438 result.initial_values[i],
3439 reconstructed[(i, 0)]
3440 );
3441 }
3442 }
3443
3444 #[test]
3445 fn test_tsrvf_initial_values_per_curve() {
3446 let m = 50;
3448 let n = 5;
3449 let t = uniform_grid(m);
3450
3451 let mut col_major = vec![0.0; n * m];
3453 for i in 0..n {
3454 let offset = (i as f64 + 1.0) * 2.0; for j in 0..m {
3456 col_major[i + j * n] = offset + (2.0 * std::f64::consts::PI * t[j]).sin();
3457 }
3458 }
3459 let data = FdMatrix::from_column_major(col_major, n, m).unwrap();
3460
3461 let result = tsrvf_transform(&data, &t, 15, 1e-4, 0.0);
3462
3463 assert_eq!(result.initial_values.len(), n);
3465 let all_same = result
3466 .initial_values
3467 .windows(2)
3468 .all(|w| (w[0] - w[1]).abs() < 1e-10);
3469 assert!(
3470 !all_same,
3471 "Initial values should differ per curve: {:?}",
3472 result.initial_values
3473 );
3474
3475 let reconstructed = tsrvf_inverse(&result, &t);
3477 for i in 0..n {
3478 assert!(
3479 (reconstructed[(i, 0)] - result.initial_values[i]).abs() < 1e-6,
3480 "Curve {i}: reconstructed f(0) = {}, expected {}",
3481 reconstructed[(i, 0)],
3482 result.initial_values[i]
3483 );
3484 }
3485
3486 let recon_initials: Vec<f64> = (0..n).map(|i| reconstructed[(i, 0)]).collect();
3489 let all_recon_same = recon_initials.windows(2).all(|w| (w[0] - w[1]).abs() < 0.1);
3490 assert!(
3491 !all_recon_same,
3492 "Reconstructed initial values must vary per curve: {:?}",
3493 recon_initials
3494 );
3495 }
3496
3497 #[test]
3498 fn test_tsrvf_single_curve() {
3499 let m = 50;
3500 let t = uniform_grid(m);
3501 let data = make_test_data(1, m, 42);
3502 let result = tsrvf_transform(&data, &t, 5, 1e-3, 0.0);
3503 assert_eq!(result.tangent_vectors.shape(), (1, m));
3504 let tv_norm: f64 = (0..m)
3506 .map(|j| result.tangent_vectors[(0, j)].powi(2))
3507 .sum::<f64>()
3508 .sqrt();
3509 assert!(
3510 tv_norm < 0.5,
3511 "Single curve tangent vector should be near zero, got {tv_norm}"
3512 );
3513 }
3514
3515 #[test]
3516 fn test_tsrvf_constant_curves() {
3517 let m = 30;
3518 let t = uniform_grid(m);
3519 let data = FdMatrix::from_column_major(vec![5.0; 3 * m], 3, m).unwrap();
3521 let result = tsrvf_transform(&data, &t, 5, 1e-3, 0.0);
3522 for i in 0..3 {
3524 for j in 0..m {
3525 let v = result.tangent_vectors[(i, j)];
3526 assert!(
3527 !v.is_nan(),
3528 "Constant curves should not produce NaN tangent vectors"
3529 );
3530 }
3531 }
3532 }
3533
3534 #[test]
3537 fn test_tsrvf_sphere_inv_exp_reference() {
3538 let m = 21;
3541 let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
3542
3543 let raw1 = vec![1.0; m];
3545 let norm1 = inner_product_l2(&raw1, &raw1, &time).max(0.0).sqrt();
3546 let psi1: Vec<f64> = raw1.iter().map(|&v| v / norm1).collect();
3547
3548 let raw2: Vec<f64> = time
3550 .iter()
3551 .map(|&t| 1.0 + 0.3 * (2.0 * std::f64::consts::PI * t).sin())
3552 .collect();
3553 let norm2 = inner_product_l2(&raw2, &raw2, &time).max(0.0).sqrt();
3554 let psi2: Vec<f64> = raw2.iter().map(|&v| v / norm2).collect();
3555
3556 let ip = inner_product_l2(&psi1, &psi2, &time).clamp(-1.0, 1.0);
3558 let theta_expected = ip.acos();
3559
3560 let v = inv_exp_map_sphere(&psi1, &psi2, &time);
3562 let v_norm = inner_product_l2(&v, &v, &time).max(0.0).sqrt();
3563
3564 assert!(
3566 (v_norm - theta_expected).abs() < 1e-10,
3567 "||v|| = {v_norm}, expected theta = {theta_expected}"
3568 );
3569
3570 assert!(
3572 theta_expected > 0.01 && theta_expected < 1.0,
3573 "theta = {theta_expected} out of expected range"
3574 );
3575 }
3576
3577 #[test]
3578 fn test_tsrvf_sphere_round_trip_reference() {
3579 let m = 21;
3581 let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
3582
3583 let raw1 = vec![1.0; m];
3584 let norm1 = inner_product_l2(&raw1, &raw1, &time).max(0.0).sqrt();
3585 let psi1: Vec<f64> = raw1.iter().map(|&v| v / norm1).collect();
3586
3587 let raw2: Vec<f64> = time
3588 .iter()
3589 .map(|&t| 1.0 + 0.3 * (2.0 * std::f64::consts::PI * t).sin())
3590 .collect();
3591 let norm2 = inner_product_l2(&raw2, &raw2, &time).max(0.0).sqrt();
3592 let psi2: Vec<f64> = raw2.iter().map(|&v| v / norm2).collect();
3593
3594 let v = inv_exp_map_sphere(&psi1, &psi2, &time);
3595 let recovered = exp_map_sphere(&psi1, &v, &time);
3596
3597 let diff: Vec<f64> = psi2
3599 .iter()
3600 .zip(recovered.iter())
3601 .map(|(&a, &b)| (a - b).powi(2))
3602 .collect();
3603 let l2_err = trapz(&diff, &time).max(0.0).sqrt();
3604 assert!(
3605 l2_err < 1e-12,
3606 "Round-trip L2 error = {l2_err:.2e}, expected < 1e-12"
3607 );
3608 }
3609
3610 #[test]
3613 fn test_penalized_alignment_lambda_zero_matches_unpenalized() {
3614 let m = 50;
3615 let t = uniform_grid(m);
3616 let data = make_test_data(2, m, 42);
3617 let f1 = data.row(0);
3618 let f2 = data.row(1);
3619
3620 let r0 = elastic_align_pair(&f1, &f2, &t, 0.0);
3621 assert!(r0.distance >= 0.0);
3623 assert_eq!(r0.gamma.len(), m);
3624 }
3625
3626 #[test]
3627 fn test_penalized_alignment_smoother_warp() {
3628 let m = 80;
3629 let t = uniform_grid(m);
3630 let f1: Vec<f64> = t
3631 .iter()
3632 .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
3633 .collect();
3634 let f2: Vec<f64> = t
3635 .iter()
3636 .map(|&ti| (2.0 * std::f64::consts::PI * (ti - 0.15)).sin())
3637 .collect();
3638
3639 let r_free = elastic_align_pair(&f1, &f2, &t, 0.0);
3640 let r_pen = elastic_align_pair(&f1, &f2, &t, 1.0);
3641
3642 let dev_free: f64 = r_free
3644 .gamma
3645 .iter()
3646 .zip(t.iter())
3647 .map(|(g, ti)| (g - ti).powi(2))
3648 .sum();
3649 let dev_pen: f64 = r_pen
3650 .gamma
3651 .iter()
3652 .zip(t.iter())
3653 .map(|(g, ti)| (g - ti).powi(2))
3654 .sum();
3655
3656 assert!(
3657 dev_pen <= dev_free + 1e-6,
3658 "Penalized warp should be closer to identity: free={dev_free:.6}, pen={dev_pen:.6}"
3659 );
3660 }
3661
3662 #[test]
3663 fn test_penalized_alignment_large_lambda_near_identity() {
3664 let m = 50;
3665 let t = uniform_grid(m);
3666 let f1: Vec<f64> = t
3667 .iter()
3668 .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
3669 .collect();
3670 let f2: Vec<f64> = t
3671 .iter()
3672 .map(|&ti| (2.0 * std::f64::consts::PI * (ti - 0.1)).sin())
3673 .collect();
3674
3675 let r = elastic_align_pair(&f1, &f2, &t, 1000.0);
3676
3677 let max_dev: f64 = r
3679 .gamma
3680 .iter()
3681 .zip(t.iter())
3682 .map(|(g, ti)| (g - ti).abs())
3683 .fold(0.0_f64, f64::max);
3684 assert!(
3685 max_dev < 0.05,
3686 "Large lambda should give near-identity warp: max deviation = {max_dev}"
3687 );
3688 }
3689
3690 #[test]
3691 fn test_penalized_karcher_mean() {
3692 let m = 40;
3693 let t = uniform_grid(m);
3694 let data = make_test_data(10, m, 42);
3695
3696 let result = karcher_mean(&data, &t, 5, 1e-3, 0.5);
3697 assert_eq!(result.mean.len(), m);
3698 for j in 0..m {
3699 assert!(result.mean[j].is_finite());
3700 }
3701 }
3702
3703 #[test]
3706 fn test_decomposition_identity_curves() {
3707 let m = 50;
3708 let t = uniform_grid(m);
3709 let f: Vec<f64> = t
3710 .iter()
3711 .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
3712 .collect();
3713
3714 let result = elastic_decomposition(&f, &f, &t, 0.0);
3715 assert!(
3716 result.d_amplitude < 0.1,
3717 "Self-decomposition amplitude should be ~0, got {}",
3718 result.d_amplitude
3719 );
3720 assert!(
3721 result.d_phase < 0.2,
3722 "Self-decomposition phase should be ~0, got {}",
3723 result.d_phase
3724 );
3725 }
3726
3727 #[test]
3728 fn test_decomposition_pythagorean() {
3729 let m = 80;
3731 let t = uniform_grid(m);
3732 let f1: Vec<f64> = t
3733 .iter()
3734 .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
3735 .collect();
3736 let f2: Vec<f64> = t
3737 .iter()
3738 .map(|&ti| 1.2 * (2.0 * std::f64::consts::PI * (ti - 0.1)).sin())
3739 .collect();
3740
3741 let result = elastic_decomposition(&f1, &f2, &t, 0.0);
3742 let da = result.d_amplitude;
3743 let dp = result.d_phase;
3744 assert!(da >= 0.0);
3746 assert!(dp >= 0.0);
3747 }
3748
3749 #[test]
3750 fn test_phase_distance_shifted_sine() {
3751 let m = 80;
3752 let t = uniform_grid(m);
3753 let f1: Vec<f64> = t
3754 .iter()
3755 .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
3756 .collect();
3757 let f2: Vec<f64> = t
3758 .iter()
3759 .map(|&ti| (2.0 * std::f64::consts::PI * (ti - 0.15)).sin())
3760 .collect();
3761
3762 let dp = phase_distance_pair(&f1, &f2, &t, 0.0);
3763 assert!(
3764 dp > 0.01,
3765 "Phase distance of shifted curves should be > 0, got {dp}"
3766 );
3767 }
3768
3769 #[test]
3770 fn test_amplitude_distance_scaled_curve() {
3771 let m = 80;
3772 let t = uniform_grid(m);
3773 let f1: Vec<f64> = t
3774 .iter()
3775 .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
3776 .collect();
3777 let f2: Vec<f64> = t
3778 .iter()
3779 .map(|&ti| 2.0 * (2.0 * std::f64::consts::PI * ti).sin())
3780 .collect();
3781
3782 let da = amplitude_distance(&f1, &f2, &t, 0.0);
3783 assert!(
3784 da > 0.01,
3785 "Amplitude distance of scaled curves should be > 0, got {da}"
3786 );
3787 }
3788
3789 #[test]
3790 fn test_phase_distance_nonneg() {
3791 let data = make_test_data(4, 40, 42);
3792 let t = uniform_grid(40);
3793 for i in 0..4 {
3794 for j in 0..4 {
3795 let fi = data.row(i);
3796 let fj = data.row(j);
3797 let dp = phase_distance_pair(&fi, &fj, &t, 0.0);
3798 assert!(dp >= 0.0, "Phase distance should be non-negative");
3799 }
3800 }
3801 }
3802
3803 #[test]
3806 fn test_schilds_ladder_zero_vector() {
3807 let m = 21;
3808 let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
3809 let raw = vec![1.0; m];
3810 let norm = crate::warping::l2_norm_l2(&raw, &time);
3811 let from: Vec<f64> = raw.iter().map(|&v| v / norm).collect();
3812 let raw2: Vec<f64> = time
3813 .iter()
3814 .map(|&t| 1.0 + 0.2 * (2.0 * std::f64::consts::PI * t).sin())
3815 .collect();
3816 let norm2 = crate::warping::l2_norm_l2(&raw2, &time);
3817 let to: Vec<f64> = raw2.iter().map(|&v| v / norm2).collect();
3818
3819 let zero = vec![0.0; m];
3820 let result = parallel_transport_schilds(&zero, &from, &to, &time);
3821 let result_norm: f64 = result.iter().map(|v| v * v).sum::<f64>().sqrt();
3822 assert!(
3823 result_norm < 1e-6,
3824 "Transporting zero should give zero, got norm {result_norm}"
3825 );
3826 }
3827
3828 #[test]
3829 fn test_pole_ladder_zero_vector() {
3830 let m = 21;
3831 let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
3832 let raw = vec![1.0; m];
3833 let norm = crate::warping::l2_norm_l2(&raw, &time);
3834 let from: Vec<f64> = raw.iter().map(|&v| v / norm).collect();
3835 let raw2: Vec<f64> = time
3836 .iter()
3837 .map(|&t| 1.0 + 0.2 * (2.0 * std::f64::consts::PI * t).sin())
3838 .collect();
3839 let norm2 = crate::warping::l2_norm_l2(&raw2, &time);
3840 let to: Vec<f64> = raw2.iter().map(|&v| v / norm2).collect();
3841
3842 let zero = vec![0.0; m];
3843 let result = parallel_transport_pole(&zero, &from, &to, &time);
3844 let result_norm: f64 = result.iter().map(|v| v * v).sum::<f64>().sqrt();
3845 assert!(
3846 result_norm < 1e-6,
3847 "Transporting zero should give zero, got norm {result_norm}"
3848 );
3849 }
3850
3851 #[test]
3852 fn test_schilds_preserves_norm() {
3853 let m = 51;
3854 let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
3855 let raw = vec![1.0; m];
3856 let norm = crate::warping::l2_norm_l2(&raw, &time);
3857 let from: Vec<f64> = raw.iter().map(|&v| v / norm).collect();
3858 let raw2: Vec<f64> = time
3859 .iter()
3860 .map(|&t| 1.0 + 0.15 * (2.0 * std::f64::consts::PI * t).sin())
3861 .collect();
3862 let norm2 = crate::warping::l2_norm_l2(&raw2, &time);
3863 let to: Vec<f64> = raw2.iter().map(|&v| v / norm2).collect();
3864
3865 let v: Vec<f64> = time
3867 .iter()
3868 .map(|&t| 0.1 * (4.0 * std::f64::consts::PI * t).cos())
3869 .collect();
3870 let v_norm = crate::warping::l2_norm_l2(&v, &time);
3871
3872 let transported = parallel_transport_schilds(&v, &from, &to, &time);
3873 let t_norm = crate::warping::l2_norm_l2(&transported, &time);
3874
3875 assert!(
3877 (t_norm - v_norm).abs() / v_norm.max(1e-10) < 1.5,
3878 "Schild's should roughly preserve norm: original={v_norm:.4}, transported={t_norm:.4}"
3879 );
3880 }
3881
3882 #[test]
3883 fn test_tsrvf_logmap_matches_original() {
3884 let m = 50;
3885 let n = 5;
3886 let t = uniform_grid(m);
3887 let data = make_test_data(n, m, 42);
3888
3889 let result_orig = tsrvf_transform(&data, &t, 5, 1e-3, 0.0);
3890 let result_logmap =
3891 tsrvf_transform_with_method(&data, &t, 5, 1e-3, 0.0, TransportMethod::LogMap);
3892
3893 for i in 0..n {
3895 for j in 0..m {
3896 assert!(
3897 (result_orig.tangent_vectors[(i, j)] - result_logmap.tangent_vectors[(i, j)])
3898 .abs()
3899 < 1e-12,
3900 "LogMap variant should match original at ({i},{j})"
3901 );
3902 }
3903 }
3904 }
3905
3906 #[test]
3907 fn test_tsrvf_with_schilds_produces_valid_result() {
3908 let m = 50;
3909 let n = 5;
3910 let t = uniform_grid(m);
3911 let data = make_test_data(n, m, 42);
3912
3913 let result =
3914 tsrvf_transform_with_method(&data, &t, 5, 1e-3, 0.0, TransportMethod::SchildsLadder);
3915
3916 assert_eq!(result.tangent_vectors.shape(), (n, m));
3917 for i in 0..n {
3918 for j in 0..m {
3919 assert!(
3920 result.tangent_vectors[(i, j)].is_finite(),
3921 "Schild's TSRVF should produce finite tangent vectors at ({i},{j})"
3922 );
3923 }
3924 }
3925 }
3926
3927 #[test]
3928 fn test_transport_methods_differ() {
3929 let m = 50;
3930 let n = 5;
3931 let t = uniform_grid(m);
3932 let data = make_test_data(n, m, 42);
3933 let karcher = karcher_mean(&data, &t, 5, 1e-3, 0.0);
3934
3935 let r_log = tsrvf_from_alignment_with_method(&karcher, &t, TransportMethod::LogMap);
3936 let r_schilds =
3937 tsrvf_from_alignment_with_method(&karcher, &t, TransportMethod::SchildsLadder);
3938
3939 let mut total_diff = 0.0;
3941 for i in 0..n {
3942 for j in 0..m {
3943 total_diff +=
3944 (r_log.tangent_vectors[(i, j)] - r_schilds.tangent_vectors[(i, j)]).abs();
3945 }
3946 }
3947
3948 assert!(total_diff.is_finite());
3951 }
3952
3953 #[test]
3956 fn test_warp_complexity_identity_is_zero() {
3957 let m = 50;
3958 let t = uniform_grid(m);
3959 let identity = t.clone();
3960 let c = warp_complexity(&identity, &t);
3961 assert!(
3962 c < 1e-10,
3963 "Identity warp should have zero complexity, got {c}"
3964 );
3965 }
3966
3967 #[test]
3968 fn test_warp_complexity_nonidentity_positive() {
3969 let m = 50;
3970 let t = uniform_grid(m);
3971 let gamma: Vec<f64> = t.iter().map(|&ti| ti * ti).collect();
3972 let c = warp_complexity(&gamma, &t);
3973 assert!(
3974 c > 0.01,
3975 "Non-identity warp should have positive complexity, got {c}"
3976 );
3977 }
3978
3979 #[test]
3980 fn test_warp_smoothness_identity_is_zero() {
3981 let m = 50;
3982 let t = uniform_grid(m);
3983 let identity = t.clone();
3984 let s = warp_smoothness(&identity, &t);
3985 assert!(
3986 s < 1e-6,
3987 "Identity warp (constant γ'=1, γ''=0) should have near-zero bending energy, got {s}"
3988 );
3989 }
3990
3991 #[test]
3992 fn test_alignment_quality_basic() {
3993 let m = 50;
3994 let n = 8;
3995 let t = uniform_grid(m);
3996 let data = make_test_data(n, m, 42);
3997 let karcher = karcher_mean(&data, &t, 10, 1e-3, 0.0);
3998 let quality = alignment_quality(&data, &karcher, &t);
3999
4000 assert_eq!(quality.warp_complexity.len(), n);
4002 assert_eq!(quality.warp_smoothness.len(), n);
4003 assert_eq!(quality.pointwise_variance_ratio.len(), m);
4004
4005 assert!(quality.total_variance >= 0.0);
4007 assert!(quality.amplitude_variance >= 0.0);
4008 assert!(quality.phase_variance >= 0.0);
4009 assert!(quality.mean_warp_complexity >= 0.0);
4010 assert!(quality.mean_warp_smoothness >= 0.0);
4011
4012 assert!(
4014 quality.amplitude_variance <= quality.total_variance + 1e-10,
4015 "Amplitude variance ({}) should be ≤ total variance ({})",
4016 quality.amplitude_variance,
4017 quality.total_variance
4018 );
4019 }
4020
4021 #[test]
4022 fn test_alignment_quality_identical_curves() {
4023 let m = 50;
4024 let t = uniform_grid(m);
4025 let curve: Vec<f64> = t
4026 .iter()
4027 .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
4028 .collect();
4029 let mut col_major = vec![0.0; 5 * m];
4030 for i in 0..5 {
4031 for j in 0..m {
4032 col_major[i + j * 5] = curve[j];
4033 }
4034 }
4035 let data = FdMatrix::from_column_major(col_major, 5, m).unwrap();
4036 let karcher = karcher_mean(&data, &t, 5, 1e-3, 0.0);
4037 let quality = alignment_quality(&data, &karcher, &t);
4038
4039 assert!(
4041 quality.total_variance < 0.01,
4042 "Identical curves should have near-zero total variance, got {}",
4043 quality.total_variance
4044 );
4045 assert!(
4046 quality.mean_warp_complexity < 0.1,
4047 "Identical curves should have near-zero warp complexity, got {}",
4048 quality.mean_warp_complexity
4049 );
4050 }
4051
4052 #[test]
4053 fn test_alignment_quality_variance_reduction() {
4054 let m = 50;
4055 let n = 10;
4056 let t = uniform_grid(m);
4057 let data = make_test_data(n, m, 42);
4058 let karcher = karcher_mean(&data, &t, 10, 1e-3, 0.0);
4059 let quality = alignment_quality(&data, &karcher, &t);
4060
4061 assert!(
4063 quality.mean_variance_reduction <= 1.5,
4064 "Mean variance reduction ratio should be ≤ ~1, got {}",
4065 quality.mean_variance_reduction
4066 );
4067 }
4068
4069 #[test]
4070 fn test_pairwise_consistency_small() {
4071 let m = 40;
4072 let n = 4;
4073 let t = uniform_grid(m);
4074 let data = make_test_data(n, m, 42);
4075
4076 let consistency = pairwise_consistency(&data, &t, 0.0, 100);
4077 assert!(
4078 consistency.is_finite() && consistency >= 0.0,
4079 "Pairwise consistency should be finite and non-negative, got {consistency}"
4080 );
4081 }
4082
4083 #[test]
4086 fn test_srsf_nd_d1_matches_existing() {
4087 let m = 50;
4088 let t = uniform_grid(m);
4089 let data = make_test_data(3, m, 42);
4090
4091 let q_1d = srsf_transform(&data, &t);
4093
4094 let data_nd = FdCurveSet::from_1d(data);
4096 let q_nd = srsf_transform_nd(&data_nd, &t);
4097
4098 assert_eq!(q_nd.ndim(), 1);
4099 for i in 0..3 {
4100 for j in 0..m {
4101 assert!(
4102 (q_1d[(i, j)] - q_nd.dims[0][(i, j)]).abs() < 1e-10,
4103 "1D nd SRSF should match existing at ({i},{j}): {} vs {}",
4104 q_1d[(i, j)],
4105 q_nd.dims[0][(i, j)]
4106 );
4107 }
4108 }
4109 }
4110
4111 #[test]
4112 fn test_srsf_nd_constant_is_zero() {
4113 let m = 30;
4114 let t = uniform_grid(m);
4115 let dim0 = FdMatrix::from_column_major(vec![3.0; m], 1, m).unwrap();
4117 let dim1 = FdMatrix::from_column_major(vec![-1.0; m], 1, m).unwrap();
4118 let data = FdCurveSet::from_dims(vec![dim0, dim1]).unwrap();
4119
4120 let q = srsf_transform_nd(&data, &t);
4121 for k in 0..2 {
4122 for j in 0..m {
4123 assert!(
4124 q.dims[k][(0, j)].abs() < 1e-10,
4125 "Constant curve SRSF should be zero, dim {k} at {j}: {}",
4126 q.dims[k][(0, j)]
4127 );
4128 }
4129 }
4130 }
4131
4132 #[test]
4133 fn test_srsf_nd_linear_r2() {
4134 let m = 51;
4135 let t = uniform_grid(m);
4136 let dim0 =
4139 FdMatrix::from_slice(&t.iter().map(|&ti| 2.0 * ti).collect::<Vec<_>>(), 1, m).unwrap();
4140 let dim1 =
4141 FdMatrix::from_slice(&t.iter().map(|&ti| 3.0 * ti).collect::<Vec<_>>(), 1, m).unwrap();
4142 let data = FdCurveSet::from_dims(vec![dim0, dim1]).unwrap();
4143
4144 let q = srsf_transform_nd(&data, &t);
4145 let expected_scale = 1.0 / 13.0_f64.powf(0.25);
4146 let mid = m / 2;
4147
4148 assert!(
4149 (q.dims[0][(0, mid)] - 2.0 * expected_scale).abs() < 0.1,
4150 "q_x at midpoint: {} vs expected {}",
4151 q.dims[0][(0, mid)],
4152 2.0 * expected_scale
4153 );
4154 assert!(
4155 (q.dims[1][(0, mid)] - 3.0 * expected_scale).abs() < 0.1,
4156 "q_y at midpoint: {} vs expected {}",
4157 q.dims[1][(0, mid)],
4158 3.0 * expected_scale
4159 );
4160 }
4161
4162 #[test]
4163 fn test_srsf_nd_round_trip() {
4164 let m = 51;
4165 let t = uniform_grid(m);
4166 let pi2 = 2.0 * std::f64::consts::PI;
4168 let vals_x: Vec<f64> = t.iter().map(|&ti| (pi2 * ti).sin()).collect();
4169 let vals_y: Vec<f64> = t.iter().map(|&ti| (pi2 * ti).cos()).collect();
4170 let dim0 = FdMatrix::from_slice(&vals_x, 1, m).unwrap();
4171 let dim1 = FdMatrix::from_slice(&vals_y, 1, m).unwrap();
4172 let data = FdCurveSet::from_dims(vec![dim0, dim1]).unwrap();
4173
4174 let q = srsf_transform_nd(&data, &t);
4175 let q_vecs: Vec<Vec<f64>> = q.dims.iter().map(|dm| dm.row(0)).collect();
4176 let f0 = vec![vals_x[0], vals_y[0]];
4177 let recon = srsf_inverse_nd(&q_vecs, &t, &f0);
4178
4179 let mut max_err = 0.0_f64;
4181 for k in 0..2 {
4182 let orig = if k == 0 { &vals_x } else { &vals_y };
4183 for j in 2..(m - 2) {
4184 let err = (recon[k][j] - orig[j]).abs();
4185 max_err = max_err.max(err);
4186 }
4187 }
4188 assert!(
4189 max_err < 0.2,
4190 "SRSF round-trip max error should be small, got {max_err}"
4191 );
4192 }
4193
4194 #[test]
4195 fn test_align_nd_identical_near_zero() {
4196 let m = 50;
4197 let t = uniform_grid(m);
4198 let pi2 = 2.0 * std::f64::consts::PI;
4199 let vals_x: Vec<f64> = t.iter().map(|&ti| (pi2 * ti).sin()).collect();
4200 let vals_y: Vec<f64> = t.iter().map(|&ti| (pi2 * ti).cos()).collect();
4201 let dim0 = FdMatrix::from_slice(&vals_x, 1, m).unwrap();
4202 let dim1 = FdMatrix::from_slice(&vals_y, 1, m).unwrap();
4203 let data = FdCurveSet::from_dims(vec![dim0, dim1]).unwrap();
4204
4205 let result = elastic_align_pair_nd(&data, &data, &t, 0.0);
4206 assert!(
4207 result.distance < 0.5,
4208 "Self-alignment distance should be ~0, got {}",
4209 result.distance
4210 );
4211 let max_dev: f64 = result
4213 .gamma
4214 .iter()
4215 .zip(t.iter())
4216 .map(|(g, ti)| (g - ti).abs())
4217 .fold(0.0_f64, f64::max);
4218 assert!(
4219 max_dev < 0.1,
4220 "Self-alignment warp should be near identity, max dev = {max_dev}"
4221 );
4222 }
4223
4224 #[test]
4225 fn test_align_nd_shifted_r2() {
4226 let m = 60;
4227 let t = uniform_grid(m);
4228 let pi2 = 2.0 * std::f64::consts::PI;
4229
4230 let f1x: Vec<f64> = t.iter().map(|&ti| (pi2 * ti).sin()).collect();
4232 let f1y: Vec<f64> = t.iter().map(|&ti| (pi2 * ti).cos()).collect();
4233 let f1 = FdCurveSet::from_dims(vec![
4234 FdMatrix::from_slice(&f1x, 1, m).unwrap(),
4235 FdMatrix::from_slice(&f1y, 1, m).unwrap(),
4236 ])
4237 .unwrap();
4238
4239 let f2x: Vec<f64> = t.iter().map(|&ti| (pi2 * (ti - 0.1)).sin()).collect();
4241 let f2y: Vec<f64> = t.iter().map(|&ti| (pi2 * (ti - 0.1)).cos()).collect();
4242 let f2 = FdCurveSet::from_dims(vec![
4243 FdMatrix::from_slice(&f2x, 1, m).unwrap(),
4244 FdMatrix::from_slice(&f2y, 1, m).unwrap(),
4245 ])
4246 .unwrap();
4247
4248 let result = elastic_align_pair_nd(&f1, &f2, &t, 0.0);
4249 assert!(
4250 result.distance.is_finite(),
4251 "Distance should be finite, got {}",
4252 result.distance
4253 );
4254 assert_eq!(result.f_aligned.len(), 2);
4255 assert_eq!(result.f_aligned[0].len(), m);
4256 let max_dev: f64 = result
4258 .gamma
4259 .iter()
4260 .zip(t.iter())
4261 .map(|(g, ti)| (g - ti).abs())
4262 .fold(0.0_f64, f64::max);
4263 assert!(
4264 max_dev > 0.01,
4265 "Shifted curves should require non-trivial warp, max dev = {max_dev}"
4266 );
4267 }
4268
4269 #[test]
4272 fn test_constrained_no_landmarks_matches_unconstrained() {
4273 let m = 50;
4274 let t = uniform_grid(m);
4275 let f1: Vec<f64> = t
4276 .iter()
4277 .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
4278 .collect();
4279 let f2: Vec<f64> = t
4280 .iter()
4281 .map(|&ti| (2.0 * std::f64::consts::PI * (ti - 0.1)).sin())
4282 .collect();
4283
4284 let r_free = elastic_align_pair(&f1, &f2, &t, 0.0);
4285 let r_const = elastic_align_pair_constrained(&f1, &f2, &t, &[], 0.0);
4286
4287 for j in 0..m {
4289 assert!(
4290 (r_free.gamma[j] - r_const.gamma[j]).abs() < 1e-10,
4291 "No-landmark constrained should match unconstrained at {j}"
4292 );
4293 }
4294 assert!(r_const.enforced_landmarks.is_empty());
4295 }
4296
4297 #[test]
4298 fn test_constrained_single_landmark_enforced() {
4299 let m = 60;
4300 let t = uniform_grid(m);
4301 let f1: Vec<f64> = t
4302 .iter()
4303 .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
4304 .collect();
4305 let f2: Vec<f64> = t
4306 .iter()
4307 .map(|&ti| (2.0 * std::f64::consts::PI * (ti - 0.1)).sin())
4308 .collect();
4309
4310 let result = elastic_align_pair_constrained(&f1, &f2, &t, &[(0.5, 0.5)], 0.0);
4312
4313 let mid_idx = snap_to_grid(0.5, &t);
4315 assert!(
4316 (result.gamma[mid_idx] - 0.5).abs() < 0.05,
4317 "Constrained gamma at midpoint should be ~0.5, got {}",
4318 result.gamma[mid_idx]
4319 );
4320 assert_eq!(result.enforced_landmarks.len(), 1);
4321 }
4322
4323 #[test]
4324 fn test_constrained_multiple_landmarks() {
4325 let m = 80;
4326 let t = uniform_grid(m);
4327 let f1: Vec<f64> = t
4328 .iter()
4329 .map(|&ti| (4.0 * std::f64::consts::PI * ti).sin())
4330 .collect();
4331 let f2: Vec<f64> = t
4332 .iter()
4333 .map(|&ti| (4.0 * std::f64::consts::PI * (ti - 0.05)).sin())
4334 .collect();
4335
4336 let landmarks = vec![(0.25, 0.25), (0.5, 0.5), (0.75, 0.75)];
4337 let result = elastic_align_pair_constrained(&f1, &f2, &t, &landmarks, 0.0);
4338
4339 for &(tt, st) in &landmarks {
4341 let idx = snap_to_grid(tt, &t);
4342 assert!(
4343 (result.gamma[idx] - st).abs() < 0.05,
4344 "Gamma at t={tt} should be ~{st}, got {}",
4345 result.gamma[idx]
4346 );
4347 }
4348 }
4349
4350 #[test]
4351 fn test_constrained_monotone_gamma() {
4352 let m = 60;
4353 let t = uniform_grid(m);
4354 let f1: Vec<f64> = t
4355 .iter()
4356 .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
4357 .collect();
4358 let f2: Vec<f64> = t
4359 .iter()
4360 .map(|&ti| (2.0 * std::f64::consts::PI * (ti - 0.1)).sin())
4361 .collect();
4362
4363 let result = elastic_align_pair_constrained(&f1, &f2, &t, &[(0.3, 0.3), (0.7, 0.7)], 0.0);
4364
4365 for j in 1..m {
4367 assert!(
4368 result.gamma[j] >= result.gamma[j - 1] - 1e-10,
4369 "Gamma should be monotone: gamma[{}]={} < gamma[{}]={}",
4370 j,
4371 result.gamma[j],
4372 j - 1,
4373 result.gamma[j - 1]
4374 );
4375 }
4376 assert!((result.gamma[0] - t[0]).abs() < 1e-10);
4378 assert!((result.gamma[m - 1] - t[m - 1]).abs() < 1e-10);
4379 }
4380
4381 #[test]
4382 fn test_constrained_distance_ge_unconstrained() {
4383 let m = 60;
4384 let t = uniform_grid(m);
4385 let f1: Vec<f64> = t
4386 .iter()
4387 .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
4388 .collect();
4389 let f2: Vec<f64> = t
4390 .iter()
4391 .map(|&ti| (2.0 * std::f64::consts::PI * (ti - 0.15)).sin())
4392 .collect();
4393
4394 let r_free = elastic_align_pair(&f1, &f2, &t, 0.0);
4395 let r_const = elastic_align_pair_constrained(&f1, &f2, &t, &[(0.5, 0.5)], 0.0);
4396
4397 assert!(
4399 r_const.distance >= r_free.distance - 1e-6,
4400 "Constrained distance ({}) should be >= unconstrained ({})",
4401 r_const.distance,
4402 r_free.distance
4403 );
4404 }
4405
4406 #[test]
4407 fn test_constrained_with_landmark_detection() {
4408 let m = 80;
4409 let t = uniform_grid(m);
4410 let f1: Vec<f64> = t
4411 .iter()
4412 .map(|&ti| (4.0 * std::f64::consts::PI * ti).sin())
4413 .collect();
4414 let f2: Vec<f64> = t
4415 .iter()
4416 .map(|&ti| (4.0 * std::f64::consts::PI * (ti - 0.05)).sin())
4417 .collect();
4418
4419 let result = elastic_align_pair_with_landmarks(
4420 &f1,
4421 &f2,
4422 &t,
4423 crate::landmark::LandmarkKind::Peak,
4424 0.1,
4425 0,
4426 0.0,
4427 );
4428
4429 assert_eq!(result.gamma.len(), m);
4430 assert_eq!(result.f_aligned.len(), m);
4431 assert!(result.distance.is_finite());
4432 for j in 1..m {
4434 assert!(
4435 result.gamma[j] >= result.gamma[j - 1] - 1e-10,
4436 "Gamma should be monotone at j={j}"
4437 );
4438 }
4439 }
4440
4441 #[test]
4444 fn test_gam_to_psi_smooth_identity() {
4445 use crate::warping::{gam_to_psi, gam_to_psi_smooth};
4448 let m = 101;
4449 let h = 1.0 / (m - 1) as f64;
4450 let gam: Vec<f64> = uniform_grid(m);
4451 let psi_raw = gam_to_psi(&gam, h);
4452 let psi_smooth = gam_to_psi_smooth(&gam, h);
4453 let skip = m / 20;
4455 for j in skip..(m - skip) {
4456 assert!(
4457 (psi_smooth[j] - 1.0).abs() < 0.05,
4458 "Smoothed psi of identity should be ~1.0, got {} at j={}",
4459 psi_smooth[j],
4460 j
4461 );
4462 assert!(
4463 (psi_smooth[j] - psi_raw[j]).abs() < 0.05,
4464 "Smoothed and raw psi should agree on smooth warp at j={}",
4465 j
4466 );
4467 }
4468 }
4469
4470 #[test]
4471 fn test_gam_to_psi_smooth_reduces_spikes() {
4472 use crate::warping::{gam_to_psi, gam_to_psi_smooth};
4475 let m = 101;
4476 let h = 1.0 / (m - 1) as f64;
4477 let argvals = uniform_grid(m);
4478 let mut gam: Vec<f64> = Vec::with_capacity(m);
4480 for j in 0..m {
4481 let t = argvals[j];
4482 let g = if t < 0.33 {
4484 t * 0.5 / 0.33
4485 } else if t < 0.67 {
4486 0.5 + (t - 0.33) * 0.5 / 0.34 * 2.0 } else {
4488 let base = 0.5 + 0.5 / 0.34 * 2.0 * 0.34; (base + (t - 0.67) * 0.5 / 0.33).min(1.0)
4490 };
4491 gam.push(g.min(1.0));
4492 }
4493 let gmax = gam[m - 1].max(1e-10);
4495 for g in &mut gam {
4496 *g /= gmax;
4497 }
4498 let psi_raw = gam_to_psi(&gam, h);
4499 let psi_smooth = gam_to_psi_smooth(&gam, h);
4500 let max_jump_raw: f64 = psi_raw
4502 .windows(2)
4503 .map(|w| (w[1] - w[0]).abs())
4504 .fold(0.0_f64, f64::max);
4505 let max_jump_smooth: f64 = psi_smooth
4506 .windows(2)
4507 .map(|w| (w[1] - w[0]).abs())
4508 .fold(0.0_f64, f64::max);
4509 assert!(
4511 max_jump_smooth < max_jump_raw + 0.01,
4512 "Smoothing should not increase max psi jump: raw={max_jump_raw:.4}, smooth={max_jump_smooth:.4}"
4513 );
4514 }
4515
4516 #[test]
4517 fn test_smooth_aligned_srsfs_preserves_shape() {
4518 use crate::smoothing::nadaraya_watson;
4520 let m = 101;
4521 let time: Vec<f64> = (0..m).map(|j| j as f64 / (m - 1) as f64).collect();
4522 let qi: Vec<f64> = time
4524 .iter()
4525 .map(|&t| (2.0 * std::f64::consts::PI * t).sin())
4526 .collect();
4527 let bandwidth = 2.0 / (m - 1) as f64;
4528 let qi_smooth = nadaraya_watson(&time, &qi, &time, bandwidth, "gaussian");
4529 let mean_orig: f64 = qi.iter().sum::<f64>() / m as f64;
4531 let mean_smooth: f64 = qi_smooth.iter().sum::<f64>() / m as f64;
4532 let mut cov = 0.0;
4533 let mut var_o = 0.0;
4534 let mut var_s = 0.0;
4535 for j in 0..m {
4536 let do_ = qi[j] - mean_orig;
4537 let ds = qi_smooth[j] - mean_smooth;
4538 cov += do_ * ds;
4539 var_o += do_ * do_;
4540 var_s += ds * ds;
4541 }
4542 let rho = cov / (var_o * var_s).sqrt().max(1e-10);
4543 assert!(
4544 rho > 0.99,
4545 "Smoothed SRSF should be highly correlated with original (rho={rho:.4})"
4546 );
4547 }
4548
4549 #[test]
4550 fn test_tsrvf_tangent_vectors_no_spikes() {
4551 let m = 101;
4555 let argvals = uniform_grid(m);
4556 let data = make_test_data(10, m, 42);
4557 let result = tsrvf_transform(&data, &argvals, 5, 1e-3, 0.0);
4558 let (n, _) = result.tangent_vectors.shape();
4559 for i in 0..n {
4560 let vi = result.tangent_vectors.row(i);
4561 let rms = (vi.iter().map(|&v| v * v).sum::<f64>() / m as f64).sqrt();
4562 if rms > 1e-10 {
4563 let max_abs = vi.iter().map(|&v| v.abs()).fold(0.0_f64, f64::max);
4564 assert!(
4565 max_abs < 10.0 * rms,
4566 "Tangent vector {} has spike: max |v| = {max_abs:.4}, rms = {rms:.4}, ratio = {:.1}",
4567 i,
4568 max_abs / rms
4569 );
4570 }
4571 }
4572 }
4573}