use crate::DType;
use numr::error::Result;
use numr::runtime::Runtime;
use numr::tensor::Tensor;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum CovarianceType {
#[default]
Full,
Tied,
Diagonal,
Spherical,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum GmmInit {
#[default]
KMeans,
Random,
}
#[derive(Debug, Clone)]
pub struct GmmOptions {
pub n_components: usize,
pub covariance_type: CovarianceType,
pub max_iter: usize,
pub tol: f64,
pub n_init: usize,
pub init: GmmInit,
pub reg_covar: f64,
}
impl Default for GmmOptions {
fn default() -> Self {
Self {
n_components: 1,
covariance_type: CovarianceType::Full,
max_iter: 100,
tol: 1e-3,
n_init: 1,
init: GmmInit::KMeans,
reg_covar: 1e-6,
}
}
}
#[derive(Debug, Clone)]
pub struct GmmModel<R: Runtime<DType = DType>> {
pub weights: Tensor<R>,
pub means: Tensor<R>,
pub covariances: Tensor<R>,
pub precisions_cholesky: Tensor<R>,
pub converged: bool,
pub n_iter: usize,
pub lower_bound: f64,
}
pub trait GmmAlgorithms<R: Runtime<DType = DType>> {
fn gmm_fit(&self, data: &Tensor<R>, options: &GmmOptions) -> Result<GmmModel<R>>;
fn gmm_predict(&self, model: &GmmModel<R>, data: &Tensor<R>) -> Result<Tensor<R>>;
fn gmm_predict_proba(&self, model: &GmmModel<R>, data: &Tensor<R>) -> Result<Tensor<R>>;
fn gmm_score(&self, model: &GmmModel<R>, data: &Tensor<R>) -> Result<Tensor<R>>;
fn gmm_bic(&self, model: &GmmModel<R>, data: &Tensor<R>) -> Result<Tensor<R>>;
fn gmm_aic(&self, model: &GmmModel<R>, data: &Tensor<R>) -> Result<Tensor<R>>;
}