Skip to main content

fdars_core/alignment/
robust_karcher.rs

1//! Robust alternatives to the Karcher mean: median and trimmed mean.
2//!
3//! The standard Karcher mean is sensitive to outlier curves. This module
4//! provides two robust alternatives:
5//!
6//! - [`karcher_median`] — Geometric median via iteratively reweighted
7//!   Karcher mean (Weiszfeld algorithm on the elastic manifold).
8//! - [`robust_karcher_mean`] — Trimmed Karcher mean that removes the
9//!   most distant curves before averaging.
10
11use super::karcher::karcher_mean;
12use super::pairwise::elastic_distance;
13use super::set::align_to_target;
14use super::srsf::srsf_single;
15use crate::error::FdarError;
16use crate::matrix::FdMatrix;
17
18/// Configuration for robust Karcher estimation.
19#[derive(Debug, Clone, PartialEq)]
20pub struct RobustKarcherConfig {
21    /// Maximum number of outer iterations.
22    pub max_iter: usize,
23    /// Convergence tolerance (relative change in SRSF).
24    pub tol: f64,
25    /// Roughness penalty for elastic alignment (0.0 = no penalty).
26    pub lambda: f64,
27    /// Fraction of most-distant curves to trim (for trimmed mean).
28    pub trim_fraction: f64,
29}
30
31impl Default for RobustKarcherConfig {
32    fn default() -> Self {
33        Self {
34            max_iter: 20,
35            tol: 1e-3,
36            lambda: 0.0,
37            trim_fraction: 0.1,
38        }
39    }
40}
41
42/// Result of robust Karcher estimation.
43#[derive(Debug, Clone, PartialEq)]
44#[non_exhaustive]
45pub struct RobustKarcherResult {
46    /// Robust mean/median curve.
47    pub mean: Vec<f64>,
48    /// SRSF of the robust mean/median.
49    pub mean_srsf: Vec<f64>,
50    /// Warping functions for all curves (n x m).
51    pub gammas: FdMatrix,
52    /// All curves aligned to the robust mean/median (n x m).
53    pub aligned_data: FdMatrix,
54    /// Per-curve weights (1/distance for median, 0/1 for trimmed).
55    pub weights: Vec<f64>,
56    /// Number of iterations performed.
57    pub n_iter: usize,
58    /// Whether the algorithm converged.
59    pub converged: bool,
60}
61
62/// Compute the Karcher median via the Weiszfeld algorithm on the elastic manifold.
63///
64/// The geometric median minimizes the sum of elastic distances to all curves,
65/// rather than the sum of squared distances (as with the mean). This makes it
66/// robust to outlier curves.
67///
68/// # Algorithm
69/// 1. Initialize with standard Karcher mean (1 iteration) as starting point.
70/// 2. Iterative Weiszfeld loop:
71///    a. Align all curves to the current median estimate.
72///    b. Compute elastic distances.
73///    c. Set weights w_i = 1 / max(d_i, epsilon), normalize.
74///    d. Compute weighted pointwise mean of aligned curves.
75///    e. Check convergence (relative change in SRSF).
76///
77/// # Arguments
78/// * `data`    — Functional data matrix (n x m).
79/// * `argvals` — Evaluation points (length m).
80/// * `config`  — Configuration parameters.
81///
82/// # Errors
83/// Returns [`FdarError::InvalidDimension`] if `argvals` length does not match `m`
84/// or `n < 2`.
85#[must_use = "expensive computation whose result should not be discarded"]
86pub fn karcher_median(
87    data: &FdMatrix,
88    argvals: &[f64],
89    config: &RobustKarcherConfig,
90) -> Result<RobustKarcherResult, FdarError> {
91    let (n, m) = data.shape();
92    validate_inputs(n, m, argvals)?;
93
94    // Step 1: Initialize with a quick Karcher mean (1 iteration).
95    let init = karcher_mean(data, argvals, 1, config.tol, config.lambda);
96    let mut current_mean = init.mean;
97
98    let mut converged = false;
99    let mut n_iter = 0;
100    let mut weights = vec![1.0 / n as f64; n];
101    let mut alignment_result = align_to_target(data, &current_mean, argvals, config.lambda);
102
103    // Step 2: Weiszfeld iterations.
104    for iter in 0..config.max_iter {
105        n_iter = iter + 1;
106
107        // Compute elastic distances.
108        let distances: Vec<f64> = (0..n)
109            .map(|i| {
110                let fi = data.row(i);
111                elastic_distance(&current_mean, &fi, argvals, config.lambda)
112            })
113            .collect();
114
115        // Compute weights: w_i = 1 / max(d_i, epsilon).
116        let epsilon = 1e-10;
117        let raw_weights: Vec<f64> = distances.iter().map(|&d| 1.0 / d.max(epsilon)).collect();
118        let w_sum: f64 = raw_weights.iter().sum();
119        weights = raw_weights.iter().map(|&w| w / w_sum).collect();
120
121        // Weighted pointwise mean of aligned curves.
122        let mut new_mean = vec![0.0; m];
123        for i in 0..n {
124            for j in 0..m {
125                new_mean[j] += weights[i] * alignment_result.aligned_data[(i, j)];
126            }
127        }
128
129        // Check convergence.
130        let old_srsf = srsf_single(&current_mean, argvals);
131        let new_srsf = srsf_single(&new_mean, argvals);
132        let rel = relative_srsf_change(&old_srsf, &new_srsf);
133
134        current_mean = new_mean;
135
136        if rel < config.tol {
137            converged = true;
138            // Final alignment to converged median.
139            alignment_result = align_to_target(data, &current_mean, argvals, config.lambda);
140            break;
141        }
142
143        // Re-align to updated median.
144        alignment_result = align_to_target(data, &current_mean, argvals, config.lambda);
145    }
146
147    let mean_srsf = srsf_single(&current_mean, argvals);
148
149    Ok(RobustKarcherResult {
150        mean: current_mean,
151        mean_srsf,
152        gammas: alignment_result.gammas,
153        aligned_data: alignment_result.aligned_data,
154        weights,
155        n_iter,
156        converged,
157    })
158}
159
160/// Compute a trimmed Karcher mean by removing the most distant curves.
161///
162/// Computes the standard Karcher mean, identifies and removes the top
163/// `trim_fraction` of curves by elastic distance, then recomputes the
164/// Karcher mean on the remaining curves. All curves (including trimmed
165/// ones) are re-aligned to the robust mean for the final output.
166///
167/// # Arguments
168/// * `data`    — Functional data matrix (n x m).
169/// * `argvals` — Evaluation points (length m).
170/// * `config`  — Configuration parameters.
171///
172/// # Errors
173/// Returns [`FdarError::InvalidDimension`] if `argvals` length does not match `m`
174/// or `n < 2`.
175/// Returns [`FdarError::InvalidParameter`] if `trim_fraction` is not in \[0, 1).
176#[must_use = "expensive computation whose result should not be discarded"]
177pub fn robust_karcher_mean(
178    data: &FdMatrix,
179    argvals: &[f64],
180    config: &RobustKarcherConfig,
181) -> Result<RobustKarcherResult, FdarError> {
182    let (n, m) = data.shape();
183    validate_inputs(n, m, argvals)?;
184
185    if !(0.0..1.0).contains(&config.trim_fraction) {
186        return Err(FdarError::InvalidParameter {
187            parameter: "trim_fraction",
188            message: format!("must be in [0, 1), got {}", config.trim_fraction),
189        });
190    }
191
192    // Step 1: Compute standard Karcher mean.
193    let initial_mean = karcher_mean(data, argvals, config.max_iter, config.tol, config.lambda);
194
195    // Step 2: Compute elastic distances from the mean.
196    let distances: Vec<f64> = (0..n)
197        .map(|i| {
198            let fi = data.row(i);
199            elastic_distance(&initial_mean.mean, &fi, argvals, config.lambda)
200        })
201        .collect();
202
203    // Step 3: Sort by distance, identify curves to trim.
204    let n_trim = ((n as f64) * config.trim_fraction).ceil() as usize;
205    let n_keep = n.saturating_sub(n_trim).max(2); // Keep at least 2 curves.
206
207    let mut indexed_distances: Vec<(usize, f64)> =
208        distances.iter().enumerate().map(|(i, &d)| (i, d)).collect();
209    indexed_distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
210
211    let kept_indices: Vec<usize> = indexed_distances
212        .iter()
213        .take(n_keep)
214        .map(|&(i, _)| i)
215        .collect();
216
217    // Step 4: Set weights.
218    let mut weights = vec![0.0; n];
219    for &idx in &kept_indices {
220        weights[idx] = 1.0;
221    }
222
223    // Step 5: Recompute Karcher mean on the kept subset.
224    let kept_data = subset_rows_from_indices(data, &kept_indices);
225    let robust_mean = karcher_mean(
226        &kept_data,
227        argvals,
228        config.max_iter,
229        config.tol,
230        config.lambda,
231    );
232
233    // Step 6: Re-align ALL curves (including trimmed) to the robust mean.
234    let final_alignment = align_to_target(data, &robust_mean.mean, argvals, config.lambda);
235
236    let mean_srsf = srsf_single(&robust_mean.mean, argvals);
237
238    Ok(RobustKarcherResult {
239        mean: robust_mean.mean,
240        mean_srsf,
241        gammas: final_alignment.gammas,
242        aligned_data: final_alignment.aligned_data,
243        weights,
244        n_iter: robust_mean.n_iter,
245        converged: robust_mean.converged,
246    })
247}
248
249/// Validate common input dimensions.
250fn validate_inputs(n: usize, m: usize, argvals: &[f64]) -> Result<(), FdarError> {
251    if argvals.len() != m {
252        return Err(FdarError::InvalidDimension {
253            parameter: "argvals",
254            expected: format!("{m}"),
255            actual: format!("{}", argvals.len()),
256        });
257    }
258    if n < 2 {
259        return Err(FdarError::InvalidDimension {
260            parameter: "data",
261            expected: "at least 2 rows".to_string(),
262            actual: format!("{n} rows"),
263        });
264    }
265    Ok(())
266}
267
268/// Compute relative change between successive SRSFs.
269fn relative_srsf_change(q_old: &[f64], q_new: &[f64]) -> f64 {
270    let diff_norm: f64 = q_old
271        .iter()
272        .zip(q_new.iter())
273        .map(|(&a, &b)| (a - b).powi(2))
274        .sum::<f64>()
275        .sqrt();
276    let old_norm: f64 = q_old.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
277    diff_norm / old_norm
278}
279
280use crate::cv::subset_rows as subset_rows_from_indices;
281
282#[cfg(test)]
283mod tests {
284    use super::*;
285    use crate::test_helpers::uniform_grid;
286
287    fn make_sine_data(n: usize, m: usize) -> (FdMatrix, Vec<f64>) {
288        let t = uniform_grid(m);
289        let mut data_vec = vec![0.0; n * m];
290        for i in 0..n {
291            let phase = 0.03 * i as f64;
292            for j in 0..m {
293                data_vec[i + j * n] = ((t[j] + phase) * 4.0).sin();
294            }
295        }
296        let data = FdMatrix::from_column_major(data_vec, n, m).unwrap();
297        (data, t)
298    }
299
300    #[test]
301    fn karcher_median_basic() {
302        let (data, t) = make_sine_data(5, 20);
303        let config = RobustKarcherConfig {
304            max_iter: 5,
305            ..Default::default()
306        };
307        let result = karcher_median(&data, &t, &config).unwrap();
308        assert_eq!(result.mean.len(), 20);
309        assert_eq!(result.mean_srsf.len(), 20);
310        assert_eq!(result.gammas.shape(), (5, 20));
311        assert_eq!(result.aligned_data.shape(), (5, 20));
312        assert_eq!(result.weights.len(), 5);
313        assert!(result.n_iter >= 1);
314    }
315
316    #[test]
317    fn karcher_median_robust_to_outlier() {
318        let m = 20;
319        let t = uniform_grid(m);
320        let n = 6;
321        let mut data_vec = vec![0.0; n * m];
322
323        // 5 clean curves (slight phase shifts).
324        for i in 0..5 {
325            let phase = 0.02 * i as f64;
326            for j in 0..m {
327                data_vec[i + j * n] = ((t[j] + phase) * 4.0).sin();
328            }
329        }
330        // 1 extreme outlier.
331        for j in 0..m {
332            data_vec[5 + j * n] = (t[j] * 20.0).cos() * 5.0;
333        }
334        let data = FdMatrix::from_column_major(data_vec, n, m).unwrap();
335
336        // Compute standard mean and median.
337        let std_mean = karcher_mean(&data, &t, 5, 1e-3, 0.0);
338        let median_config = RobustKarcherConfig {
339            max_iter: 5,
340            ..Default::default()
341        };
342        let median_result = karcher_median(&data, &t, &median_config).unwrap();
343
344        // Compute a clean reference (mean of just the clean curves).
345        let clean_data = subset_rows_from_indices(&data, &[0, 1, 2, 3, 4]);
346        let clean_mean = karcher_mean(&clean_data, &t, 5, 1e-3, 0.0);
347
348        // Median should be closer to the clean mean than the standard mean is.
349        let d_std = pointwise_l2(&std_mean.mean, &clean_mean.mean);
350        let d_median = pointwise_l2(&median_result.mean, &clean_mean.mean);
351        assert!(
352            d_median <= d_std + 1e-6,
353            "median distance to clean ({d_median:.4}) should be <= standard mean distance ({d_std:.4})"
354        );
355    }
356
357    #[test]
358    fn robust_trimmed_removes_outliers() {
359        let m = 20;
360        let t = uniform_grid(m);
361        let n = 6;
362        let mut data_vec = vec![0.0; n * m];
363
364        // 5 clean curves.
365        for i in 0..5 {
366            let phase = 0.02 * i as f64;
367            for j in 0..m {
368                data_vec[i + j * n] = ((t[j] + phase) * 4.0).sin();
369            }
370        }
371        // 1 extreme outlier.
372        for j in 0..m {
373            data_vec[5 + j * n] = (t[j] * 20.0).cos() * 5.0;
374        }
375        let data = FdMatrix::from_column_major(data_vec, n, m).unwrap();
376
377        let config = RobustKarcherConfig {
378            max_iter: 5,
379            trim_fraction: 0.2, // Trim top 20% (= 2 curves out of 6).
380            ..Default::default()
381        };
382        let result = robust_karcher_mean(&data, &t, &config).unwrap();
383
384        // The outlier (index 5) should have weight 0.0.
385        assert!(
386            result.weights[5] < 1e-10,
387            "outlier weight should be 0, got {}",
388            result.weights[5]
389        );
390
391        // At least some curves should have weight 1.0.
392        let n_kept: usize = result.weights.iter().filter(|&&w| w > 0.5).count();
393        assert!(n_kept >= 4, "should keep at least 4 curves, got {n_kept}");
394    }
395
396    #[test]
397    fn robust_config_default() {
398        let cfg = RobustKarcherConfig::default();
399        assert_eq!(cfg.max_iter, 20);
400        assert!((cfg.tol - 1e-3).abs() < f64::EPSILON);
401        assert!((cfg.lambda - 0.0).abs() < f64::EPSILON);
402        assert!((cfg.trim_fraction - 0.1).abs() < f64::EPSILON);
403    }
404
405    /// Simple pointwise L2 distance between two curves.
406    fn pointwise_l2(a: &[f64], b: &[f64]) -> f64 {
407        a.iter()
408            .zip(b.iter())
409            .map(|(&x, &y)| (x - y).powi(2))
410            .sum::<f64>()
411            .sqrt()
412    }
413}