use std::time::{Duration, Instant};
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct TokenBucketConfig {
pub capacity: u64,
pub refill_per_sec: u64,
pub startup_warm_fraction: f64,
}
impl TokenBucketConfig {
pub const DEFAULT_STARTUP_WARM_FRACTION: f64 = 0.25;
pub fn new(capacity: u64, refill_per_sec: u64) -> Self {
Self {
capacity,
refill_per_sec,
startup_warm_fraction: Self::DEFAULT_STARTUP_WARM_FRACTION,
}
}
pub fn validate(&self) -> Result<(), &'static str> {
if self.capacity == 0 {
return Err("token bucket capacity must be > 0");
}
if !self.startup_warm_fraction.is_finite() {
return Err("startup_warm_fraction must be a finite number");
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct TokenBucket {
cfg: TokenBucketConfig,
tokens: u64,
last_refill_at: Instant,
}
impl TokenBucket {
pub fn new(cfg: TokenBucketConfig) -> Self {
let warm_frac = cfg.startup_warm_fraction.clamp(0.0, 1.0);
let initial = ((cfg.capacity as f64) * warm_frac).floor() as u64;
Self {
tokens: initial.min(cfg.capacity),
last_refill_at: Instant::now(),
cfg,
}
}
pub fn new_at(cfg: TokenBucketConfig, initial_tokens: u64, now: Instant) -> Self {
Self {
tokens: initial_tokens.min(cfg.capacity),
last_refill_at: now,
cfg,
}
}
pub fn capacity(&self) -> u64 {
self.cfg.capacity
}
pub fn tokens(&mut self) -> u64 {
self.refill_now();
self.tokens
}
pub fn refill_now(&mut self) {
let now = Instant::now();
self.refill_against(now);
}
pub fn refill_against(&mut self, now: Instant) {
let elapsed = now.saturating_duration_since(self.last_refill_at);
let elapsed_secs = elapsed.as_secs_f64();
let added = (elapsed_secs * self.cfg.refill_per_sec as f64).floor() as u64;
if added > 0 {
self.tokens = self.tokens.saturating_add(added).min(self.cfg.capacity);
self.last_refill_at = now;
}
}
pub fn try_consume(&mut self, n: u64) -> bool {
self.refill_now();
self.try_consume_no_refill(n)
}
pub fn try_consume_at(&mut self, n: u64, now: Instant) -> bool {
self.refill_against(now);
self.try_consume_no_refill(n)
}
fn try_consume_no_refill(&mut self, n: u64) -> bool {
if self.tokens >= n {
self.tokens -= n;
true
} else {
false
}
}
pub fn time_until_n(&mut self, n: u64) -> Option<Duration> {
if n > self.cfg.capacity {
return None;
}
self.refill_now();
if self.tokens >= n {
return Some(Duration::ZERO);
}
if self.cfg.refill_per_sec == 0 {
return None;
}
let needed = n - self.tokens;
let secs = needed as f64 / self.cfg.refill_per_sec as f64;
Some(Duration::from_secs_f64(secs))
}
}
#[cfg(test)]
mod tests {
use super::*;
fn mk_cfg(capacity: u64, refill: u64) -> TokenBucketConfig {
TokenBucketConfig {
capacity,
refill_per_sec: refill,
startup_warm_fraction: 1.0,
}
}
#[test]
fn new_initializes_to_warm_fraction() {
let cfg = TokenBucketConfig {
capacity: 100,
refill_per_sec: 10,
startup_warm_fraction: 0.25,
};
let mut b = TokenBucket::new(cfg);
assert_eq!(b.tokens(), 25);
}
#[test]
fn warm_fraction_zero_starts_empty() {
let cfg = TokenBucketConfig {
capacity: 100,
refill_per_sec: 0,
startup_warm_fraction: 0.0,
};
let mut b = TokenBucket::new(cfg);
assert_eq!(b.tokens(), 0);
}
#[test]
fn warm_fraction_clamps_above_one() {
let cfg = TokenBucketConfig {
capacity: 100,
refill_per_sec: 0,
startup_warm_fraction: 5.0, };
let mut b = TokenBucket::new(cfg);
assert_eq!(b.tokens(), 100); }
#[test]
fn try_consume_succeeds_when_enough_tokens() {
let now = Instant::now();
let mut b = TokenBucket::new_at(mk_cfg(100, 0), 100, now);
assert!(b.try_consume_at(40, now));
assert!(b.try_consume_at(60, now));
assert!(!b.try_consume_at(1, now)); }
#[test]
fn try_consume_rejects_partial_when_under() {
let now = Instant::now();
let mut b = TokenBucket::new_at(mk_cfg(100, 0), 50, now);
assert!(!b.try_consume_at(60, now));
assert_eq!(b.tokens, 50);
}
#[test]
fn refill_replenishes_at_configured_rate() {
let t0 = Instant::now();
let mut b = TokenBucket::new_at(mk_cfg(100, 10), 0, t0);
let t1 = t0 + Duration::from_secs(5);
b.refill_against(t1);
assert_eq!(b.tokens, 50);
}
#[test]
fn refill_clamps_at_capacity() {
let t0 = Instant::now();
let mut b = TokenBucket::new_at(mk_cfg(100, 100), 90, t0);
let t1 = t0 + Duration::from_secs(10);
b.refill_against(t1);
assert_eq!(b.tokens, 100);
}
#[test]
fn refill_zero_rate_never_replenishes() {
let t0 = Instant::now();
let mut b = TokenBucket::new_at(mk_cfg(100, 0), 30, t0);
let t1 = t0 + Duration::from_secs(60);
b.refill_against(t1);
assert_eq!(b.tokens, 30);
}
#[test]
fn time_until_n_zero_when_already_available() {
let now = Instant::now();
let mut b = TokenBucket::new_at(mk_cfg(100, 10), 100, now);
assert_eq!(b.time_until_n(50), Some(Duration::ZERO));
}
#[test]
fn time_until_n_estimates_refill_wait() {
let now = Instant::now();
let mut b = TokenBucket::new_at(mk_cfg(100, 10), 0, now);
let wait = b.time_until_n(50).unwrap();
assert!((wait.as_secs_f64() - 5.0).abs() < 0.01);
}
#[test]
fn time_until_n_none_when_n_exceeds_capacity() {
let now = Instant::now();
let mut b = TokenBucket::new_at(mk_cfg(100, 10), 100, now);
assert_eq!(b.time_until_n(200), None);
}
#[test]
fn time_until_n_none_when_zero_refill_and_under() {
let now = Instant::now();
let mut b = TokenBucket::new_at(mk_cfg(100, 0), 10, now);
assert_eq!(b.time_until_n(50), None);
}
#[test]
fn validate_rejects_zero_capacity() {
let cfg = TokenBucketConfig {
capacity: 0,
refill_per_sec: 10,
startup_warm_fraction: 0.25,
};
assert!(cfg.validate().is_err());
}
#[test]
fn validate_accepts_zero_refill() {
let cfg = TokenBucketConfig {
capacity: 10,
refill_per_sec: 0,
startup_warm_fraction: 0.25,
};
assert!(cfg.validate().is_ok());
}
#[test]
fn refill_idempotent_with_no_elapsed_time() {
let now = Instant::now();
let mut b = TokenBucket::new_at(mk_cfg(100, 10), 50, now);
b.refill_against(now);
b.refill_against(now);
assert_eq!(b.tokens, 50);
}
#[test]
fn fractional_refill_does_not_lose_progress_across_short_polls() {
let t0 = Instant::now();
let mut b = TokenBucket::new_at(mk_cfg(100, 10), 0, t0);
let mut t = t0;
for _ in 0..20 {
t += Duration::from_millis(50);
b.refill_against(t);
}
assert_eq!(b.tokens, 10);
}
}