use crate::{Body, Headers, body::BodyType, h3::Frame};
use futures_lite::{AsyncRead, ready};
use std::{io, pin::Pin, task::Poll};
#[derive(Debug)]
pub struct H3Body {
body: BodyType,
header_written: bool,
}
impl From<BodyType> for H3Body {
fn from(body: BodyType) -> Self {
Self {
body,
header_written: false,
}
}
}
impl H3Body {
pub(crate) fn new(body: Body) -> Self {
body.0.into()
}
pub fn trailers(&mut self) -> Option<Headers> {
match &mut self.body {
BodyType::Streaming {
async_read, done, ..
} if *done => async_read.get_mut().as_mut().trailers(),
_ => None,
}
}
}
impl AsyncRead for H3Body {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
match &mut this.body {
BodyType::Empty => Poll::Ready(Ok(0)),
BodyType::Static { content, cursor } => {
let remaining = content.len() - *cursor;
if remaining == 0 {
return Poll::Ready(Ok(0));
}
let mut written = 0;
if !this.header_written {
let frame = Frame::Data(remaining as u64);
written += frame.encode(buf).ok_or_else(|| {
io::Error::new(
io::ErrorKind::WriteZero,
"buffer too small for frame header",
)
})?;
this.header_written = true;
}
let bytes = remaining.min(buf.len() - written);
buf[written..written + bytes].copy_from_slice(&content[*cursor..*cursor + bytes]);
*cursor += bytes;
Poll::Ready(Ok(written + bytes))
}
BodyType::Streaming {
async_read,
len: Some(len),
done,
progress,
} => {
if *done {
return Poll::Ready(Ok(0));
}
let header_len = if this.header_written {
0
} else {
Frame::Data(*len).encoded_len()
};
let max_bytes = (*len - *progress)
.try_into()
.unwrap_or(buf.len() - header_len)
.min(buf.len() - header_len);
let bytes = ready!(
async_read
.get_mut()
.as_mut()
.poll_read(cx, &mut buf[header_len..header_len + max_bytes])
)?;
if !this.header_written {
Frame::Data(*len).encode(buf);
this.header_written = true;
}
if bytes == 0 {
*done = true;
} else {
*progress += bytes as u64;
}
Poll::Ready(Ok(header_len + bytes))
}
BodyType::Streaming {
async_read,
len: None,
done,
progress,
} => {
if *done {
return Poll::Ready(Ok(0));
}
let reserved = Frame::Data(buf.len() as u64).encoded_len();
if buf.len() <= reserved {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::WriteZero,
"buffer too small for DATA frame",
)));
}
let bytes = ready!(
async_read
.get_mut()
.as_mut()
.poll_read(cx, &mut buf[reserved..])
)?;
if bytes == 0 {
*done = true;
return Poll::Ready(Ok(0));
}
*progress += bytes as u64;
let frame = Frame::Data(bytes as u64);
let header_len = frame.encode(buf).unwrap();
if header_len < reserved {
buf.copy_within(reserved..reserved + bytes, header_len);
}
Poll::Ready(Ok(header_len + bytes))
}
}
}
}