use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use crate::error::{AiError, Result};
#[derive(Clone)]
pub struct RateLimiter {
inner: Arc<RateLimiterInner>,
}
struct RateLimiterInner {
capacity: usize,
tokens: RwLock<f64>,
refill_rate: f64,
last_refill: RwLock<Instant>,
}
impl RateLimiter {
#[must_use]
pub fn new(requests_per_second: f64, burst_size: Option<usize>) -> Self {
let capacity = burst_size.unwrap_or(requests_per_second.ceil() as usize);
Self {
inner: Arc::new(RateLimiterInner {
capacity,
tokens: RwLock::new(capacity as f64),
refill_rate: requests_per_second,
last_refill: RwLock::new(Instant::now()),
}),
}
}
pub async fn acquire(&self) -> Result<RateLimitGuard> {
self.refill_tokens().await;
loop {
{
let mut tokens = self.inner.tokens.write().await;
if *tokens >= 1.0 {
*tokens -= 1.0;
return Ok(RateLimitGuard {
limiter: self.clone(),
});
}
}
tokio::time::sleep(Duration::from_millis(10)).await;
self.refill_tokens().await;
}
}
pub async fn try_acquire(&self) -> Option<RateLimitGuard> {
self.refill_tokens().await;
let mut tokens = self.inner.tokens.write().await;
if *tokens >= 1.0 {
*tokens -= 1.0;
Some(RateLimitGuard {
limiter: self.clone(),
})
} else {
None
}
}
async fn refill_tokens(&self) {
let mut last_refill = self.inner.last_refill.write().await;
let elapsed = last_refill.elapsed();
if elapsed >= Duration::from_millis(10) {
let mut tokens = self.inner.tokens.write().await;
let tokens_to_add = self.inner.refill_rate * elapsed.as_secs_f64();
*tokens = (*tokens + tokens_to_add).min(self.inner.capacity as f64);
*last_refill = Instant::now();
}
}
pub async fn available_tokens(&self) -> f64 {
self.refill_tokens().await;
*self.inner.tokens.read().await
}
#[must_use]
pub fn capacity(&self) -> usize {
self.inner.capacity
}
#[must_use]
pub fn refill_rate(&self) -> f64 {
self.inner.refill_rate
}
}
pub struct RateLimitGuard {
#[allow(dead_code)]
limiter: RateLimiter,
}
pub struct TieredRateLimiter {
limiters: Arc<RwLock<std::collections::HashMap<String, RateLimiter>>>,
}
impl Default for TieredRateLimiter {
fn default() -> Self {
Self::new()
}
}
impl TieredRateLimiter {
#[must_use]
pub fn new() -> Self {
Self {
limiters: Arc::new(RwLock::new(std::collections::HashMap::new())),
}
}
pub async fn add_tier(&self, tier: impl Into<String>, limiter: RateLimiter) {
let mut limiters = self.limiters.write().await;
limiters.insert(tier.into(), limiter);
}
pub async fn acquire(&self, tier: &str) -> Result<RateLimitGuard> {
let limiters = self.limiters.read().await;
let limiter = limiters
.get(tier)
.ok_or_else(|| AiError::Configuration(format!("No rate limiter for tier: {tier}")))?;
limiter.acquire().await
}
pub async fn try_acquire(&self, tier: &str) -> Option<RateLimitGuard> {
let limiters = self.limiters.read().await;
let limiter = limiters.get(tier)?;
limiter.try_acquire().await
}
}
pub struct RateLimiterConfig {
requests_per_second: f64,
burst_size: Option<usize>,
}
impl RateLimiterConfig {
#[must_use]
pub fn new(requests_per_second: f64) -> Self {
Self {
requests_per_second,
burst_size: None,
}
}
#[must_use]
pub fn with_burst_size(mut self, burst_size: usize) -> Self {
self.burst_size = Some(burst_size);
self
}
#[must_use]
pub fn build(self) -> RateLimiter {
RateLimiter::new(self.requests_per_second, self.burst_size)
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::time::sleep;
#[tokio::test]
async fn test_rate_limiter_basic() {
let limiter = RateLimiter::new(10.0, Some(10));
let _guard = limiter.acquire().await;
assert!(_guard.is_ok());
}
#[tokio::test]
async fn test_rate_limiter_try_acquire() {
let limiter = RateLimiter::new(10.0, Some(1));
let _guard1 = limiter.try_acquire().await;
assert!(_guard1.is_some());
let guard2 = limiter.try_acquire().await;
assert!(guard2.is_none());
}
#[tokio::test]
async fn test_rate_limiter_refill() {
let limiter = RateLimiter::new(100.0, Some(1));
let _guard = limiter.try_acquire().await.unwrap();
drop(_guard);
assert!(limiter.try_acquire().await.is_none());
sleep(Duration::from_millis(20)).await;
assert!(limiter.try_acquire().await.is_some());
}
#[tokio::test]
async fn test_rate_limiter_burst() {
let limiter = RateLimiter::new(10.0, Some(5));
let mut guards = Vec::new();
for _ in 0..5 {
if let Some(guard) = limiter.try_acquire().await {
guards.push(guard);
}
}
assert_eq!(guards.len(), 5);
assert!(limiter.try_acquire().await.is_none());
}
#[tokio::test]
async fn test_rate_limiter_available_tokens() {
let limiter = RateLimiter::new(10.0, Some(10));
let tokens = limiter.available_tokens().await;
assert_eq!(tokens, 10.0);
let _guard = limiter.try_acquire().await.unwrap();
let tokens = limiter.available_tokens().await;
assert_eq!(tokens, 9.0);
}
#[tokio::test]
async fn test_tiered_rate_limiter() {
let tiered = TieredRateLimiter::new();
tiered
.add_tier("free", RateLimiter::new(1.0, Some(1)))
.await;
tiered
.add_tier("premium", RateLimiter::new(10.0, Some(10)))
.await;
let _guard = tiered.acquire("free").await;
assert!(_guard.is_ok());
let _guard = tiered.acquire("premium").await;
assert!(_guard.is_ok());
let result = tiered.acquire("nonexistent").await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_rate_limiter_config_builder() {
let limiter = RateLimiterConfig::new(5.0).with_burst_size(10).build();
assert_eq!(limiter.capacity(), 10);
assert_eq!(limiter.refill_rate(), 5.0);
}
#[tokio::test]
async fn test_concurrent_access() {
let limiter = RateLimiter::new(100.0, Some(10));
let limiter_clone = limiter.clone();
let handles: Vec<_> = (0..10)
.map(|_| {
let limiter = limiter_clone.clone();
tokio::spawn(async move {
let _guard = limiter.acquire().await.unwrap();
tokio::time::sleep(Duration::from_millis(10)).await;
})
})
.collect();
for handle in handles {
handle.await.unwrap();
}
}
}