use std::collections::HashMap;
use std::hash::Hash;
use std::time::Duration;
use tokio::time::Instant;
use crate::config::GlobalRng;
#[derive(Debug, Clone)]
pub struct ExponentialBackoff {
base: Duration,
max: Duration,
}
impl ExponentialBackoff {
pub const fn new(base: Duration, max: Duration) -> Self {
Self { base, max }
}
#[inline]
pub fn delay(&self, attempt: u32) -> Duration {
let exponent = attempt.min(10);
let multiplier = 1u64 << exponent;
let delay = self.base.saturating_mul(multiplier as u32);
if delay > self.max { self.max } else { delay }
}
#[inline]
pub fn delay_for_failures(&self, failures: u32) -> Duration {
if failures == 0 {
return Duration::ZERO;
}
self.delay(failures - 1)
}
#[inline]
#[allow(dead_code)]
pub fn base(&self) -> Duration {
self.base
}
#[inline]
pub fn max(&self) -> Duration {
self.max
}
}
impl Default for ExponentialBackoff {
fn default() -> Self {
Self {
base: Duration::from_secs(5),
max: Duration::from_secs(300),
}
}
}
#[derive(Debug, Clone)]
struct BackoffEntry {
consecutive_failures: u32,
last_failure: Instant,
retry_after: Instant,
}
#[derive(Debug)]
pub struct TrackedBackoff<K: Eq + Hash> {
entries: HashMap<K, BackoffEntry>,
config: ExponentialBackoff,
max_entries: usize,
}
impl<K: Eq + Hash + Clone> TrackedBackoff<K> {
pub fn new(config: ExponentialBackoff, max_entries: usize) -> Self {
Self {
entries: HashMap::new(),
config,
max_entries,
}
}
pub fn with_defaults() -> Self {
Self::new(ExponentialBackoff::default(), 256)
}
pub fn is_in_backoff(&self, key: &K) -> bool {
if let Some(entry) = self.entries.get(key) {
Instant::now() < entry.retry_after
} else {
false
}
}
pub fn remaining_backoff(&self, key: &K) -> Option<Duration> {
self.entries.get(key).and_then(|entry| {
let now = Instant::now();
if now < entry.retry_after {
Some(entry.retry_after - now)
} else {
None
}
})
}
pub fn record_failure(&mut self, key: K) {
let now = Instant::now();
let consecutive_failures = {
let entry = self.entries.entry(key.clone()).or_insert(BackoffEntry {
consecutive_failures: 0,
last_failure: now,
retry_after: now,
});
entry.consecutive_failures = entry.consecutive_failures.saturating_add(1);
entry.last_failure = now;
entry.consecutive_failures
};
let backoff = self.config.delay_for_failures(consecutive_failures);
let jitter_factor: f64 = GlobalRng::random_range(0.8_f64..=1.2_f64);
let jittered_backoff = backoff.mul_f64(jitter_factor);
if let Some(entry) = self.entries.get_mut(&key) {
entry.retry_after = now + jittered_backoff;
}
tracing::debug!(
failures = consecutive_failures,
backoff_ms = jittered_backoff.as_millis() as u64,
"Backoff recorded for key"
);
self.evict_if_needed();
}
pub fn record_success(&mut self, key: &K) {
if self.entries.remove(key).is_some() {
tracing::debug!("Backoff cleared for key");
}
}
pub fn failure_count(&self, key: &K) -> u32 {
self.entries
.get(key)
.map(|e| e.consecutive_failures)
.unwrap_or(0)
}
pub fn cleanup_expired(&mut self) {
let now = Instant::now();
let max_stale = self.config.max();
self.entries.retain(|_, entry| {
let is_past_retry = now >= entry.retry_after;
let is_stale = now.duration_since(entry.last_failure) > max_stale;
!(is_past_retry && is_stale)
});
}
pub fn clear(&mut self) {
self.entries.clear();
}
#[allow(dead_code)]
pub fn len(&self) -> usize {
self.entries.len()
}
#[allow(dead_code)]
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
fn evict_if_needed(&mut self) {
while self.entries.len() > self.max_entries {
let oldest = self
.entries
.iter()
.min_by_key(|(_, entry)| entry.last_failure)
.map(|(k, _)| k.clone());
if let Some(key) = oldest {
self.entries.remove(&key);
} else {
break;
}
}
}
pub fn keys_in_backoff(&self) -> Vec<K> {
let now = Instant::now();
self.entries
.iter()
.filter(|(_, e)| now < e.retry_after)
.map(|(k, _)| k.clone())
.collect()
}
pub fn remove_expired_entries(&mut self) -> usize {
let now = Instant::now();
let before = self.entries.len();
self.entries.retain(|_, e| now < e.retry_after);
before - self.entries.len()
}
pub fn config(&self) -> &ExponentialBackoff {
&self.config
}
#[cfg(test)]
pub fn set_retry_after(&mut self, key: &K, retry_after: Instant) {
if let Some(entry) = self.entries.get_mut(key) {
entry.retry_after = retry_after;
}
}
}
impl<K: Eq + Hash + Clone> Default for TrackedBackoff<K> {
fn default() -> Self {
Self::with_defaults()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_exponential_delay_zero_indexed() {
let backoff = ExponentialBackoff::new(Duration::from_secs(1), Duration::from_secs(60));
assert_eq!(backoff.delay(0), Duration::from_secs(1)); assert_eq!(backoff.delay(1), Duration::from_secs(2)); assert_eq!(backoff.delay(2), Duration::from_secs(4)); assert_eq!(backoff.delay(3), Duration::from_secs(8)); assert_eq!(backoff.delay(4), Duration::from_secs(16)); assert_eq!(backoff.delay(5), Duration::from_secs(32)); assert_eq!(backoff.delay(6), Duration::from_secs(60)); }
#[test]
fn test_delay_for_failures_one_indexed() {
let backoff = ExponentialBackoff::new(Duration::from_secs(1), Duration::from_secs(300));
assert_eq!(backoff.delay_for_failures(0), Duration::ZERO);
assert_eq!(backoff.delay_for_failures(1), Duration::from_secs(1)); assert_eq!(backoff.delay_for_failures(2), Duration::from_secs(2)); assert_eq!(backoff.delay_for_failures(3), Duration::from_secs(4)); assert_eq!(backoff.delay_for_failures(4), Duration::from_secs(8)); }
#[test]
fn test_backoff_capped_at_max() {
let backoff = ExponentialBackoff::new(Duration::from_secs(10), Duration::from_secs(60));
assert_eq!(backoff.delay(6), Duration::from_secs(60));
assert_eq!(backoff.delay(10), Duration::from_secs(60));
assert_eq!(backoff.delay(100), Duration::from_secs(60));
}
#[test]
fn test_tracked_backoff_not_in_backoff_initially() {
let tracker: TrackedBackoff<String> = TrackedBackoff::with_defaults();
assert!(!tracker.is_in_backoff(&"test".to_string()));
}
#[test]
fn test_tracked_backoff_after_failure() {
let mut tracker: TrackedBackoff<String> = TrackedBackoff::with_defaults();
let key = "test".to_string();
tracker.record_failure(key.clone());
assert!(tracker.is_in_backoff(&key));
assert_eq!(tracker.failure_count(&key), 1);
}
#[test]
fn test_tracked_backoff_cleared_on_success() {
let mut tracker: TrackedBackoff<String> = TrackedBackoff::with_defaults();
let key = "test".to_string();
tracker.record_failure(key.clone());
assert!(tracker.is_in_backoff(&key));
tracker.record_success(&key);
assert!(!tracker.is_in_backoff(&key));
assert_eq!(tracker.failure_count(&key), 0);
}
#[test]
fn test_tracked_backoff_consecutive_failures() {
let config = ExponentialBackoff::new(Duration::from_secs(1), Duration::from_secs(300));
let mut tracker: TrackedBackoff<String> = TrackedBackoff::new(config, 100);
let key = "test".to_string();
tracker.record_failure(key.clone());
assert_eq!(tracker.failure_count(&key), 1);
tracker.record_failure(key.clone());
assert_eq!(tracker.failure_count(&key), 2);
tracker.record_failure(key.clone());
assert_eq!(tracker.failure_count(&key), 3);
}
#[test]
fn test_tracked_backoff_eviction() {
let config = ExponentialBackoff::new(Duration::from_secs(1), Duration::from_secs(300));
let mut tracker: TrackedBackoff<u32> = TrackedBackoff::new(config, 5);
for i in 0..10 {
tracker.record_failure(i);
}
assert!(tracker.len() <= 5);
}
#[test]
fn test_different_keys_tracked_separately() {
let mut tracker: TrackedBackoff<String> = TrackedBackoff::with_defaults();
tracker.record_failure("key1".to_string());
assert!(tracker.is_in_backoff(&"key1".to_string()));
assert!(!tracker.is_in_backoff(&"key2".to_string()));
}
#[test]
fn test_remaining_backoff() {
let config = ExponentialBackoff::new(Duration::from_secs(10), Duration::from_secs(300));
let mut tracker: TrackedBackoff<String> = TrackedBackoff::new(config, 100);
let key = "test".to_string();
assert!(tracker.remaining_backoff(&key).is_none());
tracker.record_failure(key.clone());
tracker.set_retry_after(&key, Instant::now() + Duration::from_secs(10));
let remaining = tracker.remaining_backoff(&key);
assert!(remaining.is_some());
assert!(remaining.unwrap() <= Duration::from_secs(10));
assert!(remaining.unwrap() >= Duration::from_secs(9));
}
#[test]
fn test_keys_in_backoff_stable_across_consecutive_calls() {
let config = ExponentialBackoff::new(Duration::from_secs(300), Duration::from_secs(3600));
let mut tracker: TrackedBackoff<u32> = TrackedBackoff::new(config, 64);
tracker.record_failure(1u32);
tracker.record_failure(2u32);
let first = tracker.keys_in_backoff();
let second = tracker.keys_in_backoff();
let mut first_sorted = first.clone();
first_sorted.sort();
let mut second_sorted = second.clone();
second_sorted.sort();
assert_eq!(
first_sorted, second_sorted,
"keys_in_backoff must return identical results on consecutive calls"
);
assert!(first_sorted.contains(&1u32));
assert!(first_sorted.contains(&2u32));
}
#[test]
fn test_remove_expired_entries_count() {
let config = ExponentialBackoff::new(Duration::from_secs(300), Duration::from_secs(3600));
let mut tracker: TrackedBackoff<u32> = TrackedBackoff::new(config, 64);
tracker.record_failure(1u32);
tracker.record_failure(2u32);
tracker.record_failure(3u32);
tracker.set_retry_after(&1u32, Instant::now() - Duration::from_secs(1));
tracker.set_retry_after(&2u32, Instant::now() - Duration::from_secs(1));
let removed = tracker.remove_expired_entries();
assert_eq!(removed, 2);
assert_eq!(tracker.len(), 1);
assert!(!tracker.is_in_backoff(&1u32));
assert!(!tracker.is_in_backoff(&2u32));
assert!(tracker.is_in_backoff(&3u32));
}
}