pub mod dist;
pub mod init;
use ndarray::{Array1, Array2, ArrayBase, Data, Ix1, Ix2, Zip};
pub use crate::kmeans::init::KMeansInit;
pub struct KMeans {
k_clusters: usize,
max_iter: usize,
tolerance: f64,
init_fn: KMeansInit,
}
impl KMeans {
#[must_use]
pub fn new_random(k_clusters: usize) -> Self {
assert_ne!(k_clusters, 0, "k_clusters must be > 0");
Self {
k_clusters,
init_fn: KMeansInit::Forgy,
tolerance: 1e-4,
max_iter: 300,
}
}
#[must_use]
pub fn new_plusplus(k_clusters: usize) -> Self {
Self {
init_fn: KMeansInit::PlusPlus,
..Self::new_random(k_clusters)
}
}
#[must_use]
pub fn with_tolerance(mut self, tolerance: f64) -> Self {
assert!(tolerance > 0.0);
self.tolerance = tolerance;
self
}
#[must_use]
pub fn with_max_iter(mut self, max_iter: usize) -> Self {
assert_ne!(max_iter, 0);
self.max_iter = max_iter;
self
}
pub fn fit(&self, data: &ArrayBase<impl Data<Elem = f64> + Sync, Ix2>) -> KMeansFitted {
let mut rng = rand::rng();
let mut centroids = self.init_fn.run(self.k_clusters, data, &mut rng);
let mut memberships = Array1::zeros(data.nrows());
for _ in 0..self.max_iter {
assign_clusters(data, ¢roids, &mut memberships);
let new_centroids = {
let mut new_centroids = Array2::<f64>::zeros((self.k_clusters, data.ncols()));
let mut counts = Array1::<f64>::zeros(self.k_clusters);
for (point, &membership) in data.outer_iter().zip(&memberships) {
let mut centroid = new_centroids.row_mut(membership);
centroid += &point;
counts[membership] += 1.0;
}
for (mut new_centroid, count) in new_centroids.outer_iter_mut().zip(counts) {
if count > 0.0 {
new_centroid /= count;
}
}
new_centroids
};
let distance = dist::naive_euclidean_sq(¢roids, &new_centroids);
centroids = new_centroids;
if distance < self.tolerance {
break;
}
}
KMeansFitted { centroids }
}
pub fn fit_predict(
&self,
data: &ArrayBase<impl Data<Elem = f64> + Sync, Ix2>,
) -> Array1<usize> {
self.fit(data).predict(data)
}
}
pub struct KMeansFitted {
centroids: Array2<f64>,
}
impl KMeansFitted {
#[must_use]
pub fn centroids(&self) -> &Array2<f64> {
&self.centroids
}
pub fn predict_inplace(
&self,
data: &ArrayBase<impl Data<Elem = f64> + Sync, Ix2>,
memberships: &mut Array1<usize>,
) {
assert_eq!(data.nrows(), memberships.len());
assign_clusters(data, self.centroids(), memberships);
}
pub fn predict(&self, data: &ArrayBase<impl Data<Elem = f64> + Sync, Ix2>) -> Array1<usize> {
let mut memberships = Array1::zeros(data.nrows());
assign_clusters(data, self.centroids(), &mut memberships);
memberships
}
}
fn assign_clusters(
data: &ArrayBase<impl Data<Elem = f64> + Sync, Ix2>,
centroids: &ArrayBase<impl Data<Elem = f64> + Sync, Ix2>,
memberships: &mut Array1<usize>,
) {
Zip::from(data.outer_iter())
.and(memberships)
.par_for_each(|point, membership| {
let (cluster_assignment, _) = closest_centroid(&point, centroids);
*membership = cluster_assignment;
});
}
fn closest_centroid(
point: &ArrayBase<impl Data<Elem = f64>, Ix1>,
centroids: &ArrayBase<impl Data<Elem = f64>, Ix2>,
) -> (usize, f64) {
if point.is_empty() || centroids.is_empty() {
unreachable!()
}
let point_dot = point.dot(point);
let mut cluster_assignment = 0;
let mut min_dist = f64::INFINITY;
for (c_idx, centroid) in centroids.outer_iter().enumerate() {
let dist = dist::euclidean_sq_precomputed(point, point_dot, ¢roid);
if dist < min_dist {
min_dist = dist;
cluster_assignment = c_idx;
}
}
(cluster_assignment, min_dist)
}