1use crate::Result;
4use rand::Rng;
5use std::collections::HashMap;
6use std::time::Duration;
7
8#[derive(Debug, Clone, serde::Deserialize, serde::Serialize, Default)]
10#[serde(rename_all = "lowercase")]
11pub enum LatencyDistribution {
12 #[default]
14 Fixed,
15 Normal,
17 Pareto,
19}
20
21#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
23pub struct LatencyProfile {
24 pub base_ms: u64,
26 pub jitter_ms: u64,
28 #[serde(default)]
30 pub distribution: LatencyDistribution,
31 #[serde(default)]
33 pub std_dev_ms: Option<f64>,
34 #[serde(default)]
36 pub pareto_shape: Option<f64>,
37 #[serde(default)]
39 pub min_ms: u64,
40 #[serde(default)]
42 pub max_ms: Option<u64>,
43 pub tag_overrides: HashMap<String, u64>,
45}
46
47impl Default for LatencyProfile {
48 fn default() -> Self {
49 Self {
50 base_ms: 50, jitter_ms: 20, distribution: LatencyDistribution::Fixed,
53 std_dev_ms: None,
54 pareto_shape: None,
55 min_ms: 0,
56 max_ms: None,
57 tag_overrides: HashMap::new(),
58 }
59 }
60}
61
62impl LatencyProfile {
63 pub fn new(base_ms: u64, jitter_ms: u64) -> Self {
65 Self {
66 base_ms,
67 jitter_ms,
68 distribution: LatencyDistribution::Fixed,
69 std_dev_ms: None,
70 pareto_shape: None,
71 min_ms: 0,
72 max_ms: None,
73 tag_overrides: HashMap::new(),
74 }
75 }
76
77 pub fn with_normal_distribution(base_ms: u64, std_dev_ms: f64) -> Self {
79 Self {
80 base_ms,
81 jitter_ms: 0, distribution: LatencyDistribution::Normal,
83 std_dev_ms: Some(std_dev_ms),
84 pareto_shape: None,
85 min_ms: 0,
86 max_ms: None,
87 tag_overrides: HashMap::new(),
88 }
89 }
90
91 pub fn with_pareto_distribution(base_ms: u64, shape: f64) -> Self {
93 Self {
94 base_ms,
95 jitter_ms: 0, distribution: LatencyDistribution::Pareto,
97 std_dev_ms: None,
98 pareto_shape: Some(shape),
99 min_ms: 0,
100 max_ms: None,
101 tag_overrides: HashMap::new(),
102 }
103 }
104
105 pub fn with_tag_override(mut self, tag: String, latency_ms: u64) -> Self {
107 self.tag_overrides.insert(tag, latency_ms);
108 self
109 }
110
111 pub fn with_min_ms(mut self, min_ms: u64) -> Self {
113 self.min_ms = min_ms;
114 self
115 }
116
117 pub fn with_max_ms(mut self, max_ms: u64) -> Self {
119 self.max_ms = Some(max_ms);
120 self
121 }
122
123 pub fn calculate_latency(&self, tags: &[String]) -> Duration {
125 let mut rng = rand::rng();
126
127 if let Some(&override_ms) = tags.iter().find_map(|tag| self.tag_overrides.get(tag)) {
130 return Duration::from_millis(override_ms);
131 }
132
133 let mut latency_ms = match self.distribution {
134 LatencyDistribution::Fixed => {
135 let jitter = if self.jitter_ms > 0 {
137 rng.random_range(0..=self.jitter_ms * 2).saturating_sub(self.jitter_ms)
138 } else {
139 0
140 };
141 self.base_ms.saturating_add(jitter)
142 }
143 LatencyDistribution::Normal => {
144 let std_dev = self.std_dev_ms.unwrap_or((self.base_ms as f64) * 0.2);
146 let mean = self.base_ms as f64;
147
148 let u1: f64 = rng.random();
150 let u2: f64 = rng.random();
151
152 let z0 = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
154 (mean + std_dev * z0).max(0.0) as u64
155 }
156 LatencyDistribution::Pareto => {
157 let shape = self.pareto_shape.unwrap_or(2.0);
159 let scale = self.base_ms as f64;
160
161 let u: f64 = rng.random();
163 (scale / (1.0 - u).powf(1.0 / shape)) as u64
164 }
165 };
166
167 latency_ms = latency_ms.max(self.min_ms);
169 if let Some(max_ms) = self.max_ms {
170 latency_ms = latency_ms.min(max_ms);
171 }
172
173 Duration::from_millis(latency_ms)
174 }
175}
176
177#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
179pub struct FaultConfig {
180 pub failure_rate: f64,
182 pub status_codes: Vec<u16>,
184 pub error_responses: HashMap<String, serde_json::Value>,
186}
187
188impl Default for FaultConfig {
189 fn default() -> Self {
190 Self {
191 failure_rate: 0.0,
192 status_codes: vec![500, 502, 503, 504],
193 error_responses: HashMap::new(),
194 }
195 }
196}
197
198impl FaultConfig {
199 pub fn new(failure_rate: f64) -> Self {
201 Self {
202 failure_rate: failure_rate.clamp(0.0, 1.0),
203 ..Default::default()
204 }
205 }
206
207 pub fn with_status_code(mut self, code: u16) -> Self {
209 if !self.status_codes.contains(&code) {
210 self.status_codes.push(code);
211 }
212 self
213 }
214
215 pub fn with_error_response(mut self, key: String, response: serde_json::Value) -> Self {
217 self.error_responses.insert(key, response);
218 self
219 }
220
221 pub fn should_fail(&self) -> bool {
223 if self.failure_rate <= 0.0 {
224 return false;
225 }
226 if self.failure_rate >= 1.0 {
227 return true;
228 }
229
230 let mut rng = rand::rng();
231 rng.random_bool(self.failure_rate)
232 }
233
234 pub fn get_failure_response(&self) -> (u16, Option<serde_json::Value>) {
236 let mut rng = rand::rng();
237
238 let status_code = if self.status_codes.is_empty() {
239 500
240 } else {
241 let index = rng.random_range(0..self.status_codes.len());
242 self.status_codes[index]
243 };
244
245 let error_response = if self.error_responses.is_empty() {
246 None
247 } else {
248 let keys: Vec<&String> = self.error_responses.keys().collect();
249 let key = keys[rng.random_range(0..keys.len())];
250 self.error_responses.get(key).cloned()
251 };
252
253 (status_code, error_response)
254 }
255}
256
257#[derive(Debug, Clone)]
259pub struct LatencyInjector {
260 latency_profile: LatencyProfile,
262 fault_config: FaultConfig,
264 enabled: bool,
266}
267
268impl LatencyInjector {
269 pub fn new(latency_profile: LatencyProfile, fault_config: FaultConfig) -> Self {
271 Self {
272 latency_profile,
273 fault_config,
274 enabled: true,
275 }
276 }
277
278 pub fn set_enabled(&mut self, enabled: bool) {
280 self.enabled = enabled;
281 }
282
283 pub fn is_enabled(&self) -> bool {
285 self.enabled
286 }
287
288 pub async fn inject_latency(&self, tags: &[String]) -> Result<()> {
290 if !self.enabled {
291 return Ok(());
292 }
293
294 let latency = self.latency_profile.calculate_latency(tags);
295 if !latency.is_zero() {
296 tokio::time::sleep(latency).await;
297 }
298
299 Ok(())
300 }
301
302 pub fn should_inject_failure(&self) -> bool {
304 if !self.enabled {
305 return false;
306 }
307
308 self.fault_config.should_fail()
309 }
310
311 pub fn get_failure_response(&self) -> (u16, Option<serde_json::Value>) {
313 self.fault_config.get_failure_response()
314 }
315
316 pub async fn process_request(
318 &self,
319 tags: &[String],
320 ) -> Result<Option<(u16, Option<serde_json::Value>)>> {
321 if !self.enabled {
322 return Ok(None);
323 }
324
325 self.inject_latency(tags).await?;
327
328 if self.should_inject_failure() {
330 let (status, response) = self.get_failure_response();
331 return Ok(Some((status, response)));
332 }
333
334 Ok(None)
335 }
336
337 pub fn update_profile(&mut self, profile: LatencyProfile) {
342 self.latency_profile = profile;
343 }
344
345 pub async fn update_profile_async(
353 this: &std::sync::Arc<tokio::sync::RwLock<Self>>,
354 profile: LatencyProfile,
355 ) -> Result<()> {
356 let mut injector = this.write().await;
357 injector.update_profile(profile);
358 Ok(())
359 }
360}
361
362impl Default for LatencyInjector {
363 fn default() -> Self {
364 Self::new(LatencyProfile::default(), FaultConfig::default())
365 }
366}
367
368#[cfg(test)]
369mod tests {
370 use super::*;
371
372 #[tokio::test]
373 async fn test_update_profile() {
374 let mut injector =
375 LatencyInjector::new(LatencyProfile::new(50, 20), FaultConfig::default());
376
377 let new_profile = LatencyProfile::new(100, 30);
379 injector.update_profile(new_profile.clone());
380
381 assert!(injector.is_enabled());
384 }
385
386 #[tokio::test]
387 async fn test_update_profile_async() {
388 use std::sync::Arc;
389 use tokio::sync::RwLock;
390
391 let injector = Arc::new(RwLock::new(LatencyInjector::new(
392 LatencyProfile::new(50, 20),
393 FaultConfig::default(),
394 )));
395
396 let new_profile = LatencyProfile::new(200, 50);
398 LatencyInjector::update_profile_async(&injector, new_profile).await;
399
400 assert!(injector.read().await.is_enabled());
402 }
403
404 #[test]
405 fn test_latency_profile_default() {
406 let profile = LatencyProfile::default();
407 assert_eq!(profile.base_ms, 50);
408 assert_eq!(profile.jitter_ms, 20);
409 assert_eq!(profile.min_ms, 0);
410 assert!(profile.max_ms.is_none());
411 assert!(matches!(profile.distribution, LatencyDistribution::Fixed));
412 }
413
414 #[test]
415 fn test_latency_profile_new() {
416 let profile = LatencyProfile::new(100, 25);
417 assert_eq!(profile.base_ms, 100);
418 assert_eq!(profile.jitter_ms, 25);
419 assert!(matches!(profile.distribution, LatencyDistribution::Fixed));
420 }
421
422 #[test]
423 fn test_latency_profile_normal_distribution() {
424 let profile = LatencyProfile::with_normal_distribution(100, 20.0);
425 assert_eq!(profile.base_ms, 100);
426 assert!(matches!(profile.distribution, LatencyDistribution::Normal));
427 assert_eq!(profile.std_dev_ms, Some(20.0));
428 }
429
430 #[test]
431 fn test_latency_profile_pareto_distribution() {
432 let profile = LatencyProfile::with_pareto_distribution(100, 2.5);
433 assert_eq!(profile.base_ms, 100);
434 assert!(matches!(profile.distribution, LatencyDistribution::Pareto));
435 assert_eq!(profile.pareto_shape, Some(2.5));
436 }
437
438 #[test]
439 fn test_latency_profile_with_tag_override() {
440 let profile = LatencyProfile::default()
441 .with_tag_override("slow".to_string(), 500)
442 .with_tag_override("fast".to_string(), 10);
443
444 assert_eq!(profile.tag_overrides.get("slow"), Some(&500));
445 assert_eq!(profile.tag_overrides.get("fast"), Some(&10));
446 }
447
448 #[test]
449 fn test_latency_profile_with_bounds() {
450 let profile = LatencyProfile::default().with_min_ms(10).with_max_ms(1000);
451
452 assert_eq!(profile.min_ms, 10);
453 assert_eq!(profile.max_ms, Some(1000));
454 }
455
456 #[test]
457 fn test_calculate_latency_with_tag_override() {
458 let profile = LatencyProfile::default().with_tag_override("slow".to_string(), 500);
459
460 let tags = vec!["slow".to_string()];
461 let latency = profile.calculate_latency(&tags);
462 assert_eq!(latency, Duration::from_millis(500));
463 }
464
465 #[test]
466 fn test_calculate_latency_fixed_distribution() {
467 let profile = LatencyProfile::new(100, 0);
468 let tags = Vec::new();
469 let latency = profile.calculate_latency(&tags);
470 assert_eq!(latency, Duration::from_millis(100));
471 }
472
473 #[test]
474 fn test_calculate_latency_respects_min_bound() {
475 let profile = LatencyProfile::new(10, 0).with_min_ms(50);
476 let tags = Vec::new();
477 let latency = profile.calculate_latency(&tags);
478 assert!(latency >= Duration::from_millis(50));
479 }
480
481 #[test]
482 fn test_calculate_latency_respects_max_bound() {
483 let profile = LatencyProfile::with_pareto_distribution(100, 2.0).with_max_ms(200);
484
485 for _ in 0..100 {
486 let latency = profile.calculate_latency(&[]);
487 assert!(latency <= Duration::from_millis(200));
488 }
489 }
490
491 #[test]
492 fn test_fault_config_default() {
493 let config = FaultConfig::default();
494 assert_eq!(config.failure_rate, 0.0);
495 assert!(!config.status_codes.is_empty());
496 assert!(config.error_responses.is_empty());
497 }
498
499 #[test]
500 fn test_fault_config_new() {
501 let config = FaultConfig::new(0.5);
502 assert_eq!(config.failure_rate, 0.5);
503 }
504
505 #[test]
506 fn test_fault_config_clamps_failure_rate() {
507 let config = FaultConfig::new(1.5);
508 assert_eq!(config.failure_rate, 1.0);
509
510 let config = FaultConfig::new(-0.5);
511 assert_eq!(config.failure_rate, 0.0);
512 }
513
514 #[test]
515 fn test_fault_config_with_status_code() {
516 let config = FaultConfig::default().with_status_code(400).with_status_code(404);
517
518 assert!(config.status_codes.contains(&400));
519 assert!(config.status_codes.contains(&404));
520 }
521
522 #[test]
523 fn test_fault_config_with_error_response() {
524 let response = serde_json::json!({"error": "test"});
525 let config =
526 FaultConfig::default().with_error_response("test".to_string(), response.clone());
527
528 assert_eq!(config.error_responses.get("test"), Some(&response));
529 }
530
531 #[test]
532 fn test_fault_config_should_fail_zero_rate() {
533 let config = FaultConfig::new(0.0);
534 assert!(!config.should_fail());
535 }
536
537 #[test]
538 fn test_fault_config_should_fail_full_rate() {
539 let config = FaultConfig::new(1.0);
540 assert!(config.should_fail());
541 }
542
543 #[test]
544 fn test_fault_config_should_fail_probabilistic() {
545 let config = FaultConfig::new(0.5);
546 let mut failures = 0;
547 let iterations = 1000;
548
549 for _ in 0..iterations {
550 if config.should_fail() {
551 failures += 1;
552 }
553 }
554
555 let failure_rate = failures as f64 / iterations as f64;
557 assert!(failure_rate > 0.4 && failure_rate < 0.6);
558 }
559
560 #[test]
561 fn test_fault_config_get_failure_response() {
562 let config = FaultConfig::new(1.0).with_status_code(502);
563
564 let (status, _) = config.get_failure_response();
565 assert!(config.status_codes.contains(&status));
566 }
567
568 #[test]
569 fn test_latency_injector_new() {
570 let injector = LatencyInjector::new(LatencyProfile::default(), FaultConfig::default());
571 assert!(injector.is_enabled());
572 }
573
574 #[test]
575 fn test_latency_injector_enable_disable() {
576 let mut injector = LatencyInjector::default();
577 assert!(injector.is_enabled());
578
579 injector.set_enabled(false);
580 assert!(!injector.is_enabled());
581
582 injector.set_enabled(true);
583 assert!(injector.is_enabled());
584 }
585
586 #[tokio::test]
587 async fn test_latency_injector_inject_latency() {
588 let injector = LatencyInjector::new(LatencyProfile::new(10, 0), FaultConfig::default());
589
590 let start = std::time::Instant::now();
591 injector.inject_latency(&[]).await.unwrap();
592 let elapsed = start.elapsed();
593
594 assert!(elapsed >= Duration::from_millis(8));
595 }
596
597 #[tokio::test]
598 async fn test_latency_injector_disabled_no_latency() {
599 let mut injector =
600 LatencyInjector::new(LatencyProfile::new(100, 0), FaultConfig::default());
601 injector.set_enabled(false);
602
603 let start = std::time::Instant::now();
604 injector.inject_latency(&[]).await.unwrap();
605 let elapsed = start.elapsed();
606
607 assert!(elapsed < Duration::from_millis(10));
608 }
609
610 #[test]
611 fn test_latency_injector_should_inject_failure() {
612 let injector = LatencyInjector::new(LatencyProfile::default(), FaultConfig::new(1.0));
613
614 assert!(injector.should_inject_failure());
615 }
616
617 #[test]
618 fn test_latency_injector_disabled_no_failure() {
619 let mut injector = LatencyInjector::new(LatencyProfile::default(), FaultConfig::new(1.0));
620 injector.set_enabled(false);
621
622 assert!(!injector.should_inject_failure());
623 }
624
625 #[tokio::test]
626 async fn test_latency_injector_process_request_no_failure() {
627 let injector = LatencyInjector::new(LatencyProfile::new(10, 0), FaultConfig::new(0.0));
628
629 let result = injector.process_request(&[]).await.unwrap();
630 assert!(result.is_none());
631 }
632
633 #[tokio::test]
634 async fn test_latency_injector_process_request_with_failure() {
635 let fault_config = FaultConfig {
636 failure_rate: 1.0,
637 status_codes: vec![503], ..Default::default()
639 };
640
641 let injector = LatencyInjector::new(LatencyProfile::new(10, 0), fault_config);
642
643 let result = injector.process_request(&[]).await.unwrap();
644 assert!(result.is_some());
645
646 let (status, _) = result.unwrap();
647 assert_eq!(status, 503);
648 }
649
650 #[tokio::test]
651 async fn test_latency_injector_process_request_disabled() {
652 let mut injector = LatencyInjector::new(LatencyProfile::new(100, 0), FaultConfig::new(1.0));
653 injector.set_enabled(false);
654
655 let result = injector.process_request(&[]).await.unwrap();
656 assert!(result.is_none());
657 }
658}