use core::time::Duration;
#[cfg(loom)]
use loom::sync::atomic::{AtomicU64, Ordering};
use clock_lib::{Clock, Monotonic, SystemClock};
#[cfg(not(loom))]
use core::sync::atomic::{AtomicU64, Ordering};
use crate::config::BucketConfig;
use crate::decision::Decision;
const MILLI: u64 = 1_000;
const REFILL_FRAC_BITS: u32 = 22;
#[inline]
fn pack(millitokens: u64, last_ms: u32) -> u64 {
(millitokens.min(u64::from(u32::MAX)) << 32) | u64::from(last_ms)
}
#[inline]
fn unpack(state: u64) -> (u64, u32) {
(state >> 32, (state & u64::from(u32::MAX)) as u32)
}
fn tier1_config(capacity: u32, amount: u32, period: Duration, initial: u32) -> BucketConfig {
if capacity == 0 || amount == 0 || period.is_zero() {
BucketConfig::raw(0, 0, Duration::from_secs(1), 0)
} else {
BucketConfig::raw(capacity, amount, period, initial)
}
}
#[repr(align(64))]
pub struct Bucket<C: Clock = SystemClock> {
state: AtomicU64,
capacity_millitokens: u64,
refill_per_ms_q: u128,
created_at: Monotonic,
config: BucketConfig,
clock: C,
}
fn build<C: Clock>(config: BucketConfig, clock: C) -> Bucket<C> {
let created_at = clock.now();
let capacity_millitokens = u64::from(config.capacity())
.saturating_mul(MILLI)
.min(u64::from(u32::MAX));
let refill_per_ms_q = if config.refill_amount() == 0 || config.refill_period().is_zero() {
0
} else {
let period_nanos = u64::try_from(config.refill_period().as_nanos())
.unwrap_or(u64::MAX)
.max(1);
let numerator = (u128::from(config.refill_amount()) * 1_000_000_000) << REFILL_FRAC_BITS;
numerator / u128::from(period_nanos)
};
let initial_millitokens = (u64::from(config.initial()) * MILLI).min(capacity_millitokens);
Bucket {
state: AtomicU64::new(pack(initial_millitokens, 0)),
capacity_millitokens,
refill_per_ms_q,
created_at,
config,
clock,
}
}
impl Bucket<SystemClock> {
#[must_use]
pub fn per_second(rate: u32) -> Self {
Self::from_config(tier1_config(rate, rate, Duration::from_secs(1), rate))
}
#[must_use]
pub fn per_duration(amount: u32, period: Duration) -> Self {
Self::from_config(tier1_config(amount, amount, period, amount))
}
#[must_use]
pub fn from_config(config: BucketConfig) -> Self {
build(config, SystemClock::new())
}
}
impl<C: Clock> Bucket<C> {
#[must_use]
pub fn with_clock<C2: Clock>(self, clock: C2) -> Bucket<C2> {
build(self.config, clock)
}
#[inline]
pub fn acquire(&self, n: u32) -> Decision {
self.acquire_inner(n)
}
#[inline]
#[must_use]
pub fn try_acquire(&self, n: u32) -> bool {
self.acquire_inner(n).is_allowed()
}
#[inline]
#[must_use]
pub fn available(&self) -> u32 {
let now_ms = self.now_ms();
let (tokens_mt, last_ms) = unpack(self.state.load(Ordering::Relaxed));
let refilled = self.refilled(tokens_mt, last_ms, now_ms);
u32::try_from(refilled / MILLI).unwrap_or(u32::MAX)
}
#[must_use]
pub const fn capacity(&self) -> u32 {
(self.capacity_millitokens / MILLI) as u32
}
#[must_use]
pub const fn config(&self) -> BucketConfig {
self.config
}
pub fn reset(&self) {
let now_ms = self.now_ms();
self.state
.store(pack(self.capacity_millitokens, now_ms), Ordering::Relaxed);
}
#[inline]
fn now_ms(&self) -> u32 {
let elapsed = self.clock.now().saturating_duration_since(self.created_at);
(elapsed.as_millis() & u128::from(u32::MAX)) as u32
}
#[inline]
fn refilled(&self, tokens_mt: u64, last_ms: u32, now_ms: u32) -> u64 {
if self.refill_per_ms_q == 0 {
return tokens_mt;
}
let elapsed_ms = now_ms.wrapping_sub(last_ms);
if elapsed_ms == 0 {
return tokens_mt;
}
let added = u128::from(elapsed_ms).saturating_mul(self.refill_per_ms_q) >> REFILL_FRAC_BITS;
let added_mt = u64::try_from(added).unwrap_or(u64::MAX);
tokens_mt
.saturating_add(added_mt)
.min(self.capacity_millitokens)
}
fn time_for(&self, deficit_mt: u64) -> Duration {
if self.refill_per_ms_q == 0 {
return Duration::MAX;
}
let numerator =
(u128::from(deficit_mt) << REFILL_FRAC_BITS).saturating_add(self.refill_per_ms_q - 1);
let millis = numerator / self.refill_per_ms_q;
Duration::from_millis(u64::try_from(millis).unwrap_or(u64::MAX))
}
#[inline]
fn acquire_inner(&self, n: u32) -> Decision {
if n == 0 {
return Decision::Allowed;
}
let need_mt = u64::from(n) * MILLI;
if need_mt > self.capacity_millitokens {
return Decision::Denied {
retry_after: Duration::MAX,
};
}
let now_ms = self.now_ms();
loop {
let current = self.state.load(Ordering::Relaxed);
let (tokens_mt, last_ms) = unpack(current);
let refilled = self.refilled(tokens_mt, last_ms, now_ms);
if refilled < need_mt {
return Decision::Denied {
retry_after: self.time_for(need_mt - refilled),
};
}
let next = pack(refilled - need_mt, now_ms);
match self.state.compare_exchange_weak(
current,
next,
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => return Decision::Allowed,
Err(_) => core::hint::spin_loop(),
}
}
}
}
impl<C: Clock> core::fmt::Debug for Bucket<C> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("Bucket")
.field("capacity", &self.capacity())
.field("available", &self.available())
.field("config", &self.config)
.finish()
}
}
pub trait TokenBucket {
fn acquire(&self, n: u32) -> Decision;
#[must_use]
fn try_acquire(&self, n: u32) -> bool;
#[must_use]
fn available(&self) -> u32;
#[must_use]
fn capacity(&self) -> u32;
}
impl<C: Clock> TokenBucket for Bucket<C> {
#[inline]
fn acquire(&self, n: u32) -> Decision {
self.acquire_inner(n)
}
#[inline]
fn try_acquire(&self, n: u32) -> bool {
self.acquire_inner(n).is_allowed()
}
#[inline]
fn available(&self) -> u32 {
Bucket::available(self)
}
#[inline]
fn capacity(&self) -> u32 {
self.capacity()
}
}
#[cfg(all(test, not(loom)))]
mod tests {
#![allow(clippy::unwrap_used)]
use super::{Bucket, TokenBucket};
use crate::decision::Decision;
use clock_lib::{ManualClock, SystemClock};
use core::time::Duration;
use std::sync::Arc;
use std::thread;
fn manual_bucket(rate: u32) -> (Arc<ManualClock>, Bucket<Arc<ManualClock>>) {
let clock = Arc::new(ManualClock::new());
let bucket = Bucket::per_second(rate).with_clock(Arc::clone(&clock));
(clock, bucket)
}
#[test]
fn test_bucket_is_send_and_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<Bucket<SystemClock>>();
assert_send_sync::<Bucket<Arc<ManualClock>>>();
}
#[test]
fn test_starts_full() {
let (_clock, bucket) = manual_bucket(10);
assert_eq!(bucket.available(), 10);
assert_eq!(bucket.capacity(), 10);
}
#[test]
fn test_acquire_deducts_tokens() {
let (_clock, bucket) = manual_bucket(10);
assert_eq!(bucket.acquire(3), Decision::Allowed);
assert_eq!(bucket.available(), 7);
}
#[test]
fn test_exact_empty_then_denied() {
let (_clock, bucket) = manual_bucket(10);
assert!(bucket.try_acquire(10)); assert_eq!(bucket.available(), 0);
assert!(!bucket.try_acquire(1));
}
#[test]
fn test_acquire_zero_always_allowed() {
let (_clock, bucket) = manual_bucket(1);
assert!(bucket.try_acquire(1)); assert!(!bucket.try_acquire(1));
assert!(bucket.try_acquire(0)); }
#[test]
fn test_request_above_capacity_never_grantable() {
let (_clock, bucket) = manual_bucket(5);
assert_eq!(
bucket.acquire(6),
Decision::Denied {
retry_after: Duration::MAX
}
);
}
#[test]
fn test_full_refill_after_one_period() {
let (clock, bucket) = manual_bucket(10);
assert!(bucket.try_acquire(10));
assert!(!bucket.try_acquire(1));
clock.advance(Duration::from_secs(1));
assert_eq!(bucket.available(), 10);
assert!(bucket.try_acquire(10));
}
#[test]
fn test_partial_refill_is_proportional() {
let (clock, bucket) = manual_bucket(100);
assert!(bucket.try_acquire(100));
clock.advance(Duration::from_millis(250)); assert_eq!(bucket.available(), 25);
}
#[test]
fn test_refill_saturates_at_capacity() {
let (clock, bucket) = manual_bucket(10);
assert!(bucket.try_acquire(10));
clock.advance(Duration::from_secs(100)); assert_eq!(bucket.available(), 10); }
#[test]
fn test_refill_after_long_idle_saturates_without_overflow() {
let (clock, bucket) = manual_bucket(1_000);
assert!(bucket.try_acquire(1_000));
clock.advance(Duration::from_secs(60 * 60 * 24 * 365 * 5));
assert_eq!(bucket.available(), 1_000);
assert!(bucket.try_acquire(1_000));
}
#[test]
fn test_denied_reports_retry_after() {
let (_clock, bucket) = manual_bucket(10);
assert!(bucket.try_acquire(10)); assert_eq!(
bucket.acquire(5),
Decision::Denied {
retry_after: Duration::from_millis(500)
}
);
}
#[test]
fn test_per_duration_uses_custom_period() {
let clock = Arc::new(ManualClock::new());
let bucket =
Bucket::per_duration(5, Duration::from_millis(100)).with_clock(Arc::clone(&clock));
assert!(bucket.try_acquire(5));
clock.advance(Duration::from_millis(100));
assert_eq!(bucket.available(), 5);
}
#[test]
fn test_sub_millisecond_period_still_refills() {
let clock = Arc::new(ManualClock::new());
let bucket =
Bucket::per_duration(5, Duration::from_micros(200)).with_clock(Arc::clone(&clock));
assert!(bucket.try_acquire(5));
clock.advance(Duration::from_millis(1));
assert_eq!(bucket.available(), 5); }
#[test]
fn test_zero_rate_is_deny_all() {
let bucket = Bucket::per_second(0);
assert_eq!(bucket.capacity(), 0);
assert_eq!(bucket.available(), 0);
assert!(!bucket.try_acquire(1));
assert!(bucket.try_acquire(0));
}
#[test]
fn test_reset_refills_to_capacity() {
let (_clock, bucket) = manual_bucket(5);
assert!(bucket.try_acquire(5));
assert_eq!(bucket.available(), 0);
bucket.reset();
assert_eq!(bucket.available(), 5);
}
#[test]
fn test_trait_object_safe_surface() {
let (_clock, bucket) = manual_bucket(4);
let as_trait: &dyn TokenBucket = &bucket;
assert_eq!(as_trait.capacity(), 4);
assert!(as_trait.try_acquire(4));
assert!(!as_trait.try_acquire(1));
}
#[test]
fn test_concurrent_acquire_never_over_grants() {
let clock = Arc::new(ManualClock::new());
let bucket = Arc::new(Bucket::per_second(100).with_clock(clock));
let threads = 8;
let demand = 30u32;
let handles: Vec<_> = (0..threads)
.map(|_| {
let bucket = Arc::clone(&bucket);
thread::spawn(move || {
let mut taken = 0u32;
for _ in 0..demand {
if bucket.try_acquire(1) {
taken += 1;
}
}
taken
})
})
.collect();
let total: u32 = handles.into_iter().map(|h| h.join().unwrap()).sum();
assert_eq!(total, 100, "CAS bucket must grant exactly capacity");
assert_eq!(bucket.available(), 0);
}
#[test]
fn test_high_contention_conserves_every_token() {
const CAPACITY: u32 = 6_001;
const THREADS: u32 = 16;
const TAKE: u32 = 3;
let clock = Arc::new(ManualClock::new()); let bucket = Arc::new(Bucket::per_second(CAPACITY).with_clock(clock));
let handles: Vec<_> = (0..THREADS)
.map(|_| {
let bucket = Arc::clone(&bucket);
thread::spawn(move || {
let mut taken = 0u32;
for _ in 0..CAPACITY {
if bucket.try_acquire(TAKE) {
taken += TAKE;
}
}
taken
})
})
.collect();
let granted: u32 = handles.into_iter().map(|h| h.join().unwrap()).sum();
assert!(granted <= CAPACITY, "over-grant: {granted} > {CAPACITY}");
assert_eq!(granted % TAKE, 0, "a partial take was granted");
assert_eq!(
bucket.available(),
CAPACITY - granted,
"tokens were lost or corrupted under contention"
);
}
#[test]
fn test_pack_unpack_round_trip() {
for &mt in &[0_u64, 1, 1_000, 50_000, u64::from(u32::MAX)] {
for &ms in &[0_u32, 1, 1_000, u32::MAX] {
let (got_mt, got_ms) = super::unpack(super::pack(mt, ms));
assert_eq!(got_mt, mt.min(u64::from(u32::MAX)));
assert_eq!(got_ms, ms);
}
}
}
}