Skip to main content

llm_tokenizer/
tiktoken.rs

1use std::{
2    collections::HashMap,
3    path::{Path, PathBuf},
4};
5
6use anyhow::{Error, Result};
7use base64::{engine::general_purpose::STANDARD, Engine as _};
8use rustc_hash::FxHashMap;
9use tiktoken_rs::{cl100k_base, p50k_base, p50k_edit, r50k_base, CoreBPE};
10
11use crate::{
12    chat_template::{
13        load_chat_template_from_file, ChatTemplateContentFormat, ChatTemplateParams,
14        ChatTemplateState, ThinkingKeyName, ThinkingToggle,
15    },
16    factory::discover_chat_template_in_dir,
17    traits::{Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait},
18};
19
20/// Regex pattern for cl100k_base tokenization.
21///
22/// This pattern is correct for OpenAI models and most open-source tiktoken models (e.g.
23/// DeepSeek, Kimi K2). Some models use a different regex β€” for example, Kimi K2's native
24/// regex includes `\p{Han}` for Chinese character splitting β€” but encode/decode roundtrips
25/// still work correctly because BPE vocab handles tokenization; the regex only affects exact
26/// token boundary placement. A future enhancement could parse the regex from HuggingFace's
27/// `generation_config.json` or similar metadata.
28const CL100K_BASE_PATTERN: &str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+";
29
30type Rank = u32;
31
32// ---------------------------------------------------------------------------
33// Tiktoken-specific config parsing (from tokenizer_config.json)
34// ---------------------------------------------------------------------------
35
36/// Parsed `tokenizer_config.json` for tiktoken-based models.
37#[derive(Default)]
38struct TiktokenConfig {
39    special_tokens: SpecialTokens,
40    /// Token string -> ID mapping from `added_tokens_decoder`
41    added_tokens: HashMap<String, TokenIdType>,
42    chat_template: Option<String>,
43}
44
45/// Parse `tokenizer_config.json` for tiktoken-based models.
46fn load_tiktoken_config(config_path: &Path) -> Result<TiktokenConfig> {
47    let content = std::fs::read_to_string(config_path)?;
48    let config: serde_json::Value = serde_json::from_str(&content)?;
49
50    let added_tokens = parse_added_tokens_decoder(&config);
51    let special_tokens = parse_special_tokens(&config);
52
53    let chat_template = config
54        .get("chat_template")
55        .and_then(|v| v.as_str())
56        .map(String::from);
57
58    Ok(TiktokenConfig {
59        special_tokens,
60        added_tokens,
61        chat_template,
62    })
63}
64
65/// Parse `added_tokens_decoder` from config JSON.
66///
67/// Format: `{ "163584": { "content": "[BOS]", "special": true }, ... }`
68fn parse_added_tokens_decoder(config: &serde_json::Value) -> HashMap<String, TokenIdType> {
69    let mut tokens = HashMap::new();
70    if let Some(added) = config
71        .get("added_tokens_decoder")
72        .and_then(|v| v.as_object())
73    {
74        for (id_str, token_info) in added {
75            if let (Ok(id), Some(content)) = (
76                id_str.parse::<TokenIdType>(),
77                token_info.get("content").and_then(|v| v.as_str()),
78            ) {
79                tokens.insert(content.to_string(), id);
80            }
81        }
82    }
83    tokens
84}
85
86/// Extract named special tokens (bos, eos, unk, etc.) from config JSON.
87///
88/// Handles both string-valued tokens (`"bos_token": "<s>"`) and object-valued tokens
89/// (`"bos_token": {"content": "<s>", "lstrip": false, ...}`) found in some HuggingFace models.
90fn parse_special_tokens(config: &serde_json::Value) -> SpecialTokens {
91    let get_str = |key: &str| {
92        config.get(key).and_then(|v| {
93            v.as_str()
94                .map(String::from)
95                .or_else(|| v.get("content").and_then(|c| c.as_str()).map(String::from))
96        })
97    };
98
99    let additional: Vec<String> = config
100        .get("additional_special_tokens")
101        .and_then(|v| v.as_array())
102        .map(|arr| {
103            arr.iter()
104                .filter_map(|v| {
105                    v.as_str()
106                        .map(String::from)
107                        .or_else(|| v.get("content").and_then(|c| c.as_str()).map(String::from))
108                })
109                .collect()
110        })
111        .unwrap_or_default();
112
113    SpecialTokens {
114        bos_token: get_str("bos_token"),
115        eos_token: get_str("eos_token"),
116        unk_token: get_str("unk_token"),
117        sep_token: get_str("sep_token"),
118        pad_token: get_str("pad_token"),
119        cls_token: get_str("cls_token"),
120        mask_token: get_str("mask_token"),
121        additional_special_tokens: additional,
122    }
123}
124
125/// Tiktoken tokenizer wrapper β€” supports both built-in OpenAI encodings and hub-loaded models.
126pub struct TiktokenTokenizer {
127    tokenizer: CoreBPE,
128    special_tokens: SpecialTokens,
129    vocab: HashMap<String, TokenIdType>,
130    reverse_vocab: HashMap<TokenIdType, String>,
131    vocab_size: usize,
132    chat_template: ChatTemplateState,
133}
134
135/// Supported Tiktoken models
136#[derive(Debug, Clone, Copy)]
137pub enum TiktokenModel {
138    /// GPT-4, GPT-3.5-turbo, text-embedding-ada-002
139    Cl100kBase,
140    /// Codex models, text-davinci-002, text-davinci-003
141    P50kBase,
142    /// Use for edit models like text-davinci-edit-001, code-davinci-edit-001
143    P50kEdit,
144    /// GPT-3 models like davinci
145    R50kBase,
146}
147
148impl TiktokenTokenizer {
149    /// Create a new Tiktoken tokenizer for the specified built-in model
150    pub fn new(model: TiktokenModel) -> Result<Self> {
151        let tokenizer = match model {
152            TiktokenModel::Cl100kBase => {
153                cl100k_base().map_err(|e| Error::msg(format!("Failed to load cl100k_base: {e}")))?
154            }
155            TiktokenModel::P50kBase => {
156                p50k_base().map_err(|e| Error::msg(format!("Failed to load p50k_base: {e}")))?
157            }
158            TiktokenModel::P50kEdit => {
159                p50k_edit().map_err(|e| Error::msg(format!("Failed to load p50k_edit: {e}")))?
160            }
161            TiktokenModel::R50kBase => {
162                r50k_base().map_err(|e| Error::msg(format!("Failed to load r50k_base: {e}")))?
163            }
164        };
165
166        let special_tokens = Self::get_special_tokens_for_model(model);
167
168        let vocab_size = match model {
169            TiktokenModel::Cl100kBase => 100256,
170            TiktokenModel::P50kBase | TiktokenModel::P50kEdit => 50281,
171            TiktokenModel::R50kBase => 50257,
172        };
173
174        Ok(TiktokenTokenizer {
175            tokenizer,
176            special_tokens,
177            vocab: HashMap::new(),
178            reverse_vocab: HashMap::new(),
179            vocab_size,
180            chat_template: ChatTemplateState::empty(),
181        })
182    }
183
184    /// Create from a directory containing tiktoken.model + tokenizer_config.json
185    pub fn from_dir(dir: &Path) -> Result<Self> {
186        Self::from_dir_with_chat_template(dir, None)
187    }
188
189    /// Create from a directory with an optional chat template file path.
190    /// Discovers the tiktoken model file automatically via `find_tiktoken_file`.
191    pub fn from_dir_with_chat_template(
192        dir: &Path,
193        chat_template_path: Option<&str>,
194    ) -> Result<Self> {
195        let tiktoken_path = find_tiktoken_file(dir)?;
196        Self::load_from_path(&tiktoken_path, chat_template_path)
197    }
198
199    /// Create from an exact tiktoken file path (`.tiktoken` or `tiktoken.model`).
200    /// Looks for `tokenizer_config.json` in the same directory.
201    pub fn from_file(tiktoken_path: &Path) -> Result<Self> {
202        Self::from_file_with_chat_template(tiktoken_path, None)
203    }
204
205    /// Create from an exact tiktoken file path with an optional chat template.
206    pub fn from_file_with_chat_template(
207        tiktoken_path: &Path,
208        chat_template_path: Option<&str>,
209    ) -> Result<Self> {
210        Self::load_from_path(tiktoken_path, chat_template_path)
211    }
212
213    /// Core loading logic shared by `from_dir` and `from_file` constructors.
214    fn load_from_path(tiktoken_path: &Path, chat_template_path: Option<&str>) -> Result<Self> {
215        // 1. Load BPE encoder from the exact file
216        let tiktoken_path_str = tiktoken_path
217            .to_str()
218            .ok_or_else(|| Error::msg("Tiktoken file path is not valid UTF-8"))?;
219        let encoder = load_tiktoken_bpe(tiktoken_path_str)?;
220
221        // 2. Parse tokenizer_config.json from the same directory
222        let dir = tiktoken_path
223            .parent()
224            .ok_or_else(|| Error::msg("Cannot determine parent directory of tiktoken file"))?;
225        let config_path = dir.join("tokenizer_config.json");
226        let config = if config_path.exists() {
227            load_tiktoken_config(&config_path)?
228        } else {
229            TiktokenConfig::default()
230        };
231
232        // 3. Build special tokens encoder for CoreBPE (needs FxHashMap)
233        let special_tokens_encoder: FxHashMap<String, Rank> = config
234            .added_tokens
235            .iter()
236            .map(|(k, &v)| (k.clone(), v))
237            .collect();
238
239        // 4. Calculate true vocab size from max token ID (handles sparse/reserved IDs),
240        //    build string-based vocab maps (borrows encoder), then pass encoder by value to CoreBPE
241        let vocab_size = encoder
242            .values()
243            .copied()
244            .chain(special_tokens_encoder.values().copied())
245            .max()
246            .map(|id| id as usize + 1)
247            .unwrap_or(0);
248        let (vocab, reverse_vocab) = build_vocab_maps(&encoder, &config.added_tokens);
249        let tokenizer = CoreBPE::new(encoder, special_tokens_encoder, CL100K_BASE_PATTERN)?;
250
251        // 5. Load chat template β€” propagate errors for explicit paths,
252        //    silently fall back for auto-discovery
253        let chat_template = if let Some(p) = chat_template_path {
254            load_chat_template_from_file(p)?
255        } else {
256            config.chat_template.or_else(|| {
257                discover_chat_template_in_dir(dir)
258                    .and_then(|p| load_chat_template_from_file(&p).ok().flatten())
259            })
260        };
261
262        Ok(TiktokenTokenizer {
263            tokenizer,
264            special_tokens: config.special_tokens,
265            vocab,
266            reverse_vocab,
267            vocab_size,
268            chat_template: ChatTemplateState::new(chat_template)?,
269        })
270    }
271
272    /// Create a tokenizer from a model string (e.g., "gpt-4", "gpt-3.5-turbo")
273    pub fn from_model_name(model_name: &str) -> Result<Self> {
274        let model = Self::model_from_name(model_name)?;
275        Self::new(model)
276    }
277
278    /// Determine the appropriate model from a model name
279    fn model_from_name(model_name: &str) -> Result<TiktokenModel> {
280        if model_name.contains("gpt-4")
281            || model_name.contains("gpt-3.5")
282            || model_name.contains("turbo")
283        {
284            Ok(TiktokenModel::Cl100kBase)
285        } else if model_name.contains("davinci-002")
286            || model_name.contains("davinci-003")
287            || model_name.contains("codex")
288        {
289            Ok(TiktokenModel::P50kBase)
290        } else if model_name.contains("edit") {
291            Ok(TiktokenModel::P50kEdit)
292        } else if model_name.contains("davinci")
293            || model_name.contains("curie")
294            || model_name.contains("babbage")
295            || model_name.contains("ada")
296        {
297            Ok(TiktokenModel::R50kBase)
298        } else {
299            Err(anyhow::anyhow!(
300                "Unrecognized OpenAI model name: '{model_name}'. Expected GPT-3, GPT-3.5, GPT-4, or related model names"
301            ))
302        }
303    }
304
305    /// Get special tokens for a specific model
306    fn get_special_tokens_for_model(model: TiktokenModel) -> SpecialTokens {
307        match model {
308            TiktokenModel::Cl100kBase => SpecialTokens {
309                bos_token: Some("<|endoftext|>".to_string()),
310                eos_token: Some("<|endoftext|>".to_string()),
311                unk_token: None,
312                sep_token: None,
313                pad_token: Some("<|endoftext|>".to_string()),
314                cls_token: None,
315                mask_token: None,
316                additional_special_tokens: vec![
317                    "<|fim_prefix|>".to_string(),
318                    "<|fim_middle|>".to_string(),
319                    "<|fim_suffix|>".to_string(),
320                    "<|endofprompt|>".to_string(),
321                ],
322            },
323            _ => SpecialTokens {
324                bos_token: Some("<|endoftext|>".to_string()),
325                eos_token: Some("<|endoftext|>".to_string()),
326                unk_token: None,
327                sep_token: None,
328                pad_token: Some("<|endoftext|>".to_string()),
329                cls_token: None,
330                mask_token: None,
331                additional_special_tokens: vec![],
332            },
333        }
334    }
335}
336
337/// Parse a .tiktoken / tiktoken.model file into a BPE encoder.
338///
339/// Format: each line is `<base64-encoded-token-bytes> <rank>`
340fn load_tiktoken_bpe(path: &str) -> Result<FxHashMap<Vec<u8>, Rank>> {
341    let content = std::fs::read_to_string(path)?;
342    let mut encoder =
343        FxHashMap::with_capacity_and_hasher(content.lines().count(), Default::default());
344    for line in content.lines() {
345        if line.is_empty() {
346            continue;
347        }
348        let mut parts = line.split_whitespace();
349        let token_b64 = parts
350            .next()
351            .ok_or_else(|| Error::msg("missing token in tiktoken file"))?;
352        let rank_str = parts
353            .next()
354            .ok_or_else(|| Error::msg("missing rank in tiktoken file"))?;
355        let token_bytes = STANDARD.decode(token_b64)?;
356        let rank: Rank = rank_str.parse()?;
357        encoder.insert(token_bytes, rank);
358    }
359    Ok(encoder)
360}
361
362/// Build string-level vocab from byte-level encoder + added tokens.
363fn build_vocab_maps(
364    encoder: &FxHashMap<Vec<u8>, Rank>,
365    added_tokens: &HashMap<String, TokenIdType>,
366) -> (HashMap<String, TokenIdType>, HashMap<TokenIdType, String>) {
367    let capacity = encoder.len() + added_tokens.len();
368    let mut vocab = HashMap::with_capacity(capacity);
369    let mut reverse_vocab = HashMap::with_capacity(capacity);
370
371    // BPE tokens (only valid UTF-8 sequences get string entries)
372    for (token_bytes, &rank) in encoder {
373        if let Ok(token_str) = std::str::from_utf8(token_bytes) {
374            vocab.insert(token_str.to_string(), rank);
375            reverse_vocab.insert(rank, token_str.to_string());
376        }
377    }
378
379    // Special/added tokens (always valid UTF-8)
380    for (token_str, &id) in added_tokens {
381        vocab.insert(token_str.clone(), id);
382        reverse_vocab.insert(id, token_str.clone());
383    }
384
385    (vocab, reverse_vocab)
386}
387
388/// Find a tiktoken model file in the given directory.
389///
390/// Looks for `tiktoken.model` first, then any `*.tiktoken` file.
391fn find_tiktoken_file(dir: &Path) -> Result<PathBuf> {
392    let tiktoken_model = dir.join("tiktoken.model");
393    if tiktoken_model.exists() {
394        return Ok(tiktoken_model);
395    }
396
397    // Look for *.tiktoken files
398    if let Ok(entries) = std::fs::read_dir(dir) {
399        for entry in entries.flatten() {
400            if let Some(name) = entry.file_name().to_str() {
401                if name.ends_with(".tiktoken") {
402                    return Ok(entry.path());
403                }
404            }
405        }
406    }
407
408    Err(Error::msg(format!(
409        "No tiktoken model file found in '{}'",
410        dir.display()
411    )))
412}
413
414/// Check whether a directory contains a tiktoken model file.
415pub fn has_tiktoken_file(dir: &Path) -> bool {
416    if dir.join("tiktoken.model").exists() {
417        return true;
418    }
419    std::fs::read_dir(dir)
420        .ok()
421        .map(|entries| {
422            entries.flatten().any(|e| {
423                e.file_name()
424                    .to_str()
425                    .is_some_and(|n| n.ends_with(".tiktoken"))
426            })
427        })
428        .unwrap_or(false)
429}
430
431/// Check whether a single file is a tiktoken model file (by name).
432pub fn is_tiktoken_file(path: &Path) -> bool {
433    path.file_name()
434        .and_then(|n| n.to_str())
435        .is_some_and(|name| name == "tiktoken.model" || name.ends_with(".tiktoken"))
436}
437
438impl Encoder for TiktokenTokenizer {
439    fn encode(&self, input: &str, _add_special_tokens: bool) -> Result<Encoding> {
440        let tokens = self.tokenizer.encode_ordinary(input);
441        Ok(Encoding::Tiktoken(tokens))
442    }
443
444    fn encode_batch(&self, inputs: &[&str], add_special_tokens: bool) -> Result<Vec<Encoding>> {
445        inputs
446            .iter()
447            .map(|input| self.encode(input, add_special_tokens))
448            .collect()
449    }
450}
451
452impl Decoder for TiktokenTokenizer {
453    fn decode(&self, token_ids: &[TokenIdType], _skip_special_tokens: bool) -> Result<String> {
454        match self.tokenizer.decode(token_ids.to_vec()) {
455            Ok(text) => Ok(text),
456            Err(err) => {
457                // Fallback to lossy decoding for incomplete UTF-8 sequences
458                let bytes: Vec<u8> = self
459                    .tokenizer
460                    ._decode_native_and_split(token_ids.to_vec())
461                    .flatten()
462                    .collect();
463                tracing::warn!(
464                    error = %err,
465                    token_count = token_ids.len(),
466                    "tiktoken decode failed; returning lossy UTF-8 fallback"
467                );
468                Ok(String::from_utf8_lossy(&bytes).into_owned())
469            }
470        }
471    }
472}
473
474impl TokenizerTrait for TiktokenTokenizer {
475    fn vocab_size(&self) -> usize {
476        self.vocab_size
477    }
478
479    fn get_special_tokens(&self) -> &SpecialTokens {
480        &self.special_tokens
481    }
482
483    fn token_to_id(&self, token: &str) -> Option<TokenIdType> {
484        self.vocab.get(token).copied()
485    }
486
487    fn id_to_token(&self, id: TokenIdType) -> Option<String> {
488        self.reverse_vocab.get(&id).cloned()
489    }
490
491    fn as_any(&self) -> &dyn std::any::Any {
492        self
493    }
494
495    fn apply_chat_template(
496        &self,
497        messages: &[serde_json::Value],
498        params: ChatTemplateParams,
499    ) -> Result<String> {
500        // Inject special tokens if the caller didn't provide them
501        if params.special_tokens.is_some() {
502            return self.chat_template.apply(messages, params);
503        }
504        let params = ChatTemplateParams {
505            special_tokens: Some(&self.special_tokens),
506            ..params
507        };
508        self.chat_template.apply(messages, params)
509    }
510
511    fn chat_template_content_format(&self) -> ChatTemplateContentFormat {
512        self.chat_template.content_format()
513    }
514
515    fn thinking_toggle(&self) -> ThinkingToggle {
516        self.chat_template.thinking_toggle()
517    }
518
519    fn thinking_key_name(&self) -> Option<ThinkingKeyName> {
520        self.chat_template.thinking_key_name()
521    }
522    fn think_in_prefill(&self) -> bool {
523        self.chat_template.think_in_prefill()
524    }
525
526    fn set_chat_template(&mut self, template: String) -> Result<()> {
527        self.chat_template.set(template)
528    }
529}
530
531#[cfg(test)]
532mod tests {
533    use super::*;
534    use crate::traits::{Decoder, Encoder, Tokenizer};
535
536    #[test]
537    fn test_tiktoken_creation() {
538        let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap();
539        assert_eq!(tokenizer.vocab_size(), 100256);
540    }
541
542    #[test]
543    fn test_model_from_name() {
544        assert!(matches!(
545            TiktokenTokenizer::model_from_name("gpt-4").unwrap(),
546            TiktokenModel::Cl100kBase
547        ));
548        assert!(matches!(
549            TiktokenTokenizer::model_from_name("gpt-3.5-turbo").unwrap(),
550            TiktokenModel::Cl100kBase
551        ));
552        assert!(matches!(
553            TiktokenTokenizer::model_from_name("text-davinci-003").unwrap(),
554            TiktokenModel::P50kBase
555        ));
556        assert!(matches!(
557            TiktokenTokenizer::model_from_name("text-davinci-edit-001").unwrap(),
558            TiktokenModel::P50kEdit
559        ));
560        assert!(matches!(
561            TiktokenTokenizer::model_from_name("davinci").unwrap(),
562            TiktokenModel::R50kBase
563        ));
564    }
565
566    #[test]
567    fn test_encode_decode() {
568        let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap();
569
570        let text = "Hello, world!";
571        let encoding = tokenizer.encode(text, false).unwrap();
572
573        let decoded = tokenizer.decode(encoding.token_ids(), false).unwrap();
574        assert_eq!(decoded, text);
575    }
576
577    #[test]
578    fn test_batch_encode() {
579        let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap();
580
581        let texts = vec!["Hello", "World", "Test"];
582        let encodings = tokenizer.encode_batch(&texts, false).unwrap();
583
584        assert_eq!(encodings.len(), 3);
585        for (i, encoding) in encodings.iter().enumerate() {
586            let decoded = tokenizer.decode(encoding.token_ids(), false).unwrap();
587            assert_eq!(decoded, texts[i]);
588        }
589    }
590
591    #[test]
592    fn test_special_tokens() {
593        let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap();
594        let special_tokens = tokenizer.get_special_tokens();
595
596        assert!(special_tokens.eos_token.is_some());
597        assert_eq!(special_tokens.eos_token.as_ref().unwrap(), "<|endoftext|>");
598    }
599
600    #[test]
601    fn test_unrecognized_model_name_returns_error() {
602        let result = TiktokenTokenizer::from_model_name("distilgpt-2");
603        assert!(result.is_err());
604        if let Err(e) = result {
605            assert!(e.to_string().contains("Unrecognized OpenAI model name"));
606        }
607
608        let result = TiktokenTokenizer::from_model_name("bert-base-uncased");
609        assert!(result.is_err());
610        if let Err(e) = result {
611            assert!(e.to_string().contains("Unrecognized OpenAI model name"));
612        }
613
614        let result = TiktokenTokenizer::from_model_name("llama-7b");
615        assert!(result.is_err());
616        if let Err(e) = result {
617            assert!(e.to_string().contains("Unrecognized OpenAI model name"));
618        }
619    }
620
621    #[test]
622    fn test_recognized_model_names() {
623        assert!(TiktokenTokenizer::from_model_name("gpt-4").is_ok());
624        assert!(TiktokenTokenizer::from_model_name("gpt-3.5-turbo").is_ok());
625        assert!(TiktokenTokenizer::from_model_name("text-davinci-003").is_ok());
626        assert!(TiktokenTokenizer::from_model_name("code-davinci-002").is_ok());
627        assert!(TiktokenTokenizer::from_model_name("text-curie-001").is_ok());
628        assert!(TiktokenTokenizer::from_model_name("text-babbage-001").is_ok());
629        assert!(TiktokenTokenizer::from_model_name("text-ada-001").is_ok());
630    }
631
632    #[test]
633    fn test_builtin_tokenizer_has_empty_vocab_maps() {
634        let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap();
635        // Built-in path: vocab maps are empty, token_to_id returns None
636        assert_eq!(tokenizer.token_to_id("hello"), None);
637        assert_eq!(tokenizer.id_to_token(0), None);
638    }
639
640    #[test]
641    fn test_load_tiktoken_bpe() {
642        use std::io::Write;
643        let dir = tempfile::tempdir().unwrap();
644        let file_path = dir.path().join("test.tiktoken");
645        let mut f = std::fs::File::create(&file_path).unwrap();
646        // "IQ==" is base64 for byte 0x21 ('!'), rank 0
647        // "Ig==" is base64 for byte 0x22 ('"'), rank 1
648        writeln!(f, "IQ== 0").unwrap();
649        writeln!(f, "Ig== 1").unwrap();
650
651        let encoder = load_tiktoken_bpe(file_path.to_str().unwrap()).unwrap();
652        assert_eq!(encoder.len(), 2);
653        assert_eq!(encoder.get(&vec![0x21u8]), Some(&0));
654        assert_eq!(encoder.get(&vec![0x22u8]), Some(&1));
655    }
656
657    #[test]
658    fn test_build_vocab_maps() {
659        let mut encoder = FxHashMap::default();
660        encoder.insert(b"hello".to_vec(), 42u32);
661        encoder.insert(vec![0xFF, 0xFE], 99u32); // invalid UTF-8
662
663        let mut added = HashMap::new();
664        added.insert("<|special|>".to_string(), 1000u32);
665
666        let (vocab, reverse_vocab) = build_vocab_maps(&encoder, &added);
667
668        // Valid UTF-8 token present
669        assert_eq!(vocab.get("hello"), Some(&42));
670        assert_eq!(reverse_vocab.get(&42), Some(&"hello".to_string()));
671
672        // Invalid UTF-8 token excluded from vocab
673        assert!(!vocab.contains_key("\u{FFFD}")); // not lossy-inserted
674
675        // Added token present
676        assert_eq!(vocab.get("<|special|>"), Some(&1000));
677        assert_eq!(reverse_vocab.get(&1000), Some(&"<|special|>".to_string()));
678    }
679
680    #[test]
681    fn test_has_tiktoken_file() {
682        let dir = tempfile::tempdir().unwrap();
683        assert!(!has_tiktoken_file(dir.path()));
684
685        std::fs::write(dir.path().join("tiktoken.model"), "test").unwrap();
686        assert!(has_tiktoken_file(dir.path()));
687    }
688
689    #[test]
690    fn test_find_tiktoken_file_model() {
691        let dir = tempfile::tempdir().unwrap();
692        std::fs::write(dir.path().join("tiktoken.model"), "test").unwrap();
693        let found = find_tiktoken_file(dir.path()).unwrap();
694        assert_eq!(found.file_name().unwrap(), "tiktoken.model");
695    }
696
697    #[test]
698    fn test_find_tiktoken_file_extension() {
699        let dir = tempfile::tempdir().unwrap();
700        std::fs::write(dir.path().join("vocab.tiktoken"), "test").unwrap();
701        let found = find_tiktoken_file(dir.path()).unwrap();
702        assert!(found
703            .file_name()
704            .unwrap()
705            .to_str()
706            .unwrap()
707            .ends_with(".tiktoken"));
708    }
709
710    #[test]
711    fn test_is_tiktoken_file() {
712        assert!(is_tiktoken_file(Path::new("tiktoken.model")));
713        assert!(is_tiktoken_file(Path::new("vocab.tiktoken")));
714        assert!(!is_tiktoken_file(Path::new("tokenizer.json")));
715        assert!(!is_tiktoken_file(Path::new("model.bin")));
716    }
717
718    #[test]
719    fn test_parse_added_tokens_decoder() {
720        let config: serde_json::Value = serde_json::json!({
721            "added_tokens_decoder": {
722                "163584": { "content": "[BOS]", "special": true },
723                "163585": { "content": "[EOS]", "special": true },
724                "163586": { "content": "<|im_end|>", "special": true }
725            }
726        });
727        let tokens = parse_added_tokens_decoder(&config);
728        assert_eq!(tokens.get("[BOS]"), Some(&163584));
729        assert_eq!(tokens.get("[EOS]"), Some(&163585));
730        assert_eq!(tokens.get("<|im_end|>"), Some(&163586));
731    }
732
733    #[test]
734    fn test_parse_special_tokens() {
735        let config: serde_json::Value = serde_json::json!({
736            "bos_token": "[BOS]",
737            "eos_token": "[EOS]",
738            "unk_token": "[UNK]",
739            "pad_token": "[PAD]",
740            "additional_special_tokens": ["<|im_end|>", "<|im_user|>"]
741        });
742        let special = parse_special_tokens(&config);
743        assert_eq!(special.bos_token.as_deref(), Some("[BOS]"));
744        assert_eq!(special.eos_token.as_deref(), Some("[EOS]"));
745        assert_eq!(special.unk_token.as_deref(), Some("[UNK]"));
746        assert_eq!(special.pad_token.as_deref(), Some("[PAD]"));
747        assert_eq!(special.additional_special_tokens.len(), 2);
748    }
749
750    #[test]
751    fn test_parse_special_tokens_object_valued() {
752        let config: serde_json::Value = serde_json::json!({
753            "bos_token": {"content": "<s>", "lstrip": false, "rstrip": false, "single_word": false, "special": true},
754            "eos_token": "</s>",
755            "unk_token": {"content": "<unk>", "special": true}
756        });
757        let special = parse_special_tokens(&config);
758        assert_eq!(special.bos_token.as_deref(), Some("<s>"));
759        assert_eq!(special.eos_token.as_deref(), Some("</s>"));
760        assert_eq!(special.unk_token.as_deref(), Some("<unk>"));
761    }
762
763    #[test]
764    fn test_tiktoken_config_default() {
765        let config = TiktokenConfig::default();
766        assert!(config.special_tokens.bos_token.is_none());
767        assert!(config.added_tokens.is_empty());
768        assert!(config.chat_template.is_none());
769    }
770
771    #[test]
772    fn test_decode_lossy_fallback_for_invalid_utf8() {
773        // cl100k_base maps individual bytes to token IDs via its byte-level BPE.
774        // Encode a multi-byte UTF-8 character, then decode only a prefix of its
775        // tokens so the raw bytes form an incomplete (invalid) UTF-8 sequence.
776        // The old implementation would return an error; the new one should fall
777        // back to lossy decoding and produce U+FFFD replacement characters.
778        let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap();
779
780        // "πŸ˜€" is U+1F600, encoded as 4 UTF-8 bytes: [0xF0, 0x9F, 0x98, 0x80].
781        // With cl100k_base this encodes to multiple tokens. Taking a strict
782        // subset of those tokens gives bytes that aren't valid UTF-8.
783        let full_encoding = tokenizer.encode("πŸ˜€", false).unwrap();
784        let full_ids = full_encoding.token_ids();
785        assert!(
786            full_ids.len() > 1,
787            "emoji should encode to multiple tokens in cl100k_base"
788        );
789
790        // Take only the first token β€” its raw bytes are an incomplete UTF-8 prefix.
791        let partial_ids = &full_ids[..1];
792        let result = tokenizer.decode(partial_ids, false);
793        assert!(
794            result.is_ok(),
795            "decode of partial UTF-8 should succeed via lossy fallback"
796        );
797        let decoded = result.unwrap();
798        assert!(
799            decoded.contains('\u{FFFD}') || decoded.is_empty(),
800            "lossy decode should contain replacement char or be empty, got: {decoded:?}"
801        );
802    }
803
804    #[test]
805    fn test_decode_valid_utf8_does_not_use_fallback() {
806        // Ensure that valid UTF-8 round-trips through the happy path unchanged.
807        let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap();
808        let text = "Hello, δΈ–η•Œ!";
809        let encoding = tokenizer.encode(text, false).unwrap();
810        let decoded = tokenizer.decode(encoding.token_ids(), false).unwrap();
811        assert_eq!(decoded, text);
812    }
813}