use super::distance::{DistanceMetric, SquaredEuclidean};
use super::flat::DataRef;
use super::util;
use crate::error::{Error, Result};
use rand::prelude::*;
#[derive(Debug, Clone)]
pub struct Kmeans<D: DistanceMetric = SquaredEuclidean> {
k: usize,
max_iter: usize,
tol: f64,
seed: Option<u64>,
seeding_alpha: f32,
metric: D,
init_centroids: Option<Vec<Vec<f32>>>,
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct KmeansFit<D: DistanceMetric = SquaredEuclidean> {
pub centroids: Vec<Vec<f32>>,
pub labels: Vec<usize>,
pub iters: usize,
pub inertia_trace: Vec<f32>,
metric: D,
}
impl<D: DistanceMetric> KmeansFit<D> {
pub fn predict(&self, data: &(impl DataRef + ?Sized)) -> Result<Vec<usize>> {
if data.n() == 0 {
return Err(Error::EmptyInput);
}
if self.centroids.is_empty() {
return Err(Error::InvalidParameter {
name: "centroids",
message: "must be non-empty",
});
}
let d = self.centroids[0].len();
let mut out = Vec::with_capacity(data.n());
for i in 0..data.n() {
let point = data.row(i);
if point.len() != d {
return Err(Error::DimensionMismatch {
expected: d,
found: point.len(),
});
}
out.push(util::assign_nearest(point, &self.centroids, &self.metric));
}
Ok(out)
}
pub fn wcss(&self, data: &(impl DataRef + ?Sized)) -> f32 {
(0..data.n())
.map(|i| {
self.metric
.distance(data.row(i), &self.centroids[self.labels[i]])
})
.sum()
}
}
impl Kmeans<SquaredEuclidean> {
pub fn new(k: usize) -> Self {
assert!(k > 0, "k must be at least 1");
Self {
k,
max_iter: 100,
tol: 1e-4,
seed: None,
seeding_alpha: 2.0,
metric: SquaredEuclidean,
init_centroids: None,
}
}
}
impl<D: DistanceMetric> Kmeans<D> {
pub fn with_metric(k: usize, metric: D) -> Self {
assert!(k > 0, "k must be at least 1");
Self {
k,
max_iter: 100,
tol: 1e-4,
seed: None,
seeding_alpha: 2.0,
metric,
init_centroids: None,
}
}
pub fn with_seeding_alpha(mut self, alpha: f32) -> Self {
self.seeding_alpha = alpha;
self
}
pub fn with_centroids(mut self, centroids: Vec<Vec<f32>>) -> Self {
self.init_centroids = Some(centroids);
self
}
pub fn with_max_iter(mut self, max_iter: usize) -> Self {
self.max_iter = max_iter;
self
}
pub fn with_tol(mut self, tol: f64) -> Self {
self.tol = tol;
self
}
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = Some(seed);
self
}
pub fn fit(&self, data: &(impl DataRef + ?Sized)) -> Result<KmeansFit<D>> {
if data.n() == 0 {
return Err(Error::EmptyInput);
}
if self.k == 0 {
return Err(Error::InvalidParameter {
name: "k",
message: "must be at least 1",
});
}
let n = data.n();
let d = data.d();
if d == 0 {
return Err(Error::InvalidParameter {
name: "dimension",
message: "must be at least 1",
});
}
if self.k > n {
return Err(Error::InvalidClusterCount {
requested: self.k,
n_items: n,
});
}
for i in 0..n {
if data.row(i).len() != d {
return Err(Error::DimensionMismatch {
expected: d,
found: data.row(i).len(),
});
}
}
util::validate_finite(data)?;
let effective_tol = (self.tol * util::mean_variance(data) * self.k as f64) as f32;
let mut rng = match self.seed {
Some(s) => StdRng::seed_from_u64(s),
None => StdRng::from_os_rng(),
};
let mut centroids = if let Some(ref init) = self.init_centroids {
assert_eq!(
init.len(),
self.k,
"init_centroids length ({}) must equal k ({})",
init.len(),
self.k
);
init.clone()
} else {
util::kmeanspp_init(data, self.k, &self.metric, self.seeding_alpha, &mut rng)
};
let mut labels = vec![0usize; n];
let mut new_centroids = vec![vec![0.0f32; d]; self.k];
let mut counts = vec![0usize; self.k];
let mut upper_bounds = vec![f32::MAX; n];
let mut lower_bounds = vec![0.0f32; n];
let mut centroid_shifts = vec![0.0f32; self.k];
let mut sums_f64 = vec![vec![0.0f64; d]; self.k];
let mut flat_buf: Vec<f32> = Vec::with_capacity(self.k * d);
let mut inertia_trace: Vec<f32> = Vec::with_capacity(self.max_iter);
let use_expanded = self.metric.supports_expanded_form();
let data_norms: Vec<f32> = if use_expanded {
util::squared_norms(data)
} else {
Vec::new()
};
#[cfg(feature = "gpu")]
let gpu_assigner = if self.metric.supports_expanded_form() && n * self.k >= 500_000 {
let data_flat = super::gpu::flatten(data);
super::gpu::GpuAssigner::new(&data_flat, n, self.k, d)
} else {
None
};
#[cfg(feature = "blas")]
let use_blas = n * self.k >= 100_000 && self.metric.supports_expanded_form();
#[cfg(feature = "blas")]
let blas_data = if use_blas {
let fd = super::flat::FlatMatrix::from_data(data);
let xn = fd.row_norms_sq();
Some((fd, xn))
} else {
None
};
let mut iters = 0usize;
for iter in 0..self.max_iter {
iters = iter + 1;
for c in &mut new_centroids {
c.fill(0.0);
}
counts.fill(0);
#[cfg(feature = "blas")]
let blas_used = if iter == 0 && use_blas {
if let Some((ref fd, ref xn)) = blas_data {
let fc = super::flat::FlatMatrix::from_data(¢roids);
let cn = fc.row_norms_sq();
let (new_labels, new_upper) = fd.blas_assign(&fc, xn, &cn);
labels.copy_from_slice(&new_labels);
for i in 0..n {
upper_bounds[i] = new_upper[i];
lower_bounds[i] = 0.0;
}
true
} else {
false
}
} else {
false
};
#[cfg(not(feature = "blas"))]
let blas_used = false;
#[cfg(feature = "gpu")]
let gpu_used = if !blas_used {
if let Some(ref assigner) = gpu_assigner {
let centroids_flat = super::gpu::flatten(¢roids);
let gpu_labels = assigner.assign(¢roids_flat);
labels.copy_from_slice(&gpu_labels);
true
} else {
false
}
} else {
false
};
#[cfg(not(feature = "gpu"))]
let gpu_used = blas_used;
if !gpu_used {
let expanded_used = if iter == 0 && use_expanded {
let centroid_norms = util::squared_norms(¢roids);
#[cfg(feature = "parallel")]
let (new_labels, new_upper, new_lower) = util::assign_expanded_parallel(
data,
¢roids,
&data_norms,
¢roid_norms,
);
#[cfg(not(feature = "parallel"))]
let (new_labels, new_upper, new_lower) =
util::assign_expanded(data, ¢roids, &data_norms, ¢roid_norms);
labels.copy_from_slice(&new_labels);
upper_bounds.copy_from_slice(&new_upper);
lower_bounds.copy_from_slice(&new_lower);
true
} else {
false
};
if !expanded_used {
#[cfg(feature = "parallel")]
{
util::hamerly_assign_parallel(
data,
¢roids,
&mut labels,
&mut upper_bounds,
&mut lower_bounds,
¢roid_shifts,
&self.metric,
iter == 0,
&mut flat_buf,
);
}
#[cfg(not(feature = "parallel"))]
if self.k <= 64 {
util::geometric_assign(
data,
¢roids,
&mut labels,
¢roid_shifts,
&self.metric,
iter == 0,
);
} else {
util::hamerly_assign(
data,
¢roids,
&mut labels,
&mut upper_bounds,
&mut lower_bounds,
¢roid_shifts,
&self.metric,
iter == 0,
&mut flat_buf,
);
}
}
}
for s in &mut sums_f64 {
s.fill(0.0);
}
#[allow(clippy::needless_range_loop)] for i in 0..n {
let k = labels[i];
let row = data.row(i);
for j in 0..d {
sums_f64[k][j] += row[j] as f64;
}
counts[k] += 1;
}
for k in 0..self.k {
if counts[k] > 0 {
let divisor = counts[k] as f64;
for j in 0..d {
new_centroids[k][j] = (sums_f64[k][j] / divisor) as f32;
}
} else {
let largest = counts
.iter()
.enumerate()
.max_by_key(|(_, &c)| c)
.map(|(idx, _)| idx)
.unwrap_or(0);
let mut farthest_idx = 0;
let mut farthest_dist = -1.0f32;
for (i, &label) in labels.iter().enumerate() {
if label == largest {
let dist = self.metric.distance(data.row(i), &new_centroids[largest]);
if dist > farthest_dist {
farthest_dist = dist;
farthest_idx = i;
}
}
}
new_centroids[k] = data.row(farthest_idx).to_vec();
}
}
if self.metric.normalize_centroids() {
for c in &mut new_centroids {
let norm: f32 = c.iter().map(|&x| x * x).sum::<f32>().sqrt();
if norm > f32::EPSILON {
for val in c.iter_mut() {
*val /= norm;
}
}
}
}
let mut shift = 0.0f32;
for k in 0..self.k {
let d = self.metric.distance(¢roids[k], &new_centroids[k]);
centroid_shifts[k] = d;
shift += d;
}
std::mem::swap(&mut centroids, &mut new_centroids);
let wcss: f32 = labels
.iter()
.enumerate()
.map(|(i, &l)| self.metric.distance(data.row(i), ¢roids[l]))
.sum();
inertia_trace.push(wcss);
if shift < effective_tol {
break;
}
}
for (i, label) in labels.iter_mut().enumerate() {
*label = util::assign_nearest(data.row(i), ¢roids, &self.metric);
}
Ok(KmeansFit {
centroids,
labels,
iters,
inertia_trace,
metric: self.metric.clone(),
})
}
}
impl<D: DistanceMetric> Kmeans<D> {
pub fn fit_predict(&self, data: &(impl DataRef + ?Sized)) -> Result<Vec<usize>> {
Ok(self.fit(data)?.labels)
}
pub fn n_clusters(&self) -> usize {
self.k
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::cluster::distance::Euclidean;
#[test]
fn test_kmeans_basic() {
let data = vec![
vec![0.0, 0.0],
vec![0.1, 0.1],
vec![10.0, 10.0],
vec![10.1, 10.1],
];
let kmeans = Kmeans::new(2).with_seed(42);
let labels = kmeans.fit_predict(&data).unwrap();
assert_eq!(labels[0], labels[1]);
assert_eq!(labels[2], labels[3]);
assert_ne!(labels[0], labels[2]);
}
#[test]
fn test_kmeans_all_points_assigned() {
let data: Vec<Vec<f32>> = (0..50)
.map(|i| vec![i as f32 * 0.1, (i % 5) as f32])
.collect();
let kmeans = Kmeans::new(5).with_seed(123);
let labels = kmeans.fit_predict(&data).unwrap();
assert_eq!(labels.len(), data.len());
for &label in &labels {
assert!(label < 5, "label {} out of range", label);
}
}
#[test]
fn test_kmeans_k_equals_n() {
let data = vec![vec![0.0, 0.0], vec![1.0, 0.0], vec![0.0, 1.0]];
let kmeans = Kmeans::new(3).with_seed(42);
let labels = kmeans.fit_predict(&data).unwrap();
let unique: std::collections::HashSet<_> = labels.iter().collect();
assert_eq!(unique.len(), 3);
}
#[test]
fn test_kmeans_deterministic_with_seed() {
let data = vec![
vec![0.0, 0.0],
vec![0.1, 0.1],
vec![10.0, 10.0],
vec![10.1, 10.1],
];
let kmeans1 = Kmeans::new(2).with_seed(42);
let kmeans2 = Kmeans::new(2).with_seed(42);
let labels1 = kmeans1.fit_predict(&data).unwrap();
let labels2 = kmeans2.fit_predict(&data).unwrap();
assert_eq!(labels1, labels2, "same seed should give same result");
}
#[test]
fn test_kmeans_scaling_invariant() {
let data = vec![
vec![0.0, 0.0],
vec![0.1, 0.1],
vec![10.0, 10.0],
vec![10.1, 10.1],
];
let scaled: Vec<Vec<f32>> = data
.iter()
.map(|v| v.iter().map(|x| x * 100.0).collect())
.collect();
let kmeans1 = Kmeans::new(2).with_seed(42);
let kmeans2 = Kmeans::new(2).with_seed(42);
let labels1 = kmeans1.fit_predict(&data).unwrap();
let labels2 = kmeans2.fit_predict(&scaled).unwrap();
assert_eq!(labels1[0], labels1[1]);
assert_eq!(labels2[0], labels2[1]);
assert_eq!(labels1[2], labels1[3]);
assert_eq!(labels2[2], labels2[3]);
assert_ne!(labels1[0], labels1[2]);
assert_ne!(labels2[0], labels2[2]);
}
#[test]
fn test_kmeans_empty_input_error() {
let data: Vec<Vec<f32>> = vec![];
let kmeans = Kmeans::new(2);
let result = kmeans.fit_predict(&data);
assert!(result.is_err());
}
#[test]
fn test_kmeans_alpha_seeding() {
let data = vec![
vec![0.0, 0.0],
vec![1.0, 1.0],
vec![10.0, 10.0],
vec![11.0, 11.0],
];
let kmeans = Kmeans::new(2).with_seed(42).with_seeding_alpha(4.0);
let labels = kmeans.fit_predict(&data).unwrap();
assert_eq!(labels[0], labels[1]);
assert_eq!(labels[2], labels[3]);
assert_ne!(labels[0], labels[2]);
}
#[test]
fn test_kmeans_with_euclidean() {
let data = vec![
vec![0.0, 0.0],
vec![0.1, 0.1],
vec![10.0, 10.0],
vec![10.1, 10.1],
];
let kmeans = Kmeans::with_metric(2, Euclidean).with_seed(42);
let labels = kmeans.fit_predict(&data).unwrap();
assert_eq!(labels[0], labels[1]);
assert_eq!(labels[2], labels[3]);
assert_ne!(labels[0], labels[2]);
}
#[test]
fn test_kmeans_fit_predict_with_custom_metric() {
let data = vec![
vec![0.0, 0.0],
vec![0.1, 0.1],
vec![10.0, 10.0],
vec![10.1, 10.1],
];
let kmeans = Kmeans::with_metric(2, Euclidean).with_seed(42);
let fit = kmeans.fit(&data).unwrap();
let new_data = vec![vec![0.05, 0.05], vec![10.05, 10.05]];
let predicted = fit.predict(&new_data).unwrap();
assert_ne!(predicted[0], predicted[1]);
}
#[test]
fn nan_input_rejected() {
let data = vec![vec![0.0, f32::NAN], vec![1.0, 1.0]];
let result = Kmeans::new(2).with_seed(42).fit_predict(&data);
assert!(result.is_err());
}
#[test]
fn inf_input_rejected() {
let data = vec![vec![0.0, 0.0], vec![1.0, f32::INFINITY]];
let result = Kmeans::new(2).with_seed(42).fit_predict(&data);
assert!(result.is_err());
}
#[test]
fn all_identical_points() {
let data = vec![vec![5.0, 5.0]; 10];
let fit = Kmeans::new(2).with_seed(42).fit(&data).unwrap();
assert!(
fit.iters <= 3,
"expected fast convergence, got {} iters",
fit.iters
);
}
#[test]
fn k1_centroid_equals_mean() {
let data = vec![vec![0.0, 0.0], vec![2.0, 4.0], vec![4.0, 8.0]];
let fit = Kmeans::new(1).with_seed(42).fit(&data).unwrap();
let centroid = &fit.centroids[0];
assert!(
(centroid[0] - 2.0).abs() < 1e-4,
"mean[0] should be 2.0, got {}",
centroid[0]
);
assert!(
(centroid[1] - 4.0).abs() < 1e-4,
"mean[1] should be 4.0, got {}",
centroid[1]
);
}
#[test]
fn self_identity_oracle() {
let data = vec![
vec![0.0, 0.0],
vec![0.1, 0.1],
vec![10.0, 10.0],
vec![10.1, 10.1],
];
let fit = Kmeans::new(2).with_seed(42).fit(&data).unwrap();
let predicted = fit.predict(&fit.centroids).unwrap();
for (k, &label) in predicted.iter().enumerate() {
assert_eq!(label, k, "centroid {k} should map to cluster {k}");
}
}
#[test]
fn scalar_data_d1() {
let data = vec![vec![0.0], vec![0.1], vec![10.0], vec![10.1]];
let labels = Kmeans::new(2).with_seed(42).fit_predict(&data).unwrap();
assert_eq!(labels[0], labels[1]);
assert_ne!(labels[0], labels[2]);
}
#[test]
fn cosine_centroids_are_normalized() {
use crate::cluster::distance::CosineDistance;
let data = vec![
vec![1.0, 0.1],
vec![2.0, 0.2],
vec![0.1, 1.0],
vec![0.2, 2.0],
];
let fit = Kmeans::with_metric(2, CosineDistance)
.with_seed(42)
.fit(&data)
.unwrap();
for (k, c) in fit.centroids.iter().enumerate() {
let norm: f32 = c.iter().map(|&x| x * x).sum::<f32>().sqrt();
assert!(
(norm - 1.0).abs() < 1e-4,
"centroid {k} should be unit-normalized, got norm={norm}"
);
}
}
#[test]
fn large_k_stress() {
use rand::prelude::*;
let mut rng = StdRng::seed_from_u64(42);
let data: Vec<Vec<f32>> = (0..5000)
.map(|_| (0..16).map(|_| rng.random::<f32>()).collect())
.collect();
let labels = Kmeans::new(100)
.with_max_iter(5)
.with_seed(42)
.fit_predict(&data)
.unwrap();
assert_eq!(labels.len(), 5000);
for &l in &labels {
assert!(l < 100);
}
}
#[test]
fn empty_cluster_reinit() {
let data = vec![
vec![0.0, 0.0],
vec![0.01, 0.01],
vec![10.0, 0.0],
vec![10.01, 0.01],
vec![0.0, 10.0],
vec![0.01, 10.01],
];
let fit = Kmeans::new(4).with_seed(42).fit(&data).unwrap();
assert_eq!(fit.centroids.len(), 4);
assert_eq!(fit.labels.len(), 6);
}
#[test]
fn high_dim_few_points() {
let data = vec![
vec![0.0; 200],
{
let mut v = vec![0.0; 200];
v[0] = 1.0;
v
},
vec![10.0; 200],
{
let mut v = vec![10.0; 200];
v[0] = 11.0;
v
},
];
let labels = Kmeans::new(2).with_seed(42).fit_predict(&data).unwrap();
assert_eq!(labels[0], labels[1]);
assert_eq!(labels[2], labels[3]);
assert_ne!(labels[0], labels[2]);
}
#[test]
fn single_point_k1() {
let data = vec![vec![42.0, 7.0]];
let fit = Kmeans::new(1).fit(&data).unwrap();
assert_eq!(fit.centroids.len(), 1);
assert!((fit.centroids[0][0] - 42.0).abs() < 1e-6);
assert!((fit.centroids[0][1] - 7.0).abs() < 1e-6);
}
#[test]
fn wcss_better_than_random() {
use rand::prelude::*;
let mut rng = StdRng::seed_from_u64(42);
let data: Vec<Vec<f32>> = (0..200)
.map(|i| {
let center = if i < 100 { 0.0 } else { 20.0 };
vec![center + rng.random::<f32>(), center + rng.random::<f32>()]
})
.collect();
let fit = Kmeans::new(2).with_seed(42).fit(&data).unwrap();
let kmeans_wcss = fit.wcss(&data);
let mut rng2 = StdRng::seed_from_u64(99);
let idx_a = rng2.random_range(0..data.len());
let idx_b = loop {
let idx = rng2.random_range(0..data.len());
if idx != idx_a {
break idx;
}
};
let rand_centroids = [&data[idx_a], &data[idx_b]];
let random_wcss: f32 = data
.iter()
.map(|p| {
let d0 = SquaredEuclidean.distance(p, rand_centroids[0]);
let d1 = SquaredEuclidean.distance(p, rand_centroids[1]);
d0.min(d1)
})
.sum();
assert!(
kmeans_wcss < random_wcss,
"k-means WCSS ({kmeans_wcss}) should be less than random-centroid baseline ({random_wcss})"
);
}
#[test]
fn warm_start_convergence() {
let data = vec![
vec![0.0, 0.0],
vec![0.1, 0.1],
vec![10.0, 10.0],
vec![10.1, 10.1],
];
let fit1 = Kmeans::new(2).with_seed(42).fit(&data).unwrap();
let fit2 = Kmeans::new(2)
.with_centroids(fit1.centroids.clone())
.fit(&data)
.unwrap();
assert!(
fit2.iters <= 2,
"warm-start should converge fast, got {} iters",
fit2.iters
);
}
#[test]
fn centroids_approximate_means() {
let data = vec![
vec![0.0, 0.0],
vec![0.1, 0.1],
vec![0.2, 0.0],
vec![10.0, 10.0],
vec![10.1, 10.1],
vec![10.2, 10.0],
];
let fit = Kmeans::new(2).with_seed(42).fit(&data).unwrap();
for k in 0..2 {
let members: Vec<&Vec<f32>> = data
.iter()
.zip(fit.labels.iter())
.filter(|(_, &l)| l == k)
.map(|(p, _)| p)
.collect();
if members.is_empty() {
continue;
}
let d = members[0].len();
for j in 0..d {
let mean: f32 = members.iter().map(|p| p[j]).sum::<f32>() / members.len() as f32;
assert!(
(fit.centroids[k][j] - mean).abs() < 0.5,
"centroid[{k}][{j}] = {}, expected ~{mean}",
fit.centroids[k][j]
);
}
}
}
#[test]
fn predict_matches_fit_labels() {
use rand::prelude::*;
let mut rng = StdRng::seed_from_u64(42);
let data: Vec<Vec<f32>> = (0..100)
.map(|_| vec![rng.random::<f32>() * 10.0, rng.random::<f32>() * 10.0])
.collect();
let fit = Kmeans::new(5).with_seed(42).fit(&data).unwrap();
let predicted = fit.predict(&data).unwrap();
assert_eq!(
fit.labels, predicted,
"predict on training data must match fit labels"
);
}
#[test]
fn idempotent_refit() {
let data = vec![
vec![0.0, 0.0],
vec![0.1, 0.1],
vec![10.0, 10.0],
vec![10.1, 10.1],
vec![20.0, 20.0],
vec![20.1, 20.1],
];
let fit1 = Kmeans::new(3).with_seed(42).fit(&data).unwrap();
let fit2 = Kmeans::new(3)
.with_centroids(fit1.centroids.clone())
.fit(&data)
.unwrap();
assert_eq!(fit1.labels, fit2.labels, "refit should produce same labels");
}
#[test]
fn extreme_scale_small() {
let data = vec![
vec![0.0, 0.0],
vec![0.1, 0.1],
vec![10.0, 10.0],
vec![10.1, 10.1],
];
let scaled: Vec<Vec<f32>> = data
.iter()
.map(|v| v.iter().map(|&x| x * 1e-6).collect())
.collect();
let labels1 = Kmeans::new(2).with_seed(42).fit_predict(&data).unwrap();
let labels2 = Kmeans::new(2).with_seed(42).fit_predict(&scaled).unwrap();
assert_eq!(labels1[0] == labels1[1], labels2[0] == labels2[1]);
assert_eq!(labels1[2] == labels1[3], labels2[2] == labels2[3]);
assert_ne!(labels1[0], labels1[2]);
assert_ne!(labels2[0], labels2[2]);
}
#[test]
fn precision_at_scale() {
use rand::prelude::*;
let mut rng = StdRng::seed_from_u64(42);
let n = 10000;
let data: Vec<Vec<f32>> = (0..n)
.map(|i| {
let center = if i < n / 2 { 0.0 } else { 10.0 };
vec![
center + rng.random::<f32>() * 0.1,
center + rng.random::<f32>() * 0.1,
]
})
.collect();
let fit = Kmeans::new(2).with_seed(42).fit(&data).unwrap();
for c in &fit.centroids {
let near_zero = (c[0] - 0.05).abs() < 0.1 && (c[1] - 0.05).abs() < 0.1;
let near_ten = (c[0] - 10.05).abs() < 0.1 && (c[1] - 10.05).abs() < 0.1;
assert!(
near_zero || near_ten,
"centroid {:?} should be near (0.05, 0.05) or (10.05, 10.05)",
c
);
}
}
}
#[cfg(test)]
mod proptests {
use super::*;
use proptest::prelude::*;
fn arb_data(max_n: usize, d: usize) -> impl Strategy<Value = Vec<Vec<f32>>> {
proptest::collection::vec(
proptest::collection::vec(-100.0f32..100.0, d..=d),
3..=max_n,
)
}
proptest! {
#[test]
fn labels_in_range(data in arb_data(50, 4)) {
let k = 3.min(data.len());
let labels = Kmeans::new(k).with_seed(42).with_max_iter(5)
.fit_predict(&data).unwrap();
prop_assert_eq!(labels.len(), data.len());
for &l in &labels {
prop_assert!(l < k, "label {} out of range [0, {})", l, k);
}
}
#[test]
fn predict_consistent(data in arb_data(30, 3)) {
let k = 2.min(data.len());
let fit = Kmeans::new(k).with_seed(42).with_max_iter(5)
.fit(&data).unwrap();
let predicted = fit.predict(&data).unwrap();
prop_assert_eq!(&fit.labels, &predicted);
}
#[test]
fn wcss_nonneg(data in arb_data(30, 3)) {
let k = 2.min(data.len());
let fit = Kmeans::new(k).with_seed(42).with_max_iter(5)
.fit(&data).unwrap();
let wcss = fit.wcss(&data);
prop_assert!(wcss >= 0.0, "WCSS must be >= 0, got {}", wcss);
prop_assert!(wcss.is_finite(), "WCSS must be finite");
}
#[test]
fn inertia_monotone_decreasing(data in arb_data(30, 3)) {
let k = 2.min(data.len());
let fit = Kmeans::new(k).with_seed(42).with_max_iter(20)
.fit(&data).unwrap();
let trace = &fit.inertia_trace;
prop_assert!(!trace.is_empty(), "inertia trace must not be empty");
for w in trace.windows(2) {
prop_assert!(w[1] <= w[0] + 1e-5,
"inertia increased: {} -> {}", w[0], w[1]);
}
}
#[test]
fn inertia_trace_length(data in arb_data(20, 3)) {
let k = 2.min(data.len());
let fit = Kmeans::new(k).with_seed(42).with_max_iter(10)
.fit(&data).unwrap();
prop_assert_eq!(fit.inertia_trace.len(), fit.iters,
"trace length {} != iters {}", fit.inertia_trace.len(), fit.iters);
}
}
}