use crate::error::RealizarError;
use minijinja::{context, Environment};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
pub const MAX_TEMPLATE_SIZE: usize = 100 * 1024;
pub const MAX_RECURSION_DEPTH: usize = 100;
pub const MAX_LOOP_ITERATIONS: usize = 10_000;
#[must_use]
pub fn sanitize_special_tokens(content: &str) -> String {
content.replace("<|", "<\u{200B}|")
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct ChatMessage {
pub role: String,
pub content: String,
}
impl ChatMessage {
#[must_use]
pub fn new(role: impl Into<String>, content: impl Into<String>) -> Self {
Self {
role: role.into(),
content: content.into(),
}
}
#[must_use]
pub fn system(content: impl Into<String>) -> Self {
Self::new("system", content)
}
#[must_use]
pub fn user(content: impl Into<String>) -> Self {
Self::new("user", content)
}
#[must_use]
pub fn assistant(content: impl Into<String>) -> Self {
Self::new("assistant", content)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum TemplateFormat {
ChatML,
Qwen3NoThink,
Llama2,
Zephyr,
Mistral,
Alpaca,
Phi,
Custom,
#[default]
Raw,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct SpecialTokens {
pub bos_token: Option<String>,
pub eos_token: Option<String>,
pub unk_token: Option<String>,
pub pad_token: Option<String>,
pub im_start_token: Option<String>,
pub im_end_token: Option<String>,
pub inst_start: Option<String>,
pub inst_end: Option<String>,
pub sys_start: Option<String>,
pub sys_end: Option<String>,
}
pub trait ChatTemplateEngine: Send + Sync {
fn format_message(&self, role: &str, content: &str) -> Result<String, RealizarError>;
fn format_conversation(&self, messages: &[ChatMessage]) -> Result<String, RealizarError>;
fn special_tokens(&self) -> &SpecialTokens;
fn format(&self) -> TemplateFormat;
fn supports_system_prompt(&self) -> bool;
}
#[derive(Debug, Deserialize)]
struct TokenizerConfig {
chat_template: Option<String>,
bos_token: Option<String>,
eos_token: Option<String>,
unk_token: Option<String>,
pad_token: Option<String>,
#[serde(flatten)]
#[allow(dead_code)]
extra: HashMap<String, serde_json::Value>,
}
pub struct HuggingFaceTemplate {
env: Environment<'static>,
template_str: String,
special_tokens: SpecialTokens,
format: TemplateFormat,
supports_system: bool,
}
impl std::fmt::Debug for HuggingFaceTemplate {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("HuggingFaceTemplate")
.field("template_str", &self.template_str)
.field("special_tokens", &self.special_tokens)
.field("format", &self.format)
.field("supports_system", &self.supports_system)
.finish_non_exhaustive()
}
}
impl HuggingFaceTemplate {
pub fn new(
template_str: String,
special_tokens: SpecialTokens,
format: TemplateFormat,
) -> Result<Self, RealizarError> {
let mut env = Environment::new();
env.set_recursion_limit(MAX_RECURSION_DEPTH);
let mut template = Self {
env,
template_str: template_str.clone(),
special_tokens,
format,
supports_system: true,
};
template
.env
.add_template_owned("chat", template_str)
.map_err(|e| RealizarError::FormatError {
reason: format!("Invalid template syntax: {e}"),
})?;
Ok(template)
}
pub fn from_json(json: &str) -> Result<Self, RealizarError> {
let config: TokenizerConfig =
serde_json::from_str(json).map_err(|e| RealizarError::FormatError {
reason: format!("Invalid tokenizer config: {e}"),
})?;
let template_str = config
.chat_template
.ok_or_else(|| RealizarError::FormatError {
reason: "No 'chat_template' found in config".to_string(),
})?;
let special_tokens = SpecialTokens {
bos_token: config.bos_token,
eos_token: config.eos_token,
unk_token: config.unk_token,
pad_token: config.pad_token,
..Default::default()
};
let format = Self::detect_format(&template_str);
Self::new(template_str, special_tokens, format)
}
fn detect_format(template: &str) -> TemplateFormat {
if template.contains("<|im_start|>") {
return TemplateFormat::ChatML;
}
if template.contains("[INST]") {
return TemplateFormat::Llama2;
}
if template.contains("### Instruction:") {
return TemplateFormat::Alpaca;
}
TemplateFormat::Custom
}
}
impl ChatTemplateEngine for HuggingFaceTemplate {
fn format_message(&self, role: &str, content: &str) -> Result<String, RealizarError> {
let safe_content = sanitize_special_tokens(content);
let messages = vec![ChatMessage::new(role, safe_content)];
self.format_conversation(&messages)
}
fn format_conversation(&self, messages: &[ChatMessage]) -> Result<String, RealizarError> {
let tmpl = self
.env
.get_template("chat")
.map_err(|e| RealizarError::FormatError {
reason: format!("Template error: {e}"),
})?;
let bos = self.special_tokens.bos_token.as_deref().unwrap_or("");
let eos = self.special_tokens.eos_token.as_deref().unwrap_or("");
let sanitized_messages: Vec<ChatMessage> = messages
.iter()
.map(|m| ChatMessage::new(&m.role, sanitize_special_tokens(&m.content)))
.collect();
let output = tmpl
.render(context!(
messages => sanitized_messages,
add_generation_prompt => true,
bos_token => bos,
eos_token => eos
))
.map_err(|e| RealizarError::FormatError {
reason: format!("Render error: {e}"),
})?;
Ok(output)
}
fn special_tokens(&self) -> &SpecialTokens {
&self.special_tokens
}
fn format(&self) -> TemplateFormat {
self.format
}
fn supports_system_prompt(&self) -> bool {
self.supports_system
}
}
#[derive(Debug, Clone)]
pub struct ChatMLTemplate {
special_tokens: SpecialTokens,
}
impl ChatMLTemplate {
#[must_use]
pub fn new() -> Self {
Self {
special_tokens: SpecialTokens {
bos_token: Some("<|endoftext|>".to_string()),
eos_token: Some("<|im_end|>".to_string()),
im_start_token: Some("<|im_start|>".to_string()),
im_end_token: Some("<|im_end|>".to_string()),
..Default::default()
},
}
}
}
impl Default for ChatMLTemplate {
fn default() -> Self {
Self::new()
}
}
impl ChatTemplateEngine for ChatMLTemplate {
fn format_message(&self, role: &str, content: &str) -> Result<String, RealizarError> {
let safe_content = sanitize_special_tokens(content);
Ok(format!("<|im_start|>{role}\n{safe_content}<|im_end|>\n"))
}
fn format_conversation(&self, messages: &[ChatMessage]) -> Result<String, RealizarError> {
use std::fmt::Write;
let mut result = String::new();
for msg in messages {
let safe_content = sanitize_special_tokens(&msg.content);
let _ = write!(
result,
"<|im_start|>{}\n{}<|im_end|>\n",
msg.role, safe_content
);
}
result.push_str("<|im_start|>assistant\n");
Ok(result)
}
fn special_tokens(&self) -> &SpecialTokens {
&self.special_tokens
}
fn format(&self) -> TemplateFormat {
TemplateFormat::ChatML
}
fn supports_system_prompt(&self) -> bool {
true
}
}
include!("chat_template_qwen3_nothink.rs");
include!("chat_template_llama2.rs");
include!("chat_template_helpers.rs");
include!("chat_template_special_tokens.rs");
include!("chat_template_prop_format.rs");
include!("chat_template_alpaca_format.rs");
include!("chat_template_contract_tests.rs");