use axum::extract::FromRequest;
use axum::extract::multipart::{Field, Multipart};
use axum::http::{Method, Request, StatusCode};
use axum::response::{IntoResponse, Response};
use crate::error::{Error, Result};
use crate::middleware::csrf::{CSRF_COOKIE_NAME, constant_time_eq, extract_cookie_value};
use crate::storage::Storage;
#[derive(Debug)]
pub struct UploadedFile {
filename: Option<String>,
content_type: Option<String>,
data: Vec<u8>,
}
impl UploadedFile {
pub async fn from_field(field: Field<'_>) -> Result<Self> {
let filename = field.file_name().map(|s| s.to_owned());
let content_type = field.content_type().map(|s| s.to_owned());
let data = field
.bytes()
.await
.map_err(|e| Error::BadRequest(format!("Failed to read upload: {e}")))?
.to_vec();
Ok(Self {
filename,
content_type,
data,
})
}
pub fn filename(&self) -> Option<&str> {
self.filename.as_deref()
}
pub fn content_type(&self) -> Option<&str> {
self.content_type.as_deref()
}
pub fn size(&self) -> usize {
self.data.len()
}
pub fn data(&self) -> &[u8] {
&self.data
}
pub fn into_bytes(self) -> Vec<u8> {
self.data
}
pub fn validate(self) -> FileValidator {
FileValidator {
file: self,
error: None,
}
}
pub async fn save(self, storage: &Storage, path: &str) -> Result<()> {
storage.put(path, self.data).await?;
Ok(())
}
}
pub struct FileValidator {
file: UploadedFile,
error: Option<String>,
}
impl FileValidator {
pub fn max_size(mut self, max: usize) -> Self {
if self.error.is_none() && self.file.size() > max {
let mb = max / (1024 * 1024);
self.error = Some(format!("File too large (max {}MB)", mb.max(1)));
}
self
}
pub fn allowed_types(mut self, types: &[&str]) -> Self {
if self.error.is_none() {
let ct = self
.file
.content_type()
.unwrap_or("application/octet-stream");
if !types.contains(&ct) {
self.error = Some(format!("File type {ct} not allowed"));
}
}
self
}
pub async fn save(self, storage: &Storage, path: &str) -> Result<()> {
if let Some(msg) = self.error {
return Err(Error::BadRequest(msg));
}
storage.put(path, self.file.data).await?;
Ok(())
}
pub fn finish(self) -> Result<UploadedFile> {
if let Some(msg) = self.error {
return Err(Error::BadRequest(msg));
}
Ok(self.file)
}
}
pub struct MultipartForm {
inner: Multipart,
}
impl MultipartForm {
pub async fn next_field(&mut self) -> Result<Option<Field<'_>>> {
self.inner
.next_field()
.await
.map_err(|e| Error::BadRequest(format!("Multipart error: {e}")))
}
}
impl<S> FromRequest<S> for MultipartForm
where
S: Send + Sync,
{
type Rejection = Response;
async fn from_request(
request: Request<axum::body::Body>,
state: &S,
) -> std::result::Result<Self, Self::Rejection> {
let (parts, body) = request.into_parts();
let needs_csrf = !matches!(parts.method, Method::GET | Method::HEAD | Method::OPTIONS);
let cookie_token = if needs_csrf {
extract_cookie_value(&parts.headers, CSRF_COOKIE_NAME)
} else {
None
};
let header_token = parts
.headers
.get("x-csrf-token")
.and_then(|v| v.to_str().ok())
.map(String::from);
let request = Request::from_parts(parts, body);
let mut multipart = Multipart::from_request(request, state)
.await
.map_err(|e| e.into_response())?;
if needs_csrf {
if let (Some(header), Some(cookie)) = (&header_token, &cookie_token) {
if constant_time_eq(header, cookie) {
return Ok(Self { inner: multipart });
}
}
let field = multipart
.next_field()
.await
.map_err(|_| StatusCode::BAD_REQUEST.into_response())?
.ok_or_else(|| StatusCode::FORBIDDEN.into_response())?;
if field.name() != Some("_csrf") {
return Err(StatusCode::FORBIDDEN.into_response());
}
let token = field
.text()
.await
.map_err(|_| StatusCode::BAD_REQUEST.into_response())?;
match &cookie_token {
Some(cookie) if constant_time_eq(&token, cookie) => {}
_ => return Err(StatusCode::FORBIDDEN.into_response()),
}
}
Ok(Self { inner: multipart })
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn uploaded_file_accessors() {
let file = UploadedFile {
filename: Some("avatar.png".into()),
content_type: Some("image/png".into()),
data: vec![1, 2, 3, 4],
};
assert_eq!(file.filename(), Some("avatar.png"));
assert_eq!(file.content_type(), Some("image/png"));
assert_eq!(file.size(), 4);
assert_eq!(file.data(), &[1, 2, 3, 4]);
}
#[test]
fn uploaded_file_into_bytes() {
let file = UploadedFile {
filename: None,
content_type: None,
data: vec![10, 20, 30],
};
assert_eq!(file.into_bytes(), vec![10, 20, 30]);
}
#[test]
fn validator_max_size_passes() {
let file = UploadedFile {
filename: None,
content_type: None,
data: vec![0; 100],
};
let result = file.validate().max_size(1024).finish();
assert!(result.is_ok());
}
#[test]
fn validator_max_size_rejects() {
let file = UploadedFile {
filename: None,
content_type: None,
data: vec![0; 2 * 1024 * 1024],
};
let result = file.validate().max_size(1024 * 1024).finish();
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("too large"), "error: {err}");
}
#[test]
fn validator_allowed_types_passes() {
let file = UploadedFile {
filename: None,
content_type: Some("image/png".into()),
data: vec![0],
};
let result = file
.validate()
.allowed_types(&["image/png", "image/jpeg"])
.finish();
assert!(result.is_ok());
}
#[test]
fn validator_allowed_types_rejects() {
let file = UploadedFile {
filename: None,
content_type: Some("image/gif".into()),
data: vec![0],
};
let result = file
.validate()
.allowed_types(&["image/png", "image/jpeg"])
.finish();
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("image/gif"), "error: {err}");
}
#[test]
fn validator_chain_stops_at_first_error() {
let file = UploadedFile {
filename: None,
content_type: Some("text/plain".into()),
data: vec![0; 2 * 1024 * 1024],
};
let result = file
.validate()
.max_size(1024 * 1024)
.allowed_types(&["image/png"])
.finish();
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("too large"), "error: {err}");
}
#[test]
fn validator_no_content_type_treated_as_octet_stream() {
let file = UploadedFile {
filename: None,
content_type: None,
data: vec![0],
};
let result = file.validate().allowed_types(&["image/png"]).finish();
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("application/octet-stream"), "error: {err}");
}
}