1use crate::{Error, Result};
9use rand::Rng;
10use std::borrow::Cow;
11use std::collections::HashMap;
12use std::sync::Arc;
13use std::time::{Duration, Instant};
14use tokio::sync::{Mutex, RwLock};
15
16const GLOBAL_BUCKET_KEY: &str = "__global__";
17
18#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
20pub struct BandwidthConfig {
21 pub enabled: bool,
23 pub max_bytes_per_sec: u64,
25 pub burst_capacity_bytes: u64,
27 pub tag_overrides: HashMap<String, u64>,
29}
30
31impl Default for BandwidthConfig {
32 fn default() -> Self {
33 Self {
34 enabled: false,
35 max_bytes_per_sec: 0, burst_capacity_bytes: 1024 * 1024, tag_overrides: HashMap::new(),
38 }
39 }
40}
41
42impl BandwidthConfig {
43 pub fn new(max_bytes_per_sec: u64, burst_capacity_bytes: u64) -> Self {
45 Self {
46 enabled: true,
47 max_bytes_per_sec,
48 burst_capacity_bytes,
49 tag_overrides: HashMap::new(),
50 }
51 }
52
53 pub fn with_tag_override(mut self, tag: String, max_bytes_per_sec: u64) -> Self {
55 self.tag_overrides.insert(tag, max_bytes_per_sec);
56 self
57 }
58
59 pub fn get_effective_limit(&self, tags: &[String]) -> u64 {
61 if let Some(&override_limit) = tags.iter().find_map(|tag| self.tag_overrides.get(tag)) {
63 return override_limit;
64 }
65 self.max_bytes_per_sec
66 }
67}
68
69#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
71pub struct BurstLossConfig {
72 pub enabled: bool,
74 pub burst_probability: f64,
76 pub burst_duration_ms: u64,
78 pub loss_rate_during_burst: f64,
80 pub recovery_time_ms: u64,
82 pub tag_overrides: HashMap<String, BurstLossOverride>,
84}
85
86#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
88pub struct BurstLossOverride {
89 pub burst_probability: f64,
91 pub burst_duration_ms: u64,
93 pub loss_rate_during_burst: f64,
95 pub recovery_time_ms: u64,
97}
98
99impl Default for BurstLossConfig {
100 fn default() -> Self {
101 Self {
102 enabled: false,
103 burst_probability: 0.1, burst_duration_ms: 5000, loss_rate_during_burst: 0.5, recovery_time_ms: 30000, tag_overrides: HashMap::new(),
108 }
109 }
110}
111
112impl BurstLossConfig {
113 pub fn new(
115 burst_probability: f64,
116 burst_duration_ms: u64,
117 loss_rate: f64,
118 recovery_time_ms: u64,
119 ) -> Self {
120 Self {
121 enabled: true,
122 burst_probability: burst_probability.clamp(0.0, 1.0),
123 burst_duration_ms,
124 loss_rate_during_burst: loss_rate.clamp(0.0, 1.0),
125 recovery_time_ms,
126 tag_overrides: HashMap::new(),
127 }
128 }
129
130 pub fn with_tag_override(mut self, tag: String, override_config: BurstLossOverride) -> Self {
132 self.tag_overrides.insert(tag, override_config);
133 self
134 }
135
136 pub fn effective_config<'a>(&'a self, tags: &[String]) -> Cow<'a, BurstLossConfig> {
138 if let Some(override_config) = tags.iter().find_map(|tag| self.tag_overrides.get(tag)) {
139 let mut temp_config = self.clone();
140 temp_config.burst_probability = override_config.burst_probability;
141 temp_config.burst_duration_ms = override_config.burst_duration_ms;
142 temp_config.loss_rate_during_burst = override_config.loss_rate_during_burst;
143 temp_config.recovery_time_ms = override_config.recovery_time_ms;
144 Cow::Owned(temp_config)
145 } else {
146 Cow::Borrowed(self)
147 }
148 }
149}
150
151#[derive(Debug)]
153struct TokenBucket {
154 tokens: f64,
156 capacity: f64,
158 refill_rate: f64,
160 last_refill: Instant,
162}
163
164impl TokenBucket {
165 fn new(capacity: u64, refill_rate_bytes_per_sec: u64) -> Self {
167 Self {
168 tokens: capacity as f64,
169 capacity: capacity as f64,
170 refill_rate: refill_rate_bytes_per_sec as f64,
171 last_refill: Instant::now(),
172 }
173 }
174
175 fn refill(&mut self) {
177 let now = Instant::now();
178 let elapsed = now.duration_since(self.last_refill).as_secs_f64();
179 let tokens_to_add = elapsed * self.refill_rate;
180
181 self.tokens = (self.tokens + tokens_to_add).min(self.capacity);
182 self.last_refill = now;
183 }
184
185 fn try_consume(&mut self, bytes: u64) -> bool {
187 self.refill();
188 if self.tokens >= bytes as f64 {
189 self.tokens -= bytes as f64;
190 true
191 } else {
192 false
193 }
194 }
195
196 fn time_until_available(&mut self, bytes: u64) -> Duration {
198 self.refill();
199 if self.tokens >= bytes as f64 {
200 Duration::ZERO
201 } else {
202 let tokens_needed = bytes as f64 - self.tokens;
203 let seconds_needed = tokens_needed / self.refill_rate;
204 Duration::from_secs_f64(seconds_needed)
205 }
206 }
207}
208
209#[derive(Debug)]
211struct BurstLossState {
212 in_burst: bool,
214 burst_start: Option<Instant>,
216 recovery_start: Option<Instant>,
218}
219
220impl BurstLossState {
221 fn new() -> Self {
222 Self {
223 in_burst: false,
224 burst_start: None,
225 recovery_start: None,
226 }
227 }
228
229 fn should_drop_packet(&mut self, config: &BurstLossConfig) -> bool {
231 if !config.enabled {
232 return false;
233 }
234
235 let now = Instant::now();
236
237 match (self.in_burst, self.burst_start, self.recovery_start) {
238 (true, Some(burst_start), _) => {
239 let burst_duration = now.duration_since(burst_start);
241 if burst_duration >= Duration::from_millis(config.burst_duration_ms) {
242 self.in_burst = false;
244 self.burst_start = None;
245 self.recovery_start = Some(now);
246 false } else {
248 let mut rng = rand::rng();
250 rng.random_bool(config.loss_rate_during_burst)
251 }
252 }
253 (true, None, _) => {
254 self.in_burst = false;
256 false
257 }
258 (false, _, Some(recovery_start)) => {
259 let recovery_duration = now.duration_since(recovery_start);
261 if recovery_duration >= Duration::from_millis(config.recovery_time_ms) {
262 self.recovery_start = None;
264 let mut rng = rand::rng();
266 if rng.random_bool(config.burst_probability) {
267 self.in_burst = true;
268 self.burst_start = Some(now);
269 rng.random_bool(config.loss_rate_during_burst)
271 } else {
272 false
273 }
274 } else {
275 false }
277 }
278 (false, _, None) => {
279 let mut rng = rand::rng();
281 if rng.random_bool(config.burst_probability) {
282 self.in_burst = true;
283 self.burst_start = Some(now);
284 rng.random_bool(config.loss_rate_during_burst)
285 } else {
286 false
287 }
288 }
289 }
290 }
291}
292
293#[derive(Debug, Clone, serde::Deserialize, serde::Serialize, Default)]
295pub struct TrafficShapingConfig {
296 pub bandwidth: BandwidthConfig,
298 pub burst_loss: BurstLossConfig,
300}
301
302#[derive(Debug, Clone)]
304pub struct TrafficShaper {
305 bandwidth_config: BandwidthConfig,
307 burst_loss_config: BurstLossConfig,
309 token_buckets: Arc<RwLock<HashMap<String, Arc<Mutex<TokenBucket>>>>>,
311 burst_loss_state: Arc<Mutex<BurstLossState>>,
313}
314
315impl TrafficShaper {
316 pub fn new(config: TrafficShapingConfig) -> Self {
318 Self {
319 bandwidth_config: config.bandwidth,
320 burst_loss_config: config.burst_loss,
321 token_buckets: Arc::new(RwLock::new(HashMap::new())),
322 burst_loss_state: Arc::new(Mutex::new(BurstLossState::new())),
323 }
324 }
325
326 pub async fn throttle_bandwidth(&self, data_size: u64, tags: &[String]) -> Result<()> {
328 if !self.bandwidth_config.enabled {
329 return Ok(());
330 }
331
332 let (bucket_key, effective_limit) = self.resolve_bandwidth_bucket(tags);
333
334 if effective_limit == 0 {
335 return Ok(());
336 }
337
338 let bucket_arc = self.get_or_create_bucket(&bucket_key, effective_limit).await;
339
340 {
341 let mut bucket = bucket_arc.lock().await;
342 if bucket.try_consume(data_size) {
343 return Ok(());
344 }
345
346 let wait_time = bucket.time_until_available(data_size);
347 drop(bucket);
348
349 if wait_time.is_zero() {
350 return Err(Error::generic(format!(
351 "Failed to acquire bandwidth tokens for {} bytes",
352 data_size
353 )));
354 }
355
356 tokio::time::sleep(wait_time).await;
357 }
358
359 let mut bucket = bucket_arc.lock().await;
360 if bucket.try_consume(data_size) {
361 Ok(())
362 } else {
363 Err(Error::generic(format!(
364 "Failed to acquire bandwidth tokens for {} bytes",
365 data_size
366 )))
367 }
368 }
369
370 pub async fn should_drop_packet(&self, tags: &[String]) -> bool {
372 if !self.burst_loss_config.enabled {
373 return false;
374 }
375
376 let effective_config = self.burst_loss_config.effective_config(tags);
377 let mut state = self.burst_loss_state.lock().await;
378 state.should_drop_packet(effective_config.as_ref())
379 }
380
381 pub async fn process_transfer(
383 &self,
384 data_size: u64,
385 tags: &[String],
386 ) -> Result<Option<Duration>> {
387 self.throttle_bandwidth(data_size, tags).await?;
389
390 if self.should_drop_packet(tags).await {
392 return Ok(Some(Duration::from_millis(100))); }
394
395 Ok(None)
396 }
397
398 pub async fn get_bandwidth_stats(&self) -> BandwidthStats {
400 let maybe_bucket = {
401 let guard = self.token_buckets.read().await;
402 guard.get(GLOBAL_BUCKET_KEY).cloned()
403 };
404
405 if let Some(bucket_arc) = maybe_bucket {
406 let bucket = bucket_arc.lock().await;
407 BandwidthStats {
408 current_tokens: bucket.tokens as u64,
409 capacity: bucket.capacity as u64,
410 refill_rate_bytes_per_sec: bucket.refill_rate as u64,
411 }
412 } else {
413 BandwidthStats {
414 current_tokens: self.bandwidth_config.burst_capacity_bytes,
415 capacity: self.bandwidth_config.burst_capacity_bytes,
416 refill_rate_bytes_per_sec: self.bandwidth_config.max_bytes_per_sec,
417 }
418 }
419 }
420
421 pub async fn get_burst_loss_stats(&self) -> BurstLossStats {
423 let state = self.burst_loss_state.lock().await;
424 BurstLossStats {
425 in_burst: state.in_burst,
426 burst_start: state.burst_start,
427 recovery_start: state.recovery_start,
428 }
429 }
430
431 async fn get_or_create_bucket(
432 &self,
433 bucket_key: &str,
434 effective_limit: u64,
435 ) -> Arc<Mutex<TokenBucket>> {
436 if let Some(existing) = self.token_buckets.read().await.get(bucket_key).cloned() {
437 return existing;
438 }
439
440 let mut buckets = self.token_buckets.write().await;
441 buckets
442 .entry(bucket_key.to_string())
443 .or_insert_with(|| {
444 Arc::new(Mutex::new(TokenBucket::new(
445 self.bandwidth_config.burst_capacity_bytes,
446 effective_limit,
447 )))
448 })
449 .clone()
450 }
451
452 fn resolve_bandwidth_bucket(&self, tags: &[String]) -> (String, u64) {
453 if let Some((tag, limit)) = tags.iter().find_map(|tag| {
454 self.bandwidth_config.tag_overrides.get(tag).map(|limit| (tag.as_str(), *limit))
455 }) {
456 (format!("tag:{}", tag), limit)
457 } else {
458 (GLOBAL_BUCKET_KEY.to_string(), self.bandwidth_config.max_bytes_per_sec)
459 }
460 }
461}
462
463#[derive(Debug, Clone)]
465pub struct BandwidthStats {
466 pub current_tokens: u64,
468 pub capacity: u64,
470 pub refill_rate_bytes_per_sec: u64,
472}
473
474#[derive(Debug, Clone)]
476pub struct BurstLossStats {
477 pub in_burst: bool,
479 pub burst_start: Option<Instant>,
481 pub recovery_start: Option<Instant>,
483}
484
485#[cfg(test)]
486mod tests {
487 use super::*;
488 use std::time::Duration;
489
490 #[tokio::test]
491 async fn test_bandwidth_throttling() {
492 let config = TrafficShapingConfig {
493 bandwidth: BandwidthConfig::new(1000, 100), burst_loss: BurstLossConfig::default(),
495 };
496 let shaper = TrafficShaper::new(config);
497
498 let result = shaper.throttle_bandwidth(50, &[]).await;
500 assert!(result.is_ok());
501
502 let start = Instant::now();
504 let result = shaper.throttle_bandwidth(80, &[]).await; let elapsed = start.elapsed();
506 assert!(result.is_ok());
507 assert!(elapsed >= Duration::from_millis(30)); }
510
511 #[tokio::test]
512 async fn test_burst_loss() {
513 let config = TrafficShapingConfig {
514 bandwidth: BandwidthConfig::default(),
515 burst_loss: BurstLossConfig::new(1.0, 1000, 1.0, 1000), };
517 let shaper = TrafficShaper::new(config);
518
519 let should_drop = shaper.should_drop_packet(&[]).await;
521 assert!(should_drop);
522
523 for _ in 0..5 {
525 let should_drop = shaper.should_drop_packet(&[]).await;
526 assert!(should_drop);
527 }
528 }
529
530 #[tokio::test]
531 async fn test_bandwidth_tag_override_with_global_unlimited() {
532 let mut bandwidth = BandwidthConfig::default();
533 bandwidth.enabled = true;
534 bandwidth.max_bytes_per_sec = 0;
535 bandwidth.burst_capacity_bytes = 100;
536 bandwidth = bandwidth.with_tag_override("limited".to_string(), 100);
537
538 let shaper = TrafficShaper::new(TrafficShapingConfig {
539 bandwidth,
540 burst_loss: BurstLossConfig::default(),
541 });
542
543 let tags = vec!["limited".to_string()];
544 shaper
545 .throttle_bandwidth(100, &tags)
546 .await
547 .expect("initial transfer should succeed immediately");
548
549 let start = Instant::now();
550 shaper
551 .throttle_bandwidth(100, &tags)
552 .await
553 .expect("tag override should throttle but eventually succeed");
554 assert!(
555 start.elapsed() >= Duration::from_millis(900),
556 "override-specific transfer should respect configured rate"
557 );
558 }
559
560 #[test]
561 fn test_bandwidth_config_overrides() {
562 let mut config = BandwidthConfig::new(1000, 100);
563 config = config.with_tag_override("high-priority".to_string(), 5000);
564
565 assert_eq!(config.get_effective_limit(&[]), 1000);
566 assert_eq!(config.get_effective_limit(&["high-priority".to_string()]), 5000);
567 assert_eq!(
568 config.get_effective_limit(&["low-priority".to_string(), "high-priority".to_string()]),
569 5000
570 );
571 }
572
573 #[test]
574 fn test_burst_loss_effective_config_override() {
575 let override_cfg = BurstLossOverride {
576 burst_probability: 0.8,
577 burst_duration_ms: 2000,
578 loss_rate_during_burst: 0.9,
579 recovery_time_ms: 5000,
580 };
581
582 let config =
583 BurstLossConfig::default().with_tag_override("flaky".to_string(), override_cfg.clone());
584
585 let effective = config.effective_config(&["flaky".to_string()]);
586 assert_eq!(effective.burst_probability, override_cfg.burst_probability);
587 assert_eq!(effective.burst_duration_ms, override_cfg.burst_duration_ms);
588 assert_eq!(effective.loss_rate_during_burst, override_cfg.loss_rate_during_burst);
589 assert_eq!(effective.recovery_time_ms, override_cfg.recovery_time_ms);
590 }
591}