1use axum::body::Body;
3use axum::http::{Request, StatusCode};
4use axum::{extract::State, middleware::Next, response::Response};
5use serde_json::Value;
6
7use crate::latency_profiles::LatencyProfiles;
8use mockforge_core::{FailureInjector, Overrides, TrafficShaper};
9
10#[derive(Clone)]
12pub struct OperationMeta {
13 pub id: String,
15 pub tags: Vec<String>,
17 pub path: String,
19}
20
21#[derive(Clone)]
23pub struct Shared {
24 pub profiles: LatencyProfiles,
26 pub overrides: Overrides,
28 pub failure_injector: Option<FailureInjector>,
30 pub traffic_shaper: Option<TrafficShaper>,
32 pub overrides_enabled: bool,
34 pub traffic_shaping_enabled: bool,
36}
37
38pub async fn add_shared_extension(
40 State(shared): State<Shared>,
41 mut req: Request<Body>,
42 next: Next,
43) -> Response {
44 req.extensions_mut().insert(shared);
45 next.run(req).await
46}
47
48pub async fn fault_then_next(req: Request<Body>, next: Next) -> Response {
50 let shared = req.extensions().get::<Shared>().unwrap().clone();
51 let op = req.extensions().get::<OperationMeta>().cloned();
52
53 if let Some(failure_injector) = &shared.failure_injector {
55 let tags = op.as_ref().map(|o| o.tags.as_slice()).unwrap_or(&[]);
56 if let Some((status_code, error_message)) = failure_injector.process_request(tags) {
57 let mut res = Response::new(axum::body::Body::from(error_message));
58 *res.status_mut() =
59 StatusCode::from_u16(status_code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
60 return res;
61 }
62 }
63
64 if let Some(op) = &op {
66 if let Some((code, msg)) = shared
67 .profiles
68 .maybe_fault(&op.id, &op.tags.iter().map(|s| s.to_string()).collect::<Vec<_>>())
69 .await
70 {
71 let mut res = Response::new(axum::body::Body::from(msg));
72 *res.status_mut() =
73 StatusCode::from_u16(code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
74 return res;
75 }
76 }
77
78 if shared.traffic_shaping_enabled {
80 if let Some(traffic_shaper) = &shared.traffic_shaper {
81 let request_size = calculate_request_size(&req);
83
84 let tags = op.as_ref().map(|o| o.tags.as_slice()).unwrap_or(&[]);
85
86 match traffic_shaper.process_transfer(request_size, tags).await {
88 Ok(Some(_timeout)) => {
89 let mut res = Response::new(axum::body::Body::from(
91 "Request timeout due to traffic shaping",
92 ));
93 *res.status_mut() = StatusCode::REQUEST_TIMEOUT;
94 return res;
95 }
96 Ok(None) => {
97 }
99 Err(e) => {
100 let mut res = Response::new(axum::body::Body::from(format!(
102 "Traffic shaping error: {}",
103 e
104 )));
105 *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
106 return res;
107 }
108 }
109 }
110 }
111
112 let (parts, body) = req.into_parts();
113 let req = Request::from_parts(parts, body);
114
115 let response = next.run(req).await;
116
117 if shared.traffic_shaping_enabled {
119 if let Some(traffic_shaper) = &shared.traffic_shaper {
120 let response_size = calculate_response_size(&response);
122
123 let tags = op.as_ref().map(|o| o.tags.as_slice()).unwrap_or(&[]);
124
125 match traffic_shaper.process_transfer(response_size, tags).await {
127 Ok(Some(_timeout)) => {
128 let mut res = Response::new(axum::body::Body::from(
130 "Response timeout due to traffic shaping",
131 ));
132 *res.status_mut() = StatusCode::GATEWAY_TIMEOUT;
133 return res;
134 }
135 Ok(None) => {
136 }
138 Err(e) => {
139 let mut res = Response::new(axum::body::Body::from(format!(
141 "Traffic shaping error: {}",
142 e
143 )));
144 *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
145 return res;
146 }
147 }
148 }
149 }
150
151 response
152}
153
154pub fn apply_overrides(shared: &Shared, op: Option<&OperationMeta>, body: &mut Value) {
161 if shared.overrides_enabled {
162 if let Some(op) = op {
163 shared.overrides.apply(
164 &op.id,
165 &op.tags.iter().map(|s| s.to_string()).collect::<Vec<_>>(),
166 &op.path,
167 body,
168 );
169 }
170 }
171}
172
173fn calculate_request_size<B>(req: &Request<B>) -> u64 {
175 let mut size = 0u64;
176
177 for (name, value) in req.headers() {
179 size += name.as_str().len() as u64;
180 size += value.as_bytes().len() as u64;
181 }
182
183 size += req.uri().to_string().len() as u64;
185
186 size += 1024; size
193}
194
195fn calculate_response_size(res: &Response) -> u64 {
197 let mut size = 0u64;
198
199 for (name, value) in res.headers() {
201 size += name.as_str().len() as u64;
202 size += value.as_bytes().len() as u64;
203 }
204
205 size += 50;
207
208 size += 2048; size
213}
214
215#[cfg(test)]
216mod tests {
217 use super::*;
218 use axum::http::{Request, Response, StatusCode};
219 use serde_json::json;
220
221 #[test]
222 fn test_operation_meta_creation() {
223 let meta = OperationMeta {
224 id: "getUserById".to_string(),
225 tags: vec!["users".to_string(), "public".to_string()],
226 path: "/users/{id}".to_string(),
227 };
228
229 assert_eq!(meta.id, "getUserById");
230 assert_eq!(meta.tags.len(), 2);
231 assert_eq!(meta.path, "/users/{id}");
232 }
233
234 #[test]
235 fn test_shared_creation() {
236 let shared = Shared {
237 profiles: LatencyProfiles::default(),
238 overrides: Overrides::default(),
239 failure_injector: None,
240 traffic_shaper: None,
241 overrides_enabled: false,
242 traffic_shaping_enabled: false,
243 };
244
245 assert!(!shared.overrides_enabled);
246 assert!(!shared.traffic_shaping_enabled);
247 assert!(shared.failure_injector.is_none());
248 assert!(shared.traffic_shaper.is_none());
249 }
250
251 #[test]
252 fn test_shared_with_failure_injector() {
253 let failure_injector = FailureInjector::new(None, true);
254 let shared = Shared {
255 profiles: LatencyProfiles::default(),
256 overrides: Overrides::default(),
257 failure_injector: Some(failure_injector),
258 traffic_shaper: None,
259 overrides_enabled: false,
260 traffic_shaping_enabled: false,
261 };
262
263 assert!(shared.failure_injector.is_some());
264 }
265
266 #[test]
267 fn test_apply_overrides_disabled() {
268 let shared = Shared {
269 profiles: LatencyProfiles::default(),
270 overrides: Overrides::default(),
271 failure_injector: None,
272 traffic_shaper: None,
273 overrides_enabled: false,
274 traffic_shaping_enabled: false,
275 };
276
277 let op = OperationMeta {
278 id: "getUser".to_string(),
279 tags: vec![],
280 path: "/users".to_string(),
281 };
282
283 let mut body = json!({"name": "John"});
284 let original = body.clone();
285
286 apply_overrides(&shared, Some(&op), &mut body);
287
288 assert_eq!(body, original);
290 }
291
292 #[test]
293 fn test_apply_overrides_enabled_no_rules() {
294 let shared = Shared {
295 profiles: LatencyProfiles::default(),
296 overrides: Overrides::default(),
297 failure_injector: None,
298 traffic_shaper: None,
299 overrides_enabled: true,
300 traffic_shaping_enabled: false,
301 };
302
303 let op = OperationMeta {
304 id: "getUser".to_string(),
305 tags: vec![],
306 path: "/users".to_string(),
307 };
308
309 let mut body = json!({"name": "John"});
310 let original = body.clone();
311
312 apply_overrides(&shared, Some(&op), &mut body);
313
314 assert_eq!(body, original);
316 }
317
318 #[test]
319 fn test_apply_overrides_with_none_operation() {
320 let shared = Shared {
321 profiles: LatencyProfiles::default(),
322 overrides: Overrides::default(),
323 failure_injector: None,
324 traffic_shaper: None,
325 overrides_enabled: true,
326 traffic_shaping_enabled: false,
327 };
328
329 let mut body = json!({"name": "John"});
330 let original = body.clone();
331
332 apply_overrides(&shared, None, &mut body);
333
334 assert_eq!(body, original);
336 }
337
338 #[test]
339 fn test_calculate_request_size_basic() {
340 let req = Request::builder()
341 .uri("/test")
342 .header("content-type", "application/json")
343 .body(())
344 .unwrap();
345
346 let size = calculate_request_size(&req);
347
348 assert!(size > 0);
350 assert!(size >= "/test".len() as u64 + "content-type".len() as u64);
352 }
353
354 #[test]
355 fn test_calculate_request_size_with_multiple_headers() {
356 let req = Request::builder()
357 .uri("/api/users")
358 .header("content-type", "application/json")
359 .header("authorization", "Bearer token123")
360 .header("user-agent", "test-client")
361 .body(())
362 .unwrap();
363
364 let size = calculate_request_size(&req);
365
366 assert!(size > 100); }
369
370 #[test]
371 fn test_calculate_response_size_basic() {
372 let res = Response::builder()
373 .status(StatusCode::OK)
374 .header("content-type", "application/json")
375 .body(axum::body::Body::empty())
376 .unwrap();
377
378 let size = calculate_response_size(&res);
379
380 assert!(size > 0);
382 assert!(size >= 50);
384 }
385
386 #[test]
387 fn test_calculate_response_size_with_multiple_headers() {
388 let res = Response::builder()
389 .status(StatusCode::OK)
390 .header("content-type", "application/json")
391 .header("cache-control", "no-cache")
392 .header("x-request-id", "123-456-789")
393 .body(axum::body::Body::empty())
394 .unwrap();
395
396 let size = calculate_response_size(&res);
397
398 assert!(size > 100);
400 }
401
402 #[test]
403 fn test_shared_clone() {
404 let shared = Shared {
405 profiles: LatencyProfiles::default(),
406 overrides: Overrides::default(),
407 failure_injector: None,
408 traffic_shaper: None,
409 overrides_enabled: true,
410 traffic_shaping_enabled: true,
411 };
412
413 let cloned = shared.clone();
414
415 assert_eq!(shared.overrides_enabled, cloned.overrides_enabled);
416 assert_eq!(shared.traffic_shaping_enabled, cloned.traffic_shaping_enabled);
417 }
418
419 #[test]
420 fn test_operation_meta_clone() {
421 let meta = OperationMeta {
422 id: "testOp".to_string(),
423 tags: vec!["tag1".to_string()],
424 path: "/test".to_string(),
425 };
426
427 let cloned = meta.clone();
428
429 assert_eq!(meta.id, cloned.id);
430 assert_eq!(meta.tags, cloned.tags);
431 assert_eq!(meta.path, cloned.path);
432 }
433}