use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use tokio::time::sleep;
#[derive(Debug, Clone)]
pub struct RateLimitConfig {
pub upload_rate: u64,
pub download_rate: u64,
pub burst_multiplier: f64,
pub min_transfer_size: u64,
pub enabled: bool,
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
upload_rate: 0, download_rate: 0, burst_multiplier: 1.5, min_transfer_size: 1024, enabled: true,
}
}
}
impl RateLimitConfig {
#[must_use]
#[inline]
pub fn with_rates(upload_mbps: f64, download_mbps: f64) -> Self {
Self {
upload_rate: (upload_mbps * 1_000_000.0 / 8.0) as u64,
download_rate: (download_mbps * 1_000_000.0 / 8.0) as u64,
..Default::default()
}
}
#[must_use]
#[inline]
pub fn symmetric(rate_mbps: f64) -> Self {
Self::with_rates(rate_mbps, rate_mbps)
}
#[must_use]
#[inline]
pub fn unlimited() -> Self {
Self {
enabled: false,
..Default::default()
}
}
}
struct TokenBucket {
tokens: AtomicU64,
max_tokens: u64,
rate: u64,
last_refill: RwLock<Instant>,
}
impl TokenBucket {
fn new(rate: u64, burst_multiplier: f64) -> Self {
let max_tokens = (rate as f64 * burst_multiplier) as u64;
Self {
tokens: AtomicU64::new(max_tokens),
max_tokens,
rate,
last_refill: RwLock::new(Instant::now()),
}
}
async fn refill(&self) {
let mut last = self.last_refill.write().await;
let now = Instant::now();
let elapsed = now.duration_since(*last);
if elapsed.as_millis() > 0 {
let new_tokens = (elapsed.as_secs_f64() * self.rate as f64) as u64;
let current = self.tokens.load(Ordering::Relaxed);
let updated = current.saturating_add(new_tokens).min(self.max_tokens);
self.tokens.store(updated, Ordering::Relaxed);
*last = now;
}
}
async fn consume(&self, bytes: u64) -> Duration {
self.refill().await;
let current = self.tokens.load(Ordering::Relaxed);
if current >= bytes {
self.tokens.fetch_sub(bytes, Ordering::Relaxed);
Duration::ZERO
} else {
let needed = bytes.saturating_sub(current);
let wait_secs = needed as f64 / self.rate as f64;
Duration::from_secs_f64(wait_secs)
}
}
fn available(&self) -> u64 {
self.tokens.load(Ordering::Relaxed)
}
}
pub struct BandwidthLimiter {
config: RateLimitConfig,
upload_bucket: Option<TokenBucket>,
download_bucket: Option<TokenBucket>,
stats: Arc<RwLock<BandwidthStats>>,
}
#[derive(Debug, Clone, Default)]
pub struct BandwidthStats {
pub bytes_uploaded: u64,
pub bytes_downloaded: u64,
pub upload_rate: f64,
pub download_rate: f64,
pub total_wait_time: Duration,
pub limited_transfers: u64,
pub started_at: Option<Instant>,
}
impl BandwidthStats {
fn new() -> Self {
Self {
started_at: Some(Instant::now()),
..Default::default()
}
}
fn update_rates(&mut self) {
if let Some(start) = self.started_at {
let elapsed = start.elapsed().as_secs_f64();
if elapsed > 0.0 {
self.upload_rate = self.bytes_uploaded as f64 / elapsed;
self.download_rate = self.bytes_downloaded as f64 / elapsed;
}
}
}
}
impl BandwidthLimiter {
#[must_use]
#[inline]
pub fn new(config: RateLimitConfig) -> Self {
let upload_bucket = if config.enabled && config.upload_rate > 0 {
Some(TokenBucket::new(
config.upload_rate,
config.burst_multiplier,
))
} else {
None
};
let download_bucket = if config.enabled && config.download_rate > 0 {
Some(TokenBucket::new(
config.download_rate,
config.burst_multiplier,
))
} else {
None
};
Self {
config,
upload_bucket,
download_bucket,
stats: Arc::new(RwLock::new(BandwidthStats::new())),
}
}
pub async fn limit_upload(&self, bytes: u64) {
if !self.config.enabled || bytes < self.config.min_transfer_size {
return;
}
if let Some(ref bucket) = self.upload_bucket {
let wait = bucket.consume(bytes).await;
if !wait.is_zero() {
let mut stats = self.stats.write().await;
stats.total_wait_time += wait;
stats.limited_transfers += 1;
drop(stats);
sleep(wait).await;
}
let mut stats = self.stats.write().await;
stats.bytes_uploaded += bytes;
stats.update_rates();
}
}
pub async fn limit_download(&self, bytes: u64) {
if !self.config.enabled || bytes < self.config.min_transfer_size {
return;
}
if let Some(ref bucket) = self.download_bucket {
let wait = bucket.consume(bytes).await;
if !wait.is_zero() {
let mut stats = self.stats.write().await;
stats.total_wait_time += wait;
stats.limited_transfers += 1;
drop(stats);
sleep(wait).await;
}
let mut stats = self.stats.write().await;
stats.bytes_downloaded += bytes;
stats.update_rates();
}
}
pub async fn record_upload(&self, bytes: u64) {
let mut stats = self.stats.write().await;
stats.bytes_uploaded += bytes;
stats.update_rates();
}
pub async fn record_download(&self, bytes: u64) {
let mut stats = self.stats.write().await;
stats.bytes_downloaded += bytes;
stats.update_rates();
}
#[must_use]
pub async fn stats(&self) -> BandwidthStats {
self.stats.read().await.clone()
}
#[must_use]
#[inline]
pub fn available_upload(&self) -> Option<u64> {
self.upload_bucket.as_ref().map(|b| b.available())
}
#[must_use]
#[inline]
pub fn available_download(&self) -> Option<u64> {
self.download_bucket.as_ref().map(|b| b.available())
}
#[must_use]
#[inline]
pub fn is_enabled(&self) -> bool {
self.config.enabled
}
#[must_use]
#[inline]
pub fn upload_rate(&self) -> u64 {
self.config.upload_rate
}
#[must_use]
#[inline]
pub fn download_rate(&self) -> u64 {
self.config.download_rate
}
}
pub struct PeerRateLimiter {
global: Arc<BandwidthLimiter>,
peer_limiters: RwLock<std::collections::HashMap<String, Arc<BandwidthLimiter>>>,
peer_rate_fraction: f64,
}
impl PeerRateLimiter {
#[must_use]
#[inline]
pub fn new(global_config: RateLimitConfig, peer_rate_fraction: f64) -> Self {
Self {
global: Arc::new(BandwidthLimiter::new(global_config)),
peer_limiters: RwLock::new(std::collections::HashMap::new()),
peer_rate_fraction,
}
}
#[must_use]
pub async fn get_peer_limiter(&self, peer_id: &str) -> Arc<BandwidthLimiter> {
{
let limiters = self.peer_limiters.read().await;
if let Some(limiter) = limiters.get(peer_id) {
return Arc::clone(limiter);
}
}
let peer_config = RateLimitConfig {
upload_rate: (self.global.upload_rate() as f64 * self.peer_rate_fraction) as u64,
download_rate: (self.global.download_rate() as f64 * self.peer_rate_fraction) as u64,
burst_multiplier: 2.0, min_transfer_size: 512,
enabled: self.global.is_enabled(),
};
let limiter = Arc::new(BandwidthLimiter::new(peer_config));
let mut limiters = self.peer_limiters.write().await;
limiters.insert(peer_id.to_string(), Arc::clone(&limiter));
limiter
}
pub async fn limit_upload(&self, peer_id: &str, bytes: u64) {
self.global.limit_upload(bytes).await;
let peer_limiter = self.get_peer_limiter(peer_id).await;
peer_limiter.limit_upload(bytes).await;
}
pub async fn limit_download(&self, peer_id: &str, bytes: u64) {
self.global.limit_download(bytes).await;
let peer_limiter = self.get_peer_limiter(peer_id).await;
peer_limiter.limit_download(bytes).await;
}
#[must_use]
pub async fn global_stats(&self) -> BandwidthStats {
self.global.stats().await
}
#[must_use]
pub async fn peer_stats(&self, peer_id: &str) -> Option<BandwidthStats> {
let limiters = self.peer_limiters.read().await;
if let Some(limiter) = limiters.get(peer_id) {
Some(limiter.stats().await)
} else {
None
}
}
pub async fn remove_peer(&self, peer_id: &str) {
let mut limiters = self.peer_limiters.write().await;
limiters.remove(peer_id);
}
#[must_use]
pub async fn peer_count(&self) -> usize {
self.peer_limiters.read().await.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_default() {
let config = RateLimitConfig::default();
assert!(config.enabled);
assert_eq!(config.upload_rate, 0);
assert_eq!(config.download_rate, 0);
}
#[test]
fn test_config_with_rates() {
let config = RateLimitConfig::with_rates(100.0, 50.0); assert_eq!(config.upload_rate, 12_500_000); assert_eq!(config.download_rate, 6_250_000); }
#[tokio::test]
async fn test_unlimited_limiter() {
let config = RateLimitConfig::unlimited();
let limiter = BandwidthLimiter::new(config);
let start = Instant::now();
limiter.limit_upload(10_000_000).await; limiter.limit_download(10_000_000).await;
assert!(start.elapsed() < Duration::from_millis(10));
}
#[tokio::test]
async fn test_stats_recording() {
let config = RateLimitConfig::unlimited();
let limiter = BandwidthLimiter::new(config);
limiter.record_upload(1000).await;
limiter.record_download(2000).await;
let stats = limiter.stats().await;
assert_eq!(stats.bytes_uploaded, 1000);
assert_eq!(stats.bytes_downloaded, 2000);
}
#[tokio::test]
async fn test_peer_rate_limiter() {
let global_config = RateLimitConfig::unlimited();
let peer_limiter = PeerRateLimiter::new(global_config, 0.25);
let _limiter = peer_limiter.get_peer_limiter("peer1").await;
assert_eq!(peer_limiter.peer_count().await, 1);
peer_limiter.remove_peer("peer1").await;
assert_eq!(peer_limiter.peer_count().await, 0);
}
}