use std::{
collections::HashMap,
sync::{Arc, Mutex, RwLock},
time::Duration,
};
const MAX_QUOTA_ENTRIES: usize = 100;
#[derive(Debug, Default)]
pub struct QuotaRegistry {
inner: RwLock<HashMap<String, Arc<QuotaState>>>,
}
impl QuotaRegistry {
#[must_use]
pub fn new() -> Self {
Self {
inner: RwLock::new(HashMap::new()),
}
}
#[must_use]
pub fn state_for_key(&self, key: &str) -> Arc<QuotaState> {
if let Ok(map) = self.inner.read()
&& let Some(state) = map.get(key)
{
return Arc::clone(state);
}
let mut map = self.inner.write().unwrap_or_else(|e| {
tracing::error!("QuotaRegistry poisoned: {e}");
e.into_inner()
});
if map.len() >= MAX_QUOTA_ENTRIES {
let idle_keys: Vec<String> = map
.iter()
.filter(|(_, v)| v.consecutive_429_count() == 0)
.map(|(k, _)| k.clone())
.collect();
for k in idle_keys {
map.remove(&k);
}
if map.len() >= MAX_QUOTA_ENTRIES {
tracing::warn!(
entries = map.len(),
"Quota registry still at capacity after pruning idle entries"
);
}
}
Arc::clone(
map.entry(key.to_owned())
.or_insert_with(|| Arc::new(QuotaState::new())),
)
}
}
#[derive(Debug)]
struct QuotaInner {
consecutive_429s: u32,
backoff_until: Option<tokio::time::Instant>,
}
#[derive(Debug)]
pub struct QuotaState {
inner: Mutex<QuotaInner>,
}
impl Default for QuotaState {
fn default() -> Self {
Self::new()
}
}
impl QuotaState {
#[must_use]
pub const fn new() -> Self {
Self {
inner: Mutex::new(QuotaInner {
consecutive_429s: 0,
backoff_until: None,
}),
}
}
pub fn record_quota_hit(&self, retry_after: Duration) {
let mut inner = match self.inner.lock() {
Ok(guard) => guard,
Err(e) => {
tracing::error!(
error = %e,
"QuotaState mutex poisoned in record_quota_hit — skipping"
);
return;
}
};
inner.consecutive_429s += 1;
let count = inner.consecutive_429s;
let now = tokio::time::Instant::now();
let exp_backoff = exponential_backoff_with_jitter(count);
let effective_backoff = retry_after.max(exp_backoff);
tracing::warn!(
consecutive_429s = count,
retry_after_ms = u64::try_from(retry_after.as_millis()).unwrap_or_else(|e| {
tracing::warn!("Int conversion failed: {}", e);
u64::MAX
}),
effective_backoff_ms =
u64::try_from(effective_backoff.as_millis()).unwrap_or_else(|e| {
tracing::warn!("Int conversion failed: {}", e);
u64::MAX
}),
"Quota hit — backing off"
);
inner.backoff_until = Some(now + effective_backoff);
}
pub async fn wait_for_quota(&self) {
let (deadline, count) = {
let inner = match self.inner.lock() {
Ok(guard) => guard,
Err(e) => {
tracing::error!(
error = %e,
"QuotaState mutex poisoned in wait_for_quota — proceeding without backoff"
);
return;
}
};
(inner.backoff_until, inner.consecutive_429s)
};
if let Some(until) = deadline {
let now = tokio::time::Instant::now();
if until > now {
let wait = until - now;
tracing::warn!(
wait_ms = u64::try_from(wait.as_millis()).unwrap_or_else(|e| {
tracing::warn!("Int conversion failed: {}", e);
u64::MAX
}),
consecutive_429s = count,
"Quota backoff — waiting"
);
tokio::time::sleep(wait).await;
tracing::info!("Quota backoff complete — resuming operations");
}
}
}
pub fn record_success(&self) {
let mut inner = match self.inner.lock() {
Ok(guard) => guard,
Err(e) => {
tracing::error!(
error = %e,
"QuotaState mutex poisoned in record_success — skipping reset"
);
return;
}
};
if inner.consecutive_429s > 0 {
tracing::info!(
previous_consecutive_429s = inner.consecutive_429s,
"Quota state reset after successful operation"
);
inner.consecutive_429s = 0;
inner.backoff_until = None;
}
}
#[must_use]
pub fn consecutive_429_count(&self) -> u32 {
match self.inner.lock() {
Ok(guard) => guard.consecutive_429s,
Err(e) => {
tracing::error!(
error = %e,
"QuotaState mutex poisoned in consecutive_429_count — returning 0"
);
0
}
}
}
}
use crate::error::MAX_BACKOFF_SECS;
const MAX_JITTER_MS: u64 = 500;
fn jitter_ms() -> u64 {
fastrand::u64(0..MAX_JITTER_MS)
}
fn exponential_backoff_with_jitter(attempt: u32) -> Duration {
let base_secs = 2u64
.checked_shl(attempt.saturating_sub(1))
.unwrap_or(MAX_BACKOFF_SECS);
let capped_secs = base_secs.min(MAX_BACKOFF_SECS);
Duration::from_millis(capped_secs * 1000 + jitter_ms())
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_quota_state_records_and_resets() {
let state = QuotaState::new();
assert_eq!(state.consecutive_429_count(), 0);
state.record_quota_hit(Duration::from_millis(10));
assert_eq!(state.consecutive_429_count(), 1);
state.record_quota_hit(Duration::from_millis(10));
assert_eq!(state.consecutive_429_count(), 2);
state.record_success();
assert_eq!(state.consecutive_429_count(), 0);
}
#[tokio::test]
async fn test_wait_for_quota_returns_immediately_when_no_backoff() {
let state = QuotaState::new();
let start = tokio::time::Instant::now();
state.wait_for_quota().await;
let elapsed = start.elapsed();
assert!(elapsed < Duration::from_millis(50));
}
#[tokio::test]
async fn test_backoff_timing() {
tokio::time::pause();
let state = QuotaState::new();
state.record_quota_hit(Duration::from_millis(1));
assert!(state.inner.lock().unwrap().backoff_until.is_some());
tokio::time::advance(Duration::from_secs(3)).await;
state.wait_for_quota().await;
}
#[tokio::test]
async fn test_exponential_backoff_progression() {
let d1 = exponential_backoff_with_jitter(1);
let d2 = exponential_backoff_with_jitter(2);
let d3 = exponential_backoff_with_jitter(3);
let d7 = exponential_backoff_with_jitter(7);
assert!(d1 >= Duration::from_secs(2) && d1 < Duration::from_millis(2500));
assert!(d2 >= Duration::from_secs(4) && d2 < Duration::from_millis(4500));
assert!(d3 >= Duration::from_secs(8) && d3 < Duration::from_millis(8500));
assert!(d7 >= Duration::from_mins(2) && d7 < Duration::from_millis(120_500));
}
#[tokio::test]
async fn test_multiple_agents_respect_shared_quota_state() {
tokio::time::pause();
let state = Arc::new(QuotaState::new());
state.record_quota_hit(Duration::from_millis(100));
assert_eq!(state.consecutive_429_count(), 1);
assert!(state.inner.lock().unwrap().backoff_until.is_some());
state.record_quota_hit(Duration::from_millis(100));
assert_eq!(state.consecutive_429_count(), 2);
tokio::time::advance(Duration::from_mins(2)).await;
state.record_success();
assert_eq!(state.consecutive_429_count(), 0);
}
#[test]
fn test_quota_state_default() {
let state = QuotaState::default();
assert_eq!(state.consecutive_429_count(), 0);
}
#[tokio::test]
async fn test_double_success_reset_is_idempotent() {
let state = QuotaState::new();
state.record_quota_hit(Duration::from_millis(10));
assert_eq!(state.consecutive_429_count(), 1);
state.record_success();
assert_eq!(state.consecutive_429_count(), 0);
state.record_success();
assert_eq!(state.consecutive_429_count(), 0);
}
#[test]
fn test_jitter_is_nondeterministic_for_same_attempt() {
let values: Vec<u64> = (0..10).map(|_i| jitter_ms()).collect();
let distinct: std::collections::HashSet<u64> = values.iter().copied().collect();
assert!(
distinct.len() >= 2,
"Expected at least 2 distinct jitter values, got {values:?}"
);
}
#[test]
fn test_jitter_bounded() {
for _ in 0..100 {
let j = jitter_ms();
assert!(j < MAX_JITTER_MS, "jitter {j} should be < {MAX_JITTER_MS}");
}
}
#[test]
fn test_backoff_with_jitter_bounded() {
for attempt in 1..=20 {
let d = exponential_backoff_with_jitter(attempt);
let base_secs = 2u64
.checked_shl(attempt.saturating_sub(1))
.unwrap_or(MAX_BACKOFF_SECS)
.min(MAX_BACKOFF_SECS);
let base = Duration::from_secs(base_secs);
let max_with_jitter = Duration::from_millis(base_secs * 1000 + MAX_JITTER_MS);
assert!(
d >= base && d < max_with_jitter,
"attempt {attempt}: {d:?} not in [{base:?}, {max_with_jitter:?})"
);
}
}
#[tokio::test]
async fn test_quota_counter_increases_monotonically() {
let state = QuotaState::new();
for expected in 1..=5 {
state.record_quota_hit(Duration::from_millis(1));
assert_eq!(state.consecutive_429_count(), expected);
}
}
#[tokio::test]
async fn test_record_quota_hit_uses_max_of_retry_and_exponential() {
tokio::time::pause();
let state = QuotaState::new();
let large_retry = Duration::from_mins(5);
state.record_quota_hit(large_retry);
let until = {
let inner = state.inner.lock().unwrap();
assert!(inner.backoff_until.is_some());
inner.backoff_until.unwrap()
};
let now = tokio::time::Instant::now();
assert!(until >= now + Duration::from_secs(290));
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn stress_quota_storm_concurrent_429s() {
let state = Arc::new(QuotaState::new());
let tasks: u32 = 10;
let hits_per_task: u32 = 5;
let mut handles = Vec::new();
for _ in 0..tasks {
let qs = Arc::clone(&state);
handles.push(tokio::spawn(async move {
for _ in 0..hits_per_task {
qs.record_quota_hit(Duration::from_millis(10));
}
}));
}
for h in handles {
h.await.expect("task should complete");
}
assert_eq!(state.consecutive_429_count(), tasks * hits_per_task);
assert!(
state.inner.lock().unwrap().backoff_until.is_some(),
"backoff_until should be set after storm"
);
state.record_success();
assert_eq!(state.consecutive_429_count(), 0);
assert!(
state.inner.lock().unwrap().backoff_until.is_none(),
"backoff_until should be cleared after success"
);
}
#[tokio::test]
async fn test_quota_exhaustion_backoff_caps() {
tokio::time::pause();
let state = QuotaState::new();
let hit_count = 20u32;
for _ in 0..hit_count {
state.record_quota_hit(Duration::from_millis(1));
}
assert_eq!(state.consecutive_429_count(), hit_count);
let until = {
let inner = state.inner.lock().unwrap();
assert!(
inner.backoff_until.is_some(),
"backoff_until should be set after many hits"
);
inner.backoff_until.unwrap()
};
let now = tokio::time::Instant::now();
let backoff_duration = until - now;
let max_allowed = Duration::from_millis(MAX_BACKOFF_SECS * 1000 + MAX_JITTER_MS);
assert!(
backoff_duration <= max_allowed,
"backoff {backoff_duration:?} exceeds cap {max_allowed:?}"
);
tokio::time::advance(Duration::from_secs(MAX_BACKOFF_SECS + 1)).await;
state.wait_for_quota().await;
state.record_success();
assert_eq!(state.consecutive_429_count(), 0);
}
#[test]
fn same_key_returns_same_quota_state() {
let registry = QuotaRegistry::new();
let a = registry.state_for_key("test-key-same");
let b = registry.state_for_key("test-key-same");
assert!(Arc::ptr_eq(&a, &b), "Same key should return the same Arc");
}
#[test]
fn different_keys_return_independent_quota_states() {
let registry = QuotaRegistry::new();
let a = registry.state_for_key("test-key-alpha");
let b = registry.state_for_key("test-key-beta");
assert!(
!Arc::ptr_eq(&a, &b),
"Different keys should return different Arcs"
);
}
#[test]
fn different_keys_have_independent_backoff() {
let registry = QuotaRegistry::new();
let a = registry.state_for_key("test-key-independent-a");
let b = registry.state_for_key("test-key-independent-b");
a.record_quota_hit(Duration::from_secs(10));
assert!(a.consecutive_429_count() > 0);
assert_eq!(b.consecutive_429_count(), 0);
}
#[test]
fn different_registries_are_fully_independent() {
let registry_a = QuotaRegistry::new();
let registry_b = QuotaRegistry::new();
let state_a = registry_a.state_for_key("shared-key");
let state_b = registry_b.state_for_key("shared-key");
assert!(
!Arc::ptr_eq(&state_a, &state_b),
"Different registries should not share state even for the same key"
);
state_a.record_quota_hit(Duration::from_secs(10));
assert!(state_a.consecutive_429_count() > 0);
assert_eq!(state_b.consecutive_429_count(), 0);
}
}