Struct linfa_clustering::GaussianMixtureModel
source · [−]pub struct GaussianMixtureModel<F: Float> { /* private fields */ }
Expand description
Gaussian Mixture Model (GMM) aims at clustering a dataset by finding normally distributed sub datasets (hence the Gaussian Mixture name) .
GMM assumes all the data points are generated from a mixture of a number K of Gaussian distributions with certain parameters. Expectation-maximization (EM) algorithm is used to fit the GMM to the dataset by parameterizing the weight, mean, and covariance of each cluster distribution.
This implementation is a port of the scikit-learn 0.23.2 Gaussian Mixture implementation.
The algorithm
The general idea is to maximize the likelihood (equivalently the log likelihood) that is maximising the probability that the dataset is drawn from our mixture of normal distributions.
After an initialization step which can be either from random distribution or from the result
of the KMeans algorithm (which is the default value of the init_method
parameter).
The core EM iterative algorithm for Gaussian Mixture is a fixed-point two-step algorithm:
- Expectation step: compute the expectation of the likelihood of the current gaussian mixture model wrt the dataset.
- Maximization step: update the gaussian parameters (weigths, means and covariances) to maximize the likelihood.
We stop iterating when there is no significant gaussian parameters change (controlled by the tolerance
parameter) or
if we reach a max number of iterations (controlled by max_n_iterations
parameter)
As the initialization of the algorithm is subject to randomness, several initializations are performed (controlled by
the n_runs
parameter).
Tutorial
Let’s do a walkthrough of a training-predict-save example.
use linfa::DatasetBase;
use linfa::prelude::*;
use linfa_clustering::{GmmValidParams, GaussianMixtureModel};
use linfa_datasets::generate;
use ndarray::{Axis, array, s, Zip};
use ndarray_rand::rand::SeedableRng;
use rand_xoshiro::Xoshiro256Plus;
use approx::assert_abs_diff_eq;
let mut rng = Xoshiro256Plus::seed_from_u64(42);
let expected_centroids = array![[0., 1.], [-10., 20.], [-1., 10.]];
let n = 200;
// We generate a dataset from points normally distributed around some distant centroids.
let dataset = DatasetBase::from(generate::blobs(n, &expected_centroids, &mut rng));
// Our GMM is expected to have a number of clusters equals the number of centroids
// used to generate the dataset
let n_clusters = expected_centroids.len_of(Axis(0));
// We fit the model from the dataset setting some options
let gmm = GaussianMixtureModel::params(n_clusters)
.n_runs(10)
.tolerance(1e-4)
.with_rng(rng)
.fit(&dataset).expect("GMM fitting");
// Then we can get dataset membership information, targets contain **cluster indexes**
// corresponding to the cluster infos in the list of GMM means and covariances
let blobs_dataset = gmm.predict(dataset);
let DatasetBase {
records: _blobs_records,
targets: blobs_targets,
..
} = blobs_dataset;
println!("GMM means = {:?}", gmm.means());
println!("GMM covariances = {:?}", gmm.covariances());
println!("GMM membership = {:?}", blobs_targets);
// We can also get the nearest cluster for a new point
let new_observation = DatasetBase::from(array![[-9., 20.5]]);
// Predict returns the **index** of the nearest cluster
let dataset = gmm.predict(new_observation);
// We can retrieve the actual centroid of the closest cluster using `.centroids()` (alias of .means())
let closest_centroid = &gmm.centroids().index_axis(Axis(0), dataset.targets()[0]);
Implementations
sourceimpl<F: Float> GaussianMixtureModel<F>
impl<F: Float> GaussianMixtureModel<F>
pub fn params(n_clusters: usize) -> GmmParams<F, Xoshiro256Plus>
pub fn params_with_rng<R: Rng + Clone>(
n_clusters: usize,
rng: R
) -> GmmParams<F, R>
pub fn weights(&self) -> &Array1<F>
pub fn means(&self) -> &Array2<F>
pub fn covariances(&self) -> &Array3<F>
pub fn precisions(&self) -> &Array3<F>
pub fn centroids(&self) -> &Array2<F>
Trait Implementations
sourceimpl<F: Float> Clone for GaussianMixtureModel<F>
impl<F: Float> Clone for GaussianMixtureModel<F>
sourceimpl<F: Debug + Float> Debug for GaussianMixtureModel<F>
impl<F: Debug + Float> Debug for GaussianMixtureModel<F>
sourceimpl<F: PartialEq + Float> PartialEq<GaussianMixtureModel<F>> for GaussianMixtureModel<F>
impl<F: PartialEq + Float> PartialEq<GaussianMixtureModel<F>> for GaussianMixtureModel<F>
sourcefn eq(&self, other: &GaussianMixtureModel<F>) -> bool
fn eq(&self, other: &GaussianMixtureModel<F>) -> bool
This method tests for self
and other
values to be equal, and is used
by ==
. Read more
sourcefn ne(&self, other: &GaussianMixtureModel<F>) -> bool
fn ne(&self, other: &GaussianMixtureModel<F>) -> bool
This method tests for !=
.
sourceimpl<F: Float, D: Data<Elem = F>> PredictInplace<ArrayBase<D, Dim<[usize; 2]>>, ArrayBase<OwnedRepr<usize>, Dim<[usize; 1]>>> for GaussianMixtureModel<F>
impl<F: Float, D: Data<Elem = F>> PredictInplace<ArrayBase<D, Dim<[usize; 2]>>, ArrayBase<OwnedRepr<usize>, Dim<[usize; 1]>>> for GaussianMixtureModel<F>
impl<F: Float> StructuralPartialEq for GaussianMixtureModel<F>
Auto Trait Implementations
impl<F> RefUnwindSafe for GaussianMixtureModel<F> where
F: RefUnwindSafe,
impl<F> Send for GaussianMixtureModel<F>
impl<F> Sync for GaussianMixtureModel<F>
impl<F> Unpin for GaussianMixtureModel<F>
impl<F> UnwindSafe for GaussianMixtureModel<F> where
F: RefUnwindSafe,
Blanket Implementations
sourceimpl<T> BorrowMut<T> for T where
T: ?Sized,
impl<T> BorrowMut<T> for T where
T: ?Sized,
const: unstable · sourcefn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more