rs_ml/classification/
naive_bayes.rs1use crate::{Axis, Estimator};
4use core::f64;
5use ndarray::{Array1, Array2};
6use std::f64::consts::PI;
7
8use super::Classifier;
9
10#[derive(Debug, Clone, Copy)]
30pub struct GaussianNBEstimator;
31
32#[derive(Debug)]
34pub struct GaussianNB<Label> {
35 means: Array2<f64>,
36 vars: Array2<f64>,
37 priors: Array1<f64>,
38 labels: Vec<Label>,
39}
40
41impl<I, Label: Eq + Clone> Estimator<(&Array2<f64>, I)> for GaussianNBEstimator
42where
43 for<'a> &'a I: IntoIterator<Item = &'a Label>,
44{
45 type Estimator = GaussianNB<Label>;
46
47 fn fit(&self, input: &(&Array2<f64>, I)) -> Option<Self::Estimator> {
48 let (features, labels) = input;
49
50 let distinct_labels: Vec<_> = labels.into_iter().fold(vec![], |mut agg, curr| {
51 if agg.contains(curr) {
52 agg
53 } else {
54 agg.push(curr.clone());
55 agg
56 }
57 });
58
59 let nfeatures = features.ncols();
60 let nrows = features.nrows();
61
62 let mut means = Array2::zeros((distinct_labels.len(), nfeatures));
63 let mut vars = Array2::zeros((distinct_labels.len(), nfeatures));
64 let mut priors = Array1::zeros(distinct_labels.len());
65
66 for (idx, label) in distinct_labels.iter().enumerate() {
67 let indeces: Vec<usize> = labels
68 .into_iter()
69 .enumerate()
70 .filter_map(|(idx, l)| match l == label {
71 true => Some(idx),
72 false => None,
73 })
74 .collect();
75
76 let filtered_view = features.select(Axis(0), &indeces);
77 let c = filtered_view.nrows();
78
79 means
80 .row_mut(idx)
81 .assign(&filtered_view.mean_axis(Axis(0))?);
82 vars.row_mut(idx)
83 .assign(&filtered_view.var_axis(Axis(0), 1.0));
84 priors[idx] = c as f64 / nrows as f64;
85 }
86
87 Some(GaussianNB {
88 labels: distinct_labels,
89 means,
90 vars,
91 priors,
92 })
93 }
94}
95
96impl<Label: Clone> Classifier<Array2<f64>, Label> for GaussianNB<Label> {
97 fn labels(&self) -> &[Label] {
98 &self.labels
99 }
100
101 fn predict_proba(&self, arr: &Array2<f64>) -> Option<Array2<f64>> {
102 let broadcasted_means = self.means.view().insert_axis(Axis(1));
103 let broadcasted_vars = self.vars.view().insert_axis(Axis(1));
104 let broadcasted_log_priors = self.priors.view().insert_axis(Axis(1)).ln();
105
106 let log_likelihood = -0.5 * (&broadcasted_vars * 2.0 * PI).ln().sum_axis(Axis(2))
107 - 0.5 * ((arr - &broadcasted_means).pow2() / broadcasted_vars).sum_axis(Axis(2))
108 + broadcasted_log_priors;
109
110 let likelihood = log_likelihood.exp().t().to_owned();
111
112 let likelihood = &likelihood / &likelihood.sum_axis(Axis(1)).insert_axis(Axis(1));
113
114 Some(likelihood)
115 }
116}