1#[derive(Debug, Clone)]
8pub struct ChatMessage {
9 pub role: String,
10 pub content: String,
11}
12
13impl ChatMessage {
14 pub fn system(content: impl Into<String>) -> Self {
15 Self {
16 role: "system".into(),
17 content: content.into(),
18 }
19 }
20
21 pub fn user(content: impl Into<String>) -> Self {
22 Self {
23 role: "user".into(),
24 content: content.into(),
25 }
26 }
27
28 pub fn assistant(content: impl Into<String>) -> Self {
29 Self {
30 role: "assistant".into(),
31 content: content.into(),
32 }
33 }
34}
35
36#[derive(Debug, Clone, Copy, PartialEq)]
38pub enum ChatTemplate {
39 ChatML,
41 Llama3,
43 Qwen,
45 Raw,
47}
48
49impl ChatTemplate {
50 pub fn from_architecture(arch: &str) -> Self {
52 match arch.to_lowercase().as_str() {
53 "llama" => ChatTemplate::ChatML, "qwen2" => ChatTemplate::Qwen,
55 "mistral" => ChatTemplate::Raw, _ => ChatTemplate::ChatML,
57 }
58 }
59
60 pub fn format(&self, messages: &[ChatMessage]) -> String {
62 match self {
63 ChatTemplate::ChatML | ChatTemplate::Qwen => format_chatml(messages),
64 ChatTemplate::Llama3 => format_llama3(messages),
65 ChatTemplate::Raw => format_raw(messages),
66 }
67 }
68
69 pub fn format_prompt(&self, prompt: &str) -> String {
71 self.format(&[ChatMessage::user(prompt)])
72 }
73
74 pub fn format_with_system(&self, system: &str, prompt: &str) -> String {
76 self.format(&[ChatMessage::system(system), ChatMessage::user(prompt)])
77 }
78}
79
80fn format_chatml(messages: &[ChatMessage]) -> String {
82 let mut output = String::new();
83 for msg in messages {
84 output.push_str("<|im_start|>");
85 output.push_str(&msg.role);
86 output.push('\n');
87 output.push_str(&msg.content);
88 output.push_str("<|im_end|>\n");
89 }
90 output.push_str("<|im_start|>assistant\n");
91 output
92}
93
94fn format_llama3(messages: &[ChatMessage]) -> String {
96 let mut output = String::from("<|begin_of_text|>");
97 for msg in messages {
98 output.push_str("<|start_header_id|>");
99 output.push_str(&msg.role);
100 output.push_str("<|end_header_id|>\n\n");
101 output.push_str(&msg.content);
102 output.push_str("<|eot_id|>");
103 }
104 output.push_str("<|start_header_id|>assistant<|end_header_id|>\n\n");
105 output
106}
107
108fn format_raw(messages: &[ChatMessage]) -> String {
110 let mut output = String::new();
111 for msg in messages {
112 if !output.is_empty() {
113 output.push('\n');
114 }
115 output.push_str(&msg.content);
116 }
117 output
118}
119
120#[cfg(test)]
121mod tests {
122 use super::*;
123
124 #[test]
125 fn chatml_single_user() {
126 let template = ChatTemplate::ChatML;
127 let result = template.format_prompt("Hello");
128 assert!(result.contains("<|im_start|>user\nHello<|im_end|>"));
129 assert!(result.ends_with("<|im_start|>assistant\n"));
130 }
131
132 #[test]
133 fn chatml_with_system() {
134 let template = ChatTemplate::ChatML;
135 let result = template.format_with_system("You are helpful.", "Hello");
136 assert!(result.contains("<|im_start|>system\nYou are helpful.<|im_end|>"));
137 assert!(result.contains("<|im_start|>user\nHello<|im_end|>"));
138 }
139
140 #[test]
141 fn chatml_multi_turn() {
142 let template = ChatTemplate::ChatML;
143 let messages = vec![
144 ChatMessage::user("What is Rust?"),
145 ChatMessage::assistant("A systems programming language."),
146 ChatMessage::user("Tell me more."),
147 ];
148 let result = template.format(&messages);
149 assert!(result.contains("What is Rust?"));
150 assert!(result.contains("A systems programming language."));
151 assert!(result.contains("Tell me more."));
152 assert!(result.ends_with("<|im_start|>assistant\n"));
153 }
154
155 #[test]
156 fn llama3_format() {
157 let template = ChatTemplate::Llama3;
158 let result = template.format_prompt("Hello");
159 assert!(result.contains("<|start_header_id|>user<|end_header_id|>"));
160 assert!(result.contains("Hello"));
161 assert!(result.contains("<|start_header_id|>assistant<|end_header_id|>"));
162 }
163
164 #[test]
165 fn detect_from_architecture() {
166 assert_eq!(
167 ChatTemplate::from_architecture("llama"),
168 ChatTemplate::ChatML
169 );
170 assert_eq!(ChatTemplate::from_architecture("qwen2"), ChatTemplate::Qwen);
171 }
172
173 #[test]
174 fn raw_format() {
175 let template = ChatTemplate::Raw;
176 let messages = vec![ChatMessage::user("Hello"), ChatMessage::user("World")];
177 let result = template.format(&messages);
178 assert_eq!(result, "Hello\nWorld");
179 }
180}