use std::{
pin::Pin,
str,
task::{Context, Poll},
};
use bytes::{Buf, Bytes, BytesMut};
use futures_core::{stream::LocalBoxStream, Stream};
use futures_util::StreamExt;
use crate::{error::MultipartError, multipart_type::MultipartType};
#[derive(PartialEq, Debug)]
enum InnerState {
Eof,
FirstBoundary,
Boundary,
Headers,
}
pub struct MultipartItem {
headers: Vec<(String, String)>,
data: BytesMut,
}
pub struct MultipartReader<'a> {
pub boundary: String,
pub multipart_type: MultipartType,
state: InnerState,
stream: LocalBoxStream<'a, Result<Bytes, MultipartError>>,
buf: BytesMut,
pending_item: Option<MultipartItem>,
}
impl<'a> MultipartReader<'a> {
pub fn from_stream_with_boundary_and_type<S>(
stream: S,
boundary: &str,
multipart_type: MultipartType,
) -> Result<MultipartReader<'a>, MultipartError>
where
S: Stream<Item = Result<Bytes, MultipartError>> + 'a,
{
Ok(MultipartReader {
stream: stream.boxed_local(),
boundary: boundary.to_string(),
multipart_type: multipart_type,
state: InnerState::FirstBoundary,
pending_item: None,
buf: BytesMut::new(),
})
}
pub fn from_data_with_boundary_and_type(
data: &[u8],
boundary: &str,
multipart_type: MultipartType,
) -> Result<MultipartReader<'a>, MultipartError> {
let stream = futures_util::stream::iter(vec![Ok(Bytes::copy_from_slice(data))]);
MultipartReader::from_stream_with_boundary_and_type(stream, boundary, multipart_type)
}
pub fn from_stream_with_headers<S>(
stream: S,
headers: &Vec<(String, String)>,
) -> Result<MultipartReader<'a>, MultipartError>
where
S: Stream<Item = Result<Bytes, MultipartError>> + 'a,
{
let content_type = headers
.iter()
.find(|(key, _)| key.to_lowercase() == "content-type");
if content_type.is_none() {
return Err(MultipartError::NoContentType);
}
let ct = content_type
.unwrap()
.1
.parse::<mime::Mime>()
.map_err(|_e| MultipartError::InvalidContentType)?;
let boundary = ct
.get_param(mime::BOUNDARY)
.ok_or(MultipartError::InvalidBoundary)?;
if ct.type_() != mime::MULTIPART {
return Err(MultipartError::InvalidContentType);
}
let multipart_type = ct
.subtype()
.as_str()
.parse::<MultipartType>()
.map_err(|_| MultipartError::InvalidMultipartType)?;
Ok(MultipartReader {
stream: stream.boxed_local(),
boundary: boundary.to_string(),
multipart_type: multipart_type,
state: InnerState::FirstBoundary,
pending_item: None,
buf: BytesMut::new(),
})
}
pub fn from_data_with_headers(
data: &[u8],
headers: &Vec<(String, String)>,
) -> Result<MultipartReader<'a>, MultipartError> {
let stream = futures_util::stream::iter(vec![Ok(Bytes::copy_from_slice(data))]);
MultipartReader::from_stream_with_headers(stream, headers)
}
fn is_boundary(self: &Self, data: &[u8]) -> bool {
data.starts_with(self.boundary.as_bytes())
}
}
impl<'a> Stream for MultipartReader<'a> {
type Item = Result<MultipartItem, MultipartError>;
fn poll_next(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
let finder = memchr::memmem::Finder::new("\r\n");
loop {
while let Some(idx) = finder.find(&this.buf) {
println!("{}", String::from_utf8_lossy(&this.buf[..idx]));
match this.state {
InnerState::FirstBoundary => {
if this.is_boundary(&this.buf[..idx]) {
this.state = InnerState::Headers;
};
}
InnerState::Boundary => {
if this.is_boundary(&this.buf[..idx]) {
if let Some(item) = this.pending_item.take() {
this.buf.advance(2 + idx);
this.state = InnerState::Headers;
return std::task::Poll::Ready(Some(Ok(item)));
}
this.state = InnerState::Headers;
this.pending_item = Some(MultipartItem {
headers: vec![],
data: BytesMut::new(),
});
};
this.pending_item
.as_mut()
.unwrap()
.data
.extend(&this.buf[..idx])
}
InnerState::Headers => {
if this.pending_item.is_none() {
this.pending_item = Some(MultipartItem {
headers: vec![],
data: BytesMut::new(),
});
}
let header = match str::from_utf8(&this.buf[..idx]) {
Ok(h) => h,
Err(_) => {
this.state = InnerState::Eof;
return std::task::Poll::Ready(Some(Err(
MultipartError::InvalidItemHeader,
)));
}
};
if header.trim().is_empty() {
this.buf.advance(2 + idx);
this.state = InnerState::Boundary;
continue;
}
let header_parts: Vec<&str> = header.split(": ").collect();
if header_parts.len() != 2 {
this.state = InnerState::Eof;
return std::task::Poll::Ready(Some(Err(
MultipartError::InvalidItemHeader,
)));
}
this.pending_item
.as_mut()
.unwrap()
.headers
.push((header_parts[0].to_string(), header_parts[1].to_string()));
}
InnerState::Eof => {
return std::task::Poll::Ready(None);
}
}
this.buf.advance(2 + idx);
}
match Pin::new(&mut this.stream).poll_next(cx) {
Poll::Ready(Some(Ok(data))) => {
this.buf.extend_from_slice(&data);
}
Poll::Ready(None) => {
this.state = InnerState::Eof;
return std::task::Poll::Ready(None);
}
Poll::Ready(Some(Err(e))) => {
this.state = InnerState::Eof;
return std::task::Poll::Ready(Some(Err(e)));
}
Poll::Pending => {
return std::task::Poll::Pending;
}
};
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[futures_test::test]
async fn valid_request() {
let headermap = vec![(
"Content-Type".to_string(),
"multipart/form-data; boundary=--974767299852498929531610575".to_string(),
)];
let data = b"--974767299852498929531610575\r
Content-Disposition: form-data; name=\"text\"\r
\r
text default\r
--974767299852498929531610575\r
Content-Disposition: form-data; name=\"file1\"; filename=\"a.txt\"\r
Content-Type: text/plain\r
\r
Content of a.txt.\r
\r\n--974767299852498929531610575\r
Content-Disposition: form-data; name=\"file2\"; filename=\"a.html\"\r
Content-Type: text/html\r
\r
<!DOCTYPE html><title>Content of a.html.</title>\r
\r
--974767299852498929531610575--\r\n";
assert!(MultipartReader::from_data_with_headers(data, &headermap).is_ok());
assert!(MultipartReader::from_data_with_boundary_and_type(
data,
"--974767299852498929531610575",
MultipartType::FormData
)
.is_ok());
let mut reader = MultipartReader::from_data_with_headers(data, &headermap).unwrap();
assert_eq!(reader.multipart_type, MultipartType::FormData);
let mut items = vec![];
loop {
match reader.next().await {
Some(Ok(item)) => items.push(item),
None => break,
Some(Err(e)) => panic!("Error: {:?}", e),
}
}
assert_eq!(items.len(), 3);
}
}