use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use crate::error::Result;
use crate::router::BoxFuture;
pub trait ThrottleStore: Send + Sync + 'static {
fn incr(&self, key: String, ttl: Duration) -> BoxFuture<'_, Result<u64>>;
fn count(&self, key: String) -> BoxFuture<'_, Result<u64>>;
}
const SWEEP_THRESHOLD: usize = 4096;
struct Entry {
count: u64,
expires_at: Instant,
}
#[derive(Clone, Default)]
pub struct MemoryThrottleStore {
inner: Arc<Mutex<HashMap<String, Entry>>>,
}
impl MemoryThrottleStore {
pub fn new() -> Self {
Self::default()
}
}
impl ThrottleStore for MemoryThrottleStore {
fn incr(&self, key: String, ttl: Duration) -> BoxFuture<'_, Result<u64>> {
Box::pin(async move {
let now = Instant::now();
let mut map = self
.inner
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
if map.len() > SWEEP_THRESHOLD {
map.retain(|_, entry| entry.expires_at > now);
}
let entry = map.entry(key).or_insert(Entry {
count: 0,
expires_at: now + ttl,
});
if entry.expires_at <= now {
entry.count = 0;
entry.expires_at = now + ttl;
}
entry.count += 1;
Ok(entry.count)
})
}
fn count(&self, key: String) -> BoxFuture<'_, Result<u64>> {
Box::pin(async move {
let now = Instant::now();
let map = self
.inner
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
let count = map
.get(&key)
.filter(|entry| entry.expires_at > now)
.map(|entry| entry.count)
.unwrap_or(0);
Ok(count)
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn counts_within_a_window_then_resets_after_it() {
let store = MemoryThrottleStore::new();
let ttl = Duration::from_millis(80);
assert_eq!(store.incr("k".into(), ttl).await.unwrap(), 1);
assert_eq!(store.incr("k".into(), ttl).await.unwrap(), 2);
assert_eq!(store.incr("k".into(), ttl).await.unwrap(), 3);
tokio::time::sleep(Duration::from_millis(120)).await;
assert_eq!(store.incr("k".into(), ttl).await.unwrap(), 1);
}
#[tokio::test]
async fn distinct_keys_count_independently() {
let store = MemoryThrottleStore::new();
let ttl = Duration::from_secs(60);
assert_eq!(store.incr("a".into(), ttl).await.unwrap(), 1);
assert_eq!(store.incr("b".into(), ttl).await.unwrap(), 1);
assert_eq!(store.incr("a".into(), ttl).await.unwrap(), 2);
}
}