http_rate/
lib.rs

1#![doc=include_str!( "../README.md")]
2#![allow(clippy::declare_interior_mutable_const)]
3
4mod error;
5mod gcra;
6mod nanos;
7mod quota;
8mod snapshot;
9mod state;
10mod timer;
11
12pub use error::TooManyRequests;
13pub use quota::Quota;
14pub use snapshot::RateSnapshot;
15
16use std::{
17    net::{IpAddr, SocketAddr},
18    sync::Arc,
19};
20
21use http::header::{HeaderMap, HeaderName, FORWARDED};
22
23use crate::state::{keyed::DefaultKeyedStateStore, RateLimiter};
24
25#[derive(Clone)]
26pub struct RateLimit {
27    limit: Arc<RateLimiter<IpAddr, DefaultKeyedStateStore<IpAddr>>>,
28}
29
30impl RateLimit {
31    /// Construct a new RateLimit with given quota.
32    pub fn new(quota: Quota) -> Self {
33        Self {
34            limit: Arc::new(RateLimiter::hashmap(quota)),
35        }
36    }
37
38    /// Rate limit [Request] based on it's [HeaderMap] state and given client [SocketAddr]
39    /// "x-real-ip", "x-forwarded-for" and "forwarded" headers are checked in order start
40    /// from left to determine client's socket address. Received [SocketAddr] will be used
41    /// as fallback when all headers are absent or can't provide valid client address.
42    ///
43    /// [Request]: http::Request
44    pub fn rate_limit(&self, headers: &HeaderMap, addr: &SocketAddr) -> Result<RateSnapshot, TooManyRequests> {
45        let addr = maybe_x_forwarded_for(headers)
46            .or_else(|| maybe_x_real_ip(headers))
47            .or_else(|| maybe_forwarded(headers))
48            .unwrap_or_else(|| addr.ip());
49        self.limit.check_key(&addr).map_err(TooManyRequests::from)
50    }
51}
52
53const X_REAL_IP: HeaderName = HeaderName::from_static("x-real-ip");
54const X_FORWARDED_FOR: HeaderName = HeaderName::from_static("x-forwarded-for");
55
56fn maybe_x_forwarded_for(headers: &HeaderMap) -> Option<IpAddr> {
57    headers
58        .get(X_FORWARDED_FOR)
59        .and_then(|hv| hv.to_str().ok())
60        .and_then(|s| s.split(',').find_map(|s| s.trim().parse().ok()))
61}
62
63fn maybe_x_real_ip(headers: &HeaderMap) -> Option<IpAddr> {
64    headers
65        .get(X_REAL_IP)
66        .and_then(|hv| hv.to_str().ok())
67        .and_then(|s| s.parse().ok())
68}
69
70fn maybe_forwarded(headers: &HeaderMap) -> Option<IpAddr> {
71    headers
72        .get_all(FORWARDED)
73        .iter()
74        .filter_map(|h| h.to_str().ok())
75        .flat_map(|val| val.split(';'))
76        .flat_map(|p| p.split(','))
77        .map(|val| val.trim().splitn(2, '='))
78        .find_map(|mut val| match (val.next(), val.next()) {
79            (Some(name), Some(val)) if name.trim().eq_ignore_ascii_case("for") => {
80                let val = val.trim();
81                val.parse::<IpAddr>()
82                    .or_else(|_| val.parse::<SocketAddr>().map(|addr| addr.ip()))
83                    .ok()
84            }
85            _ => None,
86        })
87}
88
89#[cfg(test)]
90type DefaultDirectRateLimiter = RateLimiter<state::direct::NotKeyed, state::InMemoryState>;
91
92#[cfg(test)]
93mod test {
94    use core::{num::NonZeroU32, time::Duration};
95
96    use std::thread;
97
98    use all_asserts::*;
99    use http::header::HeaderValue;
100
101    use crate::{
102        error::InsufficientCapacity,
103        quota::Quota,
104        state::RateLimiter,
105        timer::{DefaultTimer, FakeRelativeClock, Timer},
106        DefaultDirectRateLimiter,
107    };
108
109    use super::*;
110
111    #[test]
112    fn forwarded_header() {
113        let mut headers = HeaderMap::new();
114        headers.insert(
115            FORWARDED,
116            HeaderValue::from_static("for =192.0.2.60;proto=http;by=203.0.113.43"),
117        );
118        assert_eq!(maybe_forwarded(&headers).unwrap().to_string(), "192.0.2.60");
119    }
120
121    #[test]
122    fn rejects_too_many() {
123        let clock = FakeRelativeClock::default();
124        let lb = RateLimiter::direct_with_clock(Quota::per_second(2), &clock);
125        let ms = Duration::from_millis(1);
126
127        // use up our burst capacity (2 in the first second):
128        assert!(lb.check().is_ok(), "Now: {:?}", clock.now());
129        clock.advance(ms);
130        assert!(lb.check().is_ok(), "Now: {:?}", clock.now());
131
132        clock.advance(ms);
133        assert!(lb.check().is_err(), "Now: {:?}", clock.now());
134
135        // should be ok again in 1s:
136        clock.advance(ms * 1000);
137        assert!(lb.check().is_ok(), "Now: {:?}", clock.now());
138        clock.advance(ms);
139        assert!(lb.check().is_ok());
140
141        clock.advance(ms);
142        assert!(lb.check().is_err(), "{lb:?}");
143    }
144
145    #[test]
146    fn all_1_identical_to_1() {
147        let clock = FakeRelativeClock::default();
148        let lb = RateLimiter::direct_with_clock(Quota::per_second(2), &clock);
149        let ms = Duration::from_millis(1);
150        let one = NonZeroU32::new(1).unwrap();
151
152        // use up our burst capacity (2 in the first second):
153        assert!(lb.check_n(one).unwrap().is_ok(), "Now: {:?}", clock.now());
154        clock.advance(ms);
155        assert!(lb.check_n(one).unwrap().is_ok(), "Now: {:?}", clock.now());
156
157        clock.advance(ms);
158        assert!(lb.check_n(one).unwrap().is_err(), "Now: {:?}", clock.now());
159
160        // should be ok again in 1s:
161        clock.advance(ms * 1000);
162        assert!(lb.check_n(one).unwrap().is_ok(), "Now: {:?}", clock.now());
163        clock.advance(ms);
164        assert!(lb.check_n(one).unwrap().is_ok());
165
166        clock.advance(ms);
167        assert!(lb.check_n(one).unwrap().is_err(), "{lb:?}");
168    }
169
170    #[test]
171    fn never_allows_more_than_capacity_all() {
172        let clock = FakeRelativeClock::default();
173        let lb = RateLimiter::direct_with_clock(Quota::per_second(4), &clock);
174        let ms = Duration::from_millis(1);
175
176        let num = NonZeroU32::new(2).unwrap();
177
178        // Use up the burst capacity:
179        assert!(lb.check_n(num).unwrap().is_ok());
180        assert!(lb.check_n(num).unwrap().is_ok());
181
182        clock.advance(ms);
183        assert!(lb.check_n(num).unwrap().is_err());
184
185        // should be ok again in 1s:
186        clock.advance(ms * 1000);
187        assert!(lb.check_n(num).unwrap().is_ok());
188        clock.advance(ms);
189        assert!(lb.check_n(num).unwrap().is_ok());
190
191        clock.advance(ms);
192        assert!(lb.check_n(num).unwrap().is_err(), "{:?}", lb);
193    }
194
195    #[test]
196    fn rejects_too_many_all() {
197        let clock = FakeRelativeClock::default();
198        let lb = RateLimiter::direct_with_clock(Quota::per_second(5), &clock);
199        let ms = Duration::from_millis(1);
200
201        let num = NonZeroU32::new(15).unwrap();
202
203        // Should not allow the first 15 cells on a capacity 5 bucket:
204        assert!(lb.check_n(num).is_err());
205
206        // After 3 and 20 seconds, it should not allow 15 on that bucket either:
207        clock.advance(ms * 3 * 1000);
208        assert!(lb.check_n(num).is_err());
209    }
210
211    #[test]
212    fn all_capacity_check_rejects_excess() {
213        let clock = FakeRelativeClock::default();
214        let lb = RateLimiter::direct_with_clock(Quota::per_second(5), &clock);
215
216        assert_eq!(Err(InsufficientCapacity(5)), lb.check_n(NonZeroU32::new(15).unwrap()));
217        assert_eq!(Err(InsufficientCapacity(5)), lb.check_n(NonZeroU32::new(6).unwrap()));
218        assert_eq!(Err(InsufficientCapacity(5)), lb.check_n(NonZeroU32::new(7).unwrap()));
219    }
220
221    #[test]
222    fn correct_wait_time() {
223        let clock = FakeRelativeClock::default();
224        // Bucket adding a new element per 200ms:
225        let lb = RateLimiter::direct_with_clock(Quota::per_second(5), &clock);
226        let ms = Duration::from_millis(1);
227        let mut conforming = 0;
228        for _i in 0..20 {
229            clock.advance(ms);
230            let res = lb.check();
231            match res {
232                Ok(_) => {
233                    conforming += 1;
234                }
235                Err(wait) => {
236                    clock.advance(wait.wait_time_from(clock.now()));
237                    assert!(lb.check().is_ok());
238                    conforming += 1;
239                }
240            }
241        }
242        assert_eq!(20, conforming);
243    }
244
245    #[test]
246    fn actual_threadsafety() {
247        use crossbeam;
248
249        let clock = FakeRelativeClock::default();
250        let lim = RateLimiter::direct_with_clock(Quota::per_second(20), &clock);
251        let ms = Duration::from_millis(1);
252
253        crossbeam::scope(|scope| {
254            for _i in 0..20 {
255                scope.spawn(|_| {
256                    assert!(lim.check().is_ok());
257                });
258            }
259        })
260        .unwrap();
261
262        clock.advance(ms * 2);
263        assert!(lim.check().is_err());
264        clock.advance(ms * 998);
265        assert!(lim.check().is_ok());
266    }
267
268    #[test]
269    fn default_direct() {
270        let limiter = RateLimiter::direct_with_clock(Quota::per_second(20), &DefaultTimer);
271        assert!(limiter.check().is_ok());
272    }
273
274    #[test]
275    fn stresstest_large_quotas() {
276        use std::{sync::Arc, thread};
277
278        let quota = Quota::per_second(1_000_000_001);
279        let rate_limiter = Arc::new(RateLimiter::direct(quota));
280
281        fn rlspin(rl: Arc<DefaultDirectRateLimiter>) {
282            for _ in 0..1_000_000 {
283                rl.check().map_err(|e| dbg!(e)).unwrap();
284            }
285        }
286
287        let rate_limiter2 = rate_limiter.clone();
288        thread::spawn(move || {
289            rlspin(rate_limiter2);
290        });
291        rlspin(rate_limiter);
292    }
293
294    const KEYS: &[u32] = &[1u32, 2u32];
295
296    #[test]
297    fn accepts_first_cell() {
298        let clock = FakeRelativeClock::default();
299        let lb = RateLimiter::hashmap_with_clock(Quota::per_second(5), &clock);
300        for key in KEYS {
301            assert!(lb.check_key(&key).is_ok(), "key {:?}", key);
302        }
303    }
304
305    use crate::state::keyed::HashMapStateStore;
306    use core::hash::Hash;
307
308    fn retained_keys<T: Clone + Hash + Eq + Copy + Ord>(
309        limiter: RateLimiter<T, HashMapStateStore<T>, FakeRelativeClock>,
310    ) -> Vec<T> {
311        let state = limiter.into_state_store();
312        let map = state.lock().unwrap();
313        let mut keys: Vec<T> = map.keys().copied().collect();
314        keys.sort();
315        keys
316    }
317
318    #[test]
319    fn expiration() {
320        let clock = FakeRelativeClock::default();
321        let ms = Duration::from_millis(1);
322
323        let make_bucket = || {
324            let lim = RateLimiter::hashmap_with_clock(Quota::per_second(1), &clock);
325            lim.check_key(&"foo").unwrap();
326            clock.advance(ms * 200);
327            lim.check_key(&"bar").unwrap();
328            clock.advance(ms * 600);
329            lim.check_key(&"baz").unwrap();
330            lim
331        };
332        let keys = &["bar", "baz", "foo"];
333
334        // clean up all keys that are indistinguishable from unoccupied keys:
335        let lim_shrunk = make_bucket();
336        lim_shrunk.retain_recent();
337        assert_eq!(retained_keys(lim_shrunk), keys);
338
339        let lim_later = make_bucket();
340        clock.advance(ms * 1200);
341        lim_later.retain_recent();
342        assert_eq!(retained_keys(lim_later), vec!["bar", "baz"]);
343
344        let lim_later = make_bucket();
345        clock.advance(ms * (1200 + 200));
346        lim_later.retain_recent();
347        assert_eq!(retained_keys(lim_later), vec!["baz"]);
348
349        let lim_later = make_bucket();
350        clock.advance(ms * (1200 + 200 + 600));
351        lim_later.retain_recent();
352        assert_eq!(retained_keys(lim_later), Vec::<&str>::new());
353    }
354
355    #[test]
356    fn hashmap_length() {
357        let lim = RateLimiter::hashmap(Quota::per_second(1));
358        assert_eq!(lim.len(), 0);
359        assert!(lim.is_empty());
360
361        lim.check_key(&"foo").unwrap();
362        assert_eq!(lim.len(), 1);
363        assert!(!lim.is_empty(),);
364
365        lim.check_key(&"bar").unwrap();
366        assert_eq!(lim.len(), 2);
367        assert!(!lim.is_empty());
368
369        lim.check_key(&"baz").unwrap();
370        assert_eq!(lim.len(), 3);
371        assert!(!lim.is_empty());
372    }
373
374    #[test]
375    fn hashmap_shrink_to_fit() {
376        let clock = FakeRelativeClock::default();
377        // a steady rate of 3ms between elements:
378        let lim = RateLimiter::hashmap_with_clock(Quota::per_second(20), &clock);
379        let ms = Duration::from_millis(1);
380
381        assert!(lim
382            .check_key_n(&"long-lived".to_string(), NonZeroU32::new(10).unwrap())
383            .unwrap()
384            .is_ok(),);
385        assert!(lim.check_key(&"short-lived".to_string()).is_ok());
386
387        // Move the clock forward far enough that the short-lived key gets dropped:
388        clock.advance(ms * 300);
389        lim.retain_recent();
390        lim.shrink_to_fit();
391
392        assert_eq!(lim.len(), 1);
393    }
394
395    fn resident_memory_size() -> i64 {
396        let mut out: libc::rusage = unsafe { std::mem::zeroed() };
397        assert!(unsafe { libc::getrusage(libc::RUSAGE_SELF, &mut out) } == 0);
398        out.ru_maxrss
399    }
400
401    const LEAK_TOLERANCE: i64 = 1024 * 1024 * 10;
402
403    struct LeakCheck {
404        usage_before: i64,
405        n_iter: usize,
406    }
407
408    impl Drop for LeakCheck {
409        fn drop(&mut self) {
410            let usage_after = resident_memory_size();
411            assert_le!(usage_after, self.usage_before + LEAK_TOLERANCE);
412        }
413    }
414
415    impl LeakCheck {
416        fn new(n_iter: usize) -> Self {
417            LeakCheck {
418                n_iter,
419                usage_before: resident_memory_size(),
420            }
421        }
422    }
423
424    #[test]
425    fn memleak_gcra() {
426        let bucket = RateLimiter::direct(Quota::per_second(1_000_000));
427
428        let leak_check = LeakCheck::new(500_000);
429
430        for _i in 0..leak_check.n_iter {
431            drop(bucket.check());
432        }
433    }
434
435    #[test]
436    fn memleak_gcra_multi() {
437        let bucket = RateLimiter::direct(Quota::per_second(1_000_000));
438        let leak_check = LeakCheck::new(500_000);
439
440        for _i in 0..leak_check.n_iter {
441            drop(bucket.check_n(NonZeroU32::new(2).unwrap()));
442        }
443    }
444
445    #[test]
446    fn memleak_gcra_threaded() {
447        let bucket = Arc::new(RateLimiter::direct(Quota::per_second(1_000_000)));
448        let leak_check = LeakCheck::new(5_000);
449
450        for _i in 0..leak_check.n_iter {
451            let bucket = Arc::clone(&bucket);
452            thread::spawn(move || {
453                assert!(bucket.check().is_ok());
454            })
455            .join()
456            .unwrap();
457        }
458    }
459
460    #[test]
461    fn memleak_keyed() {
462        let bucket = RateLimiter::keyed(Quota::per_second(50));
463
464        let leak_check = LeakCheck::new(500_000);
465
466        for i in 0..leak_check.n_iter {
467            drop(bucket.check_key(&(i % 1000)));
468        }
469    }
470}