1use minijinja::{context, Environment};
36use serde::{Deserialize, Serialize};
37use std::path::Path;
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct Message {
42 pub role: String,
43 pub content: String,
44}
45
46impl Message {
47 pub fn new(role: impl Into<String>, content: impl Into<String>) -> Self {
48 Self {
49 role: role.into(),
50 content: content.into(),
51 }
52 }
53
54 pub fn system(content: impl Into<String>) -> Self {
55 Self::new("system", content)
56 }
57
58 pub fn user(content: impl Into<String>) -> Self {
59 Self::new("user", content)
60 }
61
62 pub fn assistant(content: impl Into<String>) -> Self {
63 Self::new("assistant", content)
64 }
65}
66
67#[derive(Debug, Clone, Default)]
69pub struct ChatTemplateOptions {
70 pub add_generation_prompt: bool,
72 pub continue_final_message: bool,
74 pub enable_thinking: bool,
76 pub extra_context: std::collections::HashMap<String, String>,
78}
79
80impl ChatTemplateOptions {
81 pub fn for_generation() -> Self {
82 Self {
83 add_generation_prompt: true,
84 ..Default::default()
85 }
86 }
87
88 pub fn for_training() -> Self {
89 Self {
90 add_generation_prompt: false,
91 ..Default::default()
92 }
93 }
94
95 pub fn with_thinking(mut self) -> Self {
96 self.enable_thinking = true;
97 self
98 }
99}
100
101#[derive(Debug, Clone, Default, Deserialize)]
103pub struct TokenConfig {
104 #[serde(default)]
105 pub bos_token: Option<StringOrToken>,
106 #[serde(default)]
107 pub eos_token: Option<StringOrToken>,
108 #[serde(default)]
109 pub unk_token: Option<StringOrToken>,
110 #[serde(default)]
111 pub pad_token: Option<StringOrToken>,
112 #[serde(default)]
113 pub chat_template: Option<ChatTemplateConfig>,
114}
115
116#[derive(Debug, Clone, Deserialize)]
118#[serde(untagged)]
119pub enum StringOrToken {
120 String(String),
121 Token { content: String },
122}
123
124impl StringOrToken {
125 pub fn as_str(&self) -> &str {
126 match self {
127 StringOrToken::String(s) => s,
128 StringOrToken::Token { content } => content,
129 }
130 }
131}
132
133impl Default for StringOrToken {
134 fn default() -> Self {
135 StringOrToken::String(String::new())
136 }
137}
138
139#[derive(Debug, Clone, Deserialize)]
141#[serde(untagged)]
142pub enum ChatTemplateConfig {
143 Single(String),
144 Multiple(Vec<NamedTemplate>),
145}
146
147#[derive(Debug, Clone, Deserialize)]
148pub struct NamedTemplate {
149 pub name: String,
150 pub template: String,
151}
152
153pub struct ChatTemplate {
155 env: Environment<'static>,
156 bos_token: String,
157 eos_token: String,
158}
159
160impl ChatTemplate {
161 pub fn new(
163 template: impl Into<String>,
164 bos_token: impl Into<String>,
165 eos_token: impl Into<String>,
166 ) -> Result<Self, ChatTemplateError> {
167 let mut env = Environment::new();
168 env.add_function("raise_exception", |msg: String| -> Result<String, _> {
170 Err(minijinja::Error::new(
171 minijinja::ErrorKind::InvalidOperation,
172 msg,
173 ))
174 });
175
176 env.add_template_owned("chat".to_string(), template.into())
177 .map_err(|e| ChatTemplateError::TemplateError(e.to_string()))?;
178
179 Ok(Self {
180 env,
181 bos_token: bos_token.into(),
182 eos_token: eos_token.into(),
183 })
184 }
185
186 pub fn from_tokenizer_config(path: impl AsRef<Path>) -> Result<Self, ChatTemplateError> {
188 let content = std::fs::read_to_string(path.as_ref())
189 .map_err(|e| ChatTemplateError::IoError(e.to_string()))?;
190
191 Self::from_tokenizer_config_str(&content)
192 }
193
194 pub fn from_tokenizer_config_str(json: &str) -> Result<Self, ChatTemplateError> {
196 let config: TokenConfig =
197 serde_json::from_str(json).map_err(|e| ChatTemplateError::ParseError(e.to_string()))?;
198
199 let template = match config.chat_template {
200 Some(ChatTemplateConfig::Single(t)) => t,
201 Some(ChatTemplateConfig::Multiple(templates)) => {
202 templates
204 .iter()
205 .find(|t| t.name == "default")
206 .or_else(|| templates.first())
207 .map(|t| t.template.clone())
208 .ok_or(ChatTemplateError::NoTemplate)?
209 }
210 None => return Err(ChatTemplateError::NoTemplate),
211 };
212
213 let bos = config
214 .bos_token
215 .map(|t| t.as_str().to_string())
216 .unwrap_or_default();
217 let eos = config
218 .eos_token
219 .map(|t| t.as_str().to_string())
220 .unwrap_or_default();
221
222 Self::new(template, bos, eos)
223 }
224
225 pub fn chatml() -> Self {
227 let template = r#"
228{%- for message in messages %}
229{{- '<|im_start|>' + message.role + '\n' + message.content | trim + '<|im_end|>\n' }}
230{%- endfor %}
231{%- if add_generation_prompt %}
232{{- '<|im_start|>assistant\n' }}
233{%- endif %}
234"#;
235 Self::new(template, "", "<|im_end|>").unwrap()
236 }
237
238 pub fn chatml_with_thinking() -> Self {
240 let template = r#"
241{%- for message in messages %}
242{{- '<|im_start|>' + message.role + '\n' + message.content | trim + '<|im_end|>\n' }}
243{%- endfor %}
244{%- if add_generation_prompt %}
245{%- if enable_thinking %}
246{{- '<|im_start|>assistant\n<think>\n' }}
247{%- else %}
248{{- '<|im_start|>assistant\n' }}
249{%- endif %}
250{%- endif %}
251"#;
252 Self::new(template, "", "<|im_end|>").unwrap()
253 }
254
255 pub fn llama2() -> Self {
257 let template = r#"
258{%- if messages[0]['role'] == 'system' %}
259 {%- set system_message = '<<SYS>>\n' + messages[0]['content'] + '\n<</SYS>>\n\n' %}
260 {%- set messages = messages[1:] %}
261{%- else %}
262 {%- set system_message = '' %}
263{%- endif %}
264{%- for message in messages %}
265 {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}
266 {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}
267 {%- endif %}
268 {%- if loop.index0 == 0 %}
269 {{- bos_token + '[INST] ' + system_message + message['content'] + ' [/INST]' }}
270 {%- elif message['role'] == 'user' %}
271 {{- bos_token + '[INST] ' + message['content'] + ' [/INST]' }}
272 {%- elif message['role'] == 'assistant' %}
273 {{- ' ' + message['content'] + ' ' + eos_token }}
274 {%- endif %}
275{%- endfor %}
276"#;
277 Self::new(template, "<s>", "</s>").unwrap()
278 }
279
280 pub fn llama3() -> Self {
282 let template = r#"
283{%- set loop_messages = messages %}
284{%- for message in loop_messages %}
285 {%- set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] | trim + '<|eot_id|>' %}
286 {%- if loop.index0 == 0 %}
287 {{- bos_token + content }}
288 {%- else %}
289 {{- content }}
290 {%- endif %}
291{%- endfor %}
292{%- if add_generation_prompt %}
293 {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }}
294{%- endif %}
295"#;
296 Self::new(template, "<|begin_of_text|>", "<|eot_id|>").unwrap()
297 }
298
299 pub fn mistral() -> Self {
301 let template = r#"
302{{- bos_token }}
303{%- for message in messages %}
304 {%- if message['role'] == 'user' %}
305 {{- '[INST] ' + message['content'] + ' [/INST]' }}
306 {%- elif message['role'] == 'assistant' %}
307 {{- ' ' + message['content'] + eos_token }}
308 {%- endif %}
309{%- endfor %}
310"#;
311 Self::new(template, "<s>", "</s>").unwrap()
312 }
313
314 pub fn gemma() -> Self {
316 let template = r#"
317{%- for message in messages %}
318 {%- if message['role'] == 'user' %}
319 {{- '<start_of_turn>user\n' + message['content'] + '<end_of_turn>\n' }}
320 {%- elif message['role'] == 'assistant' %}
321 {{- '<start_of_turn>model\n' + message['content'] + '<end_of_turn>\n' }}
322 {%- endif %}
323{%- endfor %}
324{%- if add_generation_prompt %}
325 {{- '<start_of_turn>model\n' }}
326{%- endif %}
327"#;
328 Self::new(template, "<bos>", "<eos>").unwrap()
329 }
330
331 pub fn apply(
333 &self,
334 messages: &[Message],
335 options: &ChatTemplateOptions,
336 ) -> Result<String, ChatTemplateError> {
337 let template = self
338 .env
339 .get_template("chat")
340 .map_err(|e| ChatTemplateError::TemplateError(e.to_string()))?;
341
342 let result = template
343 .render(context! {
344 messages => messages,
345 add_generation_prompt => options.add_generation_prompt,
346 continue_final_message => options.continue_final_message,
347 enable_thinking => options.enable_thinking,
348 bos_token => &self.bos_token,
349 eos_token => &self.eos_token,
350 })
351 .map_err(|e| ChatTemplateError::RenderError(e.to_string()))?;
352
353 Ok(result.trim_start().to_string())
354 }
355
356 pub fn apply_for_generation(&self, messages: &[Message]) -> Result<String, ChatTemplateError> {
358 self.apply(messages, &ChatTemplateOptions::for_generation())
359 }
360}
361
362pub struct Conversation {
364 messages: Vec<Message>,
365 template: ChatTemplate,
366 options: ChatTemplateOptions,
367}
368
369impl Conversation {
370 pub fn new(template: ChatTemplate, system_prompt: impl Into<String>) -> Self {
372 Self {
373 messages: vec![Message::system(system_prompt)],
374 template,
375 options: ChatTemplateOptions::for_generation(),
376 }
377 }
378
379 pub fn without_system(template: ChatTemplate) -> Self {
381 Self {
382 messages: Vec::new(),
383 template,
384 options: ChatTemplateOptions::for_generation(),
385 }
386 }
387
388 pub fn with_options(mut self, options: ChatTemplateOptions) -> Self {
390 self.options = options;
391 self
392 }
393
394 pub fn user_turn(&mut self, content: impl Into<String>) -> Result<String, ChatTemplateError> {
396 self.messages.push(Message::user(content));
397 self.template.apply(&self.messages, &self.options)
398 }
399
400 pub fn assistant_response(&mut self, content: impl Into<String>) {
402 self.messages.push(Message::assistant(content));
403 }
404
405 pub fn add_message(&mut self, message: Message) {
407 self.messages.push(message);
408 }
409
410 pub fn messages(&self) -> &[Message] {
412 &self.messages
413 }
414
415 pub fn clear(&mut self) {
417 if let Some(first) = self.messages.first() {
418 if first.role == "system" {
419 let system = self.messages.remove(0);
420 self.messages.clear();
421 self.messages.push(system);
422 return;
423 }
424 }
425 self.messages.clear();
426 }
427
428 pub fn format_history(&self) -> Result<String, ChatTemplateError> {
430 self.template
431 .apply(&self.messages, &ChatTemplateOptions::for_training())
432 }
433}
434
435#[derive(Debug)]
437pub enum ChatTemplateError {
438 IoError(String),
439 ParseError(String),
440 TemplateError(String),
441 RenderError(String),
442 NoTemplate,
443}
444
445impl std::fmt::Display for ChatTemplateError {
446 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
447 match self {
448 Self::IoError(e) => write!(f, "IO error: {}", e),
449 Self::ParseError(e) => write!(f, "Parse error: {}", e),
450 Self::TemplateError(e) => write!(f, "Template error: {}", e),
451 Self::RenderError(e) => write!(f, "Render error: {}", e),
452 Self::NoTemplate => write!(f, "No chat_template found in config"),
453 }
454 }
455}
456
457impl std::error::Error for ChatTemplateError {}
458
459#[cfg(test)]
460mod tests {
461 use super::*;
462
463 #[test]
464 fn test_chatml_basic() {
465 let template = ChatTemplate::chatml();
466 let messages = vec![Message::system("You are helpful."), Message::user("Hello")];
467
468 let result = template.apply_for_generation(&messages).unwrap();
469
470 assert!(result.contains("<|im_start|>system\nYou are helpful.<|im_end|>"));
471 assert!(result.contains("<|im_start|>user\nHello<|im_end|>"));
472 assert!(result.ends_with("<|im_start|>assistant\n"));
473 }
474
475 #[test]
476 fn test_multi_turn_conversation() {
477 let mut conv = Conversation::new(ChatTemplate::chatml(), "You are helpful.");
478
479 let prompt1 = conv.user_turn("Hi").unwrap();
480 assert!(prompt1.contains("Hi"));
481
482 conv.assistant_response("Hello!");
483
484 let prompt2 = conv.user_turn("How are you?").unwrap();
485 assert!(prompt2.contains("Hi"));
486 assert!(prompt2.contains("Hello!"));
487 assert!(prompt2.contains("How are you?"));
488 }
489
490 #[test]
491 fn test_thinking_mode() {
492 let template = ChatTemplate::chatml_with_thinking();
493 let messages = vec![Message::user("Think about this")];
494
495 let result = template
496 .apply(
497 &messages,
498 &ChatTemplateOptions::for_generation().with_thinking(),
499 )
500 .unwrap();
501
502 assert!(result.contains("<think>"));
503 }
504
505 #[test]
506 fn test_llama3_format() {
507 let template = ChatTemplate::llama3();
508 let messages = vec![Message::system("You are helpful."), Message::user("Hello")];
509
510 let result = template.apply_for_generation(&messages).unwrap();
511
512 assert!(result.contains("<|begin_of_text|>"));
513 assert!(result.contains("<|start_header_id|>system<|end_header_id|>"));
514 assert!(result.contains("<|start_header_id|>user<|end_header_id|>"));
515 assert!(result.contains("<|eot_id|>"));
516 }
517
518 #[test]
519 fn test_from_json_config() {
520 let json = r#"{
521 "bos_token": "<s>",
522 "eos_token": "</s>",
523 "chat_template": "{% for m in messages %}{{ m.role }}: {{ m.content }}\n{% endfor %}"
524 }"#;
525
526 let template = ChatTemplate::from_tokenizer_config_str(json).unwrap();
527 let messages = vec![Message::user("test")];
528 let result = template.apply_for_generation(&messages).unwrap();
529
530 assert!(result.contains("user: test"));
531 }
532}