use std::borrow::Cow;
#[cfg(feature = "dfa")]
use aho_corasick::{
Anchored, Input, MatchKind as AcMatchKind,
automaton::{Automaton as _, OverlappingState},
dfa::{Builder as AcDfaBuilder, DFA as AcDfa},
};
use daachorse::{
DoubleArrayAhoCorasick as BytewiseDAACEngine,
DoubleArrayAhoCorasickBuilder as BytewiseDAACBuilder, MatchKind as DAACMatchKind,
charwise::{
CharwiseDoubleArrayAhoCorasick as CharwiseDAACEngine,
CharwiseDoubleArrayAhoCorasickBuilder as CharwiseDAACBuilder,
},
};
use super::{
error::MatcherError,
pattern::{PatternEntry, PatternIndex},
rule::RuleInfo,
};
pub(super) const CHARWISE_DENSITY_THRESHOLD: f32 = 0.55;
#[inline(always)]
pub(super) fn text_char_density(text: &str) -> f32 {
let bytes = text.as_bytes();
let len = bytes.len();
if len == 0 {
return 1.0;
}
bytecount::num_chars(bytes) as f32 / len as f32
}
trait ScanEngine {
fn is_match(&self, text: &str) -> bool;
fn for_each_match_value(
&self,
text: &str,
on_value: impl FnMut(u32, usize, usize) -> bool,
) -> bool;
fn for_each_match_value_from_iter(
&self,
iter: impl Iterator<Item = u8>,
on_value: impl FnMut(u32, usize, usize) -> bool,
) -> bool;
fn heap_bytes(&self) -> usize;
}
#[cfg(feature = "dfa")]
#[derive(Clone)]
struct BytewiseDFAEngine {
dfa: AcDfa,
dfa_to_value: Vec<u32>,
has_prefilter: bool,
}
#[cfg(feature = "dfa")]
impl BytewiseDFAEngine {
fn is_match(&self, text: &str) -> bool {
self.dfa
.try_find(&Input::new(text))
.is_ok_and(|m| m.is_some())
}
#[inline(always)]
fn for_each_match_value(
&self,
text: &str,
mut on_value: impl FnMut(u32, usize, usize) -> bool,
) -> bool {
if self.has_prefilter {
let input = Input::new(text);
let mut state = OverlappingState::start();
loop {
if self.dfa.try_find_overlapping(&input, &mut state).is_err() {
break;
}
match state.get_match() {
None => break,
Some(m) => {
let pid = m.pattern().as_usize();
unsafe { core::hint::assert_unchecked(pid < self.dfa_to_value.len()) };
let value = self.dfa_to_value[pid];
if on_value(value, m.start(), m.end()) {
return true;
}
}
}
}
false
} else {
self.scan(text.as_bytes(), on_value)
}
}
#[cfg_attr(feature = "_profile_boundaries", inline(never))]
#[cfg_attr(not(feature = "_profile_boundaries"), inline(always))]
fn scan_from_iter(
&self,
iter: impl Iterator<Item = u8>,
mut on_value: impl FnMut(u32, usize, usize) -> bool,
) -> bool {
let anchored = Anchored::No;
let mut sid = match self.dfa.start_state(anchored) {
Ok(s) => s,
Err(_) => return false,
};
for (pos, byte) in iter.enumerate() {
sid = self.dfa.next_state(anchored, sid, byte);
if self.dfa.is_special(sid) {
if self.dfa.is_dead(sid) {
break;
}
if self.dfa.is_match(sid) {
let end = pos + 1;
for i in 0..self.dfa.match_len(sid) {
let pid = self.dfa.match_pattern(sid, i);
let start = end - self.dfa.pattern_len(pid);
let value = unsafe { *self.dfa_to_value.get_unchecked(pid.as_usize()) };
if on_value(value, start, end) {
return true;
}
}
}
}
}
false
}
fn heap_bytes(&self) -> usize {
self.dfa.memory_usage() + self.dfa_to_value.capacity() * size_of::<u32>()
}
#[cfg_attr(feature = "_profile_boundaries", inline(never))]
#[cfg_attr(not(feature = "_profile_boundaries"), inline(always))]
fn scan(&self, text: &[u8], mut on_value: impl FnMut(u32, usize, usize) -> bool) -> bool {
let anchored = Anchored::No;
let mut sid = match self.dfa.start_state(anchored) {
Ok(s) => s,
Err(_) => return false,
};
for (pos, &byte) in text.iter().enumerate() {
sid = self.dfa.next_state(anchored, sid, byte);
if self.dfa.is_special(sid) {
if self.dfa.is_dead(sid) {
break;
}
if self.dfa.is_match(sid) {
let end = pos + 1;
for i in 0..self.dfa.match_len(sid) {
let pid = self.dfa.match_pattern(sid, i);
let start = end - self.dfa.pattern_len(pid);
let value = unsafe { *self.dfa_to_value.get_unchecked(pid.as_usize()) };
if on_value(value, start, end) {
return true;
}
}
}
}
}
false
}
}
#[derive(Clone)]
struct BytewiseMatcher {
daac: BytewiseDAACEngine<u32>,
#[cfg(feature = "dfa")]
dfa_engine: BytewiseDFAEngine,
}
impl ScanEngine for BytewiseMatcher {
#[inline(always)]
fn is_match(&self, text: &str) -> bool {
#[cfg(feature = "dfa")]
{
self.dfa_engine.is_match(text)
}
#[cfg(not(feature = "dfa"))]
{
self.daac.find_iter(text).next().is_some()
}
}
#[inline(always)]
fn for_each_match_value(
&self,
text: &str,
on_value: impl FnMut(u32, usize, usize) -> bool,
) -> bool {
#[cfg(feature = "dfa")]
{
self.dfa_engine.for_each_match_value(text, on_value)
}
#[cfg(not(feature = "dfa"))]
{
let mut on_value = on_value;
for hit in self.daac.find_overlapping_iter(text) {
if on_value(hit.value(), hit.start(), hit.end()) {
return true;
}
}
false
}
}
#[inline(always)]
fn for_each_match_value_from_iter(
&self,
iter: impl Iterator<Item = u8>,
mut on_value: impl FnMut(u32, usize, usize) -> bool,
) -> bool {
#[cfg(feature = "dfa")]
if !self.dfa_engine.has_prefilter {
return self.dfa_engine.scan_from_iter(iter, on_value);
}
for hit in self.daac.find_overlapping_iter_from_iter(iter) {
if on_value(hit.value(), hit.start(), hit.end()) {
return true;
}
}
false
}
fn heap_bytes(&self) -> usize {
let daac = self.daac.heap_bytes();
#[cfg(feature = "dfa")]
{
daac + self.dfa_engine.heap_bytes()
}
#[cfg(not(feature = "dfa"))]
daac
}
}
type CharwiseMatcher = CharwiseDAACEngine<u32>;
impl ScanEngine for CharwiseMatcher {
fn is_match(&self, text: &str) -> bool {
self.find_iter(text).next().is_some()
}
#[inline(always)]
fn for_each_match_value(
&self,
text: &str,
mut on_value: impl FnMut(u32, usize, usize) -> bool,
) -> bool {
for hit in self.find_overlapping_iter(text) {
if on_value(hit.value(), hit.start(), hit.end()) {
return true;
}
}
false
}
#[inline(always)]
fn for_each_match_value_from_iter(
&self,
iter: impl Iterator<Item = u8>,
mut on_value: impl FnMut(u32, usize, usize) -> bool,
) -> bool {
for hit in unsafe { self.find_overlapping_iter_from_iter(iter) } {
if on_value(hit.value(), hit.start(), hit.end()) {
return true;
}
}
false
}
fn heap_bytes(&self) -> usize {
CharwiseDAACEngine::heap_bytes(self)
}
}
#[derive(Clone)]
struct Engines {
bytewise: BytewiseMatcher,
charwise: CharwiseMatcher,
}
impl Engines {
fn has_dfa_prefilter(&self) -> bool {
#[cfg(feature = "dfa")]
{
self.bytewise.dfa_engine.has_prefilter
}
#[cfg(not(feature = "dfa"))]
{
false
}
}
}
macro_rules! dispatch {
($engines:expr, $density:expr, $method:ident ($($arg:expr),*)) => {
if $density >= CHARWISE_DENSITY_THRESHOLD {
ScanEngine::$method(&$engines.bytewise, $($arg),*)
} else {
ScanEngine::$method(&$engines.charwise, $($arg),*)
}
};
}
#[derive(Clone)]
pub(super) struct ScanPlan {
engines: Engines,
patterns: PatternIndex,
}
impl ScanPlan {
pub(super) fn compile(
dedup_patterns: &[Cow<'_, str>],
dedup_entries: Vec<Vec<PatternEntry>>,
rule_info: &[RuleInfo],
) -> Result<Self, MatcherError> {
debug_assert!(
!dedup_patterns.is_empty(),
"ScanPlan::compile called with zero patterns"
);
let patterns = PatternIndex::new(dedup_entries);
let value_map = patterns.build_value_map(rule_info);
let engines = compile_automata(dedup_patterns, &value_map)?;
Ok(Self { engines, patterns })
}
pub(super) fn patterns(&self) -> &PatternIndex {
&self.patterns
}
pub(super) fn heap_bytes(&self) -> usize {
self.engines.bytewise.heap_bytes()
+ self.engines.charwise.heap_bytes()
+ self.patterns.heap_bytes()
}
#[inline(always)]
pub(super) fn is_match(&self, text: &str) -> bool {
let density = text_char_density(text);
dispatch!(self.engines, density, is_match(text))
}
#[inline(always)]
pub(super) fn for_each_match_value(
&self,
text: &str,
density: f32,
on_value: impl FnMut(u32, usize, usize) -> bool,
) -> bool {
dispatch!(self.engines, density, for_each_match_value(text, on_value))
}
pub(super) fn has_dfa_prefilter(&self) -> bool {
self.engines.has_dfa_prefilter()
}
#[inline(always)]
pub(super) fn for_each_match_value_from_iter(
&self,
iter: impl Iterator<Item = u8>,
density: f32,
on_value: impl FnMut(u32, usize, usize) -> bool,
) -> bool {
dispatch!(
self.engines,
density,
for_each_match_value_from_iter(iter, on_value)
)
}
}
#[optimize(speed)]
fn compile_automata(
dedup_patterns: &[Cow<'_, str>],
value_map: &[u32],
) -> Result<Engines, MatcherError> {
let all_patvals: Vec<(&str, u32)> = dedup_patterns
.iter()
.enumerate()
.map(|(i, p)| (p.as_ref(), value_map[i]))
.collect();
std::thread::scope(|s| {
let bytewise_handle = s.spawn(|| build_current_bytewise(all_patvals));
let charwise = CharwiseDAACBuilder::new()
.match_kind(DAACMatchKind::Standard)
.build_with_values(
dedup_patterns
.iter()
.enumerate()
.map(|(i, p)| (p.as_ref(), value_map[i])),
)
.map_err(MatcherError::automaton_build)?;
let bytewise = bytewise_handle
.join()
.expect("bytewise automaton build panicked")?;
Ok(Engines { bytewise, charwise })
})
}
fn build_current_bytewise(all_patvals: Vec<(&str, u32)>) -> Result<BytewiseMatcher, MatcherError> {
#[cfg(feature = "dfa")]
let dfa_to_value: Vec<u32> = all_patvals.iter().map(|&(_, v)| v).collect();
#[cfg(feature = "dfa")]
let dfa = AcDfaBuilder::new()
.match_kind(AcMatchKind::Standard)
.build(all_patvals.iter().map(|(p, _)| p))
.map_err(MatcherError::automaton_build)?;
let daac = BytewiseDAACBuilder::new()
.match_kind(DAACMatchKind::Standard)
.build_with_values(all_patvals)
.map_err(MatcherError::automaton_build)?;
Ok(BytewiseMatcher {
daac,
#[cfg(feature = "dfa")]
dfa_engine: BytewiseDFAEngine {
has_prefilter: dfa.prefilter().is_some(),
dfa,
dfa_to_value,
},
})
}