clia_ntex_multipart/
server.rs

1//! Multipart payload support
2use std::cell::{Cell, RefCell, RefMut};
3use std::task::{Context, Poll};
4use std::{cmp, convert::TryFrom, fmt, marker::PhantomData, pin::Pin, rc::Rc};
5
6use futures::stream::{LocalBoxStream, Stream, StreamExt};
7use ntex::http::error::{ParseError, PayloadError};
8use ntex::http::header::{self, HeaderMap, HeaderName, HeaderValue};
9use ntex::task::LocalWaker;
10use ntex::util::{Bytes, BytesMut};
11
12use crate::error::MultipartError;
13
14const MAX_HEADERS: usize = 32;
15
16/// The server-side implementation of `multipart/form-data` requests.
17///
18/// This will parse the incoming stream into `MultipartItem` instances via its
19/// Stream implementation.
20/// `MultipartItem::Field` contains multipart field. `MultipartItem::Multipart`
21/// is used for nested multipart streams.
22pub struct Multipart {
23    safety: Safety,
24    error: Option<MultipartError>,
25    inner: Option<Rc<RefCell<InnerMultipart>>>,
26}
27
28enum InnerMultipartItem {
29    None,
30    Field(Rc<RefCell<InnerField>>),
31}
32
33#[derive(PartialEq, Debug)]
34enum InnerState {
35    /// Stream eof
36    Eof,
37    /// Skip data until first boundary
38    FirstBoundary,
39    /// Reading boundary
40    Boundary,
41    /// Reading Headers,
42    Headers,
43}
44
45struct InnerMultipart {
46    payload: PayloadRef,
47    boundary: String,
48    state: InnerState,
49    item: InnerMultipartItem,
50}
51
52impl Multipart {
53    /// Create multipart instance for boundary.
54    pub fn new<S>(headers: &HeaderMap, stream: S) -> Multipart
55    where
56        S: Stream<Item = Result<Bytes, PayloadError>> + Unpin + 'static,
57    {
58        match Self::boundary(headers) {
59            Ok(boundary) => Multipart {
60                error: None,
61                safety: Safety::new(),
62                inner: Some(Rc::new(RefCell::new(InnerMultipart {
63                    boundary,
64                    payload: PayloadRef::new(PayloadBuffer::new(Box::new(stream))),
65                    state: InnerState::FirstBoundary,
66                    item: InnerMultipartItem::None,
67                }))),
68            },
69            Err(err) => Multipart { error: Some(err), safety: Safety::new(), inner: None },
70        }
71    }
72
73    /// Extract boundary info from headers.
74    fn boundary(headers: &HeaderMap) -> Result<String, MultipartError> {
75        if let Some(content_type) = headers.get(&header::CONTENT_TYPE) {
76            if let Ok(content_type) = content_type.to_str() {
77                if let Ok(ct) = content_type.parse::<mime::Mime>() {
78                    if let Some(boundary) = ct.get_param(mime::BOUNDARY) {
79                        Ok(boundary.as_str().to_owned())
80                    } else {
81                        Err(MultipartError::Boundary)
82                    }
83                } else {
84                    Err(MultipartError::ParseContentType)
85                }
86            } else {
87                Err(MultipartError::ParseContentType)
88            }
89        } else {
90            Err(MultipartError::NoContentType)
91        }
92    }
93}
94
95impl Stream for Multipart {
96    type Item = Result<Field, MultipartError>;
97
98    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
99        if let Some(err) = self.error.take() {
100            Poll::Ready(Some(Err(err)))
101        } else if self.safety.current() {
102            let this = self.get_mut();
103            let mut inner = this.inner.as_mut().unwrap().borrow_mut();
104            if let Some(mut payload) = inner.payload.get_mut(&this.safety) {
105                payload.poll_stream(cx)?;
106            }
107            inner.poll(&this.safety, cx)
108        } else if !self.safety.is_clean() {
109            Poll::Ready(Some(Err(MultipartError::NotConsumed)))
110        } else {
111            Poll::Pending
112        }
113    }
114}
115
116impl InnerMultipart {
117    fn read_headers(payload: &mut PayloadBuffer) -> Result<Option<HeaderMap>, MultipartError> {
118        match payload.read_until(b"\r\n\r\n")? {
119            None => {
120                if payload.eof {
121                    Err(MultipartError::Incomplete)
122                } else {
123                    Ok(None)
124                }
125            }
126            Some(bytes) => {
127                let mut hdrs = [httparse::EMPTY_HEADER; MAX_HEADERS];
128                match httparse::parse_headers(&bytes, &mut hdrs) {
129                    Ok(httparse::Status::Complete((_, hdrs))) => {
130                        // convert headers
131                        let mut headers = HeaderMap::with_capacity(hdrs.len());
132                        for h in hdrs {
133                            if let Ok(name) = HeaderName::try_from(h.name) {
134                                if let Ok(value) = HeaderValue::try_from(h.value) {
135                                    headers.append(name, value);
136                                } else {
137                                    return Err(ParseError::Header.into());
138                                }
139                            } else {
140                                return Err(ParseError::Header.into());
141                            }
142                        }
143                        Ok(Some(headers))
144                    }
145                    Ok(httparse::Status::Partial) => Err(ParseError::Header.into()),
146                    Err(err) => Err(ParseError::from(err).into()),
147                }
148            }
149        }
150    }
151
152    fn read_boundary(
153        payload: &mut PayloadBuffer,
154        boundary: &str,
155    ) -> Result<Option<bool>, MultipartError> {
156        // TODO: need to read epilogue
157        match payload.readline_or_eof()? {
158            None => {
159                if payload.eof {
160                    Ok(Some(true))
161                } else {
162                    Ok(None)
163                }
164            }
165            Some(chunk) => {
166                if chunk.len() < boundary.len() + 4
167                    || &chunk[..2] != b"--"
168                    || &chunk[2..boundary.len() + 2] != boundary.as_bytes()
169                {
170                    Err(MultipartError::Boundary)
171                } else if &chunk[boundary.len() + 2..] == b"\r\n" {
172                    Ok(Some(false))
173                } else if &chunk[boundary.len() + 2..boundary.len() + 4] == b"--"
174                    && (chunk.len() == boundary.len() + 4
175                        || &chunk[boundary.len() + 4..] == b"\r\n")
176                {
177                    Ok(Some(true))
178                } else {
179                    Err(MultipartError::Boundary)
180                }
181            }
182        }
183    }
184
185    fn skip_until_boundary(
186        payload: &mut PayloadBuffer,
187        boundary: &str,
188    ) -> Result<Option<bool>, MultipartError> {
189        let mut eof = false;
190        loop {
191            match payload.readline()? {
192                Some(chunk) => {
193                    if chunk.is_empty() {
194                        return Err(MultipartError::Boundary);
195                    }
196                    if chunk.len() < boundary.len() {
197                        continue;
198                    }
199                    if &chunk[..2] == b"--" && &chunk[2..chunk.len() - 2] == boundary.as_bytes()
200                    {
201                        break;
202                    } else {
203                        if chunk.len() < boundary.len() + 2 {
204                            continue;
205                        }
206                        let b: &[u8] = boundary.as_ref();
207                        if &chunk[..boundary.len()] == b
208                            && &chunk[boundary.len()..boundary.len() + 2] == b"--"
209                        {
210                            eof = true;
211                            break;
212                        }
213                    }
214                }
215                None => {
216                    return if payload.eof {
217                        Err(MultipartError::Incomplete)
218                    } else {
219                        Ok(None)
220                    };
221                }
222            }
223        }
224        Ok(Some(eof))
225    }
226
227    fn poll(
228        &mut self,
229        safety: &Safety,
230        cx: &mut Context,
231    ) -> Poll<Option<Result<Field, MultipartError>>> {
232        if self.state == InnerState::Eof {
233            Poll::Ready(None)
234        } else {
235            // release field
236            loop {
237                // Nested multipart streams of fields has to be consumed
238                // before switching to next
239                if safety.current() {
240                    let stop = match self.item {
241                        InnerMultipartItem::Field(ref mut field) => {
242                            match field.borrow_mut().poll(safety) {
243                                Poll::Pending => return Poll::Pending,
244                                Poll::Ready(Some(Ok(_))) => continue,
245                                Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
246                                Poll::Ready(None) => true,
247                            }
248                        }
249                        InnerMultipartItem::None => false,
250                    };
251                    if stop {
252                        self.item = InnerMultipartItem::None;
253                    }
254                    if let InnerMultipartItem::None = self.item {
255                        break;
256                    }
257                }
258            }
259
260            let headers = if let Some(mut payload) = self.payload.get_mut(safety) {
261                match self.state {
262                    // read until first boundary
263                    InnerState::FirstBoundary => {
264                        match InnerMultipart::skip_until_boundary(&mut payload, &self.boundary)?
265                        {
266                            Some(eof) => {
267                                if eof {
268                                    self.state = InnerState::Eof;
269                                    return Poll::Ready(None);
270                                } else {
271                                    self.state = InnerState::Headers;
272                                }
273                            }
274                            None => return Poll::Pending,
275                        }
276                    }
277                    // read boundary
278                    InnerState::Boundary => {
279                        match InnerMultipart::read_boundary(&mut payload, &self.boundary)? {
280                            None => return Poll::Pending,
281                            Some(eof) => {
282                                if eof {
283                                    self.state = InnerState::Eof;
284                                    return Poll::Ready(None);
285                                } else {
286                                    self.state = InnerState::Headers;
287                                }
288                            }
289                        }
290                    }
291                    _ => (),
292                }
293
294                // read field headers for next field
295                if self.state == InnerState::Headers {
296                    if let Some(headers) = InnerMultipart::read_headers(&mut payload)? {
297                        self.state = InnerState::Boundary;
298                        headers
299                    } else {
300                        return Poll::Pending;
301                    }
302                } else {
303                    unreachable!()
304                }
305            } else {
306                log::debug!("NotReady: field is in flight");
307                return Poll::Pending;
308            };
309
310            // content type
311            let mut mt = mime::APPLICATION_OCTET_STREAM;
312            if let Some(content_type) = headers.get(&header::CONTENT_TYPE) {
313                if let Ok(content_type) = content_type.to_str() {
314                    if let Ok(ct) = content_type.parse::<mime::Mime>() {
315                        mt = ct;
316                    }
317                }
318            }
319
320            self.state = InnerState::Boundary;
321
322            // nested multipart stream
323            if mt.type_() == mime::MULTIPART {
324                Poll::Ready(Some(Err(MultipartError::Nested)))
325            } else {
326                let field = Rc::new(RefCell::new(InnerField::new(
327                    self.payload.clone(),
328                    self.boundary.clone(),
329                    &headers,
330                )?));
331                self.item = InnerMultipartItem::Field(Rc::clone(&field));
332
333                Poll::Ready(Some(Ok(Field::new(safety.clone(cx), headers, mt, field))))
334            }
335        }
336    }
337}
338
339impl Drop for InnerMultipart {
340    fn drop(&mut self) {
341        // InnerMultipartItem::Field has to be dropped first because of Safety.
342        self.item = InnerMultipartItem::None;
343    }
344}
345
346/// A single field in a multipart stream
347pub struct Field {
348    ct: mime::Mime,
349    headers: HeaderMap,
350    inner: Rc<RefCell<InnerField>>,
351    safety: Safety,
352}
353
354impl Field {
355    fn new(
356        safety: Safety,
357        headers: HeaderMap,
358        ct: mime::Mime,
359        inner: Rc<RefCell<InnerField>>,
360    ) -> Self {
361        Field { ct, headers, inner, safety }
362    }
363
364    /// Get a map of headers
365    pub fn headers(&self) -> &HeaderMap {
366        &self.headers
367    }
368
369    /// Get the content type of the field
370    pub fn content_type(&self) -> &mime::Mime {
371        &self.ct
372    }
373}
374
375impl Stream for Field {
376    type Item = Result<Bytes, MultipartError>;
377
378    fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
379        if self.safety.current() {
380            let mut inner = self.inner.borrow_mut();
381            if let Some(mut payload) = inner.payload.as_ref().unwrap().get_mut(&self.safety) {
382                payload.poll_stream(cx)?;
383            }
384            inner.poll(&self.safety)
385        } else if !self.safety.is_clean() {
386            Poll::Ready(Some(Err(MultipartError::NotConsumed)))
387        } else {
388            Poll::Pending
389        }
390    }
391}
392
393impl fmt::Debug for Field {
394    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
395        writeln!(f, "\nField: {}", self.ct)?;
396        writeln!(f, "  boundary: {}", self.inner.borrow().boundary)?;
397        writeln!(f, "  headers:")?;
398        for (key, val) in self.headers.iter() {
399            writeln!(f, "    {:?}: {:?}", key, val)?;
400        }
401        Ok(())
402    }
403}
404
405struct InnerField {
406    payload: Option<PayloadRef>,
407    boundary: String,
408    eof: bool,
409    length: Option<u64>,
410}
411
412impl InnerField {
413    fn new(
414        payload: PayloadRef,
415        boundary: String,
416        headers: &HeaderMap,
417    ) -> Result<InnerField, PayloadError> {
418        let len = if let Some(len) = headers.get(&header::CONTENT_LENGTH) {
419            if let Ok(s) = len.to_str() {
420                if let Ok(len) = s.parse::<u64>() {
421                    Some(len)
422                } else {
423                    return Err(PayloadError::Incomplete(None));
424                }
425            } else {
426                return Err(PayloadError::Incomplete(None));
427            }
428        } else {
429            None
430        };
431
432        Ok(InnerField { boundary, payload: Some(payload), eof: false, length: len })
433    }
434
435    /// Reads body part content chunk of the specified size.
436    /// The body part must has `Content-Length` header with proper value.
437    fn read_len(
438        payload: &mut PayloadBuffer,
439        size: &mut u64,
440    ) -> Poll<Option<Result<Bytes, MultipartError>>> {
441        if *size == 0 {
442            Poll::Ready(None)
443        } else {
444            match payload.read_max(*size)? {
445                Some(mut chunk) => {
446                    let len = cmp::min(chunk.len() as u64, *size);
447                    *size -= len;
448                    let ch = chunk.split_to(len as usize);
449                    if !chunk.is_empty() {
450                        payload.unprocessed(chunk);
451                    }
452                    Poll::Ready(Some(Ok(ch)))
453                }
454                None => {
455                    if payload.eof && (*size != 0) {
456                        Poll::Ready(Some(Err(MultipartError::Incomplete)))
457                    } else {
458                        Poll::Pending
459                    }
460                }
461            }
462        }
463    }
464
465    /// Reads content chunk of body part with unknown length.
466    /// The `Content-Length` header for body part is not necessary.
467    fn read_stream(
468        payload: &mut PayloadBuffer,
469        boundary: &str,
470    ) -> Poll<Option<Result<Bytes, MultipartError>>> {
471        let mut pos = 0;
472
473        let len = payload.buf.len();
474        if len == 0 {
475            return if payload.eof {
476                Poll::Ready(Some(Err(MultipartError::Incomplete)))
477            } else {
478                Poll::Pending
479            };
480        }
481
482        // check boundary
483        if len > 4 && payload.buf[0] == b'\r' {
484            let b_len = if &payload.buf[..2] == b"\r\n" && &payload.buf[2..4] == b"--" {
485                Some(4)
486            } else if &payload.buf[1..3] == b"--" {
487                Some(3)
488            } else {
489                None
490            };
491
492            if let Some(b_len) = b_len {
493                let b_size = boundary.len() + b_len;
494                if len < b_size {
495                    return Poll::Pending;
496                } else if &payload.buf[b_len..b_size] == boundary.as_bytes() {
497                    // found boundary
498                    return Poll::Ready(None);
499                }
500            }
501        }
502
503        loop {
504            return if let Some(idx) = twoway::find_bytes(&payload.buf[pos..], b"\r") {
505                let cur = pos + idx;
506
507                // check if we have enough data for boundary detection
508                if cur + 4 > len {
509                    if cur > 0 {
510                        Poll::Ready(Some(Ok(payload.buf.split_to(cur).freeze())))
511                    } else {
512                        Poll::Pending
513                    }
514                } else {
515                    // check boundary
516                    if (&payload.buf[cur..cur + 2] == b"\r\n"
517                        && &payload.buf[cur + 2..cur + 4] == b"--")
518                        || (&payload.buf[cur..=cur] == b"\r"
519                            && &payload.buf[cur + 1..cur + 3] == b"--")
520                    {
521                        if cur != 0 {
522                            // return buffer
523                            Poll::Ready(Some(Ok(payload.buf.split_to(cur).freeze())))
524                        } else {
525                            pos = cur + 1;
526                            continue;
527                        }
528                    } else {
529                        // not boundary
530                        pos = cur + 1;
531                        continue;
532                    }
533                }
534            } else {
535                Poll::Ready(Some(Ok(payload.buf.split().freeze())))
536            };
537        }
538    }
539
540    fn poll(&mut self, s: &Safety) -> Poll<Option<Result<Bytes, MultipartError>>> {
541        if self.payload.is_none() {
542            return Poll::Ready(None);
543        }
544
545        let result = if let Some(mut payload) = self.payload.as_ref().unwrap().get_mut(s) {
546            if !self.eof {
547                let res = if let Some(ref mut len) = self.length {
548                    InnerField::read_len(&mut payload, len)
549                } else {
550                    InnerField::read_stream(&mut payload, &self.boundary)
551                };
552
553                match res {
554                    Poll::Pending => return Poll::Pending,
555                    Poll::Ready(Some(Ok(bytes))) => return Poll::Ready(Some(Ok(bytes))),
556                    Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
557                    Poll::Ready(None) => self.eof = true,
558                }
559            }
560
561            match payload.readline() {
562                Ok(None) => Poll::Pending,
563                Ok(Some(line)) => {
564                    if line.as_ref() != b"\r\n" {
565                        log::warn!(
566                            "multipart field did not read all the data or it is malformed"
567                        );
568                    }
569                    Poll::Ready(None)
570                }
571                Err(e) => Poll::Ready(Some(Err(e))),
572            }
573        } else {
574            Poll::Pending
575        };
576
577        if let Poll::Ready(None) = result {
578            self.payload.take();
579        }
580        result
581    }
582}
583
584struct PayloadRef {
585    payload: Rc<RefCell<PayloadBuffer>>,
586}
587
588impl PayloadRef {
589    fn new(payload: PayloadBuffer) -> PayloadRef {
590        PayloadRef { payload: Rc::new(payload.into()) }
591    }
592
593    fn get_mut<'a, 'b>(&'a self, s: &'b Safety) -> Option<RefMut<'a, PayloadBuffer>>
594    where
595        'a: 'b,
596    {
597        if s.current() {
598            Some(self.payload.borrow_mut())
599        } else {
600            None
601        }
602    }
603}
604
605impl Clone for PayloadRef {
606    fn clone(&self) -> PayloadRef {
607        PayloadRef { payload: Rc::clone(&self.payload) }
608    }
609}
610
611/// Counter. It tracks of number of clones of payloads and give access to
612/// payload only to top most task panics if Safety get destroyed and it not top
613/// most task.
614#[derive(Debug)]
615struct Safety {
616    task: LocalWaker,
617    level: usize,
618    payload: Rc<PhantomData<bool>>,
619    clean: Rc<Cell<bool>>,
620}
621
622impl Safety {
623    fn new() -> Safety {
624        let payload = Rc::new(PhantomData);
625        Safety {
626            task: LocalWaker::new(),
627            level: Rc::strong_count(&payload),
628            clean: Rc::new(Cell::new(true)),
629            payload,
630        }
631    }
632
633    fn current(&self) -> bool {
634        Rc::strong_count(&self.payload) == self.level && self.clean.get()
635    }
636
637    fn is_clean(&self) -> bool {
638        self.clean.get()
639    }
640
641    fn clone(&self, cx: &mut Context) -> Safety {
642        let payload = Rc::clone(&self.payload);
643        let s = Safety {
644            task: LocalWaker::new(),
645            level: Rc::strong_count(&payload),
646            clean: self.clean.clone(),
647            payload,
648        };
649        s.task.register(cx.waker());
650        s
651    }
652}
653
654impl Drop for Safety {
655    fn drop(&mut self) {
656        // parent task is dead
657        if Rc::strong_count(&self.payload) != self.level {
658            self.clean.set(true);
659        }
660        if let Some(task) = self.task.take() {
661            task.wake()
662        }
663    }
664}
665
666/// Payload buffer
667struct PayloadBuffer {
668    eof: bool,
669    buf: BytesMut,
670    stream: LocalBoxStream<'static, Result<Bytes, PayloadError>>,
671}
672
673impl PayloadBuffer {
674    /// Create new `PayloadBuffer` instance
675    fn new<S>(stream: S) -> Self
676    where
677        S: Stream<Item = Result<Bytes, PayloadError>> + 'static,
678    {
679        PayloadBuffer { eof: false, buf: BytesMut::new(), stream: stream.boxed_local() }
680    }
681
682    fn poll_stream(&mut self, cx: &mut Context) -> Result<(), PayloadError> {
683        loop {
684            match Pin::new(&mut self.stream).poll_next(cx) {
685                Poll::Ready(Some(Ok(data))) => self.buf.extend_from_slice(&data),
686                Poll::Ready(Some(Err(e))) => return Err(e),
687                Poll::Ready(None) => {
688                    self.eof = true;
689                    return Ok(());
690                }
691                Poll::Pending => return Ok(()),
692            }
693        }
694    }
695
696    /// Read exact number of bytes
697    #[cfg(test)]
698    fn read_exact(&mut self, size: usize) -> Option<Bytes> {
699        if size <= self.buf.len() {
700            Some(self.buf.split_to(size).freeze())
701        } else {
702            None
703        }
704    }
705
706    fn read_max(&mut self, size: u64) -> Result<Option<Bytes>, MultipartError> {
707        if !self.buf.is_empty() {
708            let size = std::cmp::min(self.buf.len() as u64, size) as usize;
709            Ok(Some(self.buf.split_to(size).freeze()))
710        } else if self.eof {
711            Err(MultipartError::Incomplete)
712        } else {
713            Ok(None)
714        }
715    }
716
717    /// Read until specified ending
718    pub fn read_until(&mut self, line: &[u8]) -> Result<Option<Bytes>, MultipartError> {
719        let res = twoway::find_bytes(&self.buf, line)
720            .map(|idx| self.buf.split_to(idx + line.len()).freeze());
721
722        if res.is_none() && self.eof {
723            Err(MultipartError::Incomplete)
724        } else {
725            Ok(res)
726        }
727    }
728
729    /// Read bytes until new line delimiter
730    pub fn readline(&mut self) -> Result<Option<Bytes>, MultipartError> {
731        self.read_until(b"\n")
732    }
733
734    /// Read bytes until new line delimiter or eof
735    pub fn readline_or_eof(&mut self) -> Result<Option<Bytes>, MultipartError> {
736        match self.readline() {
737            Err(MultipartError::Incomplete) if self.eof => Ok(Some(self.buf.split().freeze())),
738            line => line,
739        }
740    }
741
742    /// Put unprocessed data back to the buffer
743    pub fn unprocessed(&mut self, data: Bytes) {
744        let buf = BytesMut::from(data.as_ref());
745        let buf = std::mem::replace(&mut self.buf, buf);
746        self.buf.extend_from_slice(&buf);
747    }
748}
749
750#[cfg(test)]
751mod tests {
752    use super::*;
753
754    use futures::future::lazy;
755    use ntex::channel::mpsc;
756    use ntex::http::h1::Payload;
757    use ntex::util::Bytes;
758
759    #[ntex::test]
760    async fn test_boundary() {
761        let headers = HeaderMap::new();
762        match Multipart::boundary(&headers) {
763            Err(MultipartError::NoContentType) => (),
764            _ => unreachable!("should not happen"),
765        }
766
767        let mut headers = HeaderMap::new();
768        headers.insert(header::CONTENT_TYPE, header::HeaderValue::from_static("test"));
769
770        match Multipart::boundary(&headers) {
771            Err(MultipartError::ParseContentType) => (),
772            _ => unreachable!("should not happen"),
773        }
774
775        let mut headers = HeaderMap::new();
776        headers
777            .insert(header::CONTENT_TYPE, header::HeaderValue::from_static("multipart/mixed"));
778        match Multipart::boundary(&headers) {
779            Err(MultipartError::Boundary) => (),
780            _ => unreachable!("should not happen"),
781        }
782
783        let mut headers = HeaderMap::new();
784        headers.insert(
785            header::CONTENT_TYPE,
786            header::HeaderValue::from_static(
787                "multipart/mixed; boundary=\"5c02368e880e436dab70ed54e1c58209\"",
788            ),
789        );
790
791        assert_eq!(Multipart::boundary(&headers).unwrap(), "5c02368e880e436dab70ed54e1c58209");
792    }
793
794    fn create_stream() -> (
795        mpsc::Sender<Result<Bytes, PayloadError>>,
796        impl Stream<Item = Result<Bytes, PayloadError>>,
797    ) {
798        let (tx, rx) = mpsc::channel();
799
800        (tx, rx.map(|res| res.map_err(|_| panic!())))
801    }
802    // Stream that returns from a Bytes, one char at a time and Pending every other poll()
803    struct SlowStream {
804        bytes: Bytes,
805        pos: usize,
806        ready: bool,
807    }
808
809    impl SlowStream {
810        fn new(bytes: Bytes) -> SlowStream {
811            SlowStream { bytes, pos: 0, ready: false }
812        }
813    }
814
815    impl Stream for SlowStream {
816        type Item = Result<Bytes, PayloadError>;
817
818        fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
819            let this = self.get_mut();
820            if !this.ready {
821                this.ready = true;
822                cx.waker().wake_by_ref();
823                return Poll::Pending;
824            }
825            if this.pos == this.bytes.len() {
826                return Poll::Ready(None);
827            }
828            let res = Poll::Ready(Some(Ok(this.bytes.slice(this.pos..(this.pos + 1)))));
829            this.pos += 1;
830            this.ready = false;
831            res
832        }
833    }
834
835    fn create_simple_request_with_header() -> (Bytes, HeaderMap) {
836        let bytes = Bytes::from(
837            "testasdadsad\r\n\
838             --abbc761f78ff4d7cb7573b5a23f96ef0\r\n\
839             Content-Disposition: form-data; name=\"file\"; filename=\"fn.txt\"\r\n\
840             Content-Type: text/plain; charset=utf-8\r\nContent-Length: 4\r\n\r\n\
841             test\r\n\
842             --abbc761f78ff4d7cb7573b5a23f96ef0\r\n\
843             Content-Type: text/plain; charset=utf-8\r\nContent-Length: 4\r\n\r\n\
844             data\r\n\
845             --abbc761f78ff4d7cb7573b5a23f96ef0--\r\n",
846        );
847        let mut headers = HeaderMap::new();
848        headers.insert(
849            header::CONTENT_TYPE,
850            header::HeaderValue::from_static(
851                "multipart/mixed; boundary=\"abbc761f78ff4d7cb7573b5a23f96ef0\"",
852            ),
853        );
854        (bytes, headers)
855    }
856
857    #[ntex::test]
858    async fn test_multipart_no_end_crlf() {
859        let (sender, payload) = create_stream();
860        let (mut bytes, headers) = create_simple_request_with_header();
861        let bytes_stripped = bytes.split_to(bytes.len()); // strip crlf
862
863        sender.send(Ok(bytes_stripped)).unwrap();
864        drop(sender); // eof
865
866        let mut multipart = Multipart::new(&headers, payload);
867
868        match multipart.next().await.unwrap() {
869            Ok(_) => (),
870            _ => unreachable!(),
871        }
872
873        match multipart.next().await.unwrap() {
874            Ok(_) => (),
875            _ => unreachable!(),
876        }
877
878        match multipart.next().await {
879            None => (),
880            _ => unreachable!(),
881        }
882    }
883
884    #[ntex::test]
885    async fn test_multipart() {
886        let (sender, payload) = create_stream();
887        let (bytes, headers) = create_simple_request_with_header();
888
889        sender.send(Ok(bytes)).unwrap();
890
891        let mut multipart = Multipart::new(&headers, payload);
892        match multipart.next().await {
893            Some(Ok(mut field)) => {
894                assert_eq!(field.content_type().type_(), mime::TEXT);
895                assert_eq!(field.content_type().subtype(), mime::PLAIN);
896
897                match field.next().await.unwrap() {
898                    Ok(chunk) => assert_eq!(chunk, "test"),
899                    _ => unreachable!(),
900                }
901                match field.next().await {
902                    None => (),
903                    _ => unreachable!(),
904                }
905            }
906            _ => unreachable!(),
907        }
908
909        match multipart.next().await.unwrap() {
910            Ok(mut field) => {
911                assert_eq!(field.content_type().type_(), mime::TEXT);
912                assert_eq!(field.content_type().subtype(), mime::PLAIN);
913
914                match field.next().await {
915                    Some(Ok(chunk)) => assert_eq!(chunk, "data"),
916                    _ => unreachable!(),
917                }
918                match field.next().await {
919                    None => (),
920                    _ => unreachable!(),
921                }
922            }
923            _ => unreachable!(),
924        }
925
926        match multipart.next().await {
927            None => (),
928            _ => unreachable!(),
929        }
930    }
931
932    // Loops, collecting all bytes until end-of-field
933    async fn get_whole_field(field: &mut Field) -> BytesMut {
934        let mut b = BytesMut::new();
935        loop {
936            match field.next().await {
937                Some(Ok(chunk)) => b.extend_from_slice(&chunk),
938                None => return b,
939                _ => unreachable!(),
940            }
941        }
942    }
943
944    #[ntex::test]
945    async fn test_stream() {
946        let (bytes, headers) = create_simple_request_with_header();
947        let payload = SlowStream::new(bytes);
948
949        let mut multipart = Multipart::new(&headers, payload);
950        match multipart.next().await.unwrap() {
951            Ok(mut field) => {
952                assert_eq!(field.content_type().type_(), mime::TEXT);
953                assert_eq!(field.content_type().subtype(), mime::PLAIN);
954
955                assert_eq!(get_whole_field(&mut field).await, "test");
956            }
957            _ => unreachable!(),
958        }
959
960        match multipart.next().await {
961            Some(Ok(mut field)) => {
962                assert_eq!(field.content_type().type_(), mime::TEXT);
963                assert_eq!(field.content_type().subtype(), mime::PLAIN);
964
965                assert_eq!(get_whole_field(&mut field).await, "data");
966            }
967            _ => unreachable!(),
968        }
969
970        match multipart.next().await {
971            None => (),
972            _ => unreachable!(),
973        }
974    }
975
976    #[ntex::test]
977    async fn test_basic() {
978        let (_, payload) = Payload::create(false);
979        let mut payload = PayloadBuffer::new(payload);
980
981        assert_eq!(payload.buf.len(), 0);
982        assert!(lazy(|cx| payload.poll_stream(cx)).await.is_err());
983        assert_eq!(None, payload.read_max(1).unwrap());
984    }
985
986    #[ntex::test]
987    async fn test_eof() {
988        let (mut sender, payload) = Payload::create(false);
989        let mut payload = PayloadBuffer::new(payload);
990
991        assert_eq!(None, payload.read_max(4).unwrap());
992        sender.feed_data(Bytes::from("data"));
993        sender.feed_eof();
994        lazy(|cx| payload.poll_stream(cx)).await.unwrap();
995
996        assert_eq!(Some(Bytes::from("data")), payload.read_max(4).unwrap());
997        assert_eq!(payload.buf.len(), 0);
998        assert!(payload.read_max(1).is_err());
999        assert!(payload.eof);
1000    }
1001
1002    #[ntex::test]
1003    async fn test_err() {
1004        let (mut sender, payload) = Payload::create(false);
1005        let mut payload = PayloadBuffer::new(payload);
1006        assert_eq!(None, payload.read_max(1).unwrap());
1007        sender.set_error(PayloadError::Incomplete(None));
1008        lazy(|cx| payload.poll_stream(cx)).await.err().unwrap();
1009    }
1010
1011    #[ntex::test]
1012    async fn test_readmax() {
1013        let (mut sender, payload) = Payload::create(false);
1014        let mut payload = PayloadBuffer::new(payload);
1015
1016        sender.feed_data(Bytes::from("line1"));
1017        sender.feed_data(Bytes::from("line2"));
1018        lazy(|cx| payload.poll_stream(cx)).await.unwrap();
1019        assert_eq!(payload.buf.len(), 10);
1020
1021        assert_eq!(Some(Bytes::from("line1")), payload.read_max(5).unwrap());
1022        assert_eq!(payload.buf.len(), 5);
1023
1024        assert_eq!(Some(Bytes::from("line2")), payload.read_max(5).unwrap());
1025        assert_eq!(payload.buf.len(), 0);
1026    }
1027
1028    #[ntex::test]
1029    async fn test_readexactly() {
1030        let (mut sender, payload) = Payload::create(false);
1031        let mut payload = PayloadBuffer::new(payload);
1032
1033        assert_eq!(None, payload.read_exact(2));
1034
1035        sender.feed_data(Bytes::from("line1"));
1036        sender.feed_data(Bytes::from("line2"));
1037        lazy(|cx| payload.poll_stream(cx)).await.unwrap();
1038
1039        assert_eq!(Some(Bytes::from_static(b"li")), payload.read_exact(2));
1040        assert_eq!(payload.buf.len(), 8);
1041
1042        assert_eq!(Some(Bytes::from_static(b"ne1l")), payload.read_exact(4));
1043        assert_eq!(payload.buf.len(), 4);
1044    }
1045
1046    #[ntex::test]
1047    async fn test_readuntil() {
1048        let (mut sender, payload) = Payload::create(false);
1049        let mut payload = PayloadBuffer::new(payload);
1050
1051        assert_eq!(None, payload.read_until(b"ne").unwrap());
1052
1053        sender.feed_data(Bytes::from("line1"));
1054        sender.feed_data(Bytes::from("line2"));
1055        lazy(|cx| payload.poll_stream(cx)).await.unwrap();
1056
1057        assert_eq!(Some(Bytes::from("line")), payload.read_until(b"ne").unwrap());
1058        assert_eq!(payload.buf.len(), 6);
1059
1060        assert_eq!(Some(Bytes::from("1line2")), payload.read_until(b"2").unwrap());
1061        assert_eq!(payload.buf.len(), 0);
1062    }
1063}