tower_http/timeout/
body.rs

1use crate::BoxError;
2use http_body::Body;
3use pin_project_lite::pin_project;
4use std::{
5    future::Future,
6    pin::Pin,
7    task::{ready, Context, Poll},
8    time::Duration,
9};
10use tokio::time::{sleep, Sleep};
11
12pin_project! {
13    /// Middleware that applies a timeout to request and response bodies.
14    ///
15    /// Wrapper around a [`Body`][`http_body::Body`] to time out if data is not ready within the specified duration.
16    /// The timeout is enforced between consecutive [`Frame`][`http_body::Frame`] polls, and it
17    /// resets after each poll.
18    /// The total time to produce a [`Body`][`http_body::Body`] could exceed the timeout duration without
19    /// timing out, as long as no single interval between polls exceeds the timeout.
20    ///
21    /// If the [`Body`][`http_body::Body`] does not produce a requested data frame within the timeout period, it will return a [`TimeoutError`].
22    ///
23    /// # Differences from [`Timeout`][crate::timeout::Timeout]
24    ///
25    /// [`Timeout`][crate::timeout::Timeout] applies a timeout to the request future, not body.
26    /// That timeout is not reset when bytes are handled, whether the request is active or not.
27    /// Bodies are handled asynchronously outside of the tower stack's future and thus needs an additional timeout.
28    ///
29    /// # Example
30    ///
31    /// ```
32    /// use http::{Request, Response};
33    /// use bytes::Bytes;
34    /// use http_body_util::Full;
35    /// use std::time::Duration;
36    /// use tower::ServiceBuilder;
37    /// use tower_http::timeout::RequestBodyTimeoutLayer;
38    ///
39    /// async fn handle(_: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, std::convert::Infallible> {
40    ///     // ...
41    ///     # todo!()
42    /// }
43    ///
44    /// # #[tokio::main]
45    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
46    /// let svc = ServiceBuilder::new()
47    ///     // Timeout bodies after 30 seconds of inactivity
48    ///     .layer(RequestBodyTimeoutLayer::new(Duration::from_secs(30)))
49    ///     .service_fn(handle);
50    /// # Ok(())
51    /// # }
52    /// ```
53    pub struct TimeoutBody<B> {
54        timeout: Duration,
55        #[pin]
56        sleep: Option<Sleep>,
57        #[pin]
58        body: B,
59    }
60}
61
62impl<B> TimeoutBody<B> {
63    /// Creates a new [`TimeoutBody`].
64    pub fn new(timeout: Duration, body: B) -> Self {
65        TimeoutBody {
66            timeout,
67            sleep: None,
68            body,
69        }
70    }
71}
72
73impl<B> Body for TimeoutBody<B>
74where
75    B: Body,
76    B::Error: Into<BoxError>,
77{
78    type Data = B::Data;
79    type Error = Box<dyn std::error::Error + Send + Sync>;
80
81    fn poll_frame(
82        self: Pin<&mut Self>,
83        cx: &mut Context<'_>,
84    ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
85        let mut this = self.project();
86
87        // Start the `Sleep` if not active.
88        let sleep_pinned = if let Some(some) = this.sleep.as_mut().as_pin_mut() {
89            some
90        } else {
91            this.sleep.set(Some(sleep(*this.timeout)));
92            this.sleep.as_mut().as_pin_mut().unwrap()
93        };
94
95        // Error if the timeout has expired.
96        if let Poll::Ready(()) = sleep_pinned.poll(cx) {
97            return Poll::Ready(Some(Err(Box::new(TimeoutError(())))));
98        }
99
100        // Check for body data.
101        let frame = ready!(this.body.poll_frame(cx));
102        // A frame is ready. Reset the `Sleep`...
103        this.sleep.set(None);
104
105        Poll::Ready(frame.transpose().map_err(Into::into).transpose())
106    }
107}
108
109/// Error for [`TimeoutBody`].
110#[derive(Debug)]
111pub struct TimeoutError(());
112
113impl std::error::Error for TimeoutError {}
114
115impl std::fmt::Display for TimeoutError {
116    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
117        write!(f, "data was not received within the designated timeout")
118    }
119}
120#[cfg(test)]
121mod tests {
122    use super::*;
123
124    use bytes::Bytes;
125    use http_body::Frame;
126    use http_body_util::BodyExt;
127    use pin_project_lite::pin_project;
128    use std::{error::Error, fmt::Display};
129
130    #[derive(Debug)]
131    struct MockError;
132
133    impl Error for MockError {}
134
135    impl Display for MockError {
136        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
137            write!(f, "mock error")
138        }
139    }
140
141    pin_project! {
142        struct MockBody {
143            #[pin]
144            sleep: Sleep
145        }
146    }
147
148    impl Body for MockBody {
149        type Data = Bytes;
150        type Error = MockError;
151
152        fn poll_frame(
153            self: Pin<&mut Self>,
154            cx: &mut Context<'_>,
155        ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
156            let this = self.project();
157            this.sleep
158                .poll(cx)
159                .map(|_| Some(Ok(Frame::data(vec![].into()))))
160        }
161    }
162
163    #[tokio::test]
164    async fn test_body_available_within_timeout() {
165        let mock_sleep = Duration::from_secs(1);
166        let timeout_sleep = Duration::from_secs(2);
167
168        let mock_body = MockBody {
169            sleep: sleep(mock_sleep),
170        };
171        let timeout_body = TimeoutBody::new(timeout_sleep, mock_body);
172
173        assert!(timeout_body
174            .boxed()
175            .frame()
176            .await
177            .expect("no frame")
178            .is_ok());
179    }
180
181    #[tokio::test]
182    async fn test_body_unavailable_within_timeout_error() {
183        let mock_sleep = Duration::from_secs(2);
184        let timeout_sleep = Duration::from_secs(1);
185
186        let mock_body = MockBody {
187            sleep: sleep(mock_sleep),
188        };
189        let timeout_body = TimeoutBody::new(timeout_sleep, mock_body);
190
191        assert!(timeout_body.boxed().frame().await.unwrap().is_err());
192    }
193}