Skip to main content

fdars_core/alignment/
pairwise.rs

1//! Pairwise elastic alignment, distance computation, and distance matrices.
2
3use super::srsf::{reparameterize_curve, srsf_single, srsf_transform};
4use super::{dp_alignment_core, AlignmentResult};
5use crate::helpers::{l2_distance, simpsons_weights};
6use crate::iter_maybe_parallel;
7use crate::matrix::FdMatrix;
8#[cfg(feature = "parallel")]
9use rayon::iter::ParallelIterator;
10
11// ─── Public Alignment Functions ─────────────────────────────────────────────
12
13/// Align curve f2 to curve f1 using the elastic framework.
14///
15/// Computes the optimal warping γ such that f2∘γ is as close as possible
16/// to f1 in the elastic (Fisher-Rao) metric.
17///
18/// # Arguments
19/// * `f1` — Target curve (length m)
20/// * `f2` — Curve to align (length m)
21/// * `argvals` — Evaluation points (length m)
22/// * `lambda` — Penalty weight on warp deviation from identity (0.0 = no penalty)
23///
24/// # Returns
25/// [`AlignmentResult`] with warping function, aligned curve, and elastic distance.
26///
27/// # Examples
28///
29/// ```
30/// use fdars_core::alignment::elastic_align_pair;
31///
32/// let argvals: Vec<f64> = (0..20).map(|i| i as f64 / 19.0).collect();
33/// let f1: Vec<f64> = argvals.iter().map(|&t| (t * 6.0).sin()).collect();
34/// let f2: Vec<f64> = argvals.iter().map(|&t| ((t + 0.1) * 6.0).sin()).collect();
35/// let result = elastic_align_pair(&f1, &f2, &argvals, 0.0);
36/// assert_eq!(result.f_aligned.len(), 20);
37/// assert!(result.distance >= 0.0);
38/// ```
39#[must_use = "expensive computation whose result should not be discarded"]
40pub fn elastic_align_pair(f1: &[f64], f2: &[f64], argvals: &[f64], lambda: f64) -> AlignmentResult {
41    let q1 = srsf_single(f1, argvals);
42    let q2 = srsf_single(f2, argvals);
43    elastic_align_pair_from_srsf(f2, &q1, &q2, argvals, lambda)
44}
45
46/// Compute the elastic distance between two curves.
47///
48/// This is shorthand for aligning the pair and returning only the distance.
49///
50/// # Arguments
51/// * `f1` — First curve (length m)
52/// * `f2` — Second curve (length m)
53/// * `argvals` — Evaluation points (length m)
54/// * `lambda` — Penalty weight on warp deviation from identity (0.0 = no penalty)
55///
56/// # Examples
57///
58/// ```
59/// use fdars_core::alignment::elastic_distance;
60///
61/// let argvals: Vec<f64> = (0..20).map(|i| i as f64 / 19.0).collect();
62/// let f1: Vec<f64> = argvals.iter().map(|&t| (t * 6.0).sin()).collect();
63/// let f2: Vec<f64> = argvals.iter().map(|&t| ((t + 0.1) * 6.0).sin()).collect();
64/// let d = elastic_distance(&f1, &f2, &argvals, 0.0);
65/// assert!(d >= 0.0);
66/// ```
67#[must_use = "expensive computation whose result should not be discarded"]
68pub fn elastic_distance(f1: &[f64], f2: &[f64], argvals: &[f64], lambda: f64) -> f64 {
69    elastic_align_pair(f1, f2, argvals, lambda).distance
70}
71
72// ─── Internal Helpers with Pre-computed SRSFs ────────────────────────────────
73
74/// Align curve f2 to curve f1 given their pre-computed SRSFs.
75///
76/// This avoids redundant SRSF computation when calling from distance matrix
77/// routines where the same curve's SRSF would otherwise be recomputed for
78/// every pair.
79fn elastic_align_pair_from_srsf(
80    f2: &[f64],
81    q1: &[f64],
82    q2: &[f64],
83    argvals: &[f64],
84    lambda: f64,
85) -> AlignmentResult {
86    // Find optimal warping via DP
87    let gamma = dp_alignment_core(q1, q2, argvals, lambda);
88
89    // Apply warping to f2
90    let f_aligned = reparameterize_curve(f2, argvals, &gamma);
91
92    // Compute elastic distance: L2 distance between q1 and aligned q2 SRSF
93    let q_aligned = srsf_single(&f_aligned, argvals);
94
95    let weights = simpsons_weights(argvals);
96    let distance = l2_distance(q1, &q_aligned, &weights);
97
98    AlignmentResult {
99        gamma,
100        f_aligned,
101        distance,
102    }
103}
104
105/// Compute elastic distance given a raw curve f2, and pre-computed SRSFs q1, q2.
106///
107/// The raw f2 is needed to reparameterize before computing the aligned SRSF.
108fn elastic_distance_from_srsf(
109    f2: &[f64],
110    q1: &[f64],
111    q2: &[f64],
112    argvals: &[f64],
113    lambda: f64,
114) -> f64 {
115    let gamma = dp_alignment_core(q1, q2, argvals, lambda);
116    let f_aligned = reparameterize_curve(f2, argvals, &gamma);
117    let q_aligned = srsf_single(&f_aligned, argvals);
118    let weights = simpsons_weights(argvals);
119    l2_distance(q1, &q_aligned, &weights)
120}
121
122// ─── Distance Matrices ──────────────────────────────────────────────────────
123
124/// Compute the symmetric elastic distance matrix for a set of curves.
125///
126/// Pre-computes SRSF transforms for all curves once (O(n)) instead of
127/// recomputing each curve's SRSF for every pair (O(n²)).
128///
129/// Uses upper-triangle computation with parallelism, following the
130/// `self_distance_matrix` pattern from `metric.rs`.
131///
132/// # Arguments
133/// * `data` — Functional data matrix (n × m)
134/// * `argvals` — Evaluation points (length m)
135/// * `lambda` — Penalty weight on warp deviation from identity (0.0 = no penalty)
136///
137/// # Returns
138/// Symmetric n × n distance matrix.
139pub fn elastic_self_distance_matrix(data: &FdMatrix, argvals: &[f64], lambda: f64) -> FdMatrix {
140    let n = data.nrows();
141
142    // Pre-compute all SRSF transforms once
143    let srsfs = srsf_transform(data, argvals);
144
145    let upper_vals: Vec<f64> = iter_maybe_parallel!(0..n)
146        .flat_map(|i| {
147            let qi = srsfs.row(i);
148            ((i + 1)..n)
149                .map(|j| {
150                    let fj = data.row(j);
151                    let qj = srsfs.row(j);
152                    elastic_distance_from_srsf(&fj, &qi, &qj, argvals, lambda)
153                })
154                .collect::<Vec<_>>()
155        })
156        .collect();
157
158    let mut dist = FdMatrix::zeros(n, n);
159    let mut idx = 0;
160    for i in 0..n {
161        for j in (i + 1)..n {
162            let d = upper_vals[idx];
163            dist[(i, j)] = d;
164            dist[(j, i)] = d;
165            idx += 1;
166        }
167    }
168    dist
169}
170
171/// Compute the elastic distance matrix between two sets of curves.
172///
173/// Pre-computes SRSF transforms for both datasets once instead of
174/// recomputing each curve's SRSF for every pair.
175///
176/// # Arguments
177/// * `data1` — First dataset (n1 × m)
178/// * `data2` — Second dataset (n2 × m)
179/// * `argvals` — Evaluation points (length m)
180/// * `lambda` — Penalty weight on warp deviation from identity (0.0 = no penalty)
181///
182/// # Returns
183/// n1 × n2 distance matrix.
184pub fn elastic_cross_distance_matrix(
185    data1: &FdMatrix,
186    data2: &FdMatrix,
187    argvals: &[f64],
188    lambda: f64,
189) -> FdMatrix {
190    let n1 = data1.nrows();
191    let n2 = data2.nrows();
192
193    // Pre-compute all SRSF transforms once for both datasets
194    let srsfs1 = srsf_transform(data1, argvals);
195    let srsfs2 = srsf_transform(data2, argvals);
196
197    let vals: Vec<f64> = iter_maybe_parallel!(0..n1)
198        .flat_map(|i| {
199            let qi = srsfs1.row(i);
200            (0..n2)
201                .map(|j| {
202                    let fj = data2.row(j);
203                    let qj = srsfs2.row(j);
204                    elastic_distance_from_srsf(&fj, &qi, &qj, argvals, lambda)
205                })
206                .collect::<Vec<_>>()
207        })
208        .collect();
209
210    let mut dist = FdMatrix::zeros(n1, n2);
211    for i in 0..n1 {
212        for j in 0..n2 {
213            dist[(i, j)] = vals[i * n2 + j];
214        }
215    }
216    dist
217}
218
219/// Compute the amplitude distance between two curves (= elastic distance after alignment).
220pub fn amplitude_distance(f1: &[f64], f2: &[f64], argvals: &[f64], lambda: f64) -> f64 {
221    elastic_distance(f1, f2, argvals, lambda)
222}
223
224/// Compute the phase distance between two curves (geodesic distance of optimal warp from identity).
225pub fn phase_distance_pair(f1: &[f64], f2: &[f64], argvals: &[f64], lambda: f64) -> f64 {
226    let alignment = elastic_align_pair(f1, f2, argvals, lambda);
227    crate::warping::phase_distance(&alignment.gamma, argvals)
228}
229
230/// Compute the symmetric phase distance matrix for a set of curves.
231pub fn phase_self_distance_matrix(data: &FdMatrix, argvals: &[f64], lambda: f64) -> FdMatrix {
232    let n = data.nrows();
233
234    let upper_vals: Vec<f64> = iter_maybe_parallel!(0..n)
235        .flat_map(|i| {
236            let fi = data.row(i);
237            ((i + 1)..n)
238                .map(|j| {
239                    let fj = data.row(j);
240                    phase_distance_pair(&fi, &fj, argvals, lambda)
241                })
242                .collect::<Vec<_>>()
243        })
244        .collect();
245
246    let mut dist = FdMatrix::zeros(n, n);
247    let mut idx = 0;
248    for i in 0..n {
249        for j in (i + 1)..n {
250            let d = upper_vals[idx];
251            dist[(i, j)] = d;
252            dist[(j, i)] = d;
253            idx += 1;
254        }
255    }
256    dist
257}
258
259/// Compute the symmetric amplitude distance matrix (= elastic self distance matrix).
260pub fn amplitude_self_distance_matrix(data: &FdMatrix, argvals: &[f64], lambda: f64) -> FdMatrix {
261    elastic_self_distance_matrix(data, argvals, lambda)
262}
263
264// ─── Higher-Order Warp Penalties ─────────────────────────────────────────────
265
266/// Penalty type for alignment regularization.
267///
268/// Controls how the warping function is penalized during alignment.
269/// `FirstOrder` uses the standard DP penalty on slope deviation.
270/// `SecondOrder` and `Combined` first run standard DP alignment, then
271/// apply iterative Tikhonov smoothing to reduce warp curvature.
272#[derive(Debug, Clone, Copy, PartialEq, Default)]
273#[non_exhaustive]
274pub enum WarpPenaltyType {
275    /// Standard first-order penalty: lambda * (gamma' - 1)^2 * dt.
276    #[default]
277    FirstOrder,
278    /// Second-order (curvature) penalty: standard DP + iterative curvature smoothing.
279    SecondOrder,
280    /// Combined first- and second-order: DP alignment + curvature smoothing
281    /// weighted by `second_order_weight`.
282    Combined {
283        /// Relative weight of the curvature smoothing step (> 0).
284        second_order_weight: f64,
285    },
286}
287
288/// Number of Tikhonov smoothing iterations for second-order penalty.
289const TIKHONOV_ITERS: usize = 8;
290
291/// Apply Tikhonov curvature smoothing to a warping function.
292///
293/// Iteratively smooths toward the identity warp using Laplacian smoothing,
294/// which reduces high-frequency curvature while preserving monotonicity
295/// and boundary conditions. The smoothing weight `alpha` (clamped to [0,1])
296/// controls how much each iteration pulls interior points toward the
297/// midpoint of their neighbors.
298fn tikhonov_smooth_gamma(gamma: &[f64], argvals: &[f64], alpha: f64, n_iter: usize) -> Vec<f64> {
299    let m = gamma.len();
300    if m < 3 || alpha <= 0.0 {
301        return gamma.to_vec();
302    }
303
304    // Clamp effective weight to a stable range.
305    let w = alpha.min(0.5);
306
307    let mut gam = gamma.to_vec();
308
309    for _ in 0..n_iter {
310        let prev = gam.clone();
311
312        // Laplacian smoothing: move each interior point toward the
313        // midpoint of its neighbors, weighted by w.
314        for j in 1..m - 1 {
315            let mid = (prev[j - 1] + prev[j + 1]) / 2.0;
316            gam[j] = prev[j] + w * (mid - prev[j]);
317        }
318
319        // Enforce boundary conditions.
320        gam[0] = argvals[0];
321        gam[m - 1] = argvals[m - 1];
322
323        // Enforce monotonicity.
324        crate::warping::normalize_warp(&mut gam, argvals);
325    }
326
327    gam
328}
329
330/// Align two curves with a configurable penalty type.
331///
332/// For [`WarpPenaltyType::FirstOrder`], this delegates directly to
333/// [`elastic_align_pair`]. For [`WarpPenaltyType::SecondOrder`] and
334/// [`WarpPenaltyType::Combined`], runs the standard DP alignment first,
335/// then applies iterative Tikhonov smoothing to the warping function to
336/// reduce curvature (gamma'') while preserving alignment quality.
337///
338/// # Arguments
339/// * `f1` — Target curve (length m)
340/// * `f2` — Curve to align (length m)
341/// * `argvals` — Evaluation points (length m)
342/// * `lambda` — First-order penalty weight (passed to DP alignment)
343/// * `penalty_type` — Which penalty type to apply
344///
345/// # Returns
346/// [`AlignmentResult`] with warping function, aligned curve, and elastic distance.
347///
348/// # Examples
349///
350/// ```
351/// use fdars_core::alignment::{elastic_align_pair_penalized, WarpPenaltyType};
352///
353/// let argvals: Vec<f64> = (0..30).map(|i| i as f64 / 29.0).collect();
354/// let f1: Vec<f64> = argvals.iter().map(|&t| (t * 6.0).sin()).collect();
355/// let f2: Vec<f64> = argvals.iter().map(|&t| ((t + 0.1) * 6.0).sin()).collect();
356///
357/// // Standard first-order
358/// let r1 = elastic_align_pair_penalized(&f1, &f2, &argvals, 0.0, WarpPenaltyType::FirstOrder);
359/// assert!(r1.distance >= 0.0);
360///
361/// // Second-order smoothing
362/// let r2 = elastic_align_pair_penalized(&f1, &f2, &argvals, 0.0, WarpPenaltyType::SecondOrder);
363/// assert!(r2.distance >= 0.0);
364/// ```
365#[must_use = "expensive computation whose result should not be discarded"]
366pub fn elastic_align_pair_penalized(
367    f1: &[f64],
368    f2: &[f64],
369    argvals: &[f64],
370    lambda: f64,
371    penalty_type: WarpPenaltyType,
372) -> AlignmentResult {
373    // Step 1: Run standard first-order DP alignment.
374    let initial = elastic_align_pair(f1, f2, argvals, lambda);
375
376    let smoothing_alpha = match penalty_type {
377        WarpPenaltyType::FirstOrder => return initial,
378        WarpPenaltyType::SecondOrder => lambda.max(0.01),
379        WarpPenaltyType::Combined {
380            second_order_weight,
381        } => second_order_weight.max(1e-6),
382    };
383
384    // Step 2: Apply Tikhonov curvature smoothing to the warping function.
385    let gamma_smooth =
386        tikhonov_smooth_gamma(&initial.gamma, argvals, smoothing_alpha, TIKHONOV_ITERS);
387
388    // Step 3: Recompute aligned curve and distance with smoothed gamma.
389    let f_aligned = reparameterize_curve(f2, argvals, &gamma_smooth);
390    let q1 = srsf_single(f1, argvals);
391    let q_aligned = srsf_single(&f_aligned, argvals);
392    let weights = simpsons_weights(argvals);
393    let distance = l2_distance(&q1, &q_aligned, &weights);
394
395    AlignmentResult {
396        gamma: gamma_smooth,
397        f_aligned,
398        distance,
399    }
400}
401
402#[cfg(test)]
403mod tests {
404    use super::*;
405
406    fn uniform_grid(n: usize) -> Vec<f64> {
407        (0..n).map(|i| i as f64 / (n - 1) as f64).collect()
408    }
409
410    #[test]
411    fn penalized_first_order_matches_standard() {
412        let argvals = uniform_grid(30);
413        let f1: Vec<f64> = argvals.iter().map(|&t| (t * 6.0).sin()).collect();
414        let f2: Vec<f64> = argvals.iter().map(|&t| ((t + 0.1) * 6.0).sin()).collect();
415
416        let standard = elastic_align_pair(&f1, &f2, &argvals, 0.0);
417        let penalized =
418            elastic_align_pair_penalized(&f1, &f2, &argvals, 0.0, WarpPenaltyType::FirstOrder);
419
420        assert_eq!(standard.gamma, penalized.gamma);
421        assert_eq!(standard.f_aligned, penalized.f_aligned);
422        assert!((standard.distance - penalized.distance).abs() < 1e-12);
423    }
424
425    #[test]
426    fn second_order_produces_valid_warp() {
427        let argvals = uniform_grid(30);
428        let f1: Vec<f64> = argvals.iter().map(|&t| (t * 6.0).sin()).collect();
429        let f2: Vec<f64> = argvals.iter().map(|&t| ((t + 0.15) * 6.0).sin()).collect();
430
431        let result =
432            elastic_align_pair_penalized(&f1, &f2, &argvals, 0.1, WarpPenaltyType::SecondOrder);
433
434        let m = argvals.len();
435        assert_eq!(result.gamma.len(), m);
436        assert_eq!(result.f_aligned.len(), m);
437        assert!(result.distance >= 0.0);
438
439        // Warp should be monotone non-decreasing.
440        for j in 1..m {
441            assert!(
442                result.gamma[j] >= result.gamma[j - 1] - 1e-12,
443                "gamma should be monotone at j={j}"
444            );
445        }
446
447        // Boundary conditions.
448        assert!((result.gamma[0] - argvals[0]).abs() < 1e-12);
449        assert!((result.gamma[m - 1] - argvals[m - 1]).abs() < 1e-12);
450    }
451
452    #[test]
453    fn combined_penalty_produces_valid_warp() {
454        let argvals = uniform_grid(25);
455        let f1: Vec<f64> = argvals.iter().map(|&t| (t * 4.0).sin()).collect();
456        let f2: Vec<f64> = argvals.iter().map(|&t| ((t + 0.1) * 4.0).sin()).collect();
457
458        let result = elastic_align_pair_penalized(
459            &f1,
460            &f2,
461            &argvals,
462            0.05,
463            WarpPenaltyType::Combined {
464                second_order_weight: 0.1,
465            },
466        );
467
468        let m = argvals.len();
469        assert_eq!(result.gamma.len(), m);
470        assert!(result.distance >= 0.0);
471
472        // Monotonicity.
473        for j in 1..m {
474            assert!(
475                result.gamma[j] >= result.gamma[j - 1] - 1e-12,
476                "gamma should be monotone at j={j}"
477            );
478        }
479    }
480
481    #[test]
482    fn second_order_smoother_curvature() {
483        let argvals = uniform_grid(40);
484        let f1: Vec<f64> = argvals.iter().map(|&t| (t * 8.0).sin()).collect();
485        let f2: Vec<f64> = argvals.iter().map(|&t| ((t + 0.2) * 8.0).sin()).collect();
486
487        let first_order = elastic_align_pair(&f1, &f2, &argvals, 0.0);
488        let second_order =
489            elastic_align_pair_penalized(&f1, &f2, &argvals, 0.0, WarpPenaltyType::SecondOrder);
490
491        // Compute bending energy (sum of squared second derivative).
492        let bending = |g: &[f64]| -> f64 {
493            let m = g.len();
494            let mut energy = 0.0;
495            for j in 1..m - 1 {
496                let dt = argvals[j + 1] - argvals[j - 1];
497                if dt > 0.0 {
498                    let d2 = (g[j + 1] - 2.0 * g[j] + g[j - 1]) / (dt / 2.0).powi(2);
499                    energy += d2 * d2 * dt / 2.0;
500                }
501            }
502            energy
503        };
504
505        let be_first = bending(&first_order.gamma);
506        let be_second = bending(&second_order.gamma);
507
508        // Second-order penalty should reduce bending energy (or at least not
509        // increase it much if the first-order warp is already smooth).
510        assert!(
511            be_second <= be_first + 1e-6,
512            "second-order should reduce bending: first={be_first:.4}, second={be_second:.4}"
513        );
514    }
515
516    #[test]
517    fn warp_penalty_type_default_is_first_order() {
518        let penalty: WarpPenaltyType = WarpPenaltyType::default();
519        assert_eq!(penalty, WarpPenaltyType::FirstOrder);
520    }
521}