1use std::collections::HashMap;
26use std::sync::RwLock;
27
28use fst::{IntoStreamer, Map, Streamer};
29
30use crate::layer::{DEFAULT_LAYER_PREFS, LAYER_COUNT, Layer, unpack};
31
32const DICT_BYTES: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/wubi86.fst"));
33
34pub const PROMOTE_THRESHOLD: u32 = parse_threshold_const();
39
40const fn parse_threshold_const() -> u32 {
41 match option_env!("WUBI_PROMOTE_THRESHOLD") {
42 Some(s) => parse_u32_const(s),
43 None => 3,
44 }
45}
46
47const fn parse_u32_const(s: &str) -> u32 {
48 let bytes = s.as_bytes();
49 if bytes.is_empty() {
50 panic!("WUBI_PROMOTE_THRESHOLD must not be empty");
51 }
52 let mut i = 0;
53 let mut n: u32 = 0;
54 while i < bytes.len() {
55 let b = bytes[i];
56 if b < b'0' || b > b'9' {
57 panic!("WUBI_PROMOTE_THRESHOLD must be ASCII digits");
58 }
59 n = n * 10 + (b - b'0') as u32;
60 i += 1;
61 }
62 if n == 0 {
63 panic!("WUBI_PROMOTE_THRESHOLD must be >= 1");
64 }
65 n
66}
67
68#[derive(Debug, Clone)]
72pub struct L0Snapshot {
73 pub pins: Vec<(String, String)>,
76 pub pick_counts: Vec<(String, String, u32)>,
81 pub layer_prefs: [f64; LAYER_COUNT],
83}
84
85#[derive(Default)]
86struct L0Inner {
87 pins: HashMap<String, String>,
88 pick_counts: HashMap<(String, String), u32>,
89 layer_prefs: [f64; LAYER_COUNT],
90}
91
92impl L0Inner {
93 fn new() -> Self {
94 Self {
95 pins: HashMap::new(),
96 pick_counts: HashMap::new(),
97 layer_prefs: DEFAULT_LAYER_PREFS,
98 }
99 }
100}
101
102pub struct WubiDict {
110 map: Map<&'static [u8]>,
111 l0: RwLock<L0Inner>,
112}
113
114impl WubiDict {
115 pub fn embedded() -> Self {
119 Self {
120 map: Map::new(DICT_BYTES).expect("invalid embedded FST"),
121 l0: RwLock::new(L0Inner::new()),
122 }
123 }
124
125 pub fn len(&self) -> usize {
128 self.map.len()
129 }
130
131 pub fn is_empty(&self) -> bool {
135 self.map.len() == 0
136 }
137
138 pub fn l0_pin_count(&self) -> usize {
140 self.l0.read().map(|g| g.pins.len()).unwrap_or(0)
141 }
142
143 pub fn l0_pending_count(&self) -> usize {
145 self.l0.read().map(|g| g.pick_counts.len()).unwrap_or(0)
146 }
147
148 pub fn lookup(&self, code: &str) -> Vec<String> {
161 let mut out = Vec::new();
162 self.lookup_into(code, &mut out);
163 out
164 }
165
166 pub fn lookup_into(&self, code: &str, out: &mut Vec<String>) {
176 out.clear();
177
178 let lower = code.to_ascii_lowercase();
180 let mut prefix = lower.into_bytes();
181 let prefix_len = prefix.len();
182 prefix.push(0u8);
183
184 let mut upper = prefix.clone();
185 let last = upper.len() - 1;
186 upper[last] = 0x01;
187
188 let prefs = self
189 .l0
190 .read()
191 .map(|g| g.layer_prefs)
192 .unwrap_or(DEFAULT_LAYER_PREFS);
193
194 let mut scratch: Vec<(String, f64)> = Vec::with_capacity(8);
196 let mut stream = self
197 .map
198 .range()
199 .ge(prefix.as_slice())
200 .lt(upper.as_slice())
201 .into_stream();
202 while let Some((key, value)) = stream.next() {
203 if key.len() <= prefix_len + 1 {
204 continue;
205 }
206 let word_bytes = &key[prefix_len + 1..];
207 if let Ok(s) = core::str::from_utf8(word_bytes) {
208 let (layer, freq) = unpack(value);
209 let base = layer.base() as f64;
210 let pref = prefs[layer.as_index()];
211 scratch.push((s.to_string(), base * pref + freq as f64));
212 }
213 }
214 scratch.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
215
216 out.reserve(scratch.len());
217 for (w, _) in scratch.drain(..) {
218 out.push(w);
219 }
220
221 if let Ok(l0) = self.l0.read() {
223 if let Some(pref) = l0.pins.get(code) {
224 if let Some(idx) = out.iter().position(|w| w == pref) {
225 if idx > 0 {
226 let p = out.remove(idx);
227 out.insert(0, p);
228 }
229 }
230 }
231 }
232 }
233
234 pub fn lookup_with_meta(&self, code: &str) -> Vec<(String, Layer, u64)> {
238 let lower = code.to_ascii_lowercase();
239 let mut prefix = lower.into_bytes();
240 let prefix_len = prefix.len();
241 prefix.push(0u8);
242
243 let mut upper = prefix.clone();
244 let last = upper.len() - 1;
245 upper[last] = 0x01;
246
247 let mut stream = self
248 .map
249 .range()
250 .ge(prefix.as_slice())
251 .lt(upper.as_slice())
252 .into_stream();
253
254 let mut results = Vec::new();
255 while let Some((key, value)) = stream.next() {
256 if key.len() <= prefix_len + 1 {
257 continue;
258 }
259 let word_bytes = &key[prefix_len + 1..];
260 if let Ok(s) = core::str::from_utf8(word_bytes) {
261 let (layer, freq) = unpack(value);
262 results.push((s.to_string(), layer, freq));
263 }
264 }
265 results
266 }
267
268 pub fn prefix(&self, prefix: &str) -> Vec<(String, String)> {
272 let lower = prefix.to_ascii_lowercase();
273 let lo = lower.as_bytes().to_vec();
274 let hi = bump_last(&lo);
275
276 let prefs = self
277 .l0
278 .read()
279 .map(|g| g.layer_prefs)
280 .unwrap_or(DEFAULT_LAYER_PREFS);
281
282 let mut stream = self
283 .map
284 .range()
285 .ge(lo.as_slice())
286 .lt(hi.as_slice())
287 .into_stream();
288
289 let mut results: Vec<(String, String, f64)> = Vec::new();
290 while let Some((key, value)) = stream.next() {
291 let Some(sep) = key.iter().position(|b| *b == 0u8) else {
292 continue;
293 };
294 let (code_bytes, rest) = key.split_at(sep);
295 let word_bytes = &rest[1..];
296 if let (Ok(code), Ok(word)) = (
297 core::str::from_utf8(code_bytes),
298 core::str::from_utf8(word_bytes),
299 ) {
300 let (layer, freq) = unpack(value);
301 let score = layer.base() as f64 * prefs[layer.as_index()] + freq as f64;
302 results.push((code.to_string(), word.to_string(), score));
303 }
304 }
305 results.sort_by(|a, b| {
306 b.2.partial_cmp(&a.2)
307 .unwrap_or(std::cmp::Ordering::Equal)
308 .then(a.0.cmp(&b.0))
309 .then(a.1.cmp(&b.1))
310 });
311 results.into_iter().map(|(c, w, _)| (c, w)).collect()
312 }
313
314 pub fn record_pick(&self, code: &str, word: &str) -> bool {
327 if !self.exists_in_l1(code, word) {
328 return false;
329 }
330 let Ok(mut l0) = self.l0.write() else {
331 return false;
332 };
333 let key = (code.to_string(), word.to_string());
334 let count = l0.pick_counts.entry(key).or_insert(0);
335 *count += 1;
336 if *count >= PROMOTE_THRESHOLD {
337 l0.pins.insert(code.to_string(), word.to_string());
338 l0.pick_counts.retain(|(c, _), _| c != code);
339 return true;
340 }
341 false
342 }
343
344 pub fn pin(&self, code: &str, word: &str) -> bool {
347 if !self.exists_in_l1(code, word) {
348 return false;
349 }
350 let Ok(mut l0) = self.l0.write() else {
351 return false;
352 };
353 l0.pins.insert(code.to_string(), word.to_string());
354 l0.pick_counts.retain(|(c, _), _| c != code);
355 true
356 }
357
358 pub fn forget(&self, code: &str) -> bool {
361 let Ok(mut l0) = self.l0.write() else {
362 return false;
363 };
364 let had_pin = l0.pins.remove(code).is_some();
365 let len_before = l0.pick_counts.len();
366 l0.pick_counts.retain(|(c, _), _| c != code);
367 had_pin || l0.pick_counts.len() != len_before
368 }
369
370 pub fn set_layer_pref(&self, layer: Layer, multiplier: f64) {
373 let m = if multiplier.is_finite() && multiplier >= 0.0 {
374 multiplier
375 } else {
376 0.0
377 };
378 if let Ok(mut l0) = self.l0.write() {
379 l0.layer_prefs[layer.as_index()] = m;
380 }
381 }
382
383 pub fn layer_pref(&self, layer: Layer) -> f64 {
386 self.l0
387 .read()
388 .map(|g| g.layer_prefs[layer.as_index()])
389 .unwrap_or(DEFAULT_LAYER_PREFS[layer.as_index()])
390 }
391
392 pub fn export_l0(&self) -> L0Snapshot {
396 let Ok(l0) = self.l0.read() else {
397 return L0Snapshot {
398 pins: Vec::new(),
399 pick_counts: Vec::new(),
400 layer_prefs: DEFAULT_LAYER_PREFS,
401 };
402 };
403 L0Snapshot {
404 pins: l0
405 .pins
406 .iter()
407 .map(|(k, v)| (k.clone(), v.clone()))
408 .collect(),
409 pick_counts: l0
410 .pick_counts
411 .iter()
412 .map(|((c, w), n)| (c.clone(), w.clone(), *n))
413 .collect(),
414 layer_prefs: l0.layer_prefs,
415 }
416 }
417
418 pub fn import_l0(&self, snap: L0Snapshot) -> usize {
422 let valid_pins: Vec<(String, String)> = snap
424 .pins
425 .into_iter()
426 .filter(|(c, w)| self.exists_in_l1(c, w))
427 .collect();
428 let valid_counts: Vec<((String, String), u32)> = snap
429 .pick_counts
430 .into_iter()
431 .filter_map(|(c, w, n)| {
432 if self.exists_in_l1(&c, &w) {
433 Some(((c, w), n))
434 } else {
435 None
436 }
437 })
438 .collect();
439 let accepted = valid_pins.len();
440
441 let Ok(mut l0) = self.l0.write() else {
442 return 0;
443 };
444 l0.pins = valid_pins.into_iter().collect();
445 l0.pick_counts = valid_counts.into_iter().collect();
446 l0.layer_prefs = snap.layer_prefs;
447 accepted
448 }
449
450 fn exists_in_l1(&self, code: &str, word: &str) -> bool {
451 self.lookup_with_meta(code)
452 .iter()
453 .any(|(w, _, _)| w == word)
454 }
455}
456
457fn bump_last(bytes: &[u8]) -> Vec<u8> {
458 let mut v = bytes.to_vec();
459 if let Some(last) = v.last_mut() {
460 if *last < 0xFF {
461 *last += 1;
462 return v;
463 }
464 }
465 v.push(0xFF);
466 v
467}
468
469#[cfg(test)]
470mod tests {
471 use super::*;
472
473 #[test]
474 fn embedded_loads() {
475 let d = WubiDict::embedded();
476 assert!(d.len() >= 50);
477 }
478
479 #[test]
480 fn jianma1_g_returns_yi_first() {
481 let d = WubiDict::embedded();
482 let words = d.lookup("g");
483 assert_eq!(words.first().map(String::as_str), Some("一"));
484 }
485
486 #[test]
487 fn khlg_phrase_outranks_extension_char() {
488 let d = WubiDict::embedded();
489 let words = d.lookup("khlg");
490 let zg = words.iter().position(|w| w == "中国");
491 let ext = words.iter().position(|w| w == "䟧");
492 if let (Some(zg), Some(ext)) = (zg, ext) {
493 assert!(zg < ext, "中国 should rank above 䟧, got {words:?}");
494 }
495 }
496
497 #[test]
498 fn rrrr_keyname_outranks_phrase() {
499 let d = WubiDict::embedded();
500 let words = d.lookup("rrrr");
501 let bai = words.iter().position(|w| w == "白");
502 let zhua = words.iter().position(|w| w == "抓拍");
503 if let (Some(bai), Some(zhua)) = (bai, zhua) {
504 assert!(bai < zhua, "白 should rank above 抓拍, got {words:?}");
505 }
506 }
507
508 #[test]
509 fn record_pick_promotes_after_threshold() {
510 let d = WubiDict::embedded();
511 assert!(!d.record_pick("khlg", "跑车"));
513 assert!(!d.record_pick("khlg", "跑车"));
514 assert!(d.record_pick("khlg", "跑车"));
515 assert_eq!(d.lookup("khlg").first().map(String::as_str), Some("跑车"));
516 assert_eq!(d.l0_pin_count(), 1);
517 assert_eq!(d.l0_pending_count(), 0);
519 }
520
521 #[test]
522 fn record_pick_resets_on_promotion_so_others_must_earn_3_again() {
523 let d = WubiDict::embedded();
524 for _ in 0..3 {
526 d.record_pick("khlg", "跑车");
527 }
528 assert!(!d.record_pick("khlg", "中国"));
530 assert_eq!(d.lookup("khlg").first().map(String::as_str), Some("跑车"));
531 assert!(!d.record_pick("khlg", "中国"));
533 assert!(d.record_pick("khlg", "中国"));
534 assert_eq!(d.lookup("khlg").first().map(String::as_str), Some("中国"));
535 }
536
537 #[test]
538 fn record_pick_rejects_unknown_word() {
539 let d = WubiDict::embedded();
540 for _ in 0..PROMOTE_THRESHOLD {
541 assert!(!d.record_pick("khlg", "this_is_not_a_real_word"));
542 }
543 assert_eq!(d.l0_pin_count(), 0);
544 assert_eq!(d.l0_pending_count(), 0);
545 }
546
547 #[test]
548 fn pin_force_pins_without_counters() {
549 let d = WubiDict::embedded();
550 assert!(d.pin("khlg", "跑车"));
551 assert_eq!(d.lookup("khlg").first().map(String::as_str), Some("跑车"));
552 }
553
554 #[test]
555 fn forget_clears_pin_and_counters() {
556 let d = WubiDict::embedded();
557 d.pin("khlg", "跑车");
558 d.record_pick("khlg", "中国");
559 assert!(d.forget("khlg"));
560 assert_eq!(d.lookup("khlg").first().map(String::as_str), Some("中国"));
561 assert_eq!(d.l0_pin_count(), 0);
562 assert_eq!(d.l0_pending_count(), 0);
563 }
564
565 #[test]
566 fn layer_pref_can_demote_a_layer() {
567 let d = WubiDict::embedded();
568 d.set_layer_pref(Layer::Phrase, 0.0);
573 d.set_layer_pref(Layer::Auto, 5.0);
574 let words = d.lookup("khlg");
575 let ext = words.iter().position(|w| w == "䟧");
576 let zg = words.iter().position(|w| w == "中国");
577 if let (Some(ext), Some(zg)) = (ext, zg) {
578 assert!(
579 ext < zg,
580 "with Phrase=0 and Auto=5, 䟧 should outrank 中国, got {words:?}"
581 );
582 }
583 }
584
585 #[test]
586 fn export_import_roundtrip() {
587 let d = WubiDict::embedded();
588 d.pin("khlg", "跑车");
589 d.record_pick("wqvb", "您好");
590 d.set_layer_pref(Layer::Phrase, 1.5);
591 let snap = d.export_l0();
592 assert_eq!(snap.pins.len(), 1);
593 assert_eq!(snap.pick_counts.len(), 1);
594 assert!((snap.layer_prefs[Layer::Phrase.as_index()] - 1.5).abs() < f64::EPSILON);
595
596 d.forget("khlg");
597 d.forget("wqvb");
598 d.set_layer_pref(Layer::Phrase, 1.0);
599 assert_eq!(d.l0_pin_count(), 0);
600
601 let accepted = d.import_l0(snap);
602 assert_eq!(accepted, 1);
603 assert_eq!(d.lookup("khlg").first().map(String::as_str), Some("跑车"));
604 assert!((d.layer_pref(Layer::Phrase) - 1.5).abs() < f64::EPSILON);
605 }
606
607 #[test]
608 fn import_drops_invalid_entries() {
609 let d = WubiDict::embedded();
610 let snap = L0Snapshot {
611 pins: vec![
612 ("khlg".into(), "中国".into()),
613 ("khlg".into(), "bogus".into()),
614 ],
615 pick_counts: vec![("khlg".into(), "ghost".into(), 2)],
616 layer_prefs: DEFAULT_LAYER_PREFS,
617 };
618 let accepted = d.import_l0(snap);
619 assert_eq!(accepted, 1);
620 assert_eq!(d.l0_pending_count(), 0);
621 }
622
623 #[test]
624 fn set_layer_pref_clamps_negatives_and_nan() {
625 let d = WubiDict::embedded();
626 d.set_layer_pref(Layer::Phrase, -3.0);
627 assert_eq!(d.layer_pref(Layer::Phrase), 0.0);
628 d.set_layer_pref(Layer::Phrase, f64::NAN);
629 assert_eq!(d.layer_pref(Layer::Phrase), 0.0);
630 }
631
632 const _: () = assert!(PROMOTE_THRESHOLD >= 1);
636}