use axum::extract::{FromRequest, FromRequestParts};
use axum::response::{IntoResponse, Response};
macro_rules! impl_extractor_deref {
($extractor:ident) => {
impl<T> std::ops::Deref for $extractor<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<T> std::ops::DerefMut for $extractor<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
};
}
#[derive(Debug, Clone, Copy, Default)]
pub struct Form<T>(pub T);
impl_extractor_deref!(Form);
impl<S, T> FromRequest<S> for Form<T>
where
S: Send + Sync,
axum::extract::Form<T>: FromRequest<S, Rejection = axum::extract::rejection::FormRejection>,
{
type Rejection = crate::AutumnError;
async fn from_request(req: axum::extract::Request, state: &S) -> Result<Self, Self::Rejection> {
axum::extract::Form::from_request(req, state)
.await
.map(|axum::extract::Form(value)| Self(value))
.map_err(|err| rejection_to_error(err.status(), err.body_text()))
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct Json<T>(pub T);
impl_extractor_deref!(Json);
impl<S, T> FromRequest<S> for Json<T>
where
S: Send + Sync,
axum::extract::Json<T>: FromRequest<S, Rejection = axum::extract::rejection::JsonRejection>,
{
type Rejection = crate::AutumnError;
async fn from_request(req: axum::extract::Request, state: &S) -> Result<Self, Self::Rejection> {
axum::extract::Json::from_request(req, state)
.await
.map(|axum::extract::Json(value)| Self(value))
.map_err(|err| rejection_to_error(err.status(), err.body_text()))
}
}
impl<T> IntoResponse for Json<T>
where
axum::Json<T>: IntoResponse,
{
fn into_response(self) -> Response {
axum::Json(self.0).into_response()
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct Path<T>(pub T);
impl_extractor_deref!(Path);
impl<S, T> FromRequestParts<S> for Path<T>
where
S: Send + Sync,
axum::extract::Path<T>:
FromRequestParts<S, Rejection = axum::extract::rejection::PathRejection>,
{
type Rejection = crate::AutumnError;
async fn from_request_parts(
parts: &mut axum::http::request::Parts,
state: &S,
) -> Result<Self, Self::Rejection> {
axum::extract::Path::from_request_parts(parts, state)
.await
.map(|axum::extract::Path(value)| Self(value))
.map_err(|err| rejection_to_error(err.status(), err.body_text()))
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct Query<T>(pub T);
impl_extractor_deref!(Query);
impl<S, T> FromRequestParts<S> for Query<T>
where
S: Send + Sync,
axum::extract::Query<T>:
FromRequestParts<S, Rejection = axum::extract::rejection::QueryRejection>,
{
type Rejection = crate::AutumnError;
async fn from_request_parts(
parts: &mut axum::http::request::Parts,
state: &S,
) -> Result<Self, Self::Rejection> {
axum::extract::Query::from_request_parts(parts, state)
.await
.map(|axum::extract::Query(value)| Self(value))
.map_err(|err| rejection_to_error(err.status(), err.body_text()))
}
}
fn rejection_to_error(status: http::StatusCode, body_text: String) -> crate::AutumnError {
crate::AutumnError::bad_request_msg(body_text).with_status(status)
}
#[cfg(feature = "multipart")]
pub struct Multipart {
inner: axum::extract::Multipart,
config: crate::security::config::UploadConfig,
}
#[cfg(feature = "multipart")]
impl Multipart {
pub async fn next_field(&mut self) -> crate::AutumnResult<Option<MultipartField<'_>>> {
let Some(field) = self
.inner
.next_field()
.await
.map_err(|err| multipart_error_to_error(&err))?
else {
return Ok(None);
};
if field.file_name().is_some() && !self.config.allowed_mime_types.is_empty() {
let Some(content_type) = field.content_type().map(str::to_owned) else {
return Err(crate::AutumnError::bad_request_msg(
"missing content type on uploaded file",
));
};
if !self
.config
.allowed_mime_types
.iter()
.any(|allowed| allowed.eq_ignore_ascii_case(&content_type))
{
return Err(crate::AutumnError::bad_request_msg(format!(
"unsupported upload content type: {content_type}"
)));
}
}
Ok(Some(MultipartField {
inner: field,
max_file_size_bytes: self.config.max_file_size_bytes,
}))
}
}
#[cfg(feature = "multipart")]
impl<S> axum::extract::FromRequest<S> for Multipart
where
S: Send + Sync,
axum::extract::Multipart:
axum::extract::FromRequest<S, Rejection = axum::extract::multipart::MultipartRejection>,
{
type Rejection = crate::AutumnError;
async fn from_request(
mut req: axum::extract::Request,
state: &S,
) -> Result<Self, Self::Rejection> {
let config = req
.extensions()
.get::<crate::security::config::UploadConfig>()
.cloned()
.unwrap_or_default();
axum::extract::DefaultBodyLimit::max(config.max_request_size_bytes).apply(&mut req);
let inner = axum::extract::Multipart::from_request(req, state)
.await
.map_err(|err| multipart_rejection_to_error(&err))?;
Ok(Self { inner, config })
}
}
#[cfg(feature = "multipart")]
pub struct MultipartField<'a> {
inner: axum::extract::multipart::Field<'a>,
max_file_size_bytes: usize,
}
#[cfg(all(feature = "multipart", feature = "storage"))]
struct MultipartFieldStreamState<'a> {
inner: axum::extract::multipart::Field<'a>,
total: usize,
max: usize,
errored: bool,
}
#[cfg(feature = "multipart")]
impl<'a> MultipartField<'a> {
#[must_use]
pub fn name(&self) -> Option<&str> {
self.inner.name()
}
#[must_use]
pub fn file_name(&self) -> Option<&str> {
self.inner.file_name()
}
#[must_use]
pub fn content_type(&self) -> Option<&str> {
self.inner.content_type()
}
#[must_use]
pub fn with_max_bytes(mut self, max: usize) -> Self {
self.max_file_size_bytes = self.max_file_size_bytes.min(max);
self
}
pub async fn bytes_limited(mut self) -> crate::AutumnResult<Vec<u8>> {
let mut out = Vec::new();
let mut read = 0usize;
while let Some(chunk) = self
.inner
.chunk()
.await
.map_err(|err| multipart_error_to_error(&err))?
{
read += chunk.len();
if read > self.max_file_size_bytes {
return Err(file_too_large_error(self.max_file_size_bytes));
}
out.extend_from_slice(&chunk);
}
Ok(out)
}
#[cfg(feature = "storage")]
pub async fn save_to_blob_store<'b>(
self,
store: &'b (dyn crate::storage::BlobStore + '_),
key: impl Into<String>,
) -> crate::AutumnResult<crate::storage::Blob>
where
'a: 'b,
{
let key = key.into();
let content_type = self
.inner
.content_type()
.map_or_else(|| "application/octet-stream".to_owned(), str::to_owned);
let state = MultipartFieldStreamState {
inner: self.inner,
total: 0,
max: self.max_file_size_bytes,
errored: false,
};
let stream = futures::stream::unfold(state, |mut state| async move {
if state.errored {
return None;
}
match state.inner.chunk().await {
Ok(Some(chunk)) => {
state.total = state.total.saturating_add(chunk.len());
if state.total > state.max {
let err = crate::storage::BlobStoreError::PayloadTooLarge(format!(
"uploaded file exceeds limit of {} bytes",
state.max,
));
state.errored = true;
Some((Err(err), state))
} else {
Some((Ok(chunk), state))
}
}
Ok(None) => None,
Err(err) => {
state.errored = true;
let mapped = blob_error_from_multipart(&err);
Some((Err(mapped), state))
}
}
});
let stream: crate::storage::ByteStream<'b> = Box::pin(stream);
store
.put_stream(&key, &content_type, stream)
.await
.map_err(crate::storage::BlobStoreError::into_autumn_error)
}
pub async fn save_to<P: AsRef<std::path::Path>>(
mut self,
path: P,
) -> crate::AutumnResult<usize> {
use tokio::io::AsyncWriteExt as _;
let path = path.as_ref();
let mut file = tokio::fs::File::create(path)
.await
.map_err(crate::AutumnError::internal_server_error)?;
let mut written = 0usize;
while let Some(chunk) = self
.inner
.chunk()
.await
.map_err(|err| multipart_error_to_error(&err))?
{
written += chunk.len();
if written > self.max_file_size_bytes {
drop(file);
let _ = tokio::fs::remove_file(path).await;
return Err(file_too_large_error(self.max_file_size_bytes));
}
file.write_all(&chunk)
.await
.map_err(crate::AutumnError::internal_server_error)?;
}
file.flush()
.await
.map_err(crate::AutumnError::internal_server_error)?;
Ok(written)
}
}
#[cfg(feature = "multipart")]
fn multipart_rejection_to_error(
err: &axum::extract::multipart::MultipartRejection,
) -> crate::AutumnError {
crate::AutumnError::bad_request_msg(err.body_text()).with_status(err.status())
}
#[cfg(feature = "multipart")]
#[cfg(all(feature = "multipart", feature = "storage"))]
fn blob_error_from_multipart(
err: &axum::extract::multipart::MultipartError,
) -> crate::storage::BlobStoreError {
let status = err.status();
let body = err.body_text();
if status == http::StatusCode::PAYLOAD_TOO_LARGE {
crate::storage::BlobStoreError::PayloadTooLarge(body)
} else if status.is_client_error() {
crate::storage::BlobStoreError::InvalidInput(body)
} else {
crate::storage::BlobStoreError::Io(body)
}
}
#[cfg(feature = "multipart")]
fn multipart_error_to_error(err: &axum::extract::multipart::MultipartError) -> crate::AutumnError {
crate::AutumnError::bad_request_msg(err.body_text()).with_status(err.status())
}
#[cfg(feature = "multipart")]
fn file_too_large_error(max_file_size_bytes: usize) -> crate::AutumnError {
crate::AutumnError::bad_request_msg(format!(
"uploaded file exceeds limit of {max_file_size_bytes} bytes",
))
.with_status(http::StatusCode::PAYLOAD_TOO_LARGE)
}
pub use axum::extract::State;
#[cfg(all(test, feature = "multipart"))]
mod tests {
use super::*;
use axum::extract::FromRequest;
use axum::http::Request;
#[tokio::test]
async fn test_multipart_field_bytes_limited_success() {
let body = "--boundary\r\nContent-Disposition: form-data; name=\"file\"; filename=\"test.txt\"\r\n\r\nhello\r\n--boundary--\r\n";
let req = Request::builder()
.header("content-type", "multipart/form-data; boundary=boundary")
.body(axum::body::Body::from(body))
.unwrap();
let mut multipart = axum::extract::Multipart::from_request(req, &())
.await
.unwrap();
let field = multipart.next_field().await.unwrap().unwrap();
let wrapper = MultipartField {
inner: field,
max_file_size_bytes: 100,
};
let bytes = wrapper.bytes_limited().await.unwrap();
assert_eq!(bytes, b"hello");
}
#[tokio::test]
async fn test_multipart_field_bytes_limited_too_large() {
let body = "--boundary\r\nContent-Disposition: form-data; name=\"file\"; filename=\"test.txt\"\r\n\r\nhello world\r\n--boundary--\r\n";
let req = Request::builder()
.header("content-type", "multipart/form-data; boundary=boundary")
.body(axum::body::Body::from(body))
.unwrap();
let mut multipart = axum::extract::Multipart::from_request(req, &())
.await
.unwrap();
let field = multipart.next_field().await.unwrap().unwrap();
let wrapper = MultipartField {
inner: field,
max_file_size_bytes: 5,
};
let err = wrapper.bytes_limited().await.unwrap_err();
assert_eq!(err.status(), http::StatusCode::PAYLOAD_TOO_LARGE);
}
#[tokio::test]
async fn test_multipart_field_save_to_success() {
let body = "--boundary\r\nContent-Disposition: form-data; name=\"file\"; filename=\"test.txt\"\r\n\r\nfile content\r\n--boundary--\r\n";
let req = Request::builder()
.header("content-type", "multipart/form-data; boundary=boundary")
.body(axum::body::Body::from(body))
.unwrap();
let mut multipart = axum::extract::Multipart::from_request(req, &())
.await
.unwrap();
let field = multipart.next_field().await.unwrap().unwrap();
let wrapper = MultipartField {
inner: field,
max_file_size_bytes: 100,
};
let dir = tempfile::tempdir().unwrap();
let file_path = dir.path().join("out.txt");
let written = wrapper.save_to(&file_path).await.unwrap();
assert_eq!(written, 12);
let content = std::fs::read_to_string(&file_path).unwrap();
assert_eq!(content, "file content");
}
#[tokio::test]
async fn test_multipart_field_save_to_too_large() {
let body = "--boundary\r\nContent-Disposition: form-data; name=\"file\"; filename=\"test.txt\"\r\n\r\nfile content\r\n--boundary--\r\n";
let req = Request::builder()
.header("content-type", "multipart/form-data; boundary=boundary")
.body(axum::body::Body::from(body))
.unwrap();
let mut multipart = axum::extract::Multipart::from_request(req, &())
.await
.unwrap();
let field = multipart.next_field().await.unwrap().unwrap();
let wrapper = MultipartField {
inner: field,
max_file_size_bytes: 4,
};
let dir = tempfile::tempdir().unwrap();
let file_path = dir.path().join("out_large.txt");
let err = wrapper.save_to(&file_path).await.unwrap_err();
assert_eq!(err.status(), http::StatusCode::PAYLOAD_TOO_LARGE);
assert!(!file_path.exists());
}
#[cfg(feature = "storage")]
#[tokio::test]
async fn test_multipart_field_save_to_blob_store_success() {
use crate::storage::{BlobStore, LocalBlobStore, local::SigningKey};
use std::time::Duration;
let body = "--boundary\r\nContent-Disposition: form-data; name=\"file\"; filename=\"test.txt\"\r\nContent-Type: text/plain\r\n\r\nblob content\r\n--boundary--\r\n";
let req = Request::builder()
.header("content-type", "multipart/form-data; boundary=boundary")
.body(axum::body::Body::from(body))
.unwrap();
let mut multipart = axum::extract::Multipart::from_request(req, &())
.await
.unwrap();
let field = multipart.next_field().await.unwrap().unwrap();
let wrapper = MultipartField {
inner: field,
max_file_size_bytes: 100,
};
let root = tempfile::tempdir().unwrap();
let store = LocalBlobStore::new(
"local",
root.path(),
"/blobs",
Duration::from_secs(3600),
SigningKey::random(),
vec![],
)
.unwrap();
let blob = wrapper.save_to_blob_store(&store, "myblob").await.unwrap();
assert_eq!(blob.key, "myblob");
assert_eq!(blob.content_type, "text/plain");
let bytes = store.get("myblob").await.unwrap();
assert_eq!(&bytes[..], b"blob content");
}
#[cfg(feature = "storage")]
#[tokio::test]
async fn test_multipart_field_save_to_blob_store_too_large() {
use crate::storage::{BlobStore, LocalBlobStore, local::SigningKey};
use std::time::Duration;
let body = "--boundary\r\nContent-Disposition: form-data; name=\"file\"; filename=\"test.txt\"\r\nContent-Type: text/plain\r\n\r\nblob content\r\n--boundary--\r\n";
let req = Request::builder()
.header("content-type", "multipart/form-data; boundary=boundary")
.body(axum::body::Body::from(body))
.unwrap();
let mut multipart = axum::extract::Multipart::from_request(req, &())
.await
.unwrap();
let field = multipart.next_field().await.unwrap().unwrap();
let wrapper = MultipartField {
inner: field,
max_file_size_bytes: 4, };
let root = tempfile::tempdir().unwrap();
let store = LocalBlobStore::new(
"local",
root.path(),
"/blobs",
Duration::from_secs(3600),
SigningKey::random(),
vec![],
)
.unwrap();
let err = wrapper
.save_to_blob_store(&store, "myblob")
.await
.unwrap_err();
assert_eq!(err.status(), http::StatusCode::PAYLOAD_TOO_LARGE);
let get_err = store.get("myblob").await.unwrap_err();
assert_eq!(get_err.status(), http::StatusCode::NOT_FOUND);
}
#[tokio::test]
async fn test_multipart_field_metadata() {
let body = "--boundary\r\nContent-Disposition: form-data; name=\"custom_name\"; filename=\"custom_file.png\"\r\nContent-Type: image/png\r\n\r\npng\r\n--boundary--\r\n";
let req = Request::builder()
.header("content-type", "multipart/form-data; boundary=boundary")
.body(axum::body::Body::from(body))
.unwrap();
let mut multipart = axum::extract::Multipart::from_request(req, &())
.await
.unwrap();
let field = multipart.next_field().await.unwrap().unwrap();
let wrapper = MultipartField {
inner: field,
max_file_size_bytes: 100,
};
assert_eq!(wrapper.name(), Some("custom_name"));
assert_eq!(wrapper.file_name(), Some("custom_file.png"));
assert_eq!(wrapper.content_type(), Some("image/png"));
let tighter = wrapper.with_max_bytes(50);
assert_eq!(tighter.max_file_size_bytes, 50);
let not_tighter = tighter.with_max_bytes(200);
assert_eq!(not_tighter.max_file_size_bytes, 50); }
}