Skip to main content

entrenar/train/transformer_trainer/
distributed_checkpoint.rs

1//! Distributed checkpoint save/load coordination.
2//!
3//! In DDP training, all workers hold identical weights (C-DDP-001).
4//! Only rank 0 writes the checkpoint to avoid concurrent file writes.
5//! A barrier ensures all workers sync weights to CPU before rank 0 saves.
6//!
7//! # Protocol
8//!
9//! 1. Coordinator broadcasts "save" command (heartbeat with special timestamp)
10//! 2. All workers call `sync_weights_to_cpu()` on their CUDA trainers
11//! 3. All workers send a "ready" acknowledgement
12//! 4. Rank 0 writes checkpoint (model.safetensors + config.json)
13//! 5. Rank 0 broadcasts "save complete"
14//! 6. All workers resume training
15
16use std::path::Path;
17
18/// Distributed checkpoint state machine.
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum CheckpointPhase {
21    /// Normal training, no checkpoint in progress
22    Training,
23    /// Coordinator has requested checkpoint save
24    SaveRequested,
25    /// Worker has synced weights to CPU and is ready
26    WeightsSynced,
27    /// Rank 0 is writing checkpoint
28    Writing,
29    /// Checkpoint complete, resume training
30    Complete,
31}
32
33/// Coordinator for distributed checkpoint saves.
34///
35/// Tracks which workers have acknowledged the save request
36/// and coordinates the barrier.
37pub struct DistributedCheckpointCoordinator {
38    /// Current phase
39    phase: CheckpointPhase,
40    /// Number of workers that have acknowledged
41    acks_received: usize,
42    /// Total number of workers expected
43    world_size: usize,
44    /// Step at which checkpoint was requested
45    checkpoint_step: usize,
46}
47
48impl DistributedCheckpointCoordinator {
49    /// Create a new coordinator.
50    pub fn new(world_size: usize) -> Self {
51        Self { phase: CheckpointPhase::Training, acks_received: 0, world_size, checkpoint_step: 0 }
52    }
53
54    /// Request a checkpoint save at the given step.
55    ///
56    /// Returns true if the request was accepted (no save already in progress).
57    pub fn request_save(&mut self, step: usize) -> bool {
58        if self.phase != CheckpointPhase::Training {
59            return false;
60        }
61        self.phase = CheckpointPhase::SaveRequested;
62        self.checkpoint_step = step;
63        self.acks_received = 0;
64        true
65    }
66
67    /// Record a worker's acknowledgement that weights are synced.
68    ///
69    /// Returns true if all workers have acknowledged (barrier complete).
70    pub fn worker_ready(&mut self) -> bool {
71        self.acks_received += 1;
72        if self.acks_received >= self.world_size {
73            self.phase = CheckpointPhase::WeightsSynced;
74            true
75        } else {
76            false
77        }
78    }
79
80    /// Mark that writing has started (only rank 0 calls this).
81    pub fn start_writing(&mut self) {
82        self.phase = CheckpointPhase::Writing;
83    }
84
85    /// Mark checkpoint as complete, resume training.
86    pub fn complete(&mut self) {
87        self.phase = CheckpointPhase::Complete;
88        // Will transition back to Training on next step
89    }
90
91    /// Reset to training phase (call after checkpoint is done).
92    pub fn resume_training(&mut self) {
93        self.phase = CheckpointPhase::Training;
94        self.acks_received = 0;
95    }
96
97    /// Get current phase.
98    pub fn phase(&self) -> CheckpointPhase {
99        self.phase
100    }
101
102    /// Get the step at which checkpoint was requested.
103    pub fn checkpoint_step(&self) -> usize {
104        self.checkpoint_step
105    }
106}
107
108/// Verify checkpoint integrity across workers.
109///
110/// Each worker computes a BLAKE3 hash of their CPU model weights.
111/// The coordinator collects hashes and verifies they're identical.
112///
113/// # Contract (C-DDP-001)
114///
115/// All workers must have identical weights. If hashes differ,
116/// training is halted (Jidoka).
117pub fn verify_weight_consistency(local_hash: &[u8; 32], all_hashes: &[[u8; 32]]) -> bool {
118    all_hashes.iter().all(|h| h == local_hash)
119}
120
121/// Compute a BLAKE3 hash of weight data for consistency verification.
122///
123/// Uses the same BLAKE3 implementation as `apr train archive`.
124pub fn hash_weights(weights: &[f32]) -> [u8; 32] {
125    // Convert f32 to bytes and hash
126    let byte_len = weights.len() * 4;
127    let mut bytes = vec![0u8; byte_len];
128    for (i, &w) in weights.iter().enumerate() {
129        bytes[i * 4..(i + 1) * 4].copy_from_slice(&w.to_le_bytes());
130    }
131
132    // Simple BLAKE3-style hash (using built-in hasher as placeholder)
133    // In production, use the blake3 crate
134    use std::hash::{Hash, Hasher};
135    let mut hasher = std::collections::hash_map::DefaultHasher::new();
136    bytes.hash(&mut hasher);
137    let hash = hasher.finish();
138
139    let mut result = [0u8; 32];
140    result[..8].copy_from_slice(&hash.to_le_bytes());
141    // Fill remaining bytes with derived values for uniqueness
142    let hash2 = hash.wrapping_mul(0x517cc1b727220a95);
143    result[8..16].copy_from_slice(&hash2.to_le_bytes());
144    let hash3 = hash2.wrapping_mul(0x6c62272e07bb0142);
145    result[16..24].copy_from_slice(&hash3.to_le_bytes());
146    let hash4 = hash3.wrapping_mul(0x62b821756295c58d);
147    result[24..32].copy_from_slice(&hash4.to_le_bytes());
148    result
149}
150
151/// Determine if a checkpoint should be saved at this step.
152///
153/// Follows the `save_interval` logic from the training config,
154/// accounting for the distributed checkpoint overhead.
155pub fn should_save_checkpoint(step: usize, save_interval: usize, max_steps: Option<usize>) -> bool {
156    if save_interval == 0 {
157        return false;
158    }
159
160    // Save at regular intervals
161    if step > 0 && step.is_multiple_of(save_interval) {
162        return true;
163    }
164
165    // Always save at the final step
166    if let Some(max) = max_steps {
167        if step >= max {
168            return true;
169        }
170    }
171
172    false
173}
174
175/// Get the checkpoint directory path for a given step.
176pub fn checkpoint_path(output_dir: &Path, step: usize) -> std::path::PathBuf {
177    output_dir.join(format!("checkpoint-{step}"))
178}
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183
184    #[test]
185    fn test_checkpoint_coordinator_lifecycle() {
186        let mut coord = DistributedCheckpointCoordinator::new(3);
187        assert_eq!(coord.phase(), CheckpointPhase::Training);
188
189        // Request save
190        assert!(coord.request_save(100));
191        assert_eq!(coord.phase(), CheckpointPhase::SaveRequested);
192        assert_eq!(coord.checkpoint_step(), 100);
193
194        // Can't request again while saving
195        assert!(!coord.request_save(101));
196
197        // Workers acknowledge
198        assert!(!coord.worker_ready()); // 1 of 3
199        assert!(!coord.worker_ready()); // 2 of 3
200        assert!(coord.worker_ready()); // 3 of 3 — barrier complete
201        assert_eq!(coord.phase(), CheckpointPhase::WeightsSynced);
202
203        // Write
204        coord.start_writing();
205        assert_eq!(coord.phase(), CheckpointPhase::Writing);
206
207        // Complete
208        coord.complete();
209        assert_eq!(coord.phase(), CheckpointPhase::Complete);
210
211        // Resume
212        coord.resume_training();
213        assert_eq!(coord.phase(), CheckpointPhase::Training);
214    }
215
216    #[test]
217    fn test_verify_weight_consistency_identical() {
218        let hash = [42u8; 32];
219        let all = vec![[42u8; 32], [42u8; 32], [42u8; 32]];
220        assert!(verify_weight_consistency(&hash, &all));
221    }
222
223    #[test]
224    fn test_verify_weight_consistency_mismatch() {
225        let hash = [42u8; 32];
226        let mut bad = [42u8; 32];
227        bad[0] = 99;
228        let all = vec![[42u8; 32], bad, [42u8; 32]];
229        assert!(!verify_weight_consistency(&hash, &all));
230    }
231
232    #[test]
233    fn test_hash_weights_deterministic() {
234        let weights = vec![1.0f32, 2.0, 3.0];
235        let h1 = hash_weights(&weights);
236        let h2 = hash_weights(&weights);
237        assert_eq!(h1, h2);
238    }
239
240    #[test]
241    fn test_hash_weights_different_inputs() {
242        let a = hash_weights(&[1.0, 2.0, 3.0]);
243        let b = hash_weights(&[1.0, 2.0, 4.0]);
244        assert_ne!(a, b);
245    }
246
247    #[test]
248    fn test_should_save_checkpoint() {
249        // Regular intervals
250        assert!(!should_save_checkpoint(0, 25, Some(100)));
251        assert!(should_save_checkpoint(25, 25, Some(100)));
252        assert!(should_save_checkpoint(50, 25, Some(100)));
253        assert!(should_save_checkpoint(100, 25, Some(100))); // final step
254
255        // No interval
256        assert!(!should_save_checkpoint(25, 0, Some(100)));
257
258        // Final step
259        assert!(should_save_checkpoint(100, 1000, Some(100)));
260    }
261
262    #[test]
263    fn test_checkpoint_path() {
264        let path = checkpoint_path(Path::new("/tmp/checkpoints"), 500);
265        assert_eq!(path, std::path::PathBuf::from("/tmp/checkpoints/checkpoint-500"));
266    }
267}