1use std::iter::FusedIterator;
2
3use derive_more::Deref;
4use rapidhash::{HashMapExt, RapidHashMap};
5use thiserror::Error;
6
7use crate::{
8 Dictionary, RuleId, TokenId, bpe_with_heap_last_merge, dict::RuleIdVec, typed_vec::TypedVec,
9 vocab::TokenIdVec,
10};
11
12#[derive(Clone, Debug, Error)]
13#[non_exhaustive]
14pub enum NormalizedDictBuildError {
15 #[error("multiple atomic token sequences for token {token_id} ({seq_a:?} vs {seq_b:?})")]
16 MultipleAtomicTokenSeq {
17 token_id: TokenId,
18 seq_a: Vec<TokenId>,
19 seq_b: Vec<TokenId>,
20 },
21 #[error("improper rules for token {token_id} (proper result: {proper:?})")]
22 ImproperDict {
23 token_id: TokenId,
24 proper: Vec<TokenId>,
25 },
26}
27
28#[derive(Clone, Debug, Deref)]
29pub struct NormalizedDict {
30 #[deref]
31 dict: Dictionary,
32 pub(crate) priorities: TypedVec<TokenId, RuleId>,
33 #[cfg(test)]
34 pub(crate) canonical_rules: RapidHashMap<(TokenId, TokenId), RuleId>,
35}
36
37pub(crate) const ATOMIC_TOKEN_PRIORITY: RuleId = {
38 let mut priority = RuleId::MAX;
39 *priority.inner_mut() = (priority.inner() >> 1) + 1;
40 priority
41};
42
43#[inline(always)]
44fn to_atomic_token_id(rule_id: RuleId) -> TokenId {
45 debug_assert!(rule_id >= ATOMIC_TOKEN_PRIORITY);
46 TokenId::new((rule_id - ATOMIC_TOKEN_PRIORITY).inner())
47}
48
49impl NormalizedDict {
50 pub fn new<F: FnMut(&Dictionary, TokenId, &[u8]) -> bool>(
51 dict: Dictionary,
52 mut is_atomic: F,
53 ) -> Result<Self, NormalizedDictBuildError> {
54 let capacity = dict.num_of_tokens();
55 let mut priorities = TypedVec::new_with(RuleId::MAX, capacity);
56 let mut canonical_rules = RapidHashMap::with_capacity(capacity.as_usize());
57
58 let mut atomic_seqs = TypedVec::new_with(TokenIdVec::new(), capacity);
59
60 for (token_id, priority) in priorities.enumerate_mut() {
61 let token = &dict[token_id];
62 if token.is_empty() {
63 continue;
64 }
65 if is_atomic(&dict, token_id, token) {
66 atomic_seqs[token_id].push(token_id);
67 debug_assert!(token_id.as_usize() < ATOMIC_TOKEN_PRIORITY.as_usize());
68 let mut p = ATOMIC_TOKEN_PRIORITY;
69 *p.inner_mut() += token_id.inner();
70 *priority = p;
71 }
72 }
73
74 let mut token_to_rules = TypedVec::new_with(RuleIdVec::new(), capacity);
75 for (rule_id, rule) in dict.rules.enumerate() {
76 token_to_rules[rule.merged].push(rule_id);
77 }
78 for token_id in {
79 let mut order: Vec<_> = dict.tokens.keys().collect();
80 order.sort_by_key(|&i| dict[i].len());
81 order
82 } {
83 for &rule_id in &token_to_rules[token_id] {
84 let rule = &dict[rule_id];
85 if atomic_seqs[rule.pre].is_empty() || atomic_seqs[rule.suc].is_empty() {
86 continue;
87 }
88 let mut seq = atomic_seqs[rule.pre].clone();
89 seq.extend_from_slice(&atomic_seqs[rule.suc]);
90 let slot = &mut atomic_seqs[token_id];
91 if !slot.is_empty() && *slot != seq {
92 return Err(NormalizedDictBuildError::MultipleAtomicTokenSeq {
93 token_id,
94 seq_a: slot.to_vec(),
95 seq_b: seq.to_vec(),
96 });
97 }
98 *slot = seq;
99 }
100 }
101 drop(token_to_rules);
102
103 let mut validation = TypedVec::new_with(false, dict.num_of_rules());
104 for (token_id, seq) in atomic_seqs.enumerate() {
105 if seq.is_empty() {
106 continue;
107 }
108 let improper = bpe_with_heap_last_merge::<true>(&dict, seq.to_vec());
109 if improper.0 != vec![token_id] {
110 continue;
111 }
112 let proper = bpe_with_heap_last_merge::<false>(&dict, seq.to_vec());
113 if proper != improper {
114 return Err(NormalizedDictBuildError::ImproperDict {
115 token_id,
116 proper: proper.0,
117 });
118 }
119 if let Some(last_rule_id) = proper.1 {
120 validation[last_rule_id] = true;
121 }
122 }
123 drop(atomic_seqs);
124
125 'outer: for (id, rule) in dict.rules.enumerate() {
126 let mut left = priorities[rule.pre];
127 let mut right = priorities[rule.suc];
128 if priorities[rule.merged] != RuleId::MAX || left == RuleId::MAX || right == RuleId::MAX
129 {
130 continue;
131 }
132 while left < ATOMIC_TOKEN_PRIORITY || right < ATOMIC_TOKEN_PRIORITY {
133 let (u, v): (TokenId, TokenId);
134 if left == right {
135 u = dict[left].suc;
136 v = dict[right].pre;
137 } else if left >= ATOMIC_TOKEN_PRIORITY {
138 u = to_atomic_token_id(left);
139 v = dict[right].pre;
140 debug_assert_eq!(left, priorities[u]);
141 } else if right >= ATOMIC_TOKEN_PRIORITY {
142 u = dict[left].suc;
143 v = to_atomic_token_id(right);
144 debug_assert_eq!(right, priorities[v]);
145 } else if left > right {
146 u = dict[left].suc;
147 v = dict[right].merged;
148 debug_assert_eq!(right, priorities[v]);
149 } else {
150 u = dict[left].merged;
151 v = dict[right].pre;
152 debug_assert_eq!(left, priorities[u]);
153 }
154 if let Some(&mid) = canonical_rules.get(&(u, v)) {
155 debug_assert!(priorities[u] >= ATOMIC_TOKEN_PRIORITY || mid > priorities[u]);
156 debug_assert!(priorities[v] >= ATOMIC_TOKEN_PRIORITY || mid > priorities[v]);
157 if left == right || right == priorities[v] {
158 if mid < left {
159 continue 'outer;
160 }
161 } else if mid <= right {
162 continue 'outer;
163 }
164 }
165 if left < ATOMIC_TOKEN_PRIORITY {
166 left = priorities[u];
167 }
168 if right < ATOMIC_TOKEN_PRIORITY {
169 right = priorities[v];
170 }
171 debug_assert_ne!(left, RuleId::MAX);
172 debug_assert_ne!(right, RuleId::MAX);
173 }
174 priorities[rule.merged] = id;
175 let res = canonical_rules.insert((rule.pre, rule.suc), id);
176 debug_assert!(res.is_none());
177 debug_assert!(validation[id]);
178 validation[id] = false;
179 }
180
181 debug_assert!(validation.into_iter().all(|i| !i));
182
183 Ok(Self {
184 dict,
185 priorities,
186 #[cfg(test)]
187 canonical_rules,
188 })
189 }
190
191 #[inline]
192 pub fn new_in_bytes(dict: Dictionary) -> Result<Self, NormalizedDictBuildError> {
193 Self::new(dict, |_, _, b| b.len() == 1)
194 }
195
196 #[inline]
197 pub fn new_in_utf8(dict: Dictionary) -> Result<Self, NormalizedDictBuildError> {
198 Self::new(dict, |_, _, b| {
199 if b.len() > 4 {
200 return false;
201 }
202 std::str::from_utf8(b).is_ok_and(|s| s.chars().count() == 1)
203 })
204 }
205
206 #[inline(always)]
207 pub fn priority(&self, token_id: TokenId) -> RuleId {
208 self.priorities
209 .get(token_id)
210 .copied()
211 .unwrap_or(RuleId::MAX)
212 }
213
214 #[inline(always)]
215 pub fn is_atomic(&self, token_id: TokenId) -> bool {
216 self.is_canonical(token_id) && self.priorities[token_id] >= ATOMIC_TOKEN_PRIORITY
217 }
218
219 #[inline(always)]
220 pub fn is_canonical(&self, token_id: TokenId) -> bool {
221 self.priority(token_id) != RuleId::MAX
222 }
223
224 #[inline(always)]
225 pub fn iter_canonical_or_empty_tokens(
226 &self,
227 ) -> impl DoubleEndedIterator<Item = &[u8]> + ExactSizeIterator + FusedIterator {
228 self.tokens.enumerate().map(|(token_id, bytes)| {
229 if self.is_canonical(token_id) {
230 bytes.as_ref()
231 } else {
232 &[]
233 }
234 })
235 }
236}
237
238#[cfg(test)]
239mod tests {
240 use crate::{
241 Dictionary, NormalizedDict, NormalizedDictBuildError, RuleId, Vocab, bpe_with_heap,
242 test_utils::{bytes_into_tokens, utf8_into_tokens},
243 };
244
245 fn build_dict<T: AsRef<[u8]>, R: IntoIterator<Item = (T, T)>>(
246 vocab: &Vocab,
247 rules: R,
248 ) -> Dictionary {
249 Dictionary::new_from_token_pair(vocab.clone(), rules).unwrap()
250 }
251
252 fn build_in_bytes(dict: &Dictionary) -> Option<NormalizedDict> {
253 let dict = match NormalizedDict::new_in_bytes(dict.clone()) {
254 Ok(dict) => dict,
255 Err(NormalizedDictBuildError::ImproperDict { .. }) => {
256 return None;
257 }
258 Err(e) => {
259 dbg!(e);
260 unreachable!();
261 }
262 };
263 for rule in &dict.rules {
264 let token_id = rule.merged;
265 assert!(!dict.is_atomic(token_id));
266 let seq = &dict[token_id];
267 let res = bpe_with_heap::<false>(&dict, bytes_into_tokens(&dict, seq, 0usize));
268 assert!(dict.is_canonical(token_id) ^ (res != vec![token_id]));
269 }
270 Some(dict)
271 }
272
273 fn build_in_utf8(dict: &Dictionary) -> Option<NormalizedDict> {
274 let dict = match NormalizedDict::new_in_utf8(dict.clone()) {
275 Ok(dict) => dict,
276 Err(NormalizedDictBuildError::ImproperDict { .. }) => {
277 return None;
278 }
279 Err(e) => {
280 dbg!(e);
281 unreachable!();
282 }
283 };
284 for rule in &dict.rules {
285 let token_id = rule.merged;
286 let seq = match std::str::from_utf8(&dict[token_id]) {
287 Ok(seq) => seq,
288 Err(_) => {
289 assert!(!dict.is_canonical(token_id));
290 continue;
291 }
292 };
293 assert!(!dict.is_atomic(token_id));
294 let res = bpe_with_heap::<false>(&dict, utf8_into_tokens(&dict, seq, 0usize));
295 assert!(dict.is_canonical(token_id) ^ (res != vec![token_id]));
296 }
297 Some(dict)
298 }
299
300 fn canonical_rules<R: IntoIterator<Item = u32>>(dict: &NormalizedDict, rules: R) {
301 let mut rules: Vec<_> = rules.into_iter().map(RuleId::new).collect();
302 rules.sort();
303 let mut expected: Vec<_> = dict.canonical_rules.values().copied().collect();
304 expected.sort();
305 assert_eq!(rules, expected);
306 }
307
308 fn build_and_test_rules<R: IntoIterator<Item = u32> + Clone>(dict: &Dictionary, rules: R) {
309 if let Some(normalized) = build_in_bytes(dict) {
310 canonical_rules(&normalized, rules.clone());
311 }
312 if let Some(normalized) = build_in_utf8(dict) {
313 canonical_rules(&normalized, rules);
314 }
315 }
316
317 #[test]
318 fn test_normalized_dict() {
319 let vocab = Vocab::new([
320 b"" as &[_],
321 b"a",
322 b"b",
323 b"c",
324 b"d",
325 b"cd",
326 b"bcd",
327 b"abcd",
328 "你".as_bytes(),
329 "好".as_bytes(),
330 "呀".as_bytes(),
331 "你好".as_bytes(),
332 "你好呀".as_bytes(),
333 "好你".as_bytes(),
334 b"\xe4",
335 b"\xbd",
336 b"\xa0",
337 b"\xbd\xa0",
338 b"aa",
339 b"aaa",
340 b"aaaa",
341 b"aaaaa",
342 ])
343 .unwrap();
344
345 let dict = build_dict(&vocab, [("c", "d"), ("b", "cd"), ("a", "bcd")]);
346 build_and_test_rules(&dict, [0, 1, 2]);
347
348 let dict = build_dict(
349 &vocab,
350 [(b"\xbd" as &[_], b"\xa0" as &[_]), (b"\xe4", b"\xbd\xa0")],
351 );
352 let normalized = build_in_bytes(&dict).unwrap();
353 canonical_rules(&normalized, [0, 1]);
354
355 let dict = build_dict(&vocab, [("aa", "a"), ("a", "a")]);
356 build_and_test_rules(&dict, [1]);
357
358 let dict = build_dict(&vocab, [("a", "aa"), ("a", "a")]);
359 build_and_test_rules(&dict, [1]);
360
361 let dict = build_dict(&vocab, [("a", "a"), ("aa", "a")]);
362 build_and_test_rules(&dict, [0, 1]);
363
364 let dict = build_dict(&vocab, [("a", "a"), ("a", "aa")]);
365 build_and_test_rules(&dict, [0]);
366
367 let dict = build_dict(
368 &vocab,
369 [
370 ("a", "a"),
371 ("aa", "a"),
372 ("a", "aa"),
373 ("aa", "aa"),
374 ("a", "aaa"),
375 ("aaa", "a"),
376 ],
377 );
378 build_and_test_rules(&dict, [0, 1, 3]);
379
380 let dict = build_dict(&vocab, [("a", "a"), ("aa", "a"), ("aaa", "a")]);
381 build_and_test_rules(&dict, [0, 1]);
382
383 let dict = build_dict(&vocab, [("a", "a"), ("aa", "a"), ("aa", "aa")]);
384 build_and_test_rules(&dict, [0, 1, 2]);
385 let dict = build_dict(&vocab, [("a", "a"), ("aa", "aa"), ("aa", "a")]);
386 build_and_test_rules(&dict, [0, 1, 2]);
387
388 let dict = build_dict(
389 &vocab,
390 [
391 ("a", "a"),
392 ("aa", "aa"),
393 ("aa", "a"),
394 ("aaa", "aa"),
395 ("aa", "aaa"),
396 ("aaaa", "a"),
397 ],
398 );
399 build_and_test_rules(&dict, [0, 1, 2, 5]);
400
401 let dict = build_dict(
402 &vocab,
403 [
404 ("a", "a"),
405 ("aa", "a"),
406 ("aa", "aa"),
407 ("aaa", "aa"),
408 ("aa", "aaa"),
409 ("aaaa", "a"),
410 ],
411 );
412 build_and_test_rules(&dict, [0, 1, 2, 4]);
413
414 let dict = build_dict(&vocab, [("你", "好"), ("你好", "呀")]);
415 let normalized = build_in_utf8(&dict).unwrap();
416 canonical_rules(&normalized, [0, 1]);
417 let dict = build_dict(&vocab, [("你", "好"), ("你好", "呀"), ("好", "你")]);
418 let normalized = build_in_utf8(&dict).unwrap();
419 canonical_rules(&normalized, [0, 1, 2]);
420 let dict = build_dict(&vocab, [("你", "好"), ("好", "你"), ("你好", "呀")]);
421 let normalized = build_in_utf8(&dict).unwrap();
422 canonical_rules(&normalized, [0, 1, 2]);
423 let dict = build_dict(&vocab, [("好", "你"), ("你", "好"), ("你好", "呀")]);
424 let normalized = build_in_utf8(&dict).unwrap();
425 canonical_rules(&normalized, [0, 1, 2]);
426 let dict = build_dict(&vocab, [("你好", "呀"), ("你", "好"), ("好", "你")]);
427 assert!(build_in_utf8(&dict).is_none());
428 let dict = build_dict(&vocab, [("你好", "呀"), ("好", "你"), ("你", "好")]);
429 assert!(build_in_utf8(&dict).is_none());
430 let dict = build_dict(&vocab, [("好", "你"), ("你好", "呀"), ("你", "好")]);
431 assert!(build_in_utf8(&dict).is_none());
432
433 let vocab = Vocab::new([
434 b"" as &[_],
435 b"a",
436 b"abc",
437 b"abcde",
438 b"abcdef",
439 b"b",
440 b"ba",
441 b"bc",
442 b"bcdef",
443 b"c",
444 b"cd",
445 b"cde",
446 b"cdefg",
447 b"d",
448 b"de",
449 b"def",
450 b"e",
451 b"ef",
452 b"efg",
453 b"f",
454 b"g",
455 ])
456 .unwrap();
457 let dict = build_dict(
458 &vocab,
459 [
460 ("b", "c"),
461 ("e", "f"),
462 ("d", "e"),
463 ("c", "d"),
464 ("d", "ef"),
465 ("b", "a"),
466 ("a", "bc"),
467 ("abc", "de"),
468 ("abc", "def"),
469 ("bc", "def"),
470 ("c", "de"),
471 ("ef", "g"),
472 ("cd", "efg"),
473 ],
474 );
475 build_and_test_rules(&dict, 0..13);
476 let dict = build_dict(
477 &vocab,
478 [
479 ("b", "c"),
480 ("e", "f"),
481 ("d", "e"),
482 ("c", "d"),
483 ("d", "ef"),
484 ("a", "bc"),
485 ("b", "a"),
486 ("abc", "de"),
487 ("abc", "def"),
488 ("bc", "def"),
489 ("c", "de"),
490 ("ef", "g"),
491 ("cd", "efg"),
492 ],
493 );
494 build_and_test_rules(&dict, 0..13);
495 }
496
497 #[test]
498 fn test_normalized_dict_invalid() {
499 let dict = Dictionary::new_from_id_pair(
500 Vocab::new([b"a" as &[_], b"aa"]).unwrap(),
501 [(0usize, 0usize)],
502 )
503 .unwrap();
504 let res = NormalizedDict::new(dict.clone(), |_, _, b| b.len() == 1);
505 assert!(res.is_ok());
506 let res = NormalizedDict::new(dict, |_, _, _| true);
507 assert!(res.is_err());
508 }
509}