burn_train/checkpoint/strategy/
composed.rs1use crate::metric::store::EventStoreClient;
2
3use super::{CheckpointingAction, CheckpointingStrategy};
4use std::collections::HashSet;
5
6pub struct ComposedCheckpointingStrategy {
9 strategies: Vec<Box<dyn CheckpointingStrategy>>,
10 deleted: Vec<HashSet<usize>>,
11}
12
13#[derive(Default)]
15pub struct ComposedCheckpointingStrategyBuilder {
16 strategies: Vec<Box<dyn CheckpointingStrategy>>,
17}
18
19impl ComposedCheckpointingStrategyBuilder {
20 #[allow(clippy::should_implement_trait)]
22 pub fn add<S>(mut self, strategy: S) -> Self
23 where
24 S: CheckpointingStrategy + 'static,
25 {
26 self.strategies.push(Box::new(strategy));
27 self
28 }
29
30 pub fn build(self) -> ComposedCheckpointingStrategy {
32 ComposedCheckpointingStrategy::new(self.strategies)
33 }
34}
35
36impl ComposedCheckpointingStrategy {
37 fn new(strategies: Vec<Box<dyn CheckpointingStrategy>>) -> Self {
38 Self {
39 deleted: strategies.iter().map(|_| HashSet::new()).collect(),
40 strategies,
41 }
42 }
43 pub fn builder() -> ComposedCheckpointingStrategyBuilder {
46 ComposedCheckpointingStrategyBuilder::default()
47 }
48}
49
50impl CheckpointingStrategy for ComposedCheckpointingStrategy {
51 fn checkpointing(
52 &mut self,
53 epoch: usize,
54 collector: &EventStoreClient,
55 ) -> Vec<CheckpointingAction> {
56 let mut saved = false;
57 let mut actions = Vec::new();
58 let mut epochs_to_check = Vec::new();
59
60 for (i, strategy) in self.strategies.iter_mut().enumerate() {
61 let actions = strategy.checkpointing(epoch, collector);
62 if actions.is_empty() {
65 self.deleted
66 .get_mut(i)
67 .expect("As many 'deleted' as 'strategies'.")
68 .insert(epoch);
69 }
70
71 for action in actions {
72 match action {
73 CheckpointingAction::Delete(epoch) => {
74 self.deleted
75 .get_mut(i)
76 .expect("As many 'deleted' as 'strategies'.")
77 .insert(epoch);
78 epochs_to_check.push(epoch);
79 }
80 CheckpointingAction::Save => saved = true,
81 }
82 }
83 }
84
85 if saved {
86 actions.push(CheckpointingAction::Save);
87 }
88
89 for epoch in epochs_to_check.into_iter() {
90 let mut num_true = 0;
91 for i in 0..self.strategies.len() {
92 if self
93 .deleted
94 .get(i)
95 .expect("Ad many 'deleted' as 'strategies'.")
96 .contains(&epoch)
97 {
98 num_true += 1;
99 }
100 }
101
102 if num_true == self.strategies.len() {
103 actions.push(CheckpointingAction::Delete(epoch));
104
105 for i in 0..self.strategies.len() {
106 self.deleted
107 .get_mut(i)
108 .expect("As many 'deleted' as 'strategies'.")
109 .remove(&epoch);
110 }
111 }
112 }
113
114 actions
115 }
116}
117
118#[cfg(test)]
119mod tests {
120 use super::*;
121 use crate::{checkpoint::KeepLastNCheckpoints, metric::store::LogEventStore};
122
123 #[test]
124 fn should_delete_when_both_deletes() {
125 let store = EventStoreClient::new(LogEventStore::default());
126 let mut strategy = ComposedCheckpointingStrategy::builder()
127 .add(KeepLastNCheckpoints::new(1))
128 .add(KeepLastNCheckpoints::new(2))
129 .build();
130
131 assert_eq!(
132 vec![CheckpointingAction::Save],
133 strategy.checkpointing(1, &store)
134 );
135
136 assert_eq!(
137 vec![CheckpointingAction::Save],
138 strategy.checkpointing(2, &store)
139 );
140
141 assert_eq!(
142 vec![CheckpointingAction::Save, CheckpointingAction::Delete(1)],
143 strategy.checkpointing(3, &store)
144 );
145 }
146}