leaky-bucket 0.12.2

Futures-aware rate limiter implementation.
Documentation
//! [<img alt="github" src="https://img.shields.io/badge/github-udoprog/leaky--bucket-8da0cb?style=for-the-badge&logo=github" height="20">](https://github.com/udoprog/leaky-bucket)
//! [<img alt="crates.io" src="https://img.shields.io/crates/v/leaky-bucket.svg?style=for-the-badge&color=fc8d62&logo=rust" height="20">](https://crates.io/crates/leaky-bucket)
//! [<img alt="docs.rs" src="https://img.shields.io/badge/docs.rs-leaky--bucket-66c2a5?style=for-the-badge&logoColor=white&logo=data:image/svg+xml;base64,PHN2ZyByb2xlPSJpbWciIHhtbG5zPSJodHRwOi8vd3d3LnczLm9yZy8yMDAwL3N2ZyIgdmlld0JveD0iMCAwIDUxMiA1MTIiPjxwYXRoIGZpbGw9IiNmNWY1ZjUiIGQ9Ik00ODguNiAyNTAuMkwzOTIgMjE0VjEwNS41YzAtMTUtOS4zLTI4LjQtMjMuNC0zMy43bC0xMDAtMzcuNWMtOC4xLTMuMS0xNy4xLTMuMS0yNS4zIDBsLTEwMCAzNy41Yy0xNC4xIDUuMy0yMy40IDE4LjctMjMuNCAzMy43VjIxNGwtOTYuNiAzNi4yQzkuMyAyNTUuNSAwIDI2OC45IDAgMjgzLjlWMzk0YzAgMTMuNiA3LjcgMjYuMSAxOS45IDMyLjJsMTAwIDUwYzEwLjEgNS4xIDIyLjEgNS4xIDMyLjIgMGwxMDMuOS01MiAxMDMuOSA1MmMxMC4xIDUuMSAyMi4xIDUuMSAzMi4yIDBsMTAwLTUwYzEyLjItNi4xIDE5LjktMTguNiAxOS45LTMyLjJWMjgzLjljMC0xNS05LjMtMjguNC0yMy40LTMzLjd6TTM1OCAyMTQuOGwtODUgMzEuOXYtNjguMmw4NS0zN3Y3My4zek0xNTQgMTA0LjFsMTAyLTM4LjIgMTAyIDM4LjJ2LjZsLTEwMiA0MS40LTEwMi00MS40di0uNnptODQgMjkxLjFsLTg1IDQyLjV2LTc5LjFsODUtMzguOHY3NS40em0wLTExMmwtMTAyIDQxLjQtMTAyLTQxLjR2LS42bDEwMi0zOC4yIDEwMiAzOC4ydi42em0yNDAgMTEybC04NSA0Mi41di03OS4xbDg1LTM4Ljh2NzUuNHptMC0xMTJsLTEwMiA0MS40LTEwMi00MS40di0uNmwxMDItMzguMiAxMDIgMzguMnYuNnoiPjwvcGF0aD48L3N2Zz4K" height="20">](https://docs.rs/leaky-bucket)
//! [<img alt="build status" src="https://img.shields.io/github/actions/workflow/status/udoprog/leaky-bucket/ci.yml?branch=main&style=for-the-badge" height="20">](https://github.com/udoprog/leaky-bucket/actions?query=branch%3Amain)
//!
//! A token-based rate limiter based on the [leaky bucket] algorithm.
//!
//! If the bucket overflows and goes over its max configured capacity, the task
//! that tried to acquire the tokens will be suspended until the required number
//! of tokens has been drained from the bucket.
//!
//! Since this crate uses timing facilities from tokio it has to be used within
//! a Tokio runtime with the [`time` feature] enabled.
//!
//! <br>
//!
//! ## Usage
//!
//! Add the following to your `Cargo.toml`:
//!
//! ```toml
//! leaky-bucket = "0.12.2"
//! ```
//!
//! <br>
//!
//! ## Examples
//!
//! The core type is the [`RateLimiter`] type, which allows for limiting the
//! throughput of a section using its [`acquire`] and [`acquire_one`] methods.
//!
//! ```
//! use leaky_bucket::RateLimiter;
//! use std::time;
//!
//! #[tokio::main]
//! async fn main() {
//!     let limiter = RateLimiter::builder()
//!         .max(10)
//!         .initial(0)
//!         .refill(5)
//!         .build();
//!
//!     let start = time::Instant::now();
//!
//!     println!("Waiting for permit...");
//!
//!     // Should take ~400 ms to acquire in total.
//!     let a = limiter.acquire(7);
//!     let b = limiter.acquire(3);
//!     let c = limiter.acquire(10);
//!
//!     let ((), (), ()) = tokio::join!(a, b, c);
//!
//!     println!(
//!         "I made it in {:?}!",
//!         time::Instant::now().duration_since(start)
//!     );
//! }
//! ```
//!
//! <br>
//!
//! ## Implementation details
//!
//! Each rate limiter has two acquisition modes. A fast path and a slow path.
//! The fast path is used if the desired number of tokens are readily available,
//! and involves incrementing an atomic counter indicating that the acquired
//! number of tokens have been added to the bucket.
//!
//! If this counter goes over its configured maximum capacity, it overflows into
//! a slow path. Here one of the acquiring tasks will switch over to work as a
//! *core*. This is known as *core switching*.
//!
//! ```
//! use leaky_bucket::RateLimiter;
//! use std::time;
//!
//! # #[tokio::main] async fn main() {
//! let limiter = RateLimiter::builder()
//!     .initial(10)
//!     .interval(time::Duration::from_millis(100))
//!     .build();
//!
//! // This is instantaneous since the rate limiter starts with 10 tokens to
//! // spare.
//! limiter.acquire(10).await;
//!
//! // This however needs to core switch and wait for a while until the desired
//! // number of tokens is available.
//! limiter.acquire(3).await;
//! # }
//! ```
//!
//! The core is responsible for sleeping for the configured interval so that
//! more tokens can be added. After which it ensures that any tasks that are
//! waiting to acquire including itself are appropriately unsuspended.
//!
//! On-demand core switching is what allows this rate limiter implementation to
//! work without a coordinating background thread. But we need to ensure that
//! any asynchronous tasks that uses [`RateLimiter`] must either run an
//! [`acquire`] call to completion, or be *cancelled* by being dropped.
//!
//! If none of these hold, the core might leak and be locked indefinitely
//! preventing any future use of the rate limiter from making progress. This is
//! similar to if you would lock an asynchronous [`Mutex`] but never drop its
//! guard.
//!
//! > You can run this example with:
//! >
//! > ```sh
//! > cargo run --example block-forever
//! > ```
//!
//! ```
//! use leaky_bucket::RateLimiter;
//! use std::future::Future;
//! use std::sync::Arc;
//! use std::task::Context;
//!
//! struct Waker;
//! # impl std::task::Wake for Waker { fn wake(self: Arc<Self>) { } }
//!
//! # #[tokio::main] async fn main() {
//! let limiter = Arc::new(RateLimiter::builder().build());
//!
//! let waker = Arc::new(Waker).into();
//! let mut cx = Context::from_waker(&waker);
//!
//! let mut a0 = Box::pin(limiter.acquire(1));
//! // Poll once to ensure that the core task is assigned.
//! assert!(a0.as_mut().poll(&mut cx).is_pending());
//! assert!(a0.is_core());
//!
//! // We leak the core task, preventing the rate limiter from making progress
//! // by assigning new core tasks.
//! std::mem::forget(a0);
//!
//! // Awaiting acquire here would block forever.
//! // limiter.acquire(1).await;
//! # }
//! ```
//!
//! <br>
//!
//! ## Fairness
//!
//! By default [`RateLimiter`] uses a *fair* scheduler. This ensures that the
//! core task makes progress even if there are many tasks waiting to acquire
//! tokens. As a result it causes more frequent core switching, increasing the
//! total work needed. An unfair scheduler is expected to do a bit less work
//! under contention. But without fair scheduling some tasks might end up taking
//! longer to acquire than expected.
//!
//! This behavior can be tweaked with the [`Builder::fair`] option.
//!
//! ```
//! use leaky_bucket::RateLimiter;
//!
//! let limiter = RateLimiter::builder()
//!     .fair(false)
//!     .build();
//! ```
//!
//! The `unfair-scheduling` example can showcase this phenomenon.
//!
//! ```sh
//! cargh run --example unfair-scheduling
//! ```
//!
//! ```text
//! # fair
//! Max: 1011ms, Total: 1012ms
//! Timings:
//!  0: 101ms
//!  1: 101ms
//!  2: 101ms
//!  3: 101ms
//!  4: 101ms
//!  ...
//! # unfair
//! Max: 1014ms, Total: 1014ms
//! Timings:
//!  0: 1014ms
//!  1: 101ms
//!  2: 101ms
//!  3: 101ms
//!  4: 101ms
//!  ...
//! ```
//!
//! As can be seen above the first task in the *unfair* scheduler takes longer
//! to run because it prioritises releasing other tasks waiting to acquire over
//! itself.
//!
//! [`acquire_one`]: https://docs.rs/leaky-bucket/0/leaky_bucket/struct.RateLimiter.html#method.acquire_one
//! [`acquire`]: https://docs.rs/leaky-bucket/0/leaky_bucket/struct.RateLimiter.html#method.acquire
//! [`Builder::fair`]: https://docs.rs/leaky-bucket/0/leaky_bucket/struct.Builder.html#method.fair
//! [`Mutex`]: https://docs.rs/tokio/1/tokio/sync/struct.Mutex.html
//! [`RateLimiter`]: https://docs.rs/leaky-bucket/0/leaky_bucket/struct.RateLimiter.html
//! [`time` feature]: https://docs.rs/tokio/1/tokio/#feature-flags
//! [leaky bucket]: https://en.wikipedia.org/wiki/Leaky_bucket

