#[cfg(feature = "serde_json")]
use serde_json::{Value, from_slice, to_vec};
use {
super::{FromBytes, IntoBytes},
futures_util::{Stream, StreamExt},
hyper::{
Error,
body::{Body, Bytes, Frame, Incoming, SizeHint},
},
parking_lot::RwLock,
std::{
fmt::Debug,
io::{Error as IoError, Result as IoResult},
pin::Pin,
sync::Arc,
task::{Context, Poll, Waker},
},
tracing::error,
};
#[derive(Default)]
pub enum StreamingBody {
#[default]
Null,
Bytes {
bytes: Option<Bytes>,
},
Incoming {
incoming: Arc<RwLock<Incoming>>,
},
Stream {
stream: Arc<RwLock<Pin<Box<dyn Stream<Item = Bytes> + Send + Sync>>>>,
},
}
impl Clone for StreamingBody {
fn clone(&self) -> Self {
match self {
Self::Null => Self::Null,
Self::Bytes { bytes } => Self::Bytes {
bytes: bytes.clone(),
},
Self::Incoming { incoming } => Self::Incoming {
incoming: incoming.clone(),
},
Self::Stream { stream } => Self::Stream {
stream: stream.clone(),
},
}
}
}
impl Debug for StreamingBody {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Null => write!(f, "Null"),
Self::Bytes { .. } => write!(f, "Bytes"),
Self::Incoming { .. } => write!(f, "Incoming"),
Self::Stream { .. } => write!(f, "Stream"),
}
}
}
impl StreamingBody {
fn into_vec(self) -> Vec<u8> {
let mut ret = Vec::new();
let mut cx = Context::from_waker(Waker::noop());
match self {
StreamingBody::Null => (),
StreamingBody::Bytes { bytes } => {
if let Some(data) = bytes {
ret.extend_from_slice(&data)
}
}
StreamingBody::Incoming { incoming } => {
let mut incoming = incoming.write();
while !incoming.is_end_stream() {
match Pin::new(&mut *incoming).poll_frame(&mut cx) {
Poll::Ready(Some(Ok(frame))) => match frame.into_data() {
Ok(data) => ret.extend_from_slice(&data),
Err(e) => error!(?e, "Failed to get data"),
},
Poll::Pending => {
cx.waker().wake_by_ref();
continue;
}
Poll::Ready(Some(Err(e))) => {
error!(?e, "Failed to get frame");
break;
}
Poll::Ready(None) => break,
}
}
}
StreamingBody::Stream { stream } => loop {
let mut stream = stream.write();
match stream.as_mut().poll_next(&mut cx) {
Poll::Ready(Some(data)) => ret.extend_from_slice(&data),
Poll::Ready(None) => break,
Poll::Pending => {
cx.waker().wake_by_ref();
continue;
}
}
},
}
ret
}
}
impl Body for StreamingBody {
type Data = Bytes;
type Error = Error;
fn poll_frame(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
match self.get_mut() {
Self::Null => Poll::Ready(None),
Self::Bytes { bytes } => Poll::Ready(bytes.take().map(|b| Ok(Frame::data(b)))),
Self::Incoming { incoming } => Pin::new(&mut *incoming.write()).poll_frame(cx),
Self::Stream { stream } => {
let mut stream = stream.write();
stream
.as_mut()
.poll_next(cx)
.map(|opt| opt.map(|i| Ok(Frame::data(i))))
}
}
}
fn is_end_stream(&self) -> bool {
match self {
Self::Null => true,
Self::Bytes { bytes } => bytes.is_none(),
Self::Incoming { incoming } => incoming.read().is_end_stream(),
Self::Stream { .. } => false,
}
}
fn size_hint(&self) -> SizeHint {
match self {
Self::Bytes { bytes: Some(bytes) } => SizeHint::with_exact(bytes.len() as _),
Self::Incoming { incoming } => incoming.read().size_hint(),
Self::Stream { stream } => {
let (min, max) = stream.read().size_hint();
let mut size = SizeHint::new();
size.set_lower(min as _);
if let Some(max) = max {
size.set_upper(max as _);
}
size
}
_ => Default::default(),
}
}
}
pub trait IntoStreamingBody {
fn into_streaming_body(self) -> StreamingBody;
}
impl IntoStreamingBody for Incoming {
fn into_streaming_body(self) -> StreamingBody {
StreamingBody::Incoming {
incoming: RwLock::new(self).into(),
}
}
}
impl<'a> IntoStreamingBody for &'a str {
fn into_streaming_body(self) -> StreamingBody {
IntoBytes::into(self).into_streaming_body()
}
}
impl<'a> IntoStreamingBody for &'a [u8] {
fn into_streaming_body(self) -> StreamingBody {
IntoBytes::into(self).into_streaming_body()
}
}
impl IntoStreamingBody for String {
fn into_streaming_body(self) -> StreamingBody {
IntoBytes::into(self).into_streaming_body()
}
}
impl IntoStreamingBody for Bytes {
fn into_streaming_body(self) -> StreamingBody {
StreamingBody::Bytes { bytes: Some(self) }
}
}
impl<S, T> IntoStreamingBody for Pin<Box<S>>
where
S: Stream<Item = T> + Send + Sync + 'static,
T: IntoBytes,
{
fn into_streaming_body(self) -> StreamingBody {
StreamingBody::Stream {
stream: Arc::new(RwLock::new(Box::pin(self.map(|i| i.into())))),
}
}
}
impl IntoStreamingBody for () {
fn into_streaming_body(self) -> StreamingBody {
StreamingBody::Null
}
}
impl From<StreamingBody> for String {
fn from(value: StreamingBody) -> Self {
String::from_utf8(value.into_vec()).unwrap_or_default()
}
}
impl From<StreamingBody> for Vec<u8> {
fn from(value: StreamingBody) -> Self {
value.into_vec()
}
}
impl From<StreamingBody> for () {
fn from(_value: StreamingBody) -> Self {}
}
impl From<()> for StreamingBody {
fn from(_value: ()) -> Self {
StreamingBody::Null
}
}
impl<T> From<StreamingBody> for Pin<Box<dyn Stream<Item = IoResult<T>> + Send + Sync>>
where
T: FromBytes,
{
fn from(value: StreamingBody) -> Self {
Box::pin(value.map(|i| i.map(FromBytes::from)))
}
}
impl Stream for StreamingBody {
type Item = IoResult<Bytes>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.poll_frame(cx).map(|i| {
i.map(|i| {
i.map_or_else(
|e| Err(IoError::other(e)),
|i| {
i.into_data().map_err(|e| {
IoError::other(format!("Failed to convert the data: {:?}", e))
})
},
)
})
})
}
}
#[cfg(feature = "serde_json")]
impl IntoStreamingBody for Value {
fn into_streaming_body(self) -> StreamingBody {
to_vec(&self).unwrap_or_default().into_streaming_body()
}
}
#[cfg(feature = "serde_json")]
impl From<StreamingBody> for Value {
fn from(value: StreamingBody) -> Self {
from_slice(&value.into_vec()).unwrap_or_default()
}
}