rama_http/layer/timeout/
body.rs

1use http_body::Body;
2use pin_project_lite::pin_project;
3use rama_core::error::BoxError;
4use rama_http_types::dep::http_body;
5use std::{
6    future::Future,
7    pin::Pin,
8    task::{Context, Poll, ready},
9    time::Duration,
10};
11use tokio::time::{Sleep, sleep};
12
13pin_project! {
14    /// Middleware that applies a timeout to request and response bodies.
15    ///
16    /// Wrapper around a [`http_body::Body`] to time out if data is not ready within the specified duration.
17    ///
18    /// Bodies must produce data at most within the specified timeout.
19    /// If the body does not produce a requested data frame within the timeout period, it will return an error.
20    ///
21    /// # Differences from [`Timeout`][crate::timeout::Timeout]
22    ///
23    /// [`Timeout`][crate::timeout::Timeout] applies a timeout to the request future, not body.
24    /// That timeout is not reset when bytes are handled, whether the request is active or not.
25    /// Bodies are handled asynchronously outside of the tower stack's future and thus needs an additional timeout.
26    ///
27    /// This middleware will return a [`TimeoutError`].
28    ///
29    /// # Example
30    ///
31    /// ```
32    /// use bytes::Bytes;
33    /// use std::time::Duration;
34    /// use rama_http::{Request, Response};
35    /// use rama_http::dep::http_body_util::Full;
36    /// use rama_http::layer::timeout::RequestBodyTimeoutLayer;
37    /// use rama_core::{Layer, service::service_fn};
38    ///
39    /// async fn handle(_: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, std::convert::Infallible> {
40    ///     // ...
41    ///     # todo!()
42    /// }
43    ///
44    /// # #[tokio::main]
45    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
46    /// let svc = RequestBodyTimeoutLayer::new(
47    ///     // Timeout bodies after 30 seconds of inactivity
48    ///     Duration::from_secs(30)
49    /// ).layer(service_fn(handle));
50    /// # Ok(())
51    /// # }
52    /// ```
53    pub struct TimeoutBody<B> {
54        timeout: Duration,
55        #[pin]
56        sleep: Option<Sleep>,
57        #[pin]
58        body: B,
59    }
60}
61
62impl<B> TimeoutBody<B> {
63    /// Creates a new [`TimeoutBody`].
64    pub fn new(timeout: Duration, body: B) -> Self {
65        TimeoutBody {
66            timeout,
67            sleep: None,
68            body,
69        }
70    }
71}
72
73impl<B> Body for TimeoutBody<B>
74where
75    B: Body,
76    B::Error: Into<BoxError>,
77{
78    type Data = B::Data;
79    type Error = Box<dyn std::error::Error + Send + Sync>;
80
81    fn poll_frame(
82        self: Pin<&mut Self>,
83        cx: &mut Context<'_>,
84    ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
85        let mut this = self.project();
86
87        // Start the `Sleep` if not active.
88        let sleep_pinned = if let Some(some) = this.sleep.as_mut().as_pin_mut() {
89            some
90        } else {
91            this.sleep.set(Some(sleep(*this.timeout)));
92            this.sleep.as_mut().as_pin_mut().unwrap()
93        };
94
95        // Error if the timeout has expired.
96        if let Poll::Ready(()) = sleep_pinned.poll(cx) {
97            return Poll::Ready(Some(Err(Box::new(TimeoutError(())))));
98        }
99
100        // Check for body data.
101        let frame = ready!(this.body.poll_frame(cx));
102        // A frame is ready. Reset the `Sleep`...
103        this.sleep.set(None);
104
105        Poll::Ready(frame.transpose().map_err(Into::into).transpose())
106    }
107}
108
109/// Error for [`TimeoutBody`].
110#[derive(Debug)]
111pub struct TimeoutError(());
112
113impl std::error::Error for TimeoutError {}
114
115impl std::fmt::Display for TimeoutError {
116    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
117        write!(f, "data was not received within the designated timeout")
118    }
119}
120#[cfg(test)]
121mod tests {
122    use super::*;
123
124    use bytes::Bytes;
125    use pin_project_lite::pin_project;
126    use rama_http_types::dep::{http_body::Frame, http_body_util::BodyExt};
127    use std::{error::Error, fmt::Display};
128
129    #[derive(Debug)]
130    struct MockError;
131
132    impl Error for MockError {}
133
134    impl Display for MockError {
135        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
136            write!(f, "mock error")
137        }
138    }
139
140    pin_project! {
141        struct MockBody {
142            #[pin]
143            sleep: Sleep
144        }
145    }
146
147    impl Body for MockBody {
148        type Data = Bytes;
149        type Error = MockError;
150
151        fn poll_frame(
152            self: Pin<&mut Self>,
153            cx: &mut Context<'_>,
154        ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
155            let this = self.project();
156            this.sleep
157                .poll(cx)
158                .map(|_| Some(Ok(Frame::data(vec![].into()))))
159        }
160    }
161
162    #[tokio::test]
163    async fn test_body_available_within_timeout() {
164        let mock_sleep = Duration::from_secs(1);
165        let timeout_sleep = Duration::from_secs(2);
166
167        let mock_body = MockBody {
168            sleep: sleep(mock_sleep),
169        };
170        let timeout_body = TimeoutBody::new(timeout_sleep, mock_body);
171
172        assert!(
173            timeout_body
174                .boxed()
175                .frame()
176                .await
177                .expect("no frame")
178                .is_ok()
179        );
180    }
181
182    #[tokio::test]
183    async fn test_body_unavailable_within_timeout_error() {
184        let mock_sleep = Duration::from_secs(2);
185        let timeout_sleep = Duration::from_secs(1);
186
187        let mock_body = MockBody {
188            sleep: sleep(mock_sleep),
189        };
190        let timeout_body = TimeoutBody::new(timeout_sleep, mock_body);
191
192        assert!(timeout_body.boxed().frame().await.unwrap().is_err());
193    }
194}