burn_train/checkpoint/strategy/
metric.rs1use super::CheckpointingStrategy;
2use crate::{
3 checkpoint::CheckpointingAction,
4 metric::{
5 Metric, MetricName,
6 store::{Aggregate, Direction, EventStoreClient, Split},
7 },
8};
9
10pub struct MetricCheckpointingStrategy {
12 current: Option<usize>,
13 aggregate: Aggregate,
14 direction: Direction,
15 split: Split,
16 name: MetricName,
17}
18
19impl MetricCheckpointingStrategy {
20 pub fn new<M>(metric: &M, aggregate: Aggregate, direction: Direction, split: Split) -> Self
22 where
23 M: Metric,
24 {
25 Self {
26 current: None,
27 name: metric.name(),
28 aggregate,
29 direction,
30 split,
31 }
32 }
33}
34
35impl CheckpointingStrategy for MetricCheckpointingStrategy {
36 fn checkpointing(
37 &mut self,
38 epoch: usize,
39 store: &EventStoreClient,
40 ) -> Vec<CheckpointingAction> {
41 let best_epoch =
42 match store.find_epoch(&self.name, self.aggregate, self.direction, self.split) {
43 Some(epoch_best) => epoch_best,
44 None => epoch,
45 };
46
47 let mut actions = Vec::new();
48
49 if let Some(current) = self.current
50 && current != best_epoch
51 {
52 actions.push(CheckpointingAction::Delete(current));
53 }
54
55 if best_epoch == epoch {
56 actions.push(CheckpointingAction::Save);
57 }
58
59 self.current = Some(best_epoch);
60
61 actions
62 }
63}
64
65#[cfg(test)]
66mod tests {
67 use crate::{
68 TestBackend,
69 logger::InMemoryMetricLogger,
70 metric::{
71 LossMetric,
72 processor::{
73 MetricsTraining, MinimalEventProcessor,
74 test_utils::{end_epoch, process_train},
75 },
76 store::LogEventStore,
77 },
78 };
79
80 use super::*;
81 use std::sync::Arc;
82
83 #[test]
84 fn always_keep_the_best_epoch() {
85 let loss = LossMetric::<TestBackend>::new();
86 let mut store = LogEventStore::default();
87 let mut strategy = MetricCheckpointingStrategy::new(
88 &loss,
89 Aggregate::Mean,
90 Direction::Lowest,
91 Split::Train,
92 );
93 let mut metrics = MetricsTraining::<f64, f64>::default();
94 store.register_logger_train(InMemoryMetricLogger::default());
96 metrics.register_train_metric_numeric(loss);
98 let store = Arc::new(EventStoreClient::new(store));
99 let mut processor = MinimalEventProcessor::new(metrics, store.clone());
100
101 let mut epoch = 1;
103 process_train(&mut processor, 1.0, epoch);
104 process_train(&mut processor, 0.5, epoch);
105 end_epoch(&mut processor, epoch);
106
107 assert_eq!(
109 vec![CheckpointingAction::Save],
110 strategy.checkpointing(epoch, &store)
111 );
112
113 epoch += 1;
115 process_train(&mut processor, 0.5, epoch);
116 process_train(&mut processor, 0.3, epoch);
117 end_epoch(&mut processor, epoch);
118
119 assert_eq!(
121 vec![CheckpointingAction::Delete(1), CheckpointingAction::Save],
122 strategy.checkpointing(epoch, &store)
123 );
124
125 epoch += 1;
127 process_train(&mut processor, 1.0, epoch);
128 process_train(&mut processor, 3.0, epoch);
129 end_epoch(&mut processor, epoch);
130
131 assert!(strategy.checkpointing(epoch, &store).is_empty());
134 }
135}