use std::collections::{HashMap, VecDeque};
use std::time::{Duration, Instant};
use async_trait::async_trait;
use awp_types::TrustLevel;
use dashmap::DashMap;
#[async_trait]
pub trait RateLimiter: Send + Sync {
async fn check(&self, key: &str, trust_level: TrustLevel) -> Result<(), u64>;
}
#[derive(Debug, Clone, Copy)]
pub struct RateLimitConfig {
pub max_requests: u64,
pub window_secs: u64,
}
pub struct InMemoryRateLimiter {
windows: DashMap<String, VecDeque<Instant>>,
limits: HashMap<TrustLevel, RateLimitConfig>,
window_size: Duration,
}
impl InMemoryRateLimiter {
pub fn new() -> Self {
let mut limits = HashMap::new();
limits.insert(TrustLevel::Anonymous, RateLimitConfig { max_requests: 30, window_secs: 60 });
limits.insert(TrustLevel::Known, RateLimitConfig { max_requests: 120, window_secs: 60 });
limits.insert(TrustLevel::Partner, RateLimitConfig { max_requests: 600, window_secs: 60 });
Self { windows: DashMap::new(), limits, window_size: Duration::from_secs(60) }
}
pub fn with_config(
limits: HashMap<TrustLevel, RateLimitConfig>,
window_size: Duration,
) -> Self {
Self { windows: DashMap::new(), limits, window_size }
}
}
impl Default for InMemoryRateLimiter {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl RateLimiter for InMemoryRateLimiter {
async fn check(&self, key: &str, trust_level: TrustLevel) -> Result<(), u64> {
let config = match self.limits.get(&trust_level) {
Some(c) => *c,
None => return Ok(()),
};
let composite_key = format!("{trust_level}:{key}");
let now = Instant::now();
let window_start = now - self.window_size;
let mut entry = self.windows.entry(composite_key).or_default();
let deque = entry.value_mut();
while deque.front().is_some_and(|t| *t < window_start) {
deque.pop_front();
}
if deque.len() as u64 >= config.max_requests {
let oldest = deque.front().copied().unwrap_or(now);
let expires_at = oldest + self.window_size;
let retry_after = expires_at.duration_since(now).as_secs().max(1);
return Err(retry_after);
}
deque.push_back(now);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_anonymous_under_limit() {
let limiter = InMemoryRateLimiter::new();
for _ in 0..30 {
assert!(limiter.check("client1", TrustLevel::Anonymous).await.is_ok());
}
}
#[tokio::test]
async fn test_anonymous_over_limit() {
let limiter = InMemoryRateLimiter::new();
for _ in 0..30 {
limiter.check("client1", TrustLevel::Anonymous).await.unwrap();
}
let result = limiter.check("client1", TrustLevel::Anonymous).await;
assert!(result.is_err());
let retry_after = result.unwrap_err();
assert!(retry_after > 0);
}
#[tokio::test]
async fn test_internal_unlimited() {
let limiter = InMemoryRateLimiter::new();
for _ in 0..1000 {
assert!(limiter.check("client1", TrustLevel::Internal).await.is_ok());
}
}
#[tokio::test]
async fn test_different_keys_independent() {
let limiter = InMemoryRateLimiter::new();
for _ in 0..30 {
limiter.check("client1", TrustLevel::Anonymous).await.unwrap();
}
assert!(limiter.check("client2", TrustLevel::Anonymous).await.is_ok());
}
#[tokio::test]
async fn test_different_trust_levels_independent() {
let limiter = InMemoryRateLimiter::new();
for _ in 0..30 {
limiter.check("client1", TrustLevel::Anonymous).await.unwrap();
}
assert!(limiter.check("client1", TrustLevel::Known).await.is_ok());
}
#[tokio::test]
async fn test_custom_config() {
let mut limits = HashMap::new();
limits.insert(TrustLevel::Anonymous, RateLimitConfig { max_requests: 2, window_secs: 1 });
let limiter = InMemoryRateLimiter::with_config(limits, Duration::from_secs(1));
assert!(limiter.check("c", TrustLevel::Anonymous).await.is_ok());
assert!(limiter.check("c", TrustLevel::Anonymous).await.is_ok());
assert!(limiter.check("c", TrustLevel::Anonymous).await.is_err());
}
#[tokio::test]
async fn test_known_higher_limit() {
let limiter = InMemoryRateLimiter::new();
for _ in 0..120 {
assert!(limiter.check("client1", TrustLevel::Known).await.is_ok());
}
assert!(limiter.check("client1", TrustLevel::Known).await.is_err());
}
#[tokio::test]
async fn test_partner_higher_limit() {
let limiter = InMemoryRateLimiter::new();
for _ in 0..600 {
assert!(limiter.check("client1", TrustLevel::Partner).await.is_ok());
}
assert!(limiter.check("client1", TrustLevel::Partner).await.is_err());
}
}