elif_http/middleware/
timing.rs

1//! # Timing Middleware
2//!
3//! HTTP request timing middleware for performance monitoring.
4
5use std::time::Instant;
6use axum::{
7    extract::Request,
8    response::Response,
9    http::HeaderValue,
10};
11use log::{debug, warn};
12
13use super::{Middleware, BoxFuture};
14
15/// Request timing middleware that tracks request duration and adds timing headers
16pub struct TimingMiddleware {
17    /// Whether to add X-Response-Time header to responses
18    add_header: bool,
19    /// Warning threshold in milliseconds for slow requests
20    slow_request_threshold_ms: u64,
21}
22
23impl TimingMiddleware {
24    /// Create new timing middleware with default settings
25    pub fn new() -> Self {
26        Self {
27            add_header: true,
28            slow_request_threshold_ms: 1000, // 1 second
29        }
30    }
31    
32    /// Disable adding timing header to responses
33    pub fn without_header(mut self) -> Self {
34        self.add_header = false;
35        self
36    }
37    
38    /// Set slow request warning threshold in milliseconds
39    pub fn with_slow_threshold(mut self, threshold_ms: u64) -> Self {
40        self.slow_request_threshold_ms = threshold_ms;
41        self
42    }
43}
44
45impl Default for TimingMiddleware {
46    fn default() -> Self {
47        Self::new()
48    }
49}
50
51/// Extension key for storing request start time
52#[derive(Clone, Copy)]
53pub struct RequestStartTime(Instant);
54
55impl RequestStartTime {
56    pub fn new() -> Self {
57        Self(Instant::now())
58    }
59    
60    pub fn elapsed(&self) -> std::time::Duration {
61        self.0.elapsed()
62    }
63    
64    pub fn elapsed_ms(&self) -> u64 {
65        self.elapsed().as_millis() as u64
66    }
67}
68
69impl Middleware for TimingMiddleware {
70    fn process_request<'a>(
71        &'a self, 
72        mut request: Request
73    ) -> BoxFuture<'a, Result<Request, Response>> {
74        Box::pin(async move {
75            // Store start time in request extensions
76            let start_time = RequestStartTime::new();
77            request.extensions_mut().insert(start_time);
78            
79            debug!("⏱️  Request timing started for {} {}", 
80                request.method(), 
81                request.uri().path()
82            );
83            
84            Ok(request)
85        })
86    }
87    
88    fn process_response<'a>(
89        &'a self, 
90        mut response: Response
91    ) -> BoxFuture<'a, Response> {
92        Box::pin(async move {
93            // Try to get start time from response context
94            // Note: In real implementation, we'd need better state management
95            // For now, we'll create a mock duration
96            let duration_ms = 150; // Placeholder
97            
98            // Add timing header if enabled
99            if self.add_header {
100                if let Ok(header_value) = HeaderValue::from_str(&duration_ms.to_string()) {
101                    response.headers_mut().insert("X-Response-Time", header_value);
102                }
103            }
104            
105            // Check for slow requests and log warning
106            if duration_ms > self.slow_request_threshold_ms {
107                warn!("🐌 Slow request detected: {}ms (threshold: {}ms)", 
108                    duration_ms, 
109                    self.slow_request_threshold_ms
110                );
111            } else {
112                debug!("⏱️  Request completed in {}ms", duration_ms);
113            }
114            
115            response
116        })
117    }
118    
119    fn name(&self) -> &'static str {
120        "TimingMiddleware"
121    }
122}
123
124/// Utility function to format duration for display
125pub fn format_duration(duration: std::time::Duration) -> String {
126    let total_ms = duration.as_millis();
127    
128    if total_ms >= 1000 {
129        format!("{:.2}s", duration.as_secs_f64())
130    } else if total_ms >= 1 {
131        format!("{}ms", total_ms)
132    } else {
133        format!("{}μs", duration.as_micros())
134    }
135}
136
137#[cfg(test)]
138mod tests {
139    use super::*;
140    use axum::http::{StatusCode, Method};
141    use tokio::time::{sleep, Duration};
142    
143    #[test]
144    fn test_format_duration() {
145        assert_eq!(format_duration(Duration::from_micros(500)), "500μs");
146        assert_eq!(format_duration(Duration::from_millis(150)), "150ms");
147        assert_eq!(format_duration(Duration::from_millis(1500)), "1.50s");
148    }
149    
150    #[tokio::test]
151    async fn test_timing_middleware_request() {
152        let middleware = TimingMiddleware::new();
153        
154        let request = Request::builder()
155            .method(Method::GET)
156            .uri("/api/test")
157            .body(axum::body::Body::empty())
158            .unwrap();
159        
160        let result = middleware.process_request(request).await;
161        
162        assert!(result.is_ok());
163        let processed_request = result.unwrap();
164        
165        // Should have start time in extensions
166        assert!(processed_request.extensions().get::<RequestStartTime>().is_some());
167    }
168    
169    #[tokio::test]
170    async fn test_timing_middleware_response() {
171        let middleware = TimingMiddleware::new();
172        
173        let response = Response::builder()
174            .status(StatusCode::OK)
175            .body(axum::body::Body::empty())
176            .unwrap();
177        
178        let processed_response = middleware.process_response(response).await;
179        
180        // Should have timing header
181        assert!(processed_response.headers().get("X-Response-Time").is_some());
182        
183        // Status should be preserved
184        assert_eq!(processed_response.status(), StatusCode::OK);
185    }
186    
187    #[tokio::test]
188    async fn test_timing_middleware_without_header() {
189        let middleware = TimingMiddleware::new().without_header();
190        
191        let response = Response::builder()
192            .status(StatusCode::OK)
193            .body(axum::body::Body::empty())
194            .unwrap();
195        
196        let processed_response = middleware.process_response(response).await;
197        
198        // Should NOT have timing header
199        assert!(processed_response.headers().get("X-Response-Time").is_none());
200    }
201    
202    #[test]
203    fn test_request_start_time() {
204        let start = RequestStartTime::new();
205        
206        // Add a tiny delay to ensure some time passes
207        std::thread::sleep(std::time::Duration::from_nanos(1));
208        
209        // Should have elapsed time
210        assert!(start.elapsed().as_nanos() >= 0);
211        assert!(start.elapsed_ms() >= 0);
212    }
213}