use crate::time::{sleep, Duration, Instant};
use std::sync::Arc;
use std::sync::Mutex;
use tracing::{debug, instrument};
#[derive(Clone)]
pub struct RateLimiter {
bucket: Arc<Mutex<TokenBucket>>,
}
struct TokenBucket {
tokens: f64,
capacity: f64,
refill_rate: f64, last_refill: Instant,
}
impl RateLimiter {
pub fn new(rate: f64) -> Self {
let capacity = rate.max(1.0); let now = Instant::now();
Self {
bucket: Arc::new(Mutex::new(TokenBucket {
tokens: capacity,
capacity,
refill_rate: rate,
last_refill: now,
})),
}
}
pub fn ncbi_default() -> Self {
Self::new(3.0)
}
pub fn ncbi_with_key() -> Self {
Self::new(10.0)
}
#[instrument(skip(self))]
pub async fn acquire(&self) -> crate::Result<()> {
let should_wait = {
let mut bucket = self.bucket.lock().unwrap();
self.refill_bucket(&mut bucket);
if bucket.tokens >= 1.0 {
bucket.tokens -= 1.0;
debug!(remaining_tokens = %bucket.tokens, "Token acquired immediately");
false
} else {
debug!("No tokens available, need to wait");
true
}
};
if should_wait {
let wait_duration = Duration::from_secs(1).as_secs_f64() / self.rate();
let wait_duration = Duration::from_millis((wait_duration * 1000.0) as u64);
debug!(
wait_duration_ms = wait_duration.as_millis(),
"Waiting for rate limit"
);
sleep(wait_duration).await;
let mut bucket = self.bucket.lock().unwrap();
self.refill_bucket(&mut bucket);
bucket.tokens = bucket.tokens.min(bucket.capacity);
if bucket.tokens >= 1.0 {
bucket.tokens -= 1.0;
debug!(remaining_tokens = %bucket.tokens, "Token acquired after waiting");
}
}
Ok(())
}
pub fn check_available(&self) -> bool {
let mut bucket = self.bucket.lock().unwrap();
self.refill_bucket(&mut bucket);
bucket.tokens >= 1.0
}
pub fn token_count(&self) -> f64 {
let mut bucket = self.bucket.lock().unwrap();
self.refill_bucket(&mut bucket);
bucket.tokens
}
pub fn rate(&self) -> f64 {
let bucket = self.bucket.lock().unwrap();
bucket.refill_rate
}
fn refill_bucket(&self, bucket: &mut TokenBucket) {
let now = Instant::now();
let elapsed = now.duration_since(bucket.last_refill);
let tokens_to_add = elapsed.as_secs_f64() * bucket.refill_rate;
bucket.tokens = (bucket.tokens + tokens_to_add).min(bucket.capacity);
bucket.last_refill = now;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_basic_functionality() {
let limiter = RateLimiter::new(5.0);
limiter.acquire().await.unwrap();
let rate = limiter.rate();
assert!((rate - 5.0).abs() < 0.1);
}
#[tokio::test]
async fn test_check_available() {
let limiter = RateLimiter::new(2.0);
assert!(limiter.check_available());
}
#[tokio::test]
async fn test_ncbi_presets() {
let default_limiter = RateLimiter::ncbi_default();
let with_key_limiter = RateLimiter::ncbi_with_key();
assert!((default_limiter.rate() - 3.0).abs() < 0.1);
assert!((with_key_limiter.rate() - 10.0).abs() < 0.1);
}
#[tokio::test]
async fn test_rate_limiting_basic() {
let limiter = RateLimiter::new(1.0);
limiter.acquire().await.unwrap();
limiter.acquire().await.unwrap();
let tokens = limiter.token_count();
assert!(tokens >= 0.0);
}
}