rs_ml/classification/
naive_bayes.rs

1//! Naive Bayes classifiers
2
3use crate::{Axis, Estimator};
4use core::f64;
5use ndarray::{Array1, Array2};
6use std::f64::consts::PI;
7
8use super::{ClassificationDataSet, Classifier};
9
10/// Estimator to train a [`GaussianNB`] classifier.
11///
12/// Example:
13/// ```
14/// use ndarray::{arr1, arr2};
15/// use crate::rs_ml::Estimator;
16/// use crate::rs_ml::classification::{ClassificationRecord, ClassificationDataSet};
17/// use rs_ml::classification::naive_bayes::GaussianNBEstimator;
18///
19/// let features = arr2(&[
20///     [0., 0.],
21///     [0., 1.],
22///     [1., 0.],
23///     [1., 1.]
24/// ]);
25///
26/// let labels = arr1(&[false, true, true, false]);
27///
28/// let records: Vec<_> = features
29///     .rows()
30///     .into_iter()
31///     .zip(labels)
32///     .map(|(row, label)| ClassificationRecord::from((row.to_owned(), label)))
33///     .collect();
34///
35/// let dataset = ClassificationDataSet::from(records);
36/// let model = GaussianNBEstimator.fit(&dataset).unwrap();
37/// ```
38#[derive(Debug, Clone, Copy)]
39pub struct GaussianNBEstimator;
40
41/// Represents a fitted Gaussian Naive Bayes Classifier. Created with the `fit()` function implemented for [GaussianNBEstimator].
42#[derive(Debug)]
43pub struct GaussianNB<Label> {
44    means: Array2<f64>,
45    vars: Array2<f64>,
46    priors: Array1<f64>,
47    labels: Vec<Label>,
48}
49
50impl<Label: PartialEq + Clone> Estimator<ClassificationDataSet<Array1<f64>, Label>>
51    for GaussianNBEstimator
52{
53    type Estimator = GaussianNB<Label>;
54
55    fn fit(&self, input: &ClassificationDataSet<Array1<f64>, Label>) -> Option<Self::Estimator> {
56        let distinct_labels: Vec<_> =
57            input
58                .get_labels()
59                .into_iter()
60                .fold(vec![], |mut agg, curr| {
61                    if agg.contains(curr) {
62                        agg
63                    } else {
64                        agg.push(curr.clone());
65                        agg
66                    }
67                });
68
69        let features_vec = input.get_features();
70        let nrows = features_vec.len();
71        let nfeatures = features_vec.first()?.len();
72
73        let flat_shapes: Vec<f64> = features_vec
74            .iter()
75            .flat_map(|record| record.into_iter())
76            .copied()
77            .collect();
78
79        let features = Array2::from_shape_vec((nrows, nfeatures), flat_shapes).ok()?;
80
81        let mut means = Array2::zeros((distinct_labels.len(), nfeatures));
82        let mut vars = Array2::zeros((distinct_labels.len(), nfeatures));
83        let mut priors = Array1::zeros(distinct_labels.len());
84
85        for (idx, label) in distinct_labels.iter().enumerate() {
86            let indeces: Vec<usize> = input
87                .get_labels()
88                .into_iter()
89                .enumerate()
90                .filter_map(|(idx, l)| match l == label {
91                    true => Some(idx),
92                    false => None,
93                })
94                .collect();
95
96            let filtered_view = features.select(Axis(0), &indeces);
97            let c = filtered_view.nrows();
98
99            means
100                .row_mut(idx)
101                .assign(&filtered_view.mean_axis(Axis(0))?);
102            vars.row_mut(idx)
103                .assign(&filtered_view.var_axis(Axis(0), 1.0));
104            priors[idx] = c as f64 / nrows as f64;
105        }
106
107        Some(GaussianNB {
108            labels: distinct_labels,
109            means,
110            vars,
111            priors,
112        })
113    }
114}
115
116impl<Label: Clone> Classifier<Array1<f64>, Label> for GaussianNB<Label> {
117    fn labels(&self) -> &[Label] {
118        &self.labels
119    }
120
121    fn predict_proba<I>(&self, arr: I) -> Option<Array2<f64>>
122    where
123        I: Iterator<Item = Array1<f64>>,
124    {
125        let col_count = self.labels.len();
126
127        let likelihoods: Vec<_> = arr
128            .map(|record| {
129                let mut log_likelihood = -0.5 * (&self.vars.view() * 2.0 * PI).ln();
130                log_likelihood =
131                    log_likelihood - 0.5 * ((record - &self.means).pow2() / self.vars.view());
132
133                log_likelihood = log_likelihood + self.priors.view().insert_axis(Axis(1)).ln();
134
135                let likelihood = log_likelihood.sum_axis(Axis(1)).exp().to_owned();
136
137                &likelihood / likelihood.sum()
138            })
139            .flat_map(|likelihoods| likelihoods.into_iter())
140            .collect();
141
142        let row_count = likelihoods.len() / col_count;
143
144        Array2::from_shape_vec((row_count, col_count), likelihoods).ok()
145    }
146}