axum_extra/extract/
multipart.rs1use axum_core::{
6 __composite_rejection as composite_rejection, __define_rejection as define_rejection,
7 body::Body,
8 extract::FromRequest,
9 response::{IntoResponse, Response},
10 RequestExt,
11};
12use bytes::Bytes;
13use futures_core::stream::Stream;
14use http::{
15 header::{HeaderMap, CONTENT_TYPE},
16 Request, StatusCode,
17};
18use std::{
19 error::Error,
20 fmt,
21 pin::Pin,
22 task::{Context, Poll},
23};
24
25#[cfg_attr(docsrs, doc(cfg(feature = "multipart")))]
89#[derive(Debug)]
90pub struct Multipart {
91 inner: multer::Multipart<'static>,
92}
93
94impl<S> FromRequest<S> for Multipart
95where
96 S: Send + Sync,
97{
98 type Rejection = MultipartRejection;
99
100 async fn from_request(req: Request<Body>, _state: &S) -> Result<Self, Self::Rejection> {
101 let boundary = parse_boundary(req.headers()).ok_or(InvalidBoundary)?;
102 let stream = req.with_limited_body().into_body();
103 let multipart = multer::Multipart::new(stream.into_data_stream(), boundary);
104 Ok(Self { inner: multipart })
105 }
106}
107
108impl Multipart {
109 pub async fn next_field(&mut self) -> Result<Option<Field>, MultipartError> {
111 let field = self
112 .inner
113 .next_field()
114 .await
115 .map_err(MultipartError::from_multer)?;
116
117 if let Some(field) = field {
118 Ok(Some(Field { inner: field }))
119 } else {
120 Ok(None)
121 }
122 }
123
124 pub fn into_stream(self) -> impl Stream<Item = Result<Field, MultipartError>> + Send + 'static {
126 futures_util::stream::try_unfold(self, |mut multipart| async move {
127 let field = multipart.next_field().await?;
128 Ok(field.map(|field| (field, multipart)))
129 })
130 }
131}
132
133#[derive(Debug)]
135pub struct Field {
136 inner: multer::Field<'static>,
137}
138
139impl Stream for Field {
140 type Item = Result<Bytes, MultipartError>;
141
142 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
143 Pin::new(&mut self.inner)
144 .poll_next(cx)
145 .map_err(MultipartError::from_multer)
146 }
147}
148
149impl Field {
150 #[must_use]
154 pub fn name(&self) -> Option<&str> {
155 self.inner.name()
156 }
157
158 #[must_use]
162 pub fn file_name(&self) -> Option<&str> {
163 self.inner.file_name()
164 }
165
166 #[must_use]
168 pub fn content_type(&self) -> Option<&str> {
169 self.inner.content_type().map(|m| m.as_ref())
170 }
171
172 #[must_use]
174 pub fn headers(&self) -> &HeaderMap {
175 self.inner.headers()
176 }
177
178 pub async fn bytes(self) -> Result<Bytes, MultipartError> {
180 self.inner
181 .bytes()
182 .await
183 .map_err(MultipartError::from_multer)
184 }
185
186 pub async fn text(self) -> Result<String, MultipartError> {
188 self.inner.text().await.map_err(MultipartError::from_multer)
189 }
190
191 pub async fn chunk(&mut self) -> Result<Option<Bytes>, MultipartError> {
230 self.inner
231 .chunk()
232 .await
233 .map_err(MultipartError::from_multer)
234 }
235}
236
237#[derive(Debug)]
239pub struct MultipartError {
240 source: multer::Error,
241}
242
243impl MultipartError {
244 fn from_multer(multer: multer::Error) -> Self {
245 Self { source: multer }
246 }
247
248 pub fn body_text(&self) -> String {
250 let body = if is_body_limit_error(&self.source) {
251 "Request payload is too large".to_owned()
252 } else {
253 self.source.to_string()
254 };
255 axum_core::__log_rejection!(
256 rejection_type = Self,
257 body_text = body,
258 status = self.status(),
259 );
260 body
261 }
262
263 #[must_use]
265 pub fn status(&self) -> http::StatusCode {
266 status_code_from_multer_error(&self.source)
267 }
268}
269
270fn status_code_from_multer_error(err: &multer::Error) -> StatusCode {
271 match err {
272 multer::Error::UnknownField { .. }
273 | multer::Error::IncompleteFieldData { .. }
274 | multer::Error::IncompleteHeaders
275 | multer::Error::ReadHeaderFailed(..)
276 | multer::Error::DecodeHeaderName { .. }
277 | multer::Error::DecodeContentType(..)
278 | multer::Error::NoBoundary
279 | multer::Error::DecodeHeaderValue { .. }
280 | multer::Error::NoMultipart
281 | multer::Error::IncompleteStream => StatusCode::BAD_REQUEST,
282 multer::Error::FieldSizeExceeded { .. } | multer::Error::StreamSizeExceeded { .. } => {
283 StatusCode::PAYLOAD_TOO_LARGE
284 }
285 multer::Error::StreamReadFailed(err) => {
286 if let Some(err) = err.downcast_ref::<multer::Error>() {
287 return status_code_from_multer_error(err);
288 }
289
290 if err
291 .downcast_ref::<axum_core::Error>()
292 .and_then(|err| err.source())
293 .and_then(|err| err.downcast_ref::<http_body_util::LengthLimitError>())
294 .is_some()
295 {
296 return StatusCode::PAYLOAD_TOO_LARGE;
297 }
298
299 StatusCode::INTERNAL_SERVER_ERROR
300 }
301 _ => StatusCode::INTERNAL_SERVER_ERROR,
302 }
303}
304
305fn is_body_limit_error(err: &multer::Error) -> bool {
306 match err {
307 multer::Error::FieldSizeExceeded { .. } | multer::Error::StreamSizeExceeded { .. } => true,
308 multer::Error::StreamReadFailed(err) => {
309 if let Some(err) = err.downcast_ref::<multer::Error>() {
310 return is_body_limit_error(err);
311 }
312 err.downcast_ref::<axum_core::Error>()
313 .and_then(|err| err.source())
314 .and_then(|err| err.downcast_ref::<http_body_util::LengthLimitError>())
315 .is_some()
316 }
317 _ => false,
318 }
319}
320
321impl IntoResponse for MultipartError {
322 fn into_response(self) -> Response {
323 (self.status(), self.body_text()).into_response()
324 }
325}
326
327impl fmt::Display for MultipartError {
328 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
329 write!(f, "Error parsing `multipart/form-data` request")
330 }
331}
332
333impl std::error::Error for MultipartError {
334 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
335 Some(&self.source)
336 }
337}
338
339fn parse_boundary(headers: &HeaderMap) -> Option<String> {
340 let content_type = headers.get(CONTENT_TYPE)?.to_str().ok()?;
341 multer::parse_boundary(content_type).ok()
342}
343
344composite_rejection! {
345 pub enum MultipartRejection {
349 InvalidBoundary,
350 }
351}
352
353define_rejection! {
354 #[status = BAD_REQUEST]
355 #[body = "Invalid `boundary` for `multipart/form-data` request"]
356 pub struct InvalidBoundary;
359}
360
361#[cfg(test)]
362mod tests {
363 use super::*;
364 use crate::test_helpers::*;
365 use axum::{extract::DefaultBodyLimit, routing::post, Router};
366
367 #[tokio::test]
368 async fn content_type_with_encoding() {
369 const BYTES: &[u8] = "<!doctype html><title>🦀</title>".as_bytes();
370 const FILE_NAME: &str = "index.html";
371 const CONTENT_TYPE: &str = "text/html; charset=utf-8";
372
373 async fn handle(mut multipart: Multipart) -> impl IntoResponse {
374 let field = multipart.next_field().await.unwrap().unwrap();
375
376 assert_eq!(field.file_name().unwrap(), FILE_NAME);
377 assert_eq!(field.content_type().unwrap(), CONTENT_TYPE);
378 assert_eq!(field.bytes().await.unwrap(), BYTES);
379
380 assert!(multipart.next_field().await.unwrap().is_none());
381 }
382
383 let app = Router::new().route("/", post(handle));
384
385 let client = TestClient::new(app);
386
387 let form = reqwest::multipart::Form::new().part(
388 "file",
389 reqwest::multipart::Part::bytes(BYTES)
390 .file_name(FILE_NAME)
391 .mime_str(CONTENT_TYPE)
392 .unwrap(),
393 );
394
395 client.post("/").multipart(form).await;
396 }
397
398 fn _multipart_from_request_limited() {
400 async fn handler(_: Multipart) {}
401 let _app: Router<()> = Router::new().route("/", post(handler));
402 }
403
404 #[tokio::test]
405 async fn body_too_large() {
406 const BYTES: &[u8] = "<!doctype html><title>🦀</title>".as_bytes();
407
408 async fn handle(mut multipart: Multipart) -> Result<(), MultipartError> {
409 while let Some(field) = multipart.next_field().await? {
410 field.bytes().await?;
411 }
412 Ok(())
413 }
414
415 let app = Router::new()
416 .route("/", post(handle))
417 .layer(DefaultBodyLimit::max(BYTES.len() - 1));
418
419 let client = TestClient::new(app);
420
421 let form =
422 reqwest::multipart::Form::new().part("file", reqwest::multipart::Part::bytes(BYTES));
423
424 let res = client.post("/").multipart(form).await;
425 assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE);
426 assert_eq!(res.text().await, "Request payload is too large");
427 }
428
429 #[tokio::test]
430 #[cfg(feature = "tracing")]
431 async fn body_too_large_with_tracing() {
432 const BYTES: &[u8] = "<!doctype html><title>🦀</title>".as_bytes();
433
434 async fn handle(mut multipart: Multipart) -> impl IntoResponse {
435 let result: Result<(), MultipartError> = async {
436 while let Some(field) = multipart.next_field().await? {
437 field.bytes().await?;
438 }
439 Ok(())
440 }
441 .await;
442
443 let subscriber = tracing_subscriber::FmtSubscriber::builder()
444 .with_max_level(tracing::level_filters::LevelFilter::TRACE)
445 .with_writer(std::io::sink)
446 .finish();
447
448 let guard = tracing::subscriber::set_default(subscriber);
449 let response = result.into_response();
450 drop(guard);
451
452 response
453 }
454
455 let app = Router::new()
456 .route("/", post(handle))
457 .layer(DefaultBodyLimit::max(BYTES.len() - 1));
458
459 let client = TestClient::new(app);
460
461 let form =
462 reqwest::multipart::Form::new().part("file", reqwest::multipart::Part::bytes(BYTES));
463
464 let res = client.post("/").multipart(form).await;
465 assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE);
466 }
467}