impl Default for ChatMLTemplate {
fn default() -> Self {
Self::new()
}
}
impl ChatTemplateEngine for ChatMLTemplate {
fn format_message(&self, role: &str, content: &str) -> Result<String, AprenderError> {
let safe_content = if role == "user" {
sanitize_user_content(content)
} else {
content.to_string()
};
Ok(format!("<|im_start|>{role}\n{safe_content}<|im_end|>\n"))
}
fn format_conversation(&self, messages: &[ChatMessage]) -> Result<String, AprenderError> {
use std::fmt::Write;
let mut result = String::new();
for msg in messages {
let safe_content = if msg.role == "user" {
sanitize_user_content(&msg.content)
} else {
msg.content.clone()
};
let _ = write!(
result,
"<|im_start|>{}\n{}<|im_end|>\n",
msg.role, safe_content
);
}
result.push_str("<|im_start|>assistant\n");
Ok(result)
}
fn special_tokens(&self) -> &SpecialTokens {
&self.special_tokens
}
fn format(&self) -> TemplateFormat {
TemplateFormat::ChatML
}
fn supports_system_prompt(&self) -> bool {
true
}
}
#[derive(Debug, Clone)]
pub struct Llama2Template {
special_tokens: SpecialTokens,
}
impl Llama2Template {
#[must_use]
pub fn new() -> Self {
Self {
special_tokens: SpecialTokens {
bos_token: Some("<s>".to_string()),
eos_token: Some("</s>".to_string()),
inst_start: Some("[INST]".to_string()),
inst_end: Some("[/INST]".to_string()),
sys_start: Some("<<SYS>>".to_string()),
sys_end: Some("<</SYS>>".to_string()),
..Default::default()
},
}
}
}
impl Default for Llama2Template {
fn default() -> Self {
Self::new()
}
}
impl ChatTemplateEngine for Llama2Template {
fn format_message(&self, role: &str, content: &str) -> Result<String, AprenderError> {
let safe_content = if role == "user" {
sanitize_user_content(content)
} else {
content.to_string()
};
match role {
"system" => Ok(format!("<<SYS>>\n{safe_content}\n<</SYS>>\n\n")),
"user" => Ok(format!("[INST] {safe_content} [/INST]")),
"assistant" => Ok(format!(" {safe_content}</s>")),
_ => Ok(safe_content),
}
}
fn format_conversation(&self, messages: &[ChatMessage]) -> Result<String, AprenderError> {
let mut result = String::from("<s>");
let mut system_prompt: Option<String> = None;
let mut in_user_turn = false;
for (i, msg) in messages.iter().enumerate() {
match msg.role.as_str() {
"system" => {
system_prompt = Some(msg.content.clone());
}
"user" => {
let safe_content = sanitize_user_content(&msg.content);
if i > 0 && !in_user_turn {
result.push_str("<s>");
}
result.push_str("[INST] ");
if let Some(sys) = system_prompt.take() {
result.push_str("<<SYS>>\n");
result.push_str(&sys);
result.push_str("\n<</SYS>>\n\n");
}
result.push_str(&safe_content);
result.push_str(" [/INST]");
in_user_turn = true;
}
"assistant" => {
result.push(' ');
result.push_str(&msg.content);
result.push_str("</s>");
in_user_turn = false;
}
_ => {}
}
}
Ok(result)
}
fn special_tokens(&self) -> &SpecialTokens {
&self.special_tokens
}
fn format(&self) -> TemplateFormat {
TemplateFormat::Llama2
}
fn supports_system_prompt(&self) -> bool {
true
}
}
#[derive(Debug, Clone)]
pub struct MistralTemplate {
special_tokens: SpecialTokens,
}
impl MistralTemplate {
#[must_use]
pub fn new() -> Self {
Self {
special_tokens: SpecialTokens {
bos_token: Some("<s>".to_string()),
eos_token: Some("</s>".to_string()),
inst_start: Some("[INST]".to_string()),
inst_end: Some("[/INST]".to_string()),
..Default::default()
},
}
}
}
impl Default for MistralTemplate {
fn default() -> Self {
Self::new()
}
}
impl ChatTemplateEngine for MistralTemplate {
fn format_message(&self, role: &str, content: &str) -> Result<String, AprenderError> {
let safe_content = if role == "user" {
sanitize_user_content(content)
} else {
content.to_string()
};
match role {
"user" => Ok(format!("[INST] {safe_content} [/INST]")),
"assistant" => Ok(format!(" {safe_content}</s>")),
"system" => {
Ok(format!("{safe_content}\n\n"))
}
_ => Ok(safe_content),
}
}
fn format_conversation(&self, messages: &[ChatMessage]) -> Result<String, AprenderError> {
let mut result = String::from("<s>");
for msg in messages {
match msg.role.as_str() {
"user" => {
let safe_content = sanitize_user_content(&msg.content);
result.push_str("[INST] ");
result.push_str(&safe_content);
result.push_str(" [/INST]");
}
"assistant" => {
result.push(' ');
result.push_str(&msg.content);
result.push_str("</s>");
}
_ => {}
}
}
Ok(result)
}
fn special_tokens(&self) -> &SpecialTokens {
&self.special_tokens
}
fn format(&self) -> TemplateFormat {
TemplateFormat::Mistral
}
fn supports_system_prompt(&self) -> bool {
false }
}
#[derive(Debug, Clone)]
pub struct PhiTemplate {
special_tokens: SpecialTokens,
}
impl PhiTemplate {
#[must_use]
pub fn new() -> Self {
Self {
special_tokens: SpecialTokens::default(),
}
}
}
impl Default for PhiTemplate {
fn default() -> Self {
Self::new()
}
}
impl ChatTemplateEngine for PhiTemplate {
fn format_message(&self, role: &str, content: &str) -> Result<String, AprenderError> {
let safe_content = if role == "user" {
sanitize_user_content(content)
} else {
content.to_string()
};
match role {
"user" => Ok(format!("Instruct: {safe_content}\n")),
"assistant" => Ok(format!("Output: {safe_content}\n")),
"system" => Ok(format!("{safe_content}\n")),
_ => Ok(safe_content),
}
}
fn format_conversation(&self, messages: &[ChatMessage]) -> Result<String, AprenderError> {
let mut result = String::new();
for msg in messages {
match msg.role.as_str() {
"system" => {
result.push_str(&msg.content);
result.push('\n');
}
"user" => {
let safe_content = sanitize_user_content(&msg.content);
result.push_str("Instruct: ");
result.push_str(&safe_content);
result.push('\n');
}
"assistant" => {
result.push_str("Output: ");
result.push_str(&msg.content);
result.push('\n');
}
_ => {}
}
}
result.push_str("Output:");
Ok(result)
}
fn special_tokens(&self) -> &SpecialTokens {
&self.special_tokens
}
fn format(&self) -> TemplateFormat {
TemplateFormat::Phi
}
fn supports_system_prompt(&self) -> bool {
true
}
}
#[derive(Debug, Clone)]
pub struct AlpacaTemplate {
special_tokens: SpecialTokens,
}
impl AlpacaTemplate {
#[must_use]
pub fn new() -> Self {
Self {
special_tokens: SpecialTokens::default(),
}
}
}
impl Default for AlpacaTemplate {
fn default() -> Self {
Self::new()
}
}