use super::error::MediaError;
#[derive(Debug, Clone, PartialEq)]
pub struct DetectedMime {
pub mime_type: String,
pub extension: String,
pub source: DetectionSource,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[allow(dead_code)] pub enum DetectionSource {
MagicBytes,
Extension,
ServerHint,
}
pub fn detect_mime(data: &[u8], server_mime: Option<&str>) -> Result<DetectedMime, MediaError> {
let server_mime_normalized: Option<String> = server_mime.map(|m| m.to_ascii_lowercase());
let server_mime_ref = server_mime_normalized.as_deref();
let inspect_len = data.len().min(8192);
let sample = &data[..inspect_len];
if sample.len() >= 5 {
let text_start = std::str::from_utf8(&sample[..sample.len().min(512)]);
if let Ok(text) = text_start {
let trimmed = text.trim_start();
let has_xml_prefix = trimmed.starts_with("<?xml");
let has_svg_tag = has_svg_element(trimmed);
let has_svg_ns = trimmed.contains("xmlns=\"http://www.w3.org/2000/svg\"")
|| trimmed.contains("xmlns='http://www.w3.org/2000/svg'");
if has_svg_tag || (has_xml_prefix && has_svg_ns) {
return Ok(DetectedMime {
mime_type: "image/svg+xml".to_string(),
extension: "svg".to_string(),
source: DetectionSource::MagicBytes,
});
}
}
}
if let Some(kind) = infer::get(sample) {
let mime_type = kind.mime_type().to_string();
let extension = kind.extension().to_string();
if let Some(server) = server_mime_ref {
if !is_mime_alias(&mime_type, server) {
let detected_category = mime_type.split('/').next();
let server_category = server.split('/').next();
if detected_category != server_category {
return Err(MediaError::MimeDetectionFailed {
reason: format!(
"category conflict: server declared '{server}' \
but magic bytes detected '{mime_type}'"
),
});
} else {
tracing::debug!(
detected = %mime_type,
server = %server,
"MIME subtype mismatch: magic bytes disagree with server hint, using magic bytes"
);
}
}
}
return Ok(DetectedMime {
mime_type,
extension,
source: DetectionSource::MagicBytes,
});
}
if let Some(server) = server_mime_ref {
if server != "application/octet-stream" {
let extension = mime_to_extension(server);
return Ok(DetectedMime {
mime_type: server.to_string(),
extension,
source: DetectionSource::ServerHint,
});
}
}
Err(MediaError::mime_detection_failed(
inspect_len,
server_mime.map(|s| s.to_string()),
))
}
pub fn is_mime_alias(a: &str, b: &str) -> bool {
if a == b {
return true;
}
let pair = (a.min(b), a.max(b));
matches!(
pair,
("audio/mp3", "audio/mpeg")
| ("audio/wav", "audio/x-wav")
| ("image/jpeg", "image/jpg")
| ("audio/flac", "audio/x-flac")
| ("audio/aiff", "audio/x-aiff")
| ("image/vnd.microsoft.icon", "image/x-icon")
| ("audio/mp4", "video/mp4")
)
}
pub fn mime_to_extension(mime: &str) -> String {
let manual = match mime {
"image/png" => Some("png"),
"image/jpeg" | "image/jpg" => Some("jpg"),
"image/gif" => Some("gif"),
"image/webp" => Some("webp"),
"image/svg+xml" => Some("svg"),
"image/vnd.microsoft.icon" | "image/x-icon" => Some("ico"),
"audio/mpeg" | "audio/mp3" => Some("mp3"),
"audio/wav" | "audio/x-wav" => Some("wav"),
"audio/ogg" => Some("ogg"),
"audio/flac" | "audio/x-flac" => Some("flac"),
"audio/aiff" | "audio/x-aiff" => Some("aiff"),
"audio/mp4" | "audio/x-m4a" => Some("m4a"),
"video/mp4" => Some("mp4"),
"video/webm" => Some("webm"),
"application/pdf" => Some("pdf"),
"application/json" => Some("json"),
"text/plain" => Some("txt"),
"text/html" => Some("html"),
"text/csv" => Some("csv"),
_ => None,
};
if let Some(ext) = manual {
return ext.to_string();
}
if let Some(exts) = mime_guess::get_mime_extensions_str(mime) {
if let Some(ext) = exts.first() {
return sanitize_extension(ext);
}
}
"bin".to_string()
}
fn has_svg_element(text: &str) -> bool {
let mut search_from = 0;
while let Some(pos) = text[search_from..].find("<svg") {
let abs_pos = search_from + pos;
let after = abs_pos + 4; if after >= text.len() {
return true;
}
let next_char = text.as_bytes()[after];
if matches!(next_char, b' ' | b'>' | b'/' | b'\t' | b'\n' | b'\r') {
return true;
}
search_from = abs_pos + 4;
}
false
}
fn sanitize_extension(ext: &str) -> String {
ext.chars()
.filter(|c| c.is_ascii_alphanumeric() || *c == '-')
.collect::<String>()
.to_lowercase()
}
#[cfg(test)]
mod tests {
use super::*;
const PNG_HEADER: &[u8] = &[0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, 0, 0, 0, 0];
const JPEG_HEADER: &[u8] = &[0xFF, 0xD8, 0xFF, 0xE0, 0, 0, 0, 0, 0, 0, 0, 0];
const WAV_HEADER: &[u8] = &[
0x52, 0x49, 0x46, 0x46, 0x00, 0x00, 0x00, 0x00, 0x57, 0x41, 0x56, 0x45, ];
#[test]
fn detect_png_magic_bytes() {
let result = detect_mime(PNG_HEADER, None).unwrap();
assert_eq!(result.mime_type, "image/png");
assert_eq!(result.extension, "png");
assert_eq!(result.source, DetectionSource::MagicBytes);
}
#[test]
fn detect_jpeg_magic_bytes() {
let result = detect_mime(JPEG_HEADER, None).unwrap();
assert_eq!(result.mime_type, "image/jpeg");
assert_eq!(result.source, DetectionSource::MagicBytes);
}
#[test]
fn detect_wav_magic_bytes() {
let result = detect_mime(WAV_HEADER, None).unwrap();
assert!(
result.mime_type.contains("wav"),
"expected wav, got {}",
result.mime_type
);
assert_eq!(result.source, DetectionSource::MagicBytes);
}
#[test]
fn unknown_bytes_returns_error() {
let data = &[0x00, 0x01, 0x02, 0x03, 0x04, 0x05];
let result = detect_mime(data, None);
assert!(result.is_err());
}
#[test]
fn unknown_bytes_with_octet_stream_returns_error() {
let data = &[0x00, 0x01, 0x02, 0x03];
let result = detect_mime(data, Some("application/octet-stream"));
assert!(result.is_err());
}
#[test]
fn unknown_bytes_with_server_hint_accepted() {
let data = &[0x00, 0x01, 0x02, 0x03];
let result = detect_mime(data, Some("image/png")).unwrap();
assert_eq!(result.mime_type, "image/png");
assert_eq!(result.source, DetectionSource::ServerHint);
}
#[test]
fn magic_bytes_preferred_over_same_category_hint() {
let result = detect_mime(PNG_HEADER, Some("image/webp")).unwrap();
assert_eq!(result.mime_type, "image/png");
assert_eq!(result.source, DetectionSource::MagicBytes);
}
#[test]
fn uppercase_server_mime_normalized() {
let data = &[0x00, 0x01, 0x02, 0x03];
let result = detect_mime(data, Some("IMAGE/PNG")).unwrap();
assert_eq!(result.mime_type, "image/png");
}
#[test]
fn svg_detection() {
let svg = b"<svg xmlns=\"http://www.w3.org/2000/svg\" viewBox=\"0 0 100 100\"></svg>";
let result = detect_mime(svg, None).unwrap();
assert_eq!(result.mime_type, "image/svg+xml");
assert_eq!(result.extension, "svg");
}
#[test]
fn mime_to_extension_common_types() {
assert_eq!(mime_to_extension("image/png"), "png");
assert_eq!(mime_to_extension("image/jpeg"), "jpg");
assert_eq!(mime_to_extension("audio/mpeg"), "mp3");
assert_eq!(mime_to_extension("audio/wav"), "wav");
assert_eq!(mime_to_extension("audio/x-wav"), "wav");
assert_eq!(mime_to_extension("audio/flac"), "flac");
assert_eq!(mime_to_extension("text/plain"), "txt");
assert_eq!(mime_to_extension("application/pdf"), "pdf");
assert_eq!(mime_to_extension("application/json"), "json");
}
#[test]
fn is_mime_alias_known_pairs() {
assert!(is_mime_alias("audio/mp3", "audio/mpeg"));
assert!(is_mime_alias("audio/mpeg", "audio/mp3"));
assert!(is_mime_alias("image/jpeg", "image/jpg"));
assert!(is_mime_alias("image/jpg", "image/jpeg"));
assert!(is_mime_alias("audio/wav", "audio/x-wav"));
assert!(is_mime_alias("audio/x-wav", "audio/wav"));
assert!(is_mime_alias("audio/flac", "audio/x-flac"));
assert!(is_mime_alias("audio/x-flac", "audio/flac"));
assert!(is_mime_alias("audio/aiff", "audio/x-aiff"));
assert!(is_mime_alias("image/vnd.microsoft.icon", "image/x-icon"));
assert!(is_mime_alias("audio/mp4", "video/mp4"));
assert!(is_mime_alias("video/mp4", "audio/mp4"));
}
#[test]
fn is_mime_alias_identity() {
assert!(is_mime_alias("image/png", "image/png"));
assert!(is_mime_alias("audio/mpeg", "audio/mpeg"));
}
#[test]
fn is_mime_alias_non_aliases() {
assert!(!is_mime_alias("image/png", "image/jpeg"));
assert!(!is_mime_alias("audio/mp3", "image/png"));
assert!(!is_mime_alias("audio/ogg", "audio/flac"));
}
#[test]
fn cross_category_mismatch_is_rejected() {
let result = detect_mime(PNG_HEADER, Some("audio/wav"));
assert!(
result.is_err(),
"Cross-category mismatch should be rejected"
);
assert_eq!(result.unwrap_err().code(), "NIKA-251");
}
#[test]
fn same_category_alias_is_accepted() {
let result = detect_mime(WAV_HEADER, Some("audio/x-wav"));
assert!(result.is_ok());
}
#[test]
fn same_category_subtype_mismatch_uses_magic_bytes() {
let result = detect_mime(JPEG_HEADER, Some("image/webp")).unwrap();
assert_eq!(result.mime_type, "image/jpeg");
assert_eq!(result.source, DetectionSource::MagicBytes);
}
#[test]
fn mime_to_extension_alias_image_jpg() {
assert_eq!(mime_to_extension("image/jpg"), "jpg");
}
#[test]
fn mime_to_extension_alias_audio_mp3() {
assert_eq!(mime_to_extension("audio/mp3"), "mp3");
}
#[test]
fn mime_to_extension_alias_audio_x_flac() {
assert_eq!(mime_to_extension("audio/x-flac"), "flac");
}
#[test]
fn mime_to_extension_alias_audio_x_wav() {
assert_eq!(mime_to_extension("audio/x-wav"), "wav");
}
#[test]
fn mime_to_extension_alias_audio_x_aiff() {
assert_eq!(mime_to_extension("audio/x-aiff"), "aiff");
}
#[test]
fn mime_to_extension_alias_image_x_icon() {
assert_eq!(mime_to_extension("image/x-icon"), "ico");
}
#[test]
fn mime_to_extension_alias_audio_mp4() {
assert_eq!(mime_to_extension("audio/mp4"), "m4a");
}
#[test]
fn is_mime_alias_aiff_reverse_order() {
assert!(
is_mime_alias("audio/x-aiff", "audio/aiff"),
"reverse order must match: audio/x-aiff -> audio/aiff"
);
assert!(
is_mime_alias("audio/aiff", "audio/x-aiff"),
"canonical order must match: audio/aiff -> audio/x-aiff"
);
}
#[test]
fn is_mime_alias_icon_reverse_order() {
assert!(
is_mime_alias("image/x-icon", "image/vnd.microsoft.icon"),
"reverse order must match: image/x-icon -> image/vnd.microsoft.icon"
);
assert!(
is_mime_alias("image/vnd.microsoft.icon", "image/x-icon"),
"canonical order must match: image/vnd.microsoft.icon -> image/x-icon"
);
}
#[test]
fn is_mime_alias_mp4_cross_category_reverse() {
assert!(
is_mime_alias("video/mp4", "audio/mp4"),
"reverse order must match: video/mp4 -> audio/mp4"
);
assert!(
is_mime_alias("audio/mp4", "video/mp4"),
"canonical order must match: audio/mp4 -> video/mp4"
);
}
#[test]
fn is_mime_alias_flac_reverse_order() {
assert!(
is_mime_alias("audio/x-flac", "audio/flac"),
"reverse order must match: audio/x-flac -> audio/flac"
);
}
#[test]
fn detect_mime_mp4_container_with_audio_mp4_server_hint() {
let mp4_header: Vec<u8> = vec![
0x00, 0x00, 0x00, 0x20, 0x66, 0x74, 0x79, 0x70, 0x69, 0x73, 0x6F, 0x6D, 0x00, 0x00, 0x02, 0x00, 0x69, 0x73, 0x6F, 0x6D, 0x69, 0x73, 0x6F, 0x32, 0x61, 0x76, 0x63, 0x31, 0x6D, 0x70, 0x34, 0x31, ];
let result = detect_mime(&mp4_header, Some("audio/mp4"));
assert!(
result.is_ok(),
"audio/mp4 <> video/mp4 alias should prevent cross-category rejection, got: {:?}",
result.err()
);
let detected = result.unwrap();
assert_eq!(
detected.mime_type, "video/mp4",
"magic bytes should win (video/mp4), not server hint"
);
assert_eq!(detected.source, DetectionSource::MagicBytes);
}
#[test]
fn detect_mime_mp4_container_with_video_mp4_server_matches() {
let mp4_header: Vec<u8> = vec![
0x00, 0x00, 0x00, 0x20, 0x66, 0x74, 0x79, 0x70, 0x69, 0x73, 0x6F, 0x6D, 0x00, 0x00,
0x02, 0x00, 0x69, 0x73, 0x6F, 0x6D, 0x69, 0x73, 0x6F, 0x32, 0x61, 0x76, 0x63, 0x31,
0x6D, 0x70, 0x34, 0x31,
];
let result = detect_mime(&mp4_header, Some("video/mp4"));
assert!(result.is_ok());
let detected = result.unwrap();
assert_eq!(detected.mime_type, "video/mp4");
}
#[test]
fn sanitize_extension_rejects_traversal() {
assert_eq!(sanitize_extension("../etc/passwd"), "etcpasswd");
assert_eq!(sanitize_extension("png;rm -rf /"), "pngrm-rf");
}
#[test]
fn svg_heuristic_rejects_svg_prefixed_tag() {
let data = b"<svg-report><item>data</item></svg-report>";
let result = detect_mime(data, None);
assert!(
result.is_err(),
"<svg-report> should NOT be detected as SVG"
);
}
#[test]
fn svg_heuristic_leading_whitespace() {
let data = b" \t\n <svg xmlns=\"http://www.w3.org/2000/svg\"></svg>";
let result = detect_mime(data, None).unwrap();
assert_eq!(result.mime_type, "image/svg+xml");
assert_eq!(result.extension, "svg");
assert_eq!(result.source, DetectionSource::MagicBytes);
}
#[test]
fn svg_heuristic_xml_declaration_plus_svg_no_xmlns() {
let data = b"<?xml version=\"1.0\"?>\n<svg></svg>";
let result = detect_mime(data, None).unwrap();
assert_eq!(result.mime_type, "image/svg+xml");
assert_eq!(result.extension, "svg");
}
#[test]
fn svg_heuristic_false_positive_svg_in_xml_comment() {
let data = b"<?xml version=\"1.0\"?><!-- comment with <svg> in it --><root/>";
let result = detect_mime(data, None);
assert!(
result.is_ok(),
"heuristic fires on <svg> inside XML comment (known false positive)"
);
let detected = result.unwrap();
assert_eq!(detected.mime_type, "image/svg+xml");
}
#[test]
fn svg_heuristic_false_negative_beyond_512_byte_window() {
let mut data = Vec::new();
data.extend_from_slice(b"<data>");
data.extend(std::iter::repeat_n(b'x', 610));
data.extend_from_slice(b"</data><svg xmlns=\"http://www.w3.org/2000/svg\"></svg>");
let result = detect_mime(&data, None);
assert!(
result.is_err(),
"SVG beyond 512-byte window should not be detected (false negative)"
);
}
#[test]
fn svg_heuristic_single_quotes_xmlns() {
let data = b"<svg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 100 100'></svg>";
let result = detect_mime(data, None).unwrap();
assert_eq!(result.mime_type, "image/svg+xml");
assert_eq!(result.extension, "svg");
}
#[test]
fn svg_heuristic_uppercase_tag_not_detected() {
let data = b"<SVG xmlns=\"http://www.w3.org/2000/svg\"></SVG>";
let result = detect_mime(data, None);
assert!(
result.is_err(),
"uppercase <SVG> not detected by case-sensitive heuristic"
);
}
#[test]
fn svg_heuristic_self_closing_no_xmlns() {
let data = b"<svg/>";
let result = detect_mime(data, None).unwrap();
assert_eq!(result.mime_type, "image/svg+xml");
assert_eq!(result.extension, "svg");
}
#[test]
fn svg_heuristic_binary_with_svg_utf8_prefix() {
let mut data = Vec::new();
data.extend_from_slice(b"<svg");
data.push(0xFF); data.extend(std::iter::repeat_n(0x00u8, 20));
let result = detect_mime(&data, None);
assert!(
result.is_err(),
"binary data with <svg prefix but invalid UTF-8 should not be detected as SVG"
);
}
#[test]
fn svg_heuristic_exactly_five_bytes() {
let data = b"<svg>";
assert_eq!(data.len(), 5);
let result = detect_mime(data, None).unwrap();
assert_eq!(result.mime_type, "image/svg+xml");
assert_eq!(result.extension, "svg");
assert_eq!(result.source, DetectionSource::MagicBytes);
}
}