burn_train/checkpoint/strategy/
metric.rs1use super::CheckpointingStrategy;
2use crate::{
3 checkpoint::CheckpointingAction,
4 metric::{
5 store::{Aggregate, Direction, EventStoreClient, Split},
6 Metric,
7 },
8};
9
10pub struct MetricCheckpointingStrategy {
12 current: Option<usize>,
13 aggregate: Aggregate,
14 direction: Direction,
15 split: Split,
16 name: String,
17}
18
19impl MetricCheckpointingStrategy {
20 pub fn new<M>(aggregate: Aggregate, direction: Direction, split: Split) -> Self
22 where
23 M: Metric,
24 {
25 Self {
26 current: None,
27 name: M::NAME.to_string(),
28 aggregate,
29 direction,
30 split,
31 }
32 }
33}
34
35impl CheckpointingStrategy for MetricCheckpointingStrategy {
36 fn checkpointing(
37 &mut self,
38 epoch: usize,
39 store: &EventStoreClient,
40 ) -> Vec<CheckpointingAction> {
41 let best_epoch =
42 match store.find_epoch(&self.name, self.aggregate, self.direction, self.split) {
43 Some(epoch_best) => epoch_best,
44 None => epoch,
45 };
46
47 let mut actions = Vec::new();
48
49 if let Some(current) = self.current {
50 if current != best_epoch {
51 actions.push(CheckpointingAction::Delete(current));
52 }
53 }
54
55 if best_epoch == epoch {
56 actions.push(CheckpointingAction::Save);
57 }
58
59 self.current = Some(best_epoch);
60
61 actions
62 }
63}
64
65#[cfg(test)]
66mod tests {
67 use crate::{
68 logger::InMemoryMetricLogger,
69 metric::{
70 processor::{
71 test_utils::{end_epoch, process_train},
72 Metrics, MinimalEventProcessor,
73 },
74 store::LogEventStore,
75 LossMetric,
76 },
77 TestBackend,
78 };
79
80 use super::*;
81 use std::sync::Arc;
82
83 #[test]
84 fn always_keep_the_best_epoch() {
85 let mut store = LogEventStore::default();
86 let mut strategy = MetricCheckpointingStrategy::new::<LossMetric<TestBackend>>(
87 Aggregate::Mean,
88 Direction::Lowest,
89 Split::Train,
90 );
91 let mut metrics = Metrics::<f64, f64>::default();
92 store.register_logger_train(InMemoryMetricLogger::default());
94 metrics.register_train_metric_numeric(LossMetric::<TestBackend>::new());
96 let store = Arc::new(EventStoreClient::new(store));
97 let mut processor = MinimalEventProcessor::new(metrics, store.clone());
98
99 let mut epoch = 1;
101 process_train(&mut processor, 1.0, epoch);
102 process_train(&mut processor, 0.5, epoch);
103 end_epoch(&mut processor, epoch);
104
105 assert_eq!(
107 vec![CheckpointingAction::Save],
108 strategy.checkpointing(epoch, &store)
109 );
110
111 epoch += 1;
113 process_train(&mut processor, 0.5, epoch);
114 process_train(&mut processor, 0.3, epoch);
115 end_epoch(&mut processor, epoch);
116
117 assert_eq!(
119 vec![CheckpointingAction::Delete(1), CheckpointingAction::Save],
120 strategy.checkpointing(epoch, &store)
121 );
122
123 epoch += 1;
125 process_train(&mut processor, 1.0, epoch);
126 process_train(&mut processor, 3.0, epoch);
127 end_epoch(&mut processor, epoch);
128
129 assert!(strategy.checkpointing(epoch, &store).is_empty());
132 }
133}