mockforge_core/
traffic_shaping.rs

1//! Traffic shaping beyond latency simulation
2//!
3//! This module provides advanced traffic shaping capabilities including:
4//! - Bandwidth throttling using token bucket algorithm
5//! - Burst packet loss simulation
6//! - Integration with existing latency and fault injection
7
8use crate::{Error, Result};
9use rand::Rng;
10use std::borrow::Cow;
11use std::collections::HashMap;
12use std::sync::Arc;
13use std::time::{Duration, Instant};
14use tokio::sync::{Mutex, RwLock};
15
16const GLOBAL_BUCKET_KEY: &str = "__global__";
17
18/// Bandwidth throttling configuration
19#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
20#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
21pub struct BandwidthConfig {
22    /// Enable bandwidth throttling
23    pub enabled: bool,
24    /// Maximum bandwidth in bytes per second (0 = unlimited)
25    pub max_bytes_per_sec: u64,
26    /// Token bucket capacity in bytes (burst allowance)
27    pub burst_capacity_bytes: u64,
28    /// Tag-based bandwidth overrides
29    pub tag_overrides: HashMap<String, u64>,
30}
31
32impl Default for BandwidthConfig {
33    fn default() -> Self {
34        Self {
35            enabled: false,
36            max_bytes_per_sec: 0,              // Unlimited
37            burst_capacity_bytes: 1024 * 1024, // 1MB burst capacity
38            tag_overrides: HashMap::new(),
39        }
40    }
41}
42
43impl BandwidthConfig {
44    /// Create a new bandwidth configuration
45    pub fn new(max_bytes_per_sec: u64, burst_capacity_bytes: u64) -> Self {
46        Self {
47            enabled: true,
48            max_bytes_per_sec,
49            burst_capacity_bytes,
50            tag_overrides: HashMap::new(),
51        }
52    }
53
54    /// Add a tag-based bandwidth override
55    pub fn with_tag_override(mut self, tag: String, max_bytes_per_sec: u64) -> Self {
56        self.tag_overrides.insert(tag, max_bytes_per_sec);
57        self
58    }
59
60    /// Get the effective bandwidth limit for the given tags
61    pub fn get_effective_limit(&self, tags: &[String]) -> u64 {
62        // Check for tag overrides (use the first matching tag)
63        if let Some(&override_limit) = tags.iter().find_map(|tag| self.tag_overrides.get(tag)) {
64            return override_limit;
65        }
66        self.max_bytes_per_sec
67    }
68}
69
70/// Burst loss configuration for simulating intermittent connectivity issues
71#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
72#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
73pub struct BurstLossConfig {
74    /// Enable burst loss simulation
75    pub enabled: bool,
76    /// Probability of entering a loss burst (0.0 to 1.0)
77    pub burst_probability: f64,
78    /// Duration of loss burst in milliseconds
79    pub burst_duration_ms: u64,
80    /// Packet loss rate during burst (0.0 to 1.0)
81    pub loss_rate_during_burst: f64,
82    /// Recovery time between bursts in milliseconds
83    pub recovery_time_ms: u64,
84    /// Tag-based burst loss overrides
85    pub tag_overrides: HashMap<String, BurstLossOverride>,
86}
87
88/// Tag-specific burst loss configuration override
89#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
90#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
91pub struct BurstLossOverride {
92    /// Probability of entering a loss burst (0.0 to 1.0)
93    pub burst_probability: f64,
94    /// Duration of loss burst in milliseconds
95    pub burst_duration_ms: u64,
96    /// Packet loss rate during burst (0.0 to 1.0)
97    pub loss_rate_during_burst: f64,
98    /// Recovery time between bursts in milliseconds
99    pub recovery_time_ms: u64,
100}
101
102impl Default for BurstLossConfig {
103    fn default() -> Self {
104        Self {
105            enabled: false,
106            burst_probability: 0.1,      // 10% chance of burst
107            burst_duration_ms: 5000,     // 5 second bursts
108            loss_rate_during_burst: 0.5, // 50% loss during burst
109            recovery_time_ms: 30000,     // 30 second recovery
110            tag_overrides: HashMap::new(),
111        }
112    }
113}
114
115impl BurstLossConfig {
116    /// Create a new burst loss configuration
117    pub fn new(
118        burst_probability: f64,
119        burst_duration_ms: u64,
120        loss_rate: f64,
121        recovery_time_ms: u64,
122    ) -> Self {
123        Self {
124            enabled: true,
125            burst_probability: burst_probability.clamp(0.0, 1.0),
126            burst_duration_ms,
127            loss_rate_during_burst: loss_rate.clamp(0.0, 1.0),
128            recovery_time_ms,
129            tag_overrides: HashMap::new(),
130        }
131    }
132
133    /// Add a tag-based burst loss override
134    pub fn with_tag_override(mut self, tag: String, override_config: BurstLossOverride) -> Self {
135        self.tag_overrides.insert(tag, override_config);
136        self
137    }
138
139    /// Get the effective burst loss config for the given tags
140    pub fn effective_config<'a>(&'a self, tags: &[String]) -> Cow<'a, BurstLossConfig> {
141        if let Some(override_config) = tags.iter().find_map(|tag| self.tag_overrides.get(tag)) {
142            let mut temp_config = self.clone();
143            temp_config.burst_probability = override_config.burst_probability;
144            temp_config.burst_duration_ms = override_config.burst_duration_ms;
145            temp_config.loss_rate_during_burst = override_config.loss_rate_during_burst;
146            temp_config.recovery_time_ms = override_config.recovery_time_ms;
147            Cow::Owned(temp_config)
148        } else {
149            Cow::Borrowed(self)
150        }
151    }
152}
153
154/// Token bucket for bandwidth throttling
155#[derive(Debug)]
156struct TokenBucket {
157    /// Current number of tokens (bytes that can be sent)
158    tokens: f64,
159    /// Maximum capacity of the bucket
160    capacity: f64,
161    /// Rate of token replenishment (tokens per second)
162    refill_rate: f64,
163    /// Last refill timestamp
164    last_refill: Instant,
165}
166
167impl TokenBucket {
168    /// Create a new token bucket
169    fn new(capacity: u64, refill_rate_bytes_per_sec: u64) -> Self {
170        Self {
171            tokens: capacity as f64,
172            capacity: capacity as f64,
173            refill_rate: refill_rate_bytes_per_sec as f64,
174            last_refill: Instant::now(),
175        }
176    }
177
178    /// Refill tokens based on elapsed time
179    fn refill(&mut self) {
180        let now = Instant::now();
181        let elapsed = now.duration_since(self.last_refill).as_secs_f64();
182        let tokens_to_add = elapsed * self.refill_rate;
183
184        self.tokens = (self.tokens + tokens_to_add).min(self.capacity);
185        self.last_refill = now;
186    }
187
188    /// Try to consume tokens for the given number of bytes
189    fn try_consume(&mut self, bytes: u64) -> bool {
190        self.refill();
191        if self.tokens >= bytes as f64 {
192            self.tokens -= bytes as f64;
193            true
194        } else {
195            false
196        }
197    }
198
199    /// Get the time to wait until enough tokens are available
200    fn time_until_available(&mut self, bytes: u64) -> Duration {
201        self.refill();
202        if self.tokens >= bytes as f64 {
203            Duration::ZERO
204        } else {
205            let tokens_needed = bytes as f64 - self.tokens;
206            let seconds_needed = tokens_needed / self.refill_rate;
207            Duration::from_secs_f64(seconds_needed)
208        }
209    }
210}
211
212/// Burst loss state machine
213#[derive(Debug)]
214struct BurstLossState {
215    /// Whether currently in a loss burst
216    in_burst: bool,
217    /// When the current burst started
218    burst_start: Option<Instant>,
219    /// When the current recovery period started
220    recovery_start: Option<Instant>,
221}
222
223impl BurstLossState {
224    fn new() -> Self {
225        Self {
226            in_burst: false,
227            burst_start: None,
228            recovery_start: None,
229        }
230    }
231
232    /// Determine if a packet should be lost based on current state
233    fn should_drop_packet(&mut self, config: &BurstLossConfig) -> bool {
234        if !config.enabled {
235            return false;
236        }
237
238        let now = Instant::now();
239
240        match (self.in_burst, self.burst_start, self.recovery_start) {
241            (true, Some(burst_start), _) => {
242                // Currently in burst - check if burst should end
243                let burst_duration = now.duration_since(burst_start);
244                if burst_duration >= Duration::from_millis(config.burst_duration_ms) {
245                    // End burst and start recovery
246                    self.in_burst = false;
247                    self.burst_start = None;
248                    self.recovery_start = Some(now);
249                    false // Don't drop this packet
250                } else {
251                    // Still in burst - apply loss rate
252                    let mut rng = rand::rng();
253                    rng.random_bool(config.loss_rate_during_burst)
254                }
255            }
256            (true, None, _) => {
257                // Invalid state: in burst but no burst start time - reset to normal
258                self.in_burst = false;
259                false
260            }
261            (false, _, Some(recovery_start)) => {
262                // In recovery - check if recovery should end
263                let recovery_duration = now.duration_since(recovery_start);
264                if recovery_duration >= Duration::from_millis(config.recovery_time_ms) {
265                    // End recovery
266                    self.recovery_start = None;
267                    // Check if we should start a new burst
268                    let mut rng = rand::rng();
269                    if rng.random_bool(config.burst_probability) {
270                        self.in_burst = true;
271                        self.burst_start = Some(now);
272                        // Apply loss rate for the first packet of the burst
273                        rng.random_bool(config.loss_rate_during_burst)
274                    } else {
275                        false
276                    }
277                } else {
278                    false // Still in recovery
279                }
280            }
281            (false, _, None) => {
282                // Not in burst or recovery - check if we should start a burst
283                let mut rng = rand::rng();
284                if rng.random_bool(config.burst_probability) {
285                    self.in_burst = true;
286                    self.burst_start = Some(now);
287                    rng.random_bool(config.loss_rate_during_burst)
288                } else {
289                    false
290                }
291            }
292        }
293    }
294}
295
296/// Traffic shaping configuration combining all features
297#[derive(Debug, Clone, serde::Deserialize, serde::Serialize, Default)]
298#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
299pub struct TrafficShapingConfig {
300    /// Bandwidth throttling configuration
301    pub bandwidth: BandwidthConfig,
302    /// Burst loss configuration
303    pub burst_loss: BurstLossConfig,
304}
305
306/// Main traffic shaper combining bandwidth throttling and burst loss
307#[derive(Debug, Clone)]
308pub struct TrafficShaper {
309    /// Bandwidth configuration
310    bandwidth_config: BandwidthConfig,
311    /// Burst loss configuration
312    burst_loss_config: BurstLossConfig,
313    /// Token buckets keyed by effective tag/group
314    token_buckets: Arc<RwLock<HashMap<String, Arc<Mutex<TokenBucket>>>>>,
315    /// Burst loss state
316    burst_loss_state: Arc<Mutex<BurstLossState>>,
317}
318
319impl TrafficShaper {
320    /// Create a new traffic shaper
321    pub fn new(config: TrafficShapingConfig) -> Self {
322        Self {
323            bandwidth_config: config.bandwidth,
324            burst_loss_config: config.burst_loss,
325            token_buckets: Arc::new(RwLock::new(HashMap::new())),
326            burst_loss_state: Arc::new(Mutex::new(BurstLossState::new())),
327        }
328    }
329
330    /// Apply bandwidth throttling to a data transfer
331    pub async fn throttle_bandwidth(&self, data_size: u64, tags: &[String]) -> Result<()> {
332        if !self.bandwidth_config.enabled {
333            return Ok(());
334        }
335
336        let (bucket_key, effective_limit) = self.resolve_bandwidth_bucket(tags);
337
338        if effective_limit == 0 {
339            return Ok(());
340        }
341
342        let bucket_arc = self.get_or_create_bucket(&bucket_key, effective_limit).await;
343
344        {
345            let mut bucket = bucket_arc.lock().await;
346            if bucket.try_consume(data_size) {
347                return Ok(());
348            }
349
350            let wait_time = bucket.time_until_available(data_size);
351            drop(bucket);
352
353            if wait_time.is_zero() {
354                return Err(Error::generic(format!(
355                    "Failed to acquire bandwidth tokens for {} bytes",
356                    data_size
357                )));
358            }
359
360            tokio::time::sleep(wait_time).await;
361        }
362
363        let mut bucket = bucket_arc.lock().await;
364        if bucket.try_consume(data_size) {
365            Ok(())
366        } else {
367            Err(Error::generic(format!(
368                "Failed to acquire bandwidth tokens for {} bytes",
369                data_size
370            )))
371        }
372    }
373
374    /// Check if a packet should be dropped due to burst loss
375    pub async fn should_drop_packet(&self, tags: &[String]) -> bool {
376        if !self.burst_loss_config.enabled {
377            return false;
378        }
379
380        let effective_config = self.burst_loss_config.effective_config(tags);
381        let mut state = self.burst_loss_state.lock().await;
382        state.should_drop_packet(effective_config.as_ref())
383    }
384
385    /// Process a data transfer with both bandwidth throttling and burst loss
386    pub async fn process_transfer(
387        &self,
388        data_size: u64,
389        tags: &[String],
390    ) -> Result<Option<Duration>> {
391        // First, apply bandwidth throttling
392        self.throttle_bandwidth(data_size, tags).await?;
393
394        // Then, check for burst loss
395        if self.should_drop_packet(tags).await {
396            return Ok(Some(Duration::from_millis(100))); // Simulate packet timeout
397        }
398
399        Ok(None)
400    }
401
402    /// Get current bandwidth usage statistics
403    pub async fn get_bandwidth_stats(&self) -> BandwidthStats {
404        let maybe_bucket = {
405            let guard = self.token_buckets.read().await;
406            guard.get(GLOBAL_BUCKET_KEY).cloned()
407        };
408
409        if let Some(bucket_arc) = maybe_bucket {
410            let bucket = bucket_arc.lock().await;
411            BandwidthStats {
412                current_tokens: bucket.tokens as u64,
413                capacity: bucket.capacity as u64,
414                refill_rate_bytes_per_sec: bucket.refill_rate as u64,
415            }
416        } else {
417            BandwidthStats {
418                current_tokens: self.bandwidth_config.burst_capacity_bytes,
419                capacity: self.bandwidth_config.burst_capacity_bytes,
420                refill_rate_bytes_per_sec: self.bandwidth_config.max_bytes_per_sec,
421            }
422        }
423    }
424
425    /// Get current burst loss state
426    pub async fn get_burst_loss_stats(&self) -> BurstLossStats {
427        let state = self.burst_loss_state.lock().await;
428        BurstLossStats {
429            in_burst: state.in_burst,
430            burst_start: state.burst_start,
431            recovery_start: state.recovery_start,
432        }
433    }
434
435    async fn get_or_create_bucket(
436        &self,
437        bucket_key: &str,
438        effective_limit: u64,
439    ) -> Arc<Mutex<TokenBucket>> {
440        if let Some(existing) = self.token_buckets.read().await.get(bucket_key).cloned() {
441            return existing;
442        }
443
444        let mut buckets = self.token_buckets.write().await;
445        buckets
446            .entry(bucket_key.to_string())
447            .or_insert_with(|| {
448                Arc::new(Mutex::new(TokenBucket::new(
449                    self.bandwidth_config.burst_capacity_bytes,
450                    effective_limit,
451                )))
452            })
453            .clone()
454    }
455
456    fn resolve_bandwidth_bucket(&self, tags: &[String]) -> (String, u64) {
457        if let Some((tag, limit)) = tags.iter().find_map(|tag| {
458            self.bandwidth_config.tag_overrides.get(tag).map(|limit| (tag.as_str(), *limit))
459        }) {
460            (format!("tag:{}", tag), limit)
461        } else {
462            (GLOBAL_BUCKET_KEY.to_string(), self.bandwidth_config.max_bytes_per_sec)
463        }
464    }
465}
466
467/// Bandwidth usage statistics for the token bucket
468#[derive(Debug, Clone)]
469pub struct BandwidthStats {
470    /// Current number of available tokens (bytes that can be sent)
471    pub current_tokens: u64,
472    /// Maximum token bucket capacity (burst allowance)
473    pub capacity: u64,
474    /// Token refill rate in bytes per second
475    pub refill_rate_bytes_per_sec: u64,
476}
477
478/// Burst loss state statistics
479#[derive(Debug, Clone)]
480pub struct BurstLossStats {
481    /// Whether currently in a loss burst period
482    pub in_burst: bool,
483    /// Timestamp when the current burst started (if in burst)
484    pub burst_start: Option<Instant>,
485    /// Timestamp when recovery period started (if recovering)
486    pub recovery_start: Option<Instant>,
487}
488
489#[cfg(test)]
490mod tests {
491    use super::*;
492    use std::time::Duration;
493
494    #[tokio::test]
495    async fn test_bandwidth_throttling() {
496        let config = TrafficShapingConfig {
497            bandwidth: BandwidthConfig::new(1000, 100), // 1000 bytes/sec, 100 byte burst
498            burst_loss: BurstLossConfig::default(),
499        };
500        let shaper = TrafficShaper::new(config);
501
502        // Small transfer should succeed immediately
503        let result = shaper.throttle_bandwidth(50, &[]).await;
504        assert!(result.is_ok());
505
506        // Large transfer should be throttled (but within burst capacity)
507        let start = Instant::now();
508        let result = shaper.throttle_bandwidth(80, &[]).await; // 50 + 80 = 130 total, need to wait for refill
509        let elapsed = start.elapsed();
510        assert!(result.is_ok());
511        // Should have waited at least some time due to throttling
512        assert!(elapsed >= Duration::from_millis(30)); // At least 30ms for 80 additional bytes at 1000 bytes/sec
513    }
514
515    #[tokio::test]
516    async fn test_burst_loss() {
517        let config = TrafficShapingConfig {
518            bandwidth: BandwidthConfig::default(),
519            burst_loss: BurstLossConfig::new(1.0, 1000, 1.0, 1000), // 100% burst probability, 100% loss
520        };
521        let shaper = TrafficShaper::new(config);
522
523        // First packet should trigger burst and be dropped
524        let should_drop = shaper.should_drop_packet(&[]).await;
525        assert!(should_drop);
526
527        // Subsequent packets in burst should also be dropped
528        for _ in 0..5 {
529            let should_drop = shaper.should_drop_packet(&[]).await;
530            assert!(should_drop);
531        }
532    }
533
534    #[tokio::test]
535    async fn test_bandwidth_tag_override_with_global_unlimited() {
536        let mut bandwidth = BandwidthConfig::default();
537        bandwidth.enabled = true;
538        bandwidth.max_bytes_per_sec = 0;
539        bandwidth.burst_capacity_bytes = 100;
540        bandwidth = bandwidth.with_tag_override("limited".to_string(), 100);
541
542        let shaper = TrafficShaper::new(TrafficShapingConfig {
543            bandwidth,
544            burst_loss: BurstLossConfig::default(),
545        });
546
547        let tags = vec!["limited".to_string()];
548        shaper
549            .throttle_bandwidth(100, &tags)
550            .await
551            .expect("initial transfer should succeed immediately");
552
553        let start = Instant::now();
554        shaper
555            .throttle_bandwidth(100, &tags)
556            .await
557            .expect("tag override should throttle but eventually succeed");
558        assert!(
559            start.elapsed() >= Duration::from_millis(900),
560            "override-specific transfer should respect configured rate"
561        );
562    }
563
564    #[test]
565    fn test_bandwidth_config_overrides() {
566        let mut config = BandwidthConfig::new(1000, 100);
567        config = config.with_tag_override("high-priority".to_string(), 5000);
568
569        assert_eq!(config.get_effective_limit(&[]), 1000);
570        assert_eq!(config.get_effective_limit(&["high-priority".to_string()]), 5000);
571        assert_eq!(
572            config.get_effective_limit(&["low-priority".to_string(), "high-priority".to_string()]),
573            5000
574        );
575    }
576
577    #[test]
578    fn test_burst_loss_effective_config_override() {
579        let override_cfg = BurstLossOverride {
580            burst_probability: 0.8,
581            burst_duration_ms: 2000,
582            loss_rate_during_burst: 0.9,
583            recovery_time_ms: 5000,
584        };
585
586        let config =
587            BurstLossConfig::default().with_tag_override("flaky".to_string(), override_cfg.clone());
588
589        let effective = config.effective_config(&["flaky".to_string()]);
590        assert_eq!(effective.burst_probability, override_cfg.burst_probability);
591        assert_eq!(effective.burst_duration_ms, override_cfg.burst_duration_ms);
592        assert_eq!(effective.loss_rate_during_burst, override_cfg.loss_rate_during_burst);
593        assert_eq!(effective.recovery_time_ms, override_cfg.recovery_time_ms);
594    }
595}