entrenar/eval/evaluator/
kfold.rs1#[derive(Clone, Debug)]
5pub struct KFold {
6 n_splits: usize,
7 shuffle: bool,
8 seed: u64,
9}
10
11impl KFold {
12 pub fn new(n_splits: usize) -> Self {
14 Self { n_splits, shuffle: true, seed: 42 }
15 }
16
17 pub fn with_seed(mut self, seed: u64) -> Self {
19 self.seed = seed;
20 self
21 }
22
23 pub fn without_shuffle(mut self) -> Self {
25 self.shuffle = false;
26 self
27 }
28
29 pub fn split(&self, n_samples: usize) -> Vec<(Vec<usize>, Vec<usize>)> {
31 let mut indices: Vec<usize> = (0..n_samples).collect();
32
33 if self.shuffle {
34 let mut rng_state = self.seed;
36 for i in (1..n_samples).rev() {
37 rng_state = rng_state.wrapping_mul(6364136223846793005).wrapping_add(1);
38 let j = (rng_state >> 33) as usize % (i + 1);
39 indices.swap(i, j);
40 }
41 }
42
43 let fold_size = n_samples / self.n_splits;
44 let remainder = n_samples % self.n_splits;
45
46 let mut folds = Vec::with_capacity(self.n_splits);
47 let mut start = 0;
48
49 for i in 0..self.n_splits {
50 let extra = usize::from(i < remainder);
51 let end = start + fold_size + extra;
52
53 let test_indices: Vec<usize> = indices[start..end].to_vec();
54 let train_indices: Vec<usize> =
55 indices[..start].iter().chain(indices[end..].iter()).copied().collect();
56
57 folds.push((train_indices, test_indices));
58 start = end;
59 }
60
61 folds
62 }
63}