1use crate::{Error, Result};
9use rand::Rng;
10use std::borrow::Cow;
11use std::collections::HashMap;
12use std::sync::Arc;
13use std::time::{Duration, Instant};
14use tokio::sync::{Mutex, RwLock};
15
16const GLOBAL_BUCKET_KEY: &str = "__global__";
17
18#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
20#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
21pub struct BandwidthConfig {
22 pub enabled: bool,
24 pub max_bytes_per_sec: u64,
26 pub burst_capacity_bytes: u64,
28 pub tag_overrides: HashMap<String, u64>,
30}
31
32impl Default for BandwidthConfig {
33 fn default() -> Self {
34 Self {
35 enabled: false,
36 max_bytes_per_sec: 0, burst_capacity_bytes: 1024 * 1024, tag_overrides: HashMap::new(),
39 }
40 }
41}
42
43impl BandwidthConfig {
44 pub fn new(max_bytes_per_sec: u64, burst_capacity_bytes: u64) -> Self {
46 Self {
47 enabled: true,
48 max_bytes_per_sec,
49 burst_capacity_bytes,
50 tag_overrides: HashMap::new(),
51 }
52 }
53
54 pub fn with_tag_override(mut self, tag: String, max_bytes_per_sec: u64) -> Self {
56 self.tag_overrides.insert(tag, max_bytes_per_sec);
57 self
58 }
59
60 pub fn get_effective_limit(&self, tags: &[String]) -> u64 {
62 if let Some(&override_limit) = tags.iter().find_map(|tag| self.tag_overrides.get(tag)) {
64 return override_limit;
65 }
66 self.max_bytes_per_sec
67 }
68}
69
70#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
72#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
73pub struct BurstLossConfig {
74 pub enabled: bool,
76 pub burst_probability: f64,
78 pub burst_duration_ms: u64,
80 pub loss_rate_during_burst: f64,
82 pub recovery_time_ms: u64,
84 pub tag_overrides: HashMap<String, BurstLossOverride>,
86}
87
88#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
90#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
91pub struct BurstLossOverride {
92 pub burst_probability: f64,
94 pub burst_duration_ms: u64,
96 pub loss_rate_during_burst: f64,
98 pub recovery_time_ms: u64,
100}
101
102impl Default for BurstLossConfig {
103 fn default() -> Self {
104 Self {
105 enabled: false,
106 burst_probability: 0.1, burst_duration_ms: 5000, loss_rate_during_burst: 0.5, recovery_time_ms: 30000, tag_overrides: HashMap::new(),
111 }
112 }
113}
114
115impl BurstLossConfig {
116 pub fn new(
118 burst_probability: f64,
119 burst_duration_ms: u64,
120 loss_rate: f64,
121 recovery_time_ms: u64,
122 ) -> Self {
123 Self {
124 enabled: true,
125 burst_probability: burst_probability.clamp(0.0, 1.0),
126 burst_duration_ms,
127 loss_rate_during_burst: loss_rate.clamp(0.0, 1.0),
128 recovery_time_ms,
129 tag_overrides: HashMap::new(),
130 }
131 }
132
133 pub fn with_tag_override(mut self, tag: String, override_config: BurstLossOverride) -> Self {
135 self.tag_overrides.insert(tag, override_config);
136 self
137 }
138
139 pub fn effective_config<'a>(&'a self, tags: &[String]) -> Cow<'a, BurstLossConfig> {
141 if let Some(override_config) = tags.iter().find_map(|tag| self.tag_overrides.get(tag)) {
142 let mut temp_config = self.clone();
143 temp_config.burst_probability = override_config.burst_probability;
144 temp_config.burst_duration_ms = override_config.burst_duration_ms;
145 temp_config.loss_rate_during_burst = override_config.loss_rate_during_burst;
146 temp_config.recovery_time_ms = override_config.recovery_time_ms;
147 Cow::Owned(temp_config)
148 } else {
149 Cow::Borrowed(self)
150 }
151 }
152}
153
154#[derive(Debug)]
156struct TokenBucket {
157 tokens: f64,
159 capacity: f64,
161 refill_rate: f64,
163 last_refill: Instant,
165}
166
167impl TokenBucket {
168 fn new(capacity: u64, refill_rate_bytes_per_sec: u64) -> Self {
170 Self {
171 tokens: capacity as f64,
172 capacity: capacity as f64,
173 refill_rate: refill_rate_bytes_per_sec as f64,
174 last_refill: Instant::now(),
175 }
176 }
177
178 fn refill(&mut self) {
180 let now = Instant::now();
181 let elapsed = now.duration_since(self.last_refill).as_secs_f64();
182 let tokens_to_add = elapsed * self.refill_rate;
183
184 self.tokens = (self.tokens + tokens_to_add).min(self.capacity);
185 self.last_refill = now;
186 }
187
188 fn try_consume(&mut self, bytes: u64) -> bool {
190 self.refill();
191 if self.tokens >= bytes as f64 {
192 self.tokens -= bytes as f64;
193 true
194 } else {
195 false
196 }
197 }
198
199 fn time_until_available(&mut self, bytes: u64) -> Duration {
201 self.refill();
202 if self.tokens >= bytes as f64 {
203 Duration::ZERO
204 } else {
205 let tokens_needed = bytes as f64 - self.tokens;
206 let seconds_needed = tokens_needed / self.refill_rate;
207 Duration::from_secs_f64(seconds_needed)
208 }
209 }
210}
211
212#[derive(Debug)]
214struct BurstLossState {
215 in_burst: bool,
217 burst_start: Option<Instant>,
219 recovery_start: Option<Instant>,
221}
222
223impl BurstLossState {
224 fn new() -> Self {
225 Self {
226 in_burst: false,
227 burst_start: None,
228 recovery_start: None,
229 }
230 }
231
232 fn should_drop_packet(&mut self, config: &BurstLossConfig) -> bool {
234 if !config.enabled {
235 return false;
236 }
237
238 let now = Instant::now();
239
240 match (self.in_burst, self.burst_start, self.recovery_start) {
241 (true, Some(burst_start), _) => {
242 let burst_duration = now.duration_since(burst_start);
244 if burst_duration >= Duration::from_millis(config.burst_duration_ms) {
245 self.in_burst = false;
247 self.burst_start = None;
248 self.recovery_start = Some(now);
249 false } else {
251 let mut rng = rand::rng();
253 rng.random_bool(config.loss_rate_during_burst)
254 }
255 }
256 (true, None, _) => {
257 self.in_burst = false;
259 false
260 }
261 (false, _, Some(recovery_start)) => {
262 let recovery_duration = now.duration_since(recovery_start);
264 if recovery_duration >= Duration::from_millis(config.recovery_time_ms) {
265 self.recovery_start = None;
267 let mut rng = rand::rng();
269 if rng.random_bool(config.burst_probability) {
270 self.in_burst = true;
271 self.burst_start = Some(now);
272 rng.random_bool(config.loss_rate_during_burst)
274 } else {
275 false
276 }
277 } else {
278 false }
280 }
281 (false, _, None) => {
282 let mut rng = rand::rng();
284 if rng.random_bool(config.burst_probability) {
285 self.in_burst = true;
286 self.burst_start = Some(now);
287 rng.random_bool(config.loss_rate_during_burst)
288 } else {
289 false
290 }
291 }
292 }
293 }
294}
295
296#[derive(Debug, Clone, serde::Deserialize, serde::Serialize, Default)]
298#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
299pub struct TrafficShapingConfig {
300 pub bandwidth: BandwidthConfig,
302 pub burst_loss: BurstLossConfig,
304}
305
306#[derive(Debug, Clone)]
308pub struct TrafficShaper {
309 bandwidth_config: BandwidthConfig,
311 burst_loss_config: BurstLossConfig,
313 token_buckets: Arc<RwLock<HashMap<String, Arc<Mutex<TokenBucket>>>>>,
315 burst_loss_state: Arc<Mutex<BurstLossState>>,
317}
318
319impl TrafficShaper {
320 pub fn new(config: TrafficShapingConfig) -> Self {
322 Self {
323 bandwidth_config: config.bandwidth,
324 burst_loss_config: config.burst_loss,
325 token_buckets: Arc::new(RwLock::new(HashMap::new())),
326 burst_loss_state: Arc::new(Mutex::new(BurstLossState::new())),
327 }
328 }
329
330 pub async fn throttle_bandwidth(&self, data_size: u64, tags: &[String]) -> Result<()> {
332 if !self.bandwidth_config.enabled {
333 return Ok(());
334 }
335
336 let (bucket_key, effective_limit) = self.resolve_bandwidth_bucket(tags);
337
338 if effective_limit == 0 {
339 return Ok(());
340 }
341
342 let bucket_arc = self.get_or_create_bucket(&bucket_key, effective_limit).await;
343
344 {
345 let mut bucket = bucket_arc.lock().await;
346 if bucket.try_consume(data_size) {
347 return Ok(());
348 }
349
350 let wait_time = bucket.time_until_available(data_size);
351 drop(bucket);
352
353 if wait_time.is_zero() {
354 return Err(Error::generic(format!(
355 "Failed to acquire bandwidth tokens for {} bytes",
356 data_size
357 )));
358 }
359
360 tokio::time::sleep(wait_time).await;
361 }
362
363 let mut bucket = bucket_arc.lock().await;
364 if bucket.try_consume(data_size) {
365 Ok(())
366 } else {
367 Err(Error::generic(format!(
368 "Failed to acquire bandwidth tokens for {} bytes",
369 data_size
370 )))
371 }
372 }
373
374 pub async fn should_drop_packet(&self, tags: &[String]) -> bool {
376 if !self.burst_loss_config.enabled {
377 return false;
378 }
379
380 let effective_config = self.burst_loss_config.effective_config(tags);
381 let mut state = self.burst_loss_state.lock().await;
382 state.should_drop_packet(effective_config.as_ref())
383 }
384
385 pub async fn process_transfer(
387 &self,
388 data_size: u64,
389 tags: &[String],
390 ) -> Result<Option<Duration>> {
391 self.throttle_bandwidth(data_size, tags).await?;
393
394 if self.should_drop_packet(tags).await {
396 return Ok(Some(Duration::from_millis(100))); }
398
399 Ok(None)
400 }
401
402 pub async fn get_bandwidth_stats(&self) -> BandwidthStats {
404 let maybe_bucket = {
405 let guard = self.token_buckets.read().await;
406 guard.get(GLOBAL_BUCKET_KEY).cloned()
407 };
408
409 if let Some(bucket_arc) = maybe_bucket {
410 let bucket = bucket_arc.lock().await;
411 BandwidthStats {
412 current_tokens: bucket.tokens as u64,
413 capacity: bucket.capacity as u64,
414 refill_rate_bytes_per_sec: bucket.refill_rate as u64,
415 }
416 } else {
417 BandwidthStats {
418 current_tokens: self.bandwidth_config.burst_capacity_bytes,
419 capacity: self.bandwidth_config.burst_capacity_bytes,
420 refill_rate_bytes_per_sec: self.bandwidth_config.max_bytes_per_sec,
421 }
422 }
423 }
424
425 pub async fn get_burst_loss_stats(&self) -> BurstLossStats {
427 let state = self.burst_loss_state.lock().await;
428 BurstLossStats {
429 in_burst: state.in_burst,
430 burst_start: state.burst_start,
431 recovery_start: state.recovery_start,
432 }
433 }
434
435 async fn get_or_create_bucket(
436 &self,
437 bucket_key: &str,
438 effective_limit: u64,
439 ) -> Arc<Mutex<TokenBucket>> {
440 if let Some(existing) = self.token_buckets.read().await.get(bucket_key).cloned() {
441 return existing;
442 }
443
444 let mut buckets = self.token_buckets.write().await;
445 buckets
446 .entry(bucket_key.to_string())
447 .or_insert_with(|| {
448 Arc::new(Mutex::new(TokenBucket::new(
449 self.bandwidth_config.burst_capacity_bytes,
450 effective_limit,
451 )))
452 })
453 .clone()
454 }
455
456 fn resolve_bandwidth_bucket(&self, tags: &[String]) -> (String, u64) {
457 if let Some((tag, limit)) = tags.iter().find_map(|tag| {
458 self.bandwidth_config.tag_overrides.get(tag).map(|limit| (tag.as_str(), *limit))
459 }) {
460 (format!("tag:{}", tag), limit)
461 } else {
462 (GLOBAL_BUCKET_KEY.to_string(), self.bandwidth_config.max_bytes_per_sec)
463 }
464 }
465}
466
467#[derive(Debug, Clone)]
469pub struct BandwidthStats {
470 pub current_tokens: u64,
472 pub capacity: u64,
474 pub refill_rate_bytes_per_sec: u64,
476}
477
478#[derive(Debug, Clone)]
480pub struct BurstLossStats {
481 pub in_burst: bool,
483 pub burst_start: Option<Instant>,
485 pub recovery_start: Option<Instant>,
487}
488
489#[cfg(test)]
490mod tests {
491 use super::*;
492 use std::time::Duration;
493
494 #[tokio::test]
495 async fn test_bandwidth_throttling() {
496 let config = TrafficShapingConfig {
497 bandwidth: BandwidthConfig::new(1000, 100), burst_loss: BurstLossConfig::default(),
499 };
500 let shaper = TrafficShaper::new(config);
501
502 let result = shaper.throttle_bandwidth(50, &[]).await;
504 assert!(result.is_ok());
505
506 let start = Instant::now();
508 let result = shaper.throttle_bandwidth(80, &[]).await; let elapsed = start.elapsed();
510 assert!(result.is_ok());
511 assert!(elapsed >= Duration::from_millis(30)); }
514
515 #[tokio::test]
516 async fn test_burst_loss() {
517 let config = TrafficShapingConfig {
518 bandwidth: BandwidthConfig::default(),
519 burst_loss: BurstLossConfig::new(1.0, 1000, 1.0, 1000), };
521 let shaper = TrafficShaper::new(config);
522
523 let should_drop = shaper.should_drop_packet(&[]).await;
525 assert!(should_drop);
526
527 for _ in 0..5 {
529 let should_drop = shaper.should_drop_packet(&[]).await;
530 assert!(should_drop);
531 }
532 }
533
534 #[tokio::test]
535 async fn test_bandwidth_tag_override_with_global_unlimited() {
536 let mut bandwidth = BandwidthConfig::default();
537 bandwidth.enabled = true;
538 bandwidth.max_bytes_per_sec = 0;
539 bandwidth.burst_capacity_bytes = 100;
540 bandwidth = bandwidth.with_tag_override("limited".to_string(), 100);
541
542 let shaper = TrafficShaper::new(TrafficShapingConfig {
543 bandwidth,
544 burst_loss: BurstLossConfig::default(),
545 });
546
547 let tags = vec!["limited".to_string()];
548 shaper
549 .throttle_bandwidth(100, &tags)
550 .await
551 .expect("initial transfer should succeed immediately");
552
553 let start = Instant::now();
554 shaper
555 .throttle_bandwidth(100, &tags)
556 .await
557 .expect("tag override should throttle but eventually succeed");
558 assert!(
559 start.elapsed() >= Duration::from_millis(900),
560 "override-specific transfer should respect configured rate"
561 );
562 }
563
564 #[test]
565 fn test_bandwidth_config_overrides() {
566 let mut config = BandwidthConfig::new(1000, 100);
567 config = config.with_tag_override("high-priority".to_string(), 5000);
568
569 assert_eq!(config.get_effective_limit(&[]), 1000);
570 assert_eq!(config.get_effective_limit(&["high-priority".to_string()]), 5000);
571 assert_eq!(
572 config.get_effective_limit(&["low-priority".to_string(), "high-priority".to_string()]),
573 5000
574 );
575 }
576
577 #[test]
578 fn test_burst_loss_effective_config_override() {
579 let override_cfg = BurstLossOverride {
580 burst_probability: 0.8,
581 burst_duration_ms: 2000,
582 loss_rate_during_burst: 0.9,
583 recovery_time_ms: 5000,
584 };
585
586 let config =
587 BurstLossConfig::default().with_tag_override("flaky".to_string(), override_cfg.clone());
588
589 let effective = config.effective_config(&["flaky".to_string()]);
590 assert_eq!(effective.burst_probability, override_cfg.burst_probability);
591 assert_eq!(effective.burst_duration_ms, override_cfg.burst_duration_ms);
592 assert_eq!(effective.loss_rate_during_burst, override_cfg.loss_rate_during_burst);
593 assert_eq!(effective.recovery_time_ms, override_cfg.recovery_time_ms);
594 }
595}