use std::pin::Pin;
use std::task::{Context, Poll};
use std::path::Path;
use std::io::{self, Cursor};
use tokio::fs::File;
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, ReadBuf, Take};
use futures::stream::Stream;
use futures::ready;
use crate::http::hyper;
use crate::ext::{PollExt, Chain};
use crate::data::{Capped, N};
pub struct DataStream<'r> {
pub(crate) chain: Take<Chain<Cursor<Vec<u8>>, StreamReader<'r>>>,
}
pub struct StreamReader<'r> {
state: State,
inner: StreamKind<'r>,
}
enum State {
Pending,
Partial(Cursor<hyper::Bytes>),
Done,
}
enum StreamKind<'r> {
Empty,
Body(&'r mut hyper::Body),
Multipart(multer::Field<'r>)
}
impl<'r> DataStream<'r> {
pub(crate) fn new(buf: Vec<u8>, stream: StreamReader<'r>, limit: u64) -> Self {
let chain = Chain::new(Cursor::new(buf), stream).take(limit);
Self { chain }
}
async fn limit_exceeded(&mut self) -> io::Result<bool> {
#[cold]
async fn _limit_exceeded(stream: &mut DataStream<'_>) -> io::Result<bool> {
stream.chain.set_limit(1);
let mut buf = [0u8; 1];
Ok(stream.read(&mut buf).await? != 0)
}
Ok(self.chain.limit() == 0 && _limit_exceeded(self).await?)
}
pub fn hint(&self) -> usize {
let buf_len = self.chain.get_ref().get_ref().0.get_ref().len();
std::cmp::min(buf_len, self.chain.limit() as usize)
}
#[inline(always)]
pub async fn stream_to<W>(mut self, mut writer: W) -> io::Result<N>
where W: AsyncWrite + Unpin
{
let written = tokio::io::copy(&mut self, &mut writer).await?;
Ok(N { written, complete: !self.limit_exceeded().await? })
}
#[inline(always)]
pub async fn stream_precise_to<W>(mut self, mut writer: W) -> io::Result<u64>
where W: AsyncWrite + Unpin
{
tokio::io::copy(&mut self, &mut writer).await
}
pub async fn into_bytes(self) -> io::Result<Capped<Vec<u8>>> {
let mut vec = Vec::with_capacity(self.hint());
let n = self.stream_to(&mut vec).await?;
Ok(Capped { value: vec, n })
}
pub async fn into_string(mut self) -> io::Result<Capped<String>> {
let mut string = String::with_capacity(self.hint());
let written = self.read_to_string(&mut string).await?;
let n = N { written: written as u64, complete: !self.limit_exceeded().await? };
Ok(Capped { value: string, n })
}
pub async fn into_file<P: AsRef<Path>>(self, path: P) -> io::Result<Capped<File>> {
let mut file = File::create(path).await?;
let n = self.stream_to(&mut tokio::io::BufWriter::new(&mut file)).await?;
Ok(Capped { value: file, n })
}
}
impl StreamReader<'_> {
pub fn empty() -> Self {
Self { inner: StreamKind::Empty, state: State::Done }
}
}
impl<'r> From<&'r mut hyper::Body> for StreamReader<'r> {
fn from(body: &'r mut hyper::Body) -> Self {
Self { inner: StreamKind::Body(body), state: State::Pending }
}
}
impl<'r> From<multer::Field<'r>> for StreamReader<'r> {
fn from(field: multer::Field<'r>) -> Self {
Self { inner: StreamKind::Multipart(field), state: State::Pending }
}
}
impl AsyncRead for DataStream<'_> {
#[inline(always)]
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.chain).poll_read(cx, buf)
}
}
impl Stream for StreamKind<'_> {
type Item = io::Result<hyper::Bytes>;
fn poll_next(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
match self.get_mut() {
StreamKind::Body(body) => Pin::new(body).poll_next(cx)
.map_err_ext(|e| io::Error::new(io::ErrorKind::Other, e)),
StreamKind::Multipart(mp) => Pin::new(mp).poll_next(cx)
.map_err_ext(|e| io::Error::new(io::ErrorKind::Other, e)),
StreamKind::Empty => Poll::Ready(None),
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
match self {
StreamKind::Body(body) => body.size_hint(),
StreamKind::Multipart(mp) => mp.size_hint(),
StreamKind::Empty => (0, Some(0)),
}
}
}
impl AsyncRead for StreamReader<'_> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
loop {
self.state = match self.state {
State::Pending => {
match ready!(Pin::new(&mut self.inner).poll_next(cx)) {
Some(Err(e)) => return Poll::Ready(Err(e)),
Some(Ok(bytes)) => State::Partial(Cursor::new(bytes)),
None => State::Done,
}
},
State::Partial(ref mut cursor) => {
let rem = buf.remaining();
match ready!(Pin::new(cursor).poll_read(cx, buf)) {
Ok(()) if rem == buf.remaining() => State::Pending,
result => return Poll::Ready(result),
}
}
State::Done => return Poll::Ready(Ok(())),
}
}
}
}