Skip to main content

codec_rs/
tokenize.rs

1// SPDX-License-Identifier: MIT
2//! Pure-Rust BPE encoder. Text → token IDs.
3//!
4//! Required for the bidirectional Codec endpoint where the client wants
5//! to send token-ID prompts (zero text on the wire in either direction).
6//!
7//! ## Algorithm (for both byte_level and metaspace BPE)
8//!
9//! 1. Pre-tokenize: split input into pieces (regex for byte_level; whitespace for metaspace).
10//! 2. Encode each piece into the vocab's character space (GPT-2 byte chars or `▁`-prefixed).
11//! 3. Apply BPE merges greedily by priority — match HuggingFace reference.
12//! 4. Look up final tokens in `vocab`. Tokens not in vocab fall back to byte tokens (metaspace path).
13
14use std::cell::RefCell;
15use std::collections::HashMap;
16
17use regex::Regex;
18
19use crate::byte_encoder::{encode_byte_level_chars, METASPACE};
20use crate::map::TokenizerMap;
21
22/// Common interface every tokenizer implementation satisfies.
23///
24/// Implemented by [`BPETokenizer`] and [`crate::longest_match::LongestMatchTokenizer`].
25///
26/// The trait deliberately does not require `Sync` — `BPETokenizer` keeps
27/// a `RefCell`-backed encode cache (mirroring the .NET `Dictionary`).
28/// Wrap in `Mutex` for cross-thread sharing.
29pub trait ITokenizer: Send {
30    /// Identifier of the underlying vocabulary.
31    fn id(&self) -> &str;
32    /// Encode a string to a sequence of token IDs.
33    fn encode(&self, text: &str) -> Vec<u32>;
34}
35
36/// Pure-Rust BPE encoder.
37///
38/// Construct via [`BPETokenizer::new`]; check
39/// [`BPETokenizer::supports`] first if you don't know whether the map has
40/// the data BPE needs (use [`crate::Tokenize::pick`] which falls back).
41pub struct BPETokenizer {
42    id: String,
43    vocab: HashMap<String, u32>,
44    /// Map of `"left right"` → priority rank (lower = higher priority).
45    merge_ranks: HashMap<String, u32>,
46    pre_tok_regex: Option<Regex>,
47    /// Compiled pre-tokenizer program; preferred over the regex when present.
48    /// Bypasses the regex engine entirely — unblocks GPT-2-family maps whose
49    /// `(?i:...)` and `(?!\S)` syntax the `regex` crate doesn't support.
50    pre_tok_program: Option<crate::pretok_program::PreTokProgram>,
51    encoder: String,
52    /// `i64` so a missing fallback (-1) is comparable safely against IDs.
53    byte_fallback_start: i64,
54    /// Per-piece encode cache. Mutex-free RefCell — BPETokenizer is `!Sync`
55    /// but `Send`. (Translator only needs `Send` and constructs its own.)
56    cache: RefCell<HashMap<String, Vec<u32>>>,
57    /// Special-token scanner. Built from `map.special_tokens` plus any vocab
58    /// key in `<|body|>` shape with a non-empty identifier-like body. HF's
59    /// reference tokenizer splits input on registered specials BEFORE running
60    /// BPE — emit each match as the atomic vocab ID, BPE the surrounding
61    /// text. Required for chat templates (`<|im_start|>...<|im_end|>`),
62    /// tool-call delimiters, FIM markers, etc. to round-trip with HF.
63    special_ids: HashMap<String, u32>,
64    special_regex: Option<Regex>,
65}
66
67impl BPETokenizer {
68    /// Returns true if the map has the data BPETokenizer needs.
69    pub fn supports(map: &TokenizerMap) -> bool {
70        let has_vocab = map.vocab.as_ref().is_some_and(|v| !v.is_empty());
71        let has_merges = map.merges.as_ref().is_some_and(|v| !v.is_empty());
72        let enc_ok = matches!(map.encoder.as_deref(), Some("byte_level") | Some("metaspace"));
73        has_vocab && has_merges && enc_ok
74    }
75
76    /// Construct a `BPETokenizer` from a map. Returns an error message if
77    /// the map lacks the required vocab/merges/encoder.
78    pub fn new(map: &TokenizerMap) -> Result<Self, String> {
79        if !Self::supports(map) {
80            return Err(format!(
81                "BPETokenizer: map \"{}\" lacks vocab/merges/encoder. \
82                 Use BPETokenizer::supports(map) to check first, or call \
83                 Tokenize::pick(map) which falls back to LongestMatchTokenizer.",
84                map.id
85            ));
86        }
87
88        let vocab = map.vocab.as_ref().expect("supports() checked").clone();
89        let merges = map.merges.as_ref().expect("supports() checked");
90        let encoder = map.encoder.as_ref().expect("supports() checked").clone();
91        let id = map.id.clone();
92        let byte_fallback_start = map.byte_fallback_start.unwrap_or(-1);
93
94        let mut merge_ranks: HashMap<String, u32> = HashMap::with_capacity(merges.len());
95        for (i, m) in merges.iter().enumerate() {
96            merge_ranks.insert(m.clone(), i as u32);
97        }
98
99        // Pre-tokenizer: prefer the compiled program when present, otherwise
100        // fall back to the legacy regex. Programs bypass the regex engine
101        // entirely — required for GPT-2-family maps because `regex` doesn't
102        // support `(?i:...)` inline-flag groups or `(?!\S)` lookaround.
103        let (pre_tok_regex, pre_tok_program) = if encoder == "byte_level" {
104            if let Some(prog) = map.pre_tokenizer_program.as_ref() {
105                if prog.ops.is_empty() {
106                    return Err(format!(
107                        "BPETokenizer: byte_level map \"{}\" has empty pre_tokenizer_program.",
108                        map.id
109                    ));
110                }
111                (None, Some(prog.clone()))
112            } else if let Some(pat) = map.pre_tokenizer_pattern.as_ref() {
113                let re = Regex::new(pat)
114                    .map_err(|e| format!("BPETokenizer: invalid pre_tokenizer_pattern: {e}"))?;
115                (Some(re), None)
116            } else {
117                return Err(format!(
118                    "BPETokenizer: byte_level map \"{}\" missing both pre_tokenizer_program and pre_tokenizer_pattern.",
119                    map.id
120                ));
121            }
122        } else {
123            (None, None)
124        };
125
126        // Build the special-token scanner. Accept entries from
127        // `map.special_tokens` AND any vocab key in `<|body|>` shape
128        // with a non-empty identifier-like body — older maps shipped
129        // before a chat-template revision may carry the delimiters in
130        // `vocab` but not in `special_tokens`. Length-descending regex
131        // alternation order so longer delimiters match before shorter
132        // prefixes. Without this pre-scan, `<|im_start|>` would
133        // tokenise byte-by-byte instead of as the single atomic vocab
134        // ID (151644 for Qwen-2.5).
135        let mut special_ids: HashMap<String, u32> = HashMap::new();
136        if let Some(specials) = map.special_tokens.as_ref() {
137            for (name, id) in specials.iter() {
138                special_ids.insert(name.clone(), *id);
139            }
140        }
141        for (tok, id) in vocab.iter() {
142            if special_ids.contains_key(tok) {
143                continue;
144            }
145            if is_delimiter_shape(tok) {
146                special_ids.insert(tok.clone(), *id);
147            }
148        }
149        let special_regex = if special_ids.is_empty() {
150            None
151        } else {
152            let mut keys: Vec<&String> = special_ids.keys().collect();
153            keys.sort_by_key(|k| std::cmp::Reverse(k.len()));
154            let alt = keys
155                .iter()
156                .map(|k| regex::escape(k))
157                .collect::<Vec<_>>()
158                .join("|");
159            Some(
160                Regex::new(&alt)
161                    .map_err(|e| format!("BPETokenizer: bad special-token regex: {e}"))?,
162            )
163        };
164
165        Ok(Self {
166            id,
167            vocab,
168            merge_ranks,
169            pre_tok_regex,
170            pre_tok_program,
171            encoder,
172            byte_fallback_start,
173            cache: RefCell::new(HashMap::new()),
174            special_ids,
175            special_regex,
176        })
177    }
178
179    /// Encode text → token IDs.
180    pub fn encode(&self, text: &str) -> Vec<u32> {
181        if text.is_empty() {
182            return Vec::new();
183        }
184
185        if let Some(re) = self.special_regex.as_ref() {
186            let mut ids: Vec<u32> = Vec::new();
187            let mut cursor = 0usize;
188            for m in re.find_iter(text) {
189                if m.start() > cursor {
190                    self.encode_chunk(&text[cursor..m.start()], &mut ids);
191                }
192                ids.push(self.special_ids[m.as_str()]);
193                cursor = m.end();
194            }
195            if cursor < text.len() {
196                self.encode_chunk(&text[cursor..], &mut ids);
197            }
198            return ids;
199        }
200
201        let mut ids: Vec<u32> = Vec::new();
202        self.encode_chunk(text, &mut ids);
203        ids
204    }
205
206    /// BPE-encode a chunk of plain text into `out`.
207    fn encode_chunk(&self, text: &str, out: &mut Vec<u32>) {
208        if text.is_empty() {
209            return;
210        }
211        let pieces = self.pre_tokenize(text);
212        for piece in pieces {
213            if let Ok(cache) = self.cache.try_borrow() {
214                if let Some(cached) = cache.get(&piece) {
215                    out.extend_from_slice(cached);
216                    continue;
217                }
218            }
219            let encoded = self.encode_piece_to_vocab_space(&piece);
220            let merged = self.apply_bpe(encoded);
221            let piece_ids = self.lookup(&merged);
222            if let Ok(mut cache) = self.cache.try_borrow_mut() {
223                cache.insert(piece.clone(), piece_ids.clone());
224            }
225            out.extend_from_slice(&piece_ids);
226        }
227    }
228
229    // ── Pre-tokenization ────────────────────────────────────────────────────
230
231    fn pre_tokenize(&self, text: &str) -> Vec<String> {
232        if self.encoder == "byte_level" {
233            if let Some(prog) = self.pre_tok_program.as_ref() {
234                return crate::pretok_program::run_pretok_program(prog, text);
235            }
236            let re = self.pre_tok_regex.as_ref().expect("byte_level requires regex or program");
237            return re.find_iter(text).map(|m| m.as_str().to_string()).collect();
238        }
239
240        // Metaspace: split on whitespace, prefix every word with ▁.
241        // Collapse internal runs of spaces/tabs to a single space first
242        // (matches .NET).
243        let collapsed = collapse_spaces_and_tabs(text);
244        let parts = split_keep_whitespace(&collapsed);
245        let mut pieces: Vec<String> = Vec::new();
246        for p in parts {
247            if p == " " {
248                continue;
249            }
250            let mut s = String::with_capacity(p.len() + 3);
251            s.push(METASPACE);
252            s.push_str(&p);
253            pieces.push(s);
254        }
255        pieces
256    }
257
258    // ── Step 2: piece → vocab character space ──────────────────────────────
259
260    fn encode_piece_to_vocab_space(&self, piece: &str) -> Vec<String> {
261        if self.encoder == "byte_level" {
262            let bytes = piece.as_bytes();
263            let encoded = encode_byte_level_chars(bytes);
264            return codepoints(&encoded);
265        }
266        codepoints(piece)
267    }
268
269    // ── Step 3: BPE merges ─────────────────────────────────────────────────
270
271    fn apply_bpe(&self, tokens: Vec<String>) -> Vec<String> {
272        if tokens.len() < 2 {
273            return tokens;
274        }
275        let mut parts = tokens;
276        loop {
277            let mut best_idx: Option<usize> = None;
278            let mut best_rank: u32 = u32::MAX;
279            for i in 0..parts.len() - 1 {
280                // Build "left right" without alloc churn — small strings only here.
281                let mut key = String::with_capacity(parts[i].len() + 1 + parts[i + 1].len());
282                key.push_str(&parts[i]);
283                key.push(' ');
284                key.push_str(&parts[i + 1]);
285                if let Some(&r) = self.merge_ranks.get(&key) {
286                    if r < best_rank {
287                        best_rank = r;
288                        best_idx = Some(i);
289                    }
290                }
291            }
292            let Some(_idx) = best_idx else {
293                break;
294            };
295
296            let left = parts[best_idx.unwrap()].clone();
297            let right = parts[best_idx.unwrap() + 1].clone();
298            let merged = format!("{left}{right}");
299
300            // Merge ALL non-overlapping occurrences in one pass — matches HF.
301            let mut next: Vec<String> = Vec::with_capacity(parts.len());
302            let mut j = 0;
303            while j < parts.len() {
304                if j + 1 < parts.len() && parts[j] == left && parts[j + 1] == right {
305                    next.push(merged.clone());
306                    j += 2;
307                } else {
308                    next.push(parts[j].clone());
309                    j += 1;
310                }
311            }
312            parts = next;
313        }
314        parts
315    }
316
317    // ── Step 4: vocab lookup with byte fallback ────────────────────────────
318
319    fn lookup(&self, tokens: &[String]) -> Vec<u32> {
320        let mut ids: Vec<u32> = Vec::with_capacity(tokens.len());
321        for tok in tokens {
322            if let Some(&id) = self.vocab.get(tok) {
323                ids.push(id);
324                continue;
325            }
326            if self.byte_fallback_start >= 0 {
327                for &b in tok.as_bytes() {
328                    ids.push((self.byte_fallback_start + b as i64) as u32);
329                }
330            }
331            // For byte_level this is unreachable for valid UTF-8 input.
332        }
333        ids
334    }
335}
336
337impl ITokenizer for BPETokenizer {
338    fn id(&self) -> &str {
339        &self.id
340    }
341    fn encode(&self, text: &str) -> Vec<u32> {
342        BPETokenizer::encode(self, text)
343    }
344}
345
346// Helpers ------------------------------------------------------------------
347
348fn codepoints(s: &str) -> Vec<String> {
349    s.chars().map(|c| c.to_string()).collect()
350}
351
352fn collapse_spaces_and_tabs(s: &str) -> String {
353    let mut out = String::with_capacity(s.len());
354    let mut prev_space = false;
355    for c in s.chars() {
356        if c == ' ' || c == '\t' {
357            if !prev_space {
358                out.push(' ');
359                prev_space = true;
360            }
361        } else {
362            out.push(c);
363            prev_space = false;
364        }
365    }
366    out
367}
368
369/// Match `<|body|>` where `body` is non-empty and identifier-like
370/// (letters/digits/`_`/`-`). Catches every shipped chat-template and
371/// tool-call delimiter while excluding pathological vocab BPE tokens
372/// like Falcon's `<|>` (id 61799) that share the start/end pair.
373fn is_delimiter_shape(tok: &str) -> bool {
374    if tok.len() <= 4 {
375        return false;
376    }
377    let bytes = tok.as_bytes();
378    if !(bytes.starts_with(b"<|") && bytes.ends_with(b"|>")) {
379        return false;
380    }
381    let body = &tok[2..tok.len() - 2];
382    !body.is_empty()
383        && body
384            .chars()
385            .all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-')
386}
387
388/// Split on whitespace, keeping each whitespace char as its own segment
389/// (mirrors .NET `Regex.Split(text, "(\\s)")`).
390fn split_keep_whitespace(s: &str) -> Vec<String> {
391    let mut parts: Vec<String> = Vec::new();
392    let mut buf = String::new();
393    for c in s.chars() {
394        if c.is_whitespace() {
395            if !buf.is_empty() {
396                parts.push(std::mem::take(&mut buf));
397            }
398            parts.push(c.to_string());
399        } else {
400            buf.push(c);
401        }
402    }
403    if !buf.is_empty() {
404        parts.push(buf);
405    }
406    parts
407}