rusty_machine/learning/
gmm.rs

1//! Gaussian Mixture Models
2//!
3//! Provides implementation of GMMs using the EM algorithm.
4//!
5//! # Usage
6//!
7//! ```
8//! use rusty_machine::linalg::Matrix;
9//! use rusty_machine::learning::gmm::{CovOption, GaussianMixtureModel};
10//! use rusty_machine::learning::UnSupModel;
11//!
12//! let inputs = Matrix::new(4, 2, vec![1.0, 2.0, -3.0, -3.0, 0.1, 1.5, -5.0, -2.5]);
13//! let test_inputs = Matrix::new(3, 2, vec![1.0, 2.0, 3.0, 2.9, -4.4, -2.5]);
14//!
15//! // Create gmm with k(=2) classes.
16//! let mut model = GaussianMixtureModel::new(2);
17//! model.set_max_iters(10);
18//! model.cov_option = CovOption::Diagonal;
19//!
20//! // Where inputs is a Matrix with features in columns.
21//! model.train(&inputs).unwrap();
22//!
23//! // Print the means and covariances of the GMM
24//! println!("{:?}", model.means());
25//! println!("{:?}", model.covariances());
26//!
27//! // Where test_inputs is a Matrix with features in columns.
28//! let post_probs = model.predict(&test_inputs).unwrap();
29//!
30//! // Probabilities that each point comes from each Gaussian.
31//! println!("{:?}", post_probs.data());
32//! ```
33use linalg::{Matrix, MatrixSlice, Vector, BaseMatrix, BaseMatrixMut, Axes};
34use rulinalg::utils;
35
36use learning::{LearningResult, UnSupModel};
37use learning::toolkit::rand_utils;
38use learning::error::{Error, ErrorKind};
39
40/// Covariance options for GMMs.
41///
42/// - Full : The full covariance structure.
43/// - Regularized : Adds a regularization constant to the covariance diagonal.
44/// - Diagonal : Only the diagonal covariance structure.
45#[derive(Clone, Copy, Debug)]
46pub enum CovOption {
47    /// The full covariance structure.
48    Full,
49    /// Adds a regularization constant to the covariance diagonal.
50    Regularized(f64),
51    /// Only the diagonal covariance structure.
52    Diagonal,
53}
54
55
56/// A Gaussian Mixture Model
57#[derive(Debug)]
58pub struct GaussianMixtureModel {
59    comp_count: usize,
60    mix_weights: Vector<f64>,
61    model_means: Option<Matrix<f64>>,
62    model_covars: Option<Vec<Matrix<f64>>>,
63    log_lik: f64,
64    max_iters: usize,
65    /// The covariance options for the GMM.
66    pub cov_option: CovOption,
67}
68
69impl UnSupModel<Matrix<f64>, Matrix<f64>> for GaussianMixtureModel {
70    /// Train the model using inputs.
71    fn train(&mut self, inputs: &Matrix<f64>) -> LearningResult<()> {
72        let reg_value = if inputs.rows() > 1 {
73            1f64 / (inputs.rows() - 1) as f64
74        } else {
75            return Err(Error::new(ErrorKind::InvalidData, "Only one row of data provided."));
76        };
77
78        // Initialization:
79        let k = self.comp_count;
80
81        self.model_covars = {
82            let cov_mat = try!(self.initialize_covariances(inputs, reg_value));
83            Some(vec![cov_mat; k])
84        };
85
86        let random_rows: Vec<usize> =
87            rand_utils::reservoir_sample(&(0..inputs.rows()).collect::<Vec<usize>>(), k);
88        self.model_means = Some(inputs.select_rows(&random_rows));
89
90        for _ in 0..self.max_iters {
91            let log_lik_0 = self.log_lik;
92
93            let (weights, log_lik_1) = try!(self.membership_weights(inputs));
94
95            if (log_lik_1 - log_lik_0).abs() < 1e-15 {
96                break;
97            }
98
99            self.log_lik = log_lik_1;
100
101            self.update_params(inputs, weights);
102        }
103
104        Ok(())
105    }
106
107    /// Predict output from inputs.
108    fn predict(&self, inputs: &Matrix<f64>) -> LearningResult<Matrix<f64>> {
109        if let (&Some(_), &Some(_)) = (&self.model_means, &self.model_covars) {
110            Ok(try!(self.membership_weights(inputs)).0)
111        } else {
112            Err(Error::new_untrained())
113        }
114
115    }
116}
117
118impl GaussianMixtureModel {
119    /// Constructs a new Gaussian Mixture Model
120    ///
121    /// Defaults to 100 maximum iterations and
122    /// full covariance structure.
123    ///
124    /// # Examples
125    /// ```
126    /// use rusty_machine::learning::gmm::GaussianMixtureModel;
127    ///
128    /// let gmm = GaussianMixtureModel::new(3);
129    /// ```
130    pub fn new(k: usize) -> GaussianMixtureModel {
131        GaussianMixtureModel {
132            comp_count: k,
133            mix_weights: Vector::ones(k) / (k as f64),
134            model_means: None,
135            model_covars: None,
136            log_lik: 0f64,
137            max_iters: 100,
138            cov_option: CovOption::Full,
139        }
140    }
141
142    /// Constructs a new GMM with the specified prior mixture weights.
143    ///
144    /// The mixture weights must have the same length as the number of components.
145    /// Each element of the mixture weights must be non-negative.
146    ///
147    /// # Examples
148    ///
149    /// ```
150    /// use rusty_machine::learning::gmm::GaussianMixtureModel;
151    /// use rusty_machine::linalg::Vector;
152    ///
153    /// let mix_weights = Vector::new(vec![0.25, 0.25, 0.5]);
154    ///
155    /// let gmm = GaussianMixtureModel::with_weights(3, mix_weights).unwrap();
156    /// ```
157    ///
158    /// # Failures
159    ///
160    /// Fails if either of the following conditions are met:
161    ///
162    /// - Mixture weights do not have length k.
163    /// - Mixture weights have a negative entry.
164    pub fn with_weights(k: usize, mixture_weights: Vector<f64>) -> LearningResult<GaussianMixtureModel> {
165        if mixture_weights.size() != k {
166            Err(Error::new(ErrorKind::InvalidParameters, "Mixture weights must have length k."))
167        } else if mixture_weights.data().iter().any(|&x| x < 0f64) {
168            Err(Error::new(ErrorKind::InvalidParameters, "Mixture weights must have only non-negative entries.")) 
169        } else {
170            let sum = mixture_weights.sum();
171            let normalized_weights = mixture_weights / sum;
172
173            Ok(GaussianMixtureModel {
174                comp_count: k,
175                mix_weights: normalized_weights,
176                model_means: None,
177                model_covars: None,
178                log_lik: 0f64,
179                max_iters: 100,
180                cov_option: CovOption::Full,
181            })
182        }
183    }
184
185    /// The model means
186    ///
187    /// Returns an Option<&Matrix<f64>> containing
188    /// the model means. Each row represents
189    /// the mean of one of the Gaussians.
190    pub fn means(&self) -> Option<&Matrix<f64>> {
191        self.model_means.as_ref()
192    }
193
194    /// The model covariances
195    ///
196    /// Returns an Option<&Vec<Matrix<f64>>> containing
197    /// the model covariances. Each Matrix in the vector
198    /// is the covariance of one of the Gaussians.
199    pub fn covariances(&self) -> Option<&Vec<Matrix<f64>>> {
200        self.model_covars.as_ref()
201    }
202
203    /// The model mixture weights
204    ///
205    /// Returns a reference to the model mixture weights.
206    /// These are the weighted contributions of each underlying
207    /// Gaussian to the model distribution.
208    pub fn mixture_weights(&self) -> &Vector<f64> {
209        &self.mix_weights
210    }
211
212    /// Sets the max number of iterations for the EM algorithm.
213    ///
214    /// # Examples
215    ///
216    /// ```
217    /// use rusty_machine::learning::gmm::GaussianMixtureModel;
218    ///
219    /// let mut gmm = GaussianMixtureModel::new(2);
220    /// gmm.set_max_iters(5);
221    /// ```
222    pub fn set_max_iters(&mut self, iters: usize) {
223        self.max_iters = iters;
224    }
225
226    fn initialize_covariances(&self, inputs: &Matrix<f64>, reg_value: f64) -> LearningResult<Matrix<f64>> {
227        match self.cov_option {
228            CovOption::Diagonal => {
229                let variance = try!(inputs.variance(Axes::Row));
230                Ok(Matrix::from_diag(&variance.data()) * reg_value.sqrt())
231            }
232
233            CovOption::Full | CovOption::Regularized(_) => {
234                let means = inputs.mean(Axes::Row);
235                let mut cov_mat = Matrix::zeros(inputs.cols(), inputs.cols());
236                for (j, row) in cov_mat.iter_rows_mut().enumerate() {
237                    for (k, elem) in row.iter_mut().enumerate() {
238                        *elem = inputs.iter_rows().map(|r| {
239                            (r[j] - means[j]) * (r[k] - means[k])
240                        }).sum::<f64>();
241                    }
242                }
243                cov_mat *= reg_value;
244                if let CovOption::Regularized(eps) = self.cov_option {
245                    cov_mat += Matrix::<f64>::identity(cov_mat.cols()) * eps;
246                }
247                Ok(cov_mat)
248            }
249        }
250    }
251
252    fn membership_weights(&self, inputs: &Matrix<f64>) -> LearningResult<(Matrix<f64>, f64)> {
253        let n = inputs.rows();
254
255        let mut member_weights_data = Vec::with_capacity(n * self.comp_count);
256
257        // We compute the determinants and inverses now
258        let mut cov_sqrt_dets = Vec::with_capacity(self.comp_count);
259        let mut cov_invs = Vec::with_capacity(self.comp_count);
260
261        if let Some(ref covars) = self.model_covars {
262            for cov in covars {
263                // TODO: combine these. We compute det to get the inverse.
264                let covar_det = cov.det();
265                let covar_inv = try!(cov.inverse().map_err(Error::from));
266
267                cov_sqrt_dets.push(covar_det.sqrt());
268                cov_invs.push(covar_inv);
269            }
270        }
271
272        let mut log_lik = 0f64;
273
274        // Now we compute the membership weights
275        if let Some(ref means) = self.model_means {
276            for i in 0..n {
277                let mut pdfs = Vec::with_capacity(self.comp_count);
278                let x_i = MatrixSlice::from_matrix(inputs, [i, 0], 1, inputs.cols());
279
280                for j in 0..self.comp_count {
281                    let mu_j = MatrixSlice::from_matrix(means, [j, 0], 1, means.cols());
282                    let diff = x_i - mu_j;
283
284                    let pdf = (&diff * &cov_invs[j] * diff.transpose() * -0.5).into_vec()[0]
285                        .exp() / cov_sqrt_dets[j];
286                    pdfs.push(pdf);
287                }
288
289                let weighted_pdf_sum = utils::dot(&pdfs, self.mix_weights.data());
290
291                for (idx, pdf) in pdfs.iter().enumerate() {
292                    member_weights_data.push(self.mix_weights[idx] * pdf / (weighted_pdf_sum));
293                }
294
295                log_lik += weighted_pdf_sum.ln();
296            }
297        }
298
299        Ok((Matrix::new(n, self.comp_count, member_weights_data), log_lik))
300    }
301
302    fn update_params(&mut self, inputs: &Matrix<f64>, membership_weights: Matrix<f64>) {
303        let n = membership_weights.rows();
304        let d = inputs.cols();
305
306        let sum_weights = membership_weights.sum_rows();
307
308        self.mix_weights = &sum_weights / (n as f64);
309
310        let mut new_means = membership_weights.transpose() * inputs;
311
312        for (mean, w) in new_means.iter_rows_mut().zip(sum_weights.data().iter()) {
313            for m in mean.iter_mut() {
314                *m /= *w;
315            }
316        }
317
318        let mut new_covs = Vec::with_capacity(self.comp_count);
319
320        for k in 0..self.comp_count {
321            let mut cov_mat = Matrix::zeros(d, d);
322            let new_means_k = MatrixSlice::from_matrix(&new_means, [k, 0], 1, d);
323
324            for i in 0..n {
325                let inputs_i = MatrixSlice::from_matrix(inputs, [i, 0], 1, d);
326                let diff = inputs_i - new_means_k;
327                cov_mat += self.compute_cov(diff, membership_weights[[i, k]]);
328            }
329
330            if let CovOption::Regularized(eps) = self.cov_option {
331                cov_mat += Matrix::<f64>::identity(cov_mat.cols()) * eps;
332            }
333
334            new_covs.push(cov_mat / sum_weights[k]);
335
336        }
337
338        self.model_means = Some(new_means);
339        self.model_covars = Some(new_covs);
340    }
341
342    fn compute_cov(&self, diff: Matrix<f64>, weight: f64) -> Matrix<f64> {
343        match self.cov_option {
344            CovOption::Full | CovOption::Regularized(_) => (diff.transpose() * diff) * weight,
345            CovOption::Diagonal => Matrix::from_diag(&diff.elemul(&diff).into_vec()) * weight,
346        }
347    }
348}
349
350#[cfg(test)]
351mod tests {
352    use super::GaussianMixtureModel;
353    use linalg::Vector;
354
355    #[test]
356    fn test_means_none() {
357        let model = GaussianMixtureModel::new(5);
358
359        assert_eq!(model.means(), None);
360    }
361
362    #[test]
363    fn test_covars_none() {
364        let model = GaussianMixtureModel::new(5);
365
366        assert_eq!(model.covariances(), None);
367    }
368
369    #[test]
370    fn test_negative_mixtures() {
371        let mix_weights = Vector::new(vec![-0.25, 0.75, 0.5]);
372        let gmm_res = GaussianMixtureModel::with_weights(3, mix_weights);
373        assert!(gmm_res.is_err());
374    }
375
376    #[test]
377    fn test_wrong_length_mixtures() {
378        let mix_weights = Vector::new(vec![0.1, 0.25, 0.75, 0.5]);
379        let gmm_res = GaussianMixtureModel::with_weights(3, mix_weights);
380        assert!(gmm_res.is_err());
381    }
382}