fdars-core 0.13.0

Functional Data Analysis algorithms in Rust
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
//! Robust alternatives to the Karcher mean: median and trimmed mean.
//!
//! The standard Karcher mean is sensitive to outlier curves. This module
//! provides two robust alternatives:
//!
//! - [`karcher_median`] — Geometric median via iteratively reweighted
//!   Karcher mean (Weiszfeld algorithm on the elastic manifold).
//! - [`robust_karcher_mean`] — Trimmed Karcher mean that removes the
//!   most distant curves before averaging.

use super::karcher::karcher_mean;
use super::pairwise::elastic_distance;
use super::set::align_to_target;
use super::srsf::srsf_single;
use crate::error::FdarError;
use crate::matrix::FdMatrix;

/// Configuration for robust Karcher estimation.
#[derive(Debug, Clone, PartialEq)]
pub struct RobustKarcherConfig {
    /// Maximum number of outer iterations.
    pub max_iter: usize,
    /// Convergence tolerance (relative change in SRSF).
    pub tol: f64,
    /// Roughness penalty for elastic alignment (0.0 = no penalty).
    pub lambda: f64,
    /// Fraction of most-distant curves to trim (for trimmed mean).
    pub trim_fraction: f64,
}

impl Default for RobustKarcherConfig {
    fn default() -> Self {
        Self {
            max_iter: 20,
            tol: 1e-3,
            lambda: 0.0,
            trim_fraction: 0.1,
        }
    }
}

/// Result of robust Karcher estimation.
#[derive(Debug, Clone, PartialEq)]
#[non_exhaustive]
pub struct RobustKarcherResult {
    /// Robust mean/median curve.
    pub mean: Vec<f64>,
    /// SRSF of the robust mean/median.
    pub mean_srsf: Vec<f64>,
    /// Warping functions for all curves (n x m).
    pub gammas: FdMatrix,
    /// All curves aligned to the robust mean/median (n x m).
    pub aligned_data: FdMatrix,
    /// Per-curve weights (1/distance for median, 0/1 for trimmed).
    pub weights: Vec<f64>,
    /// Number of iterations performed.
    pub n_iter: usize,
    /// Whether the algorithm converged.
    pub converged: bool,
}

/// Compute the Karcher median via the Weiszfeld algorithm on the elastic manifold.
///
/// The geometric median minimizes the sum of elastic distances to all curves,
/// rather than the sum of squared distances (as with the mean). This makes it
/// robust to outlier curves.
///
/// # Algorithm
/// 1. Initialize with standard Karcher mean (1 iteration) as starting point.
/// 2. Iterative Weiszfeld loop:
///    a. Align all curves to the current median estimate.
///    b. Compute elastic distances.
///    c. Set weights w_i = 1 / max(d_i, epsilon), normalize.
///    d. Compute weighted pointwise mean of aligned curves.
///    e. Check convergence (relative change in SRSF).
///
/// # Arguments
/// * `data`    — Functional data matrix (n x m).
/// * `argvals` — Evaluation points (length m).
/// * `config`  — Configuration parameters.
///
/// # Errors
/// Returns [`FdarError::InvalidDimension`] if `argvals` length does not match `m`
/// or `n < 2`.
#[must_use = "expensive computation whose result should not be discarded"]
pub fn karcher_median(
    data: &FdMatrix,
    argvals: &[f64],
    config: &RobustKarcherConfig,
) -> Result<RobustKarcherResult, FdarError> {
    let (n, m) = data.shape();
    validate_inputs(n, m, argvals)?;

    // Step 1: Initialize with a quick Karcher mean (1 iteration).
    let init = karcher_mean(data, argvals, 1, config.tol, config.lambda);
    let mut current_mean = init.mean;

    let mut converged = false;
    let mut n_iter = 0;
    let mut weights = vec![1.0 / n as f64; n];
    let mut alignment_result = align_to_target(data, &current_mean, argvals, config.lambda);

    // Step 2: Weiszfeld iterations.
    for iter in 0..config.max_iter {
        n_iter = iter + 1;

        // Compute elastic distances.
        let distances: Vec<f64> = (0..n)
            .map(|i| {
                let fi = data.row(i);
                elastic_distance(&current_mean, &fi, argvals, config.lambda)
            })
            .collect();

        // Compute weights: w_i = 1 / max(d_i, epsilon).
        let epsilon = 1e-10;
        let raw_weights: Vec<f64> = distances.iter().map(|&d| 1.0 / d.max(epsilon)).collect();
        let w_sum: f64 = raw_weights.iter().sum();
        weights = raw_weights.iter().map(|&w| w / w_sum).collect();

        // Weighted pointwise mean of aligned curves.
        let mut new_mean = vec![0.0; m];
        for i in 0..n {
            for j in 0..m {
                new_mean[j] += weights[i] * alignment_result.aligned_data[(i, j)];
            }
        }

        // Check convergence.
        let old_srsf = srsf_single(&current_mean, argvals);
        let new_srsf = srsf_single(&new_mean, argvals);
        let rel = relative_srsf_change(&old_srsf, &new_srsf);

        current_mean = new_mean;

        if rel < config.tol {
            converged = true;
            // Final alignment to converged median.
            alignment_result = align_to_target(data, &current_mean, argvals, config.lambda);
            break;
        }

        // Re-align to updated median.
        alignment_result = align_to_target(data, &current_mean, argvals, config.lambda);
    }

    let mean_srsf = srsf_single(&current_mean, argvals);

    Ok(RobustKarcherResult {
        mean: current_mean,
        mean_srsf,
        gammas: alignment_result.gammas,
        aligned_data: alignment_result.aligned_data,
        weights,
        n_iter,
        converged,
    })
}

