smartcore 0.5.0

Machine Learning in Rust.
Documentation
//! # KFold
//!
//! Defines k-fold cross validator.
use std::fmt::{Debug, Display};

use crate::linalg::basic::arrays::Array2;
use crate::model_selection::BaseKFold;
use crate::rand_custom::get_rng_impl;
use rand::seq::SliceRandom;

/// K-Folds cross-validator
pub struct KFold {
    /// Number of folds. Must be at least 2.
    pub n_splits: usize, // cannot exceed std::usize::MAX
    /// Whether to shuffle the data before splitting into batches
    pub shuffle: bool,
    /// When shuffle is True, seed affects the ordering of the indices.
    /// Which controls the randomness of each fold
    pub seed: Option<u64>,
}

impl KFold {
    fn test_indices<T: Debug + Display + Copy + Sized, M: Array2<T>>(
        &self,
        x: &M,
    ) -> Vec<Vec<usize>> {
        // number of samples (rows) in the matrix
        let n_samples: usize = x.shape().0;

        // initialise indices
        let mut indices: Vec<usize> = (0..n_samples).collect();
        let mut rng = get_rng_impl(self.seed);

        if self.shuffle {
            indices.shuffle(&mut rng);
        }
        //  return a new array of given shape n_split, filled with each element of n_samples divided by n_splits.
        let mut fold_sizes = vec![n_samples / self.n_splits; self.n_splits];

        // increment by one if odd
        for fold_size in fold_sizes.iter_mut().take(n_samples % self.n_splits) {
            *fold_size += 1;
        }

        // generate the right array of arrays for test indices
        let mut return_values: Vec<Vec<usize>> = Vec::with_capacity(self.n_splits);
        let mut current: usize = 0;
        for fold_size in fold_sizes.drain(..) {
            let stop = current + fold_size;
            return_values.push(indices[current..stop].to_vec());
            current = stop
        }

        return_values
    }

    fn test_masks<T: Debug + Display + Copy + Sized, M: Array2<T>>(&self, x: &M) -> Vec<Vec<bool>> {
        let mut return_values: Vec<Vec<bool>> = Vec::with_capacity(self.n_splits);
        for test_index in self.test_indices(x).drain(..) {
            // init mask
            let mut test_mask = vec![false; x.shape().0];
            // set mask's indices to true according to test indices
            for i in test_index {
                test_mask[i] = true; // can be implemented with map()
            }
            return_values.push(test_mask);
        }
        return_values
    }
}

impl Default for KFold {
    fn default() -> KFold {
        KFold {
            n_splits: 3,
            shuffle: true,
            seed: Option::None,
        }
    }
}

impl KFold {
    /// Number of folds. Must be at least 2.
    pub fn with_n_splits(mut self, n_splits: usize) -> Self {
        self.n_splits = n_splits;
        self
    }
    /// Whether to shuffle the data before splitting into batches
    pub fn with_shuffle(mut self, shuffle: bool) -> Self {
        self.shuffle = shuffle;
        self
    }

    /// When shuffle is True, random_state affects the ordering of the indices.
    pub fn with_seed(mut self, seed: Option<u64>) -> Self {
        self.seed = seed;
        self
    }
}

/// An iterator over indices that split data into training and test set.
pub struct KFoldIter {
    indices: Vec<usize>,
    test_indices: Vec<Vec<bool>>,
}

impl Iterator for KFoldIter {
    type Item = (Vec<usize>, Vec<usize>);

    fn next(&mut self) -> Option<(Vec<usize>, Vec<usize>)> {
        self.test_indices.pop().map(|test_index| {
            let train_index = self
                .indices
                .iter()
                .enumerate()
                .filter(|&(idx, _)| !test_index[idx])
                .map(|(idx, _)| idx)
                .collect::<Vec<usize>>(); // filter train indices out according to mask
            let test_index = self
                .indices
                .iter()
                .enumerate()
                .filter(|&(idx, _)| test_index[idx])
                .map(|(idx, _)| idx)
                .collect::<Vec<usize>>(); // filter tests indices out according to mask

            (train_index, test_index)
        })
    }
}

/// Abstract class for all KFold functionalities
impl BaseKFold for KFold {
    type Output = KFoldIter;

    fn n_splits(&self) -> usize {
        self.n_splits
    }

    fn split<T: Debug + Display + Copy + Sized, M: Array2<T>>(&self, x: &M) -> Self::Output {
        if self.n_splits < 2 {
            panic!("Number of splits is too small: {}", self.n_splits);
        }
        let n_samples: usize = x.shape().0;
        let indices: Vec<usize> = (0..n_samples).collect();
        let mut test_indices = self.test_masks(x);
        test_indices.reverse();

        KFoldIter {
            indices,
            test_indices,
        }
    }
}

#[cfg(test)]
mod tests {

    use super::*;
    use crate::linalg::basic::matrix::DenseMatrix;

