mockforge_core/
traffic_shaping.rs

1//! Traffic shaping beyond latency simulation
2//!
3//! This module provides advanced traffic shaping capabilities including:
4//! - Bandwidth throttling using token bucket algorithm
5//! - Burst packet loss simulation
6//! - Integration with existing latency and fault injection
7
8use crate::{Error, Result};
9use rand::Rng;
10use std::collections::HashMap;
11use std::sync::Arc;
12use std::time::{Duration, Instant};
13use tokio::sync::Mutex;
14
15/// Bandwidth throttling configuration
16#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
17pub struct BandwidthConfig {
18    /// Enable bandwidth throttling
19    pub enabled: bool,
20    /// Maximum bandwidth in bytes per second (0 = unlimited)
21    pub max_bytes_per_sec: u64,
22    /// Token bucket capacity in bytes (burst allowance)
23    pub burst_capacity_bytes: u64,
24    /// Tag-based bandwidth overrides
25    pub tag_overrides: HashMap<String, u64>,
26}
27
28impl Default for BandwidthConfig {
29    fn default() -> Self {
30        Self {
31            enabled: false,
32            max_bytes_per_sec: 0,              // Unlimited
33            burst_capacity_bytes: 1024 * 1024, // 1MB burst capacity
34            tag_overrides: HashMap::new(),
35        }
36    }
37}
38
39impl BandwidthConfig {
40    /// Create a new bandwidth configuration
41    pub fn new(max_bytes_per_sec: u64, burst_capacity_bytes: u64) -> Self {
42        Self {
43            enabled: true,
44            max_bytes_per_sec,
45            burst_capacity_bytes,
46            tag_overrides: HashMap::new(),
47        }
48    }
49
50    /// Add a tag-based bandwidth override
51    pub fn with_tag_override(mut self, tag: String, max_bytes_per_sec: u64) -> Self {
52        self.tag_overrides.insert(tag, max_bytes_per_sec);
53        self
54    }
55
56    /// Get the effective bandwidth limit for the given tags
57    pub fn get_effective_limit(&self, tags: &[String]) -> u64 {
58        // Check for tag overrides (use the first matching tag)
59        if let Some(&override_limit) = tags.iter().find_map(|tag| self.tag_overrides.get(tag)) {
60            return override_limit;
61        }
62        self.max_bytes_per_sec
63    }
64}
65
66/// Burst loss configuration for simulating intermittent connectivity issues
67#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
68pub struct BurstLossConfig {
69    /// Enable burst loss simulation
70    pub enabled: bool,
71    /// Probability of entering a loss burst (0.0 to 1.0)
72    pub burst_probability: f64,
73    /// Duration of loss burst in milliseconds
74    pub burst_duration_ms: u64,
75    /// Packet loss rate during burst (0.0 to 1.0)
76    pub loss_rate_during_burst: f64,
77    /// Recovery time between bursts in milliseconds
78    pub recovery_time_ms: u64,
79    /// Tag-based burst loss overrides
80    pub tag_overrides: HashMap<String, BurstLossOverride>,
81}
82
83#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
84pub struct BurstLossOverride {
85    pub burst_probability: f64,
86    pub burst_duration_ms: u64,
87    pub loss_rate_during_burst: f64,
88    pub recovery_time_ms: u64,
89}
90
91impl Default for BurstLossConfig {
92    fn default() -> Self {
93        Self {
94            enabled: false,
95            burst_probability: 0.1,      // 10% chance of burst
96            burst_duration_ms: 5000,     // 5 second bursts
97            loss_rate_during_burst: 0.5, // 50% loss during burst
98            recovery_time_ms: 30000,     // 30 second recovery
99            tag_overrides: HashMap::new(),
100        }
101    }
102}
103
104impl BurstLossConfig {
105    /// Create a new burst loss configuration
106    pub fn new(
107        burst_probability: f64,
108        burst_duration_ms: u64,
109        loss_rate: f64,
110        recovery_time_ms: u64,
111    ) -> Self {
112        Self {
113            enabled: true,
114            burst_probability: burst_probability.clamp(0.0, 1.0),
115            burst_duration_ms,
116            loss_rate_during_burst: loss_rate.clamp(0.0, 1.0),
117            recovery_time_ms,
118            tag_overrides: HashMap::new(),
119        }
120    }
121
122    /// Add a tag-based burst loss override
123    pub fn with_tag_override(mut self, tag: String, override_config: BurstLossOverride) -> Self {
124        self.tag_overrides.insert(tag, override_config);
125        self
126    }
127
128    /// Get the effective burst loss config for the given tags
129    pub fn get_effective_config(&self, tags: &[String]) -> &BurstLossConfig {
130        // Check for tag overrides (use the first matching tag)
131        if let Some(override_config) = tags.iter().find_map(|tag| self.tag_overrides.get(tag)) {
132            // Create a temporary config with the override values
133            // This is a bit of a hack, but works for our use case
134            let mut temp_config = self.clone();
135            temp_config.burst_probability = override_config.burst_probability;
136            temp_config.burst_duration_ms = override_config.burst_duration_ms;
137            temp_config.loss_rate_during_burst = override_config.loss_rate_during_burst;
138            temp_config.recovery_time_ms = override_config.recovery_time_ms;
139            return Box::leak(Box::new(temp_config));
140        }
141        self
142    }
143}
144
145/// Token bucket for bandwidth throttling
146#[derive(Debug)]
147struct TokenBucket {
148    /// Current number of tokens (bytes that can be sent)
149    tokens: f64,
150    /// Maximum capacity of the bucket
151    capacity: f64,
152    /// Rate of token replenishment (tokens per second)
153    refill_rate: f64,
154    /// Last refill timestamp
155    last_refill: Instant,
156}
157
158impl TokenBucket {
159    /// Create a new token bucket
160    fn new(capacity: u64, refill_rate_bytes_per_sec: u64) -> Self {
161        Self {
162            tokens: capacity as f64,
163            capacity: capacity as f64,
164            refill_rate: refill_rate_bytes_per_sec as f64,
165            last_refill: Instant::now(),
166        }
167    }
168
169    /// Refill tokens based on elapsed time
170    fn refill(&mut self) {
171        let now = Instant::now();
172        let elapsed = now.duration_since(self.last_refill).as_secs_f64();
173        let tokens_to_add = elapsed * self.refill_rate;
174
175        self.tokens = (self.tokens + tokens_to_add).min(self.capacity);
176        self.last_refill = now;
177    }
178
179    /// Try to consume tokens for the given number of bytes
180    fn try_consume(&mut self, bytes: u64) -> bool {
181        self.refill();
182        if self.tokens >= bytes as f64 {
183            self.tokens -= bytes as f64;
184            true
185        } else {
186            false
187        }
188    }
189
190    /// Get the time to wait until enough tokens are available
191    fn time_until_available(&mut self, bytes: u64) -> Duration {
192        self.refill();
193        if self.tokens >= bytes as f64 {
194            Duration::ZERO
195        } else {
196            let tokens_needed = bytes as f64 - self.tokens;
197            let seconds_needed = tokens_needed / self.refill_rate;
198            Duration::from_secs_f64(seconds_needed)
199        }
200    }
201}
202
203/// Burst loss state machine
204#[derive(Debug)]
205struct BurstLossState {
206    /// Whether currently in a loss burst
207    in_burst: bool,
208    /// When the current burst started
209    burst_start: Option<Instant>,
210    /// When the current recovery period started
211    recovery_start: Option<Instant>,
212}
213
214impl BurstLossState {
215    fn new() -> Self {
216        Self {
217            in_burst: false,
218            burst_start: None,
219            recovery_start: None,
220        }
221    }
222
223    /// Determine if a packet should be lost based on current state
224    fn should_drop_packet(&mut self, config: &BurstLossConfig) -> bool {
225        if !config.enabled {
226            return false;
227        }
228
229        let now = Instant::now();
230
231        match (self.in_burst, self.burst_start, self.recovery_start) {
232            (true, Some(burst_start), _) => {
233                // Currently in burst - check if burst should end
234                let burst_duration = now.duration_since(burst_start);
235                if burst_duration >= Duration::from_millis(config.burst_duration_ms) {
236                    // End burst and start recovery
237                    self.in_burst = false;
238                    self.burst_start = None;
239                    self.recovery_start = Some(now);
240                    false // Don't drop this packet
241                } else {
242                    // Still in burst - apply loss rate
243                    let mut rng = rand::rng();
244                    rng.random_bool(config.loss_rate_during_burst)
245                }
246            }
247            (true, None, _) => {
248                // Invalid state: in burst but no burst start time - reset to normal
249                self.in_burst = false;
250                false
251            }
252            (false, _, Some(recovery_start)) => {
253                // In recovery - check if recovery should end
254                let recovery_duration = now.duration_since(recovery_start);
255                if recovery_duration >= Duration::from_millis(config.recovery_time_ms) {
256                    // End recovery
257                    self.recovery_start = None;
258                    // Check if we should start a new burst
259                    let mut rng = rand::rng();
260                    if rng.random_bool(config.burst_probability) {
261                        self.in_burst = true;
262                        self.burst_start = Some(now);
263                        // Apply loss rate for the first packet of the burst
264                        rng.random_bool(config.loss_rate_during_burst)
265                    } else {
266                        false
267                    }
268                } else {
269                    false // Still in recovery
270                }
271            }
272            (false, _, None) => {
273                // Not in burst or recovery - check if we should start a burst
274                let mut rng = rand::rng();
275                if rng.random_bool(config.burst_probability) {
276                    self.in_burst = true;
277                    self.burst_start = Some(now);
278                    rng.random_bool(config.loss_rate_during_burst)
279                } else {
280                    false
281                }
282            }
283        }
284    }
285}
286
287/// Traffic shaping configuration combining all features
288#[derive(Debug, Clone, serde::Deserialize, serde::Serialize, Default)]
289pub struct TrafficShapingConfig {
290    /// Bandwidth throttling configuration
291    pub bandwidth: BandwidthConfig,
292    /// Burst loss configuration
293    pub burst_loss: BurstLossConfig,
294}
295
296/// Main traffic shaper combining bandwidth throttling and burst loss
297#[derive(Debug, Clone)]
298pub struct TrafficShaper {
299    /// Bandwidth configuration
300    bandwidth_config: BandwidthConfig,
301    /// Burst loss configuration
302    burst_loss_config: BurstLossConfig,
303    /// Token bucket for bandwidth throttling (per connection/IP could be added later)
304    token_bucket: Arc<Mutex<TokenBucket>>,
305    /// Burst loss state
306    burst_loss_state: Arc<Mutex<BurstLossState>>,
307}
308
309impl TrafficShaper {
310    /// Create a new traffic shaper
311    pub fn new(config: TrafficShapingConfig) -> Self {
312        let token_bucket = if config.bandwidth.enabled && config.bandwidth.max_bytes_per_sec > 0 {
313            TokenBucket::new(
314                config.bandwidth.burst_capacity_bytes,
315                config.bandwidth.max_bytes_per_sec,
316            )
317        } else {
318            // Unlimited bandwidth - create a bucket that never blocks
319            TokenBucket::new(u64::MAX, u64::MAX)
320        };
321
322        Self {
323            bandwidth_config: config.bandwidth,
324            burst_loss_config: config.burst_loss,
325            token_bucket: Arc::new(Mutex::new(token_bucket)),
326            burst_loss_state: Arc::new(Mutex::new(BurstLossState::new())),
327        }
328    }
329
330    /// Apply bandwidth throttling to a data transfer
331    pub async fn throttle_bandwidth(&self, data_size: u64, tags: &[String]) -> Result<()> {
332        if !self.bandwidth_config.enabled {
333            return Ok(());
334        }
335
336        let effective_limit = self.bandwidth_config.get_effective_limit(tags);
337        if effective_limit == 0 {
338            // Unlimited bandwidth
339            return Ok(());
340        }
341
342        let mut bucket = self.token_bucket.lock().await;
343
344        if !bucket.try_consume(data_size) {
345            // Wait for tokens to become available
346            let wait_time = bucket.time_until_available(data_size);
347            if !wait_time.is_zero() {
348                tokio::time::sleep(wait_time).await;
349            }
350            // Try again (should succeed now)
351            if !bucket.try_consume(data_size) {
352                return Err(Error::generic(format!(
353                    "Failed to acquire bandwidth tokens for {} bytes",
354                    data_size
355                )));
356            }
357        }
358
359        Ok(())
360    }
361
362    /// Check if a packet should be dropped due to burst loss
363    pub async fn should_drop_packet(&self, tags: &[String]) -> bool {
364        if !self.burst_loss_config.enabled {
365            return false;
366        }
367
368        let effective_config = self.burst_loss_config.get_effective_config(tags);
369        let mut state = self.burst_loss_state.lock().await;
370        state.should_drop_packet(effective_config)
371    }
372
373    /// Process a data transfer with both bandwidth throttling and burst loss
374    pub async fn process_transfer(
375        &self,
376        data_size: u64,
377        tags: &[String],
378    ) -> Result<Option<Duration>> {
379        // First, apply bandwidth throttling
380        self.throttle_bandwidth(data_size, tags).await?;
381
382        // Then, check for burst loss
383        if self.should_drop_packet(tags).await {
384            return Ok(Some(Duration::from_millis(100))); // Simulate packet timeout
385        }
386
387        Ok(None)
388    }
389
390    /// Get current bandwidth usage statistics
391    pub async fn get_bandwidth_stats(&self) -> BandwidthStats {
392        let bucket = self.token_bucket.lock().await;
393        BandwidthStats {
394            current_tokens: bucket.tokens as u64,
395            capacity: bucket.capacity as u64,
396            refill_rate_bytes_per_sec: bucket.refill_rate as u64,
397        }
398    }
399
400    /// Get current burst loss state
401    pub async fn get_burst_loss_stats(&self) -> BurstLossStats {
402        let state = self.burst_loss_state.lock().await;
403        BurstLossStats {
404            in_burst: state.in_burst,
405            burst_start: state.burst_start,
406            recovery_start: state.recovery_start,
407        }
408    }
409}
410
411/// Bandwidth usage statistics
412#[derive(Debug, Clone)]
413pub struct BandwidthStats {
414    pub current_tokens: u64,
415    pub capacity: u64,
416    pub refill_rate_bytes_per_sec: u64,
417}
418
419/// Burst loss state statistics
420#[derive(Debug, Clone)]
421pub struct BurstLossStats {
422    pub in_burst: bool,
423    pub burst_start: Option<Instant>,
424    pub recovery_start: Option<Instant>,
425}
426
427#[cfg(test)]
428mod tests {
429    use super::*;
430    use std::time::Duration;
431
432    #[tokio::test]
433    async fn test_bandwidth_throttling() {
434        let config = TrafficShapingConfig {
435            bandwidth: BandwidthConfig::new(1000, 100), // 1000 bytes/sec, 100 byte burst
436            burst_loss: BurstLossConfig::default(),
437        };
438        let shaper = TrafficShaper::new(config);
439
440        // Small transfer should succeed immediately
441        let result = shaper.throttle_bandwidth(50, &[]).await;
442        assert!(result.is_ok());
443
444        // Large transfer should be throttled (but within burst capacity)
445        let start = Instant::now();
446        let result = shaper.throttle_bandwidth(80, &[]).await; // 50 + 80 = 130 total, need to wait for refill
447        let elapsed = start.elapsed();
448        assert!(result.is_ok());
449        // Should have waited at least some time due to throttling
450        assert!(elapsed >= Duration::from_millis(30)); // At least 30ms for 80 additional bytes at 1000 bytes/sec
451    }
452
453    #[tokio::test]
454    async fn test_burst_loss() {
455        let config = TrafficShapingConfig {
456            bandwidth: BandwidthConfig::default(),
457            burst_loss: BurstLossConfig::new(1.0, 1000, 1.0, 1000), // 100% burst probability, 100% loss
458        };
459        let shaper = TrafficShaper::new(config);
460
461        // First packet should trigger burst and be dropped
462        let should_drop = shaper.should_drop_packet(&[]).await;
463        assert!(should_drop);
464
465        // Subsequent packets in burst should also be dropped
466        for _ in 0..5 {
467            let should_drop = shaper.should_drop_packet(&[]).await;
468            assert!(should_drop);
469        }
470    }
471
472    #[test]
473    fn test_bandwidth_config_overrides() {
474        let mut config = BandwidthConfig::new(1000, 100);
475        config = config.with_tag_override("high-priority".to_string(), 5000);
476
477        assert_eq!(config.get_effective_limit(&[]), 1000);
478        assert_eq!(config.get_effective_limit(&["high-priority".to_string()]), 5000);
479        assert_eq!(
480            config.get_effective_limit(&["low-priority".to_string(), "high-priority".to_string()]),
481            5000
482        );
483    }
484}