use std::collections::HashMap;
use std::fmt;
use std::sync::Arc;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tokio::sync::RwLock;
use roboticus_core::RiskLevel;
use roboticus_core::config::McpTransport;
use crate::tools::{ToolContext, ToolError, ToolRegistry, ToolResult};
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum CapabilitySource {
BuiltIn,
Plugin(String),
Mcp {
server: String,
transport: McpTransport,
},
}
impl fmt::Display for CapabilitySource {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::BuiltIn => write!(f, "built-in"),
Self::Plugin(p) => write!(f, "plugin:{p}"),
Self::Mcp { server, transport } => {
let t = match transport {
McpTransport::Stdio => "stdio",
McpTransport::Sse => "sse",
McpTransport::Http => "http",
McpTransport::WebSocket => "ws",
};
write!(f, "mcp:{server}({t})")
}
}
}
}
#[async_trait]
pub trait Capability: Send + Sync {
fn name(&self) -> &str;
fn description(&self) -> &str;
fn risk_level(&self) -> RiskLevel;
fn parameters_schema(&self) -> Value;
fn source(&self) -> CapabilitySource;
fn paired_skill(&self) -> Option<&str> {
None
}
async fn execute(&self, params: Value, ctx: &ToolContext) -> Result<ToolResult, ToolError>;
}
#[derive(Debug, Clone, Serialize)]
pub struct CapabilitySummary {
pub name: String,
pub description: String,
pub source: CapabilitySource,
pub paired_skill: Option<String>,
pub risk_level: RiskLevel,
pub parameters_schema: Value,
}
#[derive(Debug)]
pub enum RegistrationError {
NameConflict {
name: String,
existing_source: CapabilitySource,
},
InvalidMetadata(String),
}
impl fmt::Display for RegistrationError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::NameConflict {
name,
existing_source,
} => write!(
f,
"capability name conflict: '{name}' already registered ({existing_source})"
),
Self::InvalidMetadata(m) => write!(f, "invalid capability metadata: {m}"),
}
}
}
impl std::error::Error for RegistrationError {}
pub struct CapabilityRegistry {
capabilities: RwLock<HashMap<String, Arc<dyn Capability>>>,
}
impl Default for CapabilityRegistry {
fn default() -> Self {
Self::new()
}
}
impl CapabilityRegistry {
pub fn new() -> Self {
Self {
capabilities: RwLock::new(HashMap::new()),
}
}
pub async fn is_empty(&self) -> bool {
self.capabilities.read().await.is_empty()
}
pub async fn register(&self, cap: Arc<dyn Capability>) -> Result<(), RegistrationError> {
let name = cap.name().to_string();
if name.is_empty() {
return Err(RegistrationError::InvalidMetadata(
"capability name is empty".into(),
));
}
if cap.description().is_empty() {
return Err(RegistrationError::InvalidMetadata(
"capability description is empty".into(),
));
}
let has_separator = name.contains("::");
let is_mcp = matches!(cap.source(), CapabilitySource::Mcp { .. });
if is_mcp && !has_separator {
return Err(RegistrationError::InvalidMetadata(format!(
"MCP capability '{name}' must use '::' separator (e.g., 'server::tool_name')"
)));
}
if !is_mcp && has_separator {
return Err(RegistrationError::InvalidMetadata(format!(
"non-MCP capability '{name}' must not use '::' separator (reserved for MCP)"
)));
}
let mut caps = self.capabilities.write().await;
if let Some(existing) = caps.get(&name)
&& existing.source() != cap.source()
{
return Err(RegistrationError::NameConflict {
name,
existing_source: existing.source(),
});
}
caps.insert(name, cap);
Ok(())
}
pub async fn register_all(
&self,
capabilities: Vec<Arc<dyn Capability>>,
) -> Vec<(String, RegistrationError)> {
let mut errors = Vec::new();
for cap in capabilities {
let name = cap.name().to_string();
if let Err(e) = self.register(cap).await {
errors.push((name, e));
}
}
errors
}
pub async fn get(&self, name: &str) -> Option<Arc<dyn Capability>> {
self.capabilities.read().await.get(name).cloned()
}
pub async fn catalog(&self) -> Vec<CapabilitySummary> {
let mut out: Vec<_> = self
.capabilities
.read()
.await
.values()
.map(|c| CapabilitySummary {
name: c.name().to_string(),
description: c.description().to_string(),
source: c.source(),
paired_skill: c.paired_skill().map(String::from),
risk_level: c.risk_level(),
parameters_schema: c.parameters_schema(),
})
.collect();
out.sort_by(|a, b| a.name.cmp(&b.name));
out
}
pub async fn list_names(&self) -> Vec<String> {
let mut names: Vec<_> = self.capabilities.read().await.keys().cloned().collect();
names.sort();
names
}
pub async fn reload_plugin(
&self,
plugin_name: &str,
new_capabilities: Vec<Arc<dyn Capability>>,
) -> Vec<(String, RegistrationError)> {
let target = CapabilitySource::Plugin(plugin_name.to_string());
let mut errors: Vec<(String, RegistrationError)> = Vec::new();
let mut valid: Vec<Arc<dyn Capability>> = Vec::new();
for cap in new_capabilities {
let name = cap.name().to_string();
if name.is_empty() {
errors.push((
name,
RegistrationError::InvalidMetadata("capability name is empty".into()),
));
continue;
}
if cap.description().is_empty() {
errors.push((
name,
RegistrationError::InvalidMetadata("capability description is empty".into()),
));
continue;
}
if cap.name().contains("::") {
errors.push((
name,
RegistrationError::InvalidMetadata(format!(
"non-MCP capability '{}' must not use '::' separator (reserved for MCP)",
cap.name()
)),
));
continue;
}
valid.push(cap);
}
let mut caps = self.capabilities.write().await;
caps.retain(|_, c| c.source() != target);
for cap in valid {
let name = cap.name().to_string();
if let Some(existing) = caps.get(&name)
&& existing.source() != cap.source()
{
errors.push((
name,
RegistrationError::NameConflict {
name: cap.name().to_string(),
existing_source: existing.source(),
},
));
continue;
}
caps.insert(name, cap);
}
drop(caps);
errors
}
pub async fn reload_mcp_server(
&self,
server_name: &str,
new_capabilities: Vec<Arc<dyn Capability>>,
) -> Result<(), RegistrationError> {
for cap in &new_capabilities {
if cap.name().is_empty() {
return Err(RegistrationError::InvalidMetadata(
"capability name is empty".into(),
));
}
if cap.description().is_empty() {
return Err(RegistrationError::InvalidMetadata(
"capability description is empty".into(),
));
}
if !cap.name().contains("::") {
return Err(RegistrationError::InvalidMetadata(format!(
"MCP capability '{}' must use '::' separator",
cap.name()
)));
}
}
let mut caps = self.capabilities.write().await;
caps.retain(|_, existing| {
!matches!(existing.source(), CapabilitySource::Mcp { server, .. } if server == server_name)
});
for cap in new_capabilities {
let name = cap.name().to_string();
caps.insert(name, cap);
}
Ok(())
}
pub async fn sync_from_tool_registry(&self, registry: Arc<ToolRegistry>) -> Result<(), String> {
let mut caps = self.capabilities.write().await;
caps.clear();
drop(caps);
let mut tools: Vec<_> = registry.list();
tools.sort_by_key(|t| t.name());
let mut errors = Vec::new();
for tool in tools {
let name = tool.name().to_string();
let source = match tool.plugin_owner() {
Some(p) => CapabilitySource::Plugin(p.to_string()),
None => CapabilitySource::BuiltIn,
};
let cap = Arc::new(ToolRegistryCapability {
registry: Arc::clone(®istry),
name,
source,
});
if let Err(e) = self.register(cap).await {
errors.push(e.to_string());
}
}
if errors.is_empty() {
Ok(())
} else {
Err(format!(
"capability sync partially failed ({} error(s)): {}",
errors.len(),
errors.join("; ")
))
}
}
pub async fn resync_tools(&self, registry: Arc<ToolRegistry>) -> Result<(), String> {
self.sync_from_tool_registry(registry).await
}
}
pub struct ToolRegistryCapability {
registry: Arc<ToolRegistry>,
name: String,
source: CapabilitySource,
}
#[async_trait]
impl Capability for ToolRegistryCapability {
fn name(&self) -> &str {
&self.name
}
fn description(&self) -> &str {
self.registry
.get(&self.name)
.map(|t| t.description())
.unwrap_or("")
}
fn risk_level(&self) -> RiskLevel {
self.registry
.get(&self.name)
.map(|t| t.risk_level())
.unwrap_or(RiskLevel::Forbidden)
}
fn parameters_schema(&self) -> Value {
self.registry
.get(&self.name)
.map(|t| t.parameters_schema())
.unwrap_or_else(|| serde_json::json!({"type": "object"}))
}
fn source(&self) -> CapabilitySource {
self.source.clone()
}
fn paired_skill(&self) -> Option<&str> {
self.registry.get(&self.name).and_then(|t| t.paired_skill())
}
async fn execute(&self, params: Value, ctx: &ToolContext) -> Result<ToolResult, ToolError> {
let tool = self.registry.get(&self.name).ok_or_else(|| ToolError {
message: format!("tool '{}' not found in ToolRegistry", self.name),
})?;
tool.execute(params, ctx).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tools::ToolRegistry;
#[tokio::test]
async fn sync_populates_catalog() {
use crate::tools::EchoTool;
let mut reg = ToolRegistry::new();
reg.register(Box::new(EchoTool));
let reg = Arc::new(reg);
let caps = CapabilityRegistry::new();
caps.sync_from_tool_registry(Arc::clone(®))
.await
.unwrap();
assert!(!caps.is_empty().await);
let names = caps.list_names().await;
assert!(names.iter().any(|n| n == "echo"));
}
#[test]
fn mcp_source_display_stdio() {
let source = CapabilitySource::Mcp {
server: "github".into(),
transport: McpTransport::Stdio,
};
assert_eq!(source.to_string(), "mcp:github(stdio)");
}
#[test]
fn mcp_source_display_sse() {
let source = CapabilitySource::Mcp {
server: "linear".into(),
transport: McpTransport::Sse,
};
assert_eq!(source.to_string(), "mcp:linear(sse)");
}
#[test]
fn mcp_source_display_http() {
let source = CapabilitySource::Mcp {
server: "sentry".into(),
transport: McpTransport::Http,
};
assert_eq!(source.to_string(), "mcp:sentry(http)");
}
#[test]
fn mcp_source_display_websocket() {
let source = CapabilitySource::Mcp {
server: "relay".into(),
transport: McpTransport::WebSocket,
};
assert_eq!(source.to_string(), "mcp:relay(ws)");
}
struct StubCap {
name: String,
source: CapabilitySource,
}
#[async_trait::async_trait]
impl Capability for StubCap {
fn name(&self) -> &str {
&self.name
}
fn description(&self) -> &str {
"stub"
}
fn risk_level(&self) -> roboticus_core::RiskLevel {
roboticus_core::RiskLevel::Safe
}
fn parameters_schema(&self) -> serde_json::Value {
serde_json::json!({"type": "object"})
}
fn source(&self) -> CapabilitySource {
self.source.clone()
}
async fn execute(
&self,
_params: serde_json::Value,
_ctx: &crate::tools::ToolContext,
) -> Result<crate::tools::ToolResult, crate::tools::ToolError> {
Ok(crate::tools::ToolResult {
output: "stub".into(),
metadata: None,
})
}
}
#[tokio::test]
async fn register_rejects_builtin_with_separator() {
let reg = CapabilityRegistry::new();
let cap = Arc::new(StubCap {
name: "ns::tool".into(),
source: CapabilitySource::BuiltIn,
});
let err = reg.register(cap).await.unwrap_err();
assert!(
matches!(err, RegistrationError::InvalidMetadata(_)),
"expected InvalidMetadata, got: {err}"
);
assert!(err.to_string().contains("reserved for MCP"));
}
#[tokio::test]
async fn register_rejects_plugin_with_separator() {
let reg = CapabilityRegistry::new();
let cap = Arc::new(StubCap {
name: "ns::tool".into(),
source: CapabilitySource::Plugin("myplugin".into()),
});
let err = reg.register(cap).await.unwrap_err();
assert!(
matches!(err, RegistrationError::InvalidMetadata(_)),
"expected InvalidMetadata, got: {err}"
);
assert!(err.to_string().contains("reserved for MCP"));
}
#[tokio::test]
async fn register_rejects_mcp_without_separator() {
let reg = CapabilityRegistry::new();
let cap = Arc::new(StubCap {
name: "tool_name".into(),
source: CapabilitySource::Mcp {
server: "github".into(),
transport: McpTransport::Stdio,
},
});
let err = reg.register(cap).await.unwrap_err();
assert!(
matches!(err, RegistrationError::InvalidMetadata(_)),
"expected InvalidMetadata, got: {err}"
);
assert!(err.to_string().contains("must use '::' separator"));
}
#[tokio::test]
async fn register_allows_mcp_with_separator() {
let reg = CapabilityRegistry::new();
let cap = Arc::new(StubCap {
name: "github::create_issue".into(),
source: CapabilitySource::Mcp {
server: "github".into(),
transport: McpTransport::Stdio,
},
});
reg.register(cap).await.unwrap();
assert!(reg.get("github::create_issue").await.is_some());
}
#[tokio::test]
async fn register_allows_builtin_without_separator() {
let reg = CapabilityRegistry::new();
let cap = Arc::new(StubCap {
name: "bash".into(),
source: CapabilitySource::BuiltIn,
});
reg.register(cap).await.unwrap();
assert!(reg.get("bash").await.is_some());
}
fn make_mcp_cap(server: &str, tool: &str) -> Arc<StubCap> {
Arc::new(StubCap {
name: format!("{server}::{tool}"),
source: CapabilitySource::Mcp {
server: server.into(),
transport: McpTransport::Stdio,
},
})
}
#[tokio::test]
async fn atomic_reload_swaps_all_at_once() {
let registry = CapabilityRegistry::new();
let old_cap = make_mcp_cap("myserver", "old_tool");
registry.register(old_cap).await.unwrap();
assert!(registry.get("myserver::old_tool").await.is_some());
let new_cap = make_mcp_cap("myserver", "new_tool");
registry
.reload_mcp_server("myserver", vec![new_cap])
.await
.unwrap();
let summaries = registry.catalog().await;
assert!(
summaries.iter().any(|s| s.name == "myserver::new_tool"),
"new tool should be in the catalog"
);
assert!(
!summaries.iter().any(|s| s.name == "myserver::old_tool"),
"old tool should have been removed"
);
}
#[tokio::test]
async fn atomic_reload_rejects_cap_without_separator() {
let registry = CapabilityRegistry::new();
let bad_cap = Arc::new(StubCap {
name: "notnamespaced".into(),
source: CapabilitySource::Mcp {
server: "myserver".into(),
transport: McpTransport::Stdio,
},
});
let err = registry
.reload_mcp_server("myserver", vec![bad_cap])
.await
.unwrap_err();
assert!(
matches!(err, RegistrationError::InvalidMetadata(_)),
"expected InvalidMetadata, got: {err}"
);
assert!(err.to_string().contains("must use '::' separator"));
}
#[tokio::test]
async fn atomic_reload_only_removes_matching_server() {
let registry = CapabilityRegistry::new();
let cap_a = make_mcp_cap("server_a", "tool1");
let cap_b = make_mcp_cap("server_b", "tool2");
registry.register(cap_a).await.unwrap();
registry.register(cap_b).await.unwrap();
let new_cap = make_mcp_cap("server_a", "tool_new");
registry
.reload_mcp_server("server_a", vec![new_cap])
.await
.unwrap();
assert!(
registry.get("server_b::tool2").await.is_some(),
"server_b tools should not be touched"
);
assert!(
registry.get("server_a::tool_new").await.is_some(),
"new server_a tool should be present"
);
assert!(
registry.get("server_a::tool1").await.is_none(),
"old server_a tool should be gone"
);
}
#[tokio::test]
async fn reload_plugin_holds_lock_atomically() {
let registry = CapabilityRegistry::new();
let old_cap = Arc::new(StubCap {
name: "old_action".into(),
source: CapabilitySource::Plugin("myplugin".into()),
});
registry.register(old_cap).await.unwrap();
let new_cap = Arc::new(StubCap {
name: "new_action".into(),
source: CapabilitySource::Plugin("myplugin".into()),
});
let errors = registry.reload_plugin("myplugin", vec![new_cap]).await;
assert!(errors.is_empty(), "unexpected errors: {errors:?}");
let names = registry.list_names().await;
assert!(
names.contains(&"new_action".to_string()),
"new tool should be registered"
);
assert!(
!names.contains(&"old_action".to_string()),
"old tool should be removed"
);
}
}