use super::types::{RateLimitEntry, RateLimitResult};
use crate::config::models::rate_limit::{RateLimitConfig, RateLimitStrategy};
use dashmap::DashMap;
use std::sync::Arc;
use std::time::Duration;
pub struct RateLimiter {
pub(super) config: RateLimitConfig,
pub(super) entries: Arc<DashMap<String, RateLimitEntry>>,
pub(super) window: Duration,
}
impl RateLimiter {
pub fn new(config: RateLimitConfig) -> Self {
Self {
config,
entries: Arc::new(DashMap::new()),
window: Duration::from_secs(60), }
}
pub fn with_window(config: RateLimitConfig, window: Duration) -> Self {
Self {
config,
entries: Arc::new(DashMap::new()),
window,
}
}
pub async fn check(&self, key: &str) -> RateLimitResult {
if !self.config.enabled {
return RateLimitResult {
allowed: true,
current_count: 0,
limit: self.config.default_rpm,
remaining: self.config.default_rpm,
reset_after_secs: 0,
retry_after_secs: None,
};
}
match self.config.strategy {
RateLimitStrategy::SlidingWindow => self.check_sliding_window_impl(key, false).await,
RateLimitStrategy::TokenBucket => self.check_token_bucket_impl(key, false).await,
RateLimitStrategy::FixedWindow => self.check_fixed_window_impl(key, false).await,
}
}
pub async fn check_and_record(&self, key: &str) -> RateLimitResult {
if !self.config.enabled {
return RateLimitResult {
allowed: true,
current_count: 0,
limit: self.config.default_rpm,
remaining: self.config.default_rpm,
reset_after_secs: 0,
retry_after_secs: None,
};
}
match self.config.strategy {
RateLimitStrategy::SlidingWindow => self.check_sliding_window_impl(key, true).await,
RateLimitStrategy::TokenBucket => self.check_token_bucket_impl(key, true).await,
RateLimitStrategy::FixedWindow => self.check_fixed_window_impl(key, true).await,
}
}
#[deprecated(note = "Use check_and_record() instead to avoid race conditions")]
pub async fn record(&self, key: &str) {
if !self.config.enabled {
return;
}
let mut entry = self.entries.entry(key.to_string()).or_default();
match self.config.strategy {
RateLimitStrategy::SlidingWindow | RateLimitStrategy::FixedWindow => {
entry.timestamps.push(std::time::Instant::now());
}
RateLimitStrategy::TokenBucket => {
}
}
}
}
impl Clone for RateLimiter {
fn clone(&self) -> Self {
Self {
config: self.config.clone(),
entries: self.entries.clone(),
window: self.window,
}
}
}