1use axum::{
4 extract::{ConnectInfo, MatchedPath, Request},
5 http::HeaderMap,
6 middleware::Next,
7 response::Response,
8};
9use mockforge_core::{
10 create_http_log_entry, log_request_global, request_logger::RealityTraceMetadata,
11};
12use std::collections::HashMap;
13use std::net::SocketAddr;
14use std::time::Instant;
15use tracing::info;
16
17pub async fn log_http_requests(
19 ConnectInfo(addr): ConnectInfo<SocketAddr>,
20 matched_path: Option<MatchedPath>,
21 req: Request,
22 next: Next,
23) -> Response {
24 let start_time = Instant::now();
25 let method = req.method().to_string();
26 let uri = req.uri().to_string();
27 let path = matched_path
28 .map(|mp| mp.as_str().to_string())
29 .unwrap_or_else(|| uri.split('?').next().unwrap_or(&uri).to_string());
30
31 let query_params: HashMap<String, String> = req
33 .uri()
34 .query()
35 .map(|q| url::form_urlencoded::parse(q.as_bytes()).into_owned().collect())
36 .unwrap_or_default();
37
38 let headers = extract_safe_headers(req.headers());
40
41 let user_agent = req
43 .headers()
44 .get("user-agent")
45 .and_then(|h| h.to_str().ok())
46 .map(|s| s.to_string());
47
48 let reality_metadata = req.extensions().get::<RealityTraceMetadata>().cloned();
51
52 let response = next.run(req).await;
54
55 let response_time_ms = start_time.elapsed().as_millis() as u64;
57 let status_code = response.status().as_u16();
58
59 let response_size_bytes = response
61 .headers()
62 .get("content-length")
63 .and_then(|h| h.to_str().ok())
64 .and_then(|s| s.parse::<u64>().ok())
65 .unwrap_or(0);
66
67 let error_message = if status_code >= 400 {
69 Some(format!(
70 "HTTP {} {}",
71 status_code,
72 response.status().canonical_reason().unwrap_or("Unknown")
73 ))
74 } else {
75 None
76 };
77
78 let mut log_entry = create_http_log_entry(
80 &method,
81 &path,
82 status_code,
83 response_time_ms,
84 Some(addr.ip().to_string()),
85 user_agent,
86 headers,
87 response_size_bytes,
88 error_message,
89 );
90
91 let query_params_for_log = query_params.clone();
93 if !query_params_for_log.is_empty() {
94 for (key, value) in query_params_for_log {
95 log_entry.metadata.insert(format!("query.{}", key), value);
96 }
97 }
98
99 log_entry.reality_metadata = reality_metadata;
101
102 log_request_global(log_entry).await;
104
105 if !query_params.is_empty() {
107 let query_params_clone = query_params.clone();
108 info!(
109 method = %method,
110 path = %path,
111 query = ?query_params_clone,
112 status = status_code,
113 duration_ms = response_time_ms,
114 client_ip = %addr.ip(),
115 "HTTP request processed"
116 );
117 } else {
118 info!(
119 method = %method,
120 path = %path,
121 status = status_code,
122 duration_ms = response_time_ms,
123 client_ip = %addr.ip(),
124 "HTTP request processed"
125 );
126 }
127
128 response
129}
130
131fn extract_safe_headers(headers: &HeaderMap) -> HashMap<String, String> {
133 let mut safe_headers = HashMap::new();
134
135 let safe_header_names = [
137 "accept",
138 "accept-encoding",
139 "accept-language",
140 "cache-control",
141 "content-type",
142 "content-length",
143 "user-agent",
144 "referer",
145 "host",
146 "x-forwarded-for",
147 "x-real-ip",
148 ];
149
150 for name in safe_header_names {
151 if let Some(value) = headers.get(name) {
152 if let Ok(value_str) = value.to_str() {
153 safe_headers.insert(name.to_string(), value_str.to_string());
154 }
155 }
156 }
157
158 safe_headers
159}
160
161#[cfg(test)]
162mod tests {
163 use super::*;
164 use axum::http::HeaderValue;
165
166 #[test]
167 fn test_extract_safe_headers_empty() {
168 let headers = HeaderMap::new();
169 let safe_headers = extract_safe_headers(&headers);
170 assert_eq!(safe_headers.len(), 0);
171 }
172
173 #[test]
174 fn test_extract_safe_headers_with_safe_headers() {
175 let mut headers = HeaderMap::new();
176 headers.insert("content-type", HeaderValue::from_static("application/json"));
177 headers.insert("user-agent", HeaderValue::from_static("test-agent"));
178 headers.insert("accept", HeaderValue::from_static("application/json"));
179
180 let safe_headers = extract_safe_headers(&headers);
181
182 assert_eq!(safe_headers.len(), 3);
183 assert_eq!(safe_headers.get("content-type"), Some(&"application/json".to_string()));
184 assert_eq!(safe_headers.get("user-agent"), Some(&"test-agent".to_string()));
185 assert_eq!(safe_headers.get("accept"), Some(&"application/json".to_string()));
186 }
187
188 #[test]
189 fn test_extract_safe_headers_excludes_sensitive_headers() {
190 let mut headers = HeaderMap::new();
191 headers.insert("content-type", HeaderValue::from_static("application/json"));
192 headers.insert("authorization", HeaderValue::from_static("Bearer token123"));
193 headers.insert("cookie", HeaderValue::from_static("session=abc123"));
194 headers.insert("x-api-key", HeaderValue::from_static("secret-key"));
195
196 let safe_headers = extract_safe_headers(&headers);
197
198 assert_eq!(safe_headers.len(), 1);
200 assert_eq!(safe_headers.get("content-type"), Some(&"application/json".to_string()));
201
202 assert!(!safe_headers.contains_key("authorization"));
204 assert!(!safe_headers.contains_key("cookie"));
205 assert!(!safe_headers.contains_key("x-api-key"));
206 }
207
208 #[test]
209 fn test_extract_safe_headers_all_safe_header_types() {
210 let mut headers = HeaderMap::new();
211
212 headers.insert("accept", HeaderValue::from_static("application/json"));
214 headers.insert("accept-encoding", HeaderValue::from_static("gzip, deflate"));
215 headers.insert("accept-language", HeaderValue::from_static("en-US"));
216 headers.insert("cache-control", HeaderValue::from_static("no-cache"));
217 headers.insert("content-type", HeaderValue::from_static("application/json"));
218 headers.insert("content-length", HeaderValue::from_static("123"));
219 headers.insert("user-agent", HeaderValue::from_static("Mozilla/5.0"));
220 headers.insert("referer", HeaderValue::from_static("https://example.com"));
221 headers.insert("host", HeaderValue::from_static("api.example.com"));
222 headers.insert("x-forwarded-for", HeaderValue::from_static("192.168.1.1"));
223 headers.insert("x-real-ip", HeaderValue::from_static("192.168.1.2"));
224
225 let safe_headers = extract_safe_headers(&headers);
226
227 assert_eq!(safe_headers.len(), 11);
228 assert_eq!(safe_headers.get("accept"), Some(&"application/json".to_string()));
229 assert_eq!(safe_headers.get("accept-encoding"), Some(&"gzip, deflate".to_string()));
230 assert_eq!(safe_headers.get("accept-language"), Some(&"en-US".to_string()));
231 assert_eq!(safe_headers.get("cache-control"), Some(&"no-cache".to_string()));
232 assert_eq!(safe_headers.get("content-type"), Some(&"application/json".to_string()));
233 assert_eq!(safe_headers.get("content-length"), Some(&"123".to_string()));
234 assert_eq!(safe_headers.get("user-agent"), Some(&"Mozilla/5.0".to_string()));
235 assert_eq!(safe_headers.get("referer"), Some(&"https://example.com".to_string()));
236 assert_eq!(safe_headers.get("host"), Some(&"api.example.com".to_string()));
237 assert_eq!(safe_headers.get("x-forwarded-for"), Some(&"192.168.1.1".to_string()));
238 assert_eq!(safe_headers.get("x-real-ip"), Some(&"192.168.1.2".to_string()));
239 }
240
241 #[test]
242 fn test_extract_safe_headers_handles_invalid_utf8() {
243 let mut headers = HeaderMap::new();
244 headers.insert("content-type", HeaderValue::from_static("application/json"));
245 let safe_headers = extract_safe_headers(&headers);
249 assert!(safe_headers.contains_key("content-type"));
250 }
251
252 #[test]
253 fn test_extract_safe_headers_case_insensitive() {
254 let mut headers = HeaderMap::new();
255 headers.insert("Content-Type", HeaderValue::from_static("application/json"));
257 headers.insert("User-Agent", HeaderValue::from_static("test"));
258
259 let safe_headers = extract_safe_headers(&headers);
260
261 assert_eq!(safe_headers.len(), 2);
263 assert!(safe_headers.contains_key("content-type"));
264 assert!(safe_headers.contains_key("user-agent"));
265 }
266
267 #[test]
268 fn test_extract_safe_headers_mixed_safe_and_unsafe() {
269 let mut headers = HeaderMap::new();
270 headers.insert("content-type", HeaderValue::from_static("application/json"));
271 headers.insert("authorization", HeaderValue::from_static("Bearer token"));
272 headers.insert("user-agent", HeaderValue::from_static("Mozilla/5.0"));
273 headers.insert("x-api-key", HeaderValue::from_static("secret"));
274 headers.insert("accept", HeaderValue::from_static("*/*"));
275
276 let safe_headers = extract_safe_headers(&headers);
277
278 assert_eq!(safe_headers.len(), 3);
280 assert!(safe_headers.contains_key("content-type"));
281 assert!(safe_headers.contains_key("user-agent"));
282 assert!(safe_headers.contains_key("accept"));
283
284 assert!(!safe_headers.contains_key("authorization"));
286 assert!(!safe_headers.contains_key("x-api-key"));
287 }
288}