salvo_core 0.92.0

Salvo is a powerful web framework that can make your work easier.
Documentation
use std::fmt::{self, Debug, Formatter};
use std::io::{Error as IoError, Result as IoResult};
use std::pin::Pin;
use std::task::{Context, Poll};

use bytes::Bytes;
use futures_util::stream::Stream;
use hyper::body::{Body, Frame, Incoming, SizeHint};

use crate::BoxedError;
use crate::fuse::{ArcFusewire, FuseEvent};

pub(crate) type BoxedBody =
    Pin<Box<dyn Body<Data = Bytes, Error = BoxedError> + Send + Sync + 'static>>;
pub(crate) type PollFrame = Poll<Option<Result<Frame<Bytes>, IoError>>>;

/// Body for HTTP request.
#[non_exhaustive]
#[derive(Default)]
pub enum ReqBody {
    /// None body.
    #[default]
    None,
    /// Once bytes body.
    Once(Bytes),
    /// Hyper default body.
    Hyper {
        /// Inner body.
        inner: Incoming,
        /// Fusewire.
        fusewire: Option<ArcFusewire>,
    },
    /// Boxed body.
    Boxed {
        /// Inner body.
        inner: BoxedBody,
        /// Fusewire.
        fusewire: Option<ArcFusewire>,
    },
}
impl ReqBody {
    #[doc(hidden)]
    pub fn set_fusewire(&mut self, value: Option<ArcFusewire>) {
        match self {
            Self::None | Self::Once(_) => {}
            Self::Hyper { fusewire, .. } | Self::Boxed { fusewire, .. } => {
                *fusewire = value;
            }
        }
    }
    /// Check is that body is not set.
    #[inline]
    pub fn is_none(&self) -> bool {
        matches!(*self, Self::None)
    }
    /// Check is that body is once.
    #[inline]
    pub fn is_once(&self) -> bool {
        matches!(*self, Self::Once(_))
    }
    /// Check is that body is hyper default body type.
    #[inline]
    pub fn is_hyper(&self) -> bool {
        matches!(*self, Self::Hyper { .. })
    }
    /// Check is that body is stream.
    #[inline]
    pub fn is_boxed(&self) -> bool {
        matches!(*self, Self::Boxed { .. })
    }

    /// Set body to none and returns current body.
    #[inline]
    #[must_use]
    pub fn take(&mut self) -> Self {
        std::mem::replace(self, Self::None)
    }
}

impl Body for ReqBody {
    type Data = Bytes;
    type Error = IoError;

    fn poll_frame(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollFrame {
        #[inline]
        fn through_fusewire(poll: PollFrame, fusewire: Option<&ArcFusewire>) -> PollFrame {
            match poll {
                Poll::Ready(None) => Poll::Ready(None),
                Poll::Ready(Some(Ok(data))) => {
                    if let Some(fusewire) = fusewire {
                        fusewire.event(FuseEvent::GainFrame);
                    }
                    Poll::Ready(Some(Ok(data)))
                }
                Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
                Poll::Pending => {
                    if let Some(fusewire) = fusewire {
                        fusewire.event(FuseEvent::WaitFrame);
                    }
                    Poll::Pending
                }
            }
        }
        match &mut *self {
            Self::None => Poll::Ready(None),
            Self::Once(bytes) => {
                if bytes.is_empty() {
                    Poll::Ready(None)
                } else {
                    let bytes = std::mem::replace(bytes, Bytes::new());
                    Poll::Ready(Some(Ok(Frame::data(bytes))))
                }
            }
            Self::Hyper { inner, fusewire } => {
                let poll = Pin::new(inner).poll_frame(cx).map_err(IoError::other);
                through_fusewire(poll, fusewire.as_ref())
            }
            Self::Boxed { inner, fusewire } => {
                let poll = Pin::new(inner).poll_frame(cx).map_err(IoError::other);
                through_fusewire(poll, fusewire.as_ref())
            }
        }
    }

    fn is_end_stream(&self) -> bool {
        match self {
            Self::None => true,
            Self::Once(bytes) => bytes.is_empty(),
            Self::Hyper { inner, .. } => inner.is_end_stream(),
            Self::Boxed { inner, .. } => inner.is_end_stream(),
        }
    }

    fn size_hint(&self) -> SizeHint {
        match self {
            Self::None => SizeHint::with_exact(0),
            Self::Once(bytes) => SizeHint::with_exact(bytes.len() as u64),
            Self::Hyper { inner, .. } => inner.size_hint(),
            Self::Boxed { inner, .. } => inner.size_hint(),
        }
    }
}
impl Stream for ReqBody {
    type Item = IoResult<Frame<Bytes>>;

    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        match Body::poll_frame(self, cx) {
            Poll::Ready(Some(Ok(frame))) => Poll::Ready(Some(Ok(frame))),
            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(IoError::other(e)))),
            Poll::Ready(None) => Poll::Ready(None),
            Poll::Pending => Poll::Pending,
        }
    }
}

