use super::types::{QuotaUsage, Window};
use crate::config::rate_limit::TierLimits;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::SystemTime;
use tokio::sync::RwLock;
#[derive(Debug, Clone)]
struct WindowState {
count: u32,
reset_at: SystemTime,
}
impl WindowState {
fn new(reset_at: SystemTime) -> Self {
Self { count: 0, reset_at }
}
fn is_expired(&self, now: SystemTime) -> bool {
now >= self.reset_at
}
fn reset(&mut self, reset_at: SystemTime) {
self.count = 0;
self.reset_at = reset_at;
}
}
#[derive(Debug, Clone)]
struct ClientQuota {
minute: WindowState,
hour: WindowState,
day: WindowState,
month: WindowState,
}
impl ClientQuota {
fn new(now: SystemTime) -> Self {
Self {
minute: WindowState::new(Window::Minute.next_reset(now)),
hour: WindowState::new(Window::Hour.next_reset(now)),
day: WindowState::new(Window::Day.next_reset(now)),
month: WindowState::new(Window::Month.next_reset(now)),
}
}
fn update(&mut self, now: SystemTime) {
if self.minute.is_expired(now) {
self.minute.reset(Window::Minute.next_reset(now));
}
if self.hour.is_expired(now) {
self.hour.reset(Window::Hour.next_reset(now));
}
if self.day.is_expired(now) {
self.day.reset(Window::Day.next_reset(now));
}
if self.month.is_expired(now) {
self.month.reset(Window::Month.next_reset(now));
}
}
fn increment(&mut self) {
self.minute.count += 1;
self.hour.count += 1;
self.day.count += 1;
self.month.count += 1;
}
fn usage(&self) -> QuotaUsage {
QuotaUsage {
minute: self.minute.count,
hour: self.hour.count,
day: self.day.count,
month: self.month.count,
}
}
fn exceeds(&self, limits: &TierLimits) -> bool {
self.minute.count >= limits.requests_per_minute
|| self.hour.count >= limits.requests_per_hour
|| self.day.count >= limits.requests_per_day
}
}
pub struct QuotaTracker {
quotas: Arc<RwLock<HashMap<String, ClientQuota>>>,
}
impl QuotaTracker {
pub fn new() -> Self {
Self {
quotas: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn check_and_increment(&self, key: &str, limits: &TierLimits) -> bool {
let now = SystemTime::now();
let mut quotas = self.quotas.write().await;
let quota = quotas
.entry(key.to_string())
.or_insert_with(|| ClientQuota::new(now));
quota.update(now);
if quota.exceeds(limits) {
return false;
}
quota.increment();
true
}
pub async fn get_usage(&self, key: &str) -> QuotaUsage {
let now = SystemTime::now();
let mut quotas = self.quotas.write().await;
let quota = quotas
.entry(key.to_string())
.or_insert_with(|| ClientQuota::new(now));
quota.update(now);
quota.usage()
}
pub async fn time_until_reset(&self, key: &str, window: Window) -> Option<u64> {
let now = SystemTime::now();
let quotas = self.quotas.read().await;
quotas.get(key).map(|quota| {
let reset_at = match window {
Window::Minute => quota.minute.reset_at,
Window::Hour => quota.hour.reset_at,
Window::Day => quota.day.reset_at,
Window::Month => quota.month.reset_at,
};
reset_at
.duration_since(now)
.unwrap_or_default()
.as_secs()
})
}
pub async fn cleanup_expired(&self) {
let now = SystemTime::now();
let mut quotas = self.quotas.write().await;
quotas.retain(|_, quota| {
!(quota.minute.is_expired(now)
&& quota.hour.is_expired(now)
&& quota.day.is_expired(now)
&& quota.month.is_expired(now))
});
}
pub async fn client_count(&self) -> usize {
self.quotas.read().await.len()
}
}
impl Default for QuotaTracker {
fn default() -> Self {
Self::new()
}
}
impl Clone for QuotaTracker {
fn clone(&self) -> Self {
Self {
quotas: Arc::clone(&self.quotas),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_quota_tracker_new() {
let tracker = QuotaTracker::new();
assert_eq!(tracker.client_count().await, 0);
}
#[tokio::test]
async fn test_check_and_increment_first_request() {
let tracker = QuotaTracker::new();
let limits = TierLimits {
requests_per_minute: 10,
requests_per_hour: 100,
requests_per_day: 1000,
max_concurrent: 5,
};
assert!(tracker.check_and_increment("user1", &limits).await);
assert_eq!(tracker.client_count().await, 1);
let usage = tracker.get_usage("user1").await;
assert_eq!(usage.minute, 1);
assert_eq!(usage.hour, 1);
assert_eq!(usage.day, 1);
assert_eq!(usage.month, 1);
}
#[tokio::test]
async fn test_check_and_increment_within_limit() {
let tracker = QuotaTracker::new();
let limits = TierLimits {
requests_per_minute: 10,
requests_per_hour: 100,
requests_per_day: 1000,
max_concurrent: 5,
};
for _ in 0..9 {
assert!(tracker.check_and_increment("user1", &limits).await);
}
let usage = tracker.get_usage("user1").await;
assert_eq!(usage.minute, 9);
}
#[tokio::test]
async fn test_check_and_increment_exceeds_limit() {
let tracker = QuotaTracker::new();
let limits = TierLimits {
requests_per_minute: 10,
requests_per_hour: 100,
requests_per_day: 1000,
max_concurrent: 5,
};
for i in 0..10 {
assert!(
tracker.check_and_increment("user1", &limits).await,
"Request {} should succeed",
i
);
}
assert!(!tracker.check_and_increment("user1", &limits).await);
let usage = tracker.get_usage("user1").await;
assert_eq!(usage.minute, 10);
}
#[tokio::test]
async fn test_different_clients_have_separate_quotas() {
let tracker = QuotaTracker::new();
let limits = TierLimits {
requests_per_minute: 5,
requests_per_hour: 50,
requests_per_day: 500,
max_concurrent: 3,
};
for _ in 0..5 {
assert!(tracker.check_and_increment("user1", &limits).await);
}
assert!(tracker.check_and_increment("user2", &limits).await);
let usage1 = tracker.get_usage("user1").await;
let usage2 = tracker.get_usage("user2").await;
assert_eq!(usage1.minute, 5);
assert_eq!(usage2.minute, 1);
}
#[tokio::test]
async fn test_hour_limit_enforcement() {
let tracker = QuotaTracker::new();
let limits = TierLimits {
requests_per_minute: 100, requests_per_hour: 10, requests_per_day: 1000,
max_concurrent: 5,
};
for _ in 0..10 {
assert!(tracker.check_and_increment("user1", &limits).await);
}
assert!(!tracker.check_and_increment("user1", &limits).await);
}
#[tokio::test]
async fn test_day_limit_enforcement() {
let tracker = QuotaTracker::new();
let limits = TierLimits {
requests_per_minute: 1000,
requests_per_hour: 10000,
requests_per_day: 5, max_concurrent: 10,
};
for _ in 0..5 {
assert!(tracker.check_and_increment("user1", &limits).await);
}
assert!(!tracker.check_and_increment("user1", &limits).await);
}
#[tokio::test]
async fn test_time_until_reset() {
let tracker = QuotaTracker::new();
let limits = TierLimits {
requests_per_minute: 10,
requests_per_hour: 100,
requests_per_day: 1000,
max_concurrent: 5,
};
assert!(tracker.check_and_increment("user1", &limits).await);
let time_until_reset = tracker.time_until_reset("user1", Window::Minute).await;
assert!(time_until_reset.is_some());
let seconds = time_until_reset.unwrap();
assert!(seconds > 0);
assert!(seconds <= 60);
}
#[tokio::test]
async fn test_cleanup_expired() {
let tracker = QuotaTracker::new();
let limits = TierLimits {
requests_per_minute: 10,
requests_per_hour: 100,
requests_per_day: 1000,
max_concurrent: 5,
};
tracker.check_and_increment("user1", &limits).await;
tracker.check_and_increment("user2", &limits).await;
assert_eq!(tracker.client_count().await, 2);
tracker.cleanup_expired().await;
assert_eq!(tracker.client_count().await, 2);
}
#[tokio::test]
async fn test_quota_tracker_clone() {
let tracker1 = QuotaTracker::new();
let tracker2 = tracker1.clone();
let limits = TierLimits {
requests_per_minute: 10,
requests_per_hour: 100,
requests_per_day: 1000,
max_concurrent: 5,
};
tracker1.check_and_increment("user1", &limits).await;
let usage = tracker2.get_usage("user1").await;
assert_eq!(usage.minute, 1);
}
}