1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
use std::collections::HashMap;
use std::collections::HashSet;
use super::util::{Fit, Unfit};

/// Implementation of a standard Naive Bayes classifier.
///
/// This classifier uses Laplace smoothing, the degree of which can be controlled with the `alpha` parameter.
///
/// # Parameters
/// - `alpha`: The Laplace smoothing factor.
/// - `probability_of_class`: HashMap storing the probabilities of each class.
/// - `probability_of_feat_by_class`: HashMap storing the probabilities of each feature given a class.
/// - `state`: PhantomData indicating whether the classifier has been fit.
///
/// # Type parameters
/// - `State`: Indicates whether the classifier has been fit. Can either be `Fit` or `Unfit`.
///
/// # Example
///
/// ```
/// use ducky_learn::naive_bayes::StdNaiveBayes;
///
/// // Define train and test data
/// let x_train: Vec<Vec<f64>> = vec![
///     vec![1.0, 2.0, 3.0],
///     vec![2.0, 3.0, 4.0],
///     vec![3.0, 4.0, 5.0],
/// ];
/// let y_train: Vec<String> = vec!["class1".to_string(), "class2".to_string(), "class1".to_string()];
///
/// let x_test: Vec<Vec<f64>> = vec![
///     vec![1.5, 2.5, 3.5],
///     vec![2.5, 3.5, 4.5],
/// ];
///
/// let mut nb = StdNaiveBayes::new(1.0);
/// let nb = nb.fit(&x_train, &y_train);
/// let y_pred = nb.predict(&x_test);
///
/// // y_pred will hold the predicted classes for x_test
/// ```

#[derive(Debug)]
pub struct StdNaiveBayes<State = Unfit> {
    pub alpha: f64,
    pub probability_of_class: HashMap<String, f64>,
    pub probability_of_feat_by_class: HashMap<String, HashMap<String, f64>>,

    state: std::marker::PhantomData<State>,
}

impl StdNaiveBayes {

    /// Constructs a new, unfitted `StdNaiveBayes` classifier with a specified alpha value.
    ///
    /// # Parameters
    /// - `alpha`: The Laplace smoothing factor.
    ///
    /// # Returns
    /// A new `StdNaiveBayes` instance.
    pub fn new(alpha: f64) -> Self {
        Self {
            alpha,
            probability_of_class: HashMap::new(),
            probability_of_feat_by_class: HashMap::new(),

            state: Default::default(),
        }
    }

    /// Fits the `StdNaiveBayes` classifier to the training data.
    ///
    /// # Parameters
    /// - `x`: The training data.
    /// - `y`: The target values.
    ///
    /// # Returns
    /// The fitted `StdNaiveBayes` classifier.
    pub fn fit(mut self, x: &Vec<Vec<f64>>, y: &Vec<String>) -> StdNaiveBayes<Fit> {
        let mut y_counts: HashMap<String, i32> = HashMap::new();
        for class in y {
            let counter = y_counts.entry(class.to_string()).or_insert(0);
            *counter += 1;
        }

        let total_rows = y.len() as f64;
        let unique_classes: HashSet<String> = y.into_iter().cloned().collect();

        for uniq_class in &unique_classes {
            self.probability_of_class.insert(uniq_class.to_string(), *y_counts.get(uniq_class).unwrap() as f64 / total_rows);

            let mut class_feat_probs: HashMap<String, f64> = HashMap::new();
            let mut sum_of_feats_in_class = 0.0;
            for (i, class) in y.iter().enumerate() {
                if class == uniq_class {
                    for (j, feat_count) in x[i].iter().enumerate() {
                        let counter = class_feat_probs.entry(j.to_string()).or_insert(0.0);
                        *counter += *feat_count;
                        sum_of_feats_in_class += *feat_count;
                    }
                }
            }
            sum_of_feats_in_class += self.alpha * x[0].len() as f64;

            for (feat, count) in class_feat_probs.iter_mut() {
                *count = (*count + self.alpha) / sum_of_feats_in_class;
            }

            self.probability_of_feat_by_class.insert(uniq_class.to_string(), class_feat_probs);
        }

        StdNaiveBayes{
            alpha: self.alpha,
            probability_of_class: self.probability_of_class.clone(),
            probability_of_feat_by_class: self.probability_of_feat_by_class.clone(),

            state: std::marker::PhantomData::<Fit>,
        }
    }
}

impl StdNaiveBayes<Fit> {

    /// Predicts the target values for the given data.
    ///
    /// # Parameters
    /// - `x`: The data to predict target values for.
    ///
    /// # Returns
    /// The predicted target values.
    ///
    /// # Panics
    /// This function will panic if the classifier has not been fit.
    pub fn predict(&self, x: &Vec<Vec<f64>>) -> Vec<String> {
        let mut y_pred: Vec<String> = Vec::new();
        let unique_classes: Vec<String> = self.probability_of_class.keys().cloned().collect();
        let class_probabilities: Vec<f64> = self.probability_of_class.values().cloned().collect();
        let small_number = 1e-9;

        for row in x {
            let mut row_probabilities: Vec<f64> = Vec::new();
            for (i, class) in unique_classes.iter().enumerate() {
                let mut log_sum = (class_probabilities[i] + small_number).ln();
                for (j, feat_count) in row.iter().enumerate() {
                    if *feat_count > 0.0 {
                        let prob = self.probability_of_feat_by_class.get(class).unwrap().get(&j.to_string()).unwrap();
                        log_sum += (*feat_count * (*prob + small_number).ln());
                    }
                }
                row_probabilities.push(log_sum);
            }

            let max_prob_index = row_probabilities.iter().enumerate().max_by(|a, b| a.1.partial_cmp(b.1).unwrap()).unwrap().0;
            y_pred.push(unique_classes[max_prob_index].to_string());
        }

        y_pred
    }
}