use serde_json::Value;
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct PromptTemplate {
template: String,
name: Option<String>,
variables: Vec<String>,
}
impl PromptTemplate {
pub fn from_string(template: impl Into<String>) -> Self {
let template = template.into();
let variables = extract_variables(&template);
Self {
template,
name: None,
variables,
}
}
pub fn from_file(path: &std::path::Path) -> Result<Self, PromptError> {
let content =
std::fs::read_to_string(path).map_err(|e| PromptError::IoError { source: e })?;
Ok(Self::from_string(content))
}
pub fn with_name(mut self, name: impl Into<String>) -> Self {
self.name = Some(name.into());
self
}
pub fn variables(&self) -> &[String] {
&self.variables
}
pub fn render(&self, context: &Value) -> Result<String, PromptError> {
let mut result = self.template.clone();
for var in &self.variables {
if let Some(value) = get_json_value(context, var) {
let escaped = escape_prompt(&value);
result = result.replace(&format!("{{{{{var}}}}}"), &escaped);
}
}
result = process_if_blocks(&result, context);
result = process_each_blocks(&result, context);
result = cleanup_unused_variables(&result);
Ok(result)
}
pub fn render_unescaped(&self, context: &Value) -> Result<String, PromptError> {
let mut result = self.template.clone();
for var in &self.variables {
if let Some(value) = get_json_value(context, var) {
result = result.replace(&format!("{{{{{var}}}}}"), &value);
}
}
result = process_if_blocks(&result, context);
result = process_each_blocks(&result, context);
result = cleanup_unused_variables(&result);
Ok(result)
}
pub fn compile(&self) -> CompiledPrompt {
CompiledPrompt {
template: Arc::new(self.clone()),
}
}
}
#[derive(Debug, Clone)]
pub struct CompiledPrompt {
template: Arc<PromptTemplate>,
}
impl CompiledPrompt {
pub fn render(&self, context: &Value) -> Result<String, PromptError> {
self.template.render(context)
}
}
#[derive(Debug, thiserror::Error)]
pub enum PromptError {
#[error("IO 错误:{source}")]
IoError {
#[from]
source: std::io::Error,
},
#[error("JSON 错误:{source}")]
JsonError {
#[from]
source: serde_json::Error,
},
#[error("缺少变量:{name}")]
MissingVariable { name: String },
#[error("模板语法错误:{message}")]
SyntaxError { message: String },
}
fn extract_variables(template: &str) -> Vec<String> {
let mut variables = Vec::new();
let mut chars = template.chars().peekable();
while let Some(ch) = chars.next() {
if ch == '{' && chars.peek() == Some(&'{') {
chars.next();
let mut var = String::new();
while let Some(&ch) = chars.peek() {
if ch == '}' {
chars.next();
if chars.peek() == Some(&'}') {
chars.next();
break;
}
}
var.push(chars.next().unwrap());
}
let var_name = var.split('|').next().unwrap_or(&var).trim().to_string();
if !var_name.is_empty() && !variables.contains(&var_name) {
variables.push(var_name);
}
}
}
variables
}
fn get_json_value(context: &Value, path: &str) -> Option<String> {
let mut current = context;
for part in path.split('.') {
current = match current {
Value::Object(map) => map.get(part)?,
Value::Array(arr) => {
let index: usize = part.parse().ok()?;
arr.get(index)?
}
_ => return None,
};
}
match current {
Value::String(s) => Some(s.clone()),
Value::Number(n) => Some(n.to_string()),
Value::Bool(b) => Some(b.to_string()),
Value::Null => Some("null".to_string()),
_ => Some(current.to_string()),
}
}
fn escape_prompt(text: &str) -> String {
text.replace("```", "\\`\\`\\`")
.replace("\"", "\\\"")
.replace("\n\n\n", "\n\n")
}
fn process_if_blocks(text: &str, context: &Value) -> String {
let mut result = text.to_string();
let if_pattern = r"\{\{#if\s+(\w+)\}\}(.*?)\{\{/if\}\}";
let re = regex::Regex::new(if_pattern).unwrap();
for cap in re.captures_iter(text) {
let var_name = &cap[1];
let block_content = &cap[2];
let should_include = get_json_value(context, var_name)
.is_some_and(|v| v != "false" && v != "null" && !v.is_empty());
let replacement = if should_include {
block_content.to_string()
} else {
String::new()
};
result = result.replace(&cap[0], &replacement);
}
result
}
fn process_each_blocks(text: &str, context: &Value) -> String {
let mut result = text.to_string();
let each_pattern = r"\{\{#each\s+(\w+)\}\}(.*?)\{\{/each\}\}";
let re = regex::Regex::new(each_pattern).unwrap();
for cap in re.captures_iter(text) {
let var_name = &cap[1];
let block_content = &cap[2];
if let Some(array_value) = context.get(var_name).and_then(|v| v.as_array()) {
let mut items = Vec::new();
for item in array_value {
let mut item_block = block_content.to_string();
if let Some(item_str) = item.as_str() {
item_block = item_block.replace("{{this}}", item_str);
} else {
item_block = item_block.replace("{{this}}", &item.to_string());
}
items.push(item_block);
}
result = result.replace(&cap[0], &items.join(""));
} else {
result = result.replace(&cap[0], "");
}
}
result
}
fn cleanup_unused_variables(text: &str) -> String {
let re = regex::Regex::new(r"\{\{\s*\w+\s*\}\}").unwrap();
re.replace_all(text, "").to_string()
}
#[derive(Debug, Default)]
pub struct PromptBuilder {
messages: Vec<(String, String)>,
}
impl PromptBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn system(mut self, content: impl Into<String>) -> Self {
self.messages.push(("system".to_string(), content.into()));
self
}
pub fn user(mut self, content: impl Into<String>) -> Self {
self.messages.push(("user".to_string(), content.into()));
self
}
pub fn assistant(mut self, content: impl Into<String>) -> Self {
self.messages
.push(("assistant".to_string(), content.into()));
self
}
pub fn tool(mut self, content: impl Into<String>, tool_name: impl Into<String>) -> Self {
self.messages
.push((format!("tool:{}", tool_name.into()), content.into()));
self
}
pub fn build(&self) -> String {
self.messages
.iter()
.map(|(role, content)| format!("<{role}>{content}</{role}>"))
.collect::<Vec<_>>()
.join("\n")
}
pub fn build_messages(&self) -> Vec<(String, String)> {
self.messages.clone()
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_prompt_template_basic() {
let template = PromptTemplate::from_string("你好,{{name}}!");
let result = template.render(&json!({"name": "世界"})).unwrap();
assert_eq!(result, "你好,世界!");
}
#[test]
fn test_prompt_template_nested() {
let template = PromptTemplate::from_string("{{user_name}}: {{user_email}}");
let result = template
.render(&json!({
"user_name": "张三",
"user_email": "zhangsan@example.com"
}))
.unwrap();
assert_eq!(result, "张三: zhangsan@example.com");
}
#[test]
fn test_prompt_template_if() {
let template = PromptTemplate::from_string("你好{{#if name}},{{name}}{{/if}}!");
let result1 = template.render(&json!({"name": "张三"})).unwrap();
assert_eq!(result1, "你好,张三!");
let result2 = template.render(&json!({})).unwrap();
assert_eq!(result2, "你好!");
}
#[test]
fn test_prompt_template_each() {
let template = PromptTemplate::from_string("项目:{{item1}}, {{item2}}, {{item3}}");
let result = template
.render(&json!({
"item1": "项目 A",
"item2": "项目 B",
"item3": "项目 C"
}))
.unwrap();
assert!(result.contains("项目 A"));
assert!(result.contains("项目 B"));
assert!(result.contains("项目 C"));
}
#[test]
fn test_prompt_escape() {
let template = PromptTemplate::from_string("{{content}}");
let result = template
.render(&json!({
"content": "```python\ncode\n```"
}))
.unwrap();
assert!(result.contains("\\`\\`\\`"));
}
#[test]
fn test_prompt_builder() {
let prompt = PromptBuilder::new().system("你是助手").user("你好").build();
assert!(prompt.contains("<system>你是助手</system>"));
assert!(prompt.contains("<user>你好</user>"));
}
}