pub const BERNSTEIN_DELTA: f64 = 0.05;
pub const MIN_SAMPLES_FOR_BERNSTEIN: u64 = 10;
#[derive(Debug, Clone, Default)]
pub struct WelfordTracker {
pub n: u64,
pub mean: f64,
pub m2: f64,
pub r_min: f64,
pub r_max: f64,
}
impl WelfordTracker {
pub fn new() -> Self {
Self {
n: 0,
mean: 0.0,
m2: 0.0,
r_min: f64::INFINITY,
r_max: f64::NEG_INFINITY,
}
}
pub fn update(&mut self, x: f64) {
self.n += 1;
let delta1 = x - self.mean;
self.mean += delta1 / self.n as f64;
let delta2 = x - self.mean; self.m2 += delta1 * delta2;
if x < self.r_min {
self.r_min = x;
}
if x > self.r_max {
self.r_max = x;
}
}
pub fn variance(&self) -> f64 {
if self.n > 1 {
self.m2 / (self.n - 1) as f64
} else {
f64::INFINITY
}
}
pub fn range(&self) -> f64 {
if self.n == 0 {
0.0
} else {
self.r_max - self.r_min
}
}
pub fn flush(&mut self) {
let (keep_min, keep_max) = (self.r_min, self.r_max);
*self = Self::new();
self.r_min = keep_min;
self.r_max = keep_max;
}
}
#[derive(Debug, Clone)]
pub struct EwmaWelfordTracker {
pub n_eff: f64,
pub mean: f64,
pub m2: f64,
pub alpha: f64,
pub n: u64,
pub r_min: f64,
pub r_max: f64,
}
impl EwmaWelfordTracker {
pub fn new(alpha: f64) -> Self {
Self {
n_eff: 1.0,
mean: 0.0,
m2: 0.0,
alpha,
n: 0,
r_min: f64::INFINITY,
r_max: f64::NEG_INFINITY,
}
}
pub fn update(&mut self, x: f64) {
self.n += 1;
let one_minus_a = 1.0 - self.alpha;
let delta = x - self.mean;
self.mean += one_minus_a * delta;
self.m2 = self.alpha * self.m2 + self.alpha * one_minus_a * delta * delta;
let sum_sq = self.alpha * self.alpha / self.n_eff + one_minus_a * one_minus_a;
self.n_eff = if sum_sq > 1e-15 {
1.0 / sum_sq
} else {
1.0 / one_minus_a };
if x < self.r_min {
self.r_min = x;
}
if x > self.r_max {
self.r_max = x;
}
}
pub fn corrected_mean(&self) -> f64 {
let bias = 1.0 - self.alpha.powi(self.n as i32);
if bias > 1e-15 {
self.mean / bias
} else {
self.mean
}
}
pub fn corrected_variance(&self) -> f64 {
let bias = 1.0 - self.alpha.powi(self.n as i32);
let raw = if bias > 1e-15 {
self.m2 / bias
} else {
self.m2
};
let one_minus_a = 1.0 - self.alpha;
if one_minus_a > 1e-15 {
raw / one_minus_a
} else {
raw
}
}
pub fn range(&self) -> f64 {
if self.n == 0 {
0.0
} else {
self.r_max - self.r_min
}
}
pub fn flush(&mut self) {
let (alpha, keep_min, keep_max) = (self.alpha, self.r_min, self.r_max);
*self = Self::new(alpha);
self.r_min = keep_min;
self.r_max = keep_max;
}
}
pub fn bernstein_halfwidth(variance: f64, n: f64, range: f64, delta: f64) -> f64 {
if n < 2.0 {
return f64::INFINITY;
}
let ln2d = (2.0 / delta).ln();
let term1 = (2.0 * variance * ln2d / n).sqrt();
let term2 = (7.0 * range / (3.0 * (n - 1.0))) * ln2d;
term1 + term2
}
pub fn empirical_bernstein_ci(mean: f64, m2: f64, n: u64, range: f64, delta: f64) -> (f64, f64) {
let n_f = n as f64;
if n < 2 {
return (f64::NEG_INFINITY, f64::INFINITY);
}
let variance = m2 / (n_f - 1.0);
let hw = bernstein_halfwidth(variance, n_f, range, delta);
(mean - hw, mean + hw)
}
pub fn ewma_bernstein_ci(tracker: &EwmaWelfordTracker, delta: f64) -> (f64, f64) {
if tracker.n < 2 {
return (f64::NEG_INFINITY, f64::INFINITY);
}
let n_eff = tracker.n_eff;
let variance = tracker.corrected_variance();
let mean = tracker.corrected_mean();
let range = tracker.range();
let hw = bernstein_halfwidth(variance, n_eff, range, delta);
(mean - hw, mean + hw)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PromotionVerdict {
Promote,
Inconclusive,
Worse,
}
pub fn bernstein_promotion_test(
challenger_mean: f64,
challenger_m2: f64,
challenger_n: u64,
challenger_range: f64,
champion_metric: f64,
delta: f64,
) -> PromotionVerdict {
if challenger_n < MIN_SAMPLES_FOR_BERNSTEIN {
return PromotionVerdict::Inconclusive;
}
let (_lo, hi) = empirical_bernstein_ci(
challenger_mean,
challenger_m2,
challenger_n,
challenger_range,
delta,
);
if challenger_mean >= champion_metric {
PromotionVerdict::Worse
} else if hi < champion_metric {
PromotionVerdict::Promote
} else {
PromotionVerdict::Inconclusive
}
}
pub type ArmStats = (f64, f64, u64, f64);
pub fn bernstein_compare(front: &[ArmStats], delta: f64) -> Option<usize> {
if front.is_empty() {
return None;
}
if front.len() == 1 {
return Some(0);
}
let cis: Vec<(f64, f64)> = front
.iter()
.map(|&(mean, m2, n, range)| empirical_bernstein_ci(mean, m2, n, range, delta))
.collect();
for (i, &(_lo_i, hi_i)) in cis.iter().enumerate() {
if !hi_i.is_finite() {
continue;
}
let dominates = cis
.iter()
.enumerate()
.all(|(j, &(lo_j, _))| j == i || hi_i < lo_j);
if dominates {
let mean_i = front[i].0;
let is_best_mean = front
.iter()
.enumerate()
.all(|(j, &(m, _, _, _))| j == i || mean_i <= m);
if is_best_mean {
return Some(i);
}
}
}
let mut sorted_by_mean: Vec<usize> = (0..front.len()).collect();
sorted_by_mean.sort_by(|&a, &b| {
front[a]
.0
.partial_cmp(&front[b].0)
.unwrap_or(std::cmp::Ordering::Equal)
});
let best_idx = sorted_by_mean[0];
let second_idx = sorted_by_mean[1];
let (_, hi_best) = cis[best_idx];
let (lo_second, _) = cis[second_idx];
if hi_best.is_finite() && lo_second.is_finite() && hi_best < lo_second {
return Some(best_idx);
}
None
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn welford_tracker_mean_and_variance() {
let mut t = WelfordTracker::new();
let values = [2.0f64, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0];
for &v in &values {
t.update(v);
}
assert!(
(t.mean - 5.0).abs() < 1e-10,
"mean should be 5.0, got {}",
t.mean
);
let expected_var = 32.0 / 7.0; assert!(
(t.variance() - expected_var).abs() < 1e-10,
"variance should be {expected_var:.6}, got {}",
t.variance()
);
assert!(
(t.range() - 7.0).abs() < 1e-10,
"range should be 7.0, got {}",
t.range()
);
}
#[test]
fn welford_tracker_cold_start_single_sample() {
let mut t = WelfordTracker::new();
let sample = core::f64::consts::PI;
t.update(sample);
assert!(
(t.mean - sample).abs() < 1e-12,
"single-sample mean should equal the sample, got {}",
t.mean
);
assert!(
t.variance().is_infinite(),
"single-sample variance should be infinite (undefined), got {}",
t.variance()
);
}
#[test]
fn welford_tracker_flush_resets_stats_keeps_range() {
let mut t = WelfordTracker::new();
for v in [1.0f64, 2.0, 3.0, 4.0] {
t.update(v);
}
let (kept_min, kept_max) = (t.r_min, t.r_max);
t.flush();
assert_eq!(t.n, 0, "flush should reset n to 0, got {}", t.n);
assert!(
(t.r_min - kept_min).abs() < 1e-12,
"flush should preserve r_min as soft bound: expected {kept_min}, got {}",
t.r_min
);
assert!(
(t.r_max - kept_max).abs() < 1e-12,
"flush should preserve r_max as soft bound: expected {kept_max}, got {}",
t.r_max
);
}
#[test]
fn bernstein_bound_widens_with_variance() {
let range = 1.0;
let n = 100.0;
let delta = 0.05;
let hw_low = bernstein_halfwidth(0.01, n, range, delta);
let hw_high = bernstein_halfwidth(0.5, n, range, delta);
assert!(
hw_high > hw_low,
"higher variance should produce wider bound: hw_low={hw_low:.6}, hw_high={hw_high:.6}"
);
}
#[test]
fn bernstein_tighter_than_hoeffding_at_low_variance() {
let range = 2.0;
let delta = 0.05;
for &(variance, n_u64) in &[(0.01f64, 1000u64), (0.1, 2000), (0.001, 500)] {
let n = n_u64 as f64;
let bernstein = bernstein_halfwidth(variance, n, range, delta);
let hoeffding = range * ((2.0 / delta).ln() / (2.0 * n)).sqrt();
assert!(
bernstein < hoeffding,
"Bernstein should be < Hoeffding at large n with low variance \
(n={n_u64}, var={variance}): bernstein={bernstein:.6}, hoeffding={hoeffding:.6}"
);
}
let n = 1000.0f64;
let var_small = 0.001; let term1_small = (2.0 * var_small * (2.0_f64 / delta).ln() / n).sqrt();
let hoeff = range * ((2.0 / delta).ln() / (2.0 * n)).sqrt();
assert!(
term1_small < hoeff,
"Bernstein term1 (variance component) should be < Hoeffding for small variance: \
term1={term1_small:.6}, hoeffding={hoeff:.6}"
);
}
#[test]
fn bernstein_returns_infinity_for_small_n() {
assert!(
bernstein_halfwidth(0.5, 0.0, 1.0, 0.05).is_infinite(),
"n=0 should give INFINITY"
);
assert!(
bernstein_halfwidth(0.5, 1.0, 1.0, 0.05).is_infinite(),
"n=1 should give INFINITY"
);
}
#[test]
fn bernstein_promotion_requires_statistical_certainty() {
let champion_metric = 0.50;
let challenger_mean = 0.48;
let challenger_range = 1.0;
let n = 11u64;
let variance = 0.2f64;
let challenger_m2 = variance * (n - 1) as f64;
let verdict = bernstein_promotion_test(
challenger_mean,
challenger_m2,
n,
challenger_range,
champion_metric,
BERNSTEIN_DELTA,
);
assert_ne!(
verdict,
PromotionVerdict::Promote,
"small n ({n}) with high variance should NOT promote: got {verdict:?}"
);
}
#[test]
fn bernstein_promotes_with_clear_advantage() {
let champion_metric = 0.50;
let challenger_mean = 0.10; let challenger_range = 0.05;
let n = 1000u64;
let variance = 0.001f64;
let challenger_m2 = variance * (n - 1) as f64;
let verdict = bernstein_promotion_test(
challenger_mean,
challenger_m2,
n,
challenger_range,
champion_metric,
BERNSTEIN_DELTA,
);
assert_eq!(
verdict,
PromotionVerdict::Promote,
"large n, tiny variance, large advantage should Promote: got {verdict:?}"
);
}
#[test]
fn bernstein_worse_when_challenger_mean_exceeds_champion() {
let verdict = bernstein_promotion_test(
0.80, 0.01, 200, 0.1, 0.50, BERNSTEIN_DELTA,
);
assert_eq!(
verdict,
PromotionVerdict::Worse,
"challenger with higher mean error should be Worse, got {verdict:?}"
);
}
#[test]
fn bernstein_inconclusive_below_min_samples() {
let verdict = bernstein_promotion_test(
0.10, 0.001, MIN_SAMPLES_FOR_BERNSTEIN - 1, 0.1, 0.50, BERNSTEIN_DELTA,
);
assert_eq!(
verdict,
PromotionVerdict::Inconclusive,
"fewer than MIN_SAMPLES_FOR_BERNSTEIN should be Inconclusive, got {verdict:?}"
);
}
#[test]
fn bernstein_handles_drift_via_ewma_decay() {
let alpha = 0.98f64;
let mut tracker = EwmaWelfordTracker::new(alpha);
for i in 0..500u64 {
let x = 0.5 + 0.1 * (if i % 2 == 0 { 1.0 } else { -1.0 });
tracker.update(x);
}
let (lo_before, hi_before) = ewma_bernstein_ci(&tracker, BERNSTEIN_DELTA);
let width_before = hi_before - lo_before;
assert!(
width_before.is_finite(),
"pre-flush CI should be finite after 500 samples: lo={lo_before}, hi={hi_before}"
);
tracker.flush();
for _ in 0..50u64 {
tracker.update(0.9); }
let (lo_after, hi_after) = ewma_bernstein_ci(&tracker, BERNSTEIN_DELTA);
let width_after = hi_after - lo_after;
assert!(
width_after.is_finite(),
"post-flush CI should be finite after 50 samples: lo={lo_after}, hi={hi_after}"
);
assert!(
width_after > width_before,
"CI after drift flush should be wider (less certainty after reset): \
before={width_before:.6}, after={width_after:.6}"
);
}
#[test]
fn pareto_front_can_invoke_bernstein_tiebreak() {
let n0 = 2000u64;
let var0 = 0.0001f64;
let m2_0 = var0 * (n0 - 1) as f64;
let range0 = 0.05f64;
let n1 = 2000u64;
let var1 = 0.0001f64;
let m2_1 = var1 * (n1 - 1) as f64;
let range1 = 0.05f64;
let front: &[ArmStats] = &[
(0.10, m2_0, n0, range0), (0.50, m2_1, n1, range1), ];
let winner = bernstein_compare(front, BERNSTEIN_DELTA);
assert_eq!(
winner,
Some(0),
"Pareto tiebreak should select arm 0 (hi_0 < lo_1, statistically dominant): got {winner:?}"
);
}
#[test]
fn pareto_front_returns_none_when_uncertain() {
let n = 15u64;
let var = 0.5f64;
let m2 = var * (n - 1) as f64;
let range = 1.0f64;
let front: &[ArmStats] = &[
(0.48, m2, n, range), (0.52, m2, n, range),
];
let winner = bernstein_compare(front, BERNSTEIN_DELTA);
assert_eq!(
winner, None,
"Pareto tiebreak with overlapping CIs should return None, got {winner:?}"
);
}
#[test]
fn pareto_front_single_entry_always_wins() {
let front: &[ArmStats] = &[(0.3, 0.01, 5, 0.1)];
assert_eq!(
bernstein_compare(front, BERNSTEIN_DELTA),
Some(0),
"single-entry Pareto front should always return Some(0)"
);
}
}