1#[derive(Debug, Clone)]
8pub struct ChatMessage {
9 pub role: String,
10 pub content: String,
11}
12
13impl ChatMessage {
14 pub fn system(content: impl Into<String>) -> Self {
15 Self {
16 role: "system".into(),
17 content: content.into(),
18 }
19 }
20
21 pub fn user(content: impl Into<String>) -> Self {
22 Self {
23 role: "user".into(),
24 content: content.into(),
25 }
26 }
27
28 pub fn assistant(content: impl Into<String>) -> Self {
29 Self {
30 role: "assistant".into(),
31 content: content.into(),
32 }
33 }
34}
35
36#[derive(Debug, Clone, Copy, PartialEq)]
38pub enum ChatTemplate {
39 ChatML,
41 Llama3,
43 Qwen,
45 Raw,
47}
48
49impl ChatTemplate {
50 pub fn from_architecture(arch: &str) -> Self {
52 match arch.to_lowercase().as_str() {
53 "llama" => ChatTemplate::ChatML, "qwen2" => ChatTemplate::Qwen,
55 "mistral" => ChatTemplate::Raw, _ => ChatTemplate::ChatML,
57 }
58 }
59
60 pub fn format(&self, messages: &[ChatMessage]) -> String {
62 match self {
63 ChatTemplate::ChatML | ChatTemplate::Qwen => format_chatml(messages),
64 ChatTemplate::Llama3 => format_llama3(messages),
65 ChatTemplate::Raw => format_raw(messages),
66 }
67 }
68
69 pub fn format_prompt(&self, prompt: &str) -> String {
71 self.format(&[ChatMessage::user(prompt)])
72 }
73
74 pub fn format_with_system(&self, system: &str, prompt: &str) -> String {
76 self.format(&[ChatMessage::system(system), ChatMessage::user(prompt)])
77 }
78}
79
80fn format_chatml(messages: &[ChatMessage]) -> String {
82 let mut output = String::new();
83 for msg in messages {
84 output.push_str("<|im_start|>");
85 output.push_str(&msg.role);
86 output.push('\n');
87 output.push_str(&msg.content);
88 output.push_str("<|im_end|>\n");
89 }
90 output.push_str("<|im_start|>assistant\n");
91 output
92}
93
94fn format_llama3(messages: &[ChatMessage]) -> String {
96 let mut output = String::from("<|begin_of_text|>");
97 for msg in messages {
98 output.push_str("<|start_header_id|>");
99 output.push_str(&msg.role);
100 output.push_str("<|end_header_id|>\n\n");
101 output.push_str(&msg.content);
102 output.push_str("<|eot_id|>");
103 }
104 output.push_str("<|start_header_id|>assistant<|end_header_id|>\n\n");
105 output
106}
107
108fn format_raw(messages: &[ChatMessage]) -> String {
110 let mut output = String::new();
111 for msg in messages {
112 if !output.is_empty() {
113 output.push('\n');
114 }
115 output.push_str(&msg.content);
116 }
117 output
118}
119
120#[cfg(test)]
121mod tests {
122 use super::*;
123
124 #[test]
125 fn chatml_single_user() {
126 let template = ChatTemplate::ChatML;
127 let result = template.format_prompt("Hello");
128 assert!(result.contains("<|im_start|>user\nHello<|im_end|>"));
129 assert!(result.ends_with("<|im_start|>assistant\n"));
130 }
131
132 #[test]
133 fn chatml_with_system() {
134 let template = ChatTemplate::ChatML;
135 let result = template.format_with_system("You are helpful.", "Hello");
136 assert!(result.contains("<|im_start|>system\nYou are helpful.<|im_end|>"));
137 assert!(result.contains("<|im_start|>user\nHello<|im_end|>"));
138 }
139
140 #[test]
141 fn chatml_multi_turn() {
142 let template = ChatTemplate::ChatML;
143 let messages = vec![
144 ChatMessage::user("What is Rust?"),
145 ChatMessage::assistant("A systems programming language."),
146 ChatMessage::user("Tell me more."),
147 ];
148 let result = template.format(&messages);
149 assert!(result.contains("What is Rust?"));
150 assert!(result.contains("A systems programming language."));
151 assert!(result.contains("Tell me more."));
152 assert!(result.ends_with("<|im_start|>assistant\n"));
153 }
154
155 #[test]
156 fn llama3_format() {
157 let template = ChatTemplate::Llama3;
158 let result = template.format_prompt("Hello");
159 assert!(result.contains("<|start_header_id|>user<|end_header_id|>"));
160 assert!(result.contains("Hello"));
161 assert!(result.contains("<|start_header_id|>assistant<|end_header_id|>"));
162 }
163
164 #[test]
165 fn detect_from_architecture() {
166 assert_eq!(
167 ChatTemplate::from_architecture("llama"),
168 ChatTemplate::ChatML
169 );
170 assert_eq!(ChatTemplate::from_architecture("qwen2"), ChatTemplate::Qwen);
171 }
172
173 #[test]
174 fn raw_format() {
175 let template = ChatTemplate::Raw;
176 let messages = vec![ChatMessage::user("Hello"), ChatMessage::user("World")];
177 let result = template.format(&messages);
178 assert_eq!(result, "Hello\nWorld");
179 }
180
181 #[test]
184 fn chatml_empty_messages_produces_assistant_header() {
185 let template = ChatTemplate::ChatML;
188 let result = template.format(&[]);
189 assert_eq!(
190 result, "<|im_start|>assistant\n",
191 "empty messages should produce just the assistant header"
192 );
193 }
194
195 #[test]
196 fn llama3_empty_messages_produces_assistant_header() {
197 let template = ChatTemplate::Llama3;
198 let result = template.format(&[]);
199 assert!(
200 result.contains("<|start_header_id|>assistant<|end_header_id|>"),
201 "empty Llama3 messages should still produce assistant header"
202 );
203 }
204
205 #[test]
206 fn chatml_handles_special_characters_in_content() {
207 let template = ChatTemplate::ChatML;
210 let content = "Here is code:\n```rust\nfn main() { println!(\"<|test|>\"); }\n```";
211 let result = template.format_prompt(content);
212 assert!(
213 result.contains(content),
214 "special characters in content should be preserved verbatim"
215 );
216 }
217
218 #[test]
219 fn chatml_multi_turn_preserves_order() {
220 let template = ChatTemplate::ChatML;
223 let messages = vec![
224 ChatMessage::system("You are a calculator."),
225 ChatMessage::user("What is 2+2?"),
226 ChatMessage::assistant("4"),
227 ChatMessage::user("And 3+3?"),
228 ];
229 let result = template.format(&messages);
230
231 let sys_pos = result.find("system\nYou are a calculator.").unwrap();
233 let user1_pos = result.find("user\nWhat is 2+2?").unwrap();
234 let asst_pos = result.find("assistant\n4").unwrap();
235 let user2_pos = result.find("user\nAnd 3+3?").unwrap();
236 let final_asst = result.rfind("<|im_start|>assistant\n").unwrap();
237
238 assert!(sys_pos < user1_pos, "system should come before first user");
239 assert!(
240 user1_pos < asst_pos,
241 "first user should come before assistant response"
242 );
243 assert!(
244 asst_pos < user2_pos,
245 "assistant response should come before second user"
246 );
247 assert!(
248 user2_pos < final_asst,
249 "second user should come before final assistant prompt"
250 );
251 }
252
253 #[test]
254 fn from_architecture_unknown_defaults_to_chatml() {
255 let template = ChatTemplate::from_architecture("unknown_arch_xyz");
257 assert_eq!(template, ChatTemplate::ChatML);
258 }
259}