Skip to main content

http_multipart/
lib.rs

1#![forbid(unsafe_code)]
2
3mod content_disposition;
4mod error;
5mod field;
6mod header;
7
8pub use self::{error::MultipartError, field::Field};
9
10use core::{future::poll_fn, pin::Pin};
11
12use bytes::{Buf, BytesMut};
13use field::FieldDecoder;
14use futures_core::stream::Stream;
15use http::{header::HeaderMap, Method, Request};
16use memchr::memmem;
17use pin_project_lite::pin_project;
18
19use self::{content_disposition::ContentDisposition, error::PayloadError};
20
21/// Multipart protocol using high level API that operate over [Stream] trait.
22///
23/// `http` crate is used as Http request input. It provides necessary header information needed for
24/// [Multipart].
25///
26/// # Examples:
27/// ```rust
28/// use std::{convert::Infallible, error, pin::pin};
29///
30/// use futures_core::stream::Stream;
31/// use http::Request;
32///
33/// async fn handle<B>(req: Request<B>) -> Result<(), Box<dyn error::Error + Send + Sync>>
34/// where
35///     B: Stream<Item = Result<Vec<u8>, Infallible>>
36/// {
37///     // destruct request type.
38///     let (parts, body) = req.into_parts();
39///     let req = Request::from_parts(parts, ());
40///
41///     // prepare multipart handling.
42///     let mut multipart = http_multipart::multipart(&req, body)?;
43///
44///     // pin multipart and start await on the request body.
45///     let mut multipart = pin!(multipart);
46///
47///     // try async iterate through fields of the multipart.
48///     while let Some(mut field) = multipart.try_next().await? {
49///         // try async iterate through single field's bytes data.
50///         while let Some(chunk) = field.try_next().await? {
51///             // handle bytes data.
52///         }
53///     }
54///
55///     Ok(())
56/// }
57/// ```
58pub fn multipart<Ext, B, T, E>(req: &Request<Ext>, body: B) -> Result<Multipart<B>, MultipartError>
59where
60    B: Stream<Item = Result<T, E>>,
61    T: AsRef<[u8]>,
62    E: Into<PayloadError>,
63{
64    multipart_with_config(req, body, Config::default())
65}
66
67/// [multipart] with [Config] that used for customize behavior of [Multipart].
68pub fn multipart_with_config<Ext, B, T, E>(
69    req: &Request<Ext>,
70    body: B,
71    config: Config,
72) -> Result<Multipart<B>, MultipartError>
73where
74    B: Stream<Item = Result<T, E>>,
75    T: AsRef<[u8]>,
76    E: Into<PayloadError>,
77{
78    if req.method() != Method::POST {
79        return Err(MultipartError::NoPostMethod);
80    }
81
82    let boundary = header::boundary(req.headers())?;
83
84    Ok(Multipart {
85        stream: body,
86        buf: BytesMut::new(),
87        boundary: boundary.into(),
88        headers: HeaderMap::new(),
89        pending_field: false,
90        config,
91    })
92}
93
94/// Configuration for [Multipart] type
95#[derive(Debug, Copy, Clone)]
96pub struct Config {
97    /// limit the max size of internal buffer.
98    /// internal buffer is used to cache overlapped chunks around boundary and filed headers.
99    /// Default to 1MB
100    pub buf_limit: usize,
101}
102
103impl Default for Config {
104    fn default() -> Self {
105        Self { buf_limit: 1024 * 1024 }
106    }
107}
108
109pin_project! {
110    pub struct Multipart<S> {
111        #[pin]
112        stream: S,
113        buf: BytesMut,
114        boundary: Box<[u8]>,
115        headers: HeaderMap,
116        pending_field: bool,
117        config: Config
118    }
119}
120
121const DOUBLE_HYPHEN: &[u8; 2] = b"--";
122const LF: &[u8; 1] = b"\n";
123const DOUBLE_CR_LF: &[u8; 4] = b"\r\n\r\n";
124const FIELD_DELIMITER: &[u8; 4] = b"\r\n--";
125
126impl<S, T, E> Multipart<S>
127where
128    S: Stream<Item = Result<T, E>>,
129    T: AsRef<[u8]>,
130    E: Into<PayloadError>,
131{
132    // take in &mut Pin<&mut Self> so Field can borrow it as Pin<&mut Multipart>.
133    // this avoid another explicit stack pin when operating on Field type.
134    pub async fn try_next<'s>(self: &'s mut Pin<&mut Self>) -> Result<Option<Field<'s, S>>, MultipartError> {
135        let boundary_len = self.boundary.len();
136
137        if self.pending_field {
138            self.as_mut().consume_pending_field().await?;
139        }
140
141        loop {
142            let this = self.as_mut().project();
143            if let Some(idx) = memmem::find(this.buf, LF) {
144                // backtrack one byte to exclude CR
145                let slice = match idx.checked_sub(1) {
146                    Some(idx) => &this.buf[..idx],
147                    // no CR before LF.
148                    None => return Err(MultipartError::Boundary),
149                };
150
151                match slice.len() {
152                    // empty line. skip.
153                    0 => {
154                        // forward one byte to include LF and remove the empty line.
155                        this.buf.advance(idx + 1);
156                        continue;
157                    }
158                    // not enough data to operate.
159                    len if len < (boundary_len + 2) => {}
160                    // not boundary.
161                    _ if &slice[..2] != DOUBLE_HYPHEN => return Err(MultipartError::Boundary),
162                    // non last boundary
163                    _ if this.boundary.as_ref().eq(&slice[2..]) => {
164                        // forward one byte to include CRLF and remove the boundary line.
165                        this.buf.advance(idx + 1);
166
167                        let field = self.as_mut().parse_field().await?;
168                        return Ok(Some(field));
169                    }
170                    // last boundary.
171                    len if len == (boundary_len + 4) => {
172                        let at = boundary_len + 2;
173                        // TODO: add log for ill formed ending boundary?;
174                        let _ = this.boundary.as_ref().eq(&slice[2..at]) && &slice[at..] == DOUBLE_HYPHEN;
175                        return Ok(None);
176                    }
177                    // boundary line exceed expected length.
178                    _ => return Err(MultipartError::Boundary),
179                }
180            }
181
182            if self.buf_overflow() {
183                return Err(MultipartError::BufferOverflow);
184            }
185
186            self.as_mut().try_read_stream_to_buf().await?;
187        }
188    }
189
190    async fn parse_field(mut self: Pin<&mut Self>) -> Result<Field<'_, S>, MultipartError> {
191        loop {
192            let this = self.as_mut().project();
193
194            if let Some(idx) = memmem::find(this.buf, DOUBLE_CR_LF) {
195                let slice = &this.buf[..idx + 4];
196
197                header::parse_headers(this.headers, slice)?;
198                this.buf.advance(slice.len());
199
200                let cp = ContentDisposition::try_from_header(this.headers)?;
201
202                header::check_headers(this.headers)?;
203
204                let length = header::content_length_opt(this.headers)?;
205
206                *this.pending_field = true;
207
208                return Ok(Field::new(length, cp, self));
209            }
210
211            if self.buf_overflow() {
212                return Err(MultipartError::Header(httparse::Error::TooManyHeaders));
213            }
214
215            self.as_mut().try_read_stream_to_buf().await?;
216        }
217    }
218
219    #[cold]
220    #[inline(never)]
221    async fn consume_pending_field(mut self: Pin<&mut Self>) -> Result<(), MultipartError> {
222        let mut field_ty = FieldDecoder::default();
223
224        loop {
225            let this = self.as_mut().project();
226            if let Some(idx) = field_ty.try_find_split_idx(this.buf, this.boundary)? {
227                this.buf.advance(idx);
228            }
229            if matches!(field_ty, FieldDecoder::StreamEnd) {
230                *this.pending_field = false;
231                return Ok(());
232            }
233            self.as_mut().try_read_stream_to_buf().await?;
234        }
235    }
236
237    async fn try_read_stream_to_buf(mut self: Pin<&mut Self>) -> Result<(), MultipartError> {
238        let bytes = self.as_mut().try_read_stream().await?;
239        self.project().buf.extend_from_slice(bytes.as_ref());
240        Ok(())
241    }
242
243    async fn try_read_stream(mut self: Pin<&mut Self>) -> Result<T, MultipartError> {
244        match poll_fn(move |cx| self.as_mut().project().stream.poll_next(cx)).await {
245            Some(Ok(bytes)) => Ok(bytes),
246            Some(Err(e)) => Err(MultipartError::Payload(e.into())),
247            None => Err(MultipartError::UnexpectedEof),
248        }
249    }
250
251    pub(crate) fn buf_overflow(&self) -> bool {
252        self.buf.len() > self.config.buf_limit
253    }
254}
255
256#[cfg(test)]
257mod test {
258    use std::{convert::Infallible, pin::pin};
259
260    use bytes::Bytes;
261    use futures_util::FutureExt;
262    use http::header::{HeaderValue, CONTENT_DISPOSITION, CONTENT_LENGTH, CONTENT_TYPE};
263
264    use super::*;
265
266    fn once_body(b: impl Into<Bytes>) -> impl Stream<Item = Result<Bytes, Infallible>> {
267        futures_util::stream::once(async { Ok(b.into()) })
268    }
269
270    #[test]
271    fn method() {
272        let req = Request::new(());
273        let body = once_body(Bytes::new());
274        let err = multipart(&req, body).err();
275        assert!(matches!(err, Some(MultipartError::NoPostMethod)));
276    }
277
278    #[test]
279    fn basic() {
280        let body = b"\
281            --abbc761f78ff4d7cb7573b5a23f96ef0\r\n\
282            Content-Disposition: form-data; name=\"file\"; filename=\"foo.txt\"\r\n\
283            Content-Type: text/plain; charset=utf-8\r\nContent-Length: 4\r\n\r\n\
284            test\r\n\
285            --abbc761f78ff4d7cb7573b5a23f96ef0\r\n\
286            Content-Disposition: form-data; name=\"file\"; filename=\"bar.txt\"\r\n\
287            Content-Type: text/plain\r\n\r\n\
288            testdata\r\n\
289            --abbc761f78ff4d7cb7573b5a23f96ef0\r\n\
290            Content-Disposition: form-data; name=\"file\"; filename=\"bar.txt\"\r\n\
291            Content-Type: text/plain\r\n\r\n\
292            testdata\r\n\
293            --abbc761f78ff4d7cb7573b5a23f96ef0\r\n\
294            Content-Disposition: form-data; name=\"file\"; filename=\"bar.txt\"\r\n\
295            Content-Type: text/plain\r\nContent-Length: 9\r\n\r\n\
296            testdata2\r\n\
297            --abbc761f78ff4d7cb7573b5a23f96ef0--\r\n\
298            ";
299
300        let mut req = Request::new(());
301        *req.method_mut() = Method::POST;
302        req.headers_mut().insert(
303            CONTENT_TYPE,
304            HeaderValue::from_static("multipart/mixed; boundary=abbc761f78ff4d7cb7573b5a23f96ef0"),
305        );
306
307        let body = once_body(Bytes::copy_from_slice(body));
308
309        let multipart = multipart(&req, body).unwrap();
310
311        let mut multipart = pin!(multipart);
312
313        {
314            let mut field = multipart.try_next().now_or_never().unwrap().unwrap().unwrap();
315
316            assert_eq!(
317                field.headers().get(CONTENT_DISPOSITION).unwrap(),
318                HeaderValue::from_static("form-data; name=\"file\"; filename=\"foo.txt\"")
319            );
320            assert_eq!(field.name().unwrap(), "file");
321            assert_eq!(field.file_name().unwrap(), "foo.txt");
322            assert_eq!(
323                field.headers().get(CONTENT_TYPE).unwrap(),
324                HeaderValue::from_static("text/plain; charset=utf-8")
325            );
326            assert_eq!(
327                field.headers().get(CONTENT_LENGTH).unwrap(),
328                HeaderValue::from_static("4")
329            );
330            assert_eq!(
331                field.try_next().now_or_never().unwrap().unwrap().unwrap().chunk(),
332                b"test"
333            );
334            assert!(field.try_next().now_or_never().unwrap().unwrap().is_none());
335        }
336
337        {
338            let mut field = multipart.try_next().now_or_never().unwrap().unwrap().unwrap();
339
340            assert_eq!(
341                field.headers().get(CONTENT_DISPOSITION).unwrap(),
342                HeaderValue::from_static("form-data; name=\"file\"; filename=\"bar.txt\"")
343            );
344            assert_eq!(field.name().unwrap(), "file");
345            assert_eq!(field.file_name().unwrap(), "bar.txt");
346            assert_eq!(
347                field.headers().get(CONTENT_TYPE).unwrap(),
348                HeaderValue::from_static("text/plain")
349            );
350            assert!(field.headers().get(CONTENT_LENGTH).is_none());
351            assert_eq!(
352                field.try_next().now_or_never().unwrap().unwrap().unwrap().chunk(),
353                b"testdata"
354            );
355            assert!(field.try_next().now_or_never().unwrap().unwrap().is_none());
356        }
357
358        // test drop field without consuming.
359        multipart.try_next().now_or_never().unwrap().unwrap().unwrap();
360
361        {
362            let mut field = multipart.try_next().now_or_never().unwrap().unwrap().unwrap();
363
364            assert_eq!(
365                field.headers().get(CONTENT_DISPOSITION).unwrap(),
366                HeaderValue::from_static("form-data; name=\"file\"; filename=\"bar.txt\"")
367            );
368            assert_eq!(field.name().unwrap(), "file");
369            assert_eq!(field.file_name().unwrap(), "bar.txt");
370            assert_eq!(
371                field.headers().get(CONTENT_TYPE).unwrap(),
372                HeaderValue::from_static("text/plain")
373            );
374            assert_eq!(
375                field.headers().get(CONTENT_LENGTH).unwrap(),
376                HeaderValue::from_static("9")
377            );
378            assert_eq!(
379                field.try_next().now_or_never().unwrap().unwrap().unwrap().chunk(),
380                b"testdata2"
381            );
382            assert!(field.try_next().now_or_never().unwrap().unwrap().is_none());
383        }
384
385        assert!(multipart.try_next().now_or_never().unwrap().unwrap().is_none());
386        assert!(multipart.try_next().now_or_never().unwrap().unwrap().is_none());
387    }
388
389    #[test]
390    fn field_header_overflow() {
391        let body = b"\
392            --12345\r\n\
393            Content-Disposition: form-data; name=\"file\"; filename=\"foo.txt\"\r\n\
394            Content-Type: text/plain; charset=utf-8\r\nContent-Length: 4";
395
396        let mut req = Request::new(());
397        *req.method_mut() = Method::POST;
398        req.headers_mut().insert(
399            CONTENT_TYPE,
400            HeaderValue::from_static("multipart/mixed; boundary=12345"),
401        );
402
403        let body = once_body(Bytes::copy_from_slice(body));
404
405        // limit is set to 7 so the first boundary can be parsed.
406        let multipart = multipart_with_config(&req, body, Config { buf_limit: 7 }).unwrap();
407
408        let mut multipart = pin!(multipart);
409
410        assert!(matches!(
411            multipart.try_next().now_or_never().unwrap().err().unwrap(),
412            MultipartError::Header(httparse::Error::TooManyHeaders)
413        ));
414    }
415
416    // Regression: boundary() in header.rs computed `end` as a subslice-relative
417    // index rather than an absolute one, so &header[start..end] panicked when
418    // Content-Type contained extra parameters after the boundary value.
419    #[test]
420    fn boundary_with_trailing_content_type_param() {
421        let mut req = Request::new(());
422        *req.method_mut() = Method::POST;
423        req.headers_mut().insert(
424            CONTENT_TYPE,
425            HeaderValue::from_static("multipart/form-data; boundary=abc; charset=utf-8"),
426        );
427        let body = once_body(Bytes::new());
428        // Should extract boundary "abc" without panicking.
429        let result = multipart(&req, body);
430        assert!(result.is_ok());
431    }
432
433    #[test]
434    fn boundary_quoted() {
435        // RFC 2045 ยง5.1: boundary may be a quoted-string.
436        let mut req = Request::new(());
437        *req.method_mut() = Method::POST;
438        req.headers_mut().insert(
439            CONTENT_TYPE,
440            HeaderValue::from_static("multipart/form-data; boundary=\"abc def\""),
441        );
442        let body = once_body(Bytes::new());
443        // Should extract "abc def" (quotes stripped), not fail.
444        assert!(multipart(&req, body).is_ok());
445    }
446
447    #[test]
448    fn boundary_unquoted_whitespace() {
449        // Unquoted boundary values may carry surrounding OWS from header folding.
450        let mut req = Request::new(());
451        *req.method_mut() = Method::POST;
452        req.headers_mut().insert(
453            CONTENT_TYPE,
454            HeaderValue::from_static("multipart/form-data; boundary= abc "),
455        );
456        let body = once_body(Bytes::new());
457        // Should extract "abc" (whitespace trimmed), not fail.
458        assert!(multipart(&req, body).is_ok());
459    }
460
461    #[test]
462    fn boundary_quoted_with_leading_whitespace() {
463        // OWS before the opening quote must also be accepted.
464        let mut req = Request::new(());
465        *req.method_mut() = Method::POST;
466        req.headers_mut().insert(
467            CONTENT_TYPE,
468            HeaderValue::from_static("multipart/form-data; boundary= \"abc def\""),
469        );
470        let body = once_body(Bytes::new());
471        assert!(multipart(&req, body).is_ok());
472    }
473
474    #[test]
475    fn boundary_quoted_unclosed() {
476        // An unclosed quote is malformed and must return an error.
477        let mut req = Request::new(());
478        *req.method_mut() = Method::POST;
479        req.headers_mut().insert(
480            CONTENT_TYPE,
481            HeaderValue::from_static("multipart/form-data; boundary=\"unclosed"),
482        );
483        let body = once_body(Bytes::new());
484        assert!(matches!(multipart(&req, body), Err(MultipartError::Boundary)));
485    }
486
487    #[test]
488    fn consume_field_boundary_split_across_chunks() {
489        // Full logical body (boundary = "abc"):
490        //   --abc\r\n<headers-f1>\r\n\r\nsomedata\r\n--abc\r\n<headers-f2>\r\n\r\nhello\r\n--abc--\r\n
491        //
492        // Split just before the second '-' of "--abc":
493        let chunk1 = Bytes::from_static(b"--abc\r\nContent-Disposition: form-data; name=\"f1\"\r\n\r\nsomedata\r\n-");
494        let chunk2 =
495            Bytes::from_static(b"-abc\r\nContent-Disposition: form-data; name=\"f2\"\r\n\r\nhello\r\n--abc--\r\n");
496
497        let mut req = Request::new(());
498        *req.method_mut() = Method::POST;
499        req.headers_mut()
500            .insert(CONTENT_TYPE, HeaderValue::from_static("multipart/mixed; boundary=abc"));
501
502        let body = futures_util::stream::iter(vec![Ok::<_, Infallible>(chunk1), Ok(chunk2)]);
503        let multipart = multipart(&req, body).unwrap();
504        let mut multipart = pin!(multipart);
505
506        // Retrieve field1 then drop it without consuming any bytes.
507        // The next try_next() call will invoke consume_pending_field(), which
508        // is where the split-boundary bug manifests.
509        {
510            let _field1 = multipart.try_next().now_or_never().unwrap().unwrap().unwrap();
511        }
512
513        // field2 must still be accessible.
514        let field2 = multipart.try_next().now_or_never().unwrap().unwrap();
515        assert!(
516            field2.is_some(),
517            "field2 was skipped because '--' was split across two stream chunks"
518        );
519    }
520
521    #[test]
522    fn boundary_overflow() {
523        let body = b"--123456";
524
525        let mut req = Request::new(());
526        *req.method_mut() = Method::POST;
527        req.headers_mut().insert(
528            CONTENT_TYPE,
529            HeaderValue::from_static("multipart/mixed; boundary=12345"),
530        );
531
532        let body = once_body(Bytes::copy_from_slice(body));
533
534        // limit is set to 7 so the first boundary can not be parsed.
535        let multipart = multipart_with_config(&req, body, Config { buf_limit: 7 }).unwrap();
536
537        let mut multipart = pin!(multipart);
538
539        assert!(matches!(
540            multipart.try_next().now_or_never().unwrap().err().unwrap(),
541            MultipartError::BufferOverflow
542        ));
543    }
544}