use super::quota::QuotaTracker;
use super::types::RateLimitDecision;
use crate::config::rate_limit::{RateLimitTier, TierLimits};
use async_trait::async_trait;
use governor::{
clock::DefaultClock,
state::{InMemoryState, NotKeyed},
Quota, RateLimiter as GovernorRateLimiter,
};
use std::collections::HashMap;
use std::num::NonZeroU32;
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
use tokio::sync::RwLock;
#[async_trait]
pub trait RateLimiter: Send + Sync {
async fn check_rate_limit(&self, key: &str, tier: RateLimitTier) -> RateLimitDecision;
}
struct ClientLimiter {
limiter: GovernorRateLimiter<NotKeyed, InMemoryState, DefaultClock>,
tier: RateLimitTier,
limits: TierLimits,
}
impl ClientLimiter {
fn new(tier: RateLimitTier, limits: TierLimits) -> Self {
let quota = Quota::per_minute(
NonZeroU32::new(limits.requests_per_minute).unwrap_or(NonZeroU32::new(1).unwrap()),
);
let limiter = GovernorRateLimiter::direct(quota);
Self {
limiter,
tier,
limits,
}
}
fn check(&self) -> bool {
self.limiter.check().is_ok()
}
fn remaining(&self) -> u32 {
self.limits.requests_per_minute
}
fn reset_at(&self) -> u64 {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
((now / 60) + 1) * 60
}
}
pub struct MultiTierRateLimiter {
limiters: Arc<RwLock<HashMap<String, ClientLimiter>>>,
quota_tracker: QuotaTracker,
config: Arc<crate::config::RateLimitConfig>,
}
impl MultiTierRateLimiter {
pub fn new(config: crate::config::RateLimitConfig) -> Self {
Self {
limiters: Arc::new(RwLock::new(HashMap::new())),
quota_tracker: QuotaTracker::new(),
config: Arc::new(config),
}
}
async fn get_or_create_limiter(&self, key: &str, tier: RateLimitTier) -> ClientLimiter {
let mut limiters = self.limiters.write().await;
if let Some(limiter) = limiters.get(key) {
if limiter.tier == tier {
return limiter.clone();
}
limiters.remove(key);
}
let limits = self.config.get_limits(tier).clone();
let limiter = ClientLimiter::new(tier, limits);
limiters.insert(key.to_string(), limiter.clone());
limiter
}
async fn check_all_limits(&self, key: &str, tier: RateLimitTier) -> RateLimitDecision {
let limits = self.config.get_limits(tier).clone();
if !self.quota_tracker.check_and_increment(key, &limits).await {
let reset_at = self.calculate_reset_time(key).await;
let retry_after = self.calculate_retry_after(reset_at);
return RateLimitDecision::deny(limits.requests_per_minute, reset_at, retry_after);
}
let limiter = self.get_or_create_limiter(key, tier).await;
if !limiter.check() {
let reset_at = limiter.reset_at();
let retry_after = self.calculate_retry_after(reset_at);
return RateLimitDecision::deny(limits.requests_per_minute, reset_at, retry_after);
}
let usage = self.quota_tracker.get_usage(key).await;
let remaining = limits
.requests_per_minute
.saturating_sub(usage.minute);
let reset_at = limiter.reset_at();
RateLimitDecision::allow(limits.requests_per_minute, remaining, reset_at)
}
async fn calculate_reset_time(&self, key: &str) -> u64 {
let reset_secs = self
.quota_tracker
.time_until_reset(key, super::types::Window::Minute)
.await
.unwrap_or(60);
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
now + reset_secs
}
fn calculate_retry_after(&self, reset_at: u64) -> u32 {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
reset_at.saturating_sub(now) as u32
}
pub async fn cleanup(&self) {
self.quota_tracker.cleanup_expired().await;
}
}
impl Clone for ClientLimiter {
fn clone(&self) -> Self {
Self::new(self.tier, self.limits.clone())
}
}
#[async_trait]
impl RateLimiter for MultiTierRateLimiter {
async fn check_rate_limit(&self, key: &str, tier: RateLimitTier) -> RateLimitDecision {
self.check_all_limits(key, tier).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::RateLimitConfig;
fn test_config() -> RateLimitConfig {
RateLimitConfig {
enabled: true,
default_tier: RateLimitTier::Free,
free: TierLimits {
requests_per_minute: 10,
requests_per_hour: 100,
requests_per_day: 1000,
max_concurrent: 5,
},
pro: TierLimits {
requests_per_minute: 100,
requests_per_hour: 1000,
requests_per_day: 10000,
max_concurrent: 50,
},
enterprise: TierLimits {
requests_per_minute: 1000,
requests_per_hour: 10000,
requests_per_day: 100000,
max_concurrent: 200,
},
}
}
#[tokio::test]
async fn test_rate_limiter_new() {
let config = test_config();
let limiter = MultiTierRateLimiter::new(config);
assert!(true);
}
#[tokio::test]
async fn test_rate_limiter_allows_first_request() {
let config = test_config();
let limiter = MultiTierRateLimiter::new(config);
let decision = limiter
.check_rate_limit("user1", RateLimitTier::Free)
.await;
assert!(decision.allowed);
assert_eq!(decision.limit, 10);
assert!(decision.remaining > 0);
assert!(decision.retry_after.is_none());
}
#[tokio::test]
async fn test_rate_limiter_allows_within_limit() {
let config = test_config();
let limiter = MultiTierRateLimiter::new(config);
for i in 0..9 {
let decision = limiter
.check_rate_limit("user1", RateLimitTier::Free)
.await;
assert!(
decision.allowed,
"Request {} should be allowed",
i + 1
);
}
}
#[tokio::test]
async fn test_rate_limiter_denies_over_limit() {
let config = test_config();
let limiter = MultiTierRateLimiter::new(config);
for _ in 0..10 {
let decision = limiter
.check_rate_limit("user1", RateLimitTier::Free)
.await;
assert!(decision.allowed);
}
let decision = limiter
.check_rate_limit("user1", RateLimitTier::Free)
.await;
assert!(!decision.allowed);
assert_eq!(decision.remaining, 0);
assert!(decision.retry_after.is_some());
assert!(decision.retry_after.unwrap() > 0);
}
#[tokio::test]
async fn test_rate_limiter_different_tiers() {
let config = test_config();
let limiter = MultiTierRateLimiter::new(config);
for _ in 0..10 {
let decision = limiter
.check_rate_limit("free_user", RateLimitTier::Free)
.await;
assert!(decision.allowed);
}
let decision = limiter
.check_rate_limit("free_user", RateLimitTier::Free)
.await;
assert!(!decision.allowed);
for i in 0..50 {
let decision = limiter
.check_rate_limit("pro_user", RateLimitTier::Pro)
.await;
assert!(decision.allowed, "Pro request {} should succeed", i);
}
}
#[tokio::test]
async fn test_rate_limiter_separate_clients() {
let config = test_config();
let limiter = MultiTierRateLimiter::new(config);
for _ in 0..10 {
limiter
.check_rate_limit("user1", RateLimitTier::Free)
.await;
}
let decision1 = limiter
.check_rate_limit("user1", RateLimitTier::Free)
.await;
assert!(!decision1.allowed);
let decision2 = limiter
.check_rate_limit("user2", RateLimitTier::Free)
.await;
assert!(decision2.allowed);
}
#[tokio::test]
async fn test_rate_limit_decision_invariants() {
let config = test_config();
let limiter = MultiTierRateLimiter::new(config);
let decision = limiter
.check_rate_limit("user1", RateLimitTier::Free)
.await;
assert!(decision.validate_invariants().is_ok());
for _ in 0..10 {
limiter
.check_rate_limit("user2", RateLimitTier::Free)
.await;
}
let deny_decision = limiter
.check_rate_limit("user2", RateLimitTier::Free)
.await;
assert!(deny_decision.validate_invariants().is_ok());
}
#[tokio::test]
async fn test_client_limiter_creation() {
let limits = TierLimits {
requests_per_minute: 100,
requests_per_hour: 1000,
requests_per_day: 10000,
max_concurrent: 50,
};
let limiter = ClientLimiter::new(RateLimitTier::Pro, limits);
assert_eq!(limiter.tier, RateLimitTier::Pro);
assert!(limiter.check()); }
#[tokio::test]
async fn test_tier_change_creates_new_limiter() {
let config = test_config();
let limiter = MultiTierRateLimiter::new(config);
limiter
.check_rate_limit("user1", RateLimitTier::Free)
.await;
let decision = limiter
.check_rate_limit("user1", RateLimitTier::Pro)
.await;
assert!(decision.allowed);
assert_eq!(decision.limit, 100); }
}