Skip to main content

sh_layer2/
prompts.rs

1//! # Prompt Manager
2//!
3//! 提示词管理,支持模板化和动态生成。
4
5use parking_lot::RwLock;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9use crate::types::Layer2Result;
10
11/// 提示词模板
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct PromptTemplate {
14    pub name: String,
15    pub template: String,
16    pub description: String,
17    pub variables: Vec<String>,
18    pub metadata: HashMap<String, String>,
19}
20
21impl PromptTemplate {
22    pub fn new(name: impl Into<String>, template: impl Into<String>) -> Self {
23        Self {
24            name: name.into(),
25            template: template.into(),
26            description: String::new(),
27            variables: Vec::new(),
28            metadata: HashMap::new(),
29        }
30    }
31
32    pub fn with_description(mut self, description: impl Into<String>) -> Self {
33        self.description = description.into();
34        self
35    }
36
37    /// 渲染模板
38    pub fn render(&self, context: &HashMap<String, String>) -> String {
39        let mut result = self.template.clone();
40
41        for var in &self.variables {
42            if let Some(value) = context.get(var) {
43                result = result.replace(&format!("{{{{{}}}}}", var), value);
44            }
45        }
46
47        result
48    }
49
50    /// 提取模板中的变量
51    pub fn extract_variables(&mut self) {
52        use regex::Regex;
53
54        let re = Regex::new(r"\{\{(\w+)\}\}").unwrap();
55        self.variables = re
56            .captures_iter(&self.template)
57            .filter_map(|cap| cap.get(1).map(|m| m.as_str().to_string()))
58            .collect();
59    }
60}
61
62/// 提示词管理器接口
63pub trait PromptManagerTrait: Send + Sync {
64    /// 注册模板
65    fn register(&self, template: PromptTemplate) -> Layer2Result<()>;
66
67    /// 注销模板
68    fn unregister(&self, name: &str) -> Layer2Result<bool>;
69
70    /// 获取模板
71    fn get(&self, name: &str) -> Option<PromptTemplate>;
72
73    /// 渲染提示词
74    fn render(&self, name: &str, context: &HashMap<String, String>) -> Layer2Result<String>;
75
76    /// 列出所有模板名称
77    fn list(&self) -> Vec<String>;
78
79    /// 模板数量
80    fn count(&self) -> usize;
81}
82
83/// 提示词管理器实现
84pub struct PromptManager {
85    templates: RwLock<HashMap<String, PromptTemplate>>,
86}
87
88impl PromptManager {
89    pub fn new() -> Self {
90        Self {
91            templates: RwLock::new(HashMap::new()),
92        }
93    }
94
95    /// 创建带有默认模板的管理器
96    pub fn with_defaults() -> Self {
97        let manager = Self::new();
98
99        // 添加默认模板
100        manager.register_default_templates();
101
102        manager
103    }
104
105    fn register_default_templates(&self) {
106        let system = PromptTemplate::new(
107            "system",
108            "You are a helpful AI assistant. Be concise and accurate.",
109        )
110        .with_description("Default system prompt");
111
112        let code_review = PromptTemplate::new(
113            "code_review",
114            "Review the following code and provide feedback:\n\n{{code}}\n\nFocus on: {{focus_areas}}"
115        )
116        .with_description("Code review prompt template");
117
118        let task = PromptTemplate::new(
119            "task",
120            "Task: {{task_description}}\n\nContext: {{context}}\n\nPlease complete this task.",
121        )
122        .with_description("General task prompt template");
123
124        let _ = self.register(system);
125        let _ = self.register(code_review);
126        let _ = self.register(task);
127    }
128}
129
130impl Default for PromptManager {
131    fn default() -> Self {
132        Self::with_defaults()
133    }
134}
135
136impl PromptManagerTrait for PromptManager {
137    fn register(&self, template: PromptTemplate) -> Layer2Result<()> {
138        let name = template.name.clone();
139        self.templates.write().insert(name, template);
140        Ok(())
141    }
142
143    fn unregister(&self, name: &str) -> Layer2Result<bool> {
144        Ok(self.templates.write().remove(name).is_some())
145    }
146
147    fn get(&self, name: &str) -> Option<PromptTemplate> {
148        self.templates.read().get(name).cloned()
149    }
150
151    fn render(&self, name: &str, context: &HashMap<String, String>) -> Layer2Result<String> {
152        let templates = self.templates.read();
153
154        let template = templates
155            .get(name)
156            .ok_or_else(|| anyhow::anyhow!("Template not found: {}", name))?;
157
158        Ok(template.render(context))
159    }
160
161    fn list(&self) -> Vec<String> {
162        self.templates.read().keys().cloned().collect()
163    }
164
165    fn count(&self) -> usize {
166        self.templates.read().len()
167    }
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173
174    #[test]
175    fn test_prompt_template() {
176        let mut template = PromptTemplate::new("test", "Hello {{name}}!");
177        template.extract_variables();
178
179        let mut context = HashMap::new();
180        context.insert("name".to_string(), "World".to_string());
181
182        let result = template.render(&context);
183        assert_eq!(result, "Hello World!");
184    }
185
186    #[test]
187    fn test_prompt_manager() {
188        let manager = PromptManager::new();
189
190        let template = PromptTemplate::new("test", "Test template");
191        manager.register(template).unwrap();
192
193        assert_eq!(manager.count(), 1);
194        assert!(manager.get("test").is_some());
195    }
196
197    #[test]
198    fn test_prompt_manager_defaults() {
199        let manager = PromptManager::with_defaults();
200
201        assert!(manager.get("system").is_some());
202        assert!(manager.get("code_review").is_some());
203        assert!(manager.get("task").is_some());
204    }
205
206    #[test]
207    fn test_render_template() {
208        let manager = PromptManager::with_defaults();
209
210        let mut context = HashMap::new();
211        context.insert("name".to_string(), "World".to_string());
212
213        // 测试默认模板可以获取
214        let templates = manager.list();
215        assert!(!templates.is_empty());
216    }
217}