//! Mini-Batch K-Means clustering.
//!
//! [`MiniBatchKMeans`] is an online variant of K-Means that processes data in
//! small random mini-batches at each iteration, making it much faster than the
//! standard Lloyd's algorithm on large datasets while achieving comparable
//! clustering quality.
//!
//! # Algorithm
//!
//! 1. **Initialization**: Seeds centroids using k-Means++ (or uniformly at
//! random, depending on `init`).
//! 2. **Mini-batch loop**: For each iteration, sample a mini-batch of
//! `batch_size` points uniformly at random (without replacement, or with
//! replacement when batch_size > n_samples).
//! - Assign each point in the batch to its nearest centroid.
//! - Update each centroid using a per-centroid learning rate:
//! `η_c = 1 / count_c`, where `count_c` is the number of samples
//! assigned to centroid `c` so far (across all batches).
//! - New centroid: `c ← c + η_c * (x - c)` (running mean update).
//! 3. **Convergence**: Stop when the maximum centroid shift falls below `tol`
//! or `max_iter` batches have been processed.
//! 4. **Multi-start**: Repeat `n_init` times and keep the run with the lowest
//! final inertia.
//!
//! # Examples
//!
//! ```
//! use ferrolearn_cluster::MiniBatchKMeans;
//! use ferrolearn_core::{Fit, Predict, Transform};
//! use ndarray::Array2;
//!
//! let x = Array2::from_shape_vec((6, 2), vec![
//! 1.0, 1.0, 1.1, 1.0, 1.0, 1.1,
//! 5.0, 5.0, 5.1, 5.0, 5.0, 5.1,
//! ]).unwrap();
//!
//! let model = MiniBatchKMeans::<f64>::new(2);
//! let fitted = model.fit(&x, &()).unwrap();
//! let labels = fitted.predict(&x).unwrap();
//! assert_eq!(labels.len(), 6);
//! ```
//!
//! # `## REQ status`
//!
//! Binary (R-DEFER-2), translating `sklearn/cluster/_kmeans.py`
//! (`class MiniBatchKMeans(_BaseKMeans)`; `_mini_batch_step`,
//! `_mini_batch_convergence`, `partial_fit`; `_check_params_vs_input` base `:875`).
//! Design doc: `.design/cluster/mini_batch_kmeans.md`. Cites use ferrolearn symbol
//! anchors / sklearn `file:line` (commit 156ef14); expected values from the live
//! sklearn 1.5.2 oracle (R-CHAR-3). MiniBatchKMeans has a PyO3 consumer
//! (`_RsMiniBatchKMeans`, `ferrolearn-python/src/extras.rs`). Verify-and-document
//! unit: the `labels_` PARTITION (up to a label permutation, well-separated regime),
//! `predict`/`transform` contracts, the (thin) PyO3 marshalling, and the `n_init`
//! default match sklearn and SHIP. Exact `cluster_centers_`/`inertia_` VALUES +
//! `labels_` integers DIVERGE — ferrolearn runs `n_init` FULL per-center-LR runs with
//! max-shift convergence, vs sklearn's single best-of-init-trials + EWA-inertia early
//! stopping + low-count reassignment (#1050/#1054) on the numpy RNG (#1049).
//!
//! | REQ | Status | Evidence |
//! |---|---|---|
//! | REQ-1 (`labels_` PARTITION up-to-permutation, separable data) | SHIPPED | impl `Fit::fit` (k-means++/random init → mini-batch per-center-LR loop → best-of-`n_init` → final `fn assign_clusters_mb`) recovers sklearn's grouping on well-separated data. Consumers: PyO3 `RsMiniBatchKMeans::fit` (`extras.rs`) + crate re-export `pub use mini_batch_kmeans::{...}` (`lib.rs`). Guards: `green_req1_partition_two_blobs`, `green_req1_partition_three_blobs` in `tests/divergence_mini_batch_kmeans.rs` (canonicalized, live-oracle). Underclaim: PARTITION up-to-permutation only — `labels_` integers + `cluster_centers_`/`inertia_` VALUES (REQ-9) diverge. |
//! | REQ-2 (`predict` nearest-center contract + `predict==labels_`) | SHIPPED | impl `Predict::predict` (`fn assign_clusters_mb`) returns argmin-squared-euclidean center index; `fn fit` computes `labels_` via the same `assign_clusters_mb` against the FINAL centers, so `fit(X).predict(X) == labels_` holds (no KMeans-style swap bug). Consumers: PyO3 `RsMiniBatchKMeans::predict` + crate re-export. Guards: `green_req2_predict_equals_labels`. |
//! | REQ-3 (`transform` distance-to-centers contract) | SHIPPED | impl `Transform::transform` returns shape `(n_samples, n_clusters)`, column j = Euclidean distance to center j, `argmin(1) == predict`. Consumer: crate re-export (the thin PyO3 binding omits `transform` — REQ-11). Guards: `green_req3_transform_argmin_equals_predict`, `green_req3_transform_shape_nonneg`. Underclaim: CONTRACT only — distances track `cluster_centers_` (REQ-9). |
//! | REQ-4 (PyO3 binding marshalling) | SHIPPED | impl `#[pyclass(name="_RsMiniBatchKMeans")]` (`ferrolearn-python/src/extras.rs`) marshals `fit`/`predict` + `labels_`; registered in `ferrolearn-python/src/lib.rs`, wrapped `class MiniBatchKMeans(_ClusterWrapper)` in `python/ferrolearn/_extras.py`, exported in `__init__.py`. Verification: `maturin develop` + binding smoke test. Underclaim: the binding is THIN — it exposes only `fit`/`predict`/`labels_` (no `transform`/`cluster_centers_`/`inertia_`/`n_iter_`) and its ctor omits `n_init`/`batch_size`/`tol` (REQ-11); fitted VALUES inherit REQ-9. |
//! | REQ-5 (`n_init` constructor default = 1) | SHIPPED | impl `fn new` defaults `n_init: 1`, matching sklearn `MiniBatchKMeans(n_init="auto")` → 1 for the default `init="k-means++"` (`_kmeans.py:886-888`; live oracle `_n_init==1`). Guard: `pin_req5_n_init_default_is_one` (`MiniBatchKMeans::<f64>::new(3).n_init == 1`). Fixed #1047. (`batch_size=1024`/`max_iter=100`/`tol=0.0` defaults already matched.) |
//! | REQ-6 (ctor/fit surface max_no_improvement/reassignment_ratio/init_size/compute_labels/verbose/sample_weight + n_clusters=8 + error ABI) | NOT-STARTED | open prereq blocker #1048. sklearn `__init__` + `fit(sample_weight)` + `_check_params_vs_input` (`InvalidParameterError`); ferrolearn `MiniBatchKMeans<F>` has `n_clusters/batch_size/max_iter/tol/n_init/random_state/init` only + `FerroError` ABI + requires explicit `n_clusters`. |
//! | REQ-7 (`random_state` numpy-RNG parity) | NOT-STARTED | open prereq blocker #1049. sklearn `check_random_state` + numpy RNG; ferrolearn `StdRng`. Exact centers/labels/inertia cannot match (R-SUBSTRATE-5); blocks REQ-9. |
//! | REQ-8 (algorithm structure: init-trials-on-subsample + EWA convergence + early stop + `n_steps_`/`n_iter_`) | NOT-STARTED | open prereq blocker #1050. sklearn does ONE k-means++ init best-of-`n_init` INIT trials on an `init_size` subsample, then one mini-batch optimization with EWA-inertia convergence + `max_no_improvement` early stopping (`_mini_batch_step`/`_mini_batch_convergence`), exposing `n_steps_` AND `n_iter_`; ferrolearn runs `n_init` FULL runs with per-center-LR + max-shift<tol, exposes only `n_iter_`. |
//! | REQ-9 (`cluster_centers_`/`inertia_`/label-integers VALUE parity) | NOT-STARTED | open prereq blocker #1051. Exact values diverge via numpy-RNG (REQ-7), algorithm structure (REQ-8), and missing low-count reassignment (REQ-12). Gated on those. |
//! | REQ-10 (`partial_fit` online) | NOT-STARTED | open prereq blocker #1052. sklearn `partial_fit` processes one mini-batch per call; ferrolearn has none. |
//! | REQ-11 (PyO3 binding thin surface + dtype) | NOT-STARTED | open prereq blocker #1053. `RsMiniBatchKMeans` exposes only `fit`/`predict`/`labels_` + ctor `(n_clusters,max_iter,random_state)` — omits `transform`/`cluster_centers_`/`inertia_`/`n_iter_` and `n_init`/`batch_size`/`tol`; marshals `labels_` as `int64` not sklearn `int32` (R-DEFER-7 last layer). |
//! | REQ-12 (low-count cluster reassignment) | NOT-STARTED | open prereq blocker #1054. sklearn `_mini_batch_step` reassigns clusters whose counts fall below `reassignment_ratio*max_count` to random high-density points; ferrolearn `fn update_centers_mini_batch` has no reassignment. |
//! | REQ-13 (ferray substrate) | NOT-STARTED | open prereq blocker #1055. `mini_batch_kmeans.rs` imports `ndarray`/`num-traits`/`rand`/`rayon`, not `ferray-core`/`ferray::random` (R-SUBSTRATE-1/2; RNG entangled with REQ-7). |
//! | REQ-14 (reject non-finite input) | SHIPPED | `fn reject_non_finite` called at the top of `Fit::fit` (after the param/sample checks, before init/mini-batch) AND in `Predict::predict` (on the query X) rejects NaN AND infinity with `FerroError::InvalidParameter{name:"X"}`, mirroring sklearn's `_validate_data(force_all_finite=True)` default reached from `MiniBatchKMeans.fit` (`_kmeans.py:2073`) and `MiniBatchKMeans.predict`→`_check_test_data`, which raise `ValueError` (`validation.py:147-154`). Consumers: the existing `fit`/`predict` entries — PyO3 `RsMiniBatchKMeans::fit`/`::predict` (`extras.rs`) + crate re-export `pub use mini_batch_kmeans::{...}` (`lib.rs`). Pinned by `divergence_nonfinite_reject_spillover.rs` (`divergence_mini_batch_kmeans_fit_rejects_nan`/`_inf`) — live sklearn 1.5.2 raises, ferrolearn now `Err`. Finite input byte-identical (the module's oracle pins stay green). Closes #2286 for this estimator. |
use ferrolearn_core::error::FerroError;
use ferrolearn_core::traits::{Fit, Predict, Transform};
use ndarray::{Array1, Array2};
use num_traits::Float;
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use rayon::prelude::*;
// ---------------------------------------------------------------------------
// Init strategy
// ---------------------------------------------------------------------------
/// Centroid initialization strategy for [`MiniBatchKMeans`].
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MiniBatchKMeansInit {
/// k-Means++ initialization (default): pick centroids with probability
/// proportional to D(x)^2 to the nearest existing centroid.
KMeansPlusPlus,
/// Uniform random initialization: pick `n_clusters` samples uniformly.
Random,
}
// ---------------------------------------------------------------------------
// MiniBatchKMeans (unfitted)
// ---------------------------------------------------------------------------
/// Mini-Batch K-Means clustering configuration (unfitted).
///
/// Holds all hyperparameters. Call [`Fit::fit`] to run the algorithm and
/// produce a [`FittedMiniBatchKMeans`].
///
/// # Type Parameters
///
/// - `F`: The floating-point type (`f32` or `f64`).
#[derive(Debug, Clone)]
pub struct MiniBatchKMeans<F> {
/// Number of clusters to form.
pub n_clusters: usize,
/// Number of samples per mini-batch.
pub batch_size: usize,
/// Maximum number of mini-batch iterations per run.
pub max_iter: usize,
/// Convergence tolerance on the maximum centroid shift.
pub tol: F,
/// Number of independent runs. The best result (lowest inertia) is kept.
pub n_init: usize,
/// Optional random seed for reproducibility.
pub random_state: Option<u64>,
/// Centroid initialization strategy.
pub init: MiniBatchKMeansInit,
}
impl<F: Float> MiniBatchKMeans<F> {
/// Create a new `MiniBatchKMeans` with the given number of clusters.
///
/// Uses default values: `batch_size = 1024`, `max_iter = 100`,
/// `tol = 0.0` (no per-batch early stopping), `n_init = 1`,
/// `random_state = None`, `init = KMeansPlusPlus`.
///
/// These match scikit-learn's `MiniBatchKMeans` defaults: `batch_size`,
/// `max_iter`, and `tol` match directly, and sklearn's
/// `n_init="auto"` resolves to `1` for the default `init="k-means++"`
/// (`sklearn/cluster/_kmeans.py:886-888`).
#[must_use]
pub fn new(n_clusters: usize) -> Self {
Self {
n_clusters,
batch_size: 1024,
max_iter: 100,
tol: F::zero(),
n_init: 1,
random_state: None,
init: MiniBatchKMeansInit::KMeansPlusPlus,
}
}
/// Set the mini-batch size.
#[must_use]
pub fn with_batch_size(mut self, batch_size: usize) -> Self {
self.batch_size = batch_size;
self
}
/// Set the maximum number of mini-batch iterations.
#[must_use]
pub fn with_max_iter(mut self, max_iter: usize) -> Self {
self.max_iter = max_iter;
self
}
/// Set the convergence tolerance.
#[must_use]
pub fn with_tol(mut self, tol: F) -> Self {
self.tol = tol;
self
}
/// Set the number of independent runs.
#[must_use]
pub fn with_n_init(mut self, n_init: usize) -> Self {
self.n_init = n_init;
self
}
/// Set the random seed for reproducibility.
#[must_use]
pub fn with_random_state(mut self, seed: u64) -> Self {
self.random_state = Some(seed);
self
}
/// Set the initialization strategy.
#[must_use]
pub fn with_init(mut self, init: MiniBatchKMeansInit) -> Self {
self.init = init;
self
}
}
// ---------------------------------------------------------------------------
// FittedMiniBatchKMeans
// ---------------------------------------------------------------------------
/// A fitted Mini-Batch K-Means model.
///
/// Created by calling [`Fit::fit`] on a [`MiniBatchKMeans`].
/// Implements [`Predict`] (assign to nearest centroid) and [`Transform`]
/// (Euclidean distance to each centroid).
#[derive(Debug, Clone)]
pub struct FittedMiniBatchKMeans<F> {
/// Cluster center coordinates, shape `(n_clusters, n_features)`.
cluster_centers_: Array2<F>,
/// Cluster label for each training sample.
labels_: Array1<usize>,
/// Sum of squared distances of training samples to their closest centroid.
inertia_: F,
/// Number of mini-batch iterations in the best run.
n_iter_: usize,
}
impl<F: Float> FittedMiniBatchKMeans<F> {
/// Cluster center coordinates, shape `(n_clusters, n_features)`.
#[must_use]
pub fn cluster_centers(&self) -> &Array2<F> {
&self.cluster_centers_
}
/// Cluster label for each training sample.
#[must_use]
pub fn labels(&self) -> &Array1<usize> {
&self.labels_
}
/// Sum of squared distances of training samples to their nearest centroid.
#[must_use]
pub fn inertia(&self) -> F {
self.inertia_
}
/// Number of mini-batch iterations run in the best run.
#[must_use]
pub fn n_iter(&self) -> usize {
self.n_iter_
}
}
// ---------------------------------------------------------------------------
// Internal helpers
// ---------------------------------------------------------------------------
/// Reject `X` containing any non-finite value (NaN or infinity).
///
/// Mirrors sklearn's `MiniBatchKMeans.fit` →
/// `self._validate_data(X, accept_sparse="csr", dtype=[np.float64, np.float32], ...)`
/// (`sklearn/cluster/_kmeans.py:2073`) and `MiniBatchKMeans.predict` →
/// `_BaseKMeans._check_test_data` → `_validate_data`, both keeping the
/// `force_all_finite=True` default, which raises
/// `ValueError("Input X contains NaN.")` / `"... contains infinity ..."`
/// (`sklearn/utils/validation.py:147-154`). NaN AND infinity are both rejected.
/// Never panics (R-CODE-2).
fn reject_non_finite<F: Float>(x: &Array2<F>) -> Result<(), FerroError> {
if x.iter().any(|v| !v.is_finite()) {
return Err(FerroError::InvalidParameter {
name: "X".into(),
reason: "Input X contains NaN or infinity.".into(),
});
}
Ok(())
}
/// Compute the squared Euclidean distance between two slices.
#[inline]
fn squared_euclidean_mb<F: Float>(a: &[F], b: &[F]) -> F {
a.iter()
.zip(b.iter())
.fold(F::zero(), |acc, (&ai, &bi)| acc + (ai - bi) * (ai - bi))
}
/// Greedy k-Means++ initialization (Arthur & Vassilvitskii 2007 with the
/// scikit-learn-style multi-trial improvement).
///
/// At each step `2 + log(k)` candidate indices are sampled with probability
/// proportional to D(x)² and the candidate that minimises the resulting
/// total potential is kept. Plain (single-trial) KMeans++ frequently lands
/// in inferior local minima — measured at -0.16 ARI vs sklearn at
/// n=5000, k=8, p=20.
fn kmeans_plus_plus_mb<F: Float>(x: &Array2<F>, k: usize, rng: &mut StdRng) -> Array2<F> {
let n_samples = x.nrows();
let n_features = x.ncols();
let mut centers = Array2::zeros((k, n_features));
if n_samples == 0 {
return centers;
}
let first_idx = rng.random_range(0..n_samples);
centers.row_mut(0).assign(&x.row(first_idx));
if k == 1 {
return centers;
}
// Track squared distance from each sample to its nearest selected centre.
let mut min_dists = Array1::from_elem(n_samples, F::max_value());
{
let center0 = centers.row(0);
let center0_slice = center0.as_slice().unwrap_or(&[]);
for i in 0..n_samples {
min_dists[i] = squared_euclidean_mb(x.row(i).as_slice().unwrap_or(&[]), center0_slice);
}
}
// sklearn's `_kmeans_plusplus` uses `n_local_trials = 2 + int(log(k))`.
let n_trials = 2 + (k as f64).ln().floor() as usize;
let n_trials = n_trials.max(1);
for c in 1..k {
let total: F = min_dists.iter().fold(F::zero(), |acc, &d| acc + d);
if total <= F::zero() {
// All points coincide with chosen centres — pick any uniformly.
let idx = rng.random_range(0..n_samples);
centers.row_mut(c).assign(&x.row(idx));
continue;
}
// Sample `n_trials` candidate indices with prob ∝ D².
let mut best_candidate: usize = 0;
let mut best_potential: Option<F> = None;
let mut best_new_dists: Option<Array1<F>> = None;
for _ in 0..n_trials {
let threshold: F = F::from(rng.random::<f64>()).unwrap_or_else(F::zero) * total;
let mut cumsum = F::zero();
let mut candidate = n_samples - 1;
for i in 0..n_samples {
cumsum = cumsum + min_dists[i];
if cumsum >= threshold {
candidate = i;
break;
}
}
// Compute the potential (sum of min squared distances) if we
// committed to this candidate as the new centre.
let cand_slice = x.row(candidate).as_slice().unwrap_or(&[]).to_vec();
let mut new_dists = min_dists.clone();
let mut potential = F::zero();
for i in 0..n_samples {
let d = squared_euclidean_mb(x.row(i).as_slice().unwrap_or(&[]), &cand_slice);
if d < new_dists[i] {
new_dists[i] = d;
}
potential = potential + new_dists[i];
}
if best_potential.is_none_or(|bp| potential < bp) {
best_potential = Some(potential);
best_candidate = candidate;
best_new_dists = Some(new_dists);
}
}
centers.row_mut(c).assign(&x.row(best_candidate));
if let Some(d) = best_new_dists {
min_dists = d;
}
}
centers
}
/// Uniform random initialization.
fn random_init_mb<F: Float>(x: &Array2<F>, k: usize, rng: &mut StdRng) -> Array2<F> {
let n_samples = x.nrows();
let n_features = x.ncols();
let mut centers = Array2::zeros((k, n_features));
// Fisher-Yates shuffle for k picks.
let mut indices: Vec<usize> = (0..n_samples).collect();
for i in 0..k {
let j = rng.random_range(i..n_samples);
indices.swap(i, j);
centers.row_mut(i).assign(&x.row(indices[i]));
}
centers
}
/// Assign each sample to its nearest centroid (parallelized).
///
/// Returns `(labels, inertia)`.
fn assign_clusters_mb<F: Float + Send + Sync>(
x: &Array2<F>,
centers: &Array2<F>,
) -> (Array1<usize>, F) {
let n_samples = x.nrows();
let k = centers.nrows();
let results: Vec<(usize, F)> = (0..n_samples)
.into_par_iter()
.map(|i| {
let row = x.row(i);
let row_slice = row.as_slice().unwrap_or(&[]);
let mut best_label = 0;
let mut best_dist = F::max_value();
for c in 0..k {
let d = squared_euclidean_mb(row_slice, centers.row(c).as_slice().unwrap_or(&[]));
if d < best_dist {
best_dist = d;
best_label = c;
}
}
(best_label, best_dist)
})
.collect();
let mut labels = Array1::zeros(n_samples);
let mut inertia = F::zero();
for (i, (label, dist)) in results.into_iter().enumerate() {
labels[i] = label;
inertia = inertia + dist;
}
(labels, inertia)
}
/// Assign each sample in a batch to its nearest centroid.
///
/// Returns the labels for batch indices.
fn assign_batch<F: Float>(
x: &Array2<F>,
batch_indices: &[usize],
centers: &Array2<F>,
) -> Vec<usize> {
let k = centers.nrows();
batch_indices
.iter()
.map(|&i| {
let row = x.row(i);
let row_slice = row.as_slice().unwrap_or(&[]);
let mut best_label = 0;
let mut best_dist = F::max_value();
for c in 0..k {
let d = squared_euclidean_mb(row_slice, centers.row(c).as_slice().unwrap_or(&[]));
if d < best_dist {
best_dist = d;
best_label = c;
}
}
best_label
})
.collect()
}
/// One mini-batch update step.
///
/// Updates centroids in-place using the online learning-rate rule and
/// returns the maximum centroid shift.
fn update_centers_mini_batch<F: Float>(
x: &Array2<F>,
batch_indices: &[usize],
batch_labels: &[usize],
centers: &mut Array2<F>,
center_counts: &mut [usize],
) -> F {
let n_features = centers.ncols();
let k = centers.nrows();
// Clone current centers to measure shift afterward.
let old_centers = centers.clone();
// Apply per-sample update: c += (1/count_c) * (x - c)
for (&idx, &label) in batch_indices.iter().zip(batch_labels.iter()) {
center_counts[label] += 1;
let lr = F::one() / F::from(center_counts[label]).unwrap_or_else(F::one);
let x_row = x.row(idx);
for j in 0..n_features {
centers[[label, j]] = centers[[label, j]] + lr * (x_row[j] - centers[[label, j]]);
}
}
// Compute max centroid shift.
let mut max_shift = F::zero();
for c in 0..k {
let shift = squared_euclidean_mb(
centers.row(c).as_slice().unwrap_or(&[]),
old_centers.row(c).as_slice().unwrap_or(&[]),
)
.sqrt();
if shift > max_shift {
max_shift = shift;
}
}
max_shift
}
/// Sample `batch_size` indices from `0..n_samples` (with replacement when
/// batch_size > n_samples).
fn sample_batch_indices(n_samples: usize, batch_size: usize, rng: &mut StdRng) -> Vec<usize> {
if batch_size >= n_samples {
// Use all samples in a random order.
let mut indices: Vec<usize> = (0..n_samples).collect();
// Partial Fisher-Yates to shuffle.
for i in 0..n_samples {
let j = rng.random_range(i..n_samples);
indices.swap(i, j);
}
return indices;
}
// Sample without replacement using partial Fisher-Yates.
let mut pool: Vec<usize> = (0..n_samples).collect();
let mut result = Vec::with_capacity(batch_size);
for i in 0..batch_size {
let j = rng.random_range(i..n_samples);
pool.swap(i, j);
result.push(pool[i]);
}
result
}
// ---------------------------------------------------------------------------
// Trait implementations
// ---------------------------------------------------------------------------
impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, ()> for MiniBatchKMeans<F> {
type Fitted = FittedMiniBatchKMeans<F>;
type Error = FerroError;
/// Fit the Mini-Batch K-Means model to the data.
///
/// Runs `n_init` independent runs with the configured initialization
/// strategy, keeping the result with the lowest final inertia.
///
/// # Errors
///
/// - [`FerroError::InvalidParameter`] if `n_clusters == 0`,
/// `batch_size == 0`, or `n_init == 0`.
/// - [`FerroError::InsufficientSamples`] if `n_samples < n_clusters`.
fn fit(&self, x: &Array2<F>, _y: &()) -> Result<FittedMiniBatchKMeans<F>, FerroError> {
let (n_samples, n_features) = x.dim();
if self.n_clusters == 0 {
return Err(FerroError::InvalidParameter {
name: "n_clusters".into(),
reason: "must be at least 1".into(),
});
}
if self.batch_size == 0 {
return Err(FerroError::InvalidParameter {
name: "batch_size".into(),
reason: "must be at least 1".into(),
});
}
if self.n_init == 0 {
return Err(FerroError::InvalidParameter {
name: "n_init".into(),
reason: "must be at least 1".into(),
});
}
if n_samples == 0 {
return Err(FerroError::InsufficientSamples {
required: self.n_clusters,
actual: 0,
context: "MiniBatchKMeans requires at least n_clusters samples".into(),
});
}
if n_samples < self.n_clusters {
return Err(FerroError::InsufficientSamples {
required: self.n_clusters,
actual: n_samples,
context: "MiniBatchKMeans requires at least n_clusters samples".into(),
});
}
// Reject non-finite X up front (NaN AND Inf), mirroring sklearn's
// `_validate_data(force_all_finite=True)` reached from
// `MiniBatchKMeans.fit` (`_kmeans.py:2073`), which raises `ValueError`
// (R-DEV-1, R-CODE-2).
reject_non_finite(x)?;
let base_seed = self.random_state.unwrap_or(0);
let mut best_result: Option<FittedMiniBatchKMeans<F>> = None;
for run in 0..self.n_init {
let mut rng = StdRng::seed_from_u64(base_seed.wrapping_add(run as u64 * 1_000_003));
// Initialize centers.
let mut centers = match self.init {
MiniBatchKMeansInit::KMeansPlusPlus => {
kmeans_plus_plus_mb(x, self.n_clusters, &mut rng)
}
MiniBatchKMeansInit::Random => random_init_mb(x, self.n_clusters, &mut rng),
};
// Per-centroid sample counts (for learning rate).
let mut center_counts = vec![1usize; self.n_clusters];
let mut n_iter = 0usize;
for _iter in 0..self.max_iter {
let batch_indices = sample_batch_indices(n_samples, self.batch_size, &mut rng);
let batch_labels = assign_batch(x, &batch_indices, ¢ers);
let shift = update_centers_mini_batch(
x,
&batch_indices,
&batch_labels,
&mut centers,
&mut center_counts,
);
n_iter += 1;
if shift < self.tol {
break;
}
}
// Compute final labels and inertia on the full dataset.
let (labels, inertia) = assign_clusters_mb(x, ¢ers);
let _ = n_features; // used indirectly via x
let candidate = FittedMiniBatchKMeans {
cluster_centers_: centers,
labels_: labels,
inertia_: inertia,
n_iter_: n_iter,
};
match &best_result {
None => best_result = Some(candidate),
Some(best) => {
if candidate.inertia_ < best.inertia_ {
best_result = Some(candidate);
}
}
}
}
best_result.ok_or_else(|| FerroError::InvalidParameter {
name: "n_init".into(),
reason: "internal error: no runs completed".into(),
})
}
}
impl<F: Float + Send + Sync + 'static> MiniBatchKMeans<F> {
/// Fit on `x` and return labels for those samples in one call.
/// Equivalent to sklearn `ClusterMixin.fit_predict`.
///
/// # Errors
///
/// Forwards any error from [`Fit::fit`] / [`Predict::predict`].
pub fn fit_predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
let fitted = self.fit(x, &())?;
fitted.predict(x)
}
}
impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedMiniBatchKMeans<F> {
type Output = Array1<usize>;
type Error = FerroError;
/// Assign each sample to the nearest cluster centroid.
///
/// # Errors
///
/// Returns [`FerroError::ShapeMismatch`] if the number of features does not
/// match the fitted model.
fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
let n_features = x.ncols();
let expected_features = self.cluster_centers_.ncols();
if n_features != expected_features {
return Err(FerroError::ShapeMismatch {
expected: vec![expected_features],
actual: vec![n_features],
context: "FittedMiniBatchKMeans::predict".into(),
});
}
// Reject non-finite query X (NaN AND Inf), mirroring sklearn's
// `MiniBatchKMeans.predict` → `_check_test_data` → `_validate_data`
// (`force_all_finite=True`), which raises `ValueError`.
reject_non_finite(x)?;
let (labels, _) = assign_clusters_mb(x, &self.cluster_centers_);
Ok(labels)
}
}
impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedMiniBatchKMeans<F> {
type Output = Array2<F>;
type Error = FerroError;
/// Compute the Euclidean distance from each sample to each centroid.
///
/// Returns a matrix of shape `(n_samples, n_clusters)` where element
/// `[i, j]` is the Euclidean distance from sample `i` to centroid `j`.
///
/// # Errors
///
/// Returns [`FerroError::ShapeMismatch`] if the number of features does not
/// match the fitted model.
fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
let n_features = x.ncols();
let expected_features = self.cluster_centers_.ncols();
if n_features != expected_features {
return Err(FerroError::ShapeMismatch {
expected: vec![expected_features],
actual: vec![n_features],
context: "FittedMiniBatchKMeans::transform".into(),
});
}
let n_samples = x.nrows();
let k = self.cluster_centers_.nrows();
let distances: Vec<F> = (0..n_samples)
.into_par_iter()
.flat_map(|i| {
let row = x.row(i);
let row_slice = row.as_slice().unwrap_or(&[]);
(0..k)
.map(|c| {
squared_euclidean_mb(
row_slice,
self.cluster_centers_.row(c).as_slice().unwrap_or(&[]),
)
.sqrt()
})
.collect::<Vec<F>>()
})
.collect();
Array2::from_shape_vec((n_samples, k), distances).map_err(|_| {
FerroError::NumericalInstability {
message: "failed to construct distance matrix".into(),
}
})
}
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
/// Well-separated 2D blobs for testing.
fn make_blobs() -> Array2<f64> {
Array2::from_shape_vec(
(9, 2),
vec![
0.0, 0.0, 0.1, 0.1, -0.1, 0.1, // cluster near (0,0)
10.0, 10.0, 10.1, 10.1, 9.9, 10.1, // cluster near (10,10)
0.0, 10.0, 0.1, 10.1, -0.1, 9.9, // cluster near (0,10)
],
)
.unwrap()
}
#[test]
fn test_well_separated_blobs() {
let x = make_blobs();
let model = MiniBatchKMeans::<f64>::new(3)
.with_random_state(42)
.with_n_init(5)
.with_batch_size(9);
let fitted = model.fit(&x, &()).unwrap();
let labels = fitted.labels();
// Points in the same blob should share a label.
assert_eq!(labels[0], labels[1]);
assert_eq!(labels[0], labels[2]);
assert_eq!(labels[3], labels[4]);
assert_eq!(labels[3], labels[5]);
assert_eq!(labels[6], labels[7]);
assert_eq!(labels[6], labels[8]);
// Different blobs should have different labels.
assert_ne!(labels[0], labels[3]);
assert_ne!(labels[0], labels[6]);
assert_ne!(labels[3], labels[6]);
}
#[test]
fn test_cluster_centers_shape() {
let x = make_blobs();
let model = MiniBatchKMeans::<f64>::new(3).with_random_state(42);
let fitted = model.fit(&x, &()).unwrap();
assert_eq!(fitted.cluster_centers().dim(), (3, 2));
}
#[test]
fn test_labels_length() {
let x = make_blobs();
let model = MiniBatchKMeans::<f64>::new(3).with_random_state(42);
let fitted = model.fit(&x, &()).unwrap();
assert_eq!(fitted.labels().len(), 9);
}
#[test]
fn test_inertia_non_negative() {
let x = make_blobs();
let model = MiniBatchKMeans::<f64>::new(3).with_random_state(42);
let fitted = model.fit(&x, &()).unwrap();
assert!(fitted.inertia() >= 0.0);
}
#[test]
fn test_n_iter_positive() {
let x = make_blobs();
let model = MiniBatchKMeans::<f64>::new(3).with_random_state(42);
let fitted = model.fit(&x, &()).unwrap();
assert!(fitted.n_iter() >= 1);
}
#[test]
fn test_predict_on_new_data() {
let x = make_blobs();
let model = MiniBatchKMeans::<f64>::new(3)
.with_random_state(42)
.with_n_init(5)
.with_batch_size(9);
let fitted = model.fit(&x, &()).unwrap();
let new_x =
Array2::from_shape_vec((3, 2), vec![0.05, 0.05, 10.05, 10.05, 0.05, 10.05]).unwrap();
let new_labels = fitted.predict(&new_x).unwrap();
assert_eq!(new_labels.len(), 3);
// Each new point should be assigned to the correct cluster.
assert_eq!(new_labels[0], fitted.labels()[0]);
assert_eq!(new_labels[1], fitted.labels()[3]);
assert_eq!(new_labels[2], fitted.labels()[6]);
}
#[test]
fn test_transform_shape() {
let x = make_blobs();
let model = MiniBatchKMeans::<f64>::new(3).with_random_state(42);
let fitted = model.fit(&x, &()).unwrap();
let dists = fitted.transform(&x).unwrap();
assert_eq!(dists.dim(), (9, 3));
}
#[test]
fn test_transform_distances_structure() {
let x = Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 10.0, 0.0, 0.0, 10.0, 10.0, 10.0])
.unwrap();
let model = MiniBatchKMeans::<f64>::new(2)
.with_random_state(42)
.with_batch_size(4)
.with_n_init(5);
let fitted = model.fit(&x, &()).unwrap();
let dists = fitted.transform(&x).unwrap();
// Shape should be (n_samples, n_clusters).
assert_eq!(dists.dim(), (4, 2));
// Distance to own centroid should be smaller than distance to other centroid.
for i in 0..4 {
let own_cluster = fitted.labels()[i];
let other_cluster = 1 - own_cluster;
assert!(dists[[i, own_cluster]] <= dists[[i, other_cluster]] + 1e-10);
}
}
#[test]
fn test_reproducibility_with_seed() {
let x = make_blobs();
let model = MiniBatchKMeans::<f64>::new(3)
.with_random_state(99)
.with_batch_size(9);
let fitted1 = model.fit(&x, &()).unwrap();
let fitted2 = model.fit(&x, &()).unwrap();
assert_eq!(fitted1.labels(), fitted2.labels());
assert_relative_eq!(fitted1.inertia(), fitted2.inertia(), epsilon = 1e-12);
}
#[test]
fn test_random_init() {
let x = make_blobs();
let model = MiniBatchKMeans::<f64>::new(3)
.with_random_state(42)
.with_init(MiniBatchKMeansInit::Random)
.with_n_init(5)
.with_batch_size(9);
let fitted = model.fit(&x, &()).unwrap();
assert_eq!(fitted.cluster_centers().dim(), (3, 2));
assert!(fitted.inertia() >= 0.0);
}
#[test]
fn test_single_cluster() {
let x =
Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0]).unwrap();
let model = MiniBatchKMeans::<f64>::new(1)
.with_random_state(42)
.with_batch_size(4);
let fitted = model.fit(&x, &()).unwrap();
for &label in fitted.labels() {
assert_eq!(label, 0);
}
}
#[test]
fn test_zero_clusters_error() {
let x = Array2::from_shape_vec((3, 2), vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0]).unwrap();
assert!(MiniBatchKMeans::<f64>::new(0).fit(&x, &()).is_err());
}
#[test]
fn test_zero_batch_size_error() {
let x = Array2::from_shape_vec((3, 2), vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0]).unwrap();
let model = MiniBatchKMeans::<f64>::new(2).with_batch_size(0);
assert!(model.fit(&x, &()).is_err());
}
#[test]
fn test_too_few_samples_error() {
let x = Array2::from_shape_vec((2, 2), vec![1.0, 1.0, 2.0, 2.0]).unwrap();
assert!(MiniBatchKMeans::<f64>::new(5).fit(&x, &()).is_err());
}
#[test]
fn test_empty_data_error() {
let x = Array2::<f64>::zeros((0, 2));
assert!(MiniBatchKMeans::<f64>::new(3).fit(&x, &()).is_err());
}
#[test]
fn test_predict_shape_mismatch_error() {
let x =
Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0]).unwrap();
let model = MiniBatchKMeans::<f64>::new(2)
.with_random_state(42)
.with_batch_size(4);
let fitted = model.fit(&x, &()).unwrap();
let x_bad = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
assert!(fitted.predict(&x_bad).is_err());
}
#[test]
fn test_transform_shape_mismatch_error() {
let x =
Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0]).unwrap();
let model = MiniBatchKMeans::<f64>::new(2)
.with_random_state(42)
.with_batch_size(4);
let fitted = model.fit(&x, &()).unwrap();
let x_bad = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
assert!(fitted.transform(&x_bad).is_err());
}
#[test]
fn test_f32_support() {
let x: Array2<f32> = Array2::from_shape_vec(
(6, 2),
vec![
0.0f32, 0.0, 0.1, 0.1, -0.1, 0.1, 10.0, 10.0, 10.1, 10.1, 9.9, 10.1,
],
)
.unwrap();
let model = MiniBatchKMeans::<f32>::new(2)
.with_random_state(42)
.with_batch_size(6);
let fitted = model.fit(&x, &()).unwrap();
assert_eq!(fitted.labels().len(), 6);
let preds = fitted.predict(&x).unwrap();
assert_eq!(preds.len(), 6);
}
#[test]
fn test_large_batch_size() {
// batch_size >= n_samples: should use all samples.
let x = make_blobs();
let model = MiniBatchKMeans::<f64>::new(3)
.with_random_state(7)
.with_batch_size(1000);
let fitted = model.fit(&x, &()).unwrap();
assert_eq!(fitted.cluster_centers().nrows(), 3);
}
#[test]
fn test_n_init_zero_error() {
let x = make_blobs();
let model = MiniBatchKMeans::<f64>::new(3).with_n_init(0);
assert!(model.fit(&x, &()).is_err());
}
#[test]
fn test_identical_points() {
let x =
Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]).unwrap();
let model = MiniBatchKMeans::<f64>::new(1)
.with_random_state(42)
.with_batch_size(4);
let fitted = model.fit(&x, &()).unwrap();
assert_relative_eq!(fitted.inertia(), 0.0, epsilon = 1e-10);
}
}