Skip to main content

fdars_core/alignment/
tsrvf.rs

1//! TSRVF (Transported SRSF) transforms and parallel transport methods.
2//!
3//! Maps aligned SRSFs to the tangent space of the Karcher mean on the Hilbert
4//! sphere. Tangent vectors live in a standard Euclidean space, enabling PCA,
5//! regression, and clustering on elastic-aligned curves.
6
7use super::karcher::karcher_mean;
8use super::srsf::{srsf_inverse, srsf_transform};
9use super::KarcherMeanResult;
10use crate::iter_maybe_parallel;
11use crate::matrix::FdMatrix;
12use crate::smoothing::nadaraya_watson;
13use crate::warping::{exp_map_sphere, inv_exp_map_sphere, l2_norm_l2};
14#[cfg(feature = "parallel")]
15use rayon::iter::ParallelIterator;
16
17// ─── Types ──────────────────────────────────────────────────────────────────
18
19/// Result of the TSRVF transform.
20#[derive(Debug, Clone, PartialEq)]
21#[non_exhaustive]
22pub struct TsrvfResult {
23    /// Tangent vectors in Euclidean space (n × m).
24    pub tangent_vectors: FdMatrix,
25    /// Karcher mean curve (length m).
26    pub mean: Vec<f64>,
27    /// SRSF of the Karcher mean (length m).
28    pub mean_srsf: Vec<f64>,
29    /// L2 norm of the mean SRSF.
30    pub mean_srsf_norm: f64,
31    /// Per-curve aligned SRSF norms (length n).
32    pub srsf_norms: Vec<f64>,
33    /// Per-curve initial values f_i(0) for SRSF inverse reconstruction (length n).
34    pub initial_values: Vec<f64>,
35    /// Warping functions from Karcher mean computation (n × m).
36    pub gammas: FdMatrix,
37    /// Whether the Karcher mean converged.
38    pub converged: bool,
39}
40
41/// Method for transporting tangent vectors on the Hilbert sphere.
42#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
43#[non_exhaustive]
44pub enum TransportMethod {
45    /// Inverse exponential map (log map) — default, matches existing TSRVF behavior.
46    #[default]
47    LogMap,
48    /// Schild's ladder approximation to parallel transport.
49    SchildsLadder,
50    /// Pole ladder approximation to parallel transport.
51    PoleLadder,
52}
53
54// ─── Smoothing ──────────────────────────────────────────────────────────────
55
56/// Smooth aligned SRSFs to remove DP kink artifacts before TSRVF computation.
57///
58/// Uses Nadaraya-Watson kernel smoothing (Gaussian, bandwidth = 2 grid spacings)
59/// on each SRSF row. This removes the derivative spikes from DP warp kinks
60/// without affecting alignment results or the Karcher mean.
61fn smooth_aligned_srsfs(srsf: &FdMatrix, m: usize) -> FdMatrix {
62    let n = srsf.nrows();
63    let time: Vec<f64> = (0..m).map(|j| j as f64 / (m - 1) as f64).collect();
64    let bandwidth = 2.0 / (m - 1) as f64;
65
66    let mut smoothed = FdMatrix::zeros(n, m);
67    for i in 0..n {
68        let qi = srsf.row(i);
69        // nadaraya_watson cannot fail here: time/qi are non-empty (m > 0) and bandwidth > 0.
70        let qi_smooth = nadaraya_watson(&time, &qi, &time, bandwidth, "gaussian")
71            .expect("smoothing valid SRSF data should not fail");
72        for j in 0..m {
73            smoothed[(i, j)] = qi_smooth[j];
74        }
75    }
76    smoothed
77}
78
79// ─── Parallel Transport ─────────────────────────────────────────────────────
80
81/// Schild's ladder parallel transport of vector `v` from `from` to `to` on the sphere.
82pub(super) fn parallel_transport_schilds(
83    v: &[f64],
84    from: &[f64],
85    to: &[f64],
86    time: &[f64],
87) -> Vec<f64> {
88    let v_norm = l2_norm_l2(v, time);
89    if v_norm < 1e-10 {
90        return vec![0.0; v.len()];
91    }
92
93    // endpoint = exp_from(v)
94    let endpoint = exp_map_sphere(from, v, time);
95
96    // midpoint_v = log_to(endpoint) — vector at `to` pointing toward endpoint
97    let log_to_ep = inv_exp_map_sphere(to, &endpoint, time);
98
99    // midpoint = exp_to(0.5 * log_to_ep)
100    let half_log: Vec<f64> = log_to_ep.iter().map(|&x| 0.5 * x).collect();
101    let midpoint = exp_map_sphere(to, &half_log, time);
102
103    // transported = 2 * log_to(midpoint)
104    let log_to_mid = inv_exp_map_sphere(to, &midpoint, time);
105    log_to_mid.iter().map(|&x| 2.0 * x).collect()
106}
107
108/// Pole ladder parallel transport of vector `v` from `from` to `to` on the sphere.
109pub(super) fn parallel_transport_pole(
110    v: &[f64],
111    from: &[f64],
112    to: &[f64],
113    time: &[f64],
114) -> Vec<f64> {
115    let v_norm = l2_norm_l2(v, time);
116    if v_norm < 1e-10 {
117        return vec![0.0; v.len()];
118    }
119
120    // pole = exp_from(-v)
121    let neg_v: Vec<f64> = v.iter().map(|&x| -x).collect();
122    let pole = exp_map_sphere(from, &neg_v, time);
123
124    // midpoint_v = log_to(pole)
125    let log_to_pole = inv_exp_map_sphere(to, &pole, time);
126
127    // midpoint = exp_to(0.5 * log_to_pole)
128    let half_log: Vec<f64> = log_to_pole.iter().map(|&x| 0.5 * x).collect();
129    let midpoint = exp_map_sphere(to, &half_log, time);
130
131    // transported = -2 * log_to(midpoint)
132    let log_to_mid = inv_exp_map_sphere(to, &midpoint, time);
133    log_to_mid.iter().map(|&x| -2.0 * x).collect()
134}
135
136// ─── TSRVF Public API ───────────────────────────────────────────────────────
137
138/// Full TSRVF pipeline: compute Karcher mean, then transport SRSFs to tangent space.
139///
140/// # Arguments
141/// * `data` — Functional data matrix (n × m)
142/// * `argvals` — Evaluation points (length m)
143/// * `max_iter` — Maximum Karcher mean iterations
144/// * `tol` — Convergence tolerance for Karcher mean
145///
146/// # Returns
147/// [`TsrvfResult`] containing tangent vectors and associated metadata.
148pub fn tsrvf_transform(
149    data: &FdMatrix,
150    argvals: &[f64],
151    max_iter: usize,
152    tol: f64,
153    lambda: f64,
154) -> TsrvfResult {
155    let karcher = karcher_mean(data, argvals, max_iter, tol, lambda);
156    tsrvf_from_alignment(&karcher, argvals)
157}
158
159/// Compute TSRVF from a pre-computed Karcher mean alignment.
160///
161/// Avoids re-running the expensive Karcher mean computation when the alignment
162/// has already been computed.
163///
164/// # Arguments
165/// * `karcher` — Pre-computed Karcher mean result
166/// * `argvals` — Evaluation points (length m)
167///
168/// # Returns
169/// [`TsrvfResult`] containing tangent vectors and associated metadata.
170pub fn tsrvf_from_alignment(karcher: &KarcherMeanResult, argvals: &[f64]) -> TsrvfResult {
171    let (n, m) = karcher.aligned_data.shape();
172
173    // Step 1: Compute SRSFs of aligned data
174    let aligned_srsf = srsf_transform(&karcher.aligned_data, argvals);
175
176    // Step 1b: Smooth aligned SRSFs to remove DP kink artifacts.
177    //
178    // DP alignment produces piecewise-linear warps with kinks at grid transitions.
179    // When curves are reparameterized by these warps, the kinks propagate into the
180    // aligned curves' derivatives (SRSFs), creating spikes that dominate TSRVF
181    // tangent vectors and PCA.
182    //
183    // R's fdasrvf does not smooth here and suffers from the same spike artifacts.
184    // Python's fdasrsf mitigates this via spline smoothing (s=1e-4) in SqrtMean.
185    // We smooth the aligned SRSFs before tangent vector computation — this only
186    // affects TSRVF output and does not change the alignment or Karcher mean.
187    let aligned_srsf = smooth_aligned_srsfs(&aligned_srsf, m);
188
189    // Step 2: Smooth and normalize mean SRSF to unit sphere.
190    // The mean SRSF must be smoothed consistently with the aligned SRSFs
191    // so that a single curve (which IS the mean) produces a zero tangent vector.
192    let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
193    let bandwidth = 2.0 / (m - 1) as f64;
194    let mean_srsf_smooth = nadaraya_watson(&time, &karcher.mean_srsf, &time, bandwidth, "gaussian")
195        .expect("smoothing valid mean SRSF should not fail");
196    let mean_norm = l2_norm_l2(&mean_srsf_smooth, &time);
197
198    let mu_unit: Vec<f64> = if mean_norm > 1e-10 {
199        mean_srsf_smooth.iter().map(|&q| q / mean_norm).collect()
200    } else {
201        vec![0.0; m]
202    };
203
204    // Step 3: For each aligned curve, compute tangent vector via inverse exponential map
205    let srsf_norms: Vec<f64> = iter_maybe_parallel!(0..n)
206        .map(|i| {
207            let qi = aligned_srsf.row(i);
208            l2_norm_l2(&qi, &time)
209        })
210        .collect();
211
212    let tangent_data: Vec<Vec<f64>> = iter_maybe_parallel!(0..n)
213        .map(|i| {
214            let qi = aligned_srsf.row(i);
215            let qi_norm = srsf_norms[i];
216
217            if qi_norm < 1e-10 || mean_norm < 1e-10 {
218                return vec![0.0; m];
219            }
220
221            // Normalize to unit sphere
222            let qi_unit: Vec<f64> = qi.iter().map(|&q| q / qi_norm).collect();
223
224            // Shooting vector from mu_unit to qi_unit
225            inv_exp_map_sphere(&mu_unit, &qi_unit, &time)
226        })
227        .collect();
228
229    // Assemble tangent vectors into FdMatrix
230    let mut tangent_vectors = FdMatrix::zeros(n, m);
231    for i in 0..n {
232        for j in 0..m {
233            tangent_vectors[(i, j)] = tangent_data[i][j];
234        }
235    }
236
237    // Store per-curve initial values for SRSF inverse reconstruction.
238    // Warping preserves f_i(0) since gamma(0) = 0.
239    let initial_values: Vec<f64> = (0..n).map(|i| karcher.aligned_data[(i, 0)]).collect();
240
241    TsrvfResult {
242        tangent_vectors,
243        mean: karcher.mean.clone(),
244        mean_srsf: mean_srsf_smooth,
245        mean_srsf_norm: mean_norm,
246        srsf_norms,
247        initial_values,
248        gammas: karcher.gammas.clone(),
249        converged: karcher.converged,
250    }
251}
252
253/// Reconstruct aligned curves from TSRVF tangent vectors.
254///
255/// Inverts the TSRVF transform: maps tangent vectors back to the Hilbert sphere
256/// via the exponential map, rescales, and reconstructs curves via SRSF inverse.
257///
258/// # Arguments
259/// * `tsrvf` — TSRVF result from [`tsrvf_transform`] or [`tsrvf_from_alignment`]
260/// * `argvals` — Evaluation points (length m)
261///
262/// # Returns
263/// FdMatrix of reconstructed aligned curves (n × m).
264pub fn tsrvf_inverse(tsrvf: &TsrvfResult, argvals: &[f64]) -> FdMatrix {
265    let (n, m) = tsrvf.tangent_vectors.shape();
266
267    let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
268
269    // Normalize mean SRSF to unit sphere
270    let mu_unit: Vec<f64> = if tsrvf.mean_srsf_norm > 1e-10 {
271        tsrvf
272            .mean_srsf
273            .iter()
274            .map(|&q| q / tsrvf.mean_srsf_norm)
275            .collect()
276    } else {
277        vec![0.0; m]
278    };
279
280    let curves: Vec<Vec<f64>> = iter_maybe_parallel!(0..n)
281        .map(|i| {
282            let vi = tsrvf.tangent_vectors.row(i);
283
284            // Map back to sphere: exp_map(mu_unit, v_i)
285            let qi_unit = exp_map_sphere(&mu_unit, &vi, &time);
286
287            // Rescale by original norm
288            let qi: Vec<f64> = qi_unit.iter().map(|&q| q * tsrvf.srsf_norms[i]).collect();
289
290            // Reconstruct curve from SRSF using per-curve initial value
291            srsf_inverse(&qi, argvals, tsrvf.initial_values[i])
292        })
293        .collect();
294
295    let mut result = FdMatrix::zeros(n, m);
296    for i in 0..n {
297        for j in 0..m {
298            result[(i, j)] = curves[i][j];
299        }
300    }
301    result
302}
303
304/// Full TSRVF pipeline with configurable transport method.
305///
306/// Like [`tsrvf_transform`] but allows choosing the parallel transport method.
307pub fn tsrvf_transform_with_method(
308    data: &FdMatrix,
309    argvals: &[f64],
310    max_iter: usize,
311    tol: f64,
312    lambda: f64,
313    method: TransportMethod,
314) -> TsrvfResult {
315    let karcher = karcher_mean(data, argvals, max_iter, tol, lambda);
316    tsrvf_from_alignment_with_method(&karcher, argvals, method)
317}
318
319/// Compute TSRVF from a pre-computed Karcher mean with configurable transport.
320///
321/// - [`TransportMethod::LogMap`]: Uses `inv_exp_map(mu, qi)` directly (standard TSRVF).
322/// - [`TransportMethod::SchildsLadder`]: Computes `v = -log_qi(mu)`, then transports
323///   via Schild's ladder from qi to mu.
324/// - [`TransportMethod::PoleLadder`]: Same but via pole ladder.
325pub fn tsrvf_from_alignment_with_method(
326    karcher: &KarcherMeanResult,
327    argvals: &[f64],
328    method: TransportMethod,
329) -> TsrvfResult {
330    if method == TransportMethod::LogMap {
331        return tsrvf_from_alignment(karcher, argvals);
332    }
333
334    let (n, m) = karcher.aligned_data.shape();
335    let aligned_srsf = srsf_transform(&karcher.aligned_data, argvals);
336    let aligned_srsf = smooth_aligned_srsfs(&aligned_srsf, m);
337    let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
338    let bandwidth = 2.0 / (m - 1) as f64;
339    let mean_srsf_smooth = nadaraya_watson(&time, &karcher.mean_srsf, &time, bandwidth, "gaussian")
340        .expect("smoothing valid mean SRSF should not fail");
341    let mean_norm = l2_norm_l2(&mean_srsf_smooth, &time);
342
343    let mu_unit: Vec<f64> = if mean_norm > 1e-10 {
344        mean_srsf_smooth.iter().map(|&q| q / mean_norm).collect()
345    } else {
346        vec![0.0; m]
347    };
348
349    let srsf_norms: Vec<f64> = iter_maybe_parallel!(0..n)
350        .map(|i| {
351            let qi = aligned_srsf.row(i);
352            l2_norm_l2(&qi, &time)
353        })
354        .collect();
355
356    let tangent_data: Vec<Vec<f64>> = iter_maybe_parallel!(0..n)
357        .map(|i| {
358            let qi = aligned_srsf.row(i);
359            let qi_norm = srsf_norms[i];
360
361            if qi_norm < 1e-10 || mean_norm < 1e-10 {
362                return vec![0.0; m];
363            }
364
365            let qi_unit: Vec<f64> = qi.iter().map(|&q| q / qi_norm).collect();
366
367            // Compute v = -log_qi(mu) — vector at qi pointing away from mu
368            let v_at_qi = inv_exp_map_sphere(&qi_unit, &mu_unit, &time);
369            let neg_v: Vec<f64> = v_at_qi.iter().map(|&x| -x).collect();
370
371            // Transport from qi to mu
372            match method {
373                TransportMethod::SchildsLadder => {
374                    parallel_transport_schilds(&neg_v, &qi_unit, &mu_unit, &time)
375                }
376                TransportMethod::PoleLadder => {
377                    parallel_transport_pole(&neg_v, &qi_unit, &mu_unit, &time)
378                }
379                TransportMethod::LogMap => unreachable!(),
380            }
381        })
382        .collect();
383
384    let mut tangent_vectors = FdMatrix::zeros(n, m);
385    for i in 0..n {
386        for j in 0..m {
387            tangent_vectors[(i, j)] = tangent_data[i][j];
388        }
389    }
390
391    let initial_values: Vec<f64> = (0..n).map(|i| karcher.aligned_data[(i, 0)]).collect();
392
393    TsrvfResult {
394        tangent_vectors,
395        mean: karcher.mean.clone(),
396        mean_srsf: mean_srsf_smooth,
397        mean_srsf_norm: mean_norm,
398        srsf_norms,
399        initial_values,
400        gammas: karcher.gammas.clone(),
401        converged: karcher.converged,
402    }
403}