#![deny(missing_docs)]
#![deny(rustdoc::missing_doc_code_examples)]

use parking_lot::{Mutex, MutexGuard};
use std::cell::UnsafeCell;
use std::convert::TryFrom as _;
use std::fmt;
use std::future::Future;
use std::marker;
use std::mem;
use std::pin::Pin;
use std::ptr;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::task::{Context, Poll, Waker};
use tokio::time;
use tracing::trace;

#[doc(hidden)]
pub mod linked_list;
use self::linked_list::{LinkedList, Node};

/// Default factor for how to calculate max refill value.
const DEFAULT_REFILL_MAX_FACTOR: usize = 10;

/// Interval to bump the shared mutex guard to allow other parts of the system
/// to make process. Processes which loop should use this number to determine
/// how many times it should loop before calling [MutexGuard::bump].
///
/// If we do not respect this limit we might inadvertently end up starving other
/// tasks from making progress so that they can unblock.
const BUMP_LIMIT: usize = 16;

/// Linked task state.
struct Task {
    /// Remaining tokens that need to be satisfied.
    remaining: usize,
    /// Link to [Linking::complete].
    complete: Option<ptr::NonNull<AtomicBool>>,
    /// The waker associated with the node.
    waker: Option<Waker>,
}

impl Task {
    /// Construct a new task state with the given permits remaining.
    const fn new() -> Self {
        Self {
            remaining: 0,
            complete: None,
            waker: None,
        }
    }

