toktrie_tiktoken/
lib.rs

1use anyhow::{bail, Result};
2use std::sync::Arc;
3use tiktoken_rs::{CoreBPE, Rank};
4use toktrie::{TokEnv, TokRxInfo, TokTrie, TokenId, TokenizerEnv};
5
6pub struct TikTokenBPE {
7    pub bpe: CoreBPE,
8    tok_trie: TokTrie,
9}
10
11impl TikTokenBPE {
12    pub fn new(
13        encoder: Vec<(Vec<u8>, Rank)>,
14        special_tokens_encoder: Vec<(String, Rank)>,
15        pattern: &str,
16        n_vocab_override: Option<usize>,
17        eos_token: u32,
18    ) -> Result<TikTokenBPE> {
19        let mut n_vocab = encoder.len() + special_tokens_encoder.len();
20        let mut tokens = vec![vec![]; n_vocab];
21
22        for (bytes, idx) in encoder.iter() {
23            while tokens.len() <= *idx as usize {
24                tokens.push(vec![]);
25            }
26            tokens[*idx as usize] = bytes.clone();
27        }
28
29        for (name, idx) in special_tokens_encoder.iter() {
30            while tokens.len() <= *idx as usize {
31                tokens.push(vec![]);
32            }
33            let mut spec_bytes = Vec::with_capacity(name.len() + 1);
34            spec_bytes.push(TokTrie::SPECIAL_TOKEN_MARKER);
35            spec_bytes.extend_from_slice(name.as_bytes());
36            tokens[*idx as usize] = spec_bytes;
37        }
38
39        n_vocab = tokens.len();
40
41        if let Some(n_vocab_override) = n_vocab_override {
42            if n_vocab_override < n_vocab {
43                bail!("vocab size too small; {} vs {}", n_vocab_override, n_vocab);
44            }
45            n_vocab = n_vocab_override;
46            tokens.resize(n_vocab, vec![]);
47        }
48
49        for (i, token) in tokens.iter_mut().enumerate() {
50            if token.is_empty() {
51                let mut name = format!(".<[{i}]>").into_bytes();
52                name[0] = TokTrie::SPECIAL_TOKEN_MARKER;
53                *token = name;
54            }
55        }
56
57        let tok_trie = TokTrie::from(
58            &TokRxInfo {
59                vocab_size: n_vocab as u32,
60                tok_eos: eos_token,
61                tok_end_of_turn: None,
62                tok_unk: None,
63                tok_pad: None,
64                tok_bos: None,
65            },
66            &tokens,
67        );
68
69        let bpe = CoreBPE::new(
70            encoder.into_iter().collect(),
71            special_tokens_encoder.into_iter().collect(),
72            pattern,
73        )?;
74
75        Ok(TikTokenBPE { bpe, tok_trie })
76    }
77
78    pub fn tokrx_info(&self) -> TokRxInfo {
79        *self.tok_trie.info()
80    }
81
82    pub fn to_env(self) -> TokEnv {
83        Arc::new(self)
84    }
85}
86
87impl TokenizerEnv for TikTokenBPE {
88    fn tok_trie(&self) -> &TokTrie {
89        &self.tok_trie
90    }
91
92    fn tokenize_bytes(&self, s: &[u8]) -> Vec<TokenId> {
93        self.tok_trie
94            .tokenize_with_greedy_fallback(s, |s| self.bpe.encode_ordinary(s))
95    }
96
97    fn tokenize_bytes_special(&self, s: &[u8]) -> Vec<TokenId> {
98        self.tok_trie.tokenize_with_greedy_fallback(s, |s| {
99            self.tok_trie
100                .tokenize_with_special(s, |s| self.bpe.encode_ordinary(s))
101        })
102    }
103}