use std::net::IpAddr;
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
use std::time::{Duration, Instant};
use dashmap::DashMap;
use parking_lot::Mutex;
struct IpState {
count: AtomicU32,
last_access: Mutex<Instant>,
}
pub struct RateLimiter {
limits: DashMap<IpAddr, Arc<IpState>>,
max_per_ip: u32,
window: Duration,
}
impl RateLimiter {
pub fn new(max_per_ip: u32, window: Duration) -> Self {
Self {
limits: DashMap::new(),
max_per_ip,
window,
}
}
pub fn check(&self, ip: IpAddr) -> Result<(), RateLimitError> {
let state = self
.limits
.entry(ip)
.or_insert_with(|| {
Arc::new(IpState {
count: AtomicU32::new(0),
last_access: Mutex::new(Instant::now()),
})
})
.clone();
loop {
let current = state.count.load(Ordering::SeqCst);
if current >= self.max_per_ip {
return Err(RateLimitError {
ip,
current,
max: self.max_per_ip,
});
}
match state.count.compare_exchange(
current,
current + 1,
Ordering::SeqCst,
Ordering::SeqCst,
) {
Ok(_) => {
*state.last_access.lock() = Instant::now();
return Ok(());
}
Err(_) => {
continue;
}
}
}
}
pub fn release(&self, ip: IpAddr) {
if let Some(state) = self.limits.get(&ip) {
let _ = state
.count
.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |count| {
Some(count.saturating_sub(1))
});
}
}
pub fn current_count(&self, ip: IpAddr) -> u32 {
self.limits
.get(&ip)
.map(|s| s.count.load(Ordering::SeqCst))
.unwrap_or(0)
}
pub fn cleanup(&self) {
let now = Instant::now();
self.limits.retain(|_, state| {
let last = *state.last_access.lock();
let count = state.count.load(Ordering::SeqCst);
count > 0 || now.duration_since(last) < self.window
});
}
pub fn start_cleanup_task(self: Arc<Self>) {
let limiter = self.clone();
let interval = self.window / 2;
tokio::spawn(async move {
let mut ticker = tokio::time::interval(interval);
loop {
ticker.tick().await;
limiter.cleanup();
}
});
}
}
#[derive(Debug, Clone)]
pub struct RateLimitError {
pub ip: IpAddr,
pub current: u32,
pub max: u32,
}
impl std::fmt::Display for RateLimitError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Rate limit exceeded for {}: {} concurrent tests (max {})",
self.ip, self.current, self.max
)
}
}
impl std::error::Error for RateLimitError {}
pub struct RateLimitGuard {
limiter: Arc<RateLimiter>,
ip: IpAddr,
}
impl RateLimitGuard {
pub fn new(limiter: Arc<RateLimiter>, ip: IpAddr) -> Self {
Self { limiter, ip }
}
}
impl Drop for RateLimitGuard {
fn drop(&mut self) {
self.limiter.release(self.ip);
}
}
#[derive(Debug, Clone)]
pub struct RateLimitConfig {
pub max_per_ip: Option<u32>,
pub window_secs: u64,
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
max_per_ip: None, window_secs: 60,
}
}
}
impl RateLimitConfig {
pub fn build(&self) -> Option<Arc<RateLimiter>> {
self.max_per_ip
.map(|max| Arc::new(RateLimiter::new(max, Duration::from_secs(self.window_secs))))
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::Ipv4Addr;
#[test]
fn test_allows_under_limit() {
let limiter = RateLimiter::new(2, Duration::from_secs(60));
let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1));
assert!(limiter.check(ip).is_ok());
assert!(limiter.check(ip).is_ok());
assert_eq!(limiter.current_count(ip), 2);
}
#[test]
fn test_blocks_over_limit() {
let limiter = RateLimiter::new(2, Duration::from_secs(60));
let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1));
assert!(limiter.check(ip).is_ok());
assert!(limiter.check(ip).is_ok());
assert!(limiter.check(ip).is_err());
}
#[test]
fn test_release_allows_new() {
let limiter = RateLimiter::new(1, Duration::from_secs(60));
let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1));
assert!(limiter.check(ip).is_ok());
assert!(limiter.check(ip).is_err());
limiter.release(ip);
assert!(limiter.check(ip).is_ok());
}
#[test]
fn test_different_ips_independent() {
let limiter = RateLimiter::new(1, Duration::from_secs(60));
let ip1 = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1));
let ip2 = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 2));
assert!(limiter.check(ip1).is_ok());
assert!(limiter.check(ip2).is_ok());
assert!(limiter.check(ip1).is_err());
assert!(limiter.check(ip2).is_err());
}
}