Skip to main content

mockforge_http/
op_middleware.rs

1//! Middleware/utilities to apply latency/failure and overrides per operation.
2use axum::body::Body;
3use axum::extract::State;
4use axum::http::{Request, StatusCode};
5use axum::middleware::Next;
6use axum::response::Response;
7use serde_json::Value;
8
9use crate::latency_profiles::LatencyProfiles;
10use mockforge_core::{FailureInjector, Overrides, TrafficShaper};
11
12/// Metadata for the current OpenAPI operation
13#[derive(Clone)]
14pub struct OperationMeta {
15    /// OpenAPI operation ID
16    pub id: String,
17    /// Tags associated with this operation
18    pub tags: Vec<String>,
19    /// API path pattern
20    pub path: String,
21}
22
23/// Shared state for operation middleware
24#[derive(Clone)]
25pub struct Shared {
26    /// Latency profiles for request simulation
27    pub profiles: LatencyProfiles,
28    /// Response overrides configuration
29    pub overrides: Overrides,
30    /// Optional failure injector for chaos engineering
31    pub failure_injector: Option<FailureInjector>,
32    /// Optional traffic shaper for bandwidth/loss simulation
33    pub traffic_shaper: Option<TrafficShaper>,
34    /// Whether overrides are enabled
35    pub overrides_enabled: bool,
36    /// Whether traffic shaping is enabled
37    pub traffic_shaping_enabled: bool,
38}
39
40/// Middleware to add shared state to request extensions
41pub async fn add_shared_extension(
42    State(shared): State<Shared>,
43    mut req: Request<Body>,
44    next: Next,
45) -> Response {
46    req.extensions_mut().insert(shared);
47    next.run(req).await
48}
49
50/// Middleware to apply fault injection before processing request
51pub async fn fault_then_next(req: Request<Body>, next: Next) -> Response {
52    let shared = match req.extensions().get::<Shared>() {
53        Some(s) => s.clone(),
54        None => {
55            tracing::error!("Shared extension not found in request - ensure add_shared_extension middleware is configured");
56            let mut res =
57                Response::new(Body::from("Internal server error: middleware misconfiguration"));
58            *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
59            return res;
60        }
61    };
62    let op = req.extensions().get::<OperationMeta>().cloned();
63
64    // First, check the new enhanced failure injection system
65    if let Some(failure_injector) = &shared.failure_injector {
66        let tags = op.as_ref().map(|o| o.tags.as_slice()).unwrap_or(&[]);
67        if let Some((status_code, error_message)) = failure_injector.process_request(tags) {
68            let mut res = Response::new(Body::from(error_message));
69            *res.status_mut() =
70                StatusCode::from_u16(status_code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
71            return res;
72        }
73    }
74
75    // Fallback to legacy latency profiles system for backward compatibility
76    if let Some(op) = &op {
77        if let Some((code, msg)) = shared
78            .profiles
79            .maybe_fault(&op.id, &op.tags.iter().map(|s| s.to_string()).collect::<Vec<_>>())
80            .await
81        {
82            let mut res = Response::new(Body::from(msg));
83            *res.status_mut() =
84                StatusCode::from_u16(code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
85            return res;
86        }
87    }
88
89    // Apply traffic shaping (bandwidth throttling and burst loss) to the request
90    if shared.traffic_shaping_enabled {
91        if let Some(traffic_shaper) = &shared.traffic_shaper {
92            // Calculate request size for bandwidth throttling
93            let request_size = calculate_request_size(&req);
94
95            let tags = op.as_ref().map(|o| o.tags.as_slice()).unwrap_or(&[]);
96
97            // Apply traffic shaping
98            match traffic_shaper.process_transfer(request_size, tags).await {
99                Ok(Some(_timeout)) => {
100                    // Request was "lost" due to burst loss - return timeout error
101                    let mut res =
102                        Response::new(Body::from("Request timeout due to traffic shaping"));
103                    *res.status_mut() = StatusCode::REQUEST_TIMEOUT;
104                    return res;
105                }
106                Ok(None) => {
107                    // Transfer allowed, continue
108                }
109                Err(e) => {
110                    // Traffic shaping error - return internal server error
111                    let mut res =
112                        Response::new(Body::from(format!("Traffic shaping error: {}", e)));
113                    *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
114                    return res;
115                }
116            }
117        }
118    }
119
120    let (parts, body) = req.into_parts();
121    let req = Request::from_parts(parts, body);
122
123    let response = next.run(req).await;
124
125    // Apply traffic shaping to the response
126    if shared.traffic_shaping_enabled {
127        if let Some(traffic_shaper) = &shared.traffic_shaper {
128            // Calculate response size for bandwidth throttling
129            let response_size = calculate_response_size(&response);
130
131            let tags = op.as_ref().map(|o| o.tags.as_slice()).unwrap_or(&[]);
132
133            // Apply traffic shaping to response
134            match traffic_shaper.process_transfer(response_size, tags).await {
135                Ok(Some(_timeout)) => {
136                    // Response was "lost" due to burst loss - return timeout error
137                    let mut res =
138                        Response::new(Body::from("Response timeout due to traffic shaping"));
139                    *res.status_mut() = StatusCode::GATEWAY_TIMEOUT;
140                    return res;
141                }
142                Ok(None) => {
143                    // Transfer allowed, continue
144                }
145                Err(e) => {
146                    // Traffic shaping error - return internal server error
147                    let mut res =
148                        Response::new(Body::from(format!("Traffic shaping error: {}", e)));
149                    *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
150                    return res;
151                }
152            }
153        }
154    }
155
156    response
157}
158
159/// Apply response overrides to a JSON body based on operation metadata
160///
161/// # Arguments
162/// * `shared` - Shared middleware state containing override configuration
163/// * `op` - Optional operation metadata for override matching
164/// * `body` - JSON response body to modify in-place
165pub fn apply_overrides(shared: &Shared, op: Option<&OperationMeta>, body: &mut Value) {
166    if shared.overrides_enabled {
167        if let Some(op) = op {
168            shared.overrides.apply(
169                &op.id,
170                &op.tags.iter().map(|s| s.to_string()).collect::<Vec<_>>(),
171                &op.path,
172                body,
173            );
174        }
175    }
176}
177
178/// Calculate the approximate size of an HTTP request for bandwidth throttling
179fn calculate_request_size<B>(req: &Request<B>) -> u64 {
180    let mut size = 0u64;
181
182    // Add header sizes
183    for (name, value) in req.headers() {
184        size += name.as_str().len() as u64;
185        size += value.as_bytes().len() as u64;
186    }
187
188    // Add URI size
189    size += req.uri().to_string().len() as u64;
190
191    // Use Content-Length header for body size when available
192    if let Some(content_length) = req.headers().get(http::header::CONTENT_LENGTH) {
193        if let Ok(len_str) = content_length.to_str() {
194            if let Ok(len) = len_str.parse::<u64>() {
195                size += len;
196                return size;
197            }
198        }
199    }
200
201    // Fallback: estimate body size from method (GET/HEAD/DELETE typically have no body)
202    let method = req.method();
203    if method == http::Method::POST || method == http::Method::PUT || method == http::Method::PATCH
204    {
205        size += 256; // Conservative estimate for requests without Content-Length
206    }
207
208    size
209}
210
211/// Calculate the approximate size of an HTTP response for bandwidth throttling
212fn calculate_response_size(res: &Response) -> u64 {
213    let mut size = 0u64;
214
215    // Add header sizes
216    for (name, value) in res.headers() {
217        size += name.as_str().len() as u64;
218        size += value.as_bytes().len() as u64;
219    }
220
221    // Add status line size
222    size += 15; // "HTTP/1.1 200 OK\r\n"
223
224    // Use Content-Length header for body size when available
225    if let Some(content_length) = res.headers().get(http::header::CONTENT_LENGTH) {
226        if let Ok(len_str) = content_length.to_str() {
227            if let Ok(len) = len_str.parse::<u64>() {
228                size += len;
229                return size;
230            }
231        }
232    }
233
234    // Fallback: estimate from status code (204/304 have no body)
235    match res.status().as_u16() {
236        204 | 304 => {}   // No body
237        _ => size += 256, // Conservative estimate for responses without Content-Length
238    }
239
240    size
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246    use axum::http::{Request, Response, StatusCode};
247    use serde_json::json;
248
249    #[test]
250    fn test_operation_meta_creation() {
251        let meta = OperationMeta {
252            id: "getUserById".to_string(),
253            tags: vec!["users".to_string(), "public".to_string()],
254            path: "/users/{id}".to_string(),
255        };
256
257        assert_eq!(meta.id, "getUserById");
258        assert_eq!(meta.tags.len(), 2);
259        assert_eq!(meta.path, "/users/{id}");
260    }
261
262    #[test]
263    fn test_shared_creation() {
264        let shared = Shared {
265            profiles: LatencyProfiles::default(),
266            overrides: Overrides::default(),
267            failure_injector: None,
268            traffic_shaper: None,
269            overrides_enabled: false,
270            traffic_shaping_enabled: false,
271        };
272
273        assert!(!shared.overrides_enabled);
274        assert!(!shared.traffic_shaping_enabled);
275        assert!(shared.failure_injector.is_none());
276        assert!(shared.traffic_shaper.is_none());
277    }
278
279    #[test]
280    fn test_shared_with_failure_injector() {
281        let failure_injector = FailureInjector::new(None, true);
282        let shared = Shared {
283            profiles: LatencyProfiles::default(),
284            overrides: Overrides::default(),
285            failure_injector: Some(failure_injector),
286            traffic_shaper: None,
287            overrides_enabled: false,
288            traffic_shaping_enabled: false,
289        };
290
291        assert!(shared.failure_injector.is_some());
292    }
293
294    #[test]
295    fn test_apply_overrides_disabled() {
296        let shared = Shared {
297            profiles: LatencyProfiles::default(),
298            overrides: Overrides::default(),
299            failure_injector: None,
300            traffic_shaper: None,
301            overrides_enabled: false,
302            traffic_shaping_enabled: false,
303        };
304
305        let op = OperationMeta {
306            id: "getUser".to_string(),
307            tags: vec![],
308            path: "/users".to_string(),
309        };
310
311        let mut body = json!({"name": "John"});
312        let original = body.clone();
313
314        apply_overrides(&shared, Some(&op), &mut body);
315
316        // Should not modify body when overrides are disabled
317        assert_eq!(body, original);
318    }
319
320    #[test]
321    fn test_apply_overrides_enabled_no_rules() {
322        let shared = Shared {
323            profiles: LatencyProfiles::default(),
324            overrides: Overrides::default(),
325            failure_injector: None,
326            traffic_shaper: None,
327            overrides_enabled: true,
328            traffic_shaping_enabled: false,
329        };
330
331        let op = OperationMeta {
332            id: "getUser".to_string(),
333            tags: vec![],
334            path: "/users".to_string(),
335        };
336
337        let mut body = json!({"name": "John"});
338        let original = body.clone();
339
340        apply_overrides(&shared, Some(&op), &mut body);
341
342        // Should not modify body when there are no override rules
343        assert_eq!(body, original);
344    }
345
346    #[test]
347    fn test_apply_overrides_with_none_operation() {
348        let shared = Shared {
349            profiles: LatencyProfiles::default(),
350            overrides: Overrides::default(),
351            failure_injector: None,
352            traffic_shaper: None,
353            overrides_enabled: true,
354            traffic_shaping_enabled: false,
355        };
356
357        let mut body = json!({"name": "John"});
358        let original = body.clone();
359
360        apply_overrides(&shared, None, &mut body);
361
362        // Should not modify body when operation is None
363        assert_eq!(body, original);
364    }
365
366    #[test]
367    fn test_calculate_request_size_basic() {
368        let req = Request::builder()
369            .uri("/test")
370            .header("content-type", "application/json")
371            .body(())
372            .unwrap();
373
374        let size = calculate_request_size(&req);
375
376        // Should be > 0 (includes headers + URI + body estimate)
377        assert!(size > 0);
378        // Should include at least the URI and header sizes
379        assert!(size >= "/test".len() as u64 + "content-type".len() as u64);
380    }
381
382    #[test]
383    fn test_calculate_request_size_with_multiple_headers() {
384        let req = Request::builder()
385            .uri("/api/users")
386            .header("content-type", "application/json")
387            .header("authorization", "Bearer token123")
388            .header("user-agent", "test-client")
389            .body(())
390            .unwrap();
391
392        let size = calculate_request_size(&req);
393
394        // Should account for all headers
395        assert!(size > 100); // Reasonable size with multiple headers
396    }
397
398    #[test]
399    fn test_calculate_response_size_basic() {
400        let res = Response::builder()
401            .status(StatusCode::OK)
402            .header("content-type", "application/json")
403            .body(axum::body::Body::empty())
404            .unwrap();
405
406        let size = calculate_response_size(&res);
407
408        // Should be > 0 (includes status line + headers + body estimate)
409        assert!(size > 0);
410        // Should include at least the status line estimate (50) and header sizes
411        assert!(size >= 50);
412    }
413
414    #[test]
415    fn test_calculate_response_size_with_multiple_headers() {
416        let res = Response::builder()
417            .status(StatusCode::OK)
418            .header("content-type", "application/json")
419            .header("cache-control", "no-cache")
420            .header("x-request-id", "123-456-789")
421            .body(axum::body::Body::empty())
422            .unwrap();
423
424        let size = calculate_response_size(&res);
425
426        // Should account for all headers
427        assert!(size > 100);
428    }
429
430    #[test]
431    fn test_shared_clone() {
432        let shared = Shared {
433            profiles: LatencyProfiles::default(),
434            overrides: Overrides::default(),
435            failure_injector: None,
436            traffic_shaper: None,
437            overrides_enabled: true,
438            traffic_shaping_enabled: true,
439        };
440
441        let cloned = shared.clone();
442
443        assert_eq!(shared.overrides_enabled, cloned.overrides_enabled);
444        assert_eq!(shared.traffic_shaping_enabled, cloned.traffic_shaping_enabled);
445    }
446
447    #[test]
448    fn test_operation_meta_clone() {
449        let meta = OperationMeta {
450            id: "testOp".to_string(),
451            tags: vec!["tag1".to_string()],
452            path: "/test".to_string(),
453        };
454
455        let cloned = meta.clone();
456
457        assert_eq!(meta.id, cloned.id);
458        assert_eq!(meta.tags, cloned.tags);
459        assert_eq!(meta.path, cloned.path);
460    }
461}