rs_ml/classification/
naive_bayes.rs

1//! Naive Bayes classifiers
2
3use crate::{Axis, Estimator};
4use core::f64;
5use ndarray::{Array1, Array2};
6use std::f64::consts::PI;
7
8use super::Classifier;
9
10/// Estimator to train a [`GaussianNB`] classifier.
11///
12/// Example:
13/// ```
14/// use ndarray::{arr1, arr2};
15/// use crate::rs_ml::Estimator;
16/// use rs_ml::classification::naive_bayes::GaussianNBEstimator;
17///
18/// let features = arr2(&[
19///     [0., 0.],
20///     [0., 1.],
21///     [1., 0.],
22///     [1., 1.]
23/// ]);
24///
25/// let labels = arr1(&[false, true, true, false]);
26///
27/// let naive_bayes_model = GaussianNBEstimator.fit(&(&features, labels)).unwrap();
28/// ```
29#[derive(Debug, Clone, Copy)]
30pub struct GaussianNBEstimator;
31
32/// Represents a fitted Gaussian Naive Bayes Classifier. Created with the `fit()` function implemented for [GaussianNBEstimator].
33#[derive(Debug)]
34pub struct GaussianNB<Label> {
35    means: Array2<f64>,
36    vars: Array2<f64>,
37    priors: Array1<f64>,
38    labels: Vec<Label>,
39}
40
41impl<I, Label: Eq + Clone> Estimator<(&Array2<f64>, I)> for GaussianNBEstimator
42where
43    for<'a> &'a I: IntoIterator<Item = &'a Label>,
44{
45    type Estimator = GaussianNB<Label>;
46
47    fn fit(&self, input: &(&Array2<f64>, I)) -> Option<Self::Estimator> {
48        let (features, labels) = input;
49
50        let distinct_labels: Vec<_> = labels.into_iter().fold(vec![], |mut agg, curr| {
51            if agg.contains(curr) {
52                agg
53            } else {
54                agg.push(curr.clone());
55                agg
56            }
57        });
58
59        let nfeatures = features.ncols();
60        let nrows = features.nrows();
61
62        let mut means = Array2::zeros((distinct_labels.len(), nfeatures));
63        let mut vars = Array2::zeros((distinct_labels.len(), nfeatures));
64        let mut priors = Array1::zeros(distinct_labels.len());
65
66        for (idx, label) in distinct_labels.iter().enumerate() {
67            let indeces: Vec<usize> = labels
68                .into_iter()
69                .enumerate()
70                .filter_map(|(idx, l)| match l == label {
71                    true => Some(idx),
72                    false => None,
73                })
74                .collect();
75
76            let filtered_view = features.select(Axis(0), &indeces);
77            let c = filtered_view.nrows();
78
79            means
80                .row_mut(idx)
81                .assign(&filtered_view.mean_axis(Axis(0))?);
82            vars.row_mut(idx)
83                .assign(&filtered_view.var_axis(Axis(0), 1.0));
84            priors[idx] = c as f64 / nrows as f64;
85        }
86
87        Some(GaussianNB {
88            labels: distinct_labels,
89            means,
90            vars,
91            priors,
92        })
93    }
94}
95
96impl<Label: Clone> Classifier<Array2<f64>, Label> for GaussianNB<Label> {
97    fn labels(&self) -> &[Label] {
98        &self.labels
99    }
100
101    fn predict_proba(&self, arr: &Array2<f64>) -> Option<Array2<f64>> {
102        let broadcasted_means = self.means.view().insert_axis(Axis(1));
103        let broadcasted_vars = self.vars.view().insert_axis(Axis(1));
104        let broadcasted_log_priors = self.priors.view().insert_axis(Axis(1)).ln();
105
106        let log_likelihood = -0.5 * (&broadcasted_vars * 2.0 * PI).ln().sum_axis(Axis(2))
107            - 0.5 * ((arr - &broadcasted_means).pow2() / broadcasted_vars).sum_axis(Axis(2))
108            + broadcasted_log_priors;
109
110        let likelihood = log_likelihood.exp().t().to_owned();
111
112        let likelihood = &likelihood / &likelihood.sum_axis(Axis(1)).insert_axis(Axis(1));
113
114        Some(likelihood)
115    }
116}