use rand::prelude::*;
use colored::Colorize;
use crate::Sample;
use std::iter::Iterator;
const WIDTH: usize = 9;
pub struct CrossValidation<'a> {
train_size: usize,
current_fold: usize,
n_folds: usize,
seed: u64,
sample: &'a Sample,
ix: Vec<usize>,
verbose: bool,
}
impl<'a> CrossValidation<'a> {
#[inline]
pub fn new(sample: &'a Sample) -> Self {
let n_sample = sample.shape().0;
let train_size = (n_sample as f64 * 0.8) as usize;
let ix = (0..n_sample).collect::<Vec<_>>();
Self {
current_fold: 0,
n_folds: 5,
seed: 1234,
verbose: false,
train_size,
sample,
ix,
}
}
#[inline]
pub fn train_ratio(mut self, ratio: f64) -> Self {
assert!(
0f64 < ratio && ratio < 1f64,
"Training ratio should be in `[0, 1)`."
);
let n_sample = self.sample.shape().0 as f64;
self.train_size = (ratio * n_sample) as usize;
self
}
#[inline]
pub fn n_folds(mut self, n_folds: usize) -> Self {
self.n_folds = n_folds;
self
}
#[inline]
pub fn seed(mut self, seed: u64) -> Self {
self.seed = seed;
self
}
#[inline]
pub fn verbose(mut self, verbose: bool) -> Self {
self.verbose = verbose;
self
}
#[inline]
pub fn shuffle(mut self) -> Self {
let mut rng = StdRng::seed_from_u64(self.seed);
self.ix.shuffle(&mut rng);
self
}
#[inline]
fn fold_at(&self, i: usize) -> (Sample, Sample) {
let sample_size = self.sample.shape().0;
let test_size = sample_size - self.train_size;
let (start, end) = (i*test_size, ((i+1)*test_size).min(sample_size));
self.sample.split(&self.ix, start, end)
}
}
impl<'a> Iterator for CrossValidation<'a> {
type Item = (Sample, Sample);
fn next(&mut self) -> Option<Self::Item> {
if self.current_fold >= self.n_folds { return None; }
let output = self.fold_at(self.current_fold);
self.current_fold += 1;
if self.verbose {
let train_size = output.0.shape().0;
let test_size = output.1.shape().0;
println!(
"{} {} {}",
format!(" [{: >3}'th fold]", self.current_fold).bold().red(),
format!("[TRAIN {:>WIDTH$}]", train_size).bold().green(),
format!("[TEST {:>WIDTH$}]", test_size).bold().yellow(),
);
}
Some(output)
}
}