use super::{EvaluationConfig, Evaluator, MetricType, ModelBuilder};
use crate::data::Dataset;
use crate::error::{NeuralError, Result};
use crate::layers::Layer;
use scirs2_core::ndarray::ScalarOperand;
use scirs2_core::numeric::{Float, FromPrimitive, NumAssign};
use scirs2_core::random::SeedableRng;
use std::collections::HashMap;
use std::fmt::Debug;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum CrossValidationStrategy {
KFold(usize),
StratifiedKFold(usize),
LeaveOneOut,
LeavePOut(usize),
ShuffleSplit(usize, f64),
}
#[derive(Debug, Clone)]
pub struct CrossValidationConfig {
pub strategy: CrossValidationStrategy,
pub shuffle: bool,
pub random_seed: Option<u64>,
pub batch_size: usize,
pub num_workers: usize,
pub metrics: Vec<MetricType>,
pub verbose: usize,
}
impl Default for CrossValidationConfig {
fn default() -> Self {
Self {
strategy: CrossValidationStrategy::KFold(5),
shuffle: true,
random_seed: None,
batch_size: 32,
num_workers: 0,
metrics: vec![MetricType::Loss, MetricType::Accuracy],
verbose: 1,
}
}
}
#[derive(Debug)]
pub struct CrossValidationFold {
pub train_indices: Vec<usize>,
pub val_indices: Vec<usize>,
}
pub struct CrossValidator<
F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + std::fmt::Display + Send + Sync,
> {
pub config: CrossValidationConfig,
evaluator: Evaluator<F>,
}
struct DatasetSubset<
F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Send + Sync + 'static,
> {
data: Vec<(
scirs2_core::ndarray::Array<F, scirs2_core::ndarray::IxDyn>,
scirs2_core::ndarray::Array<F, scirs2_core::ndarray::IxDyn>,
)>,
}
impl<F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Send + Sync + 'static>
DatasetSubset<F>
{
fn new(dataset: &dyn Dataset<F>, indices: &[usize]) -> Result<Self> {
let mut data = Vec::with_capacity(indices.len());
for &idx in indices {
let (input, target) = dataset.get(idx)?;
data.push((input, target));
}
Ok(Self { data })
}
}
impl<F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Send + Sync + 'static>
Dataset<F> for DatasetSubset<F>
{
fn len(&self) -> usize {
self.data.len()
}
fn get(
&self,
idx: usize,
) -> Result<(
scirs2_core::ndarray::Array<F, scirs2_core::ndarray::IxDyn>,
scirs2_core::ndarray::Array<F, scirs2_core::ndarray::IxDyn>,
)> {
if idx >= self.data.len() {
return Err(crate::error::NeuralError::InferenceError(format!(
"Index out of bounds: {} >= {}",
idx,
self.data.len()
)));
}
Ok((self.data[idx].0.clone(), self.data[idx].1.clone()))
}
fn box_clone(&self) -> Box<dyn Dataset<F> + Send + Sync> {
let cloned_data = self.data.clone();
Box::new(Self { data: cloned_data })
}
}
impl<
F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + std::fmt::Display + Send + Sync,
> CrossValidator<F>
{
pub fn new(config: CrossValidationConfig) -> Result<Self> {
let eval_config = EvaluationConfig {
batch_size: config.batch_size,
shuffle: false, num_workers: config.num_workers,
metrics: config.metrics.clone(),
steps: None,
verbose: config.verbose,
};
let evaluator = Evaluator::new(eval_config)?;
Ok(Self { config, evaluator })
}
pub fn create_folds(&self, dataset: &dyn Dataset<F>) -> Result<Vec<CrossValidationFold>> {
let n_samples = dataset.len();
match self.config.strategy {
CrossValidationStrategy::KFold(k) => {
if k < 2 {
return Err(NeuralError::ValidationError(
"k must be at least 2".to_string(),
));
}
if n_samples < k {
return Err(NeuralError::ValidationError(format!(
"Dataset size ({}) must be at least equal to k ({})",
n_samples, k
)));
}
let mut indices: Vec<usize> = (0..n_samples).collect();
if self.config.shuffle {
use scirs2_core::random::seq::SliceRandom;
if let Some(seed) = self.config.random_seed {
let mut rng = scirs2_core::random::rngs::StdRng::seed_from_u64(seed);
indices.shuffle(&mut rng);
} else {
let mut rng = scirs2_core::random::rng();
indices.shuffle(&mut rng);
}
}
let fold_size = n_samples / k;
let remainder = n_samples % k;
let mut folds = Vec::with_capacity(k);
let mut start = 0;
for i in 0..k {
let fold_size_adjusted = if i < remainder {
fold_size + 1
} else {
fold_size
};
let end = start + fold_size_adjusted;
let val_indices = indices[start..end].to_vec();
let mut train_indices = Vec::with_capacity(n_samples - val_indices.len());
for &idx in &indices[0..start] {
train_indices.push(idx);
}
for &idx in &indices[end..] {
train_indices.push(idx);
}
folds.push(CrossValidationFold {
train_indices,
val_indices,
});
start = end;
}
Ok(folds)
}
CrossValidationStrategy::StratifiedKFold(k) => {
let mut class_indices: HashMap<usize, Vec<usize>> = HashMap::new();
for i in 0..n_samples {
let (_, target) = dataset.get(i)?;
let class_idx = if target.ndim() > 1 && target.shape()[1] > 1 {
let mut max_idx = 0;
let mut max_val = target[[0, 0]];
for j in 1..target.shape()[1] {
if target[[0, j]] > max_val {
max_idx = j;
max_val = target[[0, j]];
}
}
max_idx
} else {
target[[0]].to_usize().unwrap_or(0)
};
class_indices.entry(class_idx).or_default().push(i);
}
let mut folds: Vec<CrossValidationFold> = (0..k)
.map(|_| CrossValidationFold {
train_indices: Vec::new(),
val_indices: Vec::new(),
})
.collect();
for (_, mut indices) in class_indices {
if self.config.shuffle {
use scirs2_core::random::seq::SliceRandom;
if let Some(seed) = self.config.random_seed {
let mut rng = scirs2_core::random::rngs::StdRng::seed_from_u64(seed);
indices.shuffle(&mut rng);
} else {
let mut rng = scirs2_core::random::rng();
indices.shuffle(&mut rng);
}
}
for (i, &idx) in indices.iter().enumerate() {
let fold_idx = i % k;
folds[fold_idx].val_indices.push(idx);
}
}
for fold in &mut folds {
let val_set: std::collections::HashSet<usize> =
fold.val_indices.iter().cloned().collect();
let train_indices: Vec<usize> =
(0..n_samples).filter(|i| !val_set.contains(i)).collect();
fold.train_indices = train_indices;
}
Ok(folds)
}
CrossValidationStrategy::LeaveOneOut => {
let mut folds = Vec::with_capacity(n_samples);
for i in 0..n_samples {
let val_indices = vec![i];
let mut train_indices = Vec::with_capacity(n_samples - 1);
for j in 0..n_samples {
if j != i {
train_indices.push(j);
}
}
folds.push(CrossValidationFold {
train_indices,
val_indices,
});
}
Ok(folds)
}
CrossValidationStrategy::LeavePOut(p) => {
if p >= n_samples {
return Err(NeuralError::ValidationError(format!(
"p ({}) must be less than dataset size ({})",
p, n_samples
)));
}
let indices: Vec<usize> = (0..n_samples).collect();
let n_folds = n_samples / p;
let mut folds = Vec::with_capacity(n_folds);
for i in 0..n_folds {
let start = i * p;
let end = ((i + 1) * p).min(n_samples);
let val_indices = indices[start..end].to_vec();
let mut train_indices = Vec::with_capacity(n_samples - p);
for (j, &idx) in indices.iter().enumerate() {
if j < start || j >= end {
train_indices.push(idx);
}
}
folds.push(CrossValidationFold {
train_indices,
val_indices,
});
}
Ok(folds)
}
CrossValidationStrategy::ShuffleSplit(n_splits, test_size) => {
if test_size <= 0.0 || test_size >= 1.0 {
return Err(NeuralError::ValidationError(
"test_size must be between 0 and 1".to_string(),
));
}
let indices: Vec<usize> = (0..n_samples).collect();
let test_count = (n_samples as f64 * test_size).ceil() as usize;
if test_count >= n_samples {
return Err(NeuralError::ValidationError(
"test_size too large for dataset".to_string(),
));
}
let mut folds = Vec::with_capacity(n_splits);
for _ in 0..n_splits {
let mut shuffled = indices.clone();
use scirs2_core::random::seq::SliceRandom;
if let Some(seed) = self.config.random_seed {
let mut rng = scirs2_core::random::rngs::StdRng::seed_from_u64(seed);
shuffled.shuffle(&mut rng);
} else {
let mut rng = scirs2_core::random::rng();
shuffled.shuffle(&mut rng);
}
let val_indices = shuffled[0..test_count].to_vec();
let train_indices = shuffled[test_count..].to_vec();
folds.push(CrossValidationFold {
train_indices,
val_indices,
});
}
Ok(folds)
}
}
}
pub fn cross_validate<L: Layer<F> + Clone>(
&mut self,
model_builder: &dyn ModelBuilder<F, Model = L>,
dataset: &dyn Dataset<F>,
loss_fn: Option<&dyn crate::losses::Loss<F>>,
) -> Result<HashMap<String, Vec<F>>> {
let folds = self.create_folds(dataset)?;
let metrics = &self.config.metrics;
let mut results: HashMap<String, Vec<F>> = metrics
.iter()
.map(|m| {
let name = match m {
MetricType::Loss => "loss".to_string(),
MetricType::Accuracy => "accuracy".to_string(),
MetricType::Precision => "precision".to_string(),
MetricType::Recall => "recall".to_string(),
MetricType::F1Score => "f1_score".to_string(),
MetricType::MeanSquaredError => "mse".to_string(),
MetricType::MeanAbsoluteError => "mae".to_string(),
MetricType::RSquared => "r2".to_string(),
MetricType::AUC => "auc".to_string(),
MetricType::Custom(name) => name.clone(),
};
(name, Vec::with_capacity(folds.len()))
})
.collect();
for (fold_idx, fold) in folds.iter().enumerate() {
if self.config.verbose > 0 {
println!("Fold {}/{}", fold_idx + 1, folds.len());
}
let _train_dataset = DatasetSubset::new(dataset, &fold.train_indices)?;
let val_dataset = DatasetSubset::new(dataset, &fold.val_indices)?;
let model = model_builder.build()?;
let fold_metrics = self.evaluator.evaluate(&model, &val_dataset, loss_fn)?;
for (name, value) in fold_metrics {
if let Some(values) = results.get_mut(&name) {
values.push(value);
}
}
}
if self.config.verbose > 0 {
for (name, values) in &results {
if !values.is_empty() {
let sum = values.iter().fold(F::zero(), |acc, &x| acc + x);
let mean = sum / F::from(values.len()).expect("Operation failed");
let variance_sum = values
.iter()
.fold(F::zero(), |acc, &x| acc + (x - mean) * (x - mean));
let std =
(variance_sum / F::from(values.len()).expect("Operation failed")).sqrt();
println!("{}: {:.4} ± {:.4}", name, mean, std);
}
}
}
Ok(results)
}
}