flodl 0.5.2

floDl — a flow-graph deep learning framework built on libtorch
Documentation
//! Unified convergence guard for DDP parameter averaging.
//!
//! Monitors weight-space divergence across replicas and recommends cadence
//! adjustments. Replaces the old dual system (loss-based for NCCL, norm-based
//! for CPU) with a single guard that both paths feed into.
//!
//! The primary signal is `||params_before_sync - params_after_sync|| / ||params_after_sync||`,
//! which measures each replica's L2 distance from the consensus (average).
//! Three norms (pre, post, diff) give a full decomposition into magnitude
//! shift and directional coherence for free.

use std::collections::VecDeque;

use super::ApplyPolicy;

// ---------------------------------------------------------------------------
// Divergence report
// ---------------------------------------------------------------------------

/// Divergence measurements from a single sync event.
///
/// The primary signal is `deltas` -- per-rank normalized L2 distance from the
/// replica's pre-sync params to the post-sync average. This captures both
/// magnitude drift and directional divergence.
///
/// When `pre_norms` and `post_norm` are available, the decomposition into
/// magnitude shift and cosine similarity is free via:
/// `cos = (||pre||^2 + ||post||^2 - (delta * ||post||)^2) / (2 * ||pre|| * ||post||)`
pub struct DivergenceReport {
    /// Per-rank normalized delta: `||pre - post|| / ||post||`.
    pub deltas: Vec<f64>,
    /// Per-rank pre-sync param L2 norm. `None` when not computed (NCCL v1).
    pub pre_norms: Option<Vec<f64>>,
    /// Post-sync param L2 norm (same across ranks after AllReduce).
    /// `None` when not computed (NCCL v1).
    pub post_norm: Option<f64>,
}

impl DivergenceReport {
    /// Max normalized delta (`||pre-post|| / ||post||`) across ranks.
    /// This is the primary cadence signal: worst-case replica drift.
    pub fn max_relative_delta(&self) -> f64 {
        self.deltas.iter().copied().fold(0.0_f64, f64::max)
    }

    /// Per-rank cosine similarity between pre-sync and post-sync params.
    ///
    /// Computed via the identity:
    /// `cos = (||pre||^2 + ||post||^2 - ||pre-post||^2) / (2 * ||pre|| * ||post||)`
    ///
    /// Returns `None` if `pre_norms` or `post_norm` are not available.
    /// Values near 1.0 = directionally aligned, near 0.0 = orthogonal.
    pub fn cosine_similarities(&self) -> Option<Vec<f64>> {
        let pre_norms = self.pre_norms.as_ref()?;
        let post_norm = self.post_norm?;
        if post_norm < 1e-10 {
            return None;
        }
        Some(
            self.deltas
                .iter()
                .zip(pre_norms)
                .map(|(&delta, &pre_norm)| {
                    if pre_norm < 1e-10 {
                        return 0.0;
                    }
                    // ||pre - post||^2 = (delta * ||post||)^2 since delta is normalized
                    let diff_sq = (delta * post_norm).powi(2);
                    let pre_sq = pre_norm.powi(2);
                    let post_sq = post_norm.powi(2);
                    ((pre_sq + post_sq - diff_sq) / (2.0 * pre_norm * post_norm)).clamp(-1.0, 1.0)
                })
                .collect(),
        )
    }

    /// Per-rank magnitude shift: `||pre|| - ||post||`.
    ///
    /// Positive = replica had larger weights than consensus (possible overfitting).
    /// Negative = replica had smaller weights.
    /// Returns `None` if `pre_norms` or `post_norm` are not available.
    pub fn magnitude_shifts(&self) -> Option<Vec<f64>> {
        let pre_norms = self.pre_norms.as_ref()?;
        let post_norm = self.post_norm?;
        Some(pre_norms.iter().map(|&pre| pre - post_norm).collect())
    }
}

// ---------------------------------------------------------------------------
// Convergence action
// ---------------------------------------------------------------------------

/// Recommended cadence adjustment from the convergence guard.
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum ConvergenceAction {
    /// No divergence concern. Cadence can grow normally.
    Stable,
    /// Divergence trending up. Suppress anchor/overshoot growth (hold current values).
    SuppressGrowth,
    /// High sustained divergence. Nudge anchor down by `factor` (0.0-1.0).
    /// Factor 0.5 = halve the anchor. Same as the existing CPU path behavior.
    NudgeDown { factor: f64 },
}

