1use anyhow::{bail, Context, Result};
8use std::collections::HashMap;
9use std::path::Path;
10
11#[derive(Debug, Clone)]
13pub struct PromptTemplate {
14 pub name: String,
16 pub content: String,
18 pub variables: Vec<String>,
20}
21
22impl PromptTemplate {
23 pub fn parse(name: String, content: String) -> Self {
25 let variables = extract_variables(&content);
26 Self {
27 name,
28 content,
29 variables,
30 }
31 }
32
33 pub fn render(&self, vars: &HashMap<&str, &str>) -> Result<String> {
35 let missing: Vec<&str> = self
37 .variables
38 .iter()
39 .filter(|v| !vars.contains_key(v.as_str()))
40 .map(|v| v.as_str())
41 .collect();
42
43 if !missing.is_empty() {
44 bail!(
45 "Missing variables for template '{}': {}",
46 self.name,
47 missing.join(", ")
48 );
49 }
50
51 Ok(render_template(&self.content, vars))
52 }
53}
54
55fn extract_variables(template: &str) -> Vec<String> {
57 let mut vars = Vec::new();
58 let mut seen = std::collections::HashSet::new();
59 let chars: Vec<char> = template.chars().collect();
60 let len = chars.len();
61 let mut i = 0;
62
63 while i < len {
64 if chars[i] == '{' && i + 1 < len && chars[i + 1] == '{' {
65 i += 2; let mut name = String::new();
67 let mut found_end = false;
68 while i < len {
69 if chars[i] == '}' && i + 1 < len && chars[i + 1] == '}' {
70 i += 2; found_end = true;
72 break;
73 }
74 name.push(chars[i]);
75 i += 1;
76 }
77 if found_end {
78 let trimmed = name.trim().to_string();
79 if !trimmed.is_empty() && seen.insert(trimmed.clone()) {
80 vars.push(trimmed);
81 }
82 }
83 } else {
84 i += 1;
85 }
86 }
87
88 vars
89}
90
91fn render_template(template: &str, vars: &HashMap<&str, &str>) -> String {
93 let mut result = template.to_string();
94 for (key, value) in vars {
95 result = result.replace(&format!("{{{{{}}}}}", key), value);
96 }
97 result
98}
99
100#[derive(Debug, Clone, Default)]
102pub struct TemplateManager {
103 templates: HashMap<String, PromptTemplate>,
104}
105
106impl TemplateManager {
107 pub fn new() -> Self {
109 Self::default()
110 }
111
112 pub fn load_from_dir(dir: &Path) -> Result<Self> {
117 let mut manager = Self::new();
118
119 if !dir.exists() {
120 return Ok(manager);
122 }
123
124 let entries = std::fs::read_dir(dir)
125 .with_context(|| format!("Failed to read templates directory: {}", dir.display()))?;
126
127 for entry in entries {
128 let entry = entry?;
129 let path = entry.path();
130
131 if path.extension().and_then(|e| e.to_str()) != Some("md") {
133 continue;
134 }
135
136 let name = path
137 .file_stem()
138 .and_then(|s| s.to_str())
139 .unwrap_or("")
140 .to_string();
141
142 if name.is_empty() {
143 continue;
144 }
145
146 let content = std::fs::read_to_string(&path)
147 .with_context(|| format!("Failed to read template: {}", path.display()))?;
148
149 let template = PromptTemplate::parse(name.clone(), content);
150 tracing::debug!(name = %name, vars = ?template.variables, "loaded template");
151 manager.templates.insert(name, template);
152 }
153
154 Ok(manager)
155 }
156
157 pub fn render(&self, name: &str, vars: HashMap<&str, &str>) -> Result<String> {
159 let template = self
160 .templates
161 .get(name)
162 .with_context(|| format!("Template '{}' not found", name))?;
163 template.render(&vars)
164 }
165
166 pub fn get(&self, name: &str) -> Option<&PromptTemplate> {
168 self.templates.get(name)
169 }
170
171 pub fn template_names(&self) -> Vec<&str> {
173 let mut names: Vec<&str> = self.templates.keys().map(|s| s.as_str()).collect();
174 names.sort();
175 names
176 }
177
178 pub fn len(&self) -> usize {
180 self.templates.len()
181 }
182
183 pub fn is_empty(&self) -> bool {
185 self.templates.is_empty()
186 }
187}
188
189#[cfg(test)]
190mod tests {
191 use super::*;
192
193 #[test]
194 fn test_extract_variables_simple() {
195 let vars = extract_variables("Hello {{name}}, welcome to {{place}}!");
196 assert_eq!(vars, vec!["name", "place"]);
197 }
198
199 #[test]
200 fn test_extract_variables_dedup() {
201 let vars = extract_variables("{{x}} + {{x}} = {{y}}");
202 assert_eq!(vars, vec!["x", "y"]);
203 }
204
205 #[test]
206 fn test_extract_variables_none() {
207 let vars = extract_variables("No variables here.");
208 assert!(vars.is_empty());
209 }
210
211 #[test]
212 fn test_extract_variables_whitespace() {
213 let vars = extract_variables("{{ name }} and {{ place }}");
214 assert_eq!(vars, vec!["name", "place"]);
215 }
216
217 #[test]
218 fn test_render_template_basic() {
219 let mut vars = HashMap::new();
220 vars.insert("name", "world");
221 vars.insert("lang", "Rust");
222
223 let result = render_template("Hello {{name}}, write {{lang}}!", &vars);
224 assert_eq!(result, "Hello world, write Rust!");
225 }
226
227 #[test]
228 fn test_render_template_no_match() {
229 let vars = HashMap::new();
230 let result = render_template("No {{vars}} to replace", &vars);
231 assert_eq!(result, "No {{vars}} to replace");
232 }
233
234 #[test]
235 fn test_prompt_template_parse_and_render() {
236 let tmpl = PromptTemplate::parse(
237 "greet".to_string(),
238 "Hello {{name}}, your role is {{role}}.".to_string(),
239 );
240 assert_eq!(tmpl.name, "greet");
241 assert_eq!(tmpl.variables, vec!["name", "role"]);
242
243 let mut vars = HashMap::new();
244 vars.insert("name", "Alice");
245 vars.insert("role", "admin");
246 let result = tmpl.render(&vars).unwrap();
247 assert_eq!(result, "Hello Alice, your role is admin.");
248 }
249
250 #[test]
251 fn test_prompt_template_missing_vars() {
252 let tmpl = PromptTemplate::parse(
253 "greet".to_string(),
254 "Hello {{name}}, role: {{role}}.".to_string(),
255 );
256
257 let mut vars = HashMap::new();
258 vars.insert("name", "Alice");
259 let err = tmpl.render(&vars).unwrap_err();
261 assert!(err.to_string().contains("Missing variables"));
262 assert!(err.to_string().contains("role"));
263 }
264
265 #[test]
266 fn test_template_manager_new() {
267 let mgr = TemplateManager::new();
268 assert!(mgr.is_empty());
269 assert_eq!(mgr.len(), 0);
270 }
271
272 #[test]
273 fn test_template_manager_get_and_render() {
274 let mut mgr = TemplateManager::new();
275 let tmpl = PromptTemplate::parse(
276 "review".to_string(),
277 "Review this {{type}}: {{content}}".to_string(),
278 );
279 mgr.templates.insert("review".to_string(), tmpl);
280
281 assert_eq!(mgr.template_names(), vec!["review"]);
282
283 let mut vars = HashMap::new();
284 vars.insert("type", "PR");
285 vars.insert("content", "my changes");
286 let result = mgr.render("review", vars).unwrap();
287 assert_eq!(result, "Review this PR: my changes");
288 }
289
290 #[test]
291 fn test_template_manager_not_found() {
292 let mgr = TemplateManager::new();
293 let err = mgr.render("nonexistent", HashMap::new()).unwrap_err();
294 assert!(err.to_string().contains("not found"));
295 }
296
297 #[test]
298 fn test_load_from_dir_missing() {
299 let mgr = TemplateManager::load_from_dir(Path::new("/nonexistent/templates")).unwrap();
300 assert!(mgr.is_empty());
301 }
302
303 #[test]
304 fn test_load_from_dir_with_files() {
305 let dir = tempfile::tempdir().unwrap();
306 let templates_dir = dir.path();
307
308 std::fs::write(
309 templates_dir.join("greet.md"),
310 "Hello {{name}}!",
311 ).unwrap();
312 std::fs::write(
313 templates_dir.join("review.md"),
314 "Review {{lang}} code.",
315 ).unwrap();
316 std::fs::write(
318 templates_dir.join("notes.txt"),
319 "Not a template",
320 ).unwrap();
321
322 let mgr = TemplateManager::load_from_dir(templates_dir).unwrap();
323 assert_eq!(mgr.len(), 2);
324
325 let mut names = mgr.template_names();
326 names.sort();
327 assert_eq!(names, vec!["greet", "review"]);
328
329 let mut vars = HashMap::new();
330 vars.insert("name", "world");
331 assert_eq!(mgr.render("greet", vars).unwrap(), "Hello world!");
332 }
333}