use dashmap::DashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
const MAX_ENTRIES: usize = 50_000;
const CAPACITY_CHECK_INTERVAL: u64 = 128;
pub struct RateLimiter {
state: DashMap<String, (u32, Instant)>,
limit: u32,
window: Duration,
allowed: AtomicU64,
limited: AtomicU64,
ops: AtomicU64,
evicted: AtomicU64,
}
impl RateLimiter {
pub fn new(limit_per_minute: u32) -> Self {
Self {
state: DashMap::new(),
limit: limit_per_minute,
window: Duration::from_secs(60),
allowed: AtomicU64::new(0),
limited: AtomicU64::new(0),
ops: AtomicU64::new(0),
evicted: AtomicU64::new(0),
}
}
pub fn with_window(limit: u32, window: Duration) -> Self {
Self {
state: DashMap::new(),
limit,
window,
allowed: AtomicU64::new(0),
limited: AtomicU64::new(0),
ops: AtomicU64::new(0),
evicted: AtomicU64::new(0),
}
}
pub fn check_and_increment(&self, ip: &str) -> bool {
let now = Instant::now();
let ops = self.ops.fetch_add(1, Ordering::Relaxed);
if ops.is_multiple_of(CAPACITY_CHECK_INTERVAL) && self.state.len() >= MAX_ENTRIES {
self.evict_expired(now);
}
if self.state.len() >= MAX_ENTRIES && !self.state.contains_key(ip) {
self.limited.fetch_add(1, Ordering::Relaxed);
self.evicted.fetch_add(1, Ordering::Relaxed);
return false;
}
let allowed = {
let mut entry = self.state.entry(ip.to_string()).or_insert((0, now));
if now.duration_since(entry.1) >= self.window {
entry.0 = 0;
entry.1 = now;
}
if entry.0 >= self.limit {
false
} else {
entry.0 += 1;
true
}
};
if allowed {
self.allowed.fetch_add(1, Ordering::Relaxed);
} else {
self.limited.fetch_add(1, Ordering::Relaxed);
}
allowed
}
fn evict_expired(&self, now: Instant) {
self.state
.retain(|_, (_, window_start)| now.duration_since(*window_start) < self.window);
}
pub fn check(&self, ip: &str) -> bool {
let now = Instant::now();
if let Some(entry) = self.state.get(ip) {
if now.duration_since(entry.1) >= self.window {
return true;
}
entry.0 < self.limit
} else {
true
}
}
pub fn get_count(&self, ip: &str) -> u32 {
self.state.get(ip).map(|e| e.0).unwrap_or(0)
}
pub fn cleanup(&self) {
let now = Instant::now();
let max_age = self.window * 2;
self.state
.retain(|_, (_, window_start)| now.duration_since(*window_start) < max_age);
}
pub fn len(&self) -> usize {
self.state.len()
}
pub fn is_empty(&self) -> bool {
self.state.is_empty()
}
pub fn stats(&self) -> RateLimiterStats {
RateLimiterStats {
tracked_ips: self.state.len(),
max_entries: MAX_ENTRIES,
allowed: self.allowed.load(Ordering::Relaxed),
limited: self.limited.load(Ordering::Relaxed),
evicted: self.evicted.load(Ordering::Relaxed),
limit: self.limit,
window_secs: self.window.as_secs(),
}
}
pub fn max_entries(&self) -> usize {
MAX_ENTRIES
}
pub fn reset(&self) {
self.state.clear();
self.allowed.store(0, Ordering::Relaxed);
self.limited.store(0, Ordering::Relaxed);
self.ops.store(0, Ordering::Relaxed);
self.evicted.store(0, Ordering::Relaxed);
}
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct RateLimiterStats {
pub tracked_ips: usize,
pub max_entries: usize,
pub allowed: u64,
pub limited: u64,
pub evicted: u64,
pub limit: u32,
pub window_secs: u64,
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
#[test]
fn test_new_ip_allowed() {
let limiter = RateLimiter::new(10);
assert!(limiter.check_and_increment("192.168.1.1"));
}
#[test]
fn test_within_limit() {
let limiter = RateLimiter::new(5);
let ip = "10.0.0.1";
for _ in 0..5 {
assert!(limiter.check_and_increment(ip));
}
assert_eq!(limiter.get_count(ip), 5);
}
#[test]
fn test_exceeds_limit() {
let limiter = RateLimiter::new(3);
let ip = "10.0.0.2";
assert!(limiter.check_and_increment(ip));
assert!(limiter.check_and_increment(ip));
assert!(limiter.check_and_increment(ip));
assert!(!limiter.check_and_increment(ip));
assert!(!limiter.check_and_increment(ip));
}
#[test]
fn test_window_reset() {
let limiter = RateLimiter::with_window(2, Duration::from_millis(50));
let ip = "10.0.0.3";
assert!(limiter.check_and_increment(ip));
assert!(limiter.check_and_increment(ip));
assert!(!limiter.check_and_increment(ip));
thread::sleep(Duration::from_millis(60));
assert!(limiter.check_and_increment(ip));
}
#[test]
fn test_different_ips_independent() {
let limiter = RateLimiter::new(2);
assert!(limiter.check_and_increment("ip1"));
assert!(limiter.check_and_increment("ip1"));
assert!(!limiter.check_and_increment("ip1"));
assert!(limiter.check_and_increment("ip2"));
assert!(limiter.check_and_increment("ip2"));
assert!(!limiter.check_and_increment("ip2")); }
#[test]
fn test_check_without_increment() {
let limiter = RateLimiter::new(2);
let ip = "10.0.0.4";
assert!(limiter.check(ip)); assert_eq!(limiter.get_count(ip), 0);
limiter.check_and_increment(ip);
limiter.check_and_increment(ip);
assert!(!limiter.check(ip)); }
#[test]
fn test_cleanup() {
let limiter = RateLimiter::with_window(10, Duration::from_millis(25));
limiter.check_and_increment("ip1");
limiter.check_and_increment("ip2");
assert_eq!(limiter.len(), 2);
thread::sleep(Duration::from_millis(60));
limiter.cleanup();
assert_eq!(limiter.len(), 0);
}
#[test]
fn test_stats() {
let limiter = RateLimiter::new(2);
limiter.check_and_increment("ip1");
limiter.check_and_increment("ip1");
limiter.check_and_increment("ip1");
let stats = limiter.stats();
assert_eq!(stats.tracked_ips, 1);
assert_eq!(stats.max_entries, MAX_ENTRIES);
assert_eq!(stats.allowed, 2);
assert_eq!(stats.limited, 1);
assert_eq!(stats.limit, 2);
}
#[test]
fn test_capacity_bound() {
let limiter = RateLimiter::with_window(100, Duration::from_secs(60));
assert_eq!(limiter.max_entries(), MAX_ENTRIES);
assert_eq!(limiter.stats().evicted, 0);
}
#[test]
fn test_reset() {
let limiter = RateLimiter::new(10);
limiter.check_and_increment("ip1");
limiter.check_and_increment("ip2");
limiter.reset();
assert!(limiter.is_empty());
let stats = limiter.stats();
assert_eq!(stats.allowed, 0);
assert_eq!(stats.limited, 0);
}
#[test]
fn test_concurrent_access() {
use std::sync::Arc;
let limiter = Arc::new(RateLimiter::new(100));
let mut handles = vec![];
for i in 0..10 {
let limiter = Arc::clone(&limiter);
let handle = thread::spawn(move || {
for _ in 0..10 {
limiter.check_and_increment(&format!("ip{}", i));
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
assert_eq!(limiter.len(), 10);
let stats = limiter.stats();
assert_eq!(stats.allowed, 100);
}
}