Skip to main content

inputx_fsa/
dict.rs

1//! `Dict` — a one-code-to-many-items map: each byte *code* maps to an ordered
2//! list of `(item_bytes, u64)` pairs. Built for IME dictionaries where a code
3//! (a pinyin string / a wubi code) yields several candidate words each with a
4//! score.
5//!
6//! This is the two-level design that makes the index small: the [`Fsa`]
7//! automaton holds only the *codes* (heavy prefix sharing → tiny), and the
8//! items live once in a packed blob the automaton points into. Compared to
9//! flattening `code\0word → score` into a single automaton, this keeps the
10//! unique word bytes out of the state graph — on the shipped pinyin set it is
11//! ~3.98 MB vs ~7.4 MB flattened (and beats the general-purpose `fst` crate's
12//! 4.54 MB), at zero dependencies.
13//!
14//! Combined buffer layout: `magic "IXDC" (4) · fsa_len u32 · fsa_bytes · blob`.
15//! Blob record (per code, addressed by the automaton's u64 value = byte
16//! offset): `n_items uvarint · [item_len uvarint · item_bytes · value uvarint]×n`.
17
18use alloc::vec::Vec;
19
20use crate::builder::{write_uvarint, Builder};
21use crate::reader::{rd_u32, rd_uvarint, Fsa, FsaError};
22
23const MAGIC: &[u8; 4] = b"IXDC";
24
25/// Accumulates `(code, item, value)` triples and serializes a [`Dict`].
26#[derive(Default)]
27pub struct DictBuilder {
28    triples: Vec<(Vec<u8>, Vec<u8>, u64)>,
29}
30
31impl DictBuilder {
32    pub fn new() -> Self {
33        Self {
34            triples: Vec::new(),
35        }
36    }
37
38    /// Add `code → (item, value)`. A code may be inserted many times (one per
39    /// item). Insertion order across codes does not matter (the builder
40    /// groups + sorts); within a code, items are stored sorted by value
41    /// descending, then item bytes — so position 0 is the top-scoring item
42    /// and the output is fully deterministic.
43    pub fn insert(&mut self, code: &[u8], item: &[u8], value: u64) {
44        self.triples.push((code.to_vec(), item.to_vec(), value));
45    }
46
47    pub fn finish(mut self) -> Vec<u8> {
48        // Group by code. Sort triples by (code, value desc, item) so each
49        // code's run is contiguous and internally ranked.
50        self.triples.sort_by(|a, b| {
51            a.0.cmp(&b.0)
52                .then(b.2.cmp(&a.2)) // value descending
53                .then(a.1.cmp(&b.1))
54        });
55
56        let mut blob: Vec<u8> = Vec::new();
57        let mut fsa = Builder::new();
58
59        let mut i = 0;
60        while i < self.triples.len() {
61            let code = &self.triples[i].0;
62            // span of this code
63            let mut j = i;
64            while j < self.triples.len() && &self.triples[j].0 == code {
65                j += 1;
66            }
67            let offset = blob.len() as u64;
68            fsa.insert(code, offset);
69            write_uvarint(&mut blob, (j - i) as u64);
70            for t in &self.triples[i..j] {
71                write_uvarint(&mut blob, t.1.len() as u64);
72                blob.extend_from_slice(&t.1);
73                write_uvarint(&mut blob, t.2);
74            }
75            i = j;
76        }
77
78        let fsa_bytes = fsa.finish();
79        let mut out =
80            Vec::with_capacity(8 + fsa_bytes.len() + blob.len());
81        out.extend_from_slice(MAGIC);
82        out.extend_from_slice(&(fsa_bytes.len() as u32).to_le_bytes());
83        out.extend_from_slice(&fsa_bytes);
84        out.extend_from_slice(&blob);
85        out
86    }
87}
88
89/// Read-only two-level dictionary over a byte container.
90pub struct Dict<D> {
91    data: D,
92    fsa_lo: usize,
93    fsa_hi: usize,
94    blob_lo: usize,
95}
96
97impl<D: AsRef<[u8]>> Dict<D> {
98    pub fn new(data: D) -> Result<Self, FsaError> {
99        let b = data.as_ref();
100        if b.len() < 8 {
101            return Err(FsaError::Truncated);
102        }
103        if &b[0..4] != MAGIC {
104            return Err(FsaError::BadMagic);
105        }
106        let fsa_len = rd_u32(b, 4) as usize;
107        let fsa_lo = 8;
108        let fsa_hi = fsa_lo + fsa_len;
109        if b.len() < fsa_hi {
110            return Err(FsaError::Truncated);
111        }
112        // Validate the embedded automaton up front.
113        Fsa::new(&b[fsa_lo..fsa_hi])?;
114        Ok(Self {
115            data,
116            fsa_lo,
117            fsa_hi,
118            blob_lo: fsa_hi,
119        })
120    }
121
122    #[inline]
123    fn fsa(&self) -> Fsa<&[u8]> {
124        // Header already validated in `new`; this re-parse is ~constant.
125        Fsa::new(&self.data.as_ref()[self.fsa_lo..self.fsa_hi])
126            .expect("embedded fsa validated in Dict::new")
127    }
128
129    /// Number of distinct codes.
130    pub fn len(&self) -> u64 {
131        self.fsa().len()
132    }
133
134    pub fn is_empty(&self) -> bool {
135        self.len() == 0
136    }
137
138    /// Items for an exact `code`, in stored (value-desc) order. Empty if the
139    /// code is absent. Allocates a `Vec`; hot paths should prefer the
140    /// allocation-free [`get_for_each`](Self::get_for_each).
141    pub fn get(&self, code: &[u8]) -> Vec<(Vec<u8>, u64)> {
142        let mut out = Vec::new();
143        self.get_for_each(code, |item, val| out.push((item.to_vec(), val)));
144        out
145    }
146
147    /// Streaming variant of [`get`](Self::get): invoke `visit(item, value)`
148    /// for each item of an exact `code` with no result allocation and no
149    /// per-item copy (the `item` slice is valid only for the call). This is
150    /// the per-keystroke entry the IME dict layer should use.
151    pub fn get_for_each<F: FnMut(&[u8], u64)>(&self, code: &[u8], mut visit: F) {
152        let Some(off) = self.fsa().get(code) else {
153            return;
154        };
155        let b = self.data.as_ref();
156        let mut p = self.blob_lo + off as usize;
157        let Some(n) = rd_uvarint(b, &mut p) else { return };
158        for _ in 0..n {
159            let Some(len) = rd_uvarint(b, &mut p).map(|l| l as usize) else { return };
160            let Some(end) = p.checked_add(len) else { return };
161            let Some(item) = b.get(p..end) else { return };
162            p = end;
163            let Some(val) = rd_uvarint(b, &mut p) else { return };
164            visit(item, val);
165        }
166    }
167
168    /// `true` if any code starts with `prefix`.
169    pub fn contains_prefix(&self, prefix: &[u8]) -> bool {
170        self.fsa().contains_prefix(prefix)
171    }
172
173    /// All `(code, item, value)` triples whose code starts with `prefix`,
174    /// codes in sorted order, items in stored order within each code.
175    pub fn prefix(&self, prefix: &[u8]) -> Vec<(Vec<u8>, Vec<u8>, u64)> {
176        let mut out = Vec::new();
177        self.prefix_for_each(prefix, |code, item, val| {
178            out.push((code.to_vec(), item.to_vec(), val))
179        });
180        out
181    }
182
183    /// Streaming variant of [`prefix`](Self::prefix): invoke `visit(code,
184    /// item, value)` per item without materializing the result — the hot
185    /// path (a bare-letter code prefix can match tens of thousands of items).
186    /// The `code` and `item` slices are valid only for the call.
187    pub fn prefix_for_each<F: FnMut(&[u8], &[u8], u64)>(&self, prefix: &[u8], mut visit: F) {
188        let fsa = self.fsa();
189        let b = self.data.as_ref();
190        let blob_lo = self.blob_lo;
191        fsa.prefix_for_each(prefix, |code, off| {
192            let mut p = blob_lo + off as usize;
193            let Some(n) = rd_uvarint(b, &mut p) else { return };
194            for _ in 0..n {
195                let Some(len) = rd_uvarint(b, &mut p).map(|l| l as usize) else { return };
196                let Some(end) = p.checked_add(len) else { return };
197                let Some(item) = b.get(p..end) else { return };
198                p = end;
199                let Some(val) = rd_uvarint(b, &mut p) else { return };
200                visit(code, item, val);
201            }
202        });
203    }
204
205}
206
207#[cfg(test)]
208mod tests {
209    use super::*;
210    use std::collections::BTreeMap;
211
212    #[test]
213    fn basic() {
214        let mut b = DictBuilder::new();
215        b.insert(b"wo", "我".as_bytes(), 100);
216        b.insert(b"wo", "握".as_bytes(), 40);
217        b.insert(b"women", "我们".as_bytes(), 90);
218        let dict = Dict::new(b.finish()).unwrap();
219        // value-desc order within code
220        assert_eq!(
221            dict.get(b"wo"),
222            vec![("我".as_bytes().to_vec(), 100), ("握".as_bytes().to_vec(), 40)]
223        );
224        assert_eq!(dict.get(b"women"), vec![("我们".as_bytes().to_vec(), 90)]);
225        assert_eq!(dict.get(b"nope"), Vec::<(Vec<u8>, u64)>::new());
226        assert_eq!(dict.len(), 2);
227        // prefix "wo" → both codes
228        let pre = dict.prefix(b"wo");
229        assert_eq!(pre.len(), 3);
230        assert_eq!(pre[0].0, b"wo");
231        assert!(dict.contains_prefix(b"wom"));
232        assert!(!dict.contains_prefix(b"x"));
233    }
234
235    #[test]
236    fn edge_cases() {
237        let mut b = DictBuilder::new();
238        // code + item containing 0x00 / 0xFF; value at width boundaries.
239        b.insert(b"a\x00b", b"\xff\x00", 0);
240        b.insert(b"a\x00b", b"item2", u64::MAX);
241        b.insert(b"", b"empty-code", 65_536); // empty code is a valid key
242        let dict = Dict::new(b.finish()).unwrap();
243        let wo = dict.get(b"a\x00b");
244        assert_eq!(wo.len(), 2);
245        // value-desc: u64::MAX item first
246        assert_eq!(wo[0], (b"item2".to_vec(), u64::MAX));
247        assert_eq!(wo[1], (b"\xff\x00".to_vec(), 0));
248        assert_eq!(dict.get(b""), vec![(b"empty-code".to_vec(), 65_536)]);
249    }
250
251    use proptest::prelude::*;
252
253    proptest! {
254        #![proptest_config(ProptestConfig { cases: 200, ..ProptestConfig::default() })]
255
256        #[test]
257        fn diff_against_oracle(
258            triples in proptest::collection::vec(
259                (
260                    proptest::collection::vec(b'a'..=b'd', 1..5),
261                    proptest::collection::vec(b'x'..=b'z', 1..4),
262                    any::<u64>(),
263                ),
264                0..48,
265            ),
266            probes in proptest::collection::vec(proptest::collection::vec(b'a'..=b'd', 0..5), 0..16),
267        ) {
268            // Oracle: code → sorted (value desc, item) list, dedup exact (code,item) last-wins.
269            let mut latest: BTreeMap<(Vec<u8>, Vec<u8>), u64> = BTreeMap::new();
270            for (c, it, v) in &triples {
271                latest.insert((c.clone(), it.clone()), *v);
272            }
273            let mut oracle: BTreeMap<Vec<u8>, Vec<(Vec<u8>, u64)>> = BTreeMap::new();
274            for ((c, it), v) in &latest {
275                oracle.entry(c.clone()).or_default().push((it.clone(), *v));
276            }
277            for items in oracle.values_mut() {
278                items.sort_by(|a, b| b.1.cmp(&a.1).then(a.0.cmp(&b.0)));
279            }
280
281            let mut b = DictBuilder::new();
282            // Insert deduped (so builder sees last-wins values, matching oracle).
283            for ((c, it), v) in &latest {
284                b.insert(c, it, *v);
285            }
286            let dict = Dict::new(b.finish()).unwrap();
287
288            prop_assert_eq!(dict.len(), oracle.len() as u64);
289            for (c, items) in &oracle {
290                prop_assert_eq!(&dict.get(c), items, "get {:?}", c);
291                // get_for_each must yield exactly what get returns.
292                let mut streamed: Vec<(Vec<u8>, u64)> = Vec::new();
293                dict.get_for_each(c, |it, v| streamed.push((it.to_vec(), v)));
294                prop_assert_eq!(&streamed, items, "get_for_each {:?}", c);
295            }
296            for p in &probes {
297                let want: Vec<(Vec<u8>, Vec<u8>, u64)> = oracle
298                    .iter()
299                    .filter(|(c, _)| c.starts_with(p))
300                    .flat_map(|(c, items)| items.iter().map(move |(it, v)| (c.clone(), it.clone(), *v)))
301                    .collect();
302                prop_assert_eq!(dict.prefix(p), want, "prefix {:?}", p);
303            }
304        }
305    }
306}