pub struct GaussianMixtureModel<F: Float> { /* private fields */ }
Expand description

Gaussian Mixture Model (GMM) aims at clustering a dataset by finding normally distributed sub datasets (hence the Gaussian Mixture name) .

GMM assumes all the data points are generated from a mixture of a number K of Gaussian distributions with certain parameters. Expectation-maximization (EM) algorithm is used to fit the GMM to the dataset by parameterizing the weight, mean, and covariance of each cluster distribution.

This implementation is a port of the scikit-learn 0.23.2 Gaussian Mixture implementation.

The algorithm

The general idea is to maximize the likelihood (equivalently the log likelihood) that is maximising the probability that the dataset is drawn from our mixture of normal distributions.

After an initialization step which can be either from random distribution or from the result of the KMeans algorithm (which is the default value of the init_method parameter). The core EM iterative algorithm for Gaussian Mixture is a fixed-point two-step algorithm:

  1. Expectation step: compute the expectation of the likelihood of the current gaussian mixture model wrt the dataset.
  2. Maximization step: update the gaussian parameters (weigths, means and covariances) to maximize the likelihood.

We stop iterating when there is no significant gaussian parameters change (controlled by the tolerance parameter) or if we reach a max number of iterations (controlled by max_n_iterations parameter) As the initialization of the algorithm is subject to randomness, several initializations are performed (controlled by the n_runs parameter).

Tutorial

Let’s do a walkthrough of a training-predict-save example.

use linfa::DatasetBase;
use linfa::prelude::*;
use linfa_clustering::{GmmValidParams, GaussianMixtureModel};
use linfa_datasets::generate;
use ndarray::{Axis, array, s, Zip};
use ndarray_rand::rand::SeedableRng;
use rand_xoshiro::Xoshiro256Plus;
use approx::assert_abs_diff_eq;

let mut rng = Xoshiro256Plus::seed_from_u64(42);
let expected_centroids = array![[0., 1.], [-10., 20.], [-1., 10.]];
let n = 200;

// We generate a dataset from points normally distributed around some distant centroids.  
let dataset = DatasetBase::from(generate::blobs(n, &expected_centroids, &mut rng));

// Our GMM is expected to have a number of clusters equals the number of centroids
// used to generate the dataset
let n_clusters = expected_centroids.len_of(Axis(0));

// We fit the model from the dataset setting some options
let gmm = GaussianMixtureModel::params(n_clusters)
            .n_runs(10)
            .tolerance(1e-4)
            .with_rng(rng)
            .fit(&dataset).expect("GMM fitting");

// Then we can get dataset membership information, targets contain **cluster indexes**
// corresponding to the cluster infos in the list of GMM means and covariances
let blobs_dataset = gmm.predict(dataset);
let DatasetBase {
    records: _blobs_records,
    targets: blobs_targets,
    ..
} = blobs_dataset;
println!("GMM means = {:?}", gmm.means());
println!("GMM covariances = {:?}", gmm.covariances());
println!("GMM membership = {:?}", blobs_targets);

// We can also get the nearest cluster for a new point
let new_observation = DatasetBase::from(array![[-9., 20.5]]);
// Predict returns the **index** of the nearest cluster
let dataset = gmm.predict(new_observation);
// We can retrieve the actual centroid of the closest cluster using `.centroids()` (alias of .means())
let closest_centroid = &gmm.centroids().index_axis(Axis(0), dataset.targets()[0]);

Implementations

Trait Implementations

Returns a copy of the value. Read more

Performs copy-assignment from source. Read more

Formats the value using the given formatter. Read more

This method tests for self and other values to be equal, and is used by ==. Read more

This method tests for !=.

Predict something in place

Create targets that predict_inplace works with.

Auto Trait Implementations

Blanket Implementations

Gets the TypeId of self. Read more

Immutably borrows from an owned value. Read more

Mutably borrows from an owned value. Read more

Returns the argument unchanged.

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

The alignment of pointer.

The type for initializers.

Initializes a with the given initializer. Read more

Dereferences the given pointer. Read more

Mutably dereferences the given pointer. Read more

Drops the object pointed to by the given pointer. Read more

The resulting type after obtaining ownership.

Creates owned data from borrowed data, usually by cloning. Read more

Uses borrowed data to replace owned data, usually by cloning. Read more

The type returned in the event of a conversion error.

Performs the conversion.

The type returned in the event of a conversion error.

Performs the conversion.