use crate::DType;
use numr::error::Result;
use numr::runtime::Runtime;
use numr::tensor::Tensor;
use super::gmm::{CovarianceType, GmmInit};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum WeightConcentrationPrior {
#[default]
DirichletProcess,
DirichletDistribution,
}
#[derive(Debug, Clone)]
pub struct BayesianGmmOptions {
pub n_components: usize,
pub covariance_type: CovarianceType,
pub max_iter: usize,
pub tol: f64,
pub n_init: usize,
pub init: GmmInit,
pub reg_covar: f64,
pub weight_concentration_prior_type: WeightConcentrationPrior,
pub weight_concentration_prior: Option<f64>,
pub mean_precision_prior: Option<f64>,
pub degrees_of_freedom_prior: Option<f64>,
}
impl Default for BayesianGmmOptions {
fn default() -> Self {
Self {
n_components: 1,
covariance_type: CovarianceType::Full,
max_iter: 100,
tol: 1e-3,
n_init: 1,
init: GmmInit::KMeans,
reg_covar: 1e-6,
weight_concentration_prior_type: WeightConcentrationPrior::DirichletProcess,
weight_concentration_prior: None,
mean_precision_prior: None,
degrees_of_freedom_prior: None,
}
}
}
#[derive(Debug, Clone)]
pub struct BayesianGmmModel<R: Runtime<DType = DType>> {
pub weights: Tensor<R>,
pub means: Tensor<R>,
pub covariances: Tensor<R>,
pub precisions_cholesky: Tensor<R>,
pub weight_concentration: Tensor<R>,
pub mean_precision: Tensor<R>,
pub degrees_of_freedom: Tensor<R>,
pub converged: bool,
pub n_iter: usize,
pub lower_bound: f64,
}
pub trait BayesianGmmAlgorithms<R: Runtime<DType = DType>> {
fn bayesian_gmm_fit(
&self,
data: &Tensor<R>,
options: &BayesianGmmOptions,
) -> Result<BayesianGmmModel<R>>;
fn bayesian_gmm_predict(
&self,
model: &BayesianGmmModel<R>,
data: &Tensor<R>,
) -> Result<Tensor<R>>;
fn bayesian_gmm_predict_proba(
&self,
model: &BayesianGmmModel<R>,
data: &Tensor<R>,
) -> Result<Tensor<R>>;
fn bayesian_gmm_score(
&self,
model: &BayesianGmmModel<R>,
data: &Tensor<R>,
) -> Result<Tensor<R>>;
}