use std::sync::atomic::{AtomicU64, Ordering};
const MIN_SAMPLES: u64 = 10;
const ALPHA_NUMERATOR: u64 = 5;
const ALPHA_DENOMINATOR: u64 = 100;
const OUTLIER_RATIO: f64 = 10.0;
const MIN_MS_PER_UNIT: f64 = 0.001;
const MAX_MS_PER_UNIT: f64 = 50.0;
const SCALE: f64 = 1_000_000.0;
#[derive(Debug, Default)]
pub(crate) struct CboFeedbackLoop {
ema_scaled: AtomicU64,
sample_count: AtomicU64,
}
impl CboFeedbackLoop {
#[must_use]
pub(crate) fn new() -> Self {
Self::default()
}
pub(crate) fn record(&self, dataset_size: usize, ef_search: usize, actual_ms: f64) {
if actual_ms <= 0.0 || dataset_size == 0 {
return;
}
let estimated_cost = Self::estimate_cost(dataset_size, ef_search);
if estimated_cost <= 0.0 {
return;
}
let observed_ratio = actual_ms / estimated_cost;
let count = self.sample_count.load(Ordering::Relaxed);
if count >= MIN_SAMPLES {
let current_ema = self.current_ema();
if current_ema > 0.0 && observed_ratio / current_ema > OUTLIER_RATIO {
return;
}
}
self.sample_count.fetch_add(1, Ordering::Relaxed);
self.ema_update(observed_ratio);
}
#[must_use]
pub(crate) fn adjusted_ms_per_cost_unit(&self) -> Option<f64> {
if self.sample_count.load(Ordering::Relaxed) < MIN_SAMPLES {
return None;
}
let v = self.current_ema();
if v > 0.0 {
Some(v.clamp(MIN_MS_PER_UNIT, MAX_MS_PER_UNIT))
} else {
None
}
}
#[must_use]
pub(crate) fn sample_count(&self) -> u64 {
self.sample_count.load(Ordering::Relaxed)
}
#[must_use]
fn current_ema(&self) -> f64 {
#[allow(clippy::cast_precision_loss)]
let scaled = self.ema_scaled.load(Ordering::Relaxed) as f64;
scaled / SCALE
}
fn estimate_cost(dataset_size: usize, ef_search: usize) -> f64 {
#[allow(clippy::cast_precision_loss)]
let n_factor = (dataset_size as f64 + 1.0).log2();
#[allow(clippy::cast_precision_loss)]
let ef_factor = ef_search as f64 / 100.0;
n_factor * ef_factor
}
fn ema_update(&self, new_value: f64) {
let clamped = new_value.clamp(0.0, MAX_MS_PER_UNIT);
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
let new_scaled = (clamped * SCALE) as u64;
loop {
let old_scaled = self.ema_scaled.load(Ordering::Relaxed);
let new_ema_scaled = if old_scaled == 0 {
new_scaled
} else {
let num = u128::from(new_scaled) * u128::from(ALPHA_NUMERATOR)
+ u128::from(old_scaled) * u128::from(ALPHA_DENOMINATOR - ALPHA_NUMERATOR);
#[allow(clippy::cast_possible_truncation)]
let result = (num / u128::from(ALPHA_DENOMINATOR)) as u64;
result
};
if self
.ema_scaled
.compare_exchange_weak(
old_scaled,
new_ema_scaled,
Ordering::Relaxed,
Ordering::Relaxed,
)
.is_ok()
{
break;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_no_adjustment_before_min_samples() {
let fb = CboFeedbackLoop::new();
for _ in 0..(MIN_SAMPLES - 1) {
fb.record(10_000, 100, 5.0);
}
assert!(
fb.adjusted_ms_per_cost_unit().is_none(),
"should return None until MIN_SAMPLES observations"
);
}
#[test]
fn test_adjustment_after_min_samples() {
let fb = CboFeedbackLoop::new();
for _ in 0..MIN_SAMPLES {
fb.record(10_000, 100, 5.0);
}
let adjusted = fb.adjusted_ms_per_cost_unit();
assert!(adjusted.is_some(), "should return Some after MIN_SAMPLES");
let v = adjusted.unwrap();
assert!(
(MIN_MS_PER_UNIT..=MAX_MS_PER_UNIT).contains(&v),
"adjusted value {v} out of bounds"
);
}
#[test]
fn test_ema_converges_toward_observed_ratio() {
let fb = CboFeedbackLoop::new();
for _ in 0..50 {
fb.record(10_000, 100, 2.0);
}
let v = fb.adjusted_ms_per_cost_unit().expect("should have value");
let expected = 2.0 / (10_001_f64.log2() * 1.0);
assert!(
(v - expected).abs() < 0.05,
"EMA {v:.4} should be near expected {expected:.4}"
);
}
#[test]
fn test_outlier_rejection() {
let fb = CboFeedbackLoop::new();
for _ in 0..20 {
fb.record(10_000, 100, 2.0);
}
let before = fb.current_ema();
let before_count = fb.sample_count();
fb.record(10_000, 100, 20_000.0);
let after = fb.current_ema();
let after_count = fb.sample_count();
assert_eq!(
before_count, after_count,
"outlier should be rejected, sample count unchanged"
);
assert!(
(after - before).abs() < f64::EPSILON,
"EMA should be unchanged after outlier rejection"
);
}
#[test]
fn test_zero_or_negative_actual_ms_ignored() {
let fb = CboFeedbackLoop::new();
fb.record(10_000, 100, 0.0);
fb.record(10_000, 100, -1.0);
assert_eq!(fb.sample_count(), 0, "invalid samples should be ignored");
}
#[test]
fn test_zero_dataset_size_ignored() {
let fb = CboFeedbackLoop::new();
fb.record(0, 100, 5.0);
assert_eq!(fb.sample_count(), 0);
}
#[test]
fn test_bounds_clamping() {
let fb = CboFeedbackLoop::new();
for _ in 0..MIN_SAMPLES {
fb.record(10_000, 100, 0.001);
}
let v = fb.adjusted_ms_per_cost_unit().unwrap();
assert!(v >= MIN_MS_PER_UNIT, "should be clamped to minimum");
}
#[test]
fn test_large_value_clamped_before_cast() {
let fb = CboFeedbackLoop::new();
for _ in 0..MIN_SAMPLES {
fb.record(1, 1, 1e10);
}
if let Some(v) = fb.adjusted_ms_per_cost_unit() {
assert!(
v <= MAX_MS_PER_UNIT,
"value must be clamped to MAX_MS_PER_UNIT, got {v}"
);
}
}
}