use ferrolearn_core::error::FerroError;
use ferrolearn_core::traits::{Fit, Predict, Transform};
use ndarray::{Array1, Array2};
use num_traits::Float;
use rand::Rng;
use rand::SeedableRng;
use rand::rngs::StdRng;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CovarianceType {
Full,
Tied,
Diag,
Spherical,
}
#[derive(Debug, Clone)]
pub struct GaussianMixture<F> {
pub n_components: usize,
pub covariance_type: CovarianceType,
pub max_iter: usize,
pub tol: F,
pub n_init: usize,
pub random_state: Option<u64>,
}
impl<F: Float> GaussianMixture<F> {
#[must_use]
pub fn new(n_components: usize) -> Self {
Self {
n_components,
covariance_type: CovarianceType::Full,
max_iter: 100,
tol: F::from(1e-3).unwrap_or_else(F::epsilon),
n_init: 1,
random_state: None,
}
}
#[must_use]
pub fn with_covariance_type(mut self, cov: CovarianceType) -> Self {
self.covariance_type = cov;
self
}
#[must_use]
pub fn with_max_iter(mut self, max_iter: usize) -> Self {
self.max_iter = max_iter;
self
}
#[must_use]
pub fn with_tol(mut self, tol: F) -> Self {
self.tol = tol;
self
}
#[must_use]
pub fn with_n_init(mut self, n_init: usize) -> Self {
self.n_init = n_init;
self
}
#[must_use]
pub fn with_random_state(mut self, seed: u64) -> Self {
self.random_state = Some(seed);
self
}
}
#[derive(Debug, Clone)]
pub struct FittedGaussianMixture<F> {
pub weights_: Array1<F>,
pub means_: Array2<F>,
pub covariances_: Array2<F>,
pub converged_: bool,
pub lower_bound_: F,
covariance_type_: CovarianceType,
n_features_: usize,
}
impl<F: Float> FittedGaussianMixture<F> {
#[must_use]
pub fn weights(&self) -> &Array1<F> {
&self.weights_
}
#[must_use]
pub fn means(&self) -> &Array2<F> {
&self.means_
}
#[must_use]
pub fn covariances(&self) -> &Array2<F> {
&self.covariances_
}
#[must_use]
pub fn converged(&self) -> bool {
self.converged_
}
#[must_use]
pub fn lower_bound(&self) -> F {
self.lower_bound_
}
#[must_use]
pub fn bic(&self, n_samples: usize) -> F {
let n = F::from(n_samples).unwrap_or_else(F::one);
let log_n = n.ln();
let params = F::from(self.n_free_params()).unwrap_or_else(F::one);
-F::from(2.0).unwrap() * self.lower_bound_ * n + params * log_n
}
#[must_use]
pub fn aic(&self, n_samples: usize) -> F {
let n = F::from(n_samples).unwrap_or_else(F::one);
let two = F::from(2.0).unwrap();
let params = F::from(self.n_free_params()).unwrap_or_else(F::one);
-two * self.lower_bound_ * n + two * params
}
fn n_free_params(&self) -> usize {
let k = self.weights_.len();
let d = self.n_features_;
let cov_params = match self.covariance_type_ {
CovarianceType::Full => k * d * (d + 1) / 2,
CovarianceType::Tied => d * (d + 1) / 2,
CovarianceType::Diag => k * d,
CovarianceType::Spherical => k,
};
k * d + cov_params + (k - 1)
}
fn log_responsibilities(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
let n_samples = x.nrows();
let n_features = x.ncols();
let k = self.weights_.len();
if n_features != self.n_features_ {
return Err(FerroError::ShapeMismatch {
expected: vec![self.n_features_],
actual: vec![n_features],
context: "number of features must match fitted GaussianMixture".into(),
});
}
let mut log_resp = Array2::zeros((n_samples, k));
let two_pi = F::from(std::f64::consts::TAU).unwrap();
for ki in 0..k {
let log_w = self.weights_[ki].ln();
let mean = self.means_.row(ki);
let log_det;
let log_norm;
match self.covariance_type_ {
CovarianceType::Full | CovarianceType::Tied => {
let cov_offset = ki * n_features;
let cov_block = self
.covariances_
.slice(ndarray::s![cov_offset..cov_offset + n_features, ..]);
let (ld, ln) =
log_det_and_norm_full(&cov_block.to_owned(), n_features, two_pi)?;
log_det = ld;
log_norm = ln;
for ni in 0..n_samples {
let diff: Vec<F> = (0..n_features).map(|j| x[[ni, j]] - mean[j]).collect();
let maha = mahalanobis_full(&diff, &cov_block.to_owned(), n_features)?;
log_resp[[ni, ki]] =
log_w + log_norm - F::from(0.5).unwrap() * (log_det + maha);
}
}
CovarianceType::Diag => {
let variances = self.covariances_.row(ki);
log_det = variances.iter().fold(F::zero(), |acc, &v| acc + v.ln());
log_norm = -F::from(n_features as f64 / 2.0).unwrap() * two_pi.ln()
- F::from(0.5).unwrap() * log_det;
for ni in 0..n_samples {
let maha: F = (0..n_features).fold(F::zero(), |acc, j| {
let d = x[[ni, j]] - mean[j];
acc + d * d / variances[j]
});
log_resp[[ni, ki]] = log_w + log_norm - F::from(0.5).unwrap() * maha;
}
}
CovarianceType::Spherical => {
let var = self.covariances_[[ki, 0]];
log_det = F::from(n_features as f64).unwrap() * var.ln();
log_norm = -F::from(n_features as f64 / 2.0).unwrap() * two_pi.ln()
- F::from(0.5).unwrap() * log_det;
for ni in 0..n_samples {
let sq: F = (0..n_features).fold(F::zero(), |acc, j| {
let d = x[[ni, j]] - mean[j];
acc + d * d
});
let maha = sq / var;
log_resp[[ni, ki]] = log_w + log_norm - F::from(0.5).unwrap() * maha;
}
}
}
}
Ok(log_resp)
}
}
fn log_det_and_norm_full<F: Float>(
cov: &Array2<F>,
d: usize,
two_pi: F,
) -> Result<(F, F), FerroError> {
let chol = cholesky(cov, d)?;
let mut log_det = F::zero();
for i in 0..d {
if chol[[i, i]] <= F::zero() {
return Err(FerroError::NumericalInstability {
message: "covariance matrix is not positive definite".into(),
});
}
log_det = log_det + chol[[i, i]].ln();
}
log_det = log_det + log_det;
let log_norm =
-F::from(d as f64 / 2.0).unwrap() * two_pi.ln() - F::from(0.5).unwrap() * log_det;
Ok((log_det, log_norm))
}
fn mahalanobis_full<F: Float>(diff: &[F], cov: &Array2<F>, d: usize) -> Result<F, FerroError> {
let chol = cholesky(cov, d)?;
let mut y = vec![F::zero(); d];
for i in 0..d {
let mut s = diff[i];
for j in 0..i {
s = s - chol[[i, j]] * y[j];
}
if chol[[i, i]] == F::zero() {
return Err(FerroError::NumericalInstability {
message: "covariance matrix has zero diagonal in Cholesky".into(),
});
}
y[i] = s / chol[[i, i]];
}
Ok(y.iter().fold(F::zero(), |acc, &v| acc + v * v))
}
fn cholesky<F: Float>(cov: &Array2<F>, d: usize) -> Result<Array2<F>, FerroError> {
let reg = F::from(1e-6).unwrap_or_else(F::epsilon);
let mut l = Array2::zeros((d, d));
for i in 0..d {
for j in 0..=i {
let mut s = cov[[i, j]];
if i == j {
s = s + reg;
}
for p in 0..j {
s = s - l[[i, p]] * l[[j, p]];
}
if i == j {
if s <= F::zero() {
return Err(FerroError::NumericalInstability {
message: format!("covariance not positive-definite at diagonal [{i},{i}]"),
});
}
l[[i, j]] = s.sqrt();
} else {
if l[[j, j]] == F::zero() {
return Err(FerroError::NumericalInstability {
message: "Cholesky: zero diagonal element".into(),
});
}
l[[i, j]] = s / l[[j, j]];
}
}
}
Ok(l)
}
fn log_sum_exp_rows<F: Float>(log_resp: &Array2<F>) -> (Array2<F>, Array1<F>) {
let n_samples = log_resp.nrows();
let k = log_resp.ncols();
let mut log_probs = Array1::zeros(n_samples);
let mut normalised = Array2::zeros((n_samples, k));
for n in 0..n_samples {
let max_val = (0..k)
.map(|ki| log_resp[[n, ki]])
.fold(F::neg_infinity(), |a, b| if b > a { b } else { a });
let sum_exp: F = (0..k).fold(F::zero(), |acc, ki| {
acc + (log_resp[[n, ki]] - max_val).exp()
});
let lse = max_val + sum_exp.ln();
log_probs[n] = lse;
for ki in 0..k {
normalised[[n, ki]] = log_resp[[n, ki]] - lse;
}
}
(normalised, log_probs)
}
fn init_means<F: Float>(x: &Array2<F>, k: usize, rng: &mut StdRng) -> Array2<F> {
let n_samples = x.nrows();
let n_features = x.ncols();
let mut means = Array2::zeros((k, n_features));
for ki in 0..k {
let idx = rng.random_range(0..n_samples);
means.row_mut(ki).assign(&x.row(idx));
for j in 0..n_features {
let jitter: f64 = rng.random_range(-1e-4..1e-4);
means[[ki, j]] = means[[ki, j]] + F::from(jitter).unwrap_or_else(F::zero);
}
}
means
}
fn init_full_cov<F: Float>(n_features: usize) -> Array2<F> {
let mut cov = Array2::zeros((n_features, n_features));
let reg = F::from(1.0).unwrap_or_else(F::one);
for j in 0..n_features {
cov[[j, j]] = reg;
}
cov
}
#[allow(clippy::too_many_lines)]
fn run_em<F: Float>(
x: &Array2<F>,
n_components: usize,
covariance_type: CovarianceType,
max_iter: usize,
tol: F,
rng: &mut StdRng,
) -> Result<FittedGaussianMixture<F>, FerroError> {
let n_samples = x.nrows();
let n_features = x.ncols();
let k = n_components;
let mut weights = Array1::from_elem(k, F::from(1.0 / k as f64).unwrap());
let mut means = init_means(x, k, rng);
let mut covariances: Array2<F> = match covariance_type {
CovarianceType::Full => {
let mut c = Array2::zeros((k * n_features, n_features));
for ki in 0..k {
let block = init_full_cov(n_features);
let offset = ki * n_features;
c.slice_mut(ndarray::s![offset..offset + n_features, ..])
.assign(&block);
}
c
}
CovarianceType::Tied => {
let block = init_full_cov(n_features);
let mut c = Array2::zeros((k * n_features, n_features));
for ki in 0..k {
let offset = ki * n_features;
c.slice_mut(ndarray::s![offset..offset + n_features, ..])
.assign(&block);
}
c
}
CovarianceType::Diag => Array2::from_elem((k, n_features), F::one()),
CovarianceType::Spherical => Array2::from_elem((k, 1), F::one()),
};
let mut prev_ll = F::neg_infinity();
let mut converged = false;
for _iter in 0..max_iter {
let tmp = FittedGaussianMixture {
weights_: weights.clone(),
means_: means.clone(),
covariances_: covariances.clone(),
converged_: false,
lower_bound_: prev_ll,
covariance_type_: covariance_type,
n_features_: n_features,
};
let log_resp_raw = tmp.log_responsibilities(x)?;
let (log_resp, log_probs) = log_sum_exp_rows(&log_resp_raw);
let ll: F =
log_probs.iter().fold(F::zero(), |acc, &v| acc + v) / F::from(n_samples).unwrap();
if (ll - prev_ll).abs() < tol {
converged = true;
prev_ll = ll;
break;
}
prev_ll = ll;
let resp: Array2<F> = log_resp.mapv(num_traits::Float::exp);
let nk: Array1<F> = (0..k)
.map(|ki| resp.column(ki).iter().fold(F::zero(), |acc, &v| acc + v))
.collect::<Array1<F>>();
let reg_nk = F::from(10.0 * f64::EPSILON).unwrap();
let total: F = nk.iter().fold(F::zero(), |acc, &v| acc + v);
for ki in 0..k {
weights[ki] = (nk[ki] + reg_nk) / (total + F::from(k).unwrap() * reg_nk);
}
for ki in 0..k {
let nki = nk[ki] + reg_nk;
for j in 0..n_features {
let s: F = (0..n_samples).fold(F::zero(), |acc, n| acc + resp[[n, ki]] * x[[n, j]]);
means[[ki, j]] = s / nki;
}
}
match covariance_type {
CovarianceType::Full => {
for ki in 0..k {
let nki = nk[ki] + reg_nk;
let offset = ki * n_features;
let mut cov_k = Array2::<F>::zeros((n_features, n_features));
for n in 0..n_samples {
let r = resp[[n, ki]];
for i in 0..n_features {
let di = x[[n, i]] - means[[ki, i]];
for j in 0..=i {
let dj = x[[n, j]] - means[[ki, j]];
cov_k[[i, j]] = cov_k[[i, j]] + r * di * dj;
}
}
}
for i in 0..n_features {
cov_k[[i, i]] = cov_k[[i, i]] / nki;
for j in 0..i {
cov_k[[i, j]] = cov_k[[i, j]] / nki;
cov_k[[j, i]] = cov_k[[i, j]];
}
}
covariances
.slice_mut(ndarray::s![offset..offset + n_features, ..])
.assign(&cov_k);
}
}
CovarianceType::Tied => {
let mut cov_tied = Array2::<F>::zeros((n_features, n_features));
let total_nk = nk.iter().fold(F::zero(), |acc, &v| acc + v) + reg_nk;
for ki in 0..k {
let nki = nk[ki];
for n in 0..n_samples {
let r = resp[[n, ki]];
for i in 0..n_features {
let di = x[[n, i]] - means[[ki, i]];
for j in 0..=i {
let dj = x[[n, j]] - means[[ki, j]];
cov_tied[[i, j]] = cov_tied[[i, j]] + r * di * dj;
let _ = nki; }
}
}
}
for i in 0..n_features {
cov_tied[[i, i]] = cov_tied[[i, i]] / total_nk;
for j in 0..i {
cov_tied[[i, j]] = cov_tied[[i, j]] / total_nk;
cov_tied[[j, i]] = cov_tied[[i, j]];
}
}
for ki in 0..k {
let offset = ki * n_features;
covariances
.slice_mut(ndarray::s![offset..offset + n_features, ..])
.assign(&cov_tied);
}
}
CovarianceType::Diag => {
for ki in 0..k {
let nki = nk[ki] + reg_nk;
for j in 0..n_features {
let s: F = (0..n_samples).fold(F::zero(), |acc, n| {
let d = x[[n, j]] - means[[ki, j]];
acc + resp[[n, ki]] * d * d
});
let var = s / nki;
covariances[[ki, j]] = if var < F::from(1e-6).unwrap() {
F::from(1e-6).unwrap()
} else {
var
};
}
}
}
CovarianceType::Spherical => {
for ki in 0..k {
let nki = nk[ki] + reg_nk;
let d_f = F::from(n_features as f64).unwrap();
let s: F = (0..n_samples).fold(F::zero(), |acc, n| {
let sq: F = (0..n_features).fold(F::zero(), |a, j| {
let d = x[[n, j]] - means[[ki, j]];
a + d * d
});
acc + resp[[n, ki]] * sq
});
let var = s / (nki * d_f);
covariances[[ki, 0]] = if var < F::from(1e-6).unwrap() {
F::from(1e-6).unwrap()
} else {
var
};
}
}
}
}
Ok(FittedGaussianMixture {
weights_: weights,
means_: means,
covariances_: covariances,
converged_: converged,
lower_bound_: prev_ll,
covariance_type_: covariance_type,
n_features_: n_features,
})
}
impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, ()> for GaussianMixture<F> {
type Fitted = FittedGaussianMixture<F>;
type Error = FerroError;
fn fit(&self, x: &Array2<F>, _y: &()) -> Result<FittedGaussianMixture<F>, FerroError> {
if self.n_components == 0 {
return Err(FerroError::InvalidParameter {
name: "n_components".into(),
reason: "must be at least 1".into(),
});
}
if self.n_init == 0 {
return Err(FerroError::InvalidParameter {
name: "n_init".into(),
reason: "must be at least 1".into(),
});
}
let n_samples = x.nrows();
if n_samples == 0 {
return Err(FerroError::InsufficientSamples {
required: self.n_components,
actual: 0,
context: "GaussianMixture requires at least n_components samples".into(),
});
}
if n_samples < self.n_components {
return Err(FerroError::InsufficientSamples {
required: self.n_components,
actual: n_samples,
context: "GaussianMixture requires at least n_components samples".into(),
});
}
let base_seed = self.random_state.unwrap_or(0);
let mut best: Option<FittedGaussianMixture<F>> = None;
for run in 0..self.n_init {
let mut rng = StdRng::seed_from_u64(base_seed.wrapping_add(run as u64));
let candidate = run_em(
x,
self.n_components,
self.covariance_type,
self.max_iter,
self.tol,
&mut rng,
)?;
match &best {
None => best = Some(candidate),
Some(b) => {
if candidate.lower_bound_ > b.lower_bound_ {
best = Some(candidate);
}
}
}
}
best.ok_or_else(|| FerroError::InvalidParameter {
name: "n_init".into(),
reason: "internal error: no EM runs completed".into(),
})
}
}
impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedGaussianMixture<F> {
type Output = Array1<usize>;
type Error = FerroError;
fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
let resp = self.transform(x)?;
let labels: Array1<usize> = resp
.rows()
.into_iter()
.map(|row| {
row.iter()
.enumerate()
.fold((0usize, F::neg_infinity()), |(best_k, best_v), (ki, &v)| {
if v > best_v {
(ki, v)
} else {
(best_k, best_v)
}
})
.0
})
.collect();
Ok(labels)
}
}
impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedGaussianMixture<F> {
type Output = Array2<F>;
type Error = FerroError;
fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
let log_resp_raw = self.log_responsibilities(x)?;
let (log_resp_norm, _) = log_sum_exp_rows(&log_resp_raw);
Ok(log_resp_norm.mapv(num_traits::Float::exp))
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
fn make_two_blobs() -> Array2<f64> {
Array2::from_shape_vec(
(12, 2),
vec![
0.0, 0.0, 0.1, 0.0, 0.0, 0.1, -0.1, 0.0, 0.0, -0.1, 0.1, 0.1, 10.0, 10.0, 10.1,
10.0, 10.0, 10.1, 9.9, 10.0, 10.0, 9.9, 10.1, 10.1,
],
)
.unwrap()
}
fn make_three_blobs() -> Array2<f64> {
Array2::from_shape_vec(
(9, 2),
vec![
0.0, 0.0, 0.1, 0.1, -0.1, 0.1, 10.0, 10.0, 10.1, 10.1, 9.9, 10.1, 0.0, 10.0, 0.1,
10.1, -0.1, 9.9,
],
)
.unwrap()
}
#[test]
fn test_new_defaults() {
let gmm = GaussianMixture::<f64>::new(3);
assert_eq!(gmm.n_components, 3);
assert_eq!(gmm.covariance_type, CovarianceType::Full);
assert_eq!(gmm.max_iter, 100);
assert_eq!(gmm.n_init, 1);
assert!(gmm.random_state.is_none());
}
#[test]
fn test_builder_methods() {
let gmm = GaussianMixture::<f64>::new(2)
.with_covariance_type(CovarianceType::Diag)
.with_max_iter(50)
.with_tol(1e-6)
.with_n_init(3)
.with_random_state(7);
assert_eq!(gmm.covariance_type, CovarianceType::Diag);
assert_eq!(gmm.max_iter, 50);
assert_eq!(gmm.n_init, 3);
assert_eq!(gmm.random_state, Some(7));
}
#[test]
fn test_zero_components_error() {
let x = make_two_blobs();
let result = GaussianMixture::<f64>::new(0).fit(&x, &());
assert!(result.is_err());
}
#[test]
fn test_zero_n_init_error() {
let x = make_two_blobs();
let result = GaussianMixture::<f64>::new(2).with_n_init(0).fit(&x, &());
assert!(result.is_err());
}
#[test]
fn test_empty_data_error() {
let x = Array2::<f64>::zeros((0, 2));
let result = GaussianMixture::<f64>::new(2).fit(&x, &());
assert!(result.is_err());
}
#[test]
fn test_more_components_than_samples_error() {
let x = Array2::from_shape_vec((2, 2), vec![1.0, 1.0, 2.0, 2.0]).unwrap();
let result = GaussianMixture::<f64>::new(5).fit(&x, &());
assert!(result.is_err());
}
#[test]
fn test_predict_feature_mismatch_error() {
let x = make_two_blobs();
let fitted = GaussianMixture::<f64>::new(2)
.with_random_state(42)
.fit(&x, &())
.unwrap();
let bad = Array2::from_shape_vec((3, 5), vec![0.0; 15]).unwrap();
assert!(fitted.predict(&bad).is_err());
}
#[test]
fn test_transform_feature_mismatch_error() {
let x = make_two_blobs();
let fitted = GaussianMixture::<f64>::new(2)
.with_random_state(42)
.fit(&x, &())
.unwrap();
let bad = Array2::from_shape_vec((3, 5), vec![0.0; 15]).unwrap();
assert!(fitted.transform(&bad).is_err());
}
#[test]
fn test_fit_two_blobs_full_covariance() {
let x = make_two_blobs();
let fitted = GaussianMixture::<f64>::new(2)
.with_random_state(42)
.with_max_iter(200)
.fit(&x, &())
.unwrap();
assert_eq!(fitted.weights().len(), 2);
assert_eq!(fitted.means().dim(), (2, 2));
let w_sum: f64 = fitted.weights().iter().sum();
assert_relative_eq!(w_sum, 1.0, epsilon = 1e-6);
}
#[test]
fn test_fit_diag_covariance() {
let x = make_two_blobs();
let fitted = GaussianMixture::<f64>::new(2)
.with_covariance_type(CovarianceType::Diag)
.with_random_state(42)
.with_max_iter(200)
.fit(&x, &())
.unwrap();
assert_eq!(fitted.covariances().dim(), (2, 2));
}
#[test]
fn test_fit_spherical_covariance() {
let x = make_two_blobs();
let fitted = GaussianMixture::<f64>::new(2)
.with_covariance_type(CovarianceType::Spherical)
.with_random_state(42)
.with_max_iter(200)
.fit(&x, &())
.unwrap();
assert_eq!(fitted.covariances().dim(), (2, 1));
}
#[test]
fn test_fit_tied_covariance() {
let x = make_two_blobs();
let fitted = GaussianMixture::<f64>::new(2)
.with_covariance_type(CovarianceType::Tied)
.with_random_state(42)
.with_max_iter(200)
.fit(&x, &())
.unwrap();
assert_eq!(fitted.covariances().dim(), (4, 2));
}
#[test]
fn test_single_component() {
let x = make_two_blobs();
let fitted = GaussianMixture::<f64>::new(1)
.with_random_state(0)
.fit(&x, &())
.unwrap();
assert_eq!(fitted.weights().len(), 1);
assert_relative_eq!(fitted.weights()[0], 1.0, epsilon = 1e-6);
}
#[test]
fn test_reproducibility() {
let x = make_two_blobs();
let gmm = GaussianMixture::<f64>::new(2).with_random_state(123);
let f1 = gmm.fit(&x, &()).unwrap();
let f2 = gmm.fit(&x, &()).unwrap();
assert_relative_eq!(f1.lower_bound(), f2.lower_bound(), epsilon = 1e-10);
}
#[test]
fn test_n_init_picks_best() {
let x = make_two_blobs();
let f1 = GaussianMixture::<f64>::new(2)
.with_random_state(0)
.with_n_init(1)
.fit(&x, &())
.unwrap();
let f5 = GaussianMixture::<f64>::new(2)
.with_random_state(0)
.with_n_init(5)
.fit(&x, &())
.unwrap();
assert!(f5.lower_bound() >= f1.lower_bound() - 1e-6);
}
#[test]
fn test_predict_shape() {
let x = make_two_blobs();
let fitted = GaussianMixture::<f64>::new(2)
.with_random_state(42)
.with_max_iter(200)
.fit(&x, &())
.unwrap();
let labels = fitted.predict(&x).unwrap();
assert_eq!(labels.len(), x.nrows());
}
#[test]
fn test_predict_valid_range() {
let x = make_two_blobs();
let fitted = GaussianMixture::<f64>::new(2)
.with_random_state(42)
.with_max_iter(200)
.fit(&x, &())
.unwrap();
let labels = fitted.predict(&x).unwrap();
for &l in &labels {
assert!(l < 2, "label {l} out of range");
}
}
#[test]
fn test_predict_well_separated_clusters() {
let x = make_two_blobs();
let fitted = GaussianMixture::<f64>::new(2)
.with_random_state(42)
.with_max_iter(300)
.fit(&x, &())
.unwrap();
let labels = fitted.predict(&x).unwrap();
assert_eq!(labels[0], labels[1]);
assert_eq!(labels[0], labels[2]);
assert_eq!(labels[6], labels[7]);
assert_eq!(labels[6], labels[8]);
assert_ne!(labels[0], labels[6]);
}
#[test]
fn test_predict_three_blobs() {
let x = make_three_blobs();
let fitted = GaussianMixture::<f64>::new(3)
.with_random_state(7)
.with_max_iter(300)
.with_n_init(3)
.fit(&x, &())
.unwrap();
let labels = fitted.predict(&x).unwrap();
assert_eq!(labels[0], labels[1]);
assert_eq!(labels[3], labels[4]);
assert_eq!(labels[6], labels[7]);
assert_ne!(labels[0], labels[3]);
assert_ne!(labels[0], labels[6]);
}
#[test]
fn test_transform_shape() {
let x = make_two_blobs();
let fitted = GaussianMixture::<f64>::new(2)
.with_random_state(42)
.with_max_iter(200)
.fit(&x, &())
.unwrap();
let resp = fitted.transform(&x).unwrap();
assert_eq!(resp.dim(), (12, 2));
}
#[test]
fn test_transform_rows_sum_to_one() {
let x = make_two_blobs();
let fitted = GaussianMixture::<f64>::new(2)
.with_random_state(42)
.with_max_iter(200)
.fit(&x, &())
.unwrap();
let resp = fitted.transform(&x).unwrap();
for row in resp.rows() {
let s: f64 = row.iter().sum();
assert_relative_eq!(s, 1.0, epsilon = 1e-6);
}
}
#[test]
fn test_transform_values_in_0_1() {
let x = make_two_blobs();
let fitted = GaussianMixture::<f64>::new(2)
.with_random_state(42)
.with_max_iter(200)
.fit(&x, &())
.unwrap();
let resp = fitted.transform(&x).unwrap();
for &v in &resp {
assert!((0.0..=1.0 + 1e-10).contains(&v));
}
}
#[test]
fn test_bic_finite() {
let x = make_two_blobs();
let fitted = GaussianMixture::<f64>::new(2)
.with_random_state(42)
.fit(&x, &())
.unwrap();
let bic = fitted.bic(x.nrows());
assert!(bic.is_finite(), "BIC should be finite");
}
#[test]
fn test_aic_finite() {
let x = make_two_blobs();
let fitted = GaussianMixture::<f64>::new(2)
.with_random_state(42)
.fit(&x, &())
.unwrap();
let aic = fitted.aic(x.nrows());
assert!(aic.is_finite(), "AIC should be finite");
}
#[test]
fn test_bic_increases_with_more_components_on_two_blobs() {
let x = make_two_blobs();
let bic2 = GaussianMixture::<f64>::new(2)
.with_random_state(42)
.with_max_iter(200)
.fit(&x, &())
.unwrap()
.bic(x.nrows());
let bic5 = GaussianMixture::<f64>::new(5)
.with_random_state(42)
.with_max_iter(200)
.fit(&x, &())
.unwrap()
.bic(x.nrows());
assert!(bic2 < bic5, "bic2={bic2} bic5={bic5}");
}
#[test]
fn test_f32_support() {
let x = Array2::<f32>::from_shape_vec(
(8, 2),
vec![
0.0, 0.0, 0.1, 0.0, 0.0, 0.1, -0.1, 0.1, 10.0, 10.0, 10.1, 10.0, 10.0, 10.1, 9.9,
10.1,
],
)
.unwrap();
let fitted = GaussianMixture::<f32>::new(2)
.with_random_state(0)
.with_max_iter(200)
.fit(&x, &())
.unwrap();
let labels = fitted.predict(&x).unwrap();
assert_eq!(labels.len(), 8);
}
#[test]
fn test_accessor_methods() {
let x = make_two_blobs();
let fitted = GaussianMixture::<f64>::new(2)
.with_random_state(1)
.fit(&x, &())
.unwrap();
assert_eq!(fitted.weights().len(), 2);
assert_eq!(fitted.means().nrows(), 2);
assert!(fitted.lower_bound().is_finite());
let _ = fitted.converged();
}
#[test]
fn test_lower_bound_finite() {
let x = make_two_blobs();
let fitted = GaussianMixture::<f64>::new(2)
.with_random_state(42)
.fit(&x, &())
.unwrap();
assert!(fitted.lower_bound().is_finite());
}
}