use super::tree::*;
use rand;
use rand::Rng;
use serde::Serialize;
use serde::de::DeserializeOwned;
use rayon::prelude::*;
pub trait VotingMethod<L> {
fn voting(&self, tree_results: &[&L]) -> f64;
}
pub trait FromGetProbability {
fn probability(&self) -> f64;
}
impl<T> FromGetProbability for T
where
f64: From<T>,
T: Copy,
{
fn probability(&self) -> f64 {
return f64::from(*self);
}
}
pub struct AverageVoting;
impl<L> VotingMethod<L> for AverageVoting
where
L: FromGetProbability,
{
fn voting(&self, tree_results: &[&L]) -> f64 {
let sum = tree_results.iter().fold(
0f64,
|sum, l| sum + l.probability(),
);
return (sum as f64) / (tree_results.len() as f64);
}
}
#[derive(Serialize, Deserialize)]
#[serde(bound(serialize = "DecisionTree<L,F>: Serialize"))]
#[serde(bound(deserialize = "DecisionTree<L,F>: DeserializeOwned"))]
pub struct RandomForest<L, F>
where
F: TreeFunction,
{
subtrees: Vec<DecisionTree<L, F>>,
}
impl<L, F> RandomForest<L, F>
where
F: TreeFunction,
{
pub fn forest_predictions(&self, input: &F::Data) -> Vec<&L> {
self.subtrees
.iter()
.filter_map(|tree| tree.predict(input))
.collect()
}
pub fn predict<V>(&self, input: &F::Data, voting_method: V) -> Option<f64>
where
V: VotingMethod<L>,
{
let predictions: Vec<_> = self.forest_predictions(input);
Some(voting_method.voting(&predictions[..]))
}
}
impl<L, F> RandomForest<L, F>
where
F: TreeFunction + Send + Sync,
<F as TreeFunction>::Param: Send + Sync,
<F as TreeFunction>::Data: Send + Sync,
L: Send + Sync,
{
pub fn forest_predictions_parallel(&self, input: &F::Data) -> Vec<&L> {
self.subtrees
.par_iter()
.filter_map(|tree| tree.predict(input))
.fold(|| Vec::with_capacity(self.subtrees.len()), |mut v, x| {
v.push(x);
v
})
.reduce(|| Vec::with_capacity(self.subtrees.len()), |mut v, mut x| {
v.append(&mut x);
v
})
}
}
pub struct RandomForestLearnParam<LearnF>
where
LearnF: TreeLearnFunctions,
{
pub tree_param: TreeParameters,
pub number_of_trees: usize,
pub size_of_subset_per_training: usize,
pub learn_function: LearnF,
}
impl<LearnF> RandomForestLearnParam<LearnF>
where
LearnF: TreeLearnFunctions + Copy,
{
pub fn new(
number_of_trees: usize,
size_of_subset_per_training: usize,
learnf: LearnF,
) -> RandomForestLearnParam<LearnF> {
RandomForestLearnParam {
tree_param: TreeParameters::new(),
number_of_trees: number_of_trees,
size_of_subset_per_training: size_of_subset_per_training,
learn_function: learnf,
}
}
pub fn train_forest(
self,
train_set: &[(&LearnF::Data, &LearnF::Truth)],
) -> Option<RandomForest<LearnF::LeafParam, LearnF::PredictFunction>> {
let mut res = vec![];
let mut rng = rand::thread_rng();
let mut subset = Vec::with_capacity(self.size_of_subset_per_training);
for _ in 0..self.number_of_trees {
subset.clear();
for _ in 0..self.size_of_subset_per_training {
subset.push(train_set[rng.gen_range(0, train_set.len())]);
}
let tree = self.tree_param.learn_tree(self.learn_function, &subset[..]);
res.push(tree);
}
Some(RandomForest { subtrees: res })
}
}
impl<LearnF> RandomForestLearnParam<LearnF>
where
LearnF: TreeLearnFunctions + Copy + Send + Sync,
LearnF::PredictFunction: Send + Sync,
LearnF::Truth: Send + Sync,
LearnF::LeafParam: Send + Sync,
LearnF::Data: Send + Sync,
LearnF::Param: Send + Sync,
{
pub fn train_forest_parallel(
self,
train_set: &[(&LearnF::Data, &LearnF::Truth)],
) -> Option<RandomForest<LearnF::LeafParam, LearnF::PredictFunction>> {
let subset_size = self.size_of_subset_per_training;
let trees = (0..self.number_of_trees)
.into_par_iter()
.map(|_| {
let mut rng = rand::thread_rng();
let mut subset = Vec::with_capacity(self.size_of_subset_per_training);
for _ in 0..subset_size {
subset.push(train_set[rng.gen_range(0, train_set.len())]);
}
let tree = self.tree_param.learn_tree(self.learn_function, &subset[..]);
tree
})
.fold(|| Vec::with_capacity(subset_size), |mut v, x| {
v.push(x);
v
})
.reduce(|| Vec::with_capacity(subset_size), |mut v, mut x| {
v.append(&mut x);
v
});
Some(RandomForest { subtrees: trees })
}
}