use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{Mutex, OwnedSemaphorePermit, Semaphore};
use tokio::time::{sleep_until, Instant};
use crate::RateLimits;
#[derive(Debug)]
pub struct RateLimiter {
limits: RateLimits,
sem: Arc<Semaphore>,
global_starts: Arc<Mutex<Vec<Instant>>>,
per_source_next: Arc<Mutex<HashMap<String, Instant>>>,
}
#[derive(Debug)]
pub struct Permit {
_slot: OwnedSemaphorePermit,
}
impl RateLimiter {
pub fn new(limits: RateLimits) -> Self {
let max = limits.max_concurrent_fetches() as usize;
Self {
limits,
sem: Arc::new(Semaphore::new(max)),
global_starts: Arc::new(Mutex::new(Vec::new())),
per_source_next: Arc::new(Mutex::new(HashMap::new())),
}
}
pub async fn acquire(&self, source: &str) -> Permit {
#[allow(clippy::expect_used)]
let slot = self
.sem
.clone()
.acquire_owned()
.await
.expect("rate-limiter semaphore is never closed");
let max_per_sec = self.limits.max_fetches_per_second() as usize;
let one_sec = Duration::from_secs(1);
loop {
let mut starts = self.global_starts.lock().await;
let now = Instant::now();
let cutoff = now.checked_sub(one_sec).unwrap_or(now);
let drop_count = starts.iter().take_while(|t| **t <= cutoff).count();
if drop_count > 0 {
starts.drain(..drop_count);
}
if starts.len() < max_per_sec {
break;
}
let wake = starts[0] + one_sec;
drop(starts);
sleep_until(wake).await;
}
let backoff = Duration::from_millis(self.limits.per_source_backoff_ms());
let mut next_map = self.per_source_next.lock().await;
let now = Instant::now();
if let Some(&next) = next_map.get(source) {
if now < next {
drop(next_map);
sleep_until(next).await;
next_map = self.per_source_next.lock().await;
}
}
let start = Instant::now();
next_map.insert(source.to_string(), start + backoff);
drop(next_map);
let mut starts = self.global_starts.lock().await;
starts.push(start);
drop(starts);
Permit { _slot: slot }
}
pub async fn sleep_for(&self, source: &str, dur: Duration) {
let mut next_map = self.per_source_next.lock().await;
let target = Instant::now() + dur;
let entry = next_map.entry(source.to_string()).or_insert(target);
if *entry < target {
*entry = target;
}
}
}
#[cfg(test)]
#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use crate::{RateLimits, MAX_CONCURRENT_FETCHES, MAX_FETCHES_PER_SECOND};
fn limiter() -> Arc<RateLimiter> {
Arc::new(RateLimiter::new(RateLimits::HARD_CODED))
}
#[tokio::test(flavor = "current_thread", start_paused = true)]
async fn concurrent_acquires_respect_max_concurrency() {
let rl = limiter();
let live = Arc::new(AtomicUsize::new(0));
let max_seen = Arc::new(AtomicUsize::new(0));
let mut handles = Vec::new();
for i in 0..10u32 {
let rl = rl.clone();
let live = live.clone();
let max_seen = max_seen.clone();
let src = format!("src-{}", i);
handles.push(tokio::spawn(async move {
let permit = rl.acquire(&src).await;
let now = live.fetch_add(1, Ordering::SeqCst) + 1;
max_seen.fetch_max(now, Ordering::SeqCst);
tokio::time::sleep(Duration::from_millis(50)).await;
live.fetch_sub(1, Ordering::SeqCst);
drop(permit);
}));
}
for h in handles {
h.await.expect("task ok");
}
let max = max_seen.load(Ordering::SeqCst);
assert!(
max <= MAX_CONCURRENT_FETCHES as usize,
"max concurrent live = {}, expected <= {}",
max,
MAX_CONCURRENT_FETCHES
);
assert!(max > 0, "at least one acquire should succeed");
}
#[tokio::test(flavor = "current_thread", start_paused = true)]
async fn same_source_starts_separated_by_backoff() {
let rl = limiter();
let backoff_ms = RateLimits::HARD_CODED.per_source_backoff_ms();
let t0 = Instant::now();
let p0 = rl.acquire("crossref").await;
drop(p0);
let _p1 = rl.acquire("crossref").await;
let elapsed = Instant::now().duration_since(t0);
assert!(
elapsed >= Duration::from_millis(backoff_ms),
"elapsed {:?} < backoff {} ms",
elapsed,
backoff_ms
);
}
#[tokio::test(flavor = "current_thread", start_paused = true)]
async fn different_sources_no_per_source_wait() {
let rl = limiter();
let backoff = Duration::from_millis(RateLimits::HARD_CODED.per_source_backoff_ms());
let t0 = Instant::now();
let _p_a = rl.acquire("source-a").await;
let _p_b = rl.acquire("source-b").await;
let elapsed = Instant::now().duration_since(t0);
assert!(
elapsed < backoff,
"elapsed {:?} should be well under per-source backoff {:?}",
elapsed,
backoff
);
}
#[tokio::test(flavor = "current_thread", start_paused = true)]
async fn global_rate_caps_starts_per_second() {
let rl = limiter();
let max_per_sec = MAX_FETCHES_PER_SECOND as usize;
let t0 = Instant::now();
let mut completion_offsets: Vec<Duration> = Vec::with_capacity(10);
for i in 0..10u32 {
let src = format!("src-{}", i);
let p = rl.acquire(&src).await;
completion_offsets.push(Instant::now().duration_since(t0));
drop(p); }
let in_first_sec = completion_offsets
.iter()
.filter(|d| **d < Duration::from_secs(1))
.count();
assert!(
in_first_sec <= max_per_sec,
"{} starts completed in first second, expected <= {}",
in_first_sec,
max_per_sec
);
}
#[tokio::test(flavor = "current_thread", start_paused = true)]
async fn sleep_for_delays_target_source() {
let rl = limiter();
let delay = Duration::from_millis(500);
rl.sleep_for("X", delay).await;
let t_y = Instant::now();
let _p_y = rl.acquire("Y").await;
let elapsed_y = Instant::now().duration_since(t_y);
assert!(
elapsed_y < delay,
"Y elapsed {:?} should be far less than {:?}",
elapsed_y,
delay
);
let t_x = Instant::now();
let _p_x = rl.acquire("X").await;
let elapsed_x = Instant::now().duration_since(t_x);
assert!(
elapsed_x >= delay,
"X elapsed {:?} < requested delay {:?}",
elapsed_x,
delay
);
}
}