candle_examples/
chat_template.rs

1//! Chat template support for LLM examples
2//!
3//! This module provides Jinja-based chat template rendering compatible with
4//! HuggingFace's `tokenizer.apply_chat_template()` functionality.
5//!
6//! # Example
7//!
8//! ```no_run
9//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
10//! use candle_examples::chat_template::{ChatTemplate, ChatTemplateOptions, Message, Conversation};
11//!
12//! // Load template from a model's tokenizer_config.json
13//! let template = ChatTemplate::from_tokenizer_config("path/to/tokenizer_config.json")?;
14//!
15//! // Or use a preset for known models
16//! let template = ChatTemplate::chatml(); // SmolLM, Qwen, etc.
17//!
18//! // Single-turn
19//! let messages = vec![
20//!     Message::system("You are helpful."),
21//!     Message::user("Hello!"),
22//! ];
23//! let prompt = template.apply_for_generation(&messages)?;
24//!
25//! // Multi-turn conversation
26//! let mut conv = Conversation::new(template, "You are helpful.");
27//! let prompt = conv.user_turn("Hello!")?;
28//! // ... generate response ...
29//! conv.assistant_response("Hi there!");
30//! let prompt = conv.user_turn("How are you?")?;
31//! # Ok(())
32//! # }
33//! ```
34
35use minijinja::{context, Environment};
36use serde::{Deserialize, Serialize};
37use std::path::Path;
38
39/// A chat message with role and content
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct Message {
42    pub role: String,
43    pub content: String,
44}
45
46impl Message {
47    pub fn new(role: impl Into<String>, content: impl Into<String>) -> Self {
48        Self {
49            role: role.into(),
50            content: content.into(),
51        }
52    }
53
54    pub fn system(content: impl Into<String>) -> Self {
55        Self::new("system", content)
56    }
57
58    pub fn user(content: impl Into<String>) -> Self {
59        Self::new("user", content)
60    }
61
62    pub fn assistant(content: impl Into<String>) -> Self {
63        Self::new("assistant", content)
64    }
65}
66
67/// Options for applying a chat template
68#[derive(Debug, Clone, Default)]
69pub struct ChatTemplateOptions {
70    /// Add tokens that prompt the model to generate an assistant response
71    pub add_generation_prompt: bool,
72    /// Continue the final message instead of starting a new one (for prefilling)
73    pub continue_final_message: bool,
74    /// Enable thinking/reasoning mode (adds <think> tags)
75    pub enable_thinking: bool,
76    /// Custom variables to pass to the template
77    pub extra_context: std::collections::HashMap<String, String>,
78}
79
80impl ChatTemplateOptions {
81    pub fn for_generation() -> Self {
82        Self {
83            add_generation_prompt: true,
84            ..Default::default()
85        }
86    }
87
88    pub fn for_training() -> Self {
89        Self {
90            add_generation_prompt: false,
91            ..Default::default()
92        }
93    }
94
95    pub fn with_thinking(mut self) -> Self {
96        self.enable_thinking = true;
97        self
98    }
99}
100
101/// Token configuration loaded from tokenizer_config.json
102#[derive(Debug, Clone, Default, Deserialize)]
103pub struct TokenConfig {
104    #[serde(default)]
105    pub bos_token: Option<StringOrToken>,
106    #[serde(default)]
107    pub eos_token: Option<StringOrToken>,
108    #[serde(default)]
109    pub unk_token: Option<StringOrToken>,
110    #[serde(default)]
111    pub pad_token: Option<StringOrToken>,
112    #[serde(default)]
113    pub chat_template: Option<ChatTemplateConfig>,
114}
115
116/// Handle both string and object token formats in tokenizer_config.json
117#[derive(Debug, Clone, Deserialize)]
118#[serde(untagged)]
119pub enum StringOrToken {
120    String(String),
121    Token { content: String },
122}
123
124impl StringOrToken {
125    pub fn as_str(&self) -> &str {
126        match self {
127            StringOrToken::String(s) => s,
128            StringOrToken::Token { content } => content,
129        }
130    }
131}
132
133impl Default for StringOrToken {
134    fn default() -> Self {
135        StringOrToken::String(String::new())
136    }
137}
138
139/// Chat template can be a single string or multiple named templates
140#[derive(Debug, Clone, Deserialize)]
141#[serde(untagged)]
142pub enum ChatTemplateConfig {
143    Single(String),
144    Multiple(Vec<NamedTemplate>),
145}
146
147#[derive(Debug, Clone, Deserialize)]
148pub struct NamedTemplate {
149    pub name: String,
150    pub template: String,
151}
152
153/// Chat template renderer using MiniJinja
154pub struct ChatTemplate {
155    env: Environment<'static>,
156    bos_token: String,
157    eos_token: String,
158}
159
160impl ChatTemplate {
161    /// Create from a Jinja template string
162    pub fn new(
163        template: impl Into<String>,
164        bos_token: impl Into<String>,
165        eos_token: impl Into<String>,
166    ) -> Result<Self, ChatTemplateError> {
167        let mut env = Environment::new();
168        // Add the raise_exception function that HF templates use
169        env.add_function("raise_exception", |msg: String| -> Result<String, _> {
170            Err(minijinja::Error::new(
171                minijinja::ErrorKind::InvalidOperation,
172                msg,
173            ))
174        });
175
176        env.add_template_owned("chat".to_string(), template.into())
177            .map_err(|e| ChatTemplateError::TemplateError(e.to_string()))?;
178
179        Ok(Self {
180            env,
181            bos_token: bos_token.into(),
182            eos_token: eos_token.into(),
183        })
184    }
185
186    /// Load chat template from a tokenizer_config.json file
187    pub fn from_tokenizer_config(path: impl AsRef<Path>) -> Result<Self, ChatTemplateError> {
188        let content = std::fs::read_to_string(path.as_ref())
189            .map_err(|e| ChatTemplateError::IoError(e.to_string()))?;
190
191        Self::from_tokenizer_config_str(&content)
192    }
193
194    /// Load chat template from tokenizer_config.json content
195    pub fn from_tokenizer_config_str(json: &str) -> Result<Self, ChatTemplateError> {
196        let config: TokenConfig =
197            serde_json::from_str(json).map_err(|e| ChatTemplateError::ParseError(e.to_string()))?;
198
199        let template = match config.chat_template {
200            Some(ChatTemplateConfig::Single(t)) => t,
201            Some(ChatTemplateConfig::Multiple(templates)) => {
202                // Use "default" template if available, otherwise first one
203                templates
204                    .iter()
205                    .find(|t| t.name == "default")
206                    .or_else(|| templates.first())
207                    .map(|t| t.template.clone())
208                    .ok_or(ChatTemplateError::NoTemplate)?
209            }
210            None => return Err(ChatTemplateError::NoTemplate),
211        };
212
213        let bos = config
214            .bos_token
215            .map(|t| t.as_str().to_string())
216            .unwrap_or_default();
217        let eos = config
218            .eos_token
219            .map(|t| t.as_str().to_string())
220            .unwrap_or_default();
221
222        Self::new(template, bos, eos)
223    }
224
225    /// ChatML template used by SmolLM, Qwen, and many other models
226    pub fn chatml() -> Self {
227        let template = r#"
228{%- for message in messages %}
229{{- '<|im_start|>' + message.role + '\n' + message.content | trim + '<|im_end|>\n' }}
230{%- endfor %}
231{%- if add_generation_prompt %}
232{{- '<|im_start|>assistant\n' }}
233{%- endif %}
234"#;
235        Self::new(template, "", "<|im_end|>").unwrap()
236    }
237
238    /// ChatML template with thinking/reasoning support
239    pub fn chatml_with_thinking() -> Self {
240        let template = r#"
241{%- for message in messages %}
242{{- '<|im_start|>' + message.role + '\n' + message.content | trim + '<|im_end|>\n' }}
243{%- endfor %}
244{%- if add_generation_prompt %}
245{%- if enable_thinking %}
246{{- '<|im_start|>assistant\n<think>\n' }}
247{%- else %}
248{{- '<|im_start|>assistant\n' }}
249{%- endif %}
250{%- endif %}
251"#;
252        Self::new(template, "", "<|im_end|>").unwrap()
253    }
254
255    /// Llama 2 chat template
256    pub fn llama2() -> Self {
257        let template = r#"
258{%- if messages[0]['role'] == 'system' %}
259    {%- set system_message = '<<SYS>>\n' + messages[0]['content'] + '\n<</SYS>>\n\n' %}
260    {%- set messages = messages[1:] %}
261{%- else %}
262    {%- set system_message = '' %}
263{%- endif %}
264{%- for message in messages %}
265    {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}
266        {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}
267    {%- endif %}
268    {%- if loop.index0 == 0 %}
269        {{- bos_token + '[INST] ' + system_message + message['content'] + ' [/INST]' }}
270    {%- elif message['role'] == 'user' %}
271        {{- bos_token + '[INST] ' + message['content'] + ' [/INST]' }}
272    {%- elif message['role'] == 'assistant' %}
273        {{- ' ' + message['content'] + ' ' + eos_token }}
274    {%- endif %}
275{%- endfor %}
276"#;
277        Self::new(template, "<s>", "</s>").unwrap()
278    }
279
280    /// Llama 3 / 3.1 chat template
281    pub fn llama3() -> Self {
282        let template = r#"
283{%- set loop_messages = messages %}
284{%- for message in loop_messages %}
285    {%- set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] | trim + '<|eot_id|>' %}
286    {%- if loop.index0 == 0 %}
287        {{- bos_token + content }}
288    {%- else %}
289        {{- content }}
290    {%- endif %}
291{%- endfor %}
292{%- if add_generation_prompt %}
293    {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }}
294{%- endif %}
295"#;
296        Self::new(template, "<|begin_of_text|>", "<|eot_id|>").unwrap()
297    }
298
299    /// Mistral Instruct template
300    pub fn mistral() -> Self {
301        let template = r#"
302{{- bos_token }}
303{%- for message in messages %}
304    {%- if message['role'] == 'user' %}
305        {{- '[INST] ' + message['content'] + ' [/INST]' }}
306    {%- elif message['role'] == 'assistant' %}
307        {{- ' ' + message['content'] + eos_token }}
308    {%- endif %}
309{%- endfor %}
310"#;
311        Self::new(template, "<s>", "</s>").unwrap()
312    }
313
314    /// Gemma template
315    pub fn gemma() -> Self {
316        let template = r#"
317{%- for message in messages %}
318    {%- if message['role'] == 'user' %}
319        {{- '<start_of_turn>user\n' + message['content'] + '<end_of_turn>\n' }}
320    {%- elif message['role'] == 'assistant' %}
321        {{- '<start_of_turn>model\n' + message['content'] + '<end_of_turn>\n' }}
322    {%- endif %}
323{%- endfor %}
324{%- if add_generation_prompt %}
325    {{- '<start_of_turn>model\n' }}
326{%- endif %}
327"#;
328        Self::new(template, "<bos>", "<eos>").unwrap()
329    }
330
331    /// Apply the chat template to messages
332    pub fn apply(
333        &self,
334        messages: &[Message],
335        options: &ChatTemplateOptions,
336    ) -> Result<String, ChatTemplateError> {
337        let template = self
338            .env
339            .get_template("chat")
340            .map_err(|e| ChatTemplateError::TemplateError(e.to_string()))?;
341
342        let result = template
343            .render(context! {
344                messages => messages,
345                add_generation_prompt => options.add_generation_prompt,
346                continue_final_message => options.continue_final_message,
347                enable_thinking => options.enable_thinking,
348                bos_token => &self.bos_token,
349                eos_token => &self.eos_token,
350            })
351            .map_err(|e| ChatTemplateError::RenderError(e.to_string()))?;
352
353        Ok(result.trim_start().to_string())
354    }
355
356    /// Convenience method: apply with add_generation_prompt=true
357    pub fn apply_for_generation(&self, messages: &[Message]) -> Result<String, ChatTemplateError> {
358        self.apply(messages, &ChatTemplateOptions::for_generation())
359    }
360}
361
362/// Multi-turn conversation manager
363pub struct Conversation {
364    messages: Vec<Message>,
365    template: ChatTemplate,
366    options: ChatTemplateOptions,
367}
368
369impl Conversation {
370    /// Create a new conversation with a system prompt
371    pub fn new(template: ChatTemplate, system_prompt: impl Into<String>) -> Self {
372        Self {
373            messages: vec![Message::system(system_prompt)],
374            template,
375            options: ChatTemplateOptions::for_generation(),
376        }
377    }
378
379    /// Create without a system prompt
380    pub fn without_system(template: ChatTemplate) -> Self {
381        Self {
382            messages: Vec::new(),
383            template,
384            options: ChatTemplateOptions::for_generation(),
385        }
386    }
387
388    /// Set options (e.g., enable thinking mode)
389    pub fn with_options(mut self, options: ChatTemplateOptions) -> Self {
390        self.options = options;
391        self
392    }
393
394    /// Add a user message and return the formatted prompt for generation
395    pub fn user_turn(&mut self, content: impl Into<String>) -> Result<String, ChatTemplateError> {
396        self.messages.push(Message::user(content));
397        self.template.apply(&self.messages, &self.options)
398    }
399
400    /// Record the assistant's response after generation
401    pub fn assistant_response(&mut self, content: impl Into<String>) {
402        self.messages.push(Message::assistant(content));
403    }
404
405    /// Add a message with a custom role
406    pub fn add_message(&mut self, message: Message) {
407        self.messages.push(message);
408    }
409
410    /// Get the conversation history
411    pub fn messages(&self) -> &[Message] {
412        &self.messages
413    }
414
415    /// Clear conversation history (keeps system prompt if present)
416    pub fn clear(&mut self) {
417        if let Some(first) = self.messages.first() {
418            if first.role == "system" {
419                let system = self.messages.remove(0);
420                self.messages.clear();
421                self.messages.push(system);
422                return;
423            }
424        }
425        self.messages.clear();
426    }
427
428    /// Format entire conversation for display (no generation prompt)
429    pub fn format_history(&self) -> Result<String, ChatTemplateError> {
430        self.template
431            .apply(&self.messages, &ChatTemplateOptions::for_training())
432    }
433}
434
435/// Errors that can occur with chat templates
436#[derive(Debug)]
437pub enum ChatTemplateError {
438    IoError(String),
439    ParseError(String),
440    TemplateError(String),
441    RenderError(String),
442    NoTemplate,
443}
444
445impl std::fmt::Display for ChatTemplateError {
446    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
447        match self {
448            Self::IoError(e) => write!(f, "IO error: {}", e),
449            Self::ParseError(e) => write!(f, "Parse error: {}", e),
450            Self::TemplateError(e) => write!(f, "Template error: {}", e),
451            Self::RenderError(e) => write!(f, "Render error: {}", e),
452            Self::NoTemplate => write!(f, "No chat_template found in config"),
453        }
454    }
455}
456
457impl std::error::Error for ChatTemplateError {}
458
459#[cfg(test)]
460mod tests {
461    use super::*;
462
463    #[test]
464    fn test_chatml_basic() {
465        let template = ChatTemplate::chatml();
466        let messages = vec![Message::system("You are helpful."), Message::user("Hello")];
467
468        let result = template.apply_for_generation(&messages).unwrap();
469
470        assert!(result.contains("<|im_start|>system\nYou are helpful.<|im_end|>"));
471        assert!(result.contains("<|im_start|>user\nHello<|im_end|>"));
472        assert!(result.ends_with("<|im_start|>assistant\n"));
473    }
474
475    #[test]
476    fn test_multi_turn_conversation() {
477        let mut conv = Conversation::new(ChatTemplate::chatml(), "You are helpful.");
478
479        let prompt1 = conv.user_turn("Hi").unwrap();
480        assert!(prompt1.contains("Hi"));
481
482        conv.assistant_response("Hello!");
483
484        let prompt2 = conv.user_turn("How are you?").unwrap();
485        assert!(prompt2.contains("Hi"));
486        assert!(prompt2.contains("Hello!"));
487        assert!(prompt2.contains("How are you?"));
488    }
489
490    #[test]
491    fn test_thinking_mode() {
492        let template = ChatTemplate::chatml_with_thinking();
493        let messages = vec![Message::user("Think about this")];
494
495        let result = template
496            .apply(
497                &messages,
498                &ChatTemplateOptions::for_generation().with_thinking(),
499            )
500            .unwrap();
501
502        assert!(result.contains("<think>"));
503    }
504
505    #[test]
506    fn test_llama3_format() {
507        let template = ChatTemplate::llama3();
508        let messages = vec![Message::system("You are helpful."), Message::user("Hello")];
509
510        let result = template.apply_for_generation(&messages).unwrap();
511
512        assert!(result.contains("<|begin_of_text|>"));
513        assert!(result.contains("<|start_header_id|>system<|end_header_id|>"));
514        assert!(result.contains("<|start_header_id|>user<|end_header_id|>"));
515        assert!(result.contains("<|eot_id|>"));
516    }
517
518    #[test]
519    fn test_from_json_config() {
520        let json = r#"{
521            "bos_token": "<s>",
522            "eos_token": "</s>",
523            "chat_template": "{% for m in messages %}{{ m.role }}: {{ m.content }}\n{% endfor %}"
524        }"#;
525
526        let template = ChatTemplate::from_tokenizer_config_str(json).unwrap();
527        let messages = vec![Message::user("test")];
528        let result = template.apply_for_generation(&messages).unwrap();
529
530        assert!(result.contains("user: test"));
531    }
532}