entrenar 0.7.13

Training & Optimization library with autograd, LoRA, quantization, and model merging
Documentation
//! Distributed checkpoint save/load coordination.
//!
//! In DDP training, all workers hold identical weights (C-DDP-001).
//! Only rank 0 writes the checkpoint to avoid concurrent file writes.
//! A barrier ensures all workers sync weights to CPU before rank 0 saves.
//!
//! # Protocol
//!
//! 1. Coordinator broadcasts "save" command (heartbeat with special timestamp)
//! 2. All workers call `sync_weights_to_cpu()` on their CUDA trainers
//! 3. All workers send a "ready" acknowledgement
//! 4. Rank 0 writes checkpoint (model.safetensors + config.json)
//! 5. Rank 0 broadcasts "save complete"
//! 6. All workers resume training

use std::path::Path;

/// Distributed checkpoint state machine.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CheckpointPhase {
    /// Normal training, no checkpoint in progress
    Training,
    /// Coordinator has requested checkpoint save
    SaveRequested,
    /// Worker has synced weights to CPU and is ready
    WeightsSynced,
    /// Rank 0 is writing checkpoint
    Writing,
    /// Checkpoint complete, resume training
    Complete,
}

/// Coordinator for distributed checkpoint saves.
///
/// Tracks which workers have acknowledged the save request
/// and coordinates the barrier.
pub struct DistributedCheckpointCoordinator {
    /// Current phase
    phase: CheckpointPhase,
    /// Number of workers that have acknowledged
    acks_received: usize,
    /// Total number of workers expected
    world_size: usize,
    /// Step at which checkpoint was requested
    checkpoint_step: usize,
}

impl DistributedCheckpointCoordinator {
    /// Create a new coordinator.
    pub fn new(world_size: usize) -> Self {
        Self { phase: CheckpointPhase::Training, acks_received: 0, world_size, checkpoint_step: 0 }
    }

    /// Request a checkpoint save at the given step.
    ///
    /// Returns true if the request was accepted (no save already in progress).
    pub fn request_save(&mut self, step: usize) -> bool {
        if self.phase != CheckpointPhase::Training {
            return false;
        }
        self.phase = CheckpointPhase::SaveRequested;
        self.checkpoint_step = step;
        self.acks_received = 0;
        true
    }

    /// Record a worker's acknowledgement that weights are synced.
    ///
    /// Returns true if all workers have acknowledged (barrier complete).
    pub fn worker_ready(&mut self) -> bool {
        self.acks_received += 1;
        if self.acks_received >= self.world_size {
            self.phase = CheckpointPhase::WeightsSynced;
            true
        } else {
            false
        }
    }

    /// Mark that writing has started (only rank 0 calls this).
    pub fn start_writing(&mut self) {
        self.phase = CheckpointPhase::Writing;
    }

    /// Mark checkpoint as complete, resume training.
    pub fn complete(&mut self) {
        self.phase = CheckpointPhase::Complete;
        // Will transition back to Training on next step
    }

    /// Reset to training phase (call after checkpoint is done).
    pub fn resume_training(&mut self) {
        self.phase = CheckpointPhase::Training;
        self.acks_received = 0;
    }

    /// Get current phase.
    pub fn phase(&self) -> CheckpointPhase {
        self.phase
    }

    /// Get the step at which checkpoint was requested.
    pub fn checkpoint_step(&self) -> usize {
        self.checkpoint_step
    }
}

/// Verify checkpoint integrity across workers.
///
/// Each worker computes a BLAKE3 hash of their CPU model weights.
/// The coordinator collects hashes and verifies they're identical.
///
/// # Contract (C-DDP-001)
///
/// All workers must have identical weights. If hashes differ,
/// training is halted (Jidoka).
pub fn verify_weight_consistency(local_hash: &[u8; 32], all_hashes: &[[u8; 32]]) -> bool {
    all_hashes.iter().all(|h| h == local_hash)
}

