use std::num::NonZeroU64;
use std::time::Instant;
use crate::vmm::disk_config::DiskThrottle;
#[derive(Debug)]
pub(crate) struct TokenBucket {
pub(crate) capacity: u64,
pub(crate) refill_rate: u64, pub(crate) available: i64,
pub(crate) last_refill: Instant,
pub(crate) unlimited: bool,
#[cfg(test)]
pub(crate) forced_nanos_until_n_tokens: Option<u64>,
}
impl TokenBucket {
pub(crate) fn unlimited() -> Self {
Self {
capacity: 0,
refill_rate: 0,
available: 0,
last_refill: Instant::now(),
unlimited: true,
#[cfg(test)]
forced_nanos_until_n_tokens: None,
}
}
pub(crate) fn new(capacity: u64, refill_rate_per_sec: u64) -> Self {
if capacity == 0 || refill_rate_per_sec == 0 {
return Self::unlimited();
}
Self {
capacity,
refill_rate: refill_rate_per_sec,
available: i64::try_from(capacity).unwrap_or(i64::MAX),
last_refill: Instant::now(),
unlimited: false,
#[cfg(test)]
forced_nanos_until_n_tokens: None,
}
}
pub(crate) fn refill(&mut self) {
if self.unlimited {
return;
}
let now = Instant::now();
let elapsed_ns = now.duration_since(self.last_refill).as_nanos();
if elapsed_ns == 0 {
return;
}
let new_tokens = (self.refill_rate as u128 * elapsed_ns) / 1_000_000_000;
let new_tokens_u64 = u64::try_from(new_tokens).unwrap_or(u64::MAX);
if new_tokens_u64 == 0 {
return;
}
let add = i64::try_from(new_tokens_u64).unwrap_or(i64::MAX);
let cap_i64 = i64::try_from(self.capacity).unwrap_or(i64::MAX);
self.available = self.available.saturating_add(add).min(cap_i64);
self.last_refill = now;
}
pub(crate) fn consume(&mut self, n: u64) -> bool {
if self.unlimited {
return true;
}
if n == 0 {
return true;
}
self.refill();
let Ok(n_signed) = i64::try_from(n) else {
return false;
};
let granted = if n > self.capacity {
self.available >= 0
} else {
self.available >= n_signed
};
if !granted {
return false;
}
self.available = self.available.saturating_sub(n_signed);
true
}
pub(crate) fn can_consume(&mut self, n: u64) -> bool {
if self.unlimited {
return true;
}
if n == 0 {
return true;
}
self.refill();
let Ok(n_signed) = i64::try_from(n) else {
return false;
};
if n > self.capacity {
self.available >= 0
} else {
self.available >= n_signed
}
}
pub(crate) fn nanos_until_n_tokens(&mut self, need: u64) -> u64 {
#[cfg(test)]
if let Some(forced) = self.forced_nanos_until_n_tokens {
return forced;
}
if self.unlimited || need == 0 {
return 0;
}
self.refill();
let deficit_i128: i128 = if need > self.capacity {
if self.available >= 0 {
return 0;
}
-(self.available as i128)
} else {
let avail_i128 = self.available as i128;
let need_i128 = need as i128;
if avail_i128 >= need_i128 {
return 0;
}
need_i128 - avail_i128
};
debug_assert!(
deficit_i128 > 0,
"deficit must be positive after the early-return \
arms above (need={need}, available={})",
self.available,
);
let deficit_u128 = deficit_i128 as u128;
let numerator = deficit_u128 * 1_000_000_000;
let denom = self.refill_rate as u128;
let nanos_u128 = numerator.div_ceil(denom);
u64::try_from(nanos_u128).unwrap_or(u64::MAX)
}
#[cfg(test)]
pub(crate) fn set_last_refill_for_test(&mut self, t: Instant) {
self.last_refill = t;
}
#[cfg(test)]
pub(crate) fn set_forced_nanos_until_n_tokens_for_test(&mut self, forced: u64) {
self.forced_nanos_until_n_tokens = Some(forced);
}
#[cfg(test)]
pub(crate) fn clear_forced_nanos_until_n_tokens_for_test(&mut self) {
self.forced_nanos_until_n_tokens = None;
}
}
pub(crate) fn buckets_from_throttle(throttle: DiskThrottle) -> (TokenBucket, TokenBucket) {
let ops_bucket = throttle.iops.map_or_else(TokenBucket::unlimited, |nz| {
let r = nz.get();
let cap = throttle.iops_burst_capacity.map_or(r, NonZeroU64::get);
TokenBucket::new(cap, r)
});
let bytes_bucket = throttle
.bytes_per_sec
.map_or_else(TokenBucket::unlimited, |nz| {
let r = nz.get();
let cap = throttle.bytes_burst_capacity.map_or(r, NonZeroU64::get);
TokenBucket::new(cap, r)
});
(ops_bucket, bytes_bucket)
}
#[cfg(test)]
mod tests {
use super::super::DiskThrottle;
use super::*;
use std::num::NonZeroU64;
#[test]
fn token_bucket_unlimited_always_grants() {
let mut tb = TokenBucket::unlimited();
for _ in 0..1_000_000 {
assert!(tb.consume(1));
}
}
#[test]
fn token_bucket_consumes_capacity() {
let mut tb = TokenBucket::new(100, 1); for _ in 0..100 {
assert!(tb.consume(1));
}
assert!(!tb.consume(1));
}
#[test]
fn token_bucket_refills_over_time() {
let mut tb = TokenBucket::new(100, 10);
for _ in 0..100 {
assert!(tb.consume(1));
}
assert!(
!tb.consume(1),
"bucket exhausted; refill too slow to top up in microseconds",
);
std::thread::sleep(std::time::Duration::from_millis(200));
assert!(
tb.consume(1),
"after 200ms at 10 tokens/sec, at least 1 should be available",
);
}
#[test]
fn throttle_zero_rate_becomes_unlimited() {
let mut tb = TokenBucket::new(0, 100);
for _ in 0..10_000 {
assert!(tb.consume(1));
}
let mut tb = TokenBucket::new(100, 0);
for _ in 0..10_000 {
assert!(tb.consume(1));
}
}
#[test]
fn token_bucket_refill_uses_elapsed_wall_time() {
let mut tb = TokenBucket::new(10, 10);
for _ in 0..10 {
assert!(tb.consume(1));
}
assert!(!tb.consume(1));
std::thread::sleep(std::time::Duration::from_millis(1100));
for _ in 0..10 {
assert!(
tb.consume(1),
"bucket should have refilled to capacity after sleep"
);
}
}
#[test]
fn token_bucket_consume_zero_is_free() {
let mut tb = TokenBucket::new(10, 10);
for _ in 0..1_000 {
assert!(tb.consume(0));
}
for _ in 0..10 {
assert!(tb.consume(1));
}
assert!(!tb.consume(1));
}
#[test]
fn token_bucket_forced_nanos_until_n_tokens_overrides_deficit() {
let mut tb = TokenBucket::new(10, 10);
for _ in 0..10 {
assert!(tb.consume(1));
}
assert!(!tb.consume(1));
let real_nanos = tb.nanos_until_n_tokens(5);
assert!(
real_nanos > 0,
"real deficit must be positive after drain; got {real_nanos}",
);
tb.set_forced_nanos_until_n_tokens_for_test(0);
assert_eq!(
tb.nanos_until_n_tokens(5),
0,
"override must force 0 regardless of deficit",
);
assert_eq!(
tb.nanos_until_n_tokens(u64::MAX),
0,
"override is need-independent; u64::MAX still returns the forced value",
);
tb.set_forced_nanos_until_n_tokens_for_test(123_456);
assert_eq!(tb.nanos_until_n_tokens(1), 123_456);
tb.clear_forced_nanos_until_n_tokens_for_test();
let post_clear = tb.nanos_until_n_tokens(5);
assert!(
post_clear > 0,
"clearing override must restore real deficit math; got {post_clear}",
);
}
#[test]
fn token_bucket_forced_nanos_until_n_tokens_overrides_unlimited() {
let mut tb = TokenBucket::unlimited();
assert_eq!(tb.nanos_until_n_tokens(1_000), 0);
tb.set_forced_nanos_until_n_tokens_for_test(7);
assert_eq!(
tb.nanos_until_n_tokens(1_000),
7,
"override must take precedence over the unlimited fast path",
);
tb.clear_forced_nanos_until_n_tokens_for_test();
assert_eq!(tb.nanos_until_n_tokens(1_000), 0);
}
#[test]
fn nanos_until_n_tokens_saturates_at_u64_max() {
let mut tb = TokenBucket::new(1, 1);
tb.set_last_refill_for_test(std::time::Instant::now());
let huge = i64::MAX as u64;
assert!(tb.consume(huge), "overconsume succeeds when available >= 0");
tb.set_last_refill_for_test(std::time::Instant::now());
assert!(
!tb.can_consume(1),
"post-overconsume balance must be negative — \
any positive consume rejected by the gate",
);
tb.set_last_refill_for_test(std::time::Instant::now());
assert_eq!(
tb.nanos_until_n_tokens(u64::MAX),
u64::MAX,
"u64-scale deficit at rate=1 must saturate at u64::MAX",
);
}
#[test]
fn nanos_until_n_tokens_ceil_div_exact() {
let mut tb = TokenBucket::new(10, 10);
for _ in 0..10 {
assert!(tb.consume(1));
}
tb.set_last_refill_for_test(std::time::Instant::now());
assert_eq!(
tb.nanos_until_n_tokens(5),
500_000_000,
"deficit=5 with rate=10/sec must equal 0.5s = 500_000_000 ns",
);
}
#[test]
fn nanos_until_n_tokens_unlimited_returns_zero() {
let mut tb = TokenBucket::unlimited();
assert_eq!(
tb.nanos_until_n_tokens(u64::MAX),
0,
"unlimited bucket must return 0 regardless of need",
);
}
#[test]
fn nanos_until_n_tokens_post_refill_returns_zero() {
let mut tb = TokenBucket::new(10, 10);
for _ in 0..10 {
assert!(tb.consume(1));
}
tb.set_last_refill_for_test(std::time::Instant::now() - std::time::Duration::from_secs(2));
assert_eq!(
tb.nanos_until_n_tokens(1),
0,
"post-refill `available >= need` must return 0",
);
}
#[test]
fn token_bucket_oversized_grants_and_drives_negative() {
let mut tb = TokenBucket::new(100, 100);
assert!(
tb.consume(150),
"oversized consume must grant when available >= 0",
);
tb.set_last_refill_for_test(std::time::Instant::now());
assert_eq!(
tb.nanos_until_n_tokens(101),
500_000_000,
"post-overconsume debt of 50 (capacity=100, n=150 → 100-150) \
produces 500ms at rate=100/sec; deficit math: \
-available * 1e9 / refill_rate = 50 * 1e9 / 100",
);
assert!(
!tb.can_consume(1),
"follower (any size) stalls while bucket is in debt",
);
}
#[test]
fn token_bucket_oversized_back_to_back_second_stalls() {
let mut tb = TokenBucket::new(100, 100);
assert!(tb.consume(150));
tb.set_last_refill_for_test(std::time::Instant::now());
assert_eq!(
tb.nanos_until_n_tokens(101),
500_000_000,
"post-first-overconsume debt of 50 → 500ms at rate=100",
);
tb.set_last_refill_for_test(std::time::Instant::now());
assert!(
!tb.consume(150),
"second oversized must stall while bucket is in debt",
);
tb.set_last_refill_for_test(std::time::Instant::now());
assert_eq!(
tb.nanos_until_n_tokens(101),
500_000_000,
"failed consume must NOT deepen the debt — deficit \
unchanged at 50",
);
assert!(!tb.can_consume(150));
}
#[test]
fn nanos_until_n_tokens_oversized_follower_waits_for_zero() {
let mut tb = TokenBucket::new(100, 100);
assert!(tb.consume(150));
tb.set_last_refill_for_test(std::time::Instant::now());
assert_eq!(
tb.nanos_until_n_tokens(200),
500_000_000,
"oversized follower waits for available to climb to 0",
);
}
#[test]
fn nanos_until_n_tokens_normal_follower_after_debt() {
let mut tb = TokenBucket::new(100, 100);
assert!(tb.consume(150));
tb.set_last_refill_for_test(std::time::Instant::now());
assert_eq!(
tb.nanos_until_n_tokens(10),
600_000_000,
"normal-sized follower waits for available to climb \
from -50 to need=10",
);
}
#[test]
fn token_bucket_consume_rejects_n_above_i64_max() {
let mut tb = TokenBucket::new(100, 100);
let pathological = (i64::MAX as u64) + 1;
assert!(
!tb.can_consume(pathological),
"n > i64::MAX must fail can_consume — i64 cast guard",
);
assert!(
!tb.consume(pathological),
"n > i64::MAX must fail consume — i64 cast guard",
);
assert!(
tb.can_consume(100),
"rejection must NOT decrement balance — full seed of \
100 tokens must still be grantable",
);
assert!(!tb.consume(u64::MAX));
assert!(!tb.can_consume(u64::MAX));
assert!(
tb.can_consume(100),
"second rejection round must also leave balance \
unchanged at the seed value",
);
}
#[test]
fn token_bucket_zero_consume_succeeds_in_debt() {
let mut tb = TokenBucket::new(100, 100);
assert!(tb.consume(150));
tb.set_last_refill_for_test(std::time::Instant::now());
assert!(
!tb.can_consume(1),
"bucket must be in debt — any positive consume rejected",
);
assert!(tb.consume(0));
assert!(tb.can_consume(0));
tb.set_last_refill_for_test(std::time::Instant::now());
assert_eq!(
tb.nanos_until_n_tokens(101),
500_000_000,
"consume(0) / can_consume(0) must NOT touch balance — \
debt of 50 unchanged after zero-cost requests",
);
}
#[test]
fn token_bucket_debt_clears_with_refill() {
let mut tb = TokenBucket::new(100, 100);
assert!(tb.consume(150));
tb.set_last_refill_for_test(std::time::Instant::now() - std::time::Duration::from_secs(1));
assert!(
tb.consume(50),
"consume must succeed after refill clears the debt",
);
tb.set_last_refill_for_test(std::time::Instant::now());
assert_eq!(
tb.nanos_until_n_tokens(1),
10_000_000,
"post-consume(50) balance must be 0 — refill paid -50 \
debt and consume(50) drained the recovered +50; \
nanos for need=1 at rate=100 = 10_000_000",
);
}
#[test]
fn token_bucket_unlimited_grants_oversized() {
let mut tb = TokenBucket::unlimited();
assert!(tb.consume(u64::MAX));
assert!(tb.can_consume(u64::MAX));
assert_eq!(
tb.nanos_until_n_tokens(u64::MAX),
0,
"unlimited bucket reports zero wait for any need",
);
}
#[test]
fn token_bucket_consume_at_capacity_takes_normal_branch() {
let mut tb = TokenBucket::new(100, 100);
assert!(
tb.consume(100),
"n == capacity must succeed via normal-path \
available >= n_signed gate",
);
tb.set_last_refill_for_test(std::time::Instant::now());
assert_eq!(
tb.nanos_until_n_tokens(1),
10_000_000,
"post-drain balance must be 0 (not negative); deficit=1 \
at rate=100 = 10_000_000 ns. Overconsume branch entered \
would drive balance to -100, deficit=101 → 1_010_000_000",
);
tb.set_last_refill_for_test(std::time::Instant::now());
assert!(
!tb.consume(100),
"n == capacity (not > capacity) must fail when \
available < n_signed; overconsume branch is \
strictly `n > capacity`, not `n >= capacity`",
);
tb.set_last_refill_for_test(std::time::Instant::now());
assert_eq!(
tb.nanos_until_n_tokens(1),
10_000_000,
"available unchanged at 0 — overconsume branch did \
NOT drive it to -100 (which would yield 1_010_000_000), \
proving the boundary check is `>` not `>=`",
);
}
#[test]
fn buckets_from_throttle_default_burst_equals_rate() {
let throttle = DiskThrottle {
iops: NonZeroU64::new(1_000),
bytes_per_sec: NonZeroU64::new(50_000),
iops_burst_capacity: None,
bytes_burst_capacity: None,
};
let (mut ops, mut bytes) = buckets_from_throttle(throttle);
assert_eq!(ops.capacity, 1_000);
assert_eq!(ops.refill_rate, 1_000);
assert!(
ops.can_consume(1_000),
"1-second-burst seed equals rate — bucket admits a \
capacity-sized request immediately",
);
assert_eq!(bytes.capacity, 50_000);
assert_eq!(bytes.refill_rate, 50_000);
assert!(
bytes.can_consume(50_000),
"bytes bucket also seeded full at capacity",
);
}
#[test]
fn buckets_from_throttle_burst_capacity_overrides_rate() {
let throttle = DiskThrottle {
iops: NonZeroU64::new(1_000),
bytes_per_sec: NonZeroU64::new(50_000),
iops_burst_capacity: NonZeroU64::new(5_000),
bytes_burst_capacity: NonZeroU64::new(250_000),
};
let (mut ops, mut bytes) = buckets_from_throttle(throttle);
assert_eq!(ops.capacity, 5_000);
assert_eq!(ops.refill_rate, 1_000);
assert!(
ops.can_consume(5_000),
"ops bucket seeded equal to burst capacity (5_000)",
);
assert_eq!(bytes.capacity, 250_000);
assert_eq!(bytes.refill_rate, 50_000);
assert!(
bytes.can_consume(250_000),
"bytes bucket seeded equal to burst capacity (250_000)",
);
}
#[test]
fn buckets_from_throttle_burst_without_rate_is_unlimited() {
let throttle = DiskThrottle {
iops: None,
bytes_per_sec: None,
iops_burst_capacity: NonZeroU64::new(5_000),
bytes_burst_capacity: NonZeroU64::new(250_000),
};
let (ops, bytes) = buckets_from_throttle(throttle);
assert!(ops.unlimited);
assert!(bytes.unlimited);
}
#[test]
fn buckets_from_throttle_per_dimension_independence() {
let throttle = DiskThrottle {
iops: NonZeroU64::new(1_000),
bytes_per_sec: NonZeroU64::new(50_000),
iops_burst_capacity: None,
bytes_burst_capacity: NonZeroU64::new(200_000),
};
let (ops, bytes) = buckets_from_throttle(throttle);
assert_eq!(ops.capacity, 1_000, "iops bucket falls back to rate");
assert_eq!(bytes.capacity, 200_000, "bytes bucket honours burst");
}
}