astrid_plugins/
registry.rs1use std::collections::HashMap;
7
8use tracing::{debug, info};
9
10use crate::error::{PluginError, PluginResult};
11use crate::plugin::{Plugin, PluginId};
12use crate::tool::PluginTool;
13
14fn qualified_tool_name(plugin_id: &PluginId, tool_name: &str) -> String {
19 format!("plugin:{plugin_id}:{tool_name}")
20}
21
22#[derive(Debug, Clone)]
24pub struct PluginToolDefinition {
25 pub name: String,
27 pub description: String,
29 pub input_schema: serde_json::Value,
31}
32
33pub struct PluginRegistry {
38 plugins: HashMap<PluginId, Box<dyn Plugin>>,
39}
40
41impl PluginRegistry {
42 #[must_use]
44 pub fn new() -> Self {
45 Self {
46 plugins: HashMap::new(),
47 }
48 }
49
50 pub fn register(&mut self, plugin: Box<dyn Plugin>) -> PluginResult<()> {
57 let id = plugin.id().clone();
58 if self.plugins.contains_key(&id) {
59 return Err(PluginError::AlreadyRegistered(id));
60 }
61 info!(plugin_id = %id, "Registered plugin");
62 self.plugins.insert(id, plugin);
63 Ok(())
64 }
65
66 pub fn unregister(&mut self, id: &PluginId) -> PluginResult<Box<dyn Plugin>> {
72 let plugin = self
73 .plugins
74 .remove(id)
75 .ok_or_else(|| PluginError::NotFound(id.clone()))?;
76 info!(plugin_id = %id, "Unregistered plugin");
77 Ok(plugin)
78 }
79
80 #[must_use]
82 pub fn get(&self, id: &PluginId) -> Option<&dyn Plugin> {
83 self.plugins.get(id).map(AsRef::as_ref)
84 }
85
86 #[must_use]
88 pub fn get_mut(&mut self, id: &PluginId) -> Option<&mut Box<dyn Plugin>> {
89 self.plugins.get_mut(id)
90 }
91
92 #[must_use]
94 pub fn list(&self) -> Vec<&PluginId> {
95 self.plugins.keys().collect()
96 }
97
98 #[must_use]
100 pub fn len(&self) -> usize {
101 self.plugins.len()
102 }
103
104 #[must_use]
106 pub fn is_empty(&self) -> bool {
107 self.plugins.is_empty()
108 }
109
110 #[must_use]
114 pub fn find_tool(&self, qualified_name: &str) -> Option<(&dyn Plugin, &dyn PluginTool)> {
115 let rest = qualified_name.strip_prefix("plugin:")?;
117 let (plugin_id_str, tool_name) = rest.split_once(':')?;
118
119 let plugin_id = PluginId::from_static(plugin_id_str);
120 let plugin = self.plugins.get(&plugin_id)?;
121
122 let tool = plugin.tools().iter().find(|t| t.name() == tool_name)?;
123
124 debug!(
125 qualified_name,
126 plugin_id = %plugin_id,
127 tool_name,
128 "Found plugin tool"
129 );
130 Some((plugin.as_ref(), tool.as_ref()))
131 }
132
133 #[must_use]
135 pub fn is_plugin_tool(name: &str) -> bool {
136 name.starts_with("plugin:") && name.matches(':').count() == 2
137 }
138
139 #[must_use]
141 pub fn all_tool_definitions(&self) -> Vec<PluginToolDefinition> {
142 let mut defs = Vec::new();
143 for (plugin_id, plugin) in &self.plugins {
144 for tool in plugin.tools() {
145 defs.push(PluginToolDefinition {
146 name: qualified_tool_name(plugin_id, tool.name()),
147 description: tool.description().to_string(),
148 input_schema: tool.input_schema(),
149 });
150 }
151 }
152 defs
153 }
154}
155
156impl Default for PluginRegistry {
157 fn default() -> Self {
158 Self::new()
159 }
160}
161
162impl std::fmt::Debug for PluginRegistry {
163 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164 f.debug_struct("PluginRegistry")
165 .field("plugin_count", &self.plugins.len())
166 .field("plugin_ids", &self.list())
167 .finish()
168 }
169}
170
171#[cfg(test)]
172mod tests {
173 use super::*;
174 use crate::context::{PluginContext, PluginToolContext};
175 use crate::manifest::{PluginEntryPoint, PluginManifest};
176 use crate::plugin::PluginState;
177
178 struct TestPlugin {
180 id: PluginId,
181 manifest: PluginManifest,
182 state: PluginState,
183 tools: Vec<Box<dyn PluginTool>>,
184 }
185
186 impl TestPlugin {
187 fn new(id: &str) -> Self {
188 let plugin_id = PluginId::from_static(id);
189 Self {
190 manifest: PluginManifest {
191 id: plugin_id.clone(),
192 name: format!("Test Plugin {id}"),
193 version: "0.1.0".into(),
194 description: None,
195 author: None,
196 entry_point: PluginEntryPoint::Wasm {
197 path: "plugin.wasm".into(),
198 hash: None,
199 },
200 capabilities: vec![],
201 config: HashMap::new(),
202 },
203 id: plugin_id,
204 state: PluginState::Ready,
205 tools: vec![Box::new(EchoTool)],
206 }
207 }
208
209 fn with_no_tools(id: &str) -> Self {
210 let mut p = Self::new(id);
211 p.tools.clear();
212 p
213 }
214 }
215
216 #[async_trait::async_trait]
217 impl Plugin for TestPlugin {
218 fn id(&self) -> &PluginId {
219 &self.id
220 }
221 fn manifest(&self) -> &PluginManifest {
222 &self.manifest
223 }
224 fn state(&self) -> PluginState {
225 self.state.clone()
226 }
227 async fn load(&mut self, _ctx: &PluginContext) -> PluginResult<()> {
228 self.state = PluginState::Ready;
229 Ok(())
230 }
231 async fn unload(&mut self) -> PluginResult<()> {
232 self.state = PluginState::Unloaded;
233 Ok(())
234 }
235 fn tools(&self) -> &[Box<dyn PluginTool>] {
236 &self.tools
237 }
238 }
239
240 struct EchoTool;
241
242 #[async_trait::async_trait]
243 impl PluginTool for EchoTool {
244 fn name(&self) -> &str {
245 "echo"
246 }
247 fn description(&self) -> &str {
248 "Echoes the input"
249 }
250 fn input_schema(&self) -> serde_json::Value {
251 serde_json::json!({
252 "type": "object",
253 "properties": {
254 "message": { "type": "string" }
255 }
256 })
257 }
258 async fn execute(
259 &self,
260 args: serde_json::Value,
261 _ctx: &PluginToolContext,
262 ) -> PluginResult<String> {
263 Ok(args.to_string())
264 }
265 }
266
267 #[test]
268 fn test_register_and_get() {
269 let mut registry = PluginRegistry::new();
270 assert!(registry.is_empty());
271
272 registry
273 .register(Box::new(TestPlugin::new("alpha")))
274 .unwrap();
275 assert_eq!(registry.len(), 1);
276
277 let id = PluginId::from_static("alpha");
278 assert!(registry.get(&id).is_some());
279 assert_eq!(registry.get(&id).unwrap().id().as_str(), "alpha");
280 }
281
282 #[test]
283 fn test_register_duplicate_fails() {
284 let mut registry = PluginRegistry::new();
285 registry
286 .register(Box::new(TestPlugin::new("alpha")))
287 .unwrap();
288 let result = registry.register(Box::new(TestPlugin::new("alpha")));
289 assert!(result.is_err());
290 assert!(matches!(
291 result.unwrap_err(),
292 PluginError::AlreadyRegistered(_)
293 ));
294 }
295
296 #[test]
297 fn test_unregister() {
298 let mut registry = PluginRegistry::new();
299 registry
300 .register(Box::new(TestPlugin::new("alpha")))
301 .unwrap();
302
303 let id = PluginId::from_static("alpha");
304 let plugin = registry.unregister(&id).unwrap();
305 assert_eq!(plugin.id().as_str(), "alpha");
306 assert!(registry.is_empty());
307 }
308
309 #[test]
310 fn test_unregister_missing_fails() {
311 let mut registry = PluginRegistry::new();
312 let id = PluginId::from_static("missing");
313 let result = registry.unregister(&id);
314 assert!(result.is_err());
315 assert!(matches!(result.unwrap_err(), PluginError::NotFound(_)));
316 }
317
318 #[test]
319 fn test_list_plugins() {
320 let mut registry = PluginRegistry::new();
321 registry
322 .register(Box::new(TestPlugin::new("alpha")))
323 .unwrap();
324 registry
325 .register(Box::new(TestPlugin::new("beta")))
326 .unwrap();
327
328 let mut ids: Vec<&str> = registry.list().iter().map(|id| id.as_str()).collect();
329 ids.sort();
330 assert_eq!(ids, vec!["alpha", "beta"]);
331 }
332
333 #[test]
334 fn test_find_tool() {
335 let mut registry = PluginRegistry::new();
336 registry
337 .register(Box::new(TestPlugin::new("alpha")))
338 .unwrap();
339
340 let result = registry.find_tool("plugin:alpha:echo");
341 assert!(result.is_some());
342 let (plugin, tool) = result.unwrap();
343 assert_eq!(plugin.id().as_str(), "alpha");
344 assert_eq!(tool.name(), "echo");
345 }
346
347 #[test]
348 fn test_find_tool_missing_plugin() {
349 let registry = PluginRegistry::new();
350 assert!(registry.find_tool("plugin:missing:echo").is_none());
351 }
352
353 #[test]
354 fn test_find_tool_missing_tool() {
355 let mut registry = PluginRegistry::new();
356 registry
357 .register(Box::new(TestPlugin::new("alpha")))
358 .unwrap();
359 assert!(registry.find_tool("plugin:alpha:missing").is_none());
360 }
361
362 #[test]
363 fn test_find_tool_invalid_format() {
364 let registry = PluginRegistry::new();
365 assert!(registry.find_tool("builtin:echo").is_none());
366 assert!(registry.find_tool("echo").is_none());
367 assert!(registry.find_tool("").is_none());
368 }
369
370 #[test]
371 fn test_is_plugin_tool() {
372 assert!(PluginRegistry::is_plugin_tool("plugin:alpha:echo"));
373 assert!(PluginRegistry::is_plugin_tool("plugin:my-plugin:read-file"));
374 assert!(!PluginRegistry::is_plugin_tool("read_file"));
375 assert!(!PluginRegistry::is_plugin_tool("server:tool"));
376 assert!(!PluginRegistry::is_plugin_tool("plugin:only-one-colon"));
377 }
378
379 #[test]
380 fn test_all_tool_definitions() {
381 let mut registry = PluginRegistry::new();
382 registry
383 .register(Box::new(TestPlugin::new("alpha")))
384 .unwrap();
385 registry
386 .register(Box::new(TestPlugin::with_no_tools("beta")))
387 .unwrap();
388
389 let defs = registry.all_tool_definitions();
390 assert_eq!(defs.len(), 1);
391 assert_eq!(defs[0].name, "plugin:alpha:echo");
392 assert_eq!(defs[0].description, "Echoes the input");
393 }
394
395 #[test]
396 fn test_get_mut() {
397 let mut registry = PluginRegistry::new();
398 registry
399 .register(Box::new(TestPlugin::new("alpha")))
400 .unwrap();
401
402 let id = PluginId::from_static("alpha");
403 let plugin = registry.get_mut(&id).unwrap();
404 assert_eq!(plugin.id().as_str(), "alpha");
405 }
406
407 #[test]
408 fn test_debug_impl() {
409 let mut registry = PluginRegistry::new();
410 registry
411 .register(Box::new(TestPlugin::new("alpha")))
412 .unwrap();
413 let debug = format!("{registry:?}");
414 assert!(debug.contains("PluginRegistry"));
415 assert!(debug.contains("plugin_count"));
416 }
417}