use crate::config::RateLimitConfig;
use dashmap::DashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum RateLimitResult {
Allowed {
remaining_minute: u32,
remaining_hour: u32,
},
Exceeded {
retry_after: u64,
limit_type: &'static str,
},
}
impl RateLimitResult {
pub fn is_allowed(&self) -> bool {
matches!(self, RateLimitResult::Allowed { .. })
}
pub fn is_exceeded(&self) -> bool {
matches!(self, RateLimitResult::Exceeded { .. })
}
pub fn retry_after(&self) -> Option<u64> {
match self {
RateLimitResult::Exceeded { retry_after, .. } => Some(*retry_after),
_ => None,
}
}
}
#[derive(Debug)]
pub struct KeyRateLimit {
minute_count: u32,
hour_count: u32,
minute_reset: Instant,
hour_reset: Instant,
}
impl KeyRateLimit {
pub fn new() -> Self {
let now = Instant::now();
Self {
minute_count: 0,
hour_count: 0,
minute_reset: now + Duration::from_secs(60),
hour_reset: now + Duration::from_secs(3600),
}
}
pub fn check_and_increment(&mut self, config: &RateLimitConfig) -> RateLimitResult {
let now = Instant::now();
if now >= self.minute_reset {
self.minute_count = 0;
self.minute_reset = now + Duration::from_secs(60);
}
if now >= self.hour_reset {
self.hour_count = 0;
self.hour_reset = now + Duration::from_secs(3600);
}
if self.minute_count >= config.requests_per_minute {
let retry_after = self.minute_reset.duration_since(now).as_secs().max(1);
return RateLimitResult::Exceeded {
retry_after,
limit_type: "minute",
};
}
if self.hour_count >= config.requests_per_hour {
let retry_after = self.hour_reset.duration_since(now).as_secs().max(1);
return RateLimitResult::Exceeded {
retry_after,
limit_type: "hour",
};
}
self.minute_count += 1;
self.hour_count += 1;
RateLimitResult::Allowed {
remaining_minute: config.requests_per_minute - self.minute_count,
remaining_hour: config.requests_per_hour - self.hour_count,
}
}
pub fn is_expired(&self) -> bool {
Instant::now() >= self.hour_reset
}
pub fn minute_count(&self) -> u32 {
self.minute_count
}
pub fn hour_count(&self) -> u32 {
self.hour_count
}
}
impl Default for KeyRateLimit {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone)]
pub struct RateLimitStore {
state: Arc<DashMap<String, KeyRateLimit>>,
}
impl RateLimitStore {
pub fn new() -> Self {
Self {
state: Arc::new(DashMap::with_capacity(1000)),
}
}
pub fn with_capacity(capacity: usize) -> Self {
Self {
state: Arc::new(DashMap::with_capacity(capacity)),
}
}
pub fn check(&self, key: &str, config: &RateLimitConfig) -> RateLimitResult {
let mut entry = self.state.entry(key.to_string()).or_default();
entry.check_and_increment(config)
}
pub fn cleanup_expired(&self) {
self.state.retain(|_, limit| !limit.is_expired());
}
pub fn len(&self) -> usize {
self.state.len()
}
pub fn is_empty(&self) -> bool {
self.state.is_empty()
}
pub fn remove(&self, key: &str) {
self.state.remove(key);
}
pub fn clear(&self) {
self.state.clear();
}
#[cfg(feature = "cleanup-task")]
pub fn spawn_cleanup_task(
self: Arc<Self>,
interval: std::time::Duration,
) -> tokio::task::JoinHandle<()> {
tokio::spawn(async move {
let mut ticker = tokio::time::interval(interval);
ticker.tick().await;
loop {
ticker.tick().await;
let before = self.len();
self.cleanup_expired();
let after = self.len();
if before != after {
tracing::debug!(
before = before,
after = after,
removed = before - after,
"Rate limiter cleanup completed"
);
}
}
})
}
}
impl Default for RateLimitStore {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_key_rate_limit_new() {
let limit = KeyRateLimit::new();
assert_eq!(limit.minute_count, 0);
assert_eq!(limit.hour_count, 0);
}
#[test]
fn test_check_and_increment_allowed() {
let mut limit = KeyRateLimit::new();
let config = RateLimitConfig::for_plan("free");
let result = limit.check_and_increment(&config);
assert!(result.is_allowed());
assert_eq!(limit.minute_count, 1);
assert_eq!(limit.hour_count, 1);
}
#[test]
fn test_check_and_increment_exceeded_minute() {
let mut limit = KeyRateLimit::new();
let config = RateLimitConfig::for_plan("free");
for _ in 0..20 {
let result = limit.check_and_increment(&config);
assert!(result.is_allowed());
}
let result = limit.check_and_increment(&config);
assert!(result.is_exceeded());
assert!(result.retry_after().unwrap() > 0);
}
#[test]
fn test_rate_limit_result_methods() {
let allowed = RateLimitResult::Allowed {
remaining_minute: 10,
remaining_hour: 100,
};
assert!(allowed.is_allowed());
assert!(!allowed.is_exceeded());
assert!(allowed.retry_after().is_none());
let exceeded = RateLimitResult::Exceeded {
retry_after: 30,
limit_type: "minute",
};
assert!(!exceeded.is_allowed());
assert!(exceeded.is_exceeded());
assert_eq!(exceeded.retry_after(), Some(30));
}
#[test]
fn test_store_basic() {
let store = RateLimitStore::new();
let config = RateLimitConfig::for_plan("free");
let result = store.check("user1", &config);
assert!(result.is_allowed());
assert_eq!(store.len(), 1);
let result = store.check("user2", &config);
assert!(result.is_allowed());
assert_eq!(store.len(), 2);
}
#[test]
fn test_store_remove() {
let store = RateLimitStore::new();
let config = RateLimitConfig::for_plan("free");
store.check("user1", &config);
assert_eq!(store.len(), 1);
store.remove("user1");
assert_eq!(store.len(), 0);
}
#[test]
fn test_store_clear() {
let store = RateLimitStore::new();
let config = RateLimitConfig::for_plan("free");
store.check("user1", &config);
store.check("user2", &config);
assert_eq!(store.len(), 2);
store.clear();
assert!(store.is_empty());
}
}