toktrie_hf_tokenizers/
lib.rs

1use anyhow::{anyhow, bail, Result};
2use std::{
3    collections::{HashMap, HashSet},
4    path::Path,
5    sync::Arc,
6};
7use tokenizers::{normalizers::Sequence, NormalizerWrapper, Tokenizer};
8use toktrie::{TokEnv, TokRxInfo, TokTrie, TokenId, TokenizerEnv};
9
10pub struct ByteTokenizer {
11    pub hf_model: String,
12    pub hf_tokenizer: Tokenizer,
13    info: TokRxInfo,
14    token_bytes: Vec<Vec<u8>>,
15}
16
17// useful when debugging this: https://www.cogsci.ed.ac.uk/~richard/utf-8.cgi
18
19fn is_self_mapped(c: char) -> bool {
20    matches!(c, '!'..='~' | '\u{00A1}'..='\u{00AC}' | '\u{00AE}'..='\u{00FF}')
21}
22
23fn build_char_map() -> HashMap<char, u8> {
24    let mut res = HashMap::default();
25    let mut k = 0x100u32;
26    for byte in 0..=255u8 {
27        let c = byte as char;
28        if is_self_mapped(c) {
29            res.insert(c, byte);
30        } else {
31            res.insert(char::from_u32(k).unwrap(), byte);
32            k += 1;
33        }
34    }
35    res
36}
37
38impl ByteTokenizer {
39    pub fn from_file(name: impl AsRef<Path>) -> Result<ByteTokenizer> {
40        let name_str = name.as_ref().display().to_string();
41        let tok = Tokenizer::from_file(name)
42            .map_err(|e| anyhow!("error loading tokenizer: {}: {}", name_str, e))?;
43        ByteTokenizer::from_tokenizer(tok)
44    }
45
46    pub fn from_json_bytes(bytes: &[u8]) -> Result<ByteTokenizer> {
47        let tok =
48            Tokenizer::from_bytes(bytes).map_err(|e| anyhow!("error loading tokenizer: {}", e))?;
49        ByteTokenizer::from_tokenizer(tok)
50    }
51
52    pub fn from_tokenizer(mut hft: Tokenizer) -> Result<ByteTokenizer> {
53        let mut is_byte_level = false;
54        let mut is_byte_fallback = false;
55        let mut space_ch = ' ';
56
57        // remove the "Prepend space"
58        if let Some(n) = hft.get_normalizer() {
59            let n = match n {
60                NormalizerWrapper::Sequence(x) => NormalizerWrapper::Sequence(Sequence::new(
61                    x.get_normalizers()
62                        .iter()
63                        .filter_map(|n| match n {
64                            NormalizerWrapper::Prepend(_) => None,
65                            _ => Some(n.clone()),
66                        })
67                        .collect(),
68                )),
69                _ => n.clone(),
70            };
71            hft.with_normalizer(Some(n));
72        }
73
74        if let Some(d) = hft.get_decoder() {
75            // DecoderWrapper::Sequence() doesn't let one access the decoders
76            // so we resort to json munching
77            let v = serde_json::to_value(d).unwrap();
78            if v["type"].as_str() == Some("ByteLevel") {
79                is_byte_level = true;
80            } else if v["type"].as_str() == Some("Sequence") {
81                if let Some(decoders) = v["decoders"].as_array() {
82                    for decoder in decoders {
83                        if decoder["type"].as_str() == Some("ByteFallback") {
84                            is_byte_fallback = true;
85                        } else if decoder["type"].as_str() == Some("Replace")
86                            && decoder["content"].as_str() == Some(" ")
87                        {
88                            if let Some(s) = decoder["pattern"]["String"].as_str() {
89                                let s: Vec<char> = s.chars().collect();
90                                if s.len() == 1 {
91                                    space_ch = s[0];
92                                }
93                            }
94                        }
95                    }
96                }
97            }
98        }
99
100        if !is_byte_fallback && !is_byte_level {
101            bail!("can't determine decoder type: {:?}", hft.get_decoder());
102        }
103
104        let vocab_size = hft.get_vocab_size(true) as u32;
105        let added = hft.get_added_tokens_decoder();
106
107        let mut res = ByteTokenizer {
108            hf_model: "foobar".to_string(),
109            info: TokRxInfo::new(vocab_size, 0),
110            token_bytes: (0..vocab_size).map(|_| Vec::new()).collect(),
111            hf_tokenizer: hft,
112        };
113
114        let mut specials = HashSet::new();
115
116        for (id, info) in added.iter() {
117            if info.special {
118                match info.content.as_str() {
119                    "</s>"
120                    | "<|endoftext|>"
121                    | "<|end_of_text|>"
122                    | "<|end▁of▁sentence|>" // funky bars from DeepSeek tokenizer
123                    | "<eos>" => res.info.tok_eos = *id,
124
125                    "<|end|>" | "<|eot_id|>" | "<|im_end|>" => res.info.tok_end_of_turn = Some(*id),
126                    "<unk>" | "<|unk|>" => res.info.tok_unk = Some(*id),
127                    "<pad>" | "<|pad|>" => res.info.tok_pad = Some(*id),
128                    _ => {}
129                }
130                specials.insert(*id);
131            } else {
132                res.token_bytes[*id as usize] = info.content.clone().into_bytes();
133            }
134        }
135
136        let char_map = build_char_map();
137
138        for tok_id in 0..vocab_size {
139            if let Some(tok_name) = res.hf_tokenizer.id_to_token(tok_id) {
140                let bytes = if specials.contains(&tok_id) {
141                    let mut bytes = tok_name.as_bytes().to_vec();
142                    bytes.insert(0, TokTrie::SPECIAL_TOKEN_MARKER);
143                    bytes
144                } else if is_byte_fallback {
145                    if tok_name.len() == 6 && tok_name.starts_with("<0x") && tok_name.ends_with(">")
146                    {
147                        // parse hex number from tok_name
148                        let hex_str = &tok_name[3..5];
149                        let byte = u8::from_str_radix(hex_str, 16).unwrap();
150                        vec![byte]
151                    } else {
152                        assert!(!tok_name.starts_with("<0x"));
153                        let tok_name = tok_name.replace(space_ch, " ");
154                        tok_name.as_bytes().to_vec()
155                    }
156                } else if is_byte_level {
157                    let bytes: Result<Vec<u8>> = tok_name
158                        .chars()
159                        .map(|c| {
160                            char_map
161                                .get(&c)
162                                .copied()
163                                .ok_or_else(|| anyhow!("missing char: {}", c))
164                        })
165                        .collect();
166                    match bytes {
167                        Ok(b) => b,
168                        Err(e) => {
169                            log::warn!("error: {} for {:?}", e, tok_name);
170                            continue;
171                        }
172                    }
173                } else {
174                    panic!();
175                };
176                res.token_bytes[tok_id as usize] = bytes;
177            } else {
178                log::warn!("missing token: {}", tok_id);
179            }
180        }
181
182        Ok(res)
183    }
184
185    pub fn tokrx_info(&self) -> TokRxInfo {
186        self.info
187    }
188    pub fn token_bytes(&self) -> Vec<Vec<u8>> {
189        self.token_bytes.clone()
190    }
191
192    pub fn set_eos_token(&mut self, tok_id: u32) {
193        self.info.tok_eos = tok_id;
194    }
195
196    pub fn into_tok_env(self, n_vocab: Option<usize>) -> Result<TokEnv> {
197        let b = ByteTokenizerEnv::new(self, n_vocab)?;
198        Ok(b.to_env())
199    }
200}
201
202pub struct ByteTokenizerEnv {
203    pub tokenizer: ByteTokenizer,
204    pub tok_trie: TokTrie,
205}
206
207impl ByteTokenizerEnv {
208    pub fn new(tokenizer: ByteTokenizer, n_vocab: Option<usize>) -> Result<ByteTokenizerEnv> {
209        let mut info = tokenizer.tokrx_info();
210        let mut token_bytes = tokenizer.token_bytes();
211        if let Some(n_vocab) = n_vocab {
212            if n_vocab < token_bytes.len() {
213                bail!("vocab size too small; {} vs {}", n_vocab, token_bytes.len());
214            }
215            while n_vocab > token_bytes.len() {
216                token_bytes.push(Vec::new());
217            }
218            info.vocab_size = n_vocab as u32;
219        }
220        let tok_trie = TokTrie::from(&info, &token_bytes);
221        Ok(ByteTokenizerEnv {
222            tokenizer,
223            tok_trie,
224        })
225    }
226
227    pub fn to_env(self) -> TokEnv {
228        Arc::new(self)
229    }
230}
231
232impl TokenizerEnv for ByteTokenizerEnv {
233    fn tok_trie(&self) -> &TokTrie {
234        &self.tok_trie
235    }
236
237    fn tokenize_bytes(&self, s: &[u8]) -> Vec<TokenId> {
238        self.tok_trie.tokenize_with_greedy_fallback(s, |s| {
239            self.tokenizer
240                .hf_tokenizer
241                .encode(s, false)
242                .expect("tokenizer error")
243                .get_ids()
244                .to_vec()
245        })
246    }
247}