Skip to main content

entrenar/eval/evaluator/
kfold.rs

1//! K-Fold cross-validation splitter
2
3/// K-Fold cross-validation splitter
4#[derive(Clone, Debug)]
5pub struct KFold {
6    n_splits: usize,
7    shuffle: bool,
8    seed: u64,
9}
10
11impl KFold {
12    /// Create a new KFold splitter
13    pub fn new(n_splits: usize) -> Self {
14        Self { n_splits, shuffle: true, seed: 42 }
15    }
16
17    /// Set random seed for shuffling
18    pub fn with_seed(mut self, seed: u64) -> Self {
19        self.seed = seed;
20        self
21    }
22
23    /// Disable shuffling
24    pub fn without_shuffle(mut self) -> Self {
25        self.shuffle = false;
26        self
27    }
28
29    /// Generate train/test indices for each fold
30    pub fn split(&self, n_samples: usize) -> Vec<(Vec<usize>, Vec<usize>)> {
31        let mut indices: Vec<usize> = (0..n_samples).collect();
32
33        if self.shuffle {
34            // Simple LCG-based shuffle for reproducibility
35            let mut rng_state = self.seed;
36            for i in (1..n_samples).rev() {
37                rng_state = rng_state.wrapping_mul(6364136223846793005).wrapping_add(1);
38                let j = (rng_state >> 33) as usize % (i + 1);
39                indices.swap(i, j);
40            }
41        }
42
43        let fold_size = n_samples / self.n_splits;
44        let remainder = n_samples % self.n_splits;
45
46        let mut folds = Vec::with_capacity(self.n_splits);
47        let mut start = 0;
48
49        for i in 0..self.n_splits {
50            let extra = usize::from(i < remainder);
51            let end = start + fold_size + extra;
52
53            let test_indices: Vec<usize> = indices[start..end].to_vec();
54            let train_indices: Vec<usize> =
55                indices[..start].iter().chain(indices[end..].iter()).copied().collect();
56
57            folds.push((train_indices, test_indices));
58            start = end;
59        }
60
61        folds
62    }
63}