use std::sync::RwLock;
use fst::{IntoStreamer, Map, Streamer};
use crate::ranking::{L0Inner, L0Snapshot, PROMOTE_THRESHOLD};
#[cfg(not(feature = "bootstrap_only"))]
const DICT_BYTES: &[u8] = include_bytes!("../data/pinyin.fst");
#[cfg(feature = "bootstrap_only")]
const DICT_BYTES: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/bootstrap.fst"));
pub struct PinyinDict {
map: Map<&'static [u8]>,
l0: RwLock<L0Inner>,
}
impl PinyinDict {
pub fn embedded() -> Self {
Self {
map: Map::new(DICT_BYTES).expect("invalid embedded pinyin FST"),
l0: RwLock::new(L0Inner::new()),
}
}
pub fn l0_pin_count(&self) -> usize {
self.l0.read().map(|g| g.pins.len()).unwrap_or(0)
}
pub fn l0_pending_count(&self) -> usize {
self.l0.read().map(|g| g.pick_counts.len()).unwrap_or(0)
}
pub fn len(&self) -> usize {
self.map.len()
}
pub fn is_empty(&self) -> bool {
self.map.len() == 0
}
pub fn lookup(&self, pinyin: &str) -> Vec<String> {
let mut out = Vec::new();
self.lookup_into(pinyin, &mut out);
out
}
pub fn lookup_into(&self, pinyin: &str, out: &mut Vec<String>) {
out.clear();
let lower = pinyin.to_ascii_lowercase();
let mut prefix = lower.into_bytes();
let prefix_len = prefix.len();
prefix.push(0u8);
let mut upper = prefix.clone();
let last = upper.len() - 1;
upper[last] = 0x01;
let mut scratch: Vec<(String, u64)> = Vec::with_capacity(8);
let mut stream = self
.map
.range()
.ge(prefix.as_slice())
.lt(upper.as_slice())
.into_stream();
while let Some((key, value)) = stream.next() {
if key.len() <= prefix_len + 1 {
continue;
}
let word_bytes = &key[prefix_len + 1..];
if let Ok(s) = core::str::from_utf8(word_bytes) {
scratch.push((s.to_string(), value));
}
}
scratch.sort_by_key(|s| std::cmp::Reverse(s.1));
out.reserve(scratch.len());
for (w, _) in scratch.drain(..) {
out.push(w);
}
if let Ok(l0) = self.l0.read()
&& let Some(pref) = l0.pins.get(&lower_str(pinyin))
&& let Some(idx) = out.iter().position(|w| w == pref)
&& idx > 0
{
let p = out.remove(idx);
out.insert(0, p);
}
}
pub fn prefix(&self, prefix: &str) -> Vec<(String, String)> {
let lower = prefix.to_ascii_lowercase();
let lo = lower.into_bytes();
let hi = bump_last(&lo);
let mut stream = self
.map
.range()
.ge(lo.as_slice())
.lt(hi.as_slice())
.into_stream();
let mut results: Vec<(String, String)> = Vec::new();
while let Some((key, _value)) = stream.next() {
let Some(sep) = key.iter().position(|b| *b == 0u8) else {
continue;
};
let (pinyin_bytes, rest) = key.split_at(sep);
let word_bytes = &rest[1..];
if let (Ok(pinyin), Ok(word)) = (
core::str::from_utf8(pinyin_bytes),
core::str::from_utf8(word_bytes),
) {
results.push((pinyin.to_string(), word.to_string()));
}
}
results.sort();
results
}
pub fn prefix_with_freq(&self, prefix: &str) -> Vec<(String, String, u64)> {
let lower = prefix.to_ascii_lowercase();
let lo = lower.into_bytes();
let hi = bump_last(&lo);
let mut stream = self
.map
.range()
.ge(lo.as_slice())
.lt(hi.as_slice())
.into_stream();
let mut results: Vec<(String, String, u64)> = Vec::new();
while let Some((key, value)) = stream.next() {
let Some(sep) = key.iter().position(|b| *b == 0u8) else {
continue;
};
let (pinyin_bytes, rest) = key.split_at(sep);
let word_bytes = &rest[1..];
if let (Ok(pinyin), Ok(word)) = (
core::str::from_utf8(pinyin_bytes),
core::str::from_utf8(word_bytes),
) {
results.push((pinyin.to_string(), word.to_string(), value));
}
}
results
}
pub fn record_pick(&self, pinyin: &str, word: &str) -> bool {
if !self.exists_in_l1(pinyin, word) {
return false;
}
let lower = lower_str(pinyin);
let Ok(mut l0) = self.l0.write() else {
return false;
};
let key = (lower.clone(), word.to_string());
let count = l0.pick_counts.entry(key).or_insert(0);
*count += 1;
if *count >= PROMOTE_THRESHOLD {
l0.pins.insert(lower.clone(), word.to_string());
l0.pick_counts.retain(|(p, _), _| p != &lower);
return true;
}
false
}
pub fn pin(&self, pinyin: &str, word: &str) -> bool {
if !self.exists_in_l1(pinyin, word) {
return false;
}
let lower = lower_str(pinyin);
let Ok(mut l0) = self.l0.write() else {
return false;
};
l0.pins.insert(lower.clone(), word.to_string());
l0.pick_counts.retain(|(p, _), _| p != &lower);
true
}
pub fn forget(&self, pinyin: &str) -> bool {
let lower = lower_str(pinyin);
let Ok(mut l0) = self.l0.write() else {
return false;
};
let had_pin = l0.pins.remove(&lower).is_some();
let len_before = l0.pick_counts.len();
l0.pick_counts.retain(|(p, _), _| p != &lower);
had_pin || l0.pick_counts.len() != len_before
}
pub fn export_l0(&self) -> L0Snapshot {
let Ok(l0) = self.l0.read() else {
return L0Snapshot::default();
};
L0Snapshot {
pins: l0
.pins
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect(),
pick_counts: l0
.pick_counts
.iter()
.map(|((p, w), n)| (p.clone(), w.clone(), *n))
.collect(),
}
}
pub fn import_l0(&self, snap: L0Snapshot) -> usize {
let valid_pins: Vec<(String, String)> = snap
.pins
.into_iter()
.filter(|(p, w)| self.exists_in_l1(p, w))
.collect();
let valid_counts: Vec<((String, String), u32)> = snap
.pick_counts
.into_iter()
.filter_map(|(p, w, n)| {
if self.exists_in_l1(&p, &w) {
Some(((p, w), n))
} else {
None
}
})
.collect();
let accepted = valid_pins.len();
let Ok(mut l0) = self.l0.write() else {
return 0;
};
l0.pins = valid_pins.into_iter().collect();
l0.pick_counts = valid_counts.into_iter().collect();
accepted
}
fn exists_in_l1(&self, pinyin: &str, word: &str) -> bool {
self.lookup(pinyin).iter().any(|w| w == word)
}
}
fn lower_str(s: &str) -> String {
s.to_ascii_lowercase()
}
fn bump_last(bytes: &[u8]) -> Vec<u8> {
let mut v = bytes.to_vec();
if let Some(last) = v.last_mut() {
if *last < 0xFF {
*last += 1;
return v;
}
}
v.push(0xFF);
v
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn embedded_loads() {
let d = PinyinDict::embedded();
assert!(d.len() >= 50, "bootstrap should have at least 50 entries");
}
#[test]
fn lookup_zhongguo_returns_zhongguo() {
let d = PinyinDict::embedded();
let words = d.lookup("zhongguo");
assert_eq!(words.first().map(String::as_str), Some("中国"));
}
#[test]
fn lookup_wo_returns_wo_first() {
let d = PinyinDict::embedded();
let words = d.lookup("wo");
assert_eq!(words.first().map(String::as_str), Some("我"));
}
#[test]
fn lookup_shi_returns_multiple() {
let d = PinyinDict::embedded();
let words = d.lookup("shi");
assert!(
words.len() >= 3,
"expected ≥3 candidates for shi, got {words:?}"
);
assert!(words.contains(&"是".to_string()));
}
#[test]
fn lookup_unknown_returns_empty() {
let d = PinyinDict::embedded();
assert!(d.lookup("qzqzqz").is_empty());
}
#[test]
fn case_insensitive() {
let d = PinyinDict::embedded();
assert_eq!(d.lookup("WO"), d.lookup("wo"));
assert_eq!(d.lookup("ZhongGuo"), d.lookup("zhongguo"));
}
#[test]
fn lookup_into_reuses_buffer() {
let d = PinyinDict::embedded();
let mut buf = Vec::with_capacity(16);
d.lookup_into("ni", &mut buf);
let cap_after_first = buf.capacity();
d.lookup_into("ta", &mut buf);
assert!(buf.capacity() >= cap_after_first);
assert_eq!(buf.first().map(String::as_str), Some("他"));
}
#[test]
fn prefix_returns_sorted_pairs() {
let d = PinyinDict::embedded();
let pairs = d.prefix("zhong");
assert!(pairs.iter().any(|(p, _)| p == "zhong"));
assert!(pairs.iter().any(|(p, w)| p == "zhongguo" && w == "中国"));
}
#[cfg(not(feature = "bootstrap_only"))]
#[test]
fn l0_starts_empty() {
let d = PinyinDict::embedded();
assert_eq!(d.l0_pin_count(), 0);
assert_eq!(d.l0_pending_count(), 0);
}
#[cfg(not(feature = "bootstrap_only"))]
#[test]
fn record_pick_promotes_after_threshold() {
let d = PinyinDict::embedded();
let target = "时";
for _ in 0..(PROMOTE_THRESHOLD - 1) {
assert!(!d.record_pick("shi", target));
}
assert!(d.record_pick("shi", target), "should promote on Nth pick");
assert_eq!(d.lookup("shi").first().map(String::as_str), Some(target));
assert_eq!(d.l0_pin_count(), 1);
assert_eq!(d.l0_pending_count(), 0);
}
#[cfg(not(feature = "bootstrap_only"))]
#[test]
fn record_pick_resets_on_promotion_so_others_must_earn_3_again() {
let d = PinyinDict::embedded();
for _ in 0..PROMOTE_THRESHOLD {
d.record_pick("shi", "时");
}
assert!(!d.record_pick("shi", "事"));
assert_eq!(d.lookup("shi").first().map(String::as_str), Some("时"));
for _ in 0..(PROMOTE_THRESHOLD - 1) {
d.record_pick("shi", "事");
}
assert_eq!(d.lookup("shi").first().map(String::as_str), Some("事"));
}
#[cfg(not(feature = "bootstrap_only"))]
#[test]
fn record_pick_rejects_unknown_word() {
let d = PinyinDict::embedded();
for _ in 0..PROMOTE_THRESHOLD {
assert!(!d.record_pick("shi", "this_is_not_a_real_word"));
}
assert_eq!(d.l0_pin_count(), 0);
assert_eq!(d.l0_pending_count(), 0);
}
#[cfg(not(feature = "bootstrap_only"))]
#[test]
fn pin_force_pins_without_counters() {
let d = PinyinDict::embedded();
assert!(d.pin("shi", "时"));
assert_eq!(d.lookup("shi").first().map(String::as_str), Some("时"));
assert_eq!(d.l0_pending_count(), 0);
}
#[cfg(not(feature = "bootstrap_only"))]
#[test]
fn forget_clears_pin_and_counters() {
let d = PinyinDict::embedded();
d.pin("shi", "时");
d.record_pick("shi", "事");
assert!(d.forget("shi"));
assert_eq!(d.l0_pin_count(), 0);
assert_eq!(d.l0_pending_count(), 0);
}
#[cfg(not(feature = "bootstrap_only"))]
#[test]
fn export_import_roundtrip() {
let d = PinyinDict::embedded();
d.pin("shi", "时");
d.record_pick("zhongguo", "中国"); let snap = d.export_l0();
assert_eq!(snap.pins.len(), 1);
assert_eq!(snap.pick_counts.len(), 1);
d.forget("shi");
d.forget("zhongguo");
assert_eq!(d.l0_pin_count(), 0);
let accepted = d.import_l0(snap);
assert_eq!(accepted, 1);
assert_eq!(d.lookup("shi").first().map(String::as_str), Some("时"));
}
#[cfg(not(feature = "bootstrap_only"))]
#[test]
fn import_drops_invalid_entries() {
let d = PinyinDict::embedded();
let snap = L0Snapshot {
pins: vec![
("shi".into(), "时".into()),
("shi".into(), "bogus_word".into()),
],
pick_counts: vec![("shi".into(), "ghost_word".into(), 2)],
};
let accepted = d.import_l0(snap);
assert_eq!(accepted, 1);
assert_eq!(d.l0_pending_count(), 0);
}
#[test]
fn l0_pin_pin_lookup_compiles() {
let d = PinyinDict::embedded();
assert!(d.pin("zhongguo", "中国"));
assert_eq!(d.l0_pin_count(), 1);
assert!(d.forget("zhongguo"));
assert_eq!(d.l0_pin_count(), 0);
}
}