burn_train/checkpoint/strategy/
composed.rs

1use crate::metric::store::EventStoreClient;
2
3use super::{CheckpointingAction, CheckpointingStrategy};
4use std::collections::HashSet;
5
6/// Compose multiple checkpointing strategy and only delete checkpoints when both strategy flag an
7/// epoch to be deleted.
8pub struct ComposedCheckpointingStrategy {
9    strategies: Vec<Box<dyn CheckpointingStrategy>>,
10    deleted: Vec<HashSet<usize>>,
11}
12
13/// Help building a [checkpointing strategy](CheckpointingStrategy) by combining multiple ones.
14#[derive(Default)]
15pub struct ComposedCheckpointingStrategyBuilder {
16    strategies: Vec<Box<dyn CheckpointingStrategy>>,
17}
18
19impl ComposedCheckpointingStrategyBuilder {
20    /// Add a new [checkpointing strategy](CheckpointingStrategy).
21    #[allow(clippy::should_implement_trait)]
22    pub fn add<S>(mut self, strategy: S) -> Self
23    where
24        S: CheckpointingStrategy + 'static,
25    {
26        self.strategies.push(Box::new(strategy));
27        self
28    }
29
30    /// Create a new [composed checkpointing strategy](ComposedCheckpointingStrategy).
31    pub fn build(self) -> ComposedCheckpointingStrategy {
32        ComposedCheckpointingStrategy::new(self.strategies)
33    }
34}
35
36impl ComposedCheckpointingStrategy {
37    fn new(strategies: Vec<Box<dyn CheckpointingStrategy>>) -> Self {
38        Self {
39            deleted: strategies.iter().map(|_| HashSet::new()).collect(),
40            strategies,
41        }
42    }
43    /// Create a new builder which help compose multiple
44    /// [checkpointing strategies](CheckpointingStrategy).
45    pub fn builder() -> ComposedCheckpointingStrategyBuilder {
46        ComposedCheckpointingStrategyBuilder::default()
47    }
48}
49
50impl CheckpointingStrategy for ComposedCheckpointingStrategy {
51    fn checkpointing(
52        &mut self,
53        epoch: usize,
54        collector: &EventStoreClient,
55    ) -> Vec<CheckpointingAction> {
56        let mut saved = false;
57        let mut actions = Vec::new();
58        let mut epochs_to_check = Vec::new();
59
60        for (i, strategy) in self.strategies.iter_mut().enumerate() {
61            let actions = strategy.checkpointing(epoch, collector);
62            // We assume that the strategy would not want the current epoch to be saved.
63            // So we flag it as deleted.
64            if actions.is_empty() {
65                self.deleted
66                    .get_mut(i)
67                    .expect("As many 'deleted' as 'strategies'.")
68                    .insert(epoch);
69            }
70
71            for action in actions {
72                match action {
73                    CheckpointingAction::Delete(epoch) => {
74                        self.deleted
75                            .get_mut(i)
76                            .expect("As many 'deleted' as 'strategies'.")
77                            .insert(epoch);
78                        epochs_to_check.push(epoch);
79                    }
80                    CheckpointingAction::Save => saved = true,
81                }
82            }
83        }
84
85        if saved {
86            actions.push(CheckpointingAction::Save);
87        }
88
89        for epoch in epochs_to_check.into_iter() {
90            let mut num_true = 0;
91            for i in 0..self.strategies.len() {
92                if self
93                    .deleted
94                    .get(i)
95                    .expect("Ad many 'deleted' as 'strategies'.")
96                    .contains(&epoch)
97                {
98                    num_true += 1;
99                }
100            }
101
102            if num_true == self.strategies.len() {
103                actions.push(CheckpointingAction::Delete(epoch));
104
105                for i in 0..self.strategies.len() {
106                    self.deleted
107                        .get_mut(i)
108                        .expect("As many 'deleted' as 'strategies'.")
109                        .remove(&epoch);
110                }
111            }
112        }
113
114        actions
115    }
116}
117
118#[cfg(test)]
119mod tests {
120    use super::*;
121    use crate::{checkpoint::KeepLastNCheckpoints, metric::store::LogEventStore};
122
123    #[test]
124    fn should_delete_when_both_deletes() {
125        let store = EventStoreClient::new(LogEventStore::default());
126        let mut strategy = ComposedCheckpointingStrategy::builder()
127            .add(KeepLastNCheckpoints::new(1))
128            .add(KeepLastNCheckpoints::new(2))
129            .build();
130
131        assert_eq!(
132            vec![CheckpointingAction::Save],
133            strategy.checkpointing(1, &store)
134        );
135
136        assert_eq!(
137            vec![CheckpointingAction::Save],
138            strategy.checkpointing(2, &store)
139        );
140
141        assert_eq!(
142            vec![CheckpointingAction::Save, CheckpointingAction::Delete(1)],
143            strategy.checkpointing(3, &store)
144        );
145    }
146}