1use super::karcher::karcher_mean;
12use super::pairwise::elastic_distance;
13use super::set::align_to_target;
14use super::srsf::srsf_single;
15use crate::error::FdarError;
16use crate::matrix::FdMatrix;
17
18#[derive(Debug, Clone, PartialEq)]
20pub struct RobustKarcherConfig {
21 pub max_iter: usize,
23 pub tol: f64,
25 pub lambda: f64,
27 pub trim_fraction: f64,
29}
30
31impl Default for RobustKarcherConfig {
32 fn default() -> Self {
33 Self {
34 max_iter: 20,
35 tol: 1e-3,
36 lambda: 0.0,
37 trim_fraction: 0.1,
38 }
39 }
40}
41
42#[derive(Debug, Clone, PartialEq)]
44#[non_exhaustive]
45pub struct RobustKarcherResult {
46 pub mean: Vec<f64>,
48 pub mean_srsf: Vec<f64>,
50 pub gammas: FdMatrix,
52 pub aligned_data: FdMatrix,
54 pub weights: Vec<f64>,
56 pub n_iter: usize,
58 pub converged: bool,
60}
61
62#[must_use = "expensive computation whose result should not be discarded"]
86pub fn karcher_median(
87 data: &FdMatrix,
88 argvals: &[f64],
89 config: &RobustKarcherConfig,
90) -> Result<RobustKarcherResult, FdarError> {
91 let (n, m) = data.shape();
92 validate_inputs(n, m, argvals)?;
93
94 let init = karcher_mean(data, argvals, 1, config.tol, config.lambda);
96 let mut current_mean = init.mean;
97
98 let mut converged = false;
99 let mut n_iter = 0;
100 let mut weights = vec![1.0 / n as f64; n];
101 let mut alignment_result = align_to_target(data, ¤t_mean, argvals, config.lambda);
102
103 for iter in 0..config.max_iter {
105 n_iter = iter + 1;
106
107 let distances: Vec<f64> = (0..n)
109 .map(|i| {
110 let fi = data.row(i);
111 elastic_distance(¤t_mean, &fi, argvals, config.lambda)
112 })
113 .collect();
114
115 let epsilon = 1e-10;
117 let raw_weights: Vec<f64> = distances.iter().map(|&d| 1.0 / d.max(epsilon)).collect();
118 let w_sum: f64 = raw_weights.iter().sum();
119 weights = raw_weights.iter().map(|&w| w / w_sum).collect();
120
121 let mut new_mean = vec![0.0; m];
123 for i in 0..n {
124 for j in 0..m {
125 new_mean[j] += weights[i] * alignment_result.aligned_data[(i, j)];
126 }
127 }
128
129 let old_srsf = srsf_single(¤t_mean, argvals);
131 let new_srsf = srsf_single(&new_mean, argvals);
132 let rel = relative_srsf_change(&old_srsf, &new_srsf);
133
134 current_mean = new_mean;
135
136 if rel < config.tol {
137 converged = true;
138 alignment_result = align_to_target(data, ¤t_mean, argvals, config.lambda);
140 break;
141 }
142
143 alignment_result = align_to_target(data, ¤t_mean, argvals, config.lambda);
145 }
146
147 let mean_srsf = srsf_single(¤t_mean, argvals);
148
149 Ok(RobustKarcherResult {
150 mean: current_mean,
151 mean_srsf,
152 gammas: alignment_result.gammas,
153 aligned_data: alignment_result.aligned_data,
154 weights,
155 n_iter,
156 converged,
157 })
158}
159
160#[must_use = "expensive computation whose result should not be discarded"]
177pub fn robust_karcher_mean(
178 data: &FdMatrix,
179 argvals: &[f64],
180 config: &RobustKarcherConfig,
181) -> Result<RobustKarcherResult, FdarError> {
182 let (n, m) = data.shape();
183 validate_inputs(n, m, argvals)?;
184
185 if !(0.0..1.0).contains(&config.trim_fraction) {
186 return Err(FdarError::InvalidParameter {
187 parameter: "trim_fraction",
188 message: format!("must be in [0, 1), got {}", config.trim_fraction),
189 });
190 }
191
192 let initial_mean = karcher_mean(data, argvals, config.max_iter, config.tol, config.lambda);
194
195 let distances: Vec<f64> = (0..n)
197 .map(|i| {
198 let fi = data.row(i);
199 elastic_distance(&initial_mean.mean, &fi, argvals, config.lambda)
200 })
201 .collect();
202
203 let n_trim = ((n as f64) * config.trim_fraction).ceil() as usize;
205 let n_keep = n.saturating_sub(n_trim).max(2); let mut indexed_distances: Vec<(usize, f64)> =
208 distances.iter().enumerate().map(|(i, &d)| (i, d)).collect();
209 indexed_distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
210
211 let kept_indices: Vec<usize> = indexed_distances
212 .iter()
213 .take(n_keep)
214 .map(|&(i, _)| i)
215 .collect();
216
217 let mut weights = vec![0.0; n];
219 for &idx in &kept_indices {
220 weights[idx] = 1.0;
221 }
222
223 let kept_data = subset_rows_from_indices(data, &kept_indices);
225 let robust_mean = karcher_mean(
226 &kept_data,
227 argvals,
228 config.max_iter,
229 config.tol,
230 config.lambda,
231 );
232
233 let final_alignment = align_to_target(data, &robust_mean.mean, argvals, config.lambda);
235
236 let mean_srsf = srsf_single(&robust_mean.mean, argvals);
237
238 Ok(RobustKarcherResult {
239 mean: robust_mean.mean,
240 mean_srsf,
241 gammas: final_alignment.gammas,
242 aligned_data: final_alignment.aligned_data,
243 weights,
244 n_iter: robust_mean.n_iter,
245 converged: robust_mean.converged,
246 })
247}
248
249fn validate_inputs(n: usize, m: usize, argvals: &[f64]) -> Result<(), FdarError> {
251 if argvals.len() != m {
252 return Err(FdarError::InvalidDimension {
253 parameter: "argvals",
254 expected: format!("{m}"),
255 actual: format!("{}", argvals.len()),
256 });
257 }
258 if n < 2 {
259 return Err(FdarError::InvalidDimension {
260 parameter: "data",
261 expected: "at least 2 rows".to_string(),
262 actual: format!("{n} rows"),
263 });
264 }
265 Ok(())
266}
267
268fn relative_srsf_change(q_old: &[f64], q_new: &[f64]) -> f64 {
270 let diff_norm: f64 = q_old
271 .iter()
272 .zip(q_new.iter())
273 .map(|(&a, &b)| (a - b).powi(2))
274 .sum::<f64>()
275 .sqrt();
276 let old_norm: f64 = q_old.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
277 diff_norm / old_norm
278}
279
280fn subset_rows_from_indices(data: &FdMatrix, indices: &[usize]) -> FdMatrix {
282 let m = data.ncols();
283 let n_new = indices.len();
284 let mut result = FdMatrix::zeros(n_new, m);
285 for (new_i, &old_i) in indices.iter().enumerate() {
286 for j in 0..m {
287 result[(new_i, j)] = data[(old_i, j)];
288 }
289 }
290 result
291}
292
293#[cfg(test)]
294mod tests {
295 use super::*;
296 use crate::test_helpers::uniform_grid;
297
298 fn make_sine_data(n: usize, m: usize) -> (FdMatrix, Vec<f64>) {
299 let t = uniform_grid(m);
300 let mut data_vec = vec![0.0; n * m];
301 for i in 0..n {
302 let phase = 0.03 * i as f64;
303 for j in 0..m {
304 data_vec[i + j * n] = ((t[j] + phase) * 4.0).sin();
305 }
306 }
307 let data = FdMatrix::from_column_major(data_vec, n, m).unwrap();
308 (data, t)
309 }
310
311 #[test]
312 fn karcher_median_basic() {
313 let (data, t) = make_sine_data(5, 20);
314 let config = RobustKarcherConfig {
315 max_iter: 5,
316 ..Default::default()
317 };
318 let result = karcher_median(&data, &t, &config).unwrap();
319 assert_eq!(result.mean.len(), 20);
320 assert_eq!(result.mean_srsf.len(), 20);
321 assert_eq!(result.gammas.shape(), (5, 20));
322 assert_eq!(result.aligned_data.shape(), (5, 20));
323 assert_eq!(result.weights.len(), 5);
324 assert!(result.n_iter >= 1);
325 }
326
327 #[test]
328 fn karcher_median_robust_to_outlier() {
329 let m = 20;
330 let t = uniform_grid(m);
331 let n = 6;
332 let mut data_vec = vec![0.0; n * m];
333
334 for i in 0..5 {
336 let phase = 0.02 * i as f64;
337 for j in 0..m {
338 data_vec[i + j * n] = ((t[j] + phase) * 4.0).sin();
339 }
340 }
341 for j in 0..m {
343 data_vec[5 + j * n] = (t[j] * 20.0).cos() * 5.0;
344 }
345 let data = FdMatrix::from_column_major(data_vec, n, m).unwrap();
346
347 let std_mean = karcher_mean(&data, &t, 5, 1e-3, 0.0);
349 let median_config = RobustKarcherConfig {
350 max_iter: 5,
351 ..Default::default()
352 };
353 let median_result = karcher_median(&data, &t, &median_config).unwrap();
354
355 let clean_data = subset_rows_from_indices(&data, &[0, 1, 2, 3, 4]);
357 let clean_mean = karcher_mean(&clean_data, &t, 5, 1e-3, 0.0);
358
359 let d_std = pointwise_l2(&std_mean.mean, &clean_mean.mean);
361 let d_median = pointwise_l2(&median_result.mean, &clean_mean.mean);
362 assert!(
363 d_median <= d_std + 1e-6,
364 "median distance to clean ({d_median:.4}) should be <= standard mean distance ({d_std:.4})"
365 );
366 }
367
368 #[test]
369 fn robust_trimmed_removes_outliers() {
370 let m = 20;
371 let t = uniform_grid(m);
372 let n = 6;
373 let mut data_vec = vec![0.0; n * m];
374
375 for i in 0..5 {
377 let phase = 0.02 * i as f64;
378 for j in 0..m {
379 data_vec[i + j * n] = ((t[j] + phase) * 4.0).sin();
380 }
381 }
382 for j in 0..m {
384 data_vec[5 + j * n] = (t[j] * 20.0).cos() * 5.0;
385 }
386 let data = FdMatrix::from_column_major(data_vec, n, m).unwrap();
387
388 let config = RobustKarcherConfig {
389 max_iter: 5,
390 trim_fraction: 0.2, ..Default::default()
392 };
393 let result = robust_karcher_mean(&data, &t, &config).unwrap();
394
395 assert!(
397 result.weights[5] < 1e-10,
398 "outlier weight should be 0, got {}",
399 result.weights[5]
400 );
401
402 let n_kept: usize = result.weights.iter().filter(|&&w| w > 0.5).count();
404 assert!(n_kept >= 4, "should keep at least 4 curves, got {n_kept}");
405 }
406
407 #[test]
408 fn robust_config_default() {
409 let cfg = RobustKarcherConfig::default();
410 assert_eq!(cfg.max_iter, 20);
411 assert!((cfg.tol - 1e-3).abs() < f64::EPSILON);
412 assert!((cfg.lambda - 0.0).abs() < f64::EPSILON);
413 assert!((cfg.trim_fraction - 0.1).abs() < f64::EPSILON);
414 }
415
416 fn pointwise_l2(a: &[f64], b: &[f64]) -> f64 {
418 a.iter()
419 .zip(b.iter())
420 .map(|(&x, &y)| (x - y).powi(2))
421 .sum::<f64>()
422 .sqrt()
423 }
424}