Skip to main content

mtc_inc_bpe/
normalize.rs

1use std::iter::FusedIterator;
2
3use derive_more::Deref;
4use rapidhash::{HashMapExt, RapidHashMap};
5use thiserror::Error;
6
7use crate::{
8    Dictionary, RuleId, TokenId, bpe_with_heap_last_merge, dict::RuleIdVec, typed_vec::TypedVec,
9    vocab::TokenIdVec,
10};
11
12#[derive(Clone, Debug, Error)]
13#[non_exhaustive]
14pub enum NormalizedDictBuildError {
15    #[error("multiple atomic token sequences for token {token_id} ({seq_a:?} vs {seq_b:?})")]
16    MultipleAtomicTokenSeq {
17        token_id: TokenId,
18        seq_a: Vec<TokenId>,
19        seq_b: Vec<TokenId>,
20    },
21    #[error("improper rules for token {token_id} (proper result: {proper:?})")]
22    ImproperDict {
23        token_id: TokenId,
24        proper: Vec<TokenId>,
25    },
26}
27
28#[derive(Clone, Debug, Deref)]
29pub struct NormalizedDict {
30    #[deref]
31    dict: Dictionary,
32    pub(crate) priorities: TypedVec<TokenId, RuleId>,
33    #[cfg(test)]
34    pub(crate) canonical_rules: RapidHashMap<(TokenId, TokenId), RuleId>,
35}
36
37pub(crate) const ATOMIC_TOKEN_PRIORITY: RuleId = {
38    let mut priority = RuleId::MAX;
39    *priority.inner_mut() = (priority.inner() >> 1) + 1;
40    priority
41};
42
43#[inline(always)]
44fn to_atomic_token_id(rule_id: RuleId) -> TokenId {
45    debug_assert!(rule_id >= ATOMIC_TOKEN_PRIORITY);
46    TokenId::new((rule_id - ATOMIC_TOKEN_PRIORITY).inner())
47}
48
49impl NormalizedDict {
50    pub fn new<F: FnMut(&Dictionary, TokenId, &[u8]) -> bool>(
51        dict: Dictionary,
52        mut is_atomic: F,
53    ) -> Result<Self, NormalizedDictBuildError> {
54        let capacity = dict.num_of_tokens();
55        let mut priorities = TypedVec::new_with(RuleId::MAX, capacity);
56        let mut canonical_rules = RapidHashMap::with_capacity(capacity.as_usize());
57
58        let mut atomic_seqs = TypedVec::new_with(TokenIdVec::new(), capacity);
59
60        for (token_id, priority) in priorities.enumerate_mut() {
61            let token = &dict[token_id];
62            if token.is_empty() {
63                continue;
64            }
65            if is_atomic(&dict, token_id, token) {
66                atomic_seqs[token_id].push(token_id);
67                debug_assert!(token_id.as_usize() < ATOMIC_TOKEN_PRIORITY.as_usize());
68                let mut p = ATOMIC_TOKEN_PRIORITY;
69                *p.inner_mut() += token_id.inner();
70                *priority = p;
71            }
72        }
73
74        let mut token_to_rules = TypedVec::new_with(RuleIdVec::new(), capacity);
75        for (rule_id, rule) in dict.rules.enumerate() {
76            token_to_rules[rule.merged].push(rule_id);
77        }
78        for token_id in {
79            let mut order: Vec<_> = dict.tokens.keys().collect();
80            order.sort_by_key(|&i| dict[i].len());
81            order
82        } {
83            for &rule_id in &token_to_rules[token_id] {
84                let rule = &dict[rule_id];
85                if atomic_seqs[rule.pre].is_empty() || atomic_seqs[rule.suc].is_empty() {
86                    continue;
87                }
88                let mut seq = atomic_seqs[rule.pre].clone();
89                seq.extend_from_slice(&atomic_seqs[rule.suc]);
90                let slot = &mut atomic_seqs[token_id];
91                if !slot.is_empty() && *slot != seq {
92                    return Err(NormalizedDictBuildError::MultipleAtomicTokenSeq {
93                        token_id,
94                        seq_a: slot.to_vec(),
95                        seq_b: seq.to_vec(),
96                    });
97                }
98                *slot = seq;
99            }
100        }
101        drop(token_to_rules);
102
103        let mut validation = TypedVec::new_with(false, dict.num_of_rules());
104        for (token_id, seq) in atomic_seqs.enumerate() {
105            if seq.is_empty() {
106                continue;
107            }
108            let improper = bpe_with_heap_last_merge::<true>(&dict, seq.to_vec());
109            if improper.0 != vec![token_id] {
110                continue;
111            }
112            let proper = bpe_with_heap_last_merge::<false>(&dict, seq.to_vec());
113            if proper != improper {
114                return Err(NormalizedDictBuildError::ImproperDict {
115                    token_id,
116                    proper: proper.0,
117                });
118            }
119            if let Some(last_rule_id) = proper.1 {
120                validation[last_rule_id] = true;
121            }
122        }
123        drop(atomic_seqs);
124
125        'outer: for (id, rule) in dict.rules.enumerate() {
126            let mut left = priorities[rule.pre];
127            let mut right = priorities[rule.suc];
128            if priorities[rule.merged] != RuleId::MAX || left == RuleId::MAX || right == RuleId::MAX
129            {
130                continue;
131            }
132            while left < ATOMIC_TOKEN_PRIORITY || right < ATOMIC_TOKEN_PRIORITY {
133                let (u, v): (TokenId, TokenId);
134                if left == right {
135                    u = dict[left].suc;
136                    v = dict[right].pre;
137                } else if left >= ATOMIC_TOKEN_PRIORITY {
138                    u = to_atomic_token_id(left);
139                    v = dict[right].pre;
140                    debug_assert_eq!(left, priorities[u]);
141                } else if right >= ATOMIC_TOKEN_PRIORITY {
142                    u = dict[left].suc;
143                    v = to_atomic_token_id(right);
144                    debug_assert_eq!(right, priorities[v]);
145                } else if left > right {
146                    u = dict[left].suc;
147                    v = dict[right].merged;
148                    debug_assert_eq!(right, priorities[v]);
149                } else {
150                    u = dict[left].merged;
151                    v = dict[right].pre;
152                    debug_assert_eq!(left, priorities[u]);
153                }
154                if let Some(&mid) = canonical_rules.get(&(u, v)) {
155                    debug_assert!(priorities[u] >= ATOMIC_TOKEN_PRIORITY || mid > priorities[u]);
156                    debug_assert!(priorities[v] >= ATOMIC_TOKEN_PRIORITY || mid > priorities[v]);
157                    if left == right || right == priorities[v] {
158                        if mid < left {
159                            continue 'outer;
160                        }
161                    } else if mid <= right {
162                        continue 'outer;
163                    }
164                }
165                if left < ATOMIC_TOKEN_PRIORITY {
166                    left = priorities[u];
167                }
168                if right < ATOMIC_TOKEN_PRIORITY {
169                    right = priorities[v];
170                }
171                debug_assert_ne!(left, RuleId::MAX);
172                debug_assert_ne!(right, RuleId::MAX);
173            }
174            priorities[rule.merged] = id;
175            let res = canonical_rules.insert((rule.pre, rule.suc), id);
176            debug_assert!(res.is_none());
177            debug_assert!(validation[id]);
178            validation[id] = false;
179        }
180
181        debug_assert!(validation.into_iter().all(|i| !i));
182
183        Ok(Self {
184            dict,
185            priorities,
186            #[cfg(test)]
187            canonical_rules,
188        })
189    }
190
191    #[inline]
192    pub fn new_in_bytes(dict: Dictionary) -> Result<Self, NormalizedDictBuildError> {
193        Self::new(dict, |_, _, b| b.len() == 1)
194    }
195
196    #[inline]
197    pub fn new_in_utf8(dict: Dictionary) -> Result<Self, NormalizedDictBuildError> {
198        Self::new(dict, |_, _, b| {
199            if b.len() > 4 {
200                return false;
201            }
202            std::str::from_utf8(b).is_ok_and(|s| s.chars().count() == 1)
203        })
204    }
205
206    #[inline(always)]
207    pub fn priority(&self, token_id: TokenId) -> RuleId {
208        self.priorities
209            .get(token_id)
210            .copied()
211            .unwrap_or(RuleId::MAX)
212    }
213
214    #[inline(always)]
215    pub fn is_atomic(&self, token_id: TokenId) -> bool {
216        self.is_canonical(token_id) && self.priorities[token_id] >= ATOMIC_TOKEN_PRIORITY
217    }
218
219    #[inline(always)]
220    pub fn is_canonical(&self, token_id: TokenId) -> bool {
221        self.priority(token_id) != RuleId::MAX
222    }
223
224    #[inline(always)]
225    pub fn iter_canonical_or_empty_tokens(
226        &self,
227    ) -> impl DoubleEndedIterator<Item = &[u8]> + ExactSizeIterator + FusedIterator {
228        self.tokens.enumerate().map(|(token_id, bytes)| {
229            if self.is_canonical(token_id) {
230                bytes.as_ref()
231            } else {
232                &[]
233            }
234        })
235    }
236}
237
238#[cfg(test)]
239mod tests {
240    use crate::{
241        Dictionary, NormalizedDict, NormalizedDictBuildError, RuleId, Vocab, bpe_with_heap,
242        test_utils::{bytes_into_tokens, utf8_into_tokens},
243    };
244
245    fn build_dict<T: AsRef<[u8]>, R: IntoIterator<Item = (T, T)>>(
246        vocab: &Vocab,
247        rules: R,
248    ) -> Dictionary {
249        Dictionary::new_from_token_pair(vocab.clone(), rules).unwrap()
250    }
251
252    fn build_in_bytes(dict: &Dictionary) -> Option<NormalizedDict> {
253        let dict = match NormalizedDict::new_in_bytes(dict.clone()) {
254            Ok(dict) => dict,
255            Err(NormalizedDictBuildError::ImproperDict { .. }) => {
256                return None;
257            }
258            Err(e) => {
259                dbg!(e);
260                unreachable!();
261            }
262        };
263        for rule in &dict.rules {
264            let token_id = rule.merged;
265            assert!(!dict.is_atomic(token_id));
266            let seq = &dict[token_id];
267            let res = bpe_with_heap::<false>(&dict, bytes_into_tokens(&dict, seq, 0usize));
268            assert!(dict.is_canonical(token_id) ^ (res != vec![token_id]));
269        }
270        Some(dict)
271    }
272
273    fn build_in_utf8(dict: &Dictionary) -> Option<NormalizedDict> {
274        let dict = match NormalizedDict::new_in_utf8(dict.clone()) {
275            Ok(dict) => dict,
276            Err(NormalizedDictBuildError::ImproperDict { .. }) => {
277                return None;
278            }
279            Err(e) => {
280                dbg!(e);
281                unreachable!();
282            }
283        };
284        for rule in &dict.rules {
285            let token_id = rule.merged;
286            let seq = match std::str::from_utf8(&dict[token_id]) {
287                Ok(seq) => seq,
288                Err(_) => {
289                    assert!(!dict.is_canonical(token_id));
290                    continue;
291                }
292            };
293            assert!(!dict.is_atomic(token_id));
294            let res = bpe_with_heap::<false>(&dict, utf8_into_tokens(&dict, seq, 0usize));
295            assert!(dict.is_canonical(token_id) ^ (res != vec![token_id]));
296        }
297        Some(dict)
298    }
299
300    fn canonical_rules<R: IntoIterator<Item = u32>>(dict: &NormalizedDict, rules: R) {
301        let mut rules: Vec<_> = rules.into_iter().map(RuleId::new).collect();
302        rules.sort();
303        let mut expected: Vec<_> = dict.canonical_rules.values().copied().collect();
304        expected.sort();
305        assert_eq!(rules, expected);
306    }
307
308    fn build_and_test_rules<R: IntoIterator<Item = u32> + Clone>(dict: &Dictionary, rules: R) {
309        if let Some(normalized) = build_in_bytes(dict) {
310            canonical_rules(&normalized, rules.clone());
311        }
312        if let Some(normalized) = build_in_utf8(dict) {
313            canonical_rules(&normalized, rules);
314        }
315    }
316
317    #[test]
318    fn test_normalized_dict() {
319        let vocab = Vocab::new([
320            b"" as &[_],
321            b"a",
322            b"b",
323            b"c",
324            b"d",
325            b"cd",
326            b"bcd",
327            b"abcd",
328            "你".as_bytes(),
329            "好".as_bytes(),
330            "呀".as_bytes(),
331            "你好".as_bytes(),
332            "你好呀".as_bytes(),
333            "好你".as_bytes(),
334            b"\xe4",
335            b"\xbd",
336            b"\xa0",
337            b"\xbd\xa0",
338            b"aa",
339            b"aaa",
340            b"aaaa",
341            b"aaaaa",
342        ])
343        .unwrap();
344
345        let dict = build_dict(&vocab, [("c", "d"), ("b", "cd"), ("a", "bcd")]);
346        build_and_test_rules(&dict, [0, 1, 2]);
347
348        let dict = build_dict(
349            &vocab,
350            [(b"\xbd" as &[_], b"\xa0" as &[_]), (b"\xe4", b"\xbd\xa0")],
351        );
352        let normalized = build_in_bytes(&dict).unwrap();
353        canonical_rules(&normalized, [0, 1]);
354
355        let dict = build_dict(&vocab, [("aa", "a"), ("a", "a")]);
356        build_and_test_rules(&dict, [1]);
357
358        let dict = build_dict(&vocab, [("a", "aa"), ("a", "a")]);
359        build_and_test_rules(&dict, [1]);
360
361        let dict = build_dict(&vocab, [("a", "a"), ("aa", "a")]);
362        build_and_test_rules(&dict, [0, 1]);
363
364        let dict = build_dict(&vocab, [("a", "a"), ("a", "aa")]);
365        build_and_test_rules(&dict, [0]);
366
367        let dict = build_dict(
368            &vocab,
369            [
370                ("a", "a"),
371                ("aa", "a"),
372                ("a", "aa"),
373                ("aa", "aa"),
374                ("a", "aaa"),
375                ("aaa", "a"),
376            ],
377        );
378        build_and_test_rules(&dict, [0, 1, 3]);
379
380        let dict = build_dict(&vocab, [("a", "a"), ("aa", "a"), ("aaa", "a")]);
381        build_and_test_rules(&dict, [0, 1]);
382
383        let dict = build_dict(&vocab, [("a", "a"), ("aa", "a"), ("aa", "aa")]);
384        build_and_test_rules(&dict, [0, 1, 2]);
385        let dict = build_dict(&vocab, [("a", "a"), ("aa", "aa"), ("aa", "a")]);
386        build_and_test_rules(&dict, [0, 1, 2]);
387
388        let dict = build_dict(
389            &vocab,
390            [
391                ("a", "a"),
392                ("aa", "aa"),
393                ("aa", "a"),
394                ("aaa", "aa"),
395                ("aa", "aaa"),
396                ("aaaa", "a"),
397            ],
398        );
399        build_and_test_rules(&dict, [0, 1, 2, 5]);
400
401        let dict = build_dict(
402            &vocab,
403            [
404                ("a", "a"),
405                ("aa", "a"),
406                ("aa", "aa"),
407                ("aaa", "aa"),
408                ("aa", "aaa"),
409                ("aaaa", "a"),
410            ],
411        );
412        build_and_test_rules(&dict, [0, 1, 2, 4]);
413
414        let dict = build_dict(&vocab, [("你", "好"), ("你好", "呀")]);
415        let normalized = build_in_utf8(&dict).unwrap();
416        canonical_rules(&normalized, [0, 1]);
417        let dict = build_dict(&vocab, [("你", "好"), ("你好", "呀"), ("好", "你")]);
418        let normalized = build_in_utf8(&dict).unwrap();
419        canonical_rules(&normalized, [0, 1, 2]);
420        let dict = build_dict(&vocab, [("你", "好"), ("好", "你"), ("你好", "呀")]);
421        let normalized = build_in_utf8(&dict).unwrap();
422        canonical_rules(&normalized, [0, 1, 2]);
423        let dict = build_dict(&vocab, [("好", "你"), ("你", "好"), ("你好", "呀")]);
424        let normalized = build_in_utf8(&dict).unwrap();
425        canonical_rules(&normalized, [0, 1, 2]);
426        let dict = build_dict(&vocab, [("你好", "呀"), ("你", "好"), ("好", "你")]);
427        assert!(build_in_utf8(&dict).is_none());
428        let dict = build_dict(&vocab, [("你好", "呀"), ("好", "你"), ("你", "好")]);
429        assert!(build_in_utf8(&dict).is_none());
430        let dict = build_dict(&vocab, [("好", "你"), ("你好", "呀"), ("你", "好")]);
431        assert!(build_in_utf8(&dict).is_none());
432
433        let vocab = Vocab::new([
434            b"" as &[_],
435            b"a",
436            b"abc",
437            b"abcde",
438            b"abcdef",
439            b"b",
440            b"ba",
441            b"bc",
442            b"bcdef",
443            b"c",
444            b"cd",
445            b"cde",
446            b"cdefg",
447            b"d",
448            b"de",
449            b"def",
450            b"e",
451            b"ef",
452            b"efg",
453            b"f",
454            b"g",
455        ])
456        .unwrap();
457        let dict = build_dict(
458            &vocab,
459            [
460                ("b", "c"),
461                ("e", "f"),
462                ("d", "e"),
463                ("c", "d"),
464                ("d", "ef"),
465                ("b", "a"),
466                ("a", "bc"),
467                ("abc", "de"),
468                ("abc", "def"),
469                ("bc", "def"),
470                ("c", "de"),
471                ("ef", "g"),
472                ("cd", "efg"),
473            ],
474        );
475        build_and_test_rules(&dict, 0..13);
476        let dict = build_dict(
477            &vocab,
478            [
479                ("b", "c"),
480                ("e", "f"),
481                ("d", "e"),
482                ("c", "d"),
483                ("d", "ef"),
484                ("a", "bc"),
485                ("b", "a"),
486                ("abc", "de"),
487                ("abc", "def"),
488                ("bc", "def"),
489                ("c", "de"),
490                ("ef", "g"),
491                ("cd", "efg"),
492            ],
493        );
494        build_and_test_rules(&dict, 0..13);
495    }
496
497    #[test]
498    fn test_normalized_dict_invalid() {
499        let dict = Dictionary::new_from_id_pair(
500            Vocab::new([b"a" as &[_], b"aa"]).unwrap(),
501            [(0usize, 0usize)],
502        )
503        .unwrap();
504        let res = NormalizedDict::new(dict.clone(), |_, _, b| b.len() == 1);
505        assert!(res.is_ok());
506        let res = NormalizedDict::new(dict, |_, _, _| true);
507        assert!(res.is_err());
508    }
509}