#[cfg(feature = "runtime_build")]
use ahash::{AHashMap, AHashSet};
use std::borrow::Cow;
use aho_corasick::{
AhoCorasick, AhoCorasickBuilder, AhoCorasickKind, MatchKind as AhoCorasickMatchKind,
};
use crate::process::string_pool::get_string_from_pool;
use crate::process::transform::simd::{multibyte_density, skip_ascii_simd};
use crate::process::transform::utf8::decode_utf8_raw;
#[inline(always)]
fn replace_scan<I>(text: &str, mut iter: I) -> Option<String>
where
I: Iterator<Item = (usize, usize, char)>,
{
if let Some((start, end, ch)) = iter.next() {
let mut result = get_string_from_pool(text.len());
result.push_str(&text[..start]);
result.push(ch);
let mut last_end = end;
for (start, end, ch) in iter {
result.push_str(&text[last_end..start]);
result.push(ch);
last_end = end;
}
result.push_str(&text[last_end..]);
Some(result)
} else {
None
}
}
#[inline(always)]
fn replace_spans_with_density<'a, I>(text: &str, mut iter: I) -> Option<(String, f32)>
where
I: Iterator<Item = (usize, usize, &'a str)>,
{
if let Some((start, end, replacement)) = iter.next() {
let mut result = get_string_from_pool(text.len());
result.push_str(&text[..start]);
result.push_str(replacement);
let mut last_end = end;
for (start, end, replacement) in iter {
result.push_str(&text[last_end..start]);
result.push_str(replacement);
last_end = end;
}
result.push_str(&text[last_end..]);
let density = multibyte_density(result.as_bytes());
Some((result, density))
} else {
None
}
}
pub(crate) trait ReplacementFinder<'a> {
fn next_replacement(&mut self) -> Option<(usize, usize, &'a [u8])>;
}
pub(crate) struct SliceReplacingByteIter<'a, F: ReplacementFinder<'a>> {
finder: F,
source: &'a [u8],
pos: usize,
next_start: usize,
next_end: usize,
next_repl: &'a [u8],
repl: &'a [u8],
repl_pos: usize,
}
impl<'a, F: ReplacementFinder<'a>> SliceReplacingByteIter<'a, F> {
#[inline(always)]
fn new(mut finder: F, source: &'a [u8]) -> Self {
let (next_start, next_end, next_repl) = finder
.next_replacement()
.map_or((usize::MAX, 0, &[] as &[u8]), |(s, e, r)| (s, e, r));
Self {
finder,
source,
pos: 0,
next_start,
next_end,
next_repl,
repl: &[],
repl_pos: 0,
}
}
#[inline(always)]
fn advance_finder(&mut self) {
match self.finder.next_replacement() {
Some((s, e, r)) => {
self.next_start = s;
self.next_end = e;
self.next_repl = r;
}
None => {
self.next_start = usize::MAX;
}
}
}
}
impl<'a, F: ReplacementFinder<'a>> Iterator for SliceReplacingByteIter<'a, F> {
type Item = u8;
#[inline(always)]
fn next(&mut self) -> Option<u8> {
if self.repl_pos < self.repl.len() {
let b = self.repl[self.repl_pos];
self.repl_pos += 1;
return Some(b);
}
if self.pos == self.next_start {
self.pos = self.next_end;
self.repl = self.next_repl;
self.repl_pos = 1;
let first = self.repl[0];
self.advance_finder();
return Some(first);
}
if self.pos < self.source.len() {
let b = self.source[self.pos];
self.pos += 1;
Some(b)
} else {
None
}
}
}
#[inline(always)]
fn page_table_lookup(cp: u32, l1: &[u16], l2: &[u32]) -> Option<u32> {
let page_idx = (cp >> 8) as usize;
let char_idx = (cp & 0xFF) as usize;
if page_idx >= l1.len() {
return None;
}
let page = unsafe { *l1.get_unchecked(page_idx) as usize };
if page == 0 {
return None;
}
debug_assert!(page * 256 + char_idx < l2.len());
let value = unsafe { *l2.get_unchecked(page * 256 + char_idx) };
(value != 0).then_some(value)
}
#[cfg(not(feature = "runtime_build"))]
#[inline]
fn decode_u16_table(bytes: &[u8]) -> Box<[u16]> {
debug_assert_eq!(bytes.len() % 2, 0);
bytes
.chunks_exact(2)
.map(|chunk| u16::from_le_bytes([chunk[0], chunk[1]]))
.collect::<Vec<_>>()
.into_boxed_slice()
}
#[cfg(not(feature = "runtime_build"))]
#[inline]
fn decode_u32_table(bytes: &[u8]) -> Box<[u32]> {
debug_assert_eq!(bytes.len() % 4, 0);
bytes
.chunks_exact(4)
.map(|chunk| u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
.collect::<Vec<_>>()
.into_boxed_slice()
}
#[inline]
fn trim_pinyin_packed(value: u32, strings: &str) -> u32 {
if value == 0 {
return 0;
}
let mut start = (value >> 8) as usize;
let mut end = start + (value & 0xFF) as usize;
let bytes = strings.as_bytes();
while start < end && bytes[start] == b' ' {
start += 1;
}
while end > start && bytes[end - 1] == b' ' {
end -= 1;
}
((start as u32) << 8) | ((end - start) as u32)
}
struct FanjianFindIter<'a> {
l1: &'a [u16],
l2: &'a [u32],
text: &'a str,
byte_offset: usize,
}
impl<'a> Iterator for FanjianFindIter<'a> {
type Item = (usize, usize, char);
#[inline(always)]
fn next(&mut self) -> Option<Self::Item> {
let bytes = self.text.as_bytes();
let len = bytes.len();
loop {
self.byte_offset = skip_ascii_simd(bytes, self.byte_offset);
if self.byte_offset >= len {
return None;
}
let start = self.byte_offset;
let (cp, char_len) = unsafe { decode_utf8_raw(bytes, start) };
self.byte_offset += char_len;
if let Some(mapped_cp) = page_table_lookup(cp, self.l1, self.l2)
&& mapped_cp != cp
{
debug_assert!(char::from_u32(mapped_cp).is_some());
let mapped = unsafe { char::from_u32_unchecked(mapped_cp) };
return Some((start, self.byte_offset, mapped));
}
}
}
}
pub(crate) struct FanjianByteIter<'a> {
find_iter: FanjianFindIter<'a>,
source: &'a [u8],
pos: usize,
next_start: usize,
next_end: usize,
next_char: char,
buf: [u8; 4],
buf_pos: u8,
buf_len: u8,
}
impl<'a> Iterator for FanjianByteIter<'a> {
type Item = u8;
#[inline(always)]
fn next(&mut self) -> Option<u8> {
if self.buf_pos < self.buf_len {
let b = self.buf[self.buf_pos as usize];
self.buf_pos += 1;
return Some(b);
}
if self.pos == self.next_start {
self.pos = self.next_end;
let len = self.next_char.len_utf8();
self.next_char.encode_utf8(&mut self.buf);
self.buf_len = len as u8;
self.buf_pos = 1;
self.advance_find_iter();
return Some(self.buf[0]);
}
if self.pos < self.source.len() {
let b = self.source[self.pos];
self.pos += 1;
Some(b)
} else {
None
}
}
}
impl<'a> FanjianByteIter<'a> {
#[inline(always)]
fn advance_find_iter(&mut self) {
match self.find_iter.next() {
Some((s, e, c)) => {
self.next_start = s;
self.next_end = e;
self.next_char = c;
}
None => {
self.next_start = usize::MAX;
}
}
}
}
#[derive(Clone)]
pub(crate) struct FanjianMatcher {
l1: Box<[u16]>,
l2: Box<[u32]>,
}
impl FanjianMatcher {
#[inline(always)]
fn iter<'a>(&'a self, text: &'a str) -> FanjianFindIter<'a> {
FanjianFindIter {
l1: &self.l1,
l2: &self.l2,
text,
byte_offset: 0,
}
}
#[inline(always)]
pub(crate) fn byte_iter<'a>(&'a self, text: &'a str) -> FanjianByteIter<'a> {
let mut iter = FanjianByteIter {
find_iter: self.iter(text),
source: text.as_bytes(),
pos: 0,
next_start: usize::MAX,
next_end: 0,
next_char: '\0',
buf: [0; 4],
buf_pos: 0,
buf_len: 0,
};
iter.advance_find_iter();
iter
}
pub(crate) fn replace(&self, text: &str) -> Option<String> {
replace_scan(text, self.iter(text))
}
#[cfg(not(feature = "runtime_build"))]
pub(crate) fn new(l1: &'static [u8], l2: &'static [u8]) -> Self {
Self {
l1: decode_u16_table(l1),
l2: decode_u32_table(l2),
}
}
#[cfg(feature = "runtime_build")]
pub(crate) fn from_map(map: AHashMap<u32, u32>) -> Self {
let (l1, l2) = build_2_stage_table(&map);
Self {
l1: l1.into_boxed_slice(),
l2: l2.into_boxed_slice(),
}
}
}
struct PinyinFindIter<'a> {
l1: &'a [u16],
l2: &'a [u32],
strings: &'a str,
text: &'a str,
byte_offset: usize,
}
impl<'a> Iterator for PinyinFindIter<'a> {
type Item = (usize, usize, &'a str);
#[inline(always)]
fn next(&mut self) -> Option<Self::Item> {
let bytes = self.text.as_bytes();
let len = bytes.len();
loop {
self.byte_offset = skip_ascii_simd(bytes, self.byte_offset);
if self.byte_offset >= len {
return None;
}
let start = self.byte_offset;
let (cp, char_len) = unsafe { decode_utf8_raw(bytes, start) };
self.byte_offset += char_len;
if let Some(value) = page_table_lookup(cp, self.l1, self.l2) {
let offset = (value >> 8) as usize;
let str_len = (value & 0xFF) as usize;
if offset + str_len <= self.strings.len() {
return Some((
start,
self.byte_offset,
&self.strings[offset..offset + str_len],
));
}
}
}
}
}
pub(crate) struct PinyinFindAdapter<'a> {
find_iter: PinyinFindIter<'a>,
}
impl<'a> ReplacementFinder<'a> for PinyinFindAdapter<'a> {
#[inline(always)]
fn next_replacement(&mut self) -> Option<(usize, usize, &'a [u8])> {
self.find_iter.next().map(|(s, e, r)| (s, e, r.as_bytes()))
}
}
pub(crate) type PinyinByteIter<'a> = SliceReplacingByteIter<'a, PinyinFindAdapter<'a>>;
#[derive(Clone)]
pub(crate) struct PinyinMatcher {
l1: Box<[u16]>,
l2: Box<[u32]>,
strings: Cow<'static, str>,
}
impl PinyinMatcher {
#[inline(always)]
fn iter<'a>(&'a self, text: &'a str) -> PinyinFindIter<'a> {
PinyinFindIter {
l1: &self.l1,
l2: &self.l2,
strings: self.strings.as_ref(),
text,
byte_offset: 0,
}
}
#[inline(always)]
pub(crate) fn byte_iter<'a>(&'a self, text: &'a str) -> PinyinByteIter<'a> {
SliceReplacingByteIter::new(
PinyinFindAdapter {
find_iter: self.iter(text),
},
text.as_bytes(),
)
}
pub(crate) fn replace(&self, text: &str) -> Option<(String, f32)> {
replace_spans_with_density(text, self.iter(text))
}
#[cfg(not(feature = "runtime_build"))]
pub(crate) fn new(
l1: &'static [u8],
l2: &'static [u8],
strings: &'static str,
trim_space: bool,
) -> Self {
let l1 = decode_u16_table(l1);
let mut l2 = decode_u32_table(l2);
if trim_space {
for value in l2.iter_mut() {
*value = trim_pinyin_packed(*value, strings);
}
}
Self {
l1,
l2,
strings: Cow::Borrowed(strings),
}
}
#[cfg(feature = "runtime_build")]
pub(crate) fn from_map(map: AHashMap<u32, &str>, trim_space: bool) -> Self {
let mut strings = String::new();
let packed: AHashMap<u32, u32> = map
.into_iter()
.map(|(key, value)| {
let offset = strings.len() as u32;
let length = value.len() as u32;
strings.push_str(value);
(key, (offset << 8) | length)
})
.collect();
let (l1, l2) = build_2_stage_table(&packed);
let strings: Cow<'static, str> = Cow::Owned(strings);
let mut l2 = l2.into_boxed_slice();
if trim_space {
for value in l2.iter_mut() {
*value = trim_pinyin_packed(*value, strings.as_ref());
}
}
Self {
l1: l1.into_boxed_slice(),
l2,
strings,
}
}
}
#[derive(Clone)]
pub(crate) struct NormalizeMatcher {
engine: AhoCorasick,
replace_list: Vec<&'static str>,
}
pub(crate) struct NormalizeFindAdapter<'a> {
find_iter: aho_corasick::FindIter<'a, 'a>,
replace_list: &'a [&'static str],
}
impl<'a> ReplacementFinder<'a> for NormalizeFindAdapter<'a> {
#[inline(always)]
fn next_replacement(&mut self) -> Option<(usize, usize, &'a [u8])> {
self.find_iter.next().map(|m| {
(
m.start(),
m.end(),
self.replace_list[m.pattern().as_usize()].as_bytes(),
)
})
}
}
pub(crate) type NormalizeByteIter<'a> = SliceReplacingByteIter<'a, NormalizeFindAdapter<'a>>;
impl NormalizeMatcher {
#[inline(always)]
fn find_iter<'a>(&'a self, text: &'a str) -> aho_corasick::FindIter<'a, 'a> {
self.engine.find_iter(text)
}
#[inline(always)]
pub(crate) fn byte_iter<'a>(&'a self, text: &'a str) -> NormalizeByteIter<'a> {
SliceReplacingByteIter::new(
NormalizeFindAdapter {
find_iter: self.find_iter(text),
replace_list: &self.replace_list,
},
text.as_bytes(),
)
}
pub(crate) fn replace(&self, text: &str) -> Option<(String, f32)> {
let replace_list = &self.replace_list;
replace_spans_with_density(
text,
self.find_iter(text)
.map(|m| (m.start(), m.end(), replace_list[m.pattern().as_usize()])),
)
}
pub(crate) fn new<I, P>(patterns: I) -> Self
where
I: IntoIterator<Item = P>,
P: AsRef<str> + AsRef<[u8]>,
{
Self {
engine: AhoCorasickBuilder::new()
.kind(Some(AhoCorasickKind::DFA))
.match_kind(AhoCorasickMatchKind::LeftmostLongest)
.build(patterns)
.unwrap(),
replace_list: Vec::new(),
}
}
pub(crate) fn with_replacements(mut self, replace_list: Vec<&'static str>) -> Self {
self.replace_list = replace_list;
self
}
#[cfg(feature = "runtime_build")]
pub(crate) fn from_dict(dict: AHashMap<&'static str, &'static str>) -> Self {
let mut pairs: Vec<(&'static str, &'static str)> = dict.into_iter().collect();
pairs.sort_unstable_by_key(|&(k, _)| k);
let replace_list: Vec<&'static str> = pairs.iter().map(|&(_, v)| v).collect();
Self::new(pairs.into_iter().map(|(k, _)| k)).with_replacements(replace_list)
}
}
#[cfg(feature = "runtime_build")]
fn build_2_stage_table(map: &AHashMap<u32, u32>) -> (Vec<u16>, Vec<u32>) {
let mut pages: AHashSet<u32> = map.keys().map(|&key| key >> 8).collect();
let mut page_list: Vec<u32> = pages.drain().collect();
page_list.sort_unstable();
const L1_SIZE: usize = (0x10FFFF >> 8) + 1;
let mut l1 = vec![0u16; L1_SIZE];
let mut l2 = vec![0u32; (page_list.len() + 1) * 256];
for (index, &page) in page_list.iter().enumerate() {
let l2_page_idx = (index + 1) as u16;
l1[page as usize] = l2_page_idx;
for char_idx in 0..256u32 {
let cp = (page << 8) | char_idx;
if let Some(&value) = map.get(&cp) {
l2[(l2_page_idx as usize * 256) + char_idx as usize] = value;
}
}
}
(l1, l2)
}
#[cfg(all(test, not(feature = "runtime_build")))]
mod tests {
use super::*;
use super::super::constants;
fn fanjian() -> FanjianMatcher {
FanjianMatcher::new(constants::FANJIAN_L1_BYTES, constants::FANJIAN_L2_BYTES)
}
fn pinyin() -> PinyinMatcher {
PinyinMatcher::new(
constants::PINYIN_L1_BYTES,
constants::PINYIN_L2_BYTES,
constants::PINYIN_STR_BYTES,
false,
)
}
fn pinyin_char() -> PinyinMatcher {
PinyinMatcher::new(
constants::PINYIN_L1_BYTES,
constants::PINYIN_L2_BYTES,
constants::PINYIN_STR_BYTES,
true,
)
}
fn normalize_matcher() -> NormalizeMatcher {
let patterns: Vec<&str> = constants::NORMALIZE_PROCESS_LIST_STR.lines().collect();
let replace_list: Vec<&'static str> = constants::NORMALIZE_PROCESS_REPLACE_LIST_STR
.lines()
.collect();
NormalizeMatcher::new(patterns.iter()).with_replacements(replace_list)
}
fn assert_byte_iter_eq_replace_fanjian(matcher: &FanjianMatcher, text: &str) {
let materialized: Vec<u8> = match matcher.replace(text) {
Some(s) => s.into_bytes(),
None => text.as_bytes().to_vec(),
};
let streamed: Vec<u8> = matcher.byte_iter(text).collect();
assert_eq!(materialized, streamed, "fanjian mismatch for: {:?}", text);
}
fn assert_byte_iter_eq_replace_pinyin(matcher: &PinyinMatcher, text: &str) {
let materialized: Vec<u8> = match matcher.replace(text) {
Some((s, _)) => s.into_bytes(),
None => text.as_bytes().to_vec(),
};
let streamed: Vec<u8> = matcher.byte_iter(text).collect();
assert_eq!(materialized, streamed, "pinyin mismatch for: {:?}", text);
}
fn assert_byte_iter_eq_replace_normalize(matcher: &NormalizeMatcher, text: &str) {
let materialized: Vec<u8> = match matcher.replace(text) {
Some((s, _)) => s.into_bytes(),
None => text.as_bytes().to_vec(),
};
let streamed: Vec<u8> = matcher.byte_iter(text).collect();
assert_eq!(materialized, streamed, "normalize mismatch for: {:?}", text);
}
#[test]
fn fanjian_byte_iter_matches_replace() {
let m = fanjian();
for text in ["", "hello", "國際經濟", "abc東def國", "a", "東"] {
assert_byte_iter_eq_replace_fanjian(&m, text);
}
}
#[test]
fn pinyin_byte_iter_matches_replace() {
let m = pinyin();
for text in ["", "hello", "中文", "abc中def文", "a", "中"] {
assert_byte_iter_eq_replace_pinyin(&m, text);
}
}
#[test]
fn pinyin_char_byte_iter_matches_replace() {
let m = pinyin_char();
for text in ["", "hello", "中文", "abc中def文"] {
assert_byte_iter_eq_replace_pinyin(&m, text);
}
}
#[test]
fn normalize_byte_iter_matches_replace() {
let m = normalize_matcher();
for text in ["", "hello", "ABC", "abc123def", "①②③"] {
assert_byte_iter_eq_replace_normalize(&m, text);
}
}
proptest::proptest! {
#![proptest_config(proptest::prelude::ProptestConfig::with_cases(500))]
#[test]
fn prop_fanjian_byte_iter(text in "\\PC{0,200}") {
let m = fanjian();
assert_byte_iter_eq_replace_fanjian(&m, &text);
}
#[test]
fn prop_pinyin_byte_iter(text in "\\PC{0,200}") {
let m = pinyin();
assert_byte_iter_eq_replace_pinyin(&m, &text);
}
#[test]
fn prop_normalize_byte_iter(text in "\\PC{0,200}") {
let m = normalize_matcher();
assert_byte_iter_eq_replace_normalize(&m, &text);
}
}
}