1use crate::{Error, Result};
9use rand::Rng;
10use std::borrow::Cow;
11use std::collections::HashMap;
12use std::sync::Arc;
13use std::time::{Duration, Instant};
14use tokio::sync::{Mutex, RwLock};
15
16const GLOBAL_BUCKET_KEY: &str = "__global__";
17
18#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
20pub struct BandwidthConfig {
21 pub enabled: bool,
23 pub max_bytes_per_sec: u64,
25 pub burst_capacity_bytes: u64,
27 pub tag_overrides: HashMap<String, u64>,
29}
30
31impl Default for BandwidthConfig {
32 fn default() -> Self {
33 Self {
34 enabled: false,
35 max_bytes_per_sec: 0, burst_capacity_bytes: 1024 * 1024, tag_overrides: HashMap::new(),
38 }
39 }
40}
41
42impl BandwidthConfig {
43 pub fn new(max_bytes_per_sec: u64, burst_capacity_bytes: u64) -> Self {
45 Self {
46 enabled: true,
47 max_bytes_per_sec,
48 burst_capacity_bytes,
49 tag_overrides: HashMap::new(),
50 }
51 }
52
53 pub fn with_tag_override(mut self, tag: String, max_bytes_per_sec: u64) -> Self {
55 self.tag_overrides.insert(tag, max_bytes_per_sec);
56 self
57 }
58
59 pub fn get_effective_limit(&self, tags: &[String]) -> u64 {
61 if let Some(&override_limit) = tags.iter().find_map(|tag| self.tag_overrides.get(tag)) {
63 return override_limit;
64 }
65 self.max_bytes_per_sec
66 }
67}
68
69#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
71pub struct BurstLossConfig {
72 pub enabled: bool,
74 pub burst_probability: f64,
76 pub burst_duration_ms: u64,
78 pub loss_rate_during_burst: f64,
80 pub recovery_time_ms: u64,
82 pub tag_overrides: HashMap<String, BurstLossOverride>,
84}
85
86#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
87pub struct BurstLossOverride {
88 pub burst_probability: f64,
89 pub burst_duration_ms: u64,
90 pub loss_rate_during_burst: f64,
91 pub recovery_time_ms: u64,
92}
93
94impl Default for BurstLossConfig {
95 fn default() -> Self {
96 Self {
97 enabled: false,
98 burst_probability: 0.1, burst_duration_ms: 5000, loss_rate_during_burst: 0.5, recovery_time_ms: 30000, tag_overrides: HashMap::new(),
103 }
104 }
105}
106
107impl BurstLossConfig {
108 pub fn new(
110 burst_probability: f64,
111 burst_duration_ms: u64,
112 loss_rate: f64,
113 recovery_time_ms: u64,
114 ) -> Self {
115 Self {
116 enabled: true,
117 burst_probability: burst_probability.clamp(0.0, 1.0),
118 burst_duration_ms,
119 loss_rate_during_burst: loss_rate.clamp(0.0, 1.0),
120 recovery_time_ms,
121 tag_overrides: HashMap::new(),
122 }
123 }
124
125 pub fn with_tag_override(mut self, tag: String, override_config: BurstLossOverride) -> Self {
127 self.tag_overrides.insert(tag, override_config);
128 self
129 }
130
131 pub fn effective_config<'a>(&'a self, tags: &[String]) -> Cow<'a, BurstLossConfig> {
133 if let Some(override_config) = tags.iter().find_map(|tag| self.tag_overrides.get(tag)) {
134 let mut temp_config = self.clone();
135 temp_config.burst_probability = override_config.burst_probability;
136 temp_config.burst_duration_ms = override_config.burst_duration_ms;
137 temp_config.loss_rate_during_burst = override_config.loss_rate_during_burst;
138 temp_config.recovery_time_ms = override_config.recovery_time_ms;
139 Cow::Owned(temp_config)
140 } else {
141 Cow::Borrowed(self)
142 }
143 }
144}
145
146#[derive(Debug)]
148struct TokenBucket {
149 tokens: f64,
151 capacity: f64,
153 refill_rate: f64,
155 last_refill: Instant,
157}
158
159impl TokenBucket {
160 fn new(capacity: u64, refill_rate_bytes_per_sec: u64) -> Self {
162 Self {
163 tokens: capacity as f64,
164 capacity: capacity as f64,
165 refill_rate: refill_rate_bytes_per_sec as f64,
166 last_refill: Instant::now(),
167 }
168 }
169
170 fn refill(&mut self) {
172 let now = Instant::now();
173 let elapsed = now.duration_since(self.last_refill).as_secs_f64();
174 let tokens_to_add = elapsed * self.refill_rate;
175
176 self.tokens = (self.tokens + tokens_to_add).min(self.capacity);
177 self.last_refill = now;
178 }
179
180 fn try_consume(&mut self, bytes: u64) -> bool {
182 self.refill();
183 if self.tokens >= bytes as f64 {
184 self.tokens -= bytes as f64;
185 true
186 } else {
187 false
188 }
189 }
190
191 fn time_until_available(&mut self, bytes: u64) -> Duration {
193 self.refill();
194 if self.tokens >= bytes as f64 {
195 Duration::ZERO
196 } else {
197 let tokens_needed = bytes as f64 - self.tokens;
198 let seconds_needed = tokens_needed / self.refill_rate;
199 Duration::from_secs_f64(seconds_needed)
200 }
201 }
202}
203
204#[derive(Debug)]
206struct BurstLossState {
207 in_burst: bool,
209 burst_start: Option<Instant>,
211 recovery_start: Option<Instant>,
213}
214
215impl BurstLossState {
216 fn new() -> Self {
217 Self {
218 in_burst: false,
219 burst_start: None,
220 recovery_start: None,
221 }
222 }
223
224 fn should_drop_packet(&mut self, config: &BurstLossConfig) -> bool {
226 if !config.enabled {
227 return false;
228 }
229
230 let now = Instant::now();
231
232 match (self.in_burst, self.burst_start, self.recovery_start) {
233 (true, Some(burst_start), _) => {
234 let burst_duration = now.duration_since(burst_start);
236 if burst_duration >= Duration::from_millis(config.burst_duration_ms) {
237 self.in_burst = false;
239 self.burst_start = None;
240 self.recovery_start = Some(now);
241 false } else {
243 let mut rng = rand::rng();
245 rng.random_bool(config.loss_rate_during_burst)
246 }
247 }
248 (true, None, _) => {
249 self.in_burst = false;
251 false
252 }
253 (false, _, Some(recovery_start)) => {
254 let recovery_duration = now.duration_since(recovery_start);
256 if recovery_duration >= Duration::from_millis(config.recovery_time_ms) {
257 self.recovery_start = None;
259 let mut rng = rand::rng();
261 if rng.random_bool(config.burst_probability) {
262 self.in_burst = true;
263 self.burst_start = Some(now);
264 rng.random_bool(config.loss_rate_during_burst)
266 } else {
267 false
268 }
269 } else {
270 false }
272 }
273 (false, _, None) => {
274 let mut rng = rand::rng();
276 if rng.random_bool(config.burst_probability) {
277 self.in_burst = true;
278 self.burst_start = Some(now);
279 rng.random_bool(config.loss_rate_during_burst)
280 } else {
281 false
282 }
283 }
284 }
285 }
286}
287
288#[derive(Debug, Clone, serde::Deserialize, serde::Serialize, Default)]
290pub struct TrafficShapingConfig {
291 pub bandwidth: BandwidthConfig,
293 pub burst_loss: BurstLossConfig,
295}
296
297#[derive(Debug, Clone)]
299pub struct TrafficShaper {
300 bandwidth_config: BandwidthConfig,
302 burst_loss_config: BurstLossConfig,
304 token_buckets: Arc<RwLock<HashMap<String, Arc<Mutex<TokenBucket>>>>>,
306 burst_loss_state: Arc<Mutex<BurstLossState>>,
308}
309
310impl TrafficShaper {
311 pub fn new(config: TrafficShapingConfig) -> Self {
313 Self {
314 bandwidth_config: config.bandwidth,
315 burst_loss_config: config.burst_loss,
316 token_buckets: Arc::new(RwLock::new(HashMap::new())),
317 burst_loss_state: Arc::new(Mutex::new(BurstLossState::new())),
318 }
319 }
320
321 pub async fn throttle_bandwidth(&self, data_size: u64, tags: &[String]) -> Result<()> {
323 if !self.bandwidth_config.enabled {
324 return Ok(());
325 }
326
327 let (bucket_key, effective_limit) = self.resolve_bandwidth_bucket(tags);
328
329 if effective_limit == 0 {
330 return Ok(());
331 }
332
333 let bucket_arc = self.get_or_create_bucket(&bucket_key, effective_limit).await;
334
335 {
336 let mut bucket = bucket_arc.lock().await;
337 if bucket.try_consume(data_size) {
338 return Ok(());
339 }
340
341 let wait_time = bucket.time_until_available(data_size);
342 drop(bucket);
343
344 if wait_time.is_zero() {
345 return Err(Error::generic(format!(
346 "Failed to acquire bandwidth tokens for {} bytes",
347 data_size
348 )));
349 }
350
351 tokio::time::sleep(wait_time).await;
352 }
353
354 let mut bucket = bucket_arc.lock().await;
355 if bucket.try_consume(data_size) {
356 Ok(())
357 } else {
358 Err(Error::generic(format!(
359 "Failed to acquire bandwidth tokens for {} bytes",
360 data_size
361 )))
362 }
363 }
364
365 pub async fn should_drop_packet(&self, tags: &[String]) -> bool {
367 if !self.burst_loss_config.enabled {
368 return false;
369 }
370
371 let effective_config = self.burst_loss_config.effective_config(tags);
372 let mut state = self.burst_loss_state.lock().await;
373 state.should_drop_packet(effective_config.as_ref())
374 }
375
376 pub async fn process_transfer(
378 &self,
379 data_size: u64,
380 tags: &[String],
381 ) -> Result<Option<Duration>> {
382 self.throttle_bandwidth(data_size, tags).await?;
384
385 if self.should_drop_packet(tags).await {
387 return Ok(Some(Duration::from_millis(100))); }
389
390 Ok(None)
391 }
392
393 pub async fn get_bandwidth_stats(&self) -> BandwidthStats {
395 let maybe_bucket = {
396 let guard = self.token_buckets.read().await;
397 guard.get(GLOBAL_BUCKET_KEY).cloned()
398 };
399
400 if let Some(bucket_arc) = maybe_bucket {
401 let bucket = bucket_arc.lock().await;
402 BandwidthStats {
403 current_tokens: bucket.tokens as u64,
404 capacity: bucket.capacity as u64,
405 refill_rate_bytes_per_sec: bucket.refill_rate as u64,
406 }
407 } else {
408 BandwidthStats {
409 current_tokens: self.bandwidth_config.burst_capacity_bytes,
410 capacity: self.bandwidth_config.burst_capacity_bytes,
411 refill_rate_bytes_per_sec: self.bandwidth_config.max_bytes_per_sec,
412 }
413 }
414 }
415
416 pub async fn get_burst_loss_stats(&self) -> BurstLossStats {
418 let state = self.burst_loss_state.lock().await;
419 BurstLossStats {
420 in_burst: state.in_burst,
421 burst_start: state.burst_start,
422 recovery_start: state.recovery_start,
423 }
424 }
425
426 async fn get_or_create_bucket(
427 &self,
428 bucket_key: &str,
429 effective_limit: u64,
430 ) -> Arc<Mutex<TokenBucket>> {
431 if let Some(existing) = self.token_buckets.read().await.get(bucket_key).cloned() {
432 return existing;
433 }
434
435 let mut buckets = self.token_buckets.write().await;
436 buckets
437 .entry(bucket_key.to_string())
438 .or_insert_with(|| {
439 Arc::new(Mutex::new(TokenBucket::new(
440 self.bandwidth_config.burst_capacity_bytes,
441 effective_limit,
442 )))
443 })
444 .clone()
445 }
446
447 fn resolve_bandwidth_bucket(&self, tags: &[String]) -> (String, u64) {
448 if let Some((tag, limit)) = tags.iter().find_map(|tag| {
449 self.bandwidth_config.tag_overrides.get(tag).map(|limit| (tag.as_str(), *limit))
450 }) {
451 (format!("tag:{}", tag), limit)
452 } else {
453 (GLOBAL_BUCKET_KEY.to_string(), self.bandwidth_config.max_bytes_per_sec)
454 }
455 }
456}
457
458#[derive(Debug, Clone)]
460pub struct BandwidthStats {
461 pub current_tokens: u64,
462 pub capacity: u64,
463 pub refill_rate_bytes_per_sec: u64,
464}
465
466#[derive(Debug, Clone)]
468pub struct BurstLossStats {
469 pub in_burst: bool,
470 pub burst_start: Option<Instant>,
471 pub recovery_start: Option<Instant>,
472}
473
474#[cfg(test)]
475mod tests {
476 use super::*;
477 use std::time::Duration;
478
479 #[tokio::test]
480 async fn test_bandwidth_throttling() {
481 let config = TrafficShapingConfig {
482 bandwidth: BandwidthConfig::new(1000, 100), burst_loss: BurstLossConfig::default(),
484 };
485 let shaper = TrafficShaper::new(config);
486
487 let result = shaper.throttle_bandwidth(50, &[]).await;
489 assert!(result.is_ok());
490
491 let start = Instant::now();
493 let result = shaper.throttle_bandwidth(80, &[]).await; let elapsed = start.elapsed();
495 assert!(result.is_ok());
496 assert!(elapsed >= Duration::from_millis(30)); }
499
500 #[tokio::test]
501 async fn test_burst_loss() {
502 let config = TrafficShapingConfig {
503 bandwidth: BandwidthConfig::default(),
504 burst_loss: BurstLossConfig::new(1.0, 1000, 1.0, 1000), };
506 let shaper = TrafficShaper::new(config);
507
508 let should_drop = shaper.should_drop_packet(&[]).await;
510 assert!(should_drop);
511
512 for _ in 0..5 {
514 let should_drop = shaper.should_drop_packet(&[]).await;
515 assert!(should_drop);
516 }
517 }
518
519 #[tokio::test]
520 async fn test_bandwidth_tag_override_with_global_unlimited() {
521 let mut bandwidth = BandwidthConfig::default();
522 bandwidth.enabled = true;
523 bandwidth.max_bytes_per_sec = 0;
524 bandwidth.burst_capacity_bytes = 100;
525 bandwidth = bandwidth.with_tag_override("limited".to_string(), 100);
526
527 let shaper = TrafficShaper::new(TrafficShapingConfig {
528 bandwidth,
529 burst_loss: BurstLossConfig::default(),
530 });
531
532 let tags = vec!["limited".to_string()];
533 shaper
534 .throttle_bandwidth(100, &tags)
535 .await
536 .expect("initial transfer should succeed immediately");
537
538 let start = Instant::now();
539 shaper
540 .throttle_bandwidth(100, &tags)
541 .await
542 .expect("tag override should throttle but eventually succeed");
543 assert!(
544 start.elapsed() >= Duration::from_millis(900),
545 "override-specific transfer should respect configured rate"
546 );
547 }
548
549 #[test]
550 fn test_bandwidth_config_overrides() {
551 let mut config = BandwidthConfig::new(1000, 100);
552 config = config.with_tag_override("high-priority".to_string(), 5000);
553
554 assert_eq!(config.get_effective_limit(&[]), 1000);
555 assert_eq!(config.get_effective_limit(&["high-priority".to_string()]), 5000);
556 assert_eq!(
557 config.get_effective_limit(&["low-priority".to_string(), "high-priority".to_string()]),
558 5000
559 );
560 }
561
562 #[test]
563 fn test_burst_loss_effective_config_override() {
564 let override_cfg = BurstLossOverride {
565 burst_probability: 0.8,
566 burst_duration_ms: 2000,
567 loss_rate_during_burst: 0.9,
568 recovery_time_ms: 5000,
569 };
570
571 let config =
572 BurstLossConfig::default().with_tag_override("flaky".to_string(), override_cfg.clone());
573
574 let effective = config.effective_config(&["flaky".to_string()]);
575 assert_eq!(effective.burst_probability, override_cfg.burst_probability);
576 assert_eq!(effective.burst_duration_ms, override_cfg.burst_duration_ms);
577 assert_eq!(effective.loss_rate_during_burst, override_cfg.loss_rate_during_burst);
578 assert_eq!(effective.recovery_time_ms, override_cfg.recovery_time_ms);
579 }
580}