use super::karcher::karcher_mean;
use super::pairwise::elastic_distance;
use super::set::align_to_target;
use super::srsf::srsf_single;
use crate::error::FdarError;
use crate::matrix::FdMatrix;
#[derive(Debug, Clone, PartialEq)]
pub struct RobustKarcherConfig {
pub max_iter: usize,
pub tol: f64,
pub lambda: f64,
pub trim_fraction: f64,
}
impl Default for RobustKarcherConfig {
fn default() -> Self {
Self {
max_iter: 20,
tol: 1e-3,
lambda: 0.0,
trim_fraction: 0.1,
}
}
}
#[derive(Debug, Clone, PartialEq)]
#[non_exhaustive]
pub struct RobustKarcherResult {
pub mean: Vec<f64>,
pub mean_srsf: Vec<f64>,
pub gammas: FdMatrix,
pub aligned_data: FdMatrix,
pub weights: Vec<f64>,
pub n_iter: usize,
pub converged: bool,
}
#[must_use = "expensive computation whose result should not be discarded"]
pub fn karcher_median(
data: &FdMatrix,
argvals: &[f64],
config: &RobustKarcherConfig,
) -> Result<RobustKarcherResult, FdarError> {
let (n, m) = data.shape();
validate_inputs(n, m, argvals)?;
let init = karcher_mean(data, argvals, 1, config.tol, config.lambda);
let mut current_mean = init.mean;
let mut converged = false;
let mut n_iter = 0;
let mut weights = vec![1.0 / n as f64; n];
let mut alignment_result = align_to_target(data, ¤t_mean, argvals, config.lambda);
for iter in 0..config.max_iter {
n_iter = iter + 1;
let distances: Vec<f64> = (0..n)
.map(|i| {
let fi = data.row(i);
elastic_distance(¤t_mean, &fi, argvals, config.lambda)
})
.collect();
let epsilon = 1e-10;
let raw_weights: Vec<f64> = distances.iter().map(|&d| 1.0 / d.max(epsilon)).collect();
let w_sum: f64 = raw_weights.iter().sum();
weights = raw_weights.iter().map(|&w| w / w_sum).collect();
let mut new_mean = vec![0.0; m];
for i in 0..n {
for j in 0..m {
new_mean[j] += weights[i] * alignment_result.aligned_data[(i, j)];
}
}
let old_srsf = srsf_single(¤t_mean, argvals);
let new_srsf = srsf_single(&new_mean, argvals);
let rel = relative_srsf_change(&old_srsf, &new_srsf);
current_mean = new_mean;
if rel < config.tol {
converged = true;
alignment_result = align_to_target(data, ¤t_mean, argvals, config.lambda);
break;
}
alignment_result = align_to_target(data, ¤t_mean, argvals, config.lambda);
}
let mean_srsf = srsf_single(¤t_mean, argvals);
Ok(RobustKarcherResult {
mean: current_mean,
mean_srsf,
gammas: alignment_result.gammas,
aligned_data: alignment_result.aligned_data,
weights,
n_iter,
converged,
})
}
#[must_use = "expensive computation whose result should not be discarded"]
pub fn robust_karcher_mean(
data: &FdMatrix,
argvals: &[f64],
config: &RobustKarcherConfig,
) -> Result<RobustKarcherResult, FdarError> {
let (n, m) = data.shape();
validate_inputs(n, m, argvals)?;
if !(0.0..1.0).contains(&config.trim_fraction) {
return Err(FdarError::InvalidParameter {
parameter: "trim_fraction",
message: format!("must be in [0, 1), got {}", config.trim_fraction),
});
}
let initial_mean = karcher_mean(data, argvals, config.max_iter, config.tol, config.lambda);
let distances: Vec<f64> = (0..n)
.map(|i| {
let fi = data.row(i);
elastic_distance(&initial_mean.mean, &fi, argvals, config.lambda)
})
.collect();
let n_trim = ((n as f64) * config.trim_fraction).ceil() as usize;
let n_keep = n.saturating_sub(n_trim).max(2);
let mut indexed_distances: Vec<(usize, f64)> =
distances.iter().enumerate().map(|(i, &d)| (i, d)).collect();
indexed_distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
let kept_indices: Vec<usize> = indexed_distances
.iter()
.take(n_keep)
.map(|&(i, _)| i)
.collect();
let mut weights = vec![0.0; n];
for &idx in &kept_indices {
weights[idx] = 1.0;
}
let kept_data = subset_rows_from_indices(data, &kept_indices);
let robust_mean = karcher_mean(
&kept_data,
argvals,
config.max_iter,
config.tol,
config.lambda,
);
let final_alignment = align_to_target(data, &robust_mean.mean, argvals, config.lambda);
let mean_srsf = srsf_single(&robust_mean.mean, argvals);
Ok(RobustKarcherResult {
mean: robust_mean.mean,
mean_srsf,
gammas: final_alignment.gammas,
aligned_data: final_alignment.aligned_data,
weights,
n_iter: robust_mean.n_iter,
converged: robust_mean.converged,
})
}
fn validate_inputs(n: usize, m: usize, argvals: &[f64]) -> Result<(), FdarError> {
if argvals.len() != m {
return Err(FdarError::InvalidDimension {
parameter: "argvals",
expected: format!("{m}"),
actual: format!("{}", argvals.len()),
});
}
if n < 2 {
return Err(FdarError::InvalidDimension {
parameter: "data",
expected: "at least 2 rows".to_string(),
actual: format!("{n} rows"),
});
}
Ok(())
}
fn relative_srsf_change(q_old: &[f64], q_new: &[f64]) -> f64 {
let diff_norm: f64 = q_old
.iter()
.zip(q_new.iter())
.map(|(&a, &b)| (a - b).powi(2))
.sum::<f64>()
.sqrt();
let old_norm: f64 = q_old.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
diff_norm / old_norm
}
use crate::cv::subset_rows as subset_rows_from_indices;
#[cfg(test)]
mod tests {
use super::*;
use crate::test_helpers::uniform_grid;
fn make_sine_data(n: usize, m: usize) -> (FdMatrix, Vec<f64>) {
let t = uniform_grid(m);
let mut data_vec = vec![0.0; n * m];
for i in 0..n {
let phase = 0.03 * i as f64;
for j in 0..m {
data_vec[i + j * n] = ((t[j] + phase) * 4.0).sin();
}
}
let data = FdMatrix::from_column_major(data_vec, n, m).unwrap();
(data, t)
}
#[test]
fn karcher_median_basic() {
let (data, t) = make_sine_data(5, 20);
let config = RobustKarcherConfig {
max_iter: 5,
..Default::default()
};
let result = karcher_median(&data, &t, &config).unwrap();
assert_eq!(result.mean.len(), 20);
assert_eq!(result.mean_srsf.len(), 20);
assert_eq!(result.gammas.shape(), (5, 20));
assert_eq!(result.aligned_data.shape(), (5, 20));
assert_eq!(result.weights.len(), 5);
assert!(result.n_iter >= 1);
}
#[test]
fn karcher_median_robust_to_outlier() {
let m = 20;
let t = uniform_grid(m);
let n = 6;
let mut data_vec = vec![0.0; n * m];
for i in 0..5 {
let phase = 0.02 * i as f64;
for j in 0..m {
data_vec[i + j * n] = ((t[j] + phase) * 4.0).sin();
}
}
for j in 0..m {
data_vec[5 + j * n] = (t[j] * 20.0).cos() * 5.0;
}
let data = FdMatrix::from_column_major(data_vec, n, m).unwrap();
let std_mean = karcher_mean(&data, &t, 5, 1e-3, 0.0);
let median_config = RobustKarcherConfig {
max_iter: 5,
..Default::default()
};
let median_result = karcher_median(&data, &t, &median_config).unwrap();
let clean_data = subset_rows_from_indices(&data, &[0, 1, 2, 3, 4]);
let clean_mean = karcher_mean(&clean_data, &t, 5, 1e-3, 0.0);
let d_std = pointwise_l2(&std_mean.mean, &clean_mean.mean);
let d_median = pointwise_l2(&median_result.mean, &clean_mean.mean);
assert!(
d_median <= d_std + 1e-6,
"median distance to clean ({d_median:.4}) should be <= standard mean distance ({d_std:.4})"
);
}
#[test]
fn robust_trimmed_removes_outliers() {
let m = 20;
let t = uniform_grid(m);
let n = 6;
let mut data_vec = vec![0.0; n * m];
for i in 0..5 {
let phase = 0.02 * i as f64;
for j in 0..m {
data_vec[i + j * n] = ((t[j] + phase) * 4.0).sin();
}
}
for j in 0..m {
data_vec[5 + j * n] = (t[j] * 20.0).cos() * 5.0;
}
let data = FdMatrix::from_column_major(data_vec, n, m).unwrap();
let config = RobustKarcherConfig {
max_iter: 5,
trim_fraction: 0.2, ..Default::default()
};
let result = robust_karcher_mean(&data, &t, &config).unwrap();
assert!(
result.weights[5] < 1e-10,
"outlier weight should be 0, got {}",
result.weights[5]
);
let n_kept: usize = result.weights.iter().filter(|&&w| w > 0.5).count();
assert!(n_kept >= 4, "should keep at least 4 curves, got {n_kept}");
}
#[test]
fn robust_config_default() {
let cfg = RobustKarcherConfig::default();
assert_eq!(cfg.max_iter, 20);
assert!((cfg.tol - 1e-3).abs() < f64::EPSILON);
assert!((cfg.lambda - 0.0).abs() < f64::EPSILON);
assert!((cfg.trim_fraction - 0.1).abs() < f64::EPSILON);
}
fn pointwise_l2(a: &[f64], b: &[f64]) -> f64 {
a.iter()
.zip(b.iter())
.map(|(&x, &y)| (x - y).powi(2))
.sum::<f64>()
.sqrt()
}
}