[−][src]Struct linfa_clustering::KMeans
K-means clustering aims to partition a set of unlabeled observations into clusters, where each observation belongs to the cluster with the nearest mean.
The mean of the points within a cluster is called centroid.
Given the set of centroids, you can assign an observation to a cluster choosing the nearest centroid.
We provide an implementation of the standard algorithm, also known as Lloyd's algorithm or naive K-means.
More details on the algorithm can be found in the next section or here.
The algorithm
K-means is an iterative algorithm: it progressively refines the choice of centroids.
It's guaranteed to converge, even though it might not find the optimal set of centroids (unfortunately it can get stuck in a local minimum, finding the optimal minimum if NP-hard!).
There are three steps in the standard algorithm:
- initialisation step: how do we choose our initial set of centroids?
- assignment step: assign each observation to the nearest cluster (minimum distance between the observation and the cluster's centroid);
- update step: recompute the centroid of each cluster.
The initialisation step is a one-off, done at the very beginning.
Assignment and update are repeated in a loop until convergence is reached (either the
euclidean distance between the old and the new clusters is below tolerance
or
we exceed the max_n_iterations
).
Parallelisation
The work performed by the assignment step does not require any coordination: the closest centroid for each point can be computed independently from the closest centroid for any of the remaining points.
This makes it a good candidate for parallel execution: KMeans::fit
parallelises the
assignment step thanks to the rayon
feature in ndarray
.
The update step requires a bit more coordination (computing a rolling mean in parallel) but it is still parallelisable. Nonetheless, our first attempts have not improved performance (most likely due to our strategy used to split work between threads), hence the update step is currently executed on a single thread.
Tutorial
Let's do a walkthrough of a training-predict-save example.
use linfa_clustering::{KMeansHyperParams, KMeans, generate_blobs}; use ndarray::{Axis, array, s}; use ndarray_rand::rand::SeedableRng; use rand_isaac::Isaac64Rng; use approx::assert_abs_diff_eq; // Our random number generator, seeded for reproducibility let seed = 42; let mut rng = Isaac64Rng::seed_from_u64(seed); // `expected_centroids` has shape `(n_centroids, n_features)` // i.e. three points in the 2-dimensional plane let expected_centroids = array![[0., 1.], [-10., 20.], [-1., 10.]]; // Let's generate a synthetic dataset: three blobs of observations // (100 points each) centered around our `expected_centroids` let observations = generate_blobs(100, &expected_centroids, &mut rng); // Let's configure and run our K-means algorithm // We use the builder pattern to specify the hyperparameters // `n_clusters` is the only mandatory parameter. // If you don't specify the others (e.g. `tolerance` or `max_n_iterations`) // default values will be used. let n_clusters = expected_centroids.len_of(Axis(0)); let hyperparams = KMeansHyperParams::new(n_clusters) .tolerance(1e-2) .build(); // Let's run the algorithm! let model = KMeans::fit(hyperparams, &observations, &mut rng); // Once we found our set of centroids, we can also assign new points to the nearest cluster let new_observation = array![[-9., 20.5]]; // Predict returns the **index** of the nearest cluster let closest_cluster_index = model.predict(&new_observation); // We can retrieve the actual centroid of the closest cluster using `.centroids()` let closest_centroid = &model.centroids().index_axis(Axis(0), closest_cluster_index[0]); // The model can be serialised (and deserialised) to disk using serde // We'll use the JSON format here for simplicity let filename = "k_means_model.json"; let writer = std::fs::File::create(filename).expect("Failed to open file."); serde_json::to_writer(writer, &model).expect("Failed to serialise model."); let reader = std::fs::File::open(filename).expect("Failed to open file."); let loaded_model: KMeans = serde_json::from_reader(reader).expect("Failed to deserialise model"); assert_abs_diff_eq!(model.centroids(), loaded_model.centroids(), epsilon = 1e-10); assert_eq!(model.hyperparameters(), loaded_model.hyperparameters());
Methods
impl KMeans
[src]
pub fn fit(
hyperparameters: KMeansHyperParams,
observations: &ArrayBase<impl Data<Elem = f64> + Sync, Ix2>,
rng: &mut impl Rng
) -> Self
[src]
hyperparameters: KMeansHyperParams,
observations: &ArrayBase<impl Data<Elem = f64> + Sync, Ix2>,
rng: &mut impl Rng
) -> Self
Given an input matrix observations
, with shape (n_observations, n_features)
,
fit
identifies n_clusters
centroids based on the training data distribution.
An instance of KMeans
is returned.
pub fn predict(
&self,
observations: &ArrayBase<impl Data<Elem = f64>, Ix2>
) -> Array1<usize>
[src]
&self,
observations: &ArrayBase<impl Data<Elem = f64>, Ix2>
) -> Array1<usize>
Given an input matrix observations
, with shape (n_observations, n_features)
,
predict
returns, for each observation, the index of the closest cluster/centroid.
You can retrieve the centroid associated to an index using the
centroids
method.
pub fn centroids(&self) -> &Array2<f64>
[src]
Return the set of centroids as a 2-dimensional matrix with shape
(n_centroids, n_features)
.
pub fn hyperparameters(&self) -> &KMeansHyperParams
[src]
Return the hyperparameters used to train this K-means model instance.
Trait Implementations
impl Clone for KMeans
[src]
impl PartialEq<KMeans> for KMeans
[src]
impl Debug for KMeans
[src]
impl StructuralPartialEq for KMeans
[src]
impl Serialize for KMeans
[src]
fn serialize<__S>(&self, __serializer: __S) -> Result<__S::Ok, __S::Error> where
__S: Serializer,
[src]
__S: Serializer,
impl<'de> Deserialize<'de> for KMeans
[src]
fn deserialize<__D>(__deserializer: __D) -> Result<Self, __D::Error> where
__D: Deserializer<'de>,
[src]
__D: Deserializer<'de>,
Auto Trait Implementations
impl Send for KMeans
impl Sync for KMeans
impl Unpin for KMeans
impl UnwindSafe for KMeans
impl RefUnwindSafe for KMeans
Blanket Implementations
impl<T, U> Into<U> for T where
U: From<T>,
[src]
U: From<T>,
impl<T> From<T> for T
[src]
impl<T> ToOwned for T where
T: Clone,
[src]
T: Clone,
type Owned = T
The resulting type after obtaining ownership.
fn to_owned(&self) -> T
[src]
fn clone_into(&self, target: &mut T)
[src]
impl<T, U> TryFrom<U> for T where
U: Into<T>,
[src]
U: Into<T>,
type Error = Infallible
The type returned in the event of a conversion error.
fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>
[src]
impl<T, U> TryInto<U> for T where
U: TryFrom<T>,
[src]
U: TryFrom<T>,
type Error = <U as TryFrom<T>>::Error
The type returned in the event of a conversion error.
fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>
[src]
impl<T> Borrow<T> for T where
T: ?Sized,
[src]
T: ?Sized,
impl<T> BorrowMut<T> for T where
T: ?Sized,
[src]
T: ?Sized,
fn borrow_mut(&mut self) -> &mut T
[src]
impl<T> Any for T where
T: 'static + ?Sized,
[src]
T: 'static + ?Sized,
impl<T> DeserializeOwned for T where
T: Deserialize<'de>,
[src]
T: Deserialize<'de>,
impl<V, T> VZip<V> for T where
V: MultiLane<T>,
V: MultiLane<T>,