1use std::time::Instant;
7use axum::{
8 extract::Request,
9 response::Response,
10 http::Method,
11};
12use tracing::{info, warn, error, Span, Level};
13use uuid::Uuid;
14
15use crate::middleware::{Middleware, BoxFuture};
16
17#[derive(Debug, Clone)]
19pub struct TracingConfig {
20 pub trace_bodies: bool,
22 pub trace_response_bodies: bool,
24 pub max_body_size: usize,
26 pub level: Level,
28 pub include_sensitive_headers: bool,
30 pub sensitive_headers: Vec<String>,
32}
33
34impl Default for TracingConfig {
35 fn default() -> Self {
36 Self {
37 trace_bodies: false,
38 trace_response_bodies: false,
39 max_body_size: 1024,
40 level: Level::INFO,
41 include_sensitive_headers: false,
42 sensitive_headers: vec![
43 "authorization".to_string(),
44 "cookie".to_string(),
45 "x-api-key".to_string(),
46 "x-auth-token".to_string(),
47 ],
48 }
49 }
50}
51
52impl TracingConfig {
53 pub fn with_body_tracing(mut self) -> Self {
55 self.trace_bodies = true;
56 self
57 }
58
59 pub fn with_response_body_tracing(mut self) -> Self {
61 self.trace_response_bodies = true;
62 self
63 }
64
65 pub fn with_max_body_size(mut self, size: usize) -> Self {
67 self.max_body_size = size;
68 self
69 }
70
71 pub fn with_level(mut self, level: Level) -> Self {
73 self.level = level;
74 self
75 }
76
77 pub fn with_sensitive_headers(mut self) -> Self {
79 self.include_sensitive_headers = true;
80 self
81 }
82
83 pub fn add_sensitive_header(mut self, header: String) -> Self {
85 self.sensitive_headers.push(header.to_lowercase());
86 self
87 }
88}
89
90pub struct TracingMiddleware {
92 config: TracingConfig,
93}
94
95impl TracingMiddleware {
96 pub fn new() -> Self {
98 Self {
99 config: TracingConfig::default(),
100 }
101 }
102
103 pub fn with_config(config: TracingConfig) -> Self {
105 Self { config }
106 }
107
108 pub fn with_body_tracing(mut self) -> Self {
110 self.config = self.config.with_body_tracing();
111 self
112 }
113
114 pub fn with_level(mut self, level: Level) -> Self {
116 self.config = self.config.with_level(level);
117 self
118 }
119
120 fn is_sensitive_header(&self, name: &str) -> bool {
122 if self.config.include_sensitive_headers {
123 return false;
124 }
125
126 let name_lower = name.to_lowercase();
127 self.config.sensitive_headers.iter().any(|h| h == &name_lower)
128 }
129
130 fn format_headers(&self, headers: &axum::http::HeaderMap) -> String {
132 headers
133 .iter()
134 .map(|(name, value)| {
135 let name_str = name.as_str();
136 let value_str = if self.is_sensitive_header(name_str) {
137 "[REDACTED]"
138 } else {
139 value.to_str().unwrap_or("[INVALID_UTF8]")
140 };
141 format!("{}={}", name_str, value_str)
142 })
143 .collect::<Vec<_>>()
144 .join(", ")
145 }
146}
147
148impl Default for TracingMiddleware {
149 fn default() -> Self {
150 Self::new()
151 }
152}
153
154impl Middleware for TracingMiddleware {
155 fn process_request<'a>(
156 &'a self,
157 mut request: Request
158 ) -> BoxFuture<'a, Result<Request, Response>> {
159 Box::pin(async move {
160 let start_time = Instant::now();
161 let request_id = Uuid::new_v4();
162
163 let span = match self.config.level {
165 Level::ERROR => tracing::error_span!(
166 "http_request",
167 method = %request.method(),
168 uri = %request.uri(),
169 request_id = %request_id,
170 remote_addr = tracing::field::Empty,
171 ),
172 Level::WARN => tracing::warn_span!(
173 "http_request",
174 method = %request.method(),
175 uri = %request.uri(),
176 request_id = %request_id,
177 remote_addr = tracing::field::Empty,
178 ),
179 Level::INFO => tracing::info_span!(
180 "http_request",
181 method = %request.method(),
182 uri = %request.uri(),
183 request_id = %request_id,
184 remote_addr = tracing::field::Empty,
185 ),
186 Level::DEBUG => tracing::debug_span!(
187 "http_request",
188 method = %request.method(),
189 uri = %request.uri(),
190 request_id = %request_id,
191 remote_addr = tracing::field::Empty,
192 ),
193 Level::TRACE => tracing::trace_span!(
194 "http_request",
195 method = %request.method(),
196 uri = %request.uri(),
197 request_id = %request_id,
198 remote_addr = tracing::field::Empty,
199 ),
200 };
201
202 request.extensions_mut().insert(RequestMetadata {
204 request_id,
205 start_time,
206 span: span.clone(),
207 });
208
209 let _enter = span.enter();
211
212 match self.config.level {
214 Level::ERROR => error!(
215 "HTTP Request: {} {} (ID: {})",
216 request.method(),
217 request.uri(),
218 request_id
219 ),
220 Level::WARN => warn!(
221 "HTTP Request: {} {} (ID: {})",
222 request.method(),
223 request.uri(),
224 request_id
225 ),
226 Level::INFO => info!(
227 "HTTP Request: {} {} (ID: {})",
228 request.method(),
229 request.uri(),
230 request_id
231 ),
232 Level::DEBUG => {
233 let headers = self.format_headers(request.headers());
234 tracing::debug!(
235 "HTTP Request: {} {} (ID: {}) - Headers: {}",
236 request.method(),
237 request.uri(),
238 request_id,
239 headers
240 );
241 },
242 Level::TRACE => {
243 let headers = self.format_headers(request.headers());
244 tracing::trace!(
245 "HTTP Request: {} {} (ID: {}) - Headers: {} - Body tracing: {}",
246 request.method(),
247 request.uri(),
248 request_id,
249 headers,
250 self.config.trace_bodies
251 );
252 }
253 }
254
255 Ok(request)
256 })
257 }
258
259 fn process_response<'a>(
260 &'a self,
261 response: Response
262 ) -> BoxFuture<'a, Response> {
263 Box::pin(async move {
264 let status = response.status();
265
266 match self.config.level {
271 Level::ERROR if status.is_server_error() => {
272 error!("HTTP Response: {} (Server Error)", status);
273 },
274 Level::WARN if status.is_client_error() => {
275 warn!("HTTP Response: {} (Client Error)", status);
276 },
277 Level::INFO => {
278 info!("HTTP Response: {}", status);
279 },
280 Level::DEBUG => {
281 let headers = self.format_headers(response.headers());
282 tracing::debug!(
283 "HTTP Response: {} - Headers: {}",
284 status,
285 headers
286 );
287 },
288 Level::TRACE => {
289 let headers = self.format_headers(response.headers());
290 tracing::trace!(
291 "HTTP Response: {} - Headers: {} - Body tracing: {}",
292 status,
293 headers,
294 self.config.trace_response_bodies
295 );
296 },
297 _ => {} }
299
300 response
301 })
302 }
303
304 fn name(&self) -> &'static str {
305 "TracingMiddleware"
306 }
307}
308
309#[derive(Debug, Clone)]
311pub struct RequestMetadata {
312 pub request_id: Uuid,
313 pub start_time: Instant,
314 pub span: Span,
315}
316
317#[cfg(test)]
318mod tests {
319 use super::*;
320 use axum::http::{Method, StatusCode, HeaderValue};
321 use tracing_test::traced_test;
322
323 #[traced_test]
324 #[tokio::test]
325 async fn test_tracing_middleware_basic() {
326 let middleware = TracingMiddleware::new();
327
328 let request = Request::builder()
329 .method(Method::GET)
330 .uri("/test")
331 .body(axum::body::Body::empty())
332 .unwrap();
333
334 let result = middleware.process_request(request).await;
335 assert!(result.is_ok());
336
337 let processed_request = result.unwrap();
338
339 let metadata = processed_request.extensions().get::<RequestMetadata>();
341 assert!(metadata.is_some());
342
343 let metadata = metadata.unwrap();
344 assert!(!metadata.request_id.is_nil());
345 assert!(metadata.start_time.elapsed().as_nanos() > 0);
346 }
347
348 #[traced_test]
349 #[tokio::test]
350 async fn test_tracing_middleware_response() {
351 let middleware = TracingMiddleware::new();
352
353 let response = Response::builder()
354 .status(StatusCode::OK)
355 .body(axum::body::Body::empty())
356 .unwrap();
357
358 let processed_response = middleware.process_response(response).await;
359 assert_eq!(processed_response.status(), StatusCode::OK);
360 }
361
362 #[tokio::test]
363 async fn test_tracing_config_customization() {
364 let config = TracingConfig::default()
365 .with_body_tracing()
366 .with_level(Level::DEBUG)
367 .with_max_body_size(2048)
368 .add_sensitive_header("x-custom-secret".to_string());
369
370 let middleware = TracingMiddleware::with_config(config);
371 assert!(middleware.config.trace_bodies);
372 assert_eq!(middleware.config.level, Level::DEBUG);
373 assert_eq!(middleware.config.max_body_size, 2048);
374 assert!(middleware.config.sensitive_headers.contains(&"x-custom-secret".to_string()));
375 }
376
377 #[tokio::test]
378 async fn test_sensitive_header_detection() {
379 let middleware = TracingMiddleware::new();
380
381 assert!(middleware.is_sensitive_header("Authorization"));
382 assert!(middleware.is_sensitive_header("COOKIE"));
383 assert!(middleware.is_sensitive_header("x-api-key"));
384 assert!(!middleware.is_sensitive_header("content-type"));
385 assert!(!middleware.is_sensitive_header("accept"));
386 }
387
388 #[tokio::test]
389 async fn test_header_formatting() {
390 let middleware = TracingMiddleware::new();
391
392 let mut headers = axum::http::HeaderMap::new();
393 headers.insert("content-type", HeaderValue::from_static("application/json"));
394 headers.insert("authorization", HeaderValue::from_static("Bearer secret"));
395 headers.insert("x-custom", HeaderValue::from_static("value"));
396
397 let formatted = middleware.format_headers(&headers);
398
399 assert!(formatted.contains("content-type=application/json"));
400 assert!(formatted.contains("authorization=[REDACTED]"));
401 assert!(formatted.contains("x-custom=value"));
402 }
403
404 #[tokio::test]
405 async fn test_tracing_middleware_name() {
406 let middleware = TracingMiddleware::new();
407 assert_eq!(middleware.name(), "TracingMiddleware");
408 }
409}