1use std::{
2 collections::HashMap,
3 path::{Path, PathBuf},
4};
5
6use anyhow::{Error, Result};
7use base64::{engine::general_purpose::STANDARD, Engine as _};
8use rustc_hash::FxHashMap;
9use tiktoken_rs::{cl100k_base, p50k_base, p50k_edit, r50k_base, CoreBPE};
10
11use crate::{
12 chat_template::{
13 load_chat_template_from_file, ChatTemplateContentFormat, ChatTemplateParams,
14 ChatTemplateState,
15 },
16 factory::discover_chat_template_in_dir,
17 traits::{Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait},
18};
19
20const CL100K_BASE_PATTERN: &str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+";
29
30type Rank = u32;
31
32#[derive(Default)]
38struct TiktokenConfig {
39 special_tokens: SpecialTokens,
40 added_tokens: HashMap<String, TokenIdType>,
42 chat_template: Option<String>,
43}
44
45fn load_tiktoken_config(config_path: &Path) -> Result<TiktokenConfig> {
47 let content = std::fs::read_to_string(config_path)?;
48 let config: serde_json::Value = serde_json::from_str(&content)?;
49
50 let added_tokens = parse_added_tokens_decoder(&config);
51 let special_tokens = parse_special_tokens(&config);
52
53 let chat_template = config
54 .get("chat_template")
55 .and_then(|v| v.as_str())
56 .map(String::from);
57
58 Ok(TiktokenConfig {
59 special_tokens,
60 added_tokens,
61 chat_template,
62 })
63}
64
65fn parse_added_tokens_decoder(config: &serde_json::Value) -> HashMap<String, TokenIdType> {
69 let mut tokens = HashMap::new();
70 if let Some(added) = config
71 .get("added_tokens_decoder")
72 .and_then(|v| v.as_object())
73 {
74 for (id_str, token_info) in added {
75 if let (Ok(id), Some(content)) = (
76 id_str.parse::<TokenIdType>(),
77 token_info.get("content").and_then(|v| v.as_str()),
78 ) {
79 tokens.insert(content.to_string(), id);
80 }
81 }
82 }
83 tokens
84}
85
86fn parse_special_tokens(config: &serde_json::Value) -> SpecialTokens {
91 let get_str = |key: &str| {
92 config.get(key).and_then(|v| {
93 v.as_str()
94 .map(String::from)
95 .or_else(|| v.get("content").and_then(|c| c.as_str()).map(String::from))
96 })
97 };
98
99 let additional: Vec<String> = config
100 .get("additional_special_tokens")
101 .and_then(|v| v.as_array())
102 .map(|arr| {
103 arr.iter()
104 .filter_map(|v| {
105 v.as_str()
106 .map(String::from)
107 .or_else(|| v.get("content").and_then(|c| c.as_str()).map(String::from))
108 })
109 .collect()
110 })
111 .unwrap_or_default();
112
113 SpecialTokens {
114 bos_token: get_str("bos_token"),
115 eos_token: get_str("eos_token"),
116 unk_token: get_str("unk_token"),
117 sep_token: get_str("sep_token"),
118 pad_token: get_str("pad_token"),
119 cls_token: get_str("cls_token"),
120 mask_token: get_str("mask_token"),
121 additional_special_tokens: additional,
122 }
123}
124
125pub struct TiktokenTokenizer {
127 tokenizer: CoreBPE,
128 special_tokens: SpecialTokens,
129 vocab: HashMap<String, TokenIdType>,
130 reverse_vocab: HashMap<TokenIdType, String>,
131 vocab_size: usize,
132 chat_template: ChatTemplateState,
133}
134
135#[derive(Debug, Clone, Copy)]
137pub enum TiktokenModel {
138 Cl100kBase,
140 P50kBase,
142 P50kEdit,
144 R50kBase,
146}
147
148impl TiktokenTokenizer {
149 pub fn new(model: TiktokenModel) -> Result<Self> {
151 let tokenizer = match model {
152 TiktokenModel::Cl100kBase => {
153 cl100k_base().map_err(|e| Error::msg(format!("Failed to load cl100k_base: {e}")))?
154 }
155 TiktokenModel::P50kBase => {
156 p50k_base().map_err(|e| Error::msg(format!("Failed to load p50k_base: {e}")))?
157 }
158 TiktokenModel::P50kEdit => {
159 p50k_edit().map_err(|e| Error::msg(format!("Failed to load p50k_edit: {e}")))?
160 }
161 TiktokenModel::R50kBase => {
162 r50k_base().map_err(|e| Error::msg(format!("Failed to load r50k_base: {e}")))?
163 }
164 };
165
166 let special_tokens = Self::get_special_tokens_for_model(model);
167
168 let vocab_size = match model {
169 TiktokenModel::Cl100kBase => 100256,
170 TiktokenModel::P50kBase | TiktokenModel::P50kEdit => 50281,
171 TiktokenModel::R50kBase => 50257,
172 };
173
174 Ok(TiktokenTokenizer {
175 tokenizer,
176 special_tokens,
177 vocab: HashMap::new(),
178 reverse_vocab: HashMap::new(),
179 vocab_size,
180 chat_template: ChatTemplateState::empty(),
181 })
182 }
183
184 pub fn from_dir(dir: &Path) -> Result<Self> {
186 Self::from_dir_with_chat_template(dir, None)
187 }
188
189 pub fn from_dir_with_chat_template(
192 dir: &Path,
193 chat_template_path: Option<&str>,
194 ) -> Result<Self> {
195 let tiktoken_path = find_tiktoken_file(dir)?;
196 Self::load_from_path(&tiktoken_path, chat_template_path)
197 }
198
199 pub fn from_file(tiktoken_path: &Path) -> Result<Self> {
202 Self::from_file_with_chat_template(tiktoken_path, None)
203 }
204
205 pub fn from_file_with_chat_template(
207 tiktoken_path: &Path,
208 chat_template_path: Option<&str>,
209 ) -> Result<Self> {
210 Self::load_from_path(tiktoken_path, chat_template_path)
211 }
212
213 fn load_from_path(tiktoken_path: &Path, chat_template_path: Option<&str>) -> Result<Self> {
215 let tiktoken_path_str = tiktoken_path
217 .to_str()
218 .ok_or_else(|| Error::msg("Tiktoken file path is not valid UTF-8"))?;
219 let encoder = load_tiktoken_bpe(tiktoken_path_str)?;
220
221 let dir = tiktoken_path
223 .parent()
224 .ok_or_else(|| Error::msg("Cannot determine parent directory of tiktoken file"))?;
225 let config_path = dir.join("tokenizer_config.json");
226 let config = if config_path.exists() {
227 load_tiktoken_config(&config_path)?
228 } else {
229 TiktokenConfig::default()
230 };
231
232 let special_tokens_encoder: FxHashMap<String, Rank> = config
234 .added_tokens
235 .iter()
236 .map(|(k, &v)| (k.clone(), v))
237 .collect();
238
239 let vocab_size = encoder
242 .values()
243 .copied()
244 .chain(special_tokens_encoder.values().copied())
245 .max()
246 .map(|id| id as usize + 1)
247 .unwrap_or(0);
248 let (vocab, reverse_vocab) = build_vocab_maps(&encoder, &config.added_tokens);
249 let tokenizer = CoreBPE::new(encoder, special_tokens_encoder, CL100K_BASE_PATTERN)?;
250
251 let chat_template = if let Some(p) = chat_template_path {
254 load_chat_template_from_file(p)?
255 } else {
256 config.chat_template.or_else(|| {
257 discover_chat_template_in_dir(dir)
258 .and_then(|p| load_chat_template_from_file(&p).ok().flatten())
259 })
260 };
261
262 Ok(TiktokenTokenizer {
263 tokenizer,
264 special_tokens: config.special_tokens,
265 vocab,
266 reverse_vocab,
267 vocab_size,
268 chat_template: ChatTemplateState::new(chat_template)?,
269 })
270 }
271
272 pub fn from_model_name(model_name: &str) -> Result<Self> {
274 let model = Self::model_from_name(model_name)?;
275 Self::new(model)
276 }
277
278 fn model_from_name(model_name: &str) -> Result<TiktokenModel> {
280 if model_name.contains("gpt-4")
281 || model_name.contains("gpt-3.5")
282 || model_name.contains("turbo")
283 {
284 Ok(TiktokenModel::Cl100kBase)
285 } else if model_name.contains("davinci-002")
286 || model_name.contains("davinci-003")
287 || model_name.contains("codex")
288 {
289 Ok(TiktokenModel::P50kBase)
290 } else if model_name.contains("edit") {
291 Ok(TiktokenModel::P50kEdit)
292 } else if model_name.contains("davinci")
293 || model_name.contains("curie")
294 || model_name.contains("babbage")
295 || model_name.contains("ada")
296 {
297 Ok(TiktokenModel::R50kBase)
298 } else {
299 Err(anyhow::anyhow!(
300 "Unrecognized OpenAI model name: '{model_name}'. Expected GPT-3, GPT-3.5, GPT-4, or related model names"
301 ))
302 }
303 }
304
305 fn get_special_tokens_for_model(model: TiktokenModel) -> SpecialTokens {
307 match model {
308 TiktokenModel::Cl100kBase => SpecialTokens {
309 bos_token: Some("<|endoftext|>".to_string()),
310 eos_token: Some("<|endoftext|>".to_string()),
311 unk_token: None,
312 sep_token: None,
313 pad_token: Some("<|endoftext|>".to_string()),
314 cls_token: None,
315 mask_token: None,
316 additional_special_tokens: vec![
317 "<|fim_prefix|>".to_string(),
318 "<|fim_middle|>".to_string(),
319 "<|fim_suffix|>".to_string(),
320 "<|endofprompt|>".to_string(),
321 ],
322 },
323 _ => SpecialTokens {
324 bos_token: Some("<|endoftext|>".to_string()),
325 eos_token: Some("<|endoftext|>".to_string()),
326 unk_token: None,
327 sep_token: None,
328 pad_token: Some("<|endoftext|>".to_string()),
329 cls_token: None,
330 mask_token: None,
331 additional_special_tokens: vec![],
332 },
333 }
334 }
335}
336
337fn load_tiktoken_bpe(path: &str) -> Result<FxHashMap<Vec<u8>, Rank>> {
341 let content = std::fs::read_to_string(path)?;
342 let mut encoder =
343 FxHashMap::with_capacity_and_hasher(content.lines().count(), Default::default());
344 for line in content.lines() {
345 if line.is_empty() {
346 continue;
347 }
348 let mut parts = line.split_whitespace();
349 let token_b64 = parts
350 .next()
351 .ok_or_else(|| Error::msg("missing token in tiktoken file"))?;
352 let rank_str = parts
353 .next()
354 .ok_or_else(|| Error::msg("missing rank in tiktoken file"))?;
355 let token_bytes = STANDARD.decode(token_b64)?;
356 let rank: Rank = rank_str.parse()?;
357 encoder.insert(token_bytes, rank);
358 }
359 Ok(encoder)
360}
361
362fn build_vocab_maps(
364 encoder: &FxHashMap<Vec<u8>, Rank>,
365 added_tokens: &HashMap<String, TokenIdType>,
366) -> (HashMap<String, TokenIdType>, HashMap<TokenIdType, String>) {
367 let capacity = encoder.len() + added_tokens.len();
368 let mut vocab = HashMap::with_capacity(capacity);
369 let mut reverse_vocab = HashMap::with_capacity(capacity);
370
371 for (token_bytes, &rank) in encoder {
373 if let Ok(token_str) = std::str::from_utf8(token_bytes) {
374 vocab.insert(token_str.to_string(), rank);
375 reverse_vocab.insert(rank, token_str.to_string());
376 }
377 }
378
379 for (token_str, &id) in added_tokens {
381 vocab.insert(token_str.clone(), id);
382 reverse_vocab.insert(id, token_str.clone());
383 }
384
385 (vocab, reverse_vocab)
386}
387
388fn find_tiktoken_file(dir: &Path) -> Result<PathBuf> {
392 let tiktoken_model = dir.join("tiktoken.model");
393 if tiktoken_model.exists() {
394 return Ok(tiktoken_model);
395 }
396
397 if let Ok(entries) = std::fs::read_dir(dir) {
399 for entry in entries.flatten() {
400 if let Some(name) = entry.file_name().to_str() {
401 if name.ends_with(".tiktoken") {
402 return Ok(entry.path());
403 }
404 }
405 }
406 }
407
408 Err(Error::msg(format!(
409 "No tiktoken model file found in '{}'",
410 dir.display()
411 )))
412}
413
414pub fn has_tiktoken_file(dir: &Path) -> bool {
416 if dir.join("tiktoken.model").exists() {
417 return true;
418 }
419 std::fs::read_dir(dir)
420 .ok()
421 .map(|entries| {
422 entries.flatten().any(|e| {
423 e.file_name()
424 .to_str()
425 .is_some_and(|n| n.ends_with(".tiktoken"))
426 })
427 })
428 .unwrap_or(false)
429}
430
431pub fn is_tiktoken_file(path: &Path) -> bool {
433 path.file_name()
434 .and_then(|n| n.to_str())
435 .is_some_and(|name| name == "tiktoken.model" || name.ends_with(".tiktoken"))
436}
437
438impl Encoder for TiktokenTokenizer {
439 fn encode(&self, input: &str, _add_special_tokens: bool) -> Result<Encoding> {
440 let tokens = self.tokenizer.encode_ordinary(input);
441 Ok(Encoding::Tiktoken(tokens))
442 }
443
444 fn encode_batch(&self, inputs: &[&str], add_special_tokens: bool) -> Result<Vec<Encoding>> {
445 inputs
446 .iter()
447 .map(|input| self.encode(input, add_special_tokens))
448 .collect()
449 }
450}
451
452impl Decoder for TiktokenTokenizer {
453 fn decode(&self, token_ids: &[TokenIdType], _skip_special_tokens: bool) -> Result<String> {
454 match self.tokenizer.decode(token_ids.to_vec()) {
455 Ok(text) => Ok(text),
456 Err(err) => {
457 let bytes: Vec<u8> = self
459 .tokenizer
460 ._decode_native_and_split(token_ids.to_vec())
461 .flatten()
462 .collect();
463 tracing::warn!(
464 error = %err,
465 token_count = token_ids.len(),
466 "tiktoken decode failed; returning lossy UTF-8 fallback"
467 );
468 Ok(String::from_utf8_lossy(&bytes).into_owned())
469 }
470 }
471 }
472}
473
474impl TokenizerTrait for TiktokenTokenizer {
475 fn vocab_size(&self) -> usize {
476 self.vocab_size
477 }
478
479 fn get_special_tokens(&self) -> &SpecialTokens {
480 &self.special_tokens
481 }
482
483 fn token_to_id(&self, token: &str) -> Option<TokenIdType> {
484 self.vocab.get(token).copied()
485 }
486
487 fn id_to_token(&self, id: TokenIdType) -> Option<String> {
488 self.reverse_vocab.get(&id).cloned()
489 }
490
491 fn as_any(&self) -> &dyn std::any::Any {
492 self
493 }
494
495 fn apply_chat_template(
496 &self,
497 messages: &[serde_json::Value],
498 params: ChatTemplateParams,
499 ) -> Result<String> {
500 if params.special_tokens.is_some() {
502 return self.chat_template.apply(messages, params);
503 }
504 let params = ChatTemplateParams {
505 special_tokens: Some(&self.special_tokens),
506 ..params
507 };
508 self.chat_template.apply(messages, params)
509 }
510
511 fn chat_template_content_format(&self) -> ChatTemplateContentFormat {
512 self.chat_template.content_format()
513 }
514
515 fn set_chat_template(&mut self, template: String) -> Result<()> {
516 self.chat_template.set(template)
517 }
518}
519
520#[cfg(test)]
521mod tests {
522 use super::*;
523 use crate::traits::{Decoder, Encoder, Tokenizer};
524
525 #[test]
526 fn test_tiktoken_creation() {
527 let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap();
528 assert_eq!(tokenizer.vocab_size(), 100256);
529 }
530
531 #[test]
532 fn test_model_from_name() {
533 assert!(matches!(
534 TiktokenTokenizer::model_from_name("gpt-4").unwrap(),
535 TiktokenModel::Cl100kBase
536 ));
537 assert!(matches!(
538 TiktokenTokenizer::model_from_name("gpt-3.5-turbo").unwrap(),
539 TiktokenModel::Cl100kBase
540 ));
541 assert!(matches!(
542 TiktokenTokenizer::model_from_name("text-davinci-003").unwrap(),
543 TiktokenModel::P50kBase
544 ));
545 assert!(matches!(
546 TiktokenTokenizer::model_from_name("text-davinci-edit-001").unwrap(),
547 TiktokenModel::P50kEdit
548 ));
549 assert!(matches!(
550 TiktokenTokenizer::model_from_name("davinci").unwrap(),
551 TiktokenModel::R50kBase
552 ));
553 }
554
555 #[test]
556 fn test_encode_decode() {
557 let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap();
558
559 let text = "Hello, world!";
560 let encoding = tokenizer.encode(text, false).unwrap();
561
562 let decoded = tokenizer.decode(encoding.token_ids(), false).unwrap();
563 assert_eq!(decoded, text);
564 }
565
566 #[test]
567 fn test_batch_encode() {
568 let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap();
569
570 let texts = vec!["Hello", "World", "Test"];
571 let encodings = tokenizer.encode_batch(&texts, false).unwrap();
572
573 assert_eq!(encodings.len(), 3);
574 for (i, encoding) in encodings.iter().enumerate() {
575 let decoded = tokenizer.decode(encoding.token_ids(), false).unwrap();
576 assert_eq!(decoded, texts[i]);
577 }
578 }
579
580 #[test]
581 fn test_special_tokens() {
582 let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap();
583 let special_tokens = tokenizer.get_special_tokens();
584
585 assert!(special_tokens.eos_token.is_some());
586 assert_eq!(special_tokens.eos_token.as_ref().unwrap(), "<|endoftext|>");
587 }
588
589 #[test]
590 fn test_unrecognized_model_name_returns_error() {
591 let result = TiktokenTokenizer::from_model_name("distilgpt-2");
592 assert!(result.is_err());
593 if let Err(e) = result {
594 assert!(e.to_string().contains("Unrecognized OpenAI model name"));
595 }
596
597 let result = TiktokenTokenizer::from_model_name("bert-base-uncased");
598 assert!(result.is_err());
599 if let Err(e) = result {
600 assert!(e.to_string().contains("Unrecognized OpenAI model name"));
601 }
602
603 let result = TiktokenTokenizer::from_model_name("llama-7b");
604 assert!(result.is_err());
605 if let Err(e) = result {
606 assert!(e.to_string().contains("Unrecognized OpenAI model name"));
607 }
608 }
609
610 #[test]
611 fn test_recognized_model_names() {
612 assert!(TiktokenTokenizer::from_model_name("gpt-4").is_ok());
613 assert!(TiktokenTokenizer::from_model_name("gpt-3.5-turbo").is_ok());
614 assert!(TiktokenTokenizer::from_model_name("text-davinci-003").is_ok());
615 assert!(TiktokenTokenizer::from_model_name("code-davinci-002").is_ok());
616 assert!(TiktokenTokenizer::from_model_name("text-curie-001").is_ok());
617 assert!(TiktokenTokenizer::from_model_name("text-babbage-001").is_ok());
618 assert!(TiktokenTokenizer::from_model_name("text-ada-001").is_ok());
619 }
620
621 #[test]
622 fn test_builtin_tokenizer_has_empty_vocab_maps() {
623 let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap();
624 assert_eq!(tokenizer.token_to_id("hello"), None);
626 assert_eq!(tokenizer.id_to_token(0), None);
627 }
628
629 #[test]
630 fn test_load_tiktoken_bpe() {
631 use std::io::Write;
632 let dir = tempfile::tempdir().unwrap();
633 let file_path = dir.path().join("test.tiktoken");
634 let mut f = std::fs::File::create(&file_path).unwrap();
635 writeln!(f, "IQ== 0").unwrap();
638 writeln!(f, "Ig== 1").unwrap();
639
640 let encoder = load_tiktoken_bpe(file_path.to_str().unwrap()).unwrap();
641 assert_eq!(encoder.len(), 2);
642 assert_eq!(encoder.get(&vec![0x21u8]), Some(&0));
643 assert_eq!(encoder.get(&vec![0x22u8]), Some(&1));
644 }
645
646 #[test]
647 fn test_build_vocab_maps() {
648 let mut encoder = FxHashMap::default();
649 encoder.insert(b"hello".to_vec(), 42u32);
650 encoder.insert(vec![0xFF, 0xFE], 99u32); let mut added = HashMap::new();
653 added.insert("<|special|>".to_string(), 1000u32);
654
655 let (vocab, reverse_vocab) = build_vocab_maps(&encoder, &added);
656
657 assert_eq!(vocab.get("hello"), Some(&42));
659 assert_eq!(reverse_vocab.get(&42), Some(&"hello".to_string()));
660
661 assert!(!vocab.contains_key("\u{FFFD}")); assert_eq!(vocab.get("<|special|>"), Some(&1000));
666 assert_eq!(reverse_vocab.get(&1000), Some(&"<|special|>".to_string()));
667 }
668
669 #[test]
670 fn test_has_tiktoken_file() {
671 let dir = tempfile::tempdir().unwrap();
672 assert!(!has_tiktoken_file(dir.path()));
673
674 std::fs::write(dir.path().join("tiktoken.model"), "test").unwrap();
675 assert!(has_tiktoken_file(dir.path()));
676 }
677
678 #[test]
679 fn test_find_tiktoken_file_model() {
680 let dir = tempfile::tempdir().unwrap();
681 std::fs::write(dir.path().join("tiktoken.model"), "test").unwrap();
682 let found = find_tiktoken_file(dir.path()).unwrap();
683 assert_eq!(found.file_name().unwrap(), "tiktoken.model");
684 }
685
686 #[test]
687 fn test_find_tiktoken_file_extension() {
688 let dir = tempfile::tempdir().unwrap();
689 std::fs::write(dir.path().join("vocab.tiktoken"), "test").unwrap();
690 let found = find_tiktoken_file(dir.path()).unwrap();
691 assert!(found
692 .file_name()
693 .unwrap()
694 .to_str()
695 .unwrap()
696 .ends_with(".tiktoken"));
697 }
698
699 #[test]
700 fn test_is_tiktoken_file() {
701 assert!(is_tiktoken_file(Path::new("tiktoken.model")));
702 assert!(is_tiktoken_file(Path::new("vocab.tiktoken")));
703 assert!(!is_tiktoken_file(Path::new("tokenizer.json")));
704 assert!(!is_tiktoken_file(Path::new("model.bin")));
705 }
706
707 #[test]
708 fn test_parse_added_tokens_decoder() {
709 let config: serde_json::Value = serde_json::json!({
710 "added_tokens_decoder": {
711 "163584": { "content": "[BOS]", "special": true },
712 "163585": { "content": "[EOS]", "special": true },
713 "163586": { "content": "<|im_end|>", "special": true }
714 }
715 });
716 let tokens = parse_added_tokens_decoder(&config);
717 assert_eq!(tokens.get("[BOS]"), Some(&163584));
718 assert_eq!(tokens.get("[EOS]"), Some(&163585));
719 assert_eq!(tokens.get("<|im_end|>"), Some(&163586));
720 }
721
722 #[test]
723 fn test_parse_special_tokens() {
724 let config: serde_json::Value = serde_json::json!({
725 "bos_token": "[BOS]",
726 "eos_token": "[EOS]",
727 "unk_token": "[UNK]",
728 "pad_token": "[PAD]",
729 "additional_special_tokens": ["<|im_end|>", "<|im_user|>"]
730 });
731 let special = parse_special_tokens(&config);
732 assert_eq!(special.bos_token.as_deref(), Some("[BOS]"));
733 assert_eq!(special.eos_token.as_deref(), Some("[EOS]"));
734 assert_eq!(special.unk_token.as_deref(), Some("[UNK]"));
735 assert_eq!(special.pad_token.as_deref(), Some("[PAD]"));
736 assert_eq!(special.additional_special_tokens.len(), 2);
737 }
738
739 #[test]
740 fn test_parse_special_tokens_object_valued() {
741 let config: serde_json::Value = serde_json::json!({
742 "bos_token": {"content": "<s>", "lstrip": false, "rstrip": false, "single_word": false, "special": true},
743 "eos_token": "</s>",
744 "unk_token": {"content": "<unk>", "special": true}
745 });
746 let special = parse_special_tokens(&config);
747 assert_eq!(special.bos_token.as_deref(), Some("<s>"));
748 assert_eq!(special.eos_token.as_deref(), Some("</s>"));
749 assert_eq!(special.unk_token.as_deref(), Some("<unk>"));
750 }
751
752 #[test]
753 fn test_tiktoken_config_default() {
754 let config = TiktokenConfig::default();
755 assert!(config.special_tokens.bos_token.is_none());
756 assert!(config.added_tokens.is_empty());
757 assert!(config.chat_template.is_none());
758 }
759
760 #[test]
761 fn test_decode_lossy_fallback_for_invalid_utf8() {
762 let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap();
768
769 let full_encoding = tokenizer.encode("π", false).unwrap();
773 let full_ids = full_encoding.token_ids();
774 assert!(
775 full_ids.len() > 1,
776 "emoji should encode to multiple tokens in cl100k_base"
777 );
778
779 let partial_ids = &full_ids[..1];
781 let result = tokenizer.decode(partial_ids, false);
782 assert!(
783 result.is_ok(),
784 "decode of partial UTF-8 should succeed via lossy fallback"
785 );
786 let decoded = result.unwrap();
787 assert!(
788 decoded.contains('\u{FFFD}') || decoded.is_empty(),
789 "lossy decode should contain replacement char or be empty, got: {decoded:?}"
790 );
791 }
792
793 #[test]
794 fn test_decode_valid_utf8_does_not_use_fallback() {
795 let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap();
797 let text = "Hello, δΈη!";
798 let encoding = tokenizer.encode(text, false).unwrap();
799 let decoded = tokenizer.decode(encoding.token_ids(), false).unwrap();
800 assert_eq!(decoded, text);
801 }
802}