use std::borrow::Cow;
#[cfg(feature = "dfa")]
use aho_corasick::{
Anchored, Input, MatchKind as AhoCorasickMatchKind, automaton::Automaton,
dfa::DFA as AcDfaEngine,
};
use daachorse::{
DoubleArrayAhoCorasick, DoubleArrayAhoCorasickBuilder,
MatchKind as DoubleArrayAhoCorasickMatchKind,
charwise::{CharwiseDoubleArrayAhoCorasick, CharwiseDoubleArrayAhoCorasickBuilder},
};
use crate::MatcherError;
#[cfg(feature = "harry")]
use super::harry::HarryMatcher;
use super::rule::{PatternEntry, PatternIndex};
#[cfg(feature = "dfa")]
const AC_DFA_PATTERN_THRESHOLD: usize = 7_000;
#[derive(Clone)]
pub(super) struct ScanPlan {
bytewise_matcher: Option<BytewiseMatcher>,
charwise_matcher: Option<CharwiseMatcher>,
#[cfg(feature = "harry")]
harry_matcher: Option<Box<HarryMatcher>>,
patterns: PatternIndex,
}
#[derive(Clone)]
enum BytewiseMatcher {
#[cfg(feature = "dfa")]
AcDfa {
matcher: Box<AcDfaEngine>,
to_value: Vec<u32>,
},
DaacBytewise(DoubleArrayAhoCorasick<u32>),
}
#[derive(Clone)]
enum CharwiseMatcher {
DaacCharwise(CharwiseDoubleArrayAhoCorasick<u32>),
}
impl ScanPlan {
pub(super) fn compile(
dedup_patterns: &[Cow<'_, str>],
dedup_entries: Vec<Vec<PatternEntry>>,
) -> Result<Self, MatcherError> {
let patterns = PatternIndex::new(dedup_entries);
let value_map = patterns.build_value_map();
let (bytewise_matcher, charwise_matcher) = compile_automata(dedup_patterns, &value_map)?;
#[cfg(feature = "harry")]
let harry_matcher = if charwise_matcher.is_none() {
let patvals: Vec<(&str, u32)> = dedup_patterns
.iter()
.enumerate()
.map(|(i, p)| (p.as_ref(), value_map[i]))
.collect();
HarryMatcher::build(&patvals).map(Box::new)
} else {
None
};
Ok(Self {
bytewise_matcher,
charwise_matcher,
#[cfg(feature = "harry")]
harry_matcher,
patterns,
})
}
#[inline(always)]
pub(super) fn patterns(&self) -> &PatternIndex {
&self.patterns
}
pub(super) fn heap_bytes(&self) -> usize {
let bw = self.bytewise_matcher.as_ref().map_or(0, |m| m.heap_bytes());
let cw = self.charwise_matcher.as_ref().map_or(0, |m| m.heap_bytes());
#[cfg(feature = "harry")]
let harry = self.harry_matcher.as_ref().map_or(0, |m| m.heap_bytes());
#[cfg(not(feature = "harry"))]
let harry = 0;
bw + cw + harry + self.patterns.heap_bytes()
}
#[cfg(feature = "harry")]
#[inline(always)]
fn uses_dfa(&self) -> bool {
#[cfg(feature = "dfa")]
if let Some(BytewiseMatcher::AcDfa { .. }) = &self.bytewise_matcher {
return true;
}
false
}
#[inline(always)]
pub(super) fn is_match(&self, text: &str) -> bool {
#[cfg(feature = "harry")]
if self.harry_matcher.as_ref().is_some_and(|_| {
self.charwise_matcher.is_none() && (!self.uses_dfa() || !text.is_ascii())
}) {
return self.harry_matcher.as_ref().unwrap().is_match(text);
}
if self.charwise_matcher.is_none() || text.is_ascii() {
self.bytewise_matcher
.as_ref()
.is_some_and(|m| m.is_match(text))
} else {
self.charwise_matcher
.as_ref()
.is_some_and(|m| m.is_match(text))
}
}
#[inline(always)]
pub(super) fn for_each_match_value(
&self,
text: &str,
is_ascii: bool,
on_value: impl FnMut(u32) -> bool,
) -> bool {
let use_bytewise = self.charwise_matcher.is_none() || is_ascii;
if use_bytewise {
if let Some(ref matcher) = self.bytewise_matcher {
return matcher.for_each_match_value(text, on_value);
}
} else if let Some(ref matcher) = self.charwise_matcher {
return matcher.for_each_match_value(text, on_value);
}
false
}
#[inline(always)]
pub(super) fn for_each_match_value_from_iter<I: Iterator<Item = u8>>(
&self,
iter: I,
is_ascii: bool,
on_value: impl FnMut(u32) -> bool,
) -> bool {
let use_bytewise = self.charwise_matcher.is_none() || is_ascii;
if use_bytewise {
if let Some(ref matcher) = self.bytewise_matcher {
return matcher.for_each_match_value_from_iter(iter, on_value);
}
} else if let Some(ref matcher) = self.charwise_matcher {
return matcher.for_each_match_value_from_iter(iter, on_value);
}
false
}
}
impl BytewiseMatcher {
#[inline(always)]
fn is_match(&self, text: &str) -> bool {
match self {
#[cfg(feature = "dfa")]
Self::AcDfa { matcher, .. } => matcher.try_find(&Input::new(text)).unwrap().is_some(),
Self::DaacBytewise(matcher) => matcher.find_iter(text).next().is_some(),
}
}
#[inline(always)]
fn for_each_match_value(&self, text: &str, mut on_value: impl FnMut(u32) -> bool) -> bool {
match self {
#[cfg(feature = "dfa")]
Self::AcDfa { matcher, to_value } => {
for hit in matcher.try_find_overlapping_iter(Input::new(text)).unwrap() {
let value = unsafe { *to_value.get_unchecked(hit.pattern().as_usize()) };
if on_value(value) {
return true;
}
}
false
}
Self::DaacBytewise(matcher) => {
for hit in matcher.find_overlapping_iter(text) {
if on_value(hit.value()) {
return true;
}
}
false
}
}
}
#[inline(always)]
fn for_each_match_value_from_iter<I: Iterator<Item = u8>>(
&self,
iter: I,
mut on_value: impl FnMut(u32) -> bool,
) -> bool {
match self {
#[cfg(feature = "dfa")]
Self::AcDfa { matcher, to_value } => {
let mut sid = matcher.start_state(Anchored::No).unwrap();
for byte in iter {
sid = matcher.next_state(Anchored::No, sid, byte);
if matcher.is_special(sid) && matcher.is_match(sid) {
for i in 0..matcher.match_len(sid) {
let pid = matcher.match_pattern(sid, i);
let value = unsafe { *to_value.get_unchecked(pid.as_usize()) };
if on_value(value) {
return true;
}
}
}
}
false
}
Self::DaacBytewise(matcher) => {
for hit in matcher.find_overlapping_iter_from_iter(iter) {
if on_value(hit.value()) {
return true;
}
}
false
}
}
}
fn heap_bytes(&self) -> usize {
match self {
#[cfg(feature = "dfa")]
Self::AcDfa { matcher, to_value } => {
matcher.memory_usage() + to_value.capacity() * size_of::<u32>()
}
Self::DaacBytewise(matcher) => matcher.heap_bytes(),
}
}
}
impl CharwiseMatcher {
#[inline(always)]
fn is_match(&self, text: &str) -> bool {
match self {
Self::DaacCharwise(matcher) => matcher.find_iter(text).next().is_some(),
}
}
#[inline(always)]
fn for_each_match_value(&self, text: &str, mut on_value: impl FnMut(u32) -> bool) -> bool {
match self {
Self::DaacCharwise(matcher) => {
for hit in matcher.find_overlapping_iter(text) {
if on_value(hit.value()) {
return true;
}
}
false
}
}
}
#[inline(always)]
fn for_each_match_value_from_iter<I: Iterator<Item = u8>>(
&self,
iter: I,
mut on_value: impl FnMut(u32) -> bool,
) -> bool {
match self {
Self::DaacCharwise(matcher) => {
for hit in unsafe { matcher.find_overlapping_iter_from_iter(iter) } {
if on_value(hit.value()) {
return true;
}
}
false
}
}
}
fn heap_bytes(&self) -> usize {
match self {
Self::DaacCharwise(matcher) => matcher.heap_bytes(),
}
}
}
fn compile_automata(
dedup_patterns: &[Cow<'_, str>],
value_map: &[u32],
) -> Result<(Option<BytewiseMatcher>, Option<CharwiseMatcher>), MatcherError> {
let cap = dedup_patterns.len();
let mut ascii_patvals: Vec<(&str, u32)> = Vec::with_capacity(cap);
let mut non_ascii_patvals: Vec<(&str, u32)> = Vec::with_capacity(cap);
for (dedup_idx, pattern) in dedup_patterns.iter().enumerate() {
let value = value_map[dedup_idx];
if pattern.as_ref().is_ascii() {
ascii_patvals.push((pattern.as_ref(), value));
} else {
non_ascii_patvals.push((pattern.as_ref(), value));
}
}
let has_ascii = !ascii_patvals.is_empty();
let has_non_ascii = !non_ascii_patvals.is_empty();
let full_charwise_patvals: Option<Vec<(&str, u32)>> = if has_ascii && has_non_ascii {
Some(
dedup_patterns
.iter()
.enumerate()
.map(|(i, p)| (p.as_ref(), value_map[i]))
.collect(),
)
} else {
None
};
let charwise_source = full_charwise_patvals
.as_deref()
.unwrap_or(non_ascii_patvals.as_slice());
let build_bytewise = move || -> Result<BytewiseMatcher, MatcherError> {
build_current_bytewise(ascii_patvals, value_map.len())
};
let build_charwise = || -> Result<CharwiseMatcher, MatcherError> {
Ok(CharwiseMatcher::DaacCharwise(
CharwiseDoubleArrayAhoCorasickBuilder::new()
.match_kind(DoubleArrayAhoCorasickMatchKind::Standard)
.build_with_values(charwise_source.iter().copied())
.map_err(MatcherError::automaton_build)?,
))
};
match (has_ascii, has_non_ascii) {
(false, false) => Ok((None, None)),
(true, false) => Ok((Some(build_bytewise()?), None)),
(false, true) => Ok((None, Some(build_charwise()?))),
(true, true) => std::thread::scope(|s| {
let bytewise_handle = s.spawn(build_bytewise);
let charwise = build_charwise()?;
let bytewise = bytewise_handle
.join()
.expect("bytewise automaton build panicked")?;
Ok((Some(bytewise), Some(charwise)))
}),
}
}
fn build_current_bytewise(
ascii_patvals: Vec<(&str, u32)>,
_value_map_len: usize,
) -> Result<BytewiseMatcher, MatcherError> {
#[cfg(feature = "dfa")]
let mut ascii_ac_to_value: Vec<u32> = Vec::with_capacity(ascii_patvals.len());
#[cfg(feature = "dfa")]
for &(_, value) in &ascii_patvals {
ascii_ac_to_value.push(value);
}
#[cfg(feature = "dfa")]
if ascii_patvals.len() <= AC_DFA_PATTERN_THRESHOLD {
return Ok(BytewiseMatcher::AcDfa {
matcher: Box::new(
AcDfaEngine::builder()
.match_kind(AhoCorasickMatchKind::Standard)
.build(ascii_patvals.iter().map(|(p, _)| p))
.map_err(MatcherError::automaton_build)?,
),
to_value: ascii_ac_to_value,
});
}
Ok(BytewiseMatcher::DaacBytewise(
DoubleArrayAhoCorasickBuilder::new()
.match_kind(DoubleArrayAhoCorasickMatchKind::Standard)
.build_with_values(ascii_patvals)
.map_err(MatcherError::automaton_build)?,
))
}
#[cfg(all(test, feature = "harry"))]
impl ScanPlan {
pub(super) fn has_harry(&self) -> bool {
self.harry_matcher.is_some()
}
}
#[cfg(all(test, feature = "harry"))]
mod tests {
use super::*;
fn compile_from_strings(patterns: &[&str]) -> ScanPlan {
let dedup_patterns: Vec<Cow<'_, str>> =
patterns.iter().map(|&p| Cow::Borrowed(p)).collect();
let dedup_entries: Vec<Vec<PatternEntry>> = patterns.iter().map(|_| vec![]).collect();
ScanPlan::compile(&dedup_patterns, dedup_entries).expect("compile should succeed")
}
#[test]
fn harry_built_for_large_ascii_sets() {
let patterns: Vec<String> = (0..64).map(|i| format!("token{i:02}")).collect();
let refs: Vec<&str> = patterns.iter().map(String::as_str).collect();
let plan = compile_from_strings(&refs);
assert!(plan.has_harry(), "should build Harry for 64 ASCII patterns");
}
#[test]
fn harry_not_built_for_small_sets() {
let patterns: Vec<String> = (0..8).map(|i| format!("token{i:02}")).collect();
let refs: Vec<&str> = patterns.iter().map(String::as_str).collect();
let plan = compile_from_strings(&refs);
assert!(
!plan.has_harry(),
"should not build Harry for < 64 patterns"
);
}
#[test]
fn harry_not_built_for_mixed_pattern_sets() {
let ascii: Vec<String> = (0..32).map(|i| format!("token{i:02}")).collect();
let cjk: Vec<String> = (0..32).map(|i| format!("测试{i:02}")).collect();
let patterns: Vec<String> = ascii.into_iter().chain(cjk).collect();
let refs: Vec<&str> = patterns.iter().map(String::as_str).collect();
let plan = compile_from_strings(&refs);
assert!(
!plan.has_harry(),
"should not build Harry for mixed ASCII+CJK patterns"
);
assert!(
plan.charwise_matcher.is_some(),
"charwise engine should exist for CJK patterns"
);
}
}