async_rate_limit/
sliding_window.rs

1use std::sync::Arc;
2
3use crate::limiters::{ThreadsafeRateLimiter, ThreadsafeVariableRateLimiter};
4use tokio::sync::{OwnedSemaphorePermit, Semaphore};
5use tokio::time::Duration;
6
7/// A rate limiter that records calls (optionally with a specified cost) during a sliding window `Duration`.
8///
9/// [`SlidingWindowRateLimiter`] implements both [`RateLimiter`] and [`VariableCostRateLimiter`], so both
10/// [`SlidingWindowRateLimiter::wait_until_ready()`] and [`SlidingWindowRateLimiter::wait_with_cost()`] can be used (even together). For instance, `limiter.wait_until_ready().await`
11///  and `limiter.wait_with_cost(1).await` would have the same effect.
12///
13/// # Example: Simple Rate Limiter
14///
15/// A method that calls an external API with a rate limit should not be called more than three times per second.
16/// ```
17/// use tokio::time::{Instant, Duration};
18/// use async_rate_limit::limiters::RateLimiter;
19/// use async_rate_limit::sliding_window::SlidingWindowRateLimiter;
20///
21/// #[tokio::main]
22/// async fn main() -> () {
23///     let mut limiter = SlidingWindowRateLimiter::new(Duration::from_secs(1), 3);
24///     
25///     for _ in 0..4 {
26///         // the 4th call will take place ~1 second after the first call
27///         limited_method(&mut limiter).await;
28///     }
29/// }
30///
31/// // note the use of the `RateLimiter` trait, rather than the direct type
32/// async fn limited_method<T>(limiter: &mut T) where T: RateLimiter {
33///     limiter.wait_until_ready().await;
34///     println!("{:?}", Instant::now());
35/// }
36/// ```
37///
38#[derive(Clone, Debug)]
39pub struct SlidingWindowRateLimiter {
40    /// Length of the window, i.e. a permit acquired at T is released at T + window
41    window: Duration,
42    /// Number of calls (or cost thereof) allowed during a given sliding window
43    permits: Arc<Semaphore>,
44}
45
46impl SlidingWindowRateLimiter {
47    /// Creates a new `SlidingWindowRateLimiter` that allows `limit` calls during any `window` Duration.
48    pub fn new(window: Duration, limit: usize) -> Self {
49        let permits = Arc::new(Semaphore::new(limit));
50
51        SlidingWindowRateLimiter { window, permits }
52    }
53
54    /// Creates a new `SlidingWindowRateLimiter` with an externally provided `Arc<Semaphore>` for permits.
55    /// # Example: A Shared Variable Cost Rate Limiter
56    ///
57    /// ```
58    /// use std::sync::Arc;
59    /// use tokio::sync::Semaphore;
60    /// use async_rate_limit::limiters::VariableCostRateLimiter;
61    /// use async_rate_limit::sliding_window::SlidingWindowRateLimiter;
62    /// use tokio::time::{Instant, Duration};
63    ///
64    /// #[tokio::main]
65    /// async fn main() -> () {
66    ///     let permits = Arc::new(Semaphore::new(5));
67    ///     let mut limiter1 =
68    ///     SlidingWindowRateLimiter::new_with_permits(Duration::from_secs(2), permits.clone());
69    ///     let mut limiter2 =
70    ///     SlidingWindowRateLimiter::new_with_permits(Duration::from_secs(2), permits.clone());
71    ///
72    ///     // Note: the above is semantically equivalent to creating `limiter1` with
73    ///     //  `SlidingWindowRateLimiter::new`, then cloning it.
74    ///
75    ///     limiter1.wait_with_cost(3).await;
76    ///     // the second call will wait 2s, since the first call consumed 3/5 shared permits
77    ///     limiter2.wait_with_cost(3).await;
78    /// }
79    /// ```
80    pub fn new_with_permits(window: Duration, permits: Arc<Semaphore>) -> Self {
81        SlidingWindowRateLimiter { window, permits }
82    }
83
84    async fn drop_permit_after_window(window: Duration, permit: OwnedSemaphorePermit) {
85        tokio::time::sleep(window).await;
86        drop(permit);
87    }
88}
89
90impl ThreadsafeRateLimiter for SlidingWindowRateLimiter {
91    /// Wait with an implied cost of 1, see the [initial example](#example-simple-rate-limiter)
92    async fn wait_until_ready(&self) {
93        let permit = self
94            .permits
95            .clone()
96            .acquire_owned()
97            .await
98            .expect("Failed to acquire permit for call");
99
100        tokio::spawn(Self::drop_permit_after_window(self.window, permit));
101    }
102}
103
104impl ThreadsafeVariableRateLimiter for SlidingWindowRateLimiter {
105    /// Wait with some variable cost per usage.
106    ///
107    /// # Example: A Shared Variable Cost Rate Limiter
108    ///
109    /// An API specifies that you may make 5 "calls" per second, but some endpoints cost more than one call.
110    /// - `/lite` costs 1 unit per call
111    /// - `/heavy` costs 3 units per call
112    /// ```
113    /// use tokio::time::{Instant, Duration};
114    /// use async_rate_limit::limiters::VariableCostRateLimiter;
115    /// use async_rate_limit::sliding_window::SlidingWindowRateLimiter;
116    ///
117    /// #[tokio::main]
118    /// async fn main() -> () {
119    ///     let mut limiter = SlidingWindowRateLimiter::new(Duration::from_secs(1), 5);
120    ///     
121    ///     for _ in 0..3 {
122    ///         // these will proceed immediately, spending 3 units
123    ///         get_lite(&mut limiter).await;
124    ///     }
125    ///
126    ///     // 3/5 units are spent, so this will wait for ~1s to proceed since it costs another 3
127    ///     get_heavy(&mut limiter).await;
128    /// }
129    ///
130    /// // note the use of the `VariableCostRateLimiter` trait, rather than the direct type
131    /// async fn get_lite<T>(limiter: &mut T) where T: VariableCostRateLimiter {
132    ///     limiter.wait_with_cost(1).await;
133    ///     println!("Lite: {:?}", Instant::now());
134    /// }
135    ///
136    /// async fn get_heavy<T>(limiter: &mut T) where T: VariableCostRateLimiter {
137    ///     limiter.wait_with_cost(3).await;
138    ///     println!("Heavy: {:?}", Instant::now());
139    /// }
140    /// ```
141    async fn wait_with_cost(&self, cost: usize) {
142        let permits = self
143            .permits
144            .clone()
145            .acquire_many_owned(cost as u32)
146            .await
147            .unwrap_or_else(|_| panic!("Failed to acquire {} permits for call", cost));
148
149        tokio::spawn(Self::drop_permit_after_window(self.window, permits));
150    }
151}
152
153#[cfg(test)]
154mod tests {
155    use super::*;
156    use tokio::time::{pause, Instant};
157
158    mod rate_limiter_tests {
159        use super::*;
160        use crate::limiters::{RateLimiter, ThreadsafeRateLimiter};
161
162        #[tokio::test]
163        async fn test_proceeds_immediately_below_limit() {
164            let limiter = SlidingWindowRateLimiter::new(Duration::from_secs(3), 7);
165
166            let start = Instant::now();
167
168            for _ in 0..7 {
169                limiter.wait_until_ready().await;
170            }
171
172            let end = Instant::now();
173
174            let duration = end - start;
175
176            assert!(duration > Duration::from_secs(0));
177            assert!(duration < Duration::from_millis(100));
178        }
179
180        #[tokio::test]
181        async fn test_waits_at_limit() {
182            pause();
183
184            let limiter = SlidingWindowRateLimiter::new(Duration::from_secs(1), 3);
185
186            let start = Instant::now();
187
188            for _ in 0..10 {
189                limiter.wait_until_ready().await;
190            }
191
192            let end = Instant::now();
193
194            let duration = end - start;
195
196            assert!(duration > Duration::from_secs(3));
197            assert!(duration < Duration::from_secs(4));
198        }
199
200        #[tokio::test]
201        async fn test_many_simultaneous_waiters() {
202            pause();
203
204            let limiter = SlidingWindowRateLimiter::new(Duration::from_secs(1), 3);
205
206            let start = Instant::now();
207
208            let mut tasks = vec![];
209
210            for _ in 0..10 {
211                let limiter_clone = Arc::new(tokio::sync::Mutex::new(limiter.clone()));
212
213                let task = tokio::spawn(async move {
214                    let limiter = limiter_clone.lock().await;
215
216                    (*limiter).wait_until_ready().await;
217                });
218                tasks.push(task);
219            }
220
221            for task in tasks.into_iter() {
222                let _ = task.await;
223            }
224
225            let end = Instant::now();
226
227            let duration = end - start;
228
229            assert!(duration > Duration::from_secs(3));
230            assert!(duration < Duration::from_secs(4));
231        }
232
233        #[tokio::test]
234        async fn test_trait_threadsafe_bounds() {
235            let limiter = SlidingWindowRateLimiter::new(Duration::from_secs(3), 7);
236
237            assert_threadsafe(&limiter).await;
238        }
239
240        #[tokio::test]
241        async fn test_trait_non_threadsafe_bounds() {
242            let mut limiter = SlidingWindowRateLimiter::new(Duration::from_secs(3), 7);
243
244            assert_non_threadsafe(&mut limiter).await;
245        }
246
247        async fn assert_threadsafe<T: ThreadsafeRateLimiter>(limiter: &T) {
248            let start = Instant::now();
249
250            for _ in 0..7 {
251                limiter.wait_until_ready().await;
252            }
253
254            let end = Instant::now();
255
256            let duration = end - start;
257
258            assert!(duration > Duration::from_secs(0));
259            assert!(duration < Duration::from_millis(100));
260        }
261
262        async fn assert_non_threadsafe<T: RateLimiter>(limiter: &mut T) {
263            let start = Instant::now();
264
265            for _ in 0..7 {
266                limiter.wait_until_ready().await;
267            }
268
269            let end = Instant::now();
270
271            let duration = end - start;
272
273            assert!(duration > Duration::from_secs(0));
274            assert!(duration < Duration::from_millis(100));
275        }
276    }
277
278    mod variable_cost_rate_limiter_tests {
279        use super::*;
280        use crate::limiters::ThreadsafeVariableRateLimiter;
281
282        #[tokio::test]
283        async fn test_proceeds_immediately_below_limit() {
284            let limiter = SlidingWindowRateLimiter::new(Duration::from_secs(3), 7);
285
286            let start = Instant::now();
287
288            for _ in 0..3 {
289                limiter.wait_with_cost(2).await;
290            }
291
292            let end = Instant::now();
293
294            let duration = end - start;
295
296            assert!(duration > Duration::from_secs(0));
297            assert!(duration < Duration::from_millis(100));
298        }
299
300        #[tokio::test]
301        async fn test_waits_at_limit() {
302            pause();
303
304            let limiter = SlidingWindowRateLimiter::new(Duration::from_secs(1), 3);
305
306            let start = Instant::now();
307
308            limiter.wait_with_cost(3).await;
309            limiter.wait_with_cost(3).await;
310            limiter.wait_with_cost(3).await;
311
312            let end = Instant::now();
313
314            let duration = end - start;
315
316            assert!(duration > Duration::from_secs(2));
317            assert!(duration < Duration::from_secs(3));
318        }
319
320        #[tokio::test]
321        async fn test_with_threadsafe_bound() {
322            pause();
323
324            let limiter = SlidingWindowRateLimiter::new(Duration::from_secs(1), 3);
325
326            assert_threadsafe(&limiter).await;
327        }
328
329        async fn assert_threadsafe<T>(limiter: &T)
330        where
331            T: ThreadsafeVariableRateLimiter,
332        {
333            let start = Instant::now();
334
335            limiter.wait_with_cost(3).await;
336            limiter.wait_with_cost(3).await;
337            limiter.wait_with_cost(3).await;
338
339            let end = Instant::now();
340
341            let duration = end - start;
342
343            assert!(duration > Duration::from_secs(2));
344            assert!(duration < Duration::from_secs(3));
345        }
346
347        #[tokio::test]
348        async fn test_waiters_with_shared_permits() {
349            pause();
350
351            let permits = Arc::new(Semaphore::new(5));
352            let limiter1 =
353                SlidingWindowRateLimiter::new_with_permits(Duration::from_secs(2), permits.clone());
354            let limiter2 =
355                SlidingWindowRateLimiter::new_with_permits(Duration::from_secs(2), permits.clone());
356
357            let start = Instant::now();
358
359            limiter1.wait_with_cost(3).await;
360            limiter2.wait_with_cost(3).await;
361
362            let end = Instant::now();
363
364            let duration = end - start;
365
366            assert!(duration > Duration::from_secs(2));
367            assert!(duration < Duration::from_secs(3));
368        }
369
370        #[tokio::test]
371        async fn test_many_waiters() {
372            pause();
373
374            let limiter = SlidingWindowRateLimiter::new(Duration::from_secs(1), 3);
375
376            let start = Instant::now();
377
378            let mut tasks = vec![];
379
380            for _ in 0..10 {
381                let limiter_clone = Arc::new(tokio::sync::Mutex::new(limiter.clone()));
382
383                let task = tokio::spawn(async move {
384                    let limiter = limiter_clone.lock().await;
385
386                    (*limiter).wait_with_cost(3).await;
387                });
388                tasks.push(task);
389            }
390
391            for task in tasks.into_iter() {
392                let _ = task.await;
393            }
394
395            let end = Instant::now();
396
397            let duration = end - start;
398
399            assert!(duration > Duration::from_secs(9));
400            assert!(duration < Duration::from_secs(10));
401        }
402    }
403}