Skip to main content

doiget_core/
rate_limiter.rs

1//! Process-wide rate limiter for HTTP fetches across all `Source` impls.
2//!
3//! See `docs/SECURITY.md` (per-session fetch flood mitigation) and
4//! `docs/SOURCES.md` §6 (Politeness defaults). The constants enforced here
5//! are the load-bearing safeguards from `docs/LEGAL.md` §6 safeguard 8.
6
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::Duration;
10
11use tokio::sync::{Mutex, OwnedSemaphorePermit, Semaphore};
12use tokio::time::{sleep_until, Instant};
13
14use crate::RateLimits;
15
16/// Process-wide async rate limiter.
17///
18/// Enforces three invariants on every [`acquire`](RateLimiter::acquire):
19///   1. **Global concurrency** — at most
20///      [`RateLimits::max_concurrent_fetches`](crate::RateLimits::max_concurrent_fetches)
21///      in flight at once.
22///   2. **Global rate** — at most
23///      [`RateLimits::max_fetches_per_second`](crate::RateLimits::max_fetches_per_second)
24///      starts in any rolling one-second window.
25///   3. **Per-source backoff** — at least
26///      [`RateLimits::per_source_backoff_ms`](crate::RateLimits::per_source_backoff_ms)
27///      between consecutive starts to the same source name.
28///
29/// The returned [`Permit`] holds the concurrency slot for the lifetime of the
30/// value; drop it when the fetch is done.
31///
32/// 429 / `Retry-After` handling is split: the limiter only exposes the admin
33/// hook [`sleep_for`](RateLimiter::sleep_for); the actual `Retry-After`
34/// header parse and call lives at the `Source::fetch` call site, per
35/// `docs/SOURCES.md` §6.
36#[derive(Debug)]
37pub struct RateLimiter {
38    limits: RateLimits,
39    sem: Arc<Semaphore>,
40    // Global rolling-second window: timestamps of starts within the last second.
41    global_starts: Arc<Mutex<Vec<Instant>>>,
42    // Earliest-allowed start time per source name.
43    per_source_next: Arc<Mutex<HashMap<String, Instant>>>,
44}
45
46/// Held while a fetch is in flight; releases the concurrency slot on drop.
47#[derive(Debug)]
48pub struct Permit {
49    _slot: OwnedSemaphorePermit,
50}
51
52impl RateLimiter {
53    /// Construct from the hard-coded [`RateLimits`] (the only public path).
54    pub fn new(limits: RateLimits) -> Self {
55        let max = limits.max_concurrent_fetches() as usize;
56        Self {
57            limits,
58            sem: Arc::new(Semaphore::new(max)),
59            global_starts: Arc::new(Mutex::new(Vec::new())),
60            per_source_next: Arc::new(Mutex::new(HashMap::new())),
61        }
62    }
63
64    /// Block until a slot is available, then return a [`Permit`].
65    ///
66    /// Order of waits, in this exact sequence:
67    ///   1. global concurrency (semaphore acquire),
68    ///   2. global rate cap (sleep if the rolling-second window is full),
69    ///   3. per-source backoff (sleep until the source's `next` time).
70    ///
71    /// Lock-acquisition order is always `global_starts` first, THEN
72    /// `per_source_next`. Any future call site that needs both locks MUST
73    /// follow the same order to keep the system deadlock-free.
74    pub async fn acquire(&self, source: &str) -> Permit {
75        // Step 1: global concurrency — bounded by Semaphore::new(max).
76        // `acquire_owned` only errors when the semaphore is closed; this
77        // type never closes it (no `close()` call exists), so the Err arm
78        // is structurally unreachable. The local `allow` is the documented
79        // exception to the workspace `expect_used` lint.
80        #[allow(clippy::expect_used)]
81        let slot = self
82            .sem
83            .clone()
84            .acquire_owned()
85            .await
86            .expect("rate-limiter semaphore is never closed");
87
88        // Step 2: global rate cap. Loop until the rolling-second window has
89        // room, sleeping until the oldest entry ages out if it does not.
90        let max_per_sec = self.limits.max_fetches_per_second() as usize;
91        let one_sec = Duration::from_secs(1);
92        loop {
93            let mut starts = self.global_starts.lock().await;
94            let now = Instant::now();
95            // Prune entries older than 1 s. starts is FIFO so we can drop
96            // a contiguous prefix.
97            let cutoff = now.checked_sub(one_sec).unwrap_or(now);
98            let drop_count = starts.iter().take_while(|t| **t <= cutoff).count();
99            if drop_count > 0 {
100                starts.drain(..drop_count);
101            }
102            if starts.len() < max_per_sec {
103                break;
104            }
105            // Window is full — wake at the moment the oldest entry ages out.
106            // starts.len() >= max_per_sec >= 1 here, so [0] is safe.
107            let wake = starts[0] + one_sec;
108            drop(starts);
109            sleep_until(wake).await;
110        }
111
112        // Step 3: per-source backoff. Acquire `per_source_next` strictly
113        // after dropping `global_starts` above (lock order documented).
114        let backoff = Duration::from_millis(self.limits.per_source_backoff_ms());
115        let mut next_map = self.per_source_next.lock().await;
116        let now = Instant::now();
117        if let Some(&next) = next_map.get(source) {
118            if now < next {
119                drop(next_map);
120                sleep_until(next).await;
121                next_map = self.per_source_next.lock().await;
122            }
123        }
124        // Record this start in both ledgers. We re-read `Instant::now()`
125        // because we may have slept in step 2 or step 3.
126        let start = Instant::now();
127        next_map.insert(source.to_string(), start + backoff);
128        drop(next_map);
129
130        // Push the start timestamp into the global window. Done AFTER
131        // releasing per_source_next to keep the documented lock order
132        // (global → per-source) on every code path.
133        let mut starts = self.global_starts.lock().await;
134        starts.push(start);
135        drop(starts);
136
137        Permit { _slot: slot }
138    }
139
140    /// Tell the limiter to delay further starts to `source` by at least
141    /// `dur`. Used when the source returns 429 with `Retry-After`.
142    pub async fn sleep_for(&self, source: &str, dur: Duration) {
143        let mut next_map = self.per_source_next.lock().await;
144        let target = Instant::now() + dur;
145        let entry = next_map.entry(source.to_string()).or_insert(target);
146        if *entry < target {
147            *entry = target;
148        }
149    }
150}
151
152// ---------------------------------------------------------------------------
153// Tests
154// ---------------------------------------------------------------------------
155
156#[cfg(test)]
157#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
158mod tests {
159    use super::*;
160    use std::sync::atomic::{AtomicUsize, Ordering};
161
162    use crate::{RateLimits, MAX_CONCURRENT_FETCHES, MAX_FETCHES_PER_SECOND};
163
164    /// Convenience: shared `Arc<RateLimiter>` initialized from
165    /// `RateLimits::HARD_CODED`.
166    fn limiter() -> Arc<RateLimiter> {
167        Arc::new(RateLimiter::new(RateLimits::HARD_CODED))
168    }
169
170    #[tokio::test(flavor = "current_thread", start_paused = true)]
171    async fn concurrent_acquires_respect_max_concurrency() {
172        // Spawn 10 tasks racing to acquire; assert the live count never
173        // exceeds MAX_CONCURRENT_FETCHES.
174        let rl = limiter();
175        let live = Arc::new(AtomicUsize::new(0));
176        let max_seen = Arc::new(AtomicUsize::new(0));
177        let mut handles = Vec::new();
178        for i in 0..10u32 {
179            let rl = rl.clone();
180            let live = live.clone();
181            let max_seen = max_seen.clone();
182            let src = format!("src-{}", i);
183            handles.push(tokio::spawn(async move {
184                let permit = rl.acquire(&src).await;
185                let now = live.fetch_add(1, Ordering::SeqCst) + 1;
186                max_seen.fetch_max(now, Ordering::SeqCst);
187                // Hold the permit briefly so peers contend.
188                tokio::time::sleep(Duration::from_millis(50)).await;
189                live.fetch_sub(1, Ordering::SeqCst);
190                drop(permit);
191            }));
192        }
193        for h in handles {
194            h.await.expect("task ok");
195        }
196        let max = max_seen.load(Ordering::SeqCst);
197        assert!(
198            max <= MAX_CONCURRENT_FETCHES as usize,
199            "max concurrent live = {}, expected <= {}",
200            max,
201            MAX_CONCURRENT_FETCHES
202        );
203        assert!(max > 0, "at least one acquire should succeed");
204    }
205
206    #[tokio::test(flavor = "current_thread", start_paused = true)]
207    async fn same_source_starts_separated_by_backoff() {
208        // Two acquires for the same source must be at least
209        // per_source_backoff_ms apart.
210        let rl = limiter();
211        let backoff_ms = RateLimits::HARD_CODED.per_source_backoff_ms();
212
213        let t0 = Instant::now();
214        let p0 = rl.acquire("crossref").await;
215        drop(p0);
216        let _p1 = rl.acquire("crossref").await;
217        let elapsed = Instant::now().duration_since(t0);
218
219        assert!(
220            elapsed >= Duration::from_millis(backoff_ms),
221            "elapsed {:?} < backoff {} ms",
222            elapsed,
223            backoff_ms
224        );
225    }
226
227    #[tokio::test(flavor = "current_thread", start_paused = true)]
228    async fn different_sources_no_per_source_wait() {
229        // Acquire source A, then source B back-to-back: per-source backoff
230        // must not apply between distinct sources. (Global rate still
231        // applies; with only two starts it does not bind.)
232        let rl = limiter();
233        let backoff = Duration::from_millis(RateLimits::HARD_CODED.per_source_backoff_ms());
234
235        let t0 = Instant::now();
236        let _p_a = rl.acquire("source-a").await;
237        let _p_b = rl.acquire("source-b").await;
238        let elapsed = Instant::now().duration_since(t0);
239
240        assert!(
241            elapsed < backoff,
242            "elapsed {:?} should be well under per-source backoff {:?}",
243            elapsed,
244            backoff
245        );
246    }
247
248    #[tokio::test(flavor = "current_thread", start_paused = true)]
249    async fn global_rate_caps_starts_per_second() {
250        // Acquire 10 distinct sources back-to-back, dropping each permit
251        // immediately so the concurrency cap (5) does not collide with the
252        // rate cap we're trying to observe. Only MAX_FETCHES_PER_SECOND
253        // starts may complete in the first second; the remainder must wait
254        // for the rolling-second window to free.
255        let rl = limiter();
256        let max_per_sec = MAX_FETCHES_PER_SECOND as usize;
257
258        let t0 = Instant::now();
259        let mut completion_offsets: Vec<Duration> = Vec::with_capacity(10);
260        for i in 0..10u32 {
261            let src = format!("src-{}", i);
262            let p = rl.acquire(&src).await;
263            completion_offsets.push(Instant::now().duration_since(t0));
264            drop(p); // release immediately — we are testing rate, not concurrency.
265        }
266
267        // Within the first second from t0, at most max_per_sec acquires
268        // should have completed.
269        let in_first_sec = completion_offsets
270            .iter()
271            .filter(|d| **d < Duration::from_secs(1))
272            .count();
273        assert!(
274            in_first_sec <= max_per_sec,
275            "{} starts completed in first second, expected <= {}",
276            in_first_sec,
277            max_per_sec
278        );
279    }
280
281    #[tokio::test(flavor = "current_thread", start_paused = true)]
282    async fn sleep_for_delays_target_source() {
283        // sleep_for("X", 500ms) then acquire("X") must take at least 500
284        // ms; acquire("Y") in the same window must NOT be delayed by it.
285        let rl = limiter();
286        let delay = Duration::from_millis(500);
287        rl.sleep_for("X", delay).await;
288
289        // Y is unaffected.
290        let t_y = Instant::now();
291        let _p_y = rl.acquire("Y").await;
292        let elapsed_y = Instant::now().duration_since(t_y);
293        assert!(
294            elapsed_y < delay,
295            "Y elapsed {:?} should be far less than {:?}",
296            elapsed_y,
297            delay
298        );
299
300        // X is delayed by at least `delay`.
301        let t_x = Instant::now();
302        let _p_x = rl.acquire("X").await;
303        let elapsed_x = Instant::now().duration_since(t_x);
304        assert!(
305            elapsed_x >= delay,
306            "X elapsed {:?} < requested delay {:?}",
307            elapsed_x,
308            delay
309        );
310    }
311}