use ferrolearn_core::error::FerroError;
use ferrolearn_core::traits::{Fit, Predict};
use ndarray::{Array1, Array2};
use num_traits::Float;
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use std::marker::PhantomData;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BisectingStrategy {
LargestCluster,
LargestSSE,
}
#[derive(Debug, Clone)]
pub struct BisectingKMeans<F> {
n_clusters: usize,
max_iter: usize,
n_init: usize,
random_state: Option<u64>,
bisecting_strategy: BisectingStrategy,
_marker: PhantomData<F>,
}
impl<F: Float> BisectingKMeans<F> {
#[must_use]
pub fn new(n_clusters: usize) -> Self {
Self {
n_clusters,
max_iter: 300,
n_init: 10,
random_state: None,
bisecting_strategy: BisectingStrategy::LargestCluster,
_marker: PhantomData,
}
}
#[must_use]
pub fn with_max_iter(mut self, max_iter: usize) -> Self {
self.max_iter = max_iter;
self
}
#[must_use]
pub fn with_n_init(mut self, n_init: usize) -> Self {
self.n_init = n_init;
self
}
#[must_use]
pub fn with_random_state(mut self, seed: u64) -> Self {
self.random_state = Some(seed);
self
}
#[must_use]
pub fn with_bisecting_strategy(mut self, strategy: BisectingStrategy) -> Self {
self.bisecting_strategy = strategy;
self
}
}
#[derive(Debug, Clone)]
pub struct FittedBisectingKMeans<F> {
cluster_centers_: Array2<F>,
labels_: Array1<isize>,
inertia_: F,
}
impl<F: Float> FittedBisectingKMeans<F> {
#[must_use]
pub fn cluster_centers(&self) -> &Array2<F> {
&self.cluster_centers_
}
#[must_use]
pub fn labels(&self) -> &Array1<isize> {
&self.labels_
}
#[must_use]
pub fn inertia(&self) -> F {
self.inertia_
}
#[must_use]
pub fn n_clusters(&self) -> usize {
self.cluster_centers_.nrows()
}
}
fn squared_euclidean<F: Float>(a: &[F], b: &[F]) -> F {
a.iter()
.zip(b.iter())
.fold(F::zero(), |acc, (&ai, &bi)| acc + (ai - bi) * (ai - bi))
}
#[derive(Debug, Clone)]
struct ClusterInfo<F> {
indices: Vec<usize>,
center: Vec<F>,
sse: F,
}
fn compute_cluster_stats<F: Float>(x: &Array2<F>, indices: &[usize]) -> (Vec<F>, F) {
let n_features = x.ncols();
let n = indices.len();
if n == 0 {
return (vec![F::zero(); n_features], F::zero());
}
let n_f = F::from(n).unwrap();
let mut center = vec![F::zero(); n_features];
for &idx in indices {
for j in 0..n_features {
center[j] = center[j] + x[[idx, j]];
}
}
for val in center.iter_mut() {
*val = *val / n_f;
}
let mut sse = F::zero();
for &idx in indices {
let row = x.row(idx);
let row_slice = row.as_slice().unwrap_or(&[]);
sse = sse + squared_euclidean(row_slice, ¢er);
}
(center, sse)
}
fn run_2means<F: Float>(
x: &Array2<F>,
indices: &[usize],
max_iter: usize,
rng: &mut StdRng,
) -> (Vec<usize>, Vec<F>, Vec<F>, F) {
let n = indices.len();
let n_features = x.ncols();
let idx0 = rng.random_range(0..n);
let mut idx1 = rng.random_range(0..n);
if n > 1 {
while idx1 == idx0 {
idx1 = rng.random_range(0..n);
}
}
let mut center0: Vec<F> = (0..n_features)
.map(|j| x[[indices[idx0], j]])
.collect();
let mut center1: Vec<F> = (0..n_features)
.map(|j| x[[indices[idx1], j]])
.collect();
let mut labels = vec![0usize; n];
for _iter in 0..max_iter {
let mut changed = false;
for (li, &sample_idx) in indices.iter().enumerate() {
let row = x.row(sample_idx);
let row_slice = row.as_slice().unwrap_or(&[]);
let d0 = squared_euclidean(row_slice, ¢er0);
let d1 = squared_euclidean(row_slice, ¢er1);
let new_label = if d0 <= d1 { 0 } else { 1 };
if new_label != labels[li] {
labels[li] = new_label;
changed = true;
}
}
if !changed {
break;
}
let mut new_c0 = vec![F::zero(); n_features];
let mut new_c1 = vec![F::zero(); n_features];
let mut count0 = F::zero();
let mut count1 = F::zero();
for (li, &sample_idx) in indices.iter().enumerate() {
if labels[li] == 0 {
count0 = count0 + F::one();
for j in 0..n_features {
new_c0[j] = new_c0[j] + x[[sample_idx, j]];
}
} else {
count1 = count1 + F::one();
for j in 0..n_features {
new_c1[j] = new_c1[j] + x[[sample_idx, j]];
}
}
}
if count0 > F::zero() {
for val in new_c0.iter_mut() {
*val = *val / count0;
}
center0 = new_c0;
}
if count1 > F::zero() {
for val in new_c1.iter_mut() {
*val = *val / count1;
}
center1 = new_c1;
}
}
let mut total_sse = F::zero();
for (li, &sample_idx) in indices.iter().enumerate() {
let row = x.row(sample_idx);
let row_slice = row.as_slice().unwrap_or(&[]);
let center = if labels[li] == 0 {
¢er0
} else {
¢er1
};
total_sse = total_sse + squared_euclidean(row_slice, center);
}
(labels, center0, center1, total_sse)
}
impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, ()> for BisectingKMeans<F> {
type Fitted = FittedBisectingKMeans<F>;
type Error = FerroError;
fn fit(&self, x: &Array2<F>, _y: &()) -> Result<FittedBisectingKMeans<F>, FerroError> {
let (n_samples, n_features) = x.dim();
if self.n_clusters == 0 {
return Err(FerroError::InvalidParameter {
name: "n_clusters".into(),
reason: "must be at least 1".into(),
});
}
if self.n_init == 0 {
return Err(FerroError::InvalidParameter {
name: "n_init".into(),
reason: "must be at least 1".into(),
});
}
if n_samples == 0 {
return Err(FerroError::InsufficientSamples {
required: self.n_clusters,
actual: 0,
context: "BisectingKMeans requires at least n_clusters samples".into(),
});
}
if n_samples < self.n_clusters {
return Err(FerroError::InsufficientSamples {
required: self.n_clusters,
actual: n_samples,
context: "BisectingKMeans requires at least n_clusters samples".into(),
});
}
let all_indices: Vec<usize> = (0..n_samples).collect();
let (center, sse) = compute_cluster_stats(x, &all_indices);
let mut clusters: Vec<ClusterInfo<F>> = vec![ClusterInfo {
indices: all_indices,
center,
sse,
}];
let base_seed = self.random_state.unwrap_or(0);
let mut split_count: u64 = 0;
while clusters.len() < self.n_clusters {
let target_idx = match self.bisecting_strategy {
BisectingStrategy::LargestCluster => {
clusters
.iter()
.enumerate()
.filter(|(_, c)| c.indices.len() >= 2)
.max_by_key(|(_, c)| c.indices.len())
.map(|(i, _)| i)
}
BisectingStrategy::LargestSSE => {
clusters
.iter()
.enumerate()
.filter(|(_, c)| c.indices.len() >= 2)
.max_by(|(_, a), (_, b)| {
a.sse.partial_cmp(&b.sse).unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(i, _)| i)
}
};
let target_idx = match target_idx {
Some(i) => i,
None => break, };
let target = &clusters[target_idx];
let target_indices = &target.indices;
let mut best_labels: Option<Vec<usize>> = None;
let mut best_c0: Vec<F> = Vec::new();
let mut best_c1: Vec<F> = Vec::new();
let mut best_sse = F::max_value();
for run in 0..self.n_init {
let seed = base_seed
.wrapping_add(split_count.wrapping_mul(1000))
.wrapping_add(run as u64);
let mut rng = StdRng::seed_from_u64(seed);
let (labels, c0, c1, sse) =
run_2means(x, target_indices, self.max_iter, &mut rng);
if sse < best_sse {
best_sse = sse;
best_labels = Some(labels);
best_c0 = c0;
best_c1 = c1;
}
}
split_count += 1;
let best_labels = best_labels.unwrap();
let mut indices0 = Vec::new();
let mut indices1 = Vec::new();
for (li, &sample_idx) in target_indices.iter().enumerate() {
if best_labels[li] == 0 {
indices0.push(sample_idx);
} else {
indices1.push(sample_idx);
}
}
let sse0 = if indices0.is_empty() {
F::zero()
} else {
let mut sse = F::zero();
for &idx in &indices0 {
let row = x.row(idx);
let row_slice = row.as_slice().unwrap_or(&[]);
sse = sse + squared_euclidean(row_slice, &best_c0);
}
sse
};
let sse1 = if indices1.is_empty() {
F::zero()
} else {
let mut sse = F::zero();
for &idx in &indices1 {
let row = x.row(idx);
let row_slice = row.as_slice().unwrap_or(&[]);
sse = sse + squared_euclidean(row_slice, &best_c1);
}
sse
};
clusters.remove(target_idx);
if !indices0.is_empty() {
clusters.push(ClusterInfo {
indices: indices0,
center: best_c0,
sse: sse0,
});
}
if !indices1.is_empty() {
clusters.push(ClusterInfo {
indices: indices1,
center: best_c1,
sse: sse1,
});
}
}
let n_final_clusters = clusters.len();
let mut cluster_centers = Array2::zeros((n_final_clusters, n_features));
let mut labels = Array1::from_elem(n_samples, 0isize);
let mut total_inertia = F::zero();
for (ci, cluster) in clusters.iter().enumerate() {
for j in 0..n_features {
cluster_centers[[ci, j]] = cluster.center[j];
}
for &idx in &cluster.indices {
labels[idx] = ci as isize;
}
total_inertia = total_inertia + cluster.sse;
}
Ok(FittedBisectingKMeans {
cluster_centers_: cluster_centers,
labels_: labels,
inertia_: total_inertia,
})
}
}
impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedBisectingKMeans<F> {
type Output = Array1<isize>;
type Error = FerroError;
fn predict(&self, x: &Array2<F>) -> Result<Array1<isize>, FerroError> {
let n_features = x.ncols();
let expected_features = self.cluster_centers_.ncols();
if n_features != expected_features {
return Err(FerroError::ShapeMismatch {
expected: vec![expected_features],
actual: vec![n_features],
context: "number of features must match fitted BisectingKMeans model".into(),
});
}
let n_samples = x.nrows();
let k = self.cluster_centers_.nrows();
let mut labels = Array1::from_elem(n_samples, 0isize);
for i in 0..n_samples {
let row = x.row(i);
let row_slice = row.as_slice().unwrap_or(&[]);
let mut best_label = 0isize;
let mut best_dist = F::max_value();
for c in 0..k {
let center = self.cluster_centers_.row(c);
let center_slice = center.as_slice().unwrap_or(&[]);
let d = squared_euclidean(row_slice, center_slice);
if d < best_dist {
best_dist = d;
best_label = c as isize;
}
}
labels[i] = best_label;
}
Ok(labels)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_blobs() -> Array2<f64> {
Array2::from_shape_vec(
(9, 2),
vec![
0.0, 0.0, 0.1, 0.1, -0.1, 0.1,
10.0, 10.0, 10.1, 10.1, 9.9, 10.1,
0.0, 10.0, 0.1, 10.1, -0.1, 9.9,
],
)
.unwrap()
}
#[test]
fn test_well_separated_blobs() {
let x = make_blobs();
let model = BisectingKMeans::<f64>::new(3).with_random_state(42);
let fitted = model.fit(&x, &()).unwrap();
let labels = fitted.labels();
assert_eq!(labels[0], labels[1]);
assert_eq!(labels[0], labels[2]);
assert_eq!(labels[3], labels[4]);
assert_eq!(labels[3], labels[5]);
assert_eq!(labels[6], labels[7]);
assert_eq!(labels[6], labels[8]);
assert_ne!(labels[0], labels[3]);
assert_ne!(labels[0], labels[6]);
assert_ne!(labels[3], labels[6]);
}
#[test]
fn test_two_clusters() {
let x = Array2::from_shape_vec(
(8, 2),
vec![
0.0, 0.0, 0.5, 0.0, 0.0, 0.5, 0.5, 0.5,
10.0, 10.0, 10.5, 10.0, 10.0, 10.5, 10.5, 10.5,
],
)
.unwrap();
let model = BisectingKMeans::<f64>::new(2).with_random_state(42);
let fitted = model.fit(&x, &()).unwrap();
let labels = fitted.labels();
assert_eq!(labels[0], labels[1]);
assert_eq!(labels[0], labels[2]);
assert_eq!(labels[0], labels[3]);
assert_eq!(labels[4], labels[5]);
assert_eq!(labels[4], labels[6]);
assert_eq!(labels[4], labels[7]);
assert_ne!(labels[0], labels[4]);
}
#[test]
fn test_single_cluster() {
let x = Array2::from_shape_vec(
(4, 2),
vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0],
)
.unwrap();
let model = BisectingKMeans::<f64>::new(1).with_random_state(42);
let fitted = model.fit(&x, &()).unwrap();
for &label in fitted.labels().iter() {
assert_eq!(label, 0);
}
assert_eq!(fitted.n_clusters(), 1);
}
#[test]
fn test_predict_assigns_correctly() {
let x = make_blobs();
let model = BisectingKMeans::<f64>::new(3).with_random_state(42);
let fitted = model.fit(&x, &()).unwrap();
let predicted = fitted.predict(&x).unwrap();
assert_eq!(predicted, *fitted.labels());
}
#[test]
fn test_predict_new_data() {
let x = make_blobs();
let model = BisectingKMeans::<f64>::new(3).with_random_state(42);
let fitted = model.fit(&x, &()).unwrap();
let new_x = Array2::from_shape_vec(
(3, 2),
vec![0.05, 0.05, 10.05, 10.05, 0.05, 10.05],
)
.unwrap();
let new_labels = fitted.predict(&new_x).unwrap();
assert_eq!(new_labels[0], fitted.labels()[0]);
assert_eq!(new_labels[1], fitted.labels()[3]);
assert_eq!(new_labels[2], fitted.labels()[6]);
}
#[test]
fn test_predict_shape_mismatch() {
let x = make_blobs();
let model = BisectingKMeans::<f64>::new(3).with_random_state(42);
let fitted = model.fit(&x, &()).unwrap();
let x_bad = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
let result = fitted.predict(&x_bad);
assert!(result.is_err());
}
#[test]
fn test_inertia_non_negative() {
let x = make_blobs();
let model = BisectingKMeans::<f64>::new(3).with_random_state(42);
let fitted = model.fit(&x, &()).unwrap();
assert!(fitted.inertia() >= 0.0);
}
#[test]
fn test_cluster_centers_shape() {
let x = make_blobs();
let model = BisectingKMeans::<f64>::new(3).with_random_state(42);
let fitted = model.fit(&x, &()).unwrap();
assert_eq!(fitted.cluster_centers().dim(), (3, 2));
}
#[test]
fn test_largest_sse_strategy() {
let x = make_blobs();
let model = BisectingKMeans::<f64>::new(3)
.with_random_state(42)
.with_bisecting_strategy(BisectingStrategy::LargestSSE);
let fitted = model.fit(&x, &()).unwrap();
assert_eq!(fitted.n_clusters(), 3);
let labels = fitted.labels();
assert_eq!(labels[0], labels[1]);
assert_eq!(labels[0], labels[2]);
assert_eq!(labels[3], labels[4]);
assert_eq!(labels[3], labels[5]);
assert_eq!(labels[6], labels[7]);
assert_eq!(labels[6], labels[8]);
}
#[test]
fn test_reproducibility() {
let x = make_blobs();
let model = BisectingKMeans::<f64>::new(3).with_random_state(123);
let fitted1 = model.fit(&x, &()).unwrap();
let fitted2 = model.fit(&x, &()).unwrap();
assert_eq!(fitted1.labels(), fitted2.labels());
}
#[test]
fn test_zero_clusters() {
let x = make_blobs();
let model = BisectingKMeans::<f64>::new(0);
let result = model.fit(&x, &());
assert!(result.is_err());
}
#[test]
fn test_k_greater_than_n_samples() {
let x = Array2::from_shape_vec((2, 2), vec![1.0, 1.0, 2.0, 2.0]).unwrap();
let model = BisectingKMeans::<f64>::new(5);
let result = model.fit(&x, &());
assert!(result.is_err());
}
#[test]
fn test_empty_data() {
let x = Array2::<f64>::zeros((0, 2));
let model = BisectingKMeans::<f64>::new(3);
let result = model.fit(&x, &());
assert!(result.is_err());
}
#[test]
fn test_invalid_n_init() {
let x = make_blobs();
let model = BisectingKMeans::<f64>::new(3).with_n_init(0);
let result = model.fit(&x, &());
assert!(result.is_err());
}
#[test]
fn test_single_sample() {
let x = Array2::from_shape_vec((1, 2), vec![5.0, 5.0]).unwrap();
let model = BisectingKMeans::<f64>::new(1).with_random_state(42);
let fitted = model.fit(&x, &()).unwrap();
assert_eq!(fitted.labels().len(), 1);
assert_eq!(fitted.labels()[0], 0);
assert_eq!(fitted.n_clusters(), 1);
}
#[test]
fn test_f32_support() {
let x = Array2::from_shape_vec(
(6, 2),
vec![
0.0f32, 0.0, 0.1, 0.1, -0.1, 0.1, 10.0, 10.0, 10.1, 10.1, 9.9, 10.1,
],
)
.unwrap();
let model = BisectingKMeans::<f32>::new(2).with_random_state(42);
let fitted = model.fit(&x, &()).unwrap();
assert_eq!(fitted.labels().len(), 6);
assert_eq!(fitted.n_clusters(), 2);
}
#[test]
fn test_identical_points() {
let x = Array2::from_shape_vec(
(4, 2),
vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
)
.unwrap();
let model = BisectingKMeans::<f64>::new(1).with_random_state(42);
let fitted = model.fit(&x, &()).unwrap();
assert_eq!(fitted.n_clusters(), 1);
}
#[test]
fn test_labels_in_range() {
let x = make_blobs();
let model = BisectingKMeans::<f64>::new(3).with_random_state(42);
let fitted = model.fit(&x, &()).unwrap();
let n_clusters = fitted.n_clusters() as isize;
for &label in fitted.labels().iter() {
assert!(label >= 0);
assert!(label < n_clusters);
}
}
#[test]
fn test_n_init_picks_best() {
let x = make_blobs();
let model_1 = BisectingKMeans::<f64>::new(3)
.with_random_state(42)
.with_n_init(1);
let fitted_1 = model_1.fit(&x, &()).unwrap();
let model_10 = BisectingKMeans::<f64>::new(3)
.with_random_state(42)
.with_n_init(10);
let fitted_10 = model_10.fit(&x, &()).unwrap();
assert!(fitted_10.inertia() <= fitted_1.inertia() + 1e-6);
}
#[test]
fn test_k_equals_n_samples() {
let x = Array2::from_shape_vec(
(3, 2),
vec![0.0, 0.0, 5.0, 5.0, 10.0, 10.0],
)
.unwrap();
let model = BisectingKMeans::<f64>::new(3).with_random_state(42);
let fitted = model.fit(&x, &()).unwrap();
assert_eq!(fitted.n_clusters(), 3);
let labels = fitted.labels();
assert_ne!(labels[0], labels[1]);
assert_ne!(labels[0], labels[2]);
assert_ne!(labels[1], labels[2]);
}
}