1use alloc::vec::Vec;
19
20use crate::builder::{write_uvarint, Builder};
21use crate::reader::{rd_u32, rd_uvarint, Fsa, FsaError};
22
23const MAGIC: &[u8; 4] = b"IXDC";
24
25#[derive(Default)]
27pub struct DictBuilder {
28 triples: Vec<(Vec<u8>, Vec<u8>, u64)>,
29}
30
31impl DictBuilder {
32 pub fn new() -> Self {
33 Self {
34 triples: Vec::new(),
35 }
36 }
37
38 pub fn insert(&mut self, code: &[u8], item: &[u8], value: u64) {
44 self.triples.push((code.to_vec(), item.to_vec(), value));
45 }
46
47 pub fn finish(mut self) -> Vec<u8> {
48 self.triples.sort_by(|a, b| {
51 a.0.cmp(&b.0)
52 .then(b.2.cmp(&a.2)) .then(a.1.cmp(&b.1))
54 });
55
56 let mut blob: Vec<u8> = Vec::new();
57 let mut fsa = Builder::new();
58
59 let mut i = 0;
60 while i < self.triples.len() {
61 let code = &self.triples[i].0;
62 let mut j = i;
64 while j < self.triples.len() && &self.triples[j].0 == code {
65 j += 1;
66 }
67 let offset = blob.len() as u64;
68 fsa.insert(code, offset);
69 write_uvarint(&mut blob, (j - i) as u64);
70 for t in &self.triples[i..j] {
71 write_uvarint(&mut blob, t.1.len() as u64);
72 blob.extend_from_slice(&t.1);
73 write_uvarint(&mut blob, t.2);
74 }
75 i = j;
76 }
77
78 let fsa_bytes = fsa.finish();
79 let mut out =
80 Vec::with_capacity(8 + fsa_bytes.len() + blob.len());
81 out.extend_from_slice(MAGIC);
82 out.extend_from_slice(&(fsa_bytes.len() as u32).to_le_bytes());
83 out.extend_from_slice(&fsa_bytes);
84 out.extend_from_slice(&blob);
85 out
86 }
87}
88
89pub struct Dict<D> {
91 data: D,
92 fsa_lo: usize,
93 fsa_hi: usize,
94 blob_lo: usize,
95}
96
97impl<D: AsRef<[u8]>> Dict<D> {
98 pub fn new(data: D) -> Result<Self, FsaError> {
99 let b = data.as_ref();
100 if b.len() < 8 {
101 return Err(FsaError::Truncated);
102 }
103 if &b[0..4] != MAGIC {
104 return Err(FsaError::BadMagic);
105 }
106 let fsa_len = rd_u32(b, 4) as usize;
107 let fsa_lo = 8;
108 let fsa_hi = fsa_lo + fsa_len;
109 if b.len() < fsa_hi {
110 return Err(FsaError::Truncated);
111 }
112 Fsa::new(&b[fsa_lo..fsa_hi])?;
114 Ok(Self {
115 data,
116 fsa_lo,
117 fsa_hi,
118 blob_lo: fsa_hi,
119 })
120 }
121
122 #[inline]
123 fn fsa(&self) -> Fsa<&[u8]> {
124 Fsa::new(&self.data.as_ref()[self.fsa_lo..self.fsa_hi])
126 .expect("embedded fsa validated in Dict::new")
127 }
128
129 pub fn len(&self) -> u64 {
131 self.fsa().len()
132 }
133
134 pub fn is_empty(&self) -> bool {
135 self.len() == 0
136 }
137
138 pub fn get(&self, code: &[u8]) -> Vec<(Vec<u8>, u64)> {
142 let mut out = Vec::new();
143 self.get_for_each(code, |item, val| out.push((item.to_vec(), val)));
144 out
145 }
146
147 pub fn get_for_each<F: FnMut(&[u8], u64)>(&self, code: &[u8], mut visit: F) {
152 let Some(off) = self.fsa().get(code) else {
153 return;
154 };
155 let b = self.data.as_ref();
156 let mut p = self.blob_lo + off as usize;
157 let Some(n) = rd_uvarint(b, &mut p) else { return };
158 for _ in 0..n {
159 let Some(len) = rd_uvarint(b, &mut p).map(|l| l as usize) else { return };
160 let Some(end) = p.checked_add(len) else { return };
161 let Some(item) = b.get(p..end) else { return };
162 p = end;
163 let Some(val) = rd_uvarint(b, &mut p) else { return };
164 visit(item, val);
165 }
166 }
167
168 pub fn contains_prefix(&self, prefix: &[u8]) -> bool {
170 self.fsa().contains_prefix(prefix)
171 }
172
173 pub fn prefix(&self, prefix: &[u8]) -> Vec<(Vec<u8>, Vec<u8>, u64)> {
176 let mut out = Vec::new();
177 self.prefix_for_each(prefix, |code, item, val| {
178 out.push((code.to_vec(), item.to_vec(), val))
179 });
180 out
181 }
182
183 pub fn prefix_for_each<F: FnMut(&[u8], &[u8], u64)>(&self, prefix: &[u8], mut visit: F) {
188 let fsa = self.fsa();
189 let b = self.data.as_ref();
190 let blob_lo = self.blob_lo;
191 fsa.prefix_for_each(prefix, |code, off| {
192 let mut p = blob_lo + off as usize;
193 let Some(n) = rd_uvarint(b, &mut p) else { return };
194 for _ in 0..n {
195 let Some(len) = rd_uvarint(b, &mut p).map(|l| l as usize) else { return };
196 let Some(end) = p.checked_add(len) else { return };
197 let Some(item) = b.get(p..end) else { return };
198 p = end;
199 let Some(val) = rd_uvarint(b, &mut p) else { return };
200 visit(code, item, val);
201 }
202 });
203 }
204
205}
206
207#[cfg(test)]
208mod tests {
209 use super::*;
210 use std::collections::BTreeMap;
211
212 #[test]
213 fn basic() {
214 let mut b = DictBuilder::new();
215 b.insert(b"wo", "我".as_bytes(), 100);
216 b.insert(b"wo", "握".as_bytes(), 40);
217 b.insert(b"women", "我们".as_bytes(), 90);
218 let dict = Dict::new(b.finish()).unwrap();
219 assert_eq!(
221 dict.get(b"wo"),
222 vec![("我".as_bytes().to_vec(), 100), ("握".as_bytes().to_vec(), 40)]
223 );
224 assert_eq!(dict.get(b"women"), vec![("我们".as_bytes().to_vec(), 90)]);
225 assert_eq!(dict.get(b"nope"), Vec::<(Vec<u8>, u64)>::new());
226 assert_eq!(dict.len(), 2);
227 let pre = dict.prefix(b"wo");
229 assert_eq!(pre.len(), 3);
230 assert_eq!(pre[0].0, b"wo");
231 assert!(dict.contains_prefix(b"wom"));
232 assert!(!dict.contains_prefix(b"x"));
233 }
234
235 #[test]
236 fn edge_cases() {
237 let mut b = DictBuilder::new();
238 b.insert(b"a\x00b", b"\xff\x00", 0);
240 b.insert(b"a\x00b", b"item2", u64::MAX);
241 b.insert(b"", b"empty-code", 65_536); let dict = Dict::new(b.finish()).unwrap();
243 let wo = dict.get(b"a\x00b");
244 assert_eq!(wo.len(), 2);
245 assert_eq!(wo[0], (b"item2".to_vec(), u64::MAX));
247 assert_eq!(wo[1], (b"\xff\x00".to_vec(), 0));
248 assert_eq!(dict.get(b""), vec![(b"empty-code".to_vec(), 65_536)]);
249 }
250
251 use proptest::prelude::*;
252
253 proptest! {
254 #![proptest_config(ProptestConfig { cases: 200, ..ProptestConfig::default() })]
255
256 #[test]
257 fn diff_against_oracle(
258 triples in proptest::collection::vec(
259 (
260 proptest::collection::vec(b'a'..=b'd', 1..5),
261 proptest::collection::vec(b'x'..=b'z', 1..4),
262 any::<u64>(),
263 ),
264 0..48,
265 ),
266 probes in proptest::collection::vec(proptest::collection::vec(b'a'..=b'd', 0..5), 0..16),
267 ) {
268 let mut latest: BTreeMap<(Vec<u8>, Vec<u8>), u64> = BTreeMap::new();
270 for (c, it, v) in &triples {
271 latest.insert((c.clone(), it.clone()), *v);
272 }
273 let mut oracle: BTreeMap<Vec<u8>, Vec<(Vec<u8>, u64)>> = BTreeMap::new();
274 for ((c, it), v) in &latest {
275 oracle.entry(c.clone()).or_default().push((it.clone(), *v));
276 }
277 for items in oracle.values_mut() {
278 items.sort_by(|a, b| b.1.cmp(&a.1).then(a.0.cmp(&b.0)));
279 }
280
281 let mut b = DictBuilder::new();
282 for ((c, it), v) in &latest {
284 b.insert(c, it, *v);
285 }
286 let dict = Dict::new(b.finish()).unwrap();
287
288 prop_assert_eq!(dict.len(), oracle.len() as u64);
289 for (c, items) in &oracle {
290 prop_assert_eq!(&dict.get(c), items, "get {:?}", c);
291 let mut streamed: Vec<(Vec<u8>, u64)> = Vec::new();
293 dict.get_for_each(c, |it, v| streamed.push((it.to_vec(), v)));
294 prop_assert_eq!(&streamed, items, "get_for_each {:?}", c);
295 }
296 for p in &probes {
297 let want: Vec<(Vec<u8>, Vec<u8>, u64)> = oracle
298 .iter()
299 .filter(|(c, _)| c.starts_with(p))
300 .flat_map(|(c, items)| items.iter().map(move |(it, v)| (c.clone(), it.clone(), *v)))
301 .collect();
302 prop_assert_eq!(dict.prefix(p), want, "prefix {:?}", p);
303 }
304 }
305 }
306}