impl From<Bytes> for ReqBody {
    fn from(value: Bytes) -> Self {
        Self::Once(value)
    }
}
impl From<Incoming> for ReqBody {
    fn from(inner: Incoming) -> Self {
        Self::Hyper {
            inner,
            fusewire: None,
        }
    }
}
impl From<String> for ReqBody {
    #[inline]
    fn from(value: String) -> Self {
        Self::Once(value.into())
    }
}
impl TryFrom<ReqBody> for Incoming {
    type Error = crate::Error;
    fn try_from(body: ReqBody) -> Result<Self, Self::Error> {
        match body {
            ReqBody::None => Err(crate::Error::other(
                "ReqBody::None cannot convert to Incoming",
            )),
            ReqBody::Once(_) => Err(crate::Error::other(
                "ReqBody::Bytes cannot convert to Incoming",
            )),
            ReqBody::Hyper { inner, .. } => Ok(inner),
            ReqBody::Boxed { .. } => Err(crate::Error::other(
                "ReqBody::Boxed cannot convert to Incoming",
            )),
        }
    }
}

impl From<&'static [u8]> for ReqBody {
    fn from(value: &'static [u8]) -> Self {
        Self::Once(Bytes::from_static(value))
    }
}

impl From<&'static str> for ReqBody {
    fn from(value: &'static str) -> Self {
        Self::Once(Bytes::from_static(value.as_bytes()))
    }
}

impl From<Vec<u8>> for ReqBody {
    fn from(value: Vec<u8>) -> Self {
        Self::Once(value.into())
    }
}

impl<T> From<Box<T>> for ReqBody
where
    T: Into<Self>,
{
    fn from(value: Box<T>) -> Self {
        (*value).into()
    }
}

cfg_feature! {
    #![feature = "quinn"]
    pub(crate) mod h3 {
        use std::boxed::Box;
        use std::pin::Pin;
        use std::task::{ready, Context, Poll};
        use std::fmt::{self, Debug, Formatter};

        use hyper::body::{Body, Frame, SizeHint};
        use salvo_http3::quic::RecvStream;
        use salvo_http3::error::Code;

        use bytes::{Buf, Bytes};

        use crate::BoxedError;
        use crate::http::ReqBody;

        /// Http3 request body.
        pub struct H3ReqBody<S, B>
        where
            S: RecvStream + Send + Unpin,
            B: Buf + Send + Unpin,
        {
            inner: salvo_http3::server::RequestStream<S, B>,
        }
        impl<S, B> Debug for H3ReqBody<S, B>
        where
            S: RecvStream + Send + Unpin,
            B: Buf + Send + Unpin,
        {
            fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
                f.debug_struct("H3ReqBody").finish()
            }
        }
        impl<S, B> H3ReqBody<S, B>
        where
            S: RecvStream + Send + Unpin + 'static,
            B: Buf + Send + Unpin + 'static,
        {
            /// Create new `H3ReqBody` instance.
            pub fn new(inner: salvo_http3::server::RequestStream<S, B>) -> Self {
                Self { inner }
            }
        }

        impl<S, B> Body for H3ReqBody<S, B>
        where
            S: RecvStream + Send + Unpin,
            B: Buf + Send + Unpin,
        {
            type Data = Bytes;
            type Error = BoxedError;

            fn poll_frame(
                mut self: Pin<&mut Self>,
                cx: &mut Context<'_>,
            ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
                let this = &mut *self;
                match ready!(this.inner.poll_recv_data(cx)) {
                    Ok(Some(buf)) => {
                        Poll::Ready(Some(Ok(Frame::data(Bytes::copy_from_slice(buf.chunk())))))
                    }
                    Ok(None) => Poll::Ready(None),
                    Err(e) => Poll::Ready(Some(Err(e.into()))),
                }
            }

            fn is_end_stream(&self) -> bool {
                false // TODO: check this
            }

            fn size_hint(&self) -> SizeHint {
                SizeHint::default()
            }
        }

        impl<S, B> Drop for H3ReqBody<S, B>
        where
            S: RecvStream + Send + Unpin,
            B: Buf + Send + Unpin,
        {
            fn drop(&mut self) {
                self.inner.stop_sending(Code::H3_NO_ERROR);
            }
        }

        impl<S, B> From<H3ReqBody<S, B>> for ReqBody
        where
            S: RecvStream + Send + Sync + Unpin + 'static,
            B: Buf + Send + Sync + Unpin + 'static,
        {
            fn from(value: H3ReqBody<S, B>) -> Self {
                Self::Boxed {
                    inner: Box::pin(value),
                    fusewire: None,
                }
            }
        }
    }
}

impl Debug for ReqBody {
    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
        match self {
            Self::None => write!(f, "ReqBody::None"),
            Self::Once(value) => f.debug_tuple("ReqBody::Once").field(value).finish(),
            Self::Hyper { inner, .. } => f
                .debug_struct("ReqBody::Hyper")
                .field("inner", inner)
                .finish(),
            Self::Boxed { .. } => write!(f, "ReqBody::Boxed{{..}}"),
        }
    }
}

#[cfg(test)]
mod tests {
    use bytes::Bytes;

    use super::*;

    #[test]
    fn test_take() {
        let mut b = ReqBody::Once(Bytes::from("abc"));
        let old = b.take();
        assert!(matches!(old, ReqBody::Once(_)));
        assert!(b.is_none());
    }

    #[test]
    fn test_debug() {
        let b = ReqBody::None;
        let s = format!("{b:?}");
        assert!(s.contains("ReqBody::None"));
    }

    #[test]
    fn test_is_end_stream() {
        let b = ReqBody::None;
        assert!(b.is_end_stream());
        let b = ReqBody::Once(Bytes::new());
        assert!(b.is_end_stream());
    }
}