use crate::storage::UploadedFile;
use axum::{
extract::{multipart::Field, FromRequest, Multipart, Request},
http::StatusCode,
response::{IntoResponse, Response},
};
use std::fmt;
pub const DEFAULT_MAX_FILE_SIZE: usize = 10 * 1024 * 1024;
pub const DEFAULT_MAX_FILES: usize = 10;
#[derive(Debug)]
pub enum FileUploadError {
MissingFile,
MultipleFiles,
MultipartError(String),
FileTooLarge {
actual: usize,
max: usize,
},
TooManyFiles {
actual: usize,
max: usize,
},
MissingField(String),
}
impl fmt::Display for FileUploadError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::MissingFile => write!(f, "No file found in upload"),
Self::MultipleFiles => write!(f, "Multiple files found, expected single file"),
Self::MultipartError(msg) => write!(f, "Multipart error: {msg}"),
Self::FileTooLarge { actual, max } => {
write!(f, "File size {actual} bytes exceeds maximum of {max} bytes")
}
Self::TooManyFiles { actual, max } => {
write!(f, "Upload contains {actual} files, maximum is {max}")
}
Self::MissingField(field) => write!(f, "Missing required field: {field}"),
}
}
}
impl std::error::Error for FileUploadError {}
impl IntoResponse for FileUploadError {
fn into_response(self) -> Response {
let status = match self {
Self::FileTooLarge { .. } => StatusCode::PAYLOAD_TOO_LARGE,
Self::MissingFile | Self::MissingField(_) | Self::MultipleFiles | Self::TooManyFiles { .. } | Self::MultipartError(_) => {
StatusCode::BAD_REQUEST
}
};
(status, self.to_string()).into_response()
}
}
#[derive(Debug)]
pub struct FileUpload(pub UploadedFile);
impl<S> FromRequest<S> for FileUpload
where
S: Send + Sync,
{
type Rejection = FileUploadError;
#[allow(clippy::manual_async_fn)]
fn from_request(
req: Request,
state: &S,
) -> impl std::future::Future<Output = Result<Self, Self::Rejection>> + Send {
async move {
let mut multipart = Multipart::from_request(req, state)
.await
.map_err(|e| FileUploadError::MultipartError(e.to_string()))?;
let mut files = Vec::new();
while let Some(field) = multipart
.next_field()
.await
.map_err(|e| FileUploadError::MultipartError(e.to_string()))?
{
if field.file_name().is_none() {
continue;
}
let filename = field
.file_name()
.ok_or_else(|| FileUploadError::MissingField("filename".to_string()))?
.to_string();
let content_type = field
.content_type()
.unwrap_or("application/octet-stream")
.to_string();
let data = read_field_data(field, DEFAULT_MAX_FILE_SIZE).await?;
files.push(UploadedFile {
filename,
content_type,
data,
});
}
match files.len() {
0 => Err(FileUploadError::MissingFile),
1 => Ok(Self(files.into_iter().next().unwrap())),
_ => Err(FileUploadError::MultipleFiles),
}
}
}
}
#[derive(Debug)]
pub struct MultiFileUpload(pub Vec<UploadedFile>);
impl<S> FromRequest<S> for MultiFileUpload
where
S: Send + Sync,
{
type Rejection = FileUploadError;
#[allow(clippy::manual_async_fn)]
fn from_request(
req: Request,
state: &S,
) -> impl std::future::Future<Output = Result<Self, Self::Rejection>> + Send {
async move {
let mut multipart = Multipart::from_request(req, state)
.await
.map_err(|e| FileUploadError::MultipartError(e.to_string()))?;
let mut files = Vec::new();
while let Some(field) = multipart
.next_field()
.await
.map_err(|e| FileUploadError::MultipartError(e.to_string()))?
{
if field.file_name().is_none() {
continue;
}
if files.len() >= DEFAULT_MAX_FILES {
return Err(FileUploadError::TooManyFiles {
actual: files.len() + 1,
max: DEFAULT_MAX_FILES,
});
}
let filename = field
.file_name()
.ok_or_else(|| FileUploadError::MissingField("filename".to_string()))?
.to_string();
let content_type = field
.content_type()
.unwrap_or("application/octet-stream")
.to_string();
let data = read_field_data(field, DEFAULT_MAX_FILE_SIZE).await?;
files.push(UploadedFile {
filename,
content_type,
data,
});
}
if files.is_empty() {
return Err(FileUploadError::MissingFile);
}
Ok(Self(files))
}
}
}
async fn read_field_data(
field: Field<'_>,
max_size: usize,
) -> Result<Vec<u8>, FileUploadError> {
let data = field
.bytes()
.await
.map_err(|e| FileUploadError::MultipartError(e.to_string()))?;
if data.len() > max_size {
return Err(FileUploadError::FileTooLarge {
actual: data.len(),
max: max_size,
});
}
Ok(data.to_vec())
}
#[cfg(test)]
mod tests {
use super::*;
use axum::http::{header, Request};
use axum::body::Body;
fn create_multipart_request(files: Vec<(&str, &str, &[u8])>) -> Request<Body> {
use std::fmt::Write;
let boundary = "----WebKitFormBoundary7MA4YWxkTrZu0gW";
let mut body = String::new();
for (name, filename, content) in files {
body.push_str("------WebKitFormBoundary7MA4YWxkTrZu0gW\r\n");
write!(
&mut body,
"Content-Disposition: form-data; name=\"{name}\"; filename=\"{filename}\"\r\n"
).unwrap();
body.push_str("Content-Type: application/octet-stream\r\n\r\n");
body.push_str(&String::from_utf8_lossy(content));
body.push_str("\r\n");
}
body.push_str("------WebKitFormBoundary7MA4YWxkTrZu0gW--\r\n");
Request::builder()
.method("POST")
.header(
header::CONTENT_TYPE,
format!("multipart/form-data; boundary={boundary}"),
)
.body(Body::from(body))
.unwrap()
}
#[tokio::test]
async fn test_single_file_upload() {
let req = create_multipart_request(vec![("file", "test.txt", b"Hello, World!")]);
let result = FileUpload::from_request(req, &()).await;
assert!(result.is_ok());
let FileUpload(file) = result.unwrap();
assert_eq!(file.filename, "test.txt");
assert_eq!(file.data, b"Hello, World!");
}
#[tokio::test]
async fn test_multiple_files_rejected_by_single_upload() {
let req = create_multipart_request(vec![
("file1", "test1.txt", b"File 1"),
("file2", "test2.txt", b"File 2"),
]);
let result = FileUpload::from_request(req, &()).await;
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), FileUploadError::MultipleFiles));
}
#[tokio::test]
async fn test_multi_file_upload() {
let req = create_multipart_request(vec![
("file1", "test1.txt", b"File 1"),
("file2", "test2.txt", b"File 2"),
]);
let result = MultiFileUpload::from_request(req, &()).await;
assert!(result.is_ok());
let MultiFileUpload(files) = result.unwrap();
assert_eq!(files.len(), 2);
assert_eq!(files[0].filename, "test1.txt");
assert_eq!(files[1].filename, "test2.txt");
}
#[tokio::test]
async fn test_missing_file() {
let req = Request::builder()
.method("POST")
.header(
header::CONTENT_TYPE,
"multipart/form-data; boundary=----WebKitFormBoundary7MA4YWxkTrZu0gW",
)
.body(Body::from(
"------WebKitFormBoundary7MA4YWxkTrZu0gW--\r\n",
))
.unwrap();
let result = FileUpload::from_request(req, &()).await;
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), FileUploadError::MissingFile));
}
}