Skip to main content

llm_tokenizer/
huggingface.rs

1use std::collections::HashMap;
2
3use anyhow::{Error, Result};
4use tokenizers::{
5    processors::template::TemplateProcessing,
6    tokenizer::{step_decode_stream, Tokenizer as HfTokenizer},
7};
8use tracing::debug;
9
10use crate::{
11    chat_template::{
12        load_chat_template_from_file, ChatTemplateContentFormat, ChatTemplateParams,
13        ChatTemplateState, ThinkingKeyName, ThinkingToggle,
14    },
15    traits::{Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait},
16};
17
18/// HuggingFace tokenizer wrapper
19pub struct HuggingFaceTokenizer {
20    tokenizer: HfTokenizer,
21    special_tokens: SpecialTokens,
22    vocab: HashMap<String, TokenIdType>,
23    reverse_vocab: HashMap<TokenIdType, String>,
24    chat_template: ChatTemplateState,
25}
26
27impl HuggingFaceTokenizer {
28    /// Create a tokenizer from a HuggingFace tokenizer JSON file
29    pub fn from_file(file_path: &str) -> Result<Self> {
30        // Try to auto-discover chat template if not explicitly provided
31        let path = std::path::Path::new(file_path);
32        let chat_template_path = path
33            .parent()
34            .and_then(crate::factory::discover_chat_template_in_dir);
35        Self::from_file_with_chat_template(file_path, chat_template_path.as_deref())
36    }
37
38    /// Create a tokenizer from a HuggingFace tokenizer JSON file with an optional chat template
39    pub fn from_file_with_chat_template(
40        file_path: &str,
41        chat_template_path: Option<&str>,
42    ) -> Result<Self> {
43        let mut tokenizer = HfTokenizer::from_file(file_path)
44            .map_err(|e| Error::msg(format!("Failed to load tokenizer: {e}")))?;
45
46        // Build vocab mappings (include special tokens to get added_tokens like <|im_start|>)
47        let vocab = tokenizer.get_vocab(true); // true = include special tokens and added_tokens
48        let reverse_vocab: HashMap<TokenIdType, String> = vocab
49            .iter()
50            .map(|(token, &id)| (id, token.clone()))
51            .collect();
52
53        // Load tokenizer_config.json once for chat template, add_bos/eos, and special tokens
54        let config_result = Self::load_chat_template_and_config(file_path);
55        let mut chat_template_str = config_result.chat_template;
56        let add_bos_token = config_result.add_bos_token;
57        let add_eos_token = config_result.add_eos_token;
58
59        // Extract special tokens — config values override vocab pattern matching
60        let special_tokens = Self::extract_special_tokens(&tokenizer, &config_result.config_tokens);
61
62        if let Some(template_path) = chat_template_path {
63            chat_template_str = load_chat_template_from_file(template_path)?;
64        }
65
66        // Configure post_processor based on tokenizer_config.json (matches Python transformers)
67        // Only modify when at least one setting is explicitly true
68        let needs_eos = add_eos_token == Some(true);
69        let needs_bos = match add_bos_token {
70            Some(true) => true,
71            Some(false) => false,
72            // Not set: preserve existing behavior from tokenizer.json
73            None => needs_eos && Self::tokenizer_adds_special_tokens(&tokenizer),
74        };
75
76        if needs_bos || needs_eos {
77            if let Some(post_processor) =
78                Self::build_post_processor(needs_bos, needs_eos, &special_tokens, &vocab)
79            {
80                debug!(needs_bos, needs_eos, "Configured post_processor");
81                tokenizer.with_post_processor(Some(post_processor));
82            }
83        }
84
85        Ok(HuggingFaceTokenizer {
86            tokenizer,
87            special_tokens,
88            vocab,
89            reverse_vocab,
90            chat_template: ChatTemplateState::new(chat_template_str)?,
91        })
92    }
93
94    /// Check if the tokenizer's post_processor adds special tokens (e.g., BOS)
95    fn tokenizer_adds_special_tokens(tokenizer: &HfTokenizer) -> bool {
96        tokenizer
97            .encode("", true)
98            .map(|enc| !enc.get_ids().is_empty())
99            .unwrap_or(false)
100    }
101
102    /// Build a TemplateProcessing post_processor (matches Python transformers' update_post_processor)
103    /// Template format: "{bos}:0 $A:0 {eos}:0" with optional BOS/EOS based on config
104    fn build_post_processor(
105        add_bos_token: bool,
106        add_eos_token: bool,
107        special_tokens: &SpecialTokens,
108        vocab: &HashMap<String, TokenIdType>,
109    ) -> Option<TemplateProcessing> {
110        // Build template string exactly like Python:
111        // single = f"{(bos + ':0 ') if add_bos_token else ''}$A:0{(' ' + eos + ':0') if add_eos_token else ''}"
112        let mut template = String::with_capacity(32);
113        let mut tokens = Vec::with_capacity(2);
114
115        if add_bos_token {
116            let bos = special_tokens.bos_token.as_ref()?;
117            let bos_id = vocab.get(bos).copied()?;
118            template.push_str(bos);
119            template.push_str(":0 ");
120            tokens.push((bos.clone(), bos_id));
121        }
122
123        template.push_str("$A:0");
124
125        if add_eos_token {
126            let eos = special_tokens.eos_token.as_ref()?;
127            let eos_id = vocab.get(eos).copied()?;
128            template.push(' ');
129            template.push_str(eos);
130            template.push_str(":0");
131            tokens.push((eos.clone(), eos_id));
132        }
133
134        TemplateProcessing::builder()
135            .try_single(template.as_str())
136            .ok()?
137            .special_tokens(tokens)
138            .build()
139            .ok()
140    }
141
142    /// Create from an existing HuggingFace tokenizer
143    pub fn from_tokenizer(tokenizer: HfTokenizer) -> Self {
144        let special_tokens = Self::extract_special_tokens(&tokenizer, &ConfigTokens::default());
145        let vocab = tokenizer.get_vocab(true); // true = include special tokens and added_tokens
146        let reverse_vocab: HashMap<TokenIdType, String> = vocab
147            .iter()
148            .map(|(token, &id)| (id, token.clone()))
149            .collect();
150
151        HuggingFaceTokenizer {
152            tokenizer,
153            special_tokens,
154            vocab,
155            reverse_vocab,
156            chat_template: ChatTemplateState::empty(),
157        }
158    }
159
160    /// Extract special tokens from the tokenizer, using config values when available.
161    ///
162    /// Prefers explicit values from `tokenizer_config.json` (e.g., `"bos_token": "<|begin_of_text|>"`)
163    /// over pattern matching against the vocabulary, since models like Llama 4 use non-standard
164    /// token names that aren't in the hardcoded pattern list.
165    fn extract_special_tokens(
166        tokenizer: &HfTokenizer,
167        config_tokens: &ConfigTokens,
168    ) -> SpecialTokens {
169        // Get vocab with special tokens included (added_tokens like <|im_start|>)
170        let vocab = tokenizer.get_vocab(true);
171
172        let find_token = |patterns: &[&str]| -> Option<String> {
173            for pattern in patterns {
174                if vocab.contains_key(*pattern) {
175                    return Some((*pattern).to_string());
176                }
177            }
178            None
179        };
180
181        // Extract additional special tokens using the tokenizers library API
182        let additional_special_tokens: Vec<String> = tokenizer
183            .get_added_tokens_decoder()
184            .iter()
185            .filter(|(_id, token)| token.special)
186            .map(|(_id, token)| token.content.clone())
187            .collect();
188
189        // Config values take priority over pattern matching
190        SpecialTokens {
191            bos_token: config_tokens
192                .bos_token
193                .clone()
194                .or_else(|| find_token(&["<s>", "<|startoftext|>", "<BOS>", "[CLS]"])),
195            eos_token: config_tokens
196                .eos_token
197                .clone()
198                .or_else(|| find_token(&["</s>", "<|endoftext|>", "<EOS>", "[SEP]"])),
199            unk_token: config_tokens
200                .unk_token
201                .clone()
202                .or_else(|| find_token(&["<unk>", "<UNK>", "[UNK]"])),
203            sep_token: find_token(&["[SEP]", "<sep>", "<SEP>"]),
204            pad_token: config_tokens
205                .pad_token
206                .clone()
207                .or_else(|| find_token(&["<pad>", "<PAD>", "[PAD]"])),
208            cls_token: find_token(&["[CLS]", "<cls>", "<CLS>"]),
209            mask_token: find_token(&["[MASK]", "<mask>", "<MASK>"]),
210            additional_special_tokens,
211        }
212    }
213
214    /// Load chat template, special token settings, and token strings from tokenizer_config.json.
215    /// Reads the file once and extracts everything needed by the tokenizer constructor.
216    fn load_chat_template_and_config(tokenizer_path: &str) -> TokenizerConfigResult {
217        (|| {
218            let path = std::path::Path::new(tokenizer_path);
219            let config_path = path.parent()?.join("tokenizer_config.json");
220
221            if !config_path.exists() {
222                return None;
223            }
224
225            let content = std::fs::read_to_string(&config_path).ok()?;
226            let config: serde_json::Value = serde_json::from_str(&content).ok()?;
227
228            // Extract chat template directly from parsed config (avoid re-reading the file)
229            let chat_template = config
230                .get("chat_template")
231                .and_then(|v| v.as_str())
232                .map(String::from);
233
234            let add_bos_token = config.get("add_bos_token").and_then(|v| v.as_bool());
235            let add_eos_token = config.get("add_eos_token").and_then(|v| v.as_bool());
236
237            // Extract special token strings (handles both "string" and {"content": "string"})
238            let get_token = |key: &str| -> Option<String> {
239                config.get(key).and_then(|v| {
240                    v.as_str()
241                        .map(String::from)
242                        .or_else(|| v.get("content").and_then(|c| c.as_str()).map(String::from))
243                })
244            };
245
246            let config_tokens = ConfigTokens {
247                bos_token: get_token("bos_token"),
248                eos_token: get_token("eos_token"),
249                unk_token: get_token("unk_token"),
250                pad_token: get_token("pad_token"),
251            };
252
253            Some(TokenizerConfigResult {
254                chat_template,
255                add_bos_token,
256                add_eos_token,
257                config_tokens,
258            })
259        })()
260        .unwrap_or_default()
261    }
262}
263
264/// Special token strings read from tokenizer_config.json.
265#[derive(Default)]
266struct ConfigTokens {
267    bos_token: Option<String>,
268    eos_token: Option<String>,
269    unk_token: Option<String>,
270    pad_token: Option<String>,
271}
272
273/// Result of parsing tokenizer_config.json.
274#[derive(Default)]
275struct TokenizerConfigResult {
276    chat_template: Option<String>,
277    add_bos_token: Option<bool>,
278    add_eos_token: Option<bool>,
279    config_tokens: ConfigTokens,
280}
281
282impl Encoder for HuggingFaceTokenizer {
283    fn encode(&self, input: &str, add_special_tokens: bool) -> Result<Encoding> {
284        self.tokenizer
285            .encode(input, add_special_tokens)
286            .map_err(|e| Error::msg(format!("Encoding failed: {e}")))
287            .map(|encoding| Encoding::Hf(Box::new(encoding)))
288    }
289
290    fn encode_batch(&self, inputs: &[&str], add_special_tokens: bool) -> Result<Vec<Encoding>> {
291        self.tokenizer
292            .encode_batch(inputs.to_vec(), add_special_tokens)
293            .map_err(|e| Error::msg(format!("Batch encoding failed: {e}")))
294            .map(|encodings| {
295                encodings
296                    .into_iter()
297                    .map(|e| Encoding::Hf(Box::new(e)))
298                    .collect()
299            })
300    }
301}
302
303impl Decoder for HuggingFaceTokenizer {
304    fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String> {
305        self.tokenizer
306            .decode(token_ids, skip_special_tokens)
307            .map_err(|e| Error::msg(format!("Decoding failed: {e}")))
308    }
309
310    /// Native incremental decode using the HF `step_decode_stream`.
311    ///
312    /// This delegates to the same algorithm the default trait method uses, but
313    /// the two internal `decode()` calls go directly through the concrete
314    /// `TokenizerImpl` rather than through `dyn Decoder` vtable dispatch.
315    fn decode_step(
316        &self,
317        token_id: TokenIdType,
318        ids: &mut Vec<TokenIdType>,
319        prefix: &mut String,
320        prefix_index: &mut usize,
321        skip_special_tokens: bool,
322    ) -> Result<Option<String>> {
323        step_decode_stream(
324            &self.tokenizer,
325            vec![token_id],
326            skip_special_tokens,
327            ids,
328            prefix,
329            prefix_index,
330        )
331        .map_err(|e| Error::msg(format!("Decode stream error: {e}")))
332    }
333}
334
335impl TokenizerTrait for HuggingFaceTokenizer {
336    fn vocab_size(&self) -> usize {
337        self.tokenizer.get_vocab_size(false)
338    }
339
340    fn get_special_tokens(&self) -> &SpecialTokens {
341        &self.special_tokens
342    }
343
344    fn token_to_id(&self, token: &str) -> Option<TokenIdType> {
345        self.vocab.get(token).copied()
346    }
347
348    fn id_to_token(&self, id: TokenIdType) -> Option<String> {
349        self.reverse_vocab.get(&id).cloned()
350    }
351
352    fn as_any(&self) -> &dyn std::any::Any {
353        self
354    }
355
356    fn apply_chat_template(
357        &self,
358        messages: &[serde_json::Value],
359        params: ChatTemplateParams,
360    ) -> Result<String> {
361        // Inject special tokens if the caller didn't provide them
362        if params.special_tokens.is_some() {
363            return self.chat_template.apply(messages, params);
364        }
365        let params = ChatTemplateParams {
366            special_tokens: Some(&self.special_tokens),
367            ..params
368        };
369        self.chat_template.apply(messages, params)
370    }
371
372    fn chat_template_content_format(&self) -> ChatTemplateContentFormat {
373        self.chat_template.content_format()
374    }
375
376    fn thinking_toggle(&self) -> ThinkingToggle {
377        self.chat_template.thinking_toggle()
378    }
379
380    fn thinking_key_name(&self) -> Option<ThinkingKeyName> {
381        self.chat_template.thinking_key_name()
382    }
383    fn think_in_prefill(&self) -> bool {
384        self.chat_template.think_in_prefill()
385    }
386
387    fn set_chat_template(&mut self, template: String) -> Result<()> {
388        self.chat_template.set(template)
389    }
390}