Skip to main content

fdars_core/
alignment.rs

1//! Elastic alignment and SRSF (Square-Root Slope Function) transforms.
2//!
3//! This module provides phase-amplitude separation for functional data via
4//! the elastic framework. Key capabilities:
5//!
6//! - [`srsf_transform`] / [`srsf_inverse`] — SRSF representation and reconstruction
7//! - [`elastic_align_pair`] — Pairwise curve alignment via dynamic programming
8//! - [`elastic_distance`] — Elastic (Fisher-Rao) distance between curves
9//! - [`align_to_target`] — Align a set of curves to a common target
10//! - [`karcher_mean`] — Karcher (Fréchet) mean in the elastic metric
11//! - [`elastic_self_distance_matrix`] / [`elastic_cross_distance_matrix`] — Distance matrices
12//! - [`reparameterize_curve`] / [`compose_warps`] — Warping utilities
13
14use crate::fdata::{deriv_1d, mean_1d};
15use crate::helpers::{
16    cumulative_trapz, gradient_uniform, l2_distance, linear_interp, simpsons_weights,
17};
18use crate::iter_maybe_parallel;
19use crate::matrix::FdMatrix;
20use crate::smoothing::nadaraya_watson;
21use crate::warping::{
22    exp_map_sphere, gam_to_psi, inv_exp_map_sphere, invert_gamma, l2_norm_l2, normalize_warp,
23    psi_to_gam,
24};
25#[cfg(feature = "parallel")]
26use rayon::iter::ParallelIterator;
27
28// ─── Types ──────────────────────────────────────────────────────────────────
29
30/// Result of aligning one curve to another.
31#[derive(Debug, Clone)]
32pub struct AlignmentResult {
33    /// Warping function γ mapping the domain to itself.
34    pub gamma: Vec<f64>,
35    /// The aligned (reparameterized) curve.
36    pub f_aligned: Vec<f64>,
37    /// Elastic distance after alignment.
38    pub distance: f64,
39}
40
41/// Result of aligning a set of curves to a common target.
42#[derive(Debug, Clone)]
43pub struct AlignmentSetResult {
44    /// Warping functions (n × m).
45    pub gammas: FdMatrix,
46    /// Aligned curves (n × m).
47    pub aligned_data: FdMatrix,
48    /// Elastic distances for each curve.
49    pub distances: Vec<f64>,
50}
51
52/// Result of the Karcher mean computation.
53#[derive(Debug, Clone)]
54pub struct KarcherMeanResult {
55    /// Karcher mean curve.
56    pub mean: Vec<f64>,
57    /// SRSF of the Karcher mean.
58    pub mean_srsf: Vec<f64>,
59    /// Final warping functions (n × m).
60    pub gammas: FdMatrix,
61    /// Curves aligned to the mean (n × m).
62    pub aligned_data: FdMatrix,
63    /// Number of iterations used.
64    pub n_iter: usize,
65    /// Whether the algorithm converged.
66    pub converged: bool,
67}
68
69// Private helpers are now in crate::helpers and crate::warping.
70
71/// Karcher mean of warping functions on the Hilbert sphere, then invert.
72/// Port of fdasrvf's `SqrtMeanInverse`.
73///
74/// Takes a matrix of warping functions (n × m) on the argvals domain,
75/// computes the Fréchet mean of their sqrt-derivative representations
76/// on the unit Hilbert sphere, converts back to a warping function,
77/// and returns its inverse (on the argvals domain).
78/// One Karcher iteration on the Hilbert sphere: compute mean shooting vector and update mu.
79///
80/// Returns `true` if converged (vbar norm ≤ threshold).
81fn karcher_sphere_step(mu: &mut Vec<f64>, psis: &[Vec<f64>], time: &[f64], step_size: f64) -> bool {
82    let m = mu.len();
83    let n = psis.len();
84    let mut vbar = vec![0.0; m];
85    for psi in psis {
86        let v = inv_exp_map_sphere(mu, psi, time);
87        for j in 0..m {
88            vbar[j] += v[j];
89        }
90    }
91    for j in 0..m {
92        vbar[j] /= n as f64;
93    }
94    if l2_norm_l2(&vbar, time) <= 1e-8 {
95        return true;
96    }
97    let scaled: Vec<f64> = vbar.iter().map(|&v| v * step_size).collect();
98    *mu = exp_map_sphere(mu, &scaled, time);
99    false
100}
101
102fn sqrt_mean_inverse(gammas: &FdMatrix, argvals: &[f64]) -> Vec<f64> {
103    let (n, m) = gammas.shape();
104    let t0 = argvals[0];
105    let t1 = argvals[m - 1];
106    let domain = t1 - t0;
107
108    let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
109    let binsize = 1.0 / (m - 1) as f64;
110
111    let psis: Vec<Vec<f64>> = (0..n)
112        .map(|i| {
113            let gam_01: Vec<f64> = (0..m).map(|j| (gammas[(i, j)] - t0) / domain).collect();
114            gam_to_psi(&gam_01, binsize)
115        })
116        .collect();
117
118    let mut mu = vec![0.0; m];
119    for psi in &psis {
120        for j in 0..m {
121            mu[j] += psi[j];
122        }
123    }
124    for j in 0..m {
125        mu[j] /= n as f64;
126    }
127
128    for _ in 0..501 {
129        if karcher_sphere_step(&mut mu, &psis, &time, 0.3) {
130            break;
131        }
132    }
133
134    let gam_mu = psi_to_gam(&mu, &time);
135    let gam_inv = invert_gamma(&gam_mu, &time);
136    gam_inv.iter().map(|&g| t0 + g * domain).collect()
137}
138
139// ─── SRSF Transform and Inverse ─────────────────────────────────────────────
140
141/// Compute the Square-Root Slope Function (SRSF) transform.
142///
143/// For each curve f, the SRSF is: `q(t) = sign(f'(t)) * sqrt(|f'(t)|)`
144///
145/// # Arguments
146/// * `data` — Functional data matrix (n × m)
147/// * `argvals` — Evaluation points (length m)
148///
149/// # Returns
150/// FdMatrix of SRSFs with the same shape as input.
151pub fn srsf_transform(data: &FdMatrix, argvals: &[f64]) -> FdMatrix {
152    let (n, m) = data.shape();
153    if n == 0 || m == 0 || argvals.len() != m {
154        return FdMatrix::zeros(n, m);
155    }
156
157    let deriv = deriv_1d(data, argvals, 1);
158
159    let mut result = FdMatrix::zeros(n, m);
160    for i in 0..n {
161        for j in 0..m {
162            let d = deriv[(i, j)];
163            result[(i, j)] = d.signum() * d.abs().sqrt();
164        }
165    }
166    result
167}
168
169/// Reconstruct a curve from its SRSF representation.
170///
171/// Given SRSF q and initial value f0, reconstructs: `f(t) = f0 + ∫₀ᵗ q(s)|q(s)| ds`
172///
173/// # Arguments
174/// * `q` — SRSF values (length m)
175/// * `argvals` — Evaluation points (length m)
176/// * `f0` — Initial value f(argvals\[0\])
177///
178/// # Returns
179/// Reconstructed curve values.
180pub fn srsf_inverse(q: &[f64], argvals: &[f64], f0: f64) -> Vec<f64> {
181    let m = q.len();
182    if m == 0 {
183        return Vec::new();
184    }
185
186    // Integrand: q(s) * |q(s)|
187    let integrand: Vec<f64> = q.iter().map(|&qi| qi * qi.abs()).collect();
188    let integral = cumulative_trapz(&integrand, argvals);
189
190    integral.iter().map(|&v| f0 + v).collect()
191}
192
193// ─── Reparameterization ─────────────────────────────────────────────────────
194
195/// Reparameterize a curve by a warping function.
196///
197/// Computes `f(γ(t))` via linear interpolation.
198///
199/// # Arguments
200/// * `f` — Curve values (length m)
201/// * `argvals` — Evaluation points (length m)
202/// * `gamma` — Warping function values (length m)
203pub fn reparameterize_curve(f: &[f64], argvals: &[f64], gamma: &[f64]) -> Vec<f64> {
204    gamma
205        .iter()
206        .map(|&g| linear_interp(argvals, f, g))
207        .collect()
208}
209
210/// Compose two warping functions: `(γ₁ ∘ γ₂)(t) = γ₁(γ₂(t))`.
211///
212/// # Arguments
213/// * `gamma1` — Outer warping function (length m)
214/// * `gamma2` — Inner warping function (length m)
215/// * `argvals` — Evaluation points (length m)
216pub fn compose_warps(gamma1: &[f64], gamma2: &[f64], argvals: &[f64]) -> Vec<f64> {
217    gamma2
218        .iter()
219        .map(|&g| linear_interp(argvals, gamma1, g))
220        .collect()
221}
222
223// ─── Dynamic Programming Alignment ──────────────────────────────────────────
224// Faithful port of fdasrvf's DP algorithm (dp_grid.cpp / dp_nbhd.cpp).
225
226/// Greatest common divisor (Euclidean algorithm).
227#[cfg(test)]
228fn gcd(a: usize, b: usize) -> usize {
229    if b == 0 {
230        a
231    } else {
232        gcd(b, a % b)
233    }
234}
235
236/// Generate coprime neighborhood: all (i,j) with 1 ≤ i,j ≤ nbhd_dim, gcd(i,j) = 1.
237/// With nbhd_dim=7 this produces 35 pairs, matching fdasrvf's default.
238#[cfg(test)]
239fn generate_coprime_nbhd(nbhd_dim: usize) -> Vec<(usize, usize)> {
240    let mut pairs = Vec::new();
241    for i in 1..=nbhd_dim {
242        for j in 1..=nbhd_dim {
243            if gcd(i, j) == 1 {
244                pairs.push((i, j));
245            }
246        }
247    }
248    pairs
249}
250
251/// Pre-computed coprime neighborhood for nbhd_dim=7 (fdasrvf default).
252/// All (dr, dc) with 1 ≤ dr, dc ≤ 7 and gcd(dr, dc) = 1.
253/// dr = row delta (q2 direction), dc = column delta (q1 direction).
254#[rustfmt::skip]
255const COPRIME_NBHD_7: [(usize, usize); 35] = [
256    (1,1),(1,2),(1,3),(1,4),(1,5),(1,6),(1,7),
257    (2,1),      (2,3),      (2,5),      (2,7),
258    (3,1),(3,2),      (3,4),(3,5),      (3,7),
259    (4,1),      (4,3),      (4,5),      (4,7),
260    (5,1),(5,2),(5,3),(5,4),      (5,6),(5,7),
261    (6,1),                  (6,5),      (6,7),
262    (7,1),(7,2),(7,3),(7,4),(7,5),(7,6),
263];
264
265/// Compute the edge weight for a move from grid point (sr, sc) to (tr, tc).
266///
267/// Port of fdasrvf's `dp_edge_weight` for 1-D curves on a shared uniform grid.
268/// - Rows = q2 indices, columns = q1 indices (matching fdasrvf convention).
269/// - `slope = (argvals[tr] - argvals[sr]) / (argvals[tc] - argvals[sc])` = γ'
270/// - Walks through sub-intervals synchronized at both curves' breakpoints,
271///   accumulating `(q1[idx1] - √slope · q2[idx2])² · dt`.
272#[inline]
273fn dp_edge_weight(
274    q1: &[f64],
275    q2: &[f64],
276    argvals: &[f64],
277    sc: usize,
278    tc: usize,
279    sr: usize,
280    tr: usize,
281) -> f64 {
282    let n1 = tc - sc;
283    let n2 = tr - sr;
284    if n1 == 0 || n2 == 0 {
285        return f64::INFINITY;
286    }
287
288    let slope = (argvals[tr] - argvals[sr]) / (argvals[tc] - argvals[sc]);
289    let rslope = slope.sqrt();
290
291    // Walk through sub-intervals synchronized at breakpoints of both curves
292    let mut weight = 0.0;
293    let mut i1 = 0usize; // sub-interval index in q1 direction
294    let mut i2 = 0usize; // sub-interval index in q2 direction
295
296    while i1 < n1 && i2 < n2 {
297        // Current sub-interval boundaries as fractions of the total span
298        let left1 = i1 as f64 / n1 as f64;
299        let right1 = (i1 + 1) as f64 / n1 as f64;
300        let left2 = i2 as f64 / n2 as f64;
301        let right2 = (i2 + 1) as f64 / n2 as f64;
302
303        let left = left1.max(left2);
304        let right = right1.min(right2);
305        let dt = right - left;
306
307        if dt > 0.0 {
308            let diff = q1[sc + i1] - rslope * q2[sr + i2];
309            weight += diff * diff * dt;
310        }
311
312        // Advance whichever sub-interval ends first
313        if right1 < right2 {
314            i1 += 1;
315        } else if right2 < right1 {
316            i2 += 1;
317        } else {
318            i1 += 1;
319            i2 += 1;
320        }
321    }
322
323    // Scale by the span in q1 direction
324    weight * (argvals[tc] - argvals[sc])
325}
326
327/// Compute the λ·(slope−1)²·dt penalty for a DP edge.
328#[inline]
329fn dp_lambda_penalty(
330    argvals: &[f64],
331    sc: usize,
332    tc: usize,
333    sr: usize,
334    tr: usize,
335    lambda: f64,
336) -> f64 {
337    if lambda > 0.0 {
338        let dt = argvals[tc] - argvals[sc];
339        let slope = (argvals[tr] - argvals[sr]) / dt;
340        lambda * (slope - 1.0).powi(2) * dt
341    } else {
342        0.0
343    }
344}
345
346/// Traceback a parent-pointer array from bottom-right to top-left.
347///
348/// Returns the path as `(row, col)` pairs from `(0,0)` to `(nrows-1, ncols-1)`.
349fn dp_traceback(parent: &[u32], nrows: usize, ncols: usize) -> Vec<(usize, usize)> {
350    let mut path = Vec::with_capacity(nrows + ncols);
351    let mut cur = (nrows - 1) * ncols + (ncols - 1);
352    loop {
353        path.push((cur / ncols, cur % ncols));
354        if cur == 0 || parent[cur] == u32::MAX {
355            break;
356        }
357        cur = parent[cur] as usize;
358    }
359    path.reverse();
360    path
361}
362
363/// Try to relax cell `(tr, tc)` from each coprime neighbor, updating cost and parent.
364#[inline]
365fn dp_relax_cell<F>(
366    e: &mut [f64],
367    parent: &mut [u32],
368    ncols: usize,
369    tr: usize,
370    tc: usize,
371    edge_cost: &F,
372) where
373    F: Fn(usize, usize, usize, usize) -> f64,
374{
375    let idx = tr * ncols + tc;
376    for &(dr, dc) in &COPRIME_NBHD_7 {
377        if dr > tr || dc > tc {
378            continue;
379        }
380        let sr = tr - dr;
381        let sc = tc - dc;
382        let src_idx = sr * ncols + sc;
383        if e[src_idx] == f64::INFINITY {
384            continue;
385        }
386        let cost = e[src_idx] + edge_cost(sr, sc, tr, tc);
387        if cost < e[idx] {
388            e[idx] = cost;
389            parent[idx] = src_idx as u32;
390        }
391    }
392}
393
394/// Shared DP grid fill + traceback using the coprime neighborhood.
395///
396/// `edge_cost(sr, sc, tr, tc)` returns the combined edge weight + penalty for
397/// a move from local (sr, sc) to local (tr, tc). Returns the raw local-index
398/// path from (0,0) to (nrows-1, ncols-1).
399fn dp_grid_solve<F>(nrows: usize, ncols: usize, edge_cost: F) -> Vec<(usize, usize)>
400where
401    F: Fn(usize, usize, usize, usize) -> f64,
402{
403    let mut e = vec![f64::INFINITY; nrows * ncols];
404    let mut parent = vec![u32::MAX; nrows * ncols];
405    e[0] = 0.0;
406
407    for tr in 0..nrows {
408        for tc in 0..ncols {
409            if tr == 0 && tc == 0 {
410                continue;
411            }
412            dp_relax_cell(&mut e, &mut parent, ncols, tr, tc, &edge_cost);
413        }
414    }
415
416    dp_traceback(&parent, nrows, ncols)
417}
418
419/// Convert a DP path (local row,col indices) to an interpolated+normalized gamma warp.
420fn dp_path_to_gamma(path: &[(usize, usize)], argvals: &[f64]) -> Vec<f64> {
421    let path_tc: Vec<f64> = path.iter().map(|&(_, c)| argvals[c]).collect();
422    let path_tr: Vec<f64> = path.iter().map(|&(r, _)| argvals[r]).collect();
423    let mut gamma: Vec<f64> = argvals
424        .iter()
425        .map(|&t| linear_interp(&path_tc, &path_tr, t))
426        .collect();
427    normalize_warp(&mut gamma, argvals);
428    gamma
429}
430
431/// Core DP alignment between two SRSFs on a grid.
432///
433/// Finds the optimal warping γ minimizing ‖q₁ - (q₂∘γ)√γ'‖².
434/// Uses fdasrvf's coprime neighborhood (nbhd_dim=7 → 35 move directions).
435/// SRSFs are L2-normalized before alignment (matching fdasrvf's `optimum.reparam`).
436fn dp_alignment_core(q1: &[f64], q2: &[f64], argvals: &[f64], lambda: f64) -> Vec<f64> {
437    let m = argvals.len();
438    if m < 2 {
439        return argvals.to_vec();
440    }
441
442    let norm1 = q1.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
443    let norm2 = q2.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
444    let q1n: Vec<f64> = q1.iter().map(|&v| v / norm1).collect();
445    let q2n: Vec<f64> = q2.iter().map(|&v| v / norm2).collect();
446
447    let path = dp_grid_solve(m, m, |sr, sc, tr, tc| {
448        dp_edge_weight(&q1n, &q2n, argvals, sc, tc, sr, tr)
449            + dp_lambda_penalty(argvals, sc, tc, sr, tr, lambda)
450    });
451
452    dp_path_to_gamma(&path, argvals)
453}
454
455// ─── Public Alignment Functions ─────────────────────────────────────────────
456
457/// Align curve f2 to curve f1 using the elastic framework.
458///
459/// Computes the optimal warping γ such that f2∘γ is as close as possible
460/// to f1 in the elastic (Fisher-Rao) metric.
461///
462/// # Arguments
463/// * `f1` — Target curve (length m)
464/// * `f2` — Curve to align (length m)
465/// * `argvals` — Evaluation points (length m)
466/// * `lambda` — Penalty weight on warp deviation from identity (0.0 = no penalty)
467///
468/// # Returns
469/// [`AlignmentResult`] with warping function, aligned curve, and elastic distance.
470pub fn elastic_align_pair(f1: &[f64], f2: &[f64], argvals: &[f64], lambda: f64) -> AlignmentResult {
471    let m = f1.len();
472
473    // Build single-row FdMatrices for SRSF computation
474    let f1_mat = FdMatrix::from_slice(f1, 1, m).unwrap();
475    let f2_mat = FdMatrix::from_slice(f2, 1, m).unwrap();
476
477    let q1_mat = srsf_transform(&f1_mat, argvals);
478    let q2_mat = srsf_transform(&f2_mat, argvals);
479
480    let q1: Vec<f64> = q1_mat.row(0);
481    let q2: Vec<f64> = q2_mat.row(0);
482
483    // Find optimal warping via DP
484    let gamma = dp_alignment_core(&q1, &q2, argvals, lambda);
485
486    // Apply warping to f2
487    let f_aligned = reparameterize_curve(f2, argvals, &gamma);
488
489    // Compute elastic distance: L2 distance between q1 and aligned q2 SRSF
490    let f_aligned_mat = FdMatrix::from_slice(&f_aligned, 1, m).unwrap();
491    let q_aligned_mat = srsf_transform(&f_aligned_mat, argvals);
492    let q_aligned: Vec<f64> = q_aligned_mat.row(0);
493
494    let weights = simpsons_weights(argvals);
495    let distance = l2_distance(&q1, &q_aligned, &weights);
496
497    AlignmentResult {
498        gamma,
499        f_aligned,
500        distance,
501    }
502}
503
504/// Compute the elastic distance between two curves.
505///
506/// This is shorthand for aligning the pair and returning only the distance.
507///
508/// # Arguments
509/// * `f1` — First curve (length m)
510/// * `f2` — Second curve (length m)
511/// * `argvals` — Evaluation points (length m)
512/// * `lambda` — Penalty weight on warp deviation from identity (0.0 = no penalty)
513pub fn elastic_distance(f1: &[f64], f2: &[f64], argvals: &[f64], lambda: f64) -> f64 {
514    elastic_align_pair(f1, f2, argvals, lambda).distance
515}
516
517/// Align all curves in `data` to a single target curve.
518///
519/// # Arguments
520/// * `data` — Functional data matrix (n × m)
521/// * `target` — Target curve to align to (length m)
522/// * `argvals` — Evaluation points (length m)
523/// * `lambda` — Penalty weight on warp deviation from identity (0.0 = no penalty)
524///
525/// # Returns
526/// [`AlignmentSetResult`] with all warping functions, aligned curves, and distances.
527pub fn align_to_target(
528    data: &FdMatrix,
529    target: &[f64],
530    argvals: &[f64],
531    lambda: f64,
532) -> AlignmentSetResult {
533    let (n, m) = data.shape();
534
535    let results: Vec<AlignmentResult> = iter_maybe_parallel!(0..n)
536        .map(|i| {
537            let fi = data.row(i);
538            elastic_align_pair(target, &fi, argvals, lambda)
539        })
540        .collect();
541
542    let mut gammas = FdMatrix::zeros(n, m);
543    let mut aligned_data = FdMatrix::zeros(n, m);
544    let mut distances = Vec::with_capacity(n);
545
546    for (i, r) in results.into_iter().enumerate() {
547        for j in 0..m {
548            gammas[(i, j)] = r.gamma[j];
549            aligned_data[(i, j)] = r.f_aligned[j];
550        }
551        distances.push(r.distance);
552    }
553
554    AlignmentSetResult {
555        gammas,
556        aligned_data,
557        distances,
558    }
559}
560
561// ─── Distance Matrices ──────────────────────────────────────────────────────
562
563/// Compute the symmetric elastic distance matrix for a set of curves.
564///
565/// Uses upper-triangle computation with parallelism, following the
566/// `self_distance_matrix` pattern from `metric.rs`.
567///
568/// # Arguments
569/// * `data` — Functional data matrix (n × m)
570/// * `argvals` — Evaluation points (length m)
571/// * `lambda` — Penalty weight on warp deviation from identity (0.0 = no penalty)
572///
573/// # Returns
574/// Symmetric n × n distance matrix.
575pub fn elastic_self_distance_matrix(data: &FdMatrix, argvals: &[f64], lambda: f64) -> FdMatrix {
576    let n = data.nrows();
577
578    let upper_vals: Vec<f64> = iter_maybe_parallel!(0..n)
579        .flat_map(|i| {
580            let fi = data.row(i);
581            ((i + 1)..n)
582                .map(|j| {
583                    let fj = data.row(j);
584                    elastic_distance(&fi, &fj, argvals, lambda)
585                })
586                .collect::<Vec<_>>()
587        })
588        .collect();
589
590    let mut dist = FdMatrix::zeros(n, n);
591    let mut idx = 0;
592    for i in 0..n {
593        for j in (i + 1)..n {
594            let d = upper_vals[idx];
595            dist[(i, j)] = d;
596            dist[(j, i)] = d;
597            idx += 1;
598        }
599    }
600    dist
601}
602
603/// Compute the elastic distance matrix between two sets of curves.
604///
605/// # Arguments
606/// * `data1` — First dataset (n1 × m)
607/// * `data2` — Second dataset (n2 × m)
608/// * `argvals` — Evaluation points (length m)
609/// * `lambda` — Penalty weight on warp deviation from identity (0.0 = no penalty)
610///
611/// # Returns
612/// n1 × n2 distance matrix.
613pub fn elastic_cross_distance_matrix(
614    data1: &FdMatrix,
615    data2: &FdMatrix,
616    argvals: &[f64],
617    lambda: f64,
618) -> FdMatrix {
619    let n1 = data1.nrows();
620    let n2 = data2.nrows();
621
622    let vals: Vec<f64> = iter_maybe_parallel!(0..n1)
623        .flat_map(|i| {
624            let fi = data1.row(i);
625            (0..n2)
626                .map(|j| {
627                    let fj = data2.row(j);
628                    elastic_distance(&fi, &fj, argvals, lambda)
629                })
630                .collect::<Vec<_>>()
631        })
632        .collect();
633
634    let mut dist = FdMatrix::zeros(n1, n2);
635    for i in 0..n1 {
636        for j in 0..n2 {
637            dist[(i, j)] = vals[i * n2 + j];
638        }
639    }
640    dist
641}
642
643// ─── Phase-Amplitude Decomposition ──────────────────────────────────────────
644
645/// Result of elastic phase-amplitude decomposition.
646#[derive(Debug, Clone)]
647pub struct DecompositionResult {
648    /// Full alignment result.
649    pub alignment: AlignmentResult,
650    /// Amplitude distance: SRSF distance after alignment.
651    pub d_amplitude: f64,
652    /// Phase distance: geodesic distance of warp from identity.
653    pub d_phase: f64,
654}
655
656/// Perform elastic phase-amplitude decomposition of two curves.
657///
658/// Returns both the alignment result and the separate amplitude and phase distances.
659///
660/// # Arguments
661/// * `f1` — Target curve (length m)
662/// * `f2` — Curve to decompose against f1 (length m)
663/// * `argvals` — Evaluation points (length m)
664/// * `lambda` — Penalty weight on warp deviation from identity (0.0 = no penalty)
665pub fn elastic_decomposition(
666    f1: &[f64],
667    f2: &[f64],
668    argvals: &[f64],
669    lambda: f64,
670) -> DecompositionResult {
671    let alignment = elastic_align_pair(f1, f2, argvals, lambda);
672    let d_amplitude = alignment.distance;
673    let d_phase = crate::warping::phase_distance(&alignment.gamma, argvals);
674    DecompositionResult {
675        alignment,
676        d_amplitude,
677        d_phase,
678    }
679}
680
681/// Compute the amplitude distance between two curves (= elastic distance after alignment).
682pub fn amplitude_distance(f1: &[f64], f2: &[f64], argvals: &[f64], lambda: f64) -> f64 {
683    elastic_distance(f1, f2, argvals, lambda)
684}
685
686/// Compute the phase distance between two curves (geodesic distance of optimal warp from identity).
687pub fn phase_distance_pair(f1: &[f64], f2: &[f64], argvals: &[f64], lambda: f64) -> f64 {
688    let alignment = elastic_align_pair(f1, f2, argvals, lambda);
689    crate::warping::phase_distance(&alignment.gamma, argvals)
690}
691
692/// Compute the symmetric phase distance matrix for a set of curves.
693pub fn phase_self_distance_matrix(data: &FdMatrix, argvals: &[f64], lambda: f64) -> FdMatrix {
694    let n = data.nrows();
695
696    let upper_vals: Vec<f64> = iter_maybe_parallel!(0..n)
697        .flat_map(|i| {
698            let fi = data.row(i);
699            ((i + 1)..n)
700                .map(|j| {
701                    let fj = data.row(j);
702                    phase_distance_pair(&fi, &fj, argvals, lambda)
703                })
704                .collect::<Vec<_>>()
705        })
706        .collect();
707
708    let mut dist = FdMatrix::zeros(n, n);
709    let mut idx = 0;
710    for i in 0..n {
711        for j in (i + 1)..n {
712            let d = upper_vals[idx];
713            dist[(i, j)] = d;
714            dist[(j, i)] = d;
715            idx += 1;
716        }
717    }
718    dist
719}
720
721/// Compute the symmetric amplitude distance matrix (= elastic self distance matrix).
722pub fn amplitude_self_distance_matrix(data: &FdMatrix, argvals: &[f64], lambda: f64) -> FdMatrix {
723    elastic_self_distance_matrix(data, argvals, lambda)
724}
725
726// ─── Karcher Mean ───────────────────────────────────────────────────────────
727
728/// Compute relative change between successive mean SRSFs.
729///
730/// Returns `‖q_new - q_old‖₂ / ‖q_old‖₂`, matching R's fdasrvf
731/// `time_warping` convergence metric (unweighted discrete L2 norm).
732fn relative_change(q_old: &[f64], q_new: &[f64]) -> f64 {
733    let diff_norm: f64 = q_old
734        .iter()
735        .zip(q_new.iter())
736        .map(|(&a, &b)| (a - b).powi(2))
737        .sum::<f64>()
738        .sqrt();
739    let old_norm: f64 = q_old.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
740    diff_norm / old_norm
741}
742
743/// Compute a single SRSF from a slice (single-row convenience).
744fn srsf_single(f: &[f64], argvals: &[f64]) -> Vec<f64> {
745    let m = f.len();
746    let mat = FdMatrix::from_slice(f, 1, m).unwrap();
747    let q_mat = srsf_transform(&mat, argvals);
748    q_mat.row(0)
749}
750
751/// Align a single SRSF q2 to q1 and return (gamma, aligned_q).
752fn align_srsf_pair(q1: &[f64], q2: &[f64], argvals: &[f64], lambda: f64) -> (Vec<f64>, Vec<f64>) {
753    let gamma = dp_alignment_core(q1, q2, argvals, lambda);
754
755    // Warp q2 by gamma and adjust by sqrt(gamma')
756    let q2_warped = reparameterize_curve(q2, argvals, &gamma);
757
758    // Compute gamma' via finite differences
759    let m = gamma.len();
760    let mut gamma_dot = vec![0.0; m];
761    gamma_dot[0] = (gamma[1] - gamma[0]) / (argvals[1] - argvals[0]);
762    for j in 1..(m - 1) {
763        gamma_dot[j] = (gamma[j + 1] - gamma[j - 1]) / (argvals[j + 1] - argvals[j - 1]);
764    }
765    gamma_dot[m - 1] = (gamma[m - 1] - gamma[m - 2]) / (argvals[m - 1] - argvals[m - 2]);
766
767    // q2_aligned = (q2 ∘ γ) * sqrt(γ')
768    let q2_aligned: Vec<f64> = q2_warped
769        .iter()
770        .zip(gamma_dot.iter())
771        .map(|(&q, &gd)| q * gd.max(0.0).sqrt())
772        .collect();
773
774    (gamma, q2_aligned)
775}
776
777/// Compute the Karcher (Fréchet) mean in the elastic metric.
778///
779/// Iteratively aligns all curves to the current mean estimate in SRSF space,
780/// computes the pointwise mean of aligned SRSFs, and reconstructs the mean curve.
781///
782/// # Arguments
783/// * `data` — Functional data matrix (n × m)
784/// * `argvals` — Evaluation points (length m)
785/// * `max_iter` — Maximum number of iterations
786/// * `tol` — Convergence tolerance for the SRSF mean
787///
788/// # Returns
789/// [`KarcherMeanResult`] with mean curve, warping functions, aligned data, and convergence info.
790///
791/// # Examples
792///
793/// ```
794/// use fdars_core::simulation::{sim_fundata, EFunType, EValType};
795/// use fdars_core::alignment::karcher_mean;
796///
797/// let t: Vec<f64> = (0..50).map(|i| i as f64 / 49.0).collect();
798/// let data = sim_fundata(20, &t, 3, EFunType::Fourier, EValType::Exponential, Some(42));
799///
800/// let result = karcher_mean(&data, &t, 20, 1e-4, 0.0);
801/// assert_eq!(result.mean.len(), 50);
802/// assert!(result.n_iter <= 20);
803/// ```
804/// Accumulate alignment results: store gammas and return the mean of aligned SRSFs.
805fn accumulate_alignments(
806    results: &[(Vec<f64>, Vec<f64>)],
807    gammas: &mut FdMatrix,
808    m: usize,
809    n: usize,
810) -> Vec<f64> {
811    let mut mu_q_new = vec![0.0; m];
812    for (i, (gamma, q_aligned)) in results.iter().enumerate() {
813        for j in 0..m {
814            gammas[(i, j)] = gamma[j];
815            mu_q_new[j] += q_aligned[j];
816        }
817    }
818    for j in 0..m {
819        mu_q_new[j] /= n as f64;
820    }
821    mu_q_new
822}
823
824/// Apply stored warps to original curves to produce aligned data.
825fn apply_stored_warps(data: &FdMatrix, gammas: &FdMatrix, argvals: &[f64]) -> FdMatrix {
826    let (n, m) = data.shape();
827    let mut aligned = FdMatrix::zeros(n, m);
828    for i in 0..n {
829        let fi = data.row(i);
830        let gamma: Vec<f64> = (0..m).map(|j| gammas[(i, j)]).collect();
831        let f_aligned = reparameterize_curve(&fi, argvals, &gamma);
832        for j in 0..m {
833            aligned[(i, j)] = f_aligned[j];
834        }
835    }
836    aligned
837}
838
839/// Select the SRSF closest to the pointwise mean as template. Returns (mu_q, mu_f).
840fn select_template(srsf_mat: &FdMatrix, data: &FdMatrix, argvals: &[f64]) -> (Vec<f64>, Vec<f64>) {
841    let (n, m) = srsf_mat.shape();
842    let mnq = mean_1d(srsf_mat);
843    let mut min_dist = f64::INFINITY;
844    let mut min_idx = 0;
845    for i in 0..n {
846        let dist_sq: f64 = (0..m).map(|j| (srsf_mat[(i, j)] - mnq[j]).powi(2)).sum();
847        if dist_sq < min_dist {
848            min_dist = dist_sq;
849            min_idx = i;
850        }
851    }
852    let _ = argvals; // kept for API consistency
853    (srsf_mat.row(min_idx), data.row(min_idx))
854}
855
856/// Pre-centering: align all curves to template, compute inverse mean warp, re-center.
857fn pre_center_template(
858    data: &FdMatrix,
859    mu_q: &[f64],
860    mu: &[f64],
861    argvals: &[f64],
862    lambda: f64,
863) -> (Vec<f64>, Vec<f64>) {
864    let (n, m) = data.shape();
865    let align_results: Vec<(Vec<f64>, Vec<f64>)> = iter_maybe_parallel!(0..n)
866        .map(|i| {
867            let fi = data.row(i);
868            let qi = srsf_single(&fi, argvals);
869            align_srsf_pair(mu_q, &qi, argvals, lambda)
870        })
871        .collect();
872
873    let mut init_gammas = FdMatrix::zeros(n, m);
874    for (i, (gamma, _)) in align_results.iter().enumerate() {
875        for j in 0..m {
876            init_gammas[(i, j)] = gamma[j];
877        }
878    }
879
880    let gam_inv = sqrt_mean_inverse(&init_gammas, argvals);
881    let mu_new = reparameterize_curve(mu, argvals, &gam_inv);
882    let mu_q_new = srsf_single(&mu_new, argvals);
883    (mu_q_new, mu_new)
884}
885
886/// Post-convergence centering: center mean SRSF and warps via SqrtMeanInverse.
887fn post_center_results(
888    data: &FdMatrix,
889    mu_q: &[f64],
890    final_gammas: &mut FdMatrix,
891    argvals: &[f64],
892) -> (Vec<f64>, Vec<f64>, FdMatrix) {
893    let (n, m) = data.shape();
894    let gam_inv = sqrt_mean_inverse(final_gammas, argvals);
895    let h = (argvals[m - 1] - argvals[0]) / (m - 1) as f64;
896    let gam_inv_dev = gradient_uniform(&gam_inv, h);
897
898    let mu_q_warped = reparameterize_curve(mu_q, argvals, &gam_inv);
899    let mu_q_centered: Vec<f64> = mu_q_warped
900        .iter()
901        .zip(gam_inv_dev.iter())
902        .map(|(&q, &gd)| q * gd.max(0.0).sqrt())
903        .collect();
904
905    for i in 0..n {
906        let gam_i: Vec<f64> = (0..m).map(|j| final_gammas[(i, j)]).collect();
907        let gam_centered = reparameterize_curve(&gam_i, argvals, &gam_inv);
908        for j in 0..m {
909            final_gammas[(i, j)] = gam_centered[j];
910        }
911    }
912
913    let initial_mean = mean_1d(data);
914    let mu = srsf_inverse(&mu_q_centered, argvals, initial_mean[0]);
915    let final_aligned = apply_stored_warps(data, final_gammas, argvals);
916    (mu, mu_q_centered, final_aligned)
917}
918
919pub fn karcher_mean(
920    data: &FdMatrix,
921    argvals: &[f64],
922    max_iter: usize,
923    tol: f64,
924    lambda: f64,
925) -> KarcherMeanResult {
926    let (n, m) = data.shape();
927
928    let srsf_mat = srsf_transform(data, argvals);
929    let (mut mu_q, mu) = select_template(&srsf_mat, data, argvals);
930    let (mu_q_c, mu_c) = pre_center_template(data, &mu_q, &mu, argvals, lambda);
931    mu_q = mu_q_c;
932    let mut mu = mu_c;
933
934    let mut converged = false;
935    let mut n_iter = 0;
936    let mut final_gammas = FdMatrix::zeros(n, m);
937
938    for iter in 0..max_iter {
939        n_iter = iter + 1;
940
941        let align_results: Vec<(Vec<f64>, Vec<f64>)> = iter_maybe_parallel!(0..n)
942            .map(|i| {
943                let fi = data.row(i);
944                let qi = srsf_single(&fi, argvals);
945                align_srsf_pair(&mu_q, &qi, argvals, lambda)
946            })
947            .collect();
948
949        let mu_q_new = accumulate_alignments(&align_results, &mut final_gammas, m, n);
950
951        let rel = relative_change(&mu_q, &mu_q_new);
952        if rel < tol {
953            converged = true;
954            mu_q = mu_q_new;
955            break;
956        }
957
958        mu_q = mu_q_new;
959        mu = srsf_inverse(&mu_q, argvals, mu[0]);
960    }
961
962    let (mu_final, mu_q_final, final_aligned) =
963        post_center_results(data, &mu_q, &mut final_gammas, argvals);
964
965    KarcherMeanResult {
966        mean: mu_final,
967        mean_srsf: mu_q_final,
968        gammas: final_gammas,
969        aligned_data: final_aligned,
970        n_iter,
971        converged,
972    }
973}
974
975// ─── TSRVF (Transported SRSF) ────────────────────────────────────────────────
976// Maps aligned SRSFs to the tangent space of the Karcher mean on the Hilbert
977// sphere. Tangent vectors live in a standard Euclidean space, enabling PCA,
978// regression, and clustering on elastic-aligned curves.
979
980/// Result of the TSRVF transform.
981#[derive(Debug, Clone)]
982pub struct TsrvfResult {
983    /// Tangent vectors in Euclidean space (n × m).
984    pub tangent_vectors: FdMatrix,
985    /// Karcher mean curve (length m).
986    pub mean: Vec<f64>,
987    /// SRSF of the Karcher mean (length m).
988    pub mean_srsf: Vec<f64>,
989    /// L2 norm of the mean SRSF.
990    pub mean_srsf_norm: f64,
991    /// Per-curve aligned SRSF norms (length n).
992    pub srsf_norms: Vec<f64>,
993    /// Per-curve initial values f_i(0) for SRSF inverse reconstruction (length n).
994    pub initial_values: Vec<f64>,
995    /// Warping functions from Karcher mean computation (n × m).
996    pub gammas: FdMatrix,
997    /// Whether the Karcher mean converged.
998    pub converged: bool,
999}
1000
1001/// Full TSRVF pipeline: compute Karcher mean, then transport SRSFs to tangent space.
1002///
1003/// # Arguments
1004/// * `data` — Functional data matrix (n × m)
1005/// * `argvals` — Evaluation points (length m)
1006/// * `max_iter` — Maximum Karcher mean iterations
1007/// * `tol` — Convergence tolerance for Karcher mean
1008///
1009/// # Returns
1010/// [`TsrvfResult`] containing tangent vectors and associated metadata.
1011pub fn tsrvf_transform(
1012    data: &FdMatrix,
1013    argvals: &[f64],
1014    max_iter: usize,
1015    tol: f64,
1016    lambda: f64,
1017) -> TsrvfResult {
1018    let karcher = karcher_mean(data, argvals, max_iter, tol, lambda);
1019    tsrvf_from_alignment(&karcher, argvals)
1020}
1021
1022/// Smooth aligned SRSFs to remove DP kink artifacts before TSRVF computation.
1023///
1024/// Uses Nadaraya-Watson kernel smoothing (Gaussian, bandwidth = 2 grid spacings)
1025/// on each SRSF row. This removes the derivative spikes from DP warp kinks
1026/// without affecting alignment results or the Karcher mean.
1027fn smooth_aligned_srsfs(srsf: &FdMatrix, m: usize) -> FdMatrix {
1028    let n = srsf.nrows();
1029    let time: Vec<f64> = (0..m).map(|j| j as f64 / (m - 1) as f64).collect();
1030    let bandwidth = 2.0 / (m - 1) as f64;
1031
1032    let mut smoothed = FdMatrix::zeros(n, m);
1033    for i in 0..n {
1034        let qi = srsf.row(i);
1035        let qi_smooth = nadaraya_watson(&time, &qi, &time, bandwidth, "gaussian");
1036        for j in 0..m {
1037            smoothed[(i, j)] = qi_smooth[j];
1038        }
1039    }
1040    smoothed
1041}
1042
1043/// Compute TSRVF from a pre-computed Karcher mean alignment.
1044///
1045/// Avoids re-running the expensive Karcher mean computation when the alignment
1046/// has already been computed.
1047///
1048/// # Arguments
1049/// * `karcher` — Pre-computed Karcher mean result
1050/// * `argvals` — Evaluation points (length m)
1051///
1052/// # Returns
1053/// [`TsrvfResult`] containing tangent vectors and associated metadata.
1054pub fn tsrvf_from_alignment(karcher: &KarcherMeanResult, argvals: &[f64]) -> TsrvfResult {
1055    let (n, m) = karcher.aligned_data.shape();
1056
1057    // Step 1: Compute SRSFs of aligned data
1058    let aligned_srsf = srsf_transform(&karcher.aligned_data, argvals);
1059
1060    // Step 1b: Smooth aligned SRSFs to remove DP kink artifacts.
1061    //
1062    // DP alignment produces piecewise-linear warps with kinks at grid transitions.
1063    // When curves are reparameterized by these warps, the kinks propagate into the
1064    // aligned curves' derivatives (SRSFs), creating spikes that dominate TSRVF
1065    // tangent vectors and PCA.
1066    //
1067    // R's fdasrvf does not smooth here and suffers from the same spike artifacts.
1068    // Python's fdasrsf mitigates this via spline smoothing (s=1e-4) in SqrtMean.
1069    // We smooth the aligned SRSFs before tangent vector computation — this only
1070    // affects TSRVF output and does not change the alignment or Karcher mean.
1071    let aligned_srsf = smooth_aligned_srsfs(&aligned_srsf, m);
1072
1073    // Step 2: Smooth and normalize mean SRSF to unit sphere.
1074    // The mean SRSF must be smoothed consistently with the aligned SRSFs
1075    // so that a single curve (which IS the mean) produces a zero tangent vector.
1076    let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
1077    let bandwidth = 2.0 / (m - 1) as f64;
1078    let mean_srsf_smooth = nadaraya_watson(&time, &karcher.mean_srsf, &time, bandwidth, "gaussian");
1079    let mean_norm = l2_norm_l2(&mean_srsf_smooth, &time);
1080
1081    let mu_unit: Vec<f64> = if mean_norm > 1e-10 {
1082        mean_srsf_smooth.iter().map(|&q| q / mean_norm).collect()
1083    } else {
1084        vec![0.0; m]
1085    };
1086
1087    // Step 3: For each aligned curve, compute tangent vector via inverse exponential map
1088    let srsf_norms: Vec<f64> = iter_maybe_parallel!(0..n)
1089        .map(|i| {
1090            let qi = aligned_srsf.row(i);
1091            l2_norm_l2(&qi, &time)
1092        })
1093        .collect();
1094
1095    let tangent_data: Vec<Vec<f64>> = iter_maybe_parallel!(0..n)
1096        .map(|i| {
1097            let qi = aligned_srsf.row(i);
1098            let qi_norm = srsf_norms[i];
1099
1100            if qi_norm < 1e-10 || mean_norm < 1e-10 {
1101                return vec![0.0; m];
1102            }
1103
1104            // Normalize to unit sphere
1105            let qi_unit: Vec<f64> = qi.iter().map(|&q| q / qi_norm).collect();
1106
1107            // Shooting vector from mu_unit to qi_unit
1108            inv_exp_map_sphere(&mu_unit, &qi_unit, &time)
1109        })
1110        .collect();
1111
1112    // Assemble tangent vectors into FdMatrix
1113    let mut tangent_vectors = FdMatrix::zeros(n, m);
1114    for i in 0..n {
1115        for j in 0..m {
1116            tangent_vectors[(i, j)] = tangent_data[i][j];
1117        }
1118    }
1119
1120    // Store per-curve initial values for SRSF inverse reconstruction.
1121    // Warping preserves f_i(0) since gamma(0) = 0.
1122    let initial_values: Vec<f64> = (0..n).map(|i| karcher.aligned_data[(i, 0)]).collect();
1123
1124    TsrvfResult {
1125        tangent_vectors,
1126        mean: karcher.mean.clone(),
1127        mean_srsf: mean_srsf_smooth,
1128        mean_srsf_norm: mean_norm,
1129        srsf_norms,
1130        initial_values,
1131        gammas: karcher.gammas.clone(),
1132        converged: karcher.converged,
1133    }
1134}
1135
1136/// Reconstruct aligned curves from TSRVF tangent vectors.
1137///
1138/// Inverts the TSRVF transform: maps tangent vectors back to the Hilbert sphere
1139/// via the exponential map, rescales, and reconstructs curves via SRSF inverse.
1140///
1141/// # Arguments
1142/// * `tsrvf` — TSRVF result from [`tsrvf_transform`] or [`tsrvf_from_alignment`]
1143/// * `argvals` — Evaluation points (length m)
1144///
1145/// # Returns
1146/// FdMatrix of reconstructed aligned curves (n × m).
1147pub fn tsrvf_inverse(tsrvf: &TsrvfResult, argvals: &[f64]) -> FdMatrix {
1148    let (n, m) = tsrvf.tangent_vectors.shape();
1149
1150    let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
1151
1152    // Normalize mean SRSF to unit sphere
1153    let mu_unit: Vec<f64> = if tsrvf.mean_srsf_norm > 1e-10 {
1154        tsrvf
1155            .mean_srsf
1156            .iter()
1157            .map(|&q| q / tsrvf.mean_srsf_norm)
1158            .collect()
1159    } else {
1160        vec![0.0; m]
1161    };
1162
1163    let curves: Vec<Vec<f64>> = iter_maybe_parallel!(0..n)
1164        .map(|i| {
1165            let vi = tsrvf.tangent_vectors.row(i);
1166
1167            // Map back to sphere: exp_map(mu_unit, v_i)
1168            let qi_unit = exp_map_sphere(&mu_unit, &vi, &time);
1169
1170            // Rescale by original norm
1171            let qi: Vec<f64> = qi_unit.iter().map(|&q| q * tsrvf.srsf_norms[i]).collect();
1172
1173            // Reconstruct curve from SRSF using per-curve initial value
1174            srsf_inverse(&qi, argvals, tsrvf.initial_values[i])
1175        })
1176        .collect();
1177
1178    let mut result = FdMatrix::zeros(n, m);
1179    for i in 0..n {
1180        for j in 0..m {
1181            result[(i, j)] = curves[i][j];
1182        }
1183    }
1184    result
1185}
1186
1187// ─── Parallel Transport Variants ─────────────────────────────────────────────
1188
1189/// Method for transporting tangent vectors on the Hilbert sphere.
1190#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
1191pub enum TransportMethod {
1192    /// Inverse exponential map (log map) — default, matches existing TSRVF behavior.
1193    #[default]
1194    LogMap,
1195    /// Schild's ladder approximation to parallel transport.
1196    SchildsLadder,
1197    /// Pole ladder approximation to parallel transport.
1198    PoleLadder,
1199}
1200
1201/// Schild's ladder parallel transport of vector `v` from `from` to `to` on the sphere.
1202fn parallel_transport_schilds(v: &[f64], from: &[f64], to: &[f64], time: &[f64]) -> Vec<f64> {
1203    use crate::warping::{exp_map_sphere, inv_exp_map_sphere};
1204
1205    let v_norm = crate::warping::l2_norm_l2(v, time);
1206    if v_norm < 1e-10 {
1207        return vec![0.0; v.len()];
1208    }
1209
1210    // endpoint = exp_from(v)
1211    let endpoint = exp_map_sphere(from, v, time);
1212
1213    // midpoint_v = log_to(endpoint) — vector at `to` pointing toward endpoint
1214    let log_to_ep = inv_exp_map_sphere(to, &endpoint, time);
1215
1216    // midpoint = exp_to(0.5 * log_to_ep)
1217    let half_log: Vec<f64> = log_to_ep.iter().map(|&x| 0.5 * x).collect();
1218    let midpoint = exp_map_sphere(to, &half_log, time);
1219
1220    // transported = 2 * log_to(midpoint)
1221    let log_to_mid = inv_exp_map_sphere(to, &midpoint, time);
1222    log_to_mid.iter().map(|&x| 2.0 * x).collect()
1223}
1224
1225/// Pole ladder parallel transport of vector `v` from `from` to `to` on the sphere.
1226fn parallel_transport_pole(v: &[f64], from: &[f64], to: &[f64], time: &[f64]) -> Vec<f64> {
1227    use crate::warping::{exp_map_sphere, inv_exp_map_sphere};
1228
1229    let v_norm = crate::warping::l2_norm_l2(v, time);
1230    if v_norm < 1e-10 {
1231        return vec![0.0; v.len()];
1232    }
1233
1234    // pole = exp_from(-v)
1235    let neg_v: Vec<f64> = v.iter().map(|&x| -x).collect();
1236    let pole = exp_map_sphere(from, &neg_v, time);
1237
1238    // midpoint_v = log_to(pole)
1239    let log_to_pole = inv_exp_map_sphere(to, &pole, time);
1240
1241    // midpoint = exp_to(0.5 * log_to_pole)
1242    let half_log: Vec<f64> = log_to_pole.iter().map(|&x| 0.5 * x).collect();
1243    let midpoint = exp_map_sphere(to, &half_log, time);
1244
1245    // transported = -2 * log_to(midpoint)
1246    let log_to_mid = inv_exp_map_sphere(to, &midpoint, time);
1247    log_to_mid.iter().map(|&x| -2.0 * x).collect()
1248}
1249
1250/// Full TSRVF pipeline with configurable transport method.
1251///
1252/// Like [`tsrvf_transform`] but allows choosing the parallel transport method.
1253pub fn tsrvf_transform_with_method(
1254    data: &FdMatrix,
1255    argvals: &[f64],
1256    max_iter: usize,
1257    tol: f64,
1258    lambda: f64,
1259    method: TransportMethod,
1260) -> TsrvfResult {
1261    let karcher = karcher_mean(data, argvals, max_iter, tol, lambda);
1262    tsrvf_from_alignment_with_method(&karcher, argvals, method)
1263}
1264
1265/// Compute TSRVF from a pre-computed Karcher mean with configurable transport.
1266///
1267/// - [`TransportMethod::LogMap`]: Uses `inv_exp_map(mu, qi)` directly (standard TSRVF).
1268/// - [`TransportMethod::SchildsLadder`]: Computes `v = -log_qi(mu)`, then transports
1269///   via Schild's ladder from qi to mu.
1270/// - [`TransportMethod::PoleLadder`]: Same but via pole ladder.
1271pub fn tsrvf_from_alignment_with_method(
1272    karcher: &KarcherMeanResult,
1273    argvals: &[f64],
1274    method: TransportMethod,
1275) -> TsrvfResult {
1276    if method == TransportMethod::LogMap {
1277        return tsrvf_from_alignment(karcher, argvals);
1278    }
1279
1280    let (n, m) = karcher.aligned_data.shape();
1281    let aligned_srsf = srsf_transform(&karcher.aligned_data, argvals);
1282    let aligned_srsf = smooth_aligned_srsfs(&aligned_srsf, m);
1283    let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
1284    let bandwidth = 2.0 / (m - 1) as f64;
1285    let mean_srsf_smooth = nadaraya_watson(&time, &karcher.mean_srsf, &time, bandwidth, "gaussian");
1286    let mean_norm = crate::warping::l2_norm_l2(&mean_srsf_smooth, &time);
1287
1288    let mu_unit: Vec<f64> = if mean_norm > 1e-10 {
1289        mean_srsf_smooth.iter().map(|&q| q / mean_norm).collect()
1290    } else {
1291        vec![0.0; m]
1292    };
1293
1294    let srsf_norms: Vec<f64> = iter_maybe_parallel!(0..n)
1295        .map(|i| {
1296            let qi = aligned_srsf.row(i);
1297            crate::warping::l2_norm_l2(&qi, &time)
1298        })
1299        .collect();
1300
1301    let tangent_data: Vec<Vec<f64>> = iter_maybe_parallel!(0..n)
1302        .map(|i| {
1303            let qi = aligned_srsf.row(i);
1304            let qi_norm = srsf_norms[i];
1305
1306            if qi_norm < 1e-10 || mean_norm < 1e-10 {
1307                return vec![0.0; m];
1308            }
1309
1310            let qi_unit: Vec<f64> = qi.iter().map(|&q| q / qi_norm).collect();
1311
1312            // Compute v = -log_qi(mu) — vector at qi pointing away from mu
1313            let v_at_qi = inv_exp_map_sphere(&qi_unit, &mu_unit, &time);
1314            let neg_v: Vec<f64> = v_at_qi.iter().map(|&x| -x).collect();
1315
1316            // Transport from qi to mu
1317            match method {
1318                TransportMethod::SchildsLadder => {
1319                    parallel_transport_schilds(&neg_v, &qi_unit, &mu_unit, &time)
1320                }
1321                TransportMethod::PoleLadder => {
1322                    parallel_transport_pole(&neg_v, &qi_unit, &mu_unit, &time)
1323                }
1324                TransportMethod::LogMap => unreachable!(),
1325            }
1326        })
1327        .collect();
1328
1329    let mut tangent_vectors = FdMatrix::zeros(n, m);
1330    for i in 0..n {
1331        for j in 0..m {
1332            tangent_vectors[(i, j)] = tangent_data[i][j];
1333        }
1334    }
1335
1336    let initial_values: Vec<f64> = (0..n).map(|i| karcher.aligned_data[(i, 0)]).collect();
1337
1338    TsrvfResult {
1339        tangent_vectors,
1340        mean: karcher.mean.clone(),
1341        mean_srsf: mean_srsf_smooth,
1342        mean_srsf_norm: mean_norm,
1343        srsf_norms,
1344        initial_values,
1345        gammas: karcher.gammas.clone(),
1346        converged: karcher.converged,
1347    }
1348}
1349
1350// ─── Alignment Quality Metrics ───────────────────────────────────────────────
1351
1352/// Comprehensive alignment quality assessment.
1353#[derive(Debug, Clone)]
1354pub struct AlignmentQuality {
1355    /// Per-curve geodesic distance from warp to identity.
1356    pub warp_complexity: Vec<f64>,
1357    /// Mean warp complexity.
1358    pub mean_warp_complexity: f64,
1359    /// Per-curve bending energy ∫(γ'')² dt.
1360    pub warp_smoothness: Vec<f64>,
1361    /// Mean warp smoothness (bending energy).
1362    pub mean_warp_smoothness: f64,
1363    /// Total variance: (1/n) Σ ∫(f_i - mean_orig)² dt.
1364    pub total_variance: f64,
1365    /// Amplitude variance: (1/n) Σ ∫(f_i^aligned - mean_aligned)² dt.
1366    pub amplitude_variance: f64,
1367    /// Phase variance: total - amplitude (clamped ≥ 0).
1368    pub phase_variance: f64,
1369    /// Phase-to-total variance ratio.
1370    pub phase_amplitude_ratio: f64,
1371    /// Pointwise ratio: aligned_var / orig_var per time point.
1372    pub pointwise_variance_ratio: Vec<f64>,
1373    /// Mean variance reduction.
1374    pub mean_variance_reduction: f64,
1375}
1376
1377/// Compute warp complexity: geodesic distance from a warp to the identity.
1378///
1379/// This is `arccos(⟨ψ, ψ_id⟩)` on the Hilbert sphere.
1380pub fn warp_complexity(gamma: &[f64], argvals: &[f64]) -> f64 {
1381    crate::warping::phase_distance(gamma, argvals)
1382}
1383
1384/// Compute warp smoothness (bending energy): ∫(γ'')² dt.
1385pub fn warp_smoothness(gamma: &[f64], argvals: &[f64]) -> f64 {
1386    let m = gamma.len();
1387    if m < 3 {
1388        return 0.0;
1389    }
1390
1391    let h = (argvals[m - 1] - argvals[0]) / (m - 1) as f64;
1392    let gam_prime = gradient_uniform(gamma, h);
1393    let gam_pprime = gradient_uniform(&gam_prime, h);
1394
1395    let integrand: Vec<f64> = gam_pprime.iter().map(|&g| g * g).collect();
1396    crate::helpers::trapz(&integrand, argvals)
1397}
1398
1399/// Compute comprehensive alignment quality metrics.
1400///
1401/// # Arguments
1402/// * `data` — Original functional data (n × m)
1403/// * `karcher` — Pre-computed Karcher mean result
1404/// * `argvals` — Evaluation points (length m)
1405pub fn alignment_quality(
1406    data: &FdMatrix,
1407    karcher: &KarcherMeanResult,
1408    argvals: &[f64],
1409) -> AlignmentQuality {
1410    let (n, m) = data.shape();
1411    let weights = simpsons_weights(argvals);
1412
1413    // Per-curve warp complexity and smoothness
1414    let wc: Vec<f64> = (0..n)
1415        .map(|i| {
1416            let gamma: Vec<f64> = (0..m).map(|j| karcher.gammas[(i, j)]).collect();
1417            warp_complexity(&gamma, argvals)
1418        })
1419        .collect();
1420    let ws: Vec<f64> = (0..n)
1421        .map(|i| {
1422            let gamma: Vec<f64> = (0..m).map(|j| karcher.gammas[(i, j)]).collect();
1423            warp_smoothness(&gamma, argvals)
1424        })
1425        .collect();
1426
1427    let mean_wc = wc.iter().sum::<f64>() / n as f64;
1428    let mean_ws = ws.iter().sum::<f64>() / n as f64;
1429
1430    // Compute original mean
1431    let orig_mean = crate::fdata::mean_1d(data);
1432
1433    // Total variance
1434    let total_var: f64 = (0..n)
1435        .map(|i| {
1436            let fi = data.row(i);
1437            let d = l2_distance(&fi, &orig_mean, &weights);
1438            d * d
1439        })
1440        .sum::<f64>()
1441        / n as f64;
1442
1443    // Aligned mean
1444    let aligned_mean = crate::fdata::mean_1d(&karcher.aligned_data);
1445
1446    // Amplitude variance
1447    let amp_var: f64 = (0..n)
1448        .map(|i| {
1449            let fi = karcher.aligned_data.row(i);
1450            let d = l2_distance(&fi, &aligned_mean, &weights);
1451            d * d
1452        })
1453        .sum::<f64>()
1454        / n as f64;
1455
1456    let phase_var = (total_var - amp_var).max(0.0);
1457    let ratio = if total_var > 1e-10 {
1458        phase_var / total_var
1459    } else {
1460        0.0
1461    };
1462
1463    // Pointwise variance ratio
1464    let mut pw_ratio = vec![0.0; m];
1465    for j in 0..m {
1466        let col_orig = data.column(j);
1467        let mean_orig = col_orig.iter().sum::<f64>() / n as f64;
1468        let var_orig: f64 = col_orig
1469            .iter()
1470            .map(|&v| (v - mean_orig).powi(2))
1471            .sum::<f64>()
1472            / n as f64;
1473
1474        let col_aligned = karcher.aligned_data.column(j);
1475        let mean_aligned = col_aligned.iter().sum::<f64>() / n as f64;
1476        let var_aligned: f64 = col_aligned
1477            .iter()
1478            .map(|&v| (v - mean_aligned).powi(2))
1479            .sum::<f64>()
1480            / n as f64;
1481
1482        pw_ratio[j] = if var_orig > 1e-15 {
1483            var_aligned / var_orig
1484        } else {
1485            1.0
1486        };
1487    }
1488
1489    let mean_vr = pw_ratio.iter().sum::<f64>() / m as f64;
1490
1491    AlignmentQuality {
1492        warp_complexity: wc,
1493        mean_warp_complexity: mean_wc,
1494        warp_smoothness: ws,
1495        mean_warp_smoothness: mean_ws,
1496        total_variance: total_var,
1497        amplitude_variance: amp_var,
1498        phase_variance: phase_var,
1499        phase_amplitude_ratio: ratio,
1500        pointwise_variance_ratio: pw_ratio,
1501        mean_variance_reduction: mean_vr,
1502    }
1503}
1504
1505/// Generate triplet indices (i,j,k) with i<j<k, capped at `max_triplets` (0 = all).
1506fn triplet_indices(n: usize, max_triplets: usize) -> Vec<(usize, usize, usize)> {
1507    let total = n * (n - 1) * (n - 2) / 6;
1508    let cap = if max_triplets > 0 {
1509        max_triplets.min(total)
1510    } else {
1511        total
1512    };
1513    (0..n)
1514        .flat_map(|i| ((i + 1)..n).flat_map(move |j| ((j + 1)..n).map(move |k| (i, j, k))))
1515        .take(cap)
1516        .collect()
1517}
1518
1519/// Compute the warp deviation for one triplet: ‖γ_ij∘γ_jk − γ_ik‖_L2.
1520fn triplet_warp_deviation(
1521    data: &FdMatrix,
1522    argvals: &[f64],
1523    weights: &[f64],
1524    i: usize,
1525    j: usize,
1526    k: usize,
1527    lambda: f64,
1528) -> f64 {
1529    let fi = data.row(i);
1530    let fj = data.row(j);
1531    let fk = data.row(k);
1532    let rij = elastic_align_pair(&fi, &fj, argvals, lambda);
1533    let rjk = elastic_align_pair(&fj, &fk, argvals, lambda);
1534    let rik = elastic_align_pair(&fi, &fk, argvals, lambda);
1535    let composed = compose_warps(&rij.gamma, &rjk.gamma, argvals);
1536    l2_distance(&composed, &rik.gamma, weights)
1537}
1538
1539/// Measure pairwise alignment consistency via triplet checks.
1540///
1541/// For triplets (i,j,k), checks `γ_ij ∘ γ_jk ≈ γ_ik` by measuring the L2
1542/// deviation of the composed warp from the direct warp.
1543///
1544/// # Arguments
1545/// * `data` — Functional data (n × m)
1546/// * `argvals` — Evaluation points (length m)
1547/// * `lambda` — Penalty weight
1548/// * `max_triplets` — Maximum number of triplets to check (0 = all)
1549pub fn pairwise_consistency(
1550    data: &FdMatrix,
1551    argvals: &[f64],
1552    lambda: f64,
1553    max_triplets: usize,
1554) -> f64 {
1555    let n = data.nrows();
1556    if n < 3 {
1557        return 0.0;
1558    }
1559
1560    let weights = simpsons_weights(argvals);
1561    let triplets = triplet_indices(n, max_triplets);
1562    if triplets.is_empty() {
1563        return 0.0;
1564    }
1565
1566    let total_dev: f64 = triplets
1567        .iter()
1568        .map(|&(i, j, k)| triplet_warp_deviation(data, argvals, &weights, i, j, k, lambda))
1569        .sum();
1570    total_dev / triplets.len() as f64
1571}
1572
1573// ─── Landmark-Constrained Elastic Alignment ────────────────────────────────
1574
1575/// Result of landmark-constrained elastic alignment.
1576#[derive(Debug, Clone)]
1577pub struct ConstrainedAlignmentResult {
1578    /// Optimal warping function (length m).
1579    pub gamma: Vec<f64>,
1580    /// Aligned curve f2∘γ (length m).
1581    pub f_aligned: Vec<f64>,
1582    /// Elastic distance after alignment.
1583    pub distance: f64,
1584    /// Enforced landmark pairs (snapped to grid): `(target_t, source_t)`.
1585    pub enforced_landmarks: Vec<(f64, f64)>,
1586}
1587
1588/// Snap a time value to the nearest grid point index.
1589fn snap_to_grid(t_val: f64, argvals: &[f64]) -> usize {
1590    let mut best = 0;
1591    let mut best_dist = (t_val - argvals[0]).abs();
1592    for (i, &a) in argvals.iter().enumerate().skip(1) {
1593        let d = (t_val - a).abs();
1594        if d < best_dist {
1595            best = i;
1596            best_dist = d;
1597        }
1598    }
1599    best
1600}
1601
1602/// Run DP on a rectangular sub-grid `[sc..=ec] × [sr..=er]`.
1603///
1604/// Uses global indices for `dp_edge_weight`. Returns the path segment
1605/// as a list of `(tc_idx, tr_idx)` pairs from start to end.
1606fn dp_segment(
1607    q1: &[f64],
1608    q2: &[f64],
1609    argvals: &[f64],
1610    sc: usize,
1611    ec: usize,
1612    sr: usize,
1613    er: usize,
1614    lambda: f64,
1615) -> Vec<(usize, usize)> {
1616    let nc = ec - sc + 1;
1617    let nr = er - sr + 1;
1618
1619    if nc <= 1 || nr <= 1 {
1620        return vec![(sc, sr), (ec, er)];
1621    }
1622
1623    let path = dp_grid_solve(nr, nc, |local_sr, local_sc, local_tr, local_tc| {
1624        let gsr = sr + local_sr;
1625        let gsc = sc + local_sc;
1626        let gtr = sr + local_tr;
1627        let gtc = sc + local_tc;
1628        dp_edge_weight(q1, q2, argvals, gsc, gtc, gsr, gtr)
1629            + dp_lambda_penalty(argvals, gsc, gtc, gsr, gtr, lambda)
1630    });
1631
1632    // Convert local indices to global
1633    path.iter().map(|&(lr, lc)| (sc + lc, sr + lr)).collect()
1634}
1635
1636/// Align f2 to f1 with landmark constraints.
1637///
1638/// Landmark pairs define waypoints on the DP grid. Between consecutive waypoints,
1639/// an independent smaller DP is run. The resulting warp passes through all landmarks.
1640///
1641/// # Arguments
1642/// * `f1` — Target curve (length m)
1643/// * `f2` — Curve to align (length m)
1644/// * `argvals` — Evaluation points (length m)
1645/// * `landmark_pairs` — `(target_t, source_t)` pairs in increasing order
1646/// * `lambda` — Penalty weight
1647///
1648/// # Returns
1649/// [`ConstrainedAlignmentResult`] with warp, aligned curve, and enforced landmarks.
1650/// Build DP waypoints from landmark pairs: snap to grid, deduplicate, add endpoints.
1651fn build_constrained_waypoints(
1652    landmark_pairs: &[(f64, f64)],
1653    argvals: &[f64],
1654    m: usize,
1655) -> Vec<(usize, usize)> {
1656    let mut waypoints: Vec<(usize, usize)> = Vec::with_capacity(landmark_pairs.len() + 2);
1657    waypoints.push((0, 0));
1658    for &(tt, st) in landmark_pairs {
1659        let tc = snap_to_grid(tt, argvals);
1660        let tr = snap_to_grid(st, argvals);
1661        if let Some(&(prev_c, prev_r)) = waypoints.last() {
1662            if tc > prev_c && tr > prev_r {
1663                waypoints.push((tc, tr));
1664            }
1665        }
1666    }
1667    let last = m - 1;
1668    if let Some(&(prev_c, prev_r)) = waypoints.last() {
1669        if prev_c != last || prev_r != last {
1670            waypoints.push((last, last));
1671        }
1672    }
1673    waypoints
1674}
1675
1676/// Run DP segments between consecutive waypoints and assemble into a gamma warp.
1677fn segmented_dp_gamma(
1678    q1n: &[f64],
1679    q2n: &[f64],
1680    argvals: &[f64],
1681    waypoints: &[(usize, usize)],
1682    lambda: f64,
1683) -> Vec<f64> {
1684    let mut full_path_tc: Vec<f64> = Vec::new();
1685    let mut full_path_tr: Vec<f64> = Vec::new();
1686
1687    for seg in 0..(waypoints.len() - 1) {
1688        let (sc, sr) = waypoints[seg];
1689        let (ec, er) = waypoints[seg + 1];
1690        let segment_path = dp_segment(q1n, q2n, argvals, sc, ec, sr, er, lambda);
1691        let start = if seg > 0 { 1 } else { 0 };
1692        for &(tc, tr) in &segment_path[start..] {
1693            full_path_tc.push(argvals[tc]);
1694            full_path_tr.push(argvals[tr]);
1695        }
1696    }
1697
1698    let mut gamma: Vec<f64> = argvals
1699        .iter()
1700        .map(|&t| linear_interp(&full_path_tc, &full_path_tr, t))
1701        .collect();
1702    normalize_warp(&mut gamma, argvals);
1703    gamma
1704}
1705
1706pub fn elastic_align_pair_constrained(
1707    f1: &[f64],
1708    f2: &[f64],
1709    argvals: &[f64],
1710    landmark_pairs: &[(f64, f64)],
1711    lambda: f64,
1712) -> ConstrainedAlignmentResult {
1713    let m = f1.len();
1714
1715    if landmark_pairs.is_empty() {
1716        let r = elastic_align_pair(f1, f2, argvals, lambda);
1717        return ConstrainedAlignmentResult {
1718            gamma: r.gamma,
1719            f_aligned: r.f_aligned,
1720            distance: r.distance,
1721            enforced_landmarks: Vec::new(),
1722        };
1723    }
1724
1725    // Compute & normalize SRSFs
1726    let f1_mat = FdMatrix::from_slice(f1, 1, m).unwrap();
1727    let f2_mat = FdMatrix::from_slice(f2, 1, m).unwrap();
1728    let q1_mat = srsf_transform(&f1_mat, argvals);
1729    let q2_mat = srsf_transform(&f2_mat, argvals);
1730    let q1: Vec<f64> = q1_mat.row(0);
1731    let q2: Vec<f64> = q2_mat.row(0);
1732    let norm1 = q1.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
1733    let norm2 = q2.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
1734    let q1n: Vec<f64> = q1.iter().map(|&v| v / norm1).collect();
1735    let q2n: Vec<f64> = q2.iter().map(|&v| v / norm2).collect();
1736
1737    let waypoints = build_constrained_waypoints(landmark_pairs, argvals, m);
1738    let gamma = segmented_dp_gamma(&q1n, &q2n, argvals, &waypoints, lambda);
1739
1740    let f_aligned = reparameterize_curve(f2, argvals, &gamma);
1741    let f_aligned_mat = FdMatrix::from_slice(&f_aligned, 1, m).unwrap();
1742    let q_aligned_mat = srsf_transform(&f_aligned_mat, argvals);
1743    let q_aligned: Vec<f64> = q_aligned_mat.row(0);
1744    let weights = simpsons_weights(argvals);
1745    let distance = l2_distance(&q1, &q_aligned, &weights);
1746
1747    let enforced: Vec<(f64, f64)> = waypoints[1..waypoints.len() - 1]
1748        .iter()
1749        .map(|&(tc, tr)| (argvals[tc], argvals[tr]))
1750        .collect();
1751
1752    ConstrainedAlignmentResult {
1753        gamma,
1754        f_aligned,
1755        distance,
1756        enforced_landmarks: enforced,
1757    }
1758}
1759
1760/// Align f2 to f1 with automatic landmark detection and elastic constraints.
1761///
1762/// Detects landmarks in both curves, matches them, and uses the matches
1763/// as constraints for segmented DP alignment.
1764///
1765/// # Arguments
1766/// * `f1` — Target curve (length m)
1767/// * `f2` — Curve to align (length m)
1768/// * `argvals` — Evaluation points (length m)
1769/// * `kind` — Type of landmarks to detect
1770/// * `min_prominence` — Minimum prominence for landmark detection
1771/// * `expected_count` — Expected number of landmarks (0 = all detected)
1772/// * `lambda` — Penalty weight
1773pub fn elastic_align_pair_with_landmarks(
1774    f1: &[f64],
1775    f2: &[f64],
1776    argvals: &[f64],
1777    kind: crate::landmark::LandmarkKind,
1778    min_prominence: f64,
1779    expected_count: usize,
1780    lambda: f64,
1781) -> ConstrainedAlignmentResult {
1782    let lm1 = crate::landmark::detect_landmarks(f1, argvals, kind, min_prominence);
1783    let lm2 = crate::landmark::detect_landmarks(f2, argvals, kind, min_prominence);
1784
1785    // Match landmarks by order (take min count)
1786    let n_match = if expected_count > 0 {
1787        expected_count.min(lm1.len()).min(lm2.len())
1788    } else {
1789        lm1.len().min(lm2.len())
1790    };
1791
1792    let pairs: Vec<(f64, f64)> = (0..n_match)
1793        .map(|i| (lm1[i].position, lm2[i].position))
1794        .collect();
1795
1796    elastic_align_pair_constrained(f1, f2, argvals, &pairs, lambda)
1797}
1798
1799// ─── Multidimensional SRSF (R^d curves) ────────────────────────────────────
1800
1801use crate::matrix::FdCurveSet;
1802
1803/// Result of aligning multidimensional (R^d) curves.
1804#[derive(Debug, Clone)]
1805pub struct AlignmentResultNd {
1806    /// Optimal warping function (length m), same for all dimensions.
1807    pub gamma: Vec<f64>,
1808    /// Aligned curve: d vectors, each length m.
1809    pub f_aligned: Vec<Vec<f64>>,
1810    /// Elastic distance after alignment.
1811    pub distance: f64,
1812}
1813
1814/// Compute the SRSF transform for multidimensional (R^d) curves.
1815///
1816/// For f: \[0,1\] → R^d, the SRSF is q(t) = f'(t) / √‖f'(t)‖ where ‖·‖ is the
1817/// Euclidean norm in R^d. For d=1 this reduces to `sign(f') · √|f'|`.
1818///
1819/// # Arguments
1820/// * `data` — Set of n curves in R^d, each with m evaluation points
1821/// * `argvals` — Evaluation points (length m)
1822///
1823/// # Returns
1824/// `FdCurveSet` of SRSF values with the same shape as input.
1825/// Scale derivative vector at one point by 1/√‖f'‖, writing into result_dims.
1826#[inline]
1827fn srsf_scale_point(derivs: &[FdMatrix], result_dims: &mut [FdMatrix], i: usize, j: usize) {
1828    let d = derivs.len();
1829    let norm_sq: f64 = derivs.iter().map(|dd| dd[(i, j)].powi(2)).sum();
1830    let norm = norm_sq.sqrt();
1831    if norm < 1e-15 {
1832        for k in 0..d {
1833            result_dims[k][(i, j)] = 0.0;
1834        }
1835    } else {
1836        let scale = 1.0 / norm.sqrt();
1837        for k in 0..d {
1838            result_dims[k][(i, j)] = derivs[k][(i, j)] * scale;
1839        }
1840    }
1841}
1842
1843pub fn srsf_transform_nd(data: &FdCurveSet, argvals: &[f64]) -> FdCurveSet {
1844    let d = data.ndim();
1845    let n = data.ncurves();
1846    let m = data.npoints();
1847
1848    if d == 0 || n == 0 || m == 0 || argvals.len() != m {
1849        return FdCurveSet {
1850            dims: (0..d).map(|_| FdMatrix::zeros(n, m)).collect(),
1851        };
1852    }
1853
1854    let derivs: Vec<FdMatrix> = data
1855        .dims
1856        .iter()
1857        .map(|dim_mat| crate::fdata::deriv_1d(dim_mat, argvals, 1))
1858        .collect();
1859
1860    let mut result_dims: Vec<FdMatrix> = (0..d).map(|_| FdMatrix::zeros(n, m)).collect();
1861    for i in 0..n {
1862        for j in 0..m {
1863            srsf_scale_point(&derivs, &mut result_dims, i, j);
1864        }
1865    }
1866
1867    FdCurveSet { dims: result_dims }
1868}
1869
1870/// Reconstruct an R^d curve from its SRSF.
1871///
1872/// Given d-dimensional SRSF vectors and initial point f0, reconstructs:
1873/// `f_k(t) = f0_k + ∫₀ᵗ q_k(s) · ‖q(s)‖ ds` for each dimension k.
1874///
1875/// # Arguments
1876/// * `q` — SRSF: d vectors, each length m
1877/// * `argvals` — Evaluation points (length m)
1878/// * `f0` — Initial values in R^d (length d)
1879///
1880/// # Returns
1881/// Reconstructed curve: d vectors, each length m.
1882pub fn srsf_inverse_nd(q: &[Vec<f64>], argvals: &[f64], f0: &[f64]) -> Vec<Vec<f64>> {
1883    let d = q.len();
1884    if d == 0 {
1885        return Vec::new();
1886    }
1887    let m = q[0].len();
1888    if m == 0 {
1889        return vec![Vec::new(); d];
1890    }
1891
1892    // Compute ||q(t)|| at each time point
1893    let norms: Vec<f64> = (0..m)
1894        .map(|j| {
1895            let norm_sq: f64 = q.iter().map(|qk| qk[j].powi(2)).sum();
1896            norm_sq.sqrt()
1897        })
1898        .collect();
1899
1900    // For each dimension, integrand = q_k(t) * ||q(t)||
1901    let mut result = Vec::with_capacity(d);
1902    for k in 0..d {
1903        let integrand: Vec<f64> = (0..m).map(|j| q[k][j] * norms[j]).collect();
1904        let integral = cumulative_trapz(&integrand, argvals);
1905        let curve: Vec<f64> = integral.iter().map(|&v| f0[k] + v).collect();
1906        result.push(curve);
1907    }
1908
1909    result
1910}
1911
1912/// Core DP alignment for R^d SRSFs.
1913///
1914/// Same DP grid and coprime neighborhood as `dp_alignment_core`, but edge weight
1915/// is the sum of `dp_edge_weight` over d dimensions.
1916fn dp_alignment_core_nd(
1917    q1: &[Vec<f64>],
1918    q2: &[Vec<f64>],
1919    argvals: &[f64],
1920    lambda: f64,
1921) -> Vec<f64> {
1922    let d = q1.len();
1923    let m = argvals.len();
1924    if m < 2 || d == 0 {
1925        return argvals.to_vec();
1926    }
1927
1928    // For d=1, delegate to existing implementation for exact backward compat
1929    if d == 1 {
1930        return dp_alignment_core(&q1[0], &q2[0], argvals, lambda);
1931    }
1932
1933    // Normalize each dimension's SRSF to unit L2 norm
1934    let q1n: Vec<Vec<f64>> = q1
1935        .iter()
1936        .map(|qk| {
1937            let norm = qk.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
1938            qk.iter().map(|&v| v / norm).collect()
1939        })
1940        .collect();
1941    let q2n: Vec<Vec<f64>> = q2
1942        .iter()
1943        .map(|qk| {
1944            let norm = qk.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
1945            qk.iter().map(|&v| v / norm).collect()
1946        })
1947        .collect();
1948
1949    let path = dp_grid_solve(m, m, |sr, sc, tr, tc| {
1950        let w: f64 = (0..d)
1951            .map(|k| dp_edge_weight(&q1n[k], &q2n[k], argvals, sc, tc, sr, tr))
1952            .sum();
1953        w + dp_lambda_penalty(argvals, sc, tc, sr, tr, lambda)
1954    });
1955
1956    dp_path_to_gamma(&path, argvals)
1957}
1958
1959/// Align an R^d curve f2 to f1 using the elastic framework.
1960///
1961/// Finds the optimal warping γ (shared across all dimensions) such that
1962/// f2∘γ is as close as possible to f1 in the elastic metric.
1963///
1964/// # Arguments
1965/// * `f1` — Target curves (d dimensions)
1966/// * `f2` — Curves to align (d dimensions)
1967/// * `argvals` — Evaluation points (length m)
1968/// * `lambda` — Penalty weight (0.0 = no penalty)
1969pub fn elastic_align_pair_nd(
1970    f1: &FdCurveSet,
1971    f2: &FdCurveSet,
1972    argvals: &[f64],
1973    lambda: f64,
1974) -> AlignmentResultNd {
1975    let d = f1.ndim();
1976    let m = f1.npoints();
1977
1978    // Compute SRSFs
1979    let q1_set = srsf_transform_nd(f1, argvals);
1980    let q2_set = srsf_transform_nd(f2, argvals);
1981
1982    // Extract first curve from each dimension
1983    let q1: Vec<Vec<f64>> = q1_set.dims.iter().map(|dm| dm.row(0)).collect();
1984    let q2: Vec<Vec<f64>> = q2_set.dims.iter().map(|dm| dm.row(0)).collect();
1985
1986    // DP alignment using summed cost over dimensions
1987    let gamma = dp_alignment_core_nd(&q1, &q2, argvals, lambda);
1988
1989    // Apply warping to f2 in each dimension
1990    let f_aligned: Vec<Vec<f64>> = f2
1991        .dims
1992        .iter()
1993        .map(|dm| {
1994            let row = dm.row(0);
1995            reparameterize_curve(&row, argvals, &gamma)
1996        })
1997        .collect();
1998
1999    // Compute elastic distance: sum of squared L2 distances between aligned SRSFs
2000    let f_aligned_set = {
2001        let dims: Vec<FdMatrix> = f_aligned
2002            .iter()
2003            .map(|fa| FdMatrix::from_slice(fa, 1, m).unwrap())
2004            .collect();
2005        FdCurveSet { dims }
2006    };
2007    let q_aligned = srsf_transform_nd(&f_aligned_set, argvals);
2008    let weights = simpsons_weights(argvals);
2009
2010    let mut dist_sq = 0.0;
2011    for k in 0..d {
2012        let q1k = q1_set.dims[k].row(0);
2013        let qak = q_aligned.dims[k].row(0);
2014        let d_k = l2_distance(&q1k, &qak, &weights);
2015        dist_sq += d_k * d_k;
2016    }
2017
2018    AlignmentResultNd {
2019        gamma,
2020        f_aligned,
2021        distance: dist_sq.sqrt(),
2022    }
2023}
2024
2025/// Elastic distance between two R^d curves.
2026///
2027/// Aligns f2 to f1 and returns the post-alignment SRSF distance.
2028pub fn elastic_distance_nd(f1: &FdCurveSet, f2: &FdCurveSet, argvals: &[f64], lambda: f64) -> f64 {
2029    elastic_align_pair_nd(f1, f2, argvals, lambda).distance
2030}
2031
2032// ─── Tests ──────────────────────────────────────────────────────────────────
2033
2034#[cfg(test)]
2035mod tests {
2036    use super::*;
2037    use crate::helpers::trapz;
2038    use crate::simulation::{sim_fundata, EFunType, EValType};
2039    use crate::warping::inner_product_l2;
2040
2041    fn uniform_grid(m: usize) -> Vec<f64> {
2042        (0..m).map(|i| i as f64 / (m - 1) as f64).collect()
2043    }
2044
2045    fn make_test_data(n: usize, m: usize, seed: u64) -> FdMatrix {
2046        let t = uniform_grid(m);
2047        sim_fundata(
2048            n,
2049            &t,
2050            3,
2051            EFunType::Fourier,
2052            EValType::Exponential,
2053            Some(seed),
2054        )
2055    }
2056
2057    // ── cumulative_trapz ──
2058
2059    #[test]
2060    fn test_cumulative_trapz_constant() {
2061        // ∫₀ᵗ 1 dt = t
2062        let x = uniform_grid(50);
2063        let y = vec![1.0; 50];
2064        let result = cumulative_trapz(&y, &x);
2065        assert!((result[0]).abs() < 1e-15, "cumulative_trapz(0) should be 0");
2066        for j in 1..50 {
2067            assert!(
2068                (result[j] - x[j]).abs() < 1e-12,
2069                "∫₀^{:.3} 1 dt should be {:.3}, got {:.3}",
2070                x[j],
2071                x[j],
2072                result[j]
2073            );
2074        }
2075    }
2076
2077    #[test]
2078    fn test_cumulative_trapz_linear() {
2079        // ∫₀ᵗ s ds = t²/2
2080        let m = 100;
2081        let x = uniform_grid(m);
2082        let y: Vec<f64> = x.clone();
2083        let result = cumulative_trapz(&y, &x);
2084        for j in 1..m {
2085            let expected = x[j] * x[j] / 2.0;
2086            assert!(
2087                (result[j] - expected).abs() < 1e-4,
2088                "∫₀^{:.3} s ds: expected {expected:.6}, got {:.6}",
2089                x[j],
2090                result[j]
2091            );
2092        }
2093    }
2094
2095    // ── normalize_warp ──
2096
2097    #[test]
2098    fn test_normalize_warp_fixes_boundaries() {
2099        let t = uniform_grid(10);
2100        let mut gamma = vec![0.1; 10]; // constant, wrong boundaries
2101        normalize_warp(&mut gamma, &t);
2102        assert_eq!(gamma[0], t[0]);
2103        assert_eq!(gamma[9], t[9]);
2104    }
2105
2106    #[test]
2107    fn test_normalize_warp_enforces_monotonicity() {
2108        let t = uniform_grid(5);
2109        let mut gamma = vec![0.0, 0.5, 0.3, 0.8, 1.0]; // non-monotone at index 2
2110        normalize_warp(&mut gamma, &t);
2111        for j in 1..5 {
2112            assert!(
2113                gamma[j] >= gamma[j - 1],
2114                "gamma should be monotone after normalization at j={j}"
2115            );
2116        }
2117    }
2118
2119    #[test]
2120    fn test_normalize_warp_identity_unchanged() {
2121        let t = uniform_grid(20);
2122        let mut gamma = t.clone();
2123        normalize_warp(&mut gamma, &t);
2124        for j in 0..20 {
2125            assert!(
2126                (gamma[j] - t[j]).abs() < 1e-15,
2127                "Identity warp should be unchanged"
2128            );
2129        }
2130    }
2131
2132    // ── linear_interp ──
2133
2134    #[test]
2135    fn test_linear_interp_at_nodes() {
2136        let x = vec![0.0, 1.0, 2.0, 3.0];
2137        let y = vec![0.0, 2.0, 4.0, 6.0];
2138        for i in 0..x.len() {
2139            assert!((linear_interp(&x, &y, x[i]) - y[i]).abs() < 1e-12);
2140        }
2141    }
2142
2143    #[test]
2144    fn test_linear_interp_midpoints() {
2145        let x = vec![0.0, 1.0, 2.0];
2146        let y = vec![0.0, 2.0, 4.0];
2147        assert!((linear_interp(&x, &y, 0.5) - 1.0).abs() < 1e-12);
2148        assert!((linear_interp(&x, &y, 1.5) - 3.0).abs() < 1e-12);
2149    }
2150
2151    #[test]
2152    fn test_linear_interp_clamp() {
2153        let x = vec![0.0, 1.0, 2.0];
2154        let y = vec![1.0, 3.0, 5.0];
2155        assert!((linear_interp(&x, &y, -1.0) - 1.0).abs() < 1e-12);
2156        assert!((linear_interp(&x, &y, 3.0) - 5.0).abs() < 1e-12);
2157    }
2158
2159    #[test]
2160    fn test_linear_interp_nonuniform_grid() {
2161        let x = vec![0.0, 0.1, 0.5, 1.0];
2162        let y = vec![0.0, 1.0, 5.0, 10.0];
2163        // Between 0.1 and 0.5: slope = (5-1)/(0.5-0.1) = 10
2164        let val = linear_interp(&x, &y, 0.3);
2165        let expected = 1.0 + 10.0 * (0.3 - 0.1);
2166        assert!(
2167            (val - expected).abs() < 1e-12,
2168            "Non-uniform interp: expected {expected}, got {val}"
2169        );
2170    }
2171
2172    #[test]
2173    fn test_linear_interp_two_points() {
2174        let x = vec![0.0, 1.0];
2175        let y = vec![3.0, 7.0];
2176        assert!((linear_interp(&x, &y, 0.25) - 4.0).abs() < 1e-12);
2177        assert!((linear_interp(&x, &y, 0.75) - 6.0).abs() < 1e-12);
2178    }
2179
2180    // ── SRSF transform ──
2181
2182    #[test]
2183    fn test_srsf_transform_linear() {
2184        // f(t) = 2t: derivative = 2, SRSF = sqrt(2)
2185        let m = 50;
2186        let t = uniform_grid(m);
2187        let f: Vec<f64> = t.iter().map(|&ti| 2.0 * ti).collect();
2188        let mat = FdMatrix::from_slice(&f, 1, m).unwrap();
2189
2190        let q_mat = srsf_transform(&mat, &t);
2191        let q: Vec<f64> = q_mat.row(0);
2192
2193        let expected = 2.0_f64.sqrt();
2194        // Interior points should be close to sqrt(2)
2195        for j in 2..(m - 2) {
2196            assert!(
2197                (q[j] - expected).abs() < 0.1,
2198                "q[{j}] = {}, expected ~{expected}",
2199                q[j]
2200            );
2201        }
2202    }
2203
2204    #[test]
2205    fn test_srsf_transform_preserves_shape() {
2206        let data = make_test_data(10, 50, 42);
2207        let t = uniform_grid(50);
2208        let q = srsf_transform(&data, &t);
2209        assert_eq!(q.shape(), data.shape());
2210    }
2211
2212    #[test]
2213    fn test_srsf_transform_constant_is_zero() {
2214        // f(t) = 5 (constant): derivative = 0, SRSF = 0
2215        let m = 30;
2216        let t = uniform_grid(m);
2217        let f = vec![5.0; m];
2218        let mat = FdMatrix::from_slice(&f, 1, m).unwrap();
2219        let q_mat = srsf_transform(&mat, &t);
2220        let q: Vec<f64> = q_mat.row(0);
2221
2222        for j in 0..m {
2223            assert!(
2224                q[j].abs() < 1e-10,
2225                "SRSF of constant should be 0, got q[{j}] = {}",
2226                q[j]
2227            );
2228        }
2229    }
2230
2231    #[test]
2232    fn test_srsf_transform_negative_slope() {
2233        // f(t) = -3t: derivative = -3, SRSF = -sqrt(3)
2234        let m = 50;
2235        let t = uniform_grid(m);
2236        let f: Vec<f64> = t.iter().map(|&ti| -3.0 * ti).collect();
2237        let mat = FdMatrix::from_slice(&f, 1, m).unwrap();
2238
2239        let q_mat = srsf_transform(&mat, &t);
2240        let q: Vec<f64> = q_mat.row(0);
2241
2242        let expected = -(3.0_f64.sqrt());
2243        for j in 2..(m - 2) {
2244            assert!(
2245                (q[j] - expected).abs() < 0.15,
2246                "q[{j}] = {}, expected ~{expected}",
2247                q[j]
2248            );
2249        }
2250    }
2251
2252    #[test]
2253    fn test_srsf_transform_empty_input() {
2254        let data = FdMatrix::zeros(0, 0);
2255        let t: Vec<f64> = vec![];
2256        let q = srsf_transform(&data, &t);
2257        assert_eq!(q.shape(), (0, 0));
2258    }
2259
2260    #[test]
2261    fn test_srsf_transform_multiple_curves() {
2262        let m = 40;
2263        let t = uniform_grid(m);
2264        let data = make_test_data(5, m, 42);
2265
2266        let q = srsf_transform(&data, &t);
2267        assert_eq!(q.shape(), (5, m));
2268
2269        // Each row should have finite values
2270        for i in 0..5 {
2271            for j in 0..m {
2272                assert!(q[(i, j)].is_finite(), "SRSF should be finite at ({i},{j})");
2273            }
2274        }
2275    }
2276
2277    // ── SRSF inverse ──
2278
2279    #[test]
2280    fn test_srsf_round_trip() {
2281        let m = 100;
2282        let t = uniform_grid(m);
2283        // Use a smooth function
2284        let f: Vec<f64> = t
2285            .iter()
2286            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin() + ti)
2287            .collect();
2288
2289        let mat = FdMatrix::from_slice(&f, 1, m).unwrap();
2290        let q_mat = srsf_transform(&mat, &t);
2291        let q: Vec<f64> = q_mat.row(0);
2292
2293        let f_recon = srsf_inverse(&q, &t, f[0]);
2294
2295        // Check reconstruction is close (interior points, avoid boundary effects)
2296        let max_err: f64 = f[5..(m - 5)]
2297            .iter()
2298            .zip(f_recon[5..(m - 5)].iter())
2299            .map(|(a, b)| (a - b).abs())
2300            .fold(0.0_f64, f64::max);
2301
2302        assert!(
2303            max_err < 0.15,
2304            "Round-trip error too large: max_err = {max_err}"
2305        );
2306    }
2307
2308    #[test]
2309    fn test_srsf_inverse_empty() {
2310        let q: Vec<f64> = vec![];
2311        let t: Vec<f64> = vec![];
2312        let result = srsf_inverse(&q, &t, 0.0);
2313        assert!(result.is_empty());
2314    }
2315
2316    #[test]
2317    fn test_srsf_inverse_preserves_initial_value() {
2318        let m = 50;
2319        let t = uniform_grid(m);
2320        let q = vec![1.0; m]; // constant SRSF
2321        let f0 = 3.15;
2322        let f = srsf_inverse(&q, &t, f0);
2323        assert!((f[0] - f0).abs() < 1e-12, "srsf_inverse should start at f0");
2324    }
2325
2326    #[test]
2327    fn test_srsf_round_trip_multiple_curves() {
2328        let m = 80;
2329        let t = uniform_grid(m);
2330        let data = make_test_data(5, m, 99);
2331
2332        let q_mat = srsf_transform(&data, &t);
2333
2334        for i in 0..5 {
2335            let fi = data.row(i);
2336            let qi = q_mat.row(i);
2337            let f_recon = srsf_inverse(&qi, &t, fi[0]);
2338            let max_err: f64 = fi[5..(m - 5)]
2339                .iter()
2340                .zip(f_recon[5..(m - 5)].iter())
2341                .map(|(a, b)| (a - b).abs())
2342                .fold(0.0_f64, f64::max);
2343            assert!(max_err < 0.3, "Round-trip curve {i}: max_err = {max_err}");
2344        }
2345    }
2346
2347    // ── Reparameterize ──
2348
2349    #[test]
2350    fn test_reparameterize_identity_warp() {
2351        let m = 50;
2352        let t = uniform_grid(m);
2353        let f: Vec<f64> = t.iter().map(|&ti| ti * ti).collect();
2354
2355        // Identity warp: γ(t) = t
2356        let result = reparameterize_curve(&f, &t, &t);
2357        for j in 0..m {
2358            assert!(
2359                (result[j] - f[j]).abs() < 1e-12,
2360                "Identity warp should return original at j={j}"
2361            );
2362        }
2363    }
2364
2365    #[test]
2366    fn test_reparameterize_linear_warp() {
2367        let m = 50;
2368        let t = uniform_grid(m);
2369        // f(t) = t (linear), γ(t) = t^2 (quadratic warp on [0,1])
2370        let f: Vec<f64> = t.clone();
2371        let gamma: Vec<f64> = t.iter().map(|&ti| ti * ti).collect();
2372
2373        let result = reparameterize_curve(&f, &t, &gamma);
2374
2375        // f(γ(t)) = γ(t) = t^2 for a linear f(t) = t
2376        for j in 0..m {
2377            assert!(
2378                (result[j] - gamma[j]).abs() < 1e-10,
2379                "f(gamma(t)) should be gamma(t) for f(t)=t at j={j}"
2380            );
2381        }
2382    }
2383
2384    #[test]
2385    fn test_reparameterize_sine_with_quadratic_warp() {
2386        let m = 100;
2387        let t = uniform_grid(m);
2388        let f: Vec<f64> = t
2389            .iter()
2390            .map(|&ti| (std::f64::consts::PI * ti).sin())
2391            .collect();
2392        let gamma: Vec<f64> = t.iter().map(|&ti| ti * ti).collect(); // speeds up start
2393
2394        let result = reparameterize_curve(&f, &t, &gamma);
2395
2396        // f(γ(t)) = sin(π * t²); check a few known values
2397        for j in 0..m {
2398            let expected = (std::f64::consts::PI * gamma[j]).sin();
2399            assert!(
2400                (result[j] - expected).abs() < 0.05,
2401                "sin(π γ(t)) at j={j}: expected {expected:.4}, got {:.4}",
2402                result[j]
2403            );
2404        }
2405    }
2406
2407    #[test]
2408    fn test_reparameterize_preserves_length() {
2409        let m = 50;
2410        let t = uniform_grid(m);
2411        let f = vec![1.0; m];
2412        let gamma: Vec<f64> = t.iter().map(|&ti| ti.sqrt()).collect();
2413
2414        let result = reparameterize_curve(&f, &t, &gamma);
2415        assert_eq!(result.len(), m);
2416    }
2417
2418    // ── Compose warps ──
2419
2420    #[test]
2421    fn test_compose_warps_identity() {
2422        let m = 50;
2423        let t = uniform_grid(m);
2424        // γ(t) = t^0.5 (a warp on [0,1])
2425        let gamma: Vec<f64> = t.iter().map(|&ti| ti.sqrt()).collect();
2426
2427        // identity ∘ γ = γ
2428        let result = compose_warps(&t, &gamma, &t);
2429        for j in 0..m {
2430            assert!(
2431                (result[j] - gamma[j]).abs() < 1e-10,
2432                "id ∘ γ should be γ at j={j}"
2433            );
2434        }
2435
2436        // γ ∘ identity = γ
2437        let result2 = compose_warps(&gamma, &t, &t);
2438        for j in 0..m {
2439            assert!(
2440                (result2[j] - gamma[j]).abs() < 1e-10,
2441                "γ ∘ id should be γ at j={j}"
2442            );
2443        }
2444    }
2445
2446    #[test]
2447    fn test_compose_warps_associativity() {
2448        // (γ₁ ∘ γ₂) ∘ γ₃ ≈ γ₁ ∘ (γ₂ ∘ γ₃)
2449        let m = 50;
2450        let t = uniform_grid(m);
2451        let g1: Vec<f64> = t.iter().map(|&ti| ti.sqrt()).collect();
2452        let g2: Vec<f64> = t.iter().map(|&ti| ti * ti).collect();
2453        let g3: Vec<f64> = t.iter().map(|&ti| 0.5 * ti + 0.5 * ti * ti).collect();
2454
2455        let g12 = compose_warps(&g1, &g2, &t);
2456        let left = compose_warps(&g12, &g3, &t); // (g1∘g2) ∘ g3
2457
2458        let g23 = compose_warps(&g2, &g3, &t);
2459        let right = compose_warps(&g1, &g23, &t); // g1 ∘ (g2∘g3)
2460
2461        for j in 0..m {
2462            assert!(
2463                (left[j] - right[j]).abs() < 0.05,
2464                "Composition should be roughly associative at j={j}: left={:.4}, right={:.4}",
2465                left[j],
2466                right[j]
2467            );
2468        }
2469    }
2470
2471    #[test]
2472    fn test_compose_warps_preserves_domain() {
2473        let m = 50;
2474        let t = uniform_grid(m);
2475        let g1: Vec<f64> = t.iter().map(|&ti| ti.sqrt()).collect();
2476        let g2: Vec<f64> = t.iter().map(|&ti| ti * ti).collect();
2477
2478        let composed = compose_warps(&g1, &g2, &t);
2479        assert!(
2480            (composed[0] - t[0]).abs() < 1e-10,
2481            "Composed warp should start at domain start"
2482        );
2483        assert!(
2484            (composed[m - 1] - t[m - 1]).abs() < 1e-10,
2485            "Composed warp should end at domain end"
2486        );
2487    }
2488
2489    // ── Elastic align pair ──
2490
2491    #[test]
2492    fn test_align_identical_curves() {
2493        let m = 50;
2494        let t = uniform_grid(m);
2495        let f: Vec<f64> = t
2496            .iter()
2497            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
2498            .collect();
2499
2500        let result = elastic_align_pair(&f, &f, &t, 0.0);
2501
2502        // Distance should be near zero
2503        assert!(
2504            result.distance < 0.1,
2505            "Distance between identical curves should be near 0, got {}",
2506            result.distance
2507        );
2508
2509        // Warp should be near identity
2510        for j in 0..m {
2511            assert!(
2512                (result.gamma[j] - t[j]).abs() < 0.1,
2513                "Warp should be near identity at j={j}: gamma={}, t={}",
2514                result.gamma[j],
2515                t[j]
2516            );
2517        }
2518    }
2519
2520    #[test]
2521    fn test_align_pair_valid_output() {
2522        let data = make_test_data(2, 50, 42);
2523        let t = uniform_grid(50);
2524        let f1 = data.row(0);
2525        let f2 = data.row(1);
2526
2527        let result = elastic_align_pair(&f1, &f2, &t, 0.0);
2528
2529        assert_eq!(result.gamma.len(), 50);
2530        assert_eq!(result.f_aligned.len(), 50);
2531        assert!(result.distance >= 0.0);
2532
2533        // Warp should be monotone
2534        for j in 1..50 {
2535            assert!(
2536                result.gamma[j] >= result.gamma[j - 1],
2537                "Warp should be monotone at j={j}"
2538            );
2539        }
2540    }
2541
2542    #[test]
2543    fn test_align_pair_warp_boundaries() {
2544        let data = make_test_data(2, 50, 42);
2545        let t = uniform_grid(50);
2546        let f1 = data.row(0);
2547        let f2 = data.row(1);
2548
2549        let result = elastic_align_pair(&f1, &f2, &t, 0.0);
2550        assert!(
2551            (result.gamma[0] - t[0]).abs() < 1e-12,
2552            "Warp should start at domain start"
2553        );
2554        assert!(
2555            (result.gamma[49] - t[49]).abs() < 1e-12,
2556            "Warp should end at domain end"
2557        );
2558    }
2559
2560    #[test]
2561    fn test_align_shifted_sine() {
2562        // Two sines with a phase shift — alignment should reduce distance
2563        let m = 80;
2564        let t = uniform_grid(m);
2565        let f1: Vec<f64> = t
2566            .iter()
2567            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
2568            .collect();
2569        let f2: Vec<f64> = t
2570            .iter()
2571            .map(|&ti| (2.0 * std::f64::consts::PI * (ti - 0.1)).sin())
2572            .collect();
2573
2574        let weights = simpsons_weights(&t);
2575        let l2_before = l2_distance(&f1, &f2, &weights);
2576        let result = elastic_align_pair(&f1, &f2, &t, 0.0);
2577        let l2_after = l2_distance(&f1, &result.f_aligned, &weights);
2578
2579        assert!(
2580            l2_after < l2_before + 0.01,
2581            "Alignment should not increase L2 distance: before={l2_before:.4}, after={l2_after:.4}"
2582        );
2583    }
2584
2585    #[test]
2586    fn test_align_pair_aligned_curve_is_finite() {
2587        let data = make_test_data(2, 50, 77);
2588        let t = uniform_grid(50);
2589        let f1 = data.row(0);
2590        let f2 = data.row(1);
2591
2592        let result = elastic_align_pair(&f1, &f2, &t, 0.0);
2593        for j in 0..50 {
2594            assert!(
2595                result.f_aligned[j].is_finite(),
2596                "Aligned curve should be finite at j={j}"
2597            );
2598        }
2599    }
2600
2601    #[test]
2602    fn test_align_pair_minimum_grid() {
2603        // Minimum viable grid: m = 2
2604        let t = vec![0.0, 1.0];
2605        let f1 = vec![0.0, 1.0];
2606        let f2 = vec![0.0, 2.0];
2607        let result = elastic_align_pair(&f1, &f2, &t, 0.0);
2608        assert_eq!(result.gamma.len(), 2);
2609        assert_eq!(result.f_aligned.len(), 2);
2610        assert!(result.distance >= 0.0);
2611    }
2612
2613    // ── Elastic distance ──
2614
2615    #[test]
2616    fn test_elastic_distance_symmetric() {
2617        let data = make_test_data(3, 50, 42);
2618        let t = uniform_grid(50);
2619        let f1 = data.row(0);
2620        let f2 = data.row(1);
2621
2622        let d12 = elastic_distance(&f1, &f2, &t, 0.0);
2623        let d21 = elastic_distance(&f2, &f1, &t, 0.0);
2624
2625        // Should be approximately symmetric (DP is not perfectly symmetric)
2626        assert!(
2627            (d12 - d21).abs() < d12.max(d21) * 0.3 + 0.01,
2628            "Elastic distance should be roughly symmetric: d12={d12}, d21={d21}"
2629        );
2630    }
2631
2632    #[test]
2633    fn test_elastic_distance_nonneg() {
2634        let data = make_test_data(3, 50, 42);
2635        let t = uniform_grid(50);
2636
2637        for i in 0..3 {
2638            for j in 0..3 {
2639                let fi = data.row(i);
2640                let fj = data.row(j);
2641                let d = elastic_distance(&fi, &fj, &t, 0.0);
2642                assert!(d >= 0.0, "Elastic distance should be non-negative");
2643            }
2644        }
2645    }
2646
2647    #[test]
2648    fn test_elastic_distance_self_near_zero() {
2649        let data = make_test_data(3, 50, 42);
2650        let t = uniform_grid(50);
2651
2652        for i in 0..3 {
2653            let fi = data.row(i);
2654            let d = elastic_distance(&fi, &fi, &t, 0.0);
2655            assert!(
2656                d < 0.1,
2657                "Self-distance should be near zero, got {d} for curve {i}"
2658            );
2659        }
2660    }
2661
2662    #[test]
2663    fn test_elastic_distance_triangle_inequality() {
2664        let data = make_test_data(3, 50, 42);
2665        let t = uniform_grid(50);
2666        let f0 = data.row(0);
2667        let f1 = data.row(1);
2668        let f2 = data.row(2);
2669
2670        let d01 = elastic_distance(&f0, &f1, &t, 0.0);
2671        let d12 = elastic_distance(&f1, &f2, &t, 0.0);
2672        let d02 = elastic_distance(&f0, &f2, &t, 0.0);
2673
2674        // Relaxed triangle inequality (DP alignment is approximate)
2675        let slack = 0.5;
2676        assert!(
2677            d02 <= d01 + d12 + slack,
2678            "Triangle inequality (relaxed): d02={d02:.4} > d01={d01:.4} + d12={d12:.4} + {slack}"
2679        );
2680    }
2681
2682    #[test]
2683    fn test_elastic_distance_different_shapes_nonzero() {
2684        let m = 50;
2685        let t = uniform_grid(m);
2686        let f1: Vec<f64> = t.to_vec(); // linear
2687        let f2: Vec<f64> = t.iter().map(|&ti| ti * ti).collect(); // quadratic
2688
2689        let d = elastic_distance(&f1, &f2, &t, 0.0);
2690        assert!(
2691            d > 0.01,
2692            "Distance between different shapes should be > 0, got {d}"
2693        );
2694    }
2695
2696    // ── Self distance matrix ──
2697
2698    #[test]
2699    fn test_self_distance_matrix_symmetric() {
2700        let data = make_test_data(5, 30, 42);
2701        let t = uniform_grid(30);
2702
2703        let dm = elastic_self_distance_matrix(&data, &t, 0.0);
2704        let n = dm.nrows();
2705
2706        assert_eq!(dm.shape(), (5, 5));
2707
2708        // Zero diagonal
2709        for i in 0..n {
2710            assert!(dm[(i, i)].abs() < 1e-12, "Diagonal should be zero");
2711        }
2712
2713        // Symmetric
2714        for i in 0..n {
2715            for j in (i + 1)..n {
2716                assert!(
2717                    (dm[(i, j)] - dm[(j, i)]).abs() < 1e-12,
2718                    "Matrix should be symmetric at ({i},{j})"
2719                );
2720            }
2721        }
2722    }
2723
2724    #[test]
2725    fn test_self_distance_matrix_nonneg() {
2726        let data = make_test_data(4, 30, 42);
2727        let t = uniform_grid(30);
2728        let dm = elastic_self_distance_matrix(&data, &t, 0.0);
2729
2730        for i in 0..4 {
2731            for j in 0..4 {
2732                assert!(
2733                    dm[(i, j)] >= 0.0,
2734                    "Distance matrix entries should be non-negative at ({i},{j})"
2735                );
2736            }
2737        }
2738    }
2739
2740    #[test]
2741    fn test_self_distance_matrix_single_curve() {
2742        let data = make_test_data(1, 30, 42);
2743        let t = uniform_grid(30);
2744        let dm = elastic_self_distance_matrix(&data, &t, 0.0);
2745        assert_eq!(dm.shape(), (1, 1));
2746        assert!(dm[(0, 0)].abs() < 1e-12);
2747    }
2748
2749    #[test]
2750    fn test_self_distance_matrix_consistent_with_pairwise() {
2751        let data = make_test_data(4, 30, 42);
2752        let t = uniform_grid(30);
2753
2754        let dm = elastic_self_distance_matrix(&data, &t, 0.0);
2755
2756        // Check a few entries match direct elastic_distance calls
2757        for i in 0..4 {
2758            for j in (i + 1)..4 {
2759                let fi = data.row(i);
2760                let fj = data.row(j);
2761                let d_direct = elastic_distance(&fi, &fj, &t, 0.0);
2762                assert!(
2763                    (dm[(i, j)] - d_direct).abs() < 1e-10,
2764                    "Matrix entry ({i},{j})={:.6} should match pairwise {d_direct:.6}",
2765                    dm[(i, j)]
2766                );
2767            }
2768        }
2769    }
2770
2771    // ── Karcher mean ──
2772
2773    #[test]
2774    fn test_karcher_mean_identical_curves() {
2775        let m = 50;
2776        let t = uniform_grid(m);
2777        let f: Vec<f64> = t
2778            .iter()
2779            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
2780            .collect();
2781
2782        // Create 5 identical curves
2783        let mut data = FdMatrix::zeros(5, m);
2784        for i in 0..5 {
2785            for j in 0..m {
2786                data[(i, j)] = f[j];
2787            }
2788        }
2789
2790        let result = karcher_mean(&data, &t, 10, 1e-4, 0.0);
2791
2792        assert_eq!(result.mean.len(), m);
2793        assert!(result.n_iter <= 10);
2794    }
2795
2796    #[test]
2797    fn test_karcher_mean_output_shape() {
2798        let data = make_test_data(15, 50, 42);
2799        let t = uniform_grid(50);
2800
2801        let result = karcher_mean(&data, &t, 5, 1e-3, 0.0);
2802
2803        assert_eq!(result.mean.len(), 50);
2804        assert_eq!(result.mean_srsf.len(), 50);
2805        assert_eq!(result.gammas.shape(), (15, 50));
2806        assert_eq!(result.aligned_data.shape(), (15, 50));
2807        assert!(result.n_iter <= 5);
2808    }
2809
2810    #[test]
2811    fn test_karcher_mean_warps_are_valid() {
2812        let data = make_test_data(10, 40, 42);
2813        let t = uniform_grid(40);
2814
2815        let result = karcher_mean(&data, &t, 5, 1e-3, 0.0);
2816
2817        for i in 0..10 {
2818            // Boundary values
2819            assert!(
2820                (result.gammas[(i, 0)] - t[0]).abs() < 1e-10,
2821                "Warp {i} should start at domain start"
2822            );
2823            assert!(
2824                (result.gammas[(i, 39)] - t[39]).abs() < 1e-10,
2825                "Warp {i} should end at domain end"
2826            );
2827            // Monotonicity
2828            for j in 1..40 {
2829                assert!(
2830                    result.gammas[(i, j)] >= result.gammas[(i, j - 1)],
2831                    "Warp {i} should be monotone at j={j}"
2832                );
2833            }
2834        }
2835    }
2836
2837    #[test]
2838    fn test_karcher_mean_aligned_data_is_finite() {
2839        let data = make_test_data(8, 40, 42);
2840        let t = uniform_grid(40);
2841        let result = karcher_mean(&data, &t, 5, 1e-3, 0.0);
2842
2843        for i in 0..8 {
2844            for j in 0..40 {
2845                assert!(
2846                    result.aligned_data[(i, j)].is_finite(),
2847                    "Aligned data should be finite at ({i},{j})"
2848                );
2849            }
2850        }
2851    }
2852
2853    #[test]
2854    fn test_karcher_mean_srsf_is_finite() {
2855        let data = make_test_data(8, 40, 42);
2856        let t = uniform_grid(40);
2857        let result = karcher_mean(&data, &t, 5, 1e-3, 0.0);
2858
2859        for j in 0..40 {
2860            assert!(
2861                result.mean_srsf[j].is_finite(),
2862                "Mean SRSF should be finite at j={j}"
2863            );
2864            assert!(
2865                result.mean[j].is_finite(),
2866                "Mean curve should be finite at j={j}"
2867            );
2868        }
2869    }
2870
2871    #[test]
2872    fn test_karcher_mean_single_iteration() {
2873        let data = make_test_data(10, 40, 42);
2874        let t = uniform_grid(40);
2875        let result = karcher_mean(&data, &t, 1, 1e-10, 0.0);
2876
2877        assert_eq!(result.n_iter, 1);
2878        assert_eq!(result.mean.len(), 40);
2879        // With only 1 iteration, still produces valid output
2880        for j in 0..40 {
2881            assert!(result.mean[j].is_finite());
2882        }
2883    }
2884
2885    #[test]
2886    fn test_karcher_mean_convergence_not_premature() {
2887        // The old convergence criterion (rel - prev_rel <= tol * prev_rel) was
2888        // always satisfied when the algorithm improved, causing premature exit
2889        // after 2 iterations. With the fix (rel < tol), the algorithm should
2890        // actually iterate until convergence or hitting the iteration cap.
2891        let n = 10;
2892        let m = 50;
2893        let t = uniform_grid(m);
2894
2895        // Create phase-shifted curves that genuinely need alignment
2896        let mut col_major = vec![0.0; n * m];
2897        for i in 0..n {
2898            let shift = (i as f64 - 5.0) * 0.05;
2899            for j in 0..m {
2900                col_major[i + j * n] = (2.0 * std::f64::consts::PI * (t[j] - shift)).sin();
2901            }
2902        }
2903        let data = FdMatrix::from_column_major(col_major, n, m).unwrap();
2904
2905        // With an impossibly tight tolerance, the algorithm should hit the
2906        // iteration cap rather than "converging" after 2 iterations.
2907        let result = karcher_mean(&data, &t, 20, 1e-15, 0.0);
2908        assert!(
2909            result.n_iter > 2,
2910            "With tol=1e-15 the algorithm should iterate beyond 2, got n_iter={}",
2911            result.n_iter
2912        );
2913
2914        // With a reasonable tolerance, it should converge and report so
2915        let result_loose = karcher_mean(&data, &t, 20, 1e-2, 0.0);
2916        assert!(
2917            result_loose.converged,
2918            "With tol=1e-2 the algorithm should converge"
2919        );
2920    }
2921
2922    // ── Align to target ──
2923
2924    #[test]
2925    fn test_align_to_target_valid() {
2926        let data = make_test_data(10, 40, 42);
2927        let t = uniform_grid(40);
2928        let target = data.row(0);
2929
2930        let result = align_to_target(&data, &target, &t, 0.0);
2931
2932        assert_eq!(result.gammas.shape(), (10, 40));
2933        assert_eq!(result.aligned_data.shape(), (10, 40));
2934        assert_eq!(result.distances.len(), 10);
2935
2936        // All distances should be non-negative
2937        for &d in &result.distances {
2938            assert!(d >= 0.0);
2939        }
2940    }
2941
2942    #[test]
2943    fn test_align_to_target_self_near_zero() {
2944        let data = make_test_data(5, 40, 42);
2945        let t = uniform_grid(40);
2946        let target = data.row(0);
2947
2948        let result = align_to_target(&data, &target, &t, 0.0);
2949
2950        // Distance of target to itself should be near zero
2951        assert!(
2952            result.distances[0] < 0.1,
2953            "Self-alignment distance should be near zero, got {}",
2954            result.distances[0]
2955        );
2956    }
2957
2958    #[test]
2959    fn test_align_to_target_warps_are_monotone() {
2960        let data = make_test_data(8, 40, 42);
2961        let t = uniform_grid(40);
2962        let target = data.row(0);
2963        let result = align_to_target(&data, &target, &t, 0.0);
2964
2965        for i in 0..8 {
2966            for j in 1..40 {
2967                assert!(
2968                    result.gammas[(i, j)] >= result.gammas[(i, j - 1)],
2969                    "Warp for curve {i} should be monotone at j={j}"
2970                );
2971            }
2972        }
2973    }
2974
2975    #[test]
2976    fn test_align_to_target_aligned_data_finite() {
2977        let data = make_test_data(6, 40, 42);
2978        let t = uniform_grid(40);
2979        let target = data.row(0);
2980        let result = align_to_target(&data, &target, &t, 0.0);
2981
2982        for i in 0..6 {
2983            for j in 0..40 {
2984                assert!(
2985                    result.aligned_data[(i, j)].is_finite(),
2986                    "Aligned data should be finite at ({i},{j})"
2987                );
2988            }
2989        }
2990    }
2991
2992    // ── Cross distance matrix ──
2993
2994    #[test]
2995    fn test_cross_distance_matrix_shape() {
2996        let data1 = make_test_data(3, 30, 42);
2997        let data2 = make_test_data(4, 30, 99);
2998        let t = uniform_grid(30);
2999
3000        let dm = elastic_cross_distance_matrix(&data1, &data2, &t, 0.0);
3001        assert_eq!(dm.shape(), (3, 4));
3002
3003        // All non-negative
3004        for i in 0..3 {
3005            for j in 0..4 {
3006                assert!(dm[(i, j)] >= 0.0);
3007            }
3008        }
3009    }
3010
3011    #[test]
3012    fn test_cross_distance_matrix_self_matches_self_matrix() {
3013        // cross_distance(data, data) should have zero diagonal (approximately)
3014        let data = make_test_data(4, 30, 42);
3015        let t = uniform_grid(30);
3016
3017        let cross = elastic_cross_distance_matrix(&data, &data, &t, 0.0);
3018        for i in 0..4 {
3019            assert!(
3020                cross[(i, i)] < 0.1,
3021                "Cross distance (self) diagonal should be near zero: got {}",
3022                cross[(i, i)]
3023            );
3024        }
3025    }
3026
3027    #[test]
3028    fn test_cross_distance_matrix_consistent_with_pairwise() {
3029        let data1 = make_test_data(3, 30, 42);
3030        let data2 = make_test_data(2, 30, 99);
3031        let t = uniform_grid(30);
3032
3033        let dm = elastic_cross_distance_matrix(&data1, &data2, &t, 0.0);
3034
3035        for i in 0..3 {
3036            for j in 0..2 {
3037                let fi = data1.row(i);
3038                let fj = data2.row(j);
3039                let d_direct = elastic_distance(&fi, &fj, &t, 0.0);
3040                assert!(
3041                    (dm[(i, j)] - d_direct).abs() < 1e-10,
3042                    "Cross matrix ({i},{j})={:.6} should match pairwise {d_direct:.6}",
3043                    dm[(i, j)]
3044                );
3045            }
3046        }
3047    }
3048
3049    // ── align_srsf_pair ──
3050
3051    #[test]
3052    fn test_align_srsf_pair_identity() {
3053        let m = 50;
3054        let t = uniform_grid(m);
3055        let f: Vec<f64> = t
3056            .iter()
3057            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
3058            .collect();
3059        let q = srsf_single(&f, &t);
3060
3061        let (gamma, q_aligned) = align_srsf_pair(&q, &q, &t, 0.0);
3062
3063        // Warp should be near identity
3064        for j in 0..m {
3065            assert!(
3066                (gamma[j] - t[j]).abs() < 0.15,
3067                "Self-SRSF alignment warp should be near identity at j={j}"
3068            );
3069        }
3070
3071        // Aligned SRSF should be close to original
3072        let weights = simpsons_weights(&t);
3073        let dist = l2_distance(&q, &q_aligned, &weights);
3074        assert!(
3075            dist < 0.5,
3076            "Self-aligned SRSF distance should be small, got {dist}"
3077        );
3078    }
3079
3080    // ── srsf_single ──
3081
3082    #[test]
3083    fn test_srsf_single_matches_matrix_version() {
3084        let m = 50;
3085        let t = uniform_grid(m);
3086        let f: Vec<f64> = t.iter().map(|&ti| ti * ti + ti).collect();
3087
3088        let q_single = srsf_single(&f, &t);
3089
3090        let mat = FdMatrix::from_slice(&f, 1, m).unwrap();
3091        let q_mat = srsf_transform(&mat, &t);
3092        let q_from_mat = q_mat.row(0);
3093
3094        for j in 0..m {
3095            assert!(
3096                (q_single[j] - q_from_mat[j]).abs() < 1e-12,
3097                "srsf_single should match srsf_transform at j={j}"
3098            );
3099        }
3100    }
3101
3102    // ── gcd ──
3103
3104    #[test]
3105    fn test_gcd_basic() {
3106        assert_eq!(gcd(1, 1), 1);
3107        assert_eq!(gcd(6, 4), 2);
3108        assert_eq!(gcd(7, 5), 1);
3109        assert_eq!(gcd(12, 8), 4);
3110        assert_eq!(gcd(7, 0), 7);
3111        assert_eq!(gcd(0, 5), 5);
3112    }
3113
3114    // ── generate_coprime_nbhd ──
3115
3116    #[test]
3117    fn test_coprime_nbhd_count() {
3118        assert_eq!(generate_coprime_nbhd(1).len(), 1); // just (1,1)
3119        assert_eq!(generate_coprime_nbhd(7).len(), 35);
3120    }
3121
3122    #[test]
3123    fn test_coprime_nbhd_matches_const() {
3124        let generated = generate_coprime_nbhd(7);
3125        assert_eq!(generated.len(), COPRIME_NBHD_7.len());
3126        for (i, pair) in generated.iter().enumerate() {
3127            assert_eq!(*pair, COPRIME_NBHD_7[i], "mismatch at index {i}");
3128        }
3129    }
3130
3131    #[test]
3132    fn test_coprime_nbhd_all_coprime() {
3133        for &(i, j) in &COPRIME_NBHD_7 {
3134            assert_eq!(gcd(i, j), 1, "({i},{j}) should be coprime");
3135            assert!((1..=7).contains(&i));
3136            assert!((1..=7).contains(&j));
3137        }
3138    }
3139
3140    // ── dp_edge_weight ──
3141
3142    #[test]
3143    fn test_dp_edge_weight_diagonal() {
3144        // Diagonal move (1,1): weight = (q1[sc] - sqrt(1)*q2[sr])^2 * h
3145        let t = uniform_grid(10);
3146        let q1 = vec![1.0; 10];
3147        let q2 = vec![1.0; 10];
3148        // Identical SRSFs: weight should be 0
3149        let w = dp_edge_weight(&q1, &q2, &t, 0, 1, 0, 1);
3150        assert!(w.abs() < 1e-12, "identical SRSFs should have zero cost");
3151    }
3152
3153    #[test]
3154    fn test_dp_edge_weight_non_diagonal() {
3155        // Move (1,2): n1=2, n2=1, slope = h/(2h) = 0.5
3156        let t = uniform_grid(10);
3157        let q1 = vec![1.0; 10];
3158        let q2 = vec![0.0; 10];
3159        let w = dp_edge_weight(&q1, &q2, &t, 0, 2, 0, 1);
3160        // diff = q1[0] - sqrt(0.5)*q2[0] = 1.0 - 0 = 1.0
3161        // weight = 1.0^2 * 1.0 * (t[2]-t[0]) = 2/9
3162        let expected = 2.0 / 9.0;
3163        assert!(
3164            (w - expected).abs() < 1e-10,
3165            "dp_edge_weight (1,2): expected {expected}, got {w}"
3166        );
3167    }
3168
3169    #[test]
3170    fn test_dp_edge_weight_zero_span() {
3171        let t = uniform_grid(10);
3172        let q1 = vec![1.0; 10];
3173        let q2 = vec![1.0; 10];
3174        // n1=0: should return INFINITY
3175        assert_eq!(dp_edge_weight(&q1, &q2, &t, 3, 3, 0, 1), f64::INFINITY);
3176        // n2=0: should return INFINITY
3177        assert_eq!(dp_edge_weight(&q1, &q2, &t, 0, 1, 3, 3), f64::INFINITY);
3178    }
3179
3180    // ── DP alignment quality ──
3181
3182    #[test]
3183    fn test_alignment_improves_distance() {
3184        // Aligned SRSF distance should be less than unaligned SRSF distance
3185        let m = 50;
3186        let t = uniform_grid(m);
3187        let f1: Vec<f64> = t
3188            .iter()
3189            .map(|&x| (2.0 * std::f64::consts::PI * x).sin())
3190            .collect();
3191        // Use a larger shift so improvement is clear
3192        let f2: Vec<f64> = t
3193            .iter()
3194            .map(|&x| (2.0 * std::f64::consts::PI * (x + 0.2)).sin())
3195            .collect();
3196
3197        let q1 = srsf_single(&f1, &t);
3198        let q2 = srsf_single(&f2, &t);
3199        let weights = simpsons_weights(&t);
3200        let unaligned_srsf_dist = l2_distance(&q1, &q2, &weights);
3201
3202        let result = elastic_align_pair(&f1, &f2, &t, 0.0);
3203
3204        assert!(
3205            result.distance <= unaligned_srsf_dist + 1e-6,
3206            "aligned SRSF dist ({}) should be <= unaligned SRSF dist ({})",
3207            result.distance,
3208            unaligned_srsf_dist
3209        );
3210    }
3211
3212    // ── Edge case: constant data ──
3213
3214    #[test]
3215    fn test_alignment_constant_curves() {
3216        let m = 30;
3217        let t = uniform_grid(m);
3218        let f1 = vec![5.0; m];
3219        let f2 = vec![5.0; m];
3220
3221        let result = elastic_align_pair(&f1, &f2, &t, 0.0);
3222        assert!(
3223            result.distance < 0.01,
3224            "Constant curves: distance should be ~0"
3225        );
3226        assert_eq!(result.f_aligned.len(), m);
3227    }
3228
3229    #[test]
3230    fn test_karcher_mean_constant_curves() {
3231        let m = 30;
3232        let t = uniform_grid(m);
3233        let mut data = FdMatrix::zeros(5, m);
3234        for i in 0..5 {
3235            for j in 0..m {
3236                data[(i, j)] = 3.0;
3237            }
3238        }
3239
3240        let result = karcher_mean(&data, &t, 5, 1e-4, 0.0);
3241        for j in 0..m {
3242            assert!(
3243                (result.mean[j] - 3.0).abs() < 0.5,
3244                "Mean of constant curves should be near 3.0, got {} at j={j}",
3245                result.mean[j]
3246            );
3247        }
3248    }
3249
3250    #[test]
3251    fn test_nan_srsf_no_panic() {
3252        let m = 20;
3253        let t = uniform_grid(m);
3254        let mut f = vec![1.0; m];
3255        f[5] = f64::NAN;
3256        let mat = FdMatrix::from_slice(&f, 1, m).unwrap();
3257        let q = srsf_transform(&mat, &t);
3258        // Should not panic; NaN propagates
3259        assert_eq!(q.nrows(), 1);
3260    }
3261
3262    #[test]
3263    fn test_n1_karcher_mean() {
3264        let m = 30;
3265        let t = uniform_grid(m);
3266        let f: Vec<f64> = t.iter().map(|&ti| ti * ti).collect();
3267        let data = FdMatrix::from_slice(&f, 1, m).unwrap();
3268        let result = karcher_mean(&data, &t, 5, 1e-4, 0.0);
3269        assert_eq!(result.mean.len(), m);
3270        // With only 1 curve, the mean should be close to the original
3271        for j in 0..m {
3272            assert!(result.mean[j].is_finite());
3273        }
3274    }
3275
3276    #[test]
3277    fn test_two_point_grid() {
3278        let t = vec![0.0, 1.0];
3279        let f1 = vec![0.0, 1.0];
3280        let f2 = vec![0.0, 2.0];
3281        let d = elastic_distance(&f1, &f2, &t, 0.0);
3282        assert!(d >= 0.0);
3283        assert!(d.is_finite());
3284    }
3285
3286    #[test]
3287    fn test_non_uniform_grid_alignment() {
3288        // Non-uniform grid: points clustered near 0
3289        let t = vec![0.0, 0.01, 0.05, 0.2, 0.5, 1.0];
3290        let m = t.len();
3291        let f1: Vec<f64> = t.iter().map(|&ti: &f64| ti.sin()).collect();
3292        let f2: Vec<f64> = t.iter().map(|&ti: &f64| (ti + 0.1).sin()).collect();
3293        let result = elastic_align_pair(&f1, &f2, &t, 0.0);
3294        assert_eq!(result.gamma.len(), m);
3295        assert!(result.distance >= 0.0);
3296        assert!(result.distance.is_finite());
3297    }
3298
3299    // ── TSRVF tests ──
3300
3301    #[test]
3302    fn test_tsrvf_output_shape() {
3303        let m = 50;
3304        let n = 10;
3305        let t = uniform_grid(m);
3306        let data = make_test_data(n, m, 42);
3307        let result = tsrvf_transform(&data, &t, 5, 1e-3, 0.0);
3308        assert_eq!(
3309            result.tangent_vectors.shape(),
3310            (n, m),
3311            "Tangent vectors should be n×m"
3312        );
3313        assert_eq!(result.gammas.shape(), (n, m), "Gammas should be n×m");
3314        assert_eq!(result.srsf_norms.len(), n, "Should have n SRSF norms");
3315        assert_eq!(result.mean.len(), m, "Mean should have m points");
3316        assert_eq!(result.mean_srsf.len(), m, "Mean SRSF should have m points");
3317    }
3318
3319    #[test]
3320    fn test_tsrvf_all_finite() {
3321        let m = 50;
3322        let n = 5;
3323        let t = uniform_grid(m);
3324        let data = make_test_data(n, m, 42);
3325        let result = tsrvf_transform(&data, &t, 5, 1e-3, 0.0);
3326        for i in 0..n {
3327            for j in 0..m {
3328                assert!(
3329                    result.tangent_vectors[(i, j)].is_finite(),
3330                    "Tangent vector should be finite at ({i},{j})"
3331                );
3332            }
3333            assert!(
3334                result.srsf_norms[i].is_finite(),
3335                "SRSF norm should be finite for curve {i}"
3336            );
3337        }
3338        assert!(
3339            result.mean_srsf_norm.is_finite(),
3340            "Mean SRSF norm should be finite"
3341        );
3342    }
3343
3344    #[test]
3345    fn test_tsrvf_identical_curves_zero_tangent() {
3346        let m = 50;
3347        let t = uniform_grid(m);
3348        // Stack 5 identical curves
3349        let curve: Vec<f64> = t
3350            .iter()
3351            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
3352            .collect();
3353        let mut col_major = vec![0.0; 5 * m];
3354        for i in 0..5 {
3355            for j in 0..m {
3356                col_major[i + j * 5] = curve[j];
3357            }
3358        }
3359        let data = FdMatrix::from_column_major(col_major, 5, m).unwrap();
3360        let result = tsrvf_transform(&data, &t, 10, 1e-4, 0.0);
3361
3362        // All tangent vectors should be approximately zero
3363        for i in 0..5 {
3364            let tv_norm_sq: f64 = (0..m).map(|j| result.tangent_vectors[(i, j)].powi(2)).sum();
3365            assert!(
3366                tv_norm_sq.sqrt() < 0.5,
3367                "Identical curves should have near-zero tangent vectors, got norm = {}",
3368                tv_norm_sq.sqrt()
3369            );
3370        }
3371    }
3372
3373    #[test]
3374    fn test_tsrvf_mean_tangent_near_zero() {
3375        let m = 50;
3376        let n = 10;
3377        let t = uniform_grid(m);
3378        let data = make_test_data(n, m, 42);
3379        let result = tsrvf_transform(&data, &t, 10, 1e-3, 0.0);
3380
3381        // Mean of tangent vectors should be approximately zero (property of Karcher mean)
3382        let mut mean_tv = vec![0.0; m];
3383        for i in 0..n {
3384            for j in 0..m {
3385                mean_tv[j] += result.tangent_vectors[(i, j)];
3386            }
3387        }
3388        for j in 0..m {
3389            mean_tv[j] /= n as f64;
3390        }
3391        let mean_norm: f64 = mean_tv.iter().map(|v| v * v).sum::<f64>().sqrt();
3392        assert!(
3393            mean_norm < 1.0,
3394            "Mean tangent vector should be near zero, got norm = {mean_norm}"
3395        );
3396    }
3397
3398    #[test]
3399    fn test_tsrvf_from_alignment() {
3400        let m = 50;
3401        let n = 5;
3402        let t = uniform_grid(m);
3403        let data = make_test_data(n, m, 42);
3404        let karcher = karcher_mean(&data, &t, 5, 1e-3, 0.0);
3405        let result = tsrvf_from_alignment(&karcher, &t);
3406        assert_eq!(result.tangent_vectors.shape(), (n, m));
3407        assert!(result.mean_srsf_norm > 0.0);
3408    }
3409
3410    #[test]
3411    fn test_tsrvf_round_trip() {
3412        let m = 50;
3413        let n = 5;
3414        let t = uniform_grid(m);
3415        let data = make_test_data(n, m, 42);
3416        let result = tsrvf_transform(&data, &t, 10, 1e-3, 0.0);
3417        let reconstructed = tsrvf_inverse(&result, &t);
3418
3419        assert_eq!(reconstructed.shape(), result.tangent_vectors.shape());
3420        // Reconstructed curves should have finite values
3421        for i in 0..n {
3422            for j in 0..m {
3423                assert!(
3424                    reconstructed[(i, j)].is_finite(),
3425                    "Reconstructed curve should be finite at ({i},{j})"
3426                );
3427            }
3428        }
3429        // Issue #12: per-curve initial values should be preserved
3430        for i in 0..n {
3431            assert!(
3432                (reconstructed[(i, 0)] - result.initial_values[i]).abs() < 1e-6,
3433                "Curve {i} initial value: expected {}, got {}",
3434                result.initial_values[i],
3435                reconstructed[(i, 0)]
3436            );
3437        }
3438    }
3439
3440    #[test]
3441    fn test_tsrvf_initial_values_per_curve() {
3442        // Issue #12: tsrvf_inverse must use per-curve initial values, not mean[0]
3443        let m = 50;
3444        let n = 5;
3445        let t = uniform_grid(m);
3446
3447        // Create curves with distinct initial values
3448        let mut col_major = vec![0.0; n * m];
3449        for i in 0..n {
3450            let offset = (i as f64 + 1.0) * 2.0; // offsets: 2, 4, 6, 8, 10
3451            for j in 0..m {
3452                col_major[i + j * n] = offset + (2.0 * std::f64::consts::PI * t[j]).sin();
3453            }
3454        }
3455        let data = FdMatrix::from_column_major(col_major, n, m).unwrap();
3456
3457        let result = tsrvf_transform(&data, &t, 15, 1e-4, 0.0);
3458
3459        // initial_values should differ per curve
3460        assert_eq!(result.initial_values.len(), n);
3461        let all_same = result
3462            .initial_values
3463            .windows(2)
3464            .all(|w| (w[0] - w[1]).abs() < 1e-10);
3465        assert!(
3466            !all_same,
3467            "Initial values should differ per curve: {:?}",
3468            result.initial_values
3469        );
3470
3471        // Reconstruct and check initial values are preserved
3472        let reconstructed = tsrvf_inverse(&result, &t);
3473        for i in 0..n {
3474            assert!(
3475                (reconstructed[(i, 0)] - result.initial_values[i]).abs() < 1e-6,
3476                "Curve {i}: reconstructed f(0) = {}, expected {}",
3477                reconstructed[(i, 0)],
3478                result.initial_values[i]
3479            );
3480        }
3481
3482        // Before the fix, all curves would have reconstructed[(i, 0)] ≈ mean[0]
3483        // Verify they are NOT all the same
3484        let recon_initials: Vec<f64> = (0..n).map(|i| reconstructed[(i, 0)]).collect();
3485        let all_recon_same = recon_initials.windows(2).all(|w| (w[0] - w[1]).abs() < 0.1);
3486        assert!(
3487            !all_recon_same,
3488            "Reconstructed initial values must vary per curve: {:?}",
3489            recon_initials
3490        );
3491    }
3492
3493    #[test]
3494    fn test_tsrvf_single_curve() {
3495        let m = 50;
3496        let t = uniform_grid(m);
3497        let data = make_test_data(1, m, 42);
3498        let result = tsrvf_transform(&data, &t, 5, 1e-3, 0.0);
3499        assert_eq!(result.tangent_vectors.shape(), (1, m));
3500        // Single curve → tangent vector should be zero (it IS the mean)
3501        let tv_norm: f64 = (0..m)
3502            .map(|j| result.tangent_vectors[(0, j)].powi(2))
3503            .sum::<f64>()
3504            .sqrt();
3505        assert!(
3506            tv_norm < 0.5,
3507            "Single curve tangent vector should be near zero, got {tv_norm}"
3508        );
3509    }
3510
3511    #[test]
3512    fn test_tsrvf_constant_curves() {
3513        let m = 30;
3514        let t = uniform_grid(m);
3515        // Constant curves → SRSF = 0, norms = 0
3516        let data = FdMatrix::from_column_major(vec![5.0; 3 * m], 3, m).unwrap();
3517        let result = tsrvf_transform(&data, &t, 5, 1e-3, 0.0);
3518        // Should not produce NaN or Inf
3519        for i in 0..3 {
3520            for j in 0..m {
3521                let v = result.tangent_vectors[(i, j)];
3522                assert!(
3523                    !v.is_nan(),
3524                    "Constant curves should not produce NaN tangent vectors"
3525                );
3526            }
3527        }
3528    }
3529
3530    // ── Reference-value tests (sphere geometry) ─────────────────────────────
3531
3532    #[test]
3533    fn test_tsrvf_sphere_inv_exp_reference() {
3534        // Analytical reference: known vectors on the Hilbert sphere
3535        // psi1 = constant (unit L2 norm), psi2 = 1 + 0.3*sin(2πt) (normalized)
3536        let m = 21;
3537        let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
3538
3539        // Construct unit vector psi1 (constant)
3540        let raw1 = vec![1.0; m];
3541        let norm1 = inner_product_l2(&raw1, &raw1, &time).max(0.0).sqrt();
3542        let psi1: Vec<f64> = raw1.iter().map(|&v| v / norm1).collect();
3543
3544        // Construct psi2 with sinusoidal perturbation
3545        let raw2: Vec<f64> = time
3546            .iter()
3547            .map(|&t| 1.0 + 0.3 * (2.0 * std::f64::consts::PI * t).sin())
3548            .collect();
3549        let norm2 = inner_product_l2(&raw2, &raw2, &time).max(0.0).sqrt();
3550        let psi2: Vec<f64> = raw2.iter().map(|&v| v / norm2).collect();
3551
3552        // Compute theta analytically
3553        let ip = inner_product_l2(&psi1, &psi2, &time).clamp(-1.0, 1.0);
3554        let theta_expected = ip.acos();
3555
3556        // Compute via inv_exp_map_sphere
3557        let v = inv_exp_map_sphere(&psi1, &psi2, &time);
3558        let v_norm = inner_product_l2(&v, &v, &time).max(0.0).sqrt();
3559
3560        // ||v|| should equal theta
3561        assert!(
3562            (v_norm - theta_expected).abs() < 1e-10,
3563            "||v|| = {v_norm}, expected theta = {theta_expected}"
3564        );
3565
3566        // theta should be small but non-zero (perturbation is mild)
3567        assert!(
3568            theta_expected > 0.01 && theta_expected < 1.0,
3569            "theta = {theta_expected} out of expected range"
3570        );
3571    }
3572
3573    #[test]
3574    fn test_tsrvf_sphere_round_trip_reference() {
3575        // Round-trip: exp_map(psi1, inv_exp_map(psi1, psi2)) should recover psi2
3576        let m = 21;
3577        let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
3578
3579        let raw1 = vec![1.0; m];
3580        let norm1 = inner_product_l2(&raw1, &raw1, &time).max(0.0).sqrt();
3581        let psi1: Vec<f64> = raw1.iter().map(|&v| v / norm1).collect();
3582
3583        let raw2: Vec<f64> = time
3584            .iter()
3585            .map(|&t| 1.0 + 0.3 * (2.0 * std::f64::consts::PI * t).sin())
3586            .collect();
3587        let norm2 = inner_product_l2(&raw2, &raw2, &time).max(0.0).sqrt();
3588        let psi2: Vec<f64> = raw2.iter().map(|&v| v / norm2).collect();
3589
3590        let v = inv_exp_map_sphere(&psi1, &psi2, &time);
3591        let recovered = exp_map_sphere(&psi1, &v, &time);
3592
3593        // L2 error between psi2 and recovered
3594        let diff: Vec<f64> = psi2
3595            .iter()
3596            .zip(recovered.iter())
3597            .map(|(&a, &b)| (a - b).powi(2))
3598            .collect();
3599        let l2_err = trapz(&diff, &time).max(0.0).sqrt();
3600        assert!(
3601            l2_err < 1e-12,
3602            "Round-trip L2 error = {l2_err:.2e}, expected < 1e-12"
3603        );
3604    }
3605
3606    // ── Penalized alignment ──
3607
3608    #[test]
3609    fn test_penalized_alignment_lambda_zero_matches_unpenalized() {
3610        let m = 50;
3611        let t = uniform_grid(m);
3612        let data = make_test_data(2, m, 42);
3613        let f1 = data.row(0);
3614        let f2 = data.row(1);
3615
3616        let r0 = elastic_align_pair(&f1, &f2, &t, 0.0);
3617        // lambda = 0.0 should produce the same result regardless
3618        assert!(r0.distance >= 0.0);
3619        assert_eq!(r0.gamma.len(), m);
3620    }
3621
3622    #[test]
3623    fn test_penalized_alignment_smoother_warp() {
3624        let m = 80;
3625        let t = uniform_grid(m);
3626        let f1: Vec<f64> = t
3627            .iter()
3628            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
3629            .collect();
3630        let f2: Vec<f64> = t
3631            .iter()
3632            .map(|&ti| (2.0 * std::f64::consts::PI * (ti - 0.15)).sin())
3633            .collect();
3634
3635        let r_free = elastic_align_pair(&f1, &f2, &t, 0.0);
3636        let r_pen = elastic_align_pair(&f1, &f2, &t, 1.0);
3637
3638        // Measure warp deviation from identity
3639        let dev_free: f64 = r_free
3640            .gamma
3641            .iter()
3642            .zip(t.iter())
3643            .map(|(g, ti)| (g - ti).powi(2))
3644            .sum();
3645        let dev_pen: f64 = r_pen
3646            .gamma
3647            .iter()
3648            .zip(t.iter())
3649            .map(|(g, ti)| (g - ti).powi(2))
3650            .sum();
3651
3652        assert!(
3653            dev_pen <= dev_free + 1e-6,
3654            "Penalized warp should be closer to identity: free={dev_free:.6}, pen={dev_pen:.6}"
3655        );
3656    }
3657
3658    #[test]
3659    fn test_penalized_alignment_large_lambda_near_identity() {
3660        let m = 50;
3661        let t = uniform_grid(m);
3662        let f1: Vec<f64> = t
3663            .iter()
3664            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
3665            .collect();
3666        let f2: Vec<f64> = t
3667            .iter()
3668            .map(|&ti| (2.0 * std::f64::consts::PI * (ti - 0.1)).sin())
3669            .collect();
3670
3671        let r = elastic_align_pair(&f1, &f2, &t, 1000.0);
3672
3673        // With very large lambda, warp should be very close to identity
3674        let max_dev: f64 = r
3675            .gamma
3676            .iter()
3677            .zip(t.iter())
3678            .map(|(g, ti)| (g - ti).abs())
3679            .fold(0.0_f64, f64::max);
3680        assert!(
3681            max_dev < 0.05,
3682            "Large lambda should give near-identity warp: max deviation = {max_dev}"
3683        );
3684    }
3685
3686    #[test]
3687    fn test_penalized_karcher_mean() {
3688        let m = 40;
3689        let t = uniform_grid(m);
3690        let data = make_test_data(10, m, 42);
3691
3692        let result = karcher_mean(&data, &t, 5, 1e-3, 0.5);
3693        assert_eq!(result.mean.len(), m);
3694        for j in 0..m {
3695            assert!(result.mean[j].is_finite());
3696        }
3697    }
3698
3699    // ── Phase-amplitude decomposition ──
3700
3701    #[test]
3702    fn test_decomposition_identity_curves() {
3703        let m = 50;
3704        let t = uniform_grid(m);
3705        let f: Vec<f64> = t
3706            .iter()
3707            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
3708            .collect();
3709
3710        let result = elastic_decomposition(&f, &f, &t, 0.0);
3711        assert!(
3712            result.d_amplitude < 0.1,
3713            "Self-decomposition amplitude should be ~0, got {}",
3714            result.d_amplitude
3715        );
3716        assert!(
3717            result.d_phase < 0.2,
3718            "Self-decomposition phase should be ~0, got {}",
3719            result.d_phase
3720        );
3721    }
3722
3723    #[test]
3724    fn test_decomposition_pythagorean() {
3725        // d_total² ≈ d_a² + d_φ² (approximately, for the Fisher-Rao metric)
3726        let m = 80;
3727        let t = uniform_grid(m);
3728        let f1: Vec<f64> = t
3729            .iter()
3730            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
3731            .collect();
3732        let f2: Vec<f64> = t
3733            .iter()
3734            .map(|&ti| 1.2 * (2.0 * std::f64::consts::PI * (ti - 0.1)).sin())
3735            .collect();
3736
3737        let result = elastic_decomposition(&f1, &f2, &t, 0.0);
3738        let da = result.d_amplitude;
3739        let dp = result.d_phase;
3740        // Both should be non-negative
3741        assert!(da >= 0.0);
3742        assert!(dp >= 0.0);
3743    }
3744
3745    #[test]
3746    fn test_phase_distance_shifted_sine() {
3747        let m = 80;
3748        let t = uniform_grid(m);
3749        let f1: Vec<f64> = t
3750            .iter()
3751            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
3752            .collect();
3753        let f2: Vec<f64> = t
3754            .iter()
3755            .map(|&ti| (2.0 * std::f64::consts::PI * (ti - 0.15)).sin())
3756            .collect();
3757
3758        let dp = phase_distance_pair(&f1, &f2, &t, 0.0);
3759        assert!(
3760            dp > 0.01,
3761            "Phase distance of shifted curves should be > 0, got {dp}"
3762        );
3763    }
3764
3765    #[test]
3766    fn test_amplitude_distance_scaled_curve() {
3767        let m = 80;
3768        let t = uniform_grid(m);
3769        let f1: Vec<f64> = t
3770            .iter()
3771            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
3772            .collect();
3773        let f2: Vec<f64> = t
3774            .iter()
3775            .map(|&ti| 2.0 * (2.0 * std::f64::consts::PI * ti).sin())
3776            .collect();
3777
3778        let da = amplitude_distance(&f1, &f2, &t, 0.0);
3779        assert!(
3780            da > 0.01,
3781            "Amplitude distance of scaled curves should be > 0, got {da}"
3782        );
3783    }
3784
3785    #[test]
3786    fn test_phase_distance_nonneg() {
3787        let data = make_test_data(4, 40, 42);
3788        let t = uniform_grid(40);
3789        for i in 0..4 {
3790            for j in 0..4 {
3791                let fi = data.row(i);
3792                let fj = data.row(j);
3793                let dp = phase_distance_pair(&fi, &fj, &t, 0.0);
3794                assert!(dp >= 0.0, "Phase distance should be non-negative");
3795            }
3796        }
3797    }
3798
3799    // ── Parallel transport ──
3800
3801    #[test]
3802    fn test_schilds_ladder_zero_vector() {
3803        let m = 21;
3804        let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
3805        let raw = vec![1.0; m];
3806        let norm = crate::warping::l2_norm_l2(&raw, &time);
3807        let from: Vec<f64> = raw.iter().map(|&v| v / norm).collect();
3808        let raw2: Vec<f64> = time
3809            .iter()
3810            .map(|&t| 1.0 + 0.2 * (2.0 * std::f64::consts::PI * t).sin())
3811            .collect();
3812        let norm2 = crate::warping::l2_norm_l2(&raw2, &time);
3813        let to: Vec<f64> = raw2.iter().map(|&v| v / norm2).collect();
3814
3815        let zero = vec![0.0; m];
3816        let result = parallel_transport_schilds(&zero, &from, &to, &time);
3817        let result_norm: f64 = result.iter().map(|v| v * v).sum::<f64>().sqrt();
3818        assert!(
3819            result_norm < 1e-6,
3820            "Transporting zero should give zero, got norm {result_norm}"
3821        );
3822    }
3823
3824    #[test]
3825    fn test_pole_ladder_zero_vector() {
3826        let m = 21;
3827        let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
3828        let raw = vec![1.0; m];
3829        let norm = crate::warping::l2_norm_l2(&raw, &time);
3830        let from: Vec<f64> = raw.iter().map(|&v| v / norm).collect();
3831        let raw2: Vec<f64> = time
3832            .iter()
3833            .map(|&t| 1.0 + 0.2 * (2.0 * std::f64::consts::PI * t).sin())
3834            .collect();
3835        let norm2 = crate::warping::l2_norm_l2(&raw2, &time);
3836        let to: Vec<f64> = raw2.iter().map(|&v| v / norm2).collect();
3837
3838        let zero = vec![0.0; m];
3839        let result = parallel_transport_pole(&zero, &from, &to, &time);
3840        let result_norm: f64 = result.iter().map(|v| v * v).sum::<f64>().sqrt();
3841        assert!(
3842            result_norm < 1e-6,
3843            "Transporting zero should give zero, got norm {result_norm}"
3844        );
3845    }
3846
3847    #[test]
3848    fn test_schilds_preserves_norm() {
3849        let m = 51;
3850        let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
3851        let raw = vec![1.0; m];
3852        let norm = crate::warping::l2_norm_l2(&raw, &time);
3853        let from: Vec<f64> = raw.iter().map(|&v| v / norm).collect();
3854        let raw2: Vec<f64> = time
3855            .iter()
3856            .map(|&t| 1.0 + 0.15 * (2.0 * std::f64::consts::PI * t).sin())
3857            .collect();
3858        let norm2 = crate::warping::l2_norm_l2(&raw2, &time);
3859        let to: Vec<f64> = raw2.iter().map(|&v| v / norm2).collect();
3860
3861        // Small tangent vector
3862        let v: Vec<f64> = time
3863            .iter()
3864            .map(|&t| 0.1 * (4.0 * std::f64::consts::PI * t).cos())
3865            .collect();
3866        let v_norm = crate::warping::l2_norm_l2(&v, &time);
3867
3868        let transported = parallel_transport_schilds(&v, &from, &to, &time);
3869        let t_norm = crate::warping::l2_norm_l2(&transported, &time);
3870
3871        // Norm should be approximately preserved (ladder methods are first-order)
3872        assert!(
3873            (t_norm - v_norm).abs() / v_norm.max(1e-10) < 1.5,
3874            "Schild's should roughly preserve norm: original={v_norm:.4}, transported={t_norm:.4}"
3875        );
3876    }
3877
3878    #[test]
3879    fn test_tsrvf_logmap_matches_original() {
3880        let m = 50;
3881        let n = 5;
3882        let t = uniform_grid(m);
3883        let data = make_test_data(n, m, 42);
3884
3885        let result_orig = tsrvf_transform(&data, &t, 5, 1e-3, 0.0);
3886        let result_logmap =
3887            tsrvf_transform_with_method(&data, &t, 5, 1e-3, 0.0, TransportMethod::LogMap);
3888
3889        // Should be identical (LogMap delegates to original)
3890        for i in 0..n {
3891            for j in 0..m {
3892                assert!(
3893                    (result_orig.tangent_vectors[(i, j)] - result_logmap.tangent_vectors[(i, j)])
3894                        .abs()
3895                        < 1e-12,
3896                    "LogMap variant should match original at ({i},{j})"
3897                );
3898            }
3899        }
3900    }
3901
3902    #[test]
3903    fn test_tsrvf_with_schilds_produces_valid_result() {
3904        let m = 50;
3905        let n = 5;
3906        let t = uniform_grid(m);
3907        let data = make_test_data(n, m, 42);
3908
3909        let result =
3910            tsrvf_transform_with_method(&data, &t, 5, 1e-3, 0.0, TransportMethod::SchildsLadder);
3911
3912        assert_eq!(result.tangent_vectors.shape(), (n, m));
3913        for i in 0..n {
3914            for j in 0..m {
3915                assert!(
3916                    result.tangent_vectors[(i, j)].is_finite(),
3917                    "Schild's TSRVF should produce finite tangent vectors at ({i},{j})"
3918                );
3919            }
3920        }
3921    }
3922
3923    #[test]
3924    fn test_transport_methods_differ() {
3925        let m = 50;
3926        let n = 5;
3927        let t = uniform_grid(m);
3928        let data = make_test_data(n, m, 42);
3929        let karcher = karcher_mean(&data, &t, 5, 1e-3, 0.0);
3930
3931        let r_log = tsrvf_from_alignment_with_method(&karcher, &t, TransportMethod::LogMap);
3932        let r_schilds =
3933            tsrvf_from_alignment_with_method(&karcher, &t, TransportMethod::SchildsLadder);
3934
3935        // Methods should produce different (but related) tangent vectors
3936        let mut total_diff = 0.0;
3937        for i in 0..n {
3938            for j in 0..m {
3939                total_diff +=
3940                    (r_log.tangent_vectors[(i, j)] - r_schilds.tangent_vectors[(i, j)]).abs();
3941            }
3942        }
3943
3944        // They should be non-zero different (unless all curves are identical to mean)
3945        // Just check both produce finite results
3946        assert!(total_diff.is_finite());
3947    }
3948
3949    // ── Alignment quality metrics ──
3950
3951    #[test]
3952    fn test_warp_complexity_identity_is_zero() {
3953        let m = 50;
3954        let t = uniform_grid(m);
3955        let identity = t.clone();
3956        let c = warp_complexity(&identity, &t);
3957        assert!(
3958            c < 1e-10,
3959            "Identity warp should have zero complexity, got {c}"
3960        );
3961    }
3962
3963    #[test]
3964    fn test_warp_complexity_nonidentity_positive() {
3965        let m = 50;
3966        let t = uniform_grid(m);
3967        let gamma: Vec<f64> = t.iter().map(|&ti| ti * ti).collect();
3968        let c = warp_complexity(&gamma, &t);
3969        assert!(
3970            c > 0.01,
3971            "Non-identity warp should have positive complexity, got {c}"
3972        );
3973    }
3974
3975    #[test]
3976    fn test_warp_smoothness_identity_is_zero() {
3977        let m = 50;
3978        let t = uniform_grid(m);
3979        let identity = t.clone();
3980        let s = warp_smoothness(&identity, &t);
3981        assert!(
3982            s < 1e-6,
3983            "Identity warp (constant γ'=1, γ''=0) should have near-zero bending energy, got {s}"
3984        );
3985    }
3986
3987    #[test]
3988    fn test_alignment_quality_basic() {
3989        let m = 50;
3990        let n = 8;
3991        let t = uniform_grid(m);
3992        let data = make_test_data(n, m, 42);
3993        let karcher = karcher_mean(&data, &t, 10, 1e-3, 0.0);
3994        let quality = alignment_quality(&data, &karcher, &t);
3995
3996        // Shape checks
3997        assert_eq!(quality.warp_complexity.len(), n);
3998        assert_eq!(quality.warp_smoothness.len(), n);
3999        assert_eq!(quality.pointwise_variance_ratio.len(), m);
4000
4001        // Non-negativity
4002        assert!(quality.total_variance >= 0.0);
4003        assert!(quality.amplitude_variance >= 0.0);
4004        assert!(quality.phase_variance >= 0.0);
4005        assert!(quality.mean_warp_complexity >= 0.0);
4006        assert!(quality.mean_warp_smoothness >= 0.0);
4007
4008        // Amplitude variance ≤ total variance
4009        assert!(
4010            quality.amplitude_variance <= quality.total_variance + 1e-10,
4011            "Amplitude variance ({}) should be ≤ total variance ({})",
4012            quality.amplitude_variance,
4013            quality.total_variance
4014        );
4015    }
4016
4017    #[test]
4018    fn test_alignment_quality_identical_curves() {
4019        let m = 50;
4020        let t = uniform_grid(m);
4021        let curve: Vec<f64> = t
4022            .iter()
4023            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
4024            .collect();
4025        let mut col_major = vec![0.0; 5 * m];
4026        for i in 0..5 {
4027            for j in 0..m {
4028                col_major[i + j * 5] = curve[j];
4029            }
4030        }
4031        let data = FdMatrix::from_column_major(col_major, 5, m).unwrap();
4032        let karcher = karcher_mean(&data, &t, 5, 1e-3, 0.0);
4033        let quality = alignment_quality(&data, &karcher, &t);
4034
4035        // Identical curves → near-zero variances and warp complexities
4036        assert!(
4037            quality.total_variance < 0.01,
4038            "Identical curves should have near-zero total variance, got {}",
4039            quality.total_variance
4040        );
4041        assert!(
4042            quality.mean_warp_complexity < 0.1,
4043            "Identical curves should have near-zero warp complexity, got {}",
4044            quality.mean_warp_complexity
4045        );
4046    }
4047
4048    #[test]
4049    fn test_alignment_quality_variance_reduction() {
4050        let m = 50;
4051        let n = 10;
4052        let t = uniform_grid(m);
4053        let data = make_test_data(n, m, 42);
4054        let karcher = karcher_mean(&data, &t, 10, 1e-3, 0.0);
4055        let quality = alignment_quality(&data, &karcher, &t);
4056
4057        // Mean variance ratio should be ≤ ~1 (alignment shouldn't increase variance)
4058        assert!(
4059            quality.mean_variance_reduction <= 1.5,
4060            "Mean variance reduction ratio should be ≤ ~1, got {}",
4061            quality.mean_variance_reduction
4062        );
4063    }
4064
4065    #[test]
4066    fn test_pairwise_consistency_small() {
4067        let m = 40;
4068        let n = 4;
4069        let t = uniform_grid(m);
4070        let data = make_test_data(n, m, 42);
4071
4072        let consistency = pairwise_consistency(&data, &t, 0.0, 100);
4073        assert!(
4074            consistency.is_finite() && consistency >= 0.0,
4075            "Pairwise consistency should be finite and non-negative, got {consistency}"
4076        );
4077    }
4078
4079    // ── Multidimensional SRSF ──
4080
4081    #[test]
4082    fn test_srsf_nd_d1_matches_existing() {
4083        let m = 50;
4084        let t = uniform_grid(m);
4085        let data = make_test_data(3, m, 42);
4086
4087        // 1D via existing function
4088        let q_1d = srsf_transform(&data, &t);
4089
4090        // 1D via nd function
4091        let data_nd = FdCurveSet::from_1d(data);
4092        let q_nd = srsf_transform_nd(&data_nd, &t);
4093
4094        assert_eq!(q_nd.ndim(), 1);
4095        for i in 0..3 {
4096            for j in 0..m {
4097                assert!(
4098                    (q_1d[(i, j)] - q_nd.dims[0][(i, j)]).abs() < 1e-10,
4099                    "1D nd SRSF should match existing at ({i},{j}): {} vs {}",
4100                    q_1d[(i, j)],
4101                    q_nd.dims[0][(i, j)]
4102                );
4103            }
4104        }
4105    }
4106
4107    #[test]
4108    fn test_srsf_nd_constant_is_zero() {
4109        let m = 30;
4110        let t = uniform_grid(m);
4111        // Constant R^2 curve: f(t) = (3.0, -1.0)
4112        let dim0 = FdMatrix::from_column_major(vec![3.0; m], 1, m).unwrap();
4113        let dim1 = FdMatrix::from_column_major(vec![-1.0; m], 1, m).unwrap();
4114        let data = FdCurveSet::from_dims(vec![dim0, dim1]).unwrap();
4115
4116        let q = srsf_transform_nd(&data, &t);
4117        for k in 0..2 {
4118            for j in 0..m {
4119                assert!(
4120                    q.dims[k][(0, j)].abs() < 1e-10,
4121                    "Constant curve SRSF should be zero, dim {k} at {j}: {}",
4122                    q.dims[k][(0, j)]
4123                );
4124            }
4125        }
4126    }
4127
4128    #[test]
4129    fn test_srsf_nd_linear_r2() {
4130        let m = 51;
4131        let t = uniform_grid(m);
4132        // f(t) = (2t, 3t) → f'(t) = (2, 3), ||f'|| = sqrt(13)
4133        // q(t) = (2, 3) / sqrt(sqrt(13)) = (2, 3) / 13^(1/4)
4134        let dim0 =
4135            FdMatrix::from_slice(&t.iter().map(|&ti| 2.0 * ti).collect::<Vec<_>>(), 1, m).unwrap();
4136        let dim1 =
4137            FdMatrix::from_slice(&t.iter().map(|&ti| 3.0 * ti).collect::<Vec<_>>(), 1, m).unwrap();
4138        let data = FdCurveSet::from_dims(vec![dim0, dim1]).unwrap();
4139
4140        let q = srsf_transform_nd(&data, &t);
4141        let expected_scale = 1.0 / 13.0_f64.powf(0.25);
4142        let mid = m / 2;
4143
4144        assert!(
4145            (q.dims[0][(0, mid)] - 2.0 * expected_scale).abs() < 0.1,
4146            "q_x at midpoint: {} vs expected {}",
4147            q.dims[0][(0, mid)],
4148            2.0 * expected_scale
4149        );
4150        assert!(
4151            (q.dims[1][(0, mid)] - 3.0 * expected_scale).abs() < 0.1,
4152            "q_y at midpoint: {} vs expected {}",
4153            q.dims[1][(0, mid)],
4154            3.0 * expected_scale
4155        );
4156    }
4157
4158    #[test]
4159    fn test_srsf_nd_round_trip() {
4160        let m = 51;
4161        let t = uniform_grid(m);
4162        // f(t) = (sin(2πt), cos(2πt))
4163        let pi2 = 2.0 * std::f64::consts::PI;
4164        let vals_x: Vec<f64> = t.iter().map(|&ti| (pi2 * ti).sin()).collect();
4165        let vals_y: Vec<f64> = t.iter().map(|&ti| (pi2 * ti).cos()).collect();
4166        let dim0 = FdMatrix::from_slice(&vals_x, 1, m).unwrap();
4167        let dim1 = FdMatrix::from_slice(&vals_y, 1, m).unwrap();
4168        let data = FdCurveSet::from_dims(vec![dim0, dim1]).unwrap();
4169
4170        let q = srsf_transform_nd(&data, &t);
4171        let q_vecs: Vec<Vec<f64>> = q.dims.iter().map(|dm| dm.row(0)).collect();
4172        let f0 = vec![vals_x[0], vals_y[0]];
4173        let recon = srsf_inverse_nd(&q_vecs, &t, &f0);
4174
4175        // Check reconstruction error (skip boundary points)
4176        let mut max_err = 0.0_f64;
4177        for k in 0..2 {
4178            let orig = if k == 0 { &vals_x } else { &vals_y };
4179            for j in 2..(m - 2) {
4180                let err = (recon[k][j] - orig[j]).abs();
4181                max_err = max_err.max(err);
4182            }
4183        }
4184        assert!(
4185            max_err < 0.2,
4186            "SRSF round-trip max error should be small, got {max_err}"
4187        );
4188    }
4189
4190    #[test]
4191    fn test_align_nd_identical_near_zero() {
4192        let m = 50;
4193        let t = uniform_grid(m);
4194        let pi2 = 2.0 * std::f64::consts::PI;
4195        let vals_x: Vec<f64> = t.iter().map(|&ti| (pi2 * ti).sin()).collect();
4196        let vals_y: Vec<f64> = t.iter().map(|&ti| (pi2 * ti).cos()).collect();
4197        let dim0 = FdMatrix::from_slice(&vals_x, 1, m).unwrap();
4198        let dim1 = FdMatrix::from_slice(&vals_y, 1, m).unwrap();
4199        let data = FdCurveSet::from_dims(vec![dim0, dim1]).unwrap();
4200
4201        let result = elastic_align_pair_nd(&data, &data, &t, 0.0);
4202        assert!(
4203            result.distance < 0.5,
4204            "Self-alignment distance should be ~0, got {}",
4205            result.distance
4206        );
4207        // Gamma should be near identity
4208        let max_dev: f64 = result
4209            .gamma
4210            .iter()
4211            .zip(t.iter())
4212            .map(|(g, ti)| (g - ti).abs())
4213            .fold(0.0_f64, f64::max);
4214        assert!(
4215            max_dev < 0.1,
4216            "Self-alignment warp should be near identity, max dev = {max_dev}"
4217        );
4218    }
4219
4220    #[test]
4221    fn test_align_nd_shifted_r2() {
4222        let m = 60;
4223        let t = uniform_grid(m);
4224        let pi2 = 2.0 * std::f64::consts::PI;
4225
4226        // f1(t) = (sin(2πt), cos(2πt))
4227        let f1x: Vec<f64> = t.iter().map(|&ti| (pi2 * ti).sin()).collect();
4228        let f1y: Vec<f64> = t.iter().map(|&ti| (pi2 * ti).cos()).collect();
4229        let f1 = FdCurveSet::from_dims(vec![
4230            FdMatrix::from_slice(&f1x, 1, m).unwrap(),
4231            FdMatrix::from_slice(&f1y, 1, m).unwrap(),
4232        ])
4233        .unwrap();
4234
4235        // f2(t) = (sin(2π(t-0.1)), cos(2π(t-0.1))) — phase shifted
4236        let f2x: Vec<f64> = t.iter().map(|&ti| (pi2 * (ti - 0.1)).sin()).collect();
4237        let f2y: Vec<f64> = t.iter().map(|&ti| (pi2 * (ti - 0.1)).cos()).collect();
4238        let f2 = FdCurveSet::from_dims(vec![
4239            FdMatrix::from_slice(&f2x, 1, m).unwrap(),
4240            FdMatrix::from_slice(&f2y, 1, m).unwrap(),
4241        ])
4242        .unwrap();
4243
4244        let result = elastic_align_pair_nd(&f1, &f2, &t, 0.0);
4245        assert!(
4246            result.distance.is_finite(),
4247            "Distance should be finite, got {}",
4248            result.distance
4249        );
4250        assert_eq!(result.f_aligned.len(), 2);
4251        assert_eq!(result.f_aligned[0].len(), m);
4252        // Warp should deviate from identity (non-trivial alignment)
4253        let max_dev: f64 = result
4254            .gamma
4255            .iter()
4256            .zip(t.iter())
4257            .map(|(g, ti)| (g - ti).abs())
4258            .fold(0.0_f64, f64::max);
4259        assert!(
4260            max_dev > 0.01,
4261            "Shifted curves should require non-trivial warp, max dev = {max_dev}"
4262        );
4263    }
4264
4265    // ── Landmark-constrained alignment ──
4266
4267    #[test]
4268    fn test_constrained_no_landmarks_matches_unconstrained() {
4269        let m = 50;
4270        let t = uniform_grid(m);
4271        let f1: Vec<f64> = t
4272            .iter()
4273            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
4274            .collect();
4275        let f2: Vec<f64> = t
4276            .iter()
4277            .map(|&ti| (2.0 * std::f64::consts::PI * (ti - 0.1)).sin())
4278            .collect();
4279
4280        let r_free = elastic_align_pair(&f1, &f2, &t, 0.0);
4281        let r_const = elastic_align_pair_constrained(&f1, &f2, &t, &[], 0.0);
4282
4283        // Should match unconstrained
4284        for j in 0..m {
4285            assert!(
4286                (r_free.gamma[j] - r_const.gamma[j]).abs() < 1e-10,
4287                "No-landmark constrained should match unconstrained at {j}"
4288            );
4289        }
4290        assert!(r_const.enforced_landmarks.is_empty());
4291    }
4292
4293    #[test]
4294    fn test_constrained_single_landmark_enforced() {
4295        let m = 60;
4296        let t = uniform_grid(m);
4297        let f1: Vec<f64> = t
4298            .iter()
4299            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
4300            .collect();
4301        let f2: Vec<f64> = t
4302            .iter()
4303            .map(|&ti| (2.0 * std::f64::consts::PI * (ti - 0.1)).sin())
4304            .collect();
4305
4306        // Constrain midpoint: target_t=0.5 should map to source_t=0.5
4307        let result = elastic_align_pair_constrained(&f1, &f2, &t, &[(0.5, 0.5)], 0.0);
4308
4309        // Gamma at the midpoint should be close to 0.5
4310        let mid_idx = snap_to_grid(0.5, &t);
4311        assert!(
4312            (result.gamma[mid_idx] - 0.5).abs() < 0.05,
4313            "Constrained gamma at midpoint should be ~0.5, got {}",
4314            result.gamma[mid_idx]
4315        );
4316        assert_eq!(result.enforced_landmarks.len(), 1);
4317    }
4318
4319    #[test]
4320    fn test_constrained_multiple_landmarks() {
4321        let m = 80;
4322        let t = uniform_grid(m);
4323        let f1: Vec<f64> = t
4324            .iter()
4325            .map(|&ti| (4.0 * std::f64::consts::PI * ti).sin())
4326            .collect();
4327        let f2: Vec<f64> = t
4328            .iter()
4329            .map(|&ti| (4.0 * std::f64::consts::PI * (ti - 0.05)).sin())
4330            .collect();
4331
4332        let landmarks = vec![(0.25, 0.25), (0.5, 0.5), (0.75, 0.75)];
4333        let result = elastic_align_pair_constrained(&f1, &f2, &t, &landmarks, 0.0);
4334
4335        // Gamma should pass through (or near) each landmark
4336        for &(tt, st) in &landmarks {
4337            let idx = snap_to_grid(tt, &t);
4338            assert!(
4339                (result.gamma[idx] - st).abs() < 0.05,
4340                "Gamma at t={tt} should be ~{st}, got {}",
4341                result.gamma[idx]
4342            );
4343        }
4344    }
4345
4346    #[test]
4347    fn test_constrained_monotone_gamma() {
4348        let m = 60;
4349        let t = uniform_grid(m);
4350        let f1: Vec<f64> = t
4351            .iter()
4352            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
4353            .collect();
4354        let f2: Vec<f64> = t
4355            .iter()
4356            .map(|&ti| (2.0 * std::f64::consts::PI * (ti - 0.1)).sin())
4357            .collect();
4358
4359        let result = elastic_align_pair_constrained(&f1, &f2, &t, &[(0.3, 0.3), (0.7, 0.7)], 0.0);
4360
4361        // Gamma should be non-decreasing
4362        for j in 1..m {
4363            assert!(
4364                result.gamma[j] >= result.gamma[j - 1] - 1e-10,
4365                "Gamma should be monotone: gamma[{}]={} < gamma[{}]={}",
4366                j,
4367                result.gamma[j],
4368                j - 1,
4369                result.gamma[j - 1]
4370            );
4371        }
4372        // Boundary conditions
4373        assert!((result.gamma[0] - t[0]).abs() < 1e-10);
4374        assert!((result.gamma[m - 1] - t[m - 1]).abs() < 1e-10);
4375    }
4376
4377    #[test]
4378    fn test_constrained_distance_ge_unconstrained() {
4379        let m = 60;
4380        let t = uniform_grid(m);
4381        let f1: Vec<f64> = t
4382            .iter()
4383            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
4384            .collect();
4385        let f2: Vec<f64> = t
4386            .iter()
4387            .map(|&ti| (2.0 * std::f64::consts::PI * (ti - 0.15)).sin())
4388            .collect();
4389
4390        let r_free = elastic_align_pair(&f1, &f2, &t, 0.0);
4391        let r_const = elastic_align_pair_constrained(&f1, &f2, &t, &[(0.5, 0.5)], 0.0);
4392
4393        // Constrained distance should be >= unconstrained (constraints reduce freedom)
4394        assert!(
4395            r_const.distance >= r_free.distance - 1e-6,
4396            "Constrained distance ({}) should be >= unconstrained ({})",
4397            r_const.distance,
4398            r_free.distance
4399        );
4400    }
4401
4402    #[test]
4403    fn test_constrained_with_landmark_detection() {
4404        let m = 80;
4405        let t = uniform_grid(m);
4406        let f1: Vec<f64> = t
4407            .iter()
4408            .map(|&ti| (4.0 * std::f64::consts::PI * ti).sin())
4409            .collect();
4410        let f2: Vec<f64> = t
4411            .iter()
4412            .map(|&ti| (4.0 * std::f64::consts::PI * (ti - 0.05)).sin())
4413            .collect();
4414
4415        let result = elastic_align_pair_with_landmarks(
4416            &f1,
4417            &f2,
4418            &t,
4419            crate::landmark::LandmarkKind::Peak,
4420            0.1,
4421            0,
4422            0.0,
4423        );
4424
4425        assert_eq!(result.gamma.len(), m);
4426        assert_eq!(result.f_aligned.len(), m);
4427        assert!(result.distance.is_finite());
4428        // Should be monotone
4429        for j in 1..m {
4430            assert!(
4431                result.gamma[j] >= result.gamma[j - 1] - 1e-10,
4432                "Gamma should be monotone at j={j}"
4433            );
4434        }
4435    }
4436
4437    // ── SRSF smoothing for TSRVF (Issue #13) ──
4438
4439    #[test]
4440    fn test_gam_to_psi_smooth_identity() {
4441        // Smoothed psi of identity warp should stay close to constant 1 in the interior.
4442        // Boundary points are biased by kernel smoothing (fewer neighbors), skip them.
4443        use crate::warping::{gam_to_psi, gam_to_psi_smooth};
4444        let m = 101;
4445        let h = 1.0 / (m - 1) as f64;
4446        let gam: Vec<f64> = uniform_grid(m);
4447        let psi_raw = gam_to_psi(&gam, h);
4448        let psi_smooth = gam_to_psi_smooth(&gam, h);
4449        // Check interior points (skip ~5% at each boundary)
4450        let skip = m / 20;
4451        for j in skip..(m - skip) {
4452            assert!(
4453                (psi_smooth[j] - 1.0).abs() < 0.05,
4454                "Smoothed psi of identity should be ~1.0, got {} at j={}",
4455                psi_smooth[j],
4456                j
4457            );
4458            assert!(
4459                (psi_smooth[j] - psi_raw[j]).abs() < 0.05,
4460                "Smoothed and raw psi should agree on smooth warp at j={}",
4461                j
4462            );
4463        }
4464    }
4465
4466    #[test]
4467    fn test_gam_to_psi_smooth_reduces_spikes() {
4468        // Create a kinky warp (simulating DP output with multiple slope changes)
4469        // and verify smoothing reduces psi spikes
4470        use crate::warping::{gam_to_psi, gam_to_psi_smooth};
4471        let m = 101;
4472        let h = 1.0 / (m - 1) as f64;
4473        let argvals = uniform_grid(m);
4474        // Multi-segment piecewise-linear warp with several kinks
4475        let mut gam: Vec<f64> = Vec::with_capacity(m);
4476        for j in 0..m {
4477            let t = argvals[j];
4478            // Three segments: slow (slope 0.5), fast (slope 2), slow (slope 0.5)
4479            let g = if t < 0.33 {
4480                t * 0.5 / 0.33
4481            } else if t < 0.67 {
4482                0.5 + (t - 0.33) * 0.5 / 0.34 * 2.0 // steeper
4483            } else {
4484                let base = 0.5 + 0.5 / 0.34 * 2.0 * 0.34; // ~1.5 but clamped
4485                (base + (t - 0.67) * 0.5 / 0.33).min(1.0)
4486            };
4487            gam.push(g.min(1.0));
4488        }
4489        // Normalize to [0,1]
4490        let gmax = gam[m - 1].max(1e-10);
4491        for g in &mut gam {
4492            *g /= gmax;
4493        }
4494        let psi_raw = gam_to_psi(&gam, h);
4495        let psi_smooth = gam_to_psi_smooth(&gam, h);
4496        // The raw psi should have jumps at kink points (slope transitions)
4497        let max_jump_raw: f64 = psi_raw
4498            .windows(2)
4499            .map(|w| (w[1] - w[0]).abs())
4500            .fold(0.0_f64, f64::max);
4501        let max_jump_smooth: f64 = psi_smooth
4502            .windows(2)
4503            .map(|w| (w[1] - w[0]).abs())
4504            .fold(0.0_f64, f64::max);
4505        // Smoothing should reduce the maximum jump in psi
4506        assert!(
4507            max_jump_smooth < max_jump_raw + 0.01,
4508            "Smoothing should not increase max psi jump: raw={max_jump_raw:.4}, smooth={max_jump_smooth:.4}"
4509        );
4510    }
4511
4512    #[test]
4513    fn test_smooth_aligned_srsfs_preserves_shape() {
4514        // Smoothing aligned SRSFs should preserve overall shape
4515        use crate::smoothing::nadaraya_watson;
4516        let m = 101;
4517        let time: Vec<f64> = (0..m).map(|j| j as f64 / (m - 1) as f64).collect();
4518        // Create a smooth SRSF (sine curve)
4519        let qi: Vec<f64> = time
4520            .iter()
4521            .map(|&t| (2.0 * std::f64::consts::PI * t).sin())
4522            .collect();
4523        let bandwidth = 2.0 / (m - 1) as f64;
4524        let qi_smooth = nadaraya_watson(&time, &qi, &time, bandwidth, "gaussian");
4525        // Correlation between original and smoothed should be very high
4526        let mean_orig: f64 = qi.iter().sum::<f64>() / m as f64;
4527        let mean_smooth: f64 = qi_smooth.iter().sum::<f64>() / m as f64;
4528        let mut cov = 0.0;
4529        let mut var_o = 0.0;
4530        let mut var_s = 0.0;
4531        for j in 0..m {
4532            let do_ = qi[j] - mean_orig;
4533            let ds = qi_smooth[j] - mean_smooth;
4534            cov += do_ * ds;
4535            var_o += do_ * do_;
4536            var_s += ds * ds;
4537        }
4538        let rho = cov / (var_o * var_s).sqrt().max(1e-10);
4539        assert!(
4540            rho > 0.99,
4541            "Smoothed SRSF should be highly correlated with original (rho={rho:.4})"
4542        );
4543    }
4544
4545    #[test]
4546    fn test_tsrvf_tangent_vectors_no_spikes() {
4547        // End-to-end: compute TSRVF tangent vectors, verify no element dominates.
4548        // The smooth_aligned_srsfs step in tsrvf_from_alignment removes DP kink
4549        // artifacts that would otherwise produce spike outliers in tangent vectors.
4550        let m = 101;
4551        let argvals = uniform_grid(m);
4552        let data = make_test_data(10, m, 42);
4553        let result = tsrvf_transform(&data, &argvals, 5, 1e-3, 0.0);
4554        let (n, _) = result.tangent_vectors.shape();
4555        for i in 0..n {
4556            let vi = result.tangent_vectors.row(i);
4557            let rms = (vi.iter().map(|&v| v * v).sum::<f64>() / m as f64).sqrt();
4558            if rms > 1e-10 {
4559                let max_abs = vi.iter().map(|&v| v.abs()).fold(0.0_f64, f64::max);
4560                assert!(
4561                    max_abs < 10.0 * rms,
4562                    "Tangent vector {} has spike: max |v| = {max_abs:.4}, rms = {rms:.4}, ratio = {:.1}",
4563                    i,
4564                    max_abs / rms
4565                );
4566            }
4567        }
4568    }
4569}