use std::sync::atomic::{AtomicI64, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use crate::error::{Error, Result};
#[derive(Debug, Clone, Copy)]
pub struct TokenBucketConfig {
pub requests_per_second: u64,
pub refill_period: Duration,
pub burst_size: u64,
}
impl TokenBucketConfig {
pub fn new(requests_per_second: u64, burst_size: u64) -> Self {
Self {
requests_per_second,
refill_period: Duration::from_secs(1),
burst_size: burst_size.max(requests_per_second),
}
}
pub fn per_provider() -> Self {
Self::new(100, 100)
}
pub fn per_model() -> Self {
Self::new(10, 10)
}
pub fn per_user() -> Self {
Self::new(1, 1)
}
pub fn unlimited() -> Self {
Self {
requests_per_second: 1_000_000_000,
refill_period: Duration::from_secs(1),
burst_size: 1_000_000_000,
}
}
#[allow(dead_code)]
fn tokens_cost(&self) -> i64 {
1
}
fn tokens_per_ms(&self) -> f64 {
self.requests_per_second as f64 / 1000.0
}
}
pub struct RateLimiter {
tokens: Arc<AtomicI64>,
last_refill: Arc<AtomicU64>,
config: TokenBucketConfig,
}
impl RateLimiter {
pub fn new(config: TokenBucketConfig) -> Self {
Self {
tokens: Arc::new(AtomicI64::new(config.burst_size as i64)),
last_refill: Arc::new(AtomicU64::new(current_time_ms())),
config,
}
}
pub fn check_and_consume(&self) -> Result<()> {
self.check_and_consume_tokens(1)
}
pub fn check_and_consume_tokens(&self, tokens: u64) -> Result<()> {
let tokens_cost = tokens as i64;
self.refill();
let mut current = self.tokens.load(Ordering::Acquire);
loop {
if current >= tokens_cost {
match self.tokens.compare_exchange(
current,
current - tokens_cost,
Ordering::Release,
Ordering::Acquire,
) {
Ok(_) => return Ok(()),
Err(actual) => {
current = actual;
continue;
}
}
} else {
return Err(Error::InvalidRequest("Rate limit exceeded".to_string()));
}
}
}
fn refill(&self) {
let now = current_time_ms();
let last = self.last_refill.load(Ordering::Acquire);
if now <= last {
return; }
let elapsed_ms = (now - last) as f64;
let tokens_to_add = (elapsed_ms * self.config.tokens_per_ms()).ceil() as i64;
if tokens_to_add <= 0 {
return; }
let current_tokens = self.tokens.load(Ordering::Acquire);
let new_tokens = (current_tokens + tokens_to_add).min(self.config.burst_size as i64);
let _ = self.tokens.compare_exchange(
current_tokens,
new_tokens,
Ordering::Release,
Ordering::Acquire,
);
let _ = self
.last_refill
.compare_exchange(last, now, Ordering::Release, Ordering::Acquire);
}
pub fn available_tokens(&self) -> u64 {
self.refill();
self.tokens.load(Ordering::Acquire).max(0) as u64
}
pub fn capacity(&self) -> u64 {
self.config.burst_size
}
pub fn is_limited(&self) -> bool {
self.available_tokens() == 0
}
pub fn reset(&self) {
self.tokens
.store(self.config.burst_size as i64, Ordering::Release);
self.last_refill.store(current_time_ms(), Ordering::Release);
}
}
impl Clone for RateLimiter {
fn clone(&self) -> Self {
Self {
tokens: Arc::clone(&self.tokens),
last_refill: Arc::clone(&self.last_refill),
config: self.config,
}
}
}
fn current_time_ms() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rate_limiter_creation() {
let limiter = RateLimiter::new(TokenBucketConfig::per_provider());
assert_eq!(limiter.capacity(), 100);
assert_eq!(limiter.available_tokens(), 100);
}
#[test]
fn test_rate_limiter_consume() {
let limiter = RateLimiter::new(TokenBucketConfig::new(10, 10));
for i in 0..10 {
assert!(
limiter.check_and_consume().is_ok(),
"Request {} should succeed",
i
);
}
assert!(limiter.check_and_consume().is_err());
}
#[test]
fn test_rate_limiter_consume_multiple() {
let limiter = RateLimiter::new(TokenBucketConfig::new(100, 100));
assert!(limiter.check_and_consume_tokens(50).is_ok());
assert_eq!(limiter.available_tokens(), 50);
assert!(limiter.check_and_consume_tokens(50).is_ok());
assert_eq!(limiter.available_tokens(), 0);
assert!(limiter.check_and_consume().is_err());
}
#[test]
fn test_rate_limiter_refill() {
let limiter = RateLimiter::new(TokenBucketConfig::new(1000, 1000));
for _ in 0..995 {
assert!(limiter.check_and_consume().is_ok());
}
let available = limiter.available_tokens();
assert!(available < 10, "Should have consumed most tokens");
std::thread::sleep(Duration::from_millis(10));
let available_after = limiter.available_tokens();
assert!(available_after > 0, "Should have refilled tokens");
}
#[test]
fn test_rate_limiter_clone() {
let limiter1 = RateLimiter::new(TokenBucketConfig::new(10, 10));
let limiter2 = limiter1.clone();
assert!(limiter1.check_and_consume().is_ok());
assert_eq!(limiter1.available_tokens(), 9);
assert_eq!(limiter2.available_tokens(), 9);
}
#[test]
fn test_rate_limiter_reset() {
let limiter = RateLimiter::new(TokenBucketConfig::new(10, 10));
for _ in 0..5 {
assert!(limiter.check_and_consume().is_ok());
}
assert_eq!(limiter.available_tokens(), 5);
limiter.reset();
assert_eq!(limiter.available_tokens(), 10);
}
#[test]
fn test_rate_limiter_unlimited() {
let limiter = RateLimiter::new(TokenBucketConfig::unlimited());
for _ in 0..1000 {
assert!(limiter.check_and_consume().is_ok());
}
}
#[test]
fn test_rate_limiter_is_limited() {
let limiter = RateLimiter::new(TokenBucketConfig::new(10, 10));
assert!(!limiter.is_limited());
for _ in 0..5 {
assert!(limiter.check_and_consume().is_ok());
}
assert!(!limiter.is_limited());
for _ in 0..5 {
assert!(limiter.check_and_consume().is_ok());
}
assert!(limiter.is_limited());
}
#[test]
fn test_per_provider_config() {
let config = TokenBucketConfig::per_provider();
assert_eq!(config.requests_per_second, 100);
assert_eq!(config.burst_size, 100);
}
#[test]
fn test_per_model_config() {
let config = TokenBucketConfig::per_model();
assert_eq!(config.requests_per_second, 10);
assert_eq!(config.burst_size, 10);
}
#[test]
fn test_per_user_config() {
let config = TokenBucketConfig::per_user();
assert_eq!(config.requests_per_second, 1);
assert_eq!(config.burst_size, 1);
}
#[tokio::test]
async fn test_concurrent_access() {
use tokio::task::JoinSet;
let limiter = RateLimiter::new(TokenBucketConfig::new(100, 100));
let mut set = JoinSet::new();
for _ in 0..10 {
let limiter = limiter.clone();
set.spawn(async move {
let mut success_count = 0;
for _ in 0..15 {
if limiter.check_and_consume().is_ok() {
success_count += 1;
}
}
success_count
});
}
let mut total_success = 0;
while let Some(result) = set.join_next().await {
total_success += result.unwrap();
}
assert!(
total_success <= 110,
"Expected <= 110, got {}",
total_success
);
assert!(total_success > 0);
}
}