use std::{
collections::HashMap,
hash::Hash,
num::NonZeroU32,
sync::{Arc, Mutex, PoisonError, Weak},
time::{Duration, Instant},
};
use governor::{DefaultDirectRateLimiter, Quota, RateLimiter};
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
pub enum BoundedLimiterError {
#[error("rate limit exceeded for key")]
RateLimited,
}
struct Entry {
limiter: DefaultDirectRateLimiter,
last_seen: Instant,
}
struct Inner<K: Eq + Hash + Clone> {
map: Mutex<HashMap<K, Entry>>,
quota: Quota,
max_tracked_keys: usize,
idle_eviction: Duration,
}
#[allow(
missing_debug_implementations,
reason = "wraps governor RateLimiter which has no Debug impl"
)]
pub struct BoundedKeyedLimiter<K: Eq + Hash + Clone> {
inner: Arc<Inner<K>>,
}
impl<K: Eq + Hash + Clone> Clone for BoundedKeyedLimiter<K> {
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
}
}
}
impl<K: Eq + Hash + Clone + Send + Sync + 'static> BoundedKeyedLimiter<K> {
#[must_use]
pub(crate) fn new(quota: Quota, max_tracked_keys: usize, idle_eviction: Duration) -> Self {
let inner = Arc::new(Inner {
map: Mutex::new(HashMap::new()),
quota,
max_tracked_keys,
idle_eviction,
});
Self::spawn_prune_task(&inner);
Self { inner }
}
#[must_use]
pub fn with_per_minute(
requests_per_minute: u32,
max_tracked_keys: usize,
idle_eviction: Duration,
) -> Self {
let rate = NonZeroU32::new(requests_per_minute.max(1)).unwrap_or(NonZeroU32::MIN);
Self::new(Quota::per_minute(rate), max_tracked_keys, idle_eviction)
}
#[must_use]
pub fn with_per_second(
requests_per_second: u32,
max_tracked_keys: usize,
idle_eviction: Duration,
) -> Self {
let rate = NonZeroU32::new(requests_per_second.max(1)).unwrap_or(NonZeroU32::MIN);
Self::new(Quota::per_second(rate), max_tracked_keys, idle_eviction)
}
fn spawn_prune_task(inner: &Arc<Inner<K>>) {
let Ok(handle) = tokio::runtime::Handle::try_current() else {
return;
};
let weak: Weak<Inner<K>> = Arc::downgrade(inner);
let interval = (inner.idle_eviction / 4).max(Duration::from_mins(1));
handle.spawn(async move {
let mut ticker = tokio::time::interval(interval);
ticker.tick().await;
loop {
ticker.tick().await;
let Some(inner) = weak.upgrade() else {
return;
};
Self::prune_idle(&inner);
}
});
}
fn prune_idle(inner: &Inner<K>) {
let mut guard = inner.map.lock().unwrap_or_else(PoisonError::into_inner);
let cutoff = Instant::now()
.checked_sub(inner.idle_eviction)
.unwrap_or_else(Instant::now);
guard.retain(|_, entry| entry.last_seen >= cutoff);
}
fn evict_lru(map: &mut HashMap<K, Entry>) {
let oldest_key = map
.iter()
.min_by_key(|(_, entry)| entry.last_seen)
.map(|(k, _)| k.clone());
if let Some(key) = oldest_key {
map.remove(&key);
}
}
pub fn check_key(&self, key: &K) -> Result<(), BoundedLimiterError> {
let mut guard = self
.inner
.map
.lock()
.unwrap_or_else(PoisonError::into_inner);
let now = Instant::now();
if let Some(entry) = guard.get_mut(key) {
entry.last_seen = now;
return entry
.limiter
.check()
.map_err(|_| BoundedLimiterError::RateLimited);
}
if guard.len() >= self.inner.max_tracked_keys {
let cutoff = now
.checked_sub(self.inner.idle_eviction)
.unwrap_or_else(Instant::now);
guard.retain(|_, entry| entry.last_seen >= cutoff);
if guard.len() >= self.inner.max_tracked_keys {
Self::evict_lru(&mut guard);
}
}
let limiter = RateLimiter::direct(self.inner.quota);
let result = limiter
.check()
.map_err(|_| BoundedLimiterError::RateLimited);
guard.insert(
key.clone(),
Entry {
limiter,
last_seen: now,
},
);
result
}
#[must_use]
pub fn len(&self) -> usize {
self.inner
.map
.lock()
.unwrap_or_else(PoisonError::into_inner)
.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
#[cfg(test)]
mod tests {
use std::{net::IpAddr, num::NonZeroU32, time::Duration};
use governor::Quota;
use super::BoundedKeyedLimiter;
fn ip(n: u32) -> IpAddr {
IpAddr::from(n.to_be_bytes())
}
#[test]
fn never_exceeds_max_tracked_keys() {
let quota = Quota::per_minute(NonZeroU32::new(10).unwrap());
let limiter: BoundedKeyedLimiter<IpAddr> =
BoundedKeyedLimiter::new(quota, 100, Duration::from_hours(1));
for i in 0..10_000_u32 {
let _ = limiter.check_key(&ip(i));
assert!(
limiter.len() <= 100,
"tracked keys exceeded cap at iteration {i}: {} > 100",
limiter.len()
);
}
assert_eq!(limiter.len(), 100, "table should be full at the cap");
}
#[test]
fn evicted_keys_get_fresh_quota() {
let quota = Quota::per_minute(NonZeroU32::new(2).unwrap());
let limiter: BoundedKeyedLimiter<IpAddr> =
BoundedKeyedLimiter::new(quota, 2, Duration::from_hours(1));
let target = ip(1);
assert!(limiter.check_key(&target).is_ok(), "first ok");
assert!(limiter.check_key(&target).is_ok(), "second ok");
assert!(limiter.check_key(&target).is_err(), "third blocked");
std::thread::sleep(Duration::from_millis(5));
let _ = limiter.check_key(&ip(2));
std::thread::sleep(Duration::from_millis(5));
let _ = limiter.check_key(&ip(3));
std::thread::sleep(Duration::from_millis(5));
let _ = limiter.check_key(&ip(4));
std::thread::sleep(Duration::from_millis(5));
let _ = limiter.check_key(&ip(5));
assert!(
limiter.check_key(&target).is_ok(),
"evicted key gets a fresh quota on reappearance"
);
}
#[test]
fn active_over_quota_key_not_evicted() {
let quota = Quota::per_minute(NonZeroU32::new(2).unwrap());
let limiter: BoundedKeyedLimiter<IpAddr> =
BoundedKeyedLimiter::new(quota, 3, Duration::from_hours(1));
for i in 100..103_u32 {
let _ = limiter.check_key(&ip(i));
}
assert_eq!(limiter.len(), 3);
std::thread::sleep(Duration::from_millis(5));
let attacker = ip(200);
let _ = limiter.check_key(&attacker);
let _ = limiter.check_key(&attacker);
for new_key in 300..310_u32 {
std::thread::sleep(Duration::from_millis(2));
let _ = limiter.check_key(&attacker); std::thread::sleep(Duration::from_millis(2));
let _ = limiter.check_key(&ip(new_key)); }
let _ = limiter.check_key(&attacker);
assert!(
limiter.check_key(&attacker).is_err(),
"actively over-quota attacker must not be evicted into a fresh quota"
);
}
}