leaky_bucket/lib.rs
1//! [<img alt="github" src="https://img.shields.io/badge/github-udoprog/leaky--bucket-8da0cb?style=for-the-badge&logo=github" height="20">](https://github.com/udoprog/leaky-bucket)
2//! [<img alt="crates.io" src="https://img.shields.io/crates/v/leaky-bucket.svg?style=for-the-badge&color=fc8d62&logo=rust" height="20">](https://crates.io/crates/leaky-bucket)
3//! [<img alt="docs.rs" src="https://img.shields.io/badge/docs.rs-leaky--bucket-66c2a5?style=for-the-badge&logoColor=white&logo=data:image/svg+xml;base64,PHN2ZyByb2xlPSJpbWciIHhtbG5zPSJodHRwOi8vd3d3LnczLm9yZy8yMDAwL3N2ZyIgdmlld0JveD0iMCAwIDUxMiA1MTIiPjxwYXRoIGZpbGw9IiNmNWY1ZjUiIGQ9Ik00ODguNiAyNTAuMkwzOTIgMjE0VjEwNS41YzAtMTUtOS4zLTI4LjQtMjMuNC0zMy43bC0xMDAtMzcuNWMtOC4xLTMuMS0xNy4xLTMuMS0yNS4zIDBsLTEwMCAzNy41Yy0xNC4xIDUuMy0yMy40IDE4LjctMjMuNCAzMy43VjIxNGwtOTYuNiAzNi4yQzkuMyAyNTUuNSAwIDI2OC45IDAgMjgzLjlWMzk0YzAgMTMuNiA3LjcgMjYuMSAxOS45IDMyLjJsMTAwIDUwYzEwLjEgNS4xIDIyLjEgNS4xIDMyLjIgMGwxMDMuOS01MiAxMDMuOSA1MmMxMC4xIDUuMSAyMi4xIDUuMSAzMi4yIDBsMTAwLTUwYzEyLjItNi4xIDE5LjktMTguNiAxOS45LTMyLjJWMjgzLjljMC0xNS05LjMtMjguNC0yMy40LTMzLjd6TTM1OCAyMTQuOGwtODUgMzEuOXYtNjguMmw4NS0zN3Y3My4zek0xNTQgMTA0LjFsMTAyLTM4LjIgMTAyIDM4LjJ2LjZsLTEwMiA0MS40LTEwMi00MS40di0uNnptODQgMjkxLjFsLTg1IDQyLjV2LTc5LjFsODUtMzguOHY3NS40em0wLTExMmwtMTAyIDQxLjQtMTAyLTQxLjR2LS42bDEwMi0zOC4yIDEwMiAzOC4ydi42em0yNDAgMTEybC04NSA0Mi41di03OS4xbDg1LTM4Ljh2NzUuNHptMC0xMTJsLTEwMiA0MS40LTEwMi00MS40di0uNmwxMDItMzguMiAxMDIgMzguMnYuNnoiPjwvcGF0aD48L3N2Zz4K" height="20">](https://docs.rs/leaky-bucket)
4//!
5//! A token-based rate limiter based on the [leaky bucket] algorithm.
6//!
7//! If the bucket overflows and goes over its max configured capacity, the task
8//! that tried to acquire the tokens will be suspended until the required number
9//! of tokens has been drained from the bucket.
10//!
11//! Since this crate uses timing facilities from tokio it has to be used within
12//! a Tokio runtime with the [`time` feature] enabled.
13//!
14//! This library has some neat features, which includes:
15//!
16//! **Not requiring a background task**. This is usually needed by token bucket
17//! rate limiters to drive progress. Instead, one of the waiting tasks
18//! temporarily assumes the role as coordinator (called the *core*). This
19//! reduces the amount of tasks needing to sleep, which can be a source of
20//! jitter for imprecise sleeping implementations and tight limiters. See below
21//! for more details.
22//!
23//! **Dropped tasks** release any resources they've reserved. So that
24//! constructing and cancellaing asynchronous tasks to not end up taking up wait
25//! slots it never uses which would be the case for cell-based rate limiters.
26//!
27//! <br>
28//!
29//! ## Usage
30//!
31//! The core type is [`RateLimiter`], which allows for limiting the throughput
32//! of a section using its [`acquire`], [`try_acquire`], and [`acquire_one`]
33//! methods.
34//!
35//! The following is a simple example where we wrap requests through a HTTP
36//! `Client`, to ensure that we don't exceed a given limit:
37//!
38//! ```
39//! use leaky_bucket::RateLimiter;
40//! # struct Client;
41//! # impl Client { async fn request<T>(&self, path: &str) -> Result<T> { todo!() } }
42//! # trait DeserializeOwned {}
43//! # impl DeserializeOwned for Vec<Post> {}
44//! # type Result<T> = core::result::Result<T, ()>;
45//!
46//! /// A blog client.
47//! pub struct BlogClient {
48//! limiter: RateLimiter,
49//! client: Client,
50//! }
51//!
52//! struct Post {
53//! // ..
54//! }
55//!
56//! impl BlogClient {
57//! /// Get all posts from the service.
58//! pub async fn get_posts(&self) -> Result<Vec<Post>> {
59//! self.request("posts").await
60//! }
61//!
62//! /// Perform a request against the service, limiting requests to abide by a rate limit.
63//! async fn request<T>(&self, path: &str) -> Result<T>
64//! where
65//! T: DeserializeOwned
66//! {
67//! // Before we start sending a request, we block on acquiring one token.
68//! self.limiter.acquire(1).await;
69//! self.client.request::<T>(path).await
70//! }
71//! }
72//! ```
73//!
74//! <br>
75//!
76//! ## Implementation details
77//!
78//! Each rate limiter has two acquisition modes. A fast path and a slow path.
79//! The fast path is used if the desired number of tokens are readily available,
80//! and simply involves decrementing the number of tokens available in the
81//! shared pool.
82//!
83//! If the required number of tokens is not available, the task will be forced
84//! to be suspended until the next refill interval. Here one of the acquiring
85//! tasks will switch over to work as a *core*. This is known as *core
86//! switching*.
87//!
88//! ```
89//! use leaky_bucket::RateLimiter;
90//! use tokio::time::Duration;
91//!
92//! # #[tokio::main(flavor="current_thread", start_paused=true)] async fn main() {
93//! let limiter = RateLimiter::builder()
94//! .initial(10)
95//! .interval(Duration::from_millis(100))
96//! .build();
97//!
98//! // This is instantaneous since the rate limiter starts with 10 tokens to
99//! // spare.
100//! limiter.acquire(10).await;
101//!
102//! // This however needs to core switch and wait for a while until the desired
103//! // number of tokens is available.
104//! limiter.acquire(3).await;
105//! # }
106//! ```
107//!
108//! The core is responsible for sleeping for the configured interval so that
109//! more tokens can be added. After which it ensures that any tasks that are
110//! waiting to acquire including itself are appropriately unsuspended.
111//!
112//! On-demand core switching is what allows this rate limiter implementation to
113//! work without a coordinating background thread. But we need to ensure that
114//! any asynchronous tasks that uses [`RateLimiter`] must either run an
115//! [`acquire`] call to completion, or be *cancelled* by being dropped.
116//!
117//! If none of these hold, the core might leak and be locked indefinitely
118//! preventing any future use of the rate limiter from making progress. This is
119//! similar to if you would lock an asynchronous [`Mutex`] but never drop its
120//! guard.
121//!
122//! > You can run this example with:
123//! >
124//! > ```sh
125//! > cargo run --example block_forever
126//! > ```
127//!
128//! ```no_run
129//! use std::future::Future;
130//! use std::sync::Arc;
131//! use std::task::Context;
132//!
133//! use leaky_bucket::RateLimiter;
134//!
135//! struct Waker;
136//! # impl std::task::Wake for Waker { fn wake(self: Arc<Self>) { } }
137//!
138//! # #[tokio::main(flavor="current_thread", start_paused=true)] async fn main() {
139//! let limiter = Arc::new(RateLimiter::builder().build());
140//!
141//! let waker = Arc::new(Waker).into();
142//! let mut cx = Context::from_waker(&waker);
143//!
144//! let mut a0 = Box::pin(limiter.acquire(1));
145//! // Poll once to ensure that the core task is assigned.
146//! assert!(a0.as_mut().poll(&mut cx).is_pending());
147//! assert!(a0.is_core());
148//!
149//! // We leak the core task, preventing the rate limiter from making progress
150//! // by assigning new core tasks.
151//! std::mem::forget(a0);
152//!
153//! // Awaiting acquire here would block forever.
154//! // limiter.acquire(1).await;
155//! # }
156//! ```
157//!
158//! <br>
159//!
160//! ## Fairness
161//!
162//! By default [`RateLimiter`] uses a *fair* scheduler. This ensures that the
163//! core task makes progress even if there are many tasks waiting to acquire
164//! tokens. This might cause more core switching, increasing the total work
165//! needed. An unfair scheduler is expected to do a bit less work under
166//! contention. But without fair scheduling some tasks might end up taking
167//! longer to acquire than expected.
168//!
169//! Unfair rate limiters also have access to a fast path for acquiring tokens,
170//! which might further improve throughput.
171//!
172//! This behavior can be tweaked with the [`Builder::fair`] option.
173//!
174//! ```
175//! use leaky_bucket::RateLimiter;
176//!
177//! let limiter = RateLimiter::builder()
178//! .fair(false)
179//! .build();
180//! ```
181//!
182//! The `unfair-scheduling` example can showcase this phenomenon.
183//!
184//! ```sh
185//! cargo run --example unfair_scheduling
186//! ```
187//!
188//! ```text
189//! # fair
190//! Max: 1011ms, Total: 1012ms
191//! Timings:
192//! 0: 101ms
193//! 1: 101ms
194//! 2: 101ms
195//! 3: 101ms
196//! 4: 101ms
197//! ...
198//! # unfair
199//! Max: 1014ms, Total: 1014ms
200//! Timings:
201//! 0: 1014ms
202//! 1: 101ms
203//! 2: 101ms
204//! 3: 101ms
205//! 4: 101ms
206//! ...
207//! ```
208//!
209//! As can be seen above the first task in the *unfair* scheduler takes longer
210//! to run because it prioritises releasing other tasks waiting to acquire over
211//! itself.
212//!
213//! [`acquire_one`]: https://docs.rs/leaky-bucket/1/leaky_bucket/struct.RateLimiter.html#method.acquire_one
214//! [`acquire`]: https://docs.rs/leaky-bucket/1/leaky_bucket/struct.RateLimiter.html#method.acquire
215//! [`Builder::fair`]: https://docs.rs/leaky-bucket/1/leaky_bucket/struct.Builder.html#method.fair
216//! [`Mutex`]: https://docs.rs/tokio/1/tokio/sync/struct.Mutex.html
217//! [`RateLimiter`]: https://docs.rs/leaky-bucket/1/leaky_bucket/struct.RateLimiter.html
218//! [`time` feature]: https://docs.rs/tokio/1/tokio/#feature-flags
219//! [`try_acquire`]: https://docs.rs/leaky-bucket/1/leaky_bucket/struct.RateLimiter.html#method.try_acquire
220//! [leaky bucket]: https://en.wikipedia.org/wiki/Leaky_bucket
221
222#![no_std]
223#![deny(missing_docs)]
224
225extern crate alloc;
226
227#[macro_use]
228extern crate std;
229
230use core::cell::UnsafeCell;
231use core::convert::TryFrom as _;
232use core::fmt;
233use core::future::Future;
234use core::mem::{self, ManuallyDrop};
235use core::ops::{Deref, DerefMut};
236use core::pin::Pin;
237use core::ptr;
238use core::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
239use core::task::{Context, Poll, Waker};
240
241use alloc::sync::Arc;
242
243use parking_lot::{Mutex, MutexGuard};
244use pin_project_lite::pin_project;
245use tokio::time::{self, Duration, Instant};
246
247#[cfg(feature = "tracing")]
248macro_rules! trace {
249 ($($arg:tt)*) => {
250 tracing::trace!($($arg)*)
251 };
252}
253
254#[cfg(not(feature = "tracing"))]
255macro_rules! trace {
256 ($($arg:tt)*) => {};
257}
258
259mod linked_list;
260use self::linked_list::{LinkedList, Node};
261
262/// Default factor for how to calculate max refill value.
263const DEFAULT_REFILL_MAX_FACTOR: usize = 10;
264
265/// Interval to bump the shared mutex guard to allow other parts of the system
266/// to make process. Processes which loop should use this number to determine
267/// how many times it should loop before calling [Guard::bump].
268///
269/// If we do not respect this limit we might inadvertently end up starving other
270/// tasks from making progress so that they can unblock.
271const BUMP_LIMIT: usize = 16;
272
273/// The maximum supported balance.
274const MAX_BALANCE: usize = isize::MAX as usize;
275
276/// Marker trait which indicates that a type represents a unique held critical section.
277trait IsCritical {}
278impl IsCritical for Critical {}
279impl IsCritical for Guard<'_> {}
280
281/// Linked task state.
282struct Task {
283 /// Remaining tokens that need to be satisfied.
284 remaining: usize,
285 /// If this node has been released or not. We make this an atomic to permit
286 /// access to it without synchronization.
287 complete: AtomicBool,
288 /// The waker associated with the node.
289 waker: Option<Waker>,
290}
291
292impl Task {
293 /// Construct a new task state with the given permits remaining.
294 const fn new() -> Self {
295 Self {
296 remaining: 0,
297 complete: AtomicBool::new(false),
298 waker: None,
299 }
300 }
301
302 /// Test if the current node is completed.
303 fn is_completed(&self) -> bool {
304 self.remaining == 0
305 }
306
307 /// Fill the current node from the given pool of tokens and modify it.
308 fn fill(&mut self, current: &mut usize) {
309 let removed = usize::min(self.remaining, *current);
310 self.remaining -= removed;
311 *current -= removed;
312 }
313}
314
315/// A borrowed rate limiter.
316struct BorrowedRateLimiter<'a>(&'a RateLimiter);
317
318impl Deref for BorrowedRateLimiter<'_> {
319 type Target = RateLimiter;
320
321 #[inline]
322 fn deref(&self) -> &RateLimiter {
323 self.0
324 }
325}
326
327struct Critical {
328 /// Waiter list.
329 waiters: LinkedList<Task>,
330 /// The deadline for when more tokens can be be added.
331 deadline: Instant,
332}
333
334#[repr(transparent)]
335struct Guard<'a> {
336 critical: MutexGuard<'a, Critical>,
337}
338
339impl Guard<'_> {
340 #[inline]
341 fn bump(this: &mut Guard<'_>) {
342 MutexGuard::bump(&mut this.critical)
343 }
344}
345
346impl Deref for Guard<'_> {
347 type Target = Critical;
348
349 #[inline]
350 fn deref(&self) -> &Critical {
351 &self.critical
352 }
353}
354
355impl DerefMut for Guard<'_> {
356 #[inline]
357 fn deref_mut(&mut self) -> &mut Critical {
358 &mut self.critical
359 }
360}
361
362impl Critical {
363 #[inline]
364 fn push_task_front(&mut self, task: &mut Node<Task>) {
365 // SAFETY: We both have mutable access to the node being pushed, and
366 // mutable access to the critical section through `self`. So we know we
367 // have exclusive tampering rights to the waiter queue.
368 unsafe {
369 self.waiters.push_front(task.into());
370 }
371 }
372
373 #[inline]
374 fn push_task(&mut self, task: &mut Node<Task>) {
375 // SAFETY: We both have mutable access to the node being pushed, and
376 // mutable access to the critical section through `self`. So we know we
377 // have exclusive tampering rights to the waiter queue.
378 unsafe {
379 self.waiters.push_back(task.into());
380 }
381 }
382
383 #[inline]
384 fn remove_task(&mut self, task: &mut Node<Task>) {
385 // SAFETY: We both have mutable access to the node being pushed, and
386 // mutable access to the critical section through `self`. So we know we
387 // have exclusive tampering rights to the waiter queue.
388 unsafe {
389 self.waiters.remove(task.into());
390 }
391 }
392
393 /// Release the current core. Beyond this point the current task may no
394 /// longer interact exclusively with the core.
395 #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "trace"))]
396 fn release(&mut self, state: &mut State<'_>) {
397 trace!("releasing core");
398 state.available = true;
399
400 // Find another task that might take over as core. Once it has acquired
401 // core status it will have to make sure it is no longer linked into the
402 // wait queue.
403 unsafe {
404 if let Some(node) = self.waiters.front() {
405 trace!(node = ?node, "waking next core");
406
407 if let Some(ref waker) = node.as_ref().waker {
408 waker.wake_by_ref();
409 }
410 }
411 }
412 }
413}
414
415#[derive(Debug)]
416struct State<'a> {
417 /// Original state.
418 state: usize,
419 /// If the core is available or not.
420 available: bool,
421 /// The balance.
422 balance: usize,
423 /// The rate limiter the state is associated with.
424 lim: &'a RateLimiter,
425}
426
427impl<'a> State<'a> {
428 fn try_fast_path(mut self, permits: usize) -> bool {
429 let mut attempts = 0;
430
431 // Fast path where we just try to nab any available permit without
432 // locking.
433 //
434 // We do have to race against anyone else grabbing permits here when
435 // storing the state back.
436 while self.balance >= permits {
437 // Abandon fast path if we've tried too many times.
438 if attempts == BUMP_LIMIT {
439 break;
440 }
441
442 self.balance -= permits;
443
444 if let Err(new_state) = self.try_save() {
445 self = new_state;
446 attempts += 1;
447 continue;
448 }
449
450 return true;
451 }
452
453 false
454 }
455
456 /// Add tokens and release any pending tasks.
457 #[cfg_attr(
458 feature = "tracing",
459 tracing::instrument(skip(self, critical, f), level = "trace")
460 )]
461 #[inline]
462 fn add_tokens<F, O>(&mut self, critical: &mut Guard<'_>, tokens: usize, f: F) -> O
463 where
464 F: FnOnce(&mut Guard<'_>, &mut State) -> O,
465 {
466 if tokens > 0 {
467 debug_assert!(
468 tokens <= MAX_BALANCE,
469 "Additional tokens {} must be less than {}",
470 tokens,
471 MAX_BALANCE
472 );
473
474 self.balance = (self.balance + tokens).min(self.lim.max);
475 drain_wait_queue(critical, self);
476 let output = f(critical, self);
477 return output;
478 }
479
480 f(critical, self)
481 }
482
483 #[inline]
484 fn decode(state: usize, lim: &'a RateLimiter) -> Self {
485 State {
486 state,
487 available: state & 1 == 1,
488 balance: state >> 1,
489 lim,
490 }
491 }
492
493 #[inline]
494 fn encode(&self) -> usize {
495 (self.balance << 1) | usize::from(self.available)
496 }
497
498 /// Try to save the state, but only succeed if it hasn't been modified.
499 #[inline]
500 fn try_save(self) -> Result<(), Self> {
501 let this = ManuallyDrop::new(self);
502
503 match this.lim.state.compare_exchange(
504 this.state,
505 this.encode(),
506 Ordering::Release,
507 Ordering::Relaxed,
508 ) {
509 Ok(_) => Ok(()),
510 Err(state) => Err(State::decode(state, this.lim)),
511 }
512 }
513}
514
515impl Drop for State<'_> {
516 #[inline]
517 fn drop(&mut self) {
518 self.lim.state.store(self.encode(), Ordering::Release);
519 }
520}
521
522/// A token-bucket rate limiter.
523pub struct RateLimiter {
524 /// Tokens to add every `per` duration.
525 refill: usize,
526 /// Interval in milliseconds to add tokens.
527 interval: Duration,
528 /// Max number of tokens associated with the rate limiter.
529 max: usize,
530 /// If the rate limiter is fair or not.
531 fair: bool,
532 /// The state of the rate limiter.
533 state: AtomicUsize,
534 /// Critical state of the rate limiter.
535 critical: Mutex<Critical>,
536}
537
538impl RateLimiter {
539 /// Construct a new [`Builder`] for a [`RateLimiter`].
540 ///
541 /// # Examples
542 ///
543 /// ```
544 /// use leaky_bucket::RateLimiter;
545 /// use tokio::time::Duration;
546 ///
547 /// let limiter = RateLimiter::builder()
548 /// .initial(100)
549 /// .refill(100)
550 /// .max(1000)
551 /// .interval(Duration::from_millis(250))
552 /// .fair(false)
553 /// .build();
554 /// ```
555 pub fn builder() -> Builder {
556 Builder::default()
557 }
558
559 /// Get the refill amount of this rate limiter as set through
560 /// [`Builder::refill`].
561 ///
562 /// # Examples
563 ///
564 /// ```
565 /// use leaky_bucket::RateLimiter;
566 ///
567 /// let limiter = RateLimiter::builder()
568 /// .refill(1024)
569 /// .build();
570 ///
571 /// assert_eq!(limiter.refill(), 1024);
572 /// ```
573 pub fn refill(&self) -> usize {
574 self.refill
575 }
576
577 /// Get the refill interval of this rate limiter as set through
578 /// [`Builder::interval`].
579 ///
580 /// # Examples
581 ///
582 /// ```
583 /// use leaky_bucket::RateLimiter;
584 /// use tokio::time::Duration;
585 ///
586 /// let limiter = RateLimiter::builder()
587 /// .interval(Duration::from_millis(1000))
588 /// .build();
589 ///
590 /// assert_eq!(limiter.interval(), Duration::from_millis(1000));
591 /// ```
592 pub fn interval(&self) -> Duration {
593 self.interval
594 }
595
596 /// Get the max value of this rate limiter as set through [`Builder::max`].
597 ///
598 /// # Examples
599 ///
600 /// ```
601 /// use leaky_bucket::RateLimiter;
602 ///
603 /// let limiter = RateLimiter::builder()
604 /// .max(1024)
605 /// .build();
606 ///
607 /// assert_eq!(limiter.max(), 1024);
608 /// ```
609 pub fn max(&self) -> usize {
610 self.max
611 }
612
613 /// Test if the current rate limiter is fair as specified through
614 /// [`Builder::fair`].
615 ///
616 /// # Examples
617 ///
618 /// ```
619 /// use leaky_bucket::RateLimiter;
620 ///
621 /// let limiter = RateLimiter::builder()
622 /// .fair(true)
623 /// .build();
624 ///
625 /// assert_eq!(limiter.is_fair(), true);
626 /// ```
627 pub fn is_fair(&self) -> bool {
628 self.fair
629 }
630
631 /// Get the current token balance.
632 ///
633 /// This indicates how many tokens can be requested without blocking.
634 ///
635 /// # Examples
636 ///
637 /// ```
638 /// use leaky_bucket::RateLimiter;
639 ///
640 /// # #[tokio::main(flavor="current_thread", start_paused=true)] async fn main() {
641 /// let limiter = RateLimiter::builder()
642 /// .initial(100)
643 /// .build();
644 ///
645 /// assert_eq!(limiter.balance(), 100);
646 /// limiter.acquire(10).await;
647 /// assert_eq!(limiter.balance(), 90);
648 /// # }
649 /// ```
650 pub fn balance(&self) -> usize {
651 self.state.load(Ordering::Acquire) >> 1
652 }
653
654 /// Acquire a single permit.
655 ///
656 /// # Examples
657 ///
658 /// ```
659 /// use leaky_bucket::RateLimiter;
660 ///
661 /// # #[tokio::main(flavor="current_thread", start_paused=true)] async fn main() {
662 /// let limiter = RateLimiter::builder()
663 /// .initial(10)
664 /// .build();
665 ///
666 /// limiter.acquire_one().await;
667 /// # }
668 /// ```
669 pub fn acquire_one(&self) -> Acquire<'_> {
670 self.acquire(1)
671 }
672
673 /// Acquire the given number of permits, suspending the current task until
674 /// they are available.
675 ///
676 /// If zero permits are specified, this function never suspends the current
677 /// task.
678 ///
679 /// # Examples
680 ///
681 /// ```
682 /// use leaky_bucket::RateLimiter;
683 ///
684 /// # #[tokio::main(flavor="current_thread", start_paused=true)] async fn main() {
685 /// let limiter = RateLimiter::builder()
686 /// .initial(10)
687 /// .build();
688 ///
689 /// limiter.acquire(10).await;
690 /// # }
691 /// ```
692 pub fn acquire(&self, permits: usize) -> Acquire<'_> {
693 Acquire {
694 inner: AcquireFut::new(BorrowedRateLimiter(self), permits),
695 }
696 }
697
698 /// Try to acquire the given number of permits, returning `true` if the
699 /// given number of permits were successfully acquired.
700 ///
701 /// If the scheduler is fair, and there are pending tasks waiting to acquire
702 /// tokens this method will return `false`.
703 ///
704 /// If zero permits are specified, this method returns `true`.
705 ///
706 /// # Examples
707 ///
708 /// ```
709 /// use leaky_bucket::RateLimiter;
710 /// use tokio::time;
711 ///
712 /// # #[tokio::main(flavor="current_thread", start_paused=true)] async fn main() {
713 /// let limiter = RateLimiter::builder().refill(1).initial(1).build();
714 ///
715 /// assert!(limiter.try_acquire(1));
716 /// assert!(!limiter.try_acquire(1));
717 /// assert!(limiter.try_acquire(0));
718 ///
719 /// time::sleep(limiter.interval() * 2).await;
720 ///
721 /// assert!(limiter.try_acquire(1));
722 /// assert!(limiter.try_acquire(1));
723 /// assert!(!limiter.try_acquire(1));
724 /// # }
725 /// ```
726 pub fn try_acquire(&self, permits: usize) -> bool {
727 if self.try_fast_path(permits) {
728 return true;
729 }
730
731 let mut critical = self.lock();
732
733 // Reload the state while we are under the critical lock, this
734 // ensures that the `available` flag is up-to-date since it is only
735 // ever modified while holding the critical lock.
736 let mut state = self.take();
737
738 // The core is *not* available, which also implies that there are tasks
739 // ahead which are busy.
740 if !state.available {
741 return false;
742 }
743
744 let now = Instant::now();
745
746 // Here we try to assume core duty temporarily to see if we can
747 // release a sufficient number of tokens to allow the current task
748 // to proceed.
749 if let Some((tokens, deadline)) = self.calculate_drain(critical.deadline, now) {
750 state.balance = (state.balance + tokens).min(self.max);
751 critical.deadline = deadline;
752 }
753
754 if state.balance >= permits {
755 state.balance -= permits;
756 return true;
757 }
758
759 false
760 }
761
762 /// Acquire a permit using an owned future.
763 ///
764 /// If zero permits are specified, this function never suspends the current
765 /// task.
766 ///
767 /// This required the [`RateLimiter`] to be wrapped inside of an
768 /// [`std::sync::Arc`] but will in contrast permit the acquire operation to
769 /// be owned by another struct making it more suitable for embedding.
770 ///
771 /// # Examples
772 ///
773 /// ```
774 /// use leaky_bucket::RateLimiter;
775 /// use std::sync::Arc;
776 ///
777 /// # #[tokio::main(flavor="current_thread", start_paused=true)] async fn main() {
778 /// let limiter = Arc::new(RateLimiter::builder().initial(10).build());
779 ///
780 /// limiter.acquire_owned(10).await;
781 /// # }
782 /// ```
783 ///
784 /// Example when embedded into another future. This wouldn't be possible
785 /// with [`RateLimiter::acquire`] since it would otherwise hold a reference
786 /// to the corresponding [`RateLimiter`] instance.
787 ///
788 /// ```
789 /// use std::future::Future;
790 /// use std::pin::Pin;
791 /// use std::sync::Arc;
792 /// use std::task::{Context, Poll};
793 ///
794 /// use leaky_bucket::{AcquireOwned, RateLimiter};
795 /// use pin_project::pin_project;
796 ///
797 /// #[pin_project]
798 /// struct MyFuture {
799 /// limiter: Arc<RateLimiter>,
800 /// #[pin]
801 /// acquire: Option<AcquireOwned>,
802 /// }
803 ///
804 /// impl Future for MyFuture {
805 /// type Output = ();
806 ///
807 /// fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
808 /// let mut this = self.project();
809 ///
810 /// loop {
811 /// if let Some(acquire) = this.acquire.as_mut().as_pin_mut() {
812 /// futures::ready!(acquire.poll(cx));
813 /// return Poll::Ready(());
814 /// }
815 ///
816 /// this.acquire.set(Some(this.limiter.clone().acquire_owned(100)));
817 /// }
818 /// }
819 /// }
820 ///
821 /// # #[tokio::main(flavor="current_thread", start_paused=true)] async fn main() {
822 /// let limiter = Arc::new(RateLimiter::builder().initial(100).build());
823 ///
824 /// let future = MyFuture { limiter, acquire: None };
825 /// future.await;
826 /// # }
827 /// ```
828 pub fn acquire_owned(self: Arc<Self>, permits: usize) -> AcquireOwned {
829 AcquireOwned {
830 inner: AcquireFut::new(self, permits),
831 }
832 }
833
834 /// Lock the critical section of the rate limiter and return the associated guard.
835 fn lock(&self) -> Guard<'_> {
836 Guard {
837 critical: self.critical.lock(),
838 }
839 }
840
841 /// Load the current state.
842 fn load(&self) -> State<'_> {
843 State::decode(self.state.load(Ordering::Acquire), self)
844 }
845
846 /// Take the current state, leaving the core state intact.
847 fn take(&self) -> State<'_> {
848 State::decode(self.state.swap(0, Ordering::Acquire), self)
849 }
850
851 /// Try to use fast path.
852 fn try_fast_path(&self, permits: usize) -> bool {
853 if permits == 0 {
854 return true;
855 }
856
857 if self.fair {
858 return false;
859 }
860
861 self.load().try_fast_path(permits)
862 }
863
864 /// Calculate refill amount. Returning a tuple of how much to fill and remaining
865 /// duration to sleep until the next refill time if appropriate.
866 ///
867 /// The maximum number of additional tokens this method will ever return is
868 /// limited to [`MAX_BALANCE`] to ensure that addition with an existing
869 /// balance will never overflow.
870 fn calculate_drain(&self, deadline: Instant, now: Instant) -> Option<(usize, Instant)> {
871 if now < deadline {
872 return None;
873 }
874
875 // Time elapsed in milliseconds since the last deadline.
876 let millis = self.interval.as_millis();
877 let since = now.saturating_duration_since(deadline).as_millis();
878
879 let periods = usize::try_from(since / millis + 1).unwrap_or(usize::MAX);
880
881 let tokens = periods
882 .checked_mul(self.refill)
883 .unwrap_or(MAX_BALANCE)
884 .min(MAX_BALANCE);
885
886 let rem = u64::try_from(since % millis).unwrap_or(u64::MAX);
887
888 // Calculated time remaining until the next deadline.
889 let deadline = now
890 + self
891 .interval
892 .saturating_sub(time::Duration::from_millis(rem));
893
894 Some((tokens, deadline))
895 }
896}
897
898impl fmt::Debug for RateLimiter {
899 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
900 f.debug_struct("RateLimiter")
901 .field("refill", &self.refill)
902 .field("interval", &self.interval)
903 .field("max", &self.max)
904 .field("fair", &self.fair)
905 .finish_non_exhaustive()
906 }
907}
908
909/// Refill the wait queue with the given number of tokens.
910#[cfg_attr(feature = "tracing", tracing::instrument(skip_all, level = "trace"))]
911fn drain_wait_queue(critical: &mut Guard<'_>, state: &mut State<'_>) {
912 trace!(?state, "releasing waiters");
913
914 let mut bump = 0;
915
916 // SAFETY: we're holding the lock guard to all the waiters so we can be
917 // sure that we have exclusive access to the wait queue.
918 unsafe {
919 while state.balance > 0 {
920 let mut node = match critical.waiters.pop_back() {
921 Some(node) => node,
922 None => break,
923 };
924
925 let n = node.as_mut();
926 n.fill(&mut state.balance);
927
928 trace! {
929 ?state,
930 remaining = n.remaining,
931 "filled node",
932 };
933
934 if !n.is_completed() {
935 critical.waiters.push_back(node);
936 break;
937 }
938
939 n.complete.store(true, Ordering::Release);
940
941 if let Some(waker) = n.waker.take() {
942 waker.wake();
943 }
944
945 bump += 1;
946
947 if bump == BUMP_LIMIT {
948 Guard::bump(critical);
949 bump = 0;
950 }
951 }
952 }
953}
954
955// SAFETY: All the internals of acquire is thread safe and correctly
956// synchronized. The embedded waiter queue doesn't have anything inherently
957// unsafe in it.
958unsafe impl Send for RateLimiter {}
959unsafe impl Sync for RateLimiter {}
960
961/// A builder for a [`RateLimiter`].
962pub struct Builder {
963 /// The max number of tokens.
964 max: Option<usize>,
965 /// The initial count of tokens.
966 initial: usize,
967 /// Tokens to add every `per` duration.
968 refill: usize,
969 /// Interval to add tokens in milliseconds.
970 interval: Duration,
971 /// If the rate limiter is fair or not.
972 fair: bool,
973}
974
975impl Builder {
976 /// Configure the max number of tokens to use.
977 ///
978 /// If unspecified, this will default to be 10 times the [`refill`] or the
979 /// [`initial`] value, whichever is largest.
980 ///
981 /// The maximum supported balance is limited to [`isize::MAX`].
982 ///
983 /// # Examples
984 ///
985 /// ```
986 /// use leaky_bucket::RateLimiter;
987 ///
988 /// let limiter = RateLimiter::builder()
989 /// .max(10_000)
990 /// .build();
991 /// ```
992 ///
993 /// [`refill`]: Builder::refill
994 /// [`initial`]: Builder::initial
995 pub fn max(&mut self, max: usize) -> &mut Self {
996 self.max = Some(max);
997 self
998 }
999
1000 /// Configure the initial number of tokens to configure. The default value
1001 /// is `0`.
1002 ///
1003 /// # Examples
1004 ///
1005 /// ```
1006 /// use leaky_bucket::RateLimiter;
1007 ///
1008 /// let limiter = RateLimiter::builder()
1009 /// .initial(10)
1010 /// .build();
1011 /// ```
1012 pub fn initial(&mut self, initial: usize) -> &mut Self {
1013 self.initial = initial;
1014 self
1015 }
1016
1017 /// Configure the time duration between which we add [`refill`] number to
1018 /// the bucket rate limiter.
1019 ///
1020 /// This is 100ms by default.
1021 ///
1022 /// # Panics
1023 ///
1024 /// This panics if the provided interval does not fit within the millisecond
1025 /// bounds of a [usize] or is zero.
1026 ///
1027 /// ```should_panic
1028 /// use leaky_bucket::RateLimiter;
1029 /// use tokio::time::Duration;
1030 ///
1031 /// let limiter = RateLimiter::builder()
1032 /// .interval(Duration::from_secs(u64::MAX))
1033 /// .build();
1034 /// ```
1035 ///
1036 /// ```should_panic
1037 /// use leaky_bucket::RateLimiter;
1038 /// use tokio::time::Duration;
1039 ///
1040 /// let limiter = RateLimiter::builder()
1041 /// .interval(Duration::from_millis(0))
1042 /// .build();
1043 /// ```
1044 ///
1045 /// # Examples
1046 ///
1047 /// ```
1048 /// use leaky_bucket::RateLimiter;
1049 /// use tokio::time::Duration;
1050 ///
1051 /// let limiter = RateLimiter::builder()
1052 /// .interval(Duration::from_millis(100))
1053 /// .build();
1054 /// ```
1055 ///
1056 /// [`refill`]: Builder::refill
1057 pub fn interval(&mut self, interval: Duration) -> &mut Self {
1058 assert! {
1059 interval.as_millis() != 0,
1060 "interval must be non-zero",
1061 };
1062 assert! {
1063 u64::try_from(interval.as_millis()).is_ok(),
1064 "interval must fit within a 64-bit integer"
1065 };
1066 self.interval = interval;
1067 self
1068 }
1069
1070 /// The number of tokens to add at each [`interval`] interval. The default
1071 /// value is `1`.
1072 ///
1073 /// # Panics
1074 ///
1075 /// Panics if a refill amount of `0` is specified.
1076 ///
1077 /// # Examples
1078 ///
1079 /// ```
1080 /// use leaky_bucket::RateLimiter;
1081 ///
1082 /// let limiter = RateLimiter::builder()
1083 /// .refill(100)
1084 /// .build();
1085 /// ```
1086 ///
1087 /// [`interval`]: Builder::interval
1088 pub fn refill(&mut self, refill: usize) -> &mut Self {
1089 assert!(refill > 0, "refill amount cannot be zero");
1090 self.refill = refill;
1091 self
1092 }
1093
1094 /// Configure the rate limiter to be fair.
1095 ///
1096 /// Fairness is enabled by deafult.
1097 ///
1098 /// Fairness ensures that tasks make progress in the order that they acquire
1099 /// even when the rate limiter is under contention. An unfair scheduler
1100 /// might have a higher total throughput.
1101 ///
1102 /// Fair scheduling also affects the behavior of
1103 /// [`RateLimiter::try_acquire`] which will return `false` if there are any
1104 /// pending tasks since they should be given priority.
1105 ///
1106 /// # Examples
1107 ///
1108 /// ```
1109 /// use leaky_bucket::RateLimiter;
1110 ///
1111 /// let limiter = RateLimiter::builder()
1112 /// .refill(100)
1113 /// .fair(false)
1114 /// .build();
1115 /// ```
1116 pub fn fair(&mut self, fair: bool) -> &mut Self {
1117 self.fair = fair;
1118 self
1119 }
1120
1121 /// Construct a new [`RateLimiter`].
1122 ///
1123 /// # Examples
1124 ///
1125 /// ```
1126 /// use leaky_bucket::RateLimiter;
1127 /// use tokio::time::Duration;
1128 ///
1129 /// let limiter = RateLimiter::builder()
1130 /// .refill(100)
1131 /// .interval(Duration::from_millis(200))
1132 /// .max(10_000)
1133 /// .build();
1134 /// ```
1135 pub fn build(&self) -> RateLimiter {
1136 let deadline = Instant::now() + self.interval;
1137
1138 let initial = self.initial.min(MAX_BALANCE);
1139 let refill = self.refill.min(MAX_BALANCE);
1140
1141 let max = match self.max {
1142 Some(max) => max.min(MAX_BALANCE),
1143 None => refill
1144 .max(initial)
1145 .saturating_mul(DEFAULT_REFILL_MAX_FACTOR)
1146 .min(MAX_BALANCE),
1147 };
1148
1149 let initial = initial.min(max);
1150
1151 RateLimiter {
1152 refill,
1153 interval: self.interval,
1154 max,
1155 fair: self.fair,
1156 state: AtomicUsize::new(initial << 1 | 1),
1157 critical: Mutex::new(Critical {
1158 waiters: LinkedList::new(),
1159 deadline,
1160 }),
1161 }
1162 }
1163}
1164
1165/// Construct a new builder with default options.
1166///
1167/// # Examples
1168///
1169/// ```
1170/// use leaky_bucket::Builder;
1171///
1172/// let limiter = Builder::default().build();
1173/// ```
1174impl Default for Builder {
1175 fn default() -> Self {
1176 Self {
1177 max: None,
1178 initial: 0,
1179 refill: 1,
1180 interval: Duration::from_millis(100),
1181 fair: true,
1182 }
1183 }
1184}
1185
1186/// The state of an acquire operation.
1187#[derive(Debug, Clone, Copy)]
1188enum AcquireFutState {
1189 /// Initial unconfigured state.
1190 Initial,
1191 /// The acquire is waiting to be released by the core.
1192 Waiting,
1193 /// The operation is completed.
1194 Complete,
1195 /// The task is currently the core.
1196 Core,
1197}
1198
1199/// Inner state and methods of the acquire.
1200#[repr(transparent)]
1201struct AcquireFutInner {
1202 /// Aliased task state.
1203 node: UnsafeCell<Node<Task>>,
1204}
1205
1206impl AcquireFutInner {
1207 const fn new() -> AcquireFutInner {
1208 AcquireFutInner {
1209 node: UnsafeCell::new(Node::new(Task::new())),
1210 }
1211 }
1212
1213 /// Access the completion flag.
1214 pub fn complete(&self) -> &AtomicBool {
1215 // SAFETY: This is always safe to access since it's atomic.
1216 unsafe { &*ptr::addr_of!((*self.node.get()).complete) }
1217 }
1218
1219 /// Get the underlying task mutably.
1220 ///
1221 /// We prove that the caller does indeed have mutable access to the node by
1222 /// passing in a mutable reference to the critical section.
1223 #[inline]
1224 pub fn get_task<'crit, C>(
1225 self: Pin<&'crit mut Self>,
1226 critical: &'crit mut C,
1227 ) -> (&'crit mut C, &'crit mut Node<Task>)
1228 where
1229 C: IsCritical,
1230 {
1231 // SAFETY: Caller has exclusive access to the critical section, since
1232 // it's passed in as a mutable argument. We can also ensure that none of
1233 // the borrows outlive the provided closure.
1234 unsafe { (critical, &mut *self.node.get()) }
1235 }
1236
1237 /// Update the waiting state for this acquisition task. This might require
1238 /// that we update the associated waker.
1239 #[cfg_attr(
1240 feature = "tracing",
1241 tracing::instrument(skip(self, critical, waker), level = "trace")
1242 )]
1243 fn update(self: Pin<&mut Self>, critical: &mut Guard<'_>, waker: &Waker) {
1244 let (critical, task) = self.get_task(critical);
1245
1246 if !task.is_linked() {
1247 critical.push_task_front(task);
1248 }
1249
1250 let new_waker = match task.waker {
1251 None => true,
1252 Some(ref w) => !w.will_wake(waker),
1253 };
1254
1255 if new_waker {
1256 trace!("updating waker");
1257 task.waker = Some(waker.clone());
1258 }
1259 }
1260
1261 /// Ensure that the current core task is correctly linked up if needed.
1262 #[cfg_attr(
1263 feature = "tracing",
1264 tracing::instrument(skip(self, critical, lim), level = "trace")
1265 )]
1266 fn link_core(self: Pin<&mut Self>, critical: &mut Critical, lim: &RateLimiter) {
1267 let (critical, task) = self.get_task(critical);
1268
1269 match (lim.fair, task.is_linked()) {
1270 (true, false) => {
1271 // Fair scheduling needs to ensure that the core is part of the wait
1272 // queue, and will be woken up in-order with other tasks.
1273 critical.push_task(task);
1274 }
1275 (false, true) => {
1276 // Unfair scheduling will not wake the core in order, so
1277 // don't bother having it linked.
1278 critical.remove_task(task);
1279 }
1280 _ => {}
1281 }
1282 }
1283
1284 /// Release any remaining tokens which are associated with this particular task.
1285 fn release_remaining(
1286 self: Pin<&mut Self>,
1287 critical: &mut Guard<'_>,
1288 state: &mut State<'_>,
1289 permits: usize,
1290 ) {
1291 let (critical, task) = self.get_task(critical);
1292
1293 if task.is_linked() {
1294 critical.remove_task(task);
1295 }
1296
1297 // Hand back permits which we've acquired so far.
1298 let release = permits.saturating_sub(task.remaining);
1299 state.add_tokens(critical, release, |_, _| ());
1300 }
1301
1302 /// Drain the given number of tokens through the core. Returns `true` if the
1303 /// core has been completed.
1304 #[cfg_attr(
1305 feature = "tracing",
1306 tracing::instrument(skip(self, critical), level = "trace")
1307 )]
1308 fn drain_core(
1309 self: Pin<&mut Self>,
1310 critical: &mut Guard<'_>,
1311 state: &mut State<'_>,
1312 tokens: usize,
1313 ) -> bool {
1314 let completed = state.add_tokens(critical, tokens, |critical, state| {
1315 let (_, task) = self.get_task(critical);
1316
1317 // If the limiter is not fair, we need to in addition to draining
1318 // remaining tokens from linked nodes, drain it from ourselves. We
1319 // fill the current holder of the core last (self). To ensure that
1320 // it stays around for as long as possible.
1321 if !state.lim.fair {
1322 task.fill(&mut state.balance);
1323 }
1324
1325 task.is_completed()
1326 });
1327
1328 if completed {
1329 // Everything was drained, including the current core (if
1330 // appropriate). So we can release it now.
1331 critical.release(state);
1332 }
1333
1334 completed
1335 }
1336
1337 /// Assume the current core and calculate how long we must sleep for in
1338 /// order to do it.
1339 ///
1340 /// # Safety
1341 ///
1342 /// This might link the current task into the task queue, so the caller must
1343 /// ensure that it is pinned.
1344 #[cfg_attr(
1345 feature = "tracing",
1346 tracing::instrument(skip(self, critical), level = "trace")
1347 )]
1348 fn assume_core(
1349 mut self: Pin<&mut Self>,
1350 critical: &mut Guard<'_>,
1351 state: &mut State<'_>,
1352 now: Instant,
1353 ) -> bool {
1354 self.as_mut().link_core(critical, state.lim);
1355
1356 let (tokens, deadline) = match state.lim.calculate_drain(critical.deadline, now) {
1357 Some(tokens) => tokens,
1358 None => return true,
1359 };
1360
1361 // It is appropriate to update the deadline.
1362 critical.deadline = deadline;
1363 !self.drain_core(critical, state, tokens)
1364 }
1365}
1366
1367pin_project! {
1368 /// The future associated with acquiring permits from a rate limiter using
1369 /// [`RateLimiter::acquire`].
1370 #[project(!Unpin)]
1371 pub struct Acquire<'a> {
1372 #[pin]
1373 inner: AcquireFut<BorrowedRateLimiter<'a>>,
1374 }
1375}
1376
1377impl Acquire<'_> {
1378 /// Test if this acquire task is currently coordinating the rate limiter.
1379 ///
1380 /// # Examples
1381 ///
1382 /// ```
1383 /// use leaky_bucket::RateLimiter;
1384 /// use std::future::Future;
1385 /// use std::sync::Arc;
1386 /// use std::task::Context;
1387 ///
1388 /// struct Waker;
1389 /// # impl std::task::Wake for Waker { fn wake(self: Arc<Self>) { } }
1390 ///
1391 /// # #[tokio::main(flavor="current_thread", start_paused=true)] async fn main() {
1392 /// let limiter = RateLimiter::builder().build();
1393 ///
1394 /// let waker = Arc::new(Waker).into();
1395 /// let mut cx = Context::from_waker(&waker);
1396 ///
1397 /// let a1 = limiter.acquire(1);
1398 /// tokio::pin!(a1);
1399 ///
1400 /// assert!(!a1.is_core());
1401 /// assert!(a1.as_mut().poll(&mut cx).is_pending());
1402 /// assert!(a1.is_core());
1403 ///
1404 /// a1.as_mut().await;
1405 ///
1406 /// // After completion this is no longer a core.
1407 /// assert!(!a1.is_core());
1408 /// # }
1409 /// ```
1410 pub fn is_core(&self) -> bool {
1411 self.inner.is_core()
1412 }
1413}
1414
1415impl Future for Acquire<'_> {
1416 type Output = ();
1417
1418 #[inline]
1419 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
1420 self.project().inner.poll(cx)
1421 }
1422}
1423
1424pin_project! {
1425 /// The future associated with acquiring permits from a rate limiter using
1426 /// [`RateLimiter::acquire_owned`].
1427 #[project(!Unpin)]
1428 pub struct AcquireOwned {
1429 #[pin]
1430 inner: AcquireFut<Arc<RateLimiter>>,
1431 }
1432}
1433
1434impl AcquireOwned {
1435 /// Test if this acquire task is currently coordinating the rate limiter.
1436 ///
1437 /// # Examples
1438 ///
1439 /// ```
1440 /// use leaky_bucket::RateLimiter;
1441 /// use std::future::Future;
1442 /// use std::sync::Arc;
1443 /// use std::task::Context;
1444 ///
1445 /// struct Waker;
1446 /// # impl std::task::Wake for Waker { fn wake(self: Arc<Self>) { } }
1447 ///
1448 /// # #[tokio::main(flavor="current_thread", start_paused=true)] async fn main() {
1449 /// let limiter = Arc::new(RateLimiter::builder().build());
1450 ///
1451 /// let waker = Arc::new(Waker).into();
1452 /// let mut cx = Context::from_waker(&waker);
1453 ///
1454 /// let a1 = limiter.acquire_owned(1);
1455 /// tokio::pin!(a1);
1456 ///
1457 /// assert!(!a1.is_core());
1458 /// assert!(a1.as_mut().poll(&mut cx).is_pending());
1459 /// assert!(a1.is_core());
1460 ///
1461 /// a1.as_mut().await;
1462 ///
1463 /// // After completion this is no longer a core.
1464 /// assert!(!a1.is_core());
1465 /// # }
1466 /// ```
1467 pub fn is_core(&self) -> bool {
1468 self.inner.is_core()
1469 }
1470}
1471
1472impl Future for AcquireOwned {
1473 type Output = ();
1474
1475 #[inline]
1476 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
1477 self.project().inner.poll(cx)
1478 }
1479}
1480
1481pin_project! {
1482 #[project(!Unpin)]
1483 #[project = AcquireFutProj]
1484 struct AcquireFut<T>
1485 where
1486 T: Deref<Target = RateLimiter>,
1487 {
1488 lim: T,
1489 permits: usize,
1490 state: AcquireFutState,
1491 #[pin]
1492 sleep: Option<time::Sleep>,
1493 #[pin]
1494 inner: AcquireFutInner,
1495 }
1496
1497 impl<T> PinnedDrop for AcquireFut<T>
1498 where
1499 T: Deref<Target = RateLimiter>,
1500 {
1501 fn drop(this: Pin<&mut Self>) {
1502 let AcquireFutProj { lim, permits, state, inner, .. } = this.project();
1503
1504 let is_core = match *state {
1505 AcquireFutState::Waiting => false,
1506 AcquireFutState::Core { .. } => true,
1507 _ => return,
1508 };
1509
1510 let mut critical = lim.lock();
1511 let mut s = lim.take();
1512 inner.release_remaining(&mut critical, &mut s, *permits);
1513
1514 if is_core {
1515 critical.release(&mut s);
1516 }
1517
1518 *state = AcquireFutState::Complete;
1519 }
1520 }
1521}
1522
1523impl<T> AcquireFut<T>
1524where
1525 T: Deref<Target = RateLimiter>,
1526{
1527 #[inline]
1528 const fn new(lim: T, permits: usize) -> Self {
1529 Self {
1530 lim,
1531 permits,
1532 state: AcquireFutState::Initial,
1533 sleep: None,
1534 inner: AcquireFutInner::new(),
1535 }
1536 }
1537
1538 fn is_core(&self) -> bool {
1539 matches!(&self.state, AcquireFutState::Core { .. })
1540 }
1541}
1542
1543// SAFETY: All the internals of acquire is thread safe and correctly
1544// synchronized. The embedded waiter queue doesn't have anything inherently
1545// unsafe in it.
1546unsafe impl<T> Send for AcquireFut<T> where T: Send + Deref<Target = RateLimiter> {}
1547unsafe impl<T> Sync for AcquireFut<T> where T: Sync + Deref<Target = RateLimiter> {}
1548
1549impl<T> Future for AcquireFut<T>
1550where
1551 T: Deref<Target = RateLimiter>,
1552{
1553 type Output = ();
1554
1555 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
1556 let AcquireFutProj {
1557 lim,
1558 permits,
1559 state,
1560 mut sleep,
1561 inner: mut internal,
1562 ..
1563 } = self.project();
1564
1565 // Hold onto the critical lock for core operations, but only acquire it
1566 // when strictly necessary.
1567 let mut critical;
1568
1569 // Shared state.
1570 //
1571 // Once we are holding onto the critical lock, we take the entire state
1572 // to ensure that any fast-past negotiators do not observe any available
1573 // permits while potential core work is ongoing.
1574 let mut s;
1575
1576 // Hold onto any call to `Instant::now` which we might perform, so we
1577 // don't have to get the current time multiple times.
1578 let outer_now;
1579
1580 match *state {
1581 AcquireFutState::Complete => {
1582 return Poll::Ready(());
1583 }
1584 AcquireFutState::Initial => {
1585 // If the rate limiter is not fair, try to oppurtunistically
1586 // just acquire a permit through the known atomic state.
1587 //
1588 // This is known as the fast path, but requires acquire to raise
1589 // against other tasks when storing the state back.
1590 if lim.try_fast_path(*permits) {
1591 *state = AcquireFutState::Complete;
1592 return Poll::Ready(());
1593 }
1594
1595 critical = lim.lock();
1596 s = lim.take();
1597
1598 let now = Instant::now();
1599
1600 // If we've hit a deadline, calculate the number of tokens
1601 // to drain and perform it in line here. This is necessary
1602 // because the core isn't aware of how long we sleep between
1603 // each acquire, so we need to perform some of the drain
1604 // work here in order to avoid acruing a debt that needs to
1605 // be filled later in.
1606 //
1607 // If we didn't do this, and the process slept for a long
1608 // time, the next time a core is acquired it would be very
1609 // far removed from the expected deadline and has no idea
1610 // when permits were acquired, so it would over-eagerly
1611 // release a lot of acquires and accumulate permits.
1612 //
1613 // This is tested for in the `test_idle` suite of tests.
1614 let tokens =
1615 if let Some((tokens, deadline)) = lim.calculate_drain(critical.deadline, now) {
1616 trace!(tokens, "inline drain");
1617 // We pre-emptively update the deadline of the core
1618 // since it might bump, and we don't want other
1619 // processes to observe that the deadline has been
1620 // reached.
1621 critical.deadline = deadline;
1622 tokens
1623 } else {
1624 0
1625 };
1626
1627 let completed = s.add_tokens(&mut critical, tokens, |critical, s| {
1628 let (_, task) = internal.as_mut().get_task(critical);
1629 task.remaining = *permits;
1630 task.fill(&mut s.balance);
1631 task.is_completed()
1632 });
1633
1634 if completed {
1635 *state = AcquireFutState::Complete;
1636 return Poll::Ready(());
1637 }
1638
1639 // Try to take over as core. If we're unsuccessful we just
1640 // ensure that we're linked into the wait queue.
1641 if !mem::take(&mut s.available) {
1642 internal.as_mut().update(&mut critical, cx.waker());
1643 *state = AcquireFutState::Waiting;
1644 return Poll::Pending;
1645 }
1646
1647 // SAFETY: This is done in a pinned section, so we know that
1648 // the linked section stays alive for the duration of this
1649 // future due to pinning guarantees.
1650 internal.as_mut().link_core(&mut critical, lim);
1651 Guard::bump(&mut critical);
1652 *state = AcquireFutState::Core;
1653 outer_now = Some(now);
1654 }
1655 AcquireFutState::Waiting => {
1656 // If we are complete, then return as ready.
1657 //
1658 // This field is atomic, so we can safely read it under shared
1659 // access and do not require a lock.
1660 if internal.complete().load(Ordering::Acquire) {
1661 *state = AcquireFutState::Complete;
1662 return Poll::Ready(());
1663 }
1664
1665 // Note: we need to operate under this lock to ensure that
1666 // the core acquired here (or elsewhere) observes that the
1667 // current task has been linked up.
1668 critical = lim.lock();
1669 s = lim.take();
1670
1671 // Try to take over as core. If we're unsuccessful we
1672 // just ensure that we're linked into the wait queue.
1673 if !mem::take(&mut s.available) {
1674 internal.update(&mut critical, cx.waker());
1675 return Poll::Pending;
1676 }
1677
1678 let now = Instant::now();
1679
1680 // This is done in a pinned section, so we know that the linked
1681 // section stays alive for the duration of this future due to
1682 // pinning guarantees.
1683 if !internal.as_mut().assume_core(&mut critical, &mut s, now) {
1684 // Marks as completed.
1685 *state = AcquireFutState::Complete;
1686 return Poll::Ready(());
1687 }
1688
1689 Guard::bump(&mut critical);
1690 *state = AcquireFutState::Core;
1691 outer_now = Some(now);
1692 }
1693 AcquireFutState::Core => {
1694 critical = lim.lock();
1695 s = lim.take();
1696 outer_now = None;
1697 }
1698 }
1699
1700 trace!(until = ?critical.deadline, "taking over core and sleeping");
1701
1702 let mut sleep = match sleep.as_mut().as_pin_mut() {
1703 Some(mut sleep) => {
1704 if sleep.deadline() != critical.deadline {
1705 sleep.as_mut().reset(critical.deadline);
1706 }
1707
1708 sleep
1709 }
1710 None => {
1711 sleep.set(Some(time::sleep_until(critical.deadline)));
1712 sleep.as_mut().as_pin_mut().unwrap()
1713 }
1714 };
1715
1716 if sleep.as_mut().poll(cx).is_pending() {
1717 return Poll::Pending;
1718 }
1719
1720 critical.deadline = outer_now.unwrap_or_else(Instant::now) + lim.interval;
1721
1722 if internal.drain_core(&mut critical, &mut s, lim.refill) {
1723 *state = AcquireFutState::Complete;
1724 return Poll::Ready(());
1725 }
1726
1727 cx.waker().wake_by_ref();
1728 Poll::Pending
1729 }
1730}
1731
1732#[cfg(test)]
1733mod tests {
1734 use super::{Acquire, AcquireOwned, RateLimiter};
1735
1736 fn is_send<T: Send>() {}
1737 fn is_sync<T: Sync>() {}
1738
1739 #[test]
1740 fn assert_send_sync() {
1741 is_send::<AcquireOwned>();
1742 is_sync::<AcquireOwned>();
1743
1744 is_send::<RateLimiter>();
1745 is_sync::<RateLimiter>();
1746
1747 is_send::<Acquire<'_>>();
1748 is_sync::<Acquire<'_>>();
1749 }
1750}