1use super::karcher::karcher_mean;
12use super::pairwise::elastic_distance;
13use super::set::align_to_target;
14use super::srsf::srsf_single;
15use crate::error::FdarError;
16use crate::matrix::FdMatrix;
17
18#[derive(Debug, Clone, PartialEq)]
20pub struct RobustKarcherConfig {
21 pub max_iter: usize,
23 pub tol: f64,
25 pub lambda: f64,
27 pub trim_fraction: f64,
29}
30
31impl Default for RobustKarcherConfig {
32 fn default() -> Self {
33 Self {
34 max_iter: 20,
35 tol: 1e-3,
36 lambda: 0.0,
37 trim_fraction: 0.1,
38 }
39 }
40}
41
42#[derive(Debug, Clone, PartialEq)]
44#[non_exhaustive]
45pub struct RobustKarcherResult {
46 pub mean: Vec<f64>,
48 pub mean_srsf: Vec<f64>,
50 pub gammas: FdMatrix,
52 pub aligned_data: FdMatrix,
54 pub weights: Vec<f64>,
56 pub n_iter: usize,
58 pub converged: bool,
60}
61
62#[must_use = "expensive computation whose result should not be discarded"]
86pub fn karcher_median(
87 data: &FdMatrix,
88 argvals: &[f64],
89 config: &RobustKarcherConfig,
90) -> Result<RobustKarcherResult, FdarError> {
91 let (n, m) = data.shape();
92 validate_inputs(n, m, argvals)?;
93
94 let init = karcher_mean(data, argvals, 1, config.tol, config.lambda);
96 let mut current_mean = init.mean;
97
98 let mut converged = false;
99 let mut n_iter = 0;
100 let mut weights = vec![1.0 / n as f64; n];
101 let mut alignment_result = align_to_target(data, ¤t_mean, argvals, config.lambda);
102
103 for iter in 0..config.max_iter {
105 n_iter = iter + 1;
106
107 let distances: Vec<f64> = (0..n)
109 .map(|i| {
110 let fi = data.row(i);
111 elastic_distance(¤t_mean, &fi, argvals, config.lambda)
112 })
113 .collect();
114
115 let epsilon = 1e-10;
117 let raw_weights: Vec<f64> = distances.iter().map(|&d| 1.0 / d.max(epsilon)).collect();
118 let w_sum: f64 = raw_weights.iter().sum();
119 weights = raw_weights.iter().map(|&w| w / w_sum).collect();
120
121 let mut new_mean = vec![0.0; m];
123 for i in 0..n {
124 for j in 0..m {
125 new_mean[j] += weights[i] * alignment_result.aligned_data[(i, j)];
126 }
127 }
128
129 let old_srsf = srsf_single(¤t_mean, argvals);
131 let new_srsf = srsf_single(&new_mean, argvals);
132 let rel = relative_srsf_change(&old_srsf, &new_srsf);
133
134 current_mean = new_mean;
135
136 if rel < config.tol {
137 converged = true;
138 alignment_result = align_to_target(data, ¤t_mean, argvals, config.lambda);
140 break;
141 }
142
143 alignment_result = align_to_target(data, ¤t_mean, argvals, config.lambda);
145 }
146
147 let mean_srsf = srsf_single(¤t_mean, argvals);
148
149 Ok(RobustKarcherResult {
150 mean: current_mean,
151 mean_srsf,
152 gammas: alignment_result.gammas,
153 aligned_data: alignment_result.aligned_data,
154 weights,
155 n_iter,
156 converged,
157 })
158}
159
160#[must_use = "expensive computation whose result should not be discarded"]
177pub fn robust_karcher_mean(
178 data: &FdMatrix,
179 argvals: &[f64],
180 config: &RobustKarcherConfig,
181) -> Result<RobustKarcherResult, FdarError> {
182 let (n, m) = data.shape();
183 validate_inputs(n, m, argvals)?;
184
185 if !(0.0..1.0).contains(&config.trim_fraction) {
186 return Err(FdarError::InvalidParameter {
187 parameter: "trim_fraction",
188 message: format!("must be in [0, 1), got {}", config.trim_fraction),
189 });
190 }
191
192 let initial_mean = karcher_mean(data, argvals, config.max_iter, config.tol, config.lambda);
194
195 let distances: Vec<f64> = (0..n)
197 .map(|i| {
198 let fi = data.row(i);
199 elastic_distance(&initial_mean.mean, &fi, argvals, config.lambda)
200 })
201 .collect();
202
203 let n_trim = ((n as f64) * config.trim_fraction).ceil() as usize;
205 let n_keep = n.saturating_sub(n_trim).max(2); let mut indexed_distances: Vec<(usize, f64)> =
208 distances.iter().enumerate().map(|(i, &d)| (i, d)).collect();
209 indexed_distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
210
211 let kept_indices: Vec<usize> = indexed_distances
212 .iter()
213 .take(n_keep)
214 .map(|&(i, _)| i)
215 .collect();
216
217 let mut weights = vec![0.0; n];
219 for &idx in &kept_indices {
220 weights[idx] = 1.0;
221 }
222
223 let kept_data = subset_rows_from_indices(data, &kept_indices);
225 let robust_mean = karcher_mean(
226 &kept_data,
227 argvals,
228 config.max_iter,
229 config.tol,
230 config.lambda,
231 );
232
233 let final_alignment = align_to_target(data, &robust_mean.mean, argvals, config.lambda);
235
236 let mean_srsf = srsf_single(&robust_mean.mean, argvals);
237
238 Ok(RobustKarcherResult {
239 mean: robust_mean.mean,
240 mean_srsf,
241 gammas: final_alignment.gammas,
242 aligned_data: final_alignment.aligned_data,
243 weights,
244 n_iter: robust_mean.n_iter,
245 converged: robust_mean.converged,
246 })
247}
248
249fn validate_inputs(n: usize, m: usize, argvals: &[f64]) -> Result<(), FdarError> {
251 if argvals.len() != m {
252 return Err(FdarError::InvalidDimension {
253 parameter: "argvals",
254 expected: format!("{m}"),
255 actual: format!("{}", argvals.len()),
256 });
257 }
258 if n < 2 {
259 return Err(FdarError::InvalidDimension {
260 parameter: "data",
261 expected: "at least 2 rows".to_string(),
262 actual: format!("{n} rows"),
263 });
264 }
265 Ok(())
266}
267
268fn relative_srsf_change(q_old: &[f64], q_new: &[f64]) -> f64 {
270 let diff_norm: f64 = q_old
271 .iter()
272 .zip(q_new.iter())
273 .map(|(&a, &b)| (a - b).powi(2))
274 .sum::<f64>()
275 .sqrt();
276 let old_norm: f64 = q_old.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
277 diff_norm / old_norm
278}
279
280use crate::cv::subset_rows as subset_rows_from_indices;
281
282#[cfg(test)]
283mod tests {
284 use super::*;
285 use crate::test_helpers::uniform_grid;
286
287 fn make_sine_data(n: usize, m: usize) -> (FdMatrix, Vec<f64>) {
288 let t = uniform_grid(m);
289 let mut data_vec = vec![0.0; n * m];
290 for i in 0..n {
291 let phase = 0.03 * i as f64;
292 for j in 0..m {
293 data_vec[i + j * n] = ((t[j] + phase) * 4.0).sin();
294 }
295 }
296 let data = FdMatrix::from_column_major(data_vec, n, m).unwrap();
297 (data, t)
298 }
299
300 #[test]
301 fn karcher_median_basic() {
302 let (data, t) = make_sine_data(5, 20);
303 let config = RobustKarcherConfig {
304 max_iter: 5,
305 ..Default::default()
306 };
307 let result = karcher_median(&data, &t, &config).unwrap();
308 assert_eq!(result.mean.len(), 20);
309 assert_eq!(result.mean_srsf.len(), 20);
310 assert_eq!(result.gammas.shape(), (5, 20));
311 assert_eq!(result.aligned_data.shape(), (5, 20));
312 assert_eq!(result.weights.len(), 5);
313 assert!(result.n_iter >= 1);
314 }
315
316 #[test]
317 fn karcher_median_robust_to_outlier() {
318 let m = 20;
319 let t = uniform_grid(m);
320 let n = 6;
321 let mut data_vec = vec![0.0; n * m];
322
323 for i in 0..5 {
325 let phase = 0.02 * i as f64;
326 for j in 0..m {
327 data_vec[i + j * n] = ((t[j] + phase) * 4.0).sin();
328 }
329 }
330 for j in 0..m {
332 data_vec[5 + j * n] = (t[j] * 20.0).cos() * 5.0;
333 }
334 let data = FdMatrix::from_column_major(data_vec, n, m).unwrap();
335
336 let std_mean = karcher_mean(&data, &t, 5, 1e-3, 0.0);
338 let median_config = RobustKarcherConfig {
339 max_iter: 5,
340 ..Default::default()
341 };
342 let median_result = karcher_median(&data, &t, &median_config).unwrap();
343
344 let clean_data = subset_rows_from_indices(&data, &[0, 1, 2, 3, 4]);
346 let clean_mean = karcher_mean(&clean_data, &t, 5, 1e-3, 0.0);
347
348 let d_std = pointwise_l2(&std_mean.mean, &clean_mean.mean);
350 let d_median = pointwise_l2(&median_result.mean, &clean_mean.mean);
351 assert!(
352 d_median <= d_std + 1e-6,
353 "median distance to clean ({d_median:.4}) should be <= standard mean distance ({d_std:.4})"
354 );
355 }
356
357 #[test]
358 fn robust_trimmed_removes_outliers() {
359 let m = 20;
360 let t = uniform_grid(m);
361 let n = 6;
362 let mut data_vec = vec![0.0; n * m];
363
364 for i in 0..5 {
366 let phase = 0.02 * i as f64;
367 for j in 0..m {
368 data_vec[i + j * n] = ((t[j] + phase) * 4.0).sin();
369 }
370 }
371 for j in 0..m {
373 data_vec[5 + j * n] = (t[j] * 20.0).cos() * 5.0;
374 }
375 let data = FdMatrix::from_column_major(data_vec, n, m).unwrap();
376
377 let config = RobustKarcherConfig {
378 max_iter: 5,
379 trim_fraction: 0.2, ..Default::default()
381 };
382 let result = robust_karcher_mean(&data, &t, &config).unwrap();
383
384 assert!(
386 result.weights[5] < 1e-10,
387 "outlier weight should be 0, got {}",
388 result.weights[5]
389 );
390
391 let n_kept: usize = result.weights.iter().filter(|&&w| w > 0.5).count();
393 assert!(n_kept >= 4, "should keep at least 4 curves, got {n_kept}");
394 }
395
396 #[test]
397 fn robust_config_default() {
398 let cfg = RobustKarcherConfig::default();
399 assert_eq!(cfg.max_iter, 20);
400 assert!((cfg.tol - 1e-3).abs() < f64::EPSILON);
401 assert!((cfg.lambda - 0.0).abs() < f64::EPSILON);
402 assert!((cfg.trim_fraction - 0.1).abs() < f64::EPSILON);
403 }
404
405 fn pointwise_l2(a: &[f64], b: &[f64]) -> f64 {
407 a.iter()
408 .zip(b.iter())
409 .map(|(&x, &y)| (x - y).powi(2))
410 .sum::<f64>()
411 .sqrt()
412 }
413}