use std::pin::Pin;
use std::task::{Context, Poll};
use bytes::Bytes;
use http_body::{Body, Frame, SizeHint};
use http_body_util::combinators::UnsyncBoxBody;
use http_body_util::BodyExt;
use http_body_util::Full;
pub type BoxError = Box<dyn std::error::Error + Send + Sync>;
pub type ReqBody = UnsyncBoxBody<Bytes, BoxError>;
type BoxStreamBody = Pin<Box<dyn Body<Data = Bytes, Error = BoxError> + Send>>;
pub struct RespBody {
kind: BodyKind,
}
enum BodyKind {
Full(Full<Bytes>),
Stream(BoxStreamBody),
}
impl RespBody {
pub fn new(body: Bytes) -> Self {
Self {
kind: BodyKind::Full(Full::new(body)),
}
}
pub fn stream<B>(body: B) -> Self
where
B: Body<Data = Bytes, Error = BoxError> + Send + 'static,
{
Self {
kind: BodyKind::Stream(Box::pin(body)),
}
}
pub fn stream_capped<B>(body: B, max_bytes: u64) -> Self
where
B: Body<Data = Bytes, Error = BoxError> + Send + 'static,
{
Self {
kind: BodyKind::Stream(Box::pin(CappedBody {
inner: Box::pin(body),
emitted: 0,
limit: max_bytes,
})),
}
}
}
struct CappedBody {
inner: BoxStreamBody,
emitted: u64,
limit: u64,
}
impl Body for CappedBody {
type Data = Bytes;
type Error = BoxError;
fn poll_frame(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
let this = self.get_mut();
match this.inner.as_mut().poll_frame(cx) {
Poll::Ready(Some(Ok(frame))) => {
if let Some(data) = frame.data_ref() {
this.emitted = this.emitted.saturating_add(data.len() as u64);
if this.emitted > this.limit {
return Poll::Ready(Some(Err(format!(
"response body exceeded the {}-byte limit",
this.limit
)
.into())));
}
}
Poll::Ready(Some(Ok(frame)))
}
other => other,
}
}
fn is_end_stream(&self) -> bool {
self.inner.is_end_stream()
}
fn size_hint(&self) -> SizeHint {
self.inner.size_hint()
}
}
impl Body for RespBody {
type Data = Bytes;
type Error = BoxError;
fn poll_frame(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
match &mut self.get_mut().kind {
BodyKind::Full(full) => Pin::new(full)
.poll_frame(cx)
.map_err(|never| match never {}),
BodyKind::Stream(stream) => stream.as_mut().poll_frame(cx),
}
}
fn is_end_stream(&self) -> bool {
match &self.kind {
BodyKind::Full(full) => full.is_end_stream(),
BodyKind::Stream(stream) => stream.is_end_stream(),
}
}
fn size_hint(&self) -> SizeHint {
match &self.kind {
BodyKind::Full(full) => full.size_hint(),
BodyKind::Stream(stream) => stream.size_hint(),
}
}
}
pub fn box_body<B>(body: B) -> ReqBody
where
B: hyper::body::Body<Data = Bytes> + Send + 'static,
B::Error: Into<BoxError>,
{
body.map_err(Into::into).boxed_unsync()
}
#[cfg(test)]
mod tests {
use super::*;
use http_body_util::StreamBody;
async fn collect_chunks(body: RespBody) -> Vec<Bytes> {
let collected = body.collect().await.expect("body collects");
vec![collected.to_bytes()]
}
#[tokio::test]
async fn full_body_yields_its_buffer() {
let body = RespBody::new(Bytes::from_static(b"hello"));
let chunks = collect_chunks(body).await;
assert_eq!(chunks, vec![Bytes::from_static(b"hello")]);
}
#[tokio::test]
async fn streaming_body_yields_each_frame() {
let frames = futures_util::stream::iter(vec![
Ok::<_, BoxError>(Frame::data(Bytes::from_static(b"a"))),
Ok(Frame::data(Bytes::from_static(b"b"))),
Ok(Frame::data(Bytes::from_static(b"c"))),
]);
let body = RespBody::stream(StreamBody::new(frames));
let mut out = Vec::new();
let mut body = body;
loop {
let frame = std::future::poll_fn(|cx| Pin::new(&mut body).poll_frame(cx)).await;
match frame {
Some(Ok(frame)) => {
if let Ok(data) = frame.into_data() {
out.push(data);
}
}
Some(Err(error)) => panic!("unexpected body error: {error}"),
None => break,
}
}
assert_eq!(
out,
vec![
Bytes::from_static(b"a"),
Bytes::from_static(b"b"),
Bytes::from_static(b"c"),
]
);
}
#[tokio::test]
async fn capped_stream_errors_once_it_exceeds_the_limit() {
let frames = futures_util::stream::iter(vec![
Ok::<_, BoxError>(Frame::data(Bytes::from_static(b"aaaa"))),
Ok(Frame::data(Bytes::from_static(b"bbbb"))),
Ok(Frame::data(Bytes::from_static(b"cccc"))),
]);
let mut body = RespBody::stream_capped(StreamBody::new(frames), 10);
let mut delivered = 0usize;
let mut errored = false;
loop {
let frame = std::future::poll_fn(|cx| Pin::new(&mut body).poll_frame(cx)).await;
match frame {
Some(Ok(frame)) => {
if let Ok(data) = frame.into_data() {
delivered += data.len();
}
}
Some(Err(_)) => {
errored = true;
break;
}
None => break,
}
}
assert!(errored, "the body should error once it exceeds the cap");
assert_eq!(delivered, 8, "only the frames within the cap are delivered");
}
}