use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
#[derive(Debug, Clone)]
pub struct RateLimitConfig {
pub max_requests: u32,
pub window: Duration,
pub max_tokens_per_minute: Option<u32>,
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
max_requests: 60, window: Duration::from_secs(60),
max_tokens_per_minute: Some(100_000),
}
}
}
impl RateLimitConfig {
pub fn conservative() -> Self {
Self {
max_requests: 10,
window: Duration::from_secs(60),
max_tokens_per_minute: Some(10_000),
}
}
pub fn aggressive() -> Self {
Self {
max_requests: 500,
window: Duration::from_secs(60),
max_tokens_per_minute: Some(1_000_000),
}
}
}
#[derive(Debug, Clone)]
struct RequestWindow {
count: u32,
tokens: u32,
window_start: Instant,
}
impl Default for RequestWindow {
fn default() -> Self {
Self {
count: 0,
tokens: 0,
window_start: Instant::now(),
}
}
}
#[derive(Debug)]
pub struct RateLimiter {
config: RateLimitConfig,
windows: RwLock<HashMap<String, RequestWindow>>,
}
impl RateLimiter {
pub fn new(config: RateLimitConfig) -> Self {
Self {
config,
windows: RwLock::new(HashMap::new()),
}
}
pub async fn check(&self, provider: &str) -> RateLimitResult {
let windows = self.windows.read().await;
let window = windows.get(provider).cloned().unwrap_or_default();
self.evaluate(&window)
}
pub async fn acquire(&self, provider: &str) -> Result<(), RateLimitError> {
loop {
let result = self.try_acquire(provider).await;
match result {
Ok(()) => return Ok(()),
Err(RateLimitError::Limited { retry_after }) => {
tokio::time::sleep(retry_after).await;
}
}
}
}
pub async fn try_acquire(&self, provider: &str) -> Result<(), RateLimitError> {
self.try_acquire_with_tokens(provider, 0).await
}
pub async fn try_acquire_with_tokens(
&self,
provider: &str,
estimated_tokens: u32,
) -> Result<(), RateLimitError> {
let mut windows = self.windows.write().await;
let window = windows.entry(provider.to_string()).or_default();
let elapsed = window.window_start.elapsed();
if elapsed >= self.config.window {
*window = RequestWindow::default();
}
if window.count >= self.config.max_requests {
let retry_after = self.config.window - elapsed;
return Err(RateLimitError::Limited { retry_after });
}
if let Some(max_tokens) = self.config.max_tokens_per_minute {
if estimated_tokens > 0 && window.tokens + estimated_tokens > max_tokens {
let retry_after = self.config.window - elapsed;
return Err(RateLimitError::Limited { retry_after });
}
}
window.count += 1;
window.tokens += estimated_tokens;
Ok(())
}
pub async fn record_tokens(&self, provider: &str, additional_tokens: u32) {
let mut windows = self.windows.write().await;
if let Some(window) = windows.get_mut(provider) {
window.tokens += additional_tokens;
}
}
pub async fn stats(&self, provider: &str) -> RateLimitStats {
let windows = self.windows.read().await;
let window = windows.get(provider).cloned().unwrap_or_default();
RateLimitStats {
requests_used: window.count,
requests_limit: self.config.max_requests,
tokens_used: window.tokens,
tokens_limit: self.config.max_tokens_per_minute,
window_remaining: self
.config
.window
.saturating_sub(window.window_start.elapsed()),
}
}
fn evaluate(&self, window: &RequestWindow) -> RateLimitResult {
let elapsed = window.window_start.elapsed();
if elapsed >= self.config.window {
return RateLimitResult::Allowed;
}
if window.count >= self.config.max_requests {
let retry_after = self.config.window - elapsed;
return RateLimitResult::Limited { retry_after };
}
if let Some(max_tokens) = self.config.max_tokens_per_minute {
if window.tokens >= max_tokens {
let retry_after = self.config.window - elapsed;
return RateLimitResult::Limited { retry_after };
}
}
RateLimitResult::Allowed
}
}
#[derive(Debug, Clone)]
pub enum RateLimitResult {
Allowed,
Limited { retry_after: Duration },
}
#[derive(Debug, thiserror::Error)]
pub enum RateLimitError {
#[error("Rate limited, retry after {retry_after:?}")]
Limited { retry_after: Duration },
}
#[derive(Debug, Clone)]
pub struct RateLimitStats {
pub requests_used: u32,
pub requests_limit: u32,
pub tokens_used: u32,
pub tokens_limit: Option<u32>,
pub window_remaining: Duration,
}
pub struct RateLimitedProvider<P> {
inner: P,
limiter: Arc<RateLimiter>,
provider_name: String,
}
impl<P> RateLimitedProvider<P> {
pub fn new(inner: P, limiter: Arc<RateLimiter>, provider_name: &str) -> Self {
Self {
inner,
limiter,
provider_name: provider_name.to_string(),
}
}
pub fn inner(&self) -> &P {
&self.inner
}
pub async fn acquire(&self) -> Result<(), RateLimitError> {
self.limiter.acquire(&self.provider_name).await
}
pub async fn stats(&self) -> RateLimitStats {
self.limiter.stats(&self.provider_name).await
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub enum UserTier {
#[default]
Free,
Pro,
Enterprise,
}
impl UserTier {
pub fn rate_limit_config(&self) -> RateLimitConfig {
match self {
UserTier::Free => RateLimitConfig {
max_requests: 10,
window: Duration::from_secs(60),
max_tokens_per_minute: Some(5_000),
},
UserTier::Pro => RateLimitConfig {
max_requests: 100,
window: Duration::from_secs(60),
max_tokens_per_minute: Some(50_000),
},
UserTier::Enterprise => RateLimitConfig {
max_requests: 1000,
window: Duration::from_secs(60),
max_tokens_per_minute: Some(500_000),
},
}
}
}
#[derive(Debug)]
pub struct UserRateLimiter {
user_windows: RwLock<HashMap<String, UserRateLimitState>>,
default_tier: UserTier,
tier_overrides: RwLock<HashMap<String, UserTier>>,
}
#[derive(Debug, Clone)]
struct UserRateLimitState {
tier: UserTier,
window: RequestWindow,
}
impl Default for UserRateLimitState {
fn default() -> Self {
Self {
tier: UserTier::Free,
window: RequestWindow::default(),
}
}
}
impl UserRateLimiter {
pub fn new(default_tier: UserTier) -> Self {
Self {
user_windows: RwLock::new(HashMap::new()),
default_tier,
tier_overrides: RwLock::new(HashMap::new()),
}
}
pub async fn set_user_tier(&self, user_id: &str, tier: UserTier) {
let mut overrides = self.tier_overrides.write().await;
overrides.insert(user_id.to_string(), tier);
}
pub async fn get_user_tier(&self, user_id: &str) -> UserTier {
let overrides = self.tier_overrides.read().await;
overrides.get(user_id).copied().unwrap_or(self.default_tier)
}
pub async fn try_acquire(&self, user_id: &str) -> Result<(), RateLimitError> {
self.try_acquire_with_tokens(user_id, 0).await
}
pub async fn try_acquire_with_tokens(
&self,
user_id: &str,
estimated_tokens: u32,
) -> Result<(), RateLimitError> {
let tier = self.get_user_tier(user_id).await;
let config = tier.rate_limit_config();
let mut windows = self.user_windows.write().await;
let state = windows
.entry(user_id.to_string())
.or_insert_with(|| UserRateLimitState {
tier,
window: RequestWindow::default(),
});
state.tier = tier;
let elapsed = state.window.window_start.elapsed();
if elapsed >= config.window {
state.window = RequestWindow::default();
}
if state.window.count >= config.max_requests {
let retry_after = config.window - elapsed;
return Err(RateLimitError::Limited { retry_after });
}
if let Some(max_tokens) = config.max_tokens_per_minute {
if estimated_tokens > 0 && state.window.tokens + estimated_tokens > max_tokens {
let retry_after = config.window - elapsed;
return Err(RateLimitError::Limited { retry_after });
}
}
state.window.count += 1;
state.window.tokens += estimated_tokens;
Ok(())
}
pub async fn user_stats(&self, user_id: &str) -> UserRateLimitStats {
let tier = self.get_user_tier(user_id).await;
let config = tier.rate_limit_config();
let windows = self.user_windows.read().await;
let state = windows.get(user_id).cloned().unwrap_or_default();
let elapsed = state.window.window_start.elapsed();
let window_remaining = if elapsed >= config.window {
config.window
} else {
config.window - elapsed
};
UserRateLimitStats {
user_id: user_id.to_string(),
tier,
requests_used: state.window.count,
requests_limit: config.max_requests,
tokens_used: state.window.tokens,
tokens_limit: config.max_tokens_per_minute,
window_remaining,
}
}
}
#[derive(Debug, Clone)]
pub struct UserRateLimitStats {
pub user_id: String,
pub tier: UserTier,
pub requests_used: u32,
pub requests_limit: u32,
pub tokens_used: u32,
pub tokens_limit: Option<u32>,
pub window_remaining: Duration,
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_rate_limiter() {
let config = RateLimitConfig {
max_requests: 3,
window: Duration::from_millis(100),
max_tokens_per_minute: None,
};
let limiter = RateLimiter::new(config);
assert!(limiter.try_acquire("test").await.is_ok());
assert!(limiter.try_acquire("test").await.is_ok());
assert!(limiter.try_acquire("test").await.is_ok());
assert!(matches!(
limiter.try_acquire("test").await,
Err(RateLimitError::Limited { .. })
));
tokio::time::sleep(Duration::from_millis(150)).await;
assert!(limiter.try_acquire("test").await.is_ok());
}
#[tokio::test]
async fn test_stats() {
let limiter = RateLimiter::new(RateLimitConfig::default());
limiter.try_acquire("provider1").await.unwrap();
limiter.try_acquire("provider1").await.unwrap();
let stats = limiter.stats("provider1").await;
assert_eq!(stats.requests_used, 2);
}
#[tokio::test]
async fn test_user_rate_limiter_tiers() {
let limiter = UserRateLimiter::new(UserTier::Free);
assert_eq!(limiter.get_user_tier("user1").await, UserTier::Free);
limiter.set_user_tier("user2", UserTier::Pro).await;
assert_eq!(limiter.get_user_tier("user2").await, UserTier::Pro);
for _ in 0..10 {
assert!(limiter.try_acquire("free_user").await.is_ok());
}
assert!(matches!(
limiter.try_acquire("free_user").await,
Err(RateLimitError::Limited { .. })
));
limiter.set_user_tier("pro_user", UserTier::Pro).await;
for _ in 0..50 {
assert!(limiter.try_acquire("pro_user").await.is_ok());
}
let stats = limiter.user_stats("pro_user").await;
assert_eq!(stats.tier, UserTier::Pro);
assert_eq!(stats.requests_used, 50);
assert_eq!(stats.requests_limit, 100);
}
}