use std::fmt::Display;
use std::iter;
use std::simd::Simd;
use std::{borrow::Cow, collections::HashMap};
use ahash::AHashMap;
use aho_corasick::{AhoCorasick, AhoCorasickBuilder, AhoCorasickKind::DFA};
use bitflags::bitflags;
use nohash_hasher::{IntMap, IntSet, IsEnabled};
use serde::{Deserializer, Serializer};
use sonic_rs::{Deserialize, Serialize};
use tinyvec::ArrayVec;
use crate::process::process_matcher::reduce_text_process;
use crate::{MatchResultTrait, TextMatcherTrait};
const WORD_COMBINATION_LIMIT: usize = 32;
const ZEROS: Simd<u8, WORD_COMBINATION_LIMIT> = Simd::from_array([0; WORD_COMBINATION_LIMIT]);
bitflags! {
#[derive(Hash, PartialEq, Eq, Clone, Copy, Debug)]
pub struct SimpleMatchType: u8 {
const None = 0b00000001;
const Fanjian = 0b00000010;
const WordDelete = 0b00000100;
const TextDelete = 0b00001000;
const Delete = 0b00001100;
const Normalize = 0b00010000;
const DeleteNormalize = 0b00011100;
const FanjianDeleteNormalize = 0b00011110;
const PinYin = 0b00100000;
const PinYinChar = 0b01000000;
}
}
impl Serialize for SimpleMatchType {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
self.bits().serialize(serializer)
}
}
impl<'de> Deserialize<'de> for SimpleMatchType {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let bits: u8 = u8::deserialize(deserializer)?;
Ok(SimpleMatchType::from_bits_retain(bits))
}
}
impl Display for SimpleMatchType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let display_str_list = self
.iter_names()
.map(|(name, _)| name.to_lowercase())
.collect::<Vec<_>>();
write!(f, "{:?}", display_str_list.join("_"))
}
}
impl IsEnabled for SimpleMatchType {}
pub type SimpleMatchTypeWordMap<'a> = IntMap<SimpleMatchType, IntMap<u64, &'a str>>;
#[derive(Debug, Clone)]
struct WordConf {
word: String,
split_bit: Simd<u8, WORD_COMBINATION_LIMIT>,
}
#[derive(Debug, Clone)]
struct SimpleAcTable {
ac_matcher: AhoCorasick,
ac_word_conf_list: Vec<(u64, usize)>,
}
#[derive(Debug, Serialize)]
pub struct SimpleResult<'a> {
pub word_id: u64,
pub word: Cow<'a, str>,
}
impl MatchResultTrait<'_> for SimpleResult<'_> {
fn word_id(&self) -> u64 {
self.word_id
}
fn word(&self) -> &str {
self.word.as_ref()
}
}
#[derive(Clone)]
pub struct SimpleMatcher {
simple_match_type_ac_table_map: IntMap<SimpleMatchType, SimpleAcTable>,
simple_wordconf_map: IntMap<u64, WordConf>,
}
impl SimpleMatcher {
pub fn new<I, S1, S2>(
simple_match_type_word_map: &HashMap<SimpleMatchType, HashMap<u64, I, S1>, S2>,
) -> SimpleMatcher
where
I: AsRef<str>,
{
let mut simple_matcher = SimpleMatcher {
simple_match_type_ac_table_map: IntMap::default(),
simple_wordconf_map: IntMap::default(),
};
for (simple_match_type, simple_word_map) in simple_match_type_word_map {
let simple_ac_table = simple_matcher.build_simple_ac_table(
*simple_match_type - SimpleMatchType::TextDelete,
simple_word_map,
);
simple_matcher.simple_match_type_ac_table_map.insert(
*simple_match_type - SimpleMatchType::WordDelete,
simple_ac_table,
);
}
simple_matcher
}
fn build_simple_ac_table<I, S2>(
&mut self,
simple_match_type: SimpleMatchType,
simple_word_map: &HashMap<u64, I, S2>,
) -> SimpleAcTable
where
I: AsRef<str>,
{
let mut ac_wordlist = Vec::new();
let mut ac_word_conf_list = Vec::new();
for (&simple_word_id, simple_word) in simple_word_map {
let mut ac_split_word_counter = AHashMap::default();
for ac_split_word in simple_word.as_ref().split(',').filter(|&x| !x.is_empty()) {
ac_split_word_counter
.entry(ac_split_word)
.and_modify(|cnt| *cnt += 1)
.or_insert(1);
}
let split_bit_vec = ac_split_word_counter
.values()
.take(WORD_COMBINATION_LIMIT)
.map(|&x| 1 << (x.min(8) - 1))
.collect::<ArrayVec<[u8; WORD_COMBINATION_LIMIT]>>();
let split_bit = Simd::load_or_default(&split_bit_vec);
self.simple_wordconf_map.insert(
simple_word_id,
WordConf {
word: simple_word.as_ref().to_owned(),
split_bit,
},
);
for (offset, &split_word) in ac_split_word_counter
.keys()
.take(WORD_COMBINATION_LIMIT)
.enumerate()
{
for ac_word in reduce_text_process(simple_match_type, split_word) {
ac_wordlist.push(ac_word);
ac_word_conf_list.push((simple_word_id, offset));
}
}
}
SimpleAcTable {
ac_matcher: AhoCorasickBuilder::new()
.kind(Some(DFA))
.ascii_case_insensitive(true)
.build(
ac_wordlist
.iter()
.map(|ac_word| ac_word.as_ref().as_bytes()),
)
.unwrap(),
ac_word_conf_list,
}
}
}
impl<'a> TextMatcherTrait<'a, SimpleResult<'a>> for SimpleMatcher {
fn is_match(&self, text: &str) -> bool {
if text.is_empty() {
return false;
}
let mut word_id_split_bit_map = IntMap::default();
for (&simple_match_type, simple_ac_table) in &self.simple_match_type_ac_table_map {
let processed_text_list = reduce_text_process(simple_match_type, text);
let processed_times = processed_text_list.len();
for (index, processed_text) in processed_text_list.iter().enumerate() {
for ac_result in simple_ac_table
.ac_matcher
.find_overlapping_iter(processed_text.as_ref())
{
let ac_word_id = ac_result.pattern().as_usize();
let ac_word_conf =
unsafe { simple_ac_table.ac_word_conf_list.get_unchecked(ac_word_id) };
let word_id = ac_word_conf.0;
let word_conf =
unsafe { self.simple_wordconf_map.get(&word_id).unwrap_unchecked() };
let split_bit_vec = word_id_split_bit_map.entry(word_id).or_insert_with(|| {
iter::repeat_n(word_conf.split_bit, processed_times)
.collect::<ArrayVec<[_; 8]>>()
});
*unsafe {
split_bit_vec
.get_unchecked_mut(index)
.as_mut_array()
.get_unchecked_mut(ac_word_conf.1)
} >>= 1;
if split_bit_vec
.iter()
.fold(Simd::splat(1), |acc, &bit| acc & bit)
== ZEROS
{
return true;
}
}
}
}
false
}
fn process(&'a self, text: &str) -> Vec<SimpleResult<'a>> {
let mut result_list = Vec::new();
if text.is_empty() {
return result_list;
}
let mut word_id_set = IntSet::default();
let mut word_id_split_bit_map = IntMap::default();
for (&simple_match_type, simple_ac_table) in &self.simple_match_type_ac_table_map {
let processed_text_list = reduce_text_process(simple_match_type, text);
let processed_times = processed_text_list.len();
for (index, processed_text) in processed_text_list.iter().enumerate() {
for ac_result in simple_ac_table
.ac_matcher
.find_overlapping_iter(processed_text.as_ref())
{
let ac_word_conf = unsafe {
simple_ac_table
.ac_word_conf_list
.get_unchecked(ac_result.pattern().as_usize())
};
let word_id = ac_word_conf.0;
if word_id_set.contains(&word_id) {
continue;
}
let word_conf =
unsafe { self.simple_wordconf_map.get(&word_id).unwrap_unchecked() };
let split_bit_vec = word_id_split_bit_map.entry(word_id).or_insert_with(|| {
iter::repeat_n(word_conf.split_bit, processed_times)
.collect::<ArrayVec<[_; 8]>>()
});
*unsafe {
split_bit_vec
.get_unchecked_mut(index)
.as_mut_array()
.get_unchecked_mut(ac_word_conf.1)
} >>= 1;
if split_bit_vec
.iter()
.fold(Simd::splat(1), |acc, &bit| acc & bit)
== ZEROS
{
word_id_set.insert(word_id);
result_list.push(SimpleResult {
word_id,
word: Cow::Borrowed(&word_conf.word),
});
}
}
}
}
result_list
}
}