use crate::{MetricsError, Result};
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Instant;
#[repr(align(64))]
pub struct TokenBucket {
state: AtomicU64,
capacity_millitokens: u64,
refill_micro_per_ms: u64,
created_at: Instant,
}
impl TokenBucket {
pub fn new(capacity: u32, refill_per_second: f64) -> Self {
let cap_mt = (capacity as u64).saturating_mul(1000);
let rate = if refill_per_second.is_finite() && refill_per_second > 0.0 {
(refill_per_second * 1_000.0).round() as u64
} else {
0
};
Self {
state: AtomicU64::new(pack(cap_mt, 0)),
capacity_millitokens: cap_mt,
refill_micro_per_ms: rate,
created_at: Instant::now(),
}
}
#[must_use]
#[inline]
pub fn capacity(&self) -> u32 {
(self.capacity_millitokens / 1000).min(u32::MAX as u64) as u32
}
#[must_use]
#[inline]
pub fn refill_per_second(&self) -> f64 {
self.refill_micro_per_ms as f64 / 1_000.0
}
#[must_use]
pub fn available(&self) -> u32 {
let packed = self.state.load(Ordering::Relaxed);
let (tokens_mt, last_ms) = unpack(packed);
let now_ms = self.now_ms();
let mt = self.refilled(tokens_mt, last_ms, now_ms);
((mt / 1000).min(u32::MAX as u64)) as u32
}
#[inline]
pub fn try_acquire(&self, n: u32) -> Result<()> {
if n == 0 {
return Ok(());
}
let needed = (n as u64) * 1000;
loop {
let packed = self.state.load(Ordering::Relaxed);
let (tokens_mt, last_ms) = unpack(packed);
let now_ms = self.now_ms();
let mt = self.refilled(tokens_mt, last_ms, now_ms);
if mt < needed {
return Err(MetricsError::WouldBlock);
}
let new_packed = pack(mt - needed, now_ms);
if self
.state
.compare_exchange_weak(packed, new_packed, Ordering::Relaxed, Ordering::Relaxed)
.is_ok()
{
return Ok(());
}
}
}
#[must_use]
#[inline]
pub fn acquire(&self, n: u32) -> bool {
self.try_acquire(n).is_ok()
}
pub fn reset(&self) {
let now_ms = self.now_ms();
self.state
.store(pack(self.capacity_millitokens, now_ms), Ordering::SeqCst);
}
#[inline]
fn now_ms(&self) -> u32 {
(self.created_at.elapsed().as_millis() as u64).min(u32::MAX as u64) as u32
}
#[inline]
fn refilled(&self, tokens_mt: u64, last_ms: u32, now_ms: u32) -> u64 {
if self.refill_micro_per_ms == 0 {
return tokens_mt;
}
let elapsed_ms = now_ms.saturating_sub(last_ms) as u64;
let added_micro = elapsed_ms.saturating_mul(self.refill_micro_per_ms);
let added_mt = added_micro / 1_000;
(tokens_mt.saturating_add(added_mt)).min(self.capacity_millitokens)
}
}
#[inline]
fn pack(tokens_mt: u64, last_ms: u32) -> u64 {
let tokens_mt = tokens_mt.min(u32::MAX as u64);
(tokens_mt << 32) | (last_ms as u64)
}
#[inline]
fn unpack(packed: u64) -> (u64, u32) {
let tokens_mt = packed >> 32;
let last_ms = (packed & 0xFFFF_FFFF) as u32;
(tokens_mt, last_ms)
}
impl std::fmt::Debug for TokenBucket {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TokenBucket")
.field("capacity", &self.capacity())
.field("available", &self.available())
.field("refill_per_second", &self.refill_per_second())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::thread;
use std::time::Duration;
#[test]
fn pack_unpack_round_trip() {
for &mt in &[0_u64, 1, 1000, 50_000, u32::MAX as u64] {
for &ms in &[0_u32, 1, 1000, u32::MAX] {
let (tmt, tms) = unpack(pack(mt, ms));
assert_eq!(tmt, mt.min(u32::MAX as u64));
assert_eq!(tms, ms);
}
}
}
#[test]
fn new_bucket_starts_full() {
let b = TokenBucket::new(10, 5.0);
assert_eq!(b.capacity(), 10);
assert_eq!(b.available(), 10);
}
#[test]
fn try_acquire_zero_is_noop() {
let b = TokenBucket::new(5, 1.0);
b.try_acquire(0).unwrap();
assert_eq!(b.available(), 5);
}
#[test]
fn drains_then_refuses() {
let b = TokenBucket::new(3, 0.0); assert!(b.acquire(1));
assert!(b.acquire(1));
assert!(b.acquire(1));
assert!(!b.acquire(1));
assert!(matches!(b.try_acquire(1), Err(MetricsError::WouldBlock)));
}
#[test]
fn refills_over_time() {
let b = TokenBucket::new(10, 200.0);
assert!(b.acquire(10));
assert_eq!(b.available(), 0);
thread::sleep(Duration::from_millis(50));
assert!(
b.available() >= 1,
"expected ≥ 1 token after 50 ms, got {}",
b.available()
);
assert!(b.acquire(1));
}
#[test]
fn refill_caps_at_capacity() {
let b = TokenBucket::new(5, 1000.0);
assert!(b.acquire(3));
thread::sleep(Duration::from_millis(50));
assert_eq!(b.available(), 5);
}
#[test]
fn reset_restores_capacity() {
let b = TokenBucket::new(4, 1.0);
assert!(b.acquire(4));
assert_eq!(b.available(), 0);
b.reset();
assert_eq!(b.available(), 4);
}
#[test]
fn concurrent_acquire_never_overshoots_capacity() {
let b = Arc::new(TokenBucket::new(100, 0.0));
let threads = 8;
let per_thread_demand = 30u32;
let handles: Vec<_> = (0..threads)
.map(|_| {
let b = Arc::clone(&b);
thread::spawn(move || {
let mut taken = 0u32;
for _ in 0..per_thread_demand {
if b.acquire(1) {
taken += 1;
}
}
taken
})
})
.collect();
let total: u32 = handles.into_iter().map(|h| h.join().unwrap()).sum();
assert_eq!(total, 100, "atomic-CAS bucket must never exceed capacity");
assert_eq!(b.available(), 0);
}
#[test]
fn invalid_refill_rate_treated_as_zero() {
let a = TokenBucket::new(5, f64::NAN);
assert_eq!(a.refill_per_second(), 0.0);
let b = TokenBucket::new(5, -1.0);
assert_eq!(b.refill_per_second(), 0.0);
let c = TokenBucket::new(5, f64::INFINITY);
assert_eq!(c.refill_per_second(), 0.0);
}
#[test]
fn debug_impl() {
let b = TokenBucket::new(7, 2.5);
let s = format!("{b:?}");
assert!(s.contains("TokenBucket"));
assert!(s.contains("capacity: 7"));
}
}