use std::{
convert::Infallible,
error::Error,
fmt,
future::Future,
pin::Pin,
task::{Context, Poll},
};
use bytes::Bytes;
use faststr::FastStr;
use futures_util::stream::Stream;
use http_body::{Frame, SizeHint};
use http_body_util::{BodyExt, Full, StreamBody, combinators::BoxBody};
use hyper::body::Incoming;
use linkedbytes::{LinkedBytes, Node};
use pin_project::pin_project;
#[cfg(feature = "json")]
use serde::de::DeserializeOwned;
use crate::error::BoxError;
type BoxStream<'a, T> = Pin<Box<dyn Stream<Item = T> + Send + Sync + 'a>>;
#[pin_project]
pub struct Body {
#[pin]
repr: BodyRepr,
}
#[pin_project(project = BodyProj)]
enum BodyRepr {
Full(#[pin] Full<Bytes>),
Hyper(#[pin] Incoming),
Stream(#[pin] StreamBody<BoxStream<'static, Result<Frame<Bytes>, BoxError>>>),
Body(#[pin] BoxBody<Bytes, BoxError>),
}
impl Default for Body {
fn default() -> Self {
Body::empty()
}
}
impl Body {
pub fn empty() -> Self {
Self {
repr: BodyRepr::Full(Full::new(Bytes::new())),
}
}
pub fn from_incoming(incoming: Incoming) -> Self {
Self {
repr: BodyRepr::Hyper(incoming),
}
}
pub fn from_stream<S>(stream: S) -> Self
where
S: Stream<Item = Result<Frame<Bytes>, BoxError>> + Send + Sync + 'static,
{
Self {
repr: BodyRepr::Stream(StreamBody::new(Box::pin(stream))),
}
}
pub fn from_body<B>(body: B) -> Self
where
B: http_body::Body<Data = Bytes> + Send + Sync + 'static,
B::Error: Into<BoxError>,
{
Self {
repr: BodyRepr::Body(BoxBody::new(body.map_err(Into::into))),
}
}
}
impl http_body::Body for Body {
type Data = Bytes;
type Error = BoxError;
fn poll_frame(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
match self.project().repr.project() {
BodyProj::Full(full) => http_body::Body::poll_frame(full, cx).map_err(BoxError::from),
BodyProj::Hyper(incoming) => {
http_body::Body::poll_frame(incoming, cx).map_err(BoxError::from)
}
BodyProj::Stream(stream) => http_body::Body::poll_frame(stream, cx),
BodyProj::Body(body) => http_body::Body::poll_frame(body, cx),
}
}
fn is_end_stream(&self) -> bool {
match &self.repr {
BodyRepr::Full(full) => http_body::Body::is_end_stream(full),
BodyRepr::Hyper(incoming) => http_body::Body::is_end_stream(incoming),
BodyRepr::Stream(stream) => http_body::Body::is_end_stream(stream),
BodyRepr::Body(body) => http_body::Body::is_end_stream(body),
}
}
fn size_hint(&self) -> SizeHint {
match &self.repr {
BodyRepr::Full(full) => http_body::Body::size_hint(full),
BodyRepr::Hyper(incoming) => http_body::Body::size_hint(incoming),
BodyRepr::Stream(stream) => http_body::Body::size_hint(stream),
BodyRepr::Body(body) => http_body::Body::size_hint(body),
}
}
}
impl fmt::Debug for Body {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.repr {
BodyRepr::Full(_) => f.write_str("Body::Full"),
BodyRepr::Hyper(_) => f.write_str("Body::Hyper"),
BodyRepr::Stream(_) => f.write_str("Body::Stream"),
BodyRepr::Body(_) => f.write_str("Body::Body"),
}
}
}
mod sealed {
pub trait SealedBody
where
Self: http_body::Body + Sized + Send,
Self::Data: Send,
{
}
impl<T> SealedBody for T
where
T: http_body::Body + Send,
T::Data: Send,
{
}
}
pub trait BodyConversion: sealed::SealedBody
where
<Self as http_body::Body>::Data: Send,
{
fn into_bytes(self) -> impl Future<Output = Result<Bytes, BodyConvertError>> + Send {
async {
Ok(self
.collect()
.await
.map_err(|_| BodyConvertError::BodyCollectionError)?
.to_bytes())
}
}
fn into_vec(self) -> impl Future<Output = Result<Vec<u8>, BodyConvertError>> + Send {
async { Ok(self.into_bytes().await?.into()) }
}
fn into_string(self) -> impl Future<Output = Result<String, BodyConvertError>> + Send {
async {
let vec = self.into_vec().await?;
let _ =
simdutf8::basic::from_utf8(&vec).map_err(|_| BodyConvertError::StringUtf8Error)?;
Ok(unsafe { String::from_utf8_unchecked(vec) })
}
}
unsafe fn into_string_unchecked(
self,
) -> impl Future<Output = Result<String, BodyConvertError>> + Send {
async {
let vec = self.into_vec().await?;
Ok(unsafe { String::from_utf8_unchecked(vec) })
}
}
fn into_faststr(self) -> impl Future<Output = Result<FastStr, BodyConvertError>> + Send {
async {
let bytes = self.into_bytes().await?;
let _ = simdutf8::basic::from_utf8(&bytes)
.map_err(|_| BodyConvertError::StringUtf8Error)?;
Ok(unsafe { FastStr::from_bytes_unchecked(bytes) })
}
}
unsafe fn into_faststr_unchecked(
self,
) -> impl Future<Output = Result<FastStr, BodyConvertError>> + Send {
async {
let bytes = self.into_bytes().await?;
Ok(unsafe { FastStr::from_bytes_unchecked(bytes) })
}
}
#[cfg(feature = "json")]
fn into_json<T>(self) -> impl Future<Output = Result<T, BodyConvertError>> + Send
where
T: DeserializeOwned,
{
async {
let bytes = self.into_bytes().await?;
crate::utils::json::deserialize(&bytes).map_err(BodyConvertError::JsonDeserializeError)
}
}
}
impl<T> BodyConversion for T
where
T: sealed::SealedBody,
<T as http_body::Body>::Data: Send,
{
}
#[derive(Debug)]
pub enum BodyConvertError {
BodyCollectionError,
StringUtf8Error,
#[cfg(feature = "json")]
JsonDeserializeError(crate::utils::json::Error),
}
impl fmt::Display for BodyConvertError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::BodyCollectionError => f.write_str("failed to collect body"),
Self::StringUtf8Error => f.write_str("body is not a valid string"),
#[cfg(feature = "json")]
Self::JsonDeserializeError(e) => write!(f, "failed to deserialize body: {e}"),
}
}
}
impl Error for BodyConvertError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
#[cfg(feature = "json")]
Self::JsonDeserializeError(e) => Some(e),
_ => None,
}
}
}
impl From<()> for Body {
fn from(_: ()) -> Self {
Self::empty()
}
}
impl From<&'static str> for Body {
fn from(value: &'static str) -> Self {
Self {
repr: BodyRepr::Full(Full::new(Bytes::from_static(value.as_bytes()))),
}
}
}
impl From<Vec<u8>> for Body {
fn from(value: Vec<u8>) -> Self {
Self {
repr: BodyRepr::Full(Full::new(Bytes::from(value))),
}
}
}
impl From<Bytes> for Body {
fn from(value: Bytes) -> Self {
Self {
repr: BodyRepr::Full(Full::new(value)),
}
}
}
impl From<FastStr> for Body {
fn from(value: FastStr) -> Self {
Self {
repr: BodyRepr::Full(Full::new(value.into_bytes())),
}
}
}
impl From<String> for Body {
fn from(value: String) -> Self {
Self {
repr: BodyRepr::Full(Full::new(Bytes::from(value))),
}
}
}
struct LinkedBytesBody<I> {
inner: I,
}
impl<I> http_body::Body for LinkedBytesBody<I>
where
I: Iterator<Item = Node> + Unpin,
{
type Data = Bytes;
type Error = Infallible;
fn poll_frame(
self: Pin<&mut Self>,
_: &mut Context<'_>,
) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
let this = self.get_mut();
let Some(node) = this.inner.next() else {
return Poll::Ready(None);
};
let bytes = match node {
Node::Bytes(bytes) => bytes,
Node::BytesMut(bytesmut) => bytesmut.freeze(),
Node::FastStr(faststr) => faststr.into_bytes(),
};
Poll::Ready(Some(Ok(Frame::data(bytes))))
}
fn is_end_stream(&self) -> bool {
false
}
fn size_hint(&self) -> SizeHint {
let (lower, upper) = self.inner.size_hint();
let mut size_hint = SizeHint::new();
size_hint.set_lower(lower as u64);
if let Some(upper) = upper {
size_hint.set_upper(upper as u64);
}
size_hint
}
}
impl From<LinkedBytes> for Body {
fn from(value: LinkedBytes) -> Self {
Body::from_body(LinkedBytesBody {
inner: value.into_iter_list(),
})
}
}
#[cfg(test)]
mod tests {
use bytes::Bytes;
use faststr::FastStr;
use linkedbytes::LinkedBytes;
use super::Body;
use crate::body::BodyConversion;
#[tokio::test]
async fn test_from_linked_bytes() {
let mut bytes = LinkedBytes::new();
bytes.insert(Bytes::from_static(b"Hello, "));
bytes.insert_faststr(FastStr::new("world!"));
let body = Body::from(bytes);
assert_eq!(body.into_string().await.unwrap(), "Hello, world!");
}
}