use std::borrow::Cow;
#[cfg(feature = "dfa")]
use aho_corasick::{
AhoCorasick as AcEngine, AhoCorasickBuilder, AhoCorasickKind, MatchKind as AhoCorasickMatchKind,
};
use daachorse::{
DoubleArrayAhoCorasick, DoubleArrayAhoCorasickBuilder,
MatchKind as DoubleArrayAhoCorasickMatchKind,
charwise::{CharwiseDoubleArrayAhoCorasick, CharwiseDoubleArrayAhoCorasickBuilder},
};
use crate::MatcherError;
use super::encoding::{DIRECT_BOUNDARY_MASK, DIRECT_BOUNDARY_SHIFT, DIRECT_RULE_MASK};
use super::pattern::{PatternEntry, PatternIndex};
pub(super) const CHARWISE_DENSITY_THRESHOLD: f32 = 0.67;
#[inline(always)]
pub(super) fn text_non_ascii_density(text: &str) -> f32 {
let bytes = text.as_bytes();
let len = bytes.len();
if len == 0 {
return 0.0;
}
super::simd::count_non_ascii_simd(bytes) as f32 / len as f32
}
#[derive(Clone)]
pub(super) struct ScanPlan {
bytewise_matcher: Option<BytewiseMatcher>,
charwise_matcher: Option<CharwiseMatcher>,
all_patterns_ascii: bool,
patterns: PatternIndex,
}
#[derive(Clone)]
struct BytewiseMatcher {
daac: DoubleArrayAhoCorasick<u32>,
#[cfg(feature = "dfa")]
dfa: Option<(AcEngine, Vec<u32>)>,
}
type CharwiseMatcher = CharwiseDoubleArrayAhoCorasick<u32>;
impl ScanPlan {
pub(super) fn compile(
dedup_patterns: &[Cow<'_, str>],
dedup_entries: Vec<Vec<PatternEntry>>,
) -> Result<Self, MatcherError> {
let patterns = PatternIndex::new(dedup_entries);
let value_map = patterns.build_value_map();
let (bytewise_matcher, charwise_matcher) = compile_automata(dedup_patterns, &value_map)?;
let all_patterns_ascii = dedup_patterns.iter().all(|p| p.is_ascii());
Ok(Self {
bytewise_matcher,
charwise_matcher,
all_patterns_ascii,
patterns,
})
}
pub(super) fn patterns(&self) -> &PatternIndex {
&self.patterns
}
#[inline(always)]
pub(super) fn has_dfa(&self) -> bool {
#[cfg(feature = "dfa")]
{
self.bytewise_matcher
.as_ref()
.is_some_and(|m| m.dfa.is_some())
}
#[cfg(not(feature = "dfa"))]
{
false
}
}
pub(super) fn heap_bytes(&self) -> usize {
let bw = self.bytewise_matcher.as_ref().map_or(0, |m| m.heap_bytes());
let cw = self.charwise_matcher.as_ref().map_or(0, |m| m.heap_bytes());
bw + cw + self.patterns.heap_bytes()
}
#[inline(always)]
pub(super) fn is_match(&self, text: &str) -> bool {
if self.all_patterns_ascii {
let density = text_non_ascii_density(text);
if density >= 1.0 && text.bytes().all(|b| b >= 0x80) {
return false;
}
if density > CHARWISE_DENSITY_THRESHOLD {
return self
.charwise_matcher
.as_ref()
.is_some_and(|m| m.is_match_text(text));
}
return self
.bytewise_matcher
.as_ref()
.is_some_and(|m| m.is_match(text));
}
if text.is_ascii() {
self.bytewise_matcher
.as_ref()
.is_some_and(|m| m.is_match(text))
} else {
self.charwise_matcher
.as_ref()
.is_some_and(|m| m.is_match_text(text))
}
}
#[inline(always)]
pub(super) fn for_each_match_value(
&self,
text: &str,
density: f32,
on_value: impl FnMut(u32, usize, usize) -> bool,
) -> bool {
if self.all_patterns_ascii && density >= 1.0 && text.bytes().all(|b| b >= 0x80) {
return false;
}
if density <= CHARWISE_DENSITY_THRESHOLD {
if let Some(ref matcher) = self.bytewise_matcher {
return matcher.for_each_match_value(text, on_value);
}
} else if let Some(ref matcher) = self.charwise_matcher {
return matcher.for_each_match_value(text, on_value);
}
false
}
#[inline(always)]
pub(super) fn for_each_match_value_from_iter(
&self,
iter: impl Iterator<Item = u8>,
density: f32,
on_value: impl FnMut(u32, usize, usize) -> bool,
) -> bool {
if density <= CHARWISE_DENSITY_THRESHOLD {
if let Some(ref matcher) = self.bytewise_matcher {
return matcher.for_each_match_value_from_iter(iter, on_value);
}
} else if let Some(ref matcher) = self.charwise_matcher {
return matcher.for_each_match_value_from_iter(iter, on_value);
}
false
}
#[inline(always)]
pub(super) fn for_each_rule_idx_simple(
&self,
text: &str,
density: f32,
on_rule: impl FnMut(usize, u8, usize, usize),
) {
if density <= CHARWISE_DENSITY_THRESHOLD {
if let Some(ref matcher) = self.bytewise_matcher {
matcher.for_each_rule_idx_simple(text, on_rule);
}
} else if let Some(ref matcher) = self.charwise_matcher {
matcher.for_each_rule_idx_simple(text, on_rule);
}
}
}
impl BytewiseMatcher {
#[inline(always)]
fn is_match(&self, text: &str) -> bool {
#[cfg(feature = "dfa")]
if let Some((ref matcher, _)) = self.dfa {
return matcher.is_match(text);
}
self.daac.find_iter(text).next().is_some()
}
#[inline(always)]
fn for_each_match_value(
&self,
text: &str,
mut on_value: impl FnMut(u32, usize, usize) -> bool,
) -> bool {
#[cfg(feature = "dfa")]
if let Some((ref matcher, ref to_value)) = self.dfa {
for m in matcher.find_overlapping_iter(text) {
let value = unsafe { *to_value.get_unchecked(m.pattern().as_usize()) };
if on_value(value, m.start(), m.end()) {
return true;
}
}
return false;
}
for hit in self.daac.find_overlapping_iter(text) {
if on_value(hit.value(), hit.start(), hit.end()) {
return true;
}
}
false
}
#[inline(always)]
fn for_each_rule_idx_simple(
&self,
text: &str,
mut on_rule: impl FnMut(usize, u8, usize, usize),
) {
#[cfg(feature = "dfa")]
if let Some((ref matcher, ref to_value)) = self.dfa {
for m in matcher.find_overlapping_iter(text) {
let value = unsafe { *to_value.get_unchecked(m.pattern().as_usize()) };
let boundary = ((value & DIRECT_BOUNDARY_MASK) >> DIRECT_BOUNDARY_SHIFT) as u8;
on_rule(
(value & DIRECT_RULE_MASK) as usize,
boundary,
m.start(),
m.end(),
);
}
return;
}
for hit in self.daac.find_overlapping_iter(text) {
let value = hit.value();
let boundary = ((value & DIRECT_BOUNDARY_MASK) >> DIRECT_BOUNDARY_SHIFT) as u8;
on_rule(
(value & DIRECT_RULE_MASK) as usize,
boundary,
hit.start(),
hit.end(),
);
}
}
#[inline(always)]
fn for_each_match_value_from_iter(
&self,
iter: impl Iterator<Item = u8>,
mut on_value: impl FnMut(u32, usize, usize) -> bool,
) -> bool {
for hit in self.daac.find_overlapping_iter_from_iter(iter) {
if on_value(hit.value(), hit.start(), hit.end()) {
return true;
}
}
false
}
fn heap_bytes(&self) -> usize {
let total = self.daac.heap_bytes();
#[cfg(feature = "dfa")]
if let Some((ref matcher, ref to_value)) = self.dfa {
return total + matcher.memory_usage() + to_value.capacity() * size_of::<u32>();
}
total
}
}
trait CharwiseMatcherExt {
fn is_match_text(&self, text: &str) -> bool;
fn for_each_match_value(
&self,
text: &str,
on_value: impl FnMut(u32, usize, usize) -> bool,
) -> bool;
fn for_each_rule_idx_simple(&self, text: &str, on_rule: impl FnMut(usize, u8, usize, usize));
}
trait CharwiseMatcherStreamExt {
fn for_each_match_value_from_iter(
&self,
iter: impl Iterator<Item = u8>,
on_value: impl FnMut(u32, usize, usize) -> bool,
) -> bool;
}
impl CharwiseMatcherStreamExt for CharwiseMatcher {
#[inline(always)]
fn for_each_match_value_from_iter(
&self,
iter: impl Iterator<Item = u8>,
mut on_value: impl FnMut(u32, usize, usize) -> bool,
) -> bool {
for hit in unsafe { self.find_overlapping_iter_from_iter(iter) } {
if on_value(hit.value(), hit.start(), hit.end()) {
return true;
}
}
false
}
}
impl CharwiseMatcherExt for CharwiseMatcher {
fn is_match_text(&self, text: &str) -> bool {
self.find_iter(text).next().is_some()
}
#[inline(always)]
fn for_each_match_value(
&self,
text: &str,
mut on_value: impl FnMut(u32, usize, usize) -> bool,
) -> bool {
for hit in self.find_overlapping_iter(text) {
if on_value(hit.value(), hit.start(), hit.end()) {
return true;
}
}
false
}
#[inline(always)]
fn for_each_rule_idx_simple(
&self,
text: &str,
mut on_rule: impl FnMut(usize, u8, usize, usize),
) {
for hit in self.find_overlapping_iter(text) {
let value = hit.value();
let boundary = ((value & DIRECT_BOUNDARY_MASK) >> DIRECT_BOUNDARY_SHIFT) as u8;
on_rule(
(value & DIRECT_RULE_MASK) as usize,
boundary,
hit.start(),
hit.end(),
);
}
}
}
fn compile_automata(
dedup_patterns: &[Cow<'_, str>],
value_map: &[u32],
) -> Result<(Option<BytewiseMatcher>, Option<CharwiseMatcher>), MatcherError> {
if dedup_patterns.is_empty() {
return Ok((None, None));
}
let all_patvals: Vec<(&str, u32)> = dedup_patterns
.iter()
.enumerate()
.map(|(i, p)| (p.as_ref(), value_map[i]))
.collect();
let all_patvals_clone = all_patvals.clone();
let build_bytewise = move || -> Result<BytewiseMatcher, MatcherError> {
build_current_bytewise(all_patvals_clone)
};
let build_charwise = |source: Vec<(&str, u32)>| -> Result<CharwiseMatcher, MatcherError> {
CharwiseDoubleArrayAhoCorasickBuilder::new()
.match_kind(DoubleArrayAhoCorasickMatchKind::Standard)
.build_with_values(source)
.map_err(MatcherError::automaton_build)
};
std::thread::scope(|s| {
let bytewise_handle = s.spawn(build_bytewise);
let charwise = build_charwise(all_patvals)?;
let bytewise = bytewise_handle
.join()
.expect("bytewise automaton build panicked")?;
Ok((Some(bytewise), Some(charwise)))
})
}
fn build_current_bytewise(all_patvals: Vec<(&str, u32)>) -> Result<BytewiseMatcher, MatcherError> {
let daac = DoubleArrayAhoCorasickBuilder::new()
.match_kind(DoubleArrayAhoCorasickMatchKind::Standard)
.build_with_values(all_patvals.clone())
.map_err(MatcherError::automaton_build)?;
#[cfg(feature = "dfa")]
let dfa = {
let to_value: Vec<u32> = all_patvals.iter().map(|&(_, v)| v).collect();
Some((
AhoCorasickBuilder::new()
.kind(Some(AhoCorasickKind::DFA))
.match_kind(AhoCorasickMatchKind::Standard)
.build(all_patvals.iter().map(|(p, _)| p))
.map_err(MatcherError::automaton_build)?,
to_value,
))
};
Ok(BytewiseMatcher {
daac,
#[cfg(feature = "dfa")]
dfa,
})
}