#[derive(Debug, Clone)]
pub struct ChatMessage {
pub role: String,
pub content: String,
}
impl ChatMessage {
pub fn system(content: impl Into<String>) -> Self {
Self {
role: "system".into(),
content: content.into(),
}
}
pub fn user(content: impl Into<String>) -> Self {
Self {
role: "user".into(),
content: content.into(),
}
}
pub fn assistant(content: impl Into<String>) -> Self {
Self {
role: "assistant".into(),
content: content.into(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum ChatTemplate {
ChatML,
Llama3,
Qwen,
Raw,
}
impl ChatTemplate {
pub fn from_architecture(arch: &str) -> Self {
match arch.to_lowercase().as_str() {
"llama" => ChatTemplate::ChatML, "qwen2" => ChatTemplate::Qwen,
"mistral" => ChatTemplate::Raw, _ => ChatTemplate::ChatML,
}
}
pub fn format(&self, messages: &[ChatMessage]) -> String {
match self {
ChatTemplate::ChatML | ChatTemplate::Qwen => format_chatml(messages),
ChatTemplate::Llama3 => format_llama3(messages),
ChatTemplate::Raw => format_raw(messages),
}
}
pub fn format_prompt(&self, prompt: &str) -> String {
self.format(&[ChatMessage::user(prompt)])
}
pub fn format_with_system(&self, system: &str, prompt: &str) -> String {
self.format(&[ChatMessage::system(system), ChatMessage::user(prompt)])
}
}
fn format_chatml(messages: &[ChatMessage]) -> String {
let mut output = String::new();
for msg in messages {
output.push_str("<|im_start|>");
output.push_str(&msg.role);
output.push('\n');
output.push_str(&msg.content);
output.push_str("<|im_end|>\n");
}
output.push_str("<|im_start|>assistant\n");
output
}
fn format_llama3(messages: &[ChatMessage]) -> String {
let mut output = String::from("<|begin_of_text|>");
for msg in messages {
output.push_str("<|start_header_id|>");
output.push_str(&msg.role);
output.push_str("<|end_header_id|>\n\n");
output.push_str(&msg.content);
output.push_str("<|eot_id|>");
}
output.push_str("<|start_header_id|>assistant<|end_header_id|>\n\n");
output
}
fn format_raw(messages: &[ChatMessage]) -> String {
let mut output = String::new();
for msg in messages {
if !output.is_empty() {
output.push('\n');
}
output.push_str(&msg.content);
}
output
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn chatml_single_user() {
let template = ChatTemplate::ChatML;
let result = template.format_prompt("Hello");
assert!(result.contains("<|im_start|>user\nHello<|im_end|>"));
assert!(result.ends_with("<|im_start|>assistant\n"));
}
#[test]
fn chatml_with_system() {
let template = ChatTemplate::ChatML;
let result = template.format_with_system("You are helpful.", "Hello");
assert!(result.contains("<|im_start|>system\nYou are helpful.<|im_end|>"));
assert!(result.contains("<|im_start|>user\nHello<|im_end|>"));
}
#[test]
fn chatml_multi_turn() {
let template = ChatTemplate::ChatML;
let messages = vec![
ChatMessage::user("What is Rust?"),
ChatMessage::assistant("A systems programming language."),
ChatMessage::user("Tell me more."),
];
let result = template.format(&messages);
assert!(result.contains("What is Rust?"));
assert!(result.contains("A systems programming language."));
assert!(result.contains("Tell me more."));
assert!(result.ends_with("<|im_start|>assistant\n"));
}
#[test]
fn llama3_format() {
let template = ChatTemplate::Llama3;
let result = template.format_prompt("Hello");
assert!(result.contains("<|start_header_id|>user<|end_header_id|>"));
assert!(result.contains("Hello"));
assert!(result.contains("<|start_header_id|>assistant<|end_header_id|>"));
}
#[test]
fn detect_from_architecture() {
assert_eq!(
ChatTemplate::from_architecture("llama"),
ChatTemplate::ChatML
);
assert_eq!(ChatTemplate::from_architecture("qwen2"), ChatTemplate::Qwen);
}
#[test]
fn raw_format() {
let template = ChatTemplate::Raw;
let messages = vec![ChatMessage::user("Hello"), ChatMessage::user("World")];
let result = template.format(&messages);
assert_eq!(result, "Hello\nWorld");
}
#[test]
fn chatml_empty_messages_produces_assistant_header() {
let template = ChatTemplate::ChatML;
let result = template.format(&[]);
assert_eq!(
result, "<|im_start|>assistant\n",
"empty messages should produce just the assistant header"
);
}
#[test]
fn llama3_empty_messages_produces_assistant_header() {
let template = ChatTemplate::Llama3;
let result = template.format(&[]);
assert!(
result.contains("<|start_header_id|>assistant<|end_header_id|>"),
"empty Llama3 messages should still produce assistant header"
);
}
#[test]
fn chatml_handles_special_characters_in_content() {
let template = ChatTemplate::ChatML;
let content = "Here is code:\n```rust\nfn main() { println!(\"<|test|>\"); }\n```";
let result = template.format_prompt(content);
assert!(
result.contains(content),
"special characters in content should be preserved verbatim"
);
}
#[test]
fn chatml_multi_turn_preserves_order() {
let template = ChatTemplate::ChatML;
let messages = vec![
ChatMessage::system("You are a calculator."),
ChatMessage::user("What is 2+2?"),
ChatMessage::assistant("4"),
ChatMessage::user("And 3+3?"),
];
let result = template.format(&messages);
let sys_pos = result.find("system\nYou are a calculator.").unwrap();
let user1_pos = result.find("user\nWhat is 2+2?").unwrap();
let asst_pos = result.find("assistant\n4").unwrap();
let user2_pos = result.find("user\nAnd 3+3?").unwrap();
let final_asst = result.rfind("<|im_start|>assistant\n").unwrap();
assert!(sys_pos < user1_pos, "system should come before first user");
assert!(
user1_pos < asst_pos,
"first user should come before assistant response"
);
assert!(
asst_pos < user2_pos,
"assistant response should come before second user"
);
assert!(
user2_pos < final_asst,
"second user should come before final assistant prompt"
);
}
#[test]
fn from_architecture_unknown_defaults_to_chatml() {
let template = ChatTemplate::from_architecture("unknown_arch_xyz");
assert_eq!(template, ChatTemplate::ChatML);
}
}