use core::fmt;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum KmeansError {
InvalidDimension { dim: usize },
InvalidDataLength { len: usize, dim: usize },
EmptyData,
InvalidClusterCount { k: usize, n_points: usize },
ClusterCountOverflow { k: usize },
}
impl fmt::Display for KmeansError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::InvalidDimension { dim } => write!(f, "dim must be positive, got {dim}"),
Self::InvalidDataLength { len, dim } => {
write!(f, "data length {len} must be a multiple of dim {dim}")
}
Self::EmptyData => write!(f, "data must not be empty"),
Self::InvalidClusterCount { k, n_points } => {
write!(f, "k={k} out of range 1..={n_points}")
}
Self::ClusterCountOverflow { k } => {
write!(f, "k={k} exceeds u32::MAX (assignments use u32)")
}
}
}
}
impl std::error::Error for KmeansError {}
#[derive(Debug, Clone)]
pub struct KmeansResult {
pub assignments: Vec<u32>,
pub centroids: Vec<f64>,
pub dim: usize,
pub iterations: usize,
}
impl KmeansResult {
pub fn centroid(&self, i: usize) -> &[f64] {
&self.centroids[i * self.dim..(i + 1) * self.dim]
}
pub fn k(&self) -> usize {
self.centroids.len() / self.dim
}
}
pub fn kmeans(
data: &[f64],
dim: usize,
k: usize,
max_iter: usize,
) -> Result<KmeansResult, KmeansError> {
if dim == 0 {
return Err(KmeansError::InvalidDimension { dim });
}
if data.len() % dim != 0 {
return Err(KmeansError::InvalidDataLength {
len: data.len(),
dim,
});
}
let n = data.len() / dim;
if n == 0 {
return Err(KmeansError::EmptyData);
}
if k == 0 || k > n {
return Err(KmeansError::InvalidClusterCount { k, n_points: n });
}
if k > u32::MAX as usize {
return Err(KmeansError::ClusterCountOverflow { k });
}
let mut centroids = kmeanspp_init(data, dim, k);
let mut assignments = vec![0u32; n];
let mut accum = vec![0.0f64; k * dim];
let mut counts = vec![0u32; k];
let mut iter_count = 0;
for iteration in 0..max_iter {
iter_count = iteration + 1;
let changed = assign_step(data, dim, ¢roids, &mut assignments);
if !changed && iteration > 0 {
break;
}
accum.fill(0.0);
counts.fill(0);
for i in 0..n {
let c = assignments[i] as usize;
counts[c] += 1;
let pt = &data[i * dim..(i + 1) * dim];
let acc = &mut accum[c * dim..(c + 1) * dim];
for j in 0..dim {
acc[j] += pt[j];
}
}
for c in 0..k {
if counts[c] > 0 {
let cnt = counts[c] as f64;
let cent = &mut centroids[c * dim..(c + 1) * dim];
for j in 0..dim {
cent[j] = accum[c * dim + j] / cnt;
}
}
}
}
Ok(KmeansResult {
assignments,
centroids,
dim,
iterations: iter_count,
})
}
fn kmeanspp_init(data: &[f64], d: usize, k: usize) -> Vec<f64> {
let n = data.len() / d;
let mut centroids = Vec::with_capacity(k * d);
let mut mean = vec![0.0f64; d];
for pt in data.chunks_exact(d) {
for j in 0..d {
mean[j] += pt[j];
}
}
let n_f = n as f64;
for v in &mut mean {
*v /= n_f;
}
let first = (0..n)
.min_by(|&a, &b| {
let da = dist_sq(&data[a * d..(a + 1) * d], &mean);
let db = dist_sq(&data[b * d..(b + 1) * d], &mean);
da.total_cmp(&db)
})
.unwrap();
centroids.extend_from_slice(&data[first * d..(first + 1) * d]);
let mut min_dists = vec![f64::INFINITY; n];
for c_idx in 1..k {
let last_start = (c_idx - 1) * d;
let last = ¢roids[last_start..last_start + d];
for i in 0..n {
let d2 = dist_sq(&data[i * d..(i + 1) * d], last);
if d2 < min_dists[i] {
min_dists[i] = d2;
}
}
let next = (0..n)
.max_by(|&a, &b| min_dists[a].total_cmp(&min_dists[b]))
.unwrap();
centroids.extend_from_slice(&data[next * d..(next + 1) * d]);
}
centroids
}
#[cfg(not(feature = "parallel"))]
fn assign_step(data: &[f64], d: usize, centroids: &[f64], assignments: &mut [u32]) -> bool {
let mut changed = false;
for (i, a) in assignments.iter_mut().enumerate() {
let pt = &data[i * d..(i + 1) * d];
let nearest = nearest_centroid(pt, centroids);
if *a != nearest {
*a = nearest;
changed = true;
}
}
changed
}
#[cfg(feature = "parallel")]
fn assign_step(data: &[f64], d: usize, centroids: &[f64], assignments: &mut [u32]) -> bool {
use rayon::prelude::*;
use std::sync::atomic::{AtomicBool, Ordering};
let changed = AtomicBool::new(false);
assignments.par_iter_mut().enumerate().for_each(|(i, a)| {
let pt = &data[i * d..(i + 1) * d];
let nearest = nearest_centroid(pt, centroids);
if *a != nearest {
*a = nearest;
changed.store(true, Ordering::Relaxed);
}
});
changed.load(Ordering::Relaxed)
}
fn nearest_centroid(point: &[f64], centroids: &[f64]) -> u32 {
let d = point.len();
debug_assert!(centroids.len() % d == 0);
let mut best = 0u32;
let mut best_dist = f64::INFINITY;
for (i, c) in centroids.chunks_exact(d).enumerate() {
let dist = dist_sq(point, c);
if dist < best_dist {
best_dist = dist;
best = u32::try_from(i).expect("k validated <= u32::MAX");
}
}
best
}
#[inline]
fn dist_sq(a: &[f64], b: &[f64]) -> f64 {
debug_assert_eq!(a.len(), b.len());
let mut sum = 0.0;
for i in 0..a.len() {
let d = a[i] - b[i];
sum += d * d;
}
sum
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn two_clusters_1d() {
let data = [0.0, 1.0, 2.0, 10.0, 11.0, 12.0];
let result = kmeans(&data, 1, 2, 100).expect("valid clustering must succeed");
assert_eq!(result.assignments.len(), 6);
assert_eq!(result.assignments[0], result.assignments[1]);
assert_eq!(result.assignments[1], result.assignments[2]);
assert_eq!(result.assignments[3], result.assignments[4]);
assert_eq!(result.assignments[4], result.assignments[5]);
assert_ne!(result.assignments[0], result.assignments[3]);
}
#[test]
fn three_clusters_2d() {
#[rustfmt::skip]
let data = [
0.0, 0.1,
0.1, 0.0,
-0.1, 0.0,
10.0, 0.1,
10.1, 0.0,
9.9, 0.0,
5.0, 10.0,
5.1, 10.1,
4.9, 9.9,
];
let result = kmeans(&data, 2, 3, 100).expect("valid clustering must succeed");
assert_eq!(result.assignments.len(), 9);
assert_eq!(result.centroids.len(), 3 * 2);
assert_eq!(result.k(), 3);
assert_eq!(result.assignments[0], result.assignments[1]);
assert_eq!(result.assignments[1], result.assignments[2]);
assert_eq!(result.assignments[3], result.assignments[4]);
assert_eq!(result.assignments[4], result.assignments[5]);
assert_eq!(result.assignments[6], result.assignments[7]);
assert_eq!(result.assignments[7], result.assignments[8]);
let a = result.assignments[0];
let b = result.assignments[3];
let c = result.assignments[6];
assert_ne!(a, b);
assert_ne!(b, c);
assert_ne!(a, c);
}
#[test]
fn single_cluster() {
let data = [1.0, 2.0, 1.1, 2.1, 0.9, 1.9];
let result = kmeans(&data, 2, 1, 100).expect("valid clustering must succeed");
assert!(result.assignments.iter().all(|&a| a == 0));
assert_eq!(result.centroids.len(), 2);
assert_eq!(result.k(), 1);
}
#[test]
fn k_equals_n() {
let data = [0.0, 10.0, 20.0];
let result = kmeans(&data, 1, 3, 100).expect("valid clustering must succeed");
let mut ids: Vec<u32> = result.assignments.clone();
ids.sort();
ids.dedup();
assert_eq!(ids.len(), 3);
}
#[test]
fn converges_in_few_iterations() {
let data = [0.0, 0.0, 100.0, 100.0];
let result = kmeans(&data, 1, 2, 100).expect("valid clustering must succeed");
assert!(
result.iterations <= 5,
"took {} iterations",
result.iterations
);
}
#[test]
fn centroid_accessor() {
let data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let result = kmeans(&data, 2, 2, 100).expect("valid clustering must succeed");
for i in 0..result.k() {
assert_eq!(result.centroid(i).len(), 2);
}
}
#[test]
fn invalid_inputs_return_errors() {
let data = [1.0, 2.0, 3.0, 4.0];
assert_eq!(
kmeans(&data, 0, 2, 100).expect_err("zero dimension must be rejected"),
KmeansError::InvalidDimension { dim: 0 }
);
assert_eq!(
kmeans(&data[..3], 2, 2, 100).expect_err("ragged data must be rejected"),
KmeansError::InvalidDataLength { len: 3, dim: 2 }
);
assert_eq!(
kmeans(&[], 2, 1, 100).expect_err("empty data must be rejected"),
KmeansError::EmptyData
);
assert_eq!(
kmeans(&data, 2, 0, 100).expect_err("zero clusters must be rejected"),
KmeansError::InvalidClusterCount { k: 0, n_points: 2 }
);
}
}