Skip to main content

llm_tokenizer/
huggingface.rs

1use std::collections::HashMap;
2
3use anyhow::{Error, Result};
4use tokenizers::{
5    processors::template::TemplateProcessing,
6    tokenizer::{step_decode_stream, Tokenizer as HfTokenizer},
7};
8use tracing::debug;
9
10use crate::{
11    chat_template::{
12        load_chat_template_from_file, ChatTemplateContentFormat, ChatTemplateParams,
13        ChatTemplateState, ThinkingKeyName, ThinkingToggle,
14    },
15    encoders::{deepseek_v32, deepseek_v4},
16    traits::{Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait},
17};
18
19#[derive(Debug, Clone, Copy)]
20enum Renderer {
21    Jinja,
22    DeepseekV32,
23    DeepseekV4,
24}
25
26/// HuggingFace tokenizer wrapper
27pub struct HuggingFaceTokenizer {
28    tokenizer: HfTokenizer,
29    special_tokens: SpecialTokens,
30    vocab: HashMap<String, TokenIdType>,
31    reverse_vocab: HashMap<TokenIdType, String>,
32    chat_template: ChatTemplateState,
33    /// EOS token IDs from config.json + generation_config.json
34    eos_token_ids: Vec<TokenIdType>,
35    /// Which renderer applies chat templates for this model.
36    renderer: Renderer,
37}
38
39impl HuggingFaceTokenizer {
40    /// Create a tokenizer from a HuggingFace tokenizer JSON file
41    pub fn from_file(file_path: &str) -> Result<Self> {
42        // Try to auto-discover chat template if not explicitly provided
43        let path = std::path::Path::new(file_path);
44        let chat_template_path = path
45            .parent()
46            .and_then(crate::factory::discover_chat_template_in_dir);
47        Self::from_file_with_chat_template(file_path, chat_template_path.as_deref())
48    }
49
50    /// Create a tokenizer from a HuggingFace tokenizer JSON file with an optional chat template
51    pub fn from_file_with_chat_template(
52        file_path: &str,
53        chat_template_path: Option<&str>,
54    ) -> Result<Self> {
55        let mut tokenizer = HfTokenizer::from_file(file_path)
56            .map_err(|e| Error::msg(format!("Failed to load tokenizer: {e}")))?;
57
58        // Build vocab mappings (include special tokens to get added_tokens like <|im_start|>)
59        let vocab = tokenizer.get_vocab(true); // true = include special tokens and added_tokens
60        let reverse_vocab: HashMap<TokenIdType, String> = vocab
61            .iter()
62            .map(|(token, &id)| (id, token.clone()))
63            .collect();
64
65        // Load tokenizer_config.json once for chat template, add_bos/eos, and special tokens
66        let config_result = Self::load_chat_template_and_config(file_path);
67        let mut chat_template_str = config_result.chat_template;
68        let add_bos_token = config_result.add_bos_token;
69        let add_eos_token = config_result.add_eos_token;
70
71        // Extract special tokens — config values override vocab pattern matching
72        let special_tokens = Self::extract_special_tokens(&tokenizer, &config_result.config_tokens);
73
74        if let Some(template_path) = chat_template_path {
75            chat_template_str = load_chat_template_from_file(template_path)?;
76        }
77
78        // Configure post_processor based on tokenizer_config.json (matches Python transformers)
79        // Only modify when at least one setting is explicitly true
80        let needs_eos = add_eos_token == Some(true);
81        let needs_bos = match add_bos_token {
82            Some(true) => true,
83            Some(false) => false,
84            // Not set: preserve existing behavior from tokenizer.json
85            None => needs_eos && Self::tokenizer_adds_special_tokens(&tokenizer),
86        };
87
88        if needs_bos || needs_eos {
89            if let Some(post_processor) =
90                Self::build_post_processor(needs_bos, needs_eos, &special_tokens, &vocab)
91            {
92                debug!(needs_bos, needs_eos, "Configured post_processor");
93                tokenizer.with_post_processor(Some(post_processor));
94            }
95        }
96
97        // Load merged EOS token IDs from config.json + generation_config.json
98        let eos_token_ids = std::path::Path::new(file_path)
99            .parent()
100            .map(crate::eos::load_eos_token_ids)
101            .unwrap_or_default();
102
103        // Detect a custom Python-encoder model from config.json::architectures.
104        let renderer = std::path::Path::new(file_path)
105            .parent()
106            .map(detect_renderer_from_config)
107            .unwrap_or(Renderer::Jinja);
108
109        Ok(HuggingFaceTokenizer {
110            tokenizer,
111            special_tokens,
112            vocab,
113            reverse_vocab,
114            chat_template: ChatTemplateState::new(chat_template_str)?,
115            eos_token_ids,
116            renderer,
117        })
118    }
119
120    /// Check if the tokenizer's post_processor adds special tokens (e.g., BOS)
121    fn tokenizer_adds_special_tokens(tokenizer: &HfTokenizer) -> bool {
122        tokenizer
123            .encode("", true)
124            .map(|enc| !enc.get_ids().is_empty())
125            .unwrap_or(false)
126    }
127
128    /// Build a TemplateProcessing post_processor (matches Python transformers' update_post_processor)
129    /// Template format: "{bos}:0 $A:0 {eos}:0" with optional BOS/EOS based on config
130    fn build_post_processor(
131        add_bos_token: bool,
132        add_eos_token: bool,
133        special_tokens: &SpecialTokens,
134        vocab: &HashMap<String, TokenIdType>,
135    ) -> Option<TemplateProcessing> {
136        // Build template string exactly like Python:
137        // single = f"{(bos + ':0 ') if add_bos_token else ''}$A:0{(' ' + eos + ':0') if add_eos_token else ''}"
138        let mut template = String::with_capacity(32);
139        let mut tokens = Vec::with_capacity(2);
140
141        if add_bos_token {
142            let bos = special_tokens.bos_token.as_ref()?;
143            let bos_id = vocab.get(bos).copied()?;
144            template.push_str(bos);
145            template.push_str(":0 ");
146            tokens.push((bos.clone(), bos_id));
147        }
148
149        template.push_str("$A:0");
150
151        if add_eos_token {
152            let eos = special_tokens.eos_token.as_ref()?;
153            let eos_id = vocab.get(eos).copied()?;
154            template.push(' ');
155            template.push_str(eos);
156            template.push_str(":0");
157            tokens.push((eos.clone(), eos_id));
158        }
159
160        TemplateProcessing::builder()
161            .try_single(template.as_str())
162            .ok()?
163            .special_tokens(tokens)
164            .build()
165            .ok()
166    }
167
168    /// Create from an existing HuggingFace tokenizer
169    pub fn from_tokenizer(tokenizer: HfTokenizer) -> Self {
170        let special_tokens = Self::extract_special_tokens(&tokenizer, &ConfigTokens::default());
171        let vocab = tokenizer.get_vocab(true); // true = include special tokens and added_tokens
172        let reverse_vocab: HashMap<TokenIdType, String> = vocab
173            .iter()
174            .map(|(token, &id)| (id, token.clone()))
175            .collect();
176
177        HuggingFaceTokenizer {
178            tokenizer,
179            special_tokens,
180            vocab,
181            reverse_vocab,
182            chat_template: ChatTemplateState::empty(),
183            eos_token_ids: Vec::new(), // No directory path in from_tokenizer
184            renderer: Renderer::Jinja,
185        }
186    }
187
188    /// Extract special tokens from the tokenizer, using config values when available.
189    ///
190    /// Prefers explicit values from `tokenizer_config.json` (e.g., `"bos_token": "<|begin_of_text|>"`)
191    /// over pattern matching against the vocabulary, since models like Llama 4 use non-standard
192    /// token names that aren't in the hardcoded pattern list.
193    fn extract_special_tokens(
194        tokenizer: &HfTokenizer,
195        config_tokens: &ConfigTokens,
196    ) -> SpecialTokens {
197        // Get vocab with special tokens included (added_tokens like <|im_start|>)
198        let vocab = tokenizer.get_vocab(true);
199
200        let find_token = |patterns: &[&str]| -> Option<String> {
201            for pattern in patterns {
202                if vocab.contains_key(*pattern) {
203                    return Some((*pattern).to_string());
204                }
205            }
206            None
207        };
208
209        // Extract additional special tokens using the tokenizers library API
210        let additional_special_tokens: Vec<String> = tokenizer
211            .get_added_tokens_decoder()
212            .iter()
213            .filter(|(_id, token)| token.special)
214            .map(|(_id, token)| token.content.clone())
215            .collect();
216
217        // Config values take priority over pattern matching
218        SpecialTokens {
219            bos_token: config_tokens
220                .bos_token
221                .clone()
222                .or_else(|| find_token(&["<s>", "<|startoftext|>", "<BOS>", "[CLS]"])),
223            eos_token: config_tokens
224                .eos_token
225                .clone()
226                .or_else(|| find_token(&["</s>", "<|endoftext|>", "<EOS>", "[SEP]"])),
227            unk_token: config_tokens
228                .unk_token
229                .clone()
230                .or_else(|| find_token(&["<unk>", "<UNK>", "[UNK]"])),
231            sep_token: find_token(&["[SEP]", "<sep>", "<SEP>"]),
232            pad_token: config_tokens
233                .pad_token
234                .clone()
235                .or_else(|| find_token(&["<pad>", "<PAD>", "[PAD]"])),
236            cls_token: find_token(&["[CLS]", "<cls>", "<CLS>"]),
237            mask_token: find_token(&["[MASK]", "<mask>", "<MASK>"]),
238            additional_special_tokens,
239        }
240    }
241
242    /// Load chat template, special token settings, and token strings from tokenizer_config.json.
243    /// Reads the file once and extracts everything needed by the tokenizer constructor.
244    fn load_chat_template_and_config(tokenizer_path: &str) -> TokenizerConfigResult {
245        (|| {
246            let path = std::path::Path::new(tokenizer_path);
247            let config_path = path.parent()?.join("tokenizer_config.json");
248
249            if !config_path.exists() {
250                return None;
251            }
252
253            let content = std::fs::read_to_string(&config_path).ok()?;
254            let config: serde_json::Value = serde_json::from_str(&content).ok()?;
255
256            // Extract chat template directly from parsed config (avoid re-reading the file)
257            let chat_template = config
258                .get("chat_template")
259                .and_then(|v| v.as_str())
260                .map(String::from);
261
262            let add_bos_token = config.get("add_bos_token").and_then(|v| v.as_bool());
263            let add_eos_token = config.get("add_eos_token").and_then(|v| v.as_bool());
264
265            // Extract special token strings (handles both "string" and {"content": "string"})
266            let get_token = |key: &str| -> Option<String> {
267                config.get(key).and_then(|v| {
268                    v.as_str()
269                        .map(String::from)
270                        .or_else(|| v.get("content").and_then(|c| c.as_str()).map(String::from))
271                })
272            };
273
274            let config_tokens = ConfigTokens {
275                bos_token: get_token("bos_token"),
276                eos_token: get_token("eos_token"),
277                unk_token: get_token("unk_token"),
278                pad_token: get_token("pad_token"),
279            };
280
281            Some(TokenizerConfigResult {
282                chat_template,
283                add_bos_token,
284                add_eos_token,
285                config_tokens,
286            })
287        })()
288        .unwrap_or_default()
289    }
290}
291
292/// Special token strings read from tokenizer_config.json.
293#[derive(Default)]
294struct ConfigTokens {
295    bos_token: Option<String>,
296    eos_token: Option<String>,
297    unk_token: Option<String>,
298    pad_token: Option<String>,
299}
300
301/// Result of parsing tokenizer_config.json.
302#[derive(Default)]
303struct TokenizerConfigResult {
304    chat_template: Option<String>,
305    add_bos_token: Option<bool>,
306    add_eos_token: Option<bool>,
307    config_tokens: ConfigTokens,
308}
309
310impl Encoder for HuggingFaceTokenizer {
311    fn encode(&self, input: &str, add_special_tokens: bool) -> Result<Encoding> {
312        self.tokenizer
313            .encode(input, add_special_tokens)
314            .map_err(|e| Error::msg(format!("Encoding failed: {e}")))
315            .map(|encoding| Encoding::Hf(Box::new(encoding)))
316    }
317
318    fn encode_batch(&self, inputs: &[&str], add_special_tokens: bool) -> Result<Vec<Encoding>> {
319        self.tokenizer
320            .encode_batch(inputs.to_vec(), add_special_tokens)
321            .map_err(|e| Error::msg(format!("Batch encoding failed: {e}")))
322            .map(|encodings| {
323                encodings
324                    .into_iter()
325                    .map(|e| Encoding::Hf(Box::new(e)))
326                    .collect()
327            })
328    }
329}
330
331impl Decoder for HuggingFaceTokenizer {
332    fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String> {
333        self.tokenizer
334            .decode(token_ids, skip_special_tokens)
335            .map_err(|e| Error::msg(format!("Decoding failed: {e}")))
336    }
337
338    /// Native incremental decode using the HF `step_decode_stream`.
339    ///
340    /// This delegates to the same algorithm the default trait method uses, but
341    /// the two internal `decode()` calls go directly through the concrete
342    /// `TokenizerImpl` rather than through `dyn Decoder` vtable dispatch.
343    fn decode_step(
344        &self,
345        token_id: TokenIdType,
346        ids: &mut Vec<TokenIdType>,
347        prefix: &mut String,
348        prefix_index: &mut usize,
349        skip_special_tokens: bool,
350    ) -> Result<Option<String>> {
351        step_decode_stream(
352            &self.tokenizer,
353            vec![token_id],
354            skip_special_tokens,
355            ids,
356            prefix,
357            prefix_index,
358        )
359        .map_err(|e| Error::msg(format!("Decode stream error: {e}")))
360    }
361}
362
363impl TokenizerTrait for HuggingFaceTokenizer {
364    fn vocab_size(&self) -> usize {
365        self.tokenizer.get_vocab_size(false)
366    }
367
368    fn get_special_tokens(&self) -> &SpecialTokens {
369        &self.special_tokens
370    }
371
372    fn token_to_id(&self, token: &str) -> Option<TokenIdType> {
373        self.vocab.get(token).copied()
374    }
375
376    fn id_to_token(&self, id: TokenIdType) -> Option<String> {
377        self.reverse_vocab.get(&id).cloned()
378    }
379
380    fn as_any(&self) -> &dyn std::any::Any {
381        self
382    }
383
384    fn eos_token_ids(&self) -> &[TokenIdType] {
385        &self.eos_token_ids
386    }
387
388    fn apply_chat_template(
389        &self,
390        messages: &[serde_json::Value],
391        params: ChatTemplateParams,
392    ) -> Result<String> {
393        match self.renderer {
394            Renderer::Jinja => {
395                // Inject special tokens if the caller didn't provide them.
396                if params.special_tokens.is_some() {
397                    return self.chat_template.apply(messages, params);
398                }
399                let params = ChatTemplateParams {
400                    special_tokens: Some(&self.special_tokens),
401                    ..params
402                };
403                self.chat_template.apply(messages, params)
404            }
405            Renderer::DeepseekV32 => apply_deepseek_v32(messages, &params),
406            Renderer::DeepseekV4 => apply_deepseek_v4(messages, &params),
407        }
408    }
409
410    fn chat_template_content_format(&self) -> ChatTemplateContentFormat {
411        self.chat_template.content_format()
412    }
413
414    fn thinking_toggle(&self) -> ThinkingToggle {
415        match self.renderer {
416            // DeepSeek V3.2 and V4 encoders gate thinking on the `thinking`
417            // kwarg, default off. The Jinja processor has no knowledge of
418            // the native encoder so we must report it directly.
419            Renderer::DeepseekV32 | Renderer::DeepseekV4 => ThinkingToggle::DefaultOff,
420            Renderer::Jinja => self.chat_template.thinking_toggle(),
421        }
422    }
423
424    fn thinking_key_name(&self) -> Option<ThinkingKeyName> {
425        match self.renderer {
426            Renderer::DeepseekV32 | Renderer::DeepseekV4 => Some(ThinkingKeyName::Thinking),
427            Renderer::Jinja => self.chat_template.thinking_key_name(),
428        }
429    }
430    fn think_in_prefill(&self) -> bool {
431        match self.renderer {
432            // Both encoders emit `<|Assistant|><think>` at the end of the
433            // prompt when thinking mode is on; the completion therefore starts
434            // mid-reasoning and the parser must be told so.
435            Renderer::DeepseekV32 | Renderer::DeepseekV4 => true,
436            Renderer::Jinja => self.chat_template.think_in_prefill(),
437        }
438    }
439
440    fn set_chat_template(&mut self, template: String) -> Result<()> {
441        self.chat_template.set(template)
442    }
443}
444
445// ---------------------------------------------------------------------------
446// Renderer detection (config.json::architectures)
447// ---------------------------------------------------------------------------
448/// Inspect the sibling `config.json` to decide which chat-template renderer to
449/// use. A missing or malformed file falls back to [`Renderer::Jinja`] without
450/// erroring (debug-logged), preserving backward compatibility for every model
451/// not in the architecture list.
452fn detect_renderer_from_config(dir: &std::path::Path) -> Renderer {
453    let path = dir.join("config.json");
454    if !path.exists() {
455        return Renderer::Jinja;
456    }
457    let content = match std::fs::read_to_string(&path) {
458        Ok(c) => c,
459        Err(err) => {
460            debug!(?err, ?path, "config.json unreadable; using Jinja renderer");
461            return Renderer::Jinja;
462        }
463    };
464    let value: serde_json::Value = match serde_json::from_str(&content) {
465        Ok(v) => v,
466        Err(err) => {
467            debug!(?err, ?path, "config.json malformed; using Jinja renderer");
468            return Renderer::Jinja;
469        }
470    };
471    let architectures = value.get("architectures").and_then(|v| v.as_array());
472    let arch_strs: Vec<&str> = architectures
473        .map(|a| a.iter().filter_map(|v| v.as_str()).collect())
474        .unwrap_or_default();
475    if arch_strs.contains(&"DeepseekV32ForCausalLM") {
476        debug!(?path, "selected DeepseekV32 chat-template renderer");
477        return Renderer::DeepseekV32;
478    }
479    if arch_strs.contains(&"DeepseekV4ForCausalLM") {
480        debug!(?path, "selected DeepseekV4 chat-template renderer");
481        return Renderer::DeepseekV4;
482    }
483    Renderer::Jinja
484}
485
486// ---------------------------------------------------------------------------
487// DeepSeek V3.2 / V4 dispatch shims
488// ---------------------------------------------------------------------------
489/// Derive the V3.2 / V4 thinking mode from `template_kwargs`. Only the
490/// `thinking` key is honored, matching sglang's DeepSeek serving path and
491/// the `ThinkingKeyName::Thinking` contract reported by this tokenizer.
492fn derive_thinking_mode(params: &ChatTemplateParams) -> deepseek_v32::ThinkingMode {
493    let enabled = params
494        .template_kwargs
495        .and_then(|k| k.get("thinking"))
496        .and_then(serde_json::Value::as_bool)
497        .unwrap_or(false);
498    if enabled {
499        deepseek_v32::ThinkingMode::Thinking
500    } else {
501        deepseek_v32::ThinkingMode::Chat
502    }
503}
504
505/// Per DeepSeek's encoding README, preserve all reasoning when a system or
506/// developer message declares `tools`; otherwise drop earlier reasoning.
507fn resolve_drop_thinking(messages: &[serde_json::Value]) -> bool {
508    !messages.iter().any(|m| {
509        let role = m.get("role").and_then(|r| r.as_str());
510        matches!(role, Some("system" | "developer"))
511            && m.get("tools")
512                .and_then(|t| t.as_array())
513                .is_some_and(|arr| !arr.is_empty())
514    })
515}
516/// Attach `tools` to a leading system/developer message so the V3.2/V4
517/// encoder can render the tools block. Mirrors the wrapper step in
518/// vllm's `vllm/tokenizers/deepseek_v32.py` and sglang's V4 serving path.
519/// Returns `None` when no rewrite is needed so callers can pass the input
520/// slice directly in the common path.
521fn inject_tools_into_messages(
522    messages: &[serde_json::Value],
523    tools: Option<&[serde_json::Value]>,
524) -> Option<Vec<serde_json::Value>> {
525    let tools = tools?;
526    if tools.is_empty() {
527        return None;
528    }
529    let mut owned: Vec<serde_json::Value> = messages.to_vec();
530    let first_role = owned
531        .first()
532        .and_then(|m| m.get("role"))
533        .and_then(|r| r.as_str());
534    if !matches!(first_role, Some("system" | "developer")) {
535        owned.insert(0, serde_json::json!({ "role": "system", "content": "" }));
536    }
537    if let Some(obj) = owned[0].as_object_mut() {
538        obj.insert("tools".into(), serde_json::Value::Array(tools.to_vec()));
539    }
540    Some(owned)
541}
542
543fn apply_deepseek_v32(
544    messages: &[serde_json::Value],
545    params: &ChatTemplateParams,
546) -> Result<String> {
547    let owned = inject_tools_into_messages(messages, params.tools);
548    let msgs: &[serde_json::Value] = owned.as_deref().unwrap_or(messages);
549    let thinking_mode = derive_thinking_mode(params);
550    let encode_params = deepseek_v32::EncodeParams {
551        add_default_bos_token: true,
552        drop_thinking: resolve_drop_thinking(msgs),
553    };
554    deepseek_v32::encode_messages(msgs, thinking_mode, &encode_params)
555        .map_err(|e| Error::msg(format!("DeepSeek V3.2 encode failed: {e}")))
556}
557fn apply_deepseek_v4(
558    messages: &[serde_json::Value],
559    params: &ChatTemplateParams,
560) -> Result<String> {
561    let owned = inject_tools_into_messages(messages, params.tools);
562    let msgs: &[serde_json::Value] = owned.as_deref().unwrap_or(messages);
563    let thinking_mode = derive_thinking_mode(params);
564    let reasoning_effort = params
565        .template_kwargs
566        .and_then(|k| k.get("reasoning_effort"))
567        .and_then(|v| v.as_str())
568        .and_then(|s| match s {
569            "max" => Some(deepseek_v4::ReasoningEffort::Max),
570            "high" => Some(deepseek_v4::ReasoningEffort::High),
571            _ => None,
572        });
573    let encode_params = deepseek_v4::EncodeParams {
574        add_default_bos_token: true,
575        drop_thinking: resolve_drop_thinking(msgs),
576        reasoning_effort,
577    };
578    deepseek_v4::encode_messages(msgs, thinking_mode, &encode_params)
579        .map_err(|e| Error::msg(format!("DeepSeek V4 encode failed: {e}")))
580}