use rand::seq::SliceRandom;
use rand::SeedableRng;
#[derive(Debug, Clone)]
pub struct HoldoutSplit {
pub train: Vec<usize>,
pub validation: Vec<usize>,
pub calibration: Vec<usize>,
}
impl HoldoutSplit {
pub fn train_len(&self) -> usize {
self.train.len()
}
pub fn val_len(&self) -> usize {
self.validation.len()
}
pub fn calib_len(&self) -> usize {
self.calibration.len()
}
}
#[derive(Debug, Clone)]
pub struct KFoldSplit {
pub folds: Vec<Vec<usize>>,
}
impl KFoldSplit {
pub fn k(&self) -> usize {
self.folds.len()
}
pub fn get_fold(&self, fold_idx: usize) -> (Vec<usize>, Vec<usize>) {
assert!(fold_idx < self.folds.len(), "fold_idx out of range");
let mut validation = self.folds[fold_idx].clone();
let mut train: Vec<usize> = self
.folds
.iter()
.enumerate()
.filter(|(i, _)| *i != fold_idx)
.flat_map(|(_, fold)| fold.iter().cloned())
.collect();
train.sort_unstable();
validation.sort_unstable();
(train, validation)
}
}
pub fn split_holdout(
num_rows: usize,
validation_ratio: f32,
calibration_ratio: f32,
seed: u64,
) -> HoldoutSplit {
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
let mut indices: Vec<usize> = (0..num_rows).collect();
indices.shuffle(&mut rng);
let n_calibration = if calibration_ratio > 0.0 {
((num_rows as f32) * calibration_ratio).ceil() as usize
} else {
0
};
let mut calibration: Vec<usize> = indices.drain(..n_calibration).collect();
let n_validation = if validation_ratio > 0.0 {
((indices.len() as f32) * validation_ratio / (1.0 - calibration_ratio)).ceil() as usize
} else {
0
};
let mut validation: Vec<usize> = indices.drain(..n_validation).collect();
let mut train = indices;
train.sort_unstable();
validation.sort_unstable();
calibration.sort_unstable();
HoldoutSplit {
train,
validation,
calibration,
}
}
pub fn split_kfold(num_rows: usize, k: usize, seed: u64) -> KFoldSplit {
assert!(k >= 2, "K-fold requires at least 2 folds");
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
let mut indices: Vec<usize> = (0..num_rows).collect();
indices.shuffle(&mut rng);
let fold_size = num_rows / k;
let remainder = num_rows % k;
let mut folds = Vec::with_capacity(k);
let mut start = 0;
for i in 0..k {
let extra = if i < remainder { 1 } else { 0 };
let end = start + fold_size + extra;
let mut fold: Vec<usize> = indices[start..end].to_vec();
fold.sort_unstable(); folds.push(fold);
start = end;
}
KFoldSplit { folds }
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashSet;
#[test]
fn test_split_holdout_basic() {
let split = split_holdout(1000, 0.2, 0.0, 42);
assert_eq!(split.train.len() + split.validation.len(), 1000);
assert!(split.validation.len() >= 190 && split.validation.len() <= 210);
assert_eq!(split.calibration.len(), 0);
}
#[test]
fn test_split_holdout_three_way() {
let split = split_holdout(1000, 0.2, 0.1, 42);
let total = split.train.len() + split.validation.len() + split.calibration.len();
assert_eq!(total, 1000);
let train_set: HashSet<_> = split.train.iter().collect();
let val_set: HashSet<_> = split.validation.iter().collect();
let calib_set: HashSet<_> = split.calibration.iter().collect();
assert!(train_set.is_disjoint(&val_set));
assert!(train_set.is_disjoint(&calib_set));
assert!(val_set.is_disjoint(&calib_set));
}
#[test]
fn test_split_holdout_sorted() {
let split = split_holdout(1000, 0.2, 0.1, 42);
assert!(split.train.windows(2).all(|w| w[0] < w[1]));
assert!(split.validation.windows(2).all(|w| w[0] < w[1]));
if split.calibration.len() > 1 {
assert!(split.calibration.windows(2).all(|w| w[0] < w[1]));
}
}
#[test]
fn test_split_holdout_deterministic() {
let split1 = split_holdout(1000, 0.2, 0.0, 42);
let split2 = split_holdout(1000, 0.2, 0.0, 42);
assert_eq!(split1.train, split2.train);
assert_eq!(split1.validation, split2.validation);
let split3 = split_holdout(1000, 0.2, 0.0, 43);
assert_ne!(split1.train, split3.train);
}
#[test]
fn test_split_kfold_basic() {
let split = split_kfold(100, 5, 42);
assert_eq!(split.k(), 5);
let all_indices: HashSet<_> = split.folds.iter().flatten().cloned().collect();
assert_eq!(all_indices.len(), 100);
}
#[test]
fn test_split_kfold_disjoint() {
let split = split_kfold(100, 5, 42);
for i in 0..5 {
for j in (i + 1)..5 {
let set_i: HashSet<_> = split.folds[i].iter().collect();
let set_j: HashSet<_> = split.folds[j].iter().collect();
assert!(set_i.is_disjoint(&set_j), "Folds {} and {} overlap", i, j);
}
}
}
#[test]
fn test_split_kfold_fold_sizes() {
let split = split_kfold(100, 5, 42);
for fold in &split.folds {
assert_eq!(fold.len(), 20);
}
let split = split_kfold(103, 5, 42);
let sizes: Vec<_> = split.folds.iter().map(|f| f.len()).collect();
assert_eq!(sizes.iter().sum::<usize>(), 103);
assert!(sizes.iter().all(|&s| s == 20 || s == 21));
}
#[test]
fn test_split_kfold_get_fold() {
let split = split_kfold(100, 5, 42);
for i in 0..5 {
let (train, val) = split.get_fold(i);
assert_eq!(train.len() + val.len(), 100);
assert_eq!(val.len(), 20);
let train_set: HashSet<_> = train.iter().collect();
let val_set: HashSet<_> = val.iter().collect();
assert!(train_set.is_disjoint(&val_set));
assert!(train.windows(2).all(|w| w[0] < w[1]));
assert!(val.windows(2).all(|w| w[0] < w[1]));
}
}
#[test]
fn test_split_kfold_deterministic() {
let split1 = split_kfold(100, 5, 42);
let split2 = split_kfold(100, 5, 42);
for i in 0..5 {
assert_eq!(split1.folds[i], split2.folds[i]);
}
let split3 = split_kfold(100, 5, 43);
assert_ne!(split1.folds[0], split3.folds[0]);
}
#[test]
fn test_split_kfold_sorted() {
let split = split_kfold(100, 5, 42);
for fold in &split.folds {
assert!(fold.windows(2).all(|w| w[0] < w[1]));
}
}
#[test]
#[should_panic(expected = "at least 2 folds")]
fn test_split_kfold_invalid_k() {
split_kfold(100, 1, 42);
}
#[test]
fn test_split_holdout_no_validation() {
let split = split_holdout(1000, 0.0, 0.0, 42);
assert_eq!(split.train.len(), 1000);
assert_eq!(split.validation.len(), 0);
assert_eq!(split.calibration.len(), 0);
}
}