use lru::LruCache;
use std::{collections::VecDeque, net::IpAddr, num::NonZeroUsize};
use tokio::time::{Duration, Instant};
pub struct RateLimiter {
entries: LruCache<IpAddr, VecDeque<Instant>>,
max: usize,
interval: Duration,
}
impl RateLimiter {
pub fn new(max: usize, interval: Duration) -> Self {
Self {
entries: LruCache::new(NonZeroUsize::new(1024).unwrap()),
max,
interval,
}
}
pub fn check(&mut self, ip: IpAddr) -> bool {
if self.max == 0 {
return true;
}
let now = Instant::now();
let deque = self.entries.get_or_insert_mut(ip, VecDeque::new);
while let Some(&ts) = deque.front() {
if now.duration_since(ts) > self.interval {
deque.pop_front();
} else {
break;
}
}
if deque.len() >= self.max {
return false;
}
deque.push_back(now);
true
}
}
#[derive(Clone, Copy)]
struct BackoffEntry {
failures: u32,
next_allowed: Instant,
}
pub struct AuthBackoff {
entries: LruCache<IpAddr, BackoffEntry>,
base_delay: Duration,
max_delay: Duration,
}
impl AuthBackoff {
pub fn new(base_delay: Duration, max_delay: Duration) -> Self {
Self {
entries: LruCache::new(NonZeroUsize::new(1024).unwrap()),
base_delay,
max_delay,
}
}
pub fn allow(&mut self, ip: IpAddr) -> bool {
match self.entries.get(&ip) {
Some(entry) => Instant::now() >= entry.next_allowed,
None => true,
}
}
pub fn record_failure(&mut self, ip: IpAddr) -> Duration {
let now = Instant::now();
let failures = self
.entries
.get(&ip)
.map(|e| e.failures.saturating_add(1))
.unwrap_or(1);
let exp = failures.saturating_sub(1).min(6);
let delay = self
.base_delay
.checked_mul(1u32 << exp)
.unwrap_or(self.max_delay)
.min(self.max_delay);
self.entries.put(
ip,
BackoffEntry {
failures,
next_allowed: now + delay,
},
);
delay
}
pub fn record_success(&mut self, ip: IpAddr) {
self.entries.pop(&ip);
}
}
#[cfg(test)]
mod tests {
use super::{AuthBackoff, RateLimiter};
use std::net::IpAddr;
use std::str::FromStr;
use tokio::time::Duration;
#[tokio::test(flavor = "current_thread")]
async fn rate_limiter_blocks_over_limit() {
let mut limiter = RateLimiter::new(2, Duration::from_millis(50));
let ip = IpAddr::from_str("127.0.0.1").unwrap();
assert!(limiter.check(ip));
assert!(limiter.check(ip));
assert!(!limiter.check(ip));
}
#[tokio::test(flavor = "current_thread")]
async fn auth_backoff_recovers_after_delay() {
let mut backoff = AuthBackoff::new(Duration::from_millis(5), Duration::from_millis(20));
let ip = IpAddr::from_str("127.0.0.1").unwrap();
assert!(backoff.allow(ip));
backoff.record_failure(ip);
assert!(!backoff.allow(ip));
tokio::time::sleep(Duration::from_millis(6)).await;
assert!(backoff.allow(ip));
}
#[tokio::test(flavor = "current_thread")]
async fn auth_backoff_clears_on_success() {
let mut backoff = AuthBackoff::new(Duration::from_millis(5), Duration::from_millis(20));
let ip = IpAddr::from_str("127.0.0.1").unwrap();
backoff.record_failure(ip);
assert!(!backoff.allow(ip));
backoff.record_success(ip);
assert!(backoff.allow(ip));
}
}