use std::collections::HashMap;
use std::net::IpAddr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use tracing::{debug, warn};
#[derive(Debug, Clone)]
pub struct RateLimiter {
inner: Arc<RwLock<RateLimiterInner>>,
requests_per_minute: u32,
burst_size: u32,
cleanup_interval: Duration,
}
#[derive(Debug)]
struct RateLimiterInner {
clients: HashMap<IpAddr, ClientState>,
last_cleanup: Instant,
}
#[derive(Debug)]
struct ClientState {
tokens: u32,
last_refill: Instant,
last_request: Instant,
}
impl RateLimiter {
pub fn new(requests_per_minute: u32, burst_size: u32) -> Self {
Self {
inner: Arc::new(RwLock::new(RateLimiterInner {
clients: HashMap::new(),
last_cleanup: Instant::now(),
})),
requests_per_minute,
burst_size,
cleanup_interval: Duration::from_secs(300), }
}
pub async fn check_rate_limit(&self, ip: IpAddr) -> bool {
let mut limiter = self.inner.write().await;
let now = Instant::now();
if now.duration_since(limiter.last_cleanup) > self.cleanup_interval {
self.cleanup_old_entries(&mut limiter, now);
limiter.last_cleanup = now;
}
let client_state = limiter.clients.entry(ip).or_insert_with(|| ClientState {
tokens: self.burst_size,
last_refill: now,
last_request: now,
});
let time_since_refill = now.duration_since(client_state.last_refill);
let tokens_to_add =
(time_since_refill.as_secs() * self.requests_per_minute as u64 / 60) as u32;
if tokens_to_add > 0 {
client_state.tokens = (client_state.tokens + tokens_to_add).min(self.burst_size);
client_state.last_refill = now;
debug!(
"Refilled {} tokens for IP {}, total: {}",
tokens_to_add, ip, client_state.tokens
);
}
client_state.last_request = now;
if client_state.tokens > 0 {
client_state.tokens -= 1;
debug!(
"Rate limit OK for IP {}, remaining tokens: {}",
ip, client_state.tokens
);
true
} else {
warn!("Rate limit exceeded for IP {}", ip);
false
}
}
pub async fn get_status(&self, ip: IpAddr) -> Option<RateLimitStatus> {
let limiter = self.inner.read().await;
limiter.clients.get(&ip).map(|state| {
let now = Instant::now();
let time_since_refill = now.duration_since(state.last_refill);
let tokens_to_add =
(time_since_refill.as_secs() * self.requests_per_minute as u64 / 60) as u32;
let current_tokens = (state.tokens + tokens_to_add).min(self.burst_size);
RateLimitStatus {
remaining: current_tokens,
limit: self.burst_size,
reset_time: state.last_refill
+ Duration::from_secs(60 / self.requests_per_minute as u64),
}
})
}
fn cleanup_old_entries(&self, limiter: &mut RateLimiterInner, now: Instant) {
let cutoff = now - Duration::from_secs(600); limiter.clients.retain(|ip, state| {
let should_keep = state.last_request > cutoff;
if !should_keep {
debug!("Cleaned up rate limit entry for IP {}", ip);
}
should_keep
});
}
}
#[derive(Debug, Clone)]
pub struct RateLimitStatus {
pub remaining: u32,
pub limit: u32,
pub reset_time: Instant,
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::IpAddr;
use std::str::FromStr;
#[tokio::test]
async fn rate_limiter_allows_initial_requests() {
let limiter = RateLimiter::new(60, 10);
let ip = IpAddr::from_str("127.0.0.1").unwrap();
for _ in 0..10 {
assert!(limiter.check_rate_limit(ip).await);
}
assert!(!limiter.check_rate_limit(ip).await);
}
#[tokio::test]
async fn rate_limiter_refills_tokens() {
let limiter = RateLimiter::new(60, 5);
let ip = IpAddr::from_str("127.0.0.1").unwrap();
for _ in 0..5 {
assert!(limiter.check_rate_limit(ip).await);
}
assert!(!limiter.check_rate_limit(ip).await);
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
}
#[tokio::test]
async fn rate_limiter_tracks_multiple_ips() {
let limiter = RateLimiter::new(60, 3);
let ip1 = IpAddr::from_str("127.0.0.1").unwrap();
let ip2 = IpAddr::from_str("192.168.1.1").unwrap();
assert!(limiter.check_rate_limit(ip1).await);
assert!(limiter.check_rate_limit(ip2).await);
let status1 = limiter.get_status(ip1).await.unwrap();
let status2 = limiter.get_status(ip2).await.unwrap();
assert_eq!(status1.remaining, 2); assert_eq!(status2.remaining, 2); assert_eq!(status1.limit, 3);
assert_eq!(status2.limit, 3);
}
}