    /// Test if the current node is completed.
    fn is_completed(&self) -> bool {
        self.remaining == 0
    }

    /// Fill the current node from the given pool of tokens and modify it.
    fn fill(&mut self, current: &mut usize) {
        let removed = usize::min(self.remaining, *current);
        self.remaining -= removed;
        *current -= removed;
    }
}

/// A borrowed rate limiter.
struct BorrowedRateLimiter<'a>(&'a RateLimiter);

impl AsRef<RateLimiter> for BorrowedRateLimiter<'_> {
    fn as_ref(&self) -> &RateLimiter {
        self.0
    }
}

struct Critical {
    /// Current balance of tokens. A value of 0 means that it is empty. Goes up
    /// to [`RateLimiter::max`].
    balance: usize,
    /// Waiter list.
    waiters: LinkedList<Task>,
    /// The deadline for when more tokens can be be added.
    deadline: time::Instant,
    /// If the core is available.
    available: bool,
}

impl Critical {
    /// Release the current core. Beyond this point the current task may no
    /// longer interact exclusively with the core.
    #[tracing::instrument(skip(self), level = "trace")]
    fn release(&mut self) {
        trace!("releasing core");
        self.available = true;

        // Find another task that might take over as core. Once it has acquired
        // core status it will have to make sure it is no longer linked into the
        // wait queue.
        //
        // We have to do this, because another task might miss that the core is
        // available since it's hidden behind an atomic, so we wake any task up
        // to ensure that it will always be picked up.
        //
        // Safety: We're holding the lock guard to all the waiters so we can be
        // certain that we have exclusive access.
        unsafe {
            if let Some(mut node) = self.waiters.front_mut() {
                trace!(node = ?node, "waking next core");

                if let Some(waker) = node.as_mut().waker.take() {
                    waker.wake();
                }
            }
        }
    }
}

/// A token-bucket rate limiter.
pub struct RateLimiter {
    /// Tokens to add every `per` duration.
    refill: usize,
    /// Interval in milliseconds to add tokens.
    interval: time::Duration,
    /// Max number of tokens associated with the rate limiter.
    max: usize,
    /// If the rate limiter is fair or not.
    fair: bool,
    /// Critical state of the rate limiter.
    critical: Mutex<Critical>,
}

impl RateLimiter {
    /// Construct a new [`Builder`] for a [`RateLimiter`].
    ///
    /// # Examples
    ///
    /// ```
    /// use leaky_bucket::RateLimiter;
    /// use std::time::Duration;
    ///
    /// let limiter = RateLimiter::builder()
    ///     .initial(100)
    ///     .refill(100)
    ///     .max(1000)
    ///     .interval(Duration::from_millis(250))
    ///     .fair(false)
    ///     .build();
    /// ```
    pub fn builder() -> Builder {
        Builder::default()
    }

    /// Get the refill amount  of this rate limiter as set through
    /// [`Builder::refill`].
    ///
    /// # Examples
    ///
    /// ```
    /// use leaky_bucket::RateLimiter;
    ///
    /// let limiter = RateLimiter::builder()
    ///     .refill(1024)
    ///     .build();
    ///
    /// assert_eq!(limiter.refill(), 1024);
    /// ```
    pub fn refill(&self) -> usize {
        self.refill
    }

    /// Get the refill interval of this rate limiter as set through
    /// [`Builder::interval`].
    ///
    /// # Examples
    ///
    /// ```
    /// use std::time::Duration;
    ///
    /// use leaky_bucket::RateLimiter;
    ///
    /// let limiter = RateLimiter::builder()
    ///     .interval(Duration::from_millis(1000))
    ///     .build();
    ///
    /// assert_eq!(limiter.interval(), Duration::from_millis(1000));
    /// ```
    pub fn interval(&self) -> time::Duration {
        self.interval
    }

    /// Get the max value of this rate limiter as set through [`Builder::max`].
    ///
    /// # Examples
    ///
    /// ```
    /// use leaky_bucket::RateLimiter;
    ///
    /// let limiter = RateLimiter::builder()
    ///     .max(1024)
    ///     .build();
    ///
    /// assert_eq!(limiter.max(), 1024);
    /// ```
    pub fn max(&self) -> usize {
        self.max
    }

    /// Test if the current rate limiter is fair as specified through
    /// [`Builder::fair`].
    ///
    /// # Examples
    ///
    /// ```
    /// use leaky_bucket::RateLimiter;
    ///
    /// let limiter = RateLimiter::builder()
    ///     .fair(true)
    ///     .build();
    ///
    /// assert_eq!(limiter.is_fair(), true);
    /// ```
    pub fn is_fair(&self) -> bool {
        self.fair
    }

    /// Get the current token balance.
    ///
    /// This indicates how many tokens can be requested without blocking.
    ///
    /// # Examples
    ///
    /// ```
    /// use leaky_bucket::RateLimiter;
    ///
    /// # #[tokio::main] async fn main() {
    /// let limiter = RateLimiter::builder()
    ///     .initial(100)
    ///     .build();
    ///
    /// assert_eq!(limiter.balance(), 100);
    /// limiter.acquire(10).await;
    /// assert_eq!(limiter.balance(), 90);
    /// # }
    /// ```
    pub fn balance(&self) -> usize {
        self.critical.lock().balance
    }

