async_throttle/
multi.rs

1use crate::RateLimiter;
2use backoff::backoff::Backoff;
3use backoff::{ExponentialBackoff, ExponentialBackoffBuilder};
4use std::hash::Hash;
5use std::time::Duration;
6
7/// [`MultiRateLimiter`] enables key-based rate limiting, where each key has its own [`RateLimiter`].
8///
9/// This behavior is useful when you want to throttle a set of keys independently, for example
10/// you may have a web crawler that wants to throttle its requests to each domain independently.
11///
12/// # Examples
13///
14/// ```
15/// use async_throttle::MultiRateLimiter;
16/// use std::sync::Arc;
17///
18/// #[tokio::main]
19/// async fn main() {
20///    let period = std::time::Duration::from_secs(5);
21///    let rate_limiter = MultiRateLimiter::new(period);
22///    
23///    // This completes instantly
24///    rate_limiter.throttle("foo", || computation()).await;
25///
26///    // This completes instantly
27///    rate_limiter.throttle("bar", || computation()).await;
28///
29///    // This takes 5 seconds to complete because the key "foo" is rate limited
30///    rate_limiter.throttle("foo", || computation()).await;
31/// }
32///
33/// async fn computation() { }
34/// ```
35pub struct MultiRateLimiter<K> {
36    /// The period for each [`RateLimiter`] associated with a particular key
37    period: Duration,
38
39    /// The key-specific [`RateLimiter`]s
40    ///
41    /// The [`RateLimiter`]s are stored in a [`dashmap::DashMap`], which is a concurrent hash map.
42    /// Note that keys may map to the same shard within the [`dashmap::DashMap`], so you may experience
43    /// increase latency due to the spin-looping nature of [MultiRateLimiter::throttle] combined
44    /// with the fallibility of [`dashmap::DashMap::try_entry`].
45    rate_limiters: dashmap::DashMap<K, RateLimiter>,
46}
47
48impl<K: Eq + Hash + Clone> MultiRateLimiter<K> {
49    /// Creates a new [`MultiRateLimiter`].
50    pub fn new(period: Duration) -> Self {
51        Self {
52            period,
53            rate_limiters: dashmap::DashMap::new(),
54        }
55    }
56
57    /// Throttles the execution of a function based on a key.
58    /// Throttling is key-specific, so multiple keys can be throttled independently.
59    ///
60    /// # Examples
61    ///
62    /// ```
63    /// use async_throttle::MultiRateLimiter;
64    /// use anyhow::Result;
65    /// use std::sync::Arc;
66    ///
67    /// async fn do_work() { /* some computation */ }
68    ///
69    /// async fn throttle_by_key(the_key: u32, limiter: Arc<MultiRateLimiter<u32>>) {
70    ///    limiter.throttle(the_key, || do_work()).await
71    /// }
72    pub async fn throttle<Fut, F, T>(&self, key: K, f: F) -> T
73    where
74        Fut: std::future::Future<Output = T>,
75        F: FnOnce() -> Fut,
76    {
77        loop {
78            let mut backoff = get_backoff();
79
80            match self.rate_limiters.try_entry(key.clone()) {
81                None => {
82                    // Safety: `next_backoff` always returns Some(Duration)
83                    tokio::time::sleep(backoff.next_backoff().unwrap()).await
84                }
85                Some(entry) => {
86                    let rate_limiter = entry.or_insert_with(|| RateLimiter::new(self.period));
87                    return rate_limiter.value().throttle(f).await;
88                }
89            }
90        }
91    }
92}
93
94fn get_backoff() -> ExponentialBackoff {
95    ExponentialBackoffBuilder::default()
96        .with_initial_interval(Duration::from_millis(50))
97        .with_max_elapsed_time(None)
98        .build()
99}