use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum TemplateFamily { ChatML, Llama3, OpenChat }
impl TemplateFamily {
pub fn render(&self, system: Option<&str>, messages: &[(String, String)], input: Option<&str>) -> String {
match self {
TemplateFamily::ChatML => {
let mut s = String::new();
if let Some(sys) = system { s.push_str(&format!("<|im_start|>system\n{}<|im_end|>\n", sys)); }
for (role, content) in messages { s.push_str(&format!("<|im_start|>{}\n{}<|im_end|>\n", role, content)); }
if let Some(inp) = input { s.push_str(&format!("<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", inp)); }
s
}
TemplateFamily::Llama3 => {
let mut s = String::new();
if let Some(sys) = system { s.push_str(&format!("<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n{}<|eot_id|>", sys)); }
for (role, content) in messages { s.push_str(&format!("<|start_header_id|>{}<|end_header_id|>\n{}<|eot_id|>", role, content)); }
if let Some(inp) = input { s.push_str(&format!("<|start_header_id|>user<|end_header_id|>\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n", inp)); }
s
}
TemplateFamily::OpenChat => {
let mut s = String::new();
for (role, content) in messages { s.push_str(&format!("{}: {}\n", role, content)); }
if let Some(inp) = input { s.push_str(&format!("user: {}\nassistant: ", inp)); } else { s.push_str("assistant: "); }
s
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_chatml_render() {
let template = TemplateFamily::ChatML;
let messages = vec![("user".to_string(), "Hello".to_string())];
let result = template.render(None, &messages, None);
assert!(result.contains("<|im_start|>user"));
assert!(result.contains("Hello"));
assert!(result.contains("<|im_end|>"));
}
#[test]
fn test_llama3_render() {
let template = TemplateFamily::Llama3;
let messages = vec![("user".to_string(), "Test".to_string())];
let result = template.render(None, &messages, None);
assert!(result.contains("<|start_header_id|>user<|end_header_id|>"));
assert!(result.contains("Test"));
assert!(result.contains("<|eot_id|>"));
}
#[test]
fn test_openchat_render() {
let template = TemplateFamily::OpenChat;
let messages = vec![("user".to_string(), "Hi".to_string())];
let result = template.render(None, &messages, None);
assert!(result.contains("user: Hi"));
assert!(result.contains("assistant: "));
}
}