1use minijinja::{context, Environment};
2use serde::Serialize;
3
4#[derive(Serialize, Debug)]
5pub struct Message {
6 pub role: String,
7 pub content: String,
8}
9
10const CHATML_JINJA_TEMPLATE: &str = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}";
14
15const CHATML_JINJA_TEMPLATE_NAME: &str = "chatml";
16
17const MISTRAL_INSTRUCT_TEMPLATE: &str = "{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% endif %}{% endfor %}";
18
19const MISTRAL_INSTRUCT_TEMPLATE_NAME: &str = "mistral-instruct";
20
21const TAIDE_JINJA_TEMPLATE_NAME: &str = "taide";
22
23const TAIDE_JINJA_TEMPLATE: &str = "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = '<<SYS>>\n' + messages[0]['content'] + '\n<</SYS>>\n\n' %}{% else %}{% set loop_messages = messages %}{% set system_message = '' %}{% endif %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{% set content = system_message + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content + ' [/INST]'}}{% elif message['role'] == 'assistant' %}{{ ' ' + content + ' ' + eos_token }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}";
24
25fn apply_chatml_template(
27 messages: &Vec<Message>,
28 add_generation_prompt: bool,
29) -> Result<String, ApplyChatMLTemplateError> {
30 let mut env = Environment::new();
31 env.add_template(CHATML_JINJA_TEMPLATE_NAME, CHATML_JINJA_TEMPLATE)
32 .map_err(ApplyChatMLTemplateError::AddTemplateError)?;
33 let template = env
34 .get_template(CHATML_JINJA_TEMPLATE_NAME)
35 .map_err(ApplyChatMLTemplateError::GetTemplateError)?;
36 template
37 .render(context! {
38 messages => messages,
39 add_generation_prompt => add_generation_prompt,
40 })
41 .map_err(ApplyChatMLTemplateError::RenderTemplateError)
42}
43
44fn apply_mistral_instruct_template(
45 messages: &Vec<Message>,
46 add_generation_prompt: bool,
47) -> Result<String, ApplyMistralInstructTemplateError> {
48 let mut env = Environment::new();
49 env.add_template(MISTRAL_INSTRUCT_TEMPLATE_NAME, MISTRAL_INSTRUCT_TEMPLATE)
50 .map_err(ApplyMistralInstructTemplateError::AddTemplateError)?;
51 let template = env
52 .get_template(MISTRAL_INSTRUCT_TEMPLATE_NAME)
53 .map_err(ApplyMistralInstructTemplateError::GetTemplateError)?;
54 template
55 .render(context! {
56 messages => messages,
57 add_generation_prompt => add_generation_prompt,
58 bos_token => "<s>",
60 eos_token => "</s>",
62 })
63 .map_err(ApplyMistralInstructTemplateError::RenderTemplateError)
64}
65
66fn apply_taide_template(
68 messages: &Vec<Message>,
69) -> Result<String, ApplyTAIDETemplateError> {
70 let mut env = Environment::new();
71 env.add_template(TAIDE_JINJA_TEMPLATE_NAME, TAIDE_JINJA_TEMPLATE)
72 .map_err(ApplyTAIDETemplateError::AddTemplateError)?;
73 let template = env
74 .get_template(TAIDE_JINJA_TEMPLATE_NAME)
75 .map_err(ApplyTAIDETemplateError::GetTemplateError)?;
76 template
77 .render(context! {
78 messages => messages,
79 bos_token => "<s>",
80 eos_token => "</s>",
81 })
82 .map_err(ApplyTAIDETemplateError::RenderTemplateError)
83}
84
85pub enum ChatTemplate {
87 ChatML,
88 MistralInstruct,
89 TAIDE
90}
91
92pub fn apply_template(
100 template: ChatTemplate,
101 messages: &Vec<Message>,
102 add_generation_prompt: bool,
103) -> Result<String, ApplyTemplateError> {
104 match template {
105 ChatTemplate::ChatML => apply_chatml_template(messages, add_generation_prompt)
106 .map_err(ApplyTemplateError::ApplyChatMLTemplateError),
107 ChatTemplate::MistralInstruct => {
108 apply_mistral_instruct_template(messages, add_generation_prompt)
109 .map_err(ApplyTemplateError::ApplyMistralInstructTemplateError)
110 }
111 ChatTemplate::TAIDE => apply_taide_template(messages)
112 .map_err(ApplyTemplateError::ApplyTAIDETemplateError),
113 }
114}
115
116#[derive(thiserror::Error, Debug)]
117pub enum ApplyChatMLTemplateError {
118 #[error("failed to add template")]
119 AddTemplateError(#[source] minijinja::Error),
120 #[error("failed to get template")]
121 GetTemplateError(#[source] minijinja::Error),
122 #[error("failed to render")]
123 RenderTemplateError(#[source] minijinja::Error),
124}
125
126#[derive(thiserror::Error, Debug)]
127pub enum ApplyMistralInstructTemplateError {
128 #[error("failed to add template")]
129 AddTemplateError(#[source] minijinja::Error),
130 #[error("failed to get template")]
131 GetTemplateError(#[source] minijinja::Error),
132 #[error("failed to render")]
133 RenderTemplateError(#[source] minijinja::Error),
134}
135
136#[derive(thiserror::Error, Debug)]
137pub enum ApplyTAIDETemplateError {
138 #[error("failed to add template")]
139 AddTemplateError(#[source] minijinja::Error),
140 #[error("failed to get template")]
141 GetTemplateError(#[source] minijinja::Error),
142 #[error("failed to render")]
143 RenderTemplateError(#[source] minijinja::Error),
144}
145
146
147#[derive(thiserror::Error, Debug)]
148pub enum ApplyTemplateError {
149 #[error("failed to apply chatml template")]
150 ApplyChatMLTemplateError(#[source] ApplyChatMLTemplateError),
151 #[error("failed to apply mistral instruct template")]
152 ApplyMistralInstructTemplateError(#[source] ApplyMistralInstructTemplateError),
153 #[error("failed to apply taide template")]
154 ApplyTAIDETemplateError(#[source] ApplyTAIDETemplateError),
155}
156
157#[cfg(test)]
158mod tests {
159 use super::*;
160
161 #[test]
162 fn test_apply_chatml_template_one_shot() {
163 let messages = vec![
164 Message {
165 role: "system".to_string(),
166 content: "Assistant is an intelligent chatbot designed to help users answer their tax related questions.".to_string(),
167 },
168 Message {
169 role: "user".to_string(),
170 content: "Hello, who are you?".to_string(),
171 }
172 ];
173
174 let prompt = apply_template(ChatTemplate::ChatML, &messages, true).unwrap();
175 assert_eq!(prompt, "<|im_start|>system\nAssistant is an intelligent chatbot designed to help users answer their tax related questions.<|im_end|>\n<|im_start|>user\nHello, who are you?<|im_end|>\n<|im_start|>assistant\n".to_string());
176
177 let prompt = apply_template(ChatTemplate::ChatML, &messages, false).unwrap();
178 assert_eq!(prompt, "<|im_start|>system\nAssistant is an intelligent chatbot designed to help users answer their tax related questions.<|im_end|>\n<|im_start|>user\nHello, who are you?<|im_end|>\n".to_string());
179 }
180
181 #[test]
182 fn test_apply_chatml_template_few_shots() {
183 let messages = vec![
184 Message {
185 role: "system".to_string(),
186 content: "Assistant is an intelligent chatbot designed to help users answer their tax related questions.".to_string(),
187 },
188 Message {
189 role: "user".to_string(),
190 content: "When do I need to file my taxes by?".to_string(),
191 },
192 Message {
193 role: "assistant".to_string(),
194 content: "In 2023, you will need to file your taxes by April 18th. The date falls after the usual April 15th deadline because April 15th falls on a Saturday in 2023.".to_string(),
195 },
196 Message {
197 role: "user".to_string(),
198 content: "How can I check the status of my tax refund?".to_string(),
199 }
200 ];
201
202 let prompt = apply_template(ChatTemplate::ChatML, &messages, true).unwrap();
203 assert_eq!(prompt, "<|im_start|>system\nAssistant is an intelligent chatbot designed to help users answer their tax related questions.<|im_end|>\n<|im_start|>user\nWhen do I need to file my taxes by?<|im_end|>\n<|im_start|>assistant\nIn 2023, you will need to file your taxes by April 18th. The date falls after the usual April 15th deadline because April 15th falls on a Saturday in 2023.<|im_end|>\n<|im_start|>user\nHow can I check the status of my tax refund?<|im_end|>\n<|im_start|>assistant\n".to_string());
204
205 let prompt = apply_template(ChatTemplate::ChatML, &messages, false).unwrap();
206 assert_eq!(prompt, "<|im_start|>system\nAssistant is an intelligent chatbot designed to help users answer their tax related questions.<|im_end|>\n<|im_start|>user\nWhen do I need to file my taxes by?<|im_end|>\n<|im_start|>assistant\nIn 2023, you will need to file your taxes by April 18th. The date falls after the usual April 15th deadline because April 15th falls on a Saturday in 2023.<|im_end|>\n<|im_start|>user\nHow can I check the status of my tax refund?<|im_end|>\n".to_string());
207 }
208
209 #[test]
210 fn test_apply_mistral_instruct_template_one_shot() {
211 let messages = vec![
212 Message {
213 role: "user".to_string(),
214 content: "Hello, who are you?".to_string(),
215 },
216 ];
217
218 let prompt = apply_template(ChatTemplate::MistralInstruct, &messages, true).unwrap();
219 assert_eq!(prompt, "<s>[INST] Hello, who are you? [/INST]".to_string());
220 }
221
222 #[test]
223 fn test_apply_mistral_instruct_template_few_shots() {
224 let messages = vec![
226 Message {
227 role: "user".to_string(),
228 content: "Hello, who are you?".to_string(),
229 },
230 Message {
231 role: "assistant".to_string(),
232 content: "I'm doing great. How can I help you today?".to_string(),
233 },
234 Message {
235 role: "user".to_string(),
236 content: "I'd like to show off how chat templating works!".to_string(),
237 },
238 Message {
239 role: "assistant".to_string(),
240 content: "Are you sure?".to_string(),
241 },
242 Message {
243 role: "user".to_string(),
244 content: "Yes!".to_string(),
245 },
246 ];
247
248 let prompt = apply_template(ChatTemplate::MistralInstruct, &messages, true).unwrap();
249 assert_eq!(prompt, "<s>[INST] Hello, who are you? [/INST]I'm doing great. How can I help you today?</s>[INST] I'd like to show off how chat templating works! [/INST]Are you sure?</s>[INST] Yes! [/INST]".to_string());
250 }
251
252 #[test]
253 fn test_apply_taide_template_one_shot() {
254 let messages = vec![
255 Message {
256 role: "user".to_string(),
257 content: "你好嗎?".to_string(),
258 }
259 ];
260
261 let prompt = apply_template(ChatTemplate::TAIDE, &messages, true).unwrap();
263 assert_eq!(prompt, "<s>[INST] 你好嗎? [/INST]".to_string());
264
265 let prompt = apply_template(ChatTemplate::TAIDE, &messages, false).unwrap();
266 assert_eq!(prompt, "<s>[INST] 你好嗎? [/INST]".to_string());
267 }
268
269 #[test]
270 fn test_apply_taide_template_one_shot_with_sys_prompt() {
271 let messages = vec![
272 Message {
273 role: "system".to_string(),
274 content: "你是一個來自台灣的AI助理,你的名字是 TAIDE。".to_string(),
275 },
276 Message {
277 role: "user".to_string(),
278 content: "你好嗎?".to_string(),
279 }
280 ];
281
282 let prompt = apply_template(ChatTemplate::TAIDE, &messages, true).unwrap();
284 assert_eq!(prompt, "<s>[INST] <<SYS>>\n你是一個來自台灣的AI助理,你的名字是 TAIDE。\n<</SYS>>\n\n你好嗎? [/INST]".to_string());
285
286 let prompt = apply_template(ChatTemplate::TAIDE, &messages, false).unwrap();
287 assert_eq!(prompt, "<s>[INST] <<SYS>>\n你是一個來自台灣的AI助理,你的名字是 TAIDE。\n<</SYS>>\n\n你好嗎? [/INST]".to_string());
288 }
289
290 #[test]
291 fn test_apply_taide_template_few_shot_with_sys_prompt() {
292 let messages = vec![
293 Message {
294 role: "system".to_string(),
295 content: "你是一個來自台灣的AI助理,你的名字是 TAIDE。".to_string(),
296 },
297 Message {
298 role: "user".to_string(),
299 content: "你好嗎?".to_string(),
300 },
301 Message {
302 role: "assistant".to_string(),
303 content: "我很好。".to_string(),
304 },
305 Message {
306 role: "user".to_string(),
307 content: "今天天氣怎樣?".to_string(),
308 },
309 ];
310
311 let prompt = apply_template(ChatTemplate::TAIDE, &messages, true).unwrap();
313 assert_eq!(prompt, "<s>[INST] <<SYS>>\n你是一個來自台灣的AI助理,你的名字是 TAIDE。\n<</SYS>>\n\n你好嗎? [/INST] 我很好。 </s><s>[INST] 今天天氣怎樣? [/INST]".to_string());
314
315 let prompt = apply_template(ChatTemplate::TAIDE, &messages, false).unwrap();
316 assert_eq!(prompt, "<s>[INST] <<SYS>>\n你是一個來自台灣的AI助理,你的名字是 TAIDE。\n<</SYS>>\n\n你好嗎? [/INST] 我很好。 </s><s>[INST] 今天天氣怎樣? [/INST]".to_string());
317 }
318
319 #[test]
320 fn test_apply_taide_template_few_shot_conversation_sys_prompt() {
321 let messages = vec![
322 Message {
323 role: "system".to_string(),
324 content: "你是一個來自台灣的AI助理,你的名字是 TAIDE。".to_string(),
325 },
326 Message {
327 role: "user".to_string(),
328 content: "你好嗎?".to_string(),
329 },
330 Message {
331 role: "assistant".to_string(),
332 content: "我很好。".to_string(),
333 },
334 Message {
335 role: "user".to_string(),
336 content: "今天天氣怎樣?".to_string(),
337 },
338 Message {
339 role: "assistant".to_string(),
340 content: "大太陽。".to_string(),
341 },
342 Message {
343 role: "user".to_string(),
344 content: "你敢感覺如何?".to_string(),
345 },
346 ];
347
348 let prompt = apply_template(ChatTemplate::TAIDE, &messages, true).unwrap();
350 assert_eq!(prompt, "<s>[INST] <<SYS>>\n你是一個來自台灣的AI助理,你的名字是 TAIDE。\n<</SYS>>\n\n你好嗎? [/INST] 我很好。 </s><s>[INST] 今天天氣怎樣? [/INST] 大太陽。 </s><s>[INST] 你敢感覺如何? [/INST]".to_string());
351
352 let prompt = apply_template(ChatTemplate::TAIDE, &messages, false).unwrap();
353 assert_eq!(prompt, "<s>[INST] <<SYS>>\n你是一個來自台灣的AI助理,你的名字是 TAIDE。\n<</SYS>>\n\n你好嗎? [/INST] 我很好。 </s><s>[INST] 今天天氣怎樣? [/INST] 大太陽。 </s><s>[INST] 你敢感覺如何? [/INST]".to_string());
354 }
355}