elif_http/middleware/
tracing.rs

1//! # Tracing Middleware
2//!
3//! Framework middleware for HTTP request tracing and observability.
4//! Replaces tower-http TraceLayer with framework-native implementation.
5
6use std::time::Instant;
7use axum::{
8    extract::Request,
9    response::Response,
10    http::Method,
11};
12use tracing::{info, warn, error, Span, Level};
13use uuid::Uuid;
14
15use crate::middleware::{Middleware, BoxFuture};
16
17/// Configuration for tracing middleware
18#[derive(Debug, Clone)]
19pub struct TracingConfig {
20    /// Whether to trace request bodies
21    pub trace_bodies: bool,
22    /// Whether to trace response bodies  
23    pub trace_response_bodies: bool,
24    /// Maximum body size to trace (in bytes)
25    pub max_body_size: usize,
26    /// Log level for requests
27    pub level: Level,
28    /// Whether to include sensitive headers in traces
29    pub include_sensitive_headers: bool,
30    /// Headers considered sensitive (will be redacted)
31    pub sensitive_headers: Vec<String>,
32}
33
34impl Default for TracingConfig {
35    fn default() -> Self {
36        Self {
37            trace_bodies: false,
38            trace_response_bodies: false,
39            max_body_size: 1024,
40            level: Level::INFO,
41            include_sensitive_headers: false,
42            sensitive_headers: vec![
43                "authorization".to_string(),
44                "cookie".to_string(),
45                "x-api-key".to_string(),
46                "x-auth-token".to_string(),
47            ],
48        }
49    }
50}
51
52impl TracingConfig {
53    /// Enable body tracing
54    pub fn with_body_tracing(mut self) -> Self {
55        self.trace_bodies = true;
56        self
57    }
58
59    /// Enable response body tracing
60    pub fn with_response_body_tracing(mut self) -> Self {
61        self.trace_response_bodies = true;
62        self
63    }
64
65    /// Set maximum body size for tracing
66    pub fn with_max_body_size(mut self, size: usize) -> Self {
67        self.max_body_size = size;
68        self
69    }
70
71    /// Set tracing level
72    pub fn with_level(mut self, level: Level) -> Self {
73        self.level = level;
74        self
75    }
76
77    /// Include sensitive headers in traces (not recommended for production)
78    pub fn with_sensitive_headers(mut self) -> Self {
79        self.include_sensitive_headers = true;
80        self
81    }
82
83    /// Add custom sensitive header
84    pub fn add_sensitive_header(mut self, header: String) -> Self {
85        self.sensitive_headers.push(header.to_lowercase());
86        self
87    }
88}
89
90/// Framework tracing middleware for HTTP requests
91pub struct TracingMiddleware {
92    config: TracingConfig,
93}
94
95impl TracingMiddleware {
96    /// Create new tracing middleware with default configuration
97    pub fn new() -> Self {
98        Self {
99            config: TracingConfig::default(),
100        }
101    }
102
103    /// Create tracing middleware with custom configuration
104    pub fn with_config(config: TracingConfig) -> Self {
105        Self { config }
106    }
107
108    /// Enable body tracing
109    pub fn with_body_tracing(mut self) -> Self {
110        self.config = self.config.with_body_tracing();
111        self
112    }
113
114    /// Set tracing level
115    pub fn with_level(mut self, level: Level) -> Self {
116        self.config = self.config.with_level(level);
117        self
118    }
119
120    /// Check if header is sensitive
121    fn is_sensitive_header(&self, name: &str) -> bool {
122        if self.config.include_sensitive_headers {
123            return false;
124        }
125        
126        let name_lower = name.to_lowercase();
127        self.config.sensitive_headers.iter().any(|h| h == &name_lower)
128    }
129
130    /// Format headers for tracing
131    fn format_headers(&self, headers: &axum::http::HeaderMap) -> String {
132        headers
133            .iter()
134            .map(|(name, value)| {
135                let name_str = name.as_str();
136                let value_str = if self.is_sensitive_header(name_str) {
137                    "[REDACTED]"
138                } else {
139                    value.to_str().unwrap_or("[INVALID_UTF8]")
140                };
141                format!("{}={}", name_str, value_str)
142            })
143            .collect::<Vec<_>>()
144            .join(", ")
145    }
146}
147
148impl Default for TracingMiddleware {
149    fn default() -> Self {
150        Self::new()
151    }
152}
153
154impl Middleware for TracingMiddleware {
155    fn process_request<'a>(
156        &'a self,
157        mut request: Request
158    ) -> BoxFuture<'a, Result<Request, Response>> {
159        Box::pin(async move {
160            let start_time = Instant::now();
161            let request_id = Uuid::new_v4();
162            
163            // Create tracing span for this request
164            let span = match self.config.level {
165                Level::ERROR => tracing::error_span!(
166                    "http_request",
167                    method = %request.method(),
168                    uri = %request.uri(),
169                    request_id = %request_id,
170                    remote_addr = tracing::field::Empty,
171                ),
172                Level::WARN => tracing::warn_span!(
173                    "http_request",
174                    method = %request.method(),
175                    uri = %request.uri(),
176                    request_id = %request_id,
177                    remote_addr = tracing::field::Empty,
178                ),
179                Level::INFO => tracing::info_span!(
180                    "http_request",
181                    method = %request.method(),
182                    uri = %request.uri(),
183                    request_id = %request_id,
184                    remote_addr = tracing::field::Empty,
185                ),
186                Level::DEBUG => tracing::debug_span!(
187                    "http_request",
188                    method = %request.method(),
189                    uri = %request.uri(),
190                    request_id = %request_id,
191                    remote_addr = tracing::field::Empty,
192                ),
193                Level::TRACE => tracing::trace_span!(
194                    "http_request",
195                    method = %request.method(),
196                    uri = %request.uri(),
197                    request_id = %request_id,
198                    remote_addr = tracing::field::Empty,
199                ),
200            };
201
202            // Store request metadata in extensions
203            request.extensions_mut().insert(RequestMetadata {
204                request_id,
205                start_time,
206                span: span.clone(),
207            });
208
209            // Enter the span for this request
210            let _enter = span.enter();
211
212            // Log request details
213            match self.config.level {
214                Level::ERROR => error!(
215                    "HTTP Request: {} {} (ID: {})",
216                    request.method(),
217                    request.uri(),
218                    request_id
219                ),
220                Level::WARN => warn!(
221                    "HTTP Request: {} {} (ID: {})",
222                    request.method(),
223                    request.uri(), 
224                    request_id
225                ),
226                Level::INFO => info!(
227                    "HTTP Request: {} {} (ID: {})",
228                    request.method(),
229                    request.uri(),
230                    request_id
231                ),
232                Level::DEBUG => {
233                    let headers = self.format_headers(request.headers());
234                    tracing::debug!(
235                        "HTTP Request: {} {} (ID: {}) - Headers: {}",
236                        request.method(),
237                        request.uri(),
238                        request_id,
239                        headers
240                    );
241                },
242                Level::TRACE => {
243                    let headers = self.format_headers(request.headers());
244                    tracing::trace!(
245                        "HTTP Request: {} {} (ID: {}) - Headers: {} - Body tracing: {}",
246                        request.method(),
247                        request.uri(),
248                        request_id,
249                        headers,
250                        self.config.trace_bodies
251                    );
252                }
253            }
254
255            Ok(request)
256        })
257    }
258
259    fn process_response<'a>(
260        &'a self,
261        response: Response
262    ) -> BoxFuture<'a, Response> {
263        Box::pin(async move {
264            let status = response.status();
265            
266            // Try to get request metadata from response extensions
267            // Note: In real middleware pipeline, this would be passed through
268            // For now, we'll create minimal tracing
269            
270            match self.config.level {
271                Level::ERROR if status.is_server_error() => {
272                    error!("HTTP Response: {} (Server Error)", status);
273                },
274                Level::WARN if status.is_client_error() => {
275                    warn!("HTTP Response: {} (Client Error)", status);
276                },
277                Level::INFO => {
278                    info!("HTTP Response: {}", status);
279                },
280                Level::DEBUG => {
281                    let headers = self.format_headers(response.headers());
282                    tracing::debug!(
283                        "HTTP Response: {} - Headers: {}",
284                        status,
285                        headers
286                    );
287                },
288                Level::TRACE => {
289                    let headers = self.format_headers(response.headers());
290                    tracing::trace!(
291                        "HTTP Response: {} - Headers: {} - Body tracing: {}",
292                        status,
293                        headers,
294                        self.config.trace_response_bodies
295                    );
296                },
297                _ => {} // Skip logging for other combinations
298            }
299
300            response
301        })
302    }
303
304    fn name(&self) -> &'static str {
305        "TracingMiddleware"
306    }
307}
308
309/// Request metadata stored in request extensions
310#[derive(Debug, Clone)]
311pub struct RequestMetadata {
312    pub request_id: Uuid,
313    pub start_time: Instant,
314    pub span: Span,
315}
316
317#[cfg(test)]
318mod tests {
319    use super::*;
320    use axum::http::{Method, StatusCode, HeaderValue};
321    use tracing_test::traced_test;
322
323    #[traced_test]
324    #[tokio::test]
325    async fn test_tracing_middleware_basic() {
326        let middleware = TracingMiddleware::new();
327        
328        let request = Request::builder()
329            .method(Method::GET)
330            .uri("/test")
331            .body(axum::body::Body::empty())
332            .unwrap();
333
334        let result = middleware.process_request(request).await;
335        assert!(result.is_ok());
336
337        let processed_request = result.unwrap();
338        
339        // Check that request metadata was added
340        let metadata = processed_request.extensions().get::<RequestMetadata>();
341        assert!(metadata.is_some());
342        
343        let metadata = metadata.unwrap();
344        assert!(!metadata.request_id.is_nil());
345        assert!(metadata.start_time.elapsed().as_nanos() > 0);
346    }
347
348    #[traced_test]
349    #[tokio::test] 
350    async fn test_tracing_middleware_response() {
351        let middleware = TracingMiddleware::new();
352        
353        let response = Response::builder()
354            .status(StatusCode::OK)
355            .body(axum::body::Body::empty())
356            .unwrap();
357
358        let processed_response = middleware.process_response(response).await;
359        assert_eq!(processed_response.status(), StatusCode::OK);
360    }
361
362    #[tokio::test]
363    async fn test_tracing_config_customization() {
364        let config = TracingConfig::default()
365            .with_body_tracing()
366            .with_level(Level::DEBUG)
367            .with_max_body_size(2048)
368            .add_sensitive_header("x-custom-secret".to_string());
369
370        let middleware = TracingMiddleware::with_config(config);
371        assert!(middleware.config.trace_bodies);
372        assert_eq!(middleware.config.level, Level::DEBUG);
373        assert_eq!(middleware.config.max_body_size, 2048);
374        assert!(middleware.config.sensitive_headers.contains(&"x-custom-secret".to_string()));
375    }
376
377    #[tokio::test]
378    async fn test_sensitive_header_detection() {
379        let middleware = TracingMiddleware::new();
380        
381        assert!(middleware.is_sensitive_header("Authorization"));
382        assert!(middleware.is_sensitive_header("COOKIE"));
383        assert!(middleware.is_sensitive_header("x-api-key"));
384        assert!(!middleware.is_sensitive_header("content-type"));
385        assert!(!middleware.is_sensitive_header("accept"));
386    }
387
388    #[tokio::test]
389    async fn test_header_formatting() {
390        let middleware = TracingMiddleware::new();
391        
392        let mut headers = axum::http::HeaderMap::new();
393        headers.insert("content-type", HeaderValue::from_static("application/json"));
394        headers.insert("authorization", HeaderValue::from_static("Bearer secret"));
395        headers.insert("x-custom", HeaderValue::from_static("value"));
396
397        let formatted = middleware.format_headers(&headers);
398        
399        assert!(formatted.contains("content-type=application/json"));
400        assert!(formatted.contains("authorization=[REDACTED]"));
401        assert!(formatted.contains("x-custom=value"));
402    }
403
404    #[tokio::test]
405    async fn test_tracing_middleware_name() {
406        let middleware = TracingMiddleware::new();
407        assert_eq!(middleware.name(), "TracingMiddleware");
408    }
409}