reinhardt_middleware/
timeout.rs1use async_trait::async_trait;
7use hyper::StatusCode;
8use reinhardt_http::{Handler, Middleware, Request, Response, Result};
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::time::timeout;
12
13#[non_exhaustive]
24#[derive(Debug, Clone)]
25pub struct TimeoutConfig {
26 pub duration: Duration,
28}
29
30impl TimeoutConfig {
31 pub fn new(duration: Duration) -> Self {
42 Self { duration }
43 }
44}
45
46impl Default for TimeoutConfig {
47 fn default() -> Self {
48 Self {
49 duration: Duration::from_secs(30),
50 }
51 }
52}
53
54pub struct TimeoutMiddleware {
69 config: TimeoutConfig,
70}
71
72impl TimeoutMiddleware {
73 pub fn new(config: TimeoutConfig) -> Self {
85 Self { config }
86 }
87}
88
89#[async_trait]
90impl Middleware for TimeoutMiddleware {
91 async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
92 match timeout(self.config.duration, next.handle(request)).await {
93 Ok(result) => result,
94 Err(_) => {
95 Ok(Response::new(StatusCode::REQUEST_TIMEOUT)
96 .with_body("Request Timeout".to_string()))
97 }
98 }
99 }
100}
101
102#[cfg(test)]
103mod tests {
104 use super::*;
105 use bytes::Bytes;
106 use hyper::{HeaderMap, Method, StatusCode, Version};
107 use std::time::Duration;
108 use tokio::time::sleep;
109
110 struct FastHandler;
111
112 #[async_trait]
113 impl Handler for FastHandler {
114 async fn handle(&self, _request: Request) -> Result<Response> {
115 Ok(Response::ok())
116 }
117 }
118
119 struct SlowHandler {
120 delay: Duration,
121 }
122
123 #[async_trait]
124 impl Handler for SlowHandler {
125 async fn handle(&self, _request: Request) -> Result<Response> {
126 sleep(self.delay).await;
127 Ok(Response::ok())
128 }
129 }
130
131 #[tokio::test]
132 async fn test_fast_request_completes() {
133 let config = TimeoutConfig::new(Duration::from_secs(1));
134 let middleware = TimeoutMiddleware::new(config);
135 let handler = Arc::new(FastHandler);
136
137 let request = Request::builder()
138 .method(Method::GET)
139 .uri("/test")
140 .version(Version::HTTP_11)
141 .headers(HeaderMap::new())
142 .body(Bytes::new())
143 .build()
144 .unwrap();
145 let response = middleware.process(request, handler).await.unwrap();
146
147 assert_eq!(response.status, StatusCode::OK);
148 }
149
150 #[tokio::test]
151 async fn test_slow_request_times_out() {
152 let config = TimeoutConfig::new(Duration::from_millis(100));
153 let middleware = TimeoutMiddleware::new(config);
154 let handler = Arc::new(SlowHandler {
155 delay: Duration::from_millis(500),
156 });
157
158 let request = Request::builder()
159 .method(Method::GET)
160 .uri("/test")
161 .version(Version::HTTP_11)
162 .headers(HeaderMap::new())
163 .body(Bytes::new())
164 .build()
165 .unwrap();
166 let response = middleware.process(request, handler).await.unwrap();
167
168 assert_eq!(response.status, StatusCode::REQUEST_TIMEOUT);
169 assert_eq!(response.body, Bytes::from("Request Timeout"));
170 }
171
172 #[tokio::test]
173 async fn test_request_just_within_timeout() {
174 let config = TimeoutConfig::new(Duration::from_millis(200));
175 let middleware = TimeoutMiddleware::new(config);
176 let handler = Arc::new(SlowHandler {
177 delay: Duration::from_millis(50),
178 });
179
180 let request = Request::builder()
181 .method(Method::GET)
182 .uri("/test")
183 .version(Version::HTTP_11)
184 .headers(HeaderMap::new())
185 .body(Bytes::new())
186 .build()
187 .unwrap();
188 let response = middleware.process(request, handler).await.unwrap();
189
190 assert_eq!(response.status, StatusCode::OK);
191 }
192
193 #[tokio::test]
194 async fn test_custom_timeout_duration() {
195 let custom_duration = Duration::from_secs(5);
196 let config = TimeoutConfig::new(custom_duration);
197
198 assert_eq!(config.duration, custom_duration);
199 }
200
201 #[tokio::test]
202 async fn test_default_timeout_config() {
203 let config = TimeoutConfig::default();
204
205 assert_eq!(config.duration, Duration::from_secs(30));
206 }
207}