use std::path::Path;
use anyhow::{Context, Result};
use minijinja::Environment;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ChatRole {
System,
User,
Assistant,
Tool,
}
impl std::fmt::Display for ChatRole {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ChatRole::System => f.write_str("system"),
ChatRole::User => f.write_str("user"),
ChatRole::Assistant => f.write_str("assistant"),
ChatRole::Tool => f.write_str("tool"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
pub role: ChatRole,
pub content: String,
}
impl ChatMessage {
pub fn system(content: impl Into<String>) -> Self {
Self {
role: ChatRole::System,
content: content.into(),
}
}
pub fn user(content: impl Into<String>) -> Self {
Self {
role: ChatRole::User,
content: content.into(),
}
}
pub fn assistant(content: impl Into<String>) -> Self {
Self {
role: ChatRole::Assistant,
content: content.into(),
}
}
}
pub struct ChatTemplate {
template_src: String,
}
impl ChatTemplate {
pub fn from_tokenizer_config(path: &Path) -> Result<Self> {
let text = std::fs::read_to_string(path).context("Failed to read tokenizer_config.json")?;
let config: serde_json::Value =
serde_json::from_str(&text).context("Invalid tokenizer_config.json")?;
let template_src = config["chat_template"]
.as_str()
.context("No chat_template found in tokenizer_config.json")?
.to_owned();
Ok(Self { template_src })
}
pub fn from_template(template: impl Into<String>) -> Self {
Self {
template_src: template.into(),
}
}
pub fn render(&self, messages: &[ChatMessage], add_generation_prompt: bool) -> Result<String> {
let mut env = Environment::new();
env.add_template("chat", &self.template_src)
.map_err(|e| anyhow::anyhow!("Template parse error: {e}"))?;
let tmpl = env
.get_template("chat")
.map_err(|e| anyhow::anyhow!("Template load error: {e}"))?;
let messages_val: Vec<serde_json::Value> = messages
.iter()
.map(|m| {
serde_json::json!({
"role": m.role.to_string(),
"content": m.content,
})
})
.collect();
let ctx = serde_json::json!({
"messages": messages_val,
"add_generation_prompt": add_generation_prompt,
"bos_token": "<s>",
"eos_token": "</s>",
});
tmpl.render(ctx)
.map_err(|e| anyhow::anyhow!("Template render error: {e}"))
}
}
pub mod builtin {
pub const CHATML: &str = concat!(
"{% for message in messages %}",
"<|im_start|>{{ message['role'] }}\n{{ message['content'] }}<|im_end|>\n",
"{% endfor %}",
"{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}",
);
pub const LLAMA3: &str = concat!(
"<|begin_of_text|>",
"{% for message in messages %}",
"<|start_header_id|>{{ message['role'] }}<|end_header_id|>\n\n",
"{{ message['content'] }}<|eot_id|>",
"{% endfor %}",
"{% if add_generation_prompt %}",
"<|start_header_id|>assistant<|end_header_id|>\n\n",
"{% endif %}",
);
pub const LLAMA2: &str = concat!(
"{% if messages[0]['role'] == 'system' %}",
"{{ '[INST] <<SYS>>\n' + messages[0]['content'] + '\n<</SYS>>\n\n' }}",
"{% set messages = messages[1:] %}",
"{% endif %}",
"{% for message in messages %}",
"{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}",
"{% elif message['role'] == 'assistant' %}{{ message['content'] + '</s>' }}",
"{% endif %}",
"{% endfor %}",
);
pub const GEMMA: &str = concat!(
"{% for message in messages %}",
"{% if message['role'] == 'user' %}<start_of_turn>user\n{{ message['content'] }}<end_of_turn>\n",
"{% elif message['role'] == 'assistant' %}<start_of_turn>model\n{{ message['content'] }}<end_of_turn>\n",
"{% endif %}",
"{% endfor %}",
"{% if add_generation_prompt %}<start_of_turn>model\n{% endif %}",
);
pub const ZEPHYR: &str = concat!(
"{% for message in messages %}",
"{% if message['role'] == 'system' %}",
"<|system|>\n{{ message['content'] }}</s>\n",
"{% elif message['role'] == 'user' %}",
"<|user|>\n{{ message['content'] }}</s>\n",
"{% elif message['role'] == 'assistant' %}",
"<|assistant|>\n{{ message['content'] }}</s>\n",
"{% endif %}",
"{% endfor %}",
"{% if add_generation_prompt %}<|assistant|>\n{% endif %}",
);
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn chatml_render() {
let tmpl = ChatTemplate::from_template(builtin::CHATML);
let messages = vec![
ChatMessage::system("You are a helpful assistant."),
ChatMessage::user("Hello!"),
];
let out = tmpl.render(&messages, true).unwrap();
assert!(out.contains("<|im_start|>system"));
assert!(out.contains("<|im_start|>user"));
assert!(out.contains("<|im_start|>assistant"));
}
#[test]
fn llama3_render() {
let tmpl = ChatTemplate::from_template(builtin::LLAMA3);
let messages = vec![ChatMessage::user("What is 2+2?")];
let out = tmpl.render(&messages, true).unwrap();
assert!(out.contains("<|begin_of_text|>"));
assert!(out.contains("<|start_header_id|>user<|end_header_id|>"));
}
}