use std::collections::{HashMap, HashSet};
use std::net::IpAddr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use thiserror::Error;
use tokio::sync::RwLock;
#[derive(Debug, Clone)]
pub struct RateLimitConfig {
pub max_requests: u32, pub window: Duration, pub whitelist: HashSet<IpAddr>, pub blacklist: HashSet<IpAddr>, }
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
max_requests: 60, window: Duration::from_secs(60),
whitelist: HashSet::new(),
blacklist: HashSet::new(),
}
}
}
#[derive(Debug, Error)]
pub enum RateLimitError {
#[error("تم تجاوز الحد المسموح للطلبات. حاول لاحقًا.\nRate limit exceeded. Try again later.")]
LimitExceeded,
#[error("العنوان محظور (قائمة سوداء).\nIP is blacklisted.")]
Blacklisted,
}
#[derive(Debug)]
struct RequestInfo {
count: u32,
window_start: Instant,
}
#[derive(Debug)]
pub struct RateLimiter {
config: RateLimitConfig,
requests: RwLock<HashMap<IpAddr, RequestInfo>>,
}
impl RateLimiter {
pub fn new(config: RateLimitConfig) -> Arc<Self> {
Arc::new(Self {
config,
requests: RwLock::new(HashMap::new()),
})
}
pub async fn check(&self, ip: IpAddr) -> Result<(), RateLimitError> {
if self.config.whitelist.contains(&ip) {
return Ok(()); }
if self.config.blacklist.contains(&ip) {
return Err(RateLimitError::Blacklisted);
}
let mut reqs = self.requests.write().await;
let now = Instant::now();
let entry = reqs.entry(ip).or_insert(RequestInfo {
count: 0,
window_start: now,
});
if now.duration_since(entry.window_start) > self.config.window {
entry.count = 1;
entry.window_start = now;
} else {
entry.count += 1;
}
if entry.count > self.config.max_requests {
return Err(RateLimitError::LimitExceeded);
}
Ok(())
}
pub fn add_whitelist(&mut self, ip: IpAddr) {
self.config.whitelist.insert(ip);
}
pub fn add_blacklist(&mut self, ip: IpAddr) {
self.config.blacklist.insert(ip);
}
}