use std::collections::HashMap;
use crate::error::SynwireError;
use crate::messages::Message;
use crate::prompts::PromptValue;
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum MessageTemplate {
System(String),
Human(String),
AI(String),
Placeholder(String),
}
#[derive(Debug, Clone)]
pub struct ChatPromptTemplate {
messages: Vec<MessageTemplate>,
input_variables: Vec<String>,
}
fn extract_variables(template: &str) -> Vec<String> {
let mut vars = Vec::new();
let mut rest = template;
while let Some(start) = rest.find('{') {
rest = &rest[start + 1..];
if let Some(end) = rest.find('}') {
let var = &rest[..end];
if !var.is_empty() {
vars.push(var.to_owned());
}
rest = &rest[end + 1..];
} else {
break;
}
}
vars
}
fn expand_placeholder(value: &str, out: &mut Vec<Message>) {
if let Ok(serde_json::Value::Array(arr)) = serde_json::from_str::<serde_json::Value>(value) {
for item in &arr {
let Some(role) = item.get("role").and_then(serde_json::Value::as_str) else {
continue;
};
let Some(content) = item.get("content").and_then(serde_json::Value::as_str) else {
continue;
};
let msg = match role {
"system" => Message::system(content),
"human" | "user" => Message::human(content),
"ai" | "assistant" => Message::ai(content),
_ => continue,
};
out.push(msg);
}
} else {
out.push(Message::human(value));
}
}
fn substitute(template: &str, variables: &HashMap<String, String>) -> Result<String, SynwireError> {
let mut result = template.to_owned();
for var in &extract_variables(template) {
let value = variables.get(var).ok_or_else(|| SynwireError::Prompt {
message: format!("missing required variable '{var}'"),
})?;
result = result.replace(&format!("{{{var}}}"), value);
}
Ok(result)
}
impl ChatPromptTemplate {
pub fn from_messages(messages: Vec<MessageTemplate>) -> Self {
let mut seen = std::collections::HashSet::new();
let mut input_variables = Vec::new();
for msg in &messages {
let tpl = match msg {
MessageTemplate::System(t) | MessageTemplate::Human(t) | MessageTemplate::AI(t) => {
t.as_str()
}
MessageTemplate::Placeholder(_) => continue,
};
for var in extract_variables(tpl) {
if seen.insert(var.clone()) {
input_variables.push(var);
}
}
}
Self {
messages,
input_variables,
}
}
pub fn input_variables(&self) -> &[String] {
&self.input_variables
}
pub fn format_messages(
&self,
variables: &HashMap<String, String>,
) -> Result<Vec<Message>, SynwireError> {
let mut result = Vec::with_capacity(self.messages.len());
for msg in &self.messages {
match msg {
MessageTemplate::System(tpl) => {
let text = substitute(tpl, variables)?;
result.push(Message::system(text));
}
MessageTemplate::Human(tpl) => {
let text = substitute(tpl, variables)?;
result.push(Message::human(text));
}
MessageTemplate::AI(tpl) => {
let text = substitute(tpl, variables)?;
result.push(Message::ai(text));
}
MessageTemplate::Placeholder(name) => {
if let Some(value) = variables.get(name.as_str()) {
expand_placeholder(value, &mut result);
}
}
}
}
Ok(result)
}
pub fn to_prompt_value(
&self,
variables: &HashMap<String, String>,
) -> Result<PromptValue, SynwireError> {
let messages = self.format_messages(variables)?;
Ok(PromptValue::Messages(messages))
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn test_chat_prompt_template_format_messages() {
let tpl = ChatPromptTemplate::from_messages(vec![
MessageTemplate::System("You are {role}".into()),
MessageTemplate::Human("{question}".into()),
]);
let mut vars = HashMap::new();
let _ = vars.insert("role".into(), "a helpful assistant".into());
let _ = vars.insert("question".into(), "What is Rust?".into());
let messages = tpl.format_messages(&vars).unwrap();
assert_eq!(messages.len(), 2);
assert_eq!(messages[0].message_type(), "system");
assert_eq!(
messages[0].content().as_text(),
"You are a helpful assistant"
);
assert_eq!(messages[1].message_type(), "human");
assert_eq!(messages[1].content().as_text(), "What is Rust?");
}
#[test]
fn test_chat_prompt_template_to_prompt_value() {
let tpl =
ChatPromptTemplate::from_messages(vec![MessageTemplate::Human("Hello {name}".into())]);
let mut vars = HashMap::new();
let _ = vars.insert("name".into(), "World".into());
let pv = tpl.to_prompt_value(&vars).unwrap();
let messages = pv.to_messages();
assert_eq!(messages.len(), 1);
assert_eq!(messages[0].content().as_text(), "Hello World");
}
#[test]
fn test_chat_prompt_template_missing_variable() {
let tpl =
ChatPromptTemplate::from_messages(vec![MessageTemplate::Human("{question}".into())]);
let vars = HashMap::new();
let err = tpl.format_messages(&vars).unwrap_err();
let msg = err.to_string();
assert!(
msg.contains("question"),
"error should mention the missing variable, got: {msg}"
);
}
#[test]
fn test_extract_variables() {
let vars = extract_variables("Hello {name}, you are {age} years old");
assert_eq!(vars, vec!["name", "age"]);
}
#[test]
fn test_extract_variables_empty() {
let vars = extract_variables("No variables here");
assert!(vars.is_empty());
}
#[test]
fn test_input_variables_auto_extracted() {
let tpl = ChatPromptTemplate::from_messages(vec![
MessageTemplate::System("You are {role}".into()),
MessageTemplate::Human("{question} about {topic}".into()),
]);
assert_eq!(tpl.input_variables(), &["role", "question", "topic"]);
}
#[test]
fn test_placeholder_missing_variable_skipped() {
let tpl = ChatPromptTemplate::from_messages(vec![
MessageTemplate::System("Hello".into()),
MessageTemplate::Placeholder("history".into()),
MessageTemplate::Human("{question}".into()),
]);
let mut vars = HashMap::new();
let _ = vars.insert("question".into(), "Hi".into());
let messages = tpl.format_messages(&vars).unwrap();
assert_eq!(messages.len(), 2);
}
#[test]
fn test_placeholder_json_array_expansion() {
let tpl = ChatPromptTemplate::from_messages(vec![
MessageTemplate::System("You are helpful.".into()),
MessageTemplate::Placeholder("history".into()),
MessageTemplate::Human("{question}".into()),
]);
let history = serde_json::json!([
{"role": "human", "content": "What is 2+2?"},
{"role": "ai", "content": "4"},
]);
let mut vars = HashMap::new();
let _ = vars.insert("history".into(), history.to_string());
let _ = vars.insert("question".into(), "And 3+3?".into());
let messages = tpl.format_messages(&vars).unwrap();
assert_eq!(messages.len(), 4);
assert_eq!(messages[0].message_type(), "system");
assert_eq!(messages[1].message_type(), "human");
assert_eq!(messages[1].content().as_text(), "What is 2+2?");
assert_eq!(messages[2].message_type(), "ai");
assert_eq!(messages[2].content().as_text(), "4");
assert_eq!(messages[3].message_type(), "human");
assert_eq!(messages[3].content().as_text(), "And 3+3?");
}
#[test]
fn test_placeholder_plain_string_becomes_human_message() {
let tpl =
ChatPromptTemplate::from_messages(vec![MessageTemplate::Placeholder("input".into())]);
let mut vars = HashMap::new();
let _ = vars.insert("input".into(), "Tell me a joke".into());
let messages = tpl.format_messages(&vars).unwrap();
assert_eq!(messages.len(), 1);
assert_eq!(messages[0].message_type(), "human");
assert_eq!(messages[0].content().as_text(), "Tell me a joke");
}
#[test]
fn test_placeholder_recognises_user_and_assistant_roles() {
let tpl =
ChatPromptTemplate::from_messages(vec![MessageTemplate::Placeholder("history".into())]);
let history = serde_json::json!([
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there"},
{"role": "system", "content": "Be concise"},
]);
let mut vars = HashMap::new();
let _ = vars.insert("history".into(), history.to_string());
let messages = tpl.format_messages(&vars).unwrap();
assert_eq!(messages.len(), 3);
assert_eq!(messages[0].message_type(), "human");
assert_eq!(messages[1].message_type(), "ai");
assert_eq!(messages[2].message_type(), "system");
}
#[test]
fn test_placeholder_skips_items_with_unknown_role() {
let tpl =
ChatPromptTemplate::from_messages(vec![MessageTemplate::Placeholder("history".into())]);
let history = serde_json::json!([
{"role": "human", "content": "Hi"},
{"role": "tool", "content": "result"},
{"role": "ai", "content": "Done"},
]);
let mut vars = HashMap::new();
let _ = vars.insert("history".into(), history.to_string());
let messages = tpl.format_messages(&vars).unwrap();
assert_eq!(messages.len(), 2);
assert_eq!(messages[0].message_type(), "human");
assert_eq!(messages[1].message_type(), "ai");
}
}