rs_ml/classification/
naive_bayes.rs1use crate::{Axis, Estimatable, Estimator};
4use core::f64;
5use ndarray::{Array1, Array2};
6use std::{f64::consts::PI, marker::PhantomData};
7
8use super::{ClassificationDataSet, Classifier};
9
10#[derive(Debug, Clone, Copy)]
39pub struct GaussianNBEstimator;
40
41#[derive(Debug)]
43pub struct GaussianNB<Input, Label> {
44 _input: PhantomData<Input>,
45 means: Array2<f64>,
46 vars: Array2<f64>,
47 priors: Array1<f64>,
48 labels: Vec<Label>,
49}
50
51impl<Input: Estimatable, Label: PartialEq + Clone> Estimator<ClassificationDataSet<Input, Label>>
52 for GaussianNBEstimator
53{
54 type Estimator = GaussianNB<Input, Label>;
55
56 fn fit(&self, input: &ClassificationDataSet<Input, Label>) -> Option<Self::Estimator> {
57 let distinct_labels: Vec<_> =
58 input
59 .get_labels()
60 .into_iter()
61 .fold(vec![], |mut agg, curr| {
62 if agg.contains(curr) {
63 agg
64 } else {
65 agg.push(curr.clone());
66 agg
67 }
68 });
69
70 let features_vec: Vec<_> = input
71 .get_features()
72 .iter()
73 .map(|i| i.prepare_for_estimation())
74 .collect();
75
76 let nrows = features_vec.len();
77 let nfeatures = features_vec.first()?.len();
78
79 let flat_shapes: Vec<f64> = features_vec
80 .iter()
81 .flat_map(|record| record.into_iter())
82 .copied()
83 .collect();
84
85 let features = Array2::from_shape_vec((nrows, nfeatures), flat_shapes).ok()?;
86
87 let mut means = Array2::zeros((distinct_labels.len(), nfeatures));
88 let mut vars = Array2::zeros((distinct_labels.len(), nfeatures));
89 let mut priors = Array1::zeros(distinct_labels.len());
90
91 for (idx, label) in distinct_labels.iter().enumerate() {
92 let indeces: Vec<usize> = input
93 .get_labels()
94 .into_iter()
95 .enumerate()
96 .filter_map(|(idx, l)| match l == label {
97 true => Some(idx),
98 false => None,
99 })
100 .collect();
101
102 let filtered_view = features.select(Axis(0), &indeces);
103 let c = filtered_view.nrows();
104
105 means
106 .row_mut(idx)
107 .assign(&filtered_view.mean_axis(Axis(0))?);
108 vars.row_mut(idx)
109 .assign(&filtered_view.var_axis(Axis(0), 1.0));
110 priors[idx] = c as f64 / nrows as f64;
111 }
112
113 Some(GaussianNB {
114 _input: PhantomData,
115 labels: distinct_labels,
116 means,
117 vars,
118 priors,
119 })
120 }
121}
122
123impl<Input: Estimatable, Label: Clone> Classifier<Input, Label> for GaussianNB<Input, Label> {
124 fn labels(&self) -> &[Label] {
125 &self.labels
126 }
127
128 fn predict_proba<I>(&self, arr: I) -> Option<Array2<f64>>
129 where
130 I: Iterator<Item = Input>,
131 {
132 let col_count = self.labels.len();
133
134 let likelihoods: Vec<_> = arr
135 .map(|record| {
136 let arr_record = record.prepare_for_estimation();
137 let mut log_likelihood = -0.5 * (&self.vars.view() * 2.0 * PI).ln();
138 log_likelihood =
139 log_likelihood - 0.5 * ((arr_record - &self.means).pow2() / self.vars.view());
140
141 log_likelihood = log_likelihood + self.priors.view().insert_axis(Axis(1)).ln();
142
143 let likelihood = log_likelihood.sum_axis(Axis(1)).exp().to_owned();
144
145 &likelihood / likelihood.sum()
146 })
147 .flat_map(|likelihoods| likelihoods.into_iter())
148 .collect();
149
150 let row_count = likelihoods.len() / col_count;
151
152 Array2::from_shape_vec((row_count, col_count), likelihoods).ok()
153 }
154}