1use crate::error::Result;
6use rand::Rng;
7use std::collections::HashMap;
8use std::time::Duration;
9
10#[derive(Debug, Clone, serde::Deserialize, serde::Serialize, Default)]
12#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
13#[serde(rename_all = "lowercase")]
14pub enum LatencyDistribution {
15 #[default]
17 Fixed,
18 Normal,
20 Pareto,
22}
23
24#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
26#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
27pub struct LatencyProfile {
28 pub base_ms: u64,
30 pub jitter_ms: u64,
32 #[serde(default)]
34 pub distribution: LatencyDistribution,
35 #[serde(default)]
37 pub std_dev_ms: Option<f64>,
38 #[serde(default)]
40 pub pareto_shape: Option<f64>,
41 #[serde(default)]
43 pub min_ms: u64,
44 #[serde(default)]
46 pub max_ms: Option<u64>,
47 pub tag_overrides: HashMap<String, u64>,
49}
50
51impl Default for LatencyProfile {
52 fn default() -> Self {
53 Self {
54 base_ms: 50, jitter_ms: 20, distribution: LatencyDistribution::Fixed,
57 std_dev_ms: None,
58 pareto_shape: None,
59 min_ms: 0,
60 max_ms: None,
61 tag_overrides: HashMap::new(),
62 }
63 }
64}
65
66impl LatencyProfile {
67 pub fn new(base_ms: u64, jitter_ms: u64) -> Self {
69 Self {
70 base_ms,
71 jitter_ms,
72 distribution: LatencyDistribution::Fixed,
73 std_dev_ms: None,
74 pareto_shape: None,
75 min_ms: 0,
76 max_ms: None,
77 tag_overrides: HashMap::new(),
78 }
79 }
80
81 pub fn with_normal_distribution(base_ms: u64, std_dev_ms: f64) -> Self {
83 Self {
84 base_ms,
85 jitter_ms: 0, distribution: LatencyDistribution::Normal,
87 std_dev_ms: Some(std_dev_ms),
88 pareto_shape: None,
89 min_ms: 0,
90 max_ms: None,
91 tag_overrides: HashMap::new(),
92 }
93 }
94
95 pub fn with_pareto_distribution(base_ms: u64, shape: f64) -> Self {
97 Self {
98 base_ms,
99 jitter_ms: 0, distribution: LatencyDistribution::Pareto,
101 std_dev_ms: None,
102 pareto_shape: Some(shape),
103 min_ms: 0,
104 max_ms: None,
105 tag_overrides: HashMap::new(),
106 }
107 }
108
109 pub fn with_tag_override(mut self, tag: String, latency_ms: u64) -> Self {
111 self.tag_overrides.insert(tag, latency_ms);
112 self
113 }
114
115 pub fn with_min_ms(mut self, min_ms: u64) -> Self {
117 self.min_ms = min_ms;
118 self
119 }
120
121 pub fn with_max_ms(mut self, max_ms: u64) -> Self {
123 self.max_ms = Some(max_ms);
124 self
125 }
126
127 pub fn calculate_latency(&self, tags: &[String]) -> Duration {
129 let mut rng = rand::rng();
130
131 if let Some(&override_ms) = tags.iter().find_map(|tag| self.tag_overrides.get(tag)) {
134 return Duration::from_millis(override_ms);
135 }
136
137 let mut latency_ms = match self.distribution {
138 LatencyDistribution::Fixed => {
139 let jitter = if self.jitter_ms > 0 {
141 rng.random_range(0..=self.jitter_ms * 2).saturating_sub(self.jitter_ms)
142 } else {
143 0
144 };
145 self.base_ms.saturating_add(jitter)
146 }
147 LatencyDistribution::Normal => {
148 let std_dev = self.std_dev_ms.unwrap_or((self.base_ms as f64) * 0.2);
150 let mean = self.base_ms as f64;
151
152 let u1: f64 = rng.random();
154 let u2: f64 = rng.random();
155
156 let z0 = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
158 (mean + std_dev * z0).max(0.0) as u64
159 }
160 LatencyDistribution::Pareto => {
161 let shape = self.pareto_shape.unwrap_or(2.0);
163 let scale = self.base_ms as f64;
164
165 let u: f64 = rng.random();
167 (scale / (1.0 - u).powf(1.0 / shape)) as u64
168 }
169 };
170
171 latency_ms = latency_ms.max(self.min_ms);
173 if let Some(max_ms) = self.max_ms {
174 latency_ms = latency_ms.min(max_ms);
175 }
176
177 Duration::from_millis(latency_ms)
178 }
179}
180
181#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
183pub struct FaultConfig {
184 pub failure_rate: f64,
186 pub status_codes: Vec<u16>,
188 pub error_responses: HashMap<String, serde_json::Value>,
190}
191
192impl Default for FaultConfig {
193 fn default() -> Self {
194 Self {
195 failure_rate: 0.0,
196 status_codes: vec![500, 502, 503, 504],
197 error_responses: HashMap::new(),
198 }
199 }
200}
201
202impl FaultConfig {
203 pub fn new(failure_rate: f64) -> Self {
205 Self {
206 failure_rate: failure_rate.clamp(0.0, 1.0),
207 ..Default::default()
208 }
209 }
210
211 pub fn with_status_code(mut self, code: u16) -> Self {
213 if !self.status_codes.contains(&code) {
214 self.status_codes.push(code);
215 }
216 self
217 }
218
219 pub fn with_error_response(mut self, key: String, response: serde_json::Value) -> Self {
221 self.error_responses.insert(key, response);
222 self
223 }
224
225 pub fn should_fail(&self) -> bool {
227 if self.failure_rate <= 0.0 {
228 return false;
229 }
230 if self.failure_rate >= 1.0 {
231 return true;
232 }
233
234 let mut rng = rand::rng();
235 rng.random_bool(self.failure_rate)
236 }
237
238 pub fn get_failure_response(&self) -> (u16, Option<serde_json::Value>) {
240 let mut rng = rand::rng();
241
242 let status_code = if self.status_codes.is_empty() {
243 500
244 } else {
245 let index = rng.random_range(0..self.status_codes.len());
246 self.status_codes[index]
247 };
248
249 let error_response = if self.error_responses.is_empty() {
250 None
251 } else {
252 let keys: Vec<&String> = self.error_responses.keys().collect();
253 let key = keys[rng.random_range(0..keys.len())];
254 self.error_responses.get(key).cloned()
255 };
256
257 (status_code, error_response)
258 }
259}
260
261#[derive(Debug, Clone)]
263pub struct LatencyInjector {
264 latency_profile: LatencyProfile,
266 fault_config: FaultConfig,
268 enabled: bool,
270}
271
272impl LatencyInjector {
273 pub fn new(latency_profile: LatencyProfile, fault_config: FaultConfig) -> Self {
275 Self {
276 latency_profile,
277 fault_config,
278 enabled: true,
279 }
280 }
281
282 pub fn set_enabled(&mut self, enabled: bool) {
284 self.enabled = enabled;
285 }
286
287 pub fn is_enabled(&self) -> bool {
289 self.enabled
290 }
291
292 pub async fn inject_latency(&self, tags: &[String]) -> Result<()> {
294 if !self.enabled {
295 return Ok(());
296 }
297
298 let latency = self.latency_profile.calculate_latency(tags);
299 if !latency.is_zero() {
300 tokio::time::sleep(latency).await;
301 }
302
303 Ok(())
304 }
305
306 pub fn should_inject_failure(&self) -> bool {
308 if !self.enabled {
309 return false;
310 }
311
312 self.fault_config.should_fail()
313 }
314
315 pub fn get_failure_response(&self) -> (u16, Option<serde_json::Value>) {
317 self.fault_config.get_failure_response()
318 }
319
320 pub async fn process_request(
322 &self,
323 tags: &[String],
324 ) -> Result<Option<(u16, Option<serde_json::Value>)>> {
325 if !self.enabled {
326 return Ok(None);
327 }
328
329 self.inject_latency(tags).await?;
331
332 if self.should_inject_failure() {
334 let (status, response) = self.get_failure_response();
335 return Ok(Some((status, response)));
336 }
337
338 Ok(None)
339 }
340
341 pub fn profile(&self) -> &LatencyProfile {
346 &self.latency_profile
347 }
348
349 pub fn update_profile(&mut self, profile: LatencyProfile) {
354 self.latency_profile = profile;
355 }
356
357 pub async fn update_profile_async(
365 this: &std::sync::Arc<tokio::sync::RwLock<Self>>,
366 profile: LatencyProfile,
367 ) -> Result<()> {
368 let mut injector = this.write().await;
369 injector.update_profile(profile);
370 Ok(())
371 }
372}
373
374impl Default for LatencyInjector {
375 fn default() -> Self {
376 Self::new(LatencyProfile::default(), FaultConfig::default())
377 }
378}
379
380#[cfg(test)]
381mod tests {
382 use super::*;
383
384 #[tokio::test]
385 async fn test_update_profile() {
386 let mut injector =
387 LatencyInjector::new(LatencyProfile::new(50, 20), FaultConfig::default());
388
389 let new_profile = LatencyProfile::new(100, 30);
391 injector.update_profile(new_profile.clone());
392
393 assert!(injector.is_enabled());
396 }
397
398 #[tokio::test]
399 async fn test_update_profile_async() {
400 use std::sync::Arc;
401 use tokio::sync::RwLock;
402
403 let injector = Arc::new(RwLock::new(LatencyInjector::new(
404 LatencyProfile::new(50, 20),
405 FaultConfig::default(),
406 )));
407
408 let new_profile = LatencyProfile::new(200, 50);
410 LatencyInjector::update_profile_async(&injector, new_profile).await.unwrap();
411
412 assert!(injector.read().await.is_enabled());
414 }
415
416 #[test]
417 fn test_latency_profile_default() {
418 let profile = LatencyProfile::default();
419 assert_eq!(profile.base_ms, 50);
420 assert_eq!(profile.jitter_ms, 20);
421 assert_eq!(profile.min_ms, 0);
422 assert!(profile.max_ms.is_none());
423 assert!(matches!(profile.distribution, LatencyDistribution::Fixed));
424 }
425
426 #[test]
427 fn test_latency_profile_new() {
428 let profile = LatencyProfile::new(100, 25);
429 assert_eq!(profile.base_ms, 100);
430 assert_eq!(profile.jitter_ms, 25);
431 assert!(matches!(profile.distribution, LatencyDistribution::Fixed));
432 }
433
434 #[test]
435 fn test_latency_profile_normal_distribution() {
436 let profile = LatencyProfile::with_normal_distribution(100, 20.0);
437 assert_eq!(profile.base_ms, 100);
438 assert!(matches!(profile.distribution, LatencyDistribution::Normal));
439 assert_eq!(profile.std_dev_ms, Some(20.0));
440 }
441
442 #[test]
443 fn test_latency_profile_pareto_distribution() {
444 let profile = LatencyProfile::with_pareto_distribution(100, 2.5);
445 assert_eq!(profile.base_ms, 100);
446 assert!(matches!(profile.distribution, LatencyDistribution::Pareto));
447 assert_eq!(profile.pareto_shape, Some(2.5));
448 }
449
450 #[test]
451 fn test_latency_profile_with_tag_override() {
452 let profile = LatencyProfile::default()
453 .with_tag_override("slow".to_string(), 500)
454 .with_tag_override("fast".to_string(), 10);
455
456 assert_eq!(profile.tag_overrides.get("slow"), Some(&500));
457 assert_eq!(profile.tag_overrides.get("fast"), Some(&10));
458 }
459
460 #[test]
461 fn test_latency_profile_with_bounds() {
462 let profile = LatencyProfile::default().with_min_ms(10).with_max_ms(1000);
463
464 assert_eq!(profile.min_ms, 10);
465 assert_eq!(profile.max_ms, Some(1000));
466 }
467
468 #[test]
469 fn test_calculate_latency_with_tag_override() {
470 let profile = LatencyProfile::default().with_tag_override("slow".to_string(), 500);
471
472 let tags = vec!["slow".to_string()];
473 let latency = profile.calculate_latency(&tags);
474 assert_eq!(latency, Duration::from_millis(500));
475 }
476
477 #[test]
478 fn test_calculate_latency_fixed_distribution() {
479 let profile = LatencyProfile::new(100, 0);
480 let tags = Vec::new();
481 let latency = profile.calculate_latency(&tags);
482 assert_eq!(latency, Duration::from_millis(100));
483 }
484
485 #[test]
486 fn test_calculate_latency_respects_min_bound() {
487 let profile = LatencyProfile::new(10, 0).with_min_ms(50);
488 let tags = Vec::new();
489 let latency = profile.calculate_latency(&tags);
490 assert!(latency >= Duration::from_millis(50));
491 }
492
493 #[test]
494 fn test_calculate_latency_respects_max_bound() {
495 let profile = LatencyProfile::with_pareto_distribution(100, 2.0).with_max_ms(200);
496
497 for _ in 0..100 {
498 let latency = profile.calculate_latency(&[]);
499 assert!(latency <= Duration::from_millis(200));
500 }
501 }
502
503 #[test]
504 fn test_fault_config_default() {
505 let config = FaultConfig::default();
506 assert_eq!(config.failure_rate, 0.0);
507 assert!(!config.status_codes.is_empty());
508 assert!(config.error_responses.is_empty());
509 }
510
511 #[test]
512 fn test_fault_config_new() {
513 let config = FaultConfig::new(0.5);
514 assert_eq!(config.failure_rate, 0.5);
515 }
516
517 #[test]
518 fn test_fault_config_clamps_failure_rate() {
519 let config = FaultConfig::new(1.5);
520 assert_eq!(config.failure_rate, 1.0);
521
522 let config = FaultConfig::new(-0.5);
523 assert_eq!(config.failure_rate, 0.0);
524 }
525
526 #[test]
527 fn test_fault_config_with_status_code() {
528 let config = FaultConfig::default().with_status_code(400).with_status_code(404);
529
530 assert!(config.status_codes.contains(&400));
531 assert!(config.status_codes.contains(&404));
532 }
533
534 #[test]
535 fn test_fault_config_with_error_response() {
536 let response = serde_json::json!({"error": "test"});
537 let config =
538 FaultConfig::default().with_error_response("test".to_string(), response.clone());
539
540 assert_eq!(config.error_responses.get("test"), Some(&response));
541 }
542
543 #[test]
544 fn test_fault_config_should_fail_zero_rate() {
545 let config = FaultConfig::new(0.0);
546 assert!(!config.should_fail());
547 }
548
549 #[test]
550 fn test_fault_config_should_fail_full_rate() {
551 let config = FaultConfig::new(1.0);
552 assert!(config.should_fail());
553 }
554
555 #[test]
556 fn test_fault_config_should_fail_probabilistic() {
557 let config = FaultConfig::new(0.5);
558 let mut failures = 0;
559 let iterations = 1000;
560
561 for _ in 0..iterations {
562 if config.should_fail() {
563 failures += 1;
564 }
565 }
566
567 let failure_rate = failures as f64 / iterations as f64;
569 assert!(failure_rate > 0.4 && failure_rate < 0.6);
570 }
571
572 #[test]
573 fn test_fault_config_get_failure_response() {
574 let config = FaultConfig::new(1.0).with_status_code(502);
575
576 let (status, _) = config.get_failure_response();
577 assert!(config.status_codes.contains(&status));
578 }
579
580 #[test]
581 fn test_latency_injector_new() {
582 let injector = LatencyInjector::new(LatencyProfile::default(), FaultConfig::default());
583 assert!(injector.is_enabled());
584 }
585
586 #[test]
587 fn test_latency_injector_enable_disable() {
588 let mut injector = LatencyInjector::default();
589 assert!(injector.is_enabled());
590
591 injector.set_enabled(false);
592 assert!(!injector.is_enabled());
593
594 injector.set_enabled(true);
595 assert!(injector.is_enabled());
596 }
597
598 #[tokio::test]
599 async fn test_latency_injector_inject_latency() {
600 let injector = LatencyInjector::new(LatencyProfile::new(10, 0), FaultConfig::default());
601
602 let start = std::time::Instant::now();
603 injector.inject_latency(&[]).await.unwrap();
604 let elapsed = start.elapsed();
605
606 assert!(elapsed >= Duration::from_millis(8));
607 }
608
609 #[tokio::test]
610 async fn test_latency_injector_disabled_no_latency() {
611 let mut injector =
612 LatencyInjector::new(LatencyProfile::new(100, 0), FaultConfig::default());
613 injector.set_enabled(false);
614
615 let start = std::time::Instant::now();
616 injector.inject_latency(&[]).await.unwrap();
617 let elapsed = start.elapsed();
618
619 assert!(elapsed < Duration::from_millis(10));
620 }
621
622 #[test]
623 fn test_latency_injector_should_inject_failure() {
624 let injector = LatencyInjector::new(LatencyProfile::default(), FaultConfig::new(1.0));
625
626 assert!(injector.should_inject_failure());
627 }
628
629 #[test]
630 fn test_latency_injector_disabled_no_failure() {
631 let mut injector = LatencyInjector::new(LatencyProfile::default(), FaultConfig::new(1.0));
632 injector.set_enabled(false);
633
634 assert!(!injector.should_inject_failure());
635 }
636
637 #[tokio::test]
638 async fn test_latency_injector_process_request_no_failure() {
639 let injector = LatencyInjector::new(LatencyProfile::new(10, 0), FaultConfig::new(0.0));
640
641 let result = injector.process_request(&[]).await.unwrap();
642 assert!(result.is_none());
643 }
644
645 #[tokio::test]
646 async fn test_latency_injector_process_request_with_failure() {
647 let fault_config = FaultConfig {
648 failure_rate: 1.0,
649 status_codes: vec![503], ..Default::default()
651 };
652
653 let injector = LatencyInjector::new(LatencyProfile::new(10, 0), fault_config);
654
655 let result = injector.process_request(&[]).await.unwrap();
656 assert!(result.is_some());
657
658 let (status, _) = result.unwrap();
659 assert_eq!(status, 503);
660 }
661
662 #[tokio::test]
663 async fn test_latency_injector_process_request_disabled() {
664 let mut injector = LatencyInjector::new(LatencyProfile::new(100, 0), FaultConfig::new(1.0));
665 injector.set_enabled(false);
666
667 let result = injector.process_request(&[]).await.unwrap();
668 assert!(result.is_none());
669 }
670}