use std::time::Instant;
use tracing::debug;
use crate::types::{ConfigProvider, RateLimitEntry, RateLimiter};
pub async fn check_rate_limit(
limiter: &RateLimiter,
ip: &str,
config: &impl ConfigProvider,
) -> bool {
let rate_config = config.rate_limit_config();
let cleanup_config = config.rate_limit_cleanup_config();
let now = Instant::now();
let mut state = limiter.state().lock().await;
if cleanup_config.is_enabled() && state.entries.len() > cleanup_config.threshold {
let should_cleanup = match state.last_cleanup {
None => true,
Some(last) => now.duration_since(last) >= cleanup_config.interval,
};
if should_cleanup {
state.last_cleanup = Some(now);
let before_count = state.entries.len();
state.entries.retain(|_, entry| {
entry
.oldest()
.is_some_and(|t| now.duration_since(t) < rate_config.window_duration * 2)
});
let removed = before_count - state.entries.len();
if removed > 0 {
debug!(
removed_entries = removed,
remaining_entries = state.entries.len(),
"Rate limiter cleanup completed"
);
}
}
}
let entry = state
.entries
.entry(ip.to_string())
.or_insert_with(|| RateLimitEntry {
timestamps: std::collections::VecDeque::new(),
});
while entry
.timestamps
.front()
.is_some_and(|&t| now.duration_since(t) >= rate_config.window_duration)
{
entry.timestamps.pop_front();
}
if (entry.timestamps.len() as u32) < rate_config.max_requests {
entry.timestamps.push_back(now);
true
} else {
false
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_utils::TestConfig;
#[tokio::test]
async fn test_first_request_allowed() {
let limiter = RateLimiter::new();
let config = TestConfig::new().with_rate_limit(5, 60);
let allowed = check_rate_limit(&limiter, "192.168.1.1", &config).await;
assert!(allowed);
}
#[tokio::test]
async fn test_requests_within_limit_allowed() {
let limiter = RateLimiter::new();
let config = TestConfig::new().with_rate_limit(5, 60);
for i in 0..5 {
let allowed = check_rate_limit(&limiter, "192.168.1.1", &config).await;
assert!(allowed, "Request {} should be allowed", i + 1);
}
}
#[tokio::test]
async fn test_request_exceeding_limit_blocked() {
let limiter = RateLimiter::new();
let config = TestConfig::new().with_rate_limit(3, 60);
for _ in 0..3 {
assert!(check_rate_limit(&limiter, "192.168.1.1", &config).await);
}
let blocked = check_rate_limit(&limiter, "192.168.1.1", &config).await;
assert!(!blocked, "Request exceeding limit should be blocked");
}
#[tokio::test]
async fn test_different_ips_independent() {
let limiter = RateLimiter::new();
let config = TestConfig::new().with_rate_limit(2, 60);
assert!(check_rate_limit(&limiter, "192.168.1.1", &config).await);
assert!(check_rate_limit(&limiter, "192.168.1.1", &config).await);
assert!(!check_rate_limit(&limiter, "192.168.1.1", &config).await);
assert!(check_rate_limit(&limiter, "192.168.1.2", &config).await);
assert!(check_rate_limit(&limiter, "192.168.1.2", &config).await);
assert!(!check_rate_limit(&limiter, "192.168.1.2", &config).await);
}
#[tokio::test]
async fn test_counter_increments_correctly() {
let limiter = RateLimiter::new();
let config = TestConfig::new().with_rate_limit(5, 60);
check_rate_limit(&limiter, "192.168.1.1", &config).await;
check_rate_limit(&limiter, "192.168.1.1", &config).await;
check_rate_limit(&limiter, "192.168.1.1", &config).await;
let state = limiter.state().lock().await;
let entry = state.entries.get("192.168.1.1").unwrap();
assert_eq!(entry.request_count(), 3);
}
#[tokio::test]
async fn test_limit_of_one() {
let limiter = RateLimiter::new();
let config = TestConfig::new().with_rate_limit(1, 60);
assert!(check_rate_limit(&limiter, "192.168.1.1", &config).await);
assert!(!check_rate_limit(&limiter, "192.168.1.1", &config).await);
}
#[tokio::test]
async fn test_ipv6_addresses() {
let limiter = RateLimiter::new();
let config = TestConfig::new().with_rate_limit(2, 60);
assert!(check_rate_limit(&limiter, "::1", &config).await);
assert!(check_rate_limit(&limiter, "::1", &config).await);
assert!(!check_rate_limit(&limiter, "::1", &config).await);
assert!(check_rate_limit(&limiter, "2001:db8::1", &config).await);
}
#[tokio::test]
async fn test_multiple_blocked_requests() {
let limiter = RateLimiter::new();
let config = TestConfig::new().with_rate_limit(1, 60);
assert!(check_rate_limit(&limiter, "192.168.1.1", &config).await);
for _ in 0..5 {
assert!(!check_rate_limit(&limiter, "192.168.1.1", &config).await);
}
let state = limiter.state().lock().await;
let entry = state.entries.get("192.168.1.1").unwrap();
assert_eq!(entry.request_count(), 1);
}
#[tokio::test]
async fn test_limiter_clone_shares_state() {
let limiter1 = RateLimiter::new();
let limiter2 = limiter1.clone();
let config = TestConfig::new().with_rate_limit(2, 60);
assert!(check_rate_limit(&limiter1, "192.168.1.1", &config).await);
assert!(check_rate_limit(&limiter2, "192.168.1.1", &config).await);
assert!(!check_rate_limit(&limiter1, "192.168.1.1", &config).await);
}
#[tokio::test]
async fn test_cleanup_disabled_when_threshold_zero() {
let limiter = RateLimiter::new();
let config = TestConfig::new().with_rate_limit(100, 60);
for i in 0..100 {
check_rate_limit(&limiter, &format!("192.168.1.{}", i), &config).await;
}
let state = limiter.state().lock().await;
assert_eq!(state.entries.len(), 100);
}
#[tokio::test]
async fn test_entries_tracked_per_ip() {
let limiter = RateLimiter::new();
let config = TestConfig::new().with_rate_limit(10, 60);
for i in 0..5 {
check_rate_limit(&limiter, &format!("10.0.0.{}", i), &config).await;
}
let state = limiter.state().lock().await;
assert_eq!(state.entries.len(), 5);
}
#[tokio::test]
async fn test_window_reset_after_expiration() {
use crate::types::{RateLimitCleanupConfig, RateLimitConfig};
use std::time::Duration;
let limiter = RateLimiter::new();
let config = TestConfig {
rate_limit: RateLimitConfig {
max_requests: 2,
window_duration: Duration::from_millis(1),
},
cleanup: RateLimitCleanupConfig {
threshold: 0,
interval: Duration::from_secs(60),
},
..TestConfig::default()
};
assert!(check_rate_limit(&limiter, "192.168.1.1", &config).await);
assert!(check_rate_limit(&limiter, "192.168.1.1", &config).await);
assert!(!check_rate_limit(&limiter, "192.168.1.1", &config).await);
tokio::time::sleep(Duration::from_millis(5)).await;
assert!(check_rate_limit(&limiter, "192.168.1.1", &config).await);
}
#[tokio::test]
async fn test_window_reset_resets_counter() {
use crate::types::{RateLimitCleanupConfig, RateLimitConfig};
use std::time::Duration;
let limiter = RateLimiter::new();
let config = TestConfig {
rate_limit: RateLimitConfig {
max_requests: 3,
window_duration: Duration::from_millis(1),
},
cleanup: RateLimitCleanupConfig {
threshold: 0,
interval: Duration::from_secs(60),
},
..TestConfig::default()
};
for _ in 0..3 {
assert!(check_rate_limit(&limiter, "192.168.1.1", &config).await);
}
assert!(!check_rate_limit(&limiter, "192.168.1.1", &config).await);
tokio::time::sleep(Duration::from_millis(5)).await;
for _ in 0..3 {
assert!(check_rate_limit(&limiter, "192.168.1.1", &config).await);
}
assert!(!check_rate_limit(&limiter, "192.168.1.1", &config).await);
}
#[tokio::test]
async fn test_window_not_expired_keeps_count() {
let limiter = RateLimiter::new();
let config = TestConfig::new().with_rate_limit(5, 3600);
for _ in 0..3 {
assert!(check_rate_limit(&limiter, "192.168.1.1", &config).await);
}
{
let state = limiter.state().lock().await;
assert_eq!(state.entries.get("192.168.1.1").unwrap().request_count(), 3);
}
assert!(check_rate_limit(&limiter, "192.168.1.1", &config).await);
assert!(check_rate_limit(&limiter, "192.168.1.1", &config).await);
assert!(!check_rate_limit(&limiter, "192.168.1.1", &config).await);
let state = limiter.state().lock().await;
assert_eq!(state.entries.get("192.168.1.1").unwrap().request_count(), 5);
}
#[tokio::test]
async fn test_different_ips_different_windows() {
use crate::types::{RateLimitCleanupConfig, RateLimitConfig};
use std::time::Duration;
let limiter = RateLimiter::new();
let config = TestConfig {
rate_limit: RateLimitConfig {
max_requests: 2,
window_duration: Duration::from_millis(50),
},
cleanup: RateLimitCleanupConfig {
threshold: 0,
interval: Duration::from_secs(60),
},
..TestConfig::default()
};
assert!(check_rate_limit(&limiter, "192.168.1.1", &config).await);
assert!(check_rate_limit(&limiter, "192.168.1.1", &config).await);
assert!(!check_rate_limit(&limiter, "192.168.1.1", &config).await);
tokio::time::sleep(Duration::from_millis(10)).await;
assert!(check_rate_limit(&limiter, "192.168.1.2", &config).await);
tokio::time::sleep(Duration::from_millis(50)).await;
assert!(check_rate_limit(&limiter, "192.168.1.1", &config).await);
let state = limiter.state().lock().await;
assert!(state.entries.contains_key("192.168.1.1"));
assert!(state.entries.contains_key("192.168.1.2"));
}
#[tokio::test]
async fn test_cleanup_removes_expired_entries() {
use crate::types::{RateLimitCleanupConfig, RateLimitConfig};
use std::time::Duration;
let limiter = RateLimiter::new();
let config = TestConfig {
rate_limit: RateLimitConfig {
max_requests: 100,
window_duration: Duration::from_millis(1), },
cleanup: RateLimitCleanupConfig {
threshold: 1, interval: Duration::from_millis(1), },
..TestConfig::default()
};
check_rate_limit(&limiter, "192.168.1.1", &config).await;
tokio::time::sleep(Duration::from_millis(10)).await;
check_rate_limit(&limiter, "192.168.1.2", &config).await;
tokio::time::sleep(Duration::from_millis(10)).await;
check_rate_limit(&limiter, "192.168.1.3", &config).await;
let state = limiter.state().lock().await;
assert!(
!state.entries.contains_key("192.168.1.1"),
"Expired entry should have been cleaned up"
);
assert!(
state.entries.contains_key("192.168.1.3"),
"Recent entry should still exist"
);
}
#[tokio::test]
async fn test_concurrent_requests_same_ip() {
let limiter = RateLimiter::new();
let config = TestConfig::new().with_rate_limit(10, 60);
let mut handles = vec![];
for _ in 0..10 {
let limiter_clone = limiter.clone();
let handle = tokio::spawn(async move {
let config = TestConfig::new().with_rate_limit(10, 60);
check_rate_limit(&limiter_clone, "192.168.1.1", &config).await
});
handles.push(handle);
}
let results: Vec<bool> = futures::future::join_all(handles)
.await
.into_iter()
.map(|r| r.unwrap())
.collect();
assert_eq!(results.iter().filter(|&&r| r).count(), 10);
assert!(!check_rate_limit(&limiter, "192.168.1.1", &config).await);
}
#[tokio::test]
async fn test_concurrent_requests_different_ips() {
let limiter = RateLimiter::new();
let mut handles = vec![];
for i in 0..50 {
let limiter_clone = limiter.clone();
let ip = format!("192.168.1.{}", i);
let handle = tokio::spawn(async move {
let config = TestConfig::new().with_rate_limit(5, 60);
check_rate_limit(&limiter_clone, &ip, &config).await
});
handles.push(handle);
}
let results: Vec<bool> = futures::future::join_all(handles)
.await
.into_iter()
.map(|r| r.unwrap())
.collect();
assert!(results.iter().all(|&r| r));
let state = limiter.state().lock().await;
assert_eq!(state.entries.len(), 50);
}
}