use serde::Serialize;
use serde::de::DeserializeOwned;
pub enum Binar {
One,
Zero,
}
#[derive(Serialize, Deserialize)]
enum Node<P, L> {
Leaf(L),
Interior {
children: Box<(Node<P, L>, Node<P, L>)>,
params: P,
},
}
pub trait TreeFunction {
type Data;
type Param;
fn binarize(&self, param: &Self::Param, input: &Self::Data) -> Binar;
fn split_set<'a>(
&self,
param: &Self::Param,
elements: &[&'a Self::Data],
) -> (Vec<&'a Self::Data>, Vec<&'a Self::Data>) {
let mut a = Vec::with_capacity(elements.len() / 2);
let mut b = Vec::with_capacity(elements.len() / 2);
for el in elements.iter() {
match self.binarize(param, el) {
Binar::Zero => a.push(*el),
Binar::One => b.push(*el),
}
}
(a, b)
}
}
pub trait TreeLearnFunctions: TreeFunction {
type Truth;
type LeafParam;
type ParamIter: Iterator<Item=Self::Param>;
type PredictFunction: TreeFunction<Data=Self::Data, Param=Self::Param>;
fn impurity(&self,
param: &Self::Param,
set_l: &[(&Self::Data, &Self::Truth)],
set_r: &[(&Self::Data, &Self::Truth)],
depth: usize)
-> f64;
fn param_set(&self) -> Self::ParamIter;
fn comp_leaf_data(&self, set: &[(&Self::Data, &Self::Truth)]) -> Self::LeafParam;
fn split_truth_set<'a>
(&self,
param: &Self::Param,
elements: &[(&'a Self::Data, &'a Self::Truth)])
-> (Vec<(&'a Self::Data, &'a Self::Truth)>, Vec<(&'a Self::Data, &'a Self::Truth)>) {
let mut a = Vec::with_capacity(elements.len()/2);
let mut b = Vec::with_capacity(elements.len()/2);
for el in elements.iter() {
let (x, truth) = *el;
match self.binarize(param, x) {
Binar::Zero => a.push((x, truth)),
Binar::One => b.push((x, truth)),
}
}
return (a, b);
}
#[allow(unused_variables)]
fn early_stop(&self, depth: usize, elements: &[(&Self::Data, &Self::Truth)]) -> bool {
false
}
fn as_predict_learn_func(self) -> Self::PredictFunction;
}
#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
pub struct TreeParameters {}
impl TreeParameters {
pub fn new() -> TreeParameters {
TreeParameters {}
}
}
#[derive(Serialize, Deserialize)]
#[serde(bound(serialize = "F: Serialize,F::Param: Serialize, L: Serialize"))]
#[serde(bound(deserialize = "F: DeserializeOwned, F::Param: DeserializeOwned,L: DeserializeOwned"))]
pub struct DecisionTree<L, F>
where
F: TreeFunction,
{
root: Option<Node<F::Param, L>>, functions: F, pub params: TreeParameters,
}
impl<L, F> DecisionTree<L, F>
where
F: TreeFunction,
{
pub fn predict<'a>(&'a self, input: &F::Data) -> Option<&'a L> {
if self.root.is_none() {
return None;
}
let mut node: &Node<F::Param, L> = self.root.as_ref().unwrap();
'l: loop {
match *node {
Node::Leaf(ref param) => return Some(param),
Node::Interior {
ref children,
ref params,
} => {
match self.functions.binarize(params, input) {
Binar::Zero => node = &(*children).0,
Binar::One => node = &(*children).1,
}
}
}
}
}
}
impl TreeParameters {
pub fn learn_tree<F>(
self,
learn_func: F,
train_set: &[(&F::Data, &F::Truth)],
) -> DecisionTree<F::LeafParam,F::PredictFunction>
where
F: TreeLearnFunctions,
{
fn learn_tree_intern<F>(
depth: usize,
subset: &[(&F::Data, &F::Truth)],
learn_func: &F,
) -> Node<F::Param, F::LeafParam>
where
F: TreeLearnFunctions,
{
use std::f64;
assert!(subset.is_empty() == false);
if learn_func.early_stop(depth, subset) {
return Node::Leaf(learn_func.comp_leaf_data(subset));
}
let parameters = learn_func.param_set();
let mut best_impurity = f64::INFINITY;
let mut best_param: Option<F::Param> = None;
let mut left: Vec<(&F::Data, &F::Truth)> = vec![];
let mut right: Vec<(&F::Data, &F::Truth)> = vec![];
for param in parameters {
let (left_, right_) = learn_func.split_truth_set(¶m, subset);
if left_.is_empty() {
continue;
}
if right_.is_empty() {
continue;
}
let impurity = learn_func.impurity(¶m, &left_[..], &right_[..], depth);
assert!(false == impurity.is_nan());
if impurity < best_impurity {
best_impurity = impurity;
best_param = Some(param);
left = left_;
right = right_;
}
}
if let Some(best) = best_param {
Node::Interior {
params: best,
children: Box::new((
learn_tree_intern(depth + 1, &left[..], learn_func),
learn_tree_intern(depth + 1, &right[..], learn_func),
)),
}
} else {
Node::Leaf(learn_func.comp_leaf_data(subset))
}
}
DecisionTree {
root: Some(learn_tree_intern(0, train_set, &learn_func)),
functions: learn_func.as_predict_learn_func(),
params: self,
}
}
}