1use std::collections::HashMap;
2use std::collections::HashSet;
3use std::collections::hash_map::Keys;
4use std::iter::FromIterator;
5use std::vec::Vec;
6
7struct Attributes {
8 attributes: HashMap<String, HashMap<String, i64>>,
9}
10
11impl Attributes {
12 pub fn new() -> Attributes {
13 Attributes {
14 attributes: HashMap::new(),
15 }
16 }
17
18 fn add(&mut self, attribute: &String, label: &String) {
19 let labels = self.attributes
20 .entry(attribute.to_string())
21 .or_insert(HashMap::new());
22 let value = labels.entry((*label).to_string()).or_insert(0);
23 *value += 1;
24 }
25
26 fn get_frequency(&mut self, attribute: &String, label: &String) -> (Option<&i64>, bool) {
27 match self.attributes.get(attribute) {
28 Some(labels) => match labels.get(label) {
29 Some(value) => return (Some(value), true),
30 None => return (None, true),
31 },
32 None => return (None, false),
33 }
34 }
35}
36
37struct Labels {
38 counts: HashMap<String, i64>,
39}
40
41impl Labels {
42 pub fn new() -> Labels {
43 Labels {
44 counts: HashMap::new(),
45 }
46 }
47
48 fn add(&mut self, label: &String) {
49 let value = self.counts.entry(label.to_string()).or_insert(0);
50 *value += 1;
51 }
52
53 fn get_count(&mut self, label: &String) -> Option<&i64> {
54 return self.counts.get(label);
55 }
56
57 fn get_labels(&mut self) -> Keys<String, i64> {
58 return self.counts.keys();
59 }
60
61 fn get_total(&mut self) -> i64 {
62 return self.counts.values().fold(0, |acc, x| acc + x);
63 }
64}
65
66struct Model {
67 labels: Labels,
68 attributes: Attributes,
69}
70
71impl Model {
72 pub fn new() -> Model {
73 Model {
74 labels: Labels::new(),
75 attributes: Attributes::new(),
76 }
77 }
78 fn train(&mut self, data: &Vec<String>, label: &String) {
79 self.labels.add(label);
80 for attribute in data {
81 self.attributes.add(attribute, label);
82 }
83 }
84}
85
86pub struct NaiveBayes {
87 model: Model,
88 minimum_probability: f64,
89 minimum_log_probability: f64
90}
91
92impl NaiveBayes {
93 pub fn new() -> NaiveBayes {
95 NaiveBayes {
96 model: Model::new(),
97 minimum_probability: 1e-9,
98 minimum_log_probability: -100.0,
99 }
100 }
101
102 fn prior(&mut self, label: &String) -> Option<f64> {
103 let total = *(&self.model.labels.get_total()) as f64;
104 let label = &self.model.labels.get_count(label);
105 if label.is_some() && total > 0.0 {
106 return Some(*label.unwrap() as f64 / total);
107 } else {
108 return None;
109 }
110 }
111
112 fn log_prior(&mut self, label: &String) -> Option<f64> {
113 let total = *(&self.model.labels.get_total()) as f64;
114 let label = &self.model.labels.get_count(label);
115 if label.is_some() && total > 0.0 {
116 return Some((*label.unwrap() as f64).ln() - total.ln());
117 } else {
118 return None;
119 }
120 }
121
122 fn calculate_attr_prob(&mut self, attribute: &String, label: &String) -> Option<f64> {
123 match self.model.attributes.get_frequency(attribute, label) {
124 (Some(frequency), true) => match self.model.labels.get_count(label) {
125 Some(count) => return Some((*frequency as f64) / (*count as f64)),
126 None => return None,
127 },
128 (None, true) => return Some(self.minimum_probability),
129 (None, false) => return None,
130 (Some(_), false) => None,
131 }
132 }
133
134 fn calculate_attr_log_prob(&mut self, attribute: &String, label: &String) -> Option<f64> {
135 match self.model.attributes.get_frequency(attribute, label) {
136 (Some(frequency), true) => match self.model.labels.get_count(label) {
137 Some(count) => return Some((*frequency as f64).ln() - (*count as f64).ln()),
138 None => return None,
139 },
140 (None, true) => return Some(self.minimum_log_probability),
141 (None, false) => return None,
142 (Some(_), false) => None,
143 }
144 }
145
146 fn label_prob(&mut self, label: &String, attrs: &HashSet<String>) -> Vec<f64> {
147 let mut probs: Vec<f64> = Vec::new();
148 for attr in attrs {
149 match self.calculate_attr_prob(attr, label) {
150 Some(p) => {
151 probs.push(p);
152 }
153 None => {}
154 }
155 }
156 return probs;
157 }
158
159 fn label_log_prob(&mut self, label: &String, attrs: &HashSet<String>) -> Vec<f64> {
160 let mut probs: Vec<f64> = Vec::new();
161 for attr in attrs {
162 match self.calculate_attr_log_prob(attr, label) {
163 Some(p) => {
164 probs.push(p);
165 }
166 None => {}
167 }
168 }
169 return probs;
170 }
171
172 pub fn train(&mut self, data: &Vec<String>, label: &String) {
174 self.model.train(data, label);
175 }
176
177 pub fn classify(&mut self, data: &Vec<String>) -> HashMap<String, f64> {
180 let attribute_set: HashSet<String> = HashSet::from_iter(data.iter().cloned());
181 let mut result: HashMap<String, f64> = HashMap::new();
182 let labels: HashSet<String> =
183 HashSet::from_iter(self.model.labels.get_labels().into_iter().cloned());
184 for label in labels {
185 let p = self.label_prob(&label, &attribute_set);
186 let p_iter = p.into_iter().fold(1.0, |acc, x| acc * x);
187 let _value = result
188 .entry(label.to_string())
189 .or_insert(p_iter * self.prior(&label).unwrap());
190 }
191
192 return result;
193 }
194
195 pub fn log_classify(&mut self, data: &Vec<String>) -> HashMap<String, f64> {
198 let attribute_set: HashSet<String> = HashSet::from_iter(data.iter().cloned());
199 let mut result: HashMap<String, f64> = HashMap::new();
200 let labels: HashSet<String> =
201 HashSet::from_iter(self.model.labels.get_labels().into_iter().cloned());
202 for label in labels {
203 let p = self.label_log_prob(&label, &attribute_set);
204 let max = p.iter().cloned().fold(-1./0. , f64::max);
205 let p_iter = p.into_iter().fold(0.0, |acc, x| acc + (x - max).exp());
206 let _value = result
207 .entry(label.to_string())
208 .or_insert(max + p_iter.ln() + self.log_prior(&label).unwrap());
209 }
210
211 return result;
212 }
213}
214
215#[cfg(test)]
216mod test_attributes {
217 use super::*;
218
219 #[test]
220 fn attribute_add() {
221 let mut model = Attributes::new();
222 model.add(&"rust".to_string(), &"naive".to_string());
223 assert_eq!(
224 *model
225 .get_frequency(&"rust".to_string(), &"naive".to_string())
226 .0
227 .unwrap(),
228 1
229 );
230 }
231
232 #[test]
233 fn get_non_existing() {
234 let mut model = Attributes::new();
235 assert_eq!(
236 model
237 .get_frequency(&"rust".to_string(), &"naive".to_string())
238 .0,
239 None
240 );
241 }
242
243}
244
245#[cfg(test)]
246mod test_labels {
247
248 use super::*;
249
250 #[test]
251 fn label_add() {
252 let mut labels = Labels::new();
253 labels.add(&"rust".to_string());
254 assert_eq!(*labels.get_count(&"rust".to_string()).unwrap(), 1);
255 }
256
257 #[test]
258 fn label_get_nonexistent() {
259 let mut labels = Labels::new();
260 assert_eq!(labels.get_count(&"rust".to_string()), None);
261 }
262
263 #[test]
264 fn get_labels() {
265 let mut labels = Labels::new();
266 labels.add(&"rust".to_string());
267 assert_eq!(labels.get_labels().len(), 1);
268 assert_eq!(labels.get_labels().last().unwrap(), "rust");
269 }
270
271 #[test]
272 fn get_counts() {
273 let mut labels = Labels::new();
274 labels.add(&"rust".to_string());
275 labels.add(&"rust".to_string());
276 assert_eq!(labels.get_labels().len(), 1);
277 assert_eq!(*labels.get_count(&"rust".to_string()).unwrap(), 2);
278 }
279
280 #[test]
281 fn get_nonexistent_counts() {
282 let mut labels = Labels::new();
283 assert_eq!(labels.get_labels().len(), 0);
284 assert_eq!(labels.get_count(&"rust".to_string()), None);
285 }
286
287 #[test]
288 fn get_nonexistent_total() {
289 let mut labels = Labels::new();
290 assert_eq!(labels.get_total(), 0);
291 }
292
293 #[test]
294 fn get_total() {
295 let mut labels = Labels::new();
296 labels.add(&"rust".to_string());
297 labels.add(&"rust".to_string());
298 labels.add(&"naive".to_string());
299 labels.add(&"bayes".to_string());
300 assert_eq!(labels.get_total(), 4);
301 }
302
303}
304
305#[cfg(test)]
306mod test_naive_bayes {
307 use super::*;
308 use std::f64::consts::LN_2;
309
310 #[test]
311 fn test_prior() {
312 let mut nb = NaiveBayes::new();
313 let mut data: Vec<String> = Vec::new();
314 data.push("rust".to_string());
315 data.push("naive".to_string());
316 data.push("bayes".to_string());
317 nb.model.train(&data, &"👍".to_string());
318 let prior = nb.prior(&"👍".to_string());
319 assert_eq!(prior, Some(1.0));
320 }
321
322 #[test]
323 fn test_log_prior() {
324 let mut nb = NaiveBayes::new();
325 let mut data: Vec<String> = Vec::new();
326 data.push("rust".to_string());
327 data.push("naive".to_string());
328 data.push("bayes".to_string());
329 nb.model.train(&data, &"👍".to_string());
330 let prior = nb.log_prior(&"👍".to_string());
331 assert_eq!(prior, Some(0.0));
332 }
333
334 #[test]
335 fn test_prior_nonexistent() {
336 let mut nb = NaiveBayes::new();
337 let mut data: Vec<String> = Vec::new();
338 data.push("rust".to_string());
339 data.push("naive".to_string());
340 data.push("bayes".to_string());
341 nb.model.train(&data, &"👍".to_string());
342 let prior = nb.prior(&"👎".to_string());
343 assert_eq!(prior, None);
344 }
345
346 #[test]
347 fn test_classification() {
348 let mut nb = NaiveBayes::new();
349 let mut data: Vec<String> = Vec::new();
350 data.push("rust".to_string());
351 data.push("naive".to_string());
352 data.push("bayes".to_string());
353 nb.model.train(&data, &"👍".to_string());
354 let mut data2: Vec<String> = Vec::new();
355 data2.push("golang".to_string());
356 data2.push("java".to_string());
357 data2.push("javascript".to_string());
358 nb.model.train(&data2, &"👎".to_string());
359
360 let classes = nb.classify(
361 &(vec![
362 "rust".to_string(),
363 "scala".to_string(),
364 "c++".to_string(),
365 ]),
366 );
367 assert_eq!(classes.get(&"👍".to_string()).unwrap(), &0.5);
368 assert_eq!(classes.get(&"👎".to_string()).unwrap(), &0.0000000005);
369 print!("{:?}", classes);
370
371 }
372
373 #[test]
374 fn test_log_classification() {
375 let mut nb = NaiveBayes::new();
376 let mut data: Vec<String> = Vec::new();
377 data.push("rust".to_string());
378 data.push("naive".to_string());
379 data.push("bayes".to_string());
380 nb.model.train(&data, &"👍".to_string());
381 let mut data2: Vec<String> = Vec::new();
382 data2.push("golang".to_string());
383 data2.push("java".to_string());
384 data2.push("javascript".to_string());
385 nb.model.train(&data2, &"👎".to_string());
386
387 let classes = nb.log_classify(
388 &(vec![
389 "rust".to_string(),
390 "scala".to_string(),
391 "c++".to_string(),
392 ]),
393 );
394 assert_eq!(classes.get(&"👍".to_string()).unwrap(), &-LN_2);
395 assert_eq!(classes.get(&"👎".to_string()).unwrap(), &-100.69314718055995);
396 print!("{:?}", classes);
397
398 }
399
400}