1use crate::storage::UploadedFile;
69use axum::{
70 extract::{multipart::Field, FromRequest, Multipart, Request},
71 http::StatusCode,
72 response::{IntoResponse, Response},
73};
74use std::fmt;
75
76pub const DEFAULT_MAX_FILE_SIZE: usize = 10 * 1024 * 1024;
78
79pub const DEFAULT_MAX_FILES: usize = 10;
81
82#[derive(Debug)]
84pub enum FileUploadError {
85 MissingFile,
87
88 MultipleFiles,
90
91 MultipartError(String),
93
94 FileTooLarge {
96 actual: usize,
98 max: usize,
100 },
101
102 TooManyFiles {
104 actual: usize,
106 max: usize,
108 },
109
110 MissingField(String),
112}
113
114impl fmt::Display for FileUploadError {
115 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
116 match self {
117 Self::MissingFile => write!(f, "No file found in upload"),
118 Self::MultipleFiles => write!(f, "Multiple files found, expected single file"),
119 Self::MultipartError(msg) => write!(f, "Multipart error: {msg}"),
120 Self::FileTooLarge { actual, max } => {
121 write!(f, "File size {actual} bytes exceeds maximum of {max} bytes")
122 }
123 Self::TooManyFiles { actual, max } => {
124 write!(f, "Upload contains {actual} files, maximum is {max}")
125 }
126 Self::MissingField(field) => write!(f, "Missing required field: {field}"),
127 }
128 }
129}
130
131impl std::error::Error for FileUploadError {}
132
133impl IntoResponse for FileUploadError {
134 fn into_response(self) -> Response {
135 let status = match self {
136 Self::FileTooLarge { .. } => StatusCode::PAYLOAD_TOO_LARGE,
137 Self::MissingFile | Self::MissingField(_) | Self::MultipleFiles | Self::TooManyFiles { .. } | Self::MultipartError(_) => {
138 StatusCode::BAD_REQUEST
139 }
140 };
141
142 (status, self.to_string()).into_response()
143 }
144}
145
146#[derive(Debug)]
164pub struct FileUpload(pub UploadedFile);
165
166impl<S> FromRequest<S> for FileUpload
167where
168 S: Send + Sync,
169{
170 type Rejection = FileUploadError;
171
172 #[allow(clippy::manual_async_fn)]
173 fn from_request(
174 req: Request,
175 state: &S,
176 ) -> impl std::future::Future<Output = Result<Self, Self::Rejection>> + Send {
177 async move {
178 let mut multipart = Multipart::from_request(req, state)
179 .await
180 .map_err(|e| FileUploadError::MultipartError(e.to_string()))?;
181
182 let mut files = Vec::new();
183
184 while let Some(field) = multipart
186 .next_field()
187 .await
188 .map_err(|e| FileUploadError::MultipartError(e.to_string()))?
189 {
190 if field.file_name().is_none() {
192 continue;
193 }
194
195 let filename = field
196 .file_name()
197 .ok_or_else(|| FileUploadError::MissingField("filename".to_string()))?
198 .to_string();
199
200 let content_type = field
201 .content_type()
202 .unwrap_or("application/octet-stream")
203 .to_string();
204
205 let data = read_field_data(field, DEFAULT_MAX_FILE_SIZE).await?;
207
208 files.push(UploadedFile {
209 filename,
210 content_type,
211 data,
212 });
213 }
214
215 match files.len() {
217 0 => Err(FileUploadError::MissingFile),
218 1 => Ok(Self(files.into_iter().next().unwrap())),
219 _ => Err(FileUploadError::MultipleFiles),
220 }
221 }
222 }
223}
224
225#[derive(Debug)]
243pub struct MultiFileUpload(pub Vec<UploadedFile>);
244
245impl<S> FromRequest<S> for MultiFileUpload
246where
247 S: Send + Sync,
248{
249 type Rejection = FileUploadError;
250
251 #[allow(clippy::manual_async_fn)]
252 fn from_request(
253 req: Request,
254 state: &S,
255 ) -> impl std::future::Future<Output = Result<Self, Self::Rejection>> + Send {
256 async move {
257 let mut multipart = Multipart::from_request(req, state)
258 .await
259 .map_err(|e| FileUploadError::MultipartError(e.to_string()))?;
260
261 let mut files = Vec::new();
262
263 while let Some(field) = multipart
264 .next_field()
265 .await
266 .map_err(|e| FileUploadError::MultipartError(e.to_string()))?
267 {
268 if field.file_name().is_none() {
270 continue;
271 }
272
273 if files.len() >= DEFAULT_MAX_FILES {
275 return Err(FileUploadError::TooManyFiles {
276 actual: files.len() + 1,
277 max: DEFAULT_MAX_FILES,
278 });
279 }
280
281 let filename = field
282 .file_name()
283 .ok_or_else(|| FileUploadError::MissingField("filename".to_string()))?
284 .to_string();
285
286 let content_type = field
287 .content_type()
288 .unwrap_or("application/octet-stream")
289 .to_string();
290
291 let data = read_field_data(field, DEFAULT_MAX_FILE_SIZE).await?;
293
294 files.push(UploadedFile {
295 filename,
296 content_type,
297 data,
298 });
299 }
300
301 if files.is_empty() {
302 return Err(FileUploadError::MissingFile);
303 }
304
305 Ok(Self(files))
306 }
307 }
308}
309
310async fn read_field_data(
315 field: Field<'_>,
316 max_size: usize,
317) -> Result<Vec<u8>, FileUploadError> {
318 let data = field
319 .bytes()
320 .await
321 .map_err(|e| FileUploadError::MultipartError(e.to_string()))?;
322
323 if data.len() > max_size {
325 return Err(FileUploadError::FileTooLarge {
326 actual: data.len(),
327 max: max_size,
328 });
329 }
330
331 Ok(data.to_vec())
332}
333
334#[cfg(test)]
335mod tests {
336 use super::*;
337 use axum::http::{header, Request};
338 use axum::body::Body;
339
340 fn create_multipart_request(files: Vec<(&str, &str, &[u8])>) -> Request<Body> {
341 use std::fmt::Write;
342
343 let boundary = "----WebKitFormBoundary7MA4YWxkTrZu0gW";
344
345 let mut body = String::new();
346
347 for (name, filename, content) in files {
348 body.push_str("------WebKitFormBoundary7MA4YWxkTrZu0gW\r\n");
349 write!(
350 &mut body,
351 "Content-Disposition: form-data; name=\"{name}\"; filename=\"{filename}\"\r\n"
352 ).unwrap();
353 body.push_str("Content-Type: application/octet-stream\r\n\r\n");
354 body.push_str(&String::from_utf8_lossy(content));
355 body.push_str("\r\n");
356 }
357
358 body.push_str("------WebKitFormBoundary7MA4YWxkTrZu0gW--\r\n");
359
360 Request::builder()
361 .method("POST")
362 .header(
363 header::CONTENT_TYPE,
364 format!("multipart/form-data; boundary={boundary}"),
365 )
366 .body(Body::from(body))
367 .unwrap()
368 }
369
370 #[tokio::test]
371 async fn test_single_file_upload() {
372 let req = create_multipart_request(vec![("file", "test.txt", b"Hello, World!")]);
373
374 let result = FileUpload::from_request(req, &()).await;
375 assert!(result.is_ok());
376
377 let FileUpload(file) = result.unwrap();
378 assert_eq!(file.filename, "test.txt");
379 assert_eq!(file.data, b"Hello, World!");
380 }
381
382 #[tokio::test]
383 async fn test_multiple_files_rejected_by_single_upload() {
384 let req = create_multipart_request(vec![
385 ("file1", "test1.txt", b"File 1"),
386 ("file2", "test2.txt", b"File 2"),
387 ]);
388
389 let result = FileUpload::from_request(req, &()).await;
390 assert!(result.is_err());
391 assert!(matches!(result.unwrap_err(), FileUploadError::MultipleFiles));
392 }
393
394 #[tokio::test]
395 async fn test_multi_file_upload() {
396 let req = create_multipart_request(vec![
397 ("file1", "test1.txt", b"File 1"),
398 ("file2", "test2.txt", b"File 2"),
399 ]);
400
401 let result = MultiFileUpload::from_request(req, &()).await;
402 assert!(result.is_ok());
403
404 let MultiFileUpload(files) = result.unwrap();
405 assert_eq!(files.len(), 2);
406 assert_eq!(files[0].filename, "test1.txt");
407 assert_eq!(files[1].filename, "test2.txt");
408 }
409
410 #[tokio::test]
411 async fn test_missing_file() {
412 let req = Request::builder()
413 .method("POST")
414 .header(
415 header::CONTENT_TYPE,
416 "multipart/form-data; boundary=----WebKitFormBoundary7MA4YWxkTrZu0gW",
417 )
418 .body(Body::from(
419 "------WebKitFormBoundary7MA4YWxkTrZu0gW--\r\n",
420 ))
421 .unwrap();
422
423 let result = FileUpload::from_request(req, &()).await;
424 assert!(result.is_err());
425 assert!(matches!(result.unwrap_err(), FileUploadError::MissingFile));
426 }
427
428 }