Skip to main content

oxibonsai_tokenizer/
tokenizer.rs

1//! High-level OxiBonsai tokenizer: BPE + Unigram + WordPiece + char-level fallback.
2//!
3//! [`OxiTokenizer`] ties together a [`Vocabulary`], a [`BpeMerges`] table, and
4//! a [`TokenizerConfig`] into a complete encode/decode API that is
5//! `no_std`-friendly and WASM-compatible.
6//!
7//! When a [`crate::unigram::UnigramVocab`] is attached via
8//! [`OxiTokenizer::with_unigram`], encoding switches to Viterbi segmentation
9//! instead of BPE.
10//!
11//! When a [`crate::wordpiece::WordPieceVocab`] is attached via
12//! [`OxiTokenizer::with_wordpiece`], encoding switches to greedy WordPiece
13//! segmentation, which is the algorithm used by BERT, RoBERTa, DeBERTa,
14//! DistilBERT, and ALBERT.
15
16use std::collections::HashSet;
17
18use tracing::debug;
19
20use crate::{
21    bpe::{bpe_encode, byte_fallback_id, pretokenize, BpeMerges},
22    error::{TokenizerError, TokenizerResult},
23    vocab::Vocabulary,
24};
25
26// ── TokenizerConfig ───────────────────────────────────────────────────────────
27
28/// Configuration knobs for an [`OxiTokenizer`].
29///
30/// Marked `#[non_exhaustive]` so that new optional knobs can be added in
31/// future minor releases without breaking downstream code.  Inside this crate
32/// struct literals with `..Default::default()` continue to work.
33#[derive(Debug, Clone)]
34#[non_exhaustive]
35pub struct TokenizerConfig {
36    /// Whether to prepend a BOS (beginning-of-sequence) token.
37    pub add_bos: bool,
38    /// Whether to append an EOS (end-of-sequence) token.
39    pub add_eos: bool,
40    /// Token ID used for BOS.
41    pub bos_token_id: u32,
42    /// Token ID used for EOS.
43    pub eos_token_id: u32,
44    /// Token ID used for unknown tokens (fallback).
45    pub unk_token_id: u32,
46    /// Token ID used for padding.
47    pub pad_token_id: u32,
48    /// Optional maximum output length (tokens are truncated, not padded).
49    pub max_length: Option<usize>,
50    /// When `true`, the decoder applies the GPT-2 **bytes ↔ unicode** inverse
51    /// map to every token string before emitting bytes (see
52    /// [`crate::hf_format`]).  When `false`, the legacy `Ġ`-stripping path is
53    /// used (same behaviour as 0.1.x).
54    ///
55    /// `from_json_file` / `OxiTokenizer::from_hf_tokenizer_json` set this to
56    /// `true` automatically; hand-built configs default to `false` for
57    /// backwards compatibility.
58    pub byte_level_decode: bool,
59}
60
61impl Default for TokenizerConfig {
62    fn default() -> Self {
63        Self {
64            add_bos: false,
65            add_eos: false,
66            bos_token_id: 1,
67            eos_token_id: 2,
68            unk_token_id: 0,
69            pad_token_id: 3,
70            max_length: None,
71            byte_level_decode: false,
72        }
73    }
74}
75
76// ── OxiTokenizer ─────────────────────────────────────────────────────────────
77
78/// Pure Rust BPE / Unigram / WordPiece tokenizer compatible with MeCrab and the WASM target.
79///
80/// The tokenizer supports:
81/// - Standard BPE encoding via a merge table
82/// - Viterbi Unigram encoding (HuggingFace `"Unigram"` model type)
83/// - Greedy WordPiece encoding (HuggingFace `"WordPiece"` model type — BERT family)
84/// - Optional BOS/EOS injection
85/// - Byte-fallback for out-of-vocabulary bytes
86/// - Character-level mode (no trained vocab needed — useful in tests)
87pub struct OxiTokenizer {
88    vocab: Vocabulary,
89    merges: BpeMerges,
90    config: TokenizerConfig,
91    /// The set of special token IDs for quick membership tests.
92    special_ids: HashSet<u32>,
93    /// Optional Unigram vocabulary for Viterbi-based segmentation.
94    ///
95    /// When `Some`, the tokenizer dispatches to Unigram encoding instead of
96    /// BPE.  When `None`, the BPE or WordPiece path is used.
97    unigram: Option<crate::unigram::UnigramVocab>,
98    /// Optional WordPiece vocabulary for BERT-style greedy segmentation.
99    ///
100    /// When `Some`, the tokenizer dispatches to WordPiece encoding.  This
101    /// takes precedence over the BPE path but is checked after Unigram.
102    /// When `None`, the BPE path (or Unigram if attached) is used.
103    wordpiece: Option<crate::wordpiece::WordPieceVocab>,
104}
105
106impl OxiTokenizer {
107    /// Construct a tokenizer from pre-built components.
108    ///
109    /// Sets `unigram` and `wordpiece` to `None` — the BPE path is used for
110    /// encoding.
111    pub fn new(vocab: Vocabulary, merges: BpeMerges, config: TokenizerConfig) -> Self {
112        let special_ids = build_special_ids(&config);
113        Self {
114            vocab,
115            merges,
116            config,
117            special_ids,
118            unigram: None,
119            wordpiece: None,
120        }
121    }
122
123    /// Construct a Unigram tokenizer from pre-built components.
124    ///
125    /// The `unigram_vocab` is used for Viterbi-based segmentation; the `vocab`
126    /// is kept for decode operations (ID → token string).  An empty
127    /// [`BpeMerges`] table is stored for API consistency.
128    pub fn with_unigram(
129        vocab: Vocabulary,
130        unigram_vocab: crate::unigram::UnigramVocab,
131        config: TokenizerConfig,
132    ) -> Self {
133        let special_ids = build_special_ids(&config);
134        Self {
135            vocab,
136            merges: BpeMerges::new(),
137            config,
138            special_ids,
139            unigram: Some(unigram_vocab),
140            wordpiece: None,
141        }
142    }
143
144    /// Construct a WordPiece tokenizer from pre-built components.
145    ///
146    /// The `wordpiece_vocab` is used for greedy longest-match-first
147    /// segmentation (BERT/RoBERTa/DeBERTa model family); the `vocab` is kept
148    /// for decode operations (ID → token string).  An empty [`BpeMerges`]
149    /// table is stored for API consistency.
150    pub fn with_wordpiece(
151        vocab: Vocabulary,
152        wordpiece_vocab: crate::wordpiece::WordPieceVocab,
153        config: TokenizerConfig,
154    ) -> Self {
155        let special_ids = build_special_ids(&config);
156        Self {
157            vocab,
158            merges: BpeMerges::new(),
159            config,
160            special_ids,
161            unigram: None,
162            wordpiece: Some(wordpiece_vocab),
163        }
164    }
165
166    /// Return `true` if this tokenizer uses Unigram (Viterbi) segmentation.
167    pub fn is_unigram(&self) -> bool {
168        self.unigram.is_some()
169    }
170
171    /// Return `true` if this tokenizer uses WordPiece (BERT-family) segmentation.
172    pub fn is_wordpiece(&self) -> bool {
173        self.wordpiece.is_some()
174    }
175
176    /// Encode a single text string into a sequence of token IDs.
177    ///
178    /// Steps:
179    /// 1. Pre-tokenize into words.
180    /// 2. Encode each word via Unigram Viterbi (if attached) or BPE.
181    /// 3. Optionally prepend BOS and append EOS.
182    /// 4. Optionally truncate to `config.max_length`.
183    pub fn encode(&self, text: &str) -> TokenizerResult<Vec<u32>> {
184        debug!(text_len = text.len(), "encoding text");
185
186        let mut ids: Vec<u32> = Vec::new();
187
188        if self.config.add_bos {
189            ids.push(self.config.bos_token_id);
190        }
191
192        if let Some(wp) = &self.wordpiece {
193            // WordPiece path: greedy longest-match-first segmentation of the
194            // full text (the WordPieceVocab splits on whitespace internally).
195            let wp_ids = wp.encode(text);
196            ids.extend_from_slice(&wp_ids);
197        } else {
198            let words = pretokenize(text);
199            for word in &words {
200                if let Some(unigram) = &self.unigram {
201                    // Unigram path: Viterbi segmentation directly on the word.
202                    let word_ids = unigram.encode(word);
203                    ids.extend_from_slice(&word_ids);
204                } else {
205                    // BPE path: apply merge table.
206                    let word_ids = bpe_encode(word, &self.vocab, &self.merges);
207                    if word_ids.is_empty() {
208                        // Byte-fallback path: encode each UTF-8 byte explicitly.
209                        for byte in word.as_bytes() {
210                            let fallback = byte_fallback_id(*byte);
211                            let fallback_id = self.vocab.get_id(&fallback);
212                            ids.push(fallback_id.unwrap_or(self.config.unk_token_id));
213                        }
214                    } else {
215                        ids.extend_from_slice(&word_ids);
216                    }
217                }
218            }
219        }
220
221        if self.config.add_eos {
222            ids.push(self.config.eos_token_id);
223        }
224
225        // Truncate if configured.
226        if let Some(max) = self.config.max_length {
227            ids.truncate(max);
228        }
229
230        Ok(ids)
231    }
232
233    /// Encode a batch of texts in sequence (returns one `Vec<u32>` per input).
234    pub fn encode_batch(&self, texts: &[&str]) -> TokenizerResult<Vec<Vec<u32>>> {
235        texts.iter().map(|t| self.encode(t)).collect()
236    }
237
238    /// Decode a sequence of token IDs back into a string.
239    ///
240    /// Special tokens (BOS, EOS, PAD, UNK) are silently skipped.
241    /// Byte-fallback tokens (`<0xHH>`) are decoded back to their original byte.
242    /// Unknown IDs that are not in the vocabulary produce `\u{FFFD}` (replacement
243    /// character) rather than an error, to be maximally robust.
244    ///
245    /// When `config.byte_level_decode` is `true`, tokens are run through the
246    /// full 256-entry GPT-2 **unicode → byte** inverse map (see
247    /// [`crate::hf_format`]).  Otherwise the legacy `Ġ`-stripping path is used.
248    pub fn decode(&self, ids: &[u32]) -> TokenizerResult<String> {
249        let bytes = self.decode_to_bytes(ids);
250        String::from_utf8(bytes).map_err(|e| TokenizerError::DecodeFailed(e.to_string()))
251    }
252
253    /// Decode to raw bytes — used by both [`Self::decode`] and the streaming
254    /// decoder so that the two paths stay byte-for-byte identical.
255    pub(crate) fn decode_to_bytes(&self, ids: &[u32]) -> Vec<u8> {
256        let mut bytes: Vec<u8> = Vec::with_capacity(ids.len() * 2);
257
258        for &id in ids {
259            self.decode_id_into(id, &mut bytes);
260        }
261
262        bytes
263    }
264
265    /// Append the UTF-8 bytes for a single token ID to `bytes`.
266    ///
267    /// Special tokens are silently dropped.  Unknown IDs produce `\u{FFFD}`.
268    pub(crate) fn decode_id_into(&self, id: u32, bytes: &mut Vec<u8>) {
269        if self.special_ids.contains(&id) {
270            return;
271        }
272
273        let token = match self.vocab.get_token(id) {
274            Some(t) => t,
275            None => {
276                bytes.extend_from_slice("\u{FFFD}".as_bytes());
277                return;
278            }
279        };
280
281        // Byte-fallback tokens: `<0xHH>` → raw byte.
282        if let Some(byte) = parse_byte_fallback(token) {
283            bytes.push(byte);
284            return;
285        }
286
287        if self.config.byte_level_decode {
288            // Full GPT-2 bytes-to-unicode inverse mapping.
289            for ch in token.chars() {
290                if let Some(b) = crate::hf_format::unicode_to_byte(ch) {
291                    bytes.push(b);
292                } else {
293                    // Non-byte-level character — emit UTF-8 verbatim.
294                    let mut buf = [0u8; 4];
295                    let s = ch.encode_utf8(&mut buf);
296                    bytes.extend_from_slice(s.as_bytes());
297                }
298            }
299        } else {
300            // Legacy `Ġ`-stripping path — kept bit-for-bit identical to 0.1.x.
301            let stripped = token.trim_start_matches('\u{0120}');
302            if token.starts_with('\u{0120}') && !bytes.is_empty() {
303                bytes.push(b' ');
304            }
305            bytes.extend_from_slice(stripped.as_bytes());
306        }
307    }
308
309    /// Decode a single token ID to its string representation.
310    pub fn decode_token(&self, id: u32) -> TokenizerResult<String> {
311        self.vocab
312            .get_token(id)
313            .map(|s| s.to_owned())
314            .ok_or_else(|| TokenizerError::DecodeFailed(format!("unknown token id {id}")))
315    }
316
317    /// Return the total vocabulary size.
318    pub fn vocab_size(&self) -> usize {
319        self.vocab.size()
320    }
321
322    /// Construct a tokenizer from JSON-encoded vocabulary and merge lists.
323    ///
324    /// `vocab_json`: `{ "token": id, ... }`
325    /// `merges_json`: `[["a", "b"], ...]` (ordered from highest to lowest priority)
326    pub fn from_json(
327        vocab_json: &str,
328        merges_json: &str,
329        config: TokenizerConfig,
330    ) -> TokenizerResult<Self> {
331        let vocab = Vocabulary::from_json(vocab_json)?;
332
333        let raw_merges: Vec<(String, String)> = serde_json::from_str(merges_json)
334            .map_err(|e| TokenizerError::InvalidJson(e.to_string()))?;
335
336        let mut merges = BpeMerges::new();
337        for (a, b) in &raw_merges {
338            // The merged token name is the concatenation.
339            let merged = format!("{a}{b}");
340            let result_id = vocab.get_id(&merged).ok_or_else(|| {
341                TokenizerError::InvalidVocab(format!("merged token {merged:?} not in vocabulary"))
342            })?;
343            merges.add_merge(a, b, result_id);
344        }
345
346        Ok(Self::new(vocab, merges, config))
347    }
348
349    /// Load a tokenizer from a HuggingFace-style `tokenizer.json` file.
350    ///
351    /// This routes through [`crate::hf_format::HfTokenizerJson`] which:
352    ///
353    /// 1. Parses the `model.vocab` map (token → id).
354    /// 2. Parses the `model.merges` list (both string-pair and array-pair forms).
355    /// 3. Picks up the `added_tokens` / `special_tokens` block.
356    /// 4. Sets `byte_level_decode = true` on the returned config so that
357    ///    decode() correctly reverses the GPT-2 bytes-to-unicode map.
358    ///
359    /// Any field not expressible in [`TokenizerConfig`] (truncation policy,
360    /// normalizer variants, ...) is ignored but does not cause an error so
361    /// that loading a live HF file "just works".
362    pub fn from_json_file(path: impl AsRef<std::path::Path>) -> TokenizerResult<Self> {
363        let json = std::fs::read_to_string(path)?;
364        Self::from_hf_tokenizer_json(&json)
365    }
366
367    /// In-memory variant of [`Self::from_json_file`] that takes the JSON as a
368    /// `&str`.  Useful for WASM builds and for tests that embed a tokenizer
369    /// fixture verbatim.
370    pub fn from_hf_tokenizer_json(json: &str) -> TokenizerResult<Self> {
371        let parsed = crate::hf_format::HfTokenizerJson::parse(json)?;
372        parsed.into_tokenizer()
373    }
374
375    /// Begin streaming decode.  Returns a [`crate::streaming::StreamingDecoder`]
376    /// that keeps UTF-8 state across `push_token` calls — essential for server
377    /// code that emits one token at a time.
378    pub fn streaming_decoder(&self) -> crate::streaming::StreamingDecoder<'_> {
379        crate::streaming::StreamingDecoder::new(self)
380    }
381
382    /// Access the tokenizer configuration (read-only).
383    pub fn config(&self) -> &TokenizerConfig {
384        &self.config
385    }
386
387    /// Access the vocabulary (read-only).
388    pub fn vocab(&self) -> &Vocabulary {
389        &self.vocab
390    }
391
392    /// Access the merge table (read-only).
393    pub fn merges(&self) -> &BpeMerges {
394        &self.merges
395    }
396
397    /// Create a character-level tokenizer (no trained merges) for testing
398    /// and examples.
399    ///
400    /// Assigns IDs 4..vocab_size to printable ASCII characters (space = 4,
401    /// '!' = 5, ...) with IDs 0-3 reserved for UNK/BOS/EOS/PAD.
402    ///
403    /// This tokenizer has no BPE merges: each character is its own token.
404    /// The `_stub` suffix is retained for API compatibility.
405    pub fn char_level_stub(vocab_size: usize) -> Self {
406        assert!(
407            vocab_size >= 4,
408            "char_level_stub requires vocab_size >= 4 for special tokens"
409        );
410
411        let mut vocab = Vocabulary::new();
412        vocab.add_special("<unk>", 0);
413        vocab.add_special("<bos>", 1);
414        vocab.add_special("<eos>", 2);
415        vocab.add_special("<pad>", 3);
416
417        // Fill remaining slots with printable ASCII characters.
418        let mut next_id = 4u32;
419        for byte in 0x20u8..=0x7Eu8 {
420            if next_id as usize >= vocab_size {
421                break;
422            }
423            let ch = char::from(byte);
424            vocab.insert(&ch.to_string(), next_id);
425            next_id += 1;
426        }
427
428        // Also populate byte-fallback tokens for any remaining slots.
429        for byte in 0u8..=255u8 {
430            if next_id as usize >= vocab_size {
431                break;
432            }
433            let fallback = byte_fallback_id(byte);
434            if vocab.get_id(&fallback).is_none() {
435                vocab.insert(&fallback, next_id);
436                next_id += 1;
437            }
438        }
439
440        let config = TokenizerConfig {
441            add_bos: false,
442            add_eos: false,
443            bos_token_id: 1,
444            eos_token_id: 2,
445            unk_token_id: 0,
446            pad_token_id: 3,
447            max_length: None,
448            byte_level_decode: false,
449        };
450
451        let merges = BpeMerges::new();
452        // Use Self::new which initialises both unigram and wordpiece to None.
453        Self::new(vocab, merges, config)
454    }
455
456    // ── Special token helpers ─────────────────────────────────────────────
457
458    /// Return the BOS token ID from the configuration.
459    pub fn bos_id(&self) -> u32 {
460        self.config.bos_token_id
461    }
462
463    /// Return the EOS token ID from the configuration.
464    pub fn eos_id(&self) -> u32 {
465        self.config.eos_token_id
466    }
467
468    /// Return `true` if `id` is one of the configured special token IDs.
469    pub fn is_special(&self, id: u32) -> bool {
470        self.special_ids.contains(&id)
471    }
472}
473
474// ── Private helpers ───────────────────────────────────────────────────────────
475
476/// Build the set of special token IDs from a config.
477fn build_special_ids(config: &TokenizerConfig) -> HashSet<u32> {
478    let mut set = HashSet::new();
479    set.insert(config.bos_token_id);
480    set.insert(config.eos_token_id);
481    set.insert(config.unk_token_id);
482    set.insert(config.pad_token_id);
483    set
484}
485
486/// Parse a byte-fallback token like `<0x41>` and return the byte value.
487///
488/// Returns `None` if the token is not in the `<0xHH>` format.
489fn parse_byte_fallback(token: &str) -> Option<u8> {
490    let inner = token.strip_prefix("<0x")?.strip_suffix('>')?;
491    if inner.len() != 2 {
492        return None;
493    }
494    u8::from_str_radix(inner, 16).ok()
495}
496
497// ── Tests ─────────────────────────────────────────────────────────────────────
498
499#[cfg(test)]
500mod tests {
501    use super::*;
502
503    #[test]
504    fn char_level_stub_encode_ascii() {
505        let tok = OxiTokenizer::char_level_stub(200);
506        let ids = tok.encode("ab").expect("encode should succeed");
507        // Each char should map to a consistent non-zero ID.
508        assert_eq!(ids.len(), 2);
509        assert_ne!(ids[0], 0); // not UNK
510        assert_ne!(ids[1], 0);
511        assert_ne!(ids[0], ids[1]); // 'a' ≠ 'b'
512    }
513
514    #[test]
515    fn char_level_stub_bos_eos() {
516        let mut tok = OxiTokenizer::char_level_stub(200);
517        tok.config.add_bos = true;
518        tok.config.add_eos = true;
519        tok.special_ids = build_special_ids(&tok.config);
520        let ids = tok.encode("hi").expect("encode should succeed");
521        assert_eq!(ids[0], 1); // BOS
522        assert_eq!(*ids.last().expect("must have last element"), 2); // EOS
523    }
524
525    #[test]
526    fn char_level_stub_vocab_size() {
527        let tok = OxiTokenizer::char_level_stub(50);
528        assert!(tok.vocab_size() <= 50);
529        assert!(tok.vocab_size() >= 4); // at least special tokens
530    }
531
532    #[test]
533    fn special_token_detection() {
534        let tok = OxiTokenizer::char_level_stub(200);
535        assert!(tok.is_special(0)); // UNK
536        assert!(tok.is_special(1)); // BOS
537        assert!(tok.is_special(2)); // EOS
538        assert!(tok.is_special(3)); // PAD
539        assert!(!tok.is_special(4)); // first real token
540    }
541
542    #[test]
543    fn bos_eos_ids_match_config() {
544        let tok = OxiTokenizer::char_level_stub(200);
545        assert_eq!(tok.bos_id(), 1);
546        assert_eq!(tok.eos_id(), 2);
547    }
548
549    #[test]
550    fn decode_token_roundtrip() {
551        let tok = OxiTokenizer::char_level_stub(200);
552        // 'a' should map to some ID; we can look it up.
553        let ids = tok.encode("a").expect("should encode");
554        if let Some(&id) = ids.first() {
555            let s = tok.decode_token(id).expect("decode_token should succeed");
556            assert_eq!(s, "a");
557        }
558    }
559
560    #[test]
561    fn decode_unknown_id_returns_error() {
562        let tok = OxiTokenizer::char_level_stub(50);
563        let result = tok.decode_token(99_999);
564        assert!(result.is_err());
565    }
566
567    #[test]
568    fn max_length_truncates() {
569        let mut tok = OxiTokenizer::char_level_stub(200);
570        tok.config.max_length = Some(3);
571        tok.special_ids = build_special_ids(&tok.config);
572        let ids = tok.encode("hello world").expect("encode should succeed");
573        assert!(ids.len() <= 3);
574    }
575
576    #[test]
577    fn encode_batch_consistency() {
578        let tok = OxiTokenizer::char_level_stub(200);
579        let texts = ["ab", "cd", "ef"];
580        let batch = tok
581            .encode_batch(&texts)
582            .expect("batch encode should succeed");
583        assert_eq!(batch.len(), 3);
584        for (i, ids) in batch.iter().enumerate() {
585            let single = tok.encode(texts[i]).expect("single encode should succeed");
586            assert_eq!(*ids, single);
587        }
588    }
589
590    #[test]
591    fn parse_byte_fallback_valid() {
592        assert_eq!(parse_byte_fallback("<0x41>"), Some(0x41));
593        assert_eq!(parse_byte_fallback("<0x00>"), Some(0x00));
594        assert_eq!(parse_byte_fallback("<0xFF>"), Some(0xFF));
595    }
596
597    #[test]
598    fn parse_byte_fallback_invalid() {
599        assert_eq!(parse_byte_fallback("hello"), None);
600        assert_eq!(parse_byte_fallback("<0x>"), None);
601        assert_eq!(parse_byte_fallback("<0x1>"), None);
602    }
603
604    #[test]
605    fn from_json_roundtrip() {
606        let vocab_json = r#"{"a":10,"b":11,"ab":20,"<unk>":0,"<bos>":1,"<eos>":2,"<pad>":3}"#;
607        let merges_json = r#"[["a","b"]]"#;
608        let config = TokenizerConfig::default();
609        let tok = OxiTokenizer::from_json(vocab_json, merges_json, config)
610            .expect("from_json should succeed");
611        assert_eq!(tok.vocab_size(), 7);
612        // Encoding "ab" should produce a single merged token 20.
613        let ids = tok.encode("ab").expect("encode should succeed");
614        assert!(ids.contains(&20));
615    }
616
617    #[test]
618    fn is_unigram_false_for_bpe() {
619        let tok = OxiTokenizer::char_level_stub(200);
620        assert!(!tok.is_unigram());
621    }
622}