use std::sync::atomic::{AtomicU64, Ordering};
use crate::clock::Timestamp;
use crate::quota::Window;
use crate::quota::strategy::QuotaTracker;
pub(crate) struct TokenBucket {
capacity: u64,
window_nanos: u64,
tokens: AtomicU64,
last_refill: AtomicU64,
consumed_in_window: AtomicU64,
window_start: AtomicU64,
}
const PRECISION: u64 = 1_000_000;
impl TokenBucket {
pub(crate) fn new(capacity: u64, window: Window, now: Timestamp) -> Self {
Self {
capacity,
window_nanos: window.as_nanos(),
tokens: AtomicU64::new(capacity * PRECISION),
last_refill: AtomicU64::new(now.0),
consumed_in_window: AtomicU64::new(0),
window_start: AtomicU64::new(now.0),
}
}
fn refill(&self, now: Timestamp) -> u64 {
let last = self.last_refill.load(Ordering::Acquire);
let elapsed = now.0.saturating_sub(last);
if elapsed == 0 {
return self.tokens.load(Ordering::Acquire);
}
let tokens_to_add = (elapsed as u128 * self.capacity as u128 * PRECISION as u128)
/ self.window_nanos as u128;
let tokens_to_add = tokens_to_add as u64;
if tokens_to_add == 0 {
return self.tokens.load(Ordering::Acquire);
}
self.last_refill.store(now.0, Ordering::Release);
let max_tokens = self.capacity * PRECISION;
let current = self.tokens.fetch_add(tokens_to_add, Ordering::AcqRel);
let new_tokens = current.saturating_add(tokens_to_add).min(max_tokens);
if current.saturating_add(tokens_to_add) > max_tokens {
self.tokens.store(max_tokens, Ordering::Release);
}
let ws = self.window_start.load(Ordering::Acquire);
if now.0.saturating_sub(ws) >= self.window_nanos {
self.consumed_in_window.store(0, Ordering::Release);
self.window_start.store(now.0, Ordering::Release);
}
new_tokens
}
}
impl QuotaTracker for TokenBucket {
fn check(&self, amount: u64, now: Timestamp) -> bool {
let available = self.refill(now);
available >= amount * PRECISION
}
fn record(&self, amount: u64, now: Timestamp) {
self.refill(now);
let cost = amount * PRECISION;
self.tokens.fetch_sub(
cost.min(self.tokens.load(Ordering::Acquire)),
Ordering::AcqRel,
);
self.consumed_in_window.fetch_add(amount, Ordering::Relaxed);
}
fn remaining(&self, now: Timestamp) -> u64 {
self.refill(now) / PRECISION
}
fn capacity(&self) -> u64 {
self.capacity
}
fn burn_rate(&self, now: Timestamp) -> f64 {
let ws = self.window_start.load(Ordering::Acquire);
let elapsed_secs = now.0.saturating_sub(ws) as f64 / 1_000_000_000.0;
if elapsed_secs < 0.001 {
return 0.0;
}
let consumed = self.consumed_in_window.load(Ordering::Acquire);
consumed as f64 / elapsed_secs
}
fn reset(&self, now: Timestamp) {
self.tokens
.store(self.capacity * PRECISION, Ordering::Release);
self.last_refill.store(now.0, Ordering::Release);
self.consumed_in_window.store(0, Ordering::Release);
self.window_start.store(now.0, Ordering::Release);
}
}
#[cfg(test)]
mod tests {
use super::*;
fn ts(ms: u64) -> Timestamp {
Timestamp(ms * 1_000_000)
}
#[test]
fn new_bucket_is_full() {
let bucket = TokenBucket::new(100, Window::Minute, ts(0));
assert_eq!(bucket.remaining(ts(0)), 100);
assert!(bucket.check(100, ts(0)));
assert!(!bucket.check(101, ts(0)));
}
#[test]
fn consume_reduces_remaining() {
let bucket = TokenBucket::new(100, Window::Minute, ts(0));
bucket.record(30, ts(0));
assert_eq!(bucket.remaining(ts(0)), 70);
}
#[test]
fn tokens_refill_over_time() {
let bucket = TokenBucket::new(60, Window::Minute, ts(0));
bucket.record(60, ts(0));
assert_eq!(bucket.remaining(ts(0)), 0);
assert_eq!(bucket.remaining(ts(30_000)), 30);
assert_eq!(bucket.remaining(ts(60_000)), 60);
}
#[test]
fn never_exceeds_capacity() {
let bucket = TokenBucket::new(100, Window::Minute, ts(0));
assert_eq!(bucket.remaining(ts(120_000)), 100);
}
#[test]
fn burn_rate_tracks_consumption() {
let bucket = TokenBucket::new(100, Window::Minute, ts(0));
bucket.record(10, ts(1_000)); bucket.record(10, ts(2_000));
let rate = bucket.burn_rate(ts(5_000)); assert!((rate - 4.0).abs() < 0.1);
}
#[test]
fn usage_ratio() {
let bucket = TokenBucket::new(100, Window::Minute, ts(0));
assert!((bucket.usage_ratio(ts(0))).abs() < 0.01);
bucket.record(80, ts(0));
assert!((bucket.usage_ratio(ts(0)) - 0.8).abs() < 0.01);
}
#[test]
fn predicted_exhaustion() {
let bucket = TokenBucket::new(100, Window::Minute, ts(0));
bucket.record(50, ts(5_000));
let secs = bucket.predicted_exhaustion_secs(ts(5_000));
assert!((secs - 5.0).abs() < 1.0);
}
#[test]
fn reset_restores_full_capacity() {
let bucket = TokenBucket::new(100, Window::Minute, ts(0));
bucket.record(100, ts(0));
assert_eq!(bucket.remaining(ts(0)), 0);
bucket.reset(ts(1_000));
assert_eq!(bucket.remaining(ts(1_000)), 100);
}
}