Skip to main content

fdars_core/alignment/
shape.rs

1//! Elastic shape analysis: quotient space operations and orbit representatives.
2//!
3//! Extends elastic alignment to work in quotient spaces where curves are
4//! considered equivalent up to reparameterization, translation, and/or scaling.
5
6use super::karcher::karcher_mean;
7use super::pairwise::{elastic_align_pair, elastic_self_distance_matrix};
8use super::srsf::srsf_single;
9use super::{AlignmentResult, KarcherMeanResult};
10use crate::error::FdarError;
11use crate::helpers::simpsons_weights;
12use crate::matrix::FdMatrix;
13use crate::warping::l2_norm_l2;
14
15// ─── Types ──────────────────────────────────────────────────────────────────
16
17/// Quotient space for shape analysis.
18///
19/// Determines which transformations are factored out when comparing curves.
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
21#[non_exhaustive]
22pub enum ShapeQuotient {
23    /// Quotient by reparameterization only (elastic distance).
24    #[default]
25    Reparameterization,
26    /// Quotient by reparameterization and vertical translation.
27    ReparameterizationTranslation,
28    /// Quotient by reparameterization, translation, and scale.
29    ReparameterizationTranslationScale,
30}
31
32/// A canonical representative of a shape orbit.
33#[derive(Debug, Clone, PartialEq)]
34#[non_exhaustive]
35pub struct OrbitRepresentative {
36    /// The pre-processed curve (centered, scaled, etc.).
37    pub representative: Vec<f64>,
38    /// SRSF of the representative curve.
39    pub representative_srsf: Vec<f64>,
40    /// Warping function applied (identity for orbit_representative).
41    pub gamma: Vec<f64>,
42    /// Vertical translation removed.
43    pub translation: f64,
44    /// Scale factor removed.
45    pub scale: f64,
46}
47
48/// Result of computing the elastic shape distance between two curves.
49#[derive(Debug, Clone, PartialEq)]
50#[non_exhaustive]
51pub struct ShapeDistanceResult {
52    /// Shape distance in the quotient space.
53    pub distance: f64,
54    /// Optimal warping function (length m).
55    pub gamma: Vec<f64>,
56    /// Second curve aligned to the first.
57    pub f2_aligned: Vec<f64>,
58}
59
60/// Result of computing the shape mean of a set of curves.
61#[derive(Debug, Clone, PartialEq)]
62#[non_exhaustive]
63pub struct ShapeMeanResult {
64    /// Shape mean curve.
65    pub mean: Vec<f64>,
66    /// SRSF of the shape mean.
67    pub mean_srsf: Vec<f64>,
68    /// Warping functions (n x m).
69    pub gammas: FdMatrix,
70    /// Curves aligned to the mean (n x m).
71    pub aligned_data: FdMatrix,
72    /// Number of iterations used.
73    pub n_iter: usize,
74    /// Whether the algorithm converged.
75    pub converged: bool,
76}
77
78// ─── Pre-processing Helpers ─────────────────────────────────────────────────
79
80/// Compute the integral mean of a curve using Simpson's weights.
81fn integral_mean(f: &[f64], argvals: &[f64]) -> f64 {
82    let w = simpsons_weights(argvals);
83    let total_w: f64 = w.iter().sum();
84    if total_w <= 0.0 {
85        return 0.0;
86    }
87    let wsum: f64 = f.iter().zip(w.iter()).map(|(&fi, &wi)| fi * wi).sum();
88    wsum / total_w
89}
90
91/// Pre-process a curve according to the quotient type.
92///
93/// Returns `(processed_curve, translation, scale)`.
94fn preprocess_curve(f: &[f64], argvals: &[f64], quotient: ShapeQuotient) -> (Vec<f64>, f64, f64) {
95    let mut curve = f.to_vec();
96    let mut translation = 0.0;
97    let mut scale = 1.0;
98
99    match quotient {
100        ShapeQuotient::Reparameterization => {
101            // No pre-processing needed.
102        }
103        ShapeQuotient::ReparameterizationTranslation => {
104            // Subtract integral mean.
105            let mean_val = integral_mean(&curve, argvals);
106            translation = mean_val;
107            for v in &mut curve {
108                *v -= mean_val;
109            }
110        }
111        ShapeQuotient::ReparameterizationTranslationScale => {
112            // Subtract integral mean, then scale by SRSF L2 norm.
113            let mean_val = integral_mean(&curve, argvals);
114            translation = mean_val;
115            for v in &mut curve {
116                *v -= mean_val;
117            }
118
119            // Compute L2 norm of the SRSF for scale normalization.
120            let q = srsf_single(&curve, argvals);
121            // Use a uniform [0,1] time grid for the L2 norm computation.
122            let m = argvals.len();
123            let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1).max(1) as f64).collect();
124            let norm = l2_norm_l2(&q, &time);
125
126            if norm > 1e-10 {
127                scale = norm;
128                for v in &mut curve {
129                    *v /= norm;
130                }
131            }
132        }
133    }
134
135    (curve, translation, scale)
136}
137
138/// Pre-process all rows of a data matrix according to the quotient type.
139fn preprocess_data(data: &FdMatrix, argvals: &[f64], quotient: ShapeQuotient) -> FdMatrix {
140    let (n, m) = data.shape();
141    let mut result = FdMatrix::zeros(n, m);
142    for i in 0..n {
143        let row = data.row(i);
144        let (processed, _, _) = preprocess_curve(&row, argvals, quotient);
145        for j in 0..m {
146            result[(i, j)] = processed[j];
147        }
148    }
149    result
150}
151
152// ─── Public API ─────────────────────────────────────────────────────────────
153
154/// Compute the canonical orbit representative of a curve.
155///
156/// Applies the quotient transformations (centering, scaling) and computes
157/// the SRSF of the result. The warping function is the identity.
158///
159/// # Arguments
160/// * `f`        — Curve values (length m).
161/// * `argvals`  — Evaluation points (length m).
162/// * `quotient` — Which transformations to factor out.
163///
164/// # Errors
165/// Returns [`FdarError::InvalidDimension`] if lengths do not match or `m < 2`.
166pub fn orbit_representative(
167    f: &[f64],
168    argvals: &[f64],
169    quotient: ShapeQuotient,
170) -> Result<OrbitRepresentative, FdarError> {
171    let m = f.len();
172    if m != argvals.len() {
173        return Err(FdarError::InvalidDimension {
174            parameter: "f",
175            expected: format!("length {}", argvals.len()),
176            actual: format!("length {m}"),
177        });
178    }
179    if m < 2 {
180        return Err(FdarError::InvalidDimension {
181            parameter: "f",
182            expected: "length >= 2".to_string(),
183            actual: format!("length {m}"),
184        });
185    }
186
187    let (representative, translation, scale) = preprocess_curve(f, argvals, quotient);
188    let representative_srsf = srsf_single(&representative, argvals);
189    let gamma = argvals.to_vec(); // identity warp
190
191    Ok(OrbitRepresentative {
192        representative,
193        representative_srsf,
194        gamma,
195        translation,
196        scale,
197    })
198}
199
200/// Compute the elastic shape distance between two curves.
201///
202/// Pre-processes both curves according to the quotient type, then computes
203/// the elastic distance after optimal alignment.
204///
205/// # Arguments
206/// * `f1`       — First curve (length m).
207/// * `f2`       — Second curve (length m).
208/// * `argvals`  — Evaluation points (length m).
209/// * `quotient` — Which transformations to factor out.
210/// * `lambda`   — Roughness penalty for alignment.
211///
212/// # Errors
213/// Returns [`FdarError::InvalidDimension`] if lengths do not match or `m < 2`.
214#[must_use = "expensive computation whose result should not be discarded"]
215pub fn shape_distance(
216    f1: &[f64],
217    f2: &[f64],
218    argvals: &[f64],
219    quotient: ShapeQuotient,
220    lambda: f64,
221) -> Result<ShapeDistanceResult, FdarError> {
222    let m = f1.len();
223    if m != f2.len() || m != argvals.len() {
224        return Err(FdarError::InvalidDimension {
225            parameter: "f1/f2",
226            expected: format!("matching lengths == argvals.len() ({})", argvals.len()),
227            actual: format!("f1.len()={}, f2.len()={}", f1.len(), f2.len()),
228        });
229    }
230    if m < 2 {
231        return Err(FdarError::InvalidDimension {
232            parameter: "f1",
233            expected: "length >= 2".to_string(),
234            actual: format!("length {m}"),
235        });
236    }
237
238    let (f1_pre, _, _) = preprocess_curve(f1, argvals, quotient);
239    let (f2_pre, _, _) = preprocess_curve(f2, argvals, quotient);
240
241    let AlignmentResult {
242        gamma,
243        f_aligned,
244        distance,
245    } = elastic_align_pair(&f1_pre, &f2_pre, argvals, lambda);
246
247    Ok(ShapeDistanceResult {
248        distance,
249        gamma,
250        f2_aligned: f_aligned,
251    })
252}
253
254/// Compute the pairwise shape distance matrix for a set of curves.
255///
256/// Pre-processes all curves according to the quotient type, then delegates to
257/// the elastic self-distance matrix computation.
258///
259/// # Arguments
260/// * `data`     — Functional data matrix (n x m).
261/// * `argvals`  — Evaluation points (length m).
262/// * `quotient` — Which transformations to factor out.
263/// * `lambda`   — Roughness penalty for alignment.
264///
265/// # Errors
266/// Returns [`FdarError::InvalidDimension`] if `argvals` length does not match `m`.
267#[must_use = "expensive computation whose result should not be discarded"]
268pub fn shape_self_distance_matrix(
269    data: &FdMatrix,
270    argvals: &[f64],
271    quotient: ShapeQuotient,
272    lambda: f64,
273) -> Result<FdMatrix, FdarError> {
274    let (_n, m) = data.shape();
275    if argvals.len() != m {
276        return Err(FdarError::InvalidDimension {
277            parameter: "argvals",
278            expected: format!("{m}"),
279            actual: format!("{}", argvals.len()),
280        });
281    }
282
283    let preprocessed = preprocess_data(data, argvals, quotient);
284    Ok(elastic_self_distance_matrix(&preprocessed, argvals, lambda))
285}
286
287/// Compute the Karcher (Frechet) mean in the elastic shape space.
288///
289/// Pre-processes all curves according to the quotient type, then computes
290/// the Karcher mean on the preprocessed data.
291///
292/// # Arguments
293/// * `data`     — Functional data matrix (n x m).
294/// * `argvals`  — Evaluation points (length m).
295/// * `quotient` — Which transformations to factor out.
296/// * `lambda`   — Roughness penalty for alignment.
297/// * `max_iter` — Maximum number of Karcher iterations.
298/// * `tol`      — Convergence tolerance.
299///
300/// # Errors
301/// Returns [`FdarError::InvalidDimension`] if `argvals` length does not match `m`
302/// or `n < 1`.
303#[must_use = "expensive computation whose result should not be discarded"]
304pub fn shape_mean(
305    data: &FdMatrix,
306    argvals: &[f64],
307    quotient: ShapeQuotient,
308    lambda: f64,
309    max_iter: usize,
310    tol: f64,
311) -> Result<ShapeMeanResult, FdarError> {
312    let (n, m) = data.shape();
313    if argvals.len() != m {
314        return Err(FdarError::InvalidDimension {
315            parameter: "argvals",
316            expected: format!("{m}"),
317            actual: format!("{}", argvals.len()),
318        });
319    }
320    if n < 1 {
321        return Err(FdarError::InvalidDimension {
322            parameter: "data",
323            expected: "at least 1 row".to_string(),
324            actual: format!("{n} rows"),
325        });
326    }
327
328    let preprocessed = preprocess_data(data, argvals, quotient);
329    let KarcherMeanResult {
330        mean,
331        mean_srsf,
332        gammas,
333        aligned_data,
334        n_iter,
335        converged,
336        ..
337    } = karcher_mean(&preprocessed, argvals, max_iter, tol, lambda);
338
339    Ok(ShapeMeanResult {
340        mean,
341        mean_srsf,
342        gammas,
343        aligned_data,
344        n_iter,
345        converged,
346    })
347}
348
349// ─── Tests ──────────────────────────────────────────────────────────────────
350
351#[cfg(test)]
352mod tests {
353    use super::*;
354    use crate::simulation::{sim_fundata, EFunType, EValType};
355    use crate::test_helpers::uniform_grid;
356
357    fn make_data(n: usize, m: usize) -> (FdMatrix, Vec<f64>) {
358        let t = uniform_grid(m);
359        let data = sim_fundata(n, &t, 3, EFunType::Fourier, EValType::Exponential, Some(99));
360        (data, t)
361    }
362
363    // ── orbit_representative ──
364
365    #[test]
366    fn orbit_representative_reparam_only() {
367        let t = uniform_grid(30);
368        let f: Vec<f64> = t.iter().map(|&x| (x * 6.0).sin()).collect();
369        let rep = orbit_representative(&f, &t, ShapeQuotient::Reparameterization).unwrap();
370        // No transformation: representative should match original.
371        assert_eq!(rep.representative.len(), 30);
372        for i in 0..30 {
373            assert!(
374                (rep.representative[i] - f[i]).abs() < 1e-12,
375                "reparameterization-only orbit should not change the curve"
376            );
377        }
378        assert!((rep.translation - 0.0).abs() < f64::EPSILON);
379        assert!((rep.scale - 1.0).abs() < f64::EPSILON);
380        assert_eq!(rep.gamma, t);
381    }
382
383    #[test]
384    fn orbit_representative_translation() {
385        let t = uniform_grid(30);
386        let offset = 5.0;
387        let f: Vec<f64> = t.iter().map(|&x| (x * 6.0).sin() + offset).collect();
388        let rep =
389            orbit_representative(&f, &t, ShapeQuotient::ReparameterizationTranslation).unwrap();
390        // The integral mean should be removed; representative should be roughly centered.
391        let mean_after = integral_mean(&rep.representative, &t);
392        assert!(
393            mean_after.abs() < 1e-10,
394            "translation quotient should center the curve, mean={mean_after}"
395        );
396    }
397
398    #[test]
399    fn orbit_representative_translation_scale() {
400        let t = uniform_grid(50);
401        let f: Vec<f64> = t.iter().map(|&x| 10.0 * (x * 4.0).sin() + 3.0).collect();
402        let rep = orbit_representative(&f, &t, ShapeQuotient::ReparameterizationTranslationScale)
403            .unwrap();
404        assert!(rep.scale > 0.0, "scale factor should be positive");
405
406        // Scaling a curve by alpha should produce the same representative (up to sign).
407        let f2: Vec<f64> = t.iter().map(|&x| 20.0 * (x * 4.0).sin() + 3.0).collect();
408        let rep2 = orbit_representative(&f2, &t, ShapeQuotient::ReparameterizationTranslationScale)
409            .unwrap();
410
411        // The representatives should be proportional (same shape); check correlation.
412        let dot: f64 = rep
413            .representative
414            .iter()
415            .zip(rep2.representative.iter())
416            .map(|(&a, &b)| a * b)
417            .sum();
418        let n1: f64 = rep
419            .representative
420            .iter()
421            .map(|&v| v * v)
422            .sum::<f64>()
423            .sqrt();
424        let n2: f64 = rep2
425            .representative
426            .iter()
427            .map(|&v| v * v)
428            .sum::<f64>()
429            .sqrt();
430        let corr = if n1 > 1e-10 && n2 > 1e-10 {
431            dot / (n1 * n2)
432        } else {
433            1.0
434        };
435        assert!(
436            corr > 0.99,
437            "scaled curves should have nearly identical representatives, corr={corr}"
438        );
439    }
440
441    #[test]
442    fn orbit_representative_length_mismatch() {
443        let t = uniform_grid(30);
444        let f = vec![1.0; 20];
445        assert!(orbit_representative(&f, &t, ShapeQuotient::Reparameterization).is_err());
446    }
447
448    #[test]
449    fn orbit_representative_too_short() {
450        let f = vec![1.0];
451        let t = vec![0.0];
452        assert!(orbit_representative(&f, &t, ShapeQuotient::Reparameterization).is_err());
453    }
454
455    // ── shape_distance ──
456
457    #[test]
458    fn shape_distance_identical_curves() {
459        let t = uniform_grid(30);
460        let f: Vec<f64> = t.iter().map(|&x| (x * 6.0).sin()).collect();
461        let result = shape_distance(&f, &f, &t, ShapeQuotient::Reparameterization, 0.0).unwrap();
462        assert!(
463            result.distance < 0.1,
464            "distance between identical curves should be near zero, got {}",
465            result.distance
466        );
467        assert_eq!(result.gamma.len(), 30);
468        assert_eq!(result.f2_aligned.len(), 30);
469    }
470
471    #[test]
472    fn shape_distance_translated_curves() {
473        let t = uniform_grid(30);
474        let f1: Vec<f64> = t.iter().map(|&x| (x * 6.0).sin()).collect();
475        let f2: Vec<f64> = t.iter().map(|&x| (x * 6.0).sin() + 5.0).collect();
476
477        // Without translation quotient: distance should be large.
478        let d_no_trans =
479            shape_distance(&f1, &f2, &t, ShapeQuotient::Reparameterization, 0.0).unwrap();
480        // With translation quotient: distance should be much smaller.
481        let d_trans = shape_distance(
482            &f1,
483            &f2,
484            &t,
485            ShapeQuotient::ReparameterizationTranslation,
486            0.0,
487        )
488        .unwrap();
489
490        assert!(
491            d_trans.distance < d_no_trans.distance + 0.01,
492            "translation quotient should not increase distance: d_trans={}, d_no_trans={}",
493            d_trans.distance,
494            d_no_trans.distance
495        );
496    }
497
498    #[test]
499    fn shape_distance_length_mismatch() {
500        let t = uniform_grid(30);
501        let f1 = vec![0.0; 30];
502        let f2 = vec![0.0; 20];
503        assert!(shape_distance(&f1, &f2, &t, ShapeQuotient::Reparameterization, 0.0).is_err());
504    }
505
506    // ── shape_self_distance_matrix ──
507
508    #[test]
509    fn shape_distance_matrix_smoke() {
510        let (data, t) = make_data(5, 20);
511        let dmat =
512            shape_self_distance_matrix(&data, &t, ShapeQuotient::Reparameterization, 0.0).unwrap();
513        assert_eq!(dmat.shape(), (5, 5));
514        // Diagonal should be zero.
515        for i in 0..5 {
516            assert!(
517                dmat[(i, i)].abs() < 1e-10,
518                "diagonal should be zero, got {}",
519                dmat[(i, i)]
520            );
521        }
522        // Should be symmetric.
523        for i in 0..5 {
524            for j in (i + 1)..5 {
525                assert!(
526                    (dmat[(i, j)] - dmat[(j, i)]).abs() < 1e-10,
527                    "distance matrix should be symmetric"
528                );
529            }
530        }
531    }
532
533    #[test]
534    fn shape_distance_matrix_argvals_mismatch() {
535        let (data, _) = make_data(5, 20);
536        let bad_t = uniform_grid(15);
537        assert!(
538            shape_self_distance_matrix(&data, &bad_t, ShapeQuotient::Reparameterization, 0.0)
539                .is_err()
540        );
541    }
542
543    // ── shape_mean ──
544
545    #[test]
546    fn shape_mean_smoke() {
547        let (data, t) = make_data(6, 25);
548        let result =
549            shape_mean(&data, &t, ShapeQuotient::Reparameterization, 0.0, 5, 1e-2).unwrap();
550        assert_eq!(result.mean.len(), 25);
551        assert_eq!(result.mean_srsf.len(), 25);
552        assert_eq!(result.gammas.shape(), (6, 25));
553        assert_eq!(result.aligned_data.shape(), (6, 25));
554        assert!(result.n_iter >= 1);
555    }
556
557    #[test]
558    fn shape_mean_translation_quotient() {
559        let (data, t) = make_data(6, 25);
560        let result = shape_mean(
561            &data,
562            &t,
563            ShapeQuotient::ReparameterizationTranslation,
564            0.0,
565            5,
566            1e-2,
567        )
568        .unwrap();
569        assert_eq!(result.mean.len(), 25);
570    }
571
572    #[test]
573    fn shape_mean_full_quotient() {
574        let (data, t) = make_data(6, 25);
575        let result = shape_mean(
576            &data,
577            &t,
578            ShapeQuotient::ReparameterizationTranslationScale,
579            0.0,
580            5,
581            1e-2,
582        )
583        .unwrap();
584        assert_eq!(result.mean.len(), 25);
585    }
586
587    #[test]
588    fn shape_mean_argvals_mismatch() {
589        let (data, _) = make_data(5, 25);
590        let bad_t = uniform_grid(15);
591        assert!(shape_mean(
592            &data,
593            &bad_t,
594            ShapeQuotient::Reparameterization,
595            0.0,
596            5,
597            1e-2
598        )
599        .is_err());
600    }
601
602    #[test]
603    fn default_quotient() {
604        assert_eq!(ShapeQuotient::default(), ShapeQuotient::Reparameterization);
605    }
606}