use std::borrow::Cow;
use std::collections::HashSet;
use tinyvec::TinyVec;
use crate::process::process_matcher::{ProcessMatcher, get_process_matcher};
use crate::process::process_type::ProcessType;
use crate::process::string_pool::{
ProcessedTextMasks, TRANSFORM_STATE, TextVariant, return_string_to_pool,
};
#[derive(Clone)]
pub struct ProcessTypeBitNode {
process_type_list: Vec<ProcessType>,
pub(crate) process_type_bit: ProcessType,
pub(crate) children: Vec<usize>,
pub(crate) matcher: Option<&'static ProcessMatcher>,
pub(crate) folded_mask: u64,
}
impl ProcessTypeBitNode {
pub(crate) fn recompute_mask_with_index(&mut self, pt_index_table: &[u8; 64]) {
self.folded_mask = self.process_type_list.iter().fold(0u64, |acc, pt| {
acc | (1u64 << pt_index_table[pt.bits() as usize])
});
}
}
pub fn build_process_type_tree(process_type_set: &HashSet<ProcessType>) -> Vec<ProcessTypeBitNode> {
let mut process_type_tree = Vec::new();
let mut root = ProcessTypeBitNode {
process_type_list: Vec::new(),
process_type_bit: ProcessType::None,
children: Vec::new(),
matcher: None,
folded_mask: 0,
};
if process_type_set.contains(&ProcessType::None) {
root.process_type_list.push(ProcessType::None);
root.folded_mask |= 1u64 << ProcessType::None.bits();
}
process_type_tree.push(root);
for &process_type in process_type_set.iter() {
let mut current_node_index = 0;
for process_type_bit in process_type.iter() {
let current_node = &process_type_tree[current_node_index];
if current_node.process_type_bit == process_type_bit {
continue;
}
let found_child = current_node
.children
.iter()
.find(|&&idx| process_type_tree[idx].process_type_bit == process_type_bit)
.copied();
if let Some(child_idx) = found_child {
current_node_index = child_idx;
process_type_tree[current_node_index]
.process_type_list
.push(process_type);
process_type_tree[current_node_index].folded_mask |= 1u64 << process_type.bits();
} else {
let mut child = ProcessTypeBitNode {
process_type_list: Vec::new(),
process_type_bit,
children: Vec::new(),
matcher: Some(get_process_matcher(process_type_bit)),
folded_mask: 0,
};
child.process_type_list.push(process_type);
child.folded_mask |= 1u64 << process_type.bits();
process_type_tree.push(child);
let new_node_index = process_type_tree.len() - 1;
process_type_tree[current_node_index]
.children
.push(new_node_index);
current_node_index = new_node_index;
}
}
}
process_type_tree
}
fn dedup_insert(
text_masks: &mut ProcessedTextMasks<'_>,
current_index: usize,
changed: Option<String>,
is_ascii: bool,
) -> usize {
match changed {
Some(processed) => {
let plen = processed.len();
if let Some(pos) = text_masks
.iter()
.position(|tv| tv.text.len() == plen && tv.text.as_ref() == processed.as_str())
{
return_string_to_pool(processed);
pos
} else {
text_masks.push(TextVariant {
text: Cow::Owned(processed),
mask: 0u64,
is_ascii,
});
text_masks.len() - 1
}
}
None => current_index,
}
}
#[inline(always)]
pub fn walk_process_tree<'a, const LAZY: bool, F>(
process_type_tree: &[ProcessTypeBitNode],
text: &'a str,
on_variant: &mut F,
) -> (ProcessedTextMasks<'a>, bool)
where
F: FnMut(&str, usize, u64, bool) -> bool,
{
{
let mut ts = TRANSFORM_STATE.borrow_mut();
let pooled: Option<ProcessedTextMasks<'static>> = ts.masks_pool.pop();
let mut text_masks: ProcessedTextMasks<'a> =
unsafe { std::mem::transmute(pooled.unwrap_or_default()) };
text_masks.clear();
let root_is_ascii = text.is_ascii();
text_masks.push(TextVariant {
text: Cow::Borrowed(text),
mask: process_type_tree[0].folded_mask,
is_ascii: root_is_ascii,
});
let mut scanned_masks: TinyVec<[u64; 8]> = TinyVec::new();
if LAZY {
scanned_masks.push(0u64);
let root_mask = process_type_tree[0].folded_mask;
if root_mask != 0 && on_variant(text, 0, root_mask, root_is_ascii) {
return (text_masks, true);
}
scanned_masks[0] = root_mask;
}
if process_type_tree[0].children.is_empty() {
return (text_masks, false);
}
ts.tree_node_indices.clear();
ts.tree_node_indices.resize(process_type_tree.len(), 0);
let mut stopped = false;
'walk: for (current_node_index, current_node) in process_type_tree.iter().enumerate() {
let current_index = ts.tree_node_indices[current_node_index];
let parent_is_ascii = text_masks[current_index].is_ascii;
for &child_node_index in ¤t_node.children {
let child_node = &process_type_tree[child_node_index];
let pm = child_node
.matcher
.expect("non-root process tree nodes always cache a matcher");
let (changed, child_is_ascii) = match child_node.process_type_bit {
ProcessType::None => (None, parent_is_ascii),
ProcessType::PinYin | ProcessType::PinYinChar => {
let current_text = text_masks[current_index].text.as_ref();
if let Some(processed) = pm.replace_all(current_text) {
(Some(processed), true)
} else {
(None, parent_is_ascii)
}
}
ProcessType::Fanjian => {
let current_text = text_masks[current_index].text.as_ref();
if let Some(processed) = pm.replace_all(current_text) {
(Some(processed), false)
} else {
(None, parent_is_ascii)
}
}
ProcessType::Delete => {
let current_text = text_masks[current_index].text.as_ref();
if let Some(processed) = pm.delete_all(current_text) {
let ia = parent_is_ascii || processed.is_ascii();
(Some(processed), ia)
} else {
(None, parent_is_ascii)
}
}
_ => {
let current_text = text_masks[current_index].text.as_ref();
if let Some(processed) = pm.replace_all(current_text) {
let ia = processed.is_ascii();
(Some(processed), ia)
} else {
(None, parent_is_ascii)
}
}
};
let old_len = if LAZY { text_masks.len() } else { 0 };
let child_index =
dedup_insert(&mut text_masks, current_index, changed, child_is_ascii);
if LAZY {
while scanned_masks.len() < text_masks.len() {
scanned_masks.push(0u64);
}
}
ts.tree_node_indices[child_node_index] = child_index;
text_masks[child_index].mask |= child_node.folded_mask;
if LAZY && child_index >= old_len {
let mask = text_masks[child_index].mask;
let is_ascii = text_masks[child_index].is_ascii;
if mask != 0
&& on_variant(
text_masks[child_index].text.as_ref(),
child_index,
mask,
is_ascii,
)
{
stopped = true;
break 'walk;
}
scanned_masks[child_index] = mask;
}
}
}
if LAZY {
if stopped {
return (text_masks, true);
}
for i in 0..text_masks.len() {
let delta = text_masks[i].mask & !scanned_masks[i];
if delta != 0
&& on_variant(
text_masks[i].text.as_ref(),
i,
delta,
text_masks[i].is_ascii,
)
{
return (text_masks, true);
}
}
}
(text_masks, false)
}
}