use std::error::Error as StdError;
use std::fmt::{Debug, Formatter};
use std::pin::Pin;
use std::task::{Context, Poll};
use bytes::Bytes;
use futures_core::Stream;
use http_body::{Frame, SizeHint};
use http_body_util::combinators::BoxBody;
use sync_wrapper::SyncWrapper;
use crate::error::error_impl::impl_into_cot_error;
use crate::{Error, Result};
#[derive(Debug)]
pub struct Body {
pub(crate) inner: BodyInner,
}
pub(crate) enum BodyInner {
Fixed(Bytes),
Streaming(SyncWrapper<Pin<Box<dyn Stream<Item = Result<Bytes>> + Send>>>),
Axum(SyncWrapper<axum::body::Body>),
Wrapper(BoxBody<Bytes, Error>),
}
impl Debug for BodyInner {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Self::Fixed(data) => f.debug_tuple("Fixed").field(data).finish(),
Self::Streaming(_) => f.debug_tuple("Streaming").field(&"...").finish(),
Self::Axum(axum_body) => f.debug_tuple("Axum").field(axum_body).finish(),
Self::Wrapper(_) => f.debug_tuple("Wrapper").field(&"...").finish(),
}
}
}
impl Body {
#[must_use]
const fn new(inner: BodyInner) -> Self {
Self { inner }
}
#[must_use]
pub const fn empty() -> Self {
Self::new(BodyInner::Fixed(Bytes::new()))
}
#[must_use]
pub fn fixed<T: Into<Bytes>>(data: T) -> Self {
Self::new(BodyInner::Fixed(data.into()))
}
#[must_use]
pub fn streaming<T: Stream<Item = Result<Bytes>> + Send + 'static>(stream: T) -> Self {
Self::new(BodyInner::Streaming(SyncWrapper::new(Box::pin(stream))))
}
pub async fn into_bytes(self) -> Result<Bytes> {
self.into_bytes_limited(usize::MAX).await
}
pub async fn into_bytes_limited(self, limit: usize) -> Result<Bytes> {
use http_body_util::BodyExt;
Ok(http_body_util::Limited::new(self, limit)
.collect()
.await
.map(http_body_util::Collected::to_bytes)
.map_err(ReadRequestBody)?)
}
#[must_use]
pub(crate) fn axum(inner: axum::body::Body) -> Self {
Self::new(BodyInner::Axum(SyncWrapper::new(inner)))
}
#[must_use]
pub(crate) fn wrapper(inner: BoxBody<Bytes, Error>) -> Self {
Self::new(BodyInner::Wrapper(inner))
}
}
impl Default for Body {
fn default() -> Self {
Self::empty()
}
}
impl http_body::Body for Body {
type Data = Bytes;
type Error = Error;
fn poll_frame(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<std::result::Result<Frame<Self::Data>, Self::Error>>> {
match self.get_mut().inner {
BodyInner::Fixed(ref mut data) => {
if data.is_empty() {
Poll::Ready(None)
} else {
let data = std::mem::take(data);
Poll::Ready(Some(Ok(Frame::data(data))))
}
}
BodyInner::Streaming(ref mut stream) => {
let stream = Pin::as_mut(stream.get_mut());
match stream.poll_next(cx) {
Poll::Ready(Some(result)) => Poll::Ready(Some(result.map(Frame::data))),
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
BodyInner::Axum(ref mut axum_body) => {
let axum_body = axum_body.get_mut();
Pin::new(axum_body)
.poll_frame(cx)
.map_err(|error| ReadRequestBody(Box::new(error)).into())
}
BodyInner::Wrapper(ref mut http_body) => Pin::new(http_body)
.poll_frame(cx)
.map_err(|error| ReadRequestBody(Box::new(error)).into()),
}
}
fn is_end_stream(&self) -> bool {
match &self.inner {
BodyInner::Fixed(data) => data.is_empty(),
BodyInner::Streaming(_) | BodyInner::Axum(_) => false,
BodyInner::Wrapper(http_body) => http_body.is_end_stream(),
}
}
fn size_hint(&self) -> SizeHint {
match &self.inner {
BodyInner::Fixed(data) => SizeHint::with_exact(data.len() as u64),
BodyInner::Streaming(_) | BodyInner::Axum(_) => SizeHint::new(),
BodyInner::Wrapper(http_body) => http_body.size_hint(),
}
}
}
macro_rules! body_from_impl {
($ty:ty) => {
impl From<$ty> for Body {
fn from(buf: $ty) -> Self {
Self::new(BodyInner::Fixed(Bytes::from(buf)))
}
}
};
}
body_from_impl!(&'static [u8]);
body_from_impl!(Vec<u8>);
body_from_impl!(&'static str);
body_from_impl!(String);
body_from_impl!(Bytes);
#[derive(Debug, thiserror::Error)]
#[error("could not retrieve request body: {0}")]
struct ReadRequestBody(#[source] Box<dyn StdError + Send + Sync>);
impl_into_cot_error!(ReadRequestBody, BAD_REQUEST);
#[cfg(test)]
mod tests {
use std::pin::Pin;
use std::task::{Context, Poll};
use futures::stream;
use http_body::Body as HttpBody;
use super::*;
#[test]
fn body_empty() {
let body = Body::empty();
if let BodyInner::Fixed(data) = body.inner {
assert!(data.is_empty());
} else {
panic!("Body::empty should create a fixed empty body");
}
}
#[test]
fn body_fixed() {
let content = "Hello, world!";
let body = Body::fixed(content);
if let BodyInner::Fixed(data) = body.inner {
assert_eq!(data, Bytes::from(content));
} else {
panic!("Body::fixed should create a fixed body with the given content");
}
}
#[cot::test]
async fn body_streaming() {
let stream = stream::once(async { Ok(Bytes::from("Hello, world!")) });
let body = Body::streaming(stream);
if let BodyInner::Streaming(_) = body.inner {
} else {
panic!("Body::streaming should create a streaming body");
}
}
#[cot::test]
async fn http_body_poll_frame_fixed() {
let content = "Hello, world!";
let mut body = Body::fixed(content);
let mut cx = Context::from_waker(futures::task::noop_waker_ref());
match Pin::new(&mut body).poll_frame(&mut cx) {
Poll::Ready(Some(Ok(frame))) => {
assert_eq!(frame.into_data().unwrap(), Bytes::from(content));
}
_ => panic!("Body::fixed should return the content in poll_frame"),
}
match Pin::new(&mut body).poll_frame(&mut cx) {
Poll::Ready(None) => {} _ => panic!("Body::fixed should return None after the content is consumed"),
}
}
#[cot::test]
async fn http_body_poll_frame_streaming() {
let content = "Hello, world!";
let mut body = Body::streaming(stream::once(async move { Ok(Bytes::from(content)) }));
let mut cx = Context::from_waker(futures::task::noop_waker_ref());
match Pin::new(&mut body).poll_frame(&mut cx) {
Poll::Ready(Some(Ok(frame))) => {
assert_eq!(frame.into_data().unwrap(), Bytes::from(content));
}
_ => panic!("Body::fixed should return the content in poll_frame"),
}
match Pin::new(&mut body).poll_frame(&mut cx) {
Poll::Ready(None) => {} _ => panic!("Body::fixed should return None after the content is consumed"),
}
}
#[test]
fn http_body_is_end_stream() {
let body = Body::empty();
assert!(body.is_end_stream());
let body = Body::fixed("Hello, world!");
assert!(!body.is_end_stream());
}
#[test]
fn http_body_size_hint() {
let body = Body::empty();
assert_eq!(body.size_hint().exact(), Some(0));
let content = "Hello, world!";
let body = Body::fixed(content);
assert_eq!(body.size_hint().exact(), Some(content.len() as u64));
}
}