use super::traits::Clustering;
use crate::error::{Error, Result};
use ndarray::{Array1, Array2};
use rand::prelude::*;
#[cfg(feature = "parallel")]
#[allow(unused_imports)]
use rayon::prelude::*;
#[derive(Debug, Clone)]
pub struct KmeansElkan {
k: usize,
max_iter: usize,
tol: f64,
seed: Option<u64>,
track_stats: bool,
}
#[derive(Debug, Clone, Default)]
pub struct ElkanStats {
pub iterations: usize,
pub distance_computations: u64,
pub distances_skipped: u64,
pub skip_fraction: f64,
pub centroid_distances: u64,
}
impl KmeansElkan {
pub fn new(k: usize) -> Self {
Self {
k,
max_iter: 100,
tol: 1e-4,
seed: None,
track_stats: false,
}
}
pub fn with_max_iter(mut self, max_iter: usize) -> Self {
self.max_iter = max_iter;
self
}
pub fn with_tol(mut self, tol: f64) -> Self {
self.tol = tol;
self
}
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = Some(seed);
self
}
pub fn with_stats(mut self, track: bool) -> Self {
self.track_stats = track;
self
}
pub fn fit_predict_with_stats(
&self,
data: &[Vec<f32>],
) -> Result<(Vec<usize>, Option<ElkanStats>)> {
if data.is_empty() {
return Err(Error::EmptyInput);
}
let n = data.len();
let d = data[0].len();
if self.k > n {
return Err(Error::InvalidClusterCount {
requested: self.k,
n_items: n,
});
}
let mut flat: Vec<f32> = Vec::with_capacity(n * d);
for point in data {
if point.len() != d {
return Err(Error::DimensionMismatch {
expected: d,
found: point.len(),
});
}
flat.extend(point);
}
let data_arr =
Array2::from_shape_vec((n, d), flat).map_err(|e| Error::Other(e.to_string()))?;
let mut rng: Box<dyn RngCore> = match self.seed {
Some(s) => Box::new(StdRng::seed_from_u64(s)),
None => Box::new(rand::rng()),
};
let mut centroids = self.init_centroids_pp(&data_arr, &mut rng);
let mut labels = vec![0usize; n];
let mut upper = vec![f32::MAX; n]; let mut lower = vec![vec![0.0f32; self.k]; n]; let mut centroid_half_dist = Array2::<f32>::zeros((self.k, self.k));
let mut stats = ElkanStats::default();
for i in 0..n {
let point = data_arr.row(i);
let mut best_cluster = 0;
let mut best_dist = f32::MAX;
#[allow(clippy::needless_range_loop)] for j in 0..self.k {
let dist = squared_distance(&point, ¢roids.row(j)).sqrt();
lower[i][j] = dist;
stats.distance_computations += 1;
if dist < best_dist {
best_dist = dist;
best_cluster = j;
}
}
labels[i] = best_cluster;
upper[i] = best_dist;
}
for i in 0..self.k {
for j in (i + 1)..self.k {
let dist = squared_distance(¢roids.row(i), ¢roids.row(j)).sqrt();
centroid_half_dist[[i, j]] = dist / 2.0;
centroid_half_dist[[j, i]] = dist / 2.0;
stats.centroid_distances += 1;
}
}
let mut s = self.compute_s(¢roid_half_dist);
for _iter in 0..self.max_iter {
stats.iterations += 1;
for i in 0..n {
let a_i = labels[i];
if upper[i] > s[a_i] {
for j in 0..self.k {
if j == a_i {
continue;
}
if upper[i] > lower[i][j] && upper[i] > centroid_half_dist[[a_i, j]] {
let point = data_arr.row(i);
let d_a = squared_distance(&point, ¢roids.row(a_i)).sqrt();
upper[i] = d_a;
lower[i][a_i] = d_a;
stats.distance_computations += 1;
if d_a > lower[i][j] && d_a > centroid_half_dist[[a_i, j]] {
let d_j = squared_distance(&point, ¢roids.row(j)).sqrt();
lower[i][j] = d_j;
stats.distance_computations += 1;
if d_j < d_a {
labels[i] = j;
upper[i] = d_j;
}
} else {
stats.distances_skipped += 1;
}
} else {
stats.distances_skipped += 1;
}
}
} else {
stats.distances_skipped += (self.k - 1) as u64;
}
}
let mut new_centroids = Array2::<f32>::zeros((self.k, d));
let mut counts = vec![0usize; self.k];
for i in 0..n {
let k = labels[i];
for j in 0..d {
new_centroids[[k, j]] += data_arr[[i, j]];
}
counts[k] += 1;
}
for k in 0..self.k {
if counts[k] > 0 {
for j in 0..d {
new_centroids[[k, j]] /= counts[k] as f32;
}
} else {
let idx = rng.random_range(0..n);
new_centroids.row_mut(k).assign(&data_arr.row(idx));
}
}
let mut centroid_shift = Array1::<f32>::zeros(self.k);
for k in 0..self.k {
centroid_shift[k] =
squared_distance(¢roids.row(k), &new_centroids.row(k)).sqrt();
}
for i in 0..n {
let a_i = labels[i];
upper[i] += centroid_shift[a_i];
for j in 0..self.k {
lower[i][j] = (lower[i][j] - centroid_shift[j]).max(0.0);
}
}
let max_shift: f32 = centroid_shift.iter().cloned().fold(0.0f32, f32::max);
centroids = new_centroids;
for i in 0..self.k {
for j in (i + 1)..self.k {
let dist = squared_distance(¢roids.row(i), ¢roids.row(j)).sqrt();
centroid_half_dist[[i, j]] = dist / 2.0;
centroid_half_dist[[j, i]] = dist / 2.0;
stats.centroid_distances += 1;
}
}
s = self.compute_s(¢roid_half_dist);
if max_shift < self.tol as f32 {
break;
}
}
let total_potential = stats.distance_computations + stats.distances_skipped;
stats.skip_fraction = if total_potential > 0 {
stats.distances_skipped as f64 / total_potential as f64
} else {
0.0
};
let final_stats = if self.track_stats { Some(stats) } else { None };
Ok((labels, final_stats))
}
fn init_centroids_pp(&self, data: &Array2<f32>, rng: &mut impl Rng) -> Array2<f32> {
let n = data.nrows();
let d = data.ncols();
let mut centroids = Array2::zeros((self.k, d));
let first = rng.random_range(0..n);
centroids.row_mut(0).assign(&data.row(first));
for i in 1..self.k {
let mut distances: Vec<f32> = Vec::with_capacity(n);
for j in 0..n {
let point = data.row(j);
let min_dist = (0..i)
.map(|c| squared_distance(&point, ¢roids.row(c)))
.fold(f32::MAX, f32::min);
distances.push(min_dist);
}
let total: f32 = distances.iter().sum();
if total == 0.0 {
let idx = rng.random_range(0..n);
centroids.row_mut(i).assign(&data.row(idx));
continue;
}
let threshold = rng.random::<f32>() * total;
let mut cumsum = 0.0;
let mut selected = 0;
for (j, &d) in distances.iter().enumerate() {
cumsum += d;
if cumsum >= threshold {
selected = j;
break;
}
}
centroids.row_mut(i).assign(&data.row(selected));
}
centroids
}
fn compute_s(&self, centroid_half_dist: &Array2<f32>) -> Vec<f32> {
let mut s = vec![f32::MAX; self.k];
for j in 0..self.k {
for jp in 0..self.k {
if jp != j && centroid_half_dist[[j, jp]] < s[j] {
s[j] = centroid_half_dist[[j, jp]];
}
}
}
s
}
}
#[inline]
fn squared_distance(a: &ndarray::ArrayView1<'_, f32>, b: &ndarray::ArrayView1<'_, f32>) -> f32 {
a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum()
}
impl Clustering for KmeansElkan {
fn fit_predict(&self, data: &[Vec<f32>]) -> Result<Vec<usize>> {
self.fit_predict_with_stats(data).map(|(labels, _)| labels)
}
fn n_clusters(&self) -> usize {
self.k
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_elkan_basic() {
let data = vec![
vec![0.0, 0.0],
vec![0.1, 0.1],
vec![10.0, 10.0],
vec![10.1, 10.1],
];
let kmeans = KmeansElkan::new(2).with_seed(42);
let labels = kmeans.fit_predict(&data).unwrap();
assert_eq!(labels[0], labels[1]);
assert_eq!(labels[2], labels[3]);
assert_ne!(labels[0], labels[2]);
}
#[test]
fn test_elkan_with_stats() {
let mut data = Vec::new();
for _ in 0..100 {
data.push(vec![0.0, 0.0]);
}
for _ in 0..100 {
data.push(vec![10.0, 10.0]);
}
for _ in 0..100 {
data.push(vec![20.0, 20.0]);
}
let kmeans = KmeansElkan::new(3).with_seed(42).with_stats(true);
let (labels, stats) = kmeans.fit_predict_with_stats(&data).unwrap();
let stats = stats.unwrap();
assert!(
stats.skip_fraction > 0.3,
"Expected significant skipping, got {}%",
stats.skip_fraction * 100.0
);
println!(
"Elkan stats: {} iterations, {} computed, {} skipped ({:.1}%)",
stats.iterations,
stats.distance_computations,
stats.distances_skipped,
stats.skip_fraction * 100.0
);
assert_eq!(labels.len(), 300);
}
#[test]
fn test_elkan_matches_lloyd() {
use super::super::kmeans::Kmeans;
let data = vec![
vec![0.0, 0.0],
vec![0.1, 0.0],
vec![0.0, 0.1],
vec![10.0, 10.0],
vec![10.1, 10.0],
vec![10.0, 10.1],
];
let lloyd = Kmeans::new(2).with_seed(42);
let elkan = KmeansElkan::new(2).with_seed(42);
let lloyd_labels = lloyd.fit_predict(&data).unwrap();
let elkan_labels = elkan.fit_predict(&data).unwrap();
for i in 0..3 {
for j in 0..3 {
assert_eq!(
lloyd_labels[i] == lloyd_labels[j],
elkan_labels[i] == elkan_labels[j],
"Mismatch at ({}, {})",
i,
j
);
}
}
}
#[test]
fn test_elkan_high_k() {
let n = 500;
let k = 50;
let d = 10;
let data: Vec<Vec<f32>> = (0..n)
.map(|i| (0..d).map(|j| ((i * j) as f32).sin()).collect())
.collect();
let kmeans = KmeansElkan::new(k)
.with_seed(42)
.with_stats(true)
.with_max_iter(20);
let (labels, stats) = kmeans.fit_predict_with_stats(&data).unwrap();
let stats = stats.unwrap();
assert!(
stats.skip_fraction > 0.5,
"Expected high skip fraction for k={}, got {:.1}%",
k,
stats.skip_fraction * 100.0
);
assert_eq!(labels.len(), n);
for &label in &labels {
assert!(label < k);
}
}
#[test]
fn test_elkan_deterministic() {
let data = vec![
vec![0.0, 0.0],
vec![0.1, 0.1],
vec![10.0, 10.0],
vec![10.1, 10.1],
];
let kmeans1 = KmeansElkan::new(2).with_seed(42);
let kmeans2 = KmeansElkan::new(2).with_seed(42);
let labels1 = kmeans1.fit_predict(&data).unwrap();
let labels2 = kmeans2.fit_predict(&data).unwrap();
assert_eq!(labels1, labels2);
}
#[test]
fn test_elkan_empty_input() {
let data: Vec<Vec<f32>> = vec![];
let kmeans = KmeansElkan::new(2);
assert!(kmeans.fit_predict(&data).is_err());
}
#[test]
fn test_elkan_k_larger_than_n() {
let data = vec![vec![0.0, 0.0], vec![1.0, 1.0]];
let kmeans = KmeansElkan::new(5);
assert!(kmeans.fit_predict(&data).is_err());
}
}