use std::f64;
use rand::{thread_rng, Rng};
use criterion::SplitCriterion;
use data::{SampleDescription, TrainingData};
use split_between::SplitBetween;
#[derive(Debug)]
pub struct Split<Theta, Threshold> {
pub theta: Theta,
pub threshold: Threshold,
}
pub trait SplitFinder
{
fn find_split<Sample, Training>(&self, data: &mut Training)
-> Option<Split<Sample::ThetaSplit, Sample::Feature>>
where Sample: SampleDescription,
Training: ?Sized + TrainingData<Sample>;
}
pub struct BestRandomSplit {
n_splits: usize,
}
impl BestRandomSplit {
pub fn new(n_splits: usize) -> Self {
BestRandomSplit {
n_splits
}
}
}
impl SplitFinder for BestRandomSplit
{
fn find_split<Sample, Training>(&self, data: &mut Training)
-> Option<Split<Sample::ThetaSplit, Sample::Feature>>
where Sample: SampleDescription,
Training: ?Sized + TrainingData<Sample>
{
let n = data.n_samples() as f64;
let mut best_criterion = Training::Criterion::from_dataset(data).get();
let mut best_split = None;
let mut rng = thread_rng();
for _ in 0..self.n_splits {
let theta = data.gen_split_feature();
let (min, max) = data.feature_bounds(&theta);
if max <= min {
continue
}
let threshold = rng.gen_range(min, max);
let split = Split{theta, threshold};
let (left, right) = data.partition_data(&split);
let left_crit = Training::Criterion::from_dataset(left).get_weighted();
let right_crit = Training::Criterion::from_dataset(right).get_weighted();
let criterion = (left_crit + right_crit) / n;
if criterion <= best_criterion {
best_criterion = criterion;
best_split = Some(split);
}
if best_criterion == 0.0 {
break
}
}
best_split
}
}
pub struct BestSplitRandomFeature {
n_features: usize,
}
impl BestSplitRandomFeature {
pub fn new(n_features: usize) -> Self {
BestSplitRandomFeature {
n_features
}
}
}
impl SplitFinder for BestSplitRandomFeature
{
fn find_split<Sample, Training>(&self, data: &mut Training)
-> Option<Split<Sample::ThetaSplit, Sample::Feature>>
where Sample: SampleDescription,
Training: ?Sized + TrainingData<Sample>
{
let mut best_criterion = Training::Criterion::from_dataset(data).get();
let mut best_split = None;
let mut features: Vec<_> = data
.all_split_features()
.expect("Dataset does not support iteration over features.")
.collect();
thread_rng().shuffle(&mut features);
let mut n_to_check = self.n_features;
for theta in features {
if best_criterion == 0.0 {
break
}
if n_to_check == 0 {
if best_split.is_some() {
break
}
} else {
n_to_check -= 1;
}
let (criterion, split) = find_best_split_for_feature(data, &theta);
if criterion <= best_criterion {
best_criterion = criterion;
best_split = split;
}
}
best_split
}
}
pub struct BestSplit {
}
impl BestSplit {
pub fn new() -> Self {
BestSplit {
}
}
}
impl SplitFinder for BestSplit
{
fn find_split<Sample, Training>(&self, data: &mut Training)
-> Option<Split<Sample::ThetaSplit, Sample::Feature>>
where Sample: SampleDescription,
Training: ?Sized + TrainingData<Sample>
{
let mut best_criterion = Training::Criterion::from_dataset(data).get();
let mut best_split = None;
let features = data.all_split_features().expect("Dataset does not support iteration over features.");
for theta in features {
if best_criterion == 0.0 {
break
}
let (criterion, split) = find_best_split_for_feature(data, &theta);
if criterion <= best_criterion {
best_criterion = criterion;
best_split = split;
}
}
best_split
}
}
fn find_best_split_for_feature<Sample, Training>(data: &mut Training, theta: &Sample::ThetaSplit)
-> (f64, Option<Split<Sample::ThetaSplit, Sample::Feature>>)
where Sample: SampleDescription,
Training: ?Sized + TrainingData<Sample>
{
data.sort_data(theta);
let mut best_criterion = f64::INFINITY;
let mut best_split = None;
let mut left_crit = Training::Criterion::from_dataset(data);
let mut right_crit = Training::Criterion::new();
let mut prev_sf: Option<Sample::Feature> = None;
data.visit_samples(|sample| {
let sf = sample.sample_as_split_feature(theta);
if let Some(ref psf) = prev_sf {
if psf == &sf {
left_crit.remove_sample(sample);
right_crit.add_sample(sample);
return
}
}
let criterion = (left_crit.get_weighted() + right_crit.get_weighted()) / data.n_samples() as f64;
if criterion <= best_criterion {
if let Some(ref psf) = prev_sf {
best_criterion = criterion;
best_split = Some(Split {
theta: theta.clone(),
threshold: psf.split_between(&sf)
});
}
}
left_crit.remove_sample(sample);
right_crit.add_sample(sample);
prev_sf = Some(sf);
});
(best_criterion, best_split)
}
#[cfg(test)]
mod tests {
use super::*;
use split::BestRandomSplit;
use testdata::Sample;
#[test]
fn best_random_split() {
let data: &mut [_] = &mut [
Sample::new(&[0.0], 1.0),
Sample::new(&[1.0], 2.0),
];
let spl = BestRandomSplit::new(1);
let split = spl.find_split(data).unwrap();
assert_eq!(split.theta, 0);
assert!(split.threshold >= 0.0);
assert!(split.threshold <= 1.0);
let data: &mut [_] = &mut [
Sample::new(&[41.0, 0.0], 1.0),
Sample::new(&[41.0, 1.0], 2.0),
Sample::new(&[43.0, 2.0], 1.0),
Sample::new(&[43.0, 3.0], 2.0),
Sample::new(&[41.0, 4.0], 11.0),
Sample::new(&[41.0, 5.0], 12.0),
Sample::new(&[43.0, 6.0], 11.0),
Sample::new(&[43.0, 7.0], 12.0),
];
let spl = BestRandomSplit::new(100);
let split = spl.find_split(data).unwrap();
assert_eq!(split.theta, 1);
assert!(split.threshold >= 3.0);
assert!(split.threshold <= 4.0);
}
#[test]
fn best_split() {
let data: &mut [_] = &mut [
Sample::new(&[0.0, 41.0, 0.0], 1.0),
Sample::new(&[0.0, 41.0, 1.0], 2.0),
Sample::new(&[0.0, 43.0, 2.0], 1.0),
Sample::new(&[0.0, 43.0, 3.0], 2.0),
Sample::new(&[0.0, 41.0, 4.0], 11.0),
Sample::new(&[0.0, 41.0, 5.0], 12.0),
Sample::new(&[0.0, 43.0, 6.0], 11.0),
Sample::new(&[0.0, 43.0, 7.0], 12.0),
Sample::new(&[0.0, 42.0, 8.0], 11.0),
Sample::new(&[0.0, 42.0, 9.0], 12.0),
];
let spl = BestSplit::new();
let split = spl.find_split(data).unwrap();
assert_eq!(split.theta, 2);
assert!(split.threshold >= 3.0);
assert!(split.threshold <= 4.0);
}
#[test]
fn best_split_random_feature() {
let data: &mut [_] = &mut [
Sample::new(&[0.0, 41.0, 0.0], 1.0),
Sample::new(&[0.0, 41.0, 1.0], 2.0),
Sample::new(&[0.0, 43.0, 2.0], 1.0),
Sample::new(&[0.0, 43.0, 3.0], 2.0),
Sample::new(&[0.0, 41.0, 4.0], 11.0),
Sample::new(&[0.0, 41.0, 5.0], 12.0),
Sample::new(&[0.0, 43.0, 6.0], 11.0),
Sample::new(&[0.0, 43.0, 7.0], 12.0),
Sample::new(&[0.0, 42.0, 8.0], 11.0),
Sample::new(&[0.0, 42.0, 9.0], 12.0),
];
let spl = BestSplitRandomFeature::new(3);
let split = spl.find_split(data).unwrap();
assert_eq!(split.theta, 2);
assert!(split.threshold >= 3.0);
assert!(split.threshold <= 4.0);
}
}