1use std::{
2 collections::HashMap,
3 path::{Path, PathBuf},
4};
5
6use anyhow::{Error, Result};
7use base64::{engine::general_purpose::STANDARD, Engine as _};
8use rustc_hash::FxHashMap;
9use tiktoken_rs::{cl100k_base, p50k_base, p50k_edit, r50k_base, CoreBPE};
10
11use crate::{
12 chat_template::{
13 load_chat_template_from_file, ChatTemplateContentFormat, ChatTemplateParams,
14 ChatTemplateState, ThinkingKeyName, ThinkingToggle,
15 },
16 factory::discover_chat_template_in_dir,
17 traits::{Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait},
18};
19
20const CL100K_BASE_PATTERN: &str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+";
29
30type Rank = u32;
31
32#[derive(Default)]
38struct TiktokenConfig {
39 special_tokens: SpecialTokens,
40 added_tokens: HashMap<String, TokenIdType>,
42 chat_template: Option<String>,
43}
44
45fn load_tiktoken_config(config_path: &Path) -> Result<TiktokenConfig> {
47 let content = std::fs::read_to_string(config_path)?;
48 let config: serde_json::Value = serde_json::from_str(&content)?;
49
50 let added_tokens = parse_added_tokens_decoder(&config);
51 let special_tokens = parse_special_tokens(&config);
52
53 let chat_template = config
54 .get("chat_template")
55 .and_then(|v| v.as_str())
56 .map(String::from);
57
58 Ok(TiktokenConfig {
59 special_tokens,
60 added_tokens,
61 chat_template,
62 })
63}
64
65fn parse_added_tokens_decoder(config: &serde_json::Value) -> HashMap<String, TokenIdType> {
69 let mut tokens = HashMap::new();
70 if let Some(added) = config
71 .get("added_tokens_decoder")
72 .and_then(|v| v.as_object())
73 {
74 for (id_str, token_info) in added {
75 if let (Ok(id), Some(content)) = (
76 id_str.parse::<TokenIdType>(),
77 token_info.get("content").and_then(|v| v.as_str()),
78 ) {
79 tokens.insert(content.to_string(), id);
80 }
81 }
82 }
83 tokens
84}
85
86fn parse_special_tokens(config: &serde_json::Value) -> SpecialTokens {
91 let get_str = |key: &str| {
92 config.get(key).and_then(|v| {
93 v.as_str()
94 .map(String::from)
95 .or_else(|| v.get("content").and_then(|c| c.as_str()).map(String::from))
96 })
97 };
98
99 let additional: Vec<String> = config
100 .get("additional_special_tokens")
101 .and_then(|v| v.as_array())
102 .map(|arr| {
103 arr.iter()
104 .filter_map(|v| {
105 v.as_str()
106 .map(String::from)
107 .or_else(|| v.get("content").and_then(|c| c.as_str()).map(String::from))
108 })
109 .collect()
110 })
111 .unwrap_or_default();
112
113 SpecialTokens {
114 bos_token: get_str("bos_token"),
115 eos_token: get_str("eos_token"),
116 unk_token: get_str("unk_token"),
117 sep_token: get_str("sep_token"),
118 pad_token: get_str("pad_token"),
119 cls_token: get_str("cls_token"),
120 mask_token: get_str("mask_token"),
121 additional_special_tokens: additional,
122 }
123}
124
125pub struct TiktokenTokenizer {
127 tokenizer: CoreBPE,
128 special_tokens: SpecialTokens,
129 vocab: HashMap<String, TokenIdType>,
130 reverse_vocab: HashMap<TokenIdType, String>,
131 vocab_size: usize,
132 chat_template: ChatTemplateState,
133}
134
135#[derive(Debug, Clone, Copy)]
137pub enum TiktokenModel {
138 Cl100kBase,
140 P50kBase,
142 P50kEdit,
144 R50kBase,
146}
147
148impl TiktokenTokenizer {
149 pub fn new(model: TiktokenModel) -> Result<Self> {
151 let tokenizer = match model {
152 TiktokenModel::Cl100kBase => {
153 cl100k_base().map_err(|e| Error::msg(format!("Failed to load cl100k_base: {e}")))?
154 }
155 TiktokenModel::P50kBase => {
156 p50k_base().map_err(|e| Error::msg(format!("Failed to load p50k_base: {e}")))?
157 }
158 TiktokenModel::P50kEdit => {
159 p50k_edit().map_err(|e| Error::msg(format!("Failed to load p50k_edit: {e}")))?
160 }
161 TiktokenModel::R50kBase => {
162 r50k_base().map_err(|e| Error::msg(format!("Failed to load r50k_base: {e}")))?
163 }
164 };
165
166 let special_tokens = Self::get_special_tokens_for_model(model);
167
168 let vocab_size = match model {
169 TiktokenModel::Cl100kBase => 100256,
170 TiktokenModel::P50kBase | TiktokenModel::P50kEdit => 50281,
171 TiktokenModel::R50kBase => 50257,
172 };
173
174 Ok(TiktokenTokenizer {
175 tokenizer,
176 special_tokens,
177 vocab: HashMap::new(),
178 reverse_vocab: HashMap::new(),
179 vocab_size,
180 chat_template: ChatTemplateState::empty(),
181 })
182 }
183
184 pub fn from_dir(dir: &Path) -> Result<Self> {
186 Self::from_dir_with_chat_template(dir, None)
187 }
188
189 pub fn from_dir_with_chat_template(
192 dir: &Path,
193 chat_template_path: Option<&str>,
194 ) -> Result<Self> {
195 let tiktoken_path = find_tiktoken_file(dir)?;
196 Self::load_from_path(&tiktoken_path, chat_template_path)
197 }
198
199 pub fn from_file(tiktoken_path: &Path) -> Result<Self> {
202 Self::from_file_with_chat_template(tiktoken_path, None)
203 }
204
205 pub fn from_file_with_chat_template(
207 tiktoken_path: &Path,
208 chat_template_path: Option<&str>,
209 ) -> Result<Self> {
210 Self::load_from_path(tiktoken_path, chat_template_path)
211 }
212
213 fn load_from_path(tiktoken_path: &Path, chat_template_path: Option<&str>) -> Result<Self> {
215 let tiktoken_path_str = tiktoken_path
217 .to_str()
218 .ok_or_else(|| Error::msg("Tiktoken file path is not valid UTF-8"))?;
219 let encoder = load_tiktoken_bpe(tiktoken_path_str)?;
220
221 let dir = tiktoken_path
223 .parent()
224 .ok_or_else(|| Error::msg("Cannot determine parent directory of tiktoken file"))?;
225 let config_path = dir.join("tokenizer_config.json");
226 let config = if config_path.exists() {
227 load_tiktoken_config(&config_path)?
228 } else {
229 TiktokenConfig::default()
230 };
231
232 let special_tokens_encoder: FxHashMap<String, Rank> = config
234 .added_tokens
235 .iter()
236 .map(|(k, &v)| (k.clone(), v))
237 .collect();
238
239 let vocab_size = encoder
242 .values()
243 .copied()
244 .chain(special_tokens_encoder.values().copied())
245 .max()
246 .map(|id| id as usize + 1)
247 .unwrap_or(0);
248 let (vocab, reverse_vocab) = build_vocab_maps(&encoder, &config.added_tokens);
249 let tokenizer = CoreBPE::new(encoder, special_tokens_encoder, CL100K_BASE_PATTERN)?;
250
251 let chat_template = if let Some(p) = chat_template_path {
254 load_chat_template_from_file(p)?
255 } else {
256 config.chat_template.or_else(|| {
257 discover_chat_template_in_dir(dir)
258 .and_then(|p| load_chat_template_from_file(&p).ok().flatten())
259 })
260 };
261
262 Ok(TiktokenTokenizer {
263 tokenizer,
264 special_tokens: config.special_tokens,
265 vocab,
266 reverse_vocab,
267 vocab_size,
268 chat_template: ChatTemplateState::new(chat_template)?,
269 })
270 }
271
272 pub fn from_model_name(model_name: &str) -> Result<Self> {
274 let model = Self::model_from_name(model_name)?;
275 Self::new(model)
276 }
277
278 fn model_from_name(model_name: &str) -> Result<TiktokenModel> {
280 if model_name.contains("gpt-4")
281 || model_name.contains("gpt-3.5")
282 || model_name.contains("turbo")
283 {
284 Ok(TiktokenModel::Cl100kBase)
285 } else if model_name.contains("davinci-002")
286 || model_name.contains("davinci-003")
287 || model_name.contains("codex")
288 {
289 Ok(TiktokenModel::P50kBase)
290 } else if model_name.contains("edit") {
291 Ok(TiktokenModel::P50kEdit)
292 } else if model_name.contains("davinci")
293 || model_name.contains("curie")
294 || model_name.contains("babbage")
295 || model_name.contains("ada")
296 {
297 Ok(TiktokenModel::R50kBase)
298 } else {
299 Err(anyhow::anyhow!(
300 "Unrecognized OpenAI model name: '{model_name}'. Expected GPT-3, GPT-3.5, GPT-4, or related model names"
301 ))
302 }
303 }
304
305 fn get_special_tokens_for_model(model: TiktokenModel) -> SpecialTokens {
307 match model {
308 TiktokenModel::Cl100kBase => SpecialTokens {
309 bos_token: Some("<|endoftext|>".to_string()),
310 eos_token: Some("<|endoftext|>".to_string()),
311 unk_token: None,
312 sep_token: None,
313 pad_token: Some("<|endoftext|>".to_string()),
314 cls_token: None,
315 mask_token: None,
316 additional_special_tokens: vec![
317 "<|fim_prefix|>".to_string(),
318 "<|fim_middle|>".to_string(),
319 "<|fim_suffix|>".to_string(),
320 "<|endofprompt|>".to_string(),
321 ],
322 },
323 _ => SpecialTokens {
324 bos_token: Some("<|endoftext|>".to_string()),
325 eos_token: Some("<|endoftext|>".to_string()),
326 unk_token: None,
327 sep_token: None,
328 pad_token: Some("<|endoftext|>".to_string()),
329 cls_token: None,
330 mask_token: None,
331 additional_special_tokens: vec![],
332 },
333 }
334 }
335}
336
337fn load_tiktoken_bpe(path: &str) -> Result<FxHashMap<Vec<u8>, Rank>> {
341 let content = std::fs::read_to_string(path)?;
342 let mut encoder =
343 FxHashMap::with_capacity_and_hasher(content.lines().count(), Default::default());
344 for line in content.lines() {
345 if line.is_empty() {
346 continue;
347 }
348 let mut parts = line.split_whitespace();
349 let token_b64 = parts
350 .next()
351 .ok_or_else(|| Error::msg("missing token in tiktoken file"))?;
352 let rank_str = parts
353 .next()
354 .ok_or_else(|| Error::msg("missing rank in tiktoken file"))?;
355 let token_bytes = STANDARD.decode(token_b64)?;
356 let rank: Rank = rank_str.parse()?;
357 encoder.insert(token_bytes, rank);
358 }
359 Ok(encoder)
360}
361
362fn build_vocab_maps(
364 encoder: &FxHashMap<Vec<u8>, Rank>,
365 added_tokens: &HashMap<String, TokenIdType>,
366) -> (HashMap<String, TokenIdType>, HashMap<TokenIdType, String>) {
367 let capacity = encoder.len() + added_tokens.len();
368 let mut vocab = HashMap::with_capacity(capacity);
369 let mut reverse_vocab = HashMap::with_capacity(capacity);
370
371 for (token_bytes, &rank) in encoder {
373 if let Ok(token_str) = std::str::from_utf8(token_bytes) {
374 vocab.insert(token_str.to_string(), rank);
375 reverse_vocab.insert(rank, token_str.to_string());
376 }
377 }
378
379 for (token_str, &id) in added_tokens {
381 vocab.insert(token_str.clone(), id);
382 reverse_vocab.insert(id, token_str.clone());
383 }
384
385 (vocab, reverse_vocab)
386}
387
388fn find_tiktoken_file(dir: &Path) -> Result<PathBuf> {
392 let tiktoken_model = dir.join("tiktoken.model");
393 if tiktoken_model.exists() {
394 return Ok(tiktoken_model);
395 }
396
397 if let Ok(entries) = std::fs::read_dir(dir) {
399 for entry in entries.flatten() {
400 if let Some(name) = entry.file_name().to_str() {
401 if name.ends_with(".tiktoken") {
402 return Ok(entry.path());
403 }
404 }
405 }
406 }
407
408 Err(Error::msg(format!(
409 "No tiktoken model file found in '{}'",
410 dir.display()
411 )))
412}
413
414pub fn has_tiktoken_file(dir: &Path) -> bool {
416 if dir.join("tiktoken.model").exists() {
417 return true;
418 }
419 std::fs::read_dir(dir)
420 .ok()
421 .map(|entries| {
422 entries.flatten().any(|e| {
423 e.file_name()
424 .to_str()
425 .is_some_and(|n| n.ends_with(".tiktoken"))
426 })
427 })
428 .unwrap_or(false)
429}
430
431pub fn is_tiktoken_file(path: &Path) -> bool {
433 path.file_name()
434 .and_then(|n| n.to_str())
435 .is_some_and(|name| name == "tiktoken.model" || name.ends_with(".tiktoken"))
436}
437
438impl Encoder for TiktokenTokenizer {
439 fn encode(&self, input: &str, _add_special_tokens: bool) -> Result<Encoding> {
440 let tokens = self.tokenizer.encode_ordinary(input);
441 Ok(Encoding::Tiktoken(tokens))
442 }
443
444 fn encode_batch(&self, inputs: &[&str], add_special_tokens: bool) -> Result<Vec<Encoding>> {
445 inputs
446 .iter()
447 .map(|input| self.encode(input, add_special_tokens))
448 .collect()
449 }
450}
451
452impl Decoder for TiktokenTokenizer {
453 fn decode(&self, token_ids: &[TokenIdType], _skip_special_tokens: bool) -> Result<String> {
454 match self.tokenizer.decode(token_ids.to_vec()) {
455 Ok(text) => Ok(text),
456 Err(err) => {
457 let bytes: Vec<u8> = self
459 .tokenizer
460 ._decode_native_and_split(token_ids.to_vec())
461 .flatten()
462 .collect();
463 tracing::warn!(
464 error = %err,
465 token_count = token_ids.len(),
466 "tiktoken decode failed; returning lossy UTF-8 fallback"
467 );
468 Ok(String::from_utf8_lossy(&bytes).into_owned())
469 }
470 }
471 }
472}
473
474impl TokenizerTrait for TiktokenTokenizer {
475 fn vocab_size(&self) -> usize {
476 self.vocab_size
477 }
478
479 fn get_special_tokens(&self) -> &SpecialTokens {
480 &self.special_tokens
481 }
482
483 fn token_to_id(&self, token: &str) -> Option<TokenIdType> {
484 self.vocab.get(token).copied()
485 }
486
487 fn id_to_token(&self, id: TokenIdType) -> Option<String> {
488 self.reverse_vocab.get(&id).cloned()
489 }
490
491 fn as_any(&self) -> &dyn std::any::Any {
492 self
493 }
494
495 fn apply_chat_template(
496 &self,
497 messages: &[serde_json::Value],
498 params: ChatTemplateParams,
499 ) -> Result<String> {
500 if params.special_tokens.is_some() {
502 return self.chat_template.apply(messages, params);
503 }
504 let params = ChatTemplateParams {
505 special_tokens: Some(&self.special_tokens),
506 ..params
507 };
508 self.chat_template.apply(messages, params)
509 }
510
511 fn chat_template_content_format(&self) -> ChatTemplateContentFormat {
512 self.chat_template.content_format()
513 }
514
515 fn thinking_toggle(&self) -> ThinkingToggle {
516 self.chat_template.thinking_toggle()
517 }
518
519 fn thinking_key_name(&self) -> Option<ThinkingKeyName> {
520 self.chat_template.thinking_key_name()
521 }
522 fn think_in_prefill(&self) -> bool {
523 self.chat_template.think_in_prefill()
524 }
525
526 fn set_chat_template(&mut self, template: String) -> Result<()> {
527 self.chat_template.set(template)
528 }
529}
530
531#[cfg(test)]
532mod tests {
533 use super::*;
534 use crate::traits::{Decoder, Encoder, Tokenizer};
535
536 #[test]
537 fn test_tiktoken_creation() {
538 let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap();
539 assert_eq!(tokenizer.vocab_size(), 100256);
540 }
541
542 #[test]
543 fn test_model_from_name() {
544 assert!(matches!(
545 TiktokenTokenizer::model_from_name("gpt-4").unwrap(),
546 TiktokenModel::Cl100kBase
547 ));
548 assert!(matches!(
549 TiktokenTokenizer::model_from_name("gpt-3.5-turbo").unwrap(),
550 TiktokenModel::Cl100kBase
551 ));
552 assert!(matches!(
553 TiktokenTokenizer::model_from_name("text-davinci-003").unwrap(),
554 TiktokenModel::P50kBase
555 ));
556 assert!(matches!(
557 TiktokenTokenizer::model_from_name("text-davinci-edit-001").unwrap(),
558 TiktokenModel::P50kEdit
559 ));
560 assert!(matches!(
561 TiktokenTokenizer::model_from_name("davinci").unwrap(),
562 TiktokenModel::R50kBase
563 ));
564 }
565
566 #[test]
567 fn test_encode_decode() {
568 let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap();
569
570 let text = "Hello, world!";
571 let encoding = tokenizer.encode(text, false).unwrap();
572
573 let decoded = tokenizer.decode(encoding.token_ids(), false).unwrap();
574 assert_eq!(decoded, text);
575 }
576
577 #[test]
578 fn test_batch_encode() {
579 let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap();
580
581 let texts = vec!["Hello", "World", "Test"];
582 let encodings = tokenizer.encode_batch(&texts, false).unwrap();
583
584 assert_eq!(encodings.len(), 3);
585 for (i, encoding) in encodings.iter().enumerate() {
586 let decoded = tokenizer.decode(encoding.token_ids(), false).unwrap();
587 assert_eq!(decoded, texts[i]);
588 }
589 }
590
591 #[test]
592 fn test_special_tokens() {
593 let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap();
594 let special_tokens = tokenizer.get_special_tokens();
595
596 assert!(special_tokens.eos_token.is_some());
597 assert_eq!(special_tokens.eos_token.as_ref().unwrap(), "<|endoftext|>");
598 }
599
600 #[test]
601 fn test_unrecognized_model_name_returns_error() {
602 let result = TiktokenTokenizer::from_model_name("distilgpt-2");
603 assert!(result.is_err());
604 if let Err(e) = result {
605 assert!(e.to_string().contains("Unrecognized OpenAI model name"));
606 }
607
608 let result = TiktokenTokenizer::from_model_name("bert-base-uncased");
609 assert!(result.is_err());
610 if let Err(e) = result {
611 assert!(e.to_string().contains("Unrecognized OpenAI model name"));
612 }
613
614 let result = TiktokenTokenizer::from_model_name("llama-7b");
615 assert!(result.is_err());
616 if let Err(e) = result {
617 assert!(e.to_string().contains("Unrecognized OpenAI model name"));
618 }
619 }
620
621 #[test]
622 fn test_recognized_model_names() {
623 assert!(TiktokenTokenizer::from_model_name("gpt-4").is_ok());
624 assert!(TiktokenTokenizer::from_model_name("gpt-3.5-turbo").is_ok());
625 assert!(TiktokenTokenizer::from_model_name("text-davinci-003").is_ok());
626 assert!(TiktokenTokenizer::from_model_name("code-davinci-002").is_ok());
627 assert!(TiktokenTokenizer::from_model_name("text-curie-001").is_ok());
628 assert!(TiktokenTokenizer::from_model_name("text-babbage-001").is_ok());
629 assert!(TiktokenTokenizer::from_model_name("text-ada-001").is_ok());
630 }
631
632 #[test]
633 fn test_builtin_tokenizer_has_empty_vocab_maps() {
634 let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap();
635 assert_eq!(tokenizer.token_to_id("hello"), None);
637 assert_eq!(tokenizer.id_to_token(0), None);
638 }
639
640 #[test]
641 fn test_load_tiktoken_bpe() {
642 use std::io::Write;
643 let dir = tempfile::tempdir().unwrap();
644 let file_path = dir.path().join("test.tiktoken");
645 let mut f = std::fs::File::create(&file_path).unwrap();
646 writeln!(f, "IQ== 0").unwrap();
649 writeln!(f, "Ig== 1").unwrap();
650
651 let encoder = load_tiktoken_bpe(file_path.to_str().unwrap()).unwrap();
652 assert_eq!(encoder.len(), 2);
653 assert_eq!(encoder.get(&vec![0x21u8]), Some(&0));
654 assert_eq!(encoder.get(&vec![0x22u8]), Some(&1));
655 }
656
657 #[test]
658 fn test_build_vocab_maps() {
659 let mut encoder = FxHashMap::default();
660 encoder.insert(b"hello".to_vec(), 42u32);
661 encoder.insert(vec![0xFF, 0xFE], 99u32); let mut added = HashMap::new();
664 added.insert("<|special|>".to_string(), 1000u32);
665
666 let (vocab, reverse_vocab) = build_vocab_maps(&encoder, &added);
667
668 assert_eq!(vocab.get("hello"), Some(&42));
670 assert_eq!(reverse_vocab.get(&42), Some(&"hello".to_string()));
671
672 assert!(!vocab.contains_key("\u{FFFD}")); assert_eq!(vocab.get("<|special|>"), Some(&1000));
677 assert_eq!(reverse_vocab.get(&1000), Some(&"<|special|>".to_string()));
678 }
679
680 #[test]
681 fn test_has_tiktoken_file() {
682 let dir = tempfile::tempdir().unwrap();
683 assert!(!has_tiktoken_file(dir.path()));
684
685 std::fs::write(dir.path().join("tiktoken.model"), "test").unwrap();
686 assert!(has_tiktoken_file(dir.path()));
687 }
688
689 #[test]
690 fn test_find_tiktoken_file_model() {
691 let dir = tempfile::tempdir().unwrap();
692 std::fs::write(dir.path().join("tiktoken.model"), "test").unwrap();
693 let found = find_tiktoken_file(dir.path()).unwrap();
694 assert_eq!(found.file_name().unwrap(), "tiktoken.model");
695 }
696
697 #[test]
698 fn test_find_tiktoken_file_extension() {
699 let dir = tempfile::tempdir().unwrap();
700 std::fs::write(dir.path().join("vocab.tiktoken"), "test").unwrap();
701 let found = find_tiktoken_file(dir.path()).unwrap();
702 assert!(found
703 .file_name()
704 .unwrap()
705 .to_str()
706 .unwrap()
707 .ends_with(".tiktoken"));
708 }
709
710 #[test]
711 fn test_is_tiktoken_file() {
712 assert!(is_tiktoken_file(Path::new("tiktoken.model")));
713 assert!(is_tiktoken_file(Path::new("vocab.tiktoken")));
714 assert!(!is_tiktoken_file(Path::new("tokenizer.json")));
715 assert!(!is_tiktoken_file(Path::new("model.bin")));
716 }
717
718 #[test]
719 fn test_parse_added_tokens_decoder() {
720 let config: serde_json::Value = serde_json::json!({
721 "added_tokens_decoder": {
722 "163584": { "content": "[BOS]", "special": true },
723 "163585": { "content": "[EOS]", "special": true },
724 "163586": { "content": "<|im_end|>", "special": true }
725 }
726 });
727 let tokens = parse_added_tokens_decoder(&config);
728 assert_eq!(tokens.get("[BOS]"), Some(&163584));
729 assert_eq!(tokens.get("[EOS]"), Some(&163585));
730 assert_eq!(tokens.get("<|im_end|>"), Some(&163586));
731 }
732
733 #[test]
734 fn test_parse_special_tokens() {
735 let config: serde_json::Value = serde_json::json!({
736 "bos_token": "[BOS]",
737 "eos_token": "[EOS]",
738 "unk_token": "[UNK]",
739 "pad_token": "[PAD]",
740 "additional_special_tokens": ["<|im_end|>", "<|im_user|>"]
741 });
742 let special = parse_special_tokens(&config);
743 assert_eq!(special.bos_token.as_deref(), Some("[BOS]"));
744 assert_eq!(special.eos_token.as_deref(), Some("[EOS]"));
745 assert_eq!(special.unk_token.as_deref(), Some("[UNK]"));
746 assert_eq!(special.pad_token.as_deref(), Some("[PAD]"));
747 assert_eq!(special.additional_special_tokens.len(), 2);
748 }
749
750 #[test]
751 fn test_parse_special_tokens_object_valued() {
752 let config: serde_json::Value = serde_json::json!({
753 "bos_token": {"content": "<s>", "lstrip": false, "rstrip": false, "single_word": false, "special": true},
754 "eos_token": "</s>",
755 "unk_token": {"content": "<unk>", "special": true}
756 });
757 let special = parse_special_tokens(&config);
758 assert_eq!(special.bos_token.as_deref(), Some("<s>"));
759 assert_eq!(special.eos_token.as_deref(), Some("</s>"));
760 assert_eq!(special.unk_token.as_deref(), Some("<unk>"));
761 }
762
763 #[test]
764 fn test_tiktoken_config_default() {
765 let config = TiktokenConfig::default();
766 assert!(config.special_tokens.bos_token.is_none());
767 assert!(config.added_tokens.is_empty());
768 assert!(config.chat_template.is_none());
769 }
770
771 #[test]
772 fn test_decode_lossy_fallback_for_invalid_utf8() {
773 let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap();
779
780 let full_encoding = tokenizer.encode("π", false).unwrap();
784 let full_ids = full_encoding.token_ids();
785 assert!(
786 full_ids.len() > 1,
787 "emoji should encode to multiple tokens in cl100k_base"
788 );
789
790 let partial_ids = &full_ids[..1];
792 let result = tokenizer.decode(partial_ids, false);
793 assert!(
794 result.is_ok(),
795 "decode of partial UTF-8 should succeed via lossy fallback"
796 );
797 let decoded = result.unwrap();
798 assert!(
799 decoded.contains('\u{FFFD}') || decoded.is_empty(),
800 "lossy decode should contain replacement char or be empty, got: {decoded:?}"
801 );
802 }
803
804 #[test]
805 fn test_decode_valid_utf8_does_not_use_fallback() {
806 let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap();
808 let text = "Hello, δΈη!";
809 let encoding = tokenizer.encode(text, false).unwrap();
810 let decoded = tokenizer.decode(encoding.token_ids(), false).unwrap();
811 assert_eq!(decoded, text);
812 }
813}