1use anyhow::{bail, Result};
2use std::sync::Arc;
3use tiktoken_rs::{CoreBPE, Rank};
4use toktrie::{TokEnv, TokRxInfo, TokTrie, TokenId, TokenizerEnv};
5
6pub struct TikTokenBPE {
7 pub bpe: CoreBPE,
8 tok_trie: TokTrie,
9}
10
11impl TikTokenBPE {
12 pub fn new(
13 encoder: Vec<(Vec<u8>, Rank)>,
14 special_tokens_encoder: Vec<(String, Rank)>,
15 pattern: &str,
16 n_vocab_override: Option<usize>,
17 eos_token: u32,
18 ) -> Result<TikTokenBPE> {
19 let mut n_vocab = encoder.len() + special_tokens_encoder.len();
20 let mut tokens = vec![vec![]; n_vocab];
21
22 for (bytes, idx) in encoder.iter() {
23 while tokens.len() <= *idx as usize {
24 tokens.push(vec![]);
25 }
26 tokens[*idx as usize] = bytes.clone();
27 }
28
29 for (name, idx) in special_tokens_encoder.iter() {
30 while tokens.len() <= *idx as usize {
31 tokens.push(vec![]);
32 }
33 let mut spec_bytes = Vec::with_capacity(name.len() + 1);
34 spec_bytes.push(TokTrie::SPECIAL_TOKEN_MARKER);
35 spec_bytes.extend_from_slice(name.as_bytes());
36 tokens[*idx as usize] = spec_bytes;
37 }
38
39 n_vocab = tokens.len();
40
41 if let Some(n_vocab_override) = n_vocab_override {
42 if n_vocab_override < n_vocab {
43 bail!("vocab size too small; {} vs {}", n_vocab_override, n_vocab);
44 }
45 n_vocab = n_vocab_override;
46 tokens.resize(n_vocab, vec![]);
47 }
48
49 for (i, token) in tokens.iter_mut().enumerate() {
50 if token.is_empty() {
51 let mut name = format!(".<[{i}]>").into_bytes();
52 name[0] = TokTrie::SPECIAL_TOKEN_MARKER;
53 *token = name;
54 }
55 }
56
57 let tok_trie = TokTrie::from(
58 &TokRxInfo {
59 vocab_size: n_vocab as u32,
60 tok_eos: eos_token,
61 tok_end_of_turn: None,
62 tok_unk: None,
63 tok_pad: None,
64 tok_bos: None,
65 },
66 &tokens,
67 );
68
69 let bpe = CoreBPE::new(
70 encoder.into_iter().collect(),
71 special_tokens_encoder.into_iter().collect(),
72 pattern,
73 )?;
74
75 Ok(TikTokenBPE { bpe, tok_trie })
76 }
77
78 pub fn tokrx_info(&self) -> TokRxInfo {
79 *self.tok_trie.info()
80 }
81
82 pub fn to_env(self) -> TokEnv {
83 Arc::new(self)
84 }
85}
86
87impl TokenizerEnv for TikTokenBPE {
88 fn tok_trie(&self) -> &TokTrie {
89 &self.tok_trie
90 }
91
92 fn tokenize_bytes(&self, s: &[u8]) -> Vec<TokenId> {
93 self.tok_trie
94 .tokenize_with_greedy_fallback(s, |s| self.bpe.encode_ordinary(s))
95 }
96
97 fn tokenize_bytes_special(&self, s: &[u8]) -> Vec<TokenId> {
98 self.tok_trie.tokenize_with_greedy_fallback(s, |s| {
99 self.tok_trie
100 .tokenize_with_special(s, |s| self.bpe.encode_ordinary(s))
101 })
102 }
103}