oxibonsai_tokenizer/tokenizer.rs
1//! High-level OxiBonsai tokenizer: BPE + Unigram + WordPiece + char-level fallback.
2//!
3//! [`OxiTokenizer`] ties together a [`Vocabulary`], a [`BpeMerges`] table, and
4//! a [`TokenizerConfig`] into a complete encode/decode API that is
5//! `no_std`-friendly and WASM-compatible.
6//!
7//! When a [`crate::unigram::UnigramVocab`] is attached via
8//! [`OxiTokenizer::with_unigram`], encoding switches to Viterbi segmentation
9//! instead of BPE.
10//!
11//! When a [`crate::wordpiece::WordPieceVocab`] is attached via
12//! [`OxiTokenizer::with_wordpiece`], encoding switches to greedy WordPiece
13//! segmentation, which is the algorithm used by BERT, RoBERTa, DeBERTa,
14//! DistilBERT, and ALBERT.
15
16use std::collections::HashSet;
17
18use tracing::debug;
19
20use crate::{
21 bpe::{bpe_encode, byte_fallback_id, pretokenize, BpeMerges},
22 error::{TokenizerError, TokenizerResult},
23 vocab::Vocabulary,
24};
25
26// ── TokenizerConfig ───────────────────────────────────────────────────────────
27
28/// Configuration knobs for an [`OxiTokenizer`].
29///
30/// Marked `#[non_exhaustive]` so that new optional knobs can be added in
31/// future minor releases without breaking downstream code. Inside this crate
32/// struct literals with `..Default::default()` continue to work.
33#[derive(Debug, Clone)]
34#[non_exhaustive]
35pub struct TokenizerConfig {
36 /// Whether to prepend a BOS (beginning-of-sequence) token.
37 pub add_bos: bool,
38 /// Whether to append an EOS (end-of-sequence) token.
39 pub add_eos: bool,
40 /// Token ID used for BOS.
41 pub bos_token_id: u32,
42 /// Token ID used for EOS.
43 pub eos_token_id: u32,
44 /// Token ID used for unknown tokens (fallback).
45 pub unk_token_id: u32,
46 /// Token ID used for padding.
47 pub pad_token_id: u32,
48 /// Optional maximum output length (tokens are truncated, not padded).
49 pub max_length: Option<usize>,
50 /// When `true`, the decoder applies the GPT-2 **bytes ↔ unicode** inverse
51 /// map to every token string before emitting bytes (see
52 /// [`crate::hf_format`]). When `false`, the legacy `Ġ`-stripping path is
53 /// used (same behaviour as 0.1.x).
54 ///
55 /// `from_json_file` / `OxiTokenizer::from_hf_tokenizer_json` set this to
56 /// `true` automatically; hand-built configs default to `false` for
57 /// backwards compatibility.
58 pub byte_level_decode: bool,
59}
60
61impl Default for TokenizerConfig {
62 fn default() -> Self {
63 Self {
64 add_bos: false,
65 add_eos: false,
66 bos_token_id: 1,
67 eos_token_id: 2,
68 unk_token_id: 0,
69 pad_token_id: 3,
70 max_length: None,
71 byte_level_decode: false,
72 }
73 }
74}
75
76// ── OxiTokenizer ─────────────────────────────────────────────────────────────
77
78/// Pure Rust BPE / Unigram / WordPiece tokenizer compatible with MeCrab and the WASM target.
79///
80/// The tokenizer supports:
81/// - Standard BPE encoding via a merge table
82/// - Viterbi Unigram encoding (HuggingFace `"Unigram"` model type)
83/// - Greedy WordPiece encoding (HuggingFace `"WordPiece"` model type — BERT family)
84/// - Optional BOS/EOS injection
85/// - Byte-fallback for out-of-vocabulary bytes
86/// - Character-level mode (no trained vocab needed — useful in tests)
87pub struct OxiTokenizer {
88 vocab: Vocabulary,
89 merges: BpeMerges,
90 config: TokenizerConfig,
91 /// The set of special token IDs for quick membership tests.
92 special_ids: HashSet<u32>,
93 /// Optional Unigram vocabulary for Viterbi-based segmentation.
94 ///
95 /// When `Some`, the tokenizer dispatches to Unigram encoding instead of
96 /// BPE. When `None`, the BPE or WordPiece path is used.
97 unigram: Option<crate::unigram::UnigramVocab>,
98 /// Optional WordPiece vocabulary for BERT-style greedy segmentation.
99 ///
100 /// When `Some`, the tokenizer dispatches to WordPiece encoding. This
101 /// takes precedence over the BPE path but is checked after Unigram.
102 /// When `None`, the BPE path (or Unigram if attached) is used.
103 wordpiece: Option<crate::wordpiece::WordPieceVocab>,
104}
105
106impl OxiTokenizer {
107 /// Construct a tokenizer from pre-built components.
108 ///
109 /// Sets `unigram` and `wordpiece` to `None` — the BPE path is used for
110 /// encoding.
111 pub fn new(vocab: Vocabulary, merges: BpeMerges, config: TokenizerConfig) -> Self {
112 let special_ids = build_special_ids(&config);
113 Self {
114 vocab,
115 merges,
116 config,
117 special_ids,
118 unigram: None,
119 wordpiece: None,
120 }
121 }
122
123 /// Construct a Unigram tokenizer from pre-built components.
124 ///
125 /// The `unigram_vocab` is used for Viterbi-based segmentation; the `vocab`
126 /// is kept for decode operations (ID → token string). An empty
127 /// [`BpeMerges`] table is stored for API consistency.
128 pub fn with_unigram(
129 vocab: Vocabulary,
130 unigram_vocab: crate::unigram::UnigramVocab,
131 config: TokenizerConfig,
132 ) -> Self {
133 let special_ids = build_special_ids(&config);
134 Self {
135 vocab,
136 merges: BpeMerges::new(),
137 config,
138 special_ids,
139 unigram: Some(unigram_vocab),
140 wordpiece: None,
141 }
142 }
143
144 /// Construct a WordPiece tokenizer from pre-built components.
145 ///
146 /// The `wordpiece_vocab` is used for greedy longest-match-first
147 /// segmentation (BERT/RoBERTa/DeBERTa model family); the `vocab` is kept
148 /// for decode operations (ID → token string). An empty [`BpeMerges`]
149 /// table is stored for API consistency.
150 pub fn with_wordpiece(
151 vocab: Vocabulary,
152 wordpiece_vocab: crate::wordpiece::WordPieceVocab,
153 config: TokenizerConfig,
154 ) -> Self {
155 let special_ids = build_special_ids(&config);
156 Self {
157 vocab,
158 merges: BpeMerges::new(),
159 config,
160 special_ids,
161 unigram: None,
162 wordpiece: Some(wordpiece_vocab),
163 }
164 }
165
166 /// Return `true` if this tokenizer uses Unigram (Viterbi) segmentation.
167 pub fn is_unigram(&self) -> bool {
168 self.unigram.is_some()
169 }
170
171 /// Return `true` if this tokenizer uses WordPiece (BERT-family) segmentation.
172 pub fn is_wordpiece(&self) -> bool {
173 self.wordpiece.is_some()
174 }
175
176 /// Encode a single text string into a sequence of token IDs.
177 ///
178 /// Steps:
179 /// 1. Pre-tokenize into words.
180 /// 2. Encode each word via Unigram Viterbi (if attached) or BPE.
181 /// 3. Optionally prepend BOS and append EOS.
182 /// 4. Optionally truncate to `config.max_length`.
183 pub fn encode(&self, text: &str) -> TokenizerResult<Vec<u32>> {
184 debug!(text_len = text.len(), "encoding text");
185
186 let mut ids: Vec<u32> = Vec::new();
187
188 if self.config.add_bos {
189 ids.push(self.config.bos_token_id);
190 }
191
192 if let Some(wp) = &self.wordpiece {
193 // WordPiece path: greedy longest-match-first segmentation of the
194 // full text (the WordPieceVocab splits on whitespace internally).
195 let wp_ids = wp.encode(text);
196 ids.extend_from_slice(&wp_ids);
197 } else {
198 let words = pretokenize(text);
199 for word in &words {
200 if let Some(unigram) = &self.unigram {
201 // Unigram path: Viterbi segmentation directly on the word.
202 let word_ids = unigram.encode(word);
203 ids.extend_from_slice(&word_ids);
204 } else {
205 // BPE path: apply merge table.
206 let word_ids = bpe_encode(word, &self.vocab, &self.merges);
207 if word_ids.is_empty() {
208 // Byte-fallback path: encode each UTF-8 byte explicitly.
209 for byte in word.as_bytes() {
210 let fallback = byte_fallback_id(*byte);
211 let fallback_id = self.vocab.get_id(&fallback);
212 ids.push(fallback_id.unwrap_or(self.config.unk_token_id));
213 }
214 } else {
215 ids.extend_from_slice(&word_ids);
216 }
217 }
218 }
219 }
220
221 if self.config.add_eos {
222 ids.push(self.config.eos_token_id);
223 }
224
225 // Truncate if configured.
226 if let Some(max) = self.config.max_length {
227 ids.truncate(max);
228 }
229
230 Ok(ids)
231 }
232
233 /// Encode a batch of texts in sequence (returns one `Vec<u32>` per input).
234 pub fn encode_batch(&self, texts: &[&str]) -> TokenizerResult<Vec<Vec<u32>>> {
235 texts.iter().map(|t| self.encode(t)).collect()
236 }
237
238 /// Decode a sequence of token IDs back into a string.
239 ///
240 /// Special tokens (BOS, EOS, PAD, UNK) are silently skipped.
241 /// Byte-fallback tokens (`<0xHH>`) are decoded back to their original byte.
242 /// Unknown IDs that are not in the vocabulary produce `\u{FFFD}` (replacement
243 /// character) rather than an error, to be maximally robust.
244 ///
245 /// When `config.byte_level_decode` is `true`, tokens are run through the
246 /// full 256-entry GPT-2 **unicode → byte** inverse map (see
247 /// [`crate::hf_format`]). Otherwise the legacy `Ġ`-stripping path is used.
248 pub fn decode(&self, ids: &[u32]) -> TokenizerResult<String> {
249 let bytes = self.decode_to_bytes(ids);
250 String::from_utf8(bytes).map_err(|e| TokenizerError::DecodeFailed(e.to_string()))
251 }
252
253 /// Decode to raw bytes — used by both [`Self::decode`] and the streaming
254 /// decoder so that the two paths stay byte-for-byte identical.
255 pub(crate) fn decode_to_bytes(&self, ids: &[u32]) -> Vec<u8> {
256 let mut bytes: Vec<u8> = Vec::with_capacity(ids.len() * 2);
257
258 for &id in ids {
259 self.decode_id_into(id, &mut bytes);
260 }
261
262 bytes
263 }
264
265 /// Append the UTF-8 bytes for a single token ID to `bytes`.
266 ///
267 /// Special tokens are silently dropped. Unknown IDs produce `\u{FFFD}`.
268 pub(crate) fn decode_id_into(&self, id: u32, bytes: &mut Vec<u8>) {
269 if self.special_ids.contains(&id) {
270 return;
271 }
272
273 let token = match self.vocab.get_token(id) {
274 Some(t) => t,
275 None => {
276 bytes.extend_from_slice("\u{FFFD}".as_bytes());
277 return;
278 }
279 };
280
281 // Byte-fallback tokens: `<0xHH>` → raw byte.
282 if let Some(byte) = parse_byte_fallback(token) {
283 bytes.push(byte);
284 return;
285 }
286
287 if self.config.byte_level_decode {
288 // Full GPT-2 bytes-to-unicode inverse mapping.
289 for ch in token.chars() {
290 if let Some(b) = crate::hf_format::unicode_to_byte(ch) {
291 bytes.push(b);
292 } else {
293 // Non-byte-level character — emit UTF-8 verbatim.
294 let mut buf = [0u8; 4];
295 let s = ch.encode_utf8(&mut buf);
296 bytes.extend_from_slice(s.as_bytes());
297 }
298 }
299 } else {
300 // Legacy `Ġ`-stripping path — kept bit-for-bit identical to 0.1.x.
301 let stripped = token.trim_start_matches('\u{0120}');
302 if token.starts_with('\u{0120}') && !bytes.is_empty() {
303 bytes.push(b' ');
304 }
305 bytes.extend_from_slice(stripped.as_bytes());
306 }
307 }
308
309 /// Decode a single token ID to its string representation.
310 pub fn decode_token(&self, id: u32) -> TokenizerResult<String> {
311 self.vocab
312 .get_token(id)
313 .map(|s| s.to_owned())
314 .ok_or_else(|| TokenizerError::DecodeFailed(format!("unknown token id {id}")))
315 }
316
317 /// Return the total vocabulary size.
318 pub fn vocab_size(&self) -> usize {
319 self.vocab.size()
320 }
321
322 /// Construct a tokenizer from JSON-encoded vocabulary and merge lists.
323 ///
324 /// `vocab_json`: `{ "token": id, ... }`
325 /// `merges_json`: `[["a", "b"], ...]` (ordered from highest to lowest priority)
326 pub fn from_json(
327 vocab_json: &str,
328 merges_json: &str,
329 config: TokenizerConfig,
330 ) -> TokenizerResult<Self> {
331 let vocab = Vocabulary::from_json(vocab_json)?;
332
333 let raw_merges: Vec<(String, String)> = serde_json::from_str(merges_json)
334 .map_err(|e| TokenizerError::InvalidJson(e.to_string()))?;
335
336 let mut merges = BpeMerges::new();
337 for (a, b) in &raw_merges {
338 // The merged token name is the concatenation.
339 let merged = format!("{a}{b}");
340 let result_id = vocab.get_id(&merged).ok_or_else(|| {
341 TokenizerError::InvalidVocab(format!("merged token {merged:?} not in vocabulary"))
342 })?;
343 merges.add_merge(a, b, result_id);
344 }
345
346 Ok(Self::new(vocab, merges, config))
347 }
348
349 /// Load a tokenizer from a HuggingFace-style `tokenizer.json` file.
350 ///
351 /// This routes through [`crate::hf_format::HfTokenizerJson`] which:
352 ///
353 /// 1. Parses the `model.vocab` map (token → id).
354 /// 2. Parses the `model.merges` list (both string-pair and array-pair forms).
355 /// 3. Picks up the `added_tokens` / `special_tokens` block.
356 /// 4. Sets `byte_level_decode = true` on the returned config so that
357 /// decode() correctly reverses the GPT-2 bytes-to-unicode map.
358 ///
359 /// Any field not expressible in [`TokenizerConfig`] (truncation policy,
360 /// normalizer variants, ...) is ignored but does not cause an error so
361 /// that loading a live HF file "just works".
362 pub fn from_json_file(path: impl AsRef<std::path::Path>) -> TokenizerResult<Self> {
363 let json = std::fs::read_to_string(path)?;
364 Self::from_hf_tokenizer_json(&json)
365 }
366
367 /// In-memory variant of [`Self::from_json_file`] that takes the JSON as a
368 /// `&str`. Useful for WASM builds and for tests that embed a tokenizer
369 /// fixture verbatim.
370 pub fn from_hf_tokenizer_json(json: &str) -> TokenizerResult<Self> {
371 let parsed = crate::hf_format::HfTokenizerJson::parse(json)?;
372 parsed.into_tokenizer()
373 }
374
375 /// Begin streaming decode. Returns a [`crate::streaming::StreamingDecoder`]
376 /// that keeps UTF-8 state across `push_token` calls — essential for server
377 /// code that emits one token at a time.
378 pub fn streaming_decoder(&self) -> crate::streaming::StreamingDecoder<'_> {
379 crate::streaming::StreamingDecoder::new(self)
380 }
381
382 /// Access the tokenizer configuration (read-only).
383 pub fn config(&self) -> &TokenizerConfig {
384 &self.config
385 }
386
387 /// Access the vocabulary (read-only).
388 pub fn vocab(&self) -> &Vocabulary {
389 &self.vocab
390 }
391
392 /// Access the merge table (read-only).
393 pub fn merges(&self) -> &BpeMerges {
394 &self.merges
395 }
396
397 /// Create a character-level tokenizer (no trained merges) for testing
398 /// and examples.
399 ///
400 /// Assigns IDs 4..vocab_size to printable ASCII characters (space = 4,
401 /// '!' = 5, ...) with IDs 0-3 reserved for UNK/BOS/EOS/PAD.
402 ///
403 /// This tokenizer has no BPE merges: each character is its own token.
404 /// The `_stub` suffix is retained for API compatibility.
405 pub fn char_level_stub(vocab_size: usize) -> Self {
406 assert!(
407 vocab_size >= 4,
408 "char_level_stub requires vocab_size >= 4 for special tokens"
409 );
410
411 let mut vocab = Vocabulary::new();
412 vocab.add_special("<unk>", 0);
413 vocab.add_special("<bos>", 1);
414 vocab.add_special("<eos>", 2);
415 vocab.add_special("<pad>", 3);
416
417 // Fill remaining slots with printable ASCII characters.
418 let mut next_id = 4u32;
419 for byte in 0x20u8..=0x7Eu8 {
420 if next_id as usize >= vocab_size {
421 break;
422 }
423 let ch = char::from(byte);
424 vocab.insert(&ch.to_string(), next_id);
425 next_id += 1;
426 }
427
428 // Also populate byte-fallback tokens for any remaining slots.
429 for byte in 0u8..=255u8 {
430 if next_id as usize >= vocab_size {
431 break;
432 }
433 let fallback = byte_fallback_id(byte);
434 if vocab.get_id(&fallback).is_none() {
435 vocab.insert(&fallback, next_id);
436 next_id += 1;
437 }
438 }
439
440 let config = TokenizerConfig {
441 add_bos: false,
442 add_eos: false,
443 bos_token_id: 1,
444 eos_token_id: 2,
445 unk_token_id: 0,
446 pad_token_id: 3,
447 max_length: None,
448 byte_level_decode: false,
449 };
450
451 let merges = BpeMerges::new();
452 // Use Self::new which initialises both unigram and wordpiece to None.
453 Self::new(vocab, merges, config)
454 }
455
456 // ── Special token helpers ─────────────────────────────────────────────
457
458 /// Return the BOS token ID from the configuration.
459 pub fn bos_id(&self) -> u32 {
460 self.config.bos_token_id
461 }
462
463 /// Return the EOS token ID from the configuration.
464 pub fn eos_id(&self) -> u32 {
465 self.config.eos_token_id
466 }
467
468 /// Return `true` if `id` is one of the configured special token IDs.
469 pub fn is_special(&self, id: u32) -> bool {
470 self.special_ids.contains(&id)
471 }
472}
473
474// ── Private helpers ───────────────────────────────────────────────────────────
475
476/// Build the set of special token IDs from a config.
477fn build_special_ids(config: &TokenizerConfig) -> HashSet<u32> {
478 let mut set = HashSet::new();
479 set.insert(config.bos_token_id);
480 set.insert(config.eos_token_id);
481 set.insert(config.unk_token_id);
482 set.insert(config.pad_token_id);
483 set
484}
485
486/// Parse a byte-fallback token like `<0x41>` and return the byte value.
487///
488/// Returns `None` if the token is not in the `<0xHH>` format.
489fn parse_byte_fallback(token: &str) -> Option<u8> {
490 let inner = token.strip_prefix("<0x")?.strip_suffix('>')?;
491 if inner.len() != 2 {
492 return None;
493 }
494 u8::from_str_radix(inner, 16).ok()
495}
496
497// ── Tests ─────────────────────────────────────────────────────────────────────
498
499#[cfg(test)]
500mod tests {
501 use super::*;
502
503 #[test]
504 fn char_level_stub_encode_ascii() {
505 let tok = OxiTokenizer::char_level_stub(200);
506 let ids = tok.encode("ab").expect("encode should succeed");
507 // Each char should map to a consistent non-zero ID.
508 assert_eq!(ids.len(), 2);
509 assert_ne!(ids[0], 0); // not UNK
510 assert_ne!(ids[1], 0);
511 assert_ne!(ids[0], ids[1]); // 'a' ≠ 'b'
512 }
513
514 #[test]
515 fn char_level_stub_bos_eos() {
516 let mut tok = OxiTokenizer::char_level_stub(200);
517 tok.config.add_bos = true;
518 tok.config.add_eos = true;
519 tok.special_ids = build_special_ids(&tok.config);
520 let ids = tok.encode("hi").expect("encode should succeed");
521 assert_eq!(ids[0], 1); // BOS
522 assert_eq!(*ids.last().expect("must have last element"), 2); // EOS
523 }
524
525 #[test]
526 fn char_level_stub_vocab_size() {
527 let tok = OxiTokenizer::char_level_stub(50);
528 assert!(tok.vocab_size() <= 50);
529 assert!(tok.vocab_size() >= 4); // at least special tokens
530 }
531
532 #[test]
533 fn special_token_detection() {
534 let tok = OxiTokenizer::char_level_stub(200);
535 assert!(tok.is_special(0)); // UNK
536 assert!(tok.is_special(1)); // BOS
537 assert!(tok.is_special(2)); // EOS
538 assert!(tok.is_special(3)); // PAD
539 assert!(!tok.is_special(4)); // first real token
540 }
541
542 #[test]
543 fn bos_eos_ids_match_config() {
544 let tok = OxiTokenizer::char_level_stub(200);
545 assert_eq!(tok.bos_id(), 1);
546 assert_eq!(tok.eos_id(), 2);
547 }
548
549 #[test]
550 fn decode_token_roundtrip() {
551 let tok = OxiTokenizer::char_level_stub(200);
552 // 'a' should map to some ID; we can look it up.
553 let ids = tok.encode("a").expect("should encode");
554 if let Some(&id) = ids.first() {
555 let s = tok.decode_token(id).expect("decode_token should succeed");
556 assert_eq!(s, "a");
557 }
558 }
559
560 #[test]
561 fn decode_unknown_id_returns_error() {
562 let tok = OxiTokenizer::char_level_stub(50);
563 let result = tok.decode_token(99_999);
564 assert!(result.is_err());
565 }
566
567 #[test]
568 fn max_length_truncates() {
569 let mut tok = OxiTokenizer::char_level_stub(200);
570 tok.config.max_length = Some(3);
571 tok.special_ids = build_special_ids(&tok.config);
572 let ids = tok.encode("hello world").expect("encode should succeed");
573 assert!(ids.len() <= 3);
574 }
575
576 #[test]
577 fn encode_batch_consistency() {
578 let tok = OxiTokenizer::char_level_stub(200);
579 let texts = ["ab", "cd", "ef"];
580 let batch = tok
581 .encode_batch(&texts)
582 .expect("batch encode should succeed");
583 assert_eq!(batch.len(), 3);
584 for (i, ids) in batch.iter().enumerate() {
585 let single = tok.encode(texts[i]).expect("single encode should succeed");
586 assert_eq!(*ids, single);
587 }
588 }
589
590 #[test]
591 fn parse_byte_fallback_valid() {
592 assert_eq!(parse_byte_fallback("<0x41>"), Some(0x41));
593 assert_eq!(parse_byte_fallback("<0x00>"), Some(0x00));
594 assert_eq!(parse_byte_fallback("<0xFF>"), Some(0xFF));
595 }
596
597 #[test]
598 fn parse_byte_fallback_invalid() {
599 assert_eq!(parse_byte_fallback("hello"), None);
600 assert_eq!(parse_byte_fallback("<0x>"), None);
601 assert_eq!(parse_byte_fallback("<0x1>"), None);
602 }
603
604 #[test]
605 fn from_json_roundtrip() {
606 let vocab_json = r#"{"a":10,"b":11,"ab":20,"<unk>":0,"<bos>":1,"<eos>":2,"<pad>":3}"#;
607 let merges_json = r#"[["a","b"]]"#;
608 let config = TokenizerConfig::default();
609 let tok = OxiTokenizer::from_json(vocab_json, merges_json, config)
610 .expect("from_json should succeed");
611 assert_eq!(tok.vocab_size(), 7);
612 // Encoding "ab" should produce a single merged token 20.
613 let ids = tok.encode("ab").expect("encode should succeed");
614 assert!(ids.contains(&20));
615 }
616
617 #[test]
618 fn is_unigram_false_for_bpe() {
619 let tok = OxiTokenizer::char_level_stub(200);
620 assert!(!tok.is_unigram());
621 }
622}