1use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use thiserror::Error;
9
10#[derive(Error, Debug)]
11pub enum TemplateError {
12 #[error("Missing variable: {0}")]
13 MissingVariable(String),
14
15 #[error("Template parsing error: {0}")]
16 ParseError(String),
17
18 #[error("Invalid template syntax: {0}")]
19 SyntaxError(String),
20}
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct PromptTemplate {
25 template: String,
27 #[serde(default)]
29 required_vars: Vec<String>,
30 #[serde(default)]
32 description: Option<String>,
33}
34
35impl PromptTemplate {
36 pub fn new(template: String) -> Self {
38 Self {
39 template,
40 required_vars: Vec::new(),
41 description: None,
42 }
43 }
44
45 pub fn with_required_vars(mut self, vars: Vec<String>) -> Self {
47 self.required_vars = vars;
48 self
49 }
50
51 pub fn with_description(mut self, description: String) -> Self {
53 self.description = Some(description);
54 self
55 }
56
57 pub fn render(
59 &self,
60 variables: &HashMap<String, String>,
61 ) -> std::result::Result<String, TemplateError> {
62 for var in &self.required_vars {
64 if !variables.contains_key(var) {
65 return Err(TemplateError::MissingVariable(var.clone()));
66 }
67 }
68
69 let mut result = self.template.clone();
70
71 for (key, value) in variables {
73 let placeholder = format!("{{{{{}}}}}", key);
74 result = result.replace(&placeholder, value);
75 }
76
77 if result.contains("{{") && result.contains("}}") {
79 let start = result.find("{{").unwrap();
80 let end = result[start..].find("}}").unwrap() + start + 2;
81 let var_name = &result[start + 2..end - 2];
82 return Err(TemplateError::MissingVariable(var_name.to_string()));
83 }
84
85 Ok(result)
86 }
87
88 pub fn render_partial(&self, variables: &HashMap<String, String>) -> String {
90 let mut result = self.template.clone();
91
92 for (key, value) in variables {
93 let placeholder = format!("{{{{{}}}}}", key);
94 result = result.replace(&placeholder, value);
95 }
96
97 result
98 }
99
100 pub fn extract_variables(&self) -> Vec<String> {
102 let mut variables = Vec::new();
103 let mut chars = self.template.chars().peekable();
104
105 while let Some(c) = chars.next() {
106 if c == '{' {
107 if let Some(&'{') = chars.peek() {
108 chars.next(); let mut var_name = String::new();
110
111 while let Some(ch) = chars.next() {
113 if ch == '}' {
114 if let Some(&'}') = chars.peek() {
115 chars.next(); variables.push(var_name);
117 break;
118 }
119 }
120 var_name.push(ch);
121 }
122 }
123 }
124 }
125
126 variables
127 }
128}
129
130pub struct TemplateLibrary {
132 templates: HashMap<String, PromptTemplate>,
133}
134
135impl Default for TemplateLibrary {
136 fn default() -> Self {
137 Self::new()
138 }
139}
140
141impl TemplateLibrary {
142 pub fn new() -> Self {
144 let mut library = Self {
145 templates: HashMap::new(),
146 };
147
148 library.add_common_templates();
150 library
151 }
152
153 pub fn add(&mut self, name: String, template: PromptTemplate) {
155 self.templates.insert(name, template);
156 }
157
158 pub fn get(&self, name: &str) -> Option<&PromptTemplate> {
160 self.templates.get(name)
161 }
162
163 pub fn remove(&mut self, name: &str) -> Option<PromptTemplate> {
165 self.templates.remove(name)
166 }
167
168 pub fn list(&self) -> Vec<String> {
170 self.templates.keys().cloned().collect()
171 }
172
173 fn add_common_templates(&mut self) {
175 self.add(
177 "code_review".to_string(),
178 PromptTemplate::new(
179 "Review the following {{language}} code and provide feedback on:\n\
180 1. Code quality and best practices\n\
181 2. Potential bugs or issues\n\
182 3. Performance improvements\n\
183 4. Security concerns\n\n\
184 Code:\n```{{language}}\n{{code}}\n```"
185 .to_string(),
186 )
187 .with_required_vars(vec!["language".to_string(), "code".to_string()])
188 .with_description("Code review template for analyzing code quality".to_string()),
189 );
190
191 self.add(
193 "summarize".to_string(),
194 PromptTemplate::new(
195 "Summarize the following text in {{style}} style:\n\n{{text}}".to_string(),
196 )
197 .with_required_vars(vec!["text".to_string()])
198 .with_description("Text summarization template".to_string()),
199 );
200
201 self.add(
203 "qa".to_string(),
204 PromptTemplate::new(
205 "Context:\n{{context}}\n\nQuestion: {{question}}\n\nAnswer:".to_string(),
206 )
207 .with_required_vars(vec!["context".to_string(), "question".to_string()])
208 .with_description("Question answering with context".to_string()),
209 );
210
211 self.add(
213 "translate".to_string(),
214 PromptTemplate::new(
215 "Translate the following text from {{source_lang}} to {{target_lang}}:\n\n{{text}}"
216 .to_string(),
217 )
218 .with_required_vars(vec![
219 "source_lang".to_string(),
220 "target_lang".to_string(),
221 "text".to_string(),
222 ])
223 .with_description("Language translation template".to_string()),
224 );
225
226 self.add(
228 "classify".to_string(),
229 PromptTemplate::new(
230 "Classify the following text into one of these categories: {{categories}}\n\n\
231 Text: {{text}}\n\n\
232 Category:"
233 .to_string(),
234 )
235 .with_required_vars(vec!["categories".to_string(), "text".to_string()])
236 .with_description("Text classification template".to_string()),
237 );
238
239 self.add(
241 "extract".to_string(),
242 PromptTemplate::new(
243 "Extract the following information from the text:\n{{fields}}\n\n\
244 Text: {{text}}\n\n\
245 Extracted information (as JSON):"
246 .to_string(),
247 )
248 .with_required_vars(vec!["fields".to_string(), "text".to_string()])
249 .with_description("Structured data extraction template".to_string()),
250 );
251
252 self.add(
254 "chain_of_thought".to_string(),
255 PromptTemplate::new(
256 "{{task}}\n\n\
257 Let's approach this step-by-step:\n\
258 1. First, let's understand what we know\n\
259 2. Then, let's identify what we need to find\n\
260 3. Finally, let's solve the problem\n\n\
261 Input: {{input}}"
262 .to_string(),
263 )
264 .with_required_vars(vec!["task".to_string(), "input".to_string()])
265 .with_description("Chain of thought reasoning template".to_string()),
266 );
267
268 self.add(
270 "few_shot".to_string(),
271 PromptTemplate::new(
272 "{{task_description}}\n\n\
273 Examples:\n{{examples}}\n\n\
274 Now, apply the same pattern:\n{{input}}"
275 .to_string(),
276 )
277 .with_required_vars(vec![
278 "task_description".to_string(),
279 "examples".to_string(),
280 "input".to_string(),
281 ])
282 .with_description("Few-shot learning template".to_string()),
283 );
284 }
285}
286
287#[cfg(test)]
288mod tests {
289 use super::*;
290
291 #[test]
292 fn test_template_render() {
293 let template =
294 PromptTemplate::new("Hello {{name}}, you are {{age}} years old.".to_string());
295
296 let mut vars = HashMap::new();
297 vars.insert("name".to_string(), "Alice".to_string());
298 vars.insert("age".to_string(), "30".to_string());
299
300 let result = template.render(&vars).unwrap();
301 assert_eq!(result, "Hello Alice, you are 30 years old.");
302 }
303
304 #[test]
305 fn test_template_missing_variable() {
306 let template = PromptTemplate::new("Hello {{name}}".to_string())
307 .with_required_vars(vec!["name".to_string()]);
308
309 let vars = HashMap::new();
310 let result = template.render(&vars);
311 assert!(result.is_err());
312 assert!(matches!(result, Err(TemplateError::MissingVariable(_))));
313 }
314
315 #[test]
316 fn test_template_partial_render() {
317 let template = PromptTemplate::new("Hello {{name}}, {{greeting}}".to_string());
318
319 let mut vars = HashMap::new();
320 vars.insert("name".to_string(), "Bob".to_string());
321
322 let result = template.render_partial(&vars);
323 assert_eq!(result, "Hello Bob, {{greeting}}");
324 }
325
326 #[test]
327 fn test_extract_variables() {
328 let template = PromptTemplate::new("{{var1}} and {{var2}} and {{var3}}".to_string());
329 let vars = template.extract_variables();
330 assert_eq!(vars.len(), 3);
331 assert!(vars.contains(&"var1".to_string()));
332 assert!(vars.contains(&"var2".to_string()));
333 assert!(vars.contains(&"var3".to_string()));
334 }
335
336 #[test]
337 fn test_template_library() {
338 let library = TemplateLibrary::new();
339
340 assert!(library.get("code_review").is_some());
342 assert!(library.get("summarize").is_some());
343 assert!(library.get("qa").is_some());
344 assert!(library.get("translate").is_some());
345
346 let code_review = library.get("code_review").unwrap();
347 let vars_needed = code_review.extract_variables();
348 assert!(vars_needed.contains(&"language".to_string()));
349 assert!(vars_needed.contains(&"code".to_string()));
350 }
351
352 #[test]
353 fn test_code_review_template() {
354 let library = TemplateLibrary::new();
355 let template = library.get("code_review").unwrap();
356
357 let mut vars = HashMap::new();
358 vars.insert("language".to_string(), "Rust".to_string());
359 vars.insert(
360 "code".to_string(),
361 "fn main() { println!(\"Hello\"); }".to_string(),
362 );
363
364 let result = template.render(&vars).unwrap();
365 assert!(result.contains("Rust"));
366 assert!(result.contains("fn main()"));
367 }
368
369 #[test]
370 fn test_qa_template() {
371 let library = TemplateLibrary::new();
372 let template = library.get("qa").unwrap();
373
374 let mut vars = HashMap::new();
375 vars.insert("context".to_string(), "The sky is blue.".to_string());
376 vars.insert("question".to_string(), "What color is the sky?".to_string());
377
378 let result = template.render(&vars).unwrap();
379 assert!(result.contains("Context"));
380 assert!(result.contains("The sky is blue"));
381 assert!(result.contains("What color is the sky?"));
382 }
383
384 #[test]
385 fn test_custom_template_addition() {
386 let mut library = TemplateLibrary::new();
387
388 let custom = PromptTemplate::new("Custom: {{value}}".to_string());
389 library.add("custom".to_string(), custom);
390
391 assert!(library.get("custom").is_some());
392
393 let template = library.get("custom").unwrap();
394 let mut vars = HashMap::new();
395 vars.insert("value".to_string(), "test".to_string());
396
397 let result = template.render(&vars).unwrap();
398 assert_eq!(result, "Custom: test");
399 }
400}