pub struct GaussianMixtureModel<F: Float> { /* private fields */ }
Expand description

Gaussian Mixture Model (GMM) aims at clustering a dataset by finding normally distributed sub datasets (hence the Gaussian Mixture name) .

GMM assumes all the data points are generated from a mixture of a number K of Gaussian distributions with certain parameters. Expectation-maximization (EM) algorithm is used to fit the GMM to the dataset by parameterizing the weight, mean, and covariance of each cluster distribution.

This implementation is a port of the scikit-learn 0.23.2 Gaussian Mixture implementation.

The algorithm

The general idea is to maximize the likelihood (equivalently the log likelihood) that is maximising the probability that the dataset is drawn from our mixture of normal distributions.

After an initialization step which can be either from random distribution or from the result of the KMeans algorithm (which is the default value of the init_method parameter). The core EM iterative algorithm for Gaussian Mixture is a fixed-point two-step algorithm:

  1. Expectation step: compute the expectation of the likelihood of the current gaussian mixture model wrt the dataset.
  2. Maximization step: update the gaussian parameters (weigths, means and covariances) to maximize the likelihood.

We stop iterating when there is no significant gaussian parameters change (controlled by the tolerance parameter) or if we reach a max number of iterations (controlled by max_n_iterations parameter) As the initialization of the algorithm is subject to randomness, several initializations are performed (controlled by the n_runs parameter).

Tutorial

Let’s do a walkthrough of a training-predict-save example.

use linfa::DatasetBase;
use linfa::prelude::*;
use linfa_clustering::{GmmValidParams, GaussianMixtureModel};
use linfa_datasets::generate;
use ndarray::{Axis, array, s, Zip};
use ndarray_rand::rand::SeedableRng;
use rand_xoshiro::Xoshiro256Plus;
use approx::assert_abs_diff_eq;

let mut rng = Xoshiro256Plus::seed_from_u64(42);
let expected_centroids = array![[0., 1.], [-10., 20.], [-1., 10.]];
let n = 200;

// We generate a dataset from points normally distributed around some distant centroids.  
let dataset = DatasetBase::from(generate::blobs(n, &expected_centroids, &mut rng));

// Our GMM is expected to have a number of clusters equals the number of centroids
// used to generate the dataset
let n_clusters = expected_centroids.len_of(Axis(0));

// We fit the model from the dataset setting some options
let gmm = GaussianMixtureModel::params(n_clusters)
            .n_runs(10)
            .tolerance(1e-4)
            .with_rng(rng)
            .fit(&dataset).expect("GMM fitting");

// Then we can get dataset membership information, targets contain **cluster indexes**
// corresponding to the cluster infos in the list of GMM means and covariances
let blobs_dataset = gmm.predict(dataset);
let DatasetBase {
    records: _blobs_records,
    targets: blobs_targets,
    ..
} = blobs_dataset;
println!("GMM means = {:?}", gmm.means());
println!("GMM covariances = {:?}", gmm.covariances());
println!("GMM membership = {:?}", blobs_targets);

// We can also get the nearest cluster for a new point
let new_observation = DatasetBase::from(array![[-9., 20.5]]);
// Predict returns the **index** of the nearest cluster
let dataset = gmm.predict(new_observation);
// We can retrieve the actual centroid of the closest cluster using `.centroids()` (alias of .means())
let closest_centroid = &gmm.centroids().index_axis(Axis(0), dataset.targets()[0]);

Implementations§

source§

impl<F: Float> GaussianMixtureModel<F>

source

pub fn params(n_clusters: usize) -> GmmParams<F, Xoshiro256Plus>

source

pub fn params_with_rng<R: Rng + Clone>( n_clusters: usize, rng: R ) -> GmmParams<F, R>

source

pub fn weights(&self) -> &Array1<F>

source

pub fn means(&self) -> &Array2<F>

source

pub fn covariances(&self) -> &Array3<F>

source

pub fn precisions(&self) -> &Array3<F>

source

pub fn centroids(&self) -> &Array2<F>

Trait Implementations§

source§

impl<F: Float> Clone for GaussianMixtureModel<F>

source§

