1use super::error::MediaError;
12
13#[derive(Debug, Clone, PartialEq)]
15pub struct DetectedMime {
16 pub mime_type: String,
18
19 pub extension: String,
21
22 pub source: DetectionSource,
24}
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28#[allow(dead_code)] pub enum DetectionSource {
30 MagicBytes,
32
33 Extension,
35
36 ServerHint,
38}
39
40pub fn detect_mime(data: &[u8], server_mime: Option<&str>) -> Result<DetectedMime, MediaError> {
47 let server_mime_normalized: Option<String> = server_mime.map(|m| m.to_ascii_lowercase());
49 let server_mime_ref = server_mime_normalized.as_deref();
50
51 let inspect_len = data.len().min(8192);
52 let sample = &data[..inspect_len];
53
54 if sample.len() >= 5 {
58 let text_start = std::str::from_utf8(&sample[..sample.len().min(512)]);
59 if let Ok(text) = text_start {
60 let trimmed = text.trim_start();
61 let has_xml_prefix = trimmed.starts_with("<?xml");
62 let has_svg_tag = has_svg_element(trimmed);
63 let has_svg_ns = trimmed.contains("xmlns=\"http://www.w3.org/2000/svg\"")
64 || trimmed.contains("xmlns='http://www.w3.org/2000/svg'");
65 if has_svg_tag || (has_xml_prefix && has_svg_ns) {
66 return Ok(DetectedMime {
67 mime_type: "image/svg+xml".to_string(),
68 extension: "svg".to_string(),
69 source: DetectionSource::MagicBytes,
70 });
71 }
72 }
73 }
74
75 if let Some(kind) = infer::get(sample) {
77 let mime_type = kind.mime_type().to_string();
78 let extension = kind.extension().to_string();
79
80 if let Some(server) = server_mime_ref {
82 if !is_mime_alias(&mime_type, server) {
83 let detected_category = mime_type.split('/').next();
84 let server_category = server.split('/').next();
85 if detected_category != server_category {
86 return Err(MediaError::MimeDetectionFailed {
88 reason: format!(
89 "category conflict: server declared '{server}' \
90 but magic bytes detected '{mime_type}'"
91 ),
92 });
93 } else {
94 tracing::debug!(
96 detected = %mime_type,
97 server = %server,
98 "MIME subtype mismatch: magic bytes disagree with server hint, using magic bytes"
99 );
100 }
101 }
102 }
103
104 return Ok(DetectedMime {
105 mime_type,
106 extension,
107 source: DetectionSource::MagicBytes,
108 });
109 }
110
111 if let Some(server) = server_mime_ref {
113 if server != "application/octet-stream" {
114 let extension = mime_to_extension(server);
115 return Ok(DetectedMime {
116 mime_type: server.to_string(),
117 extension,
118 source: DetectionSource::ServerHint,
119 });
120 }
121 }
122
123 Err(MediaError::mime_detection_failed(
125 inspect_len,
126 server_mime.map(|s| s.to_string()),
127 ))
128}
129
130pub fn is_mime_alias(a: &str, b: &str) -> bool {
141 if a == b {
142 return true;
143 }
144 let pair = (a.min(b), a.max(b));
145 matches!(
146 pair,
147 ("audio/mp3", "audio/mpeg")
148 | ("audio/wav", "audio/x-wav")
149 | ("image/jpeg", "image/jpg")
150 | ("audio/flac", "audio/x-flac")
151 | ("audio/aiff", "audio/x-aiff")
152 | ("image/vnd.microsoft.icon", "image/x-icon")
153 | ("audio/mp4", "video/mp4")
154 )
155}
156
157pub fn mime_to_extension(mime: &str) -> String {
164 let manual = match mime {
166 "image/png" => Some("png"),
167 "image/jpeg" | "image/jpg" => Some("jpg"),
168 "image/gif" => Some("gif"),
169 "image/webp" => Some("webp"),
170 "image/svg+xml" => Some("svg"),
171 "image/vnd.microsoft.icon" | "image/x-icon" => Some("ico"),
172 "audio/mpeg" | "audio/mp3" => Some("mp3"),
173 "audio/wav" | "audio/x-wav" => Some("wav"),
174 "audio/ogg" => Some("ogg"),
175 "audio/flac" | "audio/x-flac" => Some("flac"),
176 "audio/aiff" | "audio/x-aiff" => Some("aiff"),
177 "audio/mp4" | "audio/x-m4a" => Some("m4a"),
178 "video/mp4" => Some("mp4"),
179 "video/webm" => Some("webm"),
180 "application/pdf" => Some("pdf"),
181 "application/json" => Some("json"),
182 "text/plain" => Some("txt"),
183 "text/html" => Some("html"),
184 "text/csv" => Some("csv"),
185 _ => None,
186 };
187
188 if let Some(ext) = manual {
189 return ext.to_string();
190 }
191
192 if let Some(exts) = mime_guess::get_mime_extensions_str(mime) {
194 if let Some(ext) = exts.first() {
195 return sanitize_extension(ext);
196 }
197 }
198
199 "bin".to_string()
200}
201
202fn has_svg_element(text: &str) -> bool {
207 let mut search_from = 0;
208 while let Some(pos) = text[search_from..].find("<svg") {
209 let abs_pos = search_from + pos;
210 let after = abs_pos + 4; if after >= text.len() {
212 return true;
214 }
215 let next_char = text.as_bytes()[after];
216 if matches!(next_char, b' ' | b'>' | b'/' | b'\t' | b'\n' | b'\r') {
218 return true;
219 }
220 search_from = abs_pos + 4;
221 }
222 false
223}
224
225fn sanitize_extension(ext: &str) -> String {
228 ext.chars()
229 .filter(|c| c.is_ascii_alphanumeric() || *c == '-')
230 .collect::<String>()
231 .to_lowercase()
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237
238 const PNG_HEADER: &[u8] = &[0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, 0, 0, 0, 0];
240 const JPEG_HEADER: &[u8] = &[0xFF, 0xD8, 0xFF, 0xE0, 0, 0, 0, 0, 0, 0, 0, 0];
242 const WAV_HEADER: &[u8] = &[
244 0x52, 0x49, 0x46, 0x46, 0x00, 0x00, 0x00, 0x00, 0x57, 0x41, 0x56, 0x45, ];
248
249 #[test]
250 fn detect_png_magic_bytes() {
251 let result = detect_mime(PNG_HEADER, None).unwrap();
252 assert_eq!(result.mime_type, "image/png");
253 assert_eq!(result.extension, "png");
254 assert_eq!(result.source, DetectionSource::MagicBytes);
255 }
256
257 #[test]
258 fn detect_jpeg_magic_bytes() {
259 let result = detect_mime(JPEG_HEADER, None).unwrap();
260 assert_eq!(result.mime_type, "image/jpeg");
261 assert_eq!(result.source, DetectionSource::MagicBytes);
262 }
263
264 #[test]
265 fn detect_wav_magic_bytes() {
266 let result = detect_mime(WAV_HEADER, None).unwrap();
267 assert!(
268 result.mime_type.contains("wav"),
269 "expected wav, got {}",
270 result.mime_type
271 );
272 assert_eq!(result.source, DetectionSource::MagicBytes);
273 }
274
275 #[test]
276 fn unknown_bytes_returns_error() {
277 let data = &[0x00, 0x01, 0x02, 0x03, 0x04, 0x05];
278 let result = detect_mime(data, None);
279 assert!(result.is_err());
280 }
281
282 #[test]
283 fn unknown_bytes_with_octet_stream_returns_error() {
284 let data = &[0x00, 0x01, 0x02, 0x03];
285 let result = detect_mime(data, Some("application/octet-stream"));
286 assert!(result.is_err());
287 }
288
289 #[test]
290 fn unknown_bytes_with_server_hint_accepted() {
291 let data = &[0x00, 0x01, 0x02, 0x03];
292 let result = detect_mime(data, Some("image/png")).unwrap();
293 assert_eq!(result.mime_type, "image/png");
294 assert_eq!(result.source, DetectionSource::ServerHint);
295 }
296
297 #[test]
298 fn magic_bytes_preferred_over_same_category_hint() {
299 let result = detect_mime(PNG_HEADER, Some("image/webp")).unwrap();
301 assert_eq!(result.mime_type, "image/png");
302 assert_eq!(result.source, DetectionSource::MagicBytes);
303 }
304
305 #[test]
306 fn uppercase_server_mime_normalized() {
307 let data = &[0x00, 0x01, 0x02, 0x03];
308 let result = detect_mime(data, Some("IMAGE/PNG")).unwrap();
309 assert_eq!(result.mime_type, "image/png");
310 }
311
312 #[test]
313 fn svg_detection() {
314 let svg = b"<svg xmlns=\"http://www.w3.org/2000/svg\" viewBox=\"0 0 100 100\"></svg>";
315 let result = detect_mime(svg, None).unwrap();
316 assert_eq!(result.mime_type, "image/svg+xml");
317 assert_eq!(result.extension, "svg");
318 }
319
320 #[test]
321 fn mime_to_extension_common_types() {
322 assert_eq!(mime_to_extension("image/png"), "png");
323 assert_eq!(mime_to_extension("image/jpeg"), "jpg");
324 assert_eq!(mime_to_extension("audio/mpeg"), "mp3");
325 assert_eq!(mime_to_extension("audio/wav"), "wav");
326 assert_eq!(mime_to_extension("audio/x-wav"), "wav");
327 assert_eq!(mime_to_extension("audio/flac"), "flac");
328 assert_eq!(mime_to_extension("text/plain"), "txt");
329 assert_eq!(mime_to_extension("application/pdf"), "pdf");
330 assert_eq!(mime_to_extension("application/json"), "json");
331 }
332
333 #[test]
334 fn is_mime_alias_known_pairs() {
335 assert!(is_mime_alias("audio/mp3", "audio/mpeg"));
336 assert!(is_mime_alias("audio/mpeg", "audio/mp3"));
337 assert!(is_mime_alias("image/jpeg", "image/jpg"));
338 assert!(is_mime_alias("image/jpg", "image/jpeg"));
339 assert!(is_mime_alias("audio/wav", "audio/x-wav"));
340 assert!(is_mime_alias("audio/x-wav", "audio/wav"));
341 assert!(is_mime_alias("audio/flac", "audio/x-flac"));
342 assert!(is_mime_alias("audio/x-flac", "audio/flac"));
343 assert!(is_mime_alias("audio/aiff", "audio/x-aiff"));
344 assert!(is_mime_alias("image/vnd.microsoft.icon", "image/x-icon"));
345 assert!(is_mime_alias("audio/mp4", "video/mp4"));
346 assert!(is_mime_alias("video/mp4", "audio/mp4"));
347 }
348
349 #[test]
350 fn is_mime_alias_identity() {
351 assert!(is_mime_alias("image/png", "image/png"));
352 assert!(is_mime_alias("audio/mpeg", "audio/mpeg"));
353 }
354
355 #[test]
356 fn is_mime_alias_non_aliases() {
357 assert!(!is_mime_alias("image/png", "image/jpeg"));
358 assert!(!is_mime_alias("audio/mp3", "image/png"));
359 assert!(!is_mime_alias("audio/ogg", "audio/flac"));
360 }
361
362 #[test]
363 fn cross_category_mismatch_is_rejected() {
364 let result = detect_mime(PNG_HEADER, Some("audio/wav"));
366 assert!(
367 result.is_err(),
368 "Cross-category mismatch should be rejected"
369 );
370 assert_eq!(result.unwrap_err().code(), "NIKA-251");
371 }
372
373 #[test]
374 fn same_category_alias_is_accepted() {
375 let result = detect_mime(WAV_HEADER, Some("audio/x-wav"));
377 assert!(result.is_ok());
378 }
379
380 #[test]
381 fn same_category_subtype_mismatch_uses_magic_bytes() {
382 let result = detect_mime(JPEG_HEADER, Some("image/webp")).unwrap();
384 assert_eq!(result.mime_type, "image/jpeg");
385 assert_eq!(result.source, DetectionSource::MagicBytes);
386 }
387
388 #[test]
394 fn mime_to_extension_alias_image_jpg() {
395 assert_eq!(mime_to_extension("image/jpg"), "jpg");
398 }
399
400 #[test]
401 fn mime_to_extension_alias_audio_mp3() {
402 assert_eq!(mime_to_extension("audio/mp3"), "mp3");
405 }
406
407 #[test]
408 fn mime_to_extension_alias_audio_x_flac() {
409 assert_eq!(mime_to_extension("audio/x-flac"), "flac");
412 }
413
414 #[test]
415 fn mime_to_extension_alias_audio_x_wav() {
416 assert_eq!(mime_to_extension("audio/x-wav"), "wav");
418 }
419
420 #[test]
421 fn mime_to_extension_alias_audio_x_aiff() {
422 assert_eq!(mime_to_extension("audio/x-aiff"), "aiff");
424 }
425
426 #[test]
427 fn mime_to_extension_alias_image_x_icon() {
428 assert_eq!(mime_to_extension("image/x-icon"), "ico");
430 }
431
432 #[test]
433 fn mime_to_extension_alias_audio_mp4() {
434 assert_eq!(mime_to_extension("audio/mp4"), "m4a");
436 }
437
438 #[test]
445 fn is_mime_alias_aiff_reverse_order() {
446 assert!(
448 is_mime_alias("audio/x-aiff", "audio/aiff"),
449 "reverse order must match: audio/x-aiff -> audio/aiff"
450 );
451 assert!(
453 is_mime_alias("audio/aiff", "audio/x-aiff"),
454 "canonical order must match: audio/aiff -> audio/x-aiff"
455 );
456 }
457
458 #[test]
459 fn is_mime_alias_icon_reverse_order() {
460 assert!(
462 is_mime_alias("image/x-icon", "image/vnd.microsoft.icon"),
463 "reverse order must match: image/x-icon -> image/vnd.microsoft.icon"
464 );
465 assert!(
467 is_mime_alias("image/vnd.microsoft.icon", "image/x-icon"),
468 "canonical order must match: image/vnd.microsoft.icon -> image/x-icon"
469 );
470 }
471
472 #[test]
473 fn is_mime_alias_mp4_cross_category_reverse() {
474 assert!(
476 is_mime_alias("video/mp4", "audio/mp4"),
477 "reverse order must match: video/mp4 -> audio/mp4"
478 );
479 assert!(
480 is_mime_alias("audio/mp4", "video/mp4"),
481 "canonical order must match: audio/mp4 -> video/mp4"
482 );
483 }
484
485 #[test]
486 fn is_mime_alias_flac_reverse_order() {
487 assert!(
488 is_mime_alias("audio/x-flac", "audio/flac"),
489 "reverse order must match: audio/x-flac -> audio/flac"
490 );
491 }
492
493 #[test]
500 fn detect_mime_mp4_container_with_audio_mp4_server_hint() {
501 let mp4_header: Vec<u8> = vec![
503 0x00, 0x00, 0x00, 0x20, 0x66, 0x74, 0x79, 0x70, 0x69, 0x73, 0x6F, 0x6D, 0x00, 0x00, 0x02, 0x00, 0x69, 0x73, 0x6F, 0x6D, 0x69, 0x73, 0x6F, 0x32, 0x61, 0x76, 0x63, 0x31, 0x6D, 0x70, 0x34, 0x31, ];
512
513 let result = detect_mime(&mp4_header, Some("audio/mp4"));
517 assert!(
518 result.is_ok(),
519 "audio/mp4 <> video/mp4 alias should prevent cross-category rejection, got: {:?}",
520 result.err()
521 );
522
523 let detected = result.unwrap();
524 assert_eq!(
525 detected.mime_type, "video/mp4",
526 "magic bytes should win (video/mp4), not server hint"
527 );
528 assert_eq!(detected.source, DetectionSource::MagicBytes);
529 }
530
531 #[test]
532 fn detect_mime_mp4_container_with_video_mp4_server_matches() {
533 let mp4_header: Vec<u8> = vec![
535 0x00, 0x00, 0x00, 0x20, 0x66, 0x74, 0x79, 0x70, 0x69, 0x73, 0x6F, 0x6D, 0x00, 0x00,
536 0x02, 0x00, 0x69, 0x73, 0x6F, 0x6D, 0x69, 0x73, 0x6F, 0x32, 0x61, 0x76, 0x63, 0x31,
537 0x6D, 0x70, 0x34, 0x31,
538 ];
539
540 let result = detect_mime(&mp4_header, Some("video/mp4"));
542 assert!(result.is_ok());
543 let detected = result.unwrap();
544 assert_eq!(detected.mime_type, "video/mp4");
545 }
546
547 #[test]
548 fn sanitize_extension_rejects_traversal() {
549 assert_eq!(sanitize_extension("../etc/passwd"), "etcpasswd");
550 assert_eq!(sanitize_extension("png;rm -rf /"), "pngrm-rf");
551 }
552
553 #[test]
558 fn svg_heuristic_rejects_svg_prefixed_tag() {
559 let data = b"<svg-report><item>data</item></svg-report>";
563 let result = detect_mime(data, None);
564 assert!(
565 result.is_err(),
566 "<svg-report> should NOT be detected as SVG"
567 );
568 }
569
570 #[test]
571 fn svg_heuristic_leading_whitespace() {
572 let data = b" \t\n <svg xmlns=\"http://www.w3.org/2000/svg\"></svg>";
574 let result = detect_mime(data, None).unwrap();
575 assert_eq!(result.mime_type, "image/svg+xml");
576 assert_eq!(result.extension, "svg");
577 assert_eq!(result.source, DetectionSource::MagicBytes);
578 }
579
580 #[test]
581 fn svg_heuristic_xml_declaration_plus_svg_no_xmlns() {
582 let data = b"<?xml version=\"1.0\"?>\n<svg></svg>";
585 let result = detect_mime(data, None).unwrap();
586 assert_eq!(result.mime_type, "image/svg+xml");
587 assert_eq!(result.extension, "svg");
588 }
589
590 #[test]
591 fn svg_heuristic_false_positive_svg_in_xml_comment() {
592 let data = b"<?xml version=\"1.0\"?><!-- comment with <svg> in it --><root/>";
596 let result = detect_mime(data, None);
597 assert!(
598 result.is_ok(),
599 "heuristic fires on <svg> inside XML comment (known false positive)"
600 );
601 let detected = result.unwrap();
602 assert_eq!(detected.mime_type, "image/svg+xml");
603 }
604
605 #[test]
606 fn svg_heuristic_false_negative_beyond_512_byte_window() {
607 let mut data = Vec::new();
611 data.extend_from_slice(b"<data>");
613 data.extend(std::iter::repeat_n(b'x', 610));
615 data.extend_from_slice(b"</data><svg xmlns=\"http://www.w3.org/2000/svg\"></svg>");
616
617 let result = detect_mime(&data, None);
618 assert!(
621 result.is_err(),
622 "SVG beyond 512-byte window should not be detected (false negative)"
623 );
624 }
625
626 #[test]
627 fn svg_heuristic_single_quotes_xmlns() {
628 let data = b"<svg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 100 100'></svg>";
635 let result = detect_mime(data, None).unwrap();
636 assert_eq!(result.mime_type, "image/svg+xml");
637 assert_eq!(result.extension, "svg");
638 }
639
640 #[test]
641 fn svg_heuristic_uppercase_tag_not_detected() {
642 let data = b"<SVG xmlns=\"http://www.w3.org/2000/svg\"></SVG>";
647 let result = detect_mime(data, None);
648 assert!(
649 result.is_err(),
650 "uppercase <SVG> not detected by case-sensitive heuristic"
651 );
652 }
653
654 #[test]
655 fn svg_heuristic_self_closing_no_xmlns() {
656 let data = b"<svg/>";
659 let result = detect_mime(data, None).unwrap();
660 assert_eq!(result.mime_type, "image/svg+xml");
661 assert_eq!(result.extension, "svg");
662 }
663
664 #[test]
665 fn svg_heuristic_binary_with_svg_utf8_prefix() {
666 let mut data = Vec::new();
670 data.extend_from_slice(b"<svg");
671 data.push(0xFF); data.extend(std::iter::repeat_n(0x00u8, 20));
673 let result = detect_mime(&data, None);
674 assert!(
677 result.is_err(),
678 "binary data with <svg prefix but invalid UTF-8 should not be detected as SVG"
679 );
680 }
681
682 #[test]
683 fn svg_heuristic_exactly_five_bytes() {
684 let data = b"<svg>";
688 assert_eq!(data.len(), 5);
689 let result = detect_mime(data, None).unwrap();
690 assert_eq!(result.mime_type, "image/svg+xml");
691 assert_eq!(result.extension, "svg");
692 assert_eq!(result.source, DetectionSource::MagicBytes);
693 }
694}