use dashmap::DashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Duration;
const MAX_ENTRIES: usize = 2048;
const DEFAULT_FAILURE_THRESHOLD: u32 = 3;
const MIN_TIMEOUT: Duration = Duration::from_millis(500);
struct DomainEntry {
consecutive_failures: u32,
last_access: u64,
}
pub struct WaitGuard {
entries: DashMap<Box<str>, DomainEntry>,
access_counter: AtomicU64,
failure_threshold: u32,
}
impl WaitGuard {
pub fn new() -> Self {
Self {
entries: DashMap::with_capacity(256),
access_counter: AtomicU64::new(0),
failure_threshold: DEFAULT_FAILURE_THRESHOLD,
}
}
pub fn with_threshold(threshold: u32) -> Self {
Self {
entries: DashMap::with_capacity(256),
access_counter: AtomicU64::new(0),
failure_threshold: threshold.max(1),
}
}
pub fn record_bad(&self, domain: &str) {
let tick = self.access_counter.fetch_add(1, Ordering::Relaxed);
if let Some(mut entry) = self.entries.get_mut(domain) {
entry.consecutive_failures = entry.consecutive_failures.saturating_add(1);
entry.last_access = tick;
return;
}
if self.entries.len() >= MAX_ENTRIES {
self.evict_lru();
}
self.entries.insert(
domain.into(),
DomainEntry {
consecutive_failures: 1,
last_access: tick,
},
);
}
pub fn record_good(&self, domain: &str) {
if let Some(mut entry) = self.entries.get_mut(domain) {
entry.consecutive_failures = 0;
let tick = self.access_counter.fetch_add(1, Ordering::Relaxed);
entry.last_access = tick;
}
}
pub fn adjusted_timeout(&self, domain: &str, base: Duration) -> Duration {
if base <= MIN_TIMEOUT {
return base;
}
let failures = match self.entries.get(domain) {
Some(entry) => {
let f = entry.consecutive_failures;
drop(entry);
f
}
None => return base,
};
if failures < self.failure_threshold {
return base;
}
let halvings = failures / self.failure_threshold;
let halvings = halvings.min(10);
let reduced = base / (1u32 << halvings);
if reduced < MIN_TIMEOUT {
MIN_TIMEOUT
} else {
reduced
}
}
pub fn is_flagged(&self, domain: &str) -> bool {
self.entries
.get(domain)
.map(|e| e.consecutive_failures >= self.failure_threshold)
.unwrap_or(false)
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
fn evict_lru(&self) {
let mut oldest_key: Option<Box<str>> = None;
let mut oldest_access = u64::MAX;
for entry in self.entries.iter() {
if entry.value().last_access < oldest_access {
oldest_access = entry.value().last_access;
oldest_key = Some(entry.key().clone());
}
}
if let Some(key) = oldest_key {
self.entries.remove(&key);
}
}
}
impl Default for WaitGuard {
fn default() -> Self {
Self::new()
}
}
static GLOBAL_WAIT_GUARD: std::sync::LazyLock<WaitGuard> = std::sync::LazyLock::new(WaitGuard::new);
#[inline]
pub fn global_wait_guard() -> &'static WaitGuard {
&GLOBAL_WAIT_GUARD
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new_domain_not_flagged() {
let guard = WaitGuard::new();
assert!(!guard.is_flagged("example.com"));
assert_eq!(
guard.adjusted_timeout("example.com", Duration::from_secs(30)),
Duration::from_secs(30)
);
}
#[test]
fn test_below_threshold_no_reduction() {
let guard = WaitGuard::new();
guard.record_bad("example.com");
guard.record_bad("example.com");
assert!(!guard.is_flagged("example.com"));
assert_eq!(
guard.adjusted_timeout("example.com", Duration::from_secs(30)),
Duration::from_secs(30)
);
}
#[test]
fn test_at_threshold_halves_timeout() {
let guard = WaitGuard::new();
for _ in 0..3 {
guard.record_bad("example.com");
}
assert!(guard.is_flagged("example.com"));
assert_eq!(
guard.adjusted_timeout("example.com", Duration::from_secs(30)),
Duration::from_secs(15)
);
}
#[test]
fn test_double_threshold_quarters_timeout() {
let guard = WaitGuard::new();
for _ in 0..6 {
guard.record_bad("example.com");
}
assert_eq!(
guard.adjusted_timeout("example.com", Duration::from_secs(30)),
Duration::from_secs(7) + Duration::from_millis(500)
);
}
#[test]
fn test_minimum_timeout_floor() {
let guard = WaitGuard::new();
for _ in 0..100 {
guard.record_bad("example.com");
}
assert_eq!(
guard.adjusted_timeout("example.com", Duration::from_secs(30)),
MIN_TIMEOUT
);
}
#[test]
fn test_record_good_clears_failures() {
let guard = WaitGuard::new();
for _ in 0..5 {
guard.record_bad("example.com");
}
assert!(guard.is_flagged("example.com"));
guard.record_good("example.com");
assert!(!guard.is_flagged("example.com"));
assert_eq!(
guard.adjusted_timeout("example.com", Duration::from_secs(30)),
Duration::from_secs(30)
);
}
#[test]
fn test_record_good_noop_for_untracked() {
let guard = WaitGuard::new();
guard.record_good("never-seen.com");
assert!(guard.is_empty());
}
#[test]
fn test_lru_eviction_at_capacity() {
let guard = WaitGuard::new();
for i in 0..MAX_ENTRIES {
guard.record_bad(&format!("domain-{i}.com"));
}
assert_eq!(guard.len(), MAX_ENTRIES);
guard.record_bad("new-domain.com");
assert_eq!(guard.len(), MAX_ENTRIES);
assert!(!guard.entries.contains_key("domain-0.com"));
assert!(guard.entries.contains_key("new-domain.com"));
}
#[test]
fn test_custom_threshold() {
let guard = WaitGuard::with_threshold(1);
guard.record_bad("fast-flag.com");
assert!(guard.is_flagged("fast-flag.com"));
assert_eq!(
guard.adjusted_timeout("fast-flag.com", Duration::from_secs(30)),
Duration::from_secs(15)
);
}
#[test]
fn test_threshold_min_one() {
let guard = WaitGuard::with_threshold(0);
assert_eq!(guard.failure_threshold, 1);
}
#[test]
fn test_saturating_add_no_overflow() {
let guard = WaitGuard::new();
guard.entries.insert(
"overflow.com".into(),
DomainEntry {
consecutive_failures: u32::MAX - 1,
last_access: 0,
},
);
guard.record_bad("overflow.com");
guard.record_bad("overflow.com");
let entry = guard.entries.get("overflow.com").unwrap();
assert_eq!(entry.consecutive_failures, u32::MAX);
}
#[test]
fn test_multiple_domains_independent() {
let guard = WaitGuard::new();
for _ in 0..5 {
guard.record_bad("bad.com");
}
guard.record_bad("ok.com");
assert!(guard.is_flagged("bad.com"));
assert!(!guard.is_flagged("ok.com"));
}
#[test]
fn test_global_singleton_accessible() {
let g = global_wait_guard();
assert!(g.len() < usize::MAX);
}
#[test]
fn test_zero_base_never_inflated() {
let guard = WaitGuard::new();
for _ in 0..10 {
guard.record_bad("exhausted.com");
}
assert!(guard.is_flagged("exhausted.com"));
assert_eq!(
guard.adjusted_timeout("exhausted.com", Duration::ZERO),
Duration::ZERO
);
}
#[test]
fn test_small_base_never_inflated() {
let guard = WaitGuard::new();
for _ in 0..10 {
guard.record_bad("small.com");
}
let base = Duration::from_millis(200);
assert_eq!(guard.adjusted_timeout("small.com", base), base);
}
#[test]
fn test_base_at_min_timeout_returned_unchanged() {
let guard = WaitGuard::new();
for _ in 0..10 {
guard.record_bad("floor.com");
}
assert_eq!(
guard.adjusted_timeout("floor.com", MIN_TIMEOUT),
MIN_TIMEOUT
);
}
#[test]
fn test_concurrent_record_bad_does_not_reset() {
let guard = WaitGuard::new();
guard.record_bad("race.com");
guard.record_bad("race.com");
let entry = guard.entries.get("race.com").unwrap();
assert_eq!(entry.consecutive_failures, 2);
}
}