use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use tokio::sync::Mutex;
pub struct RateLimiter {
max_tokens: u32,
refill_rate: u32,
tokens: Arc<Mutex<f64>>,
last_update: Arc<AtomicU64>,
}
impl RateLimiter {
pub fn new(max_tokens: u32, refill_rate: u32) -> Self {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64;
tracing::debug!(
"Initializing RateLimiter (max: {}, rate: {}/s)",
max_tokens,
refill_rate
);
Self {
max_tokens,
refill_rate,
tokens: Arc::new(Mutex::new(max_tokens as f64)),
last_update: Arc::new(AtomicU64::new(now)),
}
}
pub fn per_second(rate: u32) -> Self {
Self::new(rate * 10, rate)
}
pub async fn acquire(&self, tokens: u32) -> bool {
let mut current = self.tokens.lock().await;
self.refill(&mut current);
if *current >= tokens as f64 {
*current -= tokens as f64;
tracing::debug!("Token acquired (remaining: {:.2})", *current);
true
} else {
tracing::warn!(
"Rate limit exceeded: not enough tokens (available: {:.2})",
*current
);
false
}
}
pub async fn wait_for_token(&self, tokens: u32) {
tracing::debug!("Waiting for {} tokens...", tokens);
loop {
if self.acquire(tokens).await {
return;
}
tokio::time::sleep(Duration::from_millis(100)).await;
}
}
fn refill(&self, tokens: &mut f64) {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64;
let last = self.last_update.load(Ordering::Relaxed);
let elapsed_ms = now.saturating_sub(last);
let elapsed_secs = elapsed_ms as f64 / 1000.0;
let new_tokens = *tokens + (self.refill_rate as f64 * elapsed_secs);
*tokens = new_tokens.min(self.max_tokens as f64);
self.last_update.store(now, Ordering::Relaxed);
}
pub async fn current_tokens(&self) -> f64 {
*self.tokens.lock().await
}
}
pub struct RequestLimiter {
max_concurrent: u32,
current: Arc<AtomicU64>,
}
impl RequestLimiter {
pub fn new(max_concurrent: u32) -> Self {
tracing::debug!(
"Initializing RequestLimiter (max concurrent: {})",
max_concurrent
);
Self {
max_concurrent,
current: Arc::new(AtomicU64::new(0)),
}
}
pub fn can_request(&self) -> bool {
let current = self.current.load(Ordering::Acquire);
current < self.max_concurrent as u64
}
pub fn start_request(&self) -> Result<RequestGuard, &'static str> {
if self.can_request() {
let prev = self.current.fetch_add(1, Ordering::Release);
tracing::debug!("Request started (active: {})", prev + 1);
Ok(RequestGuard {
limiter: Arc::new(self.clone()),
})
} else {
tracing::warn!("Concurrency limit reached: {}", self.max_concurrent);
Err("Maximum concurrent requests exceeded")
}
}
pub fn current_requests(&self) -> u64 {
self.current.load(Ordering::Acquire)
}
fn finish_request(&self) {
let prev = self.current.fetch_sub(1, Ordering::Release);
tracing::debug!("Request finished (active: {})", prev - 1);
}
}
impl Clone for RequestLimiter {
fn clone(&self) -> Self {
Self {
max_concurrent: self.max_concurrent,
current: Arc::clone(&self.current),
}
}
}
pub struct RequestGuard {
limiter: Arc<RequestLimiter>,
}
impl Drop for RequestGuard {
fn drop(&mut self) {
self.limiter.finish_request();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_rate_limiter() {
let limiter = RateLimiter::new(10, 10);
assert!(limiter.acquire(5).await);
assert!(limiter.acquire(5).await);
assert!(!limiter.acquire(1).await);
}
#[test]
fn test_request_limiter() {
let limiter = RequestLimiter::new(2);
assert!(limiter.can_request());
let _guard1 = limiter.start_request().unwrap();
assert!(limiter.can_request());
let _guard2 = limiter.start_request().unwrap();
assert!(!limiter.can_request());
drop(_guard1);
assert!(limiter.can_request());
}
}