1use super::util::{Fit, Unfit};
2use std::collections::{HashMap, HashSet};
3
4#[derive(Debug)]
42pub struct StdNaiveBayes<State = Unfit> {
43 pub alpha: f64,
44 pub probability_of_class: HashMap<String, f64>,
45 pub probability_of_feat_by_class: HashMap<String, HashMap<String, f64>>,
46
47 state: std::marker::PhantomData<State>,
48}
49
50impl StdNaiveBayes {
51 pub fn new(alpha: f64) -> Self {
59 Self {
60 alpha,
61 probability_of_class: Default::default(),
62 probability_of_feat_by_class: Default::default(),
63
64 state: Default::default(),
65 }
66 }
67
68 pub fn fit(mut self, x: &Vec<Vec<f64>>, y: &Vec<String>) -> StdNaiveBayes<Fit> {
77 let mut y_counts: HashMap<String, i32> = HashMap::new();
78 for class in y {
79 let counter = y_counts.entry(class.to_string()).or_insert(0);
80 *counter += 1;
81 }
82
83 let total_rows = y.len() as f64;
84 let unique_classes: HashSet<String> = y.into_iter().cloned().collect();
85
86 for uniq_class in &unique_classes {
87 self.probability_of_class.insert(
88 uniq_class.to_string(),
89 *y_counts.get(uniq_class).unwrap() as f64 / total_rows,
90 );
91
92 let mut class_feat_probs: HashMap<String, f64> = HashMap::new();
93 let mut sum_of_feats_in_class = 0.0;
94 for (i, class) in y.iter().enumerate() {
95 if class == uniq_class {
96 for (j, feat_count) in x[i].iter().enumerate() {
97 let counter = class_feat_probs.entry(j.to_string()).or_insert(0.0);
98 *counter += *feat_count;
99 sum_of_feats_in_class += *feat_count;
100 }
101 }
102 }
103 sum_of_feats_in_class += self.alpha * x[0].len() as f64;
104
105 for (feat, count) in class_feat_probs.iter_mut() {
106 *count = (*count + self.alpha) / sum_of_feats_in_class;
107 }
108
109 self.probability_of_feat_by_class
110 .insert(uniq_class.to_string(), class_feat_probs);
111 }
112
113 StdNaiveBayes {
114 alpha: self.alpha,
115 probability_of_class: self.probability_of_class.clone(),
116 probability_of_feat_by_class: self.probability_of_feat_by_class.clone(),
117
118 state: std::marker::PhantomData::<Fit>,
119 }
120 }
121}
122
123impl StdNaiveBayes<Fit> {
124 pub fn predict(&self, x: &Vec<Vec<f64>>) -> Vec<String> {
135 let mut y_pred: Vec<String> = Vec::new();
136 let unique_classes: Vec<String> = self.probability_of_class.keys().cloned().collect();
137 let class_probabilities: Vec<f64> = self.probability_of_class.values().cloned().collect();
138 let small_number = 1e-9;
139
140 for row in x {
141 let mut row_probabilities: Vec<f64> = Vec::new();
142 for (i, class) in unique_classes.iter().enumerate() {
143 let mut log_sum = (class_probabilities[i] + small_number).ln();
144 for (j, feat_count) in row.iter().enumerate() {
145 if *feat_count > 0.0 {
146 let prob = self
147 .probability_of_feat_by_class
148 .get(class)
149 .unwrap()
150 .get(&j.to_string())
151 .unwrap();
152 log_sum += (*feat_count * (*prob + small_number).ln());
153 }
154 }
155 row_probabilities.push(log_sum);
156 }
157
158 let max_prob_index = row_probabilities
159 .iter()
160 .enumerate()
161 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
162 .unwrap()
163 .0;
164 y_pred.push(unique_classes[max_prob_index].to_string());
165 }
166
167 y_pred
168 }
169}
170
171pub struct GaussianNaiveBayes<State = Unfit> {
228 pub classes: Vec<String>,
229 pub probability_of_class: HashMap<String, f64>,
230 pub probability_of_feat_by_class: HashMap<String, Vec<(f64, f64)>>,
231
232 state: std::marker::PhantomData<State>,
233}
234
235impl GaussianNaiveBayes {
236 pub fn new() -> Self {
249 Self {
250 classes: Default::default(),
251 probability_of_class: Default::default(),
252 probability_of_feat_by_class: Default::default(),
253
254 state: Default::default(),
255 }
256 }
257
258 pub fn fit(mut self, x: &Vec<Vec<f64>>, y: &Vec<String>) -> GaussianNaiveBayes<Fit> {
295 let uniq_classes: Vec<String> = y
296 .clone()
297 .into_iter()
298 .collect::<HashSet<String>>()
299 .into_iter()
300 .collect::<Vec<String>>();
301
302 GaussianNaiveBayes {
303 probability_of_class: calculate_class_probability(&uniq_classes, y),
304 probability_of_feat_by_class: calculate_feature_probability(x, y, &uniq_classes),
305 classes: uniq_classes,
306
307 state: std::marker::PhantomData::<Fit>,
308 }
309 }
310}
311
312impl GaussianNaiveBayes<Fit> {
313 pub fn predict(&self, x: &Vec<Vec<f64>>) -> Vec<String> {
344 let mut predictions: Vec<String> = Vec::new();
345
346 for data in x.iter() {
347 let mut max_prob = f64::NEG_INFINITY;
348 let mut max_class = String::from("");
349
350 for class in &self.classes {
351 let mut class_prob = self.probability_of_class.get(class).unwrap().ln();
352
353 if let Some(feature_probs) = self.probability_of_feat_by_class.get(class) {
354 for (index, &(mean, std_dev)) in feature_probs.iter().enumerate() {
355 let feature_value = data[index];
356 let feature_prob = calculate_probability(feature_value, mean, std_dev);
357 class_prob += feature_prob.ln();
358 }
359 }
360
361 if class_prob > max_prob {
362 max_prob = class_prob;
363 max_class = class.clone();
364 }
365 }
366 predictions.push(max_class);
367 }
368
369 predictions
370 }
371}
372
373fn calculate_mean(data: &Vec<f64>) -> f64 {
374 let sum: f64 = data.iter().sum();
375 sum / data.len() as f64
376}
377
378fn calculate_std_dev(data: &Vec<f64>, mean: f64) -> f64 {
379 let variance: f64 = data
380 .iter()
381 .map(|&value| {
382 let diff = value - mean;
383 diff * diff
384 })
385 .sum::<f64>()
386 / data.len() as f64;
387
388 variance.sqrt()
389}
390
391fn calculate_probability(x: f64, mean: f64, std_dev: f64) -> f64 {
392 let exponent = (-((x - mean).powi(2)) / (2.0 * std_dev.powi(2))).exp();
393 (1.0 / (2.0 * std::f64::consts::PI * std_dev.powi(2)).sqrt()) * exponent
394}
395
396fn calculate_class_probability(
397 uniq_classes: &Vec<String>,
398 all_classes: &Vec<String>,
399) -> HashMap<String, f64> {
400 let mut class_probability: HashMap<String, f64> = HashMap::new();
401 let total = all_classes.len() as f64;
402
403 let mut class_counts: HashMap<&String, f64> = HashMap::new();
404
405 for class in all_classes {
407 *class_counts.entry(class).or_insert(0.0) += 1.0;
408 }
409
410 uniq_classes
412 .iter()
413 .map(|class| {
414 let count = *class_counts.get(class).unwrap_or(&0.0);
415 (class.clone(), count / total)
416 })
417 .collect()
418}
419
420fn calculate_feature_probability(
421 x: &Vec<Vec<f64>>,
422 y: &Vec<String>,
423 uniq_classes: &Vec<String>,
424) -> HashMap<String, Vec<(f64, f64)>> {
425 let mut return_feature_prob: HashMap<String, Vec<(f64, f64)>> = HashMap::new();
426
427 if x.len() != y.len() {
428 return HashMap::new();
429 }
430
431 for class in uniq_classes {
432 let x_class: Vec<_> = x
433 .iter()
434 .zip(y)
435 .filter_map(|(x, y)| if y == class { Some(x.clone()) } else { None })
436 .collect();
437
438 if x_class.is_empty() {
439 continue;
440 }
441
442 let num_features = x_class[0].len();
443
444 for i in 0..num_features {
445 let feature_values: Vec<_> = x_class.iter().map(|features| features[i]).collect();
446
447 let mean: f64 = feature_values.iter().sum::<f64>() / feature_values.len() as f64;
449
450 let variance: f64 = feature_values
452 .iter()
453 .map(|value| {
454 let diff = mean - *value;
455 diff * diff
456 })
457 .sum::<f64>()
458 / feature_values.len() as f64;
459
460 let std_dev = variance.sqrt();
461
462 return_feature_prob
463 .entry(class.to_string())
464 .or_insert_with(|| Vec::with_capacity(num_features))
465 .push((mean, std_dev));
466 }
467 }
468
469 return_feature_prob
470}
471
472#[cfg(test)]
473mod calculation_functions_tests {
474 use super::*;
475
476 #[test]
477 fn test_calculate_class_probability() {
478 let uniq_classes = vec![
479 "class1".to_string(),
480 "class2".to_string(),
481 "class3".to_string(),
482 ];
483 let all_classes = vec![
484 "class1".to_string(),
485 "class2".to_string(),
486 "class2".to_string(),
487 "class3".to_string(),
488 "class3".to_string(),
489 "class3".to_string(),
490 ];
491 let probabilities = calculate_class_probability(&uniq_classes, &all_classes);
492
493 assert!(probabilities.get("class1").unwrap() - (1.0 / 6.0) < f64::EPSILON);
494 assert!(probabilities.get("class2").unwrap() - (2.0 / 6.0) < f64::EPSILON);
495 assert!(probabilities.get("class3").unwrap() - (3.0 / 6.0) < f64::EPSILON);
496 }
497
498 #[test]
499 fn test_calculate_class_probability_sum_to_one() {
500 let uniq_classes = vec![
501 "class1".to_string(),
502 "class2".to_string(),
503 "class3".to_string(),
504 ];
505 let all_classes = vec![
506 "class1".to_string(),
507 "class2".to_string(),
508 "class2".to_string(),
509 "class3".to_string(),
510 "class3".to_string(),
511 "class3".to_string(),
512 ];
513 let probabilities = calculate_class_probability(&uniq_classes, &all_classes);
514
515 let sum: f64 = probabilities.values().sum();
516
517 assert!(1.0 - sum < f64::EPSILON);
518 }
519
520 #[test]
521 fn test_calculate_feature_probability() {
522 let uniq_classes = vec!["class1".to_string(), "class2".to_string()];
523 let y = vec![
524 "class1".to_string(),
525 "class2".to_string(),
526 "class1".to_string(),
527 "class2".to_string(),
528 ];
529 let x = vec![
530 vec![1.0, 2.0],
531 vec![2.0, 2.0],
532 vec![2.0, 3.0],
533 vec![3.0, 3.0],
534 ];
535
536 let feature_probabilities = calculate_feature_probability(&x, &y, &uniq_classes);
537
538 let class1_probabilities = feature_probabilities.get("class1").unwrap();
539 assert!((class1_probabilities[0].0 - 1.5).abs() < f64::EPSILON);
540 assert!((class1_probabilities[0].1 - 0.5).abs() < f64::EPSILON);
541 assert!((class1_probabilities[1].0 - 2.5).abs() < f64::EPSILON);
542 assert!((class1_probabilities[1].1 - 0.5).abs() < f64::EPSILON);
543
544 let class2_probabilities = feature_probabilities.get("class2").unwrap();
545 assert!((class2_probabilities[0].0 - 2.5).abs() < f64::EPSILON);
546 assert!((class2_probabilities[0].1 - 0.5).abs() < f64::EPSILON);
547 assert!((class2_probabilities[1].0 - 2.5).abs() < f64::EPSILON);
548 assert!((class2_probabilities[1].1 - 0.5).abs() < f64::EPSILON);
549 }
550
551 #[test]
552 fn test_calculate_feature_probability_no_data() {
553 let uniq_classes = vec!["class1".to_string(), "class2".to_string()];
554 let y = vec![];
555 let x = vec![];
556
557 let feature_probabilities = calculate_feature_probability(&x, &y, &uniq_classes);
558
559 assert!(feature_probabilities.is_empty());
560 }
561
562 #[test]
563 fn test_calculate_feature_probability_same_feature_values() {
564 let uniq_classes = vec!["class1".to_string(), "class2".to_string()];
565 let y = vec![
566 "class1".to_string(),
567 "class1".to_string(),
568 "class2".to_string(),
569 "class2".to_string(),
570 ];
571 let x = vec![
572 vec![2.0, 2.0],
573 vec![2.0, 2.0],
574 vec![2.0, 2.0],
575 vec![2.0, 2.0],
576 ];
577
578 let feature_probabilities = calculate_feature_probability(&x, &y, &uniq_classes);
579
580 let class1_probabilities = feature_probabilities.get("class1").unwrap();
581 assert!((class1_probabilities[0].0 - 2.0).abs() < f64::EPSILON);
582 assert!((class1_probabilities[0].1 - 0.0).abs() < f64::EPSILON);
583 assert!((class1_probabilities[1].0 - 2.0).abs() < f64::EPSILON);
584 assert!((class1_probabilities[1].1 - 0.0).abs() < f64::EPSILON);
585
586 let class2_probabilities = feature_probabilities.get("class2").unwrap();
587 assert!((class2_probabilities[0].0 - 2.0).abs() < f64::EPSILON);
588 assert!((class2_probabilities[0].1 - 0.0).abs() < f64::EPSILON);
589 assert!((class2_probabilities[1].0 - 2.0).abs() < f64::EPSILON);
590 assert!((class2_probabilities[1].1 - 0.0).abs() < f64::EPSILON);
591 }
592
593 #[test]
594 fn test_calculate_feature_probability_mismatched_lengths() {
595 let uniq_classes = vec!["class1".to_string(), "class2".to_string()];
596 let y = vec!["class1".to_string(), "class2".to_string()];
597 let x = vec![];
598
599 let feature_probabilities = calculate_feature_probability(&x, &y, &uniq_classes);
600
601 assert!(feature_probabilities.is_empty());
602 }
603
604 #[test]
605 fn test_calculate_mean() {
606 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
607 assert_eq!(calculate_mean(&data), 3.0);
608 }
609
610 #[test]
611 fn test_calculate_std_dev() {
612 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
613 let mean = calculate_mean(&data);
614 assert_eq!(
615 (calculate_std_dev(&data, mean) - 1.414213).abs() < 0.00001,
616 true
617 );
618 }
619
620 #[test]
621 fn test_calculate_probability() {
622 let x = 2.0;
623 let mean = 2.0;
624 let std_dev = 1.0;
625 assert_eq!(
626 (calculate_probability(x, mean, std_dev) - 0.398942).abs() < 0.00001,
627 true
628 );
629 }
630}
631
632#[cfg(test)]
633mod naive_bayes_tests {
634 use super::*;
635
636 #[test]
637 fn test_fit_std() {
638 let mut model = StdNaiveBayes::new(1.0);
639
640 let x: Vec<Vec<f64>> = vec![
641 vec![1.0, 2.0, 3.0],
642 vec![2.0, 3.0, 1.0],
643 vec![3.0, 1.0, 2.0],
644 ];
645
646 let y: Vec<String> = vec![
647 "class1".to_string(),
648 "class2".to_string(),
649 "class1".to_string(),
650 ];
651
652 let model = model.fit(&x, &y);
653
654 assert!((model.probability_of_class.get("class1").unwrap() - 2.0 / 3.0).abs() < 1e-9);
655 assert!((model.probability_of_class.get("class2").unwrap() - 1.0 / 3.0).abs() < 1e-9);
656 }
657
658 #[test]
659 fn test_predict_std() {
660 let mut model = StdNaiveBayes::new(1.0);
661
662 let x: Vec<Vec<f64>> = vec![
663 vec![1.0, 2.0, 3.0, 1.0, 2.0],
664 vec![2.0, 3.0, 4.0, 2.0, 3.0],
665 vec![4.0, 4.0, 5.0, 4.0, 4.0],
666 vec![5.0, 5.0, 6.0, 5.0, 5.0],
667 vec![1.0, 1.0, 1.0, 1.0, 1.0],
668 ];
669
670 let y: Vec<String> = vec![
671 "class1".to_string(),
672 "class1".to_string(),
673 "class2".to_string(),
674 "class2".to_string(),
675 "class1".to_string(),
676 ];
677
678 let model = model.fit(&x, &y);
679
680 let x_test: Vec<Vec<f64>> =
681 vec![vec![1.5, 2.5, 3.5, 1.5, 2.5], vec![5.5, 4.5, 5.5, 4.5, 4.5]];
682
683 let predictions = model.predict(&x_test);
684
685 assert_eq!(predictions, vec!["class1", "class2"]);
686 }
687
688 #[test]
689 fn test_new_gaus() {
690 let model: GaussianNaiveBayes = GaussianNaiveBayes::new();
691
692 assert_eq!(model.classes.len(), 0);
693 assert_eq!(model.probability_of_class.len(), 0);
694 assert_eq!(model.probability_of_feat_by_class.len(), 0);
695 }
696
697 #[test]
698 fn test_fit_gaus() {
699 let mut model: GaussianNaiveBayes = GaussianNaiveBayes::new();
700 let x = vec![
701 vec![2.0, 1.0],
702 vec![3.0, 2.0],
703 vec![2.5, 1.5],
704 vec![4.0, 3.0],
705 ];
706 let y = vec![
707 "class1".to_string(),
708 "class1".to_string(),
709 "class2".to_string(),
710 "class2".to_string(),
711 ];
712 let model = model.fit(&x, &y);
713
714 assert_eq!(model.classes.len(), 2);
715 assert!(model.classes.contains(&"class1".to_string()));
716 assert!(model.classes.contains(&"class2".to_string()));
717
718 assert_eq!(model.probability_of_class.len(), 2);
719 assert!(model
720 .probability_of_class
721 .contains_key(&"class1".to_string()));
722 assert!(model
723 .probability_of_class
724 .contains_key(&"class2".to_string()));
725
726 assert_eq!(model.probability_of_feat_by_class.len(), 2);
727 assert!(model
728 .probability_of_feat_by_class
729 .contains_key(&"class1".to_string()));
730 assert!(model
731 .probability_of_feat_by_class
732 .contains_key(&"class2".to_string()));
733 }
734
735 #[test]
736 fn test_predict_gaus() {
737 let mut model: GaussianNaiveBayes = GaussianNaiveBayes::new();
738 let x = vec![
739 vec![2.0, 1.0],
740 vec![3.0, 2.0],
741 vec![2.5, 1.5],
742 vec![4.0, 3.0],
743 ];
744 let y = vec![
745 "class1".to_string(),
746 "class1".to_string(),
747 "class2".to_string(),
748 "class2".to_string(),
749 ];
750 let model = model.fit(&x, &y);
751
752 let x_test = vec![vec![2.0, 1.0], vec![4.0, 3.0]];
753
754 let predictions = model.predict(&x_test);
755 assert_eq!(predictions.len(), x_test.len());
756 assert_eq!(predictions[0], "class1");
757 assert_eq!(predictions[1], "class2");
758 }
759}