use std::borrow::Cow;
use std::cell::RefCell;
use std::collections::{HashMap, HashSet};
#[cfg(not(feature = "vectorscan"))]
use aho_corasick::AhoCorasickKind;
use aho_corasick::{AhoCorasick, AhoCorasickBuilder};
use serde::Serialize;
use tinyvec::TinyVec;
use crate::process::process_matcher::{
ProcessType, ProcessTypeBitNode, ProcessedTextMasks, build_process_type_tree,
reduce_text_process_emit, reduce_text_process_with_tree, return_processed_string_to_pool,
};
#[cfg(feature = "vectorscan")]
use crate::vectorscan::VectorscanScanner;
struct SimpleMatchState {
matrix: Vec<TinyVec<[i32; 16]>>,
matrix_generation: Vec<u32>,
not_flags_generation: Vec<u32>,
touched_indices: Vec<usize>,
generation: u32,
}
impl SimpleMatchState {
fn new() -> Self {
Self {
matrix: Vec::new(),
matrix_generation: Vec::new(),
not_flags_generation: Vec::new(),
touched_indices: Vec::new(),
generation: 0,
}
}
fn prepare(&mut self, size: usize) {
if self.generation == u32::MAX {
self.matrix_generation.fill(0);
self.not_flags_generation.fill(0);
self.generation = 1;
} else {
self.generation += 1;
}
if self.matrix.len() < size {
self.matrix.resize(size, TinyVec::new());
self.matrix_generation.resize(size, 0);
self.not_flags_generation.resize(size, 0);
}
self.touched_indices.clear();
}
}
thread_local! {
static SIMPLE_MATCH_STATE: RefCell<SimpleMatchState> = RefCell::new(SimpleMatchState::new());
}
pub type SimpleTable<'a> = HashMap<ProcessType, HashMap<u32, &'a str>>;
pub type SimpleTableSerde<'a> = HashMap<ProcessType, HashMap<u32, Cow<'a, str>>>;
#[derive(Debug, Clone)]
struct WordConf {
word_id: u32,
word: String,
split_bit: Vec<i32>,
not_offset: usize,
}
#[derive(Serialize, Debug)]
pub struct SimpleResult<'a> {
pub word_id: u32,
pub word: Cow<'a, str>,
}
#[derive(Debug, Clone)]
struct WordConfEntry {
process_type: ProcessType,
word_conf_idx: usize,
offset: usize,
}
#[derive(Debug, Clone)]
enum AcMatcher {
#[cfg_attr(feature = "vectorscan", allow(dead_code))]
AhoCorasick(AhoCorasick),
#[cfg(feature = "vectorscan")]
Vectorscan(VectorscanScanner),
}
#[derive(Debug, Clone)]
pub struct SimpleMatcher {
process_type_tree: Vec<ProcessTypeBitNode>,
ac_matcher: AcMatcher,
ac_dedup_word_conf_list: Vec<Vec<WordConfEntry>>,
word_conf_list: Vec<WordConf>,
}
impl SimpleMatcher {
pub fn new<'a, I, S1, S2>(
process_type_word_map: &'a HashMap<ProcessType, HashMap<u32, I, S1>, S2>,
) -> SimpleMatcher
where
I: AsRef<str> + 'a,
{
let word_size: usize = process_type_word_map.values().map(|m| m.len()).sum();
let mut process_type_set = HashSet::with_capacity(process_type_word_map.len());
let mut ac_dedup_word_conf_list = Vec::with_capacity(word_size);
let mut word_conf_list: Vec<WordConf> = Vec::with_capacity(word_size);
let mut word_id_to_idx: HashMap<u32, usize> = HashMap::with_capacity(word_size);
let mut ac_dedup_word_id = 0;
let mut ac_dedup_word_list = Vec::with_capacity(word_size);
let mut ac_dedup_word_id_map = HashMap::with_capacity(word_size);
for (&process_type, simple_word_map) in process_type_word_map {
let word_process_type = process_type - ProcessType::Delete;
process_type_set.insert(process_type.bits());
for (&simple_word_id, simple_word) in simple_word_map {
if simple_word.as_ref().is_empty() {
continue;
}
let mut ac_split_word_and_counter = HashMap::new();
let mut ac_split_word_not_counter = HashMap::new();
let mut start = 0;
let mut current_is_not = false;
let mut add_sub_word = |word: &'a str, is_not: bool| {
if word.is_empty() {
return;
}
if is_not {
let entry = ac_split_word_not_counter.entry(word).or_insert(1);
*entry -= 1;
} else {
let entry = ac_split_word_and_counter.entry(word).or_insert(0);
*entry += 1;
}
};
for (index, char) in simple_word.as_ref().match_indices(['&', '~']) {
add_sub_word(&simple_word.as_ref()[start..index], current_is_not);
current_is_not = char == "~";
start = index + 1;
}
add_sub_word(&simple_word.as_ref()[start..], current_is_not);
if ac_split_word_and_counter.is_empty() && ac_split_word_not_counter.is_empty() {
continue;
}
let not_offset = ac_split_word_and_counter.len();
let split_bit = ac_split_word_and_counter
.values()
.copied()
.chain(ac_split_word_not_counter.values().copied())
.collect::<Vec<i32>>();
let word_conf_idx = if let Some(&existing_idx) = word_id_to_idx.get(&simple_word_id)
{
word_conf_list[existing_idx] = WordConf {
word_id: simple_word_id,
word: simple_word.as_ref().to_owned(),
split_bit,
not_offset,
};
existing_idx
} else {
let idx = word_conf_list.len();
word_id_to_idx.insert(simple_word_id, idx);
word_conf_list.push(WordConf {
word_id: simple_word_id,
word: simple_word.as_ref().to_owned(),
split_bit,
not_offset,
});
idx
};
for (offset, &split_word) in ac_split_word_and_counter
.keys()
.chain(ac_split_word_not_counter.keys())
.enumerate()
{
for ac_word in reduce_text_process_emit(word_process_type, split_word) {
let Some(&ac_dedup_word_id) = ac_dedup_word_id_map.get(ac_word.as_ref())
else {
ac_dedup_word_id_map.insert(ac_word.clone(), ac_dedup_word_id);
ac_dedup_word_conf_list.push(vec![WordConfEntry {
process_type,
word_conf_idx,
offset,
}]);
ac_dedup_word_list.push(ac_word);
ac_dedup_word_id += 1;
continue;
};
ac_dedup_word_conf_list[ac_dedup_word_id as usize].push(WordConfEntry {
process_type,
word_conf_idx,
offset,
});
}
}
}
}
let process_type_tree = build_process_type_tree(&process_type_set);
let patterns = ac_dedup_word_list
.iter()
.map(|ac_word| ac_word.as_ref())
.collect::<Vec<_>>();
#[cfg(feature = "vectorscan")]
let ac_matcher = if patterns.is_empty() {
AcMatcher::AhoCorasick(AhoCorasickBuilder::new().build(&patterns).unwrap())
} else {
let flags = vec![0u32; patterns.len()];
AcMatcher::Vectorscan(
VectorscanScanner::new_literal(&patterns, &flags)
.expect("failed to compile vectorscan literal database"),
)
};
#[cfg(not(feature = "vectorscan"))]
let ac_matcher = {
#[cfg(feature = "dfa")]
let aho_corasick_kind = AhoCorasickKind::DFA;
#[cfg(not(feature = "dfa"))]
let aho_corasick_kind = AhoCorasickKind::ContiguousNFA;
AcMatcher::AhoCorasick(
AhoCorasickBuilder::new()
.kind(Some(aho_corasick_kind))
.build(&patterns)
.unwrap(),
)
};
SimpleMatcher {
process_type_tree,
ac_matcher,
ac_dedup_word_conf_list,
word_conf_list,
}
}
fn _word_match_with_processed_text_process_type_masks<'a>(
&'a self,
processed_text_process_type_masks: &ProcessedTextMasks<'a>,
state: &mut SimpleMatchState,
) {
if self.ac_dedup_word_conf_list.is_empty() {
return;
}
let processed_times = processed_text_process_type_masks.len();
for (index, (processed_text, process_type_mask)) in
processed_text_process_type_masks.iter().enumerate()
{
match &self.ac_matcher {
AcMatcher::AhoCorasick(ac_matcher) => {
for ac_dedup_result in ac_matcher.find_overlapping_iter(processed_text.as_ref())
{
let pattern_idx = ac_dedup_result.pattern().as_usize();
self.process_match(
pattern_idx,
index,
*process_type_mask,
processed_times,
state,
);
}
}
#[cfg(feature = "vectorscan")]
AcMatcher::Vectorscan(scanner) => {
let _ = scanner.scan(processed_text.as_ref().as_bytes(), |pattern_idx| {
self.process_match(
pattern_idx,
index,
*process_type_mask,
processed_times,
state,
);
});
}
}
}
}
#[inline]
fn process_match(
&self,
pattern_idx: usize,
text_index: usize,
process_type_mask: u64,
processed_times: usize,
state: &mut SimpleMatchState,
) {
let generation = state.generation;
for &WordConfEntry {
process_type: match_process_type,
word_conf_idx,
offset,
} in &self.ac_dedup_word_conf_list[pattern_idx]
{
if process_type_mask & (1u64 << match_process_type.bits()) == 0
|| state.not_flags_generation[word_conf_idx] == generation
{
continue;
}
let word_conf = &self.word_conf_list[word_conf_idx];
if state.matrix_generation[word_conf_idx] != generation {
state.matrix_generation[word_conf_idx] = generation;
state.touched_indices.push(word_conf_idx);
let num_splits = word_conf.split_bit.len();
let flat_matrix = &mut state.matrix[word_conf_idx];
flat_matrix.clear();
flat_matrix.resize(num_splits * processed_times, 0i32);
for (s, &bit) in word_conf.split_bit.iter().enumerate() {
let row_start = s * processed_times;
flat_matrix[row_start..row_start + processed_times].fill(bit);
}
}
let flat_matrix = &mut state.matrix[word_conf_idx];
let bit = &mut flat_matrix[offset * processed_times + text_index];
*bit += (offset < word_conf.not_offset) as i32 * -2 + 1;
if offset >= word_conf.not_offset && *bit > 0 {
state.not_flags_generation[word_conf_idx] = generation;
}
}
}
pub fn is_match<'a>(&'a self, text: &'a str) -> bool {
if text.is_empty() {
return false;
}
let processed_text_process_type_masks =
reduce_text_process_with_tree(&self.process_type_tree, text);
let result = self.is_match_preprocessed(&processed_text_process_type_masks);
return_processed_string_to_pool(processed_text_process_type_masks);
result
}
pub fn process<'a>(&'a self, text: &'a str) -> Vec<SimpleResult<'a>> {
if text.is_empty() {
return Vec::new();
}
let processed_text_process_type_masks =
reduce_text_process_with_tree(&self.process_type_tree, text);
let result = self.process_preprocessed(&processed_text_process_type_masks);
return_processed_string_to_pool(processed_text_process_type_masks);
result
}
fn is_match_preprocessed<'a>(
&'a self,
processed_text_process_type_masks: &ProcessedTextMasks<'a>,
) -> bool {
SIMPLE_MATCH_STATE.with(|state| {
let mut state = state.borrow_mut();
state.prepare(self.word_conf_list.len());
self._word_match_with_processed_text_process_type_masks(
processed_text_process_type_masks,
&mut state,
);
let generation = state.generation;
let processed_times = processed_text_process_type_masks.len();
state.touched_indices.iter().any(|&word_conf_idx| {
if state.not_flags_generation[word_conf_idx] == generation {
return false;
}
let num_splits = self.word_conf_list[word_conf_idx].split_bit.len();
let flat_matrix = &state.matrix[word_conf_idx];
(0..num_splits).all(|s| {
flat_matrix[s * processed_times..(s + 1) * processed_times]
.iter()
.any(|&bit| bit <= 0)
})
})
})
}
fn process_preprocessed<'a>(
&'a self,
processed_text_process_type_masks: &ProcessedTextMasks<'a>,
) -> Vec<SimpleResult<'a>> {
SIMPLE_MATCH_STATE.with(|state| {
let mut state = state.borrow_mut();
state.prepare(self.word_conf_list.len());
self._word_match_with_processed_text_process_type_masks(
processed_text_process_type_masks,
&mut state,
);
let generation = state.generation;
let processed_times = processed_text_process_type_masks.len();
state
.touched_indices
.iter()
.filter_map(|&word_conf_idx| {
if state.not_flags_generation[word_conf_idx] == generation {
return None;
}
let word_conf = &self.word_conf_list[word_conf_idx];
let num_splits = word_conf.split_bit.len();
let flat_matrix = &state.matrix[word_conf_idx];
(0..num_splits)
.all(|s| {
flat_matrix[s * processed_times..(s + 1) * processed_times]
.iter()
.any(|&bit| bit <= 0)
})
.then_some(SimpleResult {
word_id: word_conf.word_id,
word: Cow::Borrowed(&word_conf.word),
})
})
.collect()
})
}
}