use serde::{Deserialize, Serialize};
use std::collections::HashMap;
pub type Result<T> = std::result::Result<T, GuidanceError>;
#[derive(Debug, thiserror::Error)]
pub enum GuidanceError {
#[error("Template error: {0}")]
Template(String),
#[error("Variable not found: {0}")]
VariableNotFound(String),
#[error("Invalid format: {0}")]
InvalidFormat(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PromptTemplate {
pub name: String,
pub content: String,
#[serde(default)]
pub defaults: HashMap<String, String>,
}
impl PromptTemplate {
pub fn new(name: impl Into<String>, content: impl Into<String>) -> Self {
Self {
name: name.into(),
content: content.into(),
defaults: HashMap::new(),
}
}
pub fn with_default(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.defaults.insert(key.into(), value.into());
self
}
pub fn render(&self, variables: &HashMap<String, String>) -> Result<String> {
let mut result = self.content.clone();
let mut all_vars = self.defaults.clone();
all_vars.extend(variables.clone());
for (key, value) in all_vars {
let placeholder = format!("{{{{{}}}}}", key);
result = result.replace(&placeholder, &value);
}
if result.contains("{{") && result.contains("}}") {
return Err(GuidanceError::Template(
"Template contains unresolved placeholders".to_string(),
));
}
Ok(result)
}
}
#[derive(Debug, Default)]
pub struct TemplateRegistry {
templates: HashMap<String, PromptTemplate>,
}
impl TemplateRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn register(&mut self, template: PromptTemplate) {
self.templates.insert(template.name.clone(), template);
}
pub fn get(&self, name: &str) -> Option<&PromptTemplate> {
self.templates.get(name)
}
pub fn render(&self, name: &str, variables: &HashMap<String, String>) -> Result<String> {
let template = self
.get(name)
.ok_or_else(|| GuidanceError::VariableNotFound(name.to_string()))?;
template.render(variables)
}
pub fn list(&self) -> Vec<String> {
self.templates.keys().cloned().collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_template_creation() {
let template = PromptTemplate::new("test", "Hello {{name}}!");
assert_eq!(template.name, "test");
assert_eq!(template.content, "Hello {{name}}!");
}
#[test]
fn test_template_render() {
let template = PromptTemplate::new("test", "Hello {{name}}!");
let mut vars = HashMap::new();
vars.insert("name".to_string(), "World".to_string());
let result = template.render(&vars).unwrap();
assert_eq!(result, "Hello World!");
}
#[test]
fn test_template_defaults() {
let template =
PromptTemplate::new("test", "Hello {{name}}!").with_default("name", "Default");
let result = template.render(&HashMap::new()).unwrap();
assert_eq!(result, "Hello Default!");
}
#[test]
fn test_template_override_default() {
let template =
PromptTemplate::new("test", "Hello {{name}}!").with_default("name", "Default");
let mut vars = HashMap::new();
vars.insert("name".to_string(), "Custom".to_string());
let result = template.render(&vars).unwrap();
assert_eq!(result, "Hello Custom!");
}
#[test]
fn test_registry() {
let mut registry = TemplateRegistry::new();
let template = PromptTemplate::new("greeting", "Hello {{name}}!");
registry.register(template);
assert!(registry.get("greeting").is_some());
assert_eq!(registry.list().len(), 1);
}
#[test]
fn test_registry_render() {
let mut registry = TemplateRegistry::new();
let template = PromptTemplate::new("greeting", "Hello {{name}}!");
registry.register(template);
let mut vars = HashMap::new();
vars.insert("name".to_string(), "World".to_string());
let result = registry.render("greeting", &vars).unwrap();
assert_eq!(result, "Hello World!");
}
}