use crate::{TrainError, TrainResult};
use scirs2_core::random::{SeedableRng, StdRng};
use std::collections::HashMap;
pub trait CrossValidationSplit {
fn num_splits(&self) -> usize;
fn get_split(&self, fold: usize, n_samples: usize) -> TrainResult<(Vec<usize>, Vec<usize>)>;
}
#[derive(Debug, Clone)]
pub struct KFold {
pub n_splits: usize,
pub shuffle: bool,
pub random_seed: u64,
}
impl KFold {
pub fn new(n_splits: usize) -> TrainResult<Self> {
if n_splits < 2 {
return Err(TrainError::InvalidParameter(
"n_splits must be at least 2".to_string(),
));
}
Ok(Self {
n_splits,
shuffle: false,
random_seed: 42,
})
}
pub fn with_shuffle(mut self, seed: u64) -> Self {
self.shuffle = true;
self.random_seed = seed;
self
}
}
impl CrossValidationSplit for KFold {
fn num_splits(&self) -> usize {
self.n_splits
}
fn get_split(&self, fold: usize, n_samples: usize) -> TrainResult<(Vec<usize>, Vec<usize>)> {
if fold >= self.n_splits {
return Err(TrainError::InvalidParameter(format!(
"fold {} is out of range [0, {})",
fold, self.n_splits
)));
}
let mut indices: Vec<usize> = (0..n_samples).collect();
if self.shuffle {
let mut rng = StdRng::seed_from_u64(self.random_seed);
for i in (1..n_samples).rev() {
let j = rng.gen_range(0..=i);
indices.swap(i, j);
}
}
let fold_size = n_samples / self.n_splits;
let remainder = n_samples % self.n_splits;
let mut fold_sizes = vec![fold_size; self.n_splits];
for fold in fold_sizes.iter_mut().take(remainder) {
*fold += 1;
}
let mut boundaries = vec![0];
for size in &fold_sizes {
boundaries.push(
boundaries
.last()
.expect("boundaries is initialized non-empty")
+ size,
);
}
let val_start = boundaries[fold];
let val_end = boundaries[fold + 1];
let val_indices: Vec<usize> = indices[val_start..val_end].to_vec();
let mut train_indices = Vec::new();
train_indices.extend_from_slice(&indices[..val_start]);
train_indices.extend_from_slice(&indices[val_end..]);
Ok((train_indices, val_indices))
}
}
#[derive(Debug, Clone)]
pub struct StratifiedKFold {
pub n_splits: usize,
pub shuffle: bool,
pub random_seed: u64,
}
impl StratifiedKFold {
pub fn new(n_splits: usize) -> TrainResult<Self> {
if n_splits < 2 {
return Err(TrainError::InvalidParameter(
"n_splits must be at least 2".to_string(),
));
}
Ok(Self {
n_splits,
shuffle: true,
random_seed: 42,
})
}
pub fn with_seed(mut self, seed: u64) -> Self {
self.random_seed = seed;
self
}
pub fn get_stratified_split(
&self,
fold: usize,
labels: &[usize],
) -> TrainResult<(Vec<usize>, Vec<usize>)> {
if fold >= self.n_splits {
return Err(TrainError::InvalidParameter(format!(
"fold {} is out of range [0, {})",
fold, self.n_splits
)));
}
let mut class_indices: HashMap<usize, Vec<usize>> = HashMap::new();
for (i, &label) in labels.iter().enumerate() {
class_indices.entry(label).or_default().push(i);
}
if self.shuffle {
let mut rng = StdRng::seed_from_u64(self.random_seed);
for indices in class_indices.values_mut() {
for i in (1..indices.len()).rev() {
let j = rng.gen_range(0..=i);
indices.swap(i, j);
}
}
}
let mut train_indices = Vec::new();
let mut val_indices = Vec::new();
for indices in class_indices.values() {
let class_size = indices.len();
let fold_size = class_size / self.n_splits;
let remainder = class_size % self.n_splits;
let mut fold_sizes = vec![fold_size; self.n_splits];
for fold in fold_sizes.iter_mut().take(remainder) {
*fold += 1;
}
let mut boundaries = vec![0];
for size in &fold_sizes {
boundaries.push(
boundaries
.last()
.expect("boundaries is initialized non-empty")
+ size,
);
}
let val_start = boundaries[fold];
let val_end = boundaries[fold + 1];
val_indices.extend_from_slice(&indices[val_start..val_end]);
train_indices.extend_from_slice(&indices[..val_start]);
train_indices.extend_from_slice(&indices[val_end..]);
}
Ok((train_indices, val_indices))
}
}
impl CrossValidationSplit for StratifiedKFold {
fn num_splits(&self) -> usize {
self.n_splits
}
fn get_split(&self, fold: usize, n_samples: usize) -> TrainResult<(Vec<usize>, Vec<usize>)> {
let labels: Vec<usize> = (0..n_samples).map(|i| i % self.n_splits).collect();
self.get_stratified_split(fold, &labels)
}
}
#[derive(Debug, Clone)]
pub struct TimeSeriesSplit {
pub n_splits: usize,
pub min_train_size: Option<usize>,
pub max_train_size: Option<usize>,
}
impl TimeSeriesSplit {
pub fn new(n_splits: usize) -> TrainResult<Self> {
if n_splits < 2 {
return Err(TrainError::InvalidParameter(
"n_splits must be at least 2".to_string(),
));
}
Ok(Self {
n_splits,
min_train_size: None,
max_train_size: None,
})
}
pub fn with_min_train_size(mut self, size: usize) -> Self {
self.min_train_size = Some(size);
self
}
pub fn with_max_train_size(mut self, size: usize) -> Self {
self.max_train_size = Some(size);
self
}
}
impl CrossValidationSplit for TimeSeriesSplit {
fn num_splits(&self) -> usize {
self.n_splits
}
fn get_split(&self, fold: usize, n_samples: usize) -> TrainResult<(Vec<usize>, Vec<usize>)> {
if fold >= self.n_splits {
return Err(TrainError::InvalidParameter(format!(
"fold {} is out of range [0, {})",
fold, self.n_splits
)));
}
let test_size = n_samples / (self.n_splits + 1);
if test_size == 0 {
return Err(TrainError::InvalidParameter(
"Not enough samples for time series split".to_string(),
));
}
let val_start = (fold + 1) * test_size;
let val_end = ((fold + 2) * test_size).min(n_samples);
let train_end = val_start;
let train_start = if let Some(max_size) = self.max_train_size {
train_end.saturating_sub(max_size)
} else if let Some(min_size) = self.min_train_size {
if train_end < min_size {
return Err(TrainError::InvalidParameter(
"Not enough samples for min_train_size".to_string(),
));
}
0
} else {
0
};
let train_indices: Vec<usize> = (train_start..train_end).collect();
let val_indices: Vec<usize> = (val_start..val_end).collect();
if train_indices.is_empty() {
return Err(TrainError::InvalidParameter(
"Training set is empty for this fold".to_string(),
));
}
Ok((train_indices, val_indices))
}
}
#[derive(Debug, Clone, Default)]
pub struct LeaveOneOut;
impl LeaveOneOut {
pub fn new() -> Self {
Self
}
}
impl CrossValidationSplit for LeaveOneOut {
fn num_splits(&self) -> usize {
usize::MAX
}
fn get_split(&self, fold: usize, n_samples: usize) -> TrainResult<(Vec<usize>, Vec<usize>)> {
if fold >= n_samples {
return Err(TrainError::InvalidParameter(format!(
"fold {} is out of range [0, {})",
fold, n_samples
)));
}
let val_indices = vec![fold];
let mut train_indices: Vec<usize> = (0..fold).collect();
train_indices.extend(fold + 1..n_samples);
Ok((train_indices, val_indices))
}
}
#[derive(Debug, Clone)]
pub struct CrossValidationResults {
pub fold_scores: Vec<f64>,
pub fold_metrics: Vec<HashMap<String, f64>>,
}
impl CrossValidationResults {
pub fn new() -> Self {
Self {
fold_scores: Vec::new(),
fold_metrics: Vec::new(),
}
}
pub fn add_fold(&mut self, score: f64, metrics: HashMap<String, f64>) {
self.fold_scores.push(score);
self.fold_metrics.push(metrics);
}
pub fn mean_score(&self) -> f64 {
if self.fold_scores.is_empty() {
return 0.0;
}
self.fold_scores.iter().sum::<f64>() / self.fold_scores.len() as f64
}
pub fn std_score(&self) -> f64 {
if self.fold_scores.len() <= 1 {
return 0.0;
}
let mean = self.mean_score();
let variance = self
.fold_scores
.iter()
.map(|&score| (score - mean).powi(2))
.sum::<f64>()
/ (self.fold_scores.len() - 1) as f64;
variance.sqrt()
}
pub fn mean_metric(&self, metric_name: &str) -> Option<f64> {
if self.fold_metrics.is_empty() {
return None;
}
let mut sum = 0.0;
let mut count = 0;
for metrics in &self.fold_metrics {
if let Some(&value) = metrics.get(metric_name) {
sum += value;
count += 1;
}
}
if count > 0 {
Some(sum / count as f64)
} else {
None
}
}
pub fn num_folds(&self) -> usize {
self.fold_scores.len()
}
}
impl Default for CrossValidationResults {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_kfold_basic() {
let kfold = KFold::new(3).expect("unwrap");
assert_eq!(kfold.num_splits(), 3);
let (train, val) = kfold.get_split(0, 10).expect("unwrap");
assert!(!train.is_empty());
assert!(!val.is_empty());
for &idx in &val {
assert!(!train.contains(&idx));
}
let mut all_indices = train.clone();
all_indices.extend(&val);
all_indices.sort();
assert_eq!(all_indices, (0..10).collect::<Vec<_>>());
}
#[test]
fn test_kfold_with_shuffle() {
let kfold = KFold::new(3).expect("unwrap").with_shuffle(42);
let (train1, val1) = kfold.get_split(0, 10).expect("unwrap");
let (train2, val2) = kfold.get_split(0, 10).expect("unwrap");
assert_eq!(train1, train2);
assert_eq!(val1, val2);
}
#[test]
fn test_kfold_invalid() {
assert!(KFold::new(1).is_err());
let kfold = KFold::new(3).expect("unwrap");
assert!(kfold.get_split(5, 10).is_err()); }
#[test]
fn test_stratified_kfold() {
let skfold = StratifiedKFold::new(3).expect("unwrap");
assert_eq!(skfold.num_splits(), 3);
let labels = vec![0, 0, 0, 1, 1, 1, 2, 2, 2];
let (_train, val) = skfold.get_stratified_split(0, &labels).expect("unwrap");
let mut val_classes: Vec<usize> = val.iter().map(|&i| labels[i]).collect();
val_classes.sort();
val_classes.dedup();
assert!(!val.is_empty());
}
#[test]
fn test_time_series_split() {
let ts_split = TimeSeriesSplit::new(3).expect("unwrap");
assert_eq!(ts_split.num_splits(), 3);
let (train, val) = ts_split.get_split(0, 10).expect("unwrap");
if !train.is_empty() && !val.is_empty() {
assert!(train.iter().max().expect("unwrap") < val.iter().min().expect("unwrap"));
}
}
#[test]
fn test_time_series_split_with_window() {
let ts_split = TimeSeriesSplit::new(3)
.expect("unwrap")
.with_min_train_size(2)
.with_max_train_size(5);
let (train, val) = ts_split.get_split(1, 20).expect("unwrap");
assert!(train.len() <= 5);
assert!(!val.is_empty());
}
#[test]
fn test_time_series_split_invalid() {
let ts_split = TimeSeriesSplit::new(3).expect("unwrap");
assert!(ts_split.get_split(0, 2).is_err());
assert!(ts_split.get_split(5, 10).is_err());
}
#[test]
fn test_leave_one_out() {
let loo = LeaveOneOut::new();
let (train, val) = loo.get_split(0, 5).expect("unwrap");
assert_eq!(val.len(), 1);
assert_eq!(train.len(), 4);
assert_eq!(val[0], 0);
let (train, val) = loo.get_split(3, 5).expect("unwrap");
assert_eq!(val[0], 3);
assert_eq!(train.len(), 4);
}
#[test]
fn test_leave_one_out_invalid() {
let loo = LeaveOneOut::new();
assert!(loo.get_split(5, 5).is_err()); }
#[test]
fn test_cv_results() {
let mut results = CrossValidationResults::new();
let mut metrics1 = HashMap::new();
metrics1.insert("accuracy".to_string(), 0.9);
results.add_fold(0.85, metrics1);
let mut metrics2 = HashMap::new();
metrics2.insert("accuracy".to_string(), 0.95);
results.add_fold(0.90, metrics2);
let mut metrics3 = HashMap::new();
metrics3.insert("accuracy".to_string(), 0.92);
results.add_fold(0.88, metrics3);
assert_eq!(results.num_folds(), 3);
let mean = results.mean_score();
assert!((mean - 0.8766666).abs() < 1e-6);
let std = results.std_score();
assert!(std > 0.0);
let mean_acc = results.mean_metric("accuracy").expect("unwrap");
assert!((mean_acc - 0.923333).abs() < 1e-5);
}
#[test]
fn test_cv_results_empty() {
let results = CrossValidationResults::new();
assert_eq!(results.mean_score(), 0.0);
assert_eq!(results.std_score(), 0.0);
assert_eq!(results.num_folds(), 0);
assert!(results.mean_metric("accuracy").is_none());
}
#[test]
fn test_kfold_all_folds() {
let kfold = KFold::new(5).expect("unwrap");
let n_samples = 20;
let mut all_val_indices = Vec::new();
for fold in 0..5 {
let (_, val) = kfold.get_split(fold, n_samples).expect("unwrap");
all_val_indices.extend(val);
}
all_val_indices.sort();
assert_eq!(all_val_indices.len(), n_samples);
assert_eq!(all_val_indices, (0..n_samples).collect::<Vec<_>>());
}
}