1use serde::{Deserialize, Serialize};
16use std::fmt;
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
24#[serde(rename_all = "lowercase")]
25pub enum Role {
26 System,
27 User,
28 Assistant,
29}
30
31impl fmt::Display for Role {
32 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
33 match self {
34 Self::System => write!(f, "system"),
35 Self::User => write!(f, "user"),
36 Self::Assistant => write!(f, "assistant"),
37 }
38 }
39}
40
41#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
43pub struct ChatMessage {
44 pub role: Role,
45 pub content: String,
46}
47
48impl ChatMessage {
49 fn with_role(role: Role, content: impl Into<String>) -> Self {
50 Self { role, content: content.into() }
51 }
52
53 pub fn system(content: impl Into<String>) -> Self {
54 Self::with_role(Role::System, content)
55 }
56
57 pub fn user(content: impl Into<String>) -> Self {
58 Self::with_role(Role::User, content)
59 }
60
61 pub fn assistant(content: impl Into<String>) -> Self {
62 Self::with_role(Role::Assistant, content)
63 }
64}
65
66#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
68pub enum TemplateFormat {
69 Llama2,
71 Mistral,
73 ChatML,
75 Alpaca,
77 Vicuna,
79 #[default]
81 Raw,
82}
83
84impl TemplateFormat {
85 #[must_use]
87 pub fn from_model_name(name: &str) -> Self {
88 let lower = name.to_lowercase();
89 if lower.contains("llama-2") || lower.contains("llama2") {
90 Self::Llama2
91 } else if lower.contains("mistral") || lower.contains("mixtral") {
92 Self::Mistral
93 } else if lower.contains("chatml") || lower.contains("openhermes") {
94 Self::ChatML
95 } else if lower.contains("alpaca") {
96 Self::Alpaca
97 } else if lower.contains("vicuna") {
98 Self::Vicuna
99 } else {
100 Self::Raw
101 }
102 }
103}
104
105#[derive(Debug, Clone)]
111pub struct ChatTemplateEngine {
112 format: TemplateFormat,
113 bos_token: Option<String>,
114 eos_token: Option<String>,
115}
116
117impl ChatTemplateEngine {
118 #[must_use]
120 pub fn new(format: TemplateFormat) -> Self {
121 let (bos_token, eos_token) = match format {
122 TemplateFormat::Llama2 | TemplateFormat::Mistral => {
123 (Some("<s>".to_string()), Some("</s>".to_string()))
124 }
125 _ => (None, None),
126 };
127 Self { format, bos_token, eos_token }
128 }
129
130 #[must_use]
132 pub fn from_model(model_name: &str) -> Self {
133 Self::new(TemplateFormat::from_model_name(model_name))
134 }
135
136 #[must_use]
138 pub fn format(&self) -> TemplateFormat {
139 self.format
140 }
141
142 #[must_use]
144 pub fn apply(&self, messages: &[ChatMessage]) -> String {
145 match self.format {
146 TemplateFormat::Llama2 => self.apply_llama2(messages),
147 TemplateFormat::Mistral => self.apply_mistral(messages),
148 TemplateFormat::ChatML => self.apply_chatml(messages),
149 TemplateFormat::Alpaca => self.apply_alpaca(messages),
150 TemplateFormat::Vicuna => self.apply_vicuna(messages),
151 TemplateFormat::Raw => self.apply_raw(messages),
152 }
153 }
154
155 #[must_use]
157 pub fn apply_prompt(&self, prompt: &str) -> String {
158 self.apply(&[ChatMessage::user(prompt)])
159 }
160
161 fn push_bos(&self, result: &mut String) {
163 if let Some(ref bos) = self.bos_token {
164 result.push_str(bos);
165 }
166 }
167
168 fn push_eos(&self, result: &mut String) {
170 if let Some(ref eos) = self.eos_token {
171 result.push_str(eos);
172 }
173 }
174
175 fn apply_llama2(&self, messages: &[ChatMessage]) -> String {
177 let mut result = String::new();
178 self.push_bos(&mut result);
179
180 let mut system_prompt = None;
181 for msg in messages {
182 match msg.role {
183 Role::System => {
184 system_prompt = Some(&msg.content);
185 }
186 Role::User => {
187 result.push_str("[INST] ");
188 if let Some(sys) = system_prompt.take() {
189 result.push_str("<<SYS>>\n");
190 result.push_str(sys);
191 result.push_str("\n<</SYS>>\n\n");
192 }
193 result.push_str(&msg.content);
194 result.push_str(" [/INST]");
195 }
196 Role::Assistant => {
197 result.push(' ');
198 result.push_str(&msg.content);
199 self.push_eos(&mut result);
200 }
201 }
202 }
203
204 result
205 }
206
207 fn apply_mistral(&self, messages: &[ChatMessage]) -> String {
209 let mut result = String::new();
210 self.push_bos(&mut result);
211
212 for msg in messages {
213 match msg.role {
214 Role::System => {
215 result.push_str("[INST] ");
217 result.push_str(&msg.content);
218 result.push_str("\n\n");
219 }
220 Role::User => {
221 if !result.contains("[INST]") {
222 result.push_str("[INST] ");
223 }
224 result.push_str(&msg.content);
225 result.push_str(" [/INST]");
226 }
227 Role::Assistant => {
228 result.push_str(&msg.content);
229 self.push_eos(&mut result);
230 }
231 }
232 }
233
234 result
235 }
236
237 fn apply_chatml(&self, messages: &[ChatMessage]) -> String {
239 let mut result = String::new();
240
241 for msg in messages {
242 result.push_str("<|im_start|>");
243 result.push_str(&msg.role.to_string());
244 result.push('\n');
245 result.push_str(&msg.content);
246 result.push_str("<|im_end|>\n");
247 }
248
249 result.push_str("<|im_start|>assistant\n");
251
252 result
253 }
254
255 fn apply_alpaca(&self, messages: &[ChatMessage]) -> String {
257 let mut result = String::new();
258
259 for msg in messages {
260 match msg.role {
261 Role::System => {
262 result.push_str(&msg.content);
263 result.push_str("\n\n");
264 }
265 Role::User => {
266 result.push_str("### Instruction:\n");
267 result.push_str(&msg.content);
268 result.push_str("\n\n### Response:\n");
269 }
270 Role::Assistant => {
271 result.push_str(&msg.content);
272 result.push('\n');
273 }
274 }
275 }
276
277 result
278 }
279
280 fn apply_vicuna(&self, messages: &[ChatMessage]) -> String {
282 let mut result = String::new();
283
284 for msg in messages {
285 match msg.role {
286 Role::System => {
287 result.push_str(&msg.content);
288 result.push_str("\n\n");
289 }
290 Role::User => {
291 result.push_str("USER: ");
292 result.push_str(&msg.content);
293 result.push_str("\nASSISTANT:");
294 }
295 Role::Assistant => {
296 result.push(' ');
297 result.push_str(&msg.content);
298 result.push('\n');
299 }
300 }
301 }
302
303 result
304 }
305
306 fn apply_raw(&self, messages: &[ChatMessage]) -> String {
308 messages.iter().map(|m| m.content.as_str()).collect::<Vec<_>>().join("\n")
309 }
310}
311
312impl Default for ChatTemplateEngine {
313 fn default() -> Self {
314 Self::new(TemplateFormat::Raw)
315 }
316}
317
318#[cfg(test)]
323#[allow(non_snake_case)]
324mod tests {
325 use super::*;
326
327 fn assert_format_detected(model_name: &str, expected: TemplateFormat) {
333 assert_eq!(
334 TemplateFormat::from_model_name(model_name),
335 expected,
336 "model name {model_name:?} should map to {expected:?}"
337 );
338 }
339
340 fn assert_message(msg: &ChatMessage, expected_role: Role, expected_content: &str) {
342 assert_eq!(msg.role, expected_role);
343 assert_eq!(msg.content, expected_content);
344 }
345
346 fn render_prompt(format: TemplateFormat, prompt: &str) -> String {
348 ChatTemplateEngine::new(format).apply_prompt(prompt)
349 }
350
351 fn multiturn_messages() -> Vec<ChatMessage> {
353 vec![
354 ChatMessage::user("Hi!"),
355 ChatMessage::assistant("Hello!"),
356 ChatMessage::user("How are you?"),
357 ]
358 }
359
360 #[test]
365 fn test_SERVE_TPL_001_role_display() {
366 assert_eq!(format!("{}", Role::System), "system");
367 assert_eq!(format!("{}", Role::User), "user");
368 assert_eq!(format!("{}", Role::Assistant), "assistant");
369 }
370
371 #[test]
372 fn test_SERVE_TPL_001_chat_message_system() {
373 let msg = ChatMessage::system("You are a helpful assistant.");
374 assert_message(&msg, Role::System, "You are a helpful assistant.");
375 }
376
377 #[test]
378 fn test_SERVE_TPL_001_chat_message_user() {
379 let msg = ChatMessage::user("Hello!");
380 assert_message(&msg, Role::User, "Hello!");
381 }
382
383 #[test]
384 fn test_SERVE_TPL_001_chat_message_assistant() {
385 let msg = ChatMessage::assistant("Hi there!");
386 assert_message(&msg, Role::Assistant, "Hi there!");
387 }
388
389 #[test]
394 fn test_SERVE_TPL_002_detect_llama2() {
395 assert_format_detected("meta-llama/Llama-2-7b", TemplateFormat::Llama2);
396 assert_format_detected("llama2-13b", TemplateFormat::Llama2);
397 }
398
399 #[test]
400 fn test_SERVE_TPL_002_detect_mistral() {
401 assert_format_detected("mistralai/Mistral-7B", TemplateFormat::Mistral);
402 assert_format_detected("mixtral-8x7b", TemplateFormat::Mistral);
403 }
404
405 #[test]
406 fn test_SERVE_TPL_002_detect_chatml() {
407 assert_format_detected("OpenHermes-2.5", TemplateFormat::ChatML);
408 assert_format_detected("chatml-model", TemplateFormat::ChatML);
409 }
410
411 #[test]
412 fn test_SERVE_TPL_002_detect_alpaca() {
413 assert_format_detected("alpaca-7b", TemplateFormat::Alpaca);
414 }
415
416 #[test]
417 fn test_SERVE_TPL_002_detect_vicuna() {
418 assert_format_detected("vicuna-13b", TemplateFormat::Vicuna);
419 }
420
421 #[test]
422 fn test_SERVE_TPL_002_detect_raw_fallback() {
423 assert_format_detected("unknown-model", TemplateFormat::Raw);
424 }
425
426 #[test]
431 fn test_SERVE_TPL_003_llama2_simple() {
432 let result = render_prompt(TemplateFormat::Llama2, "Hello!");
433 assert!(result.contains("[INST]"));
434 assert!(result.contains("[/INST]"));
435 assert!(result.contains("Hello!"));
436 }
437
438 #[test]
439 fn test_SERVE_TPL_003_llama2_with_system() {
440 let engine = ChatTemplateEngine::new(TemplateFormat::Llama2);
441 let messages = vec![ChatMessage::system("You are helpful."), ChatMessage::user("Hi!")];
442 let result = engine.apply(&messages);
443 assert!(result.contains("<<SYS>>"));
444 assert!(result.contains("You are helpful."));
445 assert!(result.contains("<</SYS>>"));
446 assert!(result.contains("Hi!"));
447 }
448
449 #[test]
450 fn test_SERVE_TPL_003_llama2_bos_token() {
451 let result = render_prompt(TemplateFormat::Llama2, "Test");
452 assert!(result.starts_with("<s>"));
453 }
454
455 #[test]
460 fn test_SERVE_TPL_004_mistral_simple() {
461 let result = render_prompt(TemplateFormat::Mistral, "Hello!");
462 assert!(result.contains("[INST]"));
463 assert!(result.contains("[/INST]"));
464 }
465
466 #[test]
467 fn test_SERVE_TPL_004_mistral_no_sys_tags() {
468 let engine = ChatTemplateEngine::new(TemplateFormat::Mistral);
469 let messages = vec![ChatMessage::system("Be helpful."), ChatMessage::user("Hi!")];
470 let result = engine.apply(&messages);
471 assert!(!result.contains("<<SYS>>"));
473 }
474
475 #[test]
480 fn test_SERVE_TPL_005_chatml_simple() {
481 let result = render_prompt(TemplateFormat::ChatML, "Hello!");
482 assert!(result.contains("<|im_start|>user"));
483 assert!(result.contains("<|im_end|>"));
484 assert!(result.contains("<|im_start|>assistant"));
485 }
486
487 #[test]
488 fn test_SERVE_TPL_005_chatml_with_system() {
489 let engine = ChatTemplateEngine::new(TemplateFormat::ChatML);
490 let messages = vec![ChatMessage::system("You are an AI."), ChatMessage::user("Hi!")];
491 let result = engine.apply(&messages);
492 assert!(result.contains("<|im_start|>system"));
493 assert!(result.contains("You are an AI."));
494 }
495
496 #[test]
501 fn test_SERVE_TPL_006_alpaca_simple() {
502 let result = render_prompt(TemplateFormat::Alpaca, "What is 2+2?");
503 assert!(result.contains("### Instruction:"));
504 assert!(result.contains("### Response:"));
505 assert!(result.contains("What is 2+2?"));
506 }
507
508 #[test]
513 fn test_SERVE_TPL_007_vicuna_simple() {
514 let result = render_prompt(TemplateFormat::Vicuna, "Hello!");
515 assert!(result.contains("USER:"));
516 assert!(result.contains("ASSISTANT:"));
517 }
518
519 #[test]
524 fn test_SERVE_TPL_008_raw_passthrough() {
525 let result = render_prompt(TemplateFormat::Raw, "Hello!");
526 assert_eq!(result, "Hello!");
527 }
528
529 #[test]
530 fn test_SERVE_TPL_008_raw_multiple_messages() {
531 let engine = ChatTemplateEngine::new(TemplateFormat::Raw);
532 let messages = vec![ChatMessage::user("A"), ChatMessage::user("B")];
533 let result = engine.apply(&messages);
534 assert_eq!(result, "A\nB");
535 }
536
537 #[test]
542 fn test_SERVE_TPL_009_from_model() {
543 let engine = ChatTemplateEngine::from_model("meta-llama/Llama-2-7b-chat");
544 assert_eq!(engine.format(), TemplateFormat::Llama2);
545 }
546
547 #[test]
548 fn test_SERVE_TPL_009_default() {
549 let engine = ChatTemplateEngine::default();
550 assert_eq!(engine.format(), TemplateFormat::Raw);
551 }
552
553 #[test]
558 fn test_SERVE_TPL_010_llama2_multiturn() {
559 let engine = ChatTemplateEngine::new(TemplateFormat::Llama2);
560 let result = engine.apply(&multiturn_messages());
561 assert!(result.matches("[INST]").count() >= 2);
563 }
564
565 #[test]
566 fn test_SERVE_TPL_010_chatml_multiturn() {
567 let engine = ChatTemplateEngine::new(TemplateFormat::ChatML);
568 let result = engine.apply(&multiturn_messages());
569 assert!(result.matches("<|im_start|>").count() >= 3);
571 }
572
573 #[test]
578 fn test_SERVE_TPL_011_vicuna_with_system() {
579 let engine = ChatTemplateEngine::new(TemplateFormat::Vicuna);
580 let messages = vec![ChatMessage::system("You are helpful."), ChatMessage::user("Hi!")];
581 let result = engine.apply(&messages);
582 assert!(result.contains("You are helpful."));
583 assert!(result.contains("USER: Hi!"));
584 assert!(result.contains("ASSISTANT:"));
585 }
586
587 #[test]
588 fn test_SERVE_TPL_011_vicuna_with_assistant_response() {
589 let engine = ChatTemplateEngine::new(TemplateFormat::Vicuna);
590 let messages = vec![ChatMessage::user("Hi!"), ChatMessage::assistant("Hello there!")];
591 let result = engine.apply(&messages);
592 assert!(result.contains("USER: Hi!"));
593 assert!(result.contains(" Hello there!"));
594 }
595
596 #[test]
597 fn test_SERVE_TPL_011_vicuna_multiturn() {
598 let engine = ChatTemplateEngine::new(TemplateFormat::Vicuna);
599 let result = engine.apply(&multiturn_messages());
600 assert_eq!(result.matches("USER:").count(), 2);
602 assert!(result.contains(" Hello!"));
604 }
605
606 #[test]
607 fn test_SERVE_TPL_011_vicuna_system_and_assistant() {
608 let engine = ChatTemplateEngine::new(TemplateFormat::Vicuna);
609 let messages = vec![
610 ChatMessage::system("Be concise."),
611 ChatMessage::user("What is 2+2?"),
612 ChatMessage::assistant("4"),
613 ChatMessage::user("And 3+3?"),
614 ];
615 let result = engine.apply(&messages);
616 assert!(result.contains("Be concise."));
617 assert!(result.contains("USER: What is 2+2?"));
618 assert!(result.contains(" 4\n"));
619 assert!(result.contains("USER: And 3+3?"));
620 }
621
622 #[test]
627 fn test_SERVE_TPL_012_alpaca_with_system() {
628 let engine = ChatTemplateEngine::new(TemplateFormat::Alpaca);
629 let messages =
630 vec![ChatMessage::system("You are a tutor."), ChatMessage::user("Explain gravity.")];
631 let result = engine.apply(&messages);
632 assert!(result.contains("You are a tutor."));
633 assert!(result.contains("### Instruction:"));
634 assert!(result.contains("Explain gravity."));
635 assert!(result.contains("### Response:"));
636 }
637
638 #[test]
639 fn test_SERVE_TPL_012_alpaca_with_assistant_response() {
640 let engine = ChatTemplateEngine::new(TemplateFormat::Alpaca);
641 let messages = vec![
642 ChatMessage::user("What is AI?"),
643 ChatMessage::assistant("Artificial Intelligence."),
644 ];
645 let result = engine.apply(&messages);
646 assert!(result.contains("### Instruction:"));
647 assert!(result.contains("What is AI?"));
648 assert!(result.contains("Artificial Intelligence.\n"));
649 }
650
651 #[test]
652 fn test_SERVE_TPL_012_alpaca_multiturn() {
653 let engine = ChatTemplateEngine::new(TemplateFormat::Alpaca);
654 let result = engine.apply(&multiturn_messages());
655 assert_eq!(result.matches("### Instruction:").count(), 2);
657 assert!(result.contains("Hello!\n"));
659 }
660
661 #[test]
662 fn test_SERVE_TPL_012_alpaca_system_and_multiturn() {
663 let engine = ChatTemplateEngine::new(TemplateFormat::Alpaca);
664 let messages = vec![
665 ChatMessage::system("Be brief."),
666 ChatMessage::user("Define ML."),
667 ChatMessage::assistant("Machine Learning."),
668 ChatMessage::user("Define AI."),
669 ];
670 let result = engine.apply(&messages);
671 assert!(result.contains("Be brief.\n\n"));
672 assert!(result.contains("### Instruction:\nDefine ML."));
673 assert!(result.contains("Machine Learning.\n"));
674 assert!(result.contains("### Instruction:\nDefine AI."));
675 }
676
677 #[test]
682 fn test_SERVE_TPL_013_mistral_multiturn() {
683 let engine = ChatTemplateEngine::new(TemplateFormat::Mistral);
684 let result = engine.apply(&multiturn_messages());
685 assert!(result.starts_with("<s>"));
687 assert!(result.contains("[INST] Hi! [/INST]"));
689 assert!(result.contains("Hello!</s>"));
691 }
692
693 #[test]
694 fn test_SERVE_TPL_013_mistral_with_system_and_assistant() {
695 let engine = ChatTemplateEngine::new(TemplateFormat::Mistral);
696 let messages = vec![
697 ChatMessage::system("You are an expert."),
698 ChatMessage::user("Explain ML."),
699 ChatMessage::assistant("Machine Learning is..."),
700 ChatMessage::user("More detail."),
701 ];
702 let result = engine.apply(&messages);
703 assert!(result.contains("[INST] You are an expert."));
704 assert!(result.contains("Explain ML. [/INST]"));
705 assert!(result.contains("Machine Learning is...</s>"));
706 assert!(result.contains("More detail. [/INST]"));
707 }
708
709 #[test]
710 fn test_SERVE_TPL_013_mistral_system_prepends_to_first_inst() {
711 let engine = ChatTemplateEngine::new(TemplateFormat::Mistral);
712 let messages = vec![ChatMessage::system("Be helpful."), ChatMessage::user("Hi!")];
713 let result = engine.apply(&messages);
714 assert!(result.contains("[INST] Be helpful."));
716 assert!(result.contains("Hi! [/INST]"));
717 }
718
719 #[test]
724 fn test_SERVE_TPL_014_llama2_multiturn_with_assistant() {
725 let engine = ChatTemplateEngine::new(TemplateFormat::Llama2);
726 let messages = vec![
727 ChatMessage::system("You are an AI."),
728 ChatMessage::user("Hello!"),
729 ChatMessage::assistant("Hi!"),
730 ChatMessage::user("How are you?"),
731 ];
732 let result = engine.apply(&messages);
733 assert!(result.starts_with("<s>"));
734 assert!(result.contains("<<SYS>>"));
735 assert!(result.contains("You are an AI."));
736 assert!(result.contains("<</SYS>>"));
737 assert!(result.contains(" Hi!</s>"));
738 assert!(result.contains("[INST] How are you? [/INST]"));
739 }
740
741 #[test]
746 fn test_SERVE_TPL_015_chatml_system_and_multiturn() {
747 let engine = ChatTemplateEngine::new(TemplateFormat::ChatML);
748 let messages = vec![
749 ChatMessage::system("Be concise."),
750 ChatMessage::user("Hi!"),
751 ChatMessage::assistant("Hello!"),
752 ChatMessage::user("Bye!"),
753 ];
754 let result = engine.apply(&messages);
755 assert!(result.contains("<|im_start|>system\nBe concise.<|im_end|>"));
756 assert!(result.contains("<|im_start|>user\nHi!<|im_end|>"));
757 assert!(result.contains("<|im_start|>assistant\nHello!<|im_end|>"));
758 assert!(result.contains("<|im_start|>user\nBye!<|im_end|>"));
759 assert!(result.ends_with("<|im_start|>assistant\n"));
761 }
762}