Skip to main content

llm_tokenizer/
huggingface.rs

1use std::collections::HashMap;
2
3use anyhow::{Error, Result};
4use tokenizers::{processors::template::TemplateProcessing, tokenizer::Tokenizer as HfTokenizer};
5use tracing::debug;
6
7use crate::{
8    chat_template::{
9        load_chat_template_from_file, ChatTemplateContentFormat, ChatTemplateParams,
10        ChatTemplateState,
11    },
12    traits::{Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait},
13};
14
15/// HuggingFace tokenizer wrapper
16pub struct HuggingFaceTokenizer {
17    tokenizer: HfTokenizer,
18    special_tokens: SpecialTokens,
19    vocab: HashMap<String, TokenIdType>,
20    reverse_vocab: HashMap<TokenIdType, String>,
21    chat_template: ChatTemplateState,
22}
23
24impl HuggingFaceTokenizer {
25    /// Create a tokenizer from a HuggingFace tokenizer JSON file
26    pub fn from_file(file_path: &str) -> Result<Self> {
27        // Try to auto-discover chat template if not explicitly provided
28        let path = std::path::Path::new(file_path);
29        let chat_template_path = path
30            .parent()
31            .and_then(crate::factory::discover_chat_template_in_dir);
32        Self::from_file_with_chat_template(file_path, chat_template_path.as_deref())
33    }
34
35    /// Create a tokenizer from a HuggingFace tokenizer JSON file with an optional chat template
36    pub fn from_file_with_chat_template(
37        file_path: &str,
38        chat_template_path: Option<&str>,
39    ) -> Result<Self> {
40        let mut tokenizer = HfTokenizer::from_file(file_path)
41            .map_err(|e| Error::msg(format!("Failed to load tokenizer: {e}")))?;
42
43        // Extract special tokens
44        let special_tokens = Self::extract_special_tokens(&tokenizer);
45
46        // Build vocab mappings (include special tokens to get added_tokens like <|im_start|>)
47        let vocab = tokenizer.get_vocab(true); // true = include special tokens and added_tokens
48        let reverse_vocab: HashMap<TokenIdType, String> = vocab
49            .iter()
50            .map(|(token, &id)| (id, token.clone()))
51            .collect();
52
53        // Always load tokenizer_config.json for add_bos_token/add_eos_token,
54        // then override only the chat template string when an explicit path is provided.
55        let (mut chat_template_str, add_bos_token, add_eos_token) =
56            Self::load_chat_template_and_config(file_path);
57
58        if let Some(template_path) = chat_template_path {
59            chat_template_str = load_chat_template_from_file(template_path)?;
60        }
61
62        // Configure post_processor based on tokenizer_config.json (matches Python transformers)
63        // Only modify when at least one setting is explicitly true
64        let needs_eos = add_eos_token == Some(true);
65        let needs_bos = match add_bos_token {
66            Some(true) => true,
67            Some(false) => false,
68            // Not set: preserve existing behavior from tokenizer.json
69            None => needs_eos && Self::tokenizer_adds_special_tokens(&tokenizer),
70        };
71
72        if needs_bos || needs_eos {
73            if let Some(post_processor) =
74                Self::build_post_processor(needs_bos, needs_eos, &special_tokens, &vocab)
75            {
76                debug!(needs_bos, needs_eos, "Configured post_processor");
77                tokenizer.with_post_processor(Some(post_processor));
78            }
79        }
80
81        Ok(HuggingFaceTokenizer {
82            tokenizer,
83            special_tokens,
84            vocab,
85            reverse_vocab,
86            chat_template: ChatTemplateState::new(chat_template_str)?,
87        })
88    }
89
90    /// Check if the tokenizer's post_processor adds special tokens (e.g., BOS)
91    fn tokenizer_adds_special_tokens(tokenizer: &HfTokenizer) -> bool {
92        tokenizer
93            .encode("", true)
94            .map(|enc| !enc.get_ids().is_empty())
95            .unwrap_or(false)
96    }
97
98    /// Build a TemplateProcessing post_processor (matches Python transformers' update_post_processor)
99    /// Template format: "{bos}:0 $A:0 {eos}:0" with optional BOS/EOS based on config
100    fn build_post_processor(
101        add_bos_token: bool,
102        add_eos_token: bool,
103        special_tokens: &SpecialTokens,
104        vocab: &HashMap<String, TokenIdType>,
105    ) -> Option<TemplateProcessing> {
106        // Build template string exactly like Python:
107        // single = f"{(bos + ':0 ') if add_bos_token else ''}$A:0{(' ' + eos + ':0') if add_eos_token else ''}"
108        let mut template = String::with_capacity(32);
109        let mut tokens = Vec::with_capacity(2);
110
111        if add_bos_token {
112            let bos = special_tokens.bos_token.as_ref()?;
113            let bos_id = vocab.get(bos).copied()?;
114            template.push_str(bos);
115            template.push_str(":0 ");
116            tokens.push((bos.clone(), bos_id));
117        }
118
119        template.push_str("$A:0");
120
121        if add_eos_token {
122            let eos = special_tokens.eos_token.as_ref()?;
123            let eos_id = vocab.get(eos).copied()?;
124            template.push(' ');
125            template.push_str(eos);
126            template.push_str(":0");
127            tokens.push((eos.clone(), eos_id));
128        }
129
130        TemplateProcessing::builder()
131            .try_single(template.as_str())
132            .ok()?
133            .special_tokens(tokens)
134            .build()
135            .ok()
136    }
137
138    /// Create from an existing HuggingFace tokenizer
139    pub fn from_tokenizer(tokenizer: HfTokenizer) -> Self {
140        let special_tokens = Self::extract_special_tokens(&tokenizer);
141        let vocab = tokenizer.get_vocab(true); // true = include special tokens and added_tokens
142        let reverse_vocab: HashMap<TokenIdType, String> = vocab
143            .iter()
144            .map(|(token, &id)| (id, token.clone()))
145            .collect();
146
147        HuggingFaceTokenizer {
148            tokenizer,
149            special_tokens,
150            vocab,
151            reverse_vocab,
152            chat_template: ChatTemplateState::empty(),
153        }
154    }
155
156    /// Extract special tokens from the tokenizer
157    fn extract_special_tokens(tokenizer: &HfTokenizer) -> SpecialTokens {
158        // Get vocab with special tokens included (added_tokens like <|im_start|>)
159        let vocab = tokenizer.get_vocab(true);
160
161        let find_token = |patterns: &[&str]| -> Option<String> {
162            for pattern in patterns {
163                if vocab.contains_key(*pattern) {
164                    return Some((*pattern).to_string());
165                }
166            }
167            None
168        };
169
170        // Extract additional special tokens using the tokenizers library API
171        let additional_special_tokens: Vec<String> = tokenizer
172            .get_added_tokens_decoder()
173            .iter()
174            .filter(|(_id, token)| token.special) // Only tokens marked as special: true
175            .map(|(_id, token)| token.content.clone())
176            .collect();
177
178        SpecialTokens {
179            bos_token: find_token(&["<s>", "<|startoftext|>", "<BOS>", "[CLS]"]),
180            eos_token: find_token(&["</s>", "<|endoftext|>", "<EOS>", "[SEP]"]),
181            unk_token: find_token(&["<unk>", "<UNK>", "[UNK]"]),
182            sep_token: find_token(&["[SEP]", "<sep>", "<SEP>"]),
183            pad_token: find_token(&["<pad>", "<PAD>", "[PAD]"]),
184            cls_token: find_token(&["[CLS]", "<cls>", "<CLS>"]),
185            mask_token: find_token(&["[MASK]", "<mask>", "<MASK>"]),
186            additional_special_tokens,
187        }
188    }
189
190    /// Load chat template and special token settings from tokenizer_config.json
191    /// Returns Option<bool> to distinguish between explicit false vs not set
192    fn load_chat_template_and_config(
193        tokenizer_path: &str,
194    ) -> (Option<String>, Option<bool>, Option<bool>) {
195        (|| {
196            let path = std::path::Path::new(tokenizer_path);
197            let config_path = path.parent()?.join("tokenizer_config.json");
198
199            if !config_path.exists() {
200                return None;
201            }
202
203            let config_str = config_path.to_str()?;
204            let content = std::fs::read_to_string(&config_path).ok()?;
205            let config: serde_json::Value = serde_json::from_str(&content).ok()?;
206
207            let chat_template = super::chat_template::load_chat_template_from_config(config_str)
208                .ok()
209                .flatten();
210
211            let add_bos_token = config.get("add_bos_token").and_then(|v| v.as_bool());
212            let add_eos_token = config.get("add_eos_token").and_then(|v| v.as_bool());
213
214            Some((chat_template, add_bos_token, add_eos_token))
215        })()
216        .unwrap_or((None, None, None))
217    }
218}
219
220impl Encoder for HuggingFaceTokenizer {
221    fn encode(&self, input: &str, add_special_tokens: bool) -> Result<Encoding> {
222        self.tokenizer
223            .encode(input, add_special_tokens)
224            .map_err(|e| Error::msg(format!("Encoding failed: {e}")))
225            .map(|encoding| Encoding::Hf(Box::new(encoding)))
226    }
227
228    fn encode_batch(&self, inputs: &[&str], add_special_tokens: bool) -> Result<Vec<Encoding>> {
229        self.tokenizer
230            .encode_batch(inputs.to_vec(), add_special_tokens)
231            .map_err(|e| Error::msg(format!("Batch encoding failed: {e}")))
232            .map(|encodings| {
233                encodings
234                    .into_iter()
235                    .map(|e| Encoding::Hf(Box::new(e)))
236                    .collect()
237            })
238    }
239}
240
241impl Decoder for HuggingFaceTokenizer {
242    fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String> {
243        self.tokenizer
244            .decode(token_ids, skip_special_tokens)
245            .map_err(|e| Error::msg(format!("Decoding failed: {e}")))
246    }
247}
248
249impl TokenizerTrait for HuggingFaceTokenizer {
250    fn vocab_size(&self) -> usize {
251        self.tokenizer.get_vocab_size(false)
252    }
253
254    fn get_special_tokens(&self) -> &SpecialTokens {
255        &self.special_tokens
256    }
257
258    fn token_to_id(&self, token: &str) -> Option<TokenIdType> {
259        self.vocab.get(token).copied()
260    }
261
262    fn id_to_token(&self, id: TokenIdType) -> Option<String> {
263        self.reverse_vocab.get(&id).cloned()
264    }
265
266    fn as_any(&self) -> &dyn std::any::Any {
267        self
268    }
269
270    fn apply_chat_template(
271        &self,
272        messages: &[serde_json::Value],
273        params: ChatTemplateParams,
274    ) -> Result<String> {
275        self.chat_template.apply(messages, params)
276    }
277
278    fn chat_template_content_format(&self) -> ChatTemplateContentFormat {
279        self.chat_template.content_format()
280    }
281
282    fn set_chat_template(&mut self, template: String) -> Result<()> {
283        self.chat_template.set(template)
284    }
285}