1use super::types::*;
6use std::collections::HashMap;
7use std::sync::{Arc, RwLock};
8
9#[derive(Debug, Clone)]
11pub struct ToolDefinition {
12 pub name: String,
13 pub description: String,
14 pub parameters: serde_json::Value,
15}
16
17pub struct PluginToolAPI {
19 plugin_name: String,
20 tools: Arc<RwLock<HashMap<String, Vec<ToolDefinition>>>>,
21}
22
23impl PluginToolAPI {
24 pub fn new(
25 plugin_name: &str,
26 tools: Arc<RwLock<HashMap<String, Vec<ToolDefinition>>>>,
27 ) -> Self {
28 Self {
29 plugin_name: plugin_name.to_string(),
30 tools,
31 }
32 }
33
34 pub fn register(&self, tool: ToolDefinition) {
36 if let Ok(mut tools) = self.tools.write() {
37 tools
38 .entry(self.plugin_name.clone())
39 .or_default()
40 .push(tool);
41 }
42 }
43
44 pub fn unregister(&self, tool_name: &str) {
46 if let Ok(mut tools) = self.tools.write() {
47 if let Some(list) = tools.get_mut(&self.plugin_name) {
48 list.retain(|t| t.name != tool_name);
49 }
50 }
51 }
52
53 pub fn get_registered(&self) -> Vec<ToolDefinition> {
55 self.tools
56 .read()
57 .ok()
58 .and_then(|t| t.get(&self.plugin_name).cloned())
59 .unwrap_or_default()
60 }
61}
62
63pub struct PluginCommandAPI {
65 plugin_name: String,
66 commands: Arc<RwLock<HashMap<String, Vec<CommandDefinition>>>>,
67}
68
69impl PluginCommandAPI {
70 pub fn new(
71 plugin_name: &str,
72 commands: Arc<RwLock<HashMap<String, Vec<CommandDefinition>>>>,
73 ) -> Self {
74 Self {
75 plugin_name: plugin_name.to_string(),
76 commands,
77 }
78 }
79
80 pub fn register(&self, command: CommandDefinition) {
82 if let Ok(mut commands) = self.commands.write() {
83 commands
84 .entry(self.plugin_name.clone())
85 .or_default()
86 .push(command);
87 }
88 }
89
90 pub fn unregister(&self, command_name: &str) {
92 if let Ok(mut commands) = self.commands.write() {
93 if let Some(list) = commands.get_mut(&self.plugin_name) {
94 list.retain(|c| c.name != command_name);
95 }
96 }
97 }
98
99 pub fn get_registered(&self) -> Vec<CommandDefinition> {
101 self.commands
102 .read()
103 .ok()
104 .and_then(|c| c.get(&self.plugin_name).cloned())
105 .unwrap_or_default()
106 }
107}
108
109pub struct PluginSkillAPI {
111 plugin_name: String,
112 skills: Arc<RwLock<HashMap<String, Vec<SkillDefinition>>>>,
113}
114
115impl PluginSkillAPI {
116 pub fn new(
117 plugin_name: &str,
118 skills: Arc<RwLock<HashMap<String, Vec<SkillDefinition>>>>,
119 ) -> Self {
120 Self {
121 plugin_name: plugin_name.to_string(),
122 skills,
123 }
124 }
125
126 pub fn register(&self, skill: SkillDefinition) {
128 if let Ok(mut skills) = self.skills.write() {
129 skills
130 .entry(self.plugin_name.clone())
131 .or_default()
132 .push(skill);
133 }
134 }
135
136 pub fn unregister(&self, skill_name: &str) {
138 if let Ok(mut skills) = self.skills.write() {
139 if let Some(list) = skills.get_mut(&self.plugin_name) {
140 list.retain(|s| s.name != skill_name);
141 }
142 }
143 }
144
145 pub fn get_registered(&self) -> Vec<SkillDefinition> {
147 self.skills
148 .read()
149 .ok()
150 .and_then(|s| s.get(&self.plugin_name).cloned())
151 .unwrap_or_default()
152 }
153}
154
155pub struct PluginHookAPI {
157 plugin_name: String,
158 hooks: Arc<RwLock<HashMap<String, Vec<HookDefinition>>>>,
159}
160
161impl PluginHookAPI {
162 pub fn new(
163 plugin_name: &str,
164 hooks: Arc<RwLock<HashMap<String, Vec<HookDefinition>>>>,
165 ) -> Self {
166 Self {
167 plugin_name: plugin_name.to_string(),
168 hooks,
169 }
170 }
171
172 pub fn register(&self, hook: HookDefinition) {
174 if let Ok(mut hooks) = self.hooks.write() {
175 hooks
176 .entry(self.plugin_name.clone())
177 .or_default()
178 .push(hook);
179 }
180 }
181
182 pub fn unregister(&self, hook_type: PluginHookType) {
184 if let Ok(mut hooks) = self.hooks.write() {
185 if let Some(list) = hooks.get_mut(&self.plugin_name) {
186 list.retain(|h| h.hook_type != hook_type);
187 }
188 }
189 }
190
191 pub fn get_registered(&self) -> Vec<HookDefinition> {
193 self.hooks
194 .read()
195 .ok()
196 .and_then(|h| h.get(&self.plugin_name).cloned())
197 .unwrap_or_default()
198 }
199}
200
201pub struct PluginRegistry {
203 pub tools: Arc<RwLock<HashMap<String, Vec<ToolDefinition>>>>,
204 pub commands: Arc<RwLock<HashMap<String, Vec<CommandDefinition>>>>,
205 pub skills: Arc<RwLock<HashMap<String, Vec<SkillDefinition>>>>,
206 pub hooks: Arc<RwLock<HashMap<String, Vec<HookDefinition>>>>,
207}
208
209impl PluginRegistry {
210 pub fn new() -> Self {
211 Self {
212 tools: Arc::new(RwLock::new(HashMap::new())),
213 commands: Arc::new(RwLock::new(HashMap::new())),
214 skills: Arc::new(RwLock::new(HashMap::new())),
215 hooks: Arc::new(RwLock::new(HashMap::new())),
216 }
217 }
218
219 pub fn get_all_tools(&self) -> Vec<ToolDefinition> {
221 self.tools
222 .read()
223 .map(|t| t.values().flatten().cloned().collect())
224 .unwrap_or_default()
225 }
226
227 pub fn get_all_commands(&self) -> Vec<CommandDefinition> {
229 self.commands
230 .read()
231 .map(|c| c.values().flatten().cloned().collect())
232 .unwrap_or_default()
233 }
234
235 pub fn get_all_skills(&self) -> Vec<SkillDefinition> {
237 self.skills
238 .read()
239 .map(|s| s.values().flatten().cloned().collect())
240 .unwrap_or_default()
241 }
242
243 pub fn get_hooks_by_type(&self, hook_type: PluginHookType) -> Vec<HookDefinition> {
245 self.hooks
246 .read()
247 .map(|h| {
248 h.values()
249 .flatten()
250 .filter(|hook| hook.hook_type == hook_type)
251 .cloned()
252 .collect()
253 })
254 .unwrap_or_default()
255 }
256
257 pub fn clear_plugin(&self, plugin_name: &str) {
259 if let Ok(mut tools) = self.tools.write() {
260 tools.remove(plugin_name);
261 }
262 if let Ok(mut commands) = self.commands.write() {
263 commands.remove(plugin_name);
264 }
265 if let Ok(mut skills) = self.skills.write() {
266 skills.remove(plugin_name);
267 }
268 if let Ok(mut hooks) = self.hooks.write() {
269 hooks.remove(plugin_name);
270 }
271 }
272}
273
274impl Default for PluginRegistry {
275 fn default() -> Self {
276 Self::new()
277 }
278}
279
280#[cfg(test)]
281mod tests {
282 use super::*;
283
284 #[test]
285 fn test_plugin_registry_new() {
286 let registry = PluginRegistry::new();
287 assert!(registry.get_all_tools().is_empty());
288 assert!(registry.get_all_commands().is_empty());
289 assert!(registry.get_all_skills().is_empty());
290 }
291
292 #[test]
293 fn test_tool_api_register() {
294 let registry = PluginRegistry::new();
295 let tool_api = PluginToolAPI::new("test-plugin", Arc::clone(®istry.tools));
296
297 let tool = ToolDefinition {
298 name: "test-tool".to_string(),
299 description: "A test tool".to_string(),
300 parameters: serde_json::json!({}),
301 };
302
303 tool_api.register(tool);
304
305 let tools = tool_api.get_registered();
306 assert_eq!(tools.len(), 1);
307 assert_eq!(tools[0].name, "test-tool");
308 }
309
310 #[test]
311 fn test_tool_api_unregister() {
312 let registry = PluginRegistry::new();
313 let tool_api = PluginToolAPI::new("test-plugin", Arc::clone(®istry.tools));
314
315 tool_api.register(ToolDefinition {
316 name: "tool1".to_string(),
317 description: "Tool 1".to_string(),
318 parameters: serde_json::json!({}),
319 });
320 tool_api.register(ToolDefinition {
321 name: "tool2".to_string(),
322 description: "Tool 2".to_string(),
323 parameters: serde_json::json!({}),
324 });
325
326 assert_eq!(tool_api.get_registered().len(), 2);
327
328 tool_api.unregister("tool1");
329
330 let tools = tool_api.get_registered();
331 assert_eq!(tools.len(), 1);
332 assert_eq!(tools[0].name, "tool2");
333 }
334
335 #[test]
336 fn test_command_api_register() {
337 let registry = PluginRegistry::new();
338 let cmd_api = PluginCommandAPI::new("test-plugin", Arc::clone(®istry.commands));
339
340 let cmd = CommandDefinition {
341 name: "test-cmd".to_string(),
342 description: "A test command".to_string(),
343 usage: Some("/test-cmd".to_string()),
344 examples: vec!["example1".to_string()],
345 };
346
347 cmd_api.register(cmd);
348
349 let cmds = cmd_api.get_registered();
350 assert_eq!(cmds.len(), 1);
351 assert_eq!(cmds[0].name, "test-cmd");
352 }
353
354 #[test]
355 fn test_skill_api_register() {
356 let registry = PluginRegistry::new();
357 let skill_api = PluginSkillAPI::new("test-plugin", Arc::clone(®istry.skills));
358
359 let skill = SkillDefinition {
360 name: "test-skill".to_string(),
361 description: "A test skill".to_string(),
362 prompt: "Test prompt".to_string(),
363 category: Some("test".to_string()),
364 examples: vec!["example1".to_string()],
365 parameters: vec![],
366 };
367
368 skill_api.register(skill);
369
370 let skills = skill_api.get_registered();
371 assert_eq!(skills.len(), 1);
372 assert_eq!(skills[0].name, "test-skill");
373 }
374
375 #[test]
376 fn test_hook_api_register() {
377 let registry = PluginRegistry::new();
378 let hook_api = PluginHookAPI::new("test-plugin", Arc::clone(®istry.hooks));
379
380 let hook = HookDefinition {
381 hook_type: PluginHookType::BeforeToolCall,
382 priority: 10,
383 };
384
385 hook_api.register(hook);
386
387 let hooks = hook_api.get_registered();
388 assert_eq!(hooks.len(), 1);
389 assert_eq!(hooks[0].hook_type, PluginHookType::BeforeToolCall);
390 }
391
392 #[test]
393 fn test_registry_get_all() {
394 let registry = PluginRegistry::new();
395
396 let tool_api1 = PluginToolAPI::new("plugin1", Arc::clone(®istry.tools));
398 let tool_api2 = PluginToolAPI::new("plugin2", Arc::clone(®istry.tools));
399
400 tool_api1.register(ToolDefinition {
401 name: "tool1".to_string(),
402 description: "Tool 1".to_string(),
403 parameters: serde_json::json!({}),
404 });
405 tool_api2.register(ToolDefinition {
406 name: "tool2".to_string(),
407 description: "Tool 2".to_string(),
408 parameters: serde_json::json!({}),
409 });
410
411 let all_tools = registry.get_all_tools();
412 assert_eq!(all_tools.len(), 2);
413 }
414
415 #[test]
416 fn test_registry_get_hooks_by_type() {
417 let registry = PluginRegistry::new();
418 let hook_api = PluginHookAPI::new("test-plugin", Arc::clone(®istry.hooks));
419
420 hook_api.register(HookDefinition {
421 hook_type: PluginHookType::BeforeToolCall,
422 priority: 10,
423 });
424 hook_api.register(HookDefinition {
425 hook_type: PluginHookType::AfterToolCall,
426 priority: 20,
427 });
428 hook_api.register(HookDefinition {
429 hook_type: PluginHookType::BeforeToolCall,
430 priority: 5,
431 });
432
433 let before_hooks = registry.get_hooks_by_type(PluginHookType::BeforeToolCall);
434 assert_eq!(before_hooks.len(), 2);
435
436 let after_hooks = registry.get_hooks_by_type(PluginHookType::AfterToolCall);
437 assert_eq!(after_hooks.len(), 1);
438 }
439
440 #[test]
441 fn test_registry_clear_plugin() {
442 let registry = PluginRegistry::new();
443
444 let tool_api = PluginToolAPI::new("test-plugin", Arc::clone(®istry.tools));
445 let cmd_api = PluginCommandAPI::new("test-plugin", Arc::clone(®istry.commands));
446
447 tool_api.register(ToolDefinition {
448 name: "tool1".to_string(),
449 description: "Tool 1".to_string(),
450 parameters: serde_json::json!({}),
451 });
452 cmd_api.register(CommandDefinition {
453 name: "cmd1".to_string(),
454 description: "Command 1".to_string(),
455 usage: None,
456 examples: vec![],
457 });
458
459 assert_eq!(registry.get_all_tools().len(), 1);
460 assert_eq!(registry.get_all_commands().len(), 1);
461
462 registry.clear_plugin("test-plugin");
463
464 assert!(registry.get_all_tools().is_empty());
465 assert!(registry.get_all_commands().is_empty());
466 }
467}