Skip to main content

fdars_core/alignment/
robust_karcher.rs

1//! Robust alternatives to the Karcher mean: median and trimmed mean.
2//!
3//! The standard Karcher mean is sensitive to outlier curves. This module
4//! provides two robust alternatives:
5//!
6//! - [`karcher_median`] — Geometric median via iteratively reweighted
7//!   Karcher mean (Weiszfeld algorithm on the elastic manifold).
8//! - [`robust_karcher_mean`] — Trimmed Karcher mean that removes the
9//!   most distant curves before averaging.
10
11use super::karcher::karcher_mean;
12use super::pairwise::elastic_distance;
13use super::set::align_to_target;
14use super::srsf::srsf_single;
15use crate::error::FdarError;
16use crate::matrix::FdMatrix;
17
18/// Configuration for robust Karcher estimation.
19#[derive(Debug, Clone, PartialEq)]
20pub struct RobustKarcherConfig {
21    /// Maximum number of outer iterations.
22    pub max_iter: usize,
23    /// Convergence tolerance (relative change in SRSF).
24    pub tol: f64,
25    /// Roughness penalty for elastic alignment (0.0 = no penalty).
26    pub lambda: f64,
27    /// Fraction of most-distant curves to trim (for trimmed mean).
28    pub trim_fraction: f64,
29}
30
31impl Default for RobustKarcherConfig {
32    fn default() -> Self {
33        Self {
34            max_iter: 20,
35            tol: 1e-3,
36            lambda: 0.0,
37            trim_fraction: 0.1,
38        }
39    }
40}
41
42/// Result of robust Karcher estimation.
43#[derive(Debug, Clone, PartialEq)]
44#[non_exhaustive]
45pub struct RobustKarcherResult {
46    /// Robust mean/median curve.
47    pub mean: Vec<f64>,
48    /// SRSF of the robust mean/median.
49    pub mean_srsf: Vec<f64>,
50    /// Warping functions for all curves (n x m).
51    pub gammas: FdMatrix,
52    /// All curves aligned to the robust mean/median (n x m).
53    pub aligned_data: FdMatrix,
54    /// Per-curve weights (1/distance for median, 0/1 for trimmed).
55    pub weights: Vec<f64>,
56    /// Number of iterations performed.
57    pub n_iter: usize,
58    /// Whether the algorithm converged.
59    pub converged: bool,
60}
61
62/// Compute the Karcher median via the Weiszfeld algorithm on the elastic manifold.
63///
64/// The geometric median minimizes the sum of elastic distances to all curves,
65/// rather than the sum of squared distances (as with the mean). This makes it
66/// robust to outlier curves.
67///
68/// # Algorithm
69/// 1. Initialize with standard Karcher mean (1 iteration) as starting point.
70/// 2. Iterative Weiszfeld loop:
71///    a. Align all curves to the current median estimate.
72///    b. Compute elastic distances.
73///    c. Set weights w_i = 1 / max(d_i, epsilon), normalize.
74///    d. Compute weighted pointwise mean of aligned curves.
75///    e. Check convergence (relative change in SRSF).
76///
77/// # Arguments
78/// * `data`    — Functional data matrix (n x m).
79/// * `argvals` — Evaluation points (length m).
80/// * `config`  — Configuration parameters.
81///
82/// # Errors
83/// Returns [`FdarError::InvalidDimension`] if `argvals` length does not match `m`
84/// or `n < 2`.
85#[must_use = "expensive computation whose result should not be discarded"]
86pub fn karcher_median(
87    data: &FdMatrix,
88    argvals: &[f64],
89    config: &RobustKarcherConfig,
90) -> Result<RobustKarcherResult, FdarError> {
91    let (n, m) = data.shape();
92    validate_inputs(n, m, argvals)?;
93
94    // Step 1: Initialize with a quick Karcher mean (1 iteration).
95    let init = karcher_mean(data, argvals, 1, config.tol, config.lambda);
96    let mut current_mean = init.mean;
97
98    let mut converged = false;
99    let mut n_iter = 0;
100    let mut weights = vec![1.0 / n as f64; n];
101    let mut alignment_result = align_to_target(data, &current_mean, argvals, config.lambda);
102
103    // Step 2: Weiszfeld iterations.
104    for iter in 0..config.max_iter {
105        n_iter = iter + 1;
106
107        // Compute elastic distances.
108        let distances: Vec<f64> = (0..n)
109            .map(|i| {
110                let fi = data.row(i);
111                elastic_distance(&current_mean, &fi, argvals, config.lambda)
112            })
113            .collect();
114
115        // Compute weights: w_i = 1 / max(d_i, epsilon).
116        let epsilon = 1e-10;
117        let raw_weights: Vec<f64> = distances.iter().map(|&d| 1.0 / d.max(epsilon)).collect();
118        let w_sum: f64 = raw_weights.iter().sum();
119        weights = raw_weights.iter().map(|&w| w / w_sum).collect();
120
121        // Weighted pointwise mean of aligned curves.
122        let mut new_mean = vec![0.0; m];
123        for i in 0..n {
124            for j in 0..m {
125                new_mean[j] += weights[i] * alignment_result.aligned_data[(i, j)];
126            }
127        }
128
129        // Check convergence.
130        let old_srsf = srsf_single(&current_mean, argvals);
131        let new_srsf = srsf_single(&new_mean, argvals);
132        let rel = relative_srsf_change(&old_srsf, &new_srsf);
133
134        current_mean = new_mean;
135
136        if rel < config.tol {
137            converged = true;
138            // Final alignment to converged median.
139            alignment_result = align_to_target(data, &current_mean, argvals, config.lambda);
140            break;
141        }
142
143        // Re-align to updated median.
144        alignment_result = align_to_target(data, &current_mean, argvals, config.lambda);
145    }
146
147    let mean_srsf = srsf_single(&current_mean, argvals);
148
149    Ok(RobustKarcherResult {
150        mean: current_mean,
151        mean_srsf,
152        gammas: alignment_result.gammas,
153        aligned_data: alignment_result.aligned_data,
154        weights,
155        n_iter,
156        converged,
157    })
158}
159
160/// Compute a trimmed Karcher mean by removing the most distant curves.
161///
162/// Computes the standard Karcher mean, identifies and removes the top
163/// `trim_fraction` of curves by elastic distance, then recomputes the
164/// Karcher mean on the remaining curves. All curves (including trimmed
165/// ones) are re-aligned to the robust mean for the final output.
166///
167/// # Arguments
168/// * `data`    — Functional data matrix (n x m).
169/// * `argvals` — Evaluation points (length m).
170/// * `config`  — Configuration parameters.
171///
172/// # Errors
173/// Returns [`FdarError::InvalidDimension`] if `argvals` length does not match `m`
174/// or `n < 2`.
175/// Returns [`FdarError::InvalidParameter`] if `trim_fraction` is not in \[0, 1).
176#[must_use = "expensive computation whose result should not be discarded"]
177pub fn robust_karcher_mean(
178    data: &FdMatrix,
179    argvals: &[f64],
180    config: &RobustKarcherConfig,
181) -> Result<RobustKarcherResult, FdarError> {
182    let (n, m) = data.shape();
183    validate_inputs(n, m, argvals)?;
184
185    if !(0.0..1.0).contains(&config.trim_fraction) {
186        return Err(FdarError::InvalidParameter {
187            parameter: "trim_fraction",
188            message: format!("must be in [0, 1), got {}", config.trim_fraction),
189        });
190    }
191
192    // Step 1: Compute standard Karcher mean.
193    let initial_mean = karcher_mean(data, argvals, config.max_iter, config.tol, config.lambda);
194
195    // Step 2: Compute elastic distances from the mean.
196    let distances: Vec<f64> = (0..n)
197        .map(|i| {
198            let fi = data.row(i);
199            elastic_distance(&initial_mean.mean, &fi, argvals, config.lambda)
200        })
201        .collect();
202
203    // Step 3: Sort by distance, identify curves to trim.
204    let n_trim = ((n as f64) * config.trim_fraction).ceil() as usize;
205    let n_keep = n.saturating_sub(n_trim).max(2); // Keep at least 2 curves.
206
207    let mut indexed_distances: Vec<(usize, f64)> =
208        distances.iter().enumerate().map(|(i, &d)| (i, d)).collect();
209    indexed_distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
210
211    let kept_indices: Vec<usize> = indexed_distances
212        .iter()
213        .take(n_keep)
214        .map(|&(i, _)| i)
215        .collect();
216
217    // Step 4: Set weights.
218    let mut weights = vec![0.0; n];
219    for &idx in &kept_indices {
220        weights[idx] = 1.0;
221    }
222
223    // Step 5: Recompute Karcher mean on the kept subset.
224    let kept_data = subset_rows_from_indices(data, &kept_indices);
225    let robust_mean = karcher_mean(
226        &kept_data,
227        argvals,
228        config.max_iter,
229        config.tol,
230        config.lambda,
231    );
232
233    // Step 6: Re-align ALL curves (including trimmed) to the robust mean.
234    let final_alignment = align_to_target(data, &robust_mean.mean, argvals, config.lambda);
235
236    let mean_srsf = srsf_single(&robust_mean.mean, argvals);
237
238    Ok(RobustKarcherResult {
239        mean: robust_mean.mean,
240        mean_srsf,
241        gammas: final_alignment.gammas,
242        aligned_data: final_alignment.aligned_data,
243        weights,
244        n_iter: robust_mean.n_iter,
245        converged: robust_mean.converged,
246    })
247}
248
249/// Validate common input dimensions.
250fn validate_inputs(n: usize, m: usize, argvals: &[f64]) -> Result<(), FdarError> {
251    if argvals.len() != m {
252        return Err(FdarError::InvalidDimension {
253            parameter: "argvals",
254            expected: format!("{m}"),
255            actual: format!("{}", argvals.len()),
256        });
257    }
258    if n < 2 {
259        return Err(FdarError::InvalidDimension {
260            parameter: "data",
261            expected: "at least 2 rows".to_string(),
262            actual: format!("{n} rows"),
263        });
264    }
265    Ok(())
266}
267
268/// Compute relative change between successive SRSFs.
269fn relative_srsf_change(q_old: &[f64], q_new: &[f64]) -> f64 {
270    let diff_norm: f64 = q_old
271        .iter()
272        .zip(q_new.iter())
273        .map(|(&a, &b)| (a - b).powi(2))
274        .sum::<f64>()
275        .sqrt();
276    let old_norm: f64 = q_old.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
277    diff_norm / old_norm
278}
279
280/// Extract a subset of rows from an FdMatrix by index.
281fn subset_rows_from_indices(data: &FdMatrix, indices: &[usize]) -> FdMatrix {
282    let m = data.ncols();
283    let n_new = indices.len();
284    let mut result = FdMatrix::zeros(n_new, m);
285    for (new_i, &old_i) in indices.iter().enumerate() {
286        for j in 0..m {
287            result[(new_i, j)] = data[(old_i, j)];
288        }
289    }
290    result
291}
292
293#[cfg(test)]
294mod tests {
295    use super::*;
296    use crate::test_helpers::uniform_grid;
297
298    fn make_sine_data(n: usize, m: usize) -> (FdMatrix, Vec<f64>) {
299        let t = uniform_grid(m);
300        let mut data_vec = vec![0.0; n * m];
301        for i in 0..n {
302            let phase = 0.03 * i as f64;
303            for j in 0..m {
304                data_vec[i + j * n] = ((t[j] + phase) * 4.0).sin();
305            }
306        }
307        let data = FdMatrix::from_column_major(data_vec, n, m).unwrap();
308        (data, t)
309    }
310
311    #[test]
312    fn karcher_median_basic() {
313        let (data, t) = make_sine_data(5, 20);
314        let config = RobustKarcherConfig {
315            max_iter: 5,
316            ..Default::default()
317        };
318        let result = karcher_median(&data, &t, &config).unwrap();
319        assert_eq!(result.mean.len(), 20);
320        assert_eq!(result.mean_srsf.len(), 20);
321        assert_eq!(result.gammas.shape(), (5, 20));
322        assert_eq!(result.aligned_data.shape(), (5, 20));
323        assert_eq!(result.weights.len(), 5);
324        assert!(result.n_iter >= 1);
325    }
326
327    #[test]
328    fn karcher_median_robust_to_outlier() {
329        let m = 20;
330        let t = uniform_grid(m);
331        let n = 6;
332        let mut data_vec = vec![0.0; n * m];
333
334        // 5 clean curves (slight phase shifts).
335        for i in 0..5 {
336            let phase = 0.02 * i as f64;
337            for j in 0..m {
338                data_vec[i + j * n] = ((t[j] + phase) * 4.0).sin();
339            }
340        }
341        // 1 extreme outlier.
342        for j in 0..m {
343            data_vec[5 + j * n] = (t[j] * 20.0).cos() * 5.0;
344        }
345        let data = FdMatrix::from_column_major(data_vec, n, m).unwrap();
346
347        // Compute standard mean and median.
348        let std_mean = karcher_mean(&data, &t, 5, 1e-3, 0.0);
349        let median_config = RobustKarcherConfig {
350            max_iter: 5,
351            ..Default::default()
352        };
353        let median_result = karcher_median(&data, &t, &median_config).unwrap();
354
355        // Compute a clean reference (mean of just the clean curves).
356        let clean_data = subset_rows_from_indices(&data, &[0, 1, 2, 3, 4]);
357        let clean_mean = karcher_mean(&clean_data, &t, 5, 1e-3, 0.0);
358
359        // Median should be closer to the clean mean than the standard mean is.
360        let d_std = pointwise_l2(&std_mean.mean, &clean_mean.mean);
361        let d_median = pointwise_l2(&median_result.mean, &clean_mean.mean);
362        assert!(
363            d_median <= d_std + 1e-6,
364            "median distance to clean ({d_median:.4}) should be <= standard mean distance ({d_std:.4})"
365        );
366    }
367
368    #[test]
369    fn robust_trimmed_removes_outliers() {
370        let m = 20;
371        let t = uniform_grid(m);
372        let n = 6;
373        let mut data_vec = vec![0.0; n * m];
374
375        // 5 clean curves.
376        for i in 0..5 {
377            let phase = 0.02 * i as f64;
378            for j in 0..m {
379                data_vec[i + j * n] = ((t[j] + phase) * 4.0).sin();
380            }
381        }
382        // 1 extreme outlier.
383        for j in 0..m {
384            data_vec[5 + j * n] = (t[j] * 20.0).cos() * 5.0;
385        }
386        let data = FdMatrix::from_column_major(data_vec, n, m).unwrap();
387
388        let config = RobustKarcherConfig {
389            max_iter: 5,
390            trim_fraction: 0.2, // Trim top 20% (= 2 curves out of 6).
391            ..Default::default()
392        };
393        let result = robust_karcher_mean(&data, &t, &config).unwrap();
394
395        // The outlier (index 5) should have weight 0.0.
396        assert!(
397            result.weights[5] < 1e-10,
398            "outlier weight should be 0, got {}",
399            result.weights[5]
400        );
401
402        // At least some curves should have weight 1.0.
403        let n_kept: usize = result.weights.iter().filter(|&&w| w > 0.5).count();
404        assert!(n_kept >= 4, "should keep at least 4 curves, got {n_kept}");
405    }
406
407    #[test]
408    fn robust_config_default() {
409        let cfg = RobustKarcherConfig::default();
410        assert_eq!(cfg.max_iter, 20);
411        assert!((cfg.tol - 1e-3).abs() < f64::EPSILON);
412        assert!((cfg.lambda - 0.0).abs() < f64::EPSILON);
413        assert!((cfg.trim_fraction - 0.1).abs() < f64::EPSILON);
414    }
415
416    /// Simple pointwise L2 distance between two curves.
417    fn pointwise_l2(a: &[f64], b: &[f64]) -> f64 {
418        a.iter()
419            .zip(b.iter())
420            .map(|(&x, &y)| (x - y).powi(2))
421            .sum::<f64>()
422            .sqrt()
423    }
424}