pforge_runtime/
prompt.rs

1use crate::{Error, Result};
2use pforge_config::{ParamType, PromptDef};
3use serde_json::Value;
4use std::collections::HashMap;
5
6/// Prompt manager handles prompt rendering with template interpolation
7pub struct PromptManager {
8    prompts: HashMap<String, PromptEntry>,
9}
10
11struct PromptEntry {
12    description: String,
13    template: String,
14    arguments: HashMap<String, ParamType>,
15}
16
17impl PromptManager {
18    pub fn new() -> Self {
19        Self {
20            prompts: HashMap::new(),
21        }
22    }
23
24    /// Register a prompt definition
25    pub fn register(&mut self, def: PromptDef) -> Result<()> {
26        if self.prompts.contains_key(&def.name) {
27            return Err(Error::Handler(format!(
28                "Prompt '{}' already registered",
29                def.name
30            )));
31        }
32
33        self.prompts.insert(
34            def.name.clone(),
35            PromptEntry {
36                description: def.description,
37                template: def.template,
38                arguments: def.arguments,
39            },
40        );
41
42        Ok(())
43    }
44
45    /// Render a prompt with given arguments
46    pub fn render(&self, name: &str, args: HashMap<String, Value>) -> Result<String> {
47        let entry = self
48            .prompts
49            .get(name)
50            .ok_or_else(|| Error::Handler(format!("Prompt '{}' not found", name)))?;
51
52        // Validate arguments
53        self.validate_arguments(entry, &args)?;
54
55        // Perform template interpolation
56        self.interpolate(&entry.template, &args)
57    }
58
59    /// Get prompt metadata
60    pub fn get_prompt(&self, name: &str) -> Option<PromptMetadata> {
61        self.prompts.get(name).map(|entry| PromptMetadata {
62            description: entry.description.clone(),
63            arguments: entry.arguments.clone(),
64        })
65    }
66
67    /// List all registered prompts
68    pub fn list_prompts(&self) -> Vec<String> {
69        self.prompts.keys().cloned().collect()
70    }
71
72    /// Validate arguments against schema
73    fn validate_arguments(
74        &self,
75        entry: &PromptEntry,
76        args: &HashMap<String, Value>,
77    ) -> Result<()> {
78        // Check required arguments
79        for (arg_name, param_type) in &entry.arguments {
80            let is_required = match param_type {
81                ParamType::Complex { required, .. } => *required,
82                _ => false,
83            };
84
85            if is_required && !args.contains_key(arg_name) {
86                return Err(Error::Handler(format!(
87                    "Required argument '{}' not provided",
88                    arg_name
89                )));
90            }
91        }
92
93        // Type validation could be added here
94        Ok(())
95    }
96
97    /// Interpolate template with argument values
98    /// Supports {{variable}} syntax
99    fn interpolate(&self, template: &str, args: &HashMap<String, Value>) -> Result<String> {
100        let mut result = template.to_string();
101
102        for (key, value) in args {
103            let placeholder = format!("{{{{{}}}}}", key);
104            let replacement = match value {
105                Value::String(s) => s.clone(),
106                Value::Number(n) => n.to_string(),
107                Value::Bool(b) => b.to_string(),
108                Value::Null => String::new(),
109                _ => serde_json::to_string(value)
110                    .map_err(|e| Error::Handler(format!("Failed to serialize value: {}", e)))?,
111            };
112
113            result = result.replace(&placeholder, &replacement);
114        }
115
116        // Check for unresolved placeholders
117        if result.contains("{{") && result.contains("}}") {
118            // Extract unresolved variable names for better error message
119            let unresolved: Vec<&str> = result
120                .split("{{")
121                .skip(1)
122                .filter_map(|s| s.split("}}").next())
123                .collect();
124
125            if !unresolved.is_empty() {
126                return Err(Error::Handler(format!(
127                    "Unresolved template variables: {}",
128                    unresolved.join(", ")
129                )));
130            }
131        }
132
133        Ok(result)
134    }
135}
136
137impl Default for PromptManager {
138    fn default() -> Self {
139        Self::new()
140    }
141}
142
143/// Prompt metadata for discovery
144#[derive(Debug, Clone)]
145pub struct PromptMetadata {
146    pub description: String,
147    pub arguments: HashMap<String, ParamType>,
148}
149
150#[cfg(test)]
151mod tests {
152    use super::*;
153    use pforge_config::SimpleType;
154    use serde_json::json;
155
156    #[test]
157    fn test_prompt_registration() {
158        let mut manager = PromptManager::new();
159
160        let def = PromptDef {
161            name: "greeting".to_string(),
162            description: "A simple greeting prompt".to_string(),
163            template: "Hello, {{name}}!".to_string(),
164            arguments: HashMap::new(),
165        };
166
167        manager.register(def).unwrap();
168        assert_eq!(manager.list_prompts(), vec!["greeting"]);
169    }
170
171    #[test]
172    fn test_duplicate_prompt_registration() {
173        let mut manager = PromptManager::new();
174
175        let def = PromptDef {
176            name: "test".to_string(),
177            description: "Test".to_string(),
178            template: "{{x}}".to_string(),
179            arguments: HashMap::new(),
180        };
181
182        manager.register(def.clone()).unwrap();
183        let result = manager.register(def);
184        assert!(result.is_err());
185        assert!(result.unwrap_err().to_string().contains("already registered"));
186    }
187
188    #[test]
189    fn test_simple_interpolation() {
190        let mut manager = PromptManager::new();
191
192        let def = PromptDef {
193            name: "greeting".to_string(),
194            description: "Greeting".to_string(),
195            template: "Hello, {{name}}! You are {{age}} years old.".to_string(),
196            arguments: HashMap::new(),
197        };
198
199        manager.register(def).unwrap();
200
201        let mut args = HashMap::new();
202        args.insert("name".to_string(), json!("Alice"));
203        args.insert("age".to_string(), json!(30));
204
205        let result = manager.render("greeting", args).unwrap();
206        assert_eq!(result, "Hello, Alice! You are 30 years old.");
207    }
208
209    #[test]
210    fn test_required_argument_validation() {
211        let mut manager = PromptManager::new();
212
213        let mut arguments = HashMap::new();
214        arguments.insert(
215            "name".to_string(),
216            ParamType::Complex {
217                ty: SimpleType::String,
218                required: true,
219                default: None,
220                description: None,
221                validation: None,
222            },
223        );
224
225        let def = PromptDef {
226            name: "greeting".to_string(),
227            description: "Greeting".to_string(),
228            template: "Hello, {{name}}!".to_string(),
229            arguments,
230        };
231
232        manager.register(def).unwrap();
233
234        let args = HashMap::new();
235        let result = manager.render("greeting", args);
236        assert!(result.is_err());
237        assert!(result.unwrap_err().to_string().contains("Required argument"));
238    }
239
240    #[test]
241    fn test_unresolved_placeholder() {
242        let mut manager = PromptManager::new();
243
244        let def = PromptDef {
245            name: "test".to_string(),
246            description: "Test".to_string(),
247            template: "Hello, {{name}}! Welcome to {{location}}.".to_string(),
248            arguments: HashMap::new(),
249        };
250
251        manager.register(def).unwrap();
252
253        let mut args = HashMap::new();
254        args.insert("name".to_string(), json!("Alice"));
255        // Missing 'location' argument
256
257        let result = manager.render("test", args);
258        assert!(result.is_err());
259        assert!(result.unwrap_err().to_string().contains("Unresolved template variables"));
260    }
261
262    #[test]
263    fn test_get_prompt_metadata() {
264        let mut manager = PromptManager::new();
265
266        let mut arguments = HashMap::new();
267        arguments.insert(
268            "name".to_string(),
269            ParamType::Complex {
270                ty: SimpleType::String,
271                required: true,
272                default: None,
273                description: Some("User name".to_string()),
274                validation: None,
275            },
276        );
277
278        let def = PromptDef {
279            name: "greeting".to_string(),
280            description: "A greeting prompt".to_string(),
281            template: "Hello, {{name}}!".to_string(),
282            arguments,
283        };
284
285        manager.register(def).unwrap();
286
287        let metadata = manager.get_prompt("greeting").unwrap();
288        assert_eq!(metadata.description, "A greeting prompt");
289        assert!(metadata.arguments.contains_key("name"));
290    }
291
292    #[test]
293    fn test_complex_value_interpolation() {
294        let mut manager = PromptManager::new();
295
296        let def = PromptDef {
297            name: "test".to_string(),
298            description: "Test".to_string(),
299            template: "String: {{str}}, Number: {{num}}, Bool: {{bool}}".to_string(),
300            arguments: HashMap::new(),
301        };
302
303        manager.register(def).unwrap();
304
305        let mut args = HashMap::new();
306        args.insert("str".to_string(), json!("hello"));
307        args.insert("num".to_string(), json!(42));
308        args.insert("bool".to_string(), json!(true));
309
310        let result = manager.render("test", args).unwrap();
311        assert_eq!(result, "String: hello, Number: 42, Bool: true");
312    }
313}