pforge_runtime/
prompt.rs

1use crate::{Error, Result};
2use pforge_config::{ParamType, PromptDef};
3use serde_json::Value;
4use std::collections::HashMap;
5
6/// Prompt manager handles prompt rendering with template interpolation
7pub struct PromptManager {
8    prompts: HashMap<String, PromptEntry>,
9}
10
11struct PromptEntry {
12    description: String,
13    template: String,
14    arguments: HashMap<String, ParamType>,
15}
16
17impl PromptManager {
18    pub fn new() -> Self {
19        Self {
20            prompts: HashMap::new(),
21        }
22    }
23
24    /// Register a prompt definition
25    pub fn register(&mut self, def: PromptDef) -> Result<()> {
26        if self.prompts.contains_key(&def.name) {
27            return Err(Error::Handler(format!(
28                "Prompt '{}' already registered",
29                def.name
30            )));
31        }
32
33        self.prompts.insert(
34            def.name.clone(),
35            PromptEntry {
36                description: def.description,
37                template: def.template,
38                arguments: def.arguments,
39            },
40        );
41
42        Ok(())
43    }
44
45    /// Render a prompt with given arguments
46    pub fn render(&self, name: &str, args: HashMap<String, Value>) -> Result<String> {
47        let entry = self
48            .prompts
49            .get(name)
50            .ok_or_else(|| Error::Handler(format!("Prompt '{}' not found", name)))?;
51
52        // Validate arguments
53        self.validate_arguments(entry, &args)?;
54
55        // Perform template interpolation
56        self.interpolate(&entry.template, &args)
57    }
58
59    /// Get prompt metadata
60    pub fn get_prompt(&self, name: &str) -> Option<PromptMetadata> {
61        self.prompts.get(name).map(|entry| PromptMetadata {
62            description: entry.description.clone(),
63            arguments: entry.arguments.clone(),
64        })
65    }
66
67    /// List all registered prompts
68    pub fn list_prompts(&self) -> Vec<String> {
69        self.prompts.keys().cloned().collect()
70    }
71
72    /// Validate arguments against schema
73    fn validate_arguments(&self, entry: &PromptEntry, args: &HashMap<String, Value>) -> Result<()> {
74        // Check required arguments
75        for (arg_name, param_type) in &entry.arguments {
76            let is_required = match param_type {
77                ParamType::Complex { required, .. } => *required,
78                _ => false,
79            };
80
81            if is_required && !args.contains_key(arg_name) {
82                return Err(Error::Handler(format!(
83                    "Required argument '{}' not provided",
84                    arg_name
85                )));
86            }
87        }
88
89        // Type validation could be added here
90        Ok(())
91    }
92
93    /// Interpolate template with argument values
94    /// Supports {{variable}} syntax
95    fn interpolate(&self, template: &str, args: &HashMap<String, Value>) -> Result<String> {
96        let mut result = template.to_string();
97
98        for (key, value) in args {
99            let placeholder = format!("{{{{{}}}}}", key);
100            let replacement = match value {
101                Value::String(s) => s.clone(),
102                Value::Number(n) => n.to_string(),
103                Value::Bool(b) => b.to_string(),
104                Value::Null => String::new(),
105                _ => serde_json::to_string(value)
106                    .map_err(|e| Error::Handler(format!("Failed to serialize value: {}", e)))?,
107            };
108
109            result = result.replace(&placeholder, &replacement);
110        }
111
112        // Check for unresolved placeholders
113        if result.contains("{{") && result.contains("}}") {
114            // Extract unresolved variable names for better error message
115            let unresolved: Vec<&str> = result
116                .split("{{")
117                .skip(1)
118                .filter_map(|s| s.split("}}").next())
119                .collect();
120
121            if !unresolved.is_empty() {
122                return Err(Error::Handler(format!(
123                    "Unresolved template variables: {}",
124                    unresolved.join(", ")
125                )));
126            }
127        }
128
129        Ok(result)
130    }
131}
132
133impl Default for PromptManager {
134    fn default() -> Self {
135        Self::new()
136    }
137}
138
139/// Prompt metadata for discovery
140#[derive(Debug, Clone)]
141pub struct PromptMetadata {
142    pub description: String,
143    pub arguments: HashMap<String, ParamType>,
144}
145
146#[cfg(test)]
147mod tests {
148    use super::*;
149    use pforge_config::SimpleType;
150    use serde_json::json;
151
152    #[test]
153    fn test_prompt_registration() {
154        let mut manager = PromptManager::new();
155
156        let def = PromptDef {
157            name: "greeting".to_string(),
158            description: "A simple greeting prompt".to_string(),
159            template: "Hello, {{name}}!".to_string(),
160            arguments: HashMap::new(),
161        };
162
163        manager.register(def).unwrap();
164        assert_eq!(manager.list_prompts(), vec!["greeting"]);
165    }
166
167    #[test]
168    fn test_duplicate_prompt_registration() {
169        let mut manager = PromptManager::new();
170
171        let def = PromptDef {
172            name: "test".to_string(),
173            description: "Test".to_string(),
174            template: "{{x}}".to_string(),
175            arguments: HashMap::new(),
176        };
177
178        manager.register(def.clone()).unwrap();
179        let result = manager.register(def);
180        assert!(result.is_err());
181        assert!(result
182            .unwrap_err()
183            .to_string()
184            .contains("already registered"));
185    }
186
187    #[test]
188    fn test_simple_interpolation() {
189        let mut manager = PromptManager::new();
190
191        let def = PromptDef {
192            name: "greeting".to_string(),
193            description: "Greeting".to_string(),
194            template: "Hello, {{name}}! You are {{age}} years old.".to_string(),
195            arguments: HashMap::new(),
196        };
197
198        manager.register(def).unwrap();
199
200        let mut args = HashMap::new();
201        args.insert("name".to_string(), json!("Alice"));
202        args.insert("age".to_string(), json!(30));
203
204        let result = manager.render("greeting", args).unwrap();
205        assert_eq!(result, "Hello, Alice! You are 30 years old.");
206    }
207
208    #[test]
209    fn test_required_argument_validation() {
210        let mut manager = PromptManager::new();
211
212        let mut arguments = HashMap::new();
213        arguments.insert(
214            "name".to_string(),
215            ParamType::Complex {
216                ty: SimpleType::String,
217                required: true,
218                default: None,
219                description: None,
220                validation: None,
221            },
222        );
223
224        let def = PromptDef {
225            name: "greeting".to_string(),
226            description: "Greeting".to_string(),
227            template: "Hello, {{name}}!".to_string(),
228            arguments,
229        };
230
231        manager.register(def).unwrap();
232
233        let args = HashMap::new();
234        let result = manager.render("greeting", args);
235        assert!(result.is_err());
236        assert!(result
237            .unwrap_err()
238            .to_string()
239            .contains("Required argument"));
240    }
241
242    #[test]
243    fn test_unresolved_placeholder() {
244        let mut manager = PromptManager::new();
245
246        let def = PromptDef {
247            name: "test".to_string(),
248            description: "Test".to_string(),
249            template: "Hello, {{name}}! Welcome to {{location}}.".to_string(),
250            arguments: HashMap::new(),
251        };
252
253        manager.register(def).unwrap();
254
255        let mut args = HashMap::new();
256        args.insert("name".to_string(), json!("Alice"));
257        // Missing 'location' argument
258
259        let result = manager.render("test", args);
260        assert!(result.is_err());
261        assert!(result
262            .unwrap_err()
263            .to_string()
264            .contains("Unresolved template variables"));
265    }
266
267    #[test]
268    fn test_get_prompt_metadata() {
269        let mut manager = PromptManager::new();
270
271        let mut arguments = HashMap::new();
272        arguments.insert(
273            "name".to_string(),
274            ParamType::Complex {
275                ty: SimpleType::String,
276                required: true,
277                default: None,
278                description: Some("User name".to_string()),
279                validation: None,
280            },
281        );
282
283        let def = PromptDef {
284            name: "greeting".to_string(),
285            description: "A greeting prompt".to_string(),
286            template: "Hello, {{name}}!".to_string(),
287            arguments,
288        };
289
290        manager.register(def).unwrap();
291
292        let metadata = manager.get_prompt("greeting").unwrap();
293        assert_eq!(metadata.description, "A greeting prompt");
294        assert!(metadata.arguments.contains_key("name"));
295    }
296
297    #[test]
298    fn test_complex_value_interpolation() {
299        let mut manager = PromptManager::new();
300
301        let def = PromptDef {
302            name: "test".to_string(),
303            description: "Test".to_string(),
304            template: "String: {{str}}, Number: {{num}}, Bool: {{bool}}".to_string(),
305            arguments: HashMap::new(),
306        };
307
308        manager.register(def).unwrap();
309
310        let mut args = HashMap::new();
311        args.insert("str".to_string(), json!("hello"));
312        args.insert("num".to_string(), json!(42));
313        args.insert("bool".to_string(), json!(true));
314
315        let result = manager.render("test", args).unwrap();
316        assert_eq!(result, "String: hello, Number: 42, Bool: true");
317    }
318}