use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2, Axis, ScalarOperand};
use scirs2_core::numeric::{Float, FromPrimitive};
use scirs2_core::random::{Rng, RngExt, SeedableRng};
use std::f64::consts::PI;
use std::fmt::Debug;
use std::iter::Sum;
use crate::error::{ClusteringError, Result};
use crate::vq::kmeans_plus_plus;
use statrs::statistics::Statistics;
type GMMParams<F> = (Array1<F>, Array2<F>, Vec<Array2<F>>);
type GMMFitResult<F> = (Array1<F>, Array2<F>, Vec<Array2<F>>, F, usize, bool);
#[derive(Debug, Clone, Copy)]
pub enum CovarianceType {
Full,
Diagonal,
Tied,
Spherical,
}
#[derive(Debug, Clone, Copy)]
pub enum GMMInit {
KMeans,
Random,
}
#[derive(Debug, Clone)]
pub struct GMMOptions<F: Float> {
pub n_components: usize,
pub covariance_type: CovarianceType,
pub tol: F,
pub max_iter: usize,
pub n_init: usize,
pub init_method: GMMInit,
pub random_seed: Option<u64>,
pub reg_covar: F,
}
impl<F: Float + FromPrimitive> Default for GMMOptions<F> {
fn default() -> Self {
Self {
n_components: 1,
covariance_type: CovarianceType::Full,
tol: F::from(1e-3).expect("Failed to convert constant to float"),
max_iter: 100,
n_init: 1,
init_method: GMMInit::KMeans,
random_seed: None,
reg_covar: F::from(1e-6).expect("Failed to convert constant to float"),
}
}
}
pub struct GaussianMixture<F: Float> {
options: GMMOptions<F>,
weights: Option<Array1<F>>,
means: Option<Array2<F>>,
covariances: Option<Vec<Array2<F>>>,
lower_bound: Option<F>,
n_iter: Option<usize>,
converged: bool,
}
impl<F: Float + FromPrimitive + Debug + ScalarOperand + Sum + std::borrow::Borrow<f64>>
GaussianMixture<F>
{
pub fn new(options: GMMOptions<F>) -> Self {
Self {
options,
weights: None,
means: None,
covariances: None,
lower_bound: None,
n_iter: None,
converged: false,
}
}
pub fn fit(&mut self, data: ArrayView2<F>) -> Result<()> {
let n_samples = data.shape()[0];
let _n_features = data.shape()[1];
if n_samples < self.options.n_components {
return Err(ClusteringError::InvalidInput(
"Number of samples must be >= number of components".to_string(),
));
}
let mut best_lower_bound = F::neg_infinity();
let mut best_params = None;
for _ in 0..self.options.n_init {
let (weights, means, covariances, lower_bound, n_iter, converged) =
self.fit_single(data)?;
if lower_bound > best_lower_bound {
best_lower_bound = lower_bound;
best_params = Some((weights, means, covariances, lower_bound, n_iter, converged));
}
}
if let Some((weights, means, covariances, lower_bound, n_iter, converged)) = best_params {
self.weights = Some(weights);
self.means = Some(means);
self.covariances = Some(covariances);
self.lower_bound = Some(lower_bound);
self.n_iter = Some(n_iter);
self.converged = converged;
}
Ok(())
}
fn fit_single(&self, data: ArrayView2<F>) -> Result<GMMFitResult<F>> {
let _n_samples = data.shape()[0];
let _n_features = data.shape()[1];
let _n_components = self.options.n_components;
let (mut weights, mut means, mut covariances) = self.initialize_params(data)?;
let mut lower_bound = F::neg_infinity();
let mut converged = false;
for iter in 0..self.options.max_iter {
let (resp_, new_lower_bound) = self.e_step(data, &weights, &means, &covariances)?;
let change = (new_lower_bound - lower_bound).abs();
if change < self.options.tol {
converged = true;
return Ok((
weights,
means,
covariances,
new_lower_bound,
iter + 1,
converged,
));
}
lower_bound = new_lower_bound;
(weights, means, covariances) = self.m_step(data, resp_)?;
}
Ok((
weights,
means,
covariances,
lower_bound,
self.options.max_iter,
converged,
))
}
fn initialize_params(&self, data: ArrayView2<F>) -> Result<GMMParams<F>> {
let n_samples = data.shape()[0];
let n_features = data.shape()[1];
let n_components = self.options.n_components;
let weights = Array1::from_elem(
n_components,
F::one() / F::from(n_components).expect("Failed to convert to float"),
);
let means = match self.options.init_method {
GMMInit::KMeans => {
kmeans_plus_plus(data, n_components, self.options.random_seed)?
}
GMMInit::Random => {
let mut rng = match self.options.random_seed {
Some(seed) => scirs2_core::random::rngs::StdRng::seed_from_u64(seed),
None => scirs2_core::random::rngs::StdRng::seed_from_u64(
scirs2_core::random::rng().random::<u64>(),
),
};
let mut means = Array2::zeros((n_components, n_features));
for i in 0..n_components {
let idx = rng.random_range(0..n_samples);
means.slice_mut(s![i, ..]).assign(&data.slice(s![idx, ..]));
}
means
}
};
let mut covariances = Vec::with_capacity(n_components);
let data_mean = data.mean_axis(Axis(0)).expect("Operation failed");
let mut variance = Array1::<F>::zeros(n_features);
for i in 0..n_samples {
let diff = &data.slice(s![i, ..]) - &data_mean;
variance = variance + &diff.mapv(|x| x * x);
}
variance = variance / F::from(n_samples - 1).expect("Failed to convert to float");
match self.options.covariance_type {
CovarianceType::Spherical => {
let avg_variance =
variance.sum() / F::from(variance.len()).expect("Operation failed");
for _ in 0..n_components {
let mut cov = Array2::<F>::zeros((n_features, n_features));
for i in 0..n_features {
cov[[i, i]] = avg_variance;
}
covariances.push(cov);
}
}
CovarianceType::Diagonal => {
for _ in 0..n_components {
let mut cov = Array2::<F>::zeros((n_features, n_features));
for i in 0..n_features {
cov[[i, i]] = variance[i];
}
covariances.push(cov);
}
}
CovarianceType::Full | CovarianceType::Tied => {
for _ in 0..n_components {
let mut cov = Array2::<F>::zeros((n_features, n_features));
for i in 0..n_features {
cov[[i, i]] = variance[i];
}
covariances.push(cov);
}
}
}
Ok((weights, means, covariances))
}
fn e_step(
&self,
data: ArrayView2<F>,
weights: &Array1<F>,
means: &Array2<F>,
covariances: &[Array2<F>],
) -> Result<(Array2<F>, F)> {
let n_samples = data.shape()[0];
let n_components = self.options.n_components;
let mut log_prob = Array2::zeros((n_samples, n_components));
for (k, covariance) in covariances.iter().enumerate().take(n_components) {
let log_prob_k = self.log_multivariate_normal_density(
data,
means.slice(s![k, ..]).view(),
covariance,
)?;
log_prob.slice_mut(s![.., k]).assign(&log_prob_k);
}
for k in 0..n_components {
let log_weight = weights[k].ln();
log_prob
.slice_mut(s![.., k])
.mapv_inplace(|x| x + log_weight);
}
let log_prob_norm = self.logsumexp(log_prob.view(), Axis(1))?;
let mut resp_ = log_prob.clone();
for i in 0..n_samples {
for k in 0..n_components {
resp_[[i, k]] = (resp_[[i, k]] - log_prob_norm[i]).exp();
}
}
let lower_bound =
log_prob_norm.sum() / F::from(log_prob_norm.len()).expect("Operation failed");
Ok((resp_, lower_bound))
}
fn m_step(&self, data: ArrayView2<F>, resp_: Array2<F>) -> Result<GMMParams<F>> {
let n_samples = data.shape()[0];
let n_features = data.shape()[1];
let n_components = self.options.n_components;
let nk = resp_.sum_axis(Axis(0));
let weights = &nk / F::from(n_samples).expect("Failed to convert to float");
let mut means = Array2::zeros((n_components, n_features));
for k in 0..n_components {
let mut mean_k = Array1::zeros(n_features);
for i in 0..n_samples {
mean_k = mean_k + &data.slice(s![i, ..]) * resp_[[i, k]];
}
means.slice_mut(s![k, ..]).assign(&(&mean_k / nk[k]));
}
let mut covariances = Vec::with_capacity(n_components);
match self.options.covariance_type {
CovarianceType::Full => {
for k in 0..n_components {
let mean_k = means.slice(s![k, ..]);
let mut cov = Array2::zeros((n_features, n_features));
for i in 0..n_samples {
let diff = &data.slice(s![i, ..]) - &mean_k;
let outer = self.outer_product(diff.view(), diff.view());
cov = cov + &outer * resp_[[i, k]];
}
cov = cov / nk[k];
for i in 0..n_features {
cov[[i, i]] = cov[[i, i]] + self.options.reg_covar;
}
covariances.push(cov);
}
}
_ => {
for k in 0..n_components {
let mean_k = means.slice(s![k, ..]);
let mut cov = Array2::zeros((n_features, n_features));
for i in 0..n_samples {
let diff = &data.slice(s![i, ..]) - &mean_k;
for j in 0..n_features {
cov[[j, j]] = cov[[j, j]] + diff[j] * diff[j] * resp_[[i, k]];
}
}
for j in 0..n_features {
cov[[j, j]] = cov[[j, j]] / nk[k] + self.options.reg_covar;
}
covariances.push(cov);
}
}
}
Ok((weights, means, covariances))
}
fn log_multivariate_normal_density(
&self,
data: ArrayView2<F>,
mean: ArrayView1<F>,
covariance: &Array2<F>,
) -> Result<Array1<F>> {
let n_samples = data.shape()[0];
let n_features = data.shape()[1];
let mut log_prob = Array1::zeros(n_samples);
let mut log_det = F::zero();
for i in 0..n_features {
log_det = log_det + covariance[[i, i]].ln();
}
let norm_const =
F::from(n_features as f64 * (2.0 * PI).ln()).expect("Operation failed") + log_det;
for i in 0..n_samples {
let diff = &data.slice(s![i, ..]) - &mean;
let mut mahalanobis = F::zero();
for j in 0..n_features {
mahalanobis = mahalanobis + diff[j] * diff[j] / covariance[[j, j]];
}
log_prob[i] = F::from(-0.5).expect("Failed to convert constant to float")
* (norm_const + mahalanobis);
}
Ok(log_prob)
}
fn logsumexp(&self, arr: ArrayView2<F>, axis: Axis) -> Result<Array1<F>> {
let max_vals = arr.fold_axis(axis, F::neg_infinity(), |&a, &b| a.max(b));
let mut result = Array1::zeros(max_vals.len());
match axis {
Axis(1) => {
for i in 0..arr.shape()[0] {
let mut sum = F::zero();
for j in 0..arr.shape()[1] {
sum = sum + (arr[[i, j]] - max_vals[i]).exp();
}
result[i] = max_vals[i] + sum.ln();
}
}
_ => {
return Err(ClusteringError::InvalidInput(
"Only axis 1 is supported for logsumexp".to_string(),
));
}
}
Ok(result)
}
fn outer_product(&self, a: ArrayView1<F>, b: ArrayView1<F>) -> Array2<F> {
let n = a.len();
let m = b.len();
let mut result = Array2::zeros((n, m));
for i in 0..n {
for j in 0..m {
result[[i, j]] = a[i] * b[j];
}
}
result
}
pub fn predict(&self, data: ArrayView2<F>) -> Result<Array1<i32>> {
if self.weights.is_none() || self.means.is_none() || self.covariances.is_none() {
return Err(ClusteringError::InvalidInput(
"Model has not been fitted yet".to_string(),
));
}
let weights = self.weights.as_ref().expect("Operation failed");
let means = self.means.as_ref().expect("Operation failed");
let covariances = self.covariances.as_ref().expect("Operation failed");
let (resp__, _) = self.e_step(data, weights, means, covariances)?;
let mut labels = Array1::zeros(data.shape()[0]);
for i in 0..data.shape()[0] {
let mut max_resp_ = F::neg_infinity();
let mut best_k = 0;
for k in 0..self.options.n_components {
if resp__[[i, k]] > max_resp_ {
max_resp_ = resp__[[i, k]];
best_k = k;
}
}
labels[i] = best_k as i32;
}
Ok(labels)
}
}
#[allow(dead_code)]
pub fn gaussian_mixture<F>(data: ArrayView2<F>, options: GMMOptions<F>) -> Result<Array1<i32>>
where
F: Float + FromPrimitive + Debug + ScalarOperand + Sum + std::borrow::Borrow<f64>,
{
let mut gmm = GaussianMixture::new(options);
gmm.fit(data)?;
gmm.predict(data)
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array2;
#[test]
fn test_gmm_simple() {
let data = Array2::from_shape_vec(
(6, 2),
vec![1.0, 2.0, 1.2, 1.8, 0.8, 1.9, 4.0, 5.0, 4.2, 4.8, 3.9, 5.1],
)
.expect("Test: operation failed");
let options = GMMOptions {
n_components: 2,
max_iter: 10,
..Default::default()
};
let result = gaussian_mixture(data.view(), options);
assert!(result.is_ok());
let labels = result.expect("Test: operation failed");
assert_eq!(labels.len(), 6);
let unique_labels: std::collections::HashSet<_> = labels.iter().cloned().collect();
assert!(unique_labels.len() <= 2);
}
#[test]
fn test_gmm_different_covariance_types() {
let data = Array2::from_shape_vec(
(8, 2),
vec![
1.0, 1.0, 1.1, 1.1, 0.9, 0.9, 1.2, 0.8, 5.0, 5.0, 5.1, 5.1, 4.9, 4.9, 5.2, 4.8,
],
)
.expect("Test: operation failed");
let covariance_types = vec![
CovarianceType::Full,
CovarianceType::Diagonal,
CovarianceType::Spherical,
CovarianceType::Tied,
];
for cov_type in covariance_types {
let options = GMMOptions {
n_components: 2,
covariance_type: cov_type,
max_iter: 50,
..Default::default()
};
let result = gaussian_mixture(data.view(), options);
assert!(
result.is_ok(),
"Failed with covariance type: {:?}",
cov_type
);
let labels = result.expect("Test: operation failed");
assert_eq!(labels.len(), 8);
}
}
#[test]
fn test_gmm_initialization_methods() {
let data = Array2::from_shape_vec(
(6, 2),
vec![1.0, 2.0, 1.2, 1.8, 0.8, 1.9, 4.0, 5.0, 4.2, 4.8, 3.9, 5.1],
)
.expect("Test: operation failed");
let init_methods = vec![GMMInit::KMeans, GMMInit::Random];
for init_method in init_methods {
let options = GMMOptions {
n_components: 2,
init_method,
random_seed: Some(42),
max_iter: 20,
..Default::default()
};
let result = gaussian_mixture(data.view(), options);
assert!(result.is_ok(), "Failed with init method: {:?}", init_method);
let labels = result.expect("Test: operation failed");
assert_eq!(labels.len(), 6);
}
}
#[test]
fn test_gmm_parameter_validation() {
let data = Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 1.2, 1.8, 0.8, 1.9, 4.0, 5.0])
.expect("Operation failed");
let options = GMMOptions {
n_components: 0,
..Default::default()
};
let result = gaussian_mixture(data.view(), options);
assert!(result.is_err());
let options = GMMOptions {
n_components: 10,
max_iter: 5, ..Default::default()
};
let result = gaussian_mixture(data.view(), options);
let _result = result;
}
#[test]
fn test_gmm_convergence_criteria() {
let data = Array2::from_shape_vec(
(6, 2),
vec![1.0, 2.0, 1.2, 1.8, 0.8, 1.9, 4.0, 5.0, 4.2, 4.8, 3.9, 5.1],
)
.expect("Test: operation failed");
let tolerances = vec![1e-3, 1e-6, 1e-9];
for tol in tolerances {
let options = GMMOptions {
n_components: 2,
tol,
max_iter: 100,
..Default::default()
};
let result = gaussian_mixture(data.view(), options);
assert!(result.is_ok(), "Failed with tolerance: {}", tol);
}
}
#[test]
fn test_gmm_single_component() {
let data = Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 1.2, 1.8, 0.8, 1.9, 1.1, 2.1])
.expect("Operation failed");
let options = GMMOptions {
n_components: 1,
max_iter: 20,
..Default::default()
};
let result = gaussian_mixture(data.view(), options);
assert!(result.is_ok());
let labels = result.expect("Test: operation failed");
assert_eq!(labels.len(), 4);
assert!(labels.iter().all(|&l| l == 0));
}
#[test]
fn test_gmm_reproducibility_with_seed() {
let data = Array2::from_shape_vec(
(6, 2),
vec![1.0, 2.0, 1.2, 1.8, 0.8, 1.9, 4.0, 5.0, 4.2, 4.8, 3.9, 5.1],
)
.expect("Test: operation failed");
let options1 = GMMOptions {
n_components: 2,
random_seed: Some(42),
max_iter: 50,
..Default::default()
};
let options2 = GMMOptions {
n_components: 2,
random_seed: Some(42),
max_iter: 50,
..Default::default()
};
let labels1 = gaussian_mixture(data.view(), options1).expect("Operation failed");
let labels2 = gaussian_mixture(data.view(), options2).expect("Operation failed");
assert_eq!(labels1.len(), labels2.len());
let unique1: std::collections::HashSet<_> = labels1.iter().cloned().collect();
let unique2: std::collections::HashSet<_> = labels2.iter().cloned().collect();
assert_eq!(unique1.len(), unique2.len());
}
#[test]
fn test_gmm_many_components() {
let data = Array2::from_shape_vec(
(10, 2),
vec![
1.0, 1.0, 1.1, 1.1, 1.2, 1.2, 3.0, 3.0, 3.1, 3.1, 3.2, 3.2, 5.0, 5.0, 5.1, 5.1,
5.2, 5.2, 7.0, 7.0,
],
)
.expect("Test: operation failed");
let options = GMMOptions {
n_components: 3,
max_iter: 50,
..Default::default()
};
let result = gaussian_mixture(data.view(), options);
assert!(result.is_ok());
let labels = result.expect("Test: operation failed");
assert_eq!(labels.len(), 10);
let unique_labels: std::collections::HashSet<_> = labels.iter().cloned().collect();
assert!(unique_labels.len() <= 3);
assert!(!unique_labels.is_empty());
}
#[test]
fn test_gmm_regularization() {
let data = Array2::from_shape_vec(
(6, 2),
vec![1.0, 2.0, 1.2, 1.8, 0.8, 1.9, 4.0, 5.0, 4.2, 4.8, 3.9, 5.1],
)
.expect("Test: operation failed");
let reg_values = vec![1e-6, 1e-3, 1e-1];
for reg_covar in reg_values {
let options = GMMOptions {
n_components: 2,
reg_covar,
max_iter: 20,
..Default::default()
};
let result = gaussian_mixture(data.view(), options);
assert!(result.is_ok(), "Failed with reg_covar: {}", reg_covar);
}
}
#[test]
fn test_gmm_fit_predict_workflow() {
let data = Array2::from_shape_vec(
(8, 2),
vec![
1.0, 1.0, 1.1, 1.1, 0.9, 0.9, 1.2, 0.8, 5.0, 5.0, 5.1, 5.1, 4.9, 4.9, 5.2, 4.8,
],
)
.expect("Test: operation failed");
let options = GMMOptions {
n_components: 2,
max_iter: 50,
random_seed: Some(42),
..Default::default()
};
let mut gmm = GaussianMixture::new(options);
let fit_result = gmm.fit(data.view());
assert!(fit_result.is_ok());
let predict_result = gmm.predict(data.view());
assert!(predict_result.is_ok());
let labels = predict_result.expect("Test: operation failed");
assert_eq!(labels.len(), 8);
let new_data =
Array2::from_shape_vec((2, 2), vec![1.0, 1.0, 5.0, 5.0]).expect("Operation failed");
let new_labels = gmm.predict(new_data.view());
assert!(new_labels.is_ok());
assert_eq!(new_labels.expect("Test: operation failed").len(), 2);
}
}