Skip to main content

llm_tokenizer/
tiktoken.rs

1use std::{
2    collections::HashMap,
3    path::{Path, PathBuf},
4};
5
6use anyhow::{Error, Result};
7use base64::{engine::general_purpose::STANDARD, Engine as _};
8use rustc_hash::FxHashMap;
9use tiktoken_rs::{
10    cl100k_base, o200k_base, p50k_base, p50k_edit, r50k_base,
11    tokenizer::{get_tokenizer, Tokenizer},
12    CoreBPE,
13};
14
15use crate::{
16    chat_template::{
17        load_chat_template_from_file, ChatTemplateContentFormat, ChatTemplateParams,
18        ChatTemplateState, ThinkingKeyName, ThinkingToggle,
19    },
20    encoders::kimi_k25_tools::apply_kimi_k25_tools,
21    factory::discover_chat_template_in_dir,
22    kimi_k2_tokenizer,
23    traits::{Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait},
24};
25
26#[derive(Debug, Clone, Copy)]
27enum Renderer {
28    Jinja,
29    KimiK25Tools,
30}
31
32/// Regex pattern for cl100k_base tokenization.
33///
34/// This pattern is correct for OpenAI models and most open-source tiktoken models. Models
35/// with a tokenizer-specific regex specialize the pattern inside `load_from_path`.
36const CL100K_BASE_PATTERN: &str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+";
37
38type Rank = u32;
39
40// ---------------------------------------------------------------------------
41// Tiktoken-specific config parsing (from tokenizer_config.json)
42// ---------------------------------------------------------------------------
43
44/// Parsed `tokenizer_config.json` for tiktoken-based models.
45#[derive(Default)]
46struct TiktokenConfig {
47    special_tokens: SpecialTokens,
48    /// Token string -> ID mapping from `added_tokens_decoder`
49    added_tokens: HashMap<String, TokenIdType>,
50    chat_template: Option<String>,
51}
52
53/// Parse an already-loaded `tokenizer_config.json` value into a `TiktokenConfig`.
54fn parse_tiktoken_config(value: &serde_json::Value) -> TiktokenConfig {
55    TiktokenConfig {
56        special_tokens: parse_special_tokens(value),
57        added_tokens: parse_added_tokens_decoder(value),
58        chat_template: value
59            .get("chat_template")
60            .and_then(|v| v.as_str())
61            .map(String::from),
62    }
63}
64
65/// Load `tokenizer_config.json` from `dir`, returning both the parsed
66/// `TiktokenConfig` and the raw JSON value (so callers like Kimi detection
67/// can inspect the same parse without re-reading the file).
68fn load_tiktoken_config_from_dir(
69    dir: &Path,
70) -> Result<(TiktokenConfig, Option<serde_json::Value>)> {
71    let config_path = dir.join("tokenizer_config.json");
72    if !config_path.exists() {
73        return Ok((TiktokenConfig::default(), None));
74    }
75    let content = std::fs::read_to_string(&config_path)?;
76    let value: serde_json::Value = serde_json::from_str(&content)?;
77    let config = parse_tiktoken_config(&value);
78    Ok((config, Some(value)))
79}
80
81/// Parse `added_tokens_decoder` from config JSON.
82///
83/// Format: `{ "163584": { "content": "[BOS]", "special": true }, ... }`
84fn parse_added_tokens_decoder(config: &serde_json::Value) -> HashMap<String, TokenIdType> {
85    let mut tokens = HashMap::new();
86    if let Some(added) = config
87        .get("added_tokens_decoder")
88        .and_then(|v| v.as_object())
89    {
90        for (id_str, token_info) in added {
91            if let (Ok(id), Some(content)) = (
92                id_str.parse::<TokenIdType>(),
93                token_info.get("content").and_then(|v| v.as_str()),
94            ) {
95                tokens.insert(content.to_string(), id);
96            }
97        }
98    }
99    tokens
100}
101
102/// Extract named special tokens (bos, eos, unk, etc.) from config JSON.
103///
104/// Handles both string-valued tokens (`"bos_token": "<s>"`) and object-valued tokens
105/// (`"bos_token": {"content": "<s>", "lstrip": false, ...}`) found in some HuggingFace models.
106fn parse_special_tokens(config: &serde_json::Value) -> SpecialTokens {
107    let get_str = |key: &str| {
108        config.get(key).and_then(|v| {
109            v.as_str()
110                .map(String::from)
111                .or_else(|| v.get("content").and_then(|c| c.as_str()).map(String::from))
112        })
113    };
114
115    let additional: Vec<String> = config
116        .get("additional_special_tokens")
117        .and_then(|v| v.as_array())
118        .map(|arr| {
119            arr.iter()
120                .filter_map(|v| {
121                    v.as_str()
122                        .map(String::from)
123                        .or_else(|| v.get("content").and_then(|c| c.as_str()).map(String::from))
124                })
125                .collect()
126        })
127        .unwrap_or_default();
128
129    SpecialTokens {
130        bos_token: get_str("bos_token"),
131        eos_token: get_str("eos_token"),
132        unk_token: get_str("unk_token"),
133        sep_token: get_str("sep_token"),
134        pad_token: get_str("pad_token"),
135        cls_token: get_str("cls_token"),
136        mask_token: get_str("mask_token"),
137        additional_special_tokens: additional,
138    }
139}
140
141/// Tiktoken tokenizer wrapper β€” supports both built-in OpenAI encodings and hub-loaded models.
142pub struct TiktokenTokenizer {
143    tokenizer: CoreBPE,
144    special_tokens: SpecialTokens,
145    vocab: HashMap<String, TokenIdType>,
146    reverse_vocab: HashMap<TokenIdType, String>,
147    vocab_size: usize,
148    chat_template: ChatTemplateState,
149    eos_token_ids: Vec<TokenIdType>,
150    renderer: Renderer,
151}
152
153/// Supported Tiktoken models
154#[derive(Debug, Clone, Copy)]
155pub enum TiktokenModel {
156    /// GPT-4o, o1, o3, o4, GPT-4.5, GPT-5 β€” all 200k-vocab models
157    O200kBase,
158    /// GPT-4, GPT-3.5-turbo, text-embedding-ada-002
159    Cl100kBase,
160    /// Codex models, text-davinci-002, text-davinci-003
161    P50kBase,
162    /// Use for edit models like text-davinci-edit-001, code-davinci-edit-001
163    P50kEdit,
164    /// GPT-3 models like davinci
165    R50kBase,
166}
167
168impl TiktokenTokenizer {
169    /// Create a new Tiktoken tokenizer for the specified built-in model
170    pub fn new(model: TiktokenModel) -> Result<Self> {
171        let tokenizer =
172            match model {
173                TiktokenModel::O200kBase => o200k_base()
174                    .map_err(|e| Error::msg(format!("Failed to load o200k_base: {e}")))?,
175                TiktokenModel::Cl100kBase => cl100k_base()
176                    .map_err(|e| Error::msg(format!("Failed to load cl100k_base: {e}")))?,
177                TiktokenModel::P50kBase => {
178                    p50k_base().map_err(|e| Error::msg(format!("Failed to load p50k_base: {e}")))?
179                }
180                TiktokenModel::P50kEdit => {
181                    p50k_edit().map_err(|e| Error::msg(format!("Failed to load p50k_edit: {e}")))?
182                }
183                TiktokenModel::R50kBase => {
184                    r50k_base().map_err(|e| Error::msg(format!("Failed to load r50k_base: {e}")))?
185                }
186            };
187
188        let special_tokens = Self::get_special_tokens_for_model(model);
189
190        let vocab_size = match model {
191            TiktokenModel::O200kBase => 200019,
192            TiktokenModel::Cl100kBase => 100256,
193            TiktokenModel::P50kBase | TiktokenModel::P50kEdit => 50281,
194            TiktokenModel::R50kBase => 50257,
195        };
196
197        Ok(TiktokenTokenizer {
198            tokenizer,
199            special_tokens,
200            vocab: HashMap::new(),
201            reverse_vocab: HashMap::new(),
202            vocab_size,
203            chat_template: ChatTemplateState::empty(),
204            eos_token_ids: Vec::new(), // No directory path in from_model
205            renderer: Renderer::Jinja,
206        })
207    }
208
209    /// Create from a directory containing tiktoken.model + tokenizer_config.json
210    pub fn from_dir(dir: &Path) -> Result<Self> {
211        Self::from_dir_with_chat_template(dir, None)
212    }
213
214    /// Create from a directory with an optional chat template file path.
215    /// Discovers the tiktoken model file automatically via `find_tiktoken_file`.
216    pub fn from_dir_with_chat_template(
217        dir: &Path,
218        chat_template_path: Option<&str>,
219    ) -> Result<Self> {
220        let tiktoken_path = find_tiktoken_file(dir)?;
221        Self::load_from_path(&tiktoken_path, chat_template_path)
222    }
223
224    /// Create from an exact tiktoken file path (`.tiktoken` or `tiktoken.model`).
225    /// Looks for `tokenizer_config.json` in the same directory.
226    pub fn from_file(tiktoken_path: &Path) -> Result<Self> {
227        Self::from_file_with_chat_template(tiktoken_path, None)
228    }
229
230    /// Create from an exact tiktoken file path with an optional chat template.
231    pub fn from_file_with_chat_template(
232        tiktoken_path: &Path,
233        chat_template_path: Option<&str>,
234    ) -> Result<Self> {
235        Self::load_from_path(tiktoken_path, chat_template_path)
236    }
237
238    /// Core loading logic shared by `from_dir` and `from_file` constructors.
239    fn load_from_path(tiktoken_path: &Path, chat_template_path: Option<&str>) -> Result<Self> {
240        // 1. Load BPE encoder from the exact file
241        let tiktoken_path_str = tiktoken_path
242            .to_str()
243            .ok_or_else(|| Error::msg("Tiktoken file path is not valid UTF-8"))?;
244        let encoder = load_tiktoken_bpe(tiktoken_path_str)?;
245
246        // 2. Parse tokenizer_config.json from the same directory
247        let dir = tiktoken_path
248            .parent()
249            .ok_or_else(|| Error::msg("Cannot determine parent directory of tiktoken file"))?;
250        let (mut config, tokenizer_config_value) = load_tiktoken_config_from_dir(dir)?;
251
252        // Kimi-K2/K2.5/K2.6 specialize the regex and pre-fill 256 reserved
253        // special-token slots starting at `len(mergeable_ranks)`; all other
254        // tiktoken models use the cl100k pattern unchanged. Reuse the
255        // already-parsed tokenizer_config.json so we don't re-read it.
256        let pattern = if kimi_k2_tokenizer::matches(tokenizer_config_value.as_ref(), dir) {
257            kimi_k2_tokenizer::apply_reserved_special_tokens(
258                &mut config.added_tokens,
259                encoder.len(),
260            );
261            kimi_k2_tokenizer::KIMI_K2_PATTERN
262        } else {
263            CL100K_BASE_PATTERN
264        };
265
266        // 3. Build special tokens encoder for CoreBPE (needs FxHashMap)
267        let special_tokens_encoder: FxHashMap<String, Rank> = config
268            .added_tokens
269            .iter()
270            .map(|(k, &v)| (k.clone(), v))
271            .collect();
272
273        // 4. Calculate true vocab size from max token ID (handles sparse/reserved IDs),
274        //    build string-based vocab maps (borrows encoder), then pass encoder by value to CoreBPE
275        let vocab_size = encoder
276            .values()
277            .copied()
278            .chain(special_tokens_encoder.values().copied())
279            .max()
280            .map(|id| id as usize + 1)
281            .unwrap_or(0);
282        let (vocab, reverse_vocab) = build_vocab_maps(&encoder, &config.added_tokens);
283        let tokenizer = CoreBPE::new(encoder, special_tokens_encoder, pattern)?;
284
285        // 5. Load chat template β€” propagate errors for explicit paths,
286        //    silently fall back for auto-discovery
287        let chat_template = if let Some(p) = chat_template_path {
288            load_chat_template_from_file(p)?
289        } else {
290            config.chat_template.or_else(|| {
291                discover_chat_template_in_dir(dir)
292                    .and_then(|p| load_chat_template_from_file(&p).ok().flatten())
293            })
294        };
295
296        // Load merged EOS token IDs from config.json + generation_config.json
297        let eos_token_ids = crate::eos::load_eos_token_ids(dir);
298
299        // Detect which chat-template renderer to use based on config.json::architectures
300        let renderer = detect_renderer_from_config(dir);
301
302        Ok(TiktokenTokenizer {
303            tokenizer,
304            special_tokens: config.special_tokens,
305            vocab,
306            reverse_vocab,
307            vocab_size,
308            chat_template: ChatTemplateState::new(chat_template)?,
309            eos_token_ids,
310            renderer,
311        })
312    }
313
314    /// Create a tokenizer from a model string (e.g., "gpt-4", "gpt-3.5-turbo")
315    pub fn from_model_name(model_name: &str) -> Result<Self> {
316        let bare = model_name.rsplit('/').next().unwrap_or(model_name);
317        let model = match get_tokenizer(bare) {
318            Some(Tokenizer::O200kBase) => TiktokenModel::O200kBase,
319            Some(Tokenizer::Cl100kBase) => TiktokenModel::Cl100kBase,
320            Some(Tokenizer::P50kBase) => TiktokenModel::P50kBase,
321            Some(Tokenizer::P50kEdit) => TiktokenModel::P50kEdit,
322            Some(Tokenizer::R50kBase) => TiktokenModel::R50kBase,
323            _ => return Err(anyhow::anyhow!(
324                "Unrecognized OpenAI model name: '{model_name}'. Expected GPT-3, GPT-3.5, GPT-4, GPT-4o, GPT-4.5, GPT-5, o1, o3, o4, or related model names"
325            )),
326        };
327        Self::new(model)
328    }
329
330    /// Get special tokens for a specific model
331    fn get_special_tokens_for_model(model: TiktokenModel) -> SpecialTokens {
332        match model {
333            TiktokenModel::Cl100kBase => SpecialTokens {
334                bos_token: Some("<|endoftext|>".to_string()),
335                eos_token: Some("<|endoftext|>".to_string()),
336                unk_token: None,
337                sep_token: None,
338                pad_token: Some("<|endoftext|>".to_string()),
339                cls_token: None,
340                mask_token: None,
341                additional_special_tokens: vec![
342                    "<|fim_prefix|>".to_string(),
343                    "<|fim_middle|>".to_string(),
344                    "<|fim_suffix|>".to_string(),
345                    "<|endofprompt|>".to_string(),
346                ],
347            },
348            _ => SpecialTokens {
349                bos_token: Some("<|endoftext|>".to_string()),
350                eos_token: Some("<|endoftext|>".to_string()),
351                unk_token: None,
352                sep_token: None,
353                pad_token: Some("<|endoftext|>".to_string()),
354                cls_token: None,
355                mask_token: None,
356                additional_special_tokens: vec![],
357            },
358        }
359    }
360}
361
362/// Parse a .tiktoken / tiktoken.model file into a BPE encoder.
363///
364/// Format: each line is `<base64-encoded-token-bytes> <rank>`
365fn load_tiktoken_bpe(path: &str) -> Result<FxHashMap<Vec<u8>, Rank>> {
366    let content = std::fs::read_to_string(path)?;
367    let mut encoder =
368        FxHashMap::with_capacity_and_hasher(content.lines().count(), Default::default());
369    for line in content.lines() {
370        if line.is_empty() {
371            continue;
372        }
373        let mut parts = line.split_whitespace();
374        let token_b64 = parts
375            .next()
376            .ok_or_else(|| Error::msg("missing token in tiktoken file"))?;
377        let rank_str = parts
378            .next()
379            .ok_or_else(|| Error::msg("missing rank in tiktoken file"))?;
380        let token_bytes = STANDARD.decode(token_b64)?;
381        let rank: Rank = rank_str.parse()?;
382        encoder.insert(token_bytes, rank);
383    }
384    Ok(encoder)
385}
386
387/// Build string-level vocab from byte-level encoder + added tokens.
388fn build_vocab_maps(
389    encoder: &FxHashMap<Vec<u8>, Rank>,
390    added_tokens: &HashMap<String, TokenIdType>,
391) -> (HashMap<String, TokenIdType>, HashMap<TokenIdType, String>) {
392    let capacity = encoder.len() + added_tokens.len();
393    let mut vocab = HashMap::with_capacity(capacity);
394    let mut reverse_vocab = HashMap::with_capacity(capacity);
395
396    // BPE tokens (only valid UTF-8 sequences get string entries)
397    for (token_bytes, &rank) in encoder {
398        if let Ok(token_str) = std::str::from_utf8(token_bytes) {
399            vocab.insert(token_str.to_string(), rank);
400            reverse_vocab.insert(rank, token_str.to_string());
401        }
402    }
403
404    // Special/added tokens (always valid UTF-8)
405    for (token_str, &id) in added_tokens {
406        vocab.insert(token_str.clone(), id);
407        reverse_vocab.insert(id, token_str.clone());
408    }
409
410    (vocab, reverse_vocab)
411}
412
413/// Find a tiktoken model file in the given directory.
414///
415/// Looks for `tiktoken.model` first, then any `*.tiktoken` file.
416fn find_tiktoken_file(dir: &Path) -> Result<PathBuf> {
417    let tiktoken_model = dir.join("tiktoken.model");
418    if tiktoken_model.exists() {
419        return Ok(tiktoken_model);
420    }
421
422    // Look for *.tiktoken files
423    if let Ok(entries) = std::fs::read_dir(dir) {
424        for entry in entries.flatten() {
425            if let Some(name) = entry.file_name().to_str() {
426                if name.ends_with(".tiktoken") {
427                    return Ok(entry.path());
428                }
429            }
430        }
431    }
432
433    Err(Error::msg(format!(
434        "No tiktoken model file found in '{}'",
435        dir.display()
436    )))
437}
438
439/// Check whether a directory contains a tiktoken model file.
440pub fn has_tiktoken_file(dir: &Path) -> bool {
441    if dir.join("tiktoken.model").exists() {
442        return true;
443    }
444    std::fs::read_dir(dir)
445        .ok()
446        .map(|entries| {
447            entries.flatten().any(|e| {
448                e.file_name()
449                    .to_str()
450                    .is_some_and(|n| n.ends_with(".tiktoken"))
451            })
452        })
453        .unwrap_or(false)
454}
455
456/// Check whether a single file is a tiktoken model file (by name).
457pub fn is_tiktoken_file(path: &Path) -> bool {
458    path.file_name()
459        .and_then(|n| n.to_str())
460        .is_some_and(|name| name == "tiktoken.model" || name.ends_with(".tiktoken"))
461}
462
463impl Encoder for TiktokenTokenizer {
464    fn encode(&self, input: &str, _add_special_tokens: bool) -> Result<Encoding> {
465        // Always use encode_with_special_tokens so that special token strings
466        // in the input (e.g., <|media_pad|> from chat templates) are recognized
467        // as single tokens rather than split into BPE sub-tokens.
468        //
469        // NOTE: We intentionally ignore `add_special_tokens` here because the
470        // flag has different semantics across backends. For HuggingFace it
471        // controls BOS/EOS prepend/append (tiktoken has no such concept).
472        // For tiktoken, encode_ordinary vs encode_with_special_tokens controls
473        // whether special-token *patterns* in the input are recognized.
474        // All callers that encode chat-template-rendered text pass `false`
475        // (meaning "don't add BOS/EOS"), but tiktoken must still recognize
476        // the special tokens the template inserted. A proper fix requires
477        // redesigning the Encoder trait to separate "add wrapper tokens" from
478        // "recognize special-token patterns".
479        let tokens = self.tokenizer.encode_with_special_tokens(input);
480        Ok(Encoding::Tiktoken(tokens))
481    }
482
483    fn encode_batch(&self, inputs: &[&str], add_special_tokens: bool) -> Result<Vec<Encoding>> {
484        inputs
485            .iter()
486            .map(|input| self.encode(input, add_special_tokens))
487            .collect()
488    }
489}
490
491impl Decoder for TiktokenTokenizer {
492    fn decode(&self, token_ids: &[TokenIdType], _skip_special_tokens: bool) -> Result<String> {
493        match self.tokenizer.decode(token_ids.to_vec()) {
494            Ok(text) => Ok(text),
495            Err(err) if is_unknown_tiktoken_decode_error(&err) => Err(Error::msg(format!(
496                "tiktoken decode failed for unknown token id: {err}"
497            ))),
498            Err(err) => {
499                // Fallback to lossy decoding for incomplete UTF-8 sequences
500                let bytes: Vec<u8> = self
501                    .tokenizer
502                    ._decode_native_and_split(token_ids.to_vec())
503                    .flatten()
504                    .collect();
505                tracing::warn!(
506                    error = %err,
507                    token_count = token_ids.len(),
508                    "tiktoken decode failed; returning lossy UTF-8 fallback"
509                );
510                Ok(String::from_utf8_lossy(&bytes).into_owned())
511            }
512        }
513    }
514}
515
516/// Detect tiktoken's "unknown token id" error so we can surface a clean error
517/// instead of letting the lossy-decode fallback panic on a missing key.
518///
519/// We match on the `Display` string because tiktoken-rs's `DecodeKeyError` lives
520/// in a private `vendor_tiktoken` module and isn't re-exported (as of 0.9.1),
521/// so a typed `downcast_ref` is not available. The message format is stable β€”
522/// see `vendor_tiktoken::DecodeKeyError::fmt` upstream.
523fn is_unknown_tiktoken_decode_error(err: &Error) -> bool {
524    err.to_string().starts_with("Invalid token for decoding:")
525}
526
527impl TokenizerTrait for TiktokenTokenizer {
528    fn vocab_size(&self) -> usize {
529        self.vocab_size
530    }
531
532    fn get_special_tokens(&self) -> &SpecialTokens {
533        &self.special_tokens
534    }
535
536    fn token_to_id(&self, token: &str) -> Option<TokenIdType> {
537        self.vocab.get(token).copied()
538    }
539
540    fn id_to_token(&self, id: TokenIdType) -> Option<String> {
541        self.reverse_vocab.get(&id).cloned()
542    }
543
544    fn as_any(&self) -> &dyn std::any::Any {
545        self
546    }
547
548    fn apply_chat_template(
549        &self,
550        messages: &[serde_json::Value],
551        params: ChatTemplateParams,
552    ) -> Result<String> {
553        // Inject special tokens if the caller didn't provide them
554        let params = if params.special_tokens.is_some() {
555            params
556        } else {
557            ChatTemplateParams {
558                special_tokens: Some(&self.special_tokens),
559                ..params
560            }
561        };
562        match self.renderer {
563            Renderer::Jinja => self.chat_template.apply(messages, params),
564            Renderer::KimiK25Tools => apply_kimi_k25_tools(&self.chat_template, messages, params),
565        }
566    }
567
568    fn chat_template_content_format(&self) -> ChatTemplateContentFormat {
569        self.chat_template.content_format()
570    }
571
572    fn thinking_toggle(&self) -> ThinkingToggle {
573        self.chat_template.thinking_toggle()
574    }
575
576    fn thinking_key_name(&self) -> Option<ThinkingKeyName> {
577        self.chat_template.thinking_key_name()
578    }
579    fn eos_token_ids(&self) -> &[TokenIdType] {
580        &self.eos_token_ids
581    }
582
583    fn think_in_prefill(&self) -> bool {
584        self.chat_template.think_in_prefill()
585    }
586
587    fn set_chat_template(&mut self, template: String) -> Result<()> {
588        self.chat_template.set(template)
589    }
590}
591
592// ---------------------------------------------------------------------------
593// Renderer detection (config.json::architectures)
594// ---------------------------------------------------------------------------
595/// Inspect the sibling `config.json` to decide which chat-template renderer to
596/// use. Missing / unreadable / malformed config falls back to `Renderer::Jinja`
597/// silently with a debug log, mirroring `huggingface.rs::detect_renderer_from_config`.
598fn detect_renderer_from_config(dir: &Path) -> Renderer {
599    let path = dir.join("config.json");
600    if !path.exists() {
601        return Renderer::Jinja;
602    }
603    let content = match std::fs::read_to_string(&path) {
604        Ok(c) => c,
605        Err(err) => {
606            tracing::debug!(?err, ?path, "config.json unreadable; using Jinja renderer");
607            return Renderer::Jinja;
608        }
609    };
610    let value: serde_json::Value = match serde_json::from_str(&content) {
611        Ok(v) => v,
612        Err(err) => {
613            tracing::debug!(?err, ?path, "config.json malformed; using Jinja renderer");
614            return Renderer::Jinja;
615        }
616    };
617    let is_kimi = value
618        .get("architectures")
619        .and_then(|v| v.as_array())
620        .is_some_and(|a| {
621            a.iter()
622                .any(|v| v.as_str() == Some("KimiK25ForConditionalGeneration"))
623        });
624    if is_kimi {
625        tracing::debug!(?path, "selected KimiK25Tools chat-template renderer");
626        return Renderer::KimiK25Tools;
627    }
628    Renderer::Jinja
629}
630
631#[cfg(test)]
632mod tests {
633    use super::*;
634    use crate::traits::{Decoder, Encoder, Tokenizer};
635
636    const MINIMAL_TIKTOKEN_MODEL: &str = "YQ== 0\nYg== 1\n";
637
638    fn write_minimal_tiktoken_dir(
639        tokenizer_config: &str,
640        model_config: Option<&str>,
641    ) -> tempfile::TempDir {
642        let dir = tempfile::tempdir().unwrap();
643        std::fs::write(dir.path().join("tiktoken.model"), MINIMAL_TIKTOKEN_MODEL).unwrap();
644        std::fs::write(dir.path().join("tokenizer_config.json"), tokenizer_config).unwrap();
645        if let Some(model_config) = model_config {
646            std::fs::write(dir.path().join("config.json"), model_config).unwrap();
647        }
648        dir
649    }
650
651    #[test]
652    fn test_tiktoken_creation() {
653        let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap();
654        assert_eq!(tokenizer.vocab_size(), 100256);
655    }
656
657    #[test]
658    fn test_encode_decode() {
659        let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap();
660
661        let text = "Hello, world!";
662        let encoding = tokenizer.encode(text, false).unwrap();
663
664        let decoded = tokenizer.decode(encoding.token_ids(), false).unwrap();
665        assert_eq!(decoded, text);
666    }
667
668    #[test]
669    fn test_batch_encode() {
670        let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap();
671
672        let texts = vec!["Hello", "World", "Test"];
673        let encodings = tokenizer.encode_batch(&texts, false).unwrap();
674
675        assert_eq!(encodings.len(), 3);
676        for (i, encoding) in encodings.iter().enumerate() {
677            let decoded = tokenizer.decode(encoding.token_ids(), false).unwrap();
678            assert_eq!(decoded, texts[i]);
679        }
680    }
681
682    #[test]
683    fn test_special_tokens() {
684        let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap();
685        let special_tokens = tokenizer.get_special_tokens();
686
687        assert!(special_tokens.eos_token.is_some());
688        assert_eq!(special_tokens.eos_token.as_ref().unwrap(), "<|endoftext|>");
689    }
690
691    #[test]
692    fn test_builtin_tokenizer_has_empty_vocab_maps() {
693        let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap();
694        // Built-in path: vocab maps are empty, token_to_id returns None
695        assert_eq!(tokenizer.token_to_id("hello"), None);
696        assert_eq!(tokenizer.id_to_token(0), None);
697    }
698
699    #[test]
700    fn test_load_tiktoken_bpe() {
701        use std::io::Write;
702        let dir = tempfile::tempdir().unwrap();
703        let file_path = dir.path().join("test.tiktoken");
704        let mut f = std::fs::File::create(&file_path).unwrap();
705        // "IQ==" is base64 for byte 0x21 ('!'), rank 0
706        // "Ig==" is base64 for byte 0x22 ('"'), rank 1
707        writeln!(f, "IQ== 0").unwrap();
708        writeln!(f, "Ig== 1").unwrap();
709
710        let encoder = load_tiktoken_bpe(file_path.to_str().unwrap()).unwrap();
711        assert_eq!(encoder.len(), 2);
712        assert_eq!(encoder.get(&vec![0x21u8]), Some(&0));
713        assert_eq!(encoder.get(&vec![0x22u8]), Some(&1));
714    }
715
716    #[test]
717    fn test_build_vocab_maps() {
718        let mut encoder = FxHashMap::default();
719        encoder.insert(b"hello".to_vec(), 42u32);
720        encoder.insert(vec![0xFF, 0xFE], 99u32); // invalid UTF-8
721
722        let mut added = HashMap::new();
723        added.insert("<|special|>".to_string(), 1000u32);
724
725        let (vocab, reverse_vocab) = build_vocab_maps(&encoder, &added);
726
727        // Valid UTF-8 token present
728        assert_eq!(vocab.get("hello"), Some(&42));
729        assert_eq!(reverse_vocab.get(&42), Some(&"hello".to_string()));
730
731        // Invalid UTF-8 token excluded from vocab
732        assert!(!vocab.contains_key("\u{FFFD}")); // not lossy-inserted
733
734        // Added token present
735        assert_eq!(vocab.get("<|special|>"), Some(&1000));
736        assert_eq!(reverse_vocab.get(&1000), Some(&"<|special|>".to_string()));
737    }
738
739    #[test]
740    fn test_has_tiktoken_file() {
741        let dir = tempfile::tempdir().unwrap();
742        assert!(!has_tiktoken_file(dir.path()));
743
744        std::fs::write(dir.path().join("tiktoken.model"), "test").unwrap();
745        assert!(has_tiktoken_file(dir.path()));
746    }
747
748    #[test]
749    fn test_find_tiktoken_file_model() {
750        let dir = tempfile::tempdir().unwrap();
751        std::fs::write(dir.path().join("tiktoken.model"), "test").unwrap();
752        let found = find_tiktoken_file(dir.path()).unwrap();
753        assert_eq!(found.file_name().unwrap(), "tiktoken.model");
754    }
755
756    #[test]
757    fn test_find_tiktoken_file_extension() {
758        let dir = tempfile::tempdir().unwrap();
759        std::fs::write(dir.path().join("vocab.tiktoken"), "test").unwrap();
760        let found = find_tiktoken_file(dir.path()).unwrap();
761        assert!(found
762            .file_name()
763            .unwrap()
764            .to_str()
765            .unwrap()
766            .ends_with(".tiktoken"));
767    }
768
769    #[test]
770    fn test_is_tiktoken_file() {
771        assert!(is_tiktoken_file(Path::new("tiktoken.model")));
772        assert!(is_tiktoken_file(Path::new("vocab.tiktoken")));
773        assert!(!is_tiktoken_file(Path::new("tokenizer.json")));
774        assert!(!is_tiktoken_file(Path::new("model.bin")));
775    }
776
777    #[test]
778    fn test_parse_added_tokens_decoder() {
779        let config: serde_json::Value = serde_json::json!({
780            "added_tokens_decoder": {
781                "163584": { "content": "[BOS]", "special": true },
782                "163585": { "content": "[EOS]", "special": true },
783                "163586": { "content": "<|im_end|>", "special": true }
784            }
785        });
786        let tokens = parse_added_tokens_decoder(&config);
787        assert_eq!(tokens.get("[BOS]"), Some(&163584));
788        assert_eq!(tokens.get("[EOS]"), Some(&163585));
789        assert_eq!(tokens.get("<|im_end|>"), Some(&163586));
790    }
791
792    #[test]
793    fn test_tiktoken_unknown_token_decode_returns_error() {
794        let dir = write_minimal_tiktoken_dir(
795            r#"{
796                "added_tokens_decoder": {
797                    "2": { "content": "[BOS]", "special": true }
798                }
799            }"#,
800            None,
801        );
802        let tokenizer = TiktokenTokenizer::from_dir(dir.path()).unwrap();
803
804        let err = tokenizer.decode(&[4], false).unwrap_err();
805        assert!(
806            err.to_string()
807                .contains("tiktoken decode failed for unknown token id"),
808            "unexpected error: {err}"
809        );
810    }
811
812    #[test]
813    fn test_parse_special_tokens() {
814        let config: serde_json::Value = serde_json::json!({
815            "bos_token": "[BOS]",
816            "eos_token": "[EOS]",
817            "unk_token": "[UNK]",
818            "pad_token": "[PAD]",
819            "additional_special_tokens": ["<|im_end|>", "<|im_user|>"]
820        });
821        let special = parse_special_tokens(&config);
822        assert_eq!(special.bos_token.as_deref(), Some("[BOS]"));
823        assert_eq!(special.eos_token.as_deref(), Some("[EOS]"));
824        assert_eq!(special.unk_token.as_deref(), Some("[UNK]"));
825        assert_eq!(special.pad_token.as_deref(), Some("[PAD]"));
826        assert_eq!(special.additional_special_tokens.len(), 2);
827    }
828
829    #[test]
830    fn test_parse_special_tokens_object_valued() {
831        let config: serde_json::Value = serde_json::json!({
832            "bos_token": {"content": "<s>", "lstrip": false, "rstrip": false, "single_word": false, "special": true},
833            "eos_token": "</s>",
834            "unk_token": {"content": "<unk>", "special": true}
835        });
836        let special = parse_special_tokens(&config);
837        assert_eq!(special.bos_token.as_deref(), Some("<s>"));
838        assert_eq!(special.eos_token.as_deref(), Some("</s>"));
839        assert_eq!(special.unk_token.as_deref(), Some("<unk>"));
840    }
841
842    #[test]
843    fn test_tiktoken_config_default() {
844        let config = TiktokenConfig::default();
845        assert!(config.special_tokens.bos_token.is_none());
846        assert!(config.added_tokens.is_empty());
847        assert!(config.chat_template.is_none());
848    }
849
850    #[test]
851    fn test_load_tiktoken_config_from_dir_missing_file() {
852        let dir = tempfile::tempdir().unwrap();
853        let (config, value) = load_tiktoken_config_from_dir(dir.path()).unwrap();
854        assert!(value.is_none());
855        assert!(config.added_tokens.is_empty());
856    }
857
858    #[test]
859    fn test_decode_lossy_fallback_for_invalid_utf8() {
860        // cl100k_base maps individual bytes to token IDs via its byte-level BPE.
861        // Encode a multi-byte UTF-8 character, then decode only a prefix of its
862        // tokens so the raw bytes form an incomplete (invalid) UTF-8 sequence.
863        // The old implementation would return an error; the new one should fall
864        // back to lossy decoding and produce U+FFFD replacement characters.
865        let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap();
866
867        // "πŸ˜€" is U+1F600, encoded as 4 UTF-8 bytes: [0xF0, 0x9F, 0x98, 0x80].
868        // With cl100k_base this encodes to multiple tokens. Taking a strict
869        // subset of those tokens gives bytes that aren't valid UTF-8.
870        let full_encoding = tokenizer.encode("πŸ˜€", false).unwrap();
871        let full_ids = full_encoding.token_ids();
872        assert!(
873            full_ids.len() > 1,
874            "emoji should encode to multiple tokens in cl100k_base"
875        );
876
877        // Take only the first token β€” its raw bytes are an incomplete UTF-8 prefix.
878        let partial_ids = &full_ids[..1];
879        let result = tokenizer.decode(partial_ids, false);
880        assert!(
881            result.is_ok(),
882            "decode of partial UTF-8 should succeed via lossy fallback"
883        );
884        let decoded = result.unwrap();
885        assert!(
886            decoded.contains('\u{FFFD}') || decoded.is_empty(),
887            "lossy decode should contain replacement char or be empty, got: {decoded:?}"
888        );
889    }
890
891    #[test]
892    fn test_decode_valid_utf8_does_not_use_fallback() {
893        // Ensure that valid UTF-8 round-trips through the happy path unchanged.
894        let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap();
895        let text = "Hello, δΈ–η•Œ!";
896        let encoding = tokenizer.encode(text, false).unwrap();
897        let decoded = tokenizer.decode(encoding.token_ids(), false).unwrap();
898        assert_eq!(decoded, text);
899    }
900
901    #[test]
902    fn test_encode_recognizes_special_tokens_in_input() {
903        // encode_with_special_tokens must recognize special token strings
904        // so that chat-template-rendered text (containing e.g. <|endoftext|>)
905        // produces single token IDs, not BPE sub-tokens.
906        let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap();
907        // <|endoftext|> is token 100257 in cl100k_base
908        // Note: add_special_tokens is intentionally ignored for tiktoken
909        // (see Encoder impl comment), so both true and false produce the same result.
910        let encoding = tokenizer.encode("hello<|endoftext|>world", false).unwrap();
911        let ids = encoding.token_ids();
912        assert!(
913            ids.contains(&100257),
914            "Special token <|endoftext|> should be recognized as single token, got: {ids:?}"
915        );
916    }
917}