use std::{
error::Error as StdError,
fs::File,
io::Write,
pin::Pin,
task::{Context, Poll},
};
use bytes::{Bytes, BytesMut};
use futures_util::{
io::{self, AsyncRead, AsyncWrite, AsyncWriteExt},
stream::{Stream, TryStreamExt},
};
use http::header::{CONTENT_DISPOSITION, CONTENT_TYPE};
use tracing::trace;
use crate::{
utils::{parse_content_disposition, parse_content_type, parse_part_headers},
Error, Field, Flag, FormData, Result, State,
};
impl<T, B, E> Stream for State<T>
where
T: Stream<Item = Result<B, E>> + Unpin,
B: Into<Bytes>,
E: Into<Box<dyn StdError + Send + Sync>>,
{
type Item = Result<Bytes>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
loop {
if self.is_readable {
trace!("attempting to decode a part");
if let Some(data) = self.decode() {
trace!("part decoded from buffer");
return Poll::Ready(Some(Ok(data)));
}
if Flag::Next == self.flag {
return Poll::Ready(None);
}
if Flag::Eof == self.flag {
self.length -= self.buffer.len() as u64;
self.buffer.clear();
self.eof = true;
return Poll::Ready(None);
}
self.is_readable = false;
}
trace!("polling data from stream");
if self.eof {
self.is_readable = true;
continue;
}
self.buffer.reserve(1);
let bytect = match Pin::new(self.io_mut()).poll_next(cx) {
Poll::Pending => {
return Poll::Pending;
}
Poll::Ready(Some(Ok(b))) => {
let b = b.into();
let l = b.len() as u64;
if let Some(max) = self.limits.checked_stream_size(self.length + l) {
return Poll::Ready(Some(Err(Error::PayloadTooLarge(max))));
}
self.buffer.extend_from_slice(&b);
self.length += l;
l
}
Poll::Ready(Some(Err(e))) => {
return Poll::Ready(Some(Err(Error::BoxError(e.into()))))
}
Poll::Ready(None) => 0,
};
if bytect == 0 {
self.eof = true;
}
self.is_readable = true;
}
}
}
impl<T, B, E> Field<T>
where
T: Stream<Item = Result<B, E>> + Unpin,
B: Into<Bytes>,
E: Into<Box<dyn StdError + Send + Sync>>,
{
pub async fn bytes(&mut self) -> Result<Bytes> {
let mut bytes = BytesMut::new();
while let Some(buf) = self.try_next().await? {
bytes.extend_from_slice(&buf);
}
Ok(bytes.freeze())
}
pub async fn copy_to<W>(&mut self, writer: &mut W) -> Result<u64>
where
W: AsyncWrite + Send + Unpin + 'static,
{
let mut n = 0;
while let Some(buf) = self.try_next().await? {
writer.write_all(&buf).await?;
n += buf.len();
}
writer.flush().await?;
Ok(n as u64)
}
pub async fn copy_to_file(&mut self, file: &mut File) -> Result<u64> {
let mut n = 0;
while let Some(buf) = self.try_next().await? {
n += file.write(&buf)?;
}
file.flush()?;
Ok(n as u64)
}
pub async fn ignore(&mut self) -> Result<()> {
while let Some(buf) = self.try_next().await? {
drop(buf);
}
Ok(())
}
}
impl<T, B, E> AsyncRead for Field<T>
where
T: Stream<Item = Result<B, E>> + Unpin,
B: Into<Bytes>,
E: Into<Box<dyn StdError + Send + Sync>>,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
mut buf: &mut [u8],
) -> Poll<io::Result<usize>> {
match self.poll_next(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(None) => Poll::Ready(Ok(0)),
Poll::Ready(Some(Ok(b))) => Poll::Ready(Ok(buf.write(&b)?)),
Poll::Ready(Some(Err(e))) => Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, e))),
}
}
}
impl<T, B, E> Stream for Field<T>
where
T: Stream<Item = Result<B, E>> + Unpin,
B: Into<Bytes>,
E: Into<Box<dyn StdError + Send + Sync>>,
{
type Item = Result<Bytes>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
trace!("polling {} {}", self.index, self.state.is_some());
let state = match self.state.clone() {
None => return Poll::Ready(None),
Some(state) => state,
};
let is_file = self.filename.is_some();
let mut state = state
.try_lock()
.map_err(|e| Error::TryLockError(e.to_string()))?;
match Pin::new(&mut *state).poll_next(cx)? {
Poll::Pending => Poll::Pending,
Poll::Ready(res) => match res {
None => {
if let Some(waker) = state.waker_mut().take() {
waker.wake();
}
trace!("polled {}", self.index);
drop(self.state.take());
Poll::Ready(None)
}
Some(buf) => {
let l = buf.len();
if is_file {
if let Some(max) = state.limits.checked_file_size(self.length + l) {
return Poll::Ready(Some(Err(Error::FileTooLarge(max))));
}
} else if let Some(max) = state.limits.checked_field_size(self.length + l) {
return Poll::Ready(Some(Err(Error::FieldTooLarge(max))));
}
self.length += l;
trace!("polled bytes {}/{}", buf.len(), self.length);
Poll::Ready(Some(Ok(buf)))
}
},
}
}
}
impl<T, B, E> Stream for FormData<T>
where
T: Stream<Item = Result<B, E>> + Unpin,
B: Into<Bytes>,
E: Into<Box<dyn StdError + Send + Sync>>,
{
type Item = Result<Field<T>>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut state = self
.state
.try_lock()
.map_err(|e| Error::TryLockError(e.to_string()))?;
if state.waker().is_some() {
return Poll::Pending;
}
match Pin::new(&mut *state).poll_next(cx)? {
Poll::Pending => Poll::Pending,
Poll::Ready(res) => match res {
None => {
trace!("parse eof");
Poll::Ready(None)
}
Some(buf) => {
trace!("parse part");
if let Some(max) = state.limits.checked_parts(state.total + 1) {
return Poll::Ready(Some(Err(Error::PartsTooMany(max))));
}
let mut headers = match parse_part_headers(&buf) {
Ok(h) => h,
Err(_) => return Poll::Ready(Some(Err(Error::InvalidHeader))),
};
let (name, filename) = match headers
.remove(CONTENT_DISPOSITION)
.and_then(|v| parse_content_disposition(v.as_bytes()).ok())
{
Some(n) => n,
None => return Poll::Ready(Some(Err(Error::InvalidContentDisposition))),
};
if let Some(max) = state.limits.checked_field_name_size(name.len()) {
return Poll::Ready(Some(Err(Error::FieldNameTooLong(max))));
}
if filename.is_some() {
if let Some(max) = state.limits.checked_files(state.files + 1) {
return Poll::Ready(Some(Err(Error::FilesTooMany(max))));
}
state.files += 1;
} else {
if let Some(max) = state.limits.checked_fields(state.fields + 1) {
return Poll::Ready(Some(Err(Error::FieldsTooMany(max))));
}
state.fields += 1;
}
let mut field = Field::<T>::empty();
field.name = name;
field.filename = filename;
field.index = state.index();
field.content_type = parse_content_type(headers.remove(CONTENT_TYPE).as_ref());
field.state_mut().replace(self.state());
if !headers.is_empty() {
field.headers_mut().replace(headers);
}
state.waker_mut().replace(cx.waker().clone());
Poll::Ready(Some(Ok(field)))
}
},
}
}
}