#[cfg(feature = "parking_lot")]
use parking_lot::RwLock;
use std::sync::Arc;
#[cfg(not(feature = "parking_lot"))]
use std::sync::RwLock;
use tokio::{
sync::Semaphore,
time::{sleep_until, Duration, Instant},
};
#[derive(Debug)]
struct LeakyBucketInner {
max: u32,
refill_interval: Duration,
refill_amount: u32,
tokens: RwLock<u32>,
last_refill: RwLock<Instant>,
semaphore: Semaphore,
}
impl LeakyBucketInner {
fn new(max: u32, tokens: u32, refill_interval: Duration, refill_amount: u32) -> Self {
Self {
tokens: RwLock::new(tokens),
max,
refill_interval,
refill_amount,
last_refill: RwLock::new(Instant::now()),
semaphore: Semaphore::new(1),
}
}
#[inline]
fn update_tokens(&self) -> u32 {
#[cfg(feature = "parking_lot")]
let mut last_refill = self.last_refill.write();
#[cfg(not(feature = "parking_lot"))]
let mut last_refill = self.last_refill.write().expect("RwLock poisoned");
#[cfg(feature = "parking_lot")]
let mut tokens = self.tokens.write();
#[cfg(not(feature = "parking_lot"))]
let mut tokens = self.tokens.write().expect("RwLock poisoned");
let time_passed = Instant::now() - *last_refill;
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
let refills_since =
(time_passed.as_secs_f64() / self.refill_interval.as_secs_f64()).floor() as u32;
*tokens += self.refill_amount * refills_since;
*last_refill += self.refill_interval * refills_since;
*tokens = tokens.min(self.max);
*tokens
}
#[inline]
fn tokens(&self) -> u32 {
self.update_tokens()
}
fn next_refill(&self) -> Instant {
self.update_tokens();
#[cfg(feature = "parking_lot")]
let last_refill = self.last_refill.read();
#[cfg(not(feature = "parking_lot"))]
let last_refill = self.last_refill.read().expect("RwLock poisoned");
*last_refill + self.refill_interval
}
async fn acquire(&self, amount: u32) {
let _permit = self.semaphore.acquire().await;
let current_tokens = self.update_tokens();
if current_tokens < amount {
let tokens_needed = amount - current_tokens;
let mut refills_needed = tokens_needed / self.refill_amount;
let refills_needed_remainder = tokens_needed % self.refill_amount;
if refills_needed_remainder > 0 {
refills_needed += 1;
}
let target_time = {
#[cfg(feature = "parking_lot")]
let last_refill = self.last_refill.read();
#[cfg(not(feature = "parking_lot"))]
let last_refill = self.last_refill.read().expect("RwLock poisoned");
*last_refill + self.refill_interval * refills_needed
};
sleep_until(target_time).await;
self.update_tokens();
}
#[cfg(feature = "parking_lot")]
{
*self.tokens.write() -= amount;
}
#[cfg(not(feature = "parking_lot"))]
{
*self.tokens.write().expect("RwLock poisoned") -= amount;
}
}
}
#[derive(Clone, Debug)]
pub struct LeakyBucket {
inner: Arc<LeakyBucketInner>,
}
impl LeakyBucket {
fn new(max: u32, tokens: u32, refill_interval: Duration, refill_amount: u32) -> Self {
let inner = Arc::new(LeakyBucketInner::new(
max,
tokens,
refill_interval,
refill_amount,
));
Self { inner }
}
#[must_use]
pub const fn builder() -> Builder {
Builder::new()
}
#[must_use]
pub fn max(&self) -> u32 {
self.inner.max
}
#[must_use]
pub fn tokens(&self) -> u32 {
self.inner.tokens()
}
#[must_use]
pub fn next_refill(&self) -> Instant {
self.inner.next_refill()
}
#[inline]
pub async fn acquire_one(&self) {
self.acquire(1).await;
}
pub async fn acquire(&self, amount: u32) {
assert!(
amount <= self.max(),
"Acquiring more tokens than the configured maximum is not possible"
);
self.inner.acquire(amount).await;
}
}
#[derive(Debug)]
pub struct Builder {
max: Option<u32>,
tokens: Option<u32>,
refill_interval: Option<Duration>,
refill_amount: Option<u32>,
}
impl Builder {
#[must_use]
pub const fn new() -> Self {
Self {
max: None,
tokens: None,
refill_interval: None,
refill_amount: None,
}
}
#[must_use]
pub const fn max(mut self, max: u32) -> Self {
self.max = Some(max);
self
}
#[must_use]
pub const fn tokens(mut self, tokens: u32) -> Self {
self.tokens = Some(tokens);
self
}
#[must_use]
pub const fn refill_interval(mut self, refill_interval: Duration) -> Self {
self.refill_interval = Some(refill_interval);
self
}
#[must_use]
pub const fn refill_amount(mut self, refill_amount: u32) -> Self {
self.refill_amount = Some(refill_amount);
self
}
#[must_use]
pub fn build(self) -> LeakyBucket {
const DEFAULT_MAX: u32 = 120;
const DEFAULT_TOKENS: u32 = 0;
const DEFAULT_REFILL_INTERVAL: Duration = Duration::from_secs(1);
const DEFAULT_REFILL_AMOUNT: u32 = 1;
let max = self.max.unwrap_or(DEFAULT_MAX);
let tokens = self.tokens.unwrap_or(DEFAULT_TOKENS);
let refill_interval = self.refill_interval.unwrap_or(DEFAULT_REFILL_INTERVAL);
let refill_amount = self.refill_amount.unwrap_or(DEFAULT_REFILL_AMOUNT);
LeakyBucket::new(max, tokens, refill_interval, refill_amount)
}
}
impl Default for Builder {
fn default() -> Self {
Self::new()
}
}