elif_http/middleware/
timing.rs1use std::time::Instant;
6use axum::{
7 extract::Request,
8 response::Response,
9 http::HeaderValue,
10};
11use log::{debug, warn};
12
13use super::{Middleware, BoxFuture};
14
15pub struct TimingMiddleware {
17 add_header: bool,
19 slow_request_threshold_ms: u64,
21}
22
23impl TimingMiddleware {
24 pub fn new() -> Self {
26 Self {
27 add_header: true,
28 slow_request_threshold_ms: 1000, }
30 }
31
32 pub fn without_header(mut self) -> Self {
34 self.add_header = false;
35 self
36 }
37
38 pub fn with_slow_threshold(mut self, threshold_ms: u64) -> Self {
40 self.slow_request_threshold_ms = threshold_ms;
41 self
42 }
43}
44
45impl Default for TimingMiddleware {
46 fn default() -> Self {
47 Self::new()
48 }
49}
50
51#[derive(Clone, Copy)]
53pub struct RequestStartTime(Instant);
54
55impl RequestStartTime {
56 pub fn new() -> Self {
57 Self(Instant::now())
58 }
59
60 pub fn elapsed(&self) -> std::time::Duration {
61 self.0.elapsed()
62 }
63
64 pub fn elapsed_ms(&self) -> u64 {
65 self.elapsed().as_millis() as u64
66 }
67}
68
69impl Middleware for TimingMiddleware {
70 fn process_request<'a>(
71 &'a self,
72 mut request: Request
73 ) -> BoxFuture<'a, Result<Request, Response>> {
74 Box::pin(async move {
75 let start_time = RequestStartTime::new();
77 request.extensions_mut().insert(start_time);
78
79 debug!("⏱️ Request timing started for {} {}",
80 request.method(),
81 request.uri().path()
82 );
83
84 Ok(request)
85 })
86 }
87
88 fn process_response<'a>(
89 &'a self,
90 mut response: Response
91 ) -> BoxFuture<'a, Response> {
92 Box::pin(async move {
93 let duration_ms = 150; if self.add_header {
100 if let Ok(header_value) = HeaderValue::from_str(&duration_ms.to_string()) {
101 response.headers_mut().insert("X-Response-Time", header_value);
102 }
103 }
104
105 if duration_ms > self.slow_request_threshold_ms {
107 warn!("🐌 Slow request detected: {}ms (threshold: {}ms)",
108 duration_ms,
109 self.slow_request_threshold_ms
110 );
111 } else {
112 debug!("⏱️ Request completed in {}ms", duration_ms);
113 }
114
115 response
116 })
117 }
118
119 fn name(&self) -> &'static str {
120 "TimingMiddleware"
121 }
122}
123
124pub fn format_duration(duration: std::time::Duration) -> String {
126 let total_ms = duration.as_millis();
127
128 if total_ms >= 1000 {
129 format!("{:.2}s", duration.as_secs_f64())
130 } else if total_ms >= 1 {
131 format!("{}ms", total_ms)
132 } else {
133 format!("{}μs", duration.as_micros())
134 }
135}
136
137#[cfg(test)]
138mod tests {
139 use super::*;
140 use axum::http::{StatusCode, Method};
141 use tokio::time::{sleep, Duration};
142
143 #[test]
144 fn test_format_duration() {
145 assert_eq!(format_duration(Duration::from_micros(500)), "500μs");
146 assert_eq!(format_duration(Duration::from_millis(150)), "150ms");
147 assert_eq!(format_duration(Duration::from_millis(1500)), "1.50s");
148 }
149
150 #[tokio::test]
151 async fn test_timing_middleware_request() {
152 let middleware = TimingMiddleware::new();
153
154 let request = Request::builder()
155 .method(Method::GET)
156 .uri("/api/test")
157 .body(axum::body::Body::empty())
158 .unwrap();
159
160 let result = middleware.process_request(request).await;
161
162 assert!(result.is_ok());
163 let processed_request = result.unwrap();
164
165 assert!(processed_request.extensions().get::<RequestStartTime>().is_some());
167 }
168
169 #[tokio::test]
170 async fn test_timing_middleware_response() {
171 let middleware = TimingMiddleware::new();
172
173 let response = Response::builder()
174 .status(StatusCode::OK)
175 .body(axum::body::Body::empty())
176 .unwrap();
177
178 let processed_response = middleware.process_response(response).await;
179
180 assert!(processed_response.headers().get("X-Response-Time").is_some());
182
183 assert_eq!(processed_response.status(), StatusCode::OK);
185 }
186
187 #[tokio::test]
188 async fn test_timing_middleware_without_header() {
189 let middleware = TimingMiddleware::new().without_header();
190
191 let response = Response::builder()
192 .status(StatusCode::OK)
193 .body(axum::body::Body::empty())
194 .unwrap();
195
196 let processed_response = middleware.process_response(response).await;
197
198 assert!(processed_response.headers().get("X-Response-Time").is_none());
200 }
201
202 #[test]
203 fn test_request_start_time() {
204 let start = RequestStartTime::new();
205
206 std::thread::sleep(std::time::Duration::from_nanos(1));
208
209 assert!(start.elapsed().as_nanos() >= 0);
211 assert!(start.elapsed_ms() >= 0);
212 }
213}