Skip to main content

karbon_framework/storage/
upload.rs

1use axum::extract::Multipart;
2use std::path::{Path, PathBuf};
3use uuid::Uuid;
4
5use crate::error::{AppError, AppResult};
6
7/// Uploaded file info returned after processing
8#[derive(Debug, Clone, serde::Serialize)]
9pub struct UploadedFile {
10    pub filename: String,
11    pub original_name: String,
12    pub filepath: String,
13    pub mime_type: String,
14    pub size: u64,
15}
16
17/// Upload configuration
18pub struct UploadConfig {
19    pub upload_dir: PathBuf,
20    pub max_file_size: u64,
21    pub max_files: usize,
22    pub allowed_mimes: AllowedMimes,
23}
24
25impl UploadConfig {
26    pub fn new(upload_dir: impl Into<PathBuf>) -> Self {
27        Self {
28            upload_dir: upload_dir.into(),
29            max_file_size: 10 * 1024 * 1024, // 10 Mo
30            max_files: 10,
31            allowed_mimes: AllowedMimes::images(),
32        }
33    }
34
35    pub fn max_file_size(mut self, size: u64) -> Self {
36        self.max_file_size = size;
37        self
38    }
39
40    pub fn max_files(mut self, count: usize) -> Self {
41        self.max_files = count;
42        self
43    }
44
45    pub fn allowed_mimes(mut self, mimes: AllowedMimes) -> Self {
46        self.allowed_mimes = mimes;
47        self
48    }
49}
50
51/// Allowed MIME types for upload validation, with associated safe extensions
52pub struct AllowedMimes(pub Vec<MimeEntry>);
53
54/// A MIME type entry with its magic bytes signature and allowed extensions
55pub struct MimeEntry {
56    pub mime: &'static str,
57    pub extensions: &'static [&'static str],
58    pub magic_bytes: &'static [MagicSignature],
59}
60
61/// File signature (magic bytes) for content-based type detection
62pub struct MagicSignature {
63    pub offset: usize,
64    pub bytes: &'static [u8],
65}
66
67impl AllowedMimes {
68    pub fn images() -> Self {
69        Self(vec![
70            MimeEntry {
71                mime: "image/jpeg",
72                extensions: &["jpg", "jpeg"],
73                magic_bytes: &[MagicSignature {
74                    offset: 0,
75                    bytes: &[0xFF, 0xD8, 0xFF],
76                }],
77            },
78            MimeEntry {
79                mime: "image/png",
80                extensions: &["png"],
81                magic_bytes: &[MagicSignature {
82                    offset: 0,
83                    bytes: &[0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A],
84                }],
85            },
86            MimeEntry {
87                mime: "image/webp",
88                extensions: &["webp"],
89                magic_bytes: &[
90                    MagicSignature {
91                        offset: 0,
92                        bytes: b"RIFF",
93                    },
94                    MagicSignature {
95                        offset: 8,
96                        bytes: b"WEBP",
97                    },
98                ],
99            },
100            MimeEntry {
101                mime: "image/gif",
102                extensions: &["gif"],
103                magic_bytes: &[MagicSignature {
104                    offset: 0,
105                    bytes: b"GIF8",
106                }],
107            },
108        ])
109    }
110
111    pub fn images_with_svg() -> Self {
112        let mut mimes = Self::images();
113        mimes.0.push(MimeEntry {
114            mime: "image/svg+xml",
115            extensions: &["svg"],
116            magic_bytes: &[], // SVG validated separately via content check
117        });
118        mimes
119    }
120
121    pub fn documents() -> Self {
122        Self(vec![
123            MimeEntry {
124                mime: "application/pdf",
125                extensions: &["pdf"],
126                magic_bytes: &[MagicSignature {
127                    offset: 0,
128                    bytes: b"%PDF",
129                }],
130            },
131            MimeEntry {
132                mime: "text/plain",
133                extensions: &["txt"],
134                magic_bytes: &[],
135            },
136        ])
137    }
138
139    pub fn all_media() -> Self {
140        let mut mimes = Self::images();
141        mimes.0.extend(vec![
142            MimeEntry {
143                mime: "video/mp4",
144                extensions: &["mp4"],
145                magic_bytes: &[MagicSignature {
146                    offset: 4,
147                    bytes: b"ftyp",
148                }],
149            },
150            MimeEntry {
151                mime: "video/webm",
152                extensions: &["webm"],
153                magic_bytes: &[MagicSignature {
154                    offset: 0,
155                    bytes: &[0x1A, 0x45, 0xDF, 0xA3],
156                }],
157            },
158            MimeEntry {
159                mime: "audio/mpeg",
160                extensions: &["mp3"],
161                magic_bytes: &[
162                    MagicSignature {
163                        offset: 0,
164                        bytes: &[0xFF, 0xFB],
165                    },
166                    MagicSignature {
167                        offset: 0,
168                        bytes: b"ID3",
169                    },
170                ],
171            },
172        ]);
173        mimes
174    }
175
176    /// Find a MIME entry matching the given content type
177    fn find_entry(&self, mime: &str) -> Option<&MimeEntry> {
178        self.0.iter().find(|e| e.mime == mime)
179    }
180
181    /// Detect the real MIME type by inspecting magic bytes
182    fn detect_mime_from_bytes(&self, data: &[u8]) -> Option<&str> {
183        for entry in &self.0 {
184            if entry.magic_bytes.is_empty() {
185                continue;
186            }
187            let all_match = entry.magic_bytes.iter().all(|sig| {
188                if sig.offset + sig.bytes.len() > data.len() {
189                    return false;
190                }
191                &data[sig.offset..sig.offset + sig.bytes.len()] == sig.bytes
192            });
193            if all_match {
194                return Some(entry.mime);
195            }
196        }
197        None
198    }
199}
200
201/// Sanitize an extension: lowercase, alphanumeric only
202fn sanitize_extension(ext: &str) -> String {
203    ext.chars()
204        .filter(|c| c.is_ascii_alphanumeric())
205        .collect::<String>()
206        .to_lowercase()
207}
208
209/// Validate that SVG content doesn't contain dangerous elements
210fn validate_svg_content(data: &[u8]) -> AppResult<()> {
211    let content = std::str::from_utf8(data)
212        .map_err(|_| AppError::BadRequest("Invalid SVG: not valid UTF-8".to_string()))?;
213
214    let lower = content.to_lowercase();
215
216    let dangerous = [
217        "<script",
218        "javascript:",
219        "onerror",
220        "onload",
221        "onclick",
222        "onmouseover",
223        "onfocus",
224        "onblur",
225        "eval(",
226        "expression(",
227        "url(data:",
228        "<!entity",
229        "<!doctype",
230        "xlink:href=\"data:",
231        "xlink:href=\"javascript:",
232    ];
233
234    for pattern in &dangerous {
235        if lower.contains(pattern) {
236            return Err(AppError::BadRequest(format!(
237                "SVG contains forbidden content: {}",
238                pattern
239            )));
240        }
241    }
242
243    Ok(())
244}
245
246/// Safely resolve a path within a base directory, preventing path traversal
247fn safe_resolve(base: &Path, filename: &str) -> AppResult<PathBuf> {
248    let resolved = base.join(filename);
249    let canonical_base = base
250        .canonicalize()
251        .unwrap_or_else(|_| base.to_path_buf());
252    let canonical_resolved = resolved
253        .parent()
254        .and_then(|p| p.canonicalize().ok())
255        .unwrap_or_else(|| canonical_base.clone());
256
257    if !canonical_resolved.starts_with(&canonical_base) {
258        return Err(AppError::BadRequest(
259            "Invalid file path: path traversal detected".to_string(),
260        ));
261    }
262
263    Ok(resolved)
264}
265
266/// Handle file upload from multipart form data
267pub async fn handle_upload(
268    multipart: Multipart,
269    config: &UploadConfig,
270) -> AppResult<Vec<UploadedFile>> {
271    handle_upload_inner(multipart, config).await
272}
273
274async fn handle_upload_inner(
275    mut multipart: Multipart,
276    config: &UploadConfig,
277) -> AppResult<Vec<UploadedFile>> {
278    let mut files = Vec::new();
279
280    while let Some(field) = multipart
281        .next_field()
282        .await
283        .map_err(|e| AppError::BadRequest(format!("Multipart error: {}", e)))?
284    {
285        // Limit number of files
286        if files.len() >= config.max_files {
287            return Err(AppError::BadRequest(format!(
288                "Too many files. Maximum allowed: {}",
289                config.max_files
290            )));
291        }
292
293        let original_name = field
294            .file_name()
295            .map(|s| s.to_string())
296            .unwrap_or_else(|| "unknown".to_string());
297
298        let claimed_content_type = field
299            .content_type()
300            .map(|s| s.to_string())
301            .unwrap_or_else(|| "application/octet-stream".to_string());
302
303        // Read file data
304        let data = field
305            .bytes()
306            .await
307            .map_err(|e| AppError::BadRequest(format!("Failed to read file: {}", e)))?;
308
309        // Validate size
310        if data.len() as u64 > config.max_file_size {
311            return Err(AppError::BadRequest(format!(
312                "File '{}' exceeds maximum size of {} bytes",
313                original_name, config.max_file_size
314            )));
315        }
316
317        // Reject empty files
318        if data.is_empty() {
319            return Err(AppError::BadRequest("Empty file uploaded".to_string()));
320        }
321
322        // Detect real MIME type from magic bytes
323        let detected_mime = config.allowed_mimes.detect_mime_from_bytes(&data);
324
325        // Determine actual MIME type: prefer detected, fallback to claimed for types
326        // without magic bytes (like SVG, TXT)
327        let actual_mime = if let Some(detected) = detected_mime {
328            detected.to_string()
329        } else if config.allowed_mimes.find_entry(&claimed_content_type).is_some() {
330            // Only trust claimed type if the entry has no magic bytes to check
331            let entry = config.allowed_mimes.find_entry(&claimed_content_type).unwrap();
332            if entry.magic_bytes.is_empty() {
333                claimed_content_type.clone()
334            } else {
335                return Err(AppError::BadRequest(format!(
336                    "File '{}': content does not match declared type '{}'",
337                    original_name, claimed_content_type
338                )));
339            }
340        } else {
341            return Err(AppError::BadRequest(format!(
342                "File type '{}' is not allowed",
343                claimed_content_type
344            )));
345        };
346
347        // Verify the MIME type is in the allowed list
348        let mime_entry = config
349            .allowed_mimes
350            .find_entry(&actual_mime)
351            .ok_or_else(|| {
352                AppError::BadRequest(format!("File type '{}' is not allowed", actual_mime))
353            })?;
354
355        // Validate extension against whitelist for this MIME type
356        let original_ext = Path::new(&original_name)
357            .extension()
358            .and_then(|e| e.to_str())
359            .map(|e| sanitize_extension(e))
360            .unwrap_or_default();
361
362        let safe_extension = if mime_entry.extensions.contains(&original_ext.as_str()) {
363            original_ext
364        } else {
365            mime_entry.extensions[0].to_string()
366        };
367
368        // SVG-specific content validation
369        if actual_mime == "image/svg+xml" {
370            validate_svg_content(&data)?;
371        }
372
373        // Generate unique filename with safe extension
374        let filename = format!("{}.{}", Uuid::new_v4(), safe_extension);
375
376        // Ensure directory exists
377        tokio::fs::create_dir_all(&config.upload_dir)
378            .await
379            .map_err(|e| AppError::Internal(format!("Failed to create upload dir: {}", e)))?;
380
381        // Resolve path safely (prevents traversal)
382        let filepath = safe_resolve(&config.upload_dir, &filename)?;
383
384        // Write file
385        tokio::fs::write(&filepath, &data)
386            .await
387            .map_err(|e| AppError::Internal(format!("Failed to write file: {}", e)))?;
388
389        files.push(UploadedFile {
390            filename,
391            original_name,
392            filepath: config.upload_dir.to_string_lossy().to_string(),
393            mime_type: actual_mime,
394            size: data.len() as u64,
395        });
396    }
397
398    Ok(files)
399}
400
401#[cfg(test)]
402mod tests {
403    use super::*;
404
405    #[test]
406    fn test_sanitize_extension() {
407        assert_eq!(sanitize_extension("jpg"), "jpg");
408        assert_eq!(sanitize_extension("JPG"), "jpg");
409        assert_eq!(sanitize_extension("ph.p"), "php");
410        assert_eq!(sanitize_extension("j/p\\g"), "jpg");
411    }
412
413    #[test]
414    fn test_detect_jpeg() {
415        let allowed = AllowedMimes::images();
416        let jpeg_header = [0xFF, 0xD8, 0xFF, 0xE0, 0x00, 0x10];
417        assert_eq!(allowed.detect_mime_from_bytes(&jpeg_header), Some("image/jpeg"));
418    }
419
420    #[test]
421    fn test_detect_png() {
422        let allowed = AllowedMimes::images();
423        let png_header = [0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A];
424        assert_eq!(allowed.detect_mime_from_bytes(&png_header), Some("image/png"));
425    }
426
427    #[test]
428    fn test_detect_gif() {
429        let allowed = AllowedMimes::images();
430        assert_eq!(
431            allowed.detect_mime_from_bytes(b"GIF89a"),
432            Some("image/gif")
433        );
434    }
435
436    #[test]
437    fn test_detect_unknown() {
438        let allowed = AllowedMimes::images();
439        assert_eq!(allowed.detect_mime_from_bytes(b"random data"), None);
440    }
441
442    #[test]
443    fn test_svg_validation_safe() {
444        let svg = b"<svg xmlns=\"http://www.w3.org/2000/svg\"><circle r=\"10\"/></svg>";
445        assert!(validate_svg_content(svg).is_ok());
446    }
447
448    #[test]
449    fn test_svg_validation_xss_script() {
450        let svg = b"<svg><script>alert('xss')</script></svg>";
451        assert!(validate_svg_content(svg).is_err());
452    }
453
454    #[test]
455    fn test_svg_validation_xss_event() {
456        let svg = b"<svg><rect onload=\"alert('xss')\"/></svg>";
457        assert!(validate_svg_content(svg).is_err());
458    }
459
460    #[test]
461    fn test_svg_validation_xss_javascript_uri() {
462        let svg = b"<svg><a xlink:href=\"javascript:alert(1)\">click</a></svg>";
463        assert!(validate_svg_content(svg).is_err());
464    }
465
466    #[test]
467    fn test_extension_whitelist() {
468        let allowed = AllowedMimes::images();
469        let entry = allowed.find_entry("image/jpeg").unwrap();
470        assert!(entry.extensions.contains(&"jpg"));
471        assert!(entry.extensions.contains(&"jpeg"));
472    }
473
474    #[test]
475    fn test_safe_resolve_normal() {
476        let base = Path::new("/uploads");
477        let result = safe_resolve(base, "file.jpg");
478        assert!(result.is_ok());
479    }
480}