use crate::callbacks::core::Callback;
use crate::{TrainResult, TrainingState};
pub struct EarlyStoppingCallback {
pub patience: usize,
pub min_delta: f64,
best_val_loss: Option<f64>,
wait: usize,
stop_training: bool,
}
impl EarlyStoppingCallback {
pub fn new(patience: usize, min_delta: f64) -> Self {
Self {
patience,
min_delta,
best_val_loss: None,
wait: 0,
stop_training: false,
}
}
}
impl Callback for EarlyStoppingCallback {
fn on_epoch_end(&mut self, epoch: usize, state: &TrainingState) -> TrainResult<()> {
if let Some(val_loss) = state.val_loss {
let improved = self
.best_val_loss
.map(|best| val_loss < best - self.min_delta)
.unwrap_or(true);
if improved {
self.best_val_loss = Some(val_loss);
self.wait = 0;
} else {
self.wait += 1;
if self.wait >= self.patience {
println!(
"Early stopping at epoch {} (no improvement for {} epochs)",
epoch, self.patience
);
self.stop_training = true;
}
}
}
Ok(())
}
fn should_stop(&self) -> bool {
self.stop_training
}
}
#[allow(dead_code)]
pub struct ReduceLrOnPlateauCallback {
pub factor: f64,
pub patience: usize,
pub min_delta: f64,
pub min_lr: f64,
best_val_loss: Option<f64>,
wait: usize,
}
impl ReduceLrOnPlateauCallback {
#[allow(dead_code)]
pub fn new(factor: f64, patience: usize, min_delta: f64, min_lr: f64) -> Self {
Self {
factor,
patience,
min_delta,
min_lr,
best_val_loss: None,
wait: 0,
}
}
}
impl Callback for ReduceLrOnPlateauCallback {
fn on_epoch_end(&mut self, _epoch: usize, state: &TrainingState) -> TrainResult<()> {
if let Some(val_loss) = state.val_loss {
let improved = self
.best_val_loss
.map(|best| val_loss < best - self.min_delta)
.unwrap_or(true);
if improved {
self.best_val_loss = Some(val_loss);
self.wait = 0;
} else {
self.wait += 1;
if self.wait >= self.patience {
let new_lr = (state.learning_rate * self.factor).max(self.min_lr);
if new_lr != state.learning_rate {
println!("Reducing learning rate to {:.6}", new_lr);
}
self.wait = 0;
}
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
fn create_test_state() -> TrainingState {
TrainingState {
epoch: 0,
batch: 0,
train_loss: 1.0,
val_loss: Some(0.8),
batch_loss: 0.5,
learning_rate: 0.001,
metrics: HashMap::new(),
}
}
#[test]
fn test_early_stopping() {
let mut callback = EarlyStoppingCallback::new(2, 0.01);
let mut state = create_test_state();
state.val_loss = Some(1.0);
callback.on_epoch_end(0, &state).expect("unwrap");
assert!(!callback.should_stop());
state.val_loss = Some(0.8);
callback.on_epoch_end(1, &state).expect("unwrap");
assert!(!callback.should_stop());
state.val_loss = Some(0.81);
callback.on_epoch_end(2, &state).expect("unwrap");
assert!(!callback.should_stop());
state.val_loss = Some(0.82);
callback.on_epoch_end(3, &state).expect("unwrap");
assert!(callback.should_stop());
}
}