use std::collections::HashMap;
use std::sync::RwLock;
use fst::{IntoStreamer, Map, Streamer};
use crate::layer::{DEFAULT_LAYER_PREFS, LAYER_COUNT, Layer, unpack};
const DICT_BYTES: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/wubi86.fst"));
pub const PROMOTE_THRESHOLD: u32 = parse_threshold_const();
const fn parse_threshold_const() -> u32 {
match option_env!("WUBI_PROMOTE_THRESHOLD") {
Some(s) => parse_u32_const(s),
None => 3,
}
}
const fn parse_u32_const(s: &str) -> u32 {
let bytes = s.as_bytes();
if bytes.is_empty() {
panic!("WUBI_PROMOTE_THRESHOLD must not be empty");
}
let mut i = 0;
let mut n: u32 = 0;
while i < bytes.len() {
let b = bytes[i];
if b < b'0' || b > b'9' {
panic!("WUBI_PROMOTE_THRESHOLD must be ASCII digits");
}
n = n * 10 + (b - b'0') as u32;
i += 1;
}
if n == 0 {
panic!("WUBI_PROMOTE_THRESHOLD must be >= 1");
}
n
}
#[derive(Debug, Clone)]
pub struct L0Snapshot {
pub pins: Vec<(String, String)>,
pub pick_counts: Vec<(String, String, u32)>,
pub layer_prefs: [f64; LAYER_COUNT],
}
#[derive(Default)]
struct L0Inner {
pins: HashMap<String, String>,
pick_counts: HashMap<(String, String), u32>,
layer_prefs: [f64; LAYER_COUNT],
}
impl L0Inner {
fn new() -> Self {
Self {
pins: HashMap::new(),
pick_counts: HashMap::new(),
layer_prefs: DEFAULT_LAYER_PREFS,
}
}
}
pub struct WubiDict {
map: Map<&'static [u8]>,
l0: RwLock<L0Inner>,
}
impl WubiDict {
pub fn embedded() -> Self {
Self {
map: Map::new(DICT_BYTES).expect("invalid embedded FST"),
l0: RwLock::new(L0Inner::new()),
}
}
pub fn len(&self) -> usize {
self.map.len()
}
pub fn is_empty(&self) -> bool {
self.map.len() == 0
}
pub fn l0_pin_count(&self) -> usize {
self.l0.read().map(|g| g.pins.len()).unwrap_or(0)
}
pub fn l0_pending_count(&self) -> usize {
self.l0.read().map(|g| g.pick_counts.len()).unwrap_or(0)
}
pub fn lookup(&self, code: &str) -> Vec<String> {
let mut out = Vec::new();
self.lookup_into(code, &mut out);
out
}
pub fn lookup_into(&self, code: &str, out: &mut Vec<String>) {
out.clear();
let lower = code.to_ascii_lowercase();
let mut prefix = lower.into_bytes();
let prefix_len = prefix.len();
prefix.push(0u8);
let mut upper = prefix.clone();
let last = upper.len() - 1;
upper[last] = 0x01;
let prefs = self
.l0
.read()
.map(|g| g.layer_prefs)
.unwrap_or(DEFAULT_LAYER_PREFS);
let mut scratch: Vec<(String, f64)> = Vec::with_capacity(8);
let mut stream = self
.map
.range()
.ge(prefix.as_slice())
.lt(upper.as_slice())
.into_stream();
while let Some((key, value)) = stream.next() {
if key.len() <= prefix_len + 1 {
continue;
}
let word_bytes = &key[prefix_len + 1..];
if let Ok(s) = core::str::from_utf8(word_bytes) {
let (layer, freq) = unpack(value);
let base = layer.base() as f64;
let pref = prefs[layer.as_index()];
scratch.push((s.to_string(), base * pref + freq as f64));
}
}
scratch.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
out.reserve(scratch.len());
for (w, _) in scratch.drain(..) {
out.push(w);
}
if let Ok(l0) = self.l0.read() {
if let Some(pref) = l0.pins.get(code) {
if let Some(idx) = out.iter().position(|w| w == pref) {
if idx > 0 {
let p = out.remove(idx);
out.insert(0, p);
}
}
}
}
}
pub fn lookup_with_meta(&self, code: &str) -> Vec<(String, Layer, u64)> {
let lower = code.to_ascii_lowercase();
let mut prefix = lower.into_bytes();
let prefix_len = prefix.len();
prefix.push(0u8);
let mut upper = prefix.clone();
let last = upper.len() - 1;
upper[last] = 0x01;
let mut stream = self
.map
.range()
.ge(prefix.as_slice())
.lt(upper.as_slice())
.into_stream();
let mut results = Vec::new();
while let Some((key, value)) = stream.next() {
if key.len() <= prefix_len + 1 {
continue;
}
let word_bytes = &key[prefix_len + 1..];
if let Ok(s) = core::str::from_utf8(word_bytes) {
let (layer, freq) = unpack(value);
results.push((s.to_string(), layer, freq));
}
}
results
}
pub fn prefix(&self, prefix: &str) -> Vec<(String, String)> {
let lower = prefix.to_ascii_lowercase();
let lo = lower.as_bytes().to_vec();
let hi = bump_last(&lo);
let prefs = self
.l0
.read()
.map(|g| g.layer_prefs)
.unwrap_or(DEFAULT_LAYER_PREFS);
let mut stream = self
.map
.range()
.ge(lo.as_slice())
.lt(hi.as_slice())
.into_stream();
let mut results: Vec<(String, String, f64)> = Vec::new();
while let Some((key, value)) = stream.next() {
let Some(sep) = key.iter().position(|b| *b == 0u8) else {
continue;
};
let (code_bytes, rest) = key.split_at(sep);
let word_bytes = &rest[1..];
if let (Ok(code), Ok(word)) = (
core::str::from_utf8(code_bytes),
core::str::from_utf8(word_bytes),
) {
let (layer, freq) = unpack(value);
let score = layer.base() as f64 * prefs[layer.as_index()] + freq as f64;
results.push((code.to_string(), word.to_string(), score));
}
}
results.sort_by(|a, b| {
b.2.partial_cmp(&a.2)
.unwrap_or(std::cmp::Ordering::Equal)
.then(a.0.cmp(&b.0))
.then(a.1.cmp(&b.1))
});
results.into_iter().map(|(c, w, _)| (c, w)).collect()
}
pub fn record_pick(&self, code: &str, word: &str) -> bool {
if !self.exists_in_l1(code, word) {
return false;
}
let Ok(mut l0) = self.l0.write() else {
return false;
};
let key = (code.to_string(), word.to_string());
let count = l0.pick_counts.entry(key).or_insert(0);
*count += 1;
if *count >= PROMOTE_THRESHOLD {
l0.pins.insert(code.to_string(), word.to_string());
l0.pick_counts.retain(|(c, _), _| c != code);
return true;
}
false
}
pub fn pin(&self, code: &str, word: &str) -> bool {
if !self.exists_in_l1(code, word) {
return false;
}
let Ok(mut l0) = self.l0.write() else {
return false;
};
l0.pins.insert(code.to_string(), word.to_string());
l0.pick_counts.retain(|(c, _), _| c != code);
true
}
pub fn forget(&self, code: &str) -> bool {
let Ok(mut l0) = self.l0.write() else {
return false;
};
let had_pin = l0.pins.remove(code).is_some();
let len_before = l0.pick_counts.len();
l0.pick_counts.retain(|(c, _), _| c != code);
had_pin || l0.pick_counts.len() != len_before
}
pub fn set_layer_pref(&self, layer: Layer, multiplier: f64) {
let m = if multiplier.is_finite() && multiplier >= 0.0 {
multiplier
} else {
0.0
};
if let Ok(mut l0) = self.l0.write() {
l0.layer_prefs[layer.as_index()] = m;
}
}
pub fn layer_pref(&self, layer: Layer) -> f64 {
self.l0
.read()
.map(|g| g.layer_prefs[layer.as_index()])
.unwrap_or(DEFAULT_LAYER_PREFS[layer.as_index()])
}
pub fn export_l0(&self) -> L0Snapshot {
let Ok(l0) = self.l0.read() else {
return L0Snapshot {
pins: Vec::new(),
pick_counts: Vec::new(),
layer_prefs: DEFAULT_LAYER_PREFS,
};
};
L0Snapshot {
pins: l0
.pins
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect(),
pick_counts: l0
.pick_counts
.iter()
.map(|((c, w), n)| (c.clone(), w.clone(), *n))
.collect(),
layer_prefs: l0.layer_prefs,
}
}
pub fn import_l0(&self, snap: L0Snapshot) -> usize {
let valid_pins: Vec<(String, String)> = snap
.pins
.into_iter()
.filter(|(c, w)| self.exists_in_l1(c, w))
.collect();
let valid_counts: Vec<((String, String), u32)> = snap
.pick_counts
.into_iter()
.filter_map(|(c, w, n)| {
if self.exists_in_l1(&c, &w) {
Some(((c, w), n))
} else {
None
}
})
.collect();
let accepted = valid_pins.len();
let Ok(mut l0) = self.l0.write() else {
return 0;
};
l0.pins = valid_pins.into_iter().collect();
l0.pick_counts = valid_counts.into_iter().collect();
l0.layer_prefs = snap.layer_prefs;
accepted
}
fn exists_in_l1(&self, code: &str, word: &str) -> bool {
self.lookup_with_meta(code)
.iter()
.any(|(w, _, _)| w == word)
}
}
fn bump_last(bytes: &[u8]) -> Vec<u8> {
let mut v = bytes.to_vec();
if let Some(last) = v.last_mut() {
if *last < 0xFF {
*last += 1;
return v;
}
}
v.push(0xFF);
v
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn embedded_loads() {
let d = WubiDict::embedded();
assert!(d.len() >= 50);
}
#[test]
fn jianma1_g_returns_yi_first() {
let d = WubiDict::embedded();
let words = d.lookup("g");
assert_eq!(words.first().map(String::as_str), Some("一"));
}
#[test]
fn khlg_phrase_outranks_extension_char() {
let d = WubiDict::embedded();
let words = d.lookup("khlg");
let zg = words.iter().position(|w| w == "中国");
let ext = words.iter().position(|w| w == "䟧");
if let (Some(zg), Some(ext)) = (zg, ext) {
assert!(zg < ext, "中国 should rank above 䟧, got {words:?}");
}
}
#[test]
fn rrrr_keyname_outranks_phrase() {
let d = WubiDict::embedded();
let words = d.lookup("rrrr");
let bai = words.iter().position(|w| w == "白");
let zhua = words.iter().position(|w| w == "抓拍");
if let (Some(bai), Some(zhua)) = (bai, zhua) {
assert!(bai < zhua, "白 should rank above 抓拍, got {words:?}");
}
}
#[test]
fn record_pick_promotes_after_threshold() {
let d = WubiDict::embedded();
assert!(!d.record_pick("khlg", "跑车"));
assert!(!d.record_pick("khlg", "跑车"));
assert!(d.record_pick("khlg", "跑车"));
assert_eq!(d.lookup("khlg").first().map(String::as_str), Some("跑车"));
assert_eq!(d.l0_pin_count(), 1);
assert_eq!(d.l0_pending_count(), 0);
}
#[test]
fn record_pick_resets_on_promotion_so_others_must_earn_3_again() {
let d = WubiDict::embedded();
for _ in 0..3 {
d.record_pick("khlg", "跑车");
}
assert!(!d.record_pick("khlg", "中国"));
assert_eq!(d.lookup("khlg").first().map(String::as_str), Some("跑车"));
assert!(!d.record_pick("khlg", "中国"));
assert!(d.record_pick("khlg", "中国"));
assert_eq!(d.lookup("khlg").first().map(String::as_str), Some("中国"));
}
#[test]
fn record_pick_rejects_unknown_word() {
let d = WubiDict::embedded();
for _ in 0..PROMOTE_THRESHOLD {
assert!(!d.record_pick("khlg", "this_is_not_a_real_word"));
}
assert_eq!(d.l0_pin_count(), 0);
assert_eq!(d.l0_pending_count(), 0);
}
#[test]
fn pin_force_pins_without_counters() {
let d = WubiDict::embedded();
assert!(d.pin("khlg", "跑车"));
assert_eq!(d.lookup("khlg").first().map(String::as_str), Some("跑车"));
}
#[test]
fn forget_clears_pin_and_counters() {
let d = WubiDict::embedded();
d.pin("khlg", "跑车");
d.record_pick("khlg", "中国");
assert!(d.forget("khlg"));
assert_eq!(d.lookup("khlg").first().map(String::as_str), Some("中国"));
assert_eq!(d.l0_pin_count(), 0);
assert_eq!(d.l0_pending_count(), 0);
}
#[test]
fn layer_pref_can_demote_a_layer() {
let d = WubiDict::embedded();
d.set_layer_pref(Layer::Phrase, 0.0);
d.set_layer_pref(Layer::Auto, 5.0);
let words = d.lookup("khlg");
let ext = words.iter().position(|w| w == "䟧");
let zg = words.iter().position(|w| w == "中国");
if let (Some(ext), Some(zg)) = (ext, zg) {
assert!(
ext < zg,
"with Phrase=0 and Auto=5, 䟧 should outrank 中国, got {words:?}"
);
}
}
#[test]
fn export_import_roundtrip() {
let d = WubiDict::embedded();
d.pin("khlg", "跑车");
d.record_pick("wqvb", "您好");
d.set_layer_pref(Layer::Phrase, 1.5);
let snap = d.export_l0();
assert_eq!(snap.pins.len(), 1);
assert_eq!(snap.pick_counts.len(), 1);
assert!((snap.layer_prefs[Layer::Phrase.as_index()] - 1.5).abs() < f64::EPSILON);
d.forget("khlg");
d.forget("wqvb");
d.set_layer_pref(Layer::Phrase, 1.0);
assert_eq!(d.l0_pin_count(), 0);
let accepted = d.import_l0(snap);
assert_eq!(accepted, 1);
assert_eq!(d.lookup("khlg").first().map(String::as_str), Some("跑车"));
assert!((d.layer_pref(Layer::Phrase) - 1.5).abs() < f64::EPSILON);
}
#[test]
fn import_drops_invalid_entries() {
let d = WubiDict::embedded();
let snap = L0Snapshot {
pins: vec![
("khlg".into(), "中国".into()),
("khlg".into(), "bogus".into()),
],
pick_counts: vec![("khlg".into(), "ghost".into(), 2)],
layer_prefs: DEFAULT_LAYER_PREFS,
};
let accepted = d.import_l0(snap);
assert_eq!(accepted, 1);
assert_eq!(d.l0_pending_count(), 0);
}
#[test]
fn set_layer_pref_clamps_negatives_and_nan() {
let d = WubiDict::embedded();
d.set_layer_pref(Layer::Phrase, -3.0);
assert_eq!(d.layer_pref(Layer::Phrase), 0.0);
d.set_layer_pref(Layer::Phrase, f64::NAN);
assert_eq!(d.layer_pref(Layer::Phrase), 0.0);
}
#[test]
fn promote_threshold_is_at_least_one() {
assert!(PROMOTE_THRESHOLD >= 1);
}
}