mockforge_http/
op_middleware.rs

1//! Middleware/utilities to apply latency/failure and overrides per operation.
2use axum::body::Body;
3use axum::http::{Request, StatusCode};
4use axum::{extract::State, middleware::Next, response::Response};
5use serde_json::Value;
6
7use crate::latency_profiles::LatencyProfiles;
8use mockforge_core::{FailureInjector, Overrides, TrafficShaper};
9
10#[derive(Clone)]
11pub struct OperationMeta {
12    pub id: String,
13    pub tags: Vec<String>,
14    pub path: String,
15}
16
17#[derive(Clone)]
18pub struct Shared {
19    pub profiles: LatencyProfiles,
20    pub overrides: Overrides,
21    pub failure_injector: Option<FailureInjector>,
22    pub traffic_shaper: Option<TrafficShaper>,
23    pub overrides_enabled: bool,
24    pub traffic_shaping_enabled: bool,
25}
26
27pub async fn add_shared_extension(
28    State(shared): State<Shared>,
29    mut req: Request<Body>,
30    next: Next,
31) -> Response {
32    req.extensions_mut().insert(shared);
33    next.run(req).await
34}
35
36pub async fn fault_then_next(req: Request<Body>, next: Next) -> Response {
37    let shared = req.extensions().get::<Shared>().unwrap().clone();
38    let op = req.extensions().get::<OperationMeta>().cloned();
39
40    // First, check the new enhanced failure injection system
41    if let Some(failure_injector) = &shared.failure_injector {
42        let tags = op.as_ref().map(|o| o.tags.as_slice()).unwrap_or(&[]);
43        if let Some((status_code, error_message)) = failure_injector.process_request(tags) {
44            let mut res = Response::new(axum::body::Body::from(error_message));
45            *res.status_mut() =
46                StatusCode::from_u16(status_code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
47            return res;
48        }
49    }
50
51    // Fallback to legacy latency profiles system for backward compatibility
52    if let Some(op) = &op {
53        if let Some((code, msg)) = shared
54            .profiles
55            .maybe_fault(&op.id, &op.tags.iter().map(|s| s.to_string()).collect::<Vec<_>>())
56            .await
57        {
58            let mut res = Response::new(axum::body::Body::from(msg));
59            *res.status_mut() =
60                StatusCode::from_u16(code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
61            return res;
62        }
63    }
64
65    // Apply traffic shaping (bandwidth throttling and burst loss) to the request
66    if shared.traffic_shaping_enabled {
67        if let Some(traffic_shaper) = &shared.traffic_shaper {
68            // Calculate request size for bandwidth throttling
69            let request_size = calculate_request_size(&req);
70
71            let tags = op.as_ref().map(|o| o.tags.as_slice()).unwrap_or(&[]);
72
73            // Apply traffic shaping
74            match traffic_shaper.process_transfer(request_size, tags).await {
75                Ok(Some(_timeout)) => {
76                    // Request was "lost" due to burst loss - return timeout error
77                    let mut res = Response::new(axum::body::Body::from(
78                        "Request timeout due to traffic shaping",
79                    ));
80                    *res.status_mut() = StatusCode::REQUEST_TIMEOUT;
81                    return res;
82                }
83                Ok(None) => {
84                    // Transfer allowed, continue
85                }
86                Err(e) => {
87                    // Traffic shaping error - return internal server error
88                    let mut res = Response::new(axum::body::Body::from(format!(
89                        "Traffic shaping error: {}",
90                        e
91                    )));
92                    *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
93                    return res;
94                }
95            }
96        }
97    }
98
99    let (parts, body) = req.into_parts();
100    let req = Request::from_parts(parts, body);
101
102    let response = next.run(req).await;
103
104    // Apply traffic shaping to the response
105    if shared.traffic_shaping_enabled {
106        if let Some(traffic_shaper) = &shared.traffic_shaper {
107            // Calculate response size for bandwidth throttling
108            let response_size = calculate_response_size(&response);
109
110            let tags = op.as_ref().map(|o| o.tags.as_slice()).unwrap_or(&[]);
111
112            // Apply traffic shaping to response
113            match traffic_shaper.process_transfer(response_size, tags).await {
114                Ok(Some(_timeout)) => {
115                    // Response was "lost" due to burst loss - return timeout error
116                    let mut res = Response::new(axum::body::Body::from(
117                        "Response timeout due to traffic shaping",
118                    ));
119                    *res.status_mut() = StatusCode::GATEWAY_TIMEOUT;
120                    return res;
121                }
122                Ok(None) => {
123                    // Transfer allowed, continue
124                }
125                Err(e) => {
126                    // Traffic shaping error - return internal server error
127                    let mut res = Response::new(axum::body::Body::from(format!(
128                        "Traffic shaping error: {}",
129                        e
130                    )));
131                    *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
132                    return res;
133                }
134            }
135        }
136    }
137
138    response
139}
140
141pub fn apply_overrides(shared: &Shared, op: Option<&OperationMeta>, body: &mut Value) {
142    if shared.overrides_enabled {
143        if let Some(op) = op {
144            shared.overrides.apply(
145                &op.id,
146                &op.tags.iter().map(|s| s.to_string()).collect::<Vec<_>>(),
147                &op.path,
148                body,
149            );
150        }
151    }
152}
153
154/// Calculate the approximate size of an HTTP request for bandwidth throttling
155fn calculate_request_size<B>(req: &Request<B>) -> u64 {
156    let mut size = 0u64;
157
158    // Add header sizes (rough estimate)
159    for (name, value) in req.headers() {
160        size += name.as_str().len() as u64;
161        size += value.as_bytes().len() as u64;
162    }
163
164    // Add URI size
165    size += req.uri().to_string().len() as u64;
166
167    // Add body size (if available)
168    // Note: This is a rough estimate since we can't easily get the body size here
169    // without consuming the body. In practice, this would need to be implemented
170    // differently to get accurate body sizes.
171    size += 1024; // Rough estimate for body size
172
173    size
174}
175
176/// Calculate the approximate size of an HTTP response for bandwidth throttling
177fn calculate_response_size(res: &Response) -> u64 {
178    let mut size = 0u64;
179
180    // Add header sizes
181    for (name, value) in res.headers() {
182        size += name.as_str().len() as u64;
183        size += value.as_bytes().len() as u64;
184    }
185
186    // Add status line size (rough estimate)
187    size += 50;
188
189    // Add body size (rough estimate)
190    // Similar to request, this is a rough estimate
191    size += 2048; // Rough estimate for response body size
192
193    size
194}
195
196#[cfg(test)]
197mod tests {
198    use super::*;
199    use axum::http::{Request, Response, StatusCode};
200    use serde_json::json;
201
202    #[test]
203    fn test_operation_meta_creation() {
204        let meta = OperationMeta {
205            id: "getUserById".to_string(),
206            tags: vec!["users".to_string(), "public".to_string()],
207            path: "/users/{id}".to_string(),
208        };
209
210        assert_eq!(meta.id, "getUserById");
211        assert_eq!(meta.tags.len(), 2);
212        assert_eq!(meta.path, "/users/{id}");
213    }
214
215    #[test]
216    fn test_shared_creation() {
217        let shared = Shared {
218            profiles: LatencyProfiles::default(),
219            overrides: Overrides::default(),
220            failure_injector: None,
221            traffic_shaper: None,
222            overrides_enabled: false,
223            traffic_shaping_enabled: false,
224        };
225
226        assert!(!shared.overrides_enabled);
227        assert!(!shared.traffic_shaping_enabled);
228        assert!(shared.failure_injector.is_none());
229        assert!(shared.traffic_shaper.is_none());
230    }
231
232    #[test]
233    fn test_shared_with_failure_injector() {
234        let failure_injector = FailureInjector::new(None, true);
235        let shared = Shared {
236            profiles: LatencyProfiles::default(),
237            overrides: Overrides::default(),
238            failure_injector: Some(failure_injector),
239            traffic_shaper: None,
240            overrides_enabled: false,
241            traffic_shaping_enabled: false,
242        };
243
244        assert!(shared.failure_injector.is_some());
245    }
246
247    #[test]
248    fn test_apply_overrides_disabled() {
249        let shared = Shared {
250            profiles: LatencyProfiles::default(),
251            overrides: Overrides::default(),
252            failure_injector: None,
253            traffic_shaper: None,
254            overrides_enabled: false,
255            traffic_shaping_enabled: false,
256        };
257
258        let op = OperationMeta {
259            id: "getUser".to_string(),
260            tags: vec![],
261            path: "/users".to_string(),
262        };
263
264        let mut body = json!({"name": "John"});
265        let original = body.clone();
266
267        apply_overrides(&shared, Some(&op), &mut body);
268
269        // Should not modify body when overrides are disabled
270        assert_eq!(body, original);
271    }
272
273    #[test]
274    fn test_apply_overrides_enabled_no_rules() {
275        let shared = Shared {
276            profiles: LatencyProfiles::default(),
277            overrides: Overrides::default(),
278            failure_injector: None,
279            traffic_shaper: None,
280            overrides_enabled: true,
281            traffic_shaping_enabled: false,
282        };
283
284        let op = OperationMeta {
285            id: "getUser".to_string(),
286            tags: vec![],
287            path: "/users".to_string(),
288        };
289
290        let mut body = json!({"name": "John"});
291        let original = body.clone();
292
293        apply_overrides(&shared, Some(&op), &mut body);
294
295        // Should not modify body when there are no override rules
296        assert_eq!(body, original);
297    }
298
299    #[test]
300    fn test_apply_overrides_with_none_operation() {
301        let shared = Shared {
302            profiles: LatencyProfiles::default(),
303            overrides: Overrides::default(),
304            failure_injector: None,
305            traffic_shaper: None,
306            overrides_enabled: true,
307            traffic_shaping_enabled: false,
308        };
309
310        let mut body = json!({"name": "John"});
311        let original = body.clone();
312
313        apply_overrides(&shared, None, &mut body);
314
315        // Should not modify body when operation is None
316        assert_eq!(body, original);
317    }
318
319    #[test]
320    fn test_calculate_request_size_basic() {
321        let req = Request::builder()
322            .uri("/test")
323            .header("content-type", "application/json")
324            .body(())
325            .unwrap();
326
327        let size = calculate_request_size(&req);
328
329        // Should be > 0 (includes headers + URI + body estimate)
330        assert!(size > 0);
331        // Should include at least the URI and header sizes
332        assert!(size >= "/test".len() as u64 + "content-type".len() as u64);
333    }
334
335    #[test]
336    fn test_calculate_request_size_with_multiple_headers() {
337        let req = Request::builder()
338            .uri("/api/users")
339            .header("content-type", "application/json")
340            .header("authorization", "Bearer token123")
341            .header("user-agent", "test-client")
342            .body(())
343            .unwrap();
344
345        let size = calculate_request_size(&req);
346
347        // Should account for all headers
348        assert!(size > 100); // Reasonable size with multiple headers
349    }
350
351    #[test]
352    fn test_calculate_response_size_basic() {
353        let res = Response::builder()
354            .status(StatusCode::OK)
355            .header("content-type", "application/json")
356            .body(axum::body::Body::empty())
357            .unwrap();
358
359        let size = calculate_response_size(&res);
360
361        // Should be > 0 (includes status line + headers + body estimate)
362        assert!(size > 0);
363        // Should include at least the status line estimate (50) and header sizes
364        assert!(size >= 50);
365    }
366
367    #[test]
368    fn test_calculate_response_size_with_multiple_headers() {
369        let res = Response::builder()
370            .status(StatusCode::OK)
371            .header("content-type", "application/json")
372            .header("cache-control", "no-cache")
373            .header("x-request-id", "123-456-789")
374            .body(axum::body::Body::empty())
375            .unwrap();
376
377        let size = calculate_response_size(&res);
378
379        // Should account for all headers
380        assert!(size > 100);
381    }
382
383    #[test]
384    fn test_shared_clone() {
385        let shared = Shared {
386            profiles: LatencyProfiles::default(),
387            overrides: Overrides::default(),
388            failure_injector: None,
389            traffic_shaper: None,
390            overrides_enabled: true,
391            traffic_shaping_enabled: true,
392        };
393
394        let cloned = shared.clone();
395
396        assert_eq!(shared.overrides_enabled, cloned.overrides_enabled);
397        assert_eq!(shared.traffic_shaping_enabled, cloned.traffic_shaping_enabled);
398    }
399
400    #[test]
401    fn test_operation_meta_clone() {
402        let meta = OperationMeta {
403            id: "testOp".to_string(),
404            tags: vec!["tag1".to_string()],
405            path: "/test".to_string(),
406        };
407
408        let cloned = meta.clone();
409
410        assert_eq!(meta.id, cloned.id);
411        assert_eq!(meta.tags, cloned.tags);
412        assert_eq!(meta.path, cloned.path);
413    }
414}