1use std::sync::Arc;
7use std::sync::atomic::{AtomicU64, Ordering};
8use std::time::{Duration, Instant};
9use tokio::sync::RwLock;
10use tokio::time::sleep;
11
12#[derive(Debug, Clone)]
14pub struct RateLimitConfig {
15 pub upload_rate: u64,
17 pub download_rate: u64,
19 pub burst_multiplier: f64,
21 pub min_transfer_size: u64,
23 pub enabled: bool,
25}
26
27impl Default for RateLimitConfig {
28 fn default() -> Self {
29 Self {
30 upload_rate: 0, download_rate: 0, burst_multiplier: 1.5, min_transfer_size: 1024, enabled: true,
35 }
36 }
37}
38
39impl RateLimitConfig {
40 #[must_use]
42 #[inline]
43 pub fn with_rates(upload_mbps: f64, download_mbps: f64) -> Self {
44 Self {
45 upload_rate: (upload_mbps * 1_000_000.0 / 8.0) as u64,
46 download_rate: (download_mbps * 1_000_000.0 / 8.0) as u64,
47 ..Default::default()
48 }
49 }
50
51 #[must_use]
53 #[inline]
54 pub fn symmetric(rate_mbps: f64) -> Self {
55 Self::with_rates(rate_mbps, rate_mbps)
56 }
57
58 #[must_use]
60 #[inline]
61 pub fn unlimited() -> Self {
62 Self {
63 enabled: false,
64 ..Default::default()
65 }
66 }
67}
68
69struct TokenBucket {
71 tokens: AtomicU64,
73 max_tokens: u64,
75 rate: u64,
77 last_refill: RwLock<Instant>,
79}
80
81impl TokenBucket {
82 fn new(rate: u64, burst_multiplier: f64) -> Self {
83 let max_tokens = (rate as f64 * burst_multiplier) as u64;
84 Self {
85 tokens: AtomicU64::new(max_tokens),
86 max_tokens,
87 rate,
88 last_refill: RwLock::new(Instant::now()),
89 }
90 }
91
92 async fn refill(&self) {
93 let mut last = self.last_refill.write().await;
94 let now = Instant::now();
95 let elapsed = now.duration_since(*last);
96
97 if elapsed.as_millis() > 0 {
98 let new_tokens = (elapsed.as_secs_f64() * self.rate as f64) as u64;
99 let current = self.tokens.load(Ordering::Relaxed);
100 let updated = current.saturating_add(new_tokens).min(self.max_tokens);
101 self.tokens.store(updated, Ordering::Relaxed);
102 *last = now;
103 }
104 }
105
106 async fn consume(&self, bytes: u64) -> Duration {
107 self.refill().await;
108
109 let current = self.tokens.load(Ordering::Relaxed);
110
111 if current >= bytes {
112 self.tokens.fetch_sub(bytes, Ordering::Relaxed);
113 Duration::ZERO
114 } else {
115 let needed = bytes.saturating_sub(current);
117 let wait_secs = needed as f64 / self.rate as f64;
118 Duration::from_secs_f64(wait_secs)
119 }
120 }
121
122 fn available(&self) -> u64 {
123 self.tokens.load(Ordering::Relaxed)
124 }
125}
126
127pub struct BandwidthLimiter {
129 config: RateLimitConfig,
130 upload_bucket: Option<TokenBucket>,
131 download_bucket: Option<TokenBucket>,
132 stats: Arc<RwLock<BandwidthStats>>,
133}
134
135#[derive(Debug, Clone, Default)]
137pub struct BandwidthStats {
138 pub bytes_uploaded: u64,
140 pub bytes_downloaded: u64,
142 pub upload_rate: f64,
144 pub download_rate: f64,
146 pub total_wait_time: Duration,
148 pub limited_transfers: u64,
150 pub started_at: Option<Instant>,
152}
153
154impl BandwidthStats {
155 fn new() -> Self {
156 Self {
157 started_at: Some(Instant::now()),
158 ..Default::default()
159 }
160 }
161
162 fn update_rates(&mut self) {
163 if let Some(start) = self.started_at {
164 let elapsed = start.elapsed().as_secs_f64();
165 if elapsed > 0.0 {
166 self.upload_rate = self.bytes_uploaded as f64 / elapsed;
167 self.download_rate = self.bytes_downloaded as f64 / elapsed;
168 }
169 }
170 }
171}
172
173impl BandwidthLimiter {
174 #[must_use]
176 #[inline]
177 pub fn new(config: RateLimitConfig) -> Self {
178 let upload_bucket = if config.enabled && config.upload_rate > 0 {
179 Some(TokenBucket::new(
180 config.upload_rate,
181 config.burst_multiplier,
182 ))
183 } else {
184 None
185 };
186
187 let download_bucket = if config.enabled && config.download_rate > 0 {
188 Some(TokenBucket::new(
189 config.download_rate,
190 config.burst_multiplier,
191 ))
192 } else {
193 None
194 };
195
196 Self {
197 config,
198 upload_bucket,
199 download_bucket,
200 stats: Arc::new(RwLock::new(BandwidthStats::new())),
201 }
202 }
203
204 pub async fn limit_upload(&self, bytes: u64) {
208 if !self.config.enabled || bytes < self.config.min_transfer_size {
209 return;
210 }
211
212 if let Some(ref bucket) = self.upload_bucket {
213 let wait = bucket.consume(bytes).await;
214 if !wait.is_zero() {
215 let mut stats = self.stats.write().await;
216 stats.total_wait_time += wait;
217 stats.limited_transfers += 1;
218 drop(stats);
219
220 sleep(wait).await;
221 }
222
223 let mut stats = self.stats.write().await;
224 stats.bytes_uploaded += bytes;
225 stats.update_rates();
226 }
227 }
228
229 pub async fn limit_download(&self, bytes: u64) {
231 if !self.config.enabled || bytes < self.config.min_transfer_size {
232 return;
233 }
234
235 if let Some(ref bucket) = self.download_bucket {
236 let wait = bucket.consume(bytes).await;
237 if !wait.is_zero() {
238 let mut stats = self.stats.write().await;
239 stats.total_wait_time += wait;
240 stats.limited_transfers += 1;
241 drop(stats);
242
243 sleep(wait).await;
244 }
245
246 let mut stats = self.stats.write().await;
247 stats.bytes_downloaded += bytes;
248 stats.update_rates();
249 }
250 }
251
252 pub async fn record_upload(&self, bytes: u64) {
254 let mut stats = self.stats.write().await;
255 stats.bytes_uploaded += bytes;
256 stats.update_rates();
257 }
258
259 pub async fn record_download(&self, bytes: u64) {
261 let mut stats = self.stats.write().await;
262 stats.bytes_downloaded += bytes;
263 stats.update_rates();
264 }
265
266 #[must_use]
268 pub async fn stats(&self) -> BandwidthStats {
269 self.stats.read().await.clone()
270 }
271
272 #[must_use]
274 #[inline]
275 pub fn available_upload(&self) -> Option<u64> {
276 self.upload_bucket.as_ref().map(|b| b.available())
277 }
278
279 #[must_use]
281 #[inline]
282 pub fn available_download(&self) -> Option<u64> {
283 self.download_bucket.as_ref().map(|b| b.available())
284 }
285
286 #[must_use]
288 #[inline]
289 pub fn is_enabled(&self) -> bool {
290 self.config.enabled
291 }
292
293 #[must_use]
295 #[inline]
296 pub fn upload_rate(&self) -> u64 {
297 self.config.upload_rate
298 }
299
300 #[must_use]
302 #[inline]
303 pub fn download_rate(&self) -> u64 {
304 self.config.download_rate
305 }
306}
307
308pub struct PeerRateLimiter {
310 global: Arc<BandwidthLimiter>,
312 peer_limiters: RwLock<std::collections::HashMap<String, Arc<BandwidthLimiter>>>,
314 peer_rate_fraction: f64,
316}
317
318impl PeerRateLimiter {
319 #[must_use]
321 #[inline]
322 pub fn new(global_config: RateLimitConfig, peer_rate_fraction: f64) -> Self {
323 Self {
324 global: Arc::new(BandwidthLimiter::new(global_config)),
325 peer_limiters: RwLock::new(std::collections::HashMap::new()),
326 peer_rate_fraction,
327 }
328 }
329
330 #[must_use]
332 pub async fn get_peer_limiter(&self, peer_id: &str) -> Arc<BandwidthLimiter> {
333 {
334 let limiters = self.peer_limiters.read().await;
335 if let Some(limiter) = limiters.get(peer_id) {
336 return Arc::clone(limiter);
337 }
338 }
339
340 let peer_config = RateLimitConfig {
341 upload_rate: (self.global.upload_rate() as f64 * self.peer_rate_fraction) as u64,
342 download_rate: (self.global.download_rate() as f64 * self.peer_rate_fraction) as u64,
343 burst_multiplier: 2.0, min_transfer_size: 512,
345 enabled: self.global.is_enabled(),
346 };
347
348 let limiter = Arc::new(BandwidthLimiter::new(peer_config));
349
350 let mut limiters = self.peer_limiters.write().await;
351 limiters.insert(peer_id.to_string(), Arc::clone(&limiter));
352
353 limiter
354 }
355
356 pub async fn limit_upload(&self, peer_id: &str, bytes: u64) {
358 self.global.limit_upload(bytes).await;
360
361 let peer_limiter = self.get_peer_limiter(peer_id).await;
362 peer_limiter.limit_upload(bytes).await;
363 }
364
365 pub async fn limit_download(&self, peer_id: &str, bytes: u64) {
367 self.global.limit_download(bytes).await;
368
369 let peer_limiter = self.get_peer_limiter(peer_id).await;
370 peer_limiter.limit_download(bytes).await;
371 }
372
373 #[must_use]
375 pub async fn global_stats(&self) -> BandwidthStats {
376 self.global.stats().await
377 }
378
379 #[must_use]
381 pub async fn peer_stats(&self, peer_id: &str) -> Option<BandwidthStats> {
382 let limiters = self.peer_limiters.read().await;
383 if let Some(limiter) = limiters.get(peer_id) {
384 Some(limiter.stats().await)
385 } else {
386 None
387 }
388 }
389
390 pub async fn remove_peer(&self, peer_id: &str) {
392 let mut limiters = self.peer_limiters.write().await;
393 limiters.remove(peer_id);
394 }
395
396 #[must_use]
398 pub async fn peer_count(&self) -> usize {
399 self.peer_limiters.read().await.len()
400 }
401}
402
403#[cfg(test)]
404mod tests {
405 use super::*;
406
407 #[test]
408 fn test_config_default() {
409 let config = RateLimitConfig::default();
410 assert!(config.enabled);
411 assert_eq!(config.upload_rate, 0);
412 assert_eq!(config.download_rate, 0);
413 }
414
415 #[test]
416 fn test_config_with_rates() {
417 let config = RateLimitConfig::with_rates(100.0, 50.0); assert_eq!(config.upload_rate, 12_500_000); assert_eq!(config.download_rate, 6_250_000); }
421
422 #[tokio::test]
423 async fn test_unlimited_limiter() {
424 let config = RateLimitConfig::unlimited();
425 let limiter = BandwidthLimiter::new(config);
426
427 let start = Instant::now();
429 limiter.limit_upload(10_000_000).await; limiter.limit_download(10_000_000).await;
431 assert!(start.elapsed() < Duration::from_millis(10));
432 }
433
434 #[tokio::test]
435 async fn test_stats_recording() {
436 let config = RateLimitConfig::unlimited();
437 let limiter = BandwidthLimiter::new(config);
438
439 limiter.record_upload(1000).await;
440 limiter.record_download(2000).await;
441
442 let stats = limiter.stats().await;
443 assert_eq!(stats.bytes_uploaded, 1000);
444 assert_eq!(stats.bytes_downloaded, 2000);
445 }
446
447 #[tokio::test]
448 async fn test_peer_rate_limiter() {
449 let global_config = RateLimitConfig::unlimited();
450 let peer_limiter = PeerRateLimiter::new(global_config, 0.25);
451
452 let _limiter = peer_limiter.get_peer_limiter("peer1").await;
454 assert_eq!(peer_limiter.peer_count().await, 1);
455
456 peer_limiter.remove_peer("peer1").await;
458 assert_eq!(peer_limiter.peer_count().await, 0);
459 }
460}