use ferrolearn_core::error::FerroError;
use ferrolearn_core::traits::{Fit, Predict, Transform};
use ndarray::{Array1, Array2};
use num_traits::Float;
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use rayon::prelude::*;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MiniBatchKMeansInit {
KMeansPlusPlus,
Random,
}
#[derive(Debug, Clone)]
pub struct MiniBatchKMeans<F> {
pub n_clusters: usize,
pub batch_size: usize,
pub max_iter: usize,
pub tol: F,
pub n_init: usize,
pub random_state: Option<u64>,
pub init: MiniBatchKMeansInit,
}
impl<F: Float> MiniBatchKMeans<F> {
#[must_use]
pub fn new(n_clusters: usize) -> Self {
Self {
n_clusters,
batch_size: 100,
max_iter: 300,
tol: F::from(1e-4).unwrap_or_else(F::epsilon),
n_init: 3,
random_state: None,
init: MiniBatchKMeansInit::KMeansPlusPlus,
}
}
#[must_use]
pub fn with_batch_size(mut self, batch_size: usize) -> Self {
self.batch_size = batch_size;
self
}
#[must_use]
pub fn with_max_iter(mut self, max_iter: usize) -> Self {
self.max_iter = max_iter;
self
}
#[must_use]
pub fn with_tol(mut self, tol: F) -> Self {
self.tol = tol;
self
}
#[must_use]
pub fn with_n_init(mut self, n_init: usize) -> Self {
self.n_init = n_init;
self
}
#[must_use]
pub fn with_random_state(mut self, seed: u64) -> Self {
self.random_state = Some(seed);
self
}
#[must_use]
pub fn with_init(mut self, init: MiniBatchKMeansInit) -> Self {
self.init = init;
self
}
}
#[derive(Debug, Clone)]
pub struct FittedMiniBatchKMeans<F> {
cluster_centers_: Array2<F>,
labels_: Array1<usize>,
inertia_: F,
n_iter_: usize,
}
impl<F: Float> FittedMiniBatchKMeans<F> {
#[must_use]
pub fn cluster_centers(&self) -> &Array2<F> {
&self.cluster_centers_
}
#[must_use]
pub fn labels(&self) -> &Array1<usize> {
&self.labels_
}
#[must_use]
pub fn inertia(&self) -> F {
self.inertia_
}
#[must_use]
pub fn n_iter(&self) -> usize {
self.n_iter_
}
}
#[inline]
fn squared_euclidean_mb<F: Float>(a: &[F], b: &[F]) -> F {
a.iter()
.zip(b.iter())
.fold(F::zero(), |acc, (&ai, &bi)| acc + (ai - bi) * (ai - bi))
}
fn kmeans_plus_plus_mb<F: Float>(x: &Array2<F>, k: usize, rng: &mut StdRng) -> Array2<F> {
let n_samples = x.nrows();
let n_features = x.ncols();
let mut centers = Array2::zeros((k, n_features));
let first_idx = rng.random_range(0..n_samples);
centers.row_mut(0).assign(&x.row(first_idx));
let mut min_dists = Array1::from_elem(n_samples, F::max_value());
for c in 1..k {
let prev_center = centers.row(c - 1);
let prev_slice = prev_center.as_slice().unwrap_or(&[]);
for i in 0..n_samples {
let d = squared_euclidean_mb(x.row(i).as_slice().unwrap_or(&[]), prev_slice);
if d < min_dists[i] {
min_dists[i] = d;
}
}
let total: F = min_dists.iter().fold(F::zero(), |acc, &d| acc + d);
if total == F::zero() {
let idx = rng.random_range(0..n_samples);
centers.row_mut(c).assign(&x.row(idx));
continue;
}
let threshold: F = F::from(rng.random::<f64>()).unwrap_or(F::zero()) * total;
let mut cumsum = F::zero();
let mut chosen = n_samples - 1;
for i in 0..n_samples {
cumsum = cumsum + min_dists[i];
if cumsum >= threshold {
chosen = i;
break;
}
}
centers.row_mut(c).assign(&x.row(chosen));
}
centers
}
fn random_init_mb<F: Float>(x: &Array2<F>, k: usize, rng: &mut StdRng) -> Array2<F> {
let n_samples = x.nrows();
let n_features = x.ncols();
let mut centers = Array2::zeros((k, n_features));
let mut indices: Vec<usize> = (0..n_samples).collect();
for i in 0..k {
let j = rng.random_range(i..n_samples);
indices.swap(i, j);
centers.row_mut(i).assign(&x.row(indices[i]));
}
centers
}
fn assign_clusters_mb<F: Float + Send + Sync>(
x: &Array2<F>,
centers: &Array2<F>,
) -> (Array1<usize>, F) {
let n_samples = x.nrows();
let k = centers.nrows();
let results: Vec<(usize, F)> = (0..n_samples)
.into_par_iter()
.map(|i| {
let row = x.row(i);
let row_slice = row.as_slice().unwrap_or(&[]);
let mut best_label = 0;
let mut best_dist = F::max_value();
for c in 0..k {
let d = squared_euclidean_mb(row_slice, centers.row(c).as_slice().unwrap_or(&[]));
if d < best_dist {
best_dist = d;
best_label = c;
}
}
(best_label, best_dist)
})
.collect();
let mut labels = Array1::zeros(n_samples);
let mut inertia = F::zero();
for (i, (label, dist)) in results.into_iter().enumerate() {
labels[i] = label;
inertia = inertia + dist;
}
(labels, inertia)
}
fn assign_batch<F: Float>(
x: &Array2<F>,
batch_indices: &[usize],
centers: &Array2<F>,
) -> Vec<usize> {
let k = centers.nrows();
batch_indices
.iter()
.map(|&i| {
let row = x.row(i);
let row_slice = row.as_slice().unwrap_or(&[]);
let mut best_label = 0;
let mut best_dist = F::max_value();
for c in 0..k {
let d = squared_euclidean_mb(row_slice, centers.row(c).as_slice().unwrap_or(&[]));
if d < best_dist {
best_dist = d;
best_label = c;
}
}
best_label
})
.collect()
}
fn update_centers_mini_batch<F: Float>(
x: &Array2<F>,
batch_indices: &[usize],
batch_labels: &[usize],
centers: &mut Array2<F>,
center_counts: &mut [usize],
) -> F {
let n_features = centers.ncols();
let k = centers.nrows();
let old_centers = centers.clone();
for (&idx, &label) in batch_indices.iter().zip(batch_labels.iter()) {
center_counts[label] += 1;
let lr = F::one() / F::from(center_counts[label]).unwrap_or(F::one());
let x_row = x.row(idx);
for j in 0..n_features {
centers[[label, j]] = centers[[label, j]] + lr * (x_row[j] - centers[[label, j]]);
}
}
let mut max_shift = F::zero();
for c in 0..k {
let shift = squared_euclidean_mb(
centers.row(c).as_slice().unwrap_or(&[]),
old_centers.row(c).as_slice().unwrap_or(&[]),
)
.sqrt();
if shift > max_shift {
max_shift = shift;
}
}
max_shift
}
fn sample_batch_indices(n_samples: usize, batch_size: usize, rng: &mut StdRng) -> Vec<usize> {
if batch_size >= n_samples {
let mut indices: Vec<usize> = (0..n_samples).collect();
for i in 0..n_samples {
let j = rng.random_range(i..n_samples);
indices.swap(i, j);
}
return indices;
}
let mut pool: Vec<usize> = (0..n_samples).collect();
let mut result = Vec::with_capacity(batch_size);
for i in 0..batch_size {
let j = rng.random_range(i..n_samples);
pool.swap(i, j);
result.push(pool[i]);
}
result
}
impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, ()> for MiniBatchKMeans<F> {
type Fitted = FittedMiniBatchKMeans<F>;
type Error = FerroError;
fn fit(&self, x: &Array2<F>, _y: &()) -> Result<FittedMiniBatchKMeans<F>, FerroError> {
let (n_samples, n_features) = x.dim();
if self.n_clusters == 0 {
return Err(FerroError::InvalidParameter {
name: "n_clusters".into(),
reason: "must be at least 1".into(),
});
}
if self.batch_size == 0 {
return Err(FerroError::InvalidParameter {
name: "batch_size".into(),
reason: "must be at least 1".into(),
});
}
if self.n_init == 0 {
return Err(FerroError::InvalidParameter {
name: "n_init".into(),
reason: "must be at least 1".into(),
});
}
if n_samples == 0 {
return Err(FerroError::InsufficientSamples {
required: self.n_clusters,
actual: 0,
context: "MiniBatchKMeans requires at least n_clusters samples".into(),
});
}
if n_samples < self.n_clusters {
return Err(FerroError::InsufficientSamples {
required: self.n_clusters,
actual: n_samples,
context: "MiniBatchKMeans requires at least n_clusters samples".into(),
});
}
let base_seed = self.random_state.unwrap_or(0);
let mut best_result: Option<FittedMiniBatchKMeans<F>> = None;
for run in 0..self.n_init {
let mut rng = StdRng::seed_from_u64(base_seed.wrapping_add(run as u64 * 1_000_003));
let mut centers = match self.init {
MiniBatchKMeansInit::KMeansPlusPlus => {
kmeans_plus_plus_mb(x, self.n_clusters, &mut rng)
}
MiniBatchKMeansInit::Random => random_init_mb(x, self.n_clusters, &mut rng),
};
let mut center_counts = vec![1usize; self.n_clusters];
let mut n_iter = 0usize;
for _iter in 0..self.max_iter {
let batch_indices = sample_batch_indices(n_samples, self.batch_size, &mut rng);
let batch_labels = assign_batch(x, &batch_indices, ¢ers);
let shift = update_centers_mini_batch(
x,
&batch_indices,
&batch_labels,
&mut centers,
&mut center_counts,
);
n_iter += 1;
if shift < self.tol {
break;
}
}
let (labels, inertia) = assign_clusters_mb(x, ¢ers);
let _ = n_features;
let candidate = FittedMiniBatchKMeans {
cluster_centers_: centers,
labels_: labels,
inertia_: inertia,
n_iter_: n_iter,
};
match &best_result {
None => best_result = Some(candidate),
Some(best) => {
if candidate.inertia_ < best.inertia_ {
best_result = Some(candidate);
}
}
}
}
best_result.ok_or_else(|| FerroError::InvalidParameter {
name: "n_init".into(),
reason: "internal error: no runs completed".into(),
})
}
}
impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedMiniBatchKMeans<F> {
type Output = Array1<usize>;
type Error = FerroError;
fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
let n_features = x.ncols();
let expected_features = self.cluster_centers_.ncols();
if n_features != expected_features {
return Err(FerroError::ShapeMismatch {
expected: vec![expected_features],
actual: vec![n_features],
context: "FittedMiniBatchKMeans::predict".into(),
});
}
let (labels, _) = assign_clusters_mb(x, &self.cluster_centers_);
Ok(labels)
}
}
impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedMiniBatchKMeans<F> {
type Output = Array2<F>;
type Error = FerroError;
fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
let n_features = x.ncols();
let expected_features = self.cluster_centers_.ncols();
if n_features != expected_features {
return Err(FerroError::ShapeMismatch {
expected: vec![expected_features],
actual: vec![n_features],
context: "FittedMiniBatchKMeans::transform".into(),
});
}
let n_samples = x.nrows();
let k = self.cluster_centers_.nrows();
let distances: Vec<F> = (0..n_samples)
.into_par_iter()
.flat_map(|i| {
let row = x.row(i);
let row_slice = row.as_slice().unwrap_or(&[]);
(0..k)
.map(|c| {
squared_euclidean_mb(
row_slice,
self.cluster_centers_.row(c).as_slice().unwrap_or(&[]),
)
.sqrt()
})
.collect::<Vec<F>>()
})
.collect();
Array2::from_shape_vec((n_samples, k), distances).map_err(|_| {
FerroError::NumericalInstability {
message: "failed to construct distance matrix".into(),
}
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
fn make_blobs() -> Array2<f64> {
Array2::from_shape_vec(
(9, 2),
vec![
0.0, 0.0, 0.1, 0.1, -0.1, 0.1, 10.0, 10.0, 10.1, 10.1, 9.9, 10.1, 0.0, 10.0, 0.1, 10.1, -0.1, 9.9, ],
)
.unwrap()
}
#[test]
fn test_well_separated_blobs() {
let x = make_blobs();
let model = MiniBatchKMeans::<f64>::new(3)
.with_random_state(42)
.with_n_init(5)
.with_batch_size(9);
let fitted = model.fit(&x, &()).unwrap();
let labels = fitted.labels();
assert_eq!(labels[0], labels[1]);
assert_eq!(labels[0], labels[2]);
assert_eq!(labels[3], labels[4]);
assert_eq!(labels[3], labels[5]);
assert_eq!(labels[6], labels[7]);
assert_eq!(labels[6], labels[8]);
assert_ne!(labels[0], labels[3]);
assert_ne!(labels[0], labels[6]);
assert_ne!(labels[3], labels[6]);
}
#[test]
fn test_cluster_centers_shape() {
let x = make_blobs();
let model = MiniBatchKMeans::<f64>::new(3).with_random_state(42);
let fitted = model.fit(&x, &()).unwrap();
assert_eq!(fitted.cluster_centers().dim(), (3, 2));
}
#[test]
fn test_labels_length() {
let x = make_blobs();
let model = MiniBatchKMeans::<f64>::new(3).with_random_state(42);
let fitted = model.fit(&x, &()).unwrap();
assert_eq!(fitted.labels().len(), 9);
}
#[test]
fn test_inertia_non_negative() {
let x = make_blobs();
let model = MiniBatchKMeans::<f64>::new(3).with_random_state(42);
let fitted = model.fit(&x, &()).unwrap();
assert!(fitted.inertia() >= 0.0);
}
#[test]
fn test_n_iter_positive() {
let x = make_blobs();
let model = MiniBatchKMeans::<f64>::new(3).with_random_state(42);
let fitted = model.fit(&x, &()).unwrap();
assert!(fitted.n_iter() >= 1);
}
#[test]
fn test_predict_on_new_data() {
let x = make_blobs();
let model = MiniBatchKMeans::<f64>::new(3)
.with_random_state(42)
.with_n_init(5)
.with_batch_size(9);
let fitted = model.fit(&x, &()).unwrap();
let new_x =
Array2::from_shape_vec((3, 2), vec![0.05, 0.05, 10.05, 10.05, 0.05, 10.05]).unwrap();
let new_labels = fitted.predict(&new_x).unwrap();
assert_eq!(new_labels.len(), 3);
assert_eq!(new_labels[0], fitted.labels()[0]);
assert_eq!(new_labels[1], fitted.labels()[3]);
assert_eq!(new_labels[2], fitted.labels()[6]);
}
#[test]
fn test_transform_shape() {
let x = make_blobs();
let model = MiniBatchKMeans::<f64>::new(3).with_random_state(42);
let fitted = model.fit(&x, &()).unwrap();
let dists = fitted.transform(&x).unwrap();
assert_eq!(dists.dim(), (9, 3));
}
#[test]
fn test_transform_distances_structure() {
let x = Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 10.0, 0.0, 0.0, 10.0, 10.0, 10.0])
.unwrap();
let model = MiniBatchKMeans::<f64>::new(2)
.with_random_state(42)
.with_batch_size(4)
.with_n_init(5);
let fitted = model.fit(&x, &()).unwrap();
let dists = fitted.transform(&x).unwrap();
assert_eq!(dists.dim(), (4, 2));
for i in 0..4 {
let own_cluster = fitted.labels()[i];
let other_cluster = 1 - own_cluster;
assert!(dists[[i, own_cluster]] <= dists[[i, other_cluster]] + 1e-10);
}
}
#[test]
fn test_reproducibility_with_seed() {
let x = make_blobs();
let model = MiniBatchKMeans::<f64>::new(3)
.with_random_state(99)
.with_batch_size(9);
let fitted1 = model.fit(&x, &()).unwrap();
let fitted2 = model.fit(&x, &()).unwrap();
assert_eq!(fitted1.labels(), fitted2.labels());
assert_relative_eq!(fitted1.inertia(), fitted2.inertia(), epsilon = 1e-12);
}
#[test]
fn test_random_init() {
let x = make_blobs();
let model = MiniBatchKMeans::<f64>::new(3)
.with_random_state(42)
.with_init(MiniBatchKMeansInit::Random)
.with_n_init(5)
.with_batch_size(9);
let fitted = model.fit(&x, &()).unwrap();
assert_eq!(fitted.cluster_centers().dim(), (3, 2));
assert!(fitted.inertia() >= 0.0);
}
#[test]
fn test_single_cluster() {
let x =
Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0]).unwrap();
let model = MiniBatchKMeans::<f64>::new(1)
.with_random_state(42)
.with_batch_size(4);
let fitted = model.fit(&x, &()).unwrap();
for &label in fitted.labels().iter() {
assert_eq!(label, 0);
}
}
#[test]
fn test_zero_clusters_error() {
let x = Array2::from_shape_vec((3, 2), vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0]).unwrap();
assert!(MiniBatchKMeans::<f64>::new(0).fit(&x, &()).is_err());
}
#[test]
fn test_zero_batch_size_error() {
let x = Array2::from_shape_vec((3, 2), vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0]).unwrap();
let model = MiniBatchKMeans::<f64>::new(2).with_batch_size(0);
assert!(model.fit(&x, &()).is_err());
}
#[test]
fn test_too_few_samples_error() {
let x = Array2::from_shape_vec((2, 2), vec![1.0, 1.0, 2.0, 2.0]).unwrap();
assert!(MiniBatchKMeans::<f64>::new(5).fit(&x, &()).is_err());
}
#[test]
fn test_empty_data_error() {
let x = Array2::<f64>::zeros((0, 2));
assert!(MiniBatchKMeans::<f64>::new(3).fit(&x, &()).is_err());
}
#[test]
fn test_predict_shape_mismatch_error() {
let x =
Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0]).unwrap();
let model = MiniBatchKMeans::<f64>::new(2)
.with_random_state(42)
.with_batch_size(4);
let fitted = model.fit(&x, &()).unwrap();
let x_bad = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
assert!(fitted.predict(&x_bad).is_err());
}
#[test]
fn test_transform_shape_mismatch_error() {
let x =
Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0]).unwrap();
let model = MiniBatchKMeans::<f64>::new(2)
.with_random_state(42)
.with_batch_size(4);
let fitted = model.fit(&x, &()).unwrap();
let x_bad = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
assert!(fitted.transform(&x_bad).is_err());
}
#[test]
fn test_f32_support() {
let x: Array2<f32> = Array2::from_shape_vec(
(6, 2),
vec![
0.0f32, 0.0, 0.1, 0.1, -0.1, 0.1, 10.0, 10.0, 10.1, 10.1, 9.9, 10.1,
],
)
.unwrap();
let model = MiniBatchKMeans::<f32>::new(2)
.with_random_state(42)
.with_batch_size(6);
let fitted = model.fit(&x, &()).unwrap();
assert_eq!(fitted.labels().len(), 6);
let preds = fitted.predict(&x).unwrap();
assert_eq!(preds.len(), 6);
}
#[test]
fn test_large_batch_size() {
let x = make_blobs();
let model = MiniBatchKMeans::<f64>::new(3)
.with_random_state(7)
.with_batch_size(1000);
let fitted = model.fit(&x, &()).unwrap();
assert_eq!(fitted.cluster_centers().nrows(), 3);
}
#[test]
fn test_n_init_zero_error() {
let x = make_blobs();
let model = MiniBatchKMeans::<f64>::new(3).with_n_init(0);
assert!(model.fit(&x, &()).is_err());
}
#[test]
fn test_identical_points() {
let x =
Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]).unwrap();
let model = MiniBatchKMeans::<f64>::new(1)
.with_random_state(42)
.with_batch_size(4);
let fitted = model.fit(&x, &()).unwrap();
assert_relative_eq!(fitted.inertia(), 0.0, epsilon = 1e-10);
}
}