/// Compute a trimmed Karcher mean by removing the most distant curves.
///
/// Computes the standard Karcher mean, identifies and removes the top
/// `trim_fraction` of curves by elastic distance, then recomputes the
/// Karcher mean on the remaining curves. All curves (including trimmed
/// ones) are re-aligned to the robust mean for the final output.
///
/// # Arguments
/// * `data`    — Functional data matrix (n x m).
/// * `argvals` — Evaluation points (length m).
/// * `config`  — Configuration parameters.
///
/// # Errors
/// Returns [`FdarError::InvalidDimension`] if `argvals` length does not match `m`
/// or `n < 2`.
/// Returns [`FdarError::InvalidParameter`] if `trim_fraction` is not in \[0, 1).
#[must_use = "expensive computation whose result should not be discarded"]
pub fn robust_karcher_mean(
    data: &FdMatrix,
    argvals: &[f64],
    config: &RobustKarcherConfig,
) -> Result<RobustKarcherResult, FdarError> {
    let (n, m) = data.shape();
    validate_inputs(n, m, argvals)?;

    if !(0.0..1.0).contains(&config.trim_fraction) {
        return Err(FdarError::InvalidParameter {
            parameter: "trim_fraction",
            message: format!("must be in [0, 1), got {}", config.trim_fraction),
        });
    }

    // Step 1: Compute standard Karcher mean.
    let initial_mean = karcher_mean(data, argvals, config.max_iter, config.tol, config.lambda);

    // Step 2: Compute elastic distances from the mean.
    let distances: Vec<f64> = (0..n)
        .map(|i| {
            let fi = data.row(i);
            elastic_distance(&initial_mean.mean, &fi, argvals, config.lambda)
        })
        .collect();

    // Step 3: Sort by distance, identify curves to trim.
    let n_trim = ((n as f64) * config.trim_fraction).ceil() as usize;
    let n_keep = n.saturating_sub(n_trim).max(2); // Keep at least 2 curves.

    let mut indexed_distances: Vec<(usize, f64)> =
        distances.iter().enumerate().map(|(i, &d)| (i, d)).collect();
    indexed_distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));

    let kept_indices: Vec<usize> = indexed_distances
        .iter()
        .take(n_keep)
        .map(|&(i, _)| i)
        .collect();

    // Step 4: Set weights.
    let mut weights = vec![0.0; n];
    for &idx in &kept_indices {
        weights[idx] = 1.0;
    }

    // Step 5: Recompute Karcher mean on the kept subset.
    let kept_data = subset_rows_from_indices(data, &kept_indices);
    let robust_mean = karcher_mean(
        &kept_data,
        argvals,
        config.max_iter,
        config.tol,
        config.lambda,
    );

    // Step 6: Re-align ALL curves (including trimmed) to the robust mean.
    let final_alignment = align_to_target(data, &robust_mean.mean, argvals, config.lambda);

    let mean_srsf = srsf_single(&robust_mean.mean, argvals);

    Ok(RobustKarcherResult {
        mean: robust_mean.mean,
        mean_srsf,
        gammas: final_alignment.gammas,
        aligned_data: final_alignment.aligned_data,
        weights,
        n_iter: robust_mean.n_iter,
        converged: robust_mean.converged,
    })
}

/// Validate common input dimensions.
fn validate_inputs(n: usize, m: usize, argvals: &[f64]) -> Result<(), FdarError> {
    if argvals.len() != m {
        return Err(FdarError::InvalidDimension {
            parameter: "argvals",
            expected: format!("{m}"),
            actual: format!("{}", argvals.len()),
        });
    }
    if n < 2 {
        return Err(FdarError::InvalidDimension {
            parameter: "data",
            expected: "at least 2 rows".to_string(),
            actual: format!("{n} rows"),
        });
    }
    Ok(())
}

/// Compute relative change between successive SRSFs.
fn relative_srsf_change(q_old: &[f64], q_new: &[f64]) -> f64 {
    let diff_norm: f64 = q_old
        .iter()
        .zip(q_new.iter())
        .map(|(&a, &b)| (a - b).powi(2))
        .sum::<f64>()
        .sqrt();
    let old_norm: f64 = q_old.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
    diff_norm / old_norm
}

