Skip to main content

fdars_core/
alignment.rs

1//! Elastic alignment and SRSF (Square-Root Slope Function) transforms.
2//!
3//! This module provides phase-amplitude separation for functional data via
4//! the elastic framework. Key capabilities:
5//!
6//! - [`srsf_transform`] / [`srsf_inverse`] — SRSF representation and reconstruction
7//! - [`elastic_align_pair`] — Pairwise curve alignment via dynamic programming
8//! - [`elastic_distance`] — Elastic (Fisher-Rao) distance between curves
9//! - [`align_to_target`] — Align a set of curves to a common target
10//! - [`karcher_mean`] — Karcher (Fréchet) mean in the elastic metric
11//! - [`elastic_self_distance_matrix`] / [`elastic_cross_distance_matrix`] — Distance matrices
12//! - [`reparameterize_curve`] / [`compose_warps`] — Warping utilities
13
14use crate::fdata::{deriv_1d, mean_1d};
15use crate::helpers::{l2_distance, simpsons_weights};
16use crate::iter_maybe_parallel;
17use crate::matrix::FdMatrix;
18#[cfg(feature = "parallel")]
19use rayon::iter::ParallelIterator;
20
21// ─── Types ──────────────────────────────────────────────────────────────────
22
23/// Result of aligning one curve to another.
24#[derive(Debug, Clone)]
25pub struct AlignmentResult {
26    /// Warping function γ mapping the domain to itself.
27    pub gamma: Vec<f64>,
28    /// The aligned (reparameterized) curve.
29    pub f_aligned: Vec<f64>,
30    /// Elastic distance after alignment.
31    pub distance: f64,
32}
33
34/// Result of aligning a set of curves to a common target.
35#[derive(Debug, Clone)]
36pub struct AlignmentSetResult {
37    /// Warping functions (n × m).
38    pub gammas: FdMatrix,
39    /// Aligned curves (n × m).
40    pub aligned_data: FdMatrix,
41    /// Elastic distances for each curve.
42    pub distances: Vec<f64>,
43}
44
45/// Result of the Karcher mean computation.
46#[derive(Debug, Clone)]
47pub struct KarcherMeanResult {
48    /// Karcher mean curve.
49    pub mean: Vec<f64>,
50    /// SRSF of the Karcher mean.
51    pub mean_srsf: Vec<f64>,
52    /// Final warping functions (n × m).
53    pub gammas: FdMatrix,
54    /// Curves aligned to the mean (n × m).
55    pub aligned_data: FdMatrix,
56    /// Number of iterations used.
57    pub n_iter: usize,
58    /// Whether the algorithm converged.
59    pub converged: bool,
60}
61
62// ─── Private helpers ────────────────────────────────────────────────────────
63
64/// Linear interpolation at point `t` using binary search.
65fn linear_interp(x: &[f64], y: &[f64], t: f64) -> f64 {
66    if t <= x[0] {
67        return y[0];
68    }
69    let last = x.len() - 1;
70    if t >= x[last] {
71        return y[last];
72    }
73
74    // Binary search for the interval containing t
75    let idx = match x.binary_search_by(|v| v.partial_cmp(&t).unwrap()) {
76        Ok(i) => return y[i],
77        Err(i) => i,
78    };
79
80    let t0 = x[idx - 1];
81    let t1 = x[idx];
82    let y0 = y[idx - 1];
83    let y1 = y[idx];
84    y0 + (y1 - y0) * (t - t0) / (t1 - t0)
85}
86
87/// Cumulative trapezoidal integration.
88fn cumulative_trapz(y: &[f64], x: &[f64]) -> Vec<f64> {
89    let n = y.len();
90    let mut out = vec![0.0; n];
91    for k in 1..n {
92        out[k] = out[k - 1] + 0.5 * (y[k] + y[k - 1]) * (x[k] - x[k - 1]);
93    }
94    out
95}
96
97/// Ensure γ is a valid warping: monotone non-decreasing, with correct boundary values.
98fn normalize_warp(gamma: &mut [f64], argvals: &[f64]) {
99    let n = gamma.len();
100    if n == 0 {
101        return;
102    }
103
104    // Fix boundaries
105    gamma[0] = argvals[0];
106    gamma[n - 1] = argvals[n - 1];
107
108    // Enforce monotonicity
109    for i in 1..n {
110        if gamma[i] < gamma[i - 1] {
111            gamma[i] = gamma[i - 1];
112        }
113    }
114}
115
116// ─── SRSF Transform and Inverse ─────────────────────────────────────────────
117
118/// Compute the Square-Root Slope Function (SRSF) transform.
119///
120/// For each curve f, the SRSF is: `q(t) = sign(f'(t)) * sqrt(|f'(t)|)`
121///
122/// # Arguments
123/// * `data` — Functional data matrix (n × m)
124/// * `argvals` — Evaluation points (length m)
125///
126/// # Returns
127/// FdMatrix of SRSFs with the same shape as input.
128pub fn srsf_transform(data: &FdMatrix, argvals: &[f64]) -> FdMatrix {
129    let (n, m) = data.shape();
130    if n == 0 || m == 0 || argvals.len() != m {
131        return FdMatrix::zeros(n, m);
132    }
133
134    let deriv = deriv_1d(data, argvals, 1);
135
136    let mut result = FdMatrix::zeros(n, m);
137    for i in 0..n {
138        for j in 0..m {
139            let d = deriv[(i, j)];
140            result[(i, j)] = d.signum() * d.abs().sqrt();
141        }
142    }
143    result
144}
145
146/// Reconstruct a curve from its SRSF representation.
147///
148/// Given SRSF q and initial value f0, reconstructs: `f(t) = f0 + ∫₀ᵗ q(s)|q(s)| ds`
149///
150/// # Arguments
151/// * `q` — SRSF values (length m)
152/// * `argvals` — Evaluation points (length m)
153/// * `f0` — Initial value f(argvals\[0\])
154///
155/// # Returns
156/// Reconstructed curve values.
157pub fn srsf_inverse(q: &[f64], argvals: &[f64], f0: f64) -> Vec<f64> {
158    let m = q.len();
159    if m == 0 {
160        return Vec::new();
161    }
162
163    // Integrand: q(s) * |q(s)|
164    let integrand: Vec<f64> = q.iter().map(|&qi| qi * qi.abs()).collect();
165    let integral = cumulative_trapz(&integrand, argvals);
166
167    integral.iter().map(|&v| f0 + v).collect()
168}
169
170// ─── Reparameterization ─────────────────────────────────────────────────────
171
172/// Reparameterize a curve by a warping function.
173///
174/// Computes `f(γ(t))` via linear interpolation.
175///
176/// # Arguments
177/// * `f` — Curve values (length m)
178/// * `argvals` — Evaluation points (length m)
179/// * `gamma` — Warping function values (length m)
180pub fn reparameterize_curve(f: &[f64], argvals: &[f64], gamma: &[f64]) -> Vec<f64> {
181    gamma
182        .iter()
183        .map(|&g| linear_interp(argvals, f, g))
184        .collect()
185}
186
187/// Compose two warping functions: `(γ₁ ∘ γ₂)(t) = γ₁(γ₂(t))`.
188///
189/// # Arguments
190/// * `gamma1` — Outer warping function (length m)
191/// * `gamma2` — Inner warping function (length m)
192/// * `argvals` — Evaluation points (length m)
193pub fn compose_warps(gamma1: &[f64], gamma2: &[f64], argvals: &[f64]) -> Vec<f64> {
194    gamma2
195        .iter()
196        .map(|&g| linear_interp(argvals, gamma1, g))
197        .collect()
198}
199
200// ─── Dynamic Programming Alignment ──────────────────────────────────────────
201
202/// Convert a DP traceback path into a warping function sampled at argvals.
203fn path_to_gamma(path: &[(usize, usize)], argvals: &[f64], grid: &[f64]) -> Vec<f64> {
204    if path.is_empty() {
205        return argvals.to_vec();
206    }
207
208    // Extract the warping from the path: γ maps grid[path[k].0] -> grid[path[k].1]
209    let path_t: Vec<f64> = path.iter().map(|&(i, _)| grid[i]).collect();
210    let path_g: Vec<f64> = path.iter().map(|&(_, j)| grid[j]).collect();
211
212    // Interpolate to get γ at each argval
213    let mut gamma: Vec<f64> = argvals
214        .iter()
215        .map(|&t| linear_interp(&path_t, &path_g, t))
216        .collect();
217
218    normalize_warp(&mut gamma, argvals);
219    gamma
220}
221
222/// Pick the minimum-cost move (diagonal / horizontal / vertical) and write into `curr_row` and `trace`.
223#[inline]
224fn dp_pick_best(
225    prev_row: &[f64],
226    curr_row: &mut [f64],
227    trace: &mut [u8],
228    q1_i: f64,
229    q2_j: f64,
230    dt_i: f64,
231    dt_j: f64,
232    j: usize,
233    trace_off: usize,
234) {
235    // Diagonal: (i-1,j-1) → (i,j) with slope correction
236    let sqrt_slope = (dt_j / dt_i).sqrt();
237    let vd = q1_i - q2_j * sqrt_slope;
238    let cost_diag = prev_row[j - 1] + vd * vd * dt_i;
239
240    // Horizontal: (i, j-1) → (i, j)
241    let vh = q1_i - q2_j;
242    let cost_horiz = curr_row[j - 1] + vh * vh * dt_j;
243
244    // Vertical: (i-1, j) → (i, j)
245    let cost_vert = prev_row[j] + vh * vh * dt_i;
246
247    if cost_diag <= cost_horiz && cost_diag <= cost_vert {
248        curr_row[j] = cost_diag;
249        trace[trace_off + j] = 0;
250    } else if cost_horiz <= cost_vert {
251        curr_row[j] = cost_horiz;
252        trace[trace_off + j] = 1;
253    } else {
254        curr_row[j] = cost_vert;
255        trace[trace_off + j] = 2;
256    }
257}
258
259/// Traceback through the DP trace matrix to recover the optimal path.
260fn dp_traceback(trace: &[u8], m: usize) -> Vec<(usize, usize)> {
261    let mut path = Vec::with_capacity(2 * m);
262    let (mut i, mut j) = (m - 1, m - 1);
263    path.push((i, j));
264
265    while i > 0 || j > 0 {
266        match trace[i * m + j] {
267            0 => {
268                i -= 1;
269                j -= 1;
270            }
271            1 => j -= 1,
272            _ => i -= 1,
273        }
274        path.push((i, j));
275    }
276
277    path.reverse();
278    path
279}
280
281/// Core DP alignment between two SRSFs on a grid.
282///
283/// Finds the optimal warping γ minimizing ‖q₁ - (q₂∘γ)√γ'‖².
284fn dp_alignment_core(q1: &[f64], q2: &[f64], argvals: &[f64]) -> Vec<f64> {
285    let m = argvals.len();
286    if m < 2 {
287        return argvals.to_vec();
288    }
289
290    let grid = argvals;
291    let mut prev_row = vec![f64::MAX; m];
292    let mut curr_row = vec![f64::MAX; m];
293    // 0 = diagonal, 1 = horizontal, 2 = vertical
294    let mut trace = vec![0u8; m * m];
295
296    // First row: can only come from left
297    prev_row[0] = 0.0;
298    for j in 1..m {
299        let dt = grid[j] - grid[j - 1];
300        let val = q1[0] - q2[j];
301        prev_row[j] = prev_row[j - 1] + val * val * dt;
302        trace[j] = 1;
303    }
304
305    // Fill remaining rows
306    for i in 1..m {
307        let dt_i = grid[i] - grid[i - 1];
308        let val = q1[i] - q2[0];
309        curr_row[0] = prev_row[0] + val * val * dt_i;
310        trace[i * m] = 2;
311
312        let trace_off = i * m;
313        for j in 1..m {
314            let dt_j = grid[j] - grid[j - 1];
315            dp_pick_best(
316                &prev_row,
317                &mut curr_row,
318                &mut trace,
319                q1[i],
320                q2[j],
321                dt_i,
322                dt_j,
323                j,
324                trace_off,
325            );
326        }
327
328        std::mem::swap(&mut prev_row, &mut curr_row);
329    }
330
331    let path = dp_traceback(&trace, m);
332    path_to_gamma(&path, argvals, grid)
333}
334
335// ─── Public Alignment Functions ─────────────────────────────────────────────
336
337/// Align curve f2 to curve f1 using the elastic framework.
338///
339/// Computes the optimal warping γ such that f2∘γ is as close as possible
340/// to f1 in the elastic (Fisher-Rao) metric.
341///
342/// # Arguments
343/// * `f1` — Target curve (length m)
344/// * `f2` — Curve to align (length m)
345/// * `argvals` — Evaluation points (length m)
346///
347/// # Returns
348/// [`AlignmentResult`] with warping function, aligned curve, and elastic distance.
349pub fn elastic_align_pair(f1: &[f64], f2: &[f64], argvals: &[f64]) -> AlignmentResult {
350    let m = f1.len();
351
352    // Build single-row FdMatrices for SRSF computation
353    let f1_mat = FdMatrix::from_slice(f1, 1, m).unwrap();
354    let f2_mat = FdMatrix::from_slice(f2, 1, m).unwrap();
355
356    let q1_mat = srsf_transform(&f1_mat, argvals);
357    let q2_mat = srsf_transform(&f2_mat, argvals);
358
359    let q1: Vec<f64> = q1_mat.row(0);
360    let q2: Vec<f64> = q2_mat.row(0);
361
362    // Find optimal warping via DP
363    let gamma = dp_alignment_core(&q1, &q2, argvals);
364
365    // Apply warping to f2
366    let f_aligned = reparameterize_curve(f2, argvals, &gamma);
367
368    // Compute elastic distance: L2 distance between q1 and aligned q2 SRSF
369    let f_aligned_mat = FdMatrix::from_slice(&f_aligned, 1, m).unwrap();
370    let q_aligned_mat = srsf_transform(&f_aligned_mat, argvals);
371    let q_aligned: Vec<f64> = q_aligned_mat.row(0);
372
373    let weights = simpsons_weights(argvals);
374    let distance = l2_distance(&q1, &q_aligned, &weights);
375
376    AlignmentResult {
377        gamma,
378        f_aligned,
379        distance,
380    }
381}
382
383/// Compute the elastic distance between two curves.
384///
385/// This is shorthand for aligning the pair and returning only the distance.
386///
387/// # Arguments
388/// * `f1` — First curve (length m)
389/// * `f2` — Second curve (length m)
390/// * `argvals` — Evaluation points (length m)
391pub fn elastic_distance(f1: &[f64], f2: &[f64], argvals: &[f64]) -> f64 {
392    elastic_align_pair(f1, f2, argvals).distance
393}
394
395/// Align all curves in `data` to a single target curve.
396///
397/// # Arguments
398/// * `data` — Functional data matrix (n × m)
399/// * `target` — Target curve to align to (length m)
400/// * `argvals` — Evaluation points (length m)
401///
402/// # Returns
403/// [`AlignmentSetResult`] with all warping functions, aligned curves, and distances.
404pub fn align_to_target(data: &FdMatrix, target: &[f64], argvals: &[f64]) -> AlignmentSetResult {
405    let (n, m) = data.shape();
406
407    let results: Vec<AlignmentResult> = iter_maybe_parallel!(0..n)
408        .map(|i| {
409            let fi = data.row(i);
410            elastic_align_pair(target, &fi, argvals)
411        })
412        .collect();
413
414    let mut gammas = FdMatrix::zeros(n, m);
415    let mut aligned_data = FdMatrix::zeros(n, m);
416    let mut distances = Vec::with_capacity(n);
417
418    for (i, r) in results.into_iter().enumerate() {
419        for j in 0..m {
420            gammas[(i, j)] = r.gamma[j];
421            aligned_data[(i, j)] = r.f_aligned[j];
422        }
423        distances.push(r.distance);
424    }
425
426    AlignmentSetResult {
427        gammas,
428        aligned_data,
429        distances,
430    }
431}
432
433// ─── Distance Matrices ──────────────────────────────────────────────────────
434
435/// Compute the symmetric elastic distance matrix for a set of curves.
436///
437/// Uses upper-triangle computation with parallelism, following the
438/// `self_distance_matrix` pattern from `metric.rs`.
439///
440/// # Arguments
441/// * `data` — Functional data matrix (n × m)
442/// * `argvals` — Evaluation points (length m)
443///
444/// # Returns
445/// Symmetric n × n distance matrix.
446pub fn elastic_self_distance_matrix(data: &FdMatrix, argvals: &[f64]) -> FdMatrix {
447    let n = data.nrows();
448
449    let upper_vals: Vec<f64> = iter_maybe_parallel!(0..n)
450        .flat_map(|i| {
451            let fi = data.row(i);
452            ((i + 1)..n)
453                .map(|j| {
454                    let fj = data.row(j);
455                    elastic_distance(&fi, &fj, argvals)
456                })
457                .collect::<Vec<_>>()
458        })
459        .collect();
460
461    let mut dist = FdMatrix::zeros(n, n);
462    let mut idx = 0;
463    for i in 0..n {
464        for j in (i + 1)..n {
465            let d = upper_vals[idx];
466            dist[(i, j)] = d;
467            dist[(j, i)] = d;
468            idx += 1;
469        }
470    }
471    dist
472}
473
474/// Compute the elastic distance matrix between two sets of curves.
475///
476/// # Arguments
477/// * `data1` — First dataset (n1 × m)
478/// * `data2` — Second dataset (n2 × m)
479/// * `argvals` — Evaluation points (length m)
480///
481/// # Returns
482/// n1 × n2 distance matrix.
483pub fn elastic_cross_distance_matrix(
484    data1: &FdMatrix,
485    data2: &FdMatrix,
486    argvals: &[f64],
487) -> FdMatrix {
488    let n1 = data1.nrows();
489    let n2 = data2.nrows();
490
491    let vals: Vec<f64> = iter_maybe_parallel!(0..n1)
492        .flat_map(|i| {
493            let fi = data1.row(i);
494            (0..n2)
495                .map(|j| {
496                    let fj = data2.row(j);
497                    elastic_distance(&fi, &fj, argvals)
498                })
499                .collect::<Vec<_>>()
500        })
501        .collect();
502
503    let mut dist = FdMatrix::zeros(n1, n2);
504    for i in 0..n1 {
505        for j in 0..n2 {
506            dist[(i, j)] = vals[i * n2 + j];
507        }
508    }
509    dist
510}
511
512// ─── Karcher Mean ───────────────────────────────────────────────────────────
513
514/// Check convergence of the Karcher mean iteration.
515fn mean_has_converged(q_old: &[f64], q_new: &[f64], weights: &[f64], tol: f64) -> bool {
516    let dist = l2_distance(q_old, q_new, weights);
517    dist < tol
518}
519
520/// Compute a single SRSF from a slice (single-row convenience).
521fn srsf_single(f: &[f64], argvals: &[f64]) -> Vec<f64> {
522    let m = f.len();
523    let mat = FdMatrix::from_slice(f, 1, m).unwrap();
524    let q_mat = srsf_transform(&mat, argvals);
525    q_mat.row(0)
526}
527
528/// Align a single SRSF q2 to q1 and return (gamma, aligned_q).
529fn align_srsf_pair(q1: &[f64], q2: &[f64], argvals: &[f64]) -> (Vec<f64>, Vec<f64>) {
530    let gamma = dp_alignment_core(q1, q2, argvals);
531
532    // Warp q2 by gamma and adjust by sqrt(gamma')
533    let q2_warped = reparameterize_curve(q2, argvals, &gamma);
534
535    // Compute gamma' via finite differences
536    let m = gamma.len();
537    let mut gamma_dot = vec![0.0; m];
538    gamma_dot[0] = (gamma[1] - gamma[0]) / (argvals[1] - argvals[0]);
539    for j in 1..(m - 1) {
540        gamma_dot[j] = (gamma[j + 1] - gamma[j - 1]) / (argvals[j + 1] - argvals[j - 1]);
541    }
542    gamma_dot[m - 1] = (gamma[m - 1] - gamma[m - 2]) / (argvals[m - 1] - argvals[m - 2]);
543
544    // q2_aligned = (q2 ∘ γ) * sqrt(γ')
545    let q2_aligned: Vec<f64> = q2_warped
546        .iter()
547        .zip(gamma_dot.iter())
548        .map(|(&q, &gd)| q * gd.max(0.0).sqrt())
549        .collect();
550
551    (gamma, q2_aligned)
552}
553
554/// Compute the Karcher (Fréchet) mean in the elastic metric.
555///
556/// Iteratively aligns all curves to the current mean estimate in SRSF space,
557/// computes the pointwise mean of aligned SRSFs, and reconstructs the mean curve.
558///
559/// # Arguments
560/// * `data` — Functional data matrix (n × m)
561/// * `argvals` — Evaluation points (length m)
562/// * `max_iter` — Maximum number of iterations
563/// * `tol` — Convergence tolerance for the SRSF mean
564///
565/// # Returns
566/// [`KarcherMeanResult`] with mean curve, warping functions, aligned data, and convergence info.
567///
568/// # Examples
569///
570/// ```
571/// use fdars_core::simulation::{sim_fundata, EFunType, EValType};
572/// use fdars_core::alignment::karcher_mean;
573///
574/// let t: Vec<f64> = (0..50).map(|i| i as f64 / 49.0).collect();
575/// let data = sim_fundata(20, &t, 3, EFunType::Fourier, EValType::Exponential, Some(42));
576///
577/// let result = karcher_mean(&data, &t, 20, 1e-4);
578/// assert_eq!(result.mean.len(), 50);
579/// assert!(result.n_iter <= 20);
580/// ```
581/// Accumulate alignment results: store gammas and return the mean of aligned SRSFs.
582fn accumulate_alignments(
583    results: &[(Vec<f64>, Vec<f64>)],
584    gammas: &mut FdMatrix,
585    m: usize,
586    n: usize,
587) -> Vec<f64> {
588    let mut mu_q_new = vec![0.0; m];
589    for (i, (gamma, q_aligned)) in results.iter().enumerate() {
590        for j in 0..m {
591            gammas[(i, j)] = gamma[j];
592            mu_q_new[j] += q_aligned[j];
593        }
594    }
595    for j in 0..m {
596        mu_q_new[j] /= n as f64;
597    }
598    mu_q_new
599}
600
601/// Apply stored warps to original curves to produce aligned data.
602fn apply_stored_warps(data: &FdMatrix, gammas: &FdMatrix, argvals: &[f64]) -> FdMatrix {
603    let (n, m) = data.shape();
604    let mut aligned = FdMatrix::zeros(n, m);
605    for i in 0..n {
606        let fi = data.row(i);
607        let gamma: Vec<f64> = (0..m).map(|j| gammas[(i, j)]).collect();
608        let f_aligned = reparameterize_curve(&fi, argvals, &gamma);
609        for j in 0..m {
610            aligned[(i, j)] = f_aligned[j];
611        }
612    }
613    aligned
614}
615
616pub fn karcher_mean(
617    data: &FdMatrix,
618    argvals: &[f64],
619    max_iter: usize,
620    tol: f64,
621) -> KarcherMeanResult {
622    let (n, m) = data.shape();
623    let weights = simpsons_weights(argvals);
624
625    let mut mu = mean_1d(data);
626    let mut mu_q = srsf_single(&mu, argvals);
627
628    let mut converged = false;
629    let mut n_iter = 0;
630    let mut final_gammas = FdMatrix::zeros(n, m);
631
632    for iter in 0..max_iter {
633        n_iter = iter + 1;
634
635        let align_results: Vec<(Vec<f64>, Vec<f64>)> = iter_maybe_parallel!(0..n)
636            .map(|i| {
637                let fi = data.row(i);
638                let qi = srsf_single(&fi, argvals);
639                align_srsf_pair(&mu_q, &qi, argvals)
640            })
641            .collect();
642
643        let mu_q_new = accumulate_alignments(&align_results, &mut final_gammas, m, n);
644
645        if mean_has_converged(&mu_q, &mu_q_new, &weights, tol) {
646            converged = true;
647            mu_q = mu_q_new;
648            break;
649        }
650
651        mu_q = mu_q_new;
652        mu = srsf_inverse(&mu_q, argvals, mu[0]);
653    }
654
655    let initial_mean = mean_1d(data);
656    mu = srsf_inverse(&mu_q, argvals, initial_mean[0]);
657    let final_aligned = apply_stored_warps(data, &final_gammas, argvals);
658
659    KarcherMeanResult {
660        mean: mu,
661        mean_srsf: mu_q,
662        gammas: final_gammas,
663        aligned_data: final_aligned,
664        n_iter,
665        converged,
666    }
667}
668
669// ─── Tests ──────────────────────────────────────────────────────────────────
670
671#[cfg(test)]
672mod tests {
673    use super::*;
674    use crate::simulation::{sim_fundata, EFunType, EValType};
675
676    fn uniform_grid(m: usize) -> Vec<f64> {
677        (0..m).map(|i| i as f64 / (m - 1) as f64).collect()
678    }
679
680    fn make_test_data(n: usize, m: usize, seed: u64) -> FdMatrix {
681        let t = uniform_grid(m);
682        sim_fundata(
683            n,
684            &t,
685            3,
686            EFunType::Fourier,
687            EValType::Exponential,
688            Some(seed),
689        )
690    }
691
692    // ── cumulative_trapz ──
693
694    #[test]
695    fn test_cumulative_trapz_constant() {
696        // ∫₀ᵗ 1 dt = t
697        let x = uniform_grid(50);
698        let y = vec![1.0; 50];
699        let result = cumulative_trapz(&y, &x);
700        assert!((result[0]).abs() < 1e-15, "cumulative_trapz(0) should be 0");
701        for j in 1..50 {
702            assert!(
703                (result[j] - x[j]).abs() < 1e-12,
704                "∫₀^{:.3} 1 dt should be {:.3}, got {:.3}",
705                x[j],
706                x[j],
707                result[j]
708            );
709        }
710    }
711
712    #[test]
713    fn test_cumulative_trapz_linear() {
714        // ∫₀ᵗ s ds = t²/2
715        let m = 100;
716        let x = uniform_grid(m);
717        let y: Vec<f64> = x.clone();
718        let result = cumulative_trapz(&y, &x);
719        for j in 1..m {
720            let expected = x[j] * x[j] / 2.0;
721            assert!(
722                (result[j] - expected).abs() < 1e-4,
723                "∫₀^{:.3} s ds: expected {expected:.6}, got {:.6}",
724                x[j],
725                result[j]
726            );
727        }
728    }
729
730    // ── normalize_warp ──
731
732    #[test]
733    fn test_normalize_warp_fixes_boundaries() {
734        let t = uniform_grid(10);
735        let mut gamma = vec![0.1; 10]; // constant, wrong boundaries
736        normalize_warp(&mut gamma, &t);
737        assert_eq!(gamma[0], t[0]);
738        assert_eq!(gamma[9], t[9]);
739    }
740
741    #[test]
742    fn test_normalize_warp_enforces_monotonicity() {
743        let t = uniform_grid(5);
744        let mut gamma = vec![0.0, 0.5, 0.3, 0.8, 1.0]; // non-monotone at index 2
745        normalize_warp(&mut gamma, &t);
746        for j in 1..5 {
747            assert!(
748                gamma[j] >= gamma[j - 1],
749                "gamma should be monotone after normalization at j={j}"
750            );
751        }
752    }
753
754    #[test]
755    fn test_normalize_warp_identity_unchanged() {
756        let t = uniform_grid(20);
757        let mut gamma = t.clone();
758        normalize_warp(&mut gamma, &t);
759        for j in 0..20 {
760            assert!(
761                (gamma[j] - t[j]).abs() < 1e-15,
762                "Identity warp should be unchanged"
763            );
764        }
765    }
766
767    // ── linear_interp ──
768
769    #[test]
770    fn test_linear_interp_at_nodes() {
771        let x = vec![0.0, 1.0, 2.0, 3.0];
772        let y = vec![0.0, 2.0, 4.0, 6.0];
773        for i in 0..x.len() {
774            assert!((linear_interp(&x, &y, x[i]) - y[i]).abs() < 1e-12);
775        }
776    }
777
778    #[test]
779    fn test_linear_interp_midpoints() {
780        let x = vec![0.0, 1.0, 2.0];
781        let y = vec![0.0, 2.0, 4.0];
782        assert!((linear_interp(&x, &y, 0.5) - 1.0).abs() < 1e-12);
783        assert!((linear_interp(&x, &y, 1.5) - 3.0).abs() < 1e-12);
784    }
785
786    #[test]
787    fn test_linear_interp_clamp() {
788        let x = vec![0.0, 1.0, 2.0];
789        let y = vec![1.0, 3.0, 5.0];
790        assert!((linear_interp(&x, &y, -1.0) - 1.0).abs() < 1e-12);
791        assert!((linear_interp(&x, &y, 3.0) - 5.0).abs() < 1e-12);
792    }
793
794    #[test]
795    fn test_linear_interp_nonuniform_grid() {
796        let x = vec![0.0, 0.1, 0.5, 1.0];
797        let y = vec![0.0, 1.0, 5.0, 10.0];
798        // Between 0.1 and 0.5: slope = (5-1)/(0.5-0.1) = 10
799        let val = linear_interp(&x, &y, 0.3);
800        let expected = 1.0 + 10.0 * (0.3 - 0.1);
801        assert!(
802            (val - expected).abs() < 1e-12,
803            "Non-uniform interp: expected {expected}, got {val}"
804        );
805    }
806
807    #[test]
808    fn test_linear_interp_two_points() {
809        let x = vec![0.0, 1.0];
810        let y = vec![3.0, 7.0];
811        assert!((linear_interp(&x, &y, 0.25) - 4.0).abs() < 1e-12);
812        assert!((linear_interp(&x, &y, 0.75) - 6.0).abs() < 1e-12);
813    }
814
815    // ── SRSF transform ──
816
817    #[test]
818    fn test_srsf_transform_linear() {
819        // f(t) = 2t: derivative = 2, SRSF = sqrt(2)
820        let m = 50;
821        let t = uniform_grid(m);
822        let f: Vec<f64> = t.iter().map(|&ti| 2.0 * ti).collect();
823        let mat = FdMatrix::from_slice(&f, 1, m).unwrap();
824
825        let q_mat = srsf_transform(&mat, &t);
826        let q: Vec<f64> = q_mat.row(0);
827
828        let expected = 2.0_f64.sqrt();
829        // Interior points should be close to sqrt(2)
830        for j in 2..(m - 2) {
831            assert!(
832                (q[j] - expected).abs() < 0.1,
833                "q[{j}] = {}, expected ~{expected}",
834                q[j]
835            );
836        }
837    }
838
839    #[test]
840    fn test_srsf_transform_preserves_shape() {
841        let data = make_test_data(10, 50, 42);
842        let t = uniform_grid(50);
843        let q = srsf_transform(&data, &t);
844        assert_eq!(q.shape(), data.shape());
845    }
846
847    #[test]
848    fn test_srsf_transform_constant_is_zero() {
849        // f(t) = 5 (constant): derivative = 0, SRSF = 0
850        let m = 30;
851        let t = uniform_grid(m);
852        let f = vec![5.0; m];
853        let mat = FdMatrix::from_slice(&f, 1, m).unwrap();
854        let q_mat = srsf_transform(&mat, &t);
855        let q: Vec<f64> = q_mat.row(0);
856
857        for j in 0..m {
858            assert!(
859                q[j].abs() < 1e-10,
860                "SRSF of constant should be 0, got q[{j}] = {}",
861                q[j]
862            );
863        }
864    }
865
866    #[test]
867    fn test_srsf_transform_negative_slope() {
868        // f(t) = -3t: derivative = -3, SRSF = -sqrt(3)
869        let m = 50;
870        let t = uniform_grid(m);
871        let f: Vec<f64> = t.iter().map(|&ti| -3.0 * ti).collect();
872        let mat = FdMatrix::from_slice(&f, 1, m).unwrap();
873
874        let q_mat = srsf_transform(&mat, &t);
875        let q: Vec<f64> = q_mat.row(0);
876
877        let expected = -(3.0_f64.sqrt());
878        for j in 2..(m - 2) {
879            assert!(
880                (q[j] - expected).abs() < 0.15,
881                "q[{j}] = {}, expected ~{expected}",
882                q[j]
883            );
884        }
885    }
886
887    #[test]
888    fn test_srsf_transform_empty_input() {
889        let data = FdMatrix::zeros(0, 0);
890        let t: Vec<f64> = vec![];
891        let q = srsf_transform(&data, &t);
892        assert_eq!(q.shape(), (0, 0));
893    }
894
895    #[test]
896    fn test_srsf_transform_multiple_curves() {
897        let m = 40;
898        let t = uniform_grid(m);
899        let data = make_test_data(5, m, 42);
900
901        let q = srsf_transform(&data, &t);
902        assert_eq!(q.shape(), (5, m));
903
904        // Each row should have finite values
905        for i in 0..5 {
906            for j in 0..m {
907                assert!(q[(i, j)].is_finite(), "SRSF should be finite at ({i},{j})");
908            }
909        }
910    }
911
912    // ── SRSF inverse ──
913
914    #[test]
915    fn test_srsf_round_trip() {
916        let m = 100;
917        let t = uniform_grid(m);
918        // Use a smooth function
919        let f: Vec<f64> = t
920            .iter()
921            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin() + ti)
922            .collect();
923
924        let mat = FdMatrix::from_slice(&f, 1, m).unwrap();
925        let q_mat = srsf_transform(&mat, &t);
926        let q: Vec<f64> = q_mat.row(0);
927
928        let f_recon = srsf_inverse(&q, &t, f[0]);
929
930        // Check reconstruction is close (interior points, avoid boundary effects)
931        let max_err: f64 = f[5..(m - 5)]
932            .iter()
933            .zip(f_recon[5..(m - 5)].iter())
934            .map(|(a, b)| (a - b).abs())
935            .fold(0.0_f64, f64::max);
936
937        assert!(
938            max_err < 0.15,
939            "Round-trip error too large: max_err = {max_err}"
940        );
941    }
942
943    #[test]
944    fn test_srsf_inverse_empty() {
945        let q: Vec<f64> = vec![];
946        let t: Vec<f64> = vec![];
947        let result = srsf_inverse(&q, &t, 0.0);
948        assert!(result.is_empty());
949    }
950
951    #[test]
952    fn test_srsf_inverse_preserves_initial_value() {
953        let m = 50;
954        let t = uniform_grid(m);
955        let q = vec![1.0; m]; // constant SRSF
956        let f0 = 3.15;
957        let f = srsf_inverse(&q, &t, f0);
958        assert!((f[0] - f0).abs() < 1e-12, "srsf_inverse should start at f0");
959    }
960
961    #[test]
962    fn test_srsf_round_trip_multiple_curves() {
963        let m = 80;
964        let t = uniform_grid(m);
965        let data = make_test_data(5, m, 99);
966
967        let q_mat = srsf_transform(&data, &t);
968
969        for i in 0..5 {
970            let fi = data.row(i);
971            let qi = q_mat.row(i);
972            let f_recon = srsf_inverse(&qi, &t, fi[0]);
973            let max_err: f64 = fi[5..(m - 5)]
974                .iter()
975                .zip(f_recon[5..(m - 5)].iter())
976                .map(|(a, b)| (a - b).abs())
977                .fold(0.0_f64, f64::max);
978            assert!(max_err < 0.3, "Round-trip curve {i}: max_err = {max_err}");
979        }
980    }
981
982    // ── Reparameterize ──
983
984    #[test]
985    fn test_reparameterize_identity_warp() {
986        let m = 50;
987        let t = uniform_grid(m);
988        let f: Vec<f64> = t.iter().map(|&ti| ti * ti).collect();
989
990        // Identity warp: γ(t) = t
991        let result = reparameterize_curve(&f, &t, &t);
992        for j in 0..m {
993            assert!(
994                (result[j] - f[j]).abs() < 1e-12,
995                "Identity warp should return original at j={j}"
996            );
997        }
998    }
999
1000    #[test]
1001    fn test_reparameterize_linear_warp() {
1002        let m = 50;
1003        let t = uniform_grid(m);
1004        // f(t) = t (linear), γ(t) = t^2 (quadratic warp on [0,1])
1005        let f: Vec<f64> = t.clone();
1006        let gamma: Vec<f64> = t.iter().map(|&ti| ti * ti).collect();
1007
1008        let result = reparameterize_curve(&f, &t, &gamma);
1009
1010        // f(γ(t)) = γ(t) = t^2 for a linear f(t) = t
1011        for j in 0..m {
1012            assert!(
1013                (result[j] - gamma[j]).abs() < 1e-10,
1014                "f(gamma(t)) should be gamma(t) for f(t)=t at j={j}"
1015            );
1016        }
1017    }
1018
1019    #[test]
1020    fn test_reparameterize_sine_with_quadratic_warp() {
1021        let m = 100;
1022        let t = uniform_grid(m);
1023        let f: Vec<f64> = t
1024            .iter()
1025            .map(|&ti| (std::f64::consts::PI * ti).sin())
1026            .collect();
1027        let gamma: Vec<f64> = t.iter().map(|&ti| ti * ti).collect(); // speeds up start
1028
1029        let result = reparameterize_curve(&f, &t, &gamma);
1030
1031        // f(γ(t)) = sin(π * t²); check a few known values
1032        for j in 0..m {
1033            let expected = (std::f64::consts::PI * gamma[j]).sin();
1034            assert!(
1035                (result[j] - expected).abs() < 0.05,
1036                "sin(π γ(t)) at j={j}: expected {expected:.4}, got {:.4}",
1037                result[j]
1038            );
1039        }
1040    }
1041
1042    #[test]
1043    fn test_reparameterize_preserves_length() {
1044        let m = 50;
1045        let t = uniform_grid(m);
1046        let f = vec![1.0; m];
1047        let gamma: Vec<f64> = t.iter().map(|&ti| ti.sqrt()).collect();
1048
1049        let result = reparameterize_curve(&f, &t, &gamma);
1050        assert_eq!(result.len(), m);
1051    }
1052
1053    // ── Compose warps ──
1054
1055    #[test]
1056    fn test_compose_warps_identity() {
1057        let m = 50;
1058        let t = uniform_grid(m);
1059        // γ(t) = t^0.5 (a warp on [0,1])
1060        let gamma: Vec<f64> = t.iter().map(|&ti| ti.sqrt()).collect();
1061
1062        // identity ∘ γ = γ
1063        let result = compose_warps(&t, &gamma, &t);
1064        for j in 0..m {
1065            assert!(
1066                (result[j] - gamma[j]).abs() < 1e-10,
1067                "id ∘ γ should be γ at j={j}"
1068            );
1069        }
1070
1071        // γ ∘ identity = γ
1072        let result2 = compose_warps(&gamma, &t, &t);
1073        for j in 0..m {
1074            assert!(
1075                (result2[j] - gamma[j]).abs() < 1e-10,
1076                "γ ∘ id should be γ at j={j}"
1077            );
1078        }
1079    }
1080
1081    #[test]
1082    fn test_compose_warps_associativity() {
1083        // (γ₁ ∘ γ₂) ∘ γ₃ ≈ γ₁ ∘ (γ₂ ∘ γ₃)
1084        let m = 50;
1085        let t = uniform_grid(m);
1086        let g1: Vec<f64> = t.iter().map(|&ti| ti.sqrt()).collect();
1087        let g2: Vec<f64> = t.iter().map(|&ti| ti * ti).collect();
1088        let g3: Vec<f64> = t.iter().map(|&ti| 0.5 * ti + 0.5 * ti * ti).collect();
1089
1090        let g12 = compose_warps(&g1, &g2, &t);
1091        let left = compose_warps(&g12, &g3, &t); // (g1∘g2) ∘ g3
1092
1093        let g23 = compose_warps(&g2, &g3, &t);
1094        let right = compose_warps(&g1, &g23, &t); // g1 ∘ (g2∘g3)
1095
1096        for j in 0..m {
1097            assert!(
1098                (left[j] - right[j]).abs() < 0.05,
1099                "Composition should be roughly associative at j={j}: left={:.4}, right={:.4}",
1100                left[j],
1101                right[j]
1102            );
1103        }
1104    }
1105
1106    #[test]
1107    fn test_compose_warps_preserves_domain() {
1108        let m = 50;
1109        let t = uniform_grid(m);
1110        let g1: Vec<f64> = t.iter().map(|&ti| ti.sqrt()).collect();
1111        let g2: Vec<f64> = t.iter().map(|&ti| ti * ti).collect();
1112
1113        let composed = compose_warps(&g1, &g2, &t);
1114        assert!(
1115            (composed[0] - t[0]).abs() < 1e-10,
1116            "Composed warp should start at domain start"
1117        );
1118        assert!(
1119            (composed[m - 1] - t[m - 1]).abs() < 1e-10,
1120            "Composed warp should end at domain end"
1121        );
1122    }
1123
1124    // ── Elastic align pair ──
1125
1126    #[test]
1127    fn test_align_identical_curves() {
1128        let m = 50;
1129        let t = uniform_grid(m);
1130        let f: Vec<f64> = t
1131            .iter()
1132            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
1133            .collect();
1134
1135        let result = elastic_align_pair(&f, &f, &t);
1136
1137        // Distance should be near zero
1138        assert!(
1139            result.distance < 0.1,
1140            "Distance between identical curves should be near 0, got {}",
1141            result.distance
1142        );
1143
1144        // Warp should be near identity
1145        for j in 0..m {
1146            assert!(
1147                (result.gamma[j] - t[j]).abs() < 0.1,
1148                "Warp should be near identity at j={j}: gamma={}, t={}",
1149                result.gamma[j],
1150                t[j]
1151            );
1152        }
1153    }
1154
1155    #[test]
1156    fn test_align_pair_valid_output() {
1157        let data = make_test_data(2, 50, 42);
1158        let t = uniform_grid(50);
1159        let f1 = data.row(0);
1160        let f2 = data.row(1);
1161
1162        let result = elastic_align_pair(&f1, &f2, &t);
1163
1164        assert_eq!(result.gamma.len(), 50);
1165        assert_eq!(result.f_aligned.len(), 50);
1166        assert!(result.distance >= 0.0);
1167
1168        // Warp should be monotone
1169        for j in 1..50 {
1170            assert!(
1171                result.gamma[j] >= result.gamma[j - 1],
1172                "Warp should be monotone at j={j}"
1173            );
1174        }
1175    }
1176
1177    #[test]
1178    fn test_align_pair_warp_boundaries() {
1179        let data = make_test_data(2, 50, 42);
1180        let t = uniform_grid(50);
1181        let f1 = data.row(0);
1182        let f2 = data.row(1);
1183
1184        let result = elastic_align_pair(&f1, &f2, &t);
1185        assert!(
1186            (result.gamma[0] - t[0]).abs() < 1e-12,
1187            "Warp should start at domain start"
1188        );
1189        assert!(
1190            (result.gamma[49] - t[49]).abs() < 1e-12,
1191            "Warp should end at domain end"
1192        );
1193    }
1194
1195    #[test]
1196    fn test_align_shifted_sine() {
1197        // Two sines with a phase shift — alignment should reduce distance
1198        let m = 80;
1199        let t = uniform_grid(m);
1200        let f1: Vec<f64> = t
1201            .iter()
1202            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
1203            .collect();
1204        let f2: Vec<f64> = t
1205            .iter()
1206            .map(|&ti| (2.0 * std::f64::consts::PI * (ti - 0.1)).sin())
1207            .collect();
1208
1209        let weights = simpsons_weights(&t);
1210        let l2_before = l2_distance(&f1, &f2, &weights);
1211        let result = elastic_align_pair(&f1, &f2, &t);
1212        let l2_after = l2_distance(&f1, &result.f_aligned, &weights);
1213
1214        assert!(
1215            l2_after < l2_before + 0.01,
1216            "Alignment should not increase L2 distance: before={l2_before:.4}, after={l2_after:.4}"
1217        );
1218    }
1219
1220    #[test]
1221    fn test_align_pair_aligned_curve_is_finite() {
1222        let data = make_test_data(2, 50, 77);
1223        let t = uniform_grid(50);
1224        let f1 = data.row(0);
1225        let f2 = data.row(1);
1226
1227        let result = elastic_align_pair(&f1, &f2, &t);
1228        for j in 0..50 {
1229            assert!(
1230                result.f_aligned[j].is_finite(),
1231                "Aligned curve should be finite at j={j}"
1232            );
1233        }
1234    }
1235
1236    #[test]
1237    fn test_align_pair_minimum_grid() {
1238        // Minimum viable grid: m = 2
1239        let t = vec![0.0, 1.0];
1240        let f1 = vec![0.0, 1.0];
1241        let f2 = vec![0.0, 2.0];
1242        let result = elastic_align_pair(&f1, &f2, &t);
1243        assert_eq!(result.gamma.len(), 2);
1244        assert_eq!(result.f_aligned.len(), 2);
1245        assert!(result.distance >= 0.0);
1246    }
1247
1248    // ── Elastic distance ──
1249
1250    #[test]
1251    fn test_elastic_distance_symmetric() {
1252        let data = make_test_data(3, 50, 42);
1253        let t = uniform_grid(50);
1254        let f1 = data.row(0);
1255        let f2 = data.row(1);
1256
1257        let d12 = elastic_distance(&f1, &f2, &t);
1258        let d21 = elastic_distance(&f2, &f1, &t);
1259
1260        // Should be approximately symmetric (DP is not perfectly symmetric)
1261        assert!(
1262            (d12 - d21).abs() < d12.max(d21) * 0.3 + 0.01,
1263            "Elastic distance should be roughly symmetric: d12={d12}, d21={d21}"
1264        );
1265    }
1266
1267    #[test]
1268    fn test_elastic_distance_nonneg() {
1269        let data = make_test_data(3, 50, 42);
1270        let t = uniform_grid(50);
1271
1272        for i in 0..3 {
1273            for j in 0..3 {
1274                let fi = data.row(i);
1275                let fj = data.row(j);
1276                let d = elastic_distance(&fi, &fj, &t);
1277                assert!(d >= 0.0, "Elastic distance should be non-negative");
1278            }
1279        }
1280    }
1281
1282    #[test]
1283    fn test_elastic_distance_self_near_zero() {
1284        let data = make_test_data(3, 50, 42);
1285        let t = uniform_grid(50);
1286
1287        for i in 0..3 {
1288            let fi = data.row(i);
1289            let d = elastic_distance(&fi, &fi, &t);
1290            assert!(
1291                d < 0.1,
1292                "Self-distance should be near zero, got {d} for curve {i}"
1293            );
1294        }
1295    }
1296
1297    #[test]
1298    fn test_elastic_distance_triangle_inequality() {
1299        let data = make_test_data(3, 50, 42);
1300        let t = uniform_grid(50);
1301        let f0 = data.row(0);
1302        let f1 = data.row(1);
1303        let f2 = data.row(2);
1304
1305        let d01 = elastic_distance(&f0, &f1, &t);
1306        let d12 = elastic_distance(&f1, &f2, &t);
1307        let d02 = elastic_distance(&f0, &f2, &t);
1308
1309        // Relaxed triangle inequality (DP alignment is approximate)
1310        let slack = 0.5;
1311        assert!(
1312            d02 <= d01 + d12 + slack,
1313            "Triangle inequality (relaxed): d02={d02:.4} > d01={d01:.4} + d12={d12:.4} + {slack}"
1314        );
1315    }
1316
1317    #[test]
1318    fn test_elastic_distance_different_shapes_nonzero() {
1319        let m = 50;
1320        let t = uniform_grid(m);
1321        let f1: Vec<f64> = t.to_vec(); // linear
1322        let f2: Vec<f64> = t.iter().map(|&ti| ti * ti).collect(); // quadratic
1323
1324        let d = elastic_distance(&f1, &f2, &t);
1325        assert!(
1326            d > 0.01,
1327            "Distance between different shapes should be > 0, got {d}"
1328        );
1329    }
1330
1331    // ── Self distance matrix ──
1332
1333    #[test]
1334    fn test_self_distance_matrix_symmetric() {
1335        let data = make_test_data(5, 30, 42);
1336        let t = uniform_grid(30);
1337
1338        let dm = elastic_self_distance_matrix(&data, &t);
1339        let n = dm.nrows();
1340
1341        assert_eq!(dm.shape(), (5, 5));
1342
1343        // Zero diagonal
1344        for i in 0..n {
1345            assert!(dm[(i, i)].abs() < 1e-12, "Diagonal should be zero");
1346        }
1347
1348        // Symmetric
1349        for i in 0..n {
1350            for j in (i + 1)..n {
1351                assert!(
1352                    (dm[(i, j)] - dm[(j, i)]).abs() < 1e-12,
1353                    "Matrix should be symmetric at ({i},{j})"
1354                );
1355            }
1356        }
1357    }
1358
1359    #[test]
1360    fn test_self_distance_matrix_nonneg() {
1361        let data = make_test_data(4, 30, 42);
1362        let t = uniform_grid(30);
1363        let dm = elastic_self_distance_matrix(&data, &t);
1364
1365        for i in 0..4 {
1366            for j in 0..4 {
1367                assert!(
1368                    dm[(i, j)] >= 0.0,
1369                    "Distance matrix entries should be non-negative at ({i},{j})"
1370                );
1371            }
1372        }
1373    }
1374
1375    #[test]
1376    fn test_self_distance_matrix_single_curve() {
1377        let data = make_test_data(1, 30, 42);
1378        let t = uniform_grid(30);
1379        let dm = elastic_self_distance_matrix(&data, &t);
1380        assert_eq!(dm.shape(), (1, 1));
1381        assert!(dm[(0, 0)].abs() < 1e-12);
1382    }
1383
1384    #[test]
1385    fn test_self_distance_matrix_consistent_with_pairwise() {
1386        let data = make_test_data(4, 30, 42);
1387        let t = uniform_grid(30);
1388
1389        let dm = elastic_self_distance_matrix(&data, &t);
1390
1391        // Check a few entries match direct elastic_distance calls
1392        for i in 0..4 {
1393            for j in (i + 1)..4 {
1394                let fi = data.row(i);
1395                let fj = data.row(j);
1396                let d_direct = elastic_distance(&fi, &fj, &t);
1397                assert!(
1398                    (dm[(i, j)] - d_direct).abs() < 1e-10,
1399                    "Matrix entry ({i},{j})={:.6} should match pairwise {d_direct:.6}",
1400                    dm[(i, j)]
1401                );
1402            }
1403        }
1404    }
1405
1406    // ── Karcher mean ──
1407
1408    #[test]
1409    fn test_karcher_mean_identical_curves() {
1410        let m = 50;
1411        let t = uniform_grid(m);
1412        let f: Vec<f64> = t
1413            .iter()
1414            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
1415            .collect();
1416
1417        // Create 5 identical curves
1418        let mut data = FdMatrix::zeros(5, m);
1419        for i in 0..5 {
1420            for j in 0..m {
1421                data[(i, j)] = f[j];
1422            }
1423        }
1424
1425        let result = karcher_mean(&data, &t, 10, 1e-4);
1426
1427        assert_eq!(result.mean.len(), m);
1428        assert!(result.n_iter <= 10);
1429    }
1430
1431    #[test]
1432    fn test_karcher_mean_output_shape() {
1433        let data = make_test_data(15, 50, 42);
1434        let t = uniform_grid(50);
1435
1436        let result = karcher_mean(&data, &t, 5, 1e-3);
1437
1438        assert_eq!(result.mean.len(), 50);
1439        assert_eq!(result.mean_srsf.len(), 50);
1440        assert_eq!(result.gammas.shape(), (15, 50));
1441        assert_eq!(result.aligned_data.shape(), (15, 50));
1442        assert!(result.n_iter <= 5);
1443    }
1444
1445    #[test]
1446    fn test_karcher_mean_warps_are_valid() {
1447        let data = make_test_data(10, 40, 42);
1448        let t = uniform_grid(40);
1449
1450        let result = karcher_mean(&data, &t, 5, 1e-3);
1451
1452        for i in 0..10 {
1453            // Boundary values
1454            assert!(
1455                (result.gammas[(i, 0)] - t[0]).abs() < 1e-10,
1456                "Warp {i} should start at domain start"
1457            );
1458            assert!(
1459                (result.gammas[(i, 39)] - t[39]).abs() < 1e-10,
1460                "Warp {i} should end at domain end"
1461            );
1462            // Monotonicity
1463            for j in 1..40 {
1464                assert!(
1465                    result.gammas[(i, j)] >= result.gammas[(i, j - 1)],
1466                    "Warp {i} should be monotone at j={j}"
1467                );
1468            }
1469        }
1470    }
1471
1472    #[test]
1473    fn test_karcher_mean_aligned_data_is_finite() {
1474        let data = make_test_data(8, 40, 42);
1475        let t = uniform_grid(40);
1476        let result = karcher_mean(&data, &t, 5, 1e-3);
1477
1478        for i in 0..8 {
1479            for j in 0..40 {
1480                assert!(
1481                    result.aligned_data[(i, j)].is_finite(),
1482                    "Aligned data should be finite at ({i},{j})"
1483                );
1484            }
1485        }
1486    }
1487
1488    #[test]
1489    fn test_karcher_mean_srsf_is_finite() {
1490        let data = make_test_data(8, 40, 42);
1491        let t = uniform_grid(40);
1492        let result = karcher_mean(&data, &t, 5, 1e-3);
1493
1494        for j in 0..40 {
1495            assert!(
1496                result.mean_srsf[j].is_finite(),
1497                "Mean SRSF should be finite at j={j}"
1498            );
1499            assert!(
1500                result.mean[j].is_finite(),
1501                "Mean curve should be finite at j={j}"
1502            );
1503        }
1504    }
1505
1506    #[test]
1507    fn test_karcher_mean_single_iteration() {
1508        let data = make_test_data(10, 40, 42);
1509        let t = uniform_grid(40);
1510        let result = karcher_mean(&data, &t, 1, 1e-10);
1511
1512        assert_eq!(result.n_iter, 1);
1513        assert_eq!(result.mean.len(), 40);
1514        // With only 1 iteration, still produces valid output
1515        for j in 0..40 {
1516            assert!(result.mean[j].is_finite());
1517        }
1518    }
1519
1520    // ── Align to target ──
1521
1522    #[test]
1523    fn test_align_to_target_valid() {
1524        let data = make_test_data(10, 40, 42);
1525        let t = uniform_grid(40);
1526        let target = data.row(0);
1527
1528        let result = align_to_target(&data, &target, &t);
1529
1530        assert_eq!(result.gammas.shape(), (10, 40));
1531        assert_eq!(result.aligned_data.shape(), (10, 40));
1532        assert_eq!(result.distances.len(), 10);
1533
1534        // All distances should be non-negative
1535        for &d in &result.distances {
1536            assert!(d >= 0.0);
1537        }
1538    }
1539
1540    #[test]
1541    fn test_align_to_target_self_near_zero() {
1542        let data = make_test_data(5, 40, 42);
1543        let t = uniform_grid(40);
1544        let target = data.row(0);
1545
1546        let result = align_to_target(&data, &target, &t);
1547
1548        // Distance of target to itself should be near zero
1549        assert!(
1550            result.distances[0] < 0.1,
1551            "Self-alignment distance should be near zero, got {}",
1552            result.distances[0]
1553        );
1554    }
1555
1556    #[test]
1557    fn test_align_to_target_warps_are_monotone() {
1558        let data = make_test_data(8, 40, 42);
1559        let t = uniform_grid(40);
1560        let target = data.row(0);
1561        let result = align_to_target(&data, &target, &t);
1562
1563        for i in 0..8 {
1564            for j in 1..40 {
1565                assert!(
1566                    result.gammas[(i, j)] >= result.gammas[(i, j - 1)],
1567                    "Warp for curve {i} should be monotone at j={j}"
1568                );
1569            }
1570        }
1571    }
1572
1573    #[test]
1574    fn test_align_to_target_aligned_data_finite() {
1575        let data = make_test_data(6, 40, 42);
1576        let t = uniform_grid(40);
1577        let target = data.row(0);
1578        let result = align_to_target(&data, &target, &t);
1579
1580        for i in 0..6 {
1581            for j in 0..40 {
1582                assert!(
1583                    result.aligned_data[(i, j)].is_finite(),
1584                    "Aligned data should be finite at ({i},{j})"
1585                );
1586            }
1587        }
1588    }
1589
1590    // ── Cross distance matrix ──
1591
1592    #[test]
1593    fn test_cross_distance_matrix_shape() {
1594        let data1 = make_test_data(3, 30, 42);
1595        let data2 = make_test_data(4, 30, 99);
1596        let t = uniform_grid(30);
1597
1598        let dm = elastic_cross_distance_matrix(&data1, &data2, &t);
1599        assert_eq!(dm.shape(), (3, 4));
1600
1601        // All non-negative
1602        for i in 0..3 {
1603            for j in 0..4 {
1604                assert!(dm[(i, j)] >= 0.0);
1605            }
1606        }
1607    }
1608
1609    #[test]
1610    fn test_cross_distance_matrix_self_matches_self_matrix() {
1611        // cross_distance(data, data) should have zero diagonal (approximately)
1612        let data = make_test_data(4, 30, 42);
1613        let t = uniform_grid(30);
1614
1615        let cross = elastic_cross_distance_matrix(&data, &data, &t);
1616        for i in 0..4 {
1617            assert!(
1618                cross[(i, i)] < 0.1,
1619                "Cross distance (self) diagonal should be near zero: got {}",
1620                cross[(i, i)]
1621            );
1622        }
1623    }
1624
1625    #[test]
1626    fn test_cross_distance_matrix_consistent_with_pairwise() {
1627        let data1 = make_test_data(3, 30, 42);
1628        let data2 = make_test_data(2, 30, 99);
1629        let t = uniform_grid(30);
1630
1631        let dm = elastic_cross_distance_matrix(&data1, &data2, &t);
1632
1633        for i in 0..3 {
1634            for j in 0..2 {
1635                let fi = data1.row(i);
1636                let fj = data2.row(j);
1637                let d_direct = elastic_distance(&fi, &fj, &t);
1638                assert!(
1639                    (dm[(i, j)] - d_direct).abs() < 1e-10,
1640                    "Cross matrix ({i},{j})={:.6} should match pairwise {d_direct:.6}",
1641                    dm[(i, j)]
1642                );
1643            }
1644        }
1645    }
1646
1647    // ── align_srsf_pair ──
1648
1649    #[test]
1650    fn test_align_srsf_pair_identity() {
1651        let m = 50;
1652        let t = uniform_grid(m);
1653        let f: Vec<f64> = t
1654            .iter()
1655            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
1656            .collect();
1657        let q = srsf_single(&f, &t);
1658
1659        let (gamma, q_aligned) = align_srsf_pair(&q, &q, &t);
1660
1661        // Warp should be near identity
1662        for j in 0..m {
1663            assert!(
1664                (gamma[j] - t[j]).abs() < 0.15,
1665                "Self-SRSF alignment warp should be near identity at j={j}"
1666            );
1667        }
1668
1669        // Aligned SRSF should be close to original
1670        let weights = simpsons_weights(&t);
1671        let dist = l2_distance(&q, &q_aligned, &weights);
1672        assert!(
1673            dist < 0.5,
1674            "Self-aligned SRSF distance should be small, got {dist}"
1675        );
1676    }
1677
1678    // ── srsf_single ──
1679
1680    #[test]
1681    fn test_srsf_single_matches_matrix_version() {
1682        let m = 50;
1683        let t = uniform_grid(m);
1684        let f: Vec<f64> = t.iter().map(|&ti| ti * ti + ti).collect();
1685
1686        let q_single = srsf_single(&f, &t);
1687
1688        let mat = FdMatrix::from_slice(&f, 1, m).unwrap();
1689        let q_mat = srsf_transform(&mat, &t);
1690        let q_from_mat = q_mat.row(0);
1691
1692        for j in 0..m {
1693            assert!(
1694                (q_single[j] - q_from_mat[j]).abs() < 1e-12,
1695                "srsf_single should match srsf_transform at j={j}"
1696            );
1697        }
1698    }
1699
1700    // ── dp_traceback ──
1701
1702    #[test]
1703    fn test_dp_traceback_all_diagonal() {
1704        // Trace matrix with all diagonal (0) moves
1705        let m = 5;
1706        let trace = vec![0u8; m * m];
1707        let path = dp_traceback(&trace, m);
1708        assert_eq!(path.first(), Some(&(0, 0)));
1709        assert_eq!(path.last(), Some(&(m - 1, m - 1)));
1710        assert_eq!(path.len(), m);
1711    }
1712
1713    // ── Edge case: constant data ──
1714
1715    #[test]
1716    fn test_alignment_constant_curves() {
1717        let m = 30;
1718        let t = uniform_grid(m);
1719        let f1 = vec![5.0; m];
1720        let f2 = vec![5.0; m];
1721
1722        let result = elastic_align_pair(&f1, &f2, &t);
1723        assert!(
1724            result.distance < 0.01,
1725            "Constant curves: distance should be ~0"
1726        );
1727        assert_eq!(result.f_aligned.len(), m);
1728    }
1729
1730    #[test]
1731    fn test_karcher_mean_constant_curves() {
1732        let m = 30;
1733        let t = uniform_grid(m);
1734        let mut data = FdMatrix::zeros(5, m);
1735        for i in 0..5 {
1736            for j in 0..m {
1737                data[(i, j)] = 3.0;
1738            }
1739        }
1740
1741        let result = karcher_mean(&data, &t, 5, 1e-4);
1742        for j in 0..m {
1743            assert!(
1744                (result.mean[j] - 3.0).abs() < 0.5,
1745                "Mean of constant curves should be near 3.0, got {} at j={j}",
1746                result.mean[j]
1747            );
1748        }
1749    }
1750}