use regex::Regex;
use std::sync::LazyLock;
static RE_SYSTEM: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"\{\{\s*\.System\s*\}\}").unwrap());
static RE_PROMPT: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"\{\{\s*\.Prompt\s*\}\}").unwrap());
static RE_RESPONSE: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"\{\{\s*\.Response\s*\}\}").unwrap());
static RE_FIRST: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\{\{\s*\.First\s*\}\}").unwrap());
static RE_CONTENT: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"\{\{\s*\.Content\s*\}\}").unwrap());
static RE_ROLE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\{\{\s*\.Role\s*\}\}").unwrap());
static RE_TRIM_SYSTEM1: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"\{\{-\s*\.System\s*-?\}\}").unwrap());
static RE_TRIM_SYSTEM2: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"\{\{-?\s*\.System\s*-\}\}").unwrap());
static RE_TRIM_PROMPT1: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"\{\{-\s*\.Prompt\s*-?\}\}").unwrap());
static RE_TRIM_PROMPT2: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"\{\{-?\s*\.Prompt\s*-\}\}").unwrap());
static RE_TRIM_RESPONSE1: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"\{\{-\s*\.Response\s*-?\}\}").unwrap());
static RE_TRIM_RESPONSE2: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"\{\{-?\s*\.Response\s*-\}\}").unwrap());
static RE_IF_SYSTEM: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"\{\{-?\s*if\s+\.System\s*-?\}\}").unwrap());
static RE_IF_FIRST: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"\{\{-?\s*if\s+\.First\s*-?\}\}").unwrap());
static RE_IF_NOT_FIRST: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"\{\{-?\s*if\s+not\s+\.First\s*-?\}\}").unwrap());
static RE_ELSE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\{\{-?\s*else\s*-?\}\}").unwrap());
static RE_END: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\{\{-?\s*end\s*-?\}\}").unwrap());
static RE_RANGE_MESSAGES: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"\{\{-?\s*range\s+\.Messages\s*-?\}\}").unwrap());
static RE_REMAINING: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\{\{-?[^}]*-?\}\}").unwrap());
static RE_MULTI_NEWLINE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\n{3,}").unwrap());
#[derive(Debug, Clone)]
pub struct ChatTemplate {
pub template: String,
pub bos_token: Option<String>,
pub eos_token: Option<String>,
pub stop_sequences: Vec<String>,
}
impl ChatTemplate {
pub fn from_ollama_template(go_template: &str) -> Self {
let converter = GoTemplateConverter::new();
converter.convert(go_template)
}
pub fn apply(&self, system: Option<&str>, user: &str, assistant_prefix: bool) -> String {
let mut result = self.template.clone();
if let Some(sys) = system {
result = result.replace("{system}", sys);
result = result.replace("{if_system}", "");
result = result.replace("{end_if_system}", "");
} else {
result = remove_conditional_block(&result, "{if_system}", "{end_if_system}");
}
result = result.replace("{user}", user);
result = result.replace("{prompt}", user);
if assistant_prefix {
result = result.replace("{assistant}", "");
} else {
if let Some(pos) = result.find("{assistant}") {
result = result[..pos].to_string();
}
}
result
}
}
fn remove_conditional_block(s: &str, start_marker: &str, end_marker: &str) -> String {
if let Some(start) = s.find(start_marker) {
if let Some(end) = s.find(end_marker) {
let before = &s[..start];
let after = &s[end + end_marker.len()..];
return format!("{}{}", before, after);
}
}
s.to_string()
}
struct GoTemplateConverter;
impl GoTemplateConverter {
fn new() -> Self {
Self
}
fn convert(&self, go_template: &str) -> ChatTemplate {
let mut template = go_template.to_string();
let mut stop_sequences = Vec::new();
template = self.convert_variables(&template);
template = self.convert_conditionals(&template);
template = self.convert_ranges(&template);
stop_sequences.extend(self.extract_stop_sequences(&template));
template = self.clean_whitespace_markers(&template);
let bos_token = self.detect_bos_token(&template);
let eos_token = self.detect_eos_token(&template);
ChatTemplate {
template,
bos_token,
eos_token,
stop_sequences,
}
}
fn convert_variables(&self, template: &str) -> String {
let mut result = template.to_string();
result = RE_SYSTEM.replace_all(&result, "{system}").to_string();
result = RE_PROMPT.replace_all(&result, "{user}").to_string();
result = RE_RESPONSE.replace_all(&result, "{assistant}").to_string();
result = RE_FIRST.replace_all(&result, "").to_string();
result = RE_CONTENT.replace_all(&result, "{content}").to_string();
result = RE_ROLE.replace_all(&result, "{role}").to_string();
result = RE_TRIM_SYSTEM1.replace_all(&result, "{system}").to_string();
result = RE_TRIM_SYSTEM2.replace_all(&result, "{system}").to_string();
result = RE_TRIM_PROMPT1.replace_all(&result, "{user}").to_string();
result = RE_TRIM_PROMPT2.replace_all(&result, "{user}").to_string();
result = RE_TRIM_RESPONSE1
.replace_all(&result, "{assistant}")
.to_string();
result = RE_TRIM_RESPONSE2
.replace_all(&result, "{assistant}")
.to_string();
result
}
fn convert_conditionals(&self, template: &str) -> String {
let mut result = template.to_string();
result = RE_IF_SYSTEM.replace_all(&result, "{if_system}").to_string();
result = RE_IF_FIRST.replace_all(&result, "{if_first}").to_string();
result = RE_IF_NOT_FIRST
.replace_all(&result, "{if_not_first}")
.to_string();
result = RE_ELSE.replace_all(&result, "{else}").to_string();
result = RE_END.replace_all(&result, "{end_if_system}").to_string();
result
}
fn convert_ranges(&self, template: &str) -> String {
let mut result = template.to_string();
result = RE_RANGE_MESSAGES
.replace_all(&result, "{foreach_message}")
.to_string();
result = result.replace(
"{end_if_system}{end_if_system}",
"{end_foreach}{end_if_system}",
);
result
}
fn clean_whitespace_markers(&self, template: &str) -> String {
let mut result = template.to_string();
result = RE_REMAINING.replace_all(&result, "").to_string();
result = RE_MULTI_NEWLINE.replace_all(&result, "\n\n").to_string();
result.trim().to_string()
}
fn extract_stop_sequences(&self, template: &str) -> Vec<String> {
let mut sequences = Vec::new();
let markers = [
"<|end|>",
"<|eot_id|>",
"<|im_end|>",
"<|start_header_id|>", "</s>",
"[/INST]",
"<</SYS>>",
"\n\nHuman:",
"\n\nAssistant:",
];
for marker in markers {
if template.contains(marker) {
sequences.push(marker.to_string());
}
}
sequences
}
fn detect_bos_token(&self, template: &str) -> Option<String> {
let bos_patterns = ["<s>", "<|begin_of_text|>", "<|startoftext|>"];
for pattern in bos_patterns {
if template.starts_with(pattern) || template.contains(pattern) {
return Some(pattern.to_string());
}
}
None
}
fn detect_eos_token(&self, template: &str) -> Option<String> {
let eos_patterns = ["</s>", "<|end_of_text|>", "<|endoftext|>", "<|eot_id|>"];
for pattern in eos_patterns {
if template.ends_with(pattern) || template.contains(pattern) {
return Some(pattern.to_string());
}
}
None
}
}
pub fn format_chat(
template: &ChatTemplate,
messages: &[(String, String)], add_generation_prompt: bool,
) -> String {
let mut result = String::new();
let system_msg = messages
.iter()
.find(|(role, _)| role == "system")
.map(|(_, content)| content.as_str());
let user_messages: Vec<_> = messages
.iter()
.filter(|(role, _)| role == "user" || role == "assistant")
.collect();
if let Some((_, last_user_content)) = user_messages.last() {
result = template.apply(system_msg, last_user_content, !add_generation_prompt);
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simple_template_conversion() {
let go_template = r#"{{ .System }}
{{ .Prompt }}
{{ .Response }}"#;
let template = ChatTemplate::from_ollama_template(go_template);
assert!(template.template.contains("{system}"));
assert!(template.template.contains("{user}"));
assert!(template.template.contains("{assistant}"));
}
#[test]
fn test_conditional_template() {
let go_template = r#"{{- if .System }}<|system|>{{ .System }}<|end|>{{- end }}
<|user|>{{ .Prompt }}<|end|>
<|assistant|>"#;
let template = ChatTemplate::from_ollama_template(go_template);
assert!(template.template.contains("{if_system}"));
assert!(template.template.contains("{system}"));
assert!(template.stop_sequences.contains(&"<|end|>".to_string()));
}
#[test]
fn test_llama3_style_template() {
let go_template = r#"<|begin_of_text|>{{- if .System }}<|start_header_id|>system<|end_header_id|>
{{ .System }}<|eot_id|>{{- end }}<|start_header_id|>user<|end_header_id|>
{{ .Prompt }}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
"#;
let template = ChatTemplate::from_ollama_template(go_template);
assert!(template.bos_token.is_some());
assert!(template.stop_sequences.contains(&"<|eot_id|>".to_string()));
}
#[test]
fn test_apply_template() {
let template = ChatTemplate {
template: "{if_system}<|system|>{system}<|end|>{end_if_system}<|user|>{user}<|end|><|assistant|>{assistant}".to_string(),
bos_token: None,
eos_token: Some("<|end|>".to_string()),
stop_sequences: vec!["<|end|>".to_string()],
};
let result = template.apply(Some("You are helpful."), "Hello!", true);
assert!(result.contains("You are helpful."));
assert!(result.contains("Hello!"));
}
}