arl/
lib.rs

1//! Async rate limiter for tokio runtime.
2//! Used to prevent DOSing a remote service or limiting usage of some other resources.
3
4#![forbid(unsafe_code)]
5#![warn(
6    anonymous_parameters,
7    clippy::needless_borrow,
8    missing_docs,
9    missing_copy_implementations,
10    missing_debug_implementations,
11    nonstandard_style,
12    rust_2018_idioms,
13    single_use_lifetimes,
14    trivial_casts,
15    trivial_numeric_casts,
16    unreachable_pub,
17    unused_extern_crates,
18    unused_qualifications,
19    variant_size_differences
20)]
21
22use tokio::sync::mpsc::{channel, Receiver, Sender};
23use tokio::sync::oneshot;
24use tokio::time::sleep;
25use tokio::time::{Duration, Instant};
26
27/// A rate limiter, or `time barrier` to prevent continuing execution before given time passes.
28/// Used mainly for things like dealing with remote API DOS protection.
29#[derive(Clone, Debug)]
30pub struct RateLimiter {
31    sender: Sender<Message>,
32}
33
34impl RateLimiter {
35    /// Creates a new RateLimiter that prevents an async task from continuing too quickly.
36    /// # Example:
37    /// ```no_run
38    /// use std::time::Duration;
39    /// use arl::RateLimiter;
40    /// let limiter = RateLimiter::new(75, Duration::from_secs(60));
41    /// async {
42    /// loop {
43    ///         limiter.wait().await;
44    ///         // Call a remote api here.
45    ///         // This will ensure that the remote api will not be hit more than 75 times in 60 seconds block.
46    ///      }
47    /// };
48    ///```
49    /// RateLimiter can be cloned and send to other threads: it will use the same counter and limits
50    /// for all the threads.
51    pub fn new(count: usize, duration: Duration) -> Self {
52        let (sender, receiver) = channel(count);
53        RateLimiter::spawn_receiver(receiver, count, duration);
54        Self { sender }
55    }
56
57    /// Make the current task wait until given limits have passed.
58    /// Uses `tokio::time::sleep()` internally, so it allows other tasks to continue in the meantime.
59    /// # Example:
60    /// ```no_run
61    /// use std::time::Duration;
62    /// use arl::RateLimiter;
63    /// let limiter = RateLimiter::new(2, Duration::from_secs(1));
64    /// async {
65    ///     loop {
66    ///         limiter.wait().await;
67    ///         // continue here knowing that it won't be executed more than twice in a second
68    ///     }   
69    /// };
70    /// ```
71    pub async fn wait(&self) {
72        let (s, r) = oneshot::channel::<()>();
73        self.sender
74            .send(Message { sender: s })
75            .await
76            .expect("unable to send to arl channel");
77        r.await.expect("unable to read from arl channel");
78    }
79
80    fn spawn_receiver(mut receiver: Receiver<Message>, count: usize, duration: Duration) {
81        tokio::spawn(async move {
82            let mut queue = Vec::with_capacity(count);
83            while let Some(message) = receiver.recv().await {
84                while !queue.is_empty() && queue[0] <= Instant::now() {
85                    queue.remove(0);
86                }
87                if queue.len() > count {
88                    let alarm = queue.remove(0);
89                    sleep(alarm - Instant::now()).await;
90                }
91                message
92                    .sender
93                    .send(())
94                    .expect("unable to send to arl client channel");
95                queue.push(Instant::now() + duration);
96            }
97        });
98    }
99}
100
101#[derive(Debug)]
102struct Message {
103    sender: oneshot::Sender<()>,
104}
105
106#[cfg(test)]
107mod test {
108    use crate::RateLimiter;
109    use std::time::Duration;
110    use tokio::time::Instant;
111
112    #[tokio::test]
113    async fn up_to_limit_execute_quickly() {
114        const COUNT: usize = 10;
115        let limiter = RateLimiter::new(COUNT, Duration::from_secs(60));
116        let start = Instant::now();
117        for _ in 0..COUNT {
118            limiter.wait().await;
119        }
120        let elapsed = start.elapsed();
121        assert!(elapsed < Duration::from_millis(10));
122    }
123
124    #[tokio::test]
125    async fn over_limit_execute_proportionally() {
126        const COUNT: usize = 10;
127        const CHUNKS: usize = 3;
128        let limiter = RateLimiter::new(COUNT, Duration::from_secs(1));
129        let start = Instant::now();
130        for _ in 0..CHUNKS {
131            for _ in 0..COUNT {
132                limiter.wait().await;
133            }
134        }
135        let elapsed = start.elapsed();
136        // Time below compared to 2 seconds:
137        // First chunk (10 calls to wait()) was executed immediately,
138        // Second chunk executed after 1 seconds.
139        // Third chunk executed after 2 seconds.
140        assert!(elapsed > Duration::from_secs(CHUNKS as u64 - 1));
141    }
142}