Skip to main content

fdars_core/alignment/
karcher.rs

1//! Karcher (Frechet) mean computation in the elastic metric.
2
3use super::set::apply_stored_warps;
4use super::srsf::{reparameterize_curve, srsf_inverse, srsf_transform};
5use super::{dp_alignment_core, KarcherMeanResult};
6use crate::fdata::mean_1d;
7use crate::helpers::{gradient_uniform, linear_interp};
8use crate::iter_maybe_parallel;
9use crate::matrix::FdMatrix;
10use crate::warping::{
11    exp_map_sphere, gam_to_psi, inv_exp_map_sphere, invert_gamma, l2_norm_l2, psi_to_gam,
12};
13#[cfg(feature = "parallel")]
14use rayon::iter::ParallelIterator;
15
16// Re-export srsf_single from srsf module for internal use
17use super::srsf::srsf_single;
18
19// ─── Helpers ─────────────────────────────────────────────────────────────────
20
21/// One Karcher iteration on the Hilbert sphere: compute mean shooting vector and update mu.
22///
23/// Returns `true` if converged (vbar norm ≤ threshold).
24fn karcher_sphere_step(mu: &mut Vec<f64>, psis: &[Vec<f64>], time: &[f64], step_size: f64) -> bool {
25    let m = mu.len();
26    let n = psis.len();
27    let mut vbar = vec![0.0; m];
28    for psi in psis {
29        let v = inv_exp_map_sphere(mu, psi, time);
30        for j in 0..m {
31            vbar[j] += v[j];
32        }
33    }
34    for j in 0..m {
35        vbar[j] /= n as f64;
36    }
37    if l2_norm_l2(&vbar, time) <= 1e-8 {
38        return true;
39    }
40    let scaled: Vec<f64> = vbar.iter().map(|&v| v * step_size).collect();
41    *mu = exp_map_sphere(mu, &scaled, time);
42    false
43}
44
45/// Karcher mean of warping functions on the Hilbert sphere, then invert.
46/// Port of fdasrvf's `SqrtMeanInverse`.
47pub(crate) fn sqrt_mean_inverse(gammas: &FdMatrix, argvals: &[f64]) -> Vec<f64> {
48    let (n, m) = gammas.shape();
49    let t0 = argvals[0];
50    let t1 = argvals[m - 1];
51    let domain = t1 - t0;
52
53    let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
54    let binsize = 1.0 / (m - 1) as f64;
55
56    let psis: Vec<Vec<f64>> = (0..n)
57        .map(|i| {
58            let gam_01: Vec<f64> = (0..m).map(|j| (gammas[(i, j)] - t0) / domain).collect();
59            gam_to_psi(&gam_01, binsize)
60        })
61        .collect();
62
63    let mut mu = vec![0.0; m];
64    for psi in &psis {
65        for j in 0..m {
66            mu[j] += psi[j];
67        }
68    }
69    for j in 0..m {
70        mu[j] /= n as f64;
71    }
72
73    for _ in 0..501 {
74        if karcher_sphere_step(&mut mu, &psis, &time, 0.3) {
75            break;
76        }
77    }
78
79    let gam_mu = psi_to_gam(&mu, &time);
80    let gam_inv = invert_gamma(&gam_mu, &time);
81    gam_inv.iter().map(|&g| t0 + g * domain).collect()
82}
83
84/// Compute relative change between successive mean SRSFs.
85///
86/// Returns `‖q_new - q_old‖₂ / ‖q_old‖₂`, matching R's fdasrvf
87/// `time_warping` convergence metric (unweighted discrete L2 norm).
88fn relative_change(q_old: &[f64], q_new: &[f64]) -> f64 {
89    let diff_norm: f64 = q_old
90        .iter()
91        .zip(q_new.iter())
92        .map(|(&a, &b)| (a - b).powi(2))
93        .sum::<f64>()
94        .sqrt();
95    let old_norm: f64 = q_old.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
96    diff_norm / old_norm
97}
98
99/// Align a single SRSF q2 to q1 and return (gamma, aligned_q).
100pub(super) fn align_srsf_pair(
101    q1: &[f64],
102    q2: &[f64],
103    argvals: &[f64],
104    lambda: f64,
105) -> (Vec<f64>, Vec<f64>) {
106    let gamma = dp_alignment_core(q1, q2, argvals, lambda);
107
108    // Warp q2 by gamma and adjust by sqrt(gamma')
109    let q2_warped = reparameterize_curve(q2, argvals, &gamma);
110
111    // Compute gamma' via finite differences
112    let m = gamma.len();
113    let mut gamma_dot = vec![0.0; m];
114    gamma_dot[0] = (gamma[1] - gamma[0]) / (argvals[1] - argvals[0]);
115    for j in 1..(m - 1) {
116        gamma_dot[j] = (gamma[j + 1] - gamma[j - 1]) / (argvals[j + 1] - argvals[j - 1]);
117    }
118    gamma_dot[m - 1] = (gamma[m - 1] - gamma[m - 2]) / (argvals[m - 1] - argvals[m - 2]);
119
120    // q2_aligned = (q2 ∘ γ) * sqrt(γ')
121    let q2_aligned: Vec<f64> = q2_warped
122        .iter()
123        .zip(gamma_dot.iter())
124        .map(|(&q, &gd)| q * gd.max(0.0).sqrt())
125        .collect();
126
127    (gamma, q2_aligned)
128}
129
130/// Accumulate alignment results: store gammas and return the mean of aligned SRSFs.
131fn accumulate_alignments(
132    results: &[(Vec<f64>, Vec<f64>)],
133    gammas: &mut FdMatrix,
134    m: usize,
135    n: usize,
136) -> Vec<f64> {
137    let mut mu_q_new = vec![0.0; m];
138    for (i, (gamma, q_aligned)) in results.iter().enumerate() {
139        for j in 0..m {
140            gammas[(i, j)] = gamma[j];
141            mu_q_new[j] += q_aligned[j];
142        }
143    }
144    for j in 0..m {
145        mu_q_new[j] /= n as f64;
146    }
147    mu_q_new
148}
149
150/// Select the SRSF closest to the pointwise mean as template. Returns (mu_q, mu_f).
151fn select_template(srsf_mat: &FdMatrix, data: &FdMatrix, argvals: &[f64]) -> (Vec<f64>, Vec<f64>) {
152    let (n, m) = srsf_mat.shape();
153    let mnq = mean_1d(srsf_mat);
154    let mut min_dist = f64::INFINITY;
155    let mut min_idx = 0;
156    for i in 0..n {
157        let dist_sq: f64 = (0..m).map(|j| (srsf_mat[(i, j)] - mnq[j]).powi(2)).sum();
158        if dist_sq < min_dist {
159            min_dist = dist_sq;
160            min_idx = i;
161        }
162    }
163    let _ = argvals; // kept for API consistency
164    (srsf_mat.row(min_idx), data.row(min_idx))
165}
166
167/// Pre-centering: align all curves to template, compute inverse mean warp, re-center.
168fn pre_center_template(
169    data: &FdMatrix,
170    mu_q: &[f64],
171    mu: &[f64],
172    argvals: &[f64],
173    lambda: f64,
174) -> (Vec<f64>, Vec<f64>) {
175    let (n, m) = data.shape();
176    let align_results: Vec<(Vec<f64>, Vec<f64>)> = iter_maybe_parallel!(0..n)
177        .map(|i| {
178            let fi = data.row(i);
179            let qi = srsf_single(&fi, argvals);
180            align_srsf_pair(mu_q, &qi, argvals, lambda)
181        })
182        .collect();
183
184    let mut init_gammas = FdMatrix::zeros(n, m);
185    for (i, (gamma, _)) in align_results.iter().enumerate() {
186        for j in 0..m {
187            init_gammas[(i, j)] = gamma[j];
188        }
189    }
190
191    let gam_inv = sqrt_mean_inverse(&init_gammas, argvals);
192    let mu_new = reparameterize_curve(mu, argvals, &gam_inv);
193    let mu_q_new = srsf_single(&mu_new, argvals);
194    (mu_q_new, mu_new)
195}
196
197/// Post-convergence centering: center mean SRSF and warps via SqrtMeanInverse.
198fn post_center_results(
199    data: &FdMatrix,
200    mu_q: &[f64],
201    final_gammas: &mut FdMatrix,
202    argvals: &[f64],
203) -> (Vec<f64>, Vec<f64>, FdMatrix) {
204    let (n, m) = data.shape();
205    let gam_inv = sqrt_mean_inverse(final_gammas, argvals);
206    let h = (argvals[m - 1] - argvals[0]) / (m - 1) as f64;
207    let gam_inv_dev = gradient_uniform(&gam_inv, h);
208
209    let mu_q_warped = reparameterize_curve(mu_q, argvals, &gam_inv);
210    let mu_q_centered: Vec<f64> = mu_q_warped
211        .iter()
212        .zip(gam_inv_dev.iter())
213        .map(|(&q, &gd)| q * gd.max(0.0).sqrt())
214        .collect();
215
216    for i in 0..n {
217        let gam_i: Vec<f64> = (0..m).map(|j| final_gammas[(i, j)]).collect();
218        let gam_centered = reparameterize_curve(&gam_i, argvals, &gam_inv);
219        for j in 0..m {
220            final_gammas[(i, j)] = gam_centered[j];
221        }
222    }
223
224    let initial_mean = mean_1d(data);
225    let mu = srsf_inverse(&mu_q_centered, argvals, initial_mean[0]);
226    let final_aligned = apply_stored_warps(data, final_gammas, argvals);
227    (mu, mu_q_centered, final_aligned)
228}
229
230/// Downsample argvals and signal by `factor`, keeping first and last points.
231fn downsample_uniform(signal: &[f64], argvals: &[f64], factor: usize) -> (Vec<f64>, Vec<f64>) {
232    let m = signal.len();
233    if factor <= 1 || m <= 2 {
234        return (signal.to_vec(), argvals.to_vec());
235    }
236    let mut sig = Vec::new();
237    let mut arg = Vec::new();
238    for i in (0..m).step_by(factor) {
239        sig.push(signal[i]);
240        arg.push(argvals[i]);
241    }
242    // Ensure last point is included
243    if (m - 1) % factor != 0 {
244        sig.push(signal[m - 1]);
245        arg.push(argvals[m - 1]);
246    }
247    (sig, arg)
248}
249
250/// Upsample signal from coarse grid to fine grid via linear interpolation.
251fn upsample_to_fine(coarse: &[f64], argvals_coarse: &[f64], argvals_fine: &[f64]) -> Vec<f64> {
252    argvals_fine
253        .iter()
254        .map(|&t| linear_interp(argvals_coarse, coarse, t))
255        .collect()
256}
257
258// ─── Karcher Mean ───────────────────────────────────────────────────────────
259
260/// Compute the Karcher (Frechet) mean in the elastic metric.
261///
262/// Iteratively aligns all curves to the current mean estimate in SRSF space,
263/// computes the pointwise mean of aligned SRSFs, and reconstructs the mean curve.
264///
265/// # Arguments
266/// * `data` — Functional data matrix (n × m)
267/// * `argvals` — Evaluation points (length m)
268/// * `max_iter` — Maximum number of iterations
269/// * `tol` — Convergence tolerance for the SRSF mean
270///
271/// # Returns
272/// [`KarcherMeanResult`] with mean curve, warping functions, aligned data, and convergence info.
273///
274/// # Examples
275///
276/// ```
277/// use fdars_core::simulation::{sim_fundata, EFunType, EValType};
278/// use fdars_core::alignment::karcher_mean;
279///
280/// let t: Vec<f64> = (0..50).map(|i| i as f64 / 49.0).collect();
281/// let data = sim_fundata(20, &t, 3, EFunType::Fourier, EValType::Exponential, Some(42));
282///
283/// let result = karcher_mean(&data, &t, 20, 1e-4, 0.0);
284/// assert_eq!(result.mean.len(), 50);
285/// assert!(result.n_iter <= 20);
286/// ```
287#[must_use = "expensive computation whose result should not be discarded"]
288pub fn karcher_mean(
289    data: &FdMatrix,
290    argvals: &[f64],
291    max_iter: usize,
292    tol: f64,
293    lambda: f64,
294) -> KarcherMeanResult {
295    let (n, m) = data.shape();
296
297    let srsf_mat = srsf_transform(data, argvals);
298    let (mut mu_q, mu) = select_template(&srsf_mat, data, argvals);
299    let (mu_q_c, mu_c) = pre_center_template(data, &mu_q, &mu, argvals, lambda);
300    mu_q = mu_q_c;
301    let mut mu = mu_c;
302
303    let mut converged = false;
304    let mut n_iter = 0;
305    let mut final_gammas = FdMatrix::zeros(n, m);
306
307    // Coarse-to-fine strategy: run initial iterations on downsampled grid
308    // Only worthwhile for large grids with enough iterations to split
309    let coarse_factor = if m > 50 && max_iter >= 10 { 4 } else { 1 };
310    let coarse_iters = if coarse_factor > 1 { max_iter / 2 } else { 0 };
311    let fine_iters = max_iter - coarse_iters;
312
313    // Phase 1: coarse iterations
314    if coarse_iters > 0 {
315        let (mu_q_coarse, argvals_coarse) = downsample_uniform(&mu_q, argvals, coarse_factor);
316        let m_c = argvals_coarse.len();
317        let mut mu_q_c = mu_q_coarse;
318
319        // Downsample all curves to coarse grid
320        let data_coarse: Vec<Vec<f64>> = (0..n)
321            .map(|i| {
322                let row = data.row(i);
323                downsample_uniform(&row, argvals, coarse_factor).0
324            })
325            .collect();
326
327        let mut coarse_gammas = FdMatrix::zeros(n, m_c);
328
329        for iter in 0..coarse_iters {
330            n_iter = iter + 1;
331
332            let align_results: Vec<(Vec<f64>, Vec<f64>)> = iter_maybe_parallel!(0..n)
333                .map(|i| {
334                    let qi = srsf_single(&data_coarse[i], &argvals_coarse);
335                    align_srsf_pair(&mu_q_c, &qi, &argvals_coarse, lambda)
336                })
337                .collect();
338
339            let mu_q_new = accumulate_alignments(&align_results, &mut coarse_gammas, m_c, n);
340
341            let rel = relative_change(&mu_q_c, &mu_q_new);
342            if rel < tol {
343                converged = true;
344                mu_q_c = mu_q_new;
345                break;
346            }
347
348            mu_q_c = mu_q_new;
349        }
350
351        // Upsample coarse mu_q to fine grid
352        mu_q = upsample_to_fine(&mu_q_c, &argvals_coarse, argvals);
353        mu = srsf_inverse(&mu_q, argvals, mu[0]);
354    }
355
356    // Phase 2: fine iterations (or all iterations if m <= 50)
357    if fine_iters > 0 {
358        converged = false; // Fine phase must independently converge
359    }
360    let fine_start = n_iter;
361    for iter in 0..fine_iters {
362        n_iter = fine_start + iter + 1;
363
364        let align_results: Vec<(Vec<f64>, Vec<f64>)> = iter_maybe_parallel!(0..n)
365            .map(|i| {
366                let fi = data.row(i);
367                let qi = srsf_single(&fi, argvals);
368                align_srsf_pair(&mu_q, &qi, argvals, lambda)
369            })
370            .collect();
371
372        let mu_q_new = accumulate_alignments(&align_results, &mut final_gammas, m, n);
373
374        let rel = relative_change(&mu_q, &mu_q_new);
375        if rel < tol {
376            converged = true;
377            mu_q = mu_q_new;
378            break;
379        }
380
381        mu_q = mu_q_new;
382        mu = srsf_inverse(&mu_q, argvals, mu[0]);
383    }
384
385    // If coarse converged but no fine iterations ran, do one fine pass for final_gammas
386    if converged && fine_start > 0 {
387        let align_results: Vec<(Vec<f64>, Vec<f64>)> = iter_maybe_parallel!(0..n)
388            .map(|i| {
389                let fi = data.row(i);
390                let qi = srsf_single(&fi, argvals);
391                align_srsf_pair(&mu_q, &qi, argvals, lambda)
392            })
393            .collect();
394        let mu_q_new = accumulate_alignments(&align_results, &mut final_gammas, m, n);
395        mu_q = mu_q_new;
396    }
397
398    let (mu_final, mu_q_final, final_aligned) =
399        post_center_results(data, &mu_q, &mut final_gammas, argvals);
400
401    KarcherMeanResult {
402        mean: mu_final,
403        mean_srsf: mu_q_final,
404        gammas: final_gammas,
405        aligned_data: final_aligned,
406        n_iter,
407        converged,
408        aligned_srsfs: None,
409    }
410}