rs_ml/classification/
naive_bayes.rs1use crate::{Axis, Estimator};
4use core::f64;
5use ndarray::{Array1, Array2};
6use std::f64::consts::PI;
7
8use super::Classifier;
9
10pub struct GaussianNBEstimator;
12
13#[derive(Debug)]
15pub struct GaussianNB<Label> {
16 means: Array2<f64>,
17 vars: Array2<f64>,
18 priors: Array1<f64>,
19 labels: Vec<Label>,
20}
21
22impl<I, Label: Eq + Clone> Estimator<(&Array2<f64>, I)> for GaussianNBEstimator
23where
24 for<'a> &'a I: IntoIterator<Item = &'a Label>,
25{
26 type Estimator = GaussianNB<Label>;
27
28 fn fit(&self, input: &(&Array2<f64>, I)) -> Option<Self::Estimator> {
29 let (features, labels) = input;
30
31 let distinct_labels: Vec<_> = labels.into_iter().fold(vec![], |mut agg, curr| {
32 if agg.contains(curr) {
33 agg
34 } else {
35 agg.push(curr.clone());
36 agg
37 }
38 });
39
40 let nfeatures = features.ncols();
41 let nrows = features.nrows();
42
43 let mut means = Array2::zeros((distinct_labels.len(), nfeatures));
44 let mut vars = Array2::zeros((distinct_labels.len(), nfeatures));
45 let mut priors = Array1::zeros(distinct_labels.len());
46
47 for (idx, label) in distinct_labels.iter().enumerate() {
48 let indeces: Vec<usize> = labels
49 .into_iter()
50 .enumerate()
51 .filter_map(|(idx, l)| match l == label {
52 true => Some(idx),
53 false => None,
54 })
55 .collect();
56
57 let filtered_view = features.select(Axis(0), &indeces);
58 let c = filtered_view.nrows();
59
60 means
61 .row_mut(idx)
62 .assign(&filtered_view.mean_axis(Axis(0))?);
63 vars.row_mut(idx)
64 .assign(&filtered_view.var_axis(Axis(0), 1.0));
65 priors[idx] = c as f64 / nrows as f64;
66 }
67
68 Some(GaussianNB {
69 labels: distinct_labels,
70 means,
71 vars,
72 priors,
73 })
74 }
75}
76
77impl<Label: Clone> Classifier<Array2<f64>, Label> for GaussianNB<Label> {
78 fn labels(&self) -> &[Label] {
79 &self.labels
80 }
81
82 fn predict_proba(&self, arr: &Array2<f64>) -> Option<Array2<f64>> {
83 let broadcasted_means = self.means.view().insert_axis(Axis(1));
84 let broadcasted_vars = self.vars.view().insert_axis(Axis(1));
85 let broadcasted_log_priors = self.priors.view().insert_axis(Axis(1)).ln();
86
87 let mut log_likelihood = -0.5 * (&broadcasted_vars * 2.0 * PI).ln().sum_axis(Axis(2));
88 log_likelihood = log_likelihood
89 - 0.5 * ((arr - &broadcasted_means).pow2() / broadcasted_vars).sum_axis(Axis(2))
90 + broadcasted_log_priors;
91
92 Some(log_likelihood.exp().t().to_owned())
93 }
94}