burn_train/checkpoint/strategy/
lastn.rs

1use super::CheckpointingStrategy;
2use crate::{checkpoint::CheckpointingAction, metric::store::EventStoreClient};
3
4/// Keep the last N checkpoints.
5///
6/// Very useful when training, minimizing disk space while ensuring that the training can be
7/// resumed even if something goes wrong.
8#[derive(new)]
9pub struct KeepLastNCheckpoints {
10    num_keep: usize,
11}
12
13impl CheckpointingStrategy for KeepLastNCheckpoints {
14    fn checkpointing(
15        &mut self,
16        epoch: usize,
17        _store: &EventStoreClient,
18    ) -> Vec<CheckpointingAction> {
19        let mut actions = vec![CheckpointingAction::Save];
20
21        if let Some(epoch) = usize::checked_sub(epoch, self.num_keep)
22            && epoch > 0
23        {
24            actions.push(CheckpointingAction::Delete(epoch));
25        }
26
27        actions
28    }
29}
30
31#[cfg(test)]
32mod tests {
33    use super::*;
34    use crate::metric::store::LogEventStore;
35
36    #[test]
37    fn should_always_delete_lastn_epoch_if_higher_than_one() {
38        let mut strategy = KeepLastNCheckpoints::new(2);
39        let store = EventStoreClient::new(LogEventStore::default());
40
41        assert_eq!(
42            vec![CheckpointingAction::Save],
43            strategy.checkpointing(1, &store)
44        );
45
46        assert_eq!(
47            vec![CheckpointingAction::Save],
48            strategy.checkpointing(2, &store)
49        );
50
51        assert_eq!(
52            vec![CheckpointingAction::Save, CheckpointingAction::Delete(1)],
53            strategy.checkpointing(3, &store)
54        );
55    }
56}