doiget_core/rate_limiter.rs
1//! Process-wide rate limiter for HTTP fetches across all `Source` impls.
2//!
3//! See `docs/SECURITY.md` (per-session fetch flood mitigation) and
4//! `docs/SOURCES.md` §6 (Politeness defaults). The constants enforced here
5//! are the load-bearing safeguards from `docs/LEGAL.md` §6 safeguard 8.
6
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::Duration;
10
11use tokio::sync::{Mutex, OwnedSemaphorePermit, Semaphore};
12use tokio::time::{sleep_until, Instant};
13
14use crate::RateLimits;
15
16/// Process-wide async rate limiter.
17///
18/// Enforces three invariants on every [`acquire`](RateLimiter::acquire):
19/// 1. **Global concurrency** — at most
20/// [`RateLimits::max_concurrent_fetches`](crate::RateLimits::max_concurrent_fetches)
21/// in flight at once.
22/// 2. **Global rate** — at most
23/// [`RateLimits::max_fetches_per_second`](crate::RateLimits::max_fetches_per_second)
24/// starts in any rolling one-second window.
25/// 3. **Per-source backoff** — at least
26/// [`RateLimits::per_source_backoff_ms`](crate::RateLimits::per_source_backoff_ms)
27/// between consecutive starts to the same source name.
28///
29/// The returned [`Permit`] holds the concurrency slot for the lifetime of the
30/// value; drop it when the fetch is done.
31///
32/// 429 / `Retry-After` handling is split: the limiter only exposes the admin
33/// hook [`sleep_for`](RateLimiter::sleep_for); the actual `Retry-After`
34/// header parse and call lives at the `Source::fetch` call site, per
35/// `docs/SOURCES.md` §6.
36#[derive(Debug)]
37pub struct RateLimiter {
38 limits: RateLimits,
39 sem: Arc<Semaphore>,
40 // Global rolling-second window: timestamps of starts within the last second.
41 global_starts: Arc<Mutex<Vec<Instant>>>,
42 // Earliest-allowed start time per source name.
43 per_source_next: Arc<Mutex<HashMap<String, Instant>>>,
44}
45
46/// Held while a fetch is in flight; releases the concurrency slot on drop.
47#[derive(Debug)]
48pub struct Permit {
49 _slot: OwnedSemaphorePermit,
50}
51
52impl RateLimiter {
53 /// Construct from the hard-coded [`RateLimits`] (the only public path).
54 pub fn new(limits: RateLimits) -> Self {
55 let max = limits.max_concurrent_fetches() as usize;
56 Self {
57 limits,
58 sem: Arc::new(Semaphore::new(max)),
59 global_starts: Arc::new(Mutex::new(Vec::new())),
60 per_source_next: Arc::new(Mutex::new(HashMap::new())),
61 }
62 }
63
64 /// Block until a slot is available, then return a [`Permit`].
65 ///
66 /// Order of waits, in this exact sequence:
67 /// 1. global concurrency (semaphore acquire),
68 /// 2. global rate cap (sleep if the rolling-second window is full),
69 /// 3. per-source backoff (sleep until the source's `next` time).
70 ///
71 /// Lock-acquisition order is always `global_starts` first, THEN
72 /// `per_source_next`. Any future call site that needs both locks MUST
73 /// follow the same order to keep the system deadlock-free.
74 pub async fn acquire(&self, source: &str) -> Permit {
75 // Step 1: global concurrency — bounded by Semaphore::new(max).
76 // `acquire_owned` only errors when the semaphore is closed; this
77 // type never closes it (no `close()` call exists), so the Err arm
78 // is structurally unreachable. The local `allow` is the documented
79 // exception to the workspace `expect_used` lint.
80 #[allow(clippy::expect_used)]
81 let slot = self
82 .sem
83 .clone()
84 .acquire_owned()
85 .await
86 .expect("rate-limiter semaphore is never closed");
87
88 // Step 2: global rate cap. Loop until the rolling-second window has
89 // room, sleeping until the oldest entry ages out if it does not.
90 let max_per_sec = self.limits.max_fetches_per_second() as usize;
91 let one_sec = Duration::from_secs(1);
92 loop {
93 let mut starts = self.global_starts.lock().await;
94 let now = Instant::now();
95 // Prune entries older than 1 s. starts is FIFO so we can drop
96 // a contiguous prefix.
97 let cutoff = now.checked_sub(one_sec).unwrap_or(now);
98 let drop_count = starts.iter().take_while(|t| **t <= cutoff).count();
99 if drop_count > 0 {
100 starts.drain(..drop_count);
101 }
102 if starts.len() < max_per_sec {
103 break;
104 }
105 // Window is full — wake at the moment the oldest entry ages out.
106 // starts.len() >= max_per_sec >= 1 here, so [0] is safe.
107 let wake = starts[0] + one_sec;
108 drop(starts);
109 sleep_until(wake).await;
110 }
111
112 // Step 3: per-source backoff. Acquire `per_source_next` strictly
113 // after dropping `global_starts` above (lock order documented).
114 let backoff = Duration::from_millis(self.limits.per_source_backoff_ms());
115 let mut next_map = self.per_source_next.lock().await;
116 let now = Instant::now();
117 if let Some(&next) = next_map.get(source) {
118 if now < next {
119 drop(next_map);
120 sleep_until(next).await;
121 next_map = self.per_source_next.lock().await;
122 }
123 }
124 // Record this start in both ledgers. We re-read `Instant::now()`
125 // because we may have slept in step 2 or step 3.
126 let start = Instant::now();
127 next_map.insert(source.to_string(), start + backoff);
128 drop(next_map);
129
130 // Push the start timestamp into the global window. Done AFTER
131 // releasing per_source_next to keep the documented lock order
132 // (global → per-source) on every code path.
133 let mut starts = self.global_starts.lock().await;
134 starts.push(start);
135 drop(starts);
136
137 Permit { _slot: slot }
138 }
139
140 /// Tell the limiter to delay further starts to `source` by at least
141 /// `dur`. Used when the source returns 429 with `Retry-After`.
142 pub async fn sleep_for(&self, source: &str, dur: Duration) {
143 let mut next_map = self.per_source_next.lock().await;
144 let target = Instant::now() + dur;
145 let entry = next_map.entry(source.to_string()).or_insert(target);
146 if *entry < target {
147 *entry = target;
148 }
149 }
150}
151
152// ---------------------------------------------------------------------------
153// Tests
154// ---------------------------------------------------------------------------
155
156#[cfg(test)]
157#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
158mod tests {
159 use super::*;
160 use std::sync::atomic::{AtomicUsize, Ordering};
161
162 use crate::{RateLimits, MAX_CONCURRENT_FETCHES, MAX_FETCHES_PER_SECOND};
163
164 /// Convenience: shared `Arc<RateLimiter>` initialized from
165 /// `RateLimits::HARD_CODED`.
166 fn limiter() -> Arc<RateLimiter> {
167 Arc::new(RateLimiter::new(RateLimits::HARD_CODED))
168 }
169
170 #[tokio::test(flavor = "current_thread", start_paused = true)]
171 async fn concurrent_acquires_respect_max_concurrency() {
172 // Spawn 10 tasks racing to acquire; assert the live count never
173 // exceeds MAX_CONCURRENT_FETCHES.
174 let rl = limiter();
175 let live = Arc::new(AtomicUsize::new(0));
176 let max_seen = Arc::new(AtomicUsize::new(0));
177 let mut handles = Vec::new();
178 for i in 0..10u32 {
179 let rl = rl.clone();
180 let live = live.clone();
181 let max_seen = max_seen.clone();
182 let src = format!("src-{}", i);
183 handles.push(tokio::spawn(async move {
184 let permit = rl.acquire(&src).await;
185 let now = live.fetch_add(1, Ordering::SeqCst) + 1;
186 max_seen.fetch_max(now, Ordering::SeqCst);
187 // Hold the permit briefly so peers contend.
188 tokio::time::sleep(Duration::from_millis(50)).await;
189 live.fetch_sub(1, Ordering::SeqCst);
190 drop(permit);
191 }));
192 }
193 for h in handles {
194 h.await.expect("task ok");
195 }
196 let max = max_seen.load(Ordering::SeqCst);
197 assert!(
198 max <= MAX_CONCURRENT_FETCHES as usize,
199 "max concurrent live = {}, expected <= {}",
200 max,
201 MAX_CONCURRENT_FETCHES
202 );
203 assert!(max > 0, "at least one acquire should succeed");
204 }
205
206 #[tokio::test(flavor = "current_thread", start_paused = true)]
207 async fn same_source_starts_separated_by_backoff() {
208 // Two acquires for the same source must be at least
209 // per_source_backoff_ms apart.
210 let rl = limiter();
211 let backoff_ms = RateLimits::HARD_CODED.per_source_backoff_ms();
212
213 let t0 = Instant::now();
214 let p0 = rl.acquire("crossref").await;
215 drop(p0);
216 let _p1 = rl.acquire("crossref").await;
217 let elapsed = Instant::now().duration_since(t0);
218
219 assert!(
220 elapsed >= Duration::from_millis(backoff_ms),
221 "elapsed {:?} < backoff {} ms",
222 elapsed,
223 backoff_ms
224 );
225 }
226
227 #[tokio::test(flavor = "current_thread", start_paused = true)]
228 async fn different_sources_no_per_source_wait() {
229 // Acquire source A, then source B back-to-back: per-source backoff
230 // must not apply between distinct sources. (Global rate still
231 // applies; with only two starts it does not bind.)
232 let rl = limiter();
233 let backoff = Duration::from_millis(RateLimits::HARD_CODED.per_source_backoff_ms());
234
235 let t0 = Instant::now();
236 let _p_a = rl.acquire("source-a").await;
237 let _p_b = rl.acquire("source-b").await;
238 let elapsed = Instant::now().duration_since(t0);
239
240 assert!(
241 elapsed < backoff,
242 "elapsed {:?} should be well under per-source backoff {:?}",
243 elapsed,
244 backoff
245 );
246 }
247
248 #[tokio::test(flavor = "current_thread", start_paused = true)]
249 async fn global_rate_caps_starts_per_second() {
250 // Acquire 10 distinct sources back-to-back, dropping each permit
251 // immediately so the concurrency cap (5) does not collide with the
252 // rate cap we're trying to observe. Only MAX_FETCHES_PER_SECOND
253 // starts may complete in the first second; the remainder must wait
254 // for the rolling-second window to free.
255 let rl = limiter();
256 let max_per_sec = MAX_FETCHES_PER_SECOND as usize;
257
258 let t0 = Instant::now();
259 let mut completion_offsets: Vec<Duration> = Vec::with_capacity(10);
260 for i in 0..10u32 {
261 let src = format!("src-{}", i);
262 let p = rl.acquire(&src).await;
263 completion_offsets.push(Instant::now().duration_since(t0));
264 drop(p); // release immediately — we are testing rate, not concurrency.
265 }
266
267 // Within the first second from t0, at most max_per_sec acquires
268 // should have completed.
269 let in_first_sec = completion_offsets
270 .iter()
271 .filter(|d| **d < Duration::from_secs(1))
272 .count();
273 assert!(
274 in_first_sec <= max_per_sec,
275 "{} starts completed in first second, expected <= {}",
276 in_first_sec,
277 max_per_sec
278 );
279 }
280
281 #[tokio::test(flavor = "current_thread", start_paused = true)]
282 async fn sleep_for_delays_target_source() {
283 // sleep_for("X", 500ms) then acquire("X") must take at least 500
284 // ms; acquire("Y") in the same window must NOT be delayed by it.
285 let rl = limiter();
286 let delay = Duration::from_millis(500);
287 rl.sleep_for("X", delay).await;
288
289 // Y is unaffected.
290 let t_y = Instant::now();
291 let _p_y = rl.acquire("Y").await;
292 let elapsed_y = Instant::now().duration_since(t_y);
293 assert!(
294 elapsed_y < delay,
295 "Y elapsed {:?} should be far less than {:?}",
296 elapsed_y,
297 delay
298 );
299
300 // X is delayed by at least `delay`.
301 let t_x = Instant::now();
302 let _p_x = rl.acquire("X").await;
303 let elapsed_x = Instant::now().duration_since(t_x);
304 assert!(
305 elapsed_x >= delay,
306 "X elapsed {:?} < requested delay {:?}",
307 elapsed_x,
308 delay
309 );
310 }
311}