use crate::error::{Error, Result};
pub fn create_kfold_splits(
n_samples: usize,
n_folds: usize,
shuffle: bool,
seed: Option<u64>,
) -> Result<Vec<(Vec<usize>, Vec<usize>)>> {
if n_folds < 2 {
return Err(Error::InvalidInput("n_folds must be at least 2".to_string()));
}
if n_samples < n_folds {
return Err(Error::InsufficientData {
required: n_folds,
available: n_samples,
});
}
let mut indices: Vec<usize> = (0..n_samples).collect();
if shuffle {
let seed_val = seed.unwrap_or(42);
fisher_yates_shuffle(&mut indices, seed_val);
}
let fold_size = n_samples / n_folds;
let remainder = n_samples % n_folds;
let mut folds: Vec<Vec<usize>> = Vec::with_capacity(n_folds);
let mut start = 0;
for fold_idx in 0..n_folds {
let size = fold_size + if fold_idx < remainder { 1 } else { 0 };
let end = start + size;
folds.push(indices[start..end].to_vec());
start = end;
}
let mut splits: Vec<(Vec<usize>, Vec<usize>)> = Vec::with_capacity(n_folds);
for test_indices in folds {
let train_indices: Vec<usize> = indices
.iter()
.filter(|&i| !test_indices.contains(i))
.copied()
.collect();
splits.push((train_indices, test_indices));
}
Ok(splits)
}
pub fn fisher_yates_shuffle(indices: &mut [usize], seed: u64) {
if indices.is_empty() {
return;
}
let mut rng = Lcg::new(seed);
let n = indices.len();
for i in (1..n).rev() {
let j = rng.next_usize(0, i);
indices.swap(i, j);
}
}
struct Lcg {
state: u64,
}
impl Lcg {
fn new(seed: u64) -> Self {
Lcg { state: seed.wrapping_add(1) }
}
#[inline]
fn next(&mut self) -> u32 {
const A: u64 = 1103515245;
const C: u64 = 12345;
self.state = self.state.wrapping_mul(A).wrapping_add(C);
(self.state & 0xFFFFFFFF) as u32
}
#[inline]
fn next_usize(&mut self, min: usize, max: usize) -> usize {
let range = max - min + 1;
if range <= 1 {
return min;
}
let threshold = (u32::MAX / range as u32) * range as u32;
loop {
let val = self.next();
if val < threshold {
return min + (val as usize) % range;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_create_kfold_splits_basic() {
let splits = create_kfold_splits(10, 3, false, None).unwrap();
assert_eq!(splits.len(), 3);
assert_eq!(splits[0].1.len(), 4); assert_eq!(splits[0].0.len(), 6);
assert_eq!(splits[1].1.len(), 3);
assert_eq!(splits[1].0.len(), 7);
assert_eq!(splits[2].1.len(), 3);
assert_eq!(splits[2].0.len(), 7);
}
#[test]
fn test_create_kfold_splits_no_remainder() {
let splits = create_kfold_splits(9, 3, false, None).unwrap();
assert_eq!(splits.len(), 3);
for (train, test) in &splits {
assert_eq!(test.len(), 3);
assert_eq!(train.len(), 6);
}
}
#[test]
fn test_create_kfold_splits_all_samples_used() {
let n_samples = 20;
let splits = create_kfold_splits(n_samples, 5, false, None).unwrap();
let mut all_test_indices: Vec<usize> = Vec::new();
for (_, test) in &splits {
all_test_indices.extend(test);
}
all_test_indices.sort();
let expected: Vec<usize> = (0..n_samples).collect();
assert_eq!(all_test_indices, expected);
}
#[test]
fn test_create_kfold_splits_no_overlap() {
let splits = create_kfold_splits(15, 5, false, None).unwrap();
for (train, test) in &splits {
for &t in test {
assert!(!train.contains(&t));
}
}
}
#[test]
fn test_create_kfold_splits_with_shuffle() {
let splits1 = create_kfold_splits(20, 5, true, Some(42)).unwrap();
let splits2 = create_kfold_splits(20, 5, true, Some(42)).unwrap();
assert_eq!(splits1.len(), splits2.len());
for (s1, s2) in splits1.iter().zip(splits2.iter()) {
assert_eq!(s1.0, s2.0);
assert_eq!(s1.1, s2.1);
}
let splits3 = create_kfold_splits(20, 5, true, Some(123)).unwrap();
assert_ne!(splits1[0].1, splits3[0].1);
}
#[test]
fn test_create_kfold_splits_no_shuffle_reproducible() {
let splits1 = create_kfold_splits(10, 3, false, None).unwrap();
let splits2 = create_kfold_splits(10, 3, false, None).unwrap();
assert_eq!(splits1[0].1, vec![0, 1, 2, 3]);
assert_eq!(splits1, splits2);
}
#[test]
fn test_create_kfold_splits_invalid_folds() {
let result = create_kfold_splits(10, 1, false, None);
assert!(result.is_err());
}
#[test]
fn test_create_kfold_splits_insufficient_samples() {
let result = create_kfold_splits(5, 10, false, None);
assert!(result.is_err());
match result {
Err(Error::InsufficientData { required, available }) => {
assert_eq!(required, 10);
assert_eq!(available, 5);
}
_ => panic!("Expected InsufficientData error"),
}
}
#[test]
fn test_fisher_yates_shuffle_deterministic() {
let mut indices1 = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
let mut indices2 = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
fisher_yates_shuffle(&mut indices1, 42);
fisher_yates_shuffle(&mut indices2, 42);
assert_eq!(indices1, indices2);
}
#[test]
fn test_fisher_yates_shuffle_different_seeds() {
let mut indices1 = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
let mut indices2 = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
fisher_yates_shuffle(&mut indices1, 42);
fisher_yates_shuffle(&mut indices2, 123);
assert_ne!(indices1, indices2);
}
#[test]
fn test_fisher_yates_shuffle_permutation() {
let mut indices = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
let original = indices.clone();
fisher_yates_shuffle(&mut indices, 42);
let mut sorted1 = indices.clone();
let mut sorted2 = original.clone();
sorted1.sort();
sorted2.sort();
assert_eq!(sorted1, sorted2);
}
#[test]
fn test_fisher_yates_shuffle_empty() {
let mut indices: Vec<usize> = vec![];
fisher_yates_shuffle(&mut indices, 42);
assert!(indices.is_empty());
}
#[test]
fn test_fisher_yates_shuffle_single() {
let mut indices = vec![42];
fisher_yates_shuffle(&mut indices, 42);
assert_eq!(indices, vec![42]);
}
#[test]
fn test_lcg_range() {
let mut rng = Lcg::new(42);
for _ in 0..1000 {
let val = rng.next_usize(0, 99);
assert!(val <= 99);
}
}
#[test]
fn test_lcg_uniform_distribution() {
let mut rng = Lcg::new(42);
const N: usize = 10000;
const RANGE: usize = 10;
let mut counts = [0usize; RANGE];
for _ in 0..N {
let val = rng.next_usize(0, RANGE - 1);
counts[val] += 1;
}
let expected = N / RANGE;
for count in counts {
assert!(
count > expected * 80 / 100 && count < expected * 120 / 100,
"Count {} is outside expected range {}",
count,
expected
);
}
}
}