use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use bytes::{Buf, Bytes};
use h3::server::RequestStream;
#[derive(thiserror::Error, Debug)]
pub enum H3BodyError {
#[error("h3 error: {0}")]
StreamError(#[from] h3::error::StreamError),
#[error("unexpected data after trailers")]
DataAfterTrailers,
#[error("the given buffer size hint was exceeded")]
BufferExceeded,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum State {
Data(Option<u64>),
Trailers,
Done,
}
pub struct QuicIncomingBody<S> {
stream: RequestStream<S, Bytes>,
state: State,
}
impl<S> QuicIncomingBody<S> {
pub fn new(stream: RequestStream<S, Bytes>, size_hint: Option<u64>) -> Self {
Self {
stream,
state: State::Data(size_hint),
}
}
}
impl<S: h3::quic::RecvStream> http_body::Body for QuicIncomingBody<S> {
type Data = Bytes;
type Error = H3BodyError;
fn poll_frame(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
let QuicIncomingBody { stream, state } = self.as_mut().get_mut();
if *state == State::Done {
return Poll::Ready(None);
}
if let State::Data(remaining) = state {
match stream.poll_recv_data(cx) {
Poll::Ready(Ok(Some(mut buf))) => {
let buf_size = buf.remaining() as u64;
if let Some(remaining) = remaining {
if buf_size > *remaining {
*state = State::Done;
return Poll::Ready(Some(Err(H3BodyError::BufferExceeded)));
}
*remaining -= buf_size;
}
return Poll::Ready(Some(Ok(http_body::Frame::data(buf.copy_to_bytes(buf_size as usize)))));
}
Poll::Ready(Ok(None)) => {
*state = State::Trailers;
}
Poll::Ready(Err(err)) => {
*state = State::Done;
return Poll::Ready(Some(Err(err.into())));
}
Poll::Pending => {
return Poll::Pending;
}
}
}
let resp = match stream.poll_recv_data(cx) {
Poll::Ready(Ok(None)) => match std::pin::pin!(stream.recv_trailers()).poll(cx) {
Poll::Ready(Ok(Some(trailers))) => Poll::Ready(Some(Ok(http_body::Frame::trailers(trailers)))),
Poll::Pending => {
#[cfg(feature = "tracing")]
tracing::warn!("recv_trailers is pending");
Poll::Ready(None)
}
Poll::Ready(Ok(None)) => Poll::Ready(None),
Poll::Ready(Err(err)) => Poll::Ready(Some(Err(err.into()))),
},
Poll::Ready(Ok(Some(_))) => Poll::Ready(Some(Err(H3BodyError::DataAfterTrailers))),
Poll::Ready(Err(err)) => Poll::Ready(Some(Err(err.into()))),
Poll::Pending => return Poll::Pending,
};
*state = State::Done;
resp
}
fn size_hint(&self) -> http_body::SizeHint {
match self.state {
State::Data(Some(remaining)) => http_body::SizeHint::with_exact(remaining),
State::Data(None) => http_body::SizeHint::default(),
State::Trailers | State::Done => http_body::SizeHint::with_exact(0),
}
}
fn is_end_stream(&self) -> bool {
match self.state {
State::Data(Some(0)) | State::Trailers | State::Done => true,
State::Data(_) => false,
}
}
}