rs_ml/classification/
naive_bayes.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
use crate::{Axis, Classifier};
use core::f64;
use ndarray::{Array1, Array2};

#[derive(Debug)]
pub struct GaussianNB<Label> {
    means: Array2<f64>,
    std_devs: Array2<f64>,
    posteriors: Array1<f64>,
    labels: Vec<Label>,
}

impl<Label: Eq + Clone> Classifier<Array2<f64>, Label> for GaussianNB<Label> {
    fn fit(arr: &Array2<f64>, y: &[Label]) -> Option<GaussianNB<Label>> {
        let labels: Vec<Label> = y.iter().fold(vec![], |mut agg, curr| {
            if agg.contains(curr) {
                agg
            } else {
                agg.push(curr.clone());
                agg
            }
        });

        let nrows = arr.nrows();

        let mut means = Array2::zeros((labels.len(), arr.ncols()));
        let mut std_devs = Array2::zeros((labels.len(), arr.ncols()));
        let mut posteriors = Array1::zeros(labels.len());

        for (idx, label) in labels.iter().enumerate() {
            let indeces: Vec<usize> = y
                .iter()
                .enumerate()
                .filter_map(|(idx, l)| match l == label {
                    true => Some(idx),
                    false => None,
                })
                .collect();

            let filtered_view = arr.select(Axis(0), &indeces);
            let c = filtered_view.nrows();

            means
                .row_mut(idx)
                .assign(&filtered_view.mean_axis(Axis(0))?);
            std_devs
                .row_mut(idx)
                .assign(&filtered_view.std_axis(Axis(0), (c - 1) as f64));
            posteriors[idx] = c as f64 / nrows as f64;
        }

        Some(GaussianNB {
            labels,
            means,
            std_devs,
            posteriors,
        })
    }

    fn labels(&self) -> &[Label] {
        &self.labels
    }

    fn predict_proba(&self, arr: &Array2<f64>) -> Option<Array2<f64>> {
        let root_2pi = f64::sqrt(2. * f64::consts::PI);
        let broadcasted_means = self.means.view().insert_axis(Axis(1));
        let broadcasted_stddev = self.std_devs.view().insert_axis(Axis(1));
        let broadcasted_posteriors = self.posteriors.view().insert_axis(Axis(1));

        let p1 = -(arr - &broadcasted_means).pow2() / (2. * broadcasted_stddev.pow2());
        let p2 = (&broadcasted_stddev * root_2pi).recip();

        let p = (p2 * p1.exp()).product_axis(Axis(2)) * broadcasted_posteriors;

        Some(p.t().to_owned())
    }
}