1use scirs2_core::ndarray::{Array1, Array2, Axis};
10use scirs2_core::random::prelude::*;
11use scirs2_core::random::Distribution;
12use sklears_core::error::Result;
13use sklears_core::types::{Float, Int};
14use std::collections::HashMap;
15
16#[derive(Debug, Clone, PartialEq)]
18pub enum AdvancedBayesianStrategy {
19 EmpiricalBayes,
21 Hierarchical,
23 VariationalBayes,
25 MCMCSampling,
27 ConjugatePrior,
29}
30
31#[derive(Debug, Clone)]
33pub struct EmpiricalBayesEstimator {
34 pub max_iter: usize,
36 pub tolerance: Float,
38 pub random_state: Option<u64>,
40 pub hyperparameters_: Option<Array1<Float>>,
42 pub log_likelihood_: Option<Vec<Float>>,
44}
45
46impl EmpiricalBayesEstimator {
47 pub fn new() -> Self {
49 Self {
50 max_iter: 100,
51 tolerance: 1e-6,
52 random_state: None,
53 hyperparameters_: None,
54 log_likelihood_: None,
55 }
56 }
57
58 pub fn with_max_iter(mut self, max_iter: usize) -> Self {
60 self.max_iter = max_iter;
61 self
62 }
63
64 pub fn with_tolerance(mut self, tolerance: Float) -> Self {
66 self.tolerance = tolerance;
67 self
68 }
69
70 pub fn with_random_state(mut self, random_state: u64) -> Self {
72 self.random_state = Some(random_state);
73 self
74 }
75
76 pub fn fit_classification(&mut self, y: &Array1<Int>) -> Result<()> {
78 let mut class_counts: HashMap<Int, usize> = HashMap::new();
79 for &label in y.iter() {
80 *class_counts.entry(label).or_insert(0) += 1;
81 }
82
83 let mut classes: Vec<Int> = class_counts.keys().copied().collect();
84 classes.sort();
85 let n_classes = classes.len();
86 let n_samples = y.len() as Float;
87
88 let mut hyperparams = Array1::ones(n_classes);
90 let mut log_likelihoods = Vec::new();
91
92 for iter in 0..self.max_iter {
94 let mut expected_counts = Array1::<Float>::zeros(n_classes);
96 for &label in y.iter() {
97 let class_idx = classes.iter().position(|&c| c == label).unwrap();
98 expected_counts[class_idx] += 1.0;
99 }
100
101 let old_hyperparams = hyperparams.clone();
103
104 let observed_props: Array1<Float> = expected_counts.mapv(|x| x / n_samples);
106 let mean_prop = observed_props.mean().unwrap();
107 let variance_sum: Float = observed_props.mapv(|p| (p - mean_prop).powi(2)).sum();
108
109 let concentration = if variance_sum > 0.0 {
111 let variance_mean = variance_sum / (n_classes as Float);
112 let alpha_sum = mean_prop * (1.0 - mean_prop) / variance_mean - 1.0;
113 alpha_sum.max(0.1) } else {
115 1.0
116 };
117
118 hyperparams = observed_props.mapv(|p| p * concentration);
119
120 let log_likelihood = self.compute_log_likelihood(&hyperparams, &expected_counts);
122 log_likelihoods.push(log_likelihood);
123
124 let param_diff: Float = (&hyperparams - &old_hyperparams).mapv(|x| x.abs()).sum();
126 if param_diff < self.tolerance {
127 break;
128 }
129 }
130
131 self.hyperparameters_ = Some(hyperparams);
132 self.log_likelihood_ = Some(log_likelihoods);
133 Ok(())
134 }
135
136 fn compute_log_likelihood(&self, hyperparams: &Array1<Float>, counts: &Array1<Float>) -> Float {
138 let alpha_sum = hyperparams.sum();
139 let count_sum = counts.sum();
140
141 let mut log_likelihood = 0.0;
143
144 for (&alpha, &count) in hyperparams.iter().zip(counts.iter()) {
146 log_likelihood += gamma_ln(alpha + count) - gamma_ln(alpha);
147 }
148
149 log_likelihood += gamma_ln(alpha_sum) - gamma_ln(alpha_sum + count_sum);
150 log_likelihood
151 }
152
153 pub fn hyperparameters(&self) -> Option<&Array1<Float>> {
155 self.hyperparameters_.as_ref()
156 }
157
158 pub fn log_likelihood_evolution(&self) -> Option<&Vec<Float>> {
160 self.log_likelihood_.as_ref()
161 }
162}
163
164#[derive(Debug, Clone)]
166pub struct HierarchicalBayesEstimator {
167 pub groups: Option<Array1<Int>>,
169 pub global_hyperparams_: Option<Array1<Float>>,
171 pub group_params_: Option<HashMap<Int, Array1<Float>>>,
173 pub random_state: Option<u64>,
175}
176
177impl HierarchicalBayesEstimator {
178 pub fn new() -> Self {
180 Self {
181 groups: None,
182 global_hyperparams_: None,
183 group_params_: None,
184 random_state: None,
185 }
186 }
187
188 pub fn with_groups(mut self, groups: Array1<Int>) -> Self {
190 self.groups = Some(groups);
191 self
192 }
193
194 pub fn with_random_state(mut self, random_state: u64) -> Self {
196 self.random_state = Some(random_state);
197 self
198 }
199
200 pub fn fit_classification(&mut self, y: &Array1<Int>) -> Result<()> {
202 let groups = self.groups.as_ref().ok_or_else(|| {
203 sklears_core::error::SklearsError::InvalidInput(
204 "Group assignments must be provided".to_string(),
205 )
206 })?;
207
208 if groups.len() != y.len() {
209 return Err(sklears_core::error::SklearsError::InvalidInput(
210 "Groups and labels must have same length".to_string(),
211 ));
212 }
213
214 let mut class_counts: HashMap<Int, usize> = HashMap::new();
216 for &label in y.iter() {
217 *class_counts.entry(label).or_insert(0) += 1;
218 }
219 let mut classes: Vec<Int> = class_counts.keys().copied().collect();
220 classes.sort();
221 let n_classes = classes.len();
222
223 let mut unique_groups: Vec<Int> = groups.iter().copied().collect();
224 unique_groups.sort();
225 unique_groups.dedup();
226
227 let mut group_params = HashMap::new();
229 let mut global_counts = Array1::<Float>::zeros(n_classes);
230
231 for &group in &unique_groups {
232 let mut group_class_counts = Array1::<Float>::zeros(n_classes);
233 let mut group_total = 0;
234
235 for (i, (&label, &group_id)) in y.iter().zip(groups.iter()).enumerate() {
236 if group_id == group {
237 let class_idx = classes.iter().position(|&c| c == label).unwrap();
238 group_class_counts[class_idx] += 1.0;
239 global_counts[class_idx] += 1.0;
240 group_total += 1;
241 }
242 }
243
244 if group_total > 0 {
245 let group_probs = group_class_counts.mapv(|x| x / (group_total as Float));
247 group_params.insert(group, group_probs);
248 }
249 }
250
251 let global_total = global_counts.sum();
253 let global_hyperparams = if global_total > 0.0 {
254 global_counts.mapv(|x| x / global_total)
255 } else {
256 Array1::ones(n_classes) / (n_classes as Float)
257 };
258
259 self.global_hyperparams_ = Some(global_hyperparams);
260 self.group_params_ = Some(group_params);
261 Ok(())
262 }
263
264 pub fn global_hyperparameters(&self) -> Option<&Array1<Float>> {
266 self.global_hyperparams_.as_ref()
267 }
268
269 pub fn group_parameters(&self) -> Option<&HashMap<Int, Array1<Float>>> {
271 self.group_params_.as_ref()
272 }
273}
274
275#[derive(Debug, Clone)]
277pub struct VariationalBayesEstimator {
278 pub max_iter: usize,
280 pub tolerance: Float,
282 pub variational_params_: Option<Array1<Float>>,
284 pub elbo_: Option<Vec<Float>>,
286 pub random_state: Option<u64>,
288}
289
290impl VariationalBayesEstimator {
291 pub fn new() -> Self {
293 Self {
294 max_iter: 100,
295 tolerance: 1e-6,
296 variational_params_: None,
297 elbo_: None,
298 random_state: None,
299 }
300 }
301
302 pub fn with_max_iter(mut self, max_iter: usize) -> Self {
304 self.max_iter = max_iter;
305 self
306 }
307
308 pub fn with_tolerance(mut self, tolerance: Float) -> Self {
310 self.tolerance = tolerance;
311 self
312 }
313
314 pub fn fit_classification(&mut self, y: &Array1<Int>) -> Result<()> {
316 let mut class_counts: HashMap<Int, usize> = HashMap::new();
317 for &label in y.iter() {
318 *class_counts.entry(label).or_insert(0) += 1;
319 }
320
321 let mut classes: Vec<Int> = class_counts.keys().copied().collect();
322 classes.sort();
323 let n_classes = classes.len();
324
325 let mut q_params = Array1::ones(n_classes);
327 let mut elbo_values = Vec::new();
328
329 for _iter in 0..self.max_iter {
331 let old_params = q_params.clone();
332
333 let mut new_params = Array1::<Float>::zeros(n_classes);
335 for (i, &class) in classes.iter().enumerate() {
336 let count = *class_counts.get(&class).unwrap() as Float;
337 new_params[i] = count + 1.0;
339 }
340
341 q_params = new_params;
342
343 let elbo = self.compute_elbo(&q_params, &class_counts, &classes);
345 elbo_values.push(elbo);
346
347 let param_diff: Float = (&q_params - &old_params).mapv(|x| x.abs()).sum();
349 if param_diff < self.tolerance {
350 break;
351 }
352 }
353
354 let param_sum = q_params.sum();
356 q_params = q_params.mapv(|x| x / param_sum);
357
358 self.variational_params_ = Some(q_params);
359 self.elbo_ = Some(elbo_values);
360 Ok(())
361 }
362
363 fn compute_elbo(
365 &self,
366 params: &Array1<Float>,
367 counts: &HashMap<Int, usize>,
368 classes: &[Int],
369 ) -> Float {
370 let mut elbo = 0.0;
371 let param_sum = params.sum();
372
373 for (i, &class) in classes.iter().enumerate() {
375 let count = *counts.get(&class).unwrap() as Float;
376 if count > 0.0 {
377 elbo += count * (params[i] / param_sum).ln();
378 }
379 }
380
381 for ¶m in params.iter() {
383 if param > 0.0 {
384 elbo += param.ln();
385 }
386 }
387
388 elbo
389 }
390
391 pub fn variational_parameters(&self) -> Option<&Array1<Float>> {
393 self.variational_params_.as_ref()
394 }
395
396 pub fn elbo_evolution(&self) -> Option<&Vec<Float>> {
398 self.elbo_.as_ref()
399 }
400}
401
402#[derive(Debug, Clone)]
404pub struct MCMCBayesEstimator {
405 pub n_samples: usize,
407 pub burn_in: usize,
409 pub thin: usize,
411 pub samples_: Option<Array2<Float>>,
413 pub random_state: Option<u64>,
415}
416
417impl MCMCBayesEstimator {
418 pub fn new() -> Self {
420 Self {
421 n_samples: 1000,
422 burn_in: 200,
423 thin: 1,
424 samples_: None,
425 random_state: None,
426 }
427 }
428
429 pub fn with_n_samples(mut self, n_samples: usize) -> Self {
431 self.n_samples = n_samples;
432 self
433 }
434
435 pub fn with_burn_in(mut self, burn_in: usize) -> Self {
437 self.burn_in = burn_in;
438 self
439 }
440
441 pub fn with_random_state(mut self, random_state: u64) -> Self {
443 self.random_state = Some(random_state);
444 self
445 }
446
447 pub fn fit_classification(&mut self, y: &Array1<Int>) -> Result<()> {
449 let mut class_counts: HashMap<Int, usize> = HashMap::new();
450 for &label in y.iter() {
451 *class_counts.entry(label).or_insert(0) += 1;
452 }
453
454 let mut classes: Vec<Int> = class_counts.keys().copied().collect();
455 classes.sort();
456 let n_classes = classes.len();
457
458 let mut rng = if let Some(seed) = self.random_state {
459 StdRng::seed_from_u64(seed)
460 } else {
461 StdRng::seed_from_u64(0) };
463
464 let mut theta = Array1::<Float>::from_elem(n_classes, 1.0 / n_classes as Float);
466 let total_samples = self.burn_in + self.n_samples * self.thin;
467 let mut samples = Array2::<Float>::zeros((self.n_samples, n_classes));
468
469 for iter in 0..total_samples {
471 let mut alpha_posterior = Array1::<Float>::ones(n_classes); for (i, &class) in classes.iter().enumerate() {
476 let count = *class_counts.get(&class).unwrap() as Float;
477 alpha_posterior[i] += count;
478 }
479
480 let mut gamma_samples = Array1::<Float>::zeros(n_classes);
482 for i in 0..n_classes {
483 let gamma_dist = Gamma::new(alpha_posterior[i], 1.0).unwrap();
484 gamma_samples[i] = gamma_dist.sample(&mut rng);
485 }
486
487 let gamma_sum = gamma_samples.sum();
489 theta = gamma_samples.mapv(|x| x / gamma_sum);
490
491 if iter >= self.burn_in && (iter - self.burn_in) % self.thin == 0 {
493 let sample_idx = (iter - self.burn_in) / self.thin;
494 if sample_idx < self.n_samples {
495 for j in 0..n_classes {
496 samples[[sample_idx, j]] = theta[j];
497 }
498 }
499 }
500 }
501
502 self.samples_ = Some(samples);
503 Ok(())
504 }
505
506 pub fn samples(&self) -> Option<&Array2<Float>> {
508 self.samples_.as_ref()
509 }
510
511 pub fn posterior_mean(&self) -> Option<Array1<Float>> {
513 self.samples_
514 .as_ref()
515 .map(|samples| samples.mean_axis(Axis(0)).unwrap())
516 }
517
518 pub fn posterior_std(&self) -> Option<Array1<Float>> {
520 self.samples_
521 .as_ref()
522 .map(|samples| samples.std_axis(Axis(0), 0.0))
523 }
524
525 pub fn credible_interval(&self, alpha: Float) -> Option<(Array1<Float>, Array1<Float>)> {
527 let samples = self.samples_.as_ref()?;
528 let n_classes = samples.ncols();
529 let mut lower = Array1::<Float>::zeros(n_classes);
530 let mut upper = Array1::<Float>::zeros(n_classes);
531
532 for i in 0..n_classes {
533 let mut column: Vec<Float> = samples.column(i).to_vec();
534 column.sort_by(|a, b| a.partial_cmp(b).unwrap());
535
536 let lower_idx = ((alpha / 2.0) * (column.len() as Float)) as usize;
537 let upper_idx = ((1.0 - alpha / 2.0) * (column.len() as Float)) as usize;
538
539 lower[i] = column[lower_idx.min(column.len() - 1)];
540 upper[i] = column[upper_idx.min(column.len() - 1)];
541 }
542
543 Some((lower, upper))
544 }
545}
546
547fn gamma_ln(x: Float) -> Float {
549 if x <= 0.0 {
551 Float::INFINITY
552 } else if x < 12.0 {
553 if x.fract() == 0.0 && x <= 10.0 {
555 let n = x as usize;
557 if n == 1 {
558 0.0
559 } else {
560 (1..n).map(|i| (i as Float).ln()).sum()
561 }
562 } else {
563 (x - 0.5) * x.ln() - x + 0.5 * (2.0 * std::f64::consts::PI).ln()
565 }
566 } else {
567 (x - 0.5) * x.ln() - x + 0.5 * (2.0 * std::f64::consts::PI).ln()
569 }
570}
571
572impl Default for EmpiricalBayesEstimator {
574 fn default() -> Self {
575 Self::new()
576 }
577}
578
579impl Default for HierarchicalBayesEstimator {
580 fn default() -> Self {
581 Self::new()
582 }
583}
584
585impl Default for VariationalBayesEstimator {
586 fn default() -> Self {
587 Self::new()
588 }
589}
590
591impl Default for MCMCBayesEstimator {
592 fn default() -> Self {
593 Self::new()
594 }
595}
596
597#[allow(non_snake_case)]
598#[cfg(test)]
599mod tests {
600 use super::*;
601 use approx::assert_abs_diff_eq;
602 use scirs2_core::ndarray::array;
603
604 #[test]
605 fn test_empirical_bayes_basic() {
606 let y = array![0, 0, 0, 1, 1, 2]; let mut estimator = EmpiricalBayesEstimator::new().with_random_state(42);
608
609 let result = estimator.fit_classification(&y);
610 assert!(result.is_ok());
611
612 let hyperparams = estimator.hyperparameters().unwrap();
613 assert_eq!(hyperparams.len(), 3);
614
615 for ¶m in hyperparams.iter() {
617 assert!(param > 0.0);
618 }
619
620 assert!(hyperparams[0] >= hyperparams[1]);
622 assert!(hyperparams[0] >= hyperparams[2]);
623 }
624
625 #[test]
626 fn test_hierarchical_bayes_basic() {
627 let y = array![0, 0, 1, 1, 0, 1];
628 let groups = array![1, 1, 1, 2, 2, 2]; let mut estimator = HierarchicalBayesEstimator::new()
631 .with_groups(groups)
632 .with_random_state(42);
633
634 let result = estimator.fit_classification(&y);
635 assert!(result.is_ok());
636
637 let global_params = estimator.global_hyperparameters().unwrap();
638 assert_eq!(global_params.len(), 2);
639
640 let group_params = estimator.group_parameters().unwrap();
641 assert_eq!(group_params.len(), 2);
642
643 assert!(group_params.contains_key(&1));
645 assert!(group_params.contains_key(&2));
646 }
647
648 #[test]
649 fn test_variational_bayes_basic() {
650 let y = array![0, 0, 0, 1, 1, 2];
651 let mut estimator = VariationalBayesEstimator::new()
652 .with_max_iter(50)
653 .with_tolerance(1e-4);
654
655 let result = estimator.fit_classification(&y);
656 assert!(result.is_ok());
657
658 let params = estimator.variational_parameters().unwrap();
659 assert_eq!(params.len(), 3);
660
661 let sum: Float = params.sum();
663 assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-10);
664
665 let elbo = estimator.elbo_evolution().unwrap();
667 assert!(!elbo.is_empty());
668 }
669
670 #[test]
671 fn test_mcmc_bayes_basic() {
672 let y = array![0, 0, 0, 1, 1, 2];
673 let mut estimator = MCMCBayesEstimator::new()
674 .with_n_samples(100)
675 .with_burn_in(20)
676 .with_random_state(42);
677
678 let result = estimator.fit_classification(&y);
679 assert!(result.is_ok());
680
681 let samples = estimator.samples().unwrap();
682 assert_eq!(samples.nrows(), 100);
683 assert_eq!(samples.ncols(), 3);
684
685 for i in 0..samples.nrows() {
687 let row_sum: Float = samples.row(i).sum();
688 assert_abs_diff_eq!(row_sum, 1.0, epsilon = 1e-10);
689 }
690
691 let mean = estimator.posterior_mean().unwrap();
693 assert_eq!(mean.len(), 3);
694
695 let (lower, upper) = estimator.credible_interval(0.05).unwrap();
697 assert_eq!(lower.len(), 3);
698 assert_eq!(upper.len(), 3);
699
700 for i in 0..3 {
702 assert!(lower[i] <= upper[i]);
703 }
704 }
705
706 #[test]
707 fn test_gamma_ln_function() {
708 assert_abs_diff_eq!(gamma_ln(1.0), 0.0, epsilon = 1e-10);
710 assert_abs_diff_eq!(gamma_ln(2.0), 0.0, epsilon = 1e-10); assert_abs_diff_eq!(gamma_ln(3.0), (2.0f64).ln(), epsilon = 1e-10); let result = gamma_ln(10.0);
715 assert!(result > 0.0);
716 assert!(result.is_finite());
717 }
718}