use bytes::{Buf, Bytes};
use http_body::{Body, Frame, SizeHint};
use http_body_util::Full;
use pin_project_lite::pin_project;
use std::{
pin::Pin,
task::{Context, Poll},
};
use tower::BoxError;
pin_project! {
#[project = ErrorBodyInnerProj]
enum ErrorBodyInner<B> {
Passthrough {
#[pin]
inner: B,
},
Rendered {
#[pin]
inner: Full<Bytes>,
},
}
}
pin_project! {
pub struct ErrorBody<B> {
#[pin]
inner: ErrorBodyInner<B>,
}
}
impl<B> ErrorBody<B> {
pub(crate) fn passthrough(body: B) -> Self {
Self {
inner: ErrorBodyInner::Passthrough { inner: body },
}
}
pub(crate) fn error(bytes: Bytes) -> Self {
Self {
inner: ErrorBodyInner::Rendered {
inner: Full::new(bytes),
},
}
}
}
impl<B> Body for ErrorBody<B>
where
B: Body,
B::Data: Buf,
B::Error: Into<BoxError>,
{
type Data = Bytes;
type Error = BoxError;
fn poll_frame(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
match self.project().inner.project() {
ErrorBodyInnerProj::Passthrough { inner } => match inner.poll_frame(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(None) => Poll::Ready(None),
Poll::Ready(Some(Ok(frame))) => {
let frame = frame.map_data(|mut data| data.copy_to_bytes(data.remaining()));
Poll::Ready(Some(Ok(frame)))
}
Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err.into()))),
},
ErrorBodyInnerProj::Rendered { inner } => match inner.poll_frame(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(None) => Poll::Ready(None),
Poll::Ready(Some(Ok(frame))) => Poll::Ready(Some(Ok(frame))),
Poll::Ready(Some(Err(infallible))) => match infallible {},
},
}
}
fn is_end_stream(&self) -> bool {
match &self.inner {
ErrorBodyInner::Passthrough { inner } => inner.is_end_stream(),
ErrorBodyInner::Rendered { inner } => inner.is_end_stream(),
}
}
fn size_hint(&self) -> SizeHint {
match &self.inner {
ErrorBodyInner::Passthrough { inner } => inner.size_hint(),
ErrorBodyInner::Rendered { inner } => inner.size_hint(),
}
}
}
impl<B> Default for ErrorBody<B>
where
B: Default,
{
fn default() -> Self {
Self {
inner: ErrorBodyInner::Rendered {
inner: Default::default(),
},
}
}
}
#[cfg(test)]
mod test {
use crate::tower_http::ErrorBody;
use bytes::Bytes;
use http_body::Body;
use http_body_util::Full;
#[test]
fn test_error_body_size_hint_passthrough() {
let inner = Full::new(Bytes::from_static(b"test data"));
let body: ErrorBody<Full<Bytes>> = ErrorBody::passthrough(inner);
let hint = body.size_hint();
assert_eq!(hint.exact(), Some(9)); }
#[test]
fn test_error_body_size_hint_error() {
let body: ErrorBody<Full<Bytes>> = ErrorBody::error(Bytes::from_static(b"error message"));
let hint = body.size_hint();
assert_eq!(hint.exact(), Some(13)); }
#[test]
fn test_error_body_default() {
let body: ErrorBody<Full<Bytes>> = ErrorBody::default();
let hint = body.size_hint();
assert_eq!(hint.exact(), Some(0));
}
}