use std::fmt;
use axum::extract::multipart::{Field, MultipartError, MultipartRejection};
use axum::extract::{FromRequest, Request};
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
#[derive(Debug)]
pub enum TypedMultipartError {
InvalidRequest {
source: MultipartRejection,
},
InvalidRequestBody {
source: MultipartError,
},
MissingField {
field_name: String,
},
WrongFieldType {
field_name: String,
wanted: String,
source: String,
},
DuplicateField {
field_name: String,
},
UnknownField {
field_name: String,
},
InvalidEnumValue {
field_name: String,
value: String,
},
NamelessField,
FieldTooLarge {
field_name: String,
limit_bytes: usize,
},
Other {
source: String,
},
}
impl fmt::Display for TypedMultipartError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::InvalidRequest { source } => {
write!(f, "Invalid multipart request: {source}")
}
Self::InvalidRequestBody { source } => {
write!(f, "Invalid multipart body: {source}")
}
Self::MissingField { field_name } => {
write!(f, "Missing field: `{field_name}`")
}
Self::WrongFieldType {
field_name,
wanted,
source,
} => {
write!(
f,
"Wrong type for field `{field_name}` (expected {wanted}): {source}"
)
}
Self::DuplicateField { field_name } => {
write!(f, "Duplicate field: `{field_name}`")
}
Self::UnknownField { field_name } => {
write!(f, "Unknown field: `{field_name}`")
}
Self::InvalidEnumValue { field_name, value } => {
write!(f, "Invalid enum value `{value}` for field `{field_name}`")
}
Self::NamelessField => write!(f, "Encountered a field without a name"),
Self::FieldTooLarge {
field_name,
limit_bytes,
} => {
write!(
f,
"Field `{field_name}` exceeds size limit of {limit_bytes} bytes"
)
}
Self::Other { source } => write!(f, "{source}"),
}
}
}
impl std::error::Error for TypedMultipartError {}
impl IntoResponse for TypedMultipartError {
fn into_response(self) -> Response {
let status = match &self {
Self::InvalidRequest { .. }
| Self::InvalidRequestBody { .. }
| Self::MissingField { .. }
| Self::DuplicateField { .. }
| Self::UnknownField { .. }
| Self::InvalidEnumValue { .. }
| Self::NamelessField => StatusCode::BAD_REQUEST,
Self::WrongFieldType { .. } => StatusCode::UNSUPPORTED_MEDIA_TYPE,
Self::FieldTooLarge { .. } => StatusCode::PAYLOAD_TOO_LARGE,
Self::Other { .. } => StatusCode::INTERNAL_SERVER_ERROR,
};
(status, self.to_string()).into_response()
}
}
impl From<MultipartError> for TypedMultipartError {
fn from(source: MultipartError) -> Self {
Self::InvalidRequestBody { source }
}
}
impl From<MultipartRejection> for TypedMultipartError {
fn from(source: MultipartRejection) -> Self {
Self::InvalidRequest { source }
}
}
pub trait TryFromMultipartWithState<S: Send + Sync>: Sized {
fn try_from_multipart_with_state(
multipart: &mut axum::extract::Multipart,
state: &S,
) -> impl std::future::Future<Output = Result<Self, TypedMultipartError>> + Send;
}
pub trait TryFromFieldWithState<S: Send + Sync>: Sized {
fn try_from_field_with_state(
field: Field<'_>,
limit_bytes: Option<usize>,
state: &S,
) -> impl std::future::Future<Output = Result<Self, TypedMultipartError>> + Send;
}
#[derive(Debug, Clone)]
pub struct FieldMetadata {
pub name: Option<String>,
pub file_name: Option<String>,
pub content_type: Option<String>,
pub headers: axum::http::HeaderMap,
}
impl From<&Field<'_>> for FieldMetadata {
fn from(field: &Field<'_>) -> Self {
Self {
name: field.name().map(String::from),
file_name: field.file_name().map(String::from),
content_type: field.content_type().map(String::from),
headers: field.headers().clone(),
}
}
}
#[derive(Debug)]
pub struct FieldData<T> {
pub metadata: FieldMetadata,
pub contents: T,
}
impl<T, S> TryFromFieldWithState<S> for FieldData<T>
where
T: TryFromFieldWithState<S> + Send,
S: Send + Sync,
{
async fn try_from_field_with_state(
field: Field<'_>,
limit_bytes: Option<usize>,
state: &S,
) -> Result<Self, TypedMultipartError> {
let metadata = FieldMetadata::from(&field);
let contents = T::try_from_field_with_state(field, limit_bytes, state).await?;
Ok(Self { metadata, contents })
}
}
pub struct TypedMultipart<T>(pub T);
impl<T> std::ops::Deref for TypedMultipart<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<T> std::ops::DerefMut for TypedMultipart<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl<T, S> FromRequest<S> for TypedMultipart<T>
where
T: TryFromMultipartWithState<S>,
S: Send + Sync + 'static,
{
type Rejection = TypedMultipartError;
async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
let mut multipart = axum::extract::Multipart::from_request(req, state)
.await
.map_err(TypedMultipartError::from)?;
let value = T::try_from_multipart_with_state(&mut multipart, state).await?;
Ok(Self(value))
}
}
async fn read_field_data(
mut field: Field<'_>,
limit: Option<usize>,
) -> Result<(String, Vec<u8>), TypedMultipartError> {
let field_name = field.name().unwrap_or_default().to_string();
let data = if let Some(limit) = limit {
let mut buf = Vec::new();
while let Some(chunk) = field.chunk().await? {
buf.extend_from_slice(&chunk);
if buf.len() > limit {
return Err(TypedMultipartError::FieldTooLarge {
field_name,
limit_bytes: limit,
});
}
}
buf
} else {
field.bytes().await?.to_vec()
};
Ok((field_name, data))
}
fn str_to_bool(s: &str) -> Option<bool> {
match s.to_ascii_lowercase().as_str() {
"true" | "yes" | "y" | "1" | "on" => Some(true),
"false" | "no" | "n" | "0" | "off" => Some(false),
_ => None,
}
}
impl<S: Send + Sync> TryFromFieldWithState<S> for String {
async fn try_from_field_with_state(
field: Field<'_>,
limit_bytes: Option<usize>,
_state: &S,
) -> Result<Self, TypedMultipartError> {
let (field_name, data) = read_field_data(field, limit_bytes).await?;
Self::from_utf8(data).map_err(|e| TypedMultipartError::WrongFieldType {
field_name,
wanted: "String".to_string(),
source: e.to_string(),
})
}
}
impl<S: Send + Sync> TryFromFieldWithState<S> for bool {
async fn try_from_field_with_state(
field: Field<'_>,
limit_bytes: Option<usize>,
_state: &S,
) -> Result<Self, TypedMultipartError> {
let (field_name, data) = read_field_data(field, limit_bytes).await?;
let text = std::str::from_utf8(&data).map_err(|e| TypedMultipartError::WrongFieldType {
field_name: field_name.clone(),
wanted: "bool".to_string(),
source: e.to_string(),
})?;
str_to_bool(text).ok_or_else(|| TypedMultipartError::WrongFieldType {
field_name,
wanted: "bool".to_string(),
source: format!("invalid boolean value: `{text}`"),
})
}
}
macro_rules! impl_try_from_field_for_number {
($($ty:ty),* $(,)?) => {
$(
impl<S: Send + Sync> TryFromFieldWithState<S> for $ty {
async fn try_from_field_with_state(
field: Field<'_>,
limit_bytes: Option<usize>,
_state: &S,
) -> Result<Self, TypedMultipartError> {
let (field_name, data) = read_field_data(field, limit_bytes).await?;
let text = std::str::from_utf8(&data).map_err(|e| {
TypedMultipartError::WrongFieldType {
field_name: field_name.clone(),
wanted: stringify!($ty).to_string(),
source: e.to_string(),
}
})?;
text.trim().parse::<$ty>().map_err(|e| {
TypedMultipartError::WrongFieldType {
field_name,
wanted: stringify!($ty).to_string(),
source: e.to_string(),
}
})
}
}
)*
};
}
impl_try_from_field_for_number!(
i8, i16, i32, i64, i128, u8, u16, u32, u64, u128, isize, usize, f32, f64,
);
impl<S: Send + Sync> TryFromFieldWithState<S> for char {
async fn try_from_field_with_state(
field: Field<'_>,
limit_bytes: Option<usize>,
_state: &S,
) -> Result<Self, TypedMultipartError> {
let (field_name, data) = read_field_data(field, limit_bytes).await?;
let text = std::str::from_utf8(&data).map_err(|e| TypedMultipartError::WrongFieldType {
field_name: field_name.clone(),
wanted: "char".to_string(),
source: e.to_string(),
})?;
let mut chars = text.chars();
match (chars.next(), chars.next()) {
(Some(c), None) => Ok(c),
_ => Err(TypedMultipartError::WrongFieldType {
field_name,
wanted: "char".to_string(),
source: "expected exactly one character".to_string(),
}),
}
}
}
impl<S: Send + Sync> TryFromFieldWithState<S> for tempfile::NamedTempFile {
async fn try_from_field_with_state(
mut field: Field<'_>,
limit_bytes: Option<usize>,
_state: &S,
) -> Result<Self, TypedMultipartError> {
let field_name = field.name().unwrap_or_default().to_string();
let mut temp = Self::new().map_err(|e| TypedMultipartError::Other {
source: e.to_string(),
})?;
let mut total = 0usize;
while let Some(chunk) = field.chunk().await? {
total += chunk.len();
if let Some(limit) = limit_bytes
&& total > limit
{
return Err(TypedMultipartError::FieldTooLarge {
field_name,
limit_bytes: limit,
});
}
std::io::Write::write_all(&mut temp, &chunk).map_err(|e| {
TypedMultipartError::Other {
source: e.to_string(),
}
})?;
}
Ok(temp)
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::http::StatusCode;
use axum::response::IntoResponse;
#[test]
fn test_str_to_bool_truthy() {
for val in &[
"true", "True", "TRUE", "yes", "Yes", "y", "Y", "1", "on", "ON",
] {
assert_eq!(str_to_bool(val), Some(true), "expected true for `{val}`");
}
}
#[test]
fn test_str_to_bool_falsy() {
for val in &[
"false", "False", "FALSE", "no", "No", "n", "N", "0", "off", "OFF",
] {
assert_eq!(str_to_bool(val), Some(false), "expected false for `{val}`");
}
}
#[test]
fn test_str_to_bool_invalid() {
for val in &["maybe", "2", "", "yep", "nah"] {
assert_eq!(str_to_bool(val), None, "expected None for `{val}`");
}
}
#[test]
fn test_error_display() {
let err = TypedMultipartError::MissingField {
field_name: "name".to_string(),
};
assert_eq!(err.to_string(), "Missing field: `name`");
let err = TypedMultipartError::FieldTooLarge {
field_name: "file".to_string(),
limit_bytes: 1024,
};
assert_eq!(
err.to_string(),
"Field `file` exceeds size limit of 1024 bytes"
);
let err = TypedMultipartError::WrongFieldType {
field_name: "age".to_string(),
wanted: "i32".to_string(),
source: "invalid digit".to_string(),
};
assert_eq!(
err.to_string(),
"Wrong type for field `age` (expected i32): invalid digit"
);
}
#[test]
fn test_error_display_duplicate_field() {
let err = TypedMultipartError::DuplicateField {
field_name: "email".to_string(),
};
assert_eq!(err.to_string(), "Duplicate field: `email`");
}
#[test]
fn test_error_display_unknown_field() {
let err = TypedMultipartError::UnknownField {
field_name: "foo".to_string(),
};
assert_eq!(err.to_string(), "Unknown field: `foo`");
}
#[test]
fn test_error_display_invalid_enum_value() {
let err = TypedMultipartError::InvalidEnumValue {
field_name: "status".to_string(),
value: "maybe".to_string(),
};
assert_eq!(
err.to_string(),
"Invalid enum value `maybe` for field `status`"
);
}
#[test]
fn test_error_display_nameless_field() {
let err = TypedMultipartError::NamelessField;
assert_eq!(err.to_string(), "Encountered a field without a name");
}
#[test]
fn test_error_display_other() {
let err = TypedMultipartError::Other {
source: "something went wrong".to_string(),
};
assert_eq!(err.to_string(), "something went wrong");
}
#[test]
fn test_into_response_duplicate_field() {
let err = TypedMultipartError::DuplicateField {
field_name: "x".to_string(),
};
let resp = err.into_response();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
}
#[test]
fn test_into_response_unknown_field() {
let err = TypedMultipartError::UnknownField {
field_name: "x".to_string(),
};
let resp = err.into_response();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
}
#[test]
fn test_into_response_invalid_enum_value() {
let err = TypedMultipartError::InvalidEnumValue {
field_name: "x".to_string(),
value: "bad".to_string(),
};
let resp = err.into_response();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
}
#[test]
fn test_into_response_nameless_field() {
let err = TypedMultipartError::NamelessField;
let resp = err.into_response();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
}
#[test]
fn test_into_response_wrong_field_type() {
let err = TypedMultipartError::WrongFieldType {
field_name: "age".to_string(),
wanted: "i32".to_string(),
source: "err".to_string(),
};
let resp = err.into_response();
assert_eq!(resp.status(), StatusCode::UNSUPPORTED_MEDIA_TYPE);
}
#[test]
fn test_into_response_field_too_large() {
let err = TypedMultipartError::FieldTooLarge {
field_name: "file".to_string(),
limit_bytes: 100,
};
let resp = err.into_response();
assert_eq!(resp.status(), StatusCode::PAYLOAD_TOO_LARGE);
}
#[test]
fn test_into_response_other() {
let err = TypedMultipartError::Other {
source: "err".to_string(),
};
let resp = err.into_response();
assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
}
#[test]
fn test_into_response_missing_field() {
let err = TypedMultipartError::MissingField {
field_name: "x".to_string(),
};
let resp = err.into_response();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
}
#[test]
fn test_error_trait_is_implemented() {
let err: Box<dyn std::error::Error> = Box::new(TypedMultipartError::Other {
source: "test".to_string(),
});
assert_eq!(err.to_string(), "test");
}
#[test]
fn test_typed_multipart_deref() {
let tm = TypedMultipart("hello".to_string());
assert_eq!(&*tm, "hello");
assert_eq!(tm.len(), 5); }
#[test]
fn test_typed_multipart_deref_mut() {
let mut tm = TypedMultipart(vec![1, 2, 3]);
tm.push(4);
assert_eq!(&*tm, &[1, 2, 3, 4]);
}
}