use std::borrow::Cow;
#[cfg(feature = "runtime_build")]
use std::collections::HashMap;
#[cfg(feature = "runtime_build")]
use std::collections::HashSet;
use std::simd::{Simd, cmp::SimdPartialOrd};
use crate::process::simd_utils::{
simd_ascii_delete_mask, skip_ascii_simd, skip_non_digit_ascii_simd,
};
#[derive(Clone)]
pub enum SingleCharMatcher {
Fanjian {
l1: Cow<'static, [u8]>,
l2: Cow<'static, [u8]>,
},
Pinyin {
l1: Cow<'static, [u8]>,
l2: Cow<'static, [u8]>,
strings: Cow<'static, str>,
trim_space: bool,
},
Delete {
bitset: Cow<'static, [u8]>,
ascii_lut: [u8; 16],
},
}
pub enum SingleCharMatch<'a> {
Char(char),
Str(&'a str),
Delete,
}
#[inline(always)]
fn page_table_lookup(cp: u32, l1: &[u8], l2: &[u8]) -> Option<u32> {
let page_idx = (cp >> 8) as usize;
let char_idx = (cp & 0xFF) as usize;
if page_idx * 2 + 1 >= l1.len() {
return None;
}
let page = u16::from_le_bytes(l1[page_idx * 2..page_idx * 2 + 2].try_into().unwrap()) as usize;
if page == 0 {
return None;
}
let l2_idx = page * 256 + char_idx;
let val = u32::from_le_bytes(l2[l2_idx * 4..l2_idx * 4 + 4].try_into().unwrap());
if val != 0 { Some(val) } else { None }
}
#[inline(always)]
unsafe fn decode_utf8_raw(bytes: &[u8], offset: usize) -> (u32, usize) {
let b0 = unsafe { *bytes.get_unchecked(offset) };
if b0 < 0xE0 {
let b1 = unsafe { *bytes.get_unchecked(offset + 1) };
(((b0 as u32 & 0x1F) << 6) | (b1 as u32 & 0x3F), 2)
} else if b0 < 0xF0 {
let b1 = unsafe { *bytes.get_unchecked(offset + 1) };
let b2 = unsafe { *bytes.get_unchecked(offset + 2) };
(
((b0 as u32 & 0x0F) << 12) | ((b1 as u32 & 0x3F) << 6) | (b2 as u32 & 0x3F),
3,
)
} else {
let b1 = unsafe { *bytes.get_unchecked(offset + 1) };
let b2 = unsafe { *bytes.get_unchecked(offset + 2) };
let b3 = unsafe { *bytes.get_unchecked(offset + 3) };
(
((b0 as u32 & 0x07) << 18)
| ((b1 as u32 & 0x3F) << 12)
| ((b2 as u32 & 0x3F) << 6)
| (b3 as u32 & 0x3F),
4,
)
}
}
pub struct FanjianFindIter<'a> {
pub l1: &'a [u8],
pub l2: &'a [u8],
pub text: &'a str,
pub byte_offset: usize,
}
impl<'a> Iterator for FanjianFindIter<'a> {
type Item = (usize, usize, SingleCharMatch<'a>);
#[inline(always)]
fn next(&mut self) -> Option<Self::Item> {
let bytes = self.text.as_bytes();
let len = bytes.len();
loop {
self.byte_offset = skip_ascii_simd(bytes, self.byte_offset);
if self.byte_offset >= len {
return None;
}
let start = self.byte_offset;
let (cp, char_len) = unsafe { decode_utf8_raw(bytes, start) };
self.byte_offset += char_len;
if let Some(mapped_cp) = page_table_lookup(cp, self.l1, self.l2)
&& mapped_cp != cp
{
debug_assert!(char::from_u32(mapped_cp).is_some());
let mapped = unsafe { char::from_u32_unchecked(mapped_cp) };
return Some((start, self.byte_offset, SingleCharMatch::Char(mapped)));
}
}
}
}
pub struct DeleteFindIter<'a> {
pub bitset: &'a [u8],
pub ascii_lut: [u8; 16],
pub text: &'a str,
pub byte_offset: usize,
}
impl<'a> Iterator for DeleteFindIter<'a> {
type Item = (usize, usize, SingleCharMatch<'a>);
#[inline(always)]
fn next(&mut self) -> Option<Self::Item> {
let bytes = self.text.as_bytes();
let len = bytes.len();
let ascii_lut_simd = Simd::<u8, 16>::from_array(self.ascii_lut);
loop {
if self.byte_offset >= len {
return None;
}
let b = bytes[self.byte_offset];
let start = self.byte_offset;
if b < 0x80 {
let cp = b as usize;
self.byte_offset += 1;
if (self.ascii_lut[cp >> 3] & (1 << (cp & 7))) != 0 {
return Some((start, self.byte_offset, SingleCharMatch::Delete));
}
while self.byte_offset + 16 <= len {
let chunk = Simd::<u8, 16>::from_slice(&bytes[self.byte_offset..]);
let non_ascii_mask = chunk.simd_ge(Simd::<u8, 16>::splat(0x80u8)).to_bitmask();
let del_mask = simd_ascii_delete_mask(chunk, ascii_lut_simd);
let stop_mask = non_ascii_mask | del_mask;
if stop_mask != 0 {
self.byte_offset += stop_mask.trailing_zeros() as usize;
break;
}
self.byte_offset += 16;
}
while self.byte_offset < len {
let b2 = bytes[self.byte_offset];
if b2 >= 0x80 {
break;
}
let cp2 = b2 as usize;
if (self.ascii_lut[cp2 >> 3] & (1 << (cp2 & 7))) != 0 {
break;
}
self.byte_offset += 1;
}
} else {
let (cp, char_len) = unsafe { decode_utf8_raw(bytes, start) };
self.byte_offset += char_len;
let cp_usize = cp as usize;
if cp_usize / 8 < self.bitset.len()
&& (self.bitset[cp_usize / 8] & (1 << (cp_usize % 8))) != 0
{
return Some((start, self.byte_offset, SingleCharMatch::Delete));
}
}
}
}
}
pub struct PinyinFindIter<'a> {
pub l1: &'a [u8],
pub l2: &'a [u8],
pub strings: &'a str,
pub trim_space: bool,
pub text: &'a str,
pub byte_offset: usize,
}
impl<'a> Iterator for PinyinFindIter<'a> {
type Item = (usize, usize, SingleCharMatch<'a>);
#[inline(always)]
fn next(&mut self) -> Option<Self::Item> {
let bytes = self.text.as_bytes();
let len = bytes.len();
loop {
self.byte_offset = skip_non_digit_ascii_simd(bytes, self.byte_offset);
if self.byte_offset >= len {
return None;
}
let start = self.byte_offset;
let b = bytes[start];
let (cp, char_len) = if b < 0x80 {
(b as u32, 1)
} else {
unsafe { decode_utf8_raw(bytes, start) }
};
self.byte_offset += char_len;
if let Some(val) = page_table_lookup(cp, self.l1, self.l2) {
let offset = (val >> 8) as usize;
let str_len = (val & 0xFF) as usize;
if offset + str_len <= self.strings.len() {
let mut s = &self.strings[offset..offset + str_len];
if self.trim_space {
s = s.trim();
}
return Some((start, self.byte_offset, SingleCharMatch::Str(s)));
}
}
}
}
}
pub struct SingleCharFindIter<'a> {
matcher: &'a SingleCharMatcher,
text: &'a str,
byte_offset: usize,
}
impl<'a> SingleCharFindIter<'a> {
#[inline(always)]
pub fn new(matcher: &'a SingleCharMatcher, text: &'a str) -> Self {
Self {
matcher,
text,
byte_offset: 0,
}
}
}
impl<'a> Iterator for SingleCharFindIter<'a> {
type Item = (usize, usize, SingleCharMatch<'a>);
#[inline(always)]
fn next(&mut self) -> Option<Self::Item> {
let text = &self.text[self.byte_offset..];
for (i, c) in text.char_indices() {
let cp = c as u32;
let start = self.byte_offset + i;
let end = start + c.len_utf8();
match self.matcher {
SingleCharMatcher::Fanjian { l1, l2 } => {
if cp < 0x80 {
continue;
}
if let Some(mapped_cp) = page_table_lookup(cp, l1, l2) {
let mapped = char::from_u32(mapped_cp).unwrap_or(c);
if mapped != c {
self.byte_offset = end;
return Some((start, end, SingleCharMatch::Char(mapped)));
}
}
}
SingleCharMatcher::Pinyin {
l1,
l2,
strings,
trim_space,
} => {
if cp < 0x80 && !c.is_ascii_digit() {
continue;
}
if let Some(val) = page_table_lookup(cp, l1, l2) {
let offset = (val >> 8) as usize;
let len = (val & 0xFF) as usize;
if offset + len <= strings.len() {
let mut s = &strings[offset..offset + len];
if *trim_space {
s = s.trim();
}
self.byte_offset = end;
return Some((start, end, SingleCharMatch::Str(s)));
}
}
}
SingleCharMatcher::Delete { bitset, .. } => {
let cp_usize = cp as usize;
if cp_usize / 8 < bitset.len()
&& (bitset[cp_usize / 8] & (1 << (cp_usize % 8))) != 0
{
self.byte_offset = end;
return Some((start, end, SingleCharMatch::Delete));
}
}
}
}
self.byte_offset = self.text.len();
None
}
}
impl SingleCharMatcher {
#[inline(always)]
pub fn find_iter<'a>(&'a self, text: &'a str) -> SingleCharFindIter<'a> {
SingleCharFindIter::new(self, text)
}
#[inline(always)]
pub fn fanjian_iter<'a>(&'a self, text: &'a str) -> FanjianFindIter<'a> {
let SingleCharMatcher::Fanjian { l1, l2 } = self else {
panic!("fanjian_iter called on non-Fanjian matcher");
};
FanjianFindIter {
l1,
l2,
text,
byte_offset: 0,
}
}
#[inline(always)]
pub fn delete_iter<'a>(&'a self, text: &'a str) -> DeleteFindIter<'a> {
let SingleCharMatcher::Delete { bitset, ascii_lut } = self else {
panic!("delete_iter called on non-Delete matcher");
};
DeleteFindIter {
bitset,
ascii_lut: *ascii_lut,
text,
byte_offset: 0,
}
}
#[inline(always)]
pub fn pinyin_iter<'a>(&'a self, text: &'a str) -> PinyinFindIter<'a> {
let SingleCharMatcher::Pinyin {
l1,
l2,
strings,
trim_space,
} = self
else {
panic!("pinyin_iter called on non-Pinyin matcher");
};
PinyinFindIter {
l1,
l2,
strings,
trim_space: *trim_space,
text,
byte_offset: 0,
}
}
pub fn fanjian(l1: Cow<'static, [u8]>, l2: Cow<'static, [u8]>) -> Self {
SingleCharMatcher::Fanjian { l1, l2 }
}
pub fn delete(bitset: Cow<'static, [u8]>) -> Self {
let mut ascii_lut = [0u8; 16];
let copy_len = bitset.len().min(16);
ascii_lut[..copy_len].copy_from_slice(&bitset[..copy_len]);
SingleCharMatcher::Delete { bitset, ascii_lut }
}
pub fn pinyin(
l1: Cow<'static, [u8]>,
l2: Cow<'static, [u8]>,
strings: Cow<'static, str>,
trim_space: bool,
) -> Self {
SingleCharMatcher::Pinyin {
l1,
l2,
strings,
trim_space,
}
}
#[cfg(feature = "runtime_build")]
fn build_2_stage_table(map: &HashMap<u32, u32>) -> (Vec<u8>, Vec<u8>) {
let mut pages: HashSet<u32> = map.keys().map(|&k| k >> 8).collect();
let mut page_list: Vec<u32> = pages.drain().collect();
page_list.sort_unstable();
let mut l1 = vec![0u16; 4352];
let mut l2 = vec![0u32; (page_list.len() + 1) * 256];
for (i, &page) in page_list.iter().enumerate() {
let l2_page_idx = (i + 1) as u16;
l1[page as usize] = l2_page_idx;
for char_idx in 0..256u32 {
let cp = (page << 8) | char_idx;
if let Some(&val) = map.get(&cp) {
l2[(l2_page_idx as usize * 256) + char_idx as usize] = val;
}
}
}
let mut l1_bytes = Vec::with_capacity(l1.len() * 2);
for val in l1 {
l1_bytes.extend_from_slice(&val.to_le_bytes());
}
let mut l2_bytes = Vec::with_capacity(l2.len() * 4);
for val in l2 {
l2_bytes.extend_from_slice(&val.to_le_bytes());
}
(l1_bytes, l2_bytes)
}
#[cfg(feature = "runtime_build")]
pub fn fanjian_from_map(map: HashMap<u32, u32>) -> Self {
let (l1, l2) = Self::build_2_stage_table(&map);
Self::fanjian(Cow::Owned(l1), Cow::Owned(l2))
}
#[cfg(feature = "runtime_build")]
pub fn delete_from_sources(text_delete: &str, white_space: &[&str]) -> Self {
let mut bitset = vec![0u8; 139264];
for line in text_delete.trim().lines() {
for c in line.chars() {
let cp = c as usize;
bitset[cp / 8] |= 1 << (cp % 8);
}
}
for &ws in white_space {
for c in ws.chars() {
let cp = c as usize;
bitset[cp / 8] |= 1 << (cp % 8);
}
}
Self::delete(Cow::Owned(bitset))
}
#[cfg(feature = "runtime_build")]
pub fn pinyin_from_map(map: HashMap<u32, &str>, trim_space: bool) -> Self {
let mut strings = String::new();
let packed: HashMap<u32, u32> = map
.into_iter()
.map(|(k, v)| {
let offset = strings.len() as u32;
let length = v.len() as u32;
strings.push_str(v);
(k, (offset << 8) | length)
})
.collect();
let (l1, l2) = Self::build_2_stage_table(&packed);
Self::pinyin(
Cow::Owned(l1),
Cow::Owned(l2),
Cow::Owned(strings),
trim_space,
)
}
}