1use std::collections::HashMap;
2
3use anyhow::{Error, Result};
4use tokenizers::{
5 processors::template::TemplateProcessing,
6 tokenizer::{step_decode_stream, Tokenizer as HfTokenizer},
7};
8use tracing::debug;
9
10use crate::{
11 chat_template::{
12 load_chat_template_from_file, ChatTemplateContentFormat, ChatTemplateParams,
13 ChatTemplateState, ThinkingKeyName, ThinkingToggle,
14 },
15 encoders::{deepseek_v32, deepseek_v4},
16 traits::{Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait},
17};
18
19#[derive(Debug, Clone, Copy)]
20enum Renderer {
21 Jinja,
22 DeepseekV32,
23 DeepseekV4,
24}
25
26pub struct HuggingFaceTokenizer {
28 tokenizer: HfTokenizer,
29 special_tokens: SpecialTokens,
30 vocab: HashMap<String, TokenIdType>,
31 reverse_vocab: HashMap<TokenIdType, String>,
32 chat_template: ChatTemplateState,
33 eos_token_ids: Vec<TokenIdType>,
35 renderer: Renderer,
37}
38
39impl HuggingFaceTokenizer {
40 pub fn from_file(file_path: &str) -> Result<Self> {
42 let path = std::path::Path::new(file_path);
44 let chat_template_path = path
45 .parent()
46 .and_then(crate::factory::discover_chat_template_in_dir);
47 Self::from_file_with_chat_template(file_path, chat_template_path.as_deref())
48 }
49
50 pub fn from_file_with_chat_template(
52 file_path: &str,
53 chat_template_path: Option<&str>,
54 ) -> Result<Self> {
55 let mut tokenizer = HfTokenizer::from_file(file_path)
56 .map_err(|e| Error::msg(format!("Failed to load tokenizer: {e}")))?;
57
58 let vocab = tokenizer.get_vocab(true); let reverse_vocab: HashMap<TokenIdType, String> = vocab
61 .iter()
62 .map(|(token, &id)| (id, token.clone()))
63 .collect();
64
65 let config_result = Self::load_chat_template_and_config(file_path);
67 let mut chat_template_str = config_result.chat_template;
68 let add_bos_token = config_result.add_bos_token;
69 let add_eos_token = config_result.add_eos_token;
70
71 let special_tokens = Self::extract_special_tokens(&tokenizer, &config_result.config_tokens);
73
74 if let Some(template_path) = chat_template_path {
75 chat_template_str = load_chat_template_from_file(template_path)?;
76 }
77
78 let needs_eos = add_eos_token == Some(true);
81 let needs_bos = match add_bos_token {
82 Some(true) => true,
83 Some(false) => false,
84 None => needs_eos && Self::tokenizer_adds_special_tokens(&tokenizer),
86 };
87
88 if needs_bos || needs_eos {
89 if let Some(post_processor) =
90 Self::build_post_processor(needs_bos, needs_eos, &special_tokens, &vocab)
91 {
92 debug!(needs_bos, needs_eos, "Configured post_processor");
93 tokenizer.with_post_processor(Some(post_processor));
94 }
95 }
96
97 let eos_token_ids = std::path::Path::new(file_path)
99 .parent()
100 .map(crate::eos::load_eos_token_ids)
101 .unwrap_or_default();
102
103 let renderer = std::path::Path::new(file_path)
105 .parent()
106 .map(detect_renderer_from_config)
107 .unwrap_or(Renderer::Jinja);
108
109 Ok(HuggingFaceTokenizer {
110 tokenizer,
111 special_tokens,
112 vocab,
113 reverse_vocab,
114 chat_template: ChatTemplateState::new(chat_template_str)?,
115 eos_token_ids,
116 renderer,
117 })
118 }
119
120 fn tokenizer_adds_special_tokens(tokenizer: &HfTokenizer) -> bool {
122 tokenizer
123 .encode("", true)
124 .map(|enc| !enc.get_ids().is_empty())
125 .unwrap_or(false)
126 }
127
128 fn build_post_processor(
131 add_bos_token: bool,
132 add_eos_token: bool,
133 special_tokens: &SpecialTokens,
134 vocab: &HashMap<String, TokenIdType>,
135 ) -> Option<TemplateProcessing> {
136 let mut template = String::with_capacity(32);
139 let mut tokens = Vec::with_capacity(2);
140
141 if add_bos_token {
142 let bos = special_tokens.bos_token.as_ref()?;
143 let bos_id = vocab.get(bos).copied()?;
144 template.push_str(bos);
145 template.push_str(":0 ");
146 tokens.push((bos.clone(), bos_id));
147 }
148
149 template.push_str("$A:0");
150
151 if add_eos_token {
152 let eos = special_tokens.eos_token.as_ref()?;
153 let eos_id = vocab.get(eos).copied()?;
154 template.push(' ');
155 template.push_str(eos);
156 template.push_str(":0");
157 tokens.push((eos.clone(), eos_id));
158 }
159
160 TemplateProcessing::builder()
161 .try_single(template.as_str())
162 .ok()?
163 .special_tokens(tokens)
164 .build()
165 .ok()
166 }
167
168 pub fn from_tokenizer(tokenizer: HfTokenizer) -> Self {
170 let special_tokens = Self::extract_special_tokens(&tokenizer, &ConfigTokens::default());
171 let vocab = tokenizer.get_vocab(true); let reverse_vocab: HashMap<TokenIdType, String> = vocab
173 .iter()
174 .map(|(token, &id)| (id, token.clone()))
175 .collect();
176
177 HuggingFaceTokenizer {
178 tokenizer,
179 special_tokens,
180 vocab,
181 reverse_vocab,
182 chat_template: ChatTemplateState::empty(),
183 eos_token_ids: Vec::new(), renderer: Renderer::Jinja,
185 }
186 }
187
188 fn extract_special_tokens(
194 tokenizer: &HfTokenizer,
195 config_tokens: &ConfigTokens,
196 ) -> SpecialTokens {
197 let vocab = tokenizer.get_vocab(true);
199
200 let find_token = |patterns: &[&str]| -> Option<String> {
201 for pattern in patterns {
202 if vocab.contains_key(*pattern) {
203 return Some((*pattern).to_string());
204 }
205 }
206 None
207 };
208
209 let additional_special_tokens: Vec<String> = tokenizer
211 .get_added_tokens_decoder()
212 .iter()
213 .filter(|(_id, token)| token.special)
214 .map(|(_id, token)| token.content.clone())
215 .collect();
216
217 SpecialTokens {
219 bos_token: config_tokens
220 .bos_token
221 .clone()
222 .or_else(|| find_token(&["<s>", "<|startoftext|>", "<BOS>", "[CLS]"])),
223 eos_token: config_tokens
224 .eos_token
225 .clone()
226 .or_else(|| find_token(&["</s>", "<|endoftext|>", "<EOS>", "[SEP]"])),
227 unk_token: config_tokens
228 .unk_token
229 .clone()
230 .or_else(|| find_token(&["<unk>", "<UNK>", "[UNK]"])),
231 sep_token: find_token(&["[SEP]", "<sep>", "<SEP>"]),
232 pad_token: config_tokens
233 .pad_token
234 .clone()
235 .or_else(|| find_token(&["<pad>", "<PAD>", "[PAD]"])),
236 cls_token: find_token(&["[CLS]", "<cls>", "<CLS>"]),
237 mask_token: find_token(&["[MASK]", "<mask>", "<MASK>"]),
238 additional_special_tokens,
239 }
240 }
241
242 fn load_chat_template_and_config(tokenizer_path: &str) -> TokenizerConfigResult {
245 (|| {
246 let path = std::path::Path::new(tokenizer_path);
247 let config_path = path.parent()?.join("tokenizer_config.json");
248
249 if !config_path.exists() {
250 return None;
251 }
252
253 let content = std::fs::read_to_string(&config_path).ok()?;
254 let config: serde_json::Value = serde_json::from_str(&content).ok()?;
255
256 let chat_template = config
258 .get("chat_template")
259 .and_then(|v| v.as_str())
260 .map(String::from);
261
262 let add_bos_token = config.get("add_bos_token").and_then(|v| v.as_bool());
263 let add_eos_token = config.get("add_eos_token").and_then(|v| v.as_bool());
264
265 let get_token = |key: &str| -> Option<String> {
267 config.get(key).and_then(|v| {
268 v.as_str()
269 .map(String::from)
270 .or_else(|| v.get("content").and_then(|c| c.as_str()).map(String::from))
271 })
272 };
273
274 let config_tokens = ConfigTokens {
275 bos_token: get_token("bos_token"),
276 eos_token: get_token("eos_token"),
277 unk_token: get_token("unk_token"),
278 pad_token: get_token("pad_token"),
279 };
280
281 Some(TokenizerConfigResult {
282 chat_template,
283 add_bos_token,
284 add_eos_token,
285 config_tokens,
286 })
287 })()
288 .unwrap_or_default()
289 }
290}
291
292#[derive(Default)]
294struct ConfigTokens {
295 bos_token: Option<String>,
296 eos_token: Option<String>,
297 unk_token: Option<String>,
298 pad_token: Option<String>,
299}
300
301#[derive(Default)]
303struct TokenizerConfigResult {
304 chat_template: Option<String>,
305 add_bos_token: Option<bool>,
306 add_eos_token: Option<bool>,
307 config_tokens: ConfigTokens,
308}
309
310impl Encoder for HuggingFaceTokenizer {
311 fn encode(&self, input: &str, add_special_tokens: bool) -> Result<Encoding> {
312 self.tokenizer
313 .encode(input, add_special_tokens)
314 .map_err(|e| Error::msg(format!("Encoding failed: {e}")))
315 .map(|encoding| Encoding::Hf(Box::new(encoding)))
316 }
317
318 fn encode_batch(&self, inputs: &[&str], add_special_tokens: bool) -> Result<Vec<Encoding>> {
319 self.tokenizer
320 .encode_batch(inputs.to_vec(), add_special_tokens)
321 .map_err(|e| Error::msg(format!("Batch encoding failed: {e}")))
322 .map(|encodings| {
323 encodings
324 .into_iter()
325 .map(|e| Encoding::Hf(Box::new(e)))
326 .collect()
327 })
328 }
329}
330
331impl Decoder for HuggingFaceTokenizer {
332 fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String> {
333 self.tokenizer
334 .decode(token_ids, skip_special_tokens)
335 .map_err(|e| Error::msg(format!("Decoding failed: {e}")))
336 }
337
338 fn decode_step(
344 &self,
345 token_id: TokenIdType,
346 ids: &mut Vec<TokenIdType>,
347 prefix: &mut String,
348 prefix_index: &mut usize,
349 skip_special_tokens: bool,
350 ) -> Result<Option<String>> {
351 step_decode_stream(
352 &self.tokenizer,
353 vec![token_id],
354 skip_special_tokens,
355 ids,
356 prefix,
357 prefix_index,
358 )
359 .map_err(|e| Error::msg(format!("Decode stream error: {e}")))
360 }
361}
362
363impl TokenizerTrait for HuggingFaceTokenizer {
364 fn vocab_size(&self) -> usize {
365 self.tokenizer.get_vocab_size(false)
366 }
367
368 fn get_special_tokens(&self) -> &SpecialTokens {
369 &self.special_tokens
370 }
371
372 fn token_to_id(&self, token: &str) -> Option<TokenIdType> {
373 self.vocab.get(token).copied()
374 }
375
376 fn id_to_token(&self, id: TokenIdType) -> Option<String> {
377 self.reverse_vocab.get(&id).cloned()
378 }
379
380 fn as_any(&self) -> &dyn std::any::Any {
381 self
382 }
383
384 fn eos_token_ids(&self) -> &[TokenIdType] {
385 &self.eos_token_ids
386 }
387
388 fn apply_chat_template(
389 &self,
390 messages: &[serde_json::Value],
391 params: ChatTemplateParams,
392 ) -> Result<String> {
393 match self.renderer {
394 Renderer::Jinja => {
395 if params.special_tokens.is_some() {
397 return self.chat_template.apply(messages, params);
398 }
399 let params = ChatTemplateParams {
400 special_tokens: Some(&self.special_tokens),
401 ..params
402 };
403 self.chat_template.apply(messages, params)
404 }
405 Renderer::DeepseekV32 => apply_deepseek_v32(messages, ¶ms),
406 Renderer::DeepseekV4 => apply_deepseek_v4(messages, ¶ms),
407 }
408 }
409
410 fn chat_template_content_format(&self) -> ChatTemplateContentFormat {
411 self.chat_template.content_format()
412 }
413
414 fn thinking_toggle(&self) -> ThinkingToggle {
415 match self.renderer {
416 Renderer::DeepseekV32 | Renderer::DeepseekV4 => ThinkingToggle::DefaultOff,
420 Renderer::Jinja => self.chat_template.thinking_toggle(),
421 }
422 }
423
424 fn thinking_key_name(&self) -> Option<ThinkingKeyName> {
425 match self.renderer {
426 Renderer::DeepseekV32 | Renderer::DeepseekV4 => Some(ThinkingKeyName::Thinking),
427 Renderer::Jinja => self.chat_template.thinking_key_name(),
428 }
429 }
430 fn think_in_prefill(&self) -> bool {
431 match self.renderer {
432 Renderer::DeepseekV32 | Renderer::DeepseekV4 => true,
436 Renderer::Jinja => self.chat_template.think_in_prefill(),
437 }
438 }
439
440 fn set_chat_template(&mut self, template: String) -> Result<()> {
441 self.chat_template.set(template)
442 }
443}
444
445fn detect_renderer_from_config(dir: &std::path::Path) -> Renderer {
453 let path = dir.join("config.json");
454 if !path.exists() {
455 return Renderer::Jinja;
456 }
457 let content = match std::fs::read_to_string(&path) {
458 Ok(c) => c,
459 Err(err) => {
460 debug!(?err, ?path, "config.json unreadable; using Jinja renderer");
461 return Renderer::Jinja;
462 }
463 };
464 let value: serde_json::Value = match serde_json::from_str(&content) {
465 Ok(v) => v,
466 Err(err) => {
467 debug!(?err, ?path, "config.json malformed; using Jinja renderer");
468 return Renderer::Jinja;
469 }
470 };
471 let architectures = value.get("architectures").and_then(|v| v.as_array());
472 let arch_strs: Vec<&str> = architectures
473 .map(|a| a.iter().filter_map(|v| v.as_str()).collect())
474 .unwrap_or_default();
475 if arch_strs.contains(&"DeepseekV32ForCausalLM") {
476 debug!(?path, "selected DeepseekV32 chat-template renderer");
477 return Renderer::DeepseekV32;
478 }
479 if arch_strs.contains(&"DeepseekV4ForCausalLM") {
480 debug!(?path, "selected DeepseekV4 chat-template renderer");
481 return Renderer::DeepseekV4;
482 }
483 Renderer::Jinja
484}
485
486fn derive_thinking_mode(params: &ChatTemplateParams) -> deepseek_v32::ThinkingMode {
493 let enabled = params
494 .template_kwargs
495 .and_then(|k| k.get("thinking"))
496 .and_then(serde_json::Value::as_bool)
497 .unwrap_or(false);
498 if enabled {
499 deepseek_v32::ThinkingMode::Thinking
500 } else {
501 deepseek_v32::ThinkingMode::Chat
502 }
503}
504
505fn resolve_drop_thinking(messages: &[serde_json::Value]) -> bool {
508 !messages.iter().any(|m| {
509 let role = m.get("role").and_then(|r| r.as_str());
510 matches!(role, Some("system" | "developer"))
511 && m.get("tools")
512 .and_then(|t| t.as_array())
513 .is_some_and(|arr| !arr.is_empty())
514 })
515}
516fn inject_tools_into_messages(
522 messages: &[serde_json::Value],
523 tools: Option<&[serde_json::Value]>,
524) -> Option<Vec<serde_json::Value>> {
525 let tools = tools?;
526 if tools.is_empty() {
527 return None;
528 }
529 let mut owned: Vec<serde_json::Value> = messages.to_vec();
530 let first_role = owned
531 .first()
532 .and_then(|m| m.get("role"))
533 .and_then(|r| r.as_str());
534 if !matches!(first_role, Some("system" | "developer")) {
535 owned.insert(0, serde_json::json!({ "role": "system", "content": "" }));
536 }
537 if let Some(obj) = owned[0].as_object_mut() {
538 obj.insert("tools".into(), serde_json::Value::Array(tools.to_vec()));
539 }
540 Some(owned)
541}
542
543fn apply_deepseek_v32(
544 messages: &[serde_json::Value],
545 params: &ChatTemplateParams,
546) -> Result<String> {
547 let owned = inject_tools_into_messages(messages, params.tools);
548 let msgs: &[serde_json::Value] = owned.as_deref().unwrap_or(messages);
549 let thinking_mode = derive_thinking_mode(params);
550 let encode_params = deepseek_v32::EncodeParams {
551 add_default_bos_token: true,
552 drop_thinking: resolve_drop_thinking(msgs),
553 };
554 deepseek_v32::encode_messages(msgs, thinking_mode, &encode_params)
555 .map_err(|e| Error::msg(format!("DeepSeek V3.2 encode failed: {e}")))
556}
557fn apply_deepseek_v4(
558 messages: &[serde_json::Value],
559 params: &ChatTemplateParams,
560) -> Result<String> {
561 let owned = inject_tools_into_messages(messages, params.tools);
562 let msgs: &[serde_json::Value] = owned.as_deref().unwrap_or(messages);
563 let thinking_mode = derive_thinking_mode(params);
564 let reasoning_effort = params
565 .template_kwargs
566 .and_then(|k| k.get("reasoning_effort"))
567 .and_then(|v| v.as_str())
568 .and_then(|s| match s {
569 "max" => Some(deepseek_v4::ReasoningEffort::Max),
570 "high" => Some(deepseek_v4::ReasoningEffort::High),
571 _ => None,
572 });
573 let encode_params = deepseek_v4::EncodeParams {
574 add_default_bos_token: true,
575 drop_thinking: resolve_drop_thinking(msgs),
576 reasoning_effort,
577 };
578 deepseek_v4::encode_messages(msgs, thinking_mode, &encode_params)
579 .map_err(|e| Error::msg(format!("DeepSeek V4 encode failed: {e}")))
580}