sapient_tokenizers/
chat.rs1use std::path::Path;
9
10use anyhow::{Context, Result};
11use minijinja::Environment;
12use serde::{Deserialize, Serialize};
13
14#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
17#[serde(rename_all = "lowercase")]
18pub enum ChatRole {
19 System,
20 User,
21 Assistant,
22 Tool,
23}
24
25impl std::fmt::Display for ChatRole {
26 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27 match self {
28 ChatRole::System => f.write_str("system"),
29 ChatRole::User => f.write_str("user"),
30 ChatRole::Assistant => f.write_str("assistant"),
31 ChatRole::Tool => f.write_str("tool"),
32 }
33 }
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct ChatMessage {
40 pub role: ChatRole,
41 pub content: String,
42}
43
44impl ChatMessage {
45 pub fn system(content: impl Into<String>) -> Self {
46 Self {
47 role: ChatRole::System,
48 content: content.into(),
49 }
50 }
51 pub fn user(content: impl Into<String>) -> Self {
52 Self {
53 role: ChatRole::User,
54 content: content.into(),
55 }
56 }
57 pub fn assistant(content: impl Into<String>) -> Self {
58 Self {
59 role: ChatRole::Assistant,
60 content: content.into(),
61 }
62 }
63}
64
65pub struct ChatTemplate {
70 template_src: String,
71}
72
73impl ChatTemplate {
74 pub fn from_tokenizer_config(path: &Path) -> Result<Self> {
76 let text = std::fs::read_to_string(path).context("Failed to read tokenizer_config.json")?;
77 let config: serde_json::Value =
78 serde_json::from_str(&text).context("Invalid tokenizer_config.json")?;
79
80 let template_src = config["chat_template"]
81 .as_str()
82 .context("No chat_template found in tokenizer_config.json")?
83 .to_owned();
84
85 Ok(Self { template_src })
86 }
87
88 pub fn from_template(template: impl Into<String>) -> Self {
90 Self {
91 template_src: template.into(),
92 }
93 }
94
95 pub fn render(&self, messages: &[ChatMessage], add_generation_prompt: bool) -> Result<String> {
100 let mut env = Environment::new();
101
102 env.add_template("chat", &self.template_src)
104 .map_err(|e| anyhow::anyhow!("Template parse error: {e}"))?;
105
106 let tmpl = env
107 .get_template("chat")
108 .map_err(|e| anyhow::anyhow!("Template load error: {e}"))?;
109
110 let messages_val: Vec<serde_json::Value> = messages
112 .iter()
113 .map(|m| {
114 serde_json::json!({
115 "role": m.role.to_string(),
116 "content": m.content,
117 })
118 })
119 .collect();
120
121 let ctx = serde_json::json!({
122 "messages": messages_val,
123 "add_generation_prompt": add_generation_prompt,
124 "bos_token": "<s>",
125 "eos_token": "</s>",
126 });
127
128 tmpl.render(ctx)
129 .map_err(|e| anyhow::anyhow!("Template render error: {e}"))
130 }
131}
132
133pub mod builtin {
136 pub const CHATML: &str = concat!(
138 "{% for message in messages %}",
139 "<|im_start|>{{ message['role'] }}\n{{ message['content'] }}<|im_end|>\n",
140 "{% endfor %}",
141 "{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}",
142 );
143
144 pub const LLAMA3: &str = concat!(
146 "<|begin_of_text|>",
147 "{% for message in messages %}",
148 "<|start_header_id|>{{ message['role'] }}<|end_header_id|>\n\n",
149 "{{ message['content'] }}<|eot_id|>",
150 "{% endfor %}",
151 "{% if add_generation_prompt %}",
152 "<|start_header_id|>assistant<|end_header_id|>\n\n",
153 "{% endif %}",
154 );
155
156 pub const LLAMA2: &str = concat!(
158 "{% if messages[0]['role'] == 'system' %}",
159 "{{ '[INST] <<SYS>>\n' + messages[0]['content'] + '\n<</SYS>>\n\n' }}",
160 "{% set messages = messages[1:] %}",
161 "{% endif %}",
162 "{% for message in messages %}",
163 "{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}",
164 "{% elif message['role'] == 'assistant' %}{{ message['content'] + '</s>' }}",
165 "{% endif %}",
166 "{% endfor %}",
167 );
168
169 pub const GEMMA: &str = concat!(
171 "{% for message in messages %}",
172 "{% if message['role'] == 'user' %}<start_of_turn>user\n{{ message['content'] }}<end_of_turn>\n",
173 "{% elif message['role'] == 'assistant' %}<start_of_turn>model\n{{ message['content'] }}<end_of_turn>\n",
174 "{% endif %}",
175 "{% endfor %}",
176 "{% if add_generation_prompt %}<start_of_turn>model\n{% endif %}",
177 );
178
179 pub const ZEPHYR: &str = concat!(
181 "{% for message in messages %}",
182 "{% if message['role'] == 'system' %}",
183 "<|system|>\n{{ message['content'] }}</s>\n",
184 "{% elif message['role'] == 'user' %}",
185 "<|user|>\n{{ message['content'] }}</s>\n",
186 "{% elif message['role'] == 'assistant' %}",
187 "<|assistant|>\n{{ message['content'] }}</s>\n",
188 "{% endif %}",
189 "{% endfor %}",
190 "{% if add_generation_prompt %}<|assistant|>\n{% endif %}",
191 );
192}
193
194#[cfg(test)]
195mod tests {
196 use super::*;
197
198 #[test]
199 fn chatml_render() {
200 let tmpl = ChatTemplate::from_template(builtin::CHATML);
201 let messages = vec![
202 ChatMessage::system("You are a helpful assistant."),
203 ChatMessage::user("Hello!"),
204 ];
205 let out = tmpl.render(&messages, true).unwrap();
206 assert!(out.contains("<|im_start|>system"));
207 assert!(out.contains("<|im_start|>user"));
208 assert!(out.contains("<|im_start|>assistant"));
209 }
210
211 #[test]
212 fn llama3_render() {
213 let tmpl = ChatTemplate::from_template(builtin::LLAMA3);
214 let messages = vec![ChatMessage::user("What is 2+2?")];
215 let out = tmpl.render(&messages, true).unwrap();
216 assert!(out.contains("<|begin_of_text|>"));
217 assert!(out.contains("<|start_header_id|>user<|end_header_id|>"));
218 }
219}