1use axum::body::Body;
3use axum::extract::State;
4use axum::http::{Request, StatusCode};
5use axum::middleware::Next;
6use axum::response::Response;
7use serde_json::Value;
8
9use crate::latency_profiles::LatencyProfiles;
10use mockforge_core::{FailureInjector, Overrides, TrafficShaper};
11
12#[derive(Clone)]
14pub struct OperationMeta {
15 pub id: String,
17 pub tags: Vec<String>,
19 pub path: String,
21}
22
23#[derive(Clone)]
25pub struct Shared {
26 pub profiles: LatencyProfiles,
28 pub overrides: Overrides,
30 pub failure_injector: Option<FailureInjector>,
32 pub traffic_shaper: Option<TrafficShaper>,
34 pub overrides_enabled: bool,
36 pub traffic_shaping_enabled: bool,
38}
39
40pub async fn add_shared_extension(
42 State(shared): State<Shared>,
43 mut req: Request<Body>,
44 next: Next,
45) -> Response {
46 req.extensions_mut().insert(shared);
47 next.run(req).await
48}
49
50pub async fn fault_then_next(req: Request<Body>, next: Next) -> Response {
52 let shared = match req.extensions().get::<Shared>() {
53 Some(s) => s.clone(),
54 None => {
55 tracing::error!("Shared extension not found in request - ensure add_shared_extension middleware is configured");
56 let mut res =
57 Response::new(Body::from("Internal server error: middleware misconfiguration"));
58 *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
59 return res;
60 }
61 };
62 let op = req.extensions().get::<OperationMeta>().cloned();
63
64 if let Some(failure_injector) = &shared.failure_injector {
66 let tags = op.as_ref().map(|o| o.tags.as_slice()).unwrap_or(&[]);
67 if let Some((status_code, error_message)) = failure_injector.process_request(tags) {
68 let mut res = Response::new(Body::from(error_message));
69 *res.status_mut() =
70 StatusCode::from_u16(status_code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
71 return res;
72 }
73 }
74
75 if let Some(op) = &op {
77 if let Some((code, msg)) = shared
78 .profiles
79 .maybe_fault(&op.id, &op.tags.iter().map(|s| s.to_string()).collect::<Vec<_>>())
80 .await
81 {
82 let mut res = Response::new(Body::from(msg));
83 *res.status_mut() =
84 StatusCode::from_u16(code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
85 return res;
86 }
87 }
88
89 if shared.traffic_shaping_enabled {
91 if let Some(traffic_shaper) = &shared.traffic_shaper {
92 let request_size = calculate_request_size(&req);
94
95 let tags = op.as_ref().map(|o| o.tags.as_slice()).unwrap_or(&[]);
96
97 match traffic_shaper.process_transfer(request_size, tags).await {
99 Ok(Some(_timeout)) => {
100 let mut res =
102 Response::new(Body::from("Request timeout due to traffic shaping"));
103 *res.status_mut() = StatusCode::REQUEST_TIMEOUT;
104 return res;
105 }
106 Ok(None) => {
107 }
109 Err(e) => {
110 let mut res =
112 Response::new(Body::from(format!("Traffic shaping error: {}", e)));
113 *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
114 return res;
115 }
116 }
117 }
118 }
119
120 let (parts, body) = req.into_parts();
121 let req = Request::from_parts(parts, body);
122
123 let response = next.run(req).await;
124
125 if shared.traffic_shaping_enabled {
127 if let Some(traffic_shaper) = &shared.traffic_shaper {
128 let response_size = calculate_response_size(&response);
130
131 let tags = op.as_ref().map(|o| o.tags.as_slice()).unwrap_or(&[]);
132
133 match traffic_shaper.process_transfer(response_size, tags).await {
135 Ok(Some(_timeout)) => {
136 let mut res =
138 Response::new(Body::from("Response timeout due to traffic shaping"));
139 *res.status_mut() = StatusCode::GATEWAY_TIMEOUT;
140 return res;
141 }
142 Ok(None) => {
143 }
145 Err(e) => {
146 let mut res =
148 Response::new(Body::from(format!("Traffic shaping error: {}", e)));
149 *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
150 return res;
151 }
152 }
153 }
154 }
155
156 response
157}
158
159pub fn apply_overrides(shared: &Shared, op: Option<&OperationMeta>, body: &mut Value) {
166 if shared.overrides_enabled {
167 if let Some(op) = op {
168 shared.overrides.apply(
169 &op.id,
170 &op.tags.iter().map(|s| s.to_string()).collect::<Vec<_>>(),
171 &op.path,
172 body,
173 );
174 }
175 }
176}
177
178fn calculate_request_size<B>(req: &Request<B>) -> u64 {
180 let mut size = 0u64;
181
182 for (name, value) in req.headers() {
184 size += name.as_str().len() as u64;
185 size += value.as_bytes().len() as u64;
186 }
187
188 size += req.uri().to_string().len() as u64;
190
191 size += 1024; size
198}
199
200fn calculate_response_size(res: &Response) -> u64 {
202 let mut size = 0u64;
203
204 for (name, value) in res.headers() {
206 size += name.as_str().len() as u64;
207 size += value.as_bytes().len() as u64;
208 }
209
210 size += 50;
212
213 size += 2048; size
218}
219
220#[cfg(test)]
221mod tests {
222 use super::*;
223 use axum::http::{Request, Response, StatusCode};
224 use serde_json::json;
225
226 #[test]
227 fn test_operation_meta_creation() {
228 let meta = OperationMeta {
229 id: "getUserById".to_string(),
230 tags: vec!["users".to_string(), "public".to_string()],
231 path: "/users/{id}".to_string(),
232 };
233
234 assert_eq!(meta.id, "getUserById");
235 assert_eq!(meta.tags.len(), 2);
236 assert_eq!(meta.path, "/users/{id}");
237 }
238
239 #[test]
240 fn test_shared_creation() {
241 let shared = Shared {
242 profiles: LatencyProfiles::default(),
243 overrides: Overrides::default(),
244 failure_injector: None,
245 traffic_shaper: None,
246 overrides_enabled: false,
247 traffic_shaping_enabled: false,
248 };
249
250 assert!(!shared.overrides_enabled);
251 assert!(!shared.traffic_shaping_enabled);
252 assert!(shared.failure_injector.is_none());
253 assert!(shared.traffic_shaper.is_none());
254 }
255
256 #[test]
257 fn test_shared_with_failure_injector() {
258 let failure_injector = FailureInjector::new(None, true);
259 let shared = Shared {
260 profiles: LatencyProfiles::default(),
261 overrides: Overrides::default(),
262 failure_injector: Some(failure_injector),
263 traffic_shaper: None,
264 overrides_enabled: false,
265 traffic_shaping_enabled: false,
266 };
267
268 assert!(shared.failure_injector.is_some());
269 }
270
271 #[test]
272 fn test_apply_overrides_disabled() {
273 let shared = Shared {
274 profiles: LatencyProfiles::default(),
275 overrides: Overrides::default(),
276 failure_injector: None,
277 traffic_shaper: None,
278 overrides_enabled: false,
279 traffic_shaping_enabled: false,
280 };
281
282 let op = OperationMeta {
283 id: "getUser".to_string(),
284 tags: vec![],
285 path: "/users".to_string(),
286 };
287
288 let mut body = json!({"name": "John"});
289 let original = body.clone();
290
291 apply_overrides(&shared, Some(&op), &mut body);
292
293 assert_eq!(body, original);
295 }
296
297 #[test]
298 fn test_apply_overrides_enabled_no_rules() {
299 let shared = Shared {
300 profiles: LatencyProfiles::default(),
301 overrides: Overrides::default(),
302 failure_injector: None,
303 traffic_shaper: None,
304 overrides_enabled: true,
305 traffic_shaping_enabled: false,
306 };
307
308 let op = OperationMeta {
309 id: "getUser".to_string(),
310 tags: vec![],
311 path: "/users".to_string(),
312 };
313
314 let mut body = json!({"name": "John"});
315 let original = body.clone();
316
317 apply_overrides(&shared, Some(&op), &mut body);
318
319 assert_eq!(body, original);
321 }
322
323 #[test]
324 fn test_apply_overrides_with_none_operation() {
325 let shared = Shared {
326 profiles: LatencyProfiles::default(),
327 overrides: Overrides::default(),
328 failure_injector: None,
329 traffic_shaper: None,
330 overrides_enabled: true,
331 traffic_shaping_enabled: false,
332 };
333
334 let mut body = json!({"name": "John"});
335 let original = body.clone();
336
337 apply_overrides(&shared, None, &mut body);
338
339 assert_eq!(body, original);
341 }
342
343 #[test]
344 fn test_calculate_request_size_basic() {
345 let req = Request::builder()
346 .uri("/test")
347 .header("content-type", "application/json")
348 .body(())
349 .unwrap();
350
351 let size = calculate_request_size(&req);
352
353 assert!(size > 0);
355 assert!(size >= "/test".len() as u64 + "content-type".len() as u64);
357 }
358
359 #[test]
360 fn test_calculate_request_size_with_multiple_headers() {
361 let req = Request::builder()
362 .uri("/api/users")
363 .header("content-type", "application/json")
364 .header("authorization", "Bearer token123")
365 .header("user-agent", "test-client")
366 .body(())
367 .unwrap();
368
369 let size = calculate_request_size(&req);
370
371 assert!(size > 100); }
374
375 #[test]
376 fn test_calculate_response_size_basic() {
377 let res = Response::builder()
378 .status(StatusCode::OK)
379 .header("content-type", "application/json")
380 .body(axum::body::Body::empty())
381 .unwrap();
382
383 let size = calculate_response_size(&res);
384
385 assert!(size > 0);
387 assert!(size >= 50);
389 }
390
391 #[test]
392 fn test_calculate_response_size_with_multiple_headers() {
393 let res = Response::builder()
394 .status(StatusCode::OK)
395 .header("content-type", "application/json")
396 .header("cache-control", "no-cache")
397 .header("x-request-id", "123-456-789")
398 .body(axum::body::Body::empty())
399 .unwrap();
400
401 let size = calculate_response_size(&res);
402
403 assert!(size > 100);
405 }
406
407 #[test]
408 fn test_shared_clone() {
409 let shared = Shared {
410 profiles: LatencyProfiles::default(),
411 overrides: Overrides::default(),
412 failure_injector: None,
413 traffic_shaper: None,
414 overrides_enabled: true,
415 traffic_shaping_enabled: true,
416 };
417
418 let cloned = shared.clone();
419
420 assert_eq!(shared.overrides_enabled, cloned.overrides_enabled);
421 assert_eq!(shared.traffic_shaping_enabled, cloned.traffic_shaping_enabled);
422 }
423
424 #[test]
425 fn test_operation_meta_clone() {
426 let meta = OperationMeta {
427 id: "testOp".to_string(),
428 tags: vec!["tag1".to_string()],
429 path: "/test".to_string(),
430 };
431
432 let cloned = meta.clone();
433
434 assert_eq!(meta.id, cloned.id);
435 assert_eq!(meta.tags, cloned.tags);
436 assert_eq!(meta.path, cloned.path);
437 }
438}