use rand::{RngExt, SeedableRng, rngs::StdRng};
use rayon::prelude::*;
use crate::superfile::vector::distance::l2_sq;
const KMEANS_SEED_OFFSET: u64 = 7;
pub fn kmeans(vectors: &[f32], dim: usize, k: usize, iters: usize, seed: u64) -> Vec<f32> {
kmeans_with_assignments(vectors, dim, k, iters, seed).0
}
pub fn kmeans_with_assignments(
vectors: &[f32],
dim: usize,
k: usize,
iters: usize,
seed: u64,
) -> (Vec<f32>, Vec<u32>) {
assert!(dim > 0, "kmeans: dim must be > 0");
assert!(k > 0, "kmeans: k must be > 0");
assert_eq!(
vectors.len() % dim,
0,
"kmeans: vectors len {} not multiple of dim {dim}",
vectors.len()
);
let n = vectors.len() / dim;
assert!(n > 0, "kmeans: at least one doc required");
assert!(k <= n, "kmeans: k ({k}) > n_docs ({n})");
let mut rng = StdRng::seed_from_u64(seed.wrapping_add(KMEANS_SEED_OFFSET));
let mut centroids = vec![0f32; k * dim];
for i in 0..k {
let idx = rng.random_range(0..n);
centroids[i * dim..(i + 1) * dim].copy_from_slice(&vectors[idx * dim..(idx + 1) * dim]);
}
let mut assignments = vec![0u32; n];
for _ in 0..iters {
assignments = (0..n)
.into_par_iter()
.map(|d| {
let v = &vectors[d * dim..(d + 1) * dim];
let mut best = 0u32;
let mut best_d = f32::INFINITY;
for c in 0..k {
let cv = ¢roids[c * dim..(c + 1) * dim];
let dist = l2_sq(v, cv);
if dist < best_d {
best_d = dist;
best = c as u32;
}
}
best
})
.collect();
let chunk_size = (n.div_ceil(rayon::current_num_threads().max(1))).max(1);
let (sums, counts) = (0..n)
.into_par_iter()
.chunks(chunk_size)
.map(|chunk| {
let mut s = vec![0f64; k * dim];
let mut c = vec![0u64; k];
for d in chunk {
let cid = assignments[d] as usize;
c[cid] += 1;
let row = &vectors[d * dim..(d + 1) * dim];
let dst = &mut s[cid * dim..(cid + 1) * dim];
for j in 0..dim {
dst[j] += row[j] as f64;
}
}
(s, c)
})
.reduce(
|| (vec![0f64; k * dim], vec![0u64; k]),
|mut acc, x| {
for j in 0..acc.0.len() {
acc.0[j] += x.0[j];
}
for j in 0..acc.1.len() {
acc.1[j] += x.1[j];
}
acc
},
);
for c in 0..k {
if counts[c] > 0 {
let inv = 1.0 / counts[c] as f64;
let dst = &mut centroids[c * dim..(c + 1) * dim];
let src = &sums[c * dim..(c + 1) * dim];
for j in 0..dim {
dst[j] = (src[j] * inv) as f32;
}
}
}
}
(centroids, assignments)
}
pub(crate) fn assign_to_centroids(
vectors: &[f32],
centroids: &[f32],
dim: usize,
k: usize,
assignments: &mut [u32],
) {
assert!(dim > 0, "assign_to_centroids: dim must be > 0");
assert!(k > 0, "assign_to_centroids: k must be > 0");
assert_eq!(
vectors.len() % dim,
0,
"assign_to_centroids: vectors len {} not multiple of dim {dim}",
vectors.len()
);
assert_eq!(
centroids.len(),
k * dim,
"assign_to_centroids: centroids len {} != k*dim {}",
centroids.len(),
k * dim
);
let n = vectors.len() / dim;
assert_eq!(
assignments.len(),
n,
"assign_to_centroids: assignments len {} != n_docs {n}",
assignments.len()
);
if n == 0 {
return;
}
let new_assignments: Vec<u32> = (0..n)
.into_par_iter()
.map(|d| {
let v = &vectors[d * dim..(d + 1) * dim];
let mut best = 0u32;
let mut best_d = f32::INFINITY;
for c in 0..k {
let cv = ¢roids[c * dim..(c + 1) * dim];
let dist = l2_sq(v, cv);
if dist < best_d {
best_d = dist;
best = c as u32;
}
}
best
})
.collect();
assignments.copy_from_slice(&new_assignments);
}
#[cfg(test)]
mod tests {
use super::*;
fn approx(a: f32, b: f32, eps: f32) -> bool {
(a - b).abs() < eps
}
#[test]
fn returns_k_centroids_of_dim_each() {
let vectors: Vec<f32> = (0..800).map(|i| (i as f32) * 0.01).collect();
let centroids = kmeans(&vectors, 8, 4, 5, 42);
assert_eq!(centroids.len(), 4 * 8);
}
#[test]
fn determinism_same_seed_same_centroids() {
let vectors: Vec<f32> = (0..100 * 8).map(|i| (i as f32) * 0.01).collect();
let c1 = kmeans(&vectors, 8, 4, 5, 12345);
let c2 = kmeans(&vectors, 8, 4, 5, 12345);
assert_eq!(c1, c2);
}
#[test]
fn different_seeds_likely_different_centroids() {
let vectors: Vec<f32> = (0..100 * 8).map(|i| (i as f32) * 0.01).collect();
let c1 = kmeans(&vectors, 8, 4, 5, 1);
let c2 = kmeans(&vectors, 8, 4, 5, 999);
let identical = c1 == c2;
if identical {
assert_eq!(c1.len(), c2.len());
}
}
#[test]
fn centroids_are_within_data_range() {
let n = 200;
let dim = 4;
let vectors: Vec<f32> = (0..n * dim).map(|i| (i % 10) as f32).collect();
let centroids = kmeans(&vectors, dim, 8, 5, 7);
for &c in ¢roids {
assert!(
(-0.001..=9.001).contains(&c),
"centroid value {c} outside data range [0, 9]"
);
}
}
#[test]
fn cluster_data_recovers_natural_centers() {
let dim = 4;
let centers = [
[0.0f32, 0.0, 0.0, 0.0],
[10.0, 10.0, 10.0, 10.0],
[-10.0, -10.0, -10.0, -10.0],
];
let mut vectors: Vec<f32> = Vec::new();
for (cluster_idx, c) in centers.iter().enumerate() {
for d in 0..30 {
for (j, &cj) in c.iter().enumerate() {
let noise = ((cluster_idx * 30 + d + j) % 7) as f32 * 0.01 - 0.03;
vectors.push(cj + noise);
}
}
}
let centroids = kmeans(&vectors, dim, 3, 5, 42);
for c in ¢ers {
let mut best = f32::INFINITY;
for ki in 0..3 {
let cc = ¢roids[ki * dim..(ki + 1) * dim];
let d = (0..dim).map(|j| (c[j] - cc[j]).powi(2)).sum::<f32>().sqrt();
if d < best {
best = d;
}
}
assert!(
best < 0.5,
"no centroid within 0.5 of planted center {c:?} (closest = {best})"
);
}
}
#[test]
fn k_equal_to_n_assigns_each_doc_its_own_cluster() {
let dim = 2;
let vectors = vec![
1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, ];
let centroids = kmeans(&vectors, dim, 3, 5, 42);
let input_pts: Vec<[f32; 2]> = (0..3)
.map(|i| [vectors[i * 2], vectors[i * 2 + 1]])
.collect();
for ki in 0..3 {
let c = [centroids[ki * 2], centroids[ki * 2 + 1]];
let any_match = input_pts
.iter()
.any(|p| approx(p[0], c[0], 1e-3) && approx(p[1], c[1], 1e-3));
assert!(any_match, "centroid {c:?} doesn't match any input point");
}
}
#[test]
#[should_panic(expected = "k must be > 0")]
fn panics_on_zero_k() {
kmeans(&[1.0; 8], 8, 0, 5, 0);
}
#[test]
#[should_panic(expected = "k")]
fn panics_on_k_greater_than_n() {
kmeans(&[1.0; 8], 8, 5, 5, 0); }
#[test]
#[should_panic(expected = "not multiple of dim")]
fn panics_on_unaligned_input() {
kmeans(&[1.0; 7], 8, 1, 5, 0);
}
}