1extern crate rand;
2
3use std::cmp::Eq;
4use std::collections::{HashMap, HashSet};
5use std::fmt;
6use std::hash::{Hash, Hasher};
7use rand::prelude::*;
8
9pub struct TreeConfig {
11 pub decision: String,
12 pub max_depth: usize,
13 pub min_count: usize,
14 pub entropy_threshold: f64,
15 pub impurity_method: fn(&String, &Dataset) -> f64
16}
17
18impl TreeConfig {
19 pub fn new() -> TreeConfig {
21 return TreeConfig {
22 decision: "category".to_string(),
23 max_depth: 70,
24 min_count: 1,
25 entropy_threshold: 0.01,
26 impurity_method: entropy
27 };
28 }
29
30 pub fn new_gini() -> TreeConfig {
31 return TreeConfig {
32 decision: "category".to_string(),
33 max_depth: 70,
34 min_count: 1,
35 entropy_threshold: 0.01,
36 impurity_method: gini
37 };
38 }
39}
40
41impl Clone for TreeConfig {
42 fn clone(&self) -> TreeConfig {
43 TreeConfig {
44 decision: self.decision.clone(),
45 max_depth: 70,
46 min_count: 1,
47 entropy_threshold: 0.01,
48 impurity_method: self.impurity_method
49 }
50 }
51}
52
53pub struct Value {
55 pub data: String,
56}
57
58impl fmt::Debug for Value {
59 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
60 write!(f, "Value {{ data: {} }}", self.data)
61 }
62}
63
64impl Eq for Value {}
65
66impl PartialEq for Value {
67 fn eq(&self, other: &Value) -> bool {
68 self.data == other.data
69 }
70}
71
72impl Hash for Value {
73 fn hash<H: Hasher>(&self, state: &mut H) {
74 self.data.hash(state);
75 }
76}
77
78impl Clone for Value {
79 fn clone(&self) -> Value {
80 Value {
81 data: self.data.clone(),
82 }
83 }
84}
85
86struct CacheEntry {
87 attribute: String,
88 value: Value,
89}
90
91impl Eq for CacheEntry {}
92
93impl PartialEq for CacheEntry {
94 fn eq(&self, other: &CacheEntry) -> bool {
95 self.attribute == other.attribute && self.value == other.value
96 }
97}
98
99impl Hash for CacheEntry {
100 fn hash<H: Hasher>(&self, state: &mut H) {
101 self.attribute.hash(state);
102 self.value.hash(state);
103 }
104}
105
106pub type Item = HashMap<String, Value>;
108pub type Dataset = Vec<Item>;
110
111struct Split {
112 gain: f64,
113 true_branch: Dataset,
114 false_branch: Dataset,
115 attribute: String,
116 pivot: Value,
117}
118
119fn unique_values<'a>(attribute: &String, data: &'a Vec<Item>) -> HashMap<&'a Value, usize> {
121 let mut counter: HashMap<&Value, usize> = HashMap::new();
122 for mut item in data.into_iter() {
123 let value = item.get(attribute);
124 match value {
125 Some(v) => {
126 let current = counter.entry(v).or_insert(0);
127 *current += 1;
128 }
129 None => {}
130 }
131 }
132 return counter;
133}
134
135fn value_frequency(attribute: &String, data: &Dataset) -> Option<Value> {
137 let unique = unique_values(attribute, data);
138 let mut most_frequent_count = 0;
139 let mut most_frequent_value: Option<Value> = None;
140 for (value, count) in unique.into_iter() {
141 if count > most_frequent_count {
142 let _v = value.clone();
143 most_frequent_count = count;
144 most_frequent_value = Some(_v);
145 }
146 }
147 return most_frequent_value;
148}
149
150fn calculate_split(attribute: &String, pivot: &Value, data: &Dataset) -> Split {
151 let mut true_branch = Dataset::new();
152 let mut false_branch = Dataset::new();
153
154 for item in data.into_iter() {
155 let value = item.get(attribute);
156 match value {
157 Some(v) => {
158 if v == pivot {
159 true_branch.push(item.clone());
160 } else {
161 false_branch.push(item.clone());
162 }
163 }
164 None => {}
165 }
166 }
167
168 return Split {
169 gain: 0.0,
170 true_branch,
171 false_branch,
172 attribute: "category".to_string(),
173 pivot: Value {
174 data: "".to_string(),
175 },
176 };
177}
178
179fn entropy(attribute: &String, data: &Dataset) -> f64 {
180 let counter = unique_values(attribute, data);
181 let size = data.len() as f64;
182 let mut impurity = 0.0;
183 for (_, count) in counter {
184 let p = count as f64 / size;
185 impurity += -p * p.log2();
186 }
187 return impurity;
188}
189
190fn gini(attribute: &String, data: &Dataset) -> f64 {
191 let counter = unique_values(attribute, data);
192 let size = data.len() as f64;
193 let mut impurity = 1.0;
194 for (_, count) in counter {
195 let p = count as f64 / size;
196 impurity += -p * p;
197 }
198 return impurity;
199}
200
201pub struct DecisionTree {
203 decision: Option<Value>,
204 true_branch: Option<Box<DecisionTree>>,
205 false_branch: Option<Box<DecisionTree>>,
206 attribute: Option<String>,
207 pivot: Option<Value>,
208}
209
210impl DecisionTree {
211 pub fn build(
214 _attribute: String,
215 config: &TreeConfig,
216 data: &mut Dataset,
217 ) -> Option<Box<DecisionTree>> {
218 let data_size = data.len();
219
220 if config.max_depth == 0 || data_size <= config.min_count {
221 return Some(Box::new(DecisionTree {
222 decision: value_frequency(&config.decision, &data),
223 true_branch: None,
224 false_branch: None,
225 attribute: None,
226 pivot: None,
227 }));
228 }
229
230 let _impurity = (config.impurity_method)(&_attribute, data);
231
232 if _impurity <= config.entropy_threshold {
233 return Some(Box::new(DecisionTree {
234 decision: value_frequency(&config.decision, &data),
235 true_branch: None,
236 false_branch: None,
237 attribute: None,
238 pivot: None,
239 }));
240 }
241
242 let mut cache: HashSet<CacheEntry> = HashSet::new();
243
244 let _data = data.clone();
245
246 let mut best_split = Split {
247 gain: 0.0,
248 true_branch: Dataset::new(),
249 false_branch: Dataset::new(),
250 attribute: "category".to_string(),
251 pivot: Value {
252 data: "".to_string(),
253 },
254 };
255
256 for item in _data {
257 print!("Item: {:?}\n", item);
258
259 for attribute in item.keys() {
260 print!("\tAttribute: {:?}\n", attribute);
261
262 if *attribute == config.decision {
263 continue;
264 }
265
266 let pivot = item.get(attribute).unwrap();
267
268 let cache_entry = CacheEntry {
269 attribute: attribute.clone(),
270 value: pivot.clone(),
271 };
272
273 if cache.contains(&cache_entry) {
274 continue;
275 }
276
277 cache.insert(cache_entry);
278
279 let split = calculate_split(attribute, pivot, &data);
280 print!("\t\tdata = {:?}", data);
281 let _true_branch_entropy = entropy(attribute, &split.true_branch);
282 let _false_branch_entropy = entropy(attribute, &split.false_branch);
283 print!(
284 "\tE(t) = {:?}, E(f) = {:?}\n",
285 _true_branch_entropy, _false_branch_entropy
286 );
287
288 let new_entropy = (_true_branch_entropy * split.true_branch.len() as f64
289 + _false_branch_entropy * split.false_branch.len() as f64)
290 / (data_size as f64);
291
292 let gain = _impurity - new_entropy;
293
294 if gain > best_split.gain {
295 best_split = split;
296 best_split.gain = gain;
297 best_split.attribute = attribute.clone();
298 best_split.pivot = pivot.clone();
299 }
300 }
301 }
302
303 if best_split.gain > 0.0 {
304 let max_depth = config.max_depth - 1;
305 let mut true_branch_config = config.clone();
306 true_branch_config.max_depth = max_depth;
307 let mut false_branch_config = config.clone();
308 false_branch_config.max_depth = max_depth;
309 let tree = Some(Box::new(DecisionTree {
310 decision: None,
311 true_branch: DecisionTree::build(
312 _attribute.clone(),
313 &true_branch_config,
314 &mut best_split.true_branch,
315 ),
316 false_branch: DecisionTree::build(
317 _attribute.clone(),
318 &false_branch_config,
319 &mut best_split.false_branch,
320 ),
321 attribute: Some(best_split.attribute.clone()),
322 pivot: Some(best_split.pivot.clone()),
323 }));
324 return tree;
325 } else {
326 return Some(Box::new(DecisionTree {
327 decision: value_frequency(&config.decision, &data),
328 true_branch: None,
329 false_branch: None,
330 attribute: None,
331 pivot: None,
332 }));
333 }
334 }
335
336 pub fn predict(_tree: Option<Box<DecisionTree>>, item: Item) -> Option<Value> {
338 let mut tree = _tree;
339
340 loop {
341 if tree.is_some() {
342 let t = tree.unwrap();
343 let decision = t.decision.clone();
344 if decision.is_some() {
345 return decision;
346 } else {
347 let attribute = t.attribute.clone().unwrap();
348 let value: Option<&Value> = item.get(&attribute);
349 let pivot = t.pivot.clone();
350
351 if value.is_some() && pivot.is_some() && *value.unwrap() == pivot.unwrap() {
352 tree = t.true_branch;
353 } else {
354 tree = t.false_branch;
355 }
356 }
357 }
358 }
359 }
360}
361
362fn sample_dataset(data: &Dataset, size: usize) -> Dataset {
363 let mut rng = rand::thread_rng();
364 let mut shuffled = data.clone();
365 shuffled.shuffle(&mut rng);
366 shuffled.resize(size, Item::new());
367 return shuffled;
368}
369
370pub struct RandomForest {
372 trees: Vec<Option<Box<DecisionTree>>>
373}
374
375impl RandomForest {
376
377 pub fn build(attribute: String, config: TreeConfig, data: &Dataset, num_trees: usize, subsample_size: usize) -> RandomForest {
380
381 let mut trees:Vec<Option<Box<DecisionTree>>> = Vec::new();
382 for n in 0..num_trees {
383 let mut subsample = sample_dataset(data, subsample_size);
384 let tree_config = config.clone();
385 let tree = DecisionTree::build(attribute.clone(), &tree_config, &mut subsample);
386 trees.push(tree);
387 }
388 return RandomForest {
389 trees
390 }
391 }
392
393 pub fn predict(rf: RandomForest, item: Item) -> HashMap<Value, usize> {
395 let mut results:HashMap<Value, usize> = HashMap::new();
396 for tree in rf.trees {
397 let value = DecisionTree::predict(tree, item.clone());
398 match value {
399 Some(v) => {
400 let count = results.entry(v).or_insert(0);
401 *count += 1;
402 },
403 None => {}
404 }
405 }
406 return results;
407 }
408
409}
410
411#[cfg(test)]
412mod test_treeconfig {
413 use super::*;
414
415 #[test]
416 fn create_empty_defaults() {
417 let config = TreeConfig::new();
418 assert_eq!(config.decision, "category".to_string());
419 assert_eq!(config.max_depth, 70);
420 assert_eq!(config.min_count, 1);
421 }
422}
423
424#[cfg(test)]
425mod test_dataset {
426 use super::*;
427
428 #[test]
429 fn unique() {
430 let mut dataset = Dataset::new();
431 let mut item1 = Item::new();
432 item1.insert(
433 "lang".to_string(),
434 Value {
435 data: "rust".to_string(),
436 },
437 );
438 item1.insert(
439 "typing".to_string(),
440 Value {
441 data: "static".to_string(),
442 },
443 );
444 dataset.push(item1);
445 let mut item2 = Item::new();
446 item2.insert(
447 "lang".to_string(),
448 Value {
449 data: "python".to_string(),
450 },
451 );
452 item2.insert(
453 "typing".to_string(),
454 Value {
455 data: "dynamic".to_string(),
456 },
457 );
458 dataset.push(item2);
459 let mut item3 = Item::new();
460 item3.insert(
461 "lang".to_string(),
462 Value {
463 data: "rust".to_string(),
464 },
465 );
466 item3.insert(
467 "typing".to_string(),
468 Value {
469 data: "static".to_string(),
470 },
471 );
472 dataset.push(item3);
473 let unique = unique_values(&"lang".to_string(), &dataset);
474 assert_eq!(unique.len(), 2);
476 assert_eq!(
477 *unique
478 .get(&Value {
479 data: "rust".to_string()
480 })
481 .unwrap(),
482 2
483 );
484 assert_eq!(
485 *unique
486 .get(&Value {
487 data: "python".to_string()
488 })
489 .unwrap(),
490 1
491 );
492 }
493
494 #[test]
495 fn most_frequent() {
496 let mut dataset = Dataset::new();
497 let mut item1 = Item::new();
498 item1.insert(
499 "lang".to_string(),
500 Value {
501 data: "rust".to_string(),
502 },
503 );
504 item1.insert(
505 "typing".to_string(),
506 Value {
507 data: "static".to_string(),
508 },
509 );
510 dataset.push(item1);
511 let mut item2 = Item::new();
512 item2.insert(
513 "lang".to_string(),
514 Value {
515 data: "python".to_string(),
516 },
517 );
518 item2.insert(
519 "typing".to_string(),
520 Value {
521 data: "dynamic".to_string(),
522 },
523 );
524 dataset.push(item2);
525 let mut item3 = Item::new();
526 item3.insert(
527 "lang".to_string(),
528 Value {
529 data: "rust".to_string(),
530 },
531 );
532 item3.insert(
533 "typing".to_string(),
534 Value {
535 data: "static".to_string(),
536 },
537 );
538 dataset.push(item3);
539 let unique = value_frequency(&"lang".to_string(), &dataset);
540 assert_eq!(unique.is_some(), true);
542 assert_eq!(
543 unique.unwrap(),
544 Value {
545 data: "rust".to_string()
546 }
547 );
548 }
549
550 #[test]
551 fn test_sample() {
552 let mut dataset = Dataset::new();
553 for i in 1..10 {
554 let mut item = Item::new();
555 item.insert("id".to_string(), Value { data: i.to_string() });
556 dataset.push(item);
557 }
558 let shuffled = sample_dataset(&dataset, 5);
559 assert_eq!(shuffled.len(), 5);
560 }
561
562}
563
564#[cfg(test)]
565mod test_cacheentry {
566 use super::*;
567
568 #[test]
569 fn equal_entries_identity() {
570 let entry1 = CacheEntry {
571 attribute: "attribute".to_string(),
572 value: Value {
573 data: "data".to_string(),
574 },
575 };
576 let entry2 = CacheEntry {
577 attribute: "attribute".to_string(),
578 value: Value {
579 data: "data".to_string(),
580 },
581 };
582 assert_eq!(entry1 == entry2, true);
583 }
584
585 #[test]
586 fn equal_entries_hash() {
587 let entry1 = CacheEntry {
588 attribute: "attribute".to_string(),
589 value: Value {
590 data: "data".to_string(),
591 },
592 };
593 let entry2 = CacheEntry {
594 attribute: "attribute".to_string(),
595 value: Value {
596 data: "data".to_string(),
597 },
598 };
599 let mut set: HashSet<CacheEntry> = HashSet::new();
600 set.insert(entry1);
601 set.insert(entry2);
602 assert_eq!(set.len(), 1);
603 }
604
605 #[test]
606 fn diff_entries_identity() {
607 let entry1 = CacheEntry {
608 attribute: "attribute".to_string(),
609 value: Value {
610 data: "data".to_string(),
611 },
612 };
613 let entry2 = CacheEntry {
614 attribute: "other attribute".to_string(),
615 value: Value {
616 data: "data".to_string(),
617 },
618 };
619 assert_eq!(entry1 != entry2, true);
620 }
621
622 #[test]
623 fn diff_entries_hash() {
624 let entry1 = CacheEntry {
625 attribute: "attribute".to_string(),
626 value: Value {
627 data: "data".to_string(),
628 },
629 };
630 let entry2 = CacheEntry {
631 attribute: "other attribute".to_string(),
632 value: Value {
633 data: "data".to_string(),
634 },
635 };
636 let mut set: HashSet<CacheEntry> = HashSet::new();
637 set.insert(entry1);
638 set.insert(entry2);
639 assert_eq!(set.len(), 2);
640 }
641}
642
643#[cfg(test)]
644mod test_decisiontree {
645 use super::*;
646
647 #[test]
648 fn decision_less_mincount() {
649 let mut dataset = Dataset::new();
650 let mut item1 = Item::new();
651 item1.insert(
652 "lang".to_string(),
653 Value {
654 data: "rust".to_string(),
655 },
656 );
657 item1.insert(
658 "typing".to_string(),
659 Value {
660 data: "static".to_string(),
661 },
662 );
663 let mut config = TreeConfig::new();
664 config.decision = "lang".to_string();
665 let tree = DecisionTree::build("lang".to_string(), &config, &mut dataset);
666 let t = tree.unwrap();
667 print!("decision: {:?}\n", t.decision);
668 assert_eq!(t.decision, None);
669 }
670
671 #[test]
672 fn decision_more_mincount() {
673 let mut dataset = Dataset::new();
674 let mut item1 = Item::new();
675 item1.insert(
676 "lang".to_string(),
677 Value {
678 data: "rust".to_string(),
679 },
680 );
681 item1.insert(
682 "typing".to_string(),
683 Value {
684 data: "static".to_string(),
685 },
686 );
687 dataset.push(item1);
688 let mut item2 = Item::new();
689 item2.insert(
690 "lang".to_string(),
691 Value {
692 data: "python".to_string(),
693 },
694 );
695 item2.insert(
696 "typing".to_string(),
697 Value {
698 data: "dynamic".to_string(),
699 },
700 );
701 dataset.push(item2);
702 let mut item3 = Item::new();
703 item3.insert(
704 "lang".to_string(),
705 Value {
706 data: "rust".to_string(),
707 },
708 );
709 item3.insert(
710 "typing".to_string(),
711 Value {
712 data: "static".to_string(),
713 },
714 );
715 dataset.push(item3);
716 let mut config = TreeConfig::new();
717 config.decision = "lang".to_string();
718 let tree = DecisionTree::build("lang".to_string(), &config, &mut dataset);
719 let t = tree.unwrap();
720 print!("decision: {:?}\n", t.decision);
721 assert_eq!(t.decision, None);
722 }
723
724 #[test]
725 fn decision_prediction() {
726 let mut dataset = Dataset::new();
727 let mut item1 = Item::new();
728 item1.insert(
729 "lang".to_string(),
730 Value {
731 data: "rust".to_string(),
732 },
733 );
734 item1.insert(
735 "typing".to_string(),
736 Value {
737 data: "static".to_string(),
738 },
739 );
740 dataset.push(item1);
741 let mut item2 = Item::new();
742 item2.insert(
743 "lang".to_string(),
744 Value {
745 data: "python".to_string(),
746 },
747 );
748 item2.insert(
749 "typing".to_string(),
750 Value {
751 data: "dynamic".to_string(),
752 },
753 );
754 dataset.push(item2);
755 let mut config = TreeConfig::new();
756 config.decision = "lang".to_string();
757 let tree = DecisionTree::build("lang".to_string(), &config, &mut dataset);
758 let mut question = Item::new();
759 question.insert(
760 "typing".to_string(),
761 Value {
762 data: "dynamic".to_string(),
763 },
764 );
765 let answer = DecisionTree::predict(tree, question);
766 assert_eq!(answer.unwrap().data, "python");
767 }
768}
769
770#[cfg(test)]
771mod test_randomforest {
772 use super::*;
773
774 #[test]
775 fn forest_prediction() {
776 let mut dataset = Dataset::new();
777 let mut item1 = Item::new();
778 item1.insert(
779 "lang".to_string(),
780 Value {
781 data: "rust".to_string(),
782 },
783 );
784 item1.insert(
785 "typing".to_string(),
786 Value {
787 data: "static".to_string(),
788 },
789 );
790 dataset.push(item1);
791 let mut item2 = Item::new();
792 item2.insert(
793 "lang".to_string(),
794 Value {
795 data: "python".to_string(),
796 },
797 );
798 item2.insert(
799 "typing".to_string(),
800 Value {
801 data: "dynamic".to_string(),
802 },
803 );
804 dataset.push(item2);
805 let mut item3 = Item::new();
806 item3.insert(
807 "lang".to_string(),
808 Value {
809 data: "haskell".to_string(),
810 },
811 );
812 item3.insert(
813 "typing".to_string(),
814 Value {
815 data: "static".to_string(),
816 },
817 );
818 dataset.push(item3);
819 let mut config = TreeConfig::new();
820 config.decision = "lang".to_string();
821 let forest = RandomForest::build("lang".to_string(), config, &dataset, 100, 3);
822 let mut question = Item::new();
823 question.insert(
824 "typing".to_string(),
825 Value {
826 data: "static".to_string(),
827 },
828 );
829 let answer = RandomForest::predict(forest, question);
830print!("answer = {:?}\n", answer);
832 }
833
834}