use std::sync::Arc;
use std::task::{Context, Poll};
use bytes::Bytes;
use futures_util::future;
use futures_util::stream::{Stream, TryStreamExt};
use spin::mutex::spin::SpinMutex as Mutex;
#[cfg(feature = "tokio-io")]
use {tokio::io::AsyncRead, tokio_util::io::ReaderStream};
use crate::buffer::StreamBuffer;
use crate::constraints::Constraints;
use crate::content_disposition::ContentDisposition;
use crate::error::Error;
use crate::field::Field;
use crate::{Result, constants, helpers};
#[derive(Debug)]
pub struct Multipart<'r> {
state: Arc<Mutex<MultipartState<'r>>>,
}
#[derive(Debug)]
pub(crate) struct MultipartState<'r> {
pub(crate) buffer: StreamBuffer<'r>,
pub(crate) boundary_bytes: Vec<u8>,
pub(crate) field_boundary_bytes: Vec<u8>,
pub(crate) stage: StreamingStage,
pub(crate) next_field_idx: usize,
pub(crate) curr_field_name: Option<String>,
pub(crate) curr_field_size_limit: u64,
pub(crate) curr_field_size_counter: u64,
pub(crate) constraints: Constraints,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum StreamingStage {
FindingFirstBoundary,
ReadingBoundary,
DeterminingBoundaryType,
ReadingTransportPadding,
ReadingFieldHeaders,
ReadingFieldData,
Eof,
}
impl<'r> Multipart<'r> {
pub fn new<S, O, E, B>(stream: S, boundary: B) -> Self
where
S: Stream<Item = Result<O, E>> + Send + 'r,
O: Into<Bytes> + 'static,
E: Into<Box<dyn std::error::Error + Send + Sync>> + 'r,
B: Into<String>,
{
Multipart::with_constraints(stream, boundary, Constraints::default())
}
pub fn with_constraints<S, O, E, B>(stream: S, boundary: B, constraints: Constraints) -> Self
where
S: Stream<Item = Result<O, E>> + Send + 'r,
O: Into<Bytes> + 'static,
E: Into<Box<dyn std::error::Error + Send + Sync>> + 'r,
B: Into<String>,
{
let boundary = boundary.into();
let boundary_bytes = format!("{}{}", constants::BOUNDARY_EXT, boundary).into_bytes();
let field_boundary_bytes =
format!("{}{}{}", constants::CRLF, constants::BOUNDARY_EXT, boundary).into_bytes();
let stream = stream
.map_ok(|b| b.into())
.map_err(|err| Error::StreamReadFailed(err.into()));
Multipart {
state: Arc::new(Mutex::new(MultipartState {
buffer: StreamBuffer::new(stream, constraints.size_limit.whole_stream),
boundary_bytes,
field_boundary_bytes,
stage: StreamingStage::FindingFirstBoundary,
next_field_idx: 0,
curr_field_name: None,
curr_field_size_limit: constraints.size_limit.per_field,
curr_field_size_counter: 0,
constraints,
})),
}
}
#[cfg(feature = "tokio-io")]
pub fn with_reader<R, B>(reader: R, boundary: B) -> Self
where
R: AsyncRead + Unpin + Send + 'r,
B: Into<String>,
{
let stream = ReaderStream::new(reader);
Multipart::new(stream, boundary)
}
#[cfg(feature = "tokio-io")]
pub fn with_reader_and_constraints<R, B>(
reader: R,
boundary: B,
constraints: Constraints,
) -> Self
where
R: AsyncRead + Unpin + Send + 'r,
B: Into<String>,
{
let stream = ReaderStream::new(reader);
Multipart::with_constraints(stream, boundary, constraints)
}
pub async fn next_field(&mut self) -> Result<Option<Field<'r>>> {
future::poll_fn(|cx| self.poll_next_field(cx)).await
}
pub fn poll_next_field(&mut self, cx: &mut Context<'_>) -> Poll<Result<Option<Field<'r>>>> {
if Arc::strong_count(&self.state) != 1 {
return Poll::Ready(Err(Error::LockFailure));
}
debug_assert_eq!(Arc::strong_count(&self.state), 1);
debug_assert!(self.state.try_lock().is_some(), "expected exclusive lock");
let mut lock = match self.state.try_lock() {
Some(lock) => lock,
None => return Poll::Ready(Err(Error::LockFailure)),
};
let state = &mut *lock;
if state.stage == StreamingStage::Eof {
return Poll::Ready(Ok(None));
}
state.buffer.poll_stream(cx)?;
if state.stage == StreamingStage::FindingFirstBoundary {
match state.buffer.read_to(&state.boundary_bytes) {
Some(_) => state.stage = StreamingStage::ReadingBoundary,
None => {
state.buffer.poll_stream(cx)?;
if state.buffer.eof {
return Poll::Ready(Err(Error::IncompleteStream));
}
}
}
}
if state.stage == StreamingStage::ReadingFieldData {
match state.buffer.read_field_data(
&state.field_boundary_bytes,
state.curr_field_name.as_deref(),
)? {
Some((done, bytes)) => {
state.curr_field_size_counter += bytes.len() as u64;
if state.curr_field_size_counter > state.curr_field_size_limit {
return Poll::Ready(Err(Error::FieldSizeExceeded {
limit: state.curr_field_size_limit,
field_name: state.curr_field_name.clone(),
}));
}
if done {
state.stage = StreamingStage::ReadingBoundary;
} else {
return Poll::Pending;
}
}
None => {
return Poll::Pending;
}
}
}
if state.stage == StreamingStage::ReadingBoundary {
let boundary_bytes = match state.buffer.read_exact(state.boundary_bytes.len()) {
Some(bytes) => bytes,
None => {
return if state.buffer.eof {
Poll::Ready(Err(Error::IncompleteStream))
} else {
Poll::Pending
};
}
};
if &boundary_bytes[..] == state.boundary_bytes.as_slice() {
state.stage = StreamingStage::DeterminingBoundaryType;
} else {
return Poll::Ready(Err(Error::IncompleteStream));
}
}
if state.stage == StreamingStage::DeterminingBoundaryType {
let ext_len = constants::BOUNDARY_EXT.len();
let next_bytes = match state.buffer.peek_exact(ext_len) {
Some(bytes) => bytes,
None => {
return if state.buffer.eof {
Poll::Ready(Err(Error::IncompleteStream))
} else {
Poll::Pending
};
}
};
if next_bytes == constants::BOUNDARY_EXT.as_bytes() {
state.stage = StreamingStage::Eof;
return Poll::Ready(Ok(None));
} else {
state.stage = StreamingStage::ReadingTransportPadding;
}
}
if state.stage == StreamingStage::ReadingTransportPadding {
if !state.buffer.advance_past_transport_padding() {
return if state.buffer.eof {
Poll::Ready(Err(Error::IncompleteStream))
} else {
Poll::Pending
};
}
let crlf_len = constants::CRLF.len();
let crlf_bytes = match state.buffer.read_exact(crlf_len) {
Some(bytes) => bytes,
None => {
return if state.buffer.eof {
Poll::Ready(Err(Error::IncompleteStream))
} else {
Poll::Pending
};
}
};
if &crlf_bytes[..] == constants::CRLF.as_bytes() {
state.stage = StreamingStage::ReadingFieldHeaders;
} else {
return Poll::Ready(Err(Error::IncompleteStream));
}
}
if state.stage == StreamingStage::ReadingFieldHeaders {
let headers_limit = state.constraints.size_limit.headers;
let header_bytes = match state.buffer.read_until(constants::CRLF_CRLF.as_bytes()) {
Some(bytes) => bytes,
None => {
if state.buffer.buf.len() as u64 > headers_limit {
return Poll::Ready(Err(Error::HeadersSizeExceeded {
limit: headers_limit,
}));
}
return if state.buffer.eof {
return Poll::Ready(Err(Error::IncompleteStream));
} else {
Poll::Pending
};
}
};
if header_bytes.len() as u64 > headers_limit {
return Poll::Ready(Err(Error::HeadersSizeExceeded {
limit: headers_limit,
}));
}
let mut headers = [httparse::EMPTY_HEADER; constants::MAX_HEADERS];
let headers = match httparse::parse_headers(&header_bytes, &mut headers)
.map_err(Error::ReadHeaderFailed)?
{
httparse::Status::Complete((_, raw_headers)) => {
match helpers::convert_raw_headers_to_header_map(raw_headers) {
Ok(headers) => headers,
Err(err) => {
return Poll::Ready(Err(err));
}
}
}
httparse::Status::Partial => {
return Poll::Ready(Err(Error::IncompleteHeaders));
}
};
state.stage = StreamingStage::ReadingFieldData;
let field_idx = state.next_field_idx;
state.next_field_idx += 1;
let content_disposition = ContentDisposition::parse(&headers);
let field_size_limit = state
.constraints
.size_limit
.extract_size_limit_for(content_disposition.field_name.as_deref());
state.curr_field_name = content_disposition.field_name.clone();
state.curr_field_size_limit = field_size_limit;
state.curr_field_size_counter = 0;
let field_name = content_disposition.field_name.as_deref();
if !state.constraints.is_it_allowed(field_name) {
return Poll::Ready(Err(Error::UnknownField {
field_name: field_name.map(str::to_owned),
}));
}
drop(lock); let field = Field::new(self.state.clone(), headers, field_idx, content_disposition);
return Poll::Ready(Ok(Some(field)));
}
Poll::Pending
}
pub async fn next_field_with_idx(&mut self) -> Result<Option<(usize, Field<'r>)>> {
self.next_field()
.await
.map(|f| f.map(|field| (field.index(), field)))
}
}