use std::collections::{HashMap, VecDeque};
use std::sync::RwLock;
use std::time::{Duration, Instant};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RateLimitDecision {
Allowed,
Denied,
}
pub(super) struct PerConnectionRateLimiter {
max_attempts: u32,
window: Duration,
state: RwLock<HashMap<String, VecDeque<Instant>>>,
}
impl PerConnectionRateLimiter {
pub(super) fn new(max_attempts: u32, window: Duration) -> Self {
Self {
max_attempts,
window,
state: RwLock::new(HashMap::new()),
}
}
pub(super) fn admit(&self, conn_key: &str) -> RateLimitDecision {
self.admit_at(conn_key, Instant::now())
}
pub(super) fn admit_at(&self, conn_key: &str, now: Instant) -> RateLimitDecision {
let mut state = self.state.write().unwrap_or_else(|p| p.into_inner());
let cutoff = now.checked_sub(self.window).unwrap_or(now);
let entry = state.entry(conn_key.to_string()).or_default();
while let Some(front) = entry.front() {
if *front < cutoff {
entry.pop_front();
} else {
break;
}
}
if entry.len() as u32 >= self.max_attempts {
return RateLimitDecision::Denied;
}
entry.push_back(now);
RateLimitDecision::Allowed
}
pub(super) fn forget(&self, conn_key: &str) {
let mut state = self.state.write().unwrap_or_else(|p| p.into_inner());
state.remove(conn_key);
}
#[cfg(test)]
pub(super) fn attempts_in_window(&self, conn_key: &str) -> usize {
let state = self.state.read().unwrap_or_else(|p| p.into_inner());
state.get(conn_key).map(|q| q.len()).unwrap_or(0)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn admits_up_to_limit_then_denies() {
let rl = PerConnectionRateLimiter::new(3, Duration::from_secs(60));
assert_eq!(rl.admit("c"), RateLimitDecision::Allowed);
assert_eq!(rl.admit("c"), RateLimitDecision::Allowed);
assert_eq!(rl.admit("c"), RateLimitDecision::Allowed);
assert_eq!(rl.admit("c"), RateLimitDecision::Denied);
}
#[test]
fn sliding_window_replenishes_after_elapsed_time() {
let rl = PerConnectionRateLimiter::new(2, Duration::from_millis(100));
let t0 = Instant::now();
assert_eq!(rl.admit_at("c", t0), RateLimitDecision::Allowed);
assert_eq!(
rl.admit_at("c", t0 + Duration::from_millis(10)),
RateLimitDecision::Allowed
);
assert_eq!(
rl.admit_at("c", t0 + Duration::from_millis(20)),
RateLimitDecision::Denied
);
assert_eq!(
rl.admit_at("c", t0 + Duration::from_millis(200)),
RateLimitDecision::Allowed
);
}
#[test]
fn denied_attempts_do_not_extend_punishment() {
let rl = PerConnectionRateLimiter::new(2, Duration::from_millis(100));
let t0 = Instant::now();
rl.admit_at("c", t0);
rl.admit_at("c", t0);
for _ in 0..50 {
assert_eq!(rl.admit_at("c", t0), RateLimitDecision::Denied);
}
assert_eq!(rl.attempts_in_window("c"), 2);
assert_eq!(
rl.admit_at("c", t0 + Duration::from_millis(200)),
RateLimitDecision::Allowed
);
}
#[test]
fn connections_are_isolated() {
let rl = PerConnectionRateLimiter::new(1, Duration::from_secs(60));
assert_eq!(rl.admit("a"), RateLimitDecision::Allowed);
assert_eq!(rl.admit("a"), RateLimitDecision::Denied);
assert_eq!(rl.admit("b"), RateLimitDecision::Allowed);
}
#[test]
fn forget_clears_state() {
let rl = PerConnectionRateLimiter::new(1, Duration::from_secs(60));
rl.admit("c");
rl.forget("c");
assert_eq!(rl.admit("c"), RateLimitDecision::Allowed);
}
}