// ---------------------------------------------------------------------------
// Convergence guard
// ---------------------------------------------------------------------------

/// Unified convergence guard for both NCCL and CPU averaging paths.
///
/// Owns the divergence ring buffer and mode-specific reaction logic.
/// The coordinator feeds [`DivergenceReport`]s after each sync event;
/// the guard returns a [`ConvergenceAction`] that the coordinator applies
/// to ElChe's anchor and overshoot controls.
///
/// The ring buffer persists across sync intervals (required for trend
/// detection). Per-interval accumulators (`nccl_sync_divergence`) live
/// in the coordinator and are reset there.
pub struct ConvergenceGuard {
    policy: ApplyPolicy,
    enabled: bool,
    threshold: f64,
    /// Ring buffer of `max_relative_delta` from recent sync intervals (up to 5).
    history: VecDeque<f64>,
}


impl ConvergenceGuard {
    /// Create a new guard for the given training policy.
    ///
    /// `threshold` controls sensitivity. Set high (e.g. 1.0) for log-only mode
    /// during calibration. A value of ~2x steady-state divergence is typical.
    pub fn new(policy: ApplyPolicy, enabled: bool, threshold: f64) -> Self {
        Self {
            policy,
            enabled,
            threshold,
            history: VecDeque::with_capacity(6),
        }
    }

    /// Feed a divergence report and get a cadence action.
    ///
    /// Mode-specific logic:
    /// - `Sync`: always `Stable` (sync every batch, divergence near-zero by construction).
    ///   In practice, the worker skips divergence computation entirely for Sync mode.
    /// - `Cadence`: trend detection (3 consecutive rising values above threshold).
    /// - `Async`: same as Cadence for v1 (spike detection deferred to v2).
    pub fn report(&mut self, report: &DivergenceReport) -> ConvergenceAction {
        if !self.enabled {
            return ConvergenceAction::Stable;
        }

        let divergence = report.max_relative_delta();

        // Push to ring buffer (keep last 5 intervals).
        if self.history.len() >= 5 {
            self.history.pop_front();
        }
        self.history.push_back(divergence);

        match self.policy {
            ApplyPolicy::Sync => ConvergenceAction::Stable,
            ApplyPolicy::Cadence | ApplyPolicy::Async => self.check_trend(),
        }
    }

    /// Trend detection: 3 consecutive rising values, latest above threshold.
    ///
    /// Filters out single spikes and noise. Only sustained drift suppresses
    /// growth. Returns `SuppressGrowth` (not `NudgeDown`) because aggressive
    /// anchor reduction on trends alone can kill convergence. The overhead
    /// auto-tune in ElChe handles loosening.
    fn check_trend(&self) -> ConvergenceAction {
        if self.history.len() < 3 {
            return ConvergenceAction::Stable;
        }

        let len = self.history.len();
        let rising = self.history[len - 1] > self.history[len - 2]
            && self.history[len - 2] > self.history[len - 3]
            && self.history[len - 1] > self.threshold;

        if rising {
            crate::verbose!(
                "  ddp: weight-space divergence trending up | history={:.4?} | suppressing growth",
                Vec::from(self.history.clone()),
            );
            ConvergenceAction::SuppressGrowth
        } else {
            ConvergenceAction::Stable
        }
    }

    /// Direct access to divergence history for testing and monitoring.
    pub fn history(&self) -> &VecDeque<f64> {
        &self.history
    }
}

// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------

#[cfg(test)]
mod tests {
    use super::*;

    fn make_report(deltas: &[f64]) -> DivergenceReport {
        DivergenceReport {
            deltas: deltas.to_vec(),
            pre_norms: None,
            post_norm: None,
        }
    }

    fn make_full_report(
        deltas: &[f64],
        pre_norms: &[f64],
        post_norm: f64,
    ) -> DivergenceReport {
        DivergenceReport {
            deltas: deltas.to_vec(),
            pre_norms: Some(pre_norms.to_vec()),
            post_norm: Some(post_norm),
        }
    }

    // --- DivergenceReport ---

