use regex::Regex;
use std::collections::HashMap;
pub struct PromptTemplate {
template: String,
}
impl PromptTemplate {
pub fn new(template: impl Into<String>) -> Self {
Self {
template: template.into(),
}
}
pub fn format(&self, variables: &HashMap<&str, &str>) -> Result<String, String> {
let mut result = self.template.clone();
let re = Regex::new(r"\{(\w+)\}").unwrap();
for cap in re.captures_iter(&self.template) {
let var_name = cap.get(1).unwrap().as_str();
if let Some(value) = variables.get(var_name) {
result = result.replace(&format!("{{{}}}", var_name), value);
} else {
return Err(format!("Missing variable: {}", var_name));
}
}
Ok(result)
}
pub fn variables(&self) -> Vec<String> {
let re = Regex::new(r"\{(\w+)\}").unwrap();
re.captures_iter(&self.template)
.map(|cap| cap.get(1).unwrap().as_str().to_string())
.collect()
}
pub fn template(&self) -> &str {
&self.template
}
}
impl std::fmt::Display for PromptTemplate {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.template)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic_template() {
let template = PromptTemplate::new("你好,{name}!");
let mut vars = HashMap::new();
vars.insert("name", "小明");
let result = template.format(&vars).unwrap();
assert_eq!(result, "你好,小明!");
}
#[test]
fn test_multiple_variables() {
let template = PromptTemplate::new("{greeting},{name}!今天是{day}。");
let mut vars = HashMap::new();
vars.insert("greeting", "早上好");
vars.insert("name", "小红");
vars.insert("day", "星期一");
let result = template.format(&vars).unwrap();
assert_eq!(result, "早上好,小红!今天是星期一。");
}
#[test]
fn test_missing_variable() {
let template = PromptTemplate::new("你好,{name}!今天是{day}。");
let mut vars = HashMap::new();
vars.insert("name", "小明");
let result = template.format(&vars);
assert!(result.is_err());
assert!(result.unwrap_err().contains("day"));
}
#[test]
fn test_get_variables() {
let template = PromptTemplate::new("{a}, {b}, {c}");
let vars = template.variables();
assert_eq!(vars, vec!["a", "b", "c"]);
}
}