rs_ml/classification/
naive_bayes.rs1use crate::{Axis, Classifier};
4use core::f64;
5use ndarray::{Array1, Array2};
6
7#[derive(Debug)]
9pub struct GaussianNB<Label> {
10 means: Array2<f64>,
11 std_devs: Array2<f64>,
12 posteriors: Array1<f64>,
13 labels: Vec<Label>,
14}
15
16impl<Label: Eq + Clone> Classifier<Array2<f64>, Label> for GaussianNB<Label> {
17 fn fit<I>(arr: &Array2<f64>, y: I) -> Option<GaussianNB<Label>>
18 where
19 for<'a> &'a I: IntoIterator<Item = &'a Label>,
20 {
21 let labels: Vec<Label> = y.into_iter().fold(vec![], |mut agg, curr| {
22 if agg.contains(curr) {
23 agg
24 } else {
25 agg.push(curr.clone());
26 agg
27 }
28 });
29
30 let nrows = arr.nrows();
31
32 let mut means = Array2::zeros((labels.len(), arr.ncols()));
33 let mut std_devs = Array2::zeros((labels.len(), arr.ncols()));
34 let mut posteriors = Array1::zeros(labels.len());
35
36 for (idx, label) in labels.iter().enumerate() {
37 let indeces: Vec<usize> = y
38 .into_iter()
39 .enumerate()
40 .filter_map(|(idx, l)| match l == label {
41 true => Some(idx),
42 false => None,
43 })
44 .collect();
45
46 let filtered_view = arr.select(Axis(0), &indeces);
47 let c = filtered_view.nrows();
48
49 means
50 .row_mut(idx)
51 .assign(&filtered_view.mean_axis(Axis(0))?);
52 std_devs
53 .row_mut(idx)
54 .assign(&filtered_view.std_axis(Axis(0), (c - 1) as f64));
55 posteriors[idx] = c as f64 / nrows as f64;
56 }
57
58 Some(GaussianNB {
59 labels,
60 means,
61 std_devs,
62 posteriors,
63 })
64 }
65
66 fn labels(&self) -> &[Label] {
67 &self.labels
68 }
69
70 fn predict_proba(&self, arr: &Array2<f64>) -> Option<Array2<f64>> {
71 let root_2pi = f64::sqrt(2. * f64::consts::PI);
72 let broadcasted_means = self.means.view().insert_axis(Axis(1));
73 let broadcasted_stddev = self.std_devs.view().insert_axis(Axis(1));
74 let broadcasted_posteriors = self.posteriors.view().insert_axis(Axis(1));
75
76 let p1 = -(arr - &broadcasted_means).pow2() / (2. * broadcasted_stddev.pow2());
77 let p2 = (&broadcasted_stddev * root_2pi).recip();
78
79 let p = (p2 * p1.exp()).product_axis(Axis(2)) * broadcasted_posteriors;
80
81 Some(p.t().to_owned())
82 }
83}