mockforge_http/
op_middleware.rs

1//! Middleware/utilities to apply latency/failure and overrides per operation.
2use axum::body::Body;
3use axum::extract::State;
4use axum::http::{Request, StatusCode};
5use axum::middleware::Next;
6use axum::response::Response;
7use serde_json::Value;
8
9use crate::latency_profiles::LatencyProfiles;
10use mockforge_core::{FailureInjector, Overrides, TrafficShaper};
11
12/// Metadata for the current OpenAPI operation
13#[derive(Clone)]
14pub struct OperationMeta {
15    /// OpenAPI operation ID
16    pub id: String,
17    /// Tags associated with this operation
18    pub tags: Vec<String>,
19    /// API path pattern
20    pub path: String,
21}
22
23/// Shared state for operation middleware
24#[derive(Clone)]
25pub struct Shared {
26    /// Latency profiles for request simulation
27    pub profiles: LatencyProfiles,
28    /// Response overrides configuration
29    pub overrides: Overrides,
30    /// Optional failure injector for chaos engineering
31    pub failure_injector: Option<FailureInjector>,
32    /// Optional traffic shaper for bandwidth/loss simulation
33    pub traffic_shaper: Option<TrafficShaper>,
34    /// Whether overrides are enabled
35    pub overrides_enabled: bool,
36    /// Whether traffic shaping is enabled
37    pub traffic_shaping_enabled: bool,
38}
39
40/// Middleware to add shared state to request extensions
41pub async fn add_shared_extension(
42    State(shared): State<Shared>,
43    mut req: Request<Body>,
44    next: Next,
45) -> Response {
46    req.extensions_mut().insert(shared);
47    next.run(req).await
48}
49
50/// Middleware to apply fault injection before processing request
51pub async fn fault_then_next(req: Request<Body>, next: Next) -> Response {
52    let shared = match req.extensions().get::<Shared>() {
53        Some(s) => s.clone(),
54        None => {
55            tracing::error!("Shared extension not found in request - ensure add_shared_extension middleware is configured");
56            let mut res =
57                Response::new(Body::from("Internal server error: middleware misconfiguration"));
58            *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
59            return res;
60        }
61    };
62    let op = req.extensions().get::<OperationMeta>().cloned();
63
64    // First, check the new enhanced failure injection system
65    if let Some(failure_injector) = &shared.failure_injector {
66        let tags = op.as_ref().map(|o| o.tags.as_slice()).unwrap_or(&[]);
67        if let Some((status_code, error_message)) = failure_injector.process_request(tags) {
68            let mut res = Response::new(axum::body::Body::from(error_message));
69            *res.status_mut() =
70                StatusCode::from_u16(status_code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
71            return res;
72        }
73    }
74
75    // Fallback to legacy latency profiles system for backward compatibility
76    if let Some(op) = &op {
77        if let Some((code, msg)) = shared
78            .profiles
79            .maybe_fault(&op.id, &op.tags.iter().map(|s| s.to_string()).collect::<Vec<_>>())
80            .await
81        {
82            let mut res = Response::new(axum::body::Body::from(msg));
83            *res.status_mut() =
84                StatusCode::from_u16(code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
85            return res;
86        }
87    }
88
89    // Apply traffic shaping (bandwidth throttling and burst loss) to the request
90    if shared.traffic_shaping_enabled {
91        if let Some(traffic_shaper) = &shared.traffic_shaper {
92            // Calculate request size for bandwidth throttling
93            let request_size = calculate_request_size(&req);
94
95            let tags = op.as_ref().map(|o| o.tags.as_slice()).unwrap_or(&[]);
96
97            // Apply traffic shaping
98            match traffic_shaper.process_transfer(request_size, tags).await {
99                Ok(Some(_timeout)) => {
100                    // Request was "lost" due to burst loss - return timeout error
101                    let mut res = Response::new(axum::body::Body::from(
102                        "Request timeout due to traffic shaping",
103                    ));
104                    *res.status_mut() = StatusCode::REQUEST_TIMEOUT;
105                    return res;
106                }
107                Ok(None) => {
108                    // Transfer allowed, continue
109                }
110                Err(e) => {
111                    // Traffic shaping error - return internal server error
112                    let mut res = Response::new(axum::body::Body::from(format!(
113                        "Traffic shaping error: {}",
114                        e
115                    )));
116                    *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
117                    return res;
118                }
119            }
120        }
121    }
122
123    let (parts, body) = req.into_parts();
124    let req = Request::from_parts(parts, body);
125
126    let response = next.run(req).await;
127
128    // Apply traffic shaping to the response
129    if shared.traffic_shaping_enabled {
130        if let Some(traffic_shaper) = &shared.traffic_shaper {
131            // Calculate response size for bandwidth throttling
132            let response_size = calculate_response_size(&response);
133
134            let tags = op.as_ref().map(|o| o.tags.as_slice()).unwrap_or(&[]);
135
136            // Apply traffic shaping to response
137            match traffic_shaper.process_transfer(response_size, tags).await {
138                Ok(Some(_timeout)) => {
139                    // Response was "lost" due to burst loss - return timeout error
140                    let mut res = Response::new(axum::body::Body::from(
141                        "Response timeout due to traffic shaping",
142                    ));
143                    *res.status_mut() = StatusCode::GATEWAY_TIMEOUT;
144                    return res;
145                }
146                Ok(None) => {
147                    // Transfer allowed, continue
148                }
149                Err(e) => {
150                    // Traffic shaping error - return internal server error
151                    let mut res = Response::new(axum::body::Body::from(format!(
152                        "Traffic shaping error: {}",
153                        e
154                    )));
155                    *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
156                    return res;
157                }
158            }
159        }
160    }
161
162    response
163}
164
165/// Apply response overrides to a JSON body based on operation metadata
166///
167/// # Arguments
168/// * `shared` - Shared middleware state containing override configuration
169/// * `op` - Optional operation metadata for override matching
170/// * `body` - JSON response body to modify in-place
171pub fn apply_overrides(shared: &Shared, op: Option<&OperationMeta>, body: &mut Value) {
172    if shared.overrides_enabled {
173        if let Some(op) = op {
174            shared.overrides.apply(
175                &op.id,
176                &op.tags.iter().map(|s| s.to_string()).collect::<Vec<_>>(),
177                &op.path,
178                body,
179            );
180        }
181    }
182}
183
184/// Calculate the approximate size of an HTTP request for bandwidth throttling
185fn calculate_request_size<B>(req: &Request<B>) -> u64 {
186    let mut size = 0u64;
187
188    // Add header sizes (rough estimate)
189    for (name, value) in req.headers() {
190        size += name.as_str().len() as u64;
191        size += value.as_bytes().len() as u64;
192    }
193
194    // Add URI size
195    size += req.uri().to_string().len() as u64;
196
197    // Add body size (if available)
198    // Note: This is a rough estimate since we can't easily get the body size here
199    // without consuming the body. In practice, this would need to be implemented
200    // differently to get accurate body sizes.
201    size += 1024; // Rough estimate for body size
202
203    size
204}
205
206/// Calculate the approximate size of an HTTP response for bandwidth throttling
207fn calculate_response_size(res: &Response) -> u64 {
208    let mut size = 0u64;
209
210    // Add header sizes
211    for (name, value) in res.headers() {
212        size += name.as_str().len() as u64;
213        size += value.as_bytes().len() as u64;
214    }
215
216    // Add status line size (rough estimate)
217    size += 50;
218
219    // Add body size (rough estimate)
220    // Similar to request, this is a rough estimate
221    size += 2048; // Rough estimate for response body size
222
223    size
224}
225
226#[cfg(test)]
227mod tests {
228    use super::*;
229    use axum::http::{Request, Response, StatusCode};
230    use serde_json::json;
231
232    #[test]
233    fn test_operation_meta_creation() {
234        let meta = OperationMeta {
235            id: "getUserById".to_string(),
236            tags: vec!["users".to_string(), "public".to_string()],
237            path: "/users/{id}".to_string(),
238        };
239
240        assert_eq!(meta.id, "getUserById");
241        assert_eq!(meta.tags.len(), 2);
242        assert_eq!(meta.path, "/users/{id}");
243    }
244
245    #[test]
246    fn test_shared_creation() {
247        let shared = Shared {
248            profiles: LatencyProfiles::default(),
249            overrides: Overrides::default(),
250            failure_injector: None,
251            traffic_shaper: None,
252            overrides_enabled: false,
253            traffic_shaping_enabled: false,
254        };
255
256        assert!(!shared.overrides_enabled);
257        assert!(!shared.traffic_shaping_enabled);
258        assert!(shared.failure_injector.is_none());
259        assert!(shared.traffic_shaper.is_none());
260    }
261
262    #[test]
263    fn test_shared_with_failure_injector() {
264        let failure_injector = FailureInjector::new(None, true);
265        let shared = Shared {
266            profiles: LatencyProfiles::default(),
267            overrides: Overrides::default(),
268            failure_injector: Some(failure_injector),
269            traffic_shaper: None,
270            overrides_enabled: false,
271            traffic_shaping_enabled: false,
272        };
273
274        assert!(shared.failure_injector.is_some());
275    }
276
277    #[test]
278    fn test_apply_overrides_disabled() {
279        let shared = Shared {
280            profiles: LatencyProfiles::default(),
281            overrides: Overrides::default(),
282            failure_injector: None,
283            traffic_shaper: None,
284            overrides_enabled: false,
285            traffic_shaping_enabled: false,
286        };
287
288        let op = OperationMeta {
289            id: "getUser".to_string(),
290            tags: vec![],
291            path: "/users".to_string(),
292        };
293
294        let mut body = json!({"name": "John"});
295        let original = body.clone();
296
297        apply_overrides(&shared, Some(&op), &mut body);
298
299        // Should not modify body when overrides are disabled
300        assert_eq!(body, original);
301    }
302
303    #[test]
304    fn test_apply_overrides_enabled_no_rules() {
305        let shared = Shared {
306            profiles: LatencyProfiles::default(),
307            overrides: Overrides::default(),
308            failure_injector: None,
309            traffic_shaper: None,
310            overrides_enabled: true,
311            traffic_shaping_enabled: false,
312        };
313
314        let op = OperationMeta {
315            id: "getUser".to_string(),
316            tags: vec![],
317            path: "/users".to_string(),
318        };
319
320        let mut body = json!({"name": "John"});
321        let original = body.clone();
322
323        apply_overrides(&shared, Some(&op), &mut body);
324
325        // Should not modify body when there are no override rules
326        assert_eq!(body, original);
327    }
328
329    #[test]
330    fn test_apply_overrides_with_none_operation() {
331        let shared = Shared {
332            profiles: LatencyProfiles::default(),
333            overrides: Overrides::default(),
334            failure_injector: None,
335            traffic_shaper: None,
336            overrides_enabled: true,
337            traffic_shaping_enabled: false,
338        };
339
340        let mut body = json!({"name": "John"});
341        let original = body.clone();
342
343        apply_overrides(&shared, None, &mut body);
344
345        // Should not modify body when operation is None
346        assert_eq!(body, original);
347    }
348
349    #[test]
350    fn test_calculate_request_size_basic() {
351        let req = Request::builder()
352            .uri("/test")
353            .header("content-type", "application/json")
354            .body(())
355            .unwrap();
356
357        let size = calculate_request_size(&req);
358
359        // Should be > 0 (includes headers + URI + body estimate)
360        assert!(size > 0);
361        // Should include at least the URI and header sizes
362        assert!(size >= "/test".len() as u64 + "content-type".len() as u64);
363    }
364
365    #[test]
366    fn test_calculate_request_size_with_multiple_headers() {
367        let req = Request::builder()
368            .uri("/api/users")
369            .header("content-type", "application/json")
370            .header("authorization", "Bearer token123")
371            .header("user-agent", "test-client")
372            .body(())
373            .unwrap();
374
375        let size = calculate_request_size(&req);
376
377        // Should account for all headers
378        assert!(size > 100); // Reasonable size with multiple headers
379    }
380
381    #[test]
382    fn test_calculate_response_size_basic() {
383        let res = Response::builder()
384            .status(StatusCode::OK)
385            .header("content-type", "application/json")
386            .body(axum::body::Body::empty())
387            .unwrap();
388
389        let size = calculate_response_size(&res);
390
391        // Should be > 0 (includes status line + headers + body estimate)
392        assert!(size > 0);
393        // Should include at least the status line estimate (50) and header sizes
394        assert!(size >= 50);
395    }
396
397    #[test]
398    fn test_calculate_response_size_with_multiple_headers() {
399        let res = Response::builder()
400            .status(StatusCode::OK)
401            .header("content-type", "application/json")
402            .header("cache-control", "no-cache")
403            .header("x-request-id", "123-456-789")
404            .body(axum::body::Body::empty())
405            .unwrap();
406
407        let size = calculate_response_size(&res);
408
409        // Should account for all headers
410        assert!(size > 100);
411    }
412
413    #[test]
414    fn test_shared_clone() {
415        let shared = Shared {
416            profiles: LatencyProfiles::default(),
417            overrides: Overrides::default(),
418            failure_injector: None,
419            traffic_shaper: None,
420            overrides_enabled: true,
421            traffic_shaping_enabled: true,
422        };
423
424        let cloned = shared.clone();
425
426        assert_eq!(shared.overrides_enabled, cloned.overrides_enabled);
427        assert_eq!(shared.traffic_shaping_enabled, cloned.traffic_shaping_enabled);
428    }
429
430    #[test]
431    fn test_operation_meta_clone() {
432        let meta = OperationMeta {
433            id: "testOp".to_string(),
434            tags: vec!["tag1".to_string()],
435            path: "/test".to_string(),
436        };
437
438        let cloned = meta.clone();
439
440        assert_eq!(meta.id, cloned.id);
441        assert_eq!(meta.tags, cloned.tags);
442        assert_eq!(meta.path, cloned.path);
443    }
444}