Struct KMeans

Source
pub struct KMeans<F: Float, D: Distance<F>> { /* private fields */ }
Expand description

K-means clustering aims to partition a set of unlabeled observations into clusters, where each observation belongs to the cluster with the nearest mean.

The mean of the points within a cluster is called centroid.

Given the set of centroids, you can assign an observation to a cluster choosing the nearest centroid.

We provide a modified version of the standard algorithm (also known as Lloyd’s Algorithm), called m_k-means, which uses a slightly modified update step to avoid problems with empty clusters. We also provide an incremental version of the algorithm that runs on smaller batches of input data.

More details on the algorithm can be found in the next section or here. Details on m_k-means can be found here.

§Standard algorithm

K-means is an iterative algorithm: it progressively refines the choice of centroids.

It’s guaranteed to converge, even though it might not find the optimal set of centroids (unfortunately it can get stuck in a local minimum, finding the optimal minimum is NP-hard!).

There are three steps in the standard algorithm:

  • initialisation step: select initial centroids using one of our provided algorithms.
  • assignment step: assign each observation to the nearest cluster (minimum distance between the observation and the cluster’s centroid);
  • update step: recompute the centroid of each cluster.

The initialisation step is a one-off, done at the very beginning. Assignment and update are repeated in a loop until convergence is reached (either the euclidean distance between the old and the new clusters is below tolerance or we exceed the max_n_iterations).

§Incremental Algorithm

In addition to the standard algorithm, we also provide an incremental version of K-means known as Mini-Batch K-means. In this algorithm, the dataset is divided into small batches, and the assignment and update steps are performed on each batch instead of the entire dataset. The update step also takes previous update steps into account when updating the centroids.

Due to using smaller batches, Mini-Batch K-means takes significantly less time to execute than the standard K-means algorithm, although it may yield slightly worse centroids.

More details on Mini-Batch K-means can be found here.

§Parallelisation

The work performed by the assignment step does not require any coordination: the closest centroid for each point can be computed independently from the closest centroid for any of the remaining points.

This makes it a good candidate for parallel execution: KMeans::fit parallelises the assignment step thanks to the rayon feature in ndarray.

The update step requires a bit more coordination (computing a rolling mean in parallel) but it is still parallelisable. Nonetheless, our first attempts have not improved performance (most likely due to our strategy used to split work between threads), hence the update step is currently executed on a single thread.

§Tutorial

Let’s do a walkthrough of a training-predict-save example.

use linfa::DatasetBase;
use linfa::traits::{Fit, FitWith, Predict};
use linfa_clustering::{KMeansParams, KMeans, IncrKMeansError};
use linfa_datasets::generate;
use ndarray::{Axis, array, s};
use ndarray_rand::rand::SeedableRng;
use rand_xoshiro::Xoshiro256Plus;
use approx::assert_abs_diff_eq;

// Our random number generator, seeded for reproducibility
let seed = 42;
let mut rng = Xoshiro256Plus::seed_from_u64(seed);

// `expected_centroids` has shape `(n_centroids, n_features)`
// i.e. three points in the 2-dimensional plane
let expected_centroids = array![[0., 1.], [-10., 20.], [-1., 10.]];
// Let's generate a synthetic dataset: three blobs of observations
// (100 points each) centered around our `expected_centroids`
let data = generate::blobs(100, &expected_centroids, &mut rng);
let n_clusters = expected_centroids.len_of(Axis(0));

// Standard K-means
{
    let observations = DatasetBase::from(data.clone());
    // Let's configure and run our K-means algorithm
    // We use the builder pattern to specify the hyperparameters
    // `n_clusters` is the only mandatory parameter.
    // If you don't specify the others (e.g. `n_runs`, `tolerance`, `max_n_iterations`)
    // default values will be used.
    let model = KMeans::params_with_rng(n_clusters, rng.clone())
        .tolerance(1e-2)
        .fit(&observations)
        .expect("KMeans fitted");

    // Once we found our set of centroids, we can also assign new points to the nearest cluster
    let new_observation = DatasetBase::from(array![[-9., 20.5]]);
    // Predict returns the **index** of the nearest cluster
    let dataset = model.predict(new_observation);
    // We can retrieve the actual centroid of the closest cluster using `.centroids()`
    let closest_centroid = &model.centroids().index_axis(Axis(0), dataset.targets()[0]);
    assert_abs_diff_eq!(closest_centroid.to_owned(), &array![-10., 20.], epsilon = 1e-1);
}

