1use parking_lot::RwLock;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9use crate::types::Layer2Result;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct PromptTemplate {
14 pub name: String,
15 pub template: String,
16 pub description: String,
17 pub variables: Vec<String>,
18 pub metadata: HashMap<String, String>,
19}
20
21impl PromptTemplate {
22 pub fn new(name: impl Into<String>, template: impl Into<String>) -> Self {
23 Self {
24 name: name.into(),
25 template: template.into(),
26 description: String::new(),
27 variables: Vec::new(),
28 metadata: HashMap::new(),
29 }
30 }
31
32 pub fn with_description(mut self, description: impl Into<String>) -> Self {
33 self.description = description.into();
34 self
35 }
36
37 pub fn render(&self, context: &HashMap<String, String>) -> String {
39 let mut result = self.template.clone();
40
41 for var in &self.variables {
42 if let Some(value) = context.get(var) {
43 result = result.replace(&format!("{{{{{}}}}}", var), value);
44 }
45 }
46
47 result
48 }
49
50 pub fn extract_variables(&mut self) {
52 use regex::Regex;
53
54 let re = Regex::new(r"\{\{(\w+)\}\}").unwrap();
55 self.variables = re
56 .captures_iter(&self.template)
57 .filter_map(|cap| cap.get(1).map(|m| m.as_str().to_string()))
58 .collect();
59 }
60}
61
62pub trait PromptManagerTrait: Send + Sync {
64 fn register(&self, template: PromptTemplate) -> Layer2Result<()>;
66
67 fn unregister(&self, name: &str) -> Layer2Result<bool>;
69
70 fn get(&self, name: &str) -> Option<PromptTemplate>;
72
73 fn render(&self, name: &str, context: &HashMap<String, String>) -> Layer2Result<String>;
75
76 fn list(&self) -> Vec<String>;
78
79 fn count(&self) -> usize;
81}
82
83pub struct PromptManager {
85 templates: RwLock<HashMap<String, PromptTemplate>>,
86}
87
88impl PromptManager {
89 pub fn new() -> Self {
90 Self {
91 templates: RwLock::new(HashMap::new()),
92 }
93 }
94
95 pub fn with_defaults() -> Self {
97 let manager = Self::new();
98
99 manager.register_default_templates();
101
102 manager
103 }
104
105 fn register_default_templates(&self) {
106 let system = PromptTemplate::new(
107 "system",
108 "You are a helpful AI assistant. Be concise and accurate.",
109 )
110 .with_description("Default system prompt");
111
112 let code_review = PromptTemplate::new(
113 "code_review",
114 "Review the following code and provide feedback:\n\n{{code}}\n\nFocus on: {{focus_areas}}"
115 )
116 .with_description("Code review prompt template");
117
118 let task = PromptTemplate::new(
119 "task",
120 "Task: {{task_description}}\n\nContext: {{context}}\n\nPlease complete this task.",
121 )
122 .with_description("General task prompt template");
123
124 let _ = self.register(system);
125 let _ = self.register(code_review);
126 let _ = self.register(task);
127 }
128}
129
130impl Default for PromptManager {
131 fn default() -> Self {
132 Self::with_defaults()
133 }
134}
135
136impl PromptManagerTrait for PromptManager {
137 fn register(&self, template: PromptTemplate) -> Layer2Result<()> {
138 let name = template.name.clone();
139 self.templates.write().insert(name, template);
140 Ok(())
141 }
142
143 fn unregister(&self, name: &str) -> Layer2Result<bool> {
144 Ok(self.templates.write().remove(name).is_some())
145 }
146
147 fn get(&self, name: &str) -> Option<PromptTemplate> {
148 self.templates.read().get(name).cloned()
149 }
150
151 fn render(&self, name: &str, context: &HashMap<String, String>) -> Layer2Result<String> {
152 let templates = self.templates.read();
153
154 let template = templates
155 .get(name)
156 .ok_or_else(|| anyhow::anyhow!("Template not found: {}", name))?;
157
158 Ok(template.render(context))
159 }
160
161 fn list(&self) -> Vec<String> {
162 self.templates.read().keys().cloned().collect()
163 }
164
165 fn count(&self) -> usize {
166 self.templates.read().len()
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173
174 #[test]
175 fn test_prompt_template() {
176 let mut template = PromptTemplate::new("test", "Hello {{name}}!");
177 template.extract_variables();
178
179 let mut context = HashMap::new();
180 context.insert("name".to_string(), "World".to_string());
181
182 let result = template.render(&context);
183 assert_eq!(result, "Hello World!");
184 }
185
186 #[test]
187 fn test_prompt_manager() {
188 let manager = PromptManager::new();
189
190 let template = PromptTemplate::new("test", "Test template");
191 manager.register(template).unwrap();
192
193 assert_eq!(manager.count(), 1);
194 assert!(manager.get("test").is_some());
195 }
196
197 #[test]
198 fn test_prompt_manager_defaults() {
199 let manager = PromptManager::with_defaults();
200
201 assert!(manager.get("system").is_some());
202 assert!(manager.get("code_review").is_some());
203 assert!(manager.get("task").is_some());
204 }
205
206 #[test]
207 fn test_render_template() {
208 let manager = PromptManager::with_defaults();
209
210 let mut context = HashMap::new();
211 context.insert("name".to_string(), "World".to_string());
212
213 let templates = manager.list();
215 assert!(!templates.is_empty());
216 }
217}