1use crate::{Error, Result};
2use pforge_config::{ParamType, PromptDef};
3use serde_json::Value;
4use std::collections::HashMap;
5
6pub struct PromptManager {
8 prompts: HashMap<String, PromptEntry>,
9}
10
11struct PromptEntry {
12 description: String,
13 template: String,
14 arguments: HashMap<String, ParamType>,
15}
16
17impl PromptManager {
18 pub fn new() -> Self {
19 Self {
20 prompts: HashMap::new(),
21 }
22 }
23
24 pub fn register(&mut self, def: PromptDef) -> Result<()> {
26 if self.prompts.contains_key(&def.name) {
27 return Err(Error::Handler(format!(
28 "Prompt '{}' already registered",
29 def.name
30 )));
31 }
32
33 self.prompts.insert(
34 def.name.clone(),
35 PromptEntry {
36 description: def.description,
37 template: def.template,
38 arguments: def.arguments,
39 },
40 );
41
42 Ok(())
43 }
44
45 pub fn render(&self, name: &str, args: HashMap<String, Value>) -> Result<String> {
47 let entry = self
48 .prompts
49 .get(name)
50 .ok_or_else(|| Error::Handler(format!("Prompt '{}' not found", name)))?;
51
52 self.validate_arguments(entry, &args)?;
54
55 self.interpolate(&entry.template, &args)
57 }
58
59 pub fn get_prompt(&self, name: &str) -> Option<PromptMetadata> {
61 self.prompts.get(name).map(|entry| PromptMetadata {
62 description: entry.description.clone(),
63 arguments: entry.arguments.clone(),
64 })
65 }
66
67 pub fn list_prompts(&self) -> Vec<String> {
69 self.prompts.keys().cloned().collect()
70 }
71
72 fn validate_arguments(&self, entry: &PromptEntry, args: &HashMap<String, Value>) -> Result<()> {
74 for (arg_name, param_type) in &entry.arguments {
76 let is_required = match param_type {
77 ParamType::Complex { required, .. } => *required,
78 _ => false,
79 };
80
81 if is_required && !args.contains_key(arg_name) {
82 return Err(Error::Handler(format!(
83 "Required argument '{}' not provided",
84 arg_name
85 )));
86 }
87 }
88
89 Ok(())
91 }
92
93 fn interpolate(&self, template: &str, args: &HashMap<String, Value>) -> Result<String> {
96 let mut result = template.to_string();
97
98 for (key, value) in args {
99 let placeholder = format!("{{{{{}}}}}", key);
100 let replacement = match value {
101 Value::String(s) => s.clone(),
102 Value::Number(n) => n.to_string(),
103 Value::Bool(b) => b.to_string(),
104 Value::Null => String::new(),
105 _ => serde_json::to_string(value)
106 .map_err(|e| Error::Handler(format!("Failed to serialize value: {}", e)))?,
107 };
108
109 result = result.replace(&placeholder, &replacement);
110 }
111
112 if result.contains("{{") && result.contains("}}") {
114 let unresolved: Vec<&str> = result
116 .split("{{")
117 .skip(1)
118 .filter_map(|s| s.split("}}").next())
119 .collect();
120
121 if !unresolved.is_empty() {
122 return Err(Error::Handler(format!(
123 "Unresolved template variables: {}",
124 unresolved.join(", ")
125 )));
126 }
127 }
128
129 Ok(result)
130 }
131}
132
133impl Default for PromptManager {
134 fn default() -> Self {
135 Self::new()
136 }
137}
138
139#[derive(Debug, Clone)]
141pub struct PromptMetadata {
142 pub description: String,
143 pub arguments: HashMap<String, ParamType>,
144}
145
146#[cfg(test)]
147mod tests {
148 use super::*;
149 use pforge_config::SimpleType;
150 use serde_json::json;
151
152 #[test]
153 fn test_prompt_registration() {
154 let mut manager = PromptManager::new();
155
156 let def = PromptDef {
157 name: "greeting".to_string(),
158 description: "A simple greeting prompt".to_string(),
159 template: "Hello, {{name}}!".to_string(),
160 arguments: HashMap::new(),
161 };
162
163 manager.register(def).unwrap();
164 assert_eq!(manager.list_prompts(), vec!["greeting"]);
165 }
166
167 #[test]
168 fn test_duplicate_prompt_registration() {
169 let mut manager = PromptManager::new();
170
171 let def = PromptDef {
172 name: "test".to_string(),
173 description: "Test".to_string(),
174 template: "{{x}}".to_string(),
175 arguments: HashMap::new(),
176 };
177
178 manager.register(def.clone()).unwrap();
179 let result = manager.register(def);
180 assert!(result.is_err());
181 assert!(result
182 .unwrap_err()
183 .to_string()
184 .contains("already registered"));
185 }
186
187 #[test]
188 fn test_simple_interpolation() {
189 let mut manager = PromptManager::new();
190
191 let def = PromptDef {
192 name: "greeting".to_string(),
193 description: "Greeting".to_string(),
194 template: "Hello, {{name}}! You are {{age}} years old.".to_string(),
195 arguments: HashMap::new(),
196 };
197
198 manager.register(def).unwrap();
199
200 let mut args = HashMap::new();
201 args.insert("name".to_string(), json!("Alice"));
202 args.insert("age".to_string(), json!(30));
203
204 let result = manager.render("greeting", args).unwrap();
205 assert_eq!(result, "Hello, Alice! You are 30 years old.");
206 }
207
208 #[test]
209 fn test_required_argument_validation() {
210 let mut manager = PromptManager::new();
211
212 let mut arguments = HashMap::new();
213 arguments.insert(
214 "name".to_string(),
215 ParamType::Complex {
216 ty: SimpleType::String,
217 required: true,
218 default: None,
219 description: None,
220 validation: None,
221 },
222 );
223
224 let def = PromptDef {
225 name: "greeting".to_string(),
226 description: "Greeting".to_string(),
227 template: "Hello, {{name}}!".to_string(),
228 arguments,
229 };
230
231 manager.register(def).unwrap();
232
233 let args = HashMap::new();
234 let result = manager.render("greeting", args);
235 assert!(result.is_err());
236 assert!(result
237 .unwrap_err()
238 .to_string()
239 .contains("Required argument"));
240 }
241
242 #[test]
243 fn test_unresolved_placeholder() {
244 let mut manager = PromptManager::new();
245
246 let def = PromptDef {
247 name: "test".to_string(),
248 description: "Test".to_string(),
249 template: "Hello, {{name}}! Welcome to {{location}}.".to_string(),
250 arguments: HashMap::new(),
251 };
252
253 manager.register(def).unwrap();
254
255 let mut args = HashMap::new();
256 args.insert("name".to_string(), json!("Alice"));
257 let result = manager.render("test", args);
260 assert!(result.is_err());
261 assert!(result
262 .unwrap_err()
263 .to_string()
264 .contains("Unresolved template variables"));
265 }
266
267 #[test]
268 fn test_get_prompt_metadata() {
269 let mut manager = PromptManager::new();
270
271 let mut arguments = HashMap::new();
272 arguments.insert(
273 "name".to_string(),
274 ParamType::Complex {
275 ty: SimpleType::String,
276 required: true,
277 default: None,
278 description: Some("User name".to_string()),
279 validation: None,
280 },
281 );
282
283 let def = PromptDef {
284 name: "greeting".to_string(),
285 description: "A greeting prompt".to_string(),
286 template: "Hello, {{name}}!".to_string(),
287 arguments,
288 };
289
290 manager.register(def).unwrap();
291
292 let metadata = manager.get_prompt("greeting").unwrap();
293 assert_eq!(metadata.description, "A greeting prompt");
294 assert!(metadata.arguments.contains_key("name"));
295 }
296
297 #[test]
298 fn test_complex_value_interpolation() {
299 let mut manager = PromptManager::new();
300
301 let def = PromptDef {
302 name: "test".to_string(),
303 description: "Test".to_string(),
304 template: "String: {{str}}, Number: {{num}}, Bool: {{bool}}".to_string(),
305 arguments: HashMap::new(),
306 };
307
308 manager.register(def).unwrap();
309
310 let mut args = HashMap::new();
311 args.insert("str".to_string(), json!("hello"));
312 args.insert("num".to_string(), json!(42));
313 args.insert("bool".to_string(), json!(true));
314
315 let result = manager.render("test", args).unwrap();
316 assert_eq!(result, "String: hello, Number: 42, Bool: true");
317 }
318}