use crate::{ClusteringError, Result};
use ndarray::{Array1, Array2, ArrayView2};
#[derive(Debug, Clone, Copy)]
pub enum CovarianceType {
Full,
Tied,
Diag,
Spherical,
}
#[derive(Debug, Clone, Copy)]
pub enum Criterion {
BIC,
AIC,
}
pub struct GaussianMixtureBuilder {
n_components: usize,
covariance_type: CovarianceType,
em_tolerance: f64,
max_iter: usize,
n_init: usize,
}
impl GaussianMixtureBuilder {
pub fn new(n_components: usize) -> Self {
Self {
n_components,
covariance_type: CovarianceType::Full,
em_tolerance: 1e-3,
max_iter: 100,
n_init: 5,
}
}
pub fn covariance_type(mut self, cov_type: CovarianceType) -> Self {
self.covariance_type = cov_type;
self
}
pub fn em_tolerance(mut self, tol: f64) -> Self {
self.em_tolerance = tol;
self
}
pub fn max_iter(mut self, max_iter: usize) -> Self {
self.max_iter = max_iter;
self
}
pub fn n_init(mut self, n_init: usize) -> Self {
self.n_init = n_init;
self
}
pub fn build(self) -> GaussianMixture {
GaussianMixture {
n_components: self.n_components,
covariance_type: self.covariance_type,
em_tolerance: self.em_tolerance,
max_iter: self.max_iter,
n_init: self.n_init,
}
}
}
pub struct GaussianMixture {
n_components: usize,
covariance_type: CovarianceType,
em_tolerance: f64,
max_iter: usize,
n_init: usize,
}
impl GaussianMixture {
pub fn builder(n_components: usize) -> GaussianMixtureBuilder {
GaussianMixtureBuilder::new(n_components)
}
pub fn new(n_components: usize) -> Self {
Self::builder(n_components).build()
}
pub fn fit(&mut self, data: &ArrayView2<f64>) -> Result<GMMResult> {
unimplemented!("GaussianMixture::fit")
}
pub fn predict_proba(&self, data: &ArrayView2<f64>) -> Result<Array2<f64>> {
unimplemented!("GaussianMixture::predict_proba")
}
pub fn score_samples(&self, data: &ArrayView2<f64>) -> Result<Array1<f64>> {
unimplemented!("GaussianMixture::score_samples")
}
}
pub struct GMMResult {
pub means: Array2<f64>,
pub covariances: Vec<Array2<f64>>,
pub weights: Array1<f64>,
pub converged: bool,
pub n_iter: usize,
}