1pub mod added_tokens;
2pub mod decoders;
3pub mod json_structs;
4pub mod models;
5pub mod normalizers;
6pub mod post_processors;
7pub mod pre_tokenized;
8pub mod pre_tokenizers;
9
10use std::{fs, path::Path};
11
12use hf_hub::api::sync::{Api, ApiBuilder};
13use rayon::prelude::*;
14use serde_json::Value;
15
16pub use self::{
17 added_tokens::{AddedTokenInfo, AddedTokens},
18 json_structs::{
19 AddedTokenConfig, DecoderConfig, DecoderKind, ModelConfig, ModelKind, NormalizerConfig,
20 NormalizerKind, PostProcessorConfig, PostProcessorKind, PreTokenizerConfig,
21 PreTokenizerKind, TokenizerJson,
22 },
23 models::Model,
24 normalizers::{Nfc, Normalizer},
25 post_processors::PostProcessor,
26 pre_tokenizers::{ByteLevel, PreTokenizer, Split, SplitBehavior},
27};
28
29use self::{
30 added_tokens::Segment,
31 decoders::Decoder,
32 pre_tokenized::{PreTokenizedString, Split as PtSplit},
33};
34
35#[derive(Debug, thiserror::Error)]
37pub enum Error {
38 #[error("failed to download tokenizer files: {0}")]
39 Hub(#[from] hf_hub::api::sync::ApiError),
40
41 #[error("failed to read tokenizer files: {0}")]
42 Io(#[from] std::io::Error),
43
44 #[error("failed to parse tokenizer files: {0}")]
45 Json(#[from] serde_json::Error),
46
47 #[error("normalizer error: {0}")]
48 Normalizer(#[from] normalizers::Error),
49
50 #[error("pre-tokenizer error: {0}")]
51 PreTokenizer(#[from] pre_tokenizers::Error),
52
53 #[error("post-processor error: {0}")]
54 PostProcessor(#[from] post_processors::Error),
55
56 #[error("decoder error: {0}")]
57 Decoder(#[from] decoders::Error),
58
59 #[error("model error: {0}")]
60 Model(String),
61
62 #[error("invalid model identifier: {0}")]
63 InvalidIdentifier(String),
64}
65
66pub struct Tokenizer {
68 added_tokens: Option<AddedTokens>,
69 normalizer: Option<Normalizer>,
70 pre_tokenizer: Option<PreTokenizer>,
71 model: Model,
72 post_processor: Option<PostProcessor>,
73 decoder: Option<Decoder>,
74 split_only: Option<PreTokenizer>,
77}
78
79fn make_api(token: Option<&str>) -> Result<Api, hf_hub::api::sync::ApiError> {
83 match token {
84 Some(t) => ApiBuilder::new().with_token(Some(t.to_owned())).build(),
85 None => Api::new(),
86 }
87}
88
89fn validate_model_id(model: &str) -> Result<(), Error> {
90 if model.contains("..") {
91 return Err(Error::InvalidIdentifier(
92 "model identifier must not contain \"..\"".into(),
93 ));
94 }
95 Ok(())
96}
97
98impl Tokenizer {
99 fn build(json: TokenizerJson) -> Result<Self, Error> {
101 let added_tokens = AddedTokens::from_configs(&json.added_tokens).map_err(Error::Model)?;
102 let normalizer = json.normalizer.map(Normalizer::from_config).transpose()?;
103 let pre_tokenizer = json
104 .pre_tokenizer
105 .map(PreTokenizer::from_config)
106 .transpose()?;
107 let model = Model::from_config(json.model).map_err(Error::Model)?;
108 let post_processor = json
109 .post_processor
110 .map(PostProcessor::from_config)
111 .transpose()?;
112 let decoder = json.decoder.map(Decoder::from_config).transpose()?;
113
114 let split_only = Self::detect_fused_byte_level(&pre_tokenizer);
116
117 Ok(Self {
118 added_tokens,
119 normalizer,
120 pre_tokenizer,
121 model,
122 post_processor,
123 decoder,
124 split_only,
125 })
126 }
127
128 fn detect_fused_byte_level(pt: &Option<PreTokenizer>) -> Option<PreTokenizer> {
131 let PreTokenizer::Sequence(steps) = pt.as_ref()? else {
132 return None;
133 };
134 if steps.len() != 2 {
135 return None;
136 }
137 let is_split = matches!(&steps[0], PreTokenizer::Split(_));
138 let is_bulk_bl = matches!(&steps[1], PreTokenizer::ByteLevel(bl) if bl.is_bulk_only());
139 if is_split && is_bulk_bl {
140 Some(steps[0].clone())
141 } else {
142 None
143 }
144 }
145
146 pub fn from_json(json: Value) -> Result<Self, Error> {
148 let json: TokenizerJson = serde_json::from_value(json)?;
149 Self::build(json)
150 }
151
152 pub fn from_file(path: &Path) -> Result<Self, Error> {
154 let json: TokenizerJson = serde_json::from_str(&fs::read_to_string(path)?)?;
155 Self::build(json)
156 }
157
158 pub fn from_model(model: &str) -> Result<Self, Error> {
165 Self::from_model_with_token(model, None)
166 }
167
168 pub fn from_model_with_token(model: &str, token: Option<&str>) -> Result<Self, Error> {
172 validate_model_id(model)?;
173 let api = make_api(token)?;
174 let repo = api.model(model.to_string());
175 let json_path = repo.get("tokenizer.json")?;
176 let raw = fs::read_to_string(json_path)?;
177 let json: TokenizerJson = serde_json::from_str(&raw)?;
178 Self::build(json)
179 }
180
181 pub fn download_tokenizer_json(model: &str) -> Result<String, Error> {
185 validate_model_id(model)?;
186 let api = make_api(None)?;
187 let repo = api.model(model.to_string());
188 let json_path = repo.get("tokenizer.json")?;
189 Ok(fs::read_to_string(json_path)?)
190 }
191
192 pub fn normalizer(&self) -> Option<&Normalizer> {
194 self.normalizer.as_ref()
195 }
196
197 pub fn pre_tokenizer(&self) -> Option<&PreTokenizer> {
199 self.pre_tokenizer.as_ref()
200 }
201
202 pub fn post_processor(&self) -> Option<&PostProcessor> {
204 self.post_processor.as_ref()
205 }
206
207 pub fn model(&self) -> &Model {
209 &self.model
210 }
211
212 pub fn added_tokens(&self) -> Option<&AddedTokens> {
214 self.added_tokens.as_ref()
215 }
216
217 pub fn decoder(&self) -> Option<&Decoder> {
219 self.decoder.as_ref()
220 }
221
222 pub fn encode(&self, input: &str) -> Result<Vec<u32>, Error> {
227 self.encode_with_special_tokens(input, false)
228 }
229
230 pub fn encode_with_special_tokens(
235 &self,
236 input: &str,
237 add_special_tokens: bool,
238 ) -> Result<Vec<u32>, Error> {
239 if input.is_empty() {
240 return if add_special_tokens {
241 Ok(self.post_process(Vec::new(), true))
242 } else {
243 Ok(Vec::new())
244 };
245 }
246
247 let mut pts = self.build_pre_tokenized(input);
249
250 if let Some(ref split) = self.split_only {
252 split.pre_tokenize(&mut pts)?;
253 let ids = pts
254 .tokenize_batched(|buf, splits, out| {
255 self.model.tokenize_batch_fused(buf, splits, out)
256 })
257 .map_err(Error::Model)?;
258 return Ok(self.post_process(ids, add_special_tokens));
259 }
260
261 if let Some(ref pt) = self.pre_tokenizer {
263 pt.pre_tokenize(&mut pts)?;
264 }
265
266 let ids = pts
268 .tokenize(|text, out| self.model.tokenize_into(text, out))
269 .map_err(Error::Model)?;
270
271 Ok(self.post_process(ids, add_special_tokens))
273 }
274
275 pub fn encode_batch<S: AsRef<str> + Sync>(
277 &self,
278 inputs: &[S],
279 add_special_tokens: bool,
280 ) -> Result<Vec<Vec<u32>>, Error> {
281 inputs
282 .par_iter()
283 .map(|input| self.encode_with_special_tokens(input.as_ref(), add_special_tokens))
284 .collect()
285 }
286
287 pub fn set_post_processor(&mut self, pp: Option<PostProcessor>) {
290 self.post_processor = pp;
291 }
292
293 pub fn post_process(&self, ids: Vec<u32>, add_special_tokens: bool) -> Vec<u32> {
294 match &self.post_processor {
295 Some(pp) => pp.post_process_single(ids, add_special_tokens),
296 None => ids,
297 }
298 }
299
300 pub fn decode(&self, ids: &[u32], skip_special_tokens: bool) -> Result<String, Error> {
307 let mut tokens = Vec::with_capacity(ids.len());
308 for &id in ids {
309 if skip_special_tokens
310 && let Some(ref at) = self.added_tokens
311 && at.is_special(id)
312 {
313 continue;
314 }
315 if let Some(token_str) = self.id_to_token(id) {
320 tokens.push(token_str.to_string());
321 }
322 }
323
324 match &self.decoder {
325 Some(dec) => dec.decode(tokens).map_err(Error::Decoder),
326 None => Ok(tokens.join("")),
327 }
328 }
329
330 pub fn decode_tokens(&self, tokens: Vec<String>) -> Result<String, Error> {
336 match &self.decoder {
337 Some(dec) => dec.decode(tokens).map_err(Error::Decoder),
338 None => Ok(tokens.join("")),
339 }
340 }
341
342 pub fn decode_batch(
344 &self,
345 sentences: &[&[u32]],
346 skip_special_tokens: bool,
347 ) -> Result<Vec<String>, Error> {
348 sentences
349 .iter()
350 .map(|ids| self.decode(ids, skip_special_tokens))
351 .collect()
352 }
353
354 pub fn id_to_token(&self, id: u32) -> Option<&str> {
359 if let Some(ref at) = self.added_tokens
360 && let Some(s) = at.id_to_token(id)
361 {
362 return Some(s);
363 }
364 self.model.id_to_token(id)
365 }
366
367 pub fn token_to_id(&self, token: &str) -> Option<u32> {
372 if let Some(ref at) = self.added_tokens
373 && let Some(id) = at.token_to_id(token)
374 {
375 return Some(id);
376 }
377 self.model.token_to_id(token)
378 }
379
380 pub fn vocab_size(&self) -> usize {
382 let model_size = self.model.vocab_size();
383 let added_size = self.added_tokens.as_ref().map_or(0, |at| at.len());
384 model_size + added_size
385 }
386
387 pub fn is_special_token(&self, id: u32) -> bool {
389 self.added_tokens
390 .as_ref()
391 .is_some_and(|added_tokens| added_tokens.is_special(id))
392 }
393
394 pub fn build_pre_tokenized(&self, input: &str) -> PreTokenizedString {
399 let segments = match &self.added_tokens {
400 Some(at) => at.split(input),
401 None => vec![Segment::Text(input)],
402 };
403
404 if segments.len() == 1
407 && let Segment::Text(text) = segments[0]
408 {
409 let normalized = match &self.normalizer {
410 Some(n) => n.normalize(text),
411 None => std::borrow::Cow::Borrowed(text),
412 };
413 return match normalized {
414 std::borrow::Cow::Borrowed(_) => PreTokenizedString::from_text(text),
415 std::borrow::Cow::Owned(s) => {
416 let len = s.len();
417 PreTokenizedString::new(
418 s,
419 vec![PtSplit {
420 range: 0..len,
421 token_id: None,
422 }],
423 )
424 }
425 };
426 }
427
428 let mut buffer = String::with_capacity(input.len());
429 let mut splits = Vec::new();
430
431 for seg in &segments {
432 match seg {
433 Segment::Token(id) => {
434 let start = buffer.len();
435 splits.push(PtSplit {
436 range: start..start,
437 token_id: Some(*id),
438 });
439 }
440 Segment::Text(text) => {
441 if text.is_empty() {
442 continue;
443 }
444 let normalized = match &self.normalizer {
445 Some(n) => n.normalize(text),
446 None => std::borrow::Cow::Borrowed(*text),
447 };
448 let start = buffer.len();
449 buffer.push_str(&normalized);
450 let end = buffer.len();
451 splits.push(PtSplit {
452 range: start..end,
453 token_id: None,
454 });
455 }
456 }
457 }
458
459 PreTokenizedString::new(buffer, splits)
460 }
461}
462
463pub struct DecodeStream {
472 skip_special_tokens: bool,
473 ids: Vec<u32>,
474 prefix: String,
475 prefix_index: usize,
476}
477
478impl DecodeStream {
479 pub fn new(ids: Vec<u32>, skip_special_tokens: bool) -> Self {
480 Self {
481 skip_special_tokens,
482 ids,
483 prefix: String::new(),
484 prefix_index: 0,
485 }
486 }
487
488 pub fn step(
489 &mut self,
490 tokenizer: &Tokenizer,
491 token_ids: Vec<u32>,
492 ) -> Result<Option<String>, String> {
493 decode_stream_step(
494 tokenizer,
495 token_ids,
496 self.skip_special_tokens,
497 &mut self.ids,
498 &mut self.prefix,
499 &mut self.prefix_index,
500 )
501 }
502}
503
504pub fn decode_stream_step(
523 tokenizer: &Tokenizer,
524 token_ids: Vec<u32>,
525 skip_special_tokens: bool,
526 ids: &mut Vec<u32>,
527 prefix: &mut String,
528 prefix_index: &mut usize,
529) -> Result<Option<String>, String> {
530 const REPLACEMENT: char = '\u{FFFD}';
531
532 if prefix.is_empty() && !ids.is_empty() {
535 let s = tokenizer
536 .decode(ids, skip_special_tokens)
537 .map_err(|e| e.to_string())?;
538 if !s.ends_with(REPLACEMENT) {
539 *prefix = s;
540 *prefix_index = ids.len();
541 }
542 }
543
544 ids.extend(token_ids);
545
546 let string = tokenizer
547 .decode(ids, skip_special_tokens)
548 .map_err(|e| e.to_string())?;
549
550 if string.len() > prefix.len() && !string.ends_with(REPLACEMENT) {
551 if !string.starts_with(prefix.as_str()) {
552 return Err(format!(
553 "Invalid prefix encountered while decoding stream. \
554 Expected prefix: '{}', Actual string: '{}'",
555 prefix, string,
556 ));
557 }
558 let new_text = string[prefix.len()..].to_string();
559 let new_prefix_index = ids.len() - *prefix_index;
560 *ids = ids.drain(*prefix_index..).collect();
561 *prefix = tokenizer
562 .decode(ids, skip_special_tokens)
563 .map_err(|e| e.to_string())?;
564 *prefix_index = new_prefix_index;
565 Ok(Some(new_text))
566 } else {
567 Ok(None)
568 }
569}
570
571#[cfg(test)]
572mod tests {
573 use super::*;
574
575 const HF_MODELS: &[&str] = &[
576 "Qwen/Qwen3-0.6B",
577 "zai-org/GLM-4.7",
578 "deepseek-ai/DeepSeek-V3.2",
579 "MiniMaxAI/MiniMax-M2.1",
580 "openai/gpt-oss-120b",
581 "mistralai/Mistral-Nemo-Instruct-2407",
582 "Qwen/Qwen3-235B-A22B-Instruct-2507",
583 "Qwen/Qwen3-Coder-480B-A35B-Instruct",
584 "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16",
585 "nvidia/Qwen3-Nemotron-235B-A22B-GenRM",
586 "hoangquan456/Kimi-K2.5",
587 ];
588
589 #[test]
594 fn parse_hf_json() {
595 let api = make_api(None).unwrap();
596 for model in HF_MODELS {
597 let repo = api.model(model.to_string());
598 let json_path = repo
599 .get("tokenizer.json")
600 .unwrap_or_else(|e| panic!("{model}: {e}"));
601 let json: TokenizerJson = serde_json::from_str(&fs::read_to_string(json_path).unwrap())
602 .unwrap_or_else(|e| panic!("{model}: {e}"));
603 assert!(
604 !matches!(json.model, ModelConfig::Other(_)),
605 "{model}: model parsed as Other",
606 );
607 }
608 }
609
610 #[test]
612 fn encode_batch_matches_sequential() {
613 let model = "MiniMaxAI/MiniMax-M2.1";
614 let ours = Tokenizer::from_model(model).unwrap();
615
616 let inputs = &["Hello, world!", "The quick brown fox", "Test", ""];
617 let batch_results = ours.encode_batch(inputs, false).unwrap();
618
619 for (input, batch_result) in inputs.iter().zip(&batch_results) {
620 let sequential_result = ours.encode(input).unwrap();
621 assert_eq!(
622 batch_result, &sequential_result,
623 "batch mismatch for {input:?}"
624 );
625 }
626 }
627
628 #[test]
630 fn vocab_access() {
631 let model = "MiniMaxAI/MiniMax-M2.1";
632 let ours = Tokenizer::from_model(model).unwrap();
633
634 assert!(ours.vocab_size() > 0);
635
636 let token_str = ours.id_to_token(0).expect("token 0 should exist");
637 let id = ours
638 .token_to_id(token_str)
639 .expect("reverse lookup should work");
640 assert_eq!(id, 0);
641 }
642
643 #[test]
644 fn public_added_token_accessors_expose_added_vocab() {
645 let tok = Tokenizer::from_model("Qwen/Qwen3-0.6B").unwrap();
646 let added_tokens = tok.added_tokens().expect("expected added tokens");
647
648 let think_id = tok.token_to_id("<think>").expect("<think> should exist");
649 assert_eq!(added_tokens.token_to_id("<think>"), Some(think_id));
650 assert_eq!(added_tokens.id_to_token(think_id), Some("<think>"));
651
652 let mut entries: Vec<_> = added_tokens.iter().collect();
653 entries.sort_by_key(|entry| entry.id);
654 let special_entry = entries
655 .iter()
656 .find(|entry| entry.special)
657 .expect("expected at least one special added token");
658 assert!(tok.is_special_token(special_entry.id));
659 assert!(
660 entries
661 .iter()
662 .any(|entry| entry.id == think_id && entry.content == "<think>"),
663 "added-token iterator should expose <think>"
664 );
665 }
666
667 const CORPUS: &[&str] = &[
672 "",
674 " ",
675 " ",
676 "\n",
677 "\t",
678 "\r\n",
679 "a",
681 "Z",
682 "0",
683 "!",
684 "\u{00e9}", "\u{4e2d}", "Hello, world!",
688 "The quick brown fox jumps over the lazy dog.",
689 "A short sentence.",
690 " leading spaces",
692 "trailing spaces ",
693 " both sides ",
694 "multiple internal spaces",
695 "tabs\there\tand\tthere",
696 "line\none\nline\ntwo",
697 "windows\r\nline\r\nendings",
698 "mixed\n\ttabs and\r\nnewlines with spaces",
699 "42",
701 "3.14159",
702 "1,000,000",
703 "0xFF",
704 "1e-10",
705 "Numbers 1234567890 and mixed ABC123def",
706 "Hello!!! How are you???",
708 "@user #hashtag $100 %50 ^caret & *star",
709 "a-b_c.d,e;f:g",
710 "(parentheses) [brackets] {braces}",
711 "\"double quotes\" 'single quotes' `backticks`",
712 "path/to/file.txt",
713 "https://example.com/path?q=test&lang=en#section",
714 "Special chars: @#$%^&*()_+-=[]{}|;':\",./<>?",
715 "caf\u{00e9} r\u{00e9}sum\u{00e9} na\u{00ef}ve",
717 "\u{00fc}ber stra\u{00df}e gr\u{00f6}\u{00df}e",
718 "se\u{00f1}or ni\u{00f1}o a\u{00f1}o",
719 "\u{4f60}\u{597d}\u{4e16}\u{754c}", "\u{3053}\u{3093}\u{306b}\u{3061}\u{306f}", "\u{c548}\u{b155}\u{d558}\u{c138}\u{c694}", "\u{041f}\u{0440}\u{0438}\u{0432}\u{0435}\u{0442} \u{043c}\u{0438}\u{0440}",
725 "\u{0645}\u{0631}\u{062d}\u{0628}\u{0627}",
727 "\u{0928}\u{092e}\u{0938}\u{094d}\u{0924}\u{0947}",
729 "\u{1f600}\u{1f680}\u{2764}\u{fe0f}",
731 "\u{1f468}\u{200d}\u{1f469}\u{200d}\u{1f467}\u{200d}\u{1f466}",
732 "\u{1f1fa}\u{1f1f8}", "e\u{0301}", "n\u{0303}", "a\u{0308}", "Hello \u{4e16}\u{754c} \u{041c}\u{0438}\u{0440}!",
739 "User123 wrote: \u{4f60}\u{597d}!",
740 "fn main() { println!(\"hello\"); }",
742 "def foo(x: int) -> str:\n return str(x)",
743 "SELECT * FROM users WHERE id = 1;",
744 "if (x > 0 && y < 10) { z = x + y; }",
745 "<html><body><p>Hello</p></body></html>",
746 "#include <stdio.h>\nint main() { return 0; }",
747 "import numpy as np\nx = np.array([1, 2, 3])",
748 "{\"key\": \"value\", \"number\": 42, \"array\": [1, 2, 3]}",
750 "[{\"id\": 1}, {\"id\": 2}]",
751 "aaaaaaaaaa",
753 "abababababababab",
754 "the the the the the the the the",
755 "....",
756 "----",
757 " ",
758 "\n\n\n\n",
759 "This is a longer sentence with various elements: numbers (42, 3.14), \
761 symbols (@#$), Unicode (caf\u{00e9}, \u{4f60}\u{597d}), and more.",
762 "The year 2024 was notable for advances in AI. Models like GPT-4 and \
763 Claude demonstrated remarkable capabilities in reasoning, coding, and \
764 multilingual understanding.",
765 "a b c d e f g h i j k l m n o p q r s t u v w x y z",
767 "ABCDEFGHIJKLMNOPQRSTUVWXYZ",
768 "0123456789",
769 "a\nb\nc\n",
771 "# Heading\n\n- item 1\n- item 2\n\n```code```",
772 "\u{ffff}", "\u{0080}", "\u{07ff}", "\u{0800}", "\u{10000}", "\u{fffd}", "\u{feff}Hello", "\u{0000}", "abc\u{0000}def", "\u{fffe}", "\u{fdd0}", "\u{200b}\u{200c}\u{200d}", "\u{202e}Hello\u{202c}", "\u{0001}\u{0002}\u{001f}\u{007f}", "\u{0300}", "a\u{0300}\u{0301}\u{0302}\u{0303}\u{0304}", "\u{e000}\u{f8ff}", "\u{01c5}\u{01c8}\u{01cb}", "\u{2028}\u{2029}", "\u{fff9}\u{fffa}\u{fffb}", "\u{d7ff}\u{10ffff}", "ab",
796 "abc",
797 "abcd",
798 "aaa",
799 "aaaa",
800 "aaaaa",
801 "**bold** *italic* ~~strikethrough~~ __underline__",
803 "```rust\nfn main() {}\n```",
804 "> blockquote\n>> nested",
805 "| col1 | col2 |\n|------|------|\n| a | b |",
806 ];
807
808 fn compare_encode_decode(model_name: &str, corpus: &[&str]) -> Vec<String> {
812 let hf = tokenizers::Tokenizer::from_pretrained(model_name, None)
813 .unwrap_or_else(|e| panic!("{model_name}: HF load failed: {e}"));
814 let ours = Tokenizer::from_model(model_name)
815 .unwrap_or_else(|e| panic!("{model_name}: fastokens load failed: {e}"));
816
817 let mut failures = Vec::new();
818 for &input in corpus {
819 let hf_enc = hf
820 .encode(input, false)
821 .unwrap_or_else(|e| panic!("{model_name}: HF encode({input:?}): {e}"));
822 let hf_ids = hf_enc.get_ids().to_vec();
823 let our_ids = match ours.encode(input) {
824 Ok(ids) => ids,
825 Err(e) => {
826 failures.push(format!(" encode error on {input:?}: {e}"));
827 continue;
828 }
829 };
830 if our_ids != hf_ids {
831 failures.push(format!(
832 " encode mismatch on {input:?}: got {} tokens, expected {}\n\
833 \x20 ours: {:?}\n\
834 \x20 hf: {:?}",
835 our_ids.len(),
836 hf_ids.len(),
837 &our_ids[..our_ids.len().min(20)],
838 &hf_ids[..hf_ids.len().min(20)],
839 ));
840 }
841
842 if input.is_empty() || hf_ids.is_empty() {
844 continue;
845 }
846 let hf_decoded = match hf.decode(&hf_ids, false) {
847 Ok(d) => d,
848 Err(_) => continue,
849 };
850 let our_decoded = match ours.decode(&hf_ids, false) {
851 Ok(d) => d,
852 Err(e) => {
853 failures.push(format!(" decode error on {input:?}: {e}"));
854 continue;
855 }
856 };
857 if our_decoded != hf_decoded {
858 failures.push(format!(
859 " decode mismatch on {input:?}:\n\
860 \x20 ours: {:?}\n\
861 \x20 hf: {:?}",
862 &our_decoded[..our_decoded.len().min(100)],
863 &hf_decoded[..hf_decoded.len().min(100)],
864 ));
865 }
866 }
867 failures
868 }
869
870 #[test]
873 fn correctness_minimax_m2_1() {
874 let f = compare_encode_decode("MiniMaxAI/MiniMax-M2.1", CORPUS);
875 assert!(f.is_empty(), "MiniMaxAI/MiniMax-M2.1:\n{}", f.join("\n"));
876 }
877
878 #[test]
879 fn correctness_nemotron() {
880 let f = compare_encode_decode("nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16", CORPUS);
881 assert!(
882 f.is_empty(),
883 "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16:\n{}",
884 f.join("\n")
885 );
886 }
887
888 #[test]
889 fn correctness_deepseek_v3_2() {
890 let f = compare_encode_decode("deepseek-ai/DeepSeek-V3.2", CORPUS);
891 assert!(f.is_empty(), "deepseek-ai/DeepSeek-V3.2:\n{}", f.join("\n"));
892 }
893
894 #[test]
895 fn correctness_gpt_oss() {
896 let f = compare_encode_decode("openai/gpt-oss-120b", CORPUS);
897 assert!(f.is_empty(), "openai/gpt-oss-120b:\n{}", f.join("\n"));
898 }
899
900 #[test]
901 fn ignore_merges_glm47() {
902 let model = "zai-org/GLM-4.7";
903 let hf = tokenizers::Tokenizer::from_pretrained(model, None).unwrap();
904 let ours = Tokenizer::from_model(model).unwrap();
905
906 let text = " имущества";
910 let hf_ids = hf.encode(text, false).unwrap().get_ids().to_vec();
911 let our_ids = ours.encode(text).unwrap();
912 assert_eq!(
913 our_ids, hf_ids,
914 "ignore_merges mismatch on {text:?}: ours={our_ids:?} hf={hf_ids:?}"
915 );
916
917 let vocab_size = hf.get_vocab_size(false) as u64;
919 let random_ids: Vec<u32> = (0..5000)
920 .map(|i| {
921 ((i as u64).wrapping_mul(6364136223846793005).wrapping_add(1) % vocab_size) as u32
922 })
923 .collect();
924 let text = hf.decode(&random_ids, true).unwrap();
925 let hf_enc = hf.encode(text.as_str(), false).unwrap().get_ids().to_vec();
926 let our_enc = ours.encode(&text).unwrap();
927 assert_eq!(
928 our_enc,
929 hf_enc,
930 "ignore_merges random-decode mismatch: {} vs {} tokens",
931 our_enc.len(),
932 hf_enc.len()
933 );
934 }
935
936 #[test]
937 fn correctness_qwen3() {
938 let f = compare_encode_decode("Qwen/Qwen3-0.6B", CORPUS);
939 assert!(f.is_empty(), "Qwen/Qwen3-0.6B:\n{}", f.join("\n"));
940 }
941
942 #[test]
943 fn correctness_mistral_nemo() {
944 let f = compare_encode_decode("mistralai/Mistral-Nemo-Instruct-2407", CORPUS);
945 assert!(
946 f.is_empty(),
947 "mistralai/Mistral-Nemo-Instruct-2407:\n{}",
948 f.join("\n")
949 );
950 }
951
952 #[test]
953 fn correctness_qwen3_nemotron() {
954 let f = compare_encode_decode("nvidia/Qwen3-Nemotron-235B-A22B-GenRM", CORPUS);
955 assert!(
956 f.is_empty(),
957 "nvidia/Qwen3-Nemotron-235B-A22B-GenRM:\n{}",
958 f.join("\n")
959 );
960 }
961
962 #[test]
963 fn correctness_kimi_k2_5() {
964 let f = compare_encode_decode("hoangquan456/Kimi-K2.5", CORPUS);
965 assert!(f.is_empty(), "hoangquan456/Kimi-K2.5:\n{}", f.join("\n"));
966 }
967
968 #[test]
973 fn cache_consistency() {
974 let model = "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16";
975 let ours = Tokenizer::from_model(model).unwrap();
976
977 let inputs = &[
978 "Hello, world!",
979 "The quick brown fox jumps over the lazy dog.",
980 "caf\u{00e9} r\u{00e9}sum\u{00e9}",
981 "\u{4f60}\u{597d}\u{4e16}\u{754c}",
982 "fn main() { println!(\"hello\"); }",
983 "a b c d e f g h i j k l m n o p",
984 "aaaaaaaaaa bbbbbbbbbb cccccccccc",
985 ];
986
987 for &input in inputs {
988 let first = ours.encode(input).unwrap();
989 let second = ours.encode(input).unwrap();
990 assert_eq!(first, second, "cache inconsistency for {input:?}");
991 let third = ours.encode(input).unwrap();
993 assert_eq!(first, third, "cache inconsistency (3rd call) for {input:?}");
994 }
995 }
996
997 #[test]
1000 fn cache_consistency_fused() {
1001 let model = "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16";
1002 let ours = Tokenizer::from_model(model).unwrap();
1003
1004 assert!(ours.split_only.is_some(), "expected fused path for {model}",);
1006
1007 let input = "The year 2024 was notable for advances in AI. Models like \
1009 GPT-4 and Claude demonstrated remarkable capabilities.";
1010 let baseline = ours.encode(input).unwrap();
1011 for i in 0..20 {
1012 let result = ours.encode(input).unwrap();
1013 assert_eq!(result, baseline, "fused cache drift on iteration {i}");
1014 }
1015 }
1016
1017 #[test]
1022 fn added_tokens_minimax() {
1023 let corpus = &[
1024 "<filename>",
1025 "open <filename> for reading",
1026 "<filename><reponame>",
1027 "printf(\"%s <filename>\\n\")",
1028 "<think>Let me reason about this.</think>",
1029 "<think>load <filename> from <reponame></think>",
1030 "<file> is not <filename>",
1031 "<fim_prefix>code here<fim_suffix>more code<fim_middle>",
1032 ];
1033 let f = compare_encode_decode("MiniMaxAI/MiniMax-M2.1", corpus);
1034 assert!(
1035 f.is_empty(),
1036 "MiniMaxAI/MiniMax-M2.1 added tokens:\n{}",
1037 f.join("\n")
1038 );
1039 }
1040
1041 #[test]
1043 fn added_tokens_deepseek() {
1044 let corpus = &[
1045 "<|begin▁of▁sentence|>Hello",
1046 "Hello<|end▁of▁sentence|>",
1047 "<|User|>What is 2+2?<|Assistant|>4<|end▁of▁sentence|>",
1048 "Normal text without special tokens",
1049 "<|tool▁calls▁begin|>call<|tool▁calls▁end|>",
1050 ];
1051 let f = compare_encode_decode("deepseek-ai/DeepSeek-V3.2", corpus);
1052 assert!(
1053 f.is_empty(),
1054 "deepseek-ai/DeepSeek-V3.2 added tokens:\n{}",
1055 f.join("\n")
1056 );
1057 }
1058
1059 #[test]
1061 fn added_tokens_qwen3() {
1062 let corpus = &[
1063 "<|im_start|>system\nYou are a helpful assistant.<|im_end|>",
1064 "<|im_start|>user\nHello!<|im_end|>",
1065 "<|endoftext|>",
1066 "Plain text with no special tokens at all.",
1067 ];
1068 let f = compare_encode_decode("Qwen/Qwen3-0.6B", corpus);
1069 assert!(
1070 f.is_empty(),
1071 "Qwen/Qwen3-0.6B added tokens:\n{}",
1072 f.join("\n")
1073 );
1074 }
1075
1076 #[test]
1084 fn token_to_id_searches_added_tokens() {
1085 let tok = Tokenizer::from_model("Qwen/Qwen3-0.6B").unwrap();
1086 for token in &[
1088 "<|image_pad|>",
1089 "<|vision_start|>",
1090 "<|vision_end|>",
1091 "<|im_start|>",
1092 ] {
1093 let id = tok.token_to_id(token);
1094 assert!(id.is_some(), "token_to_id({token:?}) returned None");
1095 assert_eq!(tok.id_to_token(id.unwrap()), Some(*token));
1097 }
1098 }
1099
1100 #[test]
1111 fn added_tokens_qwen3vl_vision_sequence() {
1112 let corpus = &[
1113 "<|vision_start|><|image_pad|><|vision_end|>",
1115 "<|image_pad|>",
1117 "<|vision_start|><|image_pad|><|image_pad|><|image_pad|><|image_pad|><|vision_end|>",
1119 "<|vision_start|><|image_pad|><|vision_end|>\nDescribe this image.",
1121 ];
1122 let f = compare_encode_decode("Qwen/Qwen3.5-27B", corpus);
1123 assert!(
1124 f.is_empty(),
1125 "Qwen/Qwen3.5-27B VL vision sequence:\n{}",
1126 f.join("\n")
1127 );
1128 }
1129
1130 #[test]
1132 fn added_tokens_nemotron() {
1133 let corpus = &[
1134 "<|begin_of_text|>Hello world",
1135 "Hello<|end_of_text|>",
1136 "<|start_header_id|>system<|end_header_id|>\n\nYou are helpful.<|eot_id|>",
1137 "<|start_header_id|>user<|end_header_id|>\n\nHi!<|eot_id|>",
1138 "No special tokens here.",
1139 ];
1140 let f = compare_encode_decode("nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16", corpus);
1141 assert!(
1142 f.is_empty(),
1143 "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 added tokens:\n{}",
1144 f.join("\n")
1145 );
1146 }
1147
1148 #[test]
1153 fn long_input_correctness() {
1154 let model_name = "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16";
1155 let hf = tokenizers::Tokenizer::from_pretrained(model_name, None).unwrap();
1156 let ours = Tokenizer::from_model(model_name).unwrap();
1157
1158 let block = "The quick brown fox jumps over the lazy dog. \
1160 Numbers: 42, 3.14, 1000. Code: fn main() {} \
1161 Unicode: caf\u{00e9}, \u{4f60}\u{597d}. \
1162 Special: @#$%^&*(). ";
1163 let input: String = block.repeat(100);
1164 assert!(input.len() > 8000);
1165
1166 let hf_ids = hf.encode(input.as_str(), false).unwrap().get_ids().to_vec();
1167 let our_ids = ours.encode(&input).unwrap();
1168 assert_eq!(
1169 our_ids,
1170 hf_ids,
1171 "long input mismatch: {} vs {} tokens",
1172 our_ids.len(),
1173 hf_ids.len(),
1174 );
1175 }
1176
1177 #[test]
1179 fn long_input_correctness_minimax() {
1180 let model_name = "MiniMaxAI/MiniMax-M2.1";
1181 let hf = tokenizers::Tokenizer::from_pretrained(model_name, None).unwrap();
1182 let ours = Tokenizer::from_model(model_name).unwrap();
1183
1184 let block = "The quick brown fox jumps over the lazy dog. \
1185 Numbers: 42, 3.14, 1000. Code: fn main() {} \
1186 Unicode: caf\u{00e9}, \u{4f60}\u{597d}. \
1187 Special: @#$%^&*(). ";
1188 let input: String = block.repeat(100);
1189
1190 let hf_ids = hf.encode(input.as_str(), false).unwrap().get_ids().to_vec();
1191 let our_ids = ours.encode(&input).unwrap();
1192 assert_eq!(
1193 our_ids,
1194 hf_ids,
1195 "long input mismatch: {} vs {} tokens",
1196 our_ids.len(),
1197 hf_ids.len(),
1198 );
1199 }
1200
1201 use std::sync::OnceLock;
1204
1205 struct ExtendedCorpus {
1206 longbench: Vec<String>,
1207 sharegpt: Vec<String>,
1208 }
1209
1210 fn extended_corpus() -> &'static ExtendedCorpus {
1211 static CORPUS: OnceLock<ExtendedCorpus> = OnceLock::new();
1212 CORPUS.get_or_init(|| {
1213 let api = Api::new().unwrap();
1214
1215 let lb_repo = api.dataset("zai-org/LongBench-v2".to_string());
1217 let lb_path = lb_repo.get("data.json").unwrap();
1218 let lb_data: Vec<serde_json::Value> =
1219 serde_json::from_str(&fs::read_to_string(lb_path).unwrap()).unwrap();
1220 let longbench: Vec<String> = lb_data
1221 .iter()
1222 .filter_map(|item| {
1223 let ctx = item.get("context")?.as_str()?;
1224 if ctx.is_empty() {
1225 None
1226 } else {
1227 Some(ctx.to_string())
1228 }
1229 })
1230 .collect();
1231
1232 let sg_repo = api.dataset("RyokoAI/ShareGPT52K".to_string());
1234 let sg_path = sg_repo.get("sg_90k_part1.json").unwrap();
1235 let sg_data: Vec<serde_json::Value> =
1236 serde_json::from_str(&fs::read_to_string(sg_path).unwrap()).unwrap();
1237 let sharegpt: Vec<String> = sg_data
1238 .iter()
1239 .filter_map(|item| {
1240 let messages = item.get("conversations")?.as_array()?;
1241 let parts: Vec<String> = messages
1242 .iter()
1243 .filter_map(|msg| {
1244 let role = msg
1245 .get("from")
1246 .and_then(|v| v.as_str())
1247 .unwrap_or("unknown");
1248 let value = msg.get("value").and_then(|v| v.as_str())?;
1249 if value.is_empty() {
1250 return None;
1251 }
1252 Some(format!("[{role}]: {value}"))
1253 })
1254 .collect();
1255 if parts.is_empty() {
1256 None
1257 } else {
1258 Some(parts.join("\n\n"))
1259 }
1260 })
1261 .collect();
1262
1263 ExtendedCorpus {
1264 longbench,
1265 sharegpt,
1266 }
1267 })
1268 }
1269
1270 fn compare_encode_decode_batched(
1272 model_name: &str,
1273 corpus: &[String],
1274 batch_size: usize,
1275 progress: bool,
1276 ) -> Vec<String> {
1277 let hf = tokenizers::Tokenizer::from_pretrained(model_name, None)
1278 .unwrap_or_else(|e| panic!("{model_name}: HF load failed: {e}"));
1279 let ours = Tokenizer::from_model(model_name)
1280 .unwrap_or_else(|e| panic!("{model_name}: fastokens load failed: {e}"));
1281
1282 let total = corpus.len();
1283 let mut processed = 0usize;
1284 let mut failures = Vec::new();
1285 for chunk in corpus.chunks(batch_size) {
1286 let hf_results: Vec<Vec<u32>> = chunk
1287 .iter()
1288 .map(|input| {
1289 hf.encode(input.as_str(), false)
1290 .unwrap_or_else(|e| panic!("{model_name}: HF encode: {e}"))
1291 .get_ids()
1292 .to_vec()
1293 })
1294 .collect();
1295
1296 let our_results = match ours.encode_batch(chunk, false) {
1297 Ok(r) => r,
1298 Err(e) => {
1299 failures.push(format!(" encode_batch error: {e}"));
1300 continue;
1301 }
1302 };
1303
1304 for (i, (hf_ids, our_ids)) in hf_results.iter().zip(our_results.iter()).enumerate() {
1305 let input = &chunk[i];
1306 let input_preview = {
1307 let mut end = input.len().min(80);
1308 while end < input.len() && !input.is_char_boundary(end) {
1309 end += 1;
1310 }
1311 &input[..end]
1312 };
1313
1314 if our_ids != hf_ids {
1315 failures.push(format!(
1316 " encode mismatch on {:?}: got {} tokens, expected {}\n\
1317 \x20 ours: {:?}\n\
1318 \x20 hf: {:?}",
1319 input_preview,
1320 our_ids.len(),
1321 hf_ids.len(),
1322 &our_ids[..our_ids.len().min(20)],
1323 &hf_ids[..hf_ids.len().min(20)],
1324 ));
1325 }
1326
1327 if hf_ids.is_empty() || input.is_empty() {
1329 continue;
1330 }
1331 let hf_decoded = match hf.decode(hf_ids, false) {
1332 Ok(d) => d,
1333 Err(_) => continue,
1334 };
1335 let our_decoded = match ours.decode(hf_ids, false) {
1336 Ok(d) => d,
1337 Err(e) => {
1338 failures.push(format!(" decode error on {input_preview:?}: {e}"));
1339 continue;
1340 }
1341 };
1342 if our_decoded != hf_decoded {
1343 failures.push(format!(
1344 " decode mismatch on {input_preview:?}:\n\
1345 \x20 ours: {:?}\n\
1346 \x20 hf: {:?}",
1347 &our_decoded[..our_decoded.len().min(100)],
1348 &hf_decoded[..hf_decoded.len().min(100)],
1349 ));
1350 }
1351 }
1352 processed += chunk.len();
1353 if progress {
1354 eprint!(
1355 "\r {model_name}: {processed}/{total} ({:.0}%)",
1356 processed as f64 / total as f64 * 100.0,
1357 );
1358 }
1359 }
1360 if progress {
1361 eprintln!();
1362 }
1363 failures
1364 }
1365
1366 fn run_extended(model_name: &str) {
1367 let progress = std::env::var("EXTENDED_PROGRESS").is_ok();
1368 let corpus = extended_corpus();
1369 if progress {
1370 eprintln!(
1371 " {model_name}: longbench ({} samples)",
1372 corpus.longbench.len()
1373 );
1374 }
1375 let mut failures =
1376 compare_encode_decode_batched(model_name, &corpus.longbench, 10, progress);
1377 if progress {
1378 eprintln!(
1379 " {model_name}: sharegpt ({} samples)",
1380 corpus.sharegpt.len()
1381 );
1382 }
1383 failures.extend(compare_encode_decode_batched(
1384 model_name,
1385 &corpus.sharegpt,
1386 10,
1387 progress,
1388 ));
1389 assert!(
1390 failures.is_empty(),
1391 "{model_name} extended ({} failures):\n{}",
1392 failures.len(),
1393 failures.join("\n"),
1394 );
1395 }
1396
1397 #[test]
1398 #[ignore]
1399 fn extended_minimax_m2_1() {
1400 run_extended("MiniMaxAI/MiniMax-M2.1");
1401 }
1402
1403 #[test]
1404 #[ignore]
1405 fn extended_nemotron() {
1406 run_extended("nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16");
1407 }
1408
1409 #[test]
1410 #[ignore]
1411 fn extended_deepseek_v3_2() {
1412 run_extended("deepseek-ai/DeepSeek-V3.2");
1413 }
1414
1415 #[test]
1416 #[ignore]
1417 fn extended_gpt_oss() {
1418 run_extended("openai/gpt-oss-120b");
1419 }
1420
1421 #[test]
1422 #[ignore]
1423 fn extended_qwen3() {
1424 run_extended("Qwen/Qwen3-0.6B");
1425 }
1426
1427 #[test]
1428 #[ignore]
1429 fn extended_mistral_nemo() {
1430 run_extended("mistralai/Mistral-Nemo-Instruct-2407");
1431 }
1432
1433 #[test]
1434 #[ignore]
1435 fn extended_qwen3_nemotron() {
1436 run_extended("nvidia/Qwen3-Nemotron-235B-A22B-GenRM");
1437 }
1438
1439 #[test]
1440 #[ignore]
1441 fn extended_mistral_large() {
1442 run_extended("mistralai/Mistral-Large-3-675B-Instruct-2512");
1443 }
1444
1445 #[test]
1446 #[ignore]
1447 fn extended_qwen_small() {
1448 run_extended("Qwen/Qwen3-0.6B");
1449 }
1450
1451 #[test]
1455 fn encode_decode_roundtrip_all_models() {
1456 let texts = &[
1457 "Hello, world!",
1458 "日本語テスト",
1459 "The quick brown fox jumps over the lazy dog.",
1460 "fn main() { println!(\"hello\"); }",
1461 " leading and trailing spaces ",
1462 "line1\nline2\ttabbed",
1463 "0123456789",
1464 "🌍🎉✨",
1465 ];
1466 let failures: Vec<String> = HF_MODELS
1467 .iter()
1468 .flat_map(|model| {
1469 let tok = match Tokenizer::from_model(model) {
1470 Ok(t) => t,
1471 Err(e) => return vec![format!("{model}: load error: {e}")],
1472 };
1473 texts
1474 .iter()
1475 .filter_map(|text| {
1476 let ids = tok.encode_with_special_tokens(text, false).ok()?;
1477 let decoded = tok.decode(&ids, false).ok()?;
1478 if decoded != *text {
1479 Some(format!("{model}: {text:?} → {decoded:?}"))
1480 } else {
1481 None
1482 }
1483 })
1484 .collect()
1485 })
1486 .collect();
1487 assert!(
1488 failures.is_empty(),
1489 "encode→decode roundtrip failures:\n{}",
1490 failures.join("\n")
1491 );
1492 }
1493
1494 #[test]
1505 fn add_bos_token() {
1506 let tok = Tokenizer::from_model("mistralai/Mistral-Nemo-Instruct-2407").unwrap();
1508 let bos_id = tok.token_to_id("<s>").expect("<s> not in vocabulary");
1509
1510 let with_bos = tok.encode_with_special_tokens("hello world", true).unwrap();
1511 let without_bos = tok
1512 .encode_with_special_tokens("hello world", false)
1513 .unwrap();
1514
1515 assert_eq!(
1516 with_bos.first().copied(),
1517 Some(bos_id),
1518 "first token should be BOS when add_special_tokens=true"
1519 );
1520 assert_ne!(
1521 without_bos.first().copied(),
1522 Some(bos_id),
1523 "BOS should be absent when add_special_tokens=false"
1524 );
1525 assert_eq!(&with_bos[1..], without_bos.as_slice());
1527
1528 let tok_q = Tokenizer::from_model("Qwen/Qwen3-0.6B").unwrap();
1530 let with_flag = tok_q
1531 .encode_with_special_tokens("hello world", true)
1532 .unwrap();
1533 let without_flag = tok_q
1534 .encode_with_special_tokens("hello world", false)
1535 .unwrap();
1536 assert_eq!(
1537 with_flag, without_flag,
1538 "Qwen3 has no BOS post-processor — add_special_tokens should have no effect"
1539 );
1540 }
1541
1542 #[test]
1544 fn decode_skip_special_tokens() {
1545 let model = "mistralai/Mistral-Nemo-Instruct-2407";
1547 let tok = Tokenizer::from_model(model).unwrap();
1548 let text = "hello world";
1549 let ids_with = tok.encode_with_special_tokens(text, true).unwrap();
1550 let ids_without = tok.encode_with_special_tokens(text, false).unwrap();
1551 assert!(
1552 ids_with.len() > ids_without.len(),
1553 "expected BOS/EOS from {model}"
1554 );
1555
1556 let skipped = tok.decode(&ids_with, true).unwrap();
1557 assert_eq!(skipped, text);
1558
1559 let full = tok.decode(&ids_with, false).unwrap();
1560 assert_ne!(full, text);
1561 assert!(full.contains(text));
1562 }
1563
1564 #[test]
1566 fn decode_batch_matches_sequential() {
1567 let tok = Tokenizer::from_model("Qwen/Qwen3-0.6B").unwrap();
1568 let sentences = &["first sentence", "second sentence", "日本語テスト", ""];
1569 let id_batches: Vec<Vec<u32>> = sentences
1570 .iter()
1571 .map(|s| tok.encode_with_special_tokens(s, false).unwrap())
1572 .collect();
1573 let refs: Vec<&[u32]> = id_batches.iter().map(Vec::as_slice).collect();
1574 let batch_out = tok.decode_batch(&refs, false).unwrap();
1575 for (out, expected) in batch_out.iter().zip(sentences.iter()) {
1576 assert_eq!(out, expected);
1577 }
1578 }
1579
1580 #[test]
1582 fn decode_tokens_matches_decode_by_id() {
1583 let tok = Tokenizer::from_model("Qwen/Qwen3-0.6B").unwrap();
1584 for text in &["Hello, world!", "The quick brown fox", "🌍 emoji"] {
1585 let ids = tok.encode_with_special_tokens(text, false).unwrap();
1586 let token_strings: Vec<String> = ids
1587 .iter()
1588 .map(|&id| tok.id_to_token(id).unwrap().to_string())
1589 .collect();
1590 let via_ids = tok.decode(&ids, false).unwrap();
1591 let via_tokens = tok.decode_tokens(token_strings).unwrap();
1592 assert_eq!(via_ids, via_tokens, "mismatch for {text:?}");
1593 }
1594 }
1595
1596 #[test]
1598 fn empty_string_encode_decode() {
1599 let tok = Tokenizer::from_model("Qwen/Qwen3-0.6B").unwrap();
1600 let ids = tok.encode_with_special_tokens("", false).unwrap();
1601 assert!(ids.is_empty(), "expected no tokens for empty string");
1602 assert_eq!(tok.decode(&[], false).unwrap(), "");
1603 }
1604
1605 #[test]
1607 fn encode_is_stable_after_decode() {
1608 let tok = Tokenizer::from_model("Qwen/Qwen3-0.6B").unwrap();
1609 for text in &["hello world", "日本語テスト", "fn foo() {}"] {
1610 let ids1 = tok.encode_with_special_tokens(text, false).unwrap();
1611 let decoded = tok.decode(&ids1, false).unwrap();
1612 let ids2 = tok.encode_with_special_tokens(&decoded, false).unwrap();
1613 assert_eq!(ids1, ids2, "encode not stable after decode for {text:?}");
1614 }
1615 }
1616
1617 #[test]
1619 fn post_process_false_is_identity_all_models() {
1620 for model in HF_MODELS {
1621 let tok = Tokenizer::from_model(model).unwrap();
1622 let payload = vec![100u32, 200, 300];
1623 let out = tok.post_process(payload.clone(), false);
1624 assert_eq!(
1625 out, payload,
1626 "{model}: post_process(false) should be identity"
1627 );
1628 }
1629 }
1630
1631 #[test]
1633 fn post_process_true_adds_special_tokens() {
1634 let tok = Tokenizer::from_model("mistralai/Mistral-Nemo-Instruct-2407").unwrap();
1636 let payload = vec![10u32, 20, 30];
1637 let without = tok.post_process(payload.clone(), false);
1638 let with_sp = tok.post_process(payload.clone(), true);
1639 assert_eq!(without, payload);
1640 assert!(
1641 with_sp.len() > without.len(),
1642 "expected special tokens to be added"
1643 );
1644 assert!(
1646 with_sp
1647 .windows(payload.len())
1648 .any(|w| w == payload.as_slice()),
1649 "payload should appear contiguously in post-processed output"
1650 );
1651 }
1652
1653 #[test]
1655 fn decode_unknown_id_is_skipped() {
1656 let tok = Tokenizer::from_model("Qwen/Qwen3-0.6B").unwrap();
1657 assert_eq!(tok.decode(&[u32::MAX], false).unwrap(), "");
1658 }
1659
1660 #[test]
1662 fn decode_mixed_valid_and_unknown_ids() {
1663 let tok = Tokenizer::from_model("Qwen/Qwen3-0.6B").unwrap();
1664 let valid = tok.encode_with_special_tokens("hello", false).unwrap();
1665 let mut mixed = valid.clone();
1666 mixed.push(u32::MAX);
1667 mixed.extend(tok.encode_with_special_tokens(" world", false).unwrap());
1668 let expected = tok.decode(&valid, false).unwrap()
1669 + &tok
1670 .decode(
1671 &tok.encode_with_special_tokens(" world", false).unwrap(),
1672 false,
1673 )
1674 .unwrap();
1675 assert_eq!(tok.decode(&mixed, false).unwrap(), expected);
1676 }
1677
1678 #[test]
1680 fn token_id_roundtrip_all_models() {
1681 let probe_ids = [0u32, 1, 2, 100, 1000, 10_000];
1682 let failures: Vec<String> = HF_MODELS
1683 .iter()
1684 .flat_map(|model| {
1685 let tok = match Tokenizer::from_model(model) {
1686 Ok(t) => t,
1687 Err(e) => return vec![format!("{model}: load error: {e}")],
1688 };
1689 probe_ids
1690 .iter()
1691 .filter_map(|&id| {
1692 let token = tok.id_to_token(id)?;
1693 let back = tok.token_to_id(token)?;
1694 if back != id {
1695 Some(format!("{model}: id {id} → {token:?} → {back}"))
1696 } else {
1697 None
1698 }
1699 })
1700 .collect()
1701 })
1702 .collect();
1703 assert!(
1704 failures.is_empty(),
1705 "id↔token roundtrip failures:\n{}",
1706 failures.join("\n")
1707 );
1708 }
1709
1710 const STREAM_MODEL: &str = "Qwen/Qwen3-0.6B";
1713
1714 fn stream_tok() -> Tokenizer {
1715 Tokenizer::from_model(STREAM_MODEL).expect("failed to load tokenizer")
1716 }
1717
1718 fn stream_collect(tok: &Tokenizer, ids: &[u32], skip: bool) -> (String, usize) {
1719 let mut buf = Vec::new();
1720 let mut prefix = String::new();
1721 let mut prefix_index = 0usize;
1722 let mut out = String::new();
1723 for &id in ids {
1724 let chunk: Option<String> = super::decode_stream_step(
1725 tok,
1726 vec![id],
1727 skip,
1728 &mut buf,
1729 &mut prefix,
1730 &mut prefix_index,
1731 )
1732 .unwrap();
1733 if let Some(c) = chunk {
1734 out.push_str(&c);
1735 }
1736 }
1737 (out, buf.len())
1738 }
1739
1740 #[test]
1741 fn decode_stream_reconstructs_ascii() {
1742 let tok = stream_tok();
1743 let text = "Hello, world! This is a streaming decode test.";
1744 let ids = tok.encode_with_special_tokens(text, false).unwrap();
1745 let (decoded, _) = stream_collect(&tok, &ids, false);
1746 assert_eq!(decoded, text);
1747 }
1748
1749 #[test]
1750 fn decode_stream_reconstructs_unicode() {
1751 let tok = stream_tok();
1752 let text = "日本語テスト: こんにちは 🌍 — привет мир";
1753 let ids = tok.encode_with_special_tokens(text, false).unwrap();
1754 let (decoded, _) = stream_collect(&tok, &ids, false);
1755 assert_eq!(decoded, text);
1756 }
1757
1758 #[test]
1759 fn decode_stream_reconstructs_code() {
1760 let tok = stream_tok();
1761 let text = r#"fn main() { println!("hello"); }"#;
1762 let ids = tok.encode_with_special_tokens(text, false).unwrap();
1763 let (decoded, _) = stream_collect(&tok, &ids, false);
1764 assert_eq!(decoded, text);
1765 }
1766
1767 #[test]
1768 fn decode_stream_empty_ids_no_output() {
1769 let tok = stream_tok();
1770 let (decoded, buf_len) = stream_collect(&tok, &[], false);
1771 assert!(decoded.is_empty());
1772 assert_eq!(buf_len, 0);
1773 }
1774
1775 #[test]
1776 fn decode_stream_single_token() {
1777 let tok = stream_tok();
1778 let ids = tok.encode_with_special_tokens("hello", false).unwrap();
1779 assert!(!ids.is_empty());
1780 let (decoded, _) = stream_collect(&tok, &ids[..1], false);
1781 assert!(!decoded.is_empty());
1782 }
1783
1784 #[test]
1785 fn decode_stream_batch_step_matches_sequential() {
1786 let tok = stream_tok();
1787 let text = "The quick brown fox jumps over the lazy dog.";
1788 let ids = tok.encode_with_special_tokens(text, false).unwrap();
1789 let (sequential, _) = stream_collect(&tok, &ids, false);
1790 let mut buf = Vec::new();
1791 let mut prefix = String::new();
1792 let mut prefix_index = 0usize;
1793 let batch: String = super::decode_stream_step(
1794 &tok,
1795 ids.clone(),
1796 false,
1797 &mut buf,
1798 &mut prefix,
1799 &mut prefix_index,
1800 )
1801 .unwrap()
1802 .unwrap_or_default();
1803 assert_eq!(sequential, batch);
1804 }
1805
1806 #[test]
1807 fn decode_stream_pre_seeded_only_returns_new_tokens() {
1808 let tok = stream_tok();
1809 let prompt = "The capital of France is";
1810 let cont = " Paris.";
1811 let prompt_ids = tok.encode_with_special_tokens(prompt, false).unwrap();
1812 let cont_ids = tok.encode_with_special_tokens(cont, false).unwrap();
1813 let mut buf = prompt_ids.clone();
1814 let mut prefix = String::new();
1815 let mut prefix_index = 0usize;
1816 let mut out = String::new();
1817 for &id in &cont_ids {
1818 let chunk: Option<String> = super::decode_stream_step(
1819 &tok,
1820 vec![id],
1821 false,
1822 &mut buf,
1823 &mut prefix,
1824 &mut prefix_index,
1825 )
1826 .unwrap();
1827 if let Some(c) = chunk {
1828 out.push_str(&c);
1829 }
1830 }
1831 assert_eq!(out, cont);
1832 }
1833
1834 #[test]
1835 fn decode_stream_skip_special_tokens() {
1836 let tok = Tokenizer::from_model("mistralai/Mistral-Nemo-Instruct-2407").unwrap();
1837 let text = "hello";
1838 let ids_with = tok.encode_with_special_tokens(text, true).unwrap();
1839 let ids_without = tok.encode_with_special_tokens(text, false).unwrap();
1840 assert!(
1841 ids_with.len() > ids_without.len(),
1842 "expected BOS/EOS tokens"
1843 );
1844 let (with_sp, _) = stream_collect(&tok, &ids_with, false);
1845 let (no_sp, _) = stream_collect(&tok, &ids_with, true);
1846 assert_eq!(no_sp, text);
1847 assert!(with_sp.contains(&no_sp));
1848 }
1849
1850 #[test]
1851 fn decode_stream_buffer_does_not_grow_unboundedly() {
1852 let tok = stream_tok();
1853 let text = "word ".repeat(80);
1854 let ids = tok.encode_with_special_tokens(text.trim(), false).unwrap();
1855 let (_, final_buf_len) = stream_collect(&tok, &ids, false);
1856 assert!(
1857 final_buf_len < 10,
1858 "buffer grew to {final_buf_len} entries after {} tokens",
1859 ids.len()
1860 );
1861 }
1862
1863 #[test]
1864 fn decode_stream_chunks_are_non_empty_and_concatenate() {
1865 let tok = stream_tok();
1866 let text = "one two three four five six seven eight nine ten";
1867 let ids = tok.encode_with_special_tokens(text, false).unwrap();
1868 let mut buf = Vec::new();
1869 let mut prefix = String::new();
1870 let mut prefix_index = 0usize;
1871 let mut chunks: Vec<String> = Vec::new();
1872 for &id in &ids {
1873 let chunk: Option<String> = super::decode_stream_step(
1874 &tok,
1875 vec![id],
1876 false,
1877 &mut buf,
1878 &mut prefix,
1879 &mut prefix_index,
1880 )
1881 .unwrap();
1882 if let Some(c) = chunk {
1883 assert!(!c.is_empty(), "stream emitted an empty chunk");
1884 chunks.push(c);
1885 }
1886 }
1887 assert_eq!(chunks.concat(), text);
1888 }
1889
1890 #[test]
1895 fn decode_stream_unknown_id_does_not_error() {
1896 let tok = stream_tok();
1897 let mut buf = Vec::new();
1898 let mut prefix = String::new();
1899 let mut prefix_index = 0usize;
1900 let result = super::decode_stream_step(
1901 &tok,
1902 vec![u32::MAX],
1903 false,
1904 &mut buf,
1905 &mut prefix,
1906 &mut prefix_index,
1907 );
1908 assert!(result.is_ok(), "expected Ok, got {result:?}");
1909 }
1910
1911 #[test]
1912 fn decode_stream_invalid_prefix_error_message() {
1913 let tok = stream_tok();
1914 let ids = tok.encode_with_special_tokens("hello", false).unwrap();
1915 let mut buf = ids.clone();
1916 let mut prefix = "ZZZZZZZ".to_string();
1917 let mut prefix_index = 0usize;
1918 let result: Result<Option<String>, String> = super::decode_stream_step(
1919 &tok,
1920 vec![*ids.last().unwrap()],
1921 false,
1922 &mut buf,
1923 &mut prefix,
1924 &mut prefix_index,
1925 );
1926 if let Err(msg) = result {
1927 assert!(
1928 msg.starts_with("Invalid prefix encountered"),
1929 "unexpected error: {msg:?}"
1930 );
1931 }
1932 }
1933}