rs_ml/classification/
naive_bayes.rs

1//! Naive Bayes classifiers
2
3use crate::{Axis, Estimator};
4use core::f64;
5use ndarray::{Array1, Array2};
6use std::f64::consts::PI;
7
8use super::Classifier;
9
10/// Estimator for gaussian NB
11pub struct GaussianNBEstimator;
12
13/// Gaussian Naive Bayes Classifier
14#[derive(Debug)]
15pub struct GaussianNB<Label> {
16    means: Array2<f64>,
17    vars: Array2<f64>,
18    priors: Array1<f64>,
19    labels: Vec<Label>,
20}
21
22impl<I, Label: Eq + Clone> Estimator<(&Array2<f64>, I)> for GaussianNBEstimator
23where
24    for<'a> &'a I: IntoIterator<Item = &'a Label>,
25{
26    type Estimator = GaussianNB<Label>;
27
28    fn fit(&self, input: &(&Array2<f64>, I)) -> Option<Self::Estimator> {
29        let (features, labels) = input;
30
31        let distinct_labels: Vec<_> = labels.into_iter().fold(vec![], |mut agg, curr| {
32            if agg.contains(curr) {
33                agg
34            } else {
35                agg.push(curr.clone());
36                agg
37            }
38        });
39
40        let nfeatures = features.ncols();
41        let nrows = features.nrows();
42
43        let mut means = Array2::zeros((distinct_labels.len(), nfeatures));
44        let mut vars = Array2::zeros((distinct_labels.len(), nfeatures));
45        let mut priors = Array1::zeros(distinct_labels.len());
46
47        for (idx, label) in distinct_labels.iter().enumerate() {
48            let indeces: Vec<usize> = labels
49                .into_iter()
50                .enumerate()
51                .filter_map(|(idx, l)| match l == label {
52                    true => Some(idx),
53                    false => None,
54                })
55                .collect();
56
57            let filtered_view = features.select(Axis(0), &indeces);
58            let c = filtered_view.nrows();
59
60            means
61                .row_mut(idx)
62                .assign(&filtered_view.mean_axis(Axis(0))?);
63            vars.row_mut(idx)
64                .assign(&filtered_view.var_axis(Axis(0), 1.0));
65            priors[idx] = c as f64 / nrows as f64;
66        }
67
68        Some(GaussianNB {
69            labels: distinct_labels,
70            means,
71            vars,
72            priors,
73        })
74    }
75}
76
77impl<Label: Clone> Classifier<Array2<f64>, Label> for GaussianNB<Label> {
78    fn labels(&self) -> &[Label] {
79        &self.labels
80    }
81
82    fn predict_proba(&self, arr: &Array2<f64>) -> Option<Array2<f64>> {
83        let broadcasted_means = self.means.view().insert_axis(Axis(1));
84        let broadcasted_vars = self.vars.view().insert_axis(Axis(1));
85        let broadcasted_log_priors = self.priors.view().insert_axis(Axis(1)).ln();
86
87        let mut log_likelihood = -0.5 * (&broadcasted_vars * 2.0 * PI).ln().sum_axis(Axis(2));
88        log_likelihood = log_likelihood
89            - 0.5 * ((arr - &broadcasted_means).pow2() / broadcasted_vars).sum_axis(Axis(2))
90            + broadcasted_log_priors;
91
92        Some(log_likelihood.exp().t().to_owned())
93    }
94}