Skip to main content

fdars_core/alignment/
tsrvf.rs

1//! TSRVF (Transported SRSF) transforms and parallel transport methods.
2//!
3//! Maps aligned SRSFs to the tangent space of the Karcher mean on the Hilbert
4//! sphere. Tangent vectors live in a standard Euclidean space, enabling PCA,
5//! regression, and clustering on elastic-aligned curves.
6
7use super::karcher::karcher_mean;
8use super::srsf::{srsf_inverse, srsf_transform};
9use super::KarcherMeanResult;
10use crate::iter_maybe_parallel;
11use crate::matrix::FdMatrix;
12use crate::smoothing::nadaraya_watson;
13use crate::warping::{exp_map_sphere, inv_exp_map_sphere, l2_norm_l2};
14#[cfg(feature = "parallel")]
15use rayon::iter::ParallelIterator;
16
17// ─── Types ──────────────────────────────────────────────────────────────────
18
19/// Result of the TSRVF transform.
20#[derive(Debug, Clone, PartialEq)]
21pub struct TsrvfResult {
22    /// Tangent vectors in Euclidean space (n × m).
23    pub tangent_vectors: FdMatrix,
24    /// Karcher mean curve (length m).
25    pub mean: Vec<f64>,
26    /// SRSF of the Karcher mean (length m).
27    pub mean_srsf: Vec<f64>,
28    /// L2 norm of the mean SRSF.
29    pub mean_srsf_norm: f64,
30    /// Per-curve aligned SRSF norms (length n).
31    pub srsf_norms: Vec<f64>,
32    /// Per-curve initial values f_i(0) for SRSF inverse reconstruction (length n).
33    pub initial_values: Vec<f64>,
34    /// Warping functions from Karcher mean computation (n × m).
35    pub gammas: FdMatrix,
36    /// Whether the Karcher mean converged.
37    pub converged: bool,
38}
39
40/// Method for transporting tangent vectors on the Hilbert sphere.
41#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
42pub enum TransportMethod {
43    /// Inverse exponential map (log map) — default, matches existing TSRVF behavior.
44    #[default]
45    LogMap,
46    /// Schild's ladder approximation to parallel transport.
47    SchildsLadder,
48    /// Pole ladder approximation to parallel transport.
49    PoleLadder,
50}
51
52// ─── Smoothing ──────────────────────────────────────────────────────────────
53
54/// Smooth aligned SRSFs to remove DP kink artifacts before TSRVF computation.
55///
56/// Uses Nadaraya-Watson kernel smoothing (Gaussian, bandwidth = 2 grid spacings)
57/// on each SRSF row. This removes the derivative spikes from DP warp kinks
58/// without affecting alignment results or the Karcher mean.
59fn smooth_aligned_srsfs(srsf: &FdMatrix, m: usize) -> FdMatrix {
60    let n = srsf.nrows();
61    let time: Vec<f64> = (0..m).map(|j| j as f64 / (m - 1) as f64).collect();
62    let bandwidth = 2.0 / (m - 1) as f64;
63
64    let mut smoothed = FdMatrix::zeros(n, m);
65    for i in 0..n {
66        let qi = srsf.row(i);
67        let qi_smooth = nadaraya_watson(&time, &qi, &time, bandwidth, "gaussian");
68        for j in 0..m {
69            smoothed[(i, j)] = qi_smooth[j];
70        }
71    }
72    smoothed
73}
74
75// ─── Parallel Transport ─────────────────────────────────────────────────────
76
77/// Schild's ladder parallel transport of vector `v` from `from` to `to` on the sphere.
78pub(super) fn parallel_transport_schilds(
79    v: &[f64],
80    from: &[f64],
81    to: &[f64],
82    time: &[f64],
83) -> Vec<f64> {
84    let v_norm = l2_norm_l2(v, time);
85    if v_norm < 1e-10 {
86        return vec![0.0; v.len()];
87    }
88
89    // endpoint = exp_from(v)
90    let endpoint = exp_map_sphere(from, v, time);
91
92    // midpoint_v = log_to(endpoint) — vector at `to` pointing toward endpoint
93    let log_to_ep = inv_exp_map_sphere(to, &endpoint, time);
94
95    // midpoint = exp_to(0.5 * log_to_ep)
96    let half_log: Vec<f64> = log_to_ep.iter().map(|&x| 0.5 * x).collect();
97    let midpoint = exp_map_sphere(to, &half_log, time);
98
99    // transported = 2 * log_to(midpoint)
100    let log_to_mid = inv_exp_map_sphere(to, &midpoint, time);
101    log_to_mid.iter().map(|&x| 2.0 * x).collect()
102}
103
104/// Pole ladder parallel transport of vector `v` from `from` to `to` on the sphere.
105pub(super) fn parallel_transport_pole(
106    v: &[f64],
107    from: &[f64],
108    to: &[f64],
109    time: &[f64],
110) -> Vec<f64> {
111    let v_norm = l2_norm_l2(v, time);
112    if v_norm < 1e-10 {
113        return vec![0.0; v.len()];
114    }
115
116    // pole = exp_from(-v)
117    let neg_v: Vec<f64> = v.iter().map(|&x| -x).collect();
118    let pole = exp_map_sphere(from, &neg_v, time);
119
120    // midpoint_v = log_to(pole)
121    let log_to_pole = inv_exp_map_sphere(to, &pole, time);
122
123    // midpoint = exp_to(0.5 * log_to_pole)
124    let half_log: Vec<f64> = log_to_pole.iter().map(|&x| 0.5 * x).collect();
125    let midpoint = exp_map_sphere(to, &half_log, time);
126
127    // transported = -2 * log_to(midpoint)
128    let log_to_mid = inv_exp_map_sphere(to, &midpoint, time);
129    log_to_mid.iter().map(|&x| -2.0 * x).collect()
130}
131
132// ─── TSRVF Public API ───────────────────────────────────────────────────────
133
134/// Full TSRVF pipeline: compute Karcher mean, then transport SRSFs to tangent space.
135///
136/// # Arguments
137/// * `data` — Functional data matrix (n × m)
138/// * `argvals` — Evaluation points (length m)
139/// * `max_iter` — Maximum Karcher mean iterations
140/// * `tol` — Convergence tolerance for Karcher mean
141///
142/// # Returns
143/// [`TsrvfResult`] containing tangent vectors and associated metadata.
144pub fn tsrvf_transform(
145    data: &FdMatrix,
146    argvals: &[f64],
147    max_iter: usize,
148    tol: f64,
149    lambda: f64,
150) -> TsrvfResult {
151    let karcher = karcher_mean(data, argvals, max_iter, tol, lambda);
152    tsrvf_from_alignment(&karcher, argvals)
153}
154
155/// Compute TSRVF from a pre-computed Karcher mean alignment.
156///
157/// Avoids re-running the expensive Karcher mean computation when the alignment
158/// has already been computed.
159///
160/// # Arguments
161/// * `karcher` — Pre-computed Karcher mean result
162/// * `argvals` — Evaluation points (length m)
163///
164/// # Returns
165/// [`TsrvfResult`] containing tangent vectors and associated metadata.
166pub fn tsrvf_from_alignment(karcher: &KarcherMeanResult, argvals: &[f64]) -> TsrvfResult {
167    let (n, m) = karcher.aligned_data.shape();
168
169    // Step 1: Compute SRSFs of aligned data
170    let aligned_srsf = srsf_transform(&karcher.aligned_data, argvals);
171
172    // Step 1b: Smooth aligned SRSFs to remove DP kink artifacts.
173    //
174    // DP alignment produces piecewise-linear warps with kinks at grid transitions.
175    // When curves are reparameterized by these warps, the kinks propagate into the
176    // aligned curves' derivatives (SRSFs), creating spikes that dominate TSRVF
177    // tangent vectors and PCA.
178    //
179    // R's fdasrvf does not smooth here and suffers from the same spike artifacts.
180    // Python's fdasrsf mitigates this via spline smoothing (s=1e-4) in SqrtMean.
181    // We smooth the aligned SRSFs before tangent vector computation — this only
182    // affects TSRVF output and does not change the alignment or Karcher mean.
183    let aligned_srsf = smooth_aligned_srsfs(&aligned_srsf, m);
184
185    // Step 2: Smooth and normalize mean SRSF to unit sphere.
186    // The mean SRSF must be smoothed consistently with the aligned SRSFs
187    // so that a single curve (which IS the mean) produces a zero tangent vector.
188    let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
189    let bandwidth = 2.0 / (m - 1) as f64;
190    let mean_srsf_smooth = nadaraya_watson(&time, &karcher.mean_srsf, &time, bandwidth, "gaussian");
191    let mean_norm = l2_norm_l2(&mean_srsf_smooth, &time);
192
193    let mu_unit: Vec<f64> = if mean_norm > 1e-10 {
194        mean_srsf_smooth.iter().map(|&q| q / mean_norm).collect()
195    } else {
196        vec![0.0; m]
197    };
198
199    // Step 3: For each aligned curve, compute tangent vector via inverse exponential map
200    let srsf_norms: Vec<f64> = iter_maybe_parallel!(0..n)
201        .map(|i| {
202            let qi = aligned_srsf.row(i);
203            l2_norm_l2(&qi, &time)
204        })
205        .collect();
206
207    let tangent_data: Vec<Vec<f64>> = iter_maybe_parallel!(0..n)
208        .map(|i| {
209            let qi = aligned_srsf.row(i);
210            let qi_norm = srsf_norms[i];
211
212            if qi_norm < 1e-10 || mean_norm < 1e-10 {
213                return vec![0.0; m];
214            }
215
216            // Normalize to unit sphere
217            let qi_unit: Vec<f64> = qi.iter().map(|&q| q / qi_norm).collect();
218
219            // Shooting vector from mu_unit to qi_unit
220            inv_exp_map_sphere(&mu_unit, &qi_unit, &time)
221        })
222        .collect();
223
224    // Assemble tangent vectors into FdMatrix
225    let mut tangent_vectors = FdMatrix::zeros(n, m);
226    for i in 0..n {
227        for j in 0..m {
228            tangent_vectors[(i, j)] = tangent_data[i][j];
229        }
230    }
231
232    // Store per-curve initial values for SRSF inverse reconstruction.
233    // Warping preserves f_i(0) since gamma(0) = 0.
234    let initial_values: Vec<f64> = (0..n).map(|i| karcher.aligned_data[(i, 0)]).collect();
235
236    TsrvfResult {
237        tangent_vectors,
238        mean: karcher.mean.clone(),
239        mean_srsf: mean_srsf_smooth,
240        mean_srsf_norm: mean_norm,
241        srsf_norms,
242        initial_values,
243        gammas: karcher.gammas.clone(),
244        converged: karcher.converged,
245    }
246}
247
248/// Reconstruct aligned curves from TSRVF tangent vectors.
249///
250/// Inverts the TSRVF transform: maps tangent vectors back to the Hilbert sphere
251/// via the exponential map, rescales, and reconstructs curves via SRSF inverse.
252///
253/// # Arguments
254/// * `tsrvf` — TSRVF result from [`tsrvf_transform`] or [`tsrvf_from_alignment`]
255/// * `argvals` — Evaluation points (length m)
256///
257/// # Returns
258/// FdMatrix of reconstructed aligned curves (n × m).
259pub fn tsrvf_inverse(tsrvf: &TsrvfResult, argvals: &[f64]) -> FdMatrix {
260    let (n, m) = tsrvf.tangent_vectors.shape();
261
262    let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
263
264    // Normalize mean SRSF to unit sphere
265    let mu_unit: Vec<f64> = if tsrvf.mean_srsf_norm > 1e-10 {
266        tsrvf
267            .mean_srsf
268            .iter()
269            .map(|&q| q / tsrvf.mean_srsf_norm)
270            .collect()
271    } else {
272        vec![0.0; m]
273    };
274
275    let curves: Vec<Vec<f64>> = iter_maybe_parallel!(0..n)
276        .map(|i| {
277            let vi = tsrvf.tangent_vectors.row(i);
278
279            // Map back to sphere: exp_map(mu_unit, v_i)
280            let qi_unit = exp_map_sphere(&mu_unit, &vi, &time);
281
282            // Rescale by original norm
283            let qi: Vec<f64> = qi_unit.iter().map(|&q| q * tsrvf.srsf_norms[i]).collect();
284
285            // Reconstruct curve from SRSF using per-curve initial value
286            srsf_inverse(&qi, argvals, tsrvf.initial_values[i])
287        })
288        .collect();
289
290    let mut result = FdMatrix::zeros(n, m);
291    for i in 0..n {
292        for j in 0..m {
293            result[(i, j)] = curves[i][j];
294        }
295    }
296    result
297}
298
299/// Full TSRVF pipeline with configurable transport method.
300///
301/// Like [`tsrvf_transform`] but allows choosing the parallel transport method.
302pub fn tsrvf_transform_with_method(
303    data: &FdMatrix,
304    argvals: &[f64],
305    max_iter: usize,
306    tol: f64,
307    lambda: f64,
308    method: TransportMethod,
309) -> TsrvfResult {
310    let karcher = karcher_mean(data, argvals, max_iter, tol, lambda);
311    tsrvf_from_alignment_with_method(&karcher, argvals, method)
312}
313
314/// Compute TSRVF from a pre-computed Karcher mean with configurable transport.
315///
316/// - [`TransportMethod::LogMap`]: Uses `inv_exp_map(mu, qi)` directly (standard TSRVF).
317/// - [`TransportMethod::SchildsLadder`]: Computes `v = -log_qi(mu)`, then transports
318///   via Schild's ladder from qi to mu.
319/// - [`TransportMethod::PoleLadder`]: Same but via pole ladder.
320pub fn tsrvf_from_alignment_with_method(
321    karcher: &KarcherMeanResult,
322    argvals: &[f64],
323    method: TransportMethod,
324) -> TsrvfResult {
325    if method == TransportMethod::LogMap {
326        return tsrvf_from_alignment(karcher, argvals);
327    }
328
329    let (n, m) = karcher.aligned_data.shape();
330    let aligned_srsf = srsf_transform(&karcher.aligned_data, argvals);
331    let aligned_srsf = smooth_aligned_srsfs(&aligned_srsf, m);
332    let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
333    let bandwidth = 2.0 / (m - 1) as f64;
334    let mean_srsf_smooth = nadaraya_watson(&time, &karcher.mean_srsf, &time, bandwidth, "gaussian");
335    let mean_norm = l2_norm_l2(&mean_srsf_smooth, &time);
336
337    let mu_unit: Vec<f64> = if mean_norm > 1e-10 {
338        mean_srsf_smooth.iter().map(|&q| q / mean_norm).collect()
339    } else {
340        vec![0.0; m]
341    };
342
343    let srsf_norms: Vec<f64> = iter_maybe_parallel!(0..n)
344        .map(|i| {
345            let qi = aligned_srsf.row(i);
346            l2_norm_l2(&qi, &time)
347        })
348        .collect();
349
350    let tangent_data: Vec<Vec<f64>> = iter_maybe_parallel!(0..n)
351        .map(|i| {
352            let qi = aligned_srsf.row(i);
353            let qi_norm = srsf_norms[i];
354
355            if qi_norm < 1e-10 || mean_norm < 1e-10 {
356                return vec![0.0; m];
357            }
358
359            let qi_unit: Vec<f64> = qi.iter().map(|&q| q / qi_norm).collect();
360
361            // Compute v = -log_qi(mu) — vector at qi pointing away from mu
362            let v_at_qi = inv_exp_map_sphere(&qi_unit, &mu_unit, &time);
363            let neg_v: Vec<f64> = v_at_qi.iter().map(|&x| -x).collect();
364
365            // Transport from qi to mu
366            match method {
367                TransportMethod::SchildsLadder => {
368                    parallel_transport_schilds(&neg_v, &qi_unit, &mu_unit, &time)
369                }
370                TransportMethod::PoleLadder => {
371                    parallel_transport_pole(&neg_v, &qi_unit, &mu_unit, &time)
372                }
373                TransportMethod::LogMap => unreachable!(),
374            }
375        })
376        .collect();
377
378    let mut tangent_vectors = FdMatrix::zeros(n, m);
379    for i in 0..n {
380        for j in 0..m {
381            tangent_vectors[(i, j)] = tangent_data[i][j];
382        }
383    }
384
385    let initial_values: Vec<f64> = (0..n).map(|i| karcher.aligned_data[(i, 0)]).collect();
386
387    TsrvfResult {
388        tangent_vectors,
389        mean: karcher.mean.clone(),
390        mean_srsf: mean_srsf_smooth,
391        mean_srsf_norm: mean_norm,
392        srsf_norms,
393        initial_values,
394        gammas: karcher.gammas.clone(),
395        converged: karcher.converged,
396    }
397}