use bytes::{Bytes, BytesMut};
use futures_core::Stream;
use http::header::{HeaderMap, HeaderName, HeaderValue};
use httparse::Status;
use pin_project::pin_project;
use std::error::Error as StdError;
use std::mem;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
use thiserror::Error;
use twoway::{find_bytes, rfind_bytes};
type AnyStdError = Box<dyn StdError + Send + Sync + 'static>;
pub struct MultipartField<S, E>
where
S: Stream<Item = Result<Bytes, E>> + Unpin,
E: Into<AnyStdError>,
{
headers: HeaderMap<HeaderValue>,
state: Arc<Mutex<MultipartState<S, E>>>,
}
impl<S, E> MultipartField<S, E>
where
S: Stream<Item = Result<Bytes, E>> + Unpin,
E: Into<AnyStdError>,
{
pub fn headers(&self) -> &HeaderMap<HeaderValue> {
&self.headers
}
pub fn content_type<'a>(&'a self) -> Result<&'a str, MultipartError> {
if let Some(val) = self.headers.get("content-type") {
return val.to_str().map_err(|_| MultipartError::InvalidHeader);
}
Err(MultipartError::InvalidHeader)
}
pub fn filename<'a>(&'a self) -> Result<&'a str, MultipartError> {
if let Some(val) = self.headers.get("content-disposition") {
let string_val = val.to_str().map_err(|_| MultipartError::InvalidHeader)?;
if let Some(filename) = get_dispo_param(&string_val, "filename") {
return Ok(filename);
}
}
Err(MultipartError::InvalidHeader)
}
pub fn name<'a>(&'a self) -> Result<&'a str, MultipartError> {
if let Some(val) = self.headers.get("content-disposition") {
let string_val = val.to_str().map_err(|_| MultipartError::InvalidHeader)?;
if let Some(filename) = get_dispo_param(&string_val, "name") {
return Ok(filename);
}
}
Err(MultipartError::InvalidHeader)
}
}
fn get_dispo_param<'a>(input: &'a str, param: &str) -> Option<&'a str> {
if let Some(start_idx) = input.find(¶m) {
let end_param = start_idx + param.len();
if input.len() > end_param + 2 {
if &input[end_param..end_param + 2] == "=\"" {
let start = end_param + 2;
if let Some(end) = &input[start..].find("\"") {
return Some(&input[start..start + end]);
}
}
}
}
return None;
}
impl<S, E> Stream for MultipartField<S, E>
where
S: Stream<Item = Result<Bytes, E>> + Unpin,
E: Into<AnyStdError>,
{
type Item = Result<Bytes, MultipartError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let self_mut = &mut self.as_mut();
let state = &mut self_mut
.state
.try_lock()
.map_err(|_| MultipartError::InternalBorrowError)?;
match Pin::new(&mut state.parser).poll_next(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Some(Err(err))) => {
return Poll::Ready(Some(Err(MultipartError::Stream(err.into()))))
}
Poll::Ready(None) => return Poll::Ready(None),
Poll::Ready(Some(Ok(ParseOutput::Headers(headers)))) => {
state.next_item = Some(headers);
return Poll::Ready(None);
}
Poll::Ready(Some(Ok(ParseOutput::Bytes(bytes)))) => {
return Poll::Ready(Some(Ok(bytes)))
}
}
}
}
struct MultipartState<S, E>
where
S: Stream<Item = Result<Bytes, E>> + Unpin,
E: Into<AnyStdError>,
{
parser: MultipartParser<S, E>,
next_item: Option<HeaderMap<HeaderValue>>,
}
pub struct MultipartStream<S, E>
where
S: Stream<Item = Result<Bytes, E>> + Unpin,
E: Into<AnyStdError>,
{
state: Arc<Mutex<MultipartState<S, E>>>,
}
impl<S, E> MultipartStream<S, E>
where
S: Stream<Item = Result<Bytes, E>> + Unpin,
E: Into<AnyStdError>,
{
pub fn new<I: Into<Bytes>>(boundary: I, stream: S) -> Self {
Self {
state: Arc::new(Mutex::new(MultipartState {
parser: MultipartParser::new(boundary, stream),
next_item: None,
})),
}
}
}
impl<S, E> Stream for MultipartStream<S, E>
where
S: Stream<Item = Result<Bytes, E>> + Unpin,
E: Into<AnyStdError>,
{
type Item = Result<MultipartField<S, E>, MultipartError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let self_mut = &mut self.as_mut();
let state = &mut self_mut
.state
.try_lock()
.map_err(|_| MultipartError::InternalBorrowError)?;
if let Some(headers) = state.next_item.take() {
return Poll::Ready(Some(Ok(MultipartField {
headers,
state: self_mut.state.clone(),
})));
}
match Pin::new(&mut state.parser).poll_next(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(err))),
Poll::Ready(None) => return Poll::Ready(None),
Poll::Ready(Some(Ok(ParseOutput::Headers(headers)))) => {
return Poll::Ready(Some(Ok(MultipartField {
headers,
state: self_mut.state.clone(),
})));
}
Poll::Ready(Some(Ok(ParseOutput::Bytes(_bytes)))) => {
return Poll::Ready(Some(Err(MultipartError::ShouldPollField)));
}
}
}
}
#[derive(Error, Debug)]
pub enum MultipartError {
#[error("Invalid Boundary. (expected {expected:?}, found {found:?})")]
InvalidBoundary {
expected: String,
found: String,
},
#[error("Incomplete Headers")]
IncompleteHeader,
#[error("Invalid Header Value")]
InvalidHeader,
#[error(
"Tried to poll an MultipartStream when the MultipartField should be polled, try using `flatten()`"
)]
ShouldPollField,
#[error("Tried to poll an MultipartField and the Mutex has already been locked")]
InternalBorrowError,
#[error(transparent)]
HeaderParse(#[from] httparse::Error),
#[error(transparent)]
Stream(#[from] AnyStdError),
#[error("EOF while reading headers")]
EOFWhileReadingHeaders,
#[error("EOF while reading boundary")]
EOFWhileReadingBoundary,
#[error("EOF while reading body")]
EOFWhileReadingBody,
#[error("Garbage following boundary: {0:02x?}")]
GarbageAfterBoundary([u8; 2]),
}
#[pin_project(project = ParserProj)]
pub struct MultipartParser<S, E>
where
S: Stream<Item = Result<Bytes, E>>,
E: Into<AnyStdError>,
{
boundary: Bytes,
buffer: BytesMut,
state: State,
#[pin]
stream: S,
}
impl<S, E> MultipartParser<S, E>
where
S: Stream<Item = Result<Bytes, E>>,
E: Into<AnyStdError>,
{
pub fn new<I: Into<Bytes>>(boundary: I, stream: S) -> Self {
Self {
boundary: boundary.into(),
buffer: BytesMut::new(),
state: State::ReadingBoundary,
stream,
}
}
}
const NUM_HEADERS: usize = 16;
fn get_headers(buffer: &[u8]) -> Result<HeaderMap<HeaderValue>, MultipartError> {
let mut headers = [httparse::EMPTY_HEADER; NUM_HEADERS];
let idx = match httparse::parse_headers(&buffer, &mut headers)? {
Status::Complete((idx, _val)) => idx,
Status::Partial => return Err(MultipartError::IncompleteHeader),
};
let mut header_map = HeaderMap::with_capacity(idx);
for header in headers.iter().take(idx) {
if header.name != "" {
header_map.insert(
HeaderName::from_bytes(header.name.as_bytes())
.map_err(|_| MultipartError::InvalidHeader)?,
HeaderValue::from_bytes(header.value).map_err(|_| MultipartError::InvalidHeader)?,
);
}
}
Ok(header_map)
}
impl<S, E> Stream for MultipartParser<S, E>
where
S: Stream<Item = Result<Bytes, E>>,
E: Into<AnyStdError>,
{
type Item = Result<ParseOutput, MultipartError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let ParserProj {
boundary,
buffer,
state,
mut stream,
} = self.project();
loop {
match state {
State::ReadingBoundary => {
let boundary_len = boundary.len();
if buffer.len() < boundary_len + 4 {
match futures_core::ready!(stream.as_mut().poll_next(cx)) {
Some(Ok(bytes)) => {
buffer.extend_from_slice(&bytes);
continue;
}
Some(Err(e)) => {
return Poll::Ready(Some(Err(MultipartError::Stream(e.into()))))
}
None => {
return Poll::Ready(Some(Err(
MultipartError::EOFWhileReadingBoundary,
)))
}
}
}
if &buffer[..2] == b"--"
&& &buffer[2..boundary_len + 2] == &*boundary
&& &buffer[boundary_len + 2..boundary_len + 4] == b"\r\n"
{
*buffer = buffer.split_off(boundary_len + 4);
*state = State::ReadingHeader;
let mut new_boundary = BytesMut::with_capacity(boundary_len + 4);
new_boundary.extend_from_slice(b"\r\n--");
new_boundary.extend_from_slice(&boundary);
*boundary = new_boundary.freeze();
cx.waker().wake_by_ref();
return Poll::Pending;
} else {
let expected = format!("--{}\\r\\n", String::from_utf8_lossy(&boundary));
let found =
String::from_utf8_lossy(&buffer[..boundary_len + 4]).to_string();
let error = MultipartError::InvalidBoundary { expected, found };
return Poll::Ready(Some(Err(error)));
}
}
State::ReadingHeader => {
if let Some(end) = find_bytes(&buffer, b"\r\n\r\n") {
let end = end + 4;
let header_map = match get_headers(&buffer[0..end]) {
Ok(headers) => headers,
Err(error) => {
*state = State::Finished;
return Poll::Ready(Some(Err(error)));
}
};
*buffer = buffer.split_off(end);
*state = State::StreamingContent(buffer.is_empty());
cx.waker().wake_by_ref();
return Poll::Ready(Some(Ok(ParseOutput::Headers(header_map))));
} else {
match futures_core::ready!(stream.as_mut().poll_next(cx)) {
Some(Ok(bytes)) => {
buffer.extend_from_slice(&bytes);
continue;
}
Some(Err(e)) => {
return Poll::Ready(Some(Err(MultipartError::Stream(e.into()))))
}
None => {
return Poll::Ready(Some(Err(
MultipartError::EOFWhileReadingHeaders,
)))
}
}
}
}
State::StreamingContent(exhausted) => {
let boundary_len = boundary.len();
if buffer.is_empty() || *exhausted {
*state = State::StreamingContent(false);
match futures_core::ready!(stream.as_mut().poll_next(cx)) {
Some(Ok(bytes)) => {
buffer.extend_from_slice(&bytes);
continue;
}
Some(Err(e)) => {
return Poll::Ready(Some(Err(MultipartError::Stream(e.into()))))
}
None => {
return Poll::Ready(Some(Err(MultipartError::EOFWhileReadingBody)))
}
}
}
if let Some(idx) = find_bytes(&buffer, boundary) {
if buffer.len() < idx + 2 + boundary_len {
*state = State::StreamingContent(true);
continue;
}
let end_boundary = idx + boundary_len;
let after_boundary = &buffer[end_boundary..end_boundary + 2];
if after_boundary == b"\r\n" {
let mut other_bytes = (*buffer).split_off(idx);
other_bytes = other_bytes.split_off(2 + boundary_len);
let return_bytes = Bytes::from(mem::replace(buffer, other_bytes));
*state = State::ReadingHeader;
cx.waker().wake_by_ref();
return Poll::Ready(Some(Ok(ParseOutput::Bytes(return_bytes))));
} else if after_boundary == b"--" {
buffer.truncate(idx);
*state = State::Finished;
return Poll::Ready(Some(Ok(ParseOutput::Bytes(Bytes::from(
mem::take(buffer),
)))));
} else {
return Poll::Ready(Some(Err(MultipartError::GarbageAfterBoundary([
after_boundary[0],
after_boundary[1],
]))));
}
} else {
let buffer_len = buffer.len();
let start_idx =
(buffer_len as i64 - (boundary_len as i64 - 1)).max(0) as usize;
let end_of_buffer = &buffer[start_idx..];
if let Some(idx) = rfind_bytes(end_of_buffer, b"\r") {
if &end_of_buffer[idx..] == &boundary[..(end_of_buffer.len() - idx)] {
*state = State::StreamingContent(true);
let mut output = buffer.split_off(idx + start_idx);
mem::swap(&mut output, buffer);
cx.waker().wake_by_ref();
return Poll::Ready(Some(Ok(ParseOutput::Bytes(output.freeze()))));
}
}
let output = mem::take(buffer);
return Poll::Ready(Some(Ok(ParseOutput::Bytes(output.freeze()))));
}
}
State::Finished => return Poll::Ready(None),
}
}
}
}
#[derive(Debug, PartialEq)]
enum State {
ReadingBoundary,
ReadingHeader,
StreamingContent(bool),
Finished,
}
#[derive(Debug)]
pub enum ParseOutput {
Headers(HeaderMap<HeaderValue>),
Bytes(Bytes),
}
#[cfg(test)]
mod tests {
use super::*;
use crate::client::ByteStream;
use futures_util::StreamExt;
#[tokio::test]
async fn read_stream() {
let input: &[u8] = b"--AaB03x\r\n\
Content-Disposition: form-data; name=\"file\"; filename=\"text.txt\"\r\n\
Content-Type: text/plain\r\n\
\r\n\
Lorem Ipsum\n\r\n\
--AaB03x\r\n\
Content-Disposition: form-data; name=\"name1\"\r\n\
\r\n\
value1\r\n\
--AaB03x\r\n\
Content-Disposition: form-data; name=\"name2\"\r\n\
\r\n\
value2\r\n\
--AaB03x--\r\n";
let mut stream = MultipartStream::new("AaB03x", ByteStream::new(input));
if let Some(Ok(mut mpart_field)) = stream.next().await {
assert_eq!(mpart_field.name().ok(), Some("file"));
assert_eq!(mpart_field.filename().ok(), Some("text.txt"));
if let Some(Ok(bytes)) = mpart_field.next().await {
assert_eq!(bytes, Bytes::from(b"Lorem Ipsum\n" as &[u8]));
}
} else {
panic!("First value should be a field")
}
}
#[test]
fn read_filename() {
let input = "form-data; name=\"file\"; filename=\"text.txt\"";
let name = get_dispo_param(input, "name");
let filename = get_dispo_param(input, "filename");
assert_eq!(name, Some("file"));
assert_eq!(filename, Some("text.txt"));
}
#[tokio::test]
async fn reads_streams_and_fields() {
let input: &[u8] = b"--AaB03x\r\n\
Content-Disposition: form-data; name=\"file\"; filename=\"text.txt\"\r\n\
Content-Type: text/plain\r\n\
\r\n\
Lorem Ipsum\n\r\n\
--AaB03x\r\n\
Content-Disposition: form-data; name=\"name1\"\r\n\
\r\n\
value1\r\n\
--AaB03x\r\n\
Content-Disposition: form-data; name=\"name2\"\r\n\
\r\n\
value2\r\n\
--AaB03x--\r\n";
let mut read = MultipartParser::new("AaB03x", ByteStream::new(input));
if let Some(Ok(ParseOutput::Headers(val))) = read.next().await {
println!("Headers:{:?}", val);
} else {
panic!("First value should be a header")
}
if let Some(Ok(ParseOutput::Bytes(bytes))) = read.next().await {
assert_eq!(&*bytes, b"Lorem Ipsum\n");
} else {
panic!("Second value should be bytes")
}
if let Some(Ok(ParseOutput::Headers(val))) = read.next().await {
println!("Headers:{:?}", val);
} else {
panic!("Third value should be a header")
}
if let Some(Ok(ParseOutput::Bytes(bytes))) = read.next().await {
assert_eq!(&*bytes, b"value1");
} else {
panic!("Fourth value should be bytes")
}
if let Some(Ok(ParseOutput::Headers(val))) = read.next().await {
println!("Headers:{:?}", val);
} else {
panic!("Fifth value should be a header")
}
if let Some(Ok(ParseOutput::Bytes(bytes))) = read.next().await {
assert_eq!(&*bytes, b"value2");
} else {
panic!("Sixth value should be bytes")
}
assert!(read.next().await.is_none());
}
#[tokio::test]
async fn unfinished_header() {
let input: &[u8] = b"--AaB03x\r\n\
Content-Disposition: form-data; name=\"file\"; filename=\"text.txt\"\r\n\
Content-Type: text/plain";
let mut read = MultipartParser::new("AaB03x", ByteStream::new(input));
let ret = read.next().await;
assert!(matches!(
ret,
Some(Err(MultipartError::EOFWhileReadingHeaders))
),);
}
#[tokio::test]
async fn unfinished_second_header() {
let input: &[u8] = b"--AaB03x\r\n\
Content-Disposition: form-data; name=\"file\"; filename=\"text.txt\"\r\n\
Content-Type: text/plain\r\n\
\r\n\
Lorem Ipsum\n\r\n\
--AaB03x\r\n\
Content-Disposition: form-data; name=\"name1\"";
let mut read = MultipartParser::new("AaB03x", ByteStream::new(input));
if let Some(Ok(ParseOutput::Headers(val))) = read.next().await {
println!("Headers:{:?}", val);
} else {
panic!("First value should be a header")
}
if let Some(Ok(ParseOutput::Bytes(bytes))) = read.next().await {
assert_eq!(&*bytes, b"Lorem Ipsum\n");
} else {
panic!("Second value should be bytes")
}
let ret = read.next().await;
assert!(matches!(
ret,
Some(Err(MultipartError::EOFWhileReadingHeaders))
),);
}
#[tokio::test]
async fn invalid_header() {
let input: &[u8] = b"--AaB03x\r\n\
I am a bad header\r\n\
\r\n";
let mut read = MultipartParser::new("AaB03x", ByteStream::new(input));
let val = read.next().await.unwrap();
match val {
Err(MultipartError::HeaderParse(err)) => {
println!("{}", err);
}
val => {
panic!("Expecting Parse Error, Instead got:{:?}", val);
}
}
}
#[tokio::test]
async fn invalid_boundary() {
let input: &[u8] = b"--InvalidBoundary\r\n\
Content-Disposition: form-data; name=\"file\"; filename=\"text.txt\"\r\n\
Content-Type: text/plain\r\n\
\r\n\
Lorem Ipsum\n\r\n\
--InvalidBoundary--\r\n";
let mut read = MultipartParser::new("AaB03x", ByteStream::new(input));
let val = read.next().await.unwrap();
match val {
Err(MultipartError::InvalidBoundary { expected, found }) => {
assert_eq!(expected, "--AaB03x\\r\\n");
assert_eq!(found, "--InvalidB");
}
val => {
panic!("Expecting Invalid Boundary Error, Instead got:{:?}", val);
}
}
}
#[tokio::test]
async fn zero_read() {
use bytes::{BufMut, BytesMut};
let input = b"----------------------------332056022174478975396798\r\n\
Content-Disposition: form-data; name=\"file\"\r\n\
Content-Type: application/octet-stream\r\n\
\r\n\
\r\n\
\r\n\
dolphin\n\
whale\r\n\
----------------------------332056022174478975396798--\r\n";
let boundary = "--------------------------332056022174478975396798";
let mut read = MultipartStream::new(boundary, ByteStream::new(input));
let mut part = match read.next().await.unwrap() {
Ok(mf) => {
assert_eq!(mf.name().unwrap(), "file");
assert_eq!(mf.content_type().unwrap(), "application/octet-stream");
mf
}
Err(e) => panic!("unexpected: {}", e),
};
let mut buffer = BytesMut::new();
loop {
match part.next().await {
Some(Ok(bytes)) => buffer.put(bytes),
Some(Err(e)) => panic!("unexpected {}", e),
None => break,
}
}
let nth = read.next().await;
assert!(nth.is_none());
assert_eq!(buffer.as_ref(), b"\r\n\r\ndolphin\nwhale");
}
#[tokio::test]
async fn r_read() {
use std::convert::Infallible;
#[derive(Clone)]
pub struct SplitStream {
packets: Vec<Bytes>,
}
impl SplitStream {
pub fn new() -> Self {
SplitStream { packets: vec![] }
}
pub fn add_packet<P: Into<Bytes>>(&mut self, bytes: P) {
self.packets.push(bytes.into());
}
}
impl Stream for SplitStream {
type Item = Result<Bytes, Infallible>;
fn poll_next(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
if self.as_mut().packets.is_empty() {
return Poll::Ready(None);
}
Poll::Ready(Some(Ok(self.as_mut().packets.remove(0))))
}
}
use bytes::{BufMut, BytesMut};
let input1: &[u8] = b"----------------------------332056022174478975396798\r\n\
Content-Disposition: form-data; name=\"file\"\r\n\
Content-Type: application/octet-stream\r\n\
\r\n\
\r\r\r\r\r\r\r\r\r\r\r\r\r\
\r\n\
----------------------------332";
let input2: &[u8] = b"056022174478975396798--\r\n";
let boundary = "--------------------------332056022174478975396798";
let mut split_stream = SplitStream::new();
split_stream.add_packet(&*input1);
split_stream.add_packet(&*input2);
let mut read = MultipartStream::new(boundary, split_stream);
let mut part = match read.next().await.unwrap() {
Ok(mf) => {
assert_eq!(mf.name().unwrap(), "file");
assert_eq!(mf.content_type().unwrap(), "application/octet-stream");
mf
}
Err(e) => panic!("unexpected: {}", e),
};
let mut buffer = BytesMut::new();
loop {
match part.next().await {
Some(Ok(bytes)) => buffer.put(bytes),
Some(Err(e)) => panic!("unexpected {}", e),
None => break,
}
}
let nth = read.next().await;
assert!(nth.is_none());
assert_eq!(buffer.as_ref(), b"\r\r\r\r\r\r\r\r\r\r\r\r\r");
}
}