    #[test]
    fn max_relative_delta_picks_worst_rank() {
        let r = make_report(&[0.01, 0.05, 0.03]);
        assert!((r.max_relative_delta() - 0.05).abs() < 1e-10);
    }

    #[test]
    fn max_relative_delta_empty_is_zero() {
        let r = make_report(&[]);
        assert_eq!(r.max_relative_delta(), 0.0);
    }

    #[test]
    fn cosine_similarities_none_when_missing_norms() {
        let r = make_report(&[0.01, 0.02]);
        assert!(r.cosine_similarities().is_none());
        assert!(r.magnitude_shifts().is_none());
    }

    #[test]
    fn cosine_similarities_correct_for_small_delta() {
        // Small delta -> pre and post nearly identical -> cos near 1.0
        let r = make_full_report(&[0.001], &[10.0], 10.0);
        let cos = r.cosine_similarities().unwrap();
        assert!(cos[0] > 0.999, "expected cos near 1.0, got {}", cos[0]);
    }

    #[test]
    fn magnitude_shifts_correct() {
        let r = make_full_report(&[0.01, 0.02], &[10.5, 9.8], 10.0);
        let shifts = r.magnitude_shifts().unwrap();
        assert!((shifts[0] - 0.5).abs() < 1e-10);
        assert!((shifts[1] - (-0.2)).abs() < 1e-10);
    }

    // --- ConvergenceGuard ---

    #[test]
    fn sync_always_stable() {
        let mut g = ConvergenceGuard::new(ApplyPolicy::Sync, true, 0.01);
        for _ in 0..10 {
            let action = g.report(&make_report(&[0.5, 0.5]));
            assert_eq!(action, ConvergenceAction::Stable);
        }
    }

    #[test]
    fn cadence_trend_suppresses_growth() {
        let mut g = ConvergenceGuard::new(ApplyPolicy::Cadence, true, 0.01);
        // 3 rising values above threshold
        assert_eq!(g.report(&make_report(&[0.02])), ConvergenceAction::Stable); // only 1
        assert_eq!(g.report(&make_report(&[0.03])), ConvergenceAction::Stable); // only 2
        assert_eq!(
            g.report(&make_report(&[0.04])),
            ConvergenceAction::SuppressGrowth
        ); // 3 rising
    }

    #[test]
    fn cadence_non_rising_is_stable() {
        let mut g = ConvergenceGuard::new(ApplyPolicy::Cadence, true, 0.01);
        g.report(&make_report(&[0.05]));
        g.report(&make_report(&[0.04])); // dropped
        assert_eq!(
            g.report(&make_report(&[0.06])),
            ConvergenceAction::Stable
        ); // not 3 consecutive rising
    }

    #[test]
    fn below_threshold_is_stable() {
        let mut g = ConvergenceGuard::new(ApplyPolicy::Cadence, true, 0.10);
        // Rising but below threshold
        g.report(&make_report(&[0.01]));
        g.report(&make_report(&[0.02]));
        assert_eq!(
            g.report(&make_report(&[0.03])),
            ConvergenceAction::Stable
        );
    }

    #[test]
    fn disabled_always_stable() {
        let mut g = ConvergenceGuard::new(ApplyPolicy::Cadence, false, 0.01);
        g.report(&make_report(&[0.1]));
        g.report(&make_report(&[0.2]));
        assert_eq!(
            g.report(&make_report(&[0.3])),
            ConvergenceAction::Stable
        );
    }

    #[test]
    fn history_capped_at_5() {
        let mut g = ConvergenceGuard::new(ApplyPolicy::Cadence, true, 0.01);
        for i in 0..10 {
            g.report(&make_report(&[i as f64 * 0.01]));
        }
        assert_eq!(g.history().len(), 5);
    }

    #[test]
    fn async_same_as_cadence_v1() {
        let mut gc = ConvergenceGuard::new(ApplyPolicy::Cadence, true, 0.01);
        let mut ga = ConvergenceGuard::new(ApplyPolicy::Async, true, 0.01);
        let reports: Vec<DivergenceReport> =
            vec![make_report(&[0.02]), make_report(&[0.03]), make_report(&[0.04])];
        for r in &reports {
            let ac = gc.report(r);
            let aa = ga.report(r);
            assert_eq!(ac, aa);
        }
    }
}