use scirs2_core::ndarray::{ArrayD, ScalarOperand};
use scirs2_core::numeric::{Float, FromPrimitive, ToPrimitive};
use std::fmt::Debug;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum StoppingMode {
Min,
Max,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum StopReason {
PatienceExhausted {
checks_without_improvement: usize,
},
MetricDiverged,
ThresholdExceeded {
threshold: String,
},
NotStopped,
}
#[derive(Debug, Clone)]
pub struct StepResult {
pub should_stop: bool,
pub improved: bool,
pub patience_counter: usize,
pub reason: StopReason,
}
#[derive(Debug, Clone)]
pub struct EarlyStopping {
patience: usize,
min_delta: f64,
mode: StoppingMode,
warmup_checks: usize,
divergence_threshold: Option<f64>,
best_value: f64,
best_epoch: usize,
counter: usize,
total_checks: usize,
stopped: bool,
stop_reason: StopReason,
metric_history: Vec<f64>,
}
impl EarlyStopping {
pub fn new(patience: usize, min_delta: f64, mode: StoppingMode) -> Self {
let initial_best = match mode {
StoppingMode::Min => f64::INFINITY,
StoppingMode::Max => f64::NEG_INFINITY,
};
Self {
patience,
min_delta: min_delta.abs(), mode,
warmup_checks: 0,
divergence_threshold: None,
best_value: initial_best,
best_epoch: 0,
counter: 0,
total_checks: 0,
stopped: false,
stop_reason: StopReason::NotStopped,
metric_history: Vec::new(),
}
}
pub fn with_warmup(mut self, warmup_checks: usize) -> Self {
self.warmup_checks = warmup_checks;
self
}
pub fn with_divergence_threshold(mut self, threshold: f64) -> Self {
self.divergence_threshold = Some(threshold);
self
}
pub fn check(&mut self, metric: f64) -> bool {
self.step(metric).should_stop
}
pub fn step(&mut self, metric: f64) -> StepResult {
self.total_checks += 1;
self.metric_history.push(metric);
if self.stopped {
return StepResult {
should_stop: true,
improved: false,
patience_counter: self.counter,
reason: self.stop_reason.clone(),
};
}
if !metric.is_finite() {
self.stopped = true;
self.stop_reason = StopReason::MetricDiverged;
return StepResult {
should_stop: true,
improved: false,
patience_counter: self.counter,
reason: StopReason::MetricDiverged,
};
}
if let Some(threshold) = self.divergence_threshold {
let diverged = match self.mode {
StoppingMode::Min => metric > threshold,
StoppingMode::Max => metric < -threshold,
};
if diverged {
self.stopped = true;
self.stop_reason = StopReason::ThresholdExceeded {
threshold: format!("{threshold}"),
};
return StepResult {
should_stop: true,
improved: false,
patience_counter: self.counter,
reason: self.stop_reason.clone(),
};
}
}
let improved = self.is_improvement(metric);
if improved {
self.best_value = metric;
self.best_epoch = self.total_checks - 1; self.counter = 0;
} else {
self.counter += 1;
}
if self.total_checks <= self.warmup_checks {
return StepResult {
should_stop: false,
improved,
patience_counter: self.counter,
reason: StopReason::NotStopped,
};
}
if self.counter >= self.patience {
self.stopped = true;
self.stop_reason = StopReason::PatienceExhausted {
checks_without_improvement: self.counter,
};
return StepResult {
should_stop: true,
improved: false,
patience_counter: self.counter,
reason: self.stop_reason.clone(),
};
}
StepResult {
should_stop: false,
improved,
patience_counter: self.counter,
reason: StopReason::NotStopped,
}
}
fn is_improvement(&self, metric: f64) -> bool {
match self.mode {
StoppingMode::Min => metric < self.best_value - self.min_delta,
StoppingMode::Max => metric > self.best_value + self.min_delta,
}
}
pub fn best_value(&self) -> f64 {
self.best_value
}
pub fn best_epoch(&self) -> usize {
self.best_epoch
}
pub fn patience_counter(&self) -> usize {
self.counter
}
pub fn total_checks(&self) -> usize {
self.total_checks
}
pub fn metric_history(&self) -> &[f64] {
&self.metric_history
}
pub fn is_stopped(&self) -> bool {
self.stopped
}
pub fn stop_reason(&self) -> &StopReason {
&self.stop_reason
}
pub fn reset(&mut self) {
let initial_best = match self.mode {
StoppingMode::Min => f64::INFINITY,
StoppingMode::Max => f64::NEG_INFINITY,
};
self.best_value = initial_best;
self.best_epoch = 0;
self.counter = 0;
self.total_checks = 0;
self.stopped = false;
self.stop_reason = StopReason::NotStopped;
self.metric_history.clear();
}
}
#[derive(Debug)]
pub struct EarlyStoppingWithState<F>
where
F: Float + Debug + ScalarOperand + FromPrimitive + ToPrimitive + Clone,
{
inner: EarlyStopping,
best_params: Option<Vec<ArrayD<F>>>,
}
impl<F> EarlyStoppingWithState<F>
where
F: Float + Debug + ScalarOperand + FromPrimitive + ToPrimitive + Clone,
{
pub fn new(patience: usize, min_delta: f64, mode: StoppingMode) -> Self {
Self {
inner: EarlyStopping::new(patience, min_delta, mode),
best_params: None,
}
}
pub fn with_warmup(mut self, warmup_checks: usize) -> Self {
self.inner = self.inner.with_warmup(warmup_checks);
self
}
pub fn with_divergence_threshold(mut self, threshold: f64) -> Self {
self.inner = self.inner.with_divergence_threshold(threshold);
self
}
pub fn step(&mut self, metric: f64, params: &[ArrayD<F>]) -> StepResult {
let result = self.inner.step(metric);
if result.improved {
self.best_params = Some(params.to_vec());
}
result
}
pub fn check(&mut self, metric: f64) -> bool {
self.inner.check(metric)
}
pub fn best_params(&self) -> Option<&[ArrayD<F>]> {
self.best_params.as_deref()
}
pub fn best_value(&self) -> f64 {
self.inner.best_value()
}
pub fn best_epoch(&self) -> usize {
self.inner.best_epoch()
}
pub fn is_stopped(&self) -> bool {
self.inner.is_stopped()
}
pub fn stop_reason(&self) -> &StopReason {
self.inner.stop_reason()
}
pub fn metric_history(&self) -> &[f64] {
self.inner.metric_history()
}
pub fn reset(&mut self) {
self.inner.reset();
self.best_params = None;
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array;
#[test]
fn test_min_mode_basic_improvement() {
let mut es = EarlyStopping::new(3, 0.0, StoppingMode::Min);
for &loss in &[1.0, 0.9, 0.8, 0.7, 0.6, 0.5] {
assert!(!es.check(loss));
}
assert_eq!(es.best_epoch(), 5);
assert!((es.best_value() - 0.5).abs() < 1e-10);
}
#[test]
fn test_min_mode_patience_exhausted() {
let mut es = EarlyStopping::new(3, 0.0, StoppingMode::Min);
assert!(!es.check(0.5)); assert!(!es.check(0.6)); assert!(!es.check(0.7)); assert!(es.check(0.8)); assert_eq!(es.patience_counter(), 3);
}
#[test]
fn test_max_mode_basic() {
let mut es = EarlyStopping::new(2, 0.0, StoppingMode::Max);
assert!(!es.check(0.5)); assert!(!es.check(0.7)); assert!(!es.check(0.6)); assert!(es.check(0.6));
assert_eq!(es.best_epoch(), 1);
assert!((es.best_value() - 0.7).abs() < 1e-10);
}
#[test]
fn test_min_delta_effect() {
let mut es = EarlyStopping::new(3, 0.01, StoppingMode::Min);
assert!(!es.check(1.0));
assert!(!es.check(0.995)); assert!(!es.check(0.996)); assert!(es.check(0.997));
assert!((es.best_value() - 1.0).abs() < 1e-10);
}
#[test]
fn test_nan_causes_divergence_stop() {
let mut es = EarlyStopping::new(10, 0.0, StoppingMode::Min);
assert!(!es.check(0.5));
let result = es.step(f64::NAN);
assert!(result.should_stop);
assert_eq!(es.stop_reason(), &StopReason::MetricDiverged);
}
#[test]
fn test_infinity_causes_divergence_stop() {
let mut es = EarlyStopping::new(10, 0.0, StoppingMode::Min);
assert!(!es.check(0.5));
assert!(es.check(f64::INFINITY));
assert_eq!(es.stop_reason(), &StopReason::MetricDiverged);
}
#[test]
fn test_divergence_threshold() {
let mut es = EarlyStopping::new(10, 0.0, StoppingMode::Min).with_divergence_threshold(5.0);
assert!(!es.check(1.0));
assert!(!es.check(4.0));
let result = es.step(6.0); assert!(result.should_stop);
match es.stop_reason() {
StopReason::ThresholdExceeded { .. } => {} other => panic!("Expected ThresholdExceeded, got {:?}", other),
}
}
#[test]
fn test_warmup_period() {
let mut es = EarlyStopping::new(2, 0.0, StoppingMode::Min).with_warmup(3);
assert!(!es.check(0.5)); assert!(!es.check(0.6)); assert!(!es.check(0.7));
assert!(es.check(0.8)); }
#[test]
fn test_reset() {
let mut es = EarlyStopping::new(2, 0.0, StoppingMode::Min);
assert!(!es.check(0.5));
assert!(!es.check(0.6));
assert!(es.check(0.7));
assert!(es.is_stopped());
es.reset();
assert!(!es.is_stopped());
assert_eq!(es.total_checks(), 0);
assert!(es.metric_history().is_empty());
}
#[test]
fn test_step_result_details() {
let mut es = EarlyStopping::new(3, 0.0, StoppingMode::Min);
let r1 = es.step(1.0);
assert!(r1.improved);
assert!(!r1.should_stop);
assert_eq!(r1.patience_counter, 0);
let r2 = es.step(1.5);
assert!(!r2.improved);
assert_eq!(r2.patience_counter, 1);
}
#[test]
fn test_metric_history_tracking() {
let mut es = EarlyStopping::new(5, 0.0, StoppingMode::Min);
for &v in &[0.9, 0.8, 0.85, 0.7, 0.75] {
es.check(v);
}
assert_eq!(es.metric_history(), &[0.9, 0.8, 0.85, 0.7, 0.75]);
assert_eq!(es.total_checks(), 5);
}
#[test]
fn test_with_state_stores_best_params() {
let mut es = EarlyStoppingWithState::<f64>::new(3, 0.0, StoppingMode::Min);
let params1 = vec![Array::from_vec(vec![1.0, 2.0]).into_dyn()];
let params2 = vec![Array::from_vec(vec![3.0, 4.0]).into_dyn()];
let params3 = vec![Array::from_vec(vec![5.0, 6.0]).into_dyn()];
let r1 = es.step(1.0, ¶ms1); assert!(r1.improved);
assert!(es.best_params().is_some());
let r2 = es.step(0.5, ¶ms2); assert!(r2.improved);
let best = es.best_params().expect("should have best params");
assert_eq!(best[0].as_slice().expect("contiguous"), &[3.0, 4.0]);
let r3 = es.step(0.8, ¶ms3); assert!(!r3.improved);
let best = es.best_params().expect("should still have best params");
assert_eq!(best[0].as_slice().expect("contiguous"), &[3.0, 4.0]);
}
#[test]
fn test_with_state_reset_clears_params() {
let mut es = EarlyStoppingWithState::<f64>::new(3, 0.0, StoppingMode::Min);
let params = vec![Array::from_vec(vec![1.0]).into_dyn()];
es.step(0.5, ¶ms);
assert!(es.best_params().is_some());
es.reset();
assert!(es.best_params().is_none());
assert!(!es.is_stopped());
}
#[test]
fn test_with_state_stopping_behavior() {
let mut es = EarlyStoppingWithState::<f64>::new(2, 0.0, StoppingMode::Max);
let params = vec![Array::from_vec(vec![0.0]).into_dyn()];
let r1 = es.step(0.8, ¶ms); assert!(!r1.should_stop);
let r2 = es.step(0.7, ¶ms); assert!(!r2.should_stop);
let r3 = es.step(0.6, ¶ms); assert!(r3.should_stop);
assert!(es.is_stopped());
}
}