extern crate rand;
use std::cmp::Eq;
use std::collections::{HashMap, HashSet};
use std::fmt;
use std::hash::{Hash, Hasher};
use rand::prelude::*;
pub struct TreeConfig {
pub decision: String,
pub max_depth: usize,
pub min_count: usize,
pub entropy_threshold: f64,
pub impurity_method: fn(&String, &Dataset) -> f64
}
impl TreeConfig {
pub fn new() -> TreeConfig {
return TreeConfig {
decision: "category".to_string(),
max_depth: 70,
min_count: 1,
entropy_threshold: 0.01,
impurity_method: entropy
};
}
pub fn new_gini() -> TreeConfig {
return TreeConfig {
decision: "category".to_string(),
max_depth: 70,
min_count: 1,
entropy_threshold: 0.01,
impurity_method: gini
};
}
}
impl Clone for TreeConfig {
fn clone(&self) -> TreeConfig {
TreeConfig {
decision: self.decision.clone(),
max_depth: 70,
min_count: 1,
entropy_threshold: 0.01,
impurity_method: self.impurity_method
}
}
}
pub struct Value {
pub data: String,
}
impl fmt::Debug for Value {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Value {{ data: {} }}", self.data)
}
}
impl Eq for Value {}
impl PartialEq for Value {
fn eq(&self, other: &Value) -> bool {
self.data == other.data
}
}
impl Hash for Value {
fn hash<H: Hasher>(&self, state: &mut H) {
self.data.hash(state);
}
}
impl Clone for Value {
fn clone(&self) -> Value {
Value {
data: self.data.clone(),
}
}
}
struct CacheEntry {
attribute: String,
value: Value,
}
impl Eq for CacheEntry {}
impl PartialEq for CacheEntry {
fn eq(&self, other: &CacheEntry) -> bool {
self.attribute == other.attribute && self.value == other.value
}
}
impl Hash for CacheEntry {
fn hash<H: Hasher>(&self, state: &mut H) {
self.attribute.hash(state);
self.value.hash(state);
}
}
pub type Item = HashMap<String, Value>;
pub type Dataset = Vec<Item>;
struct Split {
gain: f64,
true_branch: Dataset,
false_branch: Dataset,
attribute: String,
pivot: Value,
}
fn unique_values<'a>(attribute: &String, data: &'a Vec<Item>) -> HashMap<&'a Value, usize> {
let mut counter: HashMap<&Value, usize> = HashMap::new();
for mut item in data.into_iter() {
let value = item.get(attribute);
match value {
Some(v) => {
let current = counter.entry(v).or_insert(0);
*current += 1;
}
None => {}
}
}
return counter;
}
fn value_frequency(attribute: &String, data: &Dataset) -> Option<Value> {
let unique = unique_values(attribute, data);
let mut most_frequent_count = 0;
let mut most_frequent_value: Option<Value> = None;
for (value, count) in unique.into_iter() {
if count > most_frequent_count {
let _v = value.clone();
most_frequent_count = count;
most_frequent_value = Some(_v);
}
}
return most_frequent_value;
}
fn calculate_split(attribute: &String, pivot: &Value, data: &Dataset) -> Split {
let mut true_branch = Dataset::new();
let mut false_branch = Dataset::new();
for item in data.into_iter() {
let value = item.get(attribute);
match value {
Some(v) => {
if v == pivot {
true_branch.push(item.clone());
} else {
false_branch.push(item.clone());
}
}
None => {}
}
}
return Split {
gain: 0.0,
true_branch,
false_branch,
attribute: "category".to_string(),
pivot: Value {
data: "".to_string(),
},
};
}
fn entropy(attribute: &String, data: &Dataset) -> f64 {
let counter = unique_values(attribute, data);
let size = data.len() as f64;
let mut impurity = 0.0;
for (_, count) in counter {
let p = count as f64 / size;
impurity += -p * p.log2();
}
return impurity;
}
fn gini(attribute: &String, data: &Dataset) -> f64 {
let counter = unique_values(attribute, data);
let size = data.len() as f64;
let mut impurity = 1.0;
for (_, count) in counter {
let p = count as f64 / size;
impurity += -p * p;
}
return impurity;
}
pub struct DecisionTree {
decision: Option<Value>,
true_branch: Option<Box<DecisionTree>>,
false_branch: Option<Box<DecisionTree>>,
attribute: Option<String>,
pivot: Option<Value>,
}
impl DecisionTree {
pub fn build(
_attribute: String,
config: &TreeConfig,
data: &mut Dataset,
) -> Option<Box<DecisionTree>> {
let data_size = data.len();
if config.max_depth == 0 || data_size <= config.min_count {
return Some(Box::new(DecisionTree {
decision: value_frequency(&config.decision, &data),
true_branch: None,
false_branch: None,
attribute: None,
pivot: None,
}));
}
let _impurity = (config.impurity_method)(&_attribute, data);
if _impurity <= config.entropy_threshold {
return Some(Box::new(DecisionTree {
decision: value_frequency(&config.decision, &data),
true_branch: None,
false_branch: None,
attribute: None,
pivot: None,
}));
}
let mut cache: HashSet<CacheEntry> = HashSet::new();
let _data = data.clone();
let mut best_split = Split {
gain: 0.0,
true_branch: Dataset::new(),
false_branch: Dataset::new(),
attribute: "category".to_string(),
pivot: Value {
data: "".to_string(),
},
};
for item in _data {
print!("Item: {:?}\n", item);
for attribute in item.keys() {
print!("\tAttribute: {:?}\n", attribute);
if *attribute == config.decision {
continue;
}
let pivot = item.get(attribute).unwrap();
let cache_entry = CacheEntry {
attribute: attribute.clone(),
value: pivot.clone(),
};
if cache.contains(&cache_entry) {
continue;
}
cache.insert(cache_entry);
let split = calculate_split(attribute, pivot, &data);
print!("\t\tdata = {:?}", data);
let _true_branch_entropy = entropy(attribute, &split.true_branch);
let _false_branch_entropy = entropy(attribute, &split.false_branch);
print!(
"\tE(t) = {:?}, E(f) = {:?}\n",
_true_branch_entropy, _false_branch_entropy
);
let new_entropy = (_true_branch_entropy * split.true_branch.len() as f64
+ _false_branch_entropy * split.false_branch.len() as f64)
/ (data_size as f64);
let gain = _impurity - new_entropy;
if gain > best_split.gain {
best_split = split;
best_split.gain = gain;
best_split.attribute = attribute.clone();
best_split.pivot = pivot.clone();
}
}
}
if best_split.gain > 0.0 {
let max_depth = config.max_depth - 1;
let mut true_branch_config = config.clone();
true_branch_config.max_depth = max_depth;
let mut false_branch_config = config.clone();
false_branch_config.max_depth = max_depth;
let tree = Some(Box::new(DecisionTree {
decision: None,
true_branch: DecisionTree::build(
_attribute.clone(),
&true_branch_config,
&mut best_split.true_branch,
),
false_branch: DecisionTree::build(
_attribute.clone(),
&false_branch_config,
&mut best_split.false_branch,
),
attribute: Some(best_split.attribute.clone()),
pivot: Some(best_split.pivot.clone()),
}));
return tree;
} else {
return Some(Box::new(DecisionTree {
decision: value_frequency(&config.decision, &data),
true_branch: None,
false_branch: None,
attribute: None,
pivot: None,
}));
}
}
pub fn predict(_tree: Option<Box<DecisionTree>>, item: Item) -> Option<Value> {
let mut tree = _tree;
loop {
if tree.is_some() {
let t = tree.unwrap();
let decision = t.decision.clone();
if decision.is_some() {
return decision;
} else {
let attribute = t.attribute.clone().unwrap();
let value: Option<&Value> = item.get(&attribute);
let pivot = t.pivot.clone();
if value.is_some() && pivot.is_some() && *value.unwrap() == pivot.unwrap() {
tree = t.true_branch;
} else {
tree = t.false_branch;
}
}
}
}
}
}
fn sample_dataset(data: &Dataset, size: usize) -> Dataset {
let mut rng = rand::thread_rng();
let mut shuffled = data.clone();
shuffled.shuffle(&mut rng);
shuffled.resize(size, Item::new());
return shuffled;
}
pub struct RandomForest {
trees: Vec<Option<Box<DecisionTree>>>
}
impl RandomForest {
pub fn build(attribute: String, config: TreeConfig, data: &Dataset, num_trees: usize, subsample_size: usize) -> RandomForest {
let mut trees:Vec<Option<Box<DecisionTree>>> = Vec::new();
for n in 0..num_trees {
let mut subsample = sample_dataset(data, subsample_size);
let tree_config = config.clone();
let tree = DecisionTree::build(attribute.clone(), &tree_config, &mut subsample);
trees.push(tree);
}
return RandomForest {
trees
}
}
pub fn predict(rf: RandomForest, item: Item) -> HashMap<Value, usize> {
let mut results:HashMap<Value, usize> = HashMap::new();
for tree in rf.trees {
let value = DecisionTree::predict(tree, item.clone());
match value {
Some(v) => {
let count = results.entry(v).or_insert(0);
*count += 1;
},
None => {}
}
}
return results;
}
}
#[cfg(test)]
mod test_treeconfig {
use super::*;
#[test]
fn create_empty_defaults() {
let config = TreeConfig::new();
assert_eq!(config.decision, "category".to_string());
assert_eq!(config.max_depth, 70);
assert_eq!(config.min_count, 1);
}
}
#[cfg(test)]
mod test_dataset {
use super::*;
#[test]
fn unique() {
let mut dataset = Dataset::new();
let mut item1 = Item::new();
item1.insert(
"lang".to_string(),
Value {
data: "rust".to_string(),
},
);
item1.insert(
"typing".to_string(),
Value {
data: "static".to_string(),
},
);
dataset.push(item1);
let mut item2 = Item::new();
item2.insert(
"lang".to_string(),
Value {
data: "python".to_string(),
},
);
item2.insert(
"typing".to_string(),
Value {
data: "dynamic".to_string(),
},
);
dataset.push(item2);
let mut item3 = Item::new();
item3.insert(
"lang".to_string(),
Value {
data: "rust".to_string(),
},
);
item3.insert(
"typing".to_string(),
Value {
data: "static".to_string(),
},
);
dataset.push(item3);
let unique = unique_values(&"lang".to_string(), &dataset);
assert_eq!(unique.len(), 2);
assert_eq!(
*unique
.get(&Value {
data: "rust".to_string()
})
.unwrap(),
2
);
assert_eq!(
*unique
.get(&Value {
data: "python".to_string()
})
.unwrap(),
1
);
}
#[test]
fn most_frequent() {
let mut dataset = Dataset::new();
let mut item1 = Item::new();
item1.insert(
"lang".to_string(),
Value {
data: "rust".to_string(),
},
);
item1.insert(
"typing".to_string(),
Value {
data: "static".to_string(),
},
);
dataset.push(item1);
let mut item2 = Item::new();
item2.insert(
"lang".to_string(),
Value {
data: "python".to_string(),
},
);
item2.insert(
"typing".to_string(),
Value {
data: "dynamic".to_string(),
},
);
dataset.push(item2);
let mut item3 = Item::new();
item3.insert(
"lang".to_string(),
Value {
data: "rust".to_string(),
},
);
item3.insert(
"typing".to_string(),
Value {
data: "static".to_string(),
},
);
dataset.push(item3);
let unique = value_frequency(&"lang".to_string(), &dataset);
assert_eq!(unique.is_some(), true);
assert_eq!(
unique.unwrap(),
Value {
data: "rust".to_string()
}
);
}
#[test]
fn test_sample() {
let mut dataset = Dataset::new();
for i in 1..10 {
let mut item = Item::new();
item.insert("id".to_string(), Value { data: i.to_string() });
dataset.push(item);
}
let shuffled = sample_dataset(&dataset, 5);
assert_eq!(shuffled.len(), 5);
}
}
#[cfg(test)]
mod test_cacheentry {
use super::*;
#[test]
fn equal_entries_identity() {
let entry1 = CacheEntry {
attribute: "attribute".to_string(),
value: Value {
data: "data".to_string(),
},
};
let entry2 = CacheEntry {
attribute: "attribute".to_string(),
value: Value {
data: "data".to_string(),
},
};
assert_eq!(entry1 == entry2, true);
}
#[test]
fn equal_entries_hash() {
let entry1 = CacheEntry {
attribute: "attribute".to_string(),
value: Value {
data: "data".to_string(),
},
};
let entry2 = CacheEntry {
attribute: "attribute".to_string(),
value: Value {
data: "data".to_string(),
},
};
let mut set: HashSet<CacheEntry> = HashSet::new();
set.insert(entry1);
set.insert(entry2);
assert_eq!(set.len(), 1);
}
#[test]
fn diff_entries_identity() {
let entry1 = CacheEntry {
attribute: "attribute".to_string(),
value: Value {
data: "data".to_string(),
},
};
let entry2 = CacheEntry {
attribute: "other attribute".to_string(),
value: Value {
data: "data".to_string(),
},
};
assert_eq!(entry1 != entry2, true);
}
#[test]
fn diff_entries_hash() {
let entry1 = CacheEntry {
attribute: "attribute".to_string(),
value: Value {
data: "data".to_string(),
},
};
let entry2 = CacheEntry {
attribute: "other attribute".to_string(),
value: Value {
data: "data".to_string(),
},
};
let mut set: HashSet<CacheEntry> = HashSet::new();
set.insert(entry1);
set.insert(entry2);
assert_eq!(set.len(), 2);
}
}
#[cfg(test)]
mod test_decisiontree {
use super::*;
#[test]
fn decision_less_mincount() {
let mut dataset = Dataset::new();
let mut item1 = Item::new();
item1.insert(
"lang".to_string(),
Value {
data: "rust".to_string(),
},
);
item1.insert(
"typing".to_string(),
Value {
data: "static".to_string(),
},
);
let mut config = TreeConfig::new();
config.decision = "lang".to_string();
let tree = DecisionTree::build("lang".to_string(), &config, &mut dataset);
let t = tree.unwrap();
print!("decision: {:?}\n", t.decision);
assert_eq!(t.decision, None);
}
#[test]
fn decision_more_mincount() {
let mut dataset = Dataset::new();
let mut item1 = Item::new();
item1.insert(
"lang".to_string(),
Value {
data: "rust".to_string(),
},
);
item1.insert(
"typing".to_string(),
Value {
data: "static".to_string(),
},
);
dataset.push(item1);
let mut item2 = Item::new();
item2.insert(
"lang".to_string(),
Value {
data: "python".to_string(),
},
);
item2.insert(
"typing".to_string(),
Value {
data: "dynamic".to_string(),
},
);
dataset.push(item2);
let mut item3 = Item::new();
item3.insert(
"lang".to_string(),
Value {
data: "rust".to_string(),
},
);
item3.insert(
"typing".to_string(),
Value {
data: "static".to_string(),
},
);
dataset.push(item3);
let mut config = TreeConfig::new();
config.decision = "lang".to_string();
let tree = DecisionTree::build("lang".to_string(), &config, &mut dataset);
let t = tree.unwrap();
print!("decision: {:?}\n", t.decision);
assert_eq!(t.decision, None);
}
#[test]
fn decision_prediction() {
let mut dataset = Dataset::new();
let mut item1 = Item::new();
item1.insert(
"lang".to_string(),
Value {
data: "rust".to_string(),
},
);
item1.insert(
"typing".to_string(),
Value {
data: "static".to_string(),
},
);
dataset.push(item1);
let mut item2 = Item::new();
item2.insert(
"lang".to_string(),
Value {
data: "python".to_string(),
},
);
item2.insert(
"typing".to_string(),
Value {
data: "dynamic".to_string(),
},
);
dataset.push(item2);
let mut config = TreeConfig::new();
config.decision = "lang".to_string();
let tree = DecisionTree::build("lang".to_string(), &config, &mut dataset);
let mut question = Item::new();
question.insert(
"typing".to_string(),
Value {
data: "dynamic".to_string(),
},
);
let answer = DecisionTree::predict(tree, question);
assert_eq!(answer.unwrap().data, "python");
}
}
#[cfg(test)]
mod test_randomforest {
use super::*;
#[test]
fn forest_prediction() {
let mut dataset = Dataset::new();
let mut item1 = Item::new();
item1.insert(
"lang".to_string(),
Value {
data: "rust".to_string(),
},
);
item1.insert(
"typing".to_string(),
Value {
data: "static".to_string(),
},
);
dataset.push(item1);
let mut item2 = Item::new();
item2.insert(
"lang".to_string(),
Value {
data: "python".to_string(),
},
);
item2.insert(
"typing".to_string(),
Value {
data: "dynamic".to_string(),
},
);
dataset.push(item2);
let mut item3 = Item::new();
item3.insert(
"lang".to_string(),
Value {
data: "haskell".to_string(),
},
);
item3.insert(
"typing".to_string(),
Value {
data: "static".to_string(),
},
);
dataset.push(item3);
let mut config = TreeConfig::new();
config.decision = "lang".to_string();
let forest = RandomForest::build("lang".to_string(), config, &dataset, 100, 3);
let mut question = Item::new();
question.insert(
"typing".to_string(),
Value {
data: "static".to_string(),
},
);
let answer = RandomForest::predict(forest, question);
print!("answer = {:?}\n", answer);
}
}