use super::helper_function::{preliminary_check, validate_max_iterations, validate_tolerance};
use crate::error::ModelError;
use crate::math::squared_euclidean_distance_row;
use crate::{Deserialize, Serialize};
use indicatif::{ProgressBar, ProgressStyle};
use ndarray::{Array1, Array2, ArrayBase, ArrayView1, ArrayView2, Data, Ix2};
use ndarray_rand::rand::rngs::StdRng;
use ndarray_rand::rand::{Rng, SeedableRng};
use ndarray_rand::rand::{RngCore, rng};
use rayon::prelude::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
use std::ops::AddAssign;
const KMEANS_PARALLEL_THRESHOLD: usize = 1000;
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct KMeans {
n_clusters: usize,
max_iter: usize,
tol: f64,
random_seed: Option<u64>,
centroids: Option<Array2<f64>>,
labels: Option<Array1<usize>>,
inertia: Option<f64>,
n_iter: Option<usize>,
}
impl Default for KMeans {
fn default() -> Self {
KMeans::new(8, 300, 1e-4, None).expect("Default KMeans parameters should be valid")
}
}
impl KMeans {
pub fn new(
n_clusters: usize,
max_iterations: usize,
tolerance: f64,
random_seed: Option<u64>,
) -> Result<Self, ModelError> {
if n_clusters == 0 {
return Err(ModelError::InputValidationError(
"n_clusters must be greater than 0".to_string(),
));
}
validate_max_iterations(max_iterations)?;
validate_tolerance(tolerance)?;
Ok(KMeans {
n_clusters,
max_iter: max_iterations,
tol: tolerance,
random_seed,
centroids: None,
labels: None,
inertia: None,
n_iter: None,
})
}
get_field!(get_n_clusters, n_clusters, usize);
get_field!(get_max_iterations, max_iter, usize);
get_field!(get_tolerance, tol, f64);
get_field!(get_random_seed, random_seed, Option<u64>);
get_field!(get_actual_iterations, n_iter, Option<usize>);
get_field_as_ref!(get_labels, labels, Option<&Array1<usize>>);
get_field!(get_inertia, inertia, Option<f64>);
get_field_as_ref!(get_centroids, centroids, Option<&Array2<f64>>);
fn closest_centroid(&self, x: &ArrayView2<f64>) -> Result<(usize, f64), ModelError> {
let sample = x.row(0);
let centroids = self.centroids.as_ref().unwrap();
let mut min_dist = f64::MAX;
let mut min_idx = 0;
for (i, centroid) in centroids.outer_iter().enumerate() {
let dist = squared_euclidean_distance_row(&sample, ¢roid);
if dist < min_dist {
min_dist = dist;
min_idx = i;
}
}
Ok((min_idx, min_dist))
}
fn init_centroids<S>(&mut self, data: &ArrayBase<S, Ix2>) -> Result<(), ModelError>
where
S: Data<Elem = f64>,
{
let n_samples = data.shape()[0];
let n_features = data.shape()[1];
let mut centroids = Array2::<f64>::zeros((self.n_clusters, n_features));
let mut rng = match self.random_seed {
Some(seed) => StdRng::seed_from_u64(seed),
None => StdRng::seed_from_u64(rng().next_u64()),
};
let first_center_idx = rng.random_range(0..n_samples);
centroids.row_mut(0).assign(&data.row(first_center_idx));
for k in 1..self.n_clusters {
let distances: Vec<f64> = data
.outer_iter()
.into_par_iter()
.map(|sample| {
centroids
.rows()
.into_iter()
.take(k)
.map(|centroid| squared_euclidean_distance_row(&sample, ¢roid))
.collect::<Vec<_>>()
.into_iter()
.fold(f64::MAX, f64::min)
})
.collect();
let total_dist: f64 = distances.iter().sum();
if total_dist == 0.0 {
let random_idx = rng.random_range(0..n_samples);
centroids.row_mut(k).assign(&data.row(random_idx));
continue;
}
let mut cumulative_dist = 0.0;
let choice = rng.random::<f64>() * total_dist;
for (i, &dist) in distances.iter().enumerate() {
cumulative_dist += dist;
if cumulative_dist >= choice {
centroids.row_mut(k).assign(&data.row(i));
break;
}
}
}
self.centroids = Some(centroids);
Ok(())
}
pub fn fit<S>(&mut self, data: &ArrayBase<S, Ix2>) -> Result<&mut Self, ModelError>
where
S: Data<Elem = f64>,
{
preliminary_check(data, None)?;
let n_samples = data.shape()[0];
let n_features = data.shape()[1];
if n_samples < self.n_clusters {
return Err(ModelError::InputValidationError(
"Number of samples is less than number of clusters".to_string(),
));
}
self.init_centroids(data)?;
let mut labels = Array1::<usize>::zeros(n_samples);
let mut prev_inertia: Option<f64> = None;
let mut iter_count = 0;
let mut new_centroids = Array2::<f64>::zeros((self.n_clusters, n_features));
let mut counts = vec![0usize; self.n_clusters];
let progress_bar = ProgressBar::new(self.max_iter as u64);
progress_bar.set_style(
ProgressStyle::default_bar()
.template("[{elapsed_precise}] {bar:40.cyan/blue} {pos}/{len} | Inertia: {msg}")
.expect("Failed to set progress bar template")
.progress_chars("█▓░"),
);
progress_bar.set_message(format!("{:.6}", f64::INFINITY));
for i in 0..self.max_iter {
new_centroids.fill(0.0);
counts.fill(0);
let compute_assignments =
|sample: ArrayView1<f64>| -> Result<(usize, f64), ModelError> {
let mut min_dist = f64::MAX;
let mut min_cluster = 0;
for (cluster_idx, centroid) in
self.centroids.as_ref().unwrap().outer_iter().enumerate()
{
let dist = squared_euclidean_distance_row(&sample, ¢roid);
if dist < min_dist {
min_dist = dist;
min_cluster = cluster_idx;
}
}
Ok((min_cluster, min_dist))
};
let results: Result<Vec<(usize, f64)>, ModelError> =
if n_samples >= KMEANS_PARALLEL_THRESHOLD {
data.outer_iter()
.into_par_iter()
.map(compute_assignments)
.collect()
} else {
data.outer_iter().map(compute_assignments).collect()
};
let results = results?;
let mut inertia = 0.0;
for (sample_idx, &(cluster_idx, dist)) in results.iter().enumerate() {
labels[sample_idx] = cluster_idx;
inertia += dist;
let sample = data.row(sample_idx);
new_centroids.row_mut(cluster_idx).add_assign(&sample);
counts[cluster_idx] += 1;
}
progress_bar.set_message(format!("{:.6}", inertia));
progress_bar.inc(1);
if let Some(prev) = prev_inertia {
if (prev - inertia).abs() < self.tol * prev.max(self.tol) {
iter_count = i + 1;
self.inertia = Some(inertia);
break;
}
}
prev_inertia = Some(inertia);
iter_count = i + 1;
new_centroids
.outer_iter_mut()
.into_par_iter()
.enumerate()
.for_each(|(idx, mut centroid_row)| {
if counts[idx] > 0 {
let count_f = counts[idx] as f64;
centroid_row.par_mapv_inplace(|x| x / count_f);
}
});
for (cluster_idx, &count) in counts.iter().enumerate() {
if count == 0 {
let result: Result<Option<usize>, ModelError> = results
.iter()
.enumerate()
.try_fold(
None,
|acc, (sample_idx, &(assigned_cluster, dist))| match acc {
None => Ok(Some((sample_idx, assigned_cluster, dist))),
Some((best_idx, best_cluster, best_dist)) => {
if dist > best_dist {
Ok(Some((sample_idx, assigned_cluster, dist)))
} else {
Ok(Some((best_idx, best_cluster, best_dist)))
}
}
},
)
.map(|opt| opt.map(|(idx, _, _)| idx));
if let Some(farthest_idx) = result? {
new_centroids
.row_mut(cluster_idx)
.assign(&data.row(farthest_idx));
} else {
new_centroids
.row_mut(cluster_idx)
.assign(&self.centroids.as_ref().unwrap().row(cluster_idx));
}
}
}
self.centroids = Some(new_centroids);
new_centroids = Array2::<f64>::zeros((self.n_clusters, n_features));
}
let final_inertia = self.inertia.unwrap_or_else(|| prev_inertia.unwrap_or(0.0));
let convergence_status = if iter_count < self.max_iter {
"Converged"
} else {
"Max iterations"
};
progress_bar.finish_with_message(format!(
"{:.6} | {} | Iterations: {}",
final_inertia, convergence_status, iter_count
));
self.labels = Some(labels);
if self.inertia.is_none() {
self.inertia = prev_inertia;
}
self.n_iter = Some(iter_count);
println!(
"\nKMeans clustering completed: {} samples, {} clusters, {} iterations, final inertia: {:.6}",
n_samples, self.n_clusters, iter_count, final_inertia
);
Ok(self)
}
pub fn predict<S>(&self, data: &ArrayBase<S, Ix2>) -> Result<Array1<usize>, ModelError>
where
S: Data<Elem = f64>,
{
if self.centroids.is_none() {
return Err(ModelError::NotFitted);
}
if data.is_empty() {
return Err(ModelError::InputValidationError(
"Cannot predict on empty dataset".to_string(),
));
}
if data.iter().any(|&val| !val.is_finite()) {
return Err(ModelError::InputValidationError(
"Input data contains NaN or infinite values".to_string(),
));
}
let n_features = data.shape()[1];
let expected_features = self.centroids.as_ref().unwrap().shape()[1];
if n_features != expected_features {
return Err(ModelError::InputValidationError(format!(
"Feature dimension mismatch: expected {}, got {}",
expected_features, n_features
)));
}
let labels: Result<Vec<usize>, ModelError> = data
.outer_iter()
.into_par_iter()
.map(|sample| {
let sample_shaped = sample.to_shape((1, n_features)).map_err(|_| {
ModelError::InputValidationError(
"Failed to reshape sample during prediction".to_string(),
)
})?;
let sample_view = sample_shaped.view();
let (closest_idx, _) = self.closest_centroid(&sample_view)?;
Ok(closest_idx)
})
.collect();
Ok(Array1::from(labels?))
}
pub fn fit_predict<S>(&mut self, data: &ArrayBase<S, Ix2>) -> Result<Array1<usize>, ModelError>
where
S: Data<Elem = f64>,
{
self.fit(data)?;
Ok(self.labels.clone().unwrap())
}
model_save_and_load_methods!(KMeans);
}