use crate::cv::subset_rows as subset_rows_from_indices;

#[cfg(test)]
mod tests {
    use super::*;
    use crate::test_helpers::uniform_grid;

    fn make_sine_data(n: usize, m: usize) -> (FdMatrix, Vec<f64>) {
        let t = uniform_grid(m);
        let mut data_vec = vec![0.0; n * m];
        for i in 0..n {
            let phase = 0.03 * i as f64;
            for j in 0..m {
                data_vec[i + j * n] = ((t[j] + phase) * 4.0).sin();
            }
        }
        let data = FdMatrix::from_column_major(data_vec, n, m).unwrap();
        (data, t)
    }

    #[test]
    fn karcher_median_basic() {
        let (data, t) = make_sine_data(5, 20);
        let config = RobustKarcherConfig {
            max_iter: 5,
            ..Default::default()
        };
        let result = karcher_median(&data, &t, &config).unwrap();
        assert_eq!(result.mean.len(), 20);
        assert_eq!(result.mean_srsf.len(), 20);
        assert_eq!(result.gammas.shape(), (5, 20));
        assert_eq!(result.aligned_data.shape(), (5, 20));
        assert_eq!(result.weights.len(), 5);
        assert!(result.n_iter >= 1);
    }

    #[test]
    fn karcher_median_robust_to_outlier() {
        let m = 20;
        let t = uniform_grid(m);
        let n = 6;
        let mut data_vec = vec![0.0; n * m];

        // 5 clean curves (slight phase shifts).
        for i in 0..5 {
            let phase = 0.02 * i as f64;
            for j in 0..m {
                data_vec[i + j * n] = ((t[j] + phase) * 4.0).sin();
            }
        }
        // 1 extreme outlier.
        for j in 0..m {
            data_vec[5 + j * n] = (t[j] * 20.0).cos() * 5.0;
        }
        let data = FdMatrix::from_column_major(data_vec, n, m).unwrap();

        // Compute standard mean and median.
        let std_mean = karcher_mean(&data, &t, 5, 1e-3, 0.0);
        let median_config = RobustKarcherConfig {
            max_iter: 5,
            ..Default::default()
        };
        let median_result = karcher_median(&data, &t, &median_config).unwrap();

        // Compute a clean reference (mean of just the clean curves).
        let clean_data = subset_rows_from_indices(&data, &[0, 1, 2, 3, 4]);
        let clean_mean = karcher_mean(&clean_data, &t, 5, 1e-3, 0.0);

        // Median should be closer to the clean mean than the standard mean is.
        let d_std = pointwise_l2(&std_mean.mean, &clean_mean.mean);
        let d_median = pointwise_l2(&median_result.mean, &clean_mean.mean);
        assert!(
            d_median <= d_std + 1e-6,
            "median distance to clean ({d_median:.4}) should be <= standard mean distance ({d_std:.4})"
        );
    }

    #[test]
    fn robust_trimmed_removes_outliers() {
        let m = 20;
        let t = uniform_grid(m);
        let n = 6;
        let mut data_vec = vec![0.0; n * m];

        // 5 clean curves.
        for i in 0..5 {
            let phase = 0.02 * i as f64;
            for j in 0..m {
                data_vec[i + j * n] = ((t[j] + phase) * 4.0).sin();
            }
        }
        // 1 extreme outlier.
        for j in 0..m {
            data_vec[5 + j * n] = (t[j] * 20.0).cos() * 5.0;
        }
        let data = FdMatrix::from_column_major(data_vec, n, m).unwrap();

        let config = RobustKarcherConfig {
            max_iter: 5,
            trim_fraction: 0.2, // Trim top 20% (= 2 curves out of 6).
            ..Default::default()
        };
        let result = robust_karcher_mean(&data, &t, &config).unwrap();

        // The outlier (index 5) should have weight 0.0.
        assert!(
            result.weights[5] < 1e-10,
            "outlier weight should be 0, got {}",
            result.weights[5]
        );

        // At least some curves should have weight 1.0.
        let n_kept: usize = result.weights.iter().filter(|&&w| w > 0.5).count();
        assert!(n_kept >= 4, "should keep at least 4 curves, got {n_kept}");
    }

    #[test]
    fn robust_config_default() {
        let cfg = RobustKarcherConfig::default();
        assert_eq!(cfg.max_iter, 20);
        assert!((cfg.tol - 1e-3).abs() < f64::EPSILON);
        assert!((cfg.lambda - 0.0).abs() < f64::EPSILON);
        assert!((cfg.trim_fraction - 0.1).abs() < f64::EPSILON);
    }

    /// Simple pointwise L2 distance between two curves.
    fn pointwise_l2(a: &[f64], b: &[f64]) -> f64 {
        a.iter()
            .zip(b.iter())
            .map(|(&x, &y)| (x - y).powi(2))
            .sum::<f64>()
            .sqrt()
    }
}