use crate::mmdb::{MmdbError, MmdbHeader, SearchTree};
use lru::LruCache;
use matchy_data_format::DataValue;
use matchy_literal_hash::LiteralHash;
use matchy_paraglob::Paraglob;
use std::cell::RefCell;
use std::hash::BuildHasherDefault;
use std::net::IpAddr;
use std::num::NonZeroUsize;
use std::path::PathBuf;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
#[cfg(not(target_family = "wasm"))]
use std::time::Duration;
#[cfg(not(target_family = "wasm"))]
use crate::updater::{LiveOptions, LiveState};
#[cfg(not(target_family = "wasm"))]
use memmap2::Mmap;
#[cfg(not(target_family = "wasm"))]
use std::fs::File;
#[cfg(not(target_family = "wasm"))]
pub use crate::updater::{
FallbackCallback, FallbackEvent, ReloadCallback, ReloadEvent, ReloadSource,
};
type QueryCacheInner = LruCache<String, QueryResult, BuildHasherDefault<rustc_hash::FxHasher>>;
thread_local! {
static QUERY_CACHES: RefCell<rustc_hash::FxHashMap<u64, QueryCacheInner>> =
RefCell::new(rustc_hash::FxHashMap::default());
}
static NEXT_CACHE_GENERATION: AtomicU64 = AtomicU64::new(1);
pub(crate) fn next_cache_generation() -> u64 {
NEXT_CACHE_GENERATION.fetch_add(1, Ordering::Relaxed)
}
#[derive(Debug, Default)]
pub struct DatabaseStats {
pub total_queries: AtomicU64,
pub queries_with_match: AtomicU64,
pub queries_without_match: AtomicU64,
pub cache_hits: AtomicU64,
pub cache_misses: AtomicU64,
pub ip_queries: AtomicU64,
pub string_queries: AtomicU64,
}
#[derive(Debug, Clone, Copy, Default)]
pub struct DatabaseStatsSnapshot {
pub total_queries: u64,
pub queries_with_match: u64,
pub queries_without_match: u64,
pub cache_hits: u64,
pub cache_misses: u64,
pub ip_queries: u64,
pub string_queries: u64,
}
impl DatabaseStats {
pub fn snapshot(&self) -> DatabaseStatsSnapshot {
DatabaseStatsSnapshot {
total_queries: self.total_queries.load(Ordering::Relaxed),
queries_with_match: self.queries_with_match.load(Ordering::Relaxed),
queries_without_match: self.queries_without_match.load(Ordering::Relaxed),
cache_hits: self.cache_hits.load(Ordering::Relaxed),
cache_misses: self.cache_misses.load(Ordering::Relaxed),
ip_queries: self.ip_queries.load(Ordering::Relaxed),
string_queries: self.string_queries.load(Ordering::Relaxed),
}
}
}
impl DatabaseStatsSnapshot {
#[must_use]
pub fn cache_hit_rate(&self) -> f64 {
let total_cache_ops = self.cache_hits + self.cache_misses;
if total_cache_ops == 0 {
0.0
} else {
self.cache_hits as f64 / total_cache_ops as f64
}
}
#[must_use]
pub fn match_rate(&self) -> f64 {
if self.total_queries == 0 {
0.0
} else {
self.queries_with_match as f64 / self.total_queries as f64
}
}
}
#[derive(Debug, Clone)]
pub enum QueryResult {
Ip {
data: DataValue,
prefix_len: u8,
data_offset: u32,
},
Pattern {
pattern_ids: Vec<u32>,
data: Vec<Option<DataValue>>,
data_offsets: Vec<u32>,
},
NotFound,
}
#[derive(Debug, Clone, Copy)]
pub struct LookupRef {
pub found: bool,
pub data_offset: u32,
pub prefix_len: u8,
pub result_type: u8,
}
impl LookupRef {
#[inline]
#[must_use]
pub const fn not_found() -> Self {
Self {
found: false,
data_offset: 0,
prefix_len: 0,
result_type: 0,
}
}
#[inline]
#[must_use]
pub const fn ip(data_offset: u32, prefix_len: u8) -> Self {
Self {
found: true,
data_offset,
prefix_len,
result_type: 1,
}
}
#[inline]
#[must_use]
pub const fn pattern(data_offset: u32) -> Self {
Self {
found: true,
data_offset,
prefix_len: 0,
result_type: 2,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum DatabaseFormat {
IpOnly,
PatternOnly,
Combined,
}
enum DatabaseStorage {
Owned(Vec<u8>),
#[cfg(not(target_family = "wasm"))]
Mmap(Mmap),
}
impl DatabaseStorage {
fn as_slice(&self) -> &[u8] {
match self {
Self::Owned(v) => v.as_slice(),
#[cfg(not(target_family = "wasm"))]
Self::Mmap(m) => &m[..],
}
}
}
#[derive(Clone)]
struct PatternDataMappings {
mappings_offset: usize,
pattern_count: usize,
}
impl PatternDataMappings {
fn get_offset(&self, pattern_id: u32, data: &[u8]) -> Option<u32> {
if pattern_id as usize >= self.pattern_count {
return None;
}
let offset_pos = self.mappings_offset + (pattern_id as usize * 4);
if offset_pos + 4 > data.len() {
return None;
}
Some(u32::from_le_bytes([
data[offset_pos],
data[offset_pos + 1],
data[offset_pos + 2],
data[offset_pos + 3],
]))
}
}
const DEFAULT_QUERY_CACHE_SIZE: usize = 10_000;
#[derive(Clone, Default)]
pub struct DatabaseOptions {
pub path: PathBuf,
pub cache_capacity: Option<usize>,
pub bytes: Option<Vec<u8>>,
pub cache_generation: Option<u64>,
}
pub struct DatabaseOpener {
options: DatabaseOptions,
#[cfg(not(target_family = "wasm"))]
live: LiveOptions,
}
impl DatabaseOpener {
fn new(path: impl Into<PathBuf>) -> Self {
Self {
options: DatabaseOptions {
path: path.into(),
..Default::default()
},
#[cfg(not(target_family = "wasm"))]
live: LiveOptions::default(),
}
}
#[must_use]
pub fn cache_capacity(mut self, capacity: usize) -> Self {
self.options.cache_capacity = Some(capacity);
self
}
#[must_use]
pub fn no_cache(mut self) -> Self {
self.options.cache_capacity = Some(0);
self
}
#[cfg(not(target_family = "wasm"))]
#[must_use]
pub fn watch(mut self) -> Self {
self.live.enabled = true;
self
}
#[cfg(all(not(target_family = "wasm"), feature = "auto-update"))]
#[must_use]
pub fn auto_update(mut self) -> Self {
self.live.enabled = true;
self.live.auto_update_enabled = true;
if self.live.update_interval.is_none() {
self.live.update_interval = Some(Duration::from_secs(
crate::updater::DEFAULT_UPDATE_INTERVAL_SECS,
));
}
self
}
#[cfg(all(not(target_family = "wasm"), feature = "auto-update"))]
#[must_use]
pub fn update_interval(mut self, interval: Duration) -> Self {
self.live.update_interval = Some(interval);
self
}
#[cfg(all(not(target_family = "wasm"), feature = "auto-update"))]
#[must_use]
pub fn cache_dir(mut self, path: impl Into<std::path::PathBuf>) -> Self {
self.live.cache_dir = Some(path.into());
self
}
#[cfg(not(target_family = "wasm"))]
#[must_use]
pub fn poll_interval(mut self, interval: Duration) -> Self {
self.live.poll_interval = Some(interval);
self
}
#[cfg(not(target_family = "wasm"))]
#[must_use]
pub fn on_reload<F>(mut self, callback: F) -> Self
where
F: Fn(ReloadEvent) + Send + Sync + 'static,
{
self.live.reload_callback = Some(Arc::new(callback));
self
}
#[cfg(not(target_family = "wasm"))]
#[must_use]
pub fn on_fallback<F>(mut self, callback: F) -> Self
where
F: Fn(FallbackEvent) + Send + Sync + 'static,
{
self.live.fallback_callback = Some(Arc::new(callback));
self
}
#[cfg(not(target_family = "wasm"))]
pub fn open(self) -> Result<Database, DatabaseError> {
let db = Database::open_with_options(self.options.clone())?;
#[cfg(feature = "auto-update")]
if self.live.auto_update_enabled && db.update_url().is_none() {
return Err(DatabaseError::Config(
"auto_update() requires database with embedded update URL".to_string(),
));
}
if self.live.enabled {
let live_state =
self.live
.start_updater(&self.options.path, db, self.options.cache_capacity);
Ok(Database::with_live_state(live_state))
} else {
Ok(db)
}
}
#[cfg(target_family = "wasm")]
pub fn open(self) -> Result<Database, DatabaseError> {
Database::open_with_options(self.options)
}
#[must_use]
pub fn from_bytes_builder(bytes: Vec<u8>) -> Self {
Self {
options: DatabaseOptions {
bytes: Some(bytes),
..Default::default()
},
#[cfg(not(target_family = "wasm"))]
live: LiveOptions::default(),
}
}
}
pub struct Database {
data: DatabaseStorage,
format: DatabaseFormat,
ip_header: Option<MmdbHeader>,
literal_hash: Option<LiteralHash<'static>>,
pattern_matcher: Option<Paraglob>,
pattern_data_mappings: Option<PatternDataMappings>,
cache_capacity: usize,
cache_enabled: bool,
stats: Arc<DatabaseStats>,
cache_generation: u64,
#[cfg(not(target_family = "wasm"))]
live: Option<Box<LiveState>>,
}
unsafe impl Send for Database {}
unsafe impl Sync for Database {}
impl Database {
#[inline]
fn with_cache<F, R>(&self, f: F) -> Option<R>
where
F: FnOnce(&mut QueryCacheInner) -> R,
{
if !self.cache_enabled {
return None;
}
QUERY_CACHES.with(|caches| {
let mut caches_borrow = caches.borrow_mut();
let cache = caches_borrow
.entry(self.cache_generation)
.or_insert_with(|| {
LruCache::with_hasher(
NonZeroUsize::new(self.cache_capacity)
.expect("cache_capacity > 0 when cache_enabled is true"),
BuildHasherDefault::<rustc_hash::FxHasher>::default(),
)
});
Some(f(cache))
})
}
pub fn from(path: impl Into<PathBuf>) -> DatabaseOpener {
DatabaseOpener::new(path)
}
#[must_use]
pub fn from_bytes_builder(bytes: Vec<u8>) -> DatabaseOpener {
DatabaseOpener::from_bytes_builder(bytes)
}
pub fn clear_cache(&self) {
if self.cache_enabled {
QUERY_CACHES.with(|caches| {
if let Some(cache) = caches.borrow_mut().get_mut(&self.cache_generation) {
cache.clear();
}
});
}
}
pub fn clear_cache_generation(generation: u64) {
QUERY_CACHES.with(|caches| {
caches.borrow_mut().remove(&generation);
});
}
#[must_use]
pub fn cache_size(&self) -> usize {
if !self.cache_enabled {
return 0;
}
QUERY_CACHES.with(|caches| {
caches
.borrow()
.get(&self.cache_generation)
.map_or(0, lru::LruCache::len)
})
}
#[must_use]
pub fn stats(&self) -> DatabaseStatsSnapshot {
self.stats.snapshot()
}
#[must_use]
pub fn mode(&self) -> matchy_match_mode::MatchMode {
if let Some(ref pm) = self.pattern_matcher {
return pm.mode();
}
if let Some(ref lh) = self.literal_hash {
return lh.mode();
}
matchy_match_mode::MatchMode::CaseSensitive
}
pub fn open_with_options(options: DatabaseOptions) -> Result<Self, DatabaseError> {
let mut db = if let Some(bytes) = options.bytes {
Self::from_storage(DatabaseStorage::Owned(bytes))?
} else {
Self::open_internal(
options
.path
.to_str()
.ok_or_else(|| DatabaseError::Io("Invalid path encoding".to_string()))?,
)?
};
if let Some(capacity) = options.cache_capacity {
if capacity == 0 {
db.cache_enabled = false;
} else {
db.cache_capacity = capacity;
db.cache_enabled = true;
}
}
if let Some(generation) = options.cache_generation {
db.cache_generation = generation;
}
Ok(db)
}
#[cfg(not(target_family = "wasm"))]
pub(crate) fn open_internal(path: &str) -> Result<Self, DatabaseError> {
let file = File::open(path)
.map_err(|e| DatabaseError::Io(format!("Failed to open {path}: {e}")))?;
let mmap = unsafe { Mmap::map(&file) }
.map_err(|e| DatabaseError::Io(format!("Failed to mmap {path}: {e}")))?;
Self::from_storage(DatabaseStorage::Mmap(mmap))
}
#[cfg(target_family = "wasm")]
pub(crate) fn open_internal(path: &str) -> Result<Self, DatabaseError> {
let bytes = std::fs::read(path)
.map_err(|e| DatabaseError::Io(format!("Failed to read {}: {}", path, e)))?;
Self::from_storage(DatabaseStorage::Owned(bytes))
}
pub fn from_bytes(data: Vec<u8>) -> Result<Self, DatabaseError> {
Self::from_storage(DatabaseStorage::Owned(data))
}
#[cfg(not(target_family = "wasm"))]
fn with_live_state(live_state: LiveState) -> Self {
let snapshot = live_state.current.load_full();
Self {
data: DatabaseStorage::Owned(vec![]),
format: snapshot.format,
ip_header: None,
literal_hash: None,
pattern_matcher: None,
pattern_data_mappings: None,
cache_capacity: snapshot.cache_capacity,
cache_enabled: snapshot.cache_enabled,
stats: snapshot.stats.clone(),
cache_generation: live_state.generation.load(Ordering::Acquire),
live: Some(Box::new(live_state)),
}
}
#[cfg(not(target_family = "wasm"))]
fn resolve_live_db(live: &LiveState) -> Arc<Database> {
use crate::updater::LOCAL_DB;
let current_gen = live.generation.load(Ordering::Acquire);
LOCAL_DB.with(|local| {
let mut local_ref = local.borrow_mut();
match &*local_ref {
Some((gen, db)) if *gen == current_gen => db.clone(),
_ => {
let new_db = live.current.load_full();
*local_ref = Some((current_gen, new_db.clone()));
new_db
}
}
})
}
#[cfg(not(target_family = "wasm"))]
fn lookup_live(
&self,
query: &str,
live: &LiveState,
) -> Result<Option<QueryResult>, DatabaseError> {
use crate::updater::{FallbackEvent, LOCAL_DB};
let db = Self::resolve_live_db(live);
match db.lookup(query) {
Ok(result) => Ok(result),
Err(e) if e.is_data_error() => {
if let Some(prev_db) = live.previous.load_full().as_ref() {
match prev_db.lookup(query) {
Ok(result) => {
live.current.store(prev_db.clone());
live.previous.store(Arc::new(None));
LOCAL_DB.with(|local| {
*local.borrow_mut() = None;
});
if let Some(ref callback) = live.fallback_callback {
callback(FallbackEvent {
error: e.to_string(),
generation: live.generation.load(Ordering::Acquire),
});
}
Ok(result)
}
Err(_) => Err(e),
}
} else {
Err(e)
}
}
Err(e) => Err(e),
}
}
fn from_storage(storage: DatabaseStorage) -> Result<Self, DatabaseError> {
let mut db = Self {
data: storage,
format: DatabaseFormat::IpOnly,
ip_header: None,
literal_hash: None,
pattern_matcher: None,
pattern_data_mappings: None,
cache_capacity: DEFAULT_QUERY_CACHE_SIZE,
cache_enabled: true,
stats: Arc::new(DatabaseStats::default()),
cache_generation: next_cache_generation(),
#[cfg(not(target_family = "wasm"))]
live: None,
};
let data: &'static [u8] = unsafe { std::mem::transmute(db.data.as_slice()) };
db.format = Self::detect_format(data)?;
match db.format {
DatabaseFormat::IpOnly => {
db.ip_header = Some(MmdbHeader::from_file(data).map_err(DatabaseError::Format)?);
}
DatabaseFormat::PatternOnly => {
let pg = Self::load_pattern_section(data, 0).map_err(|e| {
DatabaseError::Unsupported(format!("Failed to load pattern section: {e}"))
})?;
db.pattern_matcher = Some(pg);
}
DatabaseFormat::Combined => {
db.ip_header = Some(MmdbHeader::from_file(data).map_err(DatabaseError::Format)?);
if let Some(offset) = Self::find_pattern_section_fast(data) {
let (pg, map) =
Self::load_combined_pattern_section(data, offset).map_err(|e| {
DatabaseError::Unsupported(format!(
"Failed to load pattern section: {e}"
))
})?;
db.pattern_matcher = Some(pg);
db.pattern_data_mappings = Some(map);
}
}
}
if let Some(offset) = Self::find_literal_section_fast(data) {
let literal_data = &data[offset + 16..];
let match_mode = Self::read_match_mode_from_metadata(data);
db.literal_hash = Some(LiteralHash::from_buffer(literal_data, match_mode).map_err(
|e| DatabaseError::Unsupported(format!("Failed to load literal hash: {e}")),
)?);
}
Ok(db)
}
#[cfg(not(target_family = "wasm"))]
#[must_use]
pub fn generation(&self) -> u64 {
match &self.live {
Some(live) => live.generation.load(Ordering::Acquire),
None => 0,
}
}
#[cfg(target_family = "wasm")]
pub fn generation(&self) -> u64 {
0
}
pub fn lookup(&self, query: &str) -> Result<Option<QueryResult>, DatabaseError> {
#[cfg(not(target_family = "wasm"))]
if let Some(ref live) = self.live {
return self.lookup_live(query, live);
}
if let Some(Some(result)) = self.with_cache(|cache| cache.get(query).cloned()) {
self.stats.total_queries.fetch_add(1, Ordering::Relaxed);
self.stats.cache_hits.fetch_add(1, Ordering::Relaxed);
match &result {
QueryResult::Ip { .. } => {
self.stats.ip_queries.fetch_add(1, Ordering::Relaxed);
self.stats
.queries_with_match
.fetch_add(1, Ordering::Relaxed);
}
QueryResult::Pattern { .. } => {
self.stats.string_queries.fetch_add(1, Ordering::Relaxed);
self.stats
.queries_with_match
.fetch_add(1, Ordering::Relaxed);
}
QueryResult::NotFound => {
if query.parse::<IpAddr>().is_ok() {
self.stats.ip_queries.fetch_add(1, Ordering::Relaxed);
} else {
self.stats.string_queries.fetch_add(1, Ordering::Relaxed);
}
self.stats
.queries_without_match
.fetch_add(1, Ordering::Relaxed);
}
}
return Ok(Some(result));
}
let result = if let Ok(addr) = query.parse::<IpAddr>() {
self.lookup_ip_uncached(addr)?
} else {
self.lookup_string_uncached(query)?
};
self.stats.total_queries.fetch_add(1, Ordering::Relaxed);
if self.cache_enabled {
self.stats.cache_misses.fetch_add(1, Ordering::Relaxed);
}
match &result {
Some(QueryResult::Ip { .. }) => {
self.stats.ip_queries.fetch_add(1, Ordering::Relaxed);
self.stats
.queries_with_match
.fetch_add(1, Ordering::Relaxed);
}
Some(QueryResult::Pattern { .. }) => {
self.stats.string_queries.fetch_add(1, Ordering::Relaxed);
self.stats
.queries_with_match
.fetch_add(1, Ordering::Relaxed);
}
Some(QueryResult::NotFound) => {
self.stats.string_queries.fetch_add(1, Ordering::Relaxed);
self.stats
.queries_without_match
.fetch_add(1, Ordering::Relaxed);
}
None => {
self.stats
.queries_without_match
.fetch_add(1, Ordering::Relaxed);
}
}
if let Some(ref res) = result {
self.with_cache(|cache| cache.put(query.to_string(), res.clone()));
}
Ok(result)
}
fn lookup_ip_uncached(&self, addr: IpAddr) -> Result<Option<QueryResult>, DatabaseError> {
let header = match &self.ip_header {
Some(h) => h,
None => return Ok(None), };
let tree = SearchTree::new(self.data.as_slice(), header);
let tree_result = tree.lookup(addr).map_err(DatabaseError::Format)?;
let tree_result = match tree_result {
Some(r) => r,
None => return Ok(Some(QueryResult::NotFound)),
};
let data = self.decode_ip_data(header, tree_result.data_offset)?;
Ok(Some(QueryResult::Ip {
data,
prefix_len: tree_result.prefix_len,
data_offset: tree_result.data_offset,
}))
}
pub fn lookup_ip(&self, addr: IpAddr) -> Result<Option<QueryResult>, DatabaseError> {
let query = addr.to_string();
if let Some(Some(result)) = self.with_cache(|cache| cache.get(&query).cloned()) {
return Ok(Some(result));
}
let result = self.lookup_ip_uncached(addr)?;
if let Some(ref res) = result {
self.with_cache(|cache| cache.put(query, res.clone()));
}
Ok(result)
}
pub fn lookup_extracted(
&self,
item: &crate::extractor::Match,
input: &[u8],
) -> Result<Option<QueryResult>, DatabaseError> {
use crate::extractor::ExtractedItem;
match &item.item {
ExtractedItem::Ipv4(ip) => self.lookup_ip(IpAddr::V4(*ip)),
ExtractedItem::Ipv6(ip) => self.lookup_ip(IpAddr::V6(*ip)),
_ => self.lookup(item.as_str(input)),
}
}
fn lookup_string_uncached(&self, pattern: &str) -> Result<Option<QueryResult>, DatabaseError> {
let mut all_pattern_ids = Vec::new();
let mut all_data_values = Vec::new();
let mut all_data_offsets = Vec::new();
if let Some(literal_hash) = &self.literal_hash {
if let Some(pattern_id) = literal_hash.lookup(pattern) {
if let Some(data_offset) = literal_hash.get_data_offset(pattern_id) {
let header = self.ip_header.as_ref().ok_or_else(|| {
DatabaseError::Format(MmdbError::InvalidFormat(
"Literal hash present but no IP header".to_string(),
))
})?;
let data = self.decode_ip_data(header, data_offset)?;
all_pattern_ids.push(pattern_id);
all_data_values.push(Some(data));
all_data_offsets.push(data_offset);
}
}
}
if let Some(ref pg) = self.pattern_matcher {
let glob_pattern_ids = pg.find_all(pattern);
for &pattern_id in &glob_pattern_ids {
let (data, offset) = match (&self.pattern_data_mappings, &self.ip_header) {
(Some(mappings), Some(header)) => {
if let Some(data_offset) =
mappings.get_offset(pattern_id, self.data.as_slice())
{
(Some(self.decode_ip_data(header, data_offset)?), data_offset)
} else {
(None, 0)
}
}
(Some(_), None) => {
unreachable!(
"pattern_data_mappings present without ip_header - invalid database state"
)
}
(None, _) => {
(pg.get_pattern_data(pattern_id), 0)
}
};
all_pattern_ids.push(pattern_id);
all_data_values.push(data);
all_data_offsets.push(offset);
}
}
if all_pattern_ids.is_empty() {
if self.literal_hash.is_some() || self.pattern_matcher.is_some() {
Ok(Some(QueryResult::NotFound))
} else {
Ok(None)
}
} else {
Ok(Some(QueryResult::Pattern {
pattern_ids: all_pattern_ids,
data: all_data_values,
data_offsets: all_data_offsets,
}))
}
}
pub fn lookup_string(&self, pattern: &str) -> Result<Option<QueryResult>, DatabaseError> {
if let Some(Some(result)) = self.with_cache(|cache| cache.get(pattern).cloned()) {
return Ok(Some(result));
}
let result = self.lookup_string_uncached(pattern)?;
if let Some(ref res) = result {
self.with_cache(|cache| cache.put(pattern.to_string(), res.clone()));
}
Ok(result)
}
fn decode_ip_data(&self, header: &MmdbHeader, offset: u32) -> Result<DataValue, DatabaseError> {
use matchy_data_format::DataDecoder;
let data_section_start = header.tree_size + 16;
let data_section = &self.data.as_slice()[data_section_start..];
let decoder = DataDecoder::new(data_section, 0);
decoder
.decode(offset)
.map_err(|e| DatabaseError::Format(MmdbError::DecodeError(e.to_string())))
}
pub fn lookup_ref(&self, query: &str) -> Result<LookupRef, DatabaseError> {
#[cfg(not(target_family = "wasm"))]
if let Some(ref live) = self.live {
return Self::resolve_live_db(live).lookup_ref(query);
}
if let Some(Some(result)) = self.with_cache(|cache| cache.get(query).cloned()) {
return Ok(match result {
QueryResult::Ip {
prefix_len,
data_offset,
..
} => LookupRef::ip(data_offset, prefix_len),
QueryResult::Pattern { data_offsets, .. } => {
LookupRef::pattern(*data_offsets.first().unwrap_or(&0))
}
QueryResult::NotFound => LookupRef::not_found(),
});
}
if let Ok(addr) = query.parse::<IpAddr>() {
self.lookup_ip_ref(addr)
} else {
Ok(self.lookup_string_ref(query))
}
}
fn lookup_ip_ref(&self, addr: IpAddr) -> Result<LookupRef, DatabaseError> {
let header = match &self.ip_header {
Some(h) => h,
None => return Ok(LookupRef::not_found()),
};
let tree = SearchTree::new(self.data.as_slice(), header);
let tree_result = tree.lookup(addr).map_err(DatabaseError::Format)?;
match tree_result {
Some(r) => Ok(LookupRef::ip(r.data_offset, r.prefix_len)),
None => Ok(LookupRef::not_found()),
}
}
fn lookup_string_ref(&self, pattern: &str) -> LookupRef {
if let Some(literal_hash) = &self.literal_hash {
if let Some(pattern_id) = literal_hash.lookup(pattern) {
if let Some(data_offset) = literal_hash.get_data_offset(pattern_id) {
return LookupRef::pattern(data_offset);
}
}
}
if let Some(ref pg) = self.pattern_matcher {
let glob_pattern_ids = pg.find_all(pattern);
if let Some(&pattern_id) = glob_pattern_ids.first() {
if let Some(mappings) = &self.pattern_data_mappings {
if let Some(data_offset) = mappings.get_offset(pattern_id, self.data.as_slice())
{
return LookupRef::pattern(data_offset);
}
}
}
}
LookupRef::not_found()
}
pub fn decode_at_offset(&self, offset: u32) -> Result<DataValue, DatabaseError> {
#[cfg(not(target_family = "wasm"))]
if let Some(ref live) = self.live {
return Self::resolve_live_db(live).decode_at_offset(offset);
}
let header = self.ip_header.as_ref().ok_or_else(|| {
DatabaseError::Format(MmdbError::InvalidFormat(
"No IP header - cannot decode data".to_string(),
))
})?;
self.decode_ip_data(header, offset)
}
fn detect_format(data: &[u8]) -> Result<DatabaseFormat, DatabaseError> {
let has_paraglob_start = data.len() >= 8 && &data[0..8] == b"PARAGLOB";
if has_paraglob_start {
return Ok(DatabaseFormat::PatternOnly);
}
let has_mmdb = crate::mmdb::find_metadata_marker(data).is_ok();
if !has_mmdb {
return Err(DatabaseError::Format(MmdbError::InvalidFormat(
"Unknown database format (no MMDB or PARAGLOB marker)".to_string(),
)));
}
if let Ok(metadata) = crate::mmdb::MmdbMetadata::from_file(data) {
if let Ok(DataValue::Map(map)) = metadata.as_value() {
if let Some(DataValue::Uint32(pattern_offset)) = map.get("pattern_section_offset") {
let has_patterns = *pattern_offset != 0;
if let Some(DataValue::Uint32(literal_offset)) =
map.get("literal_section_offset")
{
let has_literals = *literal_offset != 0;
if has_patterns || has_literals {
return Ok(DatabaseFormat::Combined);
} else {
return Ok(DatabaseFormat::IpOnly);
}
}
}
}
}
let pattern_separator = b"MMDB_PATTERN\x00\x00\x00\x00";
let has_pattern_section = data.windows(16).any(|window| window == pattern_separator);
if has_pattern_section {
Ok(DatabaseFormat::Combined)
} else {
Ok(DatabaseFormat::IpOnly)
}
}
#[must_use]
pub fn format(&self) -> &str {
match self.format {
DatabaseFormat::IpOnly => "IP database",
DatabaseFormat::PatternOnly => "Pattern database",
DatabaseFormat::Combined => "Combined IP+Pattern database",
}
}
#[must_use]
pub fn has_ip_data(&self) -> bool {
self.ip_header.is_some()
}
#[must_use]
pub fn has_string_data(&self) -> bool {
self.literal_hash.is_some() || self.pattern_matcher.is_some()
}
#[must_use]
pub fn has_literal_data(&self) -> bool {
self.literal_hash.is_some()
}
#[must_use]
pub fn has_glob_data(&self) -> bool {
self.pattern_matcher.is_some()
}
#[deprecated(
since = "0.5.0",
note = "Use has_literal_data or has_glob_data instead"
)]
#[must_use]
pub fn has_pattern_data(&self) -> bool {
self.has_string_data()
}
#[must_use]
pub fn metadata(&self) -> Option<DataValue> {
#[cfg(not(target_family = "wasm"))]
if let Some(live) = &self.live {
return live.current.load().metadata();
}
if !self.has_ip_data() {
return None;
}
use crate::mmdb::MmdbMetadata;
let metadata = MmdbMetadata::from_file(self.data.as_slice()).ok()?;
metadata.as_value().ok()
}
#[must_use]
pub fn update_url(&self) -> Option<String> {
if let Some(DataValue::Map(map)) = self.metadata() {
if let Some(DataValue::String(url)) = map.get("update_url") {
return Some(url.clone());
}
}
None
}
#[must_use]
pub fn get_pattern_string(&self, pattern_id: u32) -> Option<String> {
let pg = self.pattern_matcher.as_ref()?;
pg.get_pattern(pattern_id)
}
#[must_use]
pub fn pattern_count(&self) -> usize {
match &self.pattern_matcher {
Some(pg) => pg.pattern_count(),
None => 0,
}
}
#[must_use]
pub fn glob_count(&self) -> usize {
if let Some(DataValue::Map(map)) = self.metadata() {
if let Some(count) = map.get("glob_entry_count") {
if let Some(val) = Self::extract_uint_from_datavalue(count) {
return usize::try_from(val).unwrap_or(usize::MAX);
}
}
}
self.pattern_count()
}
#[must_use]
pub fn literal_count(&self) -> usize {
if let Some(DataValue::Map(map)) = self.metadata() {
if let Some(count) = map.get("literal_entry_count") {
if let Some(val) = Self::extract_uint_from_datavalue(count) {
return usize::try_from(val).unwrap_or(usize::MAX);
}
}
}
match &self.literal_hash {
Some(lh) => lh.entry_count() as usize,
None => 0,
}
}
#[must_use]
pub fn ip_count(&self) -> usize {
if let Some(DataValue::Map(map)) = self.metadata() {
if let Some(count) = map.get("ip_entry_count") {
if let Some(val) = Self::extract_uint_from_datavalue(count) {
return usize::try_from(val).unwrap_or(usize::MAX);
}
}
if let Some(count) = map.get("node_count") {
if let Some(val) = Self::extract_uint_from_datavalue(count) {
return usize::try_from(val).unwrap_or(usize::MAX);
}
}
}
0
}
fn extract_uint_from_datavalue(value: &DataValue) -> Option<u64> {
match value {
DataValue::Uint16(v) => Some(u64::from(*v)),
DataValue::Uint32(v) => Some(u64::from(*v)),
DataValue::Uint64(v) => Some(*v),
_ => None,
}
}
fn find_pattern_section_fast(data: &[u8]) -> Option<usize> {
if let Ok(metadata) = crate::mmdb::MmdbMetadata::from_file(data) {
if let Ok(DataValue::Map(map)) = metadata.as_value() {
if let Some(DataValue::Uint32(offset)) = map.get("pattern_section_offset") {
let offset_val = *offset as usize;
if offset_val == 0 {
return None;
}
return Some(offset_val);
}
}
}
eprintln!("Warning: Database lacks section offset metadata, falling back to full file scan (slower load time)");
Self::find_pattern_section_slow(data)
}
fn find_pattern_section_slow(data: &[u8]) -> Option<usize> {
let separator = b"MMDB_PATTERN\x00\x00\x00\x00";
for i in 0..data.len().saturating_sub(16) {
if &data[i..i + 16] == separator {
return Some(i + 16);
}
}
None
}
fn find_literal_section_fast(data: &[u8]) -> Option<usize> {
if let Ok(metadata) = crate::mmdb::MmdbMetadata::from_file(data) {
if let Ok(DataValue::Map(map)) = metadata.as_value() {
if let Some(DataValue::Uint32(offset)) = map.get("literal_section_offset") {
let offset_val = *offset as usize;
if offset_val == 0 {
return None;
}
return Some(offset_val - 16);
}
}
}
if data.len() > 1024 * 1024 {
eprintln!("Warning: Database lacks section offset metadata, falling back to full file scan (slower load time)");
}
Self::find_literal_section_slow(data)
}
fn find_literal_section_slow(data: &[u8]) -> Option<usize> {
let separator = b"MMDB_LITERAL\x00\x00\x00\x00";
(0..data.len().saturating_sub(16)).find(|&i| &data[i..i + 16] == separator)
}
fn load_pattern_section(data: &'static [u8], offset: usize) -> Result<Paraglob, String> {
if offset >= data.len() {
return Err("Pattern section offset out of bounds".to_string());
}
let match_mode = Self::read_match_mode_from_metadata(data);
if offset == 0 && data.len() >= 8 && &data[0..8] == b"PARAGLOB" {
let result = unsafe { Paraglob::from_mmap(data, match_mode) };
return result.map_err(|e| format!("Failed to parse pattern-only database: {e}"));
}
Err("Invalid pattern-only database format".to_string())
}
fn load_combined_pattern_section(
data: &'static [u8],
offset: usize,
) -> Result<(Paraglob, PatternDataMappings), String> {
if offset >= data.len() {
return Err("Pattern section offset out of bounds".to_string());
}
let match_mode = Self::read_match_mode_from_metadata(data);
if offset + 8 > data.len() {
return Err("Pattern section header truncated".to_string());
}
let _total_size = u32::from_le_bytes([
data[offset],
data[offset + 1],
data[offset + 2],
data[offset + 3],
]);
let paraglob_size = u32::from_le_bytes([
data[offset + 4],
data[offset + 5],
data[offset + 6],
data[offset + 7],
]) as usize;
let paraglob_start = offset + 8;
let paraglob_end = paraglob_start + paraglob_size;
if paraglob_end > data.len() {
return Err(format!(
"Paraglob section extends beyond file (start={}, size={}, file_len={})",
paraglob_start,
paraglob_size,
data.len()
));
}
let paraglob_data = &data[paraglob_start..paraglob_end];
let paraglob = unsafe { Paraglob::from_mmap(paraglob_data, match_mode) };
let paraglob = paraglob.map_err(|e| format!("Failed to parse paraglob section: {e}"))?;
let mappings_start = paraglob_end;
if mappings_start + 4 > data.len() {
return Err("Pattern mappings section truncated".to_string());
}
let pattern_count = u32::from_le_bytes([
data[mappings_start],
data[mappings_start + 1],
data[mappings_start + 2],
data[mappings_start + 3],
]) as usize;
let offsets_start = mappings_start + 4;
let total_mapping_bytes = pattern_count * 4;
if offsets_start + total_mapping_bytes > data.len() {
return Err(format!(
"Pattern mappings section out of bounds (need {total_mapping_bytes} bytes)"
));
}
let mappings = PatternDataMappings {
mappings_offset: offsets_start,
pattern_count,
};
Ok((paraglob, mappings))
}
fn read_match_mode_from_metadata(data: &[u8]) -> matchy_match_mode::MatchMode {
use matchy_match_mode::MatchMode;
if let Ok(metadata) = crate::mmdb::MmdbMetadata::from_file(data) {
if let Ok(DataValue::Map(map)) = metadata.as_value() {
if let Some(DataValue::Uint16(mode_val)) = map.get("match_mode") {
return match *mode_val {
1 => MatchMode::CaseInsensitive,
_ => MatchMode::CaseSensitive, };
}
}
}
MatchMode::CaseSensitive
}
}
#[derive(Debug)]
pub enum DatabaseError {
Io(String),
Format(MmdbError),
Unsupported(String),
Config(String),
}
impl std::fmt::Display for DatabaseError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Io(msg) => write!(f, "I/O error: {msg}"),
Self::Format(err) => write!(f, "Format error: {err}"),
Self::Unsupported(msg) => write!(f, "Unsupported: {msg}"),
Self::Config(msg) => write!(f, "Configuration error: {msg}"),
}
}
}
impl std::error::Error for DatabaseError {}
impl DatabaseError {
#[must_use]
pub fn is_data_error(&self) -> bool {
matches!(self, Self::Format(_))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_detect_ip_database() {
let db = Database::from("tests/data/GeoLite2-Country.mmdb")
.open()
.unwrap();
assert_eq!(db.format, DatabaseFormat::IpOnly);
assert!(db.has_ip_data());
assert!(!db.has_string_data());
}
#[test]
fn test_lookup_ip_address() {
let db = Database::from("tests/data/GeoLite2-Country.mmdb")
.open()
.unwrap();
let result = db.lookup("1.1.1.1").unwrap();
assert!(result.is_some());
if let Some(QueryResult::Ip {
data, prefix_len, ..
}) = result
{
assert!(prefix_len > 0);
assert!(prefix_len <= 32);
match data {
DataValue::Map(map) => {
assert!(!map.is_empty());
}
_ => panic!("Expected map data"),
}
} else {
panic!("Expected IP result");
}
}
#[test]
fn test_lookup_ipv6() {
let db = Database::from("tests/data/GeoLite2-Country.mmdb")
.open()
.unwrap();
let result = db.lookup("2001:4860:4860::8888").unwrap();
assert!(result.is_some());
if let Some(QueryResult::Ip { prefix_len, .. }) = result {
assert!(prefix_len > 0);
assert!(prefix_len <= 128);
}
}
#[test]
fn test_lookup_not_found() {
let db = Database::from("tests/data/GeoLite2-Country.mmdb")
.open()
.unwrap();
let result = db.lookup("127.0.0.1").unwrap();
assert!(matches!(result, Some(QueryResult::NotFound)));
}
#[test]
fn test_auto_detect_query_type() {
let db = Database::from("tests/data/GeoLite2-Country.mmdb")
.open()
.unwrap();
let result = db.lookup("8.8.8.8").unwrap();
assert!(matches!(result, Some(QueryResult::Ip { .. })));
let result = db.lookup("example.com").unwrap();
assert!(result.is_none() || matches!(result, Some(QueryResult::NotFound)));
}
#[test]
fn test_lookup_extracted() {
use crate::extractor::Extractor;
let db = Database::from("tests/data/GeoLite2-Country.mmdb")
.open()
.unwrap();
let extractor = Extractor::new().unwrap();
let log_line = b"Connection from 8.8.8.8 and 2001:4860:4860::8888";
let matches: Vec<_> = extractor.extract_from_line(log_line).collect();
assert_eq!(matches.len(), 2, "Should extract 2 IP addresses");
let result = db.lookup_extracted(&matches[0], log_line).unwrap();
assert!(
matches!(result, Some(QueryResult::Ip { .. })),
"IPv4 should match via lookup_extracted"
);
let result = db.lookup_extracted(&matches[1], log_line).unwrap();
assert!(
matches!(result, Some(QueryResult::Ip { .. })),
"IPv6 should match via lookup_extracted"
);
let log_line = b"Visit example.com for more info";
let matches: Vec<_> = extractor.extract_from_line(log_line).collect();
assert_eq!(matches.len(), 1, "Should extract 1 domain");
let result = db.lookup_extracted(&matches[0], log_line).unwrap();
assert!(
result.is_none() || matches!(result, Some(QueryResult::NotFound)),
"Domain should not match in IP-only database"
);
}
#[test]
fn test_ip_count_returns_node_count_for_standard_mmdb() {
let db = Database::from("tests/data/GeoLite2-Country.mmdb")
.open()
.unwrap();
let count = db.ip_count();
assert!(
count > 0,
"ip_count() should return node_count for standard MMDB"
);
assert!(
count > 1_000_000,
"GeoLite2-Country should have > 1M nodes, got {count}"
);
}
#[test]
fn test_ip_count_prefers_ip_entry_count_when_available() {
use matchy_format::DatabaseBuilder;
use matchy_match_mode::MatchMode;
use std::collections::HashMap;
let temp_dir = tempfile::TempDir::new().unwrap();
let output_path = temp_dir.path().join("test.mxy");
let mut builder = DatabaseBuilder::new(MatchMode::CaseSensitive);
let mut data1 = HashMap::new();
data1.insert("test".to_string(), DataValue::String("value1".to_string()));
builder.add_entry("10.0.0.0/8", data1).unwrap();
let mut data2 = HashMap::new();
data2.insert("test".to_string(), DataValue::String("value2".to_string()));
builder.add_entry("192.168.0.0/16", data2).unwrap();
let mut data3 = HashMap::new();
data3.insert("test".to_string(), DataValue::String("value3".to_string()));
builder.add_entry("172.16.0.0/12", data3).unwrap();
let db_data = builder.build().unwrap();
std::fs::write(&output_path, &db_data).unwrap();
let db = Database::from(output_path.to_str().unwrap())
.open()
.unwrap();
let count = db.ip_count();
assert_eq!(
count, 3,
"ip_count() should return ip_entry_count (3) for matchy-built DB"
);
}
#[test]
fn test_lookup_ref_with_auto_reload() {
let db = Database::from("tests/data/GeoLite2-Country.mmdb")
.watch()
.open()
.unwrap();
let lookup = db.lookup_ref("1.1.1.1").unwrap();
assert!(lookup.found, "lookup_ref should find 1.1.1.1 with auto-reload enabled");
assert_eq!(lookup.result_type, 1, "result_type should be 1 (IP)");
assert!(lookup.prefix_len > 0);
let data = db.decode_at_offset(lookup.data_offset).unwrap();
match data {
DataValue::Map(map) => assert!(!map.is_empty(), "decoded data should not be empty"),
_ => panic!("Expected map data from decode_at_offset"),
}
let full_result = db.lookup("1.1.1.1").unwrap();
if let Some(QueryResult::Ip { data: full_data, prefix_len, .. }) = full_result {
assert_eq!(prefix_len, lookup.prefix_len);
let ref_data = db.decode_at_offset(lookup.data_offset).unwrap();
assert_eq!(full_data, ref_data, "lookup_ref+decode should match lookup");
} else {
panic!("Full lookup should also find 1.1.1.1");
}
}
}