use super::{EvaluationConfig, Evaluator, MetricType, ModelBuilder};
use crate::data::Dataset;
use crate::error::{NeuralError, Result};
use crate::layers::Layer;
use scirs2_core::random::SeedableRng;
use scirs2_core::ndarray::ScalarOperand;
use scirs2_core::numeric::{Float, FromPrimitive};
use std::collections::HashMap;
use std::fmt::Debug;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum CrossValidationStrategy {
KFold(usize),
StratifiedKFold(usize),
LeaveOneOut,
LeavePOut(usize),
ShuffleSplit(usize, f64),
}
#[derive(Debug, Clone)]
pub struct CrossValidationConfig {
pub strategy: CrossValidationStrategy,
pub shuffle: bool,
pub random_seed: Option<u64>,
pub batch_size: usize,
pub num_workers: usize,
pub metrics: Vec<MetricType>,
pub verbose: usize,
impl Default for CrossValidationConfig {
fn default() -> Self {
Self {
strategy: CrossValidationStrategy::KFold(5),
shuffle: true,
random_seed: None,
batch_size: 32,
num_workers: 0,
metrics: vec![MetricType::Loss, MetricType::Accuracy],
verbose: 1,
}
}
#[derive(Debug)]
pub struct CrossValidationFold {
pub train_indices: Vec<usize>,
pub val_indices: Vec<usize>,
pub struct CrossValidator<
F: Float + Debug + ScalarOperand + FromPrimitive + std::fmt::Display + Send + Sync,
> {
pub config: CrossValidationConfig,
evaluator: Evaluator<F>,
impl<F: Float + Debug + ScalarOperand + FromPrimitive + std::fmt::Display + Send + Sync>
CrossValidator<F>
{
pub fn new(config: CrossValidationConfig) -> Result<Self> {
let eval_config = EvaluationConfig {
batch_size: config.batch_size,
shuffle: false, num_workers: config.num_workers,
metrics: config.metrics.clone(),
steps: None,
verbose: config.verbose,
};
let evaluator = Evaluator::new(eval_config)?;
Ok(Self { config, evaluator })
pub fn create_folds(&self, dataset: &dyn Dataset<F>) -> Result<Vec<CrossValidationFold>> {
let n_samples = dataset.len();
match self.config.strategy {
CrossValidationStrategy::KFold(k) => {
if k < 2 {
return Err(NeuralError::ValidationError(
"k must be at least 2".to_string(),
));
}
if n_samples < k {
return Err(NeuralError::ValidationError(format!(
"Dataset size ({}) must be at least equal to k ({})",
n_samples, k
)));
let mut indices: Vec<usize> = (0..n_samples).collect();
if self.config.shuffle {
use scirs2_core::random::seq::SliceRandom;
if let Some(seed) = self.config.random_seed {
let mut rng = scirs2_core::random::rngs::StdRng::seed_from_u64(seed);
indices.shuffle(&mut rng);
} else {
let mut rng = rng();
}
let fold_size = n_samples / k;
let remainder = n_samples % k;
let mut folds = Vec::with_capacity(k);
let mut start = 0;
for i in 0..k {
let fold_size_adjusted = if i < remainder {
fold_size + 1
fold_size
};
let end = start + fold_size_adjusted;
let val_indices = indices[start..end].to_vec();
let mut train_indices = Vec::with_capacity(n_samples - val_indices.len());
for &idx in &indices[0..start] {
train_indices.push(idx);
for &idx in &indices[end..] {
folds.push(CrossValidationFold {
train_indices,
val_indices,
});
start = end;
Ok(folds)
}
CrossValidationStrategy::StratifiedKFold(k) => {
let mut class_indices = HashMap::new();
for i in 0..n_samples {
let (_, target) = dataset.get(i)?;
let class_idx = if target.ndim() > 1 && target.shape()[1] > 1 {
let mut max_idx = 0;
let mut max_val = target[[0, 0]];
for j in 1..target.shape()[1] {
if target[[0, j]] > max_val {
max_idx = j;
max_val = target[[0, j]];
}
}
max_idx
target[[0]].to_usize().unwrap_or(0)
class_indices
.entry(class_idx)
.or_insert_with(Vec::new)
.push(i);
for _ in 0..k {
train_indices: Vec::new(),
val_indices: Vec::new(),
for (_, mut indices) in class_indices {
if self.config.shuffle {
use scirs2_core::random::seq::SliceRandom;
if let Some(seed) = self.config.random_seed {
let mut rng = scirs2_core::random::rngs::StdRng::seed_from_u64(seed);
indices.shuffle(&mut rng);
} else {
let mut rng = rng();
for (i, &idx) in indices.iter().enumerate() {
let fold_idx = i % k;
folds[fold_idx].val_indices.push(idx);
for fold in folds.iter_mut().take(k) {
let val_indices = &fold.val_indices;
for i in 0..n_samples {
if !val_indices.contains(&i) {
train_indices.push(i);
fold.train_indices = train_indices;
CrossValidationStrategy::LeaveOneOut => {
let mut folds = Vec::with_capacity(n_samples);
let val_indices = vec![i];
let mut train_indices = Vec::with_capacity(n_samples - 1);
for j in 0..n_samples {
if j != i {
train_indices.push(j);
CrossValidationStrategy::LeavePOut(p) => {
if p >= n_samples {
"p ({}) must be less than dataset size ({})",
p, n_samples
let n_folds = n_samples / p;
let mut folds = Vec::with_capacity(n_folds);
for i in 0..n_folds {
let start = i * p;
let end = (i + 1) * p;
let mut train_indices = Vec::with_capacity(n_samples - p);
for (j, &idx) in indices.iter().enumerate().take(n_samples) {
if j < start || j >= end {
train_indices.push(idx);
CrossValidationStrategy::ShuffleSplit(n_splits, test_size) => {
if test_size <= 0.0 || test_size >= 1.0 {
"test_size must be between 0 and 1".to_string(),
let indices: Vec<usize> = (0..n_samples).collect();
let test_count = (n_samples as f64 * test_size).ceil() as usize;
if test_count >= n_samples {
"test_size too large for dataset".to_string(),
let mut folds = Vec::with_capacity(n_splits);
let rng_with_seed = self
.config
.random_seed
.map(scirs2_core::random::rngs::StdRng::from_seed);
for _ in 0..n_splits {
let mut shuffled = indices.clone();
if let Some(mut rng) = rng_with_seed.clone() {
shuffled.shuffle(&mut rng);
let val_indices = shuffled[0..test_count].to_vec();
let train_indices = shuffled[test_count..].to_vec();
pub fn cross_validate<L: Layer<F> + Clone>(
&mut self,
model_builder: &dyn ModelBuilder<F, Model = L>,
dataset: &dyn Dataset<F>,
loss_fn: Option<&dyn crate::losses::Loss<F>>,
) -> Result<HashMap<String, Vec<F>>> {
let folds = self.create_folds(dataset)?;
let metrics = &self.config.metrics;
let mut results: HashMap<String, Vec<F>> = metrics
.iter()
.map(|m| {
let name = match m {
MetricType::Loss => "loss".to_string(),
MetricType::Accuracy => "accuracy".to_string(),
MetricType::Precision => "precision".to_string(),
MetricType::Recall => "recall".to_string(),
MetricType::F1Score => "f1_score".to_string(),
MetricType::MeanSquaredError => "mse".to_string(),
MetricType::MeanAbsoluteError => "mae".to_string(),
MetricType::RSquared => "r2".to_string(),
MetricType::AUC => "auc".to_string(),
MetricType::Custom(name) => name.clone(),
};
(name, Vec::with_capacity(folds.len()))
})
.collect();
for (fold_idx, fold) in folds.iter().enumerate() {
if self.config.verbose > 0 {
println!("Fold {}/{}", fold_idx + 1, folds.len());
struct DatasetSubset<F: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync> {
data: Vec<(
scirs2_core::ndarray::Array<F, scirs2_core::ndarray::IxDyn>,
)>,
impl<F: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync> DatasetSubset<F> {
fn new(dataset: &dyn Dataset<F>, indices: &[usize]) -> Result<Self> {
let mut data = Vec::with_capacity(indices.len());
for &idx in indices {
let (input, target) = dataset.get(idx)?;
data.push((input, target));
Ok(Self { data })
impl<F: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync> Dataset<F>
for DatasetSubset<F>
{
fn len(&self) -> usize {
self.data.len()
fn get(
&self,
idx: usize,
) -> Result<(
)> {
if idx >= self.data.len() {
return Err(crate::error::NeuralError::InferenceError(format!(
"Index out of bounds: {} >= {}",
idx,
self.data.len()
)));
Ok((self.data[idx].0.clone(), self.data[idx].1.clone()))
fn box_clone(&self) -> Box<dyn Dataset<F> + Send + Sync> {
let cloned_data = self.data.clone();
Box::new(Self { data: cloned_data })
let _train_dataset = DatasetSubset::new(dataset, &fold.train_indices)?;
let val_dataset = DatasetSubset::new(dataset, &fold.val_indices)?;
let model = model_builder.build()?;
let fold_metrics = self.evaluator.evaluate(&model, &val_dataset, loss_fn)?;
for (name, value) in fold_metrics {
if let Some(values) = results.get_mut(&name) {
values.push(value);
for (name, values) in &results {
if !values.is_empty() {
let sum = values.iter().fold(F::zero(), |acc, &x| acc + x);
let mean = sum / F::from(values.len()).expect("Operation failed");
let variance_sum = values
.iter()
.fold(F::zero(), |acc, &x| acc + (x - mean) * (x - mean));
let std = (variance_sum / F::from(values.len()).expect("Operation failed")).sqrt();
if self.config.verbose > 0 {
println!("{}: {:.4} ± {:.4}", name, mean, std);
Ok(results)