impl ChatTemplateEngine for AlpacaTemplate {
fn format_message(&self, role: &str, content: &str) -> Result<String, AprenderError> {
let safe_content = if role == "user" {
sanitize_user_content(content)
} else {
content.to_string()
};
match role {
"system" => Ok(format!("{safe_content}\n\n")),
"user" => Ok(format!("### Instruction:\n{safe_content}\n\n")),
"assistant" => Ok(format!("### Response:\n{safe_content}\n\n")),
_ => Ok(safe_content),
}
}
fn format_conversation(&self, messages: &[ChatMessage]) -> Result<String, AprenderError> {
let mut result = String::new();
for msg in messages {
match msg.role.as_str() {
"system" => {
result.push_str(&msg.content);
result.push_str("\n\n");
}
"user" => {
let safe_content = sanitize_user_content(&msg.content);
result.push_str("### Instruction:\n");
result.push_str(&safe_content);
result.push_str("\n\n");
}
"assistant" => {
result.push_str("### Response:\n");
result.push_str(&msg.content);
result.push_str("\n\n");
}
_ => {}
}
}
result.push_str("### Response:\n");
Ok(result)
}
fn special_tokens(&self) -> &SpecialTokens {
&self.special_tokens
}
fn format(&self) -> TemplateFormat {
TemplateFormat::Alpaca
}
fn supports_system_prompt(&self) -> bool {
true
}
}
#[derive(Debug, Clone, Default)]
pub struct RawTemplate {
special_tokens: SpecialTokens,
}
impl RawTemplate {
#[must_use]
pub fn new() -> Self {
Self::default()
}
}
impl ChatTemplateEngine for RawTemplate {
fn format_message(&self, _role: &str, content: &str) -> Result<String, AprenderError> {
Ok(content.to_string())
}
fn format_conversation(&self, messages: &[ChatMessage]) -> Result<String, AprenderError> {
let result: String = messages.iter().map(|m| m.content.as_str()).collect();
Ok(result)
}
fn special_tokens(&self) -> &SpecialTokens {
&self.special_tokens
}
fn format(&self) -> TemplateFormat {
TemplateFormat::Raw
}
fn supports_system_prompt(&self) -> bool {
true
}
}
#[must_use]
pub fn detect_format_from_name(model_name: &str) -> TemplateFormat {
let name_lower = model_name.to_lowercase();
if name_lower.contains("qwen")
|| name_lower.contains("openhermes")
|| name_lower.contains("yi-")
{
return TemplateFormat::ChatML;
}
if name_lower.contains("mistral") || name_lower.contains("mixtral") {
return TemplateFormat::Mistral;
}
if name_lower.contains("llama")
|| name_lower.contains("vicuna")
|| name_lower.contains("tinyllama")
{
return TemplateFormat::Llama2;
}
if name_lower.contains("phi-") || name_lower.contains("phi2") || name_lower.contains("phi3") {
return TemplateFormat::Phi;
}
if name_lower.contains("alpaca") {
return TemplateFormat::Alpaca;
}
TemplateFormat::Raw
}
#[must_use]
pub fn detect_format_from_tokens(special_tokens: &SpecialTokens) -> TemplateFormat {
if special_tokens.im_start_token.is_some() || special_tokens.im_end_token.is_some() {
return TemplateFormat::ChatML;
}
if special_tokens.inst_start.is_some() || special_tokens.inst_end.is_some() {
return TemplateFormat::Llama2; }
TemplateFormat::Raw
}
#[must_use]
pub fn create_template(format: TemplateFormat) -> Box<dyn ChatTemplateEngine + Send + Sync> {
match format {
TemplateFormat::ChatML => Box::new(ChatMLTemplate::new()),
TemplateFormat::Llama2 => Box::new(Llama2Template::new()),
TemplateFormat::Mistral => Box::new(MistralTemplate::new()),
TemplateFormat::Phi => Box::new(PhiTemplate::new()),
TemplateFormat::Alpaca => Box::new(AlpacaTemplate::new()),
TemplateFormat::Custom | TemplateFormat::Raw => Box::new(RawTemplate::new()),
}
}
#[must_use]
pub fn auto_detect_template(model_name: &str) -> Box<dyn ChatTemplateEngine + Send + Sync> {
let format = detect_format_from_name(model_name);
create_template(format)
}
#[cfg(test)]
mod tests;