fn clone(&self) -> Self

Returns a copy of the value. Read more
1.0.0 · source§

fn clone_from(&mut self, source: &Self)

Performs copy-assignment from source. Read more
source§

impl<F: Debug + Float> Debug for GaussianMixtureModel<F>

source§

fn fmt(&self, f: &mut Formatter<'_>) -> Result

Formats the value using the given formatter. Read more
source§

impl<F: PartialEq + Float> PartialEq for GaussianMixtureModel<F>

source§

fn eq(&self, other: &GaussianMixtureModel<F>) -> bool

This method tests for self and other values to be equal, and is used by ==.
1.0.0 · source§

fn ne(&self, other: &Rhs) -> bool

This method tests for !=. The default implementation is almost always sufficient, and should not be overridden without very good reason.
source§

impl<F: Float, D: Data<Elem = F>> PredictInplace<ArrayBase<D, Dim<[usize; 2]>>, ArrayBase<OwnedRepr<usize>, Dim<[usize; 1]>>> for GaussianMixtureModel<F>

source§

fn predict_inplace( &self, observations: &ArrayBase<D, Ix2>, targets: &mut Array1<usize> )

Predict something in place
source§

fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array1<usize>

Create targets that predict_inplace works with.
source§

impl<F: Float> StructuralPartialEq for GaussianMixtureModel<F>

Auto Trait Implementations§

Blanket Implementations§

source§

impl<T> Any for Twhere T: 'static + ?Sized,

source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
source§

impl<T> Borrow<T> for Twhere T: ?Sized,

source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
source§

impl<T> BorrowMut<T> for Twhere T: ?Sized,

source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
source§

impl<T> From<T> for T

source§

fn from(t: T) -> T

Returns the argument unchanged.

source§

impl<T, U> Into<U> for Twhere U: From<T>,

source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

§

impl<T> Pointable for T

§

const ALIGN: usize = _

The alignment of pointer.
§

type Init = T

The type for initializers.
§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
source§

impl<'a, F, D, DM, T, O> Predict<&'a ArrayBase<D, DM>, T> for Owhere D: Data<Elem = F>, DM: Dimension, O: PredictInplace<ArrayBase<D, DM>, T>,

source§

fn predict(&self, records: &'a ArrayBase<D, DM>) -> T

source§

impl<'a, F, R, T, S, O> Predict<&'a DatasetBase<R, T>, S> for Owhere R: Records<Elem = F>, O: PredictInplace<R, S>,

source§

fn predict(&self, ds: &'a DatasetBase<R, T>) -> S

source§

impl<F, D, E, T, O> Predict<ArrayBase<D, Dim<[usize; 2]>>, DatasetBase<ArrayBase<D, Dim<[usize; 2]>>, T>> for Owhere D: Data<Elem = F>, T: AsTargets<Elem = E>, O: PredictInplace<ArrayBase<D, Dim<[usize; 2]>>, T>,

source§

fn predict( &self, records: ArrayBase<D, Dim<[usize; 2]>> ) -> DatasetBase<ArrayBase<D, Dim<[usize; 2]>>, T>

source§

impl<F, R, T, E, S, O> Predict<DatasetBase<R, T>, DatasetBase<R, S>> for Owhere R: Records<Elem = F>, S: AsTargets<Elem = E>, O: PredictInplace<R, S>,

source§

fn predict(&self, ds: DatasetBase<R, T>) -> DatasetBase<R, S>

source§

impl<T> ToOwned for Twhere T: Clone,

§

type Owned = T

The resulting type after obtaining ownership.
source§

fn to_owned(&self) -> T

Creates owned data from borrowed data, usually by cloning. Read more
source§

fn clone_into(&self, target: &mut T)

Uses borrowed data to replace owned data, usually by cloning. Read more
source§

impl<T, U> TryFrom<U> for Twhere U: Into<T>,

§

type Error = Infallible

The type returned in the event of a conversion error.
source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
source§

impl<T, U> TryInto<U> for Twhere U: TryFrom<T>,

§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
§

impl<V, T> VZip<V> for Twhere V: MultiLane<T>,

§

fn vzip(self) -> V