entrenar/train/transformer_trainer/
distributed_checkpoint.rs1use std::path::Path;
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum CheckpointPhase {
21 Training,
23 SaveRequested,
25 WeightsSynced,
27 Writing,
29 Complete,
31}
32
33pub struct DistributedCheckpointCoordinator {
38 phase: CheckpointPhase,
40 acks_received: usize,
42 world_size: usize,
44 checkpoint_step: usize,
46}
47
48impl DistributedCheckpointCoordinator {
49 pub fn new(world_size: usize) -> Self {
51 Self { phase: CheckpointPhase::Training, acks_received: 0, world_size, checkpoint_step: 0 }
52 }
53
54 pub fn request_save(&mut self, step: usize) -> bool {
58 if self.phase != CheckpointPhase::Training {
59 return false;
60 }
61 self.phase = CheckpointPhase::SaveRequested;
62 self.checkpoint_step = step;
63 self.acks_received = 0;
64 true
65 }
66
67 pub fn worker_ready(&mut self) -> bool {
71 self.acks_received += 1;
72 if self.acks_received >= self.world_size {
73 self.phase = CheckpointPhase::WeightsSynced;
74 true
75 } else {
76 false
77 }
78 }
79
80 pub fn start_writing(&mut self) {
82 self.phase = CheckpointPhase::Writing;
83 }
84
85 pub fn complete(&mut self) {
87 self.phase = CheckpointPhase::Complete;
88 }
90
91 pub fn resume_training(&mut self) {
93 self.phase = CheckpointPhase::Training;
94 self.acks_received = 0;
95 }
96
97 pub fn phase(&self) -> CheckpointPhase {
99 self.phase
100 }
101
102 pub fn checkpoint_step(&self) -> usize {
104 self.checkpoint_step
105 }
106}
107
108pub fn verify_weight_consistency(local_hash: &[u8; 32], all_hashes: &[[u8; 32]]) -> bool {
118 all_hashes.iter().all(|h| h == local_hash)
119}
120
121pub fn hash_weights(weights: &[f32]) -> [u8; 32] {
125 let byte_len = weights.len() * 4;
127 let mut bytes = vec![0u8; byte_len];
128 for (i, &w) in weights.iter().enumerate() {
129 bytes[i * 4..(i + 1) * 4].copy_from_slice(&w.to_le_bytes());
130 }
131
132 use std::hash::{Hash, Hasher};
135 let mut hasher = std::collections::hash_map::DefaultHasher::new();
136 bytes.hash(&mut hasher);
137 let hash = hasher.finish();
138
139 let mut result = [0u8; 32];
140 result[..8].copy_from_slice(&hash.to_le_bytes());
141 let hash2 = hash.wrapping_mul(0x517cc1b727220a95);
143 result[8..16].copy_from_slice(&hash2.to_le_bytes());
144 let hash3 = hash2.wrapping_mul(0x6c62272e07bb0142);
145 result[16..24].copy_from_slice(&hash3.to_le_bytes());
146 let hash4 = hash3.wrapping_mul(0x62b821756295c58d);
147 result[24..32].copy_from_slice(&hash4.to_le_bytes());
148 result
149}
150
151pub fn should_save_checkpoint(step: usize, save_interval: usize, max_steps: Option<usize>) -> bool {
156 if save_interval == 0 {
157 return false;
158 }
159
160 if step > 0 && step.is_multiple_of(save_interval) {
162 return true;
163 }
164
165 if let Some(max) = max_steps {
167 if step >= max {
168 return true;
169 }
170 }
171
172 false
173}
174
175pub fn checkpoint_path(output_dir: &Path, step: usize) -> std::path::PathBuf {
177 output_dir.join(format!("checkpoint-{step}"))
178}
179
180#[cfg(test)]
181mod tests {
182 use super::*;
183
184 #[test]
185 fn test_checkpoint_coordinator_lifecycle() {
186 let mut coord = DistributedCheckpointCoordinator::new(3);
187 assert_eq!(coord.phase(), CheckpointPhase::Training);
188
189 assert!(coord.request_save(100));
191 assert_eq!(coord.phase(), CheckpointPhase::SaveRequested);
192 assert_eq!(coord.checkpoint_step(), 100);
193
194 assert!(!coord.request_save(101));
196
197 assert!(!coord.worker_ready()); assert!(!coord.worker_ready()); assert!(coord.worker_ready()); assert_eq!(coord.phase(), CheckpointPhase::WeightsSynced);
202
203 coord.start_writing();
205 assert_eq!(coord.phase(), CheckpointPhase::Writing);
206
207 coord.complete();
209 assert_eq!(coord.phase(), CheckpointPhase::Complete);
210
211 coord.resume_training();
213 assert_eq!(coord.phase(), CheckpointPhase::Training);
214 }
215
216 #[test]
217 fn test_verify_weight_consistency_identical() {
218 let hash = [42u8; 32];
219 let all = vec![[42u8; 32], [42u8; 32], [42u8; 32]];
220 assert!(verify_weight_consistency(&hash, &all));
221 }
222
223 #[test]
224 fn test_verify_weight_consistency_mismatch() {
225 let hash = [42u8; 32];
226 let mut bad = [42u8; 32];
227 bad[0] = 99;
228 let all = vec![[42u8; 32], bad, [42u8; 32]];
229 assert!(!verify_weight_consistency(&hash, &all));
230 }
231
232 #[test]
233 fn test_hash_weights_deterministic() {
234 let weights = vec![1.0f32, 2.0, 3.0];
235 let h1 = hash_weights(&weights);
236 let h2 = hash_weights(&weights);
237 assert_eq!(h1, h2);
238 }
239
240 #[test]
241 fn test_hash_weights_different_inputs() {
242 let a = hash_weights(&[1.0, 2.0, 3.0]);
243 let b = hash_weights(&[1.0, 2.0, 4.0]);
244 assert_ne!(a, b);
245 }
246
247 #[test]
248 fn test_should_save_checkpoint() {
249 assert!(!should_save_checkpoint(0, 25, Some(100)));
251 assert!(should_save_checkpoint(25, 25, Some(100)));
252 assert!(should_save_checkpoint(50, 25, Some(100)));
253 assert!(should_save_checkpoint(100, 25, Some(100))); assert!(!should_save_checkpoint(25, 0, Some(100)));
257
258 assert!(should_save_checkpoint(100, 1000, Some(100)));
260 }
261
262 #[test]
263 fn test_checkpoint_path() {
264 let path = checkpoint_path(Path::new("/tmp/checkpoints"), 500);
265 assert_eq!(path, std::path::PathBuf::from("/tmp/checkpoints/checkpoint-500"));
266 }
267}