// Incremental K-means
{
    let batch_size = 100;
    // Shuffling the dataset is one way of ensuring that the batches contain random points from
    // the dataset, which is required for the algorithm to work properly
    let observations = DatasetBase::from(data.clone()).shuffle(&mut rng);

    let n_clusters = expected_centroids.nrows();
    let clf = KMeans::params_with_rng(n_clusters, rng.clone()).tolerance(1e-3);

    // Repeatedly run fit_with on every batch in the dataset until we have converged
    let model = observations
        .sample_chunks(batch_size)
        .cycle()
        .try_fold(None, |current, batch| {
            match clf.fit_with(current, &batch) {
                // Early stop condition for the kmeans loop
                Ok(model) => Err(model),
                // Continue running if not converged
                Err(IncrKMeansError::NotConverged(model)) => Ok(Some(model)),
                Err(err) => panic!("unexpected kmeans error: {}", err),
            }
        })
        .unwrap_err();

    let new_observation = DatasetBase::from(array![[-9., 20.5]]);
    let dataset = model.predict(new_observation);
    let closest_centroid = &model.centroids().index_axis(Axis(0), dataset.targets()[0]);
    assert_abs_diff_eq!(closest_centroid.to_owned(), &array![-10., 20.], epsilon = 1e-1);
}

Implementations§

Source§

impl<F: Float> KMeans<F, L2Dist>

Source

pub fn params(nclusters: usize) -> KMeansParams<F, Xoshiro256Plus, L2Dist>

Source

pub fn params_with_rng<R: Rng>( nclusters: usize, rng: R, ) -> KMeansParams<F, R, L2Dist>

Source§

impl<F: Float, D: Distance<F>> KMeans<F, D>

Source

pub fn params_with<R: Rng>( nclusters: usize, rng: R, dist_fn: D, ) -> KMeansParams<F, R, D>

Source

pub fn centroids(&self) -> &Array2<F>

Return the set of centroids as a 2-dimensional matrix with shape (n_centroids, n_features).

Source

pub fn cluster_count(&self) -> &Array1<F>

Return the number of training points belonging to each cluster

Source

pub fn inertia(&self) -> F

Return the sum of distances between each training point and its closest centroid, averaged across all training points. When training incrementally, this value is computed on the most recent batch.

Trait Implementations§

Source§

impl<F: Clone + Float, D: Clone + Distance<F>> Clone for KMeans<F, D>

Source§

fn clone(&self) -> KMeans<F, D>

Returns a duplicate of the value. Read more
1.0.0 · Source§

fn clone_from(&mut self, source: &Self)

Performs copy-assignment from source. Read more
Source§

impl<F: Debug + Float, D: Debug + Distance<F>> Debug for KMeans<F, D>

Source§

