use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::Mutex;
use tokio::time::sleep;
#[derive(Clone, Debug)]
pub struct RateLimiter {
inner: Arc<Mutex<RateLimiterInner>>,
}
#[derive(Debug)]
struct RateLimiterInner {
capacity: u32,
tokens: u32,
refill_rate: u32,
last_refill: Instant,
}
impl RateLimiter {
pub fn new(capacity: u32, refill_rate: u32) -> Self {
Self {
inner: Arc::new(Mutex::new(RateLimiterInner {
capacity,
tokens: capacity,
refill_rate,
last_refill: Instant::now(),
})),
}
}
pub fn finnhub_default() -> Self {
Self::new(30, 30)
}
pub fn finnhub_15s_window() -> Self {
Self::new(450, 30)
}
pub async fn acquire(&self) -> Result<(), crate::Error> {
loop {
let mut limiter = self.inner.lock().await;
let now = Instant::now();
let elapsed = now.duration_since(limiter.last_refill);
let tokens_to_add = (elapsed.as_secs_f64() * f64::from(limiter.refill_rate)) as u32;
if tokens_to_add > 0 {
limiter.tokens = (limiter.tokens + tokens_to_add).min(limiter.capacity);
limiter.last_refill = now;
}
if limiter.tokens > 0 {
limiter.tokens -= 1;
return Ok(());
}
let tokens_needed = 1;
let wait_time =
Duration::from_secs_f64(f64::from(tokens_needed) / f64::from(limiter.refill_rate));
drop(limiter); sleep(wait_time).await;
}
}
pub async fn try_acquire(&self) -> Result<(), crate::Error> {
let mut limiter = self.inner.lock().await;
let now = Instant::now();
let elapsed = now.duration_since(limiter.last_refill);
let tokens_to_add = (elapsed.as_secs_f64() * f64::from(limiter.refill_rate)) as u32;
if tokens_to_add > 0 {
limiter.tokens = (limiter.tokens + tokens_to_add).min(limiter.capacity);
limiter.last_refill = now;
}
if limiter.tokens > 0 {
limiter.tokens -= 1;
Ok(())
} else {
let retry_after = (1.0 / f64::from(limiter.refill_rate)).ceil() as u64;
Err(crate::Error::RateLimitExceeded { retry_after })
}
}
pub async fn available_tokens(&self) -> u32 {
let mut limiter = self.inner.lock().await;
let now = Instant::now();
let elapsed = now.duration_since(limiter.last_refill);
let tokens_to_add = (elapsed.as_secs_f64() * f64::from(limiter.refill_rate)) as u32;
if tokens_to_add > 0 {
limiter.tokens = (limiter.tokens + tokens_to_add).min(limiter.capacity);
limiter.last_refill = now;
}
limiter.tokens
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_rate_limiter_basic() {
let limiter = RateLimiter::new(2, 2);
assert!(limiter.try_acquire().await.is_ok());
assert!(limiter.try_acquire().await.is_ok());
assert!(limiter.try_acquire().await.is_err());
sleep(Duration::from_millis(600)).await;
assert!(limiter.try_acquire().await.is_ok());
}
}