use std::collections::VecDeque;
use super::ApplyPolicy;
pub struct DivergenceReport {
pub deltas: Vec<f64>,
pub pre_norms: Option<Vec<f64>>,
pub post_norm: Option<f64>,
}
impl DivergenceReport {
pub fn max_relative_delta(&self) -> f64 {
self.deltas.iter().copied().fold(0.0_f64, f64::max)
}
pub fn cosine_similarities(&self) -> Option<Vec<f64>> {
let pre_norms = self.pre_norms.as_ref()?;
let post_norm = self.post_norm?;
if post_norm < 1e-10 {
return None;
}
Some(
self.deltas
.iter()
.zip(pre_norms)
.map(|(&delta, &pre_norm)| {
if pre_norm < 1e-10 {
return 0.0;
}
let diff_sq = (delta * post_norm).powi(2);
let pre_sq = pre_norm.powi(2);
let post_sq = post_norm.powi(2);
((pre_sq + post_sq - diff_sq) / (2.0 * pre_norm * post_norm)).clamp(-1.0, 1.0)
})
.collect(),
)
}
pub fn magnitude_shifts(&self) -> Option<Vec<f64>> {
let pre_norms = self.pre_norms.as_ref()?;
let post_norm = self.post_norm?;
Some(pre_norms.iter().map(|&pre| pre - post_norm).collect())
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum ConvergenceAction {
Stable,
SuppressGrowth,
NudgeDown { factor: f64 },
}
pub struct ConvergenceGuard {
policy: ApplyPolicy,
enabled: bool,
threshold: f64,
history: VecDeque<f64>,
}
impl ConvergenceGuard {
pub fn new(policy: ApplyPolicy, enabled: bool, threshold: f64) -> Self {
Self {
policy,
enabled,
threshold,
history: VecDeque::with_capacity(6),
}
}
pub fn report(&mut self, report: &DivergenceReport) -> ConvergenceAction {
if !self.enabled {
return ConvergenceAction::Stable;
}
let divergence = report.max_relative_delta();
if self.history.len() >= 5 {
self.history.pop_front();
}
self.history.push_back(divergence);
match self.policy {
ApplyPolicy::Sync => ConvergenceAction::Stable,
ApplyPolicy::Cadence | ApplyPolicy::Async => self.check_trend(),
}
}
fn check_trend(&self) -> ConvergenceAction {
if self.history.len() < 3 {
return ConvergenceAction::Stable;
}
let len = self.history.len();
let rising = self.history[len - 1] > self.history[len - 2]
&& self.history[len - 2] > self.history[len - 3]
&& self.history[len - 1] > self.threshold;
if rising {
crate::verbose!(
" ddp: weight-space divergence trending up | history={:.4?} | suppressing growth",
Vec::from(self.history.clone()),
);
ConvergenceAction::SuppressGrowth
} else {
ConvergenceAction::Stable
}
}
pub fn history(&self) -> &VecDeque<f64> {
&self.history
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_report(deltas: &[f64]) -> DivergenceReport {
DivergenceReport {
deltas: deltas.to_vec(),
pre_norms: None,
post_norm: None,
}
}
fn make_full_report(
deltas: &[f64],
pre_norms: &[f64],
post_norm: f64,
) -> DivergenceReport {
DivergenceReport {
deltas: deltas.to_vec(),
pre_norms: Some(pre_norms.to_vec()),
post_norm: Some(post_norm),
}
}
#[test]
fn max_relative_delta_picks_worst_rank() {
let r = make_report(&[0.01, 0.05, 0.03]);
assert!((r.max_relative_delta() - 0.05).abs() < 1e-10);
}
#[test]
fn max_relative_delta_empty_is_zero() {
let r = make_report(&[]);
assert_eq!(r.max_relative_delta(), 0.0);
}
#[test]
fn cosine_similarities_none_when_missing_norms() {
let r = make_report(&[0.01, 0.02]);
assert!(r.cosine_similarities().is_none());
assert!(r.magnitude_shifts().is_none());
}
#[test]
fn cosine_similarities_correct_for_small_delta() {
let r = make_full_report(&[0.001], &[10.0], 10.0);
let cos = r.cosine_similarities().unwrap();
assert!(cos[0] > 0.999, "expected cos near 1.0, got {}", cos[0]);
}
#[test]
fn magnitude_shifts_correct() {
let r = make_full_report(&[0.01, 0.02], &[10.5, 9.8], 10.0);
let shifts = r.magnitude_shifts().unwrap();
assert!((shifts[0] - 0.5).abs() < 1e-10);
assert!((shifts[1] - (-0.2)).abs() < 1e-10);
}
#[test]
fn sync_always_stable() {
let mut g = ConvergenceGuard::new(ApplyPolicy::Sync, true, 0.01);
for _ in 0..10 {
let action = g.report(&make_report(&[0.5, 0.5]));
assert_eq!(action, ConvergenceAction::Stable);
}
}
#[test]
fn cadence_trend_suppresses_growth() {
let mut g = ConvergenceGuard::new(ApplyPolicy::Cadence, true, 0.01);
assert_eq!(g.report(&make_report(&[0.02])), ConvergenceAction::Stable); assert_eq!(g.report(&make_report(&[0.03])), ConvergenceAction::Stable); assert_eq!(
g.report(&make_report(&[0.04])),
ConvergenceAction::SuppressGrowth
); }
#[test]
fn cadence_non_rising_is_stable() {
let mut g = ConvergenceGuard::new(ApplyPolicy::Cadence, true, 0.01);
g.report(&make_report(&[0.05]));
g.report(&make_report(&[0.04])); assert_eq!(
g.report(&make_report(&[0.06])),
ConvergenceAction::Stable
); }
#[test]
fn below_threshold_is_stable() {
let mut g = ConvergenceGuard::new(ApplyPolicy::Cadence, true, 0.10);
g.report(&make_report(&[0.01]));
g.report(&make_report(&[0.02]));
assert_eq!(
g.report(&make_report(&[0.03])),
ConvergenceAction::Stable
);
}
#[test]
fn disabled_always_stable() {
let mut g = ConvergenceGuard::new(ApplyPolicy::Cadence, false, 0.01);
g.report(&make_report(&[0.1]));
g.report(&make_report(&[0.2]));
assert_eq!(
g.report(&make_report(&[0.3])),
ConvergenceAction::Stable
);
}
#[test]
fn history_capped_at_5() {
let mut g = ConvergenceGuard::new(ApplyPolicy::Cadence, true, 0.01);
for i in 0..10 {
g.report(&make_report(&[i as f64 * 0.01]));
}
assert_eq!(g.history().len(), 5);
}
#[test]
fn async_same_as_cadence_v1() {
let mut gc = ConvergenceGuard::new(ApplyPolicy::Cadence, true, 0.01);
let mut ga = ConvergenceGuard::new(ApplyPolicy::Async, true, 0.01);
let reports: Vec<DivergenceReport> =
vec![make_report(&[0.02]), make_report(&[0.03]), make_report(&[0.04])];
for r in &reports {
let ac = gc.report(r);
let aa = ga.report(r);
assert_eq!(ac, aa);
}
}
}