rten_text/models/
bpe.rs

1use std::borrow::Cow;
2use std::collections::HashMap;
3use std::error::Error;
4use std::fmt;
5use std::fmt::{Debug, Display};
6
7use super::{DecodeError, EncodeError, Model};
8use crate::tokenizer::TokenId;
9use rustc_hash::{FxBuildHasher, FxHashMap};
10
11/// Errors that can occur when building a [`Bpe`] tokenizer or encoding or
12/// decoding text using it.
13#[derive(Debug)]
14pub enum BpeError {
15    /// There was an invalid entry in the merge list. This means that either
16    /// the entry doesn't have the expected `<token> [SPACE] <token>` format
17    /// or the `<token>` is not either a single character or the concatenation
18    /// of another pair in the merge list.
19    InvalidMergeEntry(String),
20
21    /// An entry in the vocab (token string to ID map) is not either a known
22    /// special token or an entry in the merge list.
23    InvalidVocabEntry(String),
24
25    /// An entry was not found in the vocabulary.
26    MissingVocabEntry(String),
27}
28
29impl Display for BpeError {
30    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
31        match self {
32            BpeError::InvalidMergeEntry(entry) => write!(fmt, "invalid merge entry: {}", entry),
33            BpeError::InvalidVocabEntry(entry) => write!(fmt, "invalid vocab entry: {}", entry),
34            BpeError::MissingVocabEntry(entry) => write!(fmt, "missing vocab entry: {}", entry),
35        }
36    }
37}
38
39impl Error for BpeError {}
40
41/// Rank of an entry in the BPE merge list.
42///
43/// A newtype is used here to avoid confusing ranks and token IDs.
44#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
45struct Rank(u32);
46
47/// A sequence of UTF-8 bytes, encoded as a string of printable characters.
48/// [`char_to_byte`] provides the mapping between characters and bytes.
49///
50/// Unlike a Rust `str`, the sequence of bytes do not necessarily form a
51/// complete sequence of Unicode characters. The bytes may end in the middle of
52/// a character.
53pub type EncodedByteSlice<'a> = &'a str;
54
55/// Like [`EncodedByteSlice`], but owned.
56pub type EncodedBytes = String;
57
58/// Like [`EncodedByteSlice`], but owned.
59pub type EncodedBytesCow<'a> = Cow<'a, str>;
60
61/// Return true if `c` is considered a printable character.
62///
63/// This matches the output of Python's `str.isprintable` for code points < 256,
64/// except for ASCII space.
65fn is_printable(c: char) -> bool {
66    !c.is_control() && !c.is_whitespace() && c != '\u{ad}' /* soft hyphen */
67}
68
69/// Return a mapping from byte value to printable character used to represent
70/// the byte.
71///
72/// Based on the `bytes_to_unicode` function in the original GPT-2 encoder -
73/// <https://github.com/openai/gpt-2/blob/master/src/encoder.py>.
74fn byte_to_char() -> [char; 256] {
75    let mut chars = ['\x00'; 256];
76
77    for b in 0..=255u8 {
78        let ch = char::from(b);
79        if is_printable(ch) {
80            chars[b as usize] = ch;
81        }
82    }
83
84    let mut non_printable_count = 0;
85    for b in 0..=255u8 {
86        if !is_printable(char::from(b)) {
87            chars[b as usize] = char::from_u32(256 + non_printable_count).unwrap();
88            non_printable_count += 1;
89        }
90    }
91
92    chars
93}
94
95/// Return a mapping from printable character used to represent bytes to the
96/// corresponding byte value.
97pub fn char_to_byte() -> HashMap<char, u8> {
98    byte_to_char()
99        .iter()
100        .copied()
101        .enumerate()
102        .map(|(byte, ch)| (ch, byte as u8))
103        .collect()
104}
105
106/// Iteratively merge pairs of tokens in `tokens`, using the mappings in `ranks`,
107/// until no more merges are possible.
108///
109/// Returns the number of merged tokens.
110fn bpe_merge(
111    tokens: &mut Vec<TokenId>,
112    merges: &FxHashMap<(TokenId, TokenId), (Rank, TokenId)>,
113) -> usize {
114    loop {
115        // Find the pair of tokens with the lowest rank and merge all occurences
116        // of the pair.
117        let min_pair: Option<((TokenId, TokenId), (Rank, TokenId))> = tokens
118            .windows(2)
119            .filter_map(|pair| {
120                let [first, second] = pair.try_into().unwrap();
121                merges
122                    .get(&(first, second))
123                    .map(|&rank_id| ((first, second), rank_id))
124            })
125            .min_by_key(|((_first, _second), (rank, _merged_id))| *rank);
126
127        let Some(((first, second), (_rank, merged_id))) = min_pair else {
128            break;
129        };
130
131        let mut i = 0;
132        while i < tokens.len() - 1 {
133            if tokens[i] == first && tokens[i + 1] == second {
134                tokens[i] = merged_id;
135                tokens.remove(i + 1);
136            }
137            i += 1;
138        }
139    }
140    tokens.len()
141}
142
143/// Mapping from pairs of tokens to the rank and ID of the merged pair.
144type MergeMap = FxHashMap<(TokenId, TokenId), (Rank, TokenId)>;
145
146/// Build the BPE merge map that associates a rank and token ID to merged pairs
147/// of tokens.
148fn build_merge_map(
149    vocab: &FxHashMap<EncodedBytes, TokenId>,
150    merges: &[(EncodedBytesCow, EncodedBytesCow)],
151) -> Result<MergeMap, BpeError> {
152    let mut merged_str = String::new();
153    let mut merge_map = HashMap::with_capacity_and_hasher(merges.len(), FxBuildHasher);
154
155    for (i, (a, b)) in merges.iter().enumerate() {
156        let a_id = *vocab.get(a.as_ref()).ok_or_else(|| {
157            BpeError::InvalidMergeEntry(format!(
158                "first entry in merge pair \"{a} {b}\" not found in vocab"
159            ))
160        })?;
161        let b_id = *vocab.get(b.as_ref()).ok_or_else(|| {
162            BpeError::InvalidMergeEntry(format!(
163                "second entry in merge pair \"{a} {b}\" not found in vocab"
164            ))
165        })?;
166
167        merged_str.clear();
168        merged_str.push_str(a);
169        merged_str.push_str(b);
170
171        let merged_id = *vocab.get(&merged_str).ok_or_else(|| {
172            BpeError::InvalidMergeEntry(format!("merged pair \"{a} {b}\" not found in vocab"))
173        })?;
174        let rank = Rank(i as u32);
175        merge_map.insert((a_id, b_id), (rank, merged_id));
176    }
177
178    Ok(merge_map)
179}
180
181/// Parse a list of space-separated BPE merge entries into pairs of tokens.
182///
183/// Lines that are empty or contain only a `#version` marker are ignored.
184pub fn merge_pairs_from_lines(
185    lines: &[impl AsRef<str>],
186) -> Vec<(EncodedBytesCow<'static>, EncodedBytesCow<'static>)> {
187    lines
188        .iter()
189        .filter_map(|line| {
190            let line = line.as_ref();
191            if line.starts_with("#version") || line.trim().is_empty() {
192                None
193            } else {
194                // Cloning the string here is OK since this is a legacy code
195                // path that is rarely used.
196                line.split_once(' ')
197                    .map(|(a, b)| (a.to_string().into(), b.to_string().into()))
198            }
199        })
200        .collect()
201}
202
203/// Build a mapping from token bytes to ID using the merge list.
204///
205/// This is used as a fallback when the tokenizer configuration doesn't have a
206/// vocabulary.
207fn build_vocab(
208    merges: &[(EncodedBytesCow, EncodedBytesCow)],
209    end_of_word_suffix: Option<EncodedByteSlice>,
210) -> FxHashMap<EncodedBytes, TokenId> {
211    let mut vocab = FxHashMap::default();
212
213    fn byte_to_rank() -> [Rank; 256] {
214        let mut ranks = [Rank(0); 256];
215
216        let mut rank = 0;
217        for byte in 0..=255u8 {
218            if is_printable(char::from(byte)) {
219                ranks[byte as usize] = Rank(rank);
220                rank += 1;
221            }
222        }
223
224        for byte in 0..=255u8 {
225            if !is_printable(char::from(byte)) {
226                ranks[byte as usize] = Rank(rank);
227                rank += 1;
228            }
229        }
230
231        ranks
232    }
233
234    // The first 256 token IDs are reserved for individual bytes.
235    for (ch, rank) in byte_to_char().into_iter().zip(byte_to_rank()) {
236        vocab.insert(ch.into(), rank.0);
237    }
238
239    // If an end-of-word suffix is used, the next 256 token IDs are bytes that
240    // occur at the end of a word.
241    if let Some(eow_suffix) = end_of_word_suffix {
242        let start_id = vocab.len() as u32;
243        for (ch, rank) in byte_to_char().into_iter().zip(byte_to_rank()) {
244            let mut bytes: EncodedBytes = ch.into();
245            bytes.push_str(eow_suffix);
246            vocab.insert(bytes, start_id + rank.0);
247        }
248    }
249
250    // Assign token IDs to concatenated pairs from the merge list.
251    let start_id = vocab.len() as u32;
252    vocab.extend(
253        merges
254            .iter()
255            .enumerate()
256            .map(|(i, (a, b))| ([a.as_ref(), b.as_ref()].concat(), start_id + i as u32)),
257    );
258
259    vocab
260}
261
262/// Configuration for a [`Bpe`] tokenization model.
263#[derive(Default)]
264pub struct BpeOptions<'a> {
265    /// Ordered entries of the merge list. Each entry is a pair of strings
266    /// representing byte sequences. See also [`merge_pairs_from_lines`] which
267    /// can be used to extract pairs from the space-separated format used in eg.
268    /// `merges.txt` files.
269    pub merges: &'a [(EncodedBytesCow<'a>, EncodedBytesCow<'a>)],
270
271    /// Mapping between token strings and IDs. If not provided, the
272    /// ID of a token is 256 + the index of the pair in the merge list which
273    /// form the token string when concatenated. For example, if index 10 in the
274    /// merge list is "foo bar", then the token ID of "foobar" would be 266.
275    /// Token IDs below 256 are reserved for individual bytes.
276    pub vocab: Option<FxHashMap<EncodedBytes, TokenId>>,
277
278    /// Set of tokens which don't appear in `merges` but do have a mapping in
279    /// `vocab`. These are used for special purposes such as representing the
280    /// end of output.
281    pub added_tokens: FxHashMap<TokenId, String>,
282
283    /// A string which is implicitly appended to each substring that is
284    /// tokenized, after initial splitting.
285    pub end_of_word_suffix: Option<String>,
286
287    /// When encoding a string piece, match the entire piece against the
288    /// vocabulary before applying merge rules.
289    pub ignore_merges: bool,
290}
291
292/// Byte Pair Encoding tokenizer used by GPT-2 [^1] and subsequently used by
293/// many other models.
294///
295/// Byte Pair Encoding was introduced by [^2]. Despite the name, the original
296/// version operated on characters. The variant used by GPT-2 and other OpenAI
297/// models operates on bytes instead. This avoids needing a huge base vocabulary
298/// to support Unicode.
299///
300/// [^1]: Radford, Alec, et al. (2019) "Language models are unsupervised multitask learners."
301///       <https://openai.com/research/better-language-models>
302///
303/// [^2]: Sennrich, Rico, Barry Haddow, and Alexandra Birch. "Neural machine
304///       translation of rare words with subword units." arXiv preprint
305///       arXiv:1508.07909 (2015).
306pub struct Bpe {
307    merges: MergeMap,
308
309    /// Map from byte values to token IDs.
310    byte_to_token_id: [TokenId; 256],
311
312    /// Map from byte values to printable character representation used in
313    /// vocabulary.
314    byte_to_char: [char; 256],
315
316    token_id_to_encoded_bytes: FxHashMap<TokenId, EncodedBytes>,
317
318    vocab: Option<FxHashMap<EncodedBytes, TokenId>>,
319
320    /// Map from token ID to content for special tokens (eg. end-of-string).
321    added_tokens: FxHashMap<TokenId, String>,
322
323    /// A suffix which is implicitly appended to each string piece to be
324    /// tokenized.
325    ///
326    /// This was originally introduced for CLIP's tokenizer.
327    /// See <https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py>.
328    end_of_word_suffix: Option<String>,
329
330    /// When encoding a string piece, match the entire piece against the
331    /// vocabulary before applying merge rules.
332    ignore_merges: bool,
333}
334
335impl Bpe {
336    /// Create a new Byte Pair Encoding tokenizer using the given configuration.
337    pub fn new(config: BpeOptions) -> Result<Bpe, BpeError> {
338        let BpeOptions {
339            merges,
340            vocab,
341            added_tokens,
342            mut end_of_word_suffix,
343            ignore_merges,
344        } = config;
345
346        // Normalize empty end-of-word suffix to `None`.
347        end_of_word_suffix.take_if(|suffix| suffix.is_empty());
348
349        let vocab = vocab.unwrap_or_else(|| build_vocab(merges, end_of_word_suffix.as_deref()));
350
351        let merges = build_merge_map(&vocab, merges)?;
352
353        // Build byte -> token ID mapping for encoding.
354        let mut byte_to_token_id = [0; 256];
355        for (i, ch) in byte_to_char().into_iter().enumerate() {
356            let mut ch_buf = [0u8; 4];
357            let ch_str = ch.encode_utf8(&mut ch_buf);
358            if let Some(id) = vocab.get(ch_str).copied() {
359                byte_to_token_id[i] = id;
360            } else {
361                return Err(BpeError::MissingVocabEntry(ch_str.to_string()));
362            }
363        }
364
365        // If the `ignore_merges` flag is set for this tokenizer, we'll need
366        // to use the vocabulary during encoding.
367        //
368        // Otherwise we can re-use the hash/string allocations for
369        // `token_id_to_encoded_bytes`.
370        let (vocab, token_id_to_encoded_bytes) = if ignore_merges {
371            let token_id_to_encoded_bytes = vocab
372                .iter()
373                .map(|(token, id)| (*id, token.clone()))
374                .collect();
375            (Some(vocab), token_id_to_encoded_bytes)
376        } else {
377            let token_id_to_encoded_bytes =
378                vocab.into_iter().map(|(token, id)| (id, token)).collect();
379            (None, token_id_to_encoded_bytes)
380        };
381
382        Ok(Bpe {
383            added_tokens,
384            byte_to_char: byte_to_char(),
385            byte_to_token_id,
386            end_of_word_suffix,
387            ignore_merges,
388            merges,
389            token_id_to_encoded_bytes,
390            vocab,
391        })
392    }
393
394    /// Encode a string as a sequence of tokens.
395    ///
396    /// `end_of_word` specifies whether to apply end-of-word processing rules
397    /// to the initial tokenization of piece.
398    fn encode_piece(&self, piece: &str, end_of_word: bool) -> Vec<TokenId> {
399        // If `ignore_merges` is set, check for the entire string in the vocab
400        // before using merges.
401        if self.ignore_merges
402            && let Some(vocab) = self.vocab.as_ref()
403        {
404            let encoded: EncodedBytes = piece
405                .as_bytes()
406                .iter()
407                .map(|&b| self.byte_to_char[b as usize])
408                .collect();
409            if let Some(&id) = vocab.get(&encoded) {
410                return [id].into();
411            }
412        }
413
414        // Start with one token per byte.
415        let mut tokens: Vec<TokenId> = piece
416            .as_bytes()
417            .iter()
418            .map(|&b| self.byte_to_token_id[b as usize])
419            .collect();
420
421        // If the end-of-word suffix is enabled, replace the last byte's token
422        // with the one that corresponds to "{byte}{end_of_word_suffix}".
423        if self.end_of_word_suffix.is_some()
424            && end_of_word
425            && let Some(last) = tokens.pop()
426        {
427            tokens.push(last + 256);
428        }
429
430        // Iteratively merge tokens together until no more are possible.
431        bpe_merge(&mut tokens, &self.merges);
432
433        tokens
434    }
435}
436
437impl Model for Bpe {
438    fn get_token_str(&self, id: TokenId) -> Option<String> {
439        if let Some(tok_str) = self.added_tokens.get(&id) {
440            return Some(tok_str.to_string());
441        }
442        self.token_id_to_encoded_bytes.get(&id).cloned()
443    }
444
445    fn get_token_id(&self, mut text: &str) -> Option<TokenId> {
446        if let Some((&id, _str)) = self.added_tokens.iter().find(|(_id, str)| *str == text) {
447            return Some(id);
448        }
449
450        // Determine the end-of-word context. eg. In CLIP's tokenizer, the
451        // trailing "</w>" in "from</w>" indicates that it should be treated as
452        // occurring at the end of a piece from the initial split.
453        let mut end_of_word = false;
454        if let Some(suffix) = self.end_of_word_suffix.as_deref()
455            && text.ends_with(suffix)
456        {
457            text = &text[..text.len() - suffix.len()];
458            end_of_word = true;
459        }
460
461        let tokens = self.encode_piece(text, end_of_word);
462        if tokens.len() == 1 {
463            Some(tokens[0])
464        } else {
465            None
466        }
467    }
468
469    fn encode_with_offsets(
470        &self,
471        piece: &str,
472        on_token: &mut dyn FnMut(usize, TokenId),
473    ) -> Result<(), EncodeError> {
474        if piece.is_empty() {
475            return Ok(());
476        }
477        for token in self.encode_piece(piece, true /* end_of_word */) {
478            on_token(0, token)
479        }
480        Ok(())
481    }
482
483    fn decode(&self, ids: &[TokenId]) -> Result<String, DecodeError> {
484        let char_to_byte = char_to_byte();
485
486        let mut bytes = Vec::new();
487        for &id in ids {
488            if let Some(tok_str) = self.added_tokens.get(&id) {
489                bytes.extend(tok_str.as_bytes());
490            } else if let Some(encoded_bytes) = self.token_id_to_encoded_bytes.get(&id) {
491                bytes.extend(
492                    encoded_bytes
493                        .chars()
494                        .map(|ch| char_to_byte.get(&ch).copied().unwrap()),
495                );
496            } else {
497                return Err(DecodeError::InvalidTokenId(id));
498            }
499        }
500
501        String::from_utf8(bytes).map_err(|_| DecodeError::InvalidUtf8)
502    }
503}
504
505#[cfg(test)]
506mod tests {
507    use rten_testing::TestCases;
508    use rustc_hash::FxHashMap;
509
510    use super::{Bpe, BpeOptions, EncodedBytes, merge_pairs_from_lines};
511    use crate::pre_tokenizers::Split;
512    use crate::tokenizer::{TokenId, Tokenizer};
513
514    // The first ~25 lines of the merge list from GPT 2.
515    const MINI_GPT2: &str = "
516#version: 0.2
517Ġ t
518Ġ a
519h e
520i n
521r e
522o n
523Ġt he
524e r
525Ġ s
526a t
527Ġ w
528Ġ o
529e n
530Ġ c
531i t
532i s
533a n
534o r
535e s
536Ġ b
537e d
538Ġ f
539in g";
540
541    fn added_tokens() -> FxHashMap<TokenId, String> {
542        [(50256, "<|endoftext|>")]
543            .into_iter()
544            .map(|(id, str)| (id, str.to_string()))
545            .collect()
546    }
547
548    /// Generate a map from encoded token string to token ID.
549    ///
550    /// The token IDs are chosen to be different than the ones that would be
551    /// automatically generated based on the merge list, if the vocabulary was
552    /// not supplied.
553    fn gen_vocab() -> FxHashMap<EncodedBytes, TokenId> {
554        let mut next_token_id = 1000;
555        let mut vocab = minimal_vocab(next_token_id);
556        next_token_id += vocab.len() as u32;
557
558        for line in MINI_GPT2.lines().map(|l| l.trim()) {
559            if line.starts_with("#version") || line.is_empty() {
560                continue;
561            }
562            let token_str: EncodedBytes = line.chars().filter(|ch| *ch != ' ').collect();
563            vocab.insert(token_str, next_token_id);
564            next_token_id += 1;
565        }
566
567        vocab
568    }
569
570    /// Generate the simplest valid vocabulary.
571    fn minimal_vocab(start_token_id: u32) -> FxHashMap<EncodedBytes, TokenId> {
572        let mut vocab = FxHashMap::default();
573        let mut next_token_id = start_token_id;
574        for ch in super::char_to_byte().keys() {
575            vocab.insert(ch.to_string(), next_token_id);
576            next_token_id += 1;
577        }
578        vocab
579    }
580
581    #[test]
582    fn test_encode() {
583        #[derive(Debug)]
584        struct Case<'a> {
585            text: &'a str,
586            expected_tokens: &'a [&'a str],
587            merges: &'a str,
588            vocab: Option<FxHashMap<EncodedBytes, TokenId>>,
589            end_of_word_suffix: Option<String>,
590            ignore_merges: bool,
591        }
592
593        impl<'a> Default for Case<'a> {
594            fn default() -> Self {
595                Self {
596                    text: "",
597                    expected_tokens: &[],
598                    merges: "",
599                    vocab: None,
600                    end_of_word_suffix: None,
601                    ignore_merges: false,
602                }
603            }
604        }
605
606        let cases = [
607            // Minimal test using a snippet of the GPT-2 merge list.
608            Case {
609                text: "the cat is in the bed",
610                expected_tokens: &[
611                    "t", "he", "Ġc", "at", "Ġ", "is", "Ġ", "in", "Ġthe", "Ġb", "ed",
612                ],
613                merges: MINI_GPT2,
614                ..Default::default()
615            },
616            // Test several levels of merging.
617            Case {
618                text: "--------",
619                expected_tokens: &["--------"],
620                merges: "
621- -
622-- --
623---- ----
624-------- --------
625",
626                ..Default::default()
627            },
628            // End-of-word suffix
629            Case {
630                text: "barbar",
631                expected_tokens: &["bar", "bar</w>"],
632                merges: "
633b a
634ba r
635ba r</w>
636",
637                end_of_word_suffix: Some("</w>".to_string()),
638                ..Default::default()
639            },
640            // Empty end-of-word suffix. Treated as `None` for compatibility
641            // with some tokenizer.json files which represent the EOW suffix
642            // using `""` instead of `null`.
643            Case {
644                text: "barbar",
645                expected_tokens: &["bar", "bar"],
646                merges: "
647b a
648ba r",
649                end_of_word_suffix: Some("".to_string()),
650                ..Default::default()
651            },
652            // `ignore_merges` option enabled
653            Case {
654                text: "foobar",
655                expected_tokens: &["foobar"],
656                ignore_merges: true,
657                vocab: {
658                    let mut vocab = minimal_vocab(0);
659                    vocab.insert("foobar".to_string(), vocab.len() as u32);
660                    Some(vocab)
661                },
662                ..Default::default()
663            },
664        ];
665
666        cases.test_each(|case| {
667            let Case {
668                text,
669                expected_tokens: tokens,
670                merges,
671                vocab,
672                end_of_word_suffix,
673                ignore_merges,
674            } = case;
675
676            let merges: Vec<&str> = merges.lines().collect();
677            let merge_pairs = merge_pairs_from_lines(&merges);
678            let bpe_opts = BpeOptions {
679                merges: &merge_pairs,
680                vocab: vocab.clone(),
681                end_of_word_suffix: end_of_word_suffix.clone(),
682                ignore_merges: *ignore_merges,
683                added_tokens: Default::default(),
684            };
685            let model = Bpe::new(bpe_opts).unwrap();
686            let tokenizer = Tokenizer::new(model, Default::default())
687                .with_pre_tokenizer(Box::new(Split::gpt2()));
688            let encoded = tokenizer.encode(*text, None).unwrap();
689            assert_eq!(
690                tokenizer.model().get_tokens(encoded.token_ids()).unwrap(),
691                *tokens
692            );
693        })
694    }
695
696    #[test]
697    fn test_get_token_str() {
698        #[derive(Debug)]
699        struct Case<'a> {
700            input: &'a str,
701            encoded_str: &'a str,
702        }
703
704        let cases = [
705            // Printable ASCII text. Encoded string is same as input.
706            Case {
707                input: "a",
708                encoded_str: "a",
709            },
710            // Non-printable or non-ASCII text. Encoded string will use
711            // printable characters to represent these bytes.
712            Case {
713                input: " ",
714                encoded_str: "Ġ",
715            },
716            // Added tokens.
717            Case {
718                input: "<|endoftext|>",
719                encoded_str: "<|endoftext|>",
720            },
721        ];
722
723        let merges: Vec<&str> = MINI_GPT2.lines().collect();
724        let merge_pairs = merge_pairs_from_lines(&merges);
725
726        cases.test_each(|case| {
727            let bpe_opts = BpeOptions {
728                merges: &merge_pairs,
729                added_tokens: added_tokens(),
730                ..Default::default()
731            };
732            let model = Bpe::new(bpe_opts).unwrap();
733            let tokenizer = Tokenizer::new(model, Default::default())
734                .with_pre_tokenizer(Box::new(Split::gpt2()));
735
736            let tok_id = tokenizer.model().get_token_id(case.input).unwrap();
737            let token_str = tokenizer.model().get_token_str(tok_id).unwrap();
738            assert_eq!(token_str, case.encoded_str);
739        })
740    }
741
742    #[test]
743    fn test_decode() {
744        #[derive(Debug)]
745        struct Case<'a> {
746            text: &'a str,
747            add_eos: bool,
748            expected: &'a str,
749            vocab: Option<FxHashMap<EncodedBytes, TokenId>>,
750        }
751
752        let vocab = gen_vocab();
753
754        let cases = [
755            Case {
756                text: "foo bar",
757                add_eos: false,
758                expected: "foo bar",
759                vocab: None,
760            },
761            Case {
762                text: "foo bar",
763                add_eos: true,
764                expected: "foo bar<|endoftext|>",
765                vocab: None,
766            },
767            Case {
768                text: "the cat is in the bed",
769                add_eos: false,
770                expected: "the cat is in the bed",
771                vocab: None,
772            },
773            Case {
774                text: "the cat is in the bed",
775                add_eos: false,
776                expected: "the cat is in the bed",
777                vocab: Some(vocab),
778            },
779        ];
780
781        cases.test_each(|case| {
782            let Case {
783                text,
784                add_eos,
785                expected,
786                vocab,
787            } = case;
788
789            let merges: Vec<&str> = MINI_GPT2.lines().collect();
790            let merge_pairs = merge_pairs_from_lines(&merges);
791            let bpe_opts = BpeOptions {
792                merges: &merge_pairs,
793                vocab: vocab.clone(),
794                added_tokens: added_tokens(),
795                ..Default::default()
796            };
797            let model = Bpe::new(bpe_opts).unwrap();
798            let tokenizer = Tokenizer::new(model, Default::default())
799                .with_pre_tokenizer(Box::new(Split::gpt2()));
800
801            let encoded = tokenizer.encode(*text, None).unwrap();
802            let mut token_ids = encoded.token_ids().to_vec();
803            if *add_eos {
804                // The `<|endoftext|>` token ID from GPT-2.
805                token_ids.push(50256);
806            }
807            let decoded = tokenizer.decode(&token_ids).unwrap();
808            assert_eq!(decoded, *expected);
809        })
810    }
811}