1use std::cell::RefCell;
15use std::collections::HashMap;
16
17use regex::Regex;
18
19use crate::byte_encoder::{encode_byte_level_chars, METASPACE};
20use crate::map::TokenizerMap;
21
22pub trait ITokenizer: Send {
30 fn id(&self) -> &str;
32 fn encode(&self, text: &str) -> Vec<u32>;
34}
35
36pub struct BPETokenizer {
42 id: String,
43 vocab: HashMap<String, u32>,
44 merge_ranks: HashMap<String, u32>,
46 pre_tok_regex: Option<Regex>,
47 pre_tok_program: Option<crate::pretok_program::PreTokProgram>,
51 encoder: String,
52 byte_fallback_start: i64,
54 cache: RefCell<HashMap<String, Vec<u32>>>,
57 special_ids: HashMap<String, u32>,
64 special_regex: Option<Regex>,
65}
66
67impl BPETokenizer {
68 pub fn supports(map: &TokenizerMap) -> bool {
70 let has_vocab = map.vocab.as_ref().is_some_and(|v| !v.is_empty());
71 let has_merges = map.merges.as_ref().is_some_and(|v| !v.is_empty());
72 let enc_ok = matches!(map.encoder.as_deref(), Some("byte_level") | Some("metaspace"));
73 has_vocab && has_merges && enc_ok
74 }
75
76 pub fn new(map: &TokenizerMap) -> Result<Self, String> {
79 if !Self::supports(map) {
80 return Err(format!(
81 "BPETokenizer: map \"{}\" lacks vocab/merges/encoder. \
82 Use BPETokenizer::supports(map) to check first, or call \
83 Tokenize::pick(map) which falls back to LongestMatchTokenizer.",
84 map.id
85 ));
86 }
87
88 let vocab = map.vocab.as_ref().expect("supports() checked").clone();
89 let merges = map.merges.as_ref().expect("supports() checked");
90 let encoder = map.encoder.as_ref().expect("supports() checked").clone();
91 let id = map.id.clone();
92 let byte_fallback_start = map.byte_fallback_start.unwrap_or(-1);
93
94 let mut merge_ranks: HashMap<String, u32> = HashMap::with_capacity(merges.len());
95 for (i, m) in merges.iter().enumerate() {
96 merge_ranks.insert(m.clone(), i as u32);
97 }
98
99 let (pre_tok_regex, pre_tok_program) = if encoder == "byte_level" {
104 if let Some(prog) = map.pre_tokenizer_program.as_ref() {
105 if prog.ops.is_empty() {
106 return Err(format!(
107 "BPETokenizer: byte_level map \"{}\" has empty pre_tokenizer_program.",
108 map.id
109 ));
110 }
111 (None, Some(prog.clone()))
112 } else if let Some(pat) = map.pre_tokenizer_pattern.as_ref() {
113 let re = Regex::new(pat)
114 .map_err(|e| format!("BPETokenizer: invalid pre_tokenizer_pattern: {e}"))?;
115 (Some(re), None)
116 } else {
117 return Err(format!(
118 "BPETokenizer: byte_level map \"{}\" missing both pre_tokenizer_program and pre_tokenizer_pattern.",
119 map.id
120 ));
121 }
122 } else {
123 (None, None)
124 };
125
126 let mut special_ids: HashMap<String, u32> = HashMap::new();
136 if let Some(specials) = map.special_tokens.as_ref() {
137 for (name, id) in specials.iter() {
138 special_ids.insert(name.clone(), *id);
139 }
140 }
141 for (tok, id) in vocab.iter() {
142 if special_ids.contains_key(tok) {
143 continue;
144 }
145 if is_delimiter_shape(tok) {
146 special_ids.insert(tok.clone(), *id);
147 }
148 }
149 let special_regex = if special_ids.is_empty() {
150 None
151 } else {
152 let mut keys: Vec<&String> = special_ids.keys().collect();
153 keys.sort_by_key(|k| std::cmp::Reverse(k.len()));
154 let alt = keys
155 .iter()
156 .map(|k| regex::escape(k))
157 .collect::<Vec<_>>()
158 .join("|");
159 Some(
160 Regex::new(&alt)
161 .map_err(|e| format!("BPETokenizer: bad special-token regex: {e}"))?,
162 )
163 };
164
165 Ok(Self {
166 id,
167 vocab,
168 merge_ranks,
169 pre_tok_regex,
170 pre_tok_program,
171 encoder,
172 byte_fallback_start,
173 cache: RefCell::new(HashMap::new()),
174 special_ids,
175 special_regex,
176 })
177 }
178
179 pub fn encode(&self, text: &str) -> Vec<u32> {
181 if text.is_empty() {
182 return Vec::new();
183 }
184
185 if let Some(re) = self.special_regex.as_ref() {
186 let mut ids: Vec<u32> = Vec::new();
187 let mut cursor = 0usize;
188 for m in re.find_iter(text) {
189 if m.start() > cursor {
190 self.encode_chunk(&text[cursor..m.start()], &mut ids);
191 }
192 ids.push(self.special_ids[m.as_str()]);
193 cursor = m.end();
194 }
195 if cursor < text.len() {
196 self.encode_chunk(&text[cursor..], &mut ids);
197 }
198 return ids;
199 }
200
201 let mut ids: Vec<u32> = Vec::new();
202 self.encode_chunk(text, &mut ids);
203 ids
204 }
205
206 fn encode_chunk(&self, text: &str, out: &mut Vec<u32>) {
208 if text.is_empty() {
209 return;
210 }
211 let pieces = self.pre_tokenize(text);
212 for piece in pieces {
213 if let Ok(cache) = self.cache.try_borrow() {
214 if let Some(cached) = cache.get(&piece) {
215 out.extend_from_slice(cached);
216 continue;
217 }
218 }
219 let encoded = self.encode_piece_to_vocab_space(&piece);
220 let merged = self.apply_bpe(encoded);
221 let piece_ids = self.lookup(&merged);
222 if let Ok(mut cache) = self.cache.try_borrow_mut() {
223 cache.insert(piece.clone(), piece_ids.clone());
224 }
225 out.extend_from_slice(&piece_ids);
226 }
227 }
228
229 fn pre_tokenize(&self, text: &str) -> Vec<String> {
232 if self.encoder == "byte_level" {
233 if let Some(prog) = self.pre_tok_program.as_ref() {
234 return crate::pretok_program::run_pretok_program(prog, text);
235 }
236 let re = self.pre_tok_regex.as_ref().expect("byte_level requires regex or program");
237 return re.find_iter(text).map(|m| m.as_str().to_string()).collect();
238 }
239
240 let collapsed = collapse_spaces_and_tabs(text);
244 let parts = split_keep_whitespace(&collapsed);
245 let mut pieces: Vec<String> = Vec::new();
246 for p in parts {
247 if p == " " {
248 continue;
249 }
250 let mut s = String::with_capacity(p.len() + 3);
251 s.push(METASPACE);
252 s.push_str(&p);
253 pieces.push(s);
254 }
255 pieces
256 }
257
258 fn encode_piece_to_vocab_space(&self, piece: &str) -> Vec<String> {
261 if self.encoder == "byte_level" {
262 let bytes = piece.as_bytes();
263 let encoded = encode_byte_level_chars(bytes);
264 return codepoints(&encoded);
265 }
266 codepoints(piece)
267 }
268
269 fn apply_bpe(&self, tokens: Vec<String>) -> Vec<String> {
272 if tokens.len() < 2 {
273 return tokens;
274 }
275 let mut parts = tokens;
276 loop {
277 let mut best_idx: Option<usize> = None;
278 let mut best_rank: u32 = u32::MAX;
279 for i in 0..parts.len() - 1 {
280 let mut key = String::with_capacity(parts[i].len() + 1 + parts[i + 1].len());
282 key.push_str(&parts[i]);
283 key.push(' ');
284 key.push_str(&parts[i + 1]);
285 if let Some(&r) = self.merge_ranks.get(&key) {
286 if r < best_rank {
287 best_rank = r;
288 best_idx = Some(i);
289 }
290 }
291 }
292 let Some(_idx) = best_idx else {
293 break;
294 };
295
296 let left = parts[best_idx.unwrap()].clone();
297 let right = parts[best_idx.unwrap() + 1].clone();
298 let merged = format!("{left}{right}");
299
300 let mut next: Vec<String> = Vec::with_capacity(parts.len());
302 let mut j = 0;
303 while j < parts.len() {
304 if j + 1 < parts.len() && parts[j] == left && parts[j + 1] == right {
305 next.push(merged.clone());
306 j += 2;
307 } else {
308 next.push(parts[j].clone());
309 j += 1;
310 }
311 }
312 parts = next;
313 }
314 parts
315 }
316
317 fn lookup(&self, tokens: &[String]) -> Vec<u32> {
320 let mut ids: Vec<u32> = Vec::with_capacity(tokens.len());
321 for tok in tokens {
322 if let Some(&id) = self.vocab.get(tok) {
323 ids.push(id);
324 continue;
325 }
326 if self.byte_fallback_start >= 0 {
327 for &b in tok.as_bytes() {
328 ids.push((self.byte_fallback_start + b as i64) as u32);
329 }
330 }
331 }
333 ids
334 }
335}
336
337impl ITokenizer for BPETokenizer {
338 fn id(&self) -> &str {
339 &self.id
340 }
341 fn encode(&self, text: &str) -> Vec<u32> {
342 BPETokenizer::encode(self, text)
343 }
344}
345
346fn codepoints(s: &str) -> Vec<String> {
349 s.chars().map(|c| c.to_string()).collect()
350}
351
352fn collapse_spaces_and_tabs(s: &str) -> String {
353 let mut out = String::with_capacity(s.len());
354 let mut prev_space = false;
355 for c in s.chars() {
356 if c == ' ' || c == '\t' {
357 if !prev_space {
358 out.push(' ');
359 prev_space = true;
360 }
361 } else {
362 out.push(c);
363 prev_space = false;
364 }
365 }
366 out
367}
368
369fn is_delimiter_shape(tok: &str) -> bool {
374 if tok.len() <= 4 {
375 return false;
376 }
377 let bytes = tok.as_bytes();
378 if !(bytes.starts_with(b"<|") && bytes.ends_with(b"|>")) {
379 return false;
380 }
381 let body = &tok[2..tok.len() - 2];
382 !body.is_empty()
383 && body
384 .chars()
385 .all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-')
386}
387
388fn split_keep_whitespace(s: &str) -> Vec<String> {
391 let mut parts: Vec<String> = Vec::new();
392 let mut buf = String::new();
393 for c in s.chars() {
394 if c.is_whitespace() {
395 if !buf.is_empty() {
396 parts.push(std::mem::take(&mut buf));
397 }
398 parts.push(c.to_string());
399 } else {
400 buf.push(c);
401 }
402 }
403 if !buf.is_empty() {
404 parts.push(buf);
405 }
406 parts
407}