Skip to main content

fdars_core/
alignment.rs

1//! Elastic alignment and SRSF (Square-Root Slope Function) transforms.
2//!
3//! This module provides phase-amplitude separation for functional data via
4//! the elastic framework. Key capabilities:
5//!
6//! - [`srsf_transform`] / [`srsf_inverse`] — SRSF representation and reconstruction
7//! - [`elastic_align_pair`] — Pairwise curve alignment via dynamic programming
8//! - [`elastic_distance`] — Elastic (Fisher-Rao) distance between curves
9//! - [`align_to_target`] — Align a set of curves to a common target
10//! - [`karcher_mean`] — Karcher (Fréchet) mean in the elastic metric
11//! - [`elastic_self_distance_matrix`] / [`elastic_cross_distance_matrix`] — Distance matrices
12//! - [`reparameterize_curve`] / [`compose_warps`] — Warping utilities
13
14use crate::fdata::{deriv_1d, mean_1d};
15use crate::helpers::{
16    cumulative_trapz, gradient_uniform, l2_distance, linear_interp, simpsons_weights,
17};
18use crate::iter_maybe_parallel;
19use crate::matrix::FdMatrix;
20use crate::smoothing::nadaraya_watson;
21use crate::warping::{
22    exp_map_sphere, gam_to_psi, inv_exp_map_sphere, invert_gamma, l2_norm_l2, normalize_warp,
23    psi_to_gam,
24};
25#[cfg(feature = "parallel")]
26use rayon::iter::ParallelIterator;
27
28// ─── Types ──────────────────────────────────────────────────────────────────
29
30/// Result of aligning one curve to another.
31#[derive(Debug, Clone)]
32pub struct AlignmentResult {
33    /// Warping function γ mapping the domain to itself.
34    pub gamma: Vec<f64>,
35    /// The aligned (reparameterized) curve.
36    pub f_aligned: Vec<f64>,
37    /// Elastic distance after alignment.
38    pub distance: f64,
39}
40
41/// Result of aligning a set of curves to a common target.
42#[derive(Debug, Clone)]
43pub struct AlignmentSetResult {
44    /// Warping functions (n × m).
45    pub gammas: FdMatrix,
46    /// Aligned curves (n × m).
47    pub aligned_data: FdMatrix,
48    /// Elastic distances for each curve.
49    pub distances: Vec<f64>,
50}
51
52/// Result of the Karcher mean computation.
53#[derive(Debug, Clone)]
54pub struct KarcherMeanResult {
55    /// Karcher mean curve.
56    pub mean: Vec<f64>,
57    /// SRSF of the Karcher mean.
58    pub mean_srsf: Vec<f64>,
59    /// Final warping functions (n × m).
60    pub gammas: FdMatrix,
61    /// Curves aligned to the mean (n × m).
62    pub aligned_data: FdMatrix,
63    /// Number of iterations used.
64    pub n_iter: usize,
65    /// Whether the algorithm converged.
66    pub converged: bool,
67    /// Pre-computed SRSFs of aligned curves (n × m), if available.
68    /// When set, FPCA functions use these instead of recomputing from `aligned_data`.
69    pub aligned_srsfs: Option<FdMatrix>,
70}
71
72// Private helpers are now in crate::helpers and crate::warping.
73
74/// Karcher mean of warping functions on the Hilbert sphere, then invert.
75/// Port of fdasrvf's `SqrtMeanInverse`.
76///
77/// Takes a matrix of warping functions (n × m) on the argvals domain,
78/// computes the Fréchet mean of their sqrt-derivative representations
79/// on the unit Hilbert sphere, converts back to a warping function,
80/// and returns its inverse (on the argvals domain).
81/// One Karcher iteration on the Hilbert sphere: compute mean shooting vector and update mu.
82///
83/// Returns `true` if converged (vbar norm ≤ threshold).
84fn karcher_sphere_step(mu: &mut Vec<f64>, psis: &[Vec<f64>], time: &[f64], step_size: f64) -> bool {
85    let m = mu.len();
86    let n = psis.len();
87    let mut vbar = vec![0.0; m];
88    for psi in psis {
89        let v = inv_exp_map_sphere(mu, psi, time);
90        for j in 0..m {
91            vbar[j] += v[j];
92        }
93    }
94    for j in 0..m {
95        vbar[j] /= n as f64;
96    }
97    if l2_norm_l2(&vbar, time) <= 1e-8 {
98        return true;
99    }
100    let scaled: Vec<f64> = vbar.iter().map(|&v| v * step_size).collect();
101    *mu = exp_map_sphere(mu, &scaled, time);
102    false
103}
104
105pub(crate) fn sqrt_mean_inverse(gammas: &FdMatrix, argvals: &[f64]) -> Vec<f64> {
106    let (n, m) = gammas.shape();
107    let t0 = argvals[0];
108    let t1 = argvals[m - 1];
109    let domain = t1 - t0;
110
111    let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
112    let binsize = 1.0 / (m - 1) as f64;
113
114    let psis: Vec<Vec<f64>> = (0..n)
115        .map(|i| {
116            let gam_01: Vec<f64> = (0..m).map(|j| (gammas[(i, j)] - t0) / domain).collect();
117            gam_to_psi(&gam_01, binsize)
118        })
119        .collect();
120
121    let mut mu = vec![0.0; m];
122    for psi in &psis {
123        for j in 0..m {
124            mu[j] += psi[j];
125        }
126    }
127    for j in 0..m {
128        mu[j] /= n as f64;
129    }
130
131    for _ in 0..501 {
132        if karcher_sphere_step(&mut mu, &psis, &time, 0.3) {
133            break;
134        }
135    }
136
137    let gam_mu = psi_to_gam(&mu, &time);
138    let gam_inv = invert_gamma(&gam_mu, &time);
139    gam_inv.iter().map(|&g| t0 + g * domain).collect()
140}
141
142// ─── SRSF Transform and Inverse ─────────────────────────────────────────────
143
144/// Compute the Square-Root Slope Function (SRSF) transform.
145///
146/// For each curve f, the SRSF is: `q(t) = sign(f'(t)) * sqrt(|f'(t)|)`
147///
148/// # Arguments
149/// * `data` — Functional data matrix (n × m)
150/// * `argvals` — Evaluation points (length m)
151///
152/// # Returns
153/// FdMatrix of SRSFs with the same shape as input.
154pub fn srsf_transform(data: &FdMatrix, argvals: &[f64]) -> FdMatrix {
155    let (n, m) = data.shape();
156    if n == 0 || m == 0 || argvals.len() != m {
157        return FdMatrix::zeros(n, m);
158    }
159
160    let deriv = deriv_1d(data, argvals, 1);
161
162    let mut result = FdMatrix::zeros(n, m);
163    for i in 0..n {
164        for j in 0..m {
165            let d = deriv[(i, j)];
166            result[(i, j)] = d.signum() * d.abs().sqrt();
167        }
168    }
169    result
170}
171
172/// Reconstruct a curve from its SRSF representation.
173///
174/// Given SRSF q and initial value f0, reconstructs: `f(t) = f0 + ∫₀ᵗ q(s)|q(s)| ds`
175///
176/// # Arguments
177/// * `q` — SRSF values (length m)
178/// * `argvals` — Evaluation points (length m)
179/// * `f0` — Initial value f(argvals\[0\])
180///
181/// # Returns
182/// Reconstructed curve values.
183pub fn srsf_inverse(q: &[f64], argvals: &[f64], f0: f64) -> Vec<f64> {
184    let m = q.len();
185    if m == 0 {
186        return Vec::new();
187    }
188
189    // Integrand: q(s) * |q(s)|
190    let integrand: Vec<f64> = q.iter().map(|&qi| qi * qi.abs()).collect();
191    let integral = cumulative_trapz(&integrand, argvals);
192
193    integral.iter().map(|&v| f0 + v).collect()
194}
195
196// ─── Reparameterization ─────────────────────────────────────────────────────
197
198/// Reparameterize a curve by a warping function.
199///
200/// Computes `f(γ(t))` via linear interpolation.
201///
202/// # Arguments
203/// * `f` — Curve values (length m)
204/// * `argvals` — Evaluation points (length m)
205/// * `gamma` — Warping function values (length m)
206pub fn reparameterize_curve(f: &[f64], argvals: &[f64], gamma: &[f64]) -> Vec<f64> {
207    gamma
208        .iter()
209        .map(|&g| linear_interp(argvals, f, g))
210        .collect()
211}
212
213/// Compose two warping functions: `(γ₁ ∘ γ₂)(t) = γ₁(γ₂(t))`.
214///
215/// # Arguments
216/// * `gamma1` — Outer warping function (length m)
217/// * `gamma2` — Inner warping function (length m)
218/// * `argvals` — Evaluation points (length m)
219pub fn compose_warps(gamma1: &[f64], gamma2: &[f64], argvals: &[f64]) -> Vec<f64> {
220    gamma2
221        .iter()
222        .map(|&g| linear_interp(argvals, gamma1, g))
223        .collect()
224}
225
226// ─── Dynamic Programming Alignment ──────────────────────────────────────────
227// Faithful port of fdasrvf's DP algorithm (dp_grid.cpp / dp_nbhd.cpp).
228
229/// Greatest common divisor (Euclidean algorithm).
230#[cfg(test)]
231fn gcd(a: usize, b: usize) -> usize {
232    if b == 0 {
233        a
234    } else {
235        gcd(b, a % b)
236    }
237}
238
239/// Generate coprime neighborhood: all (i,j) with 1 ≤ i,j ≤ nbhd_dim, gcd(i,j) = 1.
240/// With nbhd_dim=7 this produces 35 pairs, matching fdasrvf's default.
241#[cfg(test)]
242fn generate_coprime_nbhd(nbhd_dim: usize) -> Vec<(usize, usize)> {
243    let mut pairs = Vec::new();
244    for i in 1..=nbhd_dim {
245        for j in 1..=nbhd_dim {
246            if gcd(i, j) == 1 {
247                pairs.push((i, j));
248            }
249        }
250    }
251    pairs
252}
253
254/// Pre-computed coprime neighborhood for nbhd_dim=7 (fdasrvf default).
255/// All (dr, dc) with 1 ≤ dr, dc ≤ 7 and gcd(dr, dc) = 1.
256/// dr = row delta (q2 direction), dc = column delta (q1 direction).
257#[rustfmt::skip]
258const COPRIME_NBHD_7: [(usize, usize); 35] = [
259    (1,1),(1,2),(1,3),(1,4),(1,5),(1,6),(1,7),
260    (2,1),      (2,3),      (2,5),      (2,7),
261    (3,1),(3,2),      (3,4),(3,5),      (3,7),
262    (4,1),      (4,3),      (4,5),      (4,7),
263    (5,1),(5,2),(5,3),(5,4),      (5,6),(5,7),
264    (6,1),                  (6,5),      (6,7),
265    (7,1),(7,2),(7,3),(7,4),(7,5),(7,6),
266];
267
268/// Compute the edge weight for a move from grid point (sr, sc) to (tr, tc).
269///
270/// Port of fdasrvf's `dp_edge_weight` for 1-D curves on a shared uniform grid.
271/// - Rows = q2 indices, columns = q1 indices (matching fdasrvf convention).
272/// - `slope = (argvals[tr] - argvals[sr]) / (argvals[tc] - argvals[sc])` = γ'
273/// - Walks through sub-intervals synchronized at both curves' breakpoints,
274///   accumulating `(q1[idx1] - √slope · q2[idx2])² · dt`.
275#[inline]
276fn dp_edge_weight(
277    q1: &[f64],
278    q2: &[f64],
279    argvals: &[f64],
280    sc: usize,
281    tc: usize,
282    sr: usize,
283    tr: usize,
284) -> f64 {
285    let n1 = tc - sc;
286    let n2 = tr - sr;
287    if n1 == 0 || n2 == 0 {
288        return f64::INFINITY;
289    }
290
291    let slope = (argvals[tr] - argvals[sr]) / (argvals[tc] - argvals[sc]);
292    let rslope = slope.sqrt();
293
294    // Walk through sub-intervals synchronized at breakpoints of both curves
295    let mut weight = 0.0;
296    let mut i1 = 0usize; // sub-interval index in q1 direction
297    let mut i2 = 0usize; // sub-interval index in q2 direction
298
299    while i1 < n1 && i2 < n2 {
300        // Current sub-interval boundaries as fractions of the total span
301        let left1 = i1 as f64 / n1 as f64;
302        let right1 = (i1 + 1) as f64 / n1 as f64;
303        let left2 = i2 as f64 / n2 as f64;
304        let right2 = (i2 + 1) as f64 / n2 as f64;
305
306        let left = left1.max(left2);
307        let right = right1.min(right2);
308        let dt = right - left;
309
310        if dt > 0.0 {
311            let diff = q1[sc + i1] - rslope * q2[sr + i2];
312            weight += diff * diff * dt;
313        }
314
315        // Advance whichever sub-interval ends first
316        if right1 < right2 {
317            i1 += 1;
318        } else if right2 < right1 {
319            i2 += 1;
320        } else {
321            i1 += 1;
322            i2 += 1;
323        }
324    }
325
326    // Scale by the span in q1 direction
327    weight * (argvals[tc] - argvals[sc])
328}
329
330/// Compute the λ·(slope−1)²·dt penalty for a DP edge.
331#[inline]
332fn dp_lambda_penalty(
333    argvals: &[f64],
334    sc: usize,
335    tc: usize,
336    sr: usize,
337    tr: usize,
338    lambda: f64,
339) -> f64 {
340    if lambda > 0.0 {
341        let dt = argvals[tc] - argvals[sc];
342        let slope = (argvals[tr] - argvals[sr]) / dt;
343        lambda * (slope - 1.0).powi(2) * dt
344    } else {
345        0.0
346    }
347}
348
349/// Traceback a parent-pointer array from bottom-right to top-left.
350///
351/// Returns the path as `(row, col)` pairs from `(0,0)` to `(nrows-1, ncols-1)`.
352fn dp_traceback(parent: &[u32], nrows: usize, ncols: usize) -> Vec<(usize, usize)> {
353    let mut path = Vec::with_capacity(nrows + ncols);
354    let mut cur = (nrows - 1) * ncols + (ncols - 1);
355    loop {
356        path.push((cur / ncols, cur % ncols));
357        if cur == 0 || parent[cur] == u32::MAX {
358            break;
359        }
360        cur = parent[cur] as usize;
361    }
362    path.reverse();
363    path
364}
365
366/// Try to relax cell `(tr, tc)` from each coprime neighbor, updating cost and parent.
367#[inline]
368fn dp_relax_cell<F>(
369    e: &mut [f64],
370    parent: &mut [u32],
371    ncols: usize,
372    tr: usize,
373    tc: usize,
374    edge_cost: &F,
375) where
376    F: Fn(usize, usize, usize, usize) -> f64,
377{
378    let idx = tr * ncols + tc;
379    for &(dr, dc) in &COPRIME_NBHD_7 {
380        if dr > tr || dc > tc {
381            continue;
382        }
383        let sr = tr - dr;
384        let sc = tc - dc;
385        let src_idx = sr * ncols + sc;
386        if e[src_idx] == f64::INFINITY {
387            continue;
388        }
389        let cost = e[src_idx] + edge_cost(sr, sc, tr, tc);
390        if cost < e[idx] {
391            e[idx] = cost;
392            parent[idx] = src_idx as u32;
393        }
394    }
395}
396
397/// Shared DP grid fill + traceback using the coprime neighborhood.
398///
399/// `edge_cost(sr, sc, tr, tc)` returns the combined edge weight + penalty for
400/// a move from local (sr, sc) to local (tr, tc). Returns the raw local-index
401/// path from (0,0) to (nrows-1, ncols-1).
402fn dp_grid_solve<F>(nrows: usize, ncols: usize, edge_cost: F) -> Vec<(usize, usize)>
403where
404    F: Fn(usize, usize, usize, usize) -> f64,
405{
406    let mut e = vec![f64::INFINITY; nrows * ncols];
407    let mut parent = vec![u32::MAX; nrows * ncols];
408    e[0] = 0.0;
409
410    for tr in 0..nrows {
411        for tc in 0..ncols {
412            if tr == 0 && tc == 0 {
413                continue;
414            }
415            dp_relax_cell(&mut e, &mut parent, ncols, tr, tc, &edge_cost);
416        }
417    }
418
419    dp_traceback(&parent, nrows, ncols)
420}
421
422/// Convert a DP path (local row,col indices) to an interpolated+normalized gamma warp.
423fn dp_path_to_gamma(path: &[(usize, usize)], argvals: &[f64]) -> Vec<f64> {
424    let path_tc: Vec<f64> = path.iter().map(|&(_, c)| argvals[c]).collect();
425    let path_tr: Vec<f64> = path.iter().map(|&(r, _)| argvals[r]).collect();
426    let mut gamma: Vec<f64> = argvals
427        .iter()
428        .map(|&t| linear_interp(&path_tc, &path_tr, t))
429        .collect();
430    normalize_warp(&mut gamma, argvals);
431    gamma
432}
433
434/// Core DP alignment between two SRSFs on a grid.
435///
436/// Finds the optimal warping γ minimizing ‖q₁ - (q₂∘γ)√γ'‖².
437/// Uses fdasrvf's coprime neighborhood (nbhd_dim=7 → 35 move directions).
438/// SRSFs are L2-normalized before alignment (matching fdasrvf's `optimum.reparam`).
439pub(crate) fn dp_alignment_core(q1: &[f64], q2: &[f64], argvals: &[f64], lambda: f64) -> Vec<f64> {
440    let m = argvals.len();
441    if m < 2 {
442        return argvals.to_vec();
443    }
444
445    let norm1 = q1.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
446    let norm2 = q2.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
447    let q1n: Vec<f64> = q1.iter().map(|&v| v / norm1).collect();
448    let q2n: Vec<f64> = q2.iter().map(|&v| v / norm2).collect();
449
450    let path = dp_grid_solve(m, m, |sr, sc, tr, tc| {
451        dp_edge_weight(&q1n, &q2n, argvals, sc, tc, sr, tr)
452            + dp_lambda_penalty(argvals, sc, tc, sr, tr, lambda)
453    });
454
455    dp_path_to_gamma(&path, argvals)
456}
457
458// ─── Public Alignment Functions ─────────────────────────────────────────────
459
460/// Align curve f2 to curve f1 using the elastic framework.
461///
462/// Computes the optimal warping γ such that f2∘γ is as close as possible
463/// to f1 in the elastic (Fisher-Rao) metric.
464///
465/// # Arguments
466/// * `f1` — Target curve (length m)
467/// * `f2` — Curve to align (length m)
468/// * `argvals` — Evaluation points (length m)
469/// * `lambda` — Penalty weight on warp deviation from identity (0.0 = no penalty)
470///
471/// # Returns
472/// [`AlignmentResult`] with warping function, aligned curve, and elastic distance.
473pub fn elastic_align_pair(f1: &[f64], f2: &[f64], argvals: &[f64], lambda: f64) -> AlignmentResult {
474    let m = f1.len();
475
476    // Build single-row FdMatrices for SRSF computation
477    let f1_mat = FdMatrix::from_slice(f1, 1, m).unwrap();
478    let f2_mat = FdMatrix::from_slice(f2, 1, m).unwrap();
479
480    let q1_mat = srsf_transform(&f1_mat, argvals);
481    let q2_mat = srsf_transform(&f2_mat, argvals);
482
483    let q1: Vec<f64> = q1_mat.row(0);
484    let q2: Vec<f64> = q2_mat.row(0);
485
486    // Find optimal warping via DP
487    let gamma = dp_alignment_core(&q1, &q2, argvals, lambda);
488
489    // Apply warping to f2
490    let f_aligned = reparameterize_curve(f2, argvals, &gamma);
491
492    // Compute elastic distance: L2 distance between q1 and aligned q2 SRSF
493    let f_aligned_mat = FdMatrix::from_slice(&f_aligned, 1, m).unwrap();
494    let q_aligned_mat = srsf_transform(&f_aligned_mat, argvals);
495    let q_aligned: Vec<f64> = q_aligned_mat.row(0);
496
497    let weights = simpsons_weights(argvals);
498    let distance = l2_distance(&q1, &q_aligned, &weights);
499
500    AlignmentResult {
501        gamma,
502        f_aligned,
503        distance,
504    }
505}
506
507/// Compute the elastic distance between two curves.
508///
509/// This is shorthand for aligning the pair and returning only the distance.
510///
511/// # Arguments
512/// * `f1` — First curve (length m)
513/// * `f2` — Second curve (length m)
514/// * `argvals` — Evaluation points (length m)
515/// * `lambda` — Penalty weight on warp deviation from identity (0.0 = no penalty)
516pub fn elastic_distance(f1: &[f64], f2: &[f64], argvals: &[f64], lambda: f64) -> f64 {
517    elastic_align_pair(f1, f2, argvals, lambda).distance
518}
519
520/// Align all curves in `data` to a single target curve.
521///
522/// # Arguments
523/// * `data` — Functional data matrix (n × m)
524/// * `target` — Target curve to align to (length m)
525/// * `argvals` — Evaluation points (length m)
526/// * `lambda` — Penalty weight on warp deviation from identity (0.0 = no penalty)
527///
528/// # Returns
529/// [`AlignmentSetResult`] with all warping functions, aligned curves, and distances.
530pub fn align_to_target(
531    data: &FdMatrix,
532    target: &[f64],
533    argvals: &[f64],
534    lambda: f64,
535) -> AlignmentSetResult {
536    let (n, m) = data.shape();
537
538    let results: Vec<AlignmentResult> = iter_maybe_parallel!(0..n)
539        .map(|i| {
540            let fi = data.row(i);
541            elastic_align_pair(target, &fi, argvals, lambda)
542        })
543        .collect();
544
545    let mut gammas = FdMatrix::zeros(n, m);
546    let mut aligned_data = FdMatrix::zeros(n, m);
547    let mut distances = Vec::with_capacity(n);
548
549    for (i, r) in results.into_iter().enumerate() {
550        for j in 0..m {
551            gammas[(i, j)] = r.gamma[j];
552            aligned_data[(i, j)] = r.f_aligned[j];
553        }
554        distances.push(r.distance);
555    }
556
557    AlignmentSetResult {
558        gammas,
559        aligned_data,
560        distances,
561    }
562}
563
564// ─── Distance Matrices ──────────────────────────────────────────────────────
565
566/// Compute the symmetric elastic distance matrix for a set of curves.
567///
568/// Uses upper-triangle computation with parallelism, following the
569/// `self_distance_matrix` pattern from `metric.rs`.
570///
571/// # Arguments
572/// * `data` — Functional data matrix (n × m)
573/// * `argvals` — Evaluation points (length m)
574/// * `lambda` — Penalty weight on warp deviation from identity (0.0 = no penalty)
575///
576/// # Returns
577/// Symmetric n × n distance matrix.
578pub fn elastic_self_distance_matrix(data: &FdMatrix, argvals: &[f64], lambda: f64) -> FdMatrix {
579    let n = data.nrows();
580
581    let upper_vals: Vec<f64> = iter_maybe_parallel!(0..n)
582        .flat_map(|i| {
583            let fi = data.row(i);
584            ((i + 1)..n)
585                .map(|j| {
586                    let fj = data.row(j);
587                    elastic_distance(&fi, &fj, argvals, lambda)
588                })
589                .collect::<Vec<_>>()
590        })
591        .collect();
592
593    let mut dist = FdMatrix::zeros(n, n);
594    let mut idx = 0;
595    for i in 0..n {
596        for j in (i + 1)..n {
597            let d = upper_vals[idx];
598            dist[(i, j)] = d;
599            dist[(j, i)] = d;
600            idx += 1;
601        }
602    }
603    dist
604}
605
606/// Compute the elastic distance matrix between two sets of curves.
607///
608/// # Arguments
609/// * `data1` — First dataset (n1 × m)
610/// * `data2` — Second dataset (n2 × m)
611/// * `argvals` — Evaluation points (length m)
612/// * `lambda` — Penalty weight on warp deviation from identity (0.0 = no penalty)
613///
614/// # Returns
615/// n1 × n2 distance matrix.
616pub fn elastic_cross_distance_matrix(
617    data1: &FdMatrix,
618    data2: &FdMatrix,
619    argvals: &[f64],
620    lambda: f64,
621) -> FdMatrix {
622    let n1 = data1.nrows();
623    let n2 = data2.nrows();
624
625    let vals: Vec<f64> = iter_maybe_parallel!(0..n1)
626        .flat_map(|i| {
627            let fi = data1.row(i);
628            (0..n2)
629                .map(|j| {
630                    let fj = data2.row(j);
631                    elastic_distance(&fi, &fj, argvals, lambda)
632                })
633                .collect::<Vec<_>>()
634        })
635        .collect();
636
637    let mut dist = FdMatrix::zeros(n1, n2);
638    for i in 0..n1 {
639        for j in 0..n2 {
640            dist[(i, j)] = vals[i * n2 + j];
641        }
642    }
643    dist
644}
645
646// ─── Phase-Amplitude Decomposition ──────────────────────────────────────────
647
648/// Result of elastic phase-amplitude decomposition.
649#[derive(Debug, Clone)]
650pub struct DecompositionResult {
651    /// Full alignment result.
652    pub alignment: AlignmentResult,
653    /// Amplitude distance: SRSF distance after alignment.
654    pub d_amplitude: f64,
655    /// Phase distance: geodesic distance of warp from identity.
656    pub d_phase: f64,
657}
658
659/// Perform elastic phase-amplitude decomposition of two curves.
660///
661/// Returns both the alignment result and the separate amplitude and phase distances.
662///
663/// # Arguments
664/// * `f1` — Target curve (length m)
665/// * `f2` — Curve to decompose against f1 (length m)
666/// * `argvals` — Evaluation points (length m)
667/// * `lambda` — Penalty weight on warp deviation from identity (0.0 = no penalty)
668pub fn elastic_decomposition(
669    f1: &[f64],
670    f2: &[f64],
671    argvals: &[f64],
672    lambda: f64,
673) -> DecompositionResult {
674    let alignment = elastic_align_pair(f1, f2, argvals, lambda);
675    let d_amplitude = alignment.distance;
676    let d_phase = crate::warping::phase_distance(&alignment.gamma, argvals);
677    DecompositionResult {
678        alignment,
679        d_amplitude,
680        d_phase,
681    }
682}
683
684/// Compute the amplitude distance between two curves (= elastic distance after alignment).
685pub fn amplitude_distance(f1: &[f64], f2: &[f64], argvals: &[f64], lambda: f64) -> f64 {
686    elastic_distance(f1, f2, argvals, lambda)
687}
688
689/// Compute the phase distance between two curves (geodesic distance of optimal warp from identity).
690pub fn phase_distance_pair(f1: &[f64], f2: &[f64], argvals: &[f64], lambda: f64) -> f64 {
691    let alignment = elastic_align_pair(f1, f2, argvals, lambda);
692    crate::warping::phase_distance(&alignment.gamma, argvals)
693}
694
695/// Compute the symmetric phase distance matrix for a set of curves.
696pub fn phase_self_distance_matrix(data: &FdMatrix, argvals: &[f64], lambda: f64) -> FdMatrix {
697    let n = data.nrows();
698
699    let upper_vals: Vec<f64> = iter_maybe_parallel!(0..n)
700        .flat_map(|i| {
701            let fi = data.row(i);
702            ((i + 1)..n)
703                .map(|j| {
704                    let fj = data.row(j);
705                    phase_distance_pair(&fi, &fj, argvals, lambda)
706                })
707                .collect::<Vec<_>>()
708        })
709        .collect();
710
711    let mut dist = FdMatrix::zeros(n, n);
712    let mut idx = 0;
713    for i in 0..n {
714        for j in (i + 1)..n {
715            let d = upper_vals[idx];
716            dist[(i, j)] = d;
717            dist[(j, i)] = d;
718            idx += 1;
719        }
720    }
721    dist
722}
723
724/// Compute the symmetric amplitude distance matrix (= elastic self distance matrix).
725pub fn amplitude_self_distance_matrix(data: &FdMatrix, argvals: &[f64], lambda: f64) -> FdMatrix {
726    elastic_self_distance_matrix(data, argvals, lambda)
727}
728
729// ─── Karcher Mean ───────────────────────────────────────────────────────────
730
731/// Compute relative change between successive mean SRSFs.
732///
733/// Returns `‖q_new - q_old‖₂ / ‖q_old‖₂`, matching R's fdasrvf
734/// `time_warping` convergence metric (unweighted discrete L2 norm).
735fn relative_change(q_old: &[f64], q_new: &[f64]) -> f64 {
736    let diff_norm: f64 = q_old
737        .iter()
738        .zip(q_new.iter())
739        .map(|(&a, &b)| (a - b).powi(2))
740        .sum::<f64>()
741        .sqrt();
742    let old_norm: f64 = q_old.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
743    diff_norm / old_norm
744}
745
746/// Compute a single SRSF from a slice (single-row convenience).
747fn srsf_single(f: &[f64], argvals: &[f64]) -> Vec<f64> {
748    let m = f.len();
749    let mat = FdMatrix::from_slice(f, 1, m).unwrap();
750    let q_mat = srsf_transform(&mat, argvals);
751    q_mat.row(0)
752}
753
754/// Align a single SRSF q2 to q1 and return (gamma, aligned_q).
755fn align_srsf_pair(q1: &[f64], q2: &[f64], argvals: &[f64], lambda: f64) -> (Vec<f64>, Vec<f64>) {
756    let gamma = dp_alignment_core(q1, q2, argvals, lambda);
757
758    // Warp q2 by gamma and adjust by sqrt(gamma')
759    let q2_warped = reparameterize_curve(q2, argvals, &gamma);
760
761    // Compute gamma' via finite differences
762    let m = gamma.len();
763    let mut gamma_dot = vec![0.0; m];
764    gamma_dot[0] = (gamma[1] - gamma[0]) / (argvals[1] - argvals[0]);
765    for j in 1..(m - 1) {
766        gamma_dot[j] = (gamma[j + 1] - gamma[j - 1]) / (argvals[j + 1] - argvals[j - 1]);
767    }
768    gamma_dot[m - 1] = (gamma[m - 1] - gamma[m - 2]) / (argvals[m - 1] - argvals[m - 2]);
769
770    // q2_aligned = (q2 ∘ γ) * sqrt(γ')
771    let q2_aligned: Vec<f64> = q2_warped
772        .iter()
773        .zip(gamma_dot.iter())
774        .map(|(&q, &gd)| q * gd.max(0.0).sqrt())
775        .collect();
776
777    (gamma, q2_aligned)
778}
779
780/// Compute the Karcher (Fréchet) mean in the elastic metric.
781///
782/// Iteratively aligns all curves to the current mean estimate in SRSF space,
783/// computes the pointwise mean of aligned SRSFs, and reconstructs the mean curve.
784///
785/// # Arguments
786/// * `data` — Functional data matrix (n × m)
787/// * `argvals` — Evaluation points (length m)
788/// * `max_iter` — Maximum number of iterations
789/// * `tol` — Convergence tolerance for the SRSF mean
790///
791/// # Returns
792/// [`KarcherMeanResult`] with mean curve, warping functions, aligned data, and convergence info.
793///
794/// # Examples
795///
796/// ```
797/// use fdars_core::simulation::{sim_fundata, EFunType, EValType};
798/// use fdars_core::alignment::karcher_mean;
799///
800/// let t: Vec<f64> = (0..50).map(|i| i as f64 / 49.0).collect();
801/// let data = sim_fundata(20, &t, 3, EFunType::Fourier, EValType::Exponential, Some(42));
802///
803/// let result = karcher_mean(&data, &t, 20, 1e-4, 0.0);
804/// assert_eq!(result.mean.len(), 50);
805/// assert!(result.n_iter <= 20);
806/// ```
807/// Accumulate alignment results: store gammas and return the mean of aligned SRSFs.
808fn accumulate_alignments(
809    results: &[(Vec<f64>, Vec<f64>)],
810    gammas: &mut FdMatrix,
811    m: usize,
812    n: usize,
813) -> Vec<f64> {
814    let mut mu_q_new = vec![0.0; m];
815    for (i, (gamma, q_aligned)) in results.iter().enumerate() {
816        for j in 0..m {
817            gammas[(i, j)] = gamma[j];
818            mu_q_new[j] += q_aligned[j];
819        }
820    }
821    for j in 0..m {
822        mu_q_new[j] /= n as f64;
823    }
824    mu_q_new
825}
826
827/// Apply stored warps to original curves to produce aligned data.
828fn apply_stored_warps(data: &FdMatrix, gammas: &FdMatrix, argvals: &[f64]) -> FdMatrix {
829    let (n, m) = data.shape();
830    let mut aligned = FdMatrix::zeros(n, m);
831    for i in 0..n {
832        let fi = data.row(i);
833        let gamma: Vec<f64> = (0..m).map(|j| gammas[(i, j)]).collect();
834        let f_aligned = reparameterize_curve(&fi, argvals, &gamma);
835        for j in 0..m {
836            aligned[(i, j)] = f_aligned[j];
837        }
838    }
839    aligned
840}
841
842/// Select the SRSF closest to the pointwise mean as template. Returns (mu_q, mu_f).
843fn select_template(srsf_mat: &FdMatrix, data: &FdMatrix, argvals: &[f64]) -> (Vec<f64>, Vec<f64>) {
844    let (n, m) = srsf_mat.shape();
845    let mnq = mean_1d(srsf_mat);
846    let mut min_dist = f64::INFINITY;
847    let mut min_idx = 0;
848    for i in 0..n {
849        let dist_sq: f64 = (0..m).map(|j| (srsf_mat[(i, j)] - mnq[j]).powi(2)).sum();
850        if dist_sq < min_dist {
851            min_dist = dist_sq;
852            min_idx = i;
853        }
854    }
855    let _ = argvals; // kept for API consistency
856    (srsf_mat.row(min_idx), data.row(min_idx))
857}
858
859/// Pre-centering: align all curves to template, compute inverse mean warp, re-center.
860fn pre_center_template(
861    data: &FdMatrix,
862    mu_q: &[f64],
863    mu: &[f64],
864    argvals: &[f64],
865    lambda: f64,
866) -> (Vec<f64>, Vec<f64>) {
867    let (n, m) = data.shape();
868    let align_results: Vec<(Vec<f64>, Vec<f64>)> = iter_maybe_parallel!(0..n)
869        .map(|i| {
870            let fi = data.row(i);
871            let qi = srsf_single(&fi, argvals);
872            align_srsf_pair(mu_q, &qi, argvals, lambda)
873        })
874        .collect();
875
876    let mut init_gammas = FdMatrix::zeros(n, m);
877    for (i, (gamma, _)) in align_results.iter().enumerate() {
878        for j in 0..m {
879            init_gammas[(i, j)] = gamma[j];
880        }
881    }
882
883    let gam_inv = sqrt_mean_inverse(&init_gammas, argvals);
884    let mu_new = reparameterize_curve(mu, argvals, &gam_inv);
885    let mu_q_new = srsf_single(&mu_new, argvals);
886    (mu_q_new, mu_new)
887}
888
889/// Post-convergence centering: center mean SRSF and warps via SqrtMeanInverse.
890fn post_center_results(
891    data: &FdMatrix,
892    mu_q: &[f64],
893    final_gammas: &mut FdMatrix,
894    argvals: &[f64],
895) -> (Vec<f64>, Vec<f64>, FdMatrix) {
896    let (n, m) = data.shape();
897    let gam_inv = sqrt_mean_inverse(final_gammas, argvals);
898    let h = (argvals[m - 1] - argvals[0]) / (m - 1) as f64;
899    let gam_inv_dev = gradient_uniform(&gam_inv, h);
900
901    let mu_q_warped = reparameterize_curve(mu_q, argvals, &gam_inv);
902    let mu_q_centered: Vec<f64> = mu_q_warped
903        .iter()
904        .zip(gam_inv_dev.iter())
905        .map(|(&q, &gd)| q * gd.max(0.0).sqrt())
906        .collect();
907
908    for i in 0..n {
909        let gam_i: Vec<f64> = (0..m).map(|j| final_gammas[(i, j)]).collect();
910        let gam_centered = reparameterize_curve(&gam_i, argvals, &gam_inv);
911        for j in 0..m {
912            final_gammas[(i, j)] = gam_centered[j];
913        }
914    }
915
916    let initial_mean = mean_1d(data);
917    let mu = srsf_inverse(&mu_q_centered, argvals, initial_mean[0]);
918    let final_aligned = apply_stored_warps(data, final_gammas, argvals);
919    (mu, mu_q_centered, final_aligned)
920}
921
922/// Downsample argvals and signal by `factor`, keeping first and last points.
923fn downsample_uniform(signal: &[f64], argvals: &[f64], factor: usize) -> (Vec<f64>, Vec<f64>) {
924    let m = signal.len();
925    if factor <= 1 || m <= 2 {
926        return (signal.to_vec(), argvals.to_vec());
927    }
928    let mut sig = Vec::new();
929    let mut arg = Vec::new();
930    for i in (0..m).step_by(factor) {
931        sig.push(signal[i]);
932        arg.push(argvals[i]);
933    }
934    // Ensure last point is included
935    if (m - 1) % factor != 0 {
936        sig.push(signal[m - 1]);
937        arg.push(argvals[m - 1]);
938    }
939    (sig, arg)
940}
941
942/// Upsample signal from coarse grid to fine grid via linear interpolation.
943fn upsample_to_fine(coarse: &[f64], argvals_coarse: &[f64], argvals_fine: &[f64]) -> Vec<f64> {
944    argvals_fine
945        .iter()
946        .map(|&t| linear_interp(argvals_coarse, coarse, t))
947        .collect()
948}
949
950pub fn karcher_mean(
951    data: &FdMatrix,
952    argvals: &[f64],
953    max_iter: usize,
954    tol: f64,
955    lambda: f64,
956) -> KarcherMeanResult {
957    let (n, m) = data.shape();
958
959    let srsf_mat = srsf_transform(data, argvals);
960    let (mut mu_q, mu) = select_template(&srsf_mat, data, argvals);
961    let (mu_q_c, mu_c) = pre_center_template(data, &mu_q, &mu, argvals, lambda);
962    mu_q = mu_q_c;
963    let mut mu = mu_c;
964
965    let mut converged = false;
966    let mut n_iter = 0;
967    let mut final_gammas = FdMatrix::zeros(n, m);
968
969    // Coarse-to-fine strategy: run initial iterations on downsampled grid
970    // Only worthwhile for large grids with enough iterations to split
971    let coarse_factor = if m > 50 && max_iter >= 10 { 4 } else { 1 };
972    let coarse_iters = if coarse_factor > 1 { max_iter / 2 } else { 0 };
973    let fine_iters = max_iter - coarse_iters;
974
975    // Phase 1: coarse iterations
976    if coarse_iters > 0 {
977        let (mu_q_coarse, argvals_coarse) = downsample_uniform(&mu_q, argvals, coarse_factor);
978        let m_c = argvals_coarse.len();
979        let mut mu_q_c = mu_q_coarse;
980
981        // Downsample all curves to coarse grid
982        let data_coarse: Vec<Vec<f64>> = (0..n)
983            .map(|i| {
984                let row = data.row(i);
985                downsample_uniform(&row, argvals, coarse_factor).0
986            })
987            .collect();
988
989        let mut coarse_gammas = FdMatrix::zeros(n, m_c);
990
991        for iter in 0..coarse_iters {
992            n_iter = iter + 1;
993
994            let align_results: Vec<(Vec<f64>, Vec<f64>)> = iter_maybe_parallel!(0..n)
995                .map(|i| {
996                    let qi = srsf_single(&data_coarse[i], &argvals_coarse);
997                    align_srsf_pair(&mu_q_c, &qi, &argvals_coarse, lambda)
998                })
999                .collect();
1000
1001            let mu_q_new = accumulate_alignments(&align_results, &mut coarse_gammas, m_c, n);
1002
1003            let rel = relative_change(&mu_q_c, &mu_q_new);
1004            if rel < tol {
1005                converged = true;
1006                mu_q_c = mu_q_new;
1007                break;
1008            }
1009
1010            mu_q_c = mu_q_new;
1011        }
1012
1013        // Upsample coarse mu_q to fine grid
1014        mu_q = upsample_to_fine(&mu_q_c, &argvals_coarse, argvals);
1015        mu = srsf_inverse(&mu_q, argvals, mu[0]);
1016    }
1017
1018    // Phase 2: fine iterations (or all iterations if m <= 50)
1019    if fine_iters > 0 {
1020        converged = false; // Fine phase must independently converge
1021    }
1022    let fine_start = n_iter;
1023    for iter in 0..fine_iters {
1024        n_iter = fine_start + iter + 1;
1025
1026        let align_results: Vec<(Vec<f64>, Vec<f64>)> = iter_maybe_parallel!(0..n)
1027            .map(|i| {
1028                let fi = data.row(i);
1029                let qi = srsf_single(&fi, argvals);
1030                align_srsf_pair(&mu_q, &qi, argvals, lambda)
1031            })
1032            .collect();
1033
1034        let mu_q_new = accumulate_alignments(&align_results, &mut final_gammas, m, n);
1035
1036        let rel = relative_change(&mu_q, &mu_q_new);
1037        if rel < tol {
1038            converged = true;
1039            mu_q = mu_q_new;
1040            break;
1041        }
1042
1043        mu_q = mu_q_new;
1044        mu = srsf_inverse(&mu_q, argvals, mu[0]);
1045    }
1046
1047    // If coarse converged but no fine iterations ran, do one fine pass for final_gammas
1048    if converged && fine_start > 0 {
1049        let align_results: Vec<(Vec<f64>, Vec<f64>)> = iter_maybe_parallel!(0..n)
1050            .map(|i| {
1051                let fi = data.row(i);
1052                let qi = srsf_single(&fi, argvals);
1053                align_srsf_pair(&mu_q, &qi, argvals, lambda)
1054            })
1055            .collect();
1056        let mu_q_new = accumulate_alignments(&align_results, &mut final_gammas, m, n);
1057        mu_q = mu_q_new;
1058    }
1059
1060    let (mu_final, mu_q_final, final_aligned) =
1061        post_center_results(data, &mu_q, &mut final_gammas, argvals);
1062
1063    KarcherMeanResult {
1064        mean: mu_final,
1065        mean_srsf: mu_q_final,
1066        gammas: final_gammas,
1067        aligned_data: final_aligned,
1068        n_iter,
1069        converged,
1070        aligned_srsfs: None,
1071    }
1072}
1073
1074// ─── TSRVF (Transported SRSF) ────────────────────────────────────────────────
1075// Maps aligned SRSFs to the tangent space of the Karcher mean on the Hilbert
1076// sphere. Tangent vectors live in a standard Euclidean space, enabling PCA,
1077// regression, and clustering on elastic-aligned curves.
1078
1079/// Result of the TSRVF transform.
1080#[derive(Debug, Clone)]
1081pub struct TsrvfResult {
1082    /// Tangent vectors in Euclidean space (n × m).
1083    pub tangent_vectors: FdMatrix,
1084    /// Karcher mean curve (length m).
1085    pub mean: Vec<f64>,
1086    /// SRSF of the Karcher mean (length m).
1087    pub mean_srsf: Vec<f64>,
1088    /// L2 norm of the mean SRSF.
1089    pub mean_srsf_norm: f64,
1090    /// Per-curve aligned SRSF norms (length n).
1091    pub srsf_norms: Vec<f64>,
1092    /// Per-curve initial values f_i(0) for SRSF inverse reconstruction (length n).
1093    pub initial_values: Vec<f64>,
1094    /// Warping functions from Karcher mean computation (n × m).
1095    pub gammas: FdMatrix,
1096    /// Whether the Karcher mean converged.
1097    pub converged: bool,
1098}
1099
1100/// Full TSRVF pipeline: compute Karcher mean, then transport SRSFs to tangent space.
1101///
1102/// # Arguments
1103/// * `data` — Functional data matrix (n × m)
1104/// * `argvals` — Evaluation points (length m)
1105/// * `max_iter` — Maximum Karcher mean iterations
1106/// * `tol` — Convergence tolerance for Karcher mean
1107///
1108/// # Returns
1109/// [`TsrvfResult`] containing tangent vectors and associated metadata.
1110pub fn tsrvf_transform(
1111    data: &FdMatrix,
1112    argvals: &[f64],
1113    max_iter: usize,
1114    tol: f64,
1115    lambda: f64,
1116) -> TsrvfResult {
1117    let karcher = karcher_mean(data, argvals, max_iter, tol, lambda);
1118    tsrvf_from_alignment(&karcher, argvals)
1119}
1120
1121/// Smooth aligned SRSFs to remove DP kink artifacts before TSRVF computation.
1122///
1123/// Uses Nadaraya-Watson kernel smoothing (Gaussian, bandwidth = 2 grid spacings)
1124/// on each SRSF row. This removes the derivative spikes from DP warp kinks
1125/// without affecting alignment results or the Karcher mean.
1126fn smooth_aligned_srsfs(srsf: &FdMatrix, m: usize) -> FdMatrix {
1127    let n = srsf.nrows();
1128    let time: Vec<f64> = (0..m).map(|j| j as f64 / (m - 1) as f64).collect();
1129    let bandwidth = 2.0 / (m - 1) as f64;
1130
1131    let mut smoothed = FdMatrix::zeros(n, m);
1132    for i in 0..n {
1133        let qi = srsf.row(i);
1134        let qi_smooth = nadaraya_watson(&time, &qi, &time, bandwidth, "gaussian");
1135        for j in 0..m {
1136            smoothed[(i, j)] = qi_smooth[j];
1137        }
1138    }
1139    smoothed
1140}
1141
1142/// Compute TSRVF from a pre-computed Karcher mean alignment.
1143///
1144/// Avoids re-running the expensive Karcher mean computation when the alignment
1145/// has already been computed.
1146///
1147/// # Arguments
1148/// * `karcher` — Pre-computed Karcher mean result
1149/// * `argvals` — Evaluation points (length m)
1150///
1151/// # Returns
1152/// [`TsrvfResult`] containing tangent vectors and associated metadata.
1153pub fn tsrvf_from_alignment(karcher: &KarcherMeanResult, argvals: &[f64]) -> TsrvfResult {
1154    let (n, m) = karcher.aligned_data.shape();
1155
1156    // Step 1: Compute SRSFs of aligned data
1157    let aligned_srsf = srsf_transform(&karcher.aligned_data, argvals);
1158
1159    // Step 1b: Smooth aligned SRSFs to remove DP kink artifacts.
1160    //
1161    // DP alignment produces piecewise-linear warps with kinks at grid transitions.
1162    // When curves are reparameterized by these warps, the kinks propagate into the
1163    // aligned curves' derivatives (SRSFs), creating spikes that dominate TSRVF
1164    // tangent vectors and PCA.
1165    //
1166    // R's fdasrvf does not smooth here and suffers from the same spike artifacts.
1167    // Python's fdasrsf mitigates this via spline smoothing (s=1e-4) in SqrtMean.
1168    // We smooth the aligned SRSFs before tangent vector computation — this only
1169    // affects TSRVF output and does not change the alignment or Karcher mean.
1170    let aligned_srsf = smooth_aligned_srsfs(&aligned_srsf, m);
1171
1172    // Step 2: Smooth and normalize mean SRSF to unit sphere.
1173    // The mean SRSF must be smoothed consistently with the aligned SRSFs
1174    // so that a single curve (which IS the mean) produces a zero tangent vector.
1175    let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
1176    let bandwidth = 2.0 / (m - 1) as f64;
1177    let mean_srsf_smooth = nadaraya_watson(&time, &karcher.mean_srsf, &time, bandwidth, "gaussian");
1178    let mean_norm = l2_norm_l2(&mean_srsf_smooth, &time);
1179
1180    let mu_unit: Vec<f64> = if mean_norm > 1e-10 {
1181        mean_srsf_smooth.iter().map(|&q| q / mean_norm).collect()
1182    } else {
1183        vec![0.0; m]
1184    };
1185
1186    // Step 3: For each aligned curve, compute tangent vector via inverse exponential map
1187    let srsf_norms: Vec<f64> = iter_maybe_parallel!(0..n)
1188        .map(|i| {
1189            let qi = aligned_srsf.row(i);
1190            l2_norm_l2(&qi, &time)
1191        })
1192        .collect();
1193
1194    let tangent_data: Vec<Vec<f64>> = iter_maybe_parallel!(0..n)
1195        .map(|i| {
1196            let qi = aligned_srsf.row(i);
1197            let qi_norm = srsf_norms[i];
1198
1199            if qi_norm < 1e-10 || mean_norm < 1e-10 {
1200                return vec![0.0; m];
1201            }
1202
1203            // Normalize to unit sphere
1204            let qi_unit: Vec<f64> = qi.iter().map(|&q| q / qi_norm).collect();
1205
1206            // Shooting vector from mu_unit to qi_unit
1207            inv_exp_map_sphere(&mu_unit, &qi_unit, &time)
1208        })
1209        .collect();
1210
1211    // Assemble tangent vectors into FdMatrix
1212    let mut tangent_vectors = FdMatrix::zeros(n, m);
1213    for i in 0..n {
1214        for j in 0..m {
1215            tangent_vectors[(i, j)] = tangent_data[i][j];
1216        }
1217    }
1218
1219    // Store per-curve initial values for SRSF inverse reconstruction.
1220    // Warping preserves f_i(0) since gamma(0) = 0.
1221    let initial_values: Vec<f64> = (0..n).map(|i| karcher.aligned_data[(i, 0)]).collect();
1222
1223    TsrvfResult {
1224        tangent_vectors,
1225        mean: karcher.mean.clone(),
1226        mean_srsf: mean_srsf_smooth,
1227        mean_srsf_norm: mean_norm,
1228        srsf_norms,
1229        initial_values,
1230        gammas: karcher.gammas.clone(),
1231        converged: karcher.converged,
1232    }
1233}
1234
1235/// Reconstruct aligned curves from TSRVF tangent vectors.
1236///
1237/// Inverts the TSRVF transform: maps tangent vectors back to the Hilbert sphere
1238/// via the exponential map, rescales, and reconstructs curves via SRSF inverse.
1239///
1240/// # Arguments
1241/// * `tsrvf` — TSRVF result from [`tsrvf_transform`] or [`tsrvf_from_alignment`]
1242/// * `argvals` — Evaluation points (length m)
1243///
1244/// # Returns
1245/// FdMatrix of reconstructed aligned curves (n × m).
1246pub fn tsrvf_inverse(tsrvf: &TsrvfResult, argvals: &[f64]) -> FdMatrix {
1247    let (n, m) = tsrvf.tangent_vectors.shape();
1248
1249    let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
1250
1251    // Normalize mean SRSF to unit sphere
1252    let mu_unit: Vec<f64> = if tsrvf.mean_srsf_norm > 1e-10 {
1253        tsrvf
1254            .mean_srsf
1255            .iter()
1256            .map(|&q| q / tsrvf.mean_srsf_norm)
1257            .collect()
1258    } else {
1259        vec![0.0; m]
1260    };
1261
1262    let curves: Vec<Vec<f64>> = iter_maybe_parallel!(0..n)
1263        .map(|i| {
1264            let vi = tsrvf.tangent_vectors.row(i);
1265
1266            // Map back to sphere: exp_map(mu_unit, v_i)
1267            let qi_unit = exp_map_sphere(&mu_unit, &vi, &time);
1268
1269            // Rescale by original norm
1270            let qi: Vec<f64> = qi_unit.iter().map(|&q| q * tsrvf.srsf_norms[i]).collect();
1271
1272            // Reconstruct curve from SRSF using per-curve initial value
1273            srsf_inverse(&qi, argvals, tsrvf.initial_values[i])
1274        })
1275        .collect();
1276
1277    let mut result = FdMatrix::zeros(n, m);
1278    for i in 0..n {
1279        for j in 0..m {
1280            result[(i, j)] = curves[i][j];
1281        }
1282    }
1283    result
1284}
1285
1286// ─── Parallel Transport Variants ─────────────────────────────────────────────
1287
1288/// Method for transporting tangent vectors on the Hilbert sphere.
1289#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
1290pub enum TransportMethod {
1291    /// Inverse exponential map (log map) — default, matches existing TSRVF behavior.
1292    #[default]
1293    LogMap,
1294    /// Schild's ladder approximation to parallel transport.
1295    SchildsLadder,
1296    /// Pole ladder approximation to parallel transport.
1297    PoleLadder,
1298}
1299
1300/// Schild's ladder parallel transport of vector `v` from `from` to `to` on the sphere.
1301fn parallel_transport_schilds(v: &[f64], from: &[f64], to: &[f64], time: &[f64]) -> Vec<f64> {
1302    use crate::warping::{exp_map_sphere, inv_exp_map_sphere};
1303
1304    let v_norm = crate::warping::l2_norm_l2(v, time);
1305    if v_norm < 1e-10 {
1306        return vec![0.0; v.len()];
1307    }
1308
1309    // endpoint = exp_from(v)
1310    let endpoint = exp_map_sphere(from, v, time);
1311
1312    // midpoint_v = log_to(endpoint) — vector at `to` pointing toward endpoint
1313    let log_to_ep = inv_exp_map_sphere(to, &endpoint, time);
1314
1315    // midpoint = exp_to(0.5 * log_to_ep)
1316    let half_log: Vec<f64> = log_to_ep.iter().map(|&x| 0.5 * x).collect();
1317    let midpoint = exp_map_sphere(to, &half_log, time);
1318
1319    // transported = 2 * log_to(midpoint)
1320    let log_to_mid = inv_exp_map_sphere(to, &midpoint, time);
1321    log_to_mid.iter().map(|&x| 2.0 * x).collect()
1322}
1323
1324/// Pole ladder parallel transport of vector `v` from `from` to `to` on the sphere.
1325fn parallel_transport_pole(v: &[f64], from: &[f64], to: &[f64], time: &[f64]) -> Vec<f64> {
1326    use crate::warping::{exp_map_sphere, inv_exp_map_sphere};
1327
1328    let v_norm = crate::warping::l2_norm_l2(v, time);
1329    if v_norm < 1e-10 {
1330        return vec![0.0; v.len()];
1331    }
1332
1333    // pole = exp_from(-v)
1334    let neg_v: Vec<f64> = v.iter().map(|&x| -x).collect();
1335    let pole = exp_map_sphere(from, &neg_v, time);
1336
1337    // midpoint_v = log_to(pole)
1338    let log_to_pole = inv_exp_map_sphere(to, &pole, time);
1339
1340    // midpoint = exp_to(0.5 * log_to_pole)
1341    let half_log: Vec<f64> = log_to_pole.iter().map(|&x| 0.5 * x).collect();
1342    let midpoint = exp_map_sphere(to, &half_log, time);
1343
1344    // transported = -2 * log_to(midpoint)
1345    let log_to_mid = inv_exp_map_sphere(to, &midpoint, time);
1346    log_to_mid.iter().map(|&x| -2.0 * x).collect()
1347}
1348
1349/// Full TSRVF pipeline with configurable transport method.
1350///
1351/// Like [`tsrvf_transform`] but allows choosing the parallel transport method.
1352pub fn tsrvf_transform_with_method(
1353    data: &FdMatrix,
1354    argvals: &[f64],
1355    max_iter: usize,
1356    tol: f64,
1357    lambda: f64,
1358    method: TransportMethod,
1359) -> TsrvfResult {
1360    let karcher = karcher_mean(data, argvals, max_iter, tol, lambda);
1361    tsrvf_from_alignment_with_method(&karcher, argvals, method)
1362}
1363
1364/// Compute TSRVF from a pre-computed Karcher mean with configurable transport.
1365///
1366/// - [`TransportMethod::LogMap`]: Uses `inv_exp_map(mu, qi)` directly (standard TSRVF).
1367/// - [`TransportMethod::SchildsLadder`]: Computes `v = -log_qi(mu)`, then transports
1368///   via Schild's ladder from qi to mu.
1369/// - [`TransportMethod::PoleLadder`]: Same but via pole ladder.
1370pub fn tsrvf_from_alignment_with_method(
1371    karcher: &KarcherMeanResult,
1372    argvals: &[f64],
1373    method: TransportMethod,
1374) -> TsrvfResult {
1375    if method == TransportMethod::LogMap {
1376        return tsrvf_from_alignment(karcher, argvals);
1377    }
1378
1379    let (n, m) = karcher.aligned_data.shape();
1380    let aligned_srsf = srsf_transform(&karcher.aligned_data, argvals);
1381    let aligned_srsf = smooth_aligned_srsfs(&aligned_srsf, m);
1382    let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
1383    let bandwidth = 2.0 / (m - 1) as f64;
1384    let mean_srsf_smooth = nadaraya_watson(&time, &karcher.mean_srsf, &time, bandwidth, "gaussian");
1385    let mean_norm = crate::warping::l2_norm_l2(&mean_srsf_smooth, &time);
1386
1387    let mu_unit: Vec<f64> = if mean_norm > 1e-10 {
1388        mean_srsf_smooth.iter().map(|&q| q / mean_norm).collect()
1389    } else {
1390        vec![0.0; m]
1391    };
1392
1393    let srsf_norms: Vec<f64> = iter_maybe_parallel!(0..n)
1394        .map(|i| {
1395            let qi = aligned_srsf.row(i);
1396            crate::warping::l2_norm_l2(&qi, &time)
1397        })
1398        .collect();
1399
1400    let tangent_data: Vec<Vec<f64>> = iter_maybe_parallel!(0..n)
1401        .map(|i| {
1402            let qi = aligned_srsf.row(i);
1403            let qi_norm = srsf_norms[i];
1404
1405            if qi_norm < 1e-10 || mean_norm < 1e-10 {
1406                return vec![0.0; m];
1407            }
1408
1409            let qi_unit: Vec<f64> = qi.iter().map(|&q| q / qi_norm).collect();
1410
1411            // Compute v = -log_qi(mu) — vector at qi pointing away from mu
1412            let v_at_qi = inv_exp_map_sphere(&qi_unit, &mu_unit, &time);
1413            let neg_v: Vec<f64> = v_at_qi.iter().map(|&x| -x).collect();
1414
1415            // Transport from qi to mu
1416            match method {
1417                TransportMethod::SchildsLadder => {
1418                    parallel_transport_schilds(&neg_v, &qi_unit, &mu_unit, &time)
1419                }
1420                TransportMethod::PoleLadder => {
1421                    parallel_transport_pole(&neg_v, &qi_unit, &mu_unit, &time)
1422                }
1423                TransportMethod::LogMap => unreachable!(),
1424            }
1425        })
1426        .collect();
1427
1428    let mut tangent_vectors = FdMatrix::zeros(n, m);
1429    for i in 0..n {
1430        for j in 0..m {
1431            tangent_vectors[(i, j)] = tangent_data[i][j];
1432        }
1433    }
1434
1435    let initial_values: Vec<f64> = (0..n).map(|i| karcher.aligned_data[(i, 0)]).collect();
1436
1437    TsrvfResult {
1438        tangent_vectors,
1439        mean: karcher.mean.clone(),
1440        mean_srsf: mean_srsf_smooth,
1441        mean_srsf_norm: mean_norm,
1442        srsf_norms,
1443        initial_values,
1444        gammas: karcher.gammas.clone(),
1445        converged: karcher.converged,
1446    }
1447}
1448
1449// ─── Alignment Quality Metrics ───────────────────────────────────────────────
1450
1451/// Comprehensive alignment quality assessment.
1452#[derive(Debug, Clone)]
1453pub struct AlignmentQuality {
1454    /// Per-curve geodesic distance from warp to identity.
1455    pub warp_complexity: Vec<f64>,
1456    /// Mean warp complexity.
1457    pub mean_warp_complexity: f64,
1458    /// Per-curve bending energy ∫(γ'')² dt.
1459    pub warp_smoothness: Vec<f64>,
1460    /// Mean warp smoothness (bending energy).
1461    pub mean_warp_smoothness: f64,
1462    /// Total variance: (1/n) Σ ∫(f_i - mean_orig)² dt.
1463    pub total_variance: f64,
1464    /// Amplitude variance: (1/n) Σ ∫(f_i^aligned - mean_aligned)² dt.
1465    pub amplitude_variance: f64,
1466    /// Phase variance: total - amplitude (clamped ≥ 0).
1467    pub phase_variance: f64,
1468    /// Phase-to-total variance ratio.
1469    pub phase_amplitude_ratio: f64,
1470    /// Pointwise ratio: aligned_var / orig_var per time point.
1471    pub pointwise_variance_ratio: Vec<f64>,
1472    /// Mean variance reduction.
1473    pub mean_variance_reduction: f64,
1474}
1475
1476/// Compute warp complexity: geodesic distance from a warp to the identity.
1477///
1478/// This is `arccos(⟨ψ, ψ_id⟩)` on the Hilbert sphere.
1479pub fn warp_complexity(gamma: &[f64], argvals: &[f64]) -> f64 {
1480    crate::warping::phase_distance(gamma, argvals)
1481}
1482
1483/// Compute warp smoothness (bending energy): ∫(γ'')² dt.
1484pub fn warp_smoothness(gamma: &[f64], argvals: &[f64]) -> f64 {
1485    let m = gamma.len();
1486    if m < 3 {
1487        return 0.0;
1488    }
1489
1490    let h = (argvals[m - 1] - argvals[0]) / (m - 1) as f64;
1491    let gam_prime = gradient_uniform(gamma, h);
1492    let gam_pprime = gradient_uniform(&gam_prime, h);
1493
1494    let integrand: Vec<f64> = gam_pprime.iter().map(|&g| g * g).collect();
1495    crate::helpers::trapz(&integrand, argvals)
1496}
1497
1498/// Compute comprehensive alignment quality metrics.
1499///
1500/// # Arguments
1501/// * `data` — Original functional data (n × m)
1502/// * `karcher` — Pre-computed Karcher mean result
1503/// * `argvals` — Evaluation points (length m)
1504pub fn alignment_quality(
1505    data: &FdMatrix,
1506    karcher: &KarcherMeanResult,
1507    argvals: &[f64],
1508) -> AlignmentQuality {
1509    let (n, m) = data.shape();
1510    let weights = simpsons_weights(argvals);
1511
1512    // Per-curve warp complexity and smoothness
1513    let wc: Vec<f64> = (0..n)
1514        .map(|i| {
1515            let gamma: Vec<f64> = (0..m).map(|j| karcher.gammas[(i, j)]).collect();
1516            warp_complexity(&gamma, argvals)
1517        })
1518        .collect();
1519    let ws: Vec<f64> = (0..n)
1520        .map(|i| {
1521            let gamma: Vec<f64> = (0..m).map(|j| karcher.gammas[(i, j)]).collect();
1522            warp_smoothness(&gamma, argvals)
1523        })
1524        .collect();
1525
1526    let mean_wc = wc.iter().sum::<f64>() / n as f64;
1527    let mean_ws = ws.iter().sum::<f64>() / n as f64;
1528
1529    // Compute original mean
1530    let orig_mean = crate::fdata::mean_1d(data);
1531
1532    // Total variance
1533    let total_var: f64 = (0..n)
1534        .map(|i| {
1535            let fi = data.row(i);
1536            let d = l2_distance(&fi, &orig_mean, &weights);
1537            d * d
1538        })
1539        .sum::<f64>()
1540        / n as f64;
1541
1542    // Aligned mean
1543    let aligned_mean = crate::fdata::mean_1d(&karcher.aligned_data);
1544
1545    // Amplitude variance
1546    let amp_var: f64 = (0..n)
1547        .map(|i| {
1548            let fi = karcher.aligned_data.row(i);
1549            let d = l2_distance(&fi, &aligned_mean, &weights);
1550            d * d
1551        })
1552        .sum::<f64>()
1553        / n as f64;
1554
1555    let phase_var = (total_var - amp_var).max(0.0);
1556    let ratio = if total_var > 1e-10 {
1557        phase_var / total_var
1558    } else {
1559        0.0
1560    };
1561
1562    // Pointwise variance ratio
1563    let mut pw_ratio = vec![0.0; m];
1564    for j in 0..m {
1565        let col_orig = data.column(j);
1566        let mean_orig = col_orig.iter().sum::<f64>() / n as f64;
1567        let var_orig: f64 = col_orig
1568            .iter()
1569            .map(|&v| (v - mean_orig).powi(2))
1570            .sum::<f64>()
1571            / n as f64;
1572
1573        let col_aligned = karcher.aligned_data.column(j);
1574        let mean_aligned = col_aligned.iter().sum::<f64>() / n as f64;
1575        let var_aligned: f64 = col_aligned
1576            .iter()
1577            .map(|&v| (v - mean_aligned).powi(2))
1578            .sum::<f64>()
1579            / n as f64;
1580
1581        pw_ratio[j] = if var_orig > 1e-15 {
1582            var_aligned / var_orig
1583        } else {
1584            1.0
1585        };
1586    }
1587
1588    let mean_vr = pw_ratio.iter().sum::<f64>() / m as f64;
1589
1590    AlignmentQuality {
1591        warp_complexity: wc,
1592        mean_warp_complexity: mean_wc,
1593        warp_smoothness: ws,
1594        mean_warp_smoothness: mean_ws,
1595        total_variance: total_var,
1596        amplitude_variance: amp_var,
1597        phase_variance: phase_var,
1598        phase_amplitude_ratio: ratio,
1599        pointwise_variance_ratio: pw_ratio,
1600        mean_variance_reduction: mean_vr,
1601    }
1602}
1603
1604/// Generate triplet indices (i,j,k) with i<j<k, capped at `max_triplets` (0 = all).
1605fn triplet_indices(n: usize, max_triplets: usize) -> Vec<(usize, usize, usize)> {
1606    let total = n * (n - 1) * (n - 2) / 6;
1607    let cap = if max_triplets > 0 {
1608        max_triplets.min(total)
1609    } else {
1610        total
1611    };
1612    (0..n)
1613        .flat_map(|i| ((i + 1)..n).flat_map(move |j| ((j + 1)..n).map(move |k| (i, j, k))))
1614        .take(cap)
1615        .collect()
1616}
1617
1618/// Compute the warp deviation for one triplet: ‖γ_ij∘γ_jk − γ_ik‖_L2.
1619fn triplet_warp_deviation(
1620    data: &FdMatrix,
1621    argvals: &[f64],
1622    weights: &[f64],
1623    i: usize,
1624    j: usize,
1625    k: usize,
1626    lambda: f64,
1627) -> f64 {
1628    let fi = data.row(i);
1629    let fj = data.row(j);
1630    let fk = data.row(k);
1631    let rij = elastic_align_pair(&fi, &fj, argvals, lambda);
1632    let rjk = elastic_align_pair(&fj, &fk, argvals, lambda);
1633    let rik = elastic_align_pair(&fi, &fk, argvals, lambda);
1634    let composed = compose_warps(&rij.gamma, &rjk.gamma, argvals);
1635    l2_distance(&composed, &rik.gamma, weights)
1636}
1637
1638/// Measure pairwise alignment consistency via triplet checks.
1639///
1640/// For triplets (i,j,k), checks `γ_ij ∘ γ_jk ≈ γ_ik` by measuring the L2
1641/// deviation of the composed warp from the direct warp.
1642///
1643/// # Arguments
1644/// * `data` — Functional data (n × m)
1645/// * `argvals` — Evaluation points (length m)
1646/// * `lambda` — Penalty weight
1647/// * `max_triplets` — Maximum number of triplets to check (0 = all)
1648pub fn pairwise_consistency(
1649    data: &FdMatrix,
1650    argvals: &[f64],
1651    lambda: f64,
1652    max_triplets: usize,
1653) -> f64 {
1654    let n = data.nrows();
1655    if n < 3 {
1656        return 0.0;
1657    }
1658
1659    let weights = simpsons_weights(argvals);
1660    let triplets = triplet_indices(n, max_triplets);
1661    if triplets.is_empty() {
1662        return 0.0;
1663    }
1664
1665    let total_dev: f64 = triplets
1666        .iter()
1667        .map(|&(i, j, k)| triplet_warp_deviation(data, argvals, &weights, i, j, k, lambda))
1668        .sum();
1669    total_dev / triplets.len() as f64
1670}
1671
1672// ─── Landmark-Constrained Elastic Alignment ────────────────────────────────
1673
1674/// Result of landmark-constrained elastic alignment.
1675#[derive(Debug, Clone)]
1676pub struct ConstrainedAlignmentResult {
1677    /// Optimal warping function (length m).
1678    pub gamma: Vec<f64>,
1679    /// Aligned curve f2∘γ (length m).
1680    pub f_aligned: Vec<f64>,
1681    /// Elastic distance after alignment.
1682    pub distance: f64,
1683    /// Enforced landmark pairs (snapped to grid): `(target_t, source_t)`.
1684    pub enforced_landmarks: Vec<(f64, f64)>,
1685}
1686
1687/// Snap a time value to the nearest grid point index.
1688fn snap_to_grid(t_val: f64, argvals: &[f64]) -> usize {
1689    let mut best = 0;
1690    let mut best_dist = (t_val - argvals[0]).abs();
1691    for (i, &a) in argvals.iter().enumerate().skip(1) {
1692        let d = (t_val - a).abs();
1693        if d < best_dist {
1694            best = i;
1695            best_dist = d;
1696        }
1697    }
1698    best
1699}
1700
1701/// Run DP on a rectangular sub-grid `[sc..=ec] × [sr..=er]`.
1702///
1703/// Uses global indices for `dp_edge_weight`. Returns the path segment
1704/// as a list of `(tc_idx, tr_idx)` pairs from start to end.
1705fn dp_segment(
1706    q1: &[f64],
1707    q2: &[f64],
1708    argvals: &[f64],
1709    sc: usize,
1710    ec: usize,
1711    sr: usize,
1712    er: usize,
1713    lambda: f64,
1714) -> Vec<(usize, usize)> {
1715    let nc = ec - sc + 1;
1716    let nr = er - sr + 1;
1717
1718    if nc <= 1 || nr <= 1 {
1719        return vec![(sc, sr), (ec, er)];
1720    }
1721
1722    let path = dp_grid_solve(nr, nc, |local_sr, local_sc, local_tr, local_tc| {
1723        let gsr = sr + local_sr;
1724        let gsc = sc + local_sc;
1725        let gtr = sr + local_tr;
1726        let gtc = sc + local_tc;
1727        dp_edge_weight(q1, q2, argvals, gsc, gtc, gsr, gtr)
1728            + dp_lambda_penalty(argvals, gsc, gtc, gsr, gtr, lambda)
1729    });
1730
1731    // Convert local indices to global
1732    path.iter().map(|&(lr, lc)| (sc + lc, sr + lr)).collect()
1733}
1734
1735/// Align f2 to f1 with landmark constraints.
1736///
1737/// Landmark pairs define waypoints on the DP grid. Between consecutive waypoints,
1738/// an independent smaller DP is run. The resulting warp passes through all landmarks.
1739///
1740/// # Arguments
1741/// * `f1` — Target curve (length m)
1742/// * `f2` — Curve to align (length m)
1743/// * `argvals` — Evaluation points (length m)
1744/// * `landmark_pairs` — `(target_t, source_t)` pairs in increasing order
1745/// * `lambda` — Penalty weight
1746///
1747/// # Returns
1748/// [`ConstrainedAlignmentResult`] with warp, aligned curve, and enforced landmarks.
1749/// Build DP waypoints from landmark pairs: snap to grid, deduplicate, add endpoints.
1750fn build_constrained_waypoints(
1751    landmark_pairs: &[(f64, f64)],
1752    argvals: &[f64],
1753    m: usize,
1754) -> Vec<(usize, usize)> {
1755    let mut waypoints: Vec<(usize, usize)> = Vec::with_capacity(landmark_pairs.len() + 2);
1756    waypoints.push((0, 0));
1757    for &(tt, st) in landmark_pairs {
1758        let tc = snap_to_grid(tt, argvals);
1759        let tr = snap_to_grid(st, argvals);
1760        if let Some(&(prev_c, prev_r)) = waypoints.last() {
1761            if tc > prev_c && tr > prev_r {
1762                waypoints.push((tc, tr));
1763            }
1764        }
1765    }
1766    let last = m - 1;
1767    if let Some(&(prev_c, prev_r)) = waypoints.last() {
1768        if prev_c != last || prev_r != last {
1769            waypoints.push((last, last));
1770        }
1771    }
1772    waypoints
1773}
1774
1775/// Run DP segments between consecutive waypoints and assemble into a gamma warp.
1776fn segmented_dp_gamma(
1777    q1n: &[f64],
1778    q2n: &[f64],
1779    argvals: &[f64],
1780    waypoints: &[(usize, usize)],
1781    lambda: f64,
1782) -> Vec<f64> {
1783    let mut full_path_tc: Vec<f64> = Vec::new();
1784    let mut full_path_tr: Vec<f64> = Vec::new();
1785
1786    for seg in 0..(waypoints.len() - 1) {
1787        let (sc, sr) = waypoints[seg];
1788        let (ec, er) = waypoints[seg + 1];
1789        let segment_path = dp_segment(q1n, q2n, argvals, sc, ec, sr, er, lambda);
1790        let start = if seg > 0 { 1 } else { 0 };
1791        for &(tc, tr) in &segment_path[start..] {
1792            full_path_tc.push(argvals[tc]);
1793            full_path_tr.push(argvals[tr]);
1794        }
1795    }
1796
1797    let mut gamma: Vec<f64> = argvals
1798        .iter()
1799        .map(|&t| linear_interp(&full_path_tc, &full_path_tr, t))
1800        .collect();
1801    normalize_warp(&mut gamma, argvals);
1802    gamma
1803}
1804
1805pub fn elastic_align_pair_constrained(
1806    f1: &[f64],
1807    f2: &[f64],
1808    argvals: &[f64],
1809    landmark_pairs: &[(f64, f64)],
1810    lambda: f64,
1811) -> ConstrainedAlignmentResult {
1812    let m = f1.len();
1813
1814    if landmark_pairs.is_empty() {
1815        let r = elastic_align_pair(f1, f2, argvals, lambda);
1816        return ConstrainedAlignmentResult {
1817            gamma: r.gamma,
1818            f_aligned: r.f_aligned,
1819            distance: r.distance,
1820            enforced_landmarks: Vec::new(),
1821        };
1822    }
1823
1824    // Compute & normalize SRSFs
1825    let f1_mat = FdMatrix::from_slice(f1, 1, m).unwrap();
1826    let f2_mat = FdMatrix::from_slice(f2, 1, m).unwrap();
1827    let q1_mat = srsf_transform(&f1_mat, argvals);
1828    let q2_mat = srsf_transform(&f2_mat, argvals);
1829    let q1: Vec<f64> = q1_mat.row(0);
1830    let q2: Vec<f64> = q2_mat.row(0);
1831    let norm1 = q1.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
1832    let norm2 = q2.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
1833    let q1n: Vec<f64> = q1.iter().map(|&v| v / norm1).collect();
1834    let q2n: Vec<f64> = q2.iter().map(|&v| v / norm2).collect();
1835
1836    let waypoints = build_constrained_waypoints(landmark_pairs, argvals, m);
1837    let gamma = segmented_dp_gamma(&q1n, &q2n, argvals, &waypoints, lambda);
1838
1839    let f_aligned = reparameterize_curve(f2, argvals, &gamma);
1840    let f_aligned_mat = FdMatrix::from_slice(&f_aligned, 1, m).unwrap();
1841    let q_aligned_mat = srsf_transform(&f_aligned_mat, argvals);
1842    let q_aligned: Vec<f64> = q_aligned_mat.row(0);
1843    let weights = simpsons_weights(argvals);
1844    let distance = l2_distance(&q1, &q_aligned, &weights);
1845
1846    let enforced: Vec<(f64, f64)> = waypoints[1..waypoints.len() - 1]
1847        .iter()
1848        .map(|&(tc, tr)| (argvals[tc], argvals[tr]))
1849        .collect();
1850
1851    ConstrainedAlignmentResult {
1852        gamma,
1853        f_aligned,
1854        distance,
1855        enforced_landmarks: enforced,
1856    }
1857}
1858
1859/// Align f2 to f1 with automatic landmark detection and elastic constraints.
1860///
1861/// Detects landmarks in both curves, matches them, and uses the matches
1862/// as constraints for segmented DP alignment.
1863///
1864/// # Arguments
1865/// * `f1` — Target curve (length m)
1866/// * `f2` — Curve to align (length m)
1867/// * `argvals` — Evaluation points (length m)
1868/// * `kind` — Type of landmarks to detect
1869/// * `min_prominence` — Minimum prominence for landmark detection
1870/// * `expected_count` — Expected number of landmarks (0 = all detected)
1871/// * `lambda` — Penalty weight
1872pub fn elastic_align_pair_with_landmarks(
1873    f1: &[f64],
1874    f2: &[f64],
1875    argvals: &[f64],
1876    kind: crate::landmark::LandmarkKind,
1877    min_prominence: f64,
1878    expected_count: usize,
1879    lambda: f64,
1880) -> ConstrainedAlignmentResult {
1881    let lm1 = crate::landmark::detect_landmarks(f1, argvals, kind, min_prominence);
1882    let lm2 = crate::landmark::detect_landmarks(f2, argvals, kind, min_prominence);
1883
1884    // Match landmarks by order (take min count)
1885    let n_match = if expected_count > 0 {
1886        expected_count.min(lm1.len()).min(lm2.len())
1887    } else {
1888        lm1.len().min(lm2.len())
1889    };
1890
1891    let pairs: Vec<(f64, f64)> = (0..n_match)
1892        .map(|i| (lm1[i].position, lm2[i].position))
1893        .collect();
1894
1895    elastic_align_pair_constrained(f1, f2, argvals, &pairs, lambda)
1896}
1897
1898// ─── Multidimensional SRSF (R^d curves) ────────────────────────────────────
1899
1900use crate::matrix::FdCurveSet;
1901
1902/// Result of aligning multidimensional (R^d) curves.
1903#[derive(Debug, Clone)]
1904pub struct AlignmentResultNd {
1905    /// Optimal warping function (length m), same for all dimensions.
1906    pub gamma: Vec<f64>,
1907    /// Aligned curve: d vectors, each length m.
1908    pub f_aligned: Vec<Vec<f64>>,
1909    /// Elastic distance after alignment.
1910    pub distance: f64,
1911}
1912
1913/// Compute the SRSF transform for multidimensional (R^d) curves.
1914///
1915/// For f: \[0,1\] → R^d, the SRSF is q(t) = f'(t) / √‖f'(t)‖ where ‖·‖ is the
1916/// Euclidean norm in R^d. For d=1 this reduces to `sign(f') · √|f'|`.
1917///
1918/// # Arguments
1919/// * `data` — Set of n curves in R^d, each with m evaluation points
1920/// * `argvals` — Evaluation points (length m)
1921///
1922/// # Returns
1923/// `FdCurveSet` of SRSF values with the same shape as input.
1924/// Scale derivative vector at one point by 1/√‖f'‖, writing into result_dims.
1925#[inline]
1926fn srsf_scale_point(derivs: &[FdMatrix], result_dims: &mut [FdMatrix], i: usize, j: usize) {
1927    let d = derivs.len();
1928    let norm_sq: f64 = derivs.iter().map(|dd| dd[(i, j)].powi(2)).sum();
1929    let norm = norm_sq.sqrt();
1930    if norm < 1e-15 {
1931        for k in 0..d {
1932            result_dims[k][(i, j)] = 0.0;
1933        }
1934    } else {
1935        let scale = 1.0 / norm.sqrt();
1936        for k in 0..d {
1937            result_dims[k][(i, j)] = derivs[k][(i, j)] * scale;
1938        }
1939    }
1940}
1941
1942pub fn srsf_transform_nd(data: &FdCurveSet, argvals: &[f64]) -> FdCurveSet {
1943    let d = data.ndim();
1944    let n = data.ncurves();
1945    let m = data.npoints();
1946
1947    if d == 0 || n == 0 || m == 0 || argvals.len() != m {
1948        return FdCurveSet {
1949            dims: (0..d).map(|_| FdMatrix::zeros(n, m)).collect(),
1950        };
1951    }
1952
1953    let derivs: Vec<FdMatrix> = data
1954        .dims
1955        .iter()
1956        .map(|dim_mat| crate::fdata::deriv_1d(dim_mat, argvals, 1))
1957        .collect();
1958
1959    let mut result_dims: Vec<FdMatrix> = (0..d).map(|_| FdMatrix::zeros(n, m)).collect();
1960    for i in 0..n {
1961        for j in 0..m {
1962            srsf_scale_point(&derivs, &mut result_dims, i, j);
1963        }
1964    }
1965
1966    FdCurveSet { dims: result_dims }
1967}
1968
1969/// Reconstruct an R^d curve from its SRSF.
1970///
1971/// Given d-dimensional SRSF vectors and initial point f0, reconstructs:
1972/// `f_k(t) = f0_k + ∫₀ᵗ q_k(s) · ‖q(s)‖ ds` for each dimension k.
1973///
1974/// # Arguments
1975/// * `q` — SRSF: d vectors, each length m
1976/// * `argvals` — Evaluation points (length m)
1977/// * `f0` — Initial values in R^d (length d)
1978///
1979/// # Returns
1980/// Reconstructed curve: d vectors, each length m.
1981pub fn srsf_inverse_nd(q: &[Vec<f64>], argvals: &[f64], f0: &[f64]) -> Vec<Vec<f64>> {
1982    let d = q.len();
1983    if d == 0 {
1984        return Vec::new();
1985    }
1986    let m = q[0].len();
1987    if m == 0 {
1988        return vec![Vec::new(); d];
1989    }
1990
1991    // Compute ||q(t)|| at each time point
1992    let norms: Vec<f64> = (0..m)
1993        .map(|j| {
1994            let norm_sq: f64 = q.iter().map(|qk| qk[j].powi(2)).sum();
1995            norm_sq.sqrt()
1996        })
1997        .collect();
1998
1999    // For each dimension, integrand = q_k(t) * ||q(t)||
2000    let mut result = Vec::with_capacity(d);
2001    for k in 0..d {
2002        let integrand: Vec<f64> = (0..m).map(|j| q[k][j] * norms[j]).collect();
2003        let integral = cumulative_trapz(&integrand, argvals);
2004        let curve: Vec<f64> = integral.iter().map(|&v| f0[k] + v).collect();
2005        result.push(curve);
2006    }
2007
2008    result
2009}
2010
2011/// Core DP alignment for R^d SRSFs.
2012///
2013/// Same DP grid and coprime neighborhood as `dp_alignment_core`, but edge weight
2014/// is the sum of `dp_edge_weight` over d dimensions.
2015fn dp_alignment_core_nd(
2016    q1: &[Vec<f64>],
2017    q2: &[Vec<f64>],
2018    argvals: &[f64],
2019    lambda: f64,
2020) -> Vec<f64> {
2021    let d = q1.len();
2022    let m = argvals.len();
2023    if m < 2 || d == 0 {
2024        return argvals.to_vec();
2025    }
2026
2027    // For d=1, delegate to existing implementation for exact backward compat
2028    if d == 1 {
2029        return dp_alignment_core(&q1[0], &q2[0], argvals, lambda);
2030    }
2031
2032    // Normalize each dimension's SRSF to unit L2 norm
2033    let q1n: Vec<Vec<f64>> = q1
2034        .iter()
2035        .map(|qk| {
2036            let norm = qk.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
2037            qk.iter().map(|&v| v / norm).collect()
2038        })
2039        .collect();
2040    let q2n: Vec<Vec<f64>> = q2
2041        .iter()
2042        .map(|qk| {
2043            let norm = qk.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
2044            qk.iter().map(|&v| v / norm).collect()
2045        })
2046        .collect();
2047
2048    let path = dp_grid_solve(m, m, |sr, sc, tr, tc| {
2049        let w: f64 = (0..d)
2050            .map(|k| dp_edge_weight(&q1n[k], &q2n[k], argvals, sc, tc, sr, tr))
2051            .sum();
2052        w + dp_lambda_penalty(argvals, sc, tc, sr, tr, lambda)
2053    });
2054
2055    dp_path_to_gamma(&path, argvals)
2056}
2057
2058/// Align an R^d curve f2 to f1 using the elastic framework.
2059///
2060/// Finds the optimal warping γ (shared across all dimensions) such that
2061/// f2∘γ is as close as possible to f1 in the elastic metric.
2062///
2063/// # Arguments
2064/// * `f1` — Target curves (d dimensions)
2065/// * `f2` — Curves to align (d dimensions)
2066/// * `argvals` — Evaluation points (length m)
2067/// * `lambda` — Penalty weight (0.0 = no penalty)
2068pub fn elastic_align_pair_nd(
2069    f1: &FdCurveSet,
2070    f2: &FdCurveSet,
2071    argvals: &[f64],
2072    lambda: f64,
2073) -> AlignmentResultNd {
2074    let d = f1.ndim();
2075    let m = f1.npoints();
2076
2077    // Compute SRSFs
2078    let q1_set = srsf_transform_nd(f1, argvals);
2079    let q2_set = srsf_transform_nd(f2, argvals);
2080
2081    // Extract first curve from each dimension
2082    let q1: Vec<Vec<f64>> = q1_set.dims.iter().map(|dm| dm.row(0)).collect();
2083    let q2: Vec<Vec<f64>> = q2_set.dims.iter().map(|dm| dm.row(0)).collect();
2084
2085    // DP alignment using summed cost over dimensions
2086    let gamma = dp_alignment_core_nd(&q1, &q2, argvals, lambda);
2087
2088    // Apply warping to f2 in each dimension
2089    let f_aligned: Vec<Vec<f64>> = f2
2090        .dims
2091        .iter()
2092        .map(|dm| {
2093            let row = dm.row(0);
2094            reparameterize_curve(&row, argvals, &gamma)
2095        })
2096        .collect();
2097
2098    // Compute elastic distance: sum of squared L2 distances between aligned SRSFs
2099    let f_aligned_set = {
2100        let dims: Vec<FdMatrix> = f_aligned
2101            .iter()
2102            .map(|fa| FdMatrix::from_slice(fa, 1, m).unwrap())
2103            .collect();
2104        FdCurveSet { dims }
2105    };
2106    let q_aligned = srsf_transform_nd(&f_aligned_set, argvals);
2107    let weights = simpsons_weights(argvals);
2108
2109    let mut dist_sq = 0.0;
2110    for k in 0..d {
2111        let q1k = q1_set.dims[k].row(0);
2112        let qak = q_aligned.dims[k].row(0);
2113        let d_k = l2_distance(&q1k, &qak, &weights);
2114        dist_sq += d_k * d_k;
2115    }
2116
2117    AlignmentResultNd {
2118        gamma,
2119        f_aligned,
2120        distance: dist_sq.sqrt(),
2121    }
2122}
2123
2124/// Elastic distance between two R^d curves.
2125///
2126/// Aligns f2 to f1 and returns the post-alignment SRSF distance.
2127pub fn elastic_distance_nd(f1: &FdCurveSet, f2: &FdCurveSet, argvals: &[f64], lambda: f64) -> f64 {
2128    elastic_align_pair_nd(f1, f2, argvals, lambda).distance
2129}
2130
2131// ─── Tests ──────────────────────────────────────────────────────────────────
2132
2133#[cfg(test)]
2134mod tests {
2135    use super::*;
2136    use crate::helpers::trapz;
2137    use crate::simulation::{sim_fundata, EFunType, EValType};
2138    use crate::warping::inner_product_l2;
2139
2140    fn uniform_grid(m: usize) -> Vec<f64> {
2141        (0..m).map(|i| i as f64 / (m - 1) as f64).collect()
2142    }
2143
2144    fn make_test_data(n: usize, m: usize, seed: u64) -> FdMatrix {
2145        let t = uniform_grid(m);
2146        sim_fundata(
2147            n,
2148            &t,
2149            3,
2150            EFunType::Fourier,
2151            EValType::Exponential,
2152            Some(seed),
2153        )
2154    }
2155
2156    // ── cumulative_trapz ──
2157
2158    #[test]
2159    fn test_cumulative_trapz_constant() {
2160        // ∫₀ᵗ 1 dt = t
2161        let x = uniform_grid(50);
2162        let y = vec![1.0; 50];
2163        let result = cumulative_trapz(&y, &x);
2164        assert!((result[0]).abs() < 1e-15, "cumulative_trapz(0) should be 0");
2165        for j in 1..50 {
2166            assert!(
2167                (result[j] - x[j]).abs() < 1e-12,
2168                "∫₀^{:.3} 1 dt should be {:.3}, got {:.3}",
2169                x[j],
2170                x[j],
2171                result[j]
2172            );
2173        }
2174    }
2175
2176    #[test]
2177    fn test_cumulative_trapz_linear() {
2178        // ∫₀ᵗ s ds = t²/2
2179        let m = 100;
2180        let x = uniform_grid(m);
2181        let y: Vec<f64> = x.clone();
2182        let result = cumulative_trapz(&y, &x);
2183        for j in 1..m {
2184            let expected = x[j] * x[j] / 2.0;
2185            assert!(
2186                (result[j] - expected).abs() < 1e-4,
2187                "∫₀^{:.3} s ds: expected {expected:.6}, got {:.6}",
2188                x[j],
2189                result[j]
2190            );
2191        }
2192    }
2193
2194    // ── normalize_warp ──
2195
2196    #[test]
2197    fn test_normalize_warp_fixes_boundaries() {
2198        let t = uniform_grid(10);
2199        let mut gamma = vec![0.1; 10]; // constant, wrong boundaries
2200        normalize_warp(&mut gamma, &t);
2201        assert_eq!(gamma[0], t[0]);
2202        assert_eq!(gamma[9], t[9]);
2203    }
2204
2205    #[test]
2206    fn test_normalize_warp_enforces_monotonicity() {
2207        let t = uniform_grid(5);
2208        let mut gamma = vec![0.0, 0.5, 0.3, 0.8, 1.0]; // non-monotone at index 2
2209        normalize_warp(&mut gamma, &t);
2210        for j in 1..5 {
2211            assert!(
2212                gamma[j] >= gamma[j - 1],
2213                "gamma should be monotone after normalization at j={j}"
2214            );
2215        }
2216    }
2217
2218    #[test]
2219    fn test_normalize_warp_identity_unchanged() {
2220        let t = uniform_grid(20);
2221        let mut gamma = t.clone();
2222        normalize_warp(&mut gamma, &t);
2223        for j in 0..20 {
2224            assert!(
2225                (gamma[j] - t[j]).abs() < 1e-15,
2226                "Identity warp should be unchanged"
2227            );
2228        }
2229    }
2230
2231    // ── linear_interp ──
2232
2233    #[test]
2234    fn test_linear_interp_at_nodes() {
2235        let x = vec![0.0, 1.0, 2.0, 3.0];
2236        let y = vec![0.0, 2.0, 4.0, 6.0];
2237        for i in 0..x.len() {
2238            assert!((linear_interp(&x, &y, x[i]) - y[i]).abs() < 1e-12);
2239        }
2240    }
2241
2242    #[test]
2243    fn test_linear_interp_midpoints() {
2244        let x = vec![0.0, 1.0, 2.0];
2245        let y = vec![0.0, 2.0, 4.0];
2246        assert!((linear_interp(&x, &y, 0.5) - 1.0).abs() < 1e-12);
2247        assert!((linear_interp(&x, &y, 1.5) - 3.0).abs() < 1e-12);
2248    }
2249
2250    #[test]
2251    fn test_linear_interp_clamp() {
2252        let x = vec![0.0, 1.0, 2.0];
2253        let y = vec![1.0, 3.0, 5.0];
2254        assert!((linear_interp(&x, &y, -1.0) - 1.0).abs() < 1e-12);
2255        assert!((linear_interp(&x, &y, 3.0) - 5.0).abs() < 1e-12);
2256    }
2257
2258    #[test]
2259    fn test_linear_interp_nonuniform_grid() {
2260        let x = vec![0.0, 0.1, 0.5, 1.0];
2261        let y = vec![0.0, 1.0, 5.0, 10.0];
2262        // Between 0.1 and 0.5: slope = (5-1)/(0.5-0.1) = 10
2263        let val = linear_interp(&x, &y, 0.3);
2264        let expected = 1.0 + 10.0 * (0.3 - 0.1);
2265        assert!(
2266            (val - expected).abs() < 1e-12,
2267            "Non-uniform interp: expected {expected}, got {val}"
2268        );
2269    }
2270
2271    #[test]
2272    fn test_linear_interp_two_points() {
2273        let x = vec![0.0, 1.0];
2274        let y = vec![3.0, 7.0];
2275        assert!((linear_interp(&x, &y, 0.25) - 4.0).abs() < 1e-12);
2276        assert!((linear_interp(&x, &y, 0.75) - 6.0).abs() < 1e-12);
2277    }
2278
2279    // ── SRSF transform ──
2280
2281    #[test]
2282    fn test_srsf_transform_linear() {
2283        // f(t) = 2t: derivative = 2, SRSF = sqrt(2)
2284        let m = 50;
2285        let t = uniform_grid(m);
2286        let f: Vec<f64> = t.iter().map(|&ti| 2.0 * ti).collect();
2287        let mat = FdMatrix::from_slice(&f, 1, m).unwrap();
2288
2289        let q_mat = srsf_transform(&mat, &t);
2290        let q: Vec<f64> = q_mat.row(0);
2291
2292        let expected = 2.0_f64.sqrt();
2293        // Interior points should be close to sqrt(2)
2294        for j in 2..(m - 2) {
2295            assert!(
2296                (q[j] - expected).abs() < 0.1,
2297                "q[{j}] = {}, expected ~{expected}",
2298                q[j]
2299            );
2300        }
2301    }
2302
2303    #[test]
2304    fn test_srsf_transform_preserves_shape() {
2305        let data = make_test_data(10, 50, 42);
2306        let t = uniform_grid(50);
2307        let q = srsf_transform(&data, &t);
2308        assert_eq!(q.shape(), data.shape());
2309    }
2310
2311    #[test]
2312    fn test_srsf_transform_constant_is_zero() {
2313        // f(t) = 5 (constant): derivative = 0, SRSF = 0
2314        let m = 30;
2315        let t = uniform_grid(m);
2316        let f = vec![5.0; m];
2317        let mat = FdMatrix::from_slice(&f, 1, m).unwrap();
2318        let q_mat = srsf_transform(&mat, &t);
2319        let q: Vec<f64> = q_mat.row(0);
2320
2321        for j in 0..m {
2322            assert!(
2323                q[j].abs() < 1e-10,
2324                "SRSF of constant should be 0, got q[{j}] = {}",
2325                q[j]
2326            );
2327        }
2328    }
2329
2330    #[test]
2331    fn test_srsf_transform_negative_slope() {
2332        // f(t) = -3t: derivative = -3, SRSF = -sqrt(3)
2333        let m = 50;
2334        let t = uniform_grid(m);
2335        let f: Vec<f64> = t.iter().map(|&ti| -3.0 * ti).collect();
2336        let mat = FdMatrix::from_slice(&f, 1, m).unwrap();
2337
2338        let q_mat = srsf_transform(&mat, &t);
2339        let q: Vec<f64> = q_mat.row(0);
2340
2341        let expected = -(3.0_f64.sqrt());
2342        for j in 2..(m - 2) {
2343            assert!(
2344                (q[j] - expected).abs() < 0.15,
2345                "q[{j}] = {}, expected ~{expected}",
2346                q[j]
2347            );
2348        }
2349    }
2350
2351    #[test]
2352    fn test_srsf_transform_empty_input() {
2353        let data = FdMatrix::zeros(0, 0);
2354        let t: Vec<f64> = vec![];
2355        let q = srsf_transform(&data, &t);
2356        assert_eq!(q.shape(), (0, 0));
2357    }
2358
2359    #[test]
2360    fn test_srsf_transform_multiple_curves() {
2361        let m = 40;
2362        let t = uniform_grid(m);
2363        let data = make_test_data(5, m, 42);
2364
2365        let q = srsf_transform(&data, &t);
2366        assert_eq!(q.shape(), (5, m));
2367
2368        // Each row should have finite values
2369        for i in 0..5 {
2370            for j in 0..m {
2371                assert!(q[(i, j)].is_finite(), "SRSF should be finite at ({i},{j})");
2372            }
2373        }
2374    }
2375
2376    // ── SRSF inverse ──
2377
2378    #[test]
2379    fn test_srsf_round_trip() {
2380        let m = 100;
2381        let t = uniform_grid(m);
2382        // Use a smooth function
2383        let f: Vec<f64> = t
2384            .iter()
2385            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin() + ti)
2386            .collect();
2387
2388        let mat = FdMatrix::from_slice(&f, 1, m).unwrap();
2389        let q_mat = srsf_transform(&mat, &t);
2390        let q: Vec<f64> = q_mat.row(0);
2391
2392        let f_recon = srsf_inverse(&q, &t, f[0]);
2393
2394        // Check reconstruction is close (interior points, avoid boundary effects)
2395        let max_err: f64 = f[5..(m - 5)]
2396            .iter()
2397            .zip(f_recon[5..(m - 5)].iter())
2398            .map(|(a, b)| (a - b).abs())
2399            .fold(0.0_f64, f64::max);
2400
2401        assert!(
2402            max_err < 0.15,
2403            "Round-trip error too large: max_err = {max_err}"
2404        );
2405    }
2406
2407    #[test]
2408    fn test_srsf_inverse_empty() {
2409        let q: Vec<f64> = vec![];
2410        let t: Vec<f64> = vec![];
2411        let result = srsf_inverse(&q, &t, 0.0);
2412        assert!(result.is_empty());
2413    }
2414
2415    #[test]
2416    fn test_srsf_inverse_preserves_initial_value() {
2417        let m = 50;
2418        let t = uniform_grid(m);
2419        let q = vec![1.0; m]; // constant SRSF
2420        let f0 = 3.15;
2421        let f = srsf_inverse(&q, &t, f0);
2422        assert!((f[0] - f0).abs() < 1e-12, "srsf_inverse should start at f0");
2423    }
2424
2425    #[test]
2426    fn test_srsf_round_trip_multiple_curves() {
2427        let m = 80;
2428        let t = uniform_grid(m);
2429        let data = make_test_data(5, m, 99);
2430
2431        let q_mat = srsf_transform(&data, &t);
2432
2433        for i in 0..5 {
2434            let fi = data.row(i);
2435            let qi = q_mat.row(i);
2436            let f_recon = srsf_inverse(&qi, &t, fi[0]);
2437            let max_err: f64 = fi[5..(m - 5)]
2438                .iter()
2439                .zip(f_recon[5..(m - 5)].iter())
2440                .map(|(a, b)| (a - b).abs())
2441                .fold(0.0_f64, f64::max);
2442            assert!(max_err < 0.3, "Round-trip curve {i}: max_err = {max_err}");
2443        }
2444    }
2445
2446    // ── Reparameterize ──
2447
2448    #[test]
2449    fn test_reparameterize_identity_warp() {
2450        let m = 50;
2451        let t = uniform_grid(m);
2452        let f: Vec<f64> = t.iter().map(|&ti| ti * ti).collect();
2453
2454        // Identity warp: γ(t) = t
2455        let result = reparameterize_curve(&f, &t, &t);
2456        for j in 0..m {
2457            assert!(
2458                (result[j] - f[j]).abs() < 1e-12,
2459                "Identity warp should return original at j={j}"
2460            );
2461        }
2462    }
2463
2464    #[test]
2465    fn test_reparameterize_linear_warp() {
2466        let m = 50;
2467        let t = uniform_grid(m);
2468        // f(t) = t (linear), γ(t) = t^2 (quadratic warp on [0,1])
2469        let f: Vec<f64> = t.clone();
2470        let gamma: Vec<f64> = t.iter().map(|&ti| ti * ti).collect();
2471
2472        let result = reparameterize_curve(&f, &t, &gamma);
2473
2474        // f(γ(t)) = γ(t) = t^2 for a linear f(t) = t
2475        for j in 0..m {
2476            assert!(
2477                (result[j] - gamma[j]).abs() < 1e-10,
2478                "f(gamma(t)) should be gamma(t) for f(t)=t at j={j}"
2479            );
2480        }
2481    }
2482
2483    #[test]
2484    fn test_reparameterize_sine_with_quadratic_warp() {
2485        let m = 100;
2486        let t = uniform_grid(m);
2487        let f: Vec<f64> = t
2488            .iter()
2489            .map(|&ti| (std::f64::consts::PI * ti).sin())
2490            .collect();
2491        let gamma: Vec<f64> = t.iter().map(|&ti| ti * ti).collect(); // speeds up start
2492
2493        let result = reparameterize_curve(&f, &t, &gamma);
2494
2495        // f(γ(t)) = sin(π * t²); check a few known values
2496        for j in 0..m {
2497            let expected = (std::f64::consts::PI * gamma[j]).sin();
2498            assert!(
2499                (result[j] - expected).abs() < 0.05,
2500                "sin(π γ(t)) at j={j}: expected {expected:.4}, got {:.4}",
2501                result[j]
2502            );
2503        }
2504    }
2505
2506    #[test]
2507    fn test_reparameterize_preserves_length() {
2508        let m = 50;
2509        let t = uniform_grid(m);
2510        let f = vec![1.0; m];
2511        let gamma: Vec<f64> = t.iter().map(|&ti| ti.sqrt()).collect();
2512
2513        let result = reparameterize_curve(&f, &t, &gamma);
2514        assert_eq!(result.len(), m);
2515    }
2516
2517    // ── Compose warps ──
2518
2519    #[test]
2520    fn test_compose_warps_identity() {
2521        let m = 50;
2522        let t = uniform_grid(m);
2523        // γ(t) = t^0.5 (a warp on [0,1])
2524        let gamma: Vec<f64> = t.iter().map(|&ti| ti.sqrt()).collect();
2525
2526        // identity ∘ γ = γ
2527        let result = compose_warps(&t, &gamma, &t);
2528        for j in 0..m {
2529            assert!(
2530                (result[j] - gamma[j]).abs() < 1e-10,
2531                "id ∘ γ should be γ at j={j}"
2532            );
2533        }
2534
2535        // γ ∘ identity = γ
2536        let result2 = compose_warps(&gamma, &t, &t);
2537        for j in 0..m {
2538            assert!(
2539                (result2[j] - gamma[j]).abs() < 1e-10,
2540                "γ ∘ id should be γ at j={j}"
2541            );
2542        }
2543    }
2544
2545    #[test]
2546    fn test_compose_warps_associativity() {
2547        // (γ₁ ∘ γ₂) ∘ γ₃ ≈ γ₁ ∘ (γ₂ ∘ γ₃)
2548        let m = 50;
2549        let t = uniform_grid(m);
2550        let g1: Vec<f64> = t.iter().map(|&ti| ti.sqrt()).collect();
2551        let g2: Vec<f64> = t.iter().map(|&ti| ti * ti).collect();
2552        let g3: Vec<f64> = t.iter().map(|&ti| 0.5 * ti + 0.5 * ti * ti).collect();
2553
2554        let g12 = compose_warps(&g1, &g2, &t);
2555        let left = compose_warps(&g12, &g3, &t); // (g1∘g2) ∘ g3
2556
2557        let g23 = compose_warps(&g2, &g3, &t);
2558        let right = compose_warps(&g1, &g23, &t); // g1 ∘ (g2∘g3)
2559
2560        for j in 0..m {
2561            assert!(
2562                (left[j] - right[j]).abs() < 0.05,
2563                "Composition should be roughly associative at j={j}: left={:.4}, right={:.4}",
2564                left[j],
2565                right[j]
2566            );
2567        }
2568    }
2569
2570    #[test]
2571    fn test_compose_warps_preserves_domain() {
2572        let m = 50;
2573        let t = uniform_grid(m);
2574        let g1: Vec<f64> = t.iter().map(|&ti| ti.sqrt()).collect();
2575        let g2: Vec<f64> = t.iter().map(|&ti| ti * ti).collect();
2576
2577        let composed = compose_warps(&g1, &g2, &t);
2578        assert!(
2579            (composed[0] - t[0]).abs() < 1e-10,
2580            "Composed warp should start at domain start"
2581        );
2582        assert!(
2583            (composed[m - 1] - t[m - 1]).abs() < 1e-10,
2584            "Composed warp should end at domain end"
2585        );
2586    }
2587
2588    // ── Elastic align pair ──
2589
2590    #[test]
2591    fn test_align_identical_curves() {
2592        let m = 50;
2593        let t = uniform_grid(m);
2594        let f: Vec<f64> = t
2595            .iter()
2596            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
2597            .collect();
2598
2599        let result = elastic_align_pair(&f, &f, &t, 0.0);
2600
2601        // Distance should be near zero
2602        assert!(
2603            result.distance < 0.1,
2604            "Distance between identical curves should be near 0, got {}",
2605            result.distance
2606        );
2607
2608        // Warp should be near identity
2609        for j in 0..m {
2610            assert!(
2611                (result.gamma[j] - t[j]).abs() < 0.1,
2612                "Warp should be near identity at j={j}: gamma={}, t={}",
2613                result.gamma[j],
2614                t[j]
2615            );
2616        }
2617    }
2618
2619    #[test]
2620    fn test_align_pair_valid_output() {
2621        let data = make_test_data(2, 50, 42);
2622        let t = uniform_grid(50);
2623        let f1 = data.row(0);
2624        let f2 = data.row(1);
2625
2626        let result = elastic_align_pair(&f1, &f2, &t, 0.0);
2627
2628        assert_eq!(result.gamma.len(), 50);
2629        assert_eq!(result.f_aligned.len(), 50);
2630        assert!(result.distance >= 0.0);
2631
2632        // Warp should be monotone
2633        for j in 1..50 {
2634            assert!(
2635                result.gamma[j] >= result.gamma[j - 1],
2636                "Warp should be monotone at j={j}"
2637            );
2638        }
2639    }
2640
2641    #[test]
2642    fn test_align_pair_warp_boundaries() {
2643        let data = make_test_data(2, 50, 42);
2644        let t = uniform_grid(50);
2645        let f1 = data.row(0);
2646        let f2 = data.row(1);
2647
2648        let result = elastic_align_pair(&f1, &f2, &t, 0.0);
2649        assert!(
2650            (result.gamma[0] - t[0]).abs() < 1e-12,
2651            "Warp should start at domain start"
2652        );
2653        assert!(
2654            (result.gamma[49] - t[49]).abs() < 1e-12,
2655            "Warp should end at domain end"
2656        );
2657    }
2658
2659    #[test]
2660    fn test_align_shifted_sine() {
2661        // Two sines with a phase shift — alignment should reduce distance
2662        let m = 80;
2663        let t = uniform_grid(m);
2664        let f1: Vec<f64> = t
2665            .iter()
2666            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
2667            .collect();
2668        let f2: Vec<f64> = t
2669            .iter()
2670            .map(|&ti| (2.0 * std::f64::consts::PI * (ti - 0.1)).sin())
2671            .collect();
2672
2673        let weights = simpsons_weights(&t);
2674        let l2_before = l2_distance(&f1, &f2, &weights);
2675        let result = elastic_align_pair(&f1, &f2, &t, 0.0);
2676        let l2_after = l2_distance(&f1, &result.f_aligned, &weights);
2677
2678        assert!(
2679            l2_after < l2_before + 0.01,
2680            "Alignment should not increase L2 distance: before={l2_before:.4}, after={l2_after:.4}"
2681        );
2682    }
2683
2684    #[test]
2685    fn test_align_pair_aligned_curve_is_finite() {
2686        let data = make_test_data(2, 50, 77);
2687        let t = uniform_grid(50);
2688        let f1 = data.row(0);
2689        let f2 = data.row(1);
2690
2691        let result = elastic_align_pair(&f1, &f2, &t, 0.0);
2692        for j in 0..50 {
2693            assert!(
2694                result.f_aligned[j].is_finite(),
2695                "Aligned curve should be finite at j={j}"
2696            );
2697        }
2698    }
2699
2700    #[test]
2701    fn test_align_pair_minimum_grid() {
2702        // Minimum viable grid: m = 2
2703        let t = vec![0.0, 1.0];
2704        let f1 = vec![0.0, 1.0];
2705        let f2 = vec![0.0, 2.0];
2706        let result = elastic_align_pair(&f1, &f2, &t, 0.0);
2707        assert_eq!(result.gamma.len(), 2);
2708        assert_eq!(result.f_aligned.len(), 2);
2709        assert!(result.distance >= 0.0);
2710    }
2711
2712    // ── Elastic distance ──
2713
2714    #[test]
2715    fn test_elastic_distance_symmetric() {
2716        let data = make_test_data(3, 50, 42);
2717        let t = uniform_grid(50);
2718        let f1 = data.row(0);
2719        let f2 = data.row(1);
2720
2721        let d12 = elastic_distance(&f1, &f2, &t, 0.0);
2722        let d21 = elastic_distance(&f2, &f1, &t, 0.0);
2723
2724        // Should be approximately symmetric (DP is not perfectly symmetric)
2725        assert!(
2726            (d12 - d21).abs() < d12.max(d21) * 0.3 + 0.01,
2727            "Elastic distance should be roughly symmetric: d12={d12}, d21={d21}"
2728        );
2729    }
2730
2731    #[test]
2732    fn test_elastic_distance_nonneg() {
2733        let data = make_test_data(3, 50, 42);
2734        let t = uniform_grid(50);
2735
2736        for i in 0..3 {
2737            for j in 0..3 {
2738                let fi = data.row(i);
2739                let fj = data.row(j);
2740                let d = elastic_distance(&fi, &fj, &t, 0.0);
2741                assert!(d >= 0.0, "Elastic distance should be non-negative");
2742            }
2743        }
2744    }
2745
2746    #[test]
2747    fn test_elastic_distance_self_near_zero() {
2748        let data = make_test_data(3, 50, 42);
2749        let t = uniform_grid(50);
2750
2751        for i in 0..3 {
2752            let fi = data.row(i);
2753            let d = elastic_distance(&fi, &fi, &t, 0.0);
2754            assert!(
2755                d < 0.1,
2756                "Self-distance should be near zero, got {d} for curve {i}"
2757            );
2758        }
2759    }
2760
2761    #[test]
2762    fn test_elastic_distance_triangle_inequality() {
2763        let data = make_test_data(3, 50, 42);
2764        let t = uniform_grid(50);
2765        let f0 = data.row(0);
2766        let f1 = data.row(1);
2767        let f2 = data.row(2);
2768
2769        let d01 = elastic_distance(&f0, &f1, &t, 0.0);
2770        let d12 = elastic_distance(&f1, &f2, &t, 0.0);
2771        let d02 = elastic_distance(&f0, &f2, &t, 0.0);
2772
2773        // Relaxed triangle inequality (DP alignment is approximate)
2774        let slack = 0.5;
2775        assert!(
2776            d02 <= d01 + d12 + slack,
2777            "Triangle inequality (relaxed): d02={d02:.4} > d01={d01:.4} + d12={d12:.4} + {slack}"
2778        );
2779    }
2780
2781    #[test]
2782    fn test_elastic_distance_different_shapes_nonzero() {
2783        let m = 50;
2784        let t = uniform_grid(m);
2785        let f1: Vec<f64> = t.to_vec(); // linear
2786        let f2: Vec<f64> = t.iter().map(|&ti| ti * ti).collect(); // quadratic
2787
2788        let d = elastic_distance(&f1, &f2, &t, 0.0);
2789        assert!(
2790            d > 0.01,
2791            "Distance between different shapes should be > 0, got {d}"
2792        );
2793    }
2794
2795    // ── Self distance matrix ──
2796
2797    #[test]
2798    fn test_self_distance_matrix_symmetric() {
2799        let data = make_test_data(5, 30, 42);
2800        let t = uniform_grid(30);
2801
2802        let dm = elastic_self_distance_matrix(&data, &t, 0.0);
2803        let n = dm.nrows();
2804
2805        assert_eq!(dm.shape(), (5, 5));
2806
2807        // Zero diagonal
2808        for i in 0..n {
2809            assert!(dm[(i, i)].abs() < 1e-12, "Diagonal should be zero");
2810        }
2811
2812        // Symmetric
2813        for i in 0..n {
2814            for j in (i + 1)..n {
2815                assert!(
2816                    (dm[(i, j)] - dm[(j, i)]).abs() < 1e-12,
2817                    "Matrix should be symmetric at ({i},{j})"
2818                );
2819            }
2820        }
2821    }
2822
2823    #[test]
2824    fn test_self_distance_matrix_nonneg() {
2825        let data = make_test_data(4, 30, 42);
2826        let t = uniform_grid(30);
2827        let dm = elastic_self_distance_matrix(&data, &t, 0.0);
2828
2829        for i in 0..4 {
2830            for j in 0..4 {
2831                assert!(
2832                    dm[(i, j)] >= 0.0,
2833                    "Distance matrix entries should be non-negative at ({i},{j})"
2834                );
2835            }
2836        }
2837    }
2838
2839    #[test]
2840    fn test_self_distance_matrix_single_curve() {
2841        let data = make_test_data(1, 30, 42);
2842        let t = uniform_grid(30);
2843        let dm = elastic_self_distance_matrix(&data, &t, 0.0);
2844        assert_eq!(dm.shape(), (1, 1));
2845        assert!(dm[(0, 0)].abs() < 1e-12);
2846    }
2847
2848    #[test]
2849    fn test_self_distance_matrix_consistent_with_pairwise() {
2850        let data = make_test_data(4, 30, 42);
2851        let t = uniform_grid(30);
2852
2853        let dm = elastic_self_distance_matrix(&data, &t, 0.0);
2854
2855        // Check a few entries match direct elastic_distance calls
2856        for i in 0..4 {
2857            for j in (i + 1)..4 {
2858                let fi = data.row(i);
2859                let fj = data.row(j);
2860                let d_direct = elastic_distance(&fi, &fj, &t, 0.0);
2861                assert!(
2862                    (dm[(i, j)] - d_direct).abs() < 1e-10,
2863                    "Matrix entry ({i},{j})={:.6} should match pairwise {d_direct:.6}",
2864                    dm[(i, j)]
2865                );
2866            }
2867        }
2868    }
2869
2870    // ── Karcher mean ──
2871
2872    #[test]
2873    fn test_karcher_mean_identical_curves() {
2874        let m = 50;
2875        let t = uniform_grid(m);
2876        let f: Vec<f64> = t
2877            .iter()
2878            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
2879            .collect();
2880
2881        // Create 5 identical curves
2882        let mut data = FdMatrix::zeros(5, m);
2883        for i in 0..5 {
2884            for j in 0..m {
2885                data[(i, j)] = f[j];
2886            }
2887        }
2888
2889        let result = karcher_mean(&data, &t, 10, 1e-4, 0.0);
2890
2891        assert_eq!(result.mean.len(), m);
2892        assert!(result.n_iter <= 10);
2893    }
2894
2895    #[test]
2896    fn test_karcher_mean_output_shape() {
2897        let data = make_test_data(15, 50, 42);
2898        let t = uniform_grid(50);
2899
2900        let result = karcher_mean(&data, &t, 5, 1e-3, 0.0);
2901
2902        assert_eq!(result.mean.len(), 50);
2903        assert_eq!(result.mean_srsf.len(), 50);
2904        assert_eq!(result.gammas.shape(), (15, 50));
2905        assert_eq!(result.aligned_data.shape(), (15, 50));
2906        assert!(result.n_iter <= 5);
2907    }
2908
2909    #[test]
2910    fn test_karcher_mean_warps_are_valid() {
2911        let data = make_test_data(10, 40, 42);
2912        let t = uniform_grid(40);
2913
2914        let result = karcher_mean(&data, &t, 5, 1e-3, 0.0);
2915
2916        for i in 0..10 {
2917            // Boundary values
2918            assert!(
2919                (result.gammas[(i, 0)] - t[0]).abs() < 1e-10,
2920                "Warp {i} should start at domain start"
2921            );
2922            assert!(
2923                (result.gammas[(i, 39)] - t[39]).abs() < 1e-10,
2924                "Warp {i} should end at domain end"
2925            );
2926            // Monotonicity
2927            for j in 1..40 {
2928                assert!(
2929                    result.gammas[(i, j)] >= result.gammas[(i, j - 1)],
2930                    "Warp {i} should be monotone at j={j}"
2931                );
2932            }
2933        }
2934    }
2935
2936    #[test]
2937    fn test_karcher_mean_aligned_data_is_finite() {
2938        let data = make_test_data(8, 40, 42);
2939        let t = uniform_grid(40);
2940        let result = karcher_mean(&data, &t, 5, 1e-3, 0.0);
2941
2942        for i in 0..8 {
2943            for j in 0..40 {
2944                assert!(
2945                    result.aligned_data[(i, j)].is_finite(),
2946                    "Aligned data should be finite at ({i},{j})"
2947                );
2948            }
2949        }
2950    }
2951
2952    #[test]
2953    fn test_karcher_mean_srsf_is_finite() {
2954        let data = make_test_data(8, 40, 42);
2955        let t = uniform_grid(40);
2956        let result = karcher_mean(&data, &t, 5, 1e-3, 0.0);
2957
2958        for j in 0..40 {
2959            assert!(
2960                result.mean_srsf[j].is_finite(),
2961                "Mean SRSF should be finite at j={j}"
2962            );
2963            assert!(
2964                result.mean[j].is_finite(),
2965                "Mean curve should be finite at j={j}"
2966            );
2967        }
2968    }
2969
2970    #[test]
2971    fn test_karcher_mean_single_iteration() {
2972        let data = make_test_data(10, 40, 42);
2973        let t = uniform_grid(40);
2974        let result = karcher_mean(&data, &t, 1, 1e-10, 0.0);
2975
2976        assert_eq!(result.n_iter, 1);
2977        assert_eq!(result.mean.len(), 40);
2978        // With only 1 iteration, still produces valid output
2979        for j in 0..40 {
2980            assert!(result.mean[j].is_finite());
2981        }
2982    }
2983
2984    #[test]
2985    fn test_karcher_mean_convergence_not_premature() {
2986        // The old convergence criterion (rel - prev_rel <= tol * prev_rel) was
2987        // always satisfied when the algorithm improved, causing premature exit
2988        // after 2 iterations. With the fix (rel < tol), the algorithm should
2989        // actually iterate until convergence or hitting the iteration cap.
2990        let n = 10;
2991        let m = 50;
2992        let t = uniform_grid(m);
2993
2994        // Create phase-shifted curves that genuinely need alignment
2995        let mut col_major = vec![0.0; n * m];
2996        for i in 0..n {
2997            let shift = (i as f64 - 5.0) * 0.05;
2998            for j in 0..m {
2999                col_major[i + j * n] = (2.0 * std::f64::consts::PI * (t[j] - shift)).sin();
3000            }
3001        }
3002        let data = FdMatrix::from_column_major(col_major, n, m).unwrap();
3003
3004        // With an impossibly tight tolerance, the algorithm should hit the
3005        // iteration cap rather than "converging" after 2 iterations.
3006        let result = karcher_mean(&data, &t, 20, 1e-15, 0.0);
3007        assert!(
3008            result.n_iter > 2,
3009            "With tol=1e-15 the algorithm should iterate beyond 2, got n_iter={}",
3010            result.n_iter
3011        );
3012
3013        // With a reasonable tolerance, it should converge and report so
3014        let result_loose = karcher_mean(&data, &t, 20, 1e-2, 0.0);
3015        assert!(
3016            result_loose.converged,
3017            "With tol=1e-2 the algorithm should converge"
3018        );
3019    }
3020
3021    // ── Align to target ──
3022
3023    #[test]
3024    fn test_align_to_target_valid() {
3025        let data = make_test_data(10, 40, 42);
3026        let t = uniform_grid(40);
3027        let target = data.row(0);
3028
3029        let result = align_to_target(&data, &target, &t, 0.0);
3030
3031        assert_eq!(result.gammas.shape(), (10, 40));
3032        assert_eq!(result.aligned_data.shape(), (10, 40));
3033        assert_eq!(result.distances.len(), 10);
3034
3035        // All distances should be non-negative
3036        for &d in &result.distances {
3037            assert!(d >= 0.0);
3038        }
3039    }
3040
3041    #[test]
3042    fn test_align_to_target_self_near_zero() {
3043        let data = make_test_data(5, 40, 42);
3044        let t = uniform_grid(40);
3045        let target = data.row(0);
3046
3047        let result = align_to_target(&data, &target, &t, 0.0);
3048
3049        // Distance of target to itself should be near zero
3050        assert!(
3051            result.distances[0] < 0.1,
3052            "Self-alignment distance should be near zero, got {}",
3053            result.distances[0]
3054        );
3055    }
3056
3057    #[test]
3058    fn test_align_to_target_warps_are_monotone() {
3059        let data = make_test_data(8, 40, 42);
3060        let t = uniform_grid(40);
3061        let target = data.row(0);
3062        let result = align_to_target(&data, &target, &t, 0.0);
3063
3064        for i in 0..8 {
3065            for j in 1..40 {
3066                assert!(
3067                    result.gammas[(i, j)] >= result.gammas[(i, j - 1)],
3068                    "Warp for curve {i} should be monotone at j={j}"
3069                );
3070            }
3071        }
3072    }
3073
3074    #[test]
3075    fn test_align_to_target_aligned_data_finite() {
3076        let data = make_test_data(6, 40, 42);
3077        let t = uniform_grid(40);
3078        let target = data.row(0);
3079        let result = align_to_target(&data, &target, &t, 0.0);
3080
3081        for i in 0..6 {
3082            for j in 0..40 {
3083                assert!(
3084                    result.aligned_data[(i, j)].is_finite(),
3085                    "Aligned data should be finite at ({i},{j})"
3086                );
3087            }
3088        }
3089    }
3090
3091    // ── Cross distance matrix ──
3092
3093    #[test]
3094    fn test_cross_distance_matrix_shape() {
3095        let data1 = make_test_data(3, 30, 42);
3096        let data2 = make_test_data(4, 30, 99);
3097        let t = uniform_grid(30);
3098
3099        let dm = elastic_cross_distance_matrix(&data1, &data2, &t, 0.0);
3100        assert_eq!(dm.shape(), (3, 4));
3101
3102        // All non-negative
3103        for i in 0..3 {
3104            for j in 0..4 {
3105                assert!(dm[(i, j)] >= 0.0);
3106            }
3107        }
3108    }
3109
3110    #[test]
3111    fn test_cross_distance_matrix_self_matches_self_matrix() {
3112        // cross_distance(data, data) should have zero diagonal (approximately)
3113        let data = make_test_data(4, 30, 42);
3114        let t = uniform_grid(30);
3115
3116        let cross = elastic_cross_distance_matrix(&data, &data, &t, 0.0);
3117        for i in 0..4 {
3118            assert!(
3119                cross[(i, i)] < 0.1,
3120                "Cross distance (self) diagonal should be near zero: got {}",
3121                cross[(i, i)]
3122            );
3123        }
3124    }
3125
3126    #[test]
3127    fn test_cross_distance_matrix_consistent_with_pairwise() {
3128        let data1 = make_test_data(3, 30, 42);
3129        let data2 = make_test_data(2, 30, 99);
3130        let t = uniform_grid(30);
3131
3132        let dm = elastic_cross_distance_matrix(&data1, &data2, &t, 0.0);
3133
3134        for i in 0..3 {
3135            for j in 0..2 {
3136                let fi = data1.row(i);
3137                let fj = data2.row(j);
3138                let d_direct = elastic_distance(&fi, &fj, &t, 0.0);
3139                assert!(
3140                    (dm[(i, j)] - d_direct).abs() < 1e-10,
3141                    "Cross matrix ({i},{j})={:.6} should match pairwise {d_direct:.6}",
3142                    dm[(i, j)]
3143                );
3144            }
3145        }
3146    }
3147
3148    // ── align_srsf_pair ──
3149
3150    #[test]
3151    fn test_align_srsf_pair_identity() {
3152        let m = 50;
3153        let t = uniform_grid(m);
3154        let f: Vec<f64> = t
3155            .iter()
3156            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
3157            .collect();
3158        let q = srsf_single(&f, &t);
3159
3160        let (gamma, q_aligned) = align_srsf_pair(&q, &q, &t, 0.0);
3161
3162        // Warp should be near identity
3163        for j in 0..m {
3164            assert!(
3165                (gamma[j] - t[j]).abs() < 0.15,
3166                "Self-SRSF alignment warp should be near identity at j={j}"
3167            );
3168        }
3169
3170        // Aligned SRSF should be close to original
3171        let weights = simpsons_weights(&t);
3172        let dist = l2_distance(&q, &q_aligned, &weights);
3173        assert!(
3174            dist < 0.5,
3175            "Self-aligned SRSF distance should be small, got {dist}"
3176        );
3177    }
3178
3179    // ── srsf_single ──
3180
3181    #[test]
3182    fn test_srsf_single_matches_matrix_version() {
3183        let m = 50;
3184        let t = uniform_grid(m);
3185        let f: Vec<f64> = t.iter().map(|&ti| ti * ti + ti).collect();
3186
3187        let q_single = srsf_single(&f, &t);
3188
3189        let mat = FdMatrix::from_slice(&f, 1, m).unwrap();
3190        let q_mat = srsf_transform(&mat, &t);
3191        let q_from_mat = q_mat.row(0);
3192
3193        for j in 0..m {
3194            assert!(
3195                (q_single[j] - q_from_mat[j]).abs() < 1e-12,
3196                "srsf_single should match srsf_transform at j={j}"
3197            );
3198        }
3199    }
3200
3201    // ── gcd ──
3202
3203    #[test]
3204    fn test_gcd_basic() {
3205        assert_eq!(gcd(1, 1), 1);
3206        assert_eq!(gcd(6, 4), 2);
3207        assert_eq!(gcd(7, 5), 1);
3208        assert_eq!(gcd(12, 8), 4);
3209        assert_eq!(gcd(7, 0), 7);
3210        assert_eq!(gcd(0, 5), 5);
3211    }
3212
3213    // ── generate_coprime_nbhd ──
3214
3215    #[test]
3216    fn test_coprime_nbhd_count() {
3217        assert_eq!(generate_coprime_nbhd(1).len(), 1); // just (1,1)
3218        assert_eq!(generate_coprime_nbhd(7).len(), 35);
3219    }
3220
3221    #[test]
3222    fn test_coprime_nbhd_matches_const() {
3223        let generated = generate_coprime_nbhd(7);
3224        assert_eq!(generated.len(), COPRIME_NBHD_7.len());
3225        for (i, pair) in generated.iter().enumerate() {
3226            assert_eq!(*pair, COPRIME_NBHD_7[i], "mismatch at index {i}");
3227        }
3228    }
3229
3230    #[test]
3231    fn test_coprime_nbhd_all_coprime() {
3232        for &(i, j) in &COPRIME_NBHD_7 {
3233            assert_eq!(gcd(i, j), 1, "({i},{j}) should be coprime");
3234            assert!((1..=7).contains(&i));
3235            assert!((1..=7).contains(&j));
3236        }
3237    }
3238
3239    // ── dp_edge_weight ──
3240
3241    #[test]
3242    fn test_dp_edge_weight_diagonal() {
3243        // Diagonal move (1,1): weight = (q1[sc] - sqrt(1)*q2[sr])^2 * h
3244        let t = uniform_grid(10);
3245        let q1 = vec![1.0; 10];
3246        let q2 = vec![1.0; 10];
3247        // Identical SRSFs: weight should be 0
3248        let w = dp_edge_weight(&q1, &q2, &t, 0, 1, 0, 1);
3249        assert!(w.abs() < 1e-12, "identical SRSFs should have zero cost");
3250    }
3251
3252    #[test]
3253    fn test_dp_edge_weight_non_diagonal() {
3254        // Move (1,2): n1=2, n2=1, slope = h/(2h) = 0.5
3255        let t = uniform_grid(10);
3256        let q1 = vec![1.0; 10];
3257        let q2 = vec![0.0; 10];
3258        let w = dp_edge_weight(&q1, &q2, &t, 0, 2, 0, 1);
3259        // diff = q1[0] - sqrt(0.5)*q2[0] = 1.0 - 0 = 1.0
3260        // weight = 1.0^2 * 1.0 * (t[2]-t[0]) = 2/9
3261        let expected = 2.0 / 9.0;
3262        assert!(
3263            (w - expected).abs() < 1e-10,
3264            "dp_edge_weight (1,2): expected {expected}, got {w}"
3265        );
3266    }
3267
3268    #[test]
3269    fn test_dp_edge_weight_zero_span() {
3270        let t = uniform_grid(10);
3271        let q1 = vec![1.0; 10];
3272        let q2 = vec![1.0; 10];
3273        // n1=0: should return INFINITY
3274        assert_eq!(dp_edge_weight(&q1, &q2, &t, 3, 3, 0, 1), f64::INFINITY);
3275        // n2=0: should return INFINITY
3276        assert_eq!(dp_edge_weight(&q1, &q2, &t, 0, 1, 3, 3), f64::INFINITY);
3277    }
3278
3279    // ── DP alignment quality ──
3280
3281    #[test]
3282    fn test_alignment_improves_distance() {
3283        // Aligned SRSF distance should be less than unaligned SRSF distance
3284        let m = 50;
3285        let t = uniform_grid(m);
3286        let f1: Vec<f64> = t
3287            .iter()
3288            .map(|&x| (2.0 * std::f64::consts::PI * x).sin())
3289            .collect();
3290        // Use a larger shift so improvement is clear
3291        let f2: Vec<f64> = t
3292            .iter()
3293            .map(|&x| (2.0 * std::f64::consts::PI * (x + 0.2)).sin())
3294            .collect();
3295
3296        let q1 = srsf_single(&f1, &t);
3297        let q2 = srsf_single(&f2, &t);
3298        let weights = simpsons_weights(&t);
3299        let unaligned_srsf_dist = l2_distance(&q1, &q2, &weights);
3300
3301        let result = elastic_align_pair(&f1, &f2, &t, 0.0);
3302
3303        assert!(
3304            result.distance <= unaligned_srsf_dist + 1e-6,
3305            "aligned SRSF dist ({}) should be <= unaligned SRSF dist ({})",
3306            result.distance,
3307            unaligned_srsf_dist
3308        );
3309    }
3310
3311    // ── Edge case: constant data ──
3312
3313    #[test]
3314    fn test_alignment_constant_curves() {
3315        let m = 30;
3316        let t = uniform_grid(m);
3317        let f1 = vec![5.0; m];
3318        let f2 = vec![5.0; m];
3319
3320        let result = elastic_align_pair(&f1, &f2, &t, 0.0);
3321        assert!(
3322            result.distance < 0.01,
3323            "Constant curves: distance should be ~0"
3324        );
3325        assert_eq!(result.f_aligned.len(), m);
3326    }
3327
3328    #[test]
3329    fn test_karcher_mean_constant_curves() {
3330        let m = 30;
3331        let t = uniform_grid(m);
3332        let mut data = FdMatrix::zeros(5, m);
3333        for i in 0..5 {
3334            for j in 0..m {
3335                data[(i, j)] = 3.0;
3336            }
3337        }
3338
3339        let result = karcher_mean(&data, &t, 5, 1e-4, 0.0);
3340        for j in 0..m {
3341            assert!(
3342                (result.mean[j] - 3.0).abs() < 0.5,
3343                "Mean of constant curves should be near 3.0, got {} at j={j}",
3344                result.mean[j]
3345            );
3346        }
3347    }
3348
3349    #[test]
3350    fn test_nan_srsf_no_panic() {
3351        let m = 20;
3352        let t = uniform_grid(m);
3353        let mut f = vec![1.0; m];
3354        f[5] = f64::NAN;
3355        let mat = FdMatrix::from_slice(&f, 1, m).unwrap();
3356        let q = srsf_transform(&mat, &t);
3357        // Should not panic; NaN propagates
3358        assert_eq!(q.nrows(), 1);
3359    }
3360
3361    #[test]
3362    fn test_n1_karcher_mean() {
3363        let m = 30;
3364        let t = uniform_grid(m);
3365        let f: Vec<f64> = t.iter().map(|&ti| ti * ti).collect();
3366        let data = FdMatrix::from_slice(&f, 1, m).unwrap();
3367        let result = karcher_mean(&data, &t, 5, 1e-4, 0.0);
3368        assert_eq!(result.mean.len(), m);
3369        // With only 1 curve, the mean should be close to the original
3370        for j in 0..m {
3371            assert!(result.mean[j].is_finite());
3372        }
3373    }
3374
3375    #[test]
3376    fn test_two_point_grid() {
3377        let t = vec![0.0, 1.0];
3378        let f1 = vec![0.0, 1.0];
3379        let f2 = vec![0.0, 2.0];
3380        let d = elastic_distance(&f1, &f2, &t, 0.0);
3381        assert!(d >= 0.0);
3382        assert!(d.is_finite());
3383    }
3384
3385    #[test]
3386    fn test_non_uniform_grid_alignment() {
3387        // Non-uniform grid: points clustered near 0
3388        let t = vec![0.0, 0.01, 0.05, 0.2, 0.5, 1.0];
3389        let m = t.len();
3390        let f1: Vec<f64> = t.iter().map(|&ti: &f64| ti.sin()).collect();
3391        let f2: Vec<f64> = t.iter().map(|&ti: &f64| (ti + 0.1).sin()).collect();
3392        let result = elastic_align_pair(&f1, &f2, &t, 0.0);
3393        assert_eq!(result.gamma.len(), m);
3394        assert!(result.distance >= 0.0);
3395        assert!(result.distance.is_finite());
3396    }
3397
3398    // ── TSRVF tests ──
3399
3400    #[test]
3401    fn test_tsrvf_output_shape() {
3402        let m = 50;
3403        let n = 10;
3404        let t = uniform_grid(m);
3405        let data = make_test_data(n, m, 42);
3406        let result = tsrvf_transform(&data, &t, 5, 1e-3, 0.0);
3407        assert_eq!(
3408            result.tangent_vectors.shape(),
3409            (n, m),
3410            "Tangent vectors should be n×m"
3411        );
3412        assert_eq!(result.gammas.shape(), (n, m), "Gammas should be n×m");
3413        assert_eq!(result.srsf_norms.len(), n, "Should have n SRSF norms");
3414        assert_eq!(result.mean.len(), m, "Mean should have m points");
3415        assert_eq!(result.mean_srsf.len(), m, "Mean SRSF should have m points");
3416    }
3417
3418    #[test]
3419    fn test_tsrvf_all_finite() {
3420        let m = 50;
3421        let n = 5;
3422        let t = uniform_grid(m);
3423        let data = make_test_data(n, m, 42);
3424        let result = tsrvf_transform(&data, &t, 5, 1e-3, 0.0);
3425        for i in 0..n {
3426            for j in 0..m {
3427                assert!(
3428                    result.tangent_vectors[(i, j)].is_finite(),
3429                    "Tangent vector should be finite at ({i},{j})"
3430                );
3431            }
3432            assert!(
3433                result.srsf_norms[i].is_finite(),
3434                "SRSF norm should be finite for curve {i}"
3435            );
3436        }
3437        assert!(
3438            result.mean_srsf_norm.is_finite(),
3439            "Mean SRSF norm should be finite"
3440        );
3441    }
3442
3443    #[test]
3444    fn test_tsrvf_identical_curves_zero_tangent() {
3445        let m = 50;
3446        let t = uniform_grid(m);
3447        // Stack 5 identical curves
3448        let curve: Vec<f64> = t
3449            .iter()
3450            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
3451            .collect();
3452        let mut col_major = vec![0.0; 5 * m];
3453        for i in 0..5 {
3454            for j in 0..m {
3455                col_major[i + j * 5] = curve[j];
3456            }
3457        }
3458        let data = FdMatrix::from_column_major(col_major, 5, m).unwrap();
3459        let result = tsrvf_transform(&data, &t, 10, 1e-4, 0.0);
3460
3461        // All tangent vectors should be approximately zero
3462        for i in 0..5 {
3463            let tv_norm_sq: f64 = (0..m).map(|j| result.tangent_vectors[(i, j)].powi(2)).sum();
3464            assert!(
3465                tv_norm_sq.sqrt() < 0.5,
3466                "Identical curves should have near-zero tangent vectors, got norm = {}",
3467                tv_norm_sq.sqrt()
3468            );
3469        }
3470    }
3471
3472    #[test]
3473    fn test_tsrvf_mean_tangent_near_zero() {
3474        let m = 50;
3475        let n = 10;
3476        let t = uniform_grid(m);
3477        let data = make_test_data(n, m, 42);
3478        let result = tsrvf_transform(&data, &t, 10, 1e-3, 0.0);
3479
3480        // Mean of tangent vectors should be approximately zero (property of Karcher mean)
3481        let mut mean_tv = vec![0.0; m];
3482        for i in 0..n {
3483            for j in 0..m {
3484                mean_tv[j] += result.tangent_vectors[(i, j)];
3485            }
3486        }
3487        for j in 0..m {
3488            mean_tv[j] /= n as f64;
3489        }
3490        let mean_norm: f64 = mean_tv.iter().map(|v| v * v).sum::<f64>().sqrt();
3491        assert!(
3492            mean_norm < 1.0,
3493            "Mean tangent vector should be near zero, got norm = {mean_norm}"
3494        );
3495    }
3496
3497    #[test]
3498    fn test_tsrvf_from_alignment() {
3499        let m = 50;
3500        let n = 5;
3501        let t = uniform_grid(m);
3502        let data = make_test_data(n, m, 42);
3503        let karcher = karcher_mean(&data, &t, 5, 1e-3, 0.0);
3504        let result = tsrvf_from_alignment(&karcher, &t);
3505        assert_eq!(result.tangent_vectors.shape(), (n, m));
3506        assert!(result.mean_srsf_norm > 0.0);
3507    }
3508
3509    #[test]
3510    fn test_tsrvf_round_trip() {
3511        let m = 50;
3512        let n = 5;
3513        let t = uniform_grid(m);
3514        let data = make_test_data(n, m, 42);
3515        let result = tsrvf_transform(&data, &t, 10, 1e-3, 0.0);
3516        let reconstructed = tsrvf_inverse(&result, &t);
3517
3518        assert_eq!(reconstructed.shape(), result.tangent_vectors.shape());
3519        // Reconstructed curves should have finite values
3520        for i in 0..n {
3521            for j in 0..m {
3522                assert!(
3523                    reconstructed[(i, j)].is_finite(),
3524                    "Reconstructed curve should be finite at ({i},{j})"
3525                );
3526            }
3527        }
3528        // Issue #12: per-curve initial values should be preserved
3529        for i in 0..n {
3530            assert!(
3531                (reconstructed[(i, 0)] - result.initial_values[i]).abs() < 1e-6,
3532                "Curve {i} initial value: expected {}, got {}",
3533                result.initial_values[i],
3534                reconstructed[(i, 0)]
3535            );
3536        }
3537    }
3538
3539    #[test]
3540    fn test_tsrvf_initial_values_per_curve() {
3541        // Issue #12: tsrvf_inverse must use per-curve initial values, not mean[0]
3542        let m = 50;
3543        let n = 5;
3544        let t = uniform_grid(m);
3545
3546        // Create curves with distinct initial values
3547        let mut col_major = vec![0.0; n * m];
3548        for i in 0..n {
3549            let offset = (i as f64 + 1.0) * 2.0; // offsets: 2, 4, 6, 8, 10
3550            for j in 0..m {
3551                col_major[i + j * n] = offset + (2.0 * std::f64::consts::PI * t[j]).sin();
3552            }
3553        }
3554        let data = FdMatrix::from_column_major(col_major, n, m).unwrap();
3555
3556        let result = tsrvf_transform(&data, &t, 15, 1e-4, 0.0);
3557
3558        // initial_values should differ per curve
3559        assert_eq!(result.initial_values.len(), n);
3560        let all_same = result
3561            .initial_values
3562            .windows(2)
3563            .all(|w| (w[0] - w[1]).abs() < 1e-10);
3564        assert!(
3565            !all_same,
3566            "Initial values should differ per curve: {:?}",
3567            result.initial_values
3568        );
3569
3570        // Reconstruct and check initial values are preserved
3571        let reconstructed = tsrvf_inverse(&result, &t);
3572        for i in 0..n {
3573            assert!(
3574                (reconstructed[(i, 0)] - result.initial_values[i]).abs() < 1e-6,
3575                "Curve {i}: reconstructed f(0) = {}, expected {}",
3576                reconstructed[(i, 0)],
3577                result.initial_values[i]
3578            );
3579        }
3580
3581        // Before the fix, all curves would have reconstructed[(i, 0)] ≈ mean[0]
3582        // Verify they are NOT all the same
3583        let recon_initials: Vec<f64> = (0..n).map(|i| reconstructed[(i, 0)]).collect();
3584        let all_recon_same = recon_initials.windows(2).all(|w| (w[0] - w[1]).abs() < 0.1);
3585        assert!(
3586            !all_recon_same,
3587            "Reconstructed initial values must vary per curve: {:?}",
3588            recon_initials
3589        );
3590    }
3591
3592    #[test]
3593    fn test_tsrvf_single_curve() {
3594        let m = 50;
3595        let t = uniform_grid(m);
3596        let data = make_test_data(1, m, 42);
3597        let result = tsrvf_transform(&data, &t, 5, 1e-3, 0.0);
3598        assert_eq!(result.tangent_vectors.shape(), (1, m));
3599        // Single curve → tangent vector should be zero (it IS the mean)
3600        let tv_norm: f64 = (0..m)
3601            .map(|j| result.tangent_vectors[(0, j)].powi(2))
3602            .sum::<f64>()
3603            .sqrt();
3604        assert!(
3605            tv_norm < 0.5,
3606            "Single curve tangent vector should be near zero, got {tv_norm}"
3607        );
3608    }
3609
3610    #[test]
3611    fn test_tsrvf_constant_curves() {
3612        let m = 30;
3613        let t = uniform_grid(m);
3614        // Constant curves → SRSF = 0, norms = 0
3615        let data = FdMatrix::from_column_major(vec![5.0; 3 * m], 3, m).unwrap();
3616        let result = tsrvf_transform(&data, &t, 5, 1e-3, 0.0);
3617        // Should not produce NaN or Inf
3618        for i in 0..3 {
3619            for j in 0..m {
3620                let v = result.tangent_vectors[(i, j)];
3621                assert!(
3622                    !v.is_nan(),
3623                    "Constant curves should not produce NaN tangent vectors"
3624                );
3625            }
3626        }
3627    }
3628
3629    // ── Reference-value tests (sphere geometry) ─────────────────────────────
3630
3631    #[test]
3632    fn test_tsrvf_sphere_inv_exp_reference() {
3633        // Analytical reference: known vectors on the Hilbert sphere
3634        // psi1 = constant (unit L2 norm), psi2 = 1 + 0.3*sin(2πt) (normalized)
3635        let m = 21;
3636        let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
3637
3638        // Construct unit vector psi1 (constant)
3639        let raw1 = vec![1.0; m];
3640        let norm1 = inner_product_l2(&raw1, &raw1, &time).max(0.0).sqrt();
3641        let psi1: Vec<f64> = raw1.iter().map(|&v| v / norm1).collect();
3642
3643        // Construct psi2 with sinusoidal perturbation
3644        let raw2: Vec<f64> = time
3645            .iter()
3646            .map(|&t| 1.0 + 0.3 * (2.0 * std::f64::consts::PI * t).sin())
3647            .collect();
3648        let norm2 = inner_product_l2(&raw2, &raw2, &time).max(0.0).sqrt();
3649        let psi2: Vec<f64> = raw2.iter().map(|&v| v / norm2).collect();
3650
3651        // Compute theta analytically
3652        let ip = inner_product_l2(&psi1, &psi2, &time).clamp(-1.0, 1.0);
3653        let theta_expected = ip.acos();
3654
3655        // Compute via inv_exp_map_sphere
3656        let v = inv_exp_map_sphere(&psi1, &psi2, &time);
3657        let v_norm = inner_product_l2(&v, &v, &time).max(0.0).sqrt();
3658
3659        // ||v|| should equal theta
3660        assert!(
3661            (v_norm - theta_expected).abs() < 1e-10,
3662            "||v|| = {v_norm}, expected theta = {theta_expected}"
3663        );
3664
3665        // theta should be small but non-zero (perturbation is mild)
3666        assert!(
3667            theta_expected > 0.01 && theta_expected < 1.0,
3668            "theta = {theta_expected} out of expected range"
3669        );
3670    }
3671
3672    #[test]
3673    fn test_tsrvf_sphere_round_trip_reference() {
3674        // Round-trip: exp_map(psi1, inv_exp_map(psi1, psi2)) should recover psi2
3675        let m = 21;
3676        let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
3677
3678        let raw1 = vec![1.0; m];
3679        let norm1 = inner_product_l2(&raw1, &raw1, &time).max(0.0).sqrt();
3680        let psi1: Vec<f64> = raw1.iter().map(|&v| v / norm1).collect();
3681
3682        let raw2: Vec<f64> = time
3683            .iter()
3684            .map(|&t| 1.0 + 0.3 * (2.0 * std::f64::consts::PI * t).sin())
3685            .collect();
3686        let norm2 = inner_product_l2(&raw2, &raw2, &time).max(0.0).sqrt();
3687        let psi2: Vec<f64> = raw2.iter().map(|&v| v / norm2).collect();
3688
3689        let v = inv_exp_map_sphere(&psi1, &psi2, &time);
3690        let recovered = exp_map_sphere(&psi1, &v, &time);
3691
3692        // L2 error between psi2 and recovered
3693        let diff: Vec<f64> = psi2
3694            .iter()
3695            .zip(recovered.iter())
3696            .map(|(&a, &b)| (a - b).powi(2))
3697            .collect();
3698        let l2_err = trapz(&diff, &time).max(0.0).sqrt();
3699        assert!(
3700            l2_err < 1e-12,
3701            "Round-trip L2 error = {l2_err:.2e}, expected < 1e-12"
3702        );
3703    }
3704
3705    // ── Penalized alignment ──
3706
3707    #[test]
3708    fn test_penalized_alignment_lambda_zero_matches_unpenalized() {
3709        let m = 50;
3710        let t = uniform_grid(m);
3711        let data = make_test_data(2, m, 42);
3712        let f1 = data.row(0);
3713        let f2 = data.row(1);
3714
3715        let r0 = elastic_align_pair(&f1, &f2, &t, 0.0);
3716        // lambda = 0.0 should produce the same result regardless
3717        assert!(r0.distance >= 0.0);
3718        assert_eq!(r0.gamma.len(), m);
3719    }
3720
3721    #[test]
3722    fn test_penalized_alignment_smoother_warp() {
3723        let m = 80;
3724        let t = uniform_grid(m);
3725        let f1: Vec<f64> = t
3726            .iter()
3727            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
3728            .collect();
3729        let f2: Vec<f64> = t
3730            .iter()
3731            .map(|&ti| (2.0 * std::f64::consts::PI * (ti - 0.15)).sin())
3732            .collect();
3733
3734        let r_free = elastic_align_pair(&f1, &f2, &t, 0.0);
3735        let r_pen = elastic_align_pair(&f1, &f2, &t, 1.0);
3736
3737        // Measure warp deviation from identity
3738        let dev_free: f64 = r_free
3739            .gamma
3740            .iter()
3741            .zip(t.iter())
3742            .map(|(g, ti)| (g - ti).powi(2))
3743            .sum();
3744        let dev_pen: f64 = r_pen
3745            .gamma
3746            .iter()
3747            .zip(t.iter())
3748            .map(|(g, ti)| (g - ti).powi(2))
3749            .sum();
3750
3751        assert!(
3752            dev_pen <= dev_free + 1e-6,
3753            "Penalized warp should be closer to identity: free={dev_free:.6}, pen={dev_pen:.6}"
3754        );
3755    }
3756
3757    #[test]
3758    fn test_penalized_alignment_large_lambda_near_identity() {
3759        let m = 50;
3760        let t = uniform_grid(m);
3761        let f1: Vec<f64> = t
3762            .iter()
3763            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
3764            .collect();
3765        let f2: Vec<f64> = t
3766            .iter()
3767            .map(|&ti| (2.0 * std::f64::consts::PI * (ti - 0.1)).sin())
3768            .collect();
3769
3770        let r = elastic_align_pair(&f1, &f2, &t, 1000.0);
3771
3772        // With very large lambda, warp should be very close to identity
3773        let max_dev: f64 = r
3774            .gamma
3775            .iter()
3776            .zip(t.iter())
3777            .map(|(g, ti)| (g - ti).abs())
3778            .fold(0.0_f64, f64::max);
3779        assert!(
3780            max_dev < 0.05,
3781            "Large lambda should give near-identity warp: max deviation = {max_dev}"
3782        );
3783    }
3784
3785    #[test]
3786    fn test_penalized_karcher_mean() {
3787        let m = 40;
3788        let t = uniform_grid(m);
3789        let data = make_test_data(10, m, 42);
3790
3791        let result = karcher_mean(&data, &t, 5, 1e-3, 0.5);
3792        assert_eq!(result.mean.len(), m);
3793        for j in 0..m {
3794            assert!(result.mean[j].is_finite());
3795        }
3796    }
3797
3798    // ── Phase-amplitude decomposition ──
3799
3800    #[test]
3801    fn test_decomposition_identity_curves() {
3802        let m = 50;
3803        let t = uniform_grid(m);
3804        let f: Vec<f64> = t
3805            .iter()
3806            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
3807            .collect();
3808
3809        let result = elastic_decomposition(&f, &f, &t, 0.0);
3810        assert!(
3811            result.d_amplitude < 0.1,
3812            "Self-decomposition amplitude should be ~0, got {}",
3813            result.d_amplitude
3814        );
3815        assert!(
3816            result.d_phase < 0.2,
3817            "Self-decomposition phase should be ~0, got {}",
3818            result.d_phase
3819        );
3820    }
3821
3822    #[test]
3823    fn test_decomposition_pythagorean() {
3824        // d_total² ≈ d_a² + d_φ² (approximately, for the Fisher-Rao metric)
3825        let m = 80;
3826        let t = uniform_grid(m);
3827        let f1: Vec<f64> = t
3828            .iter()
3829            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
3830            .collect();
3831        let f2: Vec<f64> = t
3832            .iter()
3833            .map(|&ti| 1.2 * (2.0 * std::f64::consts::PI * (ti - 0.1)).sin())
3834            .collect();
3835
3836        let result = elastic_decomposition(&f1, &f2, &t, 0.0);
3837        let da = result.d_amplitude;
3838        let dp = result.d_phase;
3839        // Both should be non-negative
3840        assert!(da >= 0.0);
3841        assert!(dp >= 0.0);
3842    }
3843
3844    #[test]
3845    fn test_phase_distance_shifted_sine() {
3846        let m = 80;
3847        let t = uniform_grid(m);
3848        let f1: Vec<f64> = t
3849            .iter()
3850            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
3851            .collect();
3852        let f2: Vec<f64> = t
3853            .iter()
3854            .map(|&ti| (2.0 * std::f64::consts::PI * (ti - 0.15)).sin())
3855            .collect();
3856
3857        let dp = phase_distance_pair(&f1, &f2, &t, 0.0);
3858        assert!(
3859            dp > 0.01,
3860            "Phase distance of shifted curves should be > 0, got {dp}"
3861        );
3862    }
3863
3864    #[test]
3865    fn test_amplitude_distance_scaled_curve() {
3866        let m = 80;
3867        let t = uniform_grid(m);
3868        let f1: Vec<f64> = t
3869            .iter()
3870            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
3871            .collect();
3872        let f2: Vec<f64> = t
3873            .iter()
3874            .map(|&ti| 2.0 * (2.0 * std::f64::consts::PI * ti).sin())
3875            .collect();
3876
3877        let da = amplitude_distance(&f1, &f2, &t, 0.0);
3878        assert!(
3879            da > 0.01,
3880            "Amplitude distance of scaled curves should be > 0, got {da}"
3881        );
3882    }
3883
3884    #[test]
3885    fn test_phase_distance_nonneg() {
3886        let data = make_test_data(4, 40, 42);
3887        let t = uniform_grid(40);
3888        for i in 0..4 {
3889            for j in 0..4 {
3890                let fi = data.row(i);
3891                let fj = data.row(j);
3892                let dp = phase_distance_pair(&fi, &fj, &t, 0.0);
3893                assert!(dp >= 0.0, "Phase distance should be non-negative");
3894            }
3895        }
3896    }
3897
3898    // ── Parallel transport ──
3899
3900    #[test]
3901    fn test_schilds_ladder_zero_vector() {
3902        let m = 21;
3903        let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
3904        let raw = vec![1.0; m];
3905        let norm = crate::warping::l2_norm_l2(&raw, &time);
3906        let from: Vec<f64> = raw.iter().map(|&v| v / norm).collect();
3907        let raw2: Vec<f64> = time
3908            .iter()
3909            .map(|&t| 1.0 + 0.2 * (2.0 * std::f64::consts::PI * t).sin())
3910            .collect();
3911        let norm2 = crate::warping::l2_norm_l2(&raw2, &time);
3912        let to: Vec<f64> = raw2.iter().map(|&v| v / norm2).collect();
3913
3914        let zero = vec![0.0; m];
3915        let result = parallel_transport_schilds(&zero, &from, &to, &time);
3916        let result_norm: f64 = result.iter().map(|v| v * v).sum::<f64>().sqrt();
3917        assert!(
3918            result_norm < 1e-6,
3919            "Transporting zero should give zero, got norm {result_norm}"
3920        );
3921    }
3922
3923    #[test]
3924    fn test_pole_ladder_zero_vector() {
3925        let m = 21;
3926        let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
3927        let raw = vec![1.0; m];
3928        let norm = crate::warping::l2_norm_l2(&raw, &time);
3929        let from: Vec<f64> = raw.iter().map(|&v| v / norm).collect();
3930        let raw2: Vec<f64> = time
3931            .iter()
3932            .map(|&t| 1.0 + 0.2 * (2.0 * std::f64::consts::PI * t).sin())
3933            .collect();
3934        let norm2 = crate::warping::l2_norm_l2(&raw2, &time);
3935        let to: Vec<f64> = raw2.iter().map(|&v| v / norm2).collect();
3936
3937        let zero = vec![0.0; m];
3938        let result = parallel_transport_pole(&zero, &from, &to, &time);
3939        let result_norm: f64 = result.iter().map(|v| v * v).sum::<f64>().sqrt();
3940        assert!(
3941            result_norm < 1e-6,
3942            "Transporting zero should give zero, got norm {result_norm}"
3943        );
3944    }
3945
3946    #[test]
3947    fn test_schilds_preserves_norm() {
3948        let m = 51;
3949        let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
3950        let raw = vec![1.0; m];
3951        let norm = crate::warping::l2_norm_l2(&raw, &time);
3952        let from: Vec<f64> = raw.iter().map(|&v| v / norm).collect();
3953        let raw2: Vec<f64> = time
3954            .iter()
3955            .map(|&t| 1.0 + 0.15 * (2.0 * std::f64::consts::PI * t).sin())
3956            .collect();
3957        let norm2 = crate::warping::l2_norm_l2(&raw2, &time);
3958        let to: Vec<f64> = raw2.iter().map(|&v| v / norm2).collect();
3959
3960        // Small tangent vector
3961        let v: Vec<f64> = time
3962            .iter()
3963            .map(|&t| 0.1 * (4.0 * std::f64::consts::PI * t).cos())
3964            .collect();
3965        let v_norm = crate::warping::l2_norm_l2(&v, &time);
3966
3967        let transported = parallel_transport_schilds(&v, &from, &to, &time);
3968        let t_norm = crate::warping::l2_norm_l2(&transported, &time);
3969
3970        // Norm should be approximately preserved (ladder methods are first-order)
3971        assert!(
3972            (t_norm - v_norm).abs() / v_norm.max(1e-10) < 1.5,
3973            "Schild's should roughly preserve norm: original={v_norm:.4}, transported={t_norm:.4}"
3974        );
3975    }
3976
3977    #[test]
3978    fn test_tsrvf_logmap_matches_original() {
3979        let m = 50;
3980        let n = 5;
3981        let t = uniform_grid(m);
3982        let data = make_test_data(n, m, 42);
3983
3984        let result_orig = tsrvf_transform(&data, &t, 5, 1e-3, 0.0);
3985        let result_logmap =
3986            tsrvf_transform_with_method(&data, &t, 5, 1e-3, 0.0, TransportMethod::LogMap);
3987
3988        // Should be identical (LogMap delegates to original)
3989        for i in 0..n {
3990            for j in 0..m {
3991                assert!(
3992                    (result_orig.tangent_vectors[(i, j)] - result_logmap.tangent_vectors[(i, j)])
3993                        .abs()
3994                        < 1e-12,
3995                    "LogMap variant should match original at ({i},{j})"
3996                );
3997            }
3998        }
3999    }
4000
4001    #[test]
4002    fn test_tsrvf_with_schilds_produces_valid_result() {
4003        let m = 50;
4004        let n = 5;
4005        let t = uniform_grid(m);
4006        let data = make_test_data(n, m, 42);
4007
4008        let result =
4009            tsrvf_transform_with_method(&data, &t, 5, 1e-3, 0.0, TransportMethod::SchildsLadder);
4010
4011        assert_eq!(result.tangent_vectors.shape(), (n, m));
4012        for i in 0..n {
4013            for j in 0..m {
4014                assert!(
4015                    result.tangent_vectors[(i, j)].is_finite(),
4016                    "Schild's TSRVF should produce finite tangent vectors at ({i},{j})"
4017                );
4018            }
4019        }
4020    }
4021
4022    #[test]
4023    fn test_transport_methods_differ() {
4024        let m = 50;
4025        let n = 5;
4026        let t = uniform_grid(m);
4027        let data = make_test_data(n, m, 42);
4028        let karcher = karcher_mean(&data, &t, 5, 1e-3, 0.0);
4029
4030        let r_log = tsrvf_from_alignment_with_method(&karcher, &t, TransportMethod::LogMap);
4031        let r_schilds =
4032            tsrvf_from_alignment_with_method(&karcher, &t, TransportMethod::SchildsLadder);
4033
4034        // Methods should produce different (but related) tangent vectors
4035        let mut total_diff = 0.0;
4036        for i in 0..n {
4037            for j in 0..m {
4038                total_diff +=
4039                    (r_log.tangent_vectors[(i, j)] - r_schilds.tangent_vectors[(i, j)]).abs();
4040            }
4041        }
4042
4043        // They should be non-zero different (unless all curves are identical to mean)
4044        // Just check both produce finite results
4045        assert!(total_diff.is_finite());
4046    }
4047
4048    // ── Alignment quality metrics ──
4049
4050    #[test]
4051    fn test_warp_complexity_identity_is_zero() {
4052        let m = 50;
4053        let t = uniform_grid(m);
4054        let identity = t.clone();
4055        let c = warp_complexity(&identity, &t);
4056        assert!(
4057            c < 1e-10,
4058            "Identity warp should have zero complexity, got {c}"
4059        );
4060    }
4061
4062    #[test]
4063    fn test_warp_complexity_nonidentity_positive() {
4064        let m = 50;
4065        let t = uniform_grid(m);
4066        let gamma: Vec<f64> = t.iter().map(|&ti| ti * ti).collect();
4067        let c = warp_complexity(&gamma, &t);
4068        assert!(
4069            c > 0.01,
4070            "Non-identity warp should have positive complexity, got {c}"
4071        );
4072    }
4073
4074    #[test]
4075    fn test_warp_smoothness_identity_is_zero() {
4076        let m = 50;
4077        let t = uniform_grid(m);
4078        let identity = t.clone();
4079        let s = warp_smoothness(&identity, &t);
4080        assert!(
4081            s < 1e-6,
4082            "Identity warp (constant γ'=1, γ''=0) should have near-zero bending energy, got {s}"
4083        );
4084    }
4085
4086    #[test]
4087    fn test_alignment_quality_basic() {
4088        let m = 50;
4089        let n = 8;
4090        let t = uniform_grid(m);
4091        let data = make_test_data(n, m, 42);
4092        let karcher = karcher_mean(&data, &t, 10, 1e-3, 0.0);
4093        let quality = alignment_quality(&data, &karcher, &t);
4094
4095        // Shape checks
4096        assert_eq!(quality.warp_complexity.len(), n);
4097        assert_eq!(quality.warp_smoothness.len(), n);
4098        assert_eq!(quality.pointwise_variance_ratio.len(), m);
4099
4100        // Non-negativity
4101        assert!(quality.total_variance >= 0.0);
4102        assert!(quality.amplitude_variance >= 0.0);
4103        assert!(quality.phase_variance >= 0.0);
4104        assert!(quality.mean_warp_complexity >= 0.0);
4105        assert!(quality.mean_warp_smoothness >= 0.0);
4106
4107        // Amplitude variance ≤ total variance
4108        assert!(
4109            quality.amplitude_variance <= quality.total_variance + 1e-10,
4110            "Amplitude variance ({}) should be ≤ total variance ({})",
4111            quality.amplitude_variance,
4112            quality.total_variance
4113        );
4114    }
4115
4116    #[test]
4117    fn test_alignment_quality_identical_curves() {
4118        let m = 50;
4119        let t = uniform_grid(m);
4120        let curve: Vec<f64> = t
4121            .iter()
4122            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
4123            .collect();
4124        let mut col_major = vec![0.0; 5 * m];
4125        for i in 0..5 {
4126            for j in 0..m {
4127                col_major[i + j * 5] = curve[j];
4128            }
4129        }
4130        let data = FdMatrix::from_column_major(col_major, 5, m).unwrap();
4131        let karcher = karcher_mean(&data, &t, 5, 1e-3, 0.0);
4132        let quality = alignment_quality(&data, &karcher, &t);
4133
4134        // Identical curves → near-zero variances and warp complexities
4135        assert!(
4136            quality.total_variance < 0.01,
4137            "Identical curves should have near-zero total variance, got {}",
4138            quality.total_variance
4139        );
4140        assert!(
4141            quality.mean_warp_complexity < 0.1,
4142            "Identical curves should have near-zero warp complexity, got {}",
4143            quality.mean_warp_complexity
4144        );
4145    }
4146
4147    #[test]
4148    fn test_alignment_quality_variance_reduction() {
4149        let m = 50;
4150        let n = 10;
4151        let t = uniform_grid(m);
4152        let data = make_test_data(n, m, 42);
4153        let karcher = karcher_mean(&data, &t, 10, 1e-3, 0.0);
4154        let quality = alignment_quality(&data, &karcher, &t);
4155
4156        // Mean variance ratio should be ≤ ~1 (alignment shouldn't increase variance)
4157        assert!(
4158            quality.mean_variance_reduction <= 1.5,
4159            "Mean variance reduction ratio should be ≤ ~1, got {}",
4160            quality.mean_variance_reduction
4161        );
4162    }
4163
4164    #[test]
4165    fn test_pairwise_consistency_small() {
4166        let m = 40;
4167        let n = 4;
4168        let t = uniform_grid(m);
4169        let data = make_test_data(n, m, 42);
4170
4171        let consistency = pairwise_consistency(&data, &t, 0.0, 100);
4172        assert!(
4173            consistency.is_finite() && consistency >= 0.0,
4174            "Pairwise consistency should be finite and non-negative, got {consistency}"
4175        );
4176    }
4177
4178    // ── Multidimensional SRSF ──
4179
4180    #[test]
4181    fn test_srsf_nd_d1_matches_existing() {
4182        let m = 50;
4183        let t = uniform_grid(m);
4184        let data = make_test_data(3, m, 42);
4185
4186        // 1D via existing function
4187        let q_1d = srsf_transform(&data, &t);
4188
4189        // 1D via nd function
4190        let data_nd = FdCurveSet::from_1d(data);
4191        let q_nd = srsf_transform_nd(&data_nd, &t);
4192
4193        assert_eq!(q_nd.ndim(), 1);
4194        for i in 0..3 {
4195            for j in 0..m {
4196                assert!(
4197                    (q_1d[(i, j)] - q_nd.dims[0][(i, j)]).abs() < 1e-10,
4198                    "1D nd SRSF should match existing at ({i},{j}): {} vs {}",
4199                    q_1d[(i, j)],
4200                    q_nd.dims[0][(i, j)]
4201                );
4202            }
4203        }
4204    }
4205
4206    #[test]
4207    fn test_srsf_nd_constant_is_zero() {
4208        let m = 30;
4209        let t = uniform_grid(m);
4210        // Constant R^2 curve: f(t) = (3.0, -1.0)
4211        let dim0 = FdMatrix::from_column_major(vec![3.0; m], 1, m).unwrap();
4212        let dim1 = FdMatrix::from_column_major(vec![-1.0; m], 1, m).unwrap();
4213        let data = FdCurveSet::from_dims(vec![dim0, dim1]).unwrap();
4214
4215        let q = srsf_transform_nd(&data, &t);
4216        for k in 0..2 {
4217            for j in 0..m {
4218                assert!(
4219                    q.dims[k][(0, j)].abs() < 1e-10,
4220                    "Constant curve SRSF should be zero, dim {k} at {j}: {}",
4221                    q.dims[k][(0, j)]
4222                );
4223            }
4224        }
4225    }
4226
4227    #[test]
4228    fn test_srsf_nd_linear_r2() {
4229        let m = 51;
4230        let t = uniform_grid(m);
4231        // f(t) = (2t, 3t) → f'(t) = (2, 3), ||f'|| = sqrt(13)
4232        // q(t) = (2, 3) / sqrt(sqrt(13)) = (2, 3) / 13^(1/4)
4233        let dim0 =
4234            FdMatrix::from_slice(&t.iter().map(|&ti| 2.0 * ti).collect::<Vec<_>>(), 1, m).unwrap();
4235        let dim1 =
4236            FdMatrix::from_slice(&t.iter().map(|&ti| 3.0 * ti).collect::<Vec<_>>(), 1, m).unwrap();
4237        let data = FdCurveSet::from_dims(vec![dim0, dim1]).unwrap();
4238
4239        let q = srsf_transform_nd(&data, &t);
4240        let expected_scale = 1.0 / 13.0_f64.powf(0.25);
4241        let mid = m / 2;
4242
4243        assert!(
4244            (q.dims[0][(0, mid)] - 2.0 * expected_scale).abs() < 0.1,
4245            "q_x at midpoint: {} vs expected {}",
4246            q.dims[0][(0, mid)],
4247            2.0 * expected_scale
4248        );
4249        assert!(
4250            (q.dims[1][(0, mid)] - 3.0 * expected_scale).abs() < 0.1,
4251            "q_y at midpoint: {} vs expected {}",
4252            q.dims[1][(0, mid)],
4253            3.0 * expected_scale
4254        );
4255    }
4256
4257    #[test]
4258    fn test_srsf_nd_round_trip() {
4259        let m = 51;
4260        let t = uniform_grid(m);
4261        // f(t) = (sin(2πt), cos(2πt))
4262        let pi2 = 2.0 * std::f64::consts::PI;
4263        let vals_x: Vec<f64> = t.iter().map(|&ti| (pi2 * ti).sin()).collect();
4264        let vals_y: Vec<f64> = t.iter().map(|&ti| (pi2 * ti).cos()).collect();
4265        let dim0 = FdMatrix::from_slice(&vals_x, 1, m).unwrap();
4266        let dim1 = FdMatrix::from_slice(&vals_y, 1, m).unwrap();
4267        let data = FdCurveSet::from_dims(vec![dim0, dim1]).unwrap();
4268
4269        let q = srsf_transform_nd(&data, &t);
4270        let q_vecs: Vec<Vec<f64>> = q.dims.iter().map(|dm| dm.row(0)).collect();
4271        let f0 = vec![vals_x[0], vals_y[0]];
4272        let recon = srsf_inverse_nd(&q_vecs, &t, &f0);
4273
4274        // Check reconstruction error (skip boundary points)
4275        let mut max_err = 0.0_f64;
4276        for k in 0..2 {
4277            let orig = if k == 0 { &vals_x } else { &vals_y };
4278            for j in 2..(m - 2) {
4279                let err = (recon[k][j] - orig[j]).abs();
4280                max_err = max_err.max(err);
4281            }
4282        }
4283        assert!(
4284            max_err < 0.2,
4285            "SRSF round-trip max error should be small, got {max_err}"
4286        );
4287    }
4288
4289    #[test]
4290    fn test_align_nd_identical_near_zero() {
4291        let m = 50;
4292        let t = uniform_grid(m);
4293        let pi2 = 2.0 * std::f64::consts::PI;
4294        let vals_x: Vec<f64> = t.iter().map(|&ti| (pi2 * ti).sin()).collect();
4295        let vals_y: Vec<f64> = t.iter().map(|&ti| (pi2 * ti).cos()).collect();
4296        let dim0 = FdMatrix::from_slice(&vals_x, 1, m).unwrap();
4297        let dim1 = FdMatrix::from_slice(&vals_y, 1, m).unwrap();
4298        let data = FdCurveSet::from_dims(vec![dim0, dim1]).unwrap();
4299
4300        let result = elastic_align_pair_nd(&data, &data, &t, 0.0);
4301        assert!(
4302            result.distance < 0.5,
4303            "Self-alignment distance should be ~0, got {}",
4304            result.distance
4305        );
4306        // Gamma should be near identity
4307        let max_dev: f64 = result
4308            .gamma
4309            .iter()
4310            .zip(t.iter())
4311            .map(|(g, ti)| (g - ti).abs())
4312            .fold(0.0_f64, f64::max);
4313        assert!(
4314            max_dev < 0.1,
4315            "Self-alignment warp should be near identity, max dev = {max_dev}"
4316        );
4317    }
4318
4319    #[test]
4320    fn test_align_nd_shifted_r2() {
4321        let m = 60;
4322        let t = uniform_grid(m);
4323        let pi2 = 2.0 * std::f64::consts::PI;
4324
4325        // f1(t) = (sin(2πt), cos(2πt))
4326        let f1x: Vec<f64> = t.iter().map(|&ti| (pi2 * ti).sin()).collect();
4327        let f1y: Vec<f64> = t.iter().map(|&ti| (pi2 * ti).cos()).collect();
4328        let f1 = FdCurveSet::from_dims(vec![
4329            FdMatrix::from_slice(&f1x, 1, m).unwrap(),
4330            FdMatrix::from_slice(&f1y, 1, m).unwrap(),
4331        ])
4332        .unwrap();
4333
4334        // f2(t) = (sin(2π(t-0.1)), cos(2π(t-0.1))) — phase shifted
4335        let f2x: Vec<f64> = t.iter().map(|&ti| (pi2 * (ti - 0.1)).sin()).collect();
4336        let f2y: Vec<f64> = t.iter().map(|&ti| (pi2 * (ti - 0.1)).cos()).collect();
4337        let f2 = FdCurveSet::from_dims(vec![
4338            FdMatrix::from_slice(&f2x, 1, m).unwrap(),
4339            FdMatrix::from_slice(&f2y, 1, m).unwrap(),
4340        ])
4341        .unwrap();
4342
4343        let result = elastic_align_pair_nd(&f1, &f2, &t, 0.0);
4344        assert!(
4345            result.distance.is_finite(),
4346            "Distance should be finite, got {}",
4347            result.distance
4348        );
4349        assert_eq!(result.f_aligned.len(), 2);
4350        assert_eq!(result.f_aligned[0].len(), m);
4351        // Warp should deviate from identity (non-trivial alignment)
4352        let max_dev: f64 = result
4353            .gamma
4354            .iter()
4355            .zip(t.iter())
4356            .map(|(g, ti)| (g - ti).abs())
4357            .fold(0.0_f64, f64::max);
4358        assert!(
4359            max_dev > 0.01,
4360            "Shifted curves should require non-trivial warp, max dev = {max_dev}"
4361        );
4362    }
4363
4364    // ── Landmark-constrained alignment ──
4365
4366    #[test]
4367    fn test_constrained_no_landmarks_matches_unconstrained() {
4368        let m = 50;
4369        let t = uniform_grid(m);
4370        let f1: Vec<f64> = t
4371            .iter()
4372            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
4373            .collect();
4374        let f2: Vec<f64> = t
4375            .iter()
4376            .map(|&ti| (2.0 * std::f64::consts::PI * (ti - 0.1)).sin())
4377            .collect();
4378
4379        let r_free = elastic_align_pair(&f1, &f2, &t, 0.0);
4380        let r_const = elastic_align_pair_constrained(&f1, &f2, &t, &[], 0.0);
4381
4382        // Should match unconstrained
4383        for j in 0..m {
4384            assert!(
4385                (r_free.gamma[j] - r_const.gamma[j]).abs() < 1e-10,
4386                "No-landmark constrained should match unconstrained at {j}"
4387            );
4388        }
4389        assert!(r_const.enforced_landmarks.is_empty());
4390    }
4391
4392    #[test]
4393    fn test_constrained_single_landmark_enforced() {
4394        let m = 60;
4395        let t = uniform_grid(m);
4396        let f1: Vec<f64> = t
4397            .iter()
4398            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
4399            .collect();
4400        let f2: Vec<f64> = t
4401            .iter()
4402            .map(|&ti| (2.0 * std::f64::consts::PI * (ti - 0.1)).sin())
4403            .collect();
4404
4405        // Constrain midpoint: target_t=0.5 should map to source_t=0.5
4406        let result = elastic_align_pair_constrained(&f1, &f2, &t, &[(0.5, 0.5)], 0.0);
4407
4408        // Gamma at the midpoint should be close to 0.5
4409        let mid_idx = snap_to_grid(0.5, &t);
4410        assert!(
4411            (result.gamma[mid_idx] - 0.5).abs() < 0.05,
4412            "Constrained gamma at midpoint should be ~0.5, got {}",
4413            result.gamma[mid_idx]
4414        );
4415        assert_eq!(result.enforced_landmarks.len(), 1);
4416    }
4417
4418    #[test]
4419    fn test_constrained_multiple_landmarks() {
4420        let m = 80;
4421        let t = uniform_grid(m);
4422        let f1: Vec<f64> = t
4423            .iter()
4424            .map(|&ti| (4.0 * std::f64::consts::PI * ti).sin())
4425            .collect();
4426        let f2: Vec<f64> = t
4427            .iter()
4428            .map(|&ti| (4.0 * std::f64::consts::PI * (ti - 0.05)).sin())
4429            .collect();
4430
4431        let landmarks = vec![(0.25, 0.25), (0.5, 0.5), (0.75, 0.75)];
4432        let result = elastic_align_pair_constrained(&f1, &f2, &t, &landmarks, 0.0);
4433
4434        // Gamma should pass through (or near) each landmark
4435        for &(tt, st) in &landmarks {
4436            let idx = snap_to_grid(tt, &t);
4437            assert!(
4438                (result.gamma[idx] - st).abs() < 0.05,
4439                "Gamma at t={tt} should be ~{st}, got {}",
4440                result.gamma[idx]
4441            );
4442        }
4443    }
4444
4445    #[test]
4446    fn test_constrained_monotone_gamma() {
4447        let m = 60;
4448        let t = uniform_grid(m);
4449        let f1: Vec<f64> = t
4450            .iter()
4451            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
4452            .collect();
4453        let f2: Vec<f64> = t
4454            .iter()
4455            .map(|&ti| (2.0 * std::f64::consts::PI * (ti - 0.1)).sin())
4456            .collect();
4457
4458        let result = elastic_align_pair_constrained(&f1, &f2, &t, &[(0.3, 0.3), (0.7, 0.7)], 0.0);
4459
4460        // Gamma should be non-decreasing
4461        for j in 1..m {
4462            assert!(
4463                result.gamma[j] >= result.gamma[j - 1] - 1e-10,
4464                "Gamma should be monotone: gamma[{}]={} < gamma[{}]={}",
4465                j,
4466                result.gamma[j],
4467                j - 1,
4468                result.gamma[j - 1]
4469            );
4470        }
4471        // Boundary conditions
4472        assert!((result.gamma[0] - t[0]).abs() < 1e-10);
4473        assert!((result.gamma[m - 1] - t[m - 1]).abs() < 1e-10);
4474    }
4475
4476    #[test]
4477    fn test_constrained_distance_ge_unconstrained() {
4478        let m = 60;
4479        let t = uniform_grid(m);
4480        let f1: Vec<f64> = t
4481            .iter()
4482            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
4483            .collect();
4484        let f2: Vec<f64> = t
4485            .iter()
4486            .map(|&ti| (2.0 * std::f64::consts::PI * (ti - 0.15)).sin())
4487            .collect();
4488
4489        let r_free = elastic_align_pair(&f1, &f2, &t, 0.0);
4490        let r_const = elastic_align_pair_constrained(&f1, &f2, &t, &[(0.5, 0.5)], 0.0);
4491
4492        // Constrained distance should be >= unconstrained (constraints reduce freedom)
4493        assert!(
4494            r_const.distance >= r_free.distance - 1e-6,
4495            "Constrained distance ({}) should be >= unconstrained ({})",
4496            r_const.distance,
4497            r_free.distance
4498        );
4499    }
4500
4501    #[test]
4502    fn test_constrained_with_landmark_detection() {
4503        let m = 80;
4504        let t = uniform_grid(m);
4505        let f1: Vec<f64> = t
4506            .iter()
4507            .map(|&ti| (4.0 * std::f64::consts::PI * ti).sin())
4508            .collect();
4509        let f2: Vec<f64> = t
4510            .iter()
4511            .map(|&ti| (4.0 * std::f64::consts::PI * (ti - 0.05)).sin())
4512            .collect();
4513
4514        let result = elastic_align_pair_with_landmarks(
4515            &f1,
4516            &f2,
4517            &t,
4518            crate::landmark::LandmarkKind::Peak,
4519            0.1,
4520            0,
4521            0.0,
4522        );
4523
4524        assert_eq!(result.gamma.len(), m);
4525        assert_eq!(result.f_aligned.len(), m);
4526        assert!(result.distance.is_finite());
4527        // Should be monotone
4528        for j in 1..m {
4529            assert!(
4530                result.gamma[j] >= result.gamma[j - 1] - 1e-10,
4531                "Gamma should be monotone at j={j}"
4532            );
4533        }
4534    }
4535
4536    // ── SRSF smoothing for TSRVF (Issue #13) ──
4537
4538    #[test]
4539    fn test_gam_to_psi_smooth_identity() {
4540        // Smoothed psi of identity warp should stay close to constant 1 in the interior.
4541        // Boundary points are biased by kernel smoothing (fewer neighbors), skip them.
4542        use crate::warping::{gam_to_psi, gam_to_psi_smooth};
4543        let m = 101;
4544        let h = 1.0 / (m - 1) as f64;
4545        let gam: Vec<f64> = uniform_grid(m);
4546        let psi_raw = gam_to_psi(&gam, h);
4547        let psi_smooth = gam_to_psi_smooth(&gam, h);
4548        // Check interior points (skip ~5% at each boundary)
4549        let skip = m / 20;
4550        for j in skip..(m - skip) {
4551            assert!(
4552                (psi_smooth[j] - 1.0).abs() < 0.05,
4553                "Smoothed psi of identity should be ~1.0, got {} at j={}",
4554                psi_smooth[j],
4555                j
4556            );
4557            assert!(
4558                (psi_smooth[j] - psi_raw[j]).abs() < 0.05,
4559                "Smoothed and raw psi should agree on smooth warp at j={}",
4560                j
4561            );
4562        }
4563    }
4564
4565    #[test]
4566    fn test_gam_to_psi_smooth_reduces_spikes() {
4567        // Create a kinky warp (simulating DP output with multiple slope changes)
4568        // and verify smoothing reduces psi spikes
4569        use crate::warping::{gam_to_psi, gam_to_psi_smooth};
4570        let m = 101;
4571        let h = 1.0 / (m - 1) as f64;
4572        let argvals = uniform_grid(m);
4573        // Multi-segment piecewise-linear warp with several kinks
4574        let mut gam: Vec<f64> = Vec::with_capacity(m);
4575        for j in 0..m {
4576            let t = argvals[j];
4577            // Three segments: slow (slope 0.5), fast (slope 2), slow (slope 0.5)
4578            let g = if t < 0.33 {
4579                t * 0.5 / 0.33
4580            } else if t < 0.67 {
4581                0.5 + (t - 0.33) * 0.5 / 0.34 * 2.0 // steeper
4582            } else {
4583                let base = 0.5 + 0.5 / 0.34 * 2.0 * 0.34; // ~1.5 but clamped
4584                (base + (t - 0.67) * 0.5 / 0.33).min(1.0)
4585            };
4586            gam.push(g.min(1.0));
4587        }
4588        // Normalize to [0,1]
4589        let gmax = gam[m - 1].max(1e-10);
4590        for g in &mut gam {
4591            *g /= gmax;
4592        }
4593        let psi_raw = gam_to_psi(&gam, h);
4594        let psi_smooth = gam_to_psi_smooth(&gam, h);
4595        // The raw psi should have jumps at kink points (slope transitions)
4596        let max_jump_raw: f64 = psi_raw
4597            .windows(2)
4598            .map(|w| (w[1] - w[0]).abs())
4599            .fold(0.0_f64, f64::max);
4600        let max_jump_smooth: f64 = psi_smooth
4601            .windows(2)
4602            .map(|w| (w[1] - w[0]).abs())
4603            .fold(0.0_f64, f64::max);
4604        // Smoothing should reduce the maximum jump in psi
4605        assert!(
4606            max_jump_smooth < max_jump_raw + 0.01,
4607            "Smoothing should not increase max psi jump: raw={max_jump_raw:.4}, smooth={max_jump_smooth:.4}"
4608        );
4609    }
4610
4611    #[test]
4612    fn test_smooth_aligned_srsfs_preserves_shape() {
4613        // Smoothing aligned SRSFs should preserve overall shape
4614        use crate::smoothing::nadaraya_watson;
4615        let m = 101;
4616        let time: Vec<f64> = (0..m).map(|j| j as f64 / (m - 1) as f64).collect();
4617        // Create a smooth SRSF (sine curve)
4618        let qi: Vec<f64> = time
4619            .iter()
4620            .map(|&t| (2.0 * std::f64::consts::PI * t).sin())
4621            .collect();
4622        let bandwidth = 2.0 / (m - 1) as f64;
4623        let qi_smooth = nadaraya_watson(&time, &qi, &time, bandwidth, "gaussian");
4624        // Correlation between original and smoothed should be very high
4625        let mean_orig: f64 = qi.iter().sum::<f64>() / m as f64;
4626        let mean_smooth: f64 = qi_smooth.iter().sum::<f64>() / m as f64;
4627        let mut cov = 0.0;
4628        let mut var_o = 0.0;
4629        let mut var_s = 0.0;
4630        for j in 0..m {
4631            let do_ = qi[j] - mean_orig;
4632            let ds = qi_smooth[j] - mean_smooth;
4633            cov += do_ * ds;
4634            var_o += do_ * do_;
4635            var_s += ds * ds;
4636        }
4637        let rho = cov / (var_o * var_s).sqrt().max(1e-10);
4638        assert!(
4639            rho > 0.99,
4640            "Smoothed SRSF should be highly correlated with original (rho={rho:.4})"
4641        );
4642    }
4643
4644    #[test]
4645    fn test_tsrvf_tangent_vectors_no_spikes() {
4646        // End-to-end: compute TSRVF tangent vectors, verify no element dominates.
4647        // The smooth_aligned_srsfs step in tsrvf_from_alignment removes DP kink
4648        // artifacts that would otherwise produce spike outliers in tangent vectors.
4649        let m = 101;
4650        let argvals = uniform_grid(m);
4651        let data = make_test_data(10, m, 42);
4652        let result = tsrvf_transform(&data, &argvals, 5, 1e-3, 0.0);
4653        let (n, _) = result.tangent_vectors.shape();
4654        for i in 0..n {
4655            let vi = result.tangent_vectors.row(i);
4656            let rms = (vi.iter().map(|&v| v * v).sum::<f64>() / m as f64).sqrt();
4657            if rms > 1e-10 {
4658                let max_abs = vi.iter().map(|&v| v.abs()).fold(0.0_f64, f64::max);
4659                assert!(
4660                    max_abs < 10.0 * rms,
4661                    "Tangent vector {} has spike: max |v| = {max_abs:.4}, rms = {rms:.4}, ratio = {:.1}",
4662                    i,
4663                    max_abs / rms
4664                );
4665            }
4666        }
4667    }
4668}