1use parking_lot::Mutex;
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10use thiserror::Error;
11
12#[derive(Error, Debug, Clone)]
14pub enum ThrottleError {
15 #[error("Bandwidth limit exceeded, retry after {0:?}")]
16 RateLimitExceeded(Duration),
17
18 #[error("Invalid throttle configuration: {0}")]
19 InvalidConfig(String),
20
21 #[error("Throttle disabled")]
22 Disabled,
23}
24
25#[derive(Debug, Clone)]
27pub struct ThrottleConfig {
28 pub max_upload_bytes_per_sec: Option<u64>,
30
31 pub max_download_bytes_per_sec: Option<u64>,
33
34 pub burst_size_bytes: Option<u64>,
37
38 pub enabled: bool,
40
41 pub refill_interval: Duration,
43}
44
45impl Default for ThrottleConfig {
46 fn default() -> Self {
47 Self {
48 max_upload_bytes_per_sec: None,
49 max_download_bytes_per_sec: None,
50 burst_size_bytes: None,
51 enabled: false,
52 refill_interval: Duration::from_millis(100), }
54 }
55}
56
57impl ThrottleConfig {
58 pub fn mobile() -> Self {
61 Self {
62 max_upload_bytes_per_sec: Some(1_000_000), max_download_bytes_per_sec: Some(5_000_000), burst_size_bytes: Some(2_000_000), enabled: true,
66 refill_interval: Duration::from_millis(100),
67 }
68 }
69
70 pub fn iot() -> Self {
73 Self {
74 max_upload_bytes_per_sec: Some(128_000), max_download_bytes_per_sec: Some(512_000), burst_size_bytes: Some(256_000), enabled: true,
78 refill_interval: Duration::from_millis(100),
79 }
80 }
81
82 pub fn low_power() -> Self {
85 Self {
86 max_upload_bytes_per_sec: Some(64_000), max_download_bytes_per_sec: Some(256_000), burst_size_bytes: Some(128_000), enabled: true,
90 refill_interval: Duration::from_millis(200), }
92 }
93
94 pub fn validate(&self) -> Result<(), ThrottleError> {
96 if self.refill_interval.is_zero() {
97 return Err(ThrottleError::InvalidConfig(
98 "Refill interval must be > 0".to_string(),
99 ));
100 }
101
102 if let Some(burst) = self.burst_size_bytes {
103 if burst == 0 {
104 return Err(ThrottleError::InvalidConfig(
105 "Burst size must be > 0 if specified".to_string(),
106 ));
107 }
108 }
109
110 Ok(())
111 }
112}
113
114#[derive(Debug)]
116struct TokenBucket {
117 tokens: f64,
119
120 capacity: f64,
122
123 refill_rate: f64,
125
126 last_refill: Instant,
128
129 refill_interval: Duration,
131}
132
133impl TokenBucket {
134 fn new(rate_bytes_per_sec: u64, burst_bytes: u64, refill_interval: Duration) -> Self {
135 Self {
136 tokens: burst_bytes as f64,
137 capacity: burst_bytes as f64,
138 refill_rate: rate_bytes_per_sec as f64,
139 last_refill: Instant::now(),
140 refill_interval,
141 }
142 }
143
144 fn refill(&mut self) {
146 let now = Instant::now();
147 let elapsed = now.duration_since(self.last_refill);
148
149 if elapsed >= self.refill_interval {
150 let tokens_to_add = self.refill_rate * elapsed.as_secs_f64();
151 self.tokens = (self.tokens + tokens_to_add).min(self.capacity);
152 self.last_refill = now;
153 }
154 }
155
156 fn consume(&mut self, bytes: u64) -> Result<(), Duration> {
159 self.refill();
160
161 if self.tokens >= bytes as f64 {
162 self.tokens -= bytes as f64;
163 Ok(())
164 } else {
165 let tokens_needed = bytes as f64 - self.tokens;
167 let wait_time = Duration::from_secs_f64(tokens_needed / self.refill_rate);
168 Err(wait_time)
169 }
170 }
171
172 fn available_tokens(&mut self) -> u64 {
174 self.refill();
175 self.tokens as u64
176 }
177}
178
179#[derive(Debug, Clone, Copy, PartialEq, Eq)]
181pub enum TrafficDirection {
182 Upload,
183 Download,
184}
185
186#[derive(Clone)]
188pub struct BandwidthThrottle {
189 config: ThrottleConfig,
190 upload_bucket: Arc<Mutex<Option<TokenBucket>>>,
191 download_bucket: Arc<Mutex<Option<TokenBucket>>>,
192}
193
194impl BandwidthThrottle {
195 pub fn new(config: ThrottleConfig) -> Result<Self, ThrottleError> {
197 config.validate()?;
198
199 let upload_bucket = config.max_upload_bytes_per_sec.map(|rate| {
200 let burst = config.burst_size_bytes.unwrap_or(rate * 2);
201 TokenBucket::new(rate, burst, config.refill_interval)
202 });
203
204 let download_bucket = config.max_download_bytes_per_sec.map(|rate| {
205 let burst = config.burst_size_bytes.unwrap_or(rate * 2);
206 TokenBucket::new(rate, burst, config.refill_interval)
207 });
208
209 Ok(Self {
210 config: config.clone(),
211 upload_bucket: Arc::new(Mutex::new(upload_bucket)),
212 download_bucket: Arc::new(Mutex::new(download_bucket)),
213 })
214 }
215
216 pub fn check_and_consume(
219 &self,
220 direction: TrafficDirection,
221 bytes: u64,
222 ) -> Result<(), ThrottleError> {
223 if !self.config.enabled {
224 return Err(ThrottleError::Disabled);
225 }
226
227 let bucket = match direction {
228 TrafficDirection::Upload => &self.upload_bucket,
229 TrafficDirection::Download => &self.download_bucket,
230 };
231
232 let mut guard = bucket.lock();
233 if let Some(bucket) = guard.as_mut() {
234 bucket
235 .consume(bytes)
236 .map_err(ThrottleError::RateLimitExceeded)
237 } else {
238 Ok(())
240 }
241 }
242
243 pub fn available_bandwidth(&self, direction: TrafficDirection) -> Option<u64> {
245 if !self.config.enabled {
246 return None;
247 }
248
249 let bucket = match direction {
250 TrafficDirection::Upload => &self.upload_bucket,
251 TrafficDirection::Download => &self.download_bucket,
252 };
253
254 let mut guard = bucket.lock();
255 guard.as_mut().map(|b| b.available_tokens())
256 }
257
258 pub fn set_enabled(&mut self, enabled: bool) {
260 Arc::make_mut(&mut Arc::new(self.config.clone())).enabled = enabled;
261 }
262
263 pub fn is_enabled(&self) -> bool {
265 self.config.enabled
266 }
267
268 pub fn config(&self) -> &ThrottleConfig {
270 &self.config
271 }
272
273 pub fn update_config(&mut self, config: ThrottleConfig) -> Result<(), ThrottleError> {
275 config.validate()?;
276
277 let upload_bucket = config.max_upload_bytes_per_sec.map(|rate| {
279 let burst = config.burst_size_bytes.unwrap_or(rate * 2);
280 TokenBucket::new(rate, burst, config.refill_interval)
281 });
282
283 let download_bucket = config.max_download_bytes_per_sec.map(|rate| {
284 let burst = config.burst_size_bytes.unwrap_or(rate * 2);
285 TokenBucket::new(rate, burst, config.refill_interval)
286 });
287
288 *self.upload_bucket.lock() = upload_bucket;
289 *self.download_bucket.lock() = download_bucket;
290 self.config = config;
291
292 Ok(())
293 }
294}
295
296#[derive(Debug, Clone, Default)]
298pub struct ThrottleStats {
299 pub upload_bytes_allowed: u64,
301
302 pub upload_bytes_throttled: u64,
304
305 pub download_bytes_allowed: u64,
307
308 pub download_bytes_throttled: u64,
310
311 pub upload_throttle_count: u64,
313
314 pub download_throttle_count: u64,
316}
317
318#[cfg(test)]
319mod tests {
320 use super::*;
321 use std::thread;
322
323 #[test]
324 fn test_throttle_config_default() {
325 let config = ThrottleConfig::default();
326 assert!(!config.enabled);
327 assert!(config.max_upload_bytes_per_sec.is_none());
328 assert!(config.max_download_bytes_per_sec.is_none());
329 }
330
331 #[test]
332 fn test_throttle_config_mobile() {
333 let config = ThrottleConfig::mobile();
334 assert!(config.enabled);
335 assert_eq!(config.max_upload_bytes_per_sec, Some(1_000_000));
336 assert_eq!(config.max_download_bytes_per_sec, Some(5_000_000));
337 assert!(config.validate().is_ok());
338 }
339
340 #[test]
341 fn test_throttle_config_iot() {
342 let config = ThrottleConfig::iot();
343 assert!(config.enabled);
344 assert_eq!(config.max_upload_bytes_per_sec, Some(128_000));
345 assert_eq!(config.max_download_bytes_per_sec, Some(512_000));
346 assert!(config.validate().is_ok());
347 }
348
349 #[test]
350 fn test_throttle_config_low_power() {
351 let config = ThrottleConfig::low_power();
352 assert!(config.enabled);
353 assert_eq!(config.max_upload_bytes_per_sec, Some(64_000));
354 assert_eq!(config.max_download_bytes_per_sec, Some(256_000));
355 assert!(config.validate().is_ok());
356 }
357
358 #[test]
359 fn test_throttle_disabled() {
360 let config = ThrottleConfig::default();
361 let throttle = BandwidthThrottle::new(config).unwrap();
362
363 let result = throttle.check_and_consume(TrafficDirection::Upload, 1000);
365 assert!(matches!(result, Err(ThrottleError::Disabled)));
366 }
367
368 #[test]
369 fn test_throttle_upload_within_limit() {
370 let config = ThrottleConfig {
371 enabled: true,
372 max_upload_bytes_per_sec: Some(1000),
373 burst_size_bytes: Some(2000),
374 ..Default::default()
375 };
376
377 let throttle = BandwidthThrottle::new(config).unwrap();
378
379 let result = throttle.check_and_consume(TrafficDirection::Upload, 1500);
381 assert!(result.is_ok());
382 }
383
384 #[test]
385 fn test_throttle_upload_exceeds_limit() {
386 let config = ThrottleConfig {
387 enabled: true,
388 max_upload_bytes_per_sec: Some(1000),
389 burst_size_bytes: Some(2000),
390 ..Default::default()
391 };
392
393 let throttle = BandwidthThrottle::new(config).unwrap();
394
395 let _ = throttle.check_and_consume(TrafficDirection::Upload, 2000);
397
398 let result = throttle.check_and_consume(TrafficDirection::Upload, 100);
400 assert!(matches!(result, Err(ThrottleError::RateLimitExceeded(_))));
401 }
402
403 #[test]
404 fn test_throttle_download_within_limit() {
405 let config = ThrottleConfig {
406 enabled: true,
407 max_download_bytes_per_sec: Some(5000),
408 burst_size_bytes: Some(10000),
409 ..Default::default()
410 };
411
412 let throttle = BandwidthThrottle::new(config).unwrap();
413
414 let result = throttle.check_and_consume(TrafficDirection::Download, 8000);
416 assert!(result.is_ok());
417 }
418
419 #[test]
420 fn test_throttle_refill() {
421 let config = ThrottleConfig {
422 enabled: true,
423 max_upload_bytes_per_sec: Some(1000),
424 burst_size_bytes: Some(1000),
425 refill_interval: Duration::from_millis(100),
426 ..Default::default()
427 };
428
429 let throttle = BandwidthThrottle::new(config).unwrap();
430
431 let _ = throttle.check_and_consume(TrafficDirection::Upload, 1000);
433
434 thread::sleep(Duration::from_millis(150));
436
437 let available = throttle.available_bandwidth(TrafficDirection::Upload);
439 assert!(available.is_some());
440 assert!(available.unwrap() > 0);
441 }
442
443 #[test]
444 fn test_throttle_available_bandwidth() {
445 let config = ThrottleConfig {
446 enabled: true,
447 max_upload_bytes_per_sec: Some(1000),
448 burst_size_bytes: Some(2000),
449 ..Default::default()
450 };
451
452 let throttle = BandwidthThrottle::new(config).unwrap();
453
454 let available = throttle.available_bandwidth(TrafficDirection::Upload);
455 assert_eq!(available, Some(2000)); }
457
458 #[test]
459 fn test_throttle_independent_directions() {
460 let config = ThrottleConfig {
461 enabled: true,
462 max_upload_bytes_per_sec: Some(1000),
463 max_download_bytes_per_sec: Some(5000),
464 burst_size_bytes: Some(2000),
465 ..Default::default()
466 };
467
468 let throttle = BandwidthThrottle::new(config).unwrap();
469
470 let _ = throttle.check_and_consume(TrafficDirection::Upload, 2000);
472
473 let result = throttle.check_and_consume(TrafficDirection::Download, 2000);
475 assert!(result.is_ok());
476 }
477
478 #[test]
479 fn test_throttle_update_config() {
480 let config = ThrottleConfig {
481 enabled: true,
482 max_upload_bytes_per_sec: Some(1000),
483 ..Default::default()
484 };
485
486 let mut throttle = BandwidthThrottle::new(config).unwrap();
487
488 let new_config = ThrottleConfig {
490 enabled: true,
491 max_upload_bytes_per_sec: Some(5000),
492 burst_size_bytes: Some(10000),
493 ..Default::default()
494 };
495
496 throttle.update_config(new_config).unwrap();
497
498 let available = throttle.available_bandwidth(TrafficDirection::Upload);
500 assert_eq!(available, Some(10000));
501 }
502
503 #[test]
504 fn test_throttle_config_validation() {
505 let config = ThrottleConfig {
506 refill_interval: Duration::from_secs(0),
507 ..Default::default()
508 };
509
510 let result = BandwidthThrottle::new(config);
511 assert!(matches!(result, Err(ThrottleError::InvalidConfig(_))));
512 }
513
514 #[test]
515 fn test_throttle_no_limit_direction() {
516 let config = ThrottleConfig {
517 enabled: true,
518 max_upload_bytes_per_sec: Some(1000),
519 ..Default::default()
521 };
522
523 let throttle = BandwidthThrottle::new(config).unwrap();
524
525 let result = throttle.check_and_consume(TrafficDirection::Download, 1_000_000);
527 assert!(result.is_ok());
528 }
529}