Skip to main content

fdars_core/alignment/
geodesic.rs

1//! Geodesic interpolation between curves in the elastic metric.
2
3use super::dp_alignment_core;
4use super::nd::{elastic_align_pair_nd, srsf_transform_nd};
5use super::srsf::{reparameterize_curve, srsf_inverse, srsf_single};
6use crate::error::FdarError;
7use crate::helpers::{l2_distance, simpsons_weights};
8use crate::matrix::{FdCurveSet, FdMatrix};
9use crate::warping::{exp_map_sphere, gam_to_psi, inv_exp_map_sphere, normalize_warp, psi_to_gam};
10
11/// A geodesic path between two 1-D curves in the elastic metric.
12#[derive(Debug, Clone, PartialEq)]
13#[non_exhaustive]
14pub struct GeodesicPath {
15    /// Interpolated curves (n_points x m).
16    pub curves: FdMatrix,
17    /// Interpolated warps (n_points x m).
18    pub warps: FdMatrix,
19    /// Elastic distance from f1 at each interpolation point.
20    pub distances: Vec<f64>,
21    /// Parameter values t in \[0, 1\].
22    pub parameter_values: Vec<f64>,
23}
24
25/// A geodesic path between two N-D curves in the elastic metric.
26#[derive(Debug, Clone, PartialEq)]
27#[non_exhaustive]
28pub struct GeodesicPathNd {
29    /// Interpolated curves per dimension: d `FdMatrix`es, each (n_points x m).
30    pub curves: Vec<FdMatrix>,
31    /// Interpolated warps (n_points x m).
32    pub warps: FdMatrix,
33    /// Elastic distance from f1 at each interpolation point.
34    pub distances: Vec<f64>,
35    /// Parameter values t in \[0, 1\].
36    pub parameter_values: Vec<f64>,
37}
38
39/// Compute the geodesic path between two 1-D curves in the elastic metric.
40///
41/// The path is parameterized by `n_points` values in \[0, 1\]. At t=0 the path
42/// is at `f1`; at t=1 it is at the aligned version of `f2`. Interpolation
43/// proceeds separately in amplitude (linear in SRSF space) and phase
44/// (geodesic on the Hilbert sphere of warping functions).
45///
46/// # Arguments
47/// * `f1`       - First curve (length m).
48/// * `f2`       - Second curve (length m).
49/// * `argvals`  - Evaluation grid (length m).
50/// * `n_points` - Number of interpolation points (>= 2).
51/// * `lambda`   - Alignment penalty (0 = no penalty).
52///
53/// # Errors
54/// Returns `FdarError::InvalidDimension` if lengths mismatch or are < 2.
55/// Returns `FdarError::InvalidParameter` if `n_points < 2`.
56#[must_use = "expensive computation whose result should not be discarded"]
57pub fn curve_geodesic(
58    f1: &[f64],
59    f2: &[f64],
60    argvals: &[f64],
61    n_points: usize,
62    lambda: f64,
63) -> Result<GeodesicPath, FdarError> {
64    let m = f1.len();
65
66    // ── Validation ──────────────────────────────────────────────────────
67    if m < 2 {
68        return Err(FdarError::InvalidDimension {
69            parameter: "f1",
70            expected: "length >= 2".to_string(),
71            actual: format!("length {m}"),
72        });
73    }
74    if f2.len() != m {
75        return Err(FdarError::InvalidDimension {
76            parameter: "f2",
77            expected: format!("length {m}"),
78            actual: format!("length {}", f2.len()),
79        });
80    }
81    if argvals.len() != m {
82        return Err(FdarError::InvalidDimension {
83            parameter: "argvals",
84            expected: format!("length {m}"),
85            actual: format!("length {}", argvals.len()),
86        });
87    }
88    if n_points < 2 {
89        return Err(FdarError::InvalidParameter {
90            parameter: "n_points",
91            message: format!("must be >= 2, got {n_points}"),
92        });
93    }
94
95    // ── Align f2 to f1 ─────────────────────────────────────────────────
96    let q1 = srsf_single(f1, argvals);
97    let q2 = srsf_single(f2, argvals);
98    let gamma = dp_alignment_core(&q1, &q2, argvals, lambda);
99    let f2_aligned = reparameterize_curve(f2, argvals, &gamma);
100    let q2a = srsf_single(&f2_aligned, argvals);
101
102    // ── Phase geodesic setup ────────────────────────────────────────────
103    let t0 = argvals[0];
104    let domain = argvals[m - 1] - t0;
105    let time_01: Vec<f64> = (0..m).map(|j| (j as f64) / (m - 1) as f64).collect();
106    let binsize = 1.0 / (m - 1) as f64;
107
108    let gamma_01: Vec<f64> = gamma.iter().map(|&g| (g - t0) / domain).collect();
109    let psi = gam_to_psi(&gamma_01, binsize);
110    let psi_id = gam_to_psi(&time_01, binsize);
111    let v = inv_exp_map_sphere(&psi_id, &psi, &time_01);
112
113    // ── Integration weights for distance computation ────────────────────
114    let weights = simpsons_weights(argvals);
115
116    // ── Interpolation ───────────────────────────────────────────────────
117    let parameter_values: Vec<f64> = (0..n_points)
118        .map(|k| k as f64 / (n_points - 1) as f64)
119        .collect();
120
121    let mut curves = FdMatrix::zeros(n_points, m);
122    let mut warps = FdMatrix::zeros(n_points, m);
123    let mut distances = Vec::with_capacity(n_points);
124
125    for (k, &t_k) in parameter_values.iter().enumerate() {
126        // Phase: geodesic on the Hilbert sphere
127        let scaled_v: Vec<f64> = v.iter().map(|&vi| t_k * vi).collect();
128        let psi_k = exp_map_sphere(&psi_id, &scaled_v, &time_01);
129        let mut gamma_k_01 = psi_to_gam(&psi_k, &time_01);
130        // Rescale from [0,1] to original domain
131        for j in 0..m {
132            gamma_k_01[j] = t0 + gamma_k_01[j] * domain;
133        }
134        normalize_warp(&mut gamma_k_01, argvals);
135
136        // Amplitude: linear interpolation in SRSF space
137        let q_k: Vec<f64> = (0..m).map(|j| (1.0 - t_k) * q1[j] + t_k * q2a[j]).collect();
138
139        // Reconstruct curve from SRSF
140        let f0_k = (1.0 - t_k) * f1[0] + t_k * f2_aligned[0];
141        let f_k = srsf_inverse(&q_k, argvals, f0_k);
142
143        // L2 distance from q1 to q_k
144        let dist = l2_distance(&q1, &q_k, &weights);
145
146        for j in 0..m {
147            curves[(k, j)] = f_k[j];
148            warps[(k, j)] = gamma_k_01[j];
149        }
150        distances.push(dist);
151    }
152
153    Ok(GeodesicPath {
154        curves,
155        warps,
156        distances,
157        parameter_values,
158    })
159}
160
161/// Compute the geodesic path between two N-D curves in the elastic metric.
162///
163/// Similar to [`curve_geodesic`], but for multidimensional (R^d) curves.
164/// The warping function is shared across all dimensions.
165///
166/// # Arguments
167/// * `f1`       - First curve set (d dimensions, 1 curve each).
168/// * `f2`       - Second curve set (d dimensions, 1 curve each).
169/// * `argvals`  - Evaluation grid (length m).
170/// * `n_points` - Number of interpolation points (>= 2).
171/// * `lambda`   - Alignment penalty (0 = no penalty).
172///
173/// # Errors
174/// Returns `FdarError::InvalidDimension` if curve sets have inconsistent dimensions.
175/// Returns `FdarError::InvalidParameter` if `n_points < 2`.
176#[must_use = "expensive computation whose result should not be discarded"]
177pub fn curve_geodesic_nd(
178    f1: &FdCurveSet,
179    f2: &FdCurveSet,
180    argvals: &[f64],
181    n_points: usize,
182    lambda: f64,
183) -> Result<GeodesicPathNd, FdarError> {
184    let d = f1.ndim();
185    let m = f1.npoints();
186
187    // ── Validation ──────────────────────────────────────────────────────
188    if d == 0 {
189        return Err(FdarError::InvalidDimension {
190            parameter: "f1",
191            expected: "ndim >= 1".to_string(),
192            actual: "ndim 0".to_string(),
193        });
194    }
195    if f2.ndim() != d {
196        return Err(FdarError::InvalidDimension {
197            parameter: "f2",
198            expected: format!("ndim {d}"),
199            actual: format!("ndim {}", f2.ndim()),
200        });
201    }
202    if f2.npoints() != m {
203        return Err(FdarError::InvalidDimension {
204            parameter: "f2",
205            expected: format!("{m} points"),
206            actual: format!("{} points", f2.npoints()),
207        });
208    }
209    if m < 2 {
210        return Err(FdarError::InvalidDimension {
211            parameter: "f1",
212            expected: "npoints >= 2".to_string(),
213            actual: format!("npoints {m}"),
214        });
215    }
216    if argvals.len() != m {
217        return Err(FdarError::InvalidDimension {
218            parameter: "argvals",
219            expected: format!("length {m}"),
220            actual: format!("length {}", argvals.len()),
221        });
222    }
223    if n_points < 2 {
224        return Err(FdarError::InvalidParameter {
225            parameter: "n_points",
226            message: format!("must be >= 2, got {n_points}"),
227        });
228    }
229
230    // ── Align f2 to f1 ─────────────────────────────────────────────────
231    let result = elastic_align_pair_nd(f1, f2, argvals, lambda);
232    let gamma = &result.gamma;
233
234    // SRSFs of f1 and aligned f2 per dimension
235    let q1_set = srsf_transform_nd(f1, argvals);
236    let f2_aligned_set = {
237        let dims: Vec<FdMatrix> = result
238            .f_aligned
239            .iter()
240            .map(|fa| FdMatrix::from_slice(fa, 1, m).expect("dimension invariant"))
241            .collect();
242        FdCurveSet { dims }
243    };
244    let q2a_set = srsf_transform_nd(&f2_aligned_set, argvals);
245
246    let q1: Vec<Vec<f64>> = q1_set.dims.iter().map(|dm| dm.row(0)).collect();
247    let q2a: Vec<Vec<f64>> = q2a_set.dims.iter().map(|dm| dm.row(0)).collect();
248
249    // ── Phase geodesic setup (shared across dimensions) ─────────────────
250    let t0 = argvals[0];
251    let domain = argvals[m - 1] - t0;
252    let time_01: Vec<f64> = (0..m).map(|j| (j as f64) / (m - 1) as f64).collect();
253    let binsize = 1.0 / (m - 1) as f64;
254
255    let gamma_01: Vec<f64> = gamma.iter().map(|&g| (g - t0) / domain).collect();
256    let psi = gam_to_psi(&gamma_01, binsize);
257    let psi_id = gam_to_psi(&time_01, binsize);
258    let v = inv_exp_map_sphere(&psi_id, &psi, &time_01);
259
260    // ── Integration weights ─────────────────────────────────────────────
261    let weights = simpsons_weights(argvals);
262
263    // ── Interpolation ───────────────────────────────────────────────────
264    let parameter_values: Vec<f64> = (0..n_points)
265        .map(|k| k as f64 / (n_points - 1) as f64)
266        .collect();
267
268    let mut dim_curves: Vec<FdMatrix> = (0..d).map(|_| FdMatrix::zeros(n_points, m)).collect();
269    let mut warps_mat = FdMatrix::zeros(n_points, m);
270    let mut distances = Vec::with_capacity(n_points);
271
272    for (k, &t_k) in parameter_values.iter().enumerate() {
273        // Phase: geodesic on the Hilbert sphere
274        let scaled_v: Vec<f64> = v.iter().map(|&vi| t_k * vi).collect();
275        let psi_k = exp_map_sphere(&psi_id, &scaled_v, &time_01);
276        let mut gamma_k_01 = psi_to_gam(&psi_k, &time_01);
277        for j in 0..m {
278            gamma_k_01[j] = t0 + gamma_k_01[j] * domain;
279        }
280        normalize_warp(&mut gamma_k_01, argvals);
281
282        for j in 0..m {
283            warps_mat[(k, j)] = gamma_k_01[j];
284        }
285
286        // Per-dimension amplitude interpolation + reconstruction
287        let mut dist_sq = 0.0;
288        for dd in 0..d {
289            let q_k: Vec<f64> = (0..m)
290                .map(|j| (1.0 - t_k) * q1[dd][j] + t_k * q2a[dd][j])
291                .collect();
292
293            let f0_k = (1.0 - t_k) * f1.dims[dd][(0, 0)] + t_k * result.f_aligned[dd][0];
294            let f_k = srsf_inverse(&q_k, argvals, f0_k);
295
296            let d_k = l2_distance(&q1[dd], &q_k, &weights);
297            dist_sq += d_k * d_k;
298
299            for j in 0..m {
300                dim_curves[dd][(k, j)] = f_k[j];
301            }
302        }
303        distances.push(dist_sq.sqrt());
304    }
305
306    Ok(GeodesicPathNd {
307        curves: dim_curves,
308        warps: warps_mat,
309        distances,
310        parameter_values,
311    })
312}
313
314#[cfg(test)]
315mod tests {
316    use super::*;
317    use crate::test_helpers::uniform_grid;
318
319    #[test]
320    fn geodesic_endpoints_match() {
321        let m = 51;
322        let t = uniform_grid(m);
323        let f1: Vec<f64> = t
324            .iter()
325            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
326            .collect();
327        let f2: Vec<f64> = t
328            .iter()
329            .map(|&ti| (2.0 * std::f64::consts::PI * ti).cos())
330            .collect();
331
332        let path = curve_geodesic(&f1, &f2, &t, 5, 0.0).unwrap();
333
334        // At t=0 the path curve should approximate f1
335        let first_curve = path.curves.row(0);
336        let max_diff_start: f64 = first_curve
337            .iter()
338            .zip(f1.iter())
339            .map(|(&a, &b)| (a - b).abs())
340            .fold(0.0_f64, f64::max);
341        assert!(
342            max_diff_start < 0.5,
343            "At t=0 curve should approximate f1, max diff = {max_diff_start}"
344        );
345
346        // The last curve is at t=1, which should approximate f2_aligned
347        // (not necessarily f2 itself, but should be a valid curve)
348        let last_curve = path.curves.row(path.parameter_values.len() - 1);
349        assert_eq!(last_curve.len(), m);
350    }
351
352    #[test]
353    fn geodesic_distances_nonneg() {
354        let m = 41;
355        let t = uniform_grid(m);
356        let f1: Vec<f64> = t
357            .iter()
358            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
359            .collect();
360        let f2: Vec<f64> = t
361            .iter()
362            .map(|&ti| 0.5 * (4.0 * std::f64::consts::PI * ti).sin())
363            .collect();
364
365        let path = curve_geodesic(&f1, &f2, &t, 6, 0.0).unwrap();
366        for (k, &d) in path.distances.iter().enumerate() {
367            assert!(d >= 0.0, "Distance at k={k} should be >= 0, got {d}");
368        }
369    }
370
371    #[test]
372    fn geodesic_identical_curves() {
373        let m = 41;
374        let t = uniform_grid(m);
375        let f: Vec<f64> = t
376            .iter()
377            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
378            .collect();
379
380        let path = curve_geodesic(&f, &f, &t, 4, 0.0).unwrap();
381
382        // All interpolated curves should be close to f
383        for k in 0..path.parameter_values.len() {
384            let curve_k = path.curves.row(k);
385            let max_diff: f64 = curve_k
386                .iter()
387                .zip(f.iter())
388                .map(|(&a, &b)| (a - b).abs())
389                .fold(0.0_f64, f64::max);
390            assert!(
391                max_diff < 0.5,
392                "Identical curve geodesic: curve at k={k} deviates by {max_diff}"
393            );
394        }
395
396        // Distances should be near zero
397        for (k, &d) in path.distances.iter().enumerate() {
398            assert!(
399                d < 1.0,
400                "Identical curve geodesic: distance at k={k} = {d}, expected near 0"
401            );
402        }
403    }
404
405    #[test]
406    fn geodesic_nd_dimensions() {
407        let m = 31;
408        let t = uniform_grid(m);
409        let d = 2;
410        let n_points = 4;
411
412        let f1x: Vec<f64> = t
413            .iter()
414            .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
415            .collect();
416        let f1y: Vec<f64> = t
417            .iter()
418            .map(|&ti| (2.0 * std::f64::consts::PI * ti).cos())
419            .collect();
420        let f2x: Vec<f64> = t.iter().map(|&ti| ti * ti).collect();
421        let f2y: Vec<f64> = t.to_vec();
422
423        let f1 = FdCurveSet::from_dims(vec![
424            FdMatrix::from_slice(&f1x, 1, m).unwrap(),
425            FdMatrix::from_slice(&f1y, 1, m).unwrap(),
426        ])
427        .unwrap();
428        let f2 = FdCurveSet::from_dims(vec![
429            FdMatrix::from_slice(&f2x, 1, m).unwrap(),
430            FdMatrix::from_slice(&f2y, 1, m).unwrap(),
431        ])
432        .unwrap();
433
434        let path = curve_geodesic_nd(&f1, &f2, &t, n_points, 0.0).unwrap();
435
436        assert_eq!(path.curves.len(), d, "Should have d dimension matrices");
437        for (dd, dim_mat) in path.curves.iter().enumerate() {
438            assert_eq!(
439                dim_mat.shape(),
440                (n_points, m),
441                "Dimension {dd} matrix shape mismatch"
442            );
443        }
444        assert_eq!(path.warps.shape(), (n_points, m));
445        assert_eq!(path.distances.len(), n_points);
446        assert_eq!(path.parameter_values.len(), n_points);
447    }
448}