Skip to main content

fdars_core/alignment/
closed.rs

1//! Alignment for periodic/closed curves where the starting point can shift.
2//!
3//! Closed curves have a natural ambiguity in the choice of starting point.
4//! These functions optimize over circular rotations of the parameterization
5//! in addition to the standard elastic warping.
6
7use super::dp_alignment_core;
8use super::pairwise::elastic_align_pair;
9use super::srsf::{reparameterize_curve, srsf_single};
10use crate::error::FdarError;
11use crate::helpers::{l2_distance, simpsons_weights};
12use crate::iter_maybe_parallel;
13use crate::matrix::FdMatrix;
14#[cfg(feature = "parallel")]
15use rayon::iter::ParallelIterator;
16
17// ─── Types ──────────────────────────────────────────────────────────────────
18
19/// Result of aligning one closed curve to another.
20#[derive(Debug, Clone, PartialEq)]
21#[non_exhaustive]
22pub struct ClosedAlignmentResult {
23    /// Warping function on the domain.
24    pub gamma: Vec<f64>,
25    /// The aligned (reparameterized and rotated) curve.
26    pub f_aligned: Vec<f64>,
27    /// Elastic distance after alignment.
28    pub distance: f64,
29    /// Optimal circular shift index for the source curve.
30    pub optimal_rotation: usize,
31}
32
33/// Result of computing the Karcher mean for closed curves.
34#[derive(Debug, Clone, PartialEq)]
35#[non_exhaustive]
36pub struct ClosedKarcherMeanResult {
37    /// Karcher mean curve.
38    pub mean: Vec<f64>,
39    /// SRSF of the Karcher mean.
40    pub mean_srsf: Vec<f64>,
41    /// Final warping functions (n x m).
42    pub gammas: FdMatrix,
43    /// Curves aligned to the mean (n x m).
44    pub aligned_data: FdMatrix,
45    /// Per-curve optimal rotations.
46    pub rotations: Vec<usize>,
47    /// Number of iterations used.
48    pub n_iter: usize,
49    /// Whether the algorithm converged.
50    pub converged: bool,
51}
52
53// ─── Helpers ────────────────────────────────────────────────────────────────
54
55/// Circularly shift a curve by `k` positions.
56fn circular_shift(f: &[f64], k: usize) -> Vec<f64> {
57    let m = f.len();
58    if m == 0 || k == 0 {
59        return f.to_vec();
60    }
61    let k = k % m;
62    (0..m).map(|j| f[(j + k) % m]).collect()
63}
64
65/// Find the best circular rotation of `f2` to match `f1`, using a coarse-then-fine strategy.
66///
67/// Returns `(best_rotation, best_distance)`.
68fn find_best_rotation(f1: &[f64], f2: &[f64], argvals: &[f64], lambda: f64) -> (usize, f64) {
69    let m = f1.len();
70    if m < 2 {
71        return (0, 0.0);
72    }
73
74    // Coarse search: try m/step_size evenly spaced rotations
75    let step_size = (m / 20).max(1);
76    let mut best_k = 0;
77    let mut best_dist = f64::INFINITY;
78
79    let mut k = 0;
80    while k < m {
81        let f2_rot = circular_shift(f2, k);
82        let q1 = srsf_single(f1, argvals);
83        let q2 = srsf_single(&f2_rot, argvals);
84        let gamma = dp_alignment_core(&q1, &q2, argvals, lambda);
85        let f_aligned = reparameterize_curve(&f2_rot, argvals, &gamma);
86        let q_aligned = srsf_single(&f_aligned, argvals);
87        let weights = simpsons_weights(argvals);
88        let dist = l2_distance(&q1, &q_aligned, &weights);
89
90        if dist < best_dist {
91            best_dist = dist;
92            best_k = k;
93        }
94        k += step_size;
95    }
96
97    // Fine search: refine around best coarse rotation
98    let search_start = best_k.saturating_sub(step_size);
99    let search_end = (best_k + step_size).min(m);
100
101    for k in search_start..search_end {
102        if k % step_size == 0 {
103            continue; // already checked in coarse pass
104        }
105        let f2_rot = circular_shift(f2, k);
106        let q1 = srsf_single(f1, argvals);
107        let q2 = srsf_single(&f2_rot, argvals);
108        let gamma = dp_alignment_core(&q1, &q2, argvals, lambda);
109        let f_aligned = reparameterize_curve(&f2_rot, argvals, &gamma);
110        let q_aligned = srsf_single(&f_aligned, argvals);
111        let weights = simpsons_weights(argvals);
112        let dist = l2_distance(&q1, &q_aligned, &weights);
113
114        if dist < best_dist {
115            best_dist = dist;
116            best_k = k;
117        }
118    }
119
120    (best_k, best_dist)
121}
122
123// ─── Public Functions ───────────────────────────────────────────────────────
124
125/// Align closed curve `f2` to closed curve `f1` with rotation search.
126///
127/// For closed (periodic) curves, the starting point is arbitrary. This function
128/// searches over circular shifts of `f2` to find the rotation that minimizes
129/// the elastic distance, then performs full elastic alignment at that rotation.
130///
131/// # Arguments
132/// * `f1` - Target curve (length m)
133/// * `f2` - Curve to align (length m)
134/// * `argvals` - Evaluation points (length m)
135/// * `lambda` - Penalty weight on warp deviation from identity (0.0 = no penalty)
136///
137/// # Errors
138/// Returns `FdarError::InvalidDimension` if lengths do not match or m < 2.
139#[must_use = "expensive computation whose result should not be discarded"]
140pub fn elastic_align_pair_closed(
141    f1: &[f64],
142    f2: &[f64],
143    argvals: &[f64],
144    lambda: f64,
145) -> Result<ClosedAlignmentResult, FdarError> {
146    let m = f1.len();
147    if m != f2.len() || m != argvals.len() {
148        return Err(FdarError::InvalidDimension {
149            parameter: "f1/f2/argvals",
150            expected: format!("equal lengths, f1 has {m}"),
151            actual: format!("f2 has {}, argvals has {}", f2.len(), argvals.len()),
152        });
153    }
154    if m < 2 {
155        return Err(FdarError::InvalidDimension {
156            parameter: "f1",
157            expected: "length >= 2".to_string(),
158            actual: format!("length {m}"),
159        });
160    }
161
162    let (best_k, _) = find_best_rotation(f1, f2, argvals, lambda);
163
164    // Full alignment at the best rotation
165    let f2_rotated = circular_shift(f2, best_k);
166    let result = elastic_align_pair(f1, &f2_rotated, argvals, lambda);
167
168    Ok(ClosedAlignmentResult {
169        gamma: result.gamma,
170        f_aligned: result.f_aligned,
171        distance: result.distance,
172        optimal_rotation: best_k,
173    })
174}
175
176/// Compute the elastic distance between two closed curves.
177///
178/// Optimizes over circular rotations of `f2` to find the minimum elastic distance.
179///
180/// # Arguments
181/// * `f1` - First curve (length m)
182/// * `f2` - Second curve (length m)
183/// * `argvals` - Evaluation points (length m)
184/// * `lambda` - Penalty weight on warp deviation from identity (0.0 = no penalty)
185///
186/// # Errors
187/// Returns `FdarError::InvalidDimension` if lengths do not match or m < 2.
188#[must_use = "expensive computation whose result should not be discarded"]
189pub fn elastic_distance_closed(
190    f1: &[f64],
191    f2: &[f64],
192    argvals: &[f64],
193    lambda: f64,
194) -> Result<f64, FdarError> {
195    Ok(elastic_align_pair_closed(f1, f2, argvals, lambda)?.distance)
196}
197
198/// Compute the Karcher (Frechet) mean for closed curves.
199///
200/// Uses the standard Karcher mean iteration but with [`elastic_align_pair_closed`]
201/// at each step, tracking per-curve optimal rotations.
202///
203/// # Arguments
204/// * `data` - Functional data matrix (n x m)
205/// * `argvals` - Evaluation points (length m)
206/// * `max_iter` - Maximum number of iterations
207/// * `tol` - Convergence tolerance for the SRSF mean
208/// * `lambda` - Penalty weight on warp deviation from identity (0.0 = no penalty)
209///
210/// # Errors
211/// Returns `FdarError::InvalidDimension` if dimensions are inconsistent or m < 2.
212#[must_use = "expensive computation whose result should not be discarded"]
213pub fn karcher_mean_closed(
214    data: &FdMatrix,
215    argvals: &[f64],
216    max_iter: usize,
217    tol: f64,
218    lambda: f64,
219) -> Result<ClosedKarcherMeanResult, FdarError> {
220    let (n, m) = data.shape();
221    if m != argvals.len() {
222        return Err(FdarError::InvalidDimension {
223            parameter: "argvals",
224            expected: format!("length {m}"),
225            actual: format!("length {}", argvals.len()),
226        });
227    }
228    if m < 2 {
229        return Err(FdarError::InvalidDimension {
230            parameter: "data",
231            expected: "ncols >= 2".to_string(),
232            actual: format!("ncols = {m}"),
233        });
234    }
235    if n == 0 {
236        return Err(FdarError::InvalidDimension {
237            parameter: "data",
238            expected: "nrows > 0".to_string(),
239            actual: "nrows = 0".to_string(),
240        });
241    }
242
243    // Initialize mean as the first curve
244    let mut mu: Vec<f64> = data.row(0);
245    let mut mu_q = srsf_single(&mu, argvals);
246
247    let mut gammas = FdMatrix::zeros(n, m);
248    let mut rotations = vec![0usize; n];
249    let mut converged = false;
250    let mut n_iter = 0;
251
252    for iter in 0..max_iter {
253        n_iter = iter + 1;
254
255        // Align all curves to current mean using closed alignment
256        let align_results: Vec<(ClosedAlignmentResult, Vec<f64>)> = iter_maybe_parallel!(0..n)
257            .map(|i| {
258                let fi = data.row(i);
259                let res = elastic_align_pair_closed(&mu, &fi, argvals, lambda)
260                    .expect("dimension invariant: all curves have length m");
261                let q_warped = srsf_single(&res.f_aligned, argvals);
262                (res, q_warped)
263            })
264            .collect();
265
266        // Accumulate and compute new mean SRSF
267        let mut mu_q_new = vec![0.0; m];
268        for (i, (res, q_aligned)) in align_results.iter().enumerate() {
269            for j in 0..m {
270                gammas[(i, j)] = res.gamma[j];
271                mu_q_new[j] += q_aligned[j];
272            }
273            rotations[i] = res.optimal_rotation;
274        }
275        for j in 0..m {
276            mu_q_new[j] /= n as f64;
277        }
278
279        // Check convergence
280        let diff_norm: f64 = mu_q
281            .iter()
282            .zip(mu_q_new.iter())
283            .map(|(&a, &b)| (a - b).powi(2))
284            .sum::<f64>()
285            .sqrt();
286        let old_norm: f64 = mu_q.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
287        let rel = diff_norm / old_norm;
288
289        mu_q = mu_q_new;
290
291        if rel < tol {
292            converged = true;
293            break;
294        }
295
296        // Reconstruct mean curve from SRSF
297        mu = crate::alignment::srsf::srsf_inverse(&mu_q, argvals, mu[0]);
298    }
299
300    // Compute final aligned data
301    let mut aligned_data = FdMatrix::zeros(n, m);
302    for i in 0..n {
303        let fi = data.row(i);
304        let f_rotated = circular_shift(&fi, rotations[i]);
305        let gamma_i: Vec<f64> = (0..m).map(|j| gammas[(i, j)]).collect();
306        let f_aligned = reparameterize_curve(&f_rotated, argvals, &gamma_i);
307        for j in 0..m {
308            aligned_data[(i, j)] = f_aligned[j];
309        }
310    }
311
312    // Reconstruct final mean from SRSF
313    mu = crate::alignment::srsf::srsf_inverse(&mu_q, argvals, mu[0]);
314
315    Ok(ClosedKarcherMeanResult {
316        mean: mu,
317        mean_srsf: mu_q,
318        gammas,
319        aligned_data,
320        rotations,
321        n_iter,
322        converged,
323    })
324}
325
326#[cfg(test)]
327mod tests {
328    use super::*;
329    use crate::test_helpers::uniform_grid;
330
331    #[test]
332    fn closed_align_identity() {
333        let m = 30;
334        let argvals = uniform_grid(m);
335        let f: Vec<f64> = argvals
336            .iter()
337            .map(|&t| (2.0 * std::f64::consts::PI * t).sin())
338            .collect();
339
340        let result = elastic_align_pair_closed(&f, &f, &argvals, 0.0).unwrap();
341        assert!(
342            result.distance < 0.1,
343            "identical closed curves should have near-zero distance, got {}",
344            result.distance
345        );
346        assert_eq!(
347            result.optimal_rotation, 0,
348            "identical curves should need no rotation"
349        );
350    }
351
352    #[test]
353    fn closed_align_shifted() {
354        // Use a non-periodic curve shape so the shift is unambiguous
355        let m = 40;
356        let argvals = uniform_grid(m);
357        let f1: Vec<f64> = argvals
358            .iter()
359            .map(|&t| (2.0 * std::f64::consts::PI * t).sin() + 0.5 * t)
360            .collect();
361        // Circularly shift f1 by 5 positions (small shift for reliable recovery)
362        let shift = 5;
363        let f2 = circular_shift(&f1, shift);
364
365        let result = elastic_align_pair_closed(&f1, &f2, &argvals, 0.0).unwrap();
366        // Distance should be small after alignment (shifted version of same curve)
367        assert!(
368            result.distance < 1.0,
369            "distance after closed alignment should be small, got {}",
370            result.distance
371        );
372    }
373
374    #[test]
375    fn closed_distance_symmetric() {
376        let m = 25;
377        let argvals = uniform_grid(m);
378        let f1: Vec<f64> = argvals
379            .iter()
380            .map(|&t| (2.0 * std::f64::consts::PI * t).sin())
381            .collect();
382        let f2: Vec<f64> = argvals
383            .iter()
384            .map(|&t| (2.0 * std::f64::consts::PI * t).cos())
385            .collect();
386
387        let d12 = elastic_distance_closed(&f1, &f2, &argvals, 0.0).unwrap();
388        let d21 = elastic_distance_closed(&f2, &f1, &argvals, 0.0).unwrap();
389
390        // Both distances should be finite and non-negative
391        assert!(
392            d12 >= 0.0 && d12.is_finite(),
393            "d12 should be non-negative finite, got {d12}"
394        );
395        assert!(
396            d21 >= 0.0 && d21.is_finite(),
397            "d21 should be non-negative finite, got {d21}"
398        );
399        // Closed curve alignment is not perfectly symmetric due to the discrete
400        // rotation search, but both distances should be in a reasonable range
401        assert!(
402            d12.max(d21) < 2.0 * d12.min(d21) + 0.5,
403            "closed distances should be in comparable range: d12={d12:.4}, d21={d21:.4}"
404        );
405    }
406
407    #[test]
408    fn closed_karcher_mean_smoke() {
409        let n = 5;
410        let m = 25;
411        let argvals = uniform_grid(m);
412
413        // Create 5 shifted sine curves
414        let mut data_flat = vec![0.0; n * m];
415        for i in 0..n {
416            let shift = i as f64 * 0.1;
417            for j in 0..m {
418                let t = argvals[j];
419                data_flat[i + j * n] = (2.0 * std::f64::consts::PI * (t + shift)).sin();
420            }
421        }
422        let data = FdMatrix::from_column_major(data_flat, n, m).unwrap();
423
424        let result = karcher_mean_closed(&data, &argvals, 10, 1e-3, 0.0).unwrap();
425        assert_eq!(result.mean.len(), m);
426        assert_eq!(result.mean_srsf.len(), m);
427        assert_eq!(result.gammas.shape(), (n, m));
428        assert_eq!(result.aligned_data.shape(), (n, m));
429        assert_eq!(result.rotations.len(), n);
430        assert!(result.n_iter <= 10);
431    }
432}