use axum_core::{
__composite_rejection as composite_rejection, __define_rejection as define_rejection,
body::Body,
extract::FromRequest,
response::{IntoResponse, Response},
RequestExt,
};
use bytes::Bytes;
use futures_core::stream::Stream;
use http::{
header::{HeaderMap, CONTENT_TYPE},
Request, StatusCode,
};
use std::{
error::Error,
fmt,
pin::Pin,
task::{Context, Poll},
};
#[cfg_attr(docsrs, doc(cfg(feature = "multipart")))]
#[derive(Debug)]
pub struct Multipart {
inner: multer::Multipart<'static>,
}
impl<S> FromRequest<S> for Multipart
where
S: Send + Sync,
{
type Rejection = MultipartRejection;
async fn from_request(req: Request<Body>, _state: &S) -> Result<Self, Self::Rejection> {
let boundary = parse_boundary(req.headers()).ok_or(InvalidBoundary)?;
let stream = req.with_limited_body().into_body();
let multipart = multer::Multipart::new(stream.into_data_stream(), boundary);
Ok(Self { inner: multipart })
}
}
impl Multipart {
pub async fn next_field(&mut self) -> Result<Option<Field>, MultipartError> {
let field = self
.inner
.next_field()
.await
.map_err(MultipartError::from_multer)?;
if let Some(field) = field {
Ok(Some(Field { inner: field }))
} else {
Ok(None)
}
}
pub fn into_stream(self) -> impl Stream<Item = Result<Field, MultipartError>> + Send + 'static {
futures_util::stream::try_unfold(self, |mut multipart| async move {
let field = multipart.next_field().await?;
Ok(field.map(|field| (field, multipart)))
})
}
}
#[derive(Debug)]
pub struct Field {
inner: multer::Field<'static>,
}
impl Stream for Field {
type Item = Result<Bytes, MultipartError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.inner)
.poll_next(cx)
.map_err(MultipartError::from_multer)
}
}
impl Field {
#[must_use]
pub fn name(&self) -> Option<&str> {
self.inner.name()
}
#[must_use]
pub fn file_name(&self) -> Option<&str> {
self.inner.file_name()
}
#[must_use]
pub fn content_type(&self) -> Option<&str> {
self.inner.content_type().map(|m| m.as_ref())
}
#[must_use]
pub fn headers(&self) -> &HeaderMap {
self.inner.headers()
}
pub async fn bytes(self) -> Result<Bytes, MultipartError> {
self.inner
.bytes()
.await
.map_err(MultipartError::from_multer)
}
pub async fn text(self) -> Result<String, MultipartError> {
self.inner.text().await.map_err(MultipartError::from_multer)
}
pub async fn chunk(&mut self) -> Result<Option<Bytes>, MultipartError> {
self.inner
.chunk()
.await
.map_err(MultipartError::from_multer)
}
}
#[derive(Debug)]
pub struct MultipartError {
source: multer::Error,
}
impl MultipartError {
fn from_multer(multer: multer::Error) -> Self {
Self { source: multer }
}
pub fn body_text(&self) -> String {
let body = if is_body_limit_error(&self.source) {
"Request payload is too large".to_owned()
} else {
self.source.to_string()
};
axum_core::__log_rejection!(
rejection_type = Self,
body_text = body,
status = self.status(),
);
body
}
#[must_use]
pub fn status(&self) -> http::StatusCode {
status_code_from_multer_error(&self.source)
}
}
fn status_code_from_multer_error(err: &multer::Error) -> StatusCode {
match err {
multer::Error::UnknownField { .. }
| multer::Error::IncompleteFieldData { .. }
| multer::Error::IncompleteHeaders
| multer::Error::ReadHeaderFailed(..)
| multer::Error::DecodeHeaderName { .. }
| multer::Error::DecodeContentType(..)
| multer::Error::NoBoundary
| multer::Error::DecodeHeaderValue { .. }
| multer::Error::NoMultipart
| multer::Error::IncompleteStream => StatusCode::BAD_REQUEST,
multer::Error::FieldSizeExceeded { .. } | multer::Error::StreamSizeExceeded { .. } => {
StatusCode::PAYLOAD_TOO_LARGE
}
multer::Error::StreamReadFailed(err) => {
if let Some(err) = err.downcast_ref::<multer::Error>() {
return status_code_from_multer_error(err);
}
if err
.downcast_ref::<axum_core::Error>()
.and_then(|err| err.source())
.and_then(|err| err.downcast_ref::<http_body_util::LengthLimitError>())
.is_some()
{
return StatusCode::PAYLOAD_TOO_LARGE;
}
StatusCode::INTERNAL_SERVER_ERROR
}
_ => StatusCode::INTERNAL_SERVER_ERROR,
}
}
fn is_body_limit_error(err: &multer::Error) -> bool {
match err {
multer::Error::FieldSizeExceeded { .. } | multer::Error::StreamSizeExceeded { .. } => true,
multer::Error::StreamReadFailed(err) => {
if let Some(err) = err.downcast_ref::<multer::Error>() {
return is_body_limit_error(err);
}
err.downcast_ref::<axum_core::Error>()
.and_then(|err| err.source())
.and_then(|err| err.downcast_ref::<http_body_util::LengthLimitError>())
.is_some()
}
_ => false,
}
}
impl IntoResponse for MultipartError {
fn into_response(self) -> Response {
(self.status(), self.body_text()).into_response()
}
}
impl fmt::Display for MultipartError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Error parsing `multipart/form-data` request")
}
}
impl std::error::Error for MultipartError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
Some(&self.source)
}
}
fn parse_boundary(headers: &HeaderMap) -> Option<String> {
let content_type = headers.get(CONTENT_TYPE)?.to_str().ok()?;
multer::parse_boundary(content_type).ok()
}
composite_rejection! {
pub enum MultipartRejection {
InvalidBoundary,
}
}
define_rejection! {
#[status = BAD_REQUEST]
#[body = "Invalid `boundary` for `multipart/form-data` request"]
pub struct InvalidBoundary;
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_helpers::*;
use axum::{extract::DefaultBodyLimit, routing::post, Router};
#[tokio::test]
async fn content_type_with_encoding() {
const BYTES: &[u8] = "<!doctype html><title>🦀</title>".as_bytes();
const FILE_NAME: &str = "index.html";
const CONTENT_TYPE: &str = "text/html; charset=utf-8";
async fn handle(mut multipart: Multipart) -> impl IntoResponse {
let field = multipart.next_field().await.unwrap().unwrap();
assert_eq!(field.file_name().unwrap(), FILE_NAME);
assert_eq!(field.content_type().unwrap(), CONTENT_TYPE);
assert_eq!(field.bytes().await.unwrap(), BYTES);
assert!(multipart.next_field().await.unwrap().is_none());
}
let app = Router::new().route("/", post(handle));
let client = TestClient::new(app);
let form = reqwest::multipart::Form::new().part(
"file",
reqwest::multipart::Part::bytes(BYTES)
.file_name(FILE_NAME)
.mime_str(CONTENT_TYPE)
.unwrap(),
);
client.post("/").multipart(form).await;
}
fn _multipart_from_request_limited() {
async fn handler(_: Multipart) {}
let _app: Router<()> = Router::new().route("/", post(handler));
}
#[tokio::test]
async fn body_too_large() {
const BYTES: &[u8] = "<!doctype html><title>🦀</title>".as_bytes();
async fn handle(mut multipart: Multipart) -> Result<(), MultipartError> {
while let Some(field) = multipart.next_field().await? {
field.bytes().await?;
}
Ok(())
}
let app = Router::new()
.route("/", post(handle))
.layer(DefaultBodyLimit::max(BYTES.len() - 1));
let client = TestClient::new(app);
let form =
reqwest::multipart::Form::new().part("file", reqwest::multipart::Part::bytes(BYTES));
let res = client.post("/").multipart(form).await;
assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE);
assert_eq!(res.text().await, "Request payload is too large");
}
#[tokio::test]
#[cfg(feature = "tracing")]
async fn body_too_large_with_tracing() {
const BYTES: &[u8] = "<!doctype html><title>🦀</title>".as_bytes();
async fn handle(mut multipart: Multipart) -> impl IntoResponse {
let result: Result<(), MultipartError> = async {
while let Some(field) = multipart.next_field().await? {
field.bytes().await?;
}
Ok(())
}
.await;
let subscriber = tracing_subscriber::FmtSubscriber::builder()
.with_max_level(tracing::level_filters::LevelFilter::TRACE)
.with_writer(std::io::sink)
.finish();
let guard = tracing::subscriber::set_default(subscriber);
let response = result.into_response();
drop(guard);
response
}
let app = Router::new()
.route("/", post(handle))
.layer(DefaultBodyLimit::max(BYTES.len() - 1));
let client = TestClient::new(app);
let form =
reqwest::multipart::Form::new().part("file", reqwest::multipart::Part::bytes(BYTES));
let res = client.post("/").multipart(form).await;
assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE);
}
}