1use std::{
2 collections::HashMap,
3 path::{Path, PathBuf},
4};
5
6use anyhow::{Error, Result};
7use base64::{engine::general_purpose::STANDARD, Engine as _};
8use rustc_hash::FxHashMap;
9use tiktoken_rs::{
10 cl100k_base, o200k_base, p50k_base, p50k_edit, r50k_base,
11 tokenizer::{get_tokenizer, Tokenizer},
12 CoreBPE,
13};
14
15use crate::{
16 chat_template::{
17 load_chat_template_from_file, ChatTemplateContentFormat, ChatTemplateParams,
18 ChatTemplateState, ThinkingKeyName, ThinkingToggle,
19 },
20 encoders::kimi_k25_tools::apply_kimi_k25_tools,
21 factory::discover_chat_template_in_dir,
22 kimi_k2_tokenizer,
23 traits::{Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait},
24};
25
26#[derive(Debug, Clone, Copy)]
27enum Renderer {
28 Jinja,
29 KimiK25Tools,
30}
31
32const CL100K_BASE_PATTERN: &str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+";
37
38type Rank = u32;
39
40#[derive(Default)]
46struct TiktokenConfig {
47 special_tokens: SpecialTokens,
48 added_tokens: HashMap<String, TokenIdType>,
50 chat_template: Option<String>,
51}
52
53fn parse_tiktoken_config(value: &serde_json::Value) -> TiktokenConfig {
55 TiktokenConfig {
56 special_tokens: parse_special_tokens(value),
57 added_tokens: parse_added_tokens_decoder(value),
58 chat_template: value
59 .get("chat_template")
60 .and_then(|v| v.as_str())
61 .map(String::from),
62 }
63}
64
65fn load_tiktoken_config_from_dir(
69 dir: &Path,
70) -> Result<(TiktokenConfig, Option<serde_json::Value>)> {
71 let config_path = dir.join("tokenizer_config.json");
72 if !config_path.exists() {
73 return Ok((TiktokenConfig::default(), None));
74 }
75 let content = std::fs::read_to_string(&config_path)?;
76 let value: serde_json::Value = serde_json::from_str(&content)?;
77 let config = parse_tiktoken_config(&value);
78 Ok((config, Some(value)))
79}
80
81fn parse_added_tokens_decoder(config: &serde_json::Value) -> HashMap<String, TokenIdType> {
85 let mut tokens = HashMap::new();
86 if let Some(added) = config
87 .get("added_tokens_decoder")
88 .and_then(|v| v.as_object())
89 {
90 for (id_str, token_info) in added {
91 if let (Ok(id), Some(content)) = (
92 id_str.parse::<TokenIdType>(),
93 token_info.get("content").and_then(|v| v.as_str()),
94 ) {
95 tokens.insert(content.to_string(), id);
96 }
97 }
98 }
99 tokens
100}
101
102fn parse_special_tokens(config: &serde_json::Value) -> SpecialTokens {
107 let get_str = |key: &str| {
108 config.get(key).and_then(|v| {
109 v.as_str()
110 .map(String::from)
111 .or_else(|| v.get("content").and_then(|c| c.as_str()).map(String::from))
112 })
113 };
114
115 let additional: Vec<String> = config
116 .get("additional_special_tokens")
117 .and_then(|v| v.as_array())
118 .map(|arr| {
119 arr.iter()
120 .filter_map(|v| {
121 v.as_str()
122 .map(String::from)
123 .or_else(|| v.get("content").and_then(|c| c.as_str()).map(String::from))
124 })
125 .collect()
126 })
127 .unwrap_or_default();
128
129 SpecialTokens {
130 bos_token: get_str("bos_token"),
131 eos_token: get_str("eos_token"),
132 unk_token: get_str("unk_token"),
133 sep_token: get_str("sep_token"),
134 pad_token: get_str("pad_token"),
135 cls_token: get_str("cls_token"),
136 mask_token: get_str("mask_token"),
137 additional_special_tokens: additional,
138 }
139}
140
141pub struct TiktokenTokenizer {
143 tokenizer: CoreBPE,
144 special_tokens: SpecialTokens,
145 vocab: HashMap<String, TokenIdType>,
146 reverse_vocab: HashMap<TokenIdType, String>,
147 vocab_size: usize,
148 chat_template: ChatTemplateState,
149 eos_token_ids: Vec<TokenIdType>,
150 renderer: Renderer,
151}
152
153#[derive(Debug, Clone, Copy)]
155pub enum TiktokenModel {
156 O200kBase,
158 Cl100kBase,
160 P50kBase,
162 P50kEdit,
164 R50kBase,
166}
167
168impl TiktokenTokenizer {
169 pub fn new(model: TiktokenModel) -> Result<Self> {
171 let tokenizer =
172 match model {
173 TiktokenModel::O200kBase => o200k_base()
174 .map_err(|e| Error::msg(format!("Failed to load o200k_base: {e}")))?,
175 TiktokenModel::Cl100kBase => cl100k_base()
176 .map_err(|e| Error::msg(format!("Failed to load cl100k_base: {e}")))?,
177 TiktokenModel::P50kBase => {
178 p50k_base().map_err(|e| Error::msg(format!("Failed to load p50k_base: {e}")))?
179 }
180 TiktokenModel::P50kEdit => {
181 p50k_edit().map_err(|e| Error::msg(format!("Failed to load p50k_edit: {e}")))?
182 }
183 TiktokenModel::R50kBase => {
184 r50k_base().map_err(|e| Error::msg(format!("Failed to load r50k_base: {e}")))?
185 }
186 };
187
188 let special_tokens = Self::get_special_tokens_for_model(model);
189
190 let vocab_size = match model {
191 TiktokenModel::O200kBase => 200019,
192 TiktokenModel::Cl100kBase => 100256,
193 TiktokenModel::P50kBase | TiktokenModel::P50kEdit => 50281,
194 TiktokenModel::R50kBase => 50257,
195 };
196
197 Ok(TiktokenTokenizer {
198 tokenizer,
199 special_tokens,
200 vocab: HashMap::new(),
201 reverse_vocab: HashMap::new(),
202 vocab_size,
203 chat_template: ChatTemplateState::empty(),
204 eos_token_ids: Vec::new(), renderer: Renderer::Jinja,
206 })
207 }
208
209 pub fn from_dir(dir: &Path) -> Result<Self> {
211 Self::from_dir_with_chat_template(dir, None)
212 }
213
214 pub fn from_dir_with_chat_template(
217 dir: &Path,
218 chat_template_path: Option<&str>,
219 ) -> Result<Self> {
220 let tiktoken_path = find_tiktoken_file(dir)?;
221 Self::load_from_path(&tiktoken_path, chat_template_path)
222 }
223
224 pub fn from_file(tiktoken_path: &Path) -> Result<Self> {
227 Self::from_file_with_chat_template(tiktoken_path, None)
228 }
229
230 pub fn from_file_with_chat_template(
232 tiktoken_path: &Path,
233 chat_template_path: Option<&str>,
234 ) -> Result<Self> {
235 Self::load_from_path(tiktoken_path, chat_template_path)
236 }
237
238 fn load_from_path(tiktoken_path: &Path, chat_template_path: Option<&str>) -> Result<Self> {
240 let tiktoken_path_str = tiktoken_path
242 .to_str()
243 .ok_or_else(|| Error::msg("Tiktoken file path is not valid UTF-8"))?;
244 let encoder = load_tiktoken_bpe(tiktoken_path_str)?;
245
246 let dir = tiktoken_path
248 .parent()
249 .ok_or_else(|| Error::msg("Cannot determine parent directory of tiktoken file"))?;
250 let (mut config, tokenizer_config_value) = load_tiktoken_config_from_dir(dir)?;
251
252 let pattern = if kimi_k2_tokenizer::matches(tokenizer_config_value.as_ref(), dir) {
257 kimi_k2_tokenizer::apply_reserved_special_tokens(
258 &mut config.added_tokens,
259 encoder.len(),
260 );
261 kimi_k2_tokenizer::KIMI_K2_PATTERN
262 } else {
263 CL100K_BASE_PATTERN
264 };
265
266 let special_tokens_encoder: FxHashMap<String, Rank> = config
268 .added_tokens
269 .iter()
270 .map(|(k, &v)| (k.clone(), v))
271 .collect();
272
273 let vocab_size = encoder
276 .values()
277 .copied()
278 .chain(special_tokens_encoder.values().copied())
279 .max()
280 .map(|id| id as usize + 1)
281 .unwrap_or(0);
282 let (vocab, reverse_vocab) = build_vocab_maps(&encoder, &config.added_tokens);
283 let tokenizer = CoreBPE::new(encoder, special_tokens_encoder, pattern)?;
284
285 let chat_template = if let Some(p) = chat_template_path {
288 load_chat_template_from_file(p)?
289 } else {
290 config.chat_template.or_else(|| {
291 discover_chat_template_in_dir(dir)
292 .and_then(|p| load_chat_template_from_file(&p).ok().flatten())
293 })
294 };
295
296 let eos_token_ids = crate::eos::load_eos_token_ids(dir);
298
299 let renderer = detect_renderer_from_config(dir);
301
302 Ok(TiktokenTokenizer {
303 tokenizer,
304 special_tokens: config.special_tokens,
305 vocab,
306 reverse_vocab,
307 vocab_size,
308 chat_template: ChatTemplateState::new(chat_template)?,
309 eos_token_ids,
310 renderer,
311 })
312 }
313
314 pub fn from_model_name(model_name: &str) -> Result<Self> {
316 let bare = model_name.rsplit('/').next().unwrap_or(model_name);
317 let model = match get_tokenizer(bare) {
318 Some(Tokenizer::O200kBase) => TiktokenModel::O200kBase,
319 Some(Tokenizer::Cl100kBase) => TiktokenModel::Cl100kBase,
320 Some(Tokenizer::P50kBase) => TiktokenModel::P50kBase,
321 Some(Tokenizer::P50kEdit) => TiktokenModel::P50kEdit,
322 Some(Tokenizer::R50kBase) => TiktokenModel::R50kBase,
323 _ => return Err(anyhow::anyhow!(
324 "Unrecognized OpenAI model name: '{model_name}'. Expected GPT-3, GPT-3.5, GPT-4, GPT-4o, GPT-4.5, GPT-5, o1, o3, o4, or related model names"
325 )),
326 };
327 Self::new(model)
328 }
329
330 fn get_special_tokens_for_model(model: TiktokenModel) -> SpecialTokens {
332 match model {
333 TiktokenModel::Cl100kBase => SpecialTokens {
334 bos_token: Some("<|endoftext|>".to_string()),
335 eos_token: Some("<|endoftext|>".to_string()),
336 unk_token: None,
337 sep_token: None,
338 pad_token: Some("<|endoftext|>".to_string()),
339 cls_token: None,
340 mask_token: None,
341 additional_special_tokens: vec![
342 "<|fim_prefix|>".to_string(),
343 "<|fim_middle|>".to_string(),
344 "<|fim_suffix|>".to_string(),
345 "<|endofprompt|>".to_string(),
346 ],
347 },
348 _ => SpecialTokens {
349 bos_token: Some("<|endoftext|>".to_string()),
350 eos_token: Some("<|endoftext|>".to_string()),
351 unk_token: None,
352 sep_token: None,
353 pad_token: Some("<|endoftext|>".to_string()),
354 cls_token: None,
355 mask_token: None,
356 additional_special_tokens: vec![],
357 },
358 }
359 }
360}
361
362fn load_tiktoken_bpe(path: &str) -> Result<FxHashMap<Vec<u8>, Rank>> {
366 let content = std::fs::read_to_string(path)?;
367 let mut encoder =
368 FxHashMap::with_capacity_and_hasher(content.lines().count(), Default::default());
369 for line in content.lines() {
370 if line.is_empty() {
371 continue;
372 }
373 let mut parts = line.split_whitespace();
374 let token_b64 = parts
375 .next()
376 .ok_or_else(|| Error::msg("missing token in tiktoken file"))?;
377 let rank_str = parts
378 .next()
379 .ok_or_else(|| Error::msg("missing rank in tiktoken file"))?;
380 let token_bytes = STANDARD.decode(token_b64)?;
381 let rank: Rank = rank_str.parse()?;
382 encoder.insert(token_bytes, rank);
383 }
384 Ok(encoder)
385}
386
387fn build_vocab_maps(
389 encoder: &FxHashMap<Vec<u8>, Rank>,
390 added_tokens: &HashMap<String, TokenIdType>,
391) -> (HashMap<String, TokenIdType>, HashMap<TokenIdType, String>) {
392 let capacity = encoder.len() + added_tokens.len();
393 let mut vocab = HashMap::with_capacity(capacity);
394 let mut reverse_vocab = HashMap::with_capacity(capacity);
395
396 for (token_bytes, &rank) in encoder {
398 if let Ok(token_str) = std::str::from_utf8(token_bytes) {
399 vocab.insert(token_str.to_string(), rank);
400 reverse_vocab.insert(rank, token_str.to_string());
401 }
402 }
403
404 for (token_str, &id) in added_tokens {
406 vocab.insert(token_str.clone(), id);
407 reverse_vocab.insert(id, token_str.clone());
408 }
409
410 (vocab, reverse_vocab)
411}
412
413fn find_tiktoken_file(dir: &Path) -> Result<PathBuf> {
417 let tiktoken_model = dir.join("tiktoken.model");
418 if tiktoken_model.exists() {
419 return Ok(tiktoken_model);
420 }
421
422 if let Ok(entries) = std::fs::read_dir(dir) {
424 for entry in entries.flatten() {
425 if let Some(name) = entry.file_name().to_str() {
426 if name.ends_with(".tiktoken") {
427 return Ok(entry.path());
428 }
429 }
430 }
431 }
432
433 Err(Error::msg(format!(
434 "No tiktoken model file found in '{}'",
435 dir.display()
436 )))
437}
438
439pub fn has_tiktoken_file(dir: &Path) -> bool {
441 if dir.join("tiktoken.model").exists() {
442 return true;
443 }
444 std::fs::read_dir(dir)
445 .ok()
446 .map(|entries| {
447 entries.flatten().any(|e| {
448 e.file_name()
449 .to_str()
450 .is_some_and(|n| n.ends_with(".tiktoken"))
451 })
452 })
453 .unwrap_or(false)
454}
455
456pub fn is_tiktoken_file(path: &Path) -> bool {
458 path.file_name()
459 .and_then(|n| n.to_str())
460 .is_some_and(|name| name == "tiktoken.model" || name.ends_with(".tiktoken"))
461}
462
463impl Encoder for TiktokenTokenizer {
464 fn encode(&self, input: &str, _add_special_tokens: bool) -> Result<Encoding> {
465 let tokens = self.tokenizer.encode_with_special_tokens(input);
480 Ok(Encoding::Tiktoken(tokens))
481 }
482
483 fn encode_batch(&self, inputs: &[&str], add_special_tokens: bool) -> Result<Vec<Encoding>> {
484 inputs
485 .iter()
486 .map(|input| self.encode(input, add_special_tokens))
487 .collect()
488 }
489}
490
491impl Decoder for TiktokenTokenizer {
492 fn decode(&self, token_ids: &[TokenIdType], _skip_special_tokens: bool) -> Result<String> {
493 match self.tokenizer.decode(token_ids.to_vec()) {
494 Ok(text) => Ok(text),
495 Err(err) if is_unknown_tiktoken_decode_error(&err) => Err(Error::msg(format!(
496 "tiktoken decode failed for unknown token id: {err}"
497 ))),
498 Err(err) => {
499 let bytes: Vec<u8> = self
501 .tokenizer
502 ._decode_native_and_split(token_ids.to_vec())
503 .flatten()
504 .collect();
505 tracing::warn!(
506 error = %err,
507 token_count = token_ids.len(),
508 "tiktoken decode failed; returning lossy UTF-8 fallback"
509 );
510 Ok(String::from_utf8_lossy(&bytes).into_owned())
511 }
512 }
513 }
514}
515
516fn is_unknown_tiktoken_decode_error(err: &Error) -> bool {
524 err.to_string().starts_with("Invalid token for decoding:")
525}
526
527impl TokenizerTrait for TiktokenTokenizer {
528 fn vocab_size(&self) -> usize {
529 self.vocab_size
530 }
531
532 fn get_special_tokens(&self) -> &SpecialTokens {
533 &self.special_tokens
534 }
535
536 fn token_to_id(&self, token: &str) -> Option<TokenIdType> {
537 self.vocab.get(token).copied()
538 }
539
540 fn id_to_token(&self, id: TokenIdType) -> Option<String> {
541 self.reverse_vocab.get(&id).cloned()
542 }
543
544 fn as_any(&self) -> &dyn std::any::Any {
545 self
546 }
547
548 fn apply_chat_template(
549 &self,
550 messages: &[serde_json::Value],
551 params: ChatTemplateParams,
552 ) -> Result<String> {
553 let params = if params.special_tokens.is_some() {
555 params
556 } else {
557 ChatTemplateParams {
558 special_tokens: Some(&self.special_tokens),
559 ..params
560 }
561 };
562 match self.renderer {
563 Renderer::Jinja => self.chat_template.apply(messages, params),
564 Renderer::KimiK25Tools => apply_kimi_k25_tools(&self.chat_template, messages, params),
565 }
566 }
567
568 fn chat_template_content_format(&self) -> ChatTemplateContentFormat {
569 self.chat_template.content_format()
570 }
571
572 fn thinking_toggle(&self) -> ThinkingToggle {
573 self.chat_template.thinking_toggle()
574 }
575
576 fn thinking_key_name(&self) -> Option<ThinkingKeyName> {
577 self.chat_template.thinking_key_name()
578 }
579 fn eos_token_ids(&self) -> &[TokenIdType] {
580 &self.eos_token_ids
581 }
582
583 fn think_in_prefill(&self) -> bool {
584 self.chat_template.think_in_prefill()
585 }
586
587 fn set_chat_template(&mut self, template: String) -> Result<()> {
588 self.chat_template.set(template)
589 }
590}
591
592fn detect_renderer_from_config(dir: &Path) -> Renderer {
599 let path = dir.join("config.json");
600 if !path.exists() {
601 return Renderer::Jinja;
602 }
603 let content = match std::fs::read_to_string(&path) {
604 Ok(c) => c,
605 Err(err) => {
606 tracing::debug!(?err, ?path, "config.json unreadable; using Jinja renderer");
607 return Renderer::Jinja;
608 }
609 };
610 let value: serde_json::Value = match serde_json::from_str(&content) {
611 Ok(v) => v,
612 Err(err) => {
613 tracing::debug!(?err, ?path, "config.json malformed; using Jinja renderer");
614 return Renderer::Jinja;
615 }
616 };
617 let is_kimi = value
618 .get("architectures")
619 .and_then(|v| v.as_array())
620 .is_some_and(|a| {
621 a.iter()
622 .any(|v| v.as_str() == Some("KimiK25ForConditionalGeneration"))
623 });
624 if is_kimi {
625 tracing::debug!(?path, "selected KimiK25Tools chat-template renderer");
626 return Renderer::KimiK25Tools;
627 }
628 Renderer::Jinja
629}
630
631#[cfg(test)]
632mod tests {
633 use super::*;
634 use crate::traits::{Decoder, Encoder, Tokenizer};
635
636 const MINIMAL_TIKTOKEN_MODEL: &str = "YQ== 0\nYg== 1\n";
637
638 fn write_minimal_tiktoken_dir(
639 tokenizer_config: &str,
640 model_config: Option<&str>,
641 ) -> tempfile::TempDir {
642 let dir = tempfile::tempdir().unwrap();
643 std::fs::write(dir.path().join("tiktoken.model"), MINIMAL_TIKTOKEN_MODEL).unwrap();
644 std::fs::write(dir.path().join("tokenizer_config.json"), tokenizer_config).unwrap();
645 if let Some(model_config) = model_config {
646 std::fs::write(dir.path().join("config.json"), model_config).unwrap();
647 }
648 dir
649 }
650
651 #[test]
652 fn test_tiktoken_creation() {
653 let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap();
654 assert_eq!(tokenizer.vocab_size(), 100256);
655 }
656
657 #[test]
658 fn test_encode_decode() {
659 let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap();
660
661 let text = "Hello, world!";
662 let encoding = tokenizer.encode(text, false).unwrap();
663
664 let decoded = tokenizer.decode(encoding.token_ids(), false).unwrap();
665 assert_eq!(decoded, text);
666 }
667
668 #[test]
669 fn test_batch_encode() {
670 let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap();
671
672 let texts = vec!["Hello", "World", "Test"];
673 let encodings = tokenizer.encode_batch(&texts, false).unwrap();
674
675 assert_eq!(encodings.len(), 3);
676 for (i, encoding) in encodings.iter().enumerate() {
677 let decoded = tokenizer.decode(encoding.token_ids(), false).unwrap();
678 assert_eq!(decoded, texts[i]);
679 }
680 }
681
682 #[test]
683 fn test_special_tokens() {
684 let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap();
685 let special_tokens = tokenizer.get_special_tokens();
686
687 assert!(special_tokens.eos_token.is_some());
688 assert_eq!(special_tokens.eos_token.as_ref().unwrap(), "<|endoftext|>");
689 }
690
691 #[test]
692 fn test_builtin_tokenizer_has_empty_vocab_maps() {
693 let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap();
694 assert_eq!(tokenizer.token_to_id("hello"), None);
696 assert_eq!(tokenizer.id_to_token(0), None);
697 }
698
699 #[test]
700 fn test_load_tiktoken_bpe() {
701 use std::io::Write;
702 let dir = tempfile::tempdir().unwrap();
703 let file_path = dir.path().join("test.tiktoken");
704 let mut f = std::fs::File::create(&file_path).unwrap();
705 writeln!(f, "IQ== 0").unwrap();
708 writeln!(f, "Ig== 1").unwrap();
709
710 let encoder = load_tiktoken_bpe(file_path.to_str().unwrap()).unwrap();
711 assert_eq!(encoder.len(), 2);
712 assert_eq!(encoder.get(&vec![0x21u8]), Some(&0));
713 assert_eq!(encoder.get(&vec![0x22u8]), Some(&1));
714 }
715
716 #[test]
717 fn test_build_vocab_maps() {
718 let mut encoder = FxHashMap::default();
719 encoder.insert(b"hello".to_vec(), 42u32);
720 encoder.insert(vec![0xFF, 0xFE], 99u32); let mut added = HashMap::new();
723 added.insert("<|special|>".to_string(), 1000u32);
724
725 let (vocab, reverse_vocab) = build_vocab_maps(&encoder, &added);
726
727 assert_eq!(vocab.get("hello"), Some(&42));
729 assert_eq!(reverse_vocab.get(&42), Some(&"hello".to_string()));
730
731 assert!(!vocab.contains_key("\u{FFFD}")); assert_eq!(vocab.get("<|special|>"), Some(&1000));
736 assert_eq!(reverse_vocab.get(&1000), Some(&"<|special|>".to_string()));
737 }
738
739 #[test]
740 fn test_has_tiktoken_file() {
741 let dir = tempfile::tempdir().unwrap();
742 assert!(!has_tiktoken_file(dir.path()));
743
744 std::fs::write(dir.path().join("tiktoken.model"), "test").unwrap();
745 assert!(has_tiktoken_file(dir.path()));
746 }
747
748 #[test]
749 fn test_find_tiktoken_file_model() {
750 let dir = tempfile::tempdir().unwrap();
751 std::fs::write(dir.path().join("tiktoken.model"), "test").unwrap();
752 let found = find_tiktoken_file(dir.path()).unwrap();
753 assert_eq!(found.file_name().unwrap(), "tiktoken.model");
754 }
755
756 #[test]
757 fn test_find_tiktoken_file_extension() {
758 let dir = tempfile::tempdir().unwrap();
759 std::fs::write(dir.path().join("vocab.tiktoken"), "test").unwrap();
760 let found = find_tiktoken_file(dir.path()).unwrap();
761 assert!(found
762 .file_name()
763 .unwrap()
764 .to_str()
765 .unwrap()
766 .ends_with(".tiktoken"));
767 }
768
769 #[test]
770 fn test_is_tiktoken_file() {
771 assert!(is_tiktoken_file(Path::new("tiktoken.model")));
772 assert!(is_tiktoken_file(Path::new("vocab.tiktoken")));
773 assert!(!is_tiktoken_file(Path::new("tokenizer.json")));
774 assert!(!is_tiktoken_file(Path::new("model.bin")));
775 }
776
777 #[test]
778 fn test_parse_added_tokens_decoder() {
779 let config: serde_json::Value = serde_json::json!({
780 "added_tokens_decoder": {
781 "163584": { "content": "[BOS]", "special": true },
782 "163585": { "content": "[EOS]", "special": true },
783 "163586": { "content": "<|im_end|>", "special": true }
784 }
785 });
786 let tokens = parse_added_tokens_decoder(&config);
787 assert_eq!(tokens.get("[BOS]"), Some(&163584));
788 assert_eq!(tokens.get("[EOS]"), Some(&163585));
789 assert_eq!(tokens.get("<|im_end|>"), Some(&163586));
790 }
791
792 #[test]
793 fn test_tiktoken_unknown_token_decode_returns_error() {
794 let dir = write_minimal_tiktoken_dir(
795 r#"{
796 "added_tokens_decoder": {
797 "2": { "content": "[BOS]", "special": true }
798 }
799 }"#,
800 None,
801 );
802 let tokenizer = TiktokenTokenizer::from_dir(dir.path()).unwrap();
803
804 let err = tokenizer.decode(&[4], false).unwrap_err();
805 assert!(
806 err.to_string()
807 .contains("tiktoken decode failed for unknown token id"),
808 "unexpected error: {err}"
809 );
810 }
811
812 #[test]
813 fn test_parse_special_tokens() {
814 let config: serde_json::Value = serde_json::json!({
815 "bos_token": "[BOS]",
816 "eos_token": "[EOS]",
817 "unk_token": "[UNK]",
818 "pad_token": "[PAD]",
819 "additional_special_tokens": ["<|im_end|>", "<|im_user|>"]
820 });
821 let special = parse_special_tokens(&config);
822 assert_eq!(special.bos_token.as_deref(), Some("[BOS]"));
823 assert_eq!(special.eos_token.as_deref(), Some("[EOS]"));
824 assert_eq!(special.unk_token.as_deref(), Some("[UNK]"));
825 assert_eq!(special.pad_token.as_deref(), Some("[PAD]"));
826 assert_eq!(special.additional_special_tokens.len(), 2);
827 }
828
829 #[test]
830 fn test_parse_special_tokens_object_valued() {
831 let config: serde_json::Value = serde_json::json!({
832 "bos_token": {"content": "<s>", "lstrip": false, "rstrip": false, "single_word": false, "special": true},
833 "eos_token": "</s>",
834 "unk_token": {"content": "<unk>", "special": true}
835 });
836 let special = parse_special_tokens(&config);
837 assert_eq!(special.bos_token.as_deref(), Some("<s>"));
838 assert_eq!(special.eos_token.as_deref(), Some("</s>"));
839 assert_eq!(special.unk_token.as_deref(), Some("<unk>"));
840 }
841
842 #[test]
843 fn test_tiktoken_config_default() {
844 let config = TiktokenConfig::default();
845 assert!(config.special_tokens.bos_token.is_none());
846 assert!(config.added_tokens.is_empty());
847 assert!(config.chat_template.is_none());
848 }
849
850 #[test]
851 fn test_load_tiktoken_config_from_dir_missing_file() {
852 let dir = tempfile::tempdir().unwrap();
853 let (config, value) = load_tiktoken_config_from_dir(dir.path()).unwrap();
854 assert!(value.is_none());
855 assert!(config.added_tokens.is_empty());
856 }
857
858 #[test]
859 fn test_decode_lossy_fallback_for_invalid_utf8() {
860 let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap();
866
867 let full_encoding = tokenizer.encode("π", false).unwrap();
871 let full_ids = full_encoding.token_ids();
872 assert!(
873 full_ids.len() > 1,
874 "emoji should encode to multiple tokens in cl100k_base"
875 );
876
877 let partial_ids = &full_ids[..1];
879 let result = tokenizer.decode(partial_ids, false);
880 assert!(
881 result.is_ok(),
882 "decode of partial UTF-8 should succeed via lossy fallback"
883 );
884 let decoded = result.unwrap();
885 assert!(
886 decoded.contains('\u{FFFD}') || decoded.is_empty(),
887 "lossy decode should contain replacement char or be empty, got: {decoded:?}"
888 );
889 }
890
891 #[test]
892 fn test_decode_valid_utf8_does_not_use_fallback() {
893 let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap();
895 let text = "Hello, δΈη!";
896 let encoding = tokenizer.encode(text, false).unwrap();
897 let decoded = tokenizer.decode(encoding.token_ids(), false).unwrap();
898 assert_eq!(decoded, text);
899 }
900
901 #[test]
902 fn test_encode_recognizes_special_tokens_in_input() {
903 let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap();
907 let encoding = tokenizer.encode("hello<|endoftext|>world", false).unwrap();
911 let ids = encoding.token_ids();
912 assert!(
913 ids.contains(&100257),
914 "Special token <|endoftext|> should be recognized as single token, got: {ids:?}"
915 );
916 }
917}