use std::cmp::Ordering;
use crate::k_means::hyperparameters::{KMeansHyperParams, KMeansHyperParamsBuilder};
use crate::{
k_means::errors::{KMeansError, Result},
KMeansInit,
};
use linfa::{prelude::*, DatasetBase, Float};
use ndarray::{Array1, Array2, ArrayBase, Axis, Data, DataMut, Ix1, Ix2, Zip};
use ndarray_rand::rand::Rng;
use ndarray_rand::rand::SeedableRng;
use ndarray_stats::DeviationExt;
use rand_isaac::Isaac64Rng;
#[cfg(feature = "serde")]
use serde_crate::{Deserialize, Serialize};
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
#[derive(Clone, Debug, PartialEq)]
pub struct KMeans<F: Float> {
centroids: Array2<F>,
cluster_count: Array1<F>,
inertia: F,
}
impl<F: Float> KMeans<F> {
pub fn params(nclusters: usize) -> KMeansHyperParamsBuilder<F, Isaac64Rng> {
KMeansHyperParams::new(nclusters)
}
pub fn params_with_rng<R: Rng + Clone>(
nclusters: usize,
rng: R,
) -> KMeansHyperParamsBuilder<F, R> {
KMeansHyperParams::new_with_rng(nclusters, rng)
}
pub fn centroids(&self) -> &Array2<F> {
&self.centroids
}
pub fn cluster_count(&self) -> &Array1<F> {
&self.cluster_count
}
pub fn inertia(&self) -> F {
self.inertia
}
}
impl<F: Float, R: Rng + Clone + SeedableRng, D: Data<Elem = F>, T>
Fit<ArrayBase<D, Ix2>, T, KMeansError> for KMeansHyperParams<F, R>
{
type Object = KMeans<F>;
fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object> {
let mut rng = self.rng();
let observations = dataset.records().view();
let n_samples = dataset.nsamples();
let mut min_inertia = F::infinity();
let mut best_centroids = None;
let mut best_iter = None;
let mut memberships = Array1::zeros(n_samples);
let mut dists = Array1::zeros(n_samples);
let n_runs = self.n_runs();
for _ in 0..n_runs {
let mut inertia = min_inertia;
let mut centroids = self
.init_method()
.run(self.n_clusters(), observations, &mut rng);
let mut converged_iter: Option<u64> = None;
for n_iter in 0..self.max_n_iterations() {
update_memberships_and_dists(
¢roids,
&observations,
&mut memberships,
&mut dists,
);
let new_centroids = compute_centroids(¢roids, &observations, &memberships);
inertia = dists.sum();
let distance = centroids
.sq_l2_dist(&new_centroids)
.expect("Failed to compute distance");
centroids = new_centroids;
if distance < self.tolerance() {
converged_iter = Some(n_iter);
break;
}
}
if inertia < min_inertia {
min_inertia = inertia;
best_centroids = Some(centroids.clone());
best_iter = converged_iter;
}
}
match best_iter {
Some(_n_iter) => match best_centroids {
Some(centroids) => {
let mut cluster_count = Array1::zeros(self.n_clusters());
memberships
.iter()
.for_each(|&c| cluster_count[c] += F::one());
Ok(KMeans {
centroids,
cluster_count,
inertia: min_inertia / F::cast(dataset.nsamples()),
})
}
_ => Err(KMeansError::InertiaError(
"No inertia improvement (-inf)".to_string(),
)),
},
None => Err(KMeansError::NotConverged(format!(
"KMeans fitting algorithm {} did not converge. Try different init parameters, \
or increase max_n_iterations, tolerance or check for degenerate data.",
(n_runs + 1)
))),
}
}
}
impl<'a, F: Float, R: Rng + Clone + SeedableRng, D: Data<Elem = F>, T>
IncrementalFit<'a, ArrayBase<D, Ix2>, T> for KMeansHyperParams<F, R>
{
type ObjectIn = Option<KMeans<F>>;
type ObjectOut = (KMeans<F>, bool);
fn fit_with(
&self,
model: Self::ObjectIn,
dataset: &'a DatasetBase<ArrayBase<D, Ix2>, T>,
) -> Self::ObjectOut {
let mut rng = self.rng();
let observations = dataset.records().view();
let n_samples = dataset.nsamples();
let mut model = match model {
Some(model) => model,
None => {
let centroids = if let KMeansInit::Precomputed(centroids) = self.init_method() {
centroids.clone()
} else {
let mut dists = Array1::zeros(n_samples);
(0..self.n_runs())
.map(|_| {
let centroids =
self.init_method()
.run(self.n_clusters(), observations, &mut rng);
update_min_dists(¢roids, &observations, &mut dists);
(centroids, dists.sum())
})
.min_by(|(_, d1), (_, d2)| {
if d1 < d2 {
Ordering::Less
} else {
Ordering::Greater
}
})
.unwrap()
.0
};
KMeans {
centroids,
cluster_count: Array1::zeros(self.n_clusters()),
inertia: F::zero(),
}
}
};
let mut memberships = Array1::zeros(n_samples);
let mut dists = Array1::zeros(n_samples);
update_memberships_and_dists(
&model.centroids,
&observations,
&mut memberships,
&mut dists,
);
let new_centroids = compute_centroids_incremental(
&observations,
&memberships,
&model.centroids,
&mut model.cluster_count,
);
model.inertia = dists.sum() / F::cast(n_samples);
let dist = model.centroids.sq_l2_dist(&new_centroids).unwrap();
model.centroids = new_centroids;
(model, dist < self.tolerance())
}
}
impl<'a, F: Float, R: Rng + SeedableRng + Clone> KMeansHyperParamsBuilder<F, R> {
pub fn fit<D: Data<Elem = F>, T>(
self,
dataset: &DatasetBase<ArrayBase<D, Ix2>, T>,
) -> Result<KMeans<F>> {
self.build().fit(dataset)
}
}
impl<F: Float, D: Data<Elem = F>> Transformer<&ArrayBase<D, Ix2>, Array1<F>> for KMeans<F> {
fn transform(&self, observations: &ArrayBase<D, Ix2>) -> Array1<F> {
let mut dists = Array1::zeros(observations.nrows());
update_min_dists(&self.centroids, &observations.view(), &mut dists);
dists
}
}
impl<F: Float, D: Data<Elem = F>> PredictRef<ArrayBase<D, Ix2>, Array1<usize>> for KMeans<F> {
fn predict_ref<'a>(&'a self, observations: &ArrayBase<D, Ix2>) -> Array1<usize> {
compute_cluster_memberships(&self.centroids, &observations.view())
}
}
impl<F: Float, D: Data<Elem = F>> PredictRef<ArrayBase<D, Ix1>, usize> for KMeans<F> {
fn predict_ref<'a>(&'a self, observation: &ArrayBase<D, Ix1>) -> usize {
closest_centroid(&self.centroids, &observation).0
}
}
pub fn compute_inertia<F: Float>(
centroids: &ArrayBase<impl Data<Elem = F> + Sync, Ix2>,
observations: &ArrayBase<impl Data<Elem = F>, Ix2>,
cluster_memberships: &ArrayBase<impl Data<Elem = usize>, Ix1>,
) -> F {
let mut dists = Array1::<F>::zeros(observations.nrows());
Zip::from(observations.genrows())
.and(cluster_memberships)
.and(&mut dists)
.par_apply(|observation, &cluster_membership, d| {
*d = centroids
.row(cluster_membership)
.sq_l2_dist(&observation)
.expect("Failed to compute distance");
});
dists.sum()
}
fn compute_centroids<F: Float>(
old_centroids: &Array2<F>,
observations: &ArrayBase<impl Data<Elem = F>, Ix2>,
cluster_memberships: &ArrayBase<impl Data<Elem = usize>, Ix1>,
) -> Array2<F> {
let n_clusters = old_centroids.nrows();
let mut counts: Array1<usize> = Array1::ones(n_clusters);
let mut centroids = Array2::zeros((n_clusters, observations.ncols()));
Zip::from(observations.genrows())
.and(cluster_memberships)
.apply(|observation, &cluster_membership| {
let mut centroid = centroids.row_mut(cluster_membership);
centroid += &observation;
counts[cluster_membership] += 1;
});
centroids += old_centroids;
Zip::from(centroids.genrows_mut())
.and(&counts)
.apply(|mut centroid, &cnt| centroid /= F::cast(cnt));
centroids
}
fn compute_centroids_incremental<F: Float>(
observations: &ArrayBase<impl Data<Elem = F>, Ix2>,
cluster_memberships: &ArrayBase<impl Data<Elem = usize>, Ix1>,
old_centroids: &ArrayBase<impl Data<Elem = F>, Ix2>,
counts: &mut ArrayBase<impl DataMut<Elem = F>, Ix1>,
) -> Array2<F> {
let mut centroids = old_centroids.to_owned();
Zip::from(observations.genrows())
.and(cluster_memberships)
.apply(|obs, &c| {
counts[c] += F::one();
let shift = (&obs - ¢roids.row(c)) / counts[c];
let mut centroid = centroids.row_mut(c);
centroid += &shift;
});
centroids
}
pub(crate) fn update_cluster_memberships<F: Float>(
centroids: &ArrayBase<impl Data<Elem = F> + Sync, Ix2>,
observations: &ArrayBase<impl Data<Elem = F> + Sync, Ix2>,
cluster_memberships: &mut ArrayBase<impl DataMut<Elem = usize>, Ix1>,
) {
Zip::from(observations.axis_iter(Axis(0)))
.and(cluster_memberships)
.par_apply(|observation, cluster_membership| {
*cluster_membership = closest_centroid(¢roids, &observation).0
});
}
pub(crate) fn update_min_dists<F: Float>(
centroids: &ArrayBase<impl Data<Elem = F> + Sync, Ix2>,
observations: &ArrayBase<impl Data<Elem = F> + Sync, Ix2>,
dists: &mut ArrayBase<impl DataMut<Elem = F>, Ix1>,
) {
Zip::from(observations.axis_iter(Axis(0)))
.and(dists)
.par_apply(|observation, dist| *dist = closest_centroid(¢roids, &observation).1);
}
pub(crate) fn update_memberships_and_dists<F: Float>(
centroids: &ArrayBase<impl Data<Elem = F> + Sync, Ix2>,
observations: &ArrayBase<impl Data<Elem = F> + Sync, Ix2>,
cluster_memberships: &mut ArrayBase<impl DataMut<Elem = usize>, Ix1>,
dists: &mut ArrayBase<impl DataMut<Elem = F>, Ix1>,
) {
Zip::from(observations.axis_iter(Axis(0)))
.and(cluster_memberships)
.and(dists)
.par_apply(|observation, cluster_membership, dist| {
let (m, d) = closest_centroid(¢roids, &observation);
*cluster_membership = m;
*dist = d;
});
}
fn compute_cluster_memberships<F: Float>(
centroids: &ArrayBase<impl Data<Elem = F> + Sync, Ix2>,
observations: &ArrayBase<impl Data<Elem = F> + Sync, Ix2>,
) -> Array1<usize> {
let mut memberships = Array1::zeros(observations.nrows());
update_cluster_memberships(¢roids, &observations, &mut memberships);
memberships
}
pub(crate) fn closest_centroid<F: Float>(
centroids: &ArrayBase<impl Data<Elem = F>, Ix2>,
observation: &ArrayBase<impl Data<Elem = F>, Ix1>,
) -> (usize, F) {
let iterator = centroids.genrows().into_iter();
let first_centroid = centroids.row(0);
let (mut closest_index, mut minimum_distance) = (
0,
first_centroid
.sq_l2_dist(&observation)
.expect("Failed to compute distance"),
);
for (centroid_index, centroid) in iterator.enumerate() {
let distance = centroid
.sq_l2_dist(&observation)
.expect("Failed to compute distance");
if distance < minimum_distance {
closest_index = centroid_index;
minimum_distance = distance;
}
}
(closest_index, minimum_distance)
}
#[cfg(test)]
mod tests {
use super::super::KMeansInit;
use super::*;
use approx::assert_abs_diff_eq;
use ndarray::{array, concatenate, Array, Array1, Array2, Axis};
use ndarray_rand::rand::SeedableRng;
use ndarray_rand::rand_distr::Uniform;
use ndarray_rand::RandomExt;
fn function_test_1d(x: &Array2<f64>) -> Array2<f64> {
let mut y = Array2::zeros(x.dim());
Zip::from(&mut y).and(x).apply(|yi, &xi| {
if xi < 0.4 {
*yi = xi * xi;
} else if xi >= 0.4 && xi < 0.8 {
*yi = 3. * xi + 1.;
} else {
*yi = f64::sin(10. * xi);
}
});
y
}
#[test]
fn test_n_runs() {
let mut rng = Isaac64Rng::seed_from_u64(42);
let xt = Array::random_using(100, Uniform::new(0., 1.0), &mut rng).insert_axis(Axis(1));
let yt = function_test_1d(&xt);
let data = concatenate(Axis(1), &[xt.view(), yt.view()]).unwrap();
for init in &[
KMeansInit::Random,
KMeansInit::KMeansPlusPlus,
KMeansInit::KMeansPara,
] {
let dataset = DatasetBase::from(data.clone());
let model = KMeans::params_with_rng(3, rng.clone())
.n_runs(1)
.init_method(init.clone())
.fit(&dataset)
.expect("KMeans fitted");
let clusters = model.predict(dataset);
let inertia = compute_inertia(model.centroids(), &clusters.records, &clusters.targets);
let total_dist = model.transform(&clusters.records.view()).sum();
assert_abs_diff_eq!(inertia, total_dist);
let dataset2 = DatasetBase::from(clusters.records().clone());
let model2 = KMeans::params_with_rng(3, rng.clone())
.init_method(init.clone())
.fit(&dataset2)
.expect("KMeans fitted");
let clusters2 = model2.predict(dataset2);
let inertia2 =
compute_inertia(model2.centroids(), &clusters2.records, &clusters2.targets);
let total_dist2 = model2.transform(&clusters2.records.view()).sum();
assert_abs_diff_eq!(inertia2, total_dist2);
if *init == KMeansInit::Random {
assert!(inertia2 <= inertia);
}
}
}
#[test]
fn compute_centroids_works() {
let cluster_size = 100;
let n_features = 4;
let cluster_1: Array2<f64> =
Array::random((cluster_size, n_features), Uniform::new(-100., 100.));
let memberships_1 = Array1::zeros(cluster_size);
let expected_centroid_1 = cluster_1.sum_axis(Axis(0)) / (cluster_size + 1) as f64;
let cluster_2: Array2<f64> =
Array::random((cluster_size, n_features), Uniform::new(-100., 100.));
let memberships_2 = Array1::ones(cluster_size);
let expected_centroid_2 = cluster_2.sum_axis(Axis(0)) / (cluster_size + 1) as f64;
let observations = concatenate(Axis(0), &[cluster_1.view(), cluster_2.view()]).unwrap();
let memberships =
concatenate(Axis(0), &[memberships_1.view(), memberships_2.view()]).unwrap();
let old_centroids = Array2::zeros((2, n_features));
let centroids = compute_centroids(&old_centroids, &observations, &memberships);
assert_abs_diff_eq!(
centroids.index_axis(Axis(0), 0),
expected_centroid_1,
epsilon = 1e-5
);
assert_abs_diff_eq!(
centroids.index_axis(Axis(0), 1),
expected_centroid_2,
epsilon = 1e-5
);
assert_eq!(centroids.len_of(Axis(0)), 2);
}
#[test]
fn test_compute_extra_centroids() {
let observations = array![[1.0, 2.0]];
let memberships = array![0];
let old_centroids = Array2::ones((2, 2));
let centroids = compute_centroids(&old_centroids, &observations, &memberships);
assert_abs_diff_eq!(centroids, array![[1.0, 1.5], [1.0, 1.0]]);
}
#[test]
fn nothing_is_closer_than_self() {
let n_centroids = 20;
let n_features = 5;
let mut rng = Isaac64Rng::seed_from_u64(42);
let centroids: Array2<f64> = Array::random_using(
(n_centroids, n_features),
Uniform::new(-100., 100.),
&mut rng,
);
let expected_memberships: Vec<usize> = (0..n_centroids).into_iter().collect();
assert_eq!(
compute_cluster_memberships(¢roids, ¢roids),
Array1::from(expected_memberships)
);
}
#[test]
fn oracle_test_for_closest_centroid() {
let centroids = array![[0., 0.], [1., 2.], [20., 0.], [0., 20.],];
let observations = array![[1., 0.5], [20., 2.], [20., 0.], [7., 20.],];
let memberships = array![0, 2, 2, 3];
assert_eq!(
compute_cluster_memberships(¢roids, &observations),
memberships
);
}
#[test]
fn test_compute_centroids_incremental() {
let observations = array![[-1.0, -3.0], [0., 0.], [3., 5.], [5., 5.]];
let memberships = array![0, 0, 1, 1];
let centroids = array![[-1., -1.], [3., 4.], [7., 8.]];
let mut counts = array![3.0, 0.0, 1.0];
let centroids =
compute_centroids_incremental(&observations, &memberships, ¢roids, &mut counts);
assert_abs_diff_eq!(centroids, array![[-4. / 5., -6. / 5.], [4., 5.], [7., 8.]]);
assert_abs_diff_eq!(counts, array![5., 2., 1.]);
}
#[test]
fn test_incremental_kmeans() {
let dataset1 = DatasetBase::from(array![[-1.0, -3.0], [0., 0.], [3., 5.], [5., 5.]]);
let dataset2 = DatasetBase::from(array![[-5.0, -5.0], [0., 0.], [10., 10.]]);
let model = KMeans {
centroids: array![[-1., -1.], [3., 4.], [7., 8.]],
cluster_count: array![0., 0., 0.],
inertia: 0.0,
};
let rng = Isaac64Rng::seed_from_u64(45);
let params = KMeans::params_with_rng(3, rng).tolerance(100.0).build();
let (model, converged) = params.fit_with(Some(model), &dataset1);
assert_abs_diff_eq!(model.centroids(), &array![[-0.5, -1.5], [4., 5.], [7., 8.]]);
assert!(converged);
let (model, converged) = params.fit_with(Some(model), &dataset2);
assert_abs_diff_eq!(
model.centroids(),
&array![[-6. / 4., -8. / 4.], [4., 5.], [10., 10.]]
);
assert!(converged);
}
}