use std::fmt::Debug;
use web_time_compat::{Duration, Instant};
#[derive(Debug)]
pub(crate) struct TokenBucket<I> {
rate: u64,
bucket_max: u64,
bucket: u64,
added_tokens_at: I,
}
impl<I: TokenBucketInstant> TokenBucket<I> {
pub(crate) fn new(config: &TokenBucketConfig, now: I) -> Self {
Self {
rate: config.rate,
bucket_max: config.bucket_max,
bucket: config.bucket_max,
added_tokens_at: now,
}
}
#[cfg_attr(not(test), expect(dead_code))]
pub(crate) fn is_empty(&self) -> bool {
self.bucket == 0
}
pub(crate) fn max(&self) -> u64 {
self.bucket_max
}
#[cfg_attr(not(test), expect(dead_code))]
pub(crate) fn drain(&mut self, count: u64) -> Result<BecameEmpty, InsufficientTokensError> {
Ok(self.claim(count)?.commit())
}
pub(crate) fn claim(
&mut self,
count: u64,
) -> Result<ClaimedTokens<I>, InsufficientTokensError> {
if count > self.bucket {
return Err(InsufficientTokensError {
available: self.bucket,
});
}
Ok(ClaimedTokens::new(self, count))
}
pub(crate) fn adjust(&mut self, now: I, config: &TokenBucketConfig) {
self.refill(now);
self.added_tokens_at = std::cmp::max(self.added_tokens_at, now);
self.rate = config.rate;
self.bucket_max = config.bucket_max;
self.bucket = std::cmp::min(self.bucket, self.bucket_max);
}
pub(crate) fn tokens_available_at(&self, tokens: u64) -> Result<I, NeverEnoughTokensError> {
let tokens_needed = tokens.saturating_sub(self.bucket);
if tokens_needed == 0 {
return Ok(self.added_tokens_at);
}
if self.rate == 0 {
return Err(NeverEnoughTokensError::ZeroRate);
}
if tokens > self.bucket_max {
return Err(NeverEnoughTokensError::ExceedsMaxTokens);
}
let time_needed = Self::tokens_to_duration(tokens_needed, self.rate)
.ok_or(NeverEnoughTokensError::ZeroRate)?;
let time_needed = std::cmp::max(time_needed, Duration::from_micros(1));
self.added_tokens_at
.checked_add(time_needed)
.ok_or(NeverEnoughTokensError::InstantNotRepresentable)
}
pub(crate) fn refill(&mut self, now: I) -> BecameNonEmpty {
let elapsed = now.saturating_duration_since(self.added_tokens_at);
if elapsed > I::IGNORE_THRESHOLD {
tracing::debug!(
"Time jump of {elapsed:?} is larger than {:?}; not refilling token bucket",
I::IGNORE_THRESHOLD,
);
self.added_tokens_at = now;
return BecameNonEmpty::No;
}
let old_bucket = self.bucket;
let bucket_inc = Self::duration_to_tokens(elapsed, self.rate);
self.bucket = std::cmp::min(self.bucket_max, self.bucket.saturating_add(bucket_inc));
let added_tokens_at_inc =
Self::tokens_to_duration(bucket_inc, self.rate).unwrap_or(Duration::ZERO);
self.added_tokens_at = self
.added_tokens_at
.checked_add(added_tokens_at_inc)
.expect("overflowed time");
debug_assert!(self.added_tokens_at <= now);
if old_bucket == 0 && self.bucket != 0 {
BecameNonEmpty::Yes
} else {
BecameNonEmpty::No
}
}
fn tokens_to_duration(tokens: u64, rate: u64) -> Option<Duration> {
if rate == 0 {
return None;
}
let micros = tokens.saturating_mul(1000 * 1000).div_ceil(rate);
Some(Duration::from_micros(micros))
}
fn duration_to_tokens(time: Duration, rate: u64) -> u64 {
let micros = u64::try_from(time.as_micros()).unwrap_or(u64::MAX);
rate.saturating_mul(micros) / (1000 * 1000)
}
}
#[derive(Clone, Debug)]
pub(crate) struct TokenBucketConfig {
pub(crate) rate: u64,
pub(crate) bucket_max: u64,
}
#[derive(Debug)]
pub(crate) struct ClaimedTokens<'a, I> {
bucket: &'a mut TokenBucket<I>,
count: u64,
}
impl<'a, I> ClaimedTokens<'a, I> {
fn new(bucket: &'a mut TokenBucket<I>, count: u64) -> Self {
Self { bucket, count }
}
pub(crate) fn commit(mut self) -> BecameEmpty {
self.commit_impl()
}
pub(crate) fn reduce(&mut self, count: u64) -> Result<(), InsufficientTokensError> {
if count > self.count {
return Err(InsufficientTokensError {
available: self.count,
});
}
self.count = count;
Ok(())
}
pub(crate) fn discard(mut self) {
self.count = 0;
}
fn commit_impl(&mut self) -> BecameEmpty {
self.bucket.bucket = self
.bucket
.bucket
.checked_sub(self.count)
.unwrap_or_else(|| {
panic!(
"claim commit failed: {}, {}",
self.count, self.bucket.bucket,
)
});
self.count = 0;
if self.bucket.bucket > 0 {
BecameEmpty::No
} else {
BecameEmpty::Yes
}
}
}
impl<'a, I> std::ops::Drop for ClaimedTokens<'a, I> {
fn drop(&mut self) {
self.commit_impl();
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, thiserror::Error)]
#[error("insufficient tokens for operation")]
pub(crate) struct InsufficientTokensError {
available: u64,
}
impl InsufficientTokensError {
pub(crate) fn available_tokens(&self) -> u64 {
self.available
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, thiserror::Error)]
#[error("there will never be enough tokens for this operation")]
pub(crate) enum NeverEnoughTokensError {
ExceedsMaxTokens,
ZeroRate,
InstantNotRepresentable,
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub(crate) enum BecameNonEmpty {
Yes,
No,
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub(crate) enum BecameEmpty {
Yes,
No,
}
pub(crate) trait TokenBucketInstant:
Copy + Clone + Debug + PartialEq + Eq + PartialOrd + Ord
{
const IGNORE_THRESHOLD: Duration;
fn checked_add(&self, duration: Duration) -> Option<Self>;
fn checked_duration_since(&self, earlier: Self) -> Option<Duration>;
fn saturating_duration_since(&self, earlier: Self) -> Duration {
self.checked_duration_since(earlier).unwrap_or_default()
}
}
impl TokenBucketInstant for Instant {
const IGNORE_THRESHOLD: Duration = Duration::from_secs((u32::MAX / 4) as u64);
#[inline]
fn checked_add(&self, duration: Duration) -> Option<Self> {
self.checked_add(duration)
}
#[inline]
fn checked_duration_since(&self, earlier: Self) -> Option<Duration> {
self.checked_duration_since(earlier)
}
#[inline]
fn saturating_duration_since(&self, earlier: Self) -> Duration {
self.saturating_duration_since(earlier)
}
}
#[cfg(test)]
mod test {
#![allow(clippy::unwrap_used)]
use super::*;
use rand::Rng;
#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
struct MillisTimestamp(u64);
impl TokenBucketInstant for MillisTimestamp {
const IGNORE_THRESHOLD: Duration = Duration::from_millis(1_000_000_000);
fn checked_add(&self, duration: Duration) -> Option<Self> {
let duration = u64::try_from(duration.as_millis()).ok()?;
self.0.checked_add(duration).map(Self)
}
fn checked_duration_since(&self, earlier: Self) -> Option<Duration> {
Some(Duration::from_millis(self.0.checked_sub(earlier.0)?))
}
}
#[test]
fn adjust_now() {
let time = MillisTimestamp(100);
let config = TokenBucketConfig {
rate: 10,
bucket_max: 100,
};
let mut tb = TokenBucket::new(&config, time);
assert_eq!(tb.bucket, 100);
assert_eq!(tb.bucket_max, 100);
assert_eq!(tb.rate, 10);
tb.adjust(
time,
&TokenBucketConfig {
rate: 20,
bucket_max: 100,
},
);
assert_eq!(tb.bucket, 100);
assert_eq!(tb.bucket_max, 100);
tb.adjust(
time,
&TokenBucketConfig {
rate: 20,
bucket_max: 40,
},
);
assert_eq!(tb.bucket, 40);
assert_eq!(tb.bucket_max, 40);
tb.adjust(
time,
&TokenBucketConfig {
rate: 20,
bucket_max: 100,
},
);
assert_eq!(tb.bucket, 40);
assert_eq!(tb.bucket_max, 100);
tb.adjust(
time,
&TokenBucketConfig {
rate: 200,
bucket_max: 100,
},
);
assert_eq!(tb.bucket, 40);
assert_eq!(tb.bucket_max, 100);
assert_eq!(tb.rate, 200);
}
#[test]
fn adjust_future() {
let config = TokenBucketConfig {
rate: 10,
bucket_max: 100,
};
let mut tb = TokenBucket::new(&config, MillisTimestamp(100));
assert_eq!(tb.bucket, 100);
assert_eq!(tb.bucket_max, 100);
assert_eq!(tb.rate, 10);
tb.adjust(
MillisTimestamp(300),
&TokenBucketConfig {
rate: 20,
bucket_max: 200,
},
);
assert_eq!(tb.bucket, 100);
assert_eq!(tb.bucket_max, 200);
tb.adjust(
MillisTimestamp(500),
&TokenBucketConfig {
rate: 20,
bucket_max: 200,
},
);
assert_eq!(tb.bucket, 104);
assert_eq!(tb.bucket_max, 200);
tb.adjust(
MillisTimestamp(700),
&TokenBucketConfig {
rate: 0,
bucket_max: 100,
},
);
assert_eq!(tb.bucket, 100);
assert_eq!(tb.bucket_max, 100);
tb.adjust(
MillisTimestamp(900),
&TokenBucketConfig {
rate: 100,
bucket_max: 200,
},
);
assert_eq!(tb.bucket, 100);
assert_eq!(tb.bucket_max, 200);
}
#[test]
fn adjust_zero() {
let time = MillisTimestamp(100);
let config = TokenBucketConfig {
rate: 10,
bucket_max: 100,
};
let mut tb = TokenBucket::new(&config, time);
tb.adjust(
time,
&TokenBucketConfig {
rate: 0,
bucket_max: 200,
},
);
assert_eq!(tb.bucket, 100);
assert_eq!(tb.bucket_max, 200);
assert_eq!(tb.rate, 0);
tb.refill(MillisTimestamp(10_000_000));
assert_eq!(tb.bucket, 100);
let mut tb = TokenBucket::new(&config, time);
tb.adjust(
time,
&TokenBucketConfig {
rate: 10,
bucket_max: 0,
},
);
assert_eq!(tb.bucket, 0);
assert_eq!(tb.bucket_max, 0);
assert_eq!(tb.rate, 10);
tb.refill(MillisTimestamp(10_000_000));
assert_eq!(tb.bucket, 0);
let mut tb = TokenBucket::new(&config, time);
tb.adjust(
time,
&TokenBucketConfig {
rate: 0,
bucket_max: 0,
},
);
assert_eq!(tb.bucket, 0);
assert_eq!(tb.bucket_max, 0);
assert_eq!(tb.rate, 0);
tb.refill(MillisTimestamp(10_000_000));
assert_eq!(tb.bucket, 0);
}
#[test]
fn is_empty() {
let config = TokenBucketConfig {
rate: 10,
bucket_max: 100,
};
let mut tb = TokenBucket::new(&config, MillisTimestamp(100));
assert!(!tb.is_empty());
tb.drain(99).unwrap();
assert!(!tb.is_empty());
tb.drain(1).unwrap();
assert!(tb.is_empty());
tb.refill(MillisTimestamp(199));
assert!(tb.is_empty());
tb.refill(MillisTimestamp(200));
assert!(!tb.is_empty());
}
#[test]
fn correctness() {
let config = TokenBucketConfig {
rate: 10,
bucket_max: 100,
};
let mut tb = TokenBucket::new(&config, MillisTimestamp(100));
tb.drain(50).unwrap();
assert_eq!(tb.bucket, 50);
tb.refill(MillisTimestamp(1100));
assert_eq!(tb.bucket, 60);
tb.drain(50).unwrap();
assert_eq!(tb.bucket, 10);
tb.refill(MillisTimestamp(2100));
assert_eq!(tb.bucket, 20);
tb.refill(MillisTimestamp(2101));
assert_eq!(tb.bucket, 20);
tb.refill(MillisTimestamp(2199));
assert_eq!(tb.bucket, 20);
tb.refill(MillisTimestamp(2200));
assert_eq!(tb.bucket, 21);
}
#[test]
fn rounding() {
let config = TokenBucketConfig {
rate: 10,
bucket_max: 100,
};
let mut tb = TokenBucket::new(&config, MillisTimestamp(0));
tb.drain(100).unwrap();
tb.refill(MillisTimestamp(99));
assert_eq!(tb.bucket, 0);
tb.refill(MillisTimestamp(150));
assert_eq!(tb.bucket, 1);
tb.refill(MillisTimestamp(199));
assert_eq!(tb.bucket, 1);
tb.refill(MillisTimestamp(200));
assert_eq!(tb.bucket, 2);
}
#[test]
fn tokens_available_at() {
let config = TokenBucketConfig {
rate: 10,
bucket_max: 100,
};
let mut tb = TokenBucket::new(&config, MillisTimestamp(0));
tb.drain(100).unwrap();
assert_eq!(tb.tokens_available_at(0), Ok(MillisTimestamp(0)));
assert_eq!(tb.tokens_available_at(1), Ok(MillisTimestamp(100)));
assert_eq!(tb.tokens_available_at(2), Ok(MillisTimestamp(200)));
tb.refill(MillisTimestamp(40));
assert_eq!(tb.tokens_available_at(0), Ok(MillisTimestamp(0)));
assert_eq!(tb.tokens_available_at(1), Ok(MillisTimestamp(100)));
assert_eq!(tb.tokens_available_at(2), Ok(MillisTimestamp(200)));
tb.refill(MillisTimestamp(100));
assert_eq!(tb.tokens_available_at(0), Ok(MillisTimestamp(100)));
assert_eq!(tb.tokens_available_at(1), Ok(MillisTimestamp(100)));
assert_eq!(tb.tokens_available_at(2), Ok(MillisTimestamp(200)));
tb.drain(1).unwrap();
assert_eq!(tb.tokens_available_at(0), Ok(MillisTimestamp(100)));
assert_eq!(tb.tokens_available_at(1), Ok(MillisTimestamp(200)));
assert_eq!(tb.tokens_available_at(2), Ok(MillisTimestamp(300)));
tb.refill(MillisTimestamp(140));
assert_eq!(tb.tokens_available_at(0), Ok(MillisTimestamp(100)));
assert_eq!(tb.tokens_available_at(1), Ok(MillisTimestamp(200)));
assert_eq!(tb.tokens_available_at(2), Ok(MillisTimestamp(300)));
tb.refill(MillisTimestamp(210));
assert_eq!(tb.tokens_available_at(0), Ok(MillisTimestamp(200)));
assert_eq!(tb.tokens_available_at(1), Ok(MillisTimestamp(200)));
assert_eq!(tb.tokens_available_at(2), Ok(MillisTimestamp(300)));
use NeverEnoughTokensError as NETE;
assert_eq!(tb.tokens_available_at(100), Ok(MillisTimestamp(10_100)));
assert_eq!(tb.tokens_available_at(101), Err(NETE::ExceedsMaxTokens));
assert_eq!(
tb.tokens_available_at(u64::MAX),
Err(NETE::ExceedsMaxTokens),
);
tb.adjust(
MillisTimestamp(210),
&TokenBucketConfig {
rate: 0,
bucket_max: 100,
},
);
assert_eq!(tb.tokens_available_at(0), Ok(MillisTimestamp(210)));
assert_eq!(tb.tokens_available_at(1), Ok(MillisTimestamp(210)));
assert_eq!(tb.tokens_available_at(2), Err(NETE::ZeroRate));
}
#[test]
fn test_duration_token_round_trip() {
let tokens_to_duration = TokenBucket::<Instant>::tokens_to_duration;
let duration_to_tokens = TokenBucket::<Instant>::duration_to_tokens;
let mut duration_rate_pairs = vec![
(Duration::from_nanos(0), 1),
(Duration::from_nanos(1), 1),
(Duration::from_micros(2), 1),
(Duration::MAX, 1),
(Duration::from_nanos(0), 3),
(Duration::from_nanos(1), 3),
(Duration::from_micros(2), 3),
(Duration::MAX, 3),
(Duration::from_nanos(0), 1000),
(Duration::from_nanos(1), 1000),
(Duration::from_micros(2), 1000),
(Duration::MAX, 1000),
(Duration::from_nanos(0), u64::MAX),
(Duration::from_nanos(1), u64::MAX),
(Duration::from_micros(2), u64::MAX),
(Duration::MAX, u64::MAX),
];
let mut rng = rand::rng();
for _ in 0..10_000 {
let secs = rng.random();
let nanos = rng.random();
let Ok(random_duration) = std::panic::catch_unwind(|| Duration::new(secs, nanos))
else {
continue;
};
let random_rate = rng.random();
duration_rate_pairs.push((random_duration, random_rate));
}
for (original_duration, rate) in duration_rate_pairs {
let tokens = duration_to_tokens(original_duration, rate);
let duration = tokens_to_duration(tokens, rate).unwrap();
assert_eq!(tokens, duration_to_tokens(duration, rate));
}
}
}