1use core::{
2 pin::Pin,
3 task::{Context, Poll},
4};
5
6use bytes::Bytes;
7use pin_project_lite::pin_project;
8use warp::{
9 hyper::{Body as HyperBody, Request as HyperRequest},
10 Buf, Error as WarpError, Stream,
11};
12
13pub mod error;
14pub mod utils;
15
16use error::Error;
17
18pin_project! {
20 #[project = BodyProj]
21 pub enum Body {
22 Buf { inner: Box<dyn Buf + Send + 'static> },
23 Bytes { inner: Bytes },
24 Stream { #[pin] inner: Pin<Box<dyn Stream<Item = Result<Bytes, WarpError>> + Send + 'static>> },
25 HyperBody { #[pin] inner: HyperBody }
26 }
27}
28
29impl core::fmt::Debug for Body {
30 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
31 match self {
32 Self::Buf { inner } => f.debug_tuple("Buf").field(&inner.chunk()).finish(),
33 Self::Bytes { inner } => f.debug_tuple("Bytes").field(&inner).finish(),
34 Self::Stream { inner: _ } => write!(f, "Stream"),
35 Self::HyperBody { inner: _ } => write!(f, "HyperBody"),
36 }
37 }
38}
39
40impl core::fmt::Display for Body {
41 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
42 write!(f, "{self:?}")
43 }
44}
45
46impl Default for Body {
47 fn default() -> Self {
48 Self::Bytes {
49 inner: Bytes::default(),
50 }
51 }
52}
53
54impl Body {
56 pub fn with_buf(buf: impl Buf + Send + 'static) -> Self {
57 Self::Buf {
58 inner: Box::new(buf),
59 }
60 }
61
62 pub fn with_bytes(bytes: Bytes) -> Self {
63 Self::Bytes { inner: bytes }
64 }
65
66 pub fn with_stream(
67 stream: impl Stream<Item = Result<impl Buf + 'static, WarpError>> + Send + 'static,
68 ) -> Self {
69 Self::Stream {
70 inner: Box::pin(utils::buf_stream_to_bytes_stream(stream)),
71 }
72 }
73
74 pub fn with_hyper_body(hyper_body: HyperBody) -> Self {
75 Self::HyperBody { inner: hyper_body }
76 }
77}
78
79impl From<HyperBody> for Body {
80 fn from(body: HyperBody) -> Self {
81 Self::with_hyper_body(body)
82 }
83}
84
85impl Body {
86 pub fn require_to_bytes_async(&self) -> bool {
87 matches!(
88 self,
89 Self::Stream { inner: _ } | Self::HyperBody { inner: _ }
90 )
91 }
92
93 pub fn to_bytes(self) -> Bytes {
94 match self {
95 Self::Buf { inner } => utils::buf_to_bytes(inner),
96 Self::Bytes { inner } => inner,
97 Self::Stream { inner: _ } => panic!("Please call require_to_bytes_async first"),
98 Self::HyperBody { inner: _ } => panic!("Please call require_to_bytes_async first"),
99 }
100 }
101
102 pub async fn to_bytes_async(self) -> Result<Bytes, Error> {
103 match self {
104 Self::Buf { inner } => Ok(utils::buf_to_bytes(inner)),
105 Self::Bytes { inner } => Ok(inner),
106 Self::Stream { inner } => utils::bytes_stream_to_bytes(inner)
107 .await
108 .map_err(Into::into),
109 Self::HyperBody { inner } => {
110 utils::hyper_body_to_bytes(inner).await.map_err(Into::into)
111 }
112 }
113 }
114}
115
116impl Stream for Body {
120 type Item = Result<Bytes, Error>;
121
122 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
123 match self.project() {
124 BodyProj::Buf { inner: buf } => {
125 if buf.has_remaining() {
126 let bytes = Bytes::copy_from_slice(buf.chunk());
127 let cnt = buf.chunk().len();
128 buf.advance(cnt);
129 Poll::Ready(Some(Ok(bytes)))
130 } else {
131 Poll::Ready(None)
132 }
133 }
134 BodyProj::Bytes { inner } => {
135 if !inner.is_empty() {
136 let bytes = inner.clone();
137 inner.clear();
138 Poll::Ready(Some(Ok(bytes)))
139 } else {
140 Poll::Ready(None)
141 }
142 }
143 BodyProj::Stream { inner } => inner.poll_next(cx).map_err(Into::into),
144 BodyProj::HyperBody { inner } => inner.poll_next(cx).map_err(Into::into),
145 }
146 }
147}
148
149pub fn buf_request_to_body_request(
151 req: HyperRequest<impl Buf + Send + 'static>,
152) -> HyperRequest<Body> {
153 let (parts, body) = req.into_parts();
154 HyperRequest::from_parts(parts, Body::with_buf(body))
155}
156
157pub fn bytes_request_to_body_request(req: HyperRequest<Bytes>) -> HyperRequest<Body> {
158 let (parts, body) = req.into_parts();
159 HyperRequest::from_parts(parts, Body::with_bytes(body))
160}
161
162pub fn stream_request_to_body_request(
163 req: HyperRequest<impl Stream<Item = Result<impl Buf + 'static, WarpError>> + Send + 'static>,
164) -> HyperRequest<Body> {
165 let (parts, body) = req.into_parts();
166 HyperRequest::from_parts(parts, Body::with_stream(body))
167}
168
169pub fn hyper_body_request_to_body_request(req: HyperRequest<HyperBody>) -> HyperRequest<Body> {
170 let (parts, body) = req.into_parts();
171 HyperRequest::from_parts(parts, Body::with_hyper_body(body))
172}
173
174#[cfg(test)]
175mod tests {
176 use futures_util::{stream::BoxStream, StreamExt as _, TryStreamExt};
177
178 use super::*;
179
180 #[tokio::test]
181 async fn test_with_buf() {
182 let buf = warp::test::request()
184 .body("foo")
185 .filter(&warp::body::aggregate())
186 .await
187 .unwrap();
188 let body = Body::with_buf(buf);
189 assert!(matches!(body, Body::Buf { inner: _ }));
190 assert!(!body.require_to_bytes_async());
191 assert_eq!(body.to_bytes(), Bytes::copy_from_slice(b"foo"));
192
193 let buf = warp::test::request()
195 .body("foo")
196 .filter(&warp::body::aggregate())
197 .await
198 .unwrap();
199 let body = Body::with_buf(buf);
200 assert_eq!(
201 body.to_bytes_async().await.unwrap(),
202 Bytes::copy_from_slice(b"foo")
203 );
204
205 let buf = warp::test::request()
207 .body("foo")
208 .filter(&warp::body::aggregate())
209 .await
210 .unwrap();
211 let mut body = Body::with_buf(buf);
212 assert_eq!(
213 body.next().await.unwrap().unwrap(),
214 Bytes::copy_from_slice(b"foo")
215 );
216 assert!(body.next().await.is_none());
217
218 let req = warp::test::request()
220 .body("foo")
221 .filter(&warp_filter_request::with_body_aggregate())
222 .await
223 .unwrap();
224 let (_, body) = buf_request_to_body_request(req).into_parts();
225 assert!(matches!(body, Body::Buf { inner: _ }));
226 assert!(!body.require_to_bytes_async());
227 assert_eq!(body.to_bytes(), Bytes::copy_from_slice(b"foo"));
228 }
229
230 #[tokio::test]
231 async fn test_with_bytes() {
232 let bytes = warp::test::request()
234 .body("foo")
235 .filter(&warp::body::bytes())
236 .await
237 .unwrap();
238 let body = Body::with_bytes(bytes);
239 assert!(matches!(body, Body::Bytes { inner: _ }));
240 assert!(!body.require_to_bytes_async());
241 assert_eq!(body.to_bytes(), Bytes::copy_from_slice(b"foo"));
242
243 let bytes = warp::test::request()
245 .body("foo")
246 .filter(&warp::body::bytes())
247 .await
248 .unwrap();
249 let body = Body::with_bytes(bytes);
250 assert_eq!(
251 body.to_bytes_async().await.unwrap(),
252 Bytes::copy_from_slice(b"foo")
253 );
254
255 let bytes = warp::test::request()
257 .body("foo")
258 .filter(&warp::body::bytes())
259 .await
260 .unwrap();
261 let mut body = Body::with_bytes(bytes);
262 assert_eq!(
263 body.next().await.unwrap().unwrap(),
264 Bytes::copy_from_slice(b"foo")
265 );
266 assert!(body.next().await.is_none());
267
268 let req = warp::test::request()
270 .body("foo")
271 .filter(&warp_filter_request::with_body_bytes())
272 .await
273 .unwrap();
274 let (_, body) = bytes_request_to_body_request(req).into_parts();
275 assert!(matches!(body, Body::Bytes { inner: _ }));
276 assert!(!body.require_to_bytes_async());
277 assert_eq!(body.to_bytes(), Bytes::copy_from_slice(b"foo"));
278 }
279
280 #[tokio::test]
281 async fn test_with_stream() {
282 let stream = warp::test::request()
284 .body("foo")
285 .filter(&warp::body::stream())
286 .await
287 .unwrap();
288 let body = Body::with_stream(stream);
289 assert!(matches!(body, Body::Stream { inner: _ }));
290 assert!(body.require_to_bytes_async());
291 assert_eq!(
292 body.to_bytes_async().await.unwrap(),
293 Bytes::copy_from_slice(b"foo")
294 );
295
296 let stream = warp::test::request()
298 .body("foo")
299 .filter(&warp::body::stream())
300 .await
301 .unwrap();
302 let mut body = Body::with_stream(stream);
303 assert_eq!(
304 body.next().await.unwrap().unwrap(),
305 Bytes::copy_from_slice(b"foo")
306 );
307 assert!(body.next().await.is_none());
308
309 let req = warp::test::request()
311 .body("foo")
312 .filter(&warp_filter_request::with_body_stream())
313 .await
314 .unwrap();
315 let (_, body) = stream_request_to_body_request(req).into_parts();
316 assert!(matches!(body, Body::Stream { inner: _ }));
317 assert!(body.require_to_bytes_async());
318 assert_eq!(
319 body.to_bytes_async().await.unwrap(),
320 Bytes::copy_from_slice(b"foo")
321 );
322 }
323
324 #[tokio::test]
325 async fn test_with_hyper_body() {
326 let hyper_body = HyperBody::from("foo");
328 let body = Body::with_hyper_body(hyper_body);
329 assert!(matches!(body, Body::HyperBody { inner: _ }));
330 assert!(body.require_to_bytes_async());
331 assert_eq!(
332 body.to_bytes_async().await.unwrap(),
333 Bytes::copy_from_slice(b"foo")
334 );
335
336 let hyper_body = HyperBody::from("foo");
338 let mut body = Body::with_hyper_body(hyper_body);
339 assert_eq!(
340 body.next().await.unwrap().unwrap(),
341 Bytes::copy_from_slice(b"foo")
342 );
343 assert!(body.next().await.is_none());
344
345 let req = HyperRequest::new(HyperBody::from("foo"));
347 let (_, body) = hyper_body_request_to_body_request(req).into_parts();
348 assert!(matches!(body, Body::HyperBody { inner: _ }));
349 assert!(body.require_to_bytes_async());
350 assert_eq!(
351 body.to_bytes_async().await.unwrap(),
352 Bytes::copy_from_slice(b"foo")
353 );
354 }
355
356 pin_project! {
357 pub struct BodyWrapper {
358 #[pin]
359 inner: BoxStream<'static, Result<Bytes, Box<dyn std::error::Error + Send + Sync + 'static>>>
360 }
361 }
362 #[tokio::test]
363 async fn test_wrapper() {
364 let buf = warp::test::request()
366 .body("foo")
367 .filter(&warp::body::aggregate())
368 .await
369 .unwrap();
370 let body = Body::with_buf(buf);
371 let _ = BodyWrapper {
372 inner: body.err_into().boxed(),
373 };
374
375 let stream = warp::test::request()
377 .body("foo")
378 .filter(&warp::body::stream())
379 .await
380 .unwrap();
381 let body = Body::with_stream(stream);
382 let _ = BodyWrapper {
383 inner: body.err_into().boxed(),
384 };
385 }
386}