use std::io::{self, Cursor};
use std::path::Path;
use std::pin::Pin;
use std::task::{Context, Poll};
use futures::ready;
use futures::stream::Stream;
use hyper::body::{Body, Bytes, Incoming as HyperBody};
use tokio::fs::File;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf, Take};
use tokio_util::io::StreamReader;
use crate::data::transform::Transform;
use crate::data::{Capped, N};
use crate::util::Chain;
use super::peekable::Peekable;
use super::transform::TransformBuf;
#[allow(clippy::large_enum_variant)]
#[non_exhaustive]
pub enum DataStream<'r> {
#[doc(hidden)]
Base(BaseReader<'r>),
#[doc(hidden)]
Transform(TransformReader<'r>),
}
pub struct TransformReader<'r> {
transformer: Pin<Box<dyn Transform + Send + Sync + 'r>>,
stream: Pin<Box<DataStream<'r>>>,
inner_done: bool,
}
pub type BaseReader<'r> = Take<Chain<Cursor<Vec<u8>>, RawReader<'r>>>;
pub type RawReader<'r> = StreamReader<RawStream<'r>, Bytes>;
#[allow(clippy::large_enum_variant)]
pub enum RawStream<'r> {
Empty,
Body(HyperBody),
#[cfg(feature = "http3-preview")]
H3Body(crate::listener::Cancellable<crate::listener::quic::QuicRx>),
Multipart(multer::Field<'r>),
}
impl<'r> TransformReader<'r> {
fn base_mut(&mut self) -> &mut BaseReader<'r> {
match self.stream.as_mut().get_mut() {
DataStream::Base(base) => base,
DataStream::Transform(inner) => inner.base_mut(),
}
}
fn base(&self) -> &BaseReader<'r> {
match self.stream.as_ref().get_ref() {
DataStream::Base(base) => base,
DataStream::Transform(inner) => inner.base(),
}
}
}
impl<'r> DataStream<'r> {
pub(crate) fn new(
transformers: Vec<Pin<Box<dyn Transform + Send + Sync + 'r>>>,
Peekable { buffer, reader, .. }: Peekable<512, RawReader<'r>>,
limit: u64,
) -> Self {
let mut stream = DataStream::Base(Chain::new(Cursor::new(buffer), reader).take(limit));
for transformer in transformers {
stream = DataStream::Transform(TransformReader {
transformer,
stream: Box::pin(stream),
inner_done: false,
});
}
stream
}
fn base_mut(&mut self) -> &mut BaseReader<'r> {
match self {
DataStream::Base(base) => base,
DataStream::Transform(transform) => transform.base_mut(),
}
}
fn base(&self) -> &BaseReader<'r> {
match self {
DataStream::Base(base) => base,
DataStream::Transform(transform) => transform.base(),
}
}
async fn limit_exceeded(&mut self) -> io::Result<bool> {
let base = self.base_mut();
#[cold]
async fn _limit_exceeded(base: &mut BaseReader<'_>) -> io::Result<bool> {
base.set_limit(1);
let mut buf = [0u8; 1];
let exceeded = base.read(&mut buf).await? != 0;
base.set_limit(0);
Ok(exceeded)
}
Ok(base.limit() == 0 && _limit_exceeded(base).await?)
}
pub fn hint(&self) -> usize {
let base = self.base();
if let (Some(cursor), _) = base.get_ref().get_ref() {
let len = cursor.get_ref().len() as u64;
let position = cursor.position().min(len);
let remaining = len - position;
remaining.min(base.limit()) as usize
} else {
0
}
}
#[inline(always)]
pub async fn stream_to<W>(mut self, mut writer: W) -> io::Result<N>
where
W: AsyncWrite + Unpin,
{
let written = tokio::io::copy(&mut self, &mut writer).await?;
Ok(N {
written,
complete: !self.limit_exceeded().await?,
})
}
#[inline(always)]
pub async fn stream_precise_to<W>(mut self, mut writer: W) -> io::Result<u64>
where
W: AsyncWrite + Unpin,
{
tokio::io::copy(&mut self, &mut writer).await
}
pub async fn into_bytes(self) -> io::Result<Capped<Vec<u8>>> {
let mut vec = Vec::with_capacity(self.hint());
let n = self.stream_to(&mut vec).await?;
Ok(Capped { value: vec, n })
}
pub async fn into_string(mut self) -> io::Result<Capped<String>> {
let mut string = String::with_capacity(self.hint());
let written = self.read_to_string(&mut string).await?;
let n = N {
written: written as u64,
complete: !self.limit_exceeded().await?,
};
Ok(Capped { value: string, n })
}
pub async fn into_file<P: AsRef<Path>>(self, path: P) -> io::Result<Capped<File>> {
let mut file = File::create(path).await?;
let n = self
.stream_to(&mut tokio::io::BufWriter::new(&mut file))
.await?;
Ok(Capped { value: file, n })
}
}
impl AsyncRead for DataStream<'_> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
match self.get_mut() {
DataStream::Base(inner) => Pin::new(inner).poll_read(cx, buf),
DataStream::Transform(inner) => Pin::new(inner).poll_read(cx, buf),
}
}
}
impl AsyncRead for TransformReader<'_> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let init_fill = buf.filled().len();
if !self.inner_done {
ready!(Pin::new(&mut self.stream).poll_read(cx, buf))?;
self.inner_done = init_fill == buf.filled().len();
}
if self.inner_done {
return self.transformer.as_mut().poll_finish(cx, buf);
}
let mut tbuf = TransformBuf {
buf,
cursor: init_fill,
};
self.transformer.as_mut().transform(&mut tbuf)?;
if buf.filled().len() == init_fill {
cx.waker().wake_by_ref();
return Poll::Pending;
}
Poll::Ready(Ok(()))
}
}
impl Stream for RawStream<'_> {
type Item = io::Result<Bytes>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.get_mut() {
RawStream::Body(body) => Pin::new(body)
.poll_frame(cx)
.map_ok(|frame| frame.into_data().unwrap_or_else(|_| Bytes::new()))
.map_err(io::Error::other),
#[cfg(feature = "http3-preview")]
RawStream::H3Body(stream) => Pin::new(stream).poll_next(cx),
RawStream::Multipart(s) => Pin::new(s).poll_next(cx).map_err(io::Error::other),
RawStream::Empty => Poll::Ready(None),
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
match self {
RawStream::Body(body) => {
let hint = body.size_hint();
let (lower, upper) = (hint.lower(), hint.upper());
(lower as usize, upper.map(|x| x as usize))
}
#[cfg(feature = "http3-preview")]
RawStream::H3Body(_) => (0, Some(0)),
RawStream::Multipart(mp) => mp.size_hint(),
RawStream::Empty => (0, Some(0)),
}
}
}
impl std::fmt::Display for RawStream<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RawStream::Empty => f.write_str("empty stream"),
RawStream::Body(_) => f.write_str("request body"),
#[cfg(feature = "http3-preview")]
RawStream::H3Body(_) => f.write_str("http3 quic stream"),
RawStream::Multipart(_) => f.write_str("multipart form field"),
}
}
}
impl<'r> From<HyperBody> for RawStream<'r> {
fn from(value: HyperBody) -> Self {
Self::Body(value)
}
}
#[cfg(feature = "http3-preview")]
impl<'r> From<crate::listener::Cancellable<crate::listener::quic::QuicRx>> for RawStream<'r> {
fn from(value: crate::listener::Cancellable<crate::listener::quic::QuicRx>) -> Self {
Self::H3Body(value)
}
}
impl<'r> From<multer::Field<'r>> for RawStream<'r> {
fn from(value: multer::Field<'r>) -> Self {
Self::Multipart(value)
}
}