use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
use tokio::time::sleep;
pub struct RateLimiter {
last_request: Instant,
min_interval: Duration,
}
impl RateLimiter {
pub fn new(requests_per_second: u32) -> Self {
let min_interval = if requests_per_second > 0 {
Duration::from_secs_f64(1.0 / requests_per_second as f64)
} else {
Duration::ZERO
};
Self {
last_request: Instant::now() - min_interval,
min_interval,
}
}
pub async fn wait(&mut self) {
let elapsed = self.last_request.elapsed();
if elapsed < self.min_interval {
let wait_time = self.min_interval - elapsed;
sleep(wait_time).await;
}
self.last_request = Instant::now();
}
}
pub struct ConcurrentRateLimiter {
semaphore: Arc<Semaphore>,
next_allowed_nanos: AtomicU64,
epoch: Instant,
min_interval_nanos: u64,
}
impl ConcurrentRateLimiter {
pub fn new(requests_per_second: u32, max_concurrent: usize) -> Self {
let min_interval = if requests_per_second > 0 {
Duration::from_secs_f64(1.0 / requests_per_second as f64)
} else {
Duration::ZERO
};
let epoch = Instant::now();
Self {
semaphore: Arc::new(Semaphore::new(max_concurrent)),
next_allowed_nanos: AtomicU64::new(0),
epoch,
min_interval_nanos: min_interval.as_nanos() as u64,
}
}
pub async fn acquire(&self) -> OwnedSemaphorePermit {
let permit = self
.semaphore
.clone()
.acquire_owned()
.await
.expect("ConcurrentRateLimiter semaphore unexpectedly closed");
let wait_nanos = loop {
let now_nanos = self.epoch.elapsed().as_nanos() as u64;
let current = self.next_allowed_nanos.load(Ordering::Acquire);
let scheduled = if now_nanos >= current {
now_nanos
} else {
current
};
let next = scheduled + self.min_interval_nanos;
match self.next_allowed_nanos.compare_exchange_weak(
current,
next,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => break scheduled.saturating_sub(now_nanos),
Err(_) => continue, }
};
if wait_nanos > 0 {
sleep(Duration::from_nanos(wait_nanos)).await;
}
permit
}
}
use crate::exchange::manifest::{RateLimitCategory, RateLimitConfig};
pub struct CategoryRateLimiter {
limiters: [tokio::sync::Mutex<RateLimiter>; RateLimitCategory::COUNT],
}
impl CategoryRateLimiter {
pub fn from_config(config: &RateLimitConfig) -> Self {
let limiters = RateLimitCategory::ALL.map(|cat| {
let rps = config.rps(cat);
tokio::sync::Mutex::new(RateLimiter::new(rps))
});
Self { limiters }
}
pub async fn wait(&self, category: RateLimitCategory) {
self.limiters[category as usize].lock().await.wait().await;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_rate_limiter_respects_interval() {
let mut limiter = RateLimiter::new(10);
let start = Instant::now();
limiter.wait().await;
limiter.wait().await;
let elapsed = start.elapsed();
assert!(elapsed >= Duration::from_millis(90));
}
}