tower_http/timeout/
service.rs

1use crate::timeout::body::TimeoutBody;
2use http::{Request, Response, StatusCode};
3use pin_project_lite::pin_project;
4use std::{
5    future::Future,
6    pin::Pin,
7    task::{ready, Context, Poll},
8    time::Duration,
9};
10use tokio::time::Sleep;
11use tower_layer::Layer;
12use tower_service::Service;
13
14/// Layer that applies the [`Timeout`] middleware which apply a timeout to requests.
15///
16/// See the [module docs](super) for an example.
17#[derive(Debug, Clone, Copy)]
18pub struct TimeoutLayer {
19    timeout: Duration,
20    status_code: StatusCode,
21}
22
23impl TimeoutLayer {
24    /// Creates a new [`TimeoutLayer`].
25    ///
26    /// By default, it will return a `408 Request Timeout` response if the request does not complete within the specified timeout.
27    /// To customize the response status code, use the `with_status_code` method.
28    #[deprecated(since = "0.6.7", note = "Use `TimeoutLayer::with_status_code` instead")]
29    pub fn new(timeout: Duration) -> Self {
30        Self::with_status_code(StatusCode::REQUEST_TIMEOUT, timeout)
31    }
32
33    /// Creates a new [`TimeoutLayer`] with the specified status code for the timeout response.
34    pub fn with_status_code(status_code: StatusCode, timeout: Duration) -> Self {
35        Self {
36            timeout,
37            status_code,
38        }
39    }
40}
41
42impl<S> Layer<S> for TimeoutLayer {
43    type Service = Timeout<S>;
44
45    fn layer(&self, inner: S) -> Self::Service {
46        Timeout::with_status_code(inner, self.status_code, self.timeout)
47    }
48}
49
50/// Middleware which apply a timeout to requests.
51///
52/// See the [module docs](super) for an example.
53#[derive(Debug, Clone, Copy)]
54pub struct Timeout<S> {
55    inner: S,
56    timeout: Duration,
57    status_code: StatusCode,
58}
59
60impl<S> Timeout<S> {
61    /// Creates a new [`Timeout`].
62    ///
63    /// By default, it will return a `408 Request Timeout` response if the request does not complete within the specified timeout.
64    /// To customize the response status code, use the `with_status_code` method.
65    #[deprecated(since = "0.6.7", note = "Use `Timeout::with_status_code` instead")]
66    pub fn new(inner: S, timeout: Duration) -> Self {
67        Self::with_status_code(inner, StatusCode::REQUEST_TIMEOUT, timeout)
68    }
69
70    /// Creates a new [`Timeout`] with the specified status code for the timeout response.
71    pub fn with_status_code(inner: S, status_code: StatusCode, timeout: Duration) -> Self {
72        Self {
73            inner,
74            timeout,
75            status_code,
76        }
77    }
78
79    define_inner_service_accessors!();
80
81    /// Returns a new [`Layer`] that wraps services with a `Timeout` middleware.
82    ///
83    /// [`Layer`]: tower_layer::Layer
84    #[deprecated(
85        since = "0.6.7",
86        note = "Use `Timeout::layer_with_status_code` instead"
87    )]
88    pub fn layer(timeout: Duration) -> TimeoutLayer {
89        TimeoutLayer::with_status_code(StatusCode::REQUEST_TIMEOUT, timeout)
90    }
91
92    /// Returns a new [`Layer`] that wraps services with a `Timeout` middleware with the specified status code.
93    pub fn layer_with_status_code(status_code: StatusCode, timeout: Duration) -> TimeoutLayer {
94        TimeoutLayer::with_status_code(status_code, timeout)
95    }
96}
97
98impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for Timeout<S>
99where
100    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
101    ResBody: Default,
102{
103    type Response = S::Response;
104    type Error = S::Error;
105    type Future = ResponseFuture<S::Future>;
106
107    #[inline]
108    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
109        self.inner.poll_ready(cx)
110    }
111
112    fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
113        let sleep = tokio::time::sleep(self.timeout);
114        ResponseFuture {
115            inner: self.inner.call(req),
116            sleep,
117            status_code: self.status_code,
118        }
119    }
120}
121
122pin_project! {
123    /// Response future for [`Timeout`].
124    pub struct ResponseFuture<F> {
125        #[pin]
126        inner: F,
127        #[pin]
128        sleep: Sleep,
129        status_code: StatusCode,
130    }
131}
132
133impl<F, B, E> Future for ResponseFuture<F>
134where
135    F: Future<Output = Result<Response<B>, E>>,
136    B: Default,
137{
138    type Output = Result<Response<B>, E>;
139
140    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
141        let this = self.project();
142
143        if this.sleep.poll(cx).is_ready() {
144            let mut res = Response::new(B::default());
145            *res.status_mut() = *this.status_code;
146            return Poll::Ready(Ok(res));
147        }
148
149        this.inner.poll(cx)
150    }
151}
152
153/// Applies a [`TimeoutBody`] to the request body.
154#[derive(Clone, Debug)]
155pub struct RequestBodyTimeoutLayer {
156    timeout: Duration,
157}
158
159impl RequestBodyTimeoutLayer {
160    /// Creates a new [`RequestBodyTimeoutLayer`].
161    pub fn new(timeout: Duration) -> Self {
162        Self { timeout }
163    }
164}
165
166impl<S> Layer<S> for RequestBodyTimeoutLayer {
167    type Service = RequestBodyTimeout<S>;
168
169    fn layer(&self, inner: S) -> Self::Service {
170        RequestBodyTimeout::new(inner, self.timeout)
171    }
172}
173
174/// Applies a [`TimeoutBody`] to the request body.
175#[derive(Clone, Debug)]
176pub struct RequestBodyTimeout<S> {
177    inner: S,
178    timeout: Duration,
179}
180
181impl<S> RequestBodyTimeout<S> {
182    /// Creates a new [`RequestBodyTimeout`].
183    pub fn new(service: S, timeout: Duration) -> Self {
184        Self {
185            inner: service,
186            timeout,
187        }
188    }
189
190    /// Returns a new [`Layer`] that wraps services with a [`RequestBodyTimeoutLayer`] middleware.
191    ///
192    /// [`Layer`]: tower_layer::Layer
193    pub fn layer(timeout: Duration) -> RequestBodyTimeoutLayer {
194        RequestBodyTimeoutLayer::new(timeout)
195    }
196
197    define_inner_service_accessors!();
198}
199
200impl<S, ReqBody> Service<Request<ReqBody>> for RequestBodyTimeout<S>
201where
202    S: Service<Request<TimeoutBody<ReqBody>>>,
203{
204    type Response = S::Response;
205    type Error = S::Error;
206    type Future = S::Future;
207
208    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
209        self.inner.poll_ready(cx)
210    }
211
212    fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
213        let req = req.map(|body| TimeoutBody::new(self.timeout, body));
214        self.inner.call(req)
215    }
216}
217
218/// Applies a [`TimeoutBody`] to the response body.
219#[derive(Clone)]
220pub struct ResponseBodyTimeoutLayer {
221    timeout: Duration,
222}
223
224impl ResponseBodyTimeoutLayer {
225    /// Creates a new [`ResponseBodyTimeoutLayer`].
226    pub fn new(timeout: Duration) -> Self {
227        Self { timeout }
228    }
229}
230
231impl<S> Layer<S> for ResponseBodyTimeoutLayer {
232    type Service = ResponseBodyTimeout<S>;
233
234    fn layer(&self, inner: S) -> Self::Service {
235        ResponseBodyTimeout::new(inner, self.timeout)
236    }
237}
238
239/// Applies a [`TimeoutBody`] to the response body.
240#[derive(Clone)]
241pub struct ResponseBodyTimeout<S> {
242    inner: S,
243    timeout: Duration,
244}
245
246impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for ResponseBodyTimeout<S>
247where
248    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
249{
250    type Response = Response<TimeoutBody<ResBody>>;
251    type Error = S::Error;
252    type Future = ResponseBodyTimeoutFuture<S::Future>;
253
254    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
255        self.inner.poll_ready(cx)
256    }
257
258    fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
259        ResponseBodyTimeoutFuture {
260            inner: self.inner.call(req),
261            timeout: self.timeout,
262        }
263    }
264}
265
266impl<S> ResponseBodyTimeout<S> {
267    /// Creates a new [`ResponseBodyTimeout`].
268    pub fn new(service: S, timeout: Duration) -> Self {
269        Self {
270            inner: service,
271            timeout,
272        }
273    }
274
275    /// Returns a new [`Layer`] that wraps services with a [`ResponseBodyTimeoutLayer`] middleware.
276    ///
277    /// [`Layer`]: tower_layer::Layer
278    pub fn layer(timeout: Duration) -> ResponseBodyTimeoutLayer {
279        ResponseBodyTimeoutLayer::new(timeout)
280    }
281
282    define_inner_service_accessors!();
283}
284
285pin_project! {
286    /// Response future for [`ResponseBodyTimeout`].
287    pub struct ResponseBodyTimeoutFuture<Fut> {
288        #[pin]
289        inner: Fut,
290        timeout: Duration,
291    }
292}
293
294impl<Fut, ResBody, E> Future for ResponseBodyTimeoutFuture<Fut>
295where
296    Fut: Future<Output = Result<Response<ResBody>, E>>,
297{
298    type Output = Result<Response<TimeoutBody<ResBody>>, E>;
299
300    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
301        let timeout = self.timeout;
302        let this = self.project();
303        let res = ready!(this.inner.poll(cx))?;
304        Poll::Ready(Ok(res.map(|body| TimeoutBody::new(timeout, body))))
305    }
306}
307
308#[cfg(test)]
309mod tests {
310    use super::*;
311    use crate::test_helpers::Body;
312    use http::{Request, Response, StatusCode};
313    use std::time::Duration;
314    use tower::{BoxError, ServiceBuilder, ServiceExt};
315
316    #[tokio::test]
317    async fn request_completes_within_timeout() {
318        let mut service = ServiceBuilder::new()
319            .layer(TimeoutLayer::with_status_code(
320                StatusCode::GATEWAY_TIMEOUT,
321                Duration::from_secs(1),
322            ))
323            .service_fn(fast_handler);
324
325        let request = Request::get("/").body(Body::empty()).unwrap();
326        let res = service.ready().await.unwrap().call(request).await.unwrap();
327
328        assert_eq!(res.status(), StatusCode::OK);
329    }
330
331    #[tokio::test]
332    async fn timeout_middleware_with_custom_status_code() {
333        let timeout_service = Timeout::with_status_code(
334            tower::service_fn(slow_handler),
335            StatusCode::REQUEST_TIMEOUT,
336            Duration::from_millis(10),
337        );
338
339        let mut service = ServiceBuilder::new().service(timeout_service);
340
341        let request = Request::get("/").body(Body::empty()).unwrap();
342        let res = service.ready().await.unwrap().call(request).await.unwrap();
343
344        assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT);
345    }
346
347    #[tokio::test]
348    async fn timeout_response_has_empty_body() {
349        let mut service = ServiceBuilder::new()
350            .layer(TimeoutLayer::with_status_code(
351                StatusCode::GATEWAY_TIMEOUT,
352                Duration::from_millis(10),
353            ))
354            .service_fn(slow_handler);
355
356        let request = Request::get("/").body(Body::empty()).unwrap();
357        let res = service.ready().await.unwrap().call(request).await.unwrap();
358
359        assert_eq!(res.status(), StatusCode::GATEWAY_TIMEOUT);
360
361        // Verify the body is empty (default)
362        use http_body_util::BodyExt;
363        let body = res.into_body();
364        let bytes = body.collect().await.unwrap().to_bytes();
365        assert!(bytes.is_empty());
366    }
367
368    #[tokio::test]
369    async fn deprecated_new_method_compatibility() {
370        #[allow(deprecated)]
371        let layer = TimeoutLayer::new(Duration::from_millis(10));
372
373        let mut service = ServiceBuilder::new().layer(layer).service_fn(slow_handler);
374
375        let request = Request::get("/").body(Body::empty()).unwrap();
376        let res = service.ready().await.unwrap().call(request).await.unwrap();
377
378        // Should use default 408 status code
379        assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT);
380    }
381
382    async fn slow_handler(_req: Request<Body>) -> Result<Response<Body>, BoxError> {
383        tokio::time::sleep(Duration::from_secs(10)).await;
384        Ok(Response::builder()
385            .status(StatusCode::OK)
386            .body(Body::empty())
387            .unwrap())
388    }
389
390    async fn fast_handler(_req: Request<Body>) -> Result<Response<Body>, BoxError> {
391        Ok(Response::builder()
392            .status(StatusCode::OK)
393            .body(Body::empty())
394            .unwrap())
395    }
396}