1use axum::{
4 extract::{ConnectInfo, MatchedPath, Request},
5 http::HeaderMap,
6 middleware::Next,
7 response::Response,
8};
9use mockforge_core::{create_http_log_entry, log_request_global};
10use std::collections::HashMap;
11use std::net::SocketAddr;
12use std::time::Instant;
13use tracing::info;
14
15pub async fn log_http_requests(
17 ConnectInfo(addr): ConnectInfo<SocketAddr>,
18 matched_path: Option<MatchedPath>,
19 req: Request,
20 next: Next,
21) -> Response {
22 let start_time = Instant::now();
23 let method = req.method().to_string();
24 let uri = req.uri().to_string();
25 let path = matched_path
26 .map(|mp| mp.as_str().to_string())
27 .unwrap_or_else(|| uri.split('?').next().unwrap_or(&uri).to_string());
28
29 let query_params: HashMap<String, String> = req
31 .uri()
32 .query()
33 .map(|q| {
34 url::form_urlencoded::parse(q.as_bytes())
35 .into_owned()
36 .collect()
37 })
38 .unwrap_or_default();
39
40 let headers = extract_safe_headers(req.headers());
42
43 let user_agent = req
45 .headers()
46 .get("user-agent")
47 .and_then(|h| h.to_str().ok())
48 .map(|s| s.to_string());
49
50 let response = next.run(req).await;
52
53 let response_time_ms = start_time.elapsed().as_millis() as u64;
55 let status_code = response.status().as_u16();
56
57 let response_size_bytes = response
59 .headers()
60 .get("content-length")
61 .and_then(|h| h.to_str().ok())
62 .and_then(|s| s.parse::<u64>().ok())
63 .unwrap_or(0);
64
65 let error_message = if status_code >= 400 {
67 Some(format!(
68 "HTTP {} {}",
69 status_code,
70 response.status().canonical_reason().unwrap_or("Unknown")
71 ))
72 } else {
73 None
74 };
75
76 let mut log_entry = create_http_log_entry(
78 &method,
79 &path,
80 status_code,
81 response_time_ms,
82 Some(addr.ip().to_string()),
83 user_agent,
84 headers,
85 response_size_bytes,
86 error_message,
87 );
88
89 let query_params_for_log = query_params.clone();
91 if !query_params_for_log.is_empty() {
92 for (key, value) in query_params_for_log {
93 log_entry.metadata.insert(format!("query.{}", key), value);
94 }
95 }
96
97 log_request_global(log_entry).await;
99
100 if !query_params.is_empty() {
102 let query_params_clone = query_params.clone();
103 info!(
104 method = %method,
105 path = %path,
106 query = ?query_params_clone,
107 status = status_code,
108 duration_ms = response_time_ms,
109 client_ip = %addr.ip(),
110 "HTTP request processed"
111 );
112 } else {
113 info!(
114 method = %method,
115 path = %path,
116 status = status_code,
117 duration_ms = response_time_ms,
118 client_ip = %addr.ip(),
119 "HTTP request processed"
120 );
121 }
122
123 response
124}
125
126fn extract_safe_headers(headers: &HeaderMap) -> HashMap<String, String> {
128 let mut safe_headers = HashMap::new();
129
130 let safe_header_names = [
132 "accept",
133 "accept-encoding",
134 "accept-language",
135 "cache-control",
136 "content-type",
137 "content-length",
138 "user-agent",
139 "referer",
140 "host",
141 "x-forwarded-for",
142 "x-real-ip",
143 ];
144
145 for name in safe_header_names {
146 if let Some(value) = headers.get(name) {
147 if let Ok(value_str) = value.to_str() {
148 safe_headers.insert(name.to_string(), value_str.to_string());
149 }
150 }
151 }
152
153 safe_headers
154}
155
156#[cfg(test)]
157mod tests {
158 use super::*;
159 use axum::http::HeaderValue;
160
161 #[test]
162 fn test_extract_safe_headers_empty() {
163 let headers = HeaderMap::new();
164 let safe_headers = extract_safe_headers(&headers);
165 assert_eq!(safe_headers.len(), 0);
166 }
167
168 #[test]
169 fn test_extract_safe_headers_with_safe_headers() {
170 let mut headers = HeaderMap::new();
171 headers.insert("content-type", HeaderValue::from_static("application/json"));
172 headers.insert("user-agent", HeaderValue::from_static("test-agent"));
173 headers.insert("accept", HeaderValue::from_static("application/json"));
174
175 let safe_headers = extract_safe_headers(&headers);
176
177 assert_eq!(safe_headers.len(), 3);
178 assert_eq!(safe_headers.get("content-type"), Some(&"application/json".to_string()));
179 assert_eq!(safe_headers.get("user-agent"), Some(&"test-agent".to_string()));
180 assert_eq!(safe_headers.get("accept"), Some(&"application/json".to_string()));
181 }
182
183 #[test]
184 fn test_extract_safe_headers_excludes_sensitive_headers() {
185 let mut headers = HeaderMap::new();
186 headers.insert("content-type", HeaderValue::from_static("application/json"));
187 headers.insert("authorization", HeaderValue::from_static("Bearer token123"));
188 headers.insert("cookie", HeaderValue::from_static("session=abc123"));
189 headers.insert("x-api-key", HeaderValue::from_static("secret-key"));
190
191 let safe_headers = extract_safe_headers(&headers);
192
193 assert_eq!(safe_headers.len(), 1);
195 assert_eq!(safe_headers.get("content-type"), Some(&"application/json".to_string()));
196
197 assert!(!safe_headers.contains_key("authorization"));
199 assert!(!safe_headers.contains_key("cookie"));
200 assert!(!safe_headers.contains_key("x-api-key"));
201 }
202
203 #[test]
204 fn test_extract_safe_headers_all_safe_header_types() {
205 let mut headers = HeaderMap::new();
206
207 headers.insert("accept", HeaderValue::from_static("application/json"));
209 headers.insert("accept-encoding", HeaderValue::from_static("gzip, deflate"));
210 headers.insert("accept-language", HeaderValue::from_static("en-US"));
211 headers.insert("cache-control", HeaderValue::from_static("no-cache"));
212 headers.insert("content-type", HeaderValue::from_static("application/json"));
213 headers.insert("content-length", HeaderValue::from_static("123"));
214 headers.insert("user-agent", HeaderValue::from_static("Mozilla/5.0"));
215 headers.insert("referer", HeaderValue::from_static("https://example.com"));
216 headers.insert("host", HeaderValue::from_static("api.example.com"));
217 headers.insert("x-forwarded-for", HeaderValue::from_static("192.168.1.1"));
218 headers.insert("x-real-ip", HeaderValue::from_static("192.168.1.2"));
219
220 let safe_headers = extract_safe_headers(&headers);
221
222 assert_eq!(safe_headers.len(), 11);
223 assert_eq!(safe_headers.get("accept"), Some(&"application/json".to_string()));
224 assert_eq!(safe_headers.get("accept-encoding"), Some(&"gzip, deflate".to_string()));
225 assert_eq!(safe_headers.get("accept-language"), Some(&"en-US".to_string()));
226 assert_eq!(safe_headers.get("cache-control"), Some(&"no-cache".to_string()));
227 assert_eq!(safe_headers.get("content-type"), Some(&"application/json".to_string()));
228 assert_eq!(safe_headers.get("content-length"), Some(&"123".to_string()));
229 assert_eq!(safe_headers.get("user-agent"), Some(&"Mozilla/5.0".to_string()));
230 assert_eq!(safe_headers.get("referer"), Some(&"https://example.com".to_string()));
231 assert_eq!(safe_headers.get("host"), Some(&"api.example.com".to_string()));
232 assert_eq!(safe_headers.get("x-forwarded-for"), Some(&"192.168.1.1".to_string()));
233 assert_eq!(safe_headers.get("x-real-ip"), Some(&"192.168.1.2".to_string()));
234 }
235
236 #[test]
237 fn test_extract_safe_headers_handles_invalid_utf8() {
238 let mut headers = HeaderMap::new();
239 headers.insert("content-type", HeaderValue::from_static("application/json"));
240 let safe_headers = extract_safe_headers(&headers);
244 assert!(safe_headers.contains_key("content-type"));
245 }
246
247 #[test]
248 fn test_extract_safe_headers_case_insensitive() {
249 let mut headers = HeaderMap::new();
250 headers.insert("Content-Type", HeaderValue::from_static("application/json"));
252 headers.insert("User-Agent", HeaderValue::from_static("test"));
253
254 let safe_headers = extract_safe_headers(&headers);
255
256 assert_eq!(safe_headers.len(), 2);
258 assert!(safe_headers.contains_key("content-type"));
259 assert!(safe_headers.contains_key("user-agent"));
260 }
261
262 #[test]
263 fn test_extract_safe_headers_mixed_safe_and_unsafe() {
264 let mut headers = HeaderMap::new();
265 headers.insert("content-type", HeaderValue::from_static("application/json"));
266 headers.insert("authorization", HeaderValue::from_static("Bearer token"));
267 headers.insert("user-agent", HeaderValue::from_static("Mozilla/5.0"));
268 headers.insert("x-api-key", HeaderValue::from_static("secret"));
269 headers.insert("accept", HeaderValue::from_static("*/*"));
270
271 let safe_headers = extract_safe_headers(&headers);
272
273 assert_eq!(safe_headers.len(), 3);
275 assert!(safe_headers.contains_key("content-type"));
276 assert!(safe_headers.contains_key("user-agent"));
277 assert!(safe_headers.contains_key("accept"));
278
279 assert!(!safe_headers.contains_key("authorization"));
281 assert!(!safe_headers.contains_key("x-api-key"));
282 }
283}