Skip to main content

roboticus_agent/
capability.rs

1//! Unified capability registry — a catalog of invocable tools with metadata and dispatch.
2//!
3//! Runtime tools live in [`crate::tools::ToolRegistry`]. [`CapabilityRegistry`] mirrors that
4//! registry for LLM schema export and optional capability-aware execution, while preserving
5//! a single registration path (`sync_from_tool_registry`).
6
7use std::collections::HashMap;
8use std::fmt;
9use std::sync::Arc;
10
11use async_trait::async_trait;
12use serde::{Deserialize, Serialize};
13use serde_json::Value;
14use tokio::sync::RwLock;
15
16use roboticus_core::RiskLevel;
17use roboticus_core::config::McpTransport;
18
19use crate::tools::{ToolContext, ToolError, ToolRegistry, ToolResult};
20
21/// Where a capability was registered from (built-in binary vs plugin bridge vs MCP server).
22#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
23pub enum CapabilitySource {
24    BuiltIn,
25    Plugin(String),
26    Mcp {
27        server: String,
28        transport: McpTransport,
29    },
30}
31
32impl fmt::Display for CapabilitySource {
33    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34        match self {
35            Self::BuiltIn => write!(f, "built-in"),
36            Self::Plugin(p) => write!(f, "plugin:{p}"),
37            Self::Mcp { server, transport } => {
38                let t = match transport {
39                    McpTransport::Stdio => "stdio",
40                    McpTransport::Sse => "sse",
41                    McpTransport::Http => "http",
42                    McpTransport::WebSocket => "ws",
43                };
44                write!(f, "mcp:{server}({t})")
45            }
46        }
47    }
48}
49
50/// Executable capability surface (tool) visible to policy and the LLM catalog.
51#[async_trait]
52pub trait Capability: Send + Sync {
53    fn name(&self) -> &str;
54    fn description(&self) -> &str;
55    fn risk_level(&self) -> RiskLevel;
56    fn parameters_schema(&self) -> Value;
57    fn source(&self) -> CapabilitySource;
58
59    /// Optional companion skill id/path for capability discovery (e.g. plugin `paired_skill`).
60    fn paired_skill(&self) -> Option<&str> {
61        None
62    }
63
64    async fn execute(&self, params: Value, ctx: &ToolContext) -> Result<ToolResult, ToolError>;
65}
66
67/// Serializable summary for admin/API and LLM tool list builders.
68#[derive(Debug, Clone, Serialize)]
69pub struct CapabilitySummary {
70    pub name: String,
71    pub description: String,
72    pub source: CapabilitySource,
73    pub paired_skill: Option<String>,
74    pub risk_level: RiskLevel,
75    pub parameters_schema: Value,
76}
77
78#[derive(Debug)]
79pub enum RegistrationError {
80    NameConflict {
81        name: String,
82        existing_source: CapabilitySource,
83    },
84    InvalidMetadata(String),
85}
86
87impl fmt::Display for RegistrationError {
88    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89        match self {
90            Self::NameConflict {
91                name,
92                existing_source,
93            } => write!(
94                f,
95                "capability name conflict: '{name}' already registered ({existing_source})"
96            ),
97            Self::InvalidMetadata(m) => write!(f, "invalid capability metadata: {m}"),
98        }
99    }
100}
101
102impl std::error::Error for RegistrationError {}
103
104/// Holds all runtime capabilities keyed by tool name.
105pub struct CapabilityRegistry {
106    capabilities: RwLock<HashMap<String, Arc<dyn Capability>>>,
107}
108
109impl Default for CapabilityRegistry {
110    fn default() -> Self {
111        Self::new()
112    }
113}
114
115impl CapabilityRegistry {
116    pub fn new() -> Self {
117        Self {
118            capabilities: RwLock::new(HashMap::new()),
119        }
120    }
121
122    pub async fn is_empty(&self) -> bool {
123        self.capabilities.read().await.is_empty()
124    }
125
126    pub async fn register(&self, cap: Arc<dyn Capability>) -> Result<(), RegistrationError> {
127        let name = cap.name().to_string();
128        if name.is_empty() {
129            return Err(RegistrationError::InvalidMetadata(
130                "capability name is empty".into(),
131            ));
132        }
133        if cap.description().is_empty() {
134            return Err(RegistrationError::InvalidMetadata(
135                "capability description is empty".into(),
136            ));
137        }
138
139        let has_separator = name.contains("::");
140        let is_mcp = matches!(cap.source(), CapabilitySource::Mcp { .. });
141        if is_mcp && !has_separator {
142            return Err(RegistrationError::InvalidMetadata(format!(
143                "MCP capability '{name}' must use '::' separator (e.g., 'server::tool_name')"
144            )));
145        }
146        if !is_mcp && has_separator {
147            return Err(RegistrationError::InvalidMetadata(format!(
148                "non-MCP capability '{name}' must not use '::' separator (reserved for MCP)"
149            )));
150        }
151
152        let mut caps = self.capabilities.write().await;
153        if let Some(existing) = caps.get(&name)
154            && existing.source() != cap.source()
155        {
156            return Err(RegistrationError::NameConflict {
157                name,
158                existing_source: existing.source(),
159            });
160        }
161        caps.insert(name, cap);
162        Ok(())
163    }
164
165    pub async fn register_all(
166        &self,
167        capabilities: Vec<Arc<dyn Capability>>,
168    ) -> Vec<(String, RegistrationError)> {
169        let mut errors = Vec::new();
170        for cap in capabilities {
171            let name = cap.name().to_string();
172            if let Err(e) = self.register(cap).await {
173                errors.push((name, e));
174            }
175        }
176        errors
177    }
178
179    pub async fn get(&self, name: &str) -> Option<Arc<dyn Capability>> {
180        self.capabilities.read().await.get(name).cloned()
181    }
182
183    pub async fn catalog(&self) -> Vec<CapabilitySummary> {
184        let mut out: Vec<_> = self
185            .capabilities
186            .read()
187            .await
188            .values()
189            .map(|c| CapabilitySummary {
190                name: c.name().to_string(),
191                description: c.description().to_string(),
192                source: c.source(),
193                paired_skill: c.paired_skill().map(String::from),
194                risk_level: c.risk_level(),
195                parameters_schema: c.parameters_schema(),
196            })
197            .collect();
198        out.sort_by(|a, b| a.name.cmp(&b.name));
199        out
200    }
201
202    pub async fn list_names(&self) -> Vec<String> {
203        let mut names: Vec<_> = self.capabilities.read().await.keys().cloned().collect();
204        names.sort();
205        names
206    }
207
208    /// Remove all capabilities previously attributed to `plugin_name`, then register `new_capabilities`.
209    ///
210    /// Holds the write lock for the entire operation so the pipeline never
211    /// sees a partial tool set during hot-reload.
212    pub async fn reload_plugin(
213        &self,
214        plugin_name: &str,
215        new_capabilities: Vec<Arc<dyn Capability>>,
216    ) -> Vec<(String, RegistrationError)> {
217        let target = CapabilitySource::Plugin(plugin_name.to_string());
218
219        // Validate all incoming capabilities before acquiring the lock so we
220        // never leave the registry in a half-replaced state.
221        let mut errors: Vec<(String, RegistrationError)> = Vec::new();
222        let mut valid: Vec<Arc<dyn Capability>> = Vec::new();
223        for cap in new_capabilities {
224            let name = cap.name().to_string();
225            if name.is_empty() {
226                errors.push((
227                    name,
228                    RegistrationError::InvalidMetadata("capability name is empty".into()),
229                ));
230                continue;
231            }
232            if cap.description().is_empty() {
233                errors.push((
234                    name,
235                    RegistrationError::InvalidMetadata("capability description is empty".into()),
236                ));
237                continue;
238            }
239            if cap.name().contains("::") {
240                errors.push((
241                    name,
242                    RegistrationError::InvalidMetadata(format!(
243                        "non-MCP capability '{}' must not use '::' separator (reserved for MCP)",
244                        cap.name()
245                    )),
246                ));
247                continue;
248            }
249            valid.push(cap);
250        }
251
252        // Hold the write lock for the entire remove-then-insert cycle to
253        // prevent the pipeline from observing a partial (empty) tool set.
254        let mut caps = self.capabilities.write().await;
255        caps.retain(|_, c| c.source() != target);
256        for cap in valid {
257            let name = cap.name().to_string();
258            // Plugin tools must not conflict with a different source.
259            if let Some(existing) = caps.get(&name)
260                && existing.source() != cap.source()
261            {
262                errors.push((
263                    name,
264                    RegistrationError::NameConflict {
265                        name: cap.name().to_string(),
266                        existing_source: existing.source(),
267                    },
268                ));
269                continue;
270            }
271            caps.insert(name, cap);
272        }
273        drop(caps);
274
275        errors
276    }
277
278    /// Atomically replace all capabilities from a specific MCP server.
279    ///
280    /// Holds the write lock for the entire operation so the pipeline never
281    /// sees a partial tool set during hot-reload.
282    pub async fn reload_mcp_server(
283        &self,
284        server_name: &str,
285        new_capabilities: Vec<Arc<dyn Capability>>,
286    ) -> Result<(), RegistrationError> {
287        // Validate all new capabilities before taking the lock so we never
288        // leave the registry in a half-replaced state.
289        for cap in &new_capabilities {
290            if cap.name().is_empty() {
291                return Err(RegistrationError::InvalidMetadata(
292                    "capability name is empty".into(),
293                ));
294            }
295            if cap.description().is_empty() {
296                return Err(RegistrationError::InvalidMetadata(
297                    "capability description is empty".into(),
298                ));
299            }
300            if !cap.name().contains("::") {
301                return Err(RegistrationError::InvalidMetadata(format!(
302                    "MCP capability '{}' must use '::' separator",
303                    cap.name()
304                )));
305            }
306        }
307
308        let mut caps = self.capabilities.write().await;
309
310        // Remove all existing capabilities from this server.
311        caps.retain(|_, existing| {
312            !matches!(existing.source(), CapabilitySource::Mcp { server, .. } if server == server_name)
313        });
314
315        // Insert all new capabilities atomically under the same lock.
316        for cap in new_capabilities {
317            let name = cap.name().to_string();
318            caps.insert(name, cap);
319        }
320
321        Ok(())
322    }
323
324    /// Replace the catalog with one entry per tool in `registry` (stable name order).
325    pub async fn sync_from_tool_registry(&self, registry: Arc<ToolRegistry>) -> Result<(), String> {
326        let mut caps = self.capabilities.write().await;
327        caps.clear();
328        drop(caps);
329
330        let mut tools: Vec<_> = registry.list();
331        tools.sort_by_key(|t| t.name());
332        let mut errors = Vec::new();
333        for tool in tools {
334            let name = tool.name().to_string();
335            let source = match tool.plugin_owner() {
336                Some(p) => CapabilitySource::Plugin(p.to_string()),
337                None => CapabilitySource::BuiltIn,
338            };
339            let cap = Arc::new(ToolRegistryCapability {
340                registry: Arc::clone(&registry),
341                name,
342                source,
343            });
344            if let Err(e) = self.register(cap).await {
345                errors.push(e.to_string());
346            }
347        }
348        if errors.is_empty() {
349            Ok(())
350        } else {
351            Err(format!(
352                "capability sync partially failed ({} error(s)): {}",
353                errors.len(),
354                errors.join("; ")
355            ))
356        }
357    }
358
359    /// Rebuild capabilities from tools (e.g. after hot-loading plugins into `ToolRegistry`).
360    pub async fn resync_tools(&self, registry: Arc<ToolRegistry>) -> Result<(), String> {
361        self.sync_from_tool_registry(registry).await
362    }
363}
364
365/// [`Capability`] backed by name lookup in a [`ToolRegistry`].
366pub struct ToolRegistryCapability {
367    registry: Arc<ToolRegistry>,
368    name: String,
369    source: CapabilitySource,
370}
371
372#[async_trait]
373impl Capability for ToolRegistryCapability {
374    fn name(&self) -> &str {
375        &self.name
376    }
377
378    fn description(&self) -> &str {
379        self.registry
380            .get(&self.name)
381            .map(|t| t.description())
382            .unwrap_or("")
383    }
384
385    fn risk_level(&self) -> RiskLevel {
386        self.registry
387            .get(&self.name)
388            .map(|t| t.risk_level())
389            .unwrap_or(RiskLevel::Forbidden)
390    }
391
392    fn parameters_schema(&self) -> Value {
393        self.registry
394            .get(&self.name)
395            .map(|t| t.parameters_schema())
396            .unwrap_or_else(|| serde_json::json!({"type": "object"}))
397    }
398
399    fn source(&self) -> CapabilitySource {
400        self.source.clone()
401    }
402
403    fn paired_skill(&self) -> Option<&str> {
404        self.registry.get(&self.name).and_then(|t| t.paired_skill())
405    }
406
407    async fn execute(&self, params: Value, ctx: &ToolContext) -> Result<ToolResult, ToolError> {
408        let tool = self.registry.get(&self.name).ok_or_else(|| ToolError {
409            message: format!("tool '{}' not found in ToolRegistry", self.name),
410        })?;
411        tool.execute(params, ctx).await
412    }
413}
414
415#[cfg(test)]
416mod tests {
417    use super::*;
418    use crate::tools::ToolRegistry;
419
420    #[tokio::test]
421    async fn sync_populates_catalog() {
422        use crate::tools::EchoTool;
423
424        let mut reg = ToolRegistry::new();
425        reg.register(Box::new(EchoTool));
426        let reg = Arc::new(reg);
427        let caps = CapabilityRegistry::new();
428        caps.sync_from_tool_registry(Arc::clone(&reg))
429            .await
430            .unwrap();
431        assert!(!caps.is_empty().await);
432        let names = caps.list_names().await;
433        assert!(names.iter().any(|n| n == "echo"));
434    }
435
436    // Task 3: CapabilitySource::Mcp display
437
438    #[test]
439    fn mcp_source_display_stdio() {
440        let source = CapabilitySource::Mcp {
441            server: "github".into(),
442            transport: McpTransport::Stdio,
443        };
444        assert_eq!(source.to_string(), "mcp:github(stdio)");
445    }
446
447    #[test]
448    fn mcp_source_display_sse() {
449        let source = CapabilitySource::Mcp {
450            server: "linear".into(),
451            transport: McpTransport::Sse,
452        };
453        assert_eq!(source.to_string(), "mcp:linear(sse)");
454    }
455
456    #[test]
457    fn mcp_source_display_http() {
458        let source = CapabilitySource::Mcp {
459            server: "sentry".into(),
460            transport: McpTransport::Http,
461        };
462        assert_eq!(source.to_string(), "mcp:sentry(http)");
463    }
464
465    #[test]
466    fn mcp_source_display_websocket() {
467        let source = CapabilitySource::Mcp {
468            server: "relay".into(),
469            transport: McpTransport::WebSocket,
470        };
471        assert_eq!(source.to_string(), "mcp:relay(ws)");
472    }
473
474    // Task 4: :: separator enforcement
475
476    /// A minimal capability stub for testing registration rules.
477    struct StubCap {
478        name: String,
479        source: CapabilitySource,
480    }
481
482    #[async_trait::async_trait]
483    impl Capability for StubCap {
484        fn name(&self) -> &str {
485            &self.name
486        }
487        fn description(&self) -> &str {
488            "stub"
489        }
490        fn risk_level(&self) -> roboticus_core::RiskLevel {
491            roboticus_core::RiskLevel::Safe
492        }
493        fn parameters_schema(&self) -> serde_json::Value {
494            serde_json::json!({"type": "object"})
495        }
496        fn source(&self) -> CapabilitySource {
497            self.source.clone()
498        }
499        async fn execute(
500            &self,
501            _params: serde_json::Value,
502            _ctx: &crate::tools::ToolContext,
503        ) -> Result<crate::tools::ToolResult, crate::tools::ToolError> {
504            Ok(crate::tools::ToolResult {
505                output: "stub".into(),
506                metadata: None,
507            })
508        }
509    }
510
511    #[tokio::test]
512    async fn register_rejects_builtin_with_separator() {
513        let reg = CapabilityRegistry::new();
514        let cap = Arc::new(StubCap {
515            name: "ns::tool".into(),
516            source: CapabilitySource::BuiltIn,
517        });
518        let err = reg.register(cap).await.unwrap_err();
519        assert!(
520            matches!(err, RegistrationError::InvalidMetadata(_)),
521            "expected InvalidMetadata, got: {err}"
522        );
523        assert!(err.to_string().contains("reserved for MCP"));
524    }
525
526    #[tokio::test]
527    async fn register_rejects_plugin_with_separator() {
528        let reg = CapabilityRegistry::new();
529        let cap = Arc::new(StubCap {
530            name: "ns::tool".into(),
531            source: CapabilitySource::Plugin("myplugin".into()),
532        });
533        let err = reg.register(cap).await.unwrap_err();
534        assert!(
535            matches!(err, RegistrationError::InvalidMetadata(_)),
536            "expected InvalidMetadata, got: {err}"
537        );
538        assert!(err.to_string().contains("reserved for MCP"));
539    }
540
541    #[tokio::test]
542    async fn register_rejects_mcp_without_separator() {
543        let reg = CapabilityRegistry::new();
544        let cap = Arc::new(StubCap {
545            name: "tool_name".into(),
546            source: CapabilitySource::Mcp {
547                server: "github".into(),
548                transport: McpTransport::Stdio,
549            },
550        });
551        let err = reg.register(cap).await.unwrap_err();
552        assert!(
553            matches!(err, RegistrationError::InvalidMetadata(_)),
554            "expected InvalidMetadata, got: {err}"
555        );
556        assert!(err.to_string().contains("must use '::' separator"));
557    }
558
559    #[tokio::test]
560    async fn register_allows_mcp_with_separator() {
561        let reg = CapabilityRegistry::new();
562        let cap = Arc::new(StubCap {
563            name: "github::create_issue".into(),
564            source: CapabilitySource::Mcp {
565                server: "github".into(),
566                transport: McpTransport::Stdio,
567            },
568        });
569        reg.register(cap).await.unwrap();
570        assert!(reg.get("github::create_issue").await.is_some());
571    }
572
573    #[tokio::test]
574    async fn register_allows_builtin_without_separator() {
575        let reg = CapabilityRegistry::new();
576        let cap = Arc::new(StubCap {
577            name: "bash".into(),
578            source: CapabilitySource::BuiltIn,
579        });
580        reg.register(cap).await.unwrap();
581        assert!(reg.get("bash").await.is_some());
582    }
583
584    // Task 7: atomic reload_mcp_server
585
586    fn make_mcp_cap(server: &str, tool: &str) -> Arc<StubCap> {
587        Arc::new(StubCap {
588            name: format!("{server}::{tool}"),
589            source: CapabilitySource::Mcp {
590                server: server.into(),
591                transport: McpTransport::Stdio,
592            },
593        })
594    }
595
596    #[tokio::test]
597    async fn atomic_reload_swaps_all_at_once() {
598        let registry = CapabilityRegistry::new();
599
600        // Register an old MCP tool.
601        let old_cap = make_mcp_cap("myserver", "old_tool");
602        registry.register(old_cap).await.unwrap();
603        assert!(registry.get("myserver::old_tool").await.is_some());
604
605        // Reload with a new tool set.
606        let new_cap = make_mcp_cap("myserver", "new_tool");
607        registry
608            .reload_mcp_server("myserver", vec![new_cap])
609            .await
610            .unwrap();
611
612        // Old tool must be gone; new tool must be present.
613        let summaries = registry.catalog().await;
614        assert!(
615            summaries.iter().any(|s| s.name == "myserver::new_tool"),
616            "new tool should be in the catalog"
617        );
618        assert!(
619            !summaries.iter().any(|s| s.name == "myserver::old_tool"),
620            "old tool should have been removed"
621        );
622    }
623
624    #[tokio::test]
625    async fn atomic_reload_rejects_cap_without_separator() {
626        let registry = CapabilityRegistry::new();
627        let bad_cap = Arc::new(StubCap {
628            name: "notnamespaced".into(),
629            source: CapabilitySource::Mcp {
630                server: "myserver".into(),
631                transport: McpTransport::Stdio,
632            },
633        });
634        let err = registry
635            .reload_mcp_server("myserver", vec![bad_cap])
636            .await
637            .unwrap_err();
638        assert!(
639            matches!(err, RegistrationError::InvalidMetadata(_)),
640            "expected InvalidMetadata, got: {err}"
641        );
642        assert!(err.to_string().contains("must use '::' separator"));
643    }
644
645    #[tokio::test]
646    async fn atomic_reload_only_removes_matching_server() {
647        let registry = CapabilityRegistry::new();
648
649        // Register tools from two different servers.
650        let cap_a = make_mcp_cap("server_a", "tool1");
651        let cap_b = make_mcp_cap("server_b", "tool2");
652        registry.register(cap_a).await.unwrap();
653        registry.register(cap_b).await.unwrap();
654
655        // Reload only server_a.
656        let new_cap = make_mcp_cap("server_a", "tool_new");
657        registry
658            .reload_mcp_server("server_a", vec![new_cap])
659            .await
660            .unwrap();
661
662        // server_b tool must still be present.
663        assert!(
664            registry.get("server_b::tool2").await.is_some(),
665            "server_b tools should not be touched"
666        );
667        assert!(
668            registry.get("server_a::tool_new").await.is_some(),
669            "new server_a tool should be present"
670        );
671        assert!(
672            registry.get("server_a::tool1").await.is_none(),
673            "old server_a tool should be gone"
674        );
675    }
676
677    // Task 7: reload_plugin TOCTOU fix verification
678
679    #[tokio::test]
680    async fn reload_plugin_holds_lock_atomically() {
681        let registry = CapabilityRegistry::new();
682
683        // Register an old plugin tool.
684        let old_cap = Arc::new(StubCap {
685            name: "old_action".into(),
686            source: CapabilitySource::Plugin("myplugin".into()),
687        });
688        registry.register(old_cap).await.unwrap();
689
690        // Reload with a new tool.
691        let new_cap = Arc::new(StubCap {
692            name: "new_action".into(),
693            source: CapabilitySource::Plugin("myplugin".into()),
694        });
695        let errors = registry.reload_plugin("myplugin", vec![new_cap]).await;
696        assert!(errors.is_empty(), "unexpected errors: {errors:?}");
697
698        let names = registry.list_names().await;
699        assert!(
700            names.contains(&"new_action".to_string()),
701            "new tool should be registered"
702        );
703        assert!(
704            !names.contains(&"old_action".to_string()),
705            "old tool should be removed"
706        );
707    }
708}