//! K-Means clustering with k-Means++ initialization.
//!
//! This module provides [`KMeans`], an unsupervised clustering algorithm
//! that partitions data into `k` clusters by minimizing within-cluster
//! sum of squared distances (inertia). The implementation uses Lloyd's
//! algorithm with k-Means++ initialization for smart centroid seeding.
//!
//! # Algorithm
//!
//! 1. **k-Means++ initialization**: pick the first center uniformly at random,
//! then pick each subsequent center with probability proportional to D(x)²
//! (squared distance to the nearest existing center).
//! 2. **Lloyd's algorithm**: alternate between assigning samples to the nearest
//! centroid and recomputing centroids as the mean of their assigned samples.
//! 3. **Multi-start**: repeat `n_init` times and keep the result with the
//! lowest inertia.
//!
//! The assignment step is parallelized with Rayon.
//!
//! # Examples
//!
//! ```
//! use ferrolearn_cluster::KMeans;
//! use ferrolearn_core::{Fit, Predict, Transform};
//! use ndarray::Array2;
//!
//! let x = Array2::from_shape_vec((6, 2), vec![
//! 1.0, 1.0, 1.1, 1.0, 1.0, 1.1,
//! 5.0, 5.0, 5.1, 5.0, 5.0, 5.1,
//! ]).unwrap();
//!
//! let model = KMeans::<f64>::new(2);
//! let fitted = model.fit(&x, &()).unwrap();
//! let labels = fitted.predict(&x).unwrap();
//! assert_eq!(labels.len(), 6);
//! ```
//!
//! # `## REQ status`
//!
//! Binary (R-DEFER-2), translating `sklearn/cluster/_kmeans.py`
//! (`class KMeans(_BaseKMeans)` `:1196`; `__init__` `:1388`; `_kmeans_single_lloyd`
//! `:631`; `_tolerance` `:286`; `_check_params_vs_input` `:875`). Design doc:
//! `.design/cluster/kmeans.md`. Cites use ferrolearn symbol anchors / sklearn
//! `file:line` (commit 156ef14); expected values from the live sklearn 1.5.2 oracle
//! (R-CHAR-3). KMeans has a REAL CPython consumer: `_RsKMeans`
//! (`ferrolearn-python/src/clusterers.rs`) → `ferrolearn.KMeans`. Verify-and-document
//! unit: the `labels_` PARTITION (up to a label permutation, well-separated regime),
//! `predict`/`transform` contracts, the PyO3 marshalling, the labels_/inertia_↔centers
//! consistency, and the `n_init` default match sklearn and SHIP. Exact
//! `cluster_centers_`/`inertia_` VALUES + `labels_` integers + `n_iter_` DIVERGE —
//! blocked by numpy-RNG init parity (#1039), the convergence criterion + relative tol
//! (#1036), and empty-cluster relocation (#1040).
//!
//! | REQ | Status | Evidence |
//! |---|---|---|
//! | REQ-1 (`labels_` PARTITION up-to-permutation, separable data) | SHIPPED | impl `Fit::fit` (greedy k-means++ `fn kmeans_plus_plus` → Lloyd `fn assign_clusters_into`/`fn recompute_centroids_into` → best-of-`n_init`) recovers sklearn's grouping on well-separated data. Consumers: PyO3 `RsKMeans::fit` (`clusterers.rs`) + crate re-export `pub use kmeans::{FittedKMeans, KMeans}` (`lib.rs`). Guards: `green_req1_two_blob_partition`, `green_req1_three_blob_partition` in `tests/divergence_kmeans.rs` (canonicalized, live-oracle). Underclaim: PARTITION up-to-permutation only — `labels_` integers + `cluster_centers_`/`inertia_` VALUES (REQ-9) diverge. |
//! | REQ-2 (`predict` nearest-center contract) | SHIPPED | impl `Predict::predict` (`fn nearest_center`) returns argmin-squared-euclidean center index + `FerroError::ShapeMismatch`; `transform(X).argmin(1) == predict(X)`. Consumers: PyO3 `RsKMeans::predict` + crate re-export. Guard: `green_req2_predict_is_transform_argmin`. |
//! | REQ-3 (`transform` distance-to-centers contract) | SHIPPED | impl `Transform::transform` returns shape `(n_samples, n_clusters)`, column j = Euclidean distance to center j (= sklearn `_BaseKMeans.transform`). Consumers: PyO3 `RsKMeans::transform` + crate re-export. Guard: `green_req3_transform_shape_and_nonneg`. Underclaim: CONTRACT only — distances track `cluster_centers_` (REQ-9). |
//! | REQ-4 (PyO3 binding marshalling) | SHIPPED | impl `#[pyclass(name="_RsKMeans")] RsKMeans` (`ferrolearn-python/src/clusterers.rs`) marshals `fit`/`predict`/`transform` + `cluster_centers_`/`labels_`/`inertia_`/`n_iter_` getters; registered in `ferrolearn-python/src/lib.rs`, wrapped `class KMeans(...)` in `python/ferrolearn/_clusterers.py`, exported in `__init__.py`. Verification: `maturin develop` + `pytest tests/ -q`. Underclaim: marshalling/shape contract — the binding's `n_init=10` signature default + `int64` label dtype diverge (REQ-12); fitted VALUES inherit REQ-9. |
//! | REQ-6 (`labels_`/`inertia_` consistency with final centers) | SHIPPED | impl `Fit::fit` runs a final E-step (`inertia = assign_clusters_into(&mut labels, x, ¢ers)`) after the Lloyd loop so `labels_`/`inertia_` match the post-swap `cluster_centers_`, mirroring sklearn's post-loop E-step re-run (`_kmeans.py:605-625`). Invariant `fit(X).predict(X) == labels_` now holds. Guard: `pin_req6_predict_equals_labels`. Fixed #1037. |
//! | REQ-14 (`n_init` constructor default = 1) | SHIPPED | impl `fn new` defaults `n_init: 1`, matching sklearn `n_init="auto"` → 1 for the default `init="k-means++"` (`_kmeans.py:886-896`). Guard: `pin_req14_n_init_default_is_one`. Fixed #1045. (The PyO3 binding's `n_init=10` signature default is the separate REQ-12.) |
//! | REQ-5 (convergence criterion + relative tol) | NOT-STARTED | open prereq blocker #1036. sklearn converges on label-no-change OR `(center_shift**2).sum() <= mean(var(X))*tol` (`_kmeans.py:286-294,586-601`); ferrolearn uses absolute `tol` + `max_shift < tol` MAX-euclidean-shift (`fn recompute_centroids_into` + `fn fit`). Different threshold/reduction → different `n_iter_` + stop point. |
//! | REQ-7 (`init` param + random/array/callable + exact k-means++) | NOT-STARTED | open prereq blocker #1038. sklearn `init ∈ {"k-means++","random"}|callable|array` default `"k-means++"` (`_kmeans.py:1391`); ferrolearn has NO `init` param (always greedy k-means++, `fn kmeans_plus_plus`). Default matches; param surface + non-default inits missing; exact k-means++ diverges (numpy RNG, REQ-8). |
//! | REQ-8 (`random_state` numpy-RNG parity) | NOT-STARTED | open prereq blocker #1039. sklearn `check_random_state` + numpy RNG; ferrolearn `StdRng::seed_from_u64` (`fn fit`). Different RNG → exact centers/labels/inertia/n_iter cannot match. Depends on a ferray `random` analog (R-SUBSTRATE-5); blocks REQ-9. |
//! | REQ-9 (centers/inertia/label-integers/n_iter VALUE parity) | NOT-STARTED | open prereq blocker #1040. Exact values diverge via numpy-RNG (REQ-8), convergence + relative tol (REQ-5), and empty-cluster relocation — sklearn moves an emptied center to the farthest sample (`_relocate_empty_clusters_dense`), ferrolearn keeps the old center (`fn recompute_centroids_into` else-branch). Gated on REQ-5/REQ-8. |
//! | REQ-10 (ctor/fit surface init/algorithm/copy_x/verbose/sample_weight + n_clusters=8 + error ABI) | NOT-STARTED | open prereq blocker #1041. sklearn `__init__` (`_kmeans.py:1387-1411`) + `fit(sample_weight)` + `_check_params_vs_input` (`InvalidParameterError`, `:875-908`) + `n_features_in_`; ferrolearn `KMeans<F>` has `n_clusters/max_iter/tol/n_init/random_state` only + `FerroError` ABI. |
//! | REQ-11 (`score` + `fit_transform`) | NOT-STARTED | open prereq blocker #1042. sklearn `KMeans.score(X) = -inertia` (`_kmeans.py:1156-1184`) + `fit_transform`; ferrolearn `FittedKMeans` has neither (only `fn fit_predict`); `RsKMeans` has no `score`/`fit_transform`. |
//! | REQ-12 (ferrolearn-python binding `n_init` default + `labels_` dtype) | NOT-STARTED | open prereq blocker #1043. PyO3 `RsKMeans::new` signature `n_init=10` + Python `ferrolearn.KMeans` `n_init=10` diverge from sklearn's effective `1` for k-means++ (R-DEFER-7 last layer); binding marshals `labels_` to `int64`, not sklearn `int32`. |
//! | REQ-13 (ferray substrate) | NOT-STARTED | open prereq blocker #1044. `kmeans.rs` imports `ndarray`/`num-traits`/`rand`/`rayon`, not `ferray-core`/`ferray::linalg`/`ferray::random` (R-SUBSTRATE-1/2; RNG entangled with REQ-8). |
//! | REQ-15 (reject non-finite input) | SHIPPED | `fn reject_non_finite` called at the top of `Fit::fit` (after the param/sample checks, before k-means++/Lloyd) AND in `Predict::predict` (on the query X) rejects NaN AND infinity with `FerroError::InvalidParameter{name:"X"}`, mirroring sklearn's `_validate_data(force_all_finite=True)` default reached from `KMeans.fit` (`_kmeans.py:1464`) and `KMeans.predict`→`_check_test_data` (`:950`), which raise `ValueError` (`validation.py:147-154`). Consumers: the existing `fit`/`predict` entries — PyO3 `RsKMeans::fit`/`::predict` (`clusterers.rs`) + crate re-export `pub use kmeans::{FittedKMeans, KMeans}` (`lib.rs`). Pinned by `divergence_nonfinite_reject.rs` (`divergence_kmeans_fit_rejects_nan/_inf`, `divergence_kmeans_predict_rejects_nan`) — live sklearn 1.5.2 raises, ferrolearn now `Err`. Finite input byte-identical (the module's oracle pins stay green). |
use ferrolearn_core::error::FerroError;
use ferrolearn_core::traits::{Fit, Predict, Transform};
use ndarray::{Array1, Array2};
use num_traits::Float;
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use rayon::prelude::*;
/// K-Means clustering configuration (unfitted).
///
/// Holds hyperparameters for the k-Means algorithm. Call [`Fit::fit`]
/// to run the algorithm and produce a [`FittedKMeans`].
///
/// # Type Parameters
///
/// - `F`: The floating-point type (`f32` or `f64`).
#[derive(Debug, Clone)]
pub struct KMeans<F> {
/// Number of clusters to form.
pub n_clusters: usize,
/// Maximum number of Lloyd iterations per run.
pub max_iter: usize,
/// Convergence tolerance: the algorithm stops when the maximum
/// centroid movement is less than this value.
pub tol: F,
/// Number of independent runs with different initializations.
/// The result with the lowest inertia is kept.
pub n_init: usize,
/// Optional random seed for reproducibility.
pub random_state: Option<u64>,
}
impl<F: Float> KMeans<F> {
/// Create a new `KMeans` with the given number of clusters.
///
/// Uses default values: `max_iter = 300`, `tol = 1e-4`,
/// `n_init = 1`, `random_state = None`.
///
/// `n_init = 1` mirrors scikit-learn's `n_init="auto"`, which resolves to
/// `1` for the default `init="k-means++"`
/// (`sklearn/cluster/_kmeans.py:886-888`; docstring `:359-361`). ferrolearn
/// always uses k-means++ seeding, so the sklearn-matching default is `1`.
#[must_use]
pub fn new(n_clusters: usize) -> Self {
Self {
n_clusters,
max_iter: 300,
tol: F::from(1e-4).unwrap_or_else(F::epsilon),
n_init: 1,
random_state: None,
}
}
/// Set the maximum number of iterations.
#[must_use]
pub fn with_max_iter(mut self, max_iter: usize) -> Self {
self.max_iter = max_iter;
self
}
/// Set the convergence tolerance.
#[must_use]
pub fn with_tol(mut self, tol: F) -> Self {
self.tol = tol;
self
}
/// Set the number of independent runs.
#[must_use]
pub fn with_n_init(mut self, n_init: usize) -> Self {
self.n_init = n_init;
self
}
/// Set the random seed for reproducibility.
#[must_use]
pub fn with_random_state(mut self, seed: u64) -> Self {
self.random_state = Some(seed);
self
}
}
/// Fitted K-Means model.
///
/// Stores the learned cluster centers, labels, inertia, and iteration count.
/// Implements [`Predict`] to assign new samples to clusters and [`Transform`]
/// to compute distances to each centroid.
#[derive(Debug, Clone)]
pub struct FittedKMeans<F> {
/// Cluster center coordinates, shape `(n_clusters, n_features)`.
cluster_centers_: Array2<F>,
/// Cluster label for each training sample.
labels_: Array1<usize>,
/// Sum of squared distances of samples to their closest cluster center.
inertia_: F,
/// Number of iterations run in the best run.
n_iter_: usize,
}
impl<F: Float> FittedKMeans<F> {
/// Return the cluster centers, shape `(n_clusters, n_features)`.
#[must_use]
pub fn cluster_centers(&self) -> &Array2<F> {
&self.cluster_centers_
}
/// Return the cluster labels for the training data.
#[must_use]
pub fn labels(&self) -> &Array1<usize> {
&self.labels_
}
/// Return the inertia (sum of squared distances to nearest centroid).
#[must_use]
pub fn inertia(&self) -> F {
self.inertia_
}
/// Return the number of iterations in the best run.
#[must_use]
pub fn n_iter(&self) -> usize {
self.n_iter_
}
}
/// Reject `X` containing any non-finite value (NaN or infinity).
///
/// Mirrors sklearn's `_validate_data(..., force_all_finite=True)` (the default),
/// which `KMeans.fit` (`sklearn/cluster/_kmeans.py:1464`) and `KMeans.predict` →
/// `_check_test_data` (`:950`) reach: both raise
/// `ValueError("Input X contains NaN.")` / `"... contains infinity ..."`
/// (`sklearn/utils/validation.py:147-154`). Clustering has no missing-value
/// support, so NaN AND infinity are both rejected. Never panics (R-CODE-2).
fn reject_non_finite<F: Float>(x: &Array2<F>) -> Result<(), FerroError> {
if x.iter().any(|v| !v.is_finite()) {
return Err(FerroError::InvalidParameter {
name: "X".into(),
reason: "Input X contains NaN or infinity.".into(),
});
}
Ok(())
}
/// Compute the squared Euclidean distance between two slices.
fn squared_euclidean<F: Float>(a: &[F], b: &[F]) -> F {
a.iter()
.zip(b.iter())
.fold(F::zero(), |acc, (&ai, &bi)| acc + (ai - bi) * (ai - bi))
}
/// Greedy k-Means++ initialization (Arthur & Vassilvitskii 2007 with the
/// scikit-learn-style multi-trial improvement). At each pick, sample
/// `2 + log(k)` candidates with probability proportional to D(x)² and keep
/// the one minimising the resulting potential.
fn kmeans_plus_plus<F: Float>(x: &Array2<F>, k: usize, rng: &mut StdRng) -> Array2<F> {
let n_samples = x.nrows();
let n_features = x.ncols();
let mut centers = Array2::zeros((k, n_features));
if n_samples == 0 {
return centers;
}
let first_idx = rng.random_range(0..n_samples);
centers.row_mut(0).assign(&x.row(first_idx));
if k == 1 {
return centers;
}
// Distance from each sample to its nearest selected centre.
let mut min_dists = Array1::from_elem(n_samples, F::max_value());
{
let center0 = centers.row(0);
let center0_slice = center0.as_slice().unwrap_or(&[]);
for i in 0..n_samples {
min_dists[i] = squared_euclidean(x.row(i).as_slice().unwrap_or(&[]), center0_slice);
}
}
let n_trials = (2 + (k as f64).ln().floor() as usize).max(1);
for c in 1..k {
let total: F = min_dists.iter().fold(F::zero(), |acc, &d| acc + d);
if total <= F::zero() {
let idx = rng.random_range(0..n_samples);
centers.row_mut(c).assign(&x.row(idx));
continue;
}
let mut best_candidate: usize = 0;
let mut best_potential: Option<F> = None;
let mut best_new_dists: Option<Array1<F>> = None;
for _ in 0..n_trials {
let threshold: F = F::from(rng.random::<f64>()).unwrap_or_else(F::zero) * total;
let mut cumsum = F::zero();
let mut candidate = n_samples - 1;
for i in 0..n_samples {
cumsum = cumsum + min_dists[i];
if cumsum >= threshold {
candidate = i;
break;
}
}
let cand_slice = x.row(candidate).as_slice().unwrap_or(&[]).to_vec();
let mut new_dists = min_dists.clone();
let mut potential = F::zero();
for i in 0..n_samples {
let d = squared_euclidean(x.row(i).as_slice().unwrap_or(&[]), &cand_slice);
if d < new_dists[i] {
new_dists[i] = d;
}
potential = potential + new_dists[i];
}
if best_potential.is_none_or(|bp| potential < bp) {
best_potential = Some(potential);
best_candidate = candidate;
best_new_dists = Some(new_dists);
}
}
centers.row_mut(c).assign(&x.row(best_candidate));
if let Some(d) = best_new_dists {
min_dists = d;
}
}
centers
}
/// Minimum work units (samples * features) before we parallelize.
///
/// Rayon's fork/join overhead is not amortized when the per-task work is
/// small. At 1K samples x 10 features (10K work units), the serial path
/// is faster. At 10K samples x 100 features (1M work units), parallelism
/// wins comfortably.
const PARALLEL_WORK_THRESHOLD: usize = 100_000;
/// Assign each sample to its nearest centroid.
///
/// Returns `(labels, inertia)`. Uses serial iteration for small inputs
/// and Rayon parallelism for larger ones.
fn assign_clusters<F: Float + Send + Sync>(
x: &Array2<F>,
centers: &Array2<F>,
) -> (Array1<usize>, F) {
let n_samples = x.nrows();
let mut labels = Array1::zeros(n_samples);
let inertia = assign_clusters_into(&mut labels, x, centers);
(labels, inertia)
}
/// Assign each sample to its nearest centroid, writing into a pre-allocated
/// labels array. Returns the inertia.
fn assign_clusters_into<F: Float + Send + Sync>(
labels: &mut Array1<usize>,
x: &Array2<F>,
centers: &Array2<F>,
) -> F {
let work = x.nrows() * x.ncols();
if work < PARALLEL_WORK_THRESHOLD {
assign_serial(labels, x, centers)
} else {
assign_parallel(labels, x, centers)
}
}
/// Find the nearest center for a single row.
#[inline]
fn nearest_center<F: Float>(row_slice: &[F], centers: &Array2<F>) -> (usize, F) {
let k = centers.nrows();
let mut best_label = 0;
let mut best_dist = F::max_value();
for c in 0..k {
let center_row = centers.row(c);
let center_slice = center_row.as_slice().unwrap_or(&[]);
let d = squared_euclidean(row_slice, center_slice);
if d < best_dist {
best_dist = d;
best_label = c;
}
}
(best_label, best_dist)
}
/// Serial assignment — no thread-pool overhead.
fn assign_serial<F: Float>(labels: &mut Array1<usize>, x: &Array2<F>, centers: &Array2<F>) -> F {
let n_samples = x.nrows();
let mut inertia = F::zero();
for i in 0..n_samples {
let row = x.row(i);
let row_slice = row.as_slice().unwrap_or(&[]);
let (label, dist) = nearest_center(row_slice, centers);
labels[i] = label;
inertia = inertia + dist;
}
inertia
}
/// Parallel assignment using Rayon par_chunks for cache-friendly access.
fn assign_parallel<F: Float + Send + Sync>(
labels: &mut Array1<usize>,
x: &Array2<F>,
centers: &Array2<F>,
) -> F {
let n_samples = x.nrows();
let labels_slice = labels.as_slice_mut().unwrap();
let chunk_size = (n_samples / rayon::current_num_threads()).max(64);
labels_slice
.par_chunks_mut(chunk_size)
.enumerate()
.map(|(chunk_idx, chunk)| {
let start = chunk_idx * chunk_size;
let mut local_inertia = F::zero();
for (local_i, label) in chunk.iter_mut().enumerate() {
let i = start + local_i;
let row = x.row(i);
let row_slice = row.as_slice().unwrap_or(&[]);
let (best_label, dist) = nearest_center(row_slice, centers);
*label = best_label;
local_inertia = local_inertia + dist;
}
local_inertia
})
.reduce(F::zero, |a, b| a + b)
}
/// Recompute centroids as the mean of assigned samples, writing into
/// pre-allocated buffers.
///
/// Returns the maximum centroid movement.
fn recompute_centroids_into<F: Float>(
new_centers: &mut Array2<F>,
counts: &mut [F],
x: &Array2<F>,
labels: &Array1<usize>,
n_features: usize,
old_centers: &Array2<F>,
) -> F {
let k = new_centers.nrows();
new_centers.fill(F::zero());
counts.iter_mut().for_each(|c| *c = F::zero());
for (i, &label) in labels.iter().enumerate() {
counts[label] = counts[label] + F::one();
for j in 0..n_features {
new_centers[[label, j]] = new_centers[[label, j]] + x[[i, j]];
}
}
// Divide by count; if a cluster is empty, keep the old center.
for c in 0..k {
if counts[c] > F::zero() {
for j in 0..n_features {
new_centers[[c, j]] = new_centers[[c, j]] / counts[c];
}
} else {
new_centers.row_mut(c).assign(&old_centers.row(c));
}
}
// Compute maximum centroid movement.
let mut max_shift = F::zero();
for c in 0..k {
let shift = squared_euclidean(
new_centers.row(c).as_slice().unwrap_or(&[]),
old_centers.row(c).as_slice().unwrap_or(&[]),
);
if shift > max_shift {
max_shift = shift;
}
}
max_shift.sqrt()
}
impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, ()> for KMeans<F> {
type Fitted = FittedKMeans<F>;
type Error = FerroError;
/// Fit the k-Means model to the data.
///
/// Runs Lloyd's algorithm `n_init` times with k-Means++ initialization,
/// keeping the result with the lowest inertia.
///
/// # Errors
///
/// Returns [`FerroError::InvalidParameter`] if `n_clusters` is zero.
/// Returns [`FerroError::InsufficientSamples`] if the number of samples
/// is less than `n_clusters`.
fn fit(&self, x: &Array2<F>, _y: &()) -> Result<FittedKMeans<F>, FerroError> {
let (n_samples, n_features) = x.dim();
// Validate parameters.
if self.n_clusters == 0 {
return Err(FerroError::InvalidParameter {
name: "n_clusters".into(),
reason: "must be at least 1".into(),
});
}
if n_samples == 0 {
return Err(FerroError::InsufficientSamples {
required: self.n_clusters,
actual: 0,
context: "KMeans requires at least n_clusters samples".into(),
});
}
if n_samples < self.n_clusters {
return Err(FerroError::InsufficientSamples {
required: self.n_clusters,
actual: n_samples,
context: "KMeans requires at least n_clusters samples".into(),
});
}
if self.n_init == 0 {
return Err(FerroError::InvalidParameter {
name: "n_init".into(),
reason: "must be at least 1".into(),
});
}
// Reject non-finite X up front (NaN AND Inf), mirroring sklearn's
// `_validate_data(..., force_all_finite=True)` (the default) reached from
// `KMeans.fit` (`sklearn/cluster/_kmeans.py:1464`), which raises
// `ValueError("Input X contains NaN.")` (`sklearn/utils/validation.py:147-154`).
reject_non_finite(x)?;
let base_seed = self.random_state.unwrap_or(0);
let mut best_result: Option<FittedKMeans<F>> = None;
// Pre-allocate reusable buffers for the Lloyd loop.
let mut labels = Array1::zeros(n_samples);
let mut new_centers = Array2::zeros((self.n_clusters, n_features));
let mut counts = vec![F::zero(); self.n_clusters];
for run in 0..self.n_init {
let mut rng = StdRng::seed_from_u64(base_seed.wrapping_add(run as u64));
// k-Means++ initialization.
let mut centers = kmeans_plus_plus(x, self.n_clusters, &mut rng);
let mut n_iter = 0;
for iter in 0..self.max_iter {
// Assign step (serial or parallel depending on size). The
// inertia is recomputed in the final E-step below, so the
// per-iteration value is not retained here.
let _ = assign_clusters_into(&mut labels, x, ¢ers);
// Recompute centroids using pre-allocated buffers.
let max_shift = recompute_centroids_into(
&mut new_centers,
&mut counts,
x,
&labels,
n_features,
¢ers,
);
std::mem::swap(&mut centers, &mut new_centers);
n_iter = iter + 1;
// Check convergence.
if max_shift < self.tol {
break;
}
}
// Final E-step: re-assign labels to the converged centers so that
// labels_/inertia_ are consistent with cluster_centers_ (sklearn
// _kmeans.py:605-623). Without this, `cluster_centers_` is one M-step
// ahead of `labels`/`inertia` and `predict(X)` can disagree with
// `labels_`. n_iter_ is intentionally left as the loop count (REQ-9).
let inertia = assign_clusters_into(&mut labels, x, ¢ers);
let candidate = FittedKMeans {
cluster_centers_: centers,
labels_: labels.clone(),
inertia_: inertia,
n_iter_: n_iter,
};
match &best_result {
None => best_result = Some(candidate),
Some(best) => {
if candidate.inertia_ < best.inertia_ {
best_result = Some(candidate);
}
}
}
}
// SAFETY: n_init >= 1 is validated above, so best_result is always Some.
best_result.ok_or_else(|| FerroError::InvalidParameter {
name: "n_init".into(),
reason: "internal error: no runs completed".into(),
})
}
}
impl<F: Float + Send + Sync + 'static> KMeans<F> {
/// Fit on `x` and return the cluster labels for those same samples in
/// one call. Equivalent to sklearn `ClusterMixin.fit_predict`.
///
/// # Errors
///
/// Forwards any error from [`Fit::fit`] / [`Predict::predict`].
pub fn fit_predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
let fitted = self.fit(x, &())?;
fitted.predict(x)
}
}
impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedKMeans<F> {
type Output = Array1<usize>;
type Error = FerroError;
/// Assign each sample to the nearest cluster centroid.
///
/// # Errors
///
/// Returns [`FerroError::ShapeMismatch`] if the number of features
/// does not match the fitted model.
fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
let n_features = x.ncols();
let expected_features = self.cluster_centers_.ncols();
if n_features != expected_features {
return Err(FerroError::ShapeMismatch {
expected: vec![expected_features],
actual: vec![n_features],
context: "number of features must match fitted KMeans model".into(),
});
}
// Reject non-finite query X (NaN AND Inf), mirroring sklearn's
// `KMeans.predict` → `_check_test_data` `_validate_data`
// (`sklearn/cluster/_kmeans.py:1091`/`:950`), which raises `ValueError`.
reject_non_finite(x)?;
let (labels, _inertia) = assign_clusters(x, &self.cluster_centers_);
Ok(labels)
}
}
impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedKMeans<F> {
type Output = Array2<F>;
type Error = FerroError;
/// Compute the distance from each sample to each cluster centroid.
///
/// Returns a matrix of shape `(n_samples, n_clusters)` where element
/// `[i, j]` is the Euclidean distance from sample `i` to centroid `j`.
///
/// # Errors
///
/// Returns [`FerroError::ShapeMismatch`] if the number of features
/// does not match the fitted model.
fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
let n_features = x.ncols();
let expected_features = self.cluster_centers_.ncols();
if n_features != expected_features {
return Err(FerroError::ShapeMismatch {
expected: vec![expected_features],
actual: vec![n_features],
context: "number of features must match fitted KMeans model".into(),
});
}
let n_samples = x.nrows();
let k = self.cluster_centers_.nrows();
let centers = &self.cluster_centers_;
let mut distances = vec![F::zero(); n_samples * k];
let work = n_samples * n_features;
if work < PARALLEL_WORK_THRESHOLD {
for i in 0..n_samples {
let row = x.row(i);
let row_slice = row.as_slice().unwrap_or(&[]);
for c in 0..k {
let center = centers.row(c);
let cs = center.as_slice().unwrap_or(&[]);
distances[i * k + c] = squared_euclidean(row_slice, cs).sqrt();
}
}
} else {
distances
.par_chunks_mut(k)
.enumerate()
.for_each(|(i, chunk)| {
let row = x.row(i);
let row_slice = row.as_slice().unwrap_or(&[]);
for (c, slot) in chunk.iter_mut().enumerate() {
let center = centers.row(c);
let cs = center.as_slice().unwrap_or(&[]);
*slot = squared_euclidean(row_slice, cs).sqrt();
}
});
}
Array2::from_shape_vec((n_samples, k), distances).map_err(|_| {
FerroError::NumericalInstability {
message: "failed to construct distance matrix".into(),
}
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
/// Create well-separated 2D blobs for testing.
fn make_blobs() -> Array2<f64> {
Array2::from_shape_vec(
(9, 2),
vec![
// Cluster 0 near (0, 0)
0.0, 0.0, 0.1, 0.1, -0.1, 0.1, // Cluster 1 near (10, 10)
10.0, 10.0, 10.1, 10.1, 9.9, 10.1, // Cluster 2 near (0, 10)
0.0, 10.0, 0.1, 10.1, -0.1, 9.9,
],
)
.unwrap()
}
#[test]
fn test_well_separated_blobs() {
let x = make_blobs();
let model = KMeans::<f64>::new(3).with_random_state(42).with_n_init(5);
let fitted = model.fit(&x, &()).unwrap();
let labels = fitted.labels();
// Points in the same blob should have the same label.
assert_eq!(labels[0], labels[1]);
assert_eq!(labels[0], labels[2]);
assert_eq!(labels[3], labels[4]);
assert_eq!(labels[3], labels[5]);
assert_eq!(labels[6], labels[7]);
assert_eq!(labels[6], labels[8]);
// Different blobs should have different labels.
assert_ne!(labels[0], labels[3]);
assert_ne!(labels[0], labels[6]);
assert_ne!(labels[3], labels[6]);
}
#[test]
fn test_convergence() {
let x = make_blobs();
let model = KMeans::<f64>::new(3)
.with_random_state(42)
.with_max_iter(1000)
.with_n_init(1);
let fitted = model.fit(&x, &()).unwrap();
// Well-separated blobs should converge well before max_iter.
assert!(fitted.n_iter() < 100);
}
#[test]
fn test_n_init_picks_best() {
let x = make_blobs();
// With n_init=1, we might get a suboptimal result.
let model_1 = KMeans::<f64>::new(3).with_random_state(42).with_n_init(1);
let fitted_1 = model_1.fit(&x, &()).unwrap();
// With n_init=10, we should get at least as good (usually better).
let model_10 = KMeans::<f64>::new(3).with_random_state(42).with_n_init(10);
let fitted_10 = model_10.fit(&x, &()).unwrap();
// The n_init=10 run should have inertia <= n_init=1 run.
assert!(fitted_10.inertia() <= fitted_1.inertia() + 1e-10);
}
#[test]
fn test_kmeans_pp_initialization_deterministic() {
let x = make_blobs();
let model = KMeans::<f64>::new(3).with_random_state(123).with_n_init(1);
let fitted1 = model.fit(&x, &()).unwrap();
let fitted2 = model.fit(&x, &()).unwrap();
// Same seed should produce same result.
assert_eq!(fitted1.labels(), fitted2.labels());
assert_relative_eq!(fitted1.inertia(), fitted2.inertia(), epsilon = 1e-12);
}
#[test]
fn test_reproducibility_with_seed() {
let x = make_blobs();
let model = KMeans::<f64>::new(3).with_random_state(99);
let fitted1 = model.fit(&x, &()).unwrap();
let fitted2 = model.fit(&x, &()).unwrap();
assert_eq!(fitted1.labels(), fitted2.labels());
}
#[test]
fn test_predict_on_new_data() {
let x = make_blobs();
let model = KMeans::<f64>::new(3).with_random_state(42);
let fitted = model.fit(&x, &()).unwrap();
// Predict on new points near each cluster.
let new_x =
Array2::from_shape_vec((3, 2), vec![0.05, 0.05, 10.05, 10.05, 0.05, 10.05]).unwrap();
let new_labels = fitted.predict(&new_x).unwrap();
// New points near cluster 0 center.
let label_near_origin = new_labels[0];
// Should match the training label of the origin cluster.
assert_eq!(label_near_origin, fitted.labels()[0]);
let label_near_10_10 = new_labels[1];
assert_eq!(label_near_10_10, fitted.labels()[3]);
let label_near_0_10 = new_labels[2];
assert_eq!(label_near_0_10, fitted.labels()[6]);
}
#[test]
fn test_transform_distances() {
let x = Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 10.0, 0.0, 0.0, 10.0, 10.0, 10.0])
.unwrap();
let model = KMeans::<f64>::new(2).with_random_state(42);
let fitted = model.fit(&x, &()).unwrap();
let dists = fitted.transform(&x).unwrap();
// Shape should be (n_samples, n_clusters).
assert_eq!(dists.dim(), (4, 2));
// Distance to own centroid should be small, distance to other should be large.
for i in 0..4 {
let own_cluster = fitted.labels()[i];
let other_cluster = 1 - own_cluster;
assert!(dists[[i, own_cluster]] < dists[[i, other_cluster]]);
}
}
#[test]
fn test_transform_shape() {
let x = make_blobs();
let model = KMeans::<f64>::new(3).with_random_state(42);
let fitted = model.fit(&x, &()).unwrap();
let dists = fitted.transform(&x).unwrap();
assert_eq!(dists.dim(), (9, 3));
}
#[test]
fn test_cluster_centers_shape() {
let x = make_blobs();
let model = KMeans::<f64>::new(3).with_random_state(42);
let fitted = model.fit(&x, &()).unwrap();
assert_eq!(fitted.cluster_centers().dim(), (3, 2));
}
#[test]
fn test_inertia_non_negative() {
let x = make_blobs();
let model = KMeans::<f64>::new(3).with_random_state(42);
let fitted = model.fit(&x, &()).unwrap();
assert!(fitted.inertia() >= 0.0);
}
#[test]
fn test_k_equals_n_samples() {
let x = Array2::from_shape_vec((3, 2), vec![0.0, 0.0, 5.0, 5.0, 10.0, 10.0]).unwrap();
let model = KMeans::<f64>::new(3).with_random_state(42);
let fitted = model.fit(&x, &()).unwrap();
// Each point should be its own cluster, inertia should be ~0.
assert_relative_eq!(fitted.inertia(), 0.0, epsilon = 1e-10);
// All labels should be distinct.
let labels = fitted.labels();
assert_ne!(labels[0], labels[1]);
assert_ne!(labels[0], labels[2]);
assert_ne!(labels[1], labels[2]);
}
#[test]
fn test_single_cluster() {
let x =
Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0]).unwrap();
let model = KMeans::<f64>::new(1).with_random_state(42);
let fitted = model.fit(&x, &()).unwrap();
// All points should be in cluster 0.
for &label in fitted.labels() {
assert_eq!(label, 0);
}
// Center should be the mean.
let center = fitted.cluster_centers().row(0);
assert_relative_eq!(center[0], 2.5, epsilon = 1e-10);
assert_relative_eq!(center[1], 2.5, epsilon = 1e-10);
}
#[test]
fn test_single_sample() {
let x = Array2::from_shape_vec((1, 2), vec![5.0, 5.0]).unwrap();
let model = KMeans::<f64>::new(1).with_random_state(42);
let fitted = model.fit(&x, &()).unwrap();
assert_eq!(fitted.labels().len(), 1);
assert_eq!(fitted.labels()[0], 0);
assert_relative_eq!(fitted.inertia(), 0.0, epsilon = 1e-10);
}
#[test]
fn test_k_greater_than_n_samples() {
let x = Array2::from_shape_vec((2, 2), vec![1.0, 1.0, 2.0, 2.0]).unwrap();
let model = KMeans::<f64>::new(5);
let result = model.fit(&x, &());
assert!(result.is_err());
}
#[test]
fn test_zero_clusters() {
let x = Array2::from_shape_vec((3, 2), vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0]).unwrap();
let model = KMeans::<f64>::new(0);
let result = model.fit(&x, &());
assert!(result.is_err());
}
#[test]
fn test_empty_data() {
let x = Array2::<f64>::zeros((0, 2));
let model = KMeans::<f64>::new(3);
let result = model.fit(&x, &());
assert!(result.is_err());
}
#[test]
fn test_predict_shape_mismatch() {
let x =
Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0]).unwrap();
let model = KMeans::<f64>::new(2).with_random_state(42);
let fitted = model.fit(&x, &()).unwrap();
// Wrong number of features.
let x_bad = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
let result = fitted.predict(&x_bad);
assert!(result.is_err());
}
#[test]
fn test_transform_shape_mismatch() {
let x =
Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0]).unwrap();
let model = KMeans::<f64>::new(2).with_random_state(42);
let fitted = model.fit(&x, &()).unwrap();
let x_bad = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
let result = fitted.transform(&x_bad);
assert!(result.is_err());
}
#[test]
fn test_f32_support() {
let x = Array2::from_shape_vec(
(6, 2),
vec![
0.0f32, 0.0, 0.1, 0.1, -0.1, 0.1, 10.0, 10.0, 10.1, 10.1, 9.9, 10.1,
],
)
.unwrap();
let model = KMeans::<f32>::new(2).with_random_state(42);
let fitted = model.fit(&x, &()).unwrap();
assert_eq!(fitted.labels().len(), 6);
let preds = fitted.predict(&x).unwrap();
assert_eq!(preds.len(), 6);
}
#[test]
fn test_two_clusters_on_line() {
// Points on a line: cluster at x=0 and x=100.
let x = Array2::from_shape_vec((6, 1), vec![0.0, 0.1, -0.1, 100.0, 100.1, 99.9]).unwrap();
let model = KMeans::<f64>::new(2).with_random_state(42);
let fitted = model.fit(&x, &()).unwrap();
let labels = fitted.labels();
// First three should be same cluster.
assert_eq!(labels[0], labels[1]);
assert_eq!(labels[0], labels[2]);
// Last three should be same cluster.
assert_eq!(labels[3], labels[4]);
assert_eq!(labels[3], labels[5]);
// Different clusters.
assert_ne!(labels[0], labels[3]);
}
#[test]
fn test_identical_points() {
let x =
Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]).unwrap();
let model = KMeans::<f64>::new(1).with_random_state(42);
let fitted = model.fit(&x, &()).unwrap();
assert_relative_eq!(fitted.inertia(), 0.0, epsilon = 1e-10);
}
}