use std::fmt;
use std::marker::PhantomData;
use data::{SampleDescription, TrainingData};
use split::{Split, SplitFinder};
pub enum Node<T>
where T: SampleDescription
{
Invalid, Split{ theta: T::ThetaSplit, threshold: T::Feature, left: usize, right: usize},
Leaf(T::ThetaLeaf),
}
pub struct DeterministicTree<Sample>
where Sample: SampleDescription
{
nodes: Vec<Node<Sample>>
}
impl<Sample: SampleDescription> fmt::Debug for DeterministicTree<Sample>
where Sample::ThetaLeaf: fmt::Debug,
Sample::ThetaSplit: fmt::Debug,
Sample::Feature: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Tree:")?;
let mut prefix = vec![];
self.recursive_fmt(0, &mut prefix, false, f)
}
}
impl<Sample: SampleDescription> DeterministicTree<Sample>
where Sample::ThetaLeaf: fmt::Debug,
Sample::ThetaSplit: fmt::Debug,
Sample::Feature: fmt::Debug,
{
fn recursive_fmt(&self, n: usize, prefix: &mut Vec<&str>, bottom: bool, f: &mut fmt::Formatter) -> fmt::Result {
let node = &self.nodes[n];
writeln!(f)?;
for p in prefix.iter() {
write!(f, "{}", p)?;
}
if bottom {
prefix.pop();
prefix.push(" ");
}
match *node {
Node::Invalid => write!(f, " *** Invalid ***")?,
Node::Leaf(ref l) => write!(f, " {:?}", l)?,
Node::Split{ref theta, ref threshold, left, right} => {
write!(f, "({:?}) <= {:?}", theta, threshold)?;
if let Some(&" +--") = prefix.last() {
prefix.pop();
prefix.push(" | ");
}
prefix.push(" +--");
self.recursive_fmt(left, prefix, false, f)?;
prefix.pop();
prefix.push(" +--");
self.recursive_fmt(right, prefix, true, f)?;
prefix.pop();
}
}
Ok(())
}
}
impl<Sample> DeterministicTree<Sample>
where Sample: SampleDescription
{
#[cfg(test)]
pub(crate) fn new_with_nodes(nodes: Vec<Node<Sample>>) -> Self {
DeterministicTree {
nodes
}
}
pub fn predict<TestingSample>(&self, sample: &TestingSample) -> TestingSample::Prediction
where TestingSample: SampleDescription<ThetaSplit=Sample::ThetaSplit,
ThetaLeaf=Sample::ThetaLeaf,
Feature=Sample::Feature> + ?Sized,
{
let start = &self.nodes[0] as *const Node<Sample>;
let mut node = &self.nodes[0] as *const Node<Sample>;
unsafe {
loop {
match *node {
Node::Split { ref theta, ref threshold, left, right } => {
if &sample.sample_as_split_feature(theta) <= threshold {
node = start.offset(left as isize);
} else {
node = start.offset(right as isize);
}
}
Node::Leaf(ref l) => {
return sample.sample_predict(l)
}
Node::Invalid => panic!("Invalid node found. Tree may not be fully constructed.")
}
}
}
}
}
pub struct DeterministicTreeBuilder<SF, Sample>
where SF: SplitFinder,
Sample: SampleDescription,
{
pub(crate) _p: PhantomData<Sample>,
pub(crate) min_samples_split: usize,
pub(crate) min_samples_leaf: usize,
pub(crate) max_depth: Option<usize>,
pub(crate) bootstrap: Option<usize>,
pub(crate) split_finder: SF,
}
impl<SF, Sample> DeterministicTreeBuilder<SF, Sample>
where SF: SplitFinder,
Sample: SampleDescription
{
pub fn new(min_samples_split: usize, split_finder: SF) -> Self {
DeterministicTreeBuilder {
min_samples_split,
min_samples_leaf: 1,
split_finder,
max_depth: None,
bootstrap: None,
_p: PhantomData,
}
}
pub fn with_max_depth(mut self, md: usize) -> Self {
self.max_depth = Some(md);
self
}
pub fn with_bootstrap(mut self, n: usize) -> Self {
self.bootstrap = Some(n);
self
}
pub fn fit<Training>(&self, data: &mut Training) -> DeterministicTree<Sample>
where Training: ?Sized + TrainingData<Sample>,
[Sample]: TrainingData<Sample>
{
let mut nodes = vec![Node::Invalid];
match self.bootstrap {
None => self.recursive_fit(&mut nodes, data, 0, 0),
Some(n) => {
let mut bdat = data.bootstrap_resample(n);
self.recursive_fit(&mut nodes, bdat.as_mut_slice(), 0, 0);
}
}
DeterministicTree {
nodes
}
}
fn recursive_fit<Training>(&self,
nodes: &mut Vec<Node<Sample>>,
data: &mut Training,
node: usize,
depth: usize)
where Training: ?Sized + TrainingData<Sample>
{
if let Some(md) = self.max_depth {
if depth >= md {
nodes[node] = Node::Leaf(data.train_leaf_predictor());
return
}
}
if data.n_samples() < self.min_samples_split {
nodes[node] = Node::Leaf(data.train_leaf_predictor());
return
}
let split = self.split_finder.find_split(data);
match split {
None => {},
Some(split) => {
let (left, right) = data.partition_data(&split);
if left.n_samples() >= self.min_samples_leaf
&& right.n_samples() >= self.min_samples_leaf
{
let (l, r) = Self::split_node(nodes, node, split);
self.recursive_fit(nodes, left, l, depth + 1);
self.recursive_fit(nodes, right, r, depth + 1);
return
}
}
}
nodes[node] = Node::Leaf(data.train_leaf_predictor())
}
fn split_node(nodes: &mut Vec<Node<Sample>>,
n: usize,
split: Split<Sample::ThetaSplit, Sample::Feature>)
-> (usize, usize)
{
let left = nodes.len();
let right = left + 1;
nodes.push(Node::Invalid);
nodes.push(Node::Invalid);
nodes[n] = Node::Split{
theta: split.theta,
threshold:split.threshold,
left,
right};
(left, right)
}
}
#[cfg(test)]
mod tests {
use super::*;
use split::BestRandomSplit;
use testdata::Sample;
#[test]
fn tree() {
let data: &mut [_] = &mut [
Sample::new(&[0.0], 1.0),
Sample::new(&[1.0], 2.0),
Sample::new(&[2.0], 1.0),
Sample::new(&[3.0], 2.0),
Sample::new(&[4.0], 11.0),
Sample::new(&[5.0], 12.0),
Sample::new(&[6.0], 11.0),
Sample::new(&[7.0], 12.0),
];
let dtb = DeterministicTreeBuilder {
_p: PhantomData,
min_samples_split: 2,
min_samples_leaf: 1,
max_depth: None,
split_finder: BestRandomSplit::new(1),
bootstrap: None,
};
let tree = dtb.fit(data);
for sample in data {
assert_eq!(tree.predict(sample), sample.y);
}
}
#[test]
fn bootstrap() {
let data: &mut [_] = &mut [
Sample::new(&[0.0], 1.0),
Sample::new(&[1.0], 2.0),
Sample::new(&[2.0], 1.0),
Sample::new(&[3.0], 2.0),
Sample::new(&[4.0], 11.0),
Sample::new(&[5.0], 12.0),
Sample::new(&[6.0], 11.0),
Sample::new(&[7.0], 12.0),
];
let tree = DeterministicTreeBuilder::new(2, BestRandomSplit::new(1))
.with_bootstrap(800) .fit(data);
for sample in data {
assert_eq!(tree.predict(sample), sample.y);
}
}
#[test]
fn fmt() {
let tree: DeterministicTree<Sample<_, _>> = DeterministicTree {
nodes: vec![
Node::Split { theta: 1, threshold: 2.3, left: 1, right: 2},
Node::Leaf(4.5),
Node::Invalid,
]
};
let formatted = format!("{:?}", tree);
assert_eq!(formatted, "Tree:\n(1) <= 2.3\n +-- 4.5\n +-- *** Invalid ***");
}
}