use chrono::{DateTime, Utc};
use dashmap::DashMap;
use std::{sync::Arc, time::Duration};
#[derive(Debug, Clone)]
pub struct RateLimitConfig {
pub requests_per_minute: u32,
pub burst_size: u32,
}
impl RateLimitConfig {
pub fn free_tier() -> Self {
Self {
requests_per_minute: 60,
burst_size: 100,
}
}
pub fn professional() -> Self {
Self {
requests_per_minute: 600,
burst_size: 1000,
}
}
pub fn unlimited() -> Self {
Self {
requests_per_minute: 10_000,
burst_size: 20_000,
}
}
pub fn dev_mode() -> Self {
Self {
requests_per_minute: 100_000,
burst_size: 200_000,
}
}
}
#[derive(Debug, Clone)]
struct TokenBucket {
tokens: f64,
max_tokens: f64,
refill_rate: f64,
last_refill: DateTime<Utc>,
}
impl TokenBucket {
fn new(config: &RateLimitConfig) -> Self {
let max_tokens = f64::from(config.burst_size);
Self {
tokens: max_tokens,
max_tokens,
refill_rate: f64::from(config.requests_per_minute) / 60.0, last_refill: Utc::now(),
}
}
fn try_consume(&mut self, tokens: f64) -> bool {
self.refill();
if self.tokens >= tokens {
self.tokens -= tokens;
true
} else {
false
}
}
fn refill(&mut self) {
let now = Utc::now();
let elapsed = (now - self.last_refill).num_milliseconds() as f64 / 1000.0;
if elapsed > 0.0 {
let new_tokens = elapsed * self.refill_rate;
self.tokens = (self.tokens + new_tokens).min(self.max_tokens);
self.last_refill = now;
}
}
fn remaining(&mut self) -> u32 {
self.refill();
self.tokens.floor() as u32
}
fn retry_after(&mut self) -> Duration {
self.refill();
if self.tokens >= 1.0 {
Duration::from_secs(0)
} else {
let tokens_needed = 1.0 - self.tokens;
let seconds = tokens_needed / self.refill_rate;
Duration::from_secs_f64(seconds)
}
}
}
pub struct RateLimiter {
buckets: Arc<DashMap<String, TokenBucket>>,
default_config: RateLimitConfig,
custom_configs: Arc<DashMap<String, RateLimitConfig>>,
}
impl RateLimiter {
pub fn new(default_config: RateLimitConfig) -> Self {
Self {
buckets: Arc::new(DashMap::new()),
default_config,
custom_configs: Arc::new(DashMap::new()),
}
}
pub fn set_config(&self, identifier: &str, config: RateLimitConfig) {
self.custom_configs.insert(identifier.to_string(), config);
self.buckets.remove(identifier);
}
pub fn check_rate_limit(&self, identifier: &str) -> RateLimitResult {
self.check_rate_limit_with_cost(identifier, 1.0)
}
pub fn check_rate_limit_with_cost(&self, identifier: &str, cost: f64) -> RateLimitResult {
let config = self
.custom_configs
.get(identifier)
.map_or_else(|| self.default_config.clone(), |c| c.clone());
let mut entry = self
.buckets
.entry(identifier.to_string())
.or_insert_with(|| TokenBucket::new(&config));
let allowed = entry.try_consume(cost);
let remaining = entry.remaining();
let retry_after = if allowed {
None
} else {
Some(entry.retry_after())
};
RateLimitResult {
allowed,
remaining,
retry_after,
limit: config.requests_per_minute,
}
}
pub fn get_stats(&self, identifier: &str) -> Option<RateLimitStats> {
self.buckets
.get_mut(identifier)
.map(|mut bucket| RateLimitStats {
remaining: bucket.remaining(),
retry_after: bucket.retry_after(),
})
}
pub fn cleanup(&self) {
let now = Utc::now();
self.buckets.retain(|_, bucket| {
(now - bucket.last_refill).num_hours() < 1
});
}
}
impl Default for RateLimiter {
fn default() -> Self {
Self::new(RateLimitConfig::professional())
}
}
#[derive(Debug, Clone)]
pub struct RateLimitResult {
pub allowed: bool,
pub remaining: u32,
pub retry_after: Option<Duration>,
pub limit: u32,
}
#[derive(Debug, Clone)]
pub struct RateLimitStats {
pub remaining: u32,
pub retry_after: Duration,
}
#[cfg(test)]
mod tests {
use super::*;
use std::{thread, time::Duration as StdDuration};
#[test]
fn test_token_bucket_creation() {
let config = RateLimitConfig::free_tier();
let bucket = TokenBucket::new(&config);
assert_eq!(bucket.max_tokens, 100.0);
assert_eq!(bucket.tokens, 100.0);
}
#[test]
fn test_token_consumption() {
let config = RateLimitConfig::free_tier();
let mut bucket = TokenBucket::new(&config);
assert!(bucket.try_consume(1.0));
assert_eq!(bucket.remaining(), 99);
assert!(bucket.try_consume(10.0));
assert_eq!(bucket.remaining(), 89);
}
#[test]
fn test_rate_limit_enforcement() {
let config = RateLimitConfig {
requests_per_minute: 60,
burst_size: 10,
};
let mut bucket = TokenBucket::new(&config);
for _ in 0..10 {
assert!(bucket.try_consume(1.0));
}
assert!(!bucket.try_consume(1.0));
}
#[test]
fn test_token_refill() {
let config = RateLimitConfig {
requests_per_minute: 60, burst_size: 10,
};
let mut bucket = TokenBucket::new(&config);
for _ in 0..10 {
bucket.try_consume(1.0);
}
assert_eq!(bucket.remaining(), 0);
thread::sleep(StdDuration::from_secs(2));
let remaining = bucket.remaining();
assert!(
(1..=3).contains(&remaining),
"Expected 1-3 tokens, got {remaining}"
);
}
#[test]
fn test_rate_limiter_per_identifier() {
let limiter = RateLimiter::new(RateLimitConfig {
requests_per_minute: 60,
burst_size: 5,
});
let result1 = limiter.check_rate_limit("user1");
let result2 = limiter.check_rate_limit("user2");
assert!(result1.allowed);
assert!(result2.allowed);
assert_eq!(result1.remaining, 4);
assert_eq!(result2.remaining, 4);
}
#[test]
fn test_custom_config() {
let limiter = RateLimiter::new(RateLimitConfig::free_tier());
limiter.set_config("premium_user", RateLimitConfig::unlimited());
let free_result = limiter.check_rate_limit("free_user");
let premium_result = limiter.check_rate_limit("premium_user");
assert!(free_result.limit < premium_result.limit);
}
#[test]
fn test_rate_limit_with_cost() {
let limiter = RateLimiter::new(RateLimitConfig {
requests_per_minute: 60,
burst_size: 10,
});
let result = limiter.check_rate_limit_with_cost("user1", 5.0);
assert!(result.allowed);
assert_eq!(result.remaining, 5);
}
}