use std::borrow::Cow;
use std::collections::HashMap;
use qubit_codec::{
CIntegerLiteralCodec,
CStringLiteralCodec,
CodecError,
HexCodec,
};
use roxmltree::{
Document,
Node,
};
use crate::{
MagicValueType,
MimeDetectionPolicy,
MimeError,
MimeGlob,
MimeMagic,
MimeMagicMatcher,
MimeResult,
MimeType,
MimeTypeBuilder,
};
#[derive(Debug, Clone)]
pub struct MimeRepository {
mime_types: Vec<MimeType>,
name_map: HashMap<String, usize>,
literal_globs: HashMap<String, Vec<GlobEntry>>,
extension_globs: HashMap<String, Vec<GlobEntry>>,
other_globs: Vec<GlobEntry>,
max_test_bytes: usize,
}
#[derive(Debug, Clone)]
struct GlobEntry {
glob: MimeGlob,
mime_index: usize,
}
impl MimeRepository {
pub fn from_xml(xml: &str) -> MimeResult<Self> {
let xml = strip_doctype(xml);
let document = Document::parse(&xml)?;
let root = document.root_element();
if root.tag_name().name() != "mime-info" {
return Err(MimeError::invalid_element(
root.tag_name().name(),
"root element must be <mime-info>",
));
}
let mut repository = Self::empty();
for child in root.children().filter(Node::is_element) {
if child.tag_name().name() == "mime-type" {
repository.add_mime_type(parse_mime_type(child)?);
}
}
Ok(repository)
}
pub fn empty() -> Self {
Self {
mime_types: Vec::new(),
name_map: HashMap::new(),
literal_globs: HashMap::new(),
extension_globs: HashMap::new(),
other_globs: Vec::new(),
max_test_bytes: 0,
}
}
pub fn all(&self) -> &[MimeType] {
&self.mime_types
}
pub fn get(&self, name: &str) -> Option<&MimeType> {
self.name_map
.get(&normalize_mime_name(name))
.and_then(|index| self.mime_types.get(*index))
}
pub fn max_test_bytes(&self) -> usize {
self.max_test_bytes
}
pub fn detect_by_filename(&self, filename: &str) -> Vec<&MimeType> {
let exact_filename = filename_from_path(filename);
if exact_filename.is_empty() {
return Vec::new();
}
let lookup_filename = exact_filename.to_lowercase();
let mut result = GlobDetectionResult::new();
if let Some(entries) = self.literal_globs.get(&lookup_filename) {
result.add_matching_entries(entries, exact_filename);
}
for extension in extension_suffixes(&lookup_filename) {
if let Some(entries) = self.extension_globs.get(extension) {
result.add_matching_entries(entries, exact_filename);
}
}
for entry in &self.other_globs {
if entry.glob.matches(exact_filename) {
result.compare_add(entry);
}
}
result
.entries
.into_iter()
.filter_map(|entry| self.mime_types.get(entry.mime_index))
.collect()
}
pub fn detect_by_content(&self, bytes: &[u8]) -> Vec<&MimeType> {
let mut result = MagicDetectionResult::new();
for mime_type in &self.mime_types {
for magic in mime_type.magics() {
let priority = magic.priority();
if priority >= result.best_priority && magic.matches(bytes) {
result.compare_add(priority, mime_type);
}
}
}
result.mime_types
}
pub fn detect(
&self,
filename: &str,
bytes: &[u8],
policy: MimeDetectionPolicy,
) -> Vec<&MimeType> {
let from_filename = self.detect_by_filename(filename);
if from_filename.len() == 1 && policy == MimeDetectionPolicy::PreferFilename {
return from_filename;
}
let from_content = self.detect_by_content(bytes);
merge_results(from_filename, from_content)
}
fn add_mime_type(&mut self, mime_type: MimeType) {
let mime_index = self.mime_types.len();
self.index_names(mime_index, &mime_type);
self.index_globs(mime_index, &mime_type);
self.index_magics(&mime_type);
self.mime_types.push(mime_type);
}
fn index_names(&mut self, mime_index: usize, mime_type: &MimeType) {
self.name_map
.insert(normalize_mime_name(mime_type.name()), mime_index);
for alias in mime_type.aliases() {
self.name_map.insert(normalize_mime_name(alias), mime_index);
}
}
fn index_globs(&mut self, mime_index: usize, mime_type: &MimeType) {
for glob in mime_type.globs() {
let entry = GlobEntry {
glob: glob.clone(),
mime_index,
};
if let Some(extension) = extension_pattern(glob.pattern()) {
self.extension_globs
.entry(extension.to_lowercase())
.or_default()
.push(entry);
} else if is_literal_pattern(glob.pattern()) {
self.literal_globs
.entry(glob.pattern().to_lowercase())
.or_default()
.push(entry);
} else {
self.other_globs.push(entry);
}
}
}
fn index_magics(&mut self, mime_type: &MimeType) {
for magic in mime_type.magics() {
self.max_test_bytes = self.max_test_bytes.max(magic.max_test_bytes());
}
}
}
#[derive(Debug)]
struct GlobDetectionResult<'a> {
best_weight: u16,
best_length: usize,
entries: Vec<&'a GlobEntry>,
}
impl<'a> GlobDetectionResult<'a> {
fn new() -> Self {
Self {
best_weight: 0,
best_length: 0,
entries: Vec::new(),
}
}
fn add_matching_entries(&mut self, entries: &'a [GlobEntry], filename: &str) {
for entry in entries {
if entry.glob.matches(filename) {
self.compare_add(entry);
}
}
}
fn compare_add(&mut self, entry: &'a GlobEntry) {
let weight = entry.glob.weight();
let length = entry.glob.pattern().len();
if self.entries.is_empty() || weight > self.best_weight {
self.entries.clear();
self.entries.push(entry);
self.best_weight = weight;
self.best_length = length;
} else if weight == self.best_weight {
if length > self.best_length {
self.entries.clear();
self.entries.push(entry);
self.best_length = length;
} else if length == self.best_length {
self.entries.push(entry);
}
}
}
}
fn strip_doctype(xml: &str) -> Cow<'_, str> {
let Some(start) = xml.find("<!DOCTYPE") else {
return Cow::Borrowed(xml);
};
let Some(rest) = xml.get(start..) else {
return Cow::Borrowed(xml);
};
let end_offset = rest
.find("]>")
.map(|index| index + 2)
.or_else(|| rest.find('>').map(|index| index + 1));
let Some(end_offset) = end_offset else {
return Cow::Borrowed(xml);
};
let mut cleaned = String::with_capacity(xml.len().saturating_sub(end_offset));
cleaned.push_str(&xml[..start]);
cleaned.push_str(&xml[start + end_offset..]);
Cow::Owned(cleaned)
}
#[derive(Debug)]
struct MagicDetectionResult<'a> {
best_priority: u16,
mime_types: Vec<&'a MimeType>,
}
impl<'a> MagicDetectionResult<'a> {
fn new() -> Self {
Self {
best_priority: 0,
mime_types: Vec::new(),
}
}
fn compare_add(&mut self, priority: u16, mime_type: &'a MimeType) {
if self.mime_types.is_empty() || priority > self.best_priority {
self.mime_types.clear();
self.mime_types.push(mime_type);
self.best_priority = priority;
} else if priority == self.best_priority && !self.mime_types.contains(&mime_type) {
self.mime_types.push(mime_type);
}
}
}
fn parse_mime_type(node: Node<'_, '_>) -> MimeResult<MimeType> {
let name = required_attr(node, "type")?.to_owned();
let mut builder = MimeTypeBuilder::new(&name);
for child in node.children().filter(Node::is_element) {
match child.tag_name().name() {
"comment" => {
let language = child.attribute("xml:lang").unwrap_or("");
builder = builder.description(language, child.text().unwrap_or(""));
}
"alias" => builder = builder.alias(required_attr(child, "type")?),
"sub-class-of" => builder = builder.super_type(required_attr(child, "type")?),
"glob" => builder = builder.glob(parse_glob(child)?),
"magic" => builder = builder.magic(parse_magic(child)?),
_ => {}
}
}
Ok(builder.build())
}
fn parse_glob(node: Node<'_, '_>) -> MimeResult<MimeGlob> {
let pattern = required_attr(node, "pattern")?;
let weight = optional_u16_attr(
node,
"weight",
MimeGlob::MIN_WEIGHT,
MimeGlob::MAX_WEIGHT,
MimeGlob::DEFAULT_WEIGHT,
)?;
let case_sensitive = optional_bool_attr(node, "case-sensitive", false)?;
MimeGlob::new(pattern, weight, case_sensitive)
}
fn parse_magic(node: Node<'_, '_>) -> MimeResult<MimeMagic> {
let priority = optional_u16_attr(
node,
"priority",
MimeMagic::MIN_PRIORITY,
MimeMagic::MAX_PRIORITY,
MimeMagic::DEFAULT_PRIORITY,
)?;
let matchers: MimeResult<Vec<_>> = node
.children()
.filter(Node::is_element)
.filter(|child| child.tag_name().name() == "match")
.map(parse_matcher)
.collect();
let matchers = matchers?;
if matchers.is_empty() {
return Err(MimeError::invalid_element(
"magic",
"magic must contain at least one match",
));
}
Ok(MimeMagic::new(priority, matchers))
}
fn parse_matcher(node: Node<'_, '_>) -> MimeResult<MimeMagicMatcher> {
let type_name = required_attr(node, "type")?;
let value_type = MagicValueType::from_name(type_name)
.ok_or_else(|| MimeError::invalid_attr("match", "type", type_name, "unknown type"))?;
let (offset_begin, offset_end) = parse_offset(required_attr(node, "offset")?)?;
let value = parse_value(value_type, required_attr(node, "value")?)?;
let mask = match node.attribute("mask") {
Some(mask) => Some(parse_mask(value_type, mask)?),
None => None,
};
let sub_matchers: MimeResult<Vec<_>> = node
.children()
.filter(Node::is_element)
.filter(|child| child.tag_name().name() == "match")
.map(parse_matcher)
.collect();
MimeMagicMatcher::new(
value_type,
offset_begin,
offset_end,
value,
mask,
sub_matchers?,
)
}
fn required_attr<'a>(node: Node<'a, '_>, name: &str) -> MimeResult<&'a str> {
node.attribute(name)
.filter(|value| !value.is_empty())
.ok_or_else(|| {
MimeError::invalid_attr(
node.tag_name().name(),
name,
"",
"required attribute is missing",
)
})
}
fn optional_u16_attr(
node: Node<'_, '_>,
name: &str,
min: u16,
max: u16,
default: u16,
) -> MimeResult<u16> {
let Some(value) = node.attribute(name) else {
return Ok(default);
};
let parsed = value.parse::<u16>().map_err(|error| {
MimeError::invalid_attr(node.tag_name().name(), name, value, error.to_string())
})?;
if parsed < min || parsed > max {
return Err(MimeError::invalid_attr(
node.tag_name().name(),
name,
value,
format!("value must be in {min}..={max}"),
));
}
Ok(parsed)
}
fn optional_bool_attr(node: Node<'_, '_>, name: &str, default: bool) -> MimeResult<bool> {
match node.attribute(name) {
Some("true") => Ok(true),
Some("false") => Ok(false),
Some(value) => Err(MimeError::invalid_attr(
node.tag_name().name(),
name,
value,
"expected true or false",
)),
None => Ok(default),
}
}
fn parse_offset(value: &str) -> MimeResult<(usize, usize)> {
let (begin, end) = value.split_once(':').map_or((value, value), |parts| parts);
let offset_begin = parse_usize(begin, "offset")?;
let offset_end = parse_usize(end, "offset")?;
if offset_begin > offset_end {
return Err(MimeError::invalid_attr(
"match",
"offset",
value,
"offset begin must not exceed offset end",
));
}
Ok((offset_begin, offset_end))
}
fn parse_usize(value: &str, attribute: &str) -> MimeResult<usize> {
value.parse::<usize>().map_err(|error| {
MimeError::invalid_attr(
"match",
attribute,
value,
format!("invalid integer: {error}"),
)
})
}
fn parse_value(value_type: MagicValueType, value: &str) -> MimeResult<Vec<u8>> {
match value_type {
MagicValueType::String => parse_c_string_bytes(value),
_ => parse_numeric_bytes(value_type, value),
}
}
fn parse_mask(value_type: MagicValueType, value: &str) -> MimeResult<Vec<u8>> {
match value_type {
MagicValueType::String => parse_hex_bytes(value),
_ => parse_numeric_bytes(value_type, value),
}
}
fn parse_c_string_bytes(value: &str) -> MimeResult<Vec<u8>> {
CStringLiteralCodec::new().decode(value).map_err(|error| {
MimeError::invalid_attr(
"match",
"value",
value,
format!("invalid C string literal: {error}"),
)
})
}
fn parse_numeric_bytes(value_type: MagicValueType, value: &str) -> MimeResult<Vec<u8>> {
let number = CIntegerLiteralCodec::new().decode(value).map_err(|error| {
MimeError::invalid_attr(
"match",
"value",
value,
format!("invalid C integer literal: {error}"),
)
})?;
match value_type
.numeric_width()
.expect("numeric parser should only receive numeric magic types")
{
1 => Ok(vec![number as u8]),
2 => Ok((number as u16).to_be_bytes().to_vec()),
4 => Ok((number as u32).to_be_bytes().to_vec()),
_ => unreachable!("unsupported numeric magic width"),
}
}
fn parse_hex_bytes(value: &str) -> MimeResult<Vec<u8>> {
HexCodec::new()
.with_prefix("0x")
.with_ignore_prefix_case(true)
.decode(value)
.map_err(|error| match error {
CodecError::MissingPrefix { .. } => {
MimeError::invalid_attr("match", "mask", value, "string mask must start with 0x")
}
other => MimeError::invalid_attr(
"match",
"mask",
value,
format!("invalid hex byte: {other}"),
),
})
}
fn normalize_mime_name(name: &str) -> String {
name.to_lowercase()
}
fn filename_from_path(path: &str) -> &str {
path.rsplit(['/', '\\']).next().unwrap_or(path)
}
fn extension_suffixes(filename: &str) -> Vec<&str> {
filename
.match_indices('.')
.map(|(index, _)| &filename[index + 1..])
.filter(|extension| !extension.is_empty())
.collect()
}
fn extension_pattern(pattern: &str) -> Option<&str> {
let extension = pattern.strip_prefix("*.")?;
if extension.is_empty()
|| extension
.chars()
.any(|ch| matches!(ch, '*' | '?' | '{' | '}' | '!' | '[' | ']' | '^'))
{
None
} else {
Some(extension)
}
}
fn is_literal_pattern(pattern: &str) -> bool {
!pattern
.chars()
.any(|ch| matches!(ch, '*' | '?' | '{' | '}' | '!' | '[' | ']' | '^'))
}
fn merge_results<'a>(
from_filename: Vec<&'a MimeType>,
from_content: Vec<&'a MimeType>,
) -> Vec<&'a MimeType> {
if from_filename.is_empty() {
return from_content.into_iter().take(1).collect();
}
if from_content.is_empty() {
return from_filename.into_iter().take(1).collect();
}
if let Some(common) = from_filename.iter().find(|mime_type| {
from_content
.iter()
.any(|content| content.name() == mime_type.name())
}) {
vec![*common]
} else {
from_content.into_iter().take(1).collect()
}
}