mpart_async/
server.rs

1use bytes::{Bytes, BytesMut};
2use futures_core::Stream;
3use http::header::{HeaderMap, HeaderName, HeaderValue};
4use httparse::Status;
5use log::debug;
6use pin_project_lite::pin_project;
7use std::borrow::Cow;
8use std::error::Error as StdError;
9use std::mem;
10use std::pin::Pin;
11use std::sync::{Arc, Mutex};
12use std::task::{Context, Poll};
13use thiserror::Error;
14
15use memchr::{memchr, memmem::find, memrchr};
16use percent_encoding::percent_decode_str;
17
18type AnyStdError = Box<dyn StdError + Send + Sync + 'static>;
19
20/// A single field of a [`MultipartStream`](./struct.MultipartStream.html) which itself is a stream
21///
22/// This represents either an uploaded file, or a simple text value
23///
24/// Each field will have some headers and then a body.
25/// There are no assumptions made when parsing about what headers are present, but some of the helper methods here (such as `content_type()` & `filename()`) will return an error if they aren't present.
26/// The body will be returned as an inner stream of bytes from the request, but up to the end of the field.
27///
28/// Fields are not concurrent against their parent multipart request. This is because multipart submissions are a single http request and we don't support going backwards or skipping bytes.  In other words you can't read from multiple fields from the same request at the same time: you must wait for one field to finish being read before moving on.
29
30pub struct MultipartField<S, E>
31where
32    S: Stream<Item = Result<Bytes, E>> + Unpin,
33    E: Into<AnyStdError>,
34{
35    headers: HeaderMap<HeaderValue>,
36    state: Arc<Mutex<MultipartState<S, E>>>,
37}
38
39impl<S, E> MultipartField<S, E>
40where
41    S: Stream<Item = Result<Bytes, E>> + Unpin,
42    E: Into<AnyStdError>,
43{
44    /// Return the headers for the field
45    ///
46    /// You can use `self.headers.get("my-header").and_then(|val| val.to_str().ok())` to get out the header if present
47    pub fn headers(&self) -> &HeaderMap<HeaderValue> {
48        &self.headers
49    }
50
51    /// Return the content type of the field (if present or error)
52    pub fn content_type(&self) -> Result<&str, MultipartError> {
53        if let Some(val) = self.headers.get("content-type") {
54            return val.to_str().map_err(|_| MultipartError::InvalidHeader);
55        }
56
57        Err(MultipartError::InvalidHeader)
58    }
59
60    /// Return the filename of the field (if present or error).
61    /// The returned filename will be utf8 percent-decoded
62    pub fn filename(&self) -> Result<Cow<str>, MultipartError> {
63        if let Some(val) = self.headers.get("content-disposition") {
64            let string_val =
65                std::str::from_utf8(val.as_bytes()).map_err(|_| MultipartError::InvalidHeader)?;
66            if let Some(filename) = get_dispo_param(string_val, "filename*") {
67                let stripped = strip_utf8_prefix(filename);
68                return Ok(stripped);
69            }
70            if let Some(filename) = get_dispo_param(string_val, "filename") {
71                return Ok(filename);
72            }
73        }
74
75        Err(MultipartError::InvalidHeader)
76    }
77
78    /// Return the name of the field (if present or error).
79    /// The returned name will be utf8 percent-decoded
80    pub fn name(&self) -> Result<Cow<str>, MultipartError> {
81        if let Some(val) = self.headers.get("content-disposition") {
82            let string_val =
83                std::str::from_utf8(val.as_bytes()).map_err(|_| MultipartError::InvalidHeader)?;
84            if let Some(filename) = get_dispo_param(string_val, "name") {
85                return Ok(filename);
86            }
87        }
88
89        Err(MultipartError::InvalidHeader)
90    }
91}
92
93fn strip_utf8_prefix(string: Cow<str>) -> Cow<str> {
94    if string.starts_with("UTF-8''") || string.starts_with("utf-8''") {
95        let split = string.split_at(7).1;
96        return Cow::from(split.to_owned());
97    }
98
99    string
100}
101
102/// This function will get a disposition param from `content-disposition` header & try to escape it if there are escaped quotes or percent encoding
103fn get_dispo_param<'a>(input: &'a str, param: &str) -> Option<Cow<'a, str>> {
104    debug!("dispo param:{input}, field `{param}`");
105    if let Some(start_idx) = find(input.as_bytes(), param.as_bytes()) {
106        debug!("Start idx found:{start_idx}");
107        let end_param = start_idx + param.len();
108        //check bounds
109        if input.len() > end_param + 2 && &input[end_param..end_param + 2] == "=\"" {
110            let start = end_param + 2;
111
112            let mut snippet = &input[start..];
113
114            // If we encounter a `\"` in the string we need to escape it
115            // This means that we need to create a new escaped string as it will be discontiguous
116            let mut escaped_buffer: Option<String> = None;
117
118            while let Some(end) = memchr(b'"', snippet.as_bytes()) {
119                // if we encounter a backslash before the quote
120                if end > 0
121                    && snippet
122                        .get(end - 1..end)
123                        .map_or(false, |character| character == "\\")
124                {
125                    // We get an existing escaped buffer or create an empty string
126                    let mut buffer = escaped_buffer.unwrap_or_default();
127
128                    // push up until the escaped quote
129                    buffer.push_str(&snippet[..end - 1]);
130                    // push in the quote itself
131                    buffer.push('"');
132
133                    escaped_buffer = Some(buffer);
134
135                    // Move the buffer ahead
136                    snippet = &snippet[end + 1..];
137                    continue;
138                } else {
139                    // we're at the end
140
141                    // if we have something escaped
142                    match escaped_buffer {
143                        Some(mut escaped) => {
144                            // tack on the end of the string
145                            escaped.push_str(&snippet[0..end]);
146
147                            // Double escape with percent decode
148                            if escaped.contains('%') {
149                                let decoded_val = percent_decode_str(&escaped).decode_utf8_lossy();
150                                return Some(Cow::Owned(decoded_val.into_owned()));
151                            }
152
153                            return Some(Cow::Owned(escaped));
154                        }
155                        None => {
156                            let value = &snippet[0..end];
157
158                            // Escape with percent decode, if necessary
159                            if value.contains('%') {
160                                let decoded_val = percent_decode_str(value).decode_utf8_lossy();
161
162                                return Some(decoded_val);
163                            }
164
165                            return Some(Cow::Borrowed(value));
166                        }
167                    }
168                }
169            }
170        }
171    }
172
173    None
174}
175
176//Streams out bytes
177impl<S, E> Stream for MultipartField<S, E>
178where
179    S: Stream<Item = Result<Bytes, E>> + Unpin,
180    E: Into<AnyStdError>,
181{
182    type Item = Result<Bytes, MultipartError>;
183
184    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
185        let self_mut = &mut self.as_mut();
186
187        let state = &mut self_mut
188            .state
189            .try_lock()
190            .map_err(|_| MultipartError::InternalBorrowError)?;
191
192        match Pin::new(&mut state.parser).poll_next(cx) {
193            Poll::Pending => Poll::Pending,
194            Poll::Ready(Some(Err(err))) => {
195                Poll::Ready(Some(Err(MultipartError::Stream(err.into()))))
196            }
197            Poll::Ready(None) => Poll::Ready(None),
198            //If we have headers, we have reached the next file
199            Poll::Ready(Some(Ok(ParseOutput::Headers(headers)))) => {
200                state.next_item = Some(headers);
201                Poll::Ready(None)
202            }
203            Poll::Ready(Some(Ok(ParseOutput::Bytes(bytes)))) => Poll::Ready(Some(Ok(bytes))),
204        }
205    }
206}
207
208//This is our state we use to drive the parser.  The `next_item` is there just for headers if there are more in the request
209struct MultipartState<S, E>
210where
211    S: Stream<Item = Result<Bytes, E>> + Unpin,
212    E: Into<AnyStdError>,
213{
214    parser: MultipartParser<S, E>,
215    next_item: Option<HeaderMap<HeaderValue>>,
216}
217
218/// The main `MultipartStream` struct which will contain one or more fields (a stream of streams)
219///
220/// You can construct this given a boundary and a stream of bytes from a server request.
221///
222/// **Please Note**: If you are reading in a field, you must exhaust the field's bytes before moving onto the next field
223/// ```no_run
224/// # use warp::Filter;
225/// # use bytes::{Buf, BufMut, BytesMut};
226/// # use futures_util::TryStreamExt;
227/// # use futures_core::Stream;
228/// # use mime::Mime;
229/// # use mpart_async::server::MultipartStream;
230/// # use std::convert::Infallible;
231/// # #[tokio::main]
232/// # async fn main() {
233/// #     // Match any request and return hello world!
234/// #     let routes = warp::any()
235/// #         .and(warp::header::<Mime>("content-type"))
236/// #         .and(warp::body::stream())
237/// #         .and_then(mpart);
238/// #     warp::serve(routes).run(([127, 0, 0, 1], 3030)).await;
239/// # }
240/// # async fn mpart(
241/// #     mime: Mime,
242/// #     body: impl Stream<Item = Result<impl Buf, warp::Error>> + Unpin,
243/// # ) -> Result<impl warp::Reply, Infallible> {
244/// #     let boundary = mime.get_param("boundary").map(|v| v.to_string()).unwrap();
245/// let mut stream = MultipartStream::new(boundary, body.map_ok(|mut buf| {
246///     let mut ret = BytesMut::with_capacity(buf.remaining());
247///     ret.put(buf);
248///     ret.freeze()
249/// }));
250///
251/// while let Ok(Some(mut field)) = stream.try_next().await {
252///     println!("Field received:{}", field.name().unwrap());
253///     if let Ok(filename) = field.filename() {
254///         println!("Field filename:{}", filename);
255///     }
256///
257///     while let Ok(Some(bytes)) = field.try_next().await {
258///         println!("Bytes received:{}", bytes.len());
259///     }
260/// }
261/// #     Ok(format!("Thanks!\n"))
262/// # }
263/// ```
264
265pub struct MultipartStream<S, E>
266where
267    S: Stream<Item = Result<Bytes, E>> + Unpin,
268    E: Into<AnyStdError>,
269{
270    state: Arc<Mutex<MultipartState<S, E>>>,
271}
272
273impl<S, E> MultipartStream<S, E>
274where
275    S: Stream<Item = Result<Bytes, E>> + Unpin,
276    E: Into<AnyStdError>,
277{
278    /// Construct a MultipartStream given a boundary
279    pub fn new<I: Into<Bytes>>(boundary: I, stream: S) -> Self {
280        Self {
281            state: Arc::new(Mutex::new(MultipartState {
282                parser: MultipartParser::new(boundary, stream),
283                next_item: None,
284            })),
285        }
286    }
287}
288
289impl<S, E> Stream for MultipartStream<S, E>
290where
291    S: Stream<Item = Result<Bytes, E>> + Unpin,
292    E: Into<AnyStdError>,
293{
294    type Item = Result<MultipartField<S, E>, MultipartError>;
295
296    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
297        let self_mut = &mut self.as_mut();
298
299        let state = &mut self_mut
300            .state
301            .try_lock()
302            .map_err(|_| MultipartError::InternalBorrowError)?;
303
304        if let Some(headers) = state.next_item.take() {
305            return Poll::Ready(Some(Ok(MultipartField {
306                headers,
307                state: self_mut.state.clone(),
308            })));
309        }
310
311        match Pin::new(&mut state.parser).poll_next(cx) {
312            Poll::Pending => Poll::Pending,
313            Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
314            Poll::Ready(None) => Poll::Ready(None),
315
316            //If we have headers, we have reached the next file
317            Poll::Ready(Some(Ok(ParseOutput::Headers(headers)))) => {
318                Poll::Ready(Some(Ok(MultipartField {
319                    headers,
320                    state: self_mut.state.clone(),
321                })))
322            }
323            Poll::Ready(Some(Ok(ParseOutput::Bytes(_bytes)))) => {
324                //If we are returning bytes from this stream, then there is some error
325                Poll::Ready(Some(Err(MultipartError::ShouldPollField)))
326            }
327        }
328    }
329}
330
331#[derive(Error, Debug)]
332/// The Standard Error Type
333pub enum MultipartError {
334    /// Given if the boundary is not what is expected
335    #[error("Invalid Boundary. (expected {expected:?}, found {found:?})")]
336    InvalidBoundary {
337        /// The Expected Boundary
338        expected: String,
339        /// The Found Boundary
340        found: String,
341    },
342    /// Given if when parsing the headers they are incomplete
343    #[error("Incomplete Headers")]
344    IncompleteHeader,
345    /// Given if when trying to retrieve a field like name or filename it's not present or malformed
346    #[error("Invalid Header Value")]
347    InvalidHeader,
348    /// Given if in the middle of polling a Field, and someone tries to poll the Stream
349    #[error(
350        "Tried to poll an MultipartStream when the MultipartField should be polled, try using `flatten()`"
351    )]
352    ShouldPollField,
353    /// Given if in the middle of polling a Field, but the Mutex is in use somewhere else
354    #[error("Tried to poll an MultipartField and the Mutex has already been locked")]
355    InternalBorrowError,
356    /// Given if there is an error when parsing headers
357    #[error(transparent)]
358    HeaderParse(#[from] httparse::Error),
359    /// Given if there is an error in the underlying stream
360    #[error(transparent)]
361    Stream(#[from] AnyStdError),
362    /// Given if the stream ends when reading headers
363    #[error("EOF while reading headers")]
364    EOFWhileReadingHeaders,
365    /// Given if the stream ends when reading boundary
366    #[error("EOF while reading boundary")]
367    EOFWhileReadingBoundary,
368    /// Given if if the stream ends when reading the body and there is no end boundary
369    #[error("EOF while reading body")]
370    EOFWhileReadingBody,
371    /// Given if there is garbage after the boundary
372    #[error("Garbage following boundary: {0:02x?}")]
373    GarbageAfterBoundary([u8; 2]),
374}
375
376pin_project! {
377    /// A low-level parser which `MultipartStream` uses.
378    ///
379    /// Returns either headers of a field or a byte chunk, alternating between the two types
380    ///
381    /// Unless there is an issue with using [`MultipartStream`](./struct.MultipartStream.html) you don't normally want to use this struct
382    #[project = ParserProj]
383    pub struct MultipartParser<S, E>
384    where
385        S: Stream<Item = Result<Bytes, E>>,
386        E: Into<AnyStdError>,
387    {
388        boundary: Bytes,
389        buffer: BytesMut,
390        state: State,
391        #[pin]
392        stream: S,
393    }
394}
395
396impl<S, E> MultipartParser<S, E>
397where
398    S: Stream<Item = Result<Bytes, E>>,
399    E: Into<AnyStdError>,
400{
401    /// Construct a raw parser from a boundary/stream.
402    pub fn new<I: Into<Bytes>>(boundary: I, stream: S) -> Self {
403        Self {
404            boundary: boundary.into(),
405            buffer: BytesMut::new(),
406            state: State::ReadingBoundary,
407            stream,
408        }
409    }
410}
411
412const NUM_HEADERS: usize = 16;
413
414fn get_headers(buffer: &[u8]) -> Result<HeaderMap<HeaderValue>, MultipartError> {
415    let mut headers = [httparse::EMPTY_HEADER; NUM_HEADERS];
416
417    let idx = match httparse::parse_headers(buffer, &mut headers)? {
418        Status::Complete((idx, _val)) => idx,
419        Status::Partial => return Err(MultipartError::IncompleteHeader),
420    };
421
422    let mut header_map = HeaderMap::with_capacity(idx);
423
424    for header in headers.iter().take(idx) {
425        if !header.name.is_empty() {
426            header_map.insert(
427                HeaderName::from_bytes(header.name.as_bytes())
428                    .map_err(|_| MultipartError::InvalidHeader)?,
429                HeaderValue::from_bytes(header.value).map_err(|_| MultipartError::InvalidHeader)?,
430            );
431        }
432    }
433
434    Ok(header_map)
435}
436
437impl<S, E> Stream for MultipartParser<S, E>
438where
439    S: Stream<Item = Result<Bytes, E>>,
440    E: Into<AnyStdError>,
441{
442    type Item = Result<ParseOutput, MultipartError>;
443
444    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
445        let ParserProj {
446            boundary,
447            buffer,
448            state,
449            mut stream,
450        } = self.project();
451
452        loop {
453            match state {
454                State::ReadingBoundary => {
455                    let boundary_len = boundary.len();
456
457                    //If the buffer is too small
458                    if buffer.len() < boundary_len + 4 {
459                        match futures_core::ready!(stream.as_mut().poll_next(cx)) {
460                            Some(Ok(bytes)) => {
461                                buffer.extend_from_slice(&bytes);
462                                continue;
463                            }
464                            Some(Err(e)) => {
465                                return Poll::Ready(Some(Err(MultipartError::Stream(e.into()))))
466                            }
467                            None => {
468                                return Poll::Ready(Some(Err(
469                                    MultipartError::EOFWhileReadingBoundary,
470                                )))
471                            }
472                        }
473                    }
474
475                    //If the buffer starts with `--<boundary>\r\n`
476                    if &buffer[..2] == b"--"
477                        && buffer[2..boundary_len + 2] == *boundary
478                        && &buffer[boundary_len + 2..boundary_len + 4] == b"\r\n"
479                    {
480                        //remove the boundary from the buffer, returning the tail
481                        *buffer = buffer.split_off(boundary_len + 4);
482                        *state = State::ReadingHeader;
483
484                        //Update the boundary here to include the \r\n-- preamble for individual fields
485                        let mut new_boundary = BytesMut::with_capacity(boundary_len + 4);
486
487                        new_boundary.extend_from_slice(b"\r\n--");
488                        new_boundary.extend_from_slice(boundary);
489
490                        *boundary = new_boundary.freeze();
491
492                        cx.waker().wake_by_ref();
493                        return Poll::Pending;
494                    } else {
495                        let expected = format!("--{}\\r\\n", String::from_utf8_lossy(boundary));
496                        let found =
497                            String::from_utf8_lossy(&buffer[..boundary_len + 4]).to_string();
498
499                        let error = MultipartError::InvalidBoundary { expected, found };
500
501                        //There is some error with the boundary...
502                        return Poll::Ready(Some(Err(error)));
503                    }
504                }
505                State::ReadingHeader => {
506                    if let Some(end) = find(buffer, b"\r\n\r\n") {
507                        //Need to include the end header bytes for the parse to work
508                        let end = end + 4;
509
510                        let header_map = match get_headers(&buffer[0..end]) {
511                            Ok(headers) => headers,
512                            Err(error) => {
513                                *state = State::Finished;
514                                return Poll::Ready(Some(Err(error)));
515                            }
516                        };
517
518                        *buffer = buffer.split_off(end);
519                        *state = State::StreamingContent(buffer.is_empty());
520
521                        cx.waker().wake_by_ref();
522
523                        return Poll::Ready(Some(Ok(ParseOutput::Headers(header_map))));
524                    } else {
525                        match futures_core::ready!(stream.as_mut().poll_next(cx)) {
526                            Some(Ok(bytes)) => {
527                                buffer.extend_from_slice(&bytes);
528                                continue;
529                            }
530                            Some(Err(e)) => {
531                                return Poll::Ready(Some(Err(MultipartError::Stream(e.into()))))
532                            }
533                            None => {
534                                return Poll::Ready(Some(Err(
535                                    MultipartError::EOFWhileReadingHeaders,
536                                )))
537                            }
538                        }
539                    }
540                }
541
542                State::StreamingContent(exhausted) => {
543                    let boundary_len = boundary.len();
544
545                    if buffer.is_empty() || *exhausted {
546                        *state = State::StreamingContent(false);
547                        match futures_core::ready!(stream.as_mut().poll_next(cx)) {
548                            Some(Ok(bytes)) => {
549                                buffer.extend_from_slice(&bytes);
550                                continue;
551                            }
552                            Some(Err(e)) => {
553                                return Poll::Ready(Some(Err(MultipartError::Stream(e.into()))))
554                            }
555                            None => {
556                                return Poll::Ready(Some(Err(MultipartError::EOFWhileReadingBody)))
557                            }
558                        }
559                    }
560
561                    //We want to check the value of the buffer to see if there looks like there is an `end` boundary.
562                    if let Some(idx) = find(buffer, boundary) {
563                        //Check the length has enough bytes for us
564                        if buffer.len() < idx + 2 + boundary_len {
565                            // FIXME: cannot stop the read successfully here!
566                            *state = State::StreamingContent(true);
567                            continue;
568                        }
569
570                        //The start of the boundary is 4 chars. i.e, after `\r\n--`
571                        let end_boundary = idx + boundary_len;
572
573                        //We want the chars after the boundary basically
574                        let after_boundary = &buffer[end_boundary..end_boundary + 2];
575
576                        if after_boundary == b"\r\n" {
577                            let mut other_bytes = (*buffer).split_off(idx);
578
579                            //Remove the boundary-related bytes from the start of the buffer
580                            other_bytes = other_bytes.split_off(2 + boundary_len);
581
582                            //Return the bytes up to the boundary, we're finished and need to go back to reading headers
583                            let return_bytes = Bytes::from(mem::replace(buffer, other_bytes));
584
585                            //Replace the buffer with the extra bytes
586                            *state = State::ReadingHeader;
587                            cx.waker().wake_by_ref();
588
589                            return Poll::Ready(Some(Ok(ParseOutput::Bytes(return_bytes))));
590                        } else if after_boundary == b"--" {
591                            //We're at the end, just truncate the bytes
592                            buffer.truncate(idx);
593                            *state = State::Finished;
594
595                            return Poll::Ready(Some(Ok(ParseOutput::Bytes(Bytes::from(
596                                mem::take(buffer),
597                            )))));
598                        } else {
599                            return Poll::Ready(Some(Err(MultipartError::GarbageAfterBoundary([
600                                after_boundary[0],
601                                after_boundary[1],
602                            ]))));
603                        }
604                    } else {
605                        //We need to check for partial matches by checking the last boundary_len bytes
606                        let buffer_len = buffer.len();
607
608                        //Clamp to zero if the boundary length is bigger than the buffer
609                        let start_idx =
610                            (buffer_len as i64 - (boundary_len as i64 - 1)).max(0) as usize;
611
612                        let end_of_buffer = &buffer[start_idx..];
613
614                        if let Some(idx) = memrchr(b'\r', end_of_buffer) {
615                            //If to the end of the match equals the same amount of bytes
616                            if end_of_buffer[idx..] == boundary[..(end_of_buffer.len() - idx)] {
617                                *state = State::StreamingContent(true);
618
619                                //we return all the bytes except for the start of our boundary
620                                let mut output = buffer.split_off(idx + start_idx);
621                                mem::swap(&mut output, buffer);
622
623                                cx.waker().wake_by_ref();
624                                return Poll::Ready(Some(Ok(ParseOutput::Bytes(output.freeze()))));
625                            }
626                        }
627
628                        let output = mem::take(buffer);
629                        return Poll::Ready(Some(Ok(ParseOutput::Bytes(output.freeze()))));
630                    }
631                }
632                State::Finished => return Poll::Ready(None),
633            }
634        }
635    }
636}
637
638#[derive(Debug, PartialEq)]
639enum State {
640    ReadingBoundary,
641    ReadingHeader,
642    StreamingContent(bool),
643    Finished,
644}
645
646#[derive(Debug)]
647/// The output from a MultipartParser
648pub enum ParseOutput {
649    /// Headers received in the output
650    Headers(HeaderMap<HeaderValue>),
651    /// Bytes received in the output
652    Bytes(Bytes),
653}
654
655#[cfg(test)]
656mod tests {
657    use super::*;
658    use crate::client::ByteStream;
659    use futures_util::StreamExt;
660
661    #[tokio::test]
662    async fn read_stream() {
663        let input: &[u8] = b"--AaB03x\r\n\
664                Content-Disposition: form-data; name=\"file\"; filename=\"text.txt\"\r\n\
665                Content-Type: text/plain\r\n\
666                \r\n\
667                Lorem Ipsum\n\r\n\
668                --AaB03x\r\n\
669                Content-Disposition: form-data; name=\"name1\"\r\n\
670                \r\n\
671                value1\r\n\
672                --AaB03x\r\n\
673                Content-Disposition: form-data; name=\"name2\"\r\n\
674                \r\n\
675                value2\r\n\
676                --AaB03x--\r\n";
677
678        let mut stream = MultipartStream::new("AaB03x", ByteStream::new(input));
679
680        if let Some(Ok(mut mpart_field)) = stream.next().await {
681            assert_eq!(mpart_field.name().ok(), Some(Cow::Borrowed("file")));
682            assert_eq!(mpart_field.filename().ok(), Some(Cow::Borrowed("text.txt")));
683
684            if let Some(Ok(bytes)) = mpart_field.next().await {
685                assert_eq!(bytes, Bytes::from(b"Lorem Ipsum\n" as &[u8]));
686            }
687        } else {
688            panic!("First value should be a field")
689        }
690    }
691
692    #[tokio::test]
693    async fn read_utf_8_filename() {
694        let input: &[u8] = b"--AaB03x\r\n\
695                Content-Disposition: form-data; name=\"file\"; filename=\"text.txt\"; filename*=\"aous%20.txt\"\r\n\
696                Content-Type: text/plain\r\n\
697                \r\n\
698                Lorem Ipsum\n\r\n\
699                --AaB03x--\r\n";
700
701        let mut stream = MultipartStream::new("AaB03x", ByteStream::new(input));
702
703        let field = stream.next().await.unwrap().unwrap();
704        assert_eq!(field.filename().ok(), Some(Cow::Borrowed("aous .txt")));
705    }
706
707    #[test]
708    fn read_filename() {
709        let input = "form-data; name=\"file\";\
710                           filename=\"text%20.txt\";\
711                           quoted=\"with a \\\" quote and another \\\" quote\";\
712                           empty=\"\"\
713                           percent_encoded=\"foo%20%3Cbar%3E\"\
714                           ";
715        let name = get_dispo_param(input, "name");
716        let filename = get_dispo_param(input, "filename");
717        let with_a_quote = get_dispo_param(input, "quoted");
718        let empty = get_dispo_param(input, "empty");
719        let percent_encoded = get_dispo_param(input, "percent_encoded");
720
721        assert_eq!(name, Some(Cow::Borrowed("file")));
722        assert_eq!(filename, Some(Cow::Borrowed("text .txt")));
723        assert_eq!(
724            with_a_quote,
725            Some(Cow::Owned("with a \" quote and another \" quote".into()))
726        );
727        assert_eq!(empty, Some(Cow::Borrowed("")));
728        assert_eq!(percent_encoded, Some(Cow::Borrowed("foo <bar>")));
729    }
730
731    #[test]
732    fn read_filename_umlaut() {
733        let input = "form-data; name=\"äöüß\";\
734                           filename*=\"äöü ß%20.txt\";\
735                           quoted=\"with a \\\" quote and another \\\" quote\";\
736                           empty=\"\"\
737                           percent_encoded=\"foo%20%3Cbar%3E\"\
738                           ";
739        let name = get_dispo_param(input, "name");
740        let filename = get_dispo_param(input, "filename*");
741        let with_a_quote = get_dispo_param(input, "quoted");
742        let empty = get_dispo_param(input, "empty");
743        let percent_encoded = get_dispo_param(input, "percent_encoded");
744
745        assert_eq!(name, Some(Cow::Borrowed("äöüß")));
746        assert_eq!(filename, Some(Cow::Borrowed("äöü ß .txt")));
747        assert_eq!(
748            with_a_quote,
749            Some(Cow::Owned("with a \" quote and another \" quote".into()))
750        );
751        assert_eq!(empty, Some(Cow::Borrowed("")));
752        assert_eq!(percent_encoded, Some(Cow::Borrowed("foo <bar>")));
753    }
754
755    #[tokio::test]
756    async fn reads_streams_and_fields() {
757        let input: &[u8] = b"--AaB03x\r\n\
758                Content-Disposition: form-data; name=\"file\"; filename=\"text.txt\"\r\n\
759                Content-Type: text/plain\r\n\
760                \r\n\
761                Lorem Ipsum\n\r\n\
762                --AaB03x\r\n\
763                Content-Disposition: form-data; name=\"name1\"\r\n\
764                \r\n\
765                value1\r\n\
766                --AaB03x\r\n\
767                Content-Disposition: form-data; name=\"name2\"\r\n\
768                \r\n\
769                value2\r\n\
770                --AaB03x--\r\n";
771
772        let mut read = MultipartParser::new("AaB03x", ByteStream::new(input));
773
774        if let Some(Ok(ParseOutput::Headers(val))) = read.next().await {
775            println!("Headers:{:?}", val);
776        } else {
777            panic!("First value should be a header")
778        }
779
780        if let Some(Ok(ParseOutput::Bytes(bytes))) = read.next().await {
781            assert_eq!(&*bytes, b"Lorem Ipsum\n");
782        } else {
783            panic!("Second value should be bytes")
784        }
785
786        if let Some(Ok(ParseOutput::Headers(val))) = read.next().await {
787            println!("Headers:{:?}", val);
788        } else {
789            panic!("Third value should be a header")
790        }
791
792        if let Some(Ok(ParseOutput::Bytes(bytes))) = read.next().await {
793            assert_eq!(&*bytes, b"value1");
794        } else {
795            panic!("Fourth value should be bytes")
796        }
797
798        if let Some(Ok(ParseOutput::Headers(val))) = read.next().await {
799            println!("Headers:{:?}", val);
800        } else {
801            panic!("Fifth value should be a header")
802        }
803
804        if let Some(Ok(ParseOutput::Bytes(bytes))) = read.next().await {
805            assert_eq!(&*bytes, b"value2");
806        } else {
807            panic!("Sixth value should be bytes")
808        }
809
810        assert!(read.next().await.is_none());
811    }
812
813    #[tokio::test]
814    async fn unfinished_header() {
815        let input: &[u8] = b"--AaB03x\r\n\
816                Content-Disposition: form-data; name=\"file\"; filename=\"text.txt\"\r\n\
817                Content-Type: text/plain";
818        let mut read = MultipartParser::new("AaB03x", ByteStream::new(input));
819
820        let ret = read.next().await;
821
822        assert!(matches!(
823            ret,
824            Some(Err(MultipartError::EOFWhileReadingHeaders))
825        ),);
826    }
827
828    #[tokio::test]
829    async fn unfinished_second_header() {
830        let input: &[u8] = b"--AaB03x\r\n\
831                Content-Disposition: form-data; name=\"file\"; filename=\"text.txt\"\r\n\
832                Content-Type: text/plain\r\n\
833                \r\n\
834                Lorem Ipsum\n\r\n\
835                --AaB03x\r\n\
836                Content-Disposition: form-data; name=\"name1\"";
837
838        let mut read = MultipartParser::new("AaB03x", ByteStream::new(input));
839
840        if let Some(Ok(ParseOutput::Headers(val))) = read.next().await {
841            println!("Headers:{:?}", val);
842        } else {
843            panic!("First value should be a header")
844        }
845
846        if let Some(Ok(ParseOutput::Bytes(bytes))) = read.next().await {
847            assert_eq!(&*bytes, b"Lorem Ipsum\n");
848        } else {
849            panic!("Second value should be bytes")
850        }
851
852        let ret = read.next().await;
853
854        assert!(matches!(
855            ret,
856            Some(Err(MultipartError::EOFWhileReadingHeaders))
857        ),);
858    }
859
860    #[tokio::test]
861    async fn invalid_header() {
862        let input: &[u8] = b"--AaB03x\r\n\
863                I am a bad header\r\n\
864                \r\n";
865
866        let mut read = MultipartParser::new("AaB03x", ByteStream::new(input));
867
868        let val = read.next().await.unwrap();
869
870        match val {
871            Err(MultipartError::HeaderParse(err)) => {
872                //all good
873                println!("{}", err);
874            }
875            val => {
876                panic!("Expecting Parse Error, Instead got:{:?}", val);
877            }
878        }
879    }
880
881    #[tokio::test]
882    async fn invalid_boundary() {
883        let input: &[u8] = b"--InvalidBoundary\r\n\
884                Content-Disposition: form-data; name=\"file\"; filename=\"text.txt\"\r\n\
885                Content-Type: text/plain\r\n\
886                \r\n\
887                Lorem Ipsum\n\r\n\
888                --InvalidBoundary--\r\n";
889
890        let mut read = MultipartParser::new("AaB03x", ByteStream::new(input));
891
892        let val = read.next().await.unwrap();
893
894        match val {
895            Err(MultipartError::InvalidBoundary { expected, found }) => {
896                assert_eq!(expected, "--AaB03x\\r\\n");
897                assert_eq!(found, "--InvalidB");
898            }
899            val => {
900                panic!("Expecting Invalid Boundary Error, Instead got:{:?}", val);
901            }
902        }
903    }
904
905    #[tokio::test]
906    async fn zero_read() {
907        use bytes::{BufMut, BytesMut};
908
909        let input = b"----------------------------332056022174478975396798\r\n\
910                Content-Disposition: form-data; name=\"file\"\r\n\
911                Content-Type: application/octet-stream\r\n\
912                \r\n\
913                \r\n\
914                \r\n\
915                dolphin\n\
916                whale\r\n\
917                ----------------------------332056022174478975396798--\r\n";
918
919        let boundary = "--------------------------332056022174478975396798";
920
921        let mut read = MultipartStream::new(boundary, ByteStream::new(input));
922
923        let mut part = match read.next().await.unwrap() {
924            Ok(mf) => {
925                assert_eq!(mf.name().unwrap(), "file");
926                assert_eq!(mf.content_type().unwrap(), "application/octet-stream");
927                mf
928            }
929            Err(e) => panic!("unexpected: {}", e),
930        };
931
932        let mut buffer = BytesMut::new();
933
934        loop {
935            match part.next().await {
936                Some(Ok(bytes)) => buffer.put(bytes),
937                Some(Err(e)) => panic!("unexpected {}", e),
938                None => break,
939            }
940        }
941
942        let nth = read.next().await;
943        assert!(nth.is_none());
944
945        assert_eq!(buffer.as_ref(), b"\r\n\r\ndolphin\nwhale");
946    }
947
948    #[tokio::test]
949    async fn r_read() {
950        use std::convert::Infallible;
951
952        //Used to ensure partial matches are working!
953
954        #[derive(Clone)]
955        pub struct SplitStream {
956            packets: Vec<Bytes>,
957        }
958
959        impl SplitStream {
960            pub fn new() -> Self {
961                SplitStream { packets: vec![] }
962            }
963
964            pub fn add_packet<P: Into<Bytes>>(&mut self, bytes: P) {
965                self.packets.push(bytes.into());
966            }
967        }
968
969        impl Stream for SplitStream {
970            type Item = Result<Bytes, Infallible>;
971
972            fn poll_next(
973                mut self: Pin<&mut Self>,
974                _cx: &mut Context<'_>,
975            ) -> Poll<Option<Self::Item>> {
976                if self.as_mut().packets.is_empty() {
977                    return Poll::Ready(None);
978                }
979
980                Poll::Ready(Some(Ok(self.as_mut().packets.remove(0))))
981            }
982        }
983
984        use bytes::{BufMut, BytesMut};
985
986        //This is a packet split on the boundary to test partial matching
987        let input1: &[u8] = b"----------------------------332056022174478975396798\r\n\
988                Content-Disposition: form-data; name=\"file\"\r\n\
989                Content-Type: application/octet-stream\r\n\
990                \r\n\
991                \r\r\r\r\r\r\r\r\r\r\r\r\r\
992                \r\n\
993                ----------------------------332";
994
995        //This is the rest of the packet
996        let input2: &[u8] = b"056022174478975396798--\r\n";
997
998        let boundary = "--------------------------332056022174478975396798";
999
1000        let mut split_stream = SplitStream::new();
1001
1002        split_stream.add_packet(&*input1);
1003        split_stream.add_packet(&*input2);
1004
1005        let mut read = MultipartStream::new(boundary, split_stream);
1006
1007        let mut part = match read.next().await.unwrap() {
1008            Ok(mf) => {
1009                assert_eq!(mf.name().unwrap(), "file");
1010                assert_eq!(mf.content_type().unwrap(), "application/octet-stream");
1011                mf
1012            }
1013            Err(e) => panic!("unexpected: {}", e),
1014        };
1015
1016        let mut buffer = BytesMut::new();
1017
1018        loop {
1019            match part.next().await {
1020                Some(Ok(bytes)) => buffer.put(bytes),
1021                Some(Err(e)) => panic!("unexpected {}", e),
1022                None => break,
1023            }
1024        }
1025
1026        let nth = read.next().await;
1027        assert!(nth.is_none());
1028
1029        assert_eq!(buffer.as_ref(), b"\r\r\r\r\r\r\r\r\r\r\r\r\r");
1030    }
1031
1032    #[test]
1033    fn test_strip_no_strip_necessary() {
1034        let name: Cow<str> = Cow::Owned("äöüß.txt".to_owned());
1035
1036        let res = strip_utf8_prefix(name.clone());
1037
1038        assert_eq!(res, name);
1039    }
1040
1041    #[test]
1042    fn test_strip_uppercase_utf8() {
1043        let name: Cow<str> = Cow::Owned("UTF-8''äöüß.txt".to_owned());
1044
1045        let res = strip_utf8_prefix(name);
1046
1047        assert_eq!(res, "äöüß.txt");
1048    }
1049
1050    #[test]
1051    fn test_strip_lowercase_utf8() {
1052        let name: Cow<str> = Cow::Owned("utf-8''äöüß.txt".to_owned());
1053
1054        let res = strip_utf8_prefix(name);
1055
1056        assert_eq!(res, "äöüß.txt");
1057    }
1058}