mockforge_core/
traffic_shaping.rs

1//! Traffic shaping beyond latency simulation
2//!
3//! This module provides advanced traffic shaping capabilities including:
4//! - Bandwidth throttling using token bucket algorithm
5//! - Burst packet loss simulation
6//! - Integration with existing latency and fault injection
7
8use crate::{Error, Result};
9use rand::Rng;
10use std::borrow::Cow;
11use std::collections::HashMap;
12use std::sync::Arc;
13use std::time::{Duration, Instant};
14use tokio::sync::{Mutex, RwLock};
15
16const GLOBAL_BUCKET_KEY: &str = "__global__";
17
18/// Bandwidth throttling configuration
19#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
20pub struct BandwidthConfig {
21    /// Enable bandwidth throttling
22    pub enabled: bool,
23    /// Maximum bandwidth in bytes per second (0 = unlimited)
24    pub max_bytes_per_sec: u64,
25    /// Token bucket capacity in bytes (burst allowance)
26    pub burst_capacity_bytes: u64,
27    /// Tag-based bandwidth overrides
28    pub tag_overrides: HashMap<String, u64>,
29}
30
31impl Default for BandwidthConfig {
32    fn default() -> Self {
33        Self {
34            enabled: false,
35            max_bytes_per_sec: 0,              // Unlimited
36            burst_capacity_bytes: 1024 * 1024, // 1MB burst capacity
37            tag_overrides: HashMap::new(),
38        }
39    }
40}
41
42impl BandwidthConfig {
43    /// Create a new bandwidth configuration
44    pub fn new(max_bytes_per_sec: u64, burst_capacity_bytes: u64) -> Self {
45        Self {
46            enabled: true,
47            max_bytes_per_sec,
48            burst_capacity_bytes,
49            tag_overrides: HashMap::new(),
50        }
51    }
52
53    /// Add a tag-based bandwidth override
54    pub fn with_tag_override(mut self, tag: String, max_bytes_per_sec: u64) -> Self {
55        self.tag_overrides.insert(tag, max_bytes_per_sec);
56        self
57    }
58
59    /// Get the effective bandwidth limit for the given tags
60    pub fn get_effective_limit(&self, tags: &[String]) -> u64 {
61        // Check for tag overrides (use the first matching tag)
62        if let Some(&override_limit) = tags.iter().find_map(|tag| self.tag_overrides.get(tag)) {
63            return override_limit;
64        }
65        self.max_bytes_per_sec
66    }
67}
68
69/// Burst loss configuration for simulating intermittent connectivity issues
70#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
71pub struct BurstLossConfig {
72    /// Enable burst loss simulation
73    pub enabled: bool,
74    /// Probability of entering a loss burst (0.0 to 1.0)
75    pub burst_probability: f64,
76    /// Duration of loss burst in milliseconds
77    pub burst_duration_ms: u64,
78    /// Packet loss rate during burst (0.0 to 1.0)
79    pub loss_rate_during_burst: f64,
80    /// Recovery time between bursts in milliseconds
81    pub recovery_time_ms: u64,
82    /// Tag-based burst loss overrides
83    pub tag_overrides: HashMap<String, BurstLossOverride>,
84}
85
86/// Tag-specific burst loss configuration override
87#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
88pub struct BurstLossOverride {
89    /// Probability of entering a loss burst (0.0 to 1.0)
90    pub burst_probability: f64,
91    /// Duration of loss burst in milliseconds
92    pub burst_duration_ms: u64,
93    /// Packet loss rate during burst (0.0 to 1.0)
94    pub loss_rate_during_burst: f64,
95    /// Recovery time between bursts in milliseconds
96    pub recovery_time_ms: u64,
97}
98
99impl Default for BurstLossConfig {
100    fn default() -> Self {
101        Self {
102            enabled: false,
103            burst_probability: 0.1,      // 10% chance of burst
104            burst_duration_ms: 5000,     // 5 second bursts
105            loss_rate_during_burst: 0.5, // 50% loss during burst
106            recovery_time_ms: 30000,     // 30 second recovery
107            tag_overrides: HashMap::new(),
108        }
109    }
110}
111
112impl BurstLossConfig {
113    /// Create a new burst loss configuration
114    pub fn new(
115        burst_probability: f64,
116        burst_duration_ms: u64,
117        loss_rate: f64,
118        recovery_time_ms: u64,
119    ) -> Self {
120        Self {
121            enabled: true,
122            burst_probability: burst_probability.clamp(0.0, 1.0),
123            burst_duration_ms,
124            loss_rate_during_burst: loss_rate.clamp(0.0, 1.0),
125            recovery_time_ms,
126            tag_overrides: HashMap::new(),
127        }
128    }
129
130    /// Add a tag-based burst loss override
131    pub fn with_tag_override(mut self, tag: String, override_config: BurstLossOverride) -> Self {
132        self.tag_overrides.insert(tag, override_config);
133        self
134    }
135
136    /// Get the effective burst loss config for the given tags
137    pub fn effective_config<'a>(&'a self, tags: &[String]) -> Cow<'a, BurstLossConfig> {
138        if let Some(override_config) = tags.iter().find_map(|tag| self.tag_overrides.get(tag)) {
139            let mut temp_config = self.clone();
140            temp_config.burst_probability = override_config.burst_probability;
141            temp_config.burst_duration_ms = override_config.burst_duration_ms;
142            temp_config.loss_rate_during_burst = override_config.loss_rate_during_burst;
143            temp_config.recovery_time_ms = override_config.recovery_time_ms;
144            Cow::Owned(temp_config)
145        } else {
146            Cow::Borrowed(self)
147        }
148    }
149}
150
151/// Token bucket for bandwidth throttling
152#[derive(Debug)]
153struct TokenBucket {
154    /// Current number of tokens (bytes that can be sent)
155    tokens: f64,
156    /// Maximum capacity of the bucket
157    capacity: f64,
158    /// Rate of token replenishment (tokens per second)
159    refill_rate: f64,
160    /// Last refill timestamp
161    last_refill: Instant,
162}
163
164impl TokenBucket {
165    /// Create a new token bucket
166    fn new(capacity: u64, refill_rate_bytes_per_sec: u64) -> Self {
167        Self {
168            tokens: capacity as f64,
169            capacity: capacity as f64,
170            refill_rate: refill_rate_bytes_per_sec as f64,
171            last_refill: Instant::now(),
172        }
173    }
174
175    /// Refill tokens based on elapsed time
176    fn refill(&mut self) {
177        let now = Instant::now();
178        let elapsed = now.duration_since(self.last_refill).as_secs_f64();
179        let tokens_to_add = elapsed * self.refill_rate;
180
181        self.tokens = (self.tokens + tokens_to_add).min(self.capacity);
182        self.last_refill = now;
183    }
184
185    /// Try to consume tokens for the given number of bytes
186    fn try_consume(&mut self, bytes: u64) -> bool {
187        self.refill();
188        if self.tokens >= bytes as f64 {
189            self.tokens -= bytes as f64;
190            true
191        } else {
192            false
193        }
194    }
195
196    /// Get the time to wait until enough tokens are available
197    fn time_until_available(&mut self, bytes: u64) -> Duration {
198        self.refill();
199        if self.tokens >= bytes as f64 {
200            Duration::ZERO
201        } else {
202            let tokens_needed = bytes as f64 - self.tokens;
203            let seconds_needed = tokens_needed / self.refill_rate;
204            Duration::from_secs_f64(seconds_needed)
205        }
206    }
207}
208
209/// Burst loss state machine
210#[derive(Debug)]
211struct BurstLossState {
212    /// Whether currently in a loss burst
213    in_burst: bool,
214    /// When the current burst started
215    burst_start: Option<Instant>,
216    /// When the current recovery period started
217    recovery_start: Option<Instant>,
218}
219
220impl BurstLossState {
221    fn new() -> Self {
222        Self {
223            in_burst: false,
224            burst_start: None,
225            recovery_start: None,
226        }
227    }
228
229    /// Determine if a packet should be lost based on current state
230    fn should_drop_packet(&mut self, config: &BurstLossConfig) -> bool {
231        if !config.enabled {
232            return false;
233        }
234
235        let now = Instant::now();
236
237        match (self.in_burst, self.burst_start, self.recovery_start) {
238            (true, Some(burst_start), _) => {
239                // Currently in burst - check if burst should end
240                let burst_duration = now.duration_since(burst_start);
241                if burst_duration >= Duration::from_millis(config.burst_duration_ms) {
242                    // End burst and start recovery
243                    self.in_burst = false;
244                    self.burst_start = None;
245                    self.recovery_start = Some(now);
246                    false // Don't drop this packet
247                } else {
248                    // Still in burst - apply loss rate
249                    let mut rng = rand::rng();
250                    rng.random_bool(config.loss_rate_during_burst)
251                }
252            }
253            (true, None, _) => {
254                // Invalid state: in burst but no burst start time - reset to normal
255                self.in_burst = false;
256                false
257            }
258            (false, _, Some(recovery_start)) => {
259                // In recovery - check if recovery should end
260                let recovery_duration = now.duration_since(recovery_start);
261                if recovery_duration >= Duration::from_millis(config.recovery_time_ms) {
262                    // End recovery
263                    self.recovery_start = None;
264                    // Check if we should start a new burst
265                    let mut rng = rand::rng();
266                    if rng.random_bool(config.burst_probability) {
267                        self.in_burst = true;
268                        self.burst_start = Some(now);
269                        // Apply loss rate for the first packet of the burst
270                        rng.random_bool(config.loss_rate_during_burst)
271                    } else {
272                        false
273                    }
274                } else {
275                    false // Still in recovery
276                }
277            }
278            (false, _, None) => {
279                // Not in burst or recovery - check if we should start a burst
280                let mut rng = rand::rng();
281                if rng.random_bool(config.burst_probability) {
282                    self.in_burst = true;
283                    self.burst_start = Some(now);
284                    rng.random_bool(config.loss_rate_during_burst)
285                } else {
286                    false
287                }
288            }
289        }
290    }
291}
292
293/// Traffic shaping configuration combining all features
294#[derive(Debug, Clone, serde::Deserialize, serde::Serialize, Default)]
295pub struct TrafficShapingConfig {
296    /// Bandwidth throttling configuration
297    pub bandwidth: BandwidthConfig,
298    /// Burst loss configuration
299    pub burst_loss: BurstLossConfig,
300}
301
302/// Main traffic shaper combining bandwidth throttling and burst loss
303#[derive(Debug, Clone)]
304pub struct TrafficShaper {
305    /// Bandwidth configuration
306    bandwidth_config: BandwidthConfig,
307    /// Burst loss configuration
308    burst_loss_config: BurstLossConfig,
309    /// Token buckets keyed by effective tag/group
310    token_buckets: Arc<RwLock<HashMap<String, Arc<Mutex<TokenBucket>>>>>,
311    /// Burst loss state
312    burst_loss_state: Arc<Mutex<BurstLossState>>,
313}
314
315impl TrafficShaper {
316    /// Create a new traffic shaper
317    pub fn new(config: TrafficShapingConfig) -> Self {
318        Self {
319            bandwidth_config: config.bandwidth,
320            burst_loss_config: config.burst_loss,
321            token_buckets: Arc::new(RwLock::new(HashMap::new())),
322            burst_loss_state: Arc::new(Mutex::new(BurstLossState::new())),
323        }
324    }
325
326    /// Apply bandwidth throttling to a data transfer
327    pub async fn throttle_bandwidth(&self, data_size: u64, tags: &[String]) -> Result<()> {
328        if !self.bandwidth_config.enabled {
329            return Ok(());
330        }
331
332        let (bucket_key, effective_limit) = self.resolve_bandwidth_bucket(tags);
333
334        if effective_limit == 0 {
335            return Ok(());
336        }
337
338        let bucket_arc = self.get_or_create_bucket(&bucket_key, effective_limit).await;
339
340        {
341            let mut bucket = bucket_arc.lock().await;
342            if bucket.try_consume(data_size) {
343                return Ok(());
344            }
345
346            let wait_time = bucket.time_until_available(data_size);
347            drop(bucket);
348
349            if wait_time.is_zero() {
350                return Err(Error::generic(format!(
351                    "Failed to acquire bandwidth tokens for {} bytes",
352                    data_size
353                )));
354            }
355
356            tokio::time::sleep(wait_time).await;
357        }
358
359        let mut bucket = bucket_arc.lock().await;
360        if bucket.try_consume(data_size) {
361            Ok(())
362        } else {
363            Err(Error::generic(format!(
364                "Failed to acquire bandwidth tokens for {} bytes",
365                data_size
366            )))
367        }
368    }
369
370    /// Check if a packet should be dropped due to burst loss
371    pub async fn should_drop_packet(&self, tags: &[String]) -> bool {
372        if !self.burst_loss_config.enabled {
373            return false;
374        }
375
376        let effective_config = self.burst_loss_config.effective_config(tags);
377        let mut state = self.burst_loss_state.lock().await;
378        state.should_drop_packet(effective_config.as_ref())
379    }
380
381    /// Process a data transfer with both bandwidth throttling and burst loss
382    pub async fn process_transfer(
383        &self,
384        data_size: u64,
385        tags: &[String],
386    ) -> Result<Option<Duration>> {
387        // First, apply bandwidth throttling
388        self.throttle_bandwidth(data_size, tags).await?;
389
390        // Then, check for burst loss
391        if self.should_drop_packet(tags).await {
392            return Ok(Some(Duration::from_millis(100))); // Simulate packet timeout
393        }
394
395        Ok(None)
396    }
397
398    /// Get current bandwidth usage statistics
399    pub async fn get_bandwidth_stats(&self) -> BandwidthStats {
400        let maybe_bucket = {
401            let guard = self.token_buckets.read().await;
402            guard.get(GLOBAL_BUCKET_KEY).cloned()
403        };
404
405        if let Some(bucket_arc) = maybe_bucket {
406            let bucket = bucket_arc.lock().await;
407            BandwidthStats {
408                current_tokens: bucket.tokens as u64,
409                capacity: bucket.capacity as u64,
410                refill_rate_bytes_per_sec: bucket.refill_rate as u64,
411            }
412        } else {
413            BandwidthStats {
414                current_tokens: self.bandwidth_config.burst_capacity_bytes,
415                capacity: self.bandwidth_config.burst_capacity_bytes,
416                refill_rate_bytes_per_sec: self.bandwidth_config.max_bytes_per_sec,
417            }
418        }
419    }
420
421    /// Get current burst loss state
422    pub async fn get_burst_loss_stats(&self) -> BurstLossStats {
423        let state = self.burst_loss_state.lock().await;
424        BurstLossStats {
425            in_burst: state.in_burst,
426            burst_start: state.burst_start,
427            recovery_start: state.recovery_start,
428        }
429    }
430
431    async fn get_or_create_bucket(
432        &self,
433        bucket_key: &str,
434        effective_limit: u64,
435    ) -> Arc<Mutex<TokenBucket>> {
436        if let Some(existing) = self.token_buckets.read().await.get(bucket_key).cloned() {
437            return existing;
438        }
439
440        let mut buckets = self.token_buckets.write().await;
441        buckets
442            .entry(bucket_key.to_string())
443            .or_insert_with(|| {
444                Arc::new(Mutex::new(TokenBucket::new(
445                    self.bandwidth_config.burst_capacity_bytes,
446                    effective_limit,
447                )))
448            })
449            .clone()
450    }
451
452    fn resolve_bandwidth_bucket(&self, tags: &[String]) -> (String, u64) {
453        if let Some((tag, limit)) = tags.iter().find_map(|tag| {
454            self.bandwidth_config.tag_overrides.get(tag).map(|limit| (tag.as_str(), *limit))
455        }) {
456            (format!("tag:{}", tag), limit)
457        } else {
458            (GLOBAL_BUCKET_KEY.to_string(), self.bandwidth_config.max_bytes_per_sec)
459        }
460    }
461}
462
463/// Bandwidth usage statistics for the token bucket
464#[derive(Debug, Clone)]
465pub struct BandwidthStats {
466    /// Current number of available tokens (bytes that can be sent)
467    pub current_tokens: u64,
468    /// Maximum token bucket capacity (burst allowance)
469    pub capacity: u64,
470    /// Token refill rate in bytes per second
471    pub refill_rate_bytes_per_sec: u64,
472}
473
474/// Burst loss state statistics
475#[derive(Debug, Clone)]
476pub struct BurstLossStats {
477    /// Whether currently in a loss burst period
478    pub in_burst: bool,
479    /// Timestamp when the current burst started (if in burst)
480    pub burst_start: Option<Instant>,
481    /// Timestamp when recovery period started (if recovering)
482    pub recovery_start: Option<Instant>,
483}
484
485#[cfg(test)]
486mod tests {
487    use super::*;
488    use std::time::Duration;
489
490    #[tokio::test]
491    async fn test_bandwidth_throttling() {
492        let config = TrafficShapingConfig {
493            bandwidth: BandwidthConfig::new(1000, 100), // 1000 bytes/sec, 100 byte burst
494            burst_loss: BurstLossConfig::default(),
495        };
496        let shaper = TrafficShaper::new(config);
497
498        // Small transfer should succeed immediately
499        let result = shaper.throttle_bandwidth(50, &[]).await;
500        assert!(result.is_ok());
501
502        // Large transfer should be throttled (but within burst capacity)
503        let start = Instant::now();
504        let result = shaper.throttle_bandwidth(80, &[]).await; // 50 + 80 = 130 total, need to wait for refill
505        let elapsed = start.elapsed();
506        assert!(result.is_ok());
507        // Should have waited at least some time due to throttling
508        assert!(elapsed >= Duration::from_millis(30)); // At least 30ms for 80 additional bytes at 1000 bytes/sec
509    }
510
511    #[tokio::test]
512    async fn test_burst_loss() {
513        let config = TrafficShapingConfig {
514            bandwidth: BandwidthConfig::default(),
515            burst_loss: BurstLossConfig::new(1.0, 1000, 1.0, 1000), // 100% burst probability, 100% loss
516        };
517        let shaper = TrafficShaper::new(config);
518
519        // First packet should trigger burst and be dropped
520        let should_drop = shaper.should_drop_packet(&[]).await;
521        assert!(should_drop);
522
523        // Subsequent packets in burst should also be dropped
524        for _ in 0..5 {
525            let should_drop = shaper.should_drop_packet(&[]).await;
526            assert!(should_drop);
527        }
528    }
529
530    #[tokio::test]
531    async fn test_bandwidth_tag_override_with_global_unlimited() {
532        let mut bandwidth = BandwidthConfig::default();
533        bandwidth.enabled = true;
534        bandwidth.max_bytes_per_sec = 0;
535        bandwidth.burst_capacity_bytes = 100;
536        bandwidth = bandwidth.with_tag_override("limited".to_string(), 100);
537
538        let shaper = TrafficShaper::new(TrafficShapingConfig {
539            bandwidth,
540            burst_loss: BurstLossConfig::default(),
541        });
542
543        let tags = vec!["limited".to_string()];
544        shaper
545            .throttle_bandwidth(100, &tags)
546            .await
547            .expect("initial transfer should succeed immediately");
548
549        let start = Instant::now();
550        shaper
551            .throttle_bandwidth(100, &tags)
552            .await
553            .expect("tag override should throttle but eventually succeed");
554        assert!(
555            start.elapsed() >= Duration::from_millis(900),
556            "override-specific transfer should respect configured rate"
557        );
558    }
559
560    #[test]
561    fn test_bandwidth_config_overrides() {
562        let mut config = BandwidthConfig::new(1000, 100);
563        config = config.with_tag_override("high-priority".to_string(), 5000);
564
565        assert_eq!(config.get_effective_limit(&[]), 1000);
566        assert_eq!(config.get_effective_limit(&["high-priority".to_string()]), 5000);
567        assert_eq!(
568            config.get_effective_limit(&["low-priority".to_string(), "high-priority".to_string()]),
569            5000
570        );
571    }
572
573    #[test]
574    fn test_burst_loss_effective_config_override() {
575        let override_cfg = BurstLossOverride {
576            burst_probability: 0.8,
577            burst_duration_ms: 2000,
578            loss_rate_during_burst: 0.9,
579            recovery_time_ms: 5000,
580        };
581
582        let config =
583            BurstLossConfig::default().with_tag_override("flaky".to_string(), override_cfg.clone());
584
585        let effective = config.effective_config(&["flaky".to_string()]);
586        assert_eq!(effective.burst_probability, override_cfg.burst_probability);
587        assert_eq!(effective.burst_duration_ms, override_cfg.burst_duration_ms);
588        assert_eq!(effective.loss_rate_during_burst, override_cfg.loss_rate_during_burst);
589        assert_eq!(effective.recovery_time_ms, override_cfg.recovery_time_ms);
590    }
591}