async_rate_limiter/
token_bucket.rs

1use std::{
2    pin::Pin,
3    sync::{
4        atomic::{
5            AtomicUsize,
6            Ordering::{Relaxed, SeqCst},
7        },
8        Arc, Mutex, MutexGuard,
9    },
10    task::Poll,
11    time::{Duration, Instant},
12};
13
14use futures::{Future, FutureExt};
15
16use crate::rt::delay;
17
18const NANOS_PER_SEC: u64 = 1_000_000_000;
19
20pub struct TokenBucketRateLimiter {
21    inner: Arc<Mutex<TokenBucketInner>>,
22    counter: Arc<AtomicUsize>,
23    burst: Arc<AtomicUsize>,
24    period_ns: u64,
25}
26
27struct TokenBucketInner {
28    tokens: usize,
29    last: Instant,
30}
31
32impl Clone for TokenBucketRateLimiter {
33    fn clone(&self) -> Self {
34        let prev = self.counter.fetch_add(1, SeqCst);
35        if prev == usize::MAX {
36            panic!("cannot clone `TokenBucketRateLimiter` -- too many outstanding instances");
37        }
38
39        TokenBucketRateLimiter {
40            inner: self.inner.clone(),
41            counter: self.counter.clone(),
42            burst: self.burst.clone(),
43            period_ns: self.period_ns,
44        }
45    }
46}
47
48impl TokenBucketRateLimiter {
49    /// Creates a new [`TokenBucketRateLimiter`].
50    ///
51    /// `rate` specifies the average number of operations allowed per second.
52    ///
53    /// **Note**: `rate` *MUST* be greater than zero.
54    pub fn new(rate: usize) -> TokenBucketRateLimiter {
55        assert!(rate > 0);
56
57        let period_ns = NANOS_PER_SEC.checked_div(rate as u64).unwrap();
58
59        let inner = TokenBucketInner {
60            tokens: 1,
61            last: Instant::now(),
62        };
63        let inner = Arc::new(Mutex::new(inner));
64
65        TokenBucketRateLimiter {
66            inner,
67            counter: Arc::new(AtomicUsize::new(1)),
68            burst: Arc::new(AtomicUsize::new(rate)),
69            period_ns,
70        }
71    }
72
73    /// `burst` specifies the maximum burst number of operations allowed.
74    ///
75    /// The default value of `burst` is same as `rate`.
76    ///
77    /// **Note**: `burst` *MUST* be greater than zero.
78    pub fn burst(&self, burst: usize) -> &TokenBucketRateLimiter {
79        assert!(burst > 0);
80        self.burst.store(burst, Relaxed);
81        self
82    }
83
84    /// Try to acquire a token. Return `Ok` if the token is successfully
85    /// acquired, it means that you can safely perform frequency-controlled
86    /// operations. Otherwise `Err(duration)` is returned, `duration` is the
87    /// minimum time to wait.
88    pub fn try_acquire(&self) -> Result<(), Duration> {
89        let mut inner = self.inner.lock().unwrap();
90        match self.try_acquire_inner(&mut inner) {
91            Ok(_) => Ok(()),
92            Err(next) => Err(next.saturating_duration_since(Instant::now())),
93        }
94    }
95
96    /// Acquire a token. When the token is successfully acquired, it means that
97    /// you can safely perform frequency-controlled operations.
98    pub async fn acquire(&self) {
99        let need_to_wait = {
100            let mut inner = self.inner.lock().unwrap();
101            let Err(next) = self.try_acquire_inner(&mut inner) else {
102                return;
103            };
104
105            inner.last = next;
106            self.inc_num_tokens(&mut inner);
107            self.dec_num_tokens(&mut inner);
108
109            next.saturating_duration_since(Instant::now())
110        };
111
112        if !need_to_wait.is_zero() {
113            Token::new(need_to_wait, self).await
114        }
115    }
116
117    /// Acquire a token. When the token is successfully acquired, it means that
118    /// you can safely perform frequency-controlled operations.
119    ///
120    /// If the method fails to obtain a token after exceeding the `timeout`,
121    /// false will be returned, otherwise true will be returned.
122    pub async fn acquire_with_timeout(&self, timeout: Duration) -> bool {
123        let need_to_wait = {
124            let mut inner = self.inner.lock().unwrap();
125            let Err(next) = self.try_acquire_inner(&mut inner) else {
126                return true;
127            };
128
129            inner.last = next;
130            self.inc_num_tokens(&mut inner);
131            self.dec_num_tokens(&mut inner);
132
133            let need_to_wait = next.saturating_duration_since(Instant::now());
134            if need_to_wait > timeout {
135                // failed with timeout
136                return false;
137            }
138
139            inner.last = next;
140            self.inc_num_tokens(&mut inner);
141            self.dec_num_tokens(&mut inner);
142
143            if need_to_wait.is_zero() {
144                return true;
145            }
146
147            need_to_wait
148        };
149
150        Token::new(need_to_wait, self).await;
151        true
152    }
153
154    fn try_acquire_inner(&self, inner: &mut MutexGuard<TokenBucketInner>) -> Result<(), Instant> {
155        if let Some(remain) = inner.tokens.checked_sub(1) {
156            inner.tokens = remain;
157            return Ok(());
158        }
159
160        match self.tokens_since_last(inner) {
161            Ok((tokens, duration)) => {
162                self.set_num_tokens(inner, tokens);
163                inner.last += duration;
164                // consume 1 token
165                self.dec_num_tokens(inner);
166                Ok(())
167            }
168            Err(duration) => Err(inner.last + duration),
169        }
170    }
171
172    // Get tokens generated since `last` time. Return:
173    //
174    // - Ok((tokens, duration)) if tokens has been generated, duration is the time it
175    //   takes to generate the token.
176    //
177    // - Err(duration) if tokens hasn't been generate yet, need to wait until
178    //   next cycle, duration is the period of the cycle.
179    fn tokens_since_last(
180        &self,
181        inner: &MutexGuard<TokenBucketInner>,
182    ) -> Result<(usize, Duration), Duration> {
183        let now = Instant::now();
184        let since_last = now
185            .checked_duration_since(inner.last)
186            .unwrap_or(Duration::ZERO);
187        let since_nanos = since_last.as_nanos();
188        if since_nanos >= self.period_ns as u128 {
189            let tokens = since_nanos / (self.period_ns as u128);
190            assert!(tokens >= 1);
191            Ok((
192                tokens as usize,
193                Duration::from_nanos(tokens as u64 * self.period_ns),
194            ))
195        } else {
196            Err(Duration::from_nanos(self.period_ns))
197        }
198    }
199
200    fn set_num_tokens(&self, inner: &mut MutexGuard<TokenBucketInner>, num: usize) {
201        inner.tokens = std::cmp::min(self.burst.load(Relaxed), num);
202    }
203
204    fn dec_num_tokens(&self, inner: &mut MutexGuard<TokenBucketInner>) -> Option<usize> {
205        if let Some(num) = inner.tokens.checked_sub(1) {
206            inner.tokens = num;
207            Some(num)
208        } else {
209            None
210        }
211    }
212
213    fn inc_num_tokens(&self, inner: &mut MutexGuard<TokenBucketInner>) -> usize {
214        if let Some(num) = inner.tokens.checked_add(1) {
215            self.set_num_tokens(inner, num);
216        }
217        inner.tokens
218    }
219}
220
221struct Token<'a> {
222    fut: Pin<Box<dyn Future<Output = ()>>>,
223    token_bucket: &'a TokenBucketRateLimiter,
224    consumed: bool,
225}
226
227unsafe impl Send for Token<'_> {}
228
229impl<'a> Token<'a> {
230    fn new(duration: Duration, token_bucket: &'a TokenBucketRateLimiter) -> Token<'a> {
231        let fut = delay(duration);
232        let fut = Box::pin(fut);
233        Self {
234            fut,
235            token_bucket,
236            consumed: false,
237        }
238    }
239}
240
241impl Future for Token<'_> {
242    type Output = ();
243
244    fn poll(
245        self: std::pin::Pin<&mut Self>,
246        cx: &mut std::task::Context<'_>,
247    ) -> std::task::Poll<Self::Output> {
248        let this = unsafe { self.get_unchecked_mut() };
249        match this.fut.poll_unpin(cx) {
250            std::task::Poll::Ready(_) => {
251                this.consumed = true;
252                Poll::Ready(())
253            }
254            std::task::Poll::Pending => Poll::Pending,
255        }
256    }
257}
258
259impl Drop for Token<'_> {
260    fn drop(&mut self) {
261        if self.consumed {
262            return;
263        }
264
265        // Return unused token
266        let mut inner = self.token_bucket.inner.lock().unwrap();
267        let num = inner.tokens + 1;
268        self.token_bucket.set_num_tokens(&mut inner, num);
269    }
270}