Skip to main content

axum_extra/extract/
multipart.rs

1//! Extractor that parses `multipart/form-data` requests commonly used with file uploads.
2//!
3//! See [`Multipart`] for more details.
4
5use axum_core::{
6    __composite_rejection as composite_rejection, __define_rejection as define_rejection,
7    body::Body,
8    extract::FromRequest,
9    response::{IntoResponse, Response},
10    RequestExt,
11};
12use bytes::Bytes;
13use futures_core::stream::Stream;
14use http::{
15    header::{HeaderMap, CONTENT_TYPE},
16    Request, StatusCode,
17};
18use std::{
19    error::Error,
20    fmt,
21    pin::Pin,
22    task::{Context, Poll},
23};
24
25/// Extractor that parses `multipart/form-data` requests (commonly used with file uploads).
26///
27/// ⚠️ Since extracting multipart form data from the request requires consuming the body, the
28/// `Multipart` extractor must be *last* if there are multiple extractors in a handler.
29/// See ["the order of extractors"][order-of-extractors]
30///
31/// [order-of-extractors]: crate::extract#the-order-of-extractors
32///
33/// # Example
34///
35/// ```
36/// use axum::{
37///     routing::post,
38///     Router,
39/// };
40/// use axum_extra::extract::Multipart;
41///
42/// async fn upload(mut multipart: Multipart) {
43///     while let Some(mut field) = multipart.next_field().await.unwrap() {
44///         let name = field.name().unwrap().to_string();
45///         let data = field.bytes().await.unwrap();
46///
47///         println!("Length of `{}` is {} bytes", name, data.len());
48///     }
49/// }
50///
51/// let app = Router::new().route("/upload", post(upload));
52/// # let _: Router = app;
53/// ```
54///
55/// # Field Exclusivity
56///
57/// A [`Field`] represents a raw, self-decoding stream into multipart data. As such, only one
58/// [`Field`] from a given Multipart instance may be live at once. That is, a [`Field`] emitted by
59/// [`next_field()`] must be dropped before calling [`next_field()`] again. Failure to do so will
60/// result in an error.
61///
62/// ```
63/// use axum_extra::extract::Multipart;
64///
65/// async fn handler(mut multipart: Multipart) {
66///     let field_1 = multipart.next_field().await;
67///
68///     // We cannot get the next field while `field_1` is still alive. Have to drop `field_1`
69///     // first.
70///     let field_2 = multipart.next_field().await;
71///     assert!(field_2.is_err());
72/// }
73/// ```
74///
75/// In general you should consume `Multipart` by looping over the fields in order and make sure not
76/// to keep `Field`s around from previous loop iterations. That will minimize the risk of runtime
77/// errors.
78///
79/// # Differences between this and `axum::extract::Multipart`
80///
81/// `axum::extract::Multipart` uses lifetimes to enforce field exclusivity at compile time, however
82/// that leads to significant usability issues such as `Field` not being `'static`.
83///
84/// `axum_extra::extract::Multipart` instead enforces field exclusivity at runtime which makes
85/// things easier to use at the cost of possible runtime errors.
86///
87/// [`next_field()`]: Multipart::next_field
88#[cfg_attr(docsrs, doc(cfg(feature = "multipart")))]
89#[derive(Debug)]
90pub struct Multipart {
91    inner: multer::Multipart<'static>,
92}
93
94impl<S> FromRequest<S> for Multipart
95where
96    S: Send + Sync,
97{
98    type Rejection = MultipartRejection;
99
100    async fn from_request(req: Request<Body>, _state: &S) -> Result<Self, Self::Rejection> {
101        let boundary = parse_boundary(req.headers()).ok_or(InvalidBoundary)?;
102        let stream = req.with_limited_body().into_body();
103        let multipart = multer::Multipart::new(stream.into_data_stream(), boundary);
104        Ok(Self { inner: multipart })
105    }
106}
107
108impl Multipart {
109    /// Yields the next [`Field`] if available.
110    pub async fn next_field(&mut self) -> Result<Option<Field>, MultipartError> {
111        let field = self
112            .inner
113            .next_field()
114            .await
115            .map_err(MultipartError::from_multer)?;
116
117        if let Some(field) = field {
118            Ok(Some(Field { inner: field }))
119        } else {
120            Ok(None)
121        }
122    }
123
124    /// Convert the `Multipart` into a stream of its fields.
125    pub fn into_stream(self) -> impl Stream<Item = Result<Field, MultipartError>> + Send + 'static {
126        futures_util::stream::try_unfold(self, |mut multipart| async move {
127            let field = multipart.next_field().await?;
128            Ok(field.map(|field| (field, multipart)))
129        })
130    }
131}
132
133/// A single field in a multipart stream.
134#[derive(Debug)]
135pub struct Field {
136    inner: multer::Field<'static>,
137}
138
139impl Stream for Field {
140    type Item = Result<Bytes, MultipartError>;
141
142    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
143        Pin::new(&mut self.inner)
144            .poll_next(cx)
145            .map_err(MultipartError::from_multer)
146    }
147}
148
149impl Field {
150    /// The field name found in the
151    /// [`Content-Disposition`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Disposition)
152    /// header.
153    #[must_use]
154    pub fn name(&self) -> Option<&str> {
155        self.inner.name()
156    }
157
158    /// The file name found in the
159    /// [`Content-Disposition`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Disposition)
160    /// header.
161    #[must_use]
162    pub fn file_name(&self) -> Option<&str> {
163        self.inner.file_name()
164    }
165
166    /// Get the [content type](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Type) of the field.
167    #[must_use]
168    pub fn content_type(&self) -> Option<&str> {
169        self.inner.content_type().map(|m| m.as_ref())
170    }
171
172    /// Get a map of headers as [`HeaderMap`].
173    #[must_use]
174    pub fn headers(&self) -> &HeaderMap {
175        self.inner.headers()
176    }
177
178    /// Get the full data of the field as [`Bytes`].
179    pub async fn bytes(self) -> Result<Bytes, MultipartError> {
180        self.inner
181            .bytes()
182            .await
183            .map_err(MultipartError::from_multer)
184    }
185
186    /// Get the full field data as text.
187    pub async fn text(self) -> Result<String, MultipartError> {
188        self.inner.text().await.map_err(MultipartError::from_multer)
189    }
190
191    /// Stream a chunk of the field data.
192    ///
193    /// When the field data has been exhausted, this will return [`None`].
194    ///
195    /// Note this does the same thing as `Field`'s [`Stream`] implementation.
196    ///
197    /// # Example
198    ///
199    /// ```
200    /// use axum::{
201    ///    routing::post,
202    ///    response::IntoResponse,
203    ///    http::StatusCode,
204    ///    Router,
205    /// };
206    /// use axum_extra::extract::Multipart;
207    ///
208    /// async fn upload(mut multipart: Multipart) -> Result<(), (StatusCode, String)> {
209    ///     while let Some(mut field) = multipart
210    ///         .next_field()
211    ///         .await
212    ///         .map_err(|err| (StatusCode::BAD_REQUEST, err.to_string()))?
213    ///     {
214    ///         while let Some(chunk) = field
215    ///             .chunk()
216    ///             .await
217    ///             .map_err(|err| (StatusCode::BAD_REQUEST, err.to_string()))?
218    ///         {
219    ///             println!("received {} bytes", chunk.len());
220    ///         }
221    ///     }
222    ///
223    ///     Ok(())
224    /// }
225    ///
226    /// let app = Router::new().route("/upload", post(upload));
227    /// # let _: Router = app;
228    /// ```
229    pub async fn chunk(&mut self) -> Result<Option<Bytes>, MultipartError> {
230        self.inner
231            .chunk()
232            .await
233            .map_err(MultipartError::from_multer)
234    }
235}
236
237/// Errors associated with parsing `multipart/form-data` requests.
238#[derive(Debug)]
239pub struct MultipartError {
240    source: multer::Error,
241}
242
243impl MultipartError {
244    fn from_multer(multer: multer::Error) -> Self {
245        Self { source: multer }
246    }
247
248    /// Get the response body text used for this rejection.
249    pub fn body_text(&self) -> String {
250        let body = if is_body_limit_error(&self.source) {
251            "Request payload is too large".to_owned()
252        } else {
253            self.source.to_string()
254        };
255        axum_core::__log_rejection!(
256            rejection_type = Self,
257            body_text = body,
258            status = self.status(),
259        );
260        body
261    }
262
263    /// Get the status code used for this rejection.
264    #[must_use]
265    pub fn status(&self) -> http::StatusCode {
266        status_code_from_multer_error(&self.source)
267    }
268}
269
270fn status_code_from_multer_error(err: &multer::Error) -> StatusCode {
271    match err {
272        multer::Error::UnknownField { .. }
273        | multer::Error::IncompleteFieldData { .. }
274        | multer::Error::IncompleteHeaders
275        | multer::Error::ReadHeaderFailed(..)
276        | multer::Error::DecodeHeaderName { .. }
277        | multer::Error::DecodeContentType(..)
278        | multer::Error::NoBoundary
279        | multer::Error::DecodeHeaderValue { .. }
280        | multer::Error::NoMultipart
281        | multer::Error::IncompleteStream => StatusCode::BAD_REQUEST,
282        multer::Error::FieldSizeExceeded { .. } | multer::Error::StreamSizeExceeded { .. } => {
283            StatusCode::PAYLOAD_TOO_LARGE
284        }
285        multer::Error::StreamReadFailed(err) => {
286            if let Some(err) = err.downcast_ref::<multer::Error>() {
287                return status_code_from_multer_error(err);
288            }
289
290            if err
291                .downcast_ref::<axum_core::Error>()
292                .and_then(|err| err.source())
293                .and_then(|err| err.downcast_ref::<http_body_util::LengthLimitError>())
294                .is_some()
295            {
296                return StatusCode::PAYLOAD_TOO_LARGE;
297            }
298
299            StatusCode::INTERNAL_SERVER_ERROR
300        }
301        _ => StatusCode::INTERNAL_SERVER_ERROR,
302    }
303}
304
305fn is_body_limit_error(err: &multer::Error) -> bool {
306    match err {
307        multer::Error::FieldSizeExceeded { .. } | multer::Error::StreamSizeExceeded { .. } => true,
308        multer::Error::StreamReadFailed(err) => {
309            if let Some(err) = err.downcast_ref::<multer::Error>() {
310                return is_body_limit_error(err);
311            }
312            err.downcast_ref::<axum_core::Error>()
313                .and_then(|err| err.source())
314                .and_then(|err| err.downcast_ref::<http_body_util::LengthLimitError>())
315                .is_some()
316        }
317        _ => false,
318    }
319}
320
321impl IntoResponse for MultipartError {
322    fn into_response(self) -> Response {
323        (self.status(), self.body_text()).into_response()
324    }
325}
326
327impl fmt::Display for MultipartError {
328    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
329        write!(f, "Error parsing `multipart/form-data` request")
330    }
331}
332
333impl std::error::Error for MultipartError {
334    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
335        Some(&self.source)
336    }
337}
338
339fn parse_boundary(headers: &HeaderMap) -> Option<String> {
340    let content_type = headers.get(CONTENT_TYPE)?.to_str().ok()?;
341    multer::parse_boundary(content_type).ok()
342}
343
344composite_rejection! {
345    /// Rejection used for [`Multipart`].
346    ///
347    /// Contains one variant for each way the [`Multipart`] extractor can fail.
348    pub enum MultipartRejection {
349        InvalidBoundary,
350    }
351}
352
353define_rejection! {
354    #[status = BAD_REQUEST]
355    #[body = "Invalid `boundary` for `multipart/form-data` request"]
356    /// Rejection type used if the `boundary` in a `multipart/form-data` is
357    /// missing or invalid.
358    pub struct InvalidBoundary;
359}
360
361#[cfg(test)]
362mod tests {
363    use super::*;
364    use crate::test_helpers::*;
365    use axum::{extract::DefaultBodyLimit, routing::post, Router};
366
367    #[tokio::test]
368    async fn content_type_with_encoding() {
369        const BYTES: &[u8] = "<!doctype html><title>🦀</title>".as_bytes();
370        const FILE_NAME: &str = "index.html";
371        const CONTENT_TYPE: &str = "text/html; charset=utf-8";
372
373        async fn handle(mut multipart: Multipart) -> impl IntoResponse {
374            let field = multipart.next_field().await.unwrap().unwrap();
375
376            assert_eq!(field.file_name().unwrap(), FILE_NAME);
377            assert_eq!(field.content_type().unwrap(), CONTENT_TYPE);
378            assert_eq!(field.bytes().await.unwrap(), BYTES);
379
380            assert!(multipart.next_field().await.unwrap().is_none());
381        }
382
383        let app = Router::new().route("/", post(handle));
384
385        let client = TestClient::new(app);
386
387        let form = reqwest::multipart::Form::new().part(
388            "file",
389            reqwest::multipart::Part::bytes(BYTES)
390                .file_name(FILE_NAME)
391                .mime_str(CONTENT_TYPE)
392                .unwrap(),
393        );
394
395        client.post("/").multipart(form).await;
396    }
397
398    // No need for this to be a #[test], we just want to make sure it compiles
399    fn _multipart_from_request_limited() {
400        async fn handler(_: Multipart) {}
401        let _app: Router<()> = Router::new().route("/", post(handler));
402    }
403
404    #[tokio::test]
405    async fn body_too_large() {
406        const BYTES: &[u8] = "<!doctype html><title>🦀</title>".as_bytes();
407
408        async fn handle(mut multipart: Multipart) -> Result<(), MultipartError> {
409            while let Some(field) = multipart.next_field().await? {
410                field.bytes().await?;
411            }
412            Ok(())
413        }
414
415        let app = Router::new()
416            .route("/", post(handle))
417            .layer(DefaultBodyLimit::max(BYTES.len() - 1));
418
419        let client = TestClient::new(app);
420
421        let form =
422            reqwest::multipart::Form::new().part("file", reqwest::multipart::Part::bytes(BYTES));
423
424        let res = client.post("/").multipart(form).await;
425        assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE);
426        assert_eq!(res.text().await, "Request payload is too large");
427    }
428
429    #[tokio::test]
430    #[cfg(feature = "tracing")]
431    async fn body_too_large_with_tracing() {
432        const BYTES: &[u8] = "<!doctype html><title>🦀</title>".as_bytes();
433
434        async fn handle(mut multipart: Multipart) -> impl IntoResponse {
435            let result: Result<(), MultipartError> = async {
436                while let Some(field) = multipart.next_field().await? {
437                    field.bytes().await?;
438                }
439                Ok(())
440            }
441            .await;
442
443            let subscriber = tracing_subscriber::FmtSubscriber::builder()
444                .with_max_level(tracing::level_filters::LevelFilter::TRACE)
445                .with_writer(std::io::sink)
446                .finish();
447
448            let guard = tracing::subscriber::set_default(subscriber);
449            let response = result.into_response();
450            drop(guard);
451
452            response
453        }
454
455        let app = Router::new()
456            .route("/", post(handle))
457            .layer(DefaultBodyLimit::max(BYTES.len() - 1));
458
459        let client = TestClient::new(app);
460
461        let form =
462            reqwest::multipart::Form::new().part("file", reqwest::multipart::Part::bytes(BYTES));
463
464        let res = client.post("/").multipart(form).await;
465        assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE);
466    }
467}