use ferrolearn_core::error::FerroError;
use ferrolearn_core::traits::{Fit, Predict, Transform};
use ndarray::{Array1, Array2};
use num_traits::Float;
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use rayon::prelude::*;
#[derive(Debug, Clone)]
pub struct KMeans<F> {
pub n_clusters: usize,
pub max_iter: usize,
pub tol: F,
pub n_init: usize,
pub random_state: Option<u64>,
}
impl<F: Float> KMeans<F> {
#[must_use]
pub fn new(n_clusters: usize) -> Self {
Self {
n_clusters,
max_iter: 300,
tol: F::from(1e-4).unwrap_or_else(F::epsilon),
n_init: 10,
random_state: None,
}
}
#[must_use]
pub fn with_max_iter(mut self, max_iter: usize) -> Self {
self.max_iter = max_iter;
self
}
#[must_use]
pub fn with_tol(mut self, tol: F) -> Self {
self.tol = tol;
self
}
#[must_use]
pub fn with_n_init(mut self, n_init: usize) -> Self {
self.n_init = n_init;
self
}
#[must_use]
pub fn with_random_state(mut self, seed: u64) -> Self {
self.random_state = Some(seed);
self
}
}
#[derive(Debug, Clone)]
pub struct FittedKMeans<F> {
cluster_centers_: Array2<F>,
labels_: Array1<usize>,
inertia_: F,
n_iter_: usize,
}
impl<F: Float> FittedKMeans<F> {
#[must_use]
pub fn cluster_centers(&self) -> &Array2<F> {
&self.cluster_centers_
}
#[must_use]
pub fn labels(&self) -> &Array1<usize> {
&self.labels_
}
#[must_use]
pub fn inertia(&self) -> F {
self.inertia_
}
#[must_use]
pub fn n_iter(&self) -> usize {
self.n_iter_
}
}
fn squared_euclidean<F: Float>(a: &[F], b: &[F]) -> F {
a.iter()
.zip(b.iter())
.fold(F::zero(), |acc, (&ai, &bi)| acc + (ai - bi) * (ai - bi))
}
fn kmeans_plus_plus<F: Float>(x: &Array2<F>, k: usize, rng: &mut StdRng) -> Array2<F> {
let n_samples = x.nrows();
let n_features = x.ncols();
let mut centers = Array2::zeros((k, n_features));
if n_samples == 0 {
return centers;
}
let first_idx = rng.random_range(0..n_samples);
centers.row_mut(0).assign(&x.row(first_idx));
if k == 1 {
return centers;
}
let mut min_dists = Array1::from_elem(n_samples, F::max_value());
{
let center0 = centers.row(0);
let center0_slice = center0.as_slice().unwrap_or(&[]);
for i in 0..n_samples {
min_dists[i] = squared_euclidean(
x.row(i).as_slice().unwrap_or(&[]),
center0_slice,
);
}
}
let n_trials = (2 + (k as f64).ln().floor() as usize).max(1);
for c in 1..k {
let total: F = min_dists.iter().fold(F::zero(), |acc, &d| acc + d);
if total <= F::zero() {
let idx = rng.random_range(0..n_samples);
centers.row_mut(c).assign(&x.row(idx));
continue;
}
let mut best_candidate: usize = 0;
let mut best_potential: Option<F> = None;
let mut best_new_dists: Option<Array1<F>> = None;
for _ in 0..n_trials {
let threshold: F =
F::from(rng.random::<f64>()).unwrap_or_else(F::zero) * total;
let mut cumsum = F::zero();
let mut candidate = n_samples - 1;
for i in 0..n_samples {
cumsum = cumsum + min_dists[i];
if cumsum >= threshold {
candidate = i;
break;
}
}
let cand_slice = x.row(candidate).as_slice().unwrap_or(&[]).to_vec();
let mut new_dists = min_dists.clone();
let mut potential = F::zero();
for i in 0..n_samples {
let d = squared_euclidean(
x.row(i).as_slice().unwrap_or(&[]),
&cand_slice,
);
if d < new_dists[i] {
new_dists[i] = d;
}
potential = potential + new_dists[i];
}
if best_potential.is_none_or(|bp| potential < bp) {
best_potential = Some(potential);
best_candidate = candidate;
best_new_dists = Some(new_dists);
}
}
centers.row_mut(c).assign(&x.row(best_candidate));
if let Some(d) = best_new_dists {
min_dists = d;
}
}
centers
}
const PARALLEL_WORK_THRESHOLD: usize = 100_000;
fn assign_clusters<F: Float + Send + Sync>(
x: &Array2<F>,
centers: &Array2<F>,
) -> (Array1<usize>, F) {
let n_samples = x.nrows();
let mut labels = Array1::zeros(n_samples);
let inertia = assign_clusters_into(&mut labels, x, centers);
(labels, inertia)
}
fn assign_clusters_into<F: Float + Send + Sync>(
labels: &mut Array1<usize>,
x: &Array2<F>,
centers: &Array2<F>,
) -> F {
let work = x.nrows() * x.ncols();
if work < PARALLEL_WORK_THRESHOLD {
assign_serial(labels, x, centers)
} else {
assign_parallel(labels, x, centers)
}
}
#[inline]
fn nearest_center<F: Float>(row_slice: &[F], centers: &Array2<F>) -> (usize, F) {
let k = centers.nrows();
let mut best_label = 0;
let mut best_dist = F::max_value();
for c in 0..k {
let center_row = centers.row(c);
let center_slice = center_row.as_slice().unwrap_or(&[]);
let d = squared_euclidean(row_slice, center_slice);
if d < best_dist {
best_dist = d;
best_label = c;
}
}
(best_label, best_dist)
}
fn assign_serial<F: Float>(labels: &mut Array1<usize>, x: &Array2<F>, centers: &Array2<F>) -> F {
let n_samples = x.nrows();
let mut inertia = F::zero();
for i in 0..n_samples {
let row = x.row(i);
let row_slice = row.as_slice().unwrap_or(&[]);
let (label, dist) = nearest_center(row_slice, centers);
labels[i] = label;
inertia = inertia + dist;
}
inertia
}
fn assign_parallel<F: Float + Send + Sync>(
labels: &mut Array1<usize>,
x: &Array2<F>,
centers: &Array2<F>,
) -> F {
let n_samples = x.nrows();
let labels_slice = labels.as_slice_mut().unwrap();
let chunk_size = (n_samples / rayon::current_num_threads()).max(64);
labels_slice
.par_chunks_mut(chunk_size)
.enumerate()
.map(|(chunk_idx, chunk)| {
let start = chunk_idx * chunk_size;
let mut local_inertia = F::zero();
for (local_i, label) in chunk.iter_mut().enumerate() {
let i = start + local_i;
let row = x.row(i);
let row_slice = row.as_slice().unwrap_or(&[]);
let (best_label, dist) = nearest_center(row_slice, centers);
*label = best_label;
local_inertia = local_inertia + dist;
}
local_inertia
})
.reduce(F::zero, |a, b| a + b)
}
fn recompute_centroids_into<F: Float>(
new_centers: &mut Array2<F>,
counts: &mut [F],
x: &Array2<F>,
labels: &Array1<usize>,
n_features: usize,
old_centers: &Array2<F>,
) -> F {
let k = new_centers.nrows();
new_centers.fill(F::zero());
counts.iter_mut().for_each(|c| *c = F::zero());
for (i, &label) in labels.iter().enumerate() {
counts[label] = counts[label] + F::one();
for j in 0..n_features {
new_centers[[label, j]] = new_centers[[label, j]] + x[[i, j]];
}
}
for c in 0..k {
if counts[c] > F::zero() {
for j in 0..n_features {
new_centers[[c, j]] = new_centers[[c, j]] / counts[c];
}
} else {
new_centers.row_mut(c).assign(&old_centers.row(c));
}
}
let mut max_shift = F::zero();
for c in 0..k {
let shift = squared_euclidean(
new_centers.row(c).as_slice().unwrap_or(&[]),
old_centers.row(c).as_slice().unwrap_or(&[]),
);
if shift > max_shift {
max_shift = shift;
}
}
max_shift.sqrt()
}
impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, ()> for KMeans<F> {
type Fitted = FittedKMeans<F>;
type Error = FerroError;
fn fit(&self, x: &Array2<F>, _y: &()) -> Result<FittedKMeans<F>, FerroError> {
let (n_samples, n_features) = x.dim();
if self.n_clusters == 0 {
return Err(FerroError::InvalidParameter {
name: "n_clusters".into(),
reason: "must be at least 1".into(),
});
}
if n_samples == 0 {
return Err(FerroError::InsufficientSamples {
required: self.n_clusters,
actual: 0,
context: "KMeans requires at least n_clusters samples".into(),
});
}
if n_samples < self.n_clusters {
return Err(FerroError::InsufficientSamples {
required: self.n_clusters,
actual: n_samples,
context: "KMeans requires at least n_clusters samples".into(),
});
}
if self.n_init == 0 {
return Err(FerroError::InvalidParameter {
name: "n_init".into(),
reason: "must be at least 1".into(),
});
}
let base_seed = self.random_state.unwrap_or(0);
let mut best_result: Option<FittedKMeans<F>> = None;
let mut labels = Array1::zeros(n_samples);
let mut new_centers = Array2::zeros((self.n_clusters, n_features));
let mut counts = vec![F::zero(); self.n_clusters];
for run in 0..self.n_init {
let mut rng = StdRng::seed_from_u64(base_seed.wrapping_add(run as u64));
let mut centers = kmeans_plus_plus(x, self.n_clusters, &mut rng);
let mut n_iter = 0;
let mut inertia = F::max_value();
for iter in 0..self.max_iter {
inertia = assign_clusters_into(&mut labels, x, ¢ers);
let max_shift = recompute_centroids_into(
&mut new_centers,
&mut counts,
x,
&labels,
n_features,
¢ers,
);
std::mem::swap(&mut centers, &mut new_centers);
n_iter = iter + 1;
if max_shift < self.tol {
break;
}
}
let candidate = FittedKMeans {
cluster_centers_: centers,
labels_: labels.clone(),
inertia_: inertia,
n_iter_: n_iter,
};
match &best_result {
None => best_result = Some(candidate),
Some(best) => {
if candidate.inertia_ < best.inertia_ {
best_result = Some(candidate);
}
}
}
}
best_result.ok_or_else(|| FerroError::InvalidParameter {
name: "n_init".into(),
reason: "internal error: no runs completed".into(),
})
}
}
impl<F: Float + Send + Sync + 'static> KMeans<F> {
pub fn fit_predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
let fitted = self.fit(x, &())?;
fitted.predict(x)
}
}
impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedKMeans<F> {
type Output = Array1<usize>;
type Error = FerroError;
fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
let n_features = x.ncols();
let expected_features = self.cluster_centers_.ncols();
if n_features != expected_features {
return Err(FerroError::ShapeMismatch {
expected: vec![expected_features],
actual: vec![n_features],
context: "number of features must match fitted KMeans model".into(),
});
}
let (labels, _inertia) = assign_clusters(x, &self.cluster_centers_);
Ok(labels)
}
}
impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedKMeans<F> {
type Output = Array2<F>;
type Error = FerroError;
fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
let n_features = x.ncols();
let expected_features = self.cluster_centers_.ncols();
if n_features != expected_features {
return Err(FerroError::ShapeMismatch {
expected: vec![expected_features],
actual: vec![n_features],
context: "number of features must match fitted KMeans model".into(),
});
}
let n_samples = x.nrows();
let k = self.cluster_centers_.nrows();
let centers = &self.cluster_centers_;
let mut distances = vec![F::zero(); n_samples * k];
let work = n_samples * n_features;
if work < PARALLEL_WORK_THRESHOLD {
for i in 0..n_samples {
let row = x.row(i);
let row_slice = row.as_slice().unwrap_or(&[]);
for c in 0..k {
let center = centers.row(c);
let cs = center.as_slice().unwrap_or(&[]);
distances[i * k + c] = squared_euclidean(row_slice, cs).sqrt();
}
}
} else {
distances
.par_chunks_mut(k)
.enumerate()
.for_each(|(i, chunk)| {
let row = x.row(i);
let row_slice = row.as_slice().unwrap_or(&[]);
for (c, slot) in chunk.iter_mut().enumerate() {
let center = centers.row(c);
let cs = center.as_slice().unwrap_or(&[]);
*slot = squared_euclidean(row_slice, cs).sqrt();
}
});
}
Array2::from_shape_vec((n_samples, k), distances).map_err(|_| {
FerroError::NumericalInstability {
message: "failed to construct distance matrix".into(),
}
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
fn make_blobs() -> Array2<f64> {
Array2::from_shape_vec(
(9, 2),
vec![
0.0, 0.0, 0.1, 0.1, -0.1, 0.1, 10.0, 10.0, 10.1, 10.1, 9.9, 10.1, 0.0, 10.0, 0.1, 10.1, -0.1, 9.9,
],
)
.unwrap()
}
#[test]
fn test_well_separated_blobs() {
let x = make_blobs();
let model = KMeans::<f64>::new(3).with_random_state(42).with_n_init(5);
let fitted = model.fit(&x, &()).unwrap();
let labels = fitted.labels();
assert_eq!(labels[0], labels[1]);
assert_eq!(labels[0], labels[2]);
assert_eq!(labels[3], labels[4]);
assert_eq!(labels[3], labels[5]);
assert_eq!(labels[6], labels[7]);
assert_eq!(labels[6], labels[8]);
assert_ne!(labels[0], labels[3]);
assert_ne!(labels[0], labels[6]);
assert_ne!(labels[3], labels[6]);
}
#[test]
fn test_convergence() {
let x = make_blobs();
let model = KMeans::<f64>::new(3)
.with_random_state(42)
.with_max_iter(1000)
.with_n_init(1);
let fitted = model.fit(&x, &()).unwrap();
assert!(fitted.n_iter() < 100);
}
#[test]
fn test_n_init_picks_best() {
let x = make_blobs();
let model_1 = KMeans::<f64>::new(3).with_random_state(42).with_n_init(1);
let fitted_1 = model_1.fit(&x, &()).unwrap();
let model_10 = KMeans::<f64>::new(3).with_random_state(42).with_n_init(10);
let fitted_10 = model_10.fit(&x, &()).unwrap();
assert!(fitted_10.inertia() <= fitted_1.inertia() + 1e-10);
}
#[test]
fn test_kmeans_pp_initialization_deterministic() {
let x = make_blobs();
let model = KMeans::<f64>::new(3).with_random_state(123).with_n_init(1);
let fitted1 = model.fit(&x, &()).unwrap();
let fitted2 = model.fit(&x, &()).unwrap();
assert_eq!(fitted1.labels(), fitted2.labels());
assert_relative_eq!(fitted1.inertia(), fitted2.inertia(), epsilon = 1e-12);
}
#[test]
fn test_reproducibility_with_seed() {
let x = make_blobs();
let model = KMeans::<f64>::new(3).with_random_state(99);
let fitted1 = model.fit(&x, &()).unwrap();
let fitted2 = model.fit(&x, &()).unwrap();
assert_eq!(fitted1.labels(), fitted2.labels());
}
#[test]
fn test_predict_on_new_data() {
let x = make_blobs();
let model = KMeans::<f64>::new(3).with_random_state(42);
let fitted = model.fit(&x, &()).unwrap();
let new_x =
Array2::from_shape_vec((3, 2), vec![0.05, 0.05, 10.05, 10.05, 0.05, 10.05]).unwrap();
let new_labels = fitted.predict(&new_x).unwrap();
let label_near_origin = new_labels[0];
assert_eq!(label_near_origin, fitted.labels()[0]);
let label_near_10_10 = new_labels[1];
assert_eq!(label_near_10_10, fitted.labels()[3]);
let label_near_0_10 = new_labels[2];
assert_eq!(label_near_0_10, fitted.labels()[6]);
}
#[test]
fn test_transform_distances() {
let x = Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 10.0, 0.0, 0.0, 10.0, 10.0, 10.0])
.unwrap();
let model = KMeans::<f64>::new(2).with_random_state(42);
let fitted = model.fit(&x, &()).unwrap();
let dists = fitted.transform(&x).unwrap();
assert_eq!(dists.dim(), (4, 2));
for i in 0..4 {
let own_cluster = fitted.labels()[i];
let other_cluster = 1 - own_cluster;
assert!(dists[[i, own_cluster]] < dists[[i, other_cluster]]);
}
}
#[test]
fn test_transform_shape() {
let x = make_blobs();
let model = KMeans::<f64>::new(3).with_random_state(42);
let fitted = model.fit(&x, &()).unwrap();
let dists = fitted.transform(&x).unwrap();
assert_eq!(dists.dim(), (9, 3));
}
#[test]
fn test_cluster_centers_shape() {
let x = make_blobs();
let model = KMeans::<f64>::new(3).with_random_state(42);
let fitted = model.fit(&x, &()).unwrap();
assert_eq!(fitted.cluster_centers().dim(), (3, 2));
}
#[test]
fn test_inertia_non_negative() {
let x = make_blobs();
let model = KMeans::<f64>::new(3).with_random_state(42);
let fitted = model.fit(&x, &()).unwrap();
assert!(fitted.inertia() >= 0.0);
}
#[test]
fn test_k_equals_n_samples() {
let x = Array2::from_shape_vec((3, 2), vec![0.0, 0.0, 5.0, 5.0, 10.0, 10.0]).unwrap();
let model = KMeans::<f64>::new(3).with_random_state(42);
let fitted = model.fit(&x, &()).unwrap();
assert_relative_eq!(fitted.inertia(), 0.0, epsilon = 1e-10);
let labels = fitted.labels();
assert_ne!(labels[0], labels[1]);
assert_ne!(labels[0], labels[2]);
assert_ne!(labels[1], labels[2]);
}
#[test]
fn test_single_cluster() {
let x =
Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0]).unwrap();
let model = KMeans::<f64>::new(1).with_random_state(42);
let fitted = model.fit(&x, &()).unwrap();
for &label in fitted.labels() {
assert_eq!(label, 0);
}
let center = fitted.cluster_centers().row(0);
assert_relative_eq!(center[0], 2.5, epsilon = 1e-10);
assert_relative_eq!(center[1], 2.5, epsilon = 1e-10);
}
#[test]
fn test_single_sample() {
let x = Array2::from_shape_vec((1, 2), vec![5.0, 5.0]).unwrap();
let model = KMeans::<f64>::new(1).with_random_state(42);
let fitted = model.fit(&x, &()).unwrap();
assert_eq!(fitted.labels().len(), 1);
assert_eq!(fitted.labels()[0], 0);
assert_relative_eq!(fitted.inertia(), 0.0, epsilon = 1e-10);
}
#[test]
fn test_k_greater_than_n_samples() {
let x = Array2::from_shape_vec((2, 2), vec![1.0, 1.0, 2.0, 2.0]).unwrap();
let model = KMeans::<f64>::new(5);
let result = model.fit(&x, &());
assert!(result.is_err());
}
#[test]
fn test_zero_clusters() {
let x = Array2::from_shape_vec((3, 2), vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0]).unwrap();
let model = KMeans::<f64>::new(0);
let result = model.fit(&x, &());
assert!(result.is_err());
}
#[test]
fn test_empty_data() {
let x = Array2::<f64>::zeros((0, 2));
let model = KMeans::<f64>::new(3);
let result = model.fit(&x, &());
assert!(result.is_err());
}
#[test]
fn test_predict_shape_mismatch() {
let x =
Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0]).unwrap();
let model = KMeans::<f64>::new(2).with_random_state(42);
let fitted = model.fit(&x, &()).unwrap();
let x_bad = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
let result = fitted.predict(&x_bad);
assert!(result.is_err());
}
#[test]
fn test_transform_shape_mismatch() {
let x =
Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0]).unwrap();
let model = KMeans::<f64>::new(2).with_random_state(42);
let fitted = model.fit(&x, &()).unwrap();
let x_bad = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
let result = fitted.transform(&x_bad);
assert!(result.is_err());
}
#[test]
fn test_f32_support() {
let x = Array2::from_shape_vec(
(6, 2),
vec![
0.0f32, 0.0, 0.1, 0.1, -0.1, 0.1, 10.0, 10.0, 10.1, 10.1, 9.9, 10.1,
],
)
.unwrap();
let model = KMeans::<f32>::new(2).with_random_state(42);
let fitted = model.fit(&x, &()).unwrap();
assert_eq!(fitted.labels().len(), 6);
let preds = fitted.predict(&x).unwrap();
assert_eq!(preds.len(), 6);
}
#[test]
fn test_two_clusters_on_line() {
let x = Array2::from_shape_vec((6, 1), vec![0.0, 0.1, -0.1, 100.0, 100.1, 99.9]).unwrap();
let model = KMeans::<f64>::new(2).with_random_state(42);
let fitted = model.fit(&x, &()).unwrap();
let labels = fitted.labels();
assert_eq!(labels[0], labels[1]);
assert_eq!(labels[0], labels[2]);
assert_eq!(labels[3], labels[4]);
assert_eq!(labels[3], labels[5]);
assert_ne!(labels[0], labels[3]);
}
#[test]
fn test_identical_points() {
let x =
Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]).unwrap();
let model = KMeans::<f64>::new(1).with_random_state(42);
let fitted = model.fit(&x, &()).unwrap();
assert_relative_eq!(fitted.inertia(), 0.0, epsilon = 1e-10);
}
}