use std::f64::consts::PI;
use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use crate::error::{ClusteringError, Result};
fn digamma(x: f64) -> f64 {
if x <= 0.0 {
return f64::NEG_INFINITY;
}
let mut v = x;
let mut result = 0.0;
while v < 6.0 {
result -= 1.0 / v;
v += 1.0;
}
result += v.ln() - 0.5 / v;
let inv_v2 = 1.0 / (v * v);
result -= inv_v2 * (1.0 / 12.0 - inv_v2 * (1.0 / 120.0 - inv_v2 / 252.0));
result
}
fn logsumexp_row(row: &[f64]) -> f64 {
let max = row.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
if max.is_infinite() {
return f64::NEG_INFINITY;
}
let s: f64 = row.iter().map(|&v| (v - max).exp()).sum();
max + s.ln()
}
fn cholesky(a: &Array2<f64>) -> Result<Array2<f64>> {
let n = a.shape()[0];
let mut l = Array2::<f64>::zeros((n, n));
for i in 0..n {
for j in 0..=i {
let mut s = a[[i, j]];
for k in 0..j {
s -= l[[i, k]] * l[[j, k]];
}
if i == j {
if s <= 0.0 {
s = 1e-12;
}
l[[i, j]] = s.sqrt();
} else if l[[j, j]].abs() < 1e-15 {
l[[i, j]] = 0.0;
} else {
l[[i, j]] = s / l[[j, j]];
}
}
}
Ok(l)
}
fn log_det_pd(a: &Array2<f64>) -> Result<f64> {
let l = cholesky(a)?;
let n = l.shape()[0];
let mut log_det = 0.0;
for i in 0..n {
log_det += 2.0 * l[[i, i]].ln();
}
Ok(log_det)
}
fn cholesky_solve(l: &Array2<f64>, b: ArrayView1<f64>) -> Array1<f64> {
let n = l.shape()[0];
let mut y = Array1::<f64>::zeros(n);
for i in 0..n {
let mut s = b[i];
for k in 0..i {
s -= l[[i, k]] * y[k];
}
y[i] = if l[[i, i]].abs() < 1e-15 {
0.0
} else {
s / l[[i, i]]
};
}
let mut x = Array1::<f64>::zeros(n);
for i in (0..n).rev() {
let mut s = y[i];
for k in (i + 1)..n {
s -= l[[k, i]] * x[k];
}
x[i] = if l[[i, i]].abs() < 1e-15 {
0.0
} else {
s / l[[i, i]]
};
}
x
}
fn log_mvn(x: ArrayView1<f64>, mu: ArrayView1<f64>, l: &Array2<f64>) -> f64 {
let d = x.len() as f64;
let diff: Array1<f64> = x
.iter()
.zip(mu.iter())
.map(|(&xi, &mi)| xi - mi)
.collect();
let z = cholesky_solve(l, diff.view());
let maha: f64 = z.iter().map(|&v| v * v).sum();
let log_det_l: f64 = (0..l.shape()[0]).map(|i| l[[i, i]].ln()).sum::<f64>();
-0.5 * (d * (2.0 * PI).ln() + 2.0 * log_det_l + maha)
}
fn kmeans_pp_init(data: ArrayView2<f64>, k: usize, seed: u64) -> Array2<f64> {
let n = data.shape()[0];
let d = data.shape()[1];
let mut rng_state = seed;
let lcg = |s: u64| s.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
let rand_f64 = |s: &mut u64| -> f64 {
*s = lcg(*s);
(*s >> 11) as f64 / (1u64 << 53) as f64
};
let mut centers = Array2::<f64>::zeros((k, d));
rng_state = lcg(rng_state);
let first = (rng_state as usize) % n;
centers.row_mut(0).assign(&data.row(first));
for ci in 1..k {
let mut dists = Vec::with_capacity(n);
let mut sum_d = 0.0;
for i in 0..n {
let mut min_d2 = f64::INFINITY;
for cj in 0..ci {
let d2: f64 = data
.row(i)
.iter()
.zip(centers.row(cj).iter())
.map(|(&a, &b)| (a - b) * (a - b))
.sum();
if d2 < min_d2 {
min_d2 = d2;
}
}
dists.push(min_d2);
sum_d += min_d2;
}
let mut u = rand_f64(&mut rng_state) * sum_d;
let mut chosen = n - 1;
for (i, &d_i) in dists.iter().enumerate() {
u -= d_i;
if u <= 0.0 {
chosen = i;
break;
}
}
centers.row_mut(ci).assign(&data.row(chosen));
}
centers
}
#[derive(Debug, Clone)]
pub struct GmmParams {
pub weights: Array1<f64>,
pub means: Array2<f64>,
pub chol_covs: Vec<Array2<f64>>,
pub n_iter: usize,
pub converged: bool,
pub log_likelihood: f64,
}
impl GmmParams {
pub fn n_components(&self) -> usize {
self.weights.len()
}
pub fn n_features(&self) -> usize {
self.means.shape()[1]
}
pub fn predict_proba(&self, data: ArrayView2<f64>) -> Result<Array2<f64>> {
let n = data.shape()[0];
let k = self.n_components();
let mut log_resp = Array2::<f64>::zeros((n, k));
for i in 0..n {
for c in 0..k {
if self.weights[c] <= 0.0 {
log_resp[[i, c]] = f64::NEG_INFINITY;
continue;
}
log_resp[[i, c]] = self.weights[c].ln()
+ log_mvn(
data.row(i),
self.means.row(c),
&self.chol_covs[c],
);
}
let row: Vec<f64> = (0..k).map(|c| log_resp[[i, c]]).collect();
let lse = logsumexp_row(&row);
for c in 0..k {
log_resp[[i, c]] = (log_resp[[i, c]] - lse).exp();
}
}
Ok(log_resp)
}
pub fn predict(&self, data: ArrayView2<f64>) -> Result<Array1<usize>> {
let proba = self.predict_proba(data)?;
let n = proba.shape()[0];
let k = proba.shape()[1];
let mut labels = Array1::<usize>::zeros(n);
for i in 0..n {
let mut best = 0;
let mut best_p = proba[[i, 0]];
for c in 1..k {
if proba[[i, c]] > best_p {
best_p = proba[[i, c]];
best = c;
}
}
labels[i] = best;
}
Ok(labels)
}
pub fn score(&self, data: ArrayView2<f64>) -> Result<f64> {
let n = data.shape()[0];
let k = self.n_components();
let mut total_ll = 0.0;
for i in 0..n {
let mut log_terms: Vec<f64> = Vec::with_capacity(k);
for c in 0..k {
if self.weights[c] > 0.0 {
log_terms.push(
self.weights[c].ln()
+ log_mvn(
data.row(i),
self.means.row(c),
&self.chol_covs[c],
),
);
}
}
total_ll += logsumexp_row(&log_terms);
}
Ok(total_ll / n as f64)
}
fn n_free_params(&self) -> usize {
let k = self.n_components();
let d = self.n_features();
(k - 1) + k * d + k * (d * (d + 1) / 2)
}
pub fn bic(&self, data: ArrayView2<f64>) -> Result<f64> {
let n = data.shape()[0] as f64;
let ll = self.score(data)? * n;
let p = self.n_free_params() as f64;
Ok(-2.0 * ll + p * n.ln())
}
pub fn aic(&self, data: ArrayView2<f64>) -> Result<f64> {
let n = data.shape()[0] as f64;
let ll = self.score(data)? * n;
let p = self.n_free_params() as f64;
Ok(-2.0 * ll + 2.0 * p)
}
}
pub struct GaussianMixtureModel;
impl GaussianMixtureModel {
pub fn fit(
data: ArrayView2<f64>,
n_components: usize,
max_iter: usize,
tol: f64,
) -> Result<GmmParams> {
let n = data.shape()[0];
let d = data.shape()[1];
let k = n_components;
if k == 0 {
return Err(ClusteringError::InvalidInput(
"n_components must be >= 1".to_string(),
));
}
if n < k {
return Err(ClusteringError::InvalidInput(
"n_samples must be >= n_components".to_string(),
));
}
if d == 0 {
return Err(ClusteringError::InvalidInput(
"n_features must be >= 1".to_string(),
));
}
let reg = 1e-6_f64;
let init_means = kmeans_pp_init(data, k, 42);
let mut resp = Array2::<f64>::zeros((n, k));
for i in 0..n {
let mut best_c = 0;
let mut best_d = f64::INFINITY;
for c in 0..k {
let d2: f64 = data
.row(i)
.iter()
.zip(init_means.row(c).iter())
.map(|(&a, &b)| (a - b) * (a - b))
.sum();
if d2 < best_d {
best_d = d2;
best_c = c;
}
}
resp[[i, best_c]] = 1.0;
}
let (mut weights, mut means, mut chol_covs) =
Self::m_step(data, resp.view(), k, d, reg)?;
let mut prev_ll = f64::NEG_INFINITY;
let mut n_iter = 0;
let mut converged = false;
for iter in 0..max_iter {
n_iter = iter + 1;
resp = Self::e_step(data, &weights, &means, &chol_covs, k)?;
let ll = Self::mean_log_likelihood(data, &weights, &means, &chol_covs, k);
if (ll - prev_ll).abs() < tol {
converged = true;
prev_ll = ll;
let (w, m, c) = Self::m_step(data, resp.view(), k, d, reg)?;
weights = w;
means = m;
chol_covs = c;
break;
}
prev_ll = ll;
let (w, m, c) = Self::m_step(data, resp.view(), k, d, reg)?;
weights = w;
means = m;
chol_covs = c;
}
Ok(GmmParams {
weights,
means,
chol_covs,
n_iter,
converged,
log_likelihood: prev_ll,
})
}
fn e_step(
data: ArrayView2<f64>,
weights: &Array1<f64>,
means: &Array2<f64>,
chol_covs: &[Array2<f64>],
k: usize,
) -> Result<Array2<f64>> {
let n = data.shape()[0];
let mut log_resp = Array2::<f64>::zeros((n, k));
for i in 0..n {
for c in 0..k {
if weights[c] <= 0.0 {
log_resp[[i, c]] = f64::NEG_INFINITY;
continue;
}
log_resp[[i, c]] =
weights[c].ln() + log_mvn(data.row(i), means.row(c), &chol_covs[c]);
}
let row: Vec<f64> = (0..k).map(|c| log_resp[[i, c]]).collect();
let lse = logsumexp_row(&row);
for c in 0..k {
log_resp[[i, c]] = (log_resp[[i, c]] - lse).exp();
}
}
Ok(log_resp)
}
fn m_step(
data: ArrayView2<f64>,
resp: ArrayView2<f64>,
k: usize,
d: usize,
reg: f64,
) -> Result<(Array1<f64>, Array2<f64>, Vec<Array2<f64>>)> {
let n = data.shape()[0];
let nk: Vec<f64> = (0..k)
.map(|c| (0..n).map(|i| resp[[i, c]]).sum::<f64>().max(1e-10))
.collect();
let total_n: f64 = nk.iter().sum();
let weights: Array1<f64> = nk.iter().map(|&nkc| nkc / total_n).collect();
let mut means = Array2::<f64>::zeros((k, d));
for c in 0..k {
for i in 0..n {
for f in 0..d {
means[[c, f]] += resp[[i, c]] * data[[i, f]];
}
}
for f in 0..d {
means[[c, f]] /= nk[c];
}
}
let mut chol_covs = Vec::with_capacity(k);
for c in 0..k {
let mut cov = Array2::<f64>::zeros((d, d));
for i in 0..n {
for f1 in 0..d {
let diff_f1 = data[[i, f1]] - means[[c, f1]];
for f2 in f1..d {
let diff_f2 = data[[i, f2]] - means[[c, f2]];
let v = resp[[i, c]] * diff_f1 * diff_f2 / nk[c];
cov[[f1, f2]] += v;
if f2 != f1 {
cov[[f2, f1]] += v;
}
}
}
}
for f in 0..d {
cov[[f, f]] += reg;
}
let l = cholesky(&cov)?;
chol_covs.push(l);
}
Ok((weights, means, chol_covs))
}
fn mean_log_likelihood(
data: ArrayView2<f64>,
weights: &Array1<f64>,
means: &Array2<f64>,
chol_covs: &[Array2<f64>],
k: usize,
) -> f64 {
let n = data.shape()[0];
let mut total = 0.0;
for i in 0..n {
let mut log_terms: Vec<f64> = Vec::with_capacity(k);
for c in 0..k {
if weights[c] > 0.0 {
log_terms.push(
weights[c].ln()
+ log_mvn(data.row(i), means.row(c), &chol_covs[c]),
);
}
}
total += logsumexp_row(&log_terms);
}
total / n as f64
}
}
#[derive(Debug, Clone)]
pub struct DpmmResult {
pub stick_weights: Array1<f64>,
pub means: Array2<f64>,
pub active: Vec<bool>,
pub elbo: f64,
pub n_iter: usize,
pub converged: bool,
chol_covs: Vec<Array2<f64>>,
n_active: usize,
}
impl DpmmResult {
pub fn n_components(&self) -> usize {
self.stick_weights.len()
}
pub fn n_active_components(&self) -> usize {
self.n_active
}
pub fn predict_proba(&self, data: ArrayView2<f64>) -> Result<Array2<f64>> {
let n = data.shape()[0];
let t = self.n_components();
let mut log_resp = Array2::<f64>::zeros((n, t));
for i in 0..n {
for c in 0..t {
let w = self.stick_weights[c];
if w <= 0.0 || !self.active[c] {
log_resp[[i, c]] = f64::NEG_INFINITY;
continue;
}
log_resp[[i, c]] =
w.ln() + log_mvn(data.row(i), self.means.row(c), &self.chol_covs[c]);
}
let row: Vec<f64> = (0..t).map(|c| log_resp[[i, c]]).collect();
let lse = logsumexp_row(&row);
for c in 0..t {
log_resp[[i, c]] = (log_resp[[i, c]] - lse).exp();
}
}
Ok(log_resp)
}
pub fn predict(&self, data: ArrayView2<f64>) -> Result<Array1<usize>> {
let proba = self.predict_proba(data)?;
let n = proba.shape()[0];
let t = proba.shape()[1];
let mut labels = Array1::<usize>::zeros(n);
for i in 0..n {
let mut best = 0;
let mut best_p = proba[[i, 0]];
for c in 1..t {
if proba[[i, c]] > best_p {
best_p = proba[[i, c]];
best = c;
}
}
labels[i] = best;
}
Ok(labels)
}
}
pub struct DirichletProcessMixtureModel {
pub alpha: f64,
pub truncation: usize,
pub max_iter: usize,
pub tol: f64,
pub activity_threshold: f64,
}
impl DirichletProcessMixtureModel {
pub fn new(alpha: f64, truncation: usize) -> Self {
Self {
alpha,
truncation,
max_iter: 200,
tol: 1e-4,
activity_threshold: 1e-2,
}
}
pub fn fit(&self, data: ArrayView2<f64>) -> Result<DpmmResult> {
let n = data.shape()[0];
let d = data.shape()[1];
let t = self.truncation;
if n == 0 || d == 0 {
return Err(ClusteringError::InvalidInput(
"Data must be non-empty".to_string(),
));
}
if t < 1 {
return Err(ClusteringError::InvalidInput(
"truncation must be >= 1".to_string(),
));
}
let reg = 1e-6_f64;
let alpha = self.alpha;
let k_init = t.min(n);
let init_means = kmeans_pp_init(data, k_init, 7);
let mut phi = Array2::<f64>::zeros((n, t));
for i in 0..n {
let mut best_c = 0;
let mut best_d = f64::INFINITY;
for c in 0..k_init {
let d2: f64 = data
.row(i)
.iter()
.zip(init_means.row(c).iter())
.map(|(&a, &b)| (a - b) * (a - b))
.sum();
if d2 < best_d {
best_d = d2;
best_c = c;
}
}
phi[[i, best_c]] = 1.0;
}
let mut a_gamma = Array1::<f64>::from_elem(t, 1.0);
let mut b_gamma = Array1::<f64>::from_elem(t, alpha);
let mut m = Array2::<f64>::zeros((t, d)); let mut beta_k = Array1::<f64>::from_elem(t, 1.0); let mut nu_k = Array1::<f64>::from_elem(t, d as f64 + 1.0); let mut w_k = Array2::<f64>::from_elem((t, d), 1.0);
for c in 0..k_init {
for f in 0..d {
m[[c, f]] = init_means[[c, f]];
}
}
let mut prev_elbo = f64::NEG_INFINITY;
let mut n_iter = 0;
let mut converged = false;
for iter in 0..self.max_iter {
n_iter = iter + 1;
let mut e_log_pi = Array1::<f64>::zeros(t);
let mut cumsum_b = 0.0;
for k in 0..t {
let e_log_v_k = digamma(a_gamma[k]) - digamma(a_gamma[k] + b_gamma[k]);
let e_log_1mv_k = digamma(b_gamma[k]) - digamma(a_gamma[k] + b_gamma[k]);
e_log_pi[k] = e_log_v_k + cumsum_b;
cumsum_b += e_log_1mv_k;
}
let e_log_lam: Vec<f64> = (0..t)
.map(|k| {
(0..d)
.map(|f| {
let dof_f = (nu_k[k] + 1.0 - f as f64) / 2.0;
digamma(dof_f.max(0.5)) + (2.0 * w_k[[k, f]]).ln()
})
.sum::<f64>()
})
.collect();
for i in 0..n {
let mut log_rho = Vec::with_capacity(t);
for k in 0..t {
let trace_term: f64 = (0..d)
.map(|f| {
nu_k[k] * w_k[[k, f]] * (data[[i, f]] - m[[k, f]]).powi(2)
+ 1.0 / beta_k[k]
})
.sum();
log_rho.push(e_log_pi[k] + 0.5 * e_log_lam[k]
- 0.5 * d as f64 * (2.0 * PI).ln()
- 0.5 * trace_term);
}
let lse = logsumexp_row(&log_rho);
for k in 0..t {
phi[[i, k]] = (log_rho[k] - lse).exp();
}
}
let nk: Vec<f64> = (0..t)
.map(|k| (0..n).map(|i| phi[[i, k]]).sum::<f64>().max(1e-10))
.collect();
for k in 0..t {
let sum_after: f64 = nk[(k + 1)..].iter().sum();
a_gamma[k] = 1.0 + nk[k];
b_gamma[k] = alpha + sum_after;
}
for k in 0..t {
let beta_0 = 1.0;
let nu_0 = d as f64 + 1.0;
beta_k[k] = beta_0 + nk[k];
let mut x_bar = vec![0.0_f64; d];
for i in 0..n {
for f in 0..d {
x_bar[f] += phi[[i, k]] * data[[i, f]];
}
}
for f in 0..d {
x_bar[f] /= nk[k];
m[[k, f]] = (beta_0 * 0.0 + nk[k] * x_bar[f]) / beta_k[k];
}
nu_k[k] = nu_0 + nk[k];
for f in 0..d {
let mut scatter = 0.0;
for i in 0..n {
scatter += phi[[i, k]] * (data[[i, f]] - x_bar[f]).powi(2);
}
let bc_correction = beta_0 * nk[k] / beta_k[k] * x_bar[f].powi(2);
w_k[[k, f]] = 1.0 / (1.0 / (1.0 + reg) + scatter + bc_correction);
}
}
let elbo = Self::compute_elbo(
data,
&phi,
&a_gamma,
&b_gamma,
&m,
&beta_k,
&nu_k,
&w_k,
alpha,
n,
d,
t,
);
if (elbo - prev_elbo).abs() < self.tol {
converged = true;
prev_elbo = elbo;
break;
}
prev_elbo = elbo;
}
let mut expected_weights = Array1::<f64>::zeros(t);
let mut log_remaining: f64 = 0.0;
for k in 0..t {
let e_v_k = a_gamma[k] / (a_gamma[k] + b_gamma[k]);
expected_weights[k] = e_v_k * log_remaining.exp();
log_remaining += (1.0 - e_v_k).ln();
}
let active: Vec<bool> = (0..t)
.map(|k| expected_weights[k] > self.activity_threshold / t as f64)
.collect();
let n_active = active.iter().filter(|&&a| a).count();
let mut chol_covs = Vec::with_capacity(t);
for k in 0..t {
let mut cov = Array2::<f64>::zeros((d, d));
for f in 0..d {
let var = (1.0 / (nu_k[k] * w_k[[k, f]])).max(reg);
cov[[f, f]] = var.sqrt(); }
chol_covs.push(cov);
}
let final_means = m.clone();
Ok(DpmmResult {
stick_weights: expected_weights,
means: final_means,
active,
elbo: prev_elbo,
n_iter,
converged,
chol_covs,
n_active,
})
}
#[allow(clippy::too_many_arguments)]
fn compute_elbo(
data: ArrayView2<f64>,
phi: &Array2<f64>,
a_gamma: &Array1<f64>,
b_gamma: &Array1<f64>,
m: &Array2<f64>,
beta_k: &Array1<f64>,
nu_k: &Array1<f64>,
w_k: &Array2<f64>,
alpha: f64,
n: usize,
d: usize,
t: usize,
) -> f64 {
let mut ll = 0.0;
for i in 0..n {
for k in 0..t {
if phi[[i, k]] < 1e-15 {
continue;
}
let log_norm = -(d as f64) / 2.0 * (2.0 * PI).ln();
let neg_quad: f64 = -(0..d)
.map(|f| nu_k[k] * w_k[[k, f]] * (data[[i, f]] - m[[k, f]]).powi(2))
.sum::<f64>()
/ 2.0;
let e_log_lam: f64 = (0..d)
.map(|f| {
let dof_f = (nu_k[k] + 1.0 - f as f64) / 2.0;
digamma(dof_f.max(0.5)) + (2.0 * w_k[[k, f]]).ln()
})
.sum::<f64>()
/ 2.0;
ll += phi[[i, k]] * (log_norm + e_log_lam + neg_quad);
}
}
let mut z_term = 0.0;
for i in 0..n {
for k in 0..t {
let phi_ik = phi[[i, k]];
if phi_ik > 1e-15 {
z_term -= phi_ik * phi_ik.ln(); }
}
}
let dp_term: f64 = (0..t)
.map(|k| (alpha - 1.0) * (digamma(b_gamma[k]) - digamma(a_gamma[k] + b_gamma[k])))
.sum();
let beta_entropy: f64 = (0..t)
.map(|k| {
let ab = a_gamma[k] + b_gamma[k];
let ent = (beta_k[k]).ln() - (a_gamma[k] - 1.0) * digamma(a_gamma[k])
+ (ab).ln()
- (b_gamma[k] - 1.0) * digamma(b_gamma[k])
+ digamma(ab);
ent
})
.sum();
ll + z_term + dp_term + beta_entropy
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array2;
fn two_cluster_data() -> Array2<f64> {
Array2::from_shape_vec(
(12, 2),
vec![
1.0, 1.0, 1.1, 0.9, 0.9, 1.1, 1.0, 1.0, 0.8, 1.2, 1.2, 0.8,
5.0, 5.0, 5.1, 4.9, 4.9, 5.1, 5.0, 5.0, 4.8, 5.2, 5.2, 4.8,
],
)
.expect("data")
}
#[test]
fn test_gmm_fit_basic() {
let data = two_cluster_data();
let params = GaussianMixtureModel::fit(data.view(), 2, 100, 1e-4)
.expect("gmm fit");
assert_eq!(params.n_components(), 2);
assert_eq!(params.n_features(), 2);
assert!(params.converged || params.n_iter > 0);
}
#[test]
fn test_gmm_predict_proba() {
let data = two_cluster_data();
let params = GaussianMixtureModel::fit(data.view(), 2, 100, 1e-4)
.expect("gmm fit");
let proba = params.predict_proba(data.view()).expect("predict_proba");
assert_eq!(proba.shape(), [12, 2]);
for i in 0..12 {
let row_sum: f64 = (0..2).map(|c| proba[[i, c]]).sum();
assert!((row_sum - 1.0).abs() < 1e-6, "row {i} sums to {row_sum}");
}
}
#[test]
fn test_gmm_predict_hard() {
let data = two_cluster_data();
let params = GaussianMixtureModel::fit(data.view(), 2, 100, 1e-4)
.expect("gmm fit");
let labels = params.predict(data.view()).expect("predict");
assert_eq!(labels.len(), 12);
let unique: std::collections::HashSet<_> = labels.iter().copied().collect();
assert!(unique.len() <= 2);
}
#[test]
fn test_gmm_score_finite() {
let data = two_cluster_data();
let params = GaussianMixtureModel::fit(data.view(), 2, 100, 1e-4)
.expect("gmm fit");
let score = params.score(data.view()).expect("score");
assert!(score.is_finite(), "score must be finite, got {score}");
}
#[test]
fn test_gmm_bic_aic() {
let data = two_cluster_data();
let params = GaussianMixtureModel::fit(data.view(), 2, 100, 1e-4)
.expect("gmm fit");
let bic = params.bic(data.view()).expect("bic");
let aic = params.aic(data.view()).expect("aic");
assert!(bic.is_finite());
assert!(aic.is_finite());
}
#[test]
fn test_gmm_k1_trivial() {
let data = two_cluster_data();
let params = GaussianMixtureModel::fit(data.view(), 1, 50, 1e-4)
.expect("gmm k=1");
let labels = params.predict(data.view()).expect("predict k=1");
assert!(labels.iter().all(|&l| l == 0));
}
#[test]
fn test_gmm_invalid_k() {
let data = two_cluster_data();
let result = GaussianMixtureModel::fit(data.view(), 0, 50, 1e-4);
assert!(result.is_err());
}
#[test]
fn test_dpmm_fit_basic() {
let data = two_cluster_data();
let model = DirichletProcessMixtureModel::new(1.0, 6);
let result = model.fit(data.view()).expect("dpmm fit");
assert_eq!(result.n_components(), 6);
assert!(result.n_iter > 0);
assert!(result.n_active_components() >= 1);
}
#[test]
fn test_dpmm_predict_proba() {
let data = two_cluster_data();
let model = DirichletProcessMixtureModel::new(1.0, 4);
let result = model.fit(data.view()).expect("dpmm fit");
let proba = result.predict_proba(data.view()).expect("proba");
assert_eq!(proba.shape()[0], 12);
assert_eq!(proba.shape()[1], 4);
for i in 0..12 {
let row_sum: f64 = (0..4).map(|c| proba[[i, c]]).sum();
assert!((row_sum - 1.0).abs() < 1e-5, "row {i} sum {row_sum}");
}
}
#[test]
fn test_dpmm_predict_hard() {
let data = two_cluster_data();
let model = DirichletProcessMixtureModel::new(1.0, 4);
let result = model.fit(data.view()).expect("dpmm fit");
let labels = result.predict(data.view()).expect("predict");
assert_eq!(labels.len(), 12);
}
#[test]
fn test_dpmm_alpha_concentration() {
let data = two_cluster_data();
let model_low = DirichletProcessMixtureModel::new(0.01, 8);
let model_high = DirichletProcessMixtureModel::new(10.0, 8);
let r_low = model_low.fit(data.view()).expect("low alpha");
let r_high = model_high.fit(data.view()).expect("high alpha");
assert!(r_high.n_active_components() >= r_low.n_active_components());
}
}