1use std::collections::HashMap;
2
3use anyhow::{Error, Result};
4use tokenizers::{processors::template::TemplateProcessing, tokenizer::Tokenizer as HfTokenizer};
5use tracing::debug;
6
7use crate::{
8 chat_template::{
9 detect_chat_template_content_format, ChatTemplateContentFormat, ChatTemplateParams,
10 ChatTemplateProcessor,
11 },
12 traits::{Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait},
13};
14
15pub struct HuggingFaceTokenizer {
17 tokenizer: HfTokenizer,
18 special_tokens: SpecialTokens,
19 vocab: HashMap<String, TokenIdType>,
20 reverse_vocab: HashMap<TokenIdType, String>,
21 chat_template: Option<String>,
22 content_format: ChatTemplateContentFormat,
24}
25
26impl HuggingFaceTokenizer {
27 pub fn from_file(file_path: &str) -> Result<Self> {
29 let path = std::path::Path::new(file_path);
31 let chat_template_path = path
32 .parent()
33 .and_then(crate::factory::discover_chat_template_in_dir);
34 Self::from_file_with_chat_template(file_path, chat_template_path.as_deref())
35 }
36
37 pub fn from_file_with_chat_template(
39 file_path: &str,
40 chat_template_path: Option<&str>,
41 ) -> Result<Self> {
42 let mut tokenizer = HfTokenizer::from_file(file_path)
43 .map_err(|e| Error::msg(format!("Failed to load tokenizer: {}", e)))?;
44
45 let special_tokens = Self::extract_special_tokens(&tokenizer);
47
48 let vocab = tokenizer.get_vocab(true); let reverse_vocab: HashMap<TokenIdType, String> = vocab
51 .iter()
52 .map(|(token, &id)| (id, token.clone()))
53 .collect();
54
55 let (chat_template, add_bos_token, add_eos_token) =
57 if let Some(template_path) = chat_template_path {
58 (
60 Self::load_chat_template_from_file(template_path)?,
61 None,
62 None,
63 )
64 } else {
65 Self::load_chat_template_and_config(file_path)
67 };
68
69 let content_format = if let Some(ref template) = chat_template {
71 detect_chat_template_content_format(template)
72 } else {
73 ChatTemplateContentFormat::String };
75
76 let needs_eos = add_eos_token == Some(true);
79 let needs_bos = match add_bos_token {
80 Some(true) => true,
81 Some(false) => false,
82 None => needs_eos && Self::tokenizer_adds_special_tokens(&tokenizer),
84 };
85
86 if needs_bos || needs_eos {
87 if let Some(post_processor) =
88 Self::build_post_processor(needs_bos, needs_eos, &special_tokens, &vocab)
89 {
90 debug!(needs_bos, needs_eos, "Configured post_processor");
91 tokenizer.with_post_processor(Some(post_processor));
92 }
93 }
94
95 Ok(HuggingFaceTokenizer {
96 tokenizer,
97 special_tokens,
98 vocab,
99 reverse_vocab,
100 chat_template,
101 content_format,
102 })
103 }
104
105 fn tokenizer_adds_special_tokens(tokenizer: &HfTokenizer) -> bool {
107 tokenizer
108 .encode("", true)
109 .map(|enc| !enc.get_ids().is_empty())
110 .unwrap_or(false)
111 }
112
113 fn build_post_processor(
116 add_bos_token: bool,
117 add_eos_token: bool,
118 special_tokens: &SpecialTokens,
119 vocab: &HashMap<String, TokenIdType>,
120 ) -> Option<TemplateProcessing> {
121 let mut template = String::with_capacity(32);
124 let mut tokens = Vec::with_capacity(2);
125
126 if add_bos_token {
127 let bos = special_tokens.bos_token.as_ref()?;
128 let bos_id = vocab.get(bos).copied()?;
129 template.push_str(bos);
130 template.push_str(":0 ");
131 tokens.push((bos.clone(), bos_id));
132 }
133
134 template.push_str("$A:0");
135
136 if add_eos_token {
137 let eos = special_tokens.eos_token.as_ref()?;
138 let eos_id = vocab.get(eos).copied()?;
139 template.push(' ');
140 template.push_str(eos);
141 template.push_str(":0");
142 tokens.push((eos.clone(), eos_id));
143 }
144
145 TemplateProcessing::builder()
146 .try_single(template.as_str())
147 .ok()?
148 .special_tokens(tokens)
149 .build()
150 .ok()
151 }
152
153 pub fn from_tokenizer(tokenizer: HfTokenizer) -> Self {
155 let special_tokens = Self::extract_special_tokens(&tokenizer);
156 let vocab = tokenizer.get_vocab(true); let reverse_vocab: HashMap<TokenIdType, String> = vocab
158 .iter()
159 .map(|(token, &id)| (id, token.clone()))
160 .collect();
161
162 HuggingFaceTokenizer {
163 tokenizer,
164 special_tokens,
165 vocab,
166 reverse_vocab,
167 chat_template: None,
168 content_format: ChatTemplateContentFormat::String, }
170 }
171
172 fn extract_special_tokens(tokenizer: &HfTokenizer) -> SpecialTokens {
174 let vocab = tokenizer.get_vocab(true);
176
177 let find_token = |patterns: &[&str]| -> Option<String> {
178 for pattern in patterns {
179 if vocab.contains_key(*pattern) {
180 return Some(pattern.to_string());
181 }
182 }
183 None
184 };
185
186 let additional_special_tokens: Vec<String> = tokenizer
188 .get_added_tokens_decoder()
189 .iter()
190 .filter(|(_id, token)| token.special) .map(|(_id, token)| token.content.clone())
192 .collect();
193
194 SpecialTokens {
195 bos_token: find_token(&["<s>", "<|startoftext|>", "<BOS>", "[CLS]"]),
196 eos_token: find_token(&["</s>", "<|endoftext|>", "<EOS>", "[SEP]"]),
197 unk_token: find_token(&["<unk>", "<UNK>", "[UNK]"]),
198 sep_token: find_token(&["[SEP]", "<sep>", "<SEP>"]),
199 pad_token: find_token(&["<pad>", "<PAD>", "[PAD]"]),
200 cls_token: find_token(&["[CLS]", "<cls>", "<CLS>"]),
201 mask_token: find_token(&["[MASK]", "<mask>", "<MASK>"]),
202 additional_special_tokens,
203 }
204 }
205
206 fn load_chat_template_and_config(
209 tokenizer_path: &str,
210 ) -> (Option<String>, Option<bool>, Option<bool>) {
211 (|| {
212 let path = std::path::Path::new(tokenizer_path);
213 let config_path = path.parent()?.join("tokenizer_config.json");
214
215 if !config_path.exists() {
216 return None;
217 }
218
219 let config_str = config_path.to_str()?;
220 let content = std::fs::read_to_string(&config_path).ok()?;
221 let config: serde_json::Value = serde_json::from_str(&content).ok()?;
222
223 let chat_template = super::chat_template::load_chat_template_from_config(config_str)
224 .ok()
225 .flatten();
226
227 let add_bos_token = config.get("add_bos_token").and_then(|v| v.as_bool());
228 let add_eos_token = config.get("add_eos_token").and_then(|v| v.as_bool());
229
230 Some((chat_template, add_bos_token, add_eos_token))
231 })()
232 .unwrap_or((None, None, None))
233 }
234
235 fn load_chat_template_from_file(template_path: &str) -> Result<Option<String>> {
237 use std::fs;
238
239 let content = fs::read_to_string(template_path)
240 .map_err(|e| Error::msg(format!("Failed to read chat template file: {}", e)))?;
241
242 if template_path.ends_with(".json") {
244 let json_value: serde_json::Value = serde_json::from_str(&content)
246 .map_err(|e| Error::msg(format!("Failed to parse chat_template.json: {}", e)))?;
247
248 if let Some(template_str) = json_value.as_str() {
249 return Ok(Some(template_str.to_string()));
250 } else if let Some(obj) = json_value.as_object() {
251 if let Some(template_value) = obj.get("chat_template") {
252 if let Some(template_str) = template_value.as_str() {
253 return Ok(Some(template_str.to_string()));
254 }
255 }
256 }
257
258 return Err(Error::msg(
259 "chat_template.json does not contain a valid template",
260 ));
261 }
262
263 let template = content.trim().replace("\\n", "\n");
266
267 Ok(Some(template))
268 }
269
270 pub fn set_chat_template(&mut self, template: String) {
272 self.content_format = detect_chat_template_content_format(&template);
274 self.chat_template = Some(template);
275 }
276
277 pub fn chat_template_content_format(&self) -> ChatTemplateContentFormat {
279 self.content_format
280 }
281
282 pub fn apply_chat_template(
286 &self,
287 messages: &[serde_json::Value],
288 params: ChatTemplateParams,
289 ) -> Result<String> {
290 if let Some(ref template) = self.chat_template {
291 let processor = ChatTemplateProcessor::new(template.clone());
292 processor.apply_chat_template(messages, params)
293 } else {
294 Err(Error::msg(
295 "Cannot use chat template functions because tokenizer.chat_template is not set and no template \
296 argument was passed! For information about writing templates and setting the \
297 tokenizer.chat_template attribute, please see the documentation at \
298 https://huggingface.co/docs/transformers/main/en/chat_templating"
299 ))
300 }
301 }
302}
303
304impl Encoder for HuggingFaceTokenizer {
305 fn encode(&self, input: &str, add_special_tokens: bool) -> Result<Encoding> {
306 self.tokenizer
307 .encode(input, add_special_tokens)
308 .map_err(|e| Error::msg(format!("Encoding failed: {}", e)))
309 .map(|encoding| Encoding::Hf(Box::new(encoding)))
310 }
311
312 fn encode_batch(&self, inputs: &[&str], add_special_tokens: bool) -> Result<Vec<Encoding>> {
313 self.tokenizer
314 .encode_batch(inputs.to_vec(), add_special_tokens)
315 .map_err(|e| Error::msg(format!("Batch encoding failed: {}", e)))
316 .map(|encodings| {
317 encodings
318 .into_iter()
319 .map(|e| Encoding::Hf(Box::new(e)))
320 .collect()
321 })
322 }
323}
324
325impl Decoder for HuggingFaceTokenizer {
326 fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String> {
327 self.tokenizer
328 .decode(token_ids, skip_special_tokens)
329 .map_err(|e| Error::msg(format!("Decoding failed: {}", e)))
330 }
331}
332
333impl TokenizerTrait for HuggingFaceTokenizer {
334 fn vocab_size(&self) -> usize {
335 self.tokenizer.get_vocab_size(false)
336 }
337
338 fn get_special_tokens(&self) -> &SpecialTokens {
339 &self.special_tokens
340 }
341
342 fn token_to_id(&self, token: &str) -> Option<TokenIdType> {
343 self.vocab.get(token).copied()
344 }
345
346 fn id_to_token(&self, id: TokenIdType) -> Option<String> {
347 self.reverse_vocab.get(&id).cloned()
348 }
349
350 fn as_any(&self) -> &dyn std::any::Any {
351 self
352 }
353}
354
355#[cfg(test)]
356mod tests {
357 }