naivebayes/
lib.rs

1use std::collections::HashMap;
2use std::collections::HashSet;
3use std::collections::hash_map::Keys;
4use std::iter::FromIterator;
5use std::vec::Vec;
6
7struct Attributes {
8    attributes: HashMap<String, HashMap<String, i64>>,
9}
10
11impl Attributes {
12    pub fn new() -> Attributes {
13        Attributes {
14            attributes: HashMap::new(),
15        }
16    }
17
18    fn add(&mut self, attribute: &String, label: &String) {
19        let labels = self.attributes
20            .entry(attribute.to_string())
21            .or_insert(HashMap::new());
22        let value = labels.entry((*label).to_string()).or_insert(0);
23        *value += 1;
24    }
25
26    fn get_frequency(&mut self, attribute: &String, label: &String) -> (Option<&i64>, bool) {
27        match self.attributes.get(attribute) {
28            Some(labels) => match labels.get(label) {
29                Some(value) => return (Some(value), true),
30                None => return (None, true),
31            },
32            None => return (None, false),
33        }
34    }
35}
36
37struct Labels {
38    counts: HashMap<String, i64>,
39}
40
41impl Labels {
42    pub fn new() -> Labels {
43        Labels {
44            counts: HashMap::new(),
45        }
46    }
47
48    fn add(&mut self, label: &String) {
49        let value = self.counts.entry(label.to_string()).or_insert(0);
50        *value += 1;
51    }
52
53    fn get_count(&mut self, label: &String) -> Option<&i64> {
54        return self.counts.get(label);
55    }
56
57    fn get_labels(&mut self) -> Keys<String, i64> {
58        return self.counts.keys();
59    }
60
61    fn get_total(&mut self) -> i64 {
62        return self.counts.values().fold(0, |acc, x| acc + x);
63    }
64}
65
66struct Model {
67    labels: Labels,
68    attributes: Attributes,
69}
70
71impl Model {
72    pub fn new() -> Model {
73        Model {
74            labels: Labels::new(),
75            attributes: Attributes::new(),
76        }
77    }
78    fn train(&mut self, data: &Vec<String>, label: &String) {
79        self.labels.add(label);
80        for attribute in data {
81            self.attributes.add(attribute, label);
82        }
83    }
84}
85
86pub struct NaiveBayes {
87    model: Model,
88    minimum_probability: f64,
89    minimum_log_probability: f64
90}
91
92impl NaiveBayes {
93    /// creates a new instance of a `NaiveBayes` classifier.
94    pub fn new() -> NaiveBayes {
95        NaiveBayes {
96            model: Model::new(),
97            minimum_probability: 1e-9,
98            minimum_log_probability: -100.0,
99        }
100    }
101
102    fn prior(&mut self, label: &String) -> Option<f64> {
103        let total = *(&self.model.labels.get_total()) as f64;
104        let label = &self.model.labels.get_count(label);
105        if label.is_some() && total > 0.0 {
106            return Some(*label.unwrap() as f64 / total);
107        } else {
108            return None;
109        }
110    }
111
112    fn log_prior(&mut self, label: &String) -> Option<f64> {
113        let total = *(&self.model.labels.get_total()) as f64;
114        let label = &self.model.labels.get_count(label);
115        if label.is_some() && total > 0.0 {
116            return Some((*label.unwrap() as f64).ln() - total.ln());
117        } else {
118            return None;
119        }
120    }
121
122    fn calculate_attr_prob(&mut self, attribute: &String, label: &String) -> Option<f64> {
123        match self.model.attributes.get_frequency(attribute, label) {
124            (Some(frequency), true) => match self.model.labels.get_count(label) {
125                Some(count) => return Some((*frequency as f64) / (*count as f64)),
126                None => return None,
127            },
128            (None, true) => return Some(self.minimum_probability),
129            (None, false) => return None,
130            (Some(_), false) => None,
131        }
132    }
133
134    fn calculate_attr_log_prob(&mut self, attribute: &String, label: &String) -> Option<f64> {
135        match self.model.attributes.get_frequency(attribute, label) {
136            (Some(frequency), true) => match self.model.labels.get_count(label) {
137                Some(count) => return Some((*frequency as f64).ln() - (*count as f64).ln()),
138                None => return None,
139            },
140            (None, true) => return Some(self.minimum_log_probability),
141            (None, false) => return None,
142            (Some(_), false) => None,
143        }
144    }
145
146    fn label_prob(&mut self, label: &String, attrs: &HashSet<String>) -> Vec<f64> {
147        let mut probs: Vec<f64> = Vec::new();
148        for attr in attrs {
149            match self.calculate_attr_prob(attr, label) {
150                Some(p) => {
151                    probs.push(p);
152                }
153                None => {}
154            }
155        }
156        return probs;
157    }
158
159    fn label_log_prob(&mut self, label: &String, attrs: &HashSet<String>) -> Vec<f64> {
160        let mut probs: Vec<f64> = Vec::new();
161        for attr in attrs {
162            match self.calculate_attr_log_prob(attr, label) {
163                Some(p) => {
164                    probs.push(p);
165                }
166                None => {}
167            }
168        }
169        return probs;
170    }
171
172    /// trains the model with a `Vec<String>` of tokens, associating it with a `String` label.
173    pub fn train(&mut self, data: &Vec<String>, label: &String) {
174        self.model.train(data, label);
175    }
176
177    /// classify a `Vec<String>` of tokens returning a map of tokens and probabilities
178    /// as keys and values, respectively.
179    pub fn classify(&mut self, data: &Vec<String>) -> HashMap<String, f64> {
180        let attribute_set: HashSet<String> = HashSet::from_iter(data.iter().cloned());
181        let mut result: HashMap<String, f64> = HashMap::new();
182        let labels: HashSet<String> =
183            HashSet::from_iter(self.model.labels.get_labels().into_iter().cloned());
184        for label in labels {
185            let p = self.label_prob(&label, &attribute_set);
186            let p_iter = p.into_iter().fold(1.0, |acc, x| acc * x);
187            let _value = result
188                .entry(label.to_string())
189                .or_insert(p_iter * self.prior(&label).unwrap());
190        }
191
192        return result;
193    }
194
195    /// classify a `Vec<String>` of tokens returning a map of tokens and log-probabilities
196    /// as keys and values, respectively. Using `log_classify` may prevent underflows.
197    pub fn log_classify(&mut self, data: &Vec<String>) -> HashMap<String, f64> {
198        let attribute_set: HashSet<String> = HashSet::from_iter(data.iter().cloned());
199        let mut result: HashMap<String, f64> = HashMap::new();
200        let labels: HashSet<String> =
201            HashSet::from_iter(self.model.labels.get_labels().into_iter().cloned());
202        for label in labels {
203            let p = self.label_log_prob(&label, &attribute_set);
204            let max = p.iter().cloned().fold(-1./0. /* inf */, f64::max);
205            let p_iter = p.into_iter().fold(0.0, |acc, x| acc + (x - max).exp());
206            let _value = result
207                .entry(label.to_string())
208                .or_insert(max + p_iter.ln() + self.log_prior(&label).unwrap());
209        }
210
211        return result;
212    }
213}
214
215#[cfg(test)]
216mod test_attributes {
217    use super::*;
218
219    #[test]
220    fn attribute_add() {
221        let mut model = Attributes::new();
222        model.add(&"rust".to_string(), &"naive".to_string());
223        assert_eq!(
224            *model
225                .get_frequency(&"rust".to_string(), &"naive".to_string())
226                .0
227                .unwrap(),
228            1
229        );
230    }
231
232    #[test]
233    fn get_non_existing() {
234        let mut model = Attributes::new();
235        assert_eq!(
236            model
237                .get_frequency(&"rust".to_string(), &"naive".to_string())
238                .0,
239            None
240        );
241    }
242
243}
244
245#[cfg(test)]
246mod test_labels {
247
248    use super::*;
249
250    #[test]
251    fn label_add() {
252        let mut labels = Labels::new();
253        labels.add(&"rust".to_string());
254        assert_eq!(*labels.get_count(&"rust".to_string()).unwrap(), 1);
255    }
256
257    #[test]
258    fn label_get_nonexistent() {
259        let mut labels = Labels::new();
260        assert_eq!(labels.get_count(&"rust".to_string()), None);
261    }
262
263    #[test]
264    fn get_labels() {
265        let mut labels = Labels::new();
266        labels.add(&"rust".to_string());
267        assert_eq!(labels.get_labels().len(), 1);
268        assert_eq!(labels.get_labels().last().unwrap(), "rust");
269    }
270
271    #[test]
272    fn get_counts() {
273        let mut labels = Labels::new();
274        labels.add(&"rust".to_string());
275        labels.add(&"rust".to_string());
276        assert_eq!(labels.get_labels().len(), 1);
277        assert_eq!(*labels.get_count(&"rust".to_string()).unwrap(), 2);
278    }
279
280    #[test]
281    fn get_nonexistent_counts() {
282        let mut labels = Labels::new();
283        assert_eq!(labels.get_labels().len(), 0);
284        assert_eq!(labels.get_count(&"rust".to_string()), None);
285    }
286
287    #[test]
288    fn get_nonexistent_total() {
289        let mut labels = Labels::new();
290        assert_eq!(labels.get_total(), 0);
291    }
292
293    #[test]
294    fn get_total() {
295        let mut labels = Labels::new();
296        labels.add(&"rust".to_string());
297        labels.add(&"rust".to_string());
298        labels.add(&"naive".to_string());
299        labels.add(&"bayes".to_string());
300        assert_eq!(labels.get_total(), 4);
301    }
302
303}
304
305#[cfg(test)]
306mod test_naive_bayes {
307    use super::*;
308    use std::f64::consts::LN_2;
309
310    #[test]
311    fn test_prior() {
312        let mut nb = NaiveBayes::new();
313        let mut data: Vec<String> = Vec::new();
314        data.push("rust".to_string());
315        data.push("naive".to_string());
316        data.push("bayes".to_string());
317        nb.model.train(&data, &"👍".to_string());
318        let prior = nb.prior(&"👍".to_string());
319        assert_eq!(prior, Some(1.0));
320    }
321
322    #[test]
323    fn test_log_prior() {
324        let mut nb = NaiveBayes::new();
325        let mut data: Vec<String> = Vec::new();
326        data.push("rust".to_string());
327        data.push("naive".to_string());
328        data.push("bayes".to_string());
329        nb.model.train(&data, &"👍".to_string());
330        let prior = nb.log_prior(&"👍".to_string());
331        assert_eq!(prior, Some(0.0));
332    }
333
334    #[test]
335    fn test_prior_nonexistent() {
336        let mut nb = NaiveBayes::new();
337        let mut data: Vec<String> = Vec::new();
338        data.push("rust".to_string());
339        data.push("naive".to_string());
340        data.push("bayes".to_string());
341        nb.model.train(&data, &"👍".to_string());
342        let prior = nb.prior(&"👎".to_string());
343        assert_eq!(prior, None);
344    }
345
346    #[test]
347    fn test_classification() {
348        let mut nb = NaiveBayes::new();
349        let mut data: Vec<String> = Vec::new();
350        data.push("rust".to_string());
351        data.push("naive".to_string());
352        data.push("bayes".to_string());
353        nb.model.train(&data, &"👍".to_string());
354        let mut data2: Vec<String> = Vec::new();
355        data2.push("golang".to_string());
356        data2.push("java".to_string());
357        data2.push("javascript".to_string());
358        nb.model.train(&data2, &"👎".to_string());
359
360        let classes = nb.classify(
361            &(vec![
362                "rust".to_string(),
363                "scala".to_string(),
364                "c++".to_string(),
365            ]),
366        );
367        assert_eq!(classes.get(&"👍".to_string()).unwrap(), &0.5);
368        assert_eq!(classes.get(&"👎".to_string()).unwrap(), &0.0000000005);
369        print!("{:?}", classes);
370
371    }
372
373    #[test]
374    fn test_log_classification() {
375        let mut nb = NaiveBayes::new();
376        let mut data: Vec<String> = Vec::new();
377        data.push("rust".to_string());
378        data.push("naive".to_string());
379        data.push("bayes".to_string());
380        nb.model.train(&data, &"👍".to_string());
381        let mut data2: Vec<String> = Vec::new();
382        data2.push("golang".to_string());
383        data2.push("java".to_string());
384        data2.push("javascript".to_string());
385        nb.model.train(&data2, &"👎".to_string());
386
387        let classes = nb.log_classify(
388            &(vec![
389                "rust".to_string(),
390                "scala".to_string(),
391                "c++".to_string(),
392            ]),
393        );
394        assert_eq!(classes.get(&"👍".to_string()).unwrap(), &-LN_2);
395        assert_eq!(classes.get(&"👎".to_string()).unwrap(), &-100.69314718055995);
396        print!("{:?}", classes);
397
398    }
399
400}