rama_http/layer/timeout/
body.rs1use http_body::Body;
2use pin_project_lite::pin_project;
3use rama_core::error::BoxError;
4use rama_http_types::dep::http_body;
5use std::{
6 future::Future,
7 pin::Pin,
8 task::{Context, Poll, ready},
9 time::Duration,
10};
11use tokio::time::{Sleep, sleep};
12
13pin_project! {
14 pub struct TimeoutBody<B> {
54 timeout: Duration,
55 #[pin]
56 sleep: Option<Sleep>,
57 #[pin]
58 body: B,
59 }
60}
61
62impl<B> TimeoutBody<B> {
63 pub fn new(timeout: Duration, body: B) -> Self {
65 TimeoutBody {
66 timeout,
67 sleep: None,
68 body,
69 }
70 }
71}
72
73impl<B> Body for TimeoutBody<B>
74where
75 B: Body,
76 B::Error: Into<BoxError>,
77{
78 type Data = B::Data;
79 type Error = Box<dyn std::error::Error + Send + Sync>;
80
81 fn poll_frame(
82 self: Pin<&mut Self>,
83 cx: &mut Context<'_>,
84 ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
85 let mut this = self.project();
86
87 let sleep_pinned = if let Some(some) = this.sleep.as_mut().as_pin_mut() {
89 some
90 } else {
91 this.sleep.set(Some(sleep(*this.timeout)));
92 this.sleep.as_mut().as_pin_mut().unwrap()
93 };
94
95 if let Poll::Ready(()) = sleep_pinned.poll(cx) {
97 return Poll::Ready(Some(Err(Box::new(TimeoutError(())))));
98 }
99
100 let frame = ready!(this.body.poll_frame(cx));
102 this.sleep.set(None);
104
105 Poll::Ready(frame.transpose().map_err(Into::into).transpose())
106 }
107}
108
109#[derive(Debug)]
111pub struct TimeoutError(());
112
113impl std::error::Error for TimeoutError {}
114
115impl std::fmt::Display for TimeoutError {
116 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
117 write!(f, "data was not received within the designated timeout")
118 }
119}
120#[cfg(test)]
121mod tests {
122 use super::*;
123
124 use bytes::Bytes;
125 use pin_project_lite::pin_project;
126 use rama_http_types::dep::{http_body::Frame, http_body_util::BodyExt};
127 use std::{error::Error, fmt::Display};
128
129 #[derive(Debug)]
130 struct MockError;
131
132 impl Error for MockError {}
133
134 impl Display for MockError {
135 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
136 write!(f, "mock error")
137 }
138 }
139
140 pin_project! {
141 struct MockBody {
142 #[pin]
143 sleep: Sleep
144 }
145 }
146
147 impl Body for MockBody {
148 type Data = Bytes;
149 type Error = MockError;
150
151 fn poll_frame(
152 self: Pin<&mut Self>,
153 cx: &mut Context<'_>,
154 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
155 let this = self.project();
156 this.sleep
157 .poll(cx)
158 .map(|_| Some(Ok(Frame::data(vec![].into()))))
159 }
160 }
161
162 #[tokio::test]
163 async fn test_body_available_within_timeout() {
164 let mock_sleep = Duration::from_secs(1);
165 let timeout_sleep = Duration::from_secs(2);
166
167 let mock_body = MockBody {
168 sleep: sleep(mock_sleep),
169 };
170 let timeout_body = TimeoutBody::new(timeout_sleep, mock_body);
171
172 assert!(
173 timeout_body
174 .boxed()
175 .frame()
176 .await
177 .expect("no frame")
178 .is_ok()
179 );
180 }
181
182 #[tokio::test]
183 async fn test_body_unavailable_within_timeout_error() {
184 let mock_sleep = Duration::from_secs(2);
185 let timeout_sleep = Duration::from_secs(1);
186
187 let mock_body = MockBody {
188 sleep: sleep(mock_sleep),
189 };
190 let timeout_body = TimeoutBody::new(timeout_sleep, mock_body);
191
192 assert!(timeout_body.boxed().frame().await.unwrap().is_err());
193 }
194}