1use axum::body::Body;
3use axum::extract::State;
4use axum::http::{Request, StatusCode};
5use axum::middleware::Next;
6use axum::response::Response;
7use serde_json::Value;
8
9use crate::latency_profiles::LatencyProfiles;
10use mockforge_core::{FailureInjector, Overrides, TrafficShaper};
11
12#[derive(Clone)]
14pub struct OperationMeta {
15 pub id: String,
17 pub tags: Vec<String>,
19 pub path: String,
21}
22
23#[derive(Clone)]
25pub struct Shared {
26 pub profiles: LatencyProfiles,
28 pub overrides: Overrides,
30 pub failure_injector: Option<FailureInjector>,
32 pub traffic_shaper: Option<TrafficShaper>,
34 pub overrides_enabled: bool,
36 pub traffic_shaping_enabled: bool,
38}
39
40pub async fn add_shared_extension(
42 State(shared): State<Shared>,
43 mut req: Request<Body>,
44 next: Next,
45) -> Response {
46 req.extensions_mut().insert(shared);
47 next.run(req).await
48}
49
50pub async fn fault_then_next(req: Request<Body>, next: Next) -> Response {
52 let shared = match req.extensions().get::<Shared>() {
53 Some(s) => s.clone(),
54 None => {
55 tracing::error!("Shared extension not found in request - ensure add_shared_extension middleware is configured");
56 let mut res =
57 Response::new(Body::from("Internal server error: middleware misconfiguration"));
58 *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
59 return res;
60 }
61 };
62 let op = req.extensions().get::<OperationMeta>().cloned();
63
64 if let Some(failure_injector) = &shared.failure_injector {
66 let tags = op.as_ref().map(|o| o.tags.as_slice()).unwrap_or(&[]);
67 if let Some((status_code, error_message)) = failure_injector.process_request(tags) {
68 let mut res = Response::new(axum::body::Body::from(error_message));
69 *res.status_mut() =
70 StatusCode::from_u16(status_code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
71 return res;
72 }
73 }
74
75 if let Some(op) = &op {
77 if let Some((code, msg)) = shared
78 .profiles
79 .maybe_fault(&op.id, &op.tags.iter().map(|s| s.to_string()).collect::<Vec<_>>())
80 .await
81 {
82 let mut res = Response::new(axum::body::Body::from(msg));
83 *res.status_mut() =
84 StatusCode::from_u16(code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
85 return res;
86 }
87 }
88
89 if shared.traffic_shaping_enabled {
91 if let Some(traffic_shaper) = &shared.traffic_shaper {
92 let request_size = calculate_request_size(&req);
94
95 let tags = op.as_ref().map(|o| o.tags.as_slice()).unwrap_or(&[]);
96
97 match traffic_shaper.process_transfer(request_size, tags).await {
99 Ok(Some(_timeout)) => {
100 let mut res = Response::new(axum::body::Body::from(
102 "Request timeout due to traffic shaping",
103 ));
104 *res.status_mut() = StatusCode::REQUEST_TIMEOUT;
105 return res;
106 }
107 Ok(None) => {
108 }
110 Err(e) => {
111 let mut res = Response::new(axum::body::Body::from(format!(
113 "Traffic shaping error: {}",
114 e
115 )));
116 *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
117 return res;
118 }
119 }
120 }
121 }
122
123 let (parts, body) = req.into_parts();
124 let req = Request::from_parts(parts, body);
125
126 let response = next.run(req).await;
127
128 if shared.traffic_shaping_enabled {
130 if let Some(traffic_shaper) = &shared.traffic_shaper {
131 let response_size = calculate_response_size(&response);
133
134 let tags = op.as_ref().map(|o| o.tags.as_slice()).unwrap_or(&[]);
135
136 match traffic_shaper.process_transfer(response_size, tags).await {
138 Ok(Some(_timeout)) => {
139 let mut res = Response::new(axum::body::Body::from(
141 "Response timeout due to traffic shaping",
142 ));
143 *res.status_mut() = StatusCode::GATEWAY_TIMEOUT;
144 return res;
145 }
146 Ok(None) => {
147 }
149 Err(e) => {
150 let mut res = Response::new(axum::body::Body::from(format!(
152 "Traffic shaping error: {}",
153 e
154 )));
155 *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
156 return res;
157 }
158 }
159 }
160 }
161
162 response
163}
164
165pub fn apply_overrides(shared: &Shared, op: Option<&OperationMeta>, body: &mut Value) {
172 if shared.overrides_enabled {
173 if let Some(op) = op {
174 shared.overrides.apply(
175 &op.id,
176 &op.tags.iter().map(|s| s.to_string()).collect::<Vec<_>>(),
177 &op.path,
178 body,
179 );
180 }
181 }
182}
183
184fn calculate_request_size<B>(req: &Request<B>) -> u64 {
186 let mut size = 0u64;
187
188 for (name, value) in req.headers() {
190 size += name.as_str().len() as u64;
191 size += value.as_bytes().len() as u64;
192 }
193
194 size += req.uri().to_string().len() as u64;
196
197 size += 1024; size
204}
205
206fn calculate_response_size(res: &Response) -> u64 {
208 let mut size = 0u64;
209
210 for (name, value) in res.headers() {
212 size += name.as_str().len() as u64;
213 size += value.as_bytes().len() as u64;
214 }
215
216 size += 50;
218
219 size += 2048; size
224}
225
226#[cfg(test)]
227mod tests {
228 use super::*;
229 use axum::http::{Request, Response, StatusCode};
230 use serde_json::json;
231
232 #[test]
233 fn test_operation_meta_creation() {
234 let meta = OperationMeta {
235 id: "getUserById".to_string(),
236 tags: vec!["users".to_string(), "public".to_string()],
237 path: "/users/{id}".to_string(),
238 };
239
240 assert_eq!(meta.id, "getUserById");
241 assert_eq!(meta.tags.len(), 2);
242 assert_eq!(meta.path, "/users/{id}");
243 }
244
245 #[test]
246 fn test_shared_creation() {
247 let shared = Shared {
248 profiles: LatencyProfiles::default(),
249 overrides: Overrides::default(),
250 failure_injector: None,
251 traffic_shaper: None,
252 overrides_enabled: false,
253 traffic_shaping_enabled: false,
254 };
255
256 assert!(!shared.overrides_enabled);
257 assert!(!shared.traffic_shaping_enabled);
258 assert!(shared.failure_injector.is_none());
259 assert!(shared.traffic_shaper.is_none());
260 }
261
262 #[test]
263 fn test_shared_with_failure_injector() {
264 let failure_injector = FailureInjector::new(None, true);
265 let shared = Shared {
266 profiles: LatencyProfiles::default(),
267 overrides: Overrides::default(),
268 failure_injector: Some(failure_injector),
269 traffic_shaper: None,
270 overrides_enabled: false,
271 traffic_shaping_enabled: false,
272 };
273
274 assert!(shared.failure_injector.is_some());
275 }
276
277 #[test]
278 fn test_apply_overrides_disabled() {
279 let shared = Shared {
280 profiles: LatencyProfiles::default(),
281 overrides: Overrides::default(),
282 failure_injector: None,
283 traffic_shaper: None,
284 overrides_enabled: false,
285 traffic_shaping_enabled: false,
286 };
287
288 let op = OperationMeta {
289 id: "getUser".to_string(),
290 tags: vec![],
291 path: "/users".to_string(),
292 };
293
294 let mut body = json!({"name": "John"});
295 let original = body.clone();
296
297 apply_overrides(&shared, Some(&op), &mut body);
298
299 assert_eq!(body, original);
301 }
302
303 #[test]
304 fn test_apply_overrides_enabled_no_rules() {
305 let shared = Shared {
306 profiles: LatencyProfiles::default(),
307 overrides: Overrides::default(),
308 failure_injector: None,
309 traffic_shaper: None,
310 overrides_enabled: true,
311 traffic_shaping_enabled: false,
312 };
313
314 let op = OperationMeta {
315 id: "getUser".to_string(),
316 tags: vec![],
317 path: "/users".to_string(),
318 };
319
320 let mut body = json!({"name": "John"});
321 let original = body.clone();
322
323 apply_overrides(&shared, Some(&op), &mut body);
324
325 assert_eq!(body, original);
327 }
328
329 #[test]
330 fn test_apply_overrides_with_none_operation() {
331 let shared = Shared {
332 profiles: LatencyProfiles::default(),
333 overrides: Overrides::default(),
334 failure_injector: None,
335 traffic_shaper: None,
336 overrides_enabled: true,
337 traffic_shaping_enabled: false,
338 };
339
340 let mut body = json!({"name": "John"});
341 let original = body.clone();
342
343 apply_overrides(&shared, None, &mut body);
344
345 assert_eq!(body, original);
347 }
348
349 #[test]
350 fn test_calculate_request_size_basic() {
351 let req = Request::builder()
352 .uri("/test")
353 .header("content-type", "application/json")
354 .body(())
355 .unwrap();
356
357 let size = calculate_request_size(&req);
358
359 assert!(size > 0);
361 assert!(size >= "/test".len() as u64 + "content-type".len() as u64);
363 }
364
365 #[test]
366 fn test_calculate_request_size_with_multiple_headers() {
367 let req = Request::builder()
368 .uri("/api/users")
369 .header("content-type", "application/json")
370 .header("authorization", "Bearer token123")
371 .header("user-agent", "test-client")
372 .body(())
373 .unwrap();
374
375 let size = calculate_request_size(&req);
376
377 assert!(size > 100); }
380
381 #[test]
382 fn test_calculate_response_size_basic() {
383 let res = Response::builder()
384 .status(StatusCode::OK)
385 .header("content-type", "application/json")
386 .body(axum::body::Body::empty())
387 .unwrap();
388
389 let size = calculate_response_size(&res);
390
391 assert!(size > 0);
393 assert!(size >= 50);
395 }
396
397 #[test]
398 fn test_calculate_response_size_with_multiple_headers() {
399 let res = Response::builder()
400 .status(StatusCode::OK)
401 .header("content-type", "application/json")
402 .header("cache-control", "no-cache")
403 .header("x-request-id", "123-456-789")
404 .body(axum::body::Body::empty())
405 .unwrap();
406
407 let size = calculate_response_size(&res);
408
409 assert!(size > 100);
411 }
412
413 #[test]
414 fn test_shared_clone() {
415 let shared = Shared {
416 profiles: LatencyProfiles::default(),
417 overrides: Overrides::default(),
418 failure_injector: None,
419 traffic_shaper: None,
420 overrides_enabled: true,
421 traffic_shaping_enabled: true,
422 };
423
424 let cloned = shared.clone();
425
426 assert_eq!(shared.overrides_enabled, cloned.overrides_enabled);
427 assert_eq!(shared.traffic_shaping_enabled, cloned.traffic_shaping_enabled);
428 }
429
430 #[test]
431 fn test_operation_meta_clone() {
432 let meta = OperationMeta {
433 id: "testOp".to_string(),
434 tags: vec!["tag1".to_string()],
435 path: "/test".to_string(),
436 };
437
438 let cloned = meta.clone();
439
440 assert_eq!(meta.id, cloned.id);
441 assert_eq!(meta.tags, cloned.tags);
442 assert_eq!(meta.path, cloned.path);
443 }
444}