use std::fmt::{Debug, Display};
use crate::linalg::basic::arrays::Array2;
use crate::model_selection::BaseKFold;
use crate::rand_custom::get_rng_impl;
use rand::seq::SliceRandom;
pub struct KFold {
pub n_splits: usize, pub shuffle: bool,
pub seed: Option<u64>,
}
impl KFold {
fn test_indices<T: Debug + Display + Copy + Sized, M: Array2<T>>(
&self,
x: &M,
) -> Vec<Vec<usize>> {
let n_samples: usize = x.shape().0;
let mut indices: Vec<usize> = (0..n_samples).collect();
let mut rng = get_rng_impl(self.seed);
if self.shuffle {
indices.shuffle(&mut rng);
}
let mut fold_sizes = vec![n_samples / self.n_splits; self.n_splits];
for fold_size in fold_sizes.iter_mut().take(n_samples % self.n_splits) {
*fold_size += 1;
}
let mut return_values: Vec<Vec<usize>> = Vec::with_capacity(self.n_splits);
let mut current: usize = 0;
for fold_size in fold_sizes.drain(..) {
let stop = current + fold_size;
return_values.push(indices[current..stop].to_vec());
current = stop
}
return_values
}
fn test_masks<T: Debug + Display + Copy + Sized, M: Array2<T>>(&self, x: &M) -> Vec<Vec<bool>> {
let mut return_values: Vec<Vec<bool>> = Vec::with_capacity(self.n_splits);
for test_index in self.test_indices(x).drain(..) {
let mut test_mask = vec![false; x.shape().0];
for i in test_index {
test_mask[i] = true; }
return_values.push(test_mask);
}
return_values
}
}
impl Default for KFold {
fn default() -> KFold {
KFold {
n_splits: 3,
shuffle: true,
seed: Option::None,
}
}
}
impl KFold {
pub fn with_n_splits(mut self, n_splits: usize) -> Self {
self.n_splits = n_splits;
self
}
pub fn with_shuffle(mut self, shuffle: bool) -> Self {
self.shuffle = shuffle;
self
}
pub fn with_seed(mut self, seed: Option<u64>) -> Self {
self.seed = seed;
self
}
}
pub struct KFoldIter {
indices: Vec<usize>,
test_indices: Vec<Vec<bool>>,
}
impl Iterator for KFoldIter {
type Item = (Vec<usize>, Vec<usize>);
fn next(&mut self) -> Option<(Vec<usize>, Vec<usize>)> {
self.test_indices.pop().map(|test_index| {
let train_index = self
.indices
.iter()
.enumerate()
.filter(|&(idx, _)| !test_index[idx])
.map(|(idx, _)| idx)
.collect::<Vec<usize>>(); let test_index = self
.indices
.iter()
.enumerate()
.filter(|&(idx, _)| test_index[idx])
.map(|(idx, _)| idx)
.collect::<Vec<usize>>(); (train_index, test_index)
})
}
}
impl BaseKFold for KFold {
type Output = KFoldIter;
fn n_splits(&self) -> usize {
self.n_splits
}
fn split<T: Debug + Display + Copy + Sized, M: Array2<T>>(&self, x: &M) -> Self::Output {
if self.n_splits < 2 {
panic!("Number of splits is too small: {}", self.n_splits);
}
let n_samples: usize = x.shape().0;
let indices: Vec<usize> = (0..n_samples).collect();
let mut test_indices = self.test_masks(x);
test_indices.reverse();
KFoldIter {
indices,
test_indices,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::linalg::basic::matrix::DenseMatrix;
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn run_kfold_return_test_indices_simple() {
let k = KFold {
n_splits: 3,
shuffle: false,
seed: Option::None,
};
let x: DenseMatrix<f64> = DenseMatrix::rand(33, 100);
let test_indices = k.test_indices(&x);
assert_eq!(test_indices[0], (0..11).collect::<Vec<usize>>());
assert_eq!(test_indices[1], (11..22).collect::<Vec<usize>>());
assert_eq!(test_indices[2], (22..33).collect::<Vec<usize>>());
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn run_kfold_return_test_indices_odd() {
let k = KFold {
n_splits: 3,
shuffle: false,
seed: Option::None,
};
let x: DenseMatrix<f64> = DenseMatrix::rand(34, 100);
let test_indices = k.test_indices(&x);
assert_eq!(test_indices[0], (0..12).collect::<Vec<usize>>());
assert_eq!(test_indices[1], (12..23).collect::<Vec<usize>>());
assert_eq!(test_indices[2], (23..34).collect::<Vec<usize>>());
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn run_kfold_return_test_mask_simple() {
let k = KFold {
n_splits: 2,
shuffle: false,
seed: Option::None,
};
let x: DenseMatrix<f64> = DenseMatrix::rand(22, 100);
let test_masks = k.test_masks(&x);
for t in &test_masks[0][0..11] {
assert!(*t)
}
for t in &test_masks[0][11..22] {
assert!(!*t)
}
for t in &test_masks[1][0..11] {
assert!(!*t)
}
for t in &test_masks[1][11..22] {
assert!(*t)
}
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn run_kfold_return_split_simple() {
let k = KFold {
n_splits: 2,
shuffle: false,
seed: Option::None,
};
let x: DenseMatrix<f64> = DenseMatrix::rand(22, 100);
let train_test_splits: Vec<(Vec<usize>, Vec<usize>)> = k.split(&x).collect();
assert_eq!(train_test_splits[0].1, (0..11).collect::<Vec<usize>>());
assert_eq!(train_test_splits[0].0, (11..22).collect::<Vec<usize>>());
assert_eq!(train_test_splits[1].0, (0..11).collect::<Vec<usize>>());
assert_eq!(train_test_splits[1].1, (11..22).collect::<Vec<usize>>());
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn run_kfold_return_split_simple_shuffle() {
let k = KFold {
n_splits: 2,
..KFold::default()
};
let x: DenseMatrix<f64> = DenseMatrix::rand(23, 100);
let train_test_splits: Vec<(Vec<usize>, Vec<usize>)> = k.split(&x).collect();
assert_eq!(train_test_splits[0].1.len(), 12_usize);
assert_eq!(train_test_splits[0].0.len(), 11_usize);
assert_eq!(train_test_splits[1].0.len(), 12_usize);
assert_eq!(train_test_splits[1].1.len(), 11_usize);
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn numpy_parity_test() {
let k = KFold {
n_splits: 3,
shuffle: false,
seed: Option::None,
};
let x: DenseMatrix<f64> = DenseMatrix::rand(10, 4);
let expected: Vec<(Vec<usize>, Vec<usize>)> = vec![
(vec![4, 5, 6, 7, 8, 9], vec![0, 1, 2, 3]),
(vec![0, 1, 2, 3, 7, 8, 9], vec![4, 5, 6]),
(vec![0, 1, 2, 3, 4, 5, 6], vec![7, 8, 9]),
];
for ((train, test), (expected_train, expected_test)) in k.split(&x).zip(expected) {
assert_eq!(test, expected_test);
assert_eq!(train, expected_train);
}
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn numpy_parity_test_shuffle() {
let k = KFold {
n_splits: 3,
..KFold::default()
};
let x: DenseMatrix<f64> = DenseMatrix::rand(10, 4);
let expected: Vec<(Vec<usize>, Vec<usize>)> = vec![
(vec![4, 5, 6, 7, 8, 9], vec![0, 1, 2, 3]),
(vec![0, 1, 2, 3, 7, 8, 9], vec![4, 5, 6]),
(vec![0, 1, 2, 3, 4, 5, 6], vec![7, 8, 9]),
];
for ((train, test), (expected_train, expected_test)) in k.split(&x).zip(expected) {
assert_eq!(test.len(), expected_test.len());
assert_eq!(train.len(), expected_train.len());
}
}
}