llm_tokenizer/
huggingface.rs

1use std::collections::HashMap;
2
3use anyhow::{Error, Result};
4use tokenizers::{processors::template::TemplateProcessing, tokenizer::Tokenizer as HfTokenizer};
5use tracing::debug;
6
7use crate::{
8    chat_template::{
9        detect_chat_template_content_format, ChatTemplateContentFormat, ChatTemplateParams,
10        ChatTemplateProcessor,
11    },
12    traits::{Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait},
13};
14
15/// HuggingFace tokenizer wrapper
16pub struct HuggingFaceTokenizer {
17    tokenizer: HfTokenizer,
18    special_tokens: SpecialTokens,
19    vocab: HashMap<String, TokenIdType>,
20    reverse_vocab: HashMap<TokenIdType, String>,
21    chat_template: Option<String>,
22    /// Detected chat template content format (computed once at initialization)
23    content_format: ChatTemplateContentFormat,
24}
25
26impl HuggingFaceTokenizer {
27    /// Create a tokenizer from a HuggingFace tokenizer JSON file
28    pub fn from_file(file_path: &str) -> Result<Self> {
29        // Try to auto-discover chat template if not explicitly provided
30        let path = std::path::Path::new(file_path);
31        let chat_template_path = path
32            .parent()
33            .and_then(crate::factory::discover_chat_template_in_dir);
34        Self::from_file_with_chat_template(file_path, chat_template_path.as_deref())
35    }
36
37    /// Create a tokenizer from a HuggingFace tokenizer JSON file with an optional chat template
38    pub fn from_file_with_chat_template(
39        file_path: &str,
40        chat_template_path: Option<&str>,
41    ) -> Result<Self> {
42        let mut tokenizer = HfTokenizer::from_file(file_path)
43            .map_err(|e| Error::msg(format!("Failed to load tokenizer: {}", e)))?;
44
45        // Extract special tokens
46        let special_tokens = Self::extract_special_tokens(&tokenizer);
47
48        // Build vocab mappings (include special tokens to get added_tokens like <|im_start|>)
49        let vocab = tokenizer.get_vocab(true); // true = include special tokens and added_tokens
50        let reverse_vocab: HashMap<TokenIdType, String> = vocab
51            .iter()
52            .map(|(token, &id)| (id, token.clone()))
53            .collect();
54
55        // Load chat template and tokenizer config
56        let (chat_template, add_bos_token, add_eos_token) =
57            if let Some(template_path) = chat_template_path {
58                // Load from specified .jinja file
59                (
60                    Self::load_chat_template_from_file(template_path)?,
61                    None,
62                    None,
63                )
64            } else {
65                // Try to load from tokenizer_config.json
66                Self::load_chat_template_and_config(file_path)
67            };
68
69        // Detect content format once at initialization
70        let content_format = if let Some(ref template) = chat_template {
71            detect_chat_template_content_format(template)
72        } else {
73            ChatTemplateContentFormat::String // Default if no template
74        };
75
76        // Configure post_processor based on tokenizer_config.json (matches Python transformers)
77        // Only modify when at least one setting is explicitly true
78        let needs_eos = add_eos_token == Some(true);
79        let needs_bos = match add_bos_token {
80            Some(true) => true,
81            Some(false) => false,
82            // Not set: preserve existing behavior from tokenizer.json
83            None => needs_eos && Self::tokenizer_adds_special_tokens(&tokenizer),
84        };
85
86        if needs_bos || needs_eos {
87            if let Some(post_processor) =
88                Self::build_post_processor(needs_bos, needs_eos, &special_tokens, &vocab)
89            {
90                debug!(needs_bos, needs_eos, "Configured post_processor");
91                tokenizer.with_post_processor(Some(post_processor));
92            }
93        }
94
95        Ok(HuggingFaceTokenizer {
96            tokenizer,
97            special_tokens,
98            vocab,
99            reverse_vocab,
100            chat_template,
101            content_format,
102        })
103    }
104
105    /// Check if the tokenizer's post_processor adds special tokens (e.g., BOS)
106    fn tokenizer_adds_special_tokens(tokenizer: &HfTokenizer) -> bool {
107        tokenizer
108            .encode("", true)
109            .map(|enc| !enc.get_ids().is_empty())
110            .unwrap_or(false)
111    }
112
113    /// Build a TemplateProcessing post_processor (matches Python transformers' update_post_processor)
114    /// Template format: "{bos}:0 $A:0 {eos}:0" with optional BOS/EOS based on config
115    fn build_post_processor(
116        add_bos_token: bool,
117        add_eos_token: bool,
118        special_tokens: &SpecialTokens,
119        vocab: &HashMap<String, TokenIdType>,
120    ) -> Option<TemplateProcessing> {
121        // Build template string exactly like Python:
122        // single = f"{(bos + ':0 ') if add_bos_token else ''}$A:0{(' ' + eos + ':0') if add_eos_token else ''}"
123        let mut template = String::with_capacity(32);
124        let mut tokens = Vec::with_capacity(2);
125
126        if add_bos_token {
127            let bos = special_tokens.bos_token.as_ref()?;
128            let bos_id = vocab.get(bos).copied()?;
129            template.push_str(bos);
130            template.push_str(":0 ");
131            tokens.push((bos.clone(), bos_id));
132        }
133
134        template.push_str("$A:0");
135
136        if add_eos_token {
137            let eos = special_tokens.eos_token.as_ref()?;
138            let eos_id = vocab.get(eos).copied()?;
139            template.push(' ');
140            template.push_str(eos);
141            template.push_str(":0");
142            tokens.push((eos.clone(), eos_id));
143        }
144
145        TemplateProcessing::builder()
146            .try_single(template.as_str())
147            .ok()?
148            .special_tokens(tokens)
149            .build()
150            .ok()
151    }
152
153    /// Create from an existing HuggingFace tokenizer
154    pub fn from_tokenizer(tokenizer: HfTokenizer) -> Self {
155        let special_tokens = Self::extract_special_tokens(&tokenizer);
156        let vocab = tokenizer.get_vocab(true); // true = include special tokens and added_tokens
157        let reverse_vocab: HashMap<TokenIdType, String> = vocab
158            .iter()
159            .map(|(token, &id)| (id, token.clone()))
160            .collect();
161
162        HuggingFaceTokenizer {
163            tokenizer,
164            special_tokens,
165            vocab,
166            reverse_vocab,
167            chat_template: None,
168            content_format: ChatTemplateContentFormat::String, // Default
169        }
170    }
171
172    /// Extract special tokens from the tokenizer
173    fn extract_special_tokens(tokenizer: &HfTokenizer) -> SpecialTokens {
174        // Get vocab with special tokens included (added_tokens like <|im_start|>)
175        let vocab = tokenizer.get_vocab(true);
176
177        let find_token = |patterns: &[&str]| -> Option<String> {
178            for pattern in patterns {
179                if vocab.contains_key(*pattern) {
180                    return Some(pattern.to_string());
181                }
182            }
183            None
184        };
185
186        // Extract additional special tokens using the tokenizers library API
187        let additional_special_tokens: Vec<String> = tokenizer
188            .get_added_tokens_decoder()
189            .iter()
190            .filter(|(_id, token)| token.special) // Only tokens marked as special: true
191            .map(|(_id, token)| token.content.clone())
192            .collect();
193
194        SpecialTokens {
195            bos_token: find_token(&["<s>", "<|startoftext|>", "<BOS>", "[CLS]"]),
196            eos_token: find_token(&["</s>", "<|endoftext|>", "<EOS>", "[SEP]"]),
197            unk_token: find_token(&["<unk>", "<UNK>", "[UNK]"]),
198            sep_token: find_token(&["[SEP]", "<sep>", "<SEP>"]),
199            pad_token: find_token(&["<pad>", "<PAD>", "[PAD]"]),
200            cls_token: find_token(&["[CLS]", "<cls>", "<CLS>"]),
201            mask_token: find_token(&["[MASK]", "<mask>", "<MASK>"]),
202            additional_special_tokens,
203        }
204    }
205
206    /// Load chat template and special token settings from tokenizer_config.json
207    /// Returns Option<bool> to distinguish between explicit false vs not set
208    fn load_chat_template_and_config(
209        tokenizer_path: &str,
210    ) -> (Option<String>, Option<bool>, Option<bool>) {
211        (|| {
212            let path = std::path::Path::new(tokenizer_path);
213            let config_path = path.parent()?.join("tokenizer_config.json");
214
215            if !config_path.exists() {
216                return None;
217            }
218
219            let config_str = config_path.to_str()?;
220            let content = std::fs::read_to_string(&config_path).ok()?;
221            let config: serde_json::Value = serde_json::from_str(&content).ok()?;
222
223            let chat_template = super::chat_template::load_chat_template_from_config(config_str)
224                .ok()
225                .flatten();
226
227            let add_bos_token = config.get("add_bos_token").and_then(|v| v.as_bool());
228            let add_eos_token = config.get("add_eos_token").and_then(|v| v.as_bool());
229
230            Some((chat_template, add_bos_token, add_eos_token))
231        })()
232        .unwrap_or((None, None, None))
233    }
234
235    /// Load chat template from a file (.jinja or .json containing Jinja)
236    fn load_chat_template_from_file(template_path: &str) -> Result<Option<String>> {
237        use std::fs;
238
239        let content = fs::read_to_string(template_path)
240            .map_err(|e| Error::msg(format!("Failed to read chat template file: {}", e)))?;
241
242        // Check if it's a JSON file containing a Jinja template
243        if template_path.ends_with(".json") {
244            // Parse JSON and extract the template string
245            let json_value: serde_json::Value = serde_json::from_str(&content)
246                .map_err(|e| Error::msg(format!("Failed to parse chat_template.json: {}", e)))?;
247
248            if let Some(template_str) = json_value.as_str() {
249                return Ok(Some(template_str.to_string()));
250            } else if let Some(obj) = json_value.as_object() {
251                if let Some(template_value) = obj.get("chat_template") {
252                    if let Some(template_str) = template_value.as_str() {
253                        return Ok(Some(template_str.to_string()));
254                    }
255                }
256            }
257
258            return Err(Error::msg(
259                "chat_template.json does not contain a valid template",
260            ));
261        }
262
263        // Otherwise it's a plain .jinja file
264        // Clean up the template (similar to Python implementation)
265        let template = content.trim().replace("\\n", "\n");
266
267        Ok(Some(template))
268    }
269
270    /// Set or override the chat template
271    pub fn set_chat_template(&mut self, template: String) {
272        // Detect format for the new template
273        self.content_format = detect_chat_template_content_format(&template);
274        self.chat_template = Some(template);
275    }
276
277    /// Get the content format expected by the chat template
278    pub fn chat_template_content_format(&self) -> ChatTemplateContentFormat {
279        self.content_format
280    }
281
282    /// Apply chat template if available
283    ///
284    /// Takes transformed JSON Values (already transformed based on content format)
285    pub fn apply_chat_template(
286        &self,
287        messages: &[serde_json::Value],
288        params: ChatTemplateParams,
289    ) -> Result<String> {
290        if let Some(ref template) = self.chat_template {
291            let processor = ChatTemplateProcessor::new(template.clone());
292            processor.apply_chat_template(messages, params)
293        } else {
294            Err(Error::msg(
295                "Cannot use chat template functions because tokenizer.chat_template is not set and no template \
296                argument was passed! For information about writing templates and setting the \
297                tokenizer.chat_template attribute, please see the documentation at \
298                https://huggingface.co/docs/transformers/main/en/chat_templating"
299            ))
300        }
301    }
302}
303
304impl Encoder for HuggingFaceTokenizer {
305    fn encode(&self, input: &str, add_special_tokens: bool) -> Result<Encoding> {
306        self.tokenizer
307            .encode(input, add_special_tokens)
308            .map_err(|e| Error::msg(format!("Encoding failed: {}", e)))
309            .map(|encoding| Encoding::Hf(Box::new(encoding)))
310    }
311
312    fn encode_batch(&self, inputs: &[&str], add_special_tokens: bool) -> Result<Vec<Encoding>> {
313        self.tokenizer
314            .encode_batch(inputs.to_vec(), add_special_tokens)
315            .map_err(|e| Error::msg(format!("Batch encoding failed: {}", e)))
316            .map(|encodings| {
317                encodings
318                    .into_iter()
319                    .map(|e| Encoding::Hf(Box::new(e)))
320                    .collect()
321            })
322    }
323}
324
325impl Decoder for HuggingFaceTokenizer {
326    fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String> {
327        self.tokenizer
328            .decode(token_ids, skip_special_tokens)
329            .map_err(|e| Error::msg(format!("Decoding failed: {}", e)))
330    }
331}
332
333impl TokenizerTrait for HuggingFaceTokenizer {
334    fn vocab_size(&self) -> usize {
335        self.tokenizer.get_vocab_size(false)
336    }
337
338    fn get_special_tokens(&self) -> &SpecialTokens {
339        &self.special_tokens
340    }
341
342    fn token_to_id(&self, token: &str) -> Option<TokenIdType> {
343        self.vocab.get(token).copied()
344    }
345
346    fn id_to_token(&self, id: TokenIdType) -> Option<String> {
347        self.reverse_vocab.get(&id).cloned()
348    }
349
350    fn as_any(&self) -> &dyn std::any::Any {
351        self
352    }
353}
354
355#[cfg(test)]
356mod tests {
357    // Note: Actual tokenizer tests would require a real tokenizer file
358    // These would be integration tests rather than unit tests
359}