use rand::seq::SliceRandom;
use rand::thread_rng;
#[derive(Debug, Clone)]
pub struct KMeansConfig {
pub k: usize,
pub max_iterations: usize,
pub tolerance: f64,
}
impl KMeansConfig {
pub fn new(k: usize) -> Self {
Self {
k,
max_iterations: 300,
tolerance: 1e-4,
}
}
pub fn with_max_iterations(mut self, max_iterations: usize) -> Self {
self.max_iterations = max_iterations;
self
}
pub fn with_tolerance(mut self, tolerance: f64) -> Self {
self.tolerance = tolerance;
self
}
}
pub fn kmeans(data: &[Vec<f64>], config: &KMeansConfig) -> (Vec<usize>, Vec<Vec<f64>>) {
if data.is_empty() {
panic!("Empty dataset provided.");
}
let n = data.len();
let dim = data[0].len();
if dim == 0 {
panic!("Data points must have at least one dimension.");
}
if config.k == 0 || config.k > n {
panic!(
"Invalid number of clusters k = {} for dataset of size {}",
config.k, n
);
}
let mut rng = thread_rng();
let mut centroids: Vec<Vec<f64>> = data.choose_multiple(&mut rng, config.k).cloned().collect();
let mut assignments = vec![0_usize; n];
for _iter in 0..config.max_iterations {
let mut changed = false;
for (i, point) in data.iter().enumerate() {
let mut best_cluster = assignments[i];
let mut best_dist = distance_sq(point, ¢roids[best_cluster]);
for (cluster_idx, centroid) in centroids.iter().enumerate().take(config.k) {
let dist = distance_sq(point, centroid);
if dist < best_dist {
best_dist = dist;
best_cluster = cluster_idx;
}
}
if best_cluster != assignments[i] {
assignments[i] = best_cluster;
changed = true;
}
}
let mut sums = vec![vec![0.0; dim]; config.k];
let mut counts = vec![0_usize; config.k];
for (i, point) in data.iter().enumerate() {
let c = assignments[i];
counts[c] += 1;
for (d, _) in point.iter().enumerate().take(dim) {
sums[c][d] += point[d];
}
}
let mut max_centroid_shift_sq = 0.0;
for cluster_idx in 0..config.k {
if counts[cluster_idx] > 0 {
let mut new_centroid = vec![0.0; dim];
for (d, val) in new_centroid.iter_mut().enumerate().take(dim) {
*val = sums[cluster_idx][d] / counts[cluster_idx] as f64;
}
let shift_sq = distance_sq(¢roids[cluster_idx], &new_centroid);
if shift_sq > max_centroid_shift_sq {
max_centroid_shift_sq = shift_sq;
}
centroids[cluster_idx] = new_centroid;
}
}
if !changed || max_centroid_shift_sq < config.tolerance * config.tolerance {
break;
}
}
(assignments, centroids)
}
fn distance_sq(a: &[f64], b: &[f64]) -> f64 {
a.iter()
.zip(b.iter())
.fold(0.0, |acc, (&x, &y)| acc + (x - y).powi(2))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[should_panic]
fn test_empty_data() {
let config = KMeansConfig::new(3);
let data: Vec<Vec<f64>> = vec![];
let _ = kmeans(&data, &config);
}
#[test]
#[should_panic]
fn test_invalid_k() {
let config = KMeansConfig::new(5);
let data = vec![vec![1.0, 2.0], vec![2.0, 3.0]];
let _ = kmeans(&data, &config);
}
#[test]
fn test_basic_run() {
let data = vec![
vec![1.0, 2.0],
vec![1.5, 1.8],
vec![5.0, 8.0],
vec![8.0, 8.0],
];
let config = KMeansConfig::new(2)
.with_max_iterations(50)
.with_tolerance(1e-4);
let (assignments, centroids) = kmeans(&data, &config);
assert_eq!(assignments.len(), data.len());
assert_eq!(centroids.len(), 2);
for c in ¢roids {
assert_eq!(c.len(), 2);
}
}
}