Skip to main content

fdars_core/
alignment.rs

1//! Elastic alignment and SRSF (Square-Root Slope Function) transforms.
2//!
3//! This module provides phase-amplitude separation for functional data via
4//! the elastic framework. Key capabilities:
5//!
6//! - [`srsf_transform`] / [`srsf_inverse`] — SRSF representation and reconstruction
7//! - [`elastic_align_pair`] — Pairwise curve alignment via dynamic programming
8//! - [`elastic_distance`] — Elastic (Fisher-Rao) distance between curves
9//! - [`align_to_target`] — Align a set of curves to a common target
10//! - [`karcher_mean`] — Karcher (Fréchet) mean in the elastic metric
11//! - [`elastic_self_distance_matrix`] / [`elastic_cross_distance_matrix`] — Distance matrices
12//! - [`reparameterize_curve`] / [`compose_warps`] — Warping utilities
13
14use crate::fdata::{deriv_1d, mean_1d};
15use crate::helpers::{
16    cumulative_trapz, gradient_uniform, l2_distance, linear_interp, simpsons_weights,
17};
18use crate::iter_maybe_parallel;
19use crate::matrix::FdMatrix;
20use crate::warping::{
21    exp_map_sphere, gam_to_psi, inv_exp_map_sphere, invert_gamma, l2_norm_l2, normalize_warp,
22    psi_to_gam,
23};
24#[cfg(feature = "parallel")]
25use rayon::iter::ParallelIterator;
26
27// ─── Types ──────────────────────────────────────────────────────────────────
28
29/// Result of aligning one curve to another.
30#[derive(Debug, Clone)]
31pub struct AlignmentResult {
32    /// Warping function γ mapping the domain to itself.
33    pub gamma: Vec<f64>,
34    /// The aligned (reparameterized) curve.
35    pub f_aligned: Vec<f64>,
36    /// Elastic distance after alignment.
37    pub distance: f64,
38}
39
40/// Result of aligning a set of curves to a common target.
41#[derive(Debug, Clone)]
42pub struct AlignmentSetResult {
43    /// Warping functions (n × m).
44    pub gammas: FdMatrix,
45    /// Aligned curves (n × m).
46    pub aligned_data: FdMatrix,
47    /// Elastic distances for each curve.
48    pub distances: Vec<f64>,
49}
50
51/// Result of the Karcher mean computation.
52#[derive(Debug, Clone)]
53pub struct KarcherMeanResult {
54    /// Karcher mean curve.
55    pub mean: Vec<f64>,
56    /// SRSF of the Karcher mean.
57    pub mean_srsf: Vec<f64>,
58    /// Final warping functions (n × m).
59    pub gammas: FdMatrix,
60    /// Curves aligned to the mean (n × m).
61    pub aligned_data: FdMatrix,
62    /// Number of iterations used.
63    pub n_iter: usize,
64    /// Whether the algorithm converged.
65    pub converged: bool,
66}
67
68// Private helpers are now in crate::helpers and crate::warping.
69
70/// Karcher mean of warping functions on the Hilbert sphere, then invert.
71/// Port of fdasrvf's `SqrtMeanInverse`.
72///
73/// Takes a matrix of warping functions (n × m) on the argvals domain,
74/// computes the Fréchet mean of their sqrt-derivative representations
75/// on the unit Hilbert sphere, converts back to a warping function,
76/// and returns its inverse (on the argvals domain).
77/// One Karcher iteration on the Hilbert sphere: compute mean shooting vector and update mu.
78///
79/// Returns `true` if converged (vbar norm ≤ threshold).
80fn karcher_sphere_step(mu: &mut Vec<f64>, psis: &[Vec<f64>], time: &[f64], step_size: f64) -> bool {
81    let m = mu.len();
82    let n = psis.len();
83    let mut vbar = vec![0.0; m];
84    for psi in psis {
85        let v = inv_exp_map_sphere(mu, psi, time);
86        for j in 0..m {
87            vbar[j] += v[j];
88        }
89    }
90    for j in 0..m {
91        vbar[j] /= n as f64;
92    }
93    if l2_norm_l2(&vbar, time) <= 1e-8 {
94        return true;
95    }
96    let scaled: Vec<f64> = vbar.iter().map(|&v| v * step_size).collect();
97    *mu = exp_map_sphere(mu, &scaled, time);
98    false
99}
100
101fn sqrt_mean_inverse(gammas: &FdMatrix, argvals: &[f64]) -> Vec<f64> {
102    let (n, m) = gammas.shape();
103    let t0 = argvals[0];
104    let t1 = argvals[m - 1];
105    let domain = t1 - t0;
106
107    let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
108    let binsize = 1.0 / (m - 1) as f64;
109
110    let psis: Vec<Vec<f64>> = (0..n)
111        .map(|i| {
112            let gam_01: Vec<f64> = (0..m).map(|j| (gammas[(i, j)] - t0) / domain).collect();
113            gam_to_psi(&gam_01, binsize)
114        })
115        .collect();
116
117    let mut mu = vec![0.0; m];
118    for psi in &psis {
119        for j in 0..m {
120            mu[j] += psi[j];
121        }
122    }
123    for j in 0..m {
124        mu[j] /= n as f64;
125    }
126
127    for _ in 0..501 {
128        if karcher_sphere_step(&mut mu, &psis, &time, 0.3) {
129            break;
130        }
131    }
132
133    let gam_mu = psi_to_gam(&mu, &time);
134    let gam_inv = invert_gamma(&gam_mu, &time);
135    gam_inv.iter().map(|&g| t0 + g * domain).collect()
136}
137
138// ─── SRSF Transform and Inverse ─────────────────────────────────────────────
139
140/// Compute the Square-Root Slope Function (SRSF) transform.
141///
142/// For each curve f, the SRSF is: `q(t) = sign(f'(t)) * sqrt(|f'(t)|)`
143///
144/// # Arguments
145/// * `data` — Functional data matrix (n × m)
146/// * `argvals` — Evaluation points (length m)
147///
148/// # Returns
149/// FdMatrix of SRSFs with the same shape as input.
150pub fn srsf_transform(data: &FdMatrix, argvals: &[f64]) -> FdMatrix {
151    let (n, m) = data.shape();
152    if n == 0 || m == 0 || argvals.len() != m {
153        return FdMatrix::zeros(n, m);
154    }
155
156    let deriv = deriv_1d(data, argvals, 1);
157
158    let mut result = FdMatrix::zeros(n, m);
159    for i in 0..n {
160        for j in 0..m {
161            let d = deriv[(i, j)];
162            result[(i, j)] = d.signum() * d.abs().sqrt();
163        }
164    }
165    result
166}
167
168/// Reconstruct a curve from its SRSF representation.
169///
170/// Given SRSF q and initial value f0, reconstructs: `f(t) = f0 + ∫₀ᵗ q(s)|q(s)| ds`
171///
172/// # Arguments
173/// * `q` — SRSF values (length m)
174/// * `argvals` — Evaluation points (length m)
175/// * `f0` — Initial value f(argvals\[0\])
176///
177/// # Returns
178/// Reconstructed curve values.
179pub fn srsf_inverse(q: &[f64], argvals: &[f64], f0: f64) -> Vec<f64> {
180    let m = q.len();
181    if m == 0 {
182        return Vec::new();
183    }
184
185    // Integrand: q(s) * |q(s)|
186    let integrand: Vec<f64> = q.iter().map(|&qi| qi * qi.abs()).collect();
187    let integral = cumulative_trapz(&integrand, argvals);
188
189    integral.iter().map(|&v| f0 + v).collect()
190}
191
192// ─── Reparameterization ─────────────────────────────────────────────────────
193
194/// Reparameterize a curve by a warping function.
195///
196/// Computes `f(γ(t))` via linear interpolation.
197///
198/// # Arguments
199/// * `f` — Curve values (length m)
200/// * `argvals` — Evaluation points (length m)
201/// * `gamma` — Warping function values (length m)
202pub fn reparameterize_curve(f: &[f64], argvals: &[f64], gamma: &[f64]) -> Vec<f64> {
203    gamma
204        .iter()
205        .map(|&g| linear_interp(argvals, f, g))
206        .collect()
207}
208
209/// Compose two warping functions: `(γ₁ ∘ γ₂)(t) = γ₁(γ₂(t))`.
210///
211/// # Arguments
212/// * `gamma1` — Outer warping function (length m)
213/// * `gamma2` — Inner warping function (length m)
214/// * `argvals` — Evaluation points (length m)
215pub fn compose_warps(gamma1: &[f64], gamma2: &[f64], argvals: &[f64]) -> Vec<f64> {
216    gamma2
217        .iter()
218        .map(|&g| linear_interp(argvals, gamma1, g))
219        .collect()
220}
221
222// ─── Dynamic Programming Alignment ──────────────────────────────────────────
223// Faithful port of fdasrvf's DP algorithm (dp_grid.cpp / dp_nbhd.cpp).
224
225/// Greatest common divisor (Euclidean algorithm).
226#[cfg(test)]
227fn gcd(a: usize, b: usize) -> usize {
228    if b == 0 {
229        a
230    } else {
231        gcd(b, a % b)
232    }
233}
234
235/// Generate coprime neighborhood: all (i,j) with 1 ≤ i,j ≤ nbhd_dim, gcd(i,j) = 1.
236/// With nbhd_dim=7 this produces 35 pairs, matching fdasrvf's default.
237#[cfg(test)]
238fn generate_coprime_nbhd(nbhd_dim: usize) -> Vec<(usize, usize)> {
239    let mut pairs = Vec::new();
240    for i in 1..=nbhd_dim {
241        for j in 1..=nbhd_dim {
242            if gcd(i, j) == 1 {
243                pairs.push((i, j));
244            }
245        }
246    }
247    pairs
248}
249
250/// Pre-computed coprime neighborhood for nbhd_dim=7 (fdasrvf default).
251/// All (dr, dc) with 1 ≤ dr, dc ≤ 7 and gcd(dr, dc) = 1.
252/// dr = row delta (q2 direction), dc = column delta (q1 direction).
253#[rustfmt::skip]
254const COPRIME_NBHD_7: [(usize, usize); 35] = [
255    (1,1),(1,2),(1,3),(1,4),(1,5),(1,6),(1,7),
256    (2,1),      (2,3),      (2,5),      (2,7),
257    (3,1),(3,2),      (3,4),(3,5),      (3,7),
258    (4,1),      (4,3),      (4,5),      (4,7),
259    (5,1),(5,2),(5,3),(5,4),      (5,6),(5,7),
260    (6,1),                  (6,5),      (6,7),
261    (7,1),(7,2),(7,3),(7,4),(7,5),(7,6),
262];
263
264/// Compute the edge weight for a move from grid point (sr, sc) to (tr, tc).
265///
266/// Port of fdasrvf's `dp_edge_weight` for 1-D curves on a shared uniform grid.
267/// - Rows = q2 indices, columns = q1 indices (matching fdasrvf convention).
268/// - `slope = (argvals[tr] - argvals[sr]) / (argvals[tc] - argvals[sc])` = γ'
269/// - Walks through sub-intervals synchronized at both curves' breakpoints,
270///   accumulating `(q1[idx1] - √slope · q2[idx2])² · dt`.
271#[inline]
272fn dp_edge_weight(
273    q1: &[f64],
274    q2: &[f64],
275    argvals: &[f64],
276    sc: usize,
277    tc: usize,
278    sr: usize,
279    tr: usize,
280) -> f64 {
281    let n1 = tc - sc;
282    let n2 = tr - sr;
283    if n1 == 0 || n2 == 0 {
284        return f64::INFINITY;
285    }
286
287    let slope = (argvals[tr] - argvals[sr]) / (argvals[tc] - argvals[sc]);
288    let rslope = slope.sqrt();
289
290    // Walk through sub-intervals synchronized at breakpoints of both curves
291    let mut weight = 0.0;
292    let mut i1 = 0usize; // sub-interval index in q1 direction
293    let mut i2 = 0usize; // sub-interval index in q2 direction
294
295    while i1 < n1 && i2 < n2 {
296        // Current sub-interval boundaries as fractions of the total span
297        let left1 = i1 as f64 / n1 as f64;
298        let right1 = (i1 + 1) as f64 / n1 as f64;
299        let left2 = i2 as f64 / n2 as f64;
300        let right2 = (i2 + 1) as f64 / n2 as f64;
301
302        let left = left1.max(left2);
303        let right = right1.min(right2);
304        let dt = right - left;
305
306        if dt > 0.0 {
307            let diff = q1[sc + i1] - rslope * q2[sr + i2];
308            weight += diff * diff * dt;
309        }
310
311        // Advance whichever sub-interval ends first
312        if right1 < right2 {
313            i1 += 1;
314        } else if right2 < right1 {
315            i2 += 1;
316        } else {
317            i1 += 1;
318            i2 += 1;
319        }
320    }
321
322    // Scale by the span in q1 direction
323    weight * (argvals[tc] - argvals[sc])
324}
325
326/// Compute the λ·(slope−1)²·dt penalty for a DP edge.
327#[inline]
328fn dp_lambda_penalty(
329    argvals: &[f64],
330    sc: usize,
331    tc: usize,
332    sr: usize,
333    tr: usize,
334    lambda: f64,
335) -> f64 {
336    if lambda > 0.0 {
337        let dt = argvals[tc] - argvals[sc];
338        let slope = (argvals[tr] - argvals[sr]) / dt;
339        lambda * (slope - 1.0).powi(2) * dt
340    } else {
341        0.0
342    }
343}
344
345/// Traceback a parent-pointer array from bottom-right to top-left.
346///
347/// Returns the path as `(row, col)` pairs from `(0,0)` to `(nrows-1, ncols-1)`.
348fn dp_traceback(parent: &[u32], nrows: usize, ncols: usize) -> Vec<(usize, usize)> {
349    let mut path = Vec::with_capacity(nrows + ncols);
350    let mut cur = (nrows - 1) * ncols + (ncols - 1);
351    loop {
352        path.push((cur / ncols, cur % ncols));
353        if cur == 0 || parent[cur] == u32::MAX {
354            break;
355        }
356        cur = parent[cur] as usize;
357    }
358    path.reverse();
359    path
360}
361
362/// Try to relax cell `(tr, tc)` from each coprime neighbor, updating cost and parent.
363#[inline]
364fn dp_relax_cell<F>(
365    e: &mut [f64],
366    parent: &mut [u32],
367    ncols: usize,
368    tr: usize,
369    tc: usize,
370    edge_cost: &F,
371) where
372    F: Fn(usize, usize, usize, usize) -> f64,
373{
374    let idx = tr * ncols + tc;
375    for &(dr, dc) in &COPRIME_NBHD_7 {
376        if dr > tr || dc > tc {
377            continue;
378        }
379        let sr = tr - dr;
380        let sc = tc - dc;
381        let src_idx = sr * ncols + sc;
382        if e[src_idx] == f64::INFINITY {
383            continue;
384        }
385        let cost = e[src_idx] + edge_cost(sr, sc, tr, tc);
386        if cost < e[idx] {
387            e[idx] = cost;
388            parent[idx] = src_idx as u32;
389        }
390    }
391}
392
393/// Shared DP grid fill + traceback using the coprime neighborhood.
394///
395/// `edge_cost(sr, sc, tr, tc)` returns the combined edge weight + penalty for
396/// a move from local (sr, sc) to local (tr, tc). Returns the raw local-index
397/// path from (0,0) to (nrows-1, ncols-1).
398fn dp_grid_solve<F>(nrows: usize, ncols: usize, edge_cost: F) -> Vec<(usize, usize)>
399where
400    F: Fn(usize, usize, usize, usize) -> f64,
401{
402    let mut e = vec![f64::INFINITY; nrows * ncols];
403    let mut parent = vec![u32::MAX; nrows * ncols];
404    e[0] = 0.0;
405
406    for tr in 0..nrows {
407        for tc in 0..ncols {
408            if tr == 0 && tc == 0 {
409                continue;
410            }
411            dp_relax_cell(&mut e, &mut parent, ncols, tr, tc, &edge_cost);
412        }
413    }
414
415    dp_traceback(&parent, nrows, ncols)
416}
417
418/// Convert a DP path (local row,col indices) to an interpolated+normalized gamma warp.
419fn dp_path_to_gamma(path: &[(usize, usize)], argvals: &[f64]) -> Vec<f64> {
420    let path_tc: Vec<f64> = path.iter().map(|&(_, c)| argvals[c]).collect();
421    let path_tr: Vec<f64> = path.iter().map(|&(r, _)| argvals[r]).collect();
422    let mut gamma: Vec<f64> = argvals
423        .iter()
424        .map(|&t| linear_interp(&path_tc, &path_tr, t))
425        .collect();
426    normalize_warp(&mut gamma, argvals);
427    gamma
428}
429
430/// Core DP alignment between two SRSFs on a grid.
431///
432/// Finds the optimal warping γ minimizing ‖q₁ - (q₂∘γ)√γ'‖².
433/// Uses fdasrvf's coprime neighborhood (nbhd_dim=7 → 35 move directions).
434/// SRSFs are L2-normalized before alignment (matching fdasrvf's `optimum.reparam`).
435fn dp_alignment_core(q1: &[f64], q2: &[f64], argvals: &[f64], lambda: f64) -> Vec<f64> {
436    let m = argvals.len();
437    if m < 2 {
438        return argvals.to_vec();
439    }
440
441    let norm1 = q1.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
442    let norm2 = q2.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
443    let q1n: Vec<f64> = q1.iter().map(|&v| v / norm1).collect();
444    let q2n: Vec<f64> = q2.iter().map(|&v| v / norm2).collect();
445
446    let path = dp_grid_solve(m, m, |sr, sc, tr, tc| {
447        dp_edge_weight(&q1n, &q2n, argvals, sc, tc, sr, tr)
448            + dp_lambda_penalty(argvals, sc, tc, sr, tr, lambda)
449    });
450
451    dp_path_to_gamma(&path, argvals)
452}
453
454// ─── Public Alignment Functions ─────────────────────────────────────────────
455
456/// Align curve f2 to curve f1 using the elastic framework.
457///
458/// Computes the optimal warping γ such that f2∘γ is as close as possible
459/// to f1 in the elastic (Fisher-Rao) metric.
460///
461/// # Arguments
462/// * `f1` — Target curve (length m)
463/// * `f2` — Curve to align (length m)
464/// * `argvals` — Evaluation points (length m)
465/// * `lambda` — Penalty weight on warp deviation from identity (0.0 = no penalty)
466///
467/// # Returns
468/// [`AlignmentResult`] with warping function, aligned curve, and elastic distance.
469pub fn elastic_align_pair(f1: &[f64], f2: &[f64], argvals: &[f64], lambda: f64) -> AlignmentResult {
470    let m = f1.len();
471
472    // Build single-row FdMatrices for SRSF computation
473    let f1_mat = FdMatrix::from_slice(f1, 1, m).unwrap();
474    let f2_mat = FdMatrix::from_slice(f2, 1, m).unwrap();
475
476    let q1_mat = srsf_transform(&f1_mat, argvals);
477    let q2_mat = srsf_transform(&f2_mat, argvals);
478
479    let q1: Vec<f64> = q1_mat.row(0);
480    let q2: Vec<f64> = q2_mat.row(0);
481
482    // Find optimal warping via DP
483    let gamma = dp_alignment_core(&q1, &q2, argvals, lambda);
484
485    // Apply warping to f2
486    let f_aligned = reparameterize_curve(f2, argvals, &gamma);
487
488    // Compute elastic distance: L2 distance between q1 and aligned q2 SRSF
489    let f_aligned_mat = FdMatrix::from_slice(&f_aligned, 1, m).unwrap();
490    let q_aligned_mat = srsf_transform(&f_aligned_mat, argvals);
491    let q_aligned: Vec<f64> = q_aligned_mat.row(0);
492
493    let weights = simpsons_weights(argvals);
494    let distance = l2_distance(&q1, &q_aligned, &weights);
495
496    AlignmentResult {
497        gamma,
498        f_aligned,
499        distance,
500    }
501}
502
503/// Compute the elastic distance between two curves.
504///
505/// This is shorthand for aligning the pair and returning only the distance.
506///
507/// # Arguments
508/// * `f1` — First curve (length m)
509/// * `f2` — Second curve (length m)
510/// * `argvals` — Evaluation points (length m)
511/// * `lambda` — Penalty weight on warp deviation from identity (0.0 = no penalty)
512pub fn elastic_distance(f1: &[f64], f2: &[f64], argvals: &[f64], lambda: f64) -> f64 {
513    elastic_align_pair(f1, f2, argvals, lambda).distance
514}
515
516/// Align all curves in `data` to a single target curve.
517///
518/// # Arguments
519/// * `data` — Functional data matrix (n × m)
520/// * `target` — Target curve to align to (length m)
521/// * `argvals` — Evaluation points (length m)
522/// * `lambda` — Penalty weight on warp deviation from identity (0.0 = no penalty)
523///
524/// # Returns
525/// [`AlignmentSetResult`] with all warping functions, aligned curves, and distances.
526pub fn align_to_target(
527    data: &FdMatrix,
528    target: &[f64],
529    argvals: &[f64],
530    lambda: f64,
531) -> AlignmentSetResult {
532    let (n, m) = data.shape();
533
534    let results: Vec<AlignmentResult> = iter_maybe_parallel!(0..n)
535        .map(|i| {
536            let fi = data.row(i);
537            elastic_align_pair(target, &fi, argvals, lambda)
538        })
539        .collect();
540
541    let mut gammas = FdMatrix::zeros(n, m);
542    let mut aligned_data = FdMatrix::zeros(n, m);
543    let mut distances = Vec::with_capacity(n);
544
545    for (i, r) in results.into_iter().enumerate() {
546        for j in 0..m {
547            gammas[(i, j)] = r.gamma[j];
548            aligned_data[(i, j)] = r.f_aligned[j];
549        }
550        distances.push(r.distance);
551    }
552
553    AlignmentSetResult {
554        gammas,
555        aligned_data,
556        distances,
557    }
558}
559
560// ─── Distance Matrices ──────────────────────────────────────────────────────
561
562/// Compute the symmetric elastic distance matrix for a set of curves.
563///
564/// Uses upper-triangle computation with parallelism, following the
565/// `self_distance_matrix` pattern from `metric.rs`.
566///
567/// # Arguments
568/// * `data` — Functional data matrix (n × m)
569/// * `argvals` — Evaluation points (length m)
570/// * `lambda` — Penalty weight on warp deviation from identity (0.0 = no penalty)
571///
572/// # Returns
573/// Symmetric n × n distance matrix.
574pub fn elastic_self_distance_matrix(data: &FdMatrix, argvals: &[f64], lambda: f64) -> FdMatrix {
575    let n = data.nrows();
576
577    let upper_vals: Vec<f64> = iter_maybe_parallel!(0..n)
578        .flat_map(|i| {
579            let fi = data.row(i);
580            ((i + 1)..n)
581                .map(|j| {
582                    let fj = data.row(j);
583                    elastic_distance(&fi, &fj, argvals, lambda)
584                })
585                .collect::<Vec<_>>()
586        })
587        .collect();
588
589    let mut dist = FdMatrix::zeros(n, n);
590    let mut idx = 0;
591    for i in 0..n {
592        for j in (i + 1)..n {
593            let d = upper_vals[idx];
594            dist[(i, j)] = d;
595            dist[(j, i)] = d;
596            idx += 1;
597        }
598    }
599    dist
600}
601
602/// Compute the elastic distance matrix between two sets of curves.
603///
604/// # Arguments
605/// * `data1` — First dataset (n1 × m)
606/// * `data2` — Second dataset (n2 × m)
607/// * `argvals` — Evaluation points (length m)
608/// * `lambda` — Penalty weight on warp deviation from identity (0.0 = no penalty)
609///
610/// # Returns
611/// n1 × n2 distance matrix.
612pub fn elastic_cross_distance_matrix(
613    data1: &FdMatrix,
614    data2: &FdMatrix,
615    argvals: &[f64],
616    lambda: f64,
617) -> FdMatrix {
618    let n1 = data1.nrows();
619    let n2 = data2.nrows();
620
621    let vals: Vec<f64> = iter_maybe_parallel!(0..n1)
622        .flat_map(|i| {
623            let fi = data1.row(i);
624            (0..n2)
625                .map(|j| {
626                    let fj = data2.row(j);
627                    elastic_distance(&fi, &fj, argvals, lambda)
628                })
629                .collect::<Vec<_>>()
630        })
631        .collect();
632
633    let mut dist = FdMatrix::zeros(n1, n2);
634    for i in 0..n1 {
635        for j in 0..n2 {
636            dist[(i, j)] = vals[i * n2 + j];
637        }
638    }
639    dist
640}
641
642// ─── Phase-Amplitude Decomposition ──────────────────────────────────────────
643
644/// Result of elastic phase-amplitude decomposition.
645#[derive(Debug, Clone)]
646pub struct DecompositionResult {
647    /// Full alignment result.
648    pub alignment: AlignmentResult,
649    /// Amplitude distance: SRSF distance after alignment.
650    pub d_amplitude: f64,
651    /// Phase distance: geodesic distance of warp from identity.
652    pub d_phase: f64,
653}
654
655/// Perform elastic phase-amplitude decomposition of two curves.
656///
657/// Returns both the alignment result and the separate amplitude and phase distances.
658///
659/// # Arguments
660/// * `f1` — Target curve (length m)
661/// * `f2` — Curve to decompose against f1 (length m)
662/// * `argvals` — Evaluation points (length m)
663/// * `lambda` — Penalty weight on warp deviation from identity (0.0 = no penalty)
664pub fn elastic_decomposition(
665    f1: &[f64],
666    f2: &[f64],
667    argvals: &[f64],
668    lambda: f64,
669) -> DecompositionResult {
670    let alignment = elastic_align_pair(f1, f2, argvals, lambda);
671    let d_amplitude = alignment.distance;
672    let d_phase = crate::warping::phase_distance(&alignment.gamma, argvals);
673    DecompositionResult {
674        alignment,
675        d_amplitude,
676        d_phase,
677    }
678}
679
680/// Compute the amplitude distance between two curves (= elastic distance after alignment).
681pub fn amplitude_distance(f1: &[f64], f2: &[f64], argvals: &[f64], lambda: f64) -> f64 {
682    elastic_distance(f1, f2, argvals, lambda)
683}
684
685/// Compute the phase distance between two curves (geodesic distance of optimal warp from identity).
686pub fn phase_distance_pair(f1: &[f64], f2: &[f64], argvals: &[f64], lambda: f64) -> f64 {
687    let alignment = elastic_align_pair(f1, f2, argvals, lambda);
688    crate::warping::phase_distance(&alignment.gamma, argvals)
689}
690
691/// Compute the symmetric phase distance matrix for a set of curves.
692pub fn phase_self_distance_matrix(data: &FdMatrix, argvals: &[f64], lambda: f64) -> FdMatrix {
693    let n = data.nrows();
694
695    let upper_vals: Vec<f64> = iter_maybe_parallel!(0..n)
696        .flat_map(|i| {
697            let fi = data.row(i);
698            ((i + 1)..n)
699                .map(|j| {
700                    let fj = data.row(j);
701                    phase_distance_pair(&fi, &fj, argvals, lambda)
702                })
703                .collect::<Vec<_>>()
704        })
705        .collect();
706
707    let mut dist = FdMatrix::zeros(n, n);
708    let mut idx = 0;
709    for i in 0..n {
710        for j in (i + 1)..n {
711            let d = upper_vals[idx];
712            dist[(i, j)] = d;
713            dist[(j, i)] = d;
714            idx += 1;
715        }
716    }
717    dist
718}
719
720/// Compute the symmetric amplitude distance matrix (= elastic self distance matrix).
721pub fn amplitude_self_distance_matrix(data: &FdMatrix, argvals: &[f64], lambda: f64) -> FdMatrix {
722    elastic_self_distance_matrix(data, argvals, lambda)
723}
724
725// ─── Karcher Mean ───────────────────────────────────────────────────────────
726
727/// Compute relative change between successive mean SRSFs.
728///
729/// Returns `‖q_new - q_old‖₂ / ‖q_old‖₂`, matching R's fdasrvf
730/// `time_warping` convergence metric (unweighted discrete L2 norm).
731fn relative_change(q_old: &[f64], q_new: &[f64]) -> f64 {
732    let diff_norm: f64 = q_old
733        .iter()
734        .zip(q_new.iter())
735        .map(|(&a, &b)| (a - b).powi(2))
736        .sum::<f64>()
737        .sqrt();
738    let old_norm: f64 = q_old.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
739    diff_norm / old_norm
740}
741
742/// Compute a single SRSF from a slice (single-row convenience).
743fn srsf_single(f: &[f64], argvals: &[f64]) -> Vec<f64> {
744    let m = f.len();
745    let mat = FdMatrix::from_slice(f, 1, m).unwrap();
746    let q_mat = srsf_transform(&mat, argvals);
747    q_mat.row(0)
748}
749
750/// Align a single SRSF q2 to q1 and return (gamma, aligned_q).
751fn align_srsf_pair(q1: &[f64], q2: &[f64], argvals: &[f64], lambda: f64) -> (Vec<f64>, Vec<f64>) {
752    let gamma = dp_alignment_core(q1, q2, argvals, lambda);
753
754    // Warp q2 by gamma and adjust by sqrt(gamma')
755    let q2_warped = reparameterize_curve(q2, argvals, &gamma);
756
757    // Compute gamma' via finite differences
758    let m = gamma.len();
759    let mut gamma_dot = vec![0.0; m];
760    gamma_dot[0] = (gamma[1] - gamma[0]) / (argvals[1] - argvals[0]);
761    for j in 1..(m - 1) {
762        gamma_dot[j] = (gamma[j + 1] - gamma[j - 1]) / (argvals[j + 1] - argvals[j - 1]);
763    }
764    gamma_dot[m - 1] = (gamma[m - 1] - gamma[m - 2]) / (argvals[m - 1] - argvals[m - 2]);
765
766    // q2_aligned = (q2 ∘ γ) * sqrt(γ')
767    let q2_aligned: Vec<f64> = q2_warped
768        .iter()
769        .zip(gamma_dot.iter())
770        .map(|(&q, &gd)| q * gd.max(0.0).sqrt())
771        .collect();
772
773    (gamma, q2_aligned)
774}
775
776/// Compute the Karcher (Fréchet) mean in the elastic metric.
777///
778/// Iteratively aligns all curves to the current mean estimate in SRSF space,
779/// computes the pointwise mean of aligned SRSFs, and reconstructs the mean curve.
780///
781/// # Arguments
782/// * `data` — Functional data matrix (n × m)
783/// * `argvals` — Evaluation points (length m)
784/// * `max_iter` — Maximum number of iterations
785/// * `tol` — Convergence tolerance for the SRSF mean
786///
787/// # Returns
788/// [`KarcherMeanResult`] with mean curve, warping functions, aligned data, and convergence info.
789///
790/// # Examples
791///
792/// ```
793/// use fdars_core::simulation::{sim_fundata, EFunType, EValType};
794/// use fdars_core::alignment::karcher_mean;
795///
796/// let t: Vec<f64> = (0..50).map(|i| i as f64 / 49.0).collect();
797/// let data = sim_fundata(20, &t, 3, EFunType::Fourier, EValType::Exponential, Some(42));
798///
799/// let result = karcher_mean(&data, &t, 20, 1e-4, 0.0);
800/// assert_eq!(result.mean.len(), 50);
801/// assert!(result.n_iter <= 20);
802/// ```
803/// Accumulate alignment results: store gammas and return the mean of aligned SRSFs.
804fn accumulate_alignments(
805    results: &[(Vec<f64>, Vec<f64>)],
806    gammas: &mut FdMatrix,
807    m: usize,
808    n: usize,
809) -> Vec<f64> {
810    let mut mu_q_new = vec![0.0; m];
811    for (i, (gamma, q_aligned)) in results.iter().enumerate() {
812        for j in 0..m {
813            gammas[(i, j)] = gamma[j];
814            mu_q_new[j] += q_aligned[j];
815        }
816    }
817    for j in 0..m {
818        mu_q_new[j] /= n as f64;
819    }
820    mu_q_new
821}
822
823/// Apply stored warps to original curves to produce aligned data.
824fn apply_stored_warps(data: &FdMatrix, gammas: &FdMatrix, argvals: &[f64]) -> FdMatrix {
825    let (n, m) = data.shape();
826    let mut aligned = FdMatrix::zeros(n, m);
827    for i in 0..n {
828        let fi = data.row(i);
829        let gamma: Vec<f64> = (0..m).map(|j| gammas[(i, j)]).collect();
830        let f_aligned = reparameterize_curve(&fi, argvals, &gamma);
831        for j in 0..m {
832            aligned[(i, j)] = f_aligned[j];
833        }
834    }
835    aligned
836}
837
838/// Select the SRSF closest to the pointwise mean as template. Returns (mu_q, mu_f).
839fn select_template(srsf_mat: &FdMatrix, data: &FdMatrix, argvals: &[f64]) -> (Vec<f64>, Vec<f64>) {
840    let (n, m) = srsf_mat.shape();
841    let mnq = mean_1d(srsf_mat);
842    let mut min_dist = f64::INFINITY;
843    let mut min_idx = 0;
844    for i in 0..n {
845        let dist_sq: f64 = (0..m).map(|j| (srsf_mat[(i, j)] - mnq[j]).powi(2)).sum();
846        if dist_sq < min_dist {
847            min_dist = dist_sq;
848            min_idx = i;
849        }
850    }
851    let _ = argvals; // kept for API consistency
852    (srsf_mat.row(min_idx), data.row(min_idx))
853}
854
855/// Pre-centering: align all curves to template, compute inverse mean warp, re-center.
856fn pre_center_template(
857    data: &FdMatrix,
858    mu_q: &[f64],
859    mu: &[f64],
860    argvals: &[f64],
861    lambda: f64,
862) -> (Vec<f64>, Vec<f64>) {
863    let (n, m) = data.shape();
864    let align_results: Vec<(Vec<f64>, Vec<f64>)> = iter_maybe_parallel!(0..n)
865        .map(|i| {
866            let fi = data.row(i);
867            let qi = srsf_single(&fi, argvals);
868            align_srsf_pair(mu_q, &qi, argvals, lambda)
869        })
870        .collect();
871
872    let mut init_gammas = FdMatrix::zeros(n, m);
873    for (i, (gamma, _)) in align_results.iter().enumerate() {
874        for j in 0..m {
875            init_gammas[(i, j)] = gamma[j];
876        }
877    }
878
879    let gam_inv = sqrt_mean_inverse(&init_gammas, argvals);
880    let mu_new = reparameterize_curve(mu, argvals, &gam_inv);
881    let mu_q_new = srsf_single(&mu_new, argvals);
882    (mu_q_new, mu_new)
883}
884
885/// Post-convergence centering: center mean SRSF and warps via SqrtMeanInverse.
886fn post_center_results(
887    data: &FdMatrix,
888    mu_q: &[f64],
889    final_gammas: &mut FdMatrix,
890    argvals: &[f64],
891) -> (Vec<f64>, Vec<f64>, FdMatrix) {
892    let (n, m) = data.shape();
893    let gam_inv = sqrt_mean_inverse(final_gammas, argvals);
894    let h = (argvals[m - 1] - argvals[0]) / (m - 1) as f64;
895    let gam_inv_dev = gradient_uniform(&gam_inv, h);
896
897    let mu_q_warped = reparameterize_curve(mu_q, argvals, &gam_inv);
898    let mu_q_centered: Vec<f64> = mu_q_warped
899        .iter()
900        .zip(gam_inv_dev.iter())
901        .map(|(&q, &gd)| q * gd.max(0.0).sqrt())
902        .collect();
903
904    for i in 0..n {
905        let gam_i: Vec<f64> = (0..m).map(|j| final_gammas[(i, j)]).collect();
906        let gam_centered = reparameterize_curve(&gam_i, argvals, &gam_inv);
907        for j in 0..m {
908            final_gammas[(i, j)] = gam_centered[j];
909        }
910    }
911
912    let initial_mean = mean_1d(data);
913    let mu = srsf_inverse(&mu_q_centered, argvals, initial_mean[0]);
914    let final_aligned = apply_stored_warps(data, final_gammas, argvals);
915    (mu, mu_q_centered, final_aligned)
916}
917
918pub fn karcher_mean(
919    data: &FdMatrix,
920    argvals: &[f64],
921    max_iter: usize,
922    tol: f64,
923    lambda: f64,
924) -> KarcherMeanResult {
925    let (n, m) = data.shape();
926
927    let srsf_mat = srsf_transform(data, argvals);
928    let (mut mu_q, mu) = select_template(&srsf_mat, data, argvals);
929    let (mu_q_c, mu_c) = pre_center_template(data, &mu_q, &mu, argvals, lambda);
930    mu_q = mu_q_c;
931    let mut mu = mu_c;
932
933    let mut converged = false;
934    let mut n_iter = 0;
935    let mut final_gammas = FdMatrix::zeros(n, m);
936    let mut prev_rel = 0.0_f64;
937
938    for iter in 0..max_iter {
939        n_iter = iter + 1;
940
941        let align_results: Vec<(Vec<f64>, Vec<f64>)> = iter_maybe_parallel!(0..n)
942            .map(|i| {
943                let fi = data.row(i);
944                let qi = srsf_single(&fi, argvals);
945                align_srsf_pair(&mu_q, &qi, argvals, lambda)
946            })
947            .collect();
948
949        let mu_q_new = accumulate_alignments(&align_results, &mut final_gammas, m, n);
950
951        let rel = relative_change(&mu_q, &mu_q_new);
952        if rel < f64::EPSILON || (iter > 0 && rel - prev_rel <= tol * prev_rel) {
953            converged = true;
954            mu_q = mu_q_new;
955            break;
956        }
957        prev_rel = rel;
958
959        mu_q = mu_q_new;
960        mu = srsf_inverse(&mu_q, argvals, mu[0]);
961    }
962
963    let (mu_final, mu_q_final, final_aligned) =
964        post_center_results(data, &mu_q, &mut final_gammas, argvals);
965
966    KarcherMeanResult {
967        mean: mu_final,
968        mean_srsf: mu_q_final,
969        gammas: final_gammas,
970        aligned_data: final_aligned,
971        n_iter,
972        converged,
973    }
974}
975
976// ─── TSRVF (Transported SRSF) ────────────────────────────────────────────────
977// Maps aligned SRSFs to the tangent space of the Karcher mean on the Hilbert
978// sphere. Tangent vectors live in a standard Euclidean space, enabling PCA,
979// regression, and clustering on elastic-aligned curves.
980
981/// Result of the TSRVF transform.
982#[derive(Debug, Clone)]
983pub struct TsrvfResult {
984    /// Tangent vectors in Euclidean space (n × m).
985    pub tangent_vectors: FdMatrix,
986    /// Karcher mean curve (length m).
987    pub mean: Vec<f64>,
988    /// SRSF of the Karcher mean (length m).
989    pub mean_srsf: Vec<f64>,
990    /// L2 norm of the mean SRSF.
991    pub mean_srsf_norm: f64,
992    /// Per-curve aligned SRSF norms (length n).
993    pub srsf_norms: Vec<f64>,
994    /// Warping functions from Karcher mean computation (n × m).
995    pub gammas: FdMatrix,
996    /// Whether the Karcher mean converged.
997    pub converged: bool,
998}
999
1000/// Full TSRVF pipeline: compute Karcher mean, then transport SRSFs to tangent space.
1001///
1002/// # Arguments
1003/// * `data` — Functional data matrix (n × m)
1004/// * `argvals` — Evaluation points (length m)
1005/// * `max_iter` — Maximum Karcher mean iterations
1006/// * `tol` — Convergence tolerance for Karcher mean
1007///
1008/// # Returns
1009/// [`TsrvfResult`] containing tangent vectors and associated metadata.
1010pub fn tsrvf_transform(
1011    data: &FdMatrix,
1012    argvals: &[f64],
1013    max_iter: usize,
1014    tol: f64,
1015    lambda: f64,
1016) -> TsrvfResult {
1017    let karcher = karcher_mean(data, argvals, max_iter, tol, lambda);
1018    tsrvf_from_alignment(&karcher, argvals)
1019}
1020
1021/// Compute TSRVF from a pre-computed Karcher mean alignment.
1022///
1023/// Avoids re-running the expensive Karcher mean computation when the alignment
1024/// has already been computed.
1025///
1026/// # Arguments
1027/// * `karcher` — Pre-computed Karcher mean result
1028/// * `argvals` — Evaluation points (length m)
1029///
1030/// # Returns
1031/// [`TsrvfResult`] containing tangent vectors and associated metadata.
1032pub fn tsrvf_from_alignment(karcher: &KarcherMeanResult, argvals: &[f64]) -> TsrvfResult {
1033    let (n, m) = karcher.aligned_data.shape();
1034
1035    // Step 1: Compute SRSFs of aligned data
1036    let aligned_srsf = srsf_transform(&karcher.aligned_data, argvals);
1037
1038    // Step 2: Normalize mean SRSF to unit sphere
1039    // Work on [0,1] for sphere geometry
1040    let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
1041    let mean_norm = l2_norm_l2(&karcher.mean_srsf, &time);
1042
1043    let mu_unit: Vec<f64> = if mean_norm > 1e-10 {
1044        karcher.mean_srsf.iter().map(|&q| q / mean_norm).collect()
1045    } else {
1046        vec![0.0; m]
1047    };
1048
1049    // Step 3: For each aligned curve, compute tangent vector via inverse exponential map
1050    let srsf_norms: Vec<f64> = iter_maybe_parallel!(0..n)
1051        .map(|i| {
1052            let qi = aligned_srsf.row(i);
1053            l2_norm_l2(&qi, &time)
1054        })
1055        .collect();
1056
1057    let tangent_data: Vec<Vec<f64>> = iter_maybe_parallel!(0..n)
1058        .map(|i| {
1059            let qi = aligned_srsf.row(i);
1060            let qi_norm = srsf_norms[i];
1061
1062            if qi_norm < 1e-10 || mean_norm < 1e-10 {
1063                return vec![0.0; m];
1064            }
1065
1066            // Normalize to unit sphere
1067            let qi_unit: Vec<f64> = qi.iter().map(|&q| q / qi_norm).collect();
1068
1069            // Shooting vector from mu_unit to qi_unit
1070            inv_exp_map_sphere(&mu_unit, &qi_unit, &time)
1071        })
1072        .collect();
1073
1074    // Assemble tangent vectors into FdMatrix
1075    let mut tangent_vectors = FdMatrix::zeros(n, m);
1076    for i in 0..n {
1077        for j in 0..m {
1078            tangent_vectors[(i, j)] = tangent_data[i][j];
1079        }
1080    }
1081
1082    TsrvfResult {
1083        tangent_vectors,
1084        mean: karcher.mean.clone(),
1085        mean_srsf: karcher.mean_srsf.clone(),
1086        mean_srsf_norm: mean_norm,
1087        srsf_norms,
1088        gammas: karcher.gammas.clone(),
1089        converged: karcher.converged,
1090    }
1091}
1092
1093/// Reconstruct aligned curves from TSRVF tangent vectors.
1094///
1095/// Inverts the TSRVF transform: maps tangent vectors back to the Hilbert sphere
1096/// via the exponential map, rescales, and reconstructs curves via SRSF inverse.
1097///
1098/// # Arguments
1099/// * `tsrvf` — TSRVF result from [`tsrvf_transform`] or [`tsrvf_from_alignment`]
1100/// * `argvals` — Evaluation points (length m)
1101///
1102/// # Returns
1103/// FdMatrix of reconstructed aligned curves (n × m).
1104pub fn tsrvf_inverse(tsrvf: &TsrvfResult, argvals: &[f64]) -> FdMatrix {
1105    let (n, m) = tsrvf.tangent_vectors.shape();
1106
1107    let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
1108
1109    // Normalize mean SRSF to unit sphere
1110    let mu_unit: Vec<f64> = if tsrvf.mean_srsf_norm > 1e-10 {
1111        tsrvf
1112            .mean_srsf
1113            .iter()
1114            .map(|&q| q / tsrvf.mean_srsf_norm)
1115            .collect()
1116    } else {
1117        vec![0.0; m]
1118    };
1119
1120    let curves: Vec<Vec<f64>> = iter_maybe_parallel!(0..n)
1121        .map(|i| {
1122            let vi = tsrvf.tangent_vectors.row(i);
1123
1124            // Map back to sphere: exp_map(mu_unit, v_i)
1125            let qi_unit = exp_map_sphere(&mu_unit, &vi, &time);
1126
1127            // Rescale by original norm
1128            let qi: Vec<f64> = qi_unit.iter().map(|&q| q * tsrvf.srsf_norms[i]).collect();
1129
1130            // Reconstruct curve from SRSF
1131            srsf_inverse(&qi, argvals, tsrvf.mean[0])
1132        })
1133        .collect();
1134
1135    let mut result = FdMatrix::zeros(n, m);
1136    for i in 0..n {
1137        for j in 0..m {
1138            result[(i, j)] = curves[i][j];
1139        }
1140    }
1141    result
1142}
1143
1144// ─── Parallel Transport Variants ─────────────────────────────────────────────
1145
1146/// Method for transporting tangent vectors on the Hilbert sphere.
1147#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
1148pub enum TransportMethod {
1149    /// Inverse exponential map (log map) — default, matches existing TSRVF behavior.
1150    #[default]
1151    LogMap,
1152    /// Schild's ladder approximation to parallel transport.
1153    SchildsLadder,
1154    /// Pole ladder approximation to parallel transport.
1155    PoleLadder,
1156}
1157
1158/// Schild's ladder parallel transport of vector `v` from `from` to `to` on the sphere.
1159fn parallel_transport_schilds(v: &[f64], from: &[f64], to: &[f64], time: &[f64]) -> Vec<f64> {
1160    use crate::warping::{exp_map_sphere, inv_exp_map_sphere};
1161
1162    let v_norm = crate::warping::l2_norm_l2(v, time);
1163    if v_norm < 1e-10 {
1164        return vec![0.0; v.len()];
1165    }
1166
1167    // endpoint = exp_from(v)
1168    let endpoint = exp_map_sphere(from, v, time);
1169
1170    // midpoint_v = log_to(endpoint) — vector at `to` pointing toward endpoint
1171    let log_to_ep = inv_exp_map_sphere(to, &endpoint, time);
1172
1173    // midpoint = exp_to(0.5 * log_to_ep)
1174    let half_log: Vec<f64> = log_to_ep.iter().map(|&x| 0.5 * x).collect();
1175    let midpoint = exp_map_sphere(to, &half_log, time);
1176
1177    // transported = 2 * log_to(midpoint)
1178    let log_to_mid = inv_exp_map_sphere(to, &midpoint, time);
1179    log_to_mid.iter().map(|&x| 2.0 * x).collect()
1180}
1181
1182/// Pole ladder parallel transport of vector `v` from `from` to `to` on the sphere.
1183fn parallel_transport_pole(v: &[f64], from: &[f64], to: &[f64], time: &[f64]) -> Vec<f64> {
1184    use crate::warping::{exp_map_sphere, inv_exp_map_sphere};
1185
1186    let v_norm = crate::warping::l2_norm_l2(v, time);
1187    if v_norm < 1e-10 {
1188        return vec![0.0; v.len()];
1189    }
1190
1191    // pole = exp_from(-v)
1192    let neg_v: Vec<f64> = v.iter().map(|&x| -x).collect();
1193    let pole = exp_map_sphere(from, &neg_v, time);
1194
1195    // midpoint_v = log_to(pole)
1196    let log_to_pole = inv_exp_map_sphere(to, &pole, time);
1197
1198    // midpoint = exp_to(0.5 * log_to_pole)
1199    let half_log: Vec<f64> = log_to_pole.iter().map(|&x| 0.5 * x).collect();
1200    let midpoint = exp_map_sphere(to, &half_log, time);
1201
1202    // transported = -2 * log_to(midpoint)
1203    let log_to_mid = inv_exp_map_sphere(to, &midpoint, time);
1204    log_to_mid.iter().map(|&x| -2.0 * x).collect()
1205}
1206
1207/// Full TSRVF pipeline with configurable transport method.
1208///
1209/// Like [`tsrvf_transform`] but allows choosing the parallel transport method.
1210pub fn tsrvf_transform_with_method(
1211    data: &FdMatrix,
1212    argvals: &[f64],
1213    max_iter: usize,
1214    tol: f64,
1215    lambda: f64,
1216    method: TransportMethod,
1217) -> TsrvfResult {
1218    let karcher = karcher_mean(data, argvals, max_iter, tol, lambda);
1219    tsrvf_from_alignment_with_method(&karcher, argvals, method)
1220}
1221
1222/// Compute TSRVF from a pre-computed Karcher mean with configurable transport.
1223///
1224/// - [`TransportMethod::LogMap`]: Uses `inv_exp_map(mu, qi)` directly (standard TSRVF).
1225/// - [`TransportMethod::SchildsLadder`]: Computes `v = -log_qi(mu)`, then transports
1226///   via Schild's ladder from qi to mu.
1227/// - [`TransportMethod::PoleLadder`]: Same but via pole ladder.
1228pub fn tsrvf_from_alignment_with_method(
1229    karcher: &KarcherMeanResult,
1230    argvals: &[f64],
1231    method: TransportMethod,
1232) -> TsrvfResult {
1233    if method == TransportMethod::LogMap {
1234        return tsrvf_from_alignment(karcher, argvals);
1235    }
1236
1237    let (n, m) = karcher.aligned_data.shape();
1238    let aligned_srsf = srsf_transform(&karcher.aligned_data, argvals);
1239    let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
1240    let mean_norm = crate::warping::l2_norm_l2(&karcher.mean_srsf, &time);
1241
1242    let mu_unit: Vec<f64> = if mean_norm > 1e-10 {
1243        karcher.mean_srsf.iter().map(|&q| q / mean_norm).collect()
1244    } else {
1245        vec![0.0; m]
1246    };
1247
1248    let srsf_norms: Vec<f64> = iter_maybe_parallel!(0..n)
1249        .map(|i| {
1250            let qi = aligned_srsf.row(i);
1251            crate::warping::l2_norm_l2(&qi, &time)
1252        })
1253        .collect();
1254
1255    let tangent_data: Vec<Vec<f64>> = iter_maybe_parallel!(0..n)
1256        .map(|i| {
1257            let qi = aligned_srsf.row(i);
1258            let qi_norm = srsf_norms[i];
1259
1260            if qi_norm < 1e-10 || mean_norm < 1e-10 {
1261                return vec![0.0; m];
1262            }
1263
1264            let qi_unit: Vec<f64> = qi.iter().map(|&q| q / qi_norm).collect();
1265
1266            // Compute v = -log_qi(mu) — vector at qi pointing away from mu
1267            let v_at_qi = inv_exp_map_sphere(&qi_unit, &mu_unit, &time);
1268            let neg_v: Vec<f64> = v_at_qi.iter().map(|&x| -x).collect();
1269
1270            // Transport from qi to mu
1271            match method {
1272                TransportMethod::SchildsLadder => {
1273                    parallel_transport_schilds(&neg_v, &qi_unit, &mu_unit, &time)
1274                }
1275                TransportMethod::PoleLadder => {
1276                    parallel_transport_pole(&neg_v, &qi_unit, &mu_unit, &time)
1277                }
1278                TransportMethod::LogMap => unreachable!(),
1279            }
1280        })
1281        .collect();
1282
1283    let mut tangent_vectors = FdMatrix::zeros(n, m);
1284    for i in 0..n {
1285        for j in 0..m {
1286            tangent_vectors[(i, j)] = tangent_data[i][j];
1287        }
1288    }
1289
1290    TsrvfResult {
1291        tangent_vectors,
1292        mean: karcher.mean.clone(),
1293        mean_srsf: karcher.mean_srsf.clone(),
1294        mean_srsf_norm: mean_norm,
1295        srsf_norms,
1296        gammas: karcher.gammas.clone(),
1297        converged: karcher.converged,
1298    }
1299}
1300
1301// ─── Alignment Quality Metrics ───────────────────────────────────────────────
1302
1303/// Comprehensive alignment quality assessment.
1304#[derive(Debug, Clone)]
1305pub struct AlignmentQuality {
1306    /// Per-curve geodesic distance from warp to identity.
1307    pub warp_complexity: Vec<f64>,
1308    /// Mean warp complexity.
1309    pub mean_warp_complexity: f64,
1310    /// Per-curve bending energy ∫(γ'')² dt.
1311    pub warp_smoothness: Vec<f64>,
1312    /// Mean warp smoothness (bending energy).
1313    pub mean_warp_smoothness: f64,
1314    /// Total variance: (1/n) Σ ∫(f_i - mean_orig)² dt.
1315    pub total_variance: f64,
1316    /// Amplitude variance: (1/n) Σ ∫(f_i^aligned - mean_aligned)² dt.
1317    pub amplitude_variance: f64,
1318    /// Phase variance: total - amplitude (clamped ≥ 0).
1319    pub phase_variance: f64,
1320    /// Phase-to-total variance ratio.
1321    pub phase_amplitude_ratio: f64,
1322    /// Pointwise ratio: aligned_var / orig_var per time point.
1323    pub pointwise_variance_ratio: Vec<f64>,
1324    /// Mean variance reduction.
1325    pub mean_variance_reduction: f64,
1326}
1327
1328/// Compute warp complexity: geodesic distance from a warp to the identity.
1329///
1330/// This is `arccos(⟨ψ, ψ_id⟩)` on the Hilbert sphere.
1331pub fn warp_complexity(gamma: &[f64], argvals: &[f64]) -> f64 {
1332    crate::warping::phase_distance(gamma, argvals)
1333}
1334
1335/// Compute warp smoothness (bending energy): ∫(γ'')² dt.
1336pub fn warp_smoothness(gamma: &[f64], argvals: &[f64]) -> f64 {
1337    let m = gamma.len();
1338    if m < 3 {
1339        return 0.0;
1340    }
1341
1342    let h = (argvals[m - 1] - argvals[0]) / (m - 1) as f64;
1343    let gam_prime = gradient_uniform(gamma, h);
1344    let gam_pprime = gradient_uniform(&gam_prime, h);
1345
1346    let integrand: Vec<f64> = gam_pprime.iter().map(|&g| g * g).collect();
1347    crate::helpers::trapz(&integrand, argvals)
1348}
1349
1350/// Compute comprehensive alignment quality metrics.
1351///
1352/// # Arguments
1353/// * `data` — Original functional data (n × m)
1354/// * `karcher` — Pre-computed Karcher mean result
1355/// * `argvals` — Evaluation points (length m)
1356pub fn alignment_quality(
1357    data: &FdMatrix,
1358    karcher: &KarcherMeanResult,
1359    argvals: &[f64],
1360) -> AlignmentQuality {
1361    let (n, m) = data.shape();
1362    let weights = simpsons_weights(argvals);
1363
1364    // Per-curve warp complexity and smoothness
1365    let wc: Vec<f64> = (0..n)
1366        .map(|i| {
1367            let gamma: Vec<f64> = (0..m).map(|j| karcher.gammas[(i, j)]).collect();
1368            warp_complexity(&gamma, argvals)
1369        })
1370        .collect();
1371    let ws: Vec<f64> = (0..n)
1372        .map(|i| {
1373            let gamma: Vec<f64> = (0..m).map(|j| karcher.gammas[(i, j)]).collect();
1374            warp_smoothness(&gamma, argvals)
1375        })
1376        .collect();
1377
1378    let mean_wc = wc.iter().sum::<f64>() / n as f64;
1379    let mean_ws = ws.iter().sum::<f64>() / n as f64;
1380
1381    // Compute original mean
1382    let orig_mean = crate::fdata::mean_1d(data);
1383
1384    // Total variance
1385    let total_var: f64 = (0..n)
1386        .map(|i| {
1387            let fi = data.row(i);
1388            let d = l2_distance(&fi, &orig_mean, &weights);
1389            d * d
1390        })
1391        .sum::<f64>()
1392        / n as f64;
1393
1394    // Aligned mean
1395    let aligned_mean = crate::fdata::mean_1d(&karcher.aligned_data);
1396
1397    // Amplitude variance
1398    let amp_var: f64 = (0..n)
1399        .map(|i| {
1400            let fi = karcher.aligned_data.row(i);
1401            let d = l2_distance(&fi, &aligned_mean, &weights);
1402            d * d
1403        })
1404        .sum::<f64>()
1405        / n as f64;
1406
1407    let phase_var = (total_var - amp_var).max(0.0);
1408    let ratio = if total_var > 1e-10 {
1409        phase_var / total_var
1410    } else {
1411        0.0
1412    };
1413
1414    // Pointwise variance ratio
1415    let mut pw_ratio = vec![0.0; m];
1416    for j in 0..m {
1417        let col_orig = data.column(j);
1418        let mean_orig = col_orig.iter().sum::<f64>() / n as f64;
1419        let var_orig: f64 = col_orig
1420            .iter()
1421            .map(|&v| (v - mean_orig).powi(2))
1422            .sum::<f64>()
1423            / n as f64;
1424
1425        let col_aligned = karcher.aligned_data.column(j);
1426        let mean_aligned = col_aligned.iter().sum::<f64>() / n as f64;
1427        let var_aligned: f64 = col_aligned
1428            .iter()
1429            .map(|&v| (v - mean_aligned).powi(2))
1430            .sum::<f64>()
1431            / n as f64;
1432
1433        pw_ratio[j] = if var_orig > 1e-15 {
1434            var_aligned / var_orig
1435        } else {
1436            1.0
1437        };
1438    }
1439
1440    let mean_vr = pw_ratio.iter().sum::<f64>() / m as f64;
1441
1442    AlignmentQuality {
1443        warp_complexity: wc,
1444        mean_warp_complexity: mean_wc,
1445        warp_smoothness: ws,
1446        mean_warp_smoothness: mean_ws,
1447        total_variance: total_var,
1448        amplitude_variance: amp_var,
1449        phase_variance: phase_var,
1450        phase_amplitude_ratio: ratio,
1451        pointwise_variance_ratio: pw_ratio,
1452        mean_variance_reduction: mean_vr,
1453    }
1454}
1455
1456/// Generate triplet indices (i,j,k) with i<j<k, capped at `max_triplets` (0 = all).
1457fn triplet_indices(n: usize, max_triplets: usize) -> Vec<(usize, usize, usize)> {
1458    let total = n * (n - 1) * (n - 2) / 6;
1459    let cap = if max_triplets > 0 {
1460        max_triplets.min(total)
1461    } else {
1462        total
1463    };
1464    (0..n)
1465        .flat_map(|i| ((i + 1)..n).flat_map(move |j| ((j + 1)..n).map(move |k| (i, j, k))))
1466        .take(cap)
1467        .collect()
1468}
1469
1470/// Compute the warp deviation for one triplet: ‖γ_ij∘γ_jk − γ_ik‖_L2.
1471fn triplet_warp_deviation(
1472    data: &FdMatrix,
1473    argvals: &[f64],
1474    weights: &[f64],
1475    i: usize,
1476    j: usize,
1477    k: usize,
1478    lambda: f64,
1479) -> f64 {
1480    let fi = data.row(i);
1481    let fj = data.row(j);
1482    let fk = data.row(k);
1483    let rij = elastic_align_pair(&fi, &fj, argvals, lambda);
1484    let rjk = elastic_align_pair(&fj, &fk, argvals, lambda);
1485    let rik = elastic_align_pair(&fi, &fk, argvals, lambda);
1486    let composed = compose_warps(&rij.gamma, &rjk.gamma, argvals);
1487    l2_distance(&composed, &rik.gamma, weights)
1488}
1489
1490/// Measure pairwise alignment consistency via triplet checks.
1491///
1492/// For triplets (i,j,k), checks `γ_ij ∘ γ_jk ≈ γ_ik` by measuring the L2
1493/// deviation of the composed warp from the direct warp.
1494///
1495/// # Arguments
1496/// * `data` — Functional data (n × m)
1497/// * `argvals` — Evaluation points (length m)
1498/// * `lambda` — Penalty weight
1499/// * `max_triplets` — Maximum number of triplets to check (0 = all)
1500pub fn pairwise_consistency(
1501    data: &FdMatrix,
1502    argvals: &[f64],
1503    lambda: f64,
1504    max_triplets: usize,
1505) -> f64 {
1506    let n = data.nrows();
1507    if n < 3 {
1508        return 0.0;
1509    }
1510
1511    let weights = simpsons_weights(argvals);
1512    let triplets = triplet_indices(n, max_triplets);
1513    if triplets.is_empty() {
1514        return 0.0;
1515    }
1516
1517    let total_dev: f64 = triplets
1518        .iter()
1519        .map(|&(i, j, k)| triplet_warp_deviation(data, argvals, &weights, i, j, k, lambda))
1520        .sum();
1521    total_dev / triplets.len() as f64
1522}
1523
1524// ─── Landmark-Constrained Elastic Alignment ────────────────────────────────
1525
1526/// Result of landmark-constrained elastic alignment.
1527#[derive(Debug, Clone)]
1528pub struct ConstrainedAlignmentResult {
1529    /// Optimal warping function (length m).
1530    pub gamma: Vec<f64>,
1531    /// Aligned curve f2∘γ (length m).
1532    pub f_aligned: Vec<f64>,
1533    /// Elastic distance after alignment.
1534    pub distance: f64,
1535    /// Enforced landmark pairs (snapped to grid): `(target_t, source_t)`.
1536    pub enforced_landmarks: Vec<(f64, f64)>,
1537}
1538
1539/// Snap a time value to the nearest grid point index.
1540fn snap_to_grid(t_val: f64, argvals: &[f64]) -> usize {
1541    let mut best = 0;
1542    let mut best_dist = (t_val - argvals[0]).abs();
1543    for (i, &a) in argvals.iter().enumerate().skip(1) {
1544        let d = (t_val - a).abs();
1545        if d < best_dist {
1546            best = i;
1547            best_dist = d;
1548        }
1549    }
1550    best
1551}
1552
1553/// Run DP on a rectangular sub-grid `[sc..=ec] × [sr..=er]`.
1554///
1555/// Uses global indices for `dp_edge_weight`. Returns the path segment
1556/// as a list of `(tc_idx, tr_idx)` pairs from start to end.
1557fn dp_segment(
1558    q1: &[f64],
1559    q2: &[f64],
1560    argvals: &[f64],
1561    sc: usize,
1562    ec: usize,
1563    sr: usize,
1564    er: usize,
1565    lambda: f64,
1566) -> Vec<(usize, usize)> {
1567    let nc = ec - sc + 1;
1568    let nr = er - sr + 1;
1569
1570    if nc <= 1 || nr <= 1 {
1571        return vec![(sc, sr), (ec, er)];
1572    }
1573
1574    let path = dp_grid_solve(nr, nc, |local_sr, local_sc, local_tr, local_tc| {
1575        let gsr = sr + local_sr;
1576        let gsc = sc + local_sc;
1577        let gtr = sr + local_tr;
1578        let gtc = sc + local_tc;
1579        dp_edge_weight(q1, q2, argvals, gsc, gtc, gsr, gtr)
1580            + dp_lambda_penalty(argvals, gsc, gtc, gsr, gtr, lambda)
1581    });
1582
1583    // Convert local indices to global
1584    path.iter().map(|&(lr, lc)| (sc + lc, sr + lr)).collect()
1585}
1586
1587/// Align f2 to f1 with landmark constraints.
1588///
1589/// Landmark pairs define waypoints on the DP grid. Between consecutive waypoints,
1590/// an independent smaller DP is run. The resulting warp passes through all landmarks.
1591///
1592/// # Arguments
1593/// * `f1` — Target curve (length m)
1594/// * `f2` — Curve to align (length m)
1595/// * `argvals` — Evaluation points (length m)
1596/// * `landmark_pairs` — `(target_t, source_t)` pairs in increasing order
1597/// * `lambda` — Penalty weight
1598///
1599/// # Returns
1600/// [`ConstrainedAlignmentResult`] with warp, aligned curve, and enforced landmarks.
1601/// Build DP waypoints from landmark pairs: snap to grid, deduplicate, add endpoints.
1602fn build_constrained_waypoints(
1603    landmark_pairs: &[(f64, f64)],
1604    argvals: &[f64],
1605    m: usize,
1606) -> Vec<(usize, usize)> {
1607    let mut waypoints: Vec<(usize, usize)> = Vec::with_capacity(landmark_pairs.len() + 2);
1608    waypoints.push((0, 0));
1609    for &(tt, st) in landmark_pairs {
1610        let tc = snap_to_grid(tt, argvals);
1611        let tr = snap_to_grid(st, argvals);
1612        if let Some(&(prev_c, prev_r)) = waypoints.last() {
1613            if tc > prev_c && tr > prev_r {
1614                waypoints.push((tc, tr));
1615            }
1616        }
1617    }
1618    let last = m - 1;
1619    if let Some(&(prev_c, prev_r)) = waypoints.last() {
1620        if prev_c != last || prev_r != last {
1621            waypoints.push((last, last));
1622        }
1623    }
1624    waypoints
1625}
1626
1627/// Run DP segments between consecutive waypoints and assemble into a gamma warp.
1628fn segmented_dp_gamma(
1629    q1n: &[f64],
1630    q2n: &[f64],
1631    argvals: &[f64],
1632    waypoints: &[(usize, usize)],
1633    lambda: f64,
1634) -> Vec<f64> {
1635    let mut full_path_tc: Vec<f64> = Vec::new();
1636    let mut full_path_tr: Vec<f64> = Vec::new();
1637
1638    for seg in 0..(waypoints.len() - 1) {
1639        let (sc, sr) = waypoints[seg];
1640        let (ec, er) = waypoints[seg + 1];
1641        let segment_path = dp_segment(q1n, q2n, argvals, sc, ec, sr, er, lambda);
1642        let start = if seg > 0 { 1 } else { 0 };
1643        for &(tc, tr) in &segment_path[start..] {
1644            full_path_tc.push(argvals[tc]);
1645            full_path_tr.push(argvals[tr]);
1646        }
1647    }
1648
1649    let mut gamma: Vec<f64> = argvals
1650        .iter()
1651        .map(|&t| linear_interp(&full_path_tc, &full_path_tr, t))
1652        .collect();
1653    normalize_warp(&mut gamma, argvals);
1654    gamma
1655}
1656
1657pub fn elastic_align_pair_constrained(
1658    f1: &[f64],
1659    f2: &[f64],
1660    argvals: &[f64],
1661    landmark_pairs: &[(f64, f64)],
1662    lambda: f64,
1663) -> ConstrainedAlignmentResult {
1664    let m = f1.len();
1665
1666    if landmark_pairs.is_empty() {
1667        let r = elastic_align_pair(f1, f2, argvals, lambda);
1668        return ConstrainedAlignmentResult {
1669            gamma: r.gamma,
1670            f_aligned: r.f_aligned,
1671            distance: r.distance,
1672            enforced_landmarks: Vec::new(),
1673        };
1674    }
1675
1676    // Compute & normalize SRSFs
1677    let f1_mat = FdMatrix::from_slice(f1, 1, m).unwrap();
1678    let f2_mat = FdMatrix::from_slice(f2, 1, m).unwrap();
1679    let q1_mat = srsf_transform(&f1_mat, argvals);
1680    let q2_mat = srsf_transform(&f2_mat, argvals);
1681    let q1: Vec<f64> = q1_mat.row(0);
1682    let q2: Vec<f64> = q2_mat.row(0);
1683    let norm1 = q1.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
1684    let norm2 = q2.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
1685    let q1n: Vec<f64> = q1.iter().map(|&v| v / norm1).collect();
1686    let q2n: Vec<f64> = q2.iter().map(|&v| v / norm2).collect();
1687
1688    let waypoints = build_constrained_waypoints(landmark_pairs, argvals, m);
1689    let gamma = segmented_dp_gamma(&q1n, &q2n, argvals, &waypoints, lambda);
1690
1691    let f_aligned = reparameterize_curve(f2, argvals, &gamma);
1692    let f_aligned_mat = FdMatrix::from_slice(&f_aligned, 1, m).unwrap();
1693    let q_aligned_mat = srsf_transform(&f_aligned_mat, argvals);
1694    let q_aligned: Vec<f64> = q_aligned_mat.row(0);
1695    let weights = simpsons_weights(argvals);
1696    let distance = l2_distance(&q1, &q_aligned, &weights);
1697
1698    let enforced: Vec<(f64, f64)> = waypoints[1..waypoints.len() - 1]
1699        .iter()
1700        .map(|&(tc, tr)| (argvals[tc], argvals[tr]))
1701        .collect();
1702
1703    ConstrainedAlignmentResult {
1704        gamma,
1705        f_aligned,
1706        distance,
1707        enforced_landmarks: enforced,
1708    }
1709}
1710
1711/// Align f2 to f1 with automatic landmark detection and elastic constraints.
1712///
1713/// Detects landmarks in both curves, matches them, and uses the matches
1714/// as constraints for segmented DP alignment.
1715///
1716/// # Arguments
1717/// * `f1` — Target curve (length m)
1718/// * `f2` — Curve to align (length m)
1719/// * `argvals` — Evaluation points (length m)
1720/// * `kind` — Type of landmarks to detect
1721/// * `min_prominence` — Minimum prominence for landmark detection
1722/// * `expected_count` — Expected number of landmarks (0 = all detected)
1723/// * `lambda` — Penalty weight
1724pub fn elastic_align_pair_with_landmarks(
1725    f1: &[f64],
1726    f2: &[f64],
1727    argvals: &[f64],
1728    kind: crate::landmark::LandmarkKind,
1729    min_prominence: f64,
1730    expected_count: usize,
1731    lambda: f64,
1732) -> ConstrainedAlignmentResult {
1733    let lm1 = crate::landmark::detect_landmarks(f1, argvals, kind, min_prominence);
1734    let lm2 = crate::landmark::detect_landmarks(f2, argvals, kind, min_prominence);
1735
1736    // Match landmarks by order (take min count)
1737    let n_match = if expected_count > 0 {
1738        expected_count.min(lm1.len()).min(lm2.len())
1739    } else {
1740        lm1.len().min(lm2.len())
1741    };
1742
1743    let pairs: Vec<(f64, f64)> = (0..n_match)
1744        .map(|i| (lm1[i].position, lm2[i].position))
1745        .collect();
1746
1747    elastic_align_pair_constrained(f1, f2, argvals, &pairs, lambda)
1748}
1749
1750// ─── Multidimensional SRSF (R^d curves) ────────────────────────────────────
1751
1752use crate::matrix::FdCurveSet;
1753
1754/// Result of aligning multidimensional (R^d) curves.
1755#[derive(Debug, Clone)]
1756pub struct AlignmentResultNd {
1757    /// Optimal warping function (length m), same for all dimensions.
1758    pub gamma: Vec<f64>,
1759    /// Aligned curve: d vectors, each length m.
1760    pub f_aligned: Vec<Vec<f64>>,
1761    /// Elastic distance after alignment.
1762    pub distance: f64,
1763}
1764
1765/// Compute the SRSF transform for multidimensional (R^d) curves.
1766///
1767/// For f: \[0,1\] → R^d, the SRSF is q(t) = f'(t) / √‖f'(t)‖ where ‖·‖ is the
1768/// Euclidean norm in R^d. For d=1 this reduces to `sign(f') · √|f'|`.
1769///
1770/// # Arguments
1771/// * `data` — Set of n curves in R^d, each with m evaluation points
1772/// * `argvals` — Evaluation points (length m)
1773///
1774/// # Returns
1775/// `FdCurveSet` of SRSF values with the same shape as input.
1776/// Scale derivative vector at one point by 1/√‖f'‖, writing into result_dims.
1777#[inline]
1778fn srsf_scale_point(derivs: &[FdMatrix], result_dims: &mut [FdMatrix], i: usize, j: usize) {
1779    let d = derivs.len();
1780    let norm_sq: f64 = derivs.iter().map(|dd| dd[(i, j)].powi(2)).sum();
1781    let norm = norm_sq.sqrt();
1782    if norm < 1e-15 {
1783        for k in 0..d {
1784            result_dims[k][(i, j)] = 0.0;
1785        }
1786    } else {
1787        let scale = 1.0 / norm.sqrt();
1788        for k in 0..d {
1789            result_dims[k][(i, j)] = derivs[k][(i, j)] * scale;
1790        }
1791    }
1792}
1793
1794pub fn srsf_transform_nd(data: &FdCurveSet, argvals: &[f64]) -> FdCurveSet {
1795    let d = data.ndim();
1796    let n = data.ncurves();
1797    let m = data.npoints();
1798
1799    if d == 0 || n == 0 || m == 0 || argvals.len() != m {
1800        return FdCurveSet {
1801            dims: (0..d).map(|_| FdMatrix::zeros(n, m)).collect(),
1802        };
1803    }
1804
1805    let derivs: Vec<FdMatrix> = data
1806        .dims
1807        .iter()
1808        .map(|dim_mat| crate::fdata::deriv_1d(dim_mat, argvals, 1))
1809        .collect();
1810
1811    let mut result_dims: Vec<FdMatrix> = (0..d).map(|_| FdMatrix::zeros(n, m)).collect();
1812    for i in 0..n {
1813        for j in 0..m {
1814            srsf_scale_point(&derivs, &mut result_dims, i, j);
1815        }
1816    }
1817
1818    FdCurveSet { dims: result_dims }
1819}
1820
1821/// Reconstruct an R^d curve from its SRSF.
1822///
1823/// Given d-dimensional SRSF vectors and initial point f0, reconstructs:
1824/// `f_k(t) = f0_k + ∫₀ᵗ q_k(s) · ‖q(s)‖ ds` for each dimension k.
1825///
1826/// # Arguments
1827/// * `q` — SRSF: d vectors, each length m
1828/// * `argvals` — Evaluation points (length m)
1829/// * `f0` — Initial values in R^d (length d)
1830///
1831/// # Returns
1832/// Reconstructed curve: d vectors, each length m.
1833pub fn srsf_inverse_nd(q: &[Vec<f64>], argvals: &[f64], f0: &[f64]) -> Vec<Vec<f64>> {
1834    let d = q.len();
1835    if d == 0 {
1836        return Vec::new();
1837    }
1838    let m = q[0].len();
1839    if m == 0 {
1840        return vec![Vec::new(); d];
1841    }
1842
1843    // Compute ||q(t)|| at each time point
1844    let norms: Vec<f64> = (0..m)
1845        .map(|j| {
1846            let norm_sq: f64 = q.iter().map(|qk| qk[j].powi(2)).sum();
1847            norm_sq.sqrt()
1848        })
1849        .collect();
1850
1851    // For each dimension, integrand = q_k(t) * ||q(t)||
1852    let mut result = Vec::with_capacity(d);
1853    for k in 0..d {
1854        let integrand: Vec<f64> = (0..m).map(|j| q[k][j] * norms[j]).collect();
1855        let integral = cumulative_trapz(&integrand, argvals);
1856        let curve: Vec<f64> = integral.iter().map(|&v| f0[k] + v).collect();
1857        result.push(curve);
1858    }
1859
1860    result
1861}
1862
1863/// Core DP alignment for R^d SRSFs.
1864///
1865/// Same DP grid and coprime neighborhood as `dp_alignment_core`, but edge weight
1866/// is the sum of `dp_edge_weight` over d dimensions.
1867fn dp_alignment_core_nd(
1868    q1: &[Vec<f64>],
1869    q2: &[Vec<f64>],
1870    argvals: &[f64],
1871    lambda: f64,
1872) -> Vec<f64> {
1873    let d = q1.len();
1874    let m = argvals.len();
1875    if m < 2 || d == 0 {
1876        return argvals.to_vec();
1877    }
1878
1879    // For d=1, delegate to existing implementation for exact backward compat
1880    if d == 1 {
1881        return dp_alignment_core(&q1[0], &q2[0], argvals, lambda);
1882    }
1883
1884    // Normalize each dimension's SRSF to unit L2 norm
1885    let q1n: Vec<Vec<f64>> = q1
1886        .iter()
1887        .map(|qk| {
1888            let norm = qk.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
1889            qk.iter().map(|&v| v / norm).collect()
1890        })
1891        .collect();
1892    let q2n: Vec<Vec<f64>> = q2
1893        .iter()
1894        .map(|qk| {
1895            let norm = qk.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
1896            qk.iter().map(|&v| v / norm).collect()
1897        })
1898        .collect();
1899
1900    let path = dp_grid_solve(m, m, |sr, sc, tr, tc| {
1901        let w: f64 = (0..d)
1902            .map(|k| dp_edge_weight(&q1n[k], &q2n[k], argvals, sc, tc, sr, tr))
1903            .sum();
1904        w + dp_lambda_penalty(argvals, sc, tc, sr, tr, lambda)
1905    });
1906
1907    dp_path_to_gamma(&path, argvals)
1908}
1909
1910/// Align an R^d curve f2 to f1 using the elastic framework.
1911///
1912/// Finds the optimal warping γ (shared across all dimensions) such that
1913/// f2∘γ is as close as possible to f1 in the elastic metric.
1914///
1915/// # Arguments
1916/// * `f1` — Target curves (d dimensions)
1917/// * `f2` — Curves to align (d dimensions)
1918/// * `argvals` — Evaluation points (length m)
1919/// * `lambda` — Penalty weight (0.0 = no penalty)
1920pub fn elastic_align_pair_nd(
1921    f1: &FdCurveSet,
1922    f2: &FdCurveSet,
1923    argvals: &[f64],
1924    lambda: f64,
1925) -> AlignmentResultNd {
1926    let d = f1.ndim();
1927    let m = f1.npoints();
1928
1929    // Compute SRSFs
1930    let q1_set = srsf_transform_nd(f1, argvals);
1931    let q2_set = srsf_transform_nd(f2, argvals);
1932
1933    // Extract first curve from each dimension
1934    let q1: Vec<Vec<f64>> = q1_set.dims.iter().map(|dm| dm.row(0)).collect();
1935    let q2: Vec<Vec<f64>> = q2_set.dims.iter().map(|dm| dm.row(0)).collect();
1936
1937    // DP alignment using summed cost over dimensions
1938    let gamma = dp_alignment_core_nd(&q1, &q2, argvals, lambda);
1939
1940    // Apply warping to f2 in each dimension
1941    let f_aligned: Vec<Vec<f64>> = f2
1942        .dims
1943        .iter()
1944        .map(|dm| {
1945            let row = dm.row(0);
1946            reparameterize_curve(&row, argvals, &gamma)
1947        })
1948        .collect();
1949
1950    // Compute elastic distance: sum of squared L2 distances between aligned SRSFs
1951    let f_aligned_set = {
1952        let dims: Vec<FdMatrix> = f_aligned
1953            .iter()
1954            .map(|fa| FdMatrix::from_slice(fa, 1, m).unwrap())
1955            .collect();
1956        FdCurveSet { dims }
1957    };
1958    let q_aligned = srsf_transform_nd(&f_aligned_set, argvals);
1959    let weights = simpsons_weights(argvals);
1960
1961    let mut dist_sq = 0.0;
1962    for k in 0..d {
1963        let q1k = q1_set.dims[k].row(0);
1964        let qak = q_aligned.dims[k].row(0);
1965        let d_k = l2_distance(&q1k, &qak, &weights);
1966        dist_sq += d_k * d_k;
1967    }
1968
1969    AlignmentResultNd {
1970        gamma,
1971        f_aligned,
1972        distance: dist_sq.sqrt(),
1973    }
1974}
1975
1976/// Elastic distance between two R^d curves.
1977///
1978/// Aligns f2 to f1 and returns the post-alignment SRSF distance.
1979pub fn elastic_distance_nd(f1: &FdCurveSet, f2: &FdCurveSet, argvals: &[f64], lambda: f64) -> f64 {
1980    elastic_align_pair_nd(f1, f2, argvals, lambda).distance
1981}
1982
1983// ─── Tests ──────────────────────────────────────────────────────────────────
1984
1985#[cfg(test)]
1986mod tests {
1987    use super::*;
1988    use crate::helpers::trapz;
1989    use crate::simulation::{sim_fundata, EFunType, EValType};
1990    use crate::warping::inner_product_l2;
1991
1992    fn uniform_grid(m: usize) -> Vec<f64> {
1993        (0..m).map(|i| i as f64 / (m - 1) as f64).collect()
1994    }
1995
1996    fn make_test_data(n: usize, m: usize, seed: u64) -> FdMatrix {
1997        let t = uniform_grid(m);
1998        sim_fundata(
1999            n,
2000            &t,
2001            3,
2002            EFunType::Fourier,
2003            EValType::Exponential,
2004            Some(seed),
2005        )
2006    }
2007
2008    // ── cumulative_trapz ──
2009
2010    #[test]
2011    fn test_cumulative_trapz_constant() {
2012        // ∫₀ᵗ 1 dt = t
2013        let x = uniform_grid(50);
2014        let y = vec![1.0; 50];
2015        let result = cumulative_trapz(&y, &x);
2016        assert!((result[0]).abs() < 1e-15, "cumulative_trapz(0) should be 0");
2017        for j in 1..50 {
2018            assert!(
2019                (result[j] - x[j]).abs() < 1e-12,
2020                "∫₀^{:.3} 1 dt should be {:.3}, got {:.3}",
2021                x[j],
2022                x[j],
2023                result[j]
2024            );
2025        }
2026    }
2027
2028    #[test]
2029    fn test_cumulative_trapz_linear() {
2030        // ∫₀ᵗ s ds = t²/2
2031        let m = 100;
2032        let x = uniform_grid(m);
2033        let y: Vec<f64> = x.clone();
2034        let result = cumulative_trapz(&y, &x);
2035        for j in 1..m {
2036            let expected = x[j] * x[j] / 2.0;
2037            assert!(
2038                (result[j] - expected).abs() < 1e-4,
2039                "∫₀^{:.3} s ds: expected {expected:.6}, got {:.6}",
2040                x[j],
2041                result[j]
2042            );
2043        }
2044    }
2045
2046    // ── normalize_warp ──
2047
2048    #[test]
2049    fn test_normalize_warp_fixes_boundaries() {
2050        let t = uniform_grid(10);
2051        let mut gamma = vec![0.1; 10]; // constant, wrong boundaries
2052        normalize_warp(&mut gamma, &t);
2053        assert_eq!(gamma[0], t[0]);
2054        assert_eq!(gamma[9], t[9]);
2055    }
2056
2057    #[test]
2058    fn test_normalize_warp_enforces_monotonicity() {
2059        let t = uniform_grid(5);
2060        let mut gamma = vec![0.0, 0.5, 0.3, 0.8, 1.0]; // non-monotone at index 2
2061        normalize_warp(&mut gamma, &t);
2062        for j in 1..5 {
2063            assert!(
2064                gamma[j] >= gamma[j - 1],
2065                "gamma should be monotone after normalization at j={j}"
2066            );
2067        }
2068    }
2069
2070    #[test]
2071    fn test_normalize_warp_identity_unchanged() {
2072        let t = uniform_grid(20);
2073        let mut gamma = t.clone();
2074        normalize_warp(&mut gamma, &t);
2075        for j in 0..20 {
2076            assert!(
2077                (gamma[j] - t[j]).abs() < 1e-15,
2078                "Identity warp should be unchanged"
2079            );
2080        }
2081    }
2082
2083    // ── linear_interp ──
2084
2085    #[test]
2086    fn test_linear_interp_at_nodes() {
2087        let x = vec![0.0, 1.0, 2.0, 3.0];
2088        let y = vec![0.0, 2.0, 4.0, 6.0];
2089        for i in 0..x.len() {
2090            assert!((linear_interp(&x, &y, x[i]) - y[i]).abs() < 1e-12);
2091        }
2092    }
2093
2094    #[test]
2095    fn test_linear_interp_midpoints() {
2096        let x = vec![0.0, 1.0, 2.0];
2097        let y = vec![0.0, 2.0, 4.0];
2098        assert!((linear_interp(&x, &y, 0.5) - 1.0).abs() < 1e-12);
2099        assert!((linear_interp(&x, &y, 1.5) - 3.0).abs() < 1e-12);
2100    }
2101
2102    #[test]
2103    fn test_linear_interp_clamp() {
2104        let x = vec![0.0, 1.0, 2.0];
2105        let y = vec![1.0, 3.0, 5.0];
2106        assert!((linear_interp(&x, &y, -1.0) - 1.0).abs() < 1e-12);
2107        assert!((linear_interp(&x, &y, 3.0) - 5.0).abs() < 1e-12);
2108    }
2109
2110    #[test]
2111    fn test_linear_interp_nonuniform_grid() {
2112        let x = vec![0.0, 0.1, 0.5, 1.0];
2113        let y = vec![0.0, 1.0, 5.0, 10.0];
2114        // Between 0.1 and 0.5: slope = (5-1)/(0.5-0.1) = 10
2115        let val = linear_interp(&x, &y, 0.3);
2116        let expected = 1.0 + 10.0 * (0.3 - 0.1);
2117        assert!(
2118            (val - expected).abs() < 1e-12,
2119            "Non-uniform interp: expected {expected}, got {val}"
2120        );
2121    }
2122
2123    #[test]
2124    fn test_linear_interp_two_points() {
2125        let x = vec![0.0, 1.0];
2126        let y = vec![3.0, 7.0];
2127        assert!((linear_interp(&x, &y, 0.25) - 4.0).abs() < 1e-12);
2128        assert!((linear_interp(&x, &y, 0.75) - 6.0).abs() < 1e-12);
2129    }
2130
2131    // ── SRSF transform ──
2132
2133    #[test]
2134    fn test_srsf_transform_linear() {
2135        // f(t) = 2t: derivative = 2, SRSF = sqrt(2)
2136        let m = 50;
2137        let t = uniform_grid(m);
2138        let f: Vec<f64> = t.iter().map(|&ti| 2.0 * ti).collect();
2139        let mat = FdMatrix::from_slice(&f, 1, m).unwrap();
2140
2141        let q_mat = srsf_transform(&mat, &t);
2142        let q: Vec<f64> = q_mat.row(0);
2143
2144        let expected = 2.0_f64.sqrt();
2145        // Interior points should be close to sqrt(2)
2146        for j in 2..(m - 2) {
2147            assert!(
2148                (q[j] - expected).abs() < 0.1,
2149                "q[{j}] = {}, expected ~{expected}",
2150                q[j]
2151            );
2152        }
2153    }
2154
2155    #[test]
2156    fn test_srsf_transform_preserves_shape() {
2157        let data = make_test_data(10, 50, 42);
2158        let t = uniform_grid(50);
2159        let q = srsf_transform(&data, &t);
2160        assert_eq!(q.shape(), data.shape());
2161    }
2162
2163    #[test]
2164    fn test_srsf_transform_constant_is_zero() {
2165        // f(t) = 5 (constant): derivative = 0, SRSF = 0
2166        let m = 30;
2167        let t = uniform_grid(m);
2168        let f = vec![5.0; m];
2169        let mat = FdMatrix::from_slice(&f, 1, m).unwrap();
2170        let q_mat = srsf_transform(&mat, &t);
2171        let q: Vec<f64> = q_mat.row(0);
2172
2173        for j in 0..m {
2174            assert!(
2175                q[j].abs() < 1e-10,
2176                "SRSF of constant should be 0, got q[{j}] = {}",
2177                q[j]
2178            );
2179        }
2180    }
2181
2182    #[test]
2183    fn test_srsf_transform_negative_slope() {
2184        // f(t) = -3t: derivative = -3, SRSF = -sqrt(3)
2185        let m = 50;
2186        let t = uniform_grid(m);
2187        let f: Vec<f64> = t.iter().map(|&ti| -3.0 * ti).collect();
2188        let mat = FdMatrix::from_slice(&f, 1, m).unwrap();
2189
2190        let q_mat = srsf_transform(&mat, &t);
2191        let q: Vec<f64> = q_mat.row(0);
2192
2193        let expected = -(3.0_f64.sqrt());
2194        for j in 2..(m - 2) {
2195            assert!(
2196                (q[j] - expected).abs() < 0.15,
2197                "q[{j}] = {}, expected ~{expected}",
2198                q[j]
2199            );
2200        }
2201    }
2202
2203    #[test]
2204    fn test_srsf_transform_empty_input() {
2205        let data = FdMatrix::zeros(0, 0);
2206        let t: Vec<f64> = vec![];
2207        let q = srsf_transform(&data, &t);
2208        assert_eq!(q.shape(), (0, 0));
2209    }
2210
2211    #[test]
2212    fn test_srsf_transform_multiple_curves() {
2213        let m = 40;
2214        let t = uniform_grid(m);
2215        let data = make_test_data(5, m, 42);
2216
2217        let q = srsf_transform(&data, &t);
2218        assert_eq!(q.shape(), (5, m));
2219
2220        // Each row should have finite values
2221        for i in 0..5 {
2222            for j in 0..m {
2223                assert!(q[(i, j)].is_finite(), "SRSF should be finite at ({i},{j})");
2224            }
2225        }
2226    }
2227
2228    // ── SRSF inverse ──
2229
2230    #[test]
2231    fn test_srsf_round_trip() {
2232        let m = 100;
2233        let t = uniform_grid(m);
2234        // Use a smooth function
2235        let f: Vec<f64> = t
2236            .iter()
2237            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin() + ti)
2238            .collect();
2239
2240        let mat = FdMatrix::from_slice(&f, 1, m).unwrap();
2241        let q_mat = srsf_transform(&mat, &t);
2242        let q: Vec<f64> = q_mat.row(0);
2243
2244        let f_recon = srsf_inverse(&q, &t, f[0]);
2245
2246        // Check reconstruction is close (interior points, avoid boundary effects)
2247        let max_err: f64 = f[5..(m - 5)]
2248            .iter()
2249            .zip(f_recon[5..(m - 5)].iter())
2250            .map(|(a, b)| (a - b).abs())
2251            .fold(0.0_f64, f64::max);
2252
2253        assert!(
2254            max_err < 0.15,
2255            "Round-trip error too large: max_err = {max_err}"
2256        );
2257    }
2258
2259    #[test]
2260    fn test_srsf_inverse_empty() {
2261        let q: Vec<f64> = vec![];
2262        let t: Vec<f64> = vec![];
2263        let result = srsf_inverse(&q, &t, 0.0);
2264        assert!(result.is_empty());
2265    }
2266
2267    #[test]
2268    fn test_srsf_inverse_preserves_initial_value() {
2269        let m = 50;
2270        let t = uniform_grid(m);
2271        let q = vec![1.0; m]; // constant SRSF
2272        let f0 = 3.15;
2273        let f = srsf_inverse(&q, &t, f0);
2274        assert!((f[0] - f0).abs() < 1e-12, "srsf_inverse should start at f0");
2275    }
2276
2277    #[test]
2278    fn test_srsf_round_trip_multiple_curves() {
2279        let m = 80;
2280        let t = uniform_grid(m);
2281        let data = make_test_data(5, m, 99);
2282
2283        let q_mat = srsf_transform(&data, &t);
2284
2285        for i in 0..5 {
2286            let fi = data.row(i);
2287            let qi = q_mat.row(i);
2288            let f_recon = srsf_inverse(&qi, &t, fi[0]);
2289            let max_err: f64 = fi[5..(m - 5)]
2290                .iter()
2291                .zip(f_recon[5..(m - 5)].iter())
2292                .map(|(a, b)| (a - b).abs())
2293                .fold(0.0_f64, f64::max);
2294            assert!(max_err < 0.3, "Round-trip curve {i}: max_err = {max_err}");
2295        }
2296    }
2297
2298    // ── Reparameterize ──
2299
2300    #[test]
2301    fn test_reparameterize_identity_warp() {
2302        let m = 50;
2303        let t = uniform_grid(m);
2304        let f: Vec<f64> = t.iter().map(|&ti| ti * ti).collect();
2305
2306        // Identity warp: γ(t) = t
2307        let result = reparameterize_curve(&f, &t, &t);
2308        for j in 0..m {
2309            assert!(
2310                (result[j] - f[j]).abs() < 1e-12,
2311                "Identity warp should return original at j={j}"
2312            );
2313        }
2314    }
2315
2316    #[test]
2317    fn test_reparameterize_linear_warp() {
2318        let m = 50;
2319        let t = uniform_grid(m);
2320        // f(t) = t (linear), γ(t) = t^2 (quadratic warp on [0,1])
2321        let f: Vec<f64> = t.clone();
2322        let gamma: Vec<f64> = t.iter().map(|&ti| ti * ti).collect();
2323
2324        let result = reparameterize_curve(&f, &t, &gamma);
2325
2326        // f(γ(t)) = γ(t) = t^2 for a linear f(t) = t
2327        for j in 0..m {
2328            assert!(
2329                (result[j] - gamma[j]).abs() < 1e-10,
2330                "f(gamma(t)) should be gamma(t) for f(t)=t at j={j}"
2331            );
2332        }
2333    }
2334
2335    #[test]
2336    fn test_reparameterize_sine_with_quadratic_warp() {
2337        let m = 100;
2338        let t = uniform_grid(m);
2339        let f: Vec<f64> = t
2340            .iter()
2341            .map(|&ti| (std::f64::consts::PI * ti).sin())
2342            .collect();
2343        let gamma: Vec<f64> = t.iter().map(|&ti| ti * ti).collect(); // speeds up start
2344
2345        let result = reparameterize_curve(&f, &t, &gamma);
2346
2347        // f(γ(t)) = sin(π * t²); check a few known values
2348        for j in 0..m {
2349            let expected = (std::f64::consts::PI * gamma[j]).sin();
2350            assert!(
2351                (result[j] - expected).abs() < 0.05,
2352                "sin(π γ(t)) at j={j}: expected {expected:.4}, got {:.4}",
2353                result[j]
2354            );
2355        }
2356    }
2357
2358    #[test]
2359    fn test_reparameterize_preserves_length() {
2360        let m = 50;
2361        let t = uniform_grid(m);
2362        let f = vec![1.0; m];
2363        let gamma: Vec<f64> = t.iter().map(|&ti| ti.sqrt()).collect();
2364
2365        let result = reparameterize_curve(&f, &t, &gamma);
2366        assert_eq!(result.len(), m);
2367    }
2368
2369    // ── Compose warps ──
2370
2371    #[test]
2372    fn test_compose_warps_identity() {
2373        let m = 50;
2374        let t = uniform_grid(m);
2375        // γ(t) = t^0.5 (a warp on [0,1])
2376        let gamma: Vec<f64> = t.iter().map(|&ti| ti.sqrt()).collect();
2377
2378        // identity ∘ γ = γ
2379        let result = compose_warps(&t, &gamma, &t);
2380        for j in 0..m {
2381            assert!(
2382                (result[j] - gamma[j]).abs() < 1e-10,
2383                "id ∘ γ should be γ at j={j}"
2384            );
2385        }
2386
2387        // γ ∘ identity = γ
2388        let result2 = compose_warps(&gamma, &t, &t);
2389        for j in 0..m {
2390            assert!(
2391                (result2[j] - gamma[j]).abs() < 1e-10,
2392                "γ ∘ id should be γ at j={j}"
2393            );
2394        }
2395    }
2396
2397    #[test]
2398    fn test_compose_warps_associativity() {
2399        // (γ₁ ∘ γ₂) ∘ γ₃ ≈ γ₁ ∘ (γ₂ ∘ γ₃)
2400        let m = 50;
2401        let t = uniform_grid(m);
2402        let g1: Vec<f64> = t.iter().map(|&ti| ti.sqrt()).collect();
2403        let g2: Vec<f64> = t.iter().map(|&ti| ti * ti).collect();
2404        let g3: Vec<f64> = t.iter().map(|&ti| 0.5 * ti + 0.5 * ti * ti).collect();
2405
2406        let g12 = compose_warps(&g1, &g2, &t);
2407        let left = compose_warps(&g12, &g3, &t); // (g1∘g2) ∘ g3
2408
2409        let g23 = compose_warps(&g2, &g3, &t);
2410        let right = compose_warps(&g1, &g23, &t); // g1 ∘ (g2∘g3)
2411
2412        for j in 0..m {
2413            assert!(
2414                (left[j] - right[j]).abs() < 0.05,
2415                "Composition should be roughly associative at j={j}: left={:.4}, right={:.4}",
2416                left[j],
2417                right[j]
2418            );
2419        }
2420    }
2421
2422    #[test]
2423    fn test_compose_warps_preserves_domain() {
2424        let m = 50;
2425        let t = uniform_grid(m);
2426        let g1: Vec<f64> = t.iter().map(|&ti| ti.sqrt()).collect();
2427        let g2: Vec<f64> = t.iter().map(|&ti| ti * ti).collect();
2428
2429        let composed = compose_warps(&g1, &g2, &t);
2430        assert!(
2431            (composed[0] - t[0]).abs() < 1e-10,
2432            "Composed warp should start at domain start"
2433        );
2434        assert!(
2435            (composed[m - 1] - t[m - 1]).abs() < 1e-10,
2436            "Composed warp should end at domain end"
2437        );
2438    }
2439
2440    // ── Elastic align pair ──
2441
2442    #[test]
2443    fn test_align_identical_curves() {
2444        let m = 50;
2445        let t = uniform_grid(m);
2446        let f: Vec<f64> = t
2447            .iter()
2448            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
2449            .collect();
2450
2451        let result = elastic_align_pair(&f, &f, &t, 0.0);
2452
2453        // Distance should be near zero
2454        assert!(
2455            result.distance < 0.1,
2456            "Distance between identical curves should be near 0, got {}",
2457            result.distance
2458        );
2459
2460        // Warp should be near identity
2461        for j in 0..m {
2462            assert!(
2463                (result.gamma[j] - t[j]).abs() < 0.1,
2464                "Warp should be near identity at j={j}: gamma={}, t={}",
2465                result.gamma[j],
2466                t[j]
2467            );
2468        }
2469    }
2470
2471    #[test]
2472    fn test_align_pair_valid_output() {
2473        let data = make_test_data(2, 50, 42);
2474        let t = uniform_grid(50);
2475        let f1 = data.row(0);
2476        let f2 = data.row(1);
2477
2478        let result = elastic_align_pair(&f1, &f2, &t, 0.0);
2479
2480        assert_eq!(result.gamma.len(), 50);
2481        assert_eq!(result.f_aligned.len(), 50);
2482        assert!(result.distance >= 0.0);
2483
2484        // Warp should be monotone
2485        for j in 1..50 {
2486            assert!(
2487                result.gamma[j] >= result.gamma[j - 1],
2488                "Warp should be monotone at j={j}"
2489            );
2490        }
2491    }
2492
2493    #[test]
2494    fn test_align_pair_warp_boundaries() {
2495        let data = make_test_data(2, 50, 42);
2496        let t = uniform_grid(50);
2497        let f1 = data.row(0);
2498        let f2 = data.row(1);
2499
2500        let result = elastic_align_pair(&f1, &f2, &t, 0.0);
2501        assert!(
2502            (result.gamma[0] - t[0]).abs() < 1e-12,
2503            "Warp should start at domain start"
2504        );
2505        assert!(
2506            (result.gamma[49] - t[49]).abs() < 1e-12,
2507            "Warp should end at domain end"
2508        );
2509    }
2510
2511    #[test]
2512    fn test_align_shifted_sine() {
2513        // Two sines with a phase shift — alignment should reduce distance
2514        let m = 80;
2515        let t = uniform_grid(m);
2516        let f1: Vec<f64> = t
2517            .iter()
2518            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
2519            .collect();
2520        let f2: Vec<f64> = t
2521            .iter()
2522            .map(|&ti| (2.0 * std::f64::consts::PI * (ti - 0.1)).sin())
2523            .collect();
2524
2525        let weights = simpsons_weights(&t);
2526        let l2_before = l2_distance(&f1, &f2, &weights);
2527        let result = elastic_align_pair(&f1, &f2, &t, 0.0);
2528        let l2_after = l2_distance(&f1, &result.f_aligned, &weights);
2529
2530        assert!(
2531            l2_after < l2_before + 0.01,
2532            "Alignment should not increase L2 distance: before={l2_before:.4}, after={l2_after:.4}"
2533        );
2534    }
2535
2536    #[test]
2537    fn test_align_pair_aligned_curve_is_finite() {
2538        let data = make_test_data(2, 50, 77);
2539        let t = uniform_grid(50);
2540        let f1 = data.row(0);
2541        let f2 = data.row(1);
2542
2543        let result = elastic_align_pair(&f1, &f2, &t, 0.0);
2544        for j in 0..50 {
2545            assert!(
2546                result.f_aligned[j].is_finite(),
2547                "Aligned curve should be finite at j={j}"
2548            );
2549        }
2550    }
2551
2552    #[test]
2553    fn test_align_pair_minimum_grid() {
2554        // Minimum viable grid: m = 2
2555        let t = vec![0.0, 1.0];
2556        let f1 = vec![0.0, 1.0];
2557        let f2 = vec![0.0, 2.0];
2558        let result = elastic_align_pair(&f1, &f2, &t, 0.0);
2559        assert_eq!(result.gamma.len(), 2);
2560        assert_eq!(result.f_aligned.len(), 2);
2561        assert!(result.distance >= 0.0);
2562    }
2563
2564    // ── Elastic distance ──
2565
2566    #[test]
2567    fn test_elastic_distance_symmetric() {
2568        let data = make_test_data(3, 50, 42);
2569        let t = uniform_grid(50);
2570        let f1 = data.row(0);
2571        let f2 = data.row(1);
2572
2573        let d12 = elastic_distance(&f1, &f2, &t, 0.0);
2574        let d21 = elastic_distance(&f2, &f1, &t, 0.0);
2575
2576        // Should be approximately symmetric (DP is not perfectly symmetric)
2577        assert!(
2578            (d12 - d21).abs() < d12.max(d21) * 0.3 + 0.01,
2579            "Elastic distance should be roughly symmetric: d12={d12}, d21={d21}"
2580        );
2581    }
2582
2583    #[test]
2584    fn test_elastic_distance_nonneg() {
2585        let data = make_test_data(3, 50, 42);
2586        let t = uniform_grid(50);
2587
2588        for i in 0..3 {
2589            for j in 0..3 {
2590                let fi = data.row(i);
2591                let fj = data.row(j);
2592                let d = elastic_distance(&fi, &fj, &t, 0.0);
2593                assert!(d >= 0.0, "Elastic distance should be non-negative");
2594            }
2595        }
2596    }
2597
2598    #[test]
2599    fn test_elastic_distance_self_near_zero() {
2600        let data = make_test_data(3, 50, 42);
2601        let t = uniform_grid(50);
2602
2603        for i in 0..3 {
2604            let fi = data.row(i);
2605            let d = elastic_distance(&fi, &fi, &t, 0.0);
2606            assert!(
2607                d < 0.1,
2608                "Self-distance should be near zero, got {d} for curve {i}"
2609            );
2610        }
2611    }
2612
2613    #[test]
2614    fn test_elastic_distance_triangle_inequality() {
2615        let data = make_test_data(3, 50, 42);
2616        let t = uniform_grid(50);
2617        let f0 = data.row(0);
2618        let f1 = data.row(1);
2619        let f2 = data.row(2);
2620
2621        let d01 = elastic_distance(&f0, &f1, &t, 0.0);
2622        let d12 = elastic_distance(&f1, &f2, &t, 0.0);
2623        let d02 = elastic_distance(&f0, &f2, &t, 0.0);
2624
2625        // Relaxed triangle inequality (DP alignment is approximate)
2626        let slack = 0.5;
2627        assert!(
2628            d02 <= d01 + d12 + slack,
2629            "Triangle inequality (relaxed): d02={d02:.4} > d01={d01:.4} + d12={d12:.4} + {slack}"
2630        );
2631    }
2632
2633    #[test]
2634    fn test_elastic_distance_different_shapes_nonzero() {
2635        let m = 50;
2636        let t = uniform_grid(m);
2637        let f1: Vec<f64> = t.to_vec(); // linear
2638        let f2: Vec<f64> = t.iter().map(|&ti| ti * ti).collect(); // quadratic
2639
2640        let d = elastic_distance(&f1, &f2, &t, 0.0);
2641        assert!(
2642            d > 0.01,
2643            "Distance between different shapes should be > 0, got {d}"
2644        );
2645    }
2646
2647    // ── Self distance matrix ──
2648
2649    #[test]
2650    fn test_self_distance_matrix_symmetric() {
2651        let data = make_test_data(5, 30, 42);
2652        let t = uniform_grid(30);
2653
2654        let dm = elastic_self_distance_matrix(&data, &t, 0.0);
2655        let n = dm.nrows();
2656
2657        assert_eq!(dm.shape(), (5, 5));
2658
2659        // Zero diagonal
2660        for i in 0..n {
2661            assert!(dm[(i, i)].abs() < 1e-12, "Diagonal should be zero");
2662        }
2663
2664        // Symmetric
2665        for i in 0..n {
2666            for j in (i + 1)..n {
2667                assert!(
2668                    (dm[(i, j)] - dm[(j, i)]).abs() < 1e-12,
2669                    "Matrix should be symmetric at ({i},{j})"
2670                );
2671            }
2672        }
2673    }
2674
2675    #[test]
2676    fn test_self_distance_matrix_nonneg() {
2677        let data = make_test_data(4, 30, 42);
2678        let t = uniform_grid(30);
2679        let dm = elastic_self_distance_matrix(&data, &t, 0.0);
2680
2681        for i in 0..4 {
2682            for j in 0..4 {
2683                assert!(
2684                    dm[(i, j)] >= 0.0,
2685                    "Distance matrix entries should be non-negative at ({i},{j})"
2686                );
2687            }
2688        }
2689    }
2690
2691    #[test]
2692    fn test_self_distance_matrix_single_curve() {
2693        let data = make_test_data(1, 30, 42);
2694        let t = uniform_grid(30);
2695        let dm = elastic_self_distance_matrix(&data, &t, 0.0);
2696        assert_eq!(dm.shape(), (1, 1));
2697        assert!(dm[(0, 0)].abs() < 1e-12);
2698    }
2699
2700    #[test]
2701    fn test_self_distance_matrix_consistent_with_pairwise() {
2702        let data = make_test_data(4, 30, 42);
2703        let t = uniform_grid(30);
2704
2705        let dm = elastic_self_distance_matrix(&data, &t, 0.0);
2706
2707        // Check a few entries match direct elastic_distance calls
2708        for i in 0..4 {
2709            for j in (i + 1)..4 {
2710                let fi = data.row(i);
2711                let fj = data.row(j);
2712                let d_direct = elastic_distance(&fi, &fj, &t, 0.0);
2713                assert!(
2714                    (dm[(i, j)] - d_direct).abs() < 1e-10,
2715                    "Matrix entry ({i},{j})={:.6} should match pairwise {d_direct:.6}",
2716                    dm[(i, j)]
2717                );
2718            }
2719        }
2720    }
2721
2722    // ── Karcher mean ──
2723
2724    #[test]
2725    fn test_karcher_mean_identical_curves() {
2726        let m = 50;
2727        let t = uniform_grid(m);
2728        let f: Vec<f64> = t
2729            .iter()
2730            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
2731            .collect();
2732
2733        // Create 5 identical curves
2734        let mut data = FdMatrix::zeros(5, m);
2735        for i in 0..5 {
2736            for j in 0..m {
2737                data[(i, j)] = f[j];
2738            }
2739        }
2740
2741        let result = karcher_mean(&data, &t, 10, 1e-4, 0.0);
2742
2743        assert_eq!(result.mean.len(), m);
2744        assert!(result.n_iter <= 10);
2745    }
2746
2747    #[test]
2748    fn test_karcher_mean_output_shape() {
2749        let data = make_test_data(15, 50, 42);
2750        let t = uniform_grid(50);
2751
2752        let result = karcher_mean(&data, &t, 5, 1e-3, 0.0);
2753
2754        assert_eq!(result.mean.len(), 50);
2755        assert_eq!(result.mean_srsf.len(), 50);
2756        assert_eq!(result.gammas.shape(), (15, 50));
2757        assert_eq!(result.aligned_data.shape(), (15, 50));
2758        assert!(result.n_iter <= 5);
2759    }
2760
2761    #[test]
2762    fn test_karcher_mean_warps_are_valid() {
2763        let data = make_test_data(10, 40, 42);
2764        let t = uniform_grid(40);
2765
2766        let result = karcher_mean(&data, &t, 5, 1e-3, 0.0);
2767
2768        for i in 0..10 {
2769            // Boundary values
2770            assert!(
2771                (result.gammas[(i, 0)] - t[0]).abs() < 1e-10,
2772                "Warp {i} should start at domain start"
2773            );
2774            assert!(
2775                (result.gammas[(i, 39)] - t[39]).abs() < 1e-10,
2776                "Warp {i} should end at domain end"
2777            );
2778            // Monotonicity
2779            for j in 1..40 {
2780                assert!(
2781                    result.gammas[(i, j)] >= result.gammas[(i, j - 1)],
2782                    "Warp {i} should be monotone at j={j}"
2783                );
2784            }
2785        }
2786    }
2787
2788    #[test]
2789    fn test_karcher_mean_aligned_data_is_finite() {
2790        let data = make_test_data(8, 40, 42);
2791        let t = uniform_grid(40);
2792        let result = karcher_mean(&data, &t, 5, 1e-3, 0.0);
2793
2794        for i in 0..8 {
2795            for j in 0..40 {
2796                assert!(
2797                    result.aligned_data[(i, j)].is_finite(),
2798                    "Aligned data should be finite at ({i},{j})"
2799                );
2800            }
2801        }
2802    }
2803
2804    #[test]
2805    fn test_karcher_mean_srsf_is_finite() {
2806        let data = make_test_data(8, 40, 42);
2807        let t = uniform_grid(40);
2808        let result = karcher_mean(&data, &t, 5, 1e-3, 0.0);
2809
2810        for j in 0..40 {
2811            assert!(
2812                result.mean_srsf[j].is_finite(),
2813                "Mean SRSF should be finite at j={j}"
2814            );
2815            assert!(
2816                result.mean[j].is_finite(),
2817                "Mean curve should be finite at j={j}"
2818            );
2819        }
2820    }
2821
2822    #[test]
2823    fn test_karcher_mean_single_iteration() {
2824        let data = make_test_data(10, 40, 42);
2825        let t = uniform_grid(40);
2826        let result = karcher_mean(&data, &t, 1, 1e-10, 0.0);
2827
2828        assert_eq!(result.n_iter, 1);
2829        assert_eq!(result.mean.len(), 40);
2830        // With only 1 iteration, still produces valid output
2831        for j in 0..40 {
2832            assert!(result.mean[j].is_finite());
2833        }
2834    }
2835
2836    // ── Align to target ──
2837
2838    #[test]
2839    fn test_align_to_target_valid() {
2840        let data = make_test_data(10, 40, 42);
2841        let t = uniform_grid(40);
2842        let target = data.row(0);
2843
2844        let result = align_to_target(&data, &target, &t, 0.0);
2845
2846        assert_eq!(result.gammas.shape(), (10, 40));
2847        assert_eq!(result.aligned_data.shape(), (10, 40));
2848        assert_eq!(result.distances.len(), 10);
2849
2850        // All distances should be non-negative
2851        for &d in &result.distances {
2852            assert!(d >= 0.0);
2853        }
2854    }
2855
2856    #[test]
2857    fn test_align_to_target_self_near_zero() {
2858        let data = make_test_data(5, 40, 42);
2859        let t = uniform_grid(40);
2860        let target = data.row(0);
2861
2862        let result = align_to_target(&data, &target, &t, 0.0);
2863
2864        // Distance of target to itself should be near zero
2865        assert!(
2866            result.distances[0] < 0.1,
2867            "Self-alignment distance should be near zero, got {}",
2868            result.distances[0]
2869        );
2870    }
2871
2872    #[test]
2873    fn test_align_to_target_warps_are_monotone() {
2874        let data = make_test_data(8, 40, 42);
2875        let t = uniform_grid(40);
2876        let target = data.row(0);
2877        let result = align_to_target(&data, &target, &t, 0.0);
2878
2879        for i in 0..8 {
2880            for j in 1..40 {
2881                assert!(
2882                    result.gammas[(i, j)] >= result.gammas[(i, j - 1)],
2883                    "Warp for curve {i} should be monotone at j={j}"
2884                );
2885            }
2886        }
2887    }
2888
2889    #[test]
2890    fn test_align_to_target_aligned_data_finite() {
2891        let data = make_test_data(6, 40, 42);
2892        let t = uniform_grid(40);
2893        let target = data.row(0);
2894        let result = align_to_target(&data, &target, &t, 0.0);
2895
2896        for i in 0..6 {
2897            for j in 0..40 {
2898                assert!(
2899                    result.aligned_data[(i, j)].is_finite(),
2900                    "Aligned data should be finite at ({i},{j})"
2901                );
2902            }
2903        }
2904    }
2905
2906    // ── Cross distance matrix ──
2907
2908    #[test]
2909    fn test_cross_distance_matrix_shape() {
2910        let data1 = make_test_data(3, 30, 42);
2911        let data2 = make_test_data(4, 30, 99);
2912        let t = uniform_grid(30);
2913
2914        let dm = elastic_cross_distance_matrix(&data1, &data2, &t, 0.0);
2915        assert_eq!(dm.shape(), (3, 4));
2916
2917        // All non-negative
2918        for i in 0..3 {
2919            for j in 0..4 {
2920                assert!(dm[(i, j)] >= 0.0);
2921            }
2922        }
2923    }
2924
2925    #[test]
2926    fn test_cross_distance_matrix_self_matches_self_matrix() {
2927        // cross_distance(data, data) should have zero diagonal (approximately)
2928        let data = make_test_data(4, 30, 42);
2929        let t = uniform_grid(30);
2930
2931        let cross = elastic_cross_distance_matrix(&data, &data, &t, 0.0);
2932        for i in 0..4 {
2933            assert!(
2934                cross[(i, i)] < 0.1,
2935                "Cross distance (self) diagonal should be near zero: got {}",
2936                cross[(i, i)]
2937            );
2938        }
2939    }
2940
2941    #[test]
2942    fn test_cross_distance_matrix_consistent_with_pairwise() {
2943        let data1 = make_test_data(3, 30, 42);
2944        let data2 = make_test_data(2, 30, 99);
2945        let t = uniform_grid(30);
2946
2947        let dm = elastic_cross_distance_matrix(&data1, &data2, &t, 0.0);
2948
2949        for i in 0..3 {
2950            for j in 0..2 {
2951                let fi = data1.row(i);
2952                let fj = data2.row(j);
2953                let d_direct = elastic_distance(&fi, &fj, &t, 0.0);
2954                assert!(
2955                    (dm[(i, j)] - d_direct).abs() < 1e-10,
2956                    "Cross matrix ({i},{j})={:.6} should match pairwise {d_direct:.6}",
2957                    dm[(i, j)]
2958                );
2959            }
2960        }
2961    }
2962
2963    // ── align_srsf_pair ──
2964
2965    #[test]
2966    fn test_align_srsf_pair_identity() {
2967        let m = 50;
2968        let t = uniform_grid(m);
2969        let f: Vec<f64> = t
2970            .iter()
2971            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
2972            .collect();
2973        let q = srsf_single(&f, &t);
2974
2975        let (gamma, q_aligned) = align_srsf_pair(&q, &q, &t, 0.0);
2976
2977        // Warp should be near identity
2978        for j in 0..m {
2979            assert!(
2980                (gamma[j] - t[j]).abs() < 0.15,
2981                "Self-SRSF alignment warp should be near identity at j={j}"
2982            );
2983        }
2984
2985        // Aligned SRSF should be close to original
2986        let weights = simpsons_weights(&t);
2987        let dist = l2_distance(&q, &q_aligned, &weights);
2988        assert!(
2989            dist < 0.5,
2990            "Self-aligned SRSF distance should be small, got {dist}"
2991        );
2992    }
2993
2994    // ── srsf_single ──
2995
2996    #[test]
2997    fn test_srsf_single_matches_matrix_version() {
2998        let m = 50;
2999        let t = uniform_grid(m);
3000        let f: Vec<f64> = t.iter().map(|&ti| ti * ti + ti).collect();
3001
3002        let q_single = srsf_single(&f, &t);
3003
3004        let mat = FdMatrix::from_slice(&f, 1, m).unwrap();
3005        let q_mat = srsf_transform(&mat, &t);
3006        let q_from_mat = q_mat.row(0);
3007
3008        for j in 0..m {
3009            assert!(
3010                (q_single[j] - q_from_mat[j]).abs() < 1e-12,
3011                "srsf_single should match srsf_transform at j={j}"
3012            );
3013        }
3014    }
3015
3016    // ── gcd ──
3017
3018    #[test]
3019    fn test_gcd_basic() {
3020        assert_eq!(gcd(1, 1), 1);
3021        assert_eq!(gcd(6, 4), 2);
3022        assert_eq!(gcd(7, 5), 1);
3023        assert_eq!(gcd(12, 8), 4);
3024        assert_eq!(gcd(7, 0), 7);
3025        assert_eq!(gcd(0, 5), 5);
3026    }
3027
3028    // ── generate_coprime_nbhd ──
3029
3030    #[test]
3031    fn test_coprime_nbhd_count() {
3032        assert_eq!(generate_coprime_nbhd(1).len(), 1); // just (1,1)
3033        assert_eq!(generate_coprime_nbhd(7).len(), 35);
3034    }
3035
3036    #[test]
3037    fn test_coprime_nbhd_matches_const() {
3038        let generated = generate_coprime_nbhd(7);
3039        assert_eq!(generated.len(), COPRIME_NBHD_7.len());
3040        for (i, pair) in generated.iter().enumerate() {
3041            assert_eq!(*pair, COPRIME_NBHD_7[i], "mismatch at index {i}");
3042        }
3043    }
3044
3045    #[test]
3046    fn test_coprime_nbhd_all_coprime() {
3047        for &(i, j) in &COPRIME_NBHD_7 {
3048            assert_eq!(gcd(i, j), 1, "({i},{j}) should be coprime");
3049            assert!((1..=7).contains(&i));
3050            assert!((1..=7).contains(&j));
3051        }
3052    }
3053
3054    // ── dp_edge_weight ──
3055
3056    #[test]
3057    fn test_dp_edge_weight_diagonal() {
3058        // Diagonal move (1,1): weight = (q1[sc] - sqrt(1)*q2[sr])^2 * h
3059        let t = uniform_grid(10);
3060        let q1 = vec![1.0; 10];
3061        let q2 = vec![1.0; 10];
3062        // Identical SRSFs: weight should be 0
3063        let w = dp_edge_weight(&q1, &q2, &t, 0, 1, 0, 1);
3064        assert!(w.abs() < 1e-12, "identical SRSFs should have zero cost");
3065    }
3066
3067    #[test]
3068    fn test_dp_edge_weight_non_diagonal() {
3069        // Move (1,2): n1=2, n2=1, slope = h/(2h) = 0.5
3070        let t = uniform_grid(10);
3071        let q1 = vec![1.0; 10];
3072        let q2 = vec![0.0; 10];
3073        let w = dp_edge_weight(&q1, &q2, &t, 0, 2, 0, 1);
3074        // diff = q1[0] - sqrt(0.5)*q2[0] = 1.0 - 0 = 1.0
3075        // weight = 1.0^2 * 1.0 * (t[2]-t[0]) = 2/9
3076        let expected = 2.0 / 9.0;
3077        assert!(
3078            (w - expected).abs() < 1e-10,
3079            "dp_edge_weight (1,2): expected {expected}, got {w}"
3080        );
3081    }
3082
3083    #[test]
3084    fn test_dp_edge_weight_zero_span() {
3085        let t = uniform_grid(10);
3086        let q1 = vec![1.0; 10];
3087        let q2 = vec![1.0; 10];
3088        // n1=0: should return INFINITY
3089        assert_eq!(dp_edge_weight(&q1, &q2, &t, 3, 3, 0, 1), f64::INFINITY);
3090        // n2=0: should return INFINITY
3091        assert_eq!(dp_edge_weight(&q1, &q2, &t, 0, 1, 3, 3), f64::INFINITY);
3092    }
3093
3094    // ── DP alignment quality ──
3095
3096    #[test]
3097    fn test_alignment_improves_distance() {
3098        // Aligned SRSF distance should be less than unaligned SRSF distance
3099        let m = 50;
3100        let t = uniform_grid(m);
3101        let f1: Vec<f64> = t
3102            .iter()
3103            .map(|&x| (2.0 * std::f64::consts::PI * x).sin())
3104            .collect();
3105        // Use a larger shift so improvement is clear
3106        let f2: Vec<f64> = t
3107            .iter()
3108            .map(|&x| (2.0 * std::f64::consts::PI * (x + 0.2)).sin())
3109            .collect();
3110
3111        let q1 = srsf_single(&f1, &t);
3112        let q2 = srsf_single(&f2, &t);
3113        let weights = simpsons_weights(&t);
3114        let unaligned_srsf_dist = l2_distance(&q1, &q2, &weights);
3115
3116        let result = elastic_align_pair(&f1, &f2, &t, 0.0);
3117
3118        assert!(
3119            result.distance <= unaligned_srsf_dist + 1e-6,
3120            "aligned SRSF dist ({}) should be <= unaligned SRSF dist ({})",
3121            result.distance,
3122            unaligned_srsf_dist
3123        );
3124    }
3125
3126    // ── Edge case: constant data ──
3127
3128    #[test]
3129    fn test_alignment_constant_curves() {
3130        let m = 30;
3131        let t = uniform_grid(m);
3132        let f1 = vec![5.0; m];
3133        let f2 = vec![5.0; m];
3134
3135        let result = elastic_align_pair(&f1, &f2, &t, 0.0);
3136        assert!(
3137            result.distance < 0.01,
3138            "Constant curves: distance should be ~0"
3139        );
3140        assert_eq!(result.f_aligned.len(), m);
3141    }
3142
3143    #[test]
3144    fn test_karcher_mean_constant_curves() {
3145        let m = 30;
3146        let t = uniform_grid(m);
3147        let mut data = FdMatrix::zeros(5, m);
3148        for i in 0..5 {
3149            for j in 0..m {
3150                data[(i, j)] = 3.0;
3151            }
3152        }
3153
3154        let result = karcher_mean(&data, &t, 5, 1e-4, 0.0);
3155        for j in 0..m {
3156            assert!(
3157                (result.mean[j] - 3.0).abs() < 0.5,
3158                "Mean of constant curves should be near 3.0, got {} at j={j}",
3159                result.mean[j]
3160            );
3161        }
3162    }
3163
3164    #[test]
3165    fn test_nan_srsf_no_panic() {
3166        let m = 20;
3167        let t = uniform_grid(m);
3168        let mut f = vec![1.0; m];
3169        f[5] = f64::NAN;
3170        let mat = FdMatrix::from_slice(&f, 1, m).unwrap();
3171        let q = srsf_transform(&mat, &t);
3172        // Should not panic; NaN propagates
3173        assert_eq!(q.nrows(), 1);
3174    }
3175
3176    #[test]
3177    fn test_n1_karcher_mean() {
3178        let m = 30;
3179        let t = uniform_grid(m);
3180        let f: Vec<f64> = t.iter().map(|&ti| ti * ti).collect();
3181        let data = FdMatrix::from_slice(&f, 1, m).unwrap();
3182        let result = karcher_mean(&data, &t, 5, 1e-4, 0.0);
3183        assert_eq!(result.mean.len(), m);
3184        // With only 1 curve, the mean should be close to the original
3185        for j in 0..m {
3186            assert!(result.mean[j].is_finite());
3187        }
3188    }
3189
3190    #[test]
3191    fn test_two_point_grid() {
3192        let t = vec![0.0, 1.0];
3193        let f1 = vec![0.0, 1.0];
3194        let f2 = vec![0.0, 2.0];
3195        let d = elastic_distance(&f1, &f2, &t, 0.0);
3196        assert!(d >= 0.0);
3197        assert!(d.is_finite());
3198    }
3199
3200    #[test]
3201    fn test_non_uniform_grid_alignment() {
3202        // Non-uniform grid: points clustered near 0
3203        let t = vec![0.0, 0.01, 0.05, 0.2, 0.5, 1.0];
3204        let m = t.len();
3205        let f1: Vec<f64> = t.iter().map(|&ti: &f64| ti.sin()).collect();
3206        let f2: Vec<f64> = t.iter().map(|&ti: &f64| (ti + 0.1).sin()).collect();
3207        let result = elastic_align_pair(&f1, &f2, &t, 0.0);
3208        assert_eq!(result.gamma.len(), m);
3209        assert!(result.distance >= 0.0);
3210        assert!(result.distance.is_finite());
3211    }
3212
3213    // ── TSRVF tests ──
3214
3215    #[test]
3216    fn test_tsrvf_output_shape() {
3217        let m = 50;
3218        let n = 10;
3219        let t = uniform_grid(m);
3220        let data = make_test_data(n, m, 42);
3221        let result = tsrvf_transform(&data, &t, 5, 1e-3, 0.0);
3222        assert_eq!(
3223            result.tangent_vectors.shape(),
3224            (n, m),
3225            "Tangent vectors should be n×m"
3226        );
3227        assert_eq!(result.gammas.shape(), (n, m), "Gammas should be n×m");
3228        assert_eq!(result.srsf_norms.len(), n, "Should have n SRSF norms");
3229        assert_eq!(result.mean.len(), m, "Mean should have m points");
3230        assert_eq!(result.mean_srsf.len(), m, "Mean SRSF should have m points");
3231    }
3232
3233    #[test]
3234    fn test_tsrvf_all_finite() {
3235        let m = 50;
3236        let n = 5;
3237        let t = uniform_grid(m);
3238        let data = make_test_data(n, m, 42);
3239        let result = tsrvf_transform(&data, &t, 5, 1e-3, 0.0);
3240        for i in 0..n {
3241            for j in 0..m {
3242                assert!(
3243                    result.tangent_vectors[(i, j)].is_finite(),
3244                    "Tangent vector should be finite at ({i},{j})"
3245                );
3246            }
3247            assert!(
3248                result.srsf_norms[i].is_finite(),
3249                "SRSF norm should be finite for curve {i}"
3250            );
3251        }
3252        assert!(
3253            result.mean_srsf_norm.is_finite(),
3254            "Mean SRSF norm should be finite"
3255        );
3256    }
3257
3258    #[test]
3259    fn test_tsrvf_identical_curves_zero_tangent() {
3260        let m = 50;
3261        let t = uniform_grid(m);
3262        // Stack 5 identical curves
3263        let curve: Vec<f64> = t
3264            .iter()
3265            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
3266            .collect();
3267        let mut col_major = vec![0.0; 5 * m];
3268        for i in 0..5 {
3269            for j in 0..m {
3270                col_major[i + j * 5] = curve[j];
3271            }
3272        }
3273        let data = FdMatrix::from_column_major(col_major, 5, m).unwrap();
3274        let result = tsrvf_transform(&data, &t, 10, 1e-4, 0.0);
3275
3276        // All tangent vectors should be approximately zero
3277        for i in 0..5 {
3278            let tv_norm_sq: f64 = (0..m).map(|j| result.tangent_vectors[(i, j)].powi(2)).sum();
3279            assert!(
3280                tv_norm_sq.sqrt() < 0.5,
3281                "Identical curves should have near-zero tangent vectors, got norm = {}",
3282                tv_norm_sq.sqrt()
3283            );
3284        }
3285    }
3286
3287    #[test]
3288    fn test_tsrvf_mean_tangent_near_zero() {
3289        let m = 50;
3290        let n = 10;
3291        let t = uniform_grid(m);
3292        let data = make_test_data(n, m, 42);
3293        let result = tsrvf_transform(&data, &t, 10, 1e-3, 0.0);
3294
3295        // Mean of tangent vectors should be approximately zero (property of Karcher mean)
3296        let mut mean_tv = vec![0.0; m];
3297        for i in 0..n {
3298            for j in 0..m {
3299                mean_tv[j] += result.tangent_vectors[(i, j)];
3300            }
3301        }
3302        for j in 0..m {
3303            mean_tv[j] /= n as f64;
3304        }
3305        let mean_norm: f64 = mean_tv.iter().map(|v| v * v).sum::<f64>().sqrt();
3306        assert!(
3307            mean_norm < 1.0,
3308            "Mean tangent vector should be near zero, got norm = {mean_norm}"
3309        );
3310    }
3311
3312    #[test]
3313    fn test_tsrvf_from_alignment() {
3314        let m = 50;
3315        let n = 5;
3316        let t = uniform_grid(m);
3317        let data = make_test_data(n, m, 42);
3318        let karcher = karcher_mean(&data, &t, 5, 1e-3, 0.0);
3319        let result = tsrvf_from_alignment(&karcher, &t);
3320        assert_eq!(result.tangent_vectors.shape(), (n, m));
3321        assert!(result.mean_srsf_norm > 0.0);
3322    }
3323
3324    #[test]
3325    fn test_tsrvf_round_trip() {
3326        let m = 50;
3327        let n = 5;
3328        let t = uniform_grid(m);
3329        let data = make_test_data(n, m, 42);
3330        let result = tsrvf_transform(&data, &t, 10, 1e-3, 0.0);
3331        let reconstructed = tsrvf_inverse(&result, &t);
3332
3333        assert_eq!(reconstructed.shape(), result.tangent_vectors.shape());
3334        // Reconstructed curves should have finite values
3335        for i in 0..n {
3336            for j in 0..m {
3337                assert!(
3338                    reconstructed[(i, j)].is_finite(),
3339                    "Reconstructed curve should be finite at ({i},{j})"
3340                );
3341            }
3342        }
3343    }
3344
3345    #[test]
3346    fn test_tsrvf_single_curve() {
3347        let m = 50;
3348        let t = uniform_grid(m);
3349        let data = make_test_data(1, m, 42);
3350        let result = tsrvf_transform(&data, &t, 5, 1e-3, 0.0);
3351        assert_eq!(result.tangent_vectors.shape(), (1, m));
3352        // Single curve → tangent vector should be zero (it IS the mean)
3353        let tv_norm: f64 = (0..m)
3354            .map(|j| result.tangent_vectors[(0, j)].powi(2))
3355            .sum::<f64>()
3356            .sqrt();
3357        assert!(
3358            tv_norm < 0.5,
3359            "Single curve tangent vector should be near zero, got {tv_norm}"
3360        );
3361    }
3362
3363    #[test]
3364    fn test_tsrvf_constant_curves() {
3365        let m = 30;
3366        let t = uniform_grid(m);
3367        // Constant curves → SRSF = 0, norms = 0
3368        let data = FdMatrix::from_column_major(vec![5.0; 3 * m], 3, m).unwrap();
3369        let result = tsrvf_transform(&data, &t, 5, 1e-3, 0.0);
3370        // Should not produce NaN or Inf
3371        for i in 0..3 {
3372            for j in 0..m {
3373                let v = result.tangent_vectors[(i, j)];
3374                assert!(
3375                    !v.is_nan(),
3376                    "Constant curves should not produce NaN tangent vectors"
3377                );
3378            }
3379        }
3380    }
3381
3382    // ── Reference-value tests (sphere geometry) ─────────────────────────────
3383
3384    #[test]
3385    fn test_tsrvf_sphere_inv_exp_reference() {
3386        // Analytical reference: known vectors on the Hilbert sphere
3387        // psi1 = constant (unit L2 norm), psi2 = 1 + 0.3*sin(2πt) (normalized)
3388        let m = 21;
3389        let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
3390
3391        // Construct unit vector psi1 (constant)
3392        let raw1 = vec![1.0; m];
3393        let norm1 = inner_product_l2(&raw1, &raw1, &time).max(0.0).sqrt();
3394        let psi1: Vec<f64> = raw1.iter().map(|&v| v / norm1).collect();
3395
3396        // Construct psi2 with sinusoidal perturbation
3397        let raw2: Vec<f64> = time
3398            .iter()
3399            .map(|&t| 1.0 + 0.3 * (2.0 * std::f64::consts::PI * t).sin())
3400            .collect();
3401        let norm2 = inner_product_l2(&raw2, &raw2, &time).max(0.0).sqrt();
3402        let psi2: Vec<f64> = raw2.iter().map(|&v| v / norm2).collect();
3403
3404        // Compute theta analytically
3405        let ip = inner_product_l2(&psi1, &psi2, &time).clamp(-1.0, 1.0);
3406        let theta_expected = ip.acos();
3407
3408        // Compute via inv_exp_map_sphere
3409        let v = inv_exp_map_sphere(&psi1, &psi2, &time);
3410        let v_norm = inner_product_l2(&v, &v, &time).max(0.0).sqrt();
3411
3412        // ||v|| should equal theta
3413        assert!(
3414            (v_norm - theta_expected).abs() < 1e-10,
3415            "||v|| = {v_norm}, expected theta = {theta_expected}"
3416        );
3417
3418        // theta should be small but non-zero (perturbation is mild)
3419        assert!(
3420            theta_expected > 0.01 && theta_expected < 1.0,
3421            "theta = {theta_expected} out of expected range"
3422        );
3423    }
3424
3425    #[test]
3426    fn test_tsrvf_sphere_round_trip_reference() {
3427        // Round-trip: exp_map(psi1, inv_exp_map(psi1, psi2)) should recover psi2
3428        let m = 21;
3429        let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
3430
3431        let raw1 = vec![1.0; m];
3432        let norm1 = inner_product_l2(&raw1, &raw1, &time).max(0.0).sqrt();
3433        let psi1: Vec<f64> = raw1.iter().map(|&v| v / norm1).collect();
3434
3435        let raw2: Vec<f64> = time
3436            .iter()
3437            .map(|&t| 1.0 + 0.3 * (2.0 * std::f64::consts::PI * t).sin())
3438            .collect();
3439        let norm2 = inner_product_l2(&raw2, &raw2, &time).max(0.0).sqrt();
3440        let psi2: Vec<f64> = raw2.iter().map(|&v| v / norm2).collect();
3441
3442        let v = inv_exp_map_sphere(&psi1, &psi2, &time);
3443        let recovered = exp_map_sphere(&psi1, &v, &time);
3444
3445        // L2 error between psi2 and recovered
3446        let diff: Vec<f64> = psi2
3447            .iter()
3448            .zip(recovered.iter())
3449            .map(|(&a, &b)| (a - b).powi(2))
3450            .collect();
3451        let l2_err = trapz(&diff, &time).max(0.0).sqrt();
3452        assert!(
3453            l2_err < 1e-12,
3454            "Round-trip L2 error = {l2_err:.2e}, expected < 1e-12"
3455        );
3456    }
3457
3458    // ── Penalized alignment ──
3459
3460    #[test]
3461    fn test_penalized_alignment_lambda_zero_matches_unpenalized() {
3462        let m = 50;
3463        let t = uniform_grid(m);
3464        let data = make_test_data(2, m, 42);
3465        let f1 = data.row(0);
3466        let f2 = data.row(1);
3467
3468        let r0 = elastic_align_pair(&f1, &f2, &t, 0.0);
3469        // lambda = 0.0 should produce the same result regardless
3470        assert!(r0.distance >= 0.0);
3471        assert_eq!(r0.gamma.len(), m);
3472    }
3473
3474    #[test]
3475    fn test_penalized_alignment_smoother_warp() {
3476        let m = 80;
3477        let t = uniform_grid(m);
3478        let f1: Vec<f64> = t
3479            .iter()
3480            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
3481            .collect();
3482        let f2: Vec<f64> = t
3483            .iter()
3484            .map(|&ti| (2.0 * std::f64::consts::PI * (ti - 0.15)).sin())
3485            .collect();
3486
3487        let r_free = elastic_align_pair(&f1, &f2, &t, 0.0);
3488        let r_pen = elastic_align_pair(&f1, &f2, &t, 1.0);
3489
3490        // Measure warp deviation from identity
3491        let dev_free: f64 = r_free
3492            .gamma
3493            .iter()
3494            .zip(t.iter())
3495            .map(|(g, ti)| (g - ti).powi(2))
3496            .sum();
3497        let dev_pen: f64 = r_pen
3498            .gamma
3499            .iter()
3500            .zip(t.iter())
3501            .map(|(g, ti)| (g - ti).powi(2))
3502            .sum();
3503
3504        assert!(
3505            dev_pen <= dev_free + 1e-6,
3506            "Penalized warp should be closer to identity: free={dev_free:.6}, pen={dev_pen:.6}"
3507        );
3508    }
3509
3510    #[test]
3511    fn test_penalized_alignment_large_lambda_near_identity() {
3512        let m = 50;
3513        let t = uniform_grid(m);
3514        let f1: Vec<f64> = t
3515            .iter()
3516            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
3517            .collect();
3518        let f2: Vec<f64> = t
3519            .iter()
3520            .map(|&ti| (2.0 * std::f64::consts::PI * (ti - 0.1)).sin())
3521            .collect();
3522
3523        let r = elastic_align_pair(&f1, &f2, &t, 1000.0);
3524
3525        // With very large lambda, warp should be very close to identity
3526        let max_dev: f64 = r
3527            .gamma
3528            .iter()
3529            .zip(t.iter())
3530            .map(|(g, ti)| (g - ti).abs())
3531            .fold(0.0_f64, f64::max);
3532        assert!(
3533            max_dev < 0.05,
3534            "Large lambda should give near-identity warp: max deviation = {max_dev}"
3535        );
3536    }
3537
3538    #[test]
3539    fn test_penalized_karcher_mean() {
3540        let m = 40;
3541        let t = uniform_grid(m);
3542        let data = make_test_data(10, m, 42);
3543
3544        let result = karcher_mean(&data, &t, 5, 1e-3, 0.5);
3545        assert_eq!(result.mean.len(), m);
3546        for j in 0..m {
3547            assert!(result.mean[j].is_finite());
3548        }
3549    }
3550
3551    // ── Phase-amplitude decomposition ──
3552
3553    #[test]
3554    fn test_decomposition_identity_curves() {
3555        let m = 50;
3556        let t = uniform_grid(m);
3557        let f: Vec<f64> = t
3558            .iter()
3559            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
3560            .collect();
3561
3562        let result = elastic_decomposition(&f, &f, &t, 0.0);
3563        assert!(
3564            result.d_amplitude < 0.1,
3565            "Self-decomposition amplitude should be ~0, got {}",
3566            result.d_amplitude
3567        );
3568        assert!(
3569            result.d_phase < 0.2,
3570            "Self-decomposition phase should be ~0, got {}",
3571            result.d_phase
3572        );
3573    }
3574
3575    #[test]
3576    fn test_decomposition_pythagorean() {
3577        // d_total² ≈ d_a² + d_φ² (approximately, for the Fisher-Rao metric)
3578        let m = 80;
3579        let t = uniform_grid(m);
3580        let f1: Vec<f64> = t
3581            .iter()
3582            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
3583            .collect();
3584        let f2: Vec<f64> = t
3585            .iter()
3586            .map(|&ti| 1.2 * (2.0 * std::f64::consts::PI * (ti - 0.1)).sin())
3587            .collect();
3588
3589        let result = elastic_decomposition(&f1, &f2, &t, 0.0);
3590        let da = result.d_amplitude;
3591        let dp = result.d_phase;
3592        // Both should be non-negative
3593        assert!(da >= 0.0);
3594        assert!(dp >= 0.0);
3595    }
3596
3597    #[test]
3598    fn test_phase_distance_shifted_sine() {
3599        let m = 80;
3600        let t = uniform_grid(m);
3601        let f1: Vec<f64> = t
3602            .iter()
3603            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
3604            .collect();
3605        let f2: Vec<f64> = t
3606            .iter()
3607            .map(|&ti| (2.0 * std::f64::consts::PI * (ti - 0.15)).sin())
3608            .collect();
3609
3610        let dp = phase_distance_pair(&f1, &f2, &t, 0.0);
3611        assert!(
3612            dp > 0.01,
3613            "Phase distance of shifted curves should be > 0, got {dp}"
3614        );
3615    }
3616
3617    #[test]
3618    fn test_amplitude_distance_scaled_curve() {
3619        let m = 80;
3620        let t = uniform_grid(m);
3621        let f1: Vec<f64> = t
3622            .iter()
3623            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
3624            .collect();
3625        let f2: Vec<f64> = t
3626            .iter()
3627            .map(|&ti| 2.0 * (2.0 * std::f64::consts::PI * ti).sin())
3628            .collect();
3629
3630        let da = amplitude_distance(&f1, &f2, &t, 0.0);
3631        assert!(
3632            da > 0.01,
3633            "Amplitude distance of scaled curves should be > 0, got {da}"
3634        );
3635    }
3636
3637    #[test]
3638    fn test_phase_distance_nonneg() {
3639        let data = make_test_data(4, 40, 42);
3640        let t = uniform_grid(40);
3641        for i in 0..4 {
3642            for j in 0..4 {
3643                let fi = data.row(i);
3644                let fj = data.row(j);
3645                let dp = phase_distance_pair(&fi, &fj, &t, 0.0);
3646                assert!(dp >= 0.0, "Phase distance should be non-negative");
3647            }
3648        }
3649    }
3650
3651    // ── Parallel transport ──
3652
3653    #[test]
3654    fn test_schilds_ladder_zero_vector() {
3655        let m = 21;
3656        let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
3657        let raw = vec![1.0; m];
3658        let norm = crate::warping::l2_norm_l2(&raw, &time);
3659        let from: Vec<f64> = raw.iter().map(|&v| v / norm).collect();
3660        let raw2: Vec<f64> = time
3661            .iter()
3662            .map(|&t| 1.0 + 0.2 * (2.0 * std::f64::consts::PI * t).sin())
3663            .collect();
3664        let norm2 = crate::warping::l2_norm_l2(&raw2, &time);
3665        let to: Vec<f64> = raw2.iter().map(|&v| v / norm2).collect();
3666
3667        let zero = vec![0.0; m];
3668        let result = parallel_transport_schilds(&zero, &from, &to, &time);
3669        let result_norm: f64 = result.iter().map(|v| v * v).sum::<f64>().sqrt();
3670        assert!(
3671            result_norm < 1e-6,
3672            "Transporting zero should give zero, got norm {result_norm}"
3673        );
3674    }
3675
3676    #[test]
3677    fn test_pole_ladder_zero_vector() {
3678        let m = 21;
3679        let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
3680        let raw = vec![1.0; m];
3681        let norm = crate::warping::l2_norm_l2(&raw, &time);
3682        let from: Vec<f64> = raw.iter().map(|&v| v / norm).collect();
3683        let raw2: Vec<f64> = time
3684            .iter()
3685            .map(|&t| 1.0 + 0.2 * (2.0 * std::f64::consts::PI * t).sin())
3686            .collect();
3687        let norm2 = crate::warping::l2_norm_l2(&raw2, &time);
3688        let to: Vec<f64> = raw2.iter().map(|&v| v / norm2).collect();
3689
3690        let zero = vec![0.0; m];
3691        let result = parallel_transport_pole(&zero, &from, &to, &time);
3692        let result_norm: f64 = result.iter().map(|v| v * v).sum::<f64>().sqrt();
3693        assert!(
3694            result_norm < 1e-6,
3695            "Transporting zero should give zero, got norm {result_norm}"
3696        );
3697    }
3698
3699    #[test]
3700    fn test_schilds_preserves_norm() {
3701        let m = 51;
3702        let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
3703        let raw = vec![1.0; m];
3704        let norm = crate::warping::l2_norm_l2(&raw, &time);
3705        let from: Vec<f64> = raw.iter().map(|&v| v / norm).collect();
3706        let raw2: Vec<f64> = time
3707            .iter()
3708            .map(|&t| 1.0 + 0.15 * (2.0 * std::f64::consts::PI * t).sin())
3709            .collect();
3710        let norm2 = crate::warping::l2_norm_l2(&raw2, &time);
3711        let to: Vec<f64> = raw2.iter().map(|&v| v / norm2).collect();
3712
3713        // Small tangent vector
3714        let v: Vec<f64> = time
3715            .iter()
3716            .map(|&t| 0.1 * (4.0 * std::f64::consts::PI * t).cos())
3717            .collect();
3718        let v_norm = crate::warping::l2_norm_l2(&v, &time);
3719
3720        let transported = parallel_transport_schilds(&v, &from, &to, &time);
3721        let t_norm = crate::warping::l2_norm_l2(&transported, &time);
3722
3723        // Norm should be approximately preserved (ladder methods are first-order)
3724        assert!(
3725            (t_norm - v_norm).abs() / v_norm.max(1e-10) < 1.5,
3726            "Schild's should roughly preserve norm: original={v_norm:.4}, transported={t_norm:.4}"
3727        );
3728    }
3729
3730    #[test]
3731    fn test_tsrvf_logmap_matches_original() {
3732        let m = 50;
3733        let n = 5;
3734        let t = uniform_grid(m);
3735        let data = make_test_data(n, m, 42);
3736
3737        let result_orig = tsrvf_transform(&data, &t, 5, 1e-3, 0.0);
3738        let result_logmap =
3739            tsrvf_transform_with_method(&data, &t, 5, 1e-3, 0.0, TransportMethod::LogMap);
3740
3741        // Should be identical (LogMap delegates to original)
3742        for i in 0..n {
3743            for j in 0..m {
3744                assert!(
3745                    (result_orig.tangent_vectors[(i, j)] - result_logmap.tangent_vectors[(i, j)])
3746                        .abs()
3747                        < 1e-12,
3748                    "LogMap variant should match original at ({i},{j})"
3749                );
3750            }
3751        }
3752    }
3753
3754    #[test]
3755    fn test_tsrvf_with_schilds_produces_valid_result() {
3756        let m = 50;
3757        let n = 5;
3758        let t = uniform_grid(m);
3759        let data = make_test_data(n, m, 42);
3760
3761        let result =
3762            tsrvf_transform_with_method(&data, &t, 5, 1e-3, 0.0, TransportMethod::SchildsLadder);
3763
3764        assert_eq!(result.tangent_vectors.shape(), (n, m));
3765        for i in 0..n {
3766            for j in 0..m {
3767                assert!(
3768                    result.tangent_vectors[(i, j)].is_finite(),
3769                    "Schild's TSRVF should produce finite tangent vectors at ({i},{j})"
3770                );
3771            }
3772        }
3773    }
3774
3775    #[test]
3776    fn test_transport_methods_differ() {
3777        let m = 50;
3778        let n = 5;
3779        let t = uniform_grid(m);
3780        let data = make_test_data(n, m, 42);
3781        let karcher = karcher_mean(&data, &t, 5, 1e-3, 0.0);
3782
3783        let r_log = tsrvf_from_alignment_with_method(&karcher, &t, TransportMethod::LogMap);
3784        let r_schilds =
3785            tsrvf_from_alignment_with_method(&karcher, &t, TransportMethod::SchildsLadder);
3786
3787        // Methods should produce different (but related) tangent vectors
3788        let mut total_diff = 0.0;
3789        for i in 0..n {
3790            for j in 0..m {
3791                total_diff +=
3792                    (r_log.tangent_vectors[(i, j)] - r_schilds.tangent_vectors[(i, j)]).abs();
3793            }
3794        }
3795
3796        // They should be non-zero different (unless all curves are identical to mean)
3797        // Just check both produce finite results
3798        assert!(total_diff.is_finite());
3799    }
3800
3801    // ── Alignment quality metrics ──
3802
3803    #[test]
3804    fn test_warp_complexity_identity_is_zero() {
3805        let m = 50;
3806        let t = uniform_grid(m);
3807        let identity = t.clone();
3808        let c = warp_complexity(&identity, &t);
3809        assert!(
3810            c < 1e-10,
3811            "Identity warp should have zero complexity, got {c}"
3812        );
3813    }
3814
3815    #[test]
3816    fn test_warp_complexity_nonidentity_positive() {
3817        let m = 50;
3818        let t = uniform_grid(m);
3819        let gamma: Vec<f64> = t.iter().map(|&ti| ti * ti).collect();
3820        let c = warp_complexity(&gamma, &t);
3821        assert!(
3822            c > 0.01,
3823            "Non-identity warp should have positive complexity, got {c}"
3824        );
3825    }
3826
3827    #[test]
3828    fn test_warp_smoothness_identity_is_zero() {
3829        let m = 50;
3830        let t = uniform_grid(m);
3831        let identity = t.clone();
3832        let s = warp_smoothness(&identity, &t);
3833        assert!(
3834            s < 1e-6,
3835            "Identity warp (constant γ'=1, γ''=0) should have near-zero bending energy, got {s}"
3836        );
3837    }
3838
3839    #[test]
3840    fn test_alignment_quality_basic() {
3841        let m = 50;
3842        let n = 8;
3843        let t = uniform_grid(m);
3844        let data = make_test_data(n, m, 42);
3845        let karcher = karcher_mean(&data, &t, 10, 1e-3, 0.0);
3846        let quality = alignment_quality(&data, &karcher, &t);
3847
3848        // Shape checks
3849        assert_eq!(quality.warp_complexity.len(), n);
3850        assert_eq!(quality.warp_smoothness.len(), n);
3851        assert_eq!(quality.pointwise_variance_ratio.len(), m);
3852
3853        // Non-negativity
3854        assert!(quality.total_variance >= 0.0);
3855        assert!(quality.amplitude_variance >= 0.0);
3856        assert!(quality.phase_variance >= 0.0);
3857        assert!(quality.mean_warp_complexity >= 0.0);
3858        assert!(quality.mean_warp_smoothness >= 0.0);
3859
3860        // Amplitude variance ≤ total variance
3861        assert!(
3862            quality.amplitude_variance <= quality.total_variance + 1e-10,
3863            "Amplitude variance ({}) should be ≤ total variance ({})",
3864            quality.amplitude_variance,
3865            quality.total_variance
3866        );
3867    }
3868
3869    #[test]
3870    fn test_alignment_quality_identical_curves() {
3871        let m = 50;
3872        let t = uniform_grid(m);
3873        let curve: Vec<f64> = t
3874            .iter()
3875            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
3876            .collect();
3877        let mut col_major = vec![0.0; 5 * m];
3878        for i in 0..5 {
3879            for j in 0..m {
3880                col_major[i + j * 5] = curve[j];
3881            }
3882        }
3883        let data = FdMatrix::from_column_major(col_major, 5, m).unwrap();
3884        let karcher = karcher_mean(&data, &t, 5, 1e-3, 0.0);
3885        let quality = alignment_quality(&data, &karcher, &t);
3886
3887        // Identical curves → near-zero variances and warp complexities
3888        assert!(
3889            quality.total_variance < 0.01,
3890            "Identical curves should have near-zero total variance, got {}",
3891            quality.total_variance
3892        );
3893        assert!(
3894            quality.mean_warp_complexity < 0.1,
3895            "Identical curves should have near-zero warp complexity, got {}",
3896            quality.mean_warp_complexity
3897        );
3898    }
3899
3900    #[test]
3901    fn test_alignment_quality_variance_reduction() {
3902        let m = 50;
3903        let n = 10;
3904        let t = uniform_grid(m);
3905        let data = make_test_data(n, m, 42);
3906        let karcher = karcher_mean(&data, &t, 10, 1e-3, 0.0);
3907        let quality = alignment_quality(&data, &karcher, &t);
3908
3909        // Mean variance ratio should be ≤ ~1 (alignment shouldn't increase variance)
3910        assert!(
3911            quality.mean_variance_reduction <= 1.5,
3912            "Mean variance reduction ratio should be ≤ ~1, got {}",
3913            quality.mean_variance_reduction
3914        );
3915    }
3916
3917    #[test]
3918    fn test_pairwise_consistency_small() {
3919        let m = 40;
3920        let n = 4;
3921        let t = uniform_grid(m);
3922        let data = make_test_data(n, m, 42);
3923
3924        let consistency = pairwise_consistency(&data, &t, 0.0, 100);
3925        assert!(
3926            consistency.is_finite() && consistency >= 0.0,
3927            "Pairwise consistency should be finite and non-negative, got {consistency}"
3928        );
3929    }
3930
3931    // ── Multidimensional SRSF ──
3932
3933    #[test]
3934    fn test_srsf_nd_d1_matches_existing() {
3935        let m = 50;
3936        let t = uniform_grid(m);
3937        let data = make_test_data(3, m, 42);
3938
3939        // 1D via existing function
3940        let q_1d = srsf_transform(&data, &t);
3941
3942        // 1D via nd function
3943        let data_nd = FdCurveSet::from_1d(data);
3944        let q_nd = srsf_transform_nd(&data_nd, &t);
3945
3946        assert_eq!(q_nd.ndim(), 1);
3947        for i in 0..3 {
3948            for j in 0..m {
3949                assert!(
3950                    (q_1d[(i, j)] - q_nd.dims[0][(i, j)]).abs() < 1e-10,
3951                    "1D nd SRSF should match existing at ({i},{j}): {} vs {}",
3952                    q_1d[(i, j)],
3953                    q_nd.dims[0][(i, j)]
3954                );
3955            }
3956        }
3957    }
3958
3959    #[test]
3960    fn test_srsf_nd_constant_is_zero() {
3961        let m = 30;
3962        let t = uniform_grid(m);
3963        // Constant R^2 curve: f(t) = (3.0, -1.0)
3964        let dim0 = FdMatrix::from_column_major(vec![3.0; m], 1, m).unwrap();
3965        let dim1 = FdMatrix::from_column_major(vec![-1.0; m], 1, m).unwrap();
3966        let data = FdCurveSet::from_dims(vec![dim0, dim1]).unwrap();
3967
3968        let q = srsf_transform_nd(&data, &t);
3969        for k in 0..2 {
3970            for j in 0..m {
3971                assert!(
3972                    q.dims[k][(0, j)].abs() < 1e-10,
3973                    "Constant curve SRSF should be zero, dim {k} at {j}: {}",
3974                    q.dims[k][(0, j)]
3975                );
3976            }
3977        }
3978    }
3979
3980    #[test]
3981    fn test_srsf_nd_linear_r2() {
3982        let m = 51;
3983        let t = uniform_grid(m);
3984        // f(t) = (2t, 3t) → f'(t) = (2, 3), ||f'|| = sqrt(13)
3985        // q(t) = (2, 3) / sqrt(sqrt(13)) = (2, 3) / 13^(1/4)
3986        let dim0 =
3987            FdMatrix::from_slice(&t.iter().map(|&ti| 2.0 * ti).collect::<Vec<_>>(), 1, m).unwrap();
3988        let dim1 =
3989            FdMatrix::from_slice(&t.iter().map(|&ti| 3.0 * ti).collect::<Vec<_>>(), 1, m).unwrap();
3990        let data = FdCurveSet::from_dims(vec![dim0, dim1]).unwrap();
3991
3992        let q = srsf_transform_nd(&data, &t);
3993        let expected_scale = 1.0 / 13.0_f64.powf(0.25);
3994        let mid = m / 2;
3995
3996        assert!(
3997            (q.dims[0][(0, mid)] - 2.0 * expected_scale).abs() < 0.1,
3998            "q_x at midpoint: {} vs expected {}",
3999            q.dims[0][(0, mid)],
4000            2.0 * expected_scale
4001        );
4002        assert!(
4003            (q.dims[1][(0, mid)] - 3.0 * expected_scale).abs() < 0.1,
4004            "q_y at midpoint: {} vs expected {}",
4005            q.dims[1][(0, mid)],
4006            3.0 * expected_scale
4007        );
4008    }
4009
4010    #[test]
4011    fn test_srsf_nd_round_trip() {
4012        let m = 51;
4013        let t = uniform_grid(m);
4014        // f(t) = (sin(2πt), cos(2πt))
4015        let pi2 = 2.0 * std::f64::consts::PI;
4016        let vals_x: Vec<f64> = t.iter().map(|&ti| (pi2 * ti).sin()).collect();
4017        let vals_y: Vec<f64> = t.iter().map(|&ti| (pi2 * ti).cos()).collect();
4018        let dim0 = FdMatrix::from_slice(&vals_x, 1, m).unwrap();
4019        let dim1 = FdMatrix::from_slice(&vals_y, 1, m).unwrap();
4020        let data = FdCurveSet::from_dims(vec![dim0, dim1]).unwrap();
4021
4022        let q = srsf_transform_nd(&data, &t);
4023        let q_vecs: Vec<Vec<f64>> = q.dims.iter().map(|dm| dm.row(0)).collect();
4024        let f0 = vec![vals_x[0], vals_y[0]];
4025        let recon = srsf_inverse_nd(&q_vecs, &t, &f0);
4026
4027        // Check reconstruction error (skip boundary points)
4028        let mut max_err = 0.0_f64;
4029        for k in 0..2 {
4030            let orig = if k == 0 { &vals_x } else { &vals_y };
4031            for j in 2..(m - 2) {
4032                let err = (recon[k][j] - orig[j]).abs();
4033                max_err = max_err.max(err);
4034            }
4035        }
4036        assert!(
4037            max_err < 0.2,
4038            "SRSF round-trip max error should be small, got {max_err}"
4039        );
4040    }
4041
4042    #[test]
4043    fn test_align_nd_identical_near_zero() {
4044        let m = 50;
4045        let t = uniform_grid(m);
4046        let pi2 = 2.0 * std::f64::consts::PI;
4047        let vals_x: Vec<f64> = t.iter().map(|&ti| (pi2 * ti).sin()).collect();
4048        let vals_y: Vec<f64> = t.iter().map(|&ti| (pi2 * ti).cos()).collect();
4049        let dim0 = FdMatrix::from_slice(&vals_x, 1, m).unwrap();
4050        let dim1 = FdMatrix::from_slice(&vals_y, 1, m).unwrap();
4051        let data = FdCurveSet::from_dims(vec![dim0, dim1]).unwrap();
4052
4053        let result = elastic_align_pair_nd(&data, &data, &t, 0.0);
4054        assert!(
4055            result.distance < 0.5,
4056            "Self-alignment distance should be ~0, got {}",
4057            result.distance
4058        );
4059        // Gamma should be near identity
4060        let max_dev: f64 = result
4061            .gamma
4062            .iter()
4063            .zip(t.iter())
4064            .map(|(g, ti)| (g - ti).abs())
4065            .fold(0.0_f64, f64::max);
4066        assert!(
4067            max_dev < 0.1,
4068            "Self-alignment warp should be near identity, max dev = {max_dev}"
4069        );
4070    }
4071
4072    #[test]
4073    fn test_align_nd_shifted_r2() {
4074        let m = 60;
4075        let t = uniform_grid(m);
4076        let pi2 = 2.0 * std::f64::consts::PI;
4077
4078        // f1(t) = (sin(2πt), cos(2πt))
4079        let f1x: Vec<f64> = t.iter().map(|&ti| (pi2 * ti).sin()).collect();
4080        let f1y: Vec<f64> = t.iter().map(|&ti| (pi2 * ti).cos()).collect();
4081        let f1 = FdCurveSet::from_dims(vec![
4082            FdMatrix::from_slice(&f1x, 1, m).unwrap(),
4083            FdMatrix::from_slice(&f1y, 1, m).unwrap(),
4084        ])
4085        .unwrap();
4086
4087        // f2(t) = (sin(2π(t-0.1)), cos(2π(t-0.1))) — phase shifted
4088        let f2x: Vec<f64> = t.iter().map(|&ti| (pi2 * (ti - 0.1)).sin()).collect();
4089        let f2y: Vec<f64> = t.iter().map(|&ti| (pi2 * (ti - 0.1)).cos()).collect();
4090        let f2 = FdCurveSet::from_dims(vec![
4091            FdMatrix::from_slice(&f2x, 1, m).unwrap(),
4092            FdMatrix::from_slice(&f2y, 1, m).unwrap(),
4093        ])
4094        .unwrap();
4095
4096        let result = elastic_align_pair_nd(&f1, &f2, &t, 0.0);
4097        assert!(
4098            result.distance.is_finite(),
4099            "Distance should be finite, got {}",
4100            result.distance
4101        );
4102        assert_eq!(result.f_aligned.len(), 2);
4103        assert_eq!(result.f_aligned[0].len(), m);
4104        // Warp should deviate from identity (non-trivial alignment)
4105        let max_dev: f64 = result
4106            .gamma
4107            .iter()
4108            .zip(t.iter())
4109            .map(|(g, ti)| (g - ti).abs())
4110            .fold(0.0_f64, f64::max);
4111        assert!(
4112            max_dev > 0.01,
4113            "Shifted curves should require non-trivial warp, max dev = {max_dev}"
4114        );
4115    }
4116
4117    // ── Landmark-constrained alignment ──
4118
4119    #[test]
4120    fn test_constrained_no_landmarks_matches_unconstrained() {
4121        let m = 50;
4122        let t = uniform_grid(m);
4123        let f1: Vec<f64> = t
4124            .iter()
4125            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
4126            .collect();
4127        let f2: Vec<f64> = t
4128            .iter()
4129            .map(|&ti| (2.0 * std::f64::consts::PI * (ti - 0.1)).sin())
4130            .collect();
4131
4132        let r_free = elastic_align_pair(&f1, &f2, &t, 0.0);
4133        let r_const = elastic_align_pair_constrained(&f1, &f2, &t, &[], 0.0);
4134
4135        // Should match unconstrained
4136        for j in 0..m {
4137            assert!(
4138                (r_free.gamma[j] - r_const.gamma[j]).abs() < 1e-10,
4139                "No-landmark constrained should match unconstrained at {j}"
4140            );
4141        }
4142        assert!(r_const.enforced_landmarks.is_empty());
4143    }
4144
4145    #[test]
4146    fn test_constrained_single_landmark_enforced() {
4147        let m = 60;
4148        let t = uniform_grid(m);
4149        let f1: Vec<f64> = t
4150            .iter()
4151            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
4152            .collect();
4153        let f2: Vec<f64> = t
4154            .iter()
4155            .map(|&ti| (2.0 * std::f64::consts::PI * (ti - 0.1)).sin())
4156            .collect();
4157
4158        // Constrain midpoint: target_t=0.5 should map to source_t=0.5
4159        let result = elastic_align_pair_constrained(&f1, &f2, &t, &[(0.5, 0.5)], 0.0);
4160
4161        // Gamma at the midpoint should be close to 0.5
4162        let mid_idx = snap_to_grid(0.5, &t);
4163        assert!(
4164            (result.gamma[mid_idx] - 0.5).abs() < 0.05,
4165            "Constrained gamma at midpoint should be ~0.5, got {}",
4166            result.gamma[mid_idx]
4167        );
4168        assert_eq!(result.enforced_landmarks.len(), 1);
4169    }
4170
4171    #[test]
4172    fn test_constrained_multiple_landmarks() {
4173        let m = 80;
4174        let t = uniform_grid(m);
4175        let f1: Vec<f64> = t
4176            .iter()
4177            .map(|&ti| (4.0 * std::f64::consts::PI * ti).sin())
4178            .collect();
4179        let f2: Vec<f64> = t
4180            .iter()
4181            .map(|&ti| (4.0 * std::f64::consts::PI * (ti - 0.05)).sin())
4182            .collect();
4183
4184        let landmarks = vec![(0.25, 0.25), (0.5, 0.5), (0.75, 0.75)];
4185        let result = elastic_align_pair_constrained(&f1, &f2, &t, &landmarks, 0.0);
4186
4187        // Gamma should pass through (or near) each landmark
4188        for &(tt, st) in &landmarks {
4189            let idx = snap_to_grid(tt, &t);
4190            assert!(
4191                (result.gamma[idx] - st).abs() < 0.05,
4192                "Gamma at t={tt} should be ~{st}, got {}",
4193                result.gamma[idx]
4194            );
4195        }
4196    }
4197
4198    #[test]
4199    fn test_constrained_monotone_gamma() {
4200        let m = 60;
4201        let t = uniform_grid(m);
4202        let f1: Vec<f64> = t
4203            .iter()
4204            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
4205            .collect();
4206        let f2: Vec<f64> = t
4207            .iter()
4208            .map(|&ti| (2.0 * std::f64::consts::PI * (ti - 0.1)).sin())
4209            .collect();
4210
4211        let result = elastic_align_pair_constrained(&f1, &f2, &t, &[(0.3, 0.3), (0.7, 0.7)], 0.0);
4212
4213        // Gamma should be non-decreasing
4214        for j in 1..m {
4215            assert!(
4216                result.gamma[j] >= result.gamma[j - 1] - 1e-10,
4217                "Gamma should be monotone: gamma[{}]={} < gamma[{}]={}",
4218                j,
4219                result.gamma[j],
4220                j - 1,
4221                result.gamma[j - 1]
4222            );
4223        }
4224        // Boundary conditions
4225        assert!((result.gamma[0] - t[0]).abs() < 1e-10);
4226        assert!((result.gamma[m - 1] - t[m - 1]).abs() < 1e-10);
4227    }
4228
4229    #[test]
4230    fn test_constrained_distance_ge_unconstrained() {
4231        let m = 60;
4232        let t = uniform_grid(m);
4233        let f1: Vec<f64> = t
4234            .iter()
4235            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
4236            .collect();
4237        let f2: Vec<f64> = t
4238            .iter()
4239            .map(|&ti| (2.0 * std::f64::consts::PI * (ti - 0.15)).sin())
4240            .collect();
4241
4242        let r_free = elastic_align_pair(&f1, &f2, &t, 0.0);
4243        let r_const = elastic_align_pair_constrained(&f1, &f2, &t, &[(0.5, 0.5)], 0.0);
4244
4245        // Constrained distance should be >= unconstrained (constraints reduce freedom)
4246        assert!(
4247            r_const.distance >= r_free.distance - 1e-6,
4248            "Constrained distance ({}) should be >= unconstrained ({})",
4249            r_const.distance,
4250            r_free.distance
4251        );
4252    }
4253
4254    #[test]
4255    fn test_constrained_with_landmark_detection() {
4256        let m = 80;
4257        let t = uniform_grid(m);
4258        let f1: Vec<f64> = t
4259            .iter()
4260            .map(|&ti| (4.0 * std::f64::consts::PI * ti).sin())
4261            .collect();
4262        let f2: Vec<f64> = t
4263            .iter()
4264            .map(|&ti| (4.0 * std::f64::consts::PI * (ti - 0.05)).sin())
4265            .collect();
4266
4267        let result = elastic_align_pair_with_landmarks(
4268            &f1,
4269            &f2,
4270            &t,
4271            crate::landmark::LandmarkKind::Peak,
4272            0.1,
4273            0,
4274            0.0,
4275        );
4276
4277        assert_eq!(result.gamma.len(), m);
4278        assert_eq!(result.f_aligned.len(), m);
4279        assert!(result.distance.is_finite());
4280        // Should be monotone
4281        for j in 1..m {
4282            assert!(
4283                result.gamma[j] >= result.gamma[j - 1] - 1e-10,
4284                "Gamma should be monotone at j={j}"
4285            );
4286        }
4287    }
4288}