use std::borrow::Cow;
use std::cell::RefCell;
#[cfg(feature = "runtime_build")]
use std::collections::HashMap;
use std::collections::HashSet;
use std::fmt::Display;
use std::sync::OnceLock;
use tinyvec::TinyVec;
use bitflags::bitflags;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use crate::process::constants::*;
use crate::process::multi_char_matcher::MultiCharMatcher;
use crate::process::single_char_matcher::{SingleCharMatch, SingleCharMatcher};
const STRING_POOL_INIT_CAP: usize = 16;
const MASKS_POOL_INIT_CAP: usize = 4;
const STRING_POOL_MAX: usize = 128;
const MASKS_POOL_MAX: usize = 16;
struct TransformThreadState {
tree_node_indices: Vec<usize>,
masks_pool: Vec<ProcessedTextMasks<'static>>,
}
impl TransformThreadState {
fn new() -> Self {
Self {
tree_node_indices: Vec::with_capacity(16),
masks_pool: Vec::with_capacity(MASKS_POOL_INIT_CAP),
}
}
}
thread_local! {
static STRING_POOL: RefCell<Vec<String>> = RefCell::new(Vec::with_capacity(STRING_POOL_INIT_CAP));
static TRANSFORM_STATE: RefCell<TransformThreadState> = RefCell::new(TransformThreadState::new());
}
fn get_string_from_pool(capacity: usize) -> String {
STRING_POOL.with(|pool| {
if let Some(mut s) = pool.borrow_mut().pop() {
s.clear();
if s.capacity() < capacity {
s.reserve(capacity - s.capacity());
}
s
} else {
String::with_capacity(capacity)
}
})
}
fn return_string_to_pool(s: String) {
STRING_POOL.with(|pool| {
let mut pool = pool.borrow_mut();
if pool.len() < STRING_POOL_MAX {
pool.push(s);
}
});
}
pub(crate) fn return_processed_string_to_pool(mut text_masks: ProcessedTextMasks) {
for (cow, _) in text_masks.drain(..) {
if let Cow::Owned(s) = cow {
return_string_to_pool(s);
}
}
let empty: ProcessedTextMasks<'static> = unsafe { std::mem::transmute(text_masks) };
TRANSFORM_STATE.with(|state| {
let mut state = state.borrow_mut();
if state.masks_pool.len() < MASKS_POOL_MAX {
state.masks_pool.push(empty);
}
});
}
bitflags! {
#[derive(Hash, PartialEq, Eq, Clone, Copy, Debug, Default)]
pub struct ProcessType: u8 {
const None = 0b00000001;
const Fanjian = 0b00000010;
const Delete = 0b00000100;
const Normalize = 0b00001000;
const DeleteNormalize = 0b00001100;
const FanjianDeleteNormalize = 0b00001110;
const PinYin = 0b00010000;
const PinYinChar = 0b00100000;
}
}
impl Serialize for ProcessType {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
self.bits().serialize(serializer)
}
}
impl<'de> Deserialize<'de> for ProcessType {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let bits: u8 = u8::deserialize(deserializer)?;
Ok(ProcessType::from_bits_retain(bits))
}
}
impl Display for ProcessType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let names = self
.iter_names()
.map(|(name, _)| name.to_lowercase())
.collect::<Vec<_>>();
write!(f, "{}", names.join("_"))
}
}
static PROCESS_MATCHER_CACHE: [OnceLock<ProcessMatcher>; 8] = [
OnceLock::new(),
OnceLock::new(),
OnceLock::new(),
OnceLock::new(),
OnceLock::new(),
OnceLock::new(),
OnceLock::new(),
OnceLock::new(),
];
pub type ProcessedTextMasks<'a> = Vec<(Cow<'a, str>, u64)>;
#[derive(Clone)]
pub(crate) enum ProcessMatcher {
MultiChar(MultiCharMatcher),
SingleChar(SingleCharMatcher),
}
impl ProcessMatcher {
#[inline(always)]
fn replace_scan<'a, I, M, F>(
text: &'a str,
mut iter: I,
mut push_replacement: F,
) -> (bool, Cow<'a, str>)
where
I: Iterator<Item = (usize, usize, M)>,
F: FnMut(&mut String, M),
{
if let Some((start, end, m)) = iter.next() {
let mut result = get_string_from_pool(text.len());
result.push_str(&text[0..start]);
push_replacement(&mut result, m);
let mut last_end = end;
for (start, end, m) in iter {
result.push_str(&text[last_end..start]);
push_replacement(&mut result, m);
last_end = end;
}
result.push_str(&text[last_end..]);
(true, Cow::Owned(result))
} else {
(false, Cow::Borrowed(text))
}
}
#[inline(always)]
pub(crate) fn replace_all<'a>(&self, text: &'a str) -> (bool, Cow<'a, str>) {
match self {
ProcessMatcher::SingleChar(matcher) => match matcher {
SingleCharMatcher::Fanjian { .. } => {
Self::replace_scan(text, matcher.fanjian_iter(text), |result, m| {
if let SingleCharMatch::Char(c) = m {
result.push(c);
}
})
}
SingleCharMatcher::Pinyin { .. } => {
Self::replace_scan(text, matcher.pinyin_iter(text), |result, m| {
if let SingleCharMatch::Str(s) = m {
result.push_str(s);
}
})
}
SingleCharMatcher::Delete { .. } => {
debug_assert!(false, "replace_all called on Delete matcher");
(false, Cow::Borrowed(text))
}
},
ProcessMatcher::MultiChar(mc) => {
let replacements = mc.replace_list();
Self::replace_scan(text, mc.find_iter(text), |result, idx| {
result.push_str(replacements[idx]);
})
}
}
}
#[inline(always)]
pub(crate) fn delete_all<'a>(&self, text: &'a str) -> (bool, Cow<'a, str>) {
let ProcessMatcher::SingleChar(matcher) = self else {
debug_assert!(false, "delete_all called on non-Delete matcher");
return (false, Cow::Borrowed(text));
};
Self::replace_scan(text, matcher.delete_iter(text), |_, _| {})
}
}
pub fn get_process_matcher(process_type_bit: ProcessType) -> &'static ProcessMatcher {
let index = process_type_bit.bits().trailing_zeros() as usize;
debug_assert!(index < 8, "ProcessType bit index out of bounds");
PROCESS_MATCHER_CACHE[index].get_or_init(|| {
#[cfg(feature = "runtime_build")]
{
match process_type_bit {
ProcessType::Fanjian => {
let mut map = HashMap::new();
for line in FANJIAN.trim().lines() {
let mut split = line.split('\t');
let k = split.next().unwrap().chars().next().unwrap() as u32;
let v = split.next().unwrap().chars().next().unwrap() as u32;
if k != v {
map.insert(k, v);
}
}
ProcessMatcher::SingleChar(SingleCharMatcher::fanjian_from_map(map))
}
ProcessType::PinYin | ProcessType::PinYinChar => {
let mut map = HashMap::new();
for line in PINYIN.trim().lines() {
let mut split = line.split('\t');
let k = split.next().unwrap().chars().next().unwrap() as u32;
let v = split.next().unwrap();
map.insert(k, v);
}
ProcessMatcher::SingleChar(SingleCharMatcher::pinyin_from_map(
map,
process_type_bit == ProcessType::PinYinChar,
))
}
ProcessType::Delete => ProcessMatcher::SingleChar(
SingleCharMatcher::delete_from_sources(TEXT_DELETE, WHITE_SPACE),
),
ProcessType::Normalize => {
let mut process_dict: HashMap<&'static str, &'static str> = HashMap::new();
for process_map in [NORM, NUM_NORM] {
process_dict.extend(process_map.trim().lines().map(|pair_str| {
let mut split = pair_str.split('\t');
(split.next().unwrap(), split.next().unwrap())
}));
}
process_dict.retain(|&key, &mut value| key != value);
ProcessMatcher::MultiChar(MultiCharMatcher::new_from_dict(process_dict))
}
_ => ProcessMatcher::MultiChar(MultiCharMatcher::new_empty()),
}
}
#[cfg(not(feature = "runtime_build"))]
{
match process_type_bit {
ProcessType::None => ProcessMatcher::MultiChar(MultiCharMatcher::new_empty()),
ProcessType::Fanjian => ProcessMatcher::SingleChar(SingleCharMatcher::fanjian(
Cow::Borrowed(FANJIAN_L1_BYTES),
Cow::Borrowed(FANJIAN_L2_BYTES),
)),
ProcessType::Delete => ProcessMatcher::SingleChar(SingleCharMatcher::delete(
Cow::Borrowed(DELETE_BITSET_BYTES),
)),
ProcessType::Normalize => {
#[cfg(feature = "dfa")]
{
ProcessMatcher::MultiChar(
MultiCharMatcher::new(NORMALIZE_PROCESS_LIST_STR.lines())
.with_replace_list(
NORMALIZE_PROCESS_REPLACE_LIST_STR.lines().collect(),
),
)
}
#[cfg(not(feature = "dfa"))]
{
ProcessMatcher::MultiChar(
MultiCharMatcher::deserialize_from(NORMALIZE_PROCESS_MATCHER_BYTES)
.with_replace_list(
NORMALIZE_PROCESS_REPLACE_LIST_STR.lines().collect(),
),
)
}
}
ProcessType::PinYin => ProcessMatcher::SingleChar(SingleCharMatcher::pinyin(
Cow::Borrowed(PINYIN_L1_BYTES),
Cow::Borrowed(PINYIN_L2_BYTES),
Cow::Borrowed(PINYIN_STR_BYTES),
false,
)),
ProcessType::PinYinChar => ProcessMatcher::SingleChar(SingleCharMatcher::pinyin(
Cow::Borrowed(PINYIN_L1_BYTES),
Cow::Borrowed(PINYIN_L2_BYTES),
Cow::Borrowed(PINYIN_STR_BYTES),
true,
)),
_ => unreachable!(),
}
}
})
}
#[inline(always)]
pub fn text_process<'a>(process_type: ProcessType, text: &'a str) -> Cow<'a, str> {
let mut result = Cow::Borrowed(text);
for process_type_bit in process_type.iter() {
let pm = get_process_matcher(process_type_bit);
match process_type_bit {
ProcessType::None => continue,
ProcessType::Delete => {
if let (true, Cow::Owned(processed)) = pm.delete_all(result.as_ref())
&& let Cow::Owned(old) = std::mem::replace(&mut result, Cow::Owned(processed))
{
return_string_to_pool(old);
}
}
_ => {
if let (true, Cow::Owned(processed)) = pm.replace_all(result.as_ref())
&& let Cow::Owned(old) = std::mem::replace(&mut result, Cow::Owned(processed))
{
return_string_to_pool(old);
}
}
}
}
result
}
#[inline(always)]
pub fn reduce_text_process<'a>(process_type: ProcessType, text: &'a str) -> Vec<Cow<'a, str>> {
let mut text_list: Vec<Cow<'a, str>> = Vec::new();
text_list.push(Cow::Borrowed(text));
for process_type_bit in process_type.iter() {
let pm = get_process_matcher(process_type_bit);
let current_text = text_list
.last_mut()
.expect("It should always have at least one element");
match process_type_bit {
ProcessType::None => continue,
ProcessType::Delete => {
if let (true, Cow::Owned(processed)) = pm.delete_all(current_text.as_ref()) {
text_list.push(Cow::Owned(processed));
}
}
_ => {
if let (true, Cow::Owned(processed)) = pm.replace_all(current_text.as_ref()) {
text_list.push(Cow::Owned(processed));
}
}
}
}
text_list
}
#[inline(always)]
pub fn reduce_text_process_emit<'a>(process_type: ProcessType, text: &'a str) -> Vec<Cow<'a, str>> {
let mut text_list: Vec<Cow<'a, str>> = Vec::new();
text_list.push(Cow::Borrowed(text));
for process_type_bit in process_type.iter() {
let pm = get_process_matcher(process_type_bit);
let current_text = text_list
.last_mut()
.expect("It should always have at least one element");
match process_type_bit {
ProcessType::None => continue,
ProcessType::Delete => {
if let (true, Cow::Owned(processed)) = pm.delete_all(current_text.as_ref()) {
text_list.push(Cow::Owned(processed));
}
}
_ => {
if let (true, Cow::Owned(processed)) = pm.replace_all(current_text.as_ref()) {
*current_text = Cow::Owned(processed);
}
}
}
}
text_list
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ProcessTypeBitNode {
process_type_list: Vec<ProcessType>,
process_type_bit: ProcessType,
children: Vec<usize>,
folded_mask: u64,
}
pub fn build_process_type_tree(process_type_set: &HashSet<ProcessType>) -> Vec<ProcessTypeBitNode> {
let mut process_type_tree = Vec::new();
let mut root = ProcessTypeBitNode {
process_type_list: Vec::new(),
process_type_bit: ProcessType::None,
children: Vec::new(),
folded_mask: 0,
};
if process_type_set.contains(&ProcessType::None) {
root.process_type_list.push(ProcessType::None);
root.folded_mask |= 1u64 << ProcessType::None.bits();
}
process_type_tree.push(root);
for &process_type in process_type_set.iter() {
let mut current_node_index = 0;
for process_type_bit in process_type.iter() {
let current_node = &process_type_tree[current_node_index];
if current_node.process_type_bit == process_type_bit {
continue;
}
let found_child = current_node
.children
.iter()
.find(|&&idx| process_type_tree[idx].process_type_bit == process_type_bit)
.copied();
if let Some(child_idx) = found_child {
current_node_index = child_idx;
process_type_tree[current_node_index]
.process_type_list
.push(process_type);
process_type_tree[current_node_index].folded_mask |= 1u64 << process_type.bits();
} else {
let mut child = ProcessTypeBitNode {
process_type_list: Vec::new(),
process_type_bit,
children: Vec::new(),
folded_mask: 0,
};
child.process_type_list.push(process_type);
child.folded_mask |= 1u64 << process_type.bits();
process_type_tree.push(child);
let new_node_index = process_type_tree.len() - 1;
process_type_tree[current_node_index]
.children
.push(new_node_index);
current_node_index = new_node_index;
}
}
}
process_type_tree
}
#[inline(always)]
fn dedup_insert(
text_masks: &mut ProcessedTextMasks<'_>,
current_index: usize,
changed: Option<String>,
) -> usize {
match changed {
Some(processed) => {
if let Some(pos) = text_masks
.iter()
.position(|(t, _)| t.as_ref() == processed.as_str())
{
return_string_to_pool(processed);
pos
} else {
text_masks.push((Cow::Owned(processed), 0u64));
text_masks.len() - 1
}
}
None => current_index,
}
}
#[inline(always)]
pub fn walk_process_tree<'a, const LAZY: bool, F>(
process_type_tree: &[ProcessTypeBitNode],
text: &'a str,
on_variant: &mut F,
) -> (ProcessedTextMasks<'a>, bool)
where
F: FnMut(&str, usize, u64) -> bool,
{
TRANSFORM_STATE.with(|state| {
let mut ts = state.borrow_mut();
let pooled: Option<ProcessedTextMasks<'static>> = ts.masks_pool.pop();
let mut text_masks: ProcessedTextMasks<'a> =
unsafe { std::mem::transmute(pooled.unwrap_or_default()) };
text_masks.clear();
text_masks.push((Cow::Borrowed(text), process_type_tree[0].folded_mask));
let mut scanned_masks: TinyVec<[u64; 8]> = TinyVec::new();
if LAZY {
scanned_masks.push(0u64);
let root_mask = process_type_tree[0].folded_mask;
if root_mask != 0 && on_variant(text, 0, root_mask) {
return (text_masks, true);
}
scanned_masks[0] = root_mask;
}
if process_type_tree[0].children.is_empty() {
return (text_masks, false);
}
ts.tree_node_indices.clear();
ts.tree_node_indices.resize(process_type_tree.len(), 0);
let mut stopped = false;
'walk: for (current_node_index, current_node) in process_type_tree.iter().enumerate() {
let current_index = ts.tree_node_indices[current_node_index];
for &child_node_index in ¤t_node.children {
let child_node = &process_type_tree[child_node_index];
let pm = get_process_matcher(child_node.process_type_bit);
let changed = match child_node.process_type_bit {
ProcessType::None => None,
ProcessType::Delete => {
let current_text = text_masks[current_index].0.as_ref();
match pm.delete_all(current_text) {
(true, Cow::Owned(processed)) => Some(processed),
_ => None,
}
}
_ => {
let current_text = text_masks[current_index].0.as_ref();
match pm.replace_all(current_text) {
(true, Cow::Owned(processed)) => Some(processed),
_ => None,
}
}
};
let old_len = if LAZY { text_masks.len() } else { 0 };
let child_index = dedup_insert(&mut text_masks, current_index, changed);
if LAZY {
while scanned_masks.len() < text_masks.len() {
scanned_masks.push(0u64);
}
}
ts.tree_node_indices[child_node_index] = child_index;
text_masks[child_index].1 |= child_node.folded_mask;
if LAZY && child_index >= old_len {
let mask = text_masks[child_index].1;
if mask != 0
&& on_variant(text_masks[child_index].0.as_ref(), child_index, mask)
{
stopped = true;
break 'walk;
}
scanned_masks[child_index] = mask;
}
}
}
if LAZY {
if stopped {
return (text_masks, true);
}
for i in 0..text_masks.len() {
let delta = text_masks[i].1 & !scanned_masks[i];
if delta != 0 && on_variant(text_masks[i].0.as_ref(), i, delta) {
return (text_masks, true);
}
}
}
(text_masks, false)
})
}