1use axum::extract::Multipart;
2use std::path::{Path, PathBuf};
3use uuid::Uuid;
4
5use crate::error::{AppError, AppResult};
6
7#[derive(Debug, Clone, serde::Serialize)]
9pub struct UploadedFile {
10 pub filename: String,
11 pub original_name: String,
12 pub filepath: String,
13 pub mime_type: String,
14 pub size: u64,
15}
16
17pub struct UploadConfig {
19 pub upload_dir: PathBuf,
20 pub max_file_size: u64,
21 pub max_files: usize,
22 pub allowed_mimes: AllowedMimes,
23}
24
25impl UploadConfig {
26 pub fn new(upload_dir: impl Into<PathBuf>) -> Self {
27 Self {
28 upload_dir: upload_dir.into(),
29 max_file_size: 10 * 1024 * 1024, max_files: 10,
31 allowed_mimes: AllowedMimes::images(),
32 }
33 }
34
35 pub fn max_file_size(mut self, size: u64) -> Self {
36 self.max_file_size = size;
37 self
38 }
39
40 pub fn max_files(mut self, count: usize) -> Self {
41 self.max_files = count;
42 self
43 }
44
45 pub fn allowed_mimes(mut self, mimes: AllowedMimes) -> Self {
46 self.allowed_mimes = mimes;
47 self
48 }
49}
50
51pub struct AllowedMimes(pub Vec<MimeEntry>);
53
54pub struct MimeEntry {
56 pub mime: &'static str,
57 pub extensions: &'static [&'static str],
58 pub magic_bytes: &'static [MagicSignature],
59}
60
61pub struct MagicSignature {
63 pub offset: usize,
64 pub bytes: &'static [u8],
65}
66
67impl AllowedMimes {
68 pub fn images() -> Self {
69 Self(vec![
70 MimeEntry {
71 mime: "image/jpeg",
72 extensions: &["jpg", "jpeg"],
73 magic_bytes: &[MagicSignature {
74 offset: 0,
75 bytes: &[0xFF, 0xD8, 0xFF],
76 }],
77 },
78 MimeEntry {
79 mime: "image/png",
80 extensions: &["png"],
81 magic_bytes: &[MagicSignature {
82 offset: 0,
83 bytes: &[0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A],
84 }],
85 },
86 MimeEntry {
87 mime: "image/webp",
88 extensions: &["webp"],
89 magic_bytes: &[
90 MagicSignature {
91 offset: 0,
92 bytes: b"RIFF",
93 },
94 MagicSignature {
95 offset: 8,
96 bytes: b"WEBP",
97 },
98 ],
99 },
100 MimeEntry {
101 mime: "image/gif",
102 extensions: &["gif"],
103 magic_bytes: &[MagicSignature {
104 offset: 0,
105 bytes: b"GIF8",
106 }],
107 },
108 ])
109 }
110
111 pub fn images_with_svg() -> Self {
112 let mut mimes = Self::images();
113 mimes.0.push(MimeEntry {
114 mime: "image/svg+xml",
115 extensions: &["svg"],
116 magic_bytes: &[], });
118 mimes
119 }
120
121 pub fn documents() -> Self {
122 Self(vec![
123 MimeEntry {
124 mime: "application/pdf",
125 extensions: &["pdf"],
126 magic_bytes: &[MagicSignature {
127 offset: 0,
128 bytes: b"%PDF",
129 }],
130 },
131 MimeEntry {
132 mime: "text/plain",
133 extensions: &["txt"],
134 magic_bytes: &[],
135 },
136 ])
137 }
138
139 pub fn all_media() -> Self {
140 let mut mimes = Self::images();
141 mimes.0.extend(vec![
142 MimeEntry {
143 mime: "video/mp4",
144 extensions: &["mp4"],
145 magic_bytes: &[MagicSignature {
146 offset: 4,
147 bytes: b"ftyp",
148 }],
149 },
150 MimeEntry {
151 mime: "video/webm",
152 extensions: &["webm"],
153 magic_bytes: &[MagicSignature {
154 offset: 0,
155 bytes: &[0x1A, 0x45, 0xDF, 0xA3],
156 }],
157 },
158 MimeEntry {
159 mime: "audio/mpeg",
160 extensions: &["mp3"],
161 magic_bytes: &[
162 MagicSignature {
163 offset: 0,
164 bytes: &[0xFF, 0xFB],
165 },
166 MagicSignature {
167 offset: 0,
168 bytes: b"ID3",
169 },
170 ],
171 },
172 ]);
173 mimes
174 }
175
176 fn find_entry(&self, mime: &str) -> Option<&MimeEntry> {
178 self.0.iter().find(|e| e.mime == mime)
179 }
180
181 fn detect_mime_from_bytes(&self, data: &[u8]) -> Option<&str> {
183 for entry in &self.0 {
184 if entry.magic_bytes.is_empty() {
185 continue;
186 }
187 let all_match = entry.magic_bytes.iter().all(|sig| {
188 if sig.offset + sig.bytes.len() > data.len() {
189 return false;
190 }
191 &data[sig.offset..sig.offset + sig.bytes.len()] == sig.bytes
192 });
193 if all_match {
194 return Some(entry.mime);
195 }
196 }
197 None
198 }
199}
200
201fn sanitize_extension(ext: &str) -> String {
203 ext.chars()
204 .filter(|c| c.is_ascii_alphanumeric())
205 .collect::<String>()
206 .to_lowercase()
207}
208
209fn validate_svg_content(data: &[u8]) -> AppResult<()> {
211 let content = std::str::from_utf8(data)
212 .map_err(|_| AppError::BadRequest("Invalid SVG: not valid UTF-8".to_string()))?;
213
214 let lower = content.to_lowercase();
215
216 let dangerous = [
217 "<script",
218 "javascript:",
219 "onerror",
220 "onload",
221 "onclick",
222 "onmouseover",
223 "onfocus",
224 "onblur",
225 "eval(",
226 "expression(",
227 "url(data:",
228 "<!entity",
229 "<!doctype",
230 "xlink:href=\"data:",
231 "xlink:href=\"javascript:",
232 ];
233
234 for pattern in &dangerous {
235 if lower.contains(pattern) {
236 return Err(AppError::BadRequest(format!(
237 "SVG contains forbidden content: {}",
238 pattern
239 )));
240 }
241 }
242
243 Ok(())
244}
245
246fn safe_resolve(base: &Path, filename: &str) -> AppResult<PathBuf> {
248 let resolved = base.join(filename);
249 let canonical_base = base
250 .canonicalize()
251 .unwrap_or_else(|_| base.to_path_buf());
252 let canonical_resolved = resolved
253 .parent()
254 .and_then(|p| p.canonicalize().ok())
255 .unwrap_or_else(|| canonical_base.clone());
256
257 if !canonical_resolved.starts_with(&canonical_base) {
258 return Err(AppError::BadRequest(
259 "Invalid file path: path traversal detected".to_string(),
260 ));
261 }
262
263 Ok(resolved)
264}
265
266pub async fn handle_upload(
268 multipart: Multipart,
269 config: &UploadConfig,
270) -> AppResult<Vec<UploadedFile>> {
271 handle_upload_inner(multipart, config).await
272}
273
274async fn handle_upload_inner(
275 mut multipart: Multipart,
276 config: &UploadConfig,
277) -> AppResult<Vec<UploadedFile>> {
278 let mut files = Vec::new();
279
280 while let Some(field) = multipart
281 .next_field()
282 .await
283 .map_err(|e| AppError::BadRequest(format!("Multipart error: {}", e)))?
284 {
285 if files.len() >= config.max_files {
287 return Err(AppError::BadRequest(format!(
288 "Too many files. Maximum allowed: {}",
289 config.max_files
290 )));
291 }
292
293 let original_name = field
294 .file_name()
295 .map(|s| s.to_string())
296 .unwrap_or_else(|| "unknown".to_string());
297
298 let claimed_content_type = field
299 .content_type()
300 .map(|s| s.to_string())
301 .unwrap_or_else(|| "application/octet-stream".to_string());
302
303 let data = field
305 .bytes()
306 .await
307 .map_err(|e| AppError::BadRequest(format!("Failed to read file: {}", e)))?;
308
309 if data.len() as u64 > config.max_file_size {
311 return Err(AppError::BadRequest(format!(
312 "File '{}' exceeds maximum size of {} bytes",
313 original_name, config.max_file_size
314 )));
315 }
316
317 if data.is_empty() {
319 return Err(AppError::BadRequest("Empty file uploaded".to_string()));
320 }
321
322 let detected_mime = config.allowed_mimes.detect_mime_from_bytes(&data);
324
325 let actual_mime = if let Some(detected) = detected_mime {
328 detected.to_string()
329 } else if config.allowed_mimes.find_entry(&claimed_content_type).is_some() {
330 let entry = config.allowed_mimes.find_entry(&claimed_content_type).unwrap();
332 if entry.magic_bytes.is_empty() {
333 claimed_content_type.clone()
334 } else {
335 return Err(AppError::BadRequest(format!(
336 "File '{}': content does not match declared type '{}'",
337 original_name, claimed_content_type
338 )));
339 }
340 } else {
341 return Err(AppError::BadRequest(format!(
342 "File type '{}' is not allowed",
343 claimed_content_type
344 )));
345 };
346
347 let mime_entry = config
349 .allowed_mimes
350 .find_entry(&actual_mime)
351 .ok_or_else(|| {
352 AppError::BadRequest(format!("File type '{}' is not allowed", actual_mime))
353 })?;
354
355 let original_ext = Path::new(&original_name)
357 .extension()
358 .and_then(|e| e.to_str())
359 .map(|e| sanitize_extension(e))
360 .unwrap_or_default();
361
362 let safe_extension = if mime_entry.extensions.contains(&original_ext.as_str()) {
363 original_ext
364 } else {
365 mime_entry.extensions[0].to_string()
366 };
367
368 if actual_mime == "image/svg+xml" {
370 validate_svg_content(&data)?;
371 }
372
373 let filename = format!("{}.{}", Uuid::new_v4(), safe_extension);
375
376 tokio::fs::create_dir_all(&config.upload_dir)
378 .await
379 .map_err(|e| AppError::Internal(format!("Failed to create upload dir: {}", e)))?;
380
381 let filepath = safe_resolve(&config.upload_dir, &filename)?;
383
384 tokio::fs::write(&filepath, &data)
386 .await
387 .map_err(|e| AppError::Internal(format!("Failed to write file: {}", e)))?;
388
389 files.push(UploadedFile {
390 filename,
391 original_name,
392 filepath: config.upload_dir.to_string_lossy().to_string(),
393 mime_type: actual_mime,
394 size: data.len() as u64,
395 });
396 }
397
398 Ok(files)
399}
400
401#[cfg(test)]
402mod tests {
403 use super::*;
404
405 #[test]
406 fn test_sanitize_extension() {
407 assert_eq!(sanitize_extension("jpg"), "jpg");
408 assert_eq!(sanitize_extension("JPG"), "jpg");
409 assert_eq!(sanitize_extension("ph.p"), "php");
410 assert_eq!(sanitize_extension("j/p\\g"), "jpg");
411 }
412
413 #[test]
414 fn test_detect_jpeg() {
415 let allowed = AllowedMimes::images();
416 let jpeg_header = [0xFF, 0xD8, 0xFF, 0xE0, 0x00, 0x10];
417 assert_eq!(allowed.detect_mime_from_bytes(&jpeg_header), Some("image/jpeg"));
418 }
419
420 #[test]
421 fn test_detect_png() {
422 let allowed = AllowedMimes::images();
423 let png_header = [0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A];
424 assert_eq!(allowed.detect_mime_from_bytes(&png_header), Some("image/png"));
425 }
426
427 #[test]
428 fn test_detect_gif() {
429 let allowed = AllowedMimes::images();
430 assert_eq!(
431 allowed.detect_mime_from_bytes(b"GIF89a"),
432 Some("image/gif")
433 );
434 }
435
436 #[test]
437 fn test_detect_unknown() {
438 let allowed = AllowedMimes::images();
439 assert_eq!(allowed.detect_mime_from_bytes(b"random data"), None);
440 }
441
442 #[test]
443 fn test_svg_validation_safe() {
444 let svg = b"<svg xmlns=\"http://www.w3.org/2000/svg\"><circle r=\"10\"/></svg>";
445 assert!(validate_svg_content(svg).is_ok());
446 }
447
448 #[test]
449 fn test_svg_validation_xss_script() {
450 let svg = b"<svg><script>alert('xss')</script></svg>";
451 assert!(validate_svg_content(svg).is_err());
452 }
453
454 #[test]
455 fn test_svg_validation_xss_event() {
456 let svg = b"<svg><rect onload=\"alert('xss')\"/></svg>";
457 assert!(validate_svg_content(svg).is_err());
458 }
459
460 #[test]
461 fn test_svg_validation_xss_javascript_uri() {
462 let svg = b"<svg><a xlink:href=\"javascript:alert(1)\">click</a></svg>";
463 assert!(validate_svg_content(svg).is_err());
464 }
465
466 #[test]
467 fn test_extension_whitelist() {
468 let allowed = AllowedMimes::images();
469 let entry = allowed.find_entry("image/jpeg").unwrap();
470 assert!(entry.extensions.contains(&"jpg"));
471 assert!(entry.extensions.contains(&"jpeg"));
472 }
473
474 #[test]
475 fn test_safe_resolve_normal() {
476 let base = Path::new("/uploads");
477 let result = safe_resolve(base, "file.jpg");
478 assert!(result.is_ok());
479 }
480}