1use crate::Result;
4use rand::Rng;
5use std::collections::HashMap;
6use std::time::Duration;
7
8#[derive(Debug, Clone, serde::Deserialize, serde::Serialize, Default)]
10#[serde(rename_all = "lowercase")]
11pub enum LatencyDistribution {
12 #[default]
14 Fixed,
15 Normal,
17 Pareto,
19}
20
21#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
23pub struct LatencyProfile {
24 pub base_ms: u64,
26 pub jitter_ms: u64,
28 #[serde(default)]
30 pub distribution: LatencyDistribution,
31 #[serde(default)]
33 pub std_dev_ms: Option<f64>,
34 #[serde(default)]
36 pub pareto_shape: Option<f64>,
37 #[serde(default)]
39 pub min_ms: u64,
40 #[serde(default)]
42 pub max_ms: Option<u64>,
43 pub tag_overrides: HashMap<String, u64>,
45}
46
47impl Default for LatencyProfile {
48 fn default() -> Self {
49 Self {
50 base_ms: 50, jitter_ms: 20, distribution: LatencyDistribution::Fixed,
53 std_dev_ms: None,
54 pareto_shape: None,
55 min_ms: 0,
56 max_ms: None,
57 tag_overrides: HashMap::new(),
58 }
59 }
60}
61
62impl LatencyProfile {
63 pub fn new(base_ms: u64, jitter_ms: u64) -> Self {
65 Self {
66 base_ms,
67 jitter_ms,
68 distribution: LatencyDistribution::Fixed,
69 std_dev_ms: None,
70 pareto_shape: None,
71 min_ms: 0,
72 max_ms: None,
73 tag_overrides: HashMap::new(),
74 }
75 }
76
77 pub fn with_normal_distribution(base_ms: u64, std_dev_ms: f64) -> Self {
79 Self {
80 base_ms,
81 jitter_ms: 0, distribution: LatencyDistribution::Normal,
83 std_dev_ms: Some(std_dev_ms),
84 pareto_shape: None,
85 min_ms: 0,
86 max_ms: None,
87 tag_overrides: HashMap::new(),
88 }
89 }
90
91 pub fn with_pareto_distribution(base_ms: u64, shape: f64) -> Self {
93 Self {
94 base_ms,
95 jitter_ms: 0, distribution: LatencyDistribution::Pareto,
97 std_dev_ms: None,
98 pareto_shape: Some(shape),
99 min_ms: 0,
100 max_ms: None,
101 tag_overrides: HashMap::new(),
102 }
103 }
104
105 pub fn with_tag_override(mut self, tag: String, latency_ms: u64) -> Self {
107 self.tag_overrides.insert(tag, latency_ms);
108 self
109 }
110
111 pub fn with_min_ms(mut self, min_ms: u64) -> Self {
113 self.min_ms = min_ms;
114 self
115 }
116
117 pub fn with_max_ms(mut self, max_ms: u64) -> Self {
119 self.max_ms = Some(max_ms);
120 self
121 }
122
123 pub fn calculate_latency(&self, tags: &[String]) -> Duration {
125 let mut rng = rand::rng();
126
127 if let Some(&override_ms) = tags.iter().find_map(|tag| self.tag_overrides.get(tag)) {
130 return Duration::from_millis(override_ms);
131 }
132
133 let mut latency_ms = match self.distribution {
134 LatencyDistribution::Fixed => {
135 let jitter = if self.jitter_ms > 0 {
137 rng.random_range(0..=self.jitter_ms * 2).saturating_sub(self.jitter_ms)
138 } else {
139 0
140 };
141 self.base_ms.saturating_add(jitter)
142 }
143 LatencyDistribution::Normal => {
144 let std_dev = self.std_dev_ms.unwrap_or((self.base_ms as f64) * 0.2);
146 let mean = self.base_ms as f64;
147
148 let u1: f64 = rng.random();
150 let u2: f64 = rng.random();
151
152 let z0 = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
154 (mean + std_dev * z0).max(0.0) as u64
155 }
156 LatencyDistribution::Pareto => {
157 let shape = self.pareto_shape.unwrap_or(2.0);
159 let scale = self.base_ms as f64;
160
161 let u: f64 = rng.random();
163 (scale / (1.0 - u).powf(1.0 / shape)) as u64
164 }
165 };
166
167 latency_ms = latency_ms.max(self.min_ms);
169 if let Some(max_ms) = self.max_ms {
170 latency_ms = latency_ms.min(max_ms);
171 }
172
173 Duration::from_millis(latency_ms)
174 }
175}
176
177#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
179pub struct FaultConfig {
180 pub failure_rate: f64,
182 pub status_codes: Vec<u16>,
184 pub error_responses: HashMap<String, serde_json::Value>,
186}
187
188impl Default for FaultConfig {
189 fn default() -> Self {
190 Self {
191 failure_rate: 0.0,
192 status_codes: vec![500, 502, 503, 504],
193 error_responses: HashMap::new(),
194 }
195 }
196}
197
198impl FaultConfig {
199 pub fn new(failure_rate: f64) -> Self {
201 Self {
202 failure_rate: failure_rate.clamp(0.0, 1.0),
203 ..Default::default()
204 }
205 }
206
207 pub fn with_status_code(mut self, code: u16) -> Self {
209 if !self.status_codes.contains(&code) {
210 self.status_codes.push(code);
211 }
212 self
213 }
214
215 pub fn with_error_response(mut self, key: String, response: serde_json::Value) -> Self {
217 self.error_responses.insert(key, response);
218 self
219 }
220
221 pub fn should_fail(&self) -> bool {
223 if self.failure_rate <= 0.0 {
224 return false;
225 }
226 if self.failure_rate >= 1.0 {
227 return true;
228 }
229
230 let mut rng = rand::rng();
231 rng.random_bool(self.failure_rate)
232 }
233
234 pub fn get_failure_response(&self) -> (u16, Option<serde_json::Value>) {
236 let mut rng = rand::rng();
237
238 let status_code = if self.status_codes.is_empty() {
239 500
240 } else {
241 let index = rng.random_range(0..self.status_codes.len());
242 self.status_codes[index]
243 };
244
245 let error_response = if self.error_responses.is_empty() {
246 None
247 } else {
248 let keys: Vec<&String> = self.error_responses.keys().collect();
249 let key = keys[rng.random_range(0..keys.len())];
250 self.error_responses.get(key).cloned()
251 };
252
253 (status_code, error_response)
254 }
255}
256
257#[derive(Debug, Clone)]
259pub struct LatencyInjector {
260 latency_profile: LatencyProfile,
262 fault_config: FaultConfig,
264 enabled: bool,
266}
267
268impl LatencyInjector {
269 pub fn new(latency_profile: LatencyProfile, fault_config: FaultConfig) -> Self {
271 Self {
272 latency_profile,
273 fault_config,
274 enabled: true,
275 }
276 }
277
278 pub fn set_enabled(&mut self, enabled: bool) {
280 self.enabled = enabled;
281 }
282
283 pub fn is_enabled(&self) -> bool {
285 self.enabled
286 }
287
288 pub async fn inject_latency(&self, tags: &[String]) -> Result<()> {
290 if !self.enabled {
291 return Ok(());
292 }
293
294 let latency = self.latency_profile.calculate_latency(tags);
295 if !latency.is_zero() {
296 tokio::time::sleep(latency).await;
297 }
298
299 Ok(())
300 }
301
302 pub fn should_inject_failure(&self) -> bool {
304 if !self.enabled {
305 return false;
306 }
307
308 self.fault_config.should_fail()
309 }
310
311 pub fn get_failure_response(&self) -> (u16, Option<serde_json::Value>) {
313 self.fault_config.get_failure_response()
314 }
315
316 pub async fn process_request(
318 &self,
319 tags: &[String],
320 ) -> Result<Option<(u16, Option<serde_json::Value>)>> {
321 if !self.enabled {
322 return Ok(None);
323 }
324
325 self.inject_latency(tags).await?;
327
328 if self.should_inject_failure() {
330 let (status, response) = self.get_failure_response();
331 return Ok(Some((status, response)));
332 }
333
334 Ok(None)
335 }
336}
337
338impl Default for LatencyInjector {
339 fn default() -> Self {
340 Self::new(LatencyProfile::default(), FaultConfig::default())
341 }
342}
343
344#[cfg(test)]
345mod tests {
346 use super::*;
347
348 #[test]
349 fn test_latency_profile_default() {
350 let profile = LatencyProfile::default();
351 assert_eq!(profile.base_ms, 50);
352 assert_eq!(profile.jitter_ms, 20);
353 assert_eq!(profile.min_ms, 0);
354 assert!(profile.max_ms.is_none());
355 assert!(matches!(profile.distribution, LatencyDistribution::Fixed));
356 }
357
358 #[test]
359 fn test_latency_profile_new() {
360 let profile = LatencyProfile::new(100, 25);
361 assert_eq!(profile.base_ms, 100);
362 assert_eq!(profile.jitter_ms, 25);
363 assert!(matches!(profile.distribution, LatencyDistribution::Fixed));
364 }
365
366 #[test]
367 fn test_latency_profile_normal_distribution() {
368 let profile = LatencyProfile::with_normal_distribution(100, 20.0);
369 assert_eq!(profile.base_ms, 100);
370 assert!(matches!(profile.distribution, LatencyDistribution::Normal));
371 assert_eq!(profile.std_dev_ms, Some(20.0));
372 }
373
374 #[test]
375 fn test_latency_profile_pareto_distribution() {
376 let profile = LatencyProfile::with_pareto_distribution(100, 2.5);
377 assert_eq!(profile.base_ms, 100);
378 assert!(matches!(profile.distribution, LatencyDistribution::Pareto));
379 assert_eq!(profile.pareto_shape, Some(2.5));
380 }
381
382 #[test]
383 fn test_latency_profile_with_tag_override() {
384 let profile = LatencyProfile::default()
385 .with_tag_override("slow".to_string(), 500)
386 .with_tag_override("fast".to_string(), 10);
387
388 assert_eq!(profile.tag_overrides.get("slow"), Some(&500));
389 assert_eq!(profile.tag_overrides.get("fast"), Some(&10));
390 }
391
392 #[test]
393 fn test_latency_profile_with_bounds() {
394 let profile = LatencyProfile::default().with_min_ms(10).with_max_ms(1000);
395
396 assert_eq!(profile.min_ms, 10);
397 assert_eq!(profile.max_ms, Some(1000));
398 }
399
400 #[test]
401 fn test_calculate_latency_with_tag_override() {
402 let profile = LatencyProfile::default().with_tag_override("slow".to_string(), 500);
403
404 let tags = vec!["slow".to_string()];
405 let latency = profile.calculate_latency(&tags);
406 assert_eq!(latency, Duration::from_millis(500));
407 }
408
409 #[test]
410 fn test_calculate_latency_fixed_distribution() {
411 let profile = LatencyProfile::new(100, 0);
412 let tags = Vec::new();
413 let latency = profile.calculate_latency(&tags);
414 assert_eq!(latency, Duration::from_millis(100));
415 }
416
417 #[test]
418 fn test_calculate_latency_respects_min_bound() {
419 let profile = LatencyProfile::new(10, 0).with_min_ms(50);
420 let tags = Vec::new();
421 let latency = profile.calculate_latency(&tags);
422 assert!(latency >= Duration::from_millis(50));
423 }
424
425 #[test]
426 fn test_calculate_latency_respects_max_bound() {
427 let profile = LatencyProfile::with_pareto_distribution(100, 2.0).with_max_ms(200);
428
429 for _ in 0..100 {
430 let latency = profile.calculate_latency(&[]);
431 assert!(latency <= Duration::from_millis(200));
432 }
433 }
434
435 #[test]
436 fn test_fault_config_default() {
437 let config = FaultConfig::default();
438 assert_eq!(config.failure_rate, 0.0);
439 assert!(!config.status_codes.is_empty());
440 assert!(config.error_responses.is_empty());
441 }
442
443 #[test]
444 fn test_fault_config_new() {
445 let config = FaultConfig::new(0.5);
446 assert_eq!(config.failure_rate, 0.5);
447 }
448
449 #[test]
450 fn test_fault_config_clamps_failure_rate() {
451 let config = FaultConfig::new(1.5);
452 assert_eq!(config.failure_rate, 1.0);
453
454 let config = FaultConfig::new(-0.5);
455 assert_eq!(config.failure_rate, 0.0);
456 }
457
458 #[test]
459 fn test_fault_config_with_status_code() {
460 let config = FaultConfig::default().with_status_code(400).with_status_code(404);
461
462 assert!(config.status_codes.contains(&400));
463 assert!(config.status_codes.contains(&404));
464 }
465
466 #[test]
467 fn test_fault_config_with_error_response() {
468 let response = serde_json::json!({"error": "test"});
469 let config =
470 FaultConfig::default().with_error_response("test".to_string(), response.clone());
471
472 assert_eq!(config.error_responses.get("test"), Some(&response));
473 }
474
475 #[test]
476 fn test_fault_config_should_fail_zero_rate() {
477 let config = FaultConfig::new(0.0);
478 assert!(!config.should_fail());
479 }
480
481 #[test]
482 fn test_fault_config_should_fail_full_rate() {
483 let config = FaultConfig::new(1.0);
484 assert!(config.should_fail());
485 }
486
487 #[test]
488 fn test_fault_config_should_fail_probabilistic() {
489 let config = FaultConfig::new(0.5);
490 let mut failures = 0;
491 let iterations = 1000;
492
493 for _ in 0..iterations {
494 if config.should_fail() {
495 failures += 1;
496 }
497 }
498
499 let failure_rate = failures as f64 / iterations as f64;
501 assert!(failure_rate > 0.4 && failure_rate < 0.6);
502 }
503
504 #[test]
505 fn test_fault_config_get_failure_response() {
506 let config = FaultConfig::new(1.0).with_status_code(502);
507
508 let (status, _) = config.get_failure_response();
509 assert!(config.status_codes.contains(&status));
510 }
511
512 #[test]
513 fn test_latency_injector_new() {
514 let injector = LatencyInjector::new(LatencyProfile::default(), FaultConfig::default());
515 assert!(injector.is_enabled());
516 }
517
518 #[test]
519 fn test_latency_injector_enable_disable() {
520 let mut injector = LatencyInjector::default();
521 assert!(injector.is_enabled());
522
523 injector.set_enabled(false);
524 assert!(!injector.is_enabled());
525
526 injector.set_enabled(true);
527 assert!(injector.is_enabled());
528 }
529
530 #[tokio::test]
531 async fn test_latency_injector_inject_latency() {
532 let injector = LatencyInjector::new(LatencyProfile::new(10, 0), FaultConfig::default());
533
534 let start = std::time::Instant::now();
535 injector.inject_latency(&[]).await.unwrap();
536 let elapsed = start.elapsed();
537
538 assert!(elapsed >= Duration::from_millis(8));
539 }
540
541 #[tokio::test]
542 async fn test_latency_injector_disabled_no_latency() {
543 let mut injector =
544 LatencyInjector::new(LatencyProfile::new(100, 0), FaultConfig::default());
545 injector.set_enabled(false);
546
547 let start = std::time::Instant::now();
548 injector.inject_latency(&[]).await.unwrap();
549 let elapsed = start.elapsed();
550
551 assert!(elapsed < Duration::from_millis(10));
552 }
553
554 #[test]
555 fn test_latency_injector_should_inject_failure() {
556 let injector = LatencyInjector::new(LatencyProfile::default(), FaultConfig::new(1.0));
557
558 assert!(injector.should_inject_failure());
559 }
560
561 #[test]
562 fn test_latency_injector_disabled_no_failure() {
563 let mut injector = LatencyInjector::new(LatencyProfile::default(), FaultConfig::new(1.0));
564 injector.set_enabled(false);
565
566 assert!(!injector.should_inject_failure());
567 }
568
569 #[tokio::test]
570 async fn test_latency_injector_process_request_no_failure() {
571 let injector = LatencyInjector::new(LatencyProfile::new(10, 0), FaultConfig::new(0.0));
572
573 let result = injector.process_request(&[]).await.unwrap();
574 assert!(result.is_none());
575 }
576
577 #[tokio::test]
578 async fn test_latency_injector_process_request_with_failure() {
579 let fault_config = FaultConfig {
580 failure_rate: 1.0,
581 status_codes: vec![503], ..Default::default()
583 };
584
585 let injector = LatencyInjector::new(LatencyProfile::new(10, 0), fault_config);
586
587 let result = injector.process_request(&[]).await.unwrap();
588 assert!(result.is_some());
589
590 let (status, _) = result.unwrap();
591 assert_eq!(status, 503);
592 }
593
594 #[tokio::test]
595 async fn test_latency_injector_process_request_disabled() {
596 let mut injector = LatencyInjector::new(LatencyProfile::new(100, 0), FaultConfig::new(1.0));
597 injector.set_enabled(false);
598
599 let result = injector.process_request(&[]).await.unwrap();
600 assert!(result.is_none());
601 }
602}