use std::collections::HashMap;
use std::error::Error as StdError;
use std::fmt;
use std::ops::Range;
use std::str::from_utf8;
use memchr::{memchr, memmem::find};
pub struct Database<C> {
cache: C,
types: Box<[FrozenType]>,
matches: FrozenMatches,
children: FrozenChildren,
text_idx: u32,
binary_idx: u32,
}
#[derive(Debug, Default)]
struct Type {
matches: Vec<Match>,
children: Vec<u32>,
}
#[derive(Debug)]
struct Match {
cnt: u32,
off: u32,
}
#[derive(Debug)]
struct FrozenType {
mime_type_off: u32,
matches: Range<u32>,
children: Range<u32>,
}
type FrozenMatches = Box<[Match]>;
type FrozenChildren = Box<[u32]>;
impl<C> Database<C>
where
C: AsRef<[u8]>,
{
pub fn open(cache: C) -> Result<Self, Error> {
let (parents, types) = parse_offsets(cache.as_ref()).ok_or(Error::MalformedOffsets)?;
let mut parents = parse_parents(cache.as_ref(), parents).ok_or(Error::MalformedParents)?;
let mut types = parse_types(cache.as_ref(), types).ok_or(Error::MalformedTypes)?;
let (text_off, binary_off) = add_parents(cache.as_ref(), &mut parents, &mut types)?;
add_children(&parents, &mut types);
Ok(Self::freeze(cache, types, text_off, binary_off))
}
fn freeze(cache: C, types: HashMap<u32, Type>, text_off: u32, binary_off: u32) -> Self {
let mut matches = Vec::new();
let mut children = Vec::new();
let types = types
.into_iter()
.map(|(mime_type_off, r#type)| {
let matches_start = u32::try_from(matches.len()).unwrap();
let matches_end = matches_start + u32::try_from(r#type.matches.len()).unwrap();
matches.extend(r#type.matches);
let children_start = u32::try_from(children.len()).unwrap();
let children_end = children_start + u32::try_from(r#type.children.len()).unwrap();
children.extend(r#type.children);
FrozenType {
mime_type_off,
matches: matches_start..matches_end,
children: children_start..children_end,
}
})
.collect::<Box<[_]>>();
let mime_type_off_to_type_idx = types
.iter()
.enumerate()
.map(|(idx, r#type)| (r#type.mime_type_off, u32::try_from(idx).unwrap()))
.collect::<HashMap<_, _>>();
for type_idx in &mut children {
*type_idx = mime_type_off_to_type_idx[type_idx];
}
Self {
cache,
types,
matches: matches.into(),
children: children.into(),
text_idx: mime_type_off_to_type_idx[&text_off],
binary_idx: mime_type_off_to_type_idx[&binary_off],
}
}
pub fn r#match<'a>(&'a self, bytes: &[u8]) -> Result<&'a str, Error> {
let is_text = memchr(0, bytes).is_none();
let text_type = &self.types[self.text_idx as usize];
let binary_type = &self.types[self.binary_idx as usize];
let mime_type_off = if is_text
&& let Some(mime_type_off) = self.match_children(&text_type.children, bytes)?
{
mime_type_off
} else if let Some(mime_type_off) = self.match_children(&binary_type.children, bytes)? {
mime_type_off
} else if is_text {
text_type.mime_type_off
} else {
binary_type.mime_type_off
};
resolve_mime_type(self.cache.as_ref(), mime_type_off).ok_or(Error::InvalidMimeType)
}
fn match_children(&self, children: &Range<u32>, bytes: &[u8]) -> Result<Option<u32>, Error> {
let children = &self.children[children.start as usize..children.end as usize];
for type_idx in children {
let r#type = &self.types[*type_idx as usize];
let matches = &self.matches[r#type.matches.start as usize..r#type.matches.end as usize];
debug_assert!(!matches.is_empty());
for r#match in matches {
if self
.try_match(r#match.cnt, r#match.off, bytes)
.ok_or(Error::MalformedMatch)?
{
let mime_type_off = self
.match_children(&r#type.children, bytes)?
.unwrap_or(r#type.mime_type_off);
return Ok(Some(mime_type_off));
}
}
}
Ok(None)
}
fn try_match(&self, cnt: u32, off: u32, bytes: &[u8]) -> Option<bool> {
let cnt = cnt as usize;
let off = off as usize;
let matchlets = self.cache.as_ref().get(off..)?.get(..32 * cnt)?;
for idx in 0..cnt {
let matchlet = &matchlets[32 * idx..][..32];
if self.try_matchlet(matchlet, bytes)? {
let cnt = parse_u32(matchlet, 24);
let off = parse_u32(matchlet, 28);
if cnt == 0 || self.try_match(cnt, off, bytes)? {
return Some(true);
}
}
}
Some(false)
}
#[inline]
fn try_matchlet(&self, matchlet: &[u8], bytes: &[u8]) -> Option<bool> {
let range_start = parse_u32(matchlet, 0) as usize;
let range_length = parse_u32(matchlet, 4) as usize;
let data_length = parse_u32(matchlet, 12) as usize;
let data_offset = parse_u32(matchlet, 16) as usize;
let mask_offset = parse_u32(matchlet, 20) as usize;
let data = self.cache.as_ref().get(data_offset..)?.get(..data_length)?;
let mask = if mask_offset != 0 {
Some(self.cache.as_ref().get(mask_offset..)?.get(..data_length)?)
} else {
None
};
for pos in range_start..range_start + range_length {
if let Some(bytes) = bytes.get(pos..pos + data_length) {
if let Some(mask) = mask {
if data
.iter()
.zip(mask)
.zip(bytes)
.all(|((data, mask), byte)| data & mask == byte & mask)
{
return Some(true);
}
} else {
if data == bytes {
return Some(true);
}
}
} else {
break;
}
}
Some(false)
}
}
fn add_parents(
cache: &[u8],
parents: &mut HashMap<u32, Vec<u32>>,
types: &mut HashMap<u32, Type>,
) -> Result<(u32, u32), Error> {
let text_off =
find_or_insert_mime_type(cache, types, "text/plain\0").ok_or(Error::MissingTextPlain)?;
let binary_off = find_or_insert_mime_type(cache, types, "application/octet-stream\0")
.ok_or(Error::MissingTextPlain)?;
types.retain(|mime_type_off, _type| {
if *mime_type_off == text_off || *mime_type_off == binary_off {
true
} else if let Some(mime_type) = resolve_mime_type(cache, *mime_type_off) {
let parent_mime_type_off = if mime_type.starts_with("text/") {
text_off
} else {
binary_off
};
parents
.entry(*mime_type_off)
.or_default()
.push(parent_mime_type_off);
true
} else {
false
}
});
Ok((text_off, binary_off))
}
fn find_or_insert_mime_type(
cache: &[u8],
types: &mut HashMap<u32, Type>,
mime_type0: &str,
) -> Option<u32> {
let mime_type = &mime_type0[..mime_type0.len() - 1];
let mime_type_off = types
.keys()
.copied()
.find(|mime_type_off| resolve_mime_type(cache, *mime_type_off) == Some(mime_type));
mime_type_off.or_else(|| {
let mime_type_off = find(cache, mime_type0.as_bytes())?.try_into().ok()?;
types.insert(mime_type_off, Default::default());
Some(mime_type_off)
})
}
fn add_children(parents: &HashMap<u32, Vec<u32>>, types: &mut HashMap<u32, Type>) {
for (mime_type_off, parents) in parents {
if types.contains_key(mime_type_off) {
for parent_mime_type_off in parents {
if let Some(parent_type) = types.get_mut(parent_mime_type_off) {
parent_type.children.push(*mime_type_off);
}
}
}
}
for r#type in types.values_mut() {
r#type.children.sort_unstable();
r#type.children.dedup();
}
}
fn resolve_mime_type(cache: &[u8], mime_type_off: u32) -> Option<&str> {
let off = mime_type_off as usize;
let bytes = cache.get(off..)?;
let pos = memchr(0, bytes)?;
let mime_type = &bytes[..pos];
from_utf8(mime_type).ok()
}
fn parse_offsets(cache: &[u8]) -> Option<(&[u8], &[u8])> {
if cache.len() < 2 * 2 + 9 * 4 {
return None;
}
let parents_off = parse_u32(cache, 2 * 2 + 4) as usize;
let types_off = parse_u32(cache, 2 * 2 + 5 * 4) as usize;
let parents = cache.get(parents_off..)?;
let types = cache.get(types_off..)?;
Some((parents, types))
}
fn parse_parents<'a>(cache: &'a [u8], bytes: &'a [u8]) -> Option<HashMap<u32, Vec<u32>>> {
if bytes.len() < 4 {
return None;
}
let cnt = parse_u32(bytes, 0) as usize;
let bytes = bytes.get(4..)?.get(..8 * cnt)?;
let mut parents = HashMap::<_, Vec<_>>::with_capacity(cnt);
for idx in 0..cnt {
let bytes = &bytes[8 * idx..][..8];
let mime_type_off = parse_u32(bytes, 0);
let parents_off = parse_u32(bytes, 4) as usize;
let bytes = cache.get(parents_off..)?;
if bytes.len() < 4 {
return None;
}
let cnt = parse_u32(bytes, 0) as usize;
let bytes = bytes.get(4..)?.get(..4 * cnt)?;
let parents = parents.entry(mime_type_off).or_default();
parents.reserve(cnt);
for idx in 0..cnt {
let parent_mime_type_off = parse_u32(bytes, 4 * idx);
parents.push(parent_mime_type_off);
}
}
Some(parents)
}
fn parse_types(cache: &[u8], bytes: &[u8]) -> Option<HashMap<u32, Type>> {
if bytes.len() < 3 * 4 {
return None;
}
let cnt = parse_u32(bytes, 0) as usize;
let off = parse_u32(bytes, 8) as usize;
let bytes = cache.get(off..)?.get(..16 * cnt)?;
let mut types = HashMap::<_, Type>::with_capacity(cnt);
for idx in 0..cnt {
let bytes = &bytes[16 * idx..];
let mime_type_off = parse_u32(bytes, 4);
let cnt = parse_u32(bytes, 8);
let off = parse_u32(bytes, 12);
types
.entry(mime_type_off)
.or_default()
.matches
.push(Match { cnt, off });
}
Some(types)
}
fn parse_u32(bytes: &[u8], off: usize) -> u32 {
u32::from_be_bytes(bytes[off..][..4].try_into().unwrap())
}
#[derive(Debug)]
pub enum Error {
MalformedOffsets,
MalformedParents,
MalformedTypes,
MalformedMatch,
MissingTextPlain,
MissingApplicationOctetStream,
InvalidMimeType,
}
impl fmt::Display for Error {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
let msg = match self {
Self::MalformedOffsets => "Malformed offsets",
Self::MalformedParents => "Malformed parents",
Self::MalformedTypes => "Malformed types",
Self::MalformedMatch => "Malformed match",
Self::MissingTextPlain => "Missing text/plain MIME type",
Self::MissingApplicationOctetStream => "Missing application/octet-stream MIME type",
Self::InvalidMimeType => "Invalid MIME type",
};
fmt.write_str(msg)
}
}
impl StdError for Error {}