use crate::gaussian_mixture::errors::{GmmError, Result};
use crate::gaussian_mixture::hyperparams::{
GmmCovarType, GmmInitMethod, GmmParams, GmmValidParams,
};
use crate::k_means::KMeans;
use linfa::{prelude::*, DatasetBase, Float};
use linfa_linalg::{cholesky::*, triangular::*};
use ndarray::{s, Array, Array1, Array2, Array3, ArrayBase, Axis, Data, Ix2, Ix3, Zip};
use ndarray_rand::rand::Rng;
use ndarray_rand::rand_distr::Uniform;
use ndarray_rand::RandomExt;
use ndarray_stats::QuantileExt;
use rand_xoshiro::Xoshiro256Plus;
#[cfg(feature = "serde")]
use serde_crate::{Deserialize, Serialize};
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
#[derive(Debug, PartialEq)]
pub struct GaussianMixtureModel<F: Float> {
covar_type: GmmCovarType,
weights: Array1<F>,
means: Array2<F>,
covariances: Array3<F>,
precisions: Array3<F>,
precisions_chol: Array3<F>,
}
impl<F: Float> Clone for GaussianMixtureModel<F> {
fn clone(&self) -> Self {
Self {
covar_type: self.covar_type,
weights: self.weights.to_owned(),
means: self.means.to_owned(),
covariances: self.covariances.to_owned(),
precisions: self.precisions.to_owned(),
precisions_chol: self.precisions_chol.to_owned(),
}
}
}
impl<F: Float> GaussianMixtureModel<F> {
fn new<D: Data<Elem = F>, R: Rng + Clone, T>(
hyperparameters: &GmmValidParams<F, R>,
dataset: &DatasetBase<ArrayBase<D, Ix2>, T>,
mut rng: R,
) -> Result<GaussianMixtureModel<F>> {
let observations = dataset.records().view();
let n_samples = observations.nrows();
let resp = match hyperparameters.init_method() {
GmmInitMethod::KMeans => {
let model = KMeans::params_with_rng(hyperparameters.n_clusters(), rng)
.check()
.unwrap()
.fit(dataset)?;
let mut resp = Array::<F, Ix2>::zeros((n_samples, hyperparameters.n_clusters()));
for (k, idx) in model.predict(dataset.records()).iter().enumerate() {
resp[[k, *idx]] = F::cast(1.);
}
resp
}
GmmInitMethod::Random => {
let mut resp = Array2::<f64>::random_using(
(n_samples, hyperparameters.n_clusters()),
Uniform::new(0., 1.),
&mut rng,
);
let totals = &resp.sum_axis(Axis(1)).insert_axis(Axis(0));
resp = (resp.reversed_axes() / totals).reversed_axes();
resp.mapv(F::cast)
}
};
let (mut weights, means, covariances) = Self::estimate_gaussian_parameters(
&observations,
&resp,
hyperparameters.covariance_type(),
hyperparameters.reg_covariance(),
)?;
weights /= F::cast(n_samples);
let precisions_chol = Self::compute_precisions_cholesky_full(&covariances)?;
let precisions = Self::compute_precisions_full(&precisions_chol);
Ok(GaussianMixtureModel {
covar_type: *hyperparameters.covariance_type(),
weights,
means,
covariances,
precisions,
precisions_chol,
})
}
}
impl<F: Float> GaussianMixtureModel<F> {
pub fn params(n_clusters: usize) -> GmmParams<F, Xoshiro256Plus> {
GmmParams::new(n_clusters)
}
pub fn params_with_rng<R: Rng + Clone>(n_clusters: usize, rng: R) -> GmmParams<F, R> {
GmmParams::new_with_rng(n_clusters, rng)
}
pub fn weights(&self) -> &Array1<F> {
&self.weights
}
pub fn means(&self) -> &Array2<F> {
&self.means
}
pub fn covariances(&self) -> &Array3<F> {
&self.covariances
}
pub fn precisions(&self) -> &Array3<F> {
&self.precisions
}
pub fn centroids(&self) -> &Array2<F> {
self.means()
}
fn estimate_gaussian_parameters<D: Data<Elem = F>>(
observations: &ArrayBase<D, Ix2>,
resp: &Array2<F>,
_covar_type: &GmmCovarType,
reg_covar: F,
) -> Result<(Array1<F>, Array2<F>, Array3<F>)> {
let nk = resp.sum_axis(Axis(0));
if nk.min()? < &(F::cast(10.) * F::epsilon()) {
return Err(GmmError::EmptyCluster(format!(
"Cluster #{} has no more point. Consider decreasing number of clusters or change initialization.",
nk.argmin()? + 1
)));
}
let nk2 = nk.to_owned().insert_axis(Axis(1));
let means = resp.t().dot(observations) / nk2;
let covariances =
Self::estimate_gaussian_covariances_full(observations, resp, &nk, &means, reg_covar);
Ok((nk, means, covariances))
}
fn estimate_gaussian_covariances_full<D: Data<Elem = F>>(
observations: &ArrayBase<D, Ix2>,
resp: &Array2<F>,
nk: &Array1<F>,
means: &Array2<F>,
reg_covar: F,
) -> Array3<F> {
let n_clusters = means.nrows();
let n_features = means.ncols();
let mut covariances = Array::zeros((n_clusters, n_features, n_features));
for k in 0..n_clusters {
let diff = observations - &means.row(k);
let m = &diff.t() * &resp.index_axis(Axis(1), k);
let mut cov_k = m.dot(&diff) / nk[k];
cov_k.diag_mut().mapv_inplace(|x| x + reg_covar);
covariances.slice_mut(s![k, .., ..]).assign(&cov_k);
}
covariances
}
fn compute_precisions_cholesky_full<D: Data<Elem = F>>(
covariances: &ArrayBase<D, Ix3>,
) -> Result<Array3<F>> {
let n_clusters = covariances.shape()[0];
let n_features = covariances.shape()[1];
let mut precisions_chol = Array::zeros((n_clusters, n_features, n_features));
for (k, covariance) in covariances.outer_iter().enumerate() {
let sol = {
let decomp = covariance.cholesky()?;
decomp.solve_triangular_into(Array::eye(n_features), UPLO::Lower)?
};
precisions_chol.slice_mut(s![k, .., ..]).assign(&sol.t());
}
Ok(precisions_chol)
}
fn compute_precisions_full<D: Data<Elem = F>>(
precisions_chol: &ArrayBase<D, Ix3>,
) -> Array3<F> {
let mut precisions = Array3::zeros(precisions_chol.dim());
for (k, prec_chol) in precisions_chol.outer_iter().enumerate() {
precisions
.slice_mut(s![k, .., ..])
.assign(&prec_chol.dot(&prec_chol.t()));
}
precisions
}
fn refresh_precisions_full(&mut self) {
self.precisions = Self::compute_precisions_full(&self.precisions_chol);
}
fn e_step<D: Data<Elem = F>>(
&self,
observations: &ArrayBase<D, Ix2>,
) -> Result<(F, Array2<F>)> {
let (log_prob_norm, log_resp) = self.estimate_log_prob_resp(observations);
let log_mean = log_prob_norm.mean().unwrap();
Ok((log_mean, log_resp))
}
fn m_step<D: Data<Elem = F>>(
&mut self,
reg_covar: F,
observations: &ArrayBase<D, Ix2>,
log_resp: &Array2<F>,
) -> Result<()> {
let n_samples = observations.nrows();
let (weights, means, covariances) = Self::estimate_gaussian_parameters(
observations,
&log_resp.mapv(|x| x.exp()),
&self.covar_type,
reg_covar,
)?;
self.means = means;
self.weights = weights / F::cast(n_samples);
self.precisions_chol = Self::compute_precisions_cholesky_full(&covariances)?;
Ok(())
}
fn compute_lower_bound<D: Data<Elem = F>>(
_log_resp: &ArrayBase<D, Ix2>,
log_prob_norm: F,
) -> F {
log_prob_norm
}
fn estimate_log_prob_resp<D: Data<Elem = F>>(
&self,
observations: &ArrayBase<D, Ix2>,
) -> (Array1<F>, Array2<F>) {
let weighted_log_prob = self.estimate_weighted_log_prob(observations);
let log_prob_norm = weighted_log_prob
.mapv(|x| x.exp())
.sum_axis(Axis(1))
.mapv(|x| x.ln());
let log_resp = weighted_log_prob - log_prob_norm.to_owned().insert_axis(Axis(1));
(log_prob_norm, log_resp)
}
fn estimate_weighted_log_prob<D: Data<Elem = F>>(
&self,
observations: &ArrayBase<D, Ix2>,
) -> Array2<F> {
self.estimate_log_prob(observations) + self.estimate_log_weights()
}
fn estimate_log_prob<D: Data<Elem = F>>(&self, observations: &ArrayBase<D, Ix2>) -> Array2<F> {
self.estimate_log_gaussian_prob(observations)
}
fn estimate_log_gaussian_prob<D: Data<Elem = F>>(
&self,
observations: &ArrayBase<D, Ix2>,
) -> Array2<F> {
let n_samples = observations.nrows();
let n_features = observations.ncols();
let means = self.means();
let n_clusters = means.nrows();
let log_det = Self::compute_log_det_cholesky_full(&self.precisions_chol, n_features);
let mut log_prob: Array2<F> = Array::zeros((n_samples, n_clusters));
Zip::indexed(means.rows())
.and(self.precisions_chol.outer_iter())
.for_each(|k, mu, prec_chol| {
let diff = (&observations.to_owned() - &mu).dot(&prec_chol);
log_prob
.slice_mut(s![.., k])
.assign(&diff.mapv(|v| v * v).sum_axis(Axis(1)))
});
log_prob.mapv(|v| {
F::cast(-0.5) * (v + F::cast(n_features as f64 * f64::ln(2. * std::f64::consts::PI)))
}) + log_det
}
fn compute_log_det_cholesky_full<D: Data<Elem = F>>(
matrix_chol: &ArrayBase<D, Ix3>,
n_features: usize,
) -> Array1<F> {
let n_clusters = matrix_chol.shape()[0];
let log_diags = &matrix_chol
.to_owned()
.into_shape((n_clusters, n_features * n_features))
.unwrap()
.slice(s![.., ..; n_features+1])
.to_owned()
.mapv(|x| x.ln());
log_diags.sum_axis(Axis(1))
}
fn estimate_log_weights(&self) -> Array1<F> {
self.weights().mapv(|x| x.ln())
}
}
impl<F: Float, R: Rng + Clone, D: Data<Elem = F>, T> Fit<ArrayBase<D, Ix2>, T, GmmError>
for GmmValidParams<F, R>
{
type Object = GaussianMixtureModel<F>;
fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object> {
let observations = dataset.records().view();
let mut gmm = GaussianMixtureModel::<F>::new(self, dataset, self.rng())?;
let mut max_lower_bound = -F::infinity();
let mut best_params = None;
let mut best_iter = None;
let n_runs = self.n_runs();
for _ in 0..n_runs {
let mut lower_bound = -F::infinity();
let mut converged_iter: Option<u64> = None;
for n_iter in 0..self.max_n_iterations() {
let prev_lower_bound = lower_bound;
let (log_prob_norm, log_resp) = gmm.e_step(&observations)?;
gmm.m_step(self.reg_covariance(), &observations, &log_resp)?;
lower_bound =
GaussianMixtureModel::<F>::compute_lower_bound(&log_resp, log_prob_norm);
let change = lower_bound - prev_lower_bound;
if change.abs() < self.tolerance() {
converged_iter = Some(n_iter);
break;
}
}
if lower_bound > max_lower_bound {
max_lower_bound = lower_bound;
gmm.refresh_precisions_full();
best_params = Some(gmm.clone());
best_iter = converged_iter;
}
}
match best_iter {
Some(_n_iter) => match best_params {
Some(gmm) => Ok(gmm),
_ => Err(GmmError::LowerBoundError(
"No lower bound improvement (-inf)".to_string(),
)),
},
None => Err(GmmError::NotConverged(format!(
"EM fitting algorithm {} did not converge. Try different init parameters, \
or increase max_n_iterations, tolerance or check for degenerate data.",
(n_runs + 1)
))),
}
}
}
impl<F: Float, D: Data<Elem = F>> PredictInplace<ArrayBase<D, Ix2>, Array1<usize>>
for GaussianMixtureModel<F>
{
fn predict_inplace(&self, observations: &ArrayBase<D, Ix2>, targets: &mut Array1<usize>) {
assert_eq!(
observations.nrows(),
targets.len(),
"The number of data points must match the number of output targets."
);
let (_, log_resp) = self.estimate_log_prob_resp(observations);
*targets = log_resp
.mapv(F::exp)
.map_axis(Axis(1), |row| row.argmax().unwrap());
}
fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array1<usize> {
Array1::zeros(x.nrows())
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::{abs_diff_eq, assert_abs_diff_eq};
use linfa_datasets::generate;
use linfa_linalg::LinalgError;
use linfa_linalg::Result as LAResult;
use ndarray::{array, concatenate, ArrayView1, ArrayView2, Axis};
use ndarray_rand::rand::prelude::ThreadRng;
use ndarray_rand::rand::SeedableRng;
use ndarray_rand::rand_distr::{Distribution, StandardNormal};
#[test]
fn autotraits() {
fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
has_autotraits::<GaussianMixtureModel<f64>>();
has_autotraits::<GmmError>();
has_autotraits::<GmmParams<f64, Xoshiro256Plus>>();
has_autotraits::<GmmValidParams<f64, Xoshiro256Plus>>();
has_autotraits::<GmmInitMethod>();
has_autotraits::<GmmCovarType>();
}
pub struct MultivariateNormal {
pub mean: Array1<f64>,
pub covariance: Array2<f64>,
lower: Array2<f64>,
}
impl MultivariateNormal {
pub fn new(mean: &ArrayView1<f64>, covariance: &ArrayView2<f64>) -> LAResult<Self> {
let lower = covariance.cholesky()?;
Ok(MultivariateNormal {
mean: mean.to_owned(),
covariance: covariance.to_owned(),
lower,
})
}
}
impl Distribution<Array1<f64>> for MultivariateNormal {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Array1<f64> {
let res = Array1::random_using(self.mean.shape()[0], StandardNormal, rng);
self.mean.clone() + self.lower.view().dot(&res)
}
}
#[test]
fn test_gmm_fit() {
let mut rng = Xoshiro256Plus::seed_from_u64(42);
let weights = array![0.5, 0.5];
let means = array![[0., 0.], [5., 5.]];
let covars = array![[[1., 0.8], [0.8, 1.]], [[1.0, -0.6], [-0.6, 1.0]]];
let mvn1 =
MultivariateNormal::new(&means.slice(s![0, ..]), &covars.slice(s![0, .., ..])).unwrap();
let mvn2 =
MultivariateNormal::new(&means.slice(s![1, ..]), &covars.slice(s![1, .., ..])).unwrap();
let n = 500;
let mut observations = Array2::zeros((2 * n, means.ncols()));
for (i, mut row) in observations.rows_mut().into_iter().enumerate() {
let sample = if i < n {
mvn1.sample(&mut rng)
} else {
mvn2.sample(&mut rng)
};
row.assign(&sample);
}
let dataset = DatasetBase::from(observations);
let gmm = GaussianMixtureModel::params(2)
.with_rng(rng)
.fit(&dataset)
.expect("GMM fitting");
let w = gmm.weights();
assert_abs_diff_eq!(w, &weights, epsilon = 1e-1);
let m = gmm.means();
assert!(
abs_diff_eq!(means, &m, epsilon = 1e-1)
|| abs_diff_eq!(means, m.slice(s![..;-1, ..]), epsilon = 1e-1)
);
let c = gmm.covariances();
assert!(
abs_diff_eq!(covars, &c, epsilon = 1e-1)
|| abs_diff_eq!(covars, c.slice(s![..;-1, .., ..]), epsilon = 1e-1)
);
}
fn function_test_1d(x: &Array2<f64>) -> Array2<f64> {
let mut y = Array2::zeros(x.dim());
Zip::from(&mut y).and(x).for_each(|yi, &xi| {
if xi < 0.4 {
*yi = xi * xi;
} else if (0.4..0.8).contains(&xi) {
*yi = 10. * xi + 1.;
} else {
*yi = f64::sin(10. * xi);
}
});
y
}
#[test]
fn test_zeroed_reg_covar_failure() {
let mut rng = Xoshiro256Plus::seed_from_u64(42);
let xt = Array2::random_using((50, 1), Uniform::new(0., 1.0), &mut rng);
let yt = function_test_1d(&xt);
let data = concatenate(Axis(1), &[xt.view(), yt.view()]).unwrap();
let dataset = DatasetBase::from(data);
let gmm = GaussianMixtureModel::params(3)
.reg_covariance(0.)
.with_rng(rng.clone())
.fit(&dataset);
match gmm.expect_err("should generate an error with reg_covar being nul") {
GmmError::LinalgError(e) => {
assert!(matches!(e, LinalgError::NotPositiveDefinite));
}
e => panic!("should be a linear algebra error: {:?}", e),
}
assert!(GaussianMixtureModel::params(3)
.with_rng(rng)
.fit(&dataset)
.is_ok());
}
#[test]
fn test_zeroed_reg_covar_const_failure() {
let xt = Array2::ones((50, 1));
let data = concatenate(Axis(1), &[xt.view(), xt.view()]).unwrap();
let dataset = DatasetBase::from(data);
let gmm = GaussianMixtureModel::params(1)
.reg_covariance(0.)
.fit(&dataset);
gmm.expect_err("should generate an error with reg_covar being nul");
assert!(GaussianMixtureModel::params(1).fit(&dataset).is_ok());
}
#[test]
fn test_centroids_prediction() {
let mut rng = Xoshiro256Plus::seed_from_u64(42);
let expected_centroids = array![[0., 1.], [-10., 20.], [-1., 10.]];
let n = 1000;
let blobs = DatasetBase::from(generate::blobs(n, &expected_centroids, &mut rng));
let n_clusters = expected_centroids.len_of(Axis(0));
let gmm = GaussianMixtureModel::params(n_clusters)
.with_rng(rng)
.fit(&blobs)
.expect("GMM fitting");
let gmm_centroids = gmm.centroids();
let memberships = gmm.predict(&expected_centroids);
for (i, expected_c) in expected_centroids.outer_iter().enumerate() {
let closest_c = gmm_centroids.index_axis(Axis(0), memberships[i]);
Zip::from(&closest_c)
.and(&expected_c)
.for_each(|a, b| assert_abs_diff_eq!(a, b, epsilon = 1.))
}
}
#[test]
fn test_invalid_n_runs() {
assert!(
GaussianMixtureModel::params(1)
.n_runs(0)
.fit(&DatasetBase::from(array![[0.]]))
.is_err(),
"n_runs must be strictly positive"
);
}
#[test]
fn test_invalid_tolerance() {
assert!(
GaussianMixtureModel::params(1)
.tolerance(0.)
.fit(&DatasetBase::from(array![[0.]]))
.is_err(),
"tolerance must be strictly positive"
);
}
#[test]
fn test_invalid_n_clusters() {
assert!(
GaussianMixtureModel::params(0)
.fit(&DatasetBase::from(array![[0., 0.]]))
.is_err(),
"n_clusters must be strictly positive"
);
}
#[test]
fn test_invalid_reg_covariance() {
assert!(
GaussianMixtureModel::params(1)
.reg_covariance(-1e-6)
.fit(&DatasetBase::from(array![[0.]]))
.is_err(),
"reg_covariance must be positive"
);
}
#[test]
fn test_invalid_max_n_iterations() {
assert!(
GaussianMixtureModel::params(1)
.max_n_iterations(0)
.fit(&DatasetBase::from(array![[0.]]))
.is_err(),
"max_n_iterations must be stricly positive"
);
}
fn fittable<T: Fit<Array2<f64>, (), GmmError>>(_: T) {}
#[test]
fn thread_rng_fittable() {
fittable(GaussianMixtureModel::params_with_rng(
1,
ThreadRng::default(),
));
}
}