use std::sync::{OnceLock, RwLock};
use inputx_fsa::{Dict, Fsa};
use crate::ranking::{L0Inner, L0Snapshot, PROMOTE_THRESHOLD};
#[cfg(not(feature = "bootstrap_only"))]
const DICT_BYTES: &[u8] = inputx_pinyin_data_core::EMBEDDED_PINYIN_DICT;
#[cfg(feature = "bootstrap_only")]
const DICT_BYTES: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/bootstrap.dict"));
#[cfg(all(not(feature = "bootstrap_only"), feature = "bigrams"))]
const BIGRAMS_BYTES: &[u8] = inputx_pinyin_data_bigrams::EMBEDDED_BIGRAMS;
#[cfg(any(feature = "bootstrap_only", not(feature = "bigrams")))]
const BIGRAMS_BYTES: &[u8] = &[];
#[cfg(all(not(feature = "bootstrap_only"), feature = "bigrams"))]
const BIGRAMS_INTRA_BYTES: &[u8] = inputx_pinyin_data_bigrams::EMBEDDED_BIGRAMS_INTRA;
#[cfg(any(feature = "bootstrap_only", not(feature = "bigrams")))]
const BIGRAMS_INTRA_BYTES: &[u8] = &[];
#[cfg(all(not(feature = "bootstrap_only"), feature = "trigrams"))]
const TRIGRAMS_BYTES: &[u8] = inputx_pinyin_data_trigrams::EMBEDDED_TRIGRAMS;
#[cfg(any(feature = "bootstrap_only", not(feature = "trigrams")))]
const TRIGRAMS_BYTES: &[u8] = &[];
pub struct PinyinDict {
map: Dict<&'static [u8]>,
bigrams: Option<Fsa<&'static [u8]>>,
bigrams_intra: Option<Fsa<&'static [u8]>>,
trigrams: Option<Dict<&'static [u8]>>,
l0: RwLock<L0Inner>,
char_max_freq: OnceLock<std::collections::HashMap<char, u64>>,
}
impl PinyinDict {
pub fn embedded() -> Self {
fn load_optional(bytes: &'static [u8], label: &str) -> Option<Fsa<&'static [u8]>> {
if bytes.is_empty() {
None
} else {
Some(Fsa::new(bytes).unwrap_or_else(|_| panic!("invalid embedded {label} fsa")))
}
}
fn load_optional_dict(bytes: &'static [u8], label: &str) -> Option<Dict<&'static [u8]>> {
if bytes.is_empty() {
None
} else {
Some(Dict::new(bytes).unwrap_or_else(|_| panic!("invalid embedded {label} dict")))
}
}
#[cfg(not(target_arch = "wasm32"))]
let dict_bytes: &'static [u8] = match std::env::var_os("INPUTX_PINYIN_DICT") {
Some(path) => {
let data = std::fs::read(&path)
.unwrap_or_else(|e| panic!("INPUTX_PINYIN_DICT {path:?}: {e}"));
Box::leak(data.into_boxed_slice())
}
None => DICT_BYTES,
};
#[cfg(target_arch = "wasm32")]
let dict_bytes: &'static [u8] = DICT_BYTES;
Self {
map: Dict::new(dict_bytes).expect("invalid pinyin dict"),
bigrams: load_optional(BIGRAMS_BYTES, "bigrams"),
bigrams_intra: load_optional(BIGRAMS_INTRA_BYTES, "bigrams_intra"),
trigrams: load_optional_dict(TRIGRAMS_BYTES, "trigrams"),
l0: RwLock::new(L0Inner::new()),
char_max_freq: OnceLock::new(),
}
}
pub fn char_max_freq(&self, c: char) -> u64 {
self.build_char_freq_cache().get(&c).copied().unwrap_or(0)
}
fn build_char_freq_cache(&self) -> &std::collections::HashMap<char, u64> {
self.char_max_freq.get_or_init(|| {
let mut cache = std::collections::HashMap::with_capacity(8192);
self.map.prefix_for_each(b"", |_code, word_bytes, freq| {
let Ok(word) = core::str::from_utf8(word_bytes) else { return };
let mut chars = word.chars();
let Some(c) = chars.next() else { return };
if chars.next().is_some() { return; }
let entry = cache.entry(c).or_insert(0);
if freq > *entry { *entry = freq; }
});
cache
})
}
pub fn l0_pin_count(&self) -> usize {
self.l0.read().map(|g| g.pins.len()).unwrap_or(0)
}
pub fn l0_pending_count(&self) -> usize {
self.l0.read().map(|g| g.pick_counts.len()).unwrap_or(0)
}
pub fn len(&self) -> usize {
self.map.len() as usize
}
pub fn is_empty(&self) -> bool {
self.map.is_empty()
}
pub fn lookup(&self, pinyin: &str) -> Vec<String> {
let mut out = Vec::new();
self.lookup_into(pinyin, &mut out);
out
}
pub fn lookup_into(&self, pinyin: &str, out: &mut Vec<String>) {
out.clear();
let lower = pinyin.to_ascii_lowercase();
self.map.get_for_each(lower.as_bytes(), |word, _freq| {
if let Ok(s) = core::str::from_utf8(word) {
out.push(s.to_string());
}
});
if let Ok(l0) = self.l0.read()
&& let Some(pref) = l0.pins.get(&lower_str(pinyin))
&& let Some(idx) = out.iter().position(|w| w == pref)
&& idx > 0
{
let p = out.remove(idx);
out.insert(0, p);
}
}
pub fn prefix_exists(&self, prefix: &str) -> bool {
self.map
.contains_prefix(prefix.to_ascii_lowercase().as_bytes())
}
pub fn prefix(&self, prefix: &str) -> Vec<(String, String)> {
let lower = prefix.to_ascii_lowercase();
let mut results: Vec<(String, String)> = Vec::new();
self.map.prefix_for_each(lower.as_bytes(), |code, word, _freq| {
if let (Ok(pinyin), Ok(word)) =
(core::str::from_utf8(code), core::str::from_utf8(word))
{
results.push((pinyin.to_string(), word.to_string()));
}
});
results.sort();
results
}
pub fn prefix_for_each<F>(&self, prefix: &str, mut visit: F)
where
F: FnMut(&str, &str, u64),
{
self.prefix_for_each_raw(prefix, |pinyin_bytes, word_bytes, freq| {
if let (Ok(pinyin), Ok(word)) = (
core::str::from_utf8(pinyin_bytes),
core::str::from_utf8(word_bytes),
) {
visit(pinyin, word, freq);
}
});
}
pub fn prefix_for_each_raw<F>(&self, prefix: &str, mut visit: F)
where
F: FnMut(&[u8], &[u8], u64),
{
let lower = prefix.to_ascii_lowercase();
self.map
.prefix_for_each(lower.as_bytes(), |code, word, value| {
visit(code, word, value);
});
}
pub fn prefix_with_freq(&self, prefix: &str) -> Vec<(String, String, u64)> {
let lower = prefix.to_ascii_lowercase();
let mut results: Vec<(String, String, u64)> = Vec::new();
self.map.prefix_for_each(lower.as_bytes(), |code, word, value| {
if let (Ok(pinyin), Ok(word)) =
(core::str::from_utf8(code), core::str::from_utf8(word))
{
results.push((pinyin.to_string(), word.to_string(), value));
}
});
results
}
pub fn record_pick(&self, pinyin: &str, word: &str) -> bool {
if !self.exists_in_l1(pinyin, word) {
return false;
}
let lower = lower_str(pinyin);
let Ok(mut l0) = self.l0.write() else {
return false;
};
let key = (lower.clone(), word.to_string());
let count = l0.pick_counts.entry(key).or_insert(0);
*count += 1;
if *count >= PROMOTE_THRESHOLD {
l0.pins.insert(lower.clone(), word.to_string());
l0.pick_counts.retain(|(p, _), _| p != &lower);
return true;
}
false
}
pub fn pin(&self, pinyin: &str, word: &str) -> bool {
if !self.exists_in_l1(pinyin, word) {
return false;
}
let lower = lower_str(pinyin);
let Ok(mut l0) = self.l0.write() else {
return false;
};
l0.pins.insert(lower.clone(), word.to_string());
l0.pick_counts.retain(|(p, _), _| p != &lower);
true
}
pub fn lookup_with_freq_into(&self, pinyin: &str, out: &mut Vec<(String, u64)>) {
out.clear();
let lower = lower_str(pinyin);
self.map.get_for_each(lower.as_bytes(), |word, freq| {
if let Ok(s) = core::str::from_utf8(word) {
out.push((s.to_string(), freq));
}
});
}
pub fn lookup_with_scores_into(&self, pinyin: &str, out: &mut Vec<(String, f64)>) {
out.clear();
let lower = lower_str(pinyin);
const PINYIN_PHRASE_BASE: f64 = 400_000.0;
let mut scratch: Vec<(String, f64)> = Vec::with_capacity(8);
self.map.get_for_each(lower.as_bytes(), |word, freq| {
if let Ok(s) = core::str::from_utf8(word) {
scratch.push((s.to_string(), PINYIN_PHRASE_BASE + freq as f64));
}
});
let pinned: Option<String> = self.l0.read().ok().and_then(|g| g.pins.get(&lower).cloned());
if let Some(p) = &pinned {
for e in scratch.iter_mut() {
if &e.0 == p {
e.1 *= 1000.0;
}
}
}
scratch.sort_by(|a, b| {
b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
});
out.reserve(scratch.len());
for (w, score) in scratch.drain(..) {
out.push((w, score));
}
}
pub fn best_composition_chain(&self, buffer: &str) -> Option<(f64, String, Vec<String>)> {
const MIN_LEN: usize = 4;
const MAX_LEN: usize = 30;
const MAX_SYL: usize = 24;
const STEP_PENALTY: f64 = 100_000.0;
let buf = buffer.as_bytes();
let n = buf.len();
if !(MIN_LEN..=MAX_LEN).contains(&n) {
return None;
}
let mut dp: Vec<Option<(f64, usize, String)>> = vec![None; n + 1];
dp[0] = Some((0.0, 0, String::new()));
let mut scratch: Vec<(String, u64)> = Vec::new();
for i in 1..=n {
let lo = i.saturating_sub(MAX_SYL);
for j in lo..i {
let prev_entry = match dp[j].as_ref() {
Some(p) => p.clone(),
None => continue,
};
let seg = match core::str::from_utf8(&buf[j..i]) {
Ok(s) => s,
Err(_) => continue,
};
self.lookup_raw_into(seg, &mut scratch);
if scratch.is_empty() {
continue;
}
for (word, raw_freq) in scratch.iter() {
let prev_word_opt = if prev_entry.2.is_empty() {
None
} else {
Some(prev_entry.2.as_str())
};
let bonus = self.bigram_boost(prev_word_opt, word);
let step_score = (*raw_freq as f64) + bonus - STEP_PENALTY;
let total = prev_entry.0 + step_score;
let dp_better = match dp[i].as_ref() {
None => true,
Some(cur) => total > cur.0,
};
if dp_better {
dp[i] = Some((total, j, word.clone()));
}
}
}
}
let final_entry = dp[n].as_ref()?;
let final_score = final_entry.0;
let mut chain: Vec<String> = Vec::new();
let mut pos = n;
while pos > 0 {
let entry = dp[pos].as_ref()?;
chain.push(entry.2.clone());
pos = entry.1;
}
chain.reverse();
let sentence = chain.concat();
Some((final_score, sentence, chain))
}
pub fn best_composition(&self, buffer: &str) -> Option<(f64, String)> {
const MIN_LEN: usize = 4;
const MAX_LEN: usize = 30;
const MAX_SYL: usize = 24;
const STEP_PENALTY: f64 = 100_000.0;
let buf = buffer.as_bytes();
let n = buf.len();
if !(MIN_LEN..=MAX_LEN).contains(&n) {
return None;
}
let mut dp: Vec<Option<(f64, usize, String)>> = vec![None; n + 1];
dp[0] = Some((0.0, 0, String::new()));
let mut scratch: Vec<(String, u64)> = Vec::new();
for i in 1..=n {
let lo = i.saturating_sub(MAX_SYL);
for j in lo..i {
let prev_entry = match dp[j].as_ref() {
Some(p) => p.clone(),
None => continue,
};
let seg = match core::str::from_utf8(&buf[j..i]) {
Ok(s) => s,
Err(_) => continue,
};
self.lookup_raw_into(seg, &mut scratch);
if scratch.is_empty() {
continue;
}
for (word, raw_freq) in scratch.iter() {
let prev_word_opt = if prev_entry.2.is_empty() {
None
} else {
Some(prev_entry.2.as_str())
};
let bonus = self.bigram_boost(prev_word_opt, word);
let step_score = (*raw_freq as f64) + bonus - STEP_PENALTY;
let total = prev_entry.0 + step_score;
let dp_better = match dp[i].as_ref() {
None => true,
Some(cur) => total > cur.0,
};
if dp_better {
dp[i] = Some((total, j, word.clone()));
}
}
}
}
let final_entry = dp[n].as_ref()?;
let final_score = final_entry.0;
let mut chain: Vec<String> = Vec::new();
let mut pos = n;
while pos > 0 {
let entry = dp[pos].as_ref()?;
chain.push(entry.2.clone());
pos = entry.1;
}
chain.reverse();
Some((final_score, chain.concat()))
}
pub fn top_k_compositions(&self, buffer: &str, k: usize) -> Vec<(f64, String)> {
const MIN_LEN: usize = 4;
const MAX_LEN: usize = 30;
const MAX_SYL: usize = 24;
const STEP_PENALTY: f64 = 100_000.0;
let buf = buffer.as_bytes();
let n = buf.len();
if k == 0 || !(MIN_LEN..=MAX_LEN).contains(&n) {
return Vec::new();
}
let mut dp: Vec<Vec<(f64, usize, usize, String)>> = vec![Vec::new(); n + 1];
dp[0].push((0.0, 0, 0, String::new()));
let mut scratch: Vec<(String, u64)> = Vec::new();
for i in 1..=n {
let lo = i.saturating_sub(MAX_SYL);
let mut candidates: Vec<(f64, usize, usize, String)> = Vec::new();
for j in lo..i {
if dp[j].is_empty() {
continue;
}
let seg = match core::str::from_utf8(&buf[j..i]) {
Ok(s) => s,
Err(_) => continue,
};
self.lookup_raw_into(seg, &mut scratch);
if scratch.is_empty() {
continue;
}
for (prev_idx, prev_path) in dp[j].iter().enumerate() {
let prev_word_opt = if prev_path.3.is_empty() {
None
} else {
Some(prev_path.3.as_str())
};
for (word, raw_freq) in scratch.iter() {
let bonus = self.bigram_boost(prev_word_opt, word);
let step_score = (*raw_freq as f64) + bonus - STEP_PENALTY;
let total = prev_path.0 + step_score;
candidates.push((total, j, prev_idx, word.clone()));
}
}
}
candidates.sort_by(|a, b| {
b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal)
});
candidates.truncate(k);
dp[i] = candidates;
}
let mut out: Vec<(f64, String)> = Vec::with_capacity(dp[n].len());
let mut seen: std::collections::HashSet<String> =
std::collections::HashSet::with_capacity(dp[n].len());
let end_paths: Vec<(f64, usize, usize)> = dp[n]
.iter()
.map(|p| (p.0, p.1, p.2))
.collect();
for (end_score, end_prev_pos, end_prev_idx) in end_paths {
let mut chain: Vec<String> = Vec::new();
let mut pos = n;
let mut cur_idx = dp[pos]
.iter()
.position(|p| (p.1, p.2) == (end_prev_pos, end_prev_idx))
.expect("dp[n] contains the end path we just enumerated");
while pos > 0 {
let entry = &dp[pos][cur_idx];
chain.push(entry.3.clone());
pos = entry.1;
cur_idx = entry.2;
}
chain.reverse();
let sentence = chain.concat();
if seen.insert(sentence.clone()) {
out.push((end_score, sentence));
}
}
out
}
fn lookup_raw_into(&self, pinyin: &str, out: &mut Vec<(String, u64)>) {
out.clear();
let lower = pinyin.to_ascii_lowercase();
self.map.get_for_each(lower.as_bytes(), |word, freq| {
if let Ok(s) = core::str::from_utf8(word) {
out.push((s.to_string(), freq));
}
});
}
pub fn iter_bigrams(&self) -> Vec<(String, String, u64)> {
use std::collections::HashMap;
let mut counts: HashMap<(String, String), u64> = HashMap::new();
for src in [self.bigrams.as_ref(), self.bigrams_intra.as_ref()].iter().flatten() {
src.prefix_for_each(b"", |key, count| {
let Some(sep) = key.iter().position(|&b| b == 0) else {
return;
};
let prev = &key[..sep];
let next = &key[sep + 1..];
if next.is_empty() {
return;
}
if let (Ok(p), Ok(n)) =
(core::str::from_utf8(prev), core::str::from_utf8(next))
{
*counts.entry((p.to_string(), n.to_string())).or_insert(0) += count;
}
});
}
counts.into_iter().map(|((p, n), c)| (p, n, c)).collect()
}
pub fn predict_next_words(&self, prev: &str, limit: usize) -> Vec<(String, u64)> {
const MIN_PREDICTION_COUNT: u64 = 30;
if prev.is_empty() || limit == 0 {
return Vec::new();
}
let Some(bigrams) = self.bigrams.as_ref() else {
return Vec::new();
};
let mut prefix = prev.as_bytes().to_vec();
prefix.push(0u8);
let prefix_len = prefix.len();
let mut hits: Vec<(String, u64)> = Vec::new();
bigrams.prefix_for_each(&prefix, |key, count| {
if count < MIN_PREDICTION_COUNT {
return;
}
let next_bytes = &key[prefix_len..];
if next_bytes.is_empty() {
return;
}
if let Ok(s) = core::str::from_utf8(next_bytes) {
hits.push((s.to_string(), count));
}
});
hits.sort_by(|a, b| b.1.cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
hits.truncate(limit);
hits
}
pub fn predict_next_words_context(
&self,
prev_prev: Option<&str>,
prev: &str,
limit: usize,
) -> Vec<(String, u64)> {
const MIN_TRIGRAM_COUNT: u64 = 15;
if prev.is_empty() || limit == 0 {
return Vec::new();
}
let Some(prev_prev) = prev_prev else { return Vec::new() };
if prev_prev.is_empty() {
return Vec::new();
}
let Some(trigrams) = self.trigrams.as_ref() else {
return Vec::new();
};
let mut code = prev_prev.as_bytes().to_vec();
code.push(0u8);
code.extend_from_slice(prev.as_bytes());
let mut hits: Vec<(String, u64)> = Vec::new();
trigrams.get_for_each(&code, |c_bytes, count| {
if count < MIN_TRIGRAM_COUNT {
return;
}
if let Ok(s) = core::str::from_utf8(c_bytes) {
hits.push((s.to_string(), count));
}
});
hits.sort_by(|a, b| b.1.cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
hits.truncate(limit);
hits
}
pub fn bigram_boost(&self, prev: Option<&str>, next: &str) -> f64 {
const BIGRAM_BOOST_MAX: f64 = 50_000.0;
const BIGRAM_REF: f64 = 1000.0;
let Some(prev) = prev else { return 0.0 };
if prev.is_empty() || next.is_empty() {
return 0.0;
}
let mut key = prev.as_bytes().to_vec();
key.push(0u8);
key.extend_from_slice(next.as_bytes());
let count_inter = self.bigrams.as_ref()
.and_then(|m| m.get(&key)).unwrap_or(0);
let count_intra = self.bigrams_intra.as_ref()
.and_then(|m| m.get(&key)).unwrap_or(0);
let count = count_inter + count_intra;
if count == 0 {
return 0.0;
}
let scaled = ((count as f64) + 1.0).ln() / (BIGRAM_REF + 1.0).ln();
BIGRAM_BOOST_MAX * scaled.min(1.0)
}
pub fn pinned_word(&self, pinyin: &str) -> Option<String> {
let lower = lower_str(pinyin);
self.l0.read().ok().and_then(|l0| l0.pins.get(&lower).cloned())
}
pub fn forget(&self, pinyin: &str) -> bool {
let lower = lower_str(pinyin);
let Ok(mut l0) = self.l0.write() else {
return false;
};
let had_pin = l0.pins.remove(&lower).is_some();
let len_before = l0.pick_counts.len();
l0.pick_counts.retain(|(p, _), _| p != &lower);
had_pin || l0.pick_counts.len() != len_before
}
pub fn export_l0(&self) -> L0Snapshot {
let Ok(l0) = self.l0.read() else {
return L0Snapshot::default();
};
L0Snapshot {
pins: l0
.pins
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect(),
pick_counts: l0
.pick_counts
.iter()
.map(|((p, w), n)| (p.clone(), w.clone(), *n))
.collect(),
}
}
pub fn import_l0(&self, snap: L0Snapshot) -> usize {
let valid_pins: Vec<(String, String)> = snap
.pins
.into_iter()
.filter(|(p, w)| self.exists_in_l1(p, w))
.collect();
let valid_counts: Vec<((String, String), u32)> = snap
.pick_counts
.into_iter()
.filter_map(|(p, w, n)| {
if self.exists_in_l1(&p, &w) {
Some(((p, w), n))
} else {
None
}
})
.collect();
let accepted = valid_pins.len();
let Ok(mut l0) = self.l0.write() else {
return 0;
};
l0.pins = valid_pins.into_iter().collect();
l0.pick_counts = valid_counts.into_iter().collect();
accepted
}
fn exists_in_l1(&self, pinyin: &str, word: &str) -> bool {
self.lookup(pinyin).iter().any(|w| w == word)
}
}
fn lower_str(s: &str) -> String {
s.to_ascii_lowercase()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn embedded_loads() {
let d = PinyinDict::embedded();
assert!(d.len() >= 50, "bootstrap should have at least 50 entries");
}
#[cfg(not(feature = "bootstrap_only"))]
#[test]
fn shipped_data_at_expected_scale() {
let d = PinyinDict::embedded();
assert!(d.len() >= 140_000, "pinyin.dict too small: {} codes", d.len());
assert!(d.bigram_boost(Some("中国"), "人民") > 0.0
|| d.bigram_boost(Some("我们"), "一起") > 0.0,
"bigrams index looks empty");
let ctx = d.predict_next_words_context(Some("我们"), "一起", 10);
let cold = d.predict_next_words_context(None, "我们", 10);
assert!(cold.is_empty(), "cold-start (no prev_prev) must be empty");
let _ = ctx;
}
#[test]
fn lookup_zhongguo_returns_zhongguo() {
let d = PinyinDict::embedded();
let words = d.lookup("zhongguo");
assert_eq!(words.first().map(String::as_str), Some("中国"));
}
#[test]
fn lookup_wo_returns_wo_first() {
let d = PinyinDict::embedded();
let words = d.lookup("wo");
assert_eq!(words.first().map(String::as_str), Some("我"));
}
#[test]
fn lookup_shi_returns_multiple() {
let d = PinyinDict::embedded();
let words = d.lookup("shi");
assert!(
words.len() >= 3,
"expected ≥3 candidates for shi, got {words:?}"
);
assert!(words.contains(&"是".to_string()));
}
#[test]
fn lookup_unknown_returns_empty() {
let d = PinyinDict::embedded();
assert!(d.lookup("qzqzqz").is_empty());
}
#[test]
fn case_insensitive() {
let d = PinyinDict::embedded();
assert_eq!(d.lookup("WO"), d.lookup("wo"));
assert_eq!(d.lookup("ZhongGuo"), d.lookup("zhongguo"));
}
#[test]
fn lookup_into_reuses_buffer() {
let d = PinyinDict::embedded();
let mut buf = Vec::with_capacity(16);
d.lookup_into("ni", &mut buf);
let cap_after_first = buf.capacity();
d.lookup_into("ta", &mut buf);
assert!(buf.capacity() >= cap_after_first);
assert_eq!(buf.first().map(String::as_str), Some("他"));
}
#[test]
fn prefix_returns_sorted_pairs() {
let d = PinyinDict::embedded();
let pairs = d.prefix("zhong");
assert!(pairs.iter().any(|(p, _)| p == "zhong"));
assert!(pairs.iter().any(|(p, w)| p == "zhongguo" && w == "中国"));
}
#[cfg(not(feature = "bootstrap_only"))]
#[test]
fn l0_starts_empty() {
let d = PinyinDict::embedded();
assert_eq!(d.l0_pin_count(), 0);
assert_eq!(d.l0_pending_count(), 0);
}
#[cfg(not(feature = "bootstrap_only"))]
#[test]
fn record_pick_promotes_after_threshold() {
let d = PinyinDict::embedded();
let target = "时";
for _ in 0..(PROMOTE_THRESHOLD - 1) {
assert!(!d.record_pick("shi", target));
}
assert!(d.record_pick("shi", target), "should promote on Nth pick");
assert_eq!(d.lookup("shi").first().map(String::as_str), Some(target));
assert_eq!(d.l0_pin_count(), 1);
assert_eq!(d.l0_pending_count(), 0);
}
#[cfg(not(feature = "bootstrap_only"))]
#[test]
fn record_pick_resets_on_promotion_so_others_must_earn_3_again() {
let d = PinyinDict::embedded();
for _ in 0..PROMOTE_THRESHOLD {
d.record_pick("shi", "时");
}
assert!(!d.record_pick("shi", "事"));
assert_eq!(d.lookup("shi").first().map(String::as_str), Some("时"));
for _ in 0..(PROMOTE_THRESHOLD - 1) {
d.record_pick("shi", "事");
}
assert_eq!(d.lookup("shi").first().map(String::as_str), Some("事"));
}
#[cfg(not(feature = "bootstrap_only"))]
#[test]
fn record_pick_rejects_unknown_word() {
let d = PinyinDict::embedded();
for _ in 0..PROMOTE_THRESHOLD {
assert!(!d.record_pick("shi", "this_is_not_a_real_word"));
}
assert_eq!(d.l0_pin_count(), 0);
assert_eq!(d.l0_pending_count(), 0);
}
#[cfg(not(feature = "bootstrap_only"))]
#[test]
fn pin_force_pins_without_counters() {
let d = PinyinDict::embedded();
assert!(d.pin("shi", "时"));
assert_eq!(d.lookup("shi").first().map(String::as_str), Some("时"));
assert_eq!(d.l0_pending_count(), 0);
}
#[cfg(not(feature = "bootstrap_only"))]
#[test]
fn forget_clears_pin_and_counters() {
let d = PinyinDict::embedded();
d.pin("shi", "时");
d.record_pick("shi", "事");
assert!(d.forget("shi"));
assert_eq!(d.l0_pin_count(), 0);
assert_eq!(d.l0_pending_count(), 0);
}
#[cfg(not(feature = "bootstrap_only"))]
#[test]
fn export_import_roundtrip() {
let d = PinyinDict::embedded();
d.pin("shi", "时");
d.record_pick("zhongguo", "中国"); let snap = d.export_l0();
assert_eq!(snap.pins.len(), 1);
assert_eq!(snap.pick_counts.len(), 1);
d.forget("shi");
d.forget("zhongguo");
assert_eq!(d.l0_pin_count(), 0);
let accepted = d.import_l0(snap);
assert_eq!(accepted, 1);
assert_eq!(d.lookup("shi").first().map(String::as_str), Some("时"));
}
#[cfg(not(feature = "bootstrap_only"))]
#[test]
fn import_drops_invalid_entries() {
let d = PinyinDict::embedded();
let snap = L0Snapshot {
pins: vec![
("shi".into(), "时".into()),
("shi".into(), "bogus_word".into()),
],
pick_counts: vec![("shi".into(), "ghost_word".into(), 2)],
};
let accepted = d.import_l0(snap);
assert_eq!(accepted, 1);
assert_eq!(d.l0_pending_count(), 0);
}
#[test]
fn l0_pin_pin_lookup_compiles() {
let d = PinyinDict::embedded();
assert!(d.pin("zhongguo", "中国"));
assert_eq!(d.l0_pin_count(), 1);
assert!(d.forget("zhongguo"));
assert_eq!(d.l0_pin_count(), 0);
}
#[test]
fn predict_next_words_empty_inputs() {
let d = PinyinDict::embedded();
assert!(d.predict_next_words("", 10).is_empty());
assert!(d.predict_next_words("今天", 0).is_empty());
}
#[cfg(not(feature = "bootstrap_only"))]
#[test]
fn lookup_lixiang_lixiang_leads() {
let d = PinyinDict::embedded();
let cands = d.lookup("lixiang");
assert_eq!(cands.first().map(String::as_str), Some("理想"),
"expected 理想 #1 for lixiang; got {:?}",
cands.iter().take(5).collect::<Vec<_>>());
}
#[cfg(not(feature = "bootstrap_only"))]
#[test]
fn lookup_queshi_quexi_polish_log_promoted() {
let d = PinyinDict::embedded();
let cands = d.lookup("queshi");
assert_eq!(cands.first().map(String::as_str), Some("缺失"),
"expected 缺失 #1 (was 确实 before polish-log auto-tune); top5={:?}",
cands.iter().take(5).collect::<Vec<_>>());
}
#[cfg(not(feature = "bootstrap_only"))]
#[test]
fn lookup_traditional_dropped_after_strip() {
let d = PinyinDict::embedded();
let yu_cands = d.lookup("yu");
assert!(!yu_cands.iter().take(5).any(|w| w == "於"),
"於 should be stripped; got top5={:?}", &yu_cands[..yu_cands.len().min(5)]);
let guo_cands = d.lookup("guo");
assert!(!guo_cands.iter().take(5).any(|w| w == "國"),
"國 should be stripped; got top5={:?}", &guo_cands[..guo_cands.len().min(5)]);
}
#[cfg(not(feature = "bootstrap_only"))]
#[test]
fn predict_next_words_jintian_top_followers() {
let d = PinyinDict::embedded();
let preds = d.predict_next_words("今天", 10);
assert!(!preds.is_empty(), "expected predictions for 今天");
let words: Vec<&str> = preds.iter().map(|(w, _)| w.as_str()).collect();
let has_common_followers = ["的", "在", "是", "我", "我们"]
.iter()
.any(|w| words.contains(w));
assert!(has_common_followers,
"expected at least one of 的/在/是/我/我们 in 今天 predictions; got {words:?}");
}
#[cfg(not(feature = "bootstrap_only"))]
#[test]
fn predict_next_words_context_uses_trigram_or_empty() {
let d = PinyinDict::embedded();
let with_context = d.predict_next_words_context(
Some("今天"), "的", 10);
for w in with_context.windows(2) {
assert!(w[0].1 >= w[1].1,
"trigram results must be sorted desc; got {w:?}");
}
}
#[cfg(not(feature = "bootstrap_only"))]
#[test]
fn predict_next_words_context_no_backoff_for_chains() {
let d = PinyinDict::embedded();
let chained = d.predict_next_words_context(
Some("锟斤拷"), "我们", 5);
assert!(chained.is_empty(),
"chained prediction with empty trigram must NOT backoff to bigram; \
got {chained:?}");
}
#[cfg(not(feature = "bootstrap_only"))]
#[test]
fn predict_next_words_context_threshold_15_surfaces_real_predictions() {
let d = PinyinDict::embedded();
let r = d.predict_next_words_context(Some("我们"), "的", 10);
assert!(!r.is_empty(),
"我们的 should predict at threshold 15 (counts 40/30/19); got empty");
assert!(r.iter().all(|(_, c)| *c >= 15),
"every prediction must clear the 15 threshold; got {r:?}");
}
#[cfg(not(feature = "bootstrap_only"))]
#[test]
fn predict_next_words_context_cold_start_returns_empty() {
let d = PinyinDict::embedded();
let cold = d.predict_next_words_context(None, "我们", 5);
assert!(cold.is_empty(),
"cold start (no prev_prev) must return empty under v1.4 strict; \
got {cold:?}");
}
#[cfg(not(feature = "bootstrap_only"))]
#[test]
fn predict_next_words_sorted_desc() {
let d = PinyinDict::embedded();
let preds = d.predict_next_words("我们", 5);
if preds.len() < 2 { return; } for w in preds.windows(2) {
assert!(w[0].1 >= w[1].1,
"predictions must be sorted by count desc; got {:?} then {:?}",
w[0], w[1]);
}
}
#[test]
fn bigram_boost_zero_without_prev() {
let d = PinyinDict::embedded();
assert_eq!(d.bigram_boost(None, "好"), 0.0);
assert_eq!(d.bigram_boost(None, ""), 0.0);
}
#[test]
fn bigram_boost_zero_for_empty_or_unknown() {
let d = PinyinDict::embedded();
assert_eq!(d.bigram_boost(Some(""), "好"), 0.0);
assert_eq!(d.bigram_boost(Some("今天"), ""), 0.0);
assert_eq!(d.bigram_boost(Some("锟斤拷"), "烫烫烫"), 0.0);
}
#[test]
fn best_composition_too_short_returns_none() {
let d = PinyinDict::embedded();
assert!(d.best_composition("").is_none());
assert!(d.best_composition("ni").is_none());
assert!(d.best_composition("nih").is_none());
}
#[cfg(not(feature = "bootstrap_only"))]
#[test]
fn best_composition_nihao_keeps_phrase() {
let d = PinyinDict::embedded();
let (_, chain) = d.best_composition("nihao").expect("hit");
assert_eq!(chain, "你好", "want 你好 as single phrase, got {chain:?}");
}
#[cfg(not(feature = "bootstrap_only"))]
#[test]
fn best_composition_nihaomawojiao_segments() {
let d = PinyinDict::embedded();
let result = d.best_composition("nihaomawojiao");
let Some((score, chain)) = result else {
panic!("expected some segmentation for nihaomawojiao");
};
eprintln!("nihaomawojiao → {chain:?} (score {score})");
assert!(chain.chars().all(|c| ('\u{4e00}'..='\u{9fff}').contains(&c)),
"expected pure-CJK segmentation, got {chain:?}");
let char_count = chain.chars().count();
assert!((4..=7).contains(&char_count),
"expected 4-7 CJK chars, got {char_count} in {chain:?}");
}
#[cfg(not(feature = "bootstrap_only"))]
#[test]
fn perfgate_predict_next_words_under_budget() {
let d = PinyinDict::embedded();
const ITER: usize = 30;
const MIN_BUDGET_NS: u128 = 2_000_000; const MAX_BUDGET_NS: u128 = 5_000_000;
let probes: &[(Option<&str>, &str, usize, &str)] = &[
(None, "的", 10, "cold-bigram-的"),
(None, "我", 10, "cold-bigram-我"),
(None, "今天", 10, "cold-bigram-今天"),
(Some("今天"), "的", 10, "chained-trigram-今天-的"),
(Some("锟斤拷"), "无关词", 10, "chained-empty-fast-bailout"),
(None, "的", 50, "cold-bigram-的-limit50"),
];
let mut all_passed = true;
for (prev_prev, prev, limit, label) in probes {
let mut times: Vec<u128> = Vec::with_capacity(ITER);
for _ in 0..ITER {
let start = std::time::Instant::now();
let _ = d.predict_next_words_context(*prev_prev, prev, *limit);
times.push(start.elapsed().as_nanos());
}
times.sort_unstable();
let min = times[0];
let p50 = times[times.len() / 2];
let p95 = times[(times.len() * 95) / 100];
let max = *times.last().unwrap();
eprintln!(
"perfgate-predict {label:>30}: min={:>5.2}ms p50={:>5.2}ms p95={:>5.2}ms max={:>5.2}ms",
min as f64 / 1_000_000.0,
p50 as f64 / 1_000_000.0,
p95 as f64 / 1_000_000.0,
max as f64 / 1_000_000.0,
);
if !cfg!(debug_assertions) {
if min > MIN_BUDGET_NS {
eprintln!(" ^^ FAIL: min {:.2}ms exceeds {}ms uncontended budget",
min as f64 / 1_000_000.0, MIN_BUDGET_NS / 1_000_000);
all_passed = false;
}
if p95 > MAX_BUDGET_NS {
eprintln!(" ^^ FAIL: p95 {:.2}ms exceeds {}ms",
p95 as f64 / 1_000_000.0, MAX_BUDGET_NS / 1_000_000);
all_passed = false;
}
}
}
assert!(all_passed || cfg!(debug_assertions),
"perfgate-predict failed — see eprintln above");
}
#[cfg(not(feature = "bootstrap_only"))]
#[test]
fn bigram_boost_positive_for_common_pair() {
let d = PinyinDict::embedded();
let boost_de = d.bigram_boost(Some("今天"), "的");
let boost_shi = d.bigram_boost(Some("今天"), "是");
assert!(
boost_de > 0.0 || boost_shi > 0.0,
"expected positive bigram boost for 今天→的/是, got de={boost_de} shi={boost_shi}"
);
}
}