use std::path::Path;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CheckpointPhase {
Training,
SaveRequested,
WeightsSynced,
Writing,
Complete,
}
pub struct DistributedCheckpointCoordinator {
phase: CheckpointPhase,
acks_received: usize,
world_size: usize,
checkpoint_step: usize,
}
impl DistributedCheckpointCoordinator {
pub fn new(world_size: usize) -> Self {
Self { phase: CheckpointPhase::Training, acks_received: 0, world_size, checkpoint_step: 0 }
}
pub fn request_save(&mut self, step: usize) -> bool {
if self.phase != CheckpointPhase::Training {
return false;
}
self.phase = CheckpointPhase::SaveRequested;
self.checkpoint_step = step;
self.acks_received = 0;
true
}
pub fn worker_ready(&mut self) -> bool {
self.acks_received += 1;
if self.acks_received >= self.world_size {
self.phase = CheckpointPhase::WeightsSynced;
true
} else {
false
}
}
pub fn start_writing(&mut self) {
self.phase = CheckpointPhase::Writing;
}
pub fn complete(&mut self) {
self.phase = CheckpointPhase::Complete;
}
pub fn resume_training(&mut self) {
self.phase = CheckpointPhase::Training;
self.acks_received = 0;
}
pub fn phase(&self) -> CheckpointPhase {
self.phase
}
pub fn checkpoint_step(&self) -> usize {
self.checkpoint_step
}
}
pub fn verify_weight_consistency(local_hash: &[u8; 32], all_hashes: &[[u8; 32]]) -> bool {
all_hashes.iter().all(|h| h == local_hash)
}
pub fn hash_weights(weights: &[f32]) -> [u8; 32] {
let byte_len = weights.len() * 4;
let mut bytes = vec![0u8; byte_len];
for (i, &w) in weights.iter().enumerate() {
bytes[i * 4..(i + 1) * 4].copy_from_slice(&w.to_le_bytes());
}
use std::hash::{Hash, Hasher};
let mut hasher = std::collections::hash_map::DefaultHasher::new();
bytes.hash(&mut hasher);
let hash = hasher.finish();
let mut result = [0u8; 32];
result[..8].copy_from_slice(&hash.to_le_bytes());
let hash2 = hash.wrapping_mul(0x517cc1b727220a95);
result[8..16].copy_from_slice(&hash2.to_le_bytes());
let hash3 = hash2.wrapping_mul(0x6c62272e07bb0142);
result[16..24].copy_from_slice(&hash3.to_le_bytes());
let hash4 = hash3.wrapping_mul(0x62b821756295c58d);
result[24..32].copy_from_slice(&hash4.to_le_bytes());
result
}
pub fn should_save_checkpoint(step: usize, save_interval: usize, max_steps: Option<usize>) -> bool {
if save_interval == 0 {
return false;
}
if step > 0 && step.is_multiple_of(save_interval) {
return true;
}
if let Some(max) = max_steps {
if step >= max {
return true;
}
}
false
}
pub fn checkpoint_path(output_dir: &Path, step: usize) -> std::path::PathBuf {
output_dir.join(format!("checkpoint-{step}"))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_checkpoint_coordinator_lifecycle() {
let mut coord = DistributedCheckpointCoordinator::new(3);
assert_eq!(coord.phase(), CheckpointPhase::Training);
assert!(coord.request_save(100));
assert_eq!(coord.phase(), CheckpointPhase::SaveRequested);
assert_eq!(coord.checkpoint_step(), 100);
assert!(!coord.request_save(101));
assert!(!coord.worker_ready()); assert!(!coord.worker_ready()); assert!(coord.worker_ready()); assert_eq!(coord.phase(), CheckpointPhase::WeightsSynced);
coord.start_writing();
assert_eq!(coord.phase(), CheckpointPhase::Writing);
coord.complete();
assert_eq!(coord.phase(), CheckpointPhase::Complete);
coord.resume_training();
assert_eq!(coord.phase(), CheckpointPhase::Training);
}
#[test]
fn test_verify_weight_consistency_identical() {
let hash = [42u8; 32];
let all = vec![[42u8; 32], [42u8; 32], [42u8; 32]];
assert!(verify_weight_consistency(&hash, &all));
}
#[test]
fn test_verify_weight_consistency_mismatch() {
let hash = [42u8; 32];
let mut bad = [42u8; 32];
bad[0] = 99;
let all = vec![[42u8; 32], bad, [42u8; 32]];
assert!(!verify_weight_consistency(&hash, &all));
}
#[test]
fn test_hash_weights_deterministic() {
let weights = vec![1.0f32, 2.0, 3.0];
let h1 = hash_weights(&weights);
let h2 = hash_weights(&weights);
assert_eq!(h1, h2);
}
#[test]
fn test_hash_weights_different_inputs() {
let a = hash_weights(&[1.0, 2.0, 3.0]);
let b = hash_weights(&[1.0, 2.0, 4.0]);
assert_ne!(a, b);
}
#[test]
fn test_should_save_checkpoint() {
assert!(!should_save_checkpoint(0, 25, Some(100)));
assert!(should_save_checkpoint(25, 25, Some(100)));
assert!(should_save_checkpoint(50, 25, Some(100)));
assert!(should_save_checkpoint(100, 25, Some(100)));
assert!(!should_save_checkpoint(25, 0, Some(100)));
assert!(should_save_checkpoint(100, 1000, Some(100)));
}
#[test]
fn test_checkpoint_path() {
let path = checkpoint_path(Path::new("/tmp/checkpoints"), 500);
assert_eq!(path, std::path::PathBuf::from("/tmp/checkpoints/checkpoint-500"));
}
}