rusty_ai/bayes/
gaussian.rs

1use crate::{
2    data::dataset::{Dataset, RealNumber, WholeNumber},
3    metrics::confusion::ClassificationMetrics,
4};
5use nalgebra::{DMatrix, DVector};
6use std::{
7    collections::{HashMap, HashSet},
8    error::Error,
9};
10
11/// Implementation of Gaussian Naive Bayes classifier.
12///
13/// This struct represents a Gaussian Naive Bayes classifier. It is used to fit a training dataset
14/// and make predictions on new data points. The classifier assumes that the features are
15/// independent and follow a Gaussian distribution.
16///
17/// # Type Parameters
18///
19/// * `XT`: The type of the input features.
20/// * `YT`: The type of the target labels.
21///
22/// # Example
23///
24/// ```
25/// use rusty_ai::bayes::gaussian::GaussianNB;
26/// use rusty_ai::data::dataset::Dataset;
27/// use nalgebra::{DMatrix, DVector};
28///
29/// // Create a new Gaussian Naive Bayes classifier
30/// let mut classifier = GaussianNB::new();
31///
32/// // Create a training dataset
33/// let x = DMatrix::from_row_slice(3, 2, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
34/// let y = DVector::from_vec(vec![0, 1, 0]);
35/// let dataset = Dataset::new(x, y);
36///
37/// // Fit the classifier to the training dataset
38/// let result = classifier.fit(&dataset);
39/// assert!(result.is_ok());
40///
41/// // Make predictions on new data points
42/// let x_test = DMatrix::from_row_slice(2, 2, &[1.0, 2.0, 3.0, 4.0]);
43/// let predictions = classifier.predict(&x_test);
44/// assert!(predictions.is_ok());
45/// ```
46
47#[derive(Clone, Debug)]
48pub struct GaussianNB<XT: RealNumber, YT: WholeNumber> {
49    class_freq: HashMap<YT, XT>,
50    class_mean: HashMap<YT, DVector<XT>>,
51    class_variance: HashMap<YT, DVector<XT>>,
52}
53
54impl<XT: RealNumber, YT: WholeNumber> ClassificationMetrics<YT> for GaussianNB<XT, YT> {}
55
56impl<XT: RealNumber, YT: WholeNumber> Default for GaussianNB<XT, YT> {
57    fn default() -> Self {
58        Self::new()
59    }
60}
61
62impl<XT: RealNumber, YT: WholeNumber> GaussianNB<XT, YT> {
63    /// Creates a new Gaussian Naive Bayes classifier.
64    ///
65    /// This function initializes the classifier with empty class frequency, mean, and variance
66    /// maps.
67    ///
68    /// # Returns
69    ///
70    /// A new instance of `GaussianNB`.
71    pub fn new() -> Self {
72        Self {
73            class_freq: HashMap::new(),
74            class_mean: HashMap::new(),
75            class_variance: HashMap::new(),
76        }
77    }
78
79    /// Returns a reference to the class frequency map.
80    ///
81    /// This function returns a reference to the map that stores the frequency of each class in
82    /// the training dataset.
83    ///
84    /// # Returns
85    ///
86    /// A reference to the class frequency map.
87    pub fn class_freq(&self) -> &HashMap<YT, XT> {
88        &self.class_freq
89    }
90
91    /// Returns a reference to the class mean map.
92    ///
93    /// This function returns a reference to the map that stores the mean values of each feature
94    /// for each class in the training dataset.
95    ///
96    /// # Returns
97    ///
98    /// A reference to the class mean map.
99    pub fn class_mean(&self) -> &HashMap<YT, DVector<XT>> {
100        &self.class_mean
101    }
102
103    /// Returns a reference to the class variance map.
104    ///
105    /// This function returns a reference to the map that stores the variance values of each
106    /// feature for each class in the training dataset.
107    ///
108    /// # Returns
109    ///
110    /// A reference to the class variance map.
111    pub fn class_variance(&self) -> &HashMap<YT, DVector<XT>> {
112        &self.class_variance
113    }
114
115    /// Fits the classifier to a training dataset.
116    ///
117    /// This function fits the classifier to the provided training dataset. It calculates the
118    /// class frequency, mean, and variance for each class in the dataset.
119    ///
120    /// # Arguments
121    ///
122    /// * `dataset` - The training dataset to fit the classifier to.
123    ///
124    /// # Returns
125    ///
126    /// A `Result` indicating whether the fitting process was successful or an error occurred.
127    pub fn fit(&mut self, dataset: &Dataset<XT, YT>) -> Result<String, Box<dyn Error>> {
128        let (x, y) = dataset.into_parts();
129        let classes = y.iter().cloned().collect::<HashSet<_>>();
130
131        for class in classes {
132            let class_mask = y.map(|label| label == class);
133            let class_indices = class_mask
134                .iter()
135                .enumerate()
136                .filter(|&(_, &value)| value)
137                .map(|(index, _)| index)
138                .collect::<Vec<_>>();
139            let x_class = x.select_rows(class_indices.as_slice());
140
141            let mean = DVector::from_fn(x_class.ncols(), |col, _| {
142                self.mean(&x_class.column(col).into_owned())
143            });
144            let variance = DVector::from_fn(x_class.ncols(), |col, _| {
145                self.variance(&x_class.column(col).into_owned())
146            });
147
148            let freq =
149                XT::from_usize(class_indices.len()).unwrap() / XT::from_usize(x.nrows()).unwrap();
150
151            self.class_freq.insert(class, freq);
152            self.class_mean.insert(class, mean);
153            self.class_variance.insert(class, variance);
154        }
155        Ok("Finished fitting".into())
156    }
157
158    fn mean(&self, x: &DVector<XT>) -> XT {
159        let zero = XT::from_f64(0.0).unwrap();
160        let sum: XT = x.fold(zero, |acc, x| acc + x);
161
162        sum / XT::from_usize(x.len()).unwrap()
163    }
164
165    fn variance(&self, x: &DVector<XT>) -> XT {
166        let mean = self.mean(x);
167        let zero = XT::from_f64(0.0).unwrap();
168        let numerator = x.fold(zero, |acc, x| acc + (x - mean) * (x - mean));
169
170        numerator / XT::from_usize(x.len() - 1).unwrap()
171    }
172
173    fn predict_single(&self, x: &DVector<XT>) -> Result<YT, Box<dyn Error>> {
174        let mut max_log_likelihood = XT::from_f64(f64::NEG_INFINITY).unwrap();
175        let mut max_class = YT::from_i8(0).unwrap();
176
177        for class in self.class_freq.keys() {
178            let mean = self
179                .class_mean
180                .get(class)
181                .ok_or(format!("Mean for class {:?} wasn't calculated.", class))?;
182            let variance = self
183                .class_variance
184                .get(class)
185                .ok_or(format!("Variance for class {:?} wasn't calculated.", class))?;
186            let variance_epsilon =
187                DVector::<XT>::from_element(variance.len(), XT::from_f64(1e-9).unwrap());
188
189            let starting = XT::from_f64(-0.5).unwrap();
190            let log_likelihood = starting
191                * ((x - mean).component_mul(&(x - mean)).component_div(
192                    &(variance.map(|v| v * XT::from_f64(2.0).unwrap()) + &variance_epsilon),
193                ))
194                .sum()
195                + starting * (variance + &variance_epsilon).map(|v| v.ln()).sum()
196                + self
197                    .class_freq
198                    .get(class)
199                    .ok_or(format!("Frequency of class {:?} wasn't obtained.", class))?
200                    .ln();
201
202            if log_likelihood > max_log_likelihood {
203                max_log_likelihood = log_likelihood;
204                max_class = *class;
205            }
206        }
207        Ok(max_class)
208    }
209
210    /// Predicts the class labels for a given matrix of feature vectors.
211    ///
212    /// This function predicts the class labels for each feature vector in the input matrix using
213    /// the fitted Gaussian Naive Bayes classifier. It returns a vector of predicted class labels.
214    ///
215    /// # Arguments
216    ///
217    /// * `x` - The matrix of feature vectors to predict the class labels for.
218    ///
219    /// # Returns
220    ///
221    /// A `Result` containing a vector of predicted class labels or an error if the prediction
222    /// process fails.
223    pub fn predict(&self, x: &DMatrix<XT>) -> Result<DVector<YT>, Box<dyn Error>> {
224        let mut y_pred = Vec::new();
225
226        for i in 0..x.nrows() {
227            let x_row = x.row(i).into_owned().transpose();
228            let class = self.predict_single(&x_row)?;
229            y_pred.push(class);
230        }
231
232        Ok(DVector::from_vec(y_pred))
233    }
234}
235
236#[cfg(test)]
237mod tests {
238    use super::*;
239    use approx::assert_abs_diff_eq;
240
241    #[test]
242    fn test_new() {
243        let clf = GaussianNB::<f64, i32>::new();
244
245        assert!(clf.class_freq.is_empty());
246        assert!(clf.class_mean.is_empty());
247        assert!(clf.class_variance.is_empty());
248    }
249
250    #[test]
251    fn test_model_fit() {
252        let mut clf = GaussianNB::<f64, i32>::new();
253
254        let x = DMatrix::from_row_slice(
255            4,
256            3,
257            &[
258                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
259            ],
260        );
261        let y = DVector::from_column_slice(&[0, 0, 1, 1]);
262        let dataset = Dataset::new(x, y);
263
264        let _ = clf.fit(&dataset);
265
266        assert_abs_diff_eq!(*clf.class_freq.get(&0).unwrap(), 0.5, epsilon = 1e-7);
267        assert_abs_diff_eq!(*clf.class_freq.get(&1).unwrap(), 0.5, epsilon = 1e-7);
268    }
269
270    #[test]
271    fn test_predictions() {
272        let mut clf = GaussianNB::<f64, i32>::new();
273
274        let x = DMatrix::from_row_slice(
275            4,
276            3,
277            &[
278                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
279            ],
280        );
281        let y = DVector::from_column_slice(&[0, 0, 1, 1]);
282        let dataset = Dataset::new(x, y);
283
284        let _ = clf.fit(&dataset);
285
286        let test_x = DMatrix::from_row_slice(2, 3, &[2.0, 3.0, 4.0, 6.0, 7.0, 8.0]);
287
288        let pred_y = clf.predict(&test_x).unwrap();
289
290        assert_eq!(pred_y, DVector::from_column_slice(&[0, 1]));
291    }
292
293    #[test]
294    fn test_empty_data() {
295        let mut clf = GaussianNB::<f64, i32>::new();
296        let empty_x = DMatrix::<f64>::zeros(0, 0);
297        let empty_y = DVector::<i32>::zeros(0);
298        let empty_pred_y = clf.predict(&empty_x).unwrap();
299        assert_eq!(empty_pred_y.len(), 0);
300        let dataset = Dataset::new(empty_x, empty_y);
301
302        let _ = clf.fit(&dataset);
303        assert_eq!(clf.class_freq.len(), 0);
304        assert_eq!(clf.class_mean.len(), 0);
305        assert_eq!(clf.class_variance.len(), 0);
306    }
307
308    #[test]
309    fn test_single_class() {
310        let mut clf = GaussianNB::<f64, i32>::new();
311
312        let x = DMatrix::from_row_slice(3, 2, &[1.0, 2.0, 2.0, 3.0, 3.0, 4.0]);
313        let y = DVector::from_column_slice(&[0, 0, 0]);
314        let dataset = Dataset::new(x, y);
315
316        let _ = clf.fit(&dataset);
317
318        assert_eq!(clf.class_freq.len(), 1);
319        assert_eq!(clf.class_mean.len(), 1);
320        assert_eq!(clf.class_variance.len(), 1);
321
322        let test_x = DMatrix::from_row_slice(2, 2, &[1.5, 2.5, 2.5, 3.5]);
323
324        let pred_y = clf.predict(&test_x).unwrap();
325
326        assert_eq!(pred_y, DVector::from_column_slice(&[0, 0]));
327    }
328
329    #[test]
330    fn test_predict_with_constant_feature() {
331        let mut clf = GaussianNB::<f64, i32>::new();
332
333        let x = DMatrix::from_row_slice(4, 2, &[0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0]);
334        let y = DVector::from_vec(vec![0, 0, 1, 1]);
335
336        let x_new = DMatrix::from_row_slice(2, 2, &[0.0, 1.0, 1.0, 1.0]);
337        let dataset = Dataset::new(x, y);
338
339        let _ = clf.fit(&dataset);
340
341        let y_hat = clf.predict(&x_new).unwrap();
342
343        assert_eq!(y_hat.len(), 2);
344        assert_eq!(y_hat[0], 0);
345        assert_eq!(y_hat[1], 1);
346    }
347
348    #[test]
349    fn test_gaussian_nb() {
350        let mut clf = GaussianNB::<f64, i32>::new();
351
352        let x = DMatrix::from_row_slice(
353            4,
354            3,
355            &[
356                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
357            ],
358        );
359        let y = DVector::from_column_slice(&[0, 0, 1, 1]);
360        let dataset = Dataset::new(x, y);
361
362        let _ = clf.fit(&dataset);
363
364        assert_abs_diff_eq!(*clf.class_freq.get(&0).unwrap(), 0.5, epsilon = 1e-7);
365        assert_abs_diff_eq!(*clf.class_freq.get(&1).unwrap(), 0.5, epsilon = 1e-7);
366
367        let test_x = DMatrix::from_row_slice(2, 3, &[2.0, 3.0, 4.0, 6.0, 7.0, 8.0]);
368
369        let pred_y = clf.predict(&test_x).unwrap();
370
371        assert_eq!(pred_y, DVector::from_column_slice(&[0, 1]));
372    }
373}