1use axum::body::Body;
3use axum::extract::State;
4use axum::http::{Request, StatusCode};
5use axum::middleware::Next;
6use axum::response::{Json, Response};
7use serde_json::Value;
8
9use crate::latency_profiles::LatencyProfiles;
10use mockforge_core::{FailureInjector, Overrides, TrafficShaper};
11
12#[derive(Clone)]
14pub struct OperationMeta {
15 pub id: String,
17 pub tags: Vec<String>,
19 pub path: String,
21}
22
23#[derive(Clone)]
25pub struct Shared {
26 pub profiles: LatencyProfiles,
28 pub overrides: Overrides,
30 pub failure_injector: Option<FailureInjector>,
32 pub traffic_shaper: Option<TrafficShaper>,
34 pub overrides_enabled: bool,
36 pub traffic_shaping_enabled: bool,
38}
39
40pub async fn add_shared_extension(
42 State(shared): State<Shared>,
43 mut req: Request<Body>,
44 next: Next,
45) -> Response {
46 req.extensions_mut().insert(shared);
47 next.run(req).await
48}
49
50pub async fn fault_then_next(req: Request<Body>, next: Next) -> Response {
52 let shared = req.extensions().get::<Shared>().unwrap().clone();
53 let op = req.extensions().get::<OperationMeta>().cloned();
54
55 if let Some(failure_injector) = &shared.failure_injector {
57 let tags = op.as_ref().map(|o| o.tags.as_slice()).unwrap_or(&[]);
58 if let Some((status_code, error_message)) = failure_injector.process_request(tags) {
59 let mut res = Response::new(axum::body::Body::from(error_message));
60 *res.status_mut() =
61 StatusCode::from_u16(status_code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
62 return res;
63 }
64 }
65
66 if let Some(op) = &op {
68 if let Some((code, msg)) = shared
69 .profiles
70 .maybe_fault(&op.id, &op.tags.iter().map(|s| s.to_string()).collect::<Vec<_>>())
71 .await
72 {
73 let mut res = Response::new(axum::body::Body::from(msg));
74 *res.status_mut() =
75 StatusCode::from_u16(code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
76 return res;
77 }
78 }
79
80 if shared.traffic_shaping_enabled {
82 if let Some(traffic_shaper) = &shared.traffic_shaper {
83 let request_size = calculate_request_size(&req);
85
86 let tags = op.as_ref().map(|o| o.tags.as_slice()).unwrap_or(&[]);
87
88 match traffic_shaper.process_transfer(request_size, tags).await {
90 Ok(Some(_timeout)) => {
91 let mut res = Response::new(axum::body::Body::from(
93 "Request timeout due to traffic shaping",
94 ));
95 *res.status_mut() = StatusCode::REQUEST_TIMEOUT;
96 return res;
97 }
98 Ok(None) => {
99 }
101 Err(e) => {
102 let mut res = Response::new(axum::body::Body::from(format!(
104 "Traffic shaping error: {}",
105 e
106 )));
107 *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
108 return res;
109 }
110 }
111 }
112 }
113
114 let (parts, body) = req.into_parts();
115 let req = Request::from_parts(parts, body);
116
117 let response = next.run(req).await;
118
119 if shared.traffic_shaping_enabled {
121 if let Some(traffic_shaper) = &shared.traffic_shaper {
122 let response_size = calculate_response_size(&response);
124
125 let tags = op.as_ref().map(|o| o.tags.as_slice()).unwrap_or(&[]);
126
127 match traffic_shaper.process_transfer(response_size, tags).await {
129 Ok(Some(_timeout)) => {
130 let mut res = Response::new(axum::body::Body::from(
132 "Response timeout due to traffic shaping",
133 ));
134 *res.status_mut() = StatusCode::GATEWAY_TIMEOUT;
135 return res;
136 }
137 Ok(None) => {
138 }
140 Err(e) => {
141 let mut res = Response::new(axum::body::Body::from(format!(
143 "Traffic shaping error: {}",
144 e
145 )));
146 *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
147 return res;
148 }
149 }
150 }
151 }
152
153 response
154}
155
156pub fn apply_overrides(shared: &Shared, op: Option<&OperationMeta>, body: &mut Value) {
163 if shared.overrides_enabled {
164 if let Some(op) = op {
165 shared.overrides.apply(
166 &op.id,
167 &op.tags.iter().map(|s| s.to_string()).collect::<Vec<_>>(),
168 &op.path,
169 body,
170 );
171 }
172 }
173}
174
175fn calculate_request_size<B>(req: &Request<B>) -> u64 {
177 let mut size = 0u64;
178
179 for (name, value) in req.headers() {
181 size += name.as_str().len() as u64;
182 size += value.as_bytes().len() as u64;
183 }
184
185 size += req.uri().to_string().len() as u64;
187
188 size += 1024; size
195}
196
197fn calculate_response_size(res: &Response) -> u64 {
199 let mut size = 0u64;
200
201 for (name, value) in res.headers() {
203 size += name.as_str().len() as u64;
204 size += value.as_bytes().len() as u64;
205 }
206
207 size += 50;
209
210 size += 2048; size
215}
216
217#[cfg(test)]
218mod tests {
219 use super::*;
220 use axum::http::{Request, Response, StatusCode};
221 use serde_json::json;
222
223 #[test]
224 fn test_operation_meta_creation() {
225 let meta = OperationMeta {
226 id: "getUserById".to_string(),
227 tags: vec!["users".to_string(), "public".to_string()],
228 path: "/users/{id}".to_string(),
229 };
230
231 assert_eq!(meta.id, "getUserById");
232 assert_eq!(meta.tags.len(), 2);
233 assert_eq!(meta.path, "/users/{id}");
234 }
235
236 #[test]
237 fn test_shared_creation() {
238 let shared = Shared {
239 profiles: LatencyProfiles::default(),
240 overrides: Overrides::default(),
241 failure_injector: None,
242 traffic_shaper: None,
243 overrides_enabled: false,
244 traffic_shaping_enabled: false,
245 };
246
247 assert!(!shared.overrides_enabled);
248 assert!(!shared.traffic_shaping_enabled);
249 assert!(shared.failure_injector.is_none());
250 assert!(shared.traffic_shaper.is_none());
251 }
252
253 #[test]
254 fn test_shared_with_failure_injector() {
255 let failure_injector = FailureInjector::new(None, true);
256 let shared = Shared {
257 profiles: LatencyProfiles::default(),
258 overrides: Overrides::default(),
259 failure_injector: Some(failure_injector),
260 traffic_shaper: None,
261 overrides_enabled: false,
262 traffic_shaping_enabled: false,
263 };
264
265 assert!(shared.failure_injector.is_some());
266 }
267
268 #[test]
269 fn test_apply_overrides_disabled() {
270 let shared = Shared {
271 profiles: LatencyProfiles::default(),
272 overrides: Overrides::default(),
273 failure_injector: None,
274 traffic_shaper: None,
275 overrides_enabled: false,
276 traffic_shaping_enabled: false,
277 };
278
279 let op = OperationMeta {
280 id: "getUser".to_string(),
281 tags: vec![],
282 path: "/users".to_string(),
283 };
284
285 let mut body = json!({"name": "John"});
286 let original = body.clone();
287
288 apply_overrides(&shared, Some(&op), &mut body);
289
290 assert_eq!(body, original);
292 }
293
294 #[test]
295 fn test_apply_overrides_enabled_no_rules() {
296 let shared = Shared {
297 profiles: LatencyProfiles::default(),
298 overrides: Overrides::default(),
299 failure_injector: None,
300 traffic_shaper: None,
301 overrides_enabled: true,
302 traffic_shaping_enabled: false,
303 };
304
305 let op = OperationMeta {
306 id: "getUser".to_string(),
307 tags: vec![],
308 path: "/users".to_string(),
309 };
310
311 let mut body = json!({"name": "John"});
312 let original = body.clone();
313
314 apply_overrides(&shared, Some(&op), &mut body);
315
316 assert_eq!(body, original);
318 }
319
320 #[test]
321 fn test_apply_overrides_with_none_operation() {
322 let shared = Shared {
323 profiles: LatencyProfiles::default(),
324 overrides: Overrides::default(),
325 failure_injector: None,
326 traffic_shaper: None,
327 overrides_enabled: true,
328 traffic_shaping_enabled: false,
329 };
330
331 let mut body = json!({"name": "John"});
332 let original = body.clone();
333
334 apply_overrides(&shared, None, &mut body);
335
336 assert_eq!(body, original);
338 }
339
340 #[test]
341 fn test_calculate_request_size_basic() {
342 let req = Request::builder()
343 .uri("/test")
344 .header("content-type", "application/json")
345 .body(())
346 .unwrap();
347
348 let size = calculate_request_size(&req);
349
350 assert!(size > 0);
352 assert!(size >= "/test".len() as u64 + "content-type".len() as u64);
354 }
355
356 #[test]
357 fn test_calculate_request_size_with_multiple_headers() {
358 let req = Request::builder()
359 .uri("/api/users")
360 .header("content-type", "application/json")
361 .header("authorization", "Bearer token123")
362 .header("user-agent", "test-client")
363 .body(())
364 .unwrap();
365
366 let size = calculate_request_size(&req);
367
368 assert!(size > 100); }
371
372 #[test]
373 fn test_calculate_response_size_basic() {
374 let res = Response::builder()
375 .status(StatusCode::OK)
376 .header("content-type", "application/json")
377 .body(axum::body::Body::empty())
378 .unwrap();
379
380 let size = calculate_response_size(&res);
381
382 assert!(size > 0);
384 assert!(size >= 50);
386 }
387
388 #[test]
389 fn test_calculate_response_size_with_multiple_headers() {
390 let res = Response::builder()
391 .status(StatusCode::OK)
392 .header("content-type", "application/json")
393 .header("cache-control", "no-cache")
394 .header("x-request-id", "123-456-789")
395 .body(axum::body::Body::empty())
396 .unwrap();
397
398 let size = calculate_response_size(&res);
399
400 assert!(size > 100);
402 }
403
404 #[test]
405 fn test_shared_clone() {
406 let shared = Shared {
407 profiles: LatencyProfiles::default(),
408 overrides: Overrides::default(),
409 failure_injector: None,
410 traffic_shaper: None,
411 overrides_enabled: true,
412 traffic_shaping_enabled: true,
413 };
414
415 let cloned = shared.clone();
416
417 assert_eq!(shared.overrides_enabled, cloned.overrides_enabled);
418 assert_eq!(shared.traffic_shaping_enabled, cloned.traffic_shaping_enabled);
419 }
420
421 #[test]
422 fn test_operation_meta_clone() {
423 let meta = OperationMeta {
424 id: "testOp".to_string(),
425 tags: vec!["tag1".to_string()],
426 path: "/test".to_string(),
427 };
428
429 let cloned = meta.clone();
430
431 assert_eq!(meta.id, cloned.id);
432 assert_eq!(meta.tags, cloned.tags);
433 assert_eq!(meta.path, cloned.path);
434 }
435}