    /// Acquire a single permit.
    ///
    /// # Examples
    ///
    /// ```
    /// use leaky_bucket::RateLimiter;
    ///
    /// # #[tokio::main] async fn main() {
    /// let limiter = RateLimiter::builder()
    ///     .initial(10)
    ///     .build();
    ///
    /// limiter.acquire_one().await;
    /// # }
    /// ```
    pub fn acquire_one(&self) -> Acquire<'_> {
        self.acquire(1)
    }

    /// Acquire the given number of permits, suspending the current task until
    /// they are available.
    ///
    /// If zero permits are specified, this function never suspends the current
    /// task.
    ///
    /// # Examples
    ///
    /// ```
    /// use leaky_bucket::RateLimiter;
    ///
    /// # #[tokio::main] async fn main() {
    /// let limiter = RateLimiter::builder()
    ///     .initial(10)
    ///     .build();
    ///
    /// limiter.acquire(10).await;
    /// # }
    /// ```
    pub fn acquire(&self, permits: usize) -> Acquire<'_> {
        Acquire(AcquireFut::new(BorrowedRateLimiter(self), permits))
    }

    /// Acquire a permit using an owned future.
    ///
    /// If zero permits are specified, this function never suspends the current
    /// task.
    ///
    /// This required the [`RateLimiter`] to be wrapped inside of an
    /// [`std::sync::Arc`] but will in contrast permit the acquire operation to
    /// be owned by another struct making it more suitable for embedding.
    ///
    /// # Examples
    ///
    /// ```
    /// use leaky_bucket::RateLimiter;
    /// use std::sync::Arc;
    ///
    /// # #[tokio::main] async fn main() {
    /// let limiter = Arc::new(RateLimiter::builder().initial(10).build());
    ///
    /// limiter.acquire_owned(10).await;
    /// # }
    /// ```
    ///
    /// Example when embedded into another future. This wouldn't be possible
    /// with [`RateLimiter::acquire`] since it would otherwise hold a reference
    /// to the corresponding [`RateLimiter`] instance.
    ///
    /// ```
    /// use leaky_bucket::{AcquireOwned, RateLimiter};
    /// use pin_project::pin_project;
    /// use std::future::Future;
    /// use std::pin::Pin;
    /// use std::sync::Arc;
    /// use std::task::{Context, Poll};
    /// use std::time::Duration;
    ///
    /// #[pin_project]
    /// struct MyFuture {
    ///     limiter: Arc<RateLimiter>,
    ///     #[pin]
    ///     acquire: Option<AcquireOwned>,
    /// }
    ///
    /// impl Future for MyFuture {
    ///     type Output = ();
    ///
    ///     fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
    ///         let mut this = self.project();
    ///
    ///         loop {
    ///             if let Some(acquire) = this.acquire.as_mut().as_pin_mut() {
    ///                 futures::ready!(acquire.poll(cx));
    ///                 return Poll::Ready(());
    ///             }
    ///
    ///             this.acquire.set(Some(this.limiter.clone().acquire_owned(100)));
    ///         }
    ///     }
    /// }
    ///
    /// # #[tokio::main] async fn main() {
    /// let limiter = Arc::new(RateLimiter::builder().initial(100).build());
    ///
    /// let future = MyFuture { limiter, acquire: None };
    /// future.await;
    /// # }
    /// ```
    pub fn acquire_owned(self: Arc<Self>, permits: usize) -> AcquireOwned {
        AcquireOwned(AcquireFut::new(self, permits))
    }
}

// Safety: All the internals of acquire is thread safe and correctly
// synchronized. The embedded waiter queue doesn't have anything inherently
// unsafe in it.
unsafe impl Send for RateLimiter {}
unsafe impl Sync for RateLimiter {}

/// A builder for a [`RateLimiter`].
pub struct Builder {
    /// The max number of tokens.
    max: Option<usize>,
    /// The initial count of tokens.
    initial: usize,
    /// Tokens to add every `per` duration.
    refill: usize,
    /// Interval to add tokens in milliseconds.
    interval: time::Duration,
    /// If the rate limiter is fair or not.
    fair: bool,
}

impl Builder {
    /// Configure the max number of tokens to use.
    ///
    /// If unspecified, this will default to be 2 times the [`refill`] or the
    /// [`initial`] value, whichever is largest.
    ///
    /// # Examples
    ///
    /// ```
    /// use leaky_bucket::RateLimiter;
    ///
    /// let limiter = RateLimiter::builder()
    ///     .max(10_000)
    ///     .build();
    /// ```
    ///
    /// [`refill`]: Builder::refill
    /// [`initial`]: Builder::initial
    pub fn max(&mut self, max: usize) -> &mut Self {
        self.max = Some(max);
        self
    }

    /// Configure the initial number of tokens to configure. The default value
    /// is `0`.
    ///
    /// # Examples
    ///
    /// ```
    /// use leaky_bucket::RateLimiter;
    ///
    /// let limiter = RateLimiter::builder()
    ///     .initial(10)
    ///     .build();
    /// ```
    pub fn initial(&mut self, initial: usize) -> &mut Self {
        self.initial = initial;
        self
    }

    /// Configure the time duration between which we add [`refill`] number to
    /// the bucket rate limiter.
    ///
    /// # Panics
    ///
    /// This panics if the provided interval does not fit within the millisecond
    /// bounds of a [usize] or is zero.
    ///
    /// ```should_panic
    /// use leaky_bucket::RateLimiter;
    /// use std::time;
    ///
    /// let limiter = RateLimiter::builder()
    ///     .interval(time::Duration::from_secs(u64::MAX))
    ///     .build();
    /// ```
    ///
    /// ```should_panic
    /// use leaky_bucket::RateLimiter;
    /// use std::time;
    ///
    /// let limiter = RateLimiter::builder()
    ///     .interval(time::Duration::from_millis(0))
    ///     .build();
    /// ```
    ///
    /// # Examples
    ///
    /// ```
    /// use leaky_bucket::RateLimiter;
    /// use std::time;
    ///
    /// let limiter = RateLimiter::builder()
    ///     .interval(time::Duration::from_millis(100))
    ///     .build();
    /// ```
    ///
    /// [`refill`]: Builder::refill
    pub fn interval(&mut self, interval: time::Duration) -> &mut Self {
        assert! {
            interval.as_millis() != 0,
            "interval must be non-zero",
        };
        assert! {
            u64::try_from(interval.as_millis()).is_ok(),
            "interval must fit within a 64-bit integer"
        };
        self.interval = interval;
        self
    }

