#[cfg(not(target_arch = "wasm32"))]
mod body;
#[cfg(not(target_arch = "wasm32"))]
pub(crate) use body::BandwidthBody;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Duration;
#[derive(Clone)]
pub struct BandwidthLimiter {
inner: Arc<BandwidthInner>,
}
struct BandwidthInner {
bytes_per_sec: u64,
tokens: AtomicU64,
last_refill_ns: AtomicU64,
}
impl BandwidthLimiter {
pub fn new(bytes_per_sec: u64) -> Self {
assert!(bytes_per_sec > 0, "bytes_per_sec must be greater than 0");
let now_ns = crate::clock::monotonic_nanos();
Self {
inner: Arc::new(BandwidthInner {
bytes_per_sec,
tokens: AtomicU64::new(bytes_per_sec),
last_refill_ns: AtomicU64::new(now_ns),
}),
}
}
pub fn try_consume(&self, n: u64) -> u64 {
self.refill();
let mut consumed = 0;
self.inner
.tokens
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |current| {
let take = current.min(n);
consumed = take;
Some(current - take)
})
.ok();
consumed
}
pub fn wait_duration(&self, bytes_needed: u64) -> Duration {
self.refill();
let available = self.inner.tokens.load(Ordering::Relaxed);
if available >= bytes_needed {
return Duration::ZERO;
}
let deficit = bytes_needed - available;
let bps = self.inner.bytes_per_sec.max(1) as u128;
let nanos = (deficit as u128 * 1_000_000_000u128 / bps).min(u64::MAX as u128) as u64;
Duration::from_nanos(nanos)
}
fn refill(&self) {
let inner = &self.inner;
let now = crate::clock::monotonic_nanos();
let last = inner.last_refill_ns.load(Ordering::Relaxed);
let elapsed_ns = now.saturating_sub(last);
if elapsed_ns == 0 {
return;
}
let new_bytes = (elapsed_ns as u128 * inner.bytes_per_sec as u128 / 1_000_000_000) as u64;
if new_bytes == 0 {
return;
}
let consumed_ns =
(new_bytes as u128 * 1_000_000_000u128 / inner.bytes_per_sec.max(1) as u128) as u64;
let won_refill = inner
.last_refill_ns
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |l| {
if l == last {
Some(l + consumed_ns)
} else {
None
}
})
.is_ok();
if won_refill {
inner
.tokens
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |current| {
Some(current.saturating_add(new_bytes).min(inner.bytes_per_sec))
})
.ok();
}
}
}
impl std::fmt::Debug for BandwidthLimiter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BandwidthLimiter")
.field("bytes_per_sec", &self.inner.bytes_per_sec)
.field("available", &self.inner.tokens.load(Ordering::Relaxed))
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn new_starts_with_full_bandwidth() {
let bw = BandwidthLimiter::new(1);
assert_eq!(bw.try_consume(1), 1);
assert_eq!(bw.try_consume(1), 0);
}
#[test]
fn wait_duration_zero_when_available() {
let bw = BandwidthLimiter::new(1000);
assert_eq!(bw.wait_duration(100), Duration::ZERO);
}
#[test]
fn wait_duration_nonzero_when_exhausted() {
let bw = BandwidthLimiter::new(1000);
bw.try_consume(1000);
let wait = bw.wait_duration(100);
assert!(wait > Duration::ZERO);
}
#[test]
fn refill_replenishes() {
let bw = BandwidthLimiter::new(10_000);
bw.try_consume(10_000);
std::thread::sleep(Duration::from_millis(110));
let got = bw.try_consume(5000);
assert!(got > 0, "expected some tokens after refill, got {got}");
}
#[test]
fn clone_shares_state() {
let a = BandwidthLimiter::new(100);
let b = a.clone();
a.try_consume(50);
assert_eq!(b.try_consume(50), 50);
assert_eq!(b.try_consume(1), 0);
}
#[test]
fn debug_output() {
let bw = BandwidthLimiter::new(500);
let dbg = format!("{bw:?}");
assert!(dbg.contains("BandwidthLimiter"));
assert!(dbg.contains("500"));
}
#[test]
fn try_consume_zero() {
let bw = BandwidthLimiter::new(100);
assert_eq!(bw.try_consume(0), 0);
}
#[test]
fn wait_duration_zero_bytes() {
let bw = BandwidthLimiter::new(100);
assert_eq!(bw.wait_duration(0), Duration::ZERO);
}
#[test]
fn wait_duration_exact_boundary() {
let bw = BandwidthLimiter::new(100);
assert_eq!(bw.wait_duration(100), Duration::ZERO);
}
#[test]
fn partial_consumption() {
let bw = BandwidthLimiter::new(100);
assert_eq!(bw.try_consume(60), 60);
assert_eq!(bw.try_consume(60), 40);
}
#[test]
#[should_panic(expected = "bytes_per_sec must be greater than 0")]
fn zero_bytes_per_sec() {
BandwidthLimiter::new(0);
}
}