use super::CheckpointingStrategy;
use crate::{
checkpoint::CheckpointingAction,
metric::{
store::{Aggregate, Direction, EventStoreClient, Split},
Metric,
},
};
pub struct MetricCheckpointingStrategy {
current: Option<usize>,
aggregate: Aggregate,
direction: Direction,
split: Split,
name: String,
}
impl MetricCheckpointingStrategy {
pub fn new<M>(aggregate: Aggregate, direction: Direction, split: Split) -> Self
where
M: Metric,
{
Self {
current: None,
name: M::NAME.to_string(),
aggregate,
direction,
split,
}
}
}
impl CheckpointingStrategy for MetricCheckpointingStrategy {
fn checkpointing(
&mut self,
epoch: usize,
store: &EventStoreClient,
) -> Vec<CheckpointingAction> {
let best_epoch =
match store.find_epoch(&self.name, self.aggregate, self.direction, self.split) {
Some(epoch_best) => epoch_best,
None => epoch,
};
let mut actions = Vec::new();
if let Some(current) = self.current {
if current != best_epoch {
actions.push(CheckpointingAction::Delete(current));
}
}
if best_epoch == epoch {
actions.push(CheckpointingAction::Save);
}
self.current = Some(best_epoch);
actions
}
}
#[cfg(test)]
mod tests {
use crate::{
logger::InMemoryMetricLogger,
metric::{
processor::{
test_utils::{end_epoch, process_train},
Metrics, MinimalEventProcessor,
},
store::LogEventStore,
LossMetric,
},
TestBackend,
};
use std::sync::Arc;
use super::*;
#[test]
fn always_keep_the_best_epoch() {
let mut store = LogEventStore::default();
let mut strategy = MetricCheckpointingStrategy::new::<LossMetric<TestBackend>>(
Aggregate::Mean,
Direction::Lowest,
Split::Train,
);
let mut metrics = Metrics::<f64, f64>::default();
store.register_logger_train(InMemoryMetricLogger::default());
metrics.register_train_metric_numeric(LossMetric::<TestBackend>::new());
let store = Arc::new(EventStoreClient::new(store));
let mut processor = MinimalEventProcessor::new(metrics, store.clone());
let mut epoch = 1;
process_train(&mut processor, 1.0, epoch);
process_train(&mut processor, 0.5, epoch);
end_epoch(&mut processor, epoch);
assert_eq!(
vec![CheckpointingAction::Save],
strategy.checkpointing(epoch, &store)
);
epoch += 1;
process_train(&mut processor, 0.5, epoch);
process_train(&mut processor, 0.3, epoch);
end_epoch(&mut processor, epoch);
assert_eq!(
vec![CheckpointingAction::Delete(1), CheckpointingAction::Save],
strategy.checkpointing(epoch, &store)
);
epoch += 1;
process_train(&mut processor, 1.0, epoch);
process_train(&mut processor, 3.0, epoch);
end_epoch(&mut processor, epoch);
assert!(strategy.checkpointing(epoch, &store).is_empty());
}
}