use std::fmt::Display;
use std::iter;
use std::{borrow::Cow, collections::HashMap};
use ahash::AHashMap;
use aho_corasick_unsafe::{AhoCorasick, AhoCorasickBuilder, AhoCorasickKind};
use bitflags::bitflags;
use nohash_hasher::{IntMap, IntSet, IsEnabled};
use serde::{Deserializer, Serializer};
use sonic_rs::{Deserialize, Serialize};
use crate::matcher::{MatchResultTrait, TextMatcherTrait};
use crate::process::process_matcher::{
build_smt_tree, reduce_text_process_emit, reduce_text_process_with_tree, SimpleMatchTypeBitNode,
};
bitflags! {
#[derive(Hash, PartialEq, Eq, Clone, Copy, Debug, Default)]
pub struct SimpleMatchType: u8 {
const None = 0b00000001;
const Fanjian = 0b00000010;
const WordDelete = 0b00000100;
const TextDelete = 0b00001000;
const Delete = 0b00001100;
const Normalize = 0b00010000;
const DeleteNormalize = 0b00011100;
const FanjianDeleteNormalize = 0b00011110;
const PinYin = 0b00100000;
const PinYinChar = 0b01000000;
}
}
impl Serialize for SimpleMatchType {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
self.bits().serialize(serializer)
}
}
impl<'de> Deserialize<'de> for SimpleMatchType {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let bits: u8 = u8::deserialize(deserializer)?;
Ok(SimpleMatchType::from_bits_retain(bits))
}
}
impl Display for SimpleMatchType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let display_str_list = self
.iter_names()
.map(|(name, _)| name.to_lowercase())
.collect::<Vec<_>>();
write!(f, "{:?}", display_str_list.join("_"))
}
}
impl IsEnabled for SimpleMatchType {}
pub type SimpleMatchTypeWordMap<'a> = IntMap<SimpleMatchType, IntMap<u32, &'a str>>;
#[derive(Debug, Clone, Serialize, Deserialize)]
struct WordConf {
word: String,
split_bit: Vec<i32>,
not_index: usize,
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
struct SimpleAcTable {
ac_matcher: AhoCorasick,
ac_dedup_word_conf_list: Vec<Vec<(u32, usize)>>,
}
#[derive(Debug, Serialize)]
pub struct SimpleResult<'a> {
pub word_id: u32,
pub word: Cow<'a, str>,
}
impl MatchResultTrait<'_> for SimpleResult<'_> {
fn word_id(&self) -> u32 {
self.word_id
}
fn word(&self) -> &str {
self.word.as_ref()
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct SimpleMatcher {
smt_tree: Option<Vec<SimpleMatchTypeBitNode>>,
smt_ac_table_map: IntMap<SimpleMatchType, SimpleAcTable>,
simple_word_conf_map: IntMap<u32, WordConf>,
}
impl SimpleMatcher {
pub fn new<I, S1, S2>(
smt_word_map: &HashMap<SimpleMatchType, HashMap<u32, I, S1>, S2>,
) -> SimpleMatcher
where
I: AsRef<str>,
{
let mut simple_matcher = SimpleMatcher {
smt_tree: None,
smt_ac_table_map: IntMap::default(),
simple_word_conf_map: IntMap::default(),
};
for (&simple_match_type, simple_word_map) in smt_word_map {
let simple_ac_table = simple_matcher.build_simple_ac_table(
simple_match_type - SimpleMatchType::TextDelete,
simple_word_map,
);
simple_matcher.smt_ac_table_map.insert(
simple_match_type - SimpleMatchType::WordDelete,
simple_ac_table,
);
}
if smt_word_map.len() >= 4 {
simple_matcher.smt_tree = Some(build_smt_tree(
&simple_matcher
.smt_ac_table_map
.keys()
.copied()
.collect::<Vec<SimpleMatchType>>(),
));
}
simple_matcher
}
fn build_simple_ac_table<I, S2>(
&mut self,
simple_match_type: SimpleMatchType,
simple_word_map: &HashMap<u32, I, S2>,
) -> SimpleAcTable
where
I: AsRef<str>,
{
let mut ac_dedup_word_id = 0;
let mut ac_dedup_word_conf_list = Vec::with_capacity(simple_word_map.len());
let mut ac_dedup_word_list = Vec::with_capacity(simple_word_map.len());
let mut ac_dedup_word_id_map = AHashMap::with_capacity(simple_word_map.len());
for (&simple_word_id, simple_word) in simple_word_map {
let mut ac_split_word_and_counter = AHashMap::default();
let mut ac_split_word_not_counter = AHashMap::default();
let mut start = 0;
let mut is_and = false;
let mut is_not = false;
for (index, char) in simple_word.as_ref().match_indices(['&', '~']) {
if (is_and || start == 0) && start != index {
ac_split_word_and_counter
.entry(unsafe { simple_word.as_ref().get_unchecked(start..index) })
.and_modify(|cnt| *cnt += 1)
.or_insert(1);
}
if is_not && start != index {
ac_split_word_not_counter
.entry(unsafe { simple_word.as_ref().get_unchecked(start..index) })
.and_modify(|cnt| *cnt -= 1)
.or_insert(0);
}
match char {
"&" => {
is_and = true;
is_not = false;
start = index + 1;
}
"~" => {
is_and = false;
is_not = true;
start = index + 1
}
_ => {}
}
}
if (is_and || start == 0) && start != simple_word.as_ref().len() {
ac_split_word_and_counter
.entry(unsafe { simple_word.as_ref().get_unchecked(start..) })
.and_modify(|cnt| *cnt += 1)
.or_insert(1);
}
if is_not && start != simple_word.as_ref().len() {
ac_split_word_not_counter
.entry(unsafe { simple_word.as_ref().get_unchecked(start..) })
.and_modify(|cnt| *cnt -= 1)
.or_insert(0);
}
let not_index = ac_split_word_and_counter.len();
let split_bit = ac_split_word_and_counter
.values()
.copied()
.chain(ac_split_word_not_counter.values().copied())
.collect::<Vec<i32>>();
self.simple_word_conf_map.insert(
simple_word_id,
WordConf {
word: simple_word.as_ref().to_owned(),
split_bit,
not_index,
},
);
for (offset, &split_word) in ac_split_word_and_counter
.keys()
.chain(ac_split_word_not_counter.keys())
.enumerate()
{
for ac_word in reduce_text_process_emit(simple_match_type, split_word) {
if let Some(ac_dedup_word_id) = ac_dedup_word_id_map.get(ac_word.as_ref()) {
let word_conf_list: &mut Vec<(u32, usize)> = unsafe {
ac_dedup_word_conf_list.get_unchecked_mut(*ac_dedup_word_id as usize)
};
word_conf_list.push((simple_word_id, offset));
} else {
ac_dedup_word_id_map.insert(ac_word.clone(), ac_dedup_word_id);
ac_dedup_word_conf_list.push(vec![(simple_word_id, offset)]);
ac_dedup_word_list.push(ac_word);
ac_dedup_word_id += 1;
}
}
}
}
#[cfg(feature = "dfa")]
let aho_corasick_kind = AhoCorasickKind::DFA;
#[cfg(not(feature = "dfa"))]
let aho_corasick_kind = AhoCorasickKind::ContiguousNFA;
#[cfg(feature = "serde")]
let prefilter = false;
#[cfg(not(feature = "serde"))]
let prefilter = true;
SimpleAcTable {
ac_matcher: AhoCorasickBuilder::new()
.kind(Some(aho_corasick_kind))
.ascii_case_insensitive(true)
.prefilter(prefilter)
.build(ac_dedup_word_list.iter().map(|ac_word| ac_word.as_ref()))
.unwrap(),
ac_dedup_word_conf_list,
}
}
}
impl<'a> TextMatcherTrait<'a, SimpleResult<'a>> for SimpleMatcher {
fn is_match(&self, text: &str) -> bool {
if text.is_empty() {
return false;
}
let mut word_id_split_bit_map = IntMap::default();
let mut word_id_set = IntSet::default();
let mut not_word_id_set = IntSet::default();
if let Some(smt_tree) = &self.smt_tree {
let (smt_index_set_map, processed_text_list) =
reduce_text_process_with_tree(smt_tree, text);
for (&simple_match_type, simple_ac_table) in &self.smt_ac_table_map {
let processed_index_set =
unsafe { smt_index_set_map.get(&simple_match_type).unwrap_unchecked() };
let processed_times = processed_index_set.len();
for (index, &processed_index) in processed_index_set.iter().enumerate() {
for ac_dedup_result in unsafe {
simple_ac_table
.ac_matcher
.try_find_overlapping_iter(
processed_text_list.get_unchecked(processed_index).as_ref(),
)
.unwrap_unchecked()
} {
for &(word_id, offset) in unsafe {
simple_ac_table
.ac_dedup_word_conf_list
.get_unchecked(ac_dedup_result.pattern().as_usize())
} {
if not_word_id_set.contains(&word_id) {
continue;
}
let word_conf = unsafe {
self.simple_word_conf_map.get(&word_id).unwrap_unchecked()
};
let split_bit_matrix =
word_id_split_bit_map.entry(word_id).or_insert_with(|| {
word_conf
.split_bit
.iter()
.map(|&bit| {
iter::repeat(bit).take(processed_times).collect()
})
.collect::<Vec<Vec<i32>>>()
});
unsafe {
let bit = split_bit_matrix
.get_unchecked_mut(offset)
.get_unchecked_mut(index);
*bit = bit
.unchecked_add((offset < word_conf.not_index) as i32 * -2 + 1);
if offset >= word_conf.not_index && *bit > 0 {
not_word_id_set.insert(word_id);
word_id_set.remove(&word_id);
continue;
}
if split_bit_matrix
.iter()
.all(|split_bit_vec| split_bit_vec.iter().any(|&bit| bit <= 0))
{
word_id_set.insert(word_id);
}
}
}
}
}
if !word_id_set.is_empty() {
return true;
}
}
} else {
for (&simple_match_type, simple_ac_table) in &self.smt_ac_table_map {
let processed_text_list = reduce_text_process_emit(simple_match_type, text);
let processed_times = processed_text_list.len();
for (index, processed_text) in processed_text_list.iter().enumerate() {
for ac_dedup_result in unsafe {
simple_ac_table
.ac_matcher
.try_find_overlapping_iter(processed_text.as_ref())
.unwrap_unchecked()
} {
for &(word_id, offset) in unsafe {
simple_ac_table
.ac_dedup_word_conf_list
.get_unchecked(ac_dedup_result.pattern().as_usize())
} {
if not_word_id_set.contains(&word_id) {
continue;
}
let word_conf = unsafe {
self.simple_word_conf_map.get(&word_id).unwrap_unchecked()
};
let split_bit_matrix =
word_id_split_bit_map.entry(word_id).or_insert_with(|| {
word_conf
.split_bit
.iter()
.map(|&bit| {
iter::repeat(bit).take(processed_times).collect()
})
.collect::<Vec<Vec<i32>>>()
});
unsafe {
let bit = split_bit_matrix
.get_unchecked_mut(offset)
.get_unchecked_mut(index);
*bit = bit
.unchecked_add((offset < word_conf.not_index) as i32 * -2 + 1);
if offset >= word_conf.not_index && *bit > 0 {
not_word_id_set.insert(word_id);
word_id_set.remove(&word_id);
continue;
}
if split_bit_matrix
.iter()
.all(|split_bit_vec| split_bit_vec.iter().any(|&bit| bit <= 0))
{
word_id_set.insert(word_id);
}
}
}
}
}
if !word_id_set.is_empty() {
return true;
}
}
}
false
}
fn process(&'a self, text: &str) -> Vec<SimpleResult<'a>> {
if text.is_empty() {
return Vec::new();
}
let mut word_id_split_bit_map = IntMap::default();
let mut not_word_id_set = IntSet::default();
if let Some(smt_tree) = &self.smt_tree {
let (smt_index_set_map, processed_text_list) =
reduce_text_process_with_tree(smt_tree, text);
for (&simple_match_type, simple_ac_table) in &self.smt_ac_table_map {
let processed_index_set =
unsafe { smt_index_set_map.get(&simple_match_type).unwrap_unchecked() };
let processed_times = processed_index_set.len();
for (index, &processed_index) in processed_index_set.iter().enumerate() {
for ac_dedup_result in unsafe {
simple_ac_table
.ac_matcher
.try_find_overlapping_iter(
processed_text_list.get_unchecked(processed_index).as_ref(),
)
.unwrap_unchecked()
} {
for &(word_id, offset) in unsafe {
simple_ac_table
.ac_dedup_word_conf_list
.get_unchecked(ac_dedup_result.pattern().as_usize())
} {
if not_word_id_set.contains(&word_id) {
continue;
}
let word_conf = unsafe {
self.simple_word_conf_map.get(&word_id).unwrap_unchecked()
};
let split_bit_matrix =
word_id_split_bit_map.entry(word_id).or_insert_with(|| {
word_conf
.split_bit
.iter()
.map(|&bit| {
iter::repeat(bit).take(processed_times).collect()
})
.collect::<Vec<Vec<i32>>>()
});
unsafe {
let split_bit = split_bit_matrix
.get_unchecked_mut(offset)
.get_unchecked_mut(index);
*split_bit = split_bit
.unchecked_add((offset < word_conf.not_index) as i32 * -2 + 1);
if offset >= word_conf.not_index && *split_bit > 0 {
not_word_id_set.insert(word_id);
word_id_split_bit_map.remove(&word_id);
}
};
}
}
}
}
} else {
for (&simple_match_type, simple_ac_table) in &self.smt_ac_table_map {
let processed_text_list = reduce_text_process_emit(simple_match_type, text);
let processed_times = processed_text_list.len();
for (index, processed_text) in processed_text_list.iter().enumerate() {
for ac_dedup_result in unsafe {
simple_ac_table
.ac_matcher
.try_find_overlapping_iter(processed_text.as_ref())
.unwrap_unchecked()
} {
for &(word_id, offset) in unsafe {
simple_ac_table
.ac_dedup_word_conf_list
.get_unchecked(ac_dedup_result.pattern().as_usize())
} {
if not_word_id_set.contains(&word_id) {
continue;
}
let word_conf = unsafe {
self.simple_word_conf_map.get(&word_id).unwrap_unchecked()
};
let split_bit_matrix =
word_id_split_bit_map.entry(word_id).or_insert_with(|| {
word_conf
.split_bit
.iter()
.map(|&bit| {
iter::repeat(bit).take(processed_times).collect()
})
.collect::<Vec<Vec<i32>>>()
});
unsafe {
let split_bit = split_bit_matrix
.get_unchecked_mut(offset)
.get_unchecked_mut(index);
*split_bit = split_bit
.unchecked_add((offset < word_conf.not_index) as i32 * -2 + 1);
if offset >= word_conf.not_index && *split_bit > 0 {
not_word_id_set.insert(word_id);
word_id_split_bit_map.remove(&word_id);
}
};
}
}
}
}
}
word_id_split_bit_map
.into_iter()
.filter_map(|(word_id, split_bit_matrix)| {
split_bit_matrix
.into_iter()
.all(|split_bit_vec| split_bit_vec.into_iter().any(|split_bit| split_bit <= 0))
.then_some(SimpleResult {
word_id,
word: Cow::Borrowed(
&unsafe { self.simple_word_conf_map.get(&word_id).unwrap_unchecked() }
.word,
),
})
})
.collect()
}
}