use crate::config::RateLimitConfig;
use crate::error::{GuardError, Result};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
#[cfg(feature = "rate-limit")]
use governor::{Quota, RateLimiter as GovernorLimiter};
#[cfg(feature = "rate-limit")]
use std::num::NonZeroU32;
pub struct RateLimiter {
config: RateLimitConfig,
#[cfg(feature = "rate-limit")]
limiters: Arc<
RwLock<
HashMap<
String,
Arc<
GovernorLimiter<
governor::state::NotKeyed,
governor::state::InMemoryState,
governor::clock::DefaultClock,
>,
>,
>,
>,
>,
}
impl RateLimiter {
pub fn new(config: RateLimitConfig) -> Self {
Self {
config,
#[cfg(feature = "rate-limit")]
limiters: Arc::new(RwLock::new(HashMap::new())),
}
}
#[cfg(feature = "rate-limit")]
pub async fn check(&self, user_id: &str) -> Result<()> {
if !self.config.enabled {
return Ok(());
}
let limiter = self.get_or_create_limiter(user_id).await;
match limiter.check() {
Ok(_) => Ok(()),
Err(_) => Err(GuardError::RateLimitExceeded(format!(
"Rate limit exceeded for user: {}. Limit: {} requests/minute",
user_id, self.config.requests_per_minute
))),
}
}
#[cfg(not(feature = "rate-limit"))]
pub async fn check(&self, _user_id: &str) -> Result<()> {
Ok(())
}
#[cfg(feature = "rate-limit")]
async fn get_or_create_limiter(
&self,
user_id: &str,
) -> Arc<
GovernorLimiter<
governor::state::NotKeyed,
governor::state::InMemoryState,
governor::clock::DefaultClock,
>,
> {
{
let limiters = self.limiters.read().await;
if let Some(limiter) = limiters.get(user_id) {
return limiter.clone();
}
}
let mut limiters = self.limiters.write().await;
if let Some(limiter) = limiters.get(user_id) {
return limiter.clone();
}
let quota = Quota::per_minute(
NonZeroU32::new(self.config.requests_per_minute)
.unwrap_or(NonZeroU32::new(60).unwrap()),
)
.allow_burst(
NonZeroU32::new(self.config.burst_size).unwrap_or(NonZeroU32::new(10).unwrap()),
);
let limiter = Arc::new(GovernorLimiter::direct(quota));
limiters.insert(user_id.to_string(), limiter.clone());
limiter
}
#[cfg(feature = "rate-limit")]
pub async fn cleanup(&self) {
let mut limiters = self.limiters.write().await;
if limiters.len() > 10000 {
limiters.clear();
}
}
#[cfg(feature = "rate-limit")]
pub async fn status(&self, user_id: &str) -> RateLimitStatus {
if !self.config.enabled {
return RateLimitStatus {
allowed: true,
remaining: self.config.requests_per_minute,
reset_at: None,
};
}
let limiter = self.get_or_create_limiter(user_id).await;
match limiter.check() {
Ok(_) => RateLimitStatus {
allowed: true,
remaining: self.config.requests_per_minute, reset_at: None,
},
Err(_) => RateLimitStatus {
allowed: false,
remaining: 0,
reset_at: Some(std::time::Duration::from_secs(60)),
},
}
}
#[cfg(not(feature = "rate-limit"))]
pub async fn status(&self, _user_id: &str) -> RateLimitStatus {
RateLimitStatus {
allowed: true,
remaining: u32::MAX,
reset_at: None,
}
}
}
#[derive(Debug, Clone)]
pub struct RateLimitStatus {
pub allowed: bool,
pub remaining: u32,
pub reset_at: Option<std::time::Duration>,
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_rate_limit_disabled() {
let config = RateLimitConfig {
enabled: false,
..Default::default()
};
let limiter = RateLimiter::new(config);
assert!(limiter.check("user1").await.is_ok());
assert!(limiter.check("user1").await.is_ok());
}
#[tokio::test]
#[cfg(feature = "rate-limit")]
async fn test_rate_limit_basic() {
let config = RateLimitConfig {
enabled: true,
requests_per_minute: 2,
burst_size: 2,
..Default::default()
};
let limiter = RateLimiter::new(config);
assert!(limiter.check("user1").await.is_ok());
assert!(limiter.check("user1").await.is_ok());
assert!(limiter.check("user1").await.is_err());
}
}