use std::collections::HashMap;
use std::sync::Arc;
use parking_lot::Mutex;
#[derive(Debug, Clone)]
pub struct ProviderStats {
pub success_ema: f64,
pub latency_ema_ms: f64,
pub total_calls: u64,
}
impl Default for ProviderStats {
fn default() -> Self {
Self {
success_ema: 1.0, latency_ema_ms: 500.0, total_calls: 0,
}
}
}
#[derive(Debug, Clone)]
pub struct EmaTracker {
stats: Arc<Mutex<HashMap<String, ProviderStats>>>,
alpha: f64,
reorder_interval: u64,
call_counter: Arc<Mutex<u64>>,
}
impl EmaTracker {
#[must_use]
pub fn new(alpha: f64, reorder_interval: u64) -> Self {
Self {
stats: Arc::new(Mutex::new(HashMap::new())),
alpha,
reorder_interval,
call_counter: Arc::new(Mutex::new(0)),
}
}
pub fn record(&self, provider_name: &str, success: bool, latency_ms: u64) {
let mut stats = self.stats.lock();
let entry = stats.entry(provider_name.to_owned()).or_default();
let success_val = if success { 1.0 } else { 0.0 };
entry.success_ema = self.alpha * success_val + (1.0 - self.alpha) * entry.success_ema;
#[allow(clippy::cast_precision_loss)]
let latency_f = latency_ms as f64;
entry.latency_ema_ms = self.alpha * latency_f + (1.0 - self.alpha) * entry.latency_ema_ms;
entry.total_calls += 1;
}
#[must_use]
pub fn maybe_reorder(&self, current_order: &[String]) -> Option<Vec<String>> {
let mut counter = self.call_counter.lock();
*counter += 1;
if self.reorder_interval == 0 || !(*counter).is_multiple_of(self.reorder_interval) {
return None;
}
let stats = self.stats.lock();
let mut scored: Vec<(String, f64)> = current_order
.iter()
.map(|name| {
let s = stats.get(name).cloned().unwrap_or_default();
let score = s.success_ema - s.latency_ema_ms / 10_000.0;
(name.clone(), score)
})
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
Some(scored.into_iter().map(|(name, _)| name).collect())
}
#[must_use]
pub fn snapshot(&self) -> HashMap<String, ProviderStats> {
self.stats.lock().clone()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_stats_optimistic_prior() {
let s = ProviderStats::default();
assert!((s.success_ema - 1.0).abs() < f64::EPSILON);
assert!(s.latency_ema_ms > 0.0);
assert_eq!(s.total_calls, 0);
}
#[test]
fn new_tracker_empty_stats() {
let t = EmaTracker::new(0.3, 10);
assert!(t.snapshot().is_empty());
}
#[test]
fn record_updates_success_ema() {
let t = EmaTracker::new(0.5, 100);
t.record("p1", true, 100);
let snap = t.snapshot();
let s = snap.get("p1").unwrap();
assert!((s.success_ema - 1.0).abs() < 1e-9);
t.record("p1", false, 100);
let snap = t.snapshot();
let s = snap.get("p1").unwrap();
assert!((s.success_ema - 0.5).abs() < 1e-9);
}
#[test]
fn record_updates_latency_ema() {
let t = EmaTracker::new(0.5, 100);
t.record("p1", true, 200);
let snap = t.snapshot();
let s = snap.get("p1").unwrap();
assert!((s.latency_ema_ms - 350.0).abs() < 1e-6);
}
#[test]
fn record_increments_total_calls() {
let t = EmaTracker::new(0.3, 100);
t.record("p1", true, 10);
t.record("p1", true, 10);
assert_eq!(t.snapshot().get("p1").unwrap().total_calls, 2);
}
#[test]
fn maybe_reorder_returns_none_before_interval() {
let t = EmaTracker::new(0.3, 10);
let order = vec!["p1".to_string(), "p2".to_string()];
for _ in 0..9 {
assert!(t.maybe_reorder(&order).is_none());
}
}
#[test]
fn maybe_reorder_returns_order_at_interval() {
let t = EmaTracker::new(0.3, 10);
let order = vec!["p1".to_string(), "p2".to_string()];
for _ in 0..9 {
let _ = t.maybe_reorder(&order);
}
let result = t.maybe_reorder(&order);
assert!(result.is_some());
assert_eq!(result.unwrap().len(), 2);
}
#[test]
fn maybe_reorder_fast_reliable_rises_to_top() {
let t = EmaTracker::new(1.0, 1); t.record("p1", false, 9000);
let _ = t.maybe_reorder(&["p1".to_string(), "p2".to_string()]);
t.record("p2", true, 10);
let result = t
.maybe_reorder(&["p1".to_string(), "p2".to_string()])
.unwrap();
assert_eq!(result[0], "p2");
}
#[test]
fn maybe_reorder_slow_unreliable_drops() {
let t = EmaTracker::new(1.0, 1);
t.record("p1", false, 9000);
let _ = t.maybe_reorder(&["p1".to_string(), "p2".to_string()]);
t.record("p2", true, 10);
let result = t
.maybe_reorder(&["p1".to_string(), "p2".to_string()])
.unwrap();
assert_eq!(result[result.len() - 1], "p1");
}
#[test]
fn maybe_reorder_interval_zero_always_none() {
let t = EmaTracker::new(0.3, 0);
let order = vec!["p1".to_string()];
for _ in 0..100 {
assert!(
t.maybe_reorder(&order).is_none(),
"interval=0 should never trigger reorder"
);
}
}
#[test]
fn record_multiple_providers_independent() {
let t = EmaTracker::new(0.5, 100);
t.record("p1", true, 100);
t.record("p2", false, 200);
let snap = t.snapshot();
let p1 = snap.get("p1").unwrap();
let p2 = snap.get("p2").unwrap();
assert!(
p1.success_ema > p2.success_ema,
"p1 success should be higher than p2"
);
assert_eq!(p1.total_calls, 1);
assert_eq!(p2.total_calls, 1);
}
#[test]
fn maybe_reorder_empty_order_returns_empty() {
let t = EmaTracker::new(0.3, 1);
let result = t.maybe_reorder(&[]).unwrap();
assert!(result.is_empty());
}
#[test]
fn record_many_failures_drives_success_ema_toward_zero() {
let t = EmaTracker::new(0.5, 100);
for _ in 0..20 {
t.record("p1", false, 100);
}
let snap = t.snapshot();
let s = snap.get("p1").unwrap();
assert!(
s.success_ema < 0.01,
"success EMA should be near 0 after many failures, got {}",
s.success_ema
);
}
}