burn_train/checkpoint/strategy/
metric.rs

1use super::CheckpointingStrategy;
2use crate::{
3    checkpoint::CheckpointingAction,
4    metric::{
5        store::{Aggregate, Direction, EventStoreClient, Split},
6        Metric,
7    },
8};
9
10/// Keep the best checkpoint based on a metric.
11pub struct MetricCheckpointingStrategy {
12    current: Option<usize>,
13    aggregate: Aggregate,
14    direction: Direction,
15    split: Split,
16    name: String,
17}
18
19impl MetricCheckpointingStrategy {
20    /// Create a new metric strategy.
21    pub fn new<M>(aggregate: Aggregate, direction: Direction, split: Split) -> Self
22    where
23        M: Metric,
24    {
25        Self {
26            current: None,
27            name: M::NAME.to_string(),
28            aggregate,
29            direction,
30            split,
31        }
32    }
33}
34
35impl CheckpointingStrategy for MetricCheckpointingStrategy {
36    fn checkpointing(
37        &mut self,
38        epoch: usize,
39        store: &EventStoreClient,
40    ) -> Vec<CheckpointingAction> {
41        let best_epoch =
42            match store.find_epoch(&self.name, self.aggregate, self.direction, self.split) {
43                Some(epoch_best) => epoch_best,
44                None => epoch,
45            };
46
47        let mut actions = Vec::new();
48
49        if let Some(current) = self.current {
50            if current != best_epoch {
51                actions.push(CheckpointingAction::Delete(current));
52            }
53        }
54
55        if best_epoch == epoch {
56            actions.push(CheckpointingAction::Save);
57        }
58
59        self.current = Some(best_epoch);
60
61        actions
62    }
63}
64
65#[cfg(test)]
66mod tests {
67    use crate::{
68        logger::InMemoryMetricLogger,
69        metric::{
70            processor::{
71                test_utils::{end_epoch, process_train},
72                Metrics, MinimalEventProcessor,
73            },
74            store::LogEventStore,
75            LossMetric,
76        },
77        TestBackend,
78    };
79
80    use super::*;
81    use std::sync::Arc;
82
83    #[test]
84    fn always_keep_the_best_epoch() {
85        let mut store = LogEventStore::default();
86        let mut strategy = MetricCheckpointingStrategy::new::<LossMetric<TestBackend>>(
87            Aggregate::Mean,
88            Direction::Lowest,
89            Split::Train,
90        );
91        let mut metrics = Metrics::<f64, f64>::default();
92        // Register an in memory logger.
93        store.register_logger_train(InMemoryMetricLogger::default());
94        // Register the loss metric.
95        metrics.register_train_metric_numeric(LossMetric::<TestBackend>::new());
96        let store = Arc::new(EventStoreClient::new(store));
97        let mut processor = MinimalEventProcessor::new(metrics, store.clone());
98
99        // Two points for the first epoch. Mean 0.75
100        let mut epoch = 1;
101        process_train(&mut processor, 1.0, epoch);
102        process_train(&mut processor, 0.5, epoch);
103        end_epoch(&mut processor, epoch);
104
105        // Should save the current record.
106        assert_eq!(
107            vec![CheckpointingAction::Save],
108            strategy.checkpointing(epoch, &store)
109        );
110
111        // Two points for the second epoch. Mean 0.4
112        epoch += 1;
113        process_train(&mut processor, 0.5, epoch);
114        process_train(&mut processor, 0.3, epoch);
115        end_epoch(&mut processor, epoch);
116
117        // Should save the current record and delete the pervious one.
118        assert_eq!(
119            vec![CheckpointingAction::Delete(1), CheckpointingAction::Save],
120            strategy.checkpointing(epoch, &store)
121        );
122
123        // Two points for the last epoch. Mean 2.0
124        epoch += 1;
125        process_train(&mut processor, 1.0, epoch);
126        process_train(&mut processor, 3.0, epoch);
127        end_epoch(&mut processor, epoch);
128
129        // Should not delete the previous record, since it's the best one, and should not save a
130        // new one.
131        assert!(strategy.checkpointing(epoch, &store).is_empty());
132    }
133}