1use std::collections::HashMap;
2
3use anyhow::{Error, Result};
4use tokenizers::{
5 processors::template::TemplateProcessing,
6 tokenizer::{step_decode_stream, Tokenizer as HfTokenizer},
7};
8use tracing::debug;
9
10use crate::{
11 chat_template::{
12 load_chat_template_from_file, ChatTemplateContentFormat, ChatTemplateParams,
13 ChatTemplateState, ThinkingKeyName, ThinkingToggle,
14 },
15 traits::{Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait},
16};
17
18pub struct HuggingFaceTokenizer {
20 tokenizer: HfTokenizer,
21 special_tokens: SpecialTokens,
22 vocab: HashMap<String, TokenIdType>,
23 reverse_vocab: HashMap<TokenIdType, String>,
24 chat_template: ChatTemplateState,
25}
26
27impl HuggingFaceTokenizer {
28 pub fn from_file(file_path: &str) -> Result<Self> {
30 let path = std::path::Path::new(file_path);
32 let chat_template_path = path
33 .parent()
34 .and_then(crate::factory::discover_chat_template_in_dir);
35 Self::from_file_with_chat_template(file_path, chat_template_path.as_deref())
36 }
37
38 pub fn from_file_with_chat_template(
40 file_path: &str,
41 chat_template_path: Option<&str>,
42 ) -> Result<Self> {
43 let mut tokenizer = HfTokenizer::from_file(file_path)
44 .map_err(|e| Error::msg(format!("Failed to load tokenizer: {e}")))?;
45
46 let vocab = tokenizer.get_vocab(true); let reverse_vocab: HashMap<TokenIdType, String> = vocab
49 .iter()
50 .map(|(token, &id)| (id, token.clone()))
51 .collect();
52
53 let config_result = Self::load_chat_template_and_config(file_path);
55 let mut chat_template_str = config_result.chat_template;
56 let add_bos_token = config_result.add_bos_token;
57 let add_eos_token = config_result.add_eos_token;
58
59 let special_tokens = Self::extract_special_tokens(&tokenizer, &config_result.config_tokens);
61
62 if let Some(template_path) = chat_template_path {
63 chat_template_str = load_chat_template_from_file(template_path)?;
64 }
65
66 let needs_eos = add_eos_token == Some(true);
69 let needs_bos = match add_bos_token {
70 Some(true) => true,
71 Some(false) => false,
72 None => needs_eos && Self::tokenizer_adds_special_tokens(&tokenizer),
74 };
75
76 if needs_bos || needs_eos {
77 if let Some(post_processor) =
78 Self::build_post_processor(needs_bos, needs_eos, &special_tokens, &vocab)
79 {
80 debug!(needs_bos, needs_eos, "Configured post_processor");
81 tokenizer.with_post_processor(Some(post_processor));
82 }
83 }
84
85 Ok(HuggingFaceTokenizer {
86 tokenizer,
87 special_tokens,
88 vocab,
89 reverse_vocab,
90 chat_template: ChatTemplateState::new(chat_template_str)?,
91 })
92 }
93
94 fn tokenizer_adds_special_tokens(tokenizer: &HfTokenizer) -> bool {
96 tokenizer
97 .encode("", true)
98 .map(|enc| !enc.get_ids().is_empty())
99 .unwrap_or(false)
100 }
101
102 fn build_post_processor(
105 add_bos_token: bool,
106 add_eos_token: bool,
107 special_tokens: &SpecialTokens,
108 vocab: &HashMap<String, TokenIdType>,
109 ) -> Option<TemplateProcessing> {
110 let mut template = String::with_capacity(32);
113 let mut tokens = Vec::with_capacity(2);
114
115 if add_bos_token {
116 let bos = special_tokens.bos_token.as_ref()?;
117 let bos_id = vocab.get(bos).copied()?;
118 template.push_str(bos);
119 template.push_str(":0 ");
120 tokens.push((bos.clone(), bos_id));
121 }
122
123 template.push_str("$A:0");
124
125 if add_eos_token {
126 let eos = special_tokens.eos_token.as_ref()?;
127 let eos_id = vocab.get(eos).copied()?;
128 template.push(' ');
129 template.push_str(eos);
130 template.push_str(":0");
131 tokens.push((eos.clone(), eos_id));
132 }
133
134 TemplateProcessing::builder()
135 .try_single(template.as_str())
136 .ok()?
137 .special_tokens(tokens)
138 .build()
139 .ok()
140 }
141
142 pub fn from_tokenizer(tokenizer: HfTokenizer) -> Self {
144 let special_tokens = Self::extract_special_tokens(&tokenizer, &ConfigTokens::default());
145 let vocab = tokenizer.get_vocab(true); let reverse_vocab: HashMap<TokenIdType, String> = vocab
147 .iter()
148 .map(|(token, &id)| (id, token.clone()))
149 .collect();
150
151 HuggingFaceTokenizer {
152 tokenizer,
153 special_tokens,
154 vocab,
155 reverse_vocab,
156 chat_template: ChatTemplateState::empty(),
157 }
158 }
159
160 fn extract_special_tokens(
166 tokenizer: &HfTokenizer,
167 config_tokens: &ConfigTokens,
168 ) -> SpecialTokens {
169 let vocab = tokenizer.get_vocab(true);
171
172 let find_token = |patterns: &[&str]| -> Option<String> {
173 for pattern in patterns {
174 if vocab.contains_key(*pattern) {
175 return Some((*pattern).to_string());
176 }
177 }
178 None
179 };
180
181 let additional_special_tokens: Vec<String> = tokenizer
183 .get_added_tokens_decoder()
184 .iter()
185 .filter(|(_id, token)| token.special)
186 .map(|(_id, token)| token.content.clone())
187 .collect();
188
189 SpecialTokens {
191 bos_token: config_tokens
192 .bos_token
193 .clone()
194 .or_else(|| find_token(&["<s>", "<|startoftext|>", "<BOS>", "[CLS]"])),
195 eos_token: config_tokens
196 .eos_token
197 .clone()
198 .or_else(|| find_token(&["</s>", "<|endoftext|>", "<EOS>", "[SEP]"])),
199 unk_token: config_tokens
200 .unk_token
201 .clone()
202 .or_else(|| find_token(&["<unk>", "<UNK>", "[UNK]"])),
203 sep_token: find_token(&["[SEP]", "<sep>", "<SEP>"]),
204 pad_token: config_tokens
205 .pad_token
206 .clone()
207 .or_else(|| find_token(&["<pad>", "<PAD>", "[PAD]"])),
208 cls_token: find_token(&["[CLS]", "<cls>", "<CLS>"]),
209 mask_token: find_token(&["[MASK]", "<mask>", "<MASK>"]),
210 additional_special_tokens,
211 }
212 }
213
214 fn load_chat_template_and_config(tokenizer_path: &str) -> TokenizerConfigResult {
217 (|| {
218 let path = std::path::Path::new(tokenizer_path);
219 let config_path = path.parent()?.join("tokenizer_config.json");
220
221 if !config_path.exists() {
222 return None;
223 }
224
225 let content = std::fs::read_to_string(&config_path).ok()?;
226 let config: serde_json::Value = serde_json::from_str(&content).ok()?;
227
228 let chat_template = config
230 .get("chat_template")
231 .and_then(|v| v.as_str())
232 .map(String::from);
233
234 let add_bos_token = config.get("add_bos_token").and_then(|v| v.as_bool());
235 let add_eos_token = config.get("add_eos_token").and_then(|v| v.as_bool());
236
237 let get_token = |key: &str| -> Option<String> {
239 config.get(key).and_then(|v| {
240 v.as_str()
241 .map(String::from)
242 .or_else(|| v.get("content").and_then(|c| c.as_str()).map(String::from))
243 })
244 };
245
246 let config_tokens = ConfigTokens {
247 bos_token: get_token("bos_token"),
248 eos_token: get_token("eos_token"),
249 unk_token: get_token("unk_token"),
250 pad_token: get_token("pad_token"),
251 };
252
253 Some(TokenizerConfigResult {
254 chat_template,
255 add_bos_token,
256 add_eos_token,
257 config_tokens,
258 })
259 })()
260 .unwrap_or_default()
261 }
262}
263
264#[derive(Default)]
266struct ConfigTokens {
267 bos_token: Option<String>,
268 eos_token: Option<String>,
269 unk_token: Option<String>,
270 pad_token: Option<String>,
271}
272
273#[derive(Default)]
275struct TokenizerConfigResult {
276 chat_template: Option<String>,
277 add_bos_token: Option<bool>,
278 add_eos_token: Option<bool>,
279 config_tokens: ConfigTokens,
280}
281
282impl Encoder for HuggingFaceTokenizer {
283 fn encode(&self, input: &str, add_special_tokens: bool) -> Result<Encoding> {
284 self.tokenizer
285 .encode(input, add_special_tokens)
286 .map_err(|e| Error::msg(format!("Encoding failed: {e}")))
287 .map(|encoding| Encoding::Hf(Box::new(encoding)))
288 }
289
290 fn encode_batch(&self, inputs: &[&str], add_special_tokens: bool) -> Result<Vec<Encoding>> {
291 self.tokenizer
292 .encode_batch(inputs.to_vec(), add_special_tokens)
293 .map_err(|e| Error::msg(format!("Batch encoding failed: {e}")))
294 .map(|encodings| {
295 encodings
296 .into_iter()
297 .map(|e| Encoding::Hf(Box::new(e)))
298 .collect()
299 })
300 }
301}
302
303impl Decoder for HuggingFaceTokenizer {
304 fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String> {
305 self.tokenizer
306 .decode(token_ids, skip_special_tokens)
307 .map_err(|e| Error::msg(format!("Decoding failed: {e}")))
308 }
309
310 fn decode_step(
316 &self,
317 token_id: TokenIdType,
318 ids: &mut Vec<TokenIdType>,
319 prefix: &mut String,
320 prefix_index: &mut usize,
321 skip_special_tokens: bool,
322 ) -> Result<Option<String>> {
323 step_decode_stream(
324 &self.tokenizer,
325 vec![token_id],
326 skip_special_tokens,
327 ids,
328 prefix,
329 prefix_index,
330 )
331 .map_err(|e| Error::msg(format!("Decode stream error: {e}")))
332 }
333}
334
335impl TokenizerTrait for HuggingFaceTokenizer {
336 fn vocab_size(&self) -> usize {
337 self.tokenizer.get_vocab_size(false)
338 }
339
340 fn get_special_tokens(&self) -> &SpecialTokens {
341 &self.special_tokens
342 }
343
344 fn token_to_id(&self, token: &str) -> Option<TokenIdType> {
345 self.vocab.get(token).copied()
346 }
347
348 fn id_to_token(&self, id: TokenIdType) -> Option<String> {
349 self.reverse_vocab.get(&id).cloned()
350 }
351
352 fn as_any(&self) -> &dyn std::any::Any {
353 self
354 }
355
356 fn apply_chat_template(
357 &self,
358 messages: &[serde_json::Value],
359 params: ChatTemplateParams,
360 ) -> Result<String> {
361 if params.special_tokens.is_some() {
363 return self.chat_template.apply(messages, params);
364 }
365 let params = ChatTemplateParams {
366 special_tokens: Some(&self.special_tokens),
367 ..params
368 };
369 self.chat_template.apply(messages, params)
370 }
371
372 fn chat_template_content_format(&self) -> ChatTemplateContentFormat {
373 self.chat_template.content_format()
374 }
375
376 fn thinking_toggle(&self) -> ThinkingToggle {
377 self.chat_template.thinking_toggle()
378 }
379
380 fn thinking_key_name(&self) -> Option<ThinkingKeyName> {
381 self.chat_template.thinking_key_name()
382 }
383 fn think_in_prefill(&self) -> bool {
384 self.chat_template.think_in_prefill()
385 }
386
387 fn set_chat_template(&mut self, template: String) -> Result<()> {
388 self.chat_template.set(template)
389 }
390}