use std::time::{Duration, Instant};
use tokio::sync::Mutex;
use tracing::{debug, trace};
use crate::protocol::CHUNK_SIZE;
const BURST_CHUNKS: usize = 10;
pub struct Throttler {
rate: u64,
bucket_size: u64,
state: Mutex<ThrottlerState>,
}
struct ThrottlerState {
tokens: f64,
last_refill: Instant,
}
impl Throttler {
pub fn new(rate_kb_s: u64) -> Option<Self> {
if rate_kb_s == 0 {
return None;
}
let rate = rate_kb_s * 1024;
let bucket_size = (BURST_CHUNKS * CHUNK_SIZE) as u64;
debug!(
rate_kb_s,
rate_bytes_s = rate,
bucket_size,
"Throttler created"
);
Some(Self {
rate,
bucket_size,
state: Mutex::new(ThrottlerState {
tokens: bucket_size as f64,
last_refill: Instant::now(),
}),
})
}
pub async fn acquire(&self, bytes: usize) {
loop {
let wait_duration = {
let mut state = self.state.lock().await;
let now = Instant::now();
let elapsed = (now - state.last_refill).as_secs_f64();
state.tokens += elapsed * self.rate as f64;
if state.tokens > self.bucket_size as f64 {
state.tokens = self.bucket_size as f64;
}
state.last_refill = now;
if state.tokens >= bytes as f64 {
state.tokens -= bytes as f64;
trace!(tokens_remaining = state.tokens, "Tokens consumed");
return;
}
let deficit = bytes as f64 - state.tokens;
Duration::from_secs_f64(deficit / self.rate as f64)
};
trace!(
wait_ms = wait_duration.as_secs_f64() * 1000.0,
"Throttle: sleeping"
);
tokio::time::sleep(wait_duration).await;
}
}
pub fn rate_kb_s(&self) -> u64 {
self.rate / 1024
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new_returns_none_when_rate_is_zero() {
assert!(Throttler::new(0).is_none());
}
#[test]
fn test_new_returns_some_when_rate_is_nonzero() {
let t = Throttler::new(100).unwrap();
assert_eq!(t.rate_kb_s(), 100);
}
#[tokio::test]
async fn test_acquire_immediate_when_tokens_available() {
let t = Throttler::new(1).unwrap(); let start = Instant::now();
t.acquire(1).await;
assert!(start.elapsed() < Duration::from_millis(50));
}
#[tokio::test]
async fn test_acquire_sleeps_when_exhausted() {
let t = Throttler::new(10).unwrap();
t.acquire(BURST_CHUNKS * CHUNK_SIZE).await;
let start = Instant::now();
t.acquire(2048).await;
assert!(
start.elapsed() >= Duration::from_millis(150),
"Expected >= 150ms sleep, got {:?}",
start.elapsed()
);
}
#[tokio::test]
async fn test_acquire_partial_refill() {
let t = Throttler::new(1024).unwrap();
t.acquire(BURST_CHUNKS * CHUNK_SIZE).await;
tokio::time::sleep(Duration::from_millis(100)).await;
let start = Instant::now();
t.acquire(50_000).await;
assert!(
start.elapsed() < Duration::from_millis(50),
"Expected immediate acquire after refill, got {:?}",
start.elapsed()
);
}
#[test]
#[ignore]
fn test_rate_limiting_sustained() {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let t = Throttler::new(10).unwrap();
t.acquire(BURST_CHUNKS * CHUNK_SIZE).await;
let start = Instant::now();
for _ in 0..5 {
t.acquire(10 * 1024).await;
}
let elapsed = start.elapsed();
assert!(
elapsed >= Duration::from_secs(4),
"Expected >= 4s for 50KB at 10KB/s, got {:?}",
elapsed
);
});
}
}