/// Compute a BLAKE3 hash of weight data for consistency verification.
///
/// Uses the same BLAKE3 implementation as `apr train archive`.
pub fn hash_weights(weights: &[f32]) -> [u8; 32] {
    // Convert f32 to bytes and hash
    let byte_len = weights.len() * 4;
    let mut bytes = vec![0u8; byte_len];
    for (i, &w) in weights.iter().enumerate() {
        bytes[i * 4..(i + 1) * 4].copy_from_slice(&w.to_le_bytes());
    }

    // Simple BLAKE3-style hash (using built-in hasher as placeholder)
    // In production, use the blake3 crate
    use std::hash::{Hash, Hasher};
    let mut hasher = std::collections::hash_map::DefaultHasher::new();
    bytes.hash(&mut hasher);
    let hash = hasher.finish();

    let mut result = [0u8; 32];
    result[..8].copy_from_slice(&hash.to_le_bytes());
    // Fill remaining bytes with derived values for uniqueness
    let hash2 = hash.wrapping_mul(0x517cc1b727220a95);
    result[8..16].copy_from_slice(&hash2.to_le_bytes());
    let hash3 = hash2.wrapping_mul(0x6c62272e07bb0142);
    result[16..24].copy_from_slice(&hash3.to_le_bytes());
    let hash4 = hash3.wrapping_mul(0x62b821756295c58d);
    result[24..32].copy_from_slice(&hash4.to_le_bytes());
    result
}

/// Determine if a checkpoint should be saved at this step.
///
/// Follows the `save_interval` logic from the training config,
/// accounting for the distributed checkpoint overhead.
pub fn should_save_checkpoint(step: usize, save_interval: usize, max_steps: Option<usize>) -> bool {
    if save_interval == 0 {
        return false;
    }

    // Save at regular intervals
    if step > 0 && step.is_multiple_of(save_interval) {
        return true;
    }

    // Always save at the final step
    if let Some(max) = max_steps {
        if step >= max {
            return true;
        }
    }

    false
}

/// Get the checkpoint directory path for a given step.
pub fn checkpoint_path(output_dir: &Path, step: usize) -> std::path::PathBuf {
    output_dir.join(format!("checkpoint-{step}"))
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_checkpoint_coordinator_lifecycle() {
        let mut coord = DistributedCheckpointCoordinator::new(3);
        assert_eq!(coord.phase(), CheckpointPhase::Training);

        // Request save
        assert!(coord.request_save(100));
        assert_eq!(coord.phase(), CheckpointPhase::SaveRequested);
        assert_eq!(coord.checkpoint_step(), 100);

        // Can't request again while saving
        assert!(!coord.request_save(101));

        // Workers acknowledge
        assert!(!coord.worker_ready()); // 1 of 3
        assert!(!coord.worker_ready()); // 2 of 3
        assert!(coord.worker_ready()); // 3 of 3 — barrier complete
        assert_eq!(coord.phase(), CheckpointPhase::WeightsSynced);

        // Write
        coord.start_writing();
        assert_eq!(coord.phase(), CheckpointPhase::Writing);

        // Complete
        coord.complete();
        assert_eq!(coord.phase(), CheckpointPhase::Complete);

        // Resume
        coord.resume_training();
        assert_eq!(coord.phase(), CheckpointPhase::Training);
    }

    #[test]
    fn test_verify_weight_consistency_identical() {
        let hash = [42u8; 32];
        let all = vec![[42u8; 32], [42u8; 32], [42u8; 32]];
        assert!(verify_weight_consistency(&hash, &all));
    }

    #[test]
    fn test_verify_weight_consistency_mismatch() {
        let hash = [42u8; 32];
        let mut bad = [42u8; 32];
        bad[0] = 99;
        let all = vec![[42u8; 32], bad, [42u8; 32]];
        assert!(!verify_weight_consistency(&hash, &all));
    }

    #[test]
    fn test_hash_weights_deterministic() {
        let weights = vec![1.0f32, 2.0, 3.0];
        let h1 = hash_weights(&weights);
        let h2 = hash_weights(&weights);
        assert_eq!(h1, h2);
    }

    #[test]
    fn test_hash_weights_different_inputs() {
        let a = hash_weights(&[1.0, 2.0, 3.0]);
        let b = hash_weights(&[1.0, 2.0, 4.0]);
        assert_ne!(a, b);
    }

    #[test]
    fn test_should_save_checkpoint() {
        // Regular intervals
        assert!(!should_save_checkpoint(0, 25, Some(100)));
        assert!(should_save_checkpoint(25, 25, Some(100)));
        assert!(should_save_checkpoint(50, 25, Some(100)));
        assert!(should_save_checkpoint(100, 25, Some(100))); // final step

        // No interval
        assert!(!should_save_checkpoint(25, 0, Some(100)));

        // Final step
        assert!(should_save_checkpoint(100, 1000, Some(100)));
    }

    #[test]
    fn test_checkpoint_path() {
        let path = checkpoint_path(Path::new("/tmp/checkpoints"), 500);
        assert_eq!(path, std::path::PathBuf::from("/tmp/checkpoints/checkpoint-500"));
    }
}