axum/extract/
multipart.rs1use super::{FromRequest, Request};
6use crate::body::Bytes;
7use axum_core::{
8 __composite_rejection as composite_rejection, __define_rejection as define_rejection,
9 extract::OptionalFromRequest,
10 response::{IntoResponse, Response},
11 RequestExt,
12};
13use futures_util::stream::Stream;
14use http::{
15 header::{HeaderMap, CONTENT_TYPE},
16 StatusCode,
17};
18use std::{
19 error::Error,
20 fmt,
21 pin::Pin,
22 task::{Context, Poll},
23};
24
25#[cfg_attr(docsrs, doc(cfg(feature = "multipart")))]
63#[derive(Debug)]
64pub struct Multipart {
65 inner: multer::Multipart<'static>,
66}
67
68impl<S> FromRequest<S> for Multipart
69where
70 S: Send + Sync,
71{
72 type Rejection = MultipartRejection;
73
74 async fn from_request(req: Request, _state: &S) -> Result<Self, Self::Rejection> {
75 let boundary = content_type_str(req.headers())
76 .and_then(|content_type| multer::parse_boundary(content_type).ok())
77 .ok_or(InvalidBoundary)?;
78 let stream = req.with_limited_body().into_body();
79 let multipart = multer::Multipart::new(stream.into_data_stream(), boundary);
80 Ok(Self { inner: multipart })
81 }
82}
83
84impl<S> OptionalFromRequest<S> for Multipart
85where
86 S: Send + Sync,
87{
88 type Rejection = MultipartRejection;
89
90 async fn from_request(req: Request, _state: &S) -> Result<Option<Self>, Self::Rejection> {
91 let Some(content_type) = content_type_str(req.headers()) else {
92 return Ok(None);
93 };
94 match multer::parse_boundary(content_type) {
95 Ok(boundary) => {
96 let stream = req.with_limited_body().into_body();
97 let multipart = multer::Multipart::new(stream.into_data_stream(), boundary);
98 Ok(Some(Self { inner: multipart }))
99 }
100 Err(multer::Error::NoMultipart) => Ok(None),
101 Err(_) => Err(MultipartRejection::InvalidBoundary(InvalidBoundary)),
102 }
103 }
104}
105
106impl Multipart {
107 pub async fn next_field(&mut self) -> Result<Option<Field<'_>>, MultipartError> {
109 let field = self
110 .inner
111 .next_field()
112 .await
113 .map_err(MultipartError::from_multer)?;
114
115 if let Some(field) = field {
116 Ok(Some(Field {
117 inner: field,
118 _multipart: self,
119 }))
120 } else {
121 Ok(None)
122 }
123 }
124}
125
126#[derive(Debug)]
128pub struct Field<'a> {
129 inner: multer::Field<'static>,
130 _multipart: &'a mut Multipart,
133}
134
135impl Stream for Field<'_> {
136 type Item = Result<Bytes, MultipartError>;
137
138 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
139 Pin::new(&mut self.inner)
140 .poll_next(cx)
141 .map_err(MultipartError::from_multer)
142 }
143}
144
145impl Field<'_> {
146 #[must_use]
150 pub fn name(&self) -> Option<&str> {
151 self.inner.name()
152 }
153
154 #[must_use]
158 pub fn file_name(&self) -> Option<&str> {
159 self.inner.file_name()
160 }
161
162 #[must_use]
164 pub fn content_type(&self) -> Option<&str> {
165 self.inner.content_type().map(|m| m.as_ref())
166 }
167
168 #[must_use]
170 pub fn headers(&self) -> &HeaderMap {
171 self.inner.headers()
172 }
173
174 pub async fn bytes(self) -> Result<Bytes, MultipartError> {
176 self.inner
177 .bytes()
178 .await
179 .map_err(MultipartError::from_multer)
180 }
181
182 pub async fn text(self) -> Result<String, MultipartError> {
184 self.inner.text().await.map_err(MultipartError::from_multer)
185 }
186
187 pub async fn chunk(&mut self) -> Result<Option<Bytes>, MultipartError> {
226 self.inner
227 .chunk()
228 .await
229 .map_err(MultipartError::from_multer)
230 }
231}
232
233#[derive(Debug)]
235pub struct MultipartError {
236 source: multer::Error,
237}
238
239impl MultipartError {
240 fn from_multer(multer: multer::Error) -> Self {
241 Self { source: multer }
242 }
243
244 #[must_use]
246 pub fn body_text(&self) -> String {
247 if is_body_limit_error(&self.source) {
248 "Request payload is too large".to_owned()
249 } else {
250 self.source.to_string()
251 }
252 }
253
254 #[must_use]
256 pub fn status(&self) -> http::StatusCode {
257 status_code_from_multer_error(&self.source)
258 }
259}
260
261fn status_code_from_multer_error(err: &multer::Error) -> StatusCode {
262 match err {
263 multer::Error::UnknownField { .. }
264 | multer::Error::IncompleteFieldData { .. }
265 | multer::Error::IncompleteHeaders
266 | multer::Error::ReadHeaderFailed(..)
267 | multer::Error::DecodeHeaderName { .. }
268 | multer::Error::DecodeContentType(..)
269 | multer::Error::NoBoundary
270 | multer::Error::DecodeHeaderValue { .. }
271 | multer::Error::NoMultipart
272 | multer::Error::IncompleteStream => StatusCode::BAD_REQUEST,
273 multer::Error::FieldSizeExceeded { .. } | multer::Error::StreamSizeExceeded { .. } => {
274 StatusCode::PAYLOAD_TOO_LARGE
275 }
276 multer::Error::StreamReadFailed(err) => {
277 if let Some(err) = err.downcast_ref::<multer::Error>() {
278 return status_code_from_multer_error(err);
279 }
280
281 if err
282 .downcast_ref::<crate::Error>()
283 .and_then(|err| err.source())
284 .and_then(|err| err.downcast_ref::<http_body_util::LengthLimitError>())
285 .is_some()
286 {
287 return StatusCode::PAYLOAD_TOO_LARGE;
288 }
289
290 StatusCode::INTERNAL_SERVER_ERROR
291 }
292 _ => StatusCode::INTERNAL_SERVER_ERROR,
293 }
294}
295
296fn is_body_limit_error(err: &multer::Error) -> bool {
297 match err {
298 multer::Error::FieldSizeExceeded { .. } | multer::Error::StreamSizeExceeded { .. } => true,
299 multer::Error::StreamReadFailed(err) => {
300 if let Some(err) = err.downcast_ref::<multer::Error>() {
301 return is_body_limit_error(err);
302 }
303 err.downcast_ref::<crate::Error>()
304 .and_then(|err| err.source())
305 .and_then(|err| err.downcast_ref::<http_body_util::LengthLimitError>())
306 .is_some()
307 }
308 _ => false,
309 }
310}
311
312impl fmt::Display for MultipartError {
313 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
314 write!(f, "Error parsing `multipart/form-data` request")
315 }
316}
317
318impl std::error::Error for MultipartError {
319 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
320 Some(&self.source)
321 }
322}
323
324impl IntoResponse for MultipartError {
325 fn into_response(self) -> Response {
326 let body = self.body_text();
327 axum_core::__log_rejection!(
328 rejection_type = Self,
329 body_text = body,
330 status = self.status(),
331 );
332 (self.status(), body).into_response()
333 }
334}
335
336fn content_type_str(headers: &HeaderMap) -> Option<&str> {
337 headers.get(CONTENT_TYPE)?.to_str().ok()
338}
339
340composite_rejection! {
341 pub enum MultipartRejection {
345 InvalidBoundary,
346 }
347}
348
349define_rejection! {
350 #[status = BAD_REQUEST]
351 #[body = "Invalid `boundary` for `multipart/form-data` request"]
352 pub struct InvalidBoundary;
355}
356
357#[cfg(test)]
358mod tests {
359 use axum_core::extract::DefaultBodyLimit;
360
361 use super::*;
362 use crate::{routing::post, test_helpers::*, Router};
363
364 #[crate::test]
365 async fn content_type_with_encoding() {
366 const BYTES: &[u8] = "<!doctype html><title>🦀</title>".as_bytes();
367 const FILE_NAME: &str = "index.html";
368 const CONTENT_TYPE: &str = "text/html; charset=utf-8";
369
370 async fn handle(mut multipart: Multipart) -> impl IntoResponse {
371 let field = multipart.next_field().await.unwrap().unwrap();
372
373 assert_eq!(field.file_name().unwrap(), FILE_NAME);
374 assert_eq!(field.content_type().unwrap(), CONTENT_TYPE);
375 assert_eq!(field.headers()["foo"], "bar");
376 assert_eq!(field.bytes().await.unwrap(), BYTES);
377
378 assert!(multipart.next_field().await.unwrap().is_none());
379 }
380
381 let app = Router::new().route("/", post(handle));
382
383 let client = TestClient::new(app);
384
385 let form = reqwest::multipart::Form::new().part(
386 "file",
387 reqwest::multipart::Part::bytes(BYTES)
388 .file_name(FILE_NAME)
389 .mime_str(CONTENT_TYPE)
390 .unwrap()
391 .headers(reqwest::header::HeaderMap::from_iter([(
392 reqwest::header::HeaderName::from_static("foo"),
393 reqwest::header::HeaderValue::from_static("bar"),
394 )])),
395 );
396
397 client.post("/").multipart(form).await;
398 }
399
400 fn _multipart_from_request_limited() {
402 async fn handler(_: Multipart) {}
403 let _app: Router = Router::new()
404 .route("/", post(handler))
405 .layer(tower_http::limit::RequestBodyLimitLayer::new(1024));
406 }
407
408 #[crate::test]
409 async fn body_too_large() {
410 const BYTES: &[u8] = "<!doctype html><title>🦀</title>".as_bytes();
411
412 async fn handle(mut multipart: Multipart) -> Result<(), MultipartError> {
413 while let Some(field) = multipart.next_field().await? {
414 field.bytes().await?;
415 }
416 Ok(())
417 }
418
419 let app = Router::new()
420 .route("/", post(handle))
421 .layer(DefaultBodyLimit::max(BYTES.len() - 1));
422
423 let client = TestClient::new(app);
424
425 let form =
426 reqwest::multipart::Form::new().part("file", reqwest::multipart::Part::bytes(BYTES));
427
428 let res = client.post("/").multipart(form).await;
429 assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE);
430 assert_eq!(res.text().await, "Request payload is too large");
431 }
432
433 #[crate::test]
434 async fn optional_multipart() {
435 const BYTES: &[u8] = "<!doctype html><title>🦀</title>".as_bytes();
436
437 async fn handle(multipart: Option<Multipart>) -> Result<StatusCode, MultipartError> {
438 if let Some(mut multipart) = multipart {
439 while let Some(field) = multipart.next_field().await? {
440 field.bytes().await?;
441 }
442 Ok(StatusCode::OK)
443 } else {
444 Ok(StatusCode::NO_CONTENT)
445 }
446 }
447
448 let app = Router::new().route("/", post(handle));
449 let client = TestClient::new(app);
450 let form =
451 reqwest::multipart::Form::new().part("file", reqwest::multipart::Part::bytes(BYTES));
452
453 let res = client.post("/").multipart(form).await;
454 assert_eq!(res.status(), StatusCode::OK);
455
456 let res = client.post("/").await;
457 assert_eq!(res.status(), StatusCode::NO_CONTENT);
458 }
459}