use super::error::AuthError;
use dashmap::DashMap;
use std::collections::VecDeque;
use std::sync::Arc;
use std::time::{Duration, Instant};
pub struct InMemoryRateLimiter {
windows: DashMap<String, VecDeque<Instant>>,
cleanup_interval: Duration,
window_duration: Duration,
}
impl InMemoryRateLimiter {
pub fn new(cleanup_interval: Duration) -> Self {
Self {
windows: DashMap::new(),
cleanup_interval,
window_duration: Duration::from_secs(60),
}
}
#[cfg(test)]
pub fn with_window_duration(mut self, window: Duration) -> Self {
self.window_duration = window;
self
}
pub fn check_rate_limit(&self, app_id: &str, limit_per_minute: u32) -> Result<(), AuthError> {
let now = Instant::now();
let window = self.window_duration;
let mut entry = self.windows.entry(app_id.to_owned()).or_default();
let timestamps = entry.value_mut();
while let Some(&front) = timestamps.front() {
if now.duration_since(front) > window {
timestamps.pop_front();
} else {
break;
}
}
if timestamps.len() >= limit_per_minute as usize {
return Err(AuthError::RateLimitExceeded);
}
timestamps.push_back(now);
Ok(())
}
pub fn start_cleanup_task(self: Arc<Self>) -> tokio::task::JoinHandle<()> {
let interval = self.cleanup_interval;
tokio::spawn(async move {
let mut tick = tokio::time::interval(interval);
tick.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
tick.tick().await;
self.cleanup();
}
})
}
pub(crate) fn cleanup(&self) {
let now = Instant::now();
let window = self.window_duration;
self.windows.retain(|_key, timestamps| {
while let Some(&front) = timestamps.front() {
if now.duration_since(front) > window {
timestamps.pop_front();
} else {
break;
}
}
!timestamps.is_empty()
});
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn allows_requests_under_limit() {
let limiter = InMemoryRateLimiter::new(Duration::from_secs(60));
for _ in 0..5 {
assert!(limiter.check_rate_limit("app1", 10).is_ok());
}
}
#[test]
fn rejects_requests_over_limit() {
let limiter = InMemoryRateLimiter::new(Duration::from_secs(60));
for _ in 0..10 {
limiter.check_rate_limit("app1", 10).unwrap();
}
let result = limiter.check_rate_limit("app1", 10);
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), AuthError::RateLimitExceeded));
}
#[test]
fn independent_limits_per_app() {
let limiter = InMemoryRateLimiter::new(Duration::from_secs(60));
for _ in 0..5 {
limiter.check_rate_limit("app1", 5).unwrap();
}
assert!(limiter.check_rate_limit("app1", 5).is_err());
assert!(limiter.check_rate_limit("app2", 5).is_ok());
}
#[test]
fn cleanup_removes_empty_entries() {
let limiter = InMemoryRateLimiter::new(Duration::from_secs(60));
limiter.check_rate_limit("app1", 100).unwrap();
assert!(!limiter.windows.is_empty());
limiter.cleanup();
assert!(!limiter.windows.is_empty());
}
#[test]
fn zero_limit_always_rejects() {
let limiter = InMemoryRateLimiter::new(Duration::from_secs(60));
let result = limiter.check_rate_limit("app1", 0);
assert!(matches!(result.unwrap_err(), AuthError::RateLimitExceeded));
}
#[test]
fn limit_of_one_allows_single_request() {
let limiter = InMemoryRateLimiter::new(Duration::from_secs(60));
assert!(limiter.check_rate_limit("app1", 1).is_ok());
assert!(limiter.check_rate_limit("app1", 1).is_err());
}
#[tokio::test]
async fn cleanup_removes_expired_entries() {
let limiter = InMemoryRateLimiter::new(Duration::from_secs(60))
.with_window_duration(Duration::from_millis(1));
limiter.check_rate_limit("app1", 100).unwrap();
assert!(!limiter.windows.is_empty());
tokio::time::sleep(Duration::from_millis(5)).await;
limiter.cleanup();
assert!(
limiter.windows.is_empty(),
"expired entry should have been removed by cleanup"
);
}
#[tokio::test(flavor = "multi_thread")]
async fn concurrent_rate_limit_enforcement() {
let limiter = Arc::new(InMemoryRateLimiter::new(Duration::from_secs(60)));
let limit: u32 = 30;
let num_tasks: usize = 60;
let mut handles = Vec::with_capacity(num_tasks);
for _ in 0..num_tasks {
let limiter = limiter.clone();
handles.push(tokio::spawn(async move {
limiter.check_rate_limit("contended-app", limit).is_ok()
}));
}
let mut accepted = 0u32;
for handle in handles {
if handle.await.unwrap() {
accepted += 1;
}
}
assert_eq!(
accepted, limit,
"exactly {limit} requests should have been accepted, but {accepted} were"
);
}
}