Skip to main content

fdars_core/alignment/
mod.rs

1//! Elastic alignment and SRSF (Square-Root Slope Function) transforms.
2//!
3//! This module provides phase-amplitude separation for functional data via
4//! the elastic framework. Key capabilities:
5//!
6//! - [`srsf_transform`] / [`srsf_inverse`] — SRSF representation and reconstruction
7//! - [`elastic_align_pair`] — Pairwise curve alignment via dynamic programming
8//! - [`elastic_distance`] — Elastic (Fisher-Rao) distance between curves
9//! - [`align_to_target`] — Align a set of curves to a common target
10//! - [`karcher_mean`] — Karcher (Fréchet) mean in the elastic metric
11//! - [`elastic_self_distance_matrix`] / [`elastic_cross_distance_matrix`] — Distance matrices
12//! - [`reparameterize_curve`] / [`compose_warps`] — Warping utilities
13
14mod constrained;
15mod karcher;
16mod nd;
17mod pairwise;
18mod quality;
19mod set;
20mod srsf;
21mod tsrvf;
22
23#[cfg(test)]
24mod tests;
25
26// Re-export all public items so that `crate::alignment::X` continues to work.
27pub use constrained::{
28    elastic_align_pair_constrained, elastic_align_pair_with_landmarks, ConstrainedAlignmentResult,
29};
30pub use karcher::karcher_mean;
31pub use nd::{
32    elastic_align_pair_nd, elastic_distance_nd, srsf_inverse_nd, srsf_transform_nd,
33    AlignmentResultNd,
34};
35pub use pairwise::{
36    amplitude_distance, amplitude_self_distance_matrix, elastic_align_pair,
37    elastic_cross_distance_matrix, elastic_distance, elastic_self_distance_matrix,
38    phase_distance_pair, phase_self_distance_matrix,
39};
40pub use quality::{
41    alignment_quality, pairwise_consistency, warp_complexity, warp_smoothness, AlignmentQuality,
42};
43pub use set::{align_to_target, elastic_decomposition, DecompositionResult};
44pub use srsf::{compose_warps, reparameterize_curve, srsf_inverse, srsf_transform};
45pub use tsrvf::{
46    tsrvf_from_alignment, tsrvf_from_alignment_with_method, tsrvf_inverse, tsrvf_transform,
47    tsrvf_transform_with_method, TransportMethod, TsrvfResult,
48};
49
50// Re-export pub(crate) items so other crate modules can use them.
51pub(crate) use karcher::sqrt_mean_inverse;
52
53use crate::helpers::linear_interp;
54use crate::matrix::FdMatrix;
55use crate::warping::normalize_warp;
56
57// ─── Types ──────────────────────────────────────────────────────────────────
58
59/// Result of aligning one curve to another.
60#[derive(Debug, Clone, PartialEq)]
61#[non_exhaustive]
62pub struct AlignmentResult {
63    /// Warping function γ mapping the domain to itself.
64    pub gamma: Vec<f64>,
65    /// The aligned (reparameterized) curve.
66    pub f_aligned: Vec<f64>,
67    /// Elastic distance after alignment.
68    pub distance: f64,
69}
70
71/// Result of aligning a set of curves to a common target.
72#[derive(Debug, Clone, PartialEq)]
73#[non_exhaustive]
74pub struct AlignmentSetResult {
75    /// Warping functions (n × m).
76    pub gammas: FdMatrix,
77    /// Aligned curves (n × m).
78    pub aligned_data: FdMatrix,
79    /// Elastic distances for each curve.
80    pub distances: Vec<f64>,
81}
82
83/// Result of the Karcher mean computation.
84#[derive(Debug, Clone, PartialEq)]
85#[non_exhaustive]
86pub struct KarcherMeanResult {
87    /// Karcher mean curve.
88    pub mean: Vec<f64>,
89    /// SRSF of the Karcher mean.
90    pub mean_srsf: Vec<f64>,
91    /// Final warping functions (n × m).
92    pub gammas: FdMatrix,
93    /// Curves aligned to the mean (n × m).
94    pub aligned_data: FdMatrix,
95    /// Number of iterations used.
96    pub n_iter: usize,
97    /// Whether the algorithm converged.
98    pub converged: bool,
99    /// Pre-computed SRSFs of aligned curves (n × m), if available.
100    /// When set, FPCA functions use these instead of recomputing from `aligned_data`.
101    pub aligned_srsfs: Option<FdMatrix>,
102}
103
104impl KarcherMeanResult {
105    /// Create a new `KarcherMeanResult`.
106    pub fn new(
107        mean: Vec<f64>,
108        mean_srsf: Vec<f64>,
109        gammas: FdMatrix,
110        aligned_data: FdMatrix,
111        n_iter: usize,
112        converged: bool,
113        aligned_srsfs: Option<FdMatrix>,
114    ) -> Self {
115        Self {
116            mean,
117            mean_srsf,
118            gammas,
119            aligned_data,
120            n_iter,
121            converged,
122            aligned_srsfs,
123        }
124    }
125}
126
127// ─── Dynamic Programming Alignment ──────────────────────────────────────────
128// Faithful port of fdasrvf's DP algorithm (dp_grid.cpp / dp_nbhd.cpp).
129
130/// Pre-computed coprime neighborhood for nbhd_dim=7 (fdasrvf default).
131/// All (dr, dc) with 1 ≤ dr, dc ≤ 7 and gcd(dr, dc) = 1.
132/// dr = row delta (q2 direction), dc = column delta (q1 direction).
133#[rustfmt::skip]
134const COPRIME_NBHD_7: [(usize, usize); 35] = [
135    (1,1),(1,2),(1,3),(1,4),(1,5),(1,6),(1,7),
136    (2,1),      (2,3),      (2,5),      (2,7),
137    (3,1),(3,2),      (3,4),(3,5),      (3,7),
138    (4,1),      (4,3),      (4,5),      (4,7),
139    (5,1),(5,2),(5,3),(5,4),      (5,6),(5,7),
140    (6,1),                  (6,5),      (6,7),
141    (7,1),(7,2),(7,3),(7,4),(7,5),(7,6),
142];
143
144/// Compute the edge weight for a move from grid point (sr, sc) to (tr, tc).
145///
146/// Port of fdasrvf's `dp_edge_weight` for 1-D curves on a shared uniform grid.
147/// - Rows = q2 indices, columns = q1 indices (matching fdasrvf convention).
148/// - `slope = (argvals[tr] - argvals[sr]) / (argvals[tc] - argvals[sc])` = γ'
149/// - Walks through sub-intervals synchronized at both curves' breakpoints,
150///   accumulating `(q1[idx1] - √slope · q2[idx2])² · dt`.
151#[inline]
152pub(super) fn dp_edge_weight(
153    q1: &[f64],
154    q2: &[f64],
155    argvals: &[f64],
156    sc: usize,
157    tc: usize,
158    sr: usize,
159    tr: usize,
160) -> f64 {
161    let n1 = tc - sc;
162    let n2 = tr - sr;
163    if n1 == 0 || n2 == 0 {
164        return f64::INFINITY;
165    }
166
167    let slope = (argvals[tr] - argvals[sr]) / (argvals[tc] - argvals[sc]);
168    let rslope = slope.sqrt();
169
170    // Walk through sub-intervals synchronized at breakpoints of both curves
171    let mut weight = 0.0;
172    let mut i1 = 0usize; // sub-interval index in q1 direction
173    let mut i2 = 0usize; // sub-interval index in q2 direction
174
175    while i1 < n1 && i2 < n2 {
176        // Current sub-interval boundaries as fractions of the total span
177        let left1 = i1 as f64 / n1 as f64;
178        let right1 = (i1 + 1) as f64 / n1 as f64;
179        let left2 = i2 as f64 / n2 as f64;
180        let right2 = (i2 + 1) as f64 / n2 as f64;
181
182        let left = left1.max(left2);
183        let right = right1.min(right2);
184        let dt = right - left;
185
186        if dt > 0.0 {
187            let diff = q1[sc + i1] - rslope * q2[sr + i2];
188            weight += diff * diff * dt;
189        }
190
191        // Advance whichever sub-interval ends first
192        if right1 < right2 {
193            i1 += 1;
194        } else if right2 < right1 {
195            i2 += 1;
196        } else {
197            i1 += 1;
198            i2 += 1;
199        }
200    }
201
202    // Scale by the span in q1 direction
203    weight * (argvals[tc] - argvals[sc])
204}
205
206/// Compute the λ·(slope−1)²·dt penalty for a DP edge.
207#[inline]
208pub(super) fn dp_lambda_penalty(
209    argvals: &[f64],
210    sc: usize,
211    tc: usize,
212    sr: usize,
213    tr: usize,
214    lambda: f64,
215) -> f64 {
216    if lambda > 0.0 {
217        let dt = argvals[tc] - argvals[sc];
218        let slope = (argvals[tr] - argvals[sr]) / dt;
219        lambda * (slope - 1.0).powi(2) * dt
220    } else {
221        0.0
222    }
223}
224
225/// Traceback a parent-pointer array from bottom-right to top-left.
226///
227/// Returns the path as `(row, col)` pairs from `(0,0)` to `(nrows-1, ncols-1)`.
228fn dp_traceback(parent: &[u32], nrows: usize, ncols: usize) -> Vec<(usize, usize)> {
229    let mut path = Vec::with_capacity(nrows + ncols);
230    let mut cur = (nrows - 1) * ncols + (ncols - 1);
231    loop {
232        path.push((cur / ncols, cur % ncols));
233        if cur == 0 || parent[cur] == u32::MAX {
234            break;
235        }
236        cur = parent[cur] as usize;
237    }
238    path.reverse();
239    path
240}
241
242/// Try to relax cell `(tr, tc)` from each coprime neighbor, updating cost and parent.
243#[inline]
244fn dp_relax_cell<F>(
245    e: &mut [f64],
246    parent: &mut [u32],
247    ncols: usize,
248    tr: usize,
249    tc: usize,
250    edge_cost: &F,
251) where
252    F: Fn(usize, usize, usize, usize) -> f64,
253{
254    let idx = tr * ncols + tc;
255    for &(dr, dc) in &COPRIME_NBHD_7 {
256        if dr > tr || dc > tc {
257            continue;
258        }
259        let sr = tr - dr;
260        let sc = tc - dc;
261        let src_idx = sr * ncols + sc;
262        if e[src_idx] == f64::INFINITY {
263            continue;
264        }
265        let cost = e[src_idx] + edge_cost(sr, sc, tr, tc);
266        if cost < e[idx] {
267            e[idx] = cost;
268            parent[idx] = src_idx as u32;
269        }
270    }
271}
272
273/// Shared DP grid fill + traceback using the coprime neighborhood.
274///
275/// `edge_cost(sr, sc, tr, tc)` returns the combined edge weight + penalty for
276/// a move from local (sr, sc) to local (tr, tc). Returns the raw local-index
277/// path from (0,0) to (nrows-1, ncols-1).
278pub(super) fn dp_grid_solve<F>(nrows: usize, ncols: usize, edge_cost: F) -> Vec<(usize, usize)>
279where
280    F: Fn(usize, usize, usize, usize) -> f64,
281{
282    let mut e = vec![f64::INFINITY; nrows * ncols];
283    let mut parent = vec![u32::MAX; nrows * ncols];
284    e[0] = 0.0;
285
286    for tr in 0..nrows {
287        for tc in 0..ncols {
288            if tr == 0 && tc == 0 {
289                continue;
290            }
291            dp_relax_cell(&mut e, &mut parent, ncols, tr, tc, &edge_cost);
292        }
293    }
294
295    dp_traceback(&parent, nrows, ncols)
296}
297
298/// Convert a DP path (local row,col indices) to an interpolated+normalized gamma warp.
299pub(super) fn dp_path_to_gamma(path: &[(usize, usize)], argvals: &[f64]) -> Vec<f64> {
300    let path_tc: Vec<f64> = path.iter().map(|&(_, c)| argvals[c]).collect();
301    let path_tr: Vec<f64> = path.iter().map(|&(r, _)| argvals[r]).collect();
302    let mut gamma: Vec<f64> = argvals
303        .iter()
304        .map(|&t| linear_interp(&path_tc, &path_tr, t))
305        .collect();
306    normalize_warp(&mut gamma, argvals);
307    gamma
308}
309
310/// Core DP alignment between two SRSFs on a grid.
311///
312/// Finds the optimal warping γ minimizing ‖q₁ - (q₂∘γ)√γ'‖².
313/// Uses fdasrvf's coprime neighborhood (nbhd_dim=7 → 35 move directions).
314/// SRSFs are L2-normalized before alignment (matching fdasrvf's `optimum.reparam`).
315pub(crate) fn dp_alignment_core(q1: &[f64], q2: &[f64], argvals: &[f64], lambda: f64) -> Vec<f64> {
316    let m = argvals.len();
317    if m < 2 {
318        return argvals.to_vec();
319    }
320
321    let norm1 = q1.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
322    let norm2 = q2.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
323    let q1n: Vec<f64> = q1.iter().map(|&v| v / norm1).collect();
324    let q2n: Vec<f64> = q2.iter().map(|&v| v / norm2).collect();
325
326    let path = dp_grid_solve(m, m, |sr, sc, tr, tc| {
327        dp_edge_weight(&q1n, &q2n, argvals, sc, tc, sr, tr)
328            + dp_lambda_penalty(argvals, sc, tc, sr, tr, lambda)
329    });
330
331    dp_path_to_gamma(&path, argvals)
332}
333
334/// Greatest common divisor (Euclidean algorithm). Used only in tests.
335#[cfg(test)]
336pub(super) fn gcd(a: usize, b: usize) -> usize {
337    if b == 0 {
338        a
339    } else {
340        gcd(b, a % b)
341    }
342}
343
344/// Generate coprime neighborhood: all (i,j) with 1 ≤ i,j ≤ nbhd_dim, gcd(i,j) = 1.
345/// With nbhd_dim=7 this produces 35 pairs, matching fdasrvf's default.
346#[cfg(test)]
347pub(super) fn generate_coprime_nbhd(nbhd_dim: usize) -> Vec<(usize, usize)> {
348    let mut pairs = Vec::new();
349    for i in 1..=nbhd_dim {
350        for j in 1..=nbhd_dim {
351            if gcd(i, j) == 1 {
352                pairs.push((i, j));
353            }
354        }
355    }
356    pairs
357}