    #[cfg_attr(
        all(target_arch = "wasm32", not(target_os = "wasi")),
        wasm_bindgen_test::wasm_bindgen_test
    )]
    #[test]
    fn run_kfold_return_test_indices_simple() {
        let k = KFold {
            n_splits: 3,
            shuffle: false,
            seed: Option::None,
        };
        let x: DenseMatrix<f64> = DenseMatrix::rand(33, 100);
        let test_indices = k.test_indices(&x);

        assert_eq!(test_indices[0], (0..11).collect::<Vec<usize>>());
        assert_eq!(test_indices[1], (11..22).collect::<Vec<usize>>());
        assert_eq!(test_indices[2], (22..33).collect::<Vec<usize>>());
    }

    #[cfg_attr(
        all(target_arch = "wasm32", not(target_os = "wasi")),
        wasm_bindgen_test::wasm_bindgen_test
    )]
    #[test]
    fn run_kfold_return_test_indices_odd() {
        let k = KFold {
            n_splits: 3,
            shuffle: false,
            seed: Option::None,
        };
        let x: DenseMatrix<f64> = DenseMatrix::rand(34, 100);
        let test_indices = k.test_indices(&x);

        assert_eq!(test_indices[0], (0..12).collect::<Vec<usize>>());
        assert_eq!(test_indices[1], (12..23).collect::<Vec<usize>>());
        assert_eq!(test_indices[2], (23..34).collect::<Vec<usize>>());
    }

    #[cfg_attr(
        all(target_arch = "wasm32", not(target_os = "wasi")),
        wasm_bindgen_test::wasm_bindgen_test
    )]
    #[test]
    fn run_kfold_return_test_mask_simple() {
        let k = KFold {
            n_splits: 2,
            shuffle: false,
            seed: Option::None,
        };
        let x: DenseMatrix<f64> = DenseMatrix::rand(22, 100);
        let test_masks = k.test_masks(&x);

        for t in &test_masks[0][0..11] {
            // TODO: this can be prob done better
            assert!(*t)
        }
        for t in &test_masks[0][11..22] {
            assert!(!*t)
        }

        for t in &test_masks[1][0..11] {
            assert!(!*t)
        }
        for t in &test_masks[1][11..22] {
            assert!(*t)
        }
    }

    #[cfg_attr(
        all(target_arch = "wasm32", not(target_os = "wasi")),
        wasm_bindgen_test::wasm_bindgen_test
    )]
    #[test]
    fn run_kfold_return_split_simple() {
        let k = KFold {
            n_splits: 2,
            shuffle: false,
            seed: Option::None,
        };
        let x: DenseMatrix<f64> = DenseMatrix::rand(22, 100);
        let train_test_splits: Vec<(Vec<usize>, Vec<usize>)> = k.split(&x).collect();

        assert_eq!(train_test_splits[0].1, (0..11).collect::<Vec<usize>>());
        assert_eq!(train_test_splits[0].0, (11..22).collect::<Vec<usize>>());
        assert_eq!(train_test_splits[1].0, (0..11).collect::<Vec<usize>>());
        assert_eq!(train_test_splits[1].1, (11..22).collect::<Vec<usize>>());
    }

    #[cfg_attr(
        all(target_arch = "wasm32", not(target_os = "wasi")),
        wasm_bindgen_test::wasm_bindgen_test
    )]
    #[test]
    fn run_kfold_return_split_simple_shuffle() {
        let k = KFold {
            n_splits: 2,
            ..KFold::default()
        };
        let x: DenseMatrix<f64> = DenseMatrix::rand(23, 100);
        let train_test_splits: Vec<(Vec<usize>, Vec<usize>)> = k.split(&x).collect();

        assert_eq!(train_test_splits[0].1.len(), 12_usize);
        assert_eq!(train_test_splits[0].0.len(), 11_usize);
        assert_eq!(train_test_splits[1].0.len(), 12_usize);
        assert_eq!(train_test_splits[1].1.len(), 11_usize);
    }

    #[cfg_attr(
        all(target_arch = "wasm32", not(target_os = "wasi")),
        wasm_bindgen_test::wasm_bindgen_test
    )]
    #[test]
    fn numpy_parity_test() {
        let k = KFold {
            n_splits: 3,
            shuffle: false,
            seed: Option::None,
        };
        let x: DenseMatrix<f64> = DenseMatrix::rand(10, 4);
        let expected: Vec<(Vec<usize>, Vec<usize>)> = vec![
            (vec![4, 5, 6, 7, 8, 9], vec![0, 1, 2, 3]),
            (vec![0, 1, 2, 3, 7, 8, 9], vec![4, 5, 6]),
            (vec![0, 1, 2, 3, 4, 5, 6], vec![7, 8, 9]),
        ];
        for ((train, test), (expected_train, expected_test)) in k.split(&x).zip(expected) {
            assert_eq!(test, expected_test);
            assert_eq!(train, expected_train);
        }
    }

    #[cfg_attr(
        all(target_arch = "wasm32", not(target_os = "wasi")),
        wasm_bindgen_test::wasm_bindgen_test
    )]
    #[test]
    fn numpy_parity_test_shuffle() {
        let k = KFold {
            n_splits: 3,
            ..KFold::default()
        };
        let x: DenseMatrix<f64> = DenseMatrix::rand(10, 4);
        let expected: Vec<(Vec<usize>, Vec<usize>)> = vec![
            (vec![4, 5, 6, 7, 8, 9], vec![0, 1, 2, 3]),
            (vec![0, 1, 2, 3, 7, 8, 9], vec![4, 5, 6]),
            (vec![0, 1, 2, 3, 4, 5, 6], vec![7, 8, 9]),
        ];
        for ((train, test), (expected_train, expected_test)) in k.split(&x).zip(expected) {
            assert_eq!(test.len(), expected_test.len());
            assert_eq!(train.len(), expected_train.len());
        }
    }
}