use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
#[cfg(not(target_arch = "wasm32"))]
use rayon::prelude::*;
#[cfg(test)]
use crate::vector::distance::squared_euclidean;
#[derive(Debug, Clone)]
pub struct KMeansConfig {
pub n_clusters: usize,
pub max_iterations: usize,
pub tolerance: f32,
pub seed: Option<u64>,
}
impl Default for KMeansConfig {
fn default() -> Self {
Self {
n_clusters: 100,
max_iterations: 25,
tolerance: 1e-4,
seed: None,
}
}
}
impl KMeansConfig {
pub fn new(n_clusters: usize) -> Self {
Self {
n_clusters,
..Default::default()
}
}
pub fn with_max_iterations(mut self, max_iterations: usize) -> Self {
self.max_iterations = max_iterations;
self
}
pub fn with_tolerance(mut self, tolerance: f32) -> Self {
self.tolerance = tolerance;
self
}
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = Some(seed);
self
}
}
#[derive(Debug, Clone)]
pub struct KMeansResult {
pub centroids: Vec<f32>,
pub assignments: Vec<u32>,
pub inertia: f32,
pub iterations: usize,
pub converged: bool,
}
pub fn kmeans(
vectors: &[f32],
n: usize,
dimensions: usize,
config: &KMeansConfig,
distance_fn: fn(&[f32], &[f32]) -> f32,
) -> Result<KMeansResult, KMeansError> {
if n < config.n_clusters {
return Err(KMeansError::NotEnoughVectors {
n,
k: config.n_clusters,
});
}
if vectors.len() != n * dimensions {
return Err(KMeansError::DimensionMismatch {
expected: n * dimensions,
got: vectors.len(),
});
}
let k = config.n_clusters;
let mut centroids = kmeans_plus_plus_init(vectors, n, dimensions, k, distance_fn, config.seed);
let mut assignments = vec![0u32; n];
let mut prev_inertia = f32::INFINITY;
let mut iterations = 0;
let mut converged = false;
for iter in 0..config.max_iterations {
iterations = iter + 1;
let inertia = assign_to_centroids(
vectors,
n,
dimensions,
¢roids,
k,
&mut assignments,
distance_fn,
);
let inertia_change = (prev_inertia - inertia).abs() / inertia.max(1.0);
if inertia_change < config.tolerance {
converged = true;
break;
}
prev_inertia = inertia;
update_centroids(vectors, n, dimensions, &assignments, k, &mut centroids);
}
let inertia = assign_to_centroids(
vectors,
n,
dimensions,
¢roids,
k,
&mut assignments,
distance_fn,
);
Ok(KMeansResult {
centroids,
assignments,
inertia,
iterations,
converged,
})
}
fn kmeans_plus_plus_init(
vectors: &[f32],
n: usize,
dimensions: usize,
k: usize,
distance_fn: fn(&[f32], &[f32]) -> f32,
seed: Option<u64>,
) -> Vec<f32> {
let mut rng: StdRng = match seed {
Some(s) => StdRng::seed_from_u64(s),
None => StdRng::from_entropy(),
};
let mut centroids = Vec::with_capacity(k * dimensions);
let first_idx = rng.gen_range(0..n);
let first_offset = first_idx * dimensions;
centroids.extend_from_slice(&vectors[first_offset..first_offset + dimensions]);
let mut min_dists = vec![f32::INFINITY; n];
for c in 1..k {
let prev_cent_offset = (c - 1) * dimensions;
let prev_centroid = ¢roids[prev_cent_offset..prev_cent_offset + dimensions];
let mut total_dist = 0.0;
for (i, min_dist) in min_dists.iter_mut().enumerate().take(n) {
let vec_offset = i * dimensions;
let vec = &vectors[vec_offset..vec_offset + dimensions];
let dist = distance_fn(vec, prev_centroid);
let abs_dist = dist.abs();
*min_dist = (*min_dist).min(abs_dist * abs_dist);
total_dist += *min_dist;
}
let mut r = rng.gen::<f32>() * total_dist;
let mut selected_idx = 0;
for (i, dist) in min_dists.iter().enumerate().take(n) {
r -= *dist;
if r <= 0.0 {
selected_idx = i;
break;
}
}
let selected_offset = selected_idx * dimensions;
centroids.extend_from_slice(&vectors[selected_offset..selected_offset + dimensions]);
}
centroids
}
fn assign_to_centroids(
vectors: &[f32],
n: usize,
dimensions: usize,
centroids: &[f32],
k: usize,
assignments: &mut [u32],
distance_fn: fn(&[f32], &[f32]) -> f32,
) -> f32 {
let mut inertia = 0.0;
for (i, assignment) in assignments.iter_mut().enumerate().take(n) {
let vec_offset = i * dimensions;
let vec = &vectors[vec_offset..vec_offset + dimensions];
let mut best_cluster = 0;
let mut best_dist = f32::INFINITY;
for c in 0..k {
let cent_offset = c * dimensions;
let centroid = ¢roids[cent_offset..cent_offset + dimensions];
let dist = distance_fn(vec, centroid);
if dist < best_dist {
best_dist = dist;
best_cluster = c;
}
}
*assignment = best_cluster as u32;
inertia += best_dist;
}
inertia
}
fn update_centroids(
vectors: &[f32],
n: usize,
dimensions: usize,
assignments: &[u32],
k: usize,
centroids: &mut [f32],
) {
let mut cluster_sums = vec![0.0f32; k * dimensions];
let mut cluster_counts = vec![0u32; k];
for (i, &cluster_id) in assignments.iter().enumerate().take(n) {
let cluster = cluster_id as usize;
let vec_offset = i * dimensions;
let sum_offset = cluster * dimensions;
for d in 0..dimensions {
cluster_sums[sum_offset + d] += vectors[vec_offset + d];
}
cluster_counts[cluster] += 1;
}
for (c, &count) in cluster_counts.iter().enumerate().take(k) {
if count == 0 {
continue;
}
let offset = c * dimensions;
for d in 0..dimensions {
centroids[offset + d] = cluster_sums[offset + d] / count as f32;
}
}
}
#[allow(dead_code)]
fn reinitialize_empty_clusters(
vectors: &[f32],
n: usize,
dimensions: usize,
cluster_counts: &[u32],
centroids: &mut [f32],
) {
let mut rng = rand::thread_rng();
for (c, &count) in cluster_counts.iter().enumerate() {
if count == 0 {
let rand_idx = rng.gen_range(0..n);
let rand_offset = rand_idx * dimensions;
let cent_offset = c * dimensions;
centroids[cent_offset..cent_offset + dimensions]
.copy_from_slice(&vectors[rand_offset..rand_offset + dimensions]);
}
}
}
const MIN_VECTORS_PER_THREAD: usize = 1000;
pub fn kmeans_parallel(
vectors: &[f32],
n: usize,
dimensions: usize,
config: &KMeansConfig,
distance_fn: fn(&[f32], &[f32]) -> f32,
) -> Result<KMeansResult, KMeansError> {
if n < MIN_VECTORS_PER_THREAD * 2 {
return kmeans(vectors, n, dimensions, config, distance_fn);
}
if n < config.n_clusters {
return Err(KMeansError::NotEnoughVectors {
n,
k: config.n_clusters,
});
}
if vectors.len() != n * dimensions {
return Err(KMeansError::DimensionMismatch {
expected: n * dimensions,
got: vectors.len(),
});
}
let k = config.n_clusters;
let mut centroids = kmeans_plus_plus_init(vectors, n, dimensions, k, distance_fn, config.seed);
let mut assignments = vec![0u32; n];
let mut prev_inertia = f32::INFINITY;
let mut iterations = 0;
let mut converged = false;
for iter in 0..config.max_iterations {
iterations = iter + 1;
let inertia = assign_to_centroids_parallel(
vectors,
n,
dimensions,
¢roids,
k,
&mut assignments,
distance_fn,
);
let inertia_change = (prev_inertia - inertia).abs() / inertia.max(1.0);
if inertia_change < config.tolerance {
converged = true;
break;
}
prev_inertia = inertia;
update_centroids_parallel(vectors, n, dimensions, &assignments, k, &mut centroids);
}
let inertia = assign_to_centroids_parallel(
vectors,
n,
dimensions,
¢roids,
k,
&mut assignments,
distance_fn,
);
Ok(KMeansResult {
centroids,
assignments,
inertia,
iterations,
converged,
})
}
#[cfg(not(target_arch = "wasm32"))]
fn assign_to_centroids_parallel(
vectors: &[f32],
n: usize,
dimensions: usize,
centroids: &[f32],
k: usize,
assignments: &mut [u32],
distance_fn: fn(&[f32], &[f32]) -> f32,
) -> f32 {
assignments
.par_iter_mut()
.enumerate()
.take(n)
.map(|(i, assignment)| {
let vec_offset = i * dimensions;
let vec = &vectors[vec_offset..vec_offset + dimensions];
let mut best_cluster = 0u32;
let mut best_dist = f32::INFINITY;
for c in 0..k {
let cent_offset = c * dimensions;
let centroid = ¢roids[cent_offset..cent_offset + dimensions];
let dist = distance_fn(vec, centroid);
if dist < best_dist {
best_dist = dist;
best_cluster = c as u32;
}
}
*assignment = best_cluster;
best_dist
})
.sum()
}
#[cfg(target_arch = "wasm32")]
fn assign_to_centroids_parallel(
vectors: &[f32],
n: usize,
dimensions: usize,
centroids: &[f32],
k: usize,
assignments: &mut [u32],
distance_fn: fn(&[f32], &[f32]) -> f32,
) -> f32 {
let mut inertia = 0.0f32;
for i in 0..n {
let vec_offset = i * dimensions;
let vec = &vectors[vec_offset..vec_offset + dimensions];
let mut best_cluster = 0u32;
let mut best_dist = f32::INFINITY;
for c in 0..k {
let cent_offset = c * dimensions;
let centroid = ¢roids[cent_offset..cent_offset + dimensions];
let dist = distance_fn(vec, centroid);
if dist < best_dist {
best_dist = dist;
best_cluster = c as u32;
}
}
assignments[i] = best_cluster;
inertia += best_dist;
}
inertia
}
#[cfg(not(target_arch = "wasm32"))]
fn update_centroids_parallel(
vectors: &[f32],
n: usize,
dimensions: usize,
assignments: &[u32],
k: usize,
centroids: &mut [f32],
) {
let (cluster_sums, cluster_counts) = (0..n)
.into_par_iter()
.fold(
|| (vec![0.0f32; k * dimensions], vec![0u32; k]),
|(mut sums, mut counts), i| {
let cluster = assignments[i] as usize;
let vec_offset = i * dimensions;
let sum_offset = cluster * dimensions;
for d in 0..dimensions {
sums[sum_offset + d] += vectors[vec_offset + d];
}
counts[cluster] += 1;
(sums, counts)
},
)
.reduce(
|| (vec![0.0f32; k * dimensions], vec![0u32; k]),
|(mut sums1, mut counts1), (sums2, counts2)| {
for i in 0..sums1.len() {
sums1[i] += sums2[i];
}
for i in 0..counts1.len() {
counts1[i] += counts2[i];
}
(sums1, counts1)
},
);
for (c, &count) in cluster_counts.iter().enumerate() {
if count == 0 {
continue;
}
let offset = c * dimensions;
for d in 0..dimensions {
centroids[offset + d] = cluster_sums[offset + d] / count as f32;
}
}
}
#[cfg(target_arch = "wasm32")]
fn update_centroids_parallel(
vectors: &[f32],
n: usize,
dimensions: usize,
assignments: &[u32],
k: usize,
centroids: &mut [f32],
) {
let mut cluster_sums = vec![0.0f32; k * dimensions];
let mut cluster_counts = vec![0u32; k];
for i in 0..n {
let cluster = assignments[i] as usize;
let vec_offset = i * dimensions;
let sum_offset = cluster * dimensions;
for d in 0..dimensions {
cluster_sums[sum_offset + d] += vectors[vec_offset + d];
}
cluster_counts[cluster] += 1;
}
for (c, &count) in cluster_counts.iter().enumerate() {
if count == 0 {
continue;
}
let offset = c * dimensions;
for d in 0..dimensions {
centroids[offset + d] = cluster_sums[offset + d] / count as f32;
}
}
}
#[allow(dead_code)]
#[cfg(not(target_arch = "wasm32"))]
fn kmeans_plus_plus_init_parallel(
vectors: &[f32],
n: usize,
dimensions: usize,
k: usize,
distance_fn: fn(&[f32], &[f32]) -> f32,
seed: Option<u64>,
) -> Vec<f32> {
let mut rng: StdRng = match seed {
Some(s) => StdRng::seed_from_u64(s),
None => StdRng::from_entropy(),
};
let mut centroids = Vec::with_capacity(k * dimensions);
let first_idx = rng.gen_range(0..n);
let first_offset = first_idx * dimensions;
centroids.extend_from_slice(&vectors[first_offset..first_offset + dimensions]);
let mut min_dists = vec![f32::INFINITY; n];
for c in 1..k {
let prev_cent_offset = (c - 1) * dimensions;
let prev_centroid = ¢roids[prev_cent_offset..prev_cent_offset + dimensions];
let new_dists: Vec<f32> = (0..n)
.into_par_iter()
.map(|i| {
let vec_offset = i * dimensions;
let vec = &vectors[vec_offset..vec_offset + dimensions];
let dist = distance_fn(vec, prev_centroid);
let abs_dist = dist.abs();
min_dists[i].min(abs_dist * abs_dist)
})
.collect();
let mut total_dist = 0.0f32;
for (i, dist) in new_dists.into_iter().enumerate() {
min_dists[i] = dist;
total_dist += dist;
}
let mut r = rng.gen::<f32>() * total_dist;
let mut selected_idx = 0;
for (i, dist) in min_dists.iter().enumerate().take(n) {
r -= *dist;
if r <= 0.0 {
selected_idx = i;
break;
}
}
let selected_offset = selected_idx * dimensions;
centroids.extend_from_slice(&vectors[selected_offset..selected_offset + dimensions]);
}
centroids
}
#[allow(dead_code)]
#[cfg(target_arch = "wasm32")]
fn kmeans_plus_plus_init_parallel(
vectors: &[f32],
n: usize,
dimensions: usize,
k: usize,
distance_fn: fn(&[f32], &[f32]) -> f32,
seed: Option<u64>,
) -> Vec<f32> {
let mut rng: StdRng = match seed {
Some(s) => StdRng::seed_from_u64(s),
None => StdRng::from_entropy(),
};
let mut centroids = Vec::with_capacity(k * dimensions);
let first_idx = rng.gen_range(0..n);
let first_offset = first_idx * dimensions;
centroids.extend_from_slice(&vectors[first_offset..first_offset + dimensions]);
let mut min_dists = vec![f32::INFINITY; n];
for c in 1..k {
let prev_cent_offset = (c - 1) * dimensions;
let prev_centroid = ¢roids[prev_cent_offset..prev_cent_offset + dimensions];
for i in 0..n {
let vec_offset = i * dimensions;
let vec = &vectors[vec_offset..vec_offset + dimensions];
let dist = distance_fn(vec, prev_centroid);
let abs_dist = dist.abs();
let candidate = abs_dist * abs_dist;
if candidate < min_dists[i] {
min_dists[i] = candidate;
}
}
let mut total_dist = 0.0f32;
for dist in min_dists.iter().take(n) {
total_dist += *dist;
}
let mut r = rng.gen::<f32>() * total_dist;
let mut selected_idx = 0;
for (i, dist) in min_dists.iter().enumerate().take(n) {
r -= *dist;
if r <= 0.0 {
selected_idx = i;
break;
}
}
let selected_offset = selected_idx * dimensions;
centroids.extend_from_slice(&vectors[selected_offset..selected_offset + dimensions]);
}
centroids
}
#[derive(Debug, Clone)]
pub enum KMeansError {
NotEnoughVectors { n: usize, k: usize },
DimensionMismatch { expected: usize, got: usize },
}
impl std::fmt::Display for KMeansError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
KMeansError::NotEnoughVectors { n, k } => {
write!(f, "Not enough vectors: {n} < {k} clusters")
}
KMeansError::DimensionMismatch { expected, got } => {
write!(f, "Dimension mismatch: expected {expected}, got {got}")
}
}
}
}
impl std::error::Error for KMeansError {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_kmeans_config_default() {
let config = KMeansConfig::default();
assert_eq!(config.n_clusters, 100);
assert_eq!(config.max_iterations, 25);
}
#[test]
fn test_kmeans_config_builder() {
let config = KMeansConfig::new(50)
.with_max_iterations(10)
.with_tolerance(1e-3)
.with_seed(42);
assert_eq!(config.n_clusters, 50);
assert_eq!(config.max_iterations, 10);
assert_eq!(config.seed, Some(42));
}
#[test]
fn test_kmeans_simple() {
let mut vectors = Vec::new();
for _ in 0..50 {
vectors.extend_from_slice(&[1.0 + rand::random::<f32>() * 0.1, 0.0, 0.0]);
}
for _ in 0..50 {
vectors.extend_from_slice(&[0.0, 1.0 + rand::random::<f32>() * 0.1, 0.0]);
}
let config = KMeansConfig::new(2).with_seed(42);
let result = kmeans(&vectors, 100, 3, &config, squared_euclidean).expect("expected value");
assert_eq!(result.centroids.len(), 2 * 3);
assert_eq!(result.assignments.len(), 100);
assert!(result.iterations <= config.max_iterations);
}
#[test]
fn test_kmeans_not_enough_vectors() {
let vectors = vec![1.0, 2.0, 3.0];
let config = KMeansConfig::new(2);
let result = kmeans(&vectors, 1, 3, &config, squared_euclidean);
assert!(matches!(result, Err(KMeansError::NotEnoughVectors { .. })));
}
#[test]
fn test_kmeans_dimension_mismatch() {
let vectors = vec![1.0, 2.0, 3.0, 4.0];
let config = KMeansConfig::new(1);
let result = kmeans(&vectors, 2, 3, &config, squared_euclidean);
assert!(matches!(result, Err(KMeansError::DimensionMismatch { .. })));
}
#[test]
fn test_kmeans_convergence() {
let mut vectors = Vec::new();
for _ in 0..100 {
vectors.extend_from_slice(&[0.0, 0.0]);
}
for _ in 0..100 {
vectors.extend_from_slice(&[10.0, 10.0]);
}
let config = KMeansConfig::new(2).with_seed(42).with_tolerance(1e-6);
let result = kmeans(&vectors, 200, 2, &config, squared_euclidean).expect("expected value");
assert!(result.converged || result.iterations <= 10);
}
#[test]
fn test_kmeans_assignments() {
let vectors = vec![
0.0, 0.0, 0.1, 0.1, 10.0, 10.0, 10.1, 10.1, ];
let config = KMeansConfig::new(2).with_seed(42);
let result = kmeans(&vectors, 4, 2, &config, squared_euclidean).expect("expected value");
assert_eq!(result.assignments[0], result.assignments[1]);
assert_eq!(result.assignments[2], result.assignments[3]);
assert_ne!(result.assignments[0], result.assignments[2]);
}
#[test]
fn test_error_display() {
let err1 = KMeansError::NotEnoughVectors { n: 5, k: 10 };
assert!(err1.to_string().contains("5"));
assert!(err1.to_string().contains("10"));
let err2 = KMeansError::DimensionMismatch {
expected: 100,
got: 50,
};
assert!(err2.to_string().contains("100"));
assert!(err2.to_string().contains("50"));
}
#[test]
fn test_kmeans_parallel_fallback_small_dataset() {
let vectors = vec![
0.0, 0.0, 0.1, 0.1, 10.0, 10.0, 10.1, 10.1, ];
let config = KMeansConfig::new(2).with_seed(42);
let result =
kmeans_parallel(&vectors, 4, 2, &config, squared_euclidean).expect("expected value");
assert_eq!(result.assignments[0], result.assignments[1]);
assert_eq!(result.assignments[2], result.assignments[3]);
assert_ne!(result.assignments[0], result.assignments[2]);
}
#[test]
fn test_kmeans_parallel_large_dataset() {
let n = 5000; let dims = 16;
let k = 10;
let mut vectors = Vec::with_capacity(n * dims);
for i in 0..n {
for d in 0..dims {
let cluster_center = (i % k) as f32 * 10.0;
vectors.push(cluster_center + (d as f32) * 0.01 + (i as f32) * 0.0001);
}
}
let config = KMeansConfig::new(k).with_seed(42).with_max_iterations(15);
let result =
kmeans_parallel(&vectors, n, dims, &config, squared_euclidean).expect("expected value");
assert_eq!(result.centroids.len(), k * dims);
assert_eq!(result.assignments.len(), n);
assert!(result.inertia.is_finite());
assert!(result.inertia >= 0.0);
}
#[test]
fn test_kmeans_parallel_vs_sequential_consistency() {
let n = 3000;
let dims = 8;
let k = 5;
let mut vectors = Vec::with_capacity(n * dims);
for i in 0..n {
let cluster = i % k;
for d in 0..dims {
vectors.push((cluster * 100 + d) as f32 + rand::random::<f32>() * 0.1);
}
}
let config = KMeansConfig::new(k).with_seed(123).with_max_iterations(20);
let result_par =
kmeans_parallel(&vectors, n, dims, &config, squared_euclidean).expect("expected value");
let result_seq = kmeans(&vectors, n, dims, &config, squared_euclidean).expect("expected value");
assert_eq!(result_par.centroids.len(), result_seq.centroids.len());
assert_eq!(result_par.assignments.len(), result_seq.assignments.len());
let ratio = result_par.inertia / result_seq.inertia;
assert!(ratio > 0.5 && ratio < 2.0, "Inertia ratio: {ratio}");
}
#[test]
fn test_kmeans_parallel_well_separated_clusters() {
let n = 4000;
let dims = 4;
let k = 4;
let vectors_per_cluster = n / k;
let mut vectors = Vec::with_capacity(n * dims);
for cluster in 0..k {
let center = cluster as f32 * 100.0;
for _ in 0..vectors_per_cluster {
for d in 0..dims {
vectors.push(center + (d as f32) * 0.1 + rand::random::<f32>() * 0.5);
}
}
}
let config = KMeansConfig::new(k).with_seed(456).with_max_iterations(25);
let result =
kmeans_parallel(&vectors, n, dims, &config, squared_euclidean).expect("expected value");
let mut cluster_counts = vec![0usize; k];
for &assignment in &result.assignments {
cluster_counts[assignment as usize] += 1;
}
for count in &cluster_counts {
assert!(
*count > vectors_per_cluster / 2,
"Cluster has too few vectors: {count}"
);
}
}
#[test]
fn test_kmeans_parallel_convergence() {
let n = 2500;
let dims = 6;
let k = 3;
let mut vectors = Vec::with_capacity(n * dims);
for i in 0..n {
let cluster = i % k;
for _ in 0..dims {
vectors.push(cluster as f32 * 1000.0);
}
}
let config = KMeansConfig::new(k)
.with_seed(789)
.with_max_iterations(50)
.with_tolerance(1e-6);
let result =
kmeans_parallel(&vectors, n, dims, &config, squared_euclidean).expect("expected value");
assert!(
result.converged || result.iterations < 20,
"Should converge quickly with perfect clusters, iterations: {}",
result.iterations
);
}
}