use std::{pin::Pin, task::Poll};
use bytes::Buf;
use futures_core::Stream;
use sync_wrapper::SyncWrapper;
pin_project_lite::pin_project! {
pub struct Body<D = bytes::Bytes> {
#[pin]
pub(crate) stream: BodyStream<D>,
}
}
impl<D> http_body::Body for Body<D>
where
D: Buf + From<Vec<u8>> + From<&'static [u8]> + 'static,
{
type Data = D;
type Error = crate::IOError;
#[inline]
fn poll_frame(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
self.as_mut()
.project()
.stream
.poll_next(cx)
.map(|p| p.map(|o| o.map(http_body::Frame::data)))
}
fn size_hint(&self) -> http_body::SizeHint {
match &self.stream {
BodyStream::Once { chunk: Some(Ok(d)) } => http_body::SizeHint::with_exact(
u64::try_from(d.remaining()).expect("usize should fit in u64"),
),
BodyStream::Once { .. } => http_body::SizeHint::with_exact(0),
BodyStream::ExactLen { s } => http_body::SizeHint::with_exact(s.remaining),
BodyStream::Multipart { s } => http_body::SizeHint::with_exact(s.remaining()),
}
}
fn is_end_stream(&self) -> bool {
match &self.stream {
BodyStream::Once { chunk } => chunk.is_none(),
BodyStream::ExactLen { s } => s.remaining == 0,
BodyStream::Multipart { s } => s.remaining() == 0,
}
}
}
impl<D> Body<D> {
#[inline]
pub fn empty() -> Self {
Self {
stream: BodyStream::Once { chunk: None },
}
}
}
impl<D> From<&'static [u8]> for Body<D>
where
D: From<&'static [u8]>,
{
#[inline]
fn from(value: &'static [u8]) -> Self {
Self {
stream: BodyStream::Once {
chunk: Some(Ok(value.into())),
},
}
}
}
impl<D> From<&'static str> for Body<D>
where
D: From<&'static [u8]>,
{
#[inline]
fn from(value: &'static str) -> Self {
Self {
stream: BodyStream::Once {
chunk: Some(Ok(value.as_bytes().into())),
},
}
}
}
impl<D> From<Vec<u8>> for Body<D>
where
D: From<Vec<u8>>,
{
#[inline]
fn from(value: Vec<u8>) -> Self {
Self {
stream: BodyStream::Once {
chunk: Some(Ok(value.into())),
},
}
}
}
impl<D> From<String> for Body<D>
where
D: From<Vec<u8>>,
{
#[inline]
fn from(value: String) -> Self {
Self {
stream: BodyStream::Once {
chunk: Some(Ok(value.into_bytes().into())),
},
}
}
}
pin_project_lite::pin_project! {
#[project = BodyStreamProj]
pub(crate) enum BodyStream<D> {
Once {
chunk: Option<Result<D, crate::IOError>>,
},
ExactLen {
#[pin]
s: ExactLenStream<D>,
},
Multipart {
#[pin]
s: crate::serving::MultipartStream<D>,
},
}
}
impl<D> Stream for BodyStream<D>
where
D: 'static + Buf + From<Vec<u8>> + From<&'static [u8]>,
{
type Item = Result<D, crate::IOError>;
fn poll_next(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Result<D, crate::IOError>>> {
match self.project() {
BodyStreamProj::Once { chunk } => Poll::Ready(chunk.take()),
BodyStreamProj::ExactLen { s } => s.poll_next(cx),
BodyStreamProj::Multipart { s } => s.poll_next(cx),
}
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
struct StreamTooShortError {
remaining: u64,
}
impl std::fmt::Display for StreamTooShortError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"stream ended with {} bytes still expected",
self.remaining
)
}
}
impl std::error::Error for StreamTooShortError {}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
struct StreamTooLongError {
extra: u64,
}
impl std::fmt::Display for StreamTooLongError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"stream returned (at least) {} bytes more than expected",
self.extra
)
}
}
impl std::error::Error for StreamTooLongError {}
pub(crate) struct ExactLenStream<D> {
#[allow(clippy::type_complexity)]
stream: SyncWrapper<Pin<Box<dyn Stream<Item = Result<D, crate::IOError>> + Send>>>,
remaining: u64,
}
impl<D> ExactLenStream<D> {
pub(crate) fn new(
len: u64,
stream: Pin<Box<dyn Stream<Item = Result<D, crate::IOError>> + Send>>,
) -> Self {
Self {
stream: SyncWrapper::new(stream),
remaining: len,
}
}
}
impl<D> futures_core::Stream for ExactLenStream<D>
where
D: Buf,
{
type Item = Result<D, crate::IOError>;
fn poll_next(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Result<D, crate::IOError>>> {
let this = Pin::into_inner(self);
match this.stream.get_mut().as_mut().poll_next(cx) {
Poll::Ready(Some(Ok(d))) => {
let d_len = crate::as_u64(d.remaining());
let new_rem = this.remaining.checked_sub(d_len);
if let Some(new_rem) = new_rem {
this.remaining = new_rem;
Poll::Ready(Some(Ok(d)))
} else {
let remaining = std::mem::take(&mut this.remaining); Poll::Ready(Some(Err(crate::IOError::other(StreamTooLongError {
extra: d_len - remaining,
}))))
}
}
Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
Poll::Ready(None) => {
if this.remaining != 0 {
let remaining = std::mem::take(&mut this.remaining); return Poll::Ready(Some(Err(crate::IOError::other(StreamTooShortError {
remaining,
}))));
}
Poll::Ready(None)
}
Poll::Pending => Poll::Pending,
}
}
}
const _: () = {
fn _assert() {
fn assert_bounds<T: Sync + Send>() {}
assert_bounds::<Body>();
}
};
#[cfg(test)]
mod tests {
use bytes::Bytes;
use futures_util::StreamExt as _;
use super::*;
#[tokio::test]
async fn correct_exact_len_stream() {
let inner = futures_util::stream::iter(vec![Ok("h".into()), Ok("ello".into())]);
let mut exact_len = std::pin::pin!(ExactLenStream::<Bytes>::new(5, Box::pin(inner)));
assert_eq!(exact_len.remaining, 5);
let frame = exact_len.next().await.unwrap().unwrap();
assert_eq!(frame.remaining(), 1);
assert_eq!(exact_len.remaining, 4);
let frame = exact_len.next().await.unwrap().unwrap();
assert_eq!(frame.remaining(), 4);
assert_eq!(exact_len.remaining, 0);
assert!(exact_len.next().await.is_none()); assert!(exact_len.next().await.is_none()); }
#[tokio::test]
async fn short_exact_len_stream() {
let inner = futures_util::stream::iter(vec![Ok("hello".into())]);
let mut exact_len = std::pin::pin!(ExactLenStream::<Bytes>::new(10, Box::pin(inner)));
assert_eq!(exact_len.remaining, 10);
let frame = exact_len.next().await.unwrap().unwrap();
assert_eq!(frame.remaining(), 5);
assert_eq!(exact_len.remaining, 5);
let err: crate::IOError = exact_len.next().await.unwrap().unwrap_err();
let err = err.downcast::<StreamTooShortError>().unwrap();
assert_eq!(err, StreamTooShortError { remaining: 5 });
assert!(exact_len.next().await.is_none()); }
#[tokio::test]
async fn long_exact_len_stream() {
let inner = futures_util::stream::iter(vec![Ok("h".into()), Ok("ello".into())]);
let mut exact_len = std::pin::pin!(ExactLenStream::<Bytes>::new(3, Box::pin(inner)));
assert_eq!(exact_len.remaining, 3);
let frame = exact_len.next().await.unwrap().unwrap();
assert_eq!(frame.remaining(), 1);
assert_eq!(exact_len.remaining, 2);
let err = exact_len.next().await.unwrap().unwrap_err();
let err = err.downcast::<StreamTooLongError>().unwrap();
assert_eq!(err, StreamTooLongError { extra: 2 });
assert!(exact_len.next().await.is_none()); }
}