    /// The number of tokens to add at each [`interval`] interval. The default
    /// value is `1`.
    ///
    /// # Panics
    ///
    /// Panics if a refill amount of `0` is specified.
    ///
    /// # Examples
    ///
    /// ```
    /// use leaky_bucket::RateLimiter;
    /// use std::time;
    ///
    /// let limiter = RateLimiter::builder()
    ///     .refill(100)
    ///     .build();
    /// ```
    ///
    /// [`interval`]: Builder::interval
    pub fn refill(&mut self, refill: usize) -> &mut Self {
        assert!(refill > 0, "refill amount cannot be zero");
        self.refill = refill;
        self
    }

    /// Configure the rate limiter to be fair. By default the rate limiter is
    /// *fair* which ensures that all tasks make steady progress even under
    /// contention. But an unfair scheduler might have a higher total
    /// throughput.
    ///
    /// # Examples
    ///
    /// ```
    /// use leaky_bucket::RateLimiter;
    ///
    /// let limiter = RateLimiter::builder()
    ///     .refill(100)
    ///     .fair(false)
    ///     .build();
    /// ```
    pub fn fair(&mut self, fair: bool) -> &mut Self {
        self.fair = fair;
        self
    }

    /// Construct a new [`RateLimiter`].
    ///
    /// # Examples
    ///
    /// ```
    /// use leaky_bucket::RateLimiter;
    /// use std::time;
    ///
    /// let limiter = RateLimiter::builder()
    ///     .refill(100)
    ///     .interval(time::Duration::from_millis(200))
    ///     .max(10_000)
    ///     .build();
    /// ```
    pub fn build(&self) -> RateLimiter {
        let deadline = time::Instant::now() + self.interval;

        let max = match self.max {
            Some(max) => max,
            None => usize::max(self.refill, self.initial).saturating_mul(DEFAULT_REFILL_MAX_FACTOR),
        };

        let initial = usize::min(self.initial, max);

        RateLimiter {
            refill: self.refill,
            interval: self.interval,
            max,
            fair: self.fair,
            critical: Mutex::new(Critical {
                balance: initial,
                waiters: LinkedList::new(),
                deadline,
                available: true,
            }),
        }
    }
}

/// Construct a new builder with default options.
///
/// # Examples
///
/// ```
/// use leaky_bucket::Builder;
///
/// let limiter = Builder::default().build();
/// ```
impl Default for Builder {
    fn default() -> Self {
        Self {
            max: None,
            initial: 0,
            refill: 1,
            interval: time::Duration::from_millis(100),
            fair: true,
        }
    }
}

/// The state of an acquire operation.
#[allow(clippy::large_enum_variant)]
enum State {
    /// Initial unconfigured state.
    Initial,
    /// The acquire is waiting to be released by the core.
    Waiting,
    /// This operation is currently the core.
    ///
    /// We need to take care to ensure that we don't move the configured sleep.
    /// Since it needs to be pinned to be polled.
    Core {
        /// The current sleep of the core.
        sleep: time::Sleep,
    },
    /// The operation is completed.
    Complete,
}

/// Internal state of the acquire. This is separated because it can be computed
/// in constant time.
struct AcquireState {
    /// If we are linked or not.
    linked: bool,
    /// Inner state of the acquire.
    linking: UnsafeCell<Linking>,
}

impl AcquireState {
    #[allow(clippy::declare_interior_mutable_const)]
    const INITIAL: AcquireState = AcquireState {
        linked: false,
        linking: UnsafeCell::new(Linking {
            task: Node::new(Task::new()),
            complete: AtomicBool::new(false),
            _pin: marker::PhantomPinned,
        }),
    };

    /// Access the completion flag.
    pub fn complete(&self) -> &AtomicBool {
        // Safety: This is always safe to access since it's atomic.
        unsafe {
            let ptr = self.linking.get() as *const _ as *const Node<Task>;
            let ptr = ptr.add(1) as *const AtomicBool;
            &*ptr
        }
    }

    /// Get the underlying task.
    pub unsafe fn task(&self) -> &Node<Task> {
        let ptr = self.linking.get() as *mut Node<Task>;
        &*ptr
    }

    /// Get the underlying task mutably.
    pub unsafe fn task_mut(&mut self) -> &mut Node<Task> {
        let ptr = self.linking.get() as *mut Node<Task>;
        &mut *ptr
    }

    /// Get the underlying task mutably and completion flag as a pair.
    pub unsafe fn update_project(&mut self) -> (&mut Node<Task>, &AtomicBool, &mut bool) {
        let node = self.linking.get() as *mut Node<Task>;
        let complete = node.add(1) as *const _ as *const AtomicBool;
        let node = &mut *(node as *mut Node<Task>);
        let complete = &*complete;
        (node, complete, &mut self.linked)
    }

