use std::collections::{HashMap, VecDeque};
use std::net::IpAddr;
use std::sync::Mutex;
use std::time::{Duration, Instant};
pub struct SlidingWindowRateLimiter {
limit: u32,
window: Duration,
entries: Mutex<HashMap<IpAddr, VecDeque<Instant>>>,
}
impl SlidingWindowRateLimiter {
pub fn new(limit: u32, window: Duration) -> Self {
Self {
limit,
window,
entries: Mutex::new(HashMap::new()),
}
}
pub fn check(&self, ip: IpAddr) -> bool {
if self.limit == 0 {
return true;
}
let now = Instant::now();
let cutoff = now - self.window;
let mut entries = self.entries.lock().unwrap();
let timestamps = entries.entry(ip).or_default();
while timestamps.front().is_some_and(|&t| t <= cutoff) {
timestamps.pop_front();
}
if timestamps.len() >= self.limit as usize {
return false;
}
timestamps.push_back(now);
true
}
pub fn sweep(&self) {
let now = Instant::now();
let cutoff = now - self.window;
let mut entries = self.entries.lock().unwrap();
entries.retain(|_, timestamps| {
while timestamps.front().is_some_and(|&t| t <= cutoff) {
timestamps.pop_front();
}
!timestamps.is_empty()
});
}
#[cfg(test)]
pub fn entry_count(&self) -> usize {
self.entries.lock().unwrap().len()
}
}
pub struct GatewayRateLimiter {
pair: SlidingWindowRateLimiter,
webhook: SlidingWindowRateLimiter,
}
impl GatewayRateLimiter {
pub fn new(pair_per_min: u32, webhook_per_min: u32, window: Duration) -> Self {
Self {
pair: SlidingWindowRateLimiter::new(pair_per_min, window),
webhook: SlidingWindowRateLimiter::new(webhook_per_min, window),
}
}
pub fn check_pair(&self, ip: IpAddr) -> bool {
self.pair.check(ip)
}
pub fn check_webhook(&self, ip: IpAddr) -> bool {
self.webhook.check(ip)
}
pub fn sweep(&self) {
self.pair.sweep();
self.webhook.sweep();
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{IpAddr, Ipv4Addr};
fn localhost() -> IpAddr {
IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))
}
fn other_ip() -> IpAddr {
IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))
}
#[test]
fn test_zero_limit_allows_all() {
let limiter = SlidingWindowRateLimiter::new(0, Duration::from_secs(60));
for _ in 0..100 {
assert!(limiter.check(localhost()));
}
}
#[test]
fn test_allows_up_to_limit() {
let limiter = SlidingWindowRateLimiter::new(3, Duration::from_secs(60));
assert!(limiter.check(localhost()));
assert!(limiter.check(localhost()));
assert!(limiter.check(localhost()));
assert!(!limiter.check(localhost())); }
#[test]
fn test_different_ips_independent() {
let limiter = SlidingWindowRateLimiter::new(1, Duration::from_secs(60));
assert!(limiter.check(localhost()));
assert!(limiter.check(other_ip())); assert!(!limiter.check(localhost())); }
#[test]
fn test_expired_entries_freed() {
let limiter = SlidingWindowRateLimiter::new(1, Duration::from_secs(5));
assert!(limiter.check(localhost()));
assert!(!limiter.check(localhost()));
let short = SlidingWindowRateLimiter::new(1, Duration::from_millis(50));
assert!(short.check(localhost()));
std::thread::sleep(Duration::from_millis(100)); assert!(short.check(localhost())); }
#[test]
fn test_sweep_clears_stale_ips() {
let limiter = SlidingWindowRateLimiter::new(1, Duration::from_millis(1));
assert!(limiter.check(localhost()));
std::thread::sleep(Duration::from_millis(5));
limiter.sweep();
assert_eq!(limiter.entry_count(), 0);
}
#[test]
fn test_gateway_rate_limiter_separate_limits() {
let grl = GatewayRateLimiter::new(1, 2, Duration::from_secs(60));
assert!(grl.check_pair(localhost()));
assert!(!grl.check_pair(localhost())); assert!(grl.check_webhook(localhost()));
assert!(grl.check_webhook(localhost()));
assert!(!grl.check_webhook(localhost())); }
}