Skip to main content

forgellm_runtime/
chat.rs

1//! Chat template formatting for instruct models.
2//!
3//! Formats conversations into the prompt format expected by each model family.
4//! Templates are based on the model's architecture and common conventions.
5
6/// A message in a conversation.
7#[derive(Debug, Clone)]
8pub struct ChatMessage {
9    pub role: String,
10    pub content: String,
11}
12
13impl ChatMessage {
14    pub fn system(content: impl Into<String>) -> Self {
15        Self {
16            role: "system".into(),
17            content: content.into(),
18        }
19    }
20
21    pub fn user(content: impl Into<String>) -> Self {
22        Self {
23            role: "user".into(),
24            content: content.into(),
25        }
26    }
27
28    pub fn assistant(content: impl Into<String>) -> Self {
29        Self {
30            role: "assistant".into(),
31            content: content.into(),
32        }
33    }
34}
35
36/// Chat template format.
37#[derive(Debug, Clone, Copy, PartialEq)]
38pub enum ChatTemplate {
39    /// SmolLM/Llama-style: `<|im_start|>role\ncontent<|im_end|>`
40    ChatML,
41    /// Llama 3 style: `<|start_header_id|>role<|end_header_id|>\n\ncontent<|eot_id|>`
42    Llama3,
43    /// Qwen style (same as ChatML)
44    Qwen,
45    /// Raw: just concatenate messages with no special formatting
46    Raw,
47}
48
49impl ChatTemplate {
50    /// Detect the appropriate template from a model architecture name.
51    pub fn from_architecture(arch: &str) -> Self {
52        match arch.to_lowercase().as_str() {
53            "llama" => ChatTemplate::ChatML, // SmolLM and many Llama finetunes use ChatML
54            "qwen2" => ChatTemplate::Qwen,
55            "mistral" => ChatTemplate::Raw, // Mistral uses [INST] but we simplify
56            _ => ChatTemplate::ChatML,
57        }
58    }
59
60    /// Format a list of messages into a prompt string.
61    pub fn format(&self, messages: &[ChatMessage]) -> String {
62        match self {
63            ChatTemplate::ChatML | ChatTemplate::Qwen => format_chatml(messages),
64            ChatTemplate::Llama3 => format_llama3(messages),
65            ChatTemplate::Raw => format_raw(messages),
66        }
67    }
68
69    /// Format a single user prompt as a chat conversation.
70    pub fn format_prompt(&self, prompt: &str) -> String {
71        self.format(&[ChatMessage::user(prompt)])
72    }
73
74    /// Format a single user prompt with a system message.
75    pub fn format_with_system(&self, system: &str, prompt: &str) -> String {
76        self.format(&[ChatMessage::system(system), ChatMessage::user(prompt)])
77    }
78}
79
80/// ChatML format: `<|im_start|>role\ncontent<|im_end|>\n`
81fn format_chatml(messages: &[ChatMessage]) -> String {
82    let mut output = String::new();
83    for msg in messages {
84        output.push_str("<|im_start|>");
85        output.push_str(&msg.role);
86        output.push('\n');
87        output.push_str(&msg.content);
88        output.push_str("<|im_end|>\n");
89    }
90    output.push_str("<|im_start|>assistant\n");
91    output
92}
93
94/// Llama 3 format
95fn format_llama3(messages: &[ChatMessage]) -> String {
96    let mut output = String::from("<|begin_of_text|>");
97    for msg in messages {
98        output.push_str("<|start_header_id|>");
99        output.push_str(&msg.role);
100        output.push_str("<|end_header_id|>\n\n");
101        output.push_str(&msg.content);
102        output.push_str("<|eot_id|>");
103    }
104    output.push_str("<|start_header_id|>assistant<|end_header_id|>\n\n");
105    output
106}
107
108/// Raw format: just concatenate
109fn format_raw(messages: &[ChatMessage]) -> String {
110    let mut output = String::new();
111    for msg in messages {
112        if !output.is_empty() {
113            output.push('\n');
114        }
115        output.push_str(&msg.content);
116    }
117    output
118}
119
120#[cfg(test)]
121mod tests {
122    use super::*;
123
124    #[test]
125    fn chatml_single_user() {
126        let template = ChatTemplate::ChatML;
127        let result = template.format_prompt("Hello");
128        assert!(result.contains("<|im_start|>user\nHello<|im_end|>"));
129        assert!(result.ends_with("<|im_start|>assistant\n"));
130    }
131
132    #[test]
133    fn chatml_with_system() {
134        let template = ChatTemplate::ChatML;
135        let result = template.format_with_system("You are helpful.", "Hello");
136        assert!(result.contains("<|im_start|>system\nYou are helpful.<|im_end|>"));
137        assert!(result.contains("<|im_start|>user\nHello<|im_end|>"));
138    }
139
140    #[test]
141    fn chatml_multi_turn() {
142        let template = ChatTemplate::ChatML;
143        let messages = vec![
144            ChatMessage::user("What is Rust?"),
145            ChatMessage::assistant("A systems programming language."),
146            ChatMessage::user("Tell me more."),
147        ];
148        let result = template.format(&messages);
149        assert!(result.contains("What is Rust?"));
150        assert!(result.contains("A systems programming language."));
151        assert!(result.contains("Tell me more."));
152        assert!(result.ends_with("<|im_start|>assistant\n"));
153    }
154
155    #[test]
156    fn llama3_format() {
157        let template = ChatTemplate::Llama3;
158        let result = template.format_prompt("Hello");
159        assert!(result.contains("<|start_header_id|>user<|end_header_id|>"));
160        assert!(result.contains("Hello"));
161        assert!(result.contains("<|start_header_id|>assistant<|end_header_id|>"));
162    }
163
164    #[test]
165    fn detect_from_architecture() {
166        assert_eq!(
167            ChatTemplate::from_architecture("llama"),
168            ChatTemplate::ChatML
169        );
170        assert_eq!(ChatTemplate::from_architecture("qwen2"), ChatTemplate::Qwen);
171    }
172
173    #[test]
174    fn raw_format() {
175        let template = ChatTemplate::Raw;
176        let messages = vec![ChatMessage::user("Hello"), ChatMessage::user("World")];
177        let result = template.format(&messages);
178        assert_eq!(result, "Hello\nWorld");
179    }
180}