Skip to main content

mockforge_http/
op_middleware.rs

1//! Middleware/utilities to apply latency/failure and overrides per operation.
2use axum::body::Body;
3use axum::extract::State;
4use axum::http::{Request, StatusCode};
5use axum::middleware::Next;
6use axum::response::Response;
7use serde_json::Value;
8
9use crate::latency_profiles::LatencyProfiles;
10use mockforge_core::{FailureInjector, Overrides, TrafficShaper};
11
12/// Metadata for the current OpenAPI operation
13#[derive(Clone)]
14pub struct OperationMeta {
15    /// OpenAPI operation ID
16    pub id: String,
17    /// Tags associated with this operation
18    pub tags: Vec<String>,
19    /// API path pattern
20    pub path: String,
21}
22
23/// Shared state for operation middleware
24#[derive(Clone)]
25pub struct Shared {
26    /// Latency profiles for request simulation
27    pub profiles: LatencyProfiles,
28    /// Response overrides configuration
29    pub overrides: Overrides,
30    /// Optional failure injector for chaos engineering
31    pub failure_injector: Option<FailureInjector>,
32    /// Optional traffic shaper for bandwidth/loss simulation
33    pub traffic_shaper: Option<TrafficShaper>,
34    /// Whether overrides are enabled
35    pub overrides_enabled: bool,
36    /// Whether traffic shaping is enabled
37    pub traffic_shaping_enabled: bool,
38}
39
40/// Middleware to add shared state to request extensions
41pub async fn add_shared_extension(
42    State(shared): State<Shared>,
43    mut req: Request<Body>,
44    next: Next,
45) -> Response {
46    req.extensions_mut().insert(shared);
47    next.run(req).await
48}
49
50/// Middleware to apply fault injection before processing request
51pub async fn fault_then_next(req: Request<Body>, next: Next) -> Response {
52    let shared = match req.extensions().get::<Shared>() {
53        Some(s) => s.clone(),
54        None => {
55            tracing::error!("Shared extension not found in request - ensure add_shared_extension middleware is configured");
56            let mut res =
57                Response::new(Body::from("Internal server error: middleware misconfiguration"));
58            *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
59            return res;
60        }
61    };
62    let op = req.extensions().get::<OperationMeta>().cloned();
63
64    // First, check the new enhanced failure injection system
65    if let Some(failure_injector) = &shared.failure_injector {
66        let tags = op.as_ref().map(|o| o.tags.as_slice()).unwrap_or(&[]);
67        if let Some((status_code, error_message)) = failure_injector.process_request(tags) {
68            let mut res = Response::new(Body::from(error_message));
69            *res.status_mut() =
70                StatusCode::from_u16(status_code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
71            return res;
72        }
73    }
74
75    // Fallback to legacy latency profiles system for backward compatibility
76    if let Some(op) = &op {
77        if let Some((code, msg)) = shared
78            .profiles
79            .maybe_fault(&op.id, &op.tags.iter().map(|s| s.to_string()).collect::<Vec<_>>())
80            .await
81        {
82            let mut res = Response::new(Body::from(msg));
83            *res.status_mut() =
84                StatusCode::from_u16(code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
85            return res;
86        }
87    }
88
89    // Apply traffic shaping (bandwidth throttling and burst loss) to the request
90    if shared.traffic_shaping_enabled {
91        if let Some(traffic_shaper) = &shared.traffic_shaper {
92            // Calculate request size for bandwidth throttling
93            let request_size = calculate_request_size(&req);
94
95            let tags = op.as_ref().map(|o| o.tags.as_slice()).unwrap_or(&[]);
96
97            // Apply traffic shaping
98            match traffic_shaper.process_transfer(request_size, tags).await {
99                Ok(Some(_timeout)) => {
100                    // Request was "lost" due to burst loss - return timeout error
101                    let mut res =
102                        Response::new(Body::from("Request timeout due to traffic shaping"));
103                    *res.status_mut() = StatusCode::REQUEST_TIMEOUT;
104                    return res;
105                }
106                Ok(None) => {
107                    // Transfer allowed, continue
108                }
109                Err(e) => {
110                    // Traffic shaping error - return internal server error
111                    let mut res =
112                        Response::new(Body::from(format!("Traffic shaping error: {}", e)));
113                    *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
114                    return res;
115                }
116            }
117        }
118    }
119
120    let (parts, body) = req.into_parts();
121    let req = Request::from_parts(parts, body);
122
123    let response = next.run(req).await;
124
125    // Apply traffic shaping to the response
126    if shared.traffic_shaping_enabled {
127        if let Some(traffic_shaper) = &shared.traffic_shaper {
128            // Calculate response size for bandwidth throttling
129            let response_size = calculate_response_size(&response);
130
131            let tags = op.as_ref().map(|o| o.tags.as_slice()).unwrap_or(&[]);
132
133            // Apply traffic shaping to response
134            match traffic_shaper.process_transfer(response_size, tags).await {
135                Ok(Some(_timeout)) => {
136                    // Response was "lost" due to burst loss - return timeout error
137                    let mut res =
138                        Response::new(Body::from("Response timeout due to traffic shaping"));
139                    *res.status_mut() = StatusCode::GATEWAY_TIMEOUT;
140                    return res;
141                }
142                Ok(None) => {
143                    // Transfer allowed, continue
144                }
145                Err(e) => {
146                    // Traffic shaping error - return internal server error
147                    let mut res =
148                        Response::new(Body::from(format!("Traffic shaping error: {}", e)));
149                    *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
150                    return res;
151                }
152            }
153        }
154    }
155
156    response
157}
158
159/// Apply response overrides to a JSON body based on operation metadata
160///
161/// # Arguments
162/// * `shared` - Shared middleware state containing override configuration
163/// * `op` - Optional operation metadata for override matching
164/// * `body` - JSON response body to modify in-place
165pub fn apply_overrides(shared: &Shared, op: Option<&OperationMeta>, body: &mut Value) {
166    if shared.overrides_enabled {
167        if let Some(op) = op {
168            shared.overrides.apply(
169                &op.id,
170                &op.tags.iter().map(|s| s.to_string()).collect::<Vec<_>>(),
171                &op.path,
172                body,
173            );
174        }
175    }
176}
177
178/// Calculate the approximate size of an HTTP request for bandwidth throttling
179fn calculate_request_size<B>(req: &Request<B>) -> u64 {
180    let mut size = 0u64;
181
182    // Add header sizes (rough estimate)
183    for (name, value) in req.headers() {
184        size += name.as_str().len() as u64;
185        size += value.as_bytes().len() as u64;
186    }
187
188    // Add URI size
189    size += req.uri().to_string().len() as u64;
190
191    // Add body size (if available)
192    // Note: This is a rough estimate since we can't easily get the body size here
193    // without consuming the body. In practice, this would need to be implemented
194    // differently to get accurate body sizes.
195    size += 1024; // Rough estimate for body size
196
197    size
198}
199
200/// Calculate the approximate size of an HTTP response for bandwidth throttling
201fn calculate_response_size(res: &Response) -> u64 {
202    let mut size = 0u64;
203
204    // Add header sizes
205    for (name, value) in res.headers() {
206        size += name.as_str().len() as u64;
207        size += value.as_bytes().len() as u64;
208    }
209
210    // Add status line size (rough estimate)
211    size += 50;
212
213    // Add body size (rough estimate)
214    // Similar to request, this is a rough estimate
215    size += 2048; // Rough estimate for response body size
216
217    size
218}
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223    use axum::http::{Request, Response, StatusCode};
224    use serde_json::json;
225
226    #[test]
227    fn test_operation_meta_creation() {
228        let meta = OperationMeta {
229            id: "getUserById".to_string(),
230            tags: vec!["users".to_string(), "public".to_string()],
231            path: "/users/{id}".to_string(),
232        };
233
234        assert_eq!(meta.id, "getUserById");
235        assert_eq!(meta.tags.len(), 2);
236        assert_eq!(meta.path, "/users/{id}");
237    }
238
239    #[test]
240    fn test_shared_creation() {
241        let shared = Shared {
242            profiles: LatencyProfiles::default(),
243            overrides: Overrides::default(),
244            failure_injector: None,
245            traffic_shaper: None,
246            overrides_enabled: false,
247            traffic_shaping_enabled: false,
248        };
249
250        assert!(!shared.overrides_enabled);
251        assert!(!shared.traffic_shaping_enabled);
252        assert!(shared.failure_injector.is_none());
253        assert!(shared.traffic_shaper.is_none());
254    }
255
256    #[test]
257    fn test_shared_with_failure_injector() {
258        let failure_injector = FailureInjector::new(None, true);
259        let shared = Shared {
260            profiles: LatencyProfiles::default(),
261            overrides: Overrides::default(),
262            failure_injector: Some(failure_injector),
263            traffic_shaper: None,
264            overrides_enabled: false,
265            traffic_shaping_enabled: false,
266        };
267
268        assert!(shared.failure_injector.is_some());
269    }
270
271    #[test]
272    fn test_apply_overrides_disabled() {
273        let shared = Shared {
274            profiles: LatencyProfiles::default(),
275            overrides: Overrides::default(),
276            failure_injector: None,
277            traffic_shaper: None,
278            overrides_enabled: false,
279            traffic_shaping_enabled: false,
280        };
281
282        let op = OperationMeta {
283            id: "getUser".to_string(),
284            tags: vec![],
285            path: "/users".to_string(),
286        };
287
288        let mut body = json!({"name": "John"});
289        let original = body.clone();
290
291        apply_overrides(&shared, Some(&op), &mut body);
292
293        // Should not modify body when overrides are disabled
294        assert_eq!(body, original);
295    }
296
297    #[test]
298    fn test_apply_overrides_enabled_no_rules() {
299        let shared = Shared {
300            profiles: LatencyProfiles::default(),
301            overrides: Overrides::default(),
302            failure_injector: None,
303            traffic_shaper: None,
304            overrides_enabled: true,
305            traffic_shaping_enabled: false,
306        };
307
308        let op = OperationMeta {
309            id: "getUser".to_string(),
310            tags: vec![],
311            path: "/users".to_string(),
312        };
313
314        let mut body = json!({"name": "John"});
315        let original = body.clone();
316
317        apply_overrides(&shared, Some(&op), &mut body);
318
319        // Should not modify body when there are no override rules
320        assert_eq!(body, original);
321    }
322
323    #[test]
324    fn test_apply_overrides_with_none_operation() {
325        let shared = Shared {
326            profiles: LatencyProfiles::default(),
327            overrides: Overrides::default(),
328            failure_injector: None,
329            traffic_shaper: None,
330            overrides_enabled: true,
331            traffic_shaping_enabled: false,
332        };
333
334        let mut body = json!({"name": "John"});
335        let original = body.clone();
336
337        apply_overrides(&shared, None, &mut body);
338
339        // Should not modify body when operation is None
340        assert_eq!(body, original);
341    }
342
343    #[test]
344    fn test_calculate_request_size_basic() {
345        let req = Request::builder()
346            .uri("/test")
347            .header("content-type", "application/json")
348            .body(())
349            .unwrap();
350
351        let size = calculate_request_size(&req);
352
353        // Should be > 0 (includes headers + URI + body estimate)
354        assert!(size > 0);
355        // Should include at least the URI and header sizes
356        assert!(size >= "/test".len() as u64 + "content-type".len() as u64);
357    }
358
359    #[test]
360    fn test_calculate_request_size_with_multiple_headers() {
361        let req = Request::builder()
362            .uri("/api/users")
363            .header("content-type", "application/json")
364            .header("authorization", "Bearer token123")
365            .header("user-agent", "test-client")
366            .body(())
367            .unwrap();
368
369        let size = calculate_request_size(&req);
370
371        // Should account for all headers
372        assert!(size > 100); // Reasonable size with multiple headers
373    }
374
375    #[test]
376    fn test_calculate_response_size_basic() {
377        let res = Response::builder()
378            .status(StatusCode::OK)
379            .header("content-type", "application/json")
380            .body(axum::body::Body::empty())
381            .unwrap();
382
383        let size = calculate_response_size(&res);
384
385        // Should be > 0 (includes status line + headers + body estimate)
386        assert!(size > 0);
387        // Should include at least the status line estimate (50) and header sizes
388        assert!(size >= 50);
389    }
390
391    #[test]
392    fn test_calculate_response_size_with_multiple_headers() {
393        let res = Response::builder()
394            .status(StatusCode::OK)
395            .header("content-type", "application/json")
396            .header("cache-control", "no-cache")
397            .header("x-request-id", "123-456-789")
398            .body(axum::body::Body::empty())
399            .unwrap();
400
401        let size = calculate_response_size(&res);
402
403        // Should account for all headers
404        assert!(size > 100);
405    }
406
407    #[test]
408    fn test_shared_clone() {
409        let shared = Shared {
410            profiles: LatencyProfiles::default(),
411            overrides: Overrides::default(),
412            failure_injector: None,
413            traffic_shaper: None,
414            overrides_enabled: true,
415            traffic_shaping_enabled: true,
416        };
417
418        let cloned = shared.clone();
419
420        assert_eq!(shared.overrides_enabled, cloned.overrides_enabled);
421        assert_eq!(shared.traffic_shaping_enabled, cloned.traffic_shaping_enabled);
422    }
423
424    #[test]
425    fn test_operation_meta_clone() {
426        let meta = OperationMeta {
427            id: "testOp".to_string(),
428            tags: vec!["tag1".to_string()],
429            path: "/test".to_string(),
430        };
431
432        let cloned = meta.clone();
433
434        assert_eq!(meta.id, cloned.id);
435        assert_eq!(meta.tags, cloned.tags);
436        assert_eq!(meta.path, cloned.path);
437    }
438}