Skip to main content

reinhardt_middleware/
timeout.rs

1//! Timeout middleware for limiting request processing time
2//!
3//! This middleware wraps requests with a timeout, returning an error
4//! if the handler doesn't complete within the specified duration.
5
6use async_trait::async_trait;
7use hyper::StatusCode;
8use reinhardt_http::{Handler, Middleware, Request, Response, Result};
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::time::timeout;
12
13/// Configuration for timeout middleware
14///
15/// # Examples
16///
17/// ```
18/// use reinhardt_middleware::timeout::TimeoutConfig;
19/// use std::time::Duration;
20///
21/// let config = TimeoutConfig::new(Duration::from_secs(30));
22/// ```
23#[non_exhaustive]
24#[derive(Debug, Clone)]
25pub struct TimeoutConfig {
26	/// Request timeout duration
27	pub duration: Duration,
28}
29
30impl TimeoutConfig {
31	/// Create a new timeout configuration
32	///
33	/// # Examples
34	///
35	/// ```
36	/// use reinhardt_middleware::timeout::TimeoutConfig;
37	/// use std::time::Duration;
38	///
39	/// let config = TimeoutConfig::new(Duration::from_secs(60));
40	/// ```
41	pub fn new(duration: Duration) -> Self {
42		Self { duration }
43	}
44}
45
46impl Default for TimeoutConfig {
47	fn default() -> Self {
48		Self {
49			duration: Duration::from_secs(30),
50		}
51	}
52}
53
54/// Timeout middleware
55///
56/// Wraps request processing with a timeout, returning REQUEST_TIMEOUT (408)
57/// if the handler doesn't complete within the configured duration.
58///
59/// # Examples
60///
61/// ```
62/// use reinhardt_middleware::timeout::{TimeoutMiddleware, TimeoutConfig};
63/// use std::time::Duration;
64///
65/// let config = TimeoutConfig::new(Duration::from_secs(30));
66/// let middleware = TimeoutMiddleware::new(config);
67/// ```
68pub struct TimeoutMiddleware {
69	config: TimeoutConfig,
70}
71
72impl TimeoutMiddleware {
73	/// Create a new timeout middleware
74	///
75	/// # Examples
76	///
77	/// ```
78	/// use reinhardt_middleware::timeout::{TimeoutMiddleware, TimeoutConfig};
79	/// use std::time::Duration;
80	///
81	/// let config = TimeoutConfig::new(Duration::from_secs(30));
82	/// let middleware = TimeoutMiddleware::new(config);
83	/// ```
84	pub fn new(config: TimeoutConfig) -> Self {
85		Self { config }
86	}
87}
88
89#[async_trait]
90impl Middleware for TimeoutMiddleware {
91	async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
92		match timeout(self.config.duration, next.handle(request)).await {
93			Ok(result) => result,
94			Err(_) => {
95				Ok(Response::new(StatusCode::REQUEST_TIMEOUT)
96					.with_body("Request Timeout".to_string()))
97			}
98		}
99	}
100}
101
102#[cfg(test)]
103mod tests {
104	use super::*;
105	use bytes::Bytes;
106	use hyper::{HeaderMap, Method, StatusCode, Version};
107	use std::time::Duration;
108	use tokio::time::sleep;
109
110	struct FastHandler;
111
112	#[async_trait]
113	impl Handler for FastHandler {
114		async fn handle(&self, _request: Request) -> Result<Response> {
115			Ok(Response::ok())
116		}
117	}
118
119	struct SlowHandler {
120		delay: Duration,
121	}
122
123	#[async_trait]
124	impl Handler for SlowHandler {
125		async fn handle(&self, _request: Request) -> Result<Response> {
126			sleep(self.delay).await;
127			Ok(Response::ok())
128		}
129	}
130
131	#[tokio::test]
132	async fn test_fast_request_completes() {
133		let config = TimeoutConfig::new(Duration::from_secs(1));
134		let middleware = TimeoutMiddleware::new(config);
135		let handler = Arc::new(FastHandler);
136
137		let request = Request::builder()
138			.method(Method::GET)
139			.uri("/test")
140			.version(Version::HTTP_11)
141			.headers(HeaderMap::new())
142			.body(Bytes::new())
143			.build()
144			.unwrap();
145		let response = middleware.process(request, handler).await.unwrap();
146
147		assert_eq!(response.status, StatusCode::OK);
148	}
149
150	#[tokio::test]
151	async fn test_slow_request_times_out() {
152		let config = TimeoutConfig::new(Duration::from_millis(100));
153		let middleware = TimeoutMiddleware::new(config);
154		let handler = Arc::new(SlowHandler {
155			delay: Duration::from_millis(500),
156		});
157
158		let request = Request::builder()
159			.method(Method::GET)
160			.uri("/test")
161			.version(Version::HTTP_11)
162			.headers(HeaderMap::new())
163			.body(Bytes::new())
164			.build()
165			.unwrap();
166		let response = middleware.process(request, handler).await.unwrap();
167
168		assert_eq!(response.status, StatusCode::REQUEST_TIMEOUT);
169		assert_eq!(response.body, Bytes::from("Request Timeout"));
170	}
171
172	#[tokio::test]
173	async fn test_request_just_within_timeout() {
174		let config = TimeoutConfig::new(Duration::from_millis(200));
175		let middleware = TimeoutMiddleware::new(config);
176		let handler = Arc::new(SlowHandler {
177			delay: Duration::from_millis(50),
178		});
179
180		let request = Request::builder()
181			.method(Method::GET)
182			.uri("/test")
183			.version(Version::HTTP_11)
184			.headers(HeaderMap::new())
185			.body(Bytes::new())
186			.build()
187			.unwrap();
188		let response = middleware.process(request, handler).await.unwrap();
189
190		assert_eq!(response.status, StatusCode::OK);
191	}
192
193	#[tokio::test]
194	async fn test_custom_timeout_duration() {
195		let custom_duration = Duration::from_secs(5);
196		let config = TimeoutConfig::new(custom_duration);
197
198		assert_eq!(config.duration, custom_duration);
199	}
200
201	#[tokio::test]
202	async fn test_default_timeout_config() {
203		let config = TimeoutConfig::default();
204
205		assert_eq!(config.duration, Duration::from_secs(30));
206	}
207}