1use axum::body::Body;
3use axum::extract::State;
4use axum::http::{Request, StatusCode};
5use axum::middleware::Next;
6use axum::response::Response;
7use serde_json::Value;
8
9use crate::latency_profiles::LatencyProfiles;
10use mockforge_chaos::core_failure_injection::FailureInjector;
11use mockforge_chaos::core_traffic_shaping::TrafficShaper;
12use mockforge_core::Overrides;
13
14#[derive(Clone)]
16pub struct OperationMeta {
17 pub id: String,
19 pub tags: Vec<String>,
21 pub path: String,
23}
24
25#[derive(Clone)]
27pub struct Shared {
28 pub profiles: LatencyProfiles,
30 pub overrides: Overrides,
32 pub failure_injector: Option<FailureInjector>,
34 pub traffic_shaper: Option<TrafficShaper>,
36 pub overrides_enabled: bool,
38 pub traffic_shaping_enabled: bool,
40}
41
42pub async fn add_shared_extension(
44 State(shared): State<Shared>,
45 mut req: Request<Body>,
46 next: Next,
47) -> Response {
48 req.extensions_mut().insert(shared);
49 next.run(req).await
50}
51
52pub async fn fault_then_next(req: Request<Body>, next: Next) -> Response {
54 let shared = match req.extensions().get::<Shared>() {
55 Some(s) => s.clone(),
56 None => {
57 tracing::error!("Shared extension not found in request - ensure add_shared_extension middleware is configured");
58 let mut res =
59 Response::new(Body::from("Internal server error: middleware misconfiguration"));
60 *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
61 return res;
62 }
63 };
64 let op = req.extensions().get::<OperationMeta>().cloned();
65
66 if let Some(failure_injector) = &shared.failure_injector {
68 let tags = op.as_ref().map(|o| o.tags.as_slice()).unwrap_or(&[]);
69 if let Some((status_code, error_message)) = failure_injector.process_request(tags) {
70 let mut res = Response::new(Body::from(error_message));
71 *res.status_mut() =
72 StatusCode::from_u16(status_code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
73 return res;
74 }
75 }
76
77 if let Some(op) = &op {
79 if let Some((code, msg)) = shared
80 .profiles
81 .maybe_fault(&op.id, &op.tags.iter().map(|s| s.to_string()).collect::<Vec<_>>())
82 .await
83 {
84 let mut res = Response::new(Body::from(msg));
85 *res.status_mut() =
86 StatusCode::from_u16(code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
87 return res;
88 }
89 }
90
91 if shared.traffic_shaping_enabled {
93 if let Some(traffic_shaper) = &shared.traffic_shaper {
94 let request_size = calculate_request_size(&req);
96
97 let tags = op.as_ref().map(|o| o.tags.as_slice()).unwrap_or(&[]);
98
99 match traffic_shaper.process_transfer(request_size, tags).await {
101 Ok(Some(_timeout)) => {
102 let mut res =
104 Response::new(Body::from("Request timeout due to traffic shaping"));
105 *res.status_mut() = StatusCode::REQUEST_TIMEOUT;
106 return res;
107 }
108 Ok(None) => {
109 }
111 Err(e) => {
112 let mut res =
114 Response::new(Body::from(format!("Traffic shaping error: {}", e)));
115 *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
116 return res;
117 }
118 }
119 }
120 }
121
122 let (parts, body) = req.into_parts();
123 let req = Request::from_parts(parts, body);
124
125 let response = next.run(req).await;
126
127 if shared.traffic_shaping_enabled {
129 if let Some(traffic_shaper) = &shared.traffic_shaper {
130 let response_size = calculate_response_size(&response);
132
133 let tags = op.as_ref().map(|o| o.tags.as_slice()).unwrap_or(&[]);
134
135 match traffic_shaper.process_transfer(response_size, tags).await {
137 Ok(Some(_timeout)) => {
138 let mut res =
140 Response::new(Body::from("Response timeout due to traffic shaping"));
141 *res.status_mut() = StatusCode::GATEWAY_TIMEOUT;
142 return res;
143 }
144 Ok(None) => {
145 }
147 Err(e) => {
148 let mut res =
150 Response::new(Body::from(format!("Traffic shaping error: {}", e)));
151 *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
152 return res;
153 }
154 }
155 }
156 }
157
158 response
159}
160
161pub fn apply_overrides(shared: &Shared, op: Option<&OperationMeta>, body: &mut Value) {
168 if shared.overrides_enabled {
169 if let Some(op) = op {
170 shared.overrides.apply(
171 &op.id,
172 &op.tags.iter().map(|s| s.to_string()).collect::<Vec<_>>(),
173 &op.path,
174 body,
175 );
176 }
177 }
178}
179
180fn calculate_request_size<B>(req: &Request<B>) -> u64 {
182 let mut size = 0u64;
183
184 for (name, value) in req.headers() {
186 size += name.as_str().len() as u64;
187 size += value.as_bytes().len() as u64;
188 }
189
190 size += req.uri().to_string().len() as u64;
192
193 if let Some(content_length) = req.headers().get(http::header::CONTENT_LENGTH) {
195 if let Ok(len_str) = content_length.to_str() {
196 if let Ok(len) = len_str.parse::<u64>() {
197 size += len;
198 return size;
199 }
200 }
201 }
202
203 let method = req.method();
205 if method == http::Method::POST || method == http::Method::PUT || method == http::Method::PATCH
206 {
207 size += 256; }
209
210 size
211}
212
213fn calculate_response_size(res: &Response) -> u64 {
215 let mut size = 0u64;
216
217 for (name, value) in res.headers() {
219 size += name.as_str().len() as u64;
220 size += value.as_bytes().len() as u64;
221 }
222
223 size += 15; if let Some(content_length) = res.headers().get(http::header::CONTENT_LENGTH) {
228 if let Ok(len_str) = content_length.to_str() {
229 if let Ok(len) = len_str.parse::<u64>() {
230 size += len;
231 return size;
232 }
233 }
234 }
235
236 match res.status().as_u16() {
238 204 | 304 => {} _ => size += 256, }
241
242 size
243}
244
245#[cfg(test)]
246mod tests {
247 use super::*;
248 use axum::http::{Request, Response, StatusCode};
249 use serde_json::json;
250
251 #[test]
252 fn test_operation_meta_creation() {
253 let meta = OperationMeta {
254 id: "getUserById".to_string(),
255 tags: vec!["users".to_string(), "public".to_string()],
256 path: "/users/{id}".to_string(),
257 };
258
259 assert_eq!(meta.id, "getUserById");
260 assert_eq!(meta.tags.len(), 2);
261 assert_eq!(meta.path, "/users/{id}");
262 }
263
264 #[test]
265 fn test_shared_creation() {
266 let shared = Shared {
267 profiles: LatencyProfiles::default(),
268 overrides: Overrides::default(),
269 failure_injector: None,
270 traffic_shaper: None,
271 overrides_enabled: false,
272 traffic_shaping_enabled: false,
273 };
274
275 assert!(!shared.overrides_enabled);
276 assert!(!shared.traffic_shaping_enabled);
277 assert!(shared.failure_injector.is_none());
278 assert!(shared.traffic_shaper.is_none());
279 }
280
281 #[test]
282 fn test_shared_with_failure_injector() {
283 let failure_injector = FailureInjector::new(None, true);
284 let shared = Shared {
285 profiles: LatencyProfiles::default(),
286 overrides: Overrides::default(),
287 failure_injector: Some(failure_injector),
288 traffic_shaper: None,
289 overrides_enabled: false,
290 traffic_shaping_enabled: false,
291 };
292
293 assert!(shared.failure_injector.is_some());
294 }
295
296 #[test]
297 fn test_apply_overrides_disabled() {
298 let shared = Shared {
299 profiles: LatencyProfiles::default(),
300 overrides: Overrides::default(),
301 failure_injector: None,
302 traffic_shaper: None,
303 overrides_enabled: false,
304 traffic_shaping_enabled: false,
305 };
306
307 let op = OperationMeta {
308 id: "getUser".to_string(),
309 tags: vec![],
310 path: "/users".to_string(),
311 };
312
313 let mut body = json!({"name": "John"});
314 let original = body.clone();
315
316 apply_overrides(&shared, Some(&op), &mut body);
317
318 assert_eq!(body, original);
320 }
321
322 #[test]
323 fn test_apply_overrides_enabled_no_rules() {
324 let shared = Shared {
325 profiles: LatencyProfiles::default(),
326 overrides: Overrides::default(),
327 failure_injector: None,
328 traffic_shaper: None,
329 overrides_enabled: true,
330 traffic_shaping_enabled: false,
331 };
332
333 let op = OperationMeta {
334 id: "getUser".to_string(),
335 tags: vec![],
336 path: "/users".to_string(),
337 };
338
339 let mut body = json!({"name": "John"});
340 let original = body.clone();
341
342 apply_overrides(&shared, Some(&op), &mut body);
343
344 assert_eq!(body, original);
346 }
347
348 #[test]
349 fn test_apply_overrides_with_none_operation() {
350 let shared = Shared {
351 profiles: LatencyProfiles::default(),
352 overrides: Overrides::default(),
353 failure_injector: None,
354 traffic_shaper: None,
355 overrides_enabled: true,
356 traffic_shaping_enabled: false,
357 };
358
359 let mut body = json!({"name": "John"});
360 let original = body.clone();
361
362 apply_overrides(&shared, None, &mut body);
363
364 assert_eq!(body, original);
366 }
367
368 #[test]
369 fn test_calculate_request_size_basic() {
370 let req = Request::builder()
371 .uri("/test")
372 .header("content-type", "application/json")
373 .body(())
374 .unwrap();
375
376 let size = calculate_request_size(&req);
377
378 assert!(size > 0);
380 assert!(size >= "/test".len() as u64 + "content-type".len() as u64);
382 }
383
384 #[test]
385 fn test_calculate_request_size_with_multiple_headers() {
386 let req = Request::builder()
387 .uri("/api/users")
388 .header("content-type", "application/json")
389 .header("authorization", "Bearer token123")
390 .header("user-agent", "test-client")
391 .body(())
392 .unwrap();
393
394 let size = calculate_request_size(&req);
395
396 assert!(size > 50);
398 }
399
400 #[test]
401 fn test_calculate_response_size_basic() {
402 let res = Response::builder()
403 .status(StatusCode::OK)
404 .header("content-type", "application/json")
405 .body(axum::body::Body::empty())
406 .unwrap();
407
408 let size = calculate_response_size(&res);
409
410 assert!(size > 0);
412 assert!(size >= 50);
414 }
415
416 #[test]
417 fn test_calculate_response_size_with_multiple_headers() {
418 let res = Response::builder()
419 .status(StatusCode::OK)
420 .header("content-type", "application/json")
421 .header("cache-control", "no-cache")
422 .header("x-request-id", "123-456-789")
423 .body(axum::body::Body::empty())
424 .unwrap();
425
426 let size = calculate_response_size(&res);
427
428 assert!(size > 100);
430 }
431
432 #[test]
433 fn test_shared_clone() {
434 let shared = Shared {
435 profiles: LatencyProfiles::default(),
436 overrides: Overrides::default(),
437 failure_injector: None,
438 traffic_shaper: None,
439 overrides_enabled: true,
440 traffic_shaping_enabled: true,
441 };
442
443 let cloned = shared.clone();
444
445 assert_eq!(shared.overrides_enabled, cloned.overrides_enabled);
446 assert_eq!(shared.traffic_shaping_enabled, cloned.traffic_shaping_enabled);
447 }
448
449 #[test]
450 fn test_operation_meta_clone() {
451 let meta = OperationMeta {
452 id: "testOp".to_string(),
453 tags: vec!["tag1".to_string()],
454 path: "/test".to_string(),
455 };
456
457 let cloned = meta.clone();
458
459 assert_eq!(meta.id, cloned.id);
460 assert_eq!(meta.tags, cloned.tags);
461 assert_eq!(meta.path, cloned.path);
462 }
463}