use std::{io, fmt};
use std::task::{Context, Poll};
use std::pin::Pin;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncSeek, AsyncSeekExt, ReadBuf};
#[derive(Debug)]
pub struct Body<'r> {
size: Option<usize>,
inner: Inner<'r>,
max_chunk: usize,
}
pub trait AsyncReadSeek: AsyncRead + AsyncSeek { }
impl<T: AsyncRead + AsyncSeek> AsyncReadSeek for T { }
type SizedBody<'r> = Pin<Box<dyn AsyncReadSeek + Send + 'r>>;
type UnsizedBody<'r> = Pin<Box<dyn AsyncRead + Send + 'r>>;
enum Inner<'r> {
Seekable(SizedBody<'r>),
Unsized(UnsizedBody<'r>),
Phantom(SizedBody<'r>),
None,
}
impl Default for Body<'_> {
fn default() -> Self {
Body {
size: Some(0),
inner: Inner::None,
max_chunk: Body::DEFAULT_MAX_CHUNK,
}
}
}
impl<'r> Body<'r> {
pub const DEFAULT_MAX_CHUNK: usize = 4096;
pub(crate) fn with_sized<T>(body: T, preset_size: Option<usize>) -> Self
where T: AsyncReadSeek + Send + 'r
{
Body {
size: preset_size,
inner: Inner::Seekable(Box::pin(body)),
max_chunk: Body::DEFAULT_MAX_CHUNK,
}
}
pub(crate) fn with_unsized<T>(body: T) -> Self
where T: AsyncRead + Send + 'r
{
Body {
size: None,
inner: Inner::Unsized(Box::pin(body)),
max_chunk: Body::DEFAULT_MAX_CHUNK,
}
}
pub(crate) fn set_max_chunk_size(&mut self, max_chunk: usize) {
self.max_chunk = max_chunk;
}
pub(crate) fn strip(&mut self) {
let body = std::mem::take(self);
*self = match body.inner {
Inner::Seekable(b) | Inner::Phantom(b) => Body {
size: body.size,
inner: Inner::Phantom(b),
max_chunk: body.max_chunk,
},
Inner::Unsized(_) | Inner::None => Body::default()
};
}
#[inline(always)]
pub fn is_none(&self) -> bool {
matches!(self.inner, Inner::None)
}
#[inline(always)]
pub fn is_some(&self) -> bool {
!self.is_none()
}
pub fn preset_size(&self) -> Option<usize> {
self.size
}
pub fn max_chunk_size(&self) -> usize {
self.max_chunk
}
pub async fn size(&mut self) -> Option<usize> {
if let Some(size) = self.size {
return Some(size);
}
if let Inner::Seekable(ref mut body) | Inner::Phantom(ref mut body) = self.inner {
let pos = body.seek(io::SeekFrom::Current(0)).await.ok()?;
let end = body.seek(io::SeekFrom::End(0)).await.ok()?;
body.seek(io::SeekFrom::Start(pos)).await.ok()?;
let size = end as usize - pos as usize;
self.size = Some(size);
return Some(size);
}
None
}
#[inline(always)]
pub fn take(&mut self) -> Self {
std::mem::take(self)
}
pub async fn to_bytes(&mut self) -> io::Result<Vec<u8>> {
let mut vec = Vec::new();
let n = match self.read_to_end(&mut vec).await {
Ok(n) => n,
Err(e) => {
error_!("Error reading body: {:?}", e);
return Err(e);
}
};
if let Some(ref mut size) = self.size {
*size = size.checked_sub(n).unwrap_or(0);
}
Ok(vec)
}
pub async fn to_string(&mut self) -> io::Result<String> {
String::from_utf8(self.to_bytes().await?).map_err(|e| {
error_!("Body is invalid UTF-8: {}", e);
io::Error::new(io::ErrorKind::InvalidData, e)
})
}
}
impl AsyncRead for Body<'_> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let reader = match self.inner {
Inner::Seekable(ref mut b) => b as &mut (dyn AsyncRead + Unpin),
Inner::Unsized(ref mut b) => b as &mut (dyn AsyncRead + Unpin),
Inner::Phantom(_) | Inner::None => return Poll::Ready(Ok(())),
};
Pin::new(reader).poll_read(cx, buf)
}
}
impl fmt::Debug for Inner<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Inner::Seekable(_) => "seekable".fmt(f),
Inner::Unsized(_) => "unsized".fmt(f),
Inner::Phantom(_) => "phantom".fmt(f),
Inner::None => "none".fmt(f),
}
}
}