Skip to main content

rlx_cli/
chat.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Chat-template engine for RLX runners.
17//!
18//! Replaces `LlamaModel::apply_chat_template` (llama-cpp-4) end-to-end. Two
19//! sources: an inline Jinja2 string, or `tokenizer.chat_template` (and
20//! `tokenizer.ggml.chat_template`) read directly from a GGUF file's
21//! metadata. Rendering uses `minijinja`.
22//!
23//! BOS/EOS strings are looked up via `tokenizer.ggml.bos_token_id` /
24//! `eos_token_id` against the `tokenizer.ggml.tokens` array (the GGUF
25//! convention).
26
27use anyhow::{Context, Result, anyhow};
28use minijinja::{Environment, Error as JinjaError, ErrorKind, Value};
29use rlx_gguf::{GgufFile, MetaValue};
30use std::path::Path;
31
32/// Convenience for the M3 auto-dispatch family: load the chat template
33/// + BOS/EOS strings directly from a GGUF path.
34///
35/// Alias for [`ChatTemplate::from_gguf`]. Use `rlx_models::run::auto_chat_template(path)`
36/// next to `rlx_models::run::auto_runner(path)`.
37pub fn auto_chat_template(path: &Path) -> Result<ChatTemplate> {
38    ChatTemplate::from_gguf(path)
39}
40
41/// One chat turn. `role` is conventionally one of `system`, `user`,
42/// `assistant`, `tool` — but templates can accept anything.
43#[derive(Debug, Clone)]
44pub struct ChatMessage {
45    pub role: String,
46    pub content: String,
47}
48
49impl ChatMessage {
50    pub fn user(content: impl Into<String>) -> Self {
51        Self {
52            role: "user".into(),
53            content: content.into(),
54        }
55    }
56    pub fn system(content: impl Into<String>) -> Self {
57        Self {
58            role: "system".into(),
59            content: content.into(),
60        }
61    }
62    pub fn assistant(content: impl Into<String>) -> Self {
63        Self {
64            role: "assistant".into(),
65            content: content.into(),
66        }
67    }
68}
69
70/// Where a [`ChatTemplate`] was loaded from. Useful for diagnostics and
71/// for letting a caller round-trip the source string into config.
72#[derive(Debug, Clone)]
73pub enum ChatTemplateSource {
74    Inline,
75    GgufMetadata(String),
76}
77
78/// Compiled Jinja chat template + BOS/EOS strings.
79pub struct ChatTemplate {
80    env: Environment<'static>,
81    source_text: String,
82    source_kind: ChatTemplateSource,
83    bos_token: Option<String>,
84    eos_token: Option<String>,
85}
86
87const TEMPLATE_NAME: &str = "chat";
88
89fn build_env(source: String) -> Result<Environment<'static>> {
90    let mut env = Environment::new();
91    // HF templates occasionally call `raise_exception(msg)` for invariant
92    // checks (e.g. "system must come first"). Wire it to a Jinja error.
93    env.add_function(
94        "raise_exception",
95        |msg: String| -> Result<Value, JinjaError> {
96            Err(JinjaError::new(ErrorKind::InvalidOperation, msg))
97        },
98    );
99    env.add_template_owned(TEMPLATE_NAME, source)
100        .context("compiling chat template")?;
101    Ok(env)
102}
103
104impl ChatTemplate {
105    /// Compile a chat template from a raw Jinja string.
106    pub fn from_source(src: impl Into<String>) -> Result<Self> {
107        let source_text: String = src.into();
108        let env = build_env(source_text.clone())?;
109        Ok(Self {
110            env,
111            source_text,
112            source_kind: ChatTemplateSource::Inline,
113            bos_token: None,
114            eos_token: None,
115        })
116    }
117
118    /// Override BOS/EOS strings (passed to the template as `bos_token` /
119    /// `eos_token` Jinja variables).
120    pub fn with_tokens(mut self, bos: Option<String>, eos: Option<String>) -> Self {
121        self.bos_token = bos;
122        self.eos_token = eos;
123        self
124    }
125
126    /// Load template + BOS/EOS from a GGUF file. Reads
127    /// `tokenizer.chat_template` first, then `tokenizer.ggml.chat_template`.
128    pub fn from_gguf(path: &Path) -> Result<Self> {
129        let raw = GgufFile::from_path(path).with_context(|| format!("opening GGUF {path:?}"))?;
130        Self::from_gguf_file(&raw)
131    }
132
133    /// Same as [`from_gguf`](Self::from_gguf), but reuses an already-parsed file.
134    pub fn from_gguf_file(raw: &GgufFile) -> Result<Self> {
135        let (key, src) = pick_chat_template_meta(raw).ok_or_else(|| {
136            anyhow!("no tokenizer.chat_template or tokenizer.ggml.chat_template in GGUF metadata")
137        })?;
138        let env = build_env(src.clone())?;
139        let bos = resolve_special_token(raw, "tokenizer.ggml.bos_token_id");
140        let eos = resolve_special_token(raw, "tokenizer.ggml.eos_token_id");
141        Ok(Self {
142            env,
143            source_text: src,
144            source_kind: ChatTemplateSource::GgufMetadata(key.to_owned()),
145            bos_token: bos,
146            eos_token: eos,
147        })
148    }
149
150    pub fn source_text(&self) -> &str {
151        &self.source_text
152    }
153
154    pub fn source_kind(&self) -> &ChatTemplateSource {
155        &self.source_kind
156    }
157
158    pub fn bos_token(&self) -> Option<&str> {
159        self.bos_token.as_deref()
160    }
161
162    pub fn eos_token(&self) -> Option<&str> {
163        self.eos_token.as_deref()
164    }
165
166    /// Render the template with the given messages.
167    ///
168    /// The template sees Jinja variables: `messages` (list of
169    /// `{role, content}` maps), `add_generation_prompt` (bool), and
170    /// `bos_token` / `eos_token` strings (empty if unknown).
171    pub fn render(&self, messages: &[ChatMessage], add_generation_prompt: bool) -> Result<String> {
172        let msgs: Vec<Value> = messages
173            .iter()
174            .map(|m| {
175                Value::from_serialize(serde_json::json!({
176                    "role": m.role,
177                    "content": m.content,
178                }))
179            })
180            .collect();
181        let ctx = minijinja::context! {
182            messages => Value::from(msgs),
183            add_generation_prompt => add_generation_prompt,
184            bos_token => self.bos_token.clone().unwrap_or_default(),
185            eos_token => self.eos_token.clone().unwrap_or_default(),
186        };
187        let tmpl = self
188            .env
189            .get_template(TEMPLATE_NAME)
190            .expect("template registered in build_env");
191        tmpl.render(ctx).context("rendering chat template")
192    }
193}
194
195fn pick_chat_template_meta(raw: &GgufFile) -> Option<(&'static str, String)> {
196    for key in ["tokenizer.chat_template", "tokenizer.ggml.chat_template"] {
197        if let Some(MetaValue::String(s)) = raw.metadata.get(key) {
198            return Some((key, s.clone()));
199        }
200    }
201    None
202}
203
204fn resolve_special_token(raw: &GgufFile, id_key: &str) -> Option<String> {
205    let id = raw.metadata.get(id_key).and_then(MetaValue::as_u32)? as usize;
206    let toks = raw.metadata.get("tokenizer.ggml.tokens")?;
207    let MetaValue::Array(arr) = toks else {
208        return None;
209    };
210    match arr.get(id)? {
211        MetaValue::String(s) => Some(s.clone()),
212        _ => None,
213    }
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219
220    // Minimal Qwen / ChatML-style template — same shape as Qwen3's, simplified
221    // enough that test failures point at our rendering plumbing not at
222    // upstream Jinja quirks. Whitespace-trim markers are intentionally
223    // avoided so the literal `\n` inside the template survives.
224    const QWEN_TEMPLATE: &str = "{% for m in messages %}<|im_start|>{{ m.role }}\n{{ m.content }}<|im_end|>\n{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}";
225
226    // Minimal Llama-3-style template using bos_token + headers.
227    const LLAMA3_TEMPLATE: &str = "{% for m in messages %}{% if loop.first %}{{ bos_token }}{% endif %}<|start_header_id|>{{ m.role }}<|end_header_id|>\n\n{{ m.content }}<|eot_id|>{% endfor %}{% if add_generation_prompt %}<|start_header_id|>assistant<|end_header_id|>\n\n{% endif %}";
228
229    // Minimal Gemma-style template.
230    const GEMMA_TEMPLATE: &str = "{% for m in messages %}{% set role = 'user' if m.role == 'system' else m.role %}<start_of_turn>{{ role }}\n{{ m.content }}<end_of_turn>\n{% endfor %}{% if add_generation_prompt %}<start_of_turn>model\n{% endif %}";
231
232    fn sample_conv() -> Vec<ChatMessage> {
233        vec![ChatMessage::system("be concise"), ChatMessage::user("hi")]
234    }
235
236    #[test]
237    fn qwen_template_renders_with_generation_prompt() {
238        let t = ChatTemplate::from_source(QWEN_TEMPLATE).unwrap();
239        let out = t.render(&sample_conv(), true).unwrap();
240        let expected = "<|im_start|>system\nbe concise<|im_end|>\n\
241                        <|im_start|>user\nhi<|im_end|>\n\
242                        <|im_start|>assistant\n";
243        assert_eq!(out, expected);
244    }
245
246    #[test]
247    fn qwen_template_omits_generation_prompt_when_disabled() {
248        let t = ChatTemplate::from_source(QWEN_TEMPLATE).unwrap();
249        let out = t.render(&sample_conv(), false).unwrap();
250        assert!(out.ends_with("<|im_end|>\n"));
251        assert!(!out.contains("<|im_start|>assistant\n"));
252    }
253
254    #[test]
255    fn llama3_template_uses_bos_token() {
256        let t = ChatTemplate::from_source(LLAMA3_TEMPLATE)
257            .unwrap()
258            .with_tokens(Some("<|begin_of_text|>".into()), Some("<|eot_id|>".into()));
259        let out = t.render(&sample_conv(), true).unwrap();
260        let expected = "<|begin_of_text|>\
261                        <|start_header_id|>system<|end_header_id|>\n\nbe concise<|eot_id|>\
262                        <|start_header_id|>user<|end_header_id|>\n\nhi<|eot_id|>\
263                        <|start_header_id|>assistant<|end_header_id|>\n\n";
264        assert_eq!(out, expected);
265        assert_eq!(t.bos_token(), Some("<|begin_of_text|>"));
266        assert_eq!(t.eos_token(), Some("<|eot_id|>"));
267    }
268
269    #[test]
270    fn gemma_template_rewrites_system_to_user() {
271        let t = ChatTemplate::from_source(GEMMA_TEMPLATE).unwrap();
272        let out = t.render(&sample_conv(), true).unwrap();
273        let expected = "<start_of_turn>user\nbe concise<end_of_turn>\n\
274                        <start_of_turn>user\nhi<end_of_turn>\n\
275                        <start_of_turn>model\n";
276        assert_eq!(out, expected);
277    }
278
279    #[test]
280    fn raise_exception_propagates_as_error() {
281        let t = ChatTemplate::from_source("{{ raise_exception('nope') }}").unwrap();
282        let err = t.render(&[], false).unwrap_err();
283        assert!(format!("{err:#}").contains("nope"));
284    }
285
286    /// Builds a minimal GGUF in a temp file with a chat_template + token
287    /// table, then verifies BOS/EOS resolve and rendering works.
288    #[test]
289    fn from_gguf_reads_template_and_special_tokens() {
290        // We build a v3 GGUF with three metadata keys:
291        //   tokenizer.chat_template      (String)
292        //   tokenizer.ggml.tokens        (Array of String)
293        //   tokenizer.ggml.bos_token_id  (U32)
294        //   tokenizer.ggml.eos_token_id  (U32)
295        // and one tiny f32 tensor so the file passes the loader.
296        let mut buf: Vec<u8> = Vec::new();
297        buf.extend_from_slice(&rlx_gguf::GGUF_MAGIC.to_le_bytes());
298        buf.extend_from_slice(&3u32.to_le_bytes());
299        buf.extend_from_slice(&1u64.to_le_bytes()); // tensor count
300        buf.extend_from_slice(&4u64.to_le_bytes()); // kv count
301
302        let write_string_kv = |buf: &mut Vec<u8>, k: &str, v: &str| {
303            buf.extend_from_slice(&(k.len() as u64).to_le_bytes());
304            buf.extend_from_slice(k.as_bytes());
305            buf.extend_from_slice(&8u32.to_le_bytes());
306            buf.extend_from_slice(&(v.len() as u64).to_le_bytes());
307            buf.extend_from_slice(v.as_bytes());
308        };
309        let write_u32_kv = |buf: &mut Vec<u8>, k: &str, v: u32| {
310            buf.extend_from_slice(&(k.len() as u64).to_le_bytes());
311            buf.extend_from_slice(k.as_bytes());
312            buf.extend_from_slice(&4u32.to_le_bytes());
313            buf.extend_from_slice(&v.to_le_bytes());
314        };
315        let write_string_array_kv = |buf: &mut Vec<u8>, k: &str, items: &[&str]| {
316            buf.extend_from_slice(&(k.len() as u64).to_le_bytes());
317            buf.extend_from_slice(k.as_bytes());
318            // type = Array(9)
319            buf.extend_from_slice(&9u32.to_le_bytes());
320            // element type = String(8)
321            buf.extend_from_slice(&8u32.to_le_bytes());
322            // length (u64)
323            buf.extend_from_slice(&(items.len() as u64).to_le_bytes());
324            for s in items {
325                buf.extend_from_slice(&(s.len() as u64).to_le_bytes());
326                buf.extend_from_slice(s.as_bytes());
327            }
328        };
329
330        write_string_kv(&mut buf, "tokenizer.chat_template", QWEN_TEMPLATE);
331        write_string_array_kv(
332            &mut buf,
333            "tokenizer.ggml.tokens",
334            &["<pad>", "<bos>", "<eos>", "hi"],
335        );
336        write_u32_kv(&mut buf, "tokenizer.ggml.bos_token_id", 1);
337        write_u32_kv(&mut buf, "tokenizer.ggml.eos_token_id", 2);
338
339        // tiny f32 tensor
340        let name = "w";
341        buf.extend_from_slice(&(name.len() as u64).to_le_bytes());
342        buf.extend_from_slice(name.as_bytes());
343        buf.extend_from_slice(&1u32.to_le_bytes());
344        buf.extend_from_slice(&4u64.to_le_bytes());
345        buf.extend_from_slice(&(rlx_gguf::GgmlType::F32 as u32).to_le_bytes());
346        buf.extend_from_slice(&0u64.to_le_bytes());
347        while !buf
348            .len()
349            .is_multiple_of(rlx_gguf::DEFAULT_ALIGNMENT as usize)
350        {
351            buf.push(0);
352        }
353        for _ in 0..4 {
354            buf.extend_from_slice(&1.0f32.to_le_bytes());
355        }
356        let path = std::env::temp_dir().join("rlx_chat_template_from_gguf.gguf");
357        std::fs::write(&path, &buf).unwrap();
358
359        let t = ChatTemplate::from_gguf(&path).expect("from_gguf");
360        assert_eq!(t.bos_token(), Some("<bos>"));
361        assert_eq!(t.eos_token(), Some("<eos>"));
362        let out = t.render(&sample_conv(), true).unwrap();
363        assert!(out.contains("<|im_start|>assistant\n"));
364        match t.source_kind() {
365            ChatTemplateSource::GgufMetadata(k) => assert_eq!(k, "tokenizer.chat_template"),
366            other => panic!("unexpected source: {other:?}"),
367        }
368        std::fs::remove_file(&path).ok();
369    }
370}