use std::collections::HashMap;
use anyhow::{Error, Result};
use tokenizers::{
processors::template::TemplateProcessing,
tokenizer::{step_decode_stream, Tokenizer as HfTokenizer},
};
use tracing::debug;
use crate::{
chat_template::{
load_chat_template_from_file, ChatTemplateContentFormat, ChatTemplateParams,
ChatTemplateState, ThinkingKeyName, ThinkingToggle,
},
encoders::{deepseek_v32, deepseek_v4},
traits::{Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait},
};
#[derive(Debug, Clone, Copy)]
enum Renderer {
Jinja,
DeepseekV32,
DeepseekV4,
}
pub struct HuggingFaceTokenizer {
tokenizer: HfTokenizer,
special_tokens: SpecialTokens,
vocab: HashMap<String, TokenIdType>,
reverse_vocab: HashMap<TokenIdType, String>,
chat_template: ChatTemplateState,
eos_token_ids: Vec<TokenIdType>,
renderer: Renderer,
}
impl HuggingFaceTokenizer {
pub fn from_file(file_path: &str) -> Result<Self> {
let path = std::path::Path::new(file_path);
let chat_template_path = path
.parent()
.and_then(crate::factory::discover_chat_template_in_dir);
Self::from_file_with_chat_template(file_path, chat_template_path.as_deref())
}
pub fn from_file_with_chat_template(
file_path: &str,
chat_template_path: Option<&str>,
) -> Result<Self> {
let mut tokenizer = HfTokenizer::from_file(file_path)
.map_err(|e| Error::msg(format!("Failed to load tokenizer: {e}")))?;
let vocab = tokenizer.get_vocab(true); let reverse_vocab: HashMap<TokenIdType, String> = vocab
.iter()
.map(|(token, &id)| (id, token.clone()))
.collect();
let config_result = Self::load_chat_template_and_config(file_path);
let mut chat_template_str = config_result.chat_template;
let add_bos_token = config_result.add_bos_token;
let add_eos_token = config_result.add_eos_token;
let special_tokens = Self::extract_special_tokens(&tokenizer, &config_result.config_tokens);
if let Some(template_path) = chat_template_path {
chat_template_str = load_chat_template_from_file(template_path)?;
}
let needs_eos = add_eos_token == Some(true);
let needs_bos = match add_bos_token {
Some(true) => true,
Some(false) => false,
None => needs_eos && Self::tokenizer_adds_special_tokens(&tokenizer),
};
if needs_bos || needs_eos {
if let Some(post_processor) =
Self::build_post_processor(needs_bos, needs_eos, &special_tokens, &vocab)
{
debug!(needs_bos, needs_eos, "Configured post_processor");
tokenizer.with_post_processor(Some(post_processor));
}
}
let eos_token_ids = std::path::Path::new(file_path)
.parent()
.map(crate::eos::load_eos_token_ids)
.unwrap_or_default();
let renderer = std::path::Path::new(file_path)
.parent()
.map(detect_renderer_from_config)
.unwrap_or(Renderer::Jinja);
Ok(HuggingFaceTokenizer {
tokenizer,
special_tokens,
vocab,
reverse_vocab,
chat_template: ChatTemplateState::new(chat_template_str)?,
eos_token_ids,
renderer,
})
}
fn tokenizer_adds_special_tokens(tokenizer: &HfTokenizer) -> bool {
tokenizer
.encode("", true)
.map(|enc| !enc.get_ids().is_empty())
.unwrap_or(false)
}
fn build_post_processor(
add_bos_token: bool,
add_eos_token: bool,
special_tokens: &SpecialTokens,
vocab: &HashMap<String, TokenIdType>,
) -> Option<TemplateProcessing> {
let mut template = String::with_capacity(32);
let mut tokens = Vec::with_capacity(2);
if add_bos_token {
let bos = special_tokens.bos_token.as_ref()?;
let bos_id = vocab.get(bos).copied()?;
template.push_str(bos);
template.push_str(":0 ");
tokens.push((bos.clone(), bos_id));
}
template.push_str("$A:0");
if add_eos_token {
let eos = special_tokens.eos_token.as_ref()?;
let eos_id = vocab.get(eos).copied()?;
template.push(' ');
template.push_str(eos);
template.push_str(":0");
tokens.push((eos.clone(), eos_id));
}
TemplateProcessing::builder()
.try_single(template.as_str())
.ok()?
.special_tokens(tokens)
.build()
.ok()
}
pub fn from_tokenizer(tokenizer: HfTokenizer) -> Self {
let special_tokens = Self::extract_special_tokens(&tokenizer, &ConfigTokens::default());
let vocab = tokenizer.get_vocab(true); let reverse_vocab: HashMap<TokenIdType, String> = vocab
.iter()
.map(|(token, &id)| (id, token.clone()))
.collect();
HuggingFaceTokenizer {
tokenizer,
special_tokens,
vocab,
reverse_vocab,
chat_template: ChatTemplateState::empty(),
eos_token_ids: Vec::new(), renderer: Renderer::Jinja,
}
}
fn extract_special_tokens(
tokenizer: &HfTokenizer,
config_tokens: &ConfigTokens,
) -> SpecialTokens {
let vocab = tokenizer.get_vocab(true);
let find_token = |patterns: &[&str]| -> Option<String> {
for pattern in patterns {
if vocab.contains_key(*pattern) {
return Some((*pattern).to_string());
}
}
None
};
let additional_special_tokens: Vec<String> = tokenizer
.get_added_tokens_decoder()
.iter()
.filter(|(_id, token)| token.special)
.map(|(_id, token)| token.content.clone())
.collect();
SpecialTokens {
bos_token: config_tokens
.bos_token
.clone()
.or_else(|| find_token(&["<s>", "<|startoftext|>", "<BOS>", "[CLS]"])),
eos_token: config_tokens
.eos_token
.clone()
.or_else(|| find_token(&["</s>", "<|endoftext|>", "<EOS>", "[SEP]"])),
unk_token: config_tokens
.unk_token
.clone()
.or_else(|| find_token(&["<unk>", "<UNK>", "[UNK]"])),
sep_token: find_token(&["[SEP]", "<sep>", "<SEP>"]),
pad_token: config_tokens
.pad_token
.clone()
.or_else(|| find_token(&["<pad>", "<PAD>", "[PAD]"])),
cls_token: find_token(&["[CLS]", "<cls>", "<CLS>"]),
mask_token: find_token(&["[MASK]", "<mask>", "<MASK>"]),
additional_special_tokens,
}
}
fn load_chat_template_and_config(tokenizer_path: &str) -> TokenizerConfigResult {
(|| {
let path = std::path::Path::new(tokenizer_path);
let config_path = path.parent()?.join("tokenizer_config.json");
if !config_path.exists() {
return None;
}
let content = std::fs::read_to_string(&config_path).ok()?;
let config: serde_json::Value = serde_json::from_str(&content).ok()?;
let chat_template = config
.get("chat_template")
.and_then(|v| v.as_str())
.map(String::from);
let add_bos_token = config.get("add_bos_token").and_then(|v| v.as_bool());
let add_eos_token = config.get("add_eos_token").and_then(|v| v.as_bool());
let get_token = |key: &str| -> Option<String> {
config.get(key).and_then(|v| {
v.as_str()
.map(String::from)
.or_else(|| v.get("content").and_then(|c| c.as_str()).map(String::from))
})
};
let config_tokens = ConfigTokens {
bos_token: get_token("bos_token"),
eos_token: get_token("eos_token"),
unk_token: get_token("unk_token"),
pad_token: get_token("pad_token"),
};
Some(TokenizerConfigResult {
chat_template,
add_bos_token,
add_eos_token,
config_tokens,
})
})()
.unwrap_or_default()
}
}
#[derive(Default)]
struct ConfigTokens {
bos_token: Option<String>,
eos_token: Option<String>,
unk_token: Option<String>,
pad_token: Option<String>,
}
#[derive(Default)]
struct TokenizerConfigResult {
chat_template: Option<String>,
add_bos_token: Option<bool>,
add_eos_token: Option<bool>,
config_tokens: ConfigTokens,
}
impl Encoder for HuggingFaceTokenizer {
fn encode(&self, input: &str, add_special_tokens: bool) -> Result<Encoding> {
self.tokenizer
.encode(input, add_special_tokens)
.map_err(|e| Error::msg(format!("Encoding failed: {e}")))
.map(|encoding| Encoding::Hf(Box::new(encoding)))
}
fn encode_batch(&self, inputs: &[&str], add_special_tokens: bool) -> Result<Vec<Encoding>> {
self.tokenizer
.encode_batch(inputs.to_vec(), add_special_tokens)
.map_err(|e| Error::msg(format!("Batch encoding failed: {e}")))
.map(|encodings| {
encodings
.into_iter()
.map(|e| Encoding::Hf(Box::new(e)))
.collect()
})
}
}
impl Decoder for HuggingFaceTokenizer {
fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String> {
self.tokenizer
.decode(token_ids, skip_special_tokens)
.map_err(|e| Error::msg(format!("Decoding failed: {e}")))
}
fn decode_step(
&self,
token_id: TokenIdType,
ids: &mut Vec<TokenIdType>,
prefix: &mut String,
prefix_index: &mut usize,
skip_special_tokens: bool,
) -> Result<Option<String>> {
step_decode_stream(
&self.tokenizer,
vec![token_id],
skip_special_tokens,
ids,
prefix,
prefix_index,
)
.map_err(|e| Error::msg(format!("Decode stream error: {e}")))
}
}
impl TokenizerTrait for HuggingFaceTokenizer {
fn vocab_size(&self) -> usize {
self.tokenizer.get_vocab_size(false)
}
fn get_special_tokens(&self) -> &SpecialTokens {
&self.special_tokens
}
fn token_to_id(&self, token: &str) -> Option<TokenIdType> {
self.vocab.get(token).copied()
}
fn id_to_token(&self, id: TokenIdType) -> Option<String> {
self.reverse_vocab.get(&id).cloned()
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn eos_token_ids(&self) -> &[TokenIdType] {
&self.eos_token_ids
}
fn apply_chat_template(
&self,
messages: &[serde_json::Value],
params: ChatTemplateParams,
) -> Result<String> {
match self.renderer {
Renderer::Jinja => {
if params.special_tokens.is_some() {
return self.chat_template.apply(messages, params);
}
let params = ChatTemplateParams {
special_tokens: Some(&self.special_tokens),
..params
};
self.chat_template.apply(messages, params)
}
Renderer::DeepseekV32 => apply_deepseek_v32(messages, ¶ms),
Renderer::DeepseekV4 => apply_deepseek_v4(messages, ¶ms),
}
}
fn chat_template_content_format(&self) -> ChatTemplateContentFormat {
self.chat_template.content_format()
}
fn thinking_toggle(&self) -> ThinkingToggle {
match self.renderer {
Renderer::DeepseekV32 | Renderer::DeepseekV4 => ThinkingToggle::DefaultOff,
Renderer::Jinja => self.chat_template.thinking_toggle(),
}
}
fn thinking_key_name(&self) -> Option<ThinkingKeyName> {
match self.renderer {
Renderer::DeepseekV32 | Renderer::DeepseekV4 => Some(ThinkingKeyName::Thinking),
Renderer::Jinja => self.chat_template.thinking_key_name(),
}
}
fn think_in_prefill(&self) -> bool {
match self.renderer {
Renderer::DeepseekV32 | Renderer::DeepseekV4 => true,
Renderer::Jinja => self.chat_template.think_in_prefill(),
}
}
fn set_chat_template(&mut self, template: String) -> Result<()> {
self.chat_template.set(template)
}
}
fn detect_renderer_from_config(dir: &std::path::Path) -> Renderer {
let path = dir.join("config.json");
if !path.exists() {
return Renderer::Jinja;
}
let content = match std::fs::read_to_string(&path) {
Ok(c) => c,
Err(err) => {
debug!(?err, ?path, "config.json unreadable; using Jinja renderer");
return Renderer::Jinja;
}
};
let value: serde_json::Value = match serde_json::from_str(&content) {
Ok(v) => v,
Err(err) => {
debug!(?err, ?path, "config.json malformed; using Jinja renderer");
return Renderer::Jinja;
}
};
let architectures = value.get("architectures").and_then(|v| v.as_array());
let arch_strs: Vec<&str> = architectures
.map(|a| a.iter().filter_map(|v| v.as_str()).collect())
.unwrap_or_default();
if arch_strs.contains(&"DeepseekV32ForCausalLM") {
debug!(?path, "selected DeepseekV32 chat-template renderer");
return Renderer::DeepseekV32;
}
if arch_strs.contains(&"DeepseekV4ForCausalLM") {
debug!(?path, "selected DeepseekV4 chat-template renderer");
return Renderer::DeepseekV4;
}
Renderer::Jinja
}
fn derive_thinking_mode(params: &ChatTemplateParams) -> deepseek_v32::ThinkingMode {
let enabled = params
.template_kwargs
.and_then(|k| k.get("thinking"))
.and_then(serde_json::Value::as_bool)
.unwrap_or(false);
if enabled {
deepseek_v32::ThinkingMode::Thinking
} else {
deepseek_v32::ThinkingMode::Chat
}
}
fn resolve_drop_thinking(messages: &[serde_json::Value]) -> bool {
!messages.iter().any(|m| {
let role = m.get("role").and_then(|r| r.as_str());
matches!(role, Some("system" | "developer"))
&& m.get("tools")
.and_then(|t| t.as_array())
.is_some_and(|arr| !arr.is_empty())
})
}
fn inject_tools_into_messages(
messages: &[serde_json::Value],
tools: Option<&[serde_json::Value]>,
) -> Option<Vec<serde_json::Value>> {
let tools = tools?;
if tools.is_empty() {
return None;
}
let mut owned: Vec<serde_json::Value> = messages.to_vec();
let first_role = owned
.first()
.and_then(|m| m.get("role"))
.and_then(|r| r.as_str());
if !matches!(first_role, Some("system" | "developer")) {
owned.insert(0, serde_json::json!({ "role": "system", "content": "" }));
}
if let Some(obj) = owned[0].as_object_mut() {
obj.insert("tools".into(), serde_json::Value::Array(tools.to_vec()));
}
Some(owned)
}
fn apply_deepseek_v32(
messages: &[serde_json::Value],
params: &ChatTemplateParams,
) -> Result<String> {
let owned = inject_tools_into_messages(messages, params.tools);
let msgs: &[serde_json::Value] = owned.as_deref().unwrap_or(messages);
let thinking_mode = derive_thinking_mode(params);
let encode_params = deepseek_v32::EncodeParams {
add_default_bos_token: true,
drop_thinking: resolve_drop_thinking(msgs),
};
deepseek_v32::encode_messages(msgs, thinking_mode, &encode_params)
.map_err(|e| Error::msg(format!("DeepSeek V3.2 encode failed: {e}")))
}
fn apply_deepseek_v4(
messages: &[serde_json::Value],
params: &ChatTemplateParams,
) -> Result<String> {
let owned = inject_tools_into_messages(messages, params.tools);
let msgs: &[serde_json::Value] = owned.as_deref().unwrap_or(messages);
let thinking_mode = derive_thinking_mode(params);
let reasoning_effort = params
.template_kwargs
.and_then(|k| k.get("reasoning_effort"))
.and_then(|v| v.as_str())
.and_then(|s| match s {
"max" => Some(deepseek_v4::ReasoningEffort::Max),
"high" => Some(deepseek_v4::ReasoningEffort::High),
_ => None,
});
let encode_params = deepseek_v4::EncodeParams {
add_default_bos_token: true,
drop_thinking: resolve_drop_thinking(msgs),
reasoning_effort,
};
deepseek_v4::encode_messages(msgs, thinking_mode, &encode_params)
.map_err(|e| Error::msg(format!("DeepSeek V4 encode failed: {e}")))
}