burn_train/checkpoint/strategy/
lastn.rs1use super::CheckpointingStrategy;
2use crate::{checkpoint::CheckpointingAction, metric::store::EventStoreClient};
3
4#[derive(new)]
9pub struct KeepLastNCheckpoints {
10 num_keep: usize,
11}
12
13impl CheckpointingStrategy for KeepLastNCheckpoints {
14 fn checkpointing(
15 &mut self,
16 epoch: usize,
17 _store: &EventStoreClient,
18 ) -> Vec<CheckpointingAction> {
19 let mut actions = vec![CheckpointingAction::Save];
20
21 if let Some(epoch) = usize::checked_sub(epoch, self.num_keep)
22 && epoch > 0
23 {
24 actions.push(CheckpointingAction::Delete(epoch));
25 }
26
27 actions
28 }
29}
30
31#[cfg(test)]
32mod tests {
33 use super::*;
34 use crate::metric::store::LogEventStore;
35
36 #[test]
37 fn should_always_delete_lastn_epoch_if_higher_than_one() {
38 let mut strategy = KeepLastNCheckpoints::new(2);
39 let store = EventStoreClient::new(LogEventStore::default());
40
41 assert_eq!(
42 vec![CheckpointingAction::Save],
43 strategy.checkpointing(1, &store)
44 );
45
46 assert_eq!(
47 vec![CheckpointingAction::Save],
48 strategy.checkpointing(2, &store)
49 );
50
51 assert_eq!(
52 vec![CheckpointingAction::Save, CheckpointingAction::Delete(1)],
53 strategy.checkpointing(3, &store)
54 );
55 }
56}