Skip to main content

fdars_core/alignment/
mod.rs

1//! Elastic alignment and SRSF (Square-Root Slope Function) transforms.
2//!
3//! This module provides phase-amplitude separation for functional data via
4//! the elastic framework. Key capabilities:
5//!
6//! - [`srsf_transform`] / [`srsf_inverse`] — SRSF representation and reconstruction
7//! - [`elastic_align_pair`] — Pairwise curve alignment via dynamic programming
8//! - [`elastic_distance`] — Elastic (Fisher-Rao) distance between curves
9//! - [`align_to_target`] — Align a set of curves to a common target
10//! - [`karcher_mean`] — Karcher (Fréchet) mean in the elastic metric
11//! - [`elastic_self_distance_matrix`] / [`elastic_cross_distance_matrix`] — Distance matrices
12//! - [`reparameterize_curve`] / [`compose_warps`] — Warping utilities
13
14mod bayesian;
15mod closed;
16mod clustering;
17mod constrained;
18mod diagnostics;
19mod elastic_depth;
20mod fpns;
21mod generative;
22mod geodesic;
23mod karcher;
24mod lambda_cv;
25mod multires;
26mod nd;
27mod outlier;
28mod pairwise;
29mod partial_match;
30mod persistence;
31mod phase_boxplot;
32mod quality;
33mod robust_karcher;
34mod set;
35mod shape;
36mod shape_ci;
37mod srsf;
38mod transfer;
39mod tsrvf;
40mod warp_stats;
41
42#[cfg(test)]
43mod tests;
44
45// Re-export all public items so that `crate::alignment::X` continues to work.
46pub use bayesian::{bayesian_align_pair, BayesianAlignConfig, BayesianAlignmentResult};
47pub use closed::{
48    elastic_align_pair_closed, elastic_distance_closed, karcher_mean_closed, ClosedAlignmentResult,
49    ClosedKarcherMeanResult,
50};
51pub use clustering::{
52    cut_dendrogram, hierarchical_from_distances, kmedoids_from_distances, Dendrogram,
53    KMedoidsConfig, KMedoidsResult, Linkage,
54};
55pub use constrained::{
56    elastic_align_pair_constrained, elastic_align_pair_with_landmarks, ConstrainedAlignmentResult,
57};
58pub use diagnostics::{
59    diagnose_alignment, diagnose_pairwise, AlignmentDiagnostic, AlignmentDiagnosticSummary,
60    DiagnosticConfig,
61};
62pub use elastic_depth::{elastic_depth, ElasticDepthResult};
63pub use fpns::{horiz_fpns, FpnsResult};
64pub use generative::{gauss_model, joint_gauss_model, GenerativeModelResult};
65pub use geodesic::{curve_geodesic, curve_geodesic_nd, GeodesicPath, GeodesicPathNd};
66pub use karcher::karcher_mean;
67pub use lambda_cv::{lambda_cv, LambdaCvConfig, LambdaCvResult};
68pub use multires::{elastic_align_pair_multires, MultiresConfig};
69pub use nd::{
70    elastic_align_pair_nd, elastic_distance_nd, srsf_inverse_nd, srsf_transform_nd,
71    AlignmentResultNd,
72};
73pub use nd::{karcher_covariance_nd, karcher_mean_nd, pca_nd, KarcherMeanResultNd, PcaNdResult};
74pub use outlier::{elastic_outlier_detection, ElasticOutlierConfig, ElasticOutlierResult};
75pub use pairwise::{
76    amplitude_distance, amplitude_self_distance_matrix, elastic_align_pair,
77    elastic_align_pair_penalized, elastic_cross_distance_matrix, elastic_distance,
78    elastic_self_distance_matrix, phase_distance_pair, phase_self_distance_matrix, WarpPenaltyType,
79};
80pub use partial_match::{elastic_partial_match, PartialMatchConfig, PartialMatchResult};
81pub use persistence::{peak_persistence, PersistenceDiagramResult};
82pub use phase_boxplot::{phase_boxplot, PhaseBoxplot};
83pub use quality::{
84    alignment_quality, pairwise_consistency, warp_complexity, warp_smoothness, AlignmentQuality,
85};
86pub use robust_karcher::{
87    karcher_median, robust_karcher_mean, RobustKarcherConfig, RobustKarcherResult,
88};
89pub use set::{align_to_target, elastic_decomposition, DecompositionResult};
90pub use shape::{
91    orbit_representative, shape_distance, shape_mean, shape_self_distance_matrix,
92    OrbitRepresentative, ShapeDistanceResult, ShapeMeanResult, ShapeQuotient,
93};
94pub use shape_ci::{shape_confidence_interval, ShapeCiConfig, ShapeCiResult};
95pub use srsf::{
96    compose_warps, invert_warp, reparameterize_curve, srsf_inverse, srsf_transform,
97    warp_inverse_error,
98};
99pub use transfer::{transfer_alignment, TransferAlignConfig, TransferAlignResult};
100pub use tsrvf::{
101    tsrvf_from_alignment, tsrvf_from_alignment_with_method, tsrvf_inverse, tsrvf_transform,
102    tsrvf_transform_with_method, TransportMethod, TsrvfResult,
103};
104pub use warp_stats::{warp_statistics, WarpStatistics};
105
106// Re-export pub(crate) items so other crate modules can use them.
107pub(crate) use karcher::sqrt_mean_inverse;
108
109use crate::helpers::linear_interp;
110use crate::matrix::FdMatrix;
111use crate::warping::normalize_warp;
112
113// ─── Types ──────────────────────────────────────────────────────────────────
114
115/// Result of aligning one curve to another.
116#[derive(Debug, Clone, PartialEq)]
117#[non_exhaustive]
118pub struct AlignmentResult {
119    /// Warping function γ mapping the domain to itself.
120    pub gamma: Vec<f64>,
121    /// The aligned (reparameterized) curve.
122    pub f_aligned: Vec<f64>,
123    /// Elastic distance after alignment.
124    pub distance: f64,
125}
126
127/// Result of aligning a set of curves to a common target.
128#[derive(Debug, Clone, PartialEq)]
129#[non_exhaustive]
130pub struct AlignmentSetResult {
131    /// Warping functions (n × m).
132    pub gammas: FdMatrix,
133    /// Aligned curves (n × m).
134    pub aligned_data: FdMatrix,
135    /// Elastic distances for each curve.
136    pub distances: Vec<f64>,
137}
138
139/// Result of the Karcher mean computation.
140#[derive(Debug, Clone, PartialEq)]
141#[non_exhaustive]
142pub struct KarcherMeanResult {
143    /// Karcher mean curve.
144    pub mean: Vec<f64>,
145    /// SRSF of the Karcher mean.
146    pub mean_srsf: Vec<f64>,
147    /// Final warping functions (n × m).
148    pub gammas: FdMatrix,
149    /// Curves aligned to the mean (n × m).
150    pub aligned_data: FdMatrix,
151    /// Number of iterations used.
152    pub n_iter: usize,
153    /// Whether the algorithm converged.
154    pub converged: bool,
155    /// Pre-computed SRSFs of aligned curves (n × m), if available.
156    /// When set, FPCA functions use these instead of recomputing from `aligned_data`.
157    pub aligned_srsfs: Option<FdMatrix>,
158}
159
160impl KarcherMeanResult {
161    /// Create a new `KarcherMeanResult`.
162    pub fn new(
163        mean: Vec<f64>,
164        mean_srsf: Vec<f64>,
165        gammas: FdMatrix,
166        aligned_data: FdMatrix,
167        n_iter: usize,
168        converged: bool,
169        aligned_srsfs: Option<FdMatrix>,
170    ) -> Self {
171        Self {
172            mean,
173            mean_srsf,
174            gammas,
175            aligned_data,
176            n_iter,
177            converged,
178            aligned_srsfs,
179        }
180    }
181}
182
183// ─── Trait: AlignmentOutput ─────────────────────────────────────────────────
184
185/// Common interface for alignment results, enabling interchangeable
186/// alignment methods in downstream analysis (elastic FPCA, regression, etc.).
187pub trait AlignmentOutput {
188    /// The estimated mean/template curve (length m).
189    fn mean(&self) -> &[f64];
190    /// The mean SRSF (length m).
191    fn mean_srsf(&self) -> &[f64];
192    /// The aligned curves (n × m).
193    fn aligned_data(&self) -> &FdMatrix;
194    /// The warping functions (n × m).
195    fn gammas(&self) -> &FdMatrix;
196    /// Whether the algorithm converged.
197    fn converged(&self) -> bool;
198    /// Number of iterations performed.
199    fn n_iter(&self) -> usize;
200}
201
202impl AlignmentOutput for KarcherMeanResult {
203    fn mean(&self) -> &[f64] {
204        &self.mean
205    }
206    fn mean_srsf(&self) -> &[f64] {
207        &self.mean_srsf
208    }
209    fn aligned_data(&self) -> &FdMatrix {
210        &self.aligned_data
211    }
212    fn gammas(&self) -> &FdMatrix {
213        &self.gammas
214    }
215    fn converged(&self) -> bool {
216        self.converged
217    }
218    fn n_iter(&self) -> usize {
219        self.n_iter
220    }
221}
222
223impl AlignmentOutput for RobustKarcherResult {
224    fn mean(&self) -> &[f64] {
225        &self.mean
226    }
227    fn mean_srsf(&self) -> &[f64] {
228        &self.mean_srsf
229    }
230    fn aligned_data(&self) -> &FdMatrix {
231        &self.aligned_data
232    }
233    fn gammas(&self) -> &FdMatrix {
234        &self.gammas
235    }
236    fn converged(&self) -> bool {
237        self.converged
238    }
239    fn n_iter(&self) -> usize {
240        self.n_iter
241    }
242}
243
244// ─── Conversions ───────────────────────────────────────────────────────────
245
246impl From<RobustKarcherResult> for KarcherMeanResult {
247    fn from(r: RobustKarcherResult) -> Self {
248        Self {
249            mean: r.mean,
250            mean_srsf: r.mean_srsf,
251            gammas: r.gammas,
252            aligned_data: r.aligned_data,
253            n_iter: r.n_iter,
254            converged: r.converged,
255            aligned_srsfs: None,
256        }
257    }
258}
259
260// ─── Dynamic Programming Alignment ──────────────────────────────────────────
261// Faithful port of fdasrvf's DP algorithm (dp_grid.cpp / dp_nbhd.cpp).
262
263/// Pre-computed coprime neighborhood for nbhd_dim=7 (fdasrvf default).
264/// All (dr, dc) with 1 ≤ dr, dc ≤ 7 and gcd(dr, dc) = 1.
265/// dr = row delta (q2 direction), dc = column delta (q1 direction).
266#[rustfmt::skip]
267const COPRIME_NBHD_7: [(usize, usize); 35] = [
268    (1,1),(1,2),(1,3),(1,4),(1,5),(1,6),(1,7),
269    (2,1),      (2,3),      (2,5),      (2,7),
270    (3,1),(3,2),      (3,4),(3,5),      (3,7),
271    (4,1),      (4,3),      (4,5),      (4,7),
272    (5,1),(5,2),(5,3),(5,4),      (5,6),(5,7),
273    (6,1),                  (6,5),      (6,7),
274    (7,1),(7,2),(7,3),(7,4),(7,5),(7,6),
275];
276
277/// Compute the edge weight for a move from grid point (sr, sc) to (tr, tc).
278///
279/// Port of fdasrvf's `dp_edge_weight` for 1-D curves on a shared uniform grid.
280/// - Rows = q2 indices, columns = q1 indices (matching fdasrvf convention).
281/// - `slope = (argvals[tr] - argvals[sr]) / (argvals[tc] - argvals[sc])` = γ'
282/// - Walks through sub-intervals synchronized at both curves' breakpoints,
283///   accumulating `(q1[idx1] - √slope · q2[idx2])² · dt`.
284#[inline]
285pub(super) fn dp_edge_weight(
286    q1: &[f64],
287    q2: &[f64],
288    argvals: &[f64],
289    sc: usize,
290    tc: usize,
291    sr: usize,
292    tr: usize,
293) -> f64 {
294    let n1 = tc - sc;
295    let n2 = tr - sr;
296    if n1 == 0 || n2 == 0 {
297        return f64::INFINITY;
298    }
299
300    let slope = (argvals[tr] - argvals[sr]) / (argvals[tc] - argvals[sc]);
301    let rslope = slope.sqrt();
302
303    // Walk through sub-intervals synchronized at breakpoints of both curves
304    let mut weight = 0.0;
305    let mut i1 = 0usize; // sub-interval index in q1 direction
306    let mut i2 = 0usize; // sub-interval index in q2 direction
307
308    while i1 < n1 && i2 < n2 {
309        // Current sub-interval boundaries as fractions of the total span
310        let left1 = i1 as f64 / n1 as f64;
311        let right1 = (i1 + 1) as f64 / n1 as f64;
312        let left2 = i2 as f64 / n2 as f64;
313        let right2 = (i2 + 1) as f64 / n2 as f64;
314
315        let left = left1.max(left2);
316        let right = right1.min(right2);
317        let dt = right - left;
318
319        if dt > 0.0 {
320            let diff = q1[sc + i1] - rslope * q2[sr + i2];
321            weight += diff * diff * dt;
322        }
323
324        // Advance whichever sub-interval ends first
325        if right1 < right2 {
326            i1 += 1;
327        } else if right2 < right1 {
328            i2 += 1;
329        } else {
330            i1 += 1;
331            i2 += 1;
332        }
333    }
334
335    // Scale by the span in q1 direction
336    weight * (argvals[tc] - argvals[sc])
337}
338
339/// Compute the λ·(slope−1)²·dt penalty for a DP edge.
340#[inline]
341pub(super) fn dp_lambda_penalty(
342    argvals: &[f64],
343    sc: usize,
344    tc: usize,
345    sr: usize,
346    tr: usize,
347    lambda: f64,
348) -> f64 {
349    if lambda > 0.0 {
350        let dt = argvals[tc] - argvals[sc];
351        let slope = (argvals[tr] - argvals[sr]) / dt;
352        lambda * (slope - 1.0).powi(2) * dt
353    } else {
354        0.0
355    }
356}
357
358/// Traceback a parent-pointer array from bottom-right to top-left.
359///
360/// Returns the path as `(row, col)` pairs from `(0,0)` to `(nrows-1, ncols-1)`.
361fn dp_traceback(parent: &[u32], nrows: usize, ncols: usize) -> Vec<(usize, usize)> {
362    let mut path = Vec::with_capacity(nrows + ncols);
363    let mut cur = (nrows - 1) * ncols + (ncols - 1);
364    loop {
365        path.push((cur / ncols, cur % ncols));
366        if cur == 0 || parent[cur] == u32::MAX {
367            break;
368        }
369        cur = parent[cur] as usize;
370    }
371    path.reverse();
372    path
373}
374
375/// Try to relax cell `(tr, tc)` from each coprime neighbor, updating cost and parent.
376#[inline]
377fn dp_relax_cell<F>(
378    e: &mut [f64],
379    parent: &mut [u32],
380    ncols: usize,
381    tr: usize,
382    tc: usize,
383    edge_cost: &F,
384) where
385    F: Fn(usize, usize, usize, usize) -> f64,
386{
387    let idx = tr * ncols + tc;
388    for &(dr, dc) in &COPRIME_NBHD_7 {
389        if dr > tr || dc > tc {
390            continue;
391        }
392        let sr = tr - dr;
393        let sc = tc - dc;
394        let src_idx = sr * ncols + sc;
395        if e[src_idx] == f64::INFINITY {
396            continue;
397        }
398        let cost = e[src_idx] + edge_cost(sr, sc, tr, tc);
399        if cost < e[idx] {
400            e[idx] = cost;
401            parent[idx] = src_idx as u32;
402        }
403    }
404}
405
406/// Shared DP grid fill + traceback using the coprime neighborhood.
407///
408/// `edge_cost(sr, sc, tr, tc)` returns the combined edge weight + penalty for
409/// a move from local (sr, sc) to local (tr, tc). Returns the raw local-index
410/// path from (0,0) to (nrows-1, ncols-1).
411pub(super) fn dp_grid_solve<F>(nrows: usize, ncols: usize, edge_cost: F) -> Vec<(usize, usize)>
412where
413    F: Fn(usize, usize, usize, usize) -> f64,
414{
415    let mut e = vec![f64::INFINITY; nrows * ncols];
416    let mut parent = vec![u32::MAX; nrows * ncols];
417    e[0] = 0.0;
418
419    for tr in 0..nrows {
420        for tc in 0..ncols {
421            if tr == 0 && tc == 0 {
422                continue;
423            }
424            dp_relax_cell(&mut e, &mut parent, ncols, tr, tc, &edge_cost);
425        }
426    }
427
428    dp_traceback(&parent, nrows, ncols)
429}
430
431/// Convert a DP path (local row,col indices) to an interpolated+normalized gamma warp.
432pub(super) fn dp_path_to_gamma(path: &[(usize, usize)], argvals: &[f64]) -> Vec<f64> {
433    let path_tc: Vec<f64> = path.iter().map(|&(_, c)| argvals[c]).collect();
434    let path_tr: Vec<f64> = path.iter().map(|&(r, _)| argvals[r]).collect();
435    let mut gamma: Vec<f64> = argvals
436        .iter()
437        .map(|&t| linear_interp(&path_tc, &path_tr, t))
438        .collect();
439    normalize_warp(&mut gamma, argvals);
440    gamma
441}
442
443/// Core DP alignment between two SRSFs on a grid.
444///
445/// Finds the optimal warping γ minimizing ‖q₁ - (q₂∘γ)√γ'‖².
446/// Uses fdasrvf's coprime neighborhood (nbhd_dim=7 → 35 move directions).
447/// SRSFs are L2-normalized before alignment (matching fdasrvf's `optimum.reparam`).
448pub(crate) fn dp_alignment_core(q1: &[f64], q2: &[f64], argvals: &[f64], lambda: f64) -> Vec<f64> {
449    let m = argvals.len();
450    if m < 2 {
451        return argvals.to_vec();
452    }
453
454    let norm1 = q1.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
455    let norm2 = q2.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
456    let q1n: Vec<f64> = q1.iter().map(|&v| v / norm1).collect();
457    let q2n: Vec<f64> = q2.iter().map(|&v| v / norm2).collect();
458
459    let path = dp_grid_solve(m, m, |sr, sc, tr, tc| {
460        dp_edge_weight(&q1n, &q2n, argvals, sc, tc, sr, tr)
461            + dp_lambda_penalty(argvals, sc, tc, sr, tr, lambda)
462    });
463
464    dp_path_to_gamma(&path, argvals)
465}
466
467/// Greatest common divisor (Euclidean algorithm). Used only in tests.
468#[cfg(test)]
469pub(super) fn gcd(a: usize, b: usize) -> usize {
470    if b == 0 {
471        a
472    } else {
473        gcd(b, a % b)
474    }
475}
476
477/// Generate coprime neighborhood: all (i,j) with 1 ≤ i,j ≤ nbhd_dim, gcd(i,j) = 1.
478/// With nbhd_dim=7 this produces 35 pairs, matching fdasrvf's default.
479#[cfg(test)]
480pub(super) fn generate_coprime_nbhd(nbhd_dim: usize) -> Vec<(usize, usize)> {
481    let mut pairs = Vec::new();
482    for i in 1..=nbhd_dim {
483        for j in 1..=nbhd_dim {
484            if gcd(i, j) == 1 {
485                pairs.push((i, j));
486            }
487        }
488    }
489    pairs
490}