use std::collections::{HashMap, HashSet};
use std::net::IpAddr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use thiserror::Error;
use tokio::sync::RwLock;
const PRUNE_THRESHOLD_IPS: usize = 1024;
const STALE_WINDOW_MULTIPLIER: u32 = 10;
#[derive(Debug, Clone)]
pub struct RateLimitConfig {
pub max_requests: u32, pub window: Duration, pub whitelist: HashSet<IpAddr>, pub blacklist: HashSet<IpAddr>, }
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
max_requests: 60, window: Duration::from_secs(60),
whitelist: HashSet::new(),
blacklist: HashSet::new(),
}
}
}
#[derive(Debug, Error)]
pub enum RateLimitError {
#[error("تم تجاوز الحد المسموح للطلبات. حاول لاحقًا.\nRate limit exceeded. Try again later.")]
LimitExceeded,
#[error("العنوان محظور (قائمة سوداء).\nIP is blacklisted.")]
Blacklisted,
}
#[derive(Debug)]
struct RequestInfo {
count: u32,
window_start: Instant,
}
#[derive(Debug)]
pub struct RateLimiter {
config: RateLimitConfig,
requests: RwLock<HashMap<IpAddr, RequestInfo>>,
}
impl RateLimiter {
pub fn new(config: RateLimitConfig) -> Arc<Self> {
Arc::new(Self {
config,
requests: RwLock::new(HashMap::new()),
})
}
pub async fn check(&self, ip: IpAddr) -> Result<(), RateLimitError> {
if self.config.whitelist.contains(&ip) {
return Ok(()); }
if self.config.blacklist.contains(&ip) {
return Err(RateLimitError::Blacklisted);
}
let mut reqs = self.requests.write().await;
let now = Instant::now();
self.prune_if_needed(&mut reqs, now);
let entry = reqs.entry(ip).or_insert(RequestInfo {
count: 0,
window_start: now,
});
if now.saturating_duration_since(entry.window_start) > self.config.window {
entry.count = 1;
entry.window_start = now;
} else {
entry.count += 1;
}
if entry.count > self.config.max_requests {
return Err(RateLimitError::LimitExceeded);
}
Ok(())
}
pub async fn retry_after_seconds(&self, ip: IpAddr) -> u64 {
let reqs = self.requests.read().await;
if let Some(info) = reqs.get(&ip) {
let elapsed = Instant::now().saturating_duration_since(info.window_start);
let remaining = self.config.window.saturating_sub(elapsed);
return remaining.as_secs().max(1);
}
self.config.window.as_secs().max(1)
}
fn prune_if_needed(&self, reqs: &mut HashMap<IpAddr, RequestInfo>, now: Instant) {
if reqs.len() < PRUNE_THRESHOLD_IPS {
return;
}
let stale_after = self.config.window.saturating_mul(STALE_WINDOW_MULTIPLIER);
reqs.retain(|_, info| now.saturating_duration_since(info.window_start) <= stale_after);
}
pub fn add_whitelist(&mut self, ip: IpAddr) {
self.config.whitelist.insert(ip);
}
pub fn add_blacklist(&mut self, ip: IpAddr) {
self.config.blacklist.insert(ip);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn stale_entries_are_pruned_when_map_grows() {
let limiter = RateLimiter::new(RateLimitConfig {
max_requests: 10,
window: Duration::from_millis(5),
whitelist: HashSet::new(),
blacklist: HashSet::new(),
});
for i in 0..(PRUNE_THRESHOLD_IPS + 64) {
let ip = IpAddr::from([10, 0, ((i / 256) & 0xff) as u8, (i % 256) as u8]);
let _ = limiter.check(ip).await;
}
let before = limiter.requests.read().await.len();
assert!(before >= PRUNE_THRESHOLD_IPS);
tokio::time::sleep(Duration::from_millis(70)).await;
let trigger_ip = IpAddr::from([127, 0, 0, 1]);
let _ = limiter.check(trigger_ip).await;
let after = limiter.requests.read().await.len();
assert!(after < before);
assert!(after <= 2);
}
#[tokio::test]
async fn blacklisted_ip_is_denied_immediately() {
let blocked_ip = IpAddr::from([203, 0, 113, 10]);
let mut blacklist = HashSet::new();
blacklist.insert(blocked_ip);
let limiter = RateLimiter::new(RateLimitConfig {
max_requests: 10,
window: Duration::from_secs(60),
whitelist: HashSet::new(),
blacklist,
});
let result = limiter.check(blocked_ip).await;
assert!(matches!(result, Err(RateLimitError::Blacklisted)));
}
}