1use std::marker::PhantomData;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4use std::{fmt, mem};
5
6use bytes::{Bytes, BytesMut};
7use futures_core::Stream;
8use pin_project::{pin_project, project};
9
10use crate::error::Error;
11
12#[derive(Debug, PartialEq, Copy, Clone)]
13pub enum BodySize {
15 None,
16 Empty,
17 Sized(usize),
18 Sized64(u64),
19 Stream,
20}
21
22impl BodySize {
23 pub fn is_eof(&self) -> bool {
24 match self {
25 BodySize::None
26 | BodySize::Empty
27 | BodySize::Sized(0)
28 | BodySize::Sized64(0) => true,
29 _ => false,
30 }
31 }
32}
33
34pub trait MessageBody {
36 fn size(&self) -> BodySize;
37
38 fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<Bytes, Error>>>;
39}
40
41impl MessageBody for () {
42 fn size(&self) -> BodySize {
43 BodySize::Empty
44 }
45
46 fn poll_next(&mut self, _: &mut Context<'_>) -> Poll<Option<Result<Bytes, Error>>> {
47 Poll::Ready(None)
48 }
49}
50
51impl<T: MessageBody> MessageBody for Box<T> {
52 fn size(&self) -> BodySize {
53 self.as_ref().size()
54 }
55
56 fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<Bytes, Error>>> {
57 self.as_mut().poll_next(cx)
58 }
59}
60
61#[pin_project]
62pub enum ResponseBody<B> {
63 Body(B),
64 Other(Body),
65}
66
67impl ResponseBody<Body> {
68 pub fn into_body<B>(self) -> ResponseBody<B> {
69 match self {
70 ResponseBody::Body(b) => ResponseBody::Other(b),
71 ResponseBody::Other(b) => ResponseBody::Other(b),
72 }
73 }
74}
75
76impl<B> ResponseBody<B> {
77 pub fn take_body(&mut self) -> ResponseBody<B> {
78 std::mem::replace(self, ResponseBody::Other(Body::None))
79 }
80}
81
82impl<B: MessageBody> ResponseBody<B> {
83 pub fn as_ref(&self) -> Option<&B> {
84 if let ResponseBody::Body(ref b) = self {
85 Some(b)
86 } else {
87 None
88 }
89 }
90}
91
92impl<B: MessageBody> MessageBody for ResponseBody<B> {
93 fn size(&self) -> BodySize {
94 match self {
95 ResponseBody::Body(ref body) => body.size(),
96 ResponseBody::Other(ref body) => body.size(),
97 }
98 }
99
100 fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<Bytes, Error>>> {
101 match self {
102 ResponseBody::Body(ref mut body) => body.poll_next(cx),
103 ResponseBody::Other(ref mut body) => body.poll_next(cx),
104 }
105 }
106}
107
108impl<B: MessageBody> Stream for ResponseBody<B> {
109 type Item = Result<Bytes, Error>;
110
111 #[project]
112 fn poll_next(
113 self: Pin<&mut Self>,
114 cx: &mut Context<'_>,
115 ) -> Poll<Option<Self::Item>> {
116 #[project]
117 match self.project() {
118 ResponseBody::Body(ref mut body) => body.poll_next(cx),
119 ResponseBody::Other(ref mut body) => body.poll_next(cx),
120 }
121 }
122}
123
124pub enum Body {
126 None,
128 Empty,
130 Bytes(Bytes),
132 Message(Box<dyn MessageBody>),
134}
135
136impl Body {
137 pub fn from_slice(s: &[u8]) -> Body {
139 Body::Bytes(Bytes::copy_from_slice(s))
140 }
141
142 pub fn from_message<B: MessageBody + 'static>(body: B) -> Body {
144 Body::Message(Box::new(body))
145 }
146}
147
148impl MessageBody for Body {
149 fn size(&self) -> BodySize {
150 match self {
151 Body::None => BodySize::None,
152 Body::Empty => BodySize::Empty,
153 Body::Bytes(ref bin) => BodySize::Sized(bin.len()),
154 Body::Message(ref body) => body.size(),
155 }
156 }
157
158 fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<Bytes, Error>>> {
159 match self {
160 Body::None => Poll::Ready(None),
161 Body::Empty => Poll::Ready(None),
162 Body::Bytes(ref mut bin) => {
163 let len = bin.len();
164 if len == 0 {
165 Poll::Ready(None)
166 } else {
167 Poll::Ready(Some(Ok(mem::replace(bin, Bytes::new()))))
168 }
169 }
170 Body::Message(ref mut body) => body.poll_next(cx),
171 }
172 }
173}
174
175impl PartialEq for Body {
176 fn eq(&self, other: &Body) -> bool {
177 match *self {
178 Body::None => match *other {
179 Body::None => true,
180 _ => false,
181 },
182 Body::Empty => match *other {
183 Body::Empty => true,
184 _ => false,
185 },
186 Body::Bytes(ref b) => match *other {
187 Body::Bytes(ref b2) => b == b2,
188 _ => false,
189 },
190 Body::Message(_) => false,
191 }
192 }
193}
194
195impl fmt::Debug for Body {
196 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
197 match *self {
198 Body::None => write!(f, "Body::None"),
199 Body::Empty => write!(f, "Body::Empty"),
200 Body::Bytes(ref b) => write!(f, "Body::Bytes({:?})", b),
201 Body::Message(_) => write!(f, "Body::Message(_)"),
202 }
203 }
204}
205
206impl From<&'static str> for Body {
207 fn from(s: &'static str) -> Body {
208 Body::Bytes(Bytes::from_static(s.as_ref()))
209 }
210}
211
212impl From<&'static [u8]> for Body {
213 fn from(s: &'static [u8]) -> Body {
214 Body::Bytes(Bytes::from_static(s))
215 }
216}
217
218impl From<Vec<u8>> for Body {
219 fn from(vec: Vec<u8>) -> Body {
220 Body::Bytes(Bytes::from(vec))
221 }
222}
223
224impl From<String> for Body {
225 fn from(s: String) -> Body {
226 s.into_bytes().into()
227 }
228}
229
230impl<'a> From<&'a String> for Body {
231 fn from(s: &'a String) -> Body {
232 Body::Bytes(Bytes::copy_from_slice(AsRef::<[u8]>::as_ref(&s)))
233 }
234}
235
236impl From<Bytes> for Body {
237 fn from(s: Bytes) -> Body {
238 Body::Bytes(s)
239 }
240}
241
242impl From<BytesMut> for Body {
243 fn from(s: BytesMut) -> Body {
244 Body::Bytes(s.freeze())
245 }
246}
247
248impl From<serde_json::Value> for Body {
249 fn from(v: serde_json::Value) -> Body {
250 Body::Bytes(v.to_string().into())
251 }
252}
253
254impl<S> From<SizedStream<S>> for Body
255where
256 S: Stream<Item = Result<Bytes, Error>> + 'static,
257{
258 fn from(s: SizedStream<S>) -> Body {
259 Body::from_message(s)
260 }
261}
262
263impl<S, E> From<BodyStream<S, E>> for Body
264where
265 S: Stream<Item = Result<Bytes, E>> + 'static,
266 E: Into<Error> + 'static,
267{
268 fn from(s: BodyStream<S, E>) -> Body {
269 Body::from_message(s)
270 }
271}
272
273impl MessageBody for Bytes {
274 fn size(&self) -> BodySize {
275 BodySize::Sized(self.len())
276 }
277
278 fn poll_next(&mut self, _: &mut Context<'_>) -> Poll<Option<Result<Bytes, Error>>> {
279 if self.is_empty() {
280 Poll::Ready(None)
281 } else {
282 Poll::Ready(Some(Ok(mem::replace(self, Bytes::new()))))
283 }
284 }
285}
286
287impl MessageBody for BytesMut {
288 fn size(&self) -> BodySize {
289 BodySize::Sized(self.len())
290 }
291
292 fn poll_next(&mut self, _: &mut Context<'_>) -> Poll<Option<Result<Bytes, Error>>> {
293 if self.is_empty() {
294 Poll::Ready(None)
295 } else {
296 Poll::Ready(Some(Ok(mem::replace(self, BytesMut::new()).freeze())))
297 }
298 }
299}
300
301impl MessageBody for &'static str {
302 fn size(&self) -> BodySize {
303 BodySize::Sized(self.len())
304 }
305
306 fn poll_next(&mut self, _: &mut Context<'_>) -> Poll<Option<Result<Bytes, Error>>> {
307 if self.is_empty() {
308 Poll::Ready(None)
309 } else {
310 Poll::Ready(Some(Ok(Bytes::from_static(
311 mem::replace(self, "").as_ref(),
312 ))))
313 }
314 }
315}
316
317impl MessageBody for &'static [u8] {
318 fn size(&self) -> BodySize {
319 BodySize::Sized(self.len())
320 }
321
322 fn poll_next(&mut self, _: &mut Context<'_>) -> Poll<Option<Result<Bytes, Error>>> {
323 if self.is_empty() {
324 Poll::Ready(None)
325 } else {
326 Poll::Ready(Some(Ok(Bytes::from_static(mem::replace(self, b"")))))
327 }
328 }
329}
330
331impl MessageBody for Vec<u8> {
332 fn size(&self) -> BodySize {
333 BodySize::Sized(self.len())
334 }
335
336 fn poll_next(&mut self, _: &mut Context<'_>) -> Poll<Option<Result<Bytes, Error>>> {
337 if self.is_empty() {
338 Poll::Ready(None)
339 } else {
340 Poll::Ready(Some(Ok(Bytes::from(mem::replace(self, Vec::new())))))
341 }
342 }
343}
344
345impl MessageBody for String {
346 fn size(&self) -> BodySize {
347 BodySize::Sized(self.len())
348 }
349
350 fn poll_next(&mut self, _: &mut Context<'_>) -> Poll<Option<Result<Bytes, Error>>> {
351 if self.is_empty() {
352 Poll::Ready(None)
353 } else {
354 Poll::Ready(Some(Ok(Bytes::from(
355 mem::replace(self, String::new()).into_bytes(),
356 ))))
357 }
358 }
359}
360
361#[pin_project]
364pub struct BodyStream<S, E> {
365 #[pin]
366 stream: S,
367 _t: PhantomData<E>,
368}
369
370impl<S, E> BodyStream<S, E>
371where
372 S: Stream<Item = Result<Bytes, E>>,
373 E: Into<Error>,
374{
375 pub fn new(stream: S) -> Self {
376 BodyStream {
377 stream,
378 _t: PhantomData,
379 }
380 }
381}
382
383impl<S, E> MessageBody for BodyStream<S, E>
384where
385 S: Stream<Item = Result<Bytes, E>>,
386 E: Into<Error>,
387{
388 fn size(&self) -> BodySize {
389 BodySize::Stream
390 }
391
392 fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<Bytes, Error>>> {
393 unsafe { Pin::new_unchecked(self) }
394 .project()
395 .stream
396 .poll_next(cx)
397 .map(|res| res.map(|res| res.map_err(std::convert::Into::into)))
398 }
399}
400
401#[pin_project]
404pub struct SizedStream<S> {
405 size: u64,
406 #[pin]
407 stream: S,
408}
409
410impl<S> SizedStream<S>
411where
412 S: Stream<Item = Result<Bytes, Error>>,
413{
414 pub fn new(size: u64, stream: S) -> Self {
415 SizedStream { size, stream }
416 }
417}
418
419impl<S> MessageBody for SizedStream<S>
420where
421 S: Stream<Item = Result<Bytes, Error>>,
422{
423 fn size(&self) -> BodySize {
424 BodySize::Sized64(self.size)
425 }
426
427 fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<Bytes, Error>>> {
428 unsafe { Pin::new_unchecked(self) }
429 .project()
430 .stream
431 .poll_next(cx)
432 }
433}
434
435#[cfg(test)]
436mod tests {
437 use super::*;
438 use futures_util::future::poll_fn;
439
440 impl Body {
441 pub(crate) fn get_ref(&self) -> &[u8] {
442 match *self {
443 Body::Bytes(ref bin) => &bin,
444 _ => panic!(),
445 }
446 }
447 }
448
449 impl ResponseBody<Body> {
450 pub(crate) fn get_ref(&self) -> &[u8] {
451 match *self {
452 ResponseBody::Body(ref b) => b.get_ref(),
453 ResponseBody::Other(ref b) => b.get_ref(),
454 }
455 }
456 }
457
458 #[requiem_rt::test]
459 async fn test_static_str() {
460 assert_eq!(Body::from("").size(), BodySize::Sized(0));
461 assert_eq!(Body::from("test").size(), BodySize::Sized(4));
462 assert_eq!(Body::from("test").get_ref(), b"test");
463
464 assert_eq!("test".size(), BodySize::Sized(4));
465 assert_eq!(
466 poll_fn(|cx| "test".poll_next(cx)).await.unwrap().ok(),
467 Some(Bytes::from("test"))
468 );
469 }
470
471 #[requiem_rt::test]
472 async fn test_static_bytes() {
473 assert_eq!(Body::from(b"test".as_ref()).size(), BodySize::Sized(4));
474 assert_eq!(Body::from(b"test".as_ref()).get_ref(), b"test");
475 assert_eq!(
476 Body::from_slice(b"test".as_ref()).size(),
477 BodySize::Sized(4)
478 );
479 assert_eq!(Body::from_slice(b"test".as_ref()).get_ref(), b"test");
480
481 assert_eq!((&b"test"[..]).size(), BodySize::Sized(4));
482 assert_eq!(
483 poll_fn(|cx| (&b"test"[..]).poll_next(cx))
484 .await
485 .unwrap()
486 .ok(),
487 Some(Bytes::from("test"))
488 );
489 }
490
491 #[requiem_rt::test]
492 async fn test_vec() {
493 assert_eq!(Body::from(Vec::from("test")).size(), BodySize::Sized(4));
494 assert_eq!(Body::from(Vec::from("test")).get_ref(), b"test");
495
496 assert_eq!(Vec::from("test").size(), BodySize::Sized(4));
497 assert_eq!(
498 poll_fn(|cx| Vec::from("test").poll_next(cx))
499 .await
500 .unwrap()
501 .ok(),
502 Some(Bytes::from("test"))
503 );
504 }
505
506 #[requiem_rt::test]
507 async fn test_bytes() {
508 let mut b = Bytes::from("test");
509 assert_eq!(Body::from(b.clone()).size(), BodySize::Sized(4));
510 assert_eq!(Body::from(b.clone()).get_ref(), b"test");
511
512 assert_eq!(b.size(), BodySize::Sized(4));
513 assert_eq!(
514 poll_fn(|cx| b.poll_next(cx)).await.unwrap().ok(),
515 Some(Bytes::from("test"))
516 );
517 }
518
519 #[requiem_rt::test]
520 async fn test_bytes_mut() {
521 let mut b = BytesMut::from("test");
522 assert_eq!(Body::from(b.clone()).size(), BodySize::Sized(4));
523 assert_eq!(Body::from(b.clone()).get_ref(), b"test");
524
525 assert_eq!(b.size(), BodySize::Sized(4));
526 assert_eq!(
527 poll_fn(|cx| b.poll_next(cx)).await.unwrap().ok(),
528 Some(Bytes::from("test"))
529 );
530 }
531
532 #[requiem_rt::test]
533 async fn test_string() {
534 let mut b = "test".to_owned();
535 assert_eq!(Body::from(b.clone()).size(), BodySize::Sized(4));
536 assert_eq!(Body::from(b.clone()).get_ref(), b"test");
537 assert_eq!(Body::from(&b).size(), BodySize::Sized(4));
538 assert_eq!(Body::from(&b).get_ref(), b"test");
539
540 assert_eq!(b.size(), BodySize::Sized(4));
541 assert_eq!(
542 poll_fn(|cx| b.poll_next(cx)).await.unwrap().ok(),
543 Some(Bytes::from("test"))
544 );
545 }
546
547 #[requiem_rt::test]
548 async fn test_unit() {
549 assert_eq!(().size(), BodySize::Empty);
550 assert!(poll_fn(|cx| ().poll_next(cx)).await.is_none());
551 }
552
553 #[requiem_rt::test]
554 async fn test_box() {
555 let mut val = Box::new(());
556 assert_eq!(val.size(), BodySize::Empty);
557 assert!(poll_fn(|cx| val.poll_next(cx)).await.is_none());
558 }
559
560 #[requiem_rt::test]
561 async fn test_body_eq() {
562 assert!(Body::None == Body::None);
563 assert!(Body::None != Body::Empty);
564 assert!(Body::Empty == Body::Empty);
565 assert!(Body::Empty != Body::None);
566 assert!(
567 Body::Bytes(Bytes::from_static(b"1"))
568 == Body::Bytes(Bytes::from_static(b"1"))
569 );
570 assert!(Body::Bytes(Bytes::from_static(b"1")) != Body::None);
571 }
572
573 #[requiem_rt::test]
574 async fn test_body_debug() {
575 assert!(format!("{:?}", Body::None).contains("Body::None"));
576 assert!(format!("{:?}", Body::Empty).contains("Body::Empty"));
577 assert!(format!("{:?}", Body::Bytes(Bytes::from_static(b"1"))).contains("1"));
578 }
579
580 #[requiem_rt::test]
581 async fn test_serde_json() {
582 use serde_json::json;
583 assert_eq!(
584 Body::from(serde_json::Value::String("test".into())).size(),
585 BodySize::Sized(6)
586 );
587 assert_eq!(
588 Body::from(json!({"test-key":"test-value"})).size(),
589 BodySize::Sized(25)
590 );
591 }
592}