use crate::schema::Message;
use regex::Regex;
use std::collections::HashMap;
pub struct ChatPromptTemplate {
messages: Vec<Message>,
}
impl ChatPromptTemplate {
pub fn new(messages: Vec<Message>) -> Self {
Self { messages }
}
pub fn format(&self, variables: &HashMap<&str, &str>) -> Result<Vec<Message>, String> {
let re = Regex::new(r"\{(\w+)\}").unwrap();
self.messages
.iter()
.map(|msg| {
let mut content = msg.content.clone();
for cap in re.captures_iter(&msg.content) {
let var_name = cap.get(1).unwrap().as_str();
if let Some(value) = variables.get(var_name) {
content = content.replace(&format!("{{{}}}", var_name), value);
} else {
return Err(format!("Missing variable: {} in message", var_name));
}
}
Ok(Message {
content,
message_type: msg.message_type.clone(),
name: msg.name.clone(),
additional_kwargs: msg.additional_kwargs.clone(),
id: msg.id.clone(),
tool_calls: msg.tool_calls.clone(),
})
})
.collect()
}
pub fn variables(&self) -> Vec<String> {
let re = Regex::new(r"\{(\w+)\}").unwrap();
let mut vars = std::collections::HashSet::new();
for msg in &self.messages {
for cap in re.captures_iter(&msg.content) {
vars.insert(cap.get(1).unwrap().as_str().to_string());
}
}
vars.into_iter().collect()
}
pub fn messages(&self) -> &[Message] {
&self.messages
}
pub fn from_messages(messages: impl Into<Vec<Message>>) -> Self {
Self::new(messages.into())
}
}
impl std::fmt::Display for ChatPromptTemplate {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
for msg in &self.messages {
let role = match msg.message_type {
crate::schema::MessageType::System => "System",
crate::schema::MessageType::Human => "Human",
crate::schema::MessageType::AI => "AI",
crate::schema::MessageType::Tool { .. } => "Tool",
};
writeln!(f, "{}: {}", role, msg.content)?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic_chat_template() {
let template = ChatPromptTemplate::new(vec![
Message::system("你是一个{role}助手。"),
Message::human("你好,我是{name}。"),
]);
let mut vars = HashMap::new();
vars.insert("role", "编程");
vars.insert("name", "小明");
let messages = template.format(&vars).unwrap();
assert_eq!(messages.len(), 2);
assert_eq!(messages[0].content, "你是一个编程助手。");
assert_eq!(messages[1].content, "你好,我是小明。");
}
#[test]
fn test_from_messages() {
let template = ChatPromptTemplate::from_messages([
Message::system("系统消息"),
Message::human("用户消息"),
]);
assert_eq!(template.messages().len(), 2);
}
#[test]
fn test_get_variables() {
let template = ChatPromptTemplate::new(vec![
Message::system("你是一个{role},专精于{domain}。"),
Message::human("我是{name},请问{question}"),
]);
let vars = template.variables();
assert!(vars.contains(&"role".to_string()));
assert!(vars.contains(&"domain".to_string()));
assert!(vars.contains(&"name".to_string()));
assert!(vars.contains(&"question".to_string()));
}
#[test]
fn test_missing_variable() {
let template = ChatPromptTemplate::new(vec![Message::human("你好,{name}!今天是{day}。")]);
let mut vars = HashMap::new();
vars.insert("name", "小明");
let result = template.format(&vars);
assert!(result.is_err());
}
}