rs_ml/classification/
naive_bayes.rs

1//! Naive Bayes classifiers
2
3use crate::{Axis, Classifier};
4use core::f64;
5use ndarray::{Array1, Array2};
6
7/// Gaussian Naive Bayes Classifier
8#[derive(Debug)]
9pub struct GaussianNB<Label> {
10    means: Array2<f64>,
11    std_devs: Array2<f64>,
12    posteriors: Array1<f64>,
13    labels: Vec<Label>,
14}
15
16impl<Label: Eq + Clone> Classifier<Array2<f64>, Label> for GaussianNB<Label> {
17    fn fit<I>(arr: &Array2<f64>, y: I) -> Option<GaussianNB<Label>>
18    where
19        for<'a> &'a I: IntoIterator<Item = &'a Label>,
20    {
21        let labels: Vec<Label> = y.into_iter().fold(vec![], |mut agg, curr| {
22            if agg.contains(curr) {
23                agg
24            } else {
25                agg.push(curr.clone());
26                agg
27            }
28        });
29
30        let nrows = arr.nrows();
31
32        let mut means = Array2::zeros((labels.len(), arr.ncols()));
33        let mut std_devs = Array2::zeros((labels.len(), arr.ncols()));
34        let mut posteriors = Array1::zeros(labels.len());
35
36        for (idx, label) in labels.iter().enumerate() {
37            let indeces: Vec<usize> = y
38                .into_iter()
39                .enumerate()
40                .filter_map(|(idx, l)| match l == label {
41                    true => Some(idx),
42                    false => None,
43                })
44                .collect();
45
46            let filtered_view = arr.select(Axis(0), &indeces);
47            let c = filtered_view.nrows();
48
49            means
50                .row_mut(idx)
51                .assign(&filtered_view.mean_axis(Axis(0))?);
52            std_devs
53                .row_mut(idx)
54                .assign(&filtered_view.std_axis(Axis(0), (c - 1) as f64));
55            posteriors[idx] = c as f64 / nrows as f64;
56        }
57
58        Some(GaussianNB {
59            labels,
60            means,
61            std_devs,
62            posteriors,
63        })
64    }
65
66    fn labels(&self) -> &[Label] {
67        &self.labels
68    }
69
70    fn predict_proba(&self, arr: &Array2<f64>) -> Option<Array2<f64>> {
71        let root_2pi = f64::sqrt(2. * f64::consts::PI);
72        let broadcasted_means = self.means.view().insert_axis(Axis(1));
73        let broadcasted_stddev = self.std_devs.view().insert_axis(Axis(1));
74        let broadcasted_posteriors = self.posteriors.view().insert_axis(Axis(1));
75
76        let p1 = -(arr - &broadcasted_means).pow2() / (2. * broadcasted_stddev.pow2());
77        let p2 = (&broadcasted_stddev * root_2pi).recip();
78
79        let p = (p2 * p1.exp()).product_axis(Axis(2)) * broadcasted_posteriors;
80
81        Some(p.t().to_owned())
82    }
83}