Skip to main content

roboticus_plugin_sdk/
registry.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use tokio::sync::Mutex;
7use tracing::{debug, warn};
8
9use roboticus_core::{RoboticusError, Result};
10
11use crate::{Plugin, PluginStatus, ToolDef, ToolResult};
12
13/// Controls how undeclared plugin permissions are handled at runtime.
14pub struct PermissionPolicy {
15    pub strict: bool,
16    pub allowed: Vec<String>,
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct PluginInfo {
21    pub name: String,
22    pub version: String,
23    pub status: PluginStatus,
24    pub tools: Vec<ToolDef>,
25}
26
27struct PluginEntry {
28    plugin: Arc<tokio::sync::Mutex<Box<dyn Plugin>>>,
29    status: PluginStatus,
30}
31
32/// Central registry that owns all loaded plugin instances.
33///
34/// ## Lock acquisition pattern
35///
36/// This registry uses a two-level locking scheme:
37///
38/// 1. **Outer lock** (`self.plugins`): A `tokio::sync::Mutex<HashMap<...>>` that
39///    guards the plugin map itself (registration, removal, iteration).
40/// 2. **Inner lock** (each `PluginEntry::plugin`): A per-plugin
41///    `Arc<tokio::sync::Mutex<Box<dyn Plugin>>>` that guards access to individual
42///    plugin instances.
43///
44/// Several methods (e.g., `execute_tool`, `find_tool`, `list_plugins`,
45/// `list_all_tools`) acquire the outer lock and then, while still holding it,
46/// acquire one or more inner plugin locks. This nested acquisition is safe from
47/// deadlocks because the inner locks are never held when attempting to acquire
48/// the outer lock. However, it means that a slow plugin `init()` or
49/// `execute_tool()` call can block all other registry operations for the
50/// duration.
51///
52/// `tools()` on the `Plugin` trait is expected to be non-blocking (it returns a
53/// `Vec<ToolDef>` synchronously) so the inner lock contention during tool
54/// lookups should be negligible. If plugin execution latency becomes a concern,
55/// consider cloning the `Arc` outside the outer lock and releasing the outer
56/// lock before awaiting the inner one (as `execute_tool` already does).
57pub struct PluginRegistry {
58    plugins: Mutex<HashMap<String, PluginEntry>>,
59    allow_list: Vec<String>,
60    deny_list: Vec<String>,
61    permission_policy: PermissionPolicy,
62}
63
64impl PluginRegistry {
65    pub fn new(
66        allow_list: Vec<String>,
67        deny_list: Vec<String>,
68        permission_policy: PermissionPolicy,
69    ) -> Self {
70        let normalized_allowed: Vec<String> = permission_policy
71            .allowed
72            .into_iter()
73            .map(|p| p.to_ascii_lowercase())
74            .collect();
75        Self {
76            plugins: Mutex::new(HashMap::new()),
77            allow_list,
78            deny_list,
79            permission_policy: PermissionPolicy {
80                strict: permission_policy.strict,
81                allowed: normalized_allowed,
82            },
83        }
84    }
85
86    pub fn is_allowed(&self, name: &str) -> bool {
87        if self.deny_list.iter().any(|d| d == name) {
88            return false;
89        }
90        if self.allow_list.is_empty() {
91            return true;
92        }
93        self.allow_list.iter().any(|a| a == name)
94    }
95
96    pub async fn register(&self, plugin: Box<dyn Plugin>) -> Result<()> {
97        let name = plugin.name().to_string();
98
99        if !self.is_allowed(&name) {
100            return Err(RoboticusError::Config(format!(
101                "plugin '{name}' is not allowed by policy"
102            )));
103        }
104
105        if self.permission_policy.strict {
106            for tool in plugin.tools() {
107                for perm in tool.permissions {
108                    let normalized = perm.to_ascii_lowercase();
109                    if !self
110                        .permission_policy
111                        .allowed
112                        .iter()
113                        .any(|p| p == &normalized)
114                    {
115                        return Err(RoboticusError::Config(format!(
116                            "plugin '{name}' tool '{}' declares permission '{perm}' not in allowed_permissions",
117                            tool.name
118                        )));
119                    }
120                }
121            }
122        }
123
124        debug!(name = %name, version = %plugin.version(), "registering plugin");
125
126        let entry = PluginEntry {
127            plugin: Arc::new(tokio::sync::Mutex::new(plugin)),
128            status: PluginStatus::Loaded,
129        };
130
131        let mut plugins = self.plugins.lock().await;
132        plugins.insert(name, entry);
133        Ok(())
134    }
135
136    pub async fn init_all(&self) -> Vec<String> {
137        let mut errors = Vec::new();
138        let mut plugins = self.plugins.lock().await;
139
140        for (name, entry) in plugins.iter_mut() {
141            let mut plugin = entry.plugin.lock().await;
142            match plugin.init().await {
143                Ok(()) => {
144                    entry.status = PluginStatus::Active;
145                    debug!(name = %name, "plugin initialized");
146                }
147                Err(e) => {
148                    entry.status = PluginStatus::Error;
149                    warn!(name = %name, error = %e, "plugin init failed");
150                    errors.push(format!("{name}: {e}"));
151                }
152            }
153        }
154
155        errors
156    }
157
158    pub async fn execute_tool(&self, tool_name: &str, input: &Value) -> Result<ToolResult> {
159        let plugin_arc = {
160            let plugins = self.plugins.lock().await;
161            let mut found = None;
162            for entry in plugins.values() {
163                if entry.status != PluginStatus::Active {
164                    continue;
165                }
166                let p = entry.plugin.lock().await;
167                if p.tools().iter().any(|t| t.name == tool_name) {
168                    drop(p);
169                    found = Some(Arc::clone(&entry.plugin));
170                    break;
171                }
172            }
173            found
174        };
175
176        let plugin_arc = match plugin_arc {
177            Some(p) => p,
178            None => {
179                return Err(RoboticusError::Tool {
180                    tool: tool_name.to_string(),
181                    message: "no plugin provides this tool".into(),
182                });
183            }
184        };
185
186        // Check permission policy before executing.
187        let plugin = plugin_arc.lock().await;
188        let tool_permissions: Vec<String> = plugin
189            .tools()
190            .iter()
191            .find(|t| t.name == tool_name)
192            .map(|t| t.permissions.clone())
193            .unwrap_or_default();
194
195        for perm in &tool_permissions {
196            let normalized = perm.to_ascii_lowercase();
197            if !self
198                .permission_policy
199                .allowed
200                .iter()
201                .any(|p| p == &normalized)
202            {
203                if self.permission_policy.strict {
204                    return Err(RoboticusError::Tool {
205                        tool: tool_name.to_string(),
206                        message: format!(
207                            "permission '{perm}' is not allowed by policy (strict mode)"
208                        ),
209                    });
210                }
211                warn!(
212                    tool = %tool_name,
213                    permission = %perm,
214                    "tool requires permission not in allowed list (permissive mode)"
215                );
216            }
217        }
218
219        plugin.execute_tool(tool_name, input).await
220    }
221
222    pub async fn shutdown_all(&self) {
223        let mut plugins = self.plugins.lock().await;
224        for (name, entry) in plugins.iter_mut() {
225            let mut plugin = entry.plugin.lock().await;
226            if let Err(e) = plugin.shutdown().await {
227                warn!(name = %name, error = %e, "plugin shutdown failed");
228            }
229            entry.status = PluginStatus::Disabled;
230        }
231    }
232
233    pub async fn find_tool(&self, tool_name: &str) -> Option<(String, ToolDef)> {
234        let plugins = self.plugins.lock().await;
235        for (plugin_name, entry) in plugins.iter() {
236            if entry.status != PluginStatus::Active {
237                continue;
238            }
239            let plugin = entry.plugin.lock().await;
240            for tool in plugin.tools() {
241                if tool.name == tool_name {
242                    return Some((plugin_name.clone(), tool));
243                }
244            }
245        }
246        None
247    }
248
249    pub async fn list_plugins(&self) -> Vec<PluginInfo> {
250        let plugins = self.plugins.lock().await;
251        let mut result = Vec::new();
252        for entry in plugins.values() {
253            let plugin = entry.plugin.lock().await;
254            result.push(PluginInfo {
255                name: plugin.name().to_string(),
256                version: plugin.version().to_string(),
257                status: entry.status,
258                tools: plugin.tools(),
259            });
260        }
261        result
262    }
263
264    pub async fn list_all_tools(&self) -> Vec<(String, ToolDef)> {
265        let plugins = self.plugins.lock().await;
266        let mut tools = Vec::new();
267        for (name, entry) in plugins.iter() {
268            if entry.status != PluginStatus::Active {
269                continue;
270            }
271            let plugin = entry.plugin.lock().await;
272            for tool in plugin.tools() {
273                tools.push((name.clone(), tool));
274            }
275        }
276        tools
277    }
278
279    pub async fn disable_plugin(&self, name: &str) -> Result<()> {
280        let mut plugins = self.plugins.lock().await;
281        let entry = plugins
282            .get_mut(name)
283            .ok_or_else(|| RoboticusError::Config(format!("plugin '{name}' not found")))?;
284        entry.status = PluginStatus::Disabled;
285        debug!(name, "plugin disabled");
286        Ok(())
287    }
288
289    pub async fn enable_plugin(&self, name: &str) -> Result<()> {
290        let mut plugins = self.plugins.lock().await;
291        let entry = plugins
292            .get_mut(name)
293            .ok_or_else(|| RoboticusError::Config(format!("plugin '{name}' not found")))?;
294        entry.status = PluginStatus::Active;
295        debug!(name, "plugin enabled");
296        Ok(())
297    }
298
299    /// Removes a plugin from the registry entirely, shutting it down first.
300    ///
301    /// Unlike `disable_plugin` (which keeps the entry around so it can be
302    /// re-enabled), `unregister` drops the plugin and frees all associated
303    /// resources. This should be used when a plugin is permanently removed
304    /// (e.g., uninstalled or revoked by policy).
305    pub async fn unregister(&self, name: &str) -> Result<()> {
306        let mut plugins = self.plugins.lock().await;
307        let entry = plugins
308            .remove(name)
309            .ok_or_else(|| RoboticusError::Config(format!("plugin '{name}' not found")))?;
310        // Best-effort shutdown -- log but do not propagate errors.
311        let mut plugin = entry.plugin.lock().await;
312        if let Err(e) = plugin.shutdown().await {
313            warn!(name, error = %e, "plugin shutdown failed during unregister");
314        }
315        debug!(name, "plugin unregistered");
316        Ok(())
317    }
318
319    pub async fn plugin_count(&self) -> usize {
320        let plugins = self.plugins.lock().await;
321        plugins.len()
322    }
323}
324
325#[cfg(test)]
326mod tests {
327    use super::*;
328    use async_trait::async_trait;
329
330    struct MockPlugin {
331        name: String,
332        init_fail: bool,
333    }
334
335    impl MockPlugin {
336        fn new(name: &str) -> Self {
337            Self {
338                name: name.into(),
339                init_fail: false,
340            }
341        }
342        fn failing(name: &str) -> Self {
343            Self {
344                name: name.into(),
345                init_fail: true,
346            }
347        }
348    }
349
350    #[async_trait]
351    impl Plugin for MockPlugin {
352        fn name(&self) -> &str {
353            &self.name
354        }
355        fn version(&self) -> &str {
356            "1.0.0"
357        }
358        fn tools(&self) -> Vec<ToolDef> {
359            vec![ToolDef {
360                name: format!("{}_tool", self.name),
361                description: "mock tool".into(),
362                parameters: serde_json::json!({}),
363                risk_level: roboticus_core::RiskLevel::Safe,
364                permissions: vec![],
365            }]
366        }
367        async fn init(&mut self) -> Result<()> {
368            if self.init_fail {
369                Err(RoboticusError::Config("init failed".into()))
370            } else {
371                Ok(())
372            }
373        }
374        async fn execute_tool(&self, tool_name: &str, _input: &Value) -> Result<ToolResult> {
375            Ok(ToolResult {
376                success: true,
377                output: format!("executed {tool_name}"),
378                metadata: None,
379            })
380        }
381        async fn shutdown(&mut self) -> Result<()> {
382            Ok(())
383        }
384    }
385
386    #[test]
387    fn allow_deny_lists() {
388        let reg = PluginRegistry::new(
389            vec![],
390            vec!["blocked".into()],
391            PermissionPolicy {
392                strict: false,
393                allowed: vec![],
394            },
395        );
396        assert!(reg.is_allowed("anything"));
397        assert!(!reg.is_allowed("blocked"));
398
399        let reg2 = PluginRegistry::new(
400            vec!["only_this".into()],
401            vec![],
402            PermissionPolicy {
403                strict: false,
404                allowed: vec![],
405            },
406        );
407        assert!(reg2.is_allowed("only_this"));
408        assert!(!reg2.is_allowed("other"));
409    }
410
411    #[tokio::test]
412    async fn register_and_list() {
413        let reg = PluginRegistry::new(
414            vec![],
415            vec![],
416            PermissionPolicy {
417                strict: false,
418                allowed: vec![],
419            },
420        );
421        reg.register(Box::new(MockPlugin::new("test")))
422            .await
423            .unwrap();
424        assert_eq!(reg.plugin_count().await, 1);
425        let plugins = reg.list_plugins().await;
426        assert_eq!(plugins[0].name, "test");
427        assert_eq!(plugins[0].status, PluginStatus::Loaded);
428    }
429
430    #[tokio::test]
431    async fn register_denied_fails() {
432        let reg = PluginRegistry::new(
433            vec![],
434            vec!["bad".into()],
435            PermissionPolicy {
436                strict: false,
437                allowed: vec![],
438            },
439        );
440        let result = reg.register(Box::new(MockPlugin::new("bad"))).await;
441        assert!(result.is_err());
442    }
443
444    #[tokio::test]
445    async fn init_all_activates() {
446        let reg = PluginRegistry::new(
447            vec![],
448            vec![],
449            PermissionPolicy {
450                strict: false,
451                allowed: vec![],
452            },
453        );
454        reg.register(Box::new(MockPlugin::new("p1"))).await.unwrap();
455        let errors = reg.init_all().await;
456        assert!(errors.is_empty());
457        let plugins = reg.list_plugins().await;
458        assert_eq!(plugins[0].status, PluginStatus::Active);
459    }
460
461    #[tokio::test]
462    async fn init_failure_marks_error() {
463        let reg = PluginRegistry::new(
464            vec![],
465            vec![],
466            PermissionPolicy {
467                strict: false,
468                allowed: vec![],
469            },
470        );
471        reg.register(Box::new(MockPlugin::failing("bad")))
472            .await
473            .unwrap();
474        let errors = reg.init_all().await;
475        assert_eq!(errors.len(), 1);
476        let plugins = reg.list_plugins().await;
477        assert_eq!(plugins[0].status, PluginStatus::Error);
478    }
479
480    #[tokio::test]
481    async fn execute_tool_found() {
482        let reg = PluginRegistry::new(
483            vec![],
484            vec![],
485            PermissionPolicy {
486                strict: false,
487                allowed: vec![],
488            },
489        );
490        reg.register(Box::new(MockPlugin::new("p1"))).await.unwrap();
491        reg.init_all().await;
492        let result = reg
493            .execute_tool("p1_tool", &serde_json::json!({}))
494            .await
495            .unwrap();
496        assert!(result.success);
497        assert!(result.output.contains("p1_tool"));
498    }
499
500    #[tokio::test]
501    async fn execute_tool_not_found() {
502        let reg = PluginRegistry::new(
503            vec![],
504            vec![],
505            PermissionPolicy {
506                strict: false,
507                allowed: vec![],
508            },
509        );
510        let result = reg
511            .execute_tool("nonexistent", &serde_json::json!({}))
512            .await;
513        assert!(result.is_err());
514    }
515
516    #[tokio::test]
517    async fn find_tool() {
518        let reg = PluginRegistry::new(
519            vec![],
520            vec![],
521            PermissionPolicy {
522                strict: false,
523                allowed: vec![],
524            },
525        );
526        reg.register(Box::new(MockPlugin::new("alpha")))
527            .await
528            .unwrap();
529        reg.init_all().await;
530        let found = reg.find_tool("alpha_tool").await;
531        assert!(found.is_some());
532        let (plugin_name, tool) = found.unwrap();
533        assert_eq!(plugin_name, "alpha");
534        assert_eq!(tool.name, "alpha_tool");
535    }
536
537    #[tokio::test]
538    async fn disable_enable_plugin() {
539        let reg = PluginRegistry::new(
540            vec![],
541            vec![],
542            PermissionPolicy {
543                strict: false,
544                allowed: vec![],
545            },
546        );
547        reg.register(Box::new(MockPlugin::new("p"))).await.unwrap();
548        reg.init_all().await;
549
550        reg.disable_plugin("p").await.unwrap();
551        let plugins = reg.list_plugins().await;
552        assert_eq!(plugins[0].status, PluginStatus::Disabled);
553
554        let result = reg.execute_tool("p_tool", &serde_json::json!({})).await;
555        assert!(result.is_err());
556
557        reg.enable_plugin("p").await.unwrap();
558        let result = reg.execute_tool("p_tool", &serde_json::json!({})).await;
559        assert!(result.is_ok());
560    }
561
562    #[tokio::test]
563    async fn unregister_removes_plugin() {
564        let reg = PluginRegistry::new(
565            vec![],
566            vec![],
567            PermissionPolicy {
568                strict: false,
569                allowed: vec![],
570            },
571        );
572        reg.register(Box::new(MockPlugin::new("removable")))
573            .await
574            .unwrap();
575        assert_eq!(reg.plugin_count().await, 1);
576
577        reg.unregister("removable").await.unwrap();
578        assert_eq!(reg.plugin_count().await, 0);
579    }
580
581    #[tokio::test]
582    async fn unregister_nonexistent_fails() {
583        let reg = PluginRegistry::new(
584            vec![],
585            vec![],
586            PermissionPolicy {
587                strict: false,
588                allowed: vec![],
589            },
590        );
591        let result = reg.unregister("ghost").await;
592        assert!(result.is_err());
593    }
594
595    #[tokio::test]
596    async fn unregister_makes_tool_unavailable() {
597        let reg = PluginRegistry::new(
598            vec![],
599            vec![],
600            PermissionPolicy {
601                strict: false,
602                allowed: vec![],
603            },
604        );
605        reg.register(Box::new(MockPlugin::new("p1"))).await.unwrap();
606        reg.init_all().await;
607
608        // Tool should be available before unregister.
609        assert!(reg.find_tool("p1_tool").await.is_some());
610
611        reg.unregister("p1").await.unwrap();
612
613        // Tool should be gone after unregister.
614        assert!(reg.find_tool("p1_tool").await.is_none());
615        let result = reg.execute_tool("p1_tool", &serde_json::json!({})).await;
616        assert!(result.is_err());
617    }
618
619    #[tokio::test]
620    async fn list_all_tools() {
621        let reg = PluginRegistry::new(
622            vec![],
623            vec![],
624            PermissionPolicy {
625                strict: false,
626                allowed: vec![],
627            },
628        );
629        reg.register(Box::new(MockPlugin::new("a"))).await.unwrap();
630        reg.register(Box::new(MockPlugin::new("b"))).await.unwrap();
631        reg.init_all().await;
632        let tools = reg.list_all_tools().await;
633        assert_eq!(tools.len(), 2);
634    }
635
636    /// A mock plugin whose tool declares specific permissions.
637    struct PermissionMockPlugin {
638        name: String,
639        permissions: Vec<String>,
640    }
641
642    impl PermissionMockPlugin {
643        fn new(name: &str, permissions: Vec<String>) -> Self {
644            Self {
645                name: name.into(),
646                permissions,
647            }
648        }
649    }
650
651    #[async_trait]
652    impl Plugin for PermissionMockPlugin {
653        fn name(&self) -> &str {
654            &self.name
655        }
656        fn version(&self) -> &str {
657            "1.0.0"
658        }
659        fn tools(&self) -> Vec<ToolDef> {
660            vec![ToolDef {
661                name: format!("{}_tool", self.name),
662                description: "mock tool with permissions".into(),
663                parameters: serde_json::json!({}),
664                risk_level: roboticus_core::RiskLevel::Safe,
665                permissions: self.permissions.clone(),
666            }]
667        }
668        async fn init(&mut self) -> Result<()> {
669            Ok(())
670        }
671        async fn execute_tool(&self, tool_name: &str, _input: &Value) -> Result<ToolResult> {
672            Ok(ToolResult {
673                success: true,
674                output: format!("executed {tool_name}"),
675                metadata: None,
676            })
677        }
678        async fn shutdown(&mut self) -> Result<()> {
679            Ok(())
680        }
681    }
682
683    #[tokio::test]
684    async fn strict_mode_blocks_unauthorized_plugin() {
685        let reg = PluginRegistry::new(
686            vec![],
687            vec![],
688            PermissionPolicy {
689                strict: true,
690                allowed: vec![],
691            },
692        );
693        // Strict mode now rejects unauthorized plugins at registration time (fail-fast)
694        let result = reg
695            .register(Box::new(PermissionMockPlugin::new(
696                "net",
697                vec!["network".into()],
698            )))
699            .await;
700        assert!(result.is_err());
701        let err_msg = format!("{}", result.unwrap_err());
702        assert!(err_msg.contains("permission"));
703    }
704
705    #[tokio::test]
706    async fn permissive_mode_allows_unauthorized_plugin() {
707        let reg = PluginRegistry::new(
708            vec![],
709            vec![],
710            PermissionPolicy {
711                strict: false,
712                allowed: vec![],
713            },
714        );
715        reg.register(Box::new(PermissionMockPlugin::new(
716            "net",
717            vec!["network".into()],
718        )))
719        .await
720        .unwrap();
721        reg.init_all().await;
722
723        let result = reg.execute_tool("net_tool", &serde_json::json!({})).await;
724        assert!(result.is_ok());
725        assert!(result.unwrap().success);
726    }
727
728    #[tokio::test]
729    async fn allowed_permissions_pass_strict_check() {
730        let reg = PluginRegistry::new(
731            vec![],
732            vec![],
733            PermissionPolicy {
734                strict: true,
735                allowed: vec!["network".into()],
736            },
737        );
738        reg.register(Box::new(PermissionMockPlugin::new(
739            "net",
740            vec!["network".into()],
741        )))
742        .await
743        .unwrap();
744        reg.init_all().await;
745
746        let result = reg.execute_tool("net_tool", &serde_json::json!({})).await;
747        assert!(result.is_ok());
748        assert!(result.unwrap().success);
749    }
750}