use std::usize;
use prelude::*;
use trees::decision_tree;
use multiclass::OneVsRestWrapper;
use utils::EncodableRng;
use rand;
use rand::{SeedableRng, StdRng};
use rand::distributions::{IndependentSample, Range};
#[derive(RustcEncodable, RustcDecodable)]
pub struct Hyperparameters {
tree_hyperparameters: decision_tree::Hyperparameters,
num_trees: usize,
rng: EncodableRng
}
impl Hyperparameters {
pub fn new(tree_hyperparameters: decision_tree::Hyperparameters,
num_trees: usize) -> Hyperparameters {
Hyperparameters {tree_hyperparameters: tree_hyperparameters,
num_trees: num_trees,
rng: EncodableRng::new()}
}
pub fn rng(&mut self, rng: rand::StdRng) -> &mut Hyperparameters {
self.rng.rng = rng;
self
}
pub fn build(&self) -> RandomForest {
let mut trees = Vec::with_capacity(self.num_trees);
let mut rng = self.rng.clone();
for _ in 0..self.num_trees {
let range = Range::new(0, usize::MAX);
let mut hyperparams = self.tree_hyperparameters.clone();
hyperparams.rng(
SeedableRng::from_seed(&(0..10)
.map(|_| range.ind_sample(&mut rng.rng))
.collect::<Vec<_>>()[..])
);
trees.push(hyperparams.build());
}
RandomForest {trees: trees,
rng: self.rng.clone()}
}
pub fn one_vs_rest(&mut self) -> OneVsRestWrapper<RandomForest> {
let base_model = self.build();
OneVsRestWrapper::new(base_model)
}
}
#[derive(RustcEncodable, RustcDecodable)]
#[derive(Clone)]
pub struct RandomForest {
trees: Vec<decision_tree::DecisionTree>,
rng: EncodableRng
}
impl SupervisedModel<Array> for RandomForest {
fn fit(&mut self, X: &Array, y: &Array) -> Result<(), &'static str> {
let mut rng = self.rng.clone();
for tree in self.trees.iter_mut() {
let indices = RandomForest::bootstrap_indices(X.rows(), &mut rng.rng);
try!(tree.fit(&X.get_rows(&indices), &y.get_rows(&indices)));
}
self.rng = rng;
Ok(())
}
fn decision_function(&self, X: &Array) -> Result<Array, &'static str> {
let mut df = Array::zeros(X.rows(), 1);
for tree in self.trees.iter() {
df.add_inplace(&try!(tree.decision_function(X)));
}
df.div_inplace(self.trees.len() as f32);
Ok(df)
}
}
impl SupervisedModel<SparseRowArray> for RandomForest {
fn fit(&mut self, X: &SparseRowArray, y: &Array) -> Result<(), &'static str> {
let mut rng = self.rng.clone();
for tree in self.trees.iter_mut() {
let indices = RandomForest::bootstrap_indices(X.rows(), &mut rng.rng);
let x = SparseColumnArray::from(&X.get_rows(&indices));
try!(tree.fit(&x, &y.get_rows(&indices)));
}
self.rng = rng;
Ok(())
}
fn decision_function(&self, X: &SparseRowArray) -> Result<Array, &'static str> {
let mut df = Array::zeros(X.rows(), 1);
let x = SparseColumnArray::from(X);
for tree in self.trees.iter() {
df.add_inplace(&try!(tree.decision_function(&x)));
}
df.div_inplace(self.trees.len() as f32);
Ok(df)
}
}
impl RandomForest {
pub fn trees(&self) -> &Vec<decision_tree::DecisionTree> {
&self.trees
}
fn bootstrap_indices(num_indices: usize, rng: &mut rand::StdRng) -> Vec<usize> {
let range = Range::new(0, num_indices);
(0..num_indices).map(|_| range.ind_sample(rng))
.collect::<Vec<_>>()
}
}
#[cfg(test)]
mod tests {
use prelude::*;
use trees::decision_tree;
use cross_validation::cross_validation::CrossValidation;
use datasets::iris::load_data;
use metrics::accuracy_score;
use multiclass::OneVsRestWrapper;
use super::*;
use rand::{StdRng, SeedableRng};
use bincode;
#[cfg(feature = "all_tests")]
use csv;
#[test]
fn test_random_forest_iris() {
let (data, target) = load_data();
let mut test_accuracy = 0.0;
let no_splits = 10;
let mut cv = CrossValidation::new(data.rows(),
no_splits);
cv.set_rng(StdRng::from_seed(&[100]));
for (train_idx, test_idx) in cv {
let x_train = data.get_rows(&train_idx);
let x_test = data.get_rows(&test_idx);
let y_train = target.get_rows(&train_idx);
let mut tree_params = decision_tree::Hyperparameters::new(data.cols());
tree_params.min_samples_split(10)
.max_features(4)
.rng(StdRng::from_seed(&[100]));
let mut model = Hyperparameters::new(tree_params, 10)
.rng(StdRng::from_seed(&[100]))
.one_vs_rest();
model.fit(&x_train, &y_train).unwrap();
let test_prediction = model.predict(&x_test).unwrap();
test_accuracy += accuracy_score(
&target.get_rows(&test_idx),
&test_prediction);
}
test_accuracy /= no_splits as f32;
println!("Accuracy {}", test_accuracy);
assert!(test_accuracy > 0.96);
}
#[test]
fn test_random_forest_iris_sparse() {
let (data, target) = load_data();
let mut test_accuracy = 0.0;
let no_splits = 10;
let mut cv = CrossValidation::new(data.rows(),
no_splits);
cv.set_rng(StdRng::from_seed(&[100]));
for (train_idx, test_idx) in cv {
let x_train = SparseRowArray::from(&data.get_rows(&train_idx));
let x_test = SparseRowArray::from(&data.get_rows(&test_idx));
let y_train = target.get_rows(&train_idx);
let mut tree_params = decision_tree::Hyperparameters::new(data.cols());
tree_params.min_samples_split(10)
.max_features(4)
.rng(StdRng::from_seed(&[100]));
let mut model = Hyperparameters::new(tree_params, 10)
.rng(StdRng::from_seed(&[100]))
.one_vs_rest();
model.fit(&x_train, &y_train).unwrap();
let test_prediction = model.predict(&x_test).unwrap();
test_accuracy += accuracy_score(
&target.get_rows(&test_idx),
&test_prediction);
}
test_accuracy /= no_splits as f32;
println!("Accuracy {}", test_accuracy);
assert!(test_accuracy > 0.96);
}
#[test]
#[cfg(feature = "all_tests")]
fn test_random_forest_newsgroups() {
extern crate time;
use feature_extraction::dict_vectorizer::*;
let mut rdr = csv::Reader::from_file("./test_data/newsgroups/data.csv")
.unwrap()
.has_headers(false);
let mut vectorizer = DictVectorizer::new();
let mut target = Vec::new();
for (row, record) in rdr.decode().enumerate() {
let (y, data): (f32, String) = record.unwrap();
for token in data.split_whitespace() {
vectorizer.partial_fit(row, token, 1.0);
}
target.push(y);
}
let target = Array::from(target);
let X = vectorizer.transform();
let no_splits = 2;
let mut test_accuracy = 0.0;
let mut train_accuracy = 0.0;
let mut cv = CrossValidation::new(X.rows(),
no_splits);
cv.set_rng(StdRng::from_seed(&[100]));
for (train_idx, test_idx) in cv {
let x_train = X.get_rows(&train_idx);
let x_test = X.get_rows(&test_idx);
let y_train = target.get_rows(&train_idx);
let mut tree_params = decision_tree::Hyperparameters::new(X.cols());
tree_params.min_samples_split(5)
.rng(StdRng::from_seed(&[100]));
let mut model = Hyperparameters::new(
tree_params, 20)
.one_vs_rest();
let start = time::precise_time_ns();
model.fit(&x_train, &y_train).unwrap();
println!("Elapsed {}", time::precise_time_ns() - start);
let y_hat = model.predict(&x_test).unwrap();
let y_hat_train = model.predict(&x_train).unwrap();
test_accuracy += accuracy_score(
&target.get_rows(&test_idx),
&y_hat);
train_accuracy += accuracy_score(
&target.get_rows(&train_idx),
&y_hat_train);
}
test_accuracy /= no_splits as f32;
train_accuracy /= no_splits as f32;
println!("{}", test_accuracy);
println!("train accuracy {}", train_accuracy);
assert!(train_accuracy > 0.95);
}
#[test]
fn serialization() {
let (data, target) = load_data();
let mut test_accuracy = 0.0;
let no_splits = 10;
let mut cv = CrossValidation::new(data.rows(),
no_splits);
cv.set_rng(StdRng::from_seed(&[100]));
for (train_idx, test_idx) in cv {
let x_train = data.get_rows(&train_idx);
let x_test = data.get_rows(&test_idx);
let y_train = target.get_rows(&train_idx);
let mut tree_params = decision_tree::Hyperparameters::new(data.cols());
tree_params.min_samples_split(10)
.max_features(4)
.rng(StdRng::from_seed(&[100]));
let mut model = Hyperparameters::new(tree_params, 10)
.rng(StdRng::from_seed(&[100]))
.one_vs_rest();
model.fit(&x_train, &y_train).unwrap();
let encoded = bincode::rustc_serialize::encode(&model,
bincode::SizeLimit::Infinite).unwrap();
let decoded: OneVsRestWrapper<RandomForest> = bincode::rustc_serialize::decode(&encoded).unwrap();
let test_prediction = decoded.predict(&x_test).unwrap();
test_accuracy += accuracy_score(
&target.get_rows(&test_idx),
&test_prediction);
}
test_accuracy /= no_splits as f32;
println!("Accuracy {}", test_accuracy);
assert!(test_accuracy > 0.96);
}
}