1use crate::{Error, Result};
9use rand::Rng;
10use std::collections::HashMap;
11use std::sync::Arc;
12use std::time::{Duration, Instant};
13use tokio::sync::Mutex;
14
15#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
17pub struct BandwidthConfig {
18 pub enabled: bool,
20 pub max_bytes_per_sec: u64,
22 pub burst_capacity_bytes: u64,
24 pub tag_overrides: HashMap<String, u64>,
26}
27
28impl Default for BandwidthConfig {
29 fn default() -> Self {
30 Self {
31 enabled: false,
32 max_bytes_per_sec: 0, burst_capacity_bytes: 1024 * 1024, tag_overrides: HashMap::new(),
35 }
36 }
37}
38
39impl BandwidthConfig {
40 pub fn new(max_bytes_per_sec: u64, burst_capacity_bytes: u64) -> Self {
42 Self {
43 enabled: true,
44 max_bytes_per_sec,
45 burst_capacity_bytes,
46 tag_overrides: HashMap::new(),
47 }
48 }
49
50 pub fn with_tag_override(mut self, tag: String, max_bytes_per_sec: u64) -> Self {
52 self.tag_overrides.insert(tag, max_bytes_per_sec);
53 self
54 }
55
56 pub fn get_effective_limit(&self, tags: &[String]) -> u64 {
58 if let Some(&override_limit) = tags.iter().find_map(|tag| self.tag_overrides.get(tag)) {
60 return override_limit;
61 }
62 self.max_bytes_per_sec
63 }
64}
65
66#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
68pub struct BurstLossConfig {
69 pub enabled: bool,
71 pub burst_probability: f64,
73 pub burst_duration_ms: u64,
75 pub loss_rate_during_burst: f64,
77 pub recovery_time_ms: u64,
79 pub tag_overrides: HashMap<String, BurstLossOverride>,
81}
82
83#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
84pub struct BurstLossOverride {
85 pub burst_probability: f64,
86 pub burst_duration_ms: u64,
87 pub loss_rate_during_burst: f64,
88 pub recovery_time_ms: u64,
89}
90
91impl Default for BurstLossConfig {
92 fn default() -> Self {
93 Self {
94 enabled: false,
95 burst_probability: 0.1, burst_duration_ms: 5000, loss_rate_during_burst: 0.5, recovery_time_ms: 30000, tag_overrides: HashMap::new(),
100 }
101 }
102}
103
104impl BurstLossConfig {
105 pub fn new(
107 burst_probability: f64,
108 burst_duration_ms: u64,
109 loss_rate: f64,
110 recovery_time_ms: u64,
111 ) -> Self {
112 Self {
113 enabled: true,
114 burst_probability: burst_probability.clamp(0.0, 1.0),
115 burst_duration_ms,
116 loss_rate_during_burst: loss_rate.clamp(0.0, 1.0),
117 recovery_time_ms,
118 tag_overrides: HashMap::new(),
119 }
120 }
121
122 pub fn with_tag_override(mut self, tag: String, override_config: BurstLossOverride) -> Self {
124 self.tag_overrides.insert(tag, override_config);
125 self
126 }
127
128 pub fn get_effective_config(&self, tags: &[String]) -> &BurstLossConfig {
130 if let Some(override_config) = tags.iter().find_map(|tag| self.tag_overrides.get(tag)) {
132 let mut temp_config = self.clone();
135 temp_config.burst_probability = override_config.burst_probability;
136 temp_config.burst_duration_ms = override_config.burst_duration_ms;
137 temp_config.loss_rate_during_burst = override_config.loss_rate_during_burst;
138 temp_config.recovery_time_ms = override_config.recovery_time_ms;
139 return Box::leak(Box::new(temp_config));
140 }
141 self
142 }
143}
144
145#[derive(Debug)]
147struct TokenBucket {
148 tokens: f64,
150 capacity: f64,
152 refill_rate: f64,
154 last_refill: Instant,
156}
157
158impl TokenBucket {
159 fn new(capacity: u64, refill_rate_bytes_per_sec: u64) -> Self {
161 Self {
162 tokens: capacity as f64,
163 capacity: capacity as f64,
164 refill_rate: refill_rate_bytes_per_sec as f64,
165 last_refill: Instant::now(),
166 }
167 }
168
169 fn refill(&mut self) {
171 let now = Instant::now();
172 let elapsed = now.duration_since(self.last_refill).as_secs_f64();
173 let tokens_to_add = elapsed * self.refill_rate;
174
175 self.tokens = (self.tokens + tokens_to_add).min(self.capacity);
176 self.last_refill = now;
177 }
178
179 fn try_consume(&mut self, bytes: u64) -> bool {
181 self.refill();
182 if self.tokens >= bytes as f64 {
183 self.tokens -= bytes as f64;
184 true
185 } else {
186 false
187 }
188 }
189
190 fn time_until_available(&mut self, bytes: u64) -> Duration {
192 self.refill();
193 if self.tokens >= bytes as f64 {
194 Duration::ZERO
195 } else {
196 let tokens_needed = bytes as f64 - self.tokens;
197 let seconds_needed = tokens_needed / self.refill_rate;
198 Duration::from_secs_f64(seconds_needed)
199 }
200 }
201}
202
203#[derive(Debug)]
205struct BurstLossState {
206 in_burst: bool,
208 burst_start: Option<Instant>,
210 recovery_start: Option<Instant>,
212}
213
214impl BurstLossState {
215 fn new() -> Self {
216 Self {
217 in_burst: false,
218 burst_start: None,
219 recovery_start: None,
220 }
221 }
222
223 fn should_drop_packet(&mut self, config: &BurstLossConfig) -> bool {
225 if !config.enabled {
226 return false;
227 }
228
229 let now = Instant::now();
230
231 match (self.in_burst, self.burst_start, self.recovery_start) {
232 (true, Some(burst_start), _) => {
233 let burst_duration = now.duration_since(burst_start);
235 if burst_duration >= Duration::from_millis(config.burst_duration_ms) {
236 self.in_burst = false;
238 self.burst_start = None;
239 self.recovery_start = Some(now);
240 false } else {
242 let mut rng = rand::rng();
244 rng.random_bool(config.loss_rate_during_burst)
245 }
246 }
247 (true, None, _) => {
248 self.in_burst = false;
250 false
251 }
252 (false, _, Some(recovery_start)) => {
253 let recovery_duration = now.duration_since(recovery_start);
255 if recovery_duration >= Duration::from_millis(config.recovery_time_ms) {
256 self.recovery_start = None;
258 let mut rng = rand::rng();
260 if rng.random_bool(config.burst_probability) {
261 self.in_burst = true;
262 self.burst_start = Some(now);
263 rng.random_bool(config.loss_rate_during_burst)
265 } else {
266 false
267 }
268 } else {
269 false }
271 }
272 (false, _, None) => {
273 let mut rng = rand::rng();
275 if rng.random_bool(config.burst_probability) {
276 self.in_burst = true;
277 self.burst_start = Some(now);
278 rng.random_bool(config.loss_rate_during_burst)
279 } else {
280 false
281 }
282 }
283 }
284 }
285}
286
287#[derive(Debug, Clone, serde::Deserialize, serde::Serialize, Default)]
289pub struct TrafficShapingConfig {
290 pub bandwidth: BandwidthConfig,
292 pub burst_loss: BurstLossConfig,
294}
295
296#[derive(Debug, Clone)]
298pub struct TrafficShaper {
299 bandwidth_config: BandwidthConfig,
301 burst_loss_config: BurstLossConfig,
303 token_bucket: Arc<Mutex<TokenBucket>>,
305 burst_loss_state: Arc<Mutex<BurstLossState>>,
307}
308
309impl TrafficShaper {
310 pub fn new(config: TrafficShapingConfig) -> Self {
312 let token_bucket = if config.bandwidth.enabled && config.bandwidth.max_bytes_per_sec > 0 {
313 TokenBucket::new(
314 config.bandwidth.burst_capacity_bytes,
315 config.bandwidth.max_bytes_per_sec,
316 )
317 } else {
318 TokenBucket::new(u64::MAX, u64::MAX)
320 };
321
322 Self {
323 bandwidth_config: config.bandwidth,
324 burst_loss_config: config.burst_loss,
325 token_bucket: Arc::new(Mutex::new(token_bucket)),
326 burst_loss_state: Arc::new(Mutex::new(BurstLossState::new())),
327 }
328 }
329
330 pub async fn throttle_bandwidth(&self, data_size: u64, tags: &[String]) -> Result<()> {
332 if !self.bandwidth_config.enabled {
333 return Ok(());
334 }
335
336 let effective_limit = self.bandwidth_config.get_effective_limit(tags);
337 if effective_limit == 0 {
338 return Ok(());
340 }
341
342 let mut bucket = self.token_bucket.lock().await;
343
344 if !bucket.try_consume(data_size) {
345 let wait_time = bucket.time_until_available(data_size);
347 if !wait_time.is_zero() {
348 tokio::time::sleep(wait_time).await;
349 }
350 if !bucket.try_consume(data_size) {
352 return Err(Error::generic(format!(
353 "Failed to acquire bandwidth tokens for {} bytes",
354 data_size
355 )));
356 }
357 }
358
359 Ok(())
360 }
361
362 pub async fn should_drop_packet(&self, tags: &[String]) -> bool {
364 if !self.burst_loss_config.enabled {
365 return false;
366 }
367
368 let effective_config = self.burst_loss_config.get_effective_config(tags);
369 let mut state = self.burst_loss_state.lock().await;
370 state.should_drop_packet(effective_config)
371 }
372
373 pub async fn process_transfer(
375 &self,
376 data_size: u64,
377 tags: &[String],
378 ) -> Result<Option<Duration>> {
379 self.throttle_bandwidth(data_size, tags).await?;
381
382 if self.should_drop_packet(tags).await {
384 return Ok(Some(Duration::from_millis(100))); }
386
387 Ok(None)
388 }
389
390 pub async fn get_bandwidth_stats(&self) -> BandwidthStats {
392 let bucket = self.token_bucket.lock().await;
393 BandwidthStats {
394 current_tokens: bucket.tokens as u64,
395 capacity: bucket.capacity as u64,
396 refill_rate_bytes_per_sec: bucket.refill_rate as u64,
397 }
398 }
399
400 pub async fn get_burst_loss_stats(&self) -> BurstLossStats {
402 let state = self.burst_loss_state.lock().await;
403 BurstLossStats {
404 in_burst: state.in_burst,
405 burst_start: state.burst_start,
406 recovery_start: state.recovery_start,
407 }
408 }
409}
410
411#[derive(Debug, Clone)]
413pub struct BandwidthStats {
414 pub current_tokens: u64,
415 pub capacity: u64,
416 pub refill_rate_bytes_per_sec: u64,
417}
418
419#[derive(Debug, Clone)]
421pub struct BurstLossStats {
422 pub in_burst: bool,
423 pub burst_start: Option<Instant>,
424 pub recovery_start: Option<Instant>,
425}
426
427#[cfg(test)]
428mod tests {
429 use super::*;
430 use std::time::Duration;
431
432 #[tokio::test]
433 async fn test_bandwidth_throttling() {
434 let config = TrafficShapingConfig {
435 bandwidth: BandwidthConfig::new(1000, 100), burst_loss: BurstLossConfig::default(),
437 };
438 let shaper = TrafficShaper::new(config);
439
440 let result = shaper.throttle_bandwidth(50, &[]).await;
442 assert!(result.is_ok());
443
444 let start = Instant::now();
446 let result = shaper.throttle_bandwidth(80, &[]).await; let elapsed = start.elapsed();
448 assert!(result.is_ok());
449 assert!(elapsed >= Duration::from_millis(30)); }
452
453 #[tokio::test]
454 async fn test_burst_loss() {
455 let config = TrafficShapingConfig {
456 bandwidth: BandwidthConfig::default(),
457 burst_loss: BurstLossConfig::new(1.0, 1000, 1.0, 1000), };
459 let shaper = TrafficShaper::new(config);
460
461 let should_drop = shaper.should_drop_packet(&[]).await;
463 assert!(should_drop);
464
465 for _ in 0..5 {
467 let should_drop = shaper.should_drop_packet(&[]).await;
468 assert!(should_drop);
469 }
470 }
471
472 #[test]
473 fn test_bandwidth_config_overrides() {
474 let mut config = BandwidthConfig::new(1000, 100);
475 config = config.with_tag_override("high-priority".to_string(), 5000);
476
477 assert_eq!(config.get_effective_limit(&[]), 1000);
478 assert_eq!(config.get_effective_limit(&["high-priority".to_string()]), 5000);
479 assert_eq!(
480 config.get_effective_limit(&["low-priority".to_string(), "high-priority".to_string()]),
481 5000
482 );
483 }
484}