Skip to main content

fdars_core/
alignment.rs

1//! Elastic alignment and SRSF (Square-Root Slope Function) transforms.
2//!
3//! This module provides phase-amplitude separation for functional data via
4//! the elastic framework. Key capabilities:
5//!
6//! - [`srsf_transform`] / [`srsf_inverse`] — SRSF representation and reconstruction
7//! - [`elastic_align_pair`] — Pairwise curve alignment via dynamic programming
8//! - [`elastic_distance`] — Elastic (Fisher-Rao) distance between curves
9//! - [`align_to_target`] — Align a set of curves to a common target
10//! - [`karcher_mean`] — Karcher (Fréchet) mean in the elastic metric
11//! - [`elastic_self_distance_matrix`] / [`elastic_cross_distance_matrix`] — Distance matrices
12//! - [`reparameterize_curve`] / [`compose_warps`] — Warping utilities
13
14use crate::fdata::{deriv_1d, mean_1d};
15use crate::helpers::{
16    cumulative_trapz, gradient_uniform, l2_distance, linear_interp, simpsons_weights,
17};
18use crate::iter_maybe_parallel;
19use crate::matrix::FdMatrix;
20use crate::smoothing::nadaraya_watson;
21use crate::warping::{
22    exp_map_sphere, gam_to_psi, inv_exp_map_sphere, invert_gamma, l2_norm_l2, normalize_warp,
23    psi_to_gam,
24};
25#[cfg(feature = "parallel")]
26use rayon::iter::ParallelIterator;
27
28// ─── Types ──────────────────────────────────────────────────────────────────
29
30/// Result of aligning one curve to another.
31#[derive(Debug, Clone)]
32pub struct AlignmentResult {
33    /// Warping function γ mapping the domain to itself.
34    pub gamma: Vec<f64>,
35    /// The aligned (reparameterized) curve.
36    pub f_aligned: Vec<f64>,
37    /// Elastic distance after alignment.
38    pub distance: f64,
39}
40
41/// Result of aligning a set of curves to a common target.
42#[derive(Debug, Clone)]
43pub struct AlignmentSetResult {
44    /// Warping functions (n × m).
45    pub gammas: FdMatrix,
46    /// Aligned curves (n × m).
47    pub aligned_data: FdMatrix,
48    /// Elastic distances for each curve.
49    pub distances: Vec<f64>,
50}
51
52/// Result of the Karcher mean computation.
53#[derive(Debug, Clone)]
54pub struct KarcherMeanResult {
55    /// Karcher mean curve.
56    pub mean: Vec<f64>,
57    /// SRSF of the Karcher mean.
58    pub mean_srsf: Vec<f64>,
59    /// Final warping functions (n × m).
60    pub gammas: FdMatrix,
61    /// Curves aligned to the mean (n × m).
62    pub aligned_data: FdMatrix,
63    /// Number of iterations used.
64    pub n_iter: usize,
65    /// Whether the algorithm converged.
66    pub converged: bool,
67    /// Pre-computed SRSFs of aligned curves (n × m), if available.
68    /// When set, FPCA functions use these instead of recomputing from `aligned_data`.
69    pub aligned_srsfs: Option<FdMatrix>,
70}
71
72// Private helpers are now in crate::helpers and crate::warping.
73
74/// Karcher mean of warping functions on the Hilbert sphere, then invert.
75/// Port of fdasrvf's `SqrtMeanInverse`.
76///
77/// Takes a matrix of warping functions (n × m) on the argvals domain,
78/// computes the Fréchet mean of their sqrt-derivative representations
79/// on the unit Hilbert sphere, converts back to a warping function,
80/// and returns its inverse (on the argvals domain).
81/// One Karcher iteration on the Hilbert sphere: compute mean shooting vector and update mu.
82///
83/// Returns `true` if converged (vbar norm ≤ threshold).
84fn karcher_sphere_step(mu: &mut Vec<f64>, psis: &[Vec<f64>], time: &[f64], step_size: f64) -> bool {
85    let m = mu.len();
86    let n = psis.len();
87    let mut vbar = vec![0.0; m];
88    for psi in psis {
89        let v = inv_exp_map_sphere(mu, psi, time);
90        for j in 0..m {
91            vbar[j] += v[j];
92        }
93    }
94    for j in 0..m {
95        vbar[j] /= n as f64;
96    }
97    if l2_norm_l2(&vbar, time) <= 1e-8 {
98        return true;
99    }
100    let scaled: Vec<f64> = vbar.iter().map(|&v| v * step_size).collect();
101    *mu = exp_map_sphere(mu, &scaled, time);
102    false
103}
104
105pub(crate) fn sqrt_mean_inverse(gammas: &FdMatrix, argvals: &[f64]) -> Vec<f64> {
106    let (n, m) = gammas.shape();
107    let t0 = argvals[0];
108    let t1 = argvals[m - 1];
109    let domain = t1 - t0;
110
111    let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
112    let binsize = 1.0 / (m - 1) as f64;
113
114    let psis: Vec<Vec<f64>> = (0..n)
115        .map(|i| {
116            let gam_01: Vec<f64> = (0..m).map(|j| (gammas[(i, j)] - t0) / domain).collect();
117            gam_to_psi(&gam_01, binsize)
118        })
119        .collect();
120
121    let mut mu = vec![0.0; m];
122    for psi in &psis {
123        for j in 0..m {
124            mu[j] += psi[j];
125        }
126    }
127    for j in 0..m {
128        mu[j] /= n as f64;
129    }
130
131    for _ in 0..501 {
132        if karcher_sphere_step(&mut mu, &psis, &time, 0.3) {
133            break;
134        }
135    }
136
137    let gam_mu = psi_to_gam(&mu, &time);
138    let gam_inv = invert_gamma(&gam_mu, &time);
139    gam_inv.iter().map(|&g| t0 + g * domain).collect()
140}
141
142// ─── SRSF Transform and Inverse ─────────────────────────────────────────────
143
144/// Compute the Square-Root Slope Function (SRSF) transform.
145///
146/// For each curve f, the SRSF is: `q(t) = sign(f'(t)) * sqrt(|f'(t)|)`
147///
148/// # Arguments
149/// * `data` — Functional data matrix (n × m)
150/// * `argvals` — Evaluation points (length m)
151///
152/// # Returns
153/// FdMatrix of SRSFs with the same shape as input.
154pub fn srsf_transform(data: &FdMatrix, argvals: &[f64]) -> FdMatrix {
155    let (n, m) = data.shape();
156    if n == 0 || m == 0 || argvals.len() != m {
157        return FdMatrix::zeros(n, m);
158    }
159
160    let deriv = deriv_1d(data, argvals, 1);
161
162    let mut result = FdMatrix::zeros(n, m);
163    for i in 0..n {
164        for j in 0..m {
165            let d = deriv[(i, j)];
166            result[(i, j)] = d.signum() * d.abs().sqrt();
167        }
168    }
169    result
170}
171
172/// Reconstruct a curve from its SRSF representation.
173///
174/// Given SRSF q and initial value f0, reconstructs: `f(t) = f0 + ∫₀ᵗ q(s)|q(s)| ds`
175///
176/// # Arguments
177/// * `q` — SRSF values (length m)
178/// * `argvals` — Evaluation points (length m)
179/// * `f0` — Initial value f(argvals\[0\])
180///
181/// # Returns
182/// Reconstructed curve values.
183pub fn srsf_inverse(q: &[f64], argvals: &[f64], f0: f64) -> Vec<f64> {
184    let m = q.len();
185    if m == 0 {
186        return Vec::new();
187    }
188
189    // Integrand: q(s) * |q(s)|
190    let integrand: Vec<f64> = q.iter().map(|&qi| qi * qi.abs()).collect();
191    let integral = cumulative_trapz(&integrand, argvals);
192
193    integral.iter().map(|&v| f0 + v).collect()
194}
195
196// ─── Reparameterization ─────────────────────────────────────────────────────
197
198/// Reparameterize a curve by a warping function.
199///
200/// Computes `f(γ(t))` via linear interpolation.
201///
202/// # Arguments
203/// * `f` — Curve values (length m)
204/// * `argvals` — Evaluation points (length m)
205/// * `gamma` — Warping function values (length m)
206pub fn reparameterize_curve(f: &[f64], argvals: &[f64], gamma: &[f64]) -> Vec<f64> {
207    gamma
208        .iter()
209        .map(|&g| linear_interp(argvals, f, g))
210        .collect()
211}
212
213/// Compose two warping functions: `(γ₁ ∘ γ₂)(t) = γ₁(γ₂(t))`.
214///
215/// # Arguments
216/// * `gamma1` — Outer warping function (length m)
217/// * `gamma2` — Inner warping function (length m)
218/// * `argvals` — Evaluation points (length m)
219pub fn compose_warps(gamma1: &[f64], gamma2: &[f64], argvals: &[f64]) -> Vec<f64> {
220    gamma2
221        .iter()
222        .map(|&g| linear_interp(argvals, gamma1, g))
223        .collect()
224}
225
226// ─── Dynamic Programming Alignment ──────────────────────────────────────────
227// Faithful port of fdasrvf's DP algorithm (dp_grid.cpp / dp_nbhd.cpp).
228
229/// Greatest common divisor (Euclidean algorithm).
230#[cfg(test)]
231fn gcd(a: usize, b: usize) -> usize {
232    if b == 0 {
233        a
234    } else {
235        gcd(b, a % b)
236    }
237}
238
239/// Generate coprime neighborhood: all (i,j) with 1 ≤ i,j ≤ nbhd_dim, gcd(i,j) = 1.
240/// With nbhd_dim=7 this produces 35 pairs, matching fdasrvf's default.
241#[cfg(test)]
242fn generate_coprime_nbhd(nbhd_dim: usize) -> Vec<(usize, usize)> {
243    let mut pairs = Vec::new();
244    for i in 1..=nbhd_dim {
245        for j in 1..=nbhd_dim {
246            if gcd(i, j) == 1 {
247                pairs.push((i, j));
248            }
249        }
250    }
251    pairs
252}
253
254/// Pre-computed coprime neighborhood for nbhd_dim=7 (fdasrvf default).
255/// All (dr, dc) with 1 ≤ dr, dc ≤ 7 and gcd(dr, dc) = 1.
256/// dr = row delta (q2 direction), dc = column delta (q1 direction).
257#[rustfmt::skip]
258const COPRIME_NBHD_7: [(usize, usize); 35] = [
259    (1,1),(1,2),(1,3),(1,4),(1,5),(1,6),(1,7),
260    (2,1),      (2,3),      (2,5),      (2,7),
261    (3,1),(3,2),      (3,4),(3,5),      (3,7),
262    (4,1),      (4,3),      (4,5),      (4,7),
263    (5,1),(5,2),(5,3),(5,4),      (5,6),(5,7),
264    (6,1),                  (6,5),      (6,7),
265    (7,1),(7,2),(7,3),(7,4),(7,5),(7,6),
266];
267
268/// Compute the edge weight for a move from grid point (sr, sc) to (tr, tc).
269///
270/// Port of fdasrvf's `dp_edge_weight` for 1-D curves on a shared uniform grid.
271/// - Rows = q2 indices, columns = q1 indices (matching fdasrvf convention).
272/// - `slope = (argvals[tr] - argvals[sr]) / (argvals[tc] - argvals[sc])` = γ'
273/// - Walks through sub-intervals synchronized at both curves' breakpoints,
274///   accumulating `(q1[idx1] - √slope · q2[idx2])² · dt`.
275#[inline]
276fn dp_edge_weight(
277    q1: &[f64],
278    q2: &[f64],
279    argvals: &[f64],
280    sc: usize,
281    tc: usize,
282    sr: usize,
283    tr: usize,
284) -> f64 {
285    let n1 = tc - sc;
286    let n2 = tr - sr;
287    if n1 == 0 || n2 == 0 {
288        return f64::INFINITY;
289    }
290
291    let slope = (argvals[tr] - argvals[sr]) / (argvals[tc] - argvals[sc]);
292    let rslope = slope.sqrt();
293
294    // Walk through sub-intervals synchronized at breakpoints of both curves
295    let mut weight = 0.0;
296    let mut i1 = 0usize; // sub-interval index in q1 direction
297    let mut i2 = 0usize; // sub-interval index in q2 direction
298
299    while i1 < n1 && i2 < n2 {
300        // Current sub-interval boundaries as fractions of the total span
301        let left1 = i1 as f64 / n1 as f64;
302        let right1 = (i1 + 1) as f64 / n1 as f64;
303        let left2 = i2 as f64 / n2 as f64;
304        let right2 = (i2 + 1) as f64 / n2 as f64;
305
306        let left = left1.max(left2);
307        let right = right1.min(right2);
308        let dt = right - left;
309
310        if dt > 0.0 {
311            let diff = q1[sc + i1] - rslope * q2[sr + i2];
312            weight += diff * diff * dt;
313        }
314
315        // Advance whichever sub-interval ends first
316        if right1 < right2 {
317            i1 += 1;
318        } else if right2 < right1 {
319            i2 += 1;
320        } else {
321            i1 += 1;
322            i2 += 1;
323        }
324    }
325
326    // Scale by the span in q1 direction
327    weight * (argvals[tc] - argvals[sc])
328}
329
330/// Compute the λ·(slope−1)²·dt penalty for a DP edge.
331#[inline]
332fn dp_lambda_penalty(
333    argvals: &[f64],
334    sc: usize,
335    tc: usize,
336    sr: usize,
337    tr: usize,
338    lambda: f64,
339) -> f64 {
340    if lambda > 0.0 {
341        let dt = argvals[tc] - argvals[sc];
342        let slope = (argvals[tr] - argvals[sr]) / dt;
343        lambda * (slope - 1.0).powi(2) * dt
344    } else {
345        0.0
346    }
347}
348
349/// Traceback a parent-pointer array from bottom-right to top-left.
350///
351/// Returns the path as `(row, col)` pairs from `(0,0)` to `(nrows-1, ncols-1)`.
352fn dp_traceback(parent: &[u32], nrows: usize, ncols: usize) -> Vec<(usize, usize)> {
353    let mut path = Vec::with_capacity(nrows + ncols);
354    let mut cur = (nrows - 1) * ncols + (ncols - 1);
355    loop {
356        path.push((cur / ncols, cur % ncols));
357        if cur == 0 || parent[cur] == u32::MAX {
358            break;
359        }
360        cur = parent[cur] as usize;
361    }
362    path.reverse();
363    path
364}
365
366/// Try to relax cell `(tr, tc)` from each coprime neighbor, updating cost and parent.
367#[inline]
368fn dp_relax_cell<F>(
369    e: &mut [f64],
370    parent: &mut [u32],
371    ncols: usize,
372    tr: usize,
373    tc: usize,
374    edge_cost: &F,
375) where
376    F: Fn(usize, usize, usize, usize) -> f64,
377{
378    let idx = tr * ncols + tc;
379    for &(dr, dc) in &COPRIME_NBHD_7 {
380        if dr > tr || dc > tc {
381            continue;
382        }
383        let sr = tr - dr;
384        let sc = tc - dc;
385        let src_idx = sr * ncols + sc;
386        if e[src_idx] == f64::INFINITY {
387            continue;
388        }
389        let cost = e[src_idx] + edge_cost(sr, sc, tr, tc);
390        if cost < e[idx] {
391            e[idx] = cost;
392            parent[idx] = src_idx as u32;
393        }
394    }
395}
396
397/// Shared DP grid fill + traceback using the coprime neighborhood.
398///
399/// `edge_cost(sr, sc, tr, tc)` returns the combined edge weight + penalty for
400/// a move from local (sr, sc) to local (tr, tc). Returns the raw local-index
401/// path from (0,0) to (nrows-1, ncols-1).
402fn dp_grid_solve<F>(nrows: usize, ncols: usize, edge_cost: F) -> Vec<(usize, usize)>
403where
404    F: Fn(usize, usize, usize, usize) -> f64,
405{
406    let mut e = vec![f64::INFINITY; nrows * ncols];
407    let mut parent = vec![u32::MAX; nrows * ncols];
408    e[0] = 0.0;
409
410    for tr in 0..nrows {
411        for tc in 0..ncols {
412            if tr == 0 && tc == 0 {
413                continue;
414            }
415            dp_relax_cell(&mut e, &mut parent, ncols, tr, tc, &edge_cost);
416        }
417    }
418
419    dp_traceback(&parent, nrows, ncols)
420}
421
422/// Convert a DP path (local row,col indices) to an interpolated+normalized gamma warp.
423fn dp_path_to_gamma(path: &[(usize, usize)], argvals: &[f64]) -> Vec<f64> {
424    let path_tc: Vec<f64> = path.iter().map(|&(_, c)| argvals[c]).collect();
425    let path_tr: Vec<f64> = path.iter().map(|&(r, _)| argvals[r]).collect();
426    let mut gamma: Vec<f64> = argvals
427        .iter()
428        .map(|&t| linear_interp(&path_tc, &path_tr, t))
429        .collect();
430    normalize_warp(&mut gamma, argvals);
431    gamma
432}
433
434/// Core DP alignment between two SRSFs on a grid.
435///
436/// Finds the optimal warping γ minimizing ‖q₁ - (q₂∘γ)√γ'‖².
437/// Uses fdasrvf's coprime neighborhood (nbhd_dim=7 → 35 move directions).
438/// SRSFs are L2-normalized before alignment (matching fdasrvf's `optimum.reparam`).
439pub(crate) fn dp_alignment_core(q1: &[f64], q2: &[f64], argvals: &[f64], lambda: f64) -> Vec<f64> {
440    let m = argvals.len();
441    if m < 2 {
442        return argvals.to_vec();
443    }
444
445    let norm1 = q1.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
446    let norm2 = q2.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
447    let q1n: Vec<f64> = q1.iter().map(|&v| v / norm1).collect();
448    let q2n: Vec<f64> = q2.iter().map(|&v| v / norm2).collect();
449
450    let path = dp_grid_solve(m, m, |sr, sc, tr, tc| {
451        dp_edge_weight(&q1n, &q2n, argvals, sc, tc, sr, tr)
452            + dp_lambda_penalty(argvals, sc, tc, sr, tr, lambda)
453    });
454
455    dp_path_to_gamma(&path, argvals)
456}
457
458// ─── Public Alignment Functions ─────────────────────────────────────────────
459
460/// Align curve f2 to curve f1 using the elastic framework.
461///
462/// Computes the optimal warping γ such that f2∘γ is as close as possible
463/// to f1 in the elastic (Fisher-Rao) metric.
464///
465/// # Arguments
466/// * `f1` — Target curve (length m)
467/// * `f2` — Curve to align (length m)
468/// * `argvals` — Evaluation points (length m)
469/// * `lambda` — Penalty weight on warp deviation from identity (0.0 = no penalty)
470///
471/// # Returns
472/// [`AlignmentResult`] with warping function, aligned curve, and elastic distance.
473pub fn elastic_align_pair(f1: &[f64], f2: &[f64], argvals: &[f64], lambda: f64) -> AlignmentResult {
474    let m = f1.len();
475
476    // Build single-row FdMatrices for SRSF computation
477    let f1_mat = FdMatrix::from_slice(f1, 1, m).unwrap();
478    let f2_mat = FdMatrix::from_slice(f2, 1, m).unwrap();
479
480    let q1_mat = srsf_transform(&f1_mat, argvals);
481    let q2_mat = srsf_transform(&f2_mat, argvals);
482
483    let q1: Vec<f64> = q1_mat.row(0);
484    let q2: Vec<f64> = q2_mat.row(0);
485
486    // Find optimal warping via DP
487    let gamma = dp_alignment_core(&q1, &q2, argvals, lambda);
488
489    // Apply warping to f2
490    let f_aligned = reparameterize_curve(f2, argvals, &gamma);
491
492    // Compute elastic distance: L2 distance between q1 and aligned q2 SRSF
493    let f_aligned_mat = FdMatrix::from_slice(&f_aligned, 1, m).unwrap();
494    let q_aligned_mat = srsf_transform(&f_aligned_mat, argvals);
495    let q_aligned: Vec<f64> = q_aligned_mat.row(0);
496
497    let weights = simpsons_weights(argvals);
498    let distance = l2_distance(&q1, &q_aligned, &weights);
499
500    AlignmentResult {
501        gamma,
502        f_aligned,
503        distance,
504    }
505}
506
507/// Compute the elastic distance between two curves.
508///
509/// This is shorthand for aligning the pair and returning only the distance.
510///
511/// # Arguments
512/// * `f1` — First curve (length m)
513/// * `f2` — Second curve (length m)
514/// * `argvals` — Evaluation points (length m)
515/// * `lambda` — Penalty weight on warp deviation from identity (0.0 = no penalty)
516pub fn elastic_distance(f1: &[f64], f2: &[f64], argvals: &[f64], lambda: f64) -> f64 {
517    elastic_align_pair(f1, f2, argvals, lambda).distance
518}
519
520/// Align all curves in `data` to a single target curve.
521///
522/// # Arguments
523/// * `data` — Functional data matrix (n × m)
524/// * `target` — Target curve to align to (length m)
525/// * `argvals` — Evaluation points (length m)
526/// * `lambda` — Penalty weight on warp deviation from identity (0.0 = no penalty)
527///
528/// # Returns
529/// [`AlignmentSetResult`] with all warping functions, aligned curves, and distances.
530pub fn align_to_target(
531    data: &FdMatrix,
532    target: &[f64],
533    argvals: &[f64],
534    lambda: f64,
535) -> AlignmentSetResult {
536    let (n, m) = data.shape();
537
538    let results: Vec<AlignmentResult> = iter_maybe_parallel!(0..n)
539        .map(|i| {
540            let fi = data.row(i);
541            elastic_align_pair(target, &fi, argvals, lambda)
542        })
543        .collect();
544
545    let mut gammas = FdMatrix::zeros(n, m);
546    let mut aligned_data = FdMatrix::zeros(n, m);
547    let mut distances = Vec::with_capacity(n);
548
549    for (i, r) in results.into_iter().enumerate() {
550        for j in 0..m {
551            gammas[(i, j)] = r.gamma[j];
552            aligned_data[(i, j)] = r.f_aligned[j];
553        }
554        distances.push(r.distance);
555    }
556
557    AlignmentSetResult {
558        gammas,
559        aligned_data,
560        distances,
561    }
562}
563
564// ─── Distance Matrices ──────────────────────────────────────────────────────
565
566/// Compute the symmetric elastic distance matrix for a set of curves.
567///
568/// Uses upper-triangle computation with parallelism, following the
569/// `self_distance_matrix` pattern from `metric.rs`.
570///
571/// # Arguments
572/// * `data` — Functional data matrix (n × m)
573/// * `argvals` — Evaluation points (length m)
574/// * `lambda` — Penalty weight on warp deviation from identity (0.0 = no penalty)
575///
576/// # Returns
577/// Symmetric n × n distance matrix.
578pub fn elastic_self_distance_matrix(data: &FdMatrix, argvals: &[f64], lambda: f64) -> FdMatrix {
579    let n = data.nrows();
580
581    let upper_vals: Vec<f64> = iter_maybe_parallel!(0..n)
582        .flat_map(|i| {
583            let fi = data.row(i);
584            ((i + 1)..n)
585                .map(|j| {
586                    let fj = data.row(j);
587                    elastic_distance(&fi, &fj, argvals, lambda)
588                })
589                .collect::<Vec<_>>()
590        })
591        .collect();
592
593    let mut dist = FdMatrix::zeros(n, n);
594    let mut idx = 0;
595    for i in 0..n {
596        for j in (i + 1)..n {
597            let d = upper_vals[idx];
598            dist[(i, j)] = d;
599            dist[(j, i)] = d;
600            idx += 1;
601        }
602    }
603    dist
604}
605
606/// Compute the elastic distance matrix between two sets of curves.
607///
608/// # Arguments
609/// * `data1` — First dataset (n1 × m)
610/// * `data2` — Second dataset (n2 × m)
611/// * `argvals` — Evaluation points (length m)
612/// * `lambda` — Penalty weight on warp deviation from identity (0.0 = no penalty)
613///
614/// # Returns
615/// n1 × n2 distance matrix.
616pub fn elastic_cross_distance_matrix(
617    data1: &FdMatrix,
618    data2: &FdMatrix,
619    argvals: &[f64],
620    lambda: f64,
621) -> FdMatrix {
622    let n1 = data1.nrows();
623    let n2 = data2.nrows();
624
625    let vals: Vec<f64> = iter_maybe_parallel!(0..n1)
626        .flat_map(|i| {
627            let fi = data1.row(i);
628            (0..n2)
629                .map(|j| {
630                    let fj = data2.row(j);
631                    elastic_distance(&fi, &fj, argvals, lambda)
632                })
633                .collect::<Vec<_>>()
634        })
635        .collect();
636
637    let mut dist = FdMatrix::zeros(n1, n2);
638    for i in 0..n1 {
639        for j in 0..n2 {
640            dist[(i, j)] = vals[i * n2 + j];
641        }
642    }
643    dist
644}
645
646// ─── Phase-Amplitude Decomposition ──────────────────────────────────────────
647
648/// Result of elastic phase-amplitude decomposition.
649#[derive(Debug, Clone)]
650pub struct DecompositionResult {
651    /// Full alignment result.
652    pub alignment: AlignmentResult,
653    /// Amplitude distance: SRSF distance after alignment.
654    pub d_amplitude: f64,
655    /// Phase distance: geodesic distance of warp from identity.
656    pub d_phase: f64,
657}
658
659/// Perform elastic phase-amplitude decomposition of two curves.
660///
661/// Returns both the alignment result and the separate amplitude and phase distances.
662///
663/// # Arguments
664/// * `f1` — Target curve (length m)
665/// * `f2` — Curve to decompose against f1 (length m)
666/// * `argvals` — Evaluation points (length m)
667/// * `lambda` — Penalty weight on warp deviation from identity (0.0 = no penalty)
668pub fn elastic_decomposition(
669    f1: &[f64],
670    f2: &[f64],
671    argvals: &[f64],
672    lambda: f64,
673) -> DecompositionResult {
674    let alignment = elastic_align_pair(f1, f2, argvals, lambda);
675    let d_amplitude = alignment.distance;
676    let d_phase = crate::warping::phase_distance(&alignment.gamma, argvals);
677    DecompositionResult {
678        alignment,
679        d_amplitude,
680        d_phase,
681    }
682}
683
684/// Compute the amplitude distance between two curves (= elastic distance after alignment).
685pub fn amplitude_distance(f1: &[f64], f2: &[f64], argvals: &[f64], lambda: f64) -> f64 {
686    elastic_distance(f1, f2, argvals, lambda)
687}
688
689/// Compute the phase distance between two curves (geodesic distance of optimal warp from identity).
690pub fn phase_distance_pair(f1: &[f64], f2: &[f64], argvals: &[f64], lambda: f64) -> f64 {
691    let alignment = elastic_align_pair(f1, f2, argvals, lambda);
692    crate::warping::phase_distance(&alignment.gamma, argvals)
693}
694
695/// Compute the symmetric phase distance matrix for a set of curves.
696pub fn phase_self_distance_matrix(data: &FdMatrix, argvals: &[f64], lambda: f64) -> FdMatrix {
697    let n = data.nrows();
698
699    let upper_vals: Vec<f64> = iter_maybe_parallel!(0..n)
700        .flat_map(|i| {
701            let fi = data.row(i);
702            ((i + 1)..n)
703                .map(|j| {
704                    let fj = data.row(j);
705                    phase_distance_pair(&fi, &fj, argvals, lambda)
706                })
707                .collect::<Vec<_>>()
708        })
709        .collect();
710
711    let mut dist = FdMatrix::zeros(n, n);
712    let mut idx = 0;
713    for i in 0..n {
714        for j in (i + 1)..n {
715            let d = upper_vals[idx];
716            dist[(i, j)] = d;
717            dist[(j, i)] = d;
718            idx += 1;
719        }
720    }
721    dist
722}
723
724/// Compute the symmetric amplitude distance matrix (= elastic self distance matrix).
725pub fn amplitude_self_distance_matrix(data: &FdMatrix, argvals: &[f64], lambda: f64) -> FdMatrix {
726    elastic_self_distance_matrix(data, argvals, lambda)
727}
728
729// ─── Karcher Mean ───────────────────────────────────────────────────────────
730
731/// Compute relative change between successive mean SRSFs.
732///
733/// Returns `‖q_new - q_old‖₂ / ‖q_old‖₂`, matching R's fdasrvf
734/// `time_warping` convergence metric (unweighted discrete L2 norm).
735fn relative_change(q_old: &[f64], q_new: &[f64]) -> f64 {
736    let diff_norm: f64 = q_old
737        .iter()
738        .zip(q_new.iter())
739        .map(|(&a, &b)| (a - b).powi(2))
740        .sum::<f64>()
741        .sqrt();
742    let old_norm: f64 = q_old.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
743    diff_norm / old_norm
744}
745
746/// Compute a single SRSF from a slice (single-row convenience).
747fn srsf_single(f: &[f64], argvals: &[f64]) -> Vec<f64> {
748    let m = f.len();
749    let mat = FdMatrix::from_slice(f, 1, m).unwrap();
750    let q_mat = srsf_transform(&mat, argvals);
751    q_mat.row(0)
752}
753
754/// Align a single SRSF q2 to q1 and return (gamma, aligned_q).
755fn align_srsf_pair(q1: &[f64], q2: &[f64], argvals: &[f64], lambda: f64) -> (Vec<f64>, Vec<f64>) {
756    let gamma = dp_alignment_core(q1, q2, argvals, lambda);
757
758    // Warp q2 by gamma and adjust by sqrt(gamma')
759    let q2_warped = reparameterize_curve(q2, argvals, &gamma);
760
761    // Compute gamma' via finite differences
762    let m = gamma.len();
763    let mut gamma_dot = vec![0.0; m];
764    gamma_dot[0] = (gamma[1] - gamma[0]) / (argvals[1] - argvals[0]);
765    for j in 1..(m - 1) {
766        gamma_dot[j] = (gamma[j + 1] - gamma[j - 1]) / (argvals[j + 1] - argvals[j - 1]);
767    }
768    gamma_dot[m - 1] = (gamma[m - 1] - gamma[m - 2]) / (argvals[m - 1] - argvals[m - 2]);
769
770    // q2_aligned = (q2 ∘ γ) * sqrt(γ')
771    let q2_aligned: Vec<f64> = q2_warped
772        .iter()
773        .zip(gamma_dot.iter())
774        .map(|(&q, &gd)| q * gd.max(0.0).sqrt())
775        .collect();
776
777    (gamma, q2_aligned)
778}
779
780/// Compute the Karcher (Fréchet) mean in the elastic metric.
781///
782/// Iteratively aligns all curves to the current mean estimate in SRSF space,
783/// computes the pointwise mean of aligned SRSFs, and reconstructs the mean curve.
784///
785/// # Arguments
786/// * `data` — Functional data matrix (n × m)
787/// * `argvals` — Evaluation points (length m)
788/// * `max_iter` — Maximum number of iterations
789/// * `tol` — Convergence tolerance for the SRSF mean
790///
791/// # Returns
792/// [`KarcherMeanResult`] with mean curve, warping functions, aligned data, and convergence info.
793///
794/// # Examples
795///
796/// ```
797/// use fdars_core::simulation::{sim_fundata, EFunType, EValType};
798/// use fdars_core::alignment::karcher_mean;
799///
800/// let t: Vec<f64> = (0..50).map(|i| i as f64 / 49.0).collect();
801/// let data = sim_fundata(20, &t, 3, EFunType::Fourier, EValType::Exponential, Some(42));
802///
803/// let result = karcher_mean(&data, &t, 20, 1e-4, 0.0);
804/// assert_eq!(result.mean.len(), 50);
805/// assert!(result.n_iter <= 20);
806/// ```
807/// Accumulate alignment results: store gammas and return the mean of aligned SRSFs.
808fn accumulate_alignments(
809    results: &[(Vec<f64>, Vec<f64>)],
810    gammas: &mut FdMatrix,
811    m: usize,
812    n: usize,
813) -> Vec<f64> {
814    let mut mu_q_new = vec![0.0; m];
815    for (i, (gamma, q_aligned)) in results.iter().enumerate() {
816        for j in 0..m {
817            gammas[(i, j)] = gamma[j];
818            mu_q_new[j] += q_aligned[j];
819        }
820    }
821    for j in 0..m {
822        mu_q_new[j] /= n as f64;
823    }
824    mu_q_new
825}
826
827/// Apply stored warps to original curves to produce aligned data.
828fn apply_stored_warps(data: &FdMatrix, gammas: &FdMatrix, argvals: &[f64]) -> FdMatrix {
829    let (n, m) = data.shape();
830    let mut aligned = FdMatrix::zeros(n, m);
831    for i in 0..n {
832        let fi = data.row(i);
833        let gamma: Vec<f64> = (0..m).map(|j| gammas[(i, j)]).collect();
834        let f_aligned = reparameterize_curve(&fi, argvals, &gamma);
835        for j in 0..m {
836            aligned[(i, j)] = f_aligned[j];
837        }
838    }
839    aligned
840}
841
842/// Select the SRSF closest to the pointwise mean as template. Returns (mu_q, mu_f).
843fn select_template(srsf_mat: &FdMatrix, data: &FdMatrix, argvals: &[f64]) -> (Vec<f64>, Vec<f64>) {
844    let (n, m) = srsf_mat.shape();
845    let mnq = mean_1d(srsf_mat);
846    let mut min_dist = f64::INFINITY;
847    let mut min_idx = 0;
848    for i in 0..n {
849        let dist_sq: f64 = (0..m).map(|j| (srsf_mat[(i, j)] - mnq[j]).powi(2)).sum();
850        if dist_sq < min_dist {
851            min_dist = dist_sq;
852            min_idx = i;
853        }
854    }
855    let _ = argvals; // kept for API consistency
856    (srsf_mat.row(min_idx), data.row(min_idx))
857}
858
859/// Pre-centering: align all curves to template, compute inverse mean warp, re-center.
860fn pre_center_template(
861    data: &FdMatrix,
862    mu_q: &[f64],
863    mu: &[f64],
864    argvals: &[f64],
865    lambda: f64,
866) -> (Vec<f64>, Vec<f64>) {
867    let (n, m) = data.shape();
868    let align_results: Vec<(Vec<f64>, Vec<f64>)> = iter_maybe_parallel!(0..n)
869        .map(|i| {
870            let fi = data.row(i);
871            let qi = srsf_single(&fi, argvals);
872            align_srsf_pair(mu_q, &qi, argvals, lambda)
873        })
874        .collect();
875
876    let mut init_gammas = FdMatrix::zeros(n, m);
877    for (i, (gamma, _)) in align_results.iter().enumerate() {
878        for j in 0..m {
879            init_gammas[(i, j)] = gamma[j];
880        }
881    }
882
883    let gam_inv = sqrt_mean_inverse(&init_gammas, argvals);
884    let mu_new = reparameterize_curve(mu, argvals, &gam_inv);
885    let mu_q_new = srsf_single(&mu_new, argvals);
886    (mu_q_new, mu_new)
887}
888
889/// Post-convergence centering: center mean SRSF and warps via SqrtMeanInverse.
890fn post_center_results(
891    data: &FdMatrix,
892    mu_q: &[f64],
893    final_gammas: &mut FdMatrix,
894    argvals: &[f64],
895) -> (Vec<f64>, Vec<f64>, FdMatrix) {
896    let (n, m) = data.shape();
897    let gam_inv = sqrt_mean_inverse(final_gammas, argvals);
898    let h = (argvals[m - 1] - argvals[0]) / (m - 1) as f64;
899    let gam_inv_dev = gradient_uniform(&gam_inv, h);
900
901    let mu_q_warped = reparameterize_curve(mu_q, argvals, &gam_inv);
902    let mu_q_centered: Vec<f64> = mu_q_warped
903        .iter()
904        .zip(gam_inv_dev.iter())
905        .map(|(&q, &gd)| q * gd.max(0.0).sqrt())
906        .collect();
907
908    for i in 0..n {
909        let gam_i: Vec<f64> = (0..m).map(|j| final_gammas[(i, j)]).collect();
910        let gam_centered = reparameterize_curve(&gam_i, argvals, &gam_inv);
911        for j in 0..m {
912            final_gammas[(i, j)] = gam_centered[j];
913        }
914    }
915
916    let initial_mean = mean_1d(data);
917    let mu = srsf_inverse(&mu_q_centered, argvals, initial_mean[0]);
918    let final_aligned = apply_stored_warps(data, final_gammas, argvals);
919    (mu, mu_q_centered, final_aligned)
920}
921
922pub fn karcher_mean(
923    data: &FdMatrix,
924    argvals: &[f64],
925    max_iter: usize,
926    tol: f64,
927    lambda: f64,
928) -> KarcherMeanResult {
929    let (n, m) = data.shape();
930
931    let srsf_mat = srsf_transform(data, argvals);
932    let (mut mu_q, mu) = select_template(&srsf_mat, data, argvals);
933    let (mu_q_c, mu_c) = pre_center_template(data, &mu_q, &mu, argvals, lambda);
934    mu_q = mu_q_c;
935    let mut mu = mu_c;
936
937    let mut converged = false;
938    let mut n_iter = 0;
939    let mut final_gammas = FdMatrix::zeros(n, m);
940
941    for iter in 0..max_iter {
942        n_iter = iter + 1;
943
944        let align_results: Vec<(Vec<f64>, Vec<f64>)> = iter_maybe_parallel!(0..n)
945            .map(|i| {
946                let fi = data.row(i);
947                let qi = srsf_single(&fi, argvals);
948                align_srsf_pair(&mu_q, &qi, argvals, lambda)
949            })
950            .collect();
951
952        let mu_q_new = accumulate_alignments(&align_results, &mut final_gammas, m, n);
953
954        let rel = relative_change(&mu_q, &mu_q_new);
955        if rel < tol {
956            converged = true;
957            mu_q = mu_q_new;
958            break;
959        }
960
961        mu_q = mu_q_new;
962        mu = srsf_inverse(&mu_q, argvals, mu[0]);
963    }
964
965    let (mu_final, mu_q_final, final_aligned) =
966        post_center_results(data, &mu_q, &mut final_gammas, argvals);
967
968    KarcherMeanResult {
969        mean: mu_final,
970        mean_srsf: mu_q_final,
971        gammas: final_gammas,
972        aligned_data: final_aligned,
973        n_iter,
974        converged,
975        aligned_srsfs: None,
976    }
977}
978
979// ─── TSRVF (Transported SRSF) ────────────────────────────────────────────────
980// Maps aligned SRSFs to the tangent space of the Karcher mean on the Hilbert
981// sphere. Tangent vectors live in a standard Euclidean space, enabling PCA,
982// regression, and clustering on elastic-aligned curves.
983
984/// Result of the TSRVF transform.
985#[derive(Debug, Clone)]
986pub struct TsrvfResult {
987    /// Tangent vectors in Euclidean space (n × m).
988    pub tangent_vectors: FdMatrix,
989    /// Karcher mean curve (length m).
990    pub mean: Vec<f64>,
991    /// SRSF of the Karcher mean (length m).
992    pub mean_srsf: Vec<f64>,
993    /// L2 norm of the mean SRSF.
994    pub mean_srsf_norm: f64,
995    /// Per-curve aligned SRSF norms (length n).
996    pub srsf_norms: Vec<f64>,
997    /// Per-curve initial values f_i(0) for SRSF inverse reconstruction (length n).
998    pub initial_values: Vec<f64>,
999    /// Warping functions from Karcher mean computation (n × m).
1000    pub gammas: FdMatrix,
1001    /// Whether the Karcher mean converged.
1002    pub converged: bool,
1003}
1004
1005/// Full TSRVF pipeline: compute Karcher mean, then transport SRSFs to tangent space.
1006///
1007/// # Arguments
1008/// * `data` — Functional data matrix (n × m)
1009/// * `argvals` — Evaluation points (length m)
1010/// * `max_iter` — Maximum Karcher mean iterations
1011/// * `tol` — Convergence tolerance for Karcher mean
1012///
1013/// # Returns
1014/// [`TsrvfResult`] containing tangent vectors and associated metadata.
1015pub fn tsrvf_transform(
1016    data: &FdMatrix,
1017    argvals: &[f64],
1018    max_iter: usize,
1019    tol: f64,
1020    lambda: f64,
1021) -> TsrvfResult {
1022    let karcher = karcher_mean(data, argvals, max_iter, tol, lambda);
1023    tsrvf_from_alignment(&karcher, argvals)
1024}
1025
1026/// Smooth aligned SRSFs to remove DP kink artifacts before TSRVF computation.
1027///
1028/// Uses Nadaraya-Watson kernel smoothing (Gaussian, bandwidth = 2 grid spacings)
1029/// on each SRSF row. This removes the derivative spikes from DP warp kinks
1030/// without affecting alignment results or the Karcher mean.
1031fn smooth_aligned_srsfs(srsf: &FdMatrix, m: usize) -> FdMatrix {
1032    let n = srsf.nrows();
1033    let time: Vec<f64> = (0..m).map(|j| j as f64 / (m - 1) as f64).collect();
1034    let bandwidth = 2.0 / (m - 1) as f64;
1035
1036    let mut smoothed = FdMatrix::zeros(n, m);
1037    for i in 0..n {
1038        let qi = srsf.row(i);
1039        let qi_smooth = nadaraya_watson(&time, &qi, &time, bandwidth, "gaussian");
1040        for j in 0..m {
1041            smoothed[(i, j)] = qi_smooth[j];
1042        }
1043    }
1044    smoothed
1045}
1046
1047/// Compute TSRVF from a pre-computed Karcher mean alignment.
1048///
1049/// Avoids re-running the expensive Karcher mean computation when the alignment
1050/// has already been computed.
1051///
1052/// # Arguments
1053/// * `karcher` — Pre-computed Karcher mean result
1054/// * `argvals` — Evaluation points (length m)
1055///
1056/// # Returns
1057/// [`TsrvfResult`] containing tangent vectors and associated metadata.
1058pub fn tsrvf_from_alignment(karcher: &KarcherMeanResult, argvals: &[f64]) -> TsrvfResult {
1059    let (n, m) = karcher.aligned_data.shape();
1060
1061    // Step 1: Compute SRSFs of aligned data
1062    let aligned_srsf = srsf_transform(&karcher.aligned_data, argvals);
1063
1064    // Step 1b: Smooth aligned SRSFs to remove DP kink artifacts.
1065    //
1066    // DP alignment produces piecewise-linear warps with kinks at grid transitions.
1067    // When curves are reparameterized by these warps, the kinks propagate into the
1068    // aligned curves' derivatives (SRSFs), creating spikes that dominate TSRVF
1069    // tangent vectors and PCA.
1070    //
1071    // R's fdasrvf does not smooth here and suffers from the same spike artifacts.
1072    // Python's fdasrsf mitigates this via spline smoothing (s=1e-4) in SqrtMean.
1073    // We smooth the aligned SRSFs before tangent vector computation — this only
1074    // affects TSRVF output and does not change the alignment or Karcher mean.
1075    let aligned_srsf = smooth_aligned_srsfs(&aligned_srsf, m);
1076
1077    // Step 2: Smooth and normalize mean SRSF to unit sphere.
1078    // The mean SRSF must be smoothed consistently with the aligned SRSFs
1079    // so that a single curve (which IS the mean) produces a zero tangent vector.
1080    let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
1081    let bandwidth = 2.0 / (m - 1) as f64;
1082    let mean_srsf_smooth = nadaraya_watson(&time, &karcher.mean_srsf, &time, bandwidth, "gaussian");
1083    let mean_norm = l2_norm_l2(&mean_srsf_smooth, &time);
1084
1085    let mu_unit: Vec<f64> = if mean_norm > 1e-10 {
1086        mean_srsf_smooth.iter().map(|&q| q / mean_norm).collect()
1087    } else {
1088        vec![0.0; m]
1089    };
1090
1091    // Step 3: For each aligned curve, compute tangent vector via inverse exponential map
1092    let srsf_norms: Vec<f64> = iter_maybe_parallel!(0..n)
1093        .map(|i| {
1094            let qi = aligned_srsf.row(i);
1095            l2_norm_l2(&qi, &time)
1096        })
1097        .collect();
1098
1099    let tangent_data: Vec<Vec<f64>> = iter_maybe_parallel!(0..n)
1100        .map(|i| {
1101            let qi = aligned_srsf.row(i);
1102            let qi_norm = srsf_norms[i];
1103
1104            if qi_norm < 1e-10 || mean_norm < 1e-10 {
1105                return vec![0.0; m];
1106            }
1107
1108            // Normalize to unit sphere
1109            let qi_unit: Vec<f64> = qi.iter().map(|&q| q / qi_norm).collect();
1110
1111            // Shooting vector from mu_unit to qi_unit
1112            inv_exp_map_sphere(&mu_unit, &qi_unit, &time)
1113        })
1114        .collect();
1115
1116    // Assemble tangent vectors into FdMatrix
1117    let mut tangent_vectors = FdMatrix::zeros(n, m);
1118    for i in 0..n {
1119        for j in 0..m {
1120            tangent_vectors[(i, j)] = tangent_data[i][j];
1121        }
1122    }
1123
1124    // Store per-curve initial values for SRSF inverse reconstruction.
1125    // Warping preserves f_i(0) since gamma(0) = 0.
1126    let initial_values: Vec<f64> = (0..n).map(|i| karcher.aligned_data[(i, 0)]).collect();
1127
1128    TsrvfResult {
1129        tangent_vectors,
1130        mean: karcher.mean.clone(),
1131        mean_srsf: mean_srsf_smooth,
1132        mean_srsf_norm: mean_norm,
1133        srsf_norms,
1134        initial_values,
1135        gammas: karcher.gammas.clone(),
1136        converged: karcher.converged,
1137    }
1138}
1139
1140/// Reconstruct aligned curves from TSRVF tangent vectors.
1141///
1142/// Inverts the TSRVF transform: maps tangent vectors back to the Hilbert sphere
1143/// via the exponential map, rescales, and reconstructs curves via SRSF inverse.
1144///
1145/// # Arguments
1146/// * `tsrvf` — TSRVF result from [`tsrvf_transform`] or [`tsrvf_from_alignment`]
1147/// * `argvals` — Evaluation points (length m)
1148///
1149/// # Returns
1150/// FdMatrix of reconstructed aligned curves (n × m).
1151pub fn tsrvf_inverse(tsrvf: &TsrvfResult, argvals: &[f64]) -> FdMatrix {
1152    let (n, m) = tsrvf.tangent_vectors.shape();
1153
1154    let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
1155
1156    // Normalize mean SRSF to unit sphere
1157    let mu_unit: Vec<f64> = if tsrvf.mean_srsf_norm > 1e-10 {
1158        tsrvf
1159            .mean_srsf
1160            .iter()
1161            .map(|&q| q / tsrvf.mean_srsf_norm)
1162            .collect()
1163    } else {
1164        vec![0.0; m]
1165    };
1166
1167    let curves: Vec<Vec<f64>> = iter_maybe_parallel!(0..n)
1168        .map(|i| {
1169            let vi = tsrvf.tangent_vectors.row(i);
1170
1171            // Map back to sphere: exp_map(mu_unit, v_i)
1172            let qi_unit = exp_map_sphere(&mu_unit, &vi, &time);
1173
1174            // Rescale by original norm
1175            let qi: Vec<f64> = qi_unit.iter().map(|&q| q * tsrvf.srsf_norms[i]).collect();
1176
1177            // Reconstruct curve from SRSF using per-curve initial value
1178            srsf_inverse(&qi, argvals, tsrvf.initial_values[i])
1179        })
1180        .collect();
1181
1182    let mut result = FdMatrix::zeros(n, m);
1183    for i in 0..n {
1184        for j in 0..m {
1185            result[(i, j)] = curves[i][j];
1186        }
1187    }
1188    result
1189}
1190
1191// ─── Parallel Transport Variants ─────────────────────────────────────────────
1192
1193/// Method for transporting tangent vectors on the Hilbert sphere.
1194#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
1195pub enum TransportMethod {
1196    /// Inverse exponential map (log map) — default, matches existing TSRVF behavior.
1197    #[default]
1198    LogMap,
1199    /// Schild's ladder approximation to parallel transport.
1200    SchildsLadder,
1201    /// Pole ladder approximation to parallel transport.
1202    PoleLadder,
1203}
1204
1205/// Schild's ladder parallel transport of vector `v` from `from` to `to` on the sphere.
1206fn parallel_transport_schilds(v: &[f64], from: &[f64], to: &[f64], time: &[f64]) -> Vec<f64> {
1207    use crate::warping::{exp_map_sphere, inv_exp_map_sphere};
1208
1209    let v_norm = crate::warping::l2_norm_l2(v, time);
1210    if v_norm < 1e-10 {
1211        return vec![0.0; v.len()];
1212    }
1213
1214    // endpoint = exp_from(v)
1215    let endpoint = exp_map_sphere(from, v, time);
1216
1217    // midpoint_v = log_to(endpoint) — vector at `to` pointing toward endpoint
1218    let log_to_ep = inv_exp_map_sphere(to, &endpoint, time);
1219
1220    // midpoint = exp_to(0.5 * log_to_ep)
1221    let half_log: Vec<f64> = log_to_ep.iter().map(|&x| 0.5 * x).collect();
1222    let midpoint = exp_map_sphere(to, &half_log, time);
1223
1224    // transported = 2 * log_to(midpoint)
1225    let log_to_mid = inv_exp_map_sphere(to, &midpoint, time);
1226    log_to_mid.iter().map(|&x| 2.0 * x).collect()
1227}
1228
1229/// Pole ladder parallel transport of vector `v` from `from` to `to` on the sphere.
1230fn parallel_transport_pole(v: &[f64], from: &[f64], to: &[f64], time: &[f64]) -> Vec<f64> {
1231    use crate::warping::{exp_map_sphere, inv_exp_map_sphere};
1232
1233    let v_norm = crate::warping::l2_norm_l2(v, time);
1234    if v_norm < 1e-10 {
1235        return vec![0.0; v.len()];
1236    }
1237
1238    // pole = exp_from(-v)
1239    let neg_v: Vec<f64> = v.iter().map(|&x| -x).collect();
1240    let pole = exp_map_sphere(from, &neg_v, time);
1241
1242    // midpoint_v = log_to(pole)
1243    let log_to_pole = inv_exp_map_sphere(to, &pole, time);
1244
1245    // midpoint = exp_to(0.5 * log_to_pole)
1246    let half_log: Vec<f64> = log_to_pole.iter().map(|&x| 0.5 * x).collect();
1247    let midpoint = exp_map_sphere(to, &half_log, time);
1248
1249    // transported = -2 * log_to(midpoint)
1250    let log_to_mid = inv_exp_map_sphere(to, &midpoint, time);
1251    log_to_mid.iter().map(|&x| -2.0 * x).collect()
1252}
1253
1254/// Full TSRVF pipeline with configurable transport method.
1255///
1256/// Like [`tsrvf_transform`] but allows choosing the parallel transport method.
1257pub fn tsrvf_transform_with_method(
1258    data: &FdMatrix,
1259    argvals: &[f64],
1260    max_iter: usize,
1261    tol: f64,
1262    lambda: f64,
1263    method: TransportMethod,
1264) -> TsrvfResult {
1265    let karcher = karcher_mean(data, argvals, max_iter, tol, lambda);
1266    tsrvf_from_alignment_with_method(&karcher, argvals, method)
1267}
1268
1269/// Compute TSRVF from a pre-computed Karcher mean with configurable transport.
1270///
1271/// - [`TransportMethod::LogMap`]: Uses `inv_exp_map(mu, qi)` directly (standard TSRVF).
1272/// - [`TransportMethod::SchildsLadder`]: Computes `v = -log_qi(mu)`, then transports
1273///   via Schild's ladder from qi to mu.
1274/// - [`TransportMethod::PoleLadder`]: Same but via pole ladder.
1275pub fn tsrvf_from_alignment_with_method(
1276    karcher: &KarcherMeanResult,
1277    argvals: &[f64],
1278    method: TransportMethod,
1279) -> TsrvfResult {
1280    if method == TransportMethod::LogMap {
1281        return tsrvf_from_alignment(karcher, argvals);
1282    }
1283
1284    let (n, m) = karcher.aligned_data.shape();
1285    let aligned_srsf = srsf_transform(&karcher.aligned_data, argvals);
1286    let aligned_srsf = smooth_aligned_srsfs(&aligned_srsf, m);
1287    let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
1288    let bandwidth = 2.0 / (m - 1) as f64;
1289    let mean_srsf_smooth = nadaraya_watson(&time, &karcher.mean_srsf, &time, bandwidth, "gaussian");
1290    let mean_norm = crate::warping::l2_norm_l2(&mean_srsf_smooth, &time);
1291
1292    let mu_unit: Vec<f64> = if mean_norm > 1e-10 {
1293        mean_srsf_smooth.iter().map(|&q| q / mean_norm).collect()
1294    } else {
1295        vec![0.0; m]
1296    };
1297
1298    let srsf_norms: Vec<f64> = iter_maybe_parallel!(0..n)
1299        .map(|i| {
1300            let qi = aligned_srsf.row(i);
1301            crate::warping::l2_norm_l2(&qi, &time)
1302        })
1303        .collect();
1304
1305    let tangent_data: Vec<Vec<f64>> = iter_maybe_parallel!(0..n)
1306        .map(|i| {
1307            let qi = aligned_srsf.row(i);
1308            let qi_norm = srsf_norms[i];
1309
1310            if qi_norm < 1e-10 || mean_norm < 1e-10 {
1311                return vec![0.0; m];
1312            }
1313
1314            let qi_unit: Vec<f64> = qi.iter().map(|&q| q / qi_norm).collect();
1315
1316            // Compute v = -log_qi(mu) — vector at qi pointing away from mu
1317            let v_at_qi = inv_exp_map_sphere(&qi_unit, &mu_unit, &time);
1318            let neg_v: Vec<f64> = v_at_qi.iter().map(|&x| -x).collect();
1319
1320            // Transport from qi to mu
1321            match method {
1322                TransportMethod::SchildsLadder => {
1323                    parallel_transport_schilds(&neg_v, &qi_unit, &mu_unit, &time)
1324                }
1325                TransportMethod::PoleLadder => {
1326                    parallel_transport_pole(&neg_v, &qi_unit, &mu_unit, &time)
1327                }
1328                TransportMethod::LogMap => unreachable!(),
1329            }
1330        })
1331        .collect();
1332
1333    let mut tangent_vectors = FdMatrix::zeros(n, m);
1334    for i in 0..n {
1335        for j in 0..m {
1336            tangent_vectors[(i, j)] = tangent_data[i][j];
1337        }
1338    }
1339
1340    let initial_values: Vec<f64> = (0..n).map(|i| karcher.aligned_data[(i, 0)]).collect();
1341
1342    TsrvfResult {
1343        tangent_vectors,
1344        mean: karcher.mean.clone(),
1345        mean_srsf: mean_srsf_smooth,
1346        mean_srsf_norm: mean_norm,
1347        srsf_norms,
1348        initial_values,
1349        gammas: karcher.gammas.clone(),
1350        converged: karcher.converged,
1351    }
1352}
1353
1354// ─── Alignment Quality Metrics ───────────────────────────────────────────────
1355
1356/// Comprehensive alignment quality assessment.
1357#[derive(Debug, Clone)]
1358pub struct AlignmentQuality {
1359    /// Per-curve geodesic distance from warp to identity.
1360    pub warp_complexity: Vec<f64>,
1361    /// Mean warp complexity.
1362    pub mean_warp_complexity: f64,
1363    /// Per-curve bending energy ∫(γ'')² dt.
1364    pub warp_smoothness: Vec<f64>,
1365    /// Mean warp smoothness (bending energy).
1366    pub mean_warp_smoothness: f64,
1367    /// Total variance: (1/n) Σ ∫(f_i - mean_orig)² dt.
1368    pub total_variance: f64,
1369    /// Amplitude variance: (1/n) Σ ∫(f_i^aligned - mean_aligned)² dt.
1370    pub amplitude_variance: f64,
1371    /// Phase variance: total - amplitude (clamped ≥ 0).
1372    pub phase_variance: f64,
1373    /// Phase-to-total variance ratio.
1374    pub phase_amplitude_ratio: f64,
1375    /// Pointwise ratio: aligned_var / orig_var per time point.
1376    pub pointwise_variance_ratio: Vec<f64>,
1377    /// Mean variance reduction.
1378    pub mean_variance_reduction: f64,
1379}
1380
1381/// Compute warp complexity: geodesic distance from a warp to the identity.
1382///
1383/// This is `arccos(⟨ψ, ψ_id⟩)` on the Hilbert sphere.
1384pub fn warp_complexity(gamma: &[f64], argvals: &[f64]) -> f64 {
1385    crate::warping::phase_distance(gamma, argvals)
1386}
1387
1388/// Compute warp smoothness (bending energy): ∫(γ'')² dt.
1389pub fn warp_smoothness(gamma: &[f64], argvals: &[f64]) -> f64 {
1390    let m = gamma.len();
1391    if m < 3 {
1392        return 0.0;
1393    }
1394
1395    let h = (argvals[m - 1] - argvals[0]) / (m - 1) as f64;
1396    let gam_prime = gradient_uniform(gamma, h);
1397    let gam_pprime = gradient_uniform(&gam_prime, h);
1398
1399    let integrand: Vec<f64> = gam_pprime.iter().map(|&g| g * g).collect();
1400    crate::helpers::trapz(&integrand, argvals)
1401}
1402
1403/// Compute comprehensive alignment quality metrics.
1404///
1405/// # Arguments
1406/// * `data` — Original functional data (n × m)
1407/// * `karcher` — Pre-computed Karcher mean result
1408/// * `argvals` — Evaluation points (length m)
1409pub fn alignment_quality(
1410    data: &FdMatrix,
1411    karcher: &KarcherMeanResult,
1412    argvals: &[f64],
1413) -> AlignmentQuality {
1414    let (n, m) = data.shape();
1415    let weights = simpsons_weights(argvals);
1416
1417    // Per-curve warp complexity and smoothness
1418    let wc: Vec<f64> = (0..n)
1419        .map(|i| {
1420            let gamma: Vec<f64> = (0..m).map(|j| karcher.gammas[(i, j)]).collect();
1421            warp_complexity(&gamma, argvals)
1422        })
1423        .collect();
1424    let ws: Vec<f64> = (0..n)
1425        .map(|i| {
1426            let gamma: Vec<f64> = (0..m).map(|j| karcher.gammas[(i, j)]).collect();
1427            warp_smoothness(&gamma, argvals)
1428        })
1429        .collect();
1430
1431    let mean_wc = wc.iter().sum::<f64>() / n as f64;
1432    let mean_ws = ws.iter().sum::<f64>() / n as f64;
1433
1434    // Compute original mean
1435    let orig_mean = crate::fdata::mean_1d(data);
1436
1437    // Total variance
1438    let total_var: f64 = (0..n)
1439        .map(|i| {
1440            let fi = data.row(i);
1441            let d = l2_distance(&fi, &orig_mean, &weights);
1442            d * d
1443        })
1444        .sum::<f64>()
1445        / n as f64;
1446
1447    // Aligned mean
1448    let aligned_mean = crate::fdata::mean_1d(&karcher.aligned_data);
1449
1450    // Amplitude variance
1451    let amp_var: f64 = (0..n)
1452        .map(|i| {
1453            let fi = karcher.aligned_data.row(i);
1454            let d = l2_distance(&fi, &aligned_mean, &weights);
1455            d * d
1456        })
1457        .sum::<f64>()
1458        / n as f64;
1459
1460    let phase_var = (total_var - amp_var).max(0.0);
1461    let ratio = if total_var > 1e-10 {
1462        phase_var / total_var
1463    } else {
1464        0.0
1465    };
1466
1467    // Pointwise variance ratio
1468    let mut pw_ratio = vec![0.0; m];
1469    for j in 0..m {
1470        let col_orig = data.column(j);
1471        let mean_orig = col_orig.iter().sum::<f64>() / n as f64;
1472        let var_orig: f64 = col_orig
1473            .iter()
1474            .map(|&v| (v - mean_orig).powi(2))
1475            .sum::<f64>()
1476            / n as f64;
1477
1478        let col_aligned = karcher.aligned_data.column(j);
1479        let mean_aligned = col_aligned.iter().sum::<f64>() / n as f64;
1480        let var_aligned: f64 = col_aligned
1481            .iter()
1482            .map(|&v| (v - mean_aligned).powi(2))
1483            .sum::<f64>()
1484            / n as f64;
1485
1486        pw_ratio[j] = if var_orig > 1e-15 {
1487            var_aligned / var_orig
1488        } else {
1489            1.0
1490        };
1491    }
1492
1493    let mean_vr = pw_ratio.iter().sum::<f64>() / m as f64;
1494
1495    AlignmentQuality {
1496        warp_complexity: wc,
1497        mean_warp_complexity: mean_wc,
1498        warp_smoothness: ws,
1499        mean_warp_smoothness: mean_ws,
1500        total_variance: total_var,
1501        amplitude_variance: amp_var,
1502        phase_variance: phase_var,
1503        phase_amplitude_ratio: ratio,
1504        pointwise_variance_ratio: pw_ratio,
1505        mean_variance_reduction: mean_vr,
1506    }
1507}
1508
1509/// Generate triplet indices (i,j,k) with i<j<k, capped at `max_triplets` (0 = all).
1510fn triplet_indices(n: usize, max_triplets: usize) -> Vec<(usize, usize, usize)> {
1511    let total = n * (n - 1) * (n - 2) / 6;
1512    let cap = if max_triplets > 0 {
1513        max_triplets.min(total)
1514    } else {
1515        total
1516    };
1517    (0..n)
1518        .flat_map(|i| ((i + 1)..n).flat_map(move |j| ((j + 1)..n).map(move |k| (i, j, k))))
1519        .take(cap)
1520        .collect()
1521}
1522
1523/// Compute the warp deviation for one triplet: ‖γ_ij∘γ_jk − γ_ik‖_L2.
1524fn triplet_warp_deviation(
1525    data: &FdMatrix,
1526    argvals: &[f64],
1527    weights: &[f64],
1528    i: usize,
1529    j: usize,
1530    k: usize,
1531    lambda: f64,
1532) -> f64 {
1533    let fi = data.row(i);
1534    let fj = data.row(j);
1535    let fk = data.row(k);
1536    let rij = elastic_align_pair(&fi, &fj, argvals, lambda);
1537    let rjk = elastic_align_pair(&fj, &fk, argvals, lambda);
1538    let rik = elastic_align_pair(&fi, &fk, argvals, lambda);
1539    let composed = compose_warps(&rij.gamma, &rjk.gamma, argvals);
1540    l2_distance(&composed, &rik.gamma, weights)
1541}
1542
1543/// Measure pairwise alignment consistency via triplet checks.
1544///
1545/// For triplets (i,j,k), checks `γ_ij ∘ γ_jk ≈ γ_ik` by measuring the L2
1546/// deviation of the composed warp from the direct warp.
1547///
1548/// # Arguments
1549/// * `data` — Functional data (n × m)
1550/// * `argvals` — Evaluation points (length m)
1551/// * `lambda` — Penalty weight
1552/// * `max_triplets` — Maximum number of triplets to check (0 = all)
1553pub fn pairwise_consistency(
1554    data: &FdMatrix,
1555    argvals: &[f64],
1556    lambda: f64,
1557    max_triplets: usize,
1558) -> f64 {
1559    let n = data.nrows();
1560    if n < 3 {
1561        return 0.0;
1562    }
1563
1564    let weights = simpsons_weights(argvals);
1565    let triplets = triplet_indices(n, max_triplets);
1566    if triplets.is_empty() {
1567        return 0.0;
1568    }
1569
1570    let total_dev: f64 = triplets
1571        .iter()
1572        .map(|&(i, j, k)| triplet_warp_deviation(data, argvals, &weights, i, j, k, lambda))
1573        .sum();
1574    total_dev / triplets.len() as f64
1575}
1576
1577// ─── Landmark-Constrained Elastic Alignment ────────────────────────────────
1578
1579/// Result of landmark-constrained elastic alignment.
1580#[derive(Debug, Clone)]
1581pub struct ConstrainedAlignmentResult {
1582    /// Optimal warping function (length m).
1583    pub gamma: Vec<f64>,
1584    /// Aligned curve f2∘γ (length m).
1585    pub f_aligned: Vec<f64>,
1586    /// Elastic distance after alignment.
1587    pub distance: f64,
1588    /// Enforced landmark pairs (snapped to grid): `(target_t, source_t)`.
1589    pub enforced_landmarks: Vec<(f64, f64)>,
1590}
1591
1592/// Snap a time value to the nearest grid point index.
1593fn snap_to_grid(t_val: f64, argvals: &[f64]) -> usize {
1594    let mut best = 0;
1595    let mut best_dist = (t_val - argvals[0]).abs();
1596    for (i, &a) in argvals.iter().enumerate().skip(1) {
1597        let d = (t_val - a).abs();
1598        if d < best_dist {
1599            best = i;
1600            best_dist = d;
1601        }
1602    }
1603    best
1604}
1605
1606/// Run DP on a rectangular sub-grid `[sc..=ec] × [sr..=er]`.
1607///
1608/// Uses global indices for `dp_edge_weight`. Returns the path segment
1609/// as a list of `(tc_idx, tr_idx)` pairs from start to end.
1610fn dp_segment(
1611    q1: &[f64],
1612    q2: &[f64],
1613    argvals: &[f64],
1614    sc: usize,
1615    ec: usize,
1616    sr: usize,
1617    er: usize,
1618    lambda: f64,
1619) -> Vec<(usize, usize)> {
1620    let nc = ec - sc + 1;
1621    let nr = er - sr + 1;
1622
1623    if nc <= 1 || nr <= 1 {
1624        return vec![(sc, sr), (ec, er)];
1625    }
1626
1627    let path = dp_grid_solve(nr, nc, |local_sr, local_sc, local_tr, local_tc| {
1628        let gsr = sr + local_sr;
1629        let gsc = sc + local_sc;
1630        let gtr = sr + local_tr;
1631        let gtc = sc + local_tc;
1632        dp_edge_weight(q1, q2, argvals, gsc, gtc, gsr, gtr)
1633            + dp_lambda_penalty(argvals, gsc, gtc, gsr, gtr, lambda)
1634    });
1635
1636    // Convert local indices to global
1637    path.iter().map(|&(lr, lc)| (sc + lc, sr + lr)).collect()
1638}
1639
1640/// Align f2 to f1 with landmark constraints.
1641///
1642/// Landmark pairs define waypoints on the DP grid. Between consecutive waypoints,
1643/// an independent smaller DP is run. The resulting warp passes through all landmarks.
1644///
1645/// # Arguments
1646/// * `f1` — Target curve (length m)
1647/// * `f2` — Curve to align (length m)
1648/// * `argvals` — Evaluation points (length m)
1649/// * `landmark_pairs` — `(target_t, source_t)` pairs in increasing order
1650/// * `lambda` — Penalty weight
1651///
1652/// # Returns
1653/// [`ConstrainedAlignmentResult`] with warp, aligned curve, and enforced landmarks.
1654/// Build DP waypoints from landmark pairs: snap to grid, deduplicate, add endpoints.
1655fn build_constrained_waypoints(
1656    landmark_pairs: &[(f64, f64)],
1657    argvals: &[f64],
1658    m: usize,
1659) -> Vec<(usize, usize)> {
1660    let mut waypoints: Vec<(usize, usize)> = Vec::with_capacity(landmark_pairs.len() + 2);
1661    waypoints.push((0, 0));
1662    for &(tt, st) in landmark_pairs {
1663        let tc = snap_to_grid(tt, argvals);
1664        let tr = snap_to_grid(st, argvals);
1665        if let Some(&(prev_c, prev_r)) = waypoints.last() {
1666            if tc > prev_c && tr > prev_r {
1667                waypoints.push((tc, tr));
1668            }
1669        }
1670    }
1671    let last = m - 1;
1672    if let Some(&(prev_c, prev_r)) = waypoints.last() {
1673        if prev_c != last || prev_r != last {
1674            waypoints.push((last, last));
1675        }
1676    }
1677    waypoints
1678}
1679
1680/// Run DP segments between consecutive waypoints and assemble into a gamma warp.
1681fn segmented_dp_gamma(
1682    q1n: &[f64],
1683    q2n: &[f64],
1684    argvals: &[f64],
1685    waypoints: &[(usize, usize)],
1686    lambda: f64,
1687) -> Vec<f64> {
1688    let mut full_path_tc: Vec<f64> = Vec::new();
1689    let mut full_path_tr: Vec<f64> = Vec::new();
1690
1691    for seg in 0..(waypoints.len() - 1) {
1692        let (sc, sr) = waypoints[seg];
1693        let (ec, er) = waypoints[seg + 1];
1694        let segment_path = dp_segment(q1n, q2n, argvals, sc, ec, sr, er, lambda);
1695        let start = if seg > 0 { 1 } else { 0 };
1696        for &(tc, tr) in &segment_path[start..] {
1697            full_path_tc.push(argvals[tc]);
1698            full_path_tr.push(argvals[tr]);
1699        }
1700    }
1701
1702    let mut gamma: Vec<f64> = argvals
1703        .iter()
1704        .map(|&t| linear_interp(&full_path_tc, &full_path_tr, t))
1705        .collect();
1706    normalize_warp(&mut gamma, argvals);
1707    gamma
1708}
1709
1710pub fn elastic_align_pair_constrained(
1711    f1: &[f64],
1712    f2: &[f64],
1713    argvals: &[f64],
1714    landmark_pairs: &[(f64, f64)],
1715    lambda: f64,
1716) -> ConstrainedAlignmentResult {
1717    let m = f1.len();
1718
1719    if landmark_pairs.is_empty() {
1720        let r = elastic_align_pair(f1, f2, argvals, lambda);
1721        return ConstrainedAlignmentResult {
1722            gamma: r.gamma,
1723            f_aligned: r.f_aligned,
1724            distance: r.distance,
1725            enforced_landmarks: Vec::new(),
1726        };
1727    }
1728
1729    // Compute & normalize SRSFs
1730    let f1_mat = FdMatrix::from_slice(f1, 1, m).unwrap();
1731    let f2_mat = FdMatrix::from_slice(f2, 1, m).unwrap();
1732    let q1_mat = srsf_transform(&f1_mat, argvals);
1733    let q2_mat = srsf_transform(&f2_mat, argvals);
1734    let q1: Vec<f64> = q1_mat.row(0);
1735    let q2: Vec<f64> = q2_mat.row(0);
1736    let norm1 = q1.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
1737    let norm2 = q2.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
1738    let q1n: Vec<f64> = q1.iter().map(|&v| v / norm1).collect();
1739    let q2n: Vec<f64> = q2.iter().map(|&v| v / norm2).collect();
1740
1741    let waypoints = build_constrained_waypoints(landmark_pairs, argvals, m);
1742    let gamma = segmented_dp_gamma(&q1n, &q2n, argvals, &waypoints, lambda);
1743
1744    let f_aligned = reparameterize_curve(f2, argvals, &gamma);
1745    let f_aligned_mat = FdMatrix::from_slice(&f_aligned, 1, m).unwrap();
1746    let q_aligned_mat = srsf_transform(&f_aligned_mat, argvals);
1747    let q_aligned: Vec<f64> = q_aligned_mat.row(0);
1748    let weights = simpsons_weights(argvals);
1749    let distance = l2_distance(&q1, &q_aligned, &weights);
1750
1751    let enforced: Vec<(f64, f64)> = waypoints[1..waypoints.len() - 1]
1752        .iter()
1753        .map(|&(tc, tr)| (argvals[tc], argvals[tr]))
1754        .collect();
1755
1756    ConstrainedAlignmentResult {
1757        gamma,
1758        f_aligned,
1759        distance,
1760        enforced_landmarks: enforced,
1761    }
1762}
1763
1764/// Align f2 to f1 with automatic landmark detection and elastic constraints.
1765///
1766/// Detects landmarks in both curves, matches them, and uses the matches
1767/// as constraints for segmented DP alignment.
1768///
1769/// # Arguments
1770/// * `f1` — Target curve (length m)
1771/// * `f2` — Curve to align (length m)
1772/// * `argvals` — Evaluation points (length m)
1773/// * `kind` — Type of landmarks to detect
1774/// * `min_prominence` — Minimum prominence for landmark detection
1775/// * `expected_count` — Expected number of landmarks (0 = all detected)
1776/// * `lambda` — Penalty weight
1777pub fn elastic_align_pair_with_landmarks(
1778    f1: &[f64],
1779    f2: &[f64],
1780    argvals: &[f64],
1781    kind: crate::landmark::LandmarkKind,
1782    min_prominence: f64,
1783    expected_count: usize,
1784    lambda: f64,
1785) -> ConstrainedAlignmentResult {
1786    let lm1 = crate::landmark::detect_landmarks(f1, argvals, kind, min_prominence);
1787    let lm2 = crate::landmark::detect_landmarks(f2, argvals, kind, min_prominence);
1788
1789    // Match landmarks by order (take min count)
1790    let n_match = if expected_count > 0 {
1791        expected_count.min(lm1.len()).min(lm2.len())
1792    } else {
1793        lm1.len().min(lm2.len())
1794    };
1795
1796    let pairs: Vec<(f64, f64)> = (0..n_match)
1797        .map(|i| (lm1[i].position, lm2[i].position))
1798        .collect();
1799
1800    elastic_align_pair_constrained(f1, f2, argvals, &pairs, lambda)
1801}
1802
1803// ─── Multidimensional SRSF (R^d curves) ────────────────────────────────────
1804
1805use crate::matrix::FdCurveSet;
1806
1807/// Result of aligning multidimensional (R^d) curves.
1808#[derive(Debug, Clone)]
1809pub struct AlignmentResultNd {
1810    /// Optimal warping function (length m), same for all dimensions.
1811    pub gamma: Vec<f64>,
1812    /// Aligned curve: d vectors, each length m.
1813    pub f_aligned: Vec<Vec<f64>>,
1814    /// Elastic distance after alignment.
1815    pub distance: f64,
1816}
1817
1818/// Compute the SRSF transform for multidimensional (R^d) curves.
1819///
1820/// For f: \[0,1\] → R^d, the SRSF is q(t) = f'(t) / √‖f'(t)‖ where ‖·‖ is the
1821/// Euclidean norm in R^d. For d=1 this reduces to `sign(f') · √|f'|`.
1822///
1823/// # Arguments
1824/// * `data` — Set of n curves in R^d, each with m evaluation points
1825/// * `argvals` — Evaluation points (length m)
1826///
1827/// # Returns
1828/// `FdCurveSet` of SRSF values with the same shape as input.
1829/// Scale derivative vector at one point by 1/√‖f'‖, writing into result_dims.
1830#[inline]
1831fn srsf_scale_point(derivs: &[FdMatrix], result_dims: &mut [FdMatrix], i: usize, j: usize) {
1832    let d = derivs.len();
1833    let norm_sq: f64 = derivs.iter().map(|dd| dd[(i, j)].powi(2)).sum();
1834    let norm = norm_sq.sqrt();
1835    if norm < 1e-15 {
1836        for k in 0..d {
1837            result_dims[k][(i, j)] = 0.0;
1838        }
1839    } else {
1840        let scale = 1.0 / norm.sqrt();
1841        for k in 0..d {
1842            result_dims[k][(i, j)] = derivs[k][(i, j)] * scale;
1843        }
1844    }
1845}
1846
1847pub fn srsf_transform_nd(data: &FdCurveSet, argvals: &[f64]) -> FdCurveSet {
1848    let d = data.ndim();
1849    let n = data.ncurves();
1850    let m = data.npoints();
1851
1852    if d == 0 || n == 0 || m == 0 || argvals.len() != m {
1853        return FdCurveSet {
1854            dims: (0..d).map(|_| FdMatrix::zeros(n, m)).collect(),
1855        };
1856    }
1857
1858    let derivs: Vec<FdMatrix> = data
1859        .dims
1860        .iter()
1861        .map(|dim_mat| crate::fdata::deriv_1d(dim_mat, argvals, 1))
1862        .collect();
1863
1864    let mut result_dims: Vec<FdMatrix> = (0..d).map(|_| FdMatrix::zeros(n, m)).collect();
1865    for i in 0..n {
1866        for j in 0..m {
1867            srsf_scale_point(&derivs, &mut result_dims, i, j);
1868        }
1869    }
1870
1871    FdCurveSet { dims: result_dims }
1872}
1873
1874/// Reconstruct an R^d curve from its SRSF.
1875///
1876/// Given d-dimensional SRSF vectors and initial point f0, reconstructs:
1877/// `f_k(t) = f0_k + ∫₀ᵗ q_k(s) · ‖q(s)‖ ds` for each dimension k.
1878///
1879/// # Arguments
1880/// * `q` — SRSF: d vectors, each length m
1881/// * `argvals` — Evaluation points (length m)
1882/// * `f0` — Initial values in R^d (length d)
1883///
1884/// # Returns
1885/// Reconstructed curve: d vectors, each length m.
1886pub fn srsf_inverse_nd(q: &[Vec<f64>], argvals: &[f64], f0: &[f64]) -> Vec<Vec<f64>> {
1887    let d = q.len();
1888    if d == 0 {
1889        return Vec::new();
1890    }
1891    let m = q[0].len();
1892    if m == 0 {
1893        return vec![Vec::new(); d];
1894    }
1895
1896    // Compute ||q(t)|| at each time point
1897    let norms: Vec<f64> = (0..m)
1898        .map(|j| {
1899            let norm_sq: f64 = q.iter().map(|qk| qk[j].powi(2)).sum();
1900            norm_sq.sqrt()
1901        })
1902        .collect();
1903
1904    // For each dimension, integrand = q_k(t) * ||q(t)||
1905    let mut result = Vec::with_capacity(d);
1906    for k in 0..d {
1907        let integrand: Vec<f64> = (0..m).map(|j| q[k][j] * norms[j]).collect();
1908        let integral = cumulative_trapz(&integrand, argvals);
1909        let curve: Vec<f64> = integral.iter().map(|&v| f0[k] + v).collect();
1910        result.push(curve);
1911    }
1912
1913    result
1914}
1915
1916/// Core DP alignment for R^d SRSFs.
1917///
1918/// Same DP grid and coprime neighborhood as `dp_alignment_core`, but edge weight
1919/// is the sum of `dp_edge_weight` over d dimensions.
1920fn dp_alignment_core_nd(
1921    q1: &[Vec<f64>],
1922    q2: &[Vec<f64>],
1923    argvals: &[f64],
1924    lambda: f64,
1925) -> Vec<f64> {
1926    let d = q1.len();
1927    let m = argvals.len();
1928    if m < 2 || d == 0 {
1929        return argvals.to_vec();
1930    }
1931
1932    // For d=1, delegate to existing implementation for exact backward compat
1933    if d == 1 {
1934        return dp_alignment_core(&q1[0], &q2[0], argvals, lambda);
1935    }
1936
1937    // Normalize each dimension's SRSF to unit L2 norm
1938    let q1n: Vec<Vec<f64>> = q1
1939        .iter()
1940        .map(|qk| {
1941            let norm = qk.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
1942            qk.iter().map(|&v| v / norm).collect()
1943        })
1944        .collect();
1945    let q2n: Vec<Vec<f64>> = q2
1946        .iter()
1947        .map(|qk| {
1948            let norm = qk.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
1949            qk.iter().map(|&v| v / norm).collect()
1950        })
1951        .collect();
1952
1953    let path = dp_grid_solve(m, m, |sr, sc, tr, tc| {
1954        let w: f64 = (0..d)
1955            .map(|k| dp_edge_weight(&q1n[k], &q2n[k], argvals, sc, tc, sr, tr))
1956            .sum();
1957        w + dp_lambda_penalty(argvals, sc, tc, sr, tr, lambda)
1958    });
1959
1960    dp_path_to_gamma(&path, argvals)
1961}
1962
1963/// Align an R^d curve f2 to f1 using the elastic framework.
1964///
1965/// Finds the optimal warping γ (shared across all dimensions) such that
1966/// f2∘γ is as close as possible to f1 in the elastic metric.
1967///
1968/// # Arguments
1969/// * `f1` — Target curves (d dimensions)
1970/// * `f2` — Curves to align (d dimensions)
1971/// * `argvals` — Evaluation points (length m)
1972/// * `lambda` — Penalty weight (0.0 = no penalty)
1973pub fn elastic_align_pair_nd(
1974    f1: &FdCurveSet,
1975    f2: &FdCurveSet,
1976    argvals: &[f64],
1977    lambda: f64,
1978) -> AlignmentResultNd {
1979    let d = f1.ndim();
1980    let m = f1.npoints();
1981
1982    // Compute SRSFs
1983    let q1_set = srsf_transform_nd(f1, argvals);
1984    let q2_set = srsf_transform_nd(f2, argvals);
1985
1986    // Extract first curve from each dimension
1987    let q1: Vec<Vec<f64>> = q1_set.dims.iter().map(|dm| dm.row(0)).collect();
1988    let q2: Vec<Vec<f64>> = q2_set.dims.iter().map(|dm| dm.row(0)).collect();
1989
1990    // DP alignment using summed cost over dimensions
1991    let gamma = dp_alignment_core_nd(&q1, &q2, argvals, lambda);
1992
1993    // Apply warping to f2 in each dimension
1994    let f_aligned: Vec<Vec<f64>> = f2
1995        .dims
1996        .iter()
1997        .map(|dm| {
1998            let row = dm.row(0);
1999            reparameterize_curve(&row, argvals, &gamma)
2000        })
2001        .collect();
2002
2003    // Compute elastic distance: sum of squared L2 distances between aligned SRSFs
2004    let f_aligned_set = {
2005        let dims: Vec<FdMatrix> = f_aligned
2006            .iter()
2007            .map(|fa| FdMatrix::from_slice(fa, 1, m).unwrap())
2008            .collect();
2009        FdCurveSet { dims }
2010    };
2011    let q_aligned = srsf_transform_nd(&f_aligned_set, argvals);
2012    let weights = simpsons_weights(argvals);
2013
2014    let mut dist_sq = 0.0;
2015    for k in 0..d {
2016        let q1k = q1_set.dims[k].row(0);
2017        let qak = q_aligned.dims[k].row(0);
2018        let d_k = l2_distance(&q1k, &qak, &weights);
2019        dist_sq += d_k * d_k;
2020    }
2021
2022    AlignmentResultNd {
2023        gamma,
2024        f_aligned,
2025        distance: dist_sq.sqrt(),
2026    }
2027}
2028
2029/// Elastic distance between two R^d curves.
2030///
2031/// Aligns f2 to f1 and returns the post-alignment SRSF distance.
2032pub fn elastic_distance_nd(f1: &FdCurveSet, f2: &FdCurveSet, argvals: &[f64], lambda: f64) -> f64 {
2033    elastic_align_pair_nd(f1, f2, argvals, lambda).distance
2034}
2035
2036// ─── Tests ──────────────────────────────────────────────────────────────────
2037
2038#[cfg(test)]
2039mod tests {
2040    use super::*;
2041    use crate::helpers::trapz;
2042    use crate::simulation::{sim_fundata, EFunType, EValType};
2043    use crate::warping::inner_product_l2;
2044
2045    fn uniform_grid(m: usize) -> Vec<f64> {
2046        (0..m).map(|i| i as f64 / (m - 1) as f64).collect()
2047    }
2048
2049    fn make_test_data(n: usize, m: usize, seed: u64) -> FdMatrix {
2050        let t = uniform_grid(m);
2051        sim_fundata(
2052            n,
2053            &t,
2054            3,
2055            EFunType::Fourier,
2056            EValType::Exponential,
2057            Some(seed),
2058        )
2059    }
2060
2061    // ── cumulative_trapz ──
2062
2063    #[test]
2064    fn test_cumulative_trapz_constant() {
2065        // ∫₀ᵗ 1 dt = t
2066        let x = uniform_grid(50);
2067        let y = vec![1.0; 50];
2068        let result = cumulative_trapz(&y, &x);
2069        assert!((result[0]).abs() < 1e-15, "cumulative_trapz(0) should be 0");
2070        for j in 1..50 {
2071            assert!(
2072                (result[j] - x[j]).abs() < 1e-12,
2073                "∫₀^{:.3} 1 dt should be {:.3}, got {:.3}",
2074                x[j],
2075                x[j],
2076                result[j]
2077            );
2078        }
2079    }
2080
2081    #[test]
2082    fn test_cumulative_trapz_linear() {
2083        // ∫₀ᵗ s ds = t²/2
2084        let m = 100;
2085        let x = uniform_grid(m);
2086        let y: Vec<f64> = x.clone();
2087        let result = cumulative_trapz(&y, &x);
2088        for j in 1..m {
2089            let expected = x[j] * x[j] / 2.0;
2090            assert!(
2091                (result[j] - expected).abs() < 1e-4,
2092                "∫₀^{:.3} s ds: expected {expected:.6}, got {:.6}",
2093                x[j],
2094                result[j]
2095            );
2096        }
2097    }
2098
2099    // ── normalize_warp ──
2100
2101    #[test]
2102    fn test_normalize_warp_fixes_boundaries() {
2103        let t = uniform_grid(10);
2104        let mut gamma = vec![0.1; 10]; // constant, wrong boundaries
2105        normalize_warp(&mut gamma, &t);
2106        assert_eq!(gamma[0], t[0]);
2107        assert_eq!(gamma[9], t[9]);
2108    }
2109
2110    #[test]
2111    fn test_normalize_warp_enforces_monotonicity() {
2112        let t = uniform_grid(5);
2113        let mut gamma = vec![0.0, 0.5, 0.3, 0.8, 1.0]; // non-monotone at index 2
2114        normalize_warp(&mut gamma, &t);
2115        for j in 1..5 {
2116            assert!(
2117                gamma[j] >= gamma[j - 1],
2118                "gamma should be monotone after normalization at j={j}"
2119            );
2120        }
2121    }
2122
2123    #[test]
2124    fn test_normalize_warp_identity_unchanged() {
2125        let t = uniform_grid(20);
2126        let mut gamma = t.clone();
2127        normalize_warp(&mut gamma, &t);
2128        for j in 0..20 {
2129            assert!(
2130                (gamma[j] - t[j]).abs() < 1e-15,
2131                "Identity warp should be unchanged"
2132            );
2133        }
2134    }
2135
2136    // ── linear_interp ──
2137
2138    #[test]
2139    fn test_linear_interp_at_nodes() {
2140        let x = vec![0.0, 1.0, 2.0, 3.0];
2141        let y = vec![0.0, 2.0, 4.0, 6.0];
2142        for i in 0..x.len() {
2143            assert!((linear_interp(&x, &y, x[i]) - y[i]).abs() < 1e-12);
2144        }
2145    }
2146
2147    #[test]
2148    fn test_linear_interp_midpoints() {
2149        let x = vec![0.0, 1.0, 2.0];
2150        let y = vec![0.0, 2.0, 4.0];
2151        assert!((linear_interp(&x, &y, 0.5) - 1.0).abs() < 1e-12);
2152        assert!((linear_interp(&x, &y, 1.5) - 3.0).abs() < 1e-12);
2153    }
2154
2155    #[test]
2156    fn test_linear_interp_clamp() {
2157        let x = vec![0.0, 1.0, 2.0];
2158        let y = vec![1.0, 3.0, 5.0];
2159        assert!((linear_interp(&x, &y, -1.0) - 1.0).abs() < 1e-12);
2160        assert!((linear_interp(&x, &y, 3.0) - 5.0).abs() < 1e-12);
2161    }
2162
2163    #[test]
2164    fn test_linear_interp_nonuniform_grid() {
2165        let x = vec![0.0, 0.1, 0.5, 1.0];
2166        let y = vec![0.0, 1.0, 5.0, 10.0];
2167        // Between 0.1 and 0.5: slope = (5-1)/(0.5-0.1) = 10
2168        let val = linear_interp(&x, &y, 0.3);
2169        let expected = 1.0 + 10.0 * (0.3 - 0.1);
2170        assert!(
2171            (val - expected).abs() < 1e-12,
2172            "Non-uniform interp: expected {expected}, got {val}"
2173        );
2174    }
2175
2176    #[test]
2177    fn test_linear_interp_two_points() {
2178        let x = vec![0.0, 1.0];
2179        let y = vec![3.0, 7.0];
2180        assert!((linear_interp(&x, &y, 0.25) - 4.0).abs() < 1e-12);
2181        assert!((linear_interp(&x, &y, 0.75) - 6.0).abs() < 1e-12);
2182    }
2183
2184    // ── SRSF transform ──
2185
2186    #[test]
2187    fn test_srsf_transform_linear() {
2188        // f(t) = 2t: derivative = 2, SRSF = sqrt(2)
2189        let m = 50;
2190        let t = uniform_grid(m);
2191        let f: Vec<f64> = t.iter().map(|&ti| 2.0 * ti).collect();
2192        let mat = FdMatrix::from_slice(&f, 1, m).unwrap();
2193
2194        let q_mat = srsf_transform(&mat, &t);
2195        let q: Vec<f64> = q_mat.row(0);
2196
2197        let expected = 2.0_f64.sqrt();
2198        // Interior points should be close to sqrt(2)
2199        for j in 2..(m - 2) {
2200            assert!(
2201                (q[j] - expected).abs() < 0.1,
2202                "q[{j}] = {}, expected ~{expected}",
2203                q[j]
2204            );
2205        }
2206    }
2207
2208    #[test]
2209    fn test_srsf_transform_preserves_shape() {
2210        let data = make_test_data(10, 50, 42);
2211        let t = uniform_grid(50);
2212        let q = srsf_transform(&data, &t);
2213        assert_eq!(q.shape(), data.shape());
2214    }
2215
2216    #[test]
2217    fn test_srsf_transform_constant_is_zero() {
2218        // f(t) = 5 (constant): derivative = 0, SRSF = 0
2219        let m = 30;
2220        let t = uniform_grid(m);
2221        let f = vec![5.0; m];
2222        let mat = FdMatrix::from_slice(&f, 1, m).unwrap();
2223        let q_mat = srsf_transform(&mat, &t);
2224        let q: Vec<f64> = q_mat.row(0);
2225
2226        for j in 0..m {
2227            assert!(
2228                q[j].abs() < 1e-10,
2229                "SRSF of constant should be 0, got q[{j}] = {}",
2230                q[j]
2231            );
2232        }
2233    }
2234
2235    #[test]
2236    fn test_srsf_transform_negative_slope() {
2237        // f(t) = -3t: derivative = -3, SRSF = -sqrt(3)
2238        let m = 50;
2239        let t = uniform_grid(m);
2240        let f: Vec<f64> = t.iter().map(|&ti| -3.0 * ti).collect();
2241        let mat = FdMatrix::from_slice(&f, 1, m).unwrap();
2242
2243        let q_mat = srsf_transform(&mat, &t);
2244        let q: Vec<f64> = q_mat.row(0);
2245
2246        let expected = -(3.0_f64.sqrt());
2247        for j in 2..(m - 2) {
2248            assert!(
2249                (q[j] - expected).abs() < 0.15,
2250                "q[{j}] = {}, expected ~{expected}",
2251                q[j]
2252            );
2253        }
2254    }
2255
2256    #[test]
2257    fn test_srsf_transform_empty_input() {
2258        let data = FdMatrix::zeros(0, 0);
2259        let t: Vec<f64> = vec![];
2260        let q = srsf_transform(&data, &t);
2261        assert_eq!(q.shape(), (0, 0));
2262    }
2263
2264    #[test]
2265    fn test_srsf_transform_multiple_curves() {
2266        let m = 40;
2267        let t = uniform_grid(m);
2268        let data = make_test_data(5, m, 42);
2269
2270        let q = srsf_transform(&data, &t);
2271        assert_eq!(q.shape(), (5, m));
2272
2273        // Each row should have finite values
2274        for i in 0..5 {
2275            for j in 0..m {
2276                assert!(q[(i, j)].is_finite(), "SRSF should be finite at ({i},{j})");
2277            }
2278        }
2279    }
2280
2281    // ── SRSF inverse ──
2282
2283    #[test]
2284    fn test_srsf_round_trip() {
2285        let m = 100;
2286        let t = uniform_grid(m);
2287        // Use a smooth function
2288        let f: Vec<f64> = t
2289            .iter()
2290            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin() + ti)
2291            .collect();
2292
2293        let mat = FdMatrix::from_slice(&f, 1, m).unwrap();
2294        let q_mat = srsf_transform(&mat, &t);
2295        let q: Vec<f64> = q_mat.row(0);
2296
2297        let f_recon = srsf_inverse(&q, &t, f[0]);
2298
2299        // Check reconstruction is close (interior points, avoid boundary effects)
2300        let max_err: f64 = f[5..(m - 5)]
2301            .iter()
2302            .zip(f_recon[5..(m - 5)].iter())
2303            .map(|(a, b)| (a - b).abs())
2304            .fold(0.0_f64, f64::max);
2305
2306        assert!(
2307            max_err < 0.15,
2308            "Round-trip error too large: max_err = {max_err}"
2309        );
2310    }
2311
2312    #[test]
2313    fn test_srsf_inverse_empty() {
2314        let q: Vec<f64> = vec![];
2315        let t: Vec<f64> = vec![];
2316        let result = srsf_inverse(&q, &t, 0.0);
2317        assert!(result.is_empty());
2318    }
2319
2320    #[test]
2321    fn test_srsf_inverse_preserves_initial_value() {
2322        let m = 50;
2323        let t = uniform_grid(m);
2324        let q = vec![1.0; m]; // constant SRSF
2325        let f0 = 3.15;
2326        let f = srsf_inverse(&q, &t, f0);
2327        assert!((f[0] - f0).abs() < 1e-12, "srsf_inverse should start at f0");
2328    }
2329
2330    #[test]
2331    fn test_srsf_round_trip_multiple_curves() {
2332        let m = 80;
2333        let t = uniform_grid(m);
2334        let data = make_test_data(5, m, 99);
2335
2336        let q_mat = srsf_transform(&data, &t);
2337
2338        for i in 0..5 {
2339            let fi = data.row(i);
2340            let qi = q_mat.row(i);
2341            let f_recon = srsf_inverse(&qi, &t, fi[0]);
2342            let max_err: f64 = fi[5..(m - 5)]
2343                .iter()
2344                .zip(f_recon[5..(m - 5)].iter())
2345                .map(|(a, b)| (a - b).abs())
2346                .fold(0.0_f64, f64::max);
2347            assert!(max_err < 0.3, "Round-trip curve {i}: max_err = {max_err}");
2348        }
2349    }
2350
2351    // ── Reparameterize ──
2352
2353    #[test]
2354    fn test_reparameterize_identity_warp() {
2355        let m = 50;
2356        let t = uniform_grid(m);
2357        let f: Vec<f64> = t.iter().map(|&ti| ti * ti).collect();
2358
2359        // Identity warp: γ(t) = t
2360        let result = reparameterize_curve(&f, &t, &t);
2361        for j in 0..m {
2362            assert!(
2363                (result[j] - f[j]).abs() < 1e-12,
2364                "Identity warp should return original at j={j}"
2365            );
2366        }
2367    }
2368
2369    #[test]
2370    fn test_reparameterize_linear_warp() {
2371        let m = 50;
2372        let t = uniform_grid(m);
2373        // f(t) = t (linear), γ(t) = t^2 (quadratic warp on [0,1])
2374        let f: Vec<f64> = t.clone();
2375        let gamma: Vec<f64> = t.iter().map(|&ti| ti * ti).collect();
2376
2377        let result = reparameterize_curve(&f, &t, &gamma);
2378
2379        // f(γ(t)) = γ(t) = t^2 for a linear f(t) = t
2380        for j in 0..m {
2381            assert!(
2382                (result[j] - gamma[j]).abs() < 1e-10,
2383                "f(gamma(t)) should be gamma(t) for f(t)=t at j={j}"
2384            );
2385        }
2386    }
2387
2388    #[test]
2389    fn test_reparameterize_sine_with_quadratic_warp() {
2390        let m = 100;
2391        let t = uniform_grid(m);
2392        let f: Vec<f64> = t
2393            .iter()
2394            .map(|&ti| (std::f64::consts::PI * ti).sin())
2395            .collect();
2396        let gamma: Vec<f64> = t.iter().map(|&ti| ti * ti).collect(); // speeds up start
2397
2398        let result = reparameterize_curve(&f, &t, &gamma);
2399
2400        // f(γ(t)) = sin(π * t²); check a few known values
2401        for j in 0..m {
2402            let expected = (std::f64::consts::PI * gamma[j]).sin();
2403            assert!(
2404                (result[j] - expected).abs() < 0.05,
2405                "sin(π γ(t)) at j={j}: expected {expected:.4}, got {:.4}",
2406                result[j]
2407            );
2408        }
2409    }
2410
2411    #[test]
2412    fn test_reparameterize_preserves_length() {
2413        let m = 50;
2414        let t = uniform_grid(m);
2415        let f = vec![1.0; m];
2416        let gamma: Vec<f64> = t.iter().map(|&ti| ti.sqrt()).collect();
2417
2418        let result = reparameterize_curve(&f, &t, &gamma);
2419        assert_eq!(result.len(), m);
2420    }
2421
2422    // ── Compose warps ──
2423
2424    #[test]
2425    fn test_compose_warps_identity() {
2426        let m = 50;
2427        let t = uniform_grid(m);
2428        // γ(t) = t^0.5 (a warp on [0,1])
2429        let gamma: Vec<f64> = t.iter().map(|&ti| ti.sqrt()).collect();
2430
2431        // identity ∘ γ = γ
2432        let result = compose_warps(&t, &gamma, &t);
2433        for j in 0..m {
2434            assert!(
2435                (result[j] - gamma[j]).abs() < 1e-10,
2436                "id ∘ γ should be γ at j={j}"
2437            );
2438        }
2439
2440        // γ ∘ identity = γ
2441        let result2 = compose_warps(&gamma, &t, &t);
2442        for j in 0..m {
2443            assert!(
2444                (result2[j] - gamma[j]).abs() < 1e-10,
2445                "γ ∘ id should be γ at j={j}"
2446            );
2447        }
2448    }
2449
2450    #[test]
2451    fn test_compose_warps_associativity() {
2452        // (γ₁ ∘ γ₂) ∘ γ₃ ≈ γ₁ ∘ (γ₂ ∘ γ₃)
2453        let m = 50;
2454        let t = uniform_grid(m);
2455        let g1: Vec<f64> = t.iter().map(|&ti| ti.sqrt()).collect();
2456        let g2: Vec<f64> = t.iter().map(|&ti| ti * ti).collect();
2457        let g3: Vec<f64> = t.iter().map(|&ti| 0.5 * ti + 0.5 * ti * ti).collect();
2458
2459        let g12 = compose_warps(&g1, &g2, &t);
2460        let left = compose_warps(&g12, &g3, &t); // (g1∘g2) ∘ g3
2461
2462        let g23 = compose_warps(&g2, &g3, &t);
2463        let right = compose_warps(&g1, &g23, &t); // g1 ∘ (g2∘g3)
2464
2465        for j in 0..m {
2466            assert!(
2467                (left[j] - right[j]).abs() < 0.05,
2468                "Composition should be roughly associative at j={j}: left={:.4}, right={:.4}",
2469                left[j],
2470                right[j]
2471            );
2472        }
2473    }
2474
2475    #[test]
2476    fn test_compose_warps_preserves_domain() {
2477        let m = 50;
2478        let t = uniform_grid(m);
2479        let g1: Vec<f64> = t.iter().map(|&ti| ti.sqrt()).collect();
2480        let g2: Vec<f64> = t.iter().map(|&ti| ti * ti).collect();
2481
2482        let composed = compose_warps(&g1, &g2, &t);
2483        assert!(
2484            (composed[0] - t[0]).abs() < 1e-10,
2485            "Composed warp should start at domain start"
2486        );
2487        assert!(
2488            (composed[m - 1] - t[m - 1]).abs() < 1e-10,
2489            "Composed warp should end at domain end"
2490        );
2491    }
2492
2493    // ── Elastic align pair ──
2494
2495    #[test]
2496    fn test_align_identical_curves() {
2497        let m = 50;
2498        let t = uniform_grid(m);
2499        let f: Vec<f64> = t
2500            .iter()
2501            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
2502            .collect();
2503
2504        let result = elastic_align_pair(&f, &f, &t, 0.0);
2505
2506        // Distance should be near zero
2507        assert!(
2508            result.distance < 0.1,
2509            "Distance between identical curves should be near 0, got {}",
2510            result.distance
2511        );
2512
2513        // Warp should be near identity
2514        for j in 0..m {
2515            assert!(
2516                (result.gamma[j] - t[j]).abs() < 0.1,
2517                "Warp should be near identity at j={j}: gamma={}, t={}",
2518                result.gamma[j],
2519                t[j]
2520            );
2521        }
2522    }
2523
2524    #[test]
2525    fn test_align_pair_valid_output() {
2526        let data = make_test_data(2, 50, 42);
2527        let t = uniform_grid(50);
2528        let f1 = data.row(0);
2529        let f2 = data.row(1);
2530
2531        let result = elastic_align_pair(&f1, &f2, &t, 0.0);
2532
2533        assert_eq!(result.gamma.len(), 50);
2534        assert_eq!(result.f_aligned.len(), 50);
2535        assert!(result.distance >= 0.0);
2536
2537        // Warp should be monotone
2538        for j in 1..50 {
2539            assert!(
2540                result.gamma[j] >= result.gamma[j - 1],
2541                "Warp should be monotone at j={j}"
2542            );
2543        }
2544    }
2545
2546    #[test]
2547    fn test_align_pair_warp_boundaries() {
2548        let data = make_test_data(2, 50, 42);
2549        let t = uniform_grid(50);
2550        let f1 = data.row(0);
2551        let f2 = data.row(1);
2552
2553        let result = elastic_align_pair(&f1, &f2, &t, 0.0);
2554        assert!(
2555            (result.gamma[0] - t[0]).abs() < 1e-12,
2556            "Warp should start at domain start"
2557        );
2558        assert!(
2559            (result.gamma[49] - t[49]).abs() < 1e-12,
2560            "Warp should end at domain end"
2561        );
2562    }
2563
2564    #[test]
2565    fn test_align_shifted_sine() {
2566        // Two sines with a phase shift — alignment should reduce distance
2567        let m = 80;
2568        let t = uniform_grid(m);
2569        let f1: Vec<f64> = t
2570            .iter()
2571            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
2572            .collect();
2573        let f2: Vec<f64> = t
2574            .iter()
2575            .map(|&ti| (2.0 * std::f64::consts::PI * (ti - 0.1)).sin())
2576            .collect();
2577
2578        let weights = simpsons_weights(&t);
2579        let l2_before = l2_distance(&f1, &f2, &weights);
2580        let result = elastic_align_pair(&f1, &f2, &t, 0.0);
2581        let l2_after = l2_distance(&f1, &result.f_aligned, &weights);
2582
2583        assert!(
2584            l2_after < l2_before + 0.01,
2585            "Alignment should not increase L2 distance: before={l2_before:.4}, after={l2_after:.4}"
2586        );
2587    }
2588
2589    #[test]
2590    fn test_align_pair_aligned_curve_is_finite() {
2591        let data = make_test_data(2, 50, 77);
2592        let t = uniform_grid(50);
2593        let f1 = data.row(0);
2594        let f2 = data.row(1);
2595
2596        let result = elastic_align_pair(&f1, &f2, &t, 0.0);
2597        for j in 0..50 {
2598            assert!(
2599                result.f_aligned[j].is_finite(),
2600                "Aligned curve should be finite at j={j}"
2601            );
2602        }
2603    }
2604
2605    #[test]
2606    fn test_align_pair_minimum_grid() {
2607        // Minimum viable grid: m = 2
2608        let t = vec![0.0, 1.0];
2609        let f1 = vec![0.0, 1.0];
2610        let f2 = vec![0.0, 2.0];
2611        let result = elastic_align_pair(&f1, &f2, &t, 0.0);
2612        assert_eq!(result.gamma.len(), 2);
2613        assert_eq!(result.f_aligned.len(), 2);
2614        assert!(result.distance >= 0.0);
2615    }
2616
2617    // ── Elastic distance ──
2618
2619    #[test]
2620    fn test_elastic_distance_symmetric() {
2621        let data = make_test_data(3, 50, 42);
2622        let t = uniform_grid(50);
2623        let f1 = data.row(0);
2624        let f2 = data.row(1);
2625
2626        let d12 = elastic_distance(&f1, &f2, &t, 0.0);
2627        let d21 = elastic_distance(&f2, &f1, &t, 0.0);
2628
2629        // Should be approximately symmetric (DP is not perfectly symmetric)
2630        assert!(
2631            (d12 - d21).abs() < d12.max(d21) * 0.3 + 0.01,
2632            "Elastic distance should be roughly symmetric: d12={d12}, d21={d21}"
2633        );
2634    }
2635
2636    #[test]
2637    fn test_elastic_distance_nonneg() {
2638        let data = make_test_data(3, 50, 42);
2639        let t = uniform_grid(50);
2640
2641        for i in 0..3 {
2642            for j in 0..3 {
2643                let fi = data.row(i);
2644                let fj = data.row(j);
2645                let d = elastic_distance(&fi, &fj, &t, 0.0);
2646                assert!(d >= 0.0, "Elastic distance should be non-negative");
2647            }
2648        }
2649    }
2650
2651    #[test]
2652    fn test_elastic_distance_self_near_zero() {
2653        let data = make_test_data(3, 50, 42);
2654        let t = uniform_grid(50);
2655
2656        for i in 0..3 {
2657            let fi = data.row(i);
2658            let d = elastic_distance(&fi, &fi, &t, 0.0);
2659            assert!(
2660                d < 0.1,
2661                "Self-distance should be near zero, got {d} for curve {i}"
2662            );
2663        }
2664    }
2665
2666    #[test]
2667    fn test_elastic_distance_triangle_inequality() {
2668        let data = make_test_data(3, 50, 42);
2669        let t = uniform_grid(50);
2670        let f0 = data.row(0);
2671        let f1 = data.row(1);
2672        let f2 = data.row(2);
2673
2674        let d01 = elastic_distance(&f0, &f1, &t, 0.0);
2675        let d12 = elastic_distance(&f1, &f2, &t, 0.0);
2676        let d02 = elastic_distance(&f0, &f2, &t, 0.0);
2677
2678        // Relaxed triangle inequality (DP alignment is approximate)
2679        let slack = 0.5;
2680        assert!(
2681            d02 <= d01 + d12 + slack,
2682            "Triangle inequality (relaxed): d02={d02:.4} > d01={d01:.4} + d12={d12:.4} + {slack}"
2683        );
2684    }
2685
2686    #[test]
2687    fn test_elastic_distance_different_shapes_nonzero() {
2688        let m = 50;
2689        let t = uniform_grid(m);
2690        let f1: Vec<f64> = t.to_vec(); // linear
2691        let f2: Vec<f64> = t.iter().map(|&ti| ti * ti).collect(); // quadratic
2692
2693        let d = elastic_distance(&f1, &f2, &t, 0.0);
2694        assert!(
2695            d > 0.01,
2696            "Distance between different shapes should be > 0, got {d}"
2697        );
2698    }
2699
2700    // ── Self distance matrix ──
2701
2702    #[test]
2703    fn test_self_distance_matrix_symmetric() {
2704        let data = make_test_data(5, 30, 42);
2705        let t = uniform_grid(30);
2706
2707        let dm = elastic_self_distance_matrix(&data, &t, 0.0);
2708        let n = dm.nrows();
2709
2710        assert_eq!(dm.shape(), (5, 5));
2711
2712        // Zero diagonal
2713        for i in 0..n {
2714            assert!(dm[(i, i)].abs() < 1e-12, "Diagonal should be zero");
2715        }
2716
2717        // Symmetric
2718        for i in 0..n {
2719            for j in (i + 1)..n {
2720                assert!(
2721                    (dm[(i, j)] - dm[(j, i)]).abs() < 1e-12,
2722                    "Matrix should be symmetric at ({i},{j})"
2723                );
2724            }
2725        }
2726    }
2727
2728    #[test]
2729    fn test_self_distance_matrix_nonneg() {
2730        let data = make_test_data(4, 30, 42);
2731        let t = uniform_grid(30);
2732        let dm = elastic_self_distance_matrix(&data, &t, 0.0);
2733
2734        for i in 0..4 {
2735            for j in 0..4 {
2736                assert!(
2737                    dm[(i, j)] >= 0.0,
2738                    "Distance matrix entries should be non-negative at ({i},{j})"
2739                );
2740            }
2741        }
2742    }
2743
2744    #[test]
2745    fn test_self_distance_matrix_single_curve() {
2746        let data = make_test_data(1, 30, 42);
2747        let t = uniform_grid(30);
2748        let dm = elastic_self_distance_matrix(&data, &t, 0.0);
2749        assert_eq!(dm.shape(), (1, 1));
2750        assert!(dm[(0, 0)].abs() < 1e-12);
2751    }
2752
2753    #[test]
2754    fn test_self_distance_matrix_consistent_with_pairwise() {
2755        let data = make_test_data(4, 30, 42);
2756        let t = uniform_grid(30);
2757
2758        let dm = elastic_self_distance_matrix(&data, &t, 0.0);
2759
2760        // Check a few entries match direct elastic_distance calls
2761        for i in 0..4 {
2762            for j in (i + 1)..4 {
2763                let fi = data.row(i);
2764                let fj = data.row(j);
2765                let d_direct = elastic_distance(&fi, &fj, &t, 0.0);
2766                assert!(
2767                    (dm[(i, j)] - d_direct).abs() < 1e-10,
2768                    "Matrix entry ({i},{j})={:.6} should match pairwise {d_direct:.6}",
2769                    dm[(i, j)]
2770                );
2771            }
2772        }
2773    }
2774
2775    // ── Karcher mean ──
2776
2777    #[test]
2778    fn test_karcher_mean_identical_curves() {
2779        let m = 50;
2780        let t = uniform_grid(m);
2781        let f: Vec<f64> = t
2782            .iter()
2783            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
2784            .collect();
2785
2786        // Create 5 identical curves
2787        let mut data = FdMatrix::zeros(5, m);
2788        for i in 0..5 {
2789            for j in 0..m {
2790                data[(i, j)] = f[j];
2791            }
2792        }
2793
2794        let result = karcher_mean(&data, &t, 10, 1e-4, 0.0);
2795
2796        assert_eq!(result.mean.len(), m);
2797        assert!(result.n_iter <= 10);
2798    }
2799
2800    #[test]
2801    fn test_karcher_mean_output_shape() {
2802        let data = make_test_data(15, 50, 42);
2803        let t = uniform_grid(50);
2804
2805        let result = karcher_mean(&data, &t, 5, 1e-3, 0.0);
2806
2807        assert_eq!(result.mean.len(), 50);
2808        assert_eq!(result.mean_srsf.len(), 50);
2809        assert_eq!(result.gammas.shape(), (15, 50));
2810        assert_eq!(result.aligned_data.shape(), (15, 50));
2811        assert!(result.n_iter <= 5);
2812    }
2813
2814    #[test]
2815    fn test_karcher_mean_warps_are_valid() {
2816        let data = make_test_data(10, 40, 42);
2817        let t = uniform_grid(40);
2818
2819        let result = karcher_mean(&data, &t, 5, 1e-3, 0.0);
2820
2821        for i in 0..10 {
2822            // Boundary values
2823            assert!(
2824                (result.gammas[(i, 0)] - t[0]).abs() < 1e-10,
2825                "Warp {i} should start at domain start"
2826            );
2827            assert!(
2828                (result.gammas[(i, 39)] - t[39]).abs() < 1e-10,
2829                "Warp {i} should end at domain end"
2830            );
2831            // Monotonicity
2832            for j in 1..40 {
2833                assert!(
2834                    result.gammas[(i, j)] >= result.gammas[(i, j - 1)],
2835                    "Warp {i} should be monotone at j={j}"
2836                );
2837            }
2838        }
2839    }
2840
2841    #[test]
2842    fn test_karcher_mean_aligned_data_is_finite() {
2843        let data = make_test_data(8, 40, 42);
2844        let t = uniform_grid(40);
2845        let result = karcher_mean(&data, &t, 5, 1e-3, 0.0);
2846
2847        for i in 0..8 {
2848            for j in 0..40 {
2849                assert!(
2850                    result.aligned_data[(i, j)].is_finite(),
2851                    "Aligned data should be finite at ({i},{j})"
2852                );
2853            }
2854        }
2855    }
2856
2857    #[test]
2858    fn test_karcher_mean_srsf_is_finite() {
2859        let data = make_test_data(8, 40, 42);
2860        let t = uniform_grid(40);
2861        let result = karcher_mean(&data, &t, 5, 1e-3, 0.0);
2862
2863        for j in 0..40 {
2864            assert!(
2865                result.mean_srsf[j].is_finite(),
2866                "Mean SRSF should be finite at j={j}"
2867            );
2868            assert!(
2869                result.mean[j].is_finite(),
2870                "Mean curve should be finite at j={j}"
2871            );
2872        }
2873    }
2874
2875    #[test]
2876    fn test_karcher_mean_single_iteration() {
2877        let data = make_test_data(10, 40, 42);
2878        let t = uniform_grid(40);
2879        let result = karcher_mean(&data, &t, 1, 1e-10, 0.0);
2880
2881        assert_eq!(result.n_iter, 1);
2882        assert_eq!(result.mean.len(), 40);
2883        // With only 1 iteration, still produces valid output
2884        for j in 0..40 {
2885            assert!(result.mean[j].is_finite());
2886        }
2887    }
2888
2889    #[test]
2890    fn test_karcher_mean_convergence_not_premature() {
2891        // The old convergence criterion (rel - prev_rel <= tol * prev_rel) was
2892        // always satisfied when the algorithm improved, causing premature exit
2893        // after 2 iterations. With the fix (rel < tol), the algorithm should
2894        // actually iterate until convergence or hitting the iteration cap.
2895        let n = 10;
2896        let m = 50;
2897        let t = uniform_grid(m);
2898
2899        // Create phase-shifted curves that genuinely need alignment
2900        let mut col_major = vec![0.0; n * m];
2901        for i in 0..n {
2902            let shift = (i as f64 - 5.0) * 0.05;
2903            for j in 0..m {
2904                col_major[i + j * n] = (2.0 * std::f64::consts::PI * (t[j] - shift)).sin();
2905            }
2906        }
2907        let data = FdMatrix::from_column_major(col_major, n, m).unwrap();
2908
2909        // With an impossibly tight tolerance, the algorithm should hit the
2910        // iteration cap rather than "converging" after 2 iterations.
2911        let result = karcher_mean(&data, &t, 20, 1e-15, 0.0);
2912        assert!(
2913            result.n_iter > 2,
2914            "With tol=1e-15 the algorithm should iterate beyond 2, got n_iter={}",
2915            result.n_iter
2916        );
2917
2918        // With a reasonable tolerance, it should converge and report so
2919        let result_loose = karcher_mean(&data, &t, 20, 1e-2, 0.0);
2920        assert!(
2921            result_loose.converged,
2922            "With tol=1e-2 the algorithm should converge"
2923        );
2924    }
2925
2926    // ── Align to target ──
2927
2928    #[test]
2929    fn test_align_to_target_valid() {
2930        let data = make_test_data(10, 40, 42);
2931        let t = uniform_grid(40);
2932        let target = data.row(0);
2933
2934        let result = align_to_target(&data, &target, &t, 0.0);
2935
2936        assert_eq!(result.gammas.shape(), (10, 40));
2937        assert_eq!(result.aligned_data.shape(), (10, 40));
2938        assert_eq!(result.distances.len(), 10);
2939
2940        // All distances should be non-negative
2941        for &d in &result.distances {
2942            assert!(d >= 0.0);
2943        }
2944    }
2945
2946    #[test]
2947    fn test_align_to_target_self_near_zero() {
2948        let data = make_test_data(5, 40, 42);
2949        let t = uniform_grid(40);
2950        let target = data.row(0);
2951
2952        let result = align_to_target(&data, &target, &t, 0.0);
2953
2954        // Distance of target to itself should be near zero
2955        assert!(
2956            result.distances[0] < 0.1,
2957            "Self-alignment distance should be near zero, got {}",
2958            result.distances[0]
2959        );
2960    }
2961
2962    #[test]
2963    fn test_align_to_target_warps_are_monotone() {
2964        let data = make_test_data(8, 40, 42);
2965        let t = uniform_grid(40);
2966        let target = data.row(0);
2967        let result = align_to_target(&data, &target, &t, 0.0);
2968
2969        for i in 0..8 {
2970            for j in 1..40 {
2971                assert!(
2972                    result.gammas[(i, j)] >= result.gammas[(i, j - 1)],
2973                    "Warp for curve {i} should be monotone at j={j}"
2974                );
2975            }
2976        }
2977    }
2978
2979    #[test]
2980    fn test_align_to_target_aligned_data_finite() {
2981        let data = make_test_data(6, 40, 42);
2982        let t = uniform_grid(40);
2983        let target = data.row(0);
2984        let result = align_to_target(&data, &target, &t, 0.0);
2985
2986        for i in 0..6 {
2987            for j in 0..40 {
2988                assert!(
2989                    result.aligned_data[(i, j)].is_finite(),
2990                    "Aligned data should be finite at ({i},{j})"
2991                );
2992            }
2993        }
2994    }
2995
2996    // ── Cross distance matrix ──
2997
2998    #[test]
2999    fn test_cross_distance_matrix_shape() {
3000        let data1 = make_test_data(3, 30, 42);
3001        let data2 = make_test_data(4, 30, 99);
3002        let t = uniform_grid(30);
3003
3004        let dm = elastic_cross_distance_matrix(&data1, &data2, &t, 0.0);
3005        assert_eq!(dm.shape(), (3, 4));
3006
3007        // All non-negative
3008        for i in 0..3 {
3009            for j in 0..4 {
3010                assert!(dm[(i, j)] >= 0.0);
3011            }
3012        }
3013    }
3014
3015    #[test]
3016    fn test_cross_distance_matrix_self_matches_self_matrix() {
3017        // cross_distance(data, data) should have zero diagonal (approximately)
3018        let data = make_test_data(4, 30, 42);
3019        let t = uniform_grid(30);
3020
3021        let cross = elastic_cross_distance_matrix(&data, &data, &t, 0.0);
3022        for i in 0..4 {
3023            assert!(
3024                cross[(i, i)] < 0.1,
3025                "Cross distance (self) diagonal should be near zero: got {}",
3026                cross[(i, i)]
3027            );
3028        }
3029    }
3030
3031    #[test]
3032    fn test_cross_distance_matrix_consistent_with_pairwise() {
3033        let data1 = make_test_data(3, 30, 42);
3034        let data2 = make_test_data(2, 30, 99);
3035        let t = uniform_grid(30);
3036
3037        let dm = elastic_cross_distance_matrix(&data1, &data2, &t, 0.0);
3038
3039        for i in 0..3 {
3040            for j in 0..2 {
3041                let fi = data1.row(i);
3042                let fj = data2.row(j);
3043                let d_direct = elastic_distance(&fi, &fj, &t, 0.0);
3044                assert!(
3045                    (dm[(i, j)] - d_direct).abs() < 1e-10,
3046                    "Cross matrix ({i},{j})={:.6} should match pairwise {d_direct:.6}",
3047                    dm[(i, j)]
3048                );
3049            }
3050        }
3051    }
3052
3053    // ── align_srsf_pair ──
3054
3055    #[test]
3056    fn test_align_srsf_pair_identity() {
3057        let m = 50;
3058        let t = uniform_grid(m);
3059        let f: Vec<f64> = t
3060            .iter()
3061            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
3062            .collect();
3063        let q = srsf_single(&f, &t);
3064
3065        let (gamma, q_aligned) = align_srsf_pair(&q, &q, &t, 0.0);
3066
3067        // Warp should be near identity
3068        for j in 0..m {
3069            assert!(
3070                (gamma[j] - t[j]).abs() < 0.15,
3071                "Self-SRSF alignment warp should be near identity at j={j}"
3072            );
3073        }
3074
3075        // Aligned SRSF should be close to original
3076        let weights = simpsons_weights(&t);
3077        let dist = l2_distance(&q, &q_aligned, &weights);
3078        assert!(
3079            dist < 0.5,
3080            "Self-aligned SRSF distance should be small, got {dist}"
3081        );
3082    }
3083
3084    // ── srsf_single ──
3085
3086    #[test]
3087    fn test_srsf_single_matches_matrix_version() {
3088        let m = 50;
3089        let t = uniform_grid(m);
3090        let f: Vec<f64> = t.iter().map(|&ti| ti * ti + ti).collect();
3091
3092        let q_single = srsf_single(&f, &t);
3093
3094        let mat = FdMatrix::from_slice(&f, 1, m).unwrap();
3095        let q_mat = srsf_transform(&mat, &t);
3096        let q_from_mat = q_mat.row(0);
3097
3098        for j in 0..m {
3099            assert!(
3100                (q_single[j] - q_from_mat[j]).abs() < 1e-12,
3101                "srsf_single should match srsf_transform at j={j}"
3102            );
3103        }
3104    }
3105
3106    // ── gcd ──
3107
3108    #[test]
3109    fn test_gcd_basic() {
3110        assert_eq!(gcd(1, 1), 1);
3111        assert_eq!(gcd(6, 4), 2);
3112        assert_eq!(gcd(7, 5), 1);
3113        assert_eq!(gcd(12, 8), 4);
3114        assert_eq!(gcd(7, 0), 7);
3115        assert_eq!(gcd(0, 5), 5);
3116    }
3117
3118    // ── generate_coprime_nbhd ──
3119
3120    #[test]
3121    fn test_coprime_nbhd_count() {
3122        assert_eq!(generate_coprime_nbhd(1).len(), 1); // just (1,1)
3123        assert_eq!(generate_coprime_nbhd(7).len(), 35);
3124    }
3125
3126    #[test]
3127    fn test_coprime_nbhd_matches_const() {
3128        let generated = generate_coprime_nbhd(7);
3129        assert_eq!(generated.len(), COPRIME_NBHD_7.len());
3130        for (i, pair) in generated.iter().enumerate() {
3131            assert_eq!(*pair, COPRIME_NBHD_7[i], "mismatch at index {i}");
3132        }
3133    }
3134
3135    #[test]
3136    fn test_coprime_nbhd_all_coprime() {
3137        for &(i, j) in &COPRIME_NBHD_7 {
3138            assert_eq!(gcd(i, j), 1, "({i},{j}) should be coprime");
3139            assert!((1..=7).contains(&i));
3140            assert!((1..=7).contains(&j));
3141        }
3142    }
3143
3144    // ── dp_edge_weight ──
3145
3146    #[test]
3147    fn test_dp_edge_weight_diagonal() {
3148        // Diagonal move (1,1): weight = (q1[sc] - sqrt(1)*q2[sr])^2 * h
3149        let t = uniform_grid(10);
3150        let q1 = vec![1.0; 10];
3151        let q2 = vec![1.0; 10];
3152        // Identical SRSFs: weight should be 0
3153        let w = dp_edge_weight(&q1, &q2, &t, 0, 1, 0, 1);
3154        assert!(w.abs() < 1e-12, "identical SRSFs should have zero cost");
3155    }
3156
3157    #[test]
3158    fn test_dp_edge_weight_non_diagonal() {
3159        // Move (1,2): n1=2, n2=1, slope = h/(2h) = 0.5
3160        let t = uniform_grid(10);
3161        let q1 = vec![1.0; 10];
3162        let q2 = vec![0.0; 10];
3163        let w = dp_edge_weight(&q1, &q2, &t, 0, 2, 0, 1);
3164        // diff = q1[0] - sqrt(0.5)*q2[0] = 1.0 - 0 = 1.0
3165        // weight = 1.0^2 * 1.0 * (t[2]-t[0]) = 2/9
3166        let expected = 2.0 / 9.0;
3167        assert!(
3168            (w - expected).abs() < 1e-10,
3169            "dp_edge_weight (1,2): expected {expected}, got {w}"
3170        );
3171    }
3172
3173    #[test]
3174    fn test_dp_edge_weight_zero_span() {
3175        let t = uniform_grid(10);
3176        let q1 = vec![1.0; 10];
3177        let q2 = vec![1.0; 10];
3178        // n1=0: should return INFINITY
3179        assert_eq!(dp_edge_weight(&q1, &q2, &t, 3, 3, 0, 1), f64::INFINITY);
3180        // n2=0: should return INFINITY
3181        assert_eq!(dp_edge_weight(&q1, &q2, &t, 0, 1, 3, 3), f64::INFINITY);
3182    }
3183
3184    // ── DP alignment quality ──
3185
3186    #[test]
3187    fn test_alignment_improves_distance() {
3188        // Aligned SRSF distance should be less than unaligned SRSF distance
3189        let m = 50;
3190        let t = uniform_grid(m);
3191        let f1: Vec<f64> = t
3192            .iter()
3193            .map(|&x| (2.0 * std::f64::consts::PI * x).sin())
3194            .collect();
3195        // Use a larger shift so improvement is clear
3196        let f2: Vec<f64> = t
3197            .iter()
3198            .map(|&x| (2.0 * std::f64::consts::PI * (x + 0.2)).sin())
3199            .collect();
3200
3201        let q1 = srsf_single(&f1, &t);
3202        let q2 = srsf_single(&f2, &t);
3203        let weights = simpsons_weights(&t);
3204        let unaligned_srsf_dist = l2_distance(&q1, &q2, &weights);
3205
3206        let result = elastic_align_pair(&f1, &f2, &t, 0.0);
3207
3208        assert!(
3209            result.distance <= unaligned_srsf_dist + 1e-6,
3210            "aligned SRSF dist ({}) should be <= unaligned SRSF dist ({})",
3211            result.distance,
3212            unaligned_srsf_dist
3213        );
3214    }
3215
3216    // ── Edge case: constant data ──
3217
3218    #[test]
3219    fn test_alignment_constant_curves() {
3220        let m = 30;
3221        let t = uniform_grid(m);
3222        let f1 = vec![5.0; m];
3223        let f2 = vec![5.0; m];
3224
3225        let result = elastic_align_pair(&f1, &f2, &t, 0.0);
3226        assert!(
3227            result.distance < 0.01,
3228            "Constant curves: distance should be ~0"
3229        );
3230        assert_eq!(result.f_aligned.len(), m);
3231    }
3232
3233    #[test]
3234    fn test_karcher_mean_constant_curves() {
3235        let m = 30;
3236        let t = uniform_grid(m);
3237        let mut data = FdMatrix::zeros(5, m);
3238        for i in 0..5 {
3239            for j in 0..m {
3240                data[(i, j)] = 3.0;
3241            }
3242        }
3243
3244        let result = karcher_mean(&data, &t, 5, 1e-4, 0.0);
3245        for j in 0..m {
3246            assert!(
3247                (result.mean[j] - 3.0).abs() < 0.5,
3248                "Mean of constant curves should be near 3.0, got {} at j={j}",
3249                result.mean[j]
3250            );
3251        }
3252    }
3253
3254    #[test]
3255    fn test_nan_srsf_no_panic() {
3256        let m = 20;
3257        let t = uniform_grid(m);
3258        let mut f = vec![1.0; m];
3259        f[5] = f64::NAN;
3260        let mat = FdMatrix::from_slice(&f, 1, m).unwrap();
3261        let q = srsf_transform(&mat, &t);
3262        // Should not panic; NaN propagates
3263        assert_eq!(q.nrows(), 1);
3264    }
3265
3266    #[test]
3267    fn test_n1_karcher_mean() {
3268        let m = 30;
3269        let t = uniform_grid(m);
3270        let f: Vec<f64> = t.iter().map(|&ti| ti * ti).collect();
3271        let data = FdMatrix::from_slice(&f, 1, m).unwrap();
3272        let result = karcher_mean(&data, &t, 5, 1e-4, 0.0);
3273        assert_eq!(result.mean.len(), m);
3274        // With only 1 curve, the mean should be close to the original
3275        for j in 0..m {
3276            assert!(result.mean[j].is_finite());
3277        }
3278    }
3279
3280    #[test]
3281    fn test_two_point_grid() {
3282        let t = vec![0.0, 1.0];
3283        let f1 = vec![0.0, 1.0];
3284        let f2 = vec![0.0, 2.0];
3285        let d = elastic_distance(&f1, &f2, &t, 0.0);
3286        assert!(d >= 0.0);
3287        assert!(d.is_finite());
3288    }
3289
3290    #[test]
3291    fn test_non_uniform_grid_alignment() {
3292        // Non-uniform grid: points clustered near 0
3293        let t = vec![0.0, 0.01, 0.05, 0.2, 0.5, 1.0];
3294        let m = t.len();
3295        let f1: Vec<f64> = t.iter().map(|&ti: &f64| ti.sin()).collect();
3296        let f2: Vec<f64> = t.iter().map(|&ti: &f64| (ti + 0.1).sin()).collect();
3297        let result = elastic_align_pair(&f1, &f2, &t, 0.0);
3298        assert_eq!(result.gamma.len(), m);
3299        assert!(result.distance >= 0.0);
3300        assert!(result.distance.is_finite());
3301    }
3302
3303    // ── TSRVF tests ──
3304
3305    #[test]
3306    fn test_tsrvf_output_shape() {
3307        let m = 50;
3308        let n = 10;
3309        let t = uniform_grid(m);
3310        let data = make_test_data(n, m, 42);
3311        let result = tsrvf_transform(&data, &t, 5, 1e-3, 0.0);
3312        assert_eq!(
3313            result.tangent_vectors.shape(),
3314            (n, m),
3315            "Tangent vectors should be n×m"
3316        );
3317        assert_eq!(result.gammas.shape(), (n, m), "Gammas should be n×m");
3318        assert_eq!(result.srsf_norms.len(), n, "Should have n SRSF norms");
3319        assert_eq!(result.mean.len(), m, "Mean should have m points");
3320        assert_eq!(result.mean_srsf.len(), m, "Mean SRSF should have m points");
3321    }
3322
3323    #[test]
3324    fn test_tsrvf_all_finite() {
3325        let m = 50;
3326        let n = 5;
3327        let t = uniform_grid(m);
3328        let data = make_test_data(n, m, 42);
3329        let result = tsrvf_transform(&data, &t, 5, 1e-3, 0.0);
3330        for i in 0..n {
3331            for j in 0..m {
3332                assert!(
3333                    result.tangent_vectors[(i, j)].is_finite(),
3334                    "Tangent vector should be finite at ({i},{j})"
3335                );
3336            }
3337            assert!(
3338                result.srsf_norms[i].is_finite(),
3339                "SRSF norm should be finite for curve {i}"
3340            );
3341        }
3342        assert!(
3343            result.mean_srsf_norm.is_finite(),
3344            "Mean SRSF norm should be finite"
3345        );
3346    }
3347
3348    #[test]
3349    fn test_tsrvf_identical_curves_zero_tangent() {
3350        let m = 50;
3351        let t = uniform_grid(m);
3352        // Stack 5 identical curves
3353        let curve: Vec<f64> = t
3354            .iter()
3355            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
3356            .collect();
3357        let mut col_major = vec![0.0; 5 * m];
3358        for i in 0..5 {
3359            for j in 0..m {
3360                col_major[i + j * 5] = curve[j];
3361            }
3362        }
3363        let data = FdMatrix::from_column_major(col_major, 5, m).unwrap();
3364        let result = tsrvf_transform(&data, &t, 10, 1e-4, 0.0);
3365
3366        // All tangent vectors should be approximately zero
3367        for i in 0..5 {
3368            let tv_norm_sq: f64 = (0..m).map(|j| result.tangent_vectors[(i, j)].powi(2)).sum();
3369            assert!(
3370                tv_norm_sq.sqrt() < 0.5,
3371                "Identical curves should have near-zero tangent vectors, got norm = {}",
3372                tv_norm_sq.sqrt()
3373            );
3374        }
3375    }
3376
3377    #[test]
3378    fn test_tsrvf_mean_tangent_near_zero() {
3379        let m = 50;
3380        let n = 10;
3381        let t = uniform_grid(m);
3382        let data = make_test_data(n, m, 42);
3383        let result = tsrvf_transform(&data, &t, 10, 1e-3, 0.0);
3384
3385        // Mean of tangent vectors should be approximately zero (property of Karcher mean)
3386        let mut mean_tv = vec![0.0; m];
3387        for i in 0..n {
3388            for j in 0..m {
3389                mean_tv[j] += result.tangent_vectors[(i, j)];
3390            }
3391        }
3392        for j in 0..m {
3393            mean_tv[j] /= n as f64;
3394        }
3395        let mean_norm: f64 = mean_tv.iter().map(|v| v * v).sum::<f64>().sqrt();
3396        assert!(
3397            mean_norm < 1.0,
3398            "Mean tangent vector should be near zero, got norm = {mean_norm}"
3399        );
3400    }
3401
3402    #[test]
3403    fn test_tsrvf_from_alignment() {
3404        let m = 50;
3405        let n = 5;
3406        let t = uniform_grid(m);
3407        let data = make_test_data(n, m, 42);
3408        let karcher = karcher_mean(&data, &t, 5, 1e-3, 0.0);
3409        let result = tsrvf_from_alignment(&karcher, &t);
3410        assert_eq!(result.tangent_vectors.shape(), (n, m));
3411        assert!(result.mean_srsf_norm > 0.0);
3412    }
3413
3414    #[test]
3415    fn test_tsrvf_round_trip() {
3416        let m = 50;
3417        let n = 5;
3418        let t = uniform_grid(m);
3419        let data = make_test_data(n, m, 42);
3420        let result = tsrvf_transform(&data, &t, 10, 1e-3, 0.0);
3421        let reconstructed = tsrvf_inverse(&result, &t);
3422
3423        assert_eq!(reconstructed.shape(), result.tangent_vectors.shape());
3424        // Reconstructed curves should have finite values
3425        for i in 0..n {
3426            for j in 0..m {
3427                assert!(
3428                    reconstructed[(i, j)].is_finite(),
3429                    "Reconstructed curve should be finite at ({i},{j})"
3430                );
3431            }
3432        }
3433        // Issue #12: per-curve initial values should be preserved
3434        for i in 0..n {
3435            assert!(
3436                (reconstructed[(i, 0)] - result.initial_values[i]).abs() < 1e-6,
3437                "Curve {i} initial value: expected {}, got {}",
3438                result.initial_values[i],
3439                reconstructed[(i, 0)]
3440            );
3441        }
3442    }
3443
3444    #[test]
3445    fn test_tsrvf_initial_values_per_curve() {
3446        // Issue #12: tsrvf_inverse must use per-curve initial values, not mean[0]
3447        let m = 50;
3448        let n = 5;
3449        let t = uniform_grid(m);
3450
3451        // Create curves with distinct initial values
3452        let mut col_major = vec![0.0; n * m];
3453        for i in 0..n {
3454            let offset = (i as f64 + 1.0) * 2.0; // offsets: 2, 4, 6, 8, 10
3455            for j in 0..m {
3456                col_major[i + j * n] = offset + (2.0 * std::f64::consts::PI * t[j]).sin();
3457            }
3458        }
3459        let data = FdMatrix::from_column_major(col_major, n, m).unwrap();
3460
3461        let result = tsrvf_transform(&data, &t, 15, 1e-4, 0.0);
3462
3463        // initial_values should differ per curve
3464        assert_eq!(result.initial_values.len(), n);
3465        let all_same = result
3466            .initial_values
3467            .windows(2)
3468            .all(|w| (w[0] - w[1]).abs() < 1e-10);
3469        assert!(
3470            !all_same,
3471            "Initial values should differ per curve: {:?}",
3472            result.initial_values
3473        );
3474
3475        // Reconstruct and check initial values are preserved
3476        let reconstructed = tsrvf_inverse(&result, &t);
3477        for i in 0..n {
3478            assert!(
3479                (reconstructed[(i, 0)] - result.initial_values[i]).abs() < 1e-6,
3480                "Curve {i}: reconstructed f(0) = {}, expected {}",
3481                reconstructed[(i, 0)],
3482                result.initial_values[i]
3483            );
3484        }
3485
3486        // Before the fix, all curves would have reconstructed[(i, 0)] ≈ mean[0]
3487        // Verify they are NOT all the same
3488        let recon_initials: Vec<f64> = (0..n).map(|i| reconstructed[(i, 0)]).collect();
3489        let all_recon_same = recon_initials.windows(2).all(|w| (w[0] - w[1]).abs() < 0.1);
3490        assert!(
3491            !all_recon_same,
3492            "Reconstructed initial values must vary per curve: {:?}",
3493            recon_initials
3494        );
3495    }
3496
3497    #[test]
3498    fn test_tsrvf_single_curve() {
3499        let m = 50;
3500        let t = uniform_grid(m);
3501        let data = make_test_data(1, m, 42);
3502        let result = tsrvf_transform(&data, &t, 5, 1e-3, 0.0);
3503        assert_eq!(result.tangent_vectors.shape(), (1, m));
3504        // Single curve → tangent vector should be zero (it IS the mean)
3505        let tv_norm: f64 = (0..m)
3506            .map(|j| result.tangent_vectors[(0, j)].powi(2))
3507            .sum::<f64>()
3508            .sqrt();
3509        assert!(
3510            tv_norm < 0.5,
3511            "Single curve tangent vector should be near zero, got {tv_norm}"
3512        );
3513    }
3514
3515    #[test]
3516    fn test_tsrvf_constant_curves() {
3517        let m = 30;
3518        let t = uniform_grid(m);
3519        // Constant curves → SRSF = 0, norms = 0
3520        let data = FdMatrix::from_column_major(vec![5.0; 3 * m], 3, m).unwrap();
3521        let result = tsrvf_transform(&data, &t, 5, 1e-3, 0.0);
3522        // Should not produce NaN or Inf
3523        for i in 0..3 {
3524            for j in 0..m {
3525                let v = result.tangent_vectors[(i, j)];
3526                assert!(
3527                    !v.is_nan(),
3528                    "Constant curves should not produce NaN tangent vectors"
3529                );
3530            }
3531        }
3532    }
3533
3534    // ── Reference-value tests (sphere geometry) ─────────────────────────────
3535
3536    #[test]
3537    fn test_tsrvf_sphere_inv_exp_reference() {
3538        // Analytical reference: known vectors on the Hilbert sphere
3539        // psi1 = constant (unit L2 norm), psi2 = 1 + 0.3*sin(2πt) (normalized)
3540        let m = 21;
3541        let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
3542
3543        // Construct unit vector psi1 (constant)
3544        let raw1 = vec![1.0; m];
3545        let norm1 = inner_product_l2(&raw1, &raw1, &time).max(0.0).sqrt();
3546        let psi1: Vec<f64> = raw1.iter().map(|&v| v / norm1).collect();
3547
3548        // Construct psi2 with sinusoidal perturbation
3549        let raw2: Vec<f64> = time
3550            .iter()
3551            .map(|&t| 1.0 + 0.3 * (2.0 * std::f64::consts::PI * t).sin())
3552            .collect();
3553        let norm2 = inner_product_l2(&raw2, &raw2, &time).max(0.0).sqrt();
3554        let psi2: Vec<f64> = raw2.iter().map(|&v| v / norm2).collect();
3555
3556        // Compute theta analytically
3557        let ip = inner_product_l2(&psi1, &psi2, &time).clamp(-1.0, 1.0);
3558        let theta_expected = ip.acos();
3559
3560        // Compute via inv_exp_map_sphere
3561        let v = inv_exp_map_sphere(&psi1, &psi2, &time);
3562        let v_norm = inner_product_l2(&v, &v, &time).max(0.0).sqrt();
3563
3564        // ||v|| should equal theta
3565        assert!(
3566            (v_norm - theta_expected).abs() < 1e-10,
3567            "||v|| = {v_norm}, expected theta = {theta_expected}"
3568        );
3569
3570        // theta should be small but non-zero (perturbation is mild)
3571        assert!(
3572            theta_expected > 0.01 && theta_expected < 1.0,
3573            "theta = {theta_expected} out of expected range"
3574        );
3575    }
3576
3577    #[test]
3578    fn test_tsrvf_sphere_round_trip_reference() {
3579        // Round-trip: exp_map(psi1, inv_exp_map(psi1, psi2)) should recover psi2
3580        let m = 21;
3581        let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
3582
3583        let raw1 = vec![1.0; m];
3584        let norm1 = inner_product_l2(&raw1, &raw1, &time).max(0.0).sqrt();
3585        let psi1: Vec<f64> = raw1.iter().map(|&v| v / norm1).collect();
3586
3587        let raw2: Vec<f64> = time
3588            .iter()
3589            .map(|&t| 1.0 + 0.3 * (2.0 * std::f64::consts::PI * t).sin())
3590            .collect();
3591        let norm2 = inner_product_l2(&raw2, &raw2, &time).max(0.0).sqrt();
3592        let psi2: Vec<f64> = raw2.iter().map(|&v| v / norm2).collect();
3593
3594        let v = inv_exp_map_sphere(&psi1, &psi2, &time);
3595        let recovered = exp_map_sphere(&psi1, &v, &time);
3596
3597        // L2 error between psi2 and recovered
3598        let diff: Vec<f64> = psi2
3599            .iter()
3600            .zip(recovered.iter())
3601            .map(|(&a, &b)| (a - b).powi(2))
3602            .collect();
3603        let l2_err = trapz(&diff, &time).max(0.0).sqrt();
3604        assert!(
3605            l2_err < 1e-12,
3606            "Round-trip L2 error = {l2_err:.2e}, expected < 1e-12"
3607        );
3608    }
3609
3610    // ── Penalized alignment ──
3611
3612    #[test]
3613    fn test_penalized_alignment_lambda_zero_matches_unpenalized() {
3614        let m = 50;
3615        let t = uniform_grid(m);
3616        let data = make_test_data(2, m, 42);
3617        let f1 = data.row(0);
3618        let f2 = data.row(1);
3619
3620        let r0 = elastic_align_pair(&f1, &f2, &t, 0.0);
3621        // lambda = 0.0 should produce the same result regardless
3622        assert!(r0.distance >= 0.0);
3623        assert_eq!(r0.gamma.len(), m);
3624    }
3625
3626    #[test]
3627    fn test_penalized_alignment_smoother_warp() {
3628        let m = 80;
3629        let t = uniform_grid(m);
3630        let f1: Vec<f64> = t
3631            .iter()
3632            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
3633            .collect();
3634        let f2: Vec<f64> = t
3635            .iter()
3636            .map(|&ti| (2.0 * std::f64::consts::PI * (ti - 0.15)).sin())
3637            .collect();
3638
3639        let r_free = elastic_align_pair(&f1, &f2, &t, 0.0);
3640        let r_pen = elastic_align_pair(&f1, &f2, &t, 1.0);
3641
3642        // Measure warp deviation from identity
3643        let dev_free: f64 = r_free
3644            .gamma
3645            .iter()
3646            .zip(t.iter())
3647            .map(|(g, ti)| (g - ti).powi(2))
3648            .sum();
3649        let dev_pen: f64 = r_pen
3650            .gamma
3651            .iter()
3652            .zip(t.iter())
3653            .map(|(g, ti)| (g - ti).powi(2))
3654            .sum();
3655
3656        assert!(
3657            dev_pen <= dev_free + 1e-6,
3658            "Penalized warp should be closer to identity: free={dev_free:.6}, pen={dev_pen:.6}"
3659        );
3660    }
3661
3662    #[test]
3663    fn test_penalized_alignment_large_lambda_near_identity() {
3664        let m = 50;
3665        let t = uniform_grid(m);
3666        let f1: Vec<f64> = t
3667            .iter()
3668            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
3669            .collect();
3670        let f2: Vec<f64> = t
3671            .iter()
3672            .map(|&ti| (2.0 * std::f64::consts::PI * (ti - 0.1)).sin())
3673            .collect();
3674
3675        let r = elastic_align_pair(&f1, &f2, &t, 1000.0);
3676
3677        // With very large lambda, warp should be very close to identity
3678        let max_dev: f64 = r
3679            .gamma
3680            .iter()
3681            .zip(t.iter())
3682            .map(|(g, ti)| (g - ti).abs())
3683            .fold(0.0_f64, f64::max);
3684        assert!(
3685            max_dev < 0.05,
3686            "Large lambda should give near-identity warp: max deviation = {max_dev}"
3687        );
3688    }
3689
3690    #[test]
3691    fn test_penalized_karcher_mean() {
3692        let m = 40;
3693        let t = uniform_grid(m);
3694        let data = make_test_data(10, m, 42);
3695
3696        let result = karcher_mean(&data, &t, 5, 1e-3, 0.5);
3697        assert_eq!(result.mean.len(), m);
3698        for j in 0..m {
3699            assert!(result.mean[j].is_finite());
3700        }
3701    }
3702
3703    // ── Phase-amplitude decomposition ──
3704
3705    #[test]
3706    fn test_decomposition_identity_curves() {
3707        let m = 50;
3708        let t = uniform_grid(m);
3709        let f: Vec<f64> = t
3710            .iter()
3711            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
3712            .collect();
3713
3714        let result = elastic_decomposition(&f, &f, &t, 0.0);
3715        assert!(
3716            result.d_amplitude < 0.1,
3717            "Self-decomposition amplitude should be ~0, got {}",
3718            result.d_amplitude
3719        );
3720        assert!(
3721            result.d_phase < 0.2,
3722            "Self-decomposition phase should be ~0, got {}",
3723            result.d_phase
3724        );
3725    }
3726
3727    #[test]
3728    fn test_decomposition_pythagorean() {
3729        // d_total² ≈ d_a² + d_φ² (approximately, for the Fisher-Rao metric)
3730        let m = 80;
3731        let t = uniform_grid(m);
3732        let f1: Vec<f64> = t
3733            .iter()
3734            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
3735            .collect();
3736        let f2: Vec<f64> = t
3737            .iter()
3738            .map(|&ti| 1.2 * (2.0 * std::f64::consts::PI * (ti - 0.1)).sin())
3739            .collect();
3740
3741        let result = elastic_decomposition(&f1, &f2, &t, 0.0);
3742        let da = result.d_amplitude;
3743        let dp = result.d_phase;
3744        // Both should be non-negative
3745        assert!(da >= 0.0);
3746        assert!(dp >= 0.0);
3747    }
3748
3749    #[test]
3750    fn test_phase_distance_shifted_sine() {
3751        let m = 80;
3752        let t = uniform_grid(m);
3753        let f1: Vec<f64> = t
3754            .iter()
3755            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
3756            .collect();
3757        let f2: Vec<f64> = t
3758            .iter()
3759            .map(|&ti| (2.0 * std::f64::consts::PI * (ti - 0.15)).sin())
3760            .collect();
3761
3762        let dp = phase_distance_pair(&f1, &f2, &t, 0.0);
3763        assert!(
3764            dp > 0.01,
3765            "Phase distance of shifted curves should be > 0, got {dp}"
3766        );
3767    }
3768
3769    #[test]
3770    fn test_amplitude_distance_scaled_curve() {
3771        let m = 80;
3772        let t = uniform_grid(m);
3773        let f1: Vec<f64> = t
3774            .iter()
3775            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
3776            .collect();
3777        let f2: Vec<f64> = t
3778            .iter()
3779            .map(|&ti| 2.0 * (2.0 * std::f64::consts::PI * ti).sin())
3780            .collect();
3781
3782        let da = amplitude_distance(&f1, &f2, &t, 0.0);
3783        assert!(
3784            da > 0.01,
3785            "Amplitude distance of scaled curves should be > 0, got {da}"
3786        );
3787    }
3788
3789    #[test]
3790    fn test_phase_distance_nonneg() {
3791        let data = make_test_data(4, 40, 42);
3792        let t = uniform_grid(40);
3793        for i in 0..4 {
3794            for j in 0..4 {
3795                let fi = data.row(i);
3796                let fj = data.row(j);
3797                let dp = phase_distance_pair(&fi, &fj, &t, 0.0);
3798                assert!(dp >= 0.0, "Phase distance should be non-negative");
3799            }
3800        }
3801    }
3802
3803    // ── Parallel transport ──
3804
3805    #[test]
3806    fn test_schilds_ladder_zero_vector() {
3807        let m = 21;
3808        let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
3809        let raw = vec![1.0; m];
3810        let norm = crate::warping::l2_norm_l2(&raw, &time);
3811        let from: Vec<f64> = raw.iter().map(|&v| v / norm).collect();
3812        let raw2: Vec<f64> = time
3813            .iter()
3814            .map(|&t| 1.0 + 0.2 * (2.0 * std::f64::consts::PI * t).sin())
3815            .collect();
3816        let norm2 = crate::warping::l2_norm_l2(&raw2, &time);
3817        let to: Vec<f64> = raw2.iter().map(|&v| v / norm2).collect();
3818
3819        let zero = vec![0.0; m];
3820        let result = parallel_transport_schilds(&zero, &from, &to, &time);
3821        let result_norm: f64 = result.iter().map(|v| v * v).sum::<f64>().sqrt();
3822        assert!(
3823            result_norm < 1e-6,
3824            "Transporting zero should give zero, got norm {result_norm}"
3825        );
3826    }
3827
3828    #[test]
3829    fn test_pole_ladder_zero_vector() {
3830        let m = 21;
3831        let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
3832        let raw = vec![1.0; m];
3833        let norm = crate::warping::l2_norm_l2(&raw, &time);
3834        let from: Vec<f64> = raw.iter().map(|&v| v / norm).collect();
3835        let raw2: Vec<f64> = time
3836            .iter()
3837            .map(|&t| 1.0 + 0.2 * (2.0 * std::f64::consts::PI * t).sin())
3838            .collect();
3839        let norm2 = crate::warping::l2_norm_l2(&raw2, &time);
3840        let to: Vec<f64> = raw2.iter().map(|&v| v / norm2).collect();
3841
3842        let zero = vec![0.0; m];
3843        let result = parallel_transport_pole(&zero, &from, &to, &time);
3844        let result_norm: f64 = result.iter().map(|v| v * v).sum::<f64>().sqrt();
3845        assert!(
3846            result_norm < 1e-6,
3847            "Transporting zero should give zero, got norm {result_norm}"
3848        );
3849    }
3850
3851    #[test]
3852    fn test_schilds_preserves_norm() {
3853        let m = 51;
3854        let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
3855        let raw = vec![1.0; m];
3856        let norm = crate::warping::l2_norm_l2(&raw, &time);
3857        let from: Vec<f64> = raw.iter().map(|&v| v / norm).collect();
3858        let raw2: Vec<f64> = time
3859            .iter()
3860            .map(|&t| 1.0 + 0.15 * (2.0 * std::f64::consts::PI * t).sin())
3861            .collect();
3862        let norm2 = crate::warping::l2_norm_l2(&raw2, &time);
3863        let to: Vec<f64> = raw2.iter().map(|&v| v / norm2).collect();
3864
3865        // Small tangent vector
3866        let v: Vec<f64> = time
3867            .iter()
3868            .map(|&t| 0.1 * (4.0 * std::f64::consts::PI * t).cos())
3869            .collect();
3870        let v_norm = crate::warping::l2_norm_l2(&v, &time);
3871
3872        let transported = parallel_transport_schilds(&v, &from, &to, &time);
3873        let t_norm = crate::warping::l2_norm_l2(&transported, &time);
3874
3875        // Norm should be approximately preserved (ladder methods are first-order)
3876        assert!(
3877            (t_norm - v_norm).abs() / v_norm.max(1e-10) < 1.5,
3878            "Schild's should roughly preserve norm: original={v_norm:.4}, transported={t_norm:.4}"
3879        );
3880    }
3881
3882    #[test]
3883    fn test_tsrvf_logmap_matches_original() {
3884        let m = 50;
3885        let n = 5;
3886        let t = uniform_grid(m);
3887        let data = make_test_data(n, m, 42);
3888
3889        let result_orig = tsrvf_transform(&data, &t, 5, 1e-3, 0.0);
3890        let result_logmap =
3891            tsrvf_transform_with_method(&data, &t, 5, 1e-3, 0.0, TransportMethod::LogMap);
3892
3893        // Should be identical (LogMap delegates to original)
3894        for i in 0..n {
3895            for j in 0..m {
3896                assert!(
3897                    (result_orig.tangent_vectors[(i, j)] - result_logmap.tangent_vectors[(i, j)])
3898                        .abs()
3899                        < 1e-12,
3900                    "LogMap variant should match original at ({i},{j})"
3901                );
3902            }
3903        }
3904    }
3905
3906    #[test]
3907    fn test_tsrvf_with_schilds_produces_valid_result() {
3908        let m = 50;
3909        let n = 5;
3910        let t = uniform_grid(m);
3911        let data = make_test_data(n, m, 42);
3912
3913        let result =
3914            tsrvf_transform_with_method(&data, &t, 5, 1e-3, 0.0, TransportMethod::SchildsLadder);
3915
3916        assert_eq!(result.tangent_vectors.shape(), (n, m));
3917        for i in 0..n {
3918            for j in 0..m {
3919                assert!(
3920                    result.tangent_vectors[(i, j)].is_finite(),
3921                    "Schild's TSRVF should produce finite tangent vectors at ({i},{j})"
3922                );
3923            }
3924        }
3925    }
3926
3927    #[test]
3928    fn test_transport_methods_differ() {
3929        let m = 50;
3930        let n = 5;
3931        let t = uniform_grid(m);
3932        let data = make_test_data(n, m, 42);
3933        let karcher = karcher_mean(&data, &t, 5, 1e-3, 0.0);
3934
3935        let r_log = tsrvf_from_alignment_with_method(&karcher, &t, TransportMethod::LogMap);
3936        let r_schilds =
3937            tsrvf_from_alignment_with_method(&karcher, &t, TransportMethod::SchildsLadder);
3938
3939        // Methods should produce different (but related) tangent vectors
3940        let mut total_diff = 0.0;
3941        for i in 0..n {
3942            for j in 0..m {
3943                total_diff +=
3944                    (r_log.tangent_vectors[(i, j)] - r_schilds.tangent_vectors[(i, j)]).abs();
3945            }
3946        }
3947
3948        // They should be non-zero different (unless all curves are identical to mean)
3949        // Just check both produce finite results
3950        assert!(total_diff.is_finite());
3951    }
3952
3953    // ── Alignment quality metrics ──
3954
3955    #[test]
3956    fn test_warp_complexity_identity_is_zero() {
3957        let m = 50;
3958        let t = uniform_grid(m);
3959        let identity = t.clone();
3960        let c = warp_complexity(&identity, &t);
3961        assert!(
3962            c < 1e-10,
3963            "Identity warp should have zero complexity, got {c}"
3964        );
3965    }
3966
3967    #[test]
3968    fn test_warp_complexity_nonidentity_positive() {
3969        let m = 50;
3970        let t = uniform_grid(m);
3971        let gamma: Vec<f64> = t.iter().map(|&ti| ti * ti).collect();
3972        let c = warp_complexity(&gamma, &t);
3973        assert!(
3974            c > 0.01,
3975            "Non-identity warp should have positive complexity, got {c}"
3976        );
3977    }
3978
3979    #[test]
3980    fn test_warp_smoothness_identity_is_zero() {
3981        let m = 50;
3982        let t = uniform_grid(m);
3983        let identity = t.clone();
3984        let s = warp_smoothness(&identity, &t);
3985        assert!(
3986            s < 1e-6,
3987            "Identity warp (constant γ'=1, γ''=0) should have near-zero bending energy, got {s}"
3988        );
3989    }
3990
3991    #[test]
3992    fn test_alignment_quality_basic() {
3993        let m = 50;
3994        let n = 8;
3995        let t = uniform_grid(m);
3996        let data = make_test_data(n, m, 42);
3997        let karcher = karcher_mean(&data, &t, 10, 1e-3, 0.0);
3998        let quality = alignment_quality(&data, &karcher, &t);
3999
4000        // Shape checks
4001        assert_eq!(quality.warp_complexity.len(), n);
4002        assert_eq!(quality.warp_smoothness.len(), n);
4003        assert_eq!(quality.pointwise_variance_ratio.len(), m);
4004
4005        // Non-negativity
4006        assert!(quality.total_variance >= 0.0);
4007        assert!(quality.amplitude_variance >= 0.0);
4008        assert!(quality.phase_variance >= 0.0);
4009        assert!(quality.mean_warp_complexity >= 0.0);
4010        assert!(quality.mean_warp_smoothness >= 0.0);
4011
4012        // Amplitude variance ≤ total variance
4013        assert!(
4014            quality.amplitude_variance <= quality.total_variance + 1e-10,
4015            "Amplitude variance ({}) should be ≤ total variance ({})",
4016            quality.amplitude_variance,
4017            quality.total_variance
4018        );
4019    }
4020
4021    #[test]
4022    fn test_alignment_quality_identical_curves() {
4023        let m = 50;
4024        let t = uniform_grid(m);
4025        let curve: Vec<f64> = t
4026            .iter()
4027            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
4028            .collect();
4029        let mut col_major = vec![0.0; 5 * m];
4030        for i in 0..5 {
4031            for j in 0..m {
4032                col_major[i + j * 5] = curve[j];
4033            }
4034        }
4035        let data = FdMatrix::from_column_major(col_major, 5, m).unwrap();
4036        let karcher = karcher_mean(&data, &t, 5, 1e-3, 0.0);
4037        let quality = alignment_quality(&data, &karcher, &t);
4038
4039        // Identical curves → near-zero variances and warp complexities
4040        assert!(
4041            quality.total_variance < 0.01,
4042            "Identical curves should have near-zero total variance, got {}",
4043            quality.total_variance
4044        );
4045        assert!(
4046            quality.mean_warp_complexity < 0.1,
4047            "Identical curves should have near-zero warp complexity, got {}",
4048            quality.mean_warp_complexity
4049        );
4050    }
4051
4052    #[test]
4053    fn test_alignment_quality_variance_reduction() {
4054        let m = 50;
4055        let n = 10;
4056        let t = uniform_grid(m);
4057        let data = make_test_data(n, m, 42);
4058        let karcher = karcher_mean(&data, &t, 10, 1e-3, 0.0);
4059        let quality = alignment_quality(&data, &karcher, &t);
4060
4061        // Mean variance ratio should be ≤ ~1 (alignment shouldn't increase variance)
4062        assert!(
4063            quality.mean_variance_reduction <= 1.5,
4064            "Mean variance reduction ratio should be ≤ ~1, got {}",
4065            quality.mean_variance_reduction
4066        );
4067    }
4068
4069    #[test]
4070    fn test_pairwise_consistency_small() {
4071        let m = 40;
4072        let n = 4;
4073        let t = uniform_grid(m);
4074        let data = make_test_data(n, m, 42);
4075
4076        let consistency = pairwise_consistency(&data, &t, 0.0, 100);
4077        assert!(
4078            consistency.is_finite() && consistency >= 0.0,
4079            "Pairwise consistency should be finite and non-negative, got {consistency}"
4080        );
4081    }
4082
4083    // ── Multidimensional SRSF ──
4084
4085    #[test]
4086    fn test_srsf_nd_d1_matches_existing() {
4087        let m = 50;
4088        let t = uniform_grid(m);
4089        let data = make_test_data(3, m, 42);
4090
4091        // 1D via existing function
4092        let q_1d = srsf_transform(&data, &t);
4093
4094        // 1D via nd function
4095        let data_nd = FdCurveSet::from_1d(data);
4096        let q_nd = srsf_transform_nd(&data_nd, &t);
4097
4098        assert_eq!(q_nd.ndim(), 1);
4099        for i in 0..3 {
4100            for j in 0..m {
4101                assert!(
4102                    (q_1d[(i, j)] - q_nd.dims[0][(i, j)]).abs() < 1e-10,
4103                    "1D nd SRSF should match existing at ({i},{j}): {} vs {}",
4104                    q_1d[(i, j)],
4105                    q_nd.dims[0][(i, j)]
4106                );
4107            }
4108        }
4109    }
4110
4111    #[test]
4112    fn test_srsf_nd_constant_is_zero() {
4113        let m = 30;
4114        let t = uniform_grid(m);
4115        // Constant R^2 curve: f(t) = (3.0, -1.0)
4116        let dim0 = FdMatrix::from_column_major(vec![3.0; m], 1, m).unwrap();
4117        let dim1 = FdMatrix::from_column_major(vec![-1.0; m], 1, m).unwrap();
4118        let data = FdCurveSet::from_dims(vec![dim0, dim1]).unwrap();
4119
4120        let q = srsf_transform_nd(&data, &t);
4121        for k in 0..2 {
4122            for j in 0..m {
4123                assert!(
4124                    q.dims[k][(0, j)].abs() < 1e-10,
4125                    "Constant curve SRSF should be zero, dim {k} at {j}: {}",
4126                    q.dims[k][(0, j)]
4127                );
4128            }
4129        }
4130    }
4131
4132    #[test]
4133    fn test_srsf_nd_linear_r2() {
4134        let m = 51;
4135        let t = uniform_grid(m);
4136        // f(t) = (2t, 3t) → f'(t) = (2, 3), ||f'|| = sqrt(13)
4137        // q(t) = (2, 3) / sqrt(sqrt(13)) = (2, 3) / 13^(1/4)
4138        let dim0 =
4139            FdMatrix::from_slice(&t.iter().map(|&ti| 2.0 * ti).collect::<Vec<_>>(), 1, m).unwrap();
4140        let dim1 =
4141            FdMatrix::from_slice(&t.iter().map(|&ti| 3.0 * ti).collect::<Vec<_>>(), 1, m).unwrap();
4142        let data = FdCurveSet::from_dims(vec![dim0, dim1]).unwrap();
4143
4144        let q = srsf_transform_nd(&data, &t);
4145        let expected_scale = 1.0 / 13.0_f64.powf(0.25);
4146        let mid = m / 2;
4147
4148        assert!(
4149            (q.dims[0][(0, mid)] - 2.0 * expected_scale).abs() < 0.1,
4150            "q_x at midpoint: {} vs expected {}",
4151            q.dims[0][(0, mid)],
4152            2.0 * expected_scale
4153        );
4154        assert!(
4155            (q.dims[1][(0, mid)] - 3.0 * expected_scale).abs() < 0.1,
4156            "q_y at midpoint: {} vs expected {}",
4157            q.dims[1][(0, mid)],
4158            3.0 * expected_scale
4159        );
4160    }
4161
4162    #[test]
4163    fn test_srsf_nd_round_trip() {
4164        let m = 51;
4165        let t = uniform_grid(m);
4166        // f(t) = (sin(2πt), cos(2πt))
4167        let pi2 = 2.0 * std::f64::consts::PI;
4168        let vals_x: Vec<f64> = t.iter().map(|&ti| (pi2 * ti).sin()).collect();
4169        let vals_y: Vec<f64> = t.iter().map(|&ti| (pi2 * ti).cos()).collect();
4170        let dim0 = FdMatrix::from_slice(&vals_x, 1, m).unwrap();
4171        let dim1 = FdMatrix::from_slice(&vals_y, 1, m).unwrap();
4172        let data = FdCurveSet::from_dims(vec![dim0, dim1]).unwrap();
4173
4174        let q = srsf_transform_nd(&data, &t);
4175        let q_vecs: Vec<Vec<f64>> = q.dims.iter().map(|dm| dm.row(0)).collect();
4176        let f0 = vec![vals_x[0], vals_y[0]];
4177        let recon = srsf_inverse_nd(&q_vecs, &t, &f0);
4178
4179        // Check reconstruction error (skip boundary points)
4180        let mut max_err = 0.0_f64;
4181        for k in 0..2 {
4182            let orig = if k == 0 { &vals_x } else { &vals_y };
4183            for j in 2..(m - 2) {
4184                let err = (recon[k][j] - orig[j]).abs();
4185                max_err = max_err.max(err);
4186            }
4187        }
4188        assert!(
4189            max_err < 0.2,
4190            "SRSF round-trip max error should be small, got {max_err}"
4191        );
4192    }
4193
4194    #[test]
4195    fn test_align_nd_identical_near_zero() {
4196        let m = 50;
4197        let t = uniform_grid(m);
4198        let pi2 = 2.0 * std::f64::consts::PI;
4199        let vals_x: Vec<f64> = t.iter().map(|&ti| (pi2 * ti).sin()).collect();
4200        let vals_y: Vec<f64> = t.iter().map(|&ti| (pi2 * ti).cos()).collect();
4201        let dim0 = FdMatrix::from_slice(&vals_x, 1, m).unwrap();
4202        let dim1 = FdMatrix::from_slice(&vals_y, 1, m).unwrap();
4203        let data = FdCurveSet::from_dims(vec![dim0, dim1]).unwrap();
4204
4205        let result = elastic_align_pair_nd(&data, &data, &t, 0.0);
4206        assert!(
4207            result.distance < 0.5,
4208            "Self-alignment distance should be ~0, got {}",
4209            result.distance
4210        );
4211        // Gamma should be near identity
4212        let max_dev: f64 = result
4213            .gamma
4214            .iter()
4215            .zip(t.iter())
4216            .map(|(g, ti)| (g - ti).abs())
4217            .fold(0.0_f64, f64::max);
4218        assert!(
4219            max_dev < 0.1,
4220            "Self-alignment warp should be near identity, max dev = {max_dev}"
4221        );
4222    }
4223
4224    #[test]
4225    fn test_align_nd_shifted_r2() {
4226        let m = 60;
4227        let t = uniform_grid(m);
4228        let pi2 = 2.0 * std::f64::consts::PI;
4229
4230        // f1(t) = (sin(2πt), cos(2πt))
4231        let f1x: Vec<f64> = t.iter().map(|&ti| (pi2 * ti).sin()).collect();
4232        let f1y: Vec<f64> = t.iter().map(|&ti| (pi2 * ti).cos()).collect();
4233        let f1 = FdCurveSet::from_dims(vec![
4234            FdMatrix::from_slice(&f1x, 1, m).unwrap(),
4235            FdMatrix::from_slice(&f1y, 1, m).unwrap(),
4236        ])
4237        .unwrap();
4238
4239        // f2(t) = (sin(2π(t-0.1)), cos(2π(t-0.1))) — phase shifted
4240        let f2x: Vec<f64> = t.iter().map(|&ti| (pi2 * (ti - 0.1)).sin()).collect();
4241        let f2y: Vec<f64> = t.iter().map(|&ti| (pi2 * (ti - 0.1)).cos()).collect();
4242        let f2 = FdCurveSet::from_dims(vec![
4243            FdMatrix::from_slice(&f2x, 1, m).unwrap(),
4244            FdMatrix::from_slice(&f2y, 1, m).unwrap(),
4245        ])
4246        .unwrap();
4247
4248        let result = elastic_align_pair_nd(&f1, &f2, &t, 0.0);
4249        assert!(
4250            result.distance.is_finite(),
4251            "Distance should be finite, got {}",
4252            result.distance
4253        );
4254        assert_eq!(result.f_aligned.len(), 2);
4255        assert_eq!(result.f_aligned[0].len(), m);
4256        // Warp should deviate from identity (non-trivial alignment)
4257        let max_dev: f64 = result
4258            .gamma
4259            .iter()
4260            .zip(t.iter())
4261            .map(|(g, ti)| (g - ti).abs())
4262            .fold(0.0_f64, f64::max);
4263        assert!(
4264            max_dev > 0.01,
4265            "Shifted curves should require non-trivial warp, max dev = {max_dev}"
4266        );
4267    }
4268
4269    // ── Landmark-constrained alignment ──
4270
4271    #[test]
4272    fn test_constrained_no_landmarks_matches_unconstrained() {
4273        let m = 50;
4274        let t = uniform_grid(m);
4275        let f1: Vec<f64> = t
4276            .iter()
4277            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
4278            .collect();
4279        let f2: Vec<f64> = t
4280            .iter()
4281            .map(|&ti| (2.0 * std::f64::consts::PI * (ti - 0.1)).sin())
4282            .collect();
4283
4284        let r_free = elastic_align_pair(&f1, &f2, &t, 0.0);
4285        let r_const = elastic_align_pair_constrained(&f1, &f2, &t, &[], 0.0);
4286
4287        // Should match unconstrained
4288        for j in 0..m {
4289            assert!(
4290                (r_free.gamma[j] - r_const.gamma[j]).abs() < 1e-10,
4291                "No-landmark constrained should match unconstrained at {j}"
4292            );
4293        }
4294        assert!(r_const.enforced_landmarks.is_empty());
4295    }
4296
4297    #[test]
4298    fn test_constrained_single_landmark_enforced() {
4299        let m = 60;
4300        let t = uniform_grid(m);
4301        let f1: Vec<f64> = t
4302            .iter()
4303            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
4304            .collect();
4305        let f2: Vec<f64> = t
4306            .iter()
4307            .map(|&ti| (2.0 * std::f64::consts::PI * (ti - 0.1)).sin())
4308            .collect();
4309
4310        // Constrain midpoint: target_t=0.5 should map to source_t=0.5
4311        let result = elastic_align_pair_constrained(&f1, &f2, &t, &[(0.5, 0.5)], 0.0);
4312
4313        // Gamma at the midpoint should be close to 0.5
4314        let mid_idx = snap_to_grid(0.5, &t);
4315        assert!(
4316            (result.gamma[mid_idx] - 0.5).abs() < 0.05,
4317            "Constrained gamma at midpoint should be ~0.5, got {}",
4318            result.gamma[mid_idx]
4319        );
4320        assert_eq!(result.enforced_landmarks.len(), 1);
4321    }
4322
4323    #[test]
4324    fn test_constrained_multiple_landmarks() {
4325        let m = 80;
4326        let t = uniform_grid(m);
4327        let f1: Vec<f64> = t
4328            .iter()
4329            .map(|&ti| (4.0 * std::f64::consts::PI * ti).sin())
4330            .collect();
4331        let f2: Vec<f64> = t
4332            .iter()
4333            .map(|&ti| (4.0 * std::f64::consts::PI * (ti - 0.05)).sin())
4334            .collect();
4335
4336        let landmarks = vec![(0.25, 0.25), (0.5, 0.5), (0.75, 0.75)];
4337        let result = elastic_align_pair_constrained(&f1, &f2, &t, &landmarks, 0.0);
4338
4339        // Gamma should pass through (or near) each landmark
4340        for &(tt, st) in &landmarks {
4341            let idx = snap_to_grid(tt, &t);
4342            assert!(
4343                (result.gamma[idx] - st).abs() < 0.05,
4344                "Gamma at t={tt} should be ~{st}, got {}",
4345                result.gamma[idx]
4346            );
4347        }
4348    }
4349
4350    #[test]
4351    fn test_constrained_monotone_gamma() {
4352        let m = 60;
4353        let t = uniform_grid(m);
4354        let f1: Vec<f64> = t
4355            .iter()
4356            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
4357            .collect();
4358        let f2: Vec<f64> = t
4359            .iter()
4360            .map(|&ti| (2.0 * std::f64::consts::PI * (ti - 0.1)).sin())
4361            .collect();
4362
4363        let result = elastic_align_pair_constrained(&f1, &f2, &t, &[(0.3, 0.3), (0.7, 0.7)], 0.0);
4364
4365        // Gamma should be non-decreasing
4366        for j in 1..m {
4367            assert!(
4368                result.gamma[j] >= result.gamma[j - 1] - 1e-10,
4369                "Gamma should be monotone: gamma[{}]={} < gamma[{}]={}",
4370                j,
4371                result.gamma[j],
4372                j - 1,
4373                result.gamma[j - 1]
4374            );
4375        }
4376        // Boundary conditions
4377        assert!((result.gamma[0] - t[0]).abs() < 1e-10);
4378        assert!((result.gamma[m - 1] - t[m - 1]).abs() < 1e-10);
4379    }
4380
4381    #[test]
4382    fn test_constrained_distance_ge_unconstrained() {
4383        let m = 60;
4384        let t = uniform_grid(m);
4385        let f1: Vec<f64> = t
4386            .iter()
4387            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
4388            .collect();
4389        let f2: Vec<f64> = t
4390            .iter()
4391            .map(|&ti| (2.0 * std::f64::consts::PI * (ti - 0.15)).sin())
4392            .collect();
4393
4394        let r_free = elastic_align_pair(&f1, &f2, &t, 0.0);
4395        let r_const = elastic_align_pair_constrained(&f1, &f2, &t, &[(0.5, 0.5)], 0.0);
4396
4397        // Constrained distance should be >= unconstrained (constraints reduce freedom)
4398        assert!(
4399            r_const.distance >= r_free.distance - 1e-6,
4400            "Constrained distance ({}) should be >= unconstrained ({})",
4401            r_const.distance,
4402            r_free.distance
4403        );
4404    }
4405
4406    #[test]
4407    fn test_constrained_with_landmark_detection() {
4408        let m = 80;
4409        let t = uniform_grid(m);
4410        let f1: Vec<f64> = t
4411            .iter()
4412            .map(|&ti| (4.0 * std::f64::consts::PI * ti).sin())
4413            .collect();
4414        let f2: Vec<f64> = t
4415            .iter()
4416            .map(|&ti| (4.0 * std::f64::consts::PI * (ti - 0.05)).sin())
4417            .collect();
4418
4419        let result = elastic_align_pair_with_landmarks(
4420            &f1,
4421            &f2,
4422            &t,
4423            crate::landmark::LandmarkKind::Peak,
4424            0.1,
4425            0,
4426            0.0,
4427        );
4428
4429        assert_eq!(result.gamma.len(), m);
4430        assert_eq!(result.f_aligned.len(), m);
4431        assert!(result.distance.is_finite());
4432        // Should be monotone
4433        for j in 1..m {
4434            assert!(
4435                result.gamma[j] >= result.gamma[j - 1] - 1e-10,
4436                "Gamma should be monotone at j={j}"
4437            );
4438        }
4439    }
4440
4441    // ── SRSF smoothing for TSRVF (Issue #13) ──
4442
4443    #[test]
4444    fn test_gam_to_psi_smooth_identity() {
4445        // Smoothed psi of identity warp should stay close to constant 1 in the interior.
4446        // Boundary points are biased by kernel smoothing (fewer neighbors), skip them.
4447        use crate::warping::{gam_to_psi, gam_to_psi_smooth};
4448        let m = 101;
4449        let h = 1.0 / (m - 1) as f64;
4450        let gam: Vec<f64> = uniform_grid(m);
4451        let psi_raw = gam_to_psi(&gam, h);
4452        let psi_smooth = gam_to_psi_smooth(&gam, h);
4453        // Check interior points (skip ~5% at each boundary)
4454        let skip = m / 20;
4455        for j in skip..(m - skip) {
4456            assert!(
4457                (psi_smooth[j] - 1.0).abs() < 0.05,
4458                "Smoothed psi of identity should be ~1.0, got {} at j={}",
4459                psi_smooth[j],
4460                j
4461            );
4462            assert!(
4463                (psi_smooth[j] - psi_raw[j]).abs() < 0.05,
4464                "Smoothed and raw psi should agree on smooth warp at j={}",
4465                j
4466            );
4467        }
4468    }
4469
4470    #[test]
4471    fn test_gam_to_psi_smooth_reduces_spikes() {
4472        // Create a kinky warp (simulating DP output with multiple slope changes)
4473        // and verify smoothing reduces psi spikes
4474        use crate::warping::{gam_to_psi, gam_to_psi_smooth};
4475        let m = 101;
4476        let h = 1.0 / (m - 1) as f64;
4477        let argvals = uniform_grid(m);
4478        // Multi-segment piecewise-linear warp with several kinks
4479        let mut gam: Vec<f64> = Vec::with_capacity(m);
4480        for j in 0..m {
4481            let t = argvals[j];
4482            // Three segments: slow (slope 0.5), fast (slope 2), slow (slope 0.5)
4483            let g = if t < 0.33 {
4484                t * 0.5 / 0.33
4485            } else if t < 0.67 {
4486                0.5 + (t - 0.33) * 0.5 / 0.34 * 2.0 // steeper
4487            } else {
4488                let base = 0.5 + 0.5 / 0.34 * 2.0 * 0.34; // ~1.5 but clamped
4489                (base + (t - 0.67) * 0.5 / 0.33).min(1.0)
4490            };
4491            gam.push(g.min(1.0));
4492        }
4493        // Normalize to [0,1]
4494        let gmax = gam[m - 1].max(1e-10);
4495        for g in &mut gam {
4496            *g /= gmax;
4497        }
4498        let psi_raw = gam_to_psi(&gam, h);
4499        let psi_smooth = gam_to_psi_smooth(&gam, h);
4500        // The raw psi should have jumps at kink points (slope transitions)
4501        let max_jump_raw: f64 = psi_raw
4502            .windows(2)
4503            .map(|w| (w[1] - w[0]).abs())
4504            .fold(0.0_f64, f64::max);
4505        let max_jump_smooth: f64 = psi_smooth
4506            .windows(2)
4507            .map(|w| (w[1] - w[0]).abs())
4508            .fold(0.0_f64, f64::max);
4509        // Smoothing should reduce the maximum jump in psi
4510        assert!(
4511            max_jump_smooth < max_jump_raw + 0.01,
4512            "Smoothing should not increase max psi jump: raw={max_jump_raw:.4}, smooth={max_jump_smooth:.4}"
4513        );
4514    }
4515
4516    #[test]
4517    fn test_smooth_aligned_srsfs_preserves_shape() {
4518        // Smoothing aligned SRSFs should preserve overall shape
4519        use crate::smoothing::nadaraya_watson;
4520        let m = 101;
4521        let time: Vec<f64> = (0..m).map(|j| j as f64 / (m - 1) as f64).collect();
4522        // Create a smooth SRSF (sine curve)
4523        let qi: Vec<f64> = time
4524            .iter()
4525            .map(|&t| (2.0 * std::f64::consts::PI * t).sin())
4526            .collect();
4527        let bandwidth = 2.0 / (m - 1) as f64;
4528        let qi_smooth = nadaraya_watson(&time, &qi, &time, bandwidth, "gaussian");
4529        // Correlation between original and smoothed should be very high
4530        let mean_orig: f64 = qi.iter().sum::<f64>() / m as f64;
4531        let mean_smooth: f64 = qi_smooth.iter().sum::<f64>() / m as f64;
4532        let mut cov = 0.0;
4533        let mut var_o = 0.0;
4534        let mut var_s = 0.0;
4535        for j in 0..m {
4536            let do_ = qi[j] - mean_orig;
4537            let ds = qi_smooth[j] - mean_smooth;
4538            cov += do_ * ds;
4539            var_o += do_ * do_;
4540            var_s += ds * ds;
4541        }
4542        let rho = cov / (var_o * var_s).sqrt().max(1e-10);
4543        assert!(
4544            rho > 0.99,
4545            "Smoothed SRSF should be highly correlated with original (rho={rho:.4})"
4546        );
4547    }
4548
4549    #[test]
4550    fn test_tsrvf_tangent_vectors_no_spikes() {
4551        // End-to-end: compute TSRVF tangent vectors, verify no element dominates.
4552        // The smooth_aligned_srsfs step in tsrvf_from_alignment removes DP kink
4553        // artifacts that would otherwise produce spike outliers in tangent vectors.
4554        let m = 101;
4555        let argvals = uniform_grid(m);
4556        let data = make_test_data(10, m, 42);
4557        let result = tsrvf_transform(&data, &argvals, 5, 1e-3, 0.0);
4558        let (n, _) = result.tangent_vectors.shape();
4559        for i in 0..n {
4560            let vi = result.tangent_vectors.row(i);
4561            let rms = (vi.iter().map(|&v| v * v).sum::<f64>() / m as f64).sqrt();
4562            if rms > 1e-10 {
4563                let max_abs = vi.iter().map(|&v| v.abs()).fold(0.0_f64, f64::max);
4564                assert!(
4565                    max_abs < 10.0 * rms,
4566                    "Tangent vector {} has spike: max |v| = {max_abs:.4}, rms = {rms:.4}, ratio = {:.1}",
4567                    i,
4568                    max_abs / rms
4569                );
4570            }
4571        }
4572    }
4573}