Skip to main content

forgellm_runtime/
chat.rs

1//! Chat template formatting for instruct models.
2//!
3//! Formats conversations into the prompt format expected by each model family.
4//! Templates are based on the model's architecture and common conventions.
5
6/// A message in a conversation.
7#[derive(Debug, Clone)]
8pub struct ChatMessage {
9    pub role: String,
10    pub content: String,
11}
12
13impl ChatMessage {
14    pub fn system(content: impl Into<String>) -> Self {
15        Self {
16            role: "system".into(),
17            content: content.into(),
18        }
19    }
20
21    pub fn user(content: impl Into<String>) -> Self {
22        Self {
23            role: "user".into(),
24            content: content.into(),
25        }
26    }
27
28    pub fn assistant(content: impl Into<String>) -> Self {
29        Self {
30            role: "assistant".into(),
31            content: content.into(),
32        }
33    }
34}
35
36/// Chat template format.
37#[derive(Debug, Clone, Copy, PartialEq)]
38pub enum ChatTemplate {
39    /// SmolLM/Llama-style: `<|im_start|>role\ncontent<|im_end|>`
40    ChatML,
41    /// Llama 3 style: `<|start_header_id|>role<|end_header_id|>\n\ncontent<|eot_id|>`
42    Llama3,
43    /// Qwen style (same as ChatML)
44    Qwen,
45    /// Raw: just concatenate messages with no special formatting
46    Raw,
47}
48
49impl ChatTemplate {
50    /// Detect the appropriate template from a model architecture name.
51    pub fn from_architecture(arch: &str) -> Self {
52        match arch.to_lowercase().as_str() {
53            "llama" => ChatTemplate::ChatML, // SmolLM and many Llama finetunes use ChatML
54            "qwen2" => ChatTemplate::Qwen,
55            "mistral" => ChatTemplate::Raw, // Mistral uses [INST] but we simplify
56            _ => ChatTemplate::ChatML,
57        }
58    }
59
60    /// Format a list of messages into a prompt string.
61    pub fn format(&self, messages: &[ChatMessage]) -> String {
62        match self {
63            ChatTemplate::ChatML | ChatTemplate::Qwen => format_chatml(messages),
64            ChatTemplate::Llama3 => format_llama3(messages),
65            ChatTemplate::Raw => format_raw(messages),
66        }
67    }
68
69    /// Format a single user prompt as a chat conversation.
70    pub fn format_prompt(&self, prompt: &str) -> String {
71        self.format(&[ChatMessage::user(prompt)])
72    }
73
74    /// Format a single user prompt with a system message.
75    pub fn format_with_system(&self, system: &str, prompt: &str) -> String {
76        self.format(&[ChatMessage::system(system), ChatMessage::user(prompt)])
77    }
78}
79
80/// ChatML format: `<|im_start|>role\ncontent<|im_end|>\n`
81fn format_chatml(messages: &[ChatMessage]) -> String {
82    let mut output = String::new();
83    for msg in messages {
84        output.push_str("<|im_start|>");
85        output.push_str(&msg.role);
86        output.push('\n');
87        output.push_str(&msg.content);
88        output.push_str("<|im_end|>\n");
89    }
90    output.push_str("<|im_start|>assistant\n");
91    output
92}
93
94/// Llama 3 format
95fn format_llama3(messages: &[ChatMessage]) -> String {
96    let mut output = String::from("<|begin_of_text|>");
97    for msg in messages {
98        output.push_str("<|start_header_id|>");
99        output.push_str(&msg.role);
100        output.push_str("<|end_header_id|>\n\n");
101        output.push_str(&msg.content);
102        output.push_str("<|eot_id|>");
103    }
104    output.push_str("<|start_header_id|>assistant<|end_header_id|>\n\n");
105    output
106}
107
108/// Raw format: just concatenate
109fn format_raw(messages: &[ChatMessage]) -> String {
110    let mut output = String::new();
111    for msg in messages {
112        if !output.is_empty() {
113            output.push('\n');
114        }
115        output.push_str(&msg.content);
116    }
117    output
118}
119
120#[cfg(test)]
121mod tests {
122    use super::*;
123
124    #[test]
125    fn chatml_single_user() {
126        let template = ChatTemplate::ChatML;
127        let result = template.format_prompt("Hello");
128        assert!(result.contains("<|im_start|>user\nHello<|im_end|>"));
129        assert!(result.ends_with("<|im_start|>assistant\n"));
130    }
131
132    #[test]
133    fn chatml_with_system() {
134        let template = ChatTemplate::ChatML;
135        let result = template.format_with_system("You are helpful.", "Hello");
136        assert!(result.contains("<|im_start|>system\nYou are helpful.<|im_end|>"));
137        assert!(result.contains("<|im_start|>user\nHello<|im_end|>"));
138    }
139
140    #[test]
141    fn chatml_multi_turn() {
142        let template = ChatTemplate::ChatML;
143        let messages = vec![
144            ChatMessage::user("What is Rust?"),
145            ChatMessage::assistant("A systems programming language."),
146            ChatMessage::user("Tell me more."),
147        ];
148        let result = template.format(&messages);
149        assert!(result.contains("What is Rust?"));
150        assert!(result.contains("A systems programming language."));
151        assert!(result.contains("Tell me more."));
152        assert!(result.ends_with("<|im_start|>assistant\n"));
153    }
154
155    #[test]
156    fn llama3_format() {
157        let template = ChatTemplate::Llama3;
158        let result = template.format_prompt("Hello");
159        assert!(result.contains("<|start_header_id|>user<|end_header_id|>"));
160        assert!(result.contains("Hello"));
161        assert!(result.contains("<|start_header_id|>assistant<|end_header_id|>"));
162    }
163
164    #[test]
165    fn detect_from_architecture() {
166        assert_eq!(
167            ChatTemplate::from_architecture("llama"),
168            ChatTemplate::ChatML
169        );
170        assert_eq!(ChatTemplate::from_architecture("qwen2"), ChatTemplate::Qwen);
171    }
172
173    #[test]
174    fn raw_format() {
175        let template = ChatTemplate::Raw;
176        let messages = vec![ChatMessage::user("Hello"), ChatMessage::user("World")];
177        let result = template.format(&messages);
178        assert_eq!(result, "Hello\nWorld");
179    }
180
181    // ── Real-world validation tests ──────────────────────────────────────
182
183    #[test]
184    fn chatml_empty_messages_produces_assistant_header() {
185        // An empty message list should still produce the assistant prompt header
186        // so the model knows to start generating.
187        let template = ChatTemplate::ChatML;
188        let result = template.format(&[]);
189        assert_eq!(
190            result, "<|im_start|>assistant\n",
191            "empty messages should produce just the assistant header"
192        );
193    }
194
195    #[test]
196    fn llama3_empty_messages_produces_assistant_header() {
197        let template = ChatTemplate::Llama3;
198        let result = template.format(&[]);
199        assert!(
200            result.contains("<|start_header_id|>assistant<|end_header_id|>"),
201            "empty Llama3 messages should still produce assistant header"
202        );
203    }
204
205    #[test]
206    fn chatml_handles_special_characters_in_content() {
207        // Content with newlines, angle brackets, and pipe characters should
208        // be preserved verbatim (no escaping in ChatML).
209        let template = ChatTemplate::ChatML;
210        let content = "Here is code:\n```rust\nfn main() { println!(\"<|test|>\"); }\n```";
211        let result = template.format_prompt(content);
212        assert!(
213            result.contains(content),
214            "special characters in content should be preserved verbatim"
215        );
216    }
217
218    #[test]
219    fn chatml_multi_turn_preserves_order() {
220        // A multi-turn conversation should maintain message order and all
221        // role transitions should be correct.
222        let template = ChatTemplate::ChatML;
223        let messages = vec![
224            ChatMessage::system("You are a calculator."),
225            ChatMessage::user("What is 2+2?"),
226            ChatMessage::assistant("4"),
227            ChatMessage::user("And 3+3?"),
228        ];
229        let result = template.format(&messages);
230
231        // Verify ordering: system before first user, assistant before second user
232        let sys_pos = result.find("system\nYou are a calculator.").unwrap();
233        let user1_pos = result.find("user\nWhat is 2+2?").unwrap();
234        let asst_pos = result.find("assistant\n4").unwrap();
235        let user2_pos = result.find("user\nAnd 3+3?").unwrap();
236        let final_asst = result.rfind("<|im_start|>assistant\n").unwrap();
237
238        assert!(sys_pos < user1_pos, "system should come before first user");
239        assert!(
240            user1_pos < asst_pos,
241            "first user should come before assistant response"
242        );
243        assert!(
244            asst_pos < user2_pos,
245            "assistant response should come before second user"
246        );
247        assert!(
248            user2_pos < final_asst,
249            "second user should come before final assistant prompt"
250        );
251    }
252
253    #[test]
254    fn from_architecture_unknown_defaults_to_chatml() {
255        // Unknown architectures should default to ChatML rather than panicking.
256        let template = ChatTemplate::from_architecture("unknown_arch_xyz");
257        assert_eq!(template, ChatTemplate::ChatML);
258    }
259}