mockforge_http/
contract_diff_middleware.rs1use axum::http::HeaderMap;
9use axum::{body::Body, extract::Request, middleware::Next, response::Response};
10use mockforge_core::{
11 ai_contract_diff::CapturedRequest, request_capture::get_global_capture_manager,
12};
13use std::collections::HashMap;
14use tracing::debug;
15
16const MAX_CAPTURE_BODY_SIZE: usize = 1024 * 1024;
18
19pub async fn capture_for_contract_diff(req: Request<Body>, next: Next) -> Response {
21 let method = req.method().to_string();
22 let uri = req.uri().clone();
23 let path = uri.path().to_string();
24 let query = uri.query();
25
26 let headers = extract_headers_for_capture(req.headers());
28
29 let query_params = if let Some(query) = query {
31 parse_query_params(query)
32 } else {
33 HashMap::new()
34 };
35
36 let (parts, body) = req.into_parts();
38 let body_bytes = match axum::body::to_bytes(body, MAX_CAPTURE_BODY_SIZE).await {
39 Ok(b) => b,
40 Err(_) => {
41 let rebuilt = Request::from_parts(parts, Body::empty());
43 return next.run(rebuilt).await;
44 }
45 };
46
47 let captured_body = if !body_bytes.is_empty() {
49 serde_json::from_slice::<serde_json::Value>(&body_bytes).ok()
50 } else {
51 None
52 };
53
54 let rebuilt = Request::from_parts(parts, Body::from(body_bytes));
56
57 let response = next.run(rebuilt).await;
59
60 let status_code = response.status().as_u16();
62
63 let mut captured = CapturedRequest::new(&method, &path, "proxy_middleware")
65 .with_headers(headers)
66 .with_query_params(query_params)
67 .with_response(status_code, None);
68
69 if let Some(body_value) = captured_body {
70 captured = captured.with_body(body_value);
71 }
72
73 if let Some(capture_manager) = get_global_capture_manager() {
75 if let Err(e) = capture_manager.capture(captured).await {
76 debug!("Failed to capture request for contract diff: {}", e);
77 }
78 }
79
80 response
81}
82
83fn extract_headers_for_capture(headers: &HeaderMap) -> HashMap<String, String> {
85 let mut captured_headers = HashMap::new();
86
87 let safe_headers = [
89 "accept",
90 "accept-encoding",
91 "accept-language",
92 "content-type",
93 "content-length",
94 "user-agent",
95 "referer",
96 "origin",
97 "x-requested-with",
98 ];
99
100 for header_name in safe_headers {
101 if let Some(value) = headers.get(header_name) {
102 if let Ok(value_str) = value.to_str() {
103 captured_headers.insert(header_name.to_string(), value_str.to_string());
104 }
105 }
106 }
107
108 captured_headers
109}
110
111fn parse_query_params(query: &str) -> HashMap<String, String> {
113 let mut params = HashMap::new();
114
115 for pair in query.split('&') {
116 if let Some((key, value)) = pair.split_once('=') {
117 let decoded_key = urlencoding::decode(key).unwrap_or_else(|_| key.into());
118 let decoded_value = urlencoding::decode(value).unwrap_or_else(|_| value.into());
119 params.insert(decoded_key.to_string(), decoded_value.to_string());
120 }
121 }
122
123 params
124}
125
126#[cfg(test)]
127mod tests {
128 use super::*;
129 use axum::http::HeaderValue;
130
131 #[test]
132 fn test_extract_headers_for_capture() {
133 let mut headers = HeaderMap::new();
134 headers.insert("content-type", HeaderValue::from_static("application/json"));
135 headers.insert("authorization", HeaderValue::from_static("Bearer token"));
136 headers.insert("accept", HeaderValue::from_static("application/json"));
137
138 let captured = extract_headers_for_capture(&headers);
139
140 assert_eq!(captured.get("content-type"), Some(&"application/json".to_string()));
141 assert_eq!(captured.get("accept"), Some(&"application/json".to_string()));
142 assert!(!captured.contains_key("authorization")); }
144
145 #[test]
146 fn test_parse_query_params() {
147 let query = "name=John&age=30&city=New%20York";
148 let params = parse_query_params(query);
149
150 assert_eq!(params.get("name"), Some(&"John".to_string()));
151 assert_eq!(params.get("age"), Some(&"30".to_string()));
152 assert_eq!(params.get("city"), Some(&"New York".to_string()));
153 }
154
155 #[test]
156 fn test_parse_query_params_empty() {
157 let params = parse_query_params("");
158 assert!(params.is_empty());
159 }
160}