#![forbid(unsafe_code)]
use std::time::{Duration, Instant};
#[cfg(loom)]
use loom::sync::atomic::{AtomicU64, Ordering};
#[cfg(loom)]
use loom::sync::Mutex;
#[cfg(not(loom))]
use std::sync::atomic::{AtomicU64, Ordering};
#[cfg(not(loom))]
use std::sync::Mutex;
pub trait Clock: Send + Sync {
fn now(&self) -> Instant;
}
#[derive(Debug, Default, Clone, Copy)]
pub struct SystemClock;
impl Clock for SystemClock {
fn now(&self) -> Instant {
Instant::now()
}
}
impl<C: Clock + ?Sized> Clock for &C {
fn now(&self) -> Instant {
(**self).now()
}
}
impl<C: Clock + ?Sized> Clock for std::sync::Arc<C> {
fn now(&self) -> Instant {
(**self).now()
}
}
#[cfg(loom)]
impl<C: Clock + ?Sized> Clock for loom::sync::Arc<C> {
fn now(&self) -> Instant {
(**self).now()
}
}
#[derive(Debug)]
pub struct ManualClock {
base: Instant,
offset_nanos: AtomicU64,
}
impl ManualClock {
pub fn new() -> Self {
Self {
base: Instant::now(),
offset_nanos: AtomicU64::new(0),
}
}
pub fn advance(&self, delta: Duration) {
let add = u64::try_from(delta.as_nanos()).unwrap_or(u64::MAX);
let mut cur = self.offset_nanos.load(Ordering::Acquire);
loop {
let next = cur.saturating_add(add);
match self.offset_nanos.compare_exchange_weak(
cur,
next,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => return,
Err(actual) => cur = actual,
}
}
}
}
impl Default for ManualClock {
fn default() -> Self {
Self::new()
}
}
impl Clock for ManualClock {
fn now(&self) -> Instant {
self.base + Duration::from_nanos(self.offset_nanos.load(Ordering::Acquire))
}
}
#[derive(Debug, thiserror::Error, PartialEq, Eq)]
pub enum ThrottleError {
#[error("requested {requested} tokens exceeds capacity {capacity}")]
RequestExceedsCapacity {
requested: u64,
capacity: u64,
},
#[error(
"zero refill rate cannot satisfy acquire of {requested} tokens \
(only {available} available)"
)]
ZeroRefillExhausted {
requested: u64,
available: u64,
},
}
pub struct Throttle<C: Clock = SystemClock> {
capacity: u64,
refill_per_sec: u64,
available: AtomicU64,
last_refill: Mutex<Instant>,
clock: C,
}
impl Throttle<SystemClock> {
pub fn new(capacity: u64, refill_per_sec: u64) -> Self {
Self::with_clock(capacity, refill_per_sec, SystemClock)
}
}
impl<C: Clock> Throttle<C> {
pub fn with_clock(capacity: u64, refill_per_sec: u64, clock: C) -> Self {
let now = clock.now();
Self {
capacity,
refill_per_sec,
available: AtomicU64::new(capacity),
last_refill: Mutex::new(now),
clock,
}
}
pub fn capacity(&self) -> u64 {
self.capacity
}
pub fn refill_per_sec(&self) -> u64 {
self.refill_per_sec
}
pub fn available(&self) -> u64 {
self.available.load(Ordering::Acquire)
}
pub fn try_acquire(&self, n: u64) -> bool {
if n > self.capacity {
return false;
}
self.refill();
if n == 0 {
return true;
}
self.consume(n)
}
pub fn acquire_blocking(&self, n: u64) -> Result<(), ThrottleError> {
if n > self.capacity {
return Err(ThrottleError::RequestExceedsCapacity {
requested: n,
capacity: self.capacity,
});
}
if n == 0 {
return Ok(());
}
if self.try_acquire(n) {
return Ok(());
}
if self.refill_per_sec == 0 {
return Err(ThrottleError::ZeroRefillExhausted {
requested: n,
available: self.available(),
});
}
loop {
let needed = n.saturating_sub(self.available.load(Ordering::Acquire));
let needed = needed.max(1);
let want_nanos =
u128::from(needed).saturating_mul(1_000_000_000) / u128::from(self.refill_per_sec);
let want_nanos = want_nanos.clamp(1_000_000, 1_000_000_000);
let dur = Duration::from_nanos(u64::try_from(want_nanos).unwrap_or(u64::MAX));
std::thread::sleep(dur);
if self.try_acquire(n) {
return Ok(());
}
}
}
fn refill(&self) {
if self.refill_per_sec == 0 {
return;
}
let now = self.clock.now();
let mut last = self
.last_refill
.lock()
.expect("invariant: throttle last_refill mutex must not be poisoned");
let elapsed = now.duration_since(*last);
let elapsed_nanos: u128 = elapsed.as_nanos();
let rate = u128::from(self.refill_per_sec);
let new_tokens_u128 = elapsed_nanos.saturating_mul(rate) / 1_000_000_000_u128;
if new_tokens_u128 == 0 {
return;
}
let new_tokens = u64::try_from(new_tokens_u128).unwrap_or(u64::MAX);
let credited_nanos = (u128::from(new_tokens) * 1_000_000_000_u128) / rate;
*last += Duration::from_nanos(u64::try_from(credited_nanos).unwrap_or(u64::MAX));
let mut cur = self.available.load(Ordering::Acquire);
loop {
let target = cur.saturating_add(new_tokens).min(self.capacity);
if target == cur {
break;
}
match self.available.compare_exchange_weak(
cur,
target,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => break,
Err(actual) => cur = actual,
}
}
}
fn consume(&self, n: u64) -> bool {
let mut cur = self.available.load(Ordering::Acquire);
loop {
if cur < n {
return false;
}
match self.available.compare_exchange_weak(
cur,
cur - n,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => return true,
Err(actual) => cur = actual,
}
}
}
}
impl<C: Clock> std::fmt::Debug for Throttle<C> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Throttle")
.field("capacity", &self.capacity)
.field("refill_per_sec", &self.refill_per_sec)
.field("available", &self.available())
.finish_non_exhaustive()
}
}
#[cfg(all(test, not(loom)))]
mod tests {
use super::*;
use hegel::generators as gs;
use hegel::TestCase;
use std::sync::Arc;
#[test]
fn new_starts_full() {
let t: Throttle = Throttle::new(8, 1);
assert_eq!(t.capacity(), 8);
assert_eq!(t.refill_per_sec(), 1);
assert_eq!(t.available(), 8);
}
#[test]
fn try_acquire_zero_is_always_true_even_when_empty() {
let t: Throttle = Throttle::new(2, 0);
assert!(t.try_acquire(2));
assert_eq!(t.available(), 0);
assert!(t.try_acquire(0));
}
#[test]
fn try_acquire_above_capacity_fails_fast() {
let t: Throttle = Throttle::new(4, 100);
assert!(!t.try_acquire(5));
assert_eq!(t.available(), 4);
}
#[test]
fn manual_clock_drives_refill() {
let clock = Arc::new(ManualClock::new());
let t = Throttle::with_clock(10, 100, Arc::clone(&clock));
assert!(t.try_acquire(10));
assert_eq!(t.available(), 0);
clock.advance(Duration::from_millis(100));
assert!(t.try_acquire(0));
assert_eq!(t.available(), 10);
}
#[test]
fn manual_clock_caps_at_capacity() {
let clock = Arc::new(ManualClock::new());
let t = Throttle::with_clock(5, 1, Arc::clone(&clock));
assert!(t.try_acquire(5));
clock.advance(Duration::from_secs(3600));
assert!(t.try_acquire(0));
assert_eq!(t.available(), 5);
}
#[test]
fn manual_clock_zero_refill_does_not_replenish() {
let clock = Arc::new(ManualClock::new());
let t = Throttle::with_clock(3, 0, Arc::clone(&clock));
assert!(t.try_acquire(3));
clock.advance(Duration::from_secs(60));
assert!(!t.try_acquire(1));
}
#[test]
fn acquire_blocking_above_capacity_returns_typed_error() {
let t: Throttle = Throttle::new(2, 1);
let err = t.acquire_blocking(5).unwrap_err();
assert_eq!(
err,
ThrottleError::RequestExceedsCapacity {
requested: 5,
capacity: 2,
}
);
}
#[test]
fn acquire_blocking_zero_refill_with_empty_bucket_returns_error() {
let t: Throttle = Throttle::new(1, 0);
assert!(t.try_acquire(1));
let err = t.acquire_blocking(1).unwrap_err();
assert!(matches!(
err,
ThrottleError::ZeroRefillExhausted {
requested: 1,
available: 0,
}
));
}
#[test]
fn acquire_blocking_zero_request_is_noop() {
let t: Throttle = Throttle::new(1, 0);
t.acquire_blocking(0).unwrap();
assert_eq!(t.available(), 1);
}
#[test]
fn acquire_blocking_waits_for_refill() {
let t = Throttle::new(2, 200); assert!(t.try_acquire(2));
let start = Instant::now();
t.acquire_blocking(2).unwrap();
let elapsed = start.elapsed();
assert!(
elapsed >= Duration::from_millis(5),
"acquire returned in {elapsed:?}, expected at least ~10ms"
);
assert!(
elapsed < Duration::from_secs(2),
"acquire took unexpectedly long: {elapsed:?}"
);
}
#[hegel::test(test_cases = 64)]
fn manual_clock_envelope_holds(tc: TestCase) {
let capacity = tc.draw(gs::integers::<u64>().min_value(1).max_value(64));
let refill = tc.draw(gs::integers::<u64>().min_value(1).max_value(1_000));
let req_sizes: Vec<u64> = (0..16)
.map(|_| tc.draw(gs::integers::<u64>().min_value(0).max_value(capacity)))
.collect();
let clock = Arc::new(ManualClock::new());
let t = Throttle::with_clock(capacity, refill, Arc::clone(&clock));
let mut granted: u128 = 0;
for n in &req_sizes {
if t.try_acquire(*n) {
granted += u128::from(*n);
}
}
assert!(
granted <= u128::from(capacity),
"granted {granted} > capacity {capacity} with frozen clock"
);
}
}