use bytes::{Buf, Bytes, BytesMut};
use http_body::{Body as HttpBody, Frame};
use pin_project::pin_project;
use std::fmt;
use std::pin::Pin;
use std::task::{Context, Poll};
#[pin_project(project = RemainingProj)]
#[derive(Debug)]
pub enum Remaining<B>
where
B: HttpBody,
{
Body(#[pin] B),
Error(Option<B::Error>),
}
#[pin_project]
pub struct PartialBufferedBody<B>
where
B: HttpBody,
{
prefix: Option<Bytes>,
#[pin]
remaining: Remaining<B>,
}
impl<B> PartialBufferedBody<B>
where
B: HttpBody,
{
pub fn new(prefix: Option<Bytes>, remaining: Remaining<B>) -> Self {
Self { prefix, remaining }
}
pub fn prefix(&self) -> Option<&Bytes> {
self.prefix.as_ref()
}
pub fn into_parts(self) -> (Option<Bytes>, Remaining<B>) {
(self.prefix, self.remaining)
}
}
impl<B: HttpBody> HttpBody for PartialBufferedBody<B> {
type Data = Bytes;
type Error = B::Error;
fn poll_frame(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
let this = self.project();
if let Some(prefix) = this.prefix.take() {
return Poll::Ready(Some(Ok(Frame::data(prefix))));
}
match this.remaining.project() {
RemainingProj::Body(body) => match body.poll_frame(cx) {
Poll::Ready(Some(Ok(frame))) => {
let frame = frame.map_data(|mut data| data.copy_to_bytes(data.remaining()));
Poll::Ready(Some(Ok(frame)))
}
Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
},
RemainingProj::Error(error) => {
if let Some(err) = error.take() {
Poll::Ready(Some(Err(err)))
} else {
Poll::Ready(None)
}
}
}
}
fn size_hint(&self) -> http_body::SizeHint {
let prefix_len = self.prefix.as_ref().map(|b| b.len() as u64).unwrap_or(0);
match &self.remaining {
Remaining::Body(body) => {
let hint = body.size_hint();
let lower = hint.lower().saturating_add(prefix_len);
let upper = hint.upper().map(|u| {
u.saturating_add(prefix_len).max(lower)
});
let mut result = http_body::SizeHint::new();
result.set_lower(lower);
if let Some(u) = upper {
result.set_upper(u);
}
result
}
Remaining::Error(_) => http_body::SizeHint::with_exact(prefix_len),
}
}
fn is_end_stream(&self) -> bool {
if self.prefix.is_some() {
return false;
}
match &self.remaining {
Remaining::Body(body) => body.is_end_stream(),
Remaining::Error(err) => err.is_none(),
}
}
}
#[pin_project(project = BufferedBodyProj)]
pub enum BufferedBody<B>
where
B: HttpBody,
{
Complete(Option<Bytes>),
Partial(#[pin] PartialBufferedBody<B>),
Passthrough(#[pin] B),
}
impl<B> HttpBody for BufferedBody<B>
where
B: HttpBody,
{
type Data = Bytes;
type Error = B::Error;
fn poll_frame(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
match self.project() {
BufferedBodyProj::Complete(data) => {
if let Some(bytes) = data.take() {
Poll::Ready(Some(Ok(Frame::data(bytes))))
} else {
Poll::Ready(None)
}
}
BufferedBodyProj::Partial(partial) => {
partial.poll_frame(cx)
}
BufferedBodyProj::Passthrough(body) => {
match body.poll_frame(cx) {
Poll::Ready(Some(Ok(frame))) => {
let frame = frame.map_data(|mut data| data.copy_to_bytes(data.remaining()));
Poll::Ready(Some(Ok(frame)))
}
Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
}
}
fn size_hint(&self) -> http_body::SizeHint {
match self {
BufferedBody::Complete(Some(bytes)) => {
let len = bytes.len() as u64;
http_body::SizeHint::with_exact(len)
}
BufferedBody::Complete(None) => http_body::SizeHint::with_exact(0),
BufferedBody::Partial(partial) => partial.size_hint(),
BufferedBody::Passthrough(body) => body.size_hint(),
}
}
fn is_end_stream(&self) -> bool {
match self {
BufferedBody::Complete(None) => true,
BufferedBody::Complete(Some(_)) => false,
BufferedBody::Partial(partial) => partial.is_end_stream(),
BufferedBody::Passthrough(body) => body.is_end_stream(),
}
}
}
#[derive(Debug)]
pub enum CollectExactResult<B: HttpBody> {
AtLeast {
buffered: Bytes,
remaining: Option<Remaining<B>>,
},
Incomplete {
buffered: Option<Bytes>,
error: Option<B::Error>,
},
}
impl<B: HttpBody> CollectExactResult<B> {
pub fn into_buffered_body(self) -> BufferedBody<B> {
match self {
CollectExactResult::AtLeast {
buffered,
remaining,
} => match remaining {
Some(rem) => BufferedBody::Partial(PartialBufferedBody::new(Some(buffered), rem)),
None => BufferedBody::Complete(Some(buffered)),
},
CollectExactResult::Incomplete { buffered, error } => match error {
Some(err) => BufferedBody::Partial(PartialBufferedBody::new(
buffered,
Remaining::Error(Some(err)),
)),
None => BufferedBody::Complete(buffered),
},
}
}
}
fn combine_bytes(prefix: Option<Bytes>, data: Bytes) -> Bytes {
match prefix {
Some(prefix_bytes) if !data.is_empty() => {
let mut buf = BytesMut::from(prefix_bytes.as_ref());
buf.extend_from_slice(&data);
buf.freeze()
}
Some(prefix_bytes) => prefix_bytes,
None => data,
}
}
impl<B> BufferedBody<B>
where
B: HttpBody,
{
pub async fn collect(self) -> Result<Bytes, Self>
where
B::Data: Send,
{
use http_body_util::BodyExt;
match self {
BufferedBody::Complete(Some(bytes)) => Ok(bytes),
BufferedBody::Complete(None) => Ok(Bytes::new()),
BufferedBody::Passthrough(body) => match body.collect().await {
Ok(collected) => Ok(collected.to_bytes()),
Err(err) => Err(BufferedBody::Partial(PartialBufferedBody::new(
None,
Remaining::Error(Some(err)),
))),
},
BufferedBody::Partial(partial) => {
let (prefix, remaining) = partial.into_parts();
match remaining {
Remaining::Body(body) => match body.collect().await {
Ok(collected) => {
if let Some(prefix_bytes) = prefix {
let mut combined = BytesMut::from(prefix_bytes.as_ref());
combined.extend_from_slice(&collected.to_bytes());
Ok(combined.freeze())
} else {
Ok(collected.to_bytes())
}
}
Err(err) => Err(BufferedBody::Partial(PartialBufferedBody::new(
prefix,
Remaining::Error(Some(err)),
))),
},
Remaining::Error(err) => Err(BufferedBody::Partial(PartialBufferedBody::new(
prefix,
Remaining::Error(err),
))),
}
}
}
}
pub async fn collect_exact(self, limit_bytes: usize) -> CollectExactResult<B>
where
B: Unpin,
{
match self {
BufferedBody::Complete(Some(data)) => {
if data.len() >= limit_bytes {
CollectExactResult::AtLeast {
buffered: data,
remaining: None,
}
} else {
CollectExactResult::Incomplete {
buffered: Some(data),
error: None,
}
}
}
BufferedBody::Complete(None) => {
CollectExactResult::Incomplete {
buffered: None,
error: None,
}
}
BufferedBody::Partial(partial) => {
let (prefix, remaining) = partial.into_parts();
match prefix {
Some(buffered) if buffered.len() >= limit_bytes => {
CollectExactResult::AtLeast {
buffered,
remaining: Some(remaining),
}
}
prefix => {
let prefix_len = prefix.as_ref().map(|p| p.len()).unwrap_or(0);
match remaining {
Remaining::Body(stream) => {
let needed = limit_bytes - prefix_len;
let result = collect_exact_from_stream(stream, needed).await;
match result {
CollectExactResult::AtLeast {
buffered: new_bytes,
remaining,
} => {
let combined = combine_bytes(prefix, new_bytes);
CollectExactResult::AtLeast {
buffered: combined,
remaining,
}
}
CollectExactResult::Incomplete {
buffered: new_bytes,
error,
} => {
let combined = if let Some(new) = new_bytes {
Some(combine_bytes(prefix, new))
} else {
prefix
};
CollectExactResult::Incomplete {
buffered: combined,
error,
}
}
}
}
Remaining::Error(error) => {
CollectExactResult::Incomplete {
buffered: prefix,
error,
}
}
}
}
}
}
BufferedBody::Passthrough(stream) => {
collect_exact_from_stream(stream, limit_bytes).await
}
}
}
}
async fn collect_exact_from_stream<B>(mut stream: B, limit_bytes: usize) -> CollectExactResult<B>
where
B: HttpBody + Unpin,
{
use http_body_util::BodyExt;
let mut buffer = BytesMut::new();
while buffer.len() < limit_bytes {
match stream.frame().await {
Some(Ok(frame)) => {
if let Ok(mut data) = frame.into_data() {
buffer.extend_from_slice(&data.copy_to_bytes(data.remaining()));
}
}
Some(Err(error)) => {
return CollectExactResult::Incomplete {
buffered: if buffer.is_empty() {
None
} else {
Some(buffer.freeze())
},
error: Some(error),
};
}
None => {
return CollectExactResult::Incomplete {
buffered: if buffer.is_empty() {
None
} else {
Some(buffer.freeze())
},
error: None,
};
}
}
}
CollectExactResult::AtLeast {
buffered: buffer.freeze(),
remaining: Some(Remaining::Body(stream)),
}
}
impl<B> fmt::Debug for BufferedBody<B>
where
B: HttpBody,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
BufferedBody::Complete(Some(bytes)) => f
.debug_tuple("Complete")
.field(&format!("{} bytes", bytes.len()))
.finish(),
BufferedBody::Complete(None) => f.debug_tuple("Complete").field(&"consumed").finish(),
BufferedBody::Partial(partial) => {
let prefix_len = partial.prefix().map(|b| b.len()).unwrap_or(0);
f.debug_struct("Partial")
.field("prefix_len", &prefix_len)
.field("remaining", &"...")
.finish()
}
BufferedBody::Passthrough(_) => f.debug_tuple("Passthrough").field(&"...").finish(),
}
}
}