1use anyhow::{Context, Result, anyhow};
28use minijinja::{Environment, Error as JinjaError, ErrorKind, Value};
29use rlx_gguf::{GgufFile, MetaValue};
30use std::path::Path;
31
32pub fn auto_chat_template(path: &Path) -> Result<ChatTemplate> {
38 ChatTemplate::from_gguf(path)
39}
40
41#[derive(Debug, Clone)]
44pub struct ChatMessage {
45 pub role: String,
46 pub content: String,
47}
48
49impl ChatMessage {
50 pub fn user(content: impl Into<String>) -> Self {
51 Self {
52 role: "user".into(),
53 content: content.into(),
54 }
55 }
56 pub fn system(content: impl Into<String>) -> Self {
57 Self {
58 role: "system".into(),
59 content: content.into(),
60 }
61 }
62 pub fn assistant(content: impl Into<String>) -> Self {
63 Self {
64 role: "assistant".into(),
65 content: content.into(),
66 }
67 }
68}
69
70#[derive(Debug, Clone)]
73pub enum ChatTemplateSource {
74 Inline,
75 GgufMetadata(String),
76}
77
78pub struct ChatTemplate {
80 env: Environment<'static>,
81 source_text: String,
82 source_kind: ChatTemplateSource,
83 bos_token: Option<String>,
84 eos_token: Option<String>,
85}
86
87const TEMPLATE_NAME: &str = "chat";
88
89fn build_env(source: String) -> Result<Environment<'static>> {
90 let mut env = Environment::new();
91 env.add_function(
94 "raise_exception",
95 |msg: String| -> Result<Value, JinjaError> {
96 Err(JinjaError::new(ErrorKind::InvalidOperation, msg))
97 },
98 );
99 env.add_template_owned(TEMPLATE_NAME, source)
100 .context("compiling chat template")?;
101 Ok(env)
102}
103
104impl ChatTemplate {
105 pub fn from_source(src: impl Into<String>) -> Result<Self> {
107 let source_text: String = src.into();
108 let env = build_env(source_text.clone())?;
109 Ok(Self {
110 env,
111 source_text,
112 source_kind: ChatTemplateSource::Inline,
113 bos_token: None,
114 eos_token: None,
115 })
116 }
117
118 pub fn with_tokens(mut self, bos: Option<String>, eos: Option<String>) -> Self {
121 self.bos_token = bos;
122 self.eos_token = eos;
123 self
124 }
125
126 pub fn from_gguf(path: &Path) -> Result<Self> {
129 let raw = GgufFile::from_path(path).with_context(|| format!("opening GGUF {path:?}"))?;
130 Self::from_gguf_file(&raw)
131 }
132
133 pub fn from_gguf_file(raw: &GgufFile) -> Result<Self> {
135 let (key, src) = pick_chat_template_meta(raw).ok_or_else(|| {
136 anyhow!("no tokenizer.chat_template or tokenizer.ggml.chat_template in GGUF metadata")
137 })?;
138 let env = build_env(src.clone())?;
139 let bos = resolve_special_token(raw, "tokenizer.ggml.bos_token_id");
140 let eos = resolve_special_token(raw, "tokenizer.ggml.eos_token_id");
141 Ok(Self {
142 env,
143 source_text: src,
144 source_kind: ChatTemplateSource::GgufMetadata(key.to_owned()),
145 bos_token: bos,
146 eos_token: eos,
147 })
148 }
149
150 pub fn source_text(&self) -> &str {
151 &self.source_text
152 }
153
154 pub fn source_kind(&self) -> &ChatTemplateSource {
155 &self.source_kind
156 }
157
158 pub fn bos_token(&self) -> Option<&str> {
159 self.bos_token.as_deref()
160 }
161
162 pub fn eos_token(&self) -> Option<&str> {
163 self.eos_token.as_deref()
164 }
165
166 pub fn render(&self, messages: &[ChatMessage], add_generation_prompt: bool) -> Result<String> {
172 let msgs: Vec<Value> = messages
173 .iter()
174 .map(|m| {
175 Value::from_serialize(serde_json::json!({
176 "role": m.role,
177 "content": m.content,
178 }))
179 })
180 .collect();
181 let ctx = minijinja::context! {
182 messages => Value::from(msgs),
183 add_generation_prompt => add_generation_prompt,
184 bos_token => self.bos_token.clone().unwrap_or_default(),
185 eos_token => self.eos_token.clone().unwrap_or_default(),
186 };
187 let tmpl = self
188 .env
189 .get_template(TEMPLATE_NAME)
190 .expect("template registered in build_env");
191 tmpl.render(ctx).context("rendering chat template")
192 }
193}
194
195fn pick_chat_template_meta(raw: &GgufFile) -> Option<(&'static str, String)> {
196 for key in ["tokenizer.chat_template", "tokenizer.ggml.chat_template"] {
197 if let Some(MetaValue::String(s)) = raw.metadata.get(key) {
198 return Some((key, s.clone()));
199 }
200 }
201 None
202}
203
204fn resolve_special_token(raw: &GgufFile, id_key: &str) -> Option<String> {
205 let id = raw.metadata.get(id_key).and_then(MetaValue::as_u32)? as usize;
206 let toks = raw.metadata.get("tokenizer.ggml.tokens")?;
207 let MetaValue::Array(arr) = toks else {
208 return None;
209 };
210 match arr.get(id)? {
211 MetaValue::String(s) => Some(s.clone()),
212 _ => None,
213 }
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219
220 const QWEN_TEMPLATE: &str = "{% for m in messages %}<|im_start|>{{ m.role }}\n{{ m.content }}<|im_end|>\n{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}";
225
226 const LLAMA3_TEMPLATE: &str = "{% for m in messages %}{% if loop.first %}{{ bos_token }}{% endif %}<|start_header_id|>{{ m.role }}<|end_header_id|>\n\n{{ m.content }}<|eot_id|>{% endfor %}{% if add_generation_prompt %}<|start_header_id|>assistant<|end_header_id|>\n\n{% endif %}";
228
229 const GEMMA_TEMPLATE: &str = "{% for m in messages %}{% set role = 'user' if m.role == 'system' else m.role %}<start_of_turn>{{ role }}\n{{ m.content }}<end_of_turn>\n{% endfor %}{% if add_generation_prompt %}<start_of_turn>model\n{% endif %}";
231
232 fn sample_conv() -> Vec<ChatMessage> {
233 vec![ChatMessage::system("be concise"), ChatMessage::user("hi")]
234 }
235
236 #[test]
237 fn qwen_template_renders_with_generation_prompt() {
238 let t = ChatTemplate::from_source(QWEN_TEMPLATE).unwrap();
239 let out = t.render(&sample_conv(), true).unwrap();
240 let expected = "<|im_start|>system\nbe concise<|im_end|>\n\
241 <|im_start|>user\nhi<|im_end|>\n\
242 <|im_start|>assistant\n";
243 assert_eq!(out, expected);
244 }
245
246 #[test]
247 fn qwen_template_omits_generation_prompt_when_disabled() {
248 let t = ChatTemplate::from_source(QWEN_TEMPLATE).unwrap();
249 let out = t.render(&sample_conv(), false).unwrap();
250 assert!(out.ends_with("<|im_end|>\n"));
251 assert!(!out.contains("<|im_start|>assistant\n"));
252 }
253
254 #[test]
255 fn llama3_template_uses_bos_token() {
256 let t = ChatTemplate::from_source(LLAMA3_TEMPLATE)
257 .unwrap()
258 .with_tokens(Some("<|begin_of_text|>".into()), Some("<|eot_id|>".into()));
259 let out = t.render(&sample_conv(), true).unwrap();
260 let expected = "<|begin_of_text|>\
261 <|start_header_id|>system<|end_header_id|>\n\nbe concise<|eot_id|>\
262 <|start_header_id|>user<|end_header_id|>\n\nhi<|eot_id|>\
263 <|start_header_id|>assistant<|end_header_id|>\n\n";
264 assert_eq!(out, expected);
265 assert_eq!(t.bos_token(), Some("<|begin_of_text|>"));
266 assert_eq!(t.eos_token(), Some("<|eot_id|>"));
267 }
268
269 #[test]
270 fn gemma_template_rewrites_system_to_user() {
271 let t = ChatTemplate::from_source(GEMMA_TEMPLATE).unwrap();
272 let out = t.render(&sample_conv(), true).unwrap();
273 let expected = "<start_of_turn>user\nbe concise<end_of_turn>\n\
274 <start_of_turn>user\nhi<end_of_turn>\n\
275 <start_of_turn>model\n";
276 assert_eq!(out, expected);
277 }
278
279 #[test]
280 fn raise_exception_propagates_as_error() {
281 let t = ChatTemplate::from_source("{{ raise_exception('nope') }}").unwrap();
282 let err = t.render(&[], false).unwrap_err();
283 assert!(format!("{err:#}").contains("nope"));
284 }
285
286 #[test]
289 fn from_gguf_reads_template_and_special_tokens() {
290 let mut buf: Vec<u8> = Vec::new();
297 buf.extend_from_slice(&rlx_gguf::GGUF_MAGIC.to_le_bytes());
298 buf.extend_from_slice(&3u32.to_le_bytes());
299 buf.extend_from_slice(&1u64.to_le_bytes()); buf.extend_from_slice(&4u64.to_le_bytes()); let write_string_kv = |buf: &mut Vec<u8>, k: &str, v: &str| {
303 buf.extend_from_slice(&(k.len() as u64).to_le_bytes());
304 buf.extend_from_slice(k.as_bytes());
305 buf.extend_from_slice(&8u32.to_le_bytes());
306 buf.extend_from_slice(&(v.len() as u64).to_le_bytes());
307 buf.extend_from_slice(v.as_bytes());
308 };
309 let write_u32_kv = |buf: &mut Vec<u8>, k: &str, v: u32| {
310 buf.extend_from_slice(&(k.len() as u64).to_le_bytes());
311 buf.extend_from_slice(k.as_bytes());
312 buf.extend_from_slice(&4u32.to_le_bytes());
313 buf.extend_from_slice(&v.to_le_bytes());
314 };
315 let write_string_array_kv = |buf: &mut Vec<u8>, k: &str, items: &[&str]| {
316 buf.extend_from_slice(&(k.len() as u64).to_le_bytes());
317 buf.extend_from_slice(k.as_bytes());
318 buf.extend_from_slice(&9u32.to_le_bytes());
320 buf.extend_from_slice(&8u32.to_le_bytes());
322 buf.extend_from_slice(&(items.len() as u64).to_le_bytes());
324 for s in items {
325 buf.extend_from_slice(&(s.len() as u64).to_le_bytes());
326 buf.extend_from_slice(s.as_bytes());
327 }
328 };
329
330 write_string_kv(&mut buf, "tokenizer.chat_template", QWEN_TEMPLATE);
331 write_string_array_kv(
332 &mut buf,
333 "tokenizer.ggml.tokens",
334 &["<pad>", "<bos>", "<eos>", "hi"],
335 );
336 write_u32_kv(&mut buf, "tokenizer.ggml.bos_token_id", 1);
337 write_u32_kv(&mut buf, "tokenizer.ggml.eos_token_id", 2);
338
339 let name = "w";
341 buf.extend_from_slice(&(name.len() as u64).to_le_bytes());
342 buf.extend_from_slice(name.as_bytes());
343 buf.extend_from_slice(&1u32.to_le_bytes());
344 buf.extend_from_slice(&4u64.to_le_bytes());
345 buf.extend_from_slice(&(rlx_gguf::GgmlType::F32 as u32).to_le_bytes());
346 buf.extend_from_slice(&0u64.to_le_bytes());
347 while !buf
348 .len()
349 .is_multiple_of(rlx_gguf::DEFAULT_ALIGNMENT as usize)
350 {
351 buf.push(0);
352 }
353 for _ in 0..4 {
354 buf.extend_from_slice(&1.0f32.to_le_bytes());
355 }
356 let path = std::env::temp_dir().join("rlx_chat_template_from_gguf.gguf");
357 std::fs::write(&path, &buf).unwrap();
358
359 let t = ChatTemplate::from_gguf(&path).expect("from_gguf");
360 assert_eq!(t.bos_token(), Some("<bos>"));
361 assert_eq!(t.eos_token(), Some("<eos>"));
362 let out = t.render(&sample_conv(), true).unwrap();
363 assert!(out.contains("<|im_start|>assistant\n"));
364 match t.source_kind() {
365 ChatTemplateSource::GgufMetadata(k) => assert_eq!(k, "tokenizer.chat_template"),
366 other => panic!("unexpected source: {other:?}"),
367 }
368 std::fs::remove_file(&path).ok();
369 }
370}