use std::time::{Duration, Instant};
use tokio::time;
pub struct TokenBucket {
pub bytes_per_sec: u64,
tokens: i64,
capacity: i64,
last_refill: Instant,
nanos_per_byte: u64,
}
impl TokenBucket {
pub fn new(bytes_per_sec: u64) -> Self {
let capacity = (bytes_per_sec / 10).max(8192) as i64;
let nanos_per_byte = if bytes_per_sec > 0 {
1_000_000_000 / bytes_per_sec
} else {
0
};
Self {
bytes_per_sec,
tokens: capacity,
capacity,
last_refill: Instant::now(),
nanos_per_byte,
}
}
pub async fn consume(&mut self, bytes: usize) {
let bytes = bytes as i64;
self.refill();
if self.tokens < bytes {
let tokens_needed = bytes - self.tokens;
let sleep_nanos = tokens_needed as u64 * self.nanos_per_byte;
if sleep_nanos > 0 {
if sleep_nanos > 10_000 {
let sleep_duration = Duration::from_nanos(sleep_nanos);
time::sleep(sleep_duration).await;
self.refill();
}
}
}
self.tokens -= bytes;
}
fn refill(&mut self) {
let now = Instant::now();
let elapsed = now.duration_since(self.last_refill);
let elapsed_nanos = elapsed.as_nanos() as u64;
if elapsed_nanos > 0 {
let elapsed_micros = elapsed.as_micros() as u64;
let tokens_to_add = (elapsed_micros * self.bytes_per_sec) / 1_000_000;
if tokens_to_add > 0 {
self.tokens = (self.tokens + tokens_to_add as i64).min(self.capacity);
self.last_refill = now;
}
}
}
pub fn reset(&mut self) {
self.tokens = self.capacity;
self.last_refill = Instant::now();
}
pub fn available_tokens(&self) -> i64 {
self.tokens
}
pub fn set_bandwidth(&mut self, bytes_per_sec: u64) {
self.bytes_per_sec = bytes_per_sec;
self.capacity = (bytes_per_sec / 10).max(8192) as i64;
self.nanos_per_byte = if bytes_per_sec > 0 {
1_000_000_000 / bytes_per_sec
} else {
0
};
self.reset();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_token_bucket_creation() {
let bucket = TokenBucket::new(1_000_000);
assert_eq!(bucket.bytes_per_sec, 1_000_000);
assert!(bucket.capacity > 0);
assert_eq!(bucket.tokens, bucket.capacity);
}
#[test]
fn test_token_bucket_capacity() {
let bucket = TokenBucket::new(10_000_000);
assert_eq!(bucket.capacity, 1_000_000);
let small_bucket = TokenBucket::new(1000);
assert_eq!(small_bucket.capacity, 8192);
}
#[tokio::test]
async fn test_token_consumption() {
let mut bucket = TokenBucket::new(1_000_000);
let initial = bucket.tokens;
bucket.consume(1500).await;
assert_eq!(bucket.tokens, initial - 1500);
}
#[tokio::test]
async fn test_token_refill() {
let mut bucket = TokenBucket::new(1_000_000);
bucket.tokens = 500;
let initial_tokens = bucket.tokens;
time::sleep(Duration::from_millis(10)).await;
bucket.refill();
assert!(bucket.tokens > initial_tokens);
assert!(bucket.tokens <= bucket.capacity);
}
#[test]
fn test_nanos_per_byte() {
let bucket = TokenBucket::new(1_000_000);
assert_eq!(bucket.nanos_per_byte, 1000);
let bucket2 = TokenBucket::new(10_000_000);
assert_eq!(bucket2.nanos_per_byte, 100);
}
#[test]
fn test_bandwidth_update() {
let mut bucket = TokenBucket::new(1_000_000);
let old_capacity = bucket.capacity;
bucket.set_bandwidth(10_000_000);
assert_eq!(bucket.bytes_per_sec, 10_000_000);
assert!(bucket.capacity > old_capacity);
assert_eq!(bucket.tokens, bucket.capacity);
}
#[test]
fn test_reset() {
let mut bucket = TokenBucket::new(1_000_000);
bucket.tokens = 0;
bucket.reset();
assert_eq!(bucket.tokens, bucket.capacity);
}
}