use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::Mutex;
#[derive(Clone)]
pub struct RateLimiter {
inner: Arc<Mutex<RateLimiterInner>>,
}
struct RateLimiterInner {
rate_bytes_per_sec: u64,
capacity: u64,
tokens: u64,
last_refill: Instant,
}
impl RateLimiter {
pub fn new(rate_bytes_per_sec: u64) -> Self {
let capacity = rate_bytes_per_sec.max(1);
Self {
inner: Arc::new(Mutex::new(RateLimiterInner {
rate_bytes_per_sec,
capacity,
tokens: capacity,
last_refill: Instant::now(),
})),
}
}
pub fn unlimited() -> Self {
Self::new(0)
}
pub fn is_limited(&self) -> bool {
true
}
pub async fn consume(&self, bytes: u64) {
let mut inner = self.inner.lock().await;
if inner.rate_bytes_per_sec == 0 {
return;
}
let now = Instant::now();
let elapsed = now.duration_since(inner.last_refill);
let refill = (elapsed.as_secs_f64() * inner.rate_bytes_per_sec as f64) as u64;
if refill > 0 {
inner.tokens = (inner.tokens + refill).min(inner.capacity);
inner.last_refill = now;
}
if bytes <= inner.tokens {
inner.tokens -= bytes;
return;
}
let deficit = bytes - inner.tokens;
let wait_secs = deficit as f64 / inner.rate_bytes_per_sec as f64;
let wait = Duration::from_secs_f64(wait_secs);
drop(inner);
tokio::time::sleep(wait).await;
let mut inner = self.inner.lock().await;
let now2 = Instant::now();
let elapsed2 = now2.duration_since(inner.last_refill);
let refill2 = (elapsed2.as_secs_f64() * inner.rate_bytes_per_sec as f64) as u64;
if refill2 > 0 {
inner.tokens = (inner.tokens + refill2).min(inner.capacity);
inner.last_refill = now2;
}
inner.tokens = inner.tokens.saturating_sub(bytes);
}
}
pub fn parse_limit(s: &str) -> Option<u64> {
let s = s.trim().to_uppercase();
let s = s.trim_end_matches("/S").trim_end_matches("/SEC");
if let Some(n) = s.strip_suffix("GB") {
return n
.trim()
.parse::<f64>()
.ok()
.map(|v| (v * 1024.0 * 1024.0 * 1024.0) as u64);
}
if let Some(n) = s.strip_suffix("MB") {
return n
.trim()
.parse::<f64>()
.ok()
.map(|v| (v * 1024.0 * 1024.0) as u64);
}
if let Some(n) = s.strip_suffix("KB") {
return n.trim().parse::<f64>().ok().map(|v| (v * 1024.0) as u64);
}
if let Some(n) = s.strip_suffix('B') {
return n.trim().parse::<u64>().ok();
}
s.trim().parse::<f64>().ok().map(|v| (v * 1024.0) as u64)
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Instant;
#[tokio::test]
async fn test_unlimited_returns_immediately() {
let limiter = RateLimiter::unlimited();
let start = Instant::now();
limiter.consume(1024 * 1024).await; assert!(start.elapsed() < Duration::from_millis(10));
}
#[tokio::test]
async fn test_rate_limiter_throttles() {
let limiter = RateLimiter::new(10 * 1024);
let start = Instant::now();
limiter.consume(20 * 1024).await;
let elapsed = start.elapsed();
assert!(
elapsed >= Duration::from_millis(900),
"elapsed={:?}",
elapsed
);
}
#[tokio::test]
async fn test_small_consume_no_wait() {
let limiter = RateLimiter::new(1024 * 1024);
let start = Instant::now();
limiter.consume(1).await;
assert!(start.elapsed() < Duration::from_millis(50));
}
#[test]
fn test_parse_limit_kb() {
assert_eq!(parse_limit("512KB"), Some(512 * 1024));
assert_eq!(parse_limit("512kb"), Some(512 * 1024));
}
#[test]
fn test_parse_limit_mb() {
assert_eq!(parse_limit("10MB"), Some(10 * 1024 * 1024));
assert_eq!(parse_limit("1MB/s"), Some(1024 * 1024));
}
#[test]
fn test_parse_limit_gb() {
assert_eq!(parse_limit("1GB"), Some(1024 * 1024 * 1024));
}
#[test]
fn test_parse_limit_bare_number() {
assert_eq!(parse_limit("100"), Some(100 * 1024));
}
#[test]
fn test_parse_limit_invalid() {
assert_eq!(parse_limit("abc"), None);
}
#[tokio::test]
async fn test_zero_rate_is_unlimited() {
let limiter = RateLimiter::new(0);
let start = Instant::now();
limiter.consume(100 * 1024 * 1024).await; assert!(start.elapsed() < Duration::from_millis(10));
}
}