use std::{collections::HashMap, net::SocketAddr, sync::Mutex, time::Duration};
const EWMA_ALPHA: f64 = 0.2;
#[derive(Debug, Default, Clone)]
struct UpstreamStat {
successes: u64,
failures: u64,
ewma_latency_ms: Option<f64>,
last_error: Option<String>,
}
impl UpstreamStat {
fn record_success(&mut self, latency: Duration) {
self.successes += 1;
let sample = latency.as_secs_f64() * 1000.0;
self.ewma_latency_ms = Some(match self.ewma_latency_ms {
Some(prev) => EWMA_ALPHA * sample + (1.0 - EWMA_ALPHA) * prev,
None => sample,
});
}
fn record_failure(&mut self, error: String) {
self.failures += 1;
self.last_error = Some(error);
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct UpstreamHealthRow {
pub addr: SocketAddr,
pub successes: u64,
pub failures: u64,
pub success_rate: f64,
pub ewma_latency_ms: Option<f64>,
pub last_error: Option<String>,
}
impl UpstreamHealthRow {
#[must_use]
pub fn attempts(&self) -> u64 {
self.successes + self.failures
}
}
#[derive(Debug, Default)]
pub struct UpstreamHealth {
rows: Mutex<HashMap<SocketAddr, UpstreamStat>>,
}
impl UpstreamHealth {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn record_success(&self, addr: SocketAddr, latency: Duration) {
self.rows
.lock()
.expect("upstream-health mutex poisoned")
.entry(addr)
.or_default()
.record_success(latency);
}
pub fn record_failure(&self, addr: SocketAddr, error: String) {
self.rows
.lock()
.expect("upstream-health mutex poisoned")
.entry(addr)
.or_default()
.record_failure(error);
}
#[must_use]
pub fn snapshot(&self) -> Vec<UpstreamHealthRow> {
let rows = self.rows.lock().expect("upstream-health mutex poisoned");
let mut out: Vec<UpstreamHealthRow> = rows
.iter()
.map(|(&addr, stat)| {
let attempts = stat.successes + stat.failures;
let success_rate = if attempts == 0 {
0.0
} else {
stat.successes as f64 / attempts as f64
};
UpstreamHealthRow {
addr,
successes: stat.successes,
failures: stat.failures,
success_rate,
ewma_latency_ms: stat.ewma_latency_ms,
last_error: stat.last_error.clone(),
}
})
.collect();
out.sort_unstable_by_key(|r| r.addr);
out
}
}
#[cfg(test)]
mod tests {
use super::*;
fn addr(s: &str) -> SocketAddr {
s.parse().expect("valid socket addr")
}
#[test]
fn success_updates_count_and_latency_for_the_right_upstream() {
let h = UpstreamHealth::new();
let a = addr("1.1.1.1:53");
let b = addr("9.9.9.9:53");
h.record_success(a, Duration::from_millis(20));
h.record_success(a, Duration::from_millis(40));
h.record_success(b, Duration::from_millis(100));
let snap = h.snapshot();
assert_eq!(snap.len(), 2);
let row_a = snap.iter().find(|r| r.addr == a).unwrap();
assert_eq!(row_a.successes, 2);
assert_eq!(row_a.failures, 0);
let ewma = row_a.ewma_latency_ms.expect("latency recorded");
assert!((ewma - 24.0).abs() < 1e-9, "ewma was {ewma}");
assert!((row_a.success_rate - 1.0).abs() < 1e-9);
let row_b = snap.iter().find(|r| r.addr == b).unwrap();
assert_eq!(row_b.successes, 1);
assert_eq!(row_b.ewma_latency_ms, Some(100.0));
}
#[test]
fn failure_increments_count_and_records_error_without_touching_latency() {
let h = UpstreamHealth::new();
let a = addr("1.1.1.1:53");
h.record_success(a, Duration::from_millis(30));
h.record_failure(a, "upstream UDP query timed out".to_owned());
let snap = h.snapshot();
let row = &snap[0];
assert_eq!(row.successes, 1);
assert_eq!(row.failures, 1);
assert_eq!(row.attempts(), 2);
assert_eq!(row.ewma_latency_ms, Some(30.0));
assert_eq!(
row.last_error.as_deref(),
Some("upstream UDP query timed out")
);
assert!((row.success_rate - 0.5).abs() < 1e-9);
}
#[test]
fn failure_only_upstream_has_no_latency() {
let h = UpstreamHealth::new();
let a = addr("8.8.8.8:53");
h.record_failure(a, "boom".to_owned());
let snap = h.snapshot();
let row = &snap[0];
assert_eq!(row.successes, 0);
assert_eq!(row.failures, 1);
assert_eq!(row.ewma_latency_ms, None);
assert_eq!(row.success_rate, 0.0);
}
#[test]
fn empty_snapshot_is_empty() {
assert!(UpstreamHealth::new().snapshot().is_empty());
}
}