1use std::time::Duration;
7use tokio::time::{timeout, Timeout};
8use axum::{
9 extract::Request,
10 response::{Response, IntoResponse},
11 http::StatusCode,
12};
13use tracing::{warn, error};
14
15use crate::{
16 middleware::{Middleware, BoxFuture},
17 HttpError,
18};
19
20#[derive(Debug, Clone)]
22pub struct TimeoutConfig {
23 pub timeout: Duration,
25 pub log_timeouts: bool,
27 pub timeout_message: String,
29}
30
31impl Default for TimeoutConfig {
32 fn default() -> Self {
33 Self {
34 timeout: Duration::from_secs(30),
35 log_timeouts: true,
36 timeout_message: "Request timed out".to_string(),
37 }
38 }
39}
40
41impl TimeoutConfig {
42 pub fn new(timeout: Duration) -> Self {
44 Self {
45 timeout,
46 ..Default::default()
47 }
48 }
49
50 pub fn with_timeout(mut self, timeout: Duration) -> Self {
52 self.timeout = timeout;
53 self
54 }
55
56 pub fn with_logging(mut self, log_timeouts: bool) -> Self {
58 self.log_timeouts = log_timeouts;
59 self
60 }
61
62 pub fn with_message<S: Into<String>>(mut self, message: S) -> Self {
64 self.timeout_message = message.into();
65 self
66 }
67}
68
69pub struct TimeoutMiddleware {
71 config: TimeoutConfig,
72}
73
74impl TimeoutMiddleware {
75 pub fn new() -> Self {
77 Self {
78 config: TimeoutConfig::default(),
79 }
80 }
81
82 pub fn with_duration(timeout: Duration) -> Self {
84 Self {
85 config: TimeoutConfig::new(timeout),
86 }
87 }
88
89 pub fn with_config(config: TimeoutConfig) -> Self {
91 Self { config }
92 }
93
94 pub fn timeout(mut self, duration: Duration) -> Self {
96 self.config = self.config.with_timeout(duration);
97 self
98 }
99
100 pub fn logging(mut self, enabled: bool) -> Self {
102 self.config = self.config.with_logging(enabled);
103 self
104 }
105
106 pub fn message<S: Into<String>>(mut self, message: S) -> Self {
108 self.config = self.config.with_message(message);
109 self
110 }
111
112 pub fn duration(&self) -> Duration {
114 self.config.timeout
115 }
116
117 fn timeout_response(&self) -> Response {
119 let error = HttpError::timeout(&self.config.timeout_message);
120 error.into_response()
121 }
122}
123
124impl Default for TimeoutMiddleware {
125 fn default() -> Self {
126 Self::new()
127 }
128}
129
130impl Middleware for TimeoutMiddleware {
131 fn process_request<'a>(
132 &'a self,
133 request: Request
134 ) -> BoxFuture<'a, Result<Request, Response>> {
135 Box::pin(async move {
136 let mut request = request;
139 request.extensions_mut().insert(TimeoutInfo {
140 duration: self.config.timeout,
141 message: self.config.timeout_message.clone(),
142 });
143
144 Ok(request)
145 })
146 }
147
148 fn process_response<'a>(
149 &'a self,
150 response: Response
151 ) -> BoxFuture<'a, Response> {
152 Box::pin(async move {
153 if response.status() == StatusCode::REQUEST_TIMEOUT && self.config.log_timeouts {
157 warn!("Request timed out after {:?}", self.config.timeout);
158 }
159
160 response
161 })
162 }
163
164 fn name(&self) -> &'static str {
165 "TimeoutMiddleware"
166 }
167}
168
169#[derive(Debug, Clone)]
171pub struct TimeoutInfo {
172 pub duration: Duration,
173 pub message: String,
174}
175
176pub async fn apply_timeout<F, T>(
178 future: F,
179 duration: Duration,
180 timeout_message: &str,
181) -> Result<T, Response>
182where
183 F: std::future::Future<Output = T>,
184{
185 match timeout(duration, future).await {
186 Ok(result) => Ok(result),
187 Err(_) => {
188 error!("Request timed out after {:?}: {}", duration, timeout_message);
189 let error = HttpError::timeout(timeout_message);
190 Err(error.into_response())
191 }
192 }
193}
194
195pub struct TimeoutHandler<F> {
197 handler: F,
198 duration: Duration,
199 message: String,
200}
201
202impl<F> TimeoutHandler<F> {
203 pub fn new(handler: F, duration: Duration) -> Self {
204 Self {
205 handler,
206 duration,
207 message: "Request timed out".to_string(),
208 }
209 }
210
211 pub fn with_message<S: Into<String>>(mut self, message: S) -> Self {
212 self.message = message.into();
213 self
214 }
215}
216
217impl<F, Fut, T> tower::Service<Request> for TimeoutHandler<F>
218where
219 F: tower::Service<Request, Response = T, Future = Fut> + Clone + Send + 'static,
220 Fut: std::future::Future<Output = Result<T, F::Error>> + Send + 'static,
221 T: axum::response::IntoResponse,
222{
223 type Response = Response;
224 type Error = Response;
225 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
226
227 fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> std::task::Poll<Result<(), Self::Error>> {
228 match self.handler.poll_ready(cx) {
229 std::task::Poll::Ready(Ok(())) => std::task::Poll::Ready(Ok(())),
230 std::task::Poll::Ready(Err(_)) => {
231 let error = HttpError::internal("Handler not ready");
232 std::task::Poll::Ready(Err(error.into_response()))
233 },
234 std::task::Poll::Pending => std::task::Poll::Pending,
235 }
236 }
237
238 fn call(&mut self, request: Request) -> Self::Future {
239 let handler = self.handler.clone();
240 let mut handler = handler;
241 let duration = self.duration;
242 let message = self.message.clone();
243
244 Box::pin(async move {
245 match timeout(duration, handler.call(request)).await {
246 Ok(Ok(response)) => Ok(response.into_response()),
247 Ok(Err(_)) => {
248 let error = HttpError::internal("Handler error");
249 Err(error.into_response())
250 },
251 Err(_) => {
252 error!("Request timed out after {:?}: {}", duration, message);
253 let error = HttpError::timeout(&message);
254 Err(error.into_response())
255 }
256 }
257 })
258 }
259}
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264 use axum::http::{Method, StatusCode};
265 use tokio::time::{sleep, Duration as TokioDuration};
266 use std::time::Duration;
267
268 #[tokio::test]
269 async fn test_timeout_middleware_basic() {
270 let middleware = TimeoutMiddleware::new();
271
272 let request = Request::builder()
273 .method(Method::GET)
274 .uri("/test")
275 .body(axum::body::Body::empty())
276 .unwrap();
277
278 let result = middleware.process_request(request).await;
279 assert!(result.is_ok());
280
281 let processed_request = result.unwrap();
282
283 let timeout_info = processed_request.extensions().get::<TimeoutInfo>();
285 assert!(timeout_info.is_some());
286
287 let timeout_info = timeout_info.unwrap();
288 assert_eq!(timeout_info.duration, Duration::from_secs(30));
289 assert_eq!(timeout_info.message, "Request timed out");
290 }
291
292 #[tokio::test]
293 async fn test_timeout_middleware_custom_config() {
294 let config = TimeoutConfig::new(Duration::from_secs(60))
295 .with_logging(false)
296 .with_message("Custom timeout");
297
298 let middleware = TimeoutMiddleware::with_config(config);
299
300 assert_eq!(middleware.duration(), Duration::from_secs(60));
301 assert!(!middleware.config.log_timeouts);
302 assert_eq!(middleware.config.timeout_message, "Custom timeout");
303 }
304
305 #[tokio::test]
306 async fn test_timeout_middleware_builder() {
307 let middleware = TimeoutMiddleware::new()
308 .timeout(Duration::from_secs(45))
309 .logging(true)
310 .message("Builder timeout");
311
312 assert_eq!(middleware.duration(), Duration::from_secs(45));
313 assert!(middleware.config.log_timeouts);
314 assert_eq!(middleware.config.timeout_message, "Builder timeout");
315 }
316
317 #[tokio::test]
318 async fn test_timeout_middleware_response() {
319 let middleware = TimeoutMiddleware::new();
320
321 let response = Response::builder()
322 .status(StatusCode::OK)
323 .body(axum::body::Body::empty())
324 .unwrap();
325
326 let processed_response = middleware.process_response(response).await;
327 assert_eq!(processed_response.status(), StatusCode::OK);
328 }
329
330 #[tokio::test]
331 async fn test_timeout_middleware_name() {
332 let middleware = TimeoutMiddleware::new();
333 assert_eq!(middleware.name(), "TimeoutMiddleware");
334 }
335
336 #[tokio::test]
337 async fn test_apply_timeout_success() {
338 let future = async { "success" };
339 let result = apply_timeout(future, Duration::from_secs(1), "test timeout").await;
340
341 assert!(result.is_ok());
342 assert_eq!(result.unwrap(), "success");
343 }
344
345 #[tokio::test]
346 async fn test_apply_timeout_failure() {
347 let future = async {
348 sleep(TokioDuration::from_secs(2)).await;
349 "should not reach here"
350 };
351
352 let result = apply_timeout(future, Duration::from_millis(100), "test timeout").await;
353 assert!(result.is_err());
354
355 let response = result.unwrap_err();
357 assert_eq!(response.status(), StatusCode::REQUEST_TIMEOUT);
358 }
359
360 #[tokio::test]
361 async fn test_timeout_config_defaults() {
362 let config = TimeoutConfig::default();
363
364 assert_eq!(config.timeout, Duration::from_secs(30));
365 assert!(config.log_timeouts);
366 assert_eq!(config.timeout_message, "Request timed out");
367 }
368
369 #[tokio::test]
370 async fn test_timeout_info_extension() {
371 let middleware = TimeoutMiddleware::with_duration(Duration::from_secs(15));
372
373 let request = Request::builder()
374 .method(Method::POST)
375 .uri("/api/test")
376 .body(axum::body::Body::empty())
377 .unwrap();
378
379 let result = middleware.process_request(request).await;
380 let processed_request = result.unwrap();
381
382 let timeout_info = processed_request.extensions().get::<TimeoutInfo>().unwrap();
383 assert_eq!(timeout_info.duration, Duration::from_secs(15));
384 assert_eq!(timeout_info.message, "Request timed out");
385 }
386}