use crate::AprenderError;
use minijinja::{context, Environment};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::Path;
pub const MAX_TEMPLATE_SIZE: usize = 100 * 1024;
pub const MAX_RECURSION_DEPTH: usize = 100;
pub const MAX_LOOP_ITERATIONS: usize = 10_000;
#[must_use]
pub fn sanitize_user_content(content: &str) -> String {
content
.replace("<|im_start|>", "< |im_start|>")
.replace("<|im_end|>", "< |im_end|>")
.replace("<|endoftext|>", "< |endoftext|>")
.replace("<|im_sep|>", "< |im_sep|>")
.replace("<|end|>", "< |end|>")
.replace("<s>", "< s>")
.replace("</s>", "< /s>")
.replace("[INST]", "[ INST]")
.replace("[/INST]", "[ /INST]")
.replace("<<SYS>>", "< <SYS>>")
.replace("<</SYS>>", "< </SYS>>")
}
#[must_use]
pub fn contains_injection_patterns(content: &str) -> bool {
const PATTERNS: &[&str] = &[
"<|im_start|>",
"<|im_end|>",
"<|endoftext|>",
"<|im_sep|>",
"<|end|>",
"<s>",
"</s>",
"[INST]",
"[/INST]",
"<<SYS>>",
"<</SYS>>",
];
PATTERNS.iter().any(|p| content.contains(p))
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct ChatMessage {
pub role: String,
pub content: String,
}
impl ChatMessage {
#[must_use]
pub fn new(role: impl Into<String>, content: impl Into<String>) -> Self {
Self {
role: role.into(),
content: content.into(),
}
}
#[must_use]
pub fn system(content: impl Into<String>) -> Self {
Self::new("system", content)
}
#[must_use]
pub fn user(content: impl Into<String>) -> Self {
Self::new("user", content)
}
#[must_use]
pub fn assistant(content: impl Into<String>) -> Self {
Self::new("assistant", content)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum TemplateFormat {
ChatML, Llama2, Mistral, Alpaca, Phi, Custom, Raw, }
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct SpecialTokens {
pub bos_token: Option<String>,
pub eos_token: Option<String>,
pub unk_token: Option<String>,
pub pad_token: Option<String>,
pub im_start_token: Option<String>, pub im_end_token: Option<String>, pub inst_start: Option<String>, pub inst_end: Option<String>, pub sys_start: Option<String>, pub sys_end: Option<String>, }
pub trait ChatTemplateEngine {
fn format_message(&self, role: &str, content: &str) -> Result<String, AprenderError>;
fn format_conversation(&self, messages: &[ChatMessage]) -> Result<String, AprenderError>;
fn special_tokens(&self) -> &SpecialTokens;
fn format(&self) -> TemplateFormat;
fn supports_system_prompt(&self) -> bool;
}
#[derive(Debug, Deserialize)]
struct TokenizerConfig {
chat_template: Option<String>,
bos_token: Option<String>,
eos_token: Option<String>,
unk_token: Option<String>,
pad_token: Option<String>,
#[serde(flatten)]
#[allow(dead_code)]
extra: HashMap<String, serde_json::Value>,
}
pub struct HuggingFaceTemplate {
env: Environment<'static>,
template_str: String,
special_tokens: SpecialTokens,
format: TemplateFormat,
supports_system: bool,
}
impl std::fmt::Debug for HuggingFaceTemplate {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("HuggingFaceTemplate")
.field("template_str", &self.template_str)
.field("special_tokens", &self.special_tokens)
.field("format", &self.format)
.field("supports_system", &self.supports_system)
.finish_non_exhaustive()
}
}
impl HuggingFaceTemplate {
pub fn new(
template_str: String,
special_tokens: SpecialTokens,
format: TemplateFormat,
) -> Result<Self, AprenderError> {
let mut env = Environment::new();
env.set_recursion_limit(100);
let mut template = Self {
env,
template_str: template_str.clone(),
special_tokens,
format,
supports_system: true, };
template
.env
.add_template_owned("chat", template_str)
.map_err(|e| AprenderError::ValidationError {
message: format!("Invalid template syntax: {e}"),
})?;
Ok(template)
}
pub fn from_tokenizer_config(path: &Path) -> Result<Self, AprenderError> {
let content = std::fs::read_to_string(path).map_err(AprenderError::Io)?;
Self::from_json(&content)
}
pub fn from_json(json: &str) -> Result<Self, AprenderError> {
let config: TokenizerConfig = serde_json::from_str(json).map_err(|e| {
AprenderError::Serialization(format!("Invalid tokenizer config JSON: {e}"))
})?;
let template_str = config
.chat_template
.ok_or_else(|| AprenderError::ValidationError {
message: "No 'chat_template' found in config".to_string(),
})?;
let special_tokens = SpecialTokens {
bos_token: config.bos_token,
eos_token: config.eos_token,
unk_token: config.unk_token,
pad_token: config.pad_token,
..Default::default()
};
let format = Self::detect_format(&template_str, &special_tokens);
Self::new(template_str, special_tokens, format)
}
fn detect_format(template: &str, _special_tokens: &SpecialTokens) -> TemplateFormat {
if template.contains("<|im_start|>") {
return TemplateFormat::ChatML;
}
if template.contains("[INST]") {
return TemplateFormat::Llama2; }
if template.contains("### Instruction:") {
return TemplateFormat::Alpaca;
}
TemplateFormat::Custom
}
}
impl ChatTemplateEngine for HuggingFaceTemplate {
fn format_message(&self, role: &str, content: &str) -> Result<String, AprenderError> {
let messages = vec![ChatMessage::new(role, content)];
self.format_conversation(&messages)
}
fn format_conversation(&self, messages: &[ChatMessage]) -> Result<String, AprenderError> {
let tmpl = self
.env
.get_template("chat")
.map_err(|e| AprenderError::ValidationError {
message: format!("Template retrieval error: {e}"),
})?;
let bos = self.special_tokens.bos_token.as_deref().unwrap_or("");
let eos = self.special_tokens.eos_token.as_deref().unwrap_or("");
let output = tmpl
.render(context!(
messages => messages,
add_generation_prompt => true,
bos_token => bos,
eos_token => eos
))
.map_err(|e| AprenderError::ValidationError {
message: format!("Template render error: {e}"),
})?;
Ok(output)
}
fn special_tokens(&self) -> &SpecialTokens {
&self.special_tokens
}
fn format(&self) -> TemplateFormat {
self.format
}
fn supports_system_prompt(&self) -> bool {
self.supports_system
}
}
#[derive(Debug, Clone)]
pub struct ChatMLTemplate {
special_tokens: SpecialTokens,
}
impl ChatMLTemplate {
#[must_use]
pub fn new() -> Self {
Self {
special_tokens: SpecialTokens {
bos_token: Some("<|endoftext|>".to_string()),
eos_token: Some("<|im_end|>".to_string()),
im_start_token: Some("<|im_start|>".to_string()),
im_end_token: Some("<|im_end|>".to_string()),
..Default::default()
},
}
}
#[must_use]
pub fn with_tokens(special_tokens: SpecialTokens) -> Self {
Self { special_tokens }
}
}
include!("template.rs");
include!("raw_template.rs");