ducky_learn/
naive_bayes.rs

1use super::util::{Fit, Unfit};
2use std::collections::{HashMap, HashSet};
3
4/// Implementation of a standard Naive Bayes classifier.
5///
6/// This classifier uses Laplace smoothing, the degree of which can be controlled with the `alpha` parameter.
7///
8/// # Parameters
9/// - `alpha`: The Laplace smoothing factor.
10/// - `probability_of_class`: HashMap storing the probabilities of each class.
11/// - `probability_of_feat_by_class`: HashMap storing the probabilities of each feature given a class.
12/// - `state`: PhantomData indicating whether the classifier has been fit.
13///
14/// # Type parameters
15/// - `State`: Indicates whether the classifier has been fit. Can either be `Fit` or `Unfit`.
16///
17/// # Example
18///
19/// ```
20/// use ducky_learn::naive_bayes::StdNaiveBayes;
21///
22/// // Define train and test data
23/// let x_train: Vec<Vec<f64>> = vec![
24///     vec![1.0, 2.0, 3.0],
25///     vec![2.0, 3.0, 4.0],
26///     vec![3.0, 4.0, 5.0],
27/// ];
28/// let y_train: Vec<String> = vec!["class1".to_string(), "class2".to_string(), "class1".to_string()];
29///
30/// let x_test: Vec<Vec<f64>> = vec![
31///     vec![1.5, 2.5, 3.5],
32///     vec![2.5, 3.5, 4.5],
33/// ];
34///
35/// let mut nb = StdNaiveBayes::new(1.0);
36/// let nb = nb.fit(&x_train, &y_train);
37/// let y_pred = nb.predict(&x_test);
38///
39/// // y_pred will hold the predicted classes for x_test
40/// ```
41#[derive(Debug)]
42pub struct StdNaiveBayes<State = Unfit> {
43    pub alpha: f64,
44    pub probability_of_class: HashMap<String, f64>,
45    pub probability_of_feat_by_class: HashMap<String, HashMap<String, f64>>,
46
47    state: std::marker::PhantomData<State>,
48}
49
50impl StdNaiveBayes {
51    /// Constructs a new, unfitted `StdNaiveBayes` classifier with a specified alpha value.
52    ///
53    /// # Parameters
54    /// - `alpha`: The Laplace smoothing factor.
55    ///
56    /// # Returns
57    /// A new `StdNaiveBayes` instance.
58    pub fn new(alpha: f64) -> Self {
59        Self {
60            alpha,
61            probability_of_class: Default::default(),
62            probability_of_feat_by_class: Default::default(),
63
64            state: Default::default(),
65        }
66    }
67
68    /// Fits the `StdNaiveBayes` classifier to the training data.
69    ///
70    /// # Parameters
71    /// - `x`: The training data.
72    /// - `y`: The target values.
73    ///
74    /// # Returns
75    /// The fitted `StdNaiveBayes` classifier.
76    pub fn fit(mut self, x: &Vec<Vec<f64>>, y: &Vec<String>) -> StdNaiveBayes<Fit> {
77        let mut y_counts: HashMap<String, i32> = HashMap::new();
78        for class in y {
79            let counter = y_counts.entry(class.to_string()).or_insert(0);
80            *counter += 1;
81        }
82
83        let total_rows = y.len() as f64;
84        let unique_classes: HashSet<String> = y.into_iter().cloned().collect();
85
86        for uniq_class in &unique_classes {
87            self.probability_of_class.insert(
88                uniq_class.to_string(),
89                *y_counts.get(uniq_class).unwrap() as f64 / total_rows,
90            );
91
92            let mut class_feat_probs: HashMap<String, f64> = HashMap::new();
93            let mut sum_of_feats_in_class = 0.0;
94            for (i, class) in y.iter().enumerate() {
95                if class == uniq_class {
96                    for (j, feat_count) in x[i].iter().enumerate() {
97                        let counter = class_feat_probs.entry(j.to_string()).or_insert(0.0);
98                        *counter += *feat_count;
99                        sum_of_feats_in_class += *feat_count;
100                    }
101                }
102            }
103            sum_of_feats_in_class += self.alpha * x[0].len() as f64;
104
105            for (feat, count) in class_feat_probs.iter_mut() {
106                *count = (*count + self.alpha) / sum_of_feats_in_class;
107            }
108
109            self.probability_of_feat_by_class
110                .insert(uniq_class.to_string(), class_feat_probs);
111        }
112
113        StdNaiveBayes {
114            alpha: self.alpha,
115            probability_of_class: self.probability_of_class.clone(),
116            probability_of_feat_by_class: self.probability_of_feat_by_class.clone(),
117
118            state: std::marker::PhantomData::<Fit>,
119        }
120    }
121}
122
123impl StdNaiveBayes<Fit> {
124    /// Predicts the target values for the given data.
125    ///
126    /// # Parameters
127    /// - `x`: The data to predict target values for.
128    ///
129    /// # Returns
130    /// The predicted target values.
131    ///
132    /// # Panics
133    /// This function will panic if the classifier has not been fit.
134    pub fn predict(&self, x: &Vec<Vec<f64>>) -> Vec<String> {
135        let mut y_pred: Vec<String> = Vec::new();
136        let unique_classes: Vec<String> = self.probability_of_class.keys().cloned().collect();
137        let class_probabilities: Vec<f64> = self.probability_of_class.values().cloned().collect();
138        let small_number = 1e-9;
139
140        for row in x {
141            let mut row_probabilities: Vec<f64> = Vec::new();
142            for (i, class) in unique_classes.iter().enumerate() {
143                let mut log_sum = (class_probabilities[i] + small_number).ln();
144                for (j, feat_count) in row.iter().enumerate() {
145                    if *feat_count > 0.0 {
146                        let prob = self
147                            .probability_of_feat_by_class
148                            .get(class)
149                            .unwrap()
150                            .get(&j.to_string())
151                            .unwrap();
152                        log_sum += (*feat_count * (*prob + small_number).ln());
153                    }
154                }
155                row_probabilities.push(log_sum);
156            }
157
158            let max_prob_index = row_probabilities
159                .iter()
160                .enumerate()
161                .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
162                .unwrap()
163                .0;
164            y_pred.push(unique_classes[max_prob_index].to_string());
165        }
166
167        y_pred
168    }
169}
170
171/// The `GaussianNaiveBayes` struct represents a Gaussian Naive Bayes classifier.
172///
173/// A Gaussian Naive Bayes classifier is a type of probabilistic machine learning model
174/// used for classification tasks. The Gaussian version assumes the features that it is
175/// learning from are distributed normally.
176///
177/// This struct has two possible states: `Unfit` and `Fit`. An `Unfit` model is one
178/// that has not yet been trained on data, while a `Fit` model has been trained and
179/// can be used for making predictions.
180///
181/// # Fields
182///
183/// * `classes` - A vector of unique class labels (targets) that the model may predict.
184///
185/// * `probability_of_class` - A hashmap where keys are the class labels and the values
186///   are the corresponding prior probabilities of each class.
187///
188/// * `probability_of_feat_by_class` - A hashmap where keys are the class labels and
189///   the values are vectors of tuples. Each tuple represents the mean and standard
190///   deviation of a particular feature for that class.
191///
192/// * `state` - A marker for the model's state. This is either `Unfit` (for a newly
193///   instantiated model) or `Fit` (for a model that has been trained on data).
194///
195/// # Examples
196///
197/// ```
198/// use std::collections::HashMap;
199/// use ducky_learn::naive_bayes::GaussianNaiveBayes;
200///
201/// let model = GaussianNaiveBayes::new();
202///
203/// let x_train: Vec<Vec<f64>> = vec![
204///     vec![1.0, 2.0],
205///     vec![2.0, 3.0],
206///     vec![3.0, 4.0],
207///     vec![4.0, 5.0],
208/// ];
209/// let y_train: Vec<String> = vec![
210///     "class1".to_string(),
211///     "class2".to_string(),
212///     "class1".to_string(),
213///     "class2".to_string(),
214/// ];
215///
216/// let model = model.fit(&x_train, &y_train);
217///
218/// let x_test: Vec<Vec<f64>> = vec![
219///     vec![1.5, 2.5],
220///     vec![3.5, 4.5],
221/// ];
222///
223/// let predictions = model.predict(&x_test);
224///
225/// println!("{:?}", predictions);
226/// ```
227pub struct GaussianNaiveBayes<State = Unfit> {
228    pub classes: Vec<String>,
229    pub probability_of_class: HashMap<String, f64>,
230    pub probability_of_feat_by_class: HashMap<String, Vec<(f64, f64)>>,
231
232    state: std::marker::PhantomData<State>,
233}
234
235impl GaussianNaiveBayes {
236    /// Creates a new `GaussianNaiveBayes` instance with an `Unfit` state.
237    ///
238    /// # Returns
239    ///
240    /// * `Self` - A new instance of `GaussianNaiveBayes`.
241    ///
242    /// # Examples
243    ///
244    /// ```
245    /// use ducky_learn::naive_bayes::GaussianNaiveBayes;
246    /// let model = GaussianNaiveBayes::new();
247    /// ```
248    pub fn new() -> Self {
249        Self {
250            classes: Default::default(),
251            probability_of_class: Default::default(),
252            probability_of_feat_by_class: Default::default(),
253
254            state: Default::default(),
255        }
256    }
257
258    /// Fits the model on the provided dataset, updating the model's state to `Fit`.
259    ///
260    /// # Arguments
261    ///
262    /// * `x` - A reference to a vector of vectors, where each inner vector represents
263    ///   the features of a data point.
264    ///
265    /// * `y` - A reference to a vector of class labels for each data point in `x`.
266    ///
267    /// # Returns
268    ///
269    /// * `GaussianNaiveBayes<Fit>` - The same model instance with updated fields
270    ///   and state set to `Fit`.
271    ///
272    /// # Examples
273    ///
274    /// ```
275    /// use ducky_learn::naive_bayes::GaussianNaiveBayes;
276    ///
277    /// let model = GaussianNaiveBayes::new();
278    ///
279    /// let x_train: Vec<Vec<f64>> = vec![
280    ///     vec![1.0, 2.0],
281    ///     vec![2.0, 3.0],
282    ///     vec![3.0, 4.0],
283    ///     vec![4.0, 5.0],
284    /// ];
285    /// let y_train: Vec<String> = vec![
286    ///     "class1".to_string(),
287    ///     "class2".to_string(),
288    ///     "class1".to_string(),
289    ///     "class2".to_string(),
290    /// ];
291    ///
292    /// let model = model.fit(&x_train, &y_train);
293    /// ```
294    pub fn fit(mut self, x: &Vec<Vec<f64>>, y: &Vec<String>) -> GaussianNaiveBayes<Fit> {
295        let uniq_classes: Vec<String> = y
296            .clone()
297            .into_iter()
298            .collect::<HashSet<String>>()
299            .into_iter()
300            .collect::<Vec<String>>();
301
302        GaussianNaiveBayes {
303            probability_of_class: calculate_class_probability(&uniq_classes, y),
304            probability_of_feat_by_class: calculate_feature_probability(x, y, &uniq_classes),
305            classes: uniq_classes,
306
307            state: std::marker::PhantomData::<Fit>,
308        }
309    }
310}
311
312impl GaussianNaiveBayes<Fit> {
313    /// Predicts the class of the provided data points.
314    ///
315    /// # Arguments
316    ///
317    /// * `x` - A reference to a vector of vectors, where each inner vector represents
318    ///   the features of a data point.
319    ///
320    /// # Returns
321    ///
322    /// * `Vec<String>` - A vector of predicted class labels for each data point in `x`.
323    ///
324    /// # Examples
325    ///
326    /// ```
327    /// use ducky_learn::naive_bayes::GaussianNaiveBayes;
328    ///
329    /// let model = GaussianNaiveBayes::new().fit(
330    ///     &vec![vec![0.1, 0.5], vec![0.6, 0.6]],
331    ///     &vec!["class1".to_string(), "class2".to_string()]
332    /// );
333    ///
334    /// let x_test: Vec<Vec<f64>> = vec![
335    ///     vec![1.5, 2.5],
336    ///     vec![3.5, 4.5],
337    /// ];
338    ///
339    /// let predictions = model.predict(&x_test);
340    ///
341    /// println!("{:?}", predictions);
342    /// ```
343    pub fn predict(&self, x: &Vec<Vec<f64>>) -> Vec<String> {
344        let mut predictions: Vec<String> = Vec::new();
345
346        for data in x.iter() {
347            let mut max_prob = f64::NEG_INFINITY;
348            let mut max_class = String::from("");
349
350            for class in &self.classes {
351                let mut class_prob = self.probability_of_class.get(class).unwrap().ln();
352
353                if let Some(feature_probs) = self.probability_of_feat_by_class.get(class) {
354                    for (index, &(mean, std_dev)) in feature_probs.iter().enumerate() {
355                        let feature_value = data[index];
356                        let feature_prob = calculate_probability(feature_value, mean, std_dev);
357                        class_prob += feature_prob.ln();
358                    }
359                }
360
361                if class_prob > max_prob {
362                    max_prob = class_prob;
363                    max_class = class.clone();
364                }
365            }
366            predictions.push(max_class);
367        }
368
369        predictions
370    }
371}
372
373fn calculate_mean(data: &Vec<f64>) -> f64 {
374    let sum: f64 = data.iter().sum();
375    sum / data.len() as f64
376}
377
378fn calculate_std_dev(data: &Vec<f64>, mean: f64) -> f64 {
379    let variance: f64 = data
380        .iter()
381        .map(|&value| {
382            let diff = value - mean;
383            diff * diff
384        })
385        .sum::<f64>()
386        / data.len() as f64;
387
388    variance.sqrt()
389}
390
391fn calculate_probability(x: f64, mean: f64, std_dev: f64) -> f64 {
392    let exponent = (-((x - mean).powi(2)) / (2.0 * std_dev.powi(2))).exp();
393    (1.0 / (2.0 * std::f64::consts::PI * std_dev.powi(2)).sqrt()) * exponent
394}
395
396fn calculate_class_probability(
397    uniq_classes: &Vec<String>,
398    all_classes: &Vec<String>,
399) -> HashMap<String, f64> {
400    let mut class_probability: HashMap<String, f64> = HashMap::new();
401    let total = all_classes.len() as f64;
402
403    let mut class_counts: HashMap<&String, f64> = HashMap::new();
404
405    // Calculate the counts for each class in one pass
406    for class in all_classes {
407        *class_counts.entry(class).or_insert(0.0) += 1.0;
408    }
409
410    // For each unique class, compute and store the probability
411    uniq_classes
412        .iter()
413        .map(|class| {
414            let count = *class_counts.get(class).unwrap_or(&0.0);
415            (class.clone(), count / total)
416        })
417        .collect()
418}
419
420fn calculate_feature_probability(
421    x: &Vec<Vec<f64>>,
422    y: &Vec<String>,
423    uniq_classes: &Vec<String>,
424) -> HashMap<String, Vec<(f64, f64)>> {
425    let mut return_feature_prob: HashMap<String, Vec<(f64, f64)>> = HashMap::new();
426
427    if x.len() != y.len() {
428        return HashMap::new();
429    }
430
431    for class in uniq_classes {
432        let x_class: Vec<_> = x
433            .iter()
434            .zip(y)
435            .filter_map(|(x, y)| if y == class { Some(x.clone()) } else { None })
436            .collect();
437
438        if x_class.is_empty() {
439            continue;
440        }
441
442        let num_features = x_class[0].len();
443
444        for i in 0..num_features {
445            let feature_values: Vec<_> = x_class.iter().map(|features| features[i]).collect();
446
447            // calculate the mean
448            let mean: f64 = feature_values.iter().sum::<f64>() / feature_values.len() as f64;
449
450            // calculate the standard deviation
451            let variance: f64 = feature_values
452                .iter()
453                .map(|value| {
454                    let diff = mean - *value;
455                    diff * diff
456                })
457                .sum::<f64>()
458                / feature_values.len() as f64;
459
460            let std_dev = variance.sqrt();
461
462            return_feature_prob
463                .entry(class.to_string())
464                .or_insert_with(|| Vec::with_capacity(num_features))
465                .push((mean, std_dev));
466        }
467    }
468
469    return_feature_prob
470}
471
472#[cfg(test)]
473mod calculation_functions_tests {
474    use super::*;
475
476    #[test]
477    fn test_calculate_class_probability() {
478        let uniq_classes = vec![
479            "class1".to_string(),
480            "class2".to_string(),
481            "class3".to_string(),
482        ];
483        let all_classes = vec![
484            "class1".to_string(),
485            "class2".to_string(),
486            "class2".to_string(),
487            "class3".to_string(),
488            "class3".to_string(),
489            "class3".to_string(),
490        ];
491        let probabilities = calculate_class_probability(&uniq_classes, &all_classes);
492
493        assert!(probabilities.get("class1").unwrap() - (1.0 / 6.0) < f64::EPSILON);
494        assert!(probabilities.get("class2").unwrap() - (2.0 / 6.0) < f64::EPSILON);
495        assert!(probabilities.get("class3").unwrap() - (3.0 / 6.0) < f64::EPSILON);
496    }
497
498    #[test]
499    fn test_calculate_class_probability_sum_to_one() {
500        let uniq_classes = vec![
501            "class1".to_string(),
502            "class2".to_string(),
503            "class3".to_string(),
504        ];
505        let all_classes = vec![
506            "class1".to_string(),
507            "class2".to_string(),
508            "class2".to_string(),
509            "class3".to_string(),
510            "class3".to_string(),
511            "class3".to_string(),
512        ];
513        let probabilities = calculate_class_probability(&uniq_classes, &all_classes);
514
515        let sum: f64 = probabilities.values().sum();
516
517        assert!(1.0 - sum < f64::EPSILON);
518    }
519
520    #[test]
521    fn test_calculate_feature_probability() {
522        let uniq_classes = vec!["class1".to_string(), "class2".to_string()];
523        let y = vec![
524            "class1".to_string(),
525            "class2".to_string(),
526            "class1".to_string(),
527            "class2".to_string(),
528        ];
529        let x = vec![
530            vec![1.0, 2.0],
531            vec![2.0, 2.0],
532            vec![2.0, 3.0],
533            vec![3.0, 3.0],
534        ];
535
536        let feature_probabilities = calculate_feature_probability(&x, &y, &uniq_classes);
537
538        let class1_probabilities = feature_probabilities.get("class1").unwrap();
539        assert!((class1_probabilities[0].0 - 1.5).abs() < f64::EPSILON);
540        assert!((class1_probabilities[0].1 - 0.5).abs() < f64::EPSILON);
541        assert!((class1_probabilities[1].0 - 2.5).abs() < f64::EPSILON);
542        assert!((class1_probabilities[1].1 - 0.5).abs() < f64::EPSILON);
543
544        let class2_probabilities = feature_probabilities.get("class2").unwrap();
545        assert!((class2_probabilities[0].0 - 2.5).abs() < f64::EPSILON);
546        assert!((class2_probabilities[0].1 - 0.5).abs() < f64::EPSILON);
547        assert!((class2_probabilities[1].0 - 2.5).abs() < f64::EPSILON);
548        assert!((class2_probabilities[1].1 - 0.5).abs() < f64::EPSILON);
549    }
550
551    #[test]
552    fn test_calculate_feature_probability_no_data() {
553        let uniq_classes = vec!["class1".to_string(), "class2".to_string()];
554        let y = vec![];
555        let x = vec![];
556
557        let feature_probabilities = calculate_feature_probability(&x, &y, &uniq_classes);
558
559        assert!(feature_probabilities.is_empty());
560    }
561
562    #[test]
563    fn test_calculate_feature_probability_same_feature_values() {
564        let uniq_classes = vec!["class1".to_string(), "class2".to_string()];
565        let y = vec![
566            "class1".to_string(),
567            "class1".to_string(),
568            "class2".to_string(),
569            "class2".to_string(),
570        ];
571        let x = vec![
572            vec![2.0, 2.0],
573            vec![2.0, 2.0],
574            vec![2.0, 2.0],
575            vec![2.0, 2.0],
576        ];
577
578        let feature_probabilities = calculate_feature_probability(&x, &y, &uniq_classes);
579
580        let class1_probabilities = feature_probabilities.get("class1").unwrap();
581        assert!((class1_probabilities[0].0 - 2.0).abs() < f64::EPSILON);
582        assert!((class1_probabilities[0].1 - 0.0).abs() < f64::EPSILON);
583        assert!((class1_probabilities[1].0 - 2.0).abs() < f64::EPSILON);
584        assert!((class1_probabilities[1].1 - 0.0).abs() < f64::EPSILON);
585
586        let class2_probabilities = feature_probabilities.get("class2").unwrap();
587        assert!((class2_probabilities[0].0 - 2.0).abs() < f64::EPSILON);
588        assert!((class2_probabilities[0].1 - 0.0).abs() < f64::EPSILON);
589        assert!((class2_probabilities[1].0 - 2.0).abs() < f64::EPSILON);
590        assert!((class2_probabilities[1].1 - 0.0).abs() < f64::EPSILON);
591    }
592
593    #[test]
594    fn test_calculate_feature_probability_mismatched_lengths() {
595        let uniq_classes = vec!["class1".to_string(), "class2".to_string()];
596        let y = vec!["class1".to_string(), "class2".to_string()];
597        let x = vec![];
598
599        let feature_probabilities = calculate_feature_probability(&x, &y, &uniq_classes);
600
601        assert!(feature_probabilities.is_empty());
602    }
603
604    #[test]
605    fn test_calculate_mean() {
606        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
607        assert_eq!(calculate_mean(&data), 3.0);
608    }
609
610    #[test]
611    fn test_calculate_std_dev() {
612        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
613        let mean = calculate_mean(&data);
614        assert_eq!(
615            (calculate_std_dev(&data, mean) - 1.414213).abs() < 0.00001,
616            true
617        );
618    }
619
620    #[test]
621    fn test_calculate_probability() {
622        let x = 2.0;
623        let mean = 2.0;
624        let std_dev = 1.0;
625        assert_eq!(
626            (calculate_probability(x, mean, std_dev) - 0.398942).abs() < 0.00001,
627            true
628        );
629    }
630}
631
632#[cfg(test)]
633mod naive_bayes_tests {
634    use super::*;
635
636    #[test]
637    fn test_fit_std() {
638        let mut model = StdNaiveBayes::new(1.0);
639
640        let x: Vec<Vec<f64>> = vec![
641            vec![1.0, 2.0, 3.0],
642            vec![2.0, 3.0, 1.0],
643            vec![3.0, 1.0, 2.0],
644        ];
645
646        let y: Vec<String> = vec![
647            "class1".to_string(),
648            "class2".to_string(),
649            "class1".to_string(),
650        ];
651
652        let model = model.fit(&x, &y);
653
654        assert!((model.probability_of_class.get("class1").unwrap() - 2.0 / 3.0).abs() < 1e-9);
655        assert!((model.probability_of_class.get("class2").unwrap() - 1.0 / 3.0).abs() < 1e-9);
656    }
657
658    #[test]
659    fn test_predict_std() {
660        let mut model = StdNaiveBayes::new(1.0);
661
662        let x: Vec<Vec<f64>> = vec![
663            vec![1.0, 2.0, 3.0, 1.0, 2.0],
664            vec![2.0, 3.0, 4.0, 2.0, 3.0],
665            vec![4.0, 4.0, 5.0, 4.0, 4.0],
666            vec![5.0, 5.0, 6.0, 5.0, 5.0],
667            vec![1.0, 1.0, 1.0, 1.0, 1.0],
668        ];
669
670        let y: Vec<String> = vec![
671            "class1".to_string(),
672            "class1".to_string(),
673            "class2".to_string(),
674            "class2".to_string(),
675            "class1".to_string(),
676        ];
677
678        let model = model.fit(&x, &y);
679
680        let x_test: Vec<Vec<f64>> =
681            vec![vec![1.5, 2.5, 3.5, 1.5, 2.5], vec![5.5, 4.5, 5.5, 4.5, 4.5]];
682
683        let predictions = model.predict(&x_test);
684
685        assert_eq!(predictions, vec!["class1", "class2"]);
686    }
687
688    #[test]
689    fn test_new_gaus() {
690        let model: GaussianNaiveBayes = GaussianNaiveBayes::new();
691
692        assert_eq!(model.classes.len(), 0);
693        assert_eq!(model.probability_of_class.len(), 0);
694        assert_eq!(model.probability_of_feat_by_class.len(), 0);
695    }
696
697    #[test]
698    fn test_fit_gaus() {
699        let mut model: GaussianNaiveBayes = GaussianNaiveBayes::new();
700        let x = vec![
701            vec![2.0, 1.0],
702            vec![3.0, 2.0],
703            vec![2.5, 1.5],
704            vec![4.0, 3.0],
705        ];
706        let y = vec![
707            "class1".to_string(),
708            "class1".to_string(),
709            "class2".to_string(),
710            "class2".to_string(),
711        ];
712        let model = model.fit(&x, &y);
713
714        assert_eq!(model.classes.len(), 2);
715        assert!(model.classes.contains(&"class1".to_string()));
716        assert!(model.classes.contains(&"class2".to_string()));
717
718        assert_eq!(model.probability_of_class.len(), 2);
719        assert!(model
720            .probability_of_class
721            .contains_key(&"class1".to_string()));
722        assert!(model
723            .probability_of_class
724            .contains_key(&"class2".to_string()));
725
726        assert_eq!(model.probability_of_feat_by_class.len(), 2);
727        assert!(model
728            .probability_of_feat_by_class
729            .contains_key(&"class1".to_string()));
730        assert!(model
731            .probability_of_feat_by_class
732            .contains_key(&"class2".to_string()));
733    }
734
735    #[test]
736    fn test_predict_gaus() {
737        let mut model: GaussianNaiveBayes = GaussianNaiveBayes::new();
738        let x = vec![
739            vec![2.0, 1.0],
740            vec![3.0, 2.0],
741            vec![2.5, 1.5],
742            vec![4.0, 3.0],
743        ];
744        let y = vec![
745            "class1".to_string(),
746            "class1".to_string(),
747            "class2".to_string(),
748            "class2".to_string(),
749        ];
750        let model = model.fit(&x, &y);
751
752        let x_test = vec![vec![2.0, 1.0], vec![4.0, 3.0]];
753
754        let predictions = model.predict(&x_test);
755        assert_eq!(predictions.len(), x_test.len());
756        assert_eq!(predictions[0], "class1");
757        assert_eq!(predictions[1], "class2");
758    }
759}