use std::collections::HashMap;
use std::sync::RwLock;
use inputx_fsa::Dict;
use crate::layer::{DEFAULT_LAYER_PREFS, LAYER_COUNT, Layer, unpack};
const DICT_BYTES: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/wubi86.dict"));
pub const PROMOTE_THRESHOLD: u32 = parse_threshold_const();
const fn parse_threshold_const() -> u32 {
match option_env!("WUBI_PROMOTE_THRESHOLD") {
Some(s) => parse_u32_const(s),
None => 3,
}
}
const fn parse_u32_const(s: &str) -> u32 {
let bytes = s.as_bytes();
if bytes.is_empty() {
panic!("WUBI_PROMOTE_THRESHOLD must not be empty");
}
let mut i = 0;
let mut n: u32 = 0;
while i < bytes.len() {
let b = bytes[i];
if b < b'0' || b > b'9' {
panic!("WUBI_PROMOTE_THRESHOLD must be ASCII digits");
}
n = n * 10 + (b - b'0') as u32;
i += 1;
}
if n == 0 {
panic!("WUBI_PROMOTE_THRESHOLD must be >= 1");
}
n
}
#[derive(Debug, Clone)]
pub struct L0Snapshot {
pub pins: Vec<(String, String)>,
pub pick_counts: Vec<(String, String, u32)>,
pub layer_prefs: [f64; LAYER_COUNT],
}
#[derive(Default)]
struct L0Inner {
pins: HashMap<String, String>,
pick_counts: HashMap<(String, String), u32>,
layer_prefs: [f64; LAYER_COUNT],
}
impl L0Inner {
fn new() -> Self {
Self {
pins: HashMap::new(),
pick_counts: HashMap::new(),
layer_prefs: DEFAULT_LAYER_PREFS,
}
}
}
pub struct WubiDict {
map: Dict<&'static [u8]>,
l0: RwLock<L0Inner>,
}
impl WubiDict {
pub fn embedded() -> Self {
Self {
map: Dict::new(DICT_BYTES).expect("invalid embedded wubi dict"),
l0: RwLock::new(L0Inner::new()),
}
}
pub fn len(&self) -> usize {
self.map.len() as usize
}
pub fn is_empty(&self) -> bool {
self.map.is_empty()
}
pub fn l0_pin_count(&self) -> usize {
self.l0.read().map(|g| g.pins.len()).unwrap_or(0)
}
pub fn l0_pending_count(&self) -> usize {
self.l0.read().map(|g| g.pick_counts.len()).unwrap_or(0)
}
pub fn lookup(&self, code: &str) -> Vec<String> {
let mut out = Vec::new();
self.lookup_into(code, &mut out);
out
}
pub fn lookup_with_scores_into(&self, code: &str, out: &mut Vec<(String, f64)>) {
let mut layered = Vec::with_capacity(out.capacity());
self.lookup_with_layer_into(code, &mut layered);
out.clear();
out.reserve(layered.len());
for (w, score, _layer) in layered.drain(..) {
out.push((w, score));
}
}
pub fn lookup_with_layer_into(
&self,
code: &str,
out: &mut Vec<(String, f64, Layer)>,
) {
out.clear();
let lower = code.to_ascii_lowercase();
let prefs = self
.l0
.read()
.map(|g| g.layer_prefs)
.unwrap_or(DEFAULT_LAYER_PREFS);
let full_code = lower.len() == 4;
let mut scratch: Vec<(String, f64, bool, u64, Layer)> = Vec::with_capacity(8);
let mut max_phrase_freq: u64 = 0;
self.map.get_for_each(lower.as_bytes(), |word, value| {
if let Ok(s) = core::str::from_utf8(word) {
let (layer, freq) = unpack(value);
let base = layer.base() as f64;
let pref = prefs[layer.as_index()];
let is_single = s.chars().count() == 1;
if !is_single && freq > max_phrase_freq {
max_phrase_freq = freq;
}
scratch.push((s.to_string(), base * pref + freq as f64, is_single, freq, layer));
}
});
let pinned: Option<String> = self.l0.read().ok().and_then(|g| g.pins.get(&lower).cloned());
for e in scratch.iter_mut() {
let promote = full_code && e.2 && e.3 > max_phrase_freq;
if promote {
e.1 *= 100.0;
}
if let Some(p) = &pinned
&& &e.0 == p
{
e.1 *= 1000.0;
}
}
scratch.sort_by(|a, b| {
b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
});
out.reserve(scratch.len());
for (w, score, _, _, layer) in scratch.drain(..) {
out.push((w, score, layer));
}
}
pub fn lookup_into(&self, code: &str, out: &mut Vec<String>) {
out.clear();
let lower = code.to_ascii_lowercase();
let prefix_len = lower.len();
let prefs = self
.l0
.read()
.map(|g| g.layer_prefs)
.unwrap_or(DEFAULT_LAYER_PREFS);
let full_code = prefix_len == 4;
let mut scratch: Vec<(String, f64, bool, u64)> = Vec::with_capacity(8);
let mut max_phrase_freq: u64 = 0;
self.map.get_for_each(lower.as_bytes(), |word, value| {
if let Ok(s) = core::str::from_utf8(word) {
let (layer, freq) = unpack(value);
let base = layer.base() as f64;
let pref = prefs[layer.as_index()];
let is_single = s.chars().count() == 1;
if !is_single && freq > max_phrase_freq {
max_phrase_freq = freq;
}
scratch.push((
s.to_string(),
base * pref + freq as f64,
is_single,
freq,
));
}
});
scratch.sort_by(|a, b| {
let a_promote = full_code && a.2 && a.3 > max_phrase_freq;
let b_promote = full_code && b.2 && b.3 > max_phrase_freq;
if a_promote != b_promote {
return if a_promote {
std::cmp::Ordering::Less
} else {
std::cmp::Ordering::Greater
};
}
b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
});
out.reserve(scratch.len());
for (w, _, _, _) in scratch.drain(..) {
out.push(w);
}
if let Ok(l0) = self.l0.read() {
if let Some(pref) = l0.pins.get(code) {
if let Some(idx) = out.iter().position(|w| w == pref) {
if idx > 0 {
let p = out.remove(idx);
out.insert(0, p);
}
}
}
}
}
pub fn lookup_with_meta(&self, code: &str) -> Vec<(String, Layer, u64)> {
let lower = code.to_ascii_lowercase();
let mut results = Vec::new();
self.map.get_for_each(lower.as_bytes(), |word, value| {
if let Ok(s) = core::str::from_utf8(word) {
let (layer, freq) = unpack(value);
results.push((s.to_string(), layer, freq));
}
});
results
}
pub fn prefix_predictions(&self, prefix: &str) -> Vec<(String, u64, usize)> {
let lower = prefix.to_ascii_lowercase();
let prefix_len = lower.len();
let mut results: Vec<(String, u64, usize)> = Vec::new();
self.map.prefix_for_each(lower.as_bytes(), |code_bytes, word_bytes, value| {
if code_bytes.len() <= prefix_len {
return;
}
if let (Ok(_code), Ok(word)) = (
core::str::from_utf8(code_bytes),
core::str::from_utf8(word_bytes),
) {
let (_layer, freq) = unpack(value);
results.push((word.to_string(), freq, code_bytes.len()));
}
});
results.sort_by(|a, b| {
b.1.cmp(&a.1).then(a.0.cmp(&b.0))
});
results
}
pub fn lookup_with_freq_layer_into(
&self,
code: &str,
out: &mut Vec<(String, Layer, u64)>,
) {
out.clear();
let lower = code.to_ascii_lowercase();
self.map.get_for_each(lower.as_bytes(), |word, value| {
if let Ok(s) = core::str::from_utf8(word) {
let (layer, freq) = unpack(value);
out.push((s.to_string(), layer, freq));
}
});
}
pub fn all_entries(&self) -> Vec<(String, String, Layer, u64)> {
let mut results: Vec<(String, String, Layer, u64)> = Vec::new();
self.map.prefix_for_each(b"", |code_bytes, word_bytes, value| {
if let (Ok(code), Ok(word)) = (
core::str::from_utf8(code_bytes),
core::str::from_utf8(word_bytes),
) {
let (layer, freq) = unpack(value);
results.push((code.to_string(), word.to_string(), layer, freq));
}
});
results
}
pub fn prefix(&self, prefix: &str) -> Vec<(String, String)> {
let lower = prefix.to_ascii_lowercase();
let prefs = self
.l0
.read()
.map(|g| g.layer_prefs)
.unwrap_or(DEFAULT_LAYER_PREFS);
let mut results: Vec<(String, String, f64)> = Vec::new();
self.map.prefix_for_each(lower.as_bytes(), |code_bytes, word_bytes, value| {
if let (Ok(code), Ok(word)) = (
core::str::from_utf8(code_bytes),
core::str::from_utf8(word_bytes),
) {
let (layer, freq) = unpack(value);
let score = layer.base() as f64 * prefs[layer.as_index()] + freq as f64;
results.push((code.to_string(), word.to_string(), score));
}
});
results.sort_by(|a, b| {
b.2.partial_cmp(&a.2)
.unwrap_or(std::cmp::Ordering::Equal)
.then(a.0.cmp(&b.0))
.then(a.1.cmp(&b.1))
});
results.into_iter().map(|(c, w, _)| (c, w)).collect()
}
pub fn record_pick(&self, code: &str, word: &str) -> bool {
if !self.exists_in_l1(code, word) {
return false;
}
let Ok(mut l0) = self.l0.write() else {
return false;
};
let key = (code.to_string(), word.to_string());
let count = l0.pick_counts.entry(key).or_insert(0);
*count += 1;
if *count >= PROMOTE_THRESHOLD {
l0.pins.insert(code.to_string(), word.to_string());
l0.pick_counts.retain(|(c, _), _| c != code);
return true;
}
false
}
pub fn pin(&self, code: &str, word: &str) -> bool {
if !self.exists_in_l1(code, word) {
return false;
}
let Ok(mut l0) = self.l0.write() else {
return false;
};
l0.pins.insert(code.to_string(), word.to_string());
l0.pick_counts.retain(|(c, _), _| c != code);
true
}
pub fn forget(&self, code: &str) -> bool {
let Ok(mut l0) = self.l0.write() else {
return false;
};
let had_pin = l0.pins.remove(code).is_some();
let len_before = l0.pick_counts.len();
l0.pick_counts.retain(|(c, _), _| c != code);
had_pin || l0.pick_counts.len() != len_before
}
pub fn set_layer_pref(&self, layer: Layer, multiplier: f64) {
let m = if multiplier.is_finite() && multiplier >= 0.0 {
multiplier
} else {
0.0
};
if let Ok(mut l0) = self.l0.write() {
l0.layer_prefs[layer.as_index()] = m;
}
}
pub fn layer_pref(&self, layer: Layer) -> f64 {
self.l0
.read()
.map(|g| g.layer_prefs[layer.as_index()])
.unwrap_or(DEFAULT_LAYER_PREFS[layer.as_index()])
}
pub fn export_l0(&self) -> L0Snapshot {
let Ok(l0) = self.l0.read() else {
return L0Snapshot {
pins: Vec::new(),
pick_counts: Vec::new(),
layer_prefs: DEFAULT_LAYER_PREFS,
};
};
L0Snapshot {
pins: l0
.pins
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect(),
pick_counts: l0
.pick_counts
.iter()
.map(|((c, w), n)| (c.clone(), w.clone(), *n))
.collect(),
layer_prefs: l0.layer_prefs,
}
}
pub fn import_l0(&self, snap: L0Snapshot) -> usize {
let valid_pins: Vec<(String, String)> = snap
.pins
.into_iter()
.filter(|(c, w)| self.exists_in_l1(c, w))
.collect();
let valid_counts: Vec<((String, String), u32)> = snap
.pick_counts
.into_iter()
.filter_map(|(c, w, n)| {
if self.exists_in_l1(&c, &w) {
Some(((c, w), n))
} else {
None
}
})
.collect();
let accepted = valid_pins.len();
let Ok(mut l0) = self.l0.write() else {
return 0;
};
l0.pins = valid_pins.into_iter().collect();
l0.pick_counts = valid_counts.into_iter().collect();
l0.layer_prefs = snap.layer_prefs;
accepted
}
fn exists_in_l1(&self, code: &str, word: &str) -> bool {
self.lookup_with_meta(code)
.iter()
.any(|(w, _, _)| w == word)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn embedded_loads() {
let d = WubiDict::embedded();
assert!(d.len() >= 50);
}
#[test]
fn jianma1_g_returns_yi_first() {
let d = WubiDict::embedded();
let words = d.lookup("g");
assert_eq!(words.first().map(String::as_str), Some("一"));
}
#[test]
fn khlg_phrase_outranks_extension_char() {
let d = WubiDict::embedded();
let words = d.lookup("khlg");
let zg = words.iter().position(|w| w == "中国");
let ext = words.iter().position(|w| w == "䟧");
if let (Some(zg), Some(ext)) = (zg, ext) {
assert!(zg < ext, "中国 should rank above 䟧, got {words:?}");
}
}
#[test]
fn rrrr_keyname_outranks_phrase() {
let d = WubiDict::embedded();
let words = d.lookup("rrrr");
let bai = words.iter().position(|w| w == "白");
let zhua = words.iter().position(|w| w == "抓拍");
if let (Some(bai), Some(zhua)) = (bai, zhua) {
assert!(bai < zhua, "白 should rank above 抓拍, got {words:?}");
}
}
#[test]
fn record_pick_promotes_after_threshold() {
let d = WubiDict::embedded();
assert!(!d.record_pick("khlg", "跑车"));
assert!(!d.record_pick("khlg", "跑车"));
assert!(d.record_pick("khlg", "跑车"));
assert_eq!(d.lookup("khlg").first().map(String::as_str), Some("跑车"));
assert_eq!(d.l0_pin_count(), 1);
assert_eq!(d.l0_pending_count(), 0);
}
#[test]
fn record_pick_resets_on_promotion_so_others_must_earn_3_again() {
let d = WubiDict::embedded();
for _ in 0..3 {
d.record_pick("khlg", "跑车");
}
assert!(!d.record_pick("khlg", "中国"));
assert_eq!(d.lookup("khlg").first().map(String::as_str), Some("跑车"));
assert!(!d.record_pick("khlg", "中国"));
assert!(d.record_pick("khlg", "中国"));
assert_eq!(d.lookup("khlg").first().map(String::as_str), Some("中国"));
}
#[test]
fn record_pick_rejects_unknown_word() {
let d = WubiDict::embedded();
for _ in 0..PROMOTE_THRESHOLD {
assert!(!d.record_pick("khlg", "this_is_not_a_real_word"));
}
assert_eq!(d.l0_pin_count(), 0);
assert_eq!(d.l0_pending_count(), 0);
}
#[test]
fn pin_force_pins_without_counters() {
let d = WubiDict::embedded();
assert!(d.pin("khlg", "跑车"));
assert_eq!(d.lookup("khlg").first().map(String::as_str), Some("跑车"));
}
#[test]
fn forget_clears_pin_and_counters() {
let d = WubiDict::embedded();
d.pin("khlg", "跑车");
d.record_pick("khlg", "中国");
assert!(d.forget("khlg"));
assert_eq!(d.lookup("khlg").first().map(String::as_str), Some("中国"));
assert_eq!(d.l0_pin_count(), 0);
assert_eq!(d.l0_pending_count(), 0);
}
#[test]
fn layer_pref_can_demote_a_layer() {
let d = WubiDict::embedded();
d.set_layer_pref(Layer::Phrase, 0.0);
d.set_layer_pref(Layer::Auto, 5.0);
let words = d.lookup("khlg");
let ext = words.iter().position(|w| w == "䟧");
let zg = words.iter().position(|w| w == "中国");
if let (Some(ext), Some(zg)) = (ext, zg) {
assert!(
ext < zg,
"with Phrase=0 and Auto=5, 䟧 should outrank 中国, got {words:?}"
);
}
}
#[test]
fn export_import_roundtrip() {
let d = WubiDict::embedded();
d.pin("khlg", "跑车");
d.record_pick("wqvb", "您好");
d.set_layer_pref(Layer::Phrase, 1.5);
let snap = d.export_l0();
assert_eq!(snap.pins.len(), 1);
assert_eq!(snap.pick_counts.len(), 1);
assert!((snap.layer_prefs[Layer::Phrase.as_index()] - 1.5).abs() < f64::EPSILON);
d.forget("khlg");
d.forget("wqvb");
d.set_layer_pref(Layer::Phrase, 1.0);
assert_eq!(d.l0_pin_count(), 0);
let accepted = d.import_l0(snap);
assert_eq!(accepted, 1);
assert_eq!(d.lookup("khlg").first().map(String::as_str), Some("跑车"));
assert!((d.layer_pref(Layer::Phrase) - 1.5).abs() < f64::EPSILON);
}
#[test]
fn import_drops_invalid_entries() {
let d = WubiDict::embedded();
let snap = L0Snapshot {
pins: vec![
("khlg".into(), "中国".into()),
("khlg".into(), "bogus".into()),
],
pick_counts: vec![("khlg".into(), "ghost".into(), 2)],
layer_prefs: DEFAULT_LAYER_PREFS,
};
let accepted = d.import_l0(snap);
assert_eq!(accepted, 1);
assert_eq!(d.l0_pending_count(), 0);
}
#[test]
fn set_layer_pref_clamps_negatives_and_nan() {
let d = WubiDict::embedded();
d.set_layer_pref(Layer::Phrase, -3.0);
assert_eq!(d.layer_pref(Layer::Phrase), 0.0);
d.set_layer_pref(Layer::Phrase, f64::NAN);
assert_eq!(d.layer_pref(Layer::Phrase), 0.0);
}
const _: () = assert!(PROMOTE_THRESHOLD >= 1);
}