use crate::metric::{
store::{Aggregate, Direction, EventStoreClient, Split},
Metric,
};
pub enum StoppingCondition {
NoImprovementSince {
n_epochs: usize,
},
}
pub trait EarlyStoppingStrategy {
fn should_stop(&mut self, epoch: usize, store: &EventStoreClient) -> bool;
}
pub struct MetricEarlyStoppingStrategy {
condition: StoppingCondition,
metric_name: String,
aggregate: Aggregate,
direction: Direction,
split: Split,
best_epoch: usize,
best_value: f64,
}
impl EarlyStoppingStrategy for MetricEarlyStoppingStrategy {
fn should_stop(&mut self, epoch: usize, store: &EventStoreClient) -> bool {
let current_value =
match store.find_metric(&self.metric_name, epoch, self.aggregate, self.split) {
Some(value) => value,
None => {
log::warn!("Can't find metric for early stopping.");
return false;
}
};
let is_best = match self.direction {
Direction::Lowest => current_value < self.best_value,
Direction::Highest => current_value > self.best_value,
};
if is_best {
log::info!(
"New best epoch found {} {}: {}",
epoch,
self.metric_name,
current_value
);
self.best_value = current_value;
self.best_epoch = epoch;
return false;
}
match self.condition {
StoppingCondition::NoImprovementSince { n_epochs } => {
let should_stop = epoch - self.best_epoch >= n_epochs;
if should_stop {
log::info!(
"Stopping training loop, no improvement since epoch {}, {}: {}, current \
epoch {}, {}: {}",
self.best_epoch,
self.metric_name,
self.best_value,
epoch,
self.metric_name,
current_value
);
}
should_stop
}
}
}
}
impl MetricEarlyStoppingStrategy {
pub fn new<Me: Metric>(
aggregate: Aggregate,
direction: Direction,
split: Split,
condition: StoppingCondition,
) -> Self {
let init_value = match direction {
Direction::Lowest => f64::MAX,
Direction::Highest => f64::MIN,
};
Self {
metric_name: Me::NAME.to_string(),
condition,
aggregate,
direction,
split,
best_epoch: 1,
best_value: init_value,
}
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use crate::{
logger::InMemoryMetricLogger,
metric::{
processor::{
test_utils::{end_epoch, process_train},
Metrics, MinimalEventProcessor,
},
store::LogEventStore,
LossMetric,
},
TestBackend,
};
use super::*;
#[test]
fn never_early_stop_while_it_is_improving() {
test_early_stopping(
1,
&[
(&[0.5, 0.3], false, "Should not stop first epoch"),
(&[0.4, 0.3], false, "Should not stop when improving"),
(&[0.3, 0.3], false, "Should not stop when improving"),
(&[0.2, 0.3], false, "Should not stop when improving"),
],
);
}
#[test]
fn early_stop_when_no_improvement_since_two_epochs() {
test_early_stopping(
2,
&[
(&[1.0, 0.5], false, "Should not stop first epoch"),
(&[0.5, 0.3], false, "Should not stop when improving"),
(
&[1.0, 3.0],
false,
"Should not stop first time it gets worse",
),
(
&[1.0, 2.0],
true,
"Should stop since two following epochs didn't improve",
),
],
);
}
#[test]
fn early_stop_when_stays_equal() {
test_early_stopping(
2,
&[
(&[0.5, 0.3], false, "Should not stop first epoch"),
(
&[0.5, 0.3],
false,
"Should not stop first time it stars the same",
),
(
&[0.5, 0.3],
true,
"Should stop since two following epochs didn't improve",
),
],
);
}
fn test_early_stopping(n_epochs: usize, data: &[(&[f64], bool, &str)]) {
let mut early_stopping = MetricEarlyStoppingStrategy::new::<LossMetric<TestBackend>>(
Aggregate::Mean,
Direction::Lowest,
Split::Train,
StoppingCondition::NoImprovementSince { n_epochs },
);
let mut store = LogEventStore::default();
let mut metrics = Metrics::<f64, f64>::default();
store.register_logger_train(InMemoryMetricLogger::default());
metrics.register_train_metric_numeric(LossMetric::<TestBackend>::new());
let store = Arc::new(EventStoreClient::new(store));
let mut processor = MinimalEventProcessor::new(metrics, store.clone());
let mut epoch = 1;
for (points, should_start, comment) in data {
for point in points.iter() {
process_train(&mut processor, *point, epoch);
}
end_epoch(&mut processor, epoch);
assert_eq!(
*should_start,
early_stopping.should_stop(epoch, &store),
"{comment}"
);
epoch += 1;
}
}
}