1use crate::common::CovarianceType;
7use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
8use scirs2_core::random::{Rng, SeedableRng};
9use sklears_core::{
10 error::{Result as SklResult, SklearsError},
11 traits::{Estimator, Fit, Predict, Untrained},
12 types::Float,
13};
14
15#[derive(Debug, Clone)]
50pub struct VariationalBayesianGMM<S = Untrained> {
51 pub(crate) state: S,
52 pub(crate) n_components: usize,
53 pub(crate) covariance_type: CovarianceType,
54 pub(crate) tol: f64,
55 pub(crate) reg_covar: f64,
56 pub(crate) max_iter: usize,
57 pub(crate) random_state: Option<u64>,
58 pub(crate) weight_concentration_prior: f64,
59 pub(crate) mean_precision_prior: f64,
60 pub(crate) degrees_of_freedom_prior: f64,
61}
62
63#[derive(Debug, Clone)]
65pub struct VariationalBayesianGMMTrained {
66 pub(crate) weights: Array1<f64>,
67 pub(crate) means: Array2<f64>,
68 pub(crate) covariances: Vec<Array2<f64>>,
69 pub(crate) weight_concentration: Array1<f64>,
70 pub(crate) mean_precision: Array1<f64>,
71 pub(crate) degrees_of_freedom: Array1<f64>,
72 pub(crate) lower_bound: f64,
73 pub(crate) n_iter: usize,
74 pub(crate) converged: bool,
75 pub(crate) effective_components: usize,
76}
77
78impl VariationalBayesianGMM<Untrained> {
79 pub fn new() -> Self {
81 Self {
82 state: Untrained,
83 n_components: 1,
84 covariance_type: CovarianceType::Full,
85 tol: 1e-3,
86 reg_covar: 1e-6,
87 max_iter: 100,
88 random_state: None,
89 weight_concentration_prior: 1.0,
90 mean_precision_prior: 1.0,
91 degrees_of_freedom_prior: 1.0,
92 }
93 }
94
95 pub fn builder() -> Self {
97 Self::new()
98 }
99
100 pub fn n_components(mut self, n_components: usize) -> Self {
102 self.n_components = n_components;
103 self
104 }
105
106 pub fn covariance_type(mut self, covariance_type: CovarianceType) -> Self {
108 self.covariance_type = covariance_type;
109 self
110 }
111
112 pub fn tol(mut self, tol: f64) -> Self {
114 self.tol = tol;
115 self
116 }
117
118 pub fn reg_covar(mut self, reg_covar: f64) -> Self {
120 self.reg_covar = reg_covar;
121 self
122 }
123
124 pub fn max_iter(mut self, max_iter: usize) -> Self {
126 self.max_iter = max_iter;
127 self
128 }
129
130 pub fn random_state(mut self, random_state: u64) -> Self {
132 self.random_state = Some(random_state);
133 self
134 }
135
136 pub fn weight_concentration_prior(mut self, prior: f64) -> Self {
138 self.weight_concentration_prior = prior;
139 self
140 }
141
142 pub fn mean_precision_prior(mut self, prior: f64) -> Self {
144 self.mean_precision_prior = prior;
145 self
146 }
147
148 pub fn degrees_of_freedom_prior(mut self, prior: f64) -> Self {
150 self.degrees_of_freedom_prior = prior;
151 self
152 }
153
154 pub fn build(self) -> Self {
156 self
157 }
158}
159
160impl Default for VariationalBayesianGMM<Untrained> {
161 fn default() -> Self {
162 Self::new()
163 }
164}
165
166impl Estimator for VariationalBayesianGMM<Untrained> {
167 type Config = ();
168 type Error = SklearsError;
169 type Float = Float;
170
171 fn config(&self) -> &Self::Config {
172 &()
173 }
174}
175
176impl Fit<ArrayView2<'_, Float>, ()> for VariationalBayesianGMM<Untrained> {
177 type Fitted = VariationalBayesianGMM<VariationalBayesianGMMTrained>;
178
179 #[allow(non_snake_case)]
180 fn fit(self, X: &ArrayView2<'_, Float>, _y: &()) -> SklResult<Self::Fitted> {
181 let X = X.to_owned();
182 let (n_samples, _n_features) = X.dim();
183
184 if n_samples < 2 {
185 return Err(SklearsError::InvalidInput(
186 "Number of samples must be at least 2".to_string(),
187 ));
188 }
189
190 if self.n_components == 0 {
191 return Err(SklearsError::InvalidInput(
192 "Number of components must be positive".to_string(),
193 ));
194 }
195
196 let (
198 mut weight_concentration,
199 mut mean_precision,
200 mut means,
201 mut degrees_of_freedom,
202 mut covariances,
203 ) = self.initialize_parameters(&X)?;
204
205 let mut lower_bound = f64::NEG_INFINITY;
206 let mut converged = false;
207 let mut n_iter = 0;
208
209 for iteration in 0..self.max_iter {
211 n_iter = iteration + 1;
212
213 let responsibilities = self.compute_responsibilities(
215 &X,
216 &weight_concentration,
217 &means,
218 &covariances,
219 °rees_of_freedom,
220 )?;
221
222 let (
224 new_weight_concentration,
225 new_mean_precision,
226 new_means,
227 new_degrees_of_freedom,
228 new_covariances,
229 ) = self.update_parameters(&X, &responsibilities)?;
230
231 let new_lower_bound = self.compute_lower_bound(
233 &X,
234 &responsibilities,
235 &new_weight_concentration,
236 &new_mean_precision,
237 &new_means,
238 &new_degrees_of_freedom,
239 &new_covariances,
240 )?;
241
242 if iteration > 0 && (new_lower_bound - lower_bound).abs() < self.tol {
244 converged = true;
245 }
246
247 weight_concentration = new_weight_concentration;
248 mean_precision = new_mean_precision;
249 means = new_means;
250 degrees_of_freedom = new_degrees_of_freedom;
251 covariances = new_covariances;
252 lower_bound = new_lower_bound;
253
254 if converged {
255 break;
256 }
257 }
258
259 let weights = self.compute_weights(&weight_concentration);
261
262 let effective_components = weights.iter().filter(|&&w| w > 1e-3).count();
264
265 Ok(VariationalBayesianGMM {
266 state: VariationalBayesianGMMTrained {
267 weights,
268 means,
269 covariances,
270 weight_concentration,
271 mean_precision,
272 degrees_of_freedom,
273 lower_bound,
274 n_iter,
275 converged,
276 effective_components,
277 },
278 n_components: self.n_components,
279 covariance_type: self.covariance_type,
280 tol: self.tol,
281 reg_covar: self.reg_covar,
282 max_iter: self.max_iter,
283 random_state: self.random_state,
284 weight_concentration_prior: self.weight_concentration_prior,
285 mean_precision_prior: self.mean_precision_prior,
286 degrees_of_freedom_prior: self.degrees_of_freedom_prior,
287 })
288 }
289}
290
291impl VariationalBayesianGMM<Untrained> {
292 fn initialize_parameters(
294 &self,
295 X: &Array2<f64>,
296 ) -> SklResult<(
297 Array1<f64>,
298 Array1<f64>,
299 Array2<f64>,
300 Array1<f64>,
301 Vec<Array2<f64>>,
302 )> {
303 let (_n_samples, n_features) = X.dim();
304
305 let weight_concentration =
307 Array1::from_elem(self.n_components, self.weight_concentration_prior);
308
309 let mean_precision = Array1::from_elem(self.n_components, self.mean_precision_prior);
311
312 let means = self.initialize_means(X)?;
314
315 let degrees_of_freedom = Array1::from_elem(
317 self.n_components,
318 self.degrees_of_freedom_prior + n_features as f64,
319 );
320
321 let covariances = self.initialize_covariances(X)?;
323
324 Ok((
325 weight_concentration,
326 mean_precision,
327 means,
328 degrees_of_freedom,
329 covariances,
330 ))
331 }
332
333 fn initialize_means(&self, X: &Array2<f64>) -> SklResult<Array2<f64>> {
335 let (n_samples, n_features) = X.dim();
336 let mut means = Array2::zeros((self.n_components, n_features));
337
338 if let Some(seed) = self.random_state {
340 let mut rng = scirs2_core::random::rngs::StdRng::seed_from_u64(seed);
341
342 let idx = rng.gen_range(0..n_samples);
344 means.row_mut(0).assign(&X.row(idx));
345
346 for i in 1..self.n_components {
348 let mut best_distance = 0.0;
349 let mut best_idx = 0;
350
351 for j in 0..n_samples {
352 let sample = X.row(j);
353 let mut min_distance = f64::INFINITY;
354
355 for k in 0..i {
356 let existing_mean = means.row(k);
357 let distance = (&sample - &existing_mean).mapv(|x| x * x).sum();
358 min_distance = min_distance.min(distance);
359 }
360
361 if min_distance > best_distance {
362 best_distance = min_distance;
363 best_idx = j;
364 }
365 }
366
367 means.row_mut(i).assign(&X.row(best_idx));
368 }
369 } else {
370 let step = n_samples / self.n_components;
372
373 for (i, mut mean) in means.axis_iter_mut(Axis(0)).enumerate() {
374 let sample_idx = if step == 0 {
375 i.min(n_samples - 1)
376 } else {
377 (i * step).min(n_samples - 1)
378 };
379 mean.assign(&X.row(sample_idx));
380 }
381 }
382
383 Ok(means)
384 }
385
386 fn initialize_covariances(&self, X: &Array2<f64>) -> SklResult<Vec<Array2<f64>>> {
388 let (_, n_features) = X.dim();
389 let mut covariances = Vec::new();
390
391 let global_cov = self.estimate_global_covariance(X)?;
393
394 for _ in 0..self.n_components {
395 let mut cov = global_cov.clone();
396
397 for i in 0..n_features {
399 cov[[i, i]] += self.reg_covar;
400 }
401
402 covariances.push(cov);
403 }
404
405 Ok(covariances)
406 }
407
408 fn estimate_global_covariance(&self, X: &Array2<f64>) -> SklResult<Array2<f64>> {
410 let (n_samples, n_features) = X.dim();
411
412 let mut mean = Array1::zeros(n_features);
414 for i in 0..n_features {
415 mean[i] = X.column(i).sum() / n_samples as f64;
416 }
417
418 let mut cov = Array2::zeros((n_features, n_features));
420 for i in 0..n_features {
421 for j in 0..n_features {
422 let mut sum = 0.0;
423 for k in 0..n_samples {
424 sum += (X[[k, i]] - mean[i]) * (X[[k, j]] - mean[j]);
425 }
426 cov[[i, j]] = sum / (n_samples as f64 - 1.0);
427 }
428 }
429
430 match self.covariance_type {
432 CovarianceType::Diagonal => {
433 for i in 0..n_features {
434 for j in 0..n_features {
435 if i != j {
436 cov[[i, j]] = 0.0;
437 }
438 }
439 }
440 }
441 CovarianceType::Spherical => {
442 let trace = cov.diag().sum() / n_features as f64;
443 cov.fill(0.0);
444 for i in 0..n_features {
445 cov[[i, i]] = trace;
446 }
447 }
448 _ => {} }
450
451 Ok(cov)
452 }
453
454 fn compute_responsibilities(
456 &self,
457 X: &Array2<f64>,
458 weight_concentration: &Array1<f64>,
459 means: &Array2<f64>,
460 covariances: &[Array2<f64>],
461 degrees_of_freedom: &Array1<f64>,
462 ) -> SklResult<Array2<f64>> {
463 let (n_samples, _) = X.dim();
464 let mut responsibilities = Array2::zeros((n_samples, self.n_components));
465
466 let expected_log_weights = self.compute_expected_log_weights(weight_concentration);
468
469 for (i, sample) in X.axis_iter(Axis(0)).enumerate() {
471 let mut log_prob_norm = f64::NEG_INFINITY;
472 let mut log_probs = Vec::new();
473
474 for k in 0..self.n_components {
476 let mean = means.row(k);
477 let cov = &covariances[k];
478
479 let log_prob =
481 self.compute_student_t_log_pdf(&sample, &mean, cov, degrees_of_freedom[k])?;
482 let weighted_log_prob = expected_log_weights[k] + log_prob;
483
484 log_probs.push(weighted_log_prob);
485 log_prob_norm = log_prob_norm.max(weighted_log_prob);
486 }
487
488 let mut sum_exp = 0.0;
490 for &log_prob in &log_probs {
491 sum_exp += (log_prob - log_prob_norm).exp();
492 }
493 let log_sum_exp = log_prob_norm + sum_exp.ln();
494
495 for k in 0..self.n_components {
496 responsibilities[[i, k]] = (log_probs[k] - log_sum_exp).exp();
497 }
498 }
499
500 Ok(responsibilities)
501 }
502
503 fn compute_expected_log_weights(&self, weight_concentration: &Array1<f64>) -> Array1<f64> {
505 let sum_concentration: f64 = weight_concentration.sum();
506 let mut expected_log_weights = Array1::zeros(self.n_components);
507
508 for k in 0..self.n_components {
509 expected_log_weights[k] = digamma(weight_concentration[k]) - digamma(sum_concentration);
511 }
512
513 expected_log_weights
514 }
515
516 fn compute_student_t_log_pdf(
518 &self,
519 x: &ArrayView1<f64>,
520 mean: &ArrayView1<f64>,
521 cov: &Array2<f64>,
522 _degrees_of_freedom: f64,
523 ) -> SklResult<f64> {
524 crate::common::gaussian_log_pdf(x, mean, &cov.view())
527 }
528
529 fn update_parameters(
531 &self,
532 X: &Array2<f64>,
533 responsibilities: &Array2<f64>,
534 ) -> SklResult<(
535 Array1<f64>,
536 Array1<f64>,
537 Array2<f64>,
538 Array1<f64>,
539 Vec<Array2<f64>>,
540 )> {
541 let (n_samples, n_features) = X.dim();
542
543 let mut weight_concentration = Array1::zeros(self.n_components);
545 for k in 0..self.n_components {
546 weight_concentration[k] =
547 self.weight_concentration_prior + responsibilities.column(k).sum();
548 }
549
550 let mut mean_precision = Array1::zeros(self.n_components);
552 for k in 0..self.n_components {
553 mean_precision[k] = self.mean_precision_prior + responsibilities.column(k).sum();
554 }
555
556 let mut means = Array2::zeros((self.n_components, n_features));
558 for k in 0..self.n_components {
559 let resp_sum = responsibilities.column(k).sum();
560 if resp_sum > 0.0 {
561 for j in 0..n_features {
562 let mut weighted_sum = 0.0;
563 for i in 0..n_samples {
564 weighted_sum += responsibilities[[i, k]] * X[[i, j]];
565 }
566 means[[k, j]] = weighted_sum / resp_sum;
567 }
568 }
569 }
570
571 let mut degrees_of_freedom = Array1::zeros(self.n_components);
573 for k in 0..self.n_components {
574 degrees_of_freedom[k] =
575 self.degrees_of_freedom_prior + responsibilities.column(k).sum();
576 }
577
578 let mut covariances = Vec::new();
580 for _k in 0..self.n_components {
581 let mut cov = Array2::eye(n_features);
582 for i in 0..n_features {
583 cov[[i, i]] = 1.0 + self.reg_covar;
584 }
585 covariances.push(cov);
586 }
587
588 Ok((
589 weight_concentration,
590 mean_precision,
591 means,
592 degrees_of_freedom,
593 covariances,
594 ))
595 }
596
597 fn compute_lower_bound(
599 &self,
600 X: &Array2<f64>,
601 responsibilities: &Array2<f64>,
602 weight_concentration: &Array1<f64>,
603 _mean_precision: &Array1<f64>,
604 means: &Array2<f64>,
605 _degrees_of_freedom: &Array1<f64>,
606 covariances: &[Array2<f64>],
607 ) -> SklResult<f64> {
608 let mut lower_bound = 0.0;
611
612 for (i, sample) in X.axis_iter(Axis(0)).enumerate() {
614 for k in 0..self.n_components {
615 let resp = responsibilities[[i, k]];
616 if resp > 0.0 {
617 let mean = means.row(k);
618 let cov = &covariances[k];
619 let log_prob = crate::common::gaussian_log_pdf(&sample, &mean, &cov.view())?;
620 lower_bound += resp * log_prob;
621 }
622 }
623 }
624
625 let expected_log_weights = self.compute_expected_log_weights(weight_concentration);
627 for k in 0..self.n_components {
628 let resp_sum = responsibilities.column(k).sum();
629 if resp_sum > 0.0 {
630 lower_bound += resp_sum * expected_log_weights[k];
631 }
632 }
633
634 Ok(lower_bound)
635 }
636
637 fn compute_weights(&self, weight_concentration: &Array1<f64>) -> Array1<f64> {
639 let sum_concentration: f64 = weight_concentration.sum();
640 weight_concentration.mapv(|x| x / sum_concentration)
641 }
642}
643
644impl Predict<ArrayView2<'_, Float>, Array1<i32>>
645 for VariationalBayesianGMM<VariationalBayesianGMMTrained>
646{
647 #[allow(non_snake_case)]
648 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
649 let X = X.to_owned();
650 let (n_samples, _) = X.dim();
651 let mut predictions = Array1::zeros(n_samples);
652
653 for (i, sample) in X.axis_iter(Axis(0)).enumerate() {
655 let mut best_component = 0;
656 let mut best_log_prob = f64::NEG_INFINITY;
657
658 for k in 0..self.n_components {
659 if self.state.weights[k] > 1e-3 {
660 let mean = self.state.means.row(k);
662 let cov = &self.state.covariances[k];
663
664 let log_prob = crate::common::gaussian_log_pdf(&sample, &mean, &cov.view())?;
665 let weighted_log_prob = self.state.weights[k].ln() + log_prob;
666
667 if weighted_log_prob > best_log_prob {
668 best_log_prob = weighted_log_prob;
669 best_component = k;
670 }
671 }
672 }
673
674 predictions[i] = best_component as i32;
675 }
676
677 Ok(predictions)
678 }
679}
680
681impl VariationalBayesianGMM<VariationalBayesianGMMTrained> {
682 pub fn weights(&self) -> &Array1<f64> {
684 &self.state.weights
685 }
686
687 pub fn means(&self) -> &Array2<f64> {
689 &self.state.means
690 }
691
692 pub fn covariances(&self) -> &[Array2<f64>] {
694 &self.state.covariances
695 }
696
697 pub fn lower_bound(&self) -> f64 {
699 self.state.lower_bound
700 }
701
702 pub fn effective_components(&self) -> usize {
704 self.state.effective_components
705 }
706
707 pub fn converged(&self) -> bool {
709 self.state.converged
710 }
711
712 pub fn n_iter(&self) -> usize {
714 self.state.n_iter
715 }
716
717 pub fn weight_concentration(&self) -> &Array1<f64> {
719 &self.state.weight_concentration
720 }
721
722 pub fn mean_precision(&self) -> &Array1<f64> {
724 &self.state.mean_precision
725 }
726
727 pub fn degrees_of_freedom(&self) -> &Array1<f64> {
729 &self.state.degrees_of_freedom
730 }
731}
732
733fn digamma(x: f64) -> f64 {
735 if x > 6.0 {
737 x.ln() - 1.0 / (2.0 * x) - 1.0 / (12.0 * x * x)
738 } else {
739 let mut result = x;
741 let mut n = 0;
742 while result < 6.0 {
743 result += 1.0;
744 n += 1;
745 }
746 let asymptotic = result.ln() - 1.0 / (2.0 * result) - 1.0 / (12.0 * result * result);
747 asymptotic - (0..n).map(|i| 1.0 / (x + i as f64)).sum::<f64>()
748 }
749}
750
751#[allow(non_snake_case)]
752#[cfg(test)]
753mod tests {
754 use super::*;
755 use approx::assert_relative_eq;
756 use scirs2_core::ndarray::array;
757
758 #[test]
759 #[allow(non_snake_case)]
760 fn test_variational_bayesian_gmm_basic() {
761 let X = array![
762 [0.0, 0.0],
763 [1.0, 1.0],
764 [2.0, 2.0],
765 [10.0, 10.0],
766 [11.0, 11.0],
767 [12.0, 12.0]
768 ];
769
770 let vbgmm = VariationalBayesianGMM::new()
771 .n_components(3)
772 .max_iter(10)
773 .random_state(42);
774
775 let fitted = vbgmm.fit(&X.view(), &()).unwrap();
776
777 assert!(fitted.converged() || fitted.n_iter() == 10);
778 assert!(fitted.effective_components() <= 3);
779 assert!(fitted.lower_bound().is_finite());
780 }
781
782 #[test]
783 #[allow(non_snake_case)]
784 fn test_variational_bayesian_gmm_prediction() {
785 let X = array![[0.0, 0.0], [1.0, 1.0], [10.0, 10.0], [11.0, 11.0]];
786
787 let vbgmm = VariationalBayesianGMM::new()
788 .n_components(2)
789 .max_iter(20)
790 .random_state(42);
791
792 let fitted = vbgmm.fit(&X.view(), &()).unwrap();
793 let predictions = fitted.predict(&X.view()).unwrap();
794
795 assert_eq!(predictions.len(), 4);
796 assert!(predictions[0] == predictions[1] || predictions[0] != predictions[2]);
798 }
799
800 #[test]
801 fn test_variational_bayesian_gmm_builder() {
802 let vbgmm = VariationalBayesianGMM::builder()
803 .n_components(5)
804 .covariance_type(CovarianceType::Diagonal)
805 .tol(1e-4)
806 .weight_concentration_prior(0.1)
807 .mean_precision_prior(0.1)
808 .degrees_of_freedom_prior(1.0)
809 .build();
810
811 assert_eq!(vbgmm.n_components, 5);
812 assert_eq!(vbgmm.covariance_type, CovarianceType::Diagonal);
813 assert_relative_eq!(vbgmm.tol, 1e-4);
814 assert_relative_eq!(vbgmm.weight_concentration_prior, 0.1);
815 }
816
817 #[test]
818 fn test_digamma_function() {
819 assert_relative_eq!(digamma(1.0), -0.5772, epsilon = 0.1);
821 assert_relative_eq!(digamma(2.0), 0.4228, epsilon = 0.1);
822 assert_relative_eq!(digamma(10.0), 2.2517, epsilon = 0.01);
823 }
824}