1use crate::config::{
8 LatencyDistribution, RouteConfig, RouteFaultInjectionConfig, RouteFaultType, RouteLatencyConfig,
9};
10use crate::{Error, Result};
11use axum::http::{HeaderMap, Method, StatusCode, Uri};
12use rand::{rng, Rng};
13use regex::Regex;
14use std::collections::HashMap;
15use std::time::Duration;
16use tokio::time::sleep;
17use tracing::debug;
18
19#[derive(Debug, Clone)]
21pub struct RouteMatcher {
22 routes: Vec<CompiledRoute>,
24}
25
26#[derive(Debug, Clone)]
28struct CompiledRoute {
29 config: RouteConfig,
31 path_pattern: Regex,
33 method: Method,
35}
36
37impl RouteMatcher {
38 pub fn new(routes: Vec<RouteConfig>) -> Result<Self> {
40 let mut compiled_routes = Vec::new();
41
42 for route in routes {
43 let path_pattern = Self::compile_path_pattern(&route.path)?;
45 let method = route.method.parse::<Method>().map_err(|e| {
46 Error::generic(format!("Invalid HTTP method '{}': {}", route.method, e))
47 })?;
48
49 compiled_routes.push(CompiledRoute {
50 config: route,
51 path_pattern,
52 method,
53 });
54 }
55
56 Ok(Self {
57 routes: compiled_routes,
58 })
59 }
60
61 pub fn match_route(&self, method: &Method, uri: &Uri) -> Option<&RouteConfig> {
63 let path = uri.path();
64
65 for compiled_route in &self.routes {
66 if &compiled_route.method != method {
68 continue;
69 }
70
71 if compiled_route.path_pattern.is_match(path) {
73 return Some(&compiled_route.config);
74 }
75 }
76
77 None
78 }
79
80 fn compile_path_pattern(pattern: &str) -> Result<Regex> {
83 let mut regex_pattern = String::new();
85 let mut chars = pattern.chars().peekable();
86
87 while let Some(ch) = chars.next() {
88 match ch {
89 '{' => {
90 let mut param_name = String::new();
92 while let Some(&next_ch) = chars.peek() {
93 if next_ch == '}' {
94 chars.next(); regex_pattern.push_str("([^/]+)");
97 break;
98 } else {
99 param_name.push(chars.next().unwrap());
100 }
101 }
102 }
103 '*' => {
104 regex_pattern.push_str(".*");
106 }
107 ch if ".+?^$|\\[]()".contains(ch) => {
108 regex_pattern.push('\\');
110 regex_pattern.push(ch);
111 }
112 ch => {
113 regex_pattern.push(ch);
114 }
115 }
116 }
117
118 let full_pattern = format!("^{}$", regex_pattern);
120 Regex::new(&full_pattern)
121 .map_err(|e| Error::generic(format!("Invalid route pattern '{}': {}", pattern, e)))
122 }
123}
124
125#[derive(Debug, Clone)]
127pub struct RouteChaosInjector {
128 matcher: RouteMatcher,
130}
131
132impl RouteChaosInjector {
133 pub fn new(routes: Vec<RouteConfig>) -> Result<Self> {
135 let matcher = RouteMatcher::new(routes)?;
136 Ok(Self { matcher })
137 }
138
139 pub fn should_inject_fault(
141 &self,
142 method: &Method,
143 uri: &Uri,
144 ) -> Option<RouteFaultInjectionResult> {
145 let route = self.matcher.match_route(method, uri)?;
146 let fault_config = route.fault_injection.as_ref()?;
147
148 if !fault_config.enabled {
149 return None;
150 }
151
152 let mut rng = rng();
154 if rng.random::<f64>() > fault_config.probability {
155 return None;
156 }
157
158 if fault_config.fault_types.is_empty() {
160 return None;
161 }
162
163 let fault_type =
164 &fault_config.fault_types[rng.random_range(0..fault_config.fault_types.len())];
165
166 Some(RouteFaultInjectionResult {
167 fault_type: fault_type.clone(),
168 })
169 }
170
171 pub async fn inject_latency(&self, method: &Method, uri: &Uri) -> Result<()> {
173 let route = match self.matcher.match_route(method, uri) {
174 Some(r) => r,
175 None => return Ok(()), };
177
178 let latency_config = match &route.latency {
179 Some(cfg) => cfg,
180 None => return Ok(()), };
182
183 if !latency_config.enabled {
184 return Ok(());
185 }
186
187 let mut rng = rng();
189 if rng.random::<f64>() > latency_config.probability {
190 return Ok(());
191 }
192
193 let delay_ms = self.calculate_delay(latency_config)?;
194 if delay_ms > 0 {
195 debug!("Injecting per-route latency: {}ms for {} {}", delay_ms, method, uri.path());
196 sleep(Duration::from_millis(delay_ms)).await;
197 }
198
199 Ok(())
200 }
201
202 fn calculate_delay(&self, config: &RouteLatencyConfig) -> Result<u64> {
204 let mut rng = rng();
205
206 let base_delay = match &config.distribution {
207 LatencyDistribution::Fixed => config.fixed_delay_ms.unwrap_or(0),
208 LatencyDistribution::Normal {
209 mean_ms,
210 std_dev_ms,
211 } => {
212 let u1: f64 = rng.random();
214 let u2: f64 = rng.random();
215 let z0 = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
216 let value = mean_ms + std_dev_ms * z0;
217 value.max(0.0) as u64
218 }
219 LatencyDistribution::Exponential { lambda } => {
220 let u: f64 = rng.random();
222 let value = -lambda.ln() * (1.0 - u);
223 value.max(0.0) as u64
224 }
225 LatencyDistribution::Uniform => {
226 if let Some((min, max)) = config.random_delay_range_ms {
227 rng.random_range(min..=max)
228 } else {
229 config.fixed_delay_ms.unwrap_or(0)
230 }
231 }
232 };
233
234 let delay = if config.jitter_percent > 0.0 {
236 let jitter = (base_delay as f64 * config.jitter_percent / 100.0) as u64;
237 let jitter_offset = rng.random_range(0..=jitter);
238 if rng.random_bool(0.5) {
239 base_delay + jitter_offset
240 } else {
241 base_delay.saturating_sub(jitter_offset)
242 }
243 } else {
244 base_delay
245 };
246
247 Ok(delay)
248 }
249
250 pub fn get_fault_response(&self, method: &Method, uri: &Uri) -> Option<RouteFaultResponse> {
252 let fault_result = self.should_inject_fault(method, uri)?;
253
254 match &fault_result.fault_type {
255 RouteFaultType::HttpError {
256 status_code,
257 message,
258 } => Some(RouteFaultResponse {
259 status_code: *status_code,
260 error_message: message
261 .clone()
262 .unwrap_or_else(|| format!("Injected HTTP error {}", status_code)),
263 fault_type: "http_error".to_string(),
264 }),
265 RouteFaultType::ConnectionError { message } => Some(RouteFaultResponse {
266 status_code: 503,
267 error_message: message.clone().unwrap_or_else(|| "Connection error".to_string()),
268 fault_type: "connection_error".to_string(),
269 }),
270 RouteFaultType::Timeout {
271 duration_ms,
272 message,
273 } => Some(RouteFaultResponse {
274 status_code: 504,
275 error_message: message
276 .clone()
277 .unwrap_or_else(|| format!("Request timeout after {}ms", duration_ms)),
278 fault_type: "timeout".to_string(),
279 }),
280 RouteFaultType::PartialResponse { truncate_percent } => Some(RouteFaultResponse {
281 status_code: 200,
282 error_message: format!("Partial response (truncated at {}%)", truncate_percent),
283 fault_type: "partial_response".to_string(),
284 }),
285 RouteFaultType::PayloadCorruption { corruption_type } => Some(RouteFaultResponse {
286 status_code: 200,
287 error_message: format!("Payload corruption ({})", corruption_type),
288 fault_type: "payload_corruption".to_string(),
289 }),
290 }
291 }
292}
293
294#[derive(Debug, Clone)]
296pub struct RouteFaultInjectionResult {
297 pub fault_type: RouteFaultType,
299}
300
301#[derive(Debug, Clone)]
303pub struct RouteFaultResponse {
304 pub status_code: u16,
306 pub error_message: String,
308 pub fault_type: String,
310}
311
312#[cfg(test)]
313mod tests {
314 use super::*;
315 use crate::config::{RouteConfig, RouteResponseConfig};
316
317 fn create_test_route(path: &str, method: &str) -> RouteConfig {
318 RouteConfig {
319 path: path.to_string(),
320 method: method.to_string(),
321 request: None,
322 response: RouteResponseConfig {
323 status: 200,
324 headers: HashMap::new(),
325 body: None,
326 },
327 fault_injection: None,
328 latency: None,
329 }
330 }
331
332 #[test]
333 fn test_path_pattern_compilation() {
334 let pattern = RouteMatcher::compile_path_pattern("/users/{id}").unwrap();
335 assert!(pattern.is_match("/users/123"));
336 assert!(pattern.is_match("/users/abc"));
337 assert!(!pattern.is_match("/users/123/posts"));
338 assert!(!pattern.is_match("/users"));
339 }
340
341 #[test]
342 fn test_route_matching() {
343 let routes = vec![
344 create_test_route("/users/{id}", "GET"),
345 create_test_route("/orders/{order_id}", "POST"),
346 create_test_route("/health", "GET"),
347 ];
348
349 let matcher = RouteMatcher::new(routes).unwrap();
350
351 let get_users = Method::GET;
352 let post_orders = Method::POST;
353 let get_health = Method::GET;
354
355 assert!(matcher.match_route(&get_users, &Uri::from_static("/users/123")).is_some());
356 assert!(matcher.match_route(&post_orders, &Uri::from_static("/orders/456")).is_some());
357 assert!(matcher.match_route(&get_health, &Uri::from_static("/health")).is_some());
358 assert!(matcher.match_route(&get_users, &Uri::from_static("/unknown")).is_none());
359 }
360
361 #[tokio::test]
362 async fn test_latency_injection() {
363 use crate::config::RouteLatencyConfig;
364
365 let mut route = create_test_route("/test", "GET");
366 route.latency = Some(RouteLatencyConfig {
367 enabled: true,
368 probability: 1.0,
369 fixed_delay_ms: Some(10),
370 random_delay_range_ms: None,
371 jitter_percent: 0.0,
372 distribution: LatencyDistribution::Fixed,
373 });
374
375 let injector = RouteChaosInjector::new(vec![route]).unwrap();
376 let start = std::time::Instant::now();
377 injector.inject_latency(&Method::GET, &Uri::from_static("/test")).await.unwrap();
378 let elapsed = start.elapsed();
379
380 assert!(elapsed >= Duration::from_millis(10));
381 }
382
383 #[test]
384 fn test_fault_injection() {
385 use crate::config::{RouteFaultInjectionConfig, RouteFaultType};
386
387 let mut route = create_test_route("/test", "GET");
388 route.fault_injection = Some(RouteFaultInjectionConfig {
389 enabled: true,
390 probability: 1.0,
391 fault_types: vec![RouteFaultType::HttpError {
392 status_code: 500,
393 message: Some("Test error".to_string()),
394 }],
395 });
396
397 let injector = RouteChaosInjector::new(vec![route]).unwrap();
398 let response =
399 injector.get_fault_response(&Method::GET, &Uri::from_static("/test")).unwrap();
400
401 assert_eq!(response.status_code, 500);
402 assert_eq!(response.error_message, "Test error");
403 }
404}