tower_http/timeout/
body.rs1use crate::BoxError;
2use http_body::Body;
3use pin_project_lite::pin_project;
4use std::{
5 future::Future,
6 pin::Pin,
7 task::{ready, Context, Poll},
8 time::Duration,
9};
10use tokio::time::{sleep, Sleep};
11
12pin_project! {
13 pub struct TimeoutBody<B> {
54 timeout: Duration,
55 #[pin]
56 sleep: Option<Sleep>,
57 #[pin]
58 body: B,
59 }
60}
61
62impl<B> TimeoutBody<B> {
63 pub fn new(timeout: Duration, body: B) -> Self {
65 TimeoutBody {
66 timeout,
67 sleep: None,
68 body,
69 }
70 }
71}
72
73impl<B> Body for TimeoutBody<B>
74where
75 B: Body,
76 B::Error: Into<BoxError>,
77{
78 type Data = B::Data;
79 type Error = Box<dyn std::error::Error + Send + Sync>;
80
81 fn poll_frame(
82 self: Pin<&mut Self>,
83 cx: &mut Context<'_>,
84 ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
85 let mut this = self.project();
86
87 let sleep_pinned = if let Some(some) = this.sleep.as_mut().as_pin_mut() {
89 some
90 } else {
91 this.sleep.set(Some(sleep(*this.timeout)));
92 this.sleep.as_mut().as_pin_mut().unwrap()
93 };
94
95 if let Poll::Ready(()) = sleep_pinned.poll(cx) {
97 return Poll::Ready(Some(Err(Box::new(TimeoutError(())))));
98 }
99
100 let frame = ready!(this.body.poll_frame(cx));
102 this.sleep.set(None);
104
105 Poll::Ready(frame.transpose().map_err(Into::into).transpose())
106 }
107}
108
109#[derive(Debug)]
111pub struct TimeoutError(());
112
113impl std::error::Error for TimeoutError {}
114
115impl std::fmt::Display for TimeoutError {
116 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
117 write!(f, "data was not received within the designated timeout")
118 }
119}
120#[cfg(test)]
121mod tests {
122 use super::*;
123
124 use bytes::Bytes;
125 use http_body::Frame;
126 use http_body_util::BodyExt;
127 use pin_project_lite::pin_project;
128 use std::{error::Error, fmt::Display};
129
130 #[derive(Debug)]
131 struct MockError;
132
133 impl Error for MockError {}
134
135 impl Display for MockError {
136 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
137 write!(f, "mock error")
138 }
139 }
140
141 pin_project! {
142 struct MockBody {
143 #[pin]
144 sleep: Sleep
145 }
146 }
147
148 impl Body for MockBody {
149 type Data = Bytes;
150 type Error = MockError;
151
152 fn poll_frame(
153 self: Pin<&mut Self>,
154 cx: &mut Context<'_>,
155 ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
156 let this = self.project();
157 this.sleep
158 .poll(cx)
159 .map(|_| Some(Ok(Frame::data(vec![].into()))))
160 }
161 }
162
163 #[tokio::test]
164 async fn test_body_available_within_timeout() {
165 let mock_sleep = Duration::from_secs(1);
166 let timeout_sleep = Duration::from_secs(2);
167
168 let mock_body = MockBody {
169 sleep: sleep(mock_sleep),
170 };
171 let timeout_body = TimeoutBody::new(timeout_sleep, mock_body);
172
173 assert!(timeout_body
174 .boxed()
175 .frame()
176 .await
177 .expect("no frame")
178 .is_ok());
179 }
180
181 #[tokio::test]
182 async fn test_body_unavailable_within_timeout_error() {
183 let mock_sleep = Duration::from_secs(2);
184 let timeout_sleep = Duration::from_secs(1);
185
186 let mock_body = MockBody {
187 sleep: sleep(mock_sleep),
188 };
189 let timeout_body = TimeoutBody::new(timeout_sleep, mock_body);
190
191 assert!(timeout_body.boxed().frame().await.unwrap().is_err());
192 }
193}