http_multipart/
lib.rs

1#![forbid(unsafe_code)]
2
3mod content_disposition;
4mod error;
5mod field;
6mod header;
7
8pub use self::{error::MultipartError, field::Field};
9
10use core::{future::poll_fn, pin::Pin};
11
12use bytes::{Buf, BytesMut};
13use field::FieldDecoder;
14use futures_core::stream::Stream;
15use http::{header::HeaderMap, Method, Request};
16use memchr::memmem;
17use pin_project_lite::pin_project;
18
19use self::{content_disposition::ContentDisposition, error::PayloadError};
20
21/// Multipart protocol using high level API that operate over [Stream] trait.
22///
23/// `http` crate is used as Http request input. It provides necessary header information needed for
24/// [Multipart].
25///
26/// # Examples:
27/// ```rust
28/// use std::{convert::Infallible, error, pin::pin};
29///
30/// use futures_core::stream::Stream;
31/// use http::Request;
32///
33/// async fn handle<B>(req: Request<B>) -> Result<(), Box<dyn error::Error + Send + Sync>>
34/// where
35///     B: Stream<Item = Result<Vec<u8>, Infallible>>
36/// {
37///     // destruct request type.
38///     let (parts, body) = req.into_parts();
39///     let req = Request::from_parts(parts, ());
40///
41///     // prepare multipart handling.
42///     let mut multipart = http_multipart::multipart(&req, body)?;
43///
44///     // pin multipart and start await on the request body.
45///     let mut multipart = pin!(multipart);
46///
47///     // try async iterate through fields of the multipart.
48///     while let Some(mut field) = multipart.try_next().await? {
49///         // try async iterate through single field's bytes data.
50///         while let Some(chunk) = field.try_next().await? {
51///             // handle bytes data.
52///         }
53///     }
54///
55///     Ok(())
56/// }
57/// ```
58pub fn multipart<Ext, B, T, E>(req: &Request<Ext>, body: B) -> Result<Multipart<B>, MultipartError>
59where
60    B: Stream<Item = Result<T, E>>,
61    T: AsRef<[u8]>,
62    E: Into<PayloadError>,
63{
64    multipart_with_config(req, body, Config::default())
65}
66
67/// [multipart] with [Config] that used for customize behavior of [Multipart].
68pub fn multipart_with_config<Ext, B, T, E>(
69    req: &Request<Ext>,
70    body: B,
71    config: Config,
72) -> Result<Multipart<B>, MultipartError>
73where
74    B: Stream<Item = Result<T, E>>,
75    T: AsRef<[u8]>,
76    E: Into<PayloadError>,
77{
78    if req.method() != Method::POST {
79        return Err(MultipartError::NoPostMethod);
80    }
81
82    let boundary = header::boundary(req.headers())?;
83
84    Ok(Multipart {
85        stream: body,
86        buf: BytesMut::new(),
87        boundary: boundary.into(),
88        headers: HeaderMap::new(),
89        pending_field: false,
90        config,
91    })
92}
93
94/// Configuration for [Multipart] type
95#[derive(Debug, Copy, Clone)]
96pub struct Config {
97    /// limit the max size of internal buffer.
98    /// internal buffer is used to cache overlapped chunks around boundary and filed headers.
99    /// Default to 1MB
100    pub buf_limit: usize,
101}
102
103impl Default for Config {
104    fn default() -> Self {
105        Self { buf_limit: 1024 * 1024 }
106    }
107}
108
109pin_project! {
110    pub struct Multipart<S> {
111        #[pin]
112        stream: S,
113        buf: BytesMut,
114        boundary: Box<[u8]>,
115        headers: HeaderMap,
116        pending_field: bool,
117        config: Config
118    }
119}
120
121const DOUBLE_HYPHEN: &[u8; 2] = b"--";
122const LF: &[u8; 1] = b"\n";
123const DOUBLE_CR_LF: &[u8; 4] = b"\r\n\r\n";
124
125impl<S, T, E> Multipart<S>
126where
127    S: Stream<Item = Result<T, E>>,
128    T: AsRef<[u8]>,
129    E: Into<PayloadError>,
130{
131    // take in &mut Pin<&mut Self> so Field can borrow it as Pin<&mut Multipart>.
132    // this avoid another explicit stack pin when operating on Field type.
133    pub async fn try_next<'s>(self: &'s mut Pin<&mut Self>) -> Result<Option<Field<'s, S>>, MultipartError> {
134        let boundary_len = self.boundary.len();
135
136        if self.pending_field {
137            self.as_mut().consume_pending_field().await?;
138        }
139
140        loop {
141            let this = self.as_mut().project();
142            if let Some(idx) = memmem::find(this.buf, LF) {
143                // backtrack one byte to exclude CR
144                let slice = match idx.checked_sub(1) {
145                    Some(idx) => &this.buf[..idx],
146                    // no CR before LF.
147                    None => return Err(MultipartError::Boundary),
148                };
149
150                match slice.len() {
151                    // empty line. skip.
152                    0 => {
153                        // forward one byte to include LF and remove the empty line.
154                        this.buf.advance(idx + 1);
155                        continue;
156                    }
157                    // not enough data to operate.
158                    len if len < (boundary_len + 2) => {}
159                    // not boundary.
160                    _ if &slice[..2] != DOUBLE_HYPHEN => return Err(MultipartError::Boundary),
161                    // non last boundary
162                    _ if this.boundary.as_ref().eq(&slice[2..]) => {
163                        // forward one byte to include CRLF and remove the boundary line.
164                        this.buf.advance(idx + 1);
165
166                        let field = self.as_mut().parse_field().await?;
167                        return Ok(Some(field));
168                    }
169                    // last boundary.
170                    len if len == (boundary_len + 4) => {
171                        let at = boundary_len + 2;
172                        // TODO: add log for ill formed ending boundary?;
173                        let _ = this.boundary.as_ref().eq(&slice[2..at]) && &slice[at..] == DOUBLE_HYPHEN;
174                        return Ok(None);
175                    }
176                    // boundary line exceed expected length.
177                    _ => return Err(MultipartError::Boundary),
178                }
179            }
180
181            if self.buf_overflow() {
182                return Err(MultipartError::BufferOverflow);
183            }
184
185            self.as_mut().try_read_stream_to_buf().await?;
186        }
187    }
188
189    async fn parse_field(mut self: Pin<&mut Self>) -> Result<Field<'_, S>, MultipartError> {
190        loop {
191            let this = self.as_mut().project();
192
193            if let Some(idx) = memmem::find(this.buf, DOUBLE_CR_LF) {
194                let slice = &this.buf[..idx + 4];
195
196                header::parse_headers(this.headers, slice)?;
197                this.buf.advance(slice.len());
198
199                let cp = ContentDisposition::try_from_header(this.headers)?;
200
201                header::check_headers(this.headers)?;
202
203                let length = header::content_length_opt(this.headers)?;
204
205                *this.pending_field = true;
206
207                return Ok(Field::new(length, cp, self));
208            }
209
210            if self.buf_overflow() {
211                return Err(MultipartError::Header(httparse::Error::TooManyHeaders));
212            }
213
214            self.as_mut().try_read_stream_to_buf().await?;
215        }
216    }
217
218    #[cold]
219    #[inline(never)]
220    async fn consume_pending_field(mut self: Pin<&mut Self>) -> Result<(), MultipartError> {
221        let mut field_ty = FieldDecoder::default();
222
223        loop {
224            let this = self.as_mut().project();
225            if let Some(idx) = field_ty.try_find_split_idx(this.buf, this.boundary)? {
226                this.buf.advance(idx);
227            }
228            if matches!(field_ty, FieldDecoder::StreamEnd) {
229                *this.pending_field = false;
230                return Ok(());
231            }
232            self.as_mut().try_read_stream_to_buf().await?;
233        }
234    }
235
236    async fn try_read_stream_to_buf(mut self: Pin<&mut Self>) -> Result<(), MultipartError> {
237        let bytes = self.as_mut().try_read_stream().await?;
238        self.project().buf.extend_from_slice(bytes.as_ref());
239        Ok(())
240    }
241
242    async fn try_read_stream(mut self: Pin<&mut Self>) -> Result<T, MultipartError> {
243        match poll_fn(move |cx| self.as_mut().project().stream.poll_next(cx)).await {
244            Some(Ok(bytes)) => Ok(bytes),
245            Some(Err(e)) => Err(MultipartError::Payload(e.into())),
246            None => Err(MultipartError::UnexpectedEof),
247        }
248    }
249
250    pub(crate) fn buf_overflow(&self) -> bool {
251        self.buf.len() > self.config.buf_limit
252    }
253}
254
255#[cfg(test)]
256mod test {
257    use std::{convert::Infallible, pin::pin};
258
259    use bytes::Bytes;
260    use futures_util::FutureExt;
261    use http::header::{HeaderValue, CONTENT_DISPOSITION, CONTENT_LENGTH, CONTENT_TYPE};
262
263    use super::*;
264
265    fn once_body(b: impl Into<Bytes>) -> impl Stream<Item = Result<Bytes, Infallible>> {
266        futures_util::stream::once(async { Ok(b.into()) })
267    }
268
269    #[test]
270    fn method() {
271        let req = Request::new(());
272        let body = once_body(Bytes::new());
273        let err = multipart(&req, body).err();
274        assert!(matches!(err, Some(MultipartError::NoPostMethod)));
275    }
276
277    #[test]
278    fn basic() {
279        let body = b"\
280            --abbc761f78ff4d7cb7573b5a23f96ef0\r\n\
281            Content-Disposition: form-data; name=\"file\"; filename=\"foo.txt\"\r\n\
282            Content-Type: text/plain; charset=utf-8\r\nContent-Length: 4\r\n\r\n\
283            test\r\n\
284            --abbc761f78ff4d7cb7573b5a23f96ef0\r\n\
285            Content-Disposition: form-data; name=\"file\"; filename=\"bar.txt\"\r\n\
286            Content-Type: text/plain\r\n\r\n\
287            testdata\r\n\
288            --abbc761f78ff4d7cb7573b5a23f96ef0\r\n\
289            Content-Disposition: form-data; name=\"file\"; filename=\"bar.txt\"\r\n\
290            Content-Type: text/plain\r\n\r\n\
291            testdata\r\n\
292            --abbc761f78ff4d7cb7573b5a23f96ef0\r\n\
293            Content-Disposition: form-data; name=\"file\"; filename=\"bar.txt\"\r\n\
294            Content-Type: text/plain\r\nContent-Length: 9\r\n\r\n\
295            testdata2\r\n\
296            --abbc761f78ff4d7cb7573b5a23f96ef0--\r\n\
297            ";
298
299        let mut req = Request::new(());
300        *req.method_mut() = Method::POST;
301        req.headers_mut().insert(
302            CONTENT_TYPE,
303            HeaderValue::from_static("multipart/mixed; boundary=abbc761f78ff4d7cb7573b5a23f96ef0"),
304        );
305
306        let body = once_body(Bytes::copy_from_slice(body));
307
308        let multipart = multipart(&req, body).unwrap();
309
310        let mut multipart = pin!(multipart);
311
312        {
313            let mut field = multipart.try_next().now_or_never().unwrap().unwrap().unwrap();
314
315            assert_eq!(
316                field.headers().get(CONTENT_DISPOSITION).unwrap(),
317                HeaderValue::from_static("form-data; name=\"file\"; filename=\"foo.txt\"")
318            );
319            assert_eq!(field.name().unwrap(), "file");
320            assert_eq!(field.file_name().unwrap(), "foo.txt");
321            assert_eq!(
322                field.headers().get(CONTENT_TYPE).unwrap(),
323                HeaderValue::from_static("text/plain; charset=utf-8")
324            );
325            assert_eq!(
326                field.headers().get(CONTENT_LENGTH).unwrap(),
327                HeaderValue::from_static("4")
328            );
329            assert_eq!(
330                field.try_next().now_or_never().unwrap().unwrap().unwrap().chunk(),
331                b"test"
332            );
333            assert!(field.try_next().now_or_never().unwrap().unwrap().is_none());
334        }
335
336        {
337            let mut field = multipart.try_next().now_or_never().unwrap().unwrap().unwrap();
338
339            assert_eq!(
340                field.headers().get(CONTENT_DISPOSITION).unwrap(),
341                HeaderValue::from_static("form-data; name=\"file\"; filename=\"bar.txt\"")
342            );
343            assert_eq!(field.name().unwrap(), "file");
344            assert_eq!(field.file_name().unwrap(), "bar.txt");
345            assert_eq!(
346                field.headers().get(CONTENT_TYPE).unwrap(),
347                HeaderValue::from_static("text/plain")
348            );
349            assert!(field.headers().get(CONTENT_LENGTH).is_none());
350            assert_eq!(
351                field.try_next().now_or_never().unwrap().unwrap().unwrap().chunk(),
352                b"testdata"
353            );
354            assert!(field.try_next().now_or_never().unwrap().unwrap().is_none());
355        }
356
357        // test drop field without consuming.
358        multipart.try_next().now_or_never().unwrap().unwrap().unwrap();
359
360        {
361            let mut field = multipart.try_next().now_or_never().unwrap().unwrap().unwrap();
362
363            assert_eq!(
364                field.headers().get(CONTENT_DISPOSITION).unwrap(),
365                HeaderValue::from_static("form-data; name=\"file\"; filename=\"bar.txt\"")
366            );
367            assert_eq!(field.name().unwrap(), "file");
368            assert_eq!(field.file_name().unwrap(), "bar.txt");
369            assert_eq!(
370                field.headers().get(CONTENT_TYPE).unwrap(),
371                HeaderValue::from_static("text/plain")
372            );
373            assert_eq!(
374                field.headers().get(CONTENT_LENGTH).unwrap(),
375                HeaderValue::from_static("9")
376            );
377            assert_eq!(
378                field.try_next().now_or_never().unwrap().unwrap().unwrap().chunk(),
379                b"testdata2"
380            );
381            assert!(field.try_next().now_or_never().unwrap().unwrap().is_none());
382        }
383
384        assert!(multipart.try_next().now_or_never().unwrap().unwrap().is_none());
385        assert!(multipart.try_next().now_or_never().unwrap().unwrap().is_none());
386    }
387
388    #[test]
389    fn field_header_overflow() {
390        let body = b"\
391            --12345\r\n\
392            Content-Disposition: form-data; name=\"file\"; filename=\"foo.txt\"\r\n\
393            Content-Type: text/plain; charset=utf-8\r\nContent-Length: 4";
394
395        let mut req = Request::new(());
396        *req.method_mut() = Method::POST;
397        req.headers_mut().insert(
398            CONTENT_TYPE,
399            HeaderValue::from_static("multipart/mixed; boundary=12345"),
400        );
401
402        let body = once_body(Bytes::copy_from_slice(body));
403
404        // limit is set to 7 so the first boundary can be parsed.
405        let multipart = multipart_with_config(&req, body, Config { buf_limit: 7 }).unwrap();
406
407        let mut multipart = pin!(multipart);
408
409        assert!(matches!(
410            multipart.try_next().now_or_never().unwrap().err().unwrap(),
411            MultipartError::Header(httparse::Error::TooManyHeaders)
412        ));
413    }
414
415    #[test]
416    fn boundary_overflow() {
417        let body = b"--123456";
418
419        let mut req = Request::new(());
420        *req.method_mut() = Method::POST;
421        req.headers_mut().insert(
422            CONTENT_TYPE,
423            HeaderValue::from_static("multipart/mixed; boundary=12345"),
424        );
425
426        let body = once_body(Bytes::copy_from_slice(body));
427
428        // limit is set to 7 so the first boundary can not be parsed.
429        let multipart = multipart_with_config(&req, body, Config { buf_limit: 7 }).unwrap();
430
431        let mut multipart = pin!(multipart);
432
433        assert!(matches!(
434            multipart.try_next().now_or_never().unwrap().err().unwrap(),
435            MultipartError::BufferOverflow
436        ));
437    }
438}