chat_templates/
lib.rs

1use minijinja::{context, Environment};
2use serde::Serialize;
3
4#[derive(Serialize, Debug)]
5pub struct Message {
6    pub role: String,
7    pub content: String,
8}
9
10/// [chatml](https://github.com/MicrosoftDocs/azure-docs/blob/main/articles/ai-services/openai/includes/chat-markup-language.md) jinja templatel, modified
11/// from repo [`chat_templates`](https://github.com/chujiezheng/chat_templates/tree/main/chat_templates)
12/// with minijinja [compatible syntax](https://github.com/mitsuhiko/minijinja/blob/main/COMPATIBILITY.md)
13const CHATML_JINJA_TEMPLATE: &str = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}";
14
15const CHATML_JINJA_TEMPLATE_NAME: &str = "chatml";
16
17const MISTRAL_INSTRUCT_TEMPLATE: &str = "{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% endif %}{% endfor %}";
18
19const MISTRAL_INSTRUCT_TEMPLATE_NAME: &str = "mistral-instruct";
20
21const TAIDE_JINJA_TEMPLATE_NAME: &str = "taide";
22
23const TAIDE_JINJA_TEMPLATE: &str = "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = '<<SYS>>\n' + messages[0]['content'] + '\n<</SYS>>\n\n' %}{% else %}{% set loop_messages = messages %}{% set system_message = '' %}{% endif %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{% set content = system_message + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content + ' [/INST]'}}{% elif message['role'] == 'assistant' %}{{ ' '  + content + ' ' + eos_token }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}";
24
25/// Apply Chat Markup Language (chatml) template to messages, return the prompt
26fn apply_chatml_template(
27    messages: &Vec<Message>,
28    add_generation_prompt: bool,
29) -> Result<String, ApplyChatMLTemplateError> {
30    let mut env = Environment::new();
31    env.add_template(CHATML_JINJA_TEMPLATE_NAME, CHATML_JINJA_TEMPLATE)
32        .map_err(ApplyChatMLTemplateError::AddTemplateError)?;
33    let template = env
34        .get_template(CHATML_JINJA_TEMPLATE_NAME)
35        .map_err(ApplyChatMLTemplateError::GetTemplateError)?;
36    template
37        .render(context! {
38          messages => messages,
39          add_generation_prompt => add_generation_prompt,
40        })
41        .map_err(ApplyChatMLTemplateError::RenderTemplateError)
42}
43
44fn apply_mistral_instruct_template(
45    messages: &Vec<Message>,
46    add_generation_prompt: bool,
47) -> Result<String, ApplyMistralInstructTemplateError> {
48    let mut env = Environment::new();
49    env.add_template(MISTRAL_INSTRUCT_TEMPLATE_NAME, MISTRAL_INSTRUCT_TEMPLATE)
50        .map_err(ApplyMistralInstructTemplateError::AddTemplateError)?;
51    let template = env
52        .get_template(MISTRAL_INSTRUCT_TEMPLATE_NAME)
53        .map_err(ApplyMistralInstructTemplateError::GetTemplateError)?;
54    template
55        .render(context! {
56          messages => messages,
57          add_generation_prompt => add_generation_prompt,
58          // https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2/blob/main/tokenizer_config.json#L31
59          bos_token => "<s>",
60          // https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2/blob/main/tokenizer_config.json#L33
61          eos_token => "</s>",
62        })
63        .map_err(ApplyMistralInstructTemplateError::RenderTemplateError)
64}
65
66/// Apply TAIDE template to messages, return the prompt
67fn apply_taide_template(
68  messages: &Vec<Message>,
69) -> Result<String, ApplyTAIDETemplateError> {
70  let mut env = Environment::new();
71  env.add_template(TAIDE_JINJA_TEMPLATE_NAME, TAIDE_JINJA_TEMPLATE)
72      .map_err(ApplyTAIDETemplateError::AddTemplateError)?;
73  let template = env
74      .get_template(TAIDE_JINJA_TEMPLATE_NAME)
75      .map_err(ApplyTAIDETemplateError::GetTemplateError)?;
76  template
77      .render(context! {
78        messages => messages,
79        bos_token => "<s>",
80        eos_token => "</s>",
81      })
82      .map_err(ApplyTAIDETemplateError::RenderTemplateError)
83}
84
85/// All available templates
86pub enum ChatTemplate {
87    ChatML,
88    MistralInstruct,
89    TAIDE
90}
91
92/// Apply chat template to messages, return the prompt
93///
94/// # Arguments
95/// * `messages` - a list of messages, each message contains `role` and `content`
96/// * `add_generation_prompt` - if `true`, attach `<|im_start|>assistant\n` at the end of the prompt
97/// * `template` - the jinja template
98///
99pub fn apply_template(
100    template: ChatTemplate,
101    messages: &Vec<Message>,
102    add_generation_prompt: bool,
103) -> Result<String, ApplyTemplateError> {
104    match template {
105        ChatTemplate::ChatML => apply_chatml_template(messages, add_generation_prompt)
106            .map_err(ApplyTemplateError::ApplyChatMLTemplateError),
107        ChatTemplate::MistralInstruct => {
108            apply_mistral_instruct_template(messages, add_generation_prompt)
109                .map_err(ApplyTemplateError::ApplyMistralInstructTemplateError)
110        }
111        ChatTemplate::TAIDE => apply_taide_template(messages)
112            .map_err(ApplyTemplateError::ApplyTAIDETemplateError),
113    }
114}
115
116#[derive(thiserror::Error, Debug)]
117pub enum ApplyChatMLTemplateError {
118    #[error("failed to add template")]
119    AddTemplateError(#[source] minijinja::Error),
120    #[error("failed to get template")]
121    GetTemplateError(#[source] minijinja::Error),
122    #[error("failed to render")]
123    RenderTemplateError(#[source] minijinja::Error),
124}
125
126#[derive(thiserror::Error, Debug)]
127pub enum ApplyMistralInstructTemplateError {
128    #[error("failed to add template")]
129    AddTemplateError(#[source] minijinja::Error),
130    #[error("failed to get template")]
131    GetTemplateError(#[source] minijinja::Error),
132    #[error("failed to render")]
133    RenderTemplateError(#[source] minijinja::Error),
134}
135
136#[derive(thiserror::Error, Debug)]
137pub enum ApplyTAIDETemplateError {
138    #[error("failed to add template")]
139    AddTemplateError(#[source] minijinja::Error),
140    #[error("failed to get template")]
141    GetTemplateError(#[source] minijinja::Error),
142    #[error("failed to render")]
143    RenderTemplateError(#[source] minijinja::Error),
144}
145
146
147#[derive(thiserror::Error, Debug)]
148pub enum ApplyTemplateError {
149    #[error("failed to apply chatml template")]
150    ApplyChatMLTemplateError(#[source] ApplyChatMLTemplateError),
151    #[error("failed to apply mistral instruct template")]
152    ApplyMistralInstructTemplateError(#[source] ApplyMistralInstructTemplateError),
153    #[error("failed to apply taide template")]
154    ApplyTAIDETemplateError(#[source] ApplyTAIDETemplateError),
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160
161    #[test]
162    fn test_apply_chatml_template_one_shot() {
163        let messages = vec![
164          Message {
165            role: "system".to_string(),
166            content: "Assistant is an intelligent chatbot designed to help users answer their tax related questions.".to_string(),
167          },
168          Message {
169            role: "user".to_string(),
170            content: "Hello, who are you?".to_string(),
171          }
172        ];
173
174        let prompt = apply_template(ChatTemplate::ChatML, &messages, true).unwrap();
175        assert_eq!(prompt, "<|im_start|>system\nAssistant is an intelligent chatbot designed to help users answer their tax related questions.<|im_end|>\n<|im_start|>user\nHello, who are you?<|im_end|>\n<|im_start|>assistant\n".to_string());
176
177        let prompt = apply_template(ChatTemplate::ChatML, &messages, false).unwrap();
178        assert_eq!(prompt, "<|im_start|>system\nAssistant is an intelligent chatbot designed to help users answer their tax related questions.<|im_end|>\n<|im_start|>user\nHello, who are you?<|im_end|>\n".to_string());
179    }
180
181    #[test]
182    fn test_apply_chatml_template_few_shots() {
183        let messages = vec![
184          Message {
185            role: "system".to_string(),
186            content: "Assistant is an intelligent chatbot designed to help users answer their tax related questions.".to_string(),
187          },
188          Message {
189            role: "user".to_string(),
190            content: "When do I need to file my taxes by?".to_string(),
191          },
192          Message {
193            role: "assistant".to_string(),
194            content: "In 2023, you will need to file your taxes by April 18th. The date falls after the usual April 15th deadline because April 15th falls on a Saturday in 2023.".to_string(),
195          },
196          Message {
197            role: "user".to_string(),
198            content: "How can I check the status of my tax refund?".to_string(),
199          }
200        ];
201
202        let prompt = apply_template(ChatTemplate::ChatML, &messages, true).unwrap();
203        assert_eq!(prompt, "<|im_start|>system\nAssistant is an intelligent chatbot designed to help users answer their tax related questions.<|im_end|>\n<|im_start|>user\nWhen do I need to file my taxes by?<|im_end|>\n<|im_start|>assistant\nIn 2023, you will need to file your taxes by April 18th. The date falls after the usual April 15th deadline because April 15th falls on a Saturday in 2023.<|im_end|>\n<|im_start|>user\nHow can I check the status of my tax refund?<|im_end|>\n<|im_start|>assistant\n".to_string());
204
205        let prompt = apply_template(ChatTemplate::ChatML, &messages, false).unwrap();
206        assert_eq!(prompt, "<|im_start|>system\nAssistant is an intelligent chatbot designed to help users answer their tax related questions.<|im_end|>\n<|im_start|>user\nWhen do I need to file my taxes by?<|im_end|>\n<|im_start|>assistant\nIn 2023, you will need to file your taxes by April 18th. The date falls after the usual April 15th deadline because April 15th falls on a Saturday in 2023.<|im_end|>\n<|im_start|>user\nHow can I check the status of my tax refund?<|im_end|>\n".to_string());
207    }
208
209    #[test]
210    fn test_apply_mistral_instruct_template_one_shot() {
211        let messages = vec![
212          Message {
213            role: "user".to_string(),
214            content: "Hello, who are you?".to_string(),
215          },
216        ];
217
218        let prompt = apply_template(ChatTemplate::MistralInstruct, &messages, true).unwrap();
219        assert_eq!(prompt, "<s>[INST] Hello, who are you? [/INST]".to_string());
220    }
221
222    #[test]
223    fn test_apply_mistral_instruct_template_few_shots() {
224        // see https://huggingface.co/docs/transformers/main/chat_templating#introduction
225        let messages = vec![
226          Message {
227            role: "user".to_string(),
228            content: "Hello, who are you?".to_string(),
229          },
230          Message {
231            role: "assistant".to_string(),
232            content: "I'm doing great. How can I help you today?".to_string(),
233          },
234          Message {
235            role: "user".to_string(),
236            content: "I'd like to show off how chat templating works!".to_string(),
237          },
238          Message {
239            role: "assistant".to_string(),
240            content: "Are you sure?".to_string(),
241          },
242          Message {
243            role: "user".to_string(),
244            content: "Yes!".to_string(),
245          },
246        ];
247
248        let prompt = apply_template(ChatTemplate::MistralInstruct, &messages, true).unwrap();
249        assert_eq!(prompt, "<s>[INST] Hello, who are you? [/INST]I'm doing great. How can I help you today?</s>[INST] I'd like to show off how chat templating works! [/INST]Are you sure?</s>[INST] Yes! [/INST]".to_string());
250    }
251
252    #[test]
253    fn test_apply_taide_template_one_shot() {
254        let messages = vec![
255          Message {
256            role: "user".to_string(),
257            content: "你好嗎?".to_string(),
258          }
259        ];
260
261        // taide-chat template does not support add_generation_prompt = true
262        let prompt = apply_template(ChatTemplate::TAIDE, &messages, true).unwrap();
263        assert_eq!(prompt, "<s>[INST] 你好嗎? [/INST]".to_string());
264        
265        let prompt = apply_template(ChatTemplate::TAIDE, &messages, false).unwrap();
266        assert_eq!(prompt, "<s>[INST] 你好嗎? [/INST]".to_string());
267    }
268
269    #[test]
270    fn test_apply_taide_template_one_shot_with_sys_prompt() {
271        let messages = vec![
272          Message {
273            role: "system".to_string(),
274            content: "你是一個來自台灣的AI助理,你的名字是 TAIDE。".to_string(),
275          },
276          Message {
277            role: "user".to_string(),
278            content: "你好嗎?".to_string(),
279          }
280        ];
281
282        // taide-chat template does not support add_generation_prompt = true
283        let prompt = apply_template(ChatTemplate::TAIDE, &messages, true).unwrap();
284        assert_eq!(prompt, "<s>[INST] <<SYS>>\n你是一個來自台灣的AI助理,你的名字是 TAIDE。\n<</SYS>>\n\n你好嗎? [/INST]".to_string());
285        
286        let prompt = apply_template(ChatTemplate::TAIDE, &messages, false).unwrap();
287        assert_eq!(prompt, "<s>[INST] <<SYS>>\n你是一個來自台灣的AI助理,你的名字是 TAIDE。\n<</SYS>>\n\n你好嗎? [/INST]".to_string());
288    }
289
290    #[test]
291    fn test_apply_taide_template_few_shot_with_sys_prompt() {
292        let messages = vec![
293          Message {
294            role: "system".to_string(),
295            content: "你是一個來自台灣的AI助理,你的名字是 TAIDE。".to_string(),
296          },
297          Message {
298            role: "user".to_string(),
299            content: "你好嗎?".to_string(),
300          },
301          Message {
302            role: "assistant".to_string(),
303            content: "我很好。".to_string(),
304          },
305          Message {
306            role: "user".to_string(),
307            content: "今天天氣怎樣?".to_string(),
308          },
309        ];
310
311        // taide-chat template does not support add_generation_prompt = true
312        let prompt = apply_template(ChatTemplate::TAIDE, &messages, true).unwrap();
313        assert_eq!(prompt, "<s>[INST] <<SYS>>\n你是一個來自台灣的AI助理,你的名字是 TAIDE。\n<</SYS>>\n\n你好嗎? [/INST] 我很好。 </s><s>[INST] 今天天氣怎樣? [/INST]".to_string());
314        
315        let prompt = apply_template(ChatTemplate::TAIDE, &messages, false).unwrap();
316        assert_eq!(prompt, "<s>[INST] <<SYS>>\n你是一個來自台灣的AI助理,你的名字是 TAIDE。\n<</SYS>>\n\n你好嗎? [/INST] 我很好。 </s><s>[INST] 今天天氣怎樣? [/INST]".to_string());
317    }
318
319    #[test]
320    fn test_apply_taide_template_few_shot_conversation_sys_prompt() {
321        let messages = vec![
322          Message {
323            role: "system".to_string(),
324            content: "你是一個來自台灣的AI助理,你的名字是 TAIDE。".to_string(),
325          },
326          Message {
327            role: "user".to_string(),
328            content: "你好嗎?".to_string(),
329          },
330          Message {
331            role: "assistant".to_string(),
332            content: "我很好。".to_string(),
333          },
334          Message {
335            role: "user".to_string(),
336            content: "今天天氣怎樣?".to_string(),
337          },
338          Message {
339            role: "assistant".to_string(),
340            content: "大太陽。".to_string(),
341          },
342          Message {
343            role: "user".to_string(),
344            content: "你敢感覺如何?".to_string(),
345          },
346        ];
347
348        // taide-chat template does not support add_generation_prompt = true
349        let prompt = apply_template(ChatTemplate::TAIDE, &messages, true).unwrap();
350        assert_eq!(prompt, "<s>[INST] <<SYS>>\n你是一個來自台灣的AI助理,你的名字是 TAIDE。\n<</SYS>>\n\n你好嗎? [/INST] 我很好。 </s><s>[INST] 今天天氣怎樣? [/INST] 大太陽。 </s><s>[INST] 你敢感覺如何? [/INST]".to_string());
351        
352        let prompt = apply_template(ChatTemplate::TAIDE, &messages, false).unwrap();
353        assert_eq!(prompt, "<s>[INST] <<SYS>>\n你是一個來自台灣的AI助理,你的名字是 TAIDE。\n<</SYS>>\n\n你好嗎? [/INST] 我很好。 </s><s>[INST] 今天天氣怎樣? [/INST] 大太陽。 </s><s>[INST] 你敢感覺如何? [/INST]".to_string());
354    }
355}