Skip to main content

axum/extract/
multipart.rs

1//! Extractor that parses `multipart/form-data` requests commonly used with file uploads.
2//!
3//! See [`Multipart`] for more details.
4
5use super::{FromRequest, Request};
6use crate::body::Bytes;
7use axum_core::{
8    __composite_rejection as composite_rejection, __define_rejection as define_rejection,
9    extract::OptionalFromRequest,
10    response::{IntoResponse, Response},
11    RequestExt,
12};
13use futures_util::stream::Stream;
14use http::{
15    header::{HeaderMap, CONTENT_TYPE},
16    StatusCode,
17};
18use std::{
19    error::Error,
20    fmt,
21    pin::Pin,
22    task::{Context, Poll},
23};
24
25/// Extractor that parses `multipart/form-data` requests (commonly used with file uploads).
26///
27/// ⚠️ Since extracting multipart form data from the request requires consuming the body, the
28/// `Multipart` extractor must be *last* if there are multiple extractors in a handler.
29/// See ["the order of extractors"][order-of-extractors]
30///
31/// [order-of-extractors]: crate::extract#the-order-of-extractors
32///
33/// # Example
34///
35/// ```rust,no_run
36/// use axum::{
37///     extract::Multipart,
38///     routing::post,
39///     Router,
40/// };
41/// use futures_util::stream::StreamExt;
42///
43/// async fn upload(mut multipart: Multipart) {
44///     while let Some(mut field) = multipart.next_field().await.unwrap() {
45///         let name = field.name().unwrap().to_string();
46///         let data = field.bytes().await.unwrap();
47///
48///         println!("Length of `{}` is {} bytes", name, data.len());
49///     }
50/// }
51///
52/// let app = Router::new().route("/upload", post(upload));
53/// # let _: Router = app;
54/// ```
55///
56/// # Large Files
57///
58/// For security reasons, by default, `Multipart` limits the request body size to 2MB.
59/// See [`DefaultBodyLimit`][default-body-limit] for how to configure this limit.
60///
61/// [default-body-limit]: crate::extract::DefaultBodyLimit
62#[cfg_attr(docsrs, doc(cfg(feature = "multipart")))]
63#[derive(Debug)]
64pub struct Multipart {
65    inner: multer::Multipart<'static>,
66}
67
68impl<S> FromRequest<S> for Multipart
69where
70    S: Send + Sync,
71{
72    type Rejection = MultipartRejection;
73
74    async fn from_request(req: Request, _state: &S) -> Result<Self, Self::Rejection> {
75        let boundary = content_type_str(req.headers())
76            .and_then(|content_type| multer::parse_boundary(content_type).ok())
77            .ok_or(InvalidBoundary)?;
78        let stream = req.with_limited_body().into_body();
79        let multipart = multer::Multipart::new(stream.into_data_stream(), boundary);
80        Ok(Self { inner: multipart })
81    }
82}
83
84impl<S> OptionalFromRequest<S> for Multipart
85where
86    S: Send + Sync,
87{
88    type Rejection = MultipartRejection;
89
90    async fn from_request(req: Request, _state: &S) -> Result<Option<Self>, Self::Rejection> {
91        let Some(content_type) = content_type_str(req.headers()) else {
92            return Ok(None);
93        };
94        match multer::parse_boundary(content_type) {
95            Ok(boundary) => {
96                let stream = req.with_limited_body().into_body();
97                let multipart = multer::Multipart::new(stream.into_data_stream(), boundary);
98                Ok(Some(Self { inner: multipart }))
99            }
100            Err(multer::Error::NoMultipart) => Ok(None),
101            Err(_) => Err(MultipartRejection::InvalidBoundary(InvalidBoundary)),
102        }
103    }
104}
105
106impl Multipart {
107    /// Yields the next [`Field`] if available.
108    pub async fn next_field(&mut self) -> Result<Option<Field<'_>>, MultipartError> {
109        let field = self
110            .inner
111            .next_field()
112            .await
113            .map_err(MultipartError::from_multer)?;
114
115        if let Some(field) = field {
116            Ok(Some(Field {
117                inner: field,
118                _multipart: self,
119            }))
120        } else {
121            Ok(None)
122        }
123    }
124}
125
126/// A single field in a multipart stream.
127#[derive(Debug)]
128pub struct Field<'a> {
129    inner: multer::Field<'static>,
130    // multer requires there to only be one live `multer::Field` at any point. This enforces that
131    // statically, which multer does not do, it returns an error instead.
132    _multipart: &'a mut Multipart,
133}
134
135impl Stream for Field<'_> {
136    type Item = Result<Bytes, MultipartError>;
137
138    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
139        Pin::new(&mut self.inner)
140            .poll_next(cx)
141            .map_err(MultipartError::from_multer)
142    }
143}
144
145impl Field<'_> {
146    /// The field name found in the
147    /// [`Content-Disposition`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Disposition)
148    /// header.
149    #[must_use]
150    pub fn name(&self) -> Option<&str> {
151        self.inner.name()
152    }
153
154    /// The file name found in the
155    /// [`Content-Disposition`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Disposition)
156    /// header.
157    #[must_use]
158    pub fn file_name(&self) -> Option<&str> {
159        self.inner.file_name()
160    }
161
162    /// Get the [content type](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Type) of the field.
163    #[must_use]
164    pub fn content_type(&self) -> Option<&str> {
165        self.inner.content_type().map(|m| m.as_ref())
166    }
167
168    /// Get a map of headers as [`HeaderMap`].
169    #[must_use]
170    pub fn headers(&self) -> &HeaderMap {
171        self.inner.headers()
172    }
173
174    /// Get the full data of the field as [`Bytes`].
175    pub async fn bytes(self) -> Result<Bytes, MultipartError> {
176        self.inner
177            .bytes()
178            .await
179            .map_err(MultipartError::from_multer)
180    }
181
182    /// Get the full field data as text.
183    pub async fn text(self) -> Result<String, MultipartError> {
184        self.inner.text().await.map_err(MultipartError::from_multer)
185    }
186
187    /// Stream a chunk of the field data.
188    ///
189    /// When the field data has been exhausted, this will return [`None`].
190    ///
191    /// Note this does the same thing as `Field`'s [`Stream`] implementation.
192    ///
193    /// # Example
194    ///
195    /// ```
196    /// use axum::{
197    ///    extract::Multipart,
198    ///    routing::post,
199    ///    response::IntoResponse,
200    ///    http::StatusCode,
201    ///    Router,
202    /// };
203    ///
204    /// async fn upload(mut multipart: Multipart) -> Result<(), (StatusCode, String)> {
205    ///     while let Some(mut field) = multipart
206    ///         .next_field()
207    ///         .await
208    ///         .map_err(|err| (StatusCode::BAD_REQUEST, err.to_string()))?
209    ///     {
210    ///         while let Some(chunk) = field
211    ///             .chunk()
212    ///             .await
213    ///             .map_err(|err| (StatusCode::BAD_REQUEST, err.to_string()))?
214    ///         {
215    ///             println!("received {} bytes", chunk.len());
216    ///         }
217    ///     }
218    ///
219    ///     Ok(())
220    /// }
221    ///
222    /// let app = Router::new().route("/upload", post(upload));
223    /// # let _: Router = app;
224    /// ```
225    pub async fn chunk(&mut self) -> Result<Option<Bytes>, MultipartError> {
226        self.inner
227            .chunk()
228            .await
229            .map_err(MultipartError::from_multer)
230    }
231}
232
233/// Errors associated with parsing `multipart/form-data` requests.
234#[derive(Debug)]
235pub struct MultipartError {
236    source: multer::Error,
237}
238
239impl MultipartError {
240    fn from_multer(multer: multer::Error) -> Self {
241        Self { source: multer }
242    }
243
244    /// Get the response body text used for this rejection.
245    #[must_use]
246    pub fn body_text(&self) -> String {
247        if is_body_limit_error(&self.source) {
248            "Request payload is too large".to_owned()
249        } else {
250            self.source.to_string()
251        }
252    }
253
254    /// Get the status code used for this rejection.
255    #[must_use]
256    pub fn status(&self) -> http::StatusCode {
257        status_code_from_multer_error(&self.source)
258    }
259}
260
261fn status_code_from_multer_error(err: &multer::Error) -> StatusCode {
262    match err {
263        multer::Error::UnknownField { .. }
264        | multer::Error::IncompleteFieldData { .. }
265        | multer::Error::IncompleteHeaders
266        | multer::Error::ReadHeaderFailed(..)
267        | multer::Error::DecodeHeaderName { .. }
268        | multer::Error::DecodeContentType(..)
269        | multer::Error::NoBoundary
270        | multer::Error::DecodeHeaderValue { .. }
271        | multer::Error::NoMultipart
272        | multer::Error::IncompleteStream => StatusCode::BAD_REQUEST,
273        multer::Error::FieldSizeExceeded { .. } | multer::Error::StreamSizeExceeded { .. } => {
274            StatusCode::PAYLOAD_TOO_LARGE
275        }
276        multer::Error::StreamReadFailed(err) => {
277            if let Some(err) = err.downcast_ref::<multer::Error>() {
278                return status_code_from_multer_error(err);
279            }
280
281            if err
282                .downcast_ref::<crate::Error>()
283                .and_then(|err| err.source())
284                .and_then(|err| err.downcast_ref::<http_body_util::LengthLimitError>())
285                .is_some()
286            {
287                return StatusCode::PAYLOAD_TOO_LARGE;
288            }
289
290            StatusCode::INTERNAL_SERVER_ERROR
291        }
292        _ => StatusCode::INTERNAL_SERVER_ERROR,
293    }
294}
295
296fn is_body_limit_error(err: &multer::Error) -> bool {
297    match err {
298        multer::Error::FieldSizeExceeded { .. } | multer::Error::StreamSizeExceeded { .. } => true,
299        multer::Error::StreamReadFailed(err) => {
300            if let Some(err) = err.downcast_ref::<multer::Error>() {
301                return is_body_limit_error(err);
302            }
303            err.downcast_ref::<crate::Error>()
304                .and_then(|err| err.source())
305                .and_then(|err| err.downcast_ref::<http_body_util::LengthLimitError>())
306                .is_some()
307        }
308        _ => false,
309    }
310}
311
312impl fmt::Display for MultipartError {
313    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
314        write!(f, "Error parsing `multipart/form-data` request")
315    }
316}
317
318impl std::error::Error for MultipartError {
319    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
320        Some(&self.source)
321    }
322}
323
324impl IntoResponse for MultipartError {
325    fn into_response(self) -> Response {
326        let body = self.body_text();
327        axum_core::__log_rejection!(
328            rejection_type = Self,
329            body_text = body,
330            status = self.status(),
331        );
332        (self.status(), body).into_response()
333    }
334}
335
336fn content_type_str(headers: &HeaderMap) -> Option<&str> {
337    headers.get(CONTENT_TYPE)?.to_str().ok()
338}
339
340composite_rejection! {
341    /// Rejection used for [`Multipart`].
342    ///
343    /// Contains one variant for each way the [`Multipart`] extractor can fail.
344    pub enum MultipartRejection {
345        InvalidBoundary,
346    }
347}
348
349define_rejection! {
350    #[status = BAD_REQUEST]
351    #[body = "Invalid `boundary` for `multipart/form-data` request"]
352    /// Rejection type used if the `boundary` in a `multipart/form-data` is
353    /// missing or invalid.
354    pub struct InvalidBoundary;
355}
356
357#[cfg(test)]
358mod tests {
359    use axum_core::extract::DefaultBodyLimit;
360
361    use super::*;
362    use crate::{routing::post, test_helpers::*, Router};
363
364    #[crate::test]
365    async fn content_type_with_encoding() {
366        const BYTES: &[u8] = "<!doctype html><title>🦀</title>".as_bytes();
367        const FILE_NAME: &str = "index.html";
368        const CONTENT_TYPE: &str = "text/html; charset=utf-8";
369
370        async fn handle(mut multipart: Multipart) -> impl IntoResponse {
371            let field = multipart.next_field().await.unwrap().unwrap();
372
373            assert_eq!(field.file_name().unwrap(), FILE_NAME);
374            assert_eq!(field.content_type().unwrap(), CONTENT_TYPE);
375            assert_eq!(field.headers()["foo"], "bar");
376            assert_eq!(field.bytes().await.unwrap(), BYTES);
377
378            assert!(multipart.next_field().await.unwrap().is_none());
379        }
380
381        let app = Router::new().route("/", post(handle));
382
383        let client = TestClient::new(app);
384
385        let form = reqwest::multipart::Form::new().part(
386            "file",
387            reqwest::multipart::Part::bytes(BYTES)
388                .file_name(FILE_NAME)
389                .mime_str(CONTENT_TYPE)
390                .unwrap()
391                .headers(reqwest::header::HeaderMap::from_iter([(
392                    reqwest::header::HeaderName::from_static("foo"),
393                    reqwest::header::HeaderValue::from_static("bar"),
394                )])),
395        );
396
397        client.post("/").multipart(form).await;
398    }
399
400    // No need for this to be a #[test], we just want to make sure it compiles
401    fn _multipart_from_request_limited() {
402        async fn handler(_: Multipart) {}
403        let _app: Router = Router::new()
404            .route("/", post(handler))
405            .layer(tower_http::limit::RequestBodyLimitLayer::new(1024));
406    }
407
408    #[crate::test]
409    async fn body_too_large() {
410        const BYTES: &[u8] = "<!doctype html><title>🦀</title>".as_bytes();
411
412        async fn handle(mut multipart: Multipart) -> Result<(), MultipartError> {
413            while let Some(field) = multipart.next_field().await? {
414                field.bytes().await?;
415            }
416            Ok(())
417        }
418
419        let app = Router::new()
420            .route("/", post(handle))
421            .layer(DefaultBodyLimit::max(BYTES.len() - 1));
422
423        let client = TestClient::new(app);
424
425        let form =
426            reqwest::multipart::Form::new().part("file", reqwest::multipart::Part::bytes(BYTES));
427
428        let res = client.post("/").multipart(form).await;
429        assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE);
430        assert_eq!(res.text().await, "Request payload is too large");
431    }
432
433    #[crate::test]
434    async fn optional_multipart() {
435        const BYTES: &[u8] = "<!doctype html><title>🦀</title>".as_bytes();
436
437        async fn handle(multipart: Option<Multipart>) -> Result<StatusCode, MultipartError> {
438            if let Some(mut multipart) = multipart {
439                while let Some(field) = multipart.next_field().await? {
440                    field.bytes().await?;
441                }
442                Ok(StatusCode::OK)
443            } else {
444                Ok(StatusCode::NO_CONTENT)
445            }
446        }
447
448        let app = Router::new().route("/", post(handle));
449        let client = TestClient::new(app);
450        let form =
451            reqwest::multipart::Form::new().part("file", reqwest::multipart::Part::bytes(BYTES));
452
453        let res = client.post("/").multipart(form).await;
454        assert_eq!(res.status(), StatusCode::OK);
455
456        let res = client.post("/").await;
457        assert_eq!(res.status(), StatusCode::NO_CONTENT);
458    }
459}