use crate::error::RagDriftError;
use crate::Result;
use ndarray::{Array2, ArrayView2, Axis};
use rand::distributions::WeightedIndex;
use rand::prelude::*;
use rand::rngs::StdRng;
use rand::SeedableRng;
#[derive(Debug, Clone)]
pub struct KMeansResult {
pub centroids: Array2<f32>,
pub labels: Vec<usize>,
pub inertia: f64,
pub iters: usize,
}
pub fn kmeans(
points: ArrayView2<f32>,
k: usize,
max_iters: usize,
tol: f32,
seed: u64,
) -> Result<KMeansResult> {
let n = points.nrows();
let dim = points.ncols();
if k == 0 {
return Err(RagDriftError::InvalidConfig("kmeans: k must be > 0".into()));
}
if k > n {
return Err(RagDriftError::InvalidConfig(format!(
"kmeans: k={} exceeds sample count n={}",
k, n
)));
}
if n == 0 {
return Err(RagDriftError::InsufficientSamples {
required: 1,
got: 0,
context: "kmeans",
});
}
let mut rng = StdRng::seed_from_u64(seed);
let mut centroids = kmeans_pp_init(points, k, &mut rng);
let mut labels = vec![0_usize; n];
let mut iters = 0usize;
let mut inertia = 0.0_f64;
for it in 0..max_iters {
iters = it + 1;
inertia = 0.0;
for (i, row) in points.axis_iter(Axis(0)).enumerate() {
let (best, dist) = nearest_centroid(row.as_slice().unwrap(), ¢roids);
labels[i] = best;
inertia += dist as f64;
}
let mut new_centroids = Array2::<f32>::zeros((k, dim));
let mut counts = vec![0_u32; k];
for (i, row) in points.axis_iter(Axis(0)).enumerate() {
let c = labels[i];
counts[c] += 1;
for (out, v) in new_centroids.row_mut(c).iter_mut().zip(row.iter()) {
*out += *v;
}
}
for (c, &count) in counts.iter().enumerate() {
if count == 0 {
let idx = rng.gen_range(0..n);
new_centroids.row_mut(c).assign(&points.row(idx));
} else {
let inv = 1.0_f32 / count as f32;
for v in new_centroids.row_mut(c).iter_mut() {
*v *= inv;
}
}
}
let mut max_shift = 0.0_f32;
for c in 0..k {
let mut s = 0.0_f32;
for (a, b) in centroids.row(c).iter().zip(new_centroids.row(c).iter()) {
let d = *a - *b;
s += d * d;
}
if s > max_shift {
max_shift = s;
}
}
centroids = new_centroids;
if max_shift.sqrt() <= tol {
break;
}
}
Ok(KMeansResult {
centroids,
labels,
inertia,
iters,
})
}
pub fn assign(points: ArrayView2<f32>, centroids: ArrayView2<f32>) -> Vec<usize> {
let centroids_owned: Array2<f32> = centroids.to_owned();
points
.axis_iter(Axis(0))
.map(|row| nearest_centroid(row.as_slice().unwrap(), ¢roids_owned).0)
.collect()
}
fn nearest_centroid(point: &[f32], centroids: &Array2<f32>) -> (usize, f32) {
let mut best = 0usize;
let mut best_d = f32::INFINITY;
for (i, c) in centroids.axis_iter(Axis(0)).enumerate() {
let mut d = 0.0_f32;
for (a, b) in point.iter().zip(c.iter()) {
let diff = *a - *b;
d += diff * diff;
}
if d < best_d {
best_d = d;
best = i;
}
}
(best, best_d)
}
fn kmeans_pp_init(points: ArrayView2<f32>, k: usize, rng: &mut StdRng) -> Array2<f32> {
let n = points.nrows();
let dim = points.ncols();
let mut centroids = Array2::<f32>::zeros((k, dim));
let first = rng.gen_range(0..n);
centroids.row_mut(0).assign(&points.row(first));
let mut min_dists = vec![f32::INFINITY; n];
for ci in 1..k {
for (i, row) in points.axis_iter(Axis(0)).enumerate() {
let mut d = 0.0_f32;
for (a, b) in row.iter().zip(centroids.row(ci - 1).iter()) {
let diff = *a - *b;
d += diff * diff;
}
if d < min_dists[i] {
min_dists[i] = d;
}
}
let total: f32 = min_dists.iter().sum();
let weights: Vec<f32> = if total > 0.0 {
min_dists.iter().map(|&d| d.max(0.0)).collect()
} else {
vec![1.0; n]
};
let dist = WeightedIndex::new(&weights).unwrap();
let pick = dist.sample(rng);
centroids.row_mut(ci).assign(&points.row(pick));
}
centroids
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array2;
#[test]
fn kmeans_separates_two_clusters() {
let mut points = Array2::<f32>::zeros((20, 2));
for i in 0..10 {
points[[i, 0]] = 0.0 + (i as f32) * 0.01;
points[[i, 1]] = 0.0;
points[[i + 10, 0]] = 10.0 + (i as f32) * 0.01;
points[[i + 10, 1]] = 10.0;
}
let r = kmeans(points.view(), 2, 50, 1e-6, 42).unwrap();
assert_eq!(r.labels.len(), 20);
let l0 = r.labels[0];
let l1 = r.labels[10];
assert_ne!(l0, l1);
for i in 0..10 {
assert_eq!(r.labels[i], l0);
assert_eq!(r.labels[i + 10], l1);
}
}
#[test]
fn kmeans_rejects_k_larger_than_n() {
let p = Array2::<f32>::zeros((3, 2));
assert!(kmeans(p.view(), 5, 10, 1e-6, 0).is_err());
}
#[test]
fn assign_matches_kmeans_labels() {
let mut points = Array2::<f32>::zeros((20, 2));
for i in 0..10 {
points[[i, 0]] = 0.0;
points[[i + 10, 0]] = 10.0;
}
let r = kmeans(points.view(), 2, 50, 1e-6, 0).unwrap();
let labels2 = assign(points.view(), r.centroids.view());
assert_eq!(labels2, r.labels);
}
}