Skip to main content

fdars_core/
alignment.rs

1//! Elastic alignment and SRSF (Square-Root Slope Function) transforms.
2//!
3//! This module provides phase-amplitude separation for functional data via
4//! the elastic framework. Key capabilities:
5//!
6//! - [`srsf_transform`] / [`srsf_inverse`] — SRSF representation and reconstruction
7//! - [`elastic_align_pair`] — Pairwise curve alignment via dynamic programming
8//! - [`elastic_distance`] — Elastic (Fisher-Rao) distance between curves
9//! - [`align_to_target`] — Align a set of curves to a common target
10//! - [`karcher_mean`] — Karcher (Fréchet) mean in the elastic metric
11//! - [`elastic_self_distance_matrix`] / [`elastic_cross_distance_matrix`] — Distance matrices
12//! - [`reparameterize_curve`] / [`compose_warps`] — Warping utilities
13
14use crate::fdata::{deriv_1d, mean_1d};
15use crate::helpers::{l2_distance, simpsons_weights};
16use crate::iter_maybe_parallel;
17use crate::matrix::FdMatrix;
18#[cfg(feature = "parallel")]
19use rayon::iter::ParallelIterator;
20
21// ─── Types ──────────────────────────────────────────────────────────────────
22
23/// Result of aligning one curve to another.
24#[derive(Debug, Clone)]
25pub struct AlignmentResult {
26    /// Warping function γ mapping the domain to itself.
27    pub gamma: Vec<f64>,
28    /// The aligned (reparameterized) curve.
29    pub f_aligned: Vec<f64>,
30    /// Elastic distance after alignment.
31    pub distance: f64,
32}
33
34/// Result of aligning a set of curves to a common target.
35#[derive(Debug, Clone)]
36pub struct AlignmentSetResult {
37    /// Warping functions (n × m).
38    pub gammas: FdMatrix,
39    /// Aligned curves (n × m).
40    pub aligned_data: FdMatrix,
41    /// Elastic distances for each curve.
42    pub distances: Vec<f64>,
43}
44
45/// Result of the Karcher mean computation.
46#[derive(Debug, Clone)]
47pub struct KarcherMeanResult {
48    /// Karcher mean curve.
49    pub mean: Vec<f64>,
50    /// SRSF of the Karcher mean.
51    pub mean_srsf: Vec<f64>,
52    /// Final warping functions (n × m).
53    pub gammas: FdMatrix,
54    /// Curves aligned to the mean (n × m).
55    pub aligned_data: FdMatrix,
56    /// Number of iterations used.
57    pub n_iter: usize,
58    /// Whether the algorithm converged.
59    pub converged: bool,
60}
61
62// ─── Private helpers ────────────────────────────────────────────────────────
63
64/// Linear interpolation at point `t` using binary search.
65fn linear_interp(x: &[f64], y: &[f64], t: f64) -> f64 {
66    if t <= x[0] {
67        return y[0];
68    }
69    let last = x.len() - 1;
70    if t >= x[last] {
71        return y[last];
72    }
73
74    // Binary search for the interval containing t
75    let idx = match x.binary_search_by(|v| v.partial_cmp(&t).unwrap()) {
76        Ok(i) => return y[i],
77        Err(i) => i,
78    };
79
80    let t0 = x[idx - 1];
81    let t1 = x[idx];
82    let y0 = y[idx - 1];
83    let y1 = y[idx];
84    y0 + (y1 - y0) * (t - t0) / (t1 - t0)
85}
86
87/// Cumulative trapezoidal integration.
88fn cumulative_trapz(y: &[f64], x: &[f64]) -> Vec<f64> {
89    let n = y.len();
90    let mut out = vec![0.0; n];
91    for k in 1..n {
92        out[k] = out[k - 1] + 0.5 * (y[k] + y[k - 1]) * (x[k] - x[k - 1]);
93    }
94    out
95}
96
97/// Ensure γ is a valid warping: monotone non-decreasing, with correct boundary values.
98fn normalize_warp(gamma: &mut [f64], argvals: &[f64]) {
99    let n = gamma.len();
100    if n == 0 {
101        return;
102    }
103
104    // Fix boundaries
105    gamma[0] = argvals[0];
106    gamma[n - 1] = argvals[n - 1];
107
108    // Enforce monotonicity
109    for i in 1..n {
110        if gamma[i] < gamma[i - 1] {
111            gamma[i] = gamma[i - 1];
112        }
113    }
114}
115
116// ─── Sphere Geometry for Warping Functions ──────────────────────────────────
117// Implements the Hilbert sphere representation of warping functions used by
118// fdasrvf's `SqrtMeanInverse`: psi(t) = sqrt(gamma'(t)).
119
120/// Trapezoidal integration of `y` over `x`.
121fn trapz(y: &[f64], x: &[f64]) -> f64 {
122    let mut sum = 0.0;
123    for k in 1..y.len() {
124        sum += 0.5 * (y[k] + y[k - 1]) * (x[k] - x[k - 1]);
125    }
126    sum
127}
128
129/// Numerical gradient with uniform spacing (forward/central/backward differences).
130fn gradient_uniform(y: &[f64], h: f64) -> Vec<f64> {
131    let n = y.len();
132    let mut g = vec![0.0; n];
133    if n < 2 {
134        return g;
135    }
136    g[0] = (y[1] - y[0]) / h;
137    for i in 1..(n - 1) {
138        g[i] = (y[i + 1] - y[i - 1]) / (2.0 * h);
139    }
140    g[n - 1] = (y[n - 1] - y[n - 2]) / h;
141    g
142}
143
144/// Convert warping function to Hilbert sphere representation: psi = sqrt(gamma').
145fn gam_to_psi(gam: &[f64], h: f64) -> Vec<f64> {
146    gradient_uniform(gam, h)
147        .iter()
148        .map(|&g| g.max(0.0).sqrt())
149        .collect()
150}
151
152/// Convert psi back to warping function: gamma = cumtrapz(psi^2), normalized to [0,1].
153fn psi_to_gam(psi: &[f64], time: &[f64]) -> Vec<f64> {
154    let psi_sq: Vec<f64> = psi.iter().map(|&p| p * p).collect();
155    let gam = cumulative_trapz(&psi_sq, time);
156    let min_val = gam.iter().cloned().fold(f64::INFINITY, f64::min);
157    let max_val = gam.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
158    let range = (max_val - min_val).max(1e-10);
159    gam.iter().map(|&v| (v - min_val) / range).collect()
160}
161
162/// L2 inner product: integral(psi1 * psi2 dt) via trapezoidal rule.
163fn inner_product_l2(psi1: &[f64], psi2: &[f64], time: &[f64]) -> f64 {
164    let prod: Vec<f64> = psi1.iter().zip(psi2.iter()).map(|(&a, &b)| a * b).collect();
165    trapz(&prod, time)
166}
167
168/// L2 norm: sqrt(integral(psi^2 dt)).
169fn l2_norm_l2(psi: &[f64], time: &[f64]) -> f64 {
170    inner_product_l2(psi, psi, time).max(0.0).sqrt()
171}
172
173/// Inverse exponential (log) map on the Hilbert sphere.
174/// Returns tangent vector at `mu` pointing toward `psi`.
175fn inv_exp_map_sphere(mu: &[f64], psi: &[f64], time: &[f64]) -> Vec<f64> {
176    let ip = inner_product_l2(mu, psi, time).clamp(-1.0, 1.0);
177    let theta = ip.acos();
178    if theta < 1e-10 {
179        vec![0.0; mu.len()]
180    } else {
181        let coeff = theta / theta.sin();
182        let cos_theta = theta.cos();
183        mu.iter()
184            .zip(psi.iter())
185            .map(|(&m, &p)| coeff * (p - cos_theta * m))
186            .collect()
187    }
188}
189
190/// Exponential map on the Hilbert sphere.
191/// Moves from `psi` along tangent vector `v`.
192fn exp_map_sphere(psi: &[f64], v: &[f64], time: &[f64]) -> Vec<f64> {
193    let v_norm = l2_norm_l2(v, time);
194    if v_norm < 1e-10 {
195        psi.to_vec()
196    } else {
197        let cos_n = v_norm.cos();
198        let sin_n = v_norm.sin();
199        psi.iter()
200            .zip(v.iter())
201            .map(|(&p, &vi)| cos_n * p + sin_n * vi / v_norm)
202            .collect()
203    }
204}
205
206/// Invert a warping function: find gamma_inv such that gamma_inv(gamma(t)) = t.
207/// `gam` and `time` are both on [0,1].
208fn invert_gamma(gam: &[f64], time: &[f64]) -> Vec<f64> {
209    let n = time.len();
210    // Interpolate (gam -> time) at query points time
211    // i.e., for each t in time, find s such that gam(s) = t, return s
212    let mut gam_inv: Vec<f64> = time.iter().map(|&t| linear_interp(gam, time, t)).collect();
213    gam_inv[0] = time[0];
214    gam_inv[n - 1] = time[n - 1];
215    gam_inv
216}
217
218/// Karcher mean of warping functions on the Hilbert sphere, then invert.
219/// Port of fdasrvf's `SqrtMeanInverse`.
220///
221/// Takes a matrix of warping functions (n × m) on the argvals domain,
222/// computes the Fréchet mean of their sqrt-derivative representations
223/// on the unit Hilbert sphere, converts back to a warping function,
224/// and returns its inverse (on the argvals domain).
225fn sqrt_mean_inverse(gammas: &FdMatrix, argvals: &[f64]) -> Vec<f64> {
226    let (n, m) = gammas.shape();
227    let t0 = argvals[0];
228    let t1 = argvals[m - 1];
229    let domain = t1 - t0;
230
231    // Work on [0,1] internally
232    let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
233    let binsize = 1.0 / (m - 1) as f64;
234
235    // Convert each gamma to psi = sqrt(gamma') on the unit sphere
236    let mut psis: Vec<Vec<f64>> = Vec::with_capacity(n);
237    for i in 0..n {
238        let gam_01: Vec<f64> = (0..m).map(|j| (gammas[(i, j)] - t0) / domain).collect();
239        psis.push(gam_to_psi(&gam_01, binsize));
240    }
241
242    // Initialize mu as pointwise mean of psis
243    let mut mu = vec![0.0; m];
244    for psi in &psis {
245        for j in 0..m {
246            mu[j] += psi[j];
247        }
248    }
249    for j in 0..m {
250        mu[j] /= n as f64;
251    }
252
253    // Karcher mean iteration on the Hilbert sphere
254    let step_size = 0.3;
255    let max_iter = 501;
256
257    for _ in 0..max_iter {
258        // Compute mean shooting vector (Karcher gradient)
259        let mut vbar = vec![0.0; m];
260        for psi in &psis {
261            let v = inv_exp_map_sphere(&mu, psi, &time);
262            for j in 0..m {
263                vbar[j] += v[j];
264            }
265        }
266        for j in 0..m {
267            vbar[j] /= n as f64;
268        }
269
270        // Convergence check
271        if l2_norm_l2(&vbar, &time) <= 1e-8 {
272            break;
273        }
274
275        // Move mu along geodesic
276        let scaled: Vec<f64> = vbar.iter().map(|&v| v * step_size).collect();
277        mu = exp_map_sphere(&mu, &scaled, &time);
278    }
279
280    // Convert mean psi back to warping function, then invert
281    let gam_mu = psi_to_gam(&mu, &time);
282    let gam_inv = invert_gamma(&gam_mu, &time);
283
284    // Scale back to argvals domain
285    gam_inv.iter().map(|&g| t0 + g * domain).collect()
286}
287
288// ─── SRSF Transform and Inverse ─────────────────────────────────────────────
289
290/// Compute the Square-Root Slope Function (SRSF) transform.
291///
292/// For each curve f, the SRSF is: `q(t) = sign(f'(t)) * sqrt(|f'(t)|)`
293///
294/// # Arguments
295/// * `data` — Functional data matrix (n × m)
296/// * `argvals` — Evaluation points (length m)
297///
298/// # Returns
299/// FdMatrix of SRSFs with the same shape as input.
300pub fn srsf_transform(data: &FdMatrix, argvals: &[f64]) -> FdMatrix {
301    let (n, m) = data.shape();
302    if n == 0 || m == 0 || argvals.len() != m {
303        return FdMatrix::zeros(n, m);
304    }
305
306    let deriv = deriv_1d(data, argvals, 1);
307
308    let mut result = FdMatrix::zeros(n, m);
309    for i in 0..n {
310        for j in 0..m {
311            let d = deriv[(i, j)];
312            result[(i, j)] = d.signum() * d.abs().sqrt();
313        }
314    }
315    result
316}
317
318/// Reconstruct a curve from its SRSF representation.
319///
320/// Given SRSF q and initial value f0, reconstructs: `f(t) = f0 + ∫₀ᵗ q(s)|q(s)| ds`
321///
322/// # Arguments
323/// * `q` — SRSF values (length m)
324/// * `argvals` — Evaluation points (length m)
325/// * `f0` — Initial value f(argvals\[0\])
326///
327/// # Returns
328/// Reconstructed curve values.
329pub fn srsf_inverse(q: &[f64], argvals: &[f64], f0: f64) -> Vec<f64> {
330    let m = q.len();
331    if m == 0 {
332        return Vec::new();
333    }
334
335    // Integrand: q(s) * |q(s)|
336    let integrand: Vec<f64> = q.iter().map(|&qi| qi * qi.abs()).collect();
337    let integral = cumulative_trapz(&integrand, argvals);
338
339    integral.iter().map(|&v| f0 + v).collect()
340}
341
342// ─── Reparameterization ─────────────────────────────────────────────────────
343
344/// Reparameterize a curve by a warping function.
345///
346/// Computes `f(γ(t))` via linear interpolation.
347///
348/// # Arguments
349/// * `f` — Curve values (length m)
350/// * `argvals` — Evaluation points (length m)
351/// * `gamma` — Warping function values (length m)
352pub fn reparameterize_curve(f: &[f64], argvals: &[f64], gamma: &[f64]) -> Vec<f64> {
353    gamma
354        .iter()
355        .map(|&g| linear_interp(argvals, f, g))
356        .collect()
357}
358
359/// Compose two warping functions: `(γ₁ ∘ γ₂)(t) = γ₁(γ₂(t))`.
360///
361/// # Arguments
362/// * `gamma1` — Outer warping function (length m)
363/// * `gamma2` — Inner warping function (length m)
364/// * `argvals` — Evaluation points (length m)
365pub fn compose_warps(gamma1: &[f64], gamma2: &[f64], argvals: &[f64]) -> Vec<f64> {
366    gamma2
367        .iter()
368        .map(|&g| linear_interp(argvals, gamma1, g))
369        .collect()
370}
371
372// ─── Dynamic Programming Alignment ──────────────────────────────────────────
373// Faithful port of fdasrvf's DP algorithm (dp_grid.cpp / dp_nbhd.cpp).
374
375/// Greatest common divisor (Euclidean algorithm).
376#[cfg(test)]
377fn gcd(a: usize, b: usize) -> usize {
378    if b == 0 {
379        a
380    } else {
381        gcd(b, a % b)
382    }
383}
384
385/// Generate coprime neighborhood: all (i,j) with 1 ≤ i,j ≤ nbhd_dim, gcd(i,j) = 1.
386/// With nbhd_dim=7 this produces 35 pairs, matching fdasrvf's default.
387#[cfg(test)]
388fn generate_coprime_nbhd(nbhd_dim: usize) -> Vec<(usize, usize)> {
389    let mut pairs = Vec::new();
390    for i in 1..=nbhd_dim {
391        for j in 1..=nbhd_dim {
392            if gcd(i, j) == 1 {
393                pairs.push((i, j));
394            }
395        }
396    }
397    pairs
398}
399
400/// Pre-computed coprime neighborhood for nbhd_dim=7 (fdasrvf default).
401/// All (dr, dc) with 1 ≤ dr, dc ≤ 7 and gcd(dr, dc) = 1.
402/// dr = row delta (q2 direction), dc = column delta (q1 direction).
403#[rustfmt::skip]
404const COPRIME_NBHD_7: [(usize, usize); 35] = [
405    (1,1),(1,2),(1,3),(1,4),(1,5),(1,6),(1,7),
406    (2,1),      (2,3),      (2,5),      (2,7),
407    (3,1),(3,2),      (3,4),(3,5),      (3,7),
408    (4,1),      (4,3),      (4,5),      (4,7),
409    (5,1),(5,2),(5,3),(5,4),      (5,6),(5,7),
410    (6,1),                  (6,5),      (6,7),
411    (7,1),(7,2),(7,3),(7,4),(7,5),(7,6),
412];
413
414/// Compute the edge weight for a move from grid point (sr, sc) to (tr, tc).
415///
416/// Port of fdasrvf's `dp_edge_weight` for 1-D curves on a shared uniform grid.
417/// - Rows = q2 indices, columns = q1 indices (matching fdasrvf convention).
418/// - `slope = (argvals[tr] - argvals[sr]) / (argvals[tc] - argvals[sc])` = γ'
419/// - Walks through sub-intervals synchronized at both curves' breakpoints,
420///   accumulating `(q1[idx1] - √slope · q2[idx2])² · dt`.
421#[inline]
422fn dp_edge_weight(
423    q1: &[f64],
424    q2: &[f64],
425    argvals: &[f64],
426    sc: usize,
427    tc: usize,
428    sr: usize,
429    tr: usize,
430) -> f64 {
431    let n1 = tc - sc;
432    let n2 = tr - sr;
433    if n1 == 0 || n2 == 0 {
434        return f64::INFINITY;
435    }
436
437    let slope = (argvals[tr] - argvals[sr]) / (argvals[tc] - argvals[sc]);
438    let rslope = slope.sqrt();
439
440    // Walk through sub-intervals synchronized at breakpoints of both curves
441    let mut weight = 0.0;
442    let mut i1 = 0usize; // sub-interval index in q1 direction
443    let mut i2 = 0usize; // sub-interval index in q2 direction
444
445    while i1 < n1 && i2 < n2 {
446        // Current sub-interval boundaries as fractions of the total span
447        let left1 = i1 as f64 / n1 as f64;
448        let right1 = (i1 + 1) as f64 / n1 as f64;
449        let left2 = i2 as f64 / n2 as f64;
450        let right2 = (i2 + 1) as f64 / n2 as f64;
451
452        let left = left1.max(left2);
453        let right = right1.min(right2);
454        let dt = right - left;
455
456        if dt > 0.0 {
457            let diff = q1[sc + i1] - rslope * q2[sr + i2];
458            weight += diff * diff * dt;
459        }
460
461        // Advance whichever sub-interval ends first
462        if right1 < right2 {
463            i1 += 1;
464        } else if right2 < right1 {
465            i2 += 1;
466        } else {
467            i1 += 1;
468            i2 += 1;
469        }
470    }
471
472    // Scale by the span in q1 direction
473    weight * (argvals[tc] - argvals[sc])
474}
475
476/// Core DP alignment between two SRSFs on a grid.
477///
478/// Finds the optimal warping γ minimizing ‖q₁ - (q₂∘γ)√γ'‖².
479/// Uses fdasrvf's coprime neighborhood (nbhd_dim=7 → 35 move directions).
480/// SRSFs are L2-normalized before alignment (matching fdasrvf's `optimum.reparam`).
481fn dp_alignment_core(q1: &[f64], q2: &[f64], argvals: &[f64]) -> Vec<f64> {
482    let m = argvals.len();
483    if m < 2 {
484        return argvals.to_vec();
485    }
486
487    // Normalize SRSFs to unit L2 norm (matching fdasrvf's optimum.reparam)
488    let norm1 = q1.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
489    let norm2 = q2.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
490    let q1n: Vec<f64> = q1.iter().map(|&v| v / norm1).collect();
491    let q2n: Vec<f64> = q2.iter().map(|&v| v / norm2).collect();
492
493    // Full m×m cost table and parent pointers
494    // Rows = q2 index, Columns = q1 index (matching fdasrvf)
495    let mut e = vec![f64::INFINITY; m * m];
496    let mut parent = vec![u32::MAX; m * m];
497    e[0] = 0.0;
498
499    for tr in 1..m {
500        for tc in 1..m {
501            let idx = tr * m + tc;
502            for &(dr, dc) in &COPRIME_NBHD_7 {
503                if dr > tr || dc > tc {
504                    continue;
505                }
506                let sr = tr - dr;
507                let sc = tc - dc;
508                let src_idx = sr * m + sc;
509                if e[src_idx] == f64::INFINITY {
510                    continue;
511                }
512                let w = dp_edge_weight(&q1n, &q2n, argvals, sc, tc, sr, tr);
513                let cost = e[src_idx] + w;
514                if cost < e[idx] {
515                    e[idx] = cost;
516                    parent[idx] = src_idx as u32;
517                }
518            }
519        }
520    }
521
522    // Traceback from (m-1, m-1) to (0, 0) using parent pointers
523    let mut path_tc = Vec::with_capacity(2 * m);
524    let mut path_tr = Vec::with_capacity(2 * m);
525    let mut cur = (m - 1) * m + (m - 1);
526    loop {
527        let tr = cur / m;
528        let tc = cur % m;
529        path_tc.push(argvals[tc]);
530        path_tr.push(argvals[tr]);
531        if cur == 0 {
532            break;
533        }
534        if parent[cur] == u32::MAX {
535            break;
536        }
537        cur = parent[cur] as usize;
538    }
539
540    // Reverse to forward order
541    path_tc.reverse();
542    path_tr.reverse();
543
544    // Re-interpolate gamma onto the argvals grid
545    // path_tc = t values (q1 side), path_tr = gamma values (q2 side)
546    let mut gamma: Vec<f64> = argvals
547        .iter()
548        .map(|&t| linear_interp(&path_tc, &path_tr, t))
549        .collect();
550
551    normalize_warp(&mut gamma, argvals);
552    gamma
553}
554
555// ─── Public Alignment Functions ─────────────────────────────────────────────
556
557/// Align curve f2 to curve f1 using the elastic framework.
558///
559/// Computes the optimal warping γ such that f2∘γ is as close as possible
560/// to f1 in the elastic (Fisher-Rao) metric.
561///
562/// # Arguments
563/// * `f1` — Target curve (length m)
564/// * `f2` — Curve to align (length m)
565/// * `argvals` — Evaluation points (length m)
566///
567/// # Returns
568/// [`AlignmentResult`] with warping function, aligned curve, and elastic distance.
569pub fn elastic_align_pair(f1: &[f64], f2: &[f64], argvals: &[f64]) -> AlignmentResult {
570    let m = f1.len();
571
572    // Build single-row FdMatrices for SRSF computation
573    let f1_mat = FdMatrix::from_slice(f1, 1, m).unwrap();
574    let f2_mat = FdMatrix::from_slice(f2, 1, m).unwrap();
575
576    let q1_mat = srsf_transform(&f1_mat, argvals);
577    let q2_mat = srsf_transform(&f2_mat, argvals);
578
579    let q1: Vec<f64> = q1_mat.row(0);
580    let q2: Vec<f64> = q2_mat.row(0);
581
582    // Find optimal warping via DP
583    let gamma = dp_alignment_core(&q1, &q2, argvals);
584
585    // Apply warping to f2
586    let f_aligned = reparameterize_curve(f2, argvals, &gamma);
587
588    // Compute elastic distance: L2 distance between q1 and aligned q2 SRSF
589    let f_aligned_mat = FdMatrix::from_slice(&f_aligned, 1, m).unwrap();
590    let q_aligned_mat = srsf_transform(&f_aligned_mat, argvals);
591    let q_aligned: Vec<f64> = q_aligned_mat.row(0);
592
593    let weights = simpsons_weights(argvals);
594    let distance = l2_distance(&q1, &q_aligned, &weights);
595
596    AlignmentResult {
597        gamma,
598        f_aligned,
599        distance,
600    }
601}
602
603/// Compute the elastic distance between two curves.
604///
605/// This is shorthand for aligning the pair and returning only the distance.
606///
607/// # Arguments
608/// * `f1` — First curve (length m)
609/// * `f2` — Second curve (length m)
610/// * `argvals` — Evaluation points (length m)
611pub fn elastic_distance(f1: &[f64], f2: &[f64], argvals: &[f64]) -> f64 {
612    elastic_align_pair(f1, f2, argvals).distance
613}
614
615/// Align all curves in `data` to a single target curve.
616///
617/// # Arguments
618/// * `data` — Functional data matrix (n × m)
619/// * `target` — Target curve to align to (length m)
620/// * `argvals` — Evaluation points (length m)
621///
622/// # Returns
623/// [`AlignmentSetResult`] with all warping functions, aligned curves, and distances.
624pub fn align_to_target(data: &FdMatrix, target: &[f64], argvals: &[f64]) -> AlignmentSetResult {
625    let (n, m) = data.shape();
626
627    let results: Vec<AlignmentResult> = iter_maybe_parallel!(0..n)
628        .map(|i| {
629            let fi = data.row(i);
630            elastic_align_pair(target, &fi, argvals)
631        })
632        .collect();
633
634    let mut gammas = FdMatrix::zeros(n, m);
635    let mut aligned_data = FdMatrix::zeros(n, m);
636    let mut distances = Vec::with_capacity(n);
637
638    for (i, r) in results.into_iter().enumerate() {
639        for j in 0..m {
640            gammas[(i, j)] = r.gamma[j];
641            aligned_data[(i, j)] = r.f_aligned[j];
642        }
643        distances.push(r.distance);
644    }
645
646    AlignmentSetResult {
647        gammas,
648        aligned_data,
649        distances,
650    }
651}
652
653// ─── Distance Matrices ──────────────────────────────────────────────────────
654
655/// Compute the symmetric elastic distance matrix for a set of curves.
656///
657/// Uses upper-triangle computation with parallelism, following the
658/// `self_distance_matrix` pattern from `metric.rs`.
659///
660/// # Arguments
661/// * `data` — Functional data matrix (n × m)
662/// * `argvals` — Evaluation points (length m)
663///
664/// # Returns
665/// Symmetric n × n distance matrix.
666pub fn elastic_self_distance_matrix(data: &FdMatrix, argvals: &[f64]) -> FdMatrix {
667    let n = data.nrows();
668
669    let upper_vals: Vec<f64> = iter_maybe_parallel!(0..n)
670        .flat_map(|i| {
671            let fi = data.row(i);
672            ((i + 1)..n)
673                .map(|j| {
674                    let fj = data.row(j);
675                    elastic_distance(&fi, &fj, argvals)
676                })
677                .collect::<Vec<_>>()
678        })
679        .collect();
680
681    let mut dist = FdMatrix::zeros(n, n);
682    let mut idx = 0;
683    for i in 0..n {
684        for j in (i + 1)..n {
685            let d = upper_vals[idx];
686            dist[(i, j)] = d;
687            dist[(j, i)] = d;
688            idx += 1;
689        }
690    }
691    dist
692}
693
694/// Compute the elastic distance matrix between two sets of curves.
695///
696/// # Arguments
697/// * `data1` — First dataset (n1 × m)
698/// * `data2` — Second dataset (n2 × m)
699/// * `argvals` — Evaluation points (length m)
700///
701/// # Returns
702/// n1 × n2 distance matrix.
703pub fn elastic_cross_distance_matrix(
704    data1: &FdMatrix,
705    data2: &FdMatrix,
706    argvals: &[f64],
707) -> FdMatrix {
708    let n1 = data1.nrows();
709    let n2 = data2.nrows();
710
711    let vals: Vec<f64> = iter_maybe_parallel!(0..n1)
712        .flat_map(|i| {
713            let fi = data1.row(i);
714            (0..n2)
715                .map(|j| {
716                    let fj = data2.row(j);
717                    elastic_distance(&fi, &fj, argvals)
718                })
719                .collect::<Vec<_>>()
720        })
721        .collect();
722
723    let mut dist = FdMatrix::zeros(n1, n2);
724    for i in 0..n1 {
725        for j in 0..n2 {
726            dist[(i, j)] = vals[i * n2 + j];
727        }
728    }
729    dist
730}
731
732// ─── Karcher Mean ───────────────────────────────────────────────────────────
733
734/// Compute relative change between successive mean SRSFs.
735///
736/// Returns `‖q_new - q_old‖₂ / ‖q_old‖₂`, matching R's fdasrvf
737/// `time_warping` convergence metric (unweighted discrete L2 norm).
738fn relative_change(q_old: &[f64], q_new: &[f64]) -> f64 {
739    let diff_norm: f64 = q_old
740        .iter()
741        .zip(q_new.iter())
742        .map(|(&a, &b)| (a - b).powi(2))
743        .sum::<f64>()
744        .sqrt();
745    let old_norm: f64 = q_old.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
746    diff_norm / old_norm
747}
748
749/// Compute a single SRSF from a slice (single-row convenience).
750fn srsf_single(f: &[f64], argvals: &[f64]) -> Vec<f64> {
751    let m = f.len();
752    let mat = FdMatrix::from_slice(f, 1, m).unwrap();
753    let q_mat = srsf_transform(&mat, argvals);
754    q_mat.row(0)
755}
756
757/// Align a single SRSF q2 to q1 and return (gamma, aligned_q).
758fn align_srsf_pair(q1: &[f64], q2: &[f64], argvals: &[f64]) -> (Vec<f64>, Vec<f64>) {
759    let gamma = dp_alignment_core(q1, q2, argvals);
760
761    // Warp q2 by gamma and adjust by sqrt(gamma')
762    let q2_warped = reparameterize_curve(q2, argvals, &gamma);
763
764    // Compute gamma' via finite differences
765    let m = gamma.len();
766    let mut gamma_dot = vec![0.0; m];
767    gamma_dot[0] = (gamma[1] - gamma[0]) / (argvals[1] - argvals[0]);
768    for j in 1..(m - 1) {
769        gamma_dot[j] = (gamma[j + 1] - gamma[j - 1]) / (argvals[j + 1] - argvals[j - 1]);
770    }
771    gamma_dot[m - 1] = (gamma[m - 1] - gamma[m - 2]) / (argvals[m - 1] - argvals[m - 2]);
772
773    // q2_aligned = (q2 ∘ γ) * sqrt(γ')
774    let q2_aligned: Vec<f64> = q2_warped
775        .iter()
776        .zip(gamma_dot.iter())
777        .map(|(&q, &gd)| q * gd.max(0.0).sqrt())
778        .collect();
779
780    (gamma, q2_aligned)
781}
782
783/// Compute the Karcher (Fréchet) mean in the elastic metric.
784///
785/// Iteratively aligns all curves to the current mean estimate in SRSF space,
786/// computes the pointwise mean of aligned SRSFs, and reconstructs the mean curve.
787///
788/// # Arguments
789/// * `data` — Functional data matrix (n × m)
790/// * `argvals` — Evaluation points (length m)
791/// * `max_iter` — Maximum number of iterations
792/// * `tol` — Convergence tolerance for the SRSF mean
793///
794/// # Returns
795/// [`KarcherMeanResult`] with mean curve, warping functions, aligned data, and convergence info.
796///
797/// # Examples
798///
799/// ```
800/// use fdars_core::simulation::{sim_fundata, EFunType, EValType};
801/// use fdars_core::alignment::karcher_mean;
802///
803/// let t: Vec<f64> = (0..50).map(|i| i as f64 / 49.0).collect();
804/// let data = sim_fundata(20, &t, 3, EFunType::Fourier, EValType::Exponential, Some(42));
805///
806/// let result = karcher_mean(&data, &t, 20, 1e-4);
807/// assert_eq!(result.mean.len(), 50);
808/// assert!(result.n_iter <= 20);
809/// ```
810/// Accumulate alignment results: store gammas and return the mean of aligned SRSFs.
811fn accumulate_alignments(
812    results: &[(Vec<f64>, Vec<f64>)],
813    gammas: &mut FdMatrix,
814    m: usize,
815    n: usize,
816) -> Vec<f64> {
817    let mut mu_q_new = vec![0.0; m];
818    for (i, (gamma, q_aligned)) in results.iter().enumerate() {
819        for j in 0..m {
820            gammas[(i, j)] = gamma[j];
821            mu_q_new[j] += q_aligned[j];
822        }
823    }
824    for j in 0..m {
825        mu_q_new[j] /= n as f64;
826    }
827    mu_q_new
828}
829
830/// Apply stored warps to original curves to produce aligned data.
831fn apply_stored_warps(data: &FdMatrix, gammas: &FdMatrix, argvals: &[f64]) -> FdMatrix {
832    let (n, m) = data.shape();
833    let mut aligned = FdMatrix::zeros(n, m);
834    for i in 0..n {
835        let fi = data.row(i);
836        let gamma: Vec<f64> = (0..m).map(|j| gammas[(i, j)]).collect();
837        let f_aligned = reparameterize_curve(&fi, argvals, &gamma);
838        for j in 0..m {
839            aligned[(i, j)] = f_aligned[j];
840        }
841    }
842    aligned
843}
844
845pub fn karcher_mean(
846    data: &FdMatrix,
847    argvals: &[f64],
848    max_iter: usize,
849    tol: f64,
850) -> KarcherMeanResult {
851    let (n, m) = data.shape();
852
853    // Step 1: Compute SRSFs and select closest observed SRSF to the mean as template
854    let srsf_mat = srsf_transform(data, argvals);
855    let mnq = mean_1d(&srsf_mat);
856    let mut min_dist = f64::INFINITY;
857    let mut min_idx = 0;
858    for i in 0..n {
859        let dist_sq: f64 = (0..m).map(|j| (srsf_mat[(i, j)] - mnq[j]).powi(2)).sum();
860        if dist_sq < min_dist {
861            min_dist = dist_sq;
862            min_idx = i;
863        }
864    }
865    let mut mu_q = srsf_mat.row(min_idx);
866    let mut mu = data.row(min_idx);
867
868    // Step 2: Pre-iteration centering with SqrtMeanInverse
869    // Align all curves to the selected template, then center the template
870    {
871        let align_results: Vec<(Vec<f64>, Vec<f64>)> = iter_maybe_parallel!(0..n)
872            .map(|i| {
873                let fi = data.row(i);
874                let qi = srsf_single(&fi, argvals);
875                align_srsf_pair(&mu_q, &qi, argvals)
876            })
877            .collect();
878
879        let mut init_gammas = FdMatrix::zeros(n, m);
880        for (i, (gamma, _)) in align_results.iter().enumerate() {
881            for j in 0..m {
882                init_gammas[(i, j)] = gamma[j];
883            }
884        }
885
886        // Center: compute inverse mean warp, apply to template
887        let gam_inv = sqrt_mean_inverse(&init_gammas, argvals);
888        mu = reparameterize_curve(&mu, argvals, &gam_inv);
889        mu_q = srsf_single(&mu, argvals);
890    }
891
892    // Step 3: Main Karcher mean iteration
893    let mut converged = false;
894    let mut n_iter = 0;
895    let mut final_gammas = FdMatrix::zeros(n, m);
896    let mut prev_rel = 0.0_f64;
897
898    for iter in 0..max_iter {
899        n_iter = iter + 1;
900
901        let align_results: Vec<(Vec<f64>, Vec<f64>)> = iter_maybe_parallel!(0..n)
902            .map(|i| {
903                let fi = data.row(i);
904                let qi = srsf_single(&fi, argvals);
905                align_srsf_pair(&mu_q, &qi, argvals)
906            })
907            .collect();
908
909        let mu_q_new = accumulate_alignments(&align_results, &mut final_gammas, m, n);
910
911        let rel = relative_change(&mu_q, &mu_q_new);
912        if rel < f64::EPSILON || (iter > 0 && rel - prev_rel <= tol * prev_rel) {
913            converged = true;
914            mu_q = mu_q_new;
915            break;
916        }
917        prev_rel = rel;
918
919        mu_q = mu_q_new;
920        mu = srsf_inverse(&mu_q, argvals, mu[0]);
921    }
922
923    // Step 4: Post-convergence centering with SqrtMeanInverse
924    let gam_inv = sqrt_mean_inverse(&final_gammas, argvals);
925    let h = (argvals[m - 1] - argvals[0]) / (m - 1) as f64;
926    let gam_inv_dev = gradient_uniform(&gam_inv, h);
927
928    // Center the mean SRSF: (mu_q ∘ gamI) * sqrt(gamI')
929    let mu_q_warped = reparameterize_curve(&mu_q, argvals, &gam_inv);
930    mu_q = mu_q_warped
931        .iter()
932        .zip(gam_inv_dev.iter())
933        .map(|(&q, &gd)| q * gd.max(0.0).sqrt())
934        .collect();
935
936    // Center each curve's warp: gamma_centered = gamma ∘ gamI
937    for i in 0..n {
938        let gam_i: Vec<f64> = (0..m).map(|j| final_gammas[(i, j)]).collect();
939        let gam_centered = reparameterize_curve(&gam_i, argvals, &gam_inv);
940        for j in 0..m {
941            final_gammas[(i, j)] = gam_centered[j];
942        }
943    }
944
945    // Reconstruct mean curve from centered SRSF
946    let initial_mean = mean_1d(data);
947    mu = srsf_inverse(&mu_q, argvals, initial_mean[0]);
948    let final_aligned = apply_stored_warps(data, &final_gammas, argvals);
949
950    KarcherMeanResult {
951        mean: mu,
952        mean_srsf: mu_q,
953        gammas: final_gammas,
954        aligned_data: final_aligned,
955        n_iter,
956        converged,
957    }
958}
959
960// ─── Tests ──────────────────────────────────────────────────────────────────
961
962#[cfg(test)]
963mod tests {
964    use super::*;
965    use crate::simulation::{sim_fundata, EFunType, EValType};
966
967    fn uniform_grid(m: usize) -> Vec<f64> {
968        (0..m).map(|i| i as f64 / (m - 1) as f64).collect()
969    }
970
971    fn make_test_data(n: usize, m: usize, seed: u64) -> FdMatrix {
972        let t = uniform_grid(m);
973        sim_fundata(
974            n,
975            &t,
976            3,
977            EFunType::Fourier,
978            EValType::Exponential,
979            Some(seed),
980        )
981    }
982
983    // ── cumulative_trapz ──
984
985    #[test]
986    fn test_cumulative_trapz_constant() {
987        // ∫₀ᵗ 1 dt = t
988        let x = uniform_grid(50);
989        let y = vec![1.0; 50];
990        let result = cumulative_trapz(&y, &x);
991        assert!((result[0]).abs() < 1e-15, "cumulative_trapz(0) should be 0");
992        for j in 1..50 {
993            assert!(
994                (result[j] - x[j]).abs() < 1e-12,
995                "∫₀^{:.3} 1 dt should be {:.3}, got {:.3}",
996                x[j],
997                x[j],
998                result[j]
999            );
1000        }
1001    }
1002
1003    #[test]
1004    fn test_cumulative_trapz_linear() {
1005        // ∫₀ᵗ s ds = t²/2
1006        let m = 100;
1007        let x = uniform_grid(m);
1008        let y: Vec<f64> = x.clone();
1009        let result = cumulative_trapz(&y, &x);
1010        for j in 1..m {
1011            let expected = x[j] * x[j] / 2.0;
1012            assert!(
1013                (result[j] - expected).abs() < 1e-4,
1014                "∫₀^{:.3} s ds: expected {expected:.6}, got {:.6}",
1015                x[j],
1016                result[j]
1017            );
1018        }
1019    }
1020
1021    // ── normalize_warp ──
1022
1023    #[test]
1024    fn test_normalize_warp_fixes_boundaries() {
1025        let t = uniform_grid(10);
1026        let mut gamma = vec![0.1; 10]; // constant, wrong boundaries
1027        normalize_warp(&mut gamma, &t);
1028        assert_eq!(gamma[0], t[0]);
1029        assert_eq!(gamma[9], t[9]);
1030    }
1031
1032    #[test]
1033    fn test_normalize_warp_enforces_monotonicity() {
1034        let t = uniform_grid(5);
1035        let mut gamma = vec![0.0, 0.5, 0.3, 0.8, 1.0]; // non-monotone at index 2
1036        normalize_warp(&mut gamma, &t);
1037        for j in 1..5 {
1038            assert!(
1039                gamma[j] >= gamma[j - 1],
1040                "gamma should be monotone after normalization at j={j}"
1041            );
1042        }
1043    }
1044
1045    #[test]
1046    fn test_normalize_warp_identity_unchanged() {
1047        let t = uniform_grid(20);
1048        let mut gamma = t.clone();
1049        normalize_warp(&mut gamma, &t);
1050        for j in 0..20 {
1051            assert!(
1052                (gamma[j] - t[j]).abs() < 1e-15,
1053                "Identity warp should be unchanged"
1054            );
1055        }
1056    }
1057
1058    // ── linear_interp ──
1059
1060    #[test]
1061    fn test_linear_interp_at_nodes() {
1062        let x = vec![0.0, 1.0, 2.0, 3.0];
1063        let y = vec![0.0, 2.0, 4.0, 6.0];
1064        for i in 0..x.len() {
1065            assert!((linear_interp(&x, &y, x[i]) - y[i]).abs() < 1e-12);
1066        }
1067    }
1068
1069    #[test]
1070    fn test_linear_interp_midpoints() {
1071        let x = vec![0.0, 1.0, 2.0];
1072        let y = vec![0.0, 2.0, 4.0];
1073        assert!((linear_interp(&x, &y, 0.5) - 1.0).abs() < 1e-12);
1074        assert!((linear_interp(&x, &y, 1.5) - 3.0).abs() < 1e-12);
1075    }
1076
1077    #[test]
1078    fn test_linear_interp_clamp() {
1079        let x = vec![0.0, 1.0, 2.0];
1080        let y = vec![1.0, 3.0, 5.0];
1081        assert!((linear_interp(&x, &y, -1.0) - 1.0).abs() < 1e-12);
1082        assert!((linear_interp(&x, &y, 3.0) - 5.0).abs() < 1e-12);
1083    }
1084
1085    #[test]
1086    fn test_linear_interp_nonuniform_grid() {
1087        let x = vec![0.0, 0.1, 0.5, 1.0];
1088        let y = vec![0.0, 1.0, 5.0, 10.0];
1089        // Between 0.1 and 0.5: slope = (5-1)/(0.5-0.1) = 10
1090        let val = linear_interp(&x, &y, 0.3);
1091        let expected = 1.0 + 10.0 * (0.3 - 0.1);
1092        assert!(
1093            (val - expected).abs() < 1e-12,
1094            "Non-uniform interp: expected {expected}, got {val}"
1095        );
1096    }
1097
1098    #[test]
1099    fn test_linear_interp_two_points() {
1100        let x = vec![0.0, 1.0];
1101        let y = vec![3.0, 7.0];
1102        assert!((linear_interp(&x, &y, 0.25) - 4.0).abs() < 1e-12);
1103        assert!((linear_interp(&x, &y, 0.75) - 6.0).abs() < 1e-12);
1104    }
1105
1106    // ── SRSF transform ──
1107
1108    #[test]
1109    fn test_srsf_transform_linear() {
1110        // f(t) = 2t: derivative = 2, SRSF = sqrt(2)
1111        let m = 50;
1112        let t = uniform_grid(m);
1113        let f: Vec<f64> = t.iter().map(|&ti| 2.0 * ti).collect();
1114        let mat = FdMatrix::from_slice(&f, 1, m).unwrap();
1115
1116        let q_mat = srsf_transform(&mat, &t);
1117        let q: Vec<f64> = q_mat.row(0);
1118
1119        let expected = 2.0_f64.sqrt();
1120        // Interior points should be close to sqrt(2)
1121        for j in 2..(m - 2) {
1122            assert!(
1123                (q[j] - expected).abs() < 0.1,
1124                "q[{j}] = {}, expected ~{expected}",
1125                q[j]
1126            );
1127        }
1128    }
1129
1130    #[test]
1131    fn test_srsf_transform_preserves_shape() {
1132        let data = make_test_data(10, 50, 42);
1133        let t = uniform_grid(50);
1134        let q = srsf_transform(&data, &t);
1135        assert_eq!(q.shape(), data.shape());
1136    }
1137
1138    #[test]
1139    fn test_srsf_transform_constant_is_zero() {
1140        // f(t) = 5 (constant): derivative = 0, SRSF = 0
1141        let m = 30;
1142        let t = uniform_grid(m);
1143        let f = vec![5.0; m];
1144        let mat = FdMatrix::from_slice(&f, 1, m).unwrap();
1145        let q_mat = srsf_transform(&mat, &t);
1146        let q: Vec<f64> = q_mat.row(0);
1147
1148        for j in 0..m {
1149            assert!(
1150                q[j].abs() < 1e-10,
1151                "SRSF of constant should be 0, got q[{j}] = {}",
1152                q[j]
1153            );
1154        }
1155    }
1156
1157    #[test]
1158    fn test_srsf_transform_negative_slope() {
1159        // f(t) = -3t: derivative = -3, SRSF = -sqrt(3)
1160        let m = 50;
1161        let t = uniform_grid(m);
1162        let f: Vec<f64> = t.iter().map(|&ti| -3.0 * ti).collect();
1163        let mat = FdMatrix::from_slice(&f, 1, m).unwrap();
1164
1165        let q_mat = srsf_transform(&mat, &t);
1166        let q: Vec<f64> = q_mat.row(0);
1167
1168        let expected = -(3.0_f64.sqrt());
1169        for j in 2..(m - 2) {
1170            assert!(
1171                (q[j] - expected).abs() < 0.15,
1172                "q[{j}] = {}, expected ~{expected}",
1173                q[j]
1174            );
1175        }
1176    }
1177
1178    #[test]
1179    fn test_srsf_transform_empty_input() {
1180        let data = FdMatrix::zeros(0, 0);
1181        let t: Vec<f64> = vec![];
1182        let q = srsf_transform(&data, &t);
1183        assert_eq!(q.shape(), (0, 0));
1184    }
1185
1186    #[test]
1187    fn test_srsf_transform_multiple_curves() {
1188        let m = 40;
1189        let t = uniform_grid(m);
1190        let data = make_test_data(5, m, 42);
1191
1192        let q = srsf_transform(&data, &t);
1193        assert_eq!(q.shape(), (5, m));
1194
1195        // Each row should have finite values
1196        for i in 0..5 {
1197            for j in 0..m {
1198                assert!(q[(i, j)].is_finite(), "SRSF should be finite at ({i},{j})");
1199            }
1200        }
1201    }
1202
1203    // ── SRSF inverse ──
1204
1205    #[test]
1206    fn test_srsf_round_trip() {
1207        let m = 100;
1208        let t = uniform_grid(m);
1209        // Use a smooth function
1210        let f: Vec<f64> = t
1211            .iter()
1212            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin() + ti)
1213            .collect();
1214
1215        let mat = FdMatrix::from_slice(&f, 1, m).unwrap();
1216        let q_mat = srsf_transform(&mat, &t);
1217        let q: Vec<f64> = q_mat.row(0);
1218
1219        let f_recon = srsf_inverse(&q, &t, f[0]);
1220
1221        // Check reconstruction is close (interior points, avoid boundary effects)
1222        let max_err: f64 = f[5..(m - 5)]
1223            .iter()
1224            .zip(f_recon[5..(m - 5)].iter())
1225            .map(|(a, b)| (a - b).abs())
1226            .fold(0.0_f64, f64::max);
1227
1228        assert!(
1229            max_err < 0.15,
1230            "Round-trip error too large: max_err = {max_err}"
1231        );
1232    }
1233
1234    #[test]
1235    fn test_srsf_inverse_empty() {
1236        let q: Vec<f64> = vec![];
1237        let t: Vec<f64> = vec![];
1238        let result = srsf_inverse(&q, &t, 0.0);
1239        assert!(result.is_empty());
1240    }
1241
1242    #[test]
1243    fn test_srsf_inverse_preserves_initial_value() {
1244        let m = 50;
1245        let t = uniform_grid(m);
1246        let q = vec![1.0; m]; // constant SRSF
1247        let f0 = 3.15;
1248        let f = srsf_inverse(&q, &t, f0);
1249        assert!((f[0] - f0).abs() < 1e-12, "srsf_inverse should start at f0");
1250    }
1251
1252    #[test]
1253    fn test_srsf_round_trip_multiple_curves() {
1254        let m = 80;
1255        let t = uniform_grid(m);
1256        let data = make_test_data(5, m, 99);
1257
1258        let q_mat = srsf_transform(&data, &t);
1259
1260        for i in 0..5 {
1261            let fi = data.row(i);
1262            let qi = q_mat.row(i);
1263            let f_recon = srsf_inverse(&qi, &t, fi[0]);
1264            let max_err: f64 = fi[5..(m - 5)]
1265                .iter()
1266                .zip(f_recon[5..(m - 5)].iter())
1267                .map(|(a, b)| (a - b).abs())
1268                .fold(0.0_f64, f64::max);
1269            assert!(max_err < 0.3, "Round-trip curve {i}: max_err = {max_err}");
1270        }
1271    }
1272
1273    // ── Reparameterize ──
1274
1275    #[test]
1276    fn test_reparameterize_identity_warp() {
1277        let m = 50;
1278        let t = uniform_grid(m);
1279        let f: Vec<f64> = t.iter().map(|&ti| ti * ti).collect();
1280
1281        // Identity warp: γ(t) = t
1282        let result = reparameterize_curve(&f, &t, &t);
1283        for j in 0..m {
1284            assert!(
1285                (result[j] - f[j]).abs() < 1e-12,
1286                "Identity warp should return original at j={j}"
1287            );
1288        }
1289    }
1290
1291    #[test]
1292    fn test_reparameterize_linear_warp() {
1293        let m = 50;
1294        let t = uniform_grid(m);
1295        // f(t) = t (linear), γ(t) = t^2 (quadratic warp on [0,1])
1296        let f: Vec<f64> = t.clone();
1297        let gamma: Vec<f64> = t.iter().map(|&ti| ti * ti).collect();
1298
1299        let result = reparameterize_curve(&f, &t, &gamma);
1300
1301        // f(γ(t)) = γ(t) = t^2 for a linear f(t) = t
1302        for j in 0..m {
1303            assert!(
1304                (result[j] - gamma[j]).abs() < 1e-10,
1305                "f(gamma(t)) should be gamma(t) for f(t)=t at j={j}"
1306            );
1307        }
1308    }
1309
1310    #[test]
1311    fn test_reparameterize_sine_with_quadratic_warp() {
1312        let m = 100;
1313        let t = uniform_grid(m);
1314        let f: Vec<f64> = t
1315            .iter()
1316            .map(|&ti| (std::f64::consts::PI * ti).sin())
1317            .collect();
1318        let gamma: Vec<f64> = t.iter().map(|&ti| ti * ti).collect(); // speeds up start
1319
1320        let result = reparameterize_curve(&f, &t, &gamma);
1321
1322        // f(γ(t)) = sin(π * t²); check a few known values
1323        for j in 0..m {
1324            let expected = (std::f64::consts::PI * gamma[j]).sin();
1325            assert!(
1326                (result[j] - expected).abs() < 0.05,
1327                "sin(π γ(t)) at j={j}: expected {expected:.4}, got {:.4}",
1328                result[j]
1329            );
1330        }
1331    }
1332
1333    #[test]
1334    fn test_reparameterize_preserves_length() {
1335        let m = 50;
1336        let t = uniform_grid(m);
1337        let f = vec![1.0; m];
1338        let gamma: Vec<f64> = t.iter().map(|&ti| ti.sqrt()).collect();
1339
1340        let result = reparameterize_curve(&f, &t, &gamma);
1341        assert_eq!(result.len(), m);
1342    }
1343
1344    // ── Compose warps ──
1345
1346    #[test]
1347    fn test_compose_warps_identity() {
1348        let m = 50;
1349        let t = uniform_grid(m);
1350        // γ(t) = t^0.5 (a warp on [0,1])
1351        let gamma: Vec<f64> = t.iter().map(|&ti| ti.sqrt()).collect();
1352
1353        // identity ∘ γ = γ
1354        let result = compose_warps(&t, &gamma, &t);
1355        for j in 0..m {
1356            assert!(
1357                (result[j] - gamma[j]).abs() < 1e-10,
1358                "id ∘ γ should be γ at j={j}"
1359            );
1360        }
1361
1362        // γ ∘ identity = γ
1363        let result2 = compose_warps(&gamma, &t, &t);
1364        for j in 0..m {
1365            assert!(
1366                (result2[j] - gamma[j]).abs() < 1e-10,
1367                "γ ∘ id should be γ at j={j}"
1368            );
1369        }
1370    }
1371
1372    #[test]
1373    fn test_compose_warps_associativity() {
1374        // (γ₁ ∘ γ₂) ∘ γ₃ ≈ γ₁ ∘ (γ₂ ∘ γ₃)
1375        let m = 50;
1376        let t = uniform_grid(m);
1377        let g1: Vec<f64> = t.iter().map(|&ti| ti.sqrt()).collect();
1378        let g2: Vec<f64> = t.iter().map(|&ti| ti * ti).collect();
1379        let g3: Vec<f64> = t.iter().map(|&ti| 0.5 * ti + 0.5 * ti * ti).collect();
1380
1381        let g12 = compose_warps(&g1, &g2, &t);
1382        let left = compose_warps(&g12, &g3, &t); // (g1∘g2) ∘ g3
1383
1384        let g23 = compose_warps(&g2, &g3, &t);
1385        let right = compose_warps(&g1, &g23, &t); // g1 ∘ (g2∘g3)
1386
1387        for j in 0..m {
1388            assert!(
1389                (left[j] - right[j]).abs() < 0.05,
1390                "Composition should be roughly associative at j={j}: left={:.4}, right={:.4}",
1391                left[j],
1392                right[j]
1393            );
1394        }
1395    }
1396
1397    #[test]
1398    fn test_compose_warps_preserves_domain() {
1399        let m = 50;
1400        let t = uniform_grid(m);
1401        let g1: Vec<f64> = t.iter().map(|&ti| ti.sqrt()).collect();
1402        let g2: Vec<f64> = t.iter().map(|&ti| ti * ti).collect();
1403
1404        let composed = compose_warps(&g1, &g2, &t);
1405        assert!(
1406            (composed[0] - t[0]).abs() < 1e-10,
1407            "Composed warp should start at domain start"
1408        );
1409        assert!(
1410            (composed[m - 1] - t[m - 1]).abs() < 1e-10,
1411            "Composed warp should end at domain end"
1412        );
1413    }
1414
1415    // ── Elastic align pair ──
1416
1417    #[test]
1418    fn test_align_identical_curves() {
1419        let m = 50;
1420        let t = uniform_grid(m);
1421        let f: Vec<f64> = t
1422            .iter()
1423            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
1424            .collect();
1425
1426        let result = elastic_align_pair(&f, &f, &t);
1427
1428        // Distance should be near zero
1429        assert!(
1430            result.distance < 0.1,
1431            "Distance between identical curves should be near 0, got {}",
1432            result.distance
1433        );
1434
1435        // Warp should be near identity
1436        for j in 0..m {
1437            assert!(
1438                (result.gamma[j] - t[j]).abs() < 0.1,
1439                "Warp should be near identity at j={j}: gamma={}, t={}",
1440                result.gamma[j],
1441                t[j]
1442            );
1443        }
1444    }
1445
1446    #[test]
1447    fn test_align_pair_valid_output() {
1448        let data = make_test_data(2, 50, 42);
1449        let t = uniform_grid(50);
1450        let f1 = data.row(0);
1451        let f2 = data.row(1);
1452
1453        let result = elastic_align_pair(&f1, &f2, &t);
1454
1455        assert_eq!(result.gamma.len(), 50);
1456        assert_eq!(result.f_aligned.len(), 50);
1457        assert!(result.distance >= 0.0);
1458
1459        // Warp should be monotone
1460        for j in 1..50 {
1461            assert!(
1462                result.gamma[j] >= result.gamma[j - 1],
1463                "Warp should be monotone at j={j}"
1464            );
1465        }
1466    }
1467
1468    #[test]
1469    fn test_align_pair_warp_boundaries() {
1470        let data = make_test_data(2, 50, 42);
1471        let t = uniform_grid(50);
1472        let f1 = data.row(0);
1473        let f2 = data.row(1);
1474
1475        let result = elastic_align_pair(&f1, &f2, &t);
1476        assert!(
1477            (result.gamma[0] - t[0]).abs() < 1e-12,
1478            "Warp should start at domain start"
1479        );
1480        assert!(
1481            (result.gamma[49] - t[49]).abs() < 1e-12,
1482            "Warp should end at domain end"
1483        );
1484    }
1485
1486    #[test]
1487    fn test_align_shifted_sine() {
1488        // Two sines with a phase shift — alignment should reduce distance
1489        let m = 80;
1490        let t = uniform_grid(m);
1491        let f1: Vec<f64> = t
1492            .iter()
1493            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
1494            .collect();
1495        let f2: Vec<f64> = t
1496            .iter()
1497            .map(|&ti| (2.0 * std::f64::consts::PI * (ti - 0.1)).sin())
1498            .collect();
1499
1500        let weights = simpsons_weights(&t);
1501        let l2_before = l2_distance(&f1, &f2, &weights);
1502        let result = elastic_align_pair(&f1, &f2, &t);
1503        let l2_after = l2_distance(&f1, &result.f_aligned, &weights);
1504
1505        assert!(
1506            l2_after < l2_before + 0.01,
1507            "Alignment should not increase L2 distance: before={l2_before:.4}, after={l2_after:.4}"
1508        );
1509    }
1510
1511    #[test]
1512    fn test_align_pair_aligned_curve_is_finite() {
1513        let data = make_test_data(2, 50, 77);
1514        let t = uniform_grid(50);
1515        let f1 = data.row(0);
1516        let f2 = data.row(1);
1517
1518        let result = elastic_align_pair(&f1, &f2, &t);
1519        for j in 0..50 {
1520            assert!(
1521                result.f_aligned[j].is_finite(),
1522                "Aligned curve should be finite at j={j}"
1523            );
1524        }
1525    }
1526
1527    #[test]
1528    fn test_align_pair_minimum_grid() {
1529        // Minimum viable grid: m = 2
1530        let t = vec![0.0, 1.0];
1531        let f1 = vec![0.0, 1.0];
1532        let f2 = vec![0.0, 2.0];
1533        let result = elastic_align_pair(&f1, &f2, &t);
1534        assert_eq!(result.gamma.len(), 2);
1535        assert_eq!(result.f_aligned.len(), 2);
1536        assert!(result.distance >= 0.0);
1537    }
1538
1539    // ── Elastic distance ──
1540
1541    #[test]
1542    fn test_elastic_distance_symmetric() {
1543        let data = make_test_data(3, 50, 42);
1544        let t = uniform_grid(50);
1545        let f1 = data.row(0);
1546        let f2 = data.row(1);
1547
1548        let d12 = elastic_distance(&f1, &f2, &t);
1549        let d21 = elastic_distance(&f2, &f1, &t);
1550
1551        // Should be approximately symmetric (DP is not perfectly symmetric)
1552        assert!(
1553            (d12 - d21).abs() < d12.max(d21) * 0.3 + 0.01,
1554            "Elastic distance should be roughly symmetric: d12={d12}, d21={d21}"
1555        );
1556    }
1557
1558    #[test]
1559    fn test_elastic_distance_nonneg() {
1560        let data = make_test_data(3, 50, 42);
1561        let t = uniform_grid(50);
1562
1563        for i in 0..3 {
1564            for j in 0..3 {
1565                let fi = data.row(i);
1566                let fj = data.row(j);
1567                let d = elastic_distance(&fi, &fj, &t);
1568                assert!(d >= 0.0, "Elastic distance should be non-negative");
1569            }
1570        }
1571    }
1572
1573    #[test]
1574    fn test_elastic_distance_self_near_zero() {
1575        let data = make_test_data(3, 50, 42);
1576        let t = uniform_grid(50);
1577
1578        for i in 0..3 {
1579            let fi = data.row(i);
1580            let d = elastic_distance(&fi, &fi, &t);
1581            assert!(
1582                d < 0.1,
1583                "Self-distance should be near zero, got {d} for curve {i}"
1584            );
1585        }
1586    }
1587
1588    #[test]
1589    fn test_elastic_distance_triangle_inequality() {
1590        let data = make_test_data(3, 50, 42);
1591        let t = uniform_grid(50);
1592        let f0 = data.row(0);
1593        let f1 = data.row(1);
1594        let f2 = data.row(2);
1595
1596        let d01 = elastic_distance(&f0, &f1, &t);
1597        let d12 = elastic_distance(&f1, &f2, &t);
1598        let d02 = elastic_distance(&f0, &f2, &t);
1599
1600        // Relaxed triangle inequality (DP alignment is approximate)
1601        let slack = 0.5;
1602        assert!(
1603            d02 <= d01 + d12 + slack,
1604            "Triangle inequality (relaxed): d02={d02:.4} > d01={d01:.4} + d12={d12:.4} + {slack}"
1605        );
1606    }
1607
1608    #[test]
1609    fn test_elastic_distance_different_shapes_nonzero() {
1610        let m = 50;
1611        let t = uniform_grid(m);
1612        let f1: Vec<f64> = t.to_vec(); // linear
1613        let f2: Vec<f64> = t.iter().map(|&ti| ti * ti).collect(); // quadratic
1614
1615        let d = elastic_distance(&f1, &f2, &t);
1616        assert!(
1617            d > 0.01,
1618            "Distance between different shapes should be > 0, got {d}"
1619        );
1620    }
1621
1622    // ── Self distance matrix ──
1623
1624    #[test]
1625    fn test_self_distance_matrix_symmetric() {
1626        let data = make_test_data(5, 30, 42);
1627        let t = uniform_grid(30);
1628
1629        let dm = elastic_self_distance_matrix(&data, &t);
1630        let n = dm.nrows();
1631
1632        assert_eq!(dm.shape(), (5, 5));
1633
1634        // Zero diagonal
1635        for i in 0..n {
1636            assert!(dm[(i, i)].abs() < 1e-12, "Diagonal should be zero");
1637        }
1638
1639        // Symmetric
1640        for i in 0..n {
1641            for j in (i + 1)..n {
1642                assert!(
1643                    (dm[(i, j)] - dm[(j, i)]).abs() < 1e-12,
1644                    "Matrix should be symmetric at ({i},{j})"
1645                );
1646            }
1647        }
1648    }
1649
1650    #[test]
1651    fn test_self_distance_matrix_nonneg() {
1652        let data = make_test_data(4, 30, 42);
1653        let t = uniform_grid(30);
1654        let dm = elastic_self_distance_matrix(&data, &t);
1655
1656        for i in 0..4 {
1657            for j in 0..4 {
1658                assert!(
1659                    dm[(i, j)] >= 0.0,
1660                    "Distance matrix entries should be non-negative at ({i},{j})"
1661                );
1662            }
1663        }
1664    }
1665
1666    #[test]
1667    fn test_self_distance_matrix_single_curve() {
1668        let data = make_test_data(1, 30, 42);
1669        let t = uniform_grid(30);
1670        let dm = elastic_self_distance_matrix(&data, &t);
1671        assert_eq!(dm.shape(), (1, 1));
1672        assert!(dm[(0, 0)].abs() < 1e-12);
1673    }
1674
1675    #[test]
1676    fn test_self_distance_matrix_consistent_with_pairwise() {
1677        let data = make_test_data(4, 30, 42);
1678        let t = uniform_grid(30);
1679
1680        let dm = elastic_self_distance_matrix(&data, &t);
1681
1682        // Check a few entries match direct elastic_distance calls
1683        for i in 0..4 {
1684            for j in (i + 1)..4 {
1685                let fi = data.row(i);
1686                let fj = data.row(j);
1687                let d_direct = elastic_distance(&fi, &fj, &t);
1688                assert!(
1689                    (dm[(i, j)] - d_direct).abs() < 1e-10,
1690                    "Matrix entry ({i},{j})={:.6} should match pairwise {d_direct:.6}",
1691                    dm[(i, j)]
1692                );
1693            }
1694        }
1695    }
1696
1697    // ── Karcher mean ──
1698
1699    #[test]
1700    fn test_karcher_mean_identical_curves() {
1701        let m = 50;
1702        let t = uniform_grid(m);
1703        let f: Vec<f64> = t
1704            .iter()
1705            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
1706            .collect();
1707
1708        // Create 5 identical curves
1709        let mut data = FdMatrix::zeros(5, m);
1710        for i in 0..5 {
1711            for j in 0..m {
1712                data[(i, j)] = f[j];
1713            }
1714        }
1715
1716        let result = karcher_mean(&data, &t, 10, 1e-4);
1717
1718        assert_eq!(result.mean.len(), m);
1719        assert!(result.n_iter <= 10);
1720    }
1721
1722    #[test]
1723    fn test_karcher_mean_output_shape() {
1724        let data = make_test_data(15, 50, 42);
1725        let t = uniform_grid(50);
1726
1727        let result = karcher_mean(&data, &t, 5, 1e-3);
1728
1729        assert_eq!(result.mean.len(), 50);
1730        assert_eq!(result.mean_srsf.len(), 50);
1731        assert_eq!(result.gammas.shape(), (15, 50));
1732        assert_eq!(result.aligned_data.shape(), (15, 50));
1733        assert!(result.n_iter <= 5);
1734    }
1735
1736    #[test]
1737    fn test_karcher_mean_warps_are_valid() {
1738        let data = make_test_data(10, 40, 42);
1739        let t = uniform_grid(40);
1740
1741        let result = karcher_mean(&data, &t, 5, 1e-3);
1742
1743        for i in 0..10 {
1744            // Boundary values
1745            assert!(
1746                (result.gammas[(i, 0)] - t[0]).abs() < 1e-10,
1747                "Warp {i} should start at domain start"
1748            );
1749            assert!(
1750                (result.gammas[(i, 39)] - t[39]).abs() < 1e-10,
1751                "Warp {i} should end at domain end"
1752            );
1753            // Monotonicity
1754            for j in 1..40 {
1755                assert!(
1756                    result.gammas[(i, j)] >= result.gammas[(i, j - 1)],
1757                    "Warp {i} should be monotone at j={j}"
1758                );
1759            }
1760        }
1761    }
1762
1763    #[test]
1764    fn test_karcher_mean_aligned_data_is_finite() {
1765        let data = make_test_data(8, 40, 42);
1766        let t = uniform_grid(40);
1767        let result = karcher_mean(&data, &t, 5, 1e-3);
1768
1769        for i in 0..8 {
1770            for j in 0..40 {
1771                assert!(
1772                    result.aligned_data[(i, j)].is_finite(),
1773                    "Aligned data should be finite at ({i},{j})"
1774                );
1775            }
1776        }
1777    }
1778
1779    #[test]
1780    fn test_karcher_mean_srsf_is_finite() {
1781        let data = make_test_data(8, 40, 42);
1782        let t = uniform_grid(40);
1783        let result = karcher_mean(&data, &t, 5, 1e-3);
1784
1785        for j in 0..40 {
1786            assert!(
1787                result.mean_srsf[j].is_finite(),
1788                "Mean SRSF should be finite at j={j}"
1789            );
1790            assert!(
1791                result.mean[j].is_finite(),
1792                "Mean curve should be finite at j={j}"
1793            );
1794        }
1795    }
1796
1797    #[test]
1798    fn test_karcher_mean_single_iteration() {
1799        let data = make_test_data(10, 40, 42);
1800        let t = uniform_grid(40);
1801        let result = karcher_mean(&data, &t, 1, 1e-10);
1802
1803        assert_eq!(result.n_iter, 1);
1804        assert_eq!(result.mean.len(), 40);
1805        // With only 1 iteration, still produces valid output
1806        for j in 0..40 {
1807            assert!(result.mean[j].is_finite());
1808        }
1809    }
1810
1811    // ── Align to target ──
1812
1813    #[test]
1814    fn test_align_to_target_valid() {
1815        let data = make_test_data(10, 40, 42);
1816        let t = uniform_grid(40);
1817        let target = data.row(0);
1818
1819        let result = align_to_target(&data, &target, &t);
1820
1821        assert_eq!(result.gammas.shape(), (10, 40));
1822        assert_eq!(result.aligned_data.shape(), (10, 40));
1823        assert_eq!(result.distances.len(), 10);
1824
1825        // All distances should be non-negative
1826        for &d in &result.distances {
1827            assert!(d >= 0.0);
1828        }
1829    }
1830
1831    #[test]
1832    fn test_align_to_target_self_near_zero() {
1833        let data = make_test_data(5, 40, 42);
1834        let t = uniform_grid(40);
1835        let target = data.row(0);
1836
1837        let result = align_to_target(&data, &target, &t);
1838
1839        // Distance of target to itself should be near zero
1840        assert!(
1841            result.distances[0] < 0.1,
1842            "Self-alignment distance should be near zero, got {}",
1843            result.distances[0]
1844        );
1845    }
1846
1847    #[test]
1848    fn test_align_to_target_warps_are_monotone() {
1849        let data = make_test_data(8, 40, 42);
1850        let t = uniform_grid(40);
1851        let target = data.row(0);
1852        let result = align_to_target(&data, &target, &t);
1853
1854        for i in 0..8 {
1855            for j in 1..40 {
1856                assert!(
1857                    result.gammas[(i, j)] >= result.gammas[(i, j - 1)],
1858                    "Warp for curve {i} should be monotone at j={j}"
1859                );
1860            }
1861        }
1862    }
1863
1864    #[test]
1865    fn test_align_to_target_aligned_data_finite() {
1866        let data = make_test_data(6, 40, 42);
1867        let t = uniform_grid(40);
1868        let target = data.row(0);
1869        let result = align_to_target(&data, &target, &t);
1870
1871        for i in 0..6 {
1872            for j in 0..40 {
1873                assert!(
1874                    result.aligned_data[(i, j)].is_finite(),
1875                    "Aligned data should be finite at ({i},{j})"
1876                );
1877            }
1878        }
1879    }
1880
1881    // ── Cross distance matrix ──
1882
1883    #[test]
1884    fn test_cross_distance_matrix_shape() {
1885        let data1 = make_test_data(3, 30, 42);
1886        let data2 = make_test_data(4, 30, 99);
1887        let t = uniform_grid(30);
1888
1889        let dm = elastic_cross_distance_matrix(&data1, &data2, &t);
1890        assert_eq!(dm.shape(), (3, 4));
1891
1892        // All non-negative
1893        for i in 0..3 {
1894            for j in 0..4 {
1895                assert!(dm[(i, j)] >= 0.0);
1896            }
1897        }
1898    }
1899
1900    #[test]
1901    fn test_cross_distance_matrix_self_matches_self_matrix() {
1902        // cross_distance(data, data) should have zero diagonal (approximately)
1903        let data = make_test_data(4, 30, 42);
1904        let t = uniform_grid(30);
1905
1906        let cross = elastic_cross_distance_matrix(&data, &data, &t);
1907        for i in 0..4 {
1908            assert!(
1909                cross[(i, i)] < 0.1,
1910                "Cross distance (self) diagonal should be near zero: got {}",
1911                cross[(i, i)]
1912            );
1913        }
1914    }
1915
1916    #[test]
1917    fn test_cross_distance_matrix_consistent_with_pairwise() {
1918        let data1 = make_test_data(3, 30, 42);
1919        let data2 = make_test_data(2, 30, 99);
1920        let t = uniform_grid(30);
1921
1922        let dm = elastic_cross_distance_matrix(&data1, &data2, &t);
1923
1924        for i in 0..3 {
1925            for j in 0..2 {
1926                let fi = data1.row(i);
1927                let fj = data2.row(j);
1928                let d_direct = elastic_distance(&fi, &fj, &t);
1929                assert!(
1930                    (dm[(i, j)] - d_direct).abs() < 1e-10,
1931                    "Cross matrix ({i},{j})={:.6} should match pairwise {d_direct:.6}",
1932                    dm[(i, j)]
1933                );
1934            }
1935        }
1936    }
1937
1938    // ── align_srsf_pair ──
1939
1940    #[test]
1941    fn test_align_srsf_pair_identity() {
1942        let m = 50;
1943        let t = uniform_grid(m);
1944        let f: Vec<f64> = t
1945            .iter()
1946            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
1947            .collect();
1948        let q = srsf_single(&f, &t);
1949
1950        let (gamma, q_aligned) = align_srsf_pair(&q, &q, &t);
1951
1952        // Warp should be near identity
1953        for j in 0..m {
1954            assert!(
1955                (gamma[j] - t[j]).abs() < 0.15,
1956                "Self-SRSF alignment warp should be near identity at j={j}"
1957            );
1958        }
1959
1960        // Aligned SRSF should be close to original
1961        let weights = simpsons_weights(&t);
1962        let dist = l2_distance(&q, &q_aligned, &weights);
1963        assert!(
1964            dist < 0.5,
1965            "Self-aligned SRSF distance should be small, got {dist}"
1966        );
1967    }
1968
1969    // ── srsf_single ──
1970
1971    #[test]
1972    fn test_srsf_single_matches_matrix_version() {
1973        let m = 50;
1974        let t = uniform_grid(m);
1975        let f: Vec<f64> = t.iter().map(|&ti| ti * ti + ti).collect();
1976
1977        let q_single = srsf_single(&f, &t);
1978
1979        let mat = FdMatrix::from_slice(&f, 1, m).unwrap();
1980        let q_mat = srsf_transform(&mat, &t);
1981        let q_from_mat = q_mat.row(0);
1982
1983        for j in 0..m {
1984            assert!(
1985                (q_single[j] - q_from_mat[j]).abs() < 1e-12,
1986                "srsf_single should match srsf_transform at j={j}"
1987            );
1988        }
1989    }
1990
1991    // ── gcd ──
1992
1993    #[test]
1994    fn test_gcd_basic() {
1995        assert_eq!(gcd(1, 1), 1);
1996        assert_eq!(gcd(6, 4), 2);
1997        assert_eq!(gcd(7, 5), 1);
1998        assert_eq!(gcd(12, 8), 4);
1999        assert_eq!(gcd(7, 0), 7);
2000        assert_eq!(gcd(0, 5), 5);
2001    }
2002
2003    // ── generate_coprime_nbhd ──
2004
2005    #[test]
2006    fn test_coprime_nbhd_count() {
2007        assert_eq!(generate_coprime_nbhd(1).len(), 1); // just (1,1)
2008        assert_eq!(generate_coprime_nbhd(7).len(), 35);
2009    }
2010
2011    #[test]
2012    fn test_coprime_nbhd_matches_const() {
2013        let generated = generate_coprime_nbhd(7);
2014        assert_eq!(generated.len(), COPRIME_NBHD_7.len());
2015        for (i, pair) in generated.iter().enumerate() {
2016            assert_eq!(*pair, COPRIME_NBHD_7[i], "mismatch at index {i}");
2017        }
2018    }
2019
2020    #[test]
2021    fn test_coprime_nbhd_all_coprime() {
2022        for &(i, j) in &COPRIME_NBHD_7 {
2023            assert_eq!(gcd(i, j), 1, "({i},{j}) should be coprime");
2024            assert!((1..=7).contains(&i));
2025            assert!((1..=7).contains(&j));
2026        }
2027    }
2028
2029    // ── dp_edge_weight ──
2030
2031    #[test]
2032    fn test_dp_edge_weight_diagonal() {
2033        // Diagonal move (1,1): weight = (q1[sc] - sqrt(1)*q2[sr])^2 * h
2034        let t = uniform_grid(10);
2035        let q1 = vec![1.0; 10];
2036        let q2 = vec![1.0; 10];
2037        // Identical SRSFs: weight should be 0
2038        let w = dp_edge_weight(&q1, &q2, &t, 0, 1, 0, 1);
2039        assert!(w.abs() < 1e-12, "identical SRSFs should have zero cost");
2040    }
2041
2042    #[test]
2043    fn test_dp_edge_weight_non_diagonal() {
2044        // Move (1,2): n1=2, n2=1, slope = h/(2h) = 0.5
2045        let t = uniform_grid(10);
2046        let q1 = vec![1.0; 10];
2047        let q2 = vec![0.0; 10];
2048        let w = dp_edge_weight(&q1, &q2, &t, 0, 2, 0, 1);
2049        // diff = q1[0] - sqrt(0.5)*q2[0] = 1.0 - 0 = 1.0
2050        // weight = 1.0^2 * 1.0 * (t[2]-t[0]) = 2/9
2051        let expected = 2.0 / 9.0;
2052        assert!(
2053            (w - expected).abs() < 1e-10,
2054            "dp_edge_weight (1,2): expected {expected}, got {w}"
2055        );
2056    }
2057
2058    #[test]
2059    fn test_dp_edge_weight_zero_span() {
2060        let t = uniform_grid(10);
2061        let q1 = vec![1.0; 10];
2062        let q2 = vec![1.0; 10];
2063        // n1=0: should return INFINITY
2064        assert_eq!(dp_edge_weight(&q1, &q2, &t, 3, 3, 0, 1), f64::INFINITY);
2065        // n2=0: should return INFINITY
2066        assert_eq!(dp_edge_weight(&q1, &q2, &t, 0, 1, 3, 3), f64::INFINITY);
2067    }
2068
2069    // ── DP alignment quality ──
2070
2071    #[test]
2072    fn test_alignment_improves_distance() {
2073        // Aligned SRSF distance should be less than unaligned SRSF distance
2074        let m = 50;
2075        let t = uniform_grid(m);
2076        let f1: Vec<f64> = t
2077            .iter()
2078            .map(|&x| (2.0 * std::f64::consts::PI * x).sin())
2079            .collect();
2080        // Use a larger shift so improvement is clear
2081        let f2: Vec<f64> = t
2082            .iter()
2083            .map(|&x| (2.0 * std::f64::consts::PI * (x + 0.2)).sin())
2084            .collect();
2085
2086        let q1 = srsf_single(&f1, &t);
2087        let q2 = srsf_single(&f2, &t);
2088        let weights = simpsons_weights(&t);
2089        let unaligned_srsf_dist = l2_distance(&q1, &q2, &weights);
2090
2091        let result = elastic_align_pair(&f1, &f2, &t);
2092
2093        assert!(
2094            result.distance <= unaligned_srsf_dist + 1e-6,
2095            "aligned SRSF dist ({}) should be <= unaligned SRSF dist ({})",
2096            result.distance,
2097            unaligned_srsf_dist
2098        );
2099    }
2100
2101    // ── Edge case: constant data ──
2102
2103    #[test]
2104    fn test_alignment_constant_curves() {
2105        let m = 30;
2106        let t = uniform_grid(m);
2107        let f1 = vec![5.0; m];
2108        let f2 = vec![5.0; m];
2109
2110        let result = elastic_align_pair(&f1, &f2, &t);
2111        assert!(
2112            result.distance < 0.01,
2113            "Constant curves: distance should be ~0"
2114        );
2115        assert_eq!(result.f_aligned.len(), m);
2116    }
2117
2118    #[test]
2119    fn test_karcher_mean_constant_curves() {
2120        let m = 30;
2121        let t = uniform_grid(m);
2122        let mut data = FdMatrix::zeros(5, m);
2123        for i in 0..5 {
2124            for j in 0..m {
2125                data[(i, j)] = 3.0;
2126            }
2127        }
2128
2129        let result = karcher_mean(&data, &t, 5, 1e-4);
2130        for j in 0..m {
2131            assert!(
2132                (result.mean[j] - 3.0).abs() < 0.5,
2133                "Mean of constant curves should be near 3.0, got {} at j={j}",
2134                result.mean[j]
2135            );
2136        }
2137    }
2138}