randomforests/
lib.rs

1extern crate rand;
2
3use std::cmp::Eq;
4use std::collections::{HashMap, HashSet};
5use std::fmt;
6use std::hash::{Hash, Hasher};
7use rand::prelude::*;
8
9/// `DecisionTree` creation parameters.
10pub struct TreeConfig {
11    pub decision: String,
12    pub max_depth: usize,
13    pub min_count: usize,
14    pub entropy_threshold: f64,
15    pub impurity_method: fn(&String, &Dataset) -> f64
16}
17
18impl TreeConfig {
19    /// Create a default tree configuration.
20    pub fn new() -> TreeConfig {
21        return TreeConfig {
22            decision: "category".to_string(),
23            max_depth: 70,
24            min_count: 1,
25            entropy_threshold: 0.01,
26            impurity_method: entropy
27        };
28    }
29
30    pub fn new_gini() -> TreeConfig {
31        return TreeConfig {
32            decision: "category".to_string(),
33            max_depth: 70,
34            min_count: 1,
35            entropy_threshold: 0.01,
36            impurity_method: gini
37        };
38    }
39}
40
41impl Clone for TreeConfig {
42    fn clone(&self) -> TreeConfig {
43        TreeConfig {
44            decision: self.decision.clone(),
45            max_depth: 70,
46            min_count: 1,
47            entropy_threshold: 0.01,
48            impurity_method: self.impurity_method
49        }
50    }
51}
52
53/// Value encapsulates an attribute's value as a data `String`.
54pub struct Value {
55    pub data: String,
56}
57
58impl fmt::Debug for Value {
59    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
60        write!(f, "Value {{ data: {} }}", self.data)
61    }
62}
63
64impl Eq for Value {}
65
66impl PartialEq for Value {
67    fn eq(&self, other: &Value) -> bool {
68        self.data == other.data
69    }
70}
71
72impl Hash for Value {
73    fn hash<H: Hasher>(&self, state: &mut H) {
74        self.data.hash(state);
75    }
76}
77
78impl Clone for Value {
79    fn clone(&self) -> Value {
80        Value {
81            data: self.data.clone(),
82        }
83    }
84}
85
86struct CacheEntry {
87    attribute: String,
88    value: Value,
89}
90
91impl Eq for CacheEntry {}
92
93impl PartialEq for CacheEntry {
94    fn eq(&self, other: &CacheEntry) -> bool {
95        self.attribute == other.attribute && self.value == other.value
96    }
97}
98
99impl Hash for CacheEntry {
100    fn hash<H: Hasher>(&self, state: &mut H) {
101        self.attribute.hash(state);
102        self.value.hash(state);
103    }
104}
105
106/// A Map representing a collection of attributes (as keys) with their respective values (as `Value`s).
107pub type Item = HashMap<String, Value>;
108/// A collection of items represented as a `Vec<Item>`.
109pub type Dataset = Vec<Item>;
110
111struct Split {
112    gain: f64,
113    true_branch: Dataset,
114    false_branch: Dataset,
115    attribute: String,
116    pivot: Value,
117}
118
119/// find unique value count for a certain attribute.
120fn unique_values<'a>(attribute: &String, data: &'a Vec<Item>) -> HashMap<&'a Value, usize> {
121    let mut counter: HashMap<&Value, usize> = HashMap::new();
122    for mut item in data.into_iter() {
123        let value = item.get(attribute);
124        match value {
125            Some(v) => {
126                let current = counter.entry(v).or_insert(0);
127                *current += 1;
128            }
129            None => {}
130        }
131    }
132    return counter;
133}
134
135/// find the most frequent value for a certain attribute.
136fn value_frequency(attribute: &String, data: &Dataset) -> Option<Value> {
137    let unique = unique_values(attribute, data);
138    let mut most_frequent_count = 0;
139    let mut most_frequent_value: Option<Value> = None;
140    for (value, count) in unique.into_iter() {
141        if count > most_frequent_count {
142            let _v = value.clone();
143            most_frequent_count = count;
144            most_frequent_value = Some(_v);
145        }
146    }
147    return most_frequent_value;
148}
149
150fn calculate_split(attribute: &String, pivot: &Value, data: &Dataset) -> Split {
151    let mut true_branch = Dataset::new();
152    let mut false_branch = Dataset::new();
153
154    for item in data.into_iter() {
155        let value = item.get(attribute);
156        match value {
157            Some(v) => {
158                if v == pivot {
159                    true_branch.push(item.clone());
160                } else {
161                    false_branch.push(item.clone());
162                }
163            }
164            None => {}
165        }
166    }
167
168    return Split {
169        gain: 0.0,
170        true_branch,
171        false_branch,
172        attribute: "category".to_string(),
173        pivot: Value {
174            data: "".to_string(),
175        },
176    };
177}
178
179fn entropy(attribute: &String, data: &Dataset) -> f64 {
180    let counter = unique_values(attribute, data);
181    let size = data.len() as f64;
182    let mut impurity = 0.0;
183    for (_, count) in counter {
184        let p = count as f64 / size;
185        impurity += -p * p.log2();
186    }
187    return impurity;
188}
189
190fn gini(attribute: &String, data: &Dataset) -> f64 {
191    let counter = unique_values(attribute, data);
192    let size = data.len() as f64;
193    let mut impurity = 1.0;
194    for (_, count) in counter {
195        let p = count as f64 / size;
196        impurity += -p * p;
197    }
198    return impurity;
199}
200
201/// Implements a `DecisionTree`.
202pub struct DecisionTree {
203    decision: Option<Value>,
204    true_branch: Option<Box<DecisionTree>>,
205    false_branch: Option<Box<DecisionTree>>,
206    attribute: Option<String>,
207    pivot: Option<Value>,
208}
209
210impl DecisionTree {
211    /// Build a decision tree with a `String` attribute, a `TreeConfig` and a dataset (a `Vec<Item>`).
212    /// Return either a `Some(DecisionTree)`, if successful, or `None` if not.
213    pub fn build(
214        _attribute: String,
215        config: &TreeConfig,
216        data: &mut Dataset,
217    ) -> Option<Box<DecisionTree>> {
218        let data_size = data.len();
219
220        if config.max_depth == 0 || data_size <= config.min_count {
221            return Some(Box::new(DecisionTree {
222                decision: value_frequency(&config.decision, &data),
223                true_branch: None,
224                false_branch: None,
225                attribute: None,
226                pivot: None,
227            }));
228        }
229
230        let _impurity = (config.impurity_method)(&_attribute, data);
231
232        if _impurity <= config.entropy_threshold {
233            return Some(Box::new(DecisionTree {
234                decision: value_frequency(&config.decision, &data),
235                true_branch: None,
236                false_branch: None,
237                attribute: None,
238                pivot: None,
239            }));
240        }
241
242        let mut cache: HashSet<CacheEntry> = HashSet::new();
243
244        let _data = data.clone();
245
246        let mut best_split = Split {
247            gain: 0.0,
248            true_branch: Dataset::new(),
249            false_branch: Dataset::new(),
250            attribute: "category".to_string(),
251            pivot: Value {
252                data: "".to_string(),
253            },
254        };
255
256        for item in _data {
257            print!("Item: {:?}\n", item);
258
259            for attribute in item.keys() {
260                print!("\tAttribute: {:?}\n", attribute);
261
262                if *attribute == config.decision {
263                    continue;
264                }
265
266                let pivot = item.get(attribute).unwrap();
267
268                let cache_entry = CacheEntry {
269                    attribute: attribute.clone(),
270                    value: pivot.clone(),
271                };
272
273                if cache.contains(&cache_entry) {
274                    continue;
275                }
276
277                cache.insert(cache_entry);
278
279                let split = calculate_split(attribute, pivot, &data);
280                print!("\t\tdata = {:?}", data);
281                let _true_branch_entropy = entropy(attribute, &split.true_branch);
282                let _false_branch_entropy = entropy(attribute, &split.false_branch);
283                print!(
284                    "\tE(t) = {:?}, E(f) = {:?}\n",
285                    _true_branch_entropy, _false_branch_entropy
286                );
287
288                let new_entropy = (_true_branch_entropy * split.true_branch.len() as f64
289                    + _false_branch_entropy * split.false_branch.len() as f64)
290                    / (data_size as f64);
291
292                let gain = _impurity - new_entropy;
293
294                if gain > best_split.gain {
295                    best_split = split;
296                    best_split.gain = gain;
297                    best_split.attribute = attribute.clone();
298                    best_split.pivot = pivot.clone();
299                }
300            }
301        }
302
303        if best_split.gain > 0.0 {
304            let max_depth = config.max_depth - 1;
305            let mut true_branch_config = config.clone();
306            true_branch_config.max_depth = max_depth;
307            let mut false_branch_config = config.clone();
308            false_branch_config.max_depth = max_depth;
309            let tree = Some(Box::new(DecisionTree {
310                decision: None,
311                true_branch: DecisionTree::build(
312                    _attribute.clone(),
313                    &true_branch_config,
314                    &mut best_split.true_branch,
315                ),
316                false_branch: DecisionTree::build(
317                    _attribute.clone(),
318                    &false_branch_config,
319                    &mut best_split.false_branch,
320                ),
321                attribute: Some(best_split.attribute.clone()),
322                pivot: Some(best_split.pivot.clone()),
323            }));
324            return tree;
325        } else {
326            return Some(Box::new(DecisionTree {
327                decision: value_frequency(&config.decision, &data),
328                true_branch: None,
329                false_branch: None,
330                attribute: None,
331                pivot: None,
332            }));
333        }
334    }
335
336    /// Return the `DecisionTree` prediction for a question expressed as an `Item`.
337    pub fn predict(_tree: Option<Box<DecisionTree>>, item: Item) -> Option<Value> {
338        let mut tree = _tree;
339
340        loop {
341            if tree.is_some() {
342                let t = tree.unwrap();
343                let decision = t.decision.clone();
344                if decision.is_some() {
345                    return decision;
346                } else {
347                    let attribute = t.attribute.clone().unwrap();
348                    let value: Option<&Value> = item.get(&attribute);
349                    let pivot = t.pivot.clone();
350
351                    if value.is_some() && pivot.is_some() && *value.unwrap() == pivot.unwrap() {
352                        tree = t.true_branch;
353                    } else {
354                        tree = t.false_branch;
355                    }
356                }
357            }
358        }
359    }
360}
361
362fn sample_dataset(data: &Dataset, size: usize) -> Dataset {
363    let mut rng = rand::thread_rng();
364    let mut shuffled = data.clone();
365    shuffled.shuffle(&mut rng);
366    shuffled.resize(size, Item::new());
367    return shuffled;
368}
369
370/// Implements an ensemble of `DecisionTree`s.
371pub struct RandomForest {
372    trees: Vec<Option<Box<DecisionTree>>>
373}
374
375impl RandomForest {
376
377    /// Builds an ensemble of `DecisionTree` by passing the data as a `&Dataset`, the number of trees and
378    /// the data's subsample size.
379    pub fn build(attribute: String, config: TreeConfig, data: &Dataset, num_trees: usize, subsample_size: usize) -> RandomForest {
380
381        let mut trees:Vec<Option<Box<DecisionTree>>> = Vec::new();
382        for n in 0..num_trees {
383            let mut subsample = sample_dataset(data, subsample_size);
384            let tree_config = config.clone();
385            let tree = DecisionTree::build(attribute.clone(), &tree_config, &mut subsample);
386            trees.push(tree);
387        }
388        return RandomForest {
389            trees
390        }
391    }
392
393    /// Return the `RandomForest` prediction for a question expressed as an `Item`.
394    pub fn predict(rf: RandomForest, item: Item) -> HashMap<Value, usize> {
395        let mut results:HashMap<Value, usize> = HashMap::new();
396        for tree in rf.trees {
397            let value = DecisionTree::predict(tree, item.clone());
398            match value {
399                Some(v) => {
400                    let count = results.entry(v).or_insert(0);
401                    *count += 1;
402                },
403                None => {}
404            }
405        }
406        return results;
407    }
408
409}
410
411#[cfg(test)]
412mod test_treeconfig {
413    use super::*;
414
415    #[test]
416    fn create_empty_defaults() {
417        let config = TreeConfig::new();
418        assert_eq!(config.decision, "category".to_string());
419        assert_eq!(config.max_depth, 70);
420        assert_eq!(config.min_count, 1);
421    }
422}
423
424#[cfg(test)]
425mod test_dataset {
426    use super::*;
427
428    #[test]
429    fn unique() {
430        let mut dataset = Dataset::new();
431        let mut item1 = Item::new();
432        item1.insert(
433            "lang".to_string(),
434            Value {
435                data: "rust".to_string(),
436            },
437        );
438        item1.insert(
439            "typing".to_string(),
440            Value {
441                data: "static".to_string(),
442            },
443        );
444        dataset.push(item1);
445        let mut item2 = Item::new();
446        item2.insert(
447            "lang".to_string(),
448            Value {
449                data: "python".to_string(),
450            },
451        );
452        item2.insert(
453            "typing".to_string(),
454            Value {
455                data: "dynamic".to_string(),
456            },
457        );
458        dataset.push(item2);
459        let mut item3 = Item::new();
460        item3.insert(
461            "lang".to_string(),
462            Value {
463                data: "rust".to_string(),
464            },
465        );
466        item3.insert(
467            "typing".to_string(),
468            Value {
469                data: "static".to_string(),
470            },
471        );
472        dataset.push(item3);
473        let unique = unique_values(&"lang".to_string(), &dataset);
474        // print!("unique: {:?}\n", unique);
475        assert_eq!(unique.len(), 2);
476        assert_eq!(
477            *unique
478                .get(&Value {
479                    data: "rust".to_string()
480                })
481                .unwrap(),
482            2
483        );
484        assert_eq!(
485            *unique
486                .get(&Value {
487                    data: "python".to_string()
488                })
489                .unwrap(),
490            1
491        );
492    }
493
494    #[test]
495    fn most_frequent() {
496        let mut dataset = Dataset::new();
497        let mut item1 = Item::new();
498        item1.insert(
499            "lang".to_string(),
500            Value {
501                data: "rust".to_string(),
502            },
503        );
504        item1.insert(
505            "typing".to_string(),
506            Value {
507                data: "static".to_string(),
508            },
509        );
510        dataset.push(item1);
511        let mut item2 = Item::new();
512        item2.insert(
513            "lang".to_string(),
514            Value {
515                data: "python".to_string(),
516            },
517        );
518        item2.insert(
519            "typing".to_string(),
520            Value {
521                data: "dynamic".to_string(),
522            },
523        );
524        dataset.push(item2);
525        let mut item3 = Item::new();
526        item3.insert(
527            "lang".to_string(),
528            Value {
529                data: "rust".to_string(),
530            },
531        );
532        item3.insert(
533            "typing".to_string(),
534            Value {
535                data: "static".to_string(),
536            },
537        );
538        dataset.push(item3);
539        let unique = value_frequency(&"lang".to_string(), &dataset);
540        // print!("unique: {:?}\n", unique);
541        assert_eq!(unique.is_some(), true);
542        assert_eq!(
543            unique.unwrap(),
544            Value {
545                data: "rust".to_string()
546            }
547        );
548    }
549
550    #[test]
551    fn test_sample() {
552        let mut dataset = Dataset::new();
553        for i in 1..10 {
554            let mut item = Item::new();
555            item.insert("id".to_string(), Value { data: i.to_string() });
556            dataset.push(item);
557        }
558        let shuffled = sample_dataset(&dataset, 5);
559        assert_eq!(shuffled.len(), 5);
560    }
561
562}
563
564#[cfg(test)]
565mod test_cacheentry {
566    use super::*;
567
568    #[test]
569    fn equal_entries_identity() {
570        let entry1 = CacheEntry {
571            attribute: "attribute".to_string(),
572            value: Value {
573                data: "data".to_string(),
574            },
575        };
576        let entry2 = CacheEntry {
577            attribute: "attribute".to_string(),
578            value: Value {
579                data: "data".to_string(),
580            },
581        };
582        assert_eq!(entry1 == entry2, true);
583    }
584
585    #[test]
586    fn equal_entries_hash() {
587        let entry1 = CacheEntry {
588            attribute: "attribute".to_string(),
589            value: Value {
590                data: "data".to_string(),
591            },
592        };
593        let entry2 = CacheEntry {
594            attribute: "attribute".to_string(),
595            value: Value {
596                data: "data".to_string(),
597            },
598        };
599        let mut set: HashSet<CacheEntry> = HashSet::new();
600        set.insert(entry1);
601        set.insert(entry2);
602        assert_eq!(set.len(), 1);
603    }
604
605    #[test]
606    fn diff_entries_identity() {
607        let entry1 = CacheEntry {
608            attribute: "attribute".to_string(),
609            value: Value {
610                data: "data".to_string(),
611            },
612        };
613        let entry2 = CacheEntry {
614            attribute: "other attribute".to_string(),
615            value: Value {
616                data: "data".to_string(),
617            },
618        };
619        assert_eq!(entry1 != entry2, true);
620    }
621
622    #[test]
623    fn diff_entries_hash() {
624        let entry1 = CacheEntry {
625            attribute: "attribute".to_string(),
626            value: Value {
627                data: "data".to_string(),
628            },
629        };
630        let entry2 = CacheEntry {
631            attribute: "other attribute".to_string(),
632            value: Value {
633                data: "data".to_string(),
634            },
635        };
636        let mut set: HashSet<CacheEntry> = HashSet::new();
637        set.insert(entry1);
638        set.insert(entry2);
639        assert_eq!(set.len(), 2);
640    }
641}
642
643#[cfg(test)]
644mod test_decisiontree {
645    use super::*;
646
647    #[test]
648    fn decision_less_mincount() {
649        let mut dataset = Dataset::new();
650        let mut item1 = Item::new();
651        item1.insert(
652            "lang".to_string(),
653            Value {
654                data: "rust".to_string(),
655            },
656        );
657        item1.insert(
658            "typing".to_string(),
659            Value {
660                data: "static".to_string(),
661            },
662        );
663        let mut config = TreeConfig::new();
664        config.decision = "lang".to_string();
665        let tree = DecisionTree::build("lang".to_string(), &config, &mut dataset);
666        let t = tree.unwrap();
667        print!("decision: {:?}\n", t.decision);
668        assert_eq!(t.decision, None);
669    }
670
671    #[test]
672    fn decision_more_mincount() {
673        let mut dataset = Dataset::new();
674        let mut item1 = Item::new();
675        item1.insert(
676            "lang".to_string(),
677            Value {
678                data: "rust".to_string(),
679            },
680        );
681        item1.insert(
682            "typing".to_string(),
683            Value {
684                data: "static".to_string(),
685            },
686        );
687        dataset.push(item1);
688        let mut item2 = Item::new();
689        item2.insert(
690            "lang".to_string(),
691            Value {
692                data: "python".to_string(),
693            },
694        );
695        item2.insert(
696            "typing".to_string(),
697            Value {
698                data: "dynamic".to_string(),
699            },
700        );
701        dataset.push(item2);
702        let mut item3 = Item::new();
703        item3.insert(
704            "lang".to_string(),
705            Value {
706                data: "rust".to_string(),
707            },
708        );
709        item3.insert(
710            "typing".to_string(),
711            Value {
712                data: "static".to_string(),
713            },
714        );
715        dataset.push(item3);
716        let mut config = TreeConfig::new();
717        config.decision = "lang".to_string();
718        let tree = DecisionTree::build("lang".to_string(), &config, &mut dataset);
719        let t = tree.unwrap();
720        print!("decision: {:?}\n", t.decision);
721        assert_eq!(t.decision, None);
722    }
723
724    #[test]
725    fn decision_prediction() {
726        let mut dataset = Dataset::new();
727        let mut item1 = Item::new();
728        item1.insert(
729            "lang".to_string(),
730            Value {
731                data: "rust".to_string(),
732            },
733        );
734        item1.insert(
735            "typing".to_string(),
736            Value {
737                data: "static".to_string(),
738            },
739        );
740        dataset.push(item1);
741        let mut item2 = Item::new();
742        item2.insert(
743            "lang".to_string(),
744            Value {
745                data: "python".to_string(),
746            },
747        );
748        item2.insert(
749            "typing".to_string(),
750            Value {
751                data: "dynamic".to_string(),
752            },
753        );
754        dataset.push(item2);
755        let mut config = TreeConfig::new();
756        config.decision = "lang".to_string();
757        let tree = DecisionTree::build("lang".to_string(), &config, &mut dataset);
758        let mut question = Item::new();
759        question.insert(
760            "typing".to_string(),
761            Value {
762                data: "dynamic".to_string(),
763            },
764        );
765        let answer = DecisionTree::predict(tree, question);
766        assert_eq!(answer.unwrap().data, "python");
767    }
768}
769
770#[cfg(test)]
771mod test_randomforest {
772    use super::*;
773
774    #[test]
775    fn forest_prediction() {
776        let mut dataset = Dataset::new();
777        let mut item1 = Item::new();
778        item1.insert(
779            "lang".to_string(),
780            Value {
781                data: "rust".to_string(),
782            },
783        );
784        item1.insert(
785            "typing".to_string(),
786            Value {
787                data: "static".to_string(),
788            },
789        );
790        dataset.push(item1);
791        let mut item2 = Item::new();
792        item2.insert(
793            "lang".to_string(),
794            Value {
795                data: "python".to_string(),
796            },
797        );
798        item2.insert(
799            "typing".to_string(),
800            Value {
801                data: "dynamic".to_string(),
802            },
803        );
804        dataset.push(item2);
805        let mut item3 = Item::new();
806        item3.insert(
807            "lang".to_string(),
808            Value {
809                data: "haskell".to_string(),
810            },
811        );
812        item3.insert(
813            "typing".to_string(),
814            Value {
815                data: "static".to_string(),
816            },
817        );
818        dataset.push(item3);
819        let mut config = TreeConfig::new();
820        config.decision = "lang".to_string();
821        let forest = RandomForest::build("lang".to_string(), config, &dataset, 100, 3);
822        let mut question = Item::new();
823        question.insert(
824            "typing".to_string(),
825            Value {
826                data: "static".to_string(),
827            },
828        );
829        let answer = RandomForest::predict(forest, question);
830//        assert_eq!(answer.unwrap().data, "python");
831        print!("answer = {:?}\n", answer);
832    }
833
834}