use std::{
fmt,
pin::Pin,
task::{Context, Poll},
};
use bytes::Bytes;
use futures::{Stream, StreamExt, TryStreamExt};
use http_body::SizeHint;
use crate::header::HeaderMap;
#[derive(Debug)]
pub enum BodyComponent {
Data(Bytes),
Trailers(HeaderMap),
}
pub enum Body {
Bytes(Vec<u8>),
Stream {
size_hint: Option<usize>,
stream: BodyStream,
},
}
pub type BodyStream =
Pin<Box<dyn Stream<Item = Result<BodyComponent, anyhow::Error>> + Send + Sync + 'static>>;
pub struct BodyWrapper {
body: Body,
has_written_blob: bool,
bytes_streamed: usize,
pending_trailers: Option<HeaderMap>,
eof: bool,
}
impl From<Body> for BodyWrapper {
fn from(body: Body) -> Self {
Self {
body,
has_written_blob: false,
bytes_streamed: 0,
pending_trailers: None,
eof: false,
}
}
}
impl Into<Body> for BodyWrapper {
fn into(self) -> Body {
self.body
}
}
impl Body {
pub fn new() -> Self {
Self::default()
}
pub fn empty() -> Self {
Self::default()
}
pub fn bytes_and_trailers(bytes: Vec<u8>, trailers: HeaderMap) -> Self {
Body::Stream {
size_hint: Some(bytes.len()),
stream: Box::pin(futures::stream::iter([
Ok(BodyComponent::Data(bytes.into())),
Ok(BodyComponent::Trailers(trailers)),
])),
}
}
pub fn trailers(trailers: HeaderMap) -> Self {
Body::Stream {
size_hint: Some(0),
stream: Box::pin(futures::stream::once(async move {
Ok(BodyComponent::Trailers(trailers))
})),
}
}
pub async fn collect(self) -> Result<Vec<u8>, anyhow::Error> {
match self {
Body::Bytes(x) => Ok(x),
Body::Stream {
size_hint,
mut stream,
} => {
let mut out = Vec::with_capacity(size_hint.unwrap_or_default());
while let Some(component) = stream.next().await.transpose()? {
match component {
BodyComponent::Data(data) => {
out.extend_from_slice(&data[..]);
}
BodyComponent::Trailers(_) => (),
}
}
Ok(out)
}
}
}
pub fn into_stream(
self,
) -> Pin<Box<dyn Stream<Item = Result<BodyComponent, anyhow::Error>> + Send + Sync + 'static>>
{
match self {
Body::Bytes(bytes) => Box::pin(futures::stream::once(async move {
Ok(BodyComponent::Data(bytes.into()))
})),
Body::Stream {
size_hint: _,
stream,
} => stream,
}
}
}
impl Into<Body> for Vec<u8> {
fn into(self) -> Body {
Body::Bytes(self)
}
}
impl Into<Body> for String {
fn into(self) -> Body {
Body::Bytes(self.into_bytes())
}
}
impl Into<Body> for &str {
fn into(self) -> Body {
Body::Bytes(self.as_bytes().to_vec())
}
}
impl Into<Body> for () {
fn into(self) -> Body {
Body::Bytes(vec![])
}
}
impl From<BodyStream> for Body {
fn from(stream: BodyStream) -> Self {
Self::Stream {
size_hint: None,
stream,
}
}
}
impl From<Pin<Box<dyn Stream<Item = Result<Bytes, anyhow::Error>> + Send + Sync + 'static>>>
for Body
{
fn from(
value: Pin<Box<dyn Stream<Item = Result<Bytes, anyhow::Error>> + Send + Sync + 'static>>,
) -> Self {
Self::Stream {
size_hint: None,
stream: Box::pin(value.map_ok(|x| BodyComponent::Data(x))),
}
}
}
impl fmt::Debug for Body {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Bytes(arg0) => f.debug_tuple("Bytes").field(arg0).finish(),
Self::Stream { size_hint, .. } => f.debug_tuple("Stream").field(size_hint).finish(),
}
}
}
impl Default for Body {
fn default() -> Self {
Body::Bytes(vec![])
}
}
impl http_body::Body for BodyWrapper {
type Data = Bytes;
type Error = anyhow::Error;
fn poll_data(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
match &mut self.body {
Body::Bytes(bytes) => {
let bytes = std::mem::take(bytes);
if self.has_written_blob {
return Poll::Ready(None);
}
self.eof = true;
self.has_written_blob = true;
Poll::Ready(Some(Ok(bytes.into())))
}
Body::Stream {
size_hint: _,
stream,
} => match stream.poll_next_unpin(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(None) => {
self.eof = true;
Poll::Ready(None)
}
Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
Poll::Ready(Some(Ok(BodyComponent::Data(data)))) => {
self.bytes_streamed += data.len();
Poll::Ready(Some(Ok(data)))
}
Poll::Ready(Some(Ok(BodyComponent::Trailers(trailers)))) => {
self.pending_trailers = Some(trailers);
Poll::Ready(None)
}
},
}
}
fn poll_trailers(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Result<Option<headers::HeaderMap>, Self::Error>> {
self.eof = true;
if let Some(pending_trailers) = self.pending_trailers.take() {
Poll::Ready(Ok(Some(pending_trailers.into())))
} else {
Poll::Ready(Ok(None))
}
}
fn is_end_stream(&self) -> bool {
self.eof
}
fn size_hint(&self) -> SizeHint {
if self.eof {
return SizeHint::with_exact(0);
}
match &self.body {
Body::Bytes(bytes) => SizeHint::with_exact(bytes.len() as u64),
Body::Stream {
size_hint,
stream: _,
} => size_hint
.map(|x| SizeHint::with_exact(x as u64))
.unwrap_or_default(),
}
}
}