1use parking_lot::RwLock;
36use std::collections::HashMap;
37use std::net::SocketAddr;
38use std::sync::Arc;
39use std::time::{Duration, Instant};
40use thiserror::Error;
41use tokio::time::sleep;
42use tracing::debug;
43
44#[derive(Error, Debug)]
46pub enum ThrottleError {
47 #[error("Rate limit exceeded")]
48 RateLimitExceeded,
49
50 #[error("Quota exhausted")]
51 QuotaExhausted,
52
53 #[error("Peer not found: {0}")]
54 PeerNotFound(String),
55}
56
57#[derive(Debug, Clone)]
59pub struct TokenBucket {
60 capacity: u64,
62 tokens: Arc<RwLock<f64>>,
64 refill_rate: f64,
66 last_refill: Arc<RwLock<Instant>>,
68}
69
70impl TokenBucket {
71 pub fn new(capacity: u64, refill_rate: f64) -> Self {
73 Self {
74 capacity,
75 tokens: Arc::new(RwLock::new(capacity as f64)),
76 refill_rate,
77 last_refill: Arc::new(RwLock::new(Instant::now())),
78 }
79 }
80
81 pub fn try_consume(&self, amount: u64) -> bool {
83 self.refill();
84
85 let mut tokens = self.tokens.write();
86 if *tokens >= amount as f64 {
87 *tokens -= amount as f64;
88 true
89 } else {
90 false
91 }
92 }
93
94 pub async fn consume(&self, amount: u64) {
96 loop {
97 if self.try_consume(amount) {
98 return;
99 }
100
101 let tokens = *self.tokens.read();
103 let needed = amount as f64 - tokens;
104 let wait_time = Duration::from_secs_f64(needed / self.refill_rate);
105
106 sleep(wait_time.min(Duration::from_millis(100))).await;
107 }
108 }
109
110 fn refill(&self) {
112 let now = Instant::now();
113 let mut last_refill = self.last_refill.write();
114 let elapsed = now.duration_since(*last_refill).as_secs_f64();
115
116 if elapsed > 0.0 {
117 let new_tokens = elapsed * self.refill_rate;
118 let mut tokens = self.tokens.write();
119 *tokens = (*tokens + new_tokens).min(self.capacity as f64);
120 *last_refill = now;
121 }
122 }
123
124 pub fn available_tokens(&self) -> u64 {
126 self.refill();
127 *self.tokens.read() as u64
128 }
129
130 pub fn capacity(&self) -> u64 {
132 self.capacity
133 }
134
135 pub fn refill_rate(&self) -> f64 {
137 self.refill_rate
138 }
139}
140
141#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
143pub enum QosPriority {
144 BestEffort = 0,
146 Normal = 1,
148 High = 2,
150 Critical = 3,
152}
153
154impl QosPriority {
155 pub fn multiplier(&self) -> f64 {
157 match self {
158 QosPriority::BestEffort => 0.5,
159 QosPriority::Normal => 1.0,
160 QosPriority::High => 2.0,
161 QosPriority::Critical => 4.0,
162 }
163 }
164}
165
166#[derive(Debug, Clone)]
168pub struct BandwidthConfig {
169 pub global_upload_limit: u64,
171 pub global_download_limit: u64,
173 pub peer_upload_limit: u64,
175 pub peer_download_limit: u64,
177 pub allow_burst: bool,
179 pub burst_multiplier: f64,
181}
182
183impl Default for BandwidthConfig {
184 fn default() -> Self {
185 Self {
186 global_upload_limit: 0, global_download_limit: 0, peer_upload_limit: 0, peer_download_limit: 0, allow_burst: true,
191 burst_multiplier: 2.0,
192 }
193 }
194}
195
196#[derive(Debug, Clone, Default)]
198pub struct ThrottleStats {
199 pub bytes_uploaded: u64,
201 pub bytes_downloaded: u64,
203 pub throttle_count: u64,
205 pub total_wait_time: Duration,
207 pub current_upload_rate: f64,
209 pub current_download_rate: f64,
211}
212
213struct PeerThrottle {
215 upload: Option<TokenBucket>,
216 download: Option<TokenBucket>,
217 priority: QosPriority,
218}
219
220pub struct BandwidthThrottle {
222 config: BandwidthConfig,
223 global_upload: Option<Arc<TokenBucket>>,
224 global_download: Option<Arc<TokenBucket>>,
225 peer_throttles: Arc<RwLock<HashMap<SocketAddr, PeerThrottle>>>,
226 stats: Arc<RwLock<ThrottleStats>>,
227}
228
229impl BandwidthThrottle {
230 pub fn new(config: BandwidthConfig) -> Self {
232 let burst_capacity_multiplier = if config.allow_burst {
233 config.burst_multiplier
234 } else {
235 1.0
236 };
237
238 let global_upload = if config.global_upload_limit > 0 {
239 Some(Arc::new(TokenBucket::new(
240 (config.global_upload_limit as f64 * burst_capacity_multiplier) as u64,
241 config.global_upload_limit as f64,
242 )))
243 } else {
244 None
245 };
246
247 let global_download = if config.global_download_limit > 0 {
248 Some(Arc::new(TokenBucket::new(
249 (config.global_download_limit as f64 * burst_capacity_multiplier) as u64,
250 config.global_download_limit as f64,
251 )))
252 } else {
253 None
254 };
255
256 Self {
257 config,
258 global_upload,
259 global_download,
260 peer_throttles: Arc::new(RwLock::new(HashMap::new())),
261 stats: Arc::new(RwLock::new(ThrottleStats::default())),
262 }
263 }
264
265 pub fn register_peer(&self, addr: SocketAddr, priority: QosPriority) {
267 let burst_multiplier = if self.config.allow_burst {
268 self.config.burst_multiplier
269 } else {
270 1.0
271 };
272
273 let upload = if self.config.peer_upload_limit > 0 {
274 let rate = self.config.peer_upload_limit as f64 * priority.multiplier();
275 Some(TokenBucket::new((rate * burst_multiplier) as u64, rate))
276 } else {
277 None
278 };
279
280 let download = if self.config.peer_download_limit > 0 {
281 let rate = self.config.peer_download_limit as f64 * priority.multiplier();
282 Some(TokenBucket::new((rate * burst_multiplier) as u64, rate))
283 } else {
284 None
285 };
286
287 let throttle = PeerThrottle {
288 upload,
289 download,
290 priority,
291 };
292
293 self.peer_throttles.write().insert(addr, throttle);
294 debug!("Registered peer {} with priority {:?}", addr, priority);
295 }
296
297 pub fn unregister_peer(&self, addr: &SocketAddr) {
299 self.peer_throttles.write().remove(addr);
300 debug!("Unregistered peer {}", addr);
301 }
302
303 pub async fn throttle_upload(&self, addr: &SocketAddr, bytes: u64) {
305 let start = Instant::now();
306
307 if let Some(global) = &self.global_upload {
309 global.consume(bytes).await;
310 }
311
312 let upload_bucket = {
314 let peer_throttles = self.peer_throttles.read();
315 peer_throttles
316 .get(addr)
317 .and_then(|peer| peer.upload.clone())
318 };
319
320 if let Some(upload) = upload_bucket {
321 upload.consume(bytes).await;
322 }
323
324 {
326 let mut stats = self.stats.write();
327 stats.bytes_uploaded += bytes;
328 let wait_time = start.elapsed();
329 if wait_time > Duration::from_millis(1) {
330 stats.throttle_count += 1;
331 stats.total_wait_time += wait_time;
332 }
333 }
334 }
335
336 pub async fn throttle_download(&self, addr: &SocketAddr, bytes: u64) {
338 let start = Instant::now();
339
340 if let Some(global) = &self.global_download {
342 global.consume(bytes).await;
343 }
344
345 let download_bucket = {
347 let peer_throttles = self.peer_throttles.read();
348 peer_throttles
349 .get(addr)
350 .and_then(|peer| peer.download.clone())
351 };
352
353 if let Some(download) = download_bucket {
354 download.consume(bytes).await;
355 }
356
357 {
359 let mut stats = self.stats.write();
360 stats.bytes_downloaded += bytes;
361 let wait_time = start.elapsed();
362 if wait_time > Duration::from_millis(1) {
363 stats.throttle_count += 1;
364 stats.total_wait_time += wait_time;
365 }
366 }
367 }
368
369 pub fn try_throttle_upload(&self, addr: &SocketAddr, bytes: u64) -> bool {
371 if let Some(global) = &self.global_upload {
373 if !global.try_consume(bytes) {
374 return false;
375 }
376 }
377
378 let peer_throttles = self.peer_throttles.read();
380 if let Some(peer) = peer_throttles.get(addr) {
381 if let Some(upload) = &peer.upload {
382 if !upload.try_consume(bytes) {
383 return false;
384 }
385 }
386 }
387
388 self.stats.write().bytes_uploaded += bytes;
390 true
391 }
392
393 pub fn try_throttle_download(&self, addr: &SocketAddr, bytes: u64) -> bool {
395 if let Some(global) = &self.global_download {
397 if !global.try_consume(bytes) {
398 return false;
399 }
400 }
401
402 let peer_throttles = self.peer_throttles.read();
404 if let Some(peer) = peer_throttles.get(addr) {
405 if let Some(download) = &peer.download {
406 if !download.try_consume(bytes) {
407 return false;
408 }
409 }
410 }
411
412 self.stats.write().bytes_downloaded += bytes;
414 true
415 }
416
417 pub fn update_peer_priority(&self, addr: &SocketAddr, priority: QosPriority) {
419 let mut peer_throttles = self.peer_throttles.write();
420 if let Some(peer) = peer_throttles.get_mut(addr) {
421 peer.priority = priority;
422 debug!("Updated peer {} priority to {:?}", addr, priority);
423 }
424 }
425
426 pub fn stats(&self) -> ThrottleStats {
428 self.stats.read().clone()
429 }
430
431 pub fn reset_stats(&self) {
433 *self.stats.write() = ThrottleStats::default();
434 }
435
436 pub fn available_upload_bandwidth(&self) -> Option<u64> {
438 self.global_upload.as_ref().map(|b| b.available_tokens())
439 }
440
441 pub fn available_download_bandwidth(&self) -> Option<u64> {
443 self.global_download.as_ref().map(|b| b.available_tokens())
444 }
445}
446
447#[cfg(test)]
448mod tests {
449 use super::*;
450
451 #[test]
452 fn test_token_bucket() {
453 let bucket = TokenBucket::new(100, 10.0);
454
455 assert_eq!(bucket.available_tokens(), 100);
456 assert!(bucket.try_consume(50));
457 assert_eq!(bucket.available_tokens(), 50);
458 assert!(bucket.try_consume(50));
459 assert_eq!(bucket.available_tokens(), 0);
460 assert!(!bucket.try_consume(1));
461 }
462
463 #[tokio::test]
464 async fn test_token_bucket_refill() {
465 let bucket = TokenBucket::new(100, 100.0); bucket.try_consume(100);
468 assert_eq!(bucket.available_tokens(), 0);
469
470 tokio::time::sleep(Duration::from_millis(500)).await;
471
472 let available = bucket.available_tokens();
474 assert!((45..=55).contains(&available), "Got {} tokens", available);
475 }
476
477 #[test]
478 fn test_qos_priority() {
479 assert_eq!(QosPriority::BestEffort.multiplier(), 0.5);
480 assert_eq!(QosPriority::Normal.multiplier(), 1.0);
481 assert_eq!(QosPriority::High.multiplier(), 2.0);
482 assert_eq!(QosPriority::Critical.multiplier(), 4.0);
483 }
484
485 #[test]
486 fn test_bandwidth_config_default() {
487 let config = BandwidthConfig::default();
488 assert_eq!(config.global_upload_limit, 0);
489 assert!(config.allow_burst);
490 assert_eq!(config.burst_multiplier, 2.0);
491 }
492
493 #[tokio::test]
494 async fn test_bandwidth_throttle() {
495 let config = BandwidthConfig {
496 global_upload_limit: 1000, global_download_limit: 2000, peer_upload_limit: 0,
499 peer_download_limit: 0,
500 allow_burst: false,
501 burst_multiplier: 1.0,
502 };
503
504 let throttle = BandwidthThrottle::new(config);
505 let addr: SocketAddr = "127.0.0.1:8080".parse().unwrap();
506
507 throttle.register_peer(addr, QosPriority::Normal);
508
509 assert!(throttle.try_throttle_upload(&addr, 500));
511
512 let stats = throttle.stats();
514 assert_eq!(stats.bytes_uploaded, 500);
515 }
516
517 #[tokio::test]
518 async fn test_peer_priority() {
519 let config = BandwidthConfig {
520 global_upload_limit: 0,
521 global_download_limit: 0,
522 peer_upload_limit: 1000, peer_download_limit: 0,
524 allow_burst: false,
525 burst_multiplier: 1.0,
526 };
527
528 let throttle = BandwidthThrottle::new(config);
529 let addr: SocketAddr = "127.0.0.1:8080".parse().unwrap();
530
531 throttle.register_peer(addr, QosPriority::High);
533
534 assert!(throttle.try_throttle_upload(&addr, 1000));
536 }
537}