    /// Update the waiting state for this acquisition task. This might require
    /// that we update the associated waker.
    #[tracing::instrument(skip(self, critical, waker), level = "trace")]
    fn update(&mut self, critical: &mut MutexGuard<'_, Critical>, waker: &Waker) {
        // Safety: we're ensured to do this under the critical lock since we've
        // passed the relevant guard in through `waiters`.
        let (task, complete, linked) = unsafe { self.update_project() };

        if !*linked {
            trace!("linking self");
            *linked = true;

            unsafe {
                critical.waiters.push_front(task.into());
            }
        }

        let w = &mut task.waker;

        let new_waker = match w {
            None => true,
            Some(w) => !w.will_wake(waker),
        };

        if new_waker {
            trace!("updating waker");
            *w = Some(waker.clone());
        }

        if task.complete.is_none() {
            trace!("setting complete");
            task.complete = Some(complete.into());
        }
    }

    /// Ensure that the current core task is correctly linked up if needed.
    #[tracing::instrument(skip(self, critical, lim), level = "trace")]
    unsafe fn link_core(&mut self, critical: &mut Critical, lim: &RateLimiter) {
        if lim.fair {
            // Fair scheduling needs to ensure that the core is part of the wait
            // queue, and will be woken up in-order with other tasks.
            if !mem::replace(&mut self.linked, true) {
                critical.waiters.push_front(self.task_mut().into());
            }
        } else {
            // Unfair scheduling the core task is not supposed to be in the wait
            // queue, so remove it from there if we've successfully stolen it.
            // Ensure that the current task is *not* linked since it is now to
            // become the coordinator for everyone else.
            if mem::take(&mut self.linked) {
                critical.waiters.remove(self.task_mut().into());
            }
        }
    }

    /// Refill the wait queue with the given number of tokens.
    #[tracing::instrument(skip(self, critical, lim), level = "trace")]
    fn drain_wait_queue(
        &self,
        critical: &mut MutexGuard<'_, Critical>,
        tokens: usize,
        lim: &RateLimiter,
    ) {
        critical.balance = critical.balance.saturating_add(tokens);
        trace!(tokens = tokens, "draining tokens");

        let mut bump = 0;

        // Safety: we're holding the lock guard to all the waiters so we can be
        // sure that we have exclusive access to the wait queue.
        unsafe {
            while critical.balance > 0 {
                let mut node = match critical.waiters.pop_back() {
                    Some(node) => node,
                    None => break,
                };

                let n = node.as_mut();
                n.fill(&mut critical.balance);

                trace! {
                    balance = critical.balance,
                    remaining = n.remaining,
                    "filled node",
                };

                if !n.is_completed() {
                    critical.waiters.push_back(node);
                    break;
                }

                if let Some(complete) = n.complete.take() {
                    complete.as_ref().store(true, Ordering::Release);
                }

                if let Some(waker) = n.waker.take() {
                    waker.wake();
                }

                bump += 1;

                if bump == BUMP_LIMIT {
                    MutexGuard::bump(critical);
                    bump = 0;
                }
            }
        }

        if critical.balance > lim.max {
            critical.balance = lim.max;
        }
    }

    /// Drain the given number of tokens through the core. Returns `true` if the
    /// core has been completed.
    #[tracing::instrument(skip(self, critical, tokens, lim), level = "trace")]
    fn drain_core(
        &mut self,
        critical: &mut MutexGuard<'_, Critical>,
        tokens: usize,
        lim: &RateLimiter,
    ) -> bool {
        self.drain_wait_queue(critical, tokens, lim);

        if lim.fair {
            debug_assert! {
                self.linked,
                "core must be linked for fair scheduler",
            };

            // We only need to check the state since the current core holder is
            // linked up to the wait queue.
            //
            // Safety: we're doing this under the critical lock so we know we
            // have exclusive access to the node.
            if unsafe { self.task().is_completed() } {
                // Task was unlinked by the drain action.
                self.linked = false;
                return true;
            }

            false
        } else {
            debug_assert! {
                !self.linked,
                "core must not be linked for an unfair scheduler",
            };

            // If the limiter is not fair, we need to in addition to draining
            // remaining tokens from linked nodes, drain it from ourselves. We
            // fill the current holder of the core last (self). To ensure that
            // it stays around for as long as possible.
            //
            // Safety: we know that no one else holds the task at this point.
            // The in particular the task is not linked into the wait queue.
            let c = unsafe { &mut *self.task_mut() };
            c.fill(&mut critical.balance);
            c.is_completed()
        }
    }

    /// Assume the current core and calculate how long we must sleep for in
    /// order to do it.
    ///
    /// # Safety
    ///
    /// This might link the current task into the task queue, so the caller must
    /// ensure that it is pinned.
    #[tracing::instrument(skip(self, critical, lim), level = "trace")]
    unsafe fn assume_core(
        &mut self,
        critical: &mut MutexGuard<'_, Critical>,
        lim: &RateLimiter,
    ) -> bool {
        self.link_core(critical, lim);

        let (tokens, deadline) = match calculate_drain(critical.deadline, lim.interval) {
            Some(tokens) => tokens,
            None => return true,
        };

        // It is appropriate to update the deadline.
        critical.deadline = deadline;

        if self.drain_core(critical, tokens, lim) {
            // We synthetically "ran" at the current time minus the remaining time
            // we need to wait until the last update period.
            critical.release();
            return false;
        }

        true
    }
}

impl fmt::Debug for AcquireState {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("AcquireState").finish()
    }
}

/// The future associated with acquiring permits from a rate limiter using
/// [`RateLimiter::acquire`].
pub struct Acquire<'a>(AcquireFut<BorrowedRateLimiter<'a>>);

