mockforge_http/
op_middleware.rs

1//! Middleware/utilities to apply latency/failure and overrides per operation.
2use axum::body::Body;
3use axum::http::{Request, StatusCode};
4use axum::{extract::State, middleware::Next, response::Response};
5use serde_json::Value;
6
7use crate::latency_profiles::LatencyProfiles;
8use mockforge_core::{FailureInjector, Overrides, TrafficShaper};
9
10/// Metadata for the current OpenAPI operation
11#[derive(Clone)]
12pub struct OperationMeta {
13    /// OpenAPI operation ID
14    pub id: String,
15    /// Tags associated with this operation
16    pub tags: Vec<String>,
17    /// API path pattern
18    pub path: String,
19}
20
21/// Shared state for operation middleware
22#[derive(Clone)]
23pub struct Shared {
24    /// Latency profiles for request simulation
25    pub profiles: LatencyProfiles,
26    /// Response overrides configuration
27    pub overrides: Overrides,
28    /// Optional failure injector for chaos engineering
29    pub failure_injector: Option<FailureInjector>,
30    /// Optional traffic shaper for bandwidth/loss simulation
31    pub traffic_shaper: Option<TrafficShaper>,
32    /// Whether overrides are enabled
33    pub overrides_enabled: bool,
34    /// Whether traffic shaping is enabled
35    pub traffic_shaping_enabled: bool,
36}
37
38/// Middleware to add shared state to request extensions
39pub async fn add_shared_extension(
40    State(shared): State<Shared>,
41    mut req: Request<Body>,
42    next: Next,
43) -> Response {
44    req.extensions_mut().insert(shared);
45    next.run(req).await
46}
47
48/// Middleware to apply fault injection before processing request
49pub async fn fault_then_next(req: Request<Body>, next: Next) -> Response {
50    let shared = req.extensions().get::<Shared>().unwrap().clone();
51    let op = req.extensions().get::<OperationMeta>().cloned();
52
53    // First, check the new enhanced failure injection system
54    if let Some(failure_injector) = &shared.failure_injector {
55        let tags = op.as_ref().map(|o| o.tags.as_slice()).unwrap_or(&[]);
56        if let Some((status_code, error_message)) = failure_injector.process_request(tags) {
57            let mut res = Response::new(axum::body::Body::from(error_message));
58            *res.status_mut() =
59                StatusCode::from_u16(status_code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
60            return res;
61        }
62    }
63
64    // Fallback to legacy latency profiles system for backward compatibility
65    if let Some(op) = &op {
66        if let Some((code, msg)) = shared
67            .profiles
68            .maybe_fault(&op.id, &op.tags.iter().map(|s| s.to_string()).collect::<Vec<_>>())
69            .await
70        {
71            let mut res = Response::new(axum::body::Body::from(msg));
72            *res.status_mut() =
73                StatusCode::from_u16(code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
74            return res;
75        }
76    }
77
78    // Apply traffic shaping (bandwidth throttling and burst loss) to the request
79    if shared.traffic_shaping_enabled {
80        if let Some(traffic_shaper) = &shared.traffic_shaper {
81            // Calculate request size for bandwidth throttling
82            let request_size = calculate_request_size(&req);
83
84            let tags = op.as_ref().map(|o| o.tags.as_slice()).unwrap_or(&[]);
85
86            // Apply traffic shaping
87            match traffic_shaper.process_transfer(request_size, tags).await {
88                Ok(Some(_timeout)) => {
89                    // Request was "lost" due to burst loss - return timeout error
90                    let mut res = Response::new(axum::body::Body::from(
91                        "Request timeout due to traffic shaping",
92                    ));
93                    *res.status_mut() = StatusCode::REQUEST_TIMEOUT;
94                    return res;
95                }
96                Ok(None) => {
97                    // Transfer allowed, continue
98                }
99                Err(e) => {
100                    // Traffic shaping error - return internal server error
101                    let mut res = Response::new(axum::body::Body::from(format!(
102                        "Traffic shaping error: {}",
103                        e
104                    )));
105                    *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
106                    return res;
107                }
108            }
109        }
110    }
111
112    let (parts, body) = req.into_parts();
113    let req = Request::from_parts(parts, body);
114
115    let response = next.run(req).await;
116
117    // Apply traffic shaping to the response
118    if shared.traffic_shaping_enabled {
119        if let Some(traffic_shaper) = &shared.traffic_shaper {
120            // Calculate response size for bandwidth throttling
121            let response_size = calculate_response_size(&response);
122
123            let tags = op.as_ref().map(|o| o.tags.as_slice()).unwrap_or(&[]);
124
125            // Apply traffic shaping to response
126            match traffic_shaper.process_transfer(response_size, tags).await {
127                Ok(Some(_timeout)) => {
128                    // Response was "lost" due to burst loss - return timeout error
129                    let mut res = Response::new(axum::body::Body::from(
130                        "Response timeout due to traffic shaping",
131                    ));
132                    *res.status_mut() = StatusCode::GATEWAY_TIMEOUT;
133                    return res;
134                }
135                Ok(None) => {
136                    // Transfer allowed, continue
137                }
138                Err(e) => {
139                    // Traffic shaping error - return internal server error
140                    let mut res = Response::new(axum::body::Body::from(format!(
141                        "Traffic shaping error: {}",
142                        e
143                    )));
144                    *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
145                    return res;
146                }
147            }
148        }
149    }
150
151    response
152}
153
154/// Apply response overrides to a JSON body based on operation metadata
155///
156/// # Arguments
157/// * `shared` - Shared middleware state containing override configuration
158/// * `op` - Optional operation metadata for override matching
159/// * `body` - JSON response body to modify in-place
160pub fn apply_overrides(shared: &Shared, op: Option<&OperationMeta>, body: &mut Value) {
161    if shared.overrides_enabled {
162        if let Some(op) = op {
163            shared.overrides.apply(
164                &op.id,
165                &op.tags.iter().map(|s| s.to_string()).collect::<Vec<_>>(),
166                &op.path,
167                body,
168            );
169        }
170    }
171}
172
173/// Calculate the approximate size of an HTTP request for bandwidth throttling
174fn calculate_request_size<B>(req: &Request<B>) -> u64 {
175    let mut size = 0u64;
176
177    // Add header sizes (rough estimate)
178    for (name, value) in req.headers() {
179        size += name.as_str().len() as u64;
180        size += value.as_bytes().len() as u64;
181    }
182
183    // Add URI size
184    size += req.uri().to_string().len() as u64;
185
186    // Add body size (if available)
187    // Note: This is a rough estimate since we can't easily get the body size here
188    // without consuming the body. In practice, this would need to be implemented
189    // differently to get accurate body sizes.
190    size += 1024; // Rough estimate for body size
191
192    size
193}
194
195/// Calculate the approximate size of an HTTP response for bandwidth throttling
196fn calculate_response_size(res: &Response) -> u64 {
197    let mut size = 0u64;
198
199    // Add header sizes
200    for (name, value) in res.headers() {
201        size += name.as_str().len() as u64;
202        size += value.as_bytes().len() as u64;
203    }
204
205    // Add status line size (rough estimate)
206    size += 50;
207
208    // Add body size (rough estimate)
209    // Similar to request, this is a rough estimate
210    size += 2048; // Rough estimate for response body size
211
212    size
213}
214
215#[cfg(test)]
216mod tests {
217    use super::*;
218    use axum::http::{Request, Response, StatusCode};
219    use serde_json::json;
220
221    #[test]
222    fn test_operation_meta_creation() {
223        let meta = OperationMeta {
224            id: "getUserById".to_string(),
225            tags: vec!["users".to_string(), "public".to_string()],
226            path: "/users/{id}".to_string(),
227        };
228
229        assert_eq!(meta.id, "getUserById");
230        assert_eq!(meta.tags.len(), 2);
231        assert_eq!(meta.path, "/users/{id}");
232    }
233
234    #[test]
235    fn test_shared_creation() {
236        let shared = Shared {
237            profiles: LatencyProfiles::default(),
238            overrides: Overrides::default(),
239            failure_injector: None,
240            traffic_shaper: None,
241            overrides_enabled: false,
242            traffic_shaping_enabled: false,
243        };
244
245        assert!(!shared.overrides_enabled);
246        assert!(!shared.traffic_shaping_enabled);
247        assert!(shared.failure_injector.is_none());
248        assert!(shared.traffic_shaper.is_none());
249    }
250
251    #[test]
252    fn test_shared_with_failure_injector() {
253        let failure_injector = FailureInjector::new(None, true);
254        let shared = Shared {
255            profiles: LatencyProfiles::default(),
256            overrides: Overrides::default(),
257            failure_injector: Some(failure_injector),
258            traffic_shaper: None,
259            overrides_enabled: false,
260            traffic_shaping_enabled: false,
261        };
262
263        assert!(shared.failure_injector.is_some());
264    }
265
266    #[test]
267    fn test_apply_overrides_disabled() {
268        let shared = Shared {
269            profiles: LatencyProfiles::default(),
270            overrides: Overrides::default(),
271            failure_injector: None,
272            traffic_shaper: None,
273            overrides_enabled: false,
274            traffic_shaping_enabled: false,
275        };
276
277        let op = OperationMeta {
278            id: "getUser".to_string(),
279            tags: vec![],
280            path: "/users".to_string(),
281        };
282
283        let mut body = json!({"name": "John"});
284        let original = body.clone();
285
286        apply_overrides(&shared, Some(&op), &mut body);
287
288        // Should not modify body when overrides are disabled
289        assert_eq!(body, original);
290    }
291
292    #[test]
293    fn test_apply_overrides_enabled_no_rules() {
294        let shared = Shared {
295            profiles: LatencyProfiles::default(),
296            overrides: Overrides::default(),
297            failure_injector: None,
298            traffic_shaper: None,
299            overrides_enabled: true,
300            traffic_shaping_enabled: false,
301        };
302
303        let op = OperationMeta {
304            id: "getUser".to_string(),
305            tags: vec![],
306            path: "/users".to_string(),
307        };
308
309        let mut body = json!({"name": "John"});
310        let original = body.clone();
311
312        apply_overrides(&shared, Some(&op), &mut body);
313
314        // Should not modify body when there are no override rules
315        assert_eq!(body, original);
316    }
317
318    #[test]
319    fn test_apply_overrides_with_none_operation() {
320        let shared = Shared {
321            profiles: LatencyProfiles::default(),
322            overrides: Overrides::default(),
323            failure_injector: None,
324            traffic_shaper: None,
325            overrides_enabled: true,
326            traffic_shaping_enabled: false,
327        };
328
329        let mut body = json!({"name": "John"});
330        let original = body.clone();
331
332        apply_overrides(&shared, None, &mut body);
333
334        // Should not modify body when operation is None
335        assert_eq!(body, original);
336    }
337
338    #[test]
339    fn test_calculate_request_size_basic() {
340        let req = Request::builder()
341            .uri("/test")
342            .header("content-type", "application/json")
343            .body(())
344            .unwrap();
345
346        let size = calculate_request_size(&req);
347
348        // Should be > 0 (includes headers + URI + body estimate)
349        assert!(size > 0);
350        // Should include at least the URI and header sizes
351        assert!(size >= "/test".len() as u64 + "content-type".len() as u64);
352    }
353
354    #[test]
355    fn test_calculate_request_size_with_multiple_headers() {
356        let req = Request::builder()
357            .uri("/api/users")
358            .header("content-type", "application/json")
359            .header("authorization", "Bearer token123")
360            .header("user-agent", "test-client")
361            .body(())
362            .unwrap();
363
364        let size = calculate_request_size(&req);
365
366        // Should account for all headers
367        assert!(size > 100); // Reasonable size with multiple headers
368    }
369
370    #[test]
371    fn test_calculate_response_size_basic() {
372        let res = Response::builder()
373            .status(StatusCode::OK)
374            .header("content-type", "application/json")
375            .body(axum::body::Body::empty())
376            .unwrap();
377
378        let size = calculate_response_size(&res);
379
380        // Should be > 0 (includes status line + headers + body estimate)
381        assert!(size > 0);
382        // Should include at least the status line estimate (50) and header sizes
383        assert!(size >= 50);
384    }
385
386    #[test]
387    fn test_calculate_response_size_with_multiple_headers() {
388        let res = Response::builder()
389            .status(StatusCode::OK)
390            .header("content-type", "application/json")
391            .header("cache-control", "no-cache")
392            .header("x-request-id", "123-456-789")
393            .body(axum::body::Body::empty())
394            .unwrap();
395
396        let size = calculate_response_size(&res);
397
398        // Should account for all headers
399        assert!(size > 100);
400    }
401
402    #[test]
403    fn test_shared_clone() {
404        let shared = Shared {
405            profiles: LatencyProfiles::default(),
406            overrides: Overrides::default(),
407            failure_injector: None,
408            traffic_shaper: None,
409            overrides_enabled: true,
410            traffic_shaping_enabled: true,
411        };
412
413        let cloned = shared.clone();
414
415        assert_eq!(shared.overrides_enabled, cloned.overrides_enabled);
416        assert_eq!(shared.traffic_shaping_enabled, cloned.traffic_shaping_enabled);
417    }
418
419    #[test]
420    fn test_operation_meta_clone() {
421        let meta = OperationMeta {
422            id: "testOp".to_string(),
423            tags: vec!["tag1".to_string()],
424            path: "/test".to_string(),
425        };
426
427        let cloned = meta.clone();
428
429        assert_eq!(meta.id, cloned.id);
430        assert_eq!(meta.tags, cloned.tags);
431        assert_eq!(meta.path, cloned.path);
432    }
433}