burn_train/checkpoint/strategy/
metric.rs

1use super::CheckpointingStrategy;
2use crate::{
3    checkpoint::CheckpointingAction,
4    metric::{
5        Metric, MetricName,
6        store::{Aggregate, Direction, EventStoreClient, Split},
7    },
8};
9
10/// Keep the best checkpoint based on a metric.
11pub struct MetricCheckpointingStrategy {
12    current: Option<usize>,
13    aggregate: Aggregate,
14    direction: Direction,
15    split: Split,
16    name: MetricName,
17}
18
19impl MetricCheckpointingStrategy {
20    /// Create a new metric checkpointing strategy.
21    pub fn new<M>(metric: &M, aggregate: Aggregate, direction: Direction, split: Split) -> Self
22    where
23        M: Metric,
24    {
25        Self {
26            current: None,
27            name: metric.name(),
28            aggregate,
29            direction,
30            split,
31        }
32    }
33}
34
35impl CheckpointingStrategy for MetricCheckpointingStrategy {
36    fn checkpointing(
37        &mut self,
38        epoch: usize,
39        store: &EventStoreClient,
40    ) -> Vec<CheckpointingAction> {
41        let best_epoch =
42            match store.find_epoch(&self.name, self.aggregate, self.direction, self.split) {
43                Some(epoch_best) => epoch_best,
44                None => epoch,
45            };
46
47        let mut actions = Vec::new();
48
49        if let Some(current) = self.current
50            && current != best_epoch
51        {
52            actions.push(CheckpointingAction::Delete(current));
53        }
54
55        if best_epoch == epoch {
56            actions.push(CheckpointingAction::Save);
57        }
58
59        self.current = Some(best_epoch);
60
61        actions
62    }
63}
64
65#[cfg(test)]
66mod tests {
67    use crate::{
68        TestBackend,
69        logger::InMemoryMetricLogger,
70        metric::{
71            LossMetric,
72            processor::{
73                MetricsTraining, MinimalEventProcessor,
74                test_utils::{end_epoch, process_train},
75            },
76            store::LogEventStore,
77        },
78    };
79
80    use super::*;
81    use std::sync::Arc;
82
83    #[test]
84    fn always_keep_the_best_epoch() {
85        let loss = LossMetric::<TestBackend>::new();
86        let mut store = LogEventStore::default();
87        let mut strategy = MetricCheckpointingStrategy::new(
88            &loss,
89            Aggregate::Mean,
90            Direction::Lowest,
91            Split::Train,
92        );
93        let mut metrics = MetricsTraining::<f64, f64>::default();
94        // Register an in memory logger.
95        store.register_logger_train(InMemoryMetricLogger::default());
96        // Register the loss metric.
97        metrics.register_train_metric_numeric(loss);
98        let store = Arc::new(EventStoreClient::new(store));
99        let mut processor = MinimalEventProcessor::new(metrics, store.clone());
100
101        // Two points for the first epoch. Mean 0.75
102        let mut epoch = 1;
103        process_train(&mut processor, 1.0, epoch);
104        process_train(&mut processor, 0.5, epoch);
105        end_epoch(&mut processor, epoch);
106
107        // Should save the current record.
108        assert_eq!(
109            vec![CheckpointingAction::Save],
110            strategy.checkpointing(epoch, &store)
111        );
112
113        // Two points for the second epoch. Mean 0.4
114        epoch += 1;
115        process_train(&mut processor, 0.5, epoch);
116        process_train(&mut processor, 0.3, epoch);
117        end_epoch(&mut processor, epoch);
118
119        // Should save the current record and delete the previous one.
120        assert_eq!(
121            vec![CheckpointingAction::Delete(1), CheckpointingAction::Save],
122            strategy.checkpointing(epoch, &store)
123        );
124
125        // Two points for the last epoch. Mean 2.0
126        epoch += 1;
127        process_train(&mut processor, 1.0, epoch);
128        process_train(&mut processor, 3.0, epoch);
129        end_epoch(&mut processor, epoch);
130
131        // Should not delete the previous record, since it's the best one, and should not save a
132        // new one.
133        assert!(strategy.checkpointing(epoch, &store).is_empty());
134    }
135}