leaky_bucket/
lib.rs

1//! [<img alt="github" src="https://img.shields.io/badge/github-udoprog/leaky--bucket-8da0cb?style=for-the-badge&logo=github" height="20">](https://github.com/udoprog/leaky-bucket)
2//! [<img alt="crates.io" src="https://img.shields.io/crates/v/leaky-bucket.svg?style=for-the-badge&color=fc8d62&logo=rust" height="20">](https://crates.io/crates/leaky-bucket)
3//! [<img alt="docs.rs" src="https://img.shields.io/badge/docs.rs-leaky--bucket-66c2a5?style=for-the-badge&logoColor=white&logo=data:image/svg+xml;base64,PHN2ZyByb2xlPSJpbWciIHhtbG5zPSJodHRwOi8vd3d3LnczLm9yZy8yMDAwL3N2ZyIgdmlld0JveD0iMCAwIDUxMiA1MTIiPjxwYXRoIGZpbGw9IiNmNWY1ZjUiIGQ9Ik00ODguNiAyNTAuMkwzOTIgMjE0VjEwNS41YzAtMTUtOS4zLTI4LjQtMjMuNC0zMy43bC0xMDAtMzcuNWMtOC4xLTMuMS0xNy4xLTMuMS0yNS4zIDBsLTEwMCAzNy41Yy0xNC4xIDUuMy0yMy40IDE4LjctMjMuNCAzMy43VjIxNGwtOTYuNiAzNi4yQzkuMyAyNTUuNSAwIDI2OC45IDAgMjgzLjlWMzk0YzAgMTMuNiA3LjcgMjYuMSAxOS45IDMyLjJsMTAwIDUwYzEwLjEgNS4xIDIyLjEgNS4xIDMyLjIgMGwxMDMuOS01MiAxMDMuOSA1MmMxMC4xIDUuMSAyMi4xIDUuMSAzMi4yIDBsMTAwLTUwYzEyLjItNi4xIDE5LjktMTguNiAxOS45LTMyLjJWMjgzLjljMC0xNS05LjMtMjguNC0yMy40LTMzLjd6TTM1OCAyMTQuOGwtODUgMzEuOXYtNjguMmw4NS0zN3Y3My4zek0xNTQgMTA0LjFsMTAyLTM4LjIgMTAyIDM4LjJ2LjZsLTEwMiA0MS40LTEwMi00MS40di0uNnptODQgMjkxLjFsLTg1IDQyLjV2LTc5LjFsODUtMzguOHY3NS40em0wLTExMmwtMTAyIDQxLjQtMTAyLTQxLjR2LS42bDEwMi0zOC4yIDEwMiAzOC4ydi42em0yNDAgMTEybC04NSA0Mi41di03OS4xbDg1LTM4Ljh2NzUuNHptMC0xMTJsLTEwMiA0MS40LTEwMi00MS40di0uNmwxMDItMzguMiAxMDIgMzguMnYuNnoiPjwvcGF0aD48L3N2Zz4K" height="20">](https://docs.rs/leaky-bucket)
4//!
5//! A token-based rate limiter based on the [leaky bucket] algorithm.
6//!
7//! If the bucket overflows and goes over its max configured capacity, the task
8//! that tried to acquire the tokens will be suspended until the required number
9//! of tokens has been drained from the bucket.
10//!
11//! Since this crate uses timing facilities from tokio it has to be used within
12//! a Tokio runtime with the [`time` feature] enabled.
13//!
14//! This library has some neat features, which includes:
15//!
16//! **Not requiring a background task**. This is usually needed by token bucket
17//! rate limiters to drive progress. Instead, one of the waiting tasks
18//! temporarily assumes the role as coordinator (called the *core*). This
19//! reduces the amount of tasks needing to sleep, which can be a source of
20//! jitter for imprecise sleeping implementations and tight limiters. See below
21//! for more details.
22//!
23//! **Dropped tasks** release any resources they've reserved. So that
24//! constructing and cancellaing asynchronous tasks to not end up taking up wait
25//! slots it never uses which would be the case for cell-based rate limiters.
26//!
27//! <br>
28//!
29//! ## Usage
30//!
31//! The core type is [`RateLimiter`], which allows for limiting the throughput
32//! of a section using its [`acquire`], [`try_acquire`], and [`acquire_one`]
33//! methods.
34//!
35//! The following is a simple example where we wrap requests through a HTTP
36//! `Client`, to ensure that we don't exceed a given limit:
37//!
38//! ```
39//! use leaky_bucket::RateLimiter;
40//! # struct Client;
41//! # impl Client { async fn request<T>(&self, path: &str) -> Result<T> { todo!() } }
42//! # trait DeserializeOwned {}
43//! # impl DeserializeOwned for Vec<Post> {}
44//! # type Result<T> = core::result::Result<T, ()>;
45//!
46//! /// A blog client.
47//! pub struct BlogClient {
48//!     limiter: RateLimiter,
49//!     client: Client,
50//! }
51//!
52//! struct Post {
53//!     // ..
54//! }
55//!
56//! impl BlogClient {
57//!     /// Get all posts from the service.
58//!     pub async fn get_posts(&self) -> Result<Vec<Post>> {
59//!         self.request("posts").await
60//!     }
61//!
62//!     /// Perform a request against the service, limiting requests to abide by a rate limit.
63//!     async fn request<T>(&self, path: &str) -> Result<T>
64//!     where
65//!         T: DeserializeOwned
66//!     {
67//!         // Before we start sending a request, we block on acquiring one token.
68//!         self.limiter.acquire(1).await;
69//!         self.client.request::<T>(path).await
70//!     }
71//! }
72//! ```
73//!
74//! <br>
75//!
76//! ## Implementation details
77//!
78//! Each rate limiter has two acquisition modes. A fast path and a slow path.
79//! The fast path is used if the desired number of tokens are readily available,
80//! and simply involves decrementing the number of tokens available in the
81//! shared pool.
82//!
83//! If the required number of tokens is not available, the task will be forced
84//! to be suspended until the next refill interval. Here one of the acquiring
85//! tasks will switch over to work as a *core*. This is known as *core
86//! switching*.
87//!
88//! ```
89//! use leaky_bucket::RateLimiter;
90//! use tokio::time::Duration;
91//!
92//! # #[tokio::main(flavor="current_thread", start_paused=true)] async fn main() {
93//! let limiter = RateLimiter::builder()
94//!     .initial(10)
95//!     .interval(Duration::from_millis(100))
96//!     .build();
97//!
98//! // This is instantaneous since the rate limiter starts with 10 tokens to
99//! // spare.
100//! limiter.acquire(10).await;
101//!
102//! // This however needs to core switch and wait for a while until the desired
103//! // number of tokens is available.
104//! limiter.acquire(3).await;
105//! # }
106//! ```
107//!
108//! The core is responsible for sleeping for the configured interval so that
109//! more tokens can be added. After which it ensures that any tasks that are
110//! waiting to acquire including itself are appropriately unsuspended.
111//!
112//! On-demand core switching is what allows this rate limiter implementation to
113//! work without a coordinating background thread. But we need to ensure that
114//! any asynchronous tasks that uses [`RateLimiter`] must either run an
115//! [`acquire`] call to completion, or be *cancelled* by being dropped.
116//!
117//! If none of these hold, the core might leak and be locked indefinitely
118//! preventing any future use of the rate limiter from making progress. This is
119//! similar to if you would lock an asynchronous [`Mutex`] but never drop its
120//! guard.
121//!
122//! > You can run this example with:
123//! >
124//! > ```sh
125//! > cargo run --example block_forever
126//! > ```
127//!
128//! ```no_run
129//! use std::future::Future;
130//! use std::sync::Arc;
131//! use std::task::Context;
132//!
133//! use leaky_bucket::RateLimiter;
134//!
135//! struct Waker;
136//! # impl std::task::Wake for Waker { fn wake(self: Arc<Self>) { } }
137//!
138//! # #[tokio::main(flavor="current_thread", start_paused=true)] async fn main() {
139//! let limiter = Arc::new(RateLimiter::builder().build());
140//!
141//! let waker = Arc::new(Waker).into();
142//! let mut cx = Context::from_waker(&waker);
143//!
144//! let mut a0 = Box::pin(limiter.acquire(1));
145//! // Poll once to ensure that the core task is assigned.
146//! assert!(a0.as_mut().poll(&mut cx).is_pending());
147//! assert!(a0.is_core());
148//!
149//! // We leak the core task, preventing the rate limiter from making progress
150//! // by assigning new core tasks.
151//! std::mem::forget(a0);
152//!
153//! // Awaiting acquire here would block forever.
154//! // limiter.acquire(1).await;
155//! # }
156//! ```
157//!
158//! <br>
159//!
160//! ## Fairness
161//!
162//! By default [`RateLimiter`] uses a *fair* scheduler. This ensures that the
163//! core task makes progress even if there are many tasks waiting to acquire
164//! tokens. This might cause more core switching, increasing the total work
165//! needed. An unfair scheduler is expected to do a bit less work under
166//! contention. But without fair scheduling some tasks might end up taking
167//! longer to acquire than expected.
168//!
169//! Unfair rate limiters also have access to a fast path for acquiring tokens,
170//! which might further improve throughput.
171//!
172//! This behavior can be tweaked with the [`Builder::fair`] option.
173//!
174//! ```
175//! use leaky_bucket::RateLimiter;
176//!
177//! let limiter = RateLimiter::builder()
178//!     .fair(false)
179//!     .build();
180//! ```
181//!
182//! The `unfair-scheduling` example can showcase this phenomenon.
183//!
184//! ```sh
185//! cargo run --example unfair_scheduling
186//! ```
187//!
188//! ```text
189//! # fair
190//! Max: 1011ms, Total: 1012ms
191//! Timings:
192//!  0: 101ms
193//!  1: 101ms
194//!  2: 101ms
195//!  3: 101ms
196//!  4: 101ms
197//!  ...
198//! # unfair
199//! Max: 1014ms, Total: 1014ms
200//! Timings:
201//!  0: 1014ms
202//!  1: 101ms
203//!  2: 101ms
204//!  3: 101ms
205//!  4: 101ms
206//!  ...
207//! ```
208//!
209//! As can be seen above the first task in the *unfair* scheduler takes longer
210//! to run because it prioritises releasing other tasks waiting to acquire over
211//! itself.
212//!
213//! [`acquire_one`]: https://docs.rs/leaky-bucket/1/leaky_bucket/struct.RateLimiter.html#method.acquire_one
214//! [`acquire`]: https://docs.rs/leaky-bucket/1/leaky_bucket/struct.RateLimiter.html#method.acquire
215//! [`Builder::fair`]: https://docs.rs/leaky-bucket/1/leaky_bucket/struct.Builder.html#method.fair
216//! [`Mutex`]: https://docs.rs/tokio/1/tokio/sync/struct.Mutex.html
217//! [`RateLimiter`]: https://docs.rs/leaky-bucket/1/leaky_bucket/struct.RateLimiter.html
218//! [`time` feature]: https://docs.rs/tokio/1/tokio/#feature-flags
219//! [`try_acquire`]: https://docs.rs/leaky-bucket/1/leaky_bucket/struct.RateLimiter.html#method.try_acquire
220//! [leaky bucket]: https://en.wikipedia.org/wiki/Leaky_bucket
221
222#![no_std]
223#![deny(missing_docs)]
224
225extern crate alloc;
226
227#[macro_use]
228extern crate std;
229
230use core::cell::UnsafeCell;
231use core::convert::TryFrom as _;
232use core::fmt;
233use core::future::Future;
234use core::mem::{self, ManuallyDrop};
235use core::ops::{Deref, DerefMut};
236use core::pin::Pin;
237use core::ptr;
238use core::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
239use core::task::{Context, Poll, Waker};
240
241use alloc::sync::Arc;
242
243use parking_lot::{Mutex, MutexGuard};
244use pin_project_lite::pin_project;
245use tokio::time::{self, Duration, Instant};
246
247#[cfg(feature = "tracing")]
248macro_rules! trace {
249    ($($arg:tt)*) => {
250        tracing::trace!($($arg)*)
251    };
252}
253
254#[cfg(not(feature = "tracing"))]
255macro_rules! trace {
256    ($($arg:tt)*) => {};
257}
258
259mod linked_list;
260use self::linked_list::{LinkedList, Node};
261
262/// Default factor for how to calculate max refill value.
263const DEFAULT_REFILL_MAX_FACTOR: usize = 10;
264
265/// Interval to bump the shared mutex guard to allow other parts of the system
266/// to make process. Processes which loop should use this number to determine
267/// how many times it should loop before calling [Guard::bump].
268///
269/// If we do not respect this limit we might inadvertently end up starving other
270/// tasks from making progress so that they can unblock.
271const BUMP_LIMIT: usize = 16;
272
273/// The maximum supported balance.
274const MAX_BALANCE: usize = isize::MAX as usize;
275
276/// Marker trait which indicates that a type represents a unique held critical section.
277trait IsCritical {}
278impl IsCritical for Critical {}
279impl IsCritical for Guard<'_> {}
280
281/// Linked task state.
282struct Task {
283    /// Remaining tokens that need to be satisfied.
284    remaining: usize,
285    /// If this node has been released or not. We make this an atomic to permit
286    /// access to it without synchronization.
287    complete: AtomicBool,
288    /// The waker associated with the node.
289    waker: Option<Waker>,
290}
291
292impl Task {
293    /// Construct a new task state with the given permits remaining.
294    const fn new() -> Self {
295        Self {
296            remaining: 0,
297            complete: AtomicBool::new(false),
298            waker: None,
299        }
300    }
301
302    /// Test if the current node is completed.
303    fn is_completed(&self) -> bool {
304        self.remaining == 0
305    }
306
307    /// Fill the current node from the given pool of tokens and modify it.
308    fn fill(&mut self, current: &mut usize) {
309        let removed = usize::min(self.remaining, *current);
310        self.remaining -= removed;
311        *current -= removed;
312    }
313}
314
315/// A borrowed rate limiter.
316struct BorrowedRateLimiter<'a>(&'a RateLimiter);
317
318impl Deref for BorrowedRateLimiter<'_> {
319    type Target = RateLimiter;
320
321    #[inline]
322    fn deref(&self) -> &RateLimiter {
323        self.0
324    }
325}
326
327struct Critical {
328    /// Waiter list.
329    waiters: LinkedList<Task>,
330    /// The deadline for when more tokens can be be added.
331    deadline: Instant,
332}
333
334#[repr(transparent)]
335struct Guard<'a> {
336    critical: MutexGuard<'a, Critical>,
337}
338
339impl Guard<'_> {
340    #[inline]
341    fn bump(this: &mut Guard<'_>) {
342        MutexGuard::bump(&mut this.critical)
343    }
344}
345
346impl Deref for Guard<'_> {
347    type Target = Critical;
348
349    #[inline]
350    fn deref(&self) -> &Critical {
351        &self.critical
352    }
353}
354
355impl DerefMut for Guard<'_> {
356    #[inline]
357    fn deref_mut(&mut self) -> &mut Critical {
358        &mut self.critical
359    }
360}
361
362impl Critical {
363    #[inline]
364    fn push_task_front(&mut self, task: &mut Node<Task>) {
365        // SAFETY: We both have mutable access to the node being pushed, and
366        // mutable access to the critical section through `self`. So we know we
367        // have exclusive tampering rights to the waiter queue.
368        unsafe {
369            self.waiters.push_front(task.into());
370        }
371    }
372
373    #[inline]
374    fn push_task(&mut self, task: &mut Node<Task>) {
375        // SAFETY: We both have mutable access to the node being pushed, and
376        // mutable access to the critical section through `self`. So we know we
377        // have exclusive tampering rights to the waiter queue.
378        unsafe {
379            self.waiters.push_back(task.into());
380        }
381    }
382
383    #[inline]
384    fn remove_task(&mut self, task: &mut Node<Task>) {
385        // SAFETY: We both have mutable access to the node being pushed, and
386        // mutable access to the critical section through `self`. So we know we
387        // have exclusive tampering rights to the waiter queue.
388        unsafe {
389            self.waiters.remove(task.into());
390        }
391    }
392
393    /// Release the current core. Beyond this point the current task may no
394    /// longer interact exclusively with the core.
395    #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "trace"))]
396    fn release(&mut self, state: &mut State<'_>) {
397        trace!("releasing core");
398        state.available = true;
399
400        // Find another task that might take over as core. Once it has acquired
401        // core status it will have to make sure it is no longer linked into the
402        // wait queue.
403        unsafe {
404            if let Some(node) = self.waiters.front() {
405                trace!(node = ?node, "waking next core");
406
407                if let Some(ref waker) = node.as_ref().waker {
408                    waker.wake_by_ref();
409                }
410            }
411        }
412    }
413}
414
415#[derive(Debug)]
416struct State<'a> {
417    /// Original state.
418    state: usize,
419    /// If the core is available or not.
420    available: bool,
421    /// The balance.
422    balance: usize,
423    /// The rate limiter the state is associated with.
424    lim: &'a RateLimiter,
425}
426
427impl<'a> State<'a> {
428    fn try_fast_path(mut self, permits: usize) -> bool {
429        let mut attempts = 0;
430
431        // Fast path where we just try to nab any available permit without
432        // locking.
433        //
434        // We do have to race against anyone else grabbing permits here when
435        // storing the state back.
436        while self.balance >= permits {
437            // Abandon fast path if we've tried too many times.
438            if attempts == BUMP_LIMIT {
439                break;
440            }
441
442            self.balance -= permits;
443
444            if let Err(new_state) = self.try_save() {
445                self = new_state;
446                attempts += 1;
447                continue;
448            }
449
450            return true;
451        }
452
453        false
454    }
455
456    /// Add tokens and release any pending tasks.
457    #[cfg_attr(
458        feature = "tracing",
459        tracing::instrument(skip(self, critical, f), level = "trace")
460    )]
461    #[inline]
462    fn add_tokens<F, O>(&mut self, critical: &mut Guard<'_>, tokens: usize, f: F) -> O
463    where
464        F: FnOnce(&mut Guard<'_>, &mut State) -> O,
465    {
466        if tokens > 0 {
467            debug_assert!(
468                tokens <= MAX_BALANCE,
469                "Additional tokens {} must be less than {}",
470                tokens,
471                MAX_BALANCE
472            );
473
474            self.balance = (self.balance + tokens).min(self.lim.max);
475            drain_wait_queue(critical, self);
476            let output = f(critical, self);
477            return output;
478        }
479
480        f(critical, self)
481    }
482
483    #[inline]
484    fn decode(state: usize, lim: &'a RateLimiter) -> Self {
485        State {
486            state,
487            available: state & 1 == 1,
488            balance: state >> 1,
489            lim,
490        }
491    }
492
493    #[inline]
494    fn encode(&self) -> usize {
495        (self.balance << 1) | usize::from(self.available)
496    }
497
498    /// Try to save the state, but only succeed if it hasn't been modified.
499    #[inline]
500    fn try_save(self) -> Result<(), Self> {
501        let this = ManuallyDrop::new(self);
502
503        match this.lim.state.compare_exchange(
504            this.state,
505            this.encode(),
506            Ordering::Release,
507            Ordering::Relaxed,
508        ) {
509            Ok(_) => Ok(()),
510            Err(state) => Err(State::decode(state, this.lim)),
511        }
512    }
513}
514
515impl Drop for State<'_> {
516    #[inline]
517    fn drop(&mut self) {
518        self.lim.state.store(self.encode(), Ordering::Release);
519    }
520}
521
522/// A token-bucket rate limiter.
523pub struct RateLimiter {
524    /// Tokens to add every `per` duration.
525    refill: usize,
526    /// Interval in milliseconds to add tokens.
527    interval: Duration,
528    /// Max number of tokens associated with the rate limiter.
529    max: usize,
530    /// If the rate limiter is fair or not.
531    fair: bool,
532    /// The state of the rate limiter.
533    state: AtomicUsize,
534    /// Critical state of the rate limiter.
535    critical: Mutex<Critical>,
536}
537
538impl RateLimiter {
539    /// Construct a new [`Builder`] for a [`RateLimiter`].
540    ///
541    /// # Examples
542    ///
543    /// ```
544    /// use leaky_bucket::RateLimiter;
545    /// use tokio::time::Duration;
546    ///
547    /// let limiter = RateLimiter::builder()
548    ///     .initial(100)
549    ///     .refill(100)
550    ///     .max(1000)
551    ///     .interval(Duration::from_millis(250))
552    ///     .fair(false)
553    ///     .build();
554    /// ```
555    pub fn builder() -> Builder {
556        Builder::default()
557    }
558
559    /// Get the refill amount  of this rate limiter as set through
560    /// [`Builder::refill`].
561    ///
562    /// # Examples
563    ///
564    /// ```
565    /// use leaky_bucket::RateLimiter;
566    ///
567    /// let limiter = RateLimiter::builder()
568    ///     .refill(1024)
569    ///     .build();
570    ///
571    /// assert_eq!(limiter.refill(), 1024);
572    /// ```
573    pub fn refill(&self) -> usize {
574        self.refill
575    }
576
577    /// Get the refill interval of this rate limiter as set through
578    /// [`Builder::interval`].
579    ///
580    /// # Examples
581    ///
582    /// ```
583    /// use leaky_bucket::RateLimiter;
584    /// use tokio::time::Duration;
585    ///
586    /// let limiter = RateLimiter::builder()
587    ///     .interval(Duration::from_millis(1000))
588    ///     .build();
589    ///
590    /// assert_eq!(limiter.interval(), Duration::from_millis(1000));
591    /// ```
592    pub fn interval(&self) -> Duration {
593        self.interval
594    }
595
596    /// Get the max value of this rate limiter as set through [`Builder::max`].
597    ///
598    /// # Examples
599    ///
600    /// ```
601    /// use leaky_bucket::RateLimiter;
602    ///
603    /// let limiter = RateLimiter::builder()
604    ///     .max(1024)
605    ///     .build();
606    ///
607    /// assert_eq!(limiter.max(), 1024);
608    /// ```
609    pub fn max(&self) -> usize {
610        self.max
611    }
612
613    /// Test if the current rate limiter is fair as specified through
614    /// [`Builder::fair`].
615    ///
616    /// # Examples
617    ///
618    /// ```
619    /// use leaky_bucket::RateLimiter;
620    ///
621    /// let limiter = RateLimiter::builder()
622    ///     .fair(true)
623    ///     .build();
624    ///
625    /// assert_eq!(limiter.is_fair(), true);
626    /// ```
627    pub fn is_fair(&self) -> bool {
628        self.fair
629    }
630
631    /// Get the current token balance.
632    ///
633    /// This indicates how many tokens can be requested without blocking.
634    ///
635    /// # Examples
636    ///
637    /// ```
638    /// use leaky_bucket::RateLimiter;
639    ///
640    /// # #[tokio::main(flavor="current_thread", start_paused=true)] async fn main() {
641    /// let limiter = RateLimiter::builder()
642    ///     .initial(100)
643    ///     .build();
644    ///
645    /// assert_eq!(limiter.balance(), 100);
646    /// limiter.acquire(10).await;
647    /// assert_eq!(limiter.balance(), 90);
648    /// # }
649    /// ```
650    pub fn balance(&self) -> usize {
651        self.state.load(Ordering::Acquire) >> 1
652    }
653
654    /// Acquire a single permit.
655    ///
656    /// # Examples
657    ///
658    /// ```
659    /// use leaky_bucket::RateLimiter;
660    ///
661    /// # #[tokio::main(flavor="current_thread", start_paused=true)] async fn main() {
662    /// let limiter = RateLimiter::builder()
663    ///     .initial(10)
664    ///     .build();
665    ///
666    /// limiter.acquire_one().await;
667    /// # }
668    /// ```
669    pub fn acquire_one(&self) -> Acquire<'_> {
670        self.acquire(1)
671    }
672
673    /// Acquire the given number of permits, suspending the current task until
674    /// they are available.
675    ///
676    /// If zero permits are specified, this function never suspends the current
677    /// task.
678    ///
679    /// # Examples
680    ///
681    /// ```
682    /// use leaky_bucket::RateLimiter;
683    ///
684    /// # #[tokio::main(flavor="current_thread", start_paused=true)] async fn main() {
685    /// let limiter = RateLimiter::builder()
686    ///     .initial(10)
687    ///     .build();
688    ///
689    /// limiter.acquire(10).await;
690    /// # }
691    /// ```
692    pub fn acquire(&self, permits: usize) -> Acquire<'_> {
693        Acquire {
694            inner: AcquireFut::new(BorrowedRateLimiter(self), permits),
695        }
696    }
697
698    /// Try to acquire the given number of permits, returning `true` if the
699    /// given number of permits were successfully acquired.
700    ///
701    /// If the scheduler is fair, and there are pending tasks waiting to acquire
702    /// tokens this method will return `false`.
703    ///
704    /// If zero permits are specified, this method returns `true`.
705    ///
706    /// # Examples
707    ///
708    /// ```
709    /// use leaky_bucket::RateLimiter;
710    /// use tokio::time;
711    ///
712    /// # #[tokio::main(flavor="current_thread", start_paused=true)] async fn main() {
713    /// let limiter = RateLimiter::builder().refill(1).initial(1).build();
714    ///
715    /// assert!(limiter.try_acquire(1));
716    /// assert!(!limiter.try_acquire(1));
717    /// assert!(limiter.try_acquire(0));
718    ///
719    /// time::sleep(limiter.interval() * 2).await;
720    ///
721    /// assert!(limiter.try_acquire(1));
722    /// assert!(limiter.try_acquire(1));
723    /// assert!(!limiter.try_acquire(1));
724    /// # }
725    /// ```
726    pub fn try_acquire(&self, permits: usize) -> bool {
727        if self.try_fast_path(permits) {
728            return true;
729        }
730
731        let mut critical = self.lock();
732
733        // Reload the state while we are under the critical lock, this
734        // ensures that the `available` flag is up-to-date since it is only
735        // ever modified while holding the critical lock.
736        let mut state = self.take();
737
738        // The core is *not* available, which also implies that there are tasks
739        // ahead which are busy.
740        if !state.available {
741            return false;
742        }
743
744        let now = Instant::now();
745
746        // Here we try to assume core duty temporarily to see if we can
747        // release a sufficient number of tokens to allow the current task
748        // to proceed.
749        if let Some((tokens, deadline)) = self.calculate_drain(critical.deadline, now) {
750            state.balance = (state.balance + tokens).min(self.max);
751            critical.deadline = deadline;
752        }
753
754        if state.balance >= permits {
755            state.balance -= permits;
756            return true;
757        }
758
759        false
760    }
761
762    /// Acquire a permit using an owned future.
763    ///
764    /// If zero permits are specified, this function never suspends the current
765    /// task.
766    ///
767    /// This required the [`RateLimiter`] to be wrapped inside of an
768    /// [`std::sync::Arc`] but will in contrast permit the acquire operation to
769    /// be owned by another struct making it more suitable for embedding.
770    ///
771    /// # Examples
772    ///
773    /// ```
774    /// use leaky_bucket::RateLimiter;
775    /// use std::sync::Arc;
776    ///
777    /// # #[tokio::main(flavor="current_thread", start_paused=true)] async fn main() {
778    /// let limiter = Arc::new(RateLimiter::builder().initial(10).build());
779    ///
780    /// limiter.acquire_owned(10).await;
781    /// # }
782    /// ```
783    ///
784    /// Example when embedded into another future. This wouldn't be possible
785    /// with [`RateLimiter::acquire`] since it would otherwise hold a reference
786    /// to the corresponding [`RateLimiter`] instance.
787    ///
788    /// ```
789    /// use std::future::Future;
790    /// use std::pin::Pin;
791    /// use std::sync::Arc;
792    /// use std::task::{Context, Poll};
793    ///
794    /// use leaky_bucket::{AcquireOwned, RateLimiter};
795    /// use pin_project::pin_project;
796    ///
797    /// #[pin_project]
798    /// struct MyFuture {
799    ///     limiter: Arc<RateLimiter>,
800    ///     #[pin]
801    ///     acquire: Option<AcquireOwned>,
802    /// }
803    ///
804    /// impl Future for MyFuture {
805    ///     type Output = ();
806    ///
807    ///     fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
808    ///         let mut this = self.project();
809    ///
810    ///         loop {
811    ///             if let Some(acquire) = this.acquire.as_mut().as_pin_mut() {
812    ///                 futures::ready!(acquire.poll(cx));
813    ///                 return Poll::Ready(());
814    ///             }
815    ///
816    ///             this.acquire.set(Some(this.limiter.clone().acquire_owned(100)));
817    ///         }
818    ///     }
819    /// }
820    ///
821    /// # #[tokio::main(flavor="current_thread", start_paused=true)] async fn main() {
822    /// let limiter = Arc::new(RateLimiter::builder().initial(100).build());
823    ///
824    /// let future = MyFuture { limiter, acquire: None };
825    /// future.await;
826    /// # }
827    /// ```
828    pub fn acquire_owned(self: Arc<Self>, permits: usize) -> AcquireOwned {
829        AcquireOwned {
830            inner: AcquireFut::new(self, permits),
831        }
832    }
833
834    /// Lock the critical section of the rate limiter and return the associated guard.
835    fn lock(&self) -> Guard<'_> {
836        Guard {
837            critical: self.critical.lock(),
838        }
839    }
840
841    /// Load the current state.
842    fn load(&self) -> State<'_> {
843        State::decode(self.state.load(Ordering::Acquire), self)
844    }
845
846    /// Take the current state, leaving the core state intact.
847    fn take(&self) -> State<'_> {
848        State::decode(self.state.swap(0, Ordering::Acquire), self)
849    }
850
851    /// Try to use fast path.
852    fn try_fast_path(&self, permits: usize) -> bool {
853        if permits == 0 {
854            return true;
855        }
856
857        if self.fair {
858            return false;
859        }
860
861        self.load().try_fast_path(permits)
862    }
863
864    /// Calculate refill amount. Returning a tuple of how much to fill and remaining
865    /// duration to sleep until the next refill time if appropriate.
866    ///
867    /// The maximum number of additional tokens this method will ever return is
868    /// limited to [`MAX_BALANCE`] to ensure that addition with an existing
869    /// balance will never overflow.
870    fn calculate_drain(&self, deadline: Instant, now: Instant) -> Option<(usize, Instant)> {
871        if now < deadline {
872            return None;
873        }
874
875        // Time elapsed in milliseconds since the last deadline.
876        let millis = self.interval.as_millis();
877        let since = now.saturating_duration_since(deadline).as_millis();
878
879        let periods = usize::try_from(since / millis + 1).unwrap_or(usize::MAX);
880
881        let tokens = periods
882            .checked_mul(self.refill)
883            .unwrap_or(MAX_BALANCE)
884            .min(MAX_BALANCE);
885
886        let rem = u64::try_from(since % millis).unwrap_or(u64::MAX);
887
888        // Calculated time remaining until the next deadline.
889        let deadline = now
890            + self
891                .interval
892                .saturating_sub(time::Duration::from_millis(rem));
893
894        Some((tokens, deadline))
895    }
896}
897
898impl fmt::Debug for RateLimiter {
899    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
900        f.debug_struct("RateLimiter")
901            .field("refill", &self.refill)
902            .field("interval", &self.interval)
903            .field("max", &self.max)
904            .field("fair", &self.fair)
905            .finish_non_exhaustive()
906    }
907}
908
909/// Refill the wait queue with the given number of tokens.
910#[cfg_attr(feature = "tracing", tracing::instrument(skip_all, level = "trace"))]
911fn drain_wait_queue(critical: &mut Guard<'_>, state: &mut State<'_>) {
912    trace!(?state, "releasing waiters");
913
914    let mut bump = 0;
915
916    // SAFETY: we're holding the lock guard to all the waiters so we can be
917    // sure that we have exclusive access to the wait queue.
918    unsafe {
919        while state.balance > 0 {
920            let mut node = match critical.waiters.pop_back() {
921                Some(node) => node,
922                None => break,
923            };
924
925            let n = node.as_mut();
926            n.fill(&mut state.balance);
927
928            trace! {
929                ?state,
930                remaining = n.remaining,
931                "filled node",
932            };
933
934            if !n.is_completed() {
935                critical.waiters.push_back(node);
936                break;
937            }
938
939            n.complete.store(true, Ordering::Release);
940
941            if let Some(waker) = n.waker.take() {
942                waker.wake();
943            }
944
945            bump += 1;
946
947            if bump == BUMP_LIMIT {
948                Guard::bump(critical);
949                bump = 0;
950            }
951        }
952    }
953}
954
955// SAFETY: All the internals of acquire is thread safe and correctly
956// synchronized. The embedded waiter queue doesn't have anything inherently
957// unsafe in it.
958unsafe impl Send for RateLimiter {}
959unsafe impl Sync for RateLimiter {}
960
961/// A builder for a [`RateLimiter`].
962pub struct Builder {
963    /// The max number of tokens.
964    max: Option<usize>,
965    /// The initial count of tokens.
966    initial: usize,
967    /// Tokens to add every `per` duration.
968    refill: usize,
969    /// Interval to add tokens in milliseconds.
970    interval: Duration,
971    /// If the rate limiter is fair or not.
972    fair: bool,
973}
974
975impl Builder {
976    /// Configure the max number of tokens to use.
977    ///
978    /// If unspecified, this will default to be 10 times the [`refill`] or the
979    /// [`initial`] value, whichever is largest.
980    ///
981    /// The maximum supported balance is limited to [`isize::MAX`].
982    ///
983    /// # Examples
984    ///
985    /// ```
986    /// use leaky_bucket::RateLimiter;
987    ///
988    /// let limiter = RateLimiter::builder()
989    ///     .max(10_000)
990    ///     .build();
991    /// ```
992    ///
993    /// [`refill`]: Builder::refill
994    /// [`initial`]: Builder::initial
995    pub fn max(&mut self, max: usize) -> &mut Self {
996        self.max = Some(max);
997        self
998    }
999
1000    /// Configure the initial number of tokens to configure. The default value
1001    /// is `0`.
1002    ///
1003    /// # Examples
1004    ///
1005    /// ```
1006    /// use leaky_bucket::RateLimiter;
1007    ///
1008    /// let limiter = RateLimiter::builder()
1009    ///     .initial(10)
1010    ///     .build();
1011    /// ```
1012    pub fn initial(&mut self, initial: usize) -> &mut Self {
1013        self.initial = initial;
1014        self
1015    }
1016
1017    /// Configure the time duration between which we add [`refill`] number to
1018    /// the bucket rate limiter.
1019    ///
1020    /// This is 100ms by default.
1021    ///
1022    /// # Panics
1023    ///
1024    /// This panics if the provided interval does not fit within the millisecond
1025    /// bounds of a [usize] or is zero.
1026    ///
1027    /// ```should_panic
1028    /// use leaky_bucket::RateLimiter;
1029    /// use tokio::time::Duration;
1030    ///
1031    /// let limiter = RateLimiter::builder()
1032    ///     .interval(Duration::from_secs(u64::MAX))
1033    ///     .build();
1034    /// ```
1035    ///
1036    /// ```should_panic
1037    /// use leaky_bucket::RateLimiter;
1038    /// use tokio::time::Duration;
1039    ///
1040    /// let limiter = RateLimiter::builder()
1041    ///     .interval(Duration::from_millis(0))
1042    ///     .build();
1043    /// ```
1044    ///
1045    /// # Examples
1046    ///
1047    /// ```
1048    /// use leaky_bucket::RateLimiter;
1049    /// use tokio::time::Duration;
1050    ///
1051    /// let limiter = RateLimiter::builder()
1052    ///     .interval(Duration::from_millis(100))
1053    ///     .build();
1054    /// ```
1055    ///
1056    /// [`refill`]: Builder::refill
1057    pub fn interval(&mut self, interval: Duration) -> &mut Self {
1058        assert! {
1059            interval.as_millis() != 0,
1060            "interval must be non-zero",
1061        };
1062        assert! {
1063            u64::try_from(interval.as_millis()).is_ok(),
1064            "interval must fit within a 64-bit integer"
1065        };
1066        self.interval = interval;
1067        self
1068    }
1069
1070    /// The number of tokens to add at each [`interval`] interval. The default
1071    /// value is `1`.
1072    ///
1073    /// # Panics
1074    ///
1075    /// Panics if a refill amount of `0` is specified.
1076    ///
1077    /// # Examples
1078    ///
1079    /// ```
1080    /// use leaky_bucket::RateLimiter;
1081    ///
1082    /// let limiter = RateLimiter::builder()
1083    ///     .refill(100)
1084    ///     .build();
1085    /// ```
1086    ///
1087    /// [`interval`]: Builder::interval
1088    pub fn refill(&mut self, refill: usize) -> &mut Self {
1089        assert!(refill > 0, "refill amount cannot be zero");
1090        self.refill = refill;
1091        self
1092    }
1093
1094    /// Configure the rate limiter to be fair.
1095    ///
1096    /// Fairness is enabled by deafult.
1097    ///
1098    /// Fairness ensures that tasks make progress in the order that they acquire
1099    /// even when the rate limiter is under contention. An unfair scheduler
1100    /// might have a higher total throughput.
1101    ///
1102    /// Fair scheduling also affects the behavior of
1103    /// [`RateLimiter::try_acquire`] which will return `false` if there are any
1104    /// pending tasks since they should be given priority.
1105    ///
1106    /// # Examples
1107    ///
1108    /// ```
1109    /// use leaky_bucket::RateLimiter;
1110    ///
1111    /// let limiter = RateLimiter::builder()
1112    ///     .refill(100)
1113    ///     .fair(false)
1114    ///     .build();
1115    /// ```
1116    pub fn fair(&mut self, fair: bool) -> &mut Self {
1117        self.fair = fair;
1118        self
1119    }
1120
1121    /// Construct a new [`RateLimiter`].
1122    ///
1123    /// # Examples
1124    ///
1125    /// ```
1126    /// use leaky_bucket::RateLimiter;
1127    /// use tokio::time::Duration;
1128    ///
1129    /// let limiter = RateLimiter::builder()
1130    ///     .refill(100)
1131    ///     .interval(Duration::from_millis(200))
1132    ///     .max(10_000)
1133    ///     .build();
1134    /// ```
1135    pub fn build(&self) -> RateLimiter {
1136        let deadline = Instant::now() + self.interval;
1137
1138        let initial = self.initial.min(MAX_BALANCE);
1139        let refill = self.refill.min(MAX_BALANCE);
1140
1141        let max = match self.max {
1142            Some(max) => max.min(MAX_BALANCE),
1143            None => refill
1144                .max(initial)
1145                .saturating_mul(DEFAULT_REFILL_MAX_FACTOR)
1146                .min(MAX_BALANCE),
1147        };
1148
1149        let initial = initial.min(max);
1150
1151        RateLimiter {
1152            refill,
1153            interval: self.interval,
1154            max,
1155            fair: self.fair,
1156            state: AtomicUsize::new(initial << 1 | 1),
1157            critical: Mutex::new(Critical {
1158                waiters: LinkedList::new(),
1159                deadline,
1160            }),
1161        }
1162    }
1163}
1164
1165/// Construct a new builder with default options.
1166///
1167/// # Examples
1168///
1169/// ```
1170/// use leaky_bucket::Builder;
1171///
1172/// let limiter = Builder::default().build();
1173/// ```
1174impl Default for Builder {
1175    fn default() -> Self {
1176        Self {
1177            max: None,
1178            initial: 0,
1179            refill: 1,
1180            interval: Duration::from_millis(100),
1181            fair: true,
1182        }
1183    }
1184}
1185
1186/// The state of an acquire operation.
1187#[derive(Debug, Clone, Copy)]
1188enum AcquireFutState {
1189    /// Initial unconfigured state.
1190    Initial,
1191    /// The acquire is waiting to be released by the core.
1192    Waiting,
1193    /// The operation is completed.
1194    Complete,
1195    /// The task is currently the core.
1196    Core,
1197}
1198
1199/// Inner state and methods of the acquire.
1200#[repr(transparent)]
1201struct AcquireFutInner {
1202    /// Aliased task state.
1203    node: UnsafeCell<Node<Task>>,
1204}
1205
1206impl AcquireFutInner {
1207    const fn new() -> AcquireFutInner {
1208        AcquireFutInner {
1209            node: UnsafeCell::new(Node::new(Task::new())),
1210        }
1211    }
1212
1213    /// Access the completion flag.
1214    pub fn complete(&self) -> &AtomicBool {
1215        // SAFETY: This is always safe to access since it's atomic.
1216        unsafe { &*ptr::addr_of!((*self.node.get()).complete) }
1217    }
1218
1219    /// Get the underlying task mutably.
1220    ///
1221    /// We prove that the caller does indeed have mutable access to the node by
1222    /// passing in a mutable reference to the critical section.
1223    #[inline]
1224    pub fn get_task<'crit, C>(
1225        self: Pin<&'crit mut Self>,
1226        critical: &'crit mut C,
1227    ) -> (&'crit mut C, &'crit mut Node<Task>)
1228    where
1229        C: IsCritical,
1230    {
1231        // SAFETY: Caller has exclusive access to the critical section, since
1232        // it's passed in as a mutable argument. We can also ensure that none of
1233        // the borrows outlive the provided closure.
1234        unsafe { (critical, &mut *self.node.get()) }
1235    }
1236
1237    /// Update the waiting state for this acquisition task. This might require
1238    /// that we update the associated waker.
1239    #[cfg_attr(
1240        feature = "tracing",
1241        tracing::instrument(skip(self, critical, waker), level = "trace")
1242    )]
1243    fn update(self: Pin<&mut Self>, critical: &mut Guard<'_>, waker: &Waker) {
1244        let (critical, task) = self.get_task(critical);
1245
1246        if !task.is_linked() {
1247            critical.push_task_front(task);
1248        }
1249
1250        let new_waker = match task.waker {
1251            None => true,
1252            Some(ref w) => !w.will_wake(waker),
1253        };
1254
1255        if new_waker {
1256            trace!("updating waker");
1257            task.waker = Some(waker.clone());
1258        }
1259    }
1260
1261    /// Ensure that the current core task is correctly linked up if needed.
1262    #[cfg_attr(
1263        feature = "tracing",
1264        tracing::instrument(skip(self, critical, lim), level = "trace")
1265    )]
1266    fn link_core(self: Pin<&mut Self>, critical: &mut Critical, lim: &RateLimiter) {
1267        let (critical, task) = self.get_task(critical);
1268
1269        match (lim.fair, task.is_linked()) {
1270            (true, false) => {
1271                // Fair scheduling needs to ensure that the core is part of the wait
1272                // queue, and will be woken up in-order with other tasks.
1273                critical.push_task(task);
1274            }
1275            (false, true) => {
1276                // Unfair scheduling will not wake the core in order, so
1277                // don't bother having it linked.
1278                critical.remove_task(task);
1279            }
1280            _ => {}
1281        }
1282    }
1283
1284    /// Release any remaining tokens which are associated with this particular task.
1285    fn release_remaining(
1286        self: Pin<&mut Self>,
1287        critical: &mut Guard<'_>,
1288        state: &mut State<'_>,
1289        permits: usize,
1290    ) {
1291        let (critical, task) = self.get_task(critical);
1292
1293        if task.is_linked() {
1294            critical.remove_task(task);
1295        }
1296
1297        // Hand back permits which we've acquired so far.
1298        let release = permits.saturating_sub(task.remaining);
1299        state.add_tokens(critical, release, |_, _| ());
1300    }
1301
1302    /// Drain the given number of tokens through the core. Returns `true` if the
1303    /// core has been completed.
1304    #[cfg_attr(
1305        feature = "tracing",
1306        tracing::instrument(skip(self, critical), level = "trace")
1307    )]
1308    fn drain_core(
1309        self: Pin<&mut Self>,
1310        critical: &mut Guard<'_>,
1311        state: &mut State<'_>,
1312        tokens: usize,
1313    ) -> bool {
1314        let completed = state.add_tokens(critical, tokens, |critical, state| {
1315            let (_, task) = self.get_task(critical);
1316
1317            // If the limiter is not fair, we need to in addition to draining
1318            // remaining tokens from linked nodes, drain it from ourselves. We
1319            // fill the current holder of the core last (self). To ensure that
1320            // it stays around for as long as possible.
1321            if !state.lim.fair {
1322                task.fill(&mut state.balance);
1323            }
1324
1325            task.is_completed()
1326        });
1327
1328        if completed {
1329            // Everything was drained, including the current core (if
1330            // appropriate). So we can release it now.
1331            critical.release(state);
1332        }
1333
1334        completed
1335    }
1336
1337    /// Assume the current core and calculate how long we must sleep for in
1338    /// order to do it.
1339    ///
1340    /// # Safety
1341    ///
1342    /// This might link the current task into the task queue, so the caller must
1343    /// ensure that it is pinned.
1344    #[cfg_attr(
1345        feature = "tracing",
1346        tracing::instrument(skip(self, critical), level = "trace")
1347    )]
1348    fn assume_core(
1349        mut self: Pin<&mut Self>,
1350        critical: &mut Guard<'_>,
1351        state: &mut State<'_>,
1352        now: Instant,
1353    ) -> bool {
1354        self.as_mut().link_core(critical, state.lim);
1355
1356        let (tokens, deadline) = match state.lim.calculate_drain(critical.deadline, now) {
1357            Some(tokens) => tokens,
1358            None => return true,
1359        };
1360
1361        // It is appropriate to update the deadline.
1362        critical.deadline = deadline;
1363        !self.drain_core(critical, state, tokens)
1364    }
1365}
1366
1367pin_project! {
1368    /// The future associated with acquiring permits from a rate limiter using
1369    /// [`RateLimiter::acquire`].
1370    #[project(!Unpin)]
1371    pub struct Acquire<'a> {
1372        #[pin]
1373        inner: AcquireFut<BorrowedRateLimiter<'a>>,
1374    }
1375}
1376
1377impl Acquire<'_> {
1378    /// Test if this acquire task is currently coordinating the rate limiter.
1379    ///
1380    /// # Examples
1381    ///
1382    /// ```
1383    /// use leaky_bucket::RateLimiter;
1384    /// use std::future::Future;
1385    /// use std::sync::Arc;
1386    /// use std::task::Context;
1387    ///
1388    /// struct Waker;
1389    /// # impl std::task::Wake for Waker { fn wake(self: Arc<Self>) { } }
1390    ///
1391    /// # #[tokio::main(flavor="current_thread", start_paused=true)] async fn main() {
1392    /// let limiter = RateLimiter::builder().build();
1393    ///
1394    /// let waker = Arc::new(Waker).into();
1395    /// let mut cx = Context::from_waker(&waker);
1396    ///
1397    /// let a1 = limiter.acquire(1);
1398    /// tokio::pin!(a1);
1399    ///
1400    /// assert!(!a1.is_core());
1401    /// assert!(a1.as_mut().poll(&mut cx).is_pending());
1402    /// assert!(a1.is_core());
1403    ///
1404    /// a1.as_mut().await;
1405    ///
1406    /// // After completion this is no longer a core.
1407    /// assert!(!a1.is_core());
1408    /// # }
1409    /// ```
1410    pub fn is_core(&self) -> bool {
1411        self.inner.is_core()
1412    }
1413}
1414
1415impl Future for Acquire<'_> {
1416    type Output = ();
1417
1418    #[inline]
1419    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
1420        self.project().inner.poll(cx)
1421    }
1422}
1423
1424pin_project! {
1425    /// The future associated with acquiring permits from a rate limiter using
1426    /// [`RateLimiter::acquire_owned`].
1427    #[project(!Unpin)]
1428    pub struct AcquireOwned {
1429        #[pin]
1430        inner: AcquireFut<Arc<RateLimiter>>,
1431    }
1432}
1433
1434impl AcquireOwned {
1435    /// Test if this acquire task is currently coordinating the rate limiter.
1436    ///
1437    /// # Examples
1438    ///
1439    /// ```
1440    /// use leaky_bucket::RateLimiter;
1441    /// use std::future::Future;
1442    /// use std::sync::Arc;
1443    /// use std::task::Context;
1444    ///
1445    /// struct Waker;
1446    /// # impl std::task::Wake for Waker { fn wake(self: Arc<Self>) { } }
1447    ///
1448    /// # #[tokio::main(flavor="current_thread", start_paused=true)] async fn main() {
1449    /// let limiter = Arc::new(RateLimiter::builder().build());
1450    ///
1451    /// let waker = Arc::new(Waker).into();
1452    /// let mut cx = Context::from_waker(&waker);
1453    ///
1454    /// let a1 = limiter.acquire_owned(1);
1455    /// tokio::pin!(a1);
1456    ///
1457    /// assert!(!a1.is_core());
1458    /// assert!(a1.as_mut().poll(&mut cx).is_pending());
1459    /// assert!(a1.is_core());
1460    ///
1461    /// a1.as_mut().await;
1462    ///
1463    /// // After completion this is no longer a core.
1464    /// assert!(!a1.is_core());
1465    /// # }
1466    /// ```
1467    pub fn is_core(&self) -> bool {
1468        self.inner.is_core()
1469    }
1470}
1471
1472impl Future for AcquireOwned {
1473    type Output = ();
1474
1475    #[inline]
1476    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
1477        self.project().inner.poll(cx)
1478    }
1479}
1480
1481pin_project! {
1482    #[project(!Unpin)]
1483    #[project = AcquireFutProj]
1484    struct AcquireFut<T>
1485    where
1486        T: Deref<Target = RateLimiter>,
1487    {
1488        lim: T,
1489        permits: usize,
1490        state: AcquireFutState,
1491        #[pin]
1492        sleep: Option<time::Sleep>,
1493        #[pin]
1494        inner: AcquireFutInner,
1495    }
1496
1497    impl<T> PinnedDrop for AcquireFut<T>
1498    where
1499        T: Deref<Target = RateLimiter>,
1500    {
1501        fn drop(this: Pin<&mut Self>) {
1502            let AcquireFutProj { lim, permits, state, inner, .. } = this.project();
1503
1504            let is_core = match *state {
1505                AcquireFutState::Waiting => false,
1506                AcquireFutState::Core { .. } => true,
1507                _ => return,
1508            };
1509
1510            let mut critical = lim.lock();
1511            let mut s = lim.take();
1512            inner.release_remaining(&mut critical, &mut s, *permits);
1513
1514            if is_core {
1515                critical.release(&mut s);
1516            }
1517
1518            *state = AcquireFutState::Complete;
1519        }
1520    }
1521}
1522
1523impl<T> AcquireFut<T>
1524where
1525    T: Deref<Target = RateLimiter>,
1526{
1527    #[inline]
1528    const fn new(lim: T, permits: usize) -> Self {
1529        Self {
1530            lim,
1531            permits,
1532            state: AcquireFutState::Initial,
1533            sleep: None,
1534            inner: AcquireFutInner::new(),
1535        }
1536    }
1537
1538    fn is_core(&self) -> bool {
1539        matches!(&self.state, AcquireFutState::Core { .. })
1540    }
1541}
1542
1543// SAFETY: All the internals of acquire is thread safe and correctly
1544// synchronized. The embedded waiter queue doesn't have anything inherently
1545// unsafe in it.
1546unsafe impl<T> Send for AcquireFut<T> where T: Send + Deref<Target = RateLimiter> {}
1547unsafe impl<T> Sync for AcquireFut<T> where T: Sync + Deref<Target = RateLimiter> {}
1548
1549impl<T> Future for AcquireFut<T>
1550where
1551    T: Deref<Target = RateLimiter>,
1552{
1553    type Output = ();
1554
1555    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
1556        let AcquireFutProj {
1557            lim,
1558            permits,
1559            state,
1560            mut sleep,
1561            inner: mut internal,
1562            ..
1563        } = self.project();
1564
1565        // Hold onto the critical lock for core operations, but only acquire it
1566        // when strictly necessary.
1567        let mut critical;
1568
1569        // Shared state.
1570        //
1571        // Once we are holding onto the critical lock, we take the entire state
1572        // to ensure that any fast-past negotiators do not observe any available
1573        // permits while potential core work is ongoing.
1574        let mut s;
1575
1576        // Hold onto any call to `Instant::now` which we might perform, so we
1577        // don't have to get the current time multiple times.
1578        let outer_now;
1579
1580        match *state {
1581            AcquireFutState::Complete => {
1582                return Poll::Ready(());
1583            }
1584            AcquireFutState::Initial => {
1585                // If the rate limiter is not fair, try to oppurtunistically
1586                // just acquire a permit through the known atomic state.
1587                //
1588                // This is known as the fast path, but requires acquire to raise
1589                // against other tasks when storing the state back.
1590                if lim.try_fast_path(*permits) {
1591                    *state = AcquireFutState::Complete;
1592                    return Poll::Ready(());
1593                }
1594
1595                critical = lim.lock();
1596                s = lim.take();
1597
1598                let now = Instant::now();
1599
1600                // If we've hit a deadline, calculate the number of tokens
1601                // to drain and perform it in line here. This is necessary
1602                // because the core isn't aware of how long we sleep between
1603                // each acquire, so we need to perform some of the drain
1604                // work here in order to avoid acruing a debt that needs to
1605                // be filled later in.
1606                //
1607                // If we didn't do this, and the process slept for a long
1608                // time, the next time a core is acquired it would be very
1609                // far removed from the expected deadline and has no idea
1610                // when permits were acquired, so it would over-eagerly
1611                // release a lot of acquires and accumulate permits.
1612                //
1613                // This is tested for in the `test_idle` suite of tests.
1614                let tokens =
1615                    if let Some((tokens, deadline)) = lim.calculate_drain(critical.deadline, now) {
1616                        trace!(tokens, "inline drain");
1617                        // We pre-emptively update the deadline of the core
1618                        // since it might bump, and we don't want other
1619                        // processes to observe that the deadline has been
1620                        // reached.
1621                        critical.deadline = deadline;
1622                        tokens
1623                    } else {
1624                        0
1625                    };
1626
1627                let completed = s.add_tokens(&mut critical, tokens, |critical, s| {
1628                    let (_, task) = internal.as_mut().get_task(critical);
1629                    task.remaining = *permits;
1630                    task.fill(&mut s.balance);
1631                    task.is_completed()
1632                });
1633
1634                if completed {
1635                    *state = AcquireFutState::Complete;
1636                    return Poll::Ready(());
1637                }
1638
1639                // Try to take over as core. If we're unsuccessful we just
1640                // ensure that we're linked into the wait queue.
1641                if !mem::take(&mut s.available) {
1642                    internal.as_mut().update(&mut critical, cx.waker());
1643                    *state = AcquireFutState::Waiting;
1644                    return Poll::Pending;
1645                }
1646
1647                // SAFETY: This is done in a pinned section, so we know that
1648                // the linked section stays alive for the duration of this
1649                // future due to pinning guarantees.
1650                internal.as_mut().link_core(&mut critical, lim);
1651                Guard::bump(&mut critical);
1652                *state = AcquireFutState::Core;
1653                outer_now = Some(now);
1654            }
1655            AcquireFutState::Waiting => {
1656                // If we are complete, then return as ready.
1657                //
1658                // This field is atomic, so we can safely read it under shared
1659                // access and do not require a lock.
1660                if internal.complete().load(Ordering::Acquire) {
1661                    *state = AcquireFutState::Complete;
1662                    return Poll::Ready(());
1663                }
1664
1665                // Note: we need to operate under this lock to ensure that
1666                // the core acquired here (or elsewhere) observes that the
1667                // current task has been linked up.
1668                critical = lim.lock();
1669                s = lim.take();
1670
1671                // Try to take over as core. If we're unsuccessful we
1672                // just ensure that we're linked into the wait queue.
1673                if !mem::take(&mut s.available) {
1674                    internal.update(&mut critical, cx.waker());
1675                    return Poll::Pending;
1676                }
1677
1678                let now = Instant::now();
1679
1680                // This is done in a pinned section, so we know that the linked
1681                // section stays alive for the duration of this future due to
1682                // pinning guarantees.
1683                if !internal.as_mut().assume_core(&mut critical, &mut s, now) {
1684                    // Marks as completed.
1685                    *state = AcquireFutState::Complete;
1686                    return Poll::Ready(());
1687                }
1688
1689                Guard::bump(&mut critical);
1690                *state = AcquireFutState::Core;
1691                outer_now = Some(now);
1692            }
1693            AcquireFutState::Core => {
1694                critical = lim.lock();
1695                s = lim.take();
1696                outer_now = None;
1697            }
1698        }
1699
1700        trace!(until = ?critical.deadline, "taking over core and sleeping");
1701
1702        let mut sleep = match sleep.as_mut().as_pin_mut() {
1703            Some(mut sleep) => {
1704                if sleep.deadline() != critical.deadline {
1705                    sleep.as_mut().reset(critical.deadline);
1706                }
1707
1708                sleep
1709            }
1710            None => {
1711                sleep.set(Some(time::sleep_until(critical.deadline)));
1712                sleep.as_mut().as_pin_mut().unwrap()
1713            }
1714        };
1715
1716        if sleep.as_mut().poll(cx).is_pending() {
1717            return Poll::Pending;
1718        }
1719
1720        critical.deadline = outer_now.unwrap_or_else(Instant::now) + lim.interval;
1721
1722        if internal.drain_core(&mut critical, &mut s, lim.refill) {
1723            *state = AcquireFutState::Complete;
1724            return Poll::Ready(());
1725        }
1726
1727        cx.waker().wake_by_ref();
1728        Poll::Pending
1729    }
1730}
1731
1732#[cfg(test)]
1733mod tests {
1734    use super::{Acquire, AcquireOwned, RateLimiter};
1735
1736    fn is_send<T: Send>() {}
1737    fn is_sync<T: Sync>() {}
1738
1739    #[test]
1740    fn assert_send_sync() {
1741        is_send::<AcquireOwned>();
1742        is_sync::<AcquireOwned>();
1743
1744        is_send::<RateLimiter>();
1745        is_sync::<RateLimiter>();
1746
1747        is_send::<Acquire<'_>>();
1748        is_sync::<Acquire<'_>>();
1749    }
1750}