Skip to main content

aster/plugins/
registry.rs

1//! 插件注册表
2//!
3//! 管理插件注册的工具、命令、技能和钩子
4
5use super::types::*;
6use std::collections::HashMap;
7use std::sync::{Arc, RwLock};
8
9/// 工具定义(简化版)
10#[derive(Debug, Clone)]
11pub struct ToolDefinition {
12    pub name: String,
13    pub description: String,
14    pub parameters: serde_json::Value,
15}
16
17/// 插件工具 API
18pub struct PluginToolAPI {
19    plugin_name: String,
20    tools: Arc<RwLock<HashMap<String, Vec<ToolDefinition>>>>,
21}
22
23impl PluginToolAPI {
24    pub fn new(
25        plugin_name: &str,
26        tools: Arc<RwLock<HashMap<String, Vec<ToolDefinition>>>>,
27    ) -> Self {
28        Self {
29            plugin_name: plugin_name.to_string(),
30            tools,
31        }
32    }
33
34    /// 注册工具
35    pub fn register(&self, tool: ToolDefinition) {
36        if let Ok(mut tools) = self.tools.write() {
37            tools
38                .entry(self.plugin_name.clone())
39                .or_default()
40                .push(tool);
41        }
42    }
43
44    /// 注销工具
45    pub fn unregister(&self, tool_name: &str) {
46        if let Ok(mut tools) = self.tools.write() {
47            if let Some(list) = tools.get_mut(&self.plugin_name) {
48                list.retain(|t| t.name != tool_name);
49            }
50        }
51    }
52
53    /// 获取已注册的工具
54    pub fn get_registered(&self) -> Vec<ToolDefinition> {
55        self.tools
56            .read()
57            .ok()
58            .and_then(|t| t.get(&self.plugin_name).cloned())
59            .unwrap_or_default()
60    }
61}
62
63/// 插件命令 API
64pub struct PluginCommandAPI {
65    plugin_name: String,
66    commands: Arc<RwLock<HashMap<String, Vec<CommandDefinition>>>>,
67}
68
69impl PluginCommandAPI {
70    pub fn new(
71        plugin_name: &str,
72        commands: Arc<RwLock<HashMap<String, Vec<CommandDefinition>>>>,
73    ) -> Self {
74        Self {
75            plugin_name: plugin_name.to_string(),
76            commands,
77        }
78    }
79
80    /// 注册命令
81    pub fn register(&self, command: CommandDefinition) {
82        if let Ok(mut commands) = self.commands.write() {
83            commands
84                .entry(self.plugin_name.clone())
85                .or_default()
86                .push(command);
87        }
88    }
89
90    /// 注销命令
91    pub fn unregister(&self, command_name: &str) {
92        if let Ok(mut commands) = self.commands.write() {
93            if let Some(list) = commands.get_mut(&self.plugin_name) {
94                list.retain(|c| c.name != command_name);
95            }
96        }
97    }
98
99    /// 获取已注册的命令
100    pub fn get_registered(&self) -> Vec<CommandDefinition> {
101        self.commands
102            .read()
103            .ok()
104            .and_then(|c| c.get(&self.plugin_name).cloned())
105            .unwrap_or_default()
106    }
107}
108
109/// 插件技能 API
110pub struct PluginSkillAPI {
111    plugin_name: String,
112    skills: Arc<RwLock<HashMap<String, Vec<SkillDefinition>>>>,
113}
114
115impl PluginSkillAPI {
116    pub fn new(
117        plugin_name: &str,
118        skills: Arc<RwLock<HashMap<String, Vec<SkillDefinition>>>>,
119    ) -> Self {
120        Self {
121            plugin_name: plugin_name.to_string(),
122            skills,
123        }
124    }
125
126    /// 注册技能
127    pub fn register(&self, skill: SkillDefinition) {
128        if let Ok(mut skills) = self.skills.write() {
129            skills
130                .entry(self.plugin_name.clone())
131                .or_default()
132                .push(skill);
133        }
134    }
135
136    /// 注销技能
137    pub fn unregister(&self, skill_name: &str) {
138        if let Ok(mut skills) = self.skills.write() {
139            if let Some(list) = skills.get_mut(&self.plugin_name) {
140                list.retain(|s| s.name != skill_name);
141            }
142        }
143    }
144
145    /// 获取已注册的技能
146    pub fn get_registered(&self) -> Vec<SkillDefinition> {
147        self.skills
148            .read()
149            .ok()
150            .and_then(|s| s.get(&self.plugin_name).cloned())
151            .unwrap_or_default()
152    }
153}
154
155/// 插件钩子 API
156pub struct PluginHookAPI {
157    plugin_name: String,
158    hooks: Arc<RwLock<HashMap<String, Vec<HookDefinition>>>>,
159}
160
161impl PluginHookAPI {
162    pub fn new(
163        plugin_name: &str,
164        hooks: Arc<RwLock<HashMap<String, Vec<HookDefinition>>>>,
165    ) -> Self {
166        Self {
167            plugin_name: plugin_name.to_string(),
168            hooks,
169        }
170    }
171
172    /// 注册钩子
173    pub fn register(&self, hook: HookDefinition) {
174        if let Ok(mut hooks) = self.hooks.write() {
175            hooks
176                .entry(self.plugin_name.clone())
177                .or_default()
178                .push(hook);
179        }
180    }
181
182    /// 注销钩子
183    pub fn unregister(&self, hook_type: PluginHookType) {
184        if let Ok(mut hooks) = self.hooks.write() {
185            if let Some(list) = hooks.get_mut(&self.plugin_name) {
186                list.retain(|h| h.hook_type != hook_type);
187            }
188        }
189    }
190
191    /// 获取已注册的钩子
192    pub fn get_registered(&self) -> Vec<HookDefinition> {
193        self.hooks
194            .read()
195            .ok()
196            .and_then(|h| h.get(&self.plugin_name).cloned())
197            .unwrap_or_default()
198    }
199}
200
201/// 全局注册表
202pub struct PluginRegistry {
203    pub tools: Arc<RwLock<HashMap<String, Vec<ToolDefinition>>>>,
204    pub commands: Arc<RwLock<HashMap<String, Vec<CommandDefinition>>>>,
205    pub skills: Arc<RwLock<HashMap<String, Vec<SkillDefinition>>>>,
206    pub hooks: Arc<RwLock<HashMap<String, Vec<HookDefinition>>>>,
207}
208
209impl PluginRegistry {
210    pub fn new() -> Self {
211        Self {
212            tools: Arc::new(RwLock::new(HashMap::new())),
213            commands: Arc::new(RwLock::new(HashMap::new())),
214            skills: Arc::new(RwLock::new(HashMap::new())),
215            hooks: Arc::new(RwLock::new(HashMap::new())),
216        }
217    }
218
219    /// 获取所有工具
220    pub fn get_all_tools(&self) -> Vec<ToolDefinition> {
221        self.tools
222            .read()
223            .map(|t| t.values().flatten().cloned().collect())
224            .unwrap_or_default()
225    }
226
227    /// 获取所有命令
228    pub fn get_all_commands(&self) -> Vec<CommandDefinition> {
229        self.commands
230            .read()
231            .map(|c| c.values().flatten().cloned().collect())
232            .unwrap_or_default()
233    }
234
235    /// 获取所有技能
236    pub fn get_all_skills(&self) -> Vec<SkillDefinition> {
237        self.skills
238            .read()
239            .map(|s| s.values().flatten().cloned().collect())
240            .unwrap_or_default()
241    }
242
243    /// 获取指定类型的所有钩子
244    pub fn get_hooks_by_type(&self, hook_type: PluginHookType) -> Vec<HookDefinition> {
245        self.hooks
246            .read()
247            .map(|h| {
248                h.values()
249                    .flatten()
250                    .filter(|hook| hook.hook_type == hook_type)
251                    .cloned()
252                    .collect()
253            })
254            .unwrap_or_default()
255    }
256
257    /// 清理插件的所有注册
258    pub fn clear_plugin(&self, plugin_name: &str) {
259        if let Ok(mut tools) = self.tools.write() {
260            tools.remove(plugin_name);
261        }
262        if let Ok(mut commands) = self.commands.write() {
263            commands.remove(plugin_name);
264        }
265        if let Ok(mut skills) = self.skills.write() {
266            skills.remove(plugin_name);
267        }
268        if let Ok(mut hooks) = self.hooks.write() {
269            hooks.remove(plugin_name);
270        }
271    }
272}
273
274impl Default for PluginRegistry {
275    fn default() -> Self {
276        Self::new()
277    }
278}
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283
284    #[test]
285    fn test_plugin_registry_new() {
286        let registry = PluginRegistry::new();
287        assert!(registry.get_all_tools().is_empty());
288        assert!(registry.get_all_commands().is_empty());
289        assert!(registry.get_all_skills().is_empty());
290    }
291
292    #[test]
293    fn test_tool_api_register() {
294        let registry = PluginRegistry::new();
295        let tool_api = PluginToolAPI::new("test-plugin", Arc::clone(&registry.tools));
296
297        let tool = ToolDefinition {
298            name: "test-tool".to_string(),
299            description: "A test tool".to_string(),
300            parameters: serde_json::json!({}),
301        };
302
303        tool_api.register(tool);
304
305        let tools = tool_api.get_registered();
306        assert_eq!(tools.len(), 1);
307        assert_eq!(tools[0].name, "test-tool");
308    }
309
310    #[test]
311    fn test_tool_api_unregister() {
312        let registry = PluginRegistry::new();
313        let tool_api = PluginToolAPI::new("test-plugin", Arc::clone(&registry.tools));
314
315        tool_api.register(ToolDefinition {
316            name: "tool1".to_string(),
317            description: "Tool 1".to_string(),
318            parameters: serde_json::json!({}),
319        });
320        tool_api.register(ToolDefinition {
321            name: "tool2".to_string(),
322            description: "Tool 2".to_string(),
323            parameters: serde_json::json!({}),
324        });
325
326        assert_eq!(tool_api.get_registered().len(), 2);
327
328        tool_api.unregister("tool1");
329
330        let tools = tool_api.get_registered();
331        assert_eq!(tools.len(), 1);
332        assert_eq!(tools[0].name, "tool2");
333    }
334
335    #[test]
336    fn test_command_api_register() {
337        let registry = PluginRegistry::new();
338        let cmd_api = PluginCommandAPI::new("test-plugin", Arc::clone(&registry.commands));
339
340        let cmd = CommandDefinition {
341            name: "test-cmd".to_string(),
342            description: "A test command".to_string(),
343            usage: Some("/test-cmd".to_string()),
344            examples: vec!["example1".to_string()],
345        };
346
347        cmd_api.register(cmd);
348
349        let cmds = cmd_api.get_registered();
350        assert_eq!(cmds.len(), 1);
351        assert_eq!(cmds[0].name, "test-cmd");
352    }
353
354    #[test]
355    fn test_skill_api_register() {
356        let registry = PluginRegistry::new();
357        let skill_api = PluginSkillAPI::new("test-plugin", Arc::clone(&registry.skills));
358
359        let skill = SkillDefinition {
360            name: "test-skill".to_string(),
361            description: "A test skill".to_string(),
362            prompt: "Test prompt".to_string(),
363            category: Some("test".to_string()),
364            examples: vec!["example1".to_string()],
365            parameters: vec![],
366        };
367
368        skill_api.register(skill);
369
370        let skills = skill_api.get_registered();
371        assert_eq!(skills.len(), 1);
372        assert_eq!(skills[0].name, "test-skill");
373    }
374
375    #[test]
376    fn test_hook_api_register() {
377        let registry = PluginRegistry::new();
378        let hook_api = PluginHookAPI::new("test-plugin", Arc::clone(&registry.hooks));
379
380        let hook = HookDefinition {
381            hook_type: PluginHookType::BeforeToolCall,
382            priority: 10,
383        };
384
385        hook_api.register(hook);
386
387        let hooks = hook_api.get_registered();
388        assert_eq!(hooks.len(), 1);
389        assert_eq!(hooks[0].hook_type, PluginHookType::BeforeToolCall);
390    }
391
392    #[test]
393    fn test_registry_get_all() {
394        let registry = PluginRegistry::new();
395
396        // 注册多个插件的工具
397        let tool_api1 = PluginToolAPI::new("plugin1", Arc::clone(&registry.tools));
398        let tool_api2 = PluginToolAPI::new("plugin2", Arc::clone(&registry.tools));
399
400        tool_api1.register(ToolDefinition {
401            name: "tool1".to_string(),
402            description: "Tool 1".to_string(),
403            parameters: serde_json::json!({}),
404        });
405        tool_api2.register(ToolDefinition {
406            name: "tool2".to_string(),
407            description: "Tool 2".to_string(),
408            parameters: serde_json::json!({}),
409        });
410
411        let all_tools = registry.get_all_tools();
412        assert_eq!(all_tools.len(), 2);
413    }
414
415    #[test]
416    fn test_registry_get_hooks_by_type() {
417        let registry = PluginRegistry::new();
418        let hook_api = PluginHookAPI::new("test-plugin", Arc::clone(&registry.hooks));
419
420        hook_api.register(HookDefinition {
421            hook_type: PluginHookType::BeforeToolCall,
422            priority: 10,
423        });
424        hook_api.register(HookDefinition {
425            hook_type: PluginHookType::AfterToolCall,
426            priority: 20,
427        });
428        hook_api.register(HookDefinition {
429            hook_type: PluginHookType::BeforeToolCall,
430            priority: 5,
431        });
432
433        let before_hooks = registry.get_hooks_by_type(PluginHookType::BeforeToolCall);
434        assert_eq!(before_hooks.len(), 2);
435
436        let after_hooks = registry.get_hooks_by_type(PluginHookType::AfterToolCall);
437        assert_eq!(after_hooks.len(), 1);
438    }
439
440    #[test]
441    fn test_registry_clear_plugin() {
442        let registry = PluginRegistry::new();
443
444        let tool_api = PluginToolAPI::new("test-plugin", Arc::clone(&registry.tools));
445        let cmd_api = PluginCommandAPI::new("test-plugin", Arc::clone(&registry.commands));
446
447        tool_api.register(ToolDefinition {
448            name: "tool1".to_string(),
449            description: "Tool 1".to_string(),
450            parameters: serde_json::json!({}),
451        });
452        cmd_api.register(CommandDefinition {
453            name: "cmd1".to_string(),
454            description: "Command 1".to_string(),
455            usage: None,
456            examples: vec![],
457        });
458
459        assert_eq!(registry.get_all_tools().len(), 1);
460        assert_eq!(registry.get_all_commands().len(), 1);
461
462        registry.clear_plugin("test-plugin");
463
464        assert!(registry.get_all_tools().is_empty());
465        assert!(registry.get_all_commands().is_empty());
466    }
467}