mod classify;
mod plan;
use std::collections::HashSet;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::{LazyLock, Mutex};
use thiserror::Error;
use classify::Classified;
use plan::Plan;
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum MatcherTier {
Byte,
Literal,
LiteralSet,
Regex,
}
impl std::fmt::Display for MatcherTier {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Byte => f.write_str("Byte"),
Self::Literal => f.write_str("Literal"),
Self::LiteralSet => f.write_str("LiteralSet"),
Self::Regex => f.write_str("Regex"),
}
}
}
impl MatcherTier {
#[must_use]
pub const fn rank(self) -> u8 {
match self {
Self::Byte => 4,
Self::Literal => 3,
Self::LiteralSet => 2,
Self::Regex => 1,
}
}
#[must_use]
pub const fn typical_budget_ns(self) -> Option<u64> {
match self {
Self::Byte => Some(30),
Self::Literal => Some(200),
Self::LiteralSet => Some(500),
Self::Regex => None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub struct Match {
pub start: usize,
pub end: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub struct SetMatch {
pub start: usize,
pub end: usize,
pub pattern_idx: usize,
}
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum BuildError {
#[error(
"strmatch: empty pattern\n \
hint: an empty pattern matches every position; pass a non-empty \
literal or wrap the matcher in Option for an absent-matcher slot"
)]
Empty,
#[error(
"strmatch: regex syntax error in pattern {pattern:?}\n \
reason: {source}\n \
hint: {hint}"
)]
Syntax {
pattern: String,
#[source]
source: Box<regex_syntax::Error>,
hint: &'static str,
},
#[error(
"strmatch: pattern {pattern:?} compiles to tier {got}, but builder \
requires at least {wanted}\n \
reason: {reason}\n \
hint: {hint}"
)]
TierTooLow {
pattern: String,
wanted: MatcherTier,
got: MatcherTier,
reason: &'static str,
hint: &'static str,
},
}
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum OnBelowMin {
#[default]
Allow,
Warn,
Reject,
}
pub struct StrMatcher {
plan: Plan,
tier: MatcherTier,
pattern: String,
reason: &'static str,
}
impl std::fmt::Debug for StrMatcher {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StrMatcher")
.field("tier", &self.tier)
.field("pattern", &self.pattern)
.field("reason", &self.reason)
.finish_non_exhaustive()
}
}
impl StrMatcher {
pub fn new(pattern: &str) -> Result<Self, BuildError> {
Self::builder().build(pattern)
}
#[must_use]
pub fn builder() -> StrMatcherBuilder {
StrMatcherBuilder::new()
}
#[inline]
#[must_use]
pub fn is_match(&self, hay: &[u8]) -> bool {
self.plan.is_match(hay)
}
#[inline]
#[must_use]
pub fn find(&self, hay: &[u8]) -> Option<Match> {
self.plan.find(hay)
}
#[must_use]
pub fn find_iter(&self, hay: &[u8]) -> std::vec::IntoIter<Match> {
let mut out = Vec::new();
self.plan.collect_matches(hay, &mut out);
out.into_iter()
}
#[must_use]
pub fn tier(&self) -> MatcherTier {
self.tier
}
#[must_use]
pub fn pattern(&self) -> &str {
&self.pattern
}
#[must_use]
pub fn reason(&self) -> &'static str {
self.reason
}
}
#[derive(Debug, Clone, Default)]
pub struct StrMatcherBuilder {
min_tier: Option<MatcherTier>,
on_below_min: OnBelowMin,
ascii_case_insensitive: bool,
}
impl StrMatcherBuilder {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn min_tier(mut self, tier: MatcherTier) -> Self {
self.min_tier = Some(tier);
self
}
#[must_use]
pub fn on_below_min(mut self, policy: OnBelowMin) -> Self {
self.on_below_min = policy;
self
}
#[must_use]
pub fn ascii_case_insensitive(mut self, enabled: bool) -> Self {
self.ascii_case_insensitive = enabled;
self
}
pub fn build(&self, pattern: &str) -> Result<StrMatcher, BuildError> {
let classified = classify::classify(pattern, self.ascii_case_insensitive)?;
let tier = classified.descriptor.tier;
if let Some(min) = self.min_tier
&& tier.rank() < min.rank()
{
return self.apply_below_min(pattern, min, classified);
}
if tier == MatcherTier::Regex {
warn::on_regex_fallback(
pattern,
classified.descriptor.reason,
classified.descriptor.hint,
false,
);
metrics_inc_fallback();
}
Ok(StrMatcher {
plan: classified.plan,
tier,
pattern: pattern.to_string(),
reason: classified.descriptor.reason,
})
}
fn apply_below_min(
&self,
pattern: &str,
wanted: MatcherTier,
classified: Classified,
) -> Result<StrMatcher, BuildError> {
let got = classified.descriptor.tier;
let reason = classified.descriptor.reason;
let hint = classified.descriptor.hint;
match self.on_below_min {
OnBelowMin::Reject => Err(BuildError::TierTooLow {
pattern: pattern.to_string(),
wanted,
got,
reason,
hint,
}),
OnBelowMin::Warn => {
warn::on_regex_fallback(pattern, reason, hint, true);
metrics_inc_fallback();
Ok(StrMatcher {
plan: classified.plan,
tier: got,
pattern: pattern.to_string(),
reason,
})
}
OnBelowMin::Allow => {
if got == MatcherTier::Regex {
warn::on_regex_fallback(pattern, reason, hint, false);
metrics_inc_fallback();
}
Ok(StrMatcher {
plan: classified.plan,
tier: got,
pattern: pattern.to_string(),
reason,
})
}
}
}
}
pub struct StrMatcherSet {
merged: Option<MergedAc>,
individual: Vec<(usize, StrMatcher)>,
tiers: Vec<MatcherTier>,
}
struct MergedAc {
ac: aho_corasick::AhoCorasick,
pattern_indices: Vec<usize>,
}
impl StrMatcherSet {
pub fn new<I, S>(patterns: I) -> Result<Self, BuildError>
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
Self::builder().build_set(patterns)
}
#[must_use]
pub fn builder() -> StrMatcherBuilder {
StrMatcherBuilder::new()
}
#[must_use]
pub fn is_match(&self, hay: &[u8]) -> bool {
if let Some(m) = &self.merged
&& m.ac.find(hay).is_some()
{
return true;
}
self.individual.iter().any(|(_, m)| m.is_match(hay))
}
#[must_use]
pub fn earliest_match(&self, hay: &[u8]) -> Option<SetMatch> {
let mut best: Option<SetMatch> = None;
if let Some(m) = &self.merged
&& let Some(hit) = m.ac.find(hay)
{
best = Some(SetMatch {
start: hit.start(),
end: hit.end(),
pattern_idx: m.pattern_indices[hit.pattern().as_usize()],
});
}
for (idx, matcher) in &self.individual {
if let Some(found) = matcher.find(hay) {
let cand = SetMatch {
start: found.start,
end: found.end,
pattern_idx: *idx,
};
best = match best {
None => Some(cand),
Some(b) if cand.start < b.start => Some(cand),
Some(b) if cand.start == b.start && cand.pattern_idx < b.pattern_idx => {
Some(cand)
}
Some(b) => Some(b),
};
}
}
best
}
#[must_use]
pub fn find_iter(&self, hay: &[u8]) -> std::vec::IntoIter<SetMatch> {
let mut all: Vec<SetMatch> = Vec::new();
if let Some(m) = &self.merged {
for hit in m.ac.find_iter(hay) {
all.push(SetMatch {
start: hit.start(),
end: hit.end(),
pattern_idx: m.pattern_indices[hit.pattern().as_usize()],
});
}
}
for (idx, matcher) in &self.individual {
for hit in matcher.find_iter(hay) {
all.push(SetMatch {
start: hit.start,
end: hit.end,
pattern_idx: *idx,
});
}
}
all.sort_by_key(|m| (m.start, m.pattern_idx));
all.into_iter()
}
#[must_use]
pub fn tier_counts(&self) -> [usize; 4] {
let mut counts = [0_usize; 4];
for t in &self.tiers {
let idx = match t {
MatcherTier::Byte => 0,
MatcherTier::Literal => 1,
MatcherTier::LiteralSet => 2,
MatcherTier::Regex => 3,
};
counts[idx] += 1;
}
counts
}
#[must_use]
pub fn len(&self) -> usize {
self.tiers.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.tiers.is_empty()
}
}
impl StrMatcherBuilder {
pub fn build_set<I, S>(&self, patterns: I) -> Result<StrMatcherSet, BuildError>
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
let mut individual: Vec<(usize, StrMatcher)> = Vec::new();
let mut merge_lits: Vec<Vec<u8>> = Vec::new();
let mut pattern_indices: Vec<usize> = Vec::new();
let mut tiers: Vec<MatcherTier> = Vec::new();
for (idx, pat) in patterns.into_iter().enumerate() {
let m = self.build(pat.as_ref())?;
tiers.push(m.tier);
match m.plan.merge_literals() {
Some((lits, ci)) if ci == self.ascii_case_insensitive => {
for lit in lits {
pattern_indices.push(idx);
merge_lits.push(lit);
}
}
_ => individual.push((idx, m)),
}
}
let merged = (!merge_lits.is_empty()).then(|| {
let ac = aho_corasick::AhoCorasickBuilder::new()
.ascii_case_insensitive(self.ascii_case_insensitive)
.match_kind(aho_corasick::MatchKind::LeftmostLongest)
.build(&merge_lits)
.expect("strmatch: merged AC build failed on a pre-validated literal set");
MergedAc {
ac,
pattern_indices,
}
});
Ok(StrMatcherSet {
merged,
individual,
tiers,
})
}
}
mod warn {
use super::{AtomicBool, AtomicUsize, HashSet, LazyLock, Mutex, Ordering};
pub(super) const WARN_CAP: usize = 10;
static WARNED_HASHES: LazyLock<Mutex<HashSet<u64>>> =
LazyLock::new(|| Mutex::new(HashSet::new()));
static DISTINCT: AtomicUsize = AtomicUsize::new(0);
static SUMMARY_EMITTED: AtomicBool = AtomicBool::new(false);
fn hash_pattern(pattern: &str) -> u64 {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut h = DefaultHasher::new();
pattern.hash(&mut h);
h.finish()
}
pub(super) fn on_regex_fallback(
pattern: &str,
reason: &'static str,
hint: &'static str,
force: bool,
) {
let h = hash_pattern(pattern);
let already_warned = {
let mut set = WARNED_HASHES.lock().unwrap_or_else(|e| e.into_inner());
!set.insert(h)
};
if already_warned && !force {
return;
}
let n = if force {
DISTINCT.load(Ordering::Relaxed)
} else {
DISTINCT.fetch_add(1, Ordering::Relaxed) + 1
};
if force || n <= WARN_CAP {
tracing::warn!(
target: "hyperi_rustlib::strmatch",
pattern,
reason,
hint,
"pattern falls through to regex engine on hot path"
);
} else {
tracing::debug!(
target: "hyperi_rustlib::strmatch",
pattern,
reason,
hint,
"regex fallback (WARN suppressed past cap)"
);
}
if !force && n == WARN_CAP + 1 && !SUMMARY_EMITTED.swap(true, Ordering::Relaxed) {
tracing::info!(
target: "hyperi_rustlib::strmatch",
cap = WARN_CAP,
"{}+ distinct patterns have fallen through to the regex engine; \
further fall-throughs log at DEBUG. Inspect StrMatcher::tier() / \
StrMatcherSet::tier_counts() at runtime, or scrape the \
hyperi_strmatch_regex_fallback_total metric.",
WARN_CAP,
);
}
}
#[cfg(test)]
pub(super) fn reset_for_tests() {
WARNED_HASHES
.lock()
.unwrap_or_else(|e| e.into_inner())
.clear();
DISTINCT.store(0, Ordering::Relaxed);
SUMMARY_EMITTED.store(false, Ordering::Relaxed);
}
}
#[inline]
fn metrics_inc_fallback() {
#[cfg(feature = "metrics")]
metrics::counter!("hyperi_strmatch_regex_fallback_total").increment(1);
}
#[cfg(test)]
#[doc(hidden)]
pub fn reset_warn_state_for_tests() {
warn::reset_for_tests();
}
#[cfg(test)]
mod tests;