use crate::metric::{
Metric, MetricName,
store::{Aggregate, Direction, EventStoreClient, Split},
};
#[derive(Clone)]
pub enum StoppingCondition {
NoImprovementSince {
n_epochs: usize,
},
}
pub trait EarlyStoppingStrategy: Send {
fn should_stop(&mut self, epoch: usize, store: &EventStoreClient) -> bool;
}
pub trait CloneEarlyStoppingStrategy: EarlyStoppingStrategy + Send {
fn clone_box(&self) -> Box<dyn CloneEarlyStoppingStrategy>;
}
impl<T> CloneEarlyStoppingStrategy for T
where
T: EarlyStoppingStrategy + Clone + Send + 'static,
{
fn clone_box(&self) -> Box<dyn CloneEarlyStoppingStrategy> {
Box::new(self.clone())
}
}
impl Clone for Box<dyn CloneEarlyStoppingStrategy> {
fn clone(&self) -> Box<dyn CloneEarlyStoppingStrategy> {
self.clone_box()
}
}
#[derive(Clone)]
pub struct MetricEarlyStoppingStrategy {
condition: StoppingCondition,
metric_name: MetricName,
aggregate: Aggregate,
direction: Direction,
split: Split,
best_epoch: usize,
best_value: f64,
warmup_epochs: Option<usize>,
}
impl EarlyStoppingStrategy for MetricEarlyStoppingStrategy {
fn should_stop(&mut self, epoch: usize, store: &EventStoreClient) -> bool {
let current_value =
match store.find_metric(&self.metric_name, epoch, self.aggregate, &self.split) {
Some(value) => value,
None => {
log::warn!("Can't find metric for early stopping.");
return false;
}
};
let is_best = match self.direction {
Direction::Lowest => current_value < self.best_value,
Direction::Highest => current_value > self.best_value,
};
if is_best {
log::info!(
"New best epoch found {} {}: {}",
epoch,
self.metric_name,
current_value
);
self.best_value = current_value;
self.best_epoch = epoch;
return false;
}
if let Some(warmup_epochs) = self.warmup_epochs
&& epoch <= warmup_epochs
{
return false;
}
match self.condition {
StoppingCondition::NoImprovementSince { n_epochs } => {
let should_stop = epoch - self.best_epoch >= n_epochs;
if should_stop {
log::info!(
"Stopping training loop, no improvement since epoch {}, {}: {}, current \
epoch {}, {}: {}",
self.best_epoch,
self.metric_name,
self.best_value,
epoch,
self.metric_name,
current_value
);
}
should_stop
}
}
}
}
impl MetricEarlyStoppingStrategy {
pub fn new<Me: Metric>(
metric: &Me,
aggregate: Aggregate,
direction: Direction,
split: Split,
condition: StoppingCondition,
) -> Self {
let init_value = match direction {
Direction::Lowest => f64::MAX,
Direction::Highest => f64::MIN,
};
Self {
metric_name: metric.name(),
condition,
aggregate,
direction,
split,
best_epoch: 1,
best_value: init_value,
warmup_epochs: None,
}
}
pub fn warmup_epochs(&self) -> Option<usize> {
self.warmup_epochs
}
pub fn with_warmup_epochs(self, warmup: Option<usize>) -> Self {
Self {
warmup_epochs: warmup,
..self
}
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use crate::{
EventProcessorTraining, TestBackend,
logger::InMemoryMetricLogger,
metric::{
LossMetric,
processor::{
MetricsTraining, MinimalEventProcessor,
test_utils::{end_epoch, process_train},
},
store::LogEventStore,
},
};
use super::*;
#[test]
fn never_early_stop_while_it_is_improving() {
test_early_stopping(
None,
1,
&[
(&[0.5, 0.3], false, "Should not stop first epoch"),
(&[0.4, 0.3], false, "Should not stop when improving"),
(&[0.3, 0.3], false, "Should not stop when improving"),
(&[0.2, 0.3], false, "Should not stop when improving"),
],
);
}
#[test]
fn early_stop_when_no_improvement_since_two_epochs() {
test_early_stopping(
None,
2,
&[
(&[1.0, 0.5], false, "Should not stop first epoch"),
(&[0.5, 0.3], false, "Should not stop when improving"),
(
&[1.0, 3.0],
false,
"Should not stop first time it gets worse",
),
(
&[1.0, 2.0],
true,
"Should stop since two following epochs didn't improve",
),
],
);
}
#[test]
fn early_stopping_with_warmup() {
test_early_stopping(
Some(3),
2,
&[
(&[1.0, 0.5], false, "Should not stop during warmup"),
(&[1.0, 0.5], false, "Should not stop during warmup"),
(&[1.0, 0.5], false, "Should not stop during warmup"),
(
&[1.0, 0.5],
true,
"Should stop when not improving after warmup",
),
],
)
}
#[test]
fn early_stop_when_stays_equal() {
test_early_stopping(
None,
2,
&[
(&[0.5, 0.3], false, "Should not stop first epoch"),
(
&[0.5, 0.3],
false,
"Should not stop first time it stars the same",
),
(
&[0.5, 0.3],
true,
"Should stop since two following epochs didn't improve",
),
],
);
}
fn test_early_stopping(warmup: Option<usize>, n_epochs: usize, data: &[(&[f64], bool, &str)]) {
let loss = LossMetric::<TestBackend>::new();
let mut early_stopping = MetricEarlyStoppingStrategy::new(
&loss,
Aggregate::Mean,
Direction::Lowest,
Split::Train,
StoppingCondition::NoImprovementSince { n_epochs },
)
.with_warmup_epochs(warmup);
let mut store = LogEventStore::default();
let mut metrics = MetricsTraining::<f64, f64>::default();
store.register_logger(InMemoryMetricLogger::default());
metrics.register_train_metric_numeric(loss);
let store = Arc::new(EventStoreClient::new(store));
let mut processor = MinimalEventProcessor::new(metrics, store.clone());
let mut epoch = 1;
processor.process_train(crate::LearnerEvent::Start);
for (points, should_start, comment) in data {
for point in points.iter() {
process_train(&mut processor, *point, epoch);
}
end_epoch(&mut processor, epoch);
assert_eq!(
*should_start,
early_stopping.should_stop(epoch, &store),
"{comment}"
);
epoch += 1;
}
}
}