rs_ml/classification/
naive_bayes.rs

1//! Naive Bayes classifiers
2
3use crate::{Axis, Estimatable, Estimator};
4use core::f64;
5use ndarray::{Array1, Array2};
6use std::{f64::consts::PI, marker::PhantomData};
7
8use super::{ClassificationDataSet, Classifier};
9
10/// Estimator to train a [`GaussianNB`] classifier.
11///
12/// Example:
13/// ```
14/// use ndarray::{arr1, arr2};
15/// use crate::rs_ml::Estimator;
16/// use crate::rs_ml::classification::{ClassificationRecord, ClassificationDataSet};
17/// use rs_ml::classification::naive_bayes::GaussianNBEstimator;
18///
19/// let features = arr2(&[
20///     [0., 0.],
21///     [0., 1.],
22///     [1., 0.],
23///     [1., 1.]
24/// ]);
25///
26/// let labels = arr1(&[false, true, true, false]);
27///
28/// let records: Vec<_> = features
29///     .rows()
30///     .into_iter()
31///     .zip(labels)
32///     .map(|(row, label)| ClassificationRecord::from((row.to_owned(), label)))
33///     .collect();
34///
35/// let dataset = ClassificationDataSet::from(records);
36/// let model = GaussianNBEstimator.fit(&dataset).unwrap();
37/// ```
38#[derive(Debug, Clone, Copy)]
39pub struct GaussianNBEstimator;
40
41/// Represents a fitted Gaussian Naive Bayes Classifier. Created with the `fit()` function implemented for [`GaussianNBEstimator`].
42#[derive(Debug)]
43pub struct GaussianNB<Input, Label> {
44    _input: PhantomData<Input>,
45    means: Array2<f64>,
46    vars: Array2<f64>,
47    priors: Array1<f64>,
48    labels: Vec<Label>,
49}
50
51impl<Input: Estimatable, Label: PartialEq + Clone> Estimator<ClassificationDataSet<Input, Label>>
52    for GaussianNBEstimator
53{
54    type Estimator = GaussianNB<Input, Label>;
55
56    fn fit(&self, input: &ClassificationDataSet<Input, Label>) -> Option<Self::Estimator> {
57        let distinct_labels: Vec<_> =
58            input
59                .get_labels()
60                .into_iter()
61                .fold(vec![], |mut agg, curr| {
62                    if agg.contains(curr) {
63                        agg
64                    } else {
65                        agg.push(curr.clone());
66                        agg
67                    }
68                });
69
70        let features_vec: Vec<_> = input
71            .get_features()
72            .iter()
73            .map(|i| i.prepare_for_estimation())
74            .collect();
75
76        let nrows = features_vec.len();
77        let nfeatures = features_vec.first()?.len();
78
79        let flat_shapes: Vec<f64> = features_vec
80            .iter()
81            .flat_map(|record| record.into_iter())
82            .copied()
83            .collect();
84
85        let features = Array2::from_shape_vec((nrows, nfeatures), flat_shapes).ok()?;
86
87        let mut means = Array2::zeros((distinct_labels.len(), nfeatures));
88        let mut vars = Array2::zeros((distinct_labels.len(), nfeatures));
89        let mut priors = Array1::zeros(distinct_labels.len());
90
91        for (idx, label) in distinct_labels.iter().enumerate() {
92            let indeces: Vec<usize> = input
93                .get_labels()
94                .into_iter()
95                .enumerate()
96                .filter_map(|(idx, l)| match l == label {
97                    true => Some(idx),
98                    false => None,
99                })
100                .collect();
101
102            let filtered_view = features.select(Axis(0), &indeces);
103            let c = filtered_view.nrows();
104
105            means
106                .row_mut(idx)
107                .assign(&filtered_view.mean_axis(Axis(0))?);
108            vars.row_mut(idx)
109                .assign(&filtered_view.var_axis(Axis(0), 1.0));
110            priors[idx] = c as f64 / nrows as f64;
111        }
112
113        Some(GaussianNB {
114            _input: PhantomData,
115            labels: distinct_labels,
116            means,
117            vars,
118            priors,
119        })
120    }
121}
122
123impl<Input: Estimatable, Label: Clone> Classifier<Input, Label> for GaussianNB<Input, Label> {
124    fn labels(&self) -> &[Label] {
125        &self.labels
126    }
127
128    fn predict_proba<I>(&self, arr: I) -> Option<Array2<f64>>
129    where
130        I: Iterator<Item = Input>,
131    {
132        let col_count = self.labels.len();
133
134        let likelihoods: Vec<_> = arr
135            .map(|record| {
136                let arr_record = record.prepare_for_estimation();
137                let mut log_likelihood = -0.5 * (&self.vars.view() * 2.0 * PI).ln();
138                log_likelihood =
139                    log_likelihood - 0.5 * ((arr_record - &self.means).pow2() / self.vars.view());
140
141                log_likelihood = log_likelihood + self.priors.view().insert_axis(Axis(1)).ln();
142
143                let likelihood = log_likelihood.sum_axis(Axis(1)).exp().to_owned();
144
145                &likelihood / likelihood.sum()
146            })
147            .flat_map(|likelihoods| likelihoods.into_iter())
148            .collect();
149
150        let row_count = likelihoods.len() / col_count;
151
152        Array2::from_shape_vec((row_count, col_count), likelihoods).ok()
153    }
154}