use scirs2_core::ndarray::ArrayBase;
use scirs2_core::random::{rngs::StdRng, seq::SliceRandom, SeedableRng};
use std::collections::{HashMap, HashSet};
use crate::error::{MetricsError, Result};
pub type NestedCVResult = Vec<(Vec<usize>, Vec<usize>, Vec<(Vec<usize>, Vec<usize>)>)>;
#[allow(dead_code)]
pub fn k_fold_cross_validation(
n: usize,
n_folds: usize,
shuffle: bool,
random_seed: Option<u64>,
) -> Result<Vec<(Vec<usize>, Vec<usize>)>> {
if n <= 1 {
return Err(MetricsError::InvalidInput(
"Number of samples must be greater than 1".to_string(),
));
}
if n_folds < 2 {
return Err(MetricsError::InvalidInput(
"Number of _folds must be at least 2".to_string(),
));
}
if n_folds > n {
return Err(MetricsError::InvalidInput(format!(
"Number of _folds ({}) cannot be greater than number of samples ({})",
n_folds, n
)));
}
let mut indices: Vec<usize> = (0..n).collect();
if shuffle {
let mut rng = match random_seed {
Some(_seed) => StdRng::seed_from_u64(_seed),
None => {
let mut r = scirs2_core::random::rng();
StdRng::from_rng(&mut r)
}
};
indices.shuffle(&mut rng);
}
let fold_sizes = (0..n_folds)
.map(|i| (n - i) / n_folds + !(n - i).is_multiple_of(n_folds) as usize)
.collect::<Vec<_>>();
let mut current = 0;
let mut folds = Vec::with_capacity(n_folds);
for fold_size in fold_sizes {
let test_indices = indices[current..(current + fold_size)].to_vec();
let mut train_indices = Vec::with_capacity(n - fold_size);
train_indices.extend_from_slice(&indices[0..current]);
train_indices.extend_from_slice(&indices[(current + fold_size)..]);
folds.push((train_indices, test_indices));
current += fold_size;
}
Ok(folds)
}
#[allow(dead_code)]
pub fn leave_one_out_cv(n: usize) -> Result<Vec<(Vec<usize>, Vec<usize>)>> {
if n <= 1 {
return Err(MetricsError::InvalidInput(
"Number of samples must be greater than 1".to_string(),
));
}
let mut splits = Vec::with_capacity(n);
for i in 0..n {
let test_indices = vec![i];
let mut train_indices = Vec::with_capacity(n - 1);
for j in 0..n {
if j != i {
train_indices.push(j);
}
}
splits.push((train_indices, test_indices));
}
Ok(splits)
}
#[allow(dead_code)]
pub fn stratified_k_fold<T>(
y: &ArrayBase<impl scirs2_core::ndarray::Data<Elem = T>, impl scirs2_core::ndarray::Dimension>,
n_folds: usize,
shuffle: bool,
random_seed: Option<u64>,
) -> Result<Vec<(Vec<usize>, Vec<usize>)>>
where
T: Clone + std::hash::Hash + Eq + std::fmt::Debug,
{
let n_samples = y.len();
if n_samples <= 1 {
return Err(MetricsError::InvalidInput(
"Number of samples must be greater than 1".to_string(),
));
}
if n_folds < 2 {
return Err(MetricsError::InvalidInput(
"Number of _folds must be at least 2".to_string(),
));
}
if n_folds > n_samples {
return Err(MetricsError::InvalidInput(format!(
"Number of _folds ({}) cannot be greater than number of samples ({})",
n_folds, n_samples
)));
}
let mut class_counts = HashMap::new();
for (i, val) in y.iter().enumerate() {
class_counts
.entry(val.clone())
.or_insert_with(Vec::new)
.push(i);
}
for (class, indices) in &class_counts {
let class_size = indices.len();
if class_size < n_folds {
return Err(MetricsError::InvalidInput(format!(
"Class {:?} has only {} samples, which is less than n_folds={}",
class, class_size, n_folds
)));
}
}
let mut rng = match random_seed {
Some(_seed) => Some(StdRng::seed_from_u64(_seed)),
None if shuffle => {
let mut r = scirs2_core::random::rng();
Some(StdRng::from_rng(&mut r))
}
None => None,
};
if shuffle {
let rng = rng.as_mut().expect("Operation failed");
for indices in class_counts.values_mut() {
indices.shuffle(rng);
}
}
let mut folds = vec![Vec::new(); n_folds];
for indices in class_counts.values() {
for (i, &idx) in indices.iter().enumerate() {
folds[i % n_folds].push(idx);
}
}
let mut splits = Vec::with_capacity(n_folds);
for i in 0..n_folds {
let test_indices = folds[i].clone();
let mut train_indices = Vec::with_capacity(n_samples - test_indices.len());
for (j, fold) in folds.iter().enumerate() {
if j != i {
train_indices.extend_from_slice(fold);
}
}
train_indices.sort_unstable();
splits.push((train_indices, test_indices));
}
Ok(splits)
}
#[allow(clippy::too_many_arguments)]
#[allow(dead_code)]
pub fn time_series_split(
n: usize,
n_splits: usize,
test_size: usize,
gap: usize,
max_train_size: Option<usize>,
) -> Result<Vec<(Vec<usize>, Vec<usize>)>> {
if n <= test_size {
return Err(MetricsError::InvalidInput(format!(
"Number of samples ({}) must be greater than test_size ({})",
n, test_size
)));
}
if test_size == 0 {
return Err(MetricsError::InvalidInput(
"test_size must be greater than 0".to_string(),
));
}
if n_splits < 1 {
return Err(MetricsError::InvalidInput(
"n_splits must be at least 1".to_string(),
));
}
let mut splits = Vec::with_capacity(n_splits);
let size_needed = (n_splits - 1) * (test_size + gap) + test_size;
if size_needed > n {
return Err(MetricsError::InvalidInput(format!(
"Cannot perform {} _splits with test_size={} and gap={} on {} samples",
n_splits, test_size, gap, n
)));
}
let mut test_end = n - (n_splits - 1) * (test_size + gap);
for _ in 0..n_splits {
let train_end = test_end - gap - test_size;
let test_start = train_end + gap;
let train_start = if let Some(max_size) = max_train_size {
train_end.saturating_sub(max_size)
} else {
0
};
let train_indices: Vec<usize> = (train_start..train_end).collect();
let test_indices: Vec<usize> = (test_start..test_start + test_size).collect();
splits.push((train_indices, test_indices));
test_end += test_size + gap;
}
Ok(splits)
}
#[allow(dead_code)]
pub fn grouped_k_fold<T>(
groups: &ArrayBase<
impl scirs2_core::ndarray::Data<Elem = T>,
impl scirs2_core::ndarray::Dimension,
>,
n_folds: usize,
) -> Result<Vec<(Vec<usize>, Vec<usize>)>>
where
T: Clone + std::hash::Hash + Eq + std::fmt::Debug,
{
let n_samples = groups.len();
if n_samples <= 1 {
return Err(MetricsError::InvalidInput(
"Number of samples must be greater than 1".to_string(),
));
}
if n_folds < 2 {
return Err(MetricsError::InvalidInput(
"Number of _folds must be at least 2".to_string(),
));
}
let mut unique_groups = HashSet::new();
for group in groups.iter() {
unique_groups.insert(group.clone());
}
let n_groups = unique_groups.len();
if n_folds > n_groups {
return Err(MetricsError::InvalidInput(format!(
"Number of _folds ({}) cannot be greater than number of groups ({})",
n_folds, n_groups
)));
}
let mut group_indices: HashMap<T, Vec<usize>> = HashMap::new();
for (i, group) in groups.iter().enumerate() {
group_indices.entry(group.clone()).or_default().push(i);
}
let groups_list: Vec<Vec<usize>> = group_indices.values().cloned().collect();
let mut folds: Vec<Vec<usize>> = vec![Vec::new(); n_folds];
let mut fold_sizes = vec![0; n_folds];
let mut groups_list_with_size: Vec<(usize, Vec<usize>)> = groups_list
.into_iter()
.map(|indices| (indices.len(), indices))
.collect();
groups_list_with_size.sort_unstable_by(|a, b| b.0.cmp(&a.0));
for (_, indices) in groups_list_with_size {
let fold_idx = fold_sizes
.iter()
.enumerate()
.min_by_key(|&(_, &size)| size)
.map(|(idx, _)| idx)
.expect("Operation failed");
folds[fold_idx].extend_from_slice(&indices);
fold_sizes[fold_idx] += indices.len();
}
let mut splits = Vec::with_capacity(n_folds);
for i in 0..n_folds {
let test_indices = folds[i].clone();
let mut train_indices = Vec::with_capacity(n_samples - test_indices.len());
for (j, fold) in folds.iter().enumerate() {
if j != i {
train_indices.extend_from_slice(fold);
}
}
train_indices.sort_unstable();
splits.push((train_indices, test_indices));
}
Ok(splits)
}
#[allow(clippy::too_many_arguments)]
#[allow(dead_code)]
pub fn nested_cross_validation(
n: usize,
outer_n_folds: usize,
inner_n_folds: usize,
shuffle: bool,
random_seed: Option<u64>,
) -> Result<NestedCVResult> {
if n <= outer_n_folds {
return Err(MetricsError::InvalidInput(format!(
"Number of samples ({}) must be greater than outer_n_folds ({})",
n, outer_n_folds
)));
}
if outer_n_folds < 2 {
return Err(MetricsError::InvalidInput(
"outer_n_folds must be at least 2".to_string(),
));
}
if inner_n_folds < 2 {
return Err(MetricsError::InvalidInput(
"inner_n_folds must be at least 2".to_string(),
));
}
let outer_splits = k_fold_cross_validation(n, outer_n_folds, shuffle, random_seed)?;
let mut nested_splits = Vec::with_capacity(outer_n_folds);
for (outer_fold_idx, (outer_train, outer_test)) in outer_splits.into_iter().enumerate() {
let inner_seed = random_seed.map(|seed| seed.wrapping_add(outer_fold_idx as u64));
let n_inner = outer_train.len();
let inner_raw_splits =
k_fold_cross_validation(n_inner, inner_n_folds, shuffle, inner_seed)?;
let inner_splits = inner_raw_splits
.into_iter()
.map(|(inner_train_idx, inner_val_idx)| {
let inner_train = inner_train_idx
.into_iter()
.map(|idx| outer_train[idx])
.collect();
let inner_val = inner_val_idx
.into_iter()
.map(|idx| outer_train[idx])
.collect();
(inner_train, inner_val)
})
.collect();
nested_splits.push((outer_train, outer_test, inner_splits));
}
Ok(nested_splits)
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
#[test]
fn test_k_fold_cross_validation() {
let splits = k_fold_cross_validation(10, 3, false, None).expect("Operation failed");
assert_eq!(splits.len(), 3);
for (train_indices, test_indices) in &splits {
assert_eq!(train_indices.len() + test_indices.len(), 10);
assert!(test_indices.len() >= 3); assert!(test_indices.len() <= 4); }
let mut all_test_indices = Vec::new();
for (_, test_indices) in &splits {
all_test_indices.extend_from_slice(test_indices);
}
all_test_indices.sort_unstable();
assert_eq!(all_test_indices, (0..10).collect::<Vec<_>>());
}
#[test]
fn test_leave_one_out_cv() {
let splits = leave_one_out_cv(5).expect("Operation failed");
assert_eq!(splits.len(), 5);
for (train_indices, test_indices) in &splits {
assert_eq!(train_indices.len(), 4);
assert_eq!(test_indices.len(), 1);
}
let test_indices: Vec<usize> = splits.iter().map(|(_, test)| test[0]).collect();
assert_eq!(test_indices, vec![0, 1, 2, 3, 4]);
}
#[test]
fn test_stratified_k_fold() {
let y = array![0, 0, 0, 0, 1, 1, 1, 2, 2, 2, 2, 2, 2];
let splits = stratified_k_fold(&y, 3, false, None).expect("Operation failed");
assert_eq!(splits.len(), 3);
for (_, test_indices) in &splits {
let mut class_counts = HashMap::new();
for &idx in test_indices {
let class = y[idx];
*class_counts.entry(class).or_insert(0) += 1;
}
assert!(class_counts.get(&0).map_or(0, |&c| c) >= 1);
assert!(class_counts.get(&0).map_or(0, |&c| c) <= 2);
assert!(class_counts.get(&1).map_or(0, |&c| c) >= 1);
assert!(class_counts.get(&1).map_or(0, |&c| c) <= 1);
assert!(class_counts.get(&2).map_or(0, |&c| c) >= 2);
assert!(class_counts.get(&2).map_or(0, |&c| c) <= 2);
}
let mut all_test_indices = Vec::new();
for (_, test_indices) in &splits {
all_test_indices.extend_from_slice(test_indices);
}
all_test_indices.sort_unstable();
assert_eq!(all_test_indices, (0..13).collect::<Vec<_>>());
}
#[test]
fn test_time_series_split() {
let splits = time_series_split(10, 3, 2, 0, None).expect("Operation failed");
assert_eq!(splits.len(), 3);
let (train_indices, test_indices) = &splits[0];
assert_eq!(train_indices, &[0, 1, 2, 3]);
assert_eq!(test_indices, &[4, 5]);
let (train_indices, test_indices) = &splits[1];
assert_eq!(train_indices, &[0, 1, 2, 3, 4, 5]);
assert_eq!(test_indices, &[6, 7]);
let (train_indices, test_indices) = &splits[2];
assert_eq!(train_indices, &[0, 1, 2, 3, 4, 5, 6, 7]);
assert_eq!(test_indices, &[8, 9]);
let splits = time_series_split(12, 3, 2, 1, None).expect("Operation failed");
let (train_indices, test_indices) = &splits[0];
assert_eq!(train_indices, &[0, 1, 2]);
assert_eq!(test_indices, &[4, 5]);
let splits = time_series_split(10, 3, 2, 0, Some(3)).expect("Operation failed");
let (train_indices, test_indices) = &splits[0];
assert_eq!(train_indices, &[1, 2, 3]); assert_eq!(test_indices, &[4, 5]);
let (train_indices, test_indices) = &splits[1];
assert_eq!(train_indices, &[3, 4, 5]); assert_eq!(test_indices, &[6, 7]);
}
#[test]
fn test_grouped_k_fold() {
let groups = array!["A", "A", "A", "B", "B", "C", "C", "C"];
let splits = grouped_k_fold(&groups, 3).expect("Operation failed");
assert_eq!(splits.len(), 3);
for (train_indices, test_indices) in &splits {
let mut train_groups = HashSet::new();
let mut test_groups = HashSet::new();
for &idx in train_indices {
train_groups.insert(groups[idx]);
}
for &idx in test_indices {
test_groups.insert(groups[idx]);
}
for group in &test_groups {
assert!(!train_groups.contains(group));
}
}
let mut all_test_indices = Vec::new();
for (_, test_indices) in &splits {
all_test_indices.extend_from_slice(test_indices);
}
all_test_indices.sort_unstable();
assert_eq!(all_test_indices, (0..8).collect::<Vec<_>>());
}
#[test]
fn test_nested_cross_validation() {
let nested_cv =
nested_cross_validation(20, 5, 3, true, Some(42)).expect("Operation failed");
assert_eq!(nested_cv.len(), 5);
for (outer_train, outer_test, inner_splits) in &nested_cv {
for &test_idx in outer_test {
assert!(!outer_train.contains(&test_idx));
}
assert_eq!(inner_splits.len(), 3);
for (inner_train, inner_val) in inner_splits {
for &train_idx in inner_train {
assert!(outer_train.contains(&train_idx));
}
for &val_idx in inner_val {
assert!(outer_train.contains(&val_idx));
}
for &val_idx in inner_val {
assert!(!inner_train.contains(&val_idx));
}
assert_eq!(inner_train.len() + inner_val.len(), outer_train.len());
}
}
let mut all_test_indices = Vec::new();
for (_, outer_test_, _) in &nested_cv {
all_test_indices.extend_from_slice(outer_test_);
}
all_test_indices.sort_unstable();
assert_eq!(all_test_indices, (0..20).collect::<Vec<_>>());
}
}