1use crate::{Error, Result};
2use pforge_config::{ParamType, PromptDef};
3use serde_json::Value;
4use std::collections::HashMap;
5
6pub struct PromptManager {
8 prompts: HashMap<String, PromptEntry>,
9}
10
11struct PromptEntry {
12 description: String,
13 template: String,
14 arguments: HashMap<String, ParamType>,
15}
16
17impl PromptManager {
18 pub fn new() -> Self {
19 Self {
20 prompts: HashMap::new(),
21 }
22 }
23
24 pub fn register(&mut self, def: PromptDef) -> Result<()> {
26 if self.prompts.contains_key(&def.name) {
27 return Err(Error::Handler(format!(
28 "Prompt '{}' already registered",
29 def.name
30 )));
31 }
32
33 self.prompts.insert(
34 def.name.clone(),
35 PromptEntry {
36 description: def.description,
37 template: def.template,
38 arguments: def.arguments,
39 },
40 );
41
42 Ok(())
43 }
44
45 pub fn render(&self, name: &str, args: HashMap<String, Value>) -> Result<String> {
47 let entry = self
48 .prompts
49 .get(name)
50 .ok_or_else(|| Error::Handler(format!("Prompt '{}' not found", name)))?;
51
52 self.validate_arguments(entry, &args)?;
54
55 self.interpolate(&entry.template, &args)
57 }
58
59 pub fn get_prompt(&self, name: &str) -> Option<PromptMetadata> {
61 self.prompts.get(name).map(|entry| PromptMetadata {
62 description: entry.description.clone(),
63 arguments: entry.arguments.clone(),
64 })
65 }
66
67 pub fn list_prompts(&self) -> Vec<String> {
69 self.prompts.keys().cloned().collect()
70 }
71
72 fn validate_arguments(
74 &self,
75 entry: &PromptEntry,
76 args: &HashMap<String, Value>,
77 ) -> Result<()> {
78 for (arg_name, param_type) in &entry.arguments {
80 let is_required = match param_type {
81 ParamType::Complex { required, .. } => *required,
82 _ => false,
83 };
84
85 if is_required && !args.contains_key(arg_name) {
86 return Err(Error::Handler(format!(
87 "Required argument '{}' not provided",
88 arg_name
89 )));
90 }
91 }
92
93 Ok(())
95 }
96
97 fn interpolate(&self, template: &str, args: &HashMap<String, Value>) -> Result<String> {
100 let mut result = template.to_string();
101
102 for (key, value) in args {
103 let placeholder = format!("{{{{{}}}}}", key);
104 let replacement = match value {
105 Value::String(s) => s.clone(),
106 Value::Number(n) => n.to_string(),
107 Value::Bool(b) => b.to_string(),
108 Value::Null => String::new(),
109 _ => serde_json::to_string(value)
110 .map_err(|e| Error::Handler(format!("Failed to serialize value: {}", e)))?,
111 };
112
113 result = result.replace(&placeholder, &replacement);
114 }
115
116 if result.contains("{{") && result.contains("}}") {
118 let unresolved: Vec<&str> = result
120 .split("{{")
121 .skip(1)
122 .filter_map(|s| s.split("}}").next())
123 .collect();
124
125 if !unresolved.is_empty() {
126 return Err(Error::Handler(format!(
127 "Unresolved template variables: {}",
128 unresolved.join(", ")
129 )));
130 }
131 }
132
133 Ok(result)
134 }
135}
136
137impl Default for PromptManager {
138 fn default() -> Self {
139 Self::new()
140 }
141}
142
143#[derive(Debug, Clone)]
145pub struct PromptMetadata {
146 pub description: String,
147 pub arguments: HashMap<String, ParamType>,
148}
149
150#[cfg(test)]
151mod tests {
152 use super::*;
153 use pforge_config::SimpleType;
154 use serde_json::json;
155
156 #[test]
157 fn test_prompt_registration() {
158 let mut manager = PromptManager::new();
159
160 let def = PromptDef {
161 name: "greeting".to_string(),
162 description: "A simple greeting prompt".to_string(),
163 template: "Hello, {{name}}!".to_string(),
164 arguments: HashMap::new(),
165 };
166
167 manager.register(def).unwrap();
168 assert_eq!(manager.list_prompts(), vec!["greeting"]);
169 }
170
171 #[test]
172 fn test_duplicate_prompt_registration() {
173 let mut manager = PromptManager::new();
174
175 let def = PromptDef {
176 name: "test".to_string(),
177 description: "Test".to_string(),
178 template: "{{x}}".to_string(),
179 arguments: HashMap::new(),
180 };
181
182 manager.register(def.clone()).unwrap();
183 let result = manager.register(def);
184 assert!(result.is_err());
185 assert!(result.unwrap_err().to_string().contains("already registered"));
186 }
187
188 #[test]
189 fn test_simple_interpolation() {
190 let mut manager = PromptManager::new();
191
192 let def = PromptDef {
193 name: "greeting".to_string(),
194 description: "Greeting".to_string(),
195 template: "Hello, {{name}}! You are {{age}} years old.".to_string(),
196 arguments: HashMap::new(),
197 };
198
199 manager.register(def).unwrap();
200
201 let mut args = HashMap::new();
202 args.insert("name".to_string(), json!("Alice"));
203 args.insert("age".to_string(), json!(30));
204
205 let result = manager.render("greeting", args).unwrap();
206 assert_eq!(result, "Hello, Alice! You are 30 years old.");
207 }
208
209 #[test]
210 fn test_required_argument_validation() {
211 let mut manager = PromptManager::new();
212
213 let mut arguments = HashMap::new();
214 arguments.insert(
215 "name".to_string(),
216 ParamType::Complex {
217 ty: SimpleType::String,
218 required: true,
219 default: None,
220 description: None,
221 validation: None,
222 },
223 );
224
225 let def = PromptDef {
226 name: "greeting".to_string(),
227 description: "Greeting".to_string(),
228 template: "Hello, {{name}}!".to_string(),
229 arguments,
230 };
231
232 manager.register(def).unwrap();
233
234 let args = HashMap::new();
235 let result = manager.render("greeting", args);
236 assert!(result.is_err());
237 assert!(result.unwrap_err().to_string().contains("Required argument"));
238 }
239
240 #[test]
241 fn test_unresolved_placeholder() {
242 let mut manager = PromptManager::new();
243
244 let def = PromptDef {
245 name: "test".to_string(),
246 description: "Test".to_string(),
247 template: "Hello, {{name}}! Welcome to {{location}}.".to_string(),
248 arguments: HashMap::new(),
249 };
250
251 manager.register(def).unwrap();
252
253 let mut args = HashMap::new();
254 args.insert("name".to_string(), json!("Alice"));
255 let result = manager.render("test", args);
258 assert!(result.is_err());
259 assert!(result.unwrap_err().to_string().contains("Unresolved template variables"));
260 }
261
262 #[test]
263 fn test_get_prompt_metadata() {
264 let mut manager = PromptManager::new();
265
266 let mut arguments = HashMap::new();
267 arguments.insert(
268 "name".to_string(),
269 ParamType::Complex {
270 ty: SimpleType::String,
271 required: true,
272 default: None,
273 description: Some("User name".to_string()),
274 validation: None,
275 },
276 );
277
278 let def = PromptDef {
279 name: "greeting".to_string(),
280 description: "A greeting prompt".to_string(),
281 template: "Hello, {{name}}!".to_string(),
282 arguments,
283 };
284
285 manager.register(def).unwrap();
286
287 let metadata = manager.get_prompt("greeting").unwrap();
288 assert_eq!(metadata.description, "A greeting prompt");
289 assert!(metadata.arguments.contains_key("name"));
290 }
291
292 #[test]
293 fn test_complex_value_interpolation() {
294 let mut manager = PromptManager::new();
295
296 let def = PromptDef {
297 name: "test".to_string(),
298 description: "Test".to_string(),
299 template: "String: {{str}}, Number: {{num}}, Bool: {{bool}}".to_string(),
300 arguments: HashMap::new(),
301 };
302
303 manager.register(def).unwrap();
304
305 let mut args = HashMap::new();
306 args.insert("str".to_string(), json!("hello"));
307 args.insert("num".to_string(), json!(42));
308 args.insert("bool".to_string(), json!(true));
309
310 let result = manager.render("test", args).unwrap();
311 assert_eq!(result, "String: hello, Number: 42, Bool: true");
312 }
313}