use parking_lot::Mutex;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::time::sleep;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum RateLimitAlgorithm {
TokenBucket,
LeakyBucket,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RateLimitConfig {
pub capacity: u64,
pub refill_rate: u64,
pub refill_interval: Duration,
pub algorithm: RateLimitAlgorithm,
pub block_on_limit: bool,
}
impl RateLimitConfig {
pub fn new(capacity: u64, refill_interval: Duration) -> Self {
Self {
capacity,
refill_rate: capacity,
refill_interval,
algorithm: RateLimitAlgorithm::TokenBucket,
block_on_limit: true,
}
}
pub fn per_second(requests: u64) -> Self {
Self::new(requests, Duration::from_secs(1))
}
pub fn per_minute(requests: u64) -> Self {
Self::new(requests, Duration::from_secs(60))
}
pub fn with_refill_rate(mut self, rate: u64) -> Self {
self.refill_rate = rate;
self
}
pub fn with_algorithm(mut self, algorithm: RateLimitAlgorithm) -> Self {
self.algorithm = algorithm;
self
}
pub fn with_blocking(mut self, block: bool) -> Self {
self.block_on_limit = block;
self
}
}
#[derive(Debug)]
struct RateLimiterState {
tokens: f64,
last_refill: Instant,
total_requests: u64,
requests_allowed: u64,
requests_denied: u64,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct RateLimitStats {
pub total_requests: u64,
pub requests_allowed: u64,
pub requests_denied: u64,
pub available_tokens: u64,
pub utilization_percent: f64,
}
pub struct RateLimiter {
config: RateLimitConfig,
state: Arc<Mutex<RateLimiterState>>,
}
impl RateLimiter {
pub fn new(config: RateLimitConfig) -> Self {
Self {
state: Arc::new(Mutex::new(RateLimiterState {
tokens: config.capacity as f64,
last_refill: Instant::now(),
total_requests: 0,
requests_allowed: 0,
requests_denied: 0,
})),
config,
}
}
pub async fn acquire(&self, tokens: u64) -> bool {
loop {
let wait_duration = {
let mut state = self.state.lock();
self.refill_tokens(&mut state);
state.total_requests += 1;
if state.tokens >= tokens as f64 {
state.tokens -= tokens as f64;
state.requests_allowed += 1;
return true;
} else {
state.requests_denied += 1;
if !self.config.block_on_limit {
return false;
}
let tokens_needed = tokens as f64 - state.tokens;
let tokens_per_ms = self.config.refill_rate as f64
/ self.config.refill_interval.as_millis() as f64;
let wait_ms = (tokens_needed / tokens_per_ms).ceil() as u64;
Duration::from_millis(wait_ms.max(1))
}
};
sleep(wait_duration).await;
}
}
pub fn try_acquire(&self, tokens: u64) -> bool {
let mut state = self.state.lock();
self.refill_tokens(&mut state);
state.total_requests += 1;
if state.tokens >= tokens as f64 {
state.tokens -= tokens as f64;
state.requests_allowed += 1;
true
} else {
state.requests_denied += 1;
false
}
}
pub fn stats(&self) -> RateLimitStats {
let mut state = self.state.lock();
self.refill_tokens(&mut state);
RateLimitStats {
total_requests: state.total_requests,
requests_allowed: state.requests_allowed,
requests_denied: state.requests_denied,
available_tokens: state.tokens as u64,
utilization_percent: if state.total_requests > 0 {
(state.requests_allowed as f64 / state.total_requests as f64) * 100.0
} else {
0.0
},
}
}
pub fn reset(&self) {
let mut state = self.state.lock();
state.tokens = self.config.capacity as f64;
state.last_refill = Instant::now();
state.total_requests = 0;
state.requests_allowed = 0;
state.requests_denied = 0;
}
fn refill_tokens(&self, state: &mut RateLimiterState) {
let now = Instant::now();
let elapsed = now.duration_since(state.last_refill);
if elapsed >= self.config.refill_interval {
let intervals = elapsed.as_secs_f64() / self.config.refill_interval.as_secs_f64();
let tokens_to_add = intervals * self.config.refill_rate as f64;
state.tokens = (state.tokens + tokens_to_add).min(self.config.capacity as f64);
state.last_refill = now;
}
}
}
impl Clone for RateLimiter {
fn clone(&self) -> Self {
Self {
config: self.config.clone(),
state: Arc::clone(&self.state),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::time::{sleep, Duration};
#[tokio::test]
async fn test_rate_limiter_basic() {
let config = RateLimitConfig::new(10, Duration::from_secs(1));
let limiter = RateLimiter::new(config);
for _ in 0..10 {
assert!(limiter.try_acquire(1));
}
assert!(!limiter.try_acquire(1));
let stats = limiter.stats();
assert_eq!(stats.requests_allowed, 10);
assert_eq!(stats.requests_denied, 1);
}
#[tokio::test]
async fn test_rate_limiter_refill() {
let config = RateLimitConfig::new(5, Duration::from_millis(100));
let limiter = RateLimiter::new(config);
for _ in 0..5 {
assert!(limiter.try_acquire(1));
}
assert!(!limiter.try_acquire(1));
sleep(Duration::from_millis(150)).await;
assert!(limiter.try_acquire(1));
}
#[tokio::test]
async fn test_rate_limiter_blocking() {
let config = RateLimitConfig::new(2, Duration::from_millis(100)).with_blocking(true);
let limiter = RateLimiter::new(config);
limiter.acquire(2).await;
let start = Instant::now();
limiter.acquire(1).await;
let elapsed = start.elapsed();
assert!(elapsed >= Duration::from_millis(50));
}
#[tokio::test]
async fn test_rate_limiter_stats() {
let config = RateLimitConfig::new(10, Duration::from_secs(1));
let limiter = RateLimiter::new(config);
for _ in 0..5 {
limiter.try_acquire(1);
}
let stats = limiter.stats();
assert_eq!(stats.total_requests, 5);
assert_eq!(stats.requests_allowed, 5);
assert_eq!(stats.requests_denied, 0);
assert_eq!(stats.available_tokens, 5);
assert_eq!(stats.utilization_percent, 100.0);
}
#[tokio::test]
async fn test_rate_limiter_reset() {
let config = RateLimitConfig::new(5, Duration::from_secs(1));
let limiter = RateLimiter::new(config);
for _ in 0..5 {
limiter.try_acquire(1);
}
limiter.reset();
assert!(limiter.try_acquire(1));
let stats = limiter.stats();
assert_eq!(stats.total_requests, 1);
}
#[tokio::test]
async fn test_rate_limiter_per_second() {
let config = RateLimitConfig::per_second(100);
let limiter = RateLimiter::new(config);
assert_eq!(limiter.config.capacity, 100);
assert_eq!(limiter.config.refill_interval, Duration::from_secs(1));
}
#[tokio::test]
async fn test_rate_limiter_per_minute() {
let config = RateLimitConfig::per_minute(1000);
let limiter = RateLimiter::new(config);
assert_eq!(limiter.config.capacity, 1000);
assert_eq!(limiter.config.refill_interval, Duration::from_secs(60));
}
}