1use axum::body::Body;
3use axum::http::{Request, StatusCode};
4use axum::{extract::State, middleware::Next, response::Response};
5use serde_json::Value;
6
7use crate::latency_profiles::LatencyProfiles;
8use mockforge_core::{FailureInjector, Overrides, TrafficShaper};
9
10#[derive(Clone)]
11pub struct OperationMeta {
12 pub id: String,
13 pub tags: Vec<String>,
14 pub path: String,
15}
16
17#[derive(Clone)]
18pub struct Shared {
19 pub profiles: LatencyProfiles,
20 pub overrides: Overrides,
21 pub failure_injector: Option<FailureInjector>,
22 pub traffic_shaper: Option<TrafficShaper>,
23 pub overrides_enabled: bool,
24 pub traffic_shaping_enabled: bool,
25}
26
27pub async fn add_shared_extension(
28 State(shared): State<Shared>,
29 mut req: Request<Body>,
30 next: Next,
31) -> Response {
32 req.extensions_mut().insert(shared);
33 next.run(req).await
34}
35
36pub async fn fault_then_next(req: Request<Body>, next: Next) -> Response {
37 let shared = req.extensions().get::<Shared>().unwrap().clone();
38 let op = req.extensions().get::<OperationMeta>().cloned();
39
40 if let Some(failure_injector) = &shared.failure_injector {
42 let tags = op.as_ref().map(|o| o.tags.as_slice()).unwrap_or(&[]);
43 if let Some((status_code, error_message)) = failure_injector.process_request(tags) {
44 let mut res = Response::new(axum::body::Body::from(error_message));
45 *res.status_mut() =
46 StatusCode::from_u16(status_code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
47 return res;
48 }
49 }
50
51 if let Some(op) = &op {
53 if let Some((code, msg)) = shared
54 .profiles
55 .maybe_fault(&op.id, &op.tags.iter().map(|s| s.to_string()).collect::<Vec<_>>())
56 .await
57 {
58 let mut res = Response::new(axum::body::Body::from(msg));
59 *res.status_mut() =
60 StatusCode::from_u16(code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
61 return res;
62 }
63 }
64
65 if shared.traffic_shaping_enabled {
67 if let Some(traffic_shaper) = &shared.traffic_shaper {
68 let request_size = calculate_request_size(&req);
70
71 let tags = op.as_ref().map(|o| o.tags.as_slice()).unwrap_or(&[]);
72
73 match traffic_shaper.process_transfer(request_size, tags).await {
75 Ok(Some(_timeout)) => {
76 let mut res = Response::new(axum::body::Body::from(
78 "Request timeout due to traffic shaping",
79 ));
80 *res.status_mut() = StatusCode::REQUEST_TIMEOUT;
81 return res;
82 }
83 Ok(None) => {
84 }
86 Err(e) => {
87 let mut res = Response::new(axum::body::Body::from(format!(
89 "Traffic shaping error: {}",
90 e
91 )));
92 *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
93 return res;
94 }
95 }
96 }
97 }
98
99 let (parts, body) = req.into_parts();
100 let req = Request::from_parts(parts, body);
101
102 let response = next.run(req).await;
103
104 if shared.traffic_shaping_enabled {
106 if let Some(traffic_shaper) = &shared.traffic_shaper {
107 let response_size = calculate_response_size(&response);
109
110 let tags = op.as_ref().map(|o| o.tags.as_slice()).unwrap_or(&[]);
111
112 match traffic_shaper.process_transfer(response_size, tags).await {
114 Ok(Some(_timeout)) => {
115 let mut res = Response::new(axum::body::Body::from(
117 "Response timeout due to traffic shaping",
118 ));
119 *res.status_mut() = StatusCode::GATEWAY_TIMEOUT;
120 return res;
121 }
122 Ok(None) => {
123 }
125 Err(e) => {
126 let mut res = Response::new(axum::body::Body::from(format!(
128 "Traffic shaping error: {}",
129 e
130 )));
131 *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
132 return res;
133 }
134 }
135 }
136 }
137
138 response
139}
140
141pub fn apply_overrides(shared: &Shared, op: Option<&OperationMeta>, body: &mut Value) {
142 if shared.overrides_enabled {
143 if let Some(op) = op {
144 shared.overrides.apply(
145 &op.id,
146 &op.tags.iter().map(|s| s.to_string()).collect::<Vec<_>>(),
147 &op.path,
148 body,
149 );
150 }
151 }
152}
153
154fn calculate_request_size<B>(req: &Request<B>) -> u64 {
156 let mut size = 0u64;
157
158 for (name, value) in req.headers() {
160 size += name.as_str().len() as u64;
161 size += value.as_bytes().len() as u64;
162 }
163
164 size += req.uri().to_string().len() as u64;
166
167 size += 1024; size
174}
175
176fn calculate_response_size(res: &Response) -> u64 {
178 let mut size = 0u64;
179
180 for (name, value) in res.headers() {
182 size += name.as_str().len() as u64;
183 size += value.as_bytes().len() as u64;
184 }
185
186 size += 50;
188
189 size += 2048; size
194}
195
196#[cfg(test)]
197mod tests {
198 use super::*;
199 use axum::http::{Request, Response, StatusCode};
200 use serde_json::json;
201
202 #[test]
203 fn test_operation_meta_creation() {
204 let meta = OperationMeta {
205 id: "getUserById".to_string(),
206 tags: vec!["users".to_string(), "public".to_string()],
207 path: "/users/{id}".to_string(),
208 };
209
210 assert_eq!(meta.id, "getUserById");
211 assert_eq!(meta.tags.len(), 2);
212 assert_eq!(meta.path, "/users/{id}");
213 }
214
215 #[test]
216 fn test_shared_creation() {
217 let shared = Shared {
218 profiles: LatencyProfiles::default(),
219 overrides: Overrides::default(),
220 failure_injector: None,
221 traffic_shaper: None,
222 overrides_enabled: false,
223 traffic_shaping_enabled: false,
224 };
225
226 assert!(!shared.overrides_enabled);
227 assert!(!shared.traffic_shaping_enabled);
228 assert!(shared.failure_injector.is_none());
229 assert!(shared.traffic_shaper.is_none());
230 }
231
232 #[test]
233 fn test_shared_with_failure_injector() {
234 let failure_injector = FailureInjector::new(None, true);
235 let shared = Shared {
236 profiles: LatencyProfiles::default(),
237 overrides: Overrides::default(),
238 failure_injector: Some(failure_injector),
239 traffic_shaper: None,
240 overrides_enabled: false,
241 traffic_shaping_enabled: false,
242 };
243
244 assert!(shared.failure_injector.is_some());
245 }
246
247 #[test]
248 fn test_apply_overrides_disabled() {
249 let shared = Shared {
250 profiles: LatencyProfiles::default(),
251 overrides: Overrides::default(),
252 failure_injector: None,
253 traffic_shaper: None,
254 overrides_enabled: false,
255 traffic_shaping_enabled: false,
256 };
257
258 let op = OperationMeta {
259 id: "getUser".to_string(),
260 tags: vec![],
261 path: "/users".to_string(),
262 };
263
264 let mut body = json!({"name": "John"});
265 let original = body.clone();
266
267 apply_overrides(&shared, Some(&op), &mut body);
268
269 assert_eq!(body, original);
271 }
272
273 #[test]
274 fn test_apply_overrides_enabled_no_rules() {
275 let shared = Shared {
276 profiles: LatencyProfiles::default(),
277 overrides: Overrides::default(),
278 failure_injector: None,
279 traffic_shaper: None,
280 overrides_enabled: true,
281 traffic_shaping_enabled: false,
282 };
283
284 let op = OperationMeta {
285 id: "getUser".to_string(),
286 tags: vec![],
287 path: "/users".to_string(),
288 };
289
290 let mut body = json!({"name": "John"});
291 let original = body.clone();
292
293 apply_overrides(&shared, Some(&op), &mut body);
294
295 assert_eq!(body, original);
297 }
298
299 #[test]
300 fn test_apply_overrides_with_none_operation() {
301 let shared = Shared {
302 profiles: LatencyProfiles::default(),
303 overrides: Overrides::default(),
304 failure_injector: None,
305 traffic_shaper: None,
306 overrides_enabled: true,
307 traffic_shaping_enabled: false,
308 };
309
310 let mut body = json!({"name": "John"});
311 let original = body.clone();
312
313 apply_overrides(&shared, None, &mut body);
314
315 assert_eq!(body, original);
317 }
318
319 #[test]
320 fn test_calculate_request_size_basic() {
321 let req = Request::builder()
322 .uri("/test")
323 .header("content-type", "application/json")
324 .body(())
325 .unwrap();
326
327 let size = calculate_request_size(&req);
328
329 assert!(size > 0);
331 assert!(size >= "/test".len() as u64 + "content-type".len() as u64);
333 }
334
335 #[test]
336 fn test_calculate_request_size_with_multiple_headers() {
337 let req = Request::builder()
338 .uri("/api/users")
339 .header("content-type", "application/json")
340 .header("authorization", "Bearer token123")
341 .header("user-agent", "test-client")
342 .body(())
343 .unwrap();
344
345 let size = calculate_request_size(&req);
346
347 assert!(size > 100); }
350
351 #[test]
352 fn test_calculate_response_size_basic() {
353 let res = Response::builder()
354 .status(StatusCode::OK)
355 .header("content-type", "application/json")
356 .body(axum::body::Body::empty())
357 .unwrap();
358
359 let size = calculate_response_size(&res);
360
361 assert!(size > 0);
363 assert!(size >= 50);
365 }
366
367 #[test]
368 fn test_calculate_response_size_with_multiple_headers() {
369 let res = Response::builder()
370 .status(StatusCode::OK)
371 .header("content-type", "application/json")
372 .header("cache-control", "no-cache")
373 .header("x-request-id", "123-456-789")
374 .body(axum::body::Body::empty())
375 .unwrap();
376
377 let size = calculate_response_size(&res);
378
379 assert!(size > 100);
381 }
382
383 #[test]
384 fn test_shared_clone() {
385 let shared = Shared {
386 profiles: LatencyProfiles::default(),
387 overrides: Overrides::default(),
388 failure_injector: None,
389 traffic_shaper: None,
390 overrides_enabled: true,
391 traffic_shaping_enabled: true,
392 };
393
394 let cloned = shared.clone();
395
396 assert_eq!(shared.overrides_enabled, cloned.overrides_enabled);
397 assert_eq!(shared.traffic_shaping_enabled, cloned.traffic_shaping_enabled);
398 }
399
400 #[test]
401 fn test_operation_meta_clone() {
402 let meta = OperationMeta {
403 id: "testOp".to_string(),
404 tags: vec!["tag1".to_string()],
405 path: "/test".to_string(),
406 };
407
408 let cloned = meta.clone();
409
410 assert_eq!(meta.id, cloned.id);
411 assert_eq!(meta.tags, cloned.tags);
412 assert_eq!(meta.path, cloned.path);
413 }
414}