mockforge_http/
op_middleware.rs

1//! Middleware/utilities to apply latency/failure and overrides per operation.
2use axum::body::Body;
3use axum::extract::State;
4use axum::http::{Request, StatusCode};
5use axum::middleware::Next;
6use axum::response::{Json, Response};
7use serde_json::Value;
8
9use crate::latency_profiles::LatencyProfiles;
10use mockforge_core::{FailureInjector, Overrides, TrafficShaper};
11
12/// Metadata for the current OpenAPI operation
13#[derive(Clone)]
14pub struct OperationMeta {
15    /// OpenAPI operation ID
16    pub id: String,
17    /// Tags associated with this operation
18    pub tags: Vec<String>,
19    /// API path pattern
20    pub path: String,
21}
22
23/// Shared state for operation middleware
24#[derive(Clone)]
25pub struct Shared {
26    /// Latency profiles for request simulation
27    pub profiles: LatencyProfiles,
28    /// Response overrides configuration
29    pub overrides: Overrides,
30    /// Optional failure injector for chaos engineering
31    pub failure_injector: Option<FailureInjector>,
32    /// Optional traffic shaper for bandwidth/loss simulation
33    pub traffic_shaper: Option<TrafficShaper>,
34    /// Whether overrides are enabled
35    pub overrides_enabled: bool,
36    /// Whether traffic shaping is enabled
37    pub traffic_shaping_enabled: bool,
38}
39
40/// Middleware to add shared state to request extensions
41pub async fn add_shared_extension(
42    State(shared): State<Shared>,
43    mut req: Request<Body>,
44    next: Next,
45) -> Response {
46    req.extensions_mut().insert(shared);
47    next.run(req).await
48}
49
50/// Middleware to apply fault injection before processing request
51pub async fn fault_then_next(req: Request<Body>, next: Next) -> Response {
52    let shared = req.extensions().get::<Shared>().unwrap().clone();
53    let op = req.extensions().get::<OperationMeta>().cloned();
54
55    // First, check the new enhanced failure injection system
56    if let Some(failure_injector) = &shared.failure_injector {
57        let tags = op.as_ref().map(|o| o.tags.as_slice()).unwrap_or(&[]);
58        if let Some((status_code, error_message)) = failure_injector.process_request(tags) {
59            let mut res = Response::new(axum::body::Body::from(error_message));
60            *res.status_mut() =
61                StatusCode::from_u16(status_code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
62            return res;
63        }
64    }
65
66    // Fallback to legacy latency profiles system for backward compatibility
67    if let Some(op) = &op {
68        if let Some((code, msg)) = shared
69            .profiles
70            .maybe_fault(&op.id, &op.tags.iter().map(|s| s.to_string()).collect::<Vec<_>>())
71            .await
72        {
73            let mut res = Response::new(axum::body::Body::from(msg));
74            *res.status_mut() =
75                StatusCode::from_u16(code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
76            return res;
77        }
78    }
79
80    // Apply traffic shaping (bandwidth throttling and burst loss) to the request
81    if shared.traffic_shaping_enabled {
82        if let Some(traffic_shaper) = &shared.traffic_shaper {
83            // Calculate request size for bandwidth throttling
84            let request_size = calculate_request_size(&req);
85
86            let tags = op.as_ref().map(|o| o.tags.as_slice()).unwrap_or(&[]);
87
88            // Apply traffic shaping
89            match traffic_shaper.process_transfer(request_size, tags).await {
90                Ok(Some(_timeout)) => {
91                    // Request was "lost" due to burst loss - return timeout error
92                    let mut res = Response::new(axum::body::Body::from(
93                        "Request timeout due to traffic shaping",
94                    ));
95                    *res.status_mut() = StatusCode::REQUEST_TIMEOUT;
96                    return res;
97                }
98                Ok(None) => {
99                    // Transfer allowed, continue
100                }
101                Err(e) => {
102                    // Traffic shaping error - return internal server error
103                    let mut res = Response::new(axum::body::Body::from(format!(
104                        "Traffic shaping error: {}",
105                        e
106                    )));
107                    *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
108                    return res;
109                }
110            }
111        }
112    }
113
114    let (parts, body) = req.into_parts();
115    let req = Request::from_parts(parts, body);
116
117    let response = next.run(req).await;
118
119    // Apply traffic shaping to the response
120    if shared.traffic_shaping_enabled {
121        if let Some(traffic_shaper) = &shared.traffic_shaper {
122            // Calculate response size for bandwidth throttling
123            let response_size = calculate_response_size(&response);
124
125            let tags = op.as_ref().map(|o| o.tags.as_slice()).unwrap_or(&[]);
126
127            // Apply traffic shaping to response
128            match traffic_shaper.process_transfer(response_size, tags).await {
129                Ok(Some(_timeout)) => {
130                    // Response was "lost" due to burst loss - return timeout error
131                    let mut res = Response::new(axum::body::Body::from(
132                        "Response timeout due to traffic shaping",
133                    ));
134                    *res.status_mut() = StatusCode::GATEWAY_TIMEOUT;
135                    return res;
136                }
137                Ok(None) => {
138                    // Transfer allowed, continue
139                }
140                Err(e) => {
141                    // Traffic shaping error - return internal server error
142                    let mut res = Response::new(axum::body::Body::from(format!(
143                        "Traffic shaping error: {}",
144                        e
145                    )));
146                    *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
147                    return res;
148                }
149            }
150        }
151    }
152
153    response
154}
155
156/// Apply response overrides to a JSON body based on operation metadata
157///
158/// # Arguments
159/// * `shared` - Shared middleware state containing override configuration
160/// * `op` - Optional operation metadata for override matching
161/// * `body` - JSON response body to modify in-place
162pub fn apply_overrides(shared: &Shared, op: Option<&OperationMeta>, body: &mut Value) {
163    if shared.overrides_enabled {
164        if let Some(op) = op {
165            shared.overrides.apply(
166                &op.id,
167                &op.tags.iter().map(|s| s.to_string()).collect::<Vec<_>>(),
168                &op.path,
169                body,
170            );
171        }
172    }
173}
174
175/// Calculate the approximate size of an HTTP request for bandwidth throttling
176fn calculate_request_size<B>(req: &Request<B>) -> u64 {
177    let mut size = 0u64;
178
179    // Add header sizes (rough estimate)
180    for (name, value) in req.headers() {
181        size += name.as_str().len() as u64;
182        size += value.as_bytes().len() as u64;
183    }
184
185    // Add URI size
186    size += req.uri().to_string().len() as u64;
187
188    // Add body size (if available)
189    // Note: This is a rough estimate since we can't easily get the body size here
190    // without consuming the body. In practice, this would need to be implemented
191    // differently to get accurate body sizes.
192    size += 1024; // Rough estimate for body size
193
194    size
195}
196
197/// Calculate the approximate size of an HTTP response for bandwidth throttling
198fn calculate_response_size(res: &Response) -> u64 {
199    let mut size = 0u64;
200
201    // Add header sizes
202    for (name, value) in res.headers() {
203        size += name.as_str().len() as u64;
204        size += value.as_bytes().len() as u64;
205    }
206
207    // Add status line size (rough estimate)
208    size += 50;
209
210    // Add body size (rough estimate)
211    // Similar to request, this is a rough estimate
212    size += 2048; // Rough estimate for response body size
213
214    size
215}
216
217#[cfg(test)]
218mod tests {
219    use super::*;
220    use axum::http::{Request, Response, StatusCode};
221    use serde_json::json;
222
223    #[test]
224    fn test_operation_meta_creation() {
225        let meta = OperationMeta {
226            id: "getUserById".to_string(),
227            tags: vec!["users".to_string(), "public".to_string()],
228            path: "/users/{id}".to_string(),
229        };
230
231        assert_eq!(meta.id, "getUserById");
232        assert_eq!(meta.tags.len(), 2);
233        assert_eq!(meta.path, "/users/{id}");
234    }
235
236    #[test]
237    fn test_shared_creation() {
238        let shared = Shared {
239            profiles: LatencyProfiles::default(),
240            overrides: Overrides::default(),
241            failure_injector: None,
242            traffic_shaper: None,
243            overrides_enabled: false,
244            traffic_shaping_enabled: false,
245        };
246
247        assert!(!shared.overrides_enabled);
248        assert!(!shared.traffic_shaping_enabled);
249        assert!(shared.failure_injector.is_none());
250        assert!(shared.traffic_shaper.is_none());
251    }
252
253    #[test]
254    fn test_shared_with_failure_injector() {
255        let failure_injector = FailureInjector::new(None, true);
256        let shared = Shared {
257            profiles: LatencyProfiles::default(),
258            overrides: Overrides::default(),
259            failure_injector: Some(failure_injector),
260            traffic_shaper: None,
261            overrides_enabled: false,
262            traffic_shaping_enabled: false,
263        };
264
265        assert!(shared.failure_injector.is_some());
266    }
267
268    #[test]
269    fn test_apply_overrides_disabled() {
270        let shared = Shared {
271            profiles: LatencyProfiles::default(),
272            overrides: Overrides::default(),
273            failure_injector: None,
274            traffic_shaper: None,
275            overrides_enabled: false,
276            traffic_shaping_enabled: false,
277        };
278
279        let op = OperationMeta {
280            id: "getUser".to_string(),
281            tags: vec![],
282            path: "/users".to_string(),
283        };
284
285        let mut body = json!({"name": "John"});
286        let original = body.clone();
287
288        apply_overrides(&shared, Some(&op), &mut body);
289
290        // Should not modify body when overrides are disabled
291        assert_eq!(body, original);
292    }
293
294    #[test]
295    fn test_apply_overrides_enabled_no_rules() {
296        let shared = Shared {
297            profiles: LatencyProfiles::default(),
298            overrides: Overrides::default(),
299            failure_injector: None,
300            traffic_shaper: None,
301            overrides_enabled: true,
302            traffic_shaping_enabled: false,
303        };
304
305        let op = OperationMeta {
306            id: "getUser".to_string(),
307            tags: vec![],
308            path: "/users".to_string(),
309        };
310
311        let mut body = json!({"name": "John"});
312        let original = body.clone();
313
314        apply_overrides(&shared, Some(&op), &mut body);
315
316        // Should not modify body when there are no override rules
317        assert_eq!(body, original);
318    }
319
320    #[test]
321    fn test_apply_overrides_with_none_operation() {
322        let shared = Shared {
323            profiles: LatencyProfiles::default(),
324            overrides: Overrides::default(),
325            failure_injector: None,
326            traffic_shaper: None,
327            overrides_enabled: true,
328            traffic_shaping_enabled: false,
329        };
330
331        let mut body = json!({"name": "John"});
332        let original = body.clone();
333
334        apply_overrides(&shared, None, &mut body);
335
336        // Should not modify body when operation is None
337        assert_eq!(body, original);
338    }
339
340    #[test]
341    fn test_calculate_request_size_basic() {
342        let req = Request::builder()
343            .uri("/test")
344            .header("content-type", "application/json")
345            .body(())
346            .unwrap();
347
348        let size = calculate_request_size(&req);
349
350        // Should be > 0 (includes headers + URI + body estimate)
351        assert!(size > 0);
352        // Should include at least the URI and header sizes
353        assert!(size >= "/test".len() as u64 + "content-type".len() as u64);
354    }
355
356    #[test]
357    fn test_calculate_request_size_with_multiple_headers() {
358        let req = Request::builder()
359            .uri("/api/users")
360            .header("content-type", "application/json")
361            .header("authorization", "Bearer token123")
362            .header("user-agent", "test-client")
363            .body(())
364            .unwrap();
365
366        let size = calculate_request_size(&req);
367
368        // Should account for all headers
369        assert!(size > 100); // Reasonable size with multiple headers
370    }
371
372    #[test]
373    fn test_calculate_response_size_basic() {
374        let res = Response::builder()
375            .status(StatusCode::OK)
376            .header("content-type", "application/json")
377            .body(axum::body::Body::empty())
378            .unwrap();
379
380        let size = calculate_response_size(&res);
381
382        // Should be > 0 (includes status line + headers + body estimate)
383        assert!(size > 0);
384        // Should include at least the status line estimate (50) and header sizes
385        assert!(size >= 50);
386    }
387
388    #[test]
389    fn test_calculate_response_size_with_multiple_headers() {
390        let res = Response::builder()
391            .status(StatusCode::OK)
392            .header("content-type", "application/json")
393            .header("cache-control", "no-cache")
394            .header("x-request-id", "123-456-789")
395            .body(axum::body::Body::empty())
396            .unwrap();
397
398        let size = calculate_response_size(&res);
399
400        // Should account for all headers
401        assert!(size > 100);
402    }
403
404    #[test]
405    fn test_shared_clone() {
406        let shared = Shared {
407            profiles: LatencyProfiles::default(),
408            overrides: Overrides::default(),
409            failure_injector: None,
410            traffic_shaper: None,
411            overrides_enabled: true,
412            traffic_shaping_enabled: true,
413        };
414
415        let cloned = shared.clone();
416
417        assert_eq!(shared.overrides_enabled, cloned.overrides_enabled);
418        assert_eq!(shared.traffic_shaping_enabled, cloned.traffic_shaping_enabled);
419    }
420
421    #[test]
422    fn test_operation_meta_clone() {
423        let meta = OperationMeta {
424            id: "testOp".to_string(),
425            tags: vec!["tag1".to_string()],
426            path: "/test".to_string(),
427        };
428
429        let cloned = meta.clone();
430
431        assert_eq!(meta.id, cloned.id);
432        assert_eq!(meta.tags, cloned.tags);
433        assert_eq!(meta.path, cloned.path);
434    }
435}