rocket_multipart/
reader.rs

1use std::{
2    io,
3    pin::Pin,
4    str::{FromStr, Utf8Error},
5    task::{ready, Context, Poll},
6};
7
8use memchr::{memchr, memmem::find, memrchr};
9use rocket::{
10    data::{DataStream, FromData, Outcome, ToByteUnit},
11    futures::StreamExt,
12    http::{ContentType, Header, HeaderMap, Status},
13    tokio::io::{AsyncBufRead, AsyncRead, ReadBuf},
14    Data, Request,
15};
16use thiserror::Error;
17use tokio_util::{
18    bytes::{BufMut, Bytes, BytesMut},
19    codec::{Decoder, FramedRead},
20};
21
22type Result<T, E = Error> = std::result::Result<T, E>;
23
24/// Error returned by `MultipartReader`
25#[derive(Debug, Error)]
26pub enum Error {
27    /// An underlying IO error
28    #[error(transparent)]
29    Io(#[from] io::Error),
30    /// A header was not utf8 encoded
31    #[error(transparent)]
32    Encoding(#[from] Utf8Error),
33    /// An error from `serde_json`
34    ///
35    /// Only available on `json` feature
36    #[cfg(feature = "json")]
37    #[error(transparent)]
38    Json(#[from] serde_json::Error),
39    /// The content-type of a multipart stream did not specify a boundary
40    #[error("The content type of a multipart stream must specify a boundary")]
41    BoundaryNotSpecified,
42}
43
44/// A single section in a multipart stream
45///
46/// Implements both `AsyncRead` and `AsyncBufRead`, and can be used with any API
47/// that expects either.
48pub struct MultipartReadSection<'r, 'a> {
49    headers: HeaderMap<'static>,
50    reader: &'a mut MultipartReader<'r>,
51}
52
53impl<'a> MultipartReadSection<'_, 'a> {
54    /// Gets the list of headers specific to this multipart section
55    pub fn headers(&self) -> &HeaderMap<'static> {
56        &self.headers
57    }
58
59    /// Retrieves the `Content-Type` header (if it exists) and parses it.
60    pub fn content_type(&self) -> Option<ContentType> {
61        let s = self.headers.get_one("Content-Type")?;
62        ContentType::from_str(s).ok()
63    }
64
65    /// Read the entire stream into a single bytes object.
66    ///
67    /// Should generally be more effecient than `read_to_end`, since it
68    /// generally avoids copying data into a new buffer.
69    pub async fn to_bytes(self) -> Result<Bytes> {
70        let mut raw_data = BytesMut::new();
71        while let MultipartFrame::Data(bytes) = &mut self.reader.buffer {
72            raw_data.unsplit(bytes.split());
73            match self.reader.stream.next().await {
74                Some(Ok(next)) => self.reader.buffer = next,
75                Some(Err(e)) => {
76                    self.reader.buffer = MultipartFrame::End;
77                    return Err(e);
78                }
79                None => self.reader.buffer = MultipartFrame::End,
80            }
81        }
82        Ok(raw_data.freeze())
83    }
84
85    /// Read the entire stream, and parse it as a JSON object
86    ///
87    /// Only available on `json` feature
88    #[cfg(feature = "json")]
89    pub async fn json<T: serde::de::DeserializeOwned>(self) -> Result<T> {
90        let bytes = self.to_bytes().await?;
91        Ok(serde_json::from_slice(&bytes)?)
92    }
93}
94
95impl AsyncRead for MultipartReadSection<'_, '_> {
96    fn poll_read(
97        self: Pin<&mut Self>,
98        cx: &mut Context<'_>,
99        buf: &mut ReadBuf<'_>,
100    ) -> Poll<io::Result<()>> {
101        // Relies on AsyncBufRead
102        let mut this = self.get_mut();
103        match Pin::new(&mut this).poll_fill_buf(cx) {
104            Poll::Ready(Ok(buffer)) => {
105                let write_buf = buf.initialize_unfilled();
106                let len = buffer.len().min(write_buf.len());
107                write_buf[..len].copy_from_slice(&buffer[..len]);
108                unsafe { buf.advance_mut(len) };
109                Pin::new(this).consume(len);
110                Poll::Ready(Ok(()))
111            }
112            Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
113            Poll::Pending => return Poll::Pending,
114        }
115    }
116}
117
118impl AsyncBufRead for MultipartReadSection<'_, '_> {
119    fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
120        let this = self.get_mut();
121        let buffer = &mut this.reader.buffer;
122        if buffer.is_empty() {
123            match ready!(this.reader.stream.poll_next_unpin(cx)) {
124                Some(Ok(by)) => *buffer = by,
125                None => *buffer = MultipartFrame::End,
126                Some(Err(e)) => return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, e))),
127            }
128        }
129        // Buffer now either has data, or we have run out of data to provide
130        if let MultipartFrame::Data(data) = buffer {
131            return Poll::Ready(Ok(data));
132        } else {
133            return Poll::Ready(Ok(b""));
134        }
135    }
136
137    fn consume(self: Pin<&mut Self>, amt: usize) {
138        let this = self.get_mut();
139        if let MultipartFrame::Data(data) = &mut this.reader.buffer {
140            let _ = data.split_to(amt.min(data.len()));
141        }
142    }
143}
144
145#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
146enum MultipartDecoderState {
147    BeforeFirstBoundary,
148    Headers,
149    Data,
150    End,
151}
152struct MultipartDecoder<'r> {
153    state: MultipartDecoderState,
154    boundary: &'r str,
155}
156
157#[derive(Debug, PartialEq, Eq, Clone, Hash)]
158enum MultipartFrame {
159    Boundary,
160    Header(Header<'static>),
161    Data(BytesMut),
162    End,
163}
164
165impl MultipartFrame {
166    fn is_empty(&self) -> bool {
167        if let Self::Data(v) = self {
168            v.is_empty()
169        } else {
170            false
171        }
172    }
173}
174
175const CHUNK_SIZE: usize = 1024;
176
177impl MultipartDecoder<'_> {
178    fn parse_header(header: &[u8]) -> Result<Header<'static>> {
179        if let Some(middle) = memchr(b':', header) {
180            Ok(Header::new(
181                std::str::from_utf8(&header[..middle])?.to_owned(),
182                std::str::from_utf8(&header[middle + 1..])?
183                    .trim()
184                    .to_owned(),
185            ))
186        } else {
187            // Malformed header
188            todo!()
189        }
190    }
191}
192
193impl Decoder for MultipartDecoder<'_> {
194    type Item = MultipartFrame;
195    type Error = Error;
196
197    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>> {
198        loop {
199            match self.state {
200                MultipartDecoderState::BeforeFirstBoundary => {
201                    if let Some(pos) = find(src, self.boundary.as_bytes()) {
202                        if pos < 2 {
203                            let _ = src.split_to(pos + 1);
204                            continue;
205                        }
206                        if pos + self.boundary.len() + 2 > src.len() {
207                            let _ = src.split_to(pos - 4);
208                            src.reserve(CHUNK_SIZE);
209                            return Ok(None);
210                        }
211                        if &src[pos - 2..pos] == b"--"
212                            && &src[pos + self.boundary.len()..][..2] == b"--"
213                        {
214                            self.state = MultipartDecoderState::End;
215                            continue;
216                        }
217                        if &src[pos - 2..pos] != b"--"
218                            && &src[pos + self.boundary.len()..][..2] != b"\r\n"
219                        {
220                            let _ = src.split_to(pos + self.boundary.len());
221                        } else {
222                            let _ = src.split_to(pos + self.boundary.len() + "\r\n".len());
223                            self.state = MultipartDecoderState::Headers;
224                            return Ok(Some(MultipartFrame::Boundary));
225                        }
226                    } else {
227                        src.reserve(CHUNK_SIZE);
228                        return Ok(None);
229                    }
230                }
231                MultipartDecoderState::Headers => {
232                    if let Some(end) = find(src, b"\r\n") {
233                        let header = src.split_to(end + "\r\n".len());
234                        if end == 0 {
235                            self.state = MultipartDecoderState::Data;
236                            continue;
237                        }
238                        return Ok(Some(MultipartFrame::Header(Self::parse_header(
239                            &header[..],
240                        )?)));
241                    } else {
242                        src.reserve(CHUNK_SIZE);
243                        return Ok(None);
244                    }
245                }
246                MultipartDecoderState::Data => {
247                    if let Some(pos) = find(src, self.boundary.as_bytes()) {
248                        if pos < 4 {
249                            let data = src.split_to(pos + 1);
250                            return Ok(Some(MultipartFrame::Data(data)));
251                        }
252                        if pos + self.boundary.len() + 2 > src.len() {
253                            let data = src.split_to(pos - 4);
254                            src.reserve(CHUNK_SIZE);
255                            if data.is_empty() {
256                                return Ok(None);
257                            } else {
258                                return Ok(Some(MultipartFrame::Data(data)));
259                            }
260                        }
261                        if &src[pos - 4..pos] == b"\r\n--"
262                            && &src[pos + self.boundary.len()..][..2] == b"--"
263                        {
264                            self.state = MultipartDecoderState::End;
265                            if pos - 4 > 0 {
266                                let data = src.split_to(pos - 4);
267                                return Ok(Some(MultipartFrame::Data(data)));
268                            }
269                            continue;
270                        }
271                        if &src[pos - 4..pos] != b"\r\n--"
272                            && &src[pos + self.boundary.len()..][..2] != b"\r\n"
273                        {
274                            let data = src.split_to(pos + self.boundary.len());
275                            return Ok(Some(MultipartFrame::Data(data)));
276                        } else {
277                            if pos > 4 {
278                                let data = src.split_to(pos - 4);
279                                return Ok(Some(MultipartFrame::Data(data)));
280                            }
281                            let _ = src.split_to(pos + self.boundary.len() + "\r\n".len());
282                            self.state = MultipartDecoderState::Headers;
283                            return Ok(Some(MultipartFrame::Boundary));
284                        }
285                    } else {
286                        let end = src
287                            .len()
288                            .saturating_sub(self.boundary.len() + 4) // Known safe prefix
289                            .max(memrchr(b'\r', src).unwrap_or(0));
290                        let data = src.split_to(end);
291                        src.reserve(CHUNK_SIZE);
292                        if data.is_empty() {
293                            return Ok(None);
294                        } else {
295                            return Ok(Some(MultipartFrame::Data(data)));
296                        }
297                    }
298                }
299                MultipartDecoderState::End => {
300                    let _ = src.split();
301                    src.reserve(CHUNK_SIZE);
302                    return Ok(None);
303                }
304            }
305        }
306    }
307}
308
309/// A data guard for `multipart/*` data. Provides async reading of the
310/// individual multipart sections.
311///
312/// # Example
313///
314/// ```rust,no_run
315/// # use rocket::{post, tokio::io::AsyncReadExt};
316/// # use rocket_multipart::MultipartReader;
317/// #[post("/mixed", data = "<mixed>")]
318/// async fn multipart_data(mut mixed: MultipartReader<'_>) -> String {
319///     while let Some(mut a) = mixed.next().await.unwrap() {
320///         if let Some(ct) = a.headers().get_one("Content-Type") {
321///             // Check content_type
322///         }
323///         let mut buf = vec![];
324///         a.read_to_end(&mut buf).await.unwrap();
325///         // Use section's body
326///     }
327/// #   String::new()
328/// }
329/// ```
330///
331/// # Limits
332///
333/// Like most data guards, `MultipartReader` provides a configurable limit. It
334/// uses the `file/multipart` limit or a default limit of 1 MiB. This is the
335/// limit for the entire stream, not individual sections.
336pub struct MultipartReader<'r> {
337    stream: FramedRead<DataStream<'r>, MultipartDecoder<'r>>,
338    buffer: MultipartFrame,
339    content_type: &'r ContentType,
340}
341
342impl<'r> MultipartReader<'r> {
343    /// Gets the next section from this multipart reader. The returned section
344    /// mutably borrows from `self`, so it must be dropped before another secton
345    /// can be read.
346    pub async fn next(&mut self) -> Result<Option<MultipartReadSection<'r, '_>>> {
347        while self.buffer != MultipartFrame::Boundary {
348            match self.stream.next().await {
349                Some(Ok(MultipartFrame::End)) | None => {
350                    self.buffer = MultipartFrame::End;
351                    return Ok(None);
352                }
353                Some(Ok(val)) => self.buffer = val,
354                Some(Err(e)) => return Err(e),
355            }
356        }
357
358        let mut headers = HeaderMap::new();
359        loop {
360            match self.stream.next().await {
361                Some(Ok(MultipartFrame::End)) | None => {
362                    self.buffer = MultipartFrame::End;
363                    if headers.is_empty() {
364                        return Ok(None);
365                    } else {
366                        break;
367                    }
368                }
369                Some(Ok(MultipartFrame::Header(header))) => headers.add(header),
370                Some(Ok(MultipartFrame::Boundary)) => {
371                    self.buffer = MultipartFrame::Boundary;
372                    break;
373                }
374                Some(Ok(val @ MultipartFrame::Data(_))) => {
375                    self.buffer = val;
376                    break;
377                }
378                // Some(Ok(val)) => self.buffer = val,
379                Some(Err(e)) => return Err(e),
380            }
381        }
382        Ok(Some(MultipartReadSection {
383            headers,
384            reader: self,
385        }))
386    }
387
388    /// The content type of the multipart stream as a whole. The primary type is always `multipart`
389    pub fn content_type(&self) -> &'r ContentType {
390        self.content_type
391    }
392}
393
394#[rocket::async_trait]
395impl<'r> FromData<'r> for MultipartReader<'r> {
396    type Error = Error;
397
398    async fn from_data(req: &'r Request<'_>, data: Data<'r>) -> Outcome<'r, Self> {
399        let limit = req
400            .rocket()
401            .config()
402            .limits
403            .get("file/multipart")
404            .unwrap_or(1.mebibytes());
405        if let Some(content_type) = req.content_type().filter(|c| c.top() == "multipart") {
406            if let Some(boundary) = content_type.param("boundary") {
407                Outcome::Success(Self {
408                    stream: FramedRead::new(
409                        data.open(limit),
410                        MultipartDecoder {
411                            state: MultipartDecoderState::BeforeFirstBoundary,
412                            boundary,
413                        },
414                    ),
415                    buffer: MultipartFrame::Data(BytesMut::new()),
416                    content_type,
417                })
418            } else {
419                Outcome::Error((Status::BadRequest, Error::BoundaryNotSpecified))
420            }
421        } else {
422            Outcome::Forward((data, Status::BadRequest))
423        }
424    }
425}