use std::borrow::Cow;
use std::cell::RefCell;
use std::collections::{HashMap, HashSet};
#[cfg(feature = "dfa")]
use aho_corasick::AhoCorasickKind;
use aho_corasick::{AhoCorasick, AhoCorasickBuilder};
use serde::Serialize;
use tinyvec::TinyVec;
use crate::process::process_matcher::{
ProcessType, ProcessTypeBitNode, ProcessedTextMasks, build_process_type_tree,
reduce_text_process_emit, return_processed_string_to_pool, walk_process_tree,
};
const BITMASK_CAPACITY: usize = 64;
#[derive(Default, Clone, Copy)]
struct WordState {
matrix_generation: u32,
not_generation: u32,
satisfied_generation: u32,
satisfied_mask: u64,
}
struct SimpleMatchState {
word_states: Vec<WordState>,
matrix: Vec<TinyVec<[i32; 16]>>,
touched_indices: Vec<usize>,
generation: u32,
}
impl SimpleMatchState {
fn new() -> Self {
Self {
word_states: Vec::new(),
matrix: Vec::new(),
touched_indices: Vec::new(),
generation: 0,
}
}
fn prepare(&mut self, size: usize) {
if self.generation == u32::MAX {
for state in self.word_states.iter_mut() {
state.matrix_generation = 0;
state.not_generation = 0;
state.satisfied_generation = 0;
}
self.generation = 1;
} else {
self.generation += 1;
}
if self.word_states.len() < size {
self.word_states.resize(size, WordState::default());
self.matrix.resize(size, TinyVec::new());
}
self.touched_indices.clear();
}
}
thread_local! {
static SIMPLE_MATCH_STATE: RefCell<SimpleMatchState> = RefCell::new(SimpleMatchState::new());
}
pub type SimpleTable<'a> = HashMap<ProcessType, HashMap<u32, &'a str>>;
pub type SimpleTableSerde<'a> = HashMap<ProcessType, HashMap<u32, Cow<'a, str>>>;
#[derive(Serialize, Debug)]
pub struct SimpleResult<'a> {
pub word_id: u32,
pub word: Cow<'a, str>,
}
#[derive(Debug, Clone)]
struct RuleHot {
segment_counts: Vec<i32>,
and_count: usize,
expected_mask: u64,
use_matrix: bool,
num_splits: u16,
}
#[derive(Debug, Clone)]
struct RuleCold {
word_id: u32,
word: String,
}
#[derive(Debug, Clone)]
struct PatternEntry {
process_type_mask: u64,
rule_idx: u32,
offset: u16,
}
#[derive(Debug, Clone)]
enum InternalMatcher {
AhoCorasick(AhoCorasick),
}
#[derive(Debug, Clone)]
pub struct SimpleMatcher {
process_type_tree: Vec<ProcessTypeBitNode>,
ac_matcher: InternalMatcher,
ac_dedup_entries: Vec<PatternEntry>,
ac_dedup_ranges: Vec<(usize, usize)>,
rule_hot: Vec<RuleHot>,
rule_cold: Vec<RuleCold>,
}
impl SimpleMatcher {
pub fn new<'a, I, S1, S2>(
process_type_word_map: &'a HashMap<ProcessType, HashMap<u32, I, S1>, S2>,
) -> SimpleMatcher
where
I: AsRef<str> + 'a,
{
let word_size: usize = process_type_word_map.values().map(|m| m.len()).sum();
let mut process_type_set: HashSet<ProcessType> =
HashSet::with_capacity(process_type_word_map.len());
let mut dedup_entries: Vec<Vec<PatternEntry>> = Vec::with_capacity(word_size);
let mut rule_hot: Vec<RuleHot> = Vec::with_capacity(word_size);
let mut rule_cold: Vec<RuleCold> = Vec::with_capacity(word_size);
let mut word_id_to_idx: HashMap<(ProcessType, u32), usize> =
HashMap::with_capacity(word_size);
let mut next_pattern_id: usize = 0;
let mut dedup_patterns = Vec::with_capacity(word_size);
let mut pattern_id_map: HashMap<Cow<str>, usize> = HashMap::with_capacity(word_size);
for (&process_type, simple_word_map) in process_type_word_map {
let word_process_type = process_type - ProcessType::Delete;
process_type_set.insert(process_type);
for (&simple_word_id, simple_word) in simple_word_map {
if simple_word.as_ref().is_empty() {
continue;
}
let mut and_splits: HashMap<&str, i32> = HashMap::new();
let mut not_splits: HashMap<&str, i32> = HashMap::new();
let mut start = 0;
let mut current_is_not = false;
let mut add_sub_word = |word: &'a str, is_not: bool| {
if word.is_empty() {
return;
}
if is_not {
let entry = not_splits.entry(word).or_insert(1);
*entry -= 1;
} else {
let entry = and_splits.entry(word).or_insert(0);
*entry += 1;
}
};
for (index, char) in simple_word.as_ref().match_indices(['&', '~']) {
add_sub_word(&simple_word.as_ref()[start..index], current_is_not);
current_is_not = char == "~";
start = index + 1;
}
add_sub_word(&simple_word.as_ref()[start..], current_is_not);
if and_splits.is_empty() && not_splits.is_empty() {
continue;
}
let and_count = and_splits.len();
let segment_counts = and_splits
.values()
.copied()
.chain(not_splits.values().copied())
.collect::<Vec<i32>>();
let expected_mask = if and_count > 0 && and_count <= BITMASK_CAPACITY {
u64::MAX >> (BITMASK_CAPACITY - and_count)
} else {
0
};
let num_splits = segment_counts.len() as u16;
let use_matrix = and_count > BITMASK_CAPACITY
|| segment_counts.len() > BITMASK_CAPACITY
|| segment_counts[..and_count].iter().any(|&v| v != 1)
|| segment_counts[and_count..].iter().any(|&v| v != 0);
let rule_idx = if let Some(&existing_idx) =
word_id_to_idx.get(&(process_type, simple_word_id))
{
rule_hot[existing_idx] = RuleHot {
segment_counts,
and_count,
expected_mask,
use_matrix,
num_splits,
};
rule_cold[existing_idx] = RuleCold {
word_id: simple_word_id,
word: simple_word.as_ref().to_owned(),
};
existing_idx
} else {
let idx = rule_hot.len();
word_id_to_idx.insert((process_type, simple_word_id), idx);
rule_hot.push(RuleHot {
segment_counts,
and_count,
expected_mask,
use_matrix,
num_splits,
});
rule_cold.push(RuleCold {
word_id: simple_word_id,
word: simple_word.as_ref().to_owned(),
});
idx
};
for (offset, &split_word) in and_splits.keys().chain(not_splits.keys()).enumerate()
{
for ac_word in reduce_text_process_emit(word_process_type, split_word) {
let Some(&existing_dedup_id) = pattern_id_map.get(ac_word.as_ref()) else {
pattern_id_map.insert(ac_word.clone(), next_pattern_id);
dedup_entries.push(vec![PatternEntry {
process_type_mask: 1u64 << process_type.bits(),
rule_idx: rule_idx as u32,
offset: offset as u16,
}]);
dedup_patterns.push(ac_word);
next_pattern_id += 1;
continue;
};
dedup_entries[existing_dedup_id].push(PatternEntry {
process_type_mask: 1u64 << process_type.bits(),
rule_idx: rule_idx as u32,
offset: offset as u16,
});
}
}
}
}
let process_type_tree = build_process_type_tree(&process_type_set);
let patterns = dedup_patterns
.iter()
.map(|ac_word| ac_word.as_ref())
.collect::<Vec<_>>();
let ac_matcher = InternalMatcher::AhoCorasick({
#[cfg(feature = "dfa")]
let aho_corasick_kind = AhoCorasickKind::DFA;
#[cfg(not(feature = "dfa"))]
let aho_corasick_kind = aho_corasick::AhoCorasickKind::ContiguousNFA;
AhoCorasickBuilder::new()
.kind(Some(aho_corasick_kind))
.build(patterns)
.unwrap()
});
let mut ac_dedup_entries = Vec::with_capacity(dedup_entries.iter().map(|v| v.len()).sum());
let mut ac_dedup_ranges = Vec::with_capacity(dedup_entries.len());
for entries in dedup_entries {
let start = ac_dedup_entries.len();
let len = entries.len();
ac_dedup_entries.extend(entries);
ac_dedup_ranges.push((start, len));
}
SimpleMatcher {
process_type_tree,
ac_matcher,
ac_dedup_entries,
ac_dedup_ranges,
rule_hot,
rule_cold,
}
}
pub fn is_match(&self, text: &str) -> bool {
if text.is_empty() {
return false;
}
let tree = &self.process_type_tree;
let max_pt = tree.len();
SIMPLE_MATCH_STATE.with(|state_cell| {
let mut state = state_cell.borrow_mut();
state.prepare(self.rule_hot.len());
let (text_masks, stopped) =
walk_process_tree::<true, _>(tree, text, &mut |txt, idx, mask| {
self.scan_variant(txt, idx, mask, max_pt, &mut state, true)
});
if stopped {
return_processed_string_to_pool(text_masks);
return true;
}
let generation = state.generation;
let result = state.touched_indices.iter().any(|&rule_idx| {
if state.word_states[rule_idx].not_generation == generation {
return false;
}
Self::is_rule_satisfied(
&self.rule_hot[rule_idx],
&state.word_states,
&state.matrix,
rule_idx,
max_pt,
)
});
return_processed_string_to_pool(text_masks);
result
})
}
pub fn process<'a>(&'a self, text: &'a str) -> Vec<SimpleResult<'a>> {
let mut results = Vec::new();
self.process_into(text, &mut results);
results
}
pub fn process_into<'a>(&'a self, text: &'a str, results: &mut Vec<SimpleResult<'a>>) {
if text.is_empty() {
return;
}
let (processed, _) =
walk_process_tree::<false, _>(&self.process_type_tree, text, &mut |_, _, _| false);
self.process_preprocessed_into(&processed, results);
return_processed_string_to_pool(processed);
}
fn process_preprocessed_into<'a>(
&'a self,
processed_text_process_type_masks: &ProcessedTextMasks<'a>,
results: &mut Vec<SimpleResult<'a>>,
) {
SIMPLE_MATCH_STATE.with(|state| {
let mut state = state.borrow_mut();
state.prepare(self.rule_hot.len());
self.scan_all_variants(processed_text_process_type_masks, &mut state, false);
let generation = state.generation;
let num_variants = processed_text_process_type_masks.len();
for &rule_idx in &state.touched_indices {
if state.word_states[rule_idx].not_generation == generation {
continue;
}
if Self::is_rule_satisfied(
&self.rule_hot[rule_idx],
&state.word_states,
&state.matrix,
rule_idx,
num_variants,
) {
let cold = &self.rule_cold[rule_idx];
results.push(SimpleResult {
word_id: cold.word_id,
word: Cow::Borrowed(&cold.word),
});
}
}
});
}
fn scan_all_variants<'a>(
&'a self,
processed_text_process_type_masks: &ProcessedTextMasks<'a>,
state: &mut SimpleMatchState,
exit_early: bool,
) -> bool {
if self.ac_dedup_ranges.is_empty() {
return false;
}
let num_variants = processed_text_process_type_masks.len();
for (index, (processed_text, process_type_mask)) in
processed_text_process_type_masks.iter().enumerate()
{
if *process_type_mask == 0 {
continue;
}
if self.scan_variant(
processed_text.as_ref(),
index,
*process_type_mask,
num_variants,
state,
exit_early,
) {
return true;
}
}
false
}
#[inline(always)]
fn scan_variant(
&self,
processed_text: &str,
index: usize,
process_type_mask: u64,
num_variants: usize,
state: &mut SimpleMatchState,
exit_early: bool,
) -> bool {
match &self.ac_matcher {
InternalMatcher::AhoCorasick(ac_matcher) => {
for ac_dedup_result in ac_matcher.find_overlapping_iter(processed_text) {
let pattern_idx = ac_dedup_result.pattern().as_usize();
if self.process_match(
pattern_idx,
index,
process_type_mask,
num_variants,
state,
exit_early,
) {
return true;
}
}
false
}
}
}
#[inline(always)]
fn process_match(
&self,
pattern_idx: usize,
text_index: usize,
process_type_mask: u64,
num_variants: usize,
state: &mut SimpleMatchState,
exit_early: bool,
) -> bool {
let generation = state.generation;
let (start, len) = self.ac_dedup_ranges[pattern_idx];
for entry in &self.ac_dedup_entries[start..start + len] {
let &PatternEntry {
process_type_mask: match_process_type_mask,
rule_idx,
offset,
} = entry;
let rule_idx = rule_idx as usize;
let offset = offset as usize;
if process_type_mask & match_process_type_mask == 0
|| state.word_states[rule_idx].not_generation == generation
{
continue;
}
let rule = &self.rule_hot[rule_idx];
if state.word_states[rule_idx].satisfied_generation == generation {
if exit_early {
return true;
}
continue;
}
if state.word_states[rule_idx].matrix_generation != generation {
state.word_states[rule_idx].matrix_generation = generation;
state.touched_indices.push(rule_idx);
state.word_states[rule_idx].satisfied_mask = 0;
if rule.use_matrix {
Self::init_matrix(
&mut state.matrix[rule_idx],
&rule.segment_counts,
num_variants,
);
}
}
let is_satisfied = if rule.use_matrix {
let flat_matrix = &mut state.matrix[rule_idx];
let bit = &mut flat_matrix[offset * num_variants + text_index];
if offset < rule.and_count {
*bit -= 1; } else {
*bit += 1; }
if offset < rule.and_count {
if *bit <= 0 && offset < BITMASK_CAPACITY {
state.word_states[rule_idx].satisfied_mask |= 1u64 << offset;
}
} else if *bit > 0 {
state.word_states[rule_idx].not_generation = generation;
}
Self::is_rule_satisfied(
rule,
&state.word_states,
&state.matrix,
rule_idx,
num_variants,
)
} else if offset < rule.and_count {
if offset < BITMASK_CAPACITY {
state.word_states[rule_idx].satisfied_mask |= 1u64 << offset;
}
let expected_mask = rule.expected_mask;
let satisfied = state.word_states[rule_idx].satisfied_mask == expected_mask;
if satisfied && rule.and_count == rule.num_splits as usize {
state.word_states[rule_idx].satisfied_generation = generation;
}
satisfied
} else {
state.word_states[rule_idx].not_generation = generation;
false
};
if exit_early
&& is_satisfied
&& rule.and_count == rule.num_splits as usize
&& state.word_states[rule_idx].not_generation != generation
{
return true;
}
}
false
}
#[inline(always)]
fn is_rule_satisfied(
rule: &RuleHot,
word_states: &[WordState],
matrix: &[TinyVec<[i32; 16]>],
rule_idx: usize,
num_variants: usize,
) -> bool {
let expected_mask = rule.expected_mask;
if expected_mask > 0 {
return word_states[rule_idx].satisfied_mask == expected_mask;
}
let num_splits = rule.num_splits as usize;
let flat_matrix = &matrix[rule_idx];
(0..num_splits).all(|s| {
flat_matrix[s * num_variants..(s + 1) * num_variants]
.iter()
.any(|&bit| bit <= 0)
})
}
#[cold]
#[inline(never)]
fn init_matrix(
flat_matrix: &mut TinyVec<[i32; 16]>,
segment_counts: &[i32],
num_variants: usize,
) {
let num_splits = segment_counts.len();
flat_matrix.clear();
flat_matrix.resize(num_splits * num_variants, 0i32);
for (s, &bit) in segment_counts.iter().enumerate() {
let row_start = s * num_variants;
flat_matrix[row_start..row_start + num_variants].fill(bit);
}
}
}