use rand::distributions::range::SampleRange;
use rand::thread_rng;
use array_ops::{Partition, resample};
use criterion::SplitCriterion;
use split::Split;
use split_between::SplitBetween;
pub trait SampleDescription {
type ThetaSplit: Clone;
type ThetaLeaf: Clone;
type Feature: PartialOrd + SampleRange + SplitBetween;
type Target;
type Prediction;
fn target(&self) -> Self::Target;
fn sample_as_split_feature(&self, theta: &Self::ThetaSplit) -> Self::Feature;
fn sample_predict(&self, w: &Self::ThetaLeaf) -> Self::Prediction;
}
pub trait TrainingData<Sample>: DataSet<Sample>
where Sample: SampleDescription
{
type Criterion: SplitCriterion<Sample::Target>;
fn n_samples(&self) -> usize;
fn gen_split_feature(&self) -> Sample::ThetaSplit;
fn all_split_features(&self) -> Option<Box<Iterator<Item=Sample::ThetaSplit>>> { None }
fn train_leaf_predictor(&self) -> Sample::ThetaLeaf;
fn feature_bounds(&self, theta: &Sample::ThetaSplit) -> (Sample::Feature, Sample::Feature);
}
pub trait DataSet<Sample>
where Sample: SampleDescription
{
fn partition_data(&mut self, split: &Split<Sample::ThetaSplit, Sample::Feature>) -> (&mut Self, &mut Self);
fn sort_data(&mut self, theta: &Sample::ThetaSplit);
fn bootstrap_resample(&self, n: usize) -> Vec<Sample>;
fn visit_samples<F: FnMut(&Sample)>(&self, visitor: F);
}
impl<Sample> DataSet<Sample> for [Sample]
where Sample: SampleDescription + Clone
{
fn partition_data(&mut self, split: &Split<Sample::ThetaSplit, Sample::Feature>) -> (&mut Self, &mut Self) {
let i = self.partition(|sample| sample.sample_as_split_feature(&split.theta) <= split.threshold);
self.split_at_mut(i)
}
fn sort_data(&mut self, theta: &Sample::ThetaSplit) {
self.sort_unstable_by(|a, b| {
let fa = a.sample_as_split_feature(theta);
let fb = b.sample_as_split_feature(theta);
match fa.partial_cmp(&fb) {
Some(ordering) => ordering,
None => panic!("Could not compare samples (this is likely caused by a NaN feature"),
}
})
}
fn bootstrap_resample(&self, n: usize) -> Vec<Sample> {
resample(self, n, &mut thread_rng())
}
fn visit_samples<F: FnMut(&Sample)>(&self, mut visitor: F) {
for sample in self.iter() {
visitor(sample);
}
}
}