elif_http/middleware/
timeout.rs

1//! # Timeout Middleware
2//!
3//! Framework middleware for request timeout handling.
4//! Replaces tower-http TimeoutLayer with framework-native implementation.
5
6use std::time::Duration;
7use tokio::time::{timeout, Timeout};
8use axum::{
9    extract::Request,
10    response::{Response, IntoResponse},
11    http::StatusCode,
12};
13use tracing::{warn, error};
14
15use crate::{
16    middleware::{Middleware, BoxFuture},
17    HttpError,
18};
19
20/// Configuration for timeout middleware
21#[derive(Debug, Clone)]
22pub struct TimeoutConfig {
23    /// Request timeout duration
24    pub timeout: Duration,
25    /// Whether to log timeout events
26    pub log_timeouts: bool,
27    /// Custom timeout error message
28    pub timeout_message: String,
29}
30
31impl Default for TimeoutConfig {
32    fn default() -> Self {
33        Self {
34            timeout: Duration::from_secs(30),
35            log_timeouts: true,
36            timeout_message: "Request timed out".to_string(),
37        }
38    }
39}
40
41impl TimeoutConfig {
42    /// Create new timeout configuration
43    pub fn new(timeout: Duration) -> Self {
44        Self {
45            timeout,
46            ..Default::default()
47        }
48    }
49
50    /// Set timeout duration
51    pub fn with_timeout(mut self, timeout: Duration) -> Self {
52        self.timeout = timeout;
53        self
54    }
55
56    /// Enable or disable timeout logging
57    pub fn with_logging(mut self, log_timeouts: bool) -> Self {
58        self.log_timeouts = log_timeouts;
59        self
60    }
61
62    /// Set custom timeout error message
63    pub fn with_message<S: Into<String>>(mut self, message: S) -> Self {
64        self.timeout_message = message.into();
65        self
66    }
67}
68
69/// Framework timeout middleware for HTTP requests
70pub struct TimeoutMiddleware {
71    config: TimeoutConfig,
72}
73
74impl TimeoutMiddleware {
75    /// Create new timeout middleware with default 30 second timeout
76    pub fn new() -> Self {
77        Self {
78            config: TimeoutConfig::default(),
79        }
80    }
81
82    /// Create timeout middleware with specific duration
83    pub fn with_duration(timeout: Duration) -> Self {
84        Self {
85            config: TimeoutConfig::new(timeout),
86        }
87    }
88
89    /// Create timeout middleware with custom configuration
90    pub fn with_config(config: TimeoutConfig) -> Self {
91        Self { config }
92    }
93
94    /// Set timeout duration (builder pattern)
95    pub fn timeout(mut self, duration: Duration) -> Self {
96        self.config = self.config.with_timeout(duration);
97        self
98    }
99
100    /// Enable or disable logging (builder pattern) 
101    pub fn logging(mut self, enabled: bool) -> Self {
102        self.config = self.config.with_logging(enabled);
103        self
104    }
105
106    /// Set custom timeout message (builder pattern)
107    pub fn message<S: Into<String>>(mut self, message: S) -> Self {
108        self.config = self.config.with_message(message);
109        self
110    }
111
112    /// Get timeout duration
113    pub fn duration(&self) -> Duration {
114        self.config.timeout
115    }
116
117    /// Create timeout error response
118    fn timeout_response(&self) -> Response {
119        let error = HttpError::timeout(&self.config.timeout_message);
120        error.into_response()
121    }
122}
123
124impl Default for TimeoutMiddleware {
125    fn default() -> Self {
126        Self::new()
127    }
128}
129
130impl Middleware for TimeoutMiddleware {
131    fn process_request<'a>(
132        &'a self,
133        request: Request
134    ) -> BoxFuture<'a, Result<Request, Response>> {
135        Box::pin(async move {
136            // Store timeout duration in request extensions for downstream middleware
137            // This allows handlers to know the timeout that's been applied
138            let mut request = request;
139            request.extensions_mut().insert(TimeoutInfo {
140                duration: self.config.timeout,
141                message: self.config.timeout_message.clone(),
142            });
143
144            Ok(request)
145        })
146    }
147
148    fn process_response<'a>(
149        &'a self,
150        response: Response
151    ) -> BoxFuture<'a, Response> {
152        Box::pin(async move {
153            // For timeout middleware, response processing is mainly for logging
154            // The actual timeout handling happens at the handler level or higher
155            
156            if response.status() == StatusCode::REQUEST_TIMEOUT && self.config.log_timeouts {
157                warn!("Request timed out after {:?}", self.config.timeout);
158            }
159
160            response
161        })
162    }
163
164    fn name(&self) -> &'static str {
165        "TimeoutMiddleware"
166    }
167}
168
169/// Timeout information stored in request extensions
170#[derive(Debug, Clone)]
171pub struct TimeoutInfo {
172    pub duration: Duration,
173    pub message: String,
174}
175
176/// Helper function to apply timeout to a future
177pub async fn apply_timeout<F, T>(
178    future: F,
179    duration: Duration,
180    timeout_message: &str,
181) -> Result<T, Response>
182where
183    F: std::future::Future<Output = T>,
184{
185    match timeout(duration, future).await {
186        Ok(result) => Ok(result),
187        Err(_) => {
188            error!("Request timed out after {:?}: {}", duration, timeout_message);
189            let error = HttpError::timeout(timeout_message);
190            Err(error.into_response())
191        }
192    }
193}
194
195/// Timeout middleware wrapper that can be applied to handlers
196pub struct TimeoutHandler<F> {
197    handler: F,
198    duration: Duration,
199    message: String,
200}
201
202impl<F> TimeoutHandler<F> {
203    pub fn new(handler: F, duration: Duration) -> Self {
204        Self {
205            handler,
206            duration,
207            message: "Request timed out".to_string(),
208        }
209    }
210
211    pub fn with_message<S: Into<String>>(mut self, message: S) -> Self {
212        self.message = message.into();
213        self
214    }
215}
216
217impl<F, Fut, T> tower::Service<Request> for TimeoutHandler<F>
218where
219    F: tower::Service<Request, Response = T, Future = Fut> + Clone + Send + 'static,
220    Fut: std::future::Future<Output = Result<T, F::Error>> + Send + 'static,
221    T: axum::response::IntoResponse,
222{
223    type Response = Response;
224    type Error = Response;
225    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
226
227    fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> std::task::Poll<Result<(), Self::Error>> {
228        match self.handler.poll_ready(cx) {
229            std::task::Poll::Ready(Ok(())) => std::task::Poll::Ready(Ok(())),
230            std::task::Poll::Ready(Err(_)) => {
231                let error = HttpError::internal("Handler not ready");
232                std::task::Poll::Ready(Err(error.into_response()))
233            },
234            std::task::Poll::Pending => std::task::Poll::Pending,
235        }
236    }
237
238    fn call(&mut self, request: Request) -> Self::Future {
239        let handler = self.handler.clone();
240        let mut handler = handler;
241        let duration = self.duration;
242        let message = self.message.clone();
243
244        Box::pin(async move {
245            match timeout(duration, handler.call(request)).await {
246                Ok(Ok(response)) => Ok(response.into_response()),
247                Ok(Err(_)) => {
248                    let error = HttpError::internal("Handler error");
249                    Err(error.into_response())
250                },
251                Err(_) => {
252                    error!("Request timed out after {:?}: {}", duration, message);
253                    let error = HttpError::timeout(&message);
254                    Err(error.into_response())
255                }
256            }
257        })
258    }
259}
260
261#[cfg(test)]
262mod tests {
263    use super::*;
264    use axum::http::{Method, StatusCode};
265    use tokio::time::{sleep, Duration as TokioDuration};
266    use std::time::Duration;
267
268    #[tokio::test]
269    async fn test_timeout_middleware_basic() {
270        let middleware = TimeoutMiddleware::new();
271        
272        let request = Request::builder()
273            .method(Method::GET)
274            .uri("/test")
275            .body(axum::body::Body::empty())
276            .unwrap();
277
278        let result = middleware.process_request(request).await;
279        assert!(result.is_ok());
280
281        let processed_request = result.unwrap();
282        
283        // Check that timeout info was added to extensions
284        let timeout_info = processed_request.extensions().get::<TimeoutInfo>();
285        assert!(timeout_info.is_some());
286        
287        let timeout_info = timeout_info.unwrap();
288        assert_eq!(timeout_info.duration, Duration::from_secs(30));
289        assert_eq!(timeout_info.message, "Request timed out");
290    }
291
292    #[tokio::test]
293    async fn test_timeout_middleware_custom_config() {
294        let config = TimeoutConfig::new(Duration::from_secs(60))
295            .with_logging(false)
296            .with_message("Custom timeout");
297
298        let middleware = TimeoutMiddleware::with_config(config);
299        
300        assert_eq!(middleware.duration(), Duration::from_secs(60));
301        assert!(!middleware.config.log_timeouts);
302        assert_eq!(middleware.config.timeout_message, "Custom timeout");
303    }
304
305    #[tokio::test]
306    async fn test_timeout_middleware_builder() {
307        let middleware = TimeoutMiddleware::new()
308            .timeout(Duration::from_secs(45))
309            .logging(true)
310            .message("Builder timeout");
311        
312        assert_eq!(middleware.duration(), Duration::from_secs(45));
313        assert!(middleware.config.log_timeouts);
314        assert_eq!(middleware.config.timeout_message, "Builder timeout");
315    }
316
317    #[tokio::test]
318    async fn test_timeout_middleware_response() {
319        let middleware = TimeoutMiddleware::new();
320        
321        let response = Response::builder()
322            .status(StatusCode::OK)
323            .body(axum::body::Body::empty())
324            .unwrap();
325
326        let processed_response = middleware.process_response(response).await;
327        assert_eq!(processed_response.status(), StatusCode::OK);
328    }
329
330    #[tokio::test]
331    async fn test_timeout_middleware_name() {
332        let middleware = TimeoutMiddleware::new();
333        assert_eq!(middleware.name(), "TimeoutMiddleware");
334    }
335
336    #[tokio::test]
337    async fn test_apply_timeout_success() {
338        let future = async { "success" };
339        let result = apply_timeout(future, Duration::from_secs(1), "test timeout").await;
340        
341        assert!(result.is_ok());
342        assert_eq!(result.unwrap(), "success");
343    }
344
345    #[tokio::test]
346    async fn test_apply_timeout_failure() {
347        let future = async {
348            sleep(TokioDuration::from_secs(2)).await;
349            "should not reach here"
350        };
351        
352        let result = apply_timeout(future, Duration::from_millis(100), "test timeout").await;
353        assert!(result.is_err());
354        
355        // Verify it's a timeout response
356        let response = result.unwrap_err();
357        assert_eq!(response.status(), StatusCode::REQUEST_TIMEOUT);
358    }
359
360    #[tokio::test]
361    async fn test_timeout_config_defaults() {
362        let config = TimeoutConfig::default();
363        
364        assert_eq!(config.timeout, Duration::from_secs(30));
365        assert!(config.log_timeouts);
366        assert_eq!(config.timeout_message, "Request timed out");
367    }
368
369    #[tokio::test]
370    async fn test_timeout_info_extension() {
371        let middleware = TimeoutMiddleware::with_duration(Duration::from_secs(15));
372        
373        let request = Request::builder()
374            .method(Method::POST)
375            .uri("/api/test")
376            .body(axum::body::Body::empty())
377            .unwrap();
378
379        let result = middleware.process_request(request).await;
380        let processed_request = result.unwrap();
381        
382        let timeout_info = processed_request.extensions().get::<TimeoutInfo>().unwrap();
383        assert_eq!(timeout_info.duration, Duration::from_secs(15));
384        assert_eq!(timeout_info.message, "Request timed out");
385    }
386}