1#![doc=include_str!( "../README.md")]
2#![allow(clippy::declare_interior_mutable_const)]
3
4mod error;
5mod gcra;
6mod nanos;
7mod quota;
8mod snapshot;
9mod state;
10mod timer;
11
12pub use error::TooManyRequests;
13pub use quota::Quota;
14pub use snapshot::RateSnapshot;
15
16use std::{
17 net::{IpAddr, SocketAddr},
18 sync::Arc,
19};
20
21use http::header::{HeaderMap, HeaderName, FORWARDED};
22
23use crate::state::{keyed::DefaultKeyedStateStore, RateLimiter};
24
25#[derive(Clone)]
26pub struct RateLimit {
27 limit: Arc<RateLimiter<IpAddr, DefaultKeyedStateStore<IpAddr>>>,
28}
29
30impl RateLimit {
31 pub fn new(quota: Quota) -> Self {
33 Self {
34 limit: Arc::new(RateLimiter::hashmap(quota)),
35 }
36 }
37
38 pub fn rate_limit(&self, headers: &HeaderMap, addr: &SocketAddr) -> Result<RateSnapshot, TooManyRequests> {
45 let addr = maybe_x_forwarded_for(headers)
46 .or_else(|| maybe_x_real_ip(headers))
47 .or_else(|| maybe_forwarded(headers))
48 .unwrap_or_else(|| addr.ip());
49 self.limit.check_key(&addr).map_err(TooManyRequests::from)
50 }
51}
52
53const X_REAL_IP: HeaderName = HeaderName::from_static("x-real-ip");
54const X_FORWARDED_FOR: HeaderName = HeaderName::from_static("x-forwarded-for");
55
56fn maybe_x_forwarded_for(headers: &HeaderMap) -> Option<IpAddr> {
57 headers
58 .get(X_FORWARDED_FOR)
59 .and_then(|hv| hv.to_str().ok())
60 .and_then(|s| s.split(',').find_map(|s| s.trim().parse().ok()))
61}
62
63fn maybe_x_real_ip(headers: &HeaderMap) -> Option<IpAddr> {
64 headers
65 .get(X_REAL_IP)
66 .and_then(|hv| hv.to_str().ok())
67 .and_then(|s| s.parse().ok())
68}
69
70fn maybe_forwarded(headers: &HeaderMap) -> Option<IpAddr> {
71 headers
72 .get_all(FORWARDED)
73 .iter()
74 .filter_map(|h| h.to_str().ok())
75 .flat_map(|val| val.split(';'))
76 .flat_map(|p| p.split(','))
77 .map(|val| val.trim().splitn(2, '='))
78 .find_map(|mut val| match (val.next(), val.next()) {
79 (Some(name), Some(val)) if name.trim().eq_ignore_ascii_case("for") => {
80 let val = val.trim();
81 val.parse::<IpAddr>()
82 .or_else(|_| val.parse::<SocketAddr>().map(|addr| addr.ip()))
83 .ok()
84 }
85 _ => None,
86 })
87}
88
89#[cfg(test)]
90type DefaultDirectRateLimiter = RateLimiter<state::direct::NotKeyed, state::InMemoryState>;
91
92#[cfg(test)]
93mod test {
94 use core::{num::NonZeroU32, time::Duration};
95
96 use std::thread;
97
98 use all_asserts::*;
99 use http::header::HeaderValue;
100
101 use crate::{
102 error::InsufficientCapacity,
103 quota::Quota,
104 state::RateLimiter,
105 timer::{DefaultTimer, FakeRelativeClock, Timer},
106 DefaultDirectRateLimiter,
107 };
108
109 use super::*;
110
111 #[test]
112 fn forwarded_header() {
113 let mut headers = HeaderMap::new();
114 headers.insert(
115 FORWARDED,
116 HeaderValue::from_static("for =192.0.2.60;proto=http;by=203.0.113.43"),
117 );
118 assert_eq!(maybe_forwarded(&headers).unwrap().to_string(), "192.0.2.60");
119 }
120
121 #[test]
122 fn rejects_too_many() {
123 let clock = FakeRelativeClock::default();
124 let lb = RateLimiter::direct_with_clock(Quota::per_second(2), &clock);
125 let ms = Duration::from_millis(1);
126
127 assert!(lb.check().is_ok(), "Now: {:?}", clock.now());
129 clock.advance(ms);
130 assert!(lb.check().is_ok(), "Now: {:?}", clock.now());
131
132 clock.advance(ms);
133 assert!(lb.check().is_err(), "Now: {:?}", clock.now());
134
135 clock.advance(ms * 1000);
137 assert!(lb.check().is_ok(), "Now: {:?}", clock.now());
138 clock.advance(ms);
139 assert!(lb.check().is_ok());
140
141 clock.advance(ms);
142 assert!(lb.check().is_err(), "{lb:?}");
143 }
144
145 #[test]
146 fn all_1_identical_to_1() {
147 let clock = FakeRelativeClock::default();
148 let lb = RateLimiter::direct_with_clock(Quota::per_second(2), &clock);
149 let ms = Duration::from_millis(1);
150 let one = NonZeroU32::new(1).unwrap();
151
152 assert!(lb.check_n(one).unwrap().is_ok(), "Now: {:?}", clock.now());
154 clock.advance(ms);
155 assert!(lb.check_n(one).unwrap().is_ok(), "Now: {:?}", clock.now());
156
157 clock.advance(ms);
158 assert!(lb.check_n(one).unwrap().is_err(), "Now: {:?}", clock.now());
159
160 clock.advance(ms * 1000);
162 assert!(lb.check_n(one).unwrap().is_ok(), "Now: {:?}", clock.now());
163 clock.advance(ms);
164 assert!(lb.check_n(one).unwrap().is_ok());
165
166 clock.advance(ms);
167 assert!(lb.check_n(one).unwrap().is_err(), "{lb:?}");
168 }
169
170 #[test]
171 fn never_allows_more_than_capacity_all() {
172 let clock = FakeRelativeClock::default();
173 let lb = RateLimiter::direct_with_clock(Quota::per_second(4), &clock);
174 let ms = Duration::from_millis(1);
175
176 let num = NonZeroU32::new(2).unwrap();
177
178 assert!(lb.check_n(num).unwrap().is_ok());
180 assert!(lb.check_n(num).unwrap().is_ok());
181
182 clock.advance(ms);
183 assert!(lb.check_n(num).unwrap().is_err());
184
185 clock.advance(ms * 1000);
187 assert!(lb.check_n(num).unwrap().is_ok());
188 clock.advance(ms);
189 assert!(lb.check_n(num).unwrap().is_ok());
190
191 clock.advance(ms);
192 assert!(lb.check_n(num).unwrap().is_err(), "{:?}", lb);
193 }
194
195 #[test]
196 fn rejects_too_many_all() {
197 let clock = FakeRelativeClock::default();
198 let lb = RateLimiter::direct_with_clock(Quota::per_second(5), &clock);
199 let ms = Duration::from_millis(1);
200
201 let num = NonZeroU32::new(15).unwrap();
202
203 assert!(lb.check_n(num).is_err());
205
206 clock.advance(ms * 3 * 1000);
208 assert!(lb.check_n(num).is_err());
209 }
210
211 #[test]
212 fn all_capacity_check_rejects_excess() {
213 let clock = FakeRelativeClock::default();
214 let lb = RateLimiter::direct_with_clock(Quota::per_second(5), &clock);
215
216 assert_eq!(Err(InsufficientCapacity(5)), lb.check_n(NonZeroU32::new(15).unwrap()));
217 assert_eq!(Err(InsufficientCapacity(5)), lb.check_n(NonZeroU32::new(6).unwrap()));
218 assert_eq!(Err(InsufficientCapacity(5)), lb.check_n(NonZeroU32::new(7).unwrap()));
219 }
220
221 #[test]
222 fn correct_wait_time() {
223 let clock = FakeRelativeClock::default();
224 let lb = RateLimiter::direct_with_clock(Quota::per_second(5), &clock);
226 let ms = Duration::from_millis(1);
227 let mut conforming = 0;
228 for _i in 0..20 {
229 clock.advance(ms);
230 let res = lb.check();
231 match res {
232 Ok(_) => {
233 conforming += 1;
234 }
235 Err(wait) => {
236 clock.advance(wait.wait_time_from(clock.now()));
237 assert!(lb.check().is_ok());
238 conforming += 1;
239 }
240 }
241 }
242 assert_eq!(20, conforming);
243 }
244
245 #[test]
246 fn actual_threadsafety() {
247 use crossbeam;
248
249 let clock = FakeRelativeClock::default();
250 let lim = RateLimiter::direct_with_clock(Quota::per_second(20), &clock);
251 let ms = Duration::from_millis(1);
252
253 crossbeam::scope(|scope| {
254 for _i in 0..20 {
255 scope.spawn(|_| {
256 assert!(lim.check().is_ok());
257 });
258 }
259 })
260 .unwrap();
261
262 clock.advance(ms * 2);
263 assert!(lim.check().is_err());
264 clock.advance(ms * 998);
265 assert!(lim.check().is_ok());
266 }
267
268 #[test]
269 fn default_direct() {
270 let limiter = RateLimiter::direct_with_clock(Quota::per_second(20), &DefaultTimer);
271 assert!(limiter.check().is_ok());
272 }
273
274 #[test]
275 fn stresstest_large_quotas() {
276 use std::{sync::Arc, thread};
277
278 let quota = Quota::per_second(1_000_000_001);
279 let rate_limiter = Arc::new(RateLimiter::direct(quota));
280
281 fn rlspin(rl: Arc<DefaultDirectRateLimiter>) {
282 for _ in 0..1_000_000 {
283 rl.check().map_err(|e| dbg!(e)).unwrap();
284 }
285 }
286
287 let rate_limiter2 = rate_limiter.clone();
288 thread::spawn(move || {
289 rlspin(rate_limiter2);
290 });
291 rlspin(rate_limiter);
292 }
293
294 const KEYS: &[u32] = &[1u32, 2u32];
295
296 #[test]
297 fn accepts_first_cell() {
298 let clock = FakeRelativeClock::default();
299 let lb = RateLimiter::hashmap_with_clock(Quota::per_second(5), &clock);
300 for key in KEYS {
301 assert!(lb.check_key(&key).is_ok(), "key {:?}", key);
302 }
303 }
304
305 use crate::state::keyed::HashMapStateStore;
306 use core::hash::Hash;
307
308 fn retained_keys<T: Clone + Hash + Eq + Copy + Ord>(
309 limiter: RateLimiter<T, HashMapStateStore<T>, FakeRelativeClock>,
310 ) -> Vec<T> {
311 let state = limiter.into_state_store();
312 let map = state.lock().unwrap();
313 let mut keys: Vec<T> = map.keys().copied().collect();
314 keys.sort();
315 keys
316 }
317
318 #[test]
319 fn expiration() {
320 let clock = FakeRelativeClock::default();
321 let ms = Duration::from_millis(1);
322
323 let make_bucket = || {
324 let lim = RateLimiter::hashmap_with_clock(Quota::per_second(1), &clock);
325 lim.check_key(&"foo").unwrap();
326 clock.advance(ms * 200);
327 lim.check_key(&"bar").unwrap();
328 clock.advance(ms * 600);
329 lim.check_key(&"baz").unwrap();
330 lim
331 };
332 let keys = &["bar", "baz", "foo"];
333
334 let lim_shrunk = make_bucket();
336 lim_shrunk.retain_recent();
337 assert_eq!(retained_keys(lim_shrunk), keys);
338
339 let lim_later = make_bucket();
340 clock.advance(ms * 1200);
341 lim_later.retain_recent();
342 assert_eq!(retained_keys(lim_later), vec!["bar", "baz"]);
343
344 let lim_later = make_bucket();
345 clock.advance(ms * (1200 + 200));
346 lim_later.retain_recent();
347 assert_eq!(retained_keys(lim_later), vec!["baz"]);
348
349 let lim_later = make_bucket();
350 clock.advance(ms * (1200 + 200 + 600));
351 lim_later.retain_recent();
352 assert_eq!(retained_keys(lim_later), Vec::<&str>::new());
353 }
354
355 #[test]
356 fn hashmap_length() {
357 let lim = RateLimiter::hashmap(Quota::per_second(1));
358 assert_eq!(lim.len(), 0);
359 assert!(lim.is_empty());
360
361 lim.check_key(&"foo").unwrap();
362 assert_eq!(lim.len(), 1);
363 assert!(!lim.is_empty(),);
364
365 lim.check_key(&"bar").unwrap();
366 assert_eq!(lim.len(), 2);
367 assert!(!lim.is_empty());
368
369 lim.check_key(&"baz").unwrap();
370 assert_eq!(lim.len(), 3);
371 assert!(!lim.is_empty());
372 }
373
374 #[test]
375 fn hashmap_shrink_to_fit() {
376 let clock = FakeRelativeClock::default();
377 let lim = RateLimiter::hashmap_with_clock(Quota::per_second(20), &clock);
379 let ms = Duration::from_millis(1);
380
381 assert!(lim
382 .check_key_n(&"long-lived".to_string(), NonZeroU32::new(10).unwrap())
383 .unwrap()
384 .is_ok(),);
385 assert!(lim.check_key(&"short-lived".to_string()).is_ok());
386
387 clock.advance(ms * 300);
389 lim.retain_recent();
390 lim.shrink_to_fit();
391
392 assert_eq!(lim.len(), 1);
393 }
394
395 fn resident_memory_size() -> i64 {
396 let mut out: libc::rusage = unsafe { std::mem::zeroed() };
397 assert!(unsafe { libc::getrusage(libc::RUSAGE_SELF, &mut out) } == 0);
398 out.ru_maxrss
399 }
400
401 const LEAK_TOLERANCE: i64 = 1024 * 1024 * 10;
402
403 struct LeakCheck {
404 usage_before: i64,
405 n_iter: usize,
406 }
407
408 impl Drop for LeakCheck {
409 fn drop(&mut self) {
410 let usage_after = resident_memory_size();
411 assert_le!(usage_after, self.usage_before + LEAK_TOLERANCE);
412 }
413 }
414
415 impl LeakCheck {
416 fn new(n_iter: usize) -> Self {
417 LeakCheck {
418 n_iter,
419 usage_before: resident_memory_size(),
420 }
421 }
422 }
423
424 #[test]
425 fn memleak_gcra() {
426 let bucket = RateLimiter::direct(Quota::per_second(1_000_000));
427
428 let leak_check = LeakCheck::new(500_000);
429
430 for _i in 0..leak_check.n_iter {
431 drop(bucket.check());
432 }
433 }
434
435 #[test]
436 fn memleak_gcra_multi() {
437 let bucket = RateLimiter::direct(Quota::per_second(1_000_000));
438 let leak_check = LeakCheck::new(500_000);
439
440 for _i in 0..leak_check.n_iter {
441 drop(bucket.check_n(NonZeroU32::new(2).unwrap()));
442 }
443 }
444
445 #[test]
446 fn memleak_gcra_threaded() {
447 let bucket = Arc::new(RateLimiter::direct(Quota::per_second(1_000_000)));
448 let leak_check = LeakCheck::new(5_000);
449
450 for _i in 0..leak_check.n_iter {
451 let bucket = Arc::clone(&bucket);
452 thread::spawn(move || {
453 assert!(bucket.check().is_ok());
454 })
455 .join()
456 .unwrap();
457 }
458 }
459
460 #[test]
461 fn memleak_keyed() {
462 let bucket = RateLimiter::keyed(Quota::per_second(50));
463
464 let leak_check = LeakCheck::new(500_000);
465
466 for i in 0..leak_check.n_iter {
467 drop(bucket.check_key(&(i % 1000)));
468 }
469 }
470}