async_rate_limiter/
token_bucket.rs1use std::{
2 pin::Pin,
3 sync::{
4 atomic::{
5 AtomicUsize,
6 Ordering::{Relaxed, SeqCst},
7 },
8 Arc, Mutex, MutexGuard,
9 },
10 task::Poll,
11 time::{Duration, Instant},
12};
13
14use futures::{Future, FutureExt};
15
16use crate::rt::delay;
17
18const NANOS_PER_SEC: u64 = 1_000_000_000;
19
20pub struct TokenBucketRateLimiter {
21 inner: Arc<Mutex<TokenBucketInner>>,
22 counter: Arc<AtomicUsize>,
23 burst: Arc<AtomicUsize>,
24 period_ns: u64,
25}
26
27struct TokenBucketInner {
28 tokens: usize,
29 last: Instant,
30}
31
32impl Clone for TokenBucketRateLimiter {
33 fn clone(&self) -> Self {
34 let prev = self.counter.fetch_add(1, SeqCst);
35 if prev == usize::MAX {
36 panic!("cannot clone `TokenBucketRateLimiter` -- too many outstanding instances");
37 }
38
39 TokenBucketRateLimiter {
40 inner: self.inner.clone(),
41 counter: self.counter.clone(),
42 burst: self.burst.clone(),
43 period_ns: self.period_ns,
44 }
45 }
46}
47
48impl TokenBucketRateLimiter {
49 pub fn new(rate: usize) -> TokenBucketRateLimiter {
55 assert!(rate > 0);
56
57 let period_ns = NANOS_PER_SEC.checked_div(rate as u64).unwrap();
58
59 let inner = TokenBucketInner {
60 tokens: 1,
61 last: Instant::now(),
62 };
63 let inner = Arc::new(Mutex::new(inner));
64
65 TokenBucketRateLimiter {
66 inner,
67 counter: Arc::new(AtomicUsize::new(1)),
68 burst: Arc::new(AtomicUsize::new(rate)),
69 period_ns,
70 }
71 }
72
73 pub fn burst(&self, burst: usize) -> &TokenBucketRateLimiter {
79 assert!(burst > 0);
80 self.burst.store(burst, Relaxed);
81 self
82 }
83
84 pub fn try_acquire(&self) -> Result<(), Duration> {
89 let mut inner = self.inner.lock().unwrap();
90 match self.try_acquire_inner(&mut inner) {
91 Ok(_) => Ok(()),
92 Err(next) => Err(next.saturating_duration_since(Instant::now())),
93 }
94 }
95
96 pub async fn acquire(&self) {
99 let need_to_wait = {
100 let mut inner = self.inner.lock().unwrap();
101 let Err(next) = self.try_acquire_inner(&mut inner) else {
102 return;
103 };
104
105 inner.last = next;
106 self.inc_num_tokens(&mut inner);
107 self.dec_num_tokens(&mut inner);
108
109 next.saturating_duration_since(Instant::now())
110 };
111
112 if !need_to_wait.is_zero() {
113 Token::new(need_to_wait, self).await
114 }
115 }
116
117 pub async fn acquire_with_timeout(&self, timeout: Duration) -> bool {
123 let need_to_wait = {
124 let mut inner = self.inner.lock().unwrap();
125 let Err(next) = self.try_acquire_inner(&mut inner) else {
126 return true;
127 };
128
129 inner.last = next;
130 self.inc_num_tokens(&mut inner);
131 self.dec_num_tokens(&mut inner);
132
133 let need_to_wait = next.saturating_duration_since(Instant::now());
134 if need_to_wait > timeout {
135 return false;
137 }
138
139 inner.last = next;
140 self.inc_num_tokens(&mut inner);
141 self.dec_num_tokens(&mut inner);
142
143 if need_to_wait.is_zero() {
144 return true;
145 }
146
147 need_to_wait
148 };
149
150 Token::new(need_to_wait, self).await;
151 true
152 }
153
154 fn try_acquire_inner(&self, inner: &mut MutexGuard<TokenBucketInner>) -> Result<(), Instant> {
155 if let Some(remain) = inner.tokens.checked_sub(1) {
156 inner.tokens = remain;
157 return Ok(());
158 }
159
160 match self.tokens_since_last(inner) {
161 Ok((tokens, duration)) => {
162 self.set_num_tokens(inner, tokens);
163 inner.last += duration;
164 self.dec_num_tokens(inner);
166 Ok(())
167 }
168 Err(duration) => Err(inner.last + duration),
169 }
170 }
171
172 fn tokens_since_last(
180 &self,
181 inner: &MutexGuard<TokenBucketInner>,
182 ) -> Result<(usize, Duration), Duration> {
183 let now = Instant::now();
184 let since_last = now
185 .checked_duration_since(inner.last)
186 .unwrap_or(Duration::ZERO);
187 let since_nanos = since_last.as_nanos();
188 if since_nanos >= self.period_ns as u128 {
189 let tokens = since_nanos / (self.period_ns as u128);
190 assert!(tokens >= 1);
191 Ok((
192 tokens as usize,
193 Duration::from_nanos(tokens as u64 * self.period_ns),
194 ))
195 } else {
196 Err(Duration::from_nanos(self.period_ns))
197 }
198 }
199
200 fn set_num_tokens(&self, inner: &mut MutexGuard<TokenBucketInner>, num: usize) {
201 inner.tokens = std::cmp::min(self.burst.load(Relaxed), num);
202 }
203
204 fn dec_num_tokens(&self, inner: &mut MutexGuard<TokenBucketInner>) -> Option<usize> {
205 if let Some(num) = inner.tokens.checked_sub(1) {
206 inner.tokens = num;
207 Some(num)
208 } else {
209 None
210 }
211 }
212
213 fn inc_num_tokens(&self, inner: &mut MutexGuard<TokenBucketInner>) -> usize {
214 if let Some(num) = inner.tokens.checked_add(1) {
215 self.set_num_tokens(inner, num);
216 }
217 inner.tokens
218 }
219}
220
221struct Token<'a> {
222 fut: Pin<Box<dyn Future<Output = ()>>>,
223 token_bucket: &'a TokenBucketRateLimiter,
224 consumed: bool,
225}
226
227unsafe impl Send for Token<'_> {}
228
229impl<'a> Token<'a> {
230 fn new(duration: Duration, token_bucket: &'a TokenBucketRateLimiter) -> Token<'a> {
231 let fut = delay(duration);
232 let fut = Box::pin(fut);
233 Self {
234 fut,
235 token_bucket,
236 consumed: false,
237 }
238 }
239}
240
241impl Future for Token<'_> {
242 type Output = ();
243
244 fn poll(
245 self: std::pin::Pin<&mut Self>,
246 cx: &mut std::task::Context<'_>,
247 ) -> std::task::Poll<Self::Output> {
248 let this = unsafe { self.get_unchecked_mut() };
249 match this.fut.poll_unpin(cx) {
250 std::task::Poll::Ready(_) => {
251 this.consumed = true;
252 Poll::Ready(())
253 }
254 std::task::Poll::Pending => Poll::Pending,
255 }
256 }
257}
258
259impl Drop for Token<'_> {
260 fn drop(&mut self) {
261 if self.consumed {
262 return;
263 }
264
265 let mut inner = self.token_bucket.inner.lock().unwrap();
267 let num = inner.tokens + 1;
268 self.token_bucket.set_num_tokens(&mut inner, num);
269 }
270}