impl Acquire<'_> {
    /// Test if this acquire task is currently coordinating the rate limiter.
    ///
    /// # Examples
    ///
    /// ```
    /// use leaky_bucket::RateLimiter;
    /// use std::future::Future;
    /// use std::sync::Arc;
    /// use std::task::Context;
    ///
    /// struct Waker;
    /// # impl std::task::Wake for Waker { fn wake(self: Arc<Self>) { } }
    ///
    /// # #[tokio::main] async fn main() {
    /// let limiter = RateLimiter::builder().build();
    ///
    /// let waker = Arc::new(Waker).into();
    /// let mut cx = Context::from_waker(&waker);
    ///
    /// let a1 = limiter.acquire(1);
    /// tokio::pin!(a1);
    ///
    /// assert!(!a1.is_core());
    /// assert!(a1.as_mut().poll(&mut cx).is_pending());
    /// assert!(a1.is_core());
    ///
    /// a1.as_mut().await;
    ///
    /// // After completion this is no longer a core.
    /// assert!(!a1.is_core());
    /// # }
    /// ```
    pub fn is_core(&self) -> bool {
        self.0.is_core()
    }
}

impl Future for Acquire<'_> {
    type Output = ();

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        let inner = unsafe { Pin::map_unchecked_mut(self, |this| &mut this.0) };
        inner.poll(cx)
    }
}

/// The future associated with acquiring permits from a rate limiter using
/// [`RateLimiter::acquire_owned`].
pub struct AcquireOwned(AcquireFut<Arc<RateLimiter>>);

impl AcquireOwned {
    /// Test if this acquire task is currently coordinating the rate limiter.
    ///
    /// # Examples
    ///
    /// ```
    /// use leaky_bucket::RateLimiter;
    /// use std::future::Future;
    /// use std::sync::Arc;
    /// use std::task::Context;
    ///
    /// struct Waker;
    /// # impl std::task::Wake for Waker { fn wake(self: Arc<Self>) { } }
    ///
    /// # #[tokio::main] async fn main() {
    /// let limiter = Arc::new(RateLimiter::builder().build());
    ///
    /// let waker = Arc::new(Waker).into();
    /// let mut cx = Context::from_waker(&waker);
    ///
    /// let a1 = limiter.acquire_owned(1);
    /// tokio::pin!(a1);
    ///
    /// assert!(!a1.is_core());
    /// assert!(a1.as_mut().poll(&mut cx).is_pending());
    /// assert!(a1.is_core());
    ///
    /// a1.as_mut().await;
    ///
    /// // After completion this is no longer a core.
    /// assert!(!a1.is_core());
    /// # }
    /// ```
    pub fn is_core(&self) -> bool {
        self.0.is_core()
    }
}

impl Future for AcquireOwned {
    type Output = ();

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        let inner = unsafe { Pin::map_unchecked_mut(self, |this| &mut this.0) };
        inner.poll(cx)
    }
}

struct AcquireFut<T>
where
    T: AsRef<RateLimiter>,
{
    /// Inner shared state.
    lim: T,
    /// The number of permits associated with this future.
    permits: usize,
    /// State of the acquisition.
    state: State,
    /// The internal acquire state.
    internal: AcquireState,
}

impl<T> AcquireFut<T>
where
    T: AsRef<RateLimiter>,
{
    #[inline]
    fn new(lim: T, permits: usize) -> Self {
        Self {
            lim,
            permits,
            state: State::Initial,
            internal: AcquireState::INITIAL,
        }
    }

    fn is_core(&self) -> bool {
        matches!(&self.state, State::Core { .. })
    }
}

// Safety: All the internals of acquire is thread safe and correctly
// synchronized. The embedded waiter queue doesn't have anything inherently
// unsafe in it.
unsafe impl<T> Send for AcquireFut<T> where T: AsRef<RateLimiter> {}
unsafe impl<T> Sync for AcquireFut<T> where T: AsRef<RateLimiter> {}

