1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
use super::CheckpointingStrategy;
use crate::{checkpoint::CheckpointingAction, metric::store::EventStoreClient};

/// Keep the last N checkpoints.
///
/// Very useful when training, minimizing disk space while ensuring that the training can be
/// resumed even if something goes wrong.
#[derive(new)]
pub struct KeepLastNCheckpoints {
    num_keep: usize,
}

impl CheckpointingStrategy for KeepLastNCheckpoints {
    fn checkpointing(
        &mut self,
        epoch: usize,
        _store: &EventStoreClient,
    ) -> Vec<CheckpointingAction> {
        let mut actions = vec![CheckpointingAction::Save];

        if let Some(epoch) = usize::checked_sub(epoch, self.num_keep) {
            if epoch > 0 {
                actions.push(CheckpointingAction::Delete(epoch));
            }
        }

        actions
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::metric::store::LogEventStore;

    #[test]
    fn should_always_delete_lastn_epoch_if_higher_than_one() {
        let mut strategy = KeepLastNCheckpoints::new(2);
        let store = EventStoreClient::new(LogEventStore::default());

        assert_eq!(
            vec![CheckpointingAction::Save],
            strategy.checkpointing(1, &store)
        );

        assert_eq!(
            vec![CheckpointingAction::Save],
            strategy.checkpointing(2, &store)
        );

        assert_eq!(
            vec![CheckpointingAction::Save, CheckpointingAction::Delete(1)],
            strategy.checkpointing(3, &store)
        );
    }
}