Skip to main content

oxi/
templates.rs

1//! Prompt template system for oxi
2//!
3//! Reusable prompt templates stored as Markdown files with `{{variable}}`
4//! placeholders.  Templates live under `~/.oxi/templates/` and each `.md`
5//! file becomes a named template (filename sans extension).
6
7use anyhow::{bail, Context, Result};
8use std::collections::HashMap;
9use std::path::Path;
10
11/// A single prompt template loaded from disk.
12#[derive(Debug, Clone)]
13pub struct PromptTemplate {
14    /// Template name (derived from filename without extension).
15    pub name: String,
16    /// Raw template content containing `{{variable}}` placeholders.
17    pub content: String,
18    /// Variable names extracted from the template.
19    pub variables: Vec<String>,
20}
21
22impl PromptTemplate {
23    /// Parse a template from raw text, extracting `{{variable}}` names.
24    pub fn parse(name: String, content: String) -> Self {
25        let variables = extract_variables(&content);
26        Self {
27            name,
28            content,
29            variables,
30        }
31    }
32
33    /// Render the template, replacing every `{{key}}` with the provided value.
34    pub fn render(&self, vars: &HashMap<&str, &str>) -> Result<String> {
35        // Check for missing variables
36        let missing: Vec<&str> = self
37            .variables
38            .iter()
39            .filter(|v| !vars.contains_key(v.as_str()))
40            .map(|v| v.as_str())
41            .collect();
42
43        if !missing.is_empty() {
44            bail!(
45                "Missing variables for template '{}': {}",
46                self.name,
47                missing.join(", ")
48            );
49        }
50
51        Ok(render_template(&self.content, vars))
52    }
53}
54
55/// Extract all `{{variable}}` names from a template string.
56fn extract_variables(template: &str) -> Vec<String> {
57    let mut vars = Vec::new();
58    let mut seen = std::collections::HashSet::new();
59    let chars: Vec<char> = template.chars().collect();
60    let len = chars.len();
61    let mut i = 0;
62
63    while i < len {
64        if chars[i] == '{' && i + 1 < len && chars[i + 1] == '{' {
65            i += 2; // skip '{{'
66            let mut name = String::new();
67            let mut found_end = false;
68            while i < len {
69                if chars[i] == '}' && i + 1 < len && chars[i + 1] == '}' {
70                    i += 2; // skip '}}'
71                    found_end = true;
72                    break;
73                }
74                name.push(chars[i]);
75                i += 1;
76            }
77            if found_end {
78                let trimmed = name.trim().to_string();
79                if !trimmed.is_empty() && seen.insert(trimmed.clone()) {
80                    vars.push(trimmed);
81                }
82            }
83        } else {
84            i += 1;
85        }
86    }
87
88    vars
89}
90
91/// Replace all `{{key}}` placeholders in `template` with values from `vars`.
92fn render_template(template: &str, vars: &HashMap<&str, &str>) -> String {
93    let mut result = template.to_string();
94    for (key, value) in vars {
95        result = result.replace(&format!("{{{{{}}}}}", key), value);
96    }
97    result
98}
99
100/// Manages a collection of named prompt templates.
101#[derive(Debug, Clone, Default)]
102pub struct TemplateManager {
103    templates: HashMap<String, PromptTemplate>,
104}
105
106impl TemplateManager {
107    /// Create an empty template manager.
108    pub fn new() -> Self {
109        Self::default()
110    }
111
112    /// Load all `.md` files from a directory as templates.
113    ///
114    /// Each file becomes a template named after the file (without extension).
115    /// Returns an error only if the directory cannot be read.
116    pub fn load_from_dir(dir: &Path) -> Result<Self> {
117        let mut manager = Self::new();
118
119        if !dir.exists() {
120            // No templates directory is fine — start empty.
121            return Ok(manager);
122        }
123
124        let entries = std::fs::read_dir(dir)
125            .with_context(|| format!("Failed to read templates directory: {}", dir.display()))?;
126
127        for entry in entries {
128            let entry = entry?;
129            let path = entry.path();
130
131            // Only load .md files
132            if path.extension().and_then(|e| e.to_str()) != Some("md") {
133                continue;
134            }
135
136            let name = path
137                .file_stem()
138                .and_then(|s| s.to_str())
139                .unwrap_or("")
140                .to_string();
141
142            if name.is_empty() {
143                continue;
144            }
145
146            let content = std::fs::read_to_string(&path)
147                .with_context(|| format!("Failed to read template: {}", path.display()))?;
148
149            let template = PromptTemplate::parse(name.clone(), content);
150            tracing::debug!(name = %name, vars = ?template.variables, "loaded template");
151            manager.templates.insert(name, template);
152        }
153
154        Ok(manager)
155    }
156
157    /// Render a template by name with the given variables.
158    pub fn render(&self, name: &str, vars: HashMap<&str, &str>) -> Result<String> {
159        let template = self
160            .templates
161            .get(name)
162            .with_context(|| format!("Template '{}' not found", name))?;
163        template.render(&vars)
164    }
165
166    /// Get a template by name.
167    pub fn get(&self, name: &str) -> Option<&PromptTemplate> {
168        self.templates.get(name)
169    }
170
171    /// List all loaded template names.
172    pub fn template_names(&self) -> Vec<&str> {
173        let mut names: Vec<&str> = self.templates.keys().map(|s| s.as_str()).collect();
174        names.sort();
175        names
176    }
177
178    /// Number of loaded templates.
179    pub fn len(&self) -> usize {
180        self.templates.len()
181    }
182
183    /// Whether any templates are loaded.
184    pub fn is_empty(&self) -> bool {
185        self.templates.is_empty()
186    }
187}
188
189#[cfg(test)]
190mod tests {
191    use super::*;
192
193    #[test]
194    fn test_extract_variables_simple() {
195        let vars = extract_variables("Hello {{name}}, welcome to {{place}}!");
196        assert_eq!(vars, vec!["name", "place"]);
197    }
198
199    #[test]
200    fn test_extract_variables_dedup() {
201        let vars = extract_variables("{{x}} + {{x}} = {{y}}");
202        assert_eq!(vars, vec!["x", "y"]);
203    }
204
205    #[test]
206    fn test_extract_variables_none() {
207        let vars = extract_variables("No variables here.");
208        assert!(vars.is_empty());
209    }
210
211    #[test]
212    fn test_extract_variables_whitespace() {
213        let vars = extract_variables("{{ name }} and {{  place  }}");
214        assert_eq!(vars, vec!["name", "place"]);
215    }
216
217    #[test]
218    fn test_render_template_basic() {
219        let mut vars = HashMap::new();
220        vars.insert("name", "world");
221        vars.insert("lang", "Rust");
222
223        let result = render_template("Hello {{name}}, write {{lang}}!", &vars);
224        assert_eq!(result, "Hello world, write Rust!");
225    }
226
227    #[test]
228    fn test_render_template_no_match() {
229        let vars = HashMap::new();
230        let result = render_template("No {{vars}} to replace", &vars);
231        assert_eq!(result, "No {{vars}} to replace");
232    }
233
234    #[test]
235    fn test_prompt_template_parse_and_render() {
236        let tmpl = PromptTemplate::parse(
237            "greet".to_string(),
238            "Hello {{name}}, your role is {{role}}.".to_string(),
239        );
240        assert_eq!(tmpl.name, "greet");
241        assert_eq!(tmpl.variables, vec!["name", "role"]);
242
243        let mut vars = HashMap::new();
244        vars.insert("name", "Alice");
245        vars.insert("role", "admin");
246        let result = tmpl.render(&vars).unwrap();
247        assert_eq!(result, "Hello Alice, your role is admin.");
248    }
249
250    #[test]
251    fn test_prompt_template_missing_vars() {
252        let tmpl = PromptTemplate::parse(
253            "greet".to_string(),
254            "Hello {{name}}, role: {{role}}.".to_string(),
255        );
256
257        let mut vars = HashMap::new();
258        vars.insert("name", "Alice");
259        // missing "role"
260        let err = tmpl.render(&vars).unwrap_err();
261        assert!(err.to_string().contains("Missing variables"));
262        assert!(err.to_string().contains("role"));
263    }
264
265    #[test]
266    fn test_template_manager_new() {
267        let mgr = TemplateManager::new();
268        assert!(mgr.is_empty());
269        assert_eq!(mgr.len(), 0);
270    }
271
272    #[test]
273    fn test_template_manager_get_and_render() {
274        let mut mgr = TemplateManager::new();
275        let tmpl = PromptTemplate::parse(
276            "review".to_string(),
277            "Review this {{type}}: {{content}}".to_string(),
278        );
279        mgr.templates.insert("review".to_string(), tmpl);
280
281        assert_eq!(mgr.template_names(), vec!["review"]);
282
283        let mut vars = HashMap::new();
284        vars.insert("type", "PR");
285        vars.insert("content", "my changes");
286        let result = mgr.render("review", vars).unwrap();
287        assert_eq!(result, "Review this PR: my changes");
288    }
289
290    #[test]
291    fn test_template_manager_not_found() {
292        let mgr = TemplateManager::new();
293        let err = mgr.render("nonexistent", HashMap::new()).unwrap_err();
294        assert!(err.to_string().contains("not found"));
295    }
296
297    #[test]
298    fn test_load_from_dir_missing() {
299        let mgr = TemplateManager::load_from_dir(Path::new("/nonexistent/templates")).unwrap();
300        assert!(mgr.is_empty());
301    }
302
303    #[test]
304    fn test_load_from_dir_with_files() {
305        let dir = tempfile::tempdir().unwrap();
306        let templates_dir = dir.path();
307
308        std::fs::write(
309            templates_dir.join("greet.md"),
310            "Hello {{name}}!",
311        ).unwrap();
312        std::fs::write(
313            templates_dir.join("review.md"),
314            "Review {{lang}} code.",
315        ).unwrap();
316        // Non-md file should be skipped
317        std::fs::write(
318            templates_dir.join("notes.txt"),
319            "Not a template",
320        ).unwrap();
321
322        let mgr = TemplateManager::load_from_dir(templates_dir).unwrap();
323        assert_eq!(mgr.len(), 2);
324
325        let mut names = mgr.template_names();
326        names.sort();
327        assert_eq!(names, vec!["greet", "review"]);
328
329        let mut vars = HashMap::new();
330        vars.insert("name", "world");
331        assert_eq!(mgr.render("greet", vars).unwrap(), "Hello world!");
332    }
333}