use crate::gaussian_mixture::errors::{GmmError, Result};
use crate::gaussian_mixture::hyperparameters::{GmmCovarType, GmmHyperParams, GmmInitMethod};
use crate::k_means::KMeans;
use linfa::{
dataset::{Dataset, Targets},
traits::*,
Float,
};
use ndarray::{s, Array, Array1, Array2, Array3, ArrayBase, Axis, Data, Ix2, Ix3, Zip};
use ndarray_linalg::{cholesky::*, triangular::*, Lapack, Scalar};
use ndarray_rand::rand::Rng;
use ndarray_rand::rand_distr::Uniform;
use ndarray_rand::RandomExt;
use ndarray_stats::QuantileExt;
use rand_isaac::Isaac64Rng;
#[cfg(feature = "serde")]
use serdecrate::{Deserialize, Serialize};
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serdecrate")
)]
#[derive(Debug, PartialEq)]
pub struct GaussianMixtureModel<F: Float> {
covar_type: GmmCovarType,
weights: Array1<F>,
means: Array2<F>,
covariances: Array3<F>,
precisions: Array3<F>,
precisions_chol: Array3<F>,
}
impl<F: Float> Clone for GaussianMixtureModel<F> {
fn clone(&self) -> Self {
Self {
covar_type: self.covar_type,
weights: self.weights.to_owned(),
means: self.means.to_owned(),
covariances: self.covariances.to_owned(),
precisions: self.precisions.to_owned(),
precisions_chol: self.precisions_chol.to_owned(),
}
}
}
impl<F: Float + Lapack + Scalar> GaussianMixtureModel<F> {
pub fn params(n_clusters: usize) -> GmmHyperParams<F, Isaac64Rng> {
GmmHyperParams::new(n_clusters)
}
pub fn weights(&self) -> &Array1<F> {
&self.weights
}
pub fn means(&self) -> &Array2<F> {
&self.means
}
pub fn covariances(&self) -> &Array3<F> {
&self.covariances
}
pub fn precisions(&self) -> &Array3<F> {
&self.precisions
}
pub fn centroids(&self) -> &Array2<F> {
self.means()
}
fn new<D: Data<Elem = F>, R: Rng + Clone, T: Targets>(
hyperparameters: &GmmHyperParams<F, R>,
dataset: &Dataset<ArrayBase<D, Ix2>, T>,
mut rng: R,
) -> Result<GaussianMixtureModel<F>> {
let observations = dataset.records().view();
let n_samples = observations.nrows();
let resp = match hyperparameters.init_method() {
GmmInitMethod::KMeans => {
let model = KMeans::params_with_rng(hyperparameters.n_clusters(), rng)
.build()
.fit(&dataset)?;
let mut resp = Array::<F, Ix2>::zeros((n_samples, hyperparameters.n_clusters()));
for (k, idx) in model.predict(dataset.records()).iter().enumerate() {
resp[[k, *idx]] = F::from(1.).unwrap();
}
resp
}
GmmInitMethod::Random => {
let mut resp = Array2::<f64>::random_using(
(n_samples, hyperparameters.n_clusters()),
Uniform::new(0., 1.),
&mut rng,
);
let totals = &resp.sum_axis(Axis(1)).insert_axis(Axis(0));
resp = (resp.reversed_axes() / totals).reversed_axes();
resp.mapv(|v| F::from(v).unwrap())
}
};
let (mut weights, means, covariances) = Self::estimate_gaussian_parameters(
&observations,
&resp,
hyperparameters.covariance_type(),
hyperparameters.reg_covariance(),
)?;
weights /= F::from(n_samples).unwrap();
let precisions_chol = Self::compute_precisions_cholesky_full(&covariances)?;
let precisions = Self::compute_precisions_full(&precisions_chol);
Ok(GaussianMixtureModel {
covar_type: *hyperparameters.covariance_type(),
weights,
means,
covariances,
precisions,
precisions_chol,
})
}
fn estimate_gaussian_parameters<D: Data<Elem = F>>(
observations: &ArrayBase<D, Ix2>,
resp: &Array2<F>,
_covar_type: &GmmCovarType,
reg_covar: F,
) -> Result<(Array1<F>, Array2<F>, Array3<F>)> {
let nk = resp.sum_axis(Axis(0));
if nk.min().unwrap() < &(F::from(10.).unwrap() * F::epsilon()) {
return Err(GmmError::EmptyCluster(format!(
"Cluster #{} has no more point. Consider decreasing number of clusters or change initialization.",
nk.argmin().unwrap() + 1
)));
}
let nk2 = nk.to_owned().insert_axis(Axis(1));
let means = resp.t().dot(observations) / nk2;
let covariances =
Self::estimate_gaussian_covariances_full(&observations, resp, &nk, &means, reg_covar);
Ok((nk, means, covariances))
}
fn estimate_gaussian_covariances_full<D: Data<Elem = F>>(
observations: &ArrayBase<D, Ix2>,
resp: &Array2<F>,
nk: &Array1<F>,
means: &Array2<F>,
reg_covar: F,
) -> Array3<F> {
let n_clusters = means.nrows();
let n_features = means.ncols();
let mut covariances = Array::zeros((n_clusters, n_features, n_features));
for k in 0..n_clusters {
let diff = observations - &means.row(k);
let m = &diff.t() * &resp.index_axis(Axis(1), k);
let mut cov_k = m.dot(&diff) / nk[k];
cov_k.diag_mut().mapv_inplace(|x| x + reg_covar);
covariances.slice_mut(s![k, .., ..]).assign(&cov_k);
}
covariances
}
fn compute_precisions_cholesky_full<D: Data<Elem = F>>(
covariances: &ArrayBase<D, Ix3>,
) -> Result<Array3<F>> {
let n_clusters = covariances.shape()[0];
let n_features = covariances.shape()[1];
let mut precisions_chol = Array::zeros((n_clusters, n_features, n_features));
for (k, covariance) in covariances.outer_iter().enumerate() {
let cov_chol = covariance.cholesky(UPLO::Lower)?;
let sol =
cov_chol.solve_triangular(UPLO::Lower, Diag::NonUnit, &Array::eye(n_features))?;
precisions_chol.slice_mut(s![k, .., ..]).assign(&sol.t());
}
Ok(precisions_chol)
}
fn compute_precisions_full<D: Data<Elem = F>>(
precisions_chol: &ArrayBase<D, Ix3>,
) -> Array3<F> {
let mut precisions = Array3::zeros(precisions_chol.dim());
for (k, prec_chol) in precisions_chol.outer_iter().enumerate() {
precisions
.slice_mut(s![k, .., ..])
.assign(&prec_chol.dot(&prec_chol.t()));
}
precisions
}
fn refresh_precisions_full(&mut self) {
self.precisions = Self::compute_precisions_full(&self.precisions_chol);
}
fn e_step<D: Data<Elem = F>>(
&self,
observations: &ArrayBase<D, Ix2>,
) -> Result<(F, Array2<F>)> {
let (log_prob_norm, log_resp) = self.estimate_log_prob_resp(&observations);
let log_mean = log_prob_norm.mean().unwrap();
Ok((log_mean, log_resp))
}
fn m_step<D: Data<Elem = F>>(
&mut self,
reg_covar: F,
observations: &ArrayBase<D, Ix2>,
log_resp: &Array2<F>,
) -> Result<()> {
let n_samples = observations.nrows();
let (weights, means, covariances) = Self::estimate_gaussian_parameters(
&observations,
&log_resp.mapv(|v| v.exp()),
&self.covar_type,
reg_covar,
)?;
self.means = means;
self.weights = weights / F::from(n_samples).unwrap();
self.precisions_chol = Self::compute_precisions_cholesky_full(&covariances)?;
Ok(())
}
fn compute_lower_bound<D: Data<Elem = F>>(
_log_resp: &ArrayBase<D, Ix2>,
log_prob_norm: F,
) -> F {
log_prob_norm
}
fn estimate_log_prob_resp<D: Data<Elem = F>>(
&self,
observations: &ArrayBase<D, Ix2>,
) -> (Array1<F>, Array2<F>) {
let weighted_log_prob = self.estimate_weighted_log_prob(&observations);
let log_prob_norm = weighted_log_prob
.mapv(|v| v.exp())
.sum_axis(Axis(1))
.mapv(|v| v.ln());
let log_resp = weighted_log_prob - log_prob_norm.to_owned().insert_axis(Axis(1));
(log_prob_norm, log_resp)
}
fn estimate_weighted_log_prob<D: Data<Elem = F>>(
&self,
observations: &ArrayBase<D, Ix2>,
) -> Array2<F> {
self.estimate_log_prob(&observations) + self.estimate_log_weights()
}
fn estimate_log_prob<D: Data<Elem = F>>(&self, observations: &ArrayBase<D, Ix2>) -> Array2<F> {
self.estimate_log_gaussian_prob(&observations)
}
fn estimate_log_gaussian_prob<D: Data<Elem = F>>(
&self,
observations: &ArrayBase<D, Ix2>,
) -> Array2<F> {
let n_samples = observations.nrows();
let n_features = observations.ncols();
let means = self.means();
let n_clusters = means.nrows();
let log_det = Self::compute_log_det_cholesky_full(&self.precisions_chol, n_features);
let mut log_prob: Array2<F> = Array::zeros((n_samples, n_clusters));
Zip::indexed(means.genrows())
.and(self.precisions_chol.outer_iter())
.apply(|k, mu, prec_chol| {
let diff = (&observations.to_owned() - &mu).dot(&prec_chol);
log_prob
.slice_mut(s![.., k])
.assign(&diff.mapv(|v| v * v).sum_axis(Axis(1)))
});
log_prob.mapv(|v| {
F::from(-0.5).unwrap()
* (v + F::from(n_features as f64 * f64::ln(2. * std::f64::consts::PI)).unwrap())
}) + log_det
}
fn compute_log_det_cholesky_full<D: Data<Elem = F>>(
matrix_chol: &ArrayBase<D, Ix3>,
n_features: usize,
) -> Array1<F> {
let n_clusters = matrix_chol.shape()[0];
let log_diags = &matrix_chol
.to_owned()
.into_shape((n_clusters, n_features * n_features))
.unwrap()
.slice(s![.., ..; n_features+1])
.to_owned()
.mapv(|v| v.ln());
log_diags.sum_axis(Axis(1))
}
fn estimate_log_weights(&self) -> Array1<F> {
self.weights().mapv(|v| v.ln())
}
}
impl<'a, F: Float + Lapack + Scalar, R: Rng + Clone, D: Data<Elem = F>, T: Targets>
Fit<'a, ArrayBase<D, Ix2>, T> for GmmHyperParams<F, R>
{
type Object = Result<GaussianMixtureModel<F>>;
fn fit(&self, dataset: &Dataset<ArrayBase<D, Ix2>, T>) -> Self::Object {
self.validate()?;
let observations = dataset.records().view();
let mut gmm = GaussianMixtureModel::<F>::new(self, dataset, self.rng())?;
let mut max_lower_bound = -F::infinity();
let mut best_params = None;
let mut best_iter = None;
let n_runs = self.n_runs();
for _ in 0..n_runs {
let mut lower_bound = -F::infinity();
let mut converged_iter: Option<u64> = None;
for n_iter in 0..self.max_n_iterations() {
let prev_lower_bound = lower_bound;
let (log_prob_norm, log_resp) = gmm.e_step(&observations)?;
gmm.m_step(self.reg_covariance(), &observations, &log_resp)?;
lower_bound =
GaussianMixtureModel::<F>::compute_lower_bound(&log_resp, log_prob_norm);
let change = lower_bound - prev_lower_bound;
if change.abs() < self.tolerance() {
converged_iter = Some(n_iter);
break;
}
}
if lower_bound > max_lower_bound {
max_lower_bound = lower_bound;
gmm.refresh_precisions_full();
best_params = Some(gmm.clone());
best_iter = converged_iter;
}
}
match best_iter {
Some(_n_iter) => match best_params {
Some(gmm) => Ok(gmm),
_ => Err(GmmError::LowerBoundError(
"No lower bound improvement (-inf)".to_string(),
)),
},
None => Err(GmmError::NotConverged(format!(
"EM fitting algorithm {} did not converge. Try different init parameters, \
or increase max_n_iterations, tolerance or check for degenerate data.",
(n_runs + 1)
))),
}
}
}
impl<F: Float + Lapack + Scalar, D: Data<Elem = F>> Predict<&ArrayBase<D, Ix2>, Array1<usize>>
for GaussianMixtureModel<F>
{
fn predict(&self, observations: &ArrayBase<D, Ix2>) -> Array1<usize> {
let (_, log_resp) = self.estimate_log_prob_resp(&observations);
log_resp
.mapv(|v| v.exp())
.map_axis(Axis(1), |row| row.argmax().unwrap())
}
}
impl<F: Float + Lapack + Scalar, D: Data<Elem = F>, T: Targets>
Predict<Dataset<ArrayBase<D, Ix2>, T>, Dataset<ArrayBase<D, Ix2>, Array1<usize>>>
for GaussianMixtureModel<F>
{
fn predict(
&self,
dataset: Dataset<ArrayBase<D, Ix2>, T>,
) -> Dataset<ArrayBase<D, Ix2>, Array1<usize>> {
let predicted = self.predict(dataset.records());
dataset.with_targets(predicted)
}
}
#[cfg(test)]
mod tests {
extern crate openblas_src;
use super::*;
use crate::generate_blobs;
use approx::assert_abs_diff_eq;
use ndarray::{array, stack, ArrayView1, ArrayView2, Axis};
use ndarray_linalg::error::LinalgError;
use ndarray_linalg::error::Result as LAResult;
use ndarray_rand::rand::SeedableRng;
use ndarray_rand::rand_distr::{Distribution, StandardNormal};
pub struct MultivariateNormal {
pub mean: Array1<f64>,
pub covariance: Array2<f64>,
lower: Array2<f64>,
}
impl MultivariateNormal {
pub fn new(mean: &ArrayView1<f64>, covariance: &ArrayView2<f64>) -> LAResult<Self> {
let lower = covariance.cholesky(UPLO::Lower).unwrap();
Ok(MultivariateNormal {
mean: mean.to_owned(),
covariance: covariance.to_owned(),
lower,
})
}
}
impl Distribution<Array1<f64>> for MultivariateNormal {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Array1<f64> {
let res = Array1::random_using(self.mean.shape()[0], StandardNormal, rng);
self.mean.clone() + self.lower.view().dot(&res)
}
}
#[test]
fn test_gmm_fit() {
let mut rng = Isaac64Rng::seed_from_u64(42);
let weights = array![0.5, 0.5];
let means = array![[0., 0.], [5., 5.]];
let covars = array![[[1., 0.8], [0.8, 1.]], [[1.0, -0.6], [-0.6, 1.0]]];
let mvn1 =
MultivariateNormal::new(&means.slice(s![0, ..]), &covars.slice(s![0, .., ..])).unwrap();
let mvn2 =
MultivariateNormal::new(&means.slice(s![1, ..]), &covars.slice(s![1, .., ..])).unwrap();
let n = 500;
let mut observations = Array2::zeros((2 * n, means.ncols()));
for (i, mut row) in observations.genrows_mut().into_iter().enumerate() {
let sample = if i < n {
mvn1.sample(&mut rng)
} else {
mvn2.sample(&mut rng)
};
row.assign(&sample);
}
let dataset = Dataset::from(observations);
let gmm = GaussianMixtureModel::params(2)
.with_rng(rng)
.fit(&dataset)
.expect("GMM fitting");
let w = gmm.weights();
assert_abs_diff_eq!(w, &weights, epsilon = 1e-1);
assert_abs_diff_eq!(gmm.means(), &means, epsilon = 1e-1);
assert_abs_diff_eq!(gmm.covariances(), &covars, epsilon = 1e-1);
}
fn function_test_1d(x: &Array2<f64>) -> Array2<f64> {
let mut y = Array2::zeros(x.dim());
Zip::from(&mut y).and(x).apply(|yi, &xi| {
if xi < 0.4 {
*yi = xi * xi;
} else if xi >= 0.4 && xi < 0.8 {
*yi = 3. * xi + 1.;
} else {
*yi = f64::sin(10. * xi);
}
});
y
}
#[test]
fn test_zeroed_reg_covar_failure() {
let mut rng = Isaac64Rng::seed_from_u64(42);
let xt = Array2::random_using((50, 1), Uniform::new(0., 1.), &mut rng);
let yt = function_test_1d(&xt);
let data = stack(Axis(1), &[xt.view(), yt.view()]).unwrap();
let dataset = Dataset::from(data);
let gmm = GaussianMixtureModel::params(3)
.with_reg_covariance(0.)
.with_rng(rng.clone())
.fit(&dataset);
assert!(
match gmm.expect_err("should generate an error with reg_covar being nul") {
GmmError::LinalgError(e) => match e {
LinalgError::Lapack { return_code: 2 } => true,
_ => panic!("should be a lapack error 2"),
},
_ => panic!("should be a linear algebra error"),
}
);
assert!(GaussianMixtureModel::params(3)
.with_rng(rng)
.fit(&dataset)
.is_ok());
}
#[test]
fn test_zeroed_reg_covar_const_failure() {
let xt = Array2::ones((50, 1));
let data = stack(Axis(1), &[xt.view(), xt.view()]).unwrap();
let dataset = Dataset::from(data);
let gmm = GaussianMixtureModel::params(1)
.with_reg_covariance(0.)
.fit(&dataset);
assert!(
match gmm.expect_err("should generate an error with reg_covar being nul") {
GmmError::LinalgError(e) => match e {
LinalgError::Lapack { return_code: 1 } => true,
_ => panic!("should be a lapack error 1"),
},
_ => panic!("should be a linear algebra error"),
}
);
assert!(GaussianMixtureModel::params(1).fit(&dataset).is_ok());
}
#[test]
fn test_centroids_prediction() {
let mut rng = Isaac64Rng::seed_from_u64(42);
let expected_centroids = array![[0., 1.], [-10., 20.], [-1., 10.]];
let n = 1000;
let blobs = Dataset::from(generate_blobs(n, &expected_centroids, &mut rng));
let n_clusters = expected_centroids.len_of(Axis(0));
let gmm = GaussianMixtureModel::params(n_clusters)
.with_rng(rng)
.fit(&blobs)
.expect("GMM fitting");
let gmm_centroids = gmm.centroids();
let memberships = gmm.predict(&expected_centroids);
for (i, expected_c) in expected_centroids.outer_iter().enumerate() {
let closest_c = gmm_centroids.index_axis(Axis(0), memberships[i]);
Zip::from(&closest_c)
.and(&expected_c)
.apply(|a, b| assert_abs_diff_eq!(a, b, epsilon = 1.))
}
}
#[test]
fn test_invalid_n_runs() {
assert!(
GaussianMixtureModel::params(1)
.with_n_runs(0)
.fit(&Dataset::from(array![[0.]]))
.is_err(),
"n_runs must be strictly positive"
);
}
#[test]
fn test_invalid_tolerance() {
assert!(
GaussianMixtureModel::params(1)
.with_tolerance(0.)
.fit(&Dataset::from(array![[0.]]))
.is_err(),
"tolerance must be strictly positive"
);
}
#[test]
fn test_invalid_n_clusters() {
assert!(
GaussianMixtureModel::params(0)
.fit(&Dataset::from(array![[0., 0.]]))
.is_err(),
"n_clusters must be strictly positive"
);
}
#[test]
fn test_invalid_reg_covariance() {
assert!(
GaussianMixtureModel::params(1)
.with_reg_covariance(-1e-6)
.fit(&Dataset::from(array![[0.]]))
.is_err(),
"reg_covariance must be positive"
);
}
#[test]
fn test_invalid_max_n_iterations() {
assert!(
GaussianMixtureModel::params(1)
.with_max_n_iterations(0)
.fit(&Dataset::from(array![[0.]]))
.is_err(),
"max_n_iterations must be stricly positive"
);
}
}