linfa_clustering/gaussian_mixture/
algorithm.rs

1use crate::gaussian_mixture::errors::GmmError;
2use crate::gaussian_mixture::hyperparams::{
3    GmmCovarType, GmmInitMethod, GmmParams, GmmValidParams,
4};
5use crate::k_means::KMeans;
6use linfa::{prelude::*, DatasetBase, Float};
7use linfa_linalg::{cholesky::*, triangular::*};
8use ndarray::{s, Array, Array1, Array2, Array3, ArrayBase, Axis, Data, Ix2, Ix3, Zip};
9use ndarray_rand::rand::Rng;
10use ndarray_rand::rand_distr::Uniform;
11use ndarray_rand::RandomExt;
12use ndarray_stats::QuantileExt;
13use rand_xoshiro::Xoshiro256Plus;
14#[cfg(feature = "serde")]
15use serde_crate::{Deserialize, Serialize};
16
17#[cfg_attr(
18    feature = "serde",
19    derive(Serialize, Deserialize),
20    serde(crate = "serde_crate")
21)]
22/// Gaussian Mixture Model (GMM) aims at clustering a dataset by finding normally
23/// distributed sub datasets (hence the Gaussian Mixture name) .
24///
25/// GMM assumes all the data points are generated from a mixture of a number K
26/// of Gaussian distributions with certain parameters.
27/// Expectation-maximization (EM) algorithm is used to fit the GMM to the dataset
28/// by parameterizing the weight, mean, and covariance of each cluster distribution.
29///
30/// This implementation is a port of the [scikit-learn 0.23.2 Gaussian Mixture](https://scikit-learn.org)
31/// implementation.
32///
33/// ## The algorithm  
34///
35/// The general idea is to maximize the likelihood (equivalently the log likelihood)
36/// that is maximising the probability that the dataset is drawn from our mixture of normal distributions.
37///
38/// After an initialization step which can be either from random distribution or from the result
39/// of the [KMeans](KMeans) algorithm (which is the default value of the `init_method` parameter).
40/// The core EM iterative algorithm for Gaussian Mixture is a fixed-point two-step algorithm:
41///
42/// 1. Expectation step: compute the expectation of the likelihood of the current gaussian mixture model wrt the dataset.
43/// 2. Maximization step: update the gaussian parameters (weigths, means and covariances) to maximize the likelihood.
44///
45/// We stop iterating when there is no significant gaussian parameters change (controlled by the `tolerance` parameter) or
46/// if we reach a max number of iterations (controlled by `max_n_iterations` parameter)
47/// As the initialization of the algorithm is subject to randomness, several initializations are performed (controlled by
48/// the `n_runs` parameter).   
49///
50/// ## Tutorial
51///
52/// Let's do a walkthrough of a training-predict-save example.
53///
54/// ```rust
55/// use linfa::DatasetBase;
56/// use linfa::prelude::*;
57/// use linfa_clustering::{GmmValidParams, GaussianMixtureModel};
58/// use linfa_datasets::generate;
59/// use ndarray::{Axis, array, s, Zip};
60/// use ndarray_rand::rand::SeedableRng;
61/// use rand_xoshiro::Xoshiro256Plus;
62/// use approx::assert_abs_diff_eq;
63///
64/// let mut rng = Xoshiro256Plus::seed_from_u64(42);
65/// let expected_centroids = array![[0., 1.], [-10., 20.], [-1., 10.]];
66/// let n = 200;
67///
68/// // We generate a dataset from points normally distributed around some distant centroids.  
69/// let dataset = DatasetBase::from(generate::blobs(n, &expected_centroids, &mut rng));
70///
71/// // Our GMM is expected to have a number of clusters equals the number of centroids
72/// // used to generate the dataset
73/// let n_clusters = expected_centroids.len_of(Axis(0));
74///
75/// // We fit the model from the dataset setting some options
76/// let gmm = GaussianMixtureModel::params(n_clusters)
77///             .n_runs(10)
78///             .tolerance(1e-4)
79///             .with_rng(rng)
80///             .fit(&dataset).expect("GMM fitting");
81///
82/// // Then we can get dataset membership information, targets contain **cluster indexes**
83/// // corresponding to the cluster infos in the list of GMM means and covariances
84/// let blobs_dataset = gmm.predict(dataset);
85/// let DatasetBase {
86///     records: _blobs_records,
87///     targets: blobs_targets,
88///     ..
89/// } = blobs_dataset;
90/// println!("GMM means = {:?}", gmm.means());
91/// println!("GMM covariances = {:?}", gmm.covariances());
92/// println!("GMM membership = {:?}", blobs_targets);
93///
94/// // We can also get the nearest cluster for a new point
95/// let new_observation = DatasetBase::from(array![[-9., 20.5]]);
96/// // Predict returns the **index** of the nearest cluster
97/// let dataset = gmm.predict(new_observation);
98/// // We can retrieve the actual centroid of the closest cluster using `.centroids()` (alias of .means())
99/// let closest_centroid = &gmm.centroids().index_axis(Axis(0), dataset.targets()[0]);
100/// ```
101#[derive(Debug, PartialEq)]
102pub struct GaussianMixtureModel<F: Float> {
103    covar_type: GmmCovarType,
104    weights: Array1<F>,
105    means: Array2<F>,
106    covariances: Array3<F>,
107    precisions: Array3<F>,
108    precisions_chol: Array3<F>,
109}
110
111impl<F: Float> Clone for GaussianMixtureModel<F> {
112    fn clone(&self) -> Self {
113        Self {
114            covar_type: self.covar_type,
115            weights: self.weights.to_owned(),
116            means: self.means.to_owned(),
117            covariances: self.covariances.to_owned(),
118            precisions: self.precisions.to_owned(),
119            precisions_chol: self.precisions_chol.to_owned(),
120        }
121    }
122}
123
124impl<F: Float> GaussianMixtureModel<F> {
125    fn new<D: Data<Elem = F>, R: Rng + Clone, T>(
126        hyperparameters: &GmmValidParams<F, R>,
127        dataset: &DatasetBase<ArrayBase<D, Ix2>, T>,
128        mut rng: R,
129    ) -> Result<GaussianMixtureModel<F>, GmmError> {
130        let observations = dataset.records().view();
131        let n_samples = observations.nrows();
132
133        // We initialize responsabilities (n_samples, n_clusters) of each clusters
134        // that is, given a sample, the probabilities of a cluster being the source.
135        // Responsabilities can be initialized either from a KMeans result or randomly.
136        let resp = match hyperparameters.init_method() {
137            GmmInitMethod::KMeans => {
138                let model = KMeans::params_with_rng(hyperparameters.n_clusters(), rng)
139                    .check()
140                    .unwrap()
141                    .fit(dataset)?;
142                let mut resp = Array::<F, Ix2>::zeros((n_samples, hyperparameters.n_clusters()));
143                for (k, idx) in model.predict(dataset.records()).iter().enumerate() {
144                    resp[[k, *idx]] = F::cast(1.);
145                }
146                resp
147            }
148            GmmInitMethod::Random => {
149                let mut resp = Array2::<f64>::random_using(
150                    (n_samples, hyperparameters.n_clusters()),
151                    Uniform::new(0., 1.),
152                    &mut rng,
153                );
154                let totals = &resp.sum_axis(Axis(1)).insert_axis(Axis(0));
155                resp = (resp.reversed_axes() / totals).reversed_axes();
156                resp.mapv(F::cast)
157            }
158        };
159
160        // We compute an initial GMM model from dataset and initial responsabilities wrt
161        // to covariance specification.
162        let (mut weights, means, covariances) = Self::estimate_gaussian_parameters(
163            &observations,
164            &resp,
165            hyperparameters.covariance_type(),
166            hyperparameters.reg_covariance(),
167        )?;
168        weights /= F::cast(n_samples);
169
170        // GmmCovarType = full
171        let precisions_chol = Self::compute_precisions_cholesky_full(&covariances)?;
172        let precisions = Self::compute_precisions_full(&precisions_chol);
173
174        Ok(GaussianMixtureModel {
175            covar_type: *hyperparameters.covariance_type(),
176            weights,
177            means,
178            covariances,
179            precisions,
180            precisions_chol,
181        })
182    }
183}
184
185impl<F: Float> GaussianMixtureModel<F> {
186    pub fn params(n_clusters: usize) -> GmmParams<F, Xoshiro256Plus> {
187        GmmParams::new(n_clusters)
188    }
189
190    pub fn params_with_rng<R: Rng + Clone>(n_clusters: usize, rng: R) -> GmmParams<F, R> {
191        GmmParams::new_with_rng(n_clusters, rng)
192    }
193
194    pub fn weights(&self) -> &Array1<F> {
195        &self.weights
196    }
197
198    pub fn means(&self) -> &Array2<F> {
199        &self.means
200    }
201
202    pub fn covariances(&self) -> &Array3<F> {
203        &self.covariances
204    }
205
206    pub fn precisions(&self) -> &Array3<F> {
207        &self.precisions
208    }
209
210    pub fn centroids(&self) -> &Array2<F> {
211        self.means()
212    }
213
214    #[allow(clippy::type_complexity)]
215    fn estimate_gaussian_parameters<D: Data<Elem = F>>(
216        observations: &ArrayBase<D, Ix2>,
217        resp: &Array2<F>,
218        _covar_type: &GmmCovarType,
219        reg_covar: F,
220    ) -> Result<(Array1<F>, Array2<F>, Array3<F>), GmmError> {
221        let nk = resp.sum_axis(Axis(0));
222        if nk.min()? < &(F::cast(10.) * F::epsilon()) {
223            return Err(GmmError::EmptyCluster(format!(
224                "Cluster #{} has no more point. Consider decreasing number of clusters or change initialization.",
225                nk.argmin()? + 1
226            )));
227        }
228
229        let nk2 = nk.to_owned().insert_axis(Axis(1));
230        let means = resp.t().dot(observations) / nk2;
231        // GmmCovarType = Full
232        let covariances =
233            Self::estimate_gaussian_covariances_full(observations, resp, &nk, &means, reg_covar);
234        Ok((nk, means, covariances))
235    }
236
237    fn estimate_gaussian_covariances_full<D: Data<Elem = F>>(
238        observations: &ArrayBase<D, Ix2>,
239        resp: &Array2<F>,
240        nk: &Array1<F>,
241        means: &Array2<F>,
242        reg_covar: F,
243    ) -> Array3<F> {
244        let n_clusters = means.nrows();
245        let n_features = means.ncols();
246        let mut covariances = Array::zeros((n_clusters, n_features, n_features));
247        for k in 0..n_clusters {
248            let diff = observations - &means.row(k);
249            let m = &diff.t() * &resp.index_axis(Axis(1), k);
250            let mut cov_k = m.dot(&diff) / nk[k];
251            cov_k.diag_mut().mapv_inplace(|x| x + reg_covar);
252            covariances.slice_mut(s![k, .., ..]).assign(&cov_k);
253        }
254        covariances
255    }
256
257    fn compute_precisions_cholesky_full<D: Data<Elem = F>>(
258        covariances: &ArrayBase<D, Ix3>,
259    ) -> Result<Array3<F>, GmmError> {
260        let n_clusters = covariances.shape()[0];
261        let n_features = covariances.shape()[1];
262        let mut precisions_chol = Array::zeros((n_clusters, n_features, n_features));
263        for (k, covariance) in covariances.outer_iter().enumerate() {
264            let sol = {
265                let decomp = covariance.cholesky()?;
266                decomp.solve_triangular_into(Array::eye(n_features), UPLO::Lower)?
267            };
268
269            precisions_chol.slice_mut(s![k, .., ..]).assign(&sol.t());
270        }
271        Ok(precisions_chol)
272    }
273
274    fn compute_precisions_full<D: Data<Elem = F>>(
275        precisions_chol: &ArrayBase<D, Ix3>,
276    ) -> Array3<F> {
277        let mut precisions = Array3::zeros(precisions_chol.dim());
278        for (k, prec_chol) in precisions_chol.outer_iter().enumerate() {
279            precisions
280                .slice_mut(s![k, .., ..])
281                .assign(&prec_chol.dot(&prec_chol.t()));
282        }
283        precisions
284    }
285
286    // Refresh precisions value only at the end of the fitting procedure
287    fn refresh_precisions_full(&mut self) {
288        self.precisions = Self::compute_precisions_full(&self.precisions_chol);
289    }
290
291    fn e_step<D: Data<Elem = F>>(
292        &self,
293        observations: &ArrayBase<D, Ix2>,
294    ) -> Result<(F, Array2<F>), GmmError> {
295        let (log_prob_norm, log_resp) = self.estimate_log_prob_resp(observations);
296        let log_mean = log_prob_norm.mean().unwrap();
297        Ok((log_mean, log_resp))
298    }
299
300    fn m_step<D: Data<Elem = F>>(
301        &mut self,
302        reg_covar: F,
303        observations: &ArrayBase<D, Ix2>,
304        log_resp: &Array2<F>,
305    ) -> Result<(), GmmError> {
306        let n_samples = observations.nrows();
307        let (weights, means, covariances) = Self::estimate_gaussian_parameters(
308            observations,
309            &log_resp.mapv(|x| x.exp()),
310            &self.covar_type,
311            reg_covar,
312        )?;
313        self.means = means;
314        self.weights = weights / F::cast(n_samples);
315        self.covariances = covariances;
316        // GmmCovarType = Full()
317        self.precisions_chol = Self::compute_precisions_cholesky_full(&self.covariances)?;
318        Ok(())
319    }
320
321    // We keep methods names and method boundaries from scikit-learn implementation
322    // which handles also Bayesian mixture hence below the _log_resp argument which is not used.
323    fn compute_lower_bound<D: Data<Elem = F>>(
324        _log_resp: &ArrayBase<D, Ix2>,
325        log_prob_norm: F,
326    ) -> F {
327        log_prob_norm
328    }
329
330    // Estimate log probabilities (log P(X)) and responsibilities for each sample.
331    // Compute weighted log probabilities per component (log P(X)) and responsibilities
332    // for each sample in X with respect to the current state of the model.
333    fn estimate_log_prob_resp<D: Data<Elem = F>>(
334        &self,
335        observations: &ArrayBase<D, Ix2>,
336    ) -> (Array1<F>, Array2<F>) {
337        let weighted_log_prob = self.estimate_weighted_log_prob(observations);
338        let log_prob_norm = weighted_log_prob
339            .mapv(|x| x.exp())
340            .sum_axis(Axis(1))
341            .mapv(|x| x.ln());
342        let log_resp = weighted_log_prob - log_prob_norm.to_owned().insert_axis(Axis(1));
343        (log_prob_norm, log_resp)
344    }
345
346    // Estimate weighted log probabilities for each samples wrt to the model
347    fn estimate_weighted_log_prob<D: Data<Elem = F>>(
348        &self,
349        observations: &ArrayBase<D, Ix2>,
350    ) -> Array2<F> {
351        self.estimate_log_prob(observations) + self.estimate_log_weights()
352    }
353
354    // Compute log probabilities for each samples wrt to the model which is gaussian
355    fn estimate_log_prob<D: Data<Elem = F>>(&self, observations: &ArrayBase<D, Ix2>) -> Array2<F> {
356        self.estimate_log_gaussian_prob(observations)
357    }
358
359    // Compute the log likelihood in case of the gaussian probabilities
360    // log(P(X|Mean, Precision)) = -0.5*(d*ln(2*PI)-ln(det(Precision))-(X-Mean)^t.Precision.(X-Mean)
361    fn estimate_log_gaussian_prob<D: Data<Elem = F>>(
362        &self,
363        observations: &ArrayBase<D, Ix2>,
364    ) -> Array2<F> {
365        let n_samples = observations.nrows();
366        let n_features = observations.ncols();
367        let means = self.means();
368        let n_clusters = means.nrows();
369        // GmmCovarType = full
370        // det(precision_chol) is half of det(precision)
371        let log_det = Self::compute_log_det_cholesky_full(&self.precisions_chol, n_features);
372        let mut log_prob: Array2<F> = Array::zeros((n_samples, n_clusters));
373        Zip::indexed(means.rows())
374            .and(self.precisions_chol.outer_iter())
375            .for_each(|k, mu, prec_chol| {
376                let diff = (&observations.to_owned() - &mu).dot(&prec_chol);
377                log_prob
378                    .slice_mut(s![.., k])
379                    .assign(&diff.mapv(|v| v * v).sum_axis(Axis(1)))
380            });
381        log_prob.mapv(|v| {
382            F::cast(-0.5) * (v + F::cast(n_features as f64 * f64::ln(2. * std::f64::consts::PI)))
383        }) + log_det
384    }
385
386    fn compute_log_det_cholesky_full<D: Data<Elem = F>>(
387        matrix_chol: &ArrayBase<D, Ix3>,
388        n_features: usize,
389    ) -> Array1<F> {
390        let n_clusters = matrix_chol.shape()[0];
391        let log_diags = &matrix_chol
392            .to_owned()
393            .into_shape((n_clusters, n_features * n_features))
394            .unwrap()
395            .slice(s![.., ..; n_features+1])
396            .to_owned()
397            .mapv(|x| x.ln());
398        log_diags.sum_axis(Axis(1))
399    }
400
401    fn estimate_log_weights(&self) -> Array1<F> {
402        self.weights().mapv(|x| x.ln())
403    }
404}
405
406impl<F: Float, R: Rng + Clone, D: Data<Elem = F>, T> Fit<ArrayBase<D, Ix2>, T, GmmError>
407    for GmmValidParams<F, R>
408{
409    type Object = GaussianMixtureModel<F>;
410
411    fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object, GmmError> {
412        let observations = dataset.records().view();
413        let mut gmm = GaussianMixtureModel::<F>::new(self, dataset, self.rng())?;
414
415        let mut max_lower_bound = -F::infinity();
416        let mut best_params = None;
417        let mut best_iter = None;
418
419        let n_runs = self.n_runs();
420
421        for _ in 0..n_runs {
422            let mut lower_bound = -F::infinity();
423
424            let mut converged_iter: Option<u64> = None;
425            for n_iter in 0..self.max_n_iterations() {
426                let prev_lower_bound = lower_bound;
427                let (log_prob_norm, log_resp) = gmm.e_step(&observations)?;
428                gmm.m_step(self.reg_covariance(), &observations, &log_resp)?;
429                lower_bound =
430                    GaussianMixtureModel::<F>::compute_lower_bound(&log_resp, log_prob_norm);
431                let change = lower_bound - prev_lower_bound;
432                if change.abs() < self.tolerance() {
433                    converged_iter = Some(n_iter);
434                    break;
435                }
436            }
437
438            if lower_bound > max_lower_bound {
439                max_lower_bound = lower_bound;
440                gmm.refresh_precisions_full();
441                best_params = Some(gmm.clone());
442                best_iter = converged_iter;
443            }
444        }
445
446        match best_iter {
447            Some(_n_iter) => match best_params {
448                Some(gmm) => Ok(gmm),
449                _ => Err(GmmError::LowerBoundError(
450                    "No lower bound improvement (-inf)".to_string(),
451                )),
452            },
453            None => Err(GmmError::NotConverged(format!(
454                "EM fitting algorithm {} did not converge. Try different init parameters, \
455                            or increase max_n_iterations, tolerance or check for degenerate data.",
456                (n_runs + 1)
457            ))),
458        }
459    }
460}
461
462impl<F: Float, D: Data<Elem = F>> PredictInplace<ArrayBase<D, Ix2>, Array1<usize>>
463    for GaussianMixtureModel<F>
464{
465    fn predict_inplace(&self, observations: &ArrayBase<D, Ix2>, targets: &mut Array1<usize>) {
466        assert_eq!(
467            observations.nrows(),
468            targets.len(),
469            "The number of data points must match the number of output targets."
470        );
471
472        let (_, log_resp) = self.estimate_log_prob_resp(observations);
473        *targets = log_resp
474            .mapv(F::exp)
475            .map_axis(Axis(1), |row| row.argmax().unwrap());
476    }
477
478    fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array1<usize> {
479        Array1::zeros(x.nrows())
480    }
481}
482
483#[cfg(test)]
484mod tests {
485    use super::*;
486    use approx::{abs_diff_eq, assert_abs_diff_eq};
487    use linfa_datasets::generate;
488    use linfa_linalg::LinalgError;
489    use linfa_linalg::Result as LAResult;
490    use ndarray::{array, concatenate, ArrayView1, ArrayView2, Axis};
491    use ndarray_rand::rand::prelude::ThreadRng;
492    use ndarray_rand::rand::SeedableRng;
493    use ndarray_rand::rand_distr::Normal;
494    use ndarray_rand::rand_distr::{Distribution, StandardNormal};
495    use ndarray_rand::RandomExt;
496
497    #[test]
498    fn autotraits() {
499        fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
500        has_autotraits::<GaussianMixtureModel<f64>>();
501        has_autotraits::<GmmError>();
502        has_autotraits::<GmmParams<f64, Xoshiro256Plus>>();
503        has_autotraits::<GmmValidParams<f64, Xoshiro256Plus>>();
504        has_autotraits::<GmmInitMethod>();
505        has_autotraits::<GmmCovarType>();
506    }
507
508    pub struct MultivariateNormal {
509        mean: Array1<f64>,
510        /// Lower triangular matrix (Cholesky decomposition of the covariance matrix)
511        lower: Array2<f64>,
512    }
513    impl MultivariateNormal {
514        pub fn new(mean: &ArrayView1<f64>, covariance: &ArrayView2<f64>) -> LAResult<Self> {
515            let lower = covariance.cholesky()?;
516            Ok(MultivariateNormal {
517                mean: mean.to_owned(),
518                lower,
519            })
520        }
521    }
522    impl Distribution<Array1<f64>> for MultivariateNormal {
523        fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Array1<f64> {
524            // standard normal distribution
525            let res = Array1::random_using(self.mean.shape()[0], StandardNormal, rng);
526            // use Cholesky decomposition to obtain a sample of our general multivariate normal
527            self.mean.clone() + self.lower.view().dot(&res)
528        }
529    }
530
531    #[test]
532    fn test_gmm_fit() {
533        let mut rng = Xoshiro256Plus::seed_from_u64(42);
534        let weights = array![0.5, 0.5];
535        let means = array![[0., 0.], [5., 5.]];
536        let covars = array![[[1., 0.8], [0.8, 1.]], [[1.0, -0.6], [-0.6, 1.0]]];
537        let mvn1 =
538            MultivariateNormal::new(&means.slice(s![0, ..]), &covars.slice(s![0, .., ..])).unwrap();
539        let mvn2 =
540            MultivariateNormal::new(&means.slice(s![1, ..]), &covars.slice(s![1, .., ..])).unwrap();
541
542        let n = 500;
543        let mut observations = Array2::zeros((2 * n, means.ncols()));
544        for (i, mut row) in observations.rows_mut().into_iter().enumerate() {
545            let sample = if i < n {
546                mvn1.sample(&mut rng)
547            } else {
548                mvn2.sample(&mut rng)
549            };
550            row.assign(&sample);
551        }
552        let dataset = DatasetBase::from(observations);
553        let gmm = GaussianMixtureModel::params(2)
554            .with_rng(rng)
555            .fit(&dataset)
556            .expect("GMM fitting");
557
558        // check weights
559        let w = gmm.weights();
560        assert_abs_diff_eq!(w, &weights, epsilon = 1e-1);
561        // check means (since kmeans centroids are ordered randomly, we try matching both orderings)
562        let m = gmm.means();
563        assert!(
564            abs_diff_eq!(means, &m, epsilon = 1e-1)
565                || abs_diff_eq!(means, m.slice(s![..;-1, ..]), epsilon = 1e-1)
566        );
567        // check covariances
568        let c = gmm.covariances();
569        assert!(
570            abs_diff_eq!(covars, &c, epsilon = 1e-1)
571                || abs_diff_eq!(covars, c.slice(s![..;-1, .., ..]), epsilon = 1e-1)
572        );
573    }
574
575    #[test]
576    fn test_gmm_covariances() {
577        let rng = rand_xoshiro::Xoshiro256Plus::seed_from_u64(123);
578
579        let data_0 = ndarray::Array::random((500,), Normal::new(0., 0.5).unwrap());
580        let data_1 = ndarray::Array::random((500,), Normal::new(1., 0.5).unwrap());
581        let data_2 = ndarray::Array::random((500,), Normal::new(2., 0.5).unwrap());
582        let data = ndarray::concatenate![ndarray::Axis(0), data_0, data_1, data_2];
583
584        let data_2d = data.insert_axis(ndarray::Axis(1)).to_owned();
585        let dataset = linfa::DatasetBase::from(data_2d);
586
587        let gmm = GaussianMixtureModel::params(3)
588            .n_runs(1)
589            .tolerance(1e-4)
590            .with_rng(rng)
591            .max_n_iterations(500)
592            .fit(&dataset)
593            .expect("GMM fit");
594
595        // expected results from scikit-learn 1.3.1
596        let expected = array![[[0.22564062]], [[0.26204446]], [[0.23393885]]];
597        let expected = Array::from_iter(expected.iter().cloned());
598        let actual = gmm.covariances();
599        let actual = Array::from_iter(actual.iter().cloned());
600        assert_abs_diff_eq!(expected, actual, epsilon = 1e-1);
601    }
602
603    fn function_test_1d(x: &Array2<f64>) -> Array2<f64> {
604        let mut y = Array2::zeros(x.dim());
605        Zip::from(&mut y).and(x).for_each(|yi, &xi| {
606            if xi < 0.4 {
607                *yi = xi * xi;
608            } else if (0.4..0.8).contains(&xi) {
609                *yi = 10. * xi + 1.;
610            } else {
611                *yi = f64::sin(10. * xi);
612            }
613        });
614        y
615    }
616
617    #[test]
618    fn test_zeroed_reg_covar_failure() {
619        let mut rng = Xoshiro256Plus::seed_from_u64(42);
620        let xt = Array2::random_using((50, 1), Uniform::new(0., 1.0), &mut rng);
621        let yt = function_test_1d(&xt);
622        let data = concatenate(Axis(1), &[xt.view(), yt.view()]).unwrap();
623        let dataset = DatasetBase::from(data);
624
625        // Test that cholesky decomposition fails when reg_covariance is zero
626        let gmm = GaussianMixtureModel::params(3)
627            .reg_covariance(0.)
628            .with_rng(rng.clone())
629            .fit(&dataset);
630
631        match gmm.expect_err("should generate an error with reg_covar being nul") {
632            GmmError::LinalgError(e) => {
633                assert!(matches!(e, LinalgError::NotPositiveDefinite));
634            }
635            e => panic!("should be a linear algebra error: {:?}", e),
636        }
637        // Test it passes when default value is used
638        assert!(GaussianMixtureModel::params(3)
639            .with_rng(rng)
640            .fit(&dataset)
641            .is_ok());
642    }
643
644    #[test]
645    fn test_zeroed_reg_covar_const_failure() {
646        // repeat values such that covariance is zero
647        let xt = Array2::ones((50, 1));
648        let data = concatenate(Axis(1), &[xt.view(), xt.view()]).unwrap();
649        let dataset = DatasetBase::from(data);
650
651        // Test that cholesky decomposition fails when reg_covariance is zero
652        let gmm = GaussianMixtureModel::params(1)
653            .reg_covariance(0.)
654            .fit(&dataset);
655
656        gmm.expect_err("should generate an error with reg_covar being nul");
657
658        // Test it passes when default value is used
659        assert!(GaussianMixtureModel::params(1).fit(&dataset).is_ok());
660    }
661
662    #[test]
663    fn test_centroids_prediction() {
664        let mut rng = Xoshiro256Plus::seed_from_u64(42);
665        let expected_centroids = array![[0., 1.], [-10., 20.], [-1., 10.]];
666        let n = 1000;
667        let blobs = DatasetBase::from(generate::blobs(n, &expected_centroids, &mut rng));
668
669        let n_clusters = expected_centroids.len_of(Axis(0));
670        let gmm = GaussianMixtureModel::params(n_clusters)
671            .with_rng(rng)
672            .fit(&blobs)
673            .expect("GMM fitting");
674
675        let gmm_centroids = gmm.centroids();
676        let memberships = gmm.predict(&expected_centroids);
677
678        // check that centroids used to generate test dataset belongs to the right predicted cluster
679        for (i, expected_c) in expected_centroids.outer_iter().enumerate() {
680            let closest_c = gmm_centroids.index_axis(Axis(0), memberships[i]);
681            Zip::from(&closest_c)
682                .and(&expected_c)
683                .for_each(|a, b| assert_abs_diff_eq!(a, b, epsilon = 1.))
684        }
685    }
686
687    #[test]
688    fn test_invalid_n_runs() {
689        assert!(
690            GaussianMixtureModel::params(1)
691                .n_runs(0)
692                .fit(&DatasetBase::from(array![[0.]]))
693                .is_err(),
694            "n_runs must be strictly positive"
695        );
696    }
697
698    #[test]
699    fn test_invalid_tolerance() {
700        assert!(
701            GaussianMixtureModel::params(1)
702                .tolerance(0.)
703                .fit(&DatasetBase::from(array![[0.]]))
704                .is_err(),
705            "tolerance must be strictly positive"
706        );
707    }
708
709    #[test]
710    fn test_invalid_n_clusters() {
711        assert!(
712            GaussianMixtureModel::params(0)
713                .fit(&DatasetBase::from(array![[0., 0.]]))
714                .is_err(),
715            "n_clusters must be strictly positive"
716        );
717    }
718
719    #[test]
720    fn test_invalid_reg_covariance() {
721        assert!(
722            GaussianMixtureModel::params(1)
723                .reg_covariance(-1e-6)
724                .fit(&DatasetBase::from(array![[0.]]))
725                .is_err(),
726            "reg_covariance must be positive"
727        );
728    }
729
730    #[test]
731    fn test_invalid_max_n_iterations() {
732        assert!(
733            GaussianMixtureModel::params(1)
734                .max_n_iterations(0)
735                .fit(&DatasetBase::from(array![[0.]]))
736                .is_err(),
737            "max_n_iterations must be stricly positive"
738        );
739    }
740
741    fn fittable<T: Fit<Array2<f64>, (), GmmError>>(_: T) {}
742    #[test]
743    fn thread_rng_fittable() {
744        fittable(GaussianMixtureModel::params_with_rng(
745            1,
746            ThreadRng::default(),
747        ));
748    }
749}