oxify_connect_llm/
templates.rs

1//! Prompt template engine for LLM requests
2//!
3//! This module provides a simple yet powerful template engine for creating
4//! reusable prompt templates with variable substitution and conditional sections.
5
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use thiserror::Error;
9
10#[derive(Error, Debug)]
11pub enum TemplateError {
12    #[error("Missing variable: {0}")]
13    MissingVariable(String),
14
15    #[error("Template parsing error: {0}")]
16    ParseError(String),
17
18    #[error("Invalid template syntax: {0}")]
19    SyntaxError(String),
20}
21
22/// A prompt template with variable substitution support
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct PromptTemplate {
25    /// Template string with {{variable}} placeholders
26    template: String,
27    /// Optional list of required variables
28    #[serde(default)]
29    required_vars: Vec<String>,
30    /// Optional description of the template
31    #[serde(default)]
32    description: Option<String>,
33}
34
35impl PromptTemplate {
36    /// Create a new prompt template
37    pub fn new(template: String) -> Self {
38        Self {
39            template,
40            required_vars: Vec::new(),
41            description: None,
42        }
43    }
44
45    /// Set required variables
46    pub fn with_required_vars(mut self, vars: Vec<String>) -> Self {
47        self.required_vars = vars;
48        self
49    }
50
51    /// Set description
52    pub fn with_description(mut self, description: String) -> Self {
53        self.description = Some(description);
54        self
55    }
56
57    /// Render the template with the given variables
58    pub fn render(
59        &self,
60        variables: &HashMap<String, String>,
61    ) -> std::result::Result<String, TemplateError> {
62        // Check for required variables
63        for var in &self.required_vars {
64            if !variables.contains_key(var) {
65                return Err(TemplateError::MissingVariable(var.clone()));
66            }
67        }
68
69        let mut result = self.template.clone();
70
71        // Replace {{variable}} with values
72        for (key, value) in variables {
73            let placeholder = format!("{{{{{}}}}}", key);
74            result = result.replace(&placeholder, value);
75        }
76
77        // Check for any remaining unreplaced variables (strict mode)
78        if result.contains("{{") && result.contains("}}") {
79            let start = result.find("{{").unwrap();
80            let end = result[start..].find("}}").unwrap() + start + 2;
81            let var_name = &result[start + 2..end - 2];
82            return Err(TemplateError::MissingVariable(var_name.to_string()));
83        }
84
85        Ok(result)
86    }
87
88    /// Render the template, allowing missing variables (they will be kept as-is)
89    pub fn render_partial(&self, variables: &HashMap<String, String>) -> String {
90        let mut result = self.template.clone();
91
92        for (key, value) in variables {
93            let placeholder = format!("{{{{{}}}}}", key);
94            result = result.replace(&placeholder, value);
95        }
96
97        result
98    }
99
100    /// Extract all variable names from the template
101    pub fn extract_variables(&self) -> Vec<String> {
102        let mut variables = Vec::new();
103        let mut chars = self.template.chars().peekable();
104
105        while let Some(c) = chars.next() {
106            if c == '{' {
107                if let Some(&'{') = chars.peek() {
108                    chars.next(); // consume second '{'
109                    let mut var_name = String::new();
110
111                    // Read until we find '}}'
112                    while let Some(ch) = chars.next() {
113                        if ch == '}' {
114                            if let Some(&'}') = chars.peek() {
115                                chars.next(); // consume second '}'
116                                variables.push(var_name);
117                                break;
118                            }
119                        }
120                        var_name.push(ch);
121                    }
122                }
123            }
124        }
125
126        variables
127    }
128}
129
130/// Template library for common prompt patterns
131pub struct TemplateLibrary {
132    templates: HashMap<String, PromptTemplate>,
133}
134
135impl Default for TemplateLibrary {
136    fn default() -> Self {
137        Self::new()
138    }
139}
140
141impl TemplateLibrary {
142    /// Create a new template library with common templates
143    pub fn new() -> Self {
144        let mut library = Self {
145            templates: HashMap::new(),
146        };
147
148        // Add common templates
149        library.add_common_templates();
150        library
151    }
152
153    /// Add a template to the library
154    pub fn add(&mut self, name: String, template: PromptTemplate) {
155        self.templates.insert(name, template);
156    }
157
158    /// Get a template by name
159    pub fn get(&self, name: &str) -> Option<&PromptTemplate> {
160        self.templates.get(name)
161    }
162
163    /// Remove a template by name
164    pub fn remove(&mut self, name: &str) -> Option<PromptTemplate> {
165        self.templates.remove(name)
166    }
167
168    /// List all template names
169    pub fn list(&self) -> Vec<String> {
170        self.templates.keys().cloned().collect()
171    }
172
173    /// Add common templates to the library
174    fn add_common_templates(&mut self) {
175        // Code review template
176        self.add(
177            "code_review".to_string(),
178            PromptTemplate::new(
179                "Review the following {{language}} code and provide feedback on:\n\
180                 1. Code quality and best practices\n\
181                 2. Potential bugs or issues\n\
182                 3. Performance improvements\n\
183                 4. Security concerns\n\n\
184                 Code:\n```{{language}}\n{{code}}\n```"
185                    .to_string(),
186            )
187            .with_required_vars(vec!["language".to_string(), "code".to_string()])
188            .with_description("Code review template for analyzing code quality".to_string()),
189        );
190
191        // Summarization template
192        self.add(
193            "summarize".to_string(),
194            PromptTemplate::new(
195                "Summarize the following text in {{style}} style:\n\n{{text}}".to_string(),
196            )
197            .with_required_vars(vec!["text".to_string()])
198            .with_description("Text summarization template".to_string()),
199        );
200
201        // Question answering template
202        self.add(
203            "qa".to_string(),
204            PromptTemplate::new(
205                "Context:\n{{context}}\n\nQuestion: {{question}}\n\nAnswer:".to_string(),
206            )
207            .with_required_vars(vec!["context".to_string(), "question".to_string()])
208            .with_description("Question answering with context".to_string()),
209        );
210
211        // Translation template
212        self.add(
213            "translate".to_string(),
214            PromptTemplate::new(
215                "Translate the following text from {{source_lang}} to {{target_lang}}:\n\n{{text}}"
216                    .to_string(),
217            )
218            .with_required_vars(vec![
219                "source_lang".to_string(),
220                "target_lang".to_string(),
221                "text".to_string(),
222            ])
223            .with_description("Language translation template".to_string()),
224        );
225
226        // Text classification template
227        self.add(
228            "classify".to_string(),
229            PromptTemplate::new(
230                "Classify the following text into one of these categories: {{categories}}\n\n\
231                 Text: {{text}}\n\n\
232                 Category:"
233                    .to_string(),
234            )
235            .with_required_vars(vec!["categories".to_string(), "text".to_string()])
236            .with_description("Text classification template".to_string()),
237        );
238
239        // Data extraction template
240        self.add(
241            "extract".to_string(),
242            PromptTemplate::new(
243                "Extract the following information from the text:\n{{fields}}\n\n\
244                 Text: {{text}}\n\n\
245                 Extracted information (as JSON):"
246                    .to_string(),
247            )
248            .with_required_vars(vec!["fields".to_string(), "text".to_string()])
249            .with_description("Structured data extraction template".to_string()),
250        );
251
252        // Chain of thought template
253        self.add(
254            "chain_of_thought".to_string(),
255            PromptTemplate::new(
256                "{{task}}\n\n\
257                 Let's approach this step-by-step:\n\
258                 1. First, let's understand what we know\n\
259                 2. Then, let's identify what we need to find\n\
260                 3. Finally, let's solve the problem\n\n\
261                 Input: {{input}}"
262                    .to_string(),
263            )
264            .with_required_vars(vec!["task".to_string(), "input".to_string()])
265            .with_description("Chain of thought reasoning template".to_string()),
266        );
267
268        // Few-shot learning template
269        self.add(
270            "few_shot".to_string(),
271            PromptTemplate::new(
272                "{{task_description}}\n\n\
273                 Examples:\n{{examples}}\n\n\
274                 Now, apply the same pattern:\n{{input}}"
275                    .to_string(),
276            )
277            .with_required_vars(vec![
278                "task_description".to_string(),
279                "examples".to_string(),
280                "input".to_string(),
281            ])
282            .with_description("Few-shot learning template".to_string()),
283        );
284    }
285}
286
287#[cfg(test)]
288mod tests {
289    use super::*;
290
291    #[test]
292    fn test_template_render() {
293        let template =
294            PromptTemplate::new("Hello {{name}}, you are {{age}} years old.".to_string());
295
296        let mut vars = HashMap::new();
297        vars.insert("name".to_string(), "Alice".to_string());
298        vars.insert("age".to_string(), "30".to_string());
299
300        let result = template.render(&vars).unwrap();
301        assert_eq!(result, "Hello Alice, you are 30 years old.");
302    }
303
304    #[test]
305    fn test_template_missing_variable() {
306        let template = PromptTemplate::new("Hello {{name}}".to_string())
307            .with_required_vars(vec!["name".to_string()]);
308
309        let vars = HashMap::new();
310        let result = template.render(&vars);
311        assert!(result.is_err());
312        assert!(matches!(result, Err(TemplateError::MissingVariable(_))));
313    }
314
315    #[test]
316    fn test_template_partial_render() {
317        let template = PromptTemplate::new("Hello {{name}}, {{greeting}}".to_string());
318
319        let mut vars = HashMap::new();
320        vars.insert("name".to_string(), "Bob".to_string());
321
322        let result = template.render_partial(&vars);
323        assert_eq!(result, "Hello Bob, {{greeting}}");
324    }
325
326    #[test]
327    fn test_extract_variables() {
328        let template = PromptTemplate::new("{{var1}} and {{var2}} and {{var3}}".to_string());
329        let vars = template.extract_variables();
330        assert_eq!(vars.len(), 3);
331        assert!(vars.contains(&"var1".to_string()));
332        assert!(vars.contains(&"var2".to_string()));
333        assert!(vars.contains(&"var3".to_string()));
334    }
335
336    #[test]
337    fn test_template_library() {
338        let library = TemplateLibrary::new();
339
340        // Test that common templates are loaded
341        assert!(library.get("code_review").is_some());
342        assert!(library.get("summarize").is_some());
343        assert!(library.get("qa").is_some());
344        assert!(library.get("translate").is_some());
345
346        let code_review = library.get("code_review").unwrap();
347        let vars_needed = code_review.extract_variables();
348        assert!(vars_needed.contains(&"language".to_string()));
349        assert!(vars_needed.contains(&"code".to_string()));
350    }
351
352    #[test]
353    fn test_code_review_template() {
354        let library = TemplateLibrary::new();
355        let template = library.get("code_review").unwrap();
356
357        let mut vars = HashMap::new();
358        vars.insert("language".to_string(), "Rust".to_string());
359        vars.insert(
360            "code".to_string(),
361            "fn main() { println!(\"Hello\"); }".to_string(),
362        );
363
364        let result = template.render(&vars).unwrap();
365        assert!(result.contains("Rust"));
366        assert!(result.contains("fn main()"));
367    }
368
369    #[test]
370    fn test_qa_template() {
371        let library = TemplateLibrary::new();
372        let template = library.get("qa").unwrap();
373
374        let mut vars = HashMap::new();
375        vars.insert("context".to_string(), "The sky is blue.".to_string());
376        vars.insert("question".to_string(), "What color is the sky?".to_string());
377
378        let result = template.render(&vars).unwrap();
379        assert!(result.contains("Context"));
380        assert!(result.contains("The sky is blue"));
381        assert!(result.contains("What color is the sky?"));
382    }
383
384    #[test]
385    fn test_custom_template_addition() {
386        let mut library = TemplateLibrary::new();
387
388        let custom = PromptTemplate::new("Custom: {{value}}".to_string());
389        library.add("custom".to_string(), custom);
390
391        assert!(library.get("custom").is_some());
392
393        let template = library.get("custom").unwrap();
394        let mut vars = HashMap::new();
395        vars.insert("value".to_string(), "test".to_string());
396
397        let result = template.render(&vars).unwrap();
398        assert_eq!(result, "Custom: test");
399    }
400}