use super::{EvaluationConfig, Evaluator, MetricType};
use crate::data::Dataset;
use crate::error::{Error, Result};
use crate::layers::Layer;
use scirs2_core::ndarray::{Array, IxDyn, ScalarOperand};
use scirs2_core::numeric::{Float, FromPrimitive};
use std::collections::HashMap;
use std::fmt::{Debug, Display};
#[derive(Debug, Clone)]
pub struct ValidationConfig {
pub batch_size: usize,
pub shuffle: bool,
pub num_workers: usize,
pub steps: Option<usize>,
pub metrics: Vec<MetricType>,
pub verbose: usize,
pub early_stopping: Option<EarlyStoppingConfig>,
}
impl Default for ValidationConfig {
fn default() -> Self {
Self {
batch_size: 32,
shuffle: false,
num_workers: 0,
steps: None,
metrics: vec![MetricType::Loss],
verbose: 1,
early_stopping: None,
}
}
pub struct EarlyStoppingConfig {
pub monitor: String,
pub min_delta: f64,
pub patience: usize,
pub restore_best_weights: bool,
pub mode: EarlyStoppingMode,
impl Default for EarlyStoppingConfig {
monitor: "val_loss".to_string(),
min_delta: 0.0001,
patience: 5,
restore_best_weights: true,
mode: EarlyStoppingMode::Min,
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum EarlyStoppingMode {
Min,
Max,
#[derive(Debug)]
pub struct ValidationHandler<
F: Float + Debug + ScalarOperand + Display + FromPrimitive + Send + Sync,
> {
pub config: ValidationConfig,
evaluator: Evaluator<F>,
early_stopping: Option<EarlyStoppingState<F>>,
pub struct EarlyStoppingState<
config: EarlyStoppingConfig,
best_value: F,
wait: usize,
best_weights: Option<Vec<Array<F, IxDyn>>>,
stopped_epoch: Option<usize>,
impl<F: Float + Debug + ScalarOperand + Display + FromPrimitive + Send + Sync>
ValidationHandler<F>
{
pub fn new(config: ValidationConfig) -> Result<Self> {
let eval_config = EvaluationConfig {
batch_size: config.batch_size,
shuffle: config.shuffle,
num_workers: config.num_workers,
metrics: config.metrics.clone(),
steps: config.steps,
verbose: config.verbose,
};
let evaluator = Evaluator::new(eval_config)?;
let early_stopping = config
.early_stopping
.as_ref()
.map(|es_config| EarlyStoppingState {
config: es_config.clone(),
best_value: match es_config.mode {
EarlyStoppingMode::Min => F::infinity(),
EarlyStoppingMode::Max => F::neg_infinity(),
},
wait: 0,
best_weights: None,
stopped_epoch: None,
});
Ok(Self {
config,
evaluator,
early_stopping,
})
pub fn validate<L: Layer<F>>(
&mut self,
model: &mut L,
dataset: &dyn Dataset<F>,
loss_fn: Option<&dyn crate::losses::Loss<F>>,
epoch: usize,
) -> Result<(HashMap<String, F>, bool)> {
model.set_training(false);
let metrics = self.evaluator.evaluate(model, dataset, loss_fn)?;
let mut val_metrics = HashMap::new();
for (name, value) in metrics {
val_metrics.insert(format!("val_{}", name), value);
let should_stop = if let Some(ref mut es_state) = self.early_stopping {
let monitor_value = if let Some(value) = val_metrics.get(&es_state.config.monitor) {
*value
} else {
return Err(Error::ValidationError(format!(
"Early stopping monitor '{}' not found in validation metrics",
es_state.config.monitor
)));
};
let improved = match es_state.config.mode {
EarlyStoppingMode::Min => {
monitor_value + F::from(es_state.config.min_delta).expect("Failed to convert to float")
< es_state.best_value
}
EarlyStoppingMode::Max => {
monitor_value - F::from(es_state.config.min_delta).expect("Failed to convert to float")
> es_state.best_value
if improved {
if self.config.verbose > 0 {
println!(
"Epoch {}: {} improved from {:.4} to {:.4}",
epoch, es_state.config.monitor, es_state.best_value, monitor_value
);
es_state.best_value = monitor_value;
es_state.wait = 0;
if es_state.config.restore_best_weights {
es_state.best_weights = Some(model.params());
false
es_state.wait += 1;
"Epoch {}: {} did not improve from {:.4}",
epoch, es_state.config.monitor, es_state.best_value
if es_state.wait >= es_state.config.patience {
if self.config.verbose > 0 {
println!(
"Early stopping triggered: no improvement in {} for {} epochs",
es_state.config.monitor, es_state.config.patience
);
}
es_state.stopped_epoch = Some(epoch);
if es_state.config.restore_best_weights {
if let Some(ref best_weights) = es_state.best_weights {
let mut params = model.params();
for (i, best_param) in bestweights.iter().enumerate() {
if i < params.len() {
params[i].assign(best_param);
}
}
}
true
} else {
false
}
} else {
false
model.set_training(true);
Ok((val_metrics, should_stop))
pub fn has_early_stopping(&self) -> bool {
self.early_stopping.is_some()
pub fn get_early_stopping_state(&self) -> Option<&EarlyStoppingState<F>> {
self.early_stopping.as_ref()
pub fn reset_early_stopping(&mut self) {
if let Some(ref mut es_state) = self.early_stopping {
es_state.best_value = match es_state.config.mode {
EarlyStoppingMode::Min => F::infinity(),
EarlyStoppingMode::Max => F::neg_infinity(),
es_state.wait = 0;
es_state.best_weights = None;
es_state.stopped_epoch = None;