Skip to main content

ntex_multipart/
server.rs

1//! Multipart payload support
2use std::cell::{Cell, RefCell, RefMut};
3use std::task::{Context, Poll};
4use std::{cmp, convert::TryFrom, fmt, marker::PhantomData, pin::Pin, rc::Rc};
5
6use futures::stream::{LocalBoxStream, Stream, StreamExt};
7use ntex::http::error::{DecodeError, PayloadError};
8use ntex::http::header::{self, HeaderMap, HeaderName, HeaderValue};
9use ntex::task::LocalWaker;
10use ntex::util::{Bytes, BytesMut};
11
12use crate::error::MultipartError;
13
14const MAX_HEADERS: usize = 32;
15
16/// The server-side implementation of `multipart/form-data` requests.
17///
18/// This will parse the incoming stream into `MultipartItem` instances via its
19/// Stream implementation.
20/// `MultipartItem::Field` contains multipart field. `MultipartItem::Multipart`
21/// is used for nested multipart streams.
22pub struct Multipart {
23    safety: Safety,
24    error: Option<MultipartError>,
25    inner: Option<Rc<RefCell<InnerMultipart>>>,
26}
27
28enum InnerMultipartItem {
29    None,
30    Field(Rc<RefCell<InnerField>>),
31}
32
33#[derive(PartialEq, Debug)]
34enum InnerState {
35    /// Stream eof
36    Eof,
37    /// Skip data until first boundary
38    FirstBoundary,
39    /// Reading boundary
40    Boundary,
41    /// Reading Headers,
42    Headers,
43}
44
45struct InnerMultipart {
46    payload: PayloadRef,
47    boundary: String,
48    state: InnerState,
49    item: InnerMultipartItem,
50}
51
52impl Multipart {
53    /// Create multipart instance for boundary.
54    pub fn new<S>(headers: &HeaderMap, stream: S) -> Multipart
55    where
56        S: Stream<Item = Result<Bytes, PayloadError>> + Unpin + 'static,
57    {
58        match Self::boundary(headers) {
59            Ok(boundary) => Multipart {
60                error: None,
61                safety: Safety::new(),
62                inner: Some(Rc::new(RefCell::new(InnerMultipart {
63                    boundary,
64                    payload: PayloadRef::new(PayloadBuffer::new(Box::new(stream))),
65                    state: InnerState::FirstBoundary,
66                    item: InnerMultipartItem::None,
67                }))),
68            },
69            Err(err) => Multipart { error: Some(err), safety: Safety::new(), inner: None },
70        }
71    }
72
73    /// Extract boundary info from headers.
74    fn boundary(headers: &HeaderMap) -> Result<String, MultipartError> {
75        if let Some(content_type) = headers.get(&header::CONTENT_TYPE) {
76            if let Ok(content_type) = content_type.to_str() {
77                if let Ok(ct) = content_type.parse::<mime::Mime>() {
78                    if let Some(boundary) = ct.get_param(mime::BOUNDARY) {
79                        Ok(boundary.as_str().to_owned())
80                    } else {
81                        Err(MultipartError::Boundary)
82                    }
83                } else {
84                    Err(MultipartError::ParseContentType)
85                }
86            } else {
87                Err(MultipartError::ParseContentType)
88            }
89        } else {
90            Err(MultipartError::NoContentType)
91        }
92    }
93}
94
95impl Stream for Multipart {
96    type Item = Result<Field, MultipartError>;
97
98    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
99        if let Some(err) = self.error.take() {
100            Poll::Ready(Some(Err(err)))
101        } else if self.safety.current() {
102            let this = self.get_mut();
103            let mut inner = this.inner.as_mut().unwrap().borrow_mut();
104            if let Some(mut payload) = inner.payload.get_mut(&this.safety) {
105                payload.poll_stream(cx)?;
106            }
107            inner.poll(&this.safety, cx)
108        } else if !self.safety.is_clean() {
109            Poll::Ready(Some(Err(MultipartError::NotConsumed)))
110        } else {
111            Poll::Pending
112        }
113    }
114}
115
116impl InnerMultipart {
117    fn read_headers(payload: &mut PayloadBuffer) -> Result<Option<HeaderMap>, MultipartError> {
118        match payload.read_until(b"\r\n\r\n")? {
119            None => {
120                if payload.eof {
121                    Err(MultipartError::Incomplete)
122                } else {
123                    Ok(None)
124                }
125            }
126            Some(bytes) => {
127                let mut hdrs = [httparse::EMPTY_HEADER; MAX_HEADERS];
128                match httparse::parse_headers(&bytes, &mut hdrs) {
129                    Ok(httparse::Status::Complete((_, hdrs))) => {
130                        // convert headers
131                        let mut headers = HeaderMap::with_capacity(hdrs.len());
132                        for h in hdrs {
133                            if let Ok(name) = HeaderName::try_from(h.name) {
134                                if let Ok(value) = HeaderValue::try_from(h.value) {
135                                    headers.append(name, value);
136                                } else {
137                                    return Err(DecodeError::Header.into());
138                                }
139                            } else {
140                                return Err(DecodeError::Header.into());
141                            }
142                        }
143                        Ok(Some(headers))
144                    }
145                    Ok(httparse::Status::Partial) => Err(DecodeError::Header.into()),
146                    Err(err) => Err(DecodeError::from(err).into()),
147                }
148            }
149        }
150    }
151
152    fn read_boundary(
153        payload: &mut PayloadBuffer,
154        boundary: &str,
155    ) -> Result<Option<bool>, MultipartError> {
156        // TODO: need to read epilogue
157        match payload.readline_or_eof()? {
158            None => {
159                if payload.eof {
160                    Ok(Some(true))
161                } else {
162                    Ok(None)
163                }
164            }
165            Some(chunk) => {
166                if chunk.len() < boundary.len() + 4
167                    || &chunk[..2] != b"--"
168                    || &chunk[2..boundary.len() + 2] != boundary.as_bytes()
169                {
170                    Err(MultipartError::Boundary)
171                } else if &chunk[boundary.len() + 2..] == b"\r\n" {
172                    Ok(Some(false))
173                } else if &chunk[boundary.len() + 2..boundary.len() + 4] == b"--"
174                    && (chunk.len() == boundary.len() + 4
175                        || &chunk[boundary.len() + 4..] == b"\r\n")
176                {
177                    Ok(Some(true))
178                } else {
179                    Err(MultipartError::Boundary)
180                }
181            }
182        }
183    }
184
185    fn skip_until_boundary(
186        payload: &mut PayloadBuffer,
187        boundary: &str,
188    ) -> Result<Option<bool>, MultipartError> {
189        let mut eof = false;
190        loop {
191            match payload.readline()? {
192                Some(chunk) => {
193                    if chunk.is_empty() {
194                        return Err(MultipartError::Boundary);
195                    }
196                    if chunk.len() < boundary.len() {
197                        continue;
198                    }
199                    if &chunk[..2] == b"--" && &chunk[2..chunk.len() - 2] == boundary.as_bytes()
200                    {
201                        break;
202                    } else {
203                        if chunk.len() < boundary.len() + 2 {
204                            continue;
205                        }
206                        let b: &[u8] = boundary.as_ref();
207                        if &chunk[..boundary.len()] == b
208                            && &chunk[boundary.len()..boundary.len() + 2] == b"--"
209                        {
210                            eof = true;
211                            break;
212                        }
213                    }
214                }
215                None => {
216                    return if payload.eof {
217                        Err(MultipartError::Incomplete)
218                    } else {
219                        Ok(None)
220                    };
221                }
222            }
223        }
224        Ok(Some(eof))
225    }
226
227    fn poll(
228        &mut self,
229        safety: &Safety,
230        cx: &mut Context,
231    ) -> Poll<Option<Result<Field, MultipartError>>> {
232        if self.state == InnerState::Eof {
233            Poll::Ready(None)
234        } else {
235            // release field
236            loop {
237                // Nested multipart streams of fields has to be consumed
238                // before switching to next
239                if safety.current() {
240                    let stop = match self.item {
241                        InnerMultipartItem::Field(ref mut field) => {
242                            match field.borrow_mut().poll(safety) {
243                                Poll::Pending => return Poll::Pending,
244                                Poll::Ready(Some(Ok(_))) => continue,
245                                Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
246                                Poll::Ready(None) => true,
247                            }
248                        }
249                        InnerMultipartItem::None => false,
250                    };
251                    if stop {
252                        self.item = InnerMultipartItem::None;
253                    }
254                    if let InnerMultipartItem::None = self.item {
255                        break;
256                    }
257                }
258            }
259
260            let headers = if let Some(mut payload) = self.payload.get_mut(safety) {
261                match self.state {
262                    // read until first boundary
263                    InnerState::FirstBoundary => {
264                        match InnerMultipart::skip_until_boundary(&mut payload, &self.boundary)?
265                        {
266                            Some(eof) => {
267                                if eof {
268                                    self.state = InnerState::Eof;
269                                    return Poll::Ready(None);
270                                } else {
271                                    self.state = InnerState::Headers;
272                                }
273                            }
274                            None => return Poll::Pending,
275                        }
276                    }
277                    // read boundary
278                    InnerState::Boundary => {
279                        match InnerMultipart::read_boundary(&mut payload, &self.boundary)? {
280                            None => return Poll::Pending,
281                            Some(eof) => {
282                                if eof {
283                                    self.state = InnerState::Eof;
284                                    return Poll::Ready(None);
285                                } else {
286                                    self.state = InnerState::Headers;
287                                }
288                            }
289                        }
290                    }
291                    _ => (),
292                }
293
294                // read field headers for next field
295                if self.state == InnerState::Headers {
296                    if let Some(headers) = InnerMultipart::read_headers(&mut payload)? {
297                        self.state = InnerState::Boundary;
298                        headers
299                    } else {
300                        return Poll::Pending;
301                    }
302                } else {
303                    unreachable!()
304                }
305            } else {
306                log::debug!("NotReady: field is in flight");
307                return Poll::Pending;
308            };
309
310            // content type
311            let mut mt = mime::APPLICATION_OCTET_STREAM;
312            if let Some(content_type) = headers.get(&header::CONTENT_TYPE)
313                && let Ok(content_type) = content_type.to_str()
314                && let Ok(ct) = content_type.parse::<mime::Mime>()
315            {
316                mt = ct;
317            }
318
319            self.state = InnerState::Boundary;
320
321            // nested multipart stream
322            if mt.type_() == mime::MULTIPART {
323                Poll::Ready(Some(Err(MultipartError::Nested)))
324            } else {
325                let field = Rc::new(RefCell::new(InnerField::new(
326                    self.payload.clone(),
327                    self.boundary.clone(),
328                    &headers,
329                )?));
330                self.item = InnerMultipartItem::Field(Rc::clone(&field));
331
332                Poll::Ready(Some(Ok(Field::new(safety.clone(cx), headers, mt, field))))
333            }
334        }
335    }
336}
337
338impl Drop for InnerMultipart {
339    fn drop(&mut self) {
340        // InnerMultipartItem::Field has to be dropped first because of Safety.
341        self.item = InnerMultipartItem::None;
342    }
343}
344
345/// A single field in a multipart stream
346pub struct Field {
347    ct: mime::Mime,
348    headers: HeaderMap,
349    inner: Rc<RefCell<InnerField>>,
350    safety: Safety,
351}
352
353impl Field {
354    fn new(
355        safety: Safety,
356        headers: HeaderMap,
357        ct: mime::Mime,
358        inner: Rc<RefCell<InnerField>>,
359    ) -> Self {
360        Field { ct, headers, inner, safety }
361    }
362
363    /// Get a map of headers
364    pub fn headers(&self) -> &HeaderMap {
365        &self.headers
366    }
367
368    /// Get the content type of the field
369    pub fn content_type(&self) -> &mime::Mime {
370        &self.ct
371    }
372}
373
374impl Stream for Field {
375    type Item = Result<Bytes, MultipartError>;
376
377    fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
378        if self.safety.current() {
379            let mut inner = self.inner.borrow_mut();
380            if let Some(mut payload) = inner.payload.as_ref().unwrap().get_mut(&self.safety) {
381                payload.poll_stream(cx)?;
382            }
383            inner.poll(&self.safety)
384        } else if !self.safety.is_clean() {
385            Poll::Ready(Some(Err(MultipartError::NotConsumed)))
386        } else {
387            Poll::Pending
388        }
389    }
390}
391
392impl fmt::Debug for Field {
393    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
394        writeln!(f, "\nField: {}", self.ct)?;
395        writeln!(f, "  boundary: {}", self.inner.borrow().boundary)?;
396        writeln!(f, "  headers:")?;
397        for (key, val) in self.headers.iter() {
398            writeln!(f, "    {:?}: {:?}", key, val)?;
399        }
400        Ok(())
401    }
402}
403
404struct InnerField {
405    payload: Option<PayloadRef>,
406    boundary: String,
407    eof: bool,
408    length: Option<u64>,
409}
410
411impl InnerField {
412    fn new(
413        payload: PayloadRef,
414        boundary: String,
415        headers: &HeaderMap,
416    ) -> Result<InnerField, PayloadError> {
417        let len = if let Some(len) = headers.get(&header::CONTENT_LENGTH) {
418            if let Ok(s) = len.to_str() {
419                if let Ok(len) = s.parse::<u64>() {
420                    Some(len)
421                } else {
422                    return Err(PayloadError::Incomplete(None));
423                }
424            } else {
425                return Err(PayloadError::Incomplete(None));
426            }
427        } else {
428            None
429        };
430
431        Ok(InnerField { boundary, payload: Some(payload), eof: false, length: len })
432    }
433
434    /// Reads body part content chunk of the specified size.
435    /// The body part must has `Content-Length` header with proper value.
436    fn read_len(
437        payload: &mut PayloadBuffer,
438        size: &mut u64,
439    ) -> Poll<Option<Result<Bytes, MultipartError>>> {
440        if *size == 0 {
441            Poll::Ready(None)
442        } else {
443            match payload.read_max(*size)? {
444                Some(mut chunk) => {
445                    let len = cmp::min(chunk.len() as u64, *size);
446                    *size -= len;
447                    let ch = chunk.split_to(len as usize);
448                    if !chunk.is_empty() {
449                        payload.unprocessed(chunk);
450                    }
451                    Poll::Ready(Some(Ok(ch)))
452                }
453                None => {
454                    if payload.eof && (*size != 0) {
455                        Poll::Ready(Some(Err(MultipartError::Incomplete)))
456                    } else {
457                        Poll::Pending
458                    }
459                }
460            }
461        }
462    }
463
464    /// Reads content chunk of body part with unknown length.
465    /// The `Content-Length` header for body part is not necessary.
466    fn read_stream(
467        payload: &mut PayloadBuffer,
468        boundary: &str,
469    ) -> Poll<Option<Result<Bytes, MultipartError>>> {
470        let mut pos = 0;
471
472        let len = payload.buf.len();
473        if len == 0 {
474            return if payload.eof {
475                Poll::Ready(Some(Err(MultipartError::Incomplete)))
476            } else {
477                Poll::Pending
478            };
479        }
480
481        // check boundary
482        if len > 4 && payload.buf[0] == b'\r' {
483            let b_len = if &payload.buf[..2] == b"\r\n" && &payload.buf[2..4] == b"--" {
484                Some(4)
485            } else if &payload.buf[1..3] == b"--" {
486                Some(3)
487            } else {
488                None
489            };
490
491            if let Some(b_len) = b_len {
492                let b_size = boundary.len() + b_len;
493                if len < b_size {
494                    return Poll::Pending;
495                } else if &payload.buf[b_len..b_size] == boundary.as_bytes() {
496                    // found boundary
497                    return Poll::Ready(None);
498                }
499            }
500        }
501
502        loop {
503            return if let Some(idx) = twoway::find_bytes(&payload.buf[pos..], b"\r") {
504                let cur = pos + idx;
505
506                // check if we have enough data for boundary detection
507                if cur + 4 > len {
508                    if cur > 0 {
509                        Poll::Ready(Some(Ok(payload.buf.split_to(cur))))
510                    } else {
511                        Poll::Pending
512                    }
513                } else {
514                    // check boundary
515                    if (&payload.buf[cur..cur + 2] == b"\r\n"
516                        && &payload.buf[cur + 2..cur + 4] == b"--")
517                        || (&payload.buf[cur..=cur] == b"\r"
518                            && &payload.buf[cur + 1..cur + 3] == b"--")
519                    {
520                        if cur != 0 {
521                            // return buffer
522                            Poll::Ready(Some(Ok(payload.buf.split_to(cur))))
523                        } else {
524                            pos = cur + 1;
525                            continue;
526                        }
527                    } else {
528                        // not boundary
529                        pos = cur + 1;
530                        continue;
531                    }
532                }
533            } else {
534                Poll::Ready(Some(Ok(payload.buf.take())))
535            };
536        }
537    }
538
539    fn poll(&mut self, s: &Safety) -> Poll<Option<Result<Bytes, MultipartError>>> {
540        if self.payload.is_none() {
541            return Poll::Ready(None);
542        }
543
544        let result = if let Some(mut payload) = self.payload.as_ref().unwrap().get_mut(s) {
545            if !self.eof {
546                let res = if let Some(ref mut len) = self.length {
547                    InnerField::read_len(&mut payload, len)
548                } else {
549                    InnerField::read_stream(&mut payload, &self.boundary)
550                };
551
552                match res {
553                    Poll::Pending => return Poll::Pending,
554                    Poll::Ready(Some(Ok(bytes))) => return Poll::Ready(Some(Ok(bytes))),
555                    Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
556                    Poll::Ready(None) => self.eof = true,
557                }
558            }
559
560            match payload.readline() {
561                Ok(None) => Poll::Pending,
562                Ok(Some(line)) => {
563                    if line.as_ref() != b"\r\n" {
564                        log::warn!(
565                            "multipart field did not read all the data or it is malformed"
566                        );
567                    }
568                    Poll::Ready(None)
569                }
570                Err(e) => Poll::Ready(Some(Err(e))),
571            }
572        } else {
573            Poll::Pending
574        };
575
576        if let Poll::Ready(None) = result {
577            self.payload.take();
578        }
579        result
580    }
581}
582
583struct PayloadRef {
584    payload: Rc<RefCell<PayloadBuffer>>,
585}
586
587impl PayloadRef {
588    fn new(payload: PayloadBuffer) -> PayloadRef {
589        PayloadRef { payload: Rc::new(payload.into()) }
590    }
591
592    fn get_mut<'a, 'b>(&'a self, s: &'b Safety) -> Option<RefMut<'a, PayloadBuffer>>
593    where
594        'a: 'b,
595    {
596        if s.current() { Some(self.payload.borrow_mut()) } else { None }
597    }
598}
599
600impl Clone for PayloadRef {
601    fn clone(&self) -> PayloadRef {
602        PayloadRef { payload: Rc::clone(&self.payload) }
603    }
604}
605
606/// Counter. It tracks of number of clones of payloads and give access to
607/// payload only to top most task panics if Safety get destroyed and it not top
608/// most task.
609#[derive(Debug)]
610struct Safety {
611    task: LocalWaker,
612    level: usize,
613    payload: Rc<PhantomData<bool>>,
614    clean: Rc<Cell<bool>>,
615}
616
617impl Safety {
618    fn new() -> Safety {
619        let payload = Rc::new(PhantomData);
620        Safety {
621            task: LocalWaker::new(),
622            level: Rc::strong_count(&payload),
623            clean: Rc::new(Cell::new(true)),
624            payload,
625        }
626    }
627
628    fn current(&self) -> bool {
629        Rc::strong_count(&self.payload) == self.level && self.clean.get()
630    }
631
632    fn is_clean(&self) -> bool {
633        self.clean.get()
634    }
635
636    fn clone(&self, cx: &mut Context) -> Safety {
637        let payload = Rc::clone(&self.payload);
638        let s = Safety {
639            task: LocalWaker::new(),
640            level: Rc::strong_count(&payload),
641            clean: self.clean.clone(),
642            payload,
643        };
644        s.task.register(cx.waker());
645        s
646    }
647}
648
649impl Drop for Safety {
650    fn drop(&mut self) {
651        // parent task is dead
652        if Rc::strong_count(&self.payload) != self.level {
653            self.clean.set(true);
654        }
655        if let Some(task) = self.task.take() {
656            task.wake()
657        }
658    }
659}
660
661/// Payload buffer
662struct PayloadBuffer {
663    eof: bool,
664    buf: BytesMut,
665    stream: LocalBoxStream<'static, Result<Bytes, PayloadError>>,
666}
667
668impl PayloadBuffer {
669    /// Create new `PayloadBuffer` instance
670    fn new<S>(stream: S) -> Self
671    where
672        S: Stream<Item = Result<Bytes, PayloadError>> + 'static,
673    {
674        PayloadBuffer { eof: false, buf: BytesMut::new(), stream: stream.boxed_local() }
675    }
676
677    fn poll_stream(&mut self, cx: &mut Context) -> Result<(), PayloadError> {
678        loop {
679            match Pin::new(&mut self.stream).poll_next(cx) {
680                Poll::Ready(Some(Ok(data))) => self.buf.extend_from_slice(&data),
681                Poll::Ready(Some(Err(e))) => return Err(e),
682                Poll::Ready(None) => {
683                    self.eof = true;
684                    return Ok(());
685                }
686                Poll::Pending => return Ok(()),
687            }
688        }
689    }
690
691    /// Read exact number of bytes
692    #[cfg(test)]
693    fn read_exact(&mut self, size: usize) -> Option<Bytes> {
694        if size <= self.buf.len() { Some(self.buf.split_to(size)) } else { None }
695    }
696
697    fn read_max(&mut self, size: u64) -> Result<Option<Bytes>, MultipartError> {
698        if !self.buf.is_empty() {
699            let size = std::cmp::min(self.buf.len() as u64, size) as usize;
700            Ok(Some(self.buf.split_to(size)))
701        } else if self.eof {
702            Err(MultipartError::Incomplete)
703        } else {
704            Ok(None)
705        }
706    }
707
708    /// Read until specified ending
709    pub fn read_until(&mut self, line: &[u8]) -> Result<Option<Bytes>, MultipartError> {
710        let res =
711            twoway::find_bytes(&self.buf, line).map(|idx| self.buf.split_to(idx + line.len()));
712
713        if res.is_none() && self.eof { Err(MultipartError::Incomplete) } else { Ok(res) }
714    }
715
716    /// Read bytes until new line delimiter
717    pub fn readline(&mut self) -> Result<Option<Bytes>, MultipartError> {
718        self.read_until(b"\n")
719    }
720
721    /// Read bytes until new line delimiter or eof
722    pub fn readline_or_eof(&mut self) -> Result<Option<Bytes>, MultipartError> {
723        match self.readline() {
724            Err(MultipartError::Incomplete) if self.eof => Ok(Some(self.buf.take())),
725            line => line,
726        }
727    }
728
729    /// Put unprocessed data back to the buffer
730    pub fn unprocessed(&mut self, data: Bytes) {
731        let buf = BytesMut::from(data.as_ref());
732        let buf = std::mem::replace(&mut self.buf, buf);
733        self.buf.extend_from_slice(&buf);
734    }
735}
736
737#[cfg(test)]
738mod tests {
739    use super::*;
740
741    use futures::future::lazy;
742    use ntex::{channel::bstream, channel::mpsc, util::Bytes};
743
744    #[ntex::test]
745    async fn test_boundary() {
746        let headers = HeaderMap::new();
747        match Multipart::boundary(&headers) {
748            Err(MultipartError::NoContentType) => (),
749            _ => unreachable!("should not happen"),
750        }
751
752        let mut headers = HeaderMap::new();
753        headers.insert(header::CONTENT_TYPE, header::HeaderValue::from_static("test"));
754
755        match Multipart::boundary(&headers) {
756            Err(MultipartError::ParseContentType) => (),
757            _ => unreachable!("should not happen"),
758        }
759
760        let mut headers = HeaderMap::new();
761        headers
762            .insert(header::CONTENT_TYPE, header::HeaderValue::from_static("multipart/mixed"));
763        match Multipart::boundary(&headers) {
764            Err(MultipartError::Boundary) => (),
765            _ => unreachable!("should not happen"),
766        }
767
768        let mut headers = HeaderMap::new();
769        headers.insert(
770            header::CONTENT_TYPE,
771            header::HeaderValue::from_static(
772                "multipart/mixed; boundary=\"5c02368e880e436dab70ed54e1c58209\"",
773            ),
774        );
775
776        assert_eq!(Multipart::boundary(&headers).unwrap(), "5c02368e880e436dab70ed54e1c58209");
777    }
778
779    fn create_stream() -> (
780        mpsc::Sender<Result<Bytes, PayloadError>>,
781        impl Stream<Item = Result<Bytes, PayloadError>>,
782    ) {
783        let (tx, rx) = mpsc::channel();
784
785        (tx, rx.map(|res| res.map_err(|_| panic!())))
786    }
787    // Stream that returns from a Bytes, one char at a time and Pending every other poll()
788    struct SlowStream {
789        bytes: Bytes,
790        pos: usize,
791        ready: bool,
792    }
793
794    impl SlowStream {
795        fn new(bytes: Bytes) -> SlowStream {
796            SlowStream { bytes, pos: 0, ready: false }
797        }
798    }
799
800    impl Stream for SlowStream {
801        type Item = Result<Bytes, PayloadError>;
802
803        fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
804            let this = self.get_mut();
805            if !this.ready {
806                this.ready = true;
807                cx.waker().wake_by_ref();
808                return Poll::Pending;
809            }
810            if this.pos == this.bytes.len() {
811                return Poll::Ready(None);
812            }
813            let res = Poll::Ready(Some(Ok(this.bytes.slice(this.pos..(this.pos + 1)))));
814            this.pos += 1;
815            this.ready = false;
816            res
817        }
818    }
819
820    fn create_simple_request_with_header() -> (Bytes, HeaderMap) {
821        let bytes = Bytes::from(
822            "testasdadsad\r\n\
823             --abbc761f78ff4d7cb7573b5a23f96ef0\r\n\
824             Content-Disposition: form-data; name=\"file\"; filename=\"fn.txt\"\r\n\
825             Content-Type: text/plain; charset=utf-8\r\nContent-Length: 4\r\n\r\n\
826             test\r\n\
827             --abbc761f78ff4d7cb7573b5a23f96ef0\r\n\
828             Content-Type: text/plain; charset=utf-8\r\nContent-Length: 4\r\n\r\n\
829             data\r\n\
830             --abbc761f78ff4d7cb7573b5a23f96ef0--\r\n",
831        );
832        let mut headers = HeaderMap::new();
833        headers.insert(
834            header::CONTENT_TYPE,
835            header::HeaderValue::from_static(
836                "multipart/mixed; boundary=\"abbc761f78ff4d7cb7573b5a23f96ef0\"",
837            ),
838        );
839        (bytes, headers)
840    }
841
842    #[ntex::test]
843    async fn test_multipart_no_end_crlf() {
844        let (sender, payload) = create_stream();
845        let (mut bytes, headers) = create_simple_request_with_header();
846        let bytes_stripped = bytes.split_to(bytes.len()); // strip crlf
847
848        sender.send(Ok(bytes_stripped)).unwrap();
849        drop(sender); // eof
850
851        let mut multipart = Multipart::new(&headers, payload);
852
853        match multipart.next().await.unwrap() {
854            Ok(_) => (),
855            _ => unreachable!(),
856        }
857
858        match multipart.next().await.unwrap() {
859            Ok(_) => (),
860            _ => unreachable!(),
861        }
862
863        match multipart.next().await {
864            None => (),
865            _ => unreachable!(),
866        }
867    }
868
869    #[ntex::test]
870    async fn test_multipart() {
871        let (sender, payload) = create_stream();
872        let (bytes, headers) = create_simple_request_with_header();
873
874        sender.send(Ok(bytes)).unwrap();
875
876        let mut multipart = Multipart::new(&headers, payload);
877        match multipart.next().await {
878            Some(Ok(mut field)) => {
879                assert_eq!(field.content_type().type_(), mime::TEXT);
880                assert_eq!(field.content_type().subtype(), mime::PLAIN);
881
882                match field.next().await.unwrap() {
883                    Ok(chunk) => assert_eq!(chunk, "test"),
884                    _ => unreachable!(),
885                }
886                match field.next().await {
887                    None => (),
888                    _ => unreachable!(),
889                }
890            }
891            _ => unreachable!(),
892        }
893
894        match multipart.next().await.unwrap() {
895            Ok(mut field) => {
896                assert_eq!(field.content_type().type_(), mime::TEXT);
897                assert_eq!(field.content_type().subtype(), mime::PLAIN);
898
899                match field.next().await {
900                    Some(Ok(chunk)) => assert_eq!(chunk, "data"),
901                    _ => unreachable!(),
902                }
903                match field.next().await {
904                    None => (),
905                    _ => unreachable!(),
906                }
907            }
908            _ => unreachable!(),
909        }
910
911        match multipart.next().await {
912            None => (),
913            _ => unreachable!(),
914        }
915    }
916
917    // Loops, collecting all bytes until end-of-field
918    async fn get_whole_field(field: &mut Field) -> BytesMut {
919        let mut b = BytesMut::new();
920        loop {
921            match field.next().await {
922                Some(Ok(chunk)) => b.extend_from_slice(&chunk),
923                None => return b,
924                _ => unreachable!(),
925            }
926        }
927    }
928
929    #[ntex::test]
930    async fn test_stream() {
931        let (bytes, headers) = create_simple_request_with_header();
932        let payload = SlowStream::new(bytes);
933
934        let mut multipart = Multipart::new(&headers, payload);
935        match multipart.next().await.unwrap() {
936            Ok(mut field) => {
937                assert_eq!(field.content_type().type_(), mime::TEXT);
938                assert_eq!(field.content_type().subtype(), mime::PLAIN);
939
940                assert_eq!(get_whole_field(&mut field).await, "test");
941            }
942            _ => unreachable!(),
943        }
944
945        match multipart.next().await {
946            Some(Ok(mut field)) => {
947                assert_eq!(field.content_type().type_(), mime::TEXT);
948                assert_eq!(field.content_type().subtype(), mime::PLAIN);
949
950                assert_eq!(get_whole_field(&mut field).await, "data");
951            }
952            _ => unreachable!(),
953        }
954
955        match multipart.next().await {
956            None => (),
957            _ => unreachable!(),
958        }
959    }
960
961    // #[ntex::test]
962    // async fn test_basic() {
963    //     let (_sender, payload) = bstream::channel();
964    //     let mut payload = PayloadBuffer::new(payload);
965
966    //     assert_eq!(payload.buf.len(), 0);
967    //     assert!(lazy(|cx| payload.poll_stream(cx)).await.is_err());
968    //     assert_eq!(None, payload.read_max(1).unwrap());
969    // }
970
971    #[ntex::test]
972    async fn test_eof() {
973        let (sender, payload) = bstream::channel();
974        let mut payload = PayloadBuffer::new(payload);
975
976        assert_eq!(None, payload.read_max(4).unwrap());
977        sender.feed_data(Bytes::from("data"));
978        sender.feed_eof();
979        lazy(|cx| payload.poll_stream(cx)).await.unwrap();
980
981        assert_eq!(Some(Bytes::from("data")), payload.read_max(4).unwrap());
982        assert_eq!(payload.buf.len(), 0);
983        assert!(payload.read_max(1).is_err());
984        assert!(payload.eof);
985    }
986
987    #[ntex::test]
988    async fn test_err() {
989        let (sender, payload) = bstream::channel();
990        let mut payload = PayloadBuffer::new(payload);
991        assert_eq!(None, payload.read_max(1).unwrap());
992        sender.set_error(PayloadError::Incomplete(None));
993        lazy(|cx| payload.poll_stream(cx)).await.err().unwrap();
994    }
995
996    #[ntex::test]
997    async fn test_readmax() {
998        let (sender, payload) = bstream::channel();
999        let mut payload = PayloadBuffer::new(payload);
1000
1001        sender.feed_data(Bytes::from("line1"));
1002        sender.feed_data(Bytes::from("line2"));
1003        lazy(|cx| payload.poll_stream(cx)).await.unwrap();
1004        assert_eq!(payload.buf.len(), 10);
1005
1006        assert_eq!(Some(Bytes::from("line1")), payload.read_max(5).unwrap());
1007        assert_eq!(payload.buf.len(), 5);
1008
1009        assert_eq!(Some(Bytes::from("line2")), payload.read_max(5).unwrap());
1010        assert_eq!(payload.buf.len(), 0);
1011    }
1012
1013    #[ntex::test]
1014    async fn test_readexactly() {
1015        let (sender, payload) = bstream::channel();
1016        let mut payload = PayloadBuffer::new(payload);
1017
1018        assert_eq!(None, payload.read_exact(2));
1019
1020        sender.feed_data(Bytes::from("line1"));
1021        sender.feed_data(Bytes::from("line2"));
1022        lazy(|cx| payload.poll_stream(cx)).await.unwrap();
1023
1024        assert_eq!(Some(Bytes::from_static(b"li")), payload.read_exact(2));
1025        assert_eq!(payload.buf.len(), 8);
1026
1027        assert_eq!(Some(Bytes::from_static(b"ne1l")), payload.read_exact(4));
1028        assert_eq!(payload.buf.len(), 4);
1029    }
1030
1031    #[ntex::test]
1032    async fn test_readuntil() {
1033        let (sender, payload) = bstream::channel();
1034        let mut payload = PayloadBuffer::new(payload);
1035
1036        assert_eq!(None, payload.read_until(b"ne").unwrap());
1037
1038        sender.feed_data(Bytes::from("line1"));
1039        sender.feed_data(Bytes::from("line2"));
1040        lazy(|cx| payload.poll_stream(cx)).await.unwrap();
1041
1042        assert_eq!(Some(Bytes::from("line")), payload.read_until(b"ne").unwrap());
1043        assert_eq!(payload.buf.len(), 6);
1044
1045        assert_eq!(Some(Bytes::from("1line2")), payload.read_until(b"2").unwrap());
1046        assert_eq!(payload.buf.len(), 0);
1047    }
1048}