use anyhow::{Context, Result, anyhow};
use minijinja::{Environment, Error as JinjaError, ErrorKind, Value};
use rlx_gguf::{GgufFile, MetaValue};
use std::path::Path;
pub fn auto_chat_template(path: &Path) -> Result<ChatTemplate> {
ChatTemplate::from_gguf(path)
}
#[derive(Debug, Clone)]
pub struct ChatMessage {
pub role: String,
pub content: String,
}
impl ChatMessage {
pub fn user(content: impl Into<String>) -> Self {
Self {
role: "user".into(),
content: content.into(),
}
}
pub fn system(content: impl Into<String>) -> Self {
Self {
role: "system".into(),
content: content.into(),
}
}
pub fn assistant(content: impl Into<String>) -> Self {
Self {
role: "assistant".into(),
content: content.into(),
}
}
}
#[derive(Debug, Clone)]
pub enum ChatTemplateSource {
Inline,
GgufMetadata(String),
}
pub struct ChatTemplate {
env: Environment<'static>,
source_text: String,
source_kind: ChatTemplateSource,
bos_token: Option<String>,
eos_token: Option<String>,
}
const TEMPLATE_NAME: &str = "chat";
fn build_env(source: String) -> Result<Environment<'static>> {
let mut env = Environment::new();
env.add_function(
"raise_exception",
|msg: String| -> Result<Value, JinjaError> {
Err(JinjaError::new(ErrorKind::InvalidOperation, msg))
},
);
env.add_template_owned(TEMPLATE_NAME, source)
.context("compiling chat template")?;
Ok(env)
}
impl ChatTemplate {
pub fn from_source(src: impl Into<String>) -> Result<Self> {
let source_text: String = src.into();
let env = build_env(source_text.clone())?;
Ok(Self {
env,
source_text,
source_kind: ChatTemplateSource::Inline,
bos_token: None,
eos_token: None,
})
}
pub fn with_tokens(mut self, bos: Option<String>, eos: Option<String>) -> Self {
self.bos_token = bos;
self.eos_token = eos;
self
}
pub fn from_gguf(path: &Path) -> Result<Self> {
let raw = GgufFile::from_path(path).with_context(|| format!("opening GGUF {path:?}"))?;
Self::from_gguf_file(&raw)
}
pub fn from_gguf_file(raw: &GgufFile) -> Result<Self> {
let (key, src) = pick_chat_template_meta(raw).ok_or_else(|| {
anyhow!("no tokenizer.chat_template or tokenizer.ggml.chat_template in GGUF metadata")
})?;
let env = build_env(src.clone())?;
let bos = resolve_special_token(raw, "tokenizer.ggml.bos_token_id");
let eos = resolve_special_token(raw, "tokenizer.ggml.eos_token_id");
Ok(Self {
env,
source_text: src,
source_kind: ChatTemplateSource::GgufMetadata(key.to_owned()),
bos_token: bos,
eos_token: eos,
})
}
pub fn source_text(&self) -> &str {
&self.source_text
}
pub fn source_kind(&self) -> &ChatTemplateSource {
&self.source_kind
}
pub fn bos_token(&self) -> Option<&str> {
self.bos_token.as_deref()
}
pub fn eos_token(&self) -> Option<&str> {
self.eos_token.as_deref()
}
pub fn render(&self, messages: &[ChatMessage], add_generation_prompt: bool) -> Result<String> {
let msgs: Vec<Value> = messages
.iter()
.map(|m| {
Value::from_serialize(serde_json::json!({
"role": m.role,
"content": m.content,
}))
})
.collect();
let ctx = minijinja::context! {
messages => Value::from(msgs),
add_generation_prompt => add_generation_prompt,
bos_token => self.bos_token.clone().unwrap_or_default(),
eos_token => self.eos_token.clone().unwrap_or_default(),
};
let tmpl = self
.env
.get_template(TEMPLATE_NAME)
.expect("template registered in build_env");
tmpl.render(ctx).context("rendering chat template")
}
}
fn pick_chat_template_meta(raw: &GgufFile) -> Option<(&'static str, String)> {
for key in ["tokenizer.chat_template", "tokenizer.ggml.chat_template"] {
if let Some(MetaValue::String(s)) = raw.metadata.get(key) {
return Some((key, s.clone()));
}
}
None
}
fn resolve_special_token(raw: &GgufFile, id_key: &str) -> Option<String> {
let id = raw.metadata.get(id_key).and_then(MetaValue::as_u32)? as usize;
let toks = raw.metadata.get("tokenizer.ggml.tokens")?;
let MetaValue::Array(arr) = toks else {
return None;
};
match arr.get(id)? {
MetaValue::String(s) => Some(s.clone()),
_ => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
const QWEN_TEMPLATE: &str = "{% for m in messages %}<|im_start|>{{ m.role }}\n{{ m.content }}<|im_end|>\n{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}";
const LLAMA3_TEMPLATE: &str = "{% for m in messages %}{% if loop.first %}{{ bos_token }}{% endif %}<|start_header_id|>{{ m.role }}<|end_header_id|>\n\n{{ m.content }}<|eot_id|>{% endfor %}{% if add_generation_prompt %}<|start_header_id|>assistant<|end_header_id|>\n\n{% endif %}";
const GEMMA_TEMPLATE: &str = "{% for m in messages %}{% set role = 'user' if m.role == 'system' else m.role %}<start_of_turn>{{ role }}\n{{ m.content }}<end_of_turn>\n{% endfor %}{% if add_generation_prompt %}<start_of_turn>model\n{% endif %}";
fn sample_conv() -> Vec<ChatMessage> {
vec![ChatMessage::system("be concise"), ChatMessage::user("hi")]
}
#[test]
fn qwen_template_renders_with_generation_prompt() {
let t = ChatTemplate::from_source(QWEN_TEMPLATE).unwrap();
let out = t.render(&sample_conv(), true).unwrap();
let expected = "<|im_start|>system\nbe concise<|im_end|>\n\
<|im_start|>user\nhi<|im_end|>\n\
<|im_start|>assistant\n";
assert_eq!(out, expected);
}
#[test]
fn qwen_template_omits_generation_prompt_when_disabled() {
let t = ChatTemplate::from_source(QWEN_TEMPLATE).unwrap();
let out = t.render(&sample_conv(), false).unwrap();
assert!(out.ends_with("<|im_end|>\n"));
assert!(!out.contains("<|im_start|>assistant\n"));
}
#[test]
fn llama3_template_uses_bos_token() {
let t = ChatTemplate::from_source(LLAMA3_TEMPLATE)
.unwrap()
.with_tokens(Some("<|begin_of_text|>".into()), Some("<|eot_id|>".into()));
let out = t.render(&sample_conv(), true).unwrap();
let expected = "<|begin_of_text|>\
<|start_header_id|>system<|end_header_id|>\n\nbe concise<|eot_id|>\
<|start_header_id|>user<|end_header_id|>\n\nhi<|eot_id|>\
<|start_header_id|>assistant<|end_header_id|>\n\n";
assert_eq!(out, expected);
assert_eq!(t.bos_token(), Some("<|begin_of_text|>"));
assert_eq!(t.eos_token(), Some("<|eot_id|>"));
}
#[test]
fn gemma_template_rewrites_system_to_user() {
let t = ChatTemplate::from_source(GEMMA_TEMPLATE).unwrap();
let out = t.render(&sample_conv(), true).unwrap();
let expected = "<start_of_turn>user\nbe concise<end_of_turn>\n\
<start_of_turn>user\nhi<end_of_turn>\n\
<start_of_turn>model\n";
assert_eq!(out, expected);
}
#[test]
fn raise_exception_propagates_as_error() {
let t = ChatTemplate::from_source("{{ raise_exception('nope') }}").unwrap();
let err = t.render(&[], false).unwrap_err();
assert!(format!("{err:#}").contains("nope"));
}
#[test]
fn from_gguf_reads_template_and_special_tokens() {
let mut buf: Vec<u8> = Vec::new();
buf.extend_from_slice(&rlx_gguf::GGUF_MAGIC.to_le_bytes());
buf.extend_from_slice(&3u32.to_le_bytes());
buf.extend_from_slice(&1u64.to_le_bytes()); buf.extend_from_slice(&4u64.to_le_bytes());
let write_string_kv = |buf: &mut Vec<u8>, k: &str, v: &str| {
buf.extend_from_slice(&(k.len() as u64).to_le_bytes());
buf.extend_from_slice(k.as_bytes());
buf.extend_from_slice(&8u32.to_le_bytes());
buf.extend_from_slice(&(v.len() as u64).to_le_bytes());
buf.extend_from_slice(v.as_bytes());
};
let write_u32_kv = |buf: &mut Vec<u8>, k: &str, v: u32| {
buf.extend_from_slice(&(k.len() as u64).to_le_bytes());
buf.extend_from_slice(k.as_bytes());
buf.extend_from_slice(&4u32.to_le_bytes());
buf.extend_from_slice(&v.to_le_bytes());
};
let write_string_array_kv = |buf: &mut Vec<u8>, k: &str, items: &[&str]| {
buf.extend_from_slice(&(k.len() as u64).to_le_bytes());
buf.extend_from_slice(k.as_bytes());
buf.extend_from_slice(&9u32.to_le_bytes());
buf.extend_from_slice(&8u32.to_le_bytes());
buf.extend_from_slice(&(items.len() as u64).to_le_bytes());
for s in items {
buf.extend_from_slice(&(s.len() as u64).to_le_bytes());
buf.extend_from_slice(s.as_bytes());
}
};
write_string_kv(&mut buf, "tokenizer.chat_template", QWEN_TEMPLATE);
write_string_array_kv(
&mut buf,
"tokenizer.ggml.tokens",
&["<pad>", "<bos>", "<eos>", "hi"],
);
write_u32_kv(&mut buf, "tokenizer.ggml.bos_token_id", 1);
write_u32_kv(&mut buf, "tokenizer.ggml.eos_token_id", 2);
let name = "w";
buf.extend_from_slice(&(name.len() as u64).to_le_bytes());
buf.extend_from_slice(name.as_bytes());
buf.extend_from_slice(&1u32.to_le_bytes());
buf.extend_from_slice(&4u64.to_le_bytes());
buf.extend_from_slice(&(rlx_gguf::GgmlType::F32 as u32).to_le_bytes());
buf.extend_from_slice(&0u64.to_le_bytes());
while !buf
.len()
.is_multiple_of(rlx_gguf::DEFAULT_ALIGNMENT as usize)
{
buf.push(0);
}
for _ in 0..4 {
buf.extend_from_slice(&1.0f32.to_le_bytes());
}
let path = std::env::temp_dir().join("rlx_chat_template_from_gguf.gguf");
std::fs::write(&path, &buf).unwrap();
let t = ChatTemplate::from_gguf(&path).expect("from_gguf");
assert_eq!(t.bos_token(), Some("<bos>"));
assert_eq!(t.eos_token(), Some("<eos>"));
let out = t.render(&sample_conv(), true).unwrap();
assert!(out.contains("<|im_start|>assistant\n"));
match t.source_kind() {
ChatTemplateSource::GgufMetadata(k) => assert_eq!(k, "tokenizer.chat_template"),
other => panic!("unexpected source: {other:?}"),
}
std::fs::remove_file(&path).ok();
}
}