use super::karcher::karcher_mean;
use super::pairwise::{elastic_distance, elastic_self_distance_matrix};
use super::KarcherMeanResult;
use crate::cv::subset_rows;
use crate::error::FdarError;
use crate::matrix::FdMatrix;
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
#[derive(Debug, Clone, PartialEq)]
pub struct ElasticClusterConfig {
pub k: usize,
pub lambda: f64,
pub max_iter: usize,
pub tol: f64,
pub karcher_max_iter: usize,
pub karcher_tol: f64,
pub seed: u64,
}
impl Default for ElasticClusterConfig {
fn default() -> Self {
Self {
k: 2,
lambda: 0.0,
max_iter: 20,
tol: 1e-4,
karcher_max_iter: 15,
karcher_tol: 1e-3,
seed: 42,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
#[non_exhaustive]
pub enum ElasticClusterMethod {
#[default]
KMeans,
HierarchicalSingle,
HierarchicalComplete,
HierarchicalAverage,
}
#[derive(Debug, Clone, PartialEq)]
#[non_exhaustive]
pub struct ElasticClusterResult {
pub labels: Vec<usize>,
pub centers: Vec<KarcherMeanResult>,
pub within_distances: Vec<f64>,
pub total_within_distance: f64,
pub n_iter: usize,
pub converged: bool,
}
#[derive(Debug, Clone, PartialEq)]
#[non_exhaustive]
pub struct ElasticDendrogram {
pub merges: Vec<(usize, usize, f64)>,
pub distance_matrix: FdMatrix,
}
fn kmeans_pp_init(dist_mat: &FdMatrix, k: usize, rng: &mut StdRng) -> Vec<usize> {
let n = dist_mat.nrows();
let mut centers = Vec::with_capacity(k);
centers.push(rng.gen_range(0..n));
let mut min_dist_sq: Vec<f64> = (0..n)
.map(|i| {
let d = dist_mat[(i, centers[0])];
d * d
})
.collect();
for _ in 1..k {
let total: f64 = min_dist_sq.iter().sum();
if total <= 0.0 {
for i in 0..n {
if !centers.contains(&i) {
centers.push(i);
break;
}
}
} else {
let threshold = rng.gen::<f64>() * total;
let mut cum = 0.0;
let mut chosen = n - 1;
for i in 0..n {
cum += min_dist_sq[i];
if cum >= threshold {
chosen = i;
break;
}
}
centers.push(chosen);
}
let new_center = *centers.last().unwrap();
for i in 0..n {
let d = dist_mat[(i, new_center)];
let d2 = d * d;
if d2 < min_dist_sq[i] {
min_dist_sq[i] = d2;
}
}
}
centers
}
fn reassign_empty_cluster(labels: &[usize], dist_mat: &FdMatrix) -> usize {
let n = labels.len();
let max_label = labels.iter().copied().max().unwrap_or(0);
let mut counts = vec![0usize; max_label + 1];
for &l in labels {
counts[l] += 1;
}
let largest_cluster = counts
.iter()
.enumerate()
.max_by_key(|&(_, &cnt)| cnt)
.map(|(c, _)| c)
.unwrap_or(0);
let members: Vec<usize> = (0..n).filter(|&i| labels[i] == largest_cluster).collect();
let mut max_avg_dist = -1.0_f64;
let mut farthest = members[0];
for &i in &members {
let avg_d: f64 =
members.iter().map(|&j| dist_mat[(i, j)]).sum::<f64>() / members.len() as f64;
if avg_d > max_avg_dist {
max_avg_dist = avg_d;
farthest = i;
}
}
farthest
}
#[must_use = "expensive computation whose result should not be discarded"]
pub fn elastic_kmeans(
data: &FdMatrix,
argvals: &[f64],
config: &ElasticClusterConfig,
) -> Result<ElasticClusterResult, FdarError> {
let (n, m) = data.shape();
if config.k < 1 {
return Err(FdarError::InvalidParameter {
parameter: "k",
message: "k must be >= 1".to_string(),
});
}
if config.k > n {
return Err(FdarError::InvalidParameter {
parameter: "k",
message: format!("k ({}) must be <= n ({})", config.k, n),
});
}
if argvals.len() != m {
return Err(FdarError::InvalidDimension {
parameter: "argvals",
expected: format!("{m}"),
actual: format!("{}", argvals.len()),
});
}
let k = config.k;
let dist_mat = elastic_self_distance_matrix(data, argvals, config.lambda);
let mut rng = StdRng::seed_from_u64(config.seed);
let center_indices = kmeans_pp_init(&dist_mat, k, &mut rng);
let mut labels = vec![0usize; n];
for i in 0..n {
let mut best_d = f64::INFINITY;
for (c, &ci) in center_indices.iter().enumerate() {
let d = dist_mat[(i, ci)];
if d < best_d {
best_d = d;
labels[i] = c;
}
}
}
let mut converged = false;
let mut n_iter = 0;
let mut centers: Vec<KarcherMeanResult> = Vec::with_capacity(k);
for iter in 0..config.max_iter {
n_iter = iter + 1;
centers = compute_cluster_centers(data, argvals, &labels, k, &dist_mat, config);
let new_labels: Vec<usize> = (0..n)
.map(|i| {
let fi = data.row(i);
let mut best_d = f64::INFINITY;
let mut best_c = 0;
for (c, center) in centers.iter().enumerate() {
let d = elastic_distance(&fi, ¢er.mean, argvals, config.lambda);
if d < best_d {
best_d = d;
best_c = c;
}
}
best_c
})
.collect();
if new_labels == labels {
converged = true;
labels = new_labels;
break;
}
labels = new_labels;
}
if !converged {
centers = compute_cluster_centers(data, argvals, &labels, k, &dist_mat, config);
}
let mut within_distances = vec![0.0; k];
for i in 0..n {
let fi = data.row(i);
let c = labels[i];
let d = elastic_distance(&fi, ¢ers[c].mean, argvals, config.lambda);
within_distances[c] += d;
}
let total_within_distance: f64 = within_distances.iter().sum();
Ok(ElasticClusterResult {
labels,
centers,
within_distances,
total_within_distance,
n_iter,
converged,
})
}
fn compute_cluster_centers(
data: &FdMatrix,
argvals: &[f64],
labels: &[usize],
k: usize,
dist_mat: &FdMatrix,
config: &ElasticClusterConfig,
) -> Vec<KarcherMeanResult> {
let n = data.nrows();
let mut centers = Vec::with_capacity(k);
for c in 0..k {
let members: Vec<usize> = (0..n).filter(|&i| labels[i] == c).collect();
if members.is_empty() {
let singleton_idx = reassign_empty_cluster(labels, dist_mat);
let sub = subset_rows(data, &[singleton_idx]);
centers.push(karcher_mean(
&sub,
argvals,
1,
config.karcher_tol,
config.lambda,
));
} else {
let sub = subset_rows(data, &members);
centers.push(karcher_mean(
&sub,
argvals,
config.karcher_max_iter,
config.karcher_tol,
config.lambda,
));
}
}
centers
}
#[must_use = "expensive computation whose result should not be discarded"]
pub fn elastic_hierarchical(
data: &FdMatrix,
argvals: &[f64],
method: ElasticClusterMethod,
lambda: f64,
) -> Result<ElasticDendrogram, FdarError> {
let (n, m) = data.shape();
if argvals.len() != m {
return Err(FdarError::InvalidDimension {
parameter: "argvals",
expected: format!("{m}"),
actual: format!("{}", argvals.len()),
});
}
if n < 2 {
return Err(FdarError::InvalidDimension {
parameter: "data",
expected: "at least 2 rows".to_string(),
actual: format!("{n} rows"),
});
}
let dist_mat = elastic_self_distance_matrix(data, argvals, lambda);
let mut active = vec![true; n];
let mut cluster_sizes = vec![1usize; n];
let mut cluster_dist = FdMatrix::zeros(n, n);
for i in 0..n {
for j in 0..n {
cluster_dist[(i, j)] = dist_mat[(i, j)];
}
}
let mut merges: Vec<(usize, usize, f64)> = Vec::with_capacity(n - 1);
for _ in 0..(n - 1) {
let mut min_d = f64::INFINITY;
let mut min_i = 0;
let mut min_j = 1;
for i in 0..n {
if !active[i] {
continue;
}
for j in (i + 1)..n {
if !active[j] {
continue;
}
if cluster_dist[(i, j)] < min_d {
min_d = cluster_dist[(i, j)];
min_i = i;
min_j = j;
}
}
}
merges.push((min_i, min_j, min_d));
let size_i = cluster_sizes[min_i];
let size_j = cluster_sizes[min_j];
for k in 0..n {
if !active[k] || k == min_i || k == min_j {
continue;
}
let d_ik = cluster_dist[(min_i.min(k), min_i.max(k))];
let d_jk = cluster_dist[(min_j.min(k), min_j.max(k))];
let new_d = match method {
ElasticClusterMethod::HierarchicalSingle | ElasticClusterMethod::KMeans => {
d_ik.min(d_jk)
}
ElasticClusterMethod::HierarchicalComplete => d_ik.max(d_jk),
ElasticClusterMethod::HierarchicalAverage => {
(d_ik * size_i as f64 + d_jk * size_j as f64) / (size_i + size_j) as f64
}
};
let (lo, hi) = (min_i.min(k), min_i.max(k));
cluster_dist[(lo, hi)] = new_d;
cluster_dist[(hi, lo)] = new_d;
}
cluster_sizes[min_i] = size_i + size_j;
active[min_j] = false;
}
Ok(ElasticDendrogram {
merges,
distance_matrix: dist_mat,
})
}
pub fn cut_dendrogram(dendrogram: &ElasticDendrogram, k: usize) -> Result<Vec<usize>, FdarError> {
let n = dendrogram.distance_matrix.nrows();
if k < 1 {
return Err(FdarError::InvalidParameter {
parameter: "k",
message: "k must be >= 1".to_string(),
});
}
if k > n {
return Err(FdarError::InvalidParameter {
parameter: "k",
message: format!("k ({k}) must be <= n ({n})"),
});
}
let mut cluster_of: Vec<usize> = (0..n).collect();
let merges_to_apply = n - k;
for &(ci, cj, _) in dendrogram.merges.iter().take(merges_to_apply) {
let target = cluster_of[ci];
let source = cluster_of[cj];
for label in cluster_of.iter_mut() {
if *label == source {
*label = target;
}
}
}
let mut unique: Vec<usize> = cluster_of.clone();
unique.sort_unstable();
unique.dedup();
let labels: Vec<usize> = cluster_of
.iter()
.map(|&c| unique.iter().position(|&u| u == c).unwrap())
.collect();
Ok(labels)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::simulation::{sim_fundata, EFunType, EValType};
use crate::test_helpers::uniform_grid;
fn make_data(n: usize, m: usize) -> (FdMatrix, Vec<f64>) {
let t = uniform_grid(m);
let data = sim_fundata(n, &t, 3, EFunType::Fourier, EValType::Exponential, Some(42));
(data, t)
}
#[test]
fn kmeans_smoke() {
let (data, t) = make_data(8, 20);
let config = ElasticClusterConfig {
k: 2,
max_iter: 3,
karcher_max_iter: 3,
..Default::default()
};
let result = elastic_kmeans(&data, &t, &config).unwrap();
assert_eq!(result.labels.len(), 8);
assert_eq!(result.centers.len(), 2);
assert_eq!(result.within_distances.len(), 2);
assert!(result.total_within_distance >= 0.0);
assert!(result.n_iter >= 1);
}
#[test]
fn kmeans_single_cluster() {
let (data, t) = make_data(5, 20);
let config = ElasticClusterConfig {
k: 1,
max_iter: 3,
karcher_max_iter: 3,
..Default::default()
};
let result = elastic_kmeans(&data, &t, &config).unwrap();
assert!(result.labels.iter().all(|&l| l == 0));
assert_eq!(result.centers.len(), 1);
}
#[test]
fn kmeans_k_too_large() {
let (data, t) = make_data(3, 20);
let config = ElasticClusterConfig {
k: 5,
..Default::default()
};
assert!(elastic_kmeans(&data, &t, &config).is_err());
}
#[test]
fn kmeans_k_zero() {
let (data, t) = make_data(5, 20);
let config = ElasticClusterConfig {
k: 0,
..Default::default()
};
assert!(elastic_kmeans(&data, &t, &config).is_err());
}
#[test]
fn hierarchical_single_smoke() {
let (data, t) = make_data(5, 20);
let dendro =
elastic_hierarchical(&data, &t, ElasticClusterMethod::HierarchicalSingle, 0.0).unwrap();
assert_eq!(dendro.merges.len(), 4);
for w in dendro.merges.windows(2) {
assert!(
w[1].2 >= w[0].2 - 1e-10,
"single linkage should be non-decreasing"
);
}
}
#[test]
fn hierarchical_complete_smoke() {
let (data, t) = make_data(5, 20);
let dendro =
elastic_hierarchical(&data, &t, ElasticClusterMethod::HierarchicalComplete, 0.0)
.unwrap();
assert_eq!(dendro.merges.len(), 4);
}
#[test]
fn hierarchical_average_smoke() {
let (data, t) = make_data(5, 20);
let dendro =
elastic_hierarchical(&data, &t, ElasticClusterMethod::HierarchicalAverage, 0.0)
.unwrap();
assert_eq!(dendro.merges.len(), 4);
}
#[test]
fn hierarchical_too_few_curves() {
let t = uniform_grid(20);
let curve: Vec<f64> = t.iter().map(|&x| x.sin()).collect();
let data = FdMatrix::from_slice(&curve, 1, 20).unwrap();
assert!(
elastic_hierarchical(&data, &t, ElasticClusterMethod::HierarchicalSingle, 0.0).is_err()
);
}
#[test]
fn cut_dendrogram_all_singletons() {
let (data, t) = make_data(5, 20);
let dendro =
elastic_hierarchical(&data, &t, ElasticClusterMethod::HierarchicalSingle, 0.0).unwrap();
let labels = cut_dendrogram(&dendro, 5).unwrap();
let mut sorted = labels.clone();
sorted.sort_unstable();
assert_eq!(sorted, vec![0, 1, 2, 3, 4]);
}
#[test]
fn cut_dendrogram_one_cluster() {
let (data, t) = make_data(5, 20);
let dendro =
elastic_hierarchical(&data, &t, ElasticClusterMethod::HierarchicalSingle, 0.0).unwrap();
let labels = cut_dendrogram(&dendro, 1).unwrap();
assert!(labels.iter().all(|&l| l == 0));
}
#[test]
fn cut_dendrogram_k_too_large() {
let (data, t) = make_data(5, 20);
let dendro =
elastic_hierarchical(&data, &t, ElasticClusterMethod::HierarchicalSingle, 0.0).unwrap();
assert!(cut_dendrogram(&dendro, 10).is_err());
}
#[test]
fn cut_dendrogram_two_clusters() {
let (data, t) = make_data(6, 20);
let dendro =
elastic_hierarchical(&data, &t, ElasticClusterMethod::HierarchicalSingle, 0.0).unwrap();
let labels = cut_dendrogram(&dendro, 2).unwrap();
assert_eq!(labels.len(), 6);
let unique: std::collections::HashSet<usize> = labels.iter().copied().collect();
assert_eq!(unique.len(), 2);
}
#[test]
fn default_config_values() {
let cfg = ElasticClusterConfig::default();
assert_eq!(cfg.k, 2);
assert!((cfg.lambda - 0.0).abs() < f64::EPSILON);
assert_eq!(cfg.max_iter, 20);
assert!((cfg.tol - 1e-4).abs() < f64::EPSILON);
assert_eq!(cfg.karcher_max_iter, 15);
assert!((cfg.karcher_tol - 1e-3).abs() < f64::EPSILON);
assert_eq!(cfg.seed, 42);
}
#[test]
fn default_method() {
assert_eq!(
ElasticClusterMethod::default(),
ElasticClusterMethod::KMeans
);
}
}