use anyhow::{bail, Context, Result};
use std::collections::HashMap;
use std::path::Path;
#[derive(Debug, Clone)]
pub struct PromptTemplate {
pub name: String,
pub content: String,
pub variables: Vec<String>,
}
impl PromptTemplate {
pub fn parse(name: String, content: String) -> Self {
let variables = extract_variables(&content);
Self {
name,
content,
variables,
}
}
pub fn render(&self, vars: &HashMap<&str, &str>) -> Result<String> {
let missing: Vec<&str> = self
.variables
.iter()
.filter(|v| !vars.contains_key(v.as_str()))
.map(|v| v.as_str())
.collect();
if !missing.is_empty() {
bail!(
"Missing variables for template '{}': {}",
self.name,
missing.join(", ")
);
}
Ok(render_template(&self.content, vars))
}
}
fn extract_variables(template: &str) -> Vec<String> {
let mut vars = Vec::new();
let mut seen = std::collections::HashSet::new();
let chars: Vec<char> = template.chars().collect();
let len = chars.len();
let mut i = 0;
while i < len {
if chars[i] == '{' && i + 1 < len && chars[i + 1] == '{' {
i += 2; let mut name = String::new();
let mut found_end = false;
while i < len {
if chars[i] == '}' && i + 1 < len && chars[i + 1] == '}' {
i += 2; found_end = true;
break;
}
name.push(chars[i]);
i += 1;
}
if found_end {
let trimmed = name.trim().to_string();
if !trimmed.is_empty() && seen.insert(trimmed.clone()) {
vars.push(trimmed);
}
}
} else {
i += 1;
}
}
vars
}
fn render_template(template: &str, vars: &HashMap<&str, &str>) -> String {
let mut result = template.to_string();
for (key, value) in vars {
result = result.replace(&format!("{{{{{}}}}}", key), value);
}
result
}
#[derive(Debug, Clone, Default)]
pub struct TemplateManager {
templates: HashMap<String, PromptTemplate>,
}
impl TemplateManager {
pub fn new() -> Self {
Self::default()
}
pub fn load_from_dir(dir: &Path) -> Result<Self> {
let mut manager = Self::new();
if !dir.exists() {
return Ok(manager);
}
let entries = std::fs::read_dir(dir)
.with_context(|| format!("Failed to read templates directory: {}", dir.display()))?;
for entry in entries {
let entry = entry?;
let path = entry.path();
if path.extension().and_then(|e| e.to_str()) != Some("md") {
continue;
}
let name = path
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("")
.to_string();
if name.is_empty() {
continue;
}
let content = std::fs::read_to_string(&path)
.with_context(|| format!("Failed to read template: {}", path.display()))?;
let template = PromptTemplate::parse(name.clone(), content);
tracing::debug!(name = %name, vars = ?template.variables, "loaded template");
manager.templates.insert(name, template);
}
Ok(manager)
}
pub fn render(&self, name: &str, vars: HashMap<&str, &str>) -> Result<String> {
let template = self
.templates
.get(name)
.with_context(|| format!("Template '{}' not found", name))?;
template.render(&vars)
}
pub fn get(&self, name: &str) -> Option<&PromptTemplate> {
self.templates.get(name)
}
pub fn template_names(&self) -> Vec<&str> {
let mut names: Vec<&str> = self.templates.keys().map(|s| s.as_str()).collect();
names.sort();
names
}
pub fn len(&self) -> usize {
self.templates.len()
}
pub fn is_empty(&self) -> bool {
self.templates.is_empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_variables_simple() {
let vars = extract_variables("Hello {{name}}, welcome to {{place}}!");
assert_eq!(vars, vec!["name", "place"]);
}
#[test]
fn test_extract_variables_dedup() {
let vars = extract_variables("{{x}} + {{x}} = {{y}}");
assert_eq!(vars, vec!["x", "y"]);
}
#[test]
fn test_extract_variables_none() {
let vars = extract_variables("No variables here.");
assert!(vars.is_empty());
}
#[test]
fn test_extract_variables_whitespace() {
let vars = extract_variables("{{ name }} and {{ place }}");
assert_eq!(vars, vec!["name", "place"]);
}
#[test]
fn test_render_template_basic() {
let mut vars = HashMap::new();
vars.insert("name", "world");
vars.insert("lang", "Rust");
let result = render_template("Hello {{name}}, write {{lang}}!", &vars);
assert_eq!(result, "Hello world, write Rust!");
}
#[test]
fn test_render_template_no_match() {
let vars = HashMap::new();
let result = render_template("No {{vars}} to replace", &vars);
assert_eq!(result, "No {{vars}} to replace");
}
#[test]
fn test_prompt_template_parse_and_render() {
let tmpl = PromptTemplate::parse(
"greet".to_string(),
"Hello {{name}}, your role is {{role}}.".to_string(),
);
assert_eq!(tmpl.name, "greet");
assert_eq!(tmpl.variables, vec!["name", "role"]);
let mut vars = HashMap::new();
vars.insert("name", "Alice");
vars.insert("role", "admin");
let result = tmpl.render(&vars).unwrap();
assert_eq!(result, "Hello Alice, your role is admin.");
}
#[test]
fn test_prompt_template_missing_vars() {
let tmpl = PromptTemplate::parse(
"greet".to_string(),
"Hello {{name}}, role: {{role}}.".to_string(),
);
let mut vars = HashMap::new();
vars.insert("name", "Alice");
let err = tmpl.render(&vars).unwrap_err();
assert!(err.to_string().contains("Missing variables"));
assert!(err.to_string().contains("role"));
}
#[test]
fn test_template_manager_new() {
let mgr = TemplateManager::new();
assert!(mgr.is_empty());
assert_eq!(mgr.len(), 0);
}
#[test]
fn test_template_manager_get_and_render() {
let mut mgr = TemplateManager::new();
let tmpl = PromptTemplate::parse(
"review".to_string(),
"Review this {{type}}: {{content}}".to_string(),
);
mgr.templates.insert("review".to_string(), tmpl);
assert_eq!(mgr.template_names(), vec!["review"]);
let mut vars = HashMap::new();
vars.insert("type", "PR");
vars.insert("content", "my changes");
let result = mgr.render("review", vars).unwrap();
assert_eq!(result, "Review this PR: my changes");
}
#[test]
fn test_template_manager_not_found() {
let mgr = TemplateManager::new();
let err = mgr.render("nonexistent", HashMap::new()).unwrap_err();
assert!(err.to_string().contains("not found"));
}
#[test]
fn test_load_from_dir_missing() {
let mgr = TemplateManager::load_from_dir(Path::new("/nonexistent/templates")).unwrap();
assert!(mgr.is_empty());
}
#[test]
fn test_load_from_dir_with_files() {
let dir = tempfile::tempdir().unwrap();
let templates_dir = dir.path();
std::fs::write(
templates_dir.join("greet.md"),
"Hello {{name}}!",
).unwrap();
std::fs::write(
templates_dir.join("review.md"),
"Review {{lang}} code.",
).unwrap();
std::fs::write(
templates_dir.join("notes.txt"),
"Not a template",
).unwrap();
let mgr = TemplateManager::load_from_dir(templates_dir).unwrap();
assert_eq!(mgr.len(), 2);
let mut names = mgr.template_names();
names.sort();
assert_eq!(names, vec!["greet", "review"]);
let mut vars = HashMap::new();
vars.insert("name", "world");
assert_eq!(mgr.render("greet", vars).unwrap(), "Hello world!");
}
}