fn fmt(&self, f: &mut Formatter<'_>) -> Result

Formats the value using the given formatter. Read more
Source§

impl<F: PartialEq + Float, D: PartialEq + Distance<F>> PartialEq for KMeans<F, D>

Source§

fn eq(&self, other: &KMeans<F, D>) -> bool

Tests for self and other values to be equal, and is used by ==.
1.0.0 · Source§

fn ne(&self, other: &Rhs) -> bool

Tests for !=. The default implementation is almost always sufficient, and should not be overridden without very good reason.
Source§

impl<F: Float, DA: Data<Elem = F>, D: Distance<F>> PredictInplace<ArrayBase<DA, Dim<[usize; 1]>>, usize> for KMeans<F, D>

Source§

fn predict_inplace( &self, observation: &ArrayBase<DA, Ix1>, membership: &mut usize, )

Given one input observation, return the index of its closest cluster

You can retrieve the centroid associated to an index using the centroids method.

Source§

fn default_target(&self, _x: &ArrayBase<DA, Ix1>) -> usize

Create targets that predict_inplace works with.
Source§

impl<F: Float, DA: Data<Elem = F>, D: Distance<F>> PredictInplace<ArrayBase<DA, Dim<[usize; 2]>>, ArrayBase<OwnedRepr<usize>, Dim<[usize; 1]>>> for KMeans<F, D>

Source§

fn predict_inplace( &self, observations: &ArrayBase<DA, Ix2>, memberships: &mut Array1<usize>, )

Given an input matrix observations, with shape (n_observations, n_features), predict returns, for each observation, the index of the closest cluster/centroid.

You can retrieve the centroid associated to an index using the centroids method.

Source§

fn default_target(&self, x: &ArrayBase<DA, Ix2>) -> Array1<usize>

Create targets that predict_inplace works with.
Source§

impl<F: Float, DA: Data<Elem = F>, D: Distance<F>> Transformer<&ArrayBase<DA, Dim<[usize; 2]>>, ArrayBase<OwnedRepr<F>, Dim<[usize; 1]>>> for KMeans<F, D>

Source§

fn transform(&self, observations: &ArrayBase<DA, Ix2>) -> Array1<F>

Given an input matrix observations, with shape (n_observations, n_features), transform returns, for each observation, its squared distance to its centroid.

Source§

impl<F: Float, D: Distance<F>> StructuralPartialEq for KMeans<F, D>

Auto Trait Implementations§

§

impl<F, D> Freeze for KMeans<F, D>
where F: Freeze, D: Freeze,

§

impl<F, D> RefUnwindSafe for KMeans<F, D>

§

impl<F, D> Send for KMeans<F, D>

§

impl<F, D> Sync for KMeans<F, D>

§

impl<F, D> Unpin for KMeans<F, D>

§

impl<F, D> UnwindSafe for KMeans<F, D>

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> CloneToUninit for T
where T: Clone,

Source§

unsafe fn clone_to_uninit(&self, dest: *mut u8)

🔬This is a nightly-only experimental API. (clone_to_uninit)
Performs copy-assignment from self to dest. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> IntoEither for T

Source§

fn into_either(self, into_left: bool) -> Either<Self, Self>

Converts self into a Left variant of Either<Self, Self> if into_left is true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
where F: FnOnce(&Self) -> bool,

Converts self into a Left variant of Either<Self, Self> if into_left(&self) returns true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

impl<T> Pointable for T

Source§

const ALIGN: usize

The alignment of pointer.
Source§

type Init = T

The type for initializers.
Source§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
Source§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
Source§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
Source§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
Source§

impl<'a, F, D, DM, T, O> Predict<&'a ArrayBase<D, DM>, T> for O
where D: Data<Elem = F>, DM: Dimension, O: PredictInplace<ArrayBase<D, DM>, T>,

Source§

fn predict(&self, records: &'a ArrayBase<D, DM>) -> T

Source§

impl<'a, F, R, T, S, O> Predict<&'a DatasetBase<R, T>, S> for O
where R: Records<Elem = F>, O: PredictInplace<R, S>,

Source§

fn predict(&self, ds: &'a DatasetBase<R, T>) -> S

Source§

impl<F, D, E, T, O> Predict<ArrayBase<D, Dim<[usize; 2]>>, DatasetBase<ArrayBase<D, Dim<[usize; 2]>>, T>> for O
where D: Data<Elem = F>, T: AsTargets<Elem = E>, O: PredictInplace<ArrayBase<D, Dim<[usize; 2]>>, T>,

Source§

fn predict( &self, records: ArrayBase<D, Dim<[usize; 2]>>, ) -> DatasetBase<ArrayBase<D, Dim<[usize; 2]>>, T>

Source§

impl<F, R, T, E, S, O> Predict<DatasetBase<R, T>, DatasetBase<R, S>> for O
where R: Records<Elem = F>, S: AsTargets<Elem = E>, O: PredictInplace<R, S>,

Source§

fn predict(&self, ds: DatasetBase<R, T>) -> DatasetBase<R, S>

Source§

impl<T> ToOwned for T
where T: Clone,

Source§

type Owned = T

The resulting type after obtaining ownership.
Source§

fn to_owned(&self) -> T

Creates owned data from borrowed data, usually by cloning. Read more
Source§

fn clone_into(&self, target: &mut T)

Uses borrowed data to replace owned data, usually by cloning. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

Source§

fn vzip(self) -> V