impl<T> Future for AcquireFut<T>
where
    T: AsRef<RateLimiter>,
{
    type Output = ();

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        let this = unsafe { self.get_unchecked_mut() };
        let lim = this.lim.as_ref();

        loop {
            match &mut this.state {
                State::Initial => {
                    // Safety: The task is not linked up yet, so we can safely
                    // inspect the number of permits without having to
                    // synchronize.
                    if this.permits == 0 {
                        this.state = State::Complete;
                        return Poll::Ready(());
                    }

                    let mut critical = lim.critical.lock();

                    // If we've hit a deadline, calculate the number of tokens
                    // to drain and perform it in line here. This is necessary
                    // because the core isn't aware of how long we sleep between
                    // each acquire, so we need to perform some of the drain
                    // work here in order to avoid acruing a debt that needs to
                    // be filled later in.
                    //
                    // If we didn't do this, and the process slept for a long
                    // time, the next time a core is acquired it would be very
                    // far removed from the expected deadline and has no idea
                    // when permits were acquired, so it would over-eagerly
                    // release a lot of acquires and accumulate permits.
                    //
                    // This is tested for in the `test_idle` suite of tests.
                    if let Some((tokens, deadline)) =
                        calculate_drain(critical.deadline, lim.interval)
                    {
                        trace!(tokens = tokens, "inline drain");
                        // We pre-emptively update the deadline of the core
                        // since it might bump, and we don't want other
                        // processes to observe that the deadline has been
                        // reached.
                        critical.deadline = deadline;
                        this.internal.drain_wait_queue(&mut critical, tokens, lim);
                    }

                    // Test the fast path first, where we simply subtract the
                    // permits available from the current balance.
                    if let Some(balance) = critical.balance.checked_sub(this.permits) {
                        critical.balance = balance;
                        this.state = State::Complete;
                        return Poll::Ready(());
                    }

                    let balance = mem::take(&mut critical.balance);

                    // Safety: This is done in a pinned section, so we know that
                    // the linked section stays alive for the duration of this
                    // future due to pinning guarantees.
                    unsafe {
                        this.internal.task_mut().remaining = this.permits - balance;
                    }

                    // Try to take over as core. If we're unsuccessful we just
                    // ensure that we're linked into the wait queue.
                    if !mem::take(&mut critical.available) {
                        this.internal.update(&mut critical, cx.waker());
                        this.state = State::Waiting;
                        return Poll::Pending;
                    }

                    // Safety: This is done in a pinned section, so we know that
                    // the linked section stays alive for the duration of this
                    // future due to pinning guarantees.
                    unsafe { this.internal.link_core(&mut critical, lim) };

                    trace!(until = ?critical.deadline, "taking over core and sleeping");
                    this.state = State::Core {
                        sleep: time::sleep_until(critical.deadline),
                    };

                    trace!("no immediate tokens available");
                }
                State::Waiting => {
                    // If we are complete, then return as ready.
                    //
                    // This field is atomic, so we can safely read it under shared
                    // access and do not require a lock.
                    if this.internal.complete().load(Ordering::Acquire) {
                        this.state = State::Complete;
                        return Poll::Ready(());
                    }

                    // Note: we need to operate under this lock to ensure that
                    // the core acquired here (or elsewhere) observes that the
                    // current task has been linked up.
                    let mut critical = lim.critical.lock();

                    // Try to take over as core. If we're unsuccessful we
                    // just ensure that we're linked into the wait queue.
                    if !mem::take(&mut critical.available) {
                        this.internal.update(&mut critical, cx.waker());
                        return Poll::Pending;
                    }

                    // Safety: This is done in a pinned section, so we know that
                    // the linked section stays alive for the duration of this
                    // future due to pinning guarantees.
                    let assumed = unsafe { this.internal.assume_core(&mut critical, lim) };

                    if !assumed {
                        // Marks as completed.
                        this.state = State::Complete;
                        return Poll::Ready(());
                    }

                    trace!(until = ?critical.deadline, "taking over core and sleeping");
                    this.state = State::Core {
                        sleep: time::sleep_until(critical.deadline),
                    };
                }
                State::Core { sleep } => {
                    let mut sleep = unsafe { Pin::new_unchecked(sleep) };

                    if sleep.as_mut().poll(cx).is_pending() {
                        return Poll::Pending;
                    }

                    let now = time::Instant::now();
                    trace!(now = ?now, "sleep completed");
                    let mut critical = lim.critical.lock();
                    critical.deadline = now + lim.interval;

                    // Safety: we know that we're the only one with access to core
                    // because we ensured it as we acquire the `available` lock.
                    if this.internal.drain_core(&mut critical, lim.refill, lim) {
                        critical.release();
                        this.state = State::Complete;
                        return Poll::Ready(());
                    }

                    trace!(sleep = ?lim.interval, "keeping core and sleeping");
                    sleep.as_mut().reset(critical.deadline);
                }
                State::Complete => {
                    panic!("polled after completion");
                }
            }
        }
    }
}

impl<T> Drop for AcquireFut<T>
where
    T: AsRef<RateLimiter>,
{
    fn drop(&mut self) {
        let lim = self.lim.as_ref();

        match &mut self.state {
            State::Waiting => unsafe {
                debug_assert! {
                    self.internal.linked,
                    "waiting nodes have to be linked",
                };

                // While the node is linked into the wait queue we have to
                // ensure it's only accessed under a lock, but once it's been
                // unlinked we can do what we want with it.
                let mut critical = lim.critical.lock();
                critical.waiters.remove(self.internal.task_mut().into());
            },
            State::Core { .. } => unsafe {
                let mut critical = lim.critical.lock();

                if mem::take(&mut self.internal.linked) {
                    critical.waiters.remove(self.internal.task_mut().into());
                }

                critical.release();
            },
            _ => (),
        }
    }
}

/// All of the state that is linked into the wait queue.
///
/// This is only ever accessed through raw pointer manipulation to avoid issues
/// with field aliasing.
#[repr(C)]
struct Linking {
    /// The node in the linked list.
    task: Node<Task>,
    /// If this node has been released or not. We make this an atomic to permit
    /// access to it without synchronization.
    complete: AtomicBool,
    /// Avoids noalias heuristics from kicking in on references to a `Linking`
    /// struct.
    _pin: marker::PhantomPinned,
}

/// Calculate refill amount. Returning a tuple of how much to fill and remaining
/// duration to sleep until the next refill time if appropriate.
fn calculate_drain(
    deadline: time::Instant,
    interval: time::Duration,
) -> Option<(usize, time::Instant)> {
    let now = time::Instant::now();

    if now < deadline {
        return None;
    }

    // Time elapsed in milliseconds since the last deadline.
    let millis = interval.as_millis();
    let since = now.saturating_duration_since(deadline).as_millis();

    let tokens = usize::try_from(since / millis + 1).unwrap_or(usize::MAX);
    let rem = u64::try_from(since % millis).unwrap_or(u64::MAX);

    // Calculated time remaining until the next deadline.
    let deadline = now + (interval - time::Duration::from_millis(rem));
    Some((tokens, deadline))
}

#[cfg(test)]
mod tests {
    use super::{Acquire, AcquireOwned, RateLimiter};

    fn is_send<T: Send>() {}
    fn is_sync<T: Sync>() {}

    #[test]
    fn assert_send_sync() {
        is_send::<AcquireOwned>();
        is_sync::<AcquireOwned>();

        is_send::<RateLimiter>();
        is_sync::<RateLimiter>();

        is_send::<Acquire<'_>>();
        is_sync::<Acquire<'_>>();
    }
}