use std::{borrow::Borrow, collections::HashMap, hash::Hash, sync::Arc, time::Duration};
use ruma::time::Instant;
use super::locks::RwLock;
const MAX_DELAY: u64 = 15 * 60;
const MULTIPLIER: u64 = 15;
#[derive(Clone, Debug)]
pub struct FailuresCache<T: Eq + Hash> {
inner: Arc<InnerCache<T>>,
}
#[derive(Debug)]
struct InnerCache<T: Eq + Hash> {
max_delay: Duration,
backoff_multiplier: u64,
items: RwLock<HashMap<T, FailuresItem>>,
}
impl<T: Eq + Hash> Default for InnerCache<T> {
fn default() -> Self {
Self {
max_delay: Duration::from_secs(MAX_DELAY),
backoff_multiplier: MULTIPLIER,
items: Default::default(),
}
}
}
#[derive(Debug, Clone, Copy)]
struct FailuresItem {
insertion_time: Instant,
duration: Duration,
failure_count: u8,
}
impl FailuresItem {
fn expired(&self) -> bool {
self.insertion_time.elapsed() >= self.duration
}
fn expire(&mut self) {
self.duration = Duration::from_secs(0);
}
}
impl<T> FailuresCache<T>
where
T: Eq + Hash,
{
pub fn new() -> Self {
Self { inner: Default::default() }
}
pub fn with_settings(max_delay: Duration, multiplier: u8) -> Self {
Self {
inner: InnerCache {
max_delay,
backoff_multiplier: multiplier.into(),
items: Default::default(),
}
.into(),
}
}
pub fn contains<Q>(&self, key: &Q) -> bool
where
T: Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
self.inner.items.read().get(key).is_some_and(|item| !item.expired())
}
pub fn failure_count<Q>(&self, key: &Q) -> Option<u8>
where
T: Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
self.inner.items.read().get(key).map(|i| i.failure_count)
}
fn calculate_delay(&self, failure_count: u8) -> Duration {
let exponential_backoff = 2u64.saturating_pow(failure_count.into());
let delay = exponential_backoff.saturating_mul(self.inner.backoff_multiplier);
Duration::from_secs(delay).clamp(Duration::from_secs(1), self.inner.max_delay)
}
pub fn insert(&self, item: T) {
self.extend([item]);
}
pub fn extend(&self, iterator: impl IntoIterator<Item = T>) {
let mut lock = self.inner.items.write();
let now = Instant::now();
for key in iterator {
let failure_count = if let Some(value) = lock.get(&key) {
value.failure_count.saturating_add(1)
} else {
0
};
let delay = self.calculate_delay(failure_count);
let item = FailuresItem { insertion_time: now, duration: delay, failure_count };
lock.insert(key, item);
}
}
pub fn remove<'a, I, Q>(&'a self, iterator: I)
where
I: Iterator<Item = &'a Q>,
T: Borrow<Q>,
Q: Hash + Eq + 'a + ?Sized,
{
let mut lock = self.inner.items.write();
for item in iterator {
lock.remove(item);
}
}
#[doc(hidden)]
pub fn expire(&self, item: &T) {
self.inner.items.write().get_mut(item).map(FailuresItem::expire);
}
}
impl<T: Eq + Hash> Default for FailuresCache<T> {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use proptest::prelude::*;
use super::FailuresCache;
#[test]
fn failures_cache() {
let cache = FailuresCache::new();
assert!(!cache.contains(&1));
cache.extend([1u8].iter());
assert!(cache.contains(&1));
cache.inner.items.write().get_mut(&1).unwrap().duration = Duration::from_secs(0);
assert!(!cache.contains(&1));
cache.remove([1u8].iter());
assert!(cache.inner.items.read().get(&1).is_none())
}
#[test]
fn failures_cache_timeout() {
let cache: FailuresCache<u8> = FailuresCache::new();
assert_eq!(cache.calculate_delay(0).as_secs(), 15);
assert_eq!(cache.calculate_delay(1).as_secs(), 30);
assert_eq!(cache.calculate_delay(2).as_secs(), 60);
assert_eq!(cache.calculate_delay(3).as_secs(), 120);
assert_eq!(cache.calculate_delay(4).as_secs(), 240);
assert_eq!(cache.calculate_delay(5).as_secs(), 480);
assert_eq!(cache.calculate_delay(6).as_secs(), 900);
assert_eq!(cache.calculate_delay(7).as_secs(), 900);
}
proptest! {
#[test]
fn failures_cache_proptest_timeout(count in 0..10u8) {
let cache: FailuresCache<u8> = FailuresCache::new();
let delay = cache.calculate_delay(count).as_secs();
assert!(delay <= 900);
assert!(delay >= 15);
}
}
}