Skip to main content

scirs2_cluster/
soft_clustering.rs

1//! Probabilistic soft clustering algorithms.
2//!
3//! This module provides two complementary soft-clustering algorithms:
4//!
5//! * [`GaussianMixtureModel`] – parametric EM-based GMM with BIC/AIC model
6//!   selection, k-means++ initialisation, and full soft-assignment output.
7//!
8//! * [`DirichletProcessMixtureModel`] – nonparametric Bayesian mixture with
9//!   stick-breaking process and variational mean-field inference that
10//!   automatically infers the number of active components.
11//!
12//! # Example – GMM
13//!
14//! ```rust
15//! use scirs2_cluster::soft_clustering::{GaussianMixtureModel, GmmParams};
16//! use scirs2_core::ndarray::Array2;
17//!
18//! let data = Array2::from_shape_vec((6, 2), vec![
19//!     1.0, 2.0,  1.2, 1.8,  0.8, 1.9,
20//!     4.0, 5.0,  4.2, 4.8,  3.9, 5.1,
21//! ]).expect("operation should succeed");
22//!
23//! let params = GaussianMixtureModel::fit(data.view(), 2, 100, 1e-4).expect("operation should succeed");
24//! let proba  = params.predict_proba(data.view()).expect("operation should succeed");
25//! assert_eq!(proba.shape(), [6, 2]);
26//! ```
27
28use std::f64::consts::PI;
29
30use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
31
32use crate::error::{ClusteringError, Result};
33
34// ═══════════════════════════════════════════════════════════════════════════
35// Internal maths helpers
36// ═══════════════════════════════════════════════════════════════════════════
37
38/// Digamma function approximation (Stirling series).
39fn digamma(x: f64) -> f64 {
40    if x <= 0.0 {
41        return f64::NEG_INFINITY;
42    }
43    let mut v = x;
44    let mut result = 0.0;
45    // Shift argument so the asymptotic is accurate
46    while v < 6.0 {
47        result -= 1.0 / v;
48        v += 1.0;
49    }
50    // Asymptotic series
51    result += v.ln() - 0.5 / v;
52    let inv_v2 = 1.0 / (v * v);
53    result -= inv_v2 * (1.0 / 12.0 - inv_v2 * (1.0 / 120.0 - inv_v2 / 252.0));
54    result
55}
56
57/// Compute log-sum-exp for a row of a 2D array.
58fn logsumexp_row(row: &[f64]) -> f64 {
59    let max = row.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
60    if max.is_infinite() {
61        return f64::NEG_INFINITY;
62    }
63    let s: f64 = row.iter().map(|&v| (v - max).exp()).sum();
64    max + s.ln()
65}
66
67/// Cholesky decomposition of an `(n × n)` symmetric positive-definite matrix.
68/// Returns the lower-triangular factor `L` s.t. `A = L L^T`.
69fn cholesky(a: &Array2<f64>) -> Result<Array2<f64>> {
70    let n = a.shape()[0];
71    let mut l = Array2::<f64>::zeros((n, n));
72    for i in 0..n {
73        for j in 0..=i {
74            let mut s = a[[i, j]];
75            for k in 0..j {
76                s -= l[[i, k]] * l[[j, k]];
77            }
78            if i == j {
79                if s <= 0.0 {
80                    s = 1e-12;
81                }
82                l[[i, j]] = s.sqrt();
83            } else if l[[j, j]].abs() < 1e-15 {
84                l[[i, j]] = 0.0;
85            } else {
86                l[[i, j]] = s / l[[j, j]];
87            }
88        }
89    }
90    Ok(l)
91}
92
93/// Log-determinant of a positive-definite matrix via Cholesky.
94fn log_det_pd(a: &Array2<f64>) -> Result<f64> {
95    let l = cholesky(a)?;
96    let n = l.shape()[0];
97    let mut log_det = 0.0;
98    for i in 0..n {
99        log_det += 2.0 * l[[i, i]].ln();
100    }
101    Ok(log_det)
102}
103
104/// Solve `L L^T x = b` (Cholesky back-substitution).
105fn cholesky_solve(l: &Array2<f64>, b: ArrayView1<f64>) -> Array1<f64> {
106    let n = l.shape()[0];
107    let mut y = Array1::<f64>::zeros(n);
108    // Forward substitution L y = b
109    for i in 0..n {
110        let mut s = b[i];
111        for k in 0..i {
112            s -= l[[i, k]] * y[k];
113        }
114        y[i] = if l[[i, i]].abs() < 1e-15 {
115            0.0
116        } else {
117            s / l[[i, i]]
118        };
119    }
120    // Back substitution L^T x = y
121    let mut x = Array1::<f64>::zeros(n);
122    for i in (0..n).rev() {
123        let mut s = y[i];
124        for k in (i + 1)..n {
125            s -= l[[k, i]] * x[k];
126        }
127        x[i] = if l[[i, i]].abs() < 1e-15 {
128            0.0
129        } else {
130            s / l[[i, i]]
131        };
132    }
133    x
134}
135
136/// Log of the multivariate Gaussian pdf N(x | mu, Sigma).
137/// Uses the Cholesky factor `l` of Sigma.
138fn log_mvn(x: ArrayView1<f64>, mu: ArrayView1<f64>, l: &Array2<f64>) -> f64 {
139    let d = x.len() as f64;
140    let diff: Array1<f64> = x.iter().zip(mu.iter()).map(|(&xi, &mi)| xi - mi).collect();
141    let z = cholesky_solve(l, diff.view());
142    let maha: f64 = z.iter().map(|&v| v * v).sum();
143    let log_det_l: f64 = (0..l.shape()[0]).map(|i| l[[i, i]].ln()).sum::<f64>();
144    -0.5 * (d * (2.0 * PI).ln() + 2.0 * log_det_l + maha)
145}
146
147/// K-means++ initialisation; returns `(n_components × n_features)` centroid array.
148fn kmeans_pp_init(data: ArrayView2<f64>, k: usize, seed: u64) -> Array2<f64> {
149    let n = data.shape()[0];
150    let d = data.shape()[1];
151
152    let mut rng_state = seed;
153    let lcg = |s: u64| {
154        s.wrapping_mul(6364136223846793005)
155            .wrapping_add(1442695040888963407)
156    };
157    let rand_f64 = |s: &mut u64| -> f64 {
158        *s = lcg(*s);
159        (*s >> 11) as f64 / (1u64 << 53) as f64
160    };
161
162    let mut centers = Array2::<f64>::zeros((k, d));
163    // First centre: random row
164    rng_state = lcg(rng_state);
165    let first = (rng_state as usize) % n;
166    centers.row_mut(0).assign(&data.row(first));
167
168    for ci in 1..k {
169        // For each point compute min distance^2 to chosen centres
170        let mut dists = Vec::with_capacity(n);
171        let mut sum_d = 0.0;
172        for i in 0..n {
173            let mut min_d2 = f64::INFINITY;
174            for cj in 0..ci {
175                let d2: f64 = data
176                    .row(i)
177                    .iter()
178                    .zip(centers.row(cj).iter())
179                    .map(|(&a, &b)| (a - b) * (a - b))
180                    .sum();
181                if d2 < min_d2 {
182                    min_d2 = d2;
183                }
184            }
185            dists.push(min_d2);
186            sum_d += min_d2;
187        }
188        // Sample proportionally to distance
189        let mut u = rand_f64(&mut rng_state) * sum_d;
190        let mut chosen = n - 1;
191        for (i, &d_i) in dists.iter().enumerate() {
192            u -= d_i;
193            if u <= 0.0 {
194                chosen = i;
195                break;
196            }
197        }
198        centers.row_mut(ci).assign(&data.row(chosen));
199    }
200    centers
201}
202
203// ═══════════════════════════════════════════════════════════════════════════
204// GMM – fitted parameters bundle
205// ═══════════════════════════════════════════════════════════════════════════
206
207/// Fitted parameters of a Gaussian Mixture Model.
208///
209/// All heavy computation is done in [`GaussianMixtureModel::fit`]; this
210/// struct is a pure data container that also exposes `predict_proba`,
211/// `predict`, `score`, `bic`, and `aic`.
212#[derive(Debug, Clone)]
213pub struct GmmParams {
214    /// Mixture weights, shape `(k,)`.
215    pub weights: Array1<f64>,
216    /// Component means, shape `(k, d)`.
217    pub means: Array2<f64>,
218    /// Cholesky factors of component covariances, each shape `(d, d)`.
219    pub chol_covs: Vec<Array2<f64>>,
220    /// Number of EM iterations performed.
221    pub n_iter: usize,
222    /// Whether the EM converged.
223    pub converged: bool,
224    /// Log-likelihood per sample at convergence.
225    pub log_likelihood: f64,
226}
227
228impl GmmParams {
229    /// Number of mixture components.
230    pub fn n_components(&self) -> usize {
231        self.weights.len()
232    }
233
234    /// Feature dimension.
235    pub fn n_features(&self) -> usize {
236        self.means.shape()[1]
237    }
238
239    /// Compute soft assignments (posterior responsibilities).
240    ///
241    /// Returns an `(n_samples, k)` array where each row sums to 1.
242    pub fn predict_proba(&self, data: ArrayView2<f64>) -> Result<Array2<f64>> {
243        let n = data.shape()[0];
244        let k = self.n_components();
245        let mut log_resp = Array2::<f64>::zeros((n, k));
246
247        for i in 0..n {
248            for c in 0..k {
249                if self.weights[c] <= 0.0 {
250                    log_resp[[i, c]] = f64::NEG_INFINITY;
251                    continue;
252                }
253                log_resp[[i, c]] = self.weights[c].ln()
254                    + log_mvn(data.row(i), self.means.row(c), &self.chol_covs[c]);
255            }
256            // Normalise in log space
257            let row: Vec<f64> = (0..k).map(|c| log_resp[[i, c]]).collect();
258            let lse = logsumexp_row(&row);
259            for c in 0..k {
260                log_resp[[i, c]] = (log_resp[[i, c]] - lse).exp();
261            }
262        }
263        Ok(log_resp)
264    }
265
266    /// Hard cluster assignments: argmax of `predict_proba`.
267    pub fn predict(&self, data: ArrayView2<f64>) -> Result<Array1<usize>> {
268        let proba = self.predict_proba(data)?;
269        let n = proba.shape()[0];
270        let k = proba.shape()[1];
271        let mut labels = Array1::<usize>::zeros(n);
272        for i in 0..n {
273            let mut best = 0;
274            let mut best_p = proba[[i, 0]];
275            for c in 1..k {
276                if proba[[i, c]] > best_p {
277                    best_p = proba[[i, c]];
278                    best = c;
279                }
280            }
281            labels[i] = best;
282        }
283        Ok(labels)
284    }
285
286    /// Mean log-likelihood over `data`.
287    pub fn score(&self, data: ArrayView2<f64>) -> Result<f64> {
288        let n = data.shape()[0];
289        let k = self.n_components();
290        let mut total_ll = 0.0;
291        for i in 0..n {
292            let mut log_terms: Vec<f64> = Vec::with_capacity(k);
293            for c in 0..k {
294                if self.weights[c] > 0.0 {
295                    log_terms.push(
296                        self.weights[c].ln()
297                            + log_mvn(data.row(i), self.means.row(c), &self.chol_covs[c]),
298                    );
299                }
300            }
301            total_ll += logsumexp_row(&log_terms);
302        }
303        Ok(total_ll / n as f64)
304    }
305
306    /// Number of free parameters in the model.
307    ///
308    /// For full-covariance GMM:
309    ///   k - 1  (weights) + k*d (means) + k*(d*(d+1)/2) (covariances)
310    fn n_free_params(&self) -> usize {
311        let k = self.n_components();
312        let d = self.n_features();
313        (k - 1) + k * d + k * (d * (d + 1) / 2)
314    }
315
316    /// Bayesian Information Criterion.
317    ///
318    /// BIC = -2 * log_likelihood * n_samples + p * ln(n_samples)
319    pub fn bic(&self, data: ArrayView2<f64>) -> Result<f64> {
320        let n = data.shape()[0] as f64;
321        let ll = self.score(data)? * n;
322        let p = self.n_free_params() as f64;
323        Ok(-2.0 * ll + p * n.ln())
324    }
325
326    /// Akaike Information Criterion.
327    ///
328    /// AIC = -2 * log_likelihood * n_samples + 2 * p
329    pub fn aic(&self, data: ArrayView2<f64>) -> Result<f64> {
330        let n = data.shape()[0] as f64;
331        let ll = self.score(data)? * n;
332        let p = self.n_free_params() as f64;
333        Ok(-2.0 * ll + 2.0 * p)
334    }
335}
336
337// ═══════════════════════════════════════════════════════════════════════════
338// GaussianMixtureModel – EM algorithm
339// ═══════════════════════════════════════════════════════════════════════════
340
341/// Gaussian Mixture Model with Expectation-Maximisation training.
342///
343/// Initialises with k-means++ for robustness, then iterates E / M steps
344/// until log-likelihood change drops below `tol` or `max_iter` is reached.
345/// Covariance matrices are full (unconstrained) with a regularisation term
346/// added to the diagonal to ensure numerical stability.
347pub struct GaussianMixtureModel;
348
349impl GaussianMixtureModel {
350    /// Fit a GMM to `data`.
351    ///
352    /// # Arguments
353    ///
354    /// * `data`         – `(n_samples, n_features)` input array.
355    /// * `n_components` – Number of Gaussian components `k`.
356    /// * `max_iter`     – Maximum EM iterations.
357    /// * `tol`          – Convergence tolerance on mean log-likelihood change.
358    ///
359    /// # Returns
360    ///
361    /// A [`GmmParams`] bundle with the fitted parameters.
362    pub fn fit(
363        data: ArrayView2<f64>,
364        n_components: usize,
365        max_iter: usize,
366        tol: f64,
367    ) -> Result<GmmParams> {
368        let n = data.shape()[0];
369        let d = data.shape()[1];
370        let k = n_components;
371
372        if k == 0 {
373            return Err(ClusteringError::InvalidInput(
374                "n_components must be >= 1".to_string(),
375            ));
376        }
377        if n < k {
378            return Err(ClusteringError::InvalidInput(
379                "n_samples must be >= n_components".to_string(),
380            ));
381        }
382        if d == 0 {
383            return Err(ClusteringError::InvalidInput(
384                "n_features must be >= 1".to_string(),
385            ));
386        }
387
388        let reg = 1e-6_f64;
389
390        // ── Initialise with k-means++ ────────────────────────────────────
391        let init_means = kmeans_pp_init(data, k, 42);
392
393        // Initial responsibilities: hard-assign each point to nearest centre
394        let mut resp = Array2::<f64>::zeros((n, k));
395        for i in 0..n {
396            let mut best_c = 0;
397            let mut best_d = f64::INFINITY;
398            for c in 0..k {
399                let d2: f64 = data
400                    .row(i)
401                    .iter()
402                    .zip(init_means.row(c).iter())
403                    .map(|(&a, &b)| (a - b) * (a - b))
404                    .sum();
405                if d2 < best_d {
406                    best_d = d2;
407                    best_c = c;
408                }
409            }
410            resp[[i, best_c]] = 1.0;
411        }
412
413        // Initial M-step from hard assignments
414        let (mut weights, mut means, mut chol_covs) = Self::m_step(data, resp.view(), k, d, reg)?;
415
416        let mut prev_ll = f64::NEG_INFINITY;
417        let mut n_iter = 0;
418        let mut converged = false;
419
420        for iter in 0..max_iter {
421            n_iter = iter + 1;
422
423            // ── E-step ──────────────────────────────────────────────────
424            resp = Self::e_step(data, &weights, &means, &chol_covs, k)?;
425
426            // ── Compute log-likelihood ───────────────────────────────────
427            let ll = Self::mean_log_likelihood(data, &weights, &means, &chol_covs, k);
428
429            if (ll - prev_ll).abs() < tol {
430                converged = true;
431                prev_ll = ll;
432                // One more M-step to stay consistent
433                let (w, m, c) = Self::m_step(data, resp.view(), k, d, reg)?;
434                weights = w;
435                means = m;
436                chol_covs = c;
437                break;
438            }
439            prev_ll = ll;
440
441            // ── M-step ──────────────────────────────────────────────────
442            let (w, m, c) = Self::m_step(data, resp.view(), k, d, reg)?;
443            weights = w;
444            means = m;
445            chol_covs = c;
446        }
447
448        Ok(GmmParams {
449            weights,
450            means,
451            chol_covs,
452            n_iter,
453            converged,
454            log_likelihood: prev_ll,
455        })
456    }
457
458    // ── Private helpers ─────────────────────────────────────────────────────
459
460    fn e_step(
461        data: ArrayView2<f64>,
462        weights: &Array1<f64>,
463        means: &Array2<f64>,
464        chol_covs: &[Array2<f64>],
465        k: usize,
466    ) -> Result<Array2<f64>> {
467        let n = data.shape()[0];
468        let mut log_resp = Array2::<f64>::zeros((n, k));
469
470        for i in 0..n {
471            for c in 0..k {
472                if weights[c] <= 0.0 {
473                    log_resp[[i, c]] = f64::NEG_INFINITY;
474                    continue;
475                }
476                log_resp[[i, c]] =
477                    weights[c].ln() + log_mvn(data.row(i), means.row(c), &chol_covs[c]);
478            }
479            let row: Vec<f64> = (0..k).map(|c| log_resp[[i, c]]).collect();
480            let lse = logsumexp_row(&row);
481            for c in 0..k {
482                log_resp[[i, c]] = (log_resp[[i, c]] - lse).exp();
483            }
484        }
485        Ok(log_resp)
486    }
487
488    fn m_step(
489        data: ArrayView2<f64>,
490        resp: ArrayView2<f64>,
491        k: usize,
492        d: usize,
493        reg: f64,
494    ) -> Result<(Array1<f64>, Array2<f64>, Vec<Array2<f64>>)> {
495        let n = data.shape()[0];
496
497        // Effective counts: N_k = sum_i r_{ik}
498        let nk: Vec<f64> = (0..k)
499            .map(|c| (0..n).map(|i| resp[[i, c]]).sum::<f64>().max(1e-10))
500            .collect();
501
502        let total_n: f64 = nk.iter().sum();
503        let weights: Array1<f64> = nk.iter().map(|&nkc| nkc / total_n).collect();
504
505        // Means: mu_k = (1/N_k) sum_i r_{ik} x_i
506        let mut means = Array2::<f64>::zeros((k, d));
507        for c in 0..k {
508            for i in 0..n {
509                for f in 0..d {
510                    means[[c, f]] += resp[[i, c]] * data[[i, f]];
511                }
512            }
513            for f in 0..d {
514                means[[c, f]] /= nk[c];
515            }
516        }
517
518        // Covariances: Sigma_k = (1/N_k) sum_i r_{ik} (x_i - mu_k)(x_i - mu_k)^T + reg*I
519        let mut chol_covs = Vec::with_capacity(k);
520        for c in 0..k {
521            let mut cov = Array2::<f64>::zeros((d, d));
522            for i in 0..n {
523                for f1 in 0..d {
524                    let diff_f1 = data[[i, f1]] - means[[c, f1]];
525                    for f2 in f1..d {
526                        let diff_f2 = data[[i, f2]] - means[[c, f2]];
527                        let v = resp[[i, c]] * diff_f1 * diff_f2 / nk[c];
528                        cov[[f1, f2]] += v;
529                        if f2 != f1 {
530                            cov[[f2, f1]] += v;
531                        }
532                    }
533                }
534            }
535            for f in 0..d {
536                cov[[f, f]] += reg;
537            }
538            let l = cholesky(&cov)?;
539            chol_covs.push(l);
540        }
541
542        Ok((weights, means, chol_covs))
543    }
544
545    fn mean_log_likelihood(
546        data: ArrayView2<f64>,
547        weights: &Array1<f64>,
548        means: &Array2<f64>,
549        chol_covs: &[Array2<f64>],
550        k: usize,
551    ) -> f64 {
552        let n = data.shape()[0];
553        let mut total = 0.0;
554        for i in 0..n {
555            let mut log_terms: Vec<f64> = Vec::with_capacity(k);
556            for c in 0..k {
557                if weights[c] > 0.0 {
558                    log_terms
559                        .push(weights[c].ln() + log_mvn(data.row(i), means.row(c), &chol_covs[c]));
560                }
561            }
562            total += logsumexp_row(&log_terms);
563        }
564        total / n as f64
565    }
566}
567
568// ═══════════════════════════════════════════════════════════════════════════
569// Dirichlet Process Mixture Model
570// ═══════════════════════════════════════════════════════════════════════════
571
572/// Fitted state returned by [`DirichletProcessMixtureModel::fit`].
573#[derive(Debug, Clone)]
574pub struct DpmmResult {
575    /// Variational mean stick-breaking weights (truncated at `T` components).
576    pub stick_weights: Array1<f64>,
577    /// Posterior mean of component means, shape `(T, d)`.
578    pub means: Array2<f64>,
579    /// Active component mask: `active[t]` is `true` if component `t` has
580    /// meaningful responsibility in the fitted data.
581    pub active: Vec<bool>,
582    /// ELBO (variational lower bound) at convergence.
583    pub elbo: f64,
584    /// Number of variational EM iterations.
585    pub n_iter: usize,
586    /// Whether the variational EM converged.
587    pub converged: bool,
588    // Cholesky factors for prediction (diagonal Gaussian per component)
589    chol_covs: Vec<Array2<f64>>,
590    n_active: usize,
591}
592
593impl DpmmResult {
594    /// Number of truncation components.
595    pub fn n_components(&self) -> usize {
596        self.stick_weights.len()
597    }
598
599    /// Number of components with non-negligible weight.
600    pub fn n_active_components(&self) -> usize {
601        self.n_active
602    }
603
604    /// Soft assignments: `(n_samples, T)` responsibility matrix.
605    pub fn predict_proba(&self, data: ArrayView2<f64>) -> Result<Array2<f64>> {
606        let n = data.shape()[0];
607        let t = self.n_components();
608        let mut log_resp = Array2::<f64>::zeros((n, t));
609        for i in 0..n {
610            for c in 0..t {
611                let w = self.stick_weights[c];
612                if w <= 0.0 || !self.active[c] {
613                    log_resp[[i, c]] = f64::NEG_INFINITY;
614                    continue;
615                }
616                log_resp[[i, c]] =
617                    w.ln() + log_mvn(data.row(i), self.means.row(c), &self.chol_covs[c]);
618            }
619            let row: Vec<f64> = (0..t).map(|c| log_resp[[i, c]]).collect();
620            let lse = logsumexp_row(&row);
621            for c in 0..t {
622                log_resp[[i, c]] = (log_resp[[i, c]] - lse).exp();
623            }
624        }
625        Ok(log_resp)
626    }
627
628    /// Hard cluster assignments (ignoring inactive components).
629    pub fn predict(&self, data: ArrayView2<f64>) -> Result<Array1<usize>> {
630        let proba = self.predict_proba(data)?;
631        let n = proba.shape()[0];
632        let t = proba.shape()[1];
633        let mut labels = Array1::<usize>::zeros(n);
634        for i in 0..n {
635            let mut best = 0;
636            let mut best_p = proba[[i, 0]];
637            for c in 1..t {
638                if proba[[i, c]] > best_p {
639                    best_p = proba[[i, c]];
640                    best = c;
641                }
642            }
643            labels[i] = best;
644        }
645        Ok(labels)
646    }
647}
648
649/// Dirichlet Process Mixture Model via truncated stick-breaking variational inference.
650///
651/// The model assumes a DP with concentration `alpha`, truncated at `T` components.
652/// Each component is a spherical Gaussian N(x | mu_k, sigma_k^2 I).
653/// Variational mean-field approximation is used, updating:
654///
655/// 1. Component responsibilities (E-like step).
656/// 2. Posterior stick-breaking parameters (M-like step for the DP prior).
657/// 3. Posterior component parameters (M-like step for the Gaussian likelihood).
658///
659/// The number of *effective* components is inferred automatically from the data.
660pub struct DirichletProcessMixtureModel {
661    /// DP concentration parameter (larger => more components).
662    pub alpha: f64,
663    /// Truncation level.
664    pub truncation: usize,
665    /// Maximum variational EM iterations.
666    pub max_iter: usize,
667    /// Convergence tolerance on the ELBO.
668    pub tol: f64,
669    /// Minimum component weight threshold to declare a component active.
670    pub activity_threshold: f64,
671}
672
673impl DirichletProcessMixtureModel {
674    /// Create a new DPMM estimator.
675    pub fn new(alpha: f64, truncation: usize) -> Self {
676        Self {
677            alpha,
678            truncation,
679            max_iter: 200,
680            tol: 1e-4,
681            activity_threshold: 1e-2,
682        }
683    }
684
685    /// Fit the DPMM to `data`.
686    pub fn fit(&self, data: ArrayView2<f64>) -> Result<DpmmResult> {
687        let n = data.shape()[0];
688        let d = data.shape()[1];
689        let t = self.truncation;
690
691        if n == 0 || d == 0 {
692            return Err(ClusteringError::InvalidInput(
693                "Data must be non-empty".to_string(),
694            ));
695        }
696        if t < 1 {
697            return Err(ClusteringError::InvalidInput(
698                "truncation must be >= 1".to_string(),
699            ));
700        }
701
702        let reg = 1e-6_f64;
703        let alpha = self.alpha;
704
705        // ── Initialise via k-means++ ─────────────────────────────────────
706        let k_init = t.min(n);
707        let init_means = kmeans_pp_init(data, k_init, 7);
708
709        // Responsibilities: hard-assign to nearest centre
710        let mut phi = Array2::<f64>::zeros((n, t));
711        for i in 0..n {
712            let mut best_c = 0;
713            let mut best_d = f64::INFINITY;
714            for c in 0..k_init {
715                let d2: f64 = data
716                    .row(i)
717                    .iter()
718                    .zip(init_means.row(c).iter())
719                    .map(|(&a, &b)| (a - b) * (a - b))
720                    .sum();
721                if d2 < best_d {
722                    best_d = d2;
723                    best_c = c;
724                }
725            }
726            phi[[i, best_c]] = 1.0;
727        }
728
729        // Variational parameters
730        // Stick-breaking: gamma_k = (a_k, b_k) for Beta(a_k, b_k)
731        let mut a_gamma = Array1::<f64>::from_elem(t, 1.0);
732        let mut b_gamma = Array1::<f64>::from_elem(t, alpha);
733
734        // Gaussian posteriors: (m_k, beta_k, nu_k, W_k_diag) — use diagonal
735        let mut m = Array2::<f64>::zeros((t, d)); // posterior mean
736        let mut beta_k = Array1::<f64>::from_elem(t, 1.0); // precision scale
737        let mut nu_k = Array1::<f64>::from_elem(t, d as f64 + 1.0); // dof
738        let mut w_k = Array2::<f64>::from_elem((t, d), 1.0); // diagonal Wishart
739
740        // Copy init_means
741        for c in 0..k_init {
742            for f in 0..d {
743                m[[c, f]] = init_means[[c, f]];
744            }
745        }
746
747        let mut prev_elbo = f64::NEG_INFINITY;
748        let mut n_iter = 0;
749        let mut converged = false;
750
751        for iter in 0..self.max_iter {
752            n_iter = iter + 1;
753
754            // ── E-step: update phi (responsibilities) ────────────────────
755            // E[log pi_k] from stick-breaking
756            let mut e_log_pi = Array1::<f64>::zeros(t);
757            let mut cumsum_b = 0.0;
758            for k in 0..t {
759                let e_log_v_k = digamma(a_gamma[k]) - digamma(a_gamma[k] + b_gamma[k]);
760                let e_log_1mv_k = digamma(b_gamma[k]) - digamma(a_gamma[k] + b_gamma[k]);
761                e_log_pi[k] = e_log_v_k + cumsum_b;
762                cumsum_b += e_log_1mv_k;
763            }
764
765            // E[log |Lambda_k|] = sum_f (digamma((nu_k + 1 - f) / 2) + ln(2 * W_kf))
766            // For diagonal Wishart this simplifies to:
767            let e_log_lam: Vec<f64> = (0..t)
768                .map(|k| {
769                    (0..d)
770                        .map(|f| {
771                            let dof_f = (nu_k[k] + 1.0 - f as f64) / 2.0;
772                            digamma(dof_f.max(0.5)) + (2.0 * w_k[[k, f]]).ln()
773                        })
774                        .sum::<f64>()
775                })
776                .collect();
777
778            for i in 0..n {
779                let mut log_rho = Vec::with_capacity(t);
780                for k in 0..t {
781                    // E[||x_i - mu_k||^2 * Lambda_k] under diagonal Wishart-Gaussian
782                    let trace_term: f64 = (0..d)
783                        .map(|f| {
784                            nu_k[k] * w_k[[k, f]] * (data[[i, f]] - m[[k, f]]).powi(2)
785                                + 1.0 / beta_k[k]
786                        })
787                        .sum();
788                    log_rho.push(
789                        e_log_pi[k] + 0.5 * e_log_lam[k]
790                            - 0.5 * d as f64 * (2.0 * PI).ln()
791                            - 0.5 * trace_term,
792                    );
793                }
794                let lse = logsumexp_row(&log_rho);
795                for k in 0..t {
796                    phi[[i, k]] = (log_rho[k] - lse).exp();
797                }
798            }
799
800            // ── M-step: update stick-breaking parameters ─────────────────
801            let nk: Vec<f64> = (0..t)
802                .map(|k| (0..n).map(|i| phi[[i, k]]).sum::<f64>().max(1e-10))
803                .collect();
804
805            for k in 0..t {
806                let sum_after: f64 = nk[(k + 1)..].iter().sum();
807                a_gamma[k] = 1.0 + nk[k];
808                b_gamma[k] = alpha + sum_after;
809            }
810
811            // ── M-step: update Gaussian posteriors ───────────────────────
812            for k in 0..t {
813                let beta_0 = 1.0;
814                let nu_0 = d as f64 + 1.0;
815
816                // Update beta and m
817                beta_k[k] = beta_0 + nk[k];
818                let mut x_bar = vec![0.0_f64; d];
819                for i in 0..n {
820                    for f in 0..d {
821                        x_bar[f] += phi[[i, k]] * data[[i, f]];
822                    }
823                }
824                for f in 0..d {
825                    x_bar[f] /= nk[k];
826                    m[[k, f]] = (beta_0 * 0.0 + nk[k] * x_bar[f]) / beta_k[k];
827                }
828
829                // Update nu
830                nu_k[k] = nu_0 + nk[k];
831
832                // Update W (diagonal precision matrix)
833                for f in 0..d {
834                    let mut scatter = 0.0;
835                    for i in 0..n {
836                        scatter += phi[[i, k]] * (data[[i, f]] - x_bar[f]).powi(2);
837                    }
838                    let bc_correction = beta_0 * nk[k] / beta_k[k] * x_bar[f].powi(2);
839                    w_k[[k, f]] = 1.0 / (1.0 / (1.0 + reg) + scatter + bc_correction);
840                }
841            }
842
843            // ── Compute ELBO (simplified) ─────────────────────────────────
844            let elbo = Self::compute_elbo(
845                data, &phi, &a_gamma, &b_gamma, &m, &beta_k, &nu_k, &w_k, alpha, n, d, t,
846            );
847
848            if (elbo - prev_elbo).abs() < self.tol {
849                converged = true;
850                prev_elbo = elbo;
851                break;
852            }
853            prev_elbo = elbo;
854        }
855
856        // ── Convert variational parameters to summary ────────────────────
857        // Compute expected stick weights via E[V_k]
858        let mut expected_weights = Array1::<f64>::zeros(t);
859        let mut log_remaining: f64 = 0.0;
860        for k in 0..t {
861            let e_v_k = a_gamma[k] / (a_gamma[k] + b_gamma[k]);
862            expected_weights[k] = e_v_k * log_remaining.exp();
863            log_remaining += (1.0 - e_v_k).ln();
864        }
865
866        let active: Vec<bool> = (0..t)
867            .map(|k| expected_weights[k] > self.activity_threshold / t as f64)
868            .collect();
869        let n_active = active.iter().filter(|&&a| a).count();
870
871        // Build diagonal Cholesky factors for prediction
872        let mut chol_covs = Vec::with_capacity(t);
873        for k in 0..t {
874            // Posterior predictive covariance ≈ diag(1 / (nu_k * w_k))
875            let mut cov = Array2::<f64>::zeros((d, d));
876            for f in 0..d {
877                let var = (1.0 / (nu_k[k] * w_k[[k, f]])).max(reg);
878                cov[[f, f]] = var.sqrt(); // store Cholesky (diagonal)
879            }
880            chol_covs.push(cov);
881        }
882
883        let final_means = m.clone();
884
885        Ok(DpmmResult {
886            stick_weights: expected_weights,
887            means: final_means,
888            active,
889            elbo: prev_elbo,
890            n_iter,
891            converged,
892            chol_covs,
893            n_active,
894        })
895    }
896
897    /// Simplified ELBO estimate for convergence monitoring.
898    #[allow(clippy::too_many_arguments)]
899    fn compute_elbo(
900        data: ArrayView2<f64>,
901        phi: &Array2<f64>,
902        a_gamma: &Array1<f64>,
903        b_gamma: &Array1<f64>,
904        m: &Array2<f64>,
905        beta_k: &Array1<f64>,
906        nu_k: &Array1<f64>,
907        w_k: &Array2<f64>,
908        alpha: f64,
909        n: usize,
910        d: usize,
911        t: usize,
912    ) -> f64 {
913        // E[log p(X | Z, params)] approximated via responsibilities
914        let mut ll = 0.0;
915        for i in 0..n {
916            for k in 0..t {
917                if phi[[i, k]] < 1e-15 {
918                    continue;
919                }
920                let log_norm = -(d as f64) / 2.0 * (2.0 * PI).ln();
921                let neg_quad: f64 = -(0..d)
922                    .map(|f| nu_k[k] * w_k[[k, f]] * (data[[i, f]] - m[[k, f]]).powi(2))
923                    .sum::<f64>()
924                    / 2.0;
925                let e_log_lam: f64 = (0..d)
926                    .map(|f| {
927                        let dof_f = (nu_k[k] + 1.0 - f as f64) / 2.0;
928                        digamma(dof_f.max(0.5)) + (2.0 * w_k[[k, f]]).ln()
929                    })
930                    .sum::<f64>()
931                    / 2.0;
932                ll += phi[[i, k]] * (log_norm + e_log_lam + neg_quad);
933            }
934        }
935
936        // E[log p(Z | pi)] - E[log q(Z)]
937        let mut z_term = 0.0;
938        for i in 0..n {
939            for k in 0..t {
940                let phi_ik = phi[[i, k]];
941                if phi_ik > 1e-15 {
942                    z_term -= phi_ik * phi_ik.ln(); // entropy
943                }
944            }
945        }
946
947        // DP prior contribution (simplified)
948        let dp_term: f64 = (0..t)
949            .map(|k| (alpha - 1.0) * (digamma(b_gamma[k]) - digamma(a_gamma[k] + b_gamma[k])))
950            .sum();
951
952        // Beta variational entropy
953        let beta_entropy: f64 = (0..t)
954            .map(|k| {
955                let ab = a_gamma[k] + b_gamma[k];
956                let ent = (beta_k[k]).ln() - (a_gamma[k] - 1.0) * digamma(a_gamma[k]) + (ab).ln()
957                    - (b_gamma[k] - 1.0) * digamma(b_gamma[k])
958                    + digamma(ab);
959                ent
960            })
961            .sum();
962
963        ll + z_term + dp_term + beta_entropy
964    }
965}
966
967// ─── Tests ───────────────────────────────────────────────────────────────────
968
969#[cfg(test)]
970mod tests {
971    use super::*;
972    use scirs2_core::ndarray::Array2;
973
974    fn two_cluster_data() -> Array2<f64> {
975        Array2::from_shape_vec(
976            (12, 2),
977            vec![
978                1.0, 1.0, 1.1, 0.9, 0.9, 1.1, 1.0, 1.0, 0.8, 1.2, 1.2, 0.8, 5.0, 5.0, 5.1, 4.9,
979                4.9, 5.1, 5.0, 5.0, 4.8, 5.2, 5.2, 4.8,
980            ],
981        )
982        .expect("data")
983    }
984
985    #[test]
986    fn test_gmm_fit_basic() {
987        let data = two_cluster_data();
988        let params = GaussianMixtureModel::fit(data.view(), 2, 100, 1e-4).expect("gmm fit");
989        assert_eq!(params.n_components(), 2);
990        assert_eq!(params.n_features(), 2);
991        assert!(params.converged || params.n_iter > 0);
992    }
993
994    #[test]
995    fn test_gmm_predict_proba() {
996        let data = two_cluster_data();
997        let params = GaussianMixtureModel::fit(data.view(), 2, 100, 1e-4).expect("gmm fit");
998        let proba = params.predict_proba(data.view()).expect("predict_proba");
999        assert_eq!(proba.shape(), [12, 2]);
1000        // Each row should sum to 1
1001        for i in 0..12 {
1002            let row_sum: f64 = (0..2).map(|c| proba[[i, c]]).sum();
1003            assert!((row_sum - 1.0).abs() < 1e-6, "row {i} sums to {row_sum}");
1004        }
1005    }
1006
1007    #[test]
1008    fn test_gmm_predict_hard() {
1009        let data = two_cluster_data();
1010        let params = GaussianMixtureModel::fit(data.view(), 2, 100, 1e-4).expect("gmm fit");
1011        let labels = params.predict(data.view()).expect("predict");
1012        assert_eq!(labels.len(), 12);
1013        // Two distinct clusters expected
1014        let unique: std::collections::HashSet<_> = labels.iter().copied().collect();
1015        assert!(unique.len() <= 2);
1016    }
1017
1018    #[test]
1019    fn test_gmm_score_finite() {
1020        let data = two_cluster_data();
1021        let params = GaussianMixtureModel::fit(data.view(), 2, 100, 1e-4).expect("gmm fit");
1022        let score = params.score(data.view()).expect("score");
1023        assert!(score.is_finite(), "score must be finite, got {score}");
1024    }
1025
1026    #[test]
1027    fn test_gmm_bic_aic() {
1028        let data = two_cluster_data();
1029        let params = GaussianMixtureModel::fit(data.view(), 2, 100, 1e-4).expect("gmm fit");
1030        let bic = params.bic(data.view()).expect("bic");
1031        let aic = params.aic(data.view()).expect("aic");
1032        assert!(bic.is_finite());
1033        assert!(aic.is_finite());
1034        // BIC >= AIC for reasonable n > e
1035        // (not always true with tiny datasets, just check they're finite)
1036    }
1037
1038    #[test]
1039    fn test_gmm_k1_trivial() {
1040        let data = two_cluster_data();
1041        let params = GaussianMixtureModel::fit(data.view(), 1, 50, 1e-4).expect("gmm k=1");
1042        let labels = params.predict(data.view()).expect("predict k=1");
1043        // All labels should be 0
1044        assert!(labels.iter().all(|&l| l == 0));
1045    }
1046
1047    #[test]
1048    fn test_gmm_invalid_k() {
1049        let data = two_cluster_data();
1050        let result = GaussianMixtureModel::fit(data.view(), 0, 50, 1e-4);
1051        assert!(result.is_err());
1052    }
1053
1054    #[test]
1055    fn test_dpmm_fit_basic() {
1056        let data = two_cluster_data();
1057        let model = DirichletProcessMixtureModel::new(1.0, 6);
1058        let result = model.fit(data.view()).expect("dpmm fit");
1059        assert_eq!(result.n_components(), 6);
1060        assert!(result.n_iter > 0);
1061        // At least one active component
1062        assert!(result.n_active_components() >= 1);
1063    }
1064
1065    #[test]
1066    fn test_dpmm_predict_proba() {
1067        let data = two_cluster_data();
1068        let model = DirichletProcessMixtureModel::new(1.0, 4);
1069        let result = model.fit(data.view()).expect("dpmm fit");
1070        let proba = result.predict_proba(data.view()).expect("proba");
1071        assert_eq!(proba.shape()[0], 12);
1072        assert_eq!(proba.shape()[1], 4);
1073        for i in 0..12 {
1074            let row_sum: f64 = (0..4).map(|c| proba[[i, c]]).sum();
1075            assert!((row_sum - 1.0).abs() < 1e-5, "row {i} sum {row_sum}");
1076        }
1077    }
1078
1079    #[test]
1080    fn test_dpmm_predict_hard() {
1081        let data = two_cluster_data();
1082        let model = DirichletProcessMixtureModel::new(1.0, 4);
1083        let result = model.fit(data.view()).expect("dpmm fit");
1084        let labels = result.predict(data.view()).expect("predict");
1085        assert_eq!(labels.len(), 12);
1086    }
1087
1088    #[test]
1089    fn test_dpmm_alpha_concentration() {
1090        // Higher alpha => more active components expected
1091        let data = two_cluster_data();
1092        let model_low = DirichletProcessMixtureModel::new(0.01, 8);
1093        let model_high = DirichletProcessMixtureModel::new(10.0, 8);
1094        let r_low = model_low.fit(data.view()).expect("low alpha");
1095        let r_high = model_high.fit(data.view()).expect("high alpha");
1096        assert!(r_high.n_active_components() >= r_low.n_active_components());
1097    }
1098}