1use axum::body::Body;
3use axum::extract::State;
4use axum::http::{Request, StatusCode};
5use axum::middleware::Next;
6use axum::response::Response;
7use serde_json::Value;
8
9use crate::latency_profiles::LatencyProfiles;
10use mockforge_core::{FailureInjector, Overrides, TrafficShaper};
11
12#[derive(Clone)]
14pub struct OperationMeta {
15 pub id: String,
17 pub tags: Vec<String>,
19 pub path: String,
21}
22
23#[derive(Clone)]
25pub struct Shared {
26 pub profiles: LatencyProfiles,
28 pub overrides: Overrides,
30 pub failure_injector: Option<FailureInjector>,
32 pub traffic_shaper: Option<TrafficShaper>,
34 pub overrides_enabled: bool,
36 pub traffic_shaping_enabled: bool,
38}
39
40pub async fn add_shared_extension(
42 State(shared): State<Shared>,
43 mut req: Request<Body>,
44 next: Next,
45) -> Response {
46 req.extensions_mut().insert(shared);
47 next.run(req).await
48}
49
50pub async fn fault_then_next(req: Request<Body>, next: Next) -> Response {
52 let shared = match req.extensions().get::<Shared>() {
53 Some(s) => s.clone(),
54 None => {
55 tracing::error!("Shared extension not found in request - ensure add_shared_extension middleware is configured");
56 let mut res =
57 Response::new(Body::from("Internal server error: middleware misconfiguration"));
58 *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
59 return res;
60 }
61 };
62 let op = req.extensions().get::<OperationMeta>().cloned();
63
64 if let Some(failure_injector) = &shared.failure_injector {
66 let tags = op.as_ref().map(|o| o.tags.as_slice()).unwrap_or(&[]);
67 if let Some((status_code, error_message)) = failure_injector.process_request(tags) {
68 let mut res = Response::new(Body::from(error_message));
69 *res.status_mut() =
70 StatusCode::from_u16(status_code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
71 return res;
72 }
73 }
74
75 if let Some(op) = &op {
77 if let Some((code, msg)) = shared
78 .profiles
79 .maybe_fault(&op.id, &op.tags.iter().map(|s| s.to_string()).collect::<Vec<_>>())
80 .await
81 {
82 let mut res = Response::new(Body::from(msg));
83 *res.status_mut() =
84 StatusCode::from_u16(code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
85 return res;
86 }
87 }
88
89 if shared.traffic_shaping_enabled {
91 if let Some(traffic_shaper) = &shared.traffic_shaper {
92 let request_size = calculate_request_size(&req);
94
95 let tags = op.as_ref().map(|o| o.tags.as_slice()).unwrap_or(&[]);
96
97 match traffic_shaper.process_transfer(request_size, tags).await {
99 Ok(Some(_timeout)) => {
100 let mut res =
102 Response::new(Body::from("Request timeout due to traffic shaping"));
103 *res.status_mut() = StatusCode::REQUEST_TIMEOUT;
104 return res;
105 }
106 Ok(None) => {
107 }
109 Err(e) => {
110 let mut res =
112 Response::new(Body::from(format!("Traffic shaping error: {}", e)));
113 *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
114 return res;
115 }
116 }
117 }
118 }
119
120 let (parts, body) = req.into_parts();
121 let req = Request::from_parts(parts, body);
122
123 let response = next.run(req).await;
124
125 if shared.traffic_shaping_enabled {
127 if let Some(traffic_shaper) = &shared.traffic_shaper {
128 let response_size = calculate_response_size(&response);
130
131 let tags = op.as_ref().map(|o| o.tags.as_slice()).unwrap_or(&[]);
132
133 match traffic_shaper.process_transfer(response_size, tags).await {
135 Ok(Some(_timeout)) => {
136 let mut res =
138 Response::new(Body::from("Response timeout due to traffic shaping"));
139 *res.status_mut() = StatusCode::GATEWAY_TIMEOUT;
140 return res;
141 }
142 Ok(None) => {
143 }
145 Err(e) => {
146 let mut res =
148 Response::new(Body::from(format!("Traffic shaping error: {}", e)));
149 *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
150 return res;
151 }
152 }
153 }
154 }
155
156 response
157}
158
159pub fn apply_overrides(shared: &Shared, op: Option<&OperationMeta>, body: &mut Value) {
166 if shared.overrides_enabled {
167 if let Some(op) = op {
168 shared.overrides.apply(
169 &op.id,
170 &op.tags.iter().map(|s| s.to_string()).collect::<Vec<_>>(),
171 &op.path,
172 body,
173 );
174 }
175 }
176}
177
178fn calculate_request_size<B>(req: &Request<B>) -> u64 {
180 let mut size = 0u64;
181
182 for (name, value) in req.headers() {
184 size += name.as_str().len() as u64;
185 size += value.as_bytes().len() as u64;
186 }
187
188 size += req.uri().to_string().len() as u64;
190
191 if let Some(content_length) = req.headers().get(http::header::CONTENT_LENGTH) {
193 if let Ok(len_str) = content_length.to_str() {
194 if let Ok(len) = len_str.parse::<u64>() {
195 size += len;
196 return size;
197 }
198 }
199 }
200
201 let method = req.method();
203 if method == http::Method::POST || method == http::Method::PUT || method == http::Method::PATCH
204 {
205 size += 256; }
207
208 size
209}
210
211fn calculate_response_size(res: &Response) -> u64 {
213 let mut size = 0u64;
214
215 for (name, value) in res.headers() {
217 size += name.as_str().len() as u64;
218 size += value.as_bytes().len() as u64;
219 }
220
221 size += 15; if let Some(content_length) = res.headers().get(http::header::CONTENT_LENGTH) {
226 if let Ok(len_str) = content_length.to_str() {
227 if let Ok(len) = len_str.parse::<u64>() {
228 size += len;
229 return size;
230 }
231 }
232 }
233
234 match res.status().as_u16() {
236 204 | 304 => {} _ => size += 256, }
239
240 size
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246 use axum::http::{Request, Response, StatusCode};
247 use serde_json::json;
248
249 #[test]
250 fn test_operation_meta_creation() {
251 let meta = OperationMeta {
252 id: "getUserById".to_string(),
253 tags: vec!["users".to_string(), "public".to_string()],
254 path: "/users/{id}".to_string(),
255 };
256
257 assert_eq!(meta.id, "getUserById");
258 assert_eq!(meta.tags.len(), 2);
259 assert_eq!(meta.path, "/users/{id}");
260 }
261
262 #[test]
263 fn test_shared_creation() {
264 let shared = Shared {
265 profiles: LatencyProfiles::default(),
266 overrides: Overrides::default(),
267 failure_injector: None,
268 traffic_shaper: None,
269 overrides_enabled: false,
270 traffic_shaping_enabled: false,
271 };
272
273 assert!(!shared.overrides_enabled);
274 assert!(!shared.traffic_shaping_enabled);
275 assert!(shared.failure_injector.is_none());
276 assert!(shared.traffic_shaper.is_none());
277 }
278
279 #[test]
280 fn test_shared_with_failure_injector() {
281 let failure_injector = FailureInjector::new(None, true);
282 let shared = Shared {
283 profiles: LatencyProfiles::default(),
284 overrides: Overrides::default(),
285 failure_injector: Some(failure_injector),
286 traffic_shaper: None,
287 overrides_enabled: false,
288 traffic_shaping_enabled: false,
289 };
290
291 assert!(shared.failure_injector.is_some());
292 }
293
294 #[test]
295 fn test_apply_overrides_disabled() {
296 let shared = Shared {
297 profiles: LatencyProfiles::default(),
298 overrides: Overrides::default(),
299 failure_injector: None,
300 traffic_shaper: None,
301 overrides_enabled: false,
302 traffic_shaping_enabled: false,
303 };
304
305 let op = OperationMeta {
306 id: "getUser".to_string(),
307 tags: vec![],
308 path: "/users".to_string(),
309 };
310
311 let mut body = json!({"name": "John"});
312 let original = body.clone();
313
314 apply_overrides(&shared, Some(&op), &mut body);
315
316 assert_eq!(body, original);
318 }
319
320 #[test]
321 fn test_apply_overrides_enabled_no_rules() {
322 let shared = Shared {
323 profiles: LatencyProfiles::default(),
324 overrides: Overrides::default(),
325 failure_injector: None,
326 traffic_shaper: None,
327 overrides_enabled: true,
328 traffic_shaping_enabled: false,
329 };
330
331 let op = OperationMeta {
332 id: "getUser".to_string(),
333 tags: vec![],
334 path: "/users".to_string(),
335 };
336
337 let mut body = json!({"name": "John"});
338 let original = body.clone();
339
340 apply_overrides(&shared, Some(&op), &mut body);
341
342 assert_eq!(body, original);
344 }
345
346 #[test]
347 fn test_apply_overrides_with_none_operation() {
348 let shared = Shared {
349 profiles: LatencyProfiles::default(),
350 overrides: Overrides::default(),
351 failure_injector: None,
352 traffic_shaper: None,
353 overrides_enabled: true,
354 traffic_shaping_enabled: false,
355 };
356
357 let mut body = json!({"name": "John"});
358 let original = body.clone();
359
360 apply_overrides(&shared, None, &mut body);
361
362 assert_eq!(body, original);
364 }
365
366 #[test]
367 fn test_calculate_request_size_basic() {
368 let req = Request::builder()
369 .uri("/test")
370 .header("content-type", "application/json")
371 .body(())
372 .unwrap();
373
374 let size = calculate_request_size(&req);
375
376 assert!(size > 0);
378 assert!(size >= "/test".len() as u64 + "content-type".len() as u64);
380 }
381
382 #[test]
383 fn test_calculate_request_size_with_multiple_headers() {
384 let req = Request::builder()
385 .uri("/api/users")
386 .header("content-type", "application/json")
387 .header("authorization", "Bearer token123")
388 .header("user-agent", "test-client")
389 .body(())
390 .unwrap();
391
392 let size = calculate_request_size(&req);
393
394 assert!(size > 100); }
397
398 #[test]
399 fn test_calculate_response_size_basic() {
400 let res = Response::builder()
401 .status(StatusCode::OK)
402 .header("content-type", "application/json")
403 .body(axum::body::Body::empty())
404 .unwrap();
405
406 let size = calculate_response_size(&res);
407
408 assert!(size > 0);
410 assert!(size >= 50);
412 }
413
414 #[test]
415 fn test_calculate_response_size_with_multiple_headers() {
416 let res = Response::builder()
417 .status(StatusCode::OK)
418 .header("content-type", "application/json")
419 .header("cache-control", "no-cache")
420 .header("x-request-id", "123-456-789")
421 .body(axum::body::Body::empty())
422 .unwrap();
423
424 let size = calculate_response_size(&res);
425
426 assert!(size > 100);
428 }
429
430 #[test]
431 fn test_shared_clone() {
432 let shared = Shared {
433 profiles: LatencyProfiles::default(),
434 overrides: Overrides::default(),
435 failure_injector: None,
436 traffic_shaper: None,
437 overrides_enabled: true,
438 traffic_shaping_enabled: true,
439 };
440
441 let cloned = shared.clone();
442
443 assert_eq!(shared.overrides_enabled, cloned.overrides_enabled);
444 assert_eq!(shared.traffic_shaping_enabled, cloned.traffic_shaping_enabled);
445 }
446
447 #[test]
448 fn test_operation_meta_clone() {
449 let meta = OperationMeta {
450 id: "testOp".to_string(),
451 tags: vec!["tag1".to_string()],
452 path: "/test".to_string(),
453 };
454
455 let cloned = meta.clone();
456
457 assert_eq!(meta.id, cloned.id);
458 assert_eq!(meta.tags, cloned.tags);
459 assert_eq!(meta.path, cloned.path);
460 }
461}