linfa_clustering/k_means/
algorithm.rs

1use std::cmp::Ordering;
2use std::fmt::Debug;
3
4use crate::k_means::{KMeansParams, KMeansValidParams};
5use crate::IncrKMeansError;
6use crate::{k_means::errors::KMeansError, KMeansInit};
7use linfa::{prelude::*, DatasetBase, Float};
8use linfa_nn::distance::{Distance, L2Dist};
9use ndarray::{Array1, Array2, ArrayBase, Axis, Data, DataMut, Ix1, Ix2, Zip};
10use ndarray_rand::rand::{Rng, SeedableRng};
11use rand_xoshiro::Xoshiro256Plus;
12
13#[cfg(feature = "serde")]
14use serde_crate::{Deserialize, Serialize};
15
16#[cfg_attr(
17    feature = "serde",
18    derive(Serialize, Deserialize),
19    serde(crate = "serde_crate")
20)]
21#[derive(Clone, Debug, PartialEq)]
22/// K-means clustering aims to partition a set of unlabeled observations into clusters,
23/// where each observation belongs to the cluster with the nearest mean.
24///
25/// The mean of the points within a cluster is called *centroid*.
26///
27/// Given the set of centroids, you can assign an observation to a cluster
28/// choosing the nearest centroid.
29///
30/// We provide a modified version of the _standard algorithm_ (also known as Lloyd's Algorithm),
31/// called m_k-means, which uses a slightly modified update step to avoid problems with empty
32/// clusters. We also provide an incremental version of the algorithm that runs on smaller batches
33/// of input data.
34///
35/// More details on the algorithm can be found in the next section or
36/// [here](https://en.wikipedia.org/wiki/K-means_clustering). Details on m_k-means can be found
37/// [here](https://www.researchgate.net/publication/228414762_A_Modified_k-means_Algorithm_to_Avoid_Empty_Clusters).
38///
39/// ## Standard algorithm
40///
41/// K-means is an iterative algorithm: it progressively refines the choice of centroids.
42///
43/// It's guaranteed to converge, even though it might not find the optimal set of centroids
44/// (unfortunately it can get stuck in a local minimum, finding the optimal minimum is NP-hard!).
45///
46/// There are three steps in the standard algorithm:
47/// - initialisation step: select initial centroids using one of our provided algorithms.
48/// - assignment step: assign each observation to the nearest cluster
49///                    (minimum distance between the observation and the cluster's centroid);
50/// - update step: recompute the centroid of each cluster.
51///
52/// The initialisation step is a one-off, done at the very beginning.
53/// Assignment and update are repeated in a loop until convergence is reached (either the
54/// euclidean distance between the old and the new clusters is below `tolerance` or
55/// we exceed the `max_n_iterations`).
56///
57/// ## Incremental Algorithm
58///
59/// In addition to the standard algorithm, we also provide an incremental version of K-means known
60/// as Mini-Batch K-means. In this algorithm, the dataset is divided into small batches, and the
61/// assignment and update steps are performed on each batch instead of the entire dataset. The
62/// update step also takes previous update steps into account when updating the centroids.
63///
64/// Due to using smaller batches, Mini-Batch K-means takes significantly less time to execute than
65/// the standard K-means algorithm, although it may yield slightly worse centroids.
66///
67/// More details on Mini-Batch K-means can be found [here](https://www.eecs.tufts.edu/~dsculley/papers/fastkmeans.pdf).
68///
69/// ## Parallelisation
70///
71/// The work performed by the assignment step does not require any coordination:
72/// the closest centroid for each point can be computed independently from the
73/// closest centroid for any of the remaining points.
74///
75/// This makes it a good candidate for parallel execution: `KMeans::fit` parallelises the
76/// assignment step thanks to the `rayon` feature in `ndarray`.
77///
78/// The update step requires a bit more coordination (computing a rolling mean in
79/// parallel) but it is still parallelisable.
80/// Nonetheless, our first attempts have not improved performance
81/// (most likely due to our strategy used to split work between threads), hence
82/// the update step is currently executed on a single thread.
83///
84/// ## Tutorial
85///
86/// Let's do a walkthrough of a training-predict-save example.
87///
88/// ```
89/// use linfa::DatasetBase;
90/// use linfa::traits::{Fit, FitWith, Predict};
91/// use linfa_clustering::{KMeansParams, KMeans, IncrKMeansError};
92/// use linfa_datasets::generate;
93/// use ndarray::{Axis, array, s};
94/// use ndarray_rand::rand::SeedableRng;
95/// use rand_xoshiro::Xoshiro256Plus;
96/// use approx::assert_abs_diff_eq;
97///
98/// // Our random number generator, seeded for reproducibility
99/// let seed = 42;
100/// let mut rng = Xoshiro256Plus::seed_from_u64(seed);
101///
102/// // `expected_centroids` has shape `(n_centroids, n_features)`
103/// // i.e. three points in the 2-dimensional plane
104/// let expected_centroids = array![[0., 1.], [-10., 20.], [-1., 10.]];
105/// // Let's generate a synthetic dataset: three blobs of observations
106/// // (100 points each) centered around our `expected_centroids`
107/// let data = generate::blobs(100, &expected_centroids, &mut rng);
108/// let n_clusters = expected_centroids.len_of(Axis(0));
109///
110/// // Standard K-means
111/// {
112///     let observations = DatasetBase::from(data.clone());
113///     // Let's configure and run our K-means algorithm
114///     // We use the builder pattern to specify the hyperparameters
115///     // `n_clusters` is the only mandatory parameter.
116///     // If you don't specify the others (e.g. `n_runs`, `tolerance`, `max_n_iterations`)
117///     // default values will be used.
118///     let model = KMeans::params_with_rng(n_clusters, rng.clone())
119///         .tolerance(1e-2)
120///         .fit(&observations)
121///         .expect("KMeans fitted");
122///
123///     // Once we found our set of centroids, we can also assign new points to the nearest cluster
124///     let new_observation = DatasetBase::from(array![[-9., 20.5]]);
125///     // Predict returns the **index** of the nearest cluster
126///     let dataset = model.predict(new_observation);
127///     // We can retrieve the actual centroid of the closest cluster using `.centroids()`
128///     let closest_centroid = &model.centroids().index_axis(Axis(0), dataset.targets()[0]);
129///     assert_abs_diff_eq!(closest_centroid.to_owned(), &array![-10., 20.], epsilon = 1e-1);
130/// }
131///
132/// // Incremental K-means
133/// {
134///     let batch_size = 100;
135///     // Shuffling the dataset is one way of ensuring that the batches contain random points from
136///     // the dataset, which is required for the algorithm to work properly
137///     let observations = DatasetBase::from(data.clone()).shuffle(&mut rng);
138///
139///     let n_clusters = expected_centroids.nrows();
140///     let clf = KMeans::params_with_rng(n_clusters, rng.clone()).tolerance(1e-3);
141///
142///     // Repeatedly run fit_with on every batch in the dataset until we have converged
143///     let model = observations
144///         .sample_chunks(batch_size)
145///         .cycle()
146///         .try_fold(None, |current, batch| {
147///             match clf.fit_with(current, &batch) {
148///                 // Early stop condition for the kmeans loop
149///                 Ok(model) => Err(model),
150///                 // Continue running if not converged
151///                 Err(IncrKMeansError::NotConverged(model)) => Ok(Some(model)),
152///                 Err(err) => panic!("unexpected kmeans error: {}", err),
153///             }
154///         })
155///         .unwrap_err();
156///
157///     let new_observation = DatasetBase::from(array![[-9., 20.5]]);
158///     let dataset = model.predict(new_observation);
159///     let closest_centroid = &model.centroids().index_axis(Axis(0), dataset.targets()[0]);
160///     assert_abs_diff_eq!(closest_centroid.to_owned(), &array![-10., 20.], epsilon = 1e-1);
161/// }
162/// ```
163///
164/*///
165/// // The model can be serialised (and deserialised) to disk using serde
166/// // We'll use the JSON format here for simplicity
167/// let filename = "k_means_model.json";
168/// let writer = std::fs::File::create(filename).expect("Failed to open file.");
169/// serde_json::to_writer(writer, &model).expect("Failed to serialise model.");
170///
171/// let reader = std::fs::File::open(filename).expect("Failed to open file.");
172/// let loaded_model: KMeans<f64> = serde_json::from_reader(reader).expect("Failed to deserialise model");
173///
174/// assert_abs_diff_eq!(model.centroids(), loaded_model.centroids(), epsilon = 1e-10);
175/// assert_eq!(model.hyperparameters(), loaded_model.hyperparameters());
176/// ```
177*/
178pub struct KMeans<F: Float, D: Distance<F>> {
179    centroids: Array2<F>,
180    cluster_count: Array1<F>,
181    inertia: F,
182    dist_fn: D,
183}
184
185impl<F: Float> KMeans<F, L2Dist> {
186    pub fn params(nclusters: usize) -> KMeansParams<F, Xoshiro256Plus, L2Dist> {
187        KMeansParams::new(nclusters, Xoshiro256Plus::seed_from_u64(42), L2Dist)
188    }
189
190    pub fn params_with_rng<R: Rng>(nclusters: usize, rng: R) -> KMeansParams<F, R, L2Dist> {
191        KMeansParams::new(nclusters, rng, L2Dist)
192    }
193}
194
195impl<F: Float, D: Distance<F>> KMeans<F, D> {
196    pub fn params_with<R: Rng>(nclusters: usize, rng: R, dist_fn: D) -> KMeansParams<F, R, D> {
197        KMeansParams::new(nclusters, rng, dist_fn)
198    }
199
200    /// Return the set of centroids as a 2-dimensional matrix with shape
201    /// `(n_centroids, n_features)`.
202    pub fn centroids(&self) -> &Array2<F> {
203        &self.centroids
204    }
205
206    /// Return the number of training points belonging to each cluster
207    pub fn cluster_count(&self) -> &Array1<F> {
208        &self.cluster_count
209    }
210
211    /// Return the sum of distances between each training point and its closest centroid, averaged
212    /// across all training points.  When training incrementally, this value is computed on the
213    /// most recent batch.
214    pub fn inertia(&self) -> F {
215        self.inertia
216    }
217}
218
219impl<F: Float, R: Rng + Clone, DA: Data<Elem = F>, T, D: Distance<F>>
220    Fit<ArrayBase<DA, Ix2>, T, KMeansError> for KMeansValidParams<F, R, D>
221{
222    type Object = KMeans<F, D>;
223
224    /// Given an input matrix `observations`, with shape `(n_observations, n_features)`,
225    /// `fit` identifies `n_clusters` centroids based on the training data distribution.
226    ///
227    /// An instance of `KMeans` is returned.
228    ///
229    fn fit(
230        &self,
231        dataset: &DatasetBase<ArrayBase<DA, Ix2>, T>,
232    ) -> Result<Self::Object, KMeansError> {
233        let mut rng = self.rng().clone();
234        let observations = dataset.records().view();
235        let n_samples = dataset.nsamples();
236
237        let mut min_inertia = F::infinity();
238        let mut best_centroids = None;
239        let mut memberships = Array1::zeros(n_samples);
240        let mut dists = Array1::zeros(n_samples);
241
242        let n_runs = self.n_runs();
243
244        for _ in 0..n_runs {
245            let mut centroids =
246                self.init_method()
247                    .run(self.dist_fn(), self.n_clusters(), observations, &mut rng);
248            let mut n_iter = 0;
249            let inertia = loop {
250                update_memberships_and_dists(
251                    self.dist_fn(),
252                    &centroids,
253                    &observations,
254                    &mut memberships,
255                    &mut dists,
256                );
257                let new_centroids = compute_centroids(&centroids, &observations, &memberships);
258                let distance = self
259                    .dist_fn()
260                    .distance(centroids.view(), new_centroids.view());
261                centroids = new_centroids;
262                n_iter += 1;
263                if distance < self.tolerance() || n_iter == self.max_n_iterations() {
264                    break dists.sum();
265                }
266            };
267
268            // We keep the centroids which minimize the inertia (defined as the sum of
269            // the squared distances of the closest centroid for all observations)
270            // over the n runs of the KMeans algorithm.
271            if inertia < min_inertia {
272                min_inertia = inertia;
273                best_centroids = Some(centroids.clone());
274            }
275        }
276
277        match best_centroids {
278            Some(centroids) => {
279                let mut cluster_count = Array1::zeros(self.n_clusters());
280                memberships
281                    .iter()
282                    .for_each(|&c| cluster_count[c] += F::one());
283                Ok(KMeans {
284                    centroids,
285                    cluster_count,
286                    inertia: min_inertia / F::cast(dataset.nsamples()),
287                    dist_fn: self.dist_fn().clone(),
288                })
289            }
290            _ => Err(KMeansError::InertiaError),
291        }
292    }
293}
294
295impl<'a, F: Float + Debug, R: Rng + Clone, DA: Data<Elem = F>, T, D: 'a + Distance<F> + Debug>
296    FitWith<'a, ArrayBase<DA, Ix2>, T, IncrKMeansError<KMeans<F, D>>>
297    for KMeansValidParams<F, R, D>
298{
299    type ObjectIn = Option<KMeans<F, D>>;
300    type ObjectOut = KMeans<F, D>;
301
302    /// Performs a single batch update of the Mini-Batch K-means algorithm.
303    ///
304    /// Given an input matrix `observations`, with shape `(n_batch, n_features)` and a previous
305    /// `KMeans` model, the model's centroids are updated with the input matrix. If `model` is
306    /// `None`, then it's initialized using the specified initialization algorithm. The return
307    /// value consists of the updated model and a `bool` value that indicates whether the algorithm
308    /// has converged.
309    fn fit_with(
310        &self,
311        model: Self::ObjectIn,
312        dataset: &'a DatasetBase<ArrayBase<DA, Ix2>, T>,
313    ) -> Result<Self::ObjectOut, IncrKMeansError<Self::ObjectOut>> {
314        let observations = dataset.records().view();
315        let n_samples = dataset.nsamples();
316
317        let mut model = match model {
318            Some(model) => model,
319            None => {
320                let centroids = if let KMeansInit::Precomputed(centroids) = self.init_method() {
321                    // If using precomputed centroids, don't run the init algorithm multiple times
322                    centroids.clone()
323                } else {
324                    let mut rng = self.rng().clone();
325                    let mut dists = Array1::zeros(n_samples);
326                    // Initial centroids derived from the first batch by running the init algorithm
327                    // n_runs times and taking the centroids with the lowest inertia
328                    (0..self.n_runs())
329                        .map(|_| {
330                            let centroids = self.init_method().run(
331                                self.dist_fn(),
332                                self.n_clusters(),
333                                observations,
334                                &mut rng,
335                            );
336                            update_min_dists(self.dist_fn(), &centroids, &observations, &mut dists);
337                            (centroids, dists.sum())
338                        })
339                        .min_by(|(_, d1), (_, d2)| {
340                            if d1 < d2 {
341                                Ordering::Less
342                            } else {
343                                Ordering::Greater
344                            }
345                        })
346                        .unwrap()
347                        .0
348                };
349                KMeans {
350                    centroids,
351                    cluster_count: Array1::zeros(self.n_clusters()),
352                    inertia: F::zero(),
353                    dist_fn: self.dist_fn().clone(),
354                }
355            }
356        };
357
358        let mut memberships = Array1::zeros(n_samples);
359        let mut dists = Array1::zeros(n_samples);
360        update_memberships_and_dists(
361            self.dist_fn(),
362            &model.centroids,
363            &observations,
364            &mut memberships,
365            &mut dists,
366        );
367        let new_centroids = compute_centroids_incremental(
368            &observations,
369            &memberships,
370            &model.centroids,
371            &mut model.cluster_count,
372        );
373        model.inertia = dists.sum() / F::cast(n_samples);
374        let dist = self
375            .dist_fn()
376            .distance(model.centroids.view(), new_centroids.view());
377        model.centroids = new_centroids;
378
379        if dist < self.tolerance() {
380            Ok(model)
381        } else {
382            Err(IncrKMeansError::NotConverged(model))
383        }
384    }
385}
386
387impl<F: Float, DA: Data<Elem = F>, D: Distance<F>> Transformer<&ArrayBase<DA, Ix2>, Array1<F>>
388    for KMeans<F, D>
389{
390    /// Given an input matrix `observations`, with shape `(n_observations, n_features)`,
391    /// `transform` returns, for each observation, its squared distance to its centroid.
392    fn transform(&self, observations: &ArrayBase<DA, Ix2>) -> Array1<F> {
393        let mut dists = Array1::zeros(observations.nrows());
394        update_min_dists(
395            &self.dist_fn,
396            &self.centroids,
397            &observations.view(),
398            &mut dists,
399        );
400        dists
401    }
402}
403
404impl<F: Float, DA: Data<Elem = F>, D: Distance<F>> PredictInplace<ArrayBase<DA, Ix2>, Array1<usize>>
405    for KMeans<F, D>
406{
407    /// Given an input matrix `observations`, with shape `(n_observations, n_features)`,
408    /// `predict` returns, for each observation, the index of the closest cluster/centroid.
409    ///
410    /// You can retrieve the centroid associated to an index using the
411    /// [`centroids` method](#method.centroids).
412    fn predict_inplace(&self, observations: &ArrayBase<DA, Ix2>, memberships: &mut Array1<usize>) {
413        assert_eq!(
414            observations.nrows(),
415            memberships.len(),
416            "The number of data points must match the number of memberships."
417        );
418
419        update_cluster_memberships(
420            &self.dist_fn,
421            &self.centroids,
422            &observations.view(),
423            memberships,
424        );
425    }
426
427    fn default_target(&self, x: &ArrayBase<DA, Ix2>) -> Array1<usize> {
428        Array1::zeros(x.nrows())
429    }
430}
431
432impl<F: Float, DA: Data<Elem = F>, D: Distance<F>> PredictInplace<ArrayBase<DA, Ix1>, usize>
433    for KMeans<F, D>
434{
435    /// Given one input observation, return the index of its closest cluster
436    ///
437    /// You can retrieve the centroid associated to an index using the
438    /// [`centroids` method](#method.centroids).
439    fn predict_inplace(&self, observation: &ArrayBase<DA, Ix1>, membership: &mut usize) {
440        *membership = closest_centroid(&self.dist_fn, &self.centroids, observation).0;
441    }
442
443    fn default_target(&self, _x: &ArrayBase<DA, Ix1>) -> usize {
444        0
445    }
446}
447
448/// K-means is an iterative algorithm.
449/// We will perform the assignment and update steps until we are satisfied
450/// (according to our convergence criteria).
451///
452/// `compute_centroids` returns a 2-dimensional array,
453/// where the i-th row corresponds to the i-th cluster.
454fn compute_centroids<F: Float>(
455    old_centroids: &Array2<F>,
456    // (n_observations, n_features)
457    observations: &ArrayBase<impl Data<Elem = F>, Ix2>,
458    // (n_observations,)
459    cluster_memberships: &ArrayBase<impl Data<Elem = usize>, Ix1>,
460) -> Array2<F> {
461    let n_clusters = old_centroids.nrows();
462    let mut counts: Array1<usize> = Array1::ones(n_clusters);
463    let mut centroids = Array2::zeros((n_clusters, observations.ncols()));
464
465    Zip::from(observations.rows())
466        .and(cluster_memberships)
467        .for_each(|observation, &cluster_membership| {
468            let mut centroid = centroids.row_mut(cluster_membership);
469            centroid += &observation;
470            counts[cluster_membership] += 1;
471        });
472    // m_k-means: Treat the old centroid like another point in the cluster
473    centroids += old_centroids;
474
475    Zip::from(centroids.rows_mut())
476        .and(&counts)
477        .for_each(|mut centroid, &cnt| centroid /= F::cast(cnt));
478    centroids
479}
480
481/// Returns new centroids which has the moving average of all observations in each cluster added to
482/// the old centroids.
483/// Updates `counts` with the number of observations in each cluster.
484fn compute_centroids_incremental<F: Float>(
485    observations: &ArrayBase<impl Data<Elem = F>, Ix2>,
486    cluster_memberships: &ArrayBase<impl Data<Elem = usize>, Ix1>,
487    old_centroids: &ArrayBase<impl Data<Elem = F>, Ix2>,
488    counts: &mut ArrayBase<impl DataMut<Elem = F>, Ix1>,
489) -> Array2<F> {
490    let mut centroids = old_centroids.to_owned();
491    // We can parallelize this
492    Zip::from(observations.rows())
493        .and(cluster_memberships)
494        .for_each(|obs, &c| {
495            // Computes centroids[c] += (observation - centroids[c]) / counts[c]
496            // If cluster is empty for this batch, then this wouldn't even be called, so no
497            // chance of getting NaN.
498            counts[c] += F::one();
499            let shift = (&obs - &centroids.row(c)) / counts[c];
500            let mut centroid = centroids.row_mut(c);
501            centroid += &shift;
502        });
503    centroids
504}
505
506// Update `cluster_memberships` with the index of the cluster each observation belongs to.
507pub(crate) fn update_cluster_memberships<F: Float, D: Distance<F>>(
508    dist_fn: &D,
509    centroids: &ArrayBase<impl Data<Elem = F> + Sync, Ix2>,
510    observations: &ArrayBase<impl Data<Elem = F> + Sync, Ix2>,
511    cluster_memberships: &mut ArrayBase<impl DataMut<Elem = usize>, Ix1>,
512) {
513    Zip::from(observations.axis_iter(Axis(0)))
514        .and(cluster_memberships)
515        .par_for_each(|observation, cluster_membership| {
516            *cluster_membership = closest_centroid(dist_fn, centroids, &observation).0
517        });
518}
519
520// Updates `dists` with the distance of each observation from its closest centroid.
521pub(crate) fn update_min_dists<F: Float, D: Distance<F>>(
522    dist_fn: &D,
523    centroids: &ArrayBase<impl Data<Elem = F> + Sync, Ix2>,
524    observations: &ArrayBase<impl Data<Elem = F> + Sync, Ix2>,
525    dists: &mut ArrayBase<impl DataMut<Elem = F>, Ix1>,
526) {
527    Zip::from(observations.axis_iter(Axis(0)))
528        .and(dists)
529        .par_for_each(|observation, dist| {
530            *dist = closest_centroid(dist_fn, centroids, &observation).1
531        });
532}
533
534// Efficient combination of `update_cluster_memberships` and `update_min_dists`.
535pub(crate) fn update_memberships_and_dists<F: Float, D: Distance<F>>(
536    dist_fn: &D,
537    centroids: &ArrayBase<impl Data<Elem = F> + Sync, Ix2>,
538    observations: &ArrayBase<impl Data<Elem = F> + Sync, Ix2>,
539    cluster_memberships: &mut ArrayBase<impl DataMut<Elem = usize>, Ix1>,
540    dists: &mut ArrayBase<impl DataMut<Elem = F>, Ix1>,
541) {
542    Zip::from(observations.axis_iter(Axis(0)))
543        .and(cluster_memberships)
544        .and(dists)
545        .par_for_each(|observation, cluster_membership, dist| {
546            let (m, d) = closest_centroid(dist_fn, centroids, &observation);
547            *cluster_membership = m;
548            *dist = d;
549        });
550}
551
552/// Given a matrix of centroids with shape (n_centroids, n_features) and an observation,
553/// return the index of the closest centroid (the index of the corresponding row in `centroids`).
554pub(crate) fn closest_centroid<F: Float, D: Distance<F>>(
555    dist_fn: &D,
556    // (n_centroids, n_features)
557    centroids: &ArrayBase<impl Data<Elem = F>, Ix2>,
558    // (n_features)
559    observation: &ArrayBase<impl Data<Elem = F>, Ix1>,
560) -> (usize, F) {
561    let iterator = centroids.rows().into_iter();
562
563    let first_centroid = centroids.row(0);
564    let (mut closest_index, mut minimum_distance) = (
565        0,
566        dist_fn.rdistance(first_centroid.view(), observation.view()),
567    );
568
569    for (centroid_index, centroid) in iterator.enumerate() {
570        let distance = dist_fn.rdistance(centroid.view(), observation.view());
571        if distance < minimum_distance {
572            closest_index = centroid_index;
573            minimum_distance = distance;
574        }
575    }
576    (closest_index, minimum_distance)
577}
578
579#[cfg(test)]
580mod tests {
581    use super::super::KMeansInit;
582    use super::*;
583    use crate::KMeansParamsError;
584    use approx::assert_abs_diff_eq;
585    use linfa_nn::distance::L1Dist;
586    use ndarray::{array, concatenate, Array, Array1, Array2, Axis};
587    use ndarray_rand::rand::prelude::ThreadRng;
588    use ndarray_rand::rand::SeedableRng;
589    use ndarray_rand::rand_distr::Uniform;
590    use ndarray_rand::RandomExt;
591
592    #[test]
593    fn autotraits() {
594        fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
595        has_autotraits::<KMeans<f64, L2Dist>>();
596        has_autotraits::<KMeansParamsError>();
597        has_autotraits::<KMeansError>();
598        has_autotraits::<IncrKMeansError<String>>();
599    }
600
601    fn function_test_1d(x: &Array2<f64>) -> Array2<f64> {
602        let mut y = Array2::zeros(x.dim());
603        Zip::from(&mut y).and(x).for_each(|yi, &xi| {
604            if xi < 0.4 {
605                *yi = xi * xi;
606            } else if (0.4..0.8).contains(&xi) {
607                *yi = 3. * xi + 1.;
608            } else {
609                *yi = f64::sin(10. * xi);
610            }
611        });
612        y
613    }
614
615    macro_rules! calc_inertia {
616        ($dist:expr, $centroids:expr, $obs:expr, $memberships:expr) => {
617            $obs.rows()
618                .into_iter()
619                .zip($memberships.iter())
620                .map(|(row, &c)| $dist.rdistance(row.view(), $centroids.row(c).view()))
621                .sum::<f64>()
622        };
623    }
624
625    macro_rules! calc_memberships {
626        ($dist:expr, $centroids:expr, $obs:expr) => {{
627            let mut memberships = Array1::zeros($obs.nrows());
628            update_cluster_memberships(&$dist, &$centroids, &$obs, &mut memberships);
629            memberships
630        }};
631    }
632
633    #[test]
634    fn test_min_dists() {
635        let centroids = array![[0.0, 1.0], [40.0, 10.0]];
636        let observations = array![[3.0, 4.0], [1.0, 3.0], [25.0, 15.0]];
637        let mut dists = Array1::zeros(observations.nrows());
638
639        update_min_dists(&L2Dist, &centroids, &observations, &mut dists);
640        assert_abs_diff_eq!(dists, array![18.0, 5.0, 250.0]);
641        update_min_dists(&L1Dist, &centroids, &observations, &mut dists);
642        assert_abs_diff_eq!(dists, array![6.0, 3.0, 20.0]);
643    }
644
645    fn test_n_runs<D: Distance<f64>>(dist_fn: D) {
646        let mut rng = Xoshiro256Plus::seed_from_u64(42);
647        let xt = Array::random_using(100, Uniform::new(0., 1.0), &mut rng).insert_axis(Axis(1));
648        let yt = function_test_1d(&xt);
649        let data = concatenate(Axis(1), &[xt.view(), yt.view()]).unwrap();
650
651        for init in &[
652            KMeansInit::Random,
653            KMeansInit::KMeansPlusPlus,
654            KMeansInit::KMeansPara,
655        ] {
656            // First clustering with one iteration
657            let dataset = DatasetBase::from(data.clone());
658            let model = KMeans::params_with(3, rng.clone(), dist_fn.clone())
659                .n_runs(1)
660                .init_method(init.clone())
661                .fit(&dataset)
662                .expect("KMeans fitted");
663            let clusters = model.predict(dataset);
664            let inertia = calc_inertia!(
665                dist_fn,
666                model.centroids(),
667                clusters.records,
668                clusters.targets
669            );
670            let total_dist = model.transform(&clusters.records.view()).sum();
671            assert_abs_diff_eq!(inertia, total_dist, epsilon = 1e-5);
672
673            let single_cluster: usize = model.predict(&data.row(0));
674            assert_abs_diff_eq!(single_cluster, clusters.targets[0]);
675
676            // Second clustering with 10 iterations (default)
677            let dataset2 = DatasetBase::from(clusters.records().clone());
678            let model2 = KMeans::params_with(3, rng.clone(), dist_fn.clone())
679                .init_method(init.clone())
680                .fit(&dataset2)
681                .expect("KMeans fitted");
682            let clusters2 = model2.predict(dataset2);
683            let inertia2 = calc_inertia!(
684                dist_fn,
685                model2.centroids(),
686                clusters2.records,
687                clusters2.targets
688            );
689            let total_dist2 = model2.transform(&clusters2.records.view()).sum();
690            assert_abs_diff_eq!(inertia2, total_dist2, epsilon = 1e-5);
691
692            // Check we improve inertia (only really makes a difference for random init)
693            if *init == KMeansInit::Random {
694                assert!(inertia2 <= inertia);
695            }
696        }
697    }
698
699    #[test]
700    fn test_n_runs_l2dist() {
701        test_n_runs(L2Dist);
702    }
703
704    #[test]
705    fn test_n_runs_l1dist() {
706        test_n_runs(L1Dist);
707    }
708
709    #[test]
710    fn compute_centroids_works() {
711        let cluster_size = 100;
712        let n_features = 4;
713
714        // Let's setup a synthetic set of observations, composed of two clusters with known means
715        let cluster_1: Array2<f64> =
716            Array::random((cluster_size, n_features), Uniform::new(-100., 100.));
717        let memberships_1 = Array1::zeros(cluster_size);
718        let expected_centroid_1 = cluster_1.sum_axis(Axis(0)) / (cluster_size + 1) as f64;
719
720        let cluster_2: Array2<f64> =
721            Array::random((cluster_size, n_features), Uniform::new(-100., 100.));
722        let memberships_2 = Array1::ones(cluster_size);
723        let expected_centroid_2 = cluster_2.sum_axis(Axis(0)) / (cluster_size + 1) as f64;
724
725        // `concatenate` combines arrays along a given axis: https://docs.rs/ndarray/0.13.0/ndarray/fn.concatenate.html
726        let observations = concatenate(Axis(0), &[cluster_1.view(), cluster_2.view()]).unwrap();
727        let memberships =
728            concatenate(Axis(0), &[memberships_1.view(), memberships_2.view()]).unwrap();
729
730        // Does it work?
731        let old_centroids = Array2::zeros((2, n_features));
732        let centroids = compute_centroids(&old_centroids, &observations, &memberships);
733        assert_abs_diff_eq!(
734            centroids.index_axis(Axis(0), 0),
735            expected_centroid_1,
736            epsilon = 1e-5
737        );
738        assert_abs_diff_eq!(
739            centroids.index_axis(Axis(0), 1),
740            expected_centroid_2,
741            epsilon = 1e-5
742        );
743
744        assert_eq!(centroids.len_of(Axis(0)), 2);
745    }
746
747    #[test]
748    fn test_compute_extra_centroids() {
749        let observations = array![[1.0, 2.0]];
750        let memberships = array![0];
751        // Should return an average of 0 for empty clusters
752        let old_centroids = Array2::ones((2, 2));
753        let centroids = compute_centroids(&old_centroids, &observations, &memberships);
754        assert_abs_diff_eq!(centroids, array![[1.0, 1.5], [1.0, 1.0]]);
755    }
756
757    #[test]
758    // An observation is closest to itself.
759    fn nothing_is_closer_than_self() {
760        let n_centroids = 20;
761        let n_features = 5;
762        let mut rng = Xoshiro256Plus::seed_from_u64(42);
763        let centroids: Array2<f64> = Array::random_using(
764            (n_centroids, n_features),
765            Uniform::new(-100., 100.),
766            &mut rng,
767        );
768
769        let expected_memberships = (0..n_centroids).collect::<Array1<_>>();
770        assert_eq!(
771            calc_memberships!(L2Dist, centroids, centroids),
772            expected_memberships
773        );
774        assert_eq!(
775            calc_memberships!(L1Dist, centroids, centroids),
776            expected_memberships
777        );
778    }
779
780    #[test]
781    fn oracle_test_for_closest_centroid() {
782        let centroids = array![[0., 0.], [1., 2.], [20., 0.], [0., 20.],];
783        let observations = array![[1., 0.6], [20., 2.], [20., 0.], [7., 20.],];
784        let l2_memberships = array![0, 2, 2, 3];
785        let l1_memberships = array![1, 2, 2, 3];
786
787        assert_eq!(
788            calc_memberships!(L2Dist, centroids, observations),
789            l2_memberships
790        );
791        assert_eq!(
792            calc_memberships!(L1Dist, centroids, observations),
793            l1_memberships
794        );
795    }
796
797    #[test]
798    fn test_compute_centroids_incremental() {
799        let observations = array![[-1.0, -3.0], [0., 0.], [3., 5.], [5., 5.]];
800        let memberships = array![0, 0, 1, 1];
801        let centroids = array![[-1., -1.], [3., 4.], [7., 8.]];
802        let mut counts = array![3.0, 0.0, 1.0];
803        let centroids =
804            compute_centroids_incremental(&observations, &memberships, &centroids, &mut counts);
805
806        assert_abs_diff_eq!(centroids, array![[-4. / 5., -6. / 5.], [4., 5.], [7., 8.]]);
807        assert_abs_diff_eq!(counts, array![5., 2., 1.]);
808    }
809
810    #[test]
811    fn test_incremental_kmeans() {
812        let dataset1 = DatasetBase::from(array![[-1.0, -3.0], [0., 0.], [3., 5.], [5., 5.]]);
813        let dataset2 = DatasetBase::from(array![[-5.0, -5.0], [0., 0.], [10., 10.]]);
814        let model = KMeans {
815            centroids: array![[-1., -1.], [3., 4.], [7., 8.]],
816            cluster_count: array![0., 0., 0.],
817            inertia: 0.0,
818            dist_fn: L2Dist,
819        };
820        let rng = Xoshiro256Plus::seed_from_u64(45);
821        let params = KMeans::params_with_rng(3, rng).tolerance(100.0);
822
823        // Should converge on first try
824        let model = params.fit_with(Some(model), &dataset1).unwrap();
825        assert_abs_diff_eq!(model.centroids(), &array![[-0.5, -1.5], [4., 5.], [7., 8.]]);
826
827        let model = params.fit_with(Some(model), &dataset2).unwrap();
828        assert_abs_diff_eq!(
829            model.centroids(),
830            &array![[-6. / 4., -8. / 4.], [4., 5.], [10., 10.]]
831        );
832    }
833
834    #[test]
835    fn test_tolerance() {
836        let rng = Xoshiro256Plus::seed_from_u64(45);
837        // The "correct" centroid for the dataset is [6, 6], so the centroid distance from the
838        // initial centroid in the first iteration should be around 8.48. With a tolerance of 8.5,
839        // KMeans should converge on first iteration.
840        let params = KMeans::params_with_rng(1, rng)
841            .tolerance(8.5)
842            .init_method(KMeansInit::Precomputed(array![[0., 0.]]));
843        let data = DatasetBase::from(array![[1., 1.], [11., 11.]]);
844        assert!(params.fit_with(None, &data).is_ok());
845    }
846
847    #[test]
848    fn test_max_n_iterations() {
849        let mut rng = Xoshiro256Plus::seed_from_u64(42);
850        let xt = Array::random_using(100, Uniform::new(0., 1.0), &mut rng).insert_axis(Axis(1));
851        let yt = function_test_1d(&xt);
852        let data = concatenate(Axis(1), &[xt.view(), yt.view()]).unwrap();
853        let dataset = DatasetBase::from(data.clone());
854        // For data created using the above rng and seed, for 6 clusters, it would take 8 iterations to converge.
855        // However, when specifying max_n_iterations as 5, the algorithm should stop early gracefully.
856        let _model = KMeans::params_with(6, rng.clone(), L2Dist)
857            .n_runs(1)
858            .max_n_iterations(5)
859            .init_method(KMeansInit::Random)
860            .fit(&dataset)
861            .expect("KMeans fitted");
862    }
863
864    fn fittable<T: Fit<Array2<f64>, (), KMeansError>>(_: T) {}
865    #[test]
866    fn thread_rng_fittable() {
867        fittable(KMeans::params_with_rng(1, ThreadRng::default()));
868    }
869}