mockforge_core/
traffic_shaping.rs

1//! Traffic shaping beyond latency simulation
2//!
3//! This module provides advanced traffic shaping capabilities including:
4//! - Bandwidth throttling using token bucket algorithm
5//! - Burst packet loss simulation
6//! - Integration with existing latency and fault injection
7
8use crate::{Error, Result};
9use rand::Rng;
10use std::borrow::Cow;
11use std::collections::HashMap;
12use std::sync::Arc;
13use std::time::{Duration, Instant};
14use tokio::sync::{Mutex, RwLock};
15
16const GLOBAL_BUCKET_KEY: &str = "__global__";
17
18/// Bandwidth throttling configuration
19#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
20pub struct BandwidthConfig {
21    /// Enable bandwidth throttling
22    pub enabled: bool,
23    /// Maximum bandwidth in bytes per second (0 = unlimited)
24    pub max_bytes_per_sec: u64,
25    /// Token bucket capacity in bytes (burst allowance)
26    pub burst_capacity_bytes: u64,
27    /// Tag-based bandwidth overrides
28    pub tag_overrides: HashMap<String, u64>,
29}
30
31impl Default for BandwidthConfig {
32    fn default() -> Self {
33        Self {
34            enabled: false,
35            max_bytes_per_sec: 0,              // Unlimited
36            burst_capacity_bytes: 1024 * 1024, // 1MB burst capacity
37            tag_overrides: HashMap::new(),
38        }
39    }
40}
41
42impl BandwidthConfig {
43    /// Create a new bandwidth configuration
44    pub fn new(max_bytes_per_sec: u64, burst_capacity_bytes: u64) -> Self {
45        Self {
46            enabled: true,
47            max_bytes_per_sec,
48            burst_capacity_bytes,
49            tag_overrides: HashMap::new(),
50        }
51    }
52
53    /// Add a tag-based bandwidth override
54    pub fn with_tag_override(mut self, tag: String, max_bytes_per_sec: u64) -> Self {
55        self.tag_overrides.insert(tag, max_bytes_per_sec);
56        self
57    }
58
59    /// Get the effective bandwidth limit for the given tags
60    pub fn get_effective_limit(&self, tags: &[String]) -> u64 {
61        // Check for tag overrides (use the first matching tag)
62        if let Some(&override_limit) = tags.iter().find_map(|tag| self.tag_overrides.get(tag)) {
63            return override_limit;
64        }
65        self.max_bytes_per_sec
66    }
67}
68
69/// Burst loss configuration for simulating intermittent connectivity issues
70#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
71pub struct BurstLossConfig {
72    /// Enable burst loss simulation
73    pub enabled: bool,
74    /// Probability of entering a loss burst (0.0 to 1.0)
75    pub burst_probability: f64,
76    /// Duration of loss burst in milliseconds
77    pub burst_duration_ms: u64,
78    /// Packet loss rate during burst (0.0 to 1.0)
79    pub loss_rate_during_burst: f64,
80    /// Recovery time between bursts in milliseconds
81    pub recovery_time_ms: u64,
82    /// Tag-based burst loss overrides
83    pub tag_overrides: HashMap<String, BurstLossOverride>,
84}
85
86#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
87pub struct BurstLossOverride {
88    pub burst_probability: f64,
89    pub burst_duration_ms: u64,
90    pub loss_rate_during_burst: f64,
91    pub recovery_time_ms: u64,
92}
93
94impl Default for BurstLossConfig {
95    fn default() -> Self {
96        Self {
97            enabled: false,
98            burst_probability: 0.1,      // 10% chance of burst
99            burst_duration_ms: 5000,     // 5 second bursts
100            loss_rate_during_burst: 0.5, // 50% loss during burst
101            recovery_time_ms: 30000,     // 30 second recovery
102            tag_overrides: HashMap::new(),
103        }
104    }
105}
106
107impl BurstLossConfig {
108    /// Create a new burst loss configuration
109    pub fn new(
110        burst_probability: f64,
111        burst_duration_ms: u64,
112        loss_rate: f64,
113        recovery_time_ms: u64,
114    ) -> Self {
115        Self {
116            enabled: true,
117            burst_probability: burst_probability.clamp(0.0, 1.0),
118            burst_duration_ms,
119            loss_rate_during_burst: loss_rate.clamp(0.0, 1.0),
120            recovery_time_ms,
121            tag_overrides: HashMap::new(),
122        }
123    }
124
125    /// Add a tag-based burst loss override
126    pub fn with_tag_override(mut self, tag: String, override_config: BurstLossOverride) -> Self {
127        self.tag_overrides.insert(tag, override_config);
128        self
129    }
130
131    /// Get the effective burst loss config for the given tags
132    pub fn effective_config<'a>(&'a self, tags: &[String]) -> Cow<'a, BurstLossConfig> {
133        if let Some(override_config) = tags.iter().find_map(|tag| self.tag_overrides.get(tag)) {
134            let mut temp_config = self.clone();
135            temp_config.burst_probability = override_config.burst_probability;
136            temp_config.burst_duration_ms = override_config.burst_duration_ms;
137            temp_config.loss_rate_during_burst = override_config.loss_rate_during_burst;
138            temp_config.recovery_time_ms = override_config.recovery_time_ms;
139            Cow::Owned(temp_config)
140        } else {
141            Cow::Borrowed(self)
142        }
143    }
144}
145
146/// Token bucket for bandwidth throttling
147#[derive(Debug)]
148struct TokenBucket {
149    /// Current number of tokens (bytes that can be sent)
150    tokens: f64,
151    /// Maximum capacity of the bucket
152    capacity: f64,
153    /// Rate of token replenishment (tokens per second)
154    refill_rate: f64,
155    /// Last refill timestamp
156    last_refill: Instant,
157}
158
159impl TokenBucket {
160    /// Create a new token bucket
161    fn new(capacity: u64, refill_rate_bytes_per_sec: u64) -> Self {
162        Self {
163            tokens: capacity as f64,
164            capacity: capacity as f64,
165            refill_rate: refill_rate_bytes_per_sec as f64,
166            last_refill: Instant::now(),
167        }
168    }
169
170    /// Refill tokens based on elapsed time
171    fn refill(&mut self) {
172        let now = Instant::now();
173        let elapsed = now.duration_since(self.last_refill).as_secs_f64();
174        let tokens_to_add = elapsed * self.refill_rate;
175
176        self.tokens = (self.tokens + tokens_to_add).min(self.capacity);
177        self.last_refill = now;
178    }
179
180    /// Try to consume tokens for the given number of bytes
181    fn try_consume(&mut self, bytes: u64) -> bool {
182        self.refill();
183        if self.tokens >= bytes as f64 {
184            self.tokens -= bytes as f64;
185            true
186        } else {
187            false
188        }
189    }
190
191    /// Get the time to wait until enough tokens are available
192    fn time_until_available(&mut self, bytes: u64) -> Duration {
193        self.refill();
194        if self.tokens >= bytes as f64 {
195            Duration::ZERO
196        } else {
197            let tokens_needed = bytes as f64 - self.tokens;
198            let seconds_needed = tokens_needed / self.refill_rate;
199            Duration::from_secs_f64(seconds_needed)
200        }
201    }
202}
203
204/// Burst loss state machine
205#[derive(Debug)]
206struct BurstLossState {
207    /// Whether currently in a loss burst
208    in_burst: bool,
209    /// When the current burst started
210    burst_start: Option<Instant>,
211    /// When the current recovery period started
212    recovery_start: Option<Instant>,
213}
214
215impl BurstLossState {
216    fn new() -> Self {
217        Self {
218            in_burst: false,
219            burst_start: None,
220            recovery_start: None,
221        }
222    }
223
224    /// Determine if a packet should be lost based on current state
225    fn should_drop_packet(&mut self, config: &BurstLossConfig) -> bool {
226        if !config.enabled {
227            return false;
228        }
229
230        let now = Instant::now();
231
232        match (self.in_burst, self.burst_start, self.recovery_start) {
233            (true, Some(burst_start), _) => {
234                // Currently in burst - check if burst should end
235                let burst_duration = now.duration_since(burst_start);
236                if burst_duration >= Duration::from_millis(config.burst_duration_ms) {
237                    // End burst and start recovery
238                    self.in_burst = false;
239                    self.burst_start = None;
240                    self.recovery_start = Some(now);
241                    false // Don't drop this packet
242                } else {
243                    // Still in burst - apply loss rate
244                    let mut rng = rand::rng();
245                    rng.random_bool(config.loss_rate_during_burst)
246                }
247            }
248            (true, None, _) => {
249                // Invalid state: in burst but no burst start time - reset to normal
250                self.in_burst = false;
251                false
252            }
253            (false, _, Some(recovery_start)) => {
254                // In recovery - check if recovery should end
255                let recovery_duration = now.duration_since(recovery_start);
256                if recovery_duration >= Duration::from_millis(config.recovery_time_ms) {
257                    // End recovery
258                    self.recovery_start = None;
259                    // Check if we should start a new burst
260                    let mut rng = rand::rng();
261                    if rng.random_bool(config.burst_probability) {
262                        self.in_burst = true;
263                        self.burst_start = Some(now);
264                        // Apply loss rate for the first packet of the burst
265                        rng.random_bool(config.loss_rate_during_burst)
266                    } else {
267                        false
268                    }
269                } else {
270                    false // Still in recovery
271                }
272            }
273            (false, _, None) => {
274                // Not in burst or recovery - check if we should start a burst
275                let mut rng = rand::rng();
276                if rng.random_bool(config.burst_probability) {
277                    self.in_burst = true;
278                    self.burst_start = Some(now);
279                    rng.random_bool(config.loss_rate_during_burst)
280                } else {
281                    false
282                }
283            }
284        }
285    }
286}
287
288/// Traffic shaping configuration combining all features
289#[derive(Debug, Clone, serde::Deserialize, serde::Serialize, Default)]
290pub struct TrafficShapingConfig {
291    /// Bandwidth throttling configuration
292    pub bandwidth: BandwidthConfig,
293    /// Burst loss configuration
294    pub burst_loss: BurstLossConfig,
295}
296
297/// Main traffic shaper combining bandwidth throttling and burst loss
298#[derive(Debug, Clone)]
299pub struct TrafficShaper {
300    /// Bandwidth configuration
301    bandwidth_config: BandwidthConfig,
302    /// Burst loss configuration
303    burst_loss_config: BurstLossConfig,
304    /// Token buckets keyed by effective tag/group
305    token_buckets: Arc<RwLock<HashMap<String, Arc<Mutex<TokenBucket>>>>>,
306    /// Burst loss state
307    burst_loss_state: Arc<Mutex<BurstLossState>>,
308}
309
310impl TrafficShaper {
311    /// Create a new traffic shaper
312    pub fn new(config: TrafficShapingConfig) -> Self {
313        Self {
314            bandwidth_config: config.bandwidth,
315            burst_loss_config: config.burst_loss,
316            token_buckets: Arc::new(RwLock::new(HashMap::new())),
317            burst_loss_state: Arc::new(Mutex::new(BurstLossState::new())),
318        }
319    }
320
321    /// Apply bandwidth throttling to a data transfer
322    pub async fn throttle_bandwidth(&self, data_size: u64, tags: &[String]) -> Result<()> {
323        if !self.bandwidth_config.enabled {
324            return Ok(());
325        }
326
327        let (bucket_key, effective_limit) = self.resolve_bandwidth_bucket(tags);
328
329        if effective_limit == 0 {
330            return Ok(());
331        }
332
333        let bucket_arc = self.get_or_create_bucket(&bucket_key, effective_limit).await;
334
335        {
336            let mut bucket = bucket_arc.lock().await;
337            if bucket.try_consume(data_size) {
338                return Ok(());
339            }
340
341            let wait_time = bucket.time_until_available(data_size);
342            drop(bucket);
343
344            if wait_time.is_zero() {
345                return Err(Error::generic(format!(
346                    "Failed to acquire bandwidth tokens for {} bytes",
347                    data_size
348                )));
349            }
350
351            tokio::time::sleep(wait_time).await;
352        }
353
354        let mut bucket = bucket_arc.lock().await;
355        if bucket.try_consume(data_size) {
356            Ok(())
357        } else {
358            Err(Error::generic(format!(
359                "Failed to acquire bandwidth tokens for {} bytes",
360                data_size
361            )))
362        }
363    }
364
365    /// Check if a packet should be dropped due to burst loss
366    pub async fn should_drop_packet(&self, tags: &[String]) -> bool {
367        if !self.burst_loss_config.enabled {
368            return false;
369        }
370
371        let effective_config = self.burst_loss_config.effective_config(tags);
372        let mut state = self.burst_loss_state.lock().await;
373        state.should_drop_packet(effective_config.as_ref())
374    }
375
376    /// Process a data transfer with both bandwidth throttling and burst loss
377    pub async fn process_transfer(
378        &self,
379        data_size: u64,
380        tags: &[String],
381    ) -> Result<Option<Duration>> {
382        // First, apply bandwidth throttling
383        self.throttle_bandwidth(data_size, tags).await?;
384
385        // Then, check for burst loss
386        if self.should_drop_packet(tags).await {
387            return Ok(Some(Duration::from_millis(100))); // Simulate packet timeout
388        }
389
390        Ok(None)
391    }
392
393    /// Get current bandwidth usage statistics
394    pub async fn get_bandwidth_stats(&self) -> BandwidthStats {
395        let maybe_bucket = {
396            let guard = self.token_buckets.read().await;
397            guard.get(GLOBAL_BUCKET_KEY).cloned()
398        };
399
400        if let Some(bucket_arc) = maybe_bucket {
401            let bucket = bucket_arc.lock().await;
402            BandwidthStats {
403                current_tokens: bucket.tokens as u64,
404                capacity: bucket.capacity as u64,
405                refill_rate_bytes_per_sec: bucket.refill_rate as u64,
406            }
407        } else {
408            BandwidthStats {
409                current_tokens: self.bandwidth_config.burst_capacity_bytes,
410                capacity: self.bandwidth_config.burst_capacity_bytes,
411                refill_rate_bytes_per_sec: self.bandwidth_config.max_bytes_per_sec,
412            }
413        }
414    }
415
416    /// Get current burst loss state
417    pub async fn get_burst_loss_stats(&self) -> BurstLossStats {
418        let state = self.burst_loss_state.lock().await;
419        BurstLossStats {
420            in_burst: state.in_burst,
421            burst_start: state.burst_start,
422            recovery_start: state.recovery_start,
423        }
424    }
425
426    async fn get_or_create_bucket(
427        &self,
428        bucket_key: &str,
429        effective_limit: u64,
430    ) -> Arc<Mutex<TokenBucket>> {
431        if let Some(existing) = self.token_buckets.read().await.get(bucket_key).cloned() {
432            return existing;
433        }
434
435        let mut buckets = self.token_buckets.write().await;
436        buckets
437            .entry(bucket_key.to_string())
438            .or_insert_with(|| {
439                Arc::new(Mutex::new(TokenBucket::new(
440                    self.bandwidth_config.burst_capacity_bytes,
441                    effective_limit,
442                )))
443            })
444            .clone()
445    }
446
447    fn resolve_bandwidth_bucket(&self, tags: &[String]) -> (String, u64) {
448        if let Some((tag, limit)) = tags.iter().find_map(|tag| {
449            self.bandwidth_config.tag_overrides.get(tag).map(|limit| (tag.as_str(), *limit))
450        }) {
451            (format!("tag:{}", tag), limit)
452        } else {
453            (GLOBAL_BUCKET_KEY.to_string(), self.bandwidth_config.max_bytes_per_sec)
454        }
455    }
456}
457
458/// Bandwidth usage statistics
459#[derive(Debug, Clone)]
460pub struct BandwidthStats {
461    pub current_tokens: u64,
462    pub capacity: u64,
463    pub refill_rate_bytes_per_sec: u64,
464}
465
466/// Burst loss state statistics
467#[derive(Debug, Clone)]
468pub struct BurstLossStats {
469    pub in_burst: bool,
470    pub burst_start: Option<Instant>,
471    pub recovery_start: Option<Instant>,
472}
473
474#[cfg(test)]
475mod tests {
476    use super::*;
477    use std::time::Duration;
478
479    #[tokio::test]
480    async fn test_bandwidth_throttling() {
481        let config = TrafficShapingConfig {
482            bandwidth: BandwidthConfig::new(1000, 100), // 1000 bytes/sec, 100 byte burst
483            burst_loss: BurstLossConfig::default(),
484        };
485        let shaper = TrafficShaper::new(config);
486
487        // Small transfer should succeed immediately
488        let result = shaper.throttle_bandwidth(50, &[]).await;
489        assert!(result.is_ok());
490
491        // Large transfer should be throttled (but within burst capacity)
492        let start = Instant::now();
493        let result = shaper.throttle_bandwidth(80, &[]).await; // 50 + 80 = 130 total, need to wait for refill
494        let elapsed = start.elapsed();
495        assert!(result.is_ok());
496        // Should have waited at least some time due to throttling
497        assert!(elapsed >= Duration::from_millis(30)); // At least 30ms for 80 additional bytes at 1000 bytes/sec
498    }
499
500    #[tokio::test]
501    async fn test_burst_loss() {
502        let config = TrafficShapingConfig {
503            bandwidth: BandwidthConfig::default(),
504            burst_loss: BurstLossConfig::new(1.0, 1000, 1.0, 1000), // 100% burst probability, 100% loss
505        };
506        let shaper = TrafficShaper::new(config);
507
508        // First packet should trigger burst and be dropped
509        let should_drop = shaper.should_drop_packet(&[]).await;
510        assert!(should_drop);
511
512        // Subsequent packets in burst should also be dropped
513        for _ in 0..5 {
514            let should_drop = shaper.should_drop_packet(&[]).await;
515            assert!(should_drop);
516        }
517    }
518
519    #[tokio::test]
520    async fn test_bandwidth_tag_override_with_global_unlimited() {
521        let mut bandwidth = BandwidthConfig::default();
522        bandwidth.enabled = true;
523        bandwidth.max_bytes_per_sec = 0;
524        bandwidth.burst_capacity_bytes = 100;
525        bandwidth = bandwidth.with_tag_override("limited".to_string(), 100);
526
527        let shaper = TrafficShaper::new(TrafficShapingConfig {
528            bandwidth,
529            burst_loss: BurstLossConfig::default(),
530        });
531
532        let tags = vec!["limited".to_string()];
533        shaper
534            .throttle_bandwidth(100, &tags)
535            .await
536            .expect("initial transfer should succeed immediately");
537
538        let start = Instant::now();
539        shaper
540            .throttle_bandwidth(100, &tags)
541            .await
542            .expect("tag override should throttle but eventually succeed");
543        assert!(
544            start.elapsed() >= Duration::from_millis(900),
545            "override-specific transfer should respect configured rate"
546        );
547    }
548
549    #[test]
550    fn test_bandwidth_config_overrides() {
551        let mut config = BandwidthConfig::new(1000, 100);
552        config = config.with_tag_override("high-priority".to_string(), 5000);
553
554        assert_eq!(config.get_effective_limit(&[]), 1000);
555        assert_eq!(config.get_effective_limit(&["high-priority".to_string()]), 5000);
556        assert_eq!(
557            config.get_effective_limit(&["low-priority".to_string(), "high-priority".to_string()]),
558            5000
559        );
560    }
561
562    #[test]
563    fn test_burst_loss_effective_config_override() {
564        let override_cfg = BurstLossOverride {
565            burst_probability: 0.8,
566            burst_duration_ms: 2000,
567            loss_rate_during_burst: 0.9,
568            recovery_time_ms: 5000,
569        };
570
571        let config =
572            BurstLossConfig::default().with_tag_override("flaky".to_string(), override_cfg.clone());
573
574        let effective = config.effective_config(&["flaky".to_string()]);
575        assert_eq!(effective.burst_probability, override_cfg.burst_probability);
576        assert_eq!(effective.burst_duration_ms, override_cfg.burst_duration_ms);
577        assert_eq!(effective.loss_rate_during_burst, override_cfg.loss_rate_during_burst);
578        assert_eq!(effective.recovery_time_ms, override_cfg.recovery_time_ms);
579    }
580}