1use crate::gaussian_mixture::errors::GmmError;
2use crate::gaussian_mixture::hyperparams::{
3 GmmCovarType, GmmInitMethod, GmmParams, GmmValidParams,
4};
5use crate::k_means::KMeans;
6use linfa::{prelude::*, DatasetBase, Float};
7use linfa_linalg::{cholesky::*, triangular::*};
8use ndarray::{s, Array, Array1, Array2, Array3, ArrayBase, Axis, Data, Ix2, Ix3, Zip};
9use ndarray_rand::rand::Rng;
10use ndarray_rand::rand_distr::Uniform;
11use ndarray_rand::RandomExt;
12use ndarray_stats::QuantileExt;
13use rand_xoshiro::Xoshiro256Plus;
14#[cfg(feature = "serde")]
15use serde_crate::{Deserialize, Serialize};
16
17#[cfg_attr(
18 feature = "serde",
19 derive(Serialize, Deserialize),
20 serde(crate = "serde_crate")
21)]
22#[derive(Debug, PartialEq)]
102pub struct GaussianMixtureModel<F: Float> {
103 covar_type: GmmCovarType,
104 weights: Array1<F>,
105 means: Array2<F>,
106 covariances: Array3<F>,
107 precisions: Array3<F>,
108 precisions_chol: Array3<F>,
109}
110
111impl<F: Float> Clone for GaussianMixtureModel<F> {
112 fn clone(&self) -> Self {
113 Self {
114 covar_type: self.covar_type,
115 weights: self.weights.to_owned(),
116 means: self.means.to_owned(),
117 covariances: self.covariances.to_owned(),
118 precisions: self.precisions.to_owned(),
119 precisions_chol: self.precisions_chol.to_owned(),
120 }
121 }
122}
123
124impl<F: Float> GaussianMixtureModel<F> {
125 fn new<D: Data<Elem = F>, R: Rng + Clone, T>(
126 hyperparameters: &GmmValidParams<F, R>,
127 dataset: &DatasetBase<ArrayBase<D, Ix2>, T>,
128 mut rng: R,
129 ) -> Result<GaussianMixtureModel<F>, GmmError> {
130 let observations = dataset.records().view();
131 let n_samples = observations.nrows();
132
133 let resp = match hyperparameters.init_method() {
137 GmmInitMethod::KMeans => {
138 let model = KMeans::params_with_rng(hyperparameters.n_clusters(), rng)
139 .check()
140 .unwrap()
141 .fit(dataset)?;
142 let mut resp = Array::<F, Ix2>::zeros((n_samples, hyperparameters.n_clusters()));
143 for (k, idx) in model.predict(dataset.records()).iter().enumerate() {
144 resp[[k, *idx]] = F::cast(1.);
145 }
146 resp
147 }
148 GmmInitMethod::Random => {
149 let mut resp = Array2::<f64>::random_using(
150 (n_samples, hyperparameters.n_clusters()),
151 Uniform::new(0., 1.),
152 &mut rng,
153 );
154 let totals = &resp.sum_axis(Axis(1)).insert_axis(Axis(0));
155 resp = (resp.reversed_axes() / totals).reversed_axes();
156 resp.mapv(F::cast)
157 }
158 };
159
160 let (mut weights, means, covariances) = Self::estimate_gaussian_parameters(
163 &observations,
164 &resp,
165 hyperparameters.covariance_type(),
166 hyperparameters.reg_covariance(),
167 )?;
168 weights /= F::cast(n_samples);
169
170 let precisions_chol = Self::compute_precisions_cholesky_full(&covariances)?;
172 let precisions = Self::compute_precisions_full(&precisions_chol);
173
174 Ok(GaussianMixtureModel {
175 covar_type: *hyperparameters.covariance_type(),
176 weights,
177 means,
178 covariances,
179 precisions,
180 precisions_chol,
181 })
182 }
183}
184
185impl<F: Float> GaussianMixtureModel<F> {
186 pub fn params(n_clusters: usize) -> GmmParams<F, Xoshiro256Plus> {
187 GmmParams::new(n_clusters)
188 }
189
190 pub fn params_with_rng<R: Rng + Clone>(n_clusters: usize, rng: R) -> GmmParams<F, R> {
191 GmmParams::new_with_rng(n_clusters, rng)
192 }
193
194 pub fn weights(&self) -> &Array1<F> {
195 &self.weights
196 }
197
198 pub fn means(&self) -> &Array2<F> {
199 &self.means
200 }
201
202 pub fn covariances(&self) -> &Array3<F> {
203 &self.covariances
204 }
205
206 pub fn precisions(&self) -> &Array3<F> {
207 &self.precisions
208 }
209
210 pub fn centroids(&self) -> &Array2<F> {
211 self.means()
212 }
213
214 #[allow(clippy::type_complexity)]
215 fn estimate_gaussian_parameters<D: Data<Elem = F>>(
216 observations: &ArrayBase<D, Ix2>,
217 resp: &Array2<F>,
218 _covar_type: &GmmCovarType,
219 reg_covar: F,
220 ) -> Result<(Array1<F>, Array2<F>, Array3<F>), GmmError> {
221 let nk = resp.sum_axis(Axis(0));
222 if nk.min()? < &(F::cast(10.) * F::epsilon()) {
223 return Err(GmmError::EmptyCluster(format!(
224 "Cluster #{} has no more point. Consider decreasing number of clusters or change initialization.",
225 nk.argmin()? + 1
226 )));
227 }
228
229 let nk2 = nk.to_owned().insert_axis(Axis(1));
230 let means = resp.t().dot(observations) / nk2;
231 let covariances =
233 Self::estimate_gaussian_covariances_full(observations, resp, &nk, &means, reg_covar);
234 Ok((nk, means, covariances))
235 }
236
237 fn estimate_gaussian_covariances_full<D: Data<Elem = F>>(
238 observations: &ArrayBase<D, Ix2>,
239 resp: &Array2<F>,
240 nk: &Array1<F>,
241 means: &Array2<F>,
242 reg_covar: F,
243 ) -> Array3<F> {
244 let n_clusters = means.nrows();
245 let n_features = means.ncols();
246 let mut covariances = Array::zeros((n_clusters, n_features, n_features));
247 for k in 0..n_clusters {
248 let diff = observations - &means.row(k);
249 let m = &diff.t() * &resp.index_axis(Axis(1), k);
250 let mut cov_k = m.dot(&diff) / nk[k];
251 cov_k.diag_mut().mapv_inplace(|x| x + reg_covar);
252 covariances.slice_mut(s![k, .., ..]).assign(&cov_k);
253 }
254 covariances
255 }
256
257 fn compute_precisions_cholesky_full<D: Data<Elem = F>>(
258 covariances: &ArrayBase<D, Ix3>,
259 ) -> Result<Array3<F>, GmmError> {
260 let n_clusters = covariances.shape()[0];
261 let n_features = covariances.shape()[1];
262 let mut precisions_chol = Array::zeros((n_clusters, n_features, n_features));
263 for (k, covariance) in covariances.outer_iter().enumerate() {
264 let sol = {
265 let decomp = covariance.cholesky()?;
266 decomp.solve_triangular_into(Array::eye(n_features), UPLO::Lower)?
267 };
268
269 precisions_chol.slice_mut(s![k, .., ..]).assign(&sol.t());
270 }
271 Ok(precisions_chol)
272 }
273
274 fn compute_precisions_full<D: Data<Elem = F>>(
275 precisions_chol: &ArrayBase<D, Ix3>,
276 ) -> Array3<F> {
277 let mut precisions = Array3::zeros(precisions_chol.dim());
278 for (k, prec_chol) in precisions_chol.outer_iter().enumerate() {
279 precisions
280 .slice_mut(s![k, .., ..])
281 .assign(&prec_chol.dot(&prec_chol.t()));
282 }
283 precisions
284 }
285
286 fn refresh_precisions_full(&mut self) {
288 self.precisions = Self::compute_precisions_full(&self.precisions_chol);
289 }
290
291 fn e_step<D: Data<Elem = F>>(
292 &self,
293 observations: &ArrayBase<D, Ix2>,
294 ) -> Result<(F, Array2<F>), GmmError> {
295 let (log_prob_norm, log_resp) = self.estimate_log_prob_resp(observations);
296 let log_mean = log_prob_norm.mean().unwrap();
297 Ok((log_mean, log_resp))
298 }
299
300 fn m_step<D: Data<Elem = F>>(
301 &mut self,
302 reg_covar: F,
303 observations: &ArrayBase<D, Ix2>,
304 log_resp: &Array2<F>,
305 ) -> Result<(), GmmError> {
306 let n_samples = observations.nrows();
307 let (weights, means, covariances) = Self::estimate_gaussian_parameters(
308 observations,
309 &log_resp.mapv(|x| x.exp()),
310 &self.covar_type,
311 reg_covar,
312 )?;
313 self.means = means;
314 self.weights = weights / F::cast(n_samples);
315 self.covariances = covariances;
316 self.precisions_chol = Self::compute_precisions_cholesky_full(&self.covariances)?;
318 Ok(())
319 }
320
321 fn compute_lower_bound<D: Data<Elem = F>>(
324 _log_resp: &ArrayBase<D, Ix2>,
325 log_prob_norm: F,
326 ) -> F {
327 log_prob_norm
328 }
329
330 fn estimate_log_prob_resp<D: Data<Elem = F>>(
334 &self,
335 observations: &ArrayBase<D, Ix2>,
336 ) -> (Array1<F>, Array2<F>) {
337 let weighted_log_prob = self.estimate_weighted_log_prob(observations);
338 let log_prob_norm = weighted_log_prob
339 .mapv(|x| x.exp())
340 .sum_axis(Axis(1))
341 .mapv(|x| x.ln());
342 let log_resp = weighted_log_prob - log_prob_norm.to_owned().insert_axis(Axis(1));
343 (log_prob_norm, log_resp)
344 }
345
346 fn estimate_weighted_log_prob<D: Data<Elem = F>>(
348 &self,
349 observations: &ArrayBase<D, Ix2>,
350 ) -> Array2<F> {
351 self.estimate_log_prob(observations) + self.estimate_log_weights()
352 }
353
354 fn estimate_log_prob<D: Data<Elem = F>>(&self, observations: &ArrayBase<D, Ix2>) -> Array2<F> {
356 self.estimate_log_gaussian_prob(observations)
357 }
358
359 fn estimate_log_gaussian_prob<D: Data<Elem = F>>(
362 &self,
363 observations: &ArrayBase<D, Ix2>,
364 ) -> Array2<F> {
365 let n_samples = observations.nrows();
366 let n_features = observations.ncols();
367 let means = self.means();
368 let n_clusters = means.nrows();
369 let log_det = Self::compute_log_det_cholesky_full(&self.precisions_chol, n_features);
372 let mut log_prob: Array2<F> = Array::zeros((n_samples, n_clusters));
373 Zip::indexed(means.rows())
374 .and(self.precisions_chol.outer_iter())
375 .for_each(|k, mu, prec_chol| {
376 let diff = (&observations.to_owned() - &mu).dot(&prec_chol);
377 log_prob
378 .slice_mut(s![.., k])
379 .assign(&diff.mapv(|v| v * v).sum_axis(Axis(1)))
380 });
381 log_prob.mapv(|v| {
382 F::cast(-0.5) * (v + F::cast(n_features as f64 * f64::ln(2. * std::f64::consts::PI)))
383 }) + log_det
384 }
385
386 fn compute_log_det_cholesky_full<D: Data<Elem = F>>(
387 matrix_chol: &ArrayBase<D, Ix3>,
388 n_features: usize,
389 ) -> Array1<F> {
390 let n_clusters = matrix_chol.shape()[0];
391 let log_diags = &matrix_chol
392 .to_owned()
393 .into_shape((n_clusters, n_features * n_features))
394 .unwrap()
395 .slice(s![.., ..; n_features+1])
396 .to_owned()
397 .mapv(|x| x.ln());
398 log_diags.sum_axis(Axis(1))
399 }
400
401 fn estimate_log_weights(&self) -> Array1<F> {
402 self.weights().mapv(|x| x.ln())
403 }
404}
405
406impl<F: Float, R: Rng + Clone, D: Data<Elem = F>, T> Fit<ArrayBase<D, Ix2>, T, GmmError>
407 for GmmValidParams<F, R>
408{
409 type Object = GaussianMixtureModel<F>;
410
411 fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object, GmmError> {
412 let observations = dataset.records().view();
413 let mut gmm = GaussianMixtureModel::<F>::new(self, dataset, self.rng())?;
414
415 let mut max_lower_bound = -F::infinity();
416 let mut best_params = None;
417 let mut best_iter = None;
418
419 let n_runs = self.n_runs();
420
421 for _ in 0..n_runs {
422 let mut lower_bound = -F::infinity();
423
424 let mut converged_iter: Option<u64> = None;
425 for n_iter in 0..self.max_n_iterations() {
426 let prev_lower_bound = lower_bound;
427 let (log_prob_norm, log_resp) = gmm.e_step(&observations)?;
428 gmm.m_step(self.reg_covariance(), &observations, &log_resp)?;
429 lower_bound =
430 GaussianMixtureModel::<F>::compute_lower_bound(&log_resp, log_prob_norm);
431 let change = lower_bound - prev_lower_bound;
432 if change.abs() < self.tolerance() {
433 converged_iter = Some(n_iter);
434 break;
435 }
436 }
437
438 if lower_bound > max_lower_bound {
439 max_lower_bound = lower_bound;
440 gmm.refresh_precisions_full();
441 best_params = Some(gmm.clone());
442 best_iter = converged_iter;
443 }
444 }
445
446 match best_iter {
447 Some(_n_iter) => match best_params {
448 Some(gmm) => Ok(gmm),
449 _ => Err(GmmError::LowerBoundError(
450 "No lower bound improvement (-inf)".to_string(),
451 )),
452 },
453 None => Err(GmmError::NotConverged(format!(
454 "EM fitting algorithm {} did not converge. Try different init parameters, \
455 or increase max_n_iterations, tolerance or check for degenerate data.",
456 (n_runs + 1)
457 ))),
458 }
459 }
460}
461
462impl<F: Float, D: Data<Elem = F>> PredictInplace<ArrayBase<D, Ix2>, Array1<usize>>
463 for GaussianMixtureModel<F>
464{
465 fn predict_inplace(&self, observations: &ArrayBase<D, Ix2>, targets: &mut Array1<usize>) {
466 assert_eq!(
467 observations.nrows(),
468 targets.len(),
469 "The number of data points must match the number of output targets."
470 );
471
472 let (_, log_resp) = self.estimate_log_prob_resp(observations);
473 *targets = log_resp
474 .mapv(F::exp)
475 .map_axis(Axis(1), |row| row.argmax().unwrap());
476 }
477
478 fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array1<usize> {
479 Array1::zeros(x.nrows())
480 }
481}
482
483#[cfg(test)]
484mod tests {
485 use super::*;
486 use approx::{abs_diff_eq, assert_abs_diff_eq};
487 use linfa_datasets::generate;
488 use linfa_linalg::LinalgError;
489 use linfa_linalg::Result as LAResult;
490 use ndarray::{array, concatenate, ArrayView1, ArrayView2, Axis};
491 use ndarray_rand::rand::prelude::ThreadRng;
492 use ndarray_rand::rand::SeedableRng;
493 use ndarray_rand::rand_distr::Normal;
494 use ndarray_rand::rand_distr::{Distribution, StandardNormal};
495 use ndarray_rand::RandomExt;
496
497 #[test]
498 fn autotraits() {
499 fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
500 has_autotraits::<GaussianMixtureModel<f64>>();
501 has_autotraits::<GmmError>();
502 has_autotraits::<GmmParams<f64, Xoshiro256Plus>>();
503 has_autotraits::<GmmValidParams<f64, Xoshiro256Plus>>();
504 has_autotraits::<GmmInitMethod>();
505 has_autotraits::<GmmCovarType>();
506 }
507
508 pub struct MultivariateNormal {
509 mean: Array1<f64>,
510 lower: Array2<f64>,
512 }
513 impl MultivariateNormal {
514 pub fn new(mean: &ArrayView1<f64>, covariance: &ArrayView2<f64>) -> LAResult<Self> {
515 let lower = covariance.cholesky()?;
516 Ok(MultivariateNormal {
517 mean: mean.to_owned(),
518 lower,
519 })
520 }
521 }
522 impl Distribution<Array1<f64>> for MultivariateNormal {
523 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Array1<f64> {
524 let res = Array1::random_using(self.mean.shape()[0], StandardNormal, rng);
526 self.mean.clone() + self.lower.view().dot(&res)
528 }
529 }
530
531 #[test]
532 fn test_gmm_fit() {
533 let mut rng = Xoshiro256Plus::seed_from_u64(42);
534 let weights = array![0.5, 0.5];
535 let means = array![[0., 0.], [5., 5.]];
536 let covars = array![[[1., 0.8], [0.8, 1.]], [[1.0, -0.6], [-0.6, 1.0]]];
537 let mvn1 =
538 MultivariateNormal::new(&means.slice(s![0, ..]), &covars.slice(s![0, .., ..])).unwrap();
539 let mvn2 =
540 MultivariateNormal::new(&means.slice(s![1, ..]), &covars.slice(s![1, .., ..])).unwrap();
541
542 let n = 500;
543 let mut observations = Array2::zeros((2 * n, means.ncols()));
544 for (i, mut row) in observations.rows_mut().into_iter().enumerate() {
545 let sample = if i < n {
546 mvn1.sample(&mut rng)
547 } else {
548 mvn2.sample(&mut rng)
549 };
550 row.assign(&sample);
551 }
552 let dataset = DatasetBase::from(observations);
553 let gmm = GaussianMixtureModel::params(2)
554 .with_rng(rng)
555 .fit(&dataset)
556 .expect("GMM fitting");
557
558 let w = gmm.weights();
560 assert_abs_diff_eq!(w, &weights, epsilon = 1e-1);
561 let m = gmm.means();
563 assert!(
564 abs_diff_eq!(means, &m, epsilon = 1e-1)
565 || abs_diff_eq!(means, m.slice(s![..;-1, ..]), epsilon = 1e-1)
566 );
567 let c = gmm.covariances();
569 assert!(
570 abs_diff_eq!(covars, &c, epsilon = 1e-1)
571 || abs_diff_eq!(covars, c.slice(s![..;-1, .., ..]), epsilon = 1e-1)
572 );
573 }
574
575 #[test]
576 fn test_gmm_covariances() {
577 let rng = rand_xoshiro::Xoshiro256Plus::seed_from_u64(123);
578
579 let data_0 = ndarray::Array::random((500,), Normal::new(0., 0.5).unwrap());
580 let data_1 = ndarray::Array::random((500,), Normal::new(1., 0.5).unwrap());
581 let data_2 = ndarray::Array::random((500,), Normal::new(2., 0.5).unwrap());
582 let data = ndarray::concatenate![ndarray::Axis(0), data_0, data_1, data_2];
583
584 let data_2d = data.insert_axis(ndarray::Axis(1)).to_owned();
585 let dataset = linfa::DatasetBase::from(data_2d);
586
587 let gmm = GaussianMixtureModel::params(3)
588 .n_runs(1)
589 .tolerance(1e-4)
590 .with_rng(rng)
591 .max_n_iterations(500)
592 .fit(&dataset)
593 .expect("GMM fit");
594
595 let expected = array![[[0.22564062]], [[0.26204446]], [[0.23393885]]];
597 let expected = Array::from_iter(expected.iter().cloned());
598 let actual = gmm.covariances();
599 let actual = Array::from_iter(actual.iter().cloned());
600 assert_abs_diff_eq!(expected, actual, epsilon = 1e-1);
601 }
602
603 fn function_test_1d(x: &Array2<f64>) -> Array2<f64> {
604 let mut y = Array2::zeros(x.dim());
605 Zip::from(&mut y).and(x).for_each(|yi, &xi| {
606 if xi < 0.4 {
607 *yi = xi * xi;
608 } else if (0.4..0.8).contains(&xi) {
609 *yi = 10. * xi + 1.;
610 } else {
611 *yi = f64::sin(10. * xi);
612 }
613 });
614 y
615 }
616
617 #[test]
618 fn test_zeroed_reg_covar_failure() {
619 let mut rng = Xoshiro256Plus::seed_from_u64(42);
620 let xt = Array2::random_using((50, 1), Uniform::new(0., 1.0), &mut rng);
621 let yt = function_test_1d(&xt);
622 let data = concatenate(Axis(1), &[xt.view(), yt.view()]).unwrap();
623 let dataset = DatasetBase::from(data);
624
625 let gmm = GaussianMixtureModel::params(3)
627 .reg_covariance(0.)
628 .with_rng(rng.clone())
629 .fit(&dataset);
630
631 match gmm.expect_err("should generate an error with reg_covar being nul") {
632 GmmError::LinalgError(e) => {
633 assert!(matches!(e, LinalgError::NotPositiveDefinite));
634 }
635 e => panic!("should be a linear algebra error: {:?}", e),
636 }
637 assert!(GaussianMixtureModel::params(3)
639 .with_rng(rng)
640 .fit(&dataset)
641 .is_ok());
642 }
643
644 #[test]
645 fn test_zeroed_reg_covar_const_failure() {
646 let xt = Array2::ones((50, 1));
648 let data = concatenate(Axis(1), &[xt.view(), xt.view()]).unwrap();
649 let dataset = DatasetBase::from(data);
650
651 let gmm = GaussianMixtureModel::params(1)
653 .reg_covariance(0.)
654 .fit(&dataset);
655
656 gmm.expect_err("should generate an error with reg_covar being nul");
657
658 assert!(GaussianMixtureModel::params(1).fit(&dataset).is_ok());
660 }
661
662 #[test]
663 fn test_centroids_prediction() {
664 let mut rng = Xoshiro256Plus::seed_from_u64(42);
665 let expected_centroids = array![[0., 1.], [-10., 20.], [-1., 10.]];
666 let n = 1000;
667 let blobs = DatasetBase::from(generate::blobs(n, &expected_centroids, &mut rng));
668
669 let n_clusters = expected_centroids.len_of(Axis(0));
670 let gmm = GaussianMixtureModel::params(n_clusters)
671 .with_rng(rng)
672 .fit(&blobs)
673 .expect("GMM fitting");
674
675 let gmm_centroids = gmm.centroids();
676 let memberships = gmm.predict(&expected_centroids);
677
678 for (i, expected_c) in expected_centroids.outer_iter().enumerate() {
680 let closest_c = gmm_centroids.index_axis(Axis(0), memberships[i]);
681 Zip::from(&closest_c)
682 .and(&expected_c)
683 .for_each(|a, b| assert_abs_diff_eq!(a, b, epsilon = 1.))
684 }
685 }
686
687 #[test]
688 fn test_invalid_n_runs() {
689 assert!(
690 GaussianMixtureModel::params(1)
691 .n_runs(0)
692 .fit(&DatasetBase::from(array![[0.]]))
693 .is_err(),
694 "n_runs must be strictly positive"
695 );
696 }
697
698 #[test]
699 fn test_invalid_tolerance() {
700 assert!(
701 GaussianMixtureModel::params(1)
702 .tolerance(0.)
703 .fit(&DatasetBase::from(array![[0.]]))
704 .is_err(),
705 "tolerance must be strictly positive"
706 );
707 }
708
709 #[test]
710 fn test_invalid_n_clusters() {
711 assert!(
712 GaussianMixtureModel::params(0)
713 .fit(&DatasetBase::from(array![[0., 0.]]))
714 .is_err(),
715 "n_clusters must be strictly positive"
716 );
717 }
718
719 #[test]
720 fn test_invalid_reg_covariance() {
721 assert!(
722 GaussianMixtureModel::params(1)
723 .reg_covariance(-1e-6)
724 .fit(&DatasetBase::from(array![[0.]]))
725 .is_err(),
726 "reg_covariance must be positive"
727 );
728 }
729
730 #[test]
731 fn test_invalid_max_n_iterations() {
732 assert!(
733 GaussianMixtureModel::params(1)
734 .max_n_iterations(0)
735 .fit(&DatasetBase::from(array![[0.]]))
736 .is_err(),
737 "max_n_iterations must be stricly positive"
738 );
739 }
740
741 fn fittable<T: Fit<Array2<f64>, (), GmmError>>(_: T) {}
742 #[test]
743 fn thread_rng_fittable() {
744 fittable(GaussianMixtureModel::params_with_rng(
745 1,
746 ThreadRng::default(),
747 ));
748 }
749}