use core::time::Duration;
use better_bucket::{Bucket, Decision as BucketDecision};
use clock_lib::{Clock, SystemClock};
use crate::decision::Decision;
#[cfg(feature = "runtime")]
use crate::error::ThrottleError;
use crate::limiter::Limiter;
#[derive(Debug)]
pub struct Throttle<C: Clock = SystemClock> {
bucket: Bucket<C>,
}
impl Throttle<SystemClock> {
#[must_use]
pub fn per_second(rate: u32) -> Self {
Self {
bucket: Bucket::per_second(rate),
}
}
#[must_use]
pub fn per_duration(amount: u32, period: Duration) -> Self {
Self {
bucket: Bucket::per_duration(amount, period),
}
}
}
impl<C: Clock> Throttle<C> {
#[must_use]
pub fn with_clock<C2: Clock>(self, clock: C2) -> Throttle<C2> {
Throttle {
bucket: self.bucket.with_clock(clock),
}
}
#[inline]
#[must_use]
pub fn capacity(&self) -> u32 {
self.bucket.capacity()
}
#[inline]
#[must_use]
pub fn available(&self) -> u32 {
self.bucket.available()
}
#[inline]
#[must_use]
pub fn try_acquire(&self) -> bool {
self.bucket.try_acquire(1)
}
#[inline]
#[must_use]
pub fn try_acquire_with_cost(&self, cost: u32) -> bool {
self.bucket.try_acquire(cost)
}
#[inline]
#[must_use]
pub fn peek(&self, cost: u32) -> Decision {
let capacity = self.bucket.capacity();
if cost > capacity {
return Decision::Impossible;
}
let available = self.bucket.available();
if available >= cost {
return Decision::Acquired;
}
let config = self.bucket.config();
let refill_amount = config.refill_amount();
let period = config.refill_period();
if refill_amount == 0 || period.is_zero() {
return Decision::Impossible;
}
let deficit = cost - available;
Decision::Retry {
after: estimate_refill_wait(period, deficit, refill_amount),
}
}
#[inline]
fn decide(&self, cost: u32) -> Decision {
match self.bucket.acquire(cost) {
BucketDecision::Allowed => Decision::Acquired,
BucketDecision::Denied { retry_after } if retry_after == Duration::MAX => {
Decision::Impossible
}
BucketDecision::Denied { retry_after } => Decision::Retry { after: retry_after },
_ => Decision::Impossible,
}
}
}
#[cfg(feature = "runtime")]
#[cfg_attr(docsrs, doc(cfg(feature = "runtime")))]
impl<C: Clock> Throttle<C> {
pub async fn acquire(&self) -> Result<(), ThrottleError> {
self.acquire_with_cost(1).await
}
pub async fn acquire_with_cost(&self, cost: u32) -> Result<(), ThrottleError> {
let timer = crate::obs::Timer::start();
let result = loop {
match self.decide(cost) {
Decision::Acquired => break Ok(()),
Decision::Impossible => {
break Err(ThrottleError::CostExceedsCapacity {
cost,
capacity: self.capacity(),
});
}
Decision::Retry { after } => crate::rt::sleep(after).await,
}
};
if result.is_ok() {
crate::obs::acquired("throttle");
}
crate::obs::wait("throttle", &timer);
crate::obs::trace_acquire("throttle", cost, result.is_ok(), &timer);
result
}
}
#[inline]
fn estimate_refill_wait(period: Duration, deficit: u32, refill_amount: u32) -> Duration {
let numerator = period.as_nanos().saturating_mul(u128::from(deficit));
let nanos = numerator.div_ceil(u128::from(refill_amount));
Duration::from_nanos(u64::try_from(nanos).unwrap_or(u64::MAX))
}
impl<C: Clock> Limiter for Throttle<C> {
#[inline]
fn peek(&self, cost: u32) -> Decision {
Throttle::peek(self, cost)
}
#[inline]
fn acquire_cost(&self, cost: u32) -> Decision {
self.decide(cost)
}
#[inline]
fn available(&self) -> u32 {
self.bucket.available()
}
#[inline]
fn capacity(&self) -> u32 {
self.bucket.capacity()
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)]
use super::Throttle;
use crate::decision::Decision;
use crate::error::ThrottleError;
use crate::limiter::Limiter;
use clock_lib::ManualClock;
use core::time::Duration;
use std::sync::Arc;
fn assert_send_sync<T: Send + Sync>() {}
#[test]
fn test_public_types_are_send_sync() {
assert_send_sync::<Throttle>();
assert_send_sync::<Decision>();
assert_send_sync::<ThrottleError>();
}
#[test]
fn test_try_acquire_grants_up_to_capacity_then_refuses() {
let throttle = Throttle::per_second(3);
assert!(throttle.try_acquire());
assert!(throttle.try_acquire());
assert!(throttle.try_acquire());
assert!(!throttle.try_acquire());
}
#[test]
fn test_try_acquire_with_cost_is_all_or_nothing() {
let throttle = Throttle::per_second(10);
assert!(throttle.try_acquire_with_cost(7));
assert!(!throttle.try_acquire_with_cost(7));
assert!(throttle.try_acquire_with_cost(3));
}
#[test]
fn test_refill_after_a_full_period_under_manual_clock() {
let clock = Arc::new(ManualClock::new());
let throttle = Throttle::per_second(4).with_clock(clock.clone());
for _ in 0..4 {
assert!(throttle.try_acquire());
}
assert!(!throttle.try_acquire());
clock.advance(Duration::from_secs(1));
assert!(throttle.try_acquire());
}
#[test]
fn test_acquire_cost_reports_retry_then_impossible() {
let throttle = Throttle::per_second(2);
assert_eq!(throttle.acquire_cost(2), Decision::Acquired);
assert!(matches!(throttle.acquire_cost(1), Decision::Retry { .. }));
assert_eq!(throttle.acquire_cost(3), Decision::Impossible);
}
#[test]
fn test_available_tracks_consumption() {
let throttle = Throttle::per_second(5);
assert_eq!(throttle.available(), 5);
assert!(throttle.try_acquire_with_cost(2));
assert_eq!(throttle.available(), 3);
}
#[tokio::test]
async fn test_acquire_returns_immediately_when_a_token_is_free() {
let throttle = Throttle::per_second(1);
assert!(throttle.acquire().await.is_ok());
}
#[tokio::test]
async fn test_acquire_with_cost_errors_when_cost_exceeds_capacity() {
let throttle = Throttle::per_second(5);
let err = throttle.acquire_with_cost(9).await.unwrap_err();
assert_eq!(
err,
ThrottleError::CostExceedsCapacity {
cost: 9,
capacity: 5,
}
);
}
#[tokio::test]
async fn test_acquire_waits_for_refill_then_succeeds() {
let throttle = Throttle::per_second(1000);
for _ in 0..1000 {
assert!(throttle.try_acquire());
}
assert!(!throttle.try_acquire());
assert!(throttle.acquire().await.is_ok());
}
}