mockforge_http/
contract_diff_middleware.rs1use axum::http::header::CONTENT_LENGTH;
9use axum::http::HeaderMap;
10use axum::{body::Body, extract::Request, middleware::Next, response::Response};
11use mockforge_core::{
12 ai_contract_diff::CapturedRequest, request_capture::get_global_capture_manager,
13};
14use std::collections::HashMap;
15use tracing::debug;
16
17fn max_capture_body_size() -> usize {
29 const DEFAULT_MB: usize = 10;
30 std::env::var("MOCKFORGE_CONTRACT_DIFF_MAX_BODY_MB")
31 .ok()
32 .and_then(|v| v.parse::<usize>().ok())
33 .unwrap_or(DEFAULT_MB)
34 .saturating_mul(1024 * 1024)
35}
36
37pub async fn capture_for_contract_diff(req: Request<Body>, next: Next) -> Response {
39 let method = req.method().to_string();
40 let uri = req.uri().clone();
41 let path = uri.path().to_string();
42 let query = uri.query();
43 let max_body = max_capture_body_size();
44
45 let content_length = req
50 .headers()
51 .get(CONTENT_LENGTH)
52 .and_then(|v| v.to_str().ok())
53 .and_then(|s| s.parse::<usize>().ok());
54 if let Some(len) = content_length {
55 if len > max_body {
56 debug!(
57 "contract_diff: skipping capture for {} {} — content-length {} exceeds cap {}",
58 method, path, len, max_body
59 );
60 return next.run(req).await;
61 }
62 }
63
64 let headers = extract_headers_for_capture(req.headers());
66
67 let query_params = if let Some(query) = query {
69 parse_query_params(query)
70 } else {
71 HashMap::new()
72 };
73
74 let (parts, body) = req.into_parts();
76 let body_bytes = match axum::body::to_bytes(body, max_body).await {
77 Ok(b) => b,
78 Err(_) => {
79 return Response::builder()
87 .status(http::StatusCode::PAYLOAD_TOO_LARGE)
88 .header(
89 http::header::CONTENT_TYPE,
90 "application/json",
91 )
92 .body(Body::from(format!(
93 r#"{{"error":"PAYLOAD_TOO_LARGE","message":"chunked request body exceeded contract_diff capture cap (~{} MiB); raise MOCKFORGE_CONTRACT_DIFF_MAX_BODY_MB or send Content-Length"}}"#,
94 max_body / (1024 * 1024)
95 )))
96 .unwrap_or_else(|_| {
97 Response::new(Body::from("payload too large"))
98 });
99 }
100 };
101
102 let captured_body = if !body_bytes.is_empty() {
104 serde_json::from_slice::<serde_json::Value>(&body_bytes).ok()
105 } else {
106 None
107 };
108
109 let rebuilt = Request::from_parts(parts, Body::from(body_bytes));
111
112 let response = next.run(rebuilt).await;
114
115 let status_code = response.status().as_u16();
117
118 let mut captured = CapturedRequest::new(&method, &path, "proxy_middleware")
120 .with_headers(headers)
121 .with_query_params(query_params)
122 .with_response(status_code, None);
123
124 if let Some(body_value) = captured_body {
125 captured = captured.with_body(body_value);
126 }
127
128 if let Some(capture_manager) = get_global_capture_manager() {
130 if let Err(e) = capture_manager.capture(captured).await {
131 debug!("Failed to capture request for contract diff: {}", e);
132 }
133 }
134
135 response
136}
137
138fn extract_headers_for_capture(headers: &HeaderMap) -> HashMap<String, String> {
140 let mut captured_headers = HashMap::new();
141
142 let safe_headers = [
144 "accept",
145 "accept-encoding",
146 "accept-language",
147 "content-type",
148 "content-length",
149 "user-agent",
150 "referer",
151 "origin",
152 "x-requested-with",
153 ];
154
155 for header_name in safe_headers {
156 if let Some(value) = headers.get(header_name) {
157 if let Ok(value_str) = value.to_str() {
158 captured_headers.insert(header_name.to_string(), value_str.to_string());
159 }
160 }
161 }
162
163 captured_headers
164}
165
166fn parse_query_params(query: &str) -> HashMap<String, String> {
168 let mut params = HashMap::new();
169
170 for pair in query.split('&') {
171 if let Some((key, value)) = pair.split_once('=') {
172 let decoded_key = urlencoding::decode(key).unwrap_or_else(|_| key.into());
173 let decoded_value = urlencoding::decode(value).unwrap_or_else(|_| value.into());
174 params.insert(decoded_key.to_string(), decoded_value.to_string());
175 }
176 }
177
178 params
179}
180
181#[cfg(test)]
182mod tests {
183 use super::*;
184 use axum::http::HeaderValue;
185
186 #[test]
187 fn test_extract_headers_for_capture() {
188 let mut headers = HeaderMap::new();
189 headers.insert("content-type", HeaderValue::from_static("application/json"));
190 headers.insert("authorization", HeaderValue::from_static("Bearer token"));
191 headers.insert("accept", HeaderValue::from_static("application/json"));
192
193 let captured = extract_headers_for_capture(&headers);
194
195 assert_eq!(captured.get("content-type"), Some(&"application/json".to_string()));
196 assert_eq!(captured.get("accept"), Some(&"application/json".to_string()));
197 assert!(!captured.contains_key("authorization")); }
199
200 #[test]
201 fn test_parse_query_params() {
202 let query = "name=John&age=30&city=New%20York";
203 let params = parse_query_params(query);
204
205 assert_eq!(params.get("name"), Some(&"John".to_string()));
206 assert_eq!(params.get("age"), Some(&"30".to_string()));
207 assert_eq!(params.get("city"), Some(&"New York".to_string()));
208 }
209
210 #[test]
211 fn test_parse_query_params_empty() {
212 let params = parse_query_params("");
213 assert!(params.is_empty());
214 }
215}