use std::collections::HashMap;
use std::fmt;
use std::sync::Arc;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tokio::sync::RwLock;
use roboticus_core::RiskLevel;
use crate::tools::{ToolContext, ToolError, ToolRegistry, ToolResult};
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum CapabilitySource {
BuiltIn,
Plugin(String),
}
impl fmt::Display for CapabilitySource {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::BuiltIn => write!(f, "built-in"),
Self::Plugin(p) => write!(f, "plugin:{p}"),
}
}
}
#[async_trait]
pub trait Capability: Send + Sync {
fn name(&self) -> &str;
fn description(&self) -> &str;
fn risk_level(&self) -> RiskLevel;
fn parameters_schema(&self) -> Value;
fn source(&self) -> CapabilitySource;
fn paired_skill(&self) -> Option<&str> {
None
}
async fn execute(&self, params: Value, ctx: &ToolContext) -> Result<ToolResult, ToolError>;
}
#[derive(Debug, Clone, Serialize)]
pub struct CapabilitySummary {
pub name: String,
pub description: String,
pub source: CapabilitySource,
pub paired_skill: Option<String>,
pub risk_level: RiskLevel,
pub parameters_schema: Value,
}
#[derive(Debug)]
pub enum RegistrationError {
NameConflict {
name: String,
existing_source: CapabilitySource,
},
InvalidMetadata(String),
}
impl fmt::Display for RegistrationError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::NameConflict {
name,
existing_source,
} => write!(
f,
"capability name conflict: '{name}' already registered ({existing_source})"
),
Self::InvalidMetadata(m) => write!(f, "invalid capability metadata: {m}"),
}
}
}
impl std::error::Error for RegistrationError {}
pub struct CapabilityRegistry {
capabilities: RwLock<HashMap<String, Arc<dyn Capability>>>,
}
impl Default for CapabilityRegistry {
fn default() -> Self {
Self::new()
}
}
impl CapabilityRegistry {
pub fn new() -> Self {
Self {
capabilities: RwLock::new(HashMap::new()),
}
}
pub async fn is_empty(&self) -> bool {
self.capabilities.read().await.is_empty()
}
pub async fn register(&self, cap: Arc<dyn Capability>) -> Result<(), RegistrationError> {
let name = cap.name().to_string();
if name.is_empty() {
return Err(RegistrationError::InvalidMetadata(
"capability name is empty".into(),
));
}
if cap.description().is_empty() {
return Err(RegistrationError::InvalidMetadata(
"capability description is empty".into(),
));
}
let mut caps = self.capabilities.write().await;
if let Some(existing) = caps.get(&name)
&& existing.source() != cap.source()
{
return Err(RegistrationError::NameConflict {
name,
existing_source: existing.source(),
});
}
caps.insert(name, cap);
Ok(())
}
pub async fn register_all(
&self,
capabilities: Vec<Arc<dyn Capability>>,
) -> Vec<(String, RegistrationError)> {
let mut errors = Vec::new();
for cap in capabilities {
let name = cap.name().to_string();
if let Err(e) = self.register(cap).await {
errors.push((name, e));
}
}
errors
}
pub async fn get(&self, name: &str) -> Option<Arc<dyn Capability>> {
self.capabilities.read().await.get(name).cloned()
}
pub async fn catalog(&self) -> Vec<CapabilitySummary> {
let mut out: Vec<_> = self
.capabilities
.read()
.await
.values()
.map(|c| CapabilitySummary {
name: c.name().to_string(),
description: c.description().to_string(),
source: c.source(),
paired_skill: c.paired_skill().map(String::from),
risk_level: c.risk_level(),
parameters_schema: c.parameters_schema(),
})
.collect();
out.sort_by(|a, b| a.name.cmp(&b.name));
out
}
pub async fn list_names(&self) -> Vec<String> {
let mut names: Vec<_> = self.capabilities.read().await.keys().cloned().collect();
names.sort();
names
}
pub async fn reload_plugin(
&self,
plugin_name: &str,
new_capabilities: Vec<Arc<dyn Capability>>,
) -> Vec<(String, RegistrationError)> {
let target = CapabilitySource::Plugin(plugin_name.to_string());
let mut caps = self.capabilities.write().await;
caps.retain(|_, c| c.source() != target);
drop(caps);
let mut errors = Vec::new();
for cap in new_capabilities {
let name = cap.name().to_string();
match self.register(cap).await {
Ok(()) => {}
Err(e) => errors.push((name, e)),
}
}
errors
}
pub async fn sync_from_tool_registry(&self, registry: Arc<ToolRegistry>) -> Result<(), String> {
let mut caps = self.capabilities.write().await;
caps.clear();
drop(caps);
let mut tools: Vec<_> = registry.list();
tools.sort_by_key(|t| t.name());
let mut errors = Vec::new();
for tool in tools {
let name = tool.name().to_string();
let source = match tool.plugin_owner() {
Some(p) => CapabilitySource::Plugin(p.to_string()),
None => CapabilitySource::BuiltIn,
};
let cap = Arc::new(ToolRegistryCapability {
registry: Arc::clone(®istry),
name,
source,
});
if let Err(e) = self.register(cap).await {
errors.push(e.to_string());
}
}
if errors.is_empty() {
Ok(())
} else {
Err(format!(
"capability sync partially failed ({} error(s)): {}",
errors.len(),
errors.join("; ")
))
}
}
pub async fn resync_tools(&self, registry: Arc<ToolRegistry>) -> Result<(), String> {
self.sync_from_tool_registry(registry).await
}
}
pub struct ToolRegistryCapability {
registry: Arc<ToolRegistry>,
name: String,
source: CapabilitySource,
}
#[async_trait]
impl Capability for ToolRegistryCapability {
fn name(&self) -> &str {
&self.name
}
fn description(&self) -> &str {
self.registry
.get(&self.name)
.map(|t| t.description())
.unwrap_or("")
}
fn risk_level(&self) -> RiskLevel {
self.registry
.get(&self.name)
.map(|t| t.risk_level())
.unwrap_or(RiskLevel::Forbidden)
}
fn parameters_schema(&self) -> Value {
self.registry
.get(&self.name)
.map(|t| t.parameters_schema())
.unwrap_or_else(|| serde_json::json!({"type": "object"}))
}
fn source(&self) -> CapabilitySource {
self.source.clone()
}
fn paired_skill(&self) -> Option<&str> {
self.registry.get(&self.name).and_then(|t| t.paired_skill())
}
async fn execute(&self, params: Value, ctx: &ToolContext) -> Result<ToolResult, ToolError> {
let tool = self.registry.get(&self.name).ok_or_else(|| ToolError {
message: format!("tool '{}' not found in ToolRegistry", self.name),
})?;
tool.execute(params, ctx).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tools::{EchoTool, ToolRegistry};
struct StubCapability {
name: String,
description: String,
source: CapabilitySource,
paired: Option<String>,
}
impl StubCapability {
fn builtin(name: &str) -> Self {
Self {
name: name.to_string(),
description: format!("Description for {name}"),
source: CapabilitySource::BuiltIn,
paired: None,
}
}
fn plugin(name: &str, plugin_name: &str) -> Self {
Self {
name: name.to_string(),
description: format!("Plugin capability {name}"),
source: CapabilitySource::Plugin(plugin_name.to_string()),
paired: None,
}
}
fn with_paired_skill(mut self, skill: &str) -> Self {
self.paired = Some(skill.to_string());
self
}
}
#[async_trait::async_trait]
impl Capability for StubCapability {
fn name(&self) -> &str {
&self.name
}
fn description(&self) -> &str {
&self.description
}
fn risk_level(&self) -> roboticus_core::RiskLevel {
roboticus_core::RiskLevel::Safe
}
fn parameters_schema(&self) -> serde_json::Value {
serde_json::json!({"type": "object"})
}
fn source(&self) -> CapabilitySource {
self.source.clone()
}
fn paired_skill(&self) -> Option<&str> {
self.paired.as_deref()
}
async fn execute(
&self,
_params: serde_json::Value,
_ctx: &crate::tools::ToolContext,
) -> Result<crate::tools::ToolResult, crate::tools::ToolError> {
Ok(crate::tools::ToolResult {
output: format!("executed {}", self.name),
metadata: None,
})
}
}
#[test]
fn capability_source_display_builtin() {
let src = CapabilitySource::BuiltIn;
assert_eq!(src.to_string(), "built-in");
}
#[test]
fn capability_source_display_plugin() {
let src = CapabilitySource::Plugin("my-plugin".to_string());
assert_eq!(src.to_string(), "plugin:my-plugin");
}
#[test]
fn capability_source_equality() {
assert_eq!(CapabilitySource::BuiltIn, CapabilitySource::BuiltIn);
assert_ne!(
CapabilitySource::BuiltIn,
CapabilitySource::Plugin("x".to_string())
);
assert_eq!(
CapabilitySource::Plugin("a".to_string()),
CapabilitySource::Plugin("a".to_string())
);
assert_ne!(
CapabilitySource::Plugin("a".to_string()),
CapabilitySource::Plugin("b".to_string())
);
}
#[test]
fn registration_error_name_conflict_display() {
let err = RegistrationError::NameConflict {
name: "tool-x".to_string(),
existing_source: CapabilitySource::BuiltIn,
};
let msg = err.to_string();
assert!(msg.contains("tool-x"));
assert!(msg.contains("built-in"));
}
#[test]
fn registration_error_invalid_metadata_display() {
let err = RegistrationError::InvalidMetadata("name is empty".to_string());
assert!(err.to_string().contains("name is empty"));
}
#[tokio::test]
async fn new_registry_is_empty() {
let reg = CapabilityRegistry::new();
assert!(reg.is_empty().await);
}
#[tokio::test]
async fn default_registry_is_empty() {
let reg = CapabilityRegistry::default();
assert!(reg.is_empty().await);
}
#[tokio::test]
async fn register_valid_capability_succeeds() {
let reg = CapabilityRegistry::new();
let cap: Arc<dyn Capability> = Arc::new(StubCapability::builtin("tool-a"));
reg.register(cap).await.unwrap();
assert!(!reg.is_empty().await);
}
#[tokio::test]
async fn register_rejects_empty_name() {
let reg = CapabilityRegistry::new();
let bad = Arc::new(StubCapability {
name: String::new(),
description: "desc".into(),
source: CapabilitySource::BuiltIn,
paired: None,
});
let err = reg.register(bad as Arc<dyn Capability>).await.unwrap_err();
assert!(matches!(err, RegistrationError::InvalidMetadata(_)));
}
#[tokio::test]
async fn register_rejects_empty_description() {
let reg = CapabilityRegistry::new();
let bad = Arc::new(StubCapability {
name: "tool-x".into(),
description: String::new(),
source: CapabilitySource::BuiltIn,
paired: None,
});
let err = reg.register(bad as Arc<dyn Capability>).await.unwrap_err();
assert!(matches!(err, RegistrationError::InvalidMetadata(_)));
}
#[tokio::test]
async fn register_conflicts_across_different_sources() {
let reg = CapabilityRegistry::new();
let builtin: Arc<dyn Capability> = Arc::new(StubCapability::builtin("tool-x"));
reg.register(builtin).await.unwrap();
let plugin: Arc<dyn Capability> = Arc::new(StubCapability::plugin("tool-x", "my-plugin"));
let err = reg.register(plugin).await.unwrap_err();
assert!(matches!(err, RegistrationError::NameConflict { .. }));
}
#[tokio::test]
async fn register_same_source_overwrites_without_error() {
let reg = CapabilityRegistry::new();
let cap1: Arc<dyn Capability> = Arc::new(StubCapability::builtin("tool-x"));
reg.register(cap1).await.unwrap();
let cap2: Arc<dyn Capability> = Arc::new(StubCapability::builtin("tool-x"));
reg.register(cap2).await.unwrap();
let names = reg.list_names().await;
let count = names.iter().filter(|n| n.as_str() == "tool-x").count();
assert_eq!(count, 1, "should not duplicate same-source re-registration");
}
#[tokio::test]
async fn register_all_returns_empty_errors_on_success() {
let reg = CapabilityRegistry::new();
let caps: Vec<Arc<dyn Capability>> = vec![
Arc::new(StubCapability::builtin("a")),
Arc::new(StubCapability::builtin("b")),
];
let errors = reg.register_all(caps).await;
assert!(errors.is_empty());
let names = reg.list_names().await;
assert_eq!(names, vec!["a", "b"]);
}
#[tokio::test]
async fn register_all_collects_errors_for_failed_registrations() {
let reg = CapabilityRegistry::new();
reg.register(Arc::new(StubCapability::builtin("b")) as Arc<dyn Capability>)
.await
.unwrap();
let caps: Vec<Arc<dyn Capability>> = vec![
Arc::new(StubCapability::builtin("a")),
Arc::new(StubCapability::plugin("b", "my-plugin")),
];
let errors = reg.register_all(caps).await;
assert_eq!(errors.len(), 1);
assert_eq!(errors[0].0, "b");
}
#[tokio::test]
async fn get_returns_registered_capability() {
let reg = CapabilityRegistry::new();
reg.register(Arc::new(StubCapability::builtin("my-tool")) as Arc<dyn Capability>)
.await
.unwrap();
let found = reg.get("my-tool").await;
assert!(found.is_some());
assert_eq!(found.unwrap().name(), "my-tool");
}
#[tokio::test]
async fn get_returns_none_for_missing() {
let reg = CapabilityRegistry::new();
assert!(reg.get("nonexistent").await.is_none());
}
#[tokio::test]
async fn catalog_is_sorted_by_name() {
let reg = CapabilityRegistry::new();
for name in ["zebra", "alpha", "middle"] {
reg.register(Arc::new(StubCapability::builtin(name)) as Arc<dyn Capability>)
.await
.unwrap();
}
let catalog = reg.catalog().await;
let names: Vec<&str> = catalog.iter().map(|c| c.name.as_str()).collect();
assert_eq!(names, vec!["alpha", "middle", "zebra"]);
}
#[tokio::test]
async fn catalog_includes_paired_skill() {
let reg = CapabilityRegistry::new();
let cap = StubCapability::builtin("paired-tool").with_paired_skill("my-skill.md");
reg.register(Arc::new(cap) as Arc<dyn Capability>)
.await
.unwrap();
let catalog = reg.catalog().await;
assert_eq!(catalog[0].paired_skill, Some("my-skill.md".to_string()));
}
#[tokio::test]
async fn catalog_empty_when_no_capabilities() {
let reg = CapabilityRegistry::new();
assert!(reg.catalog().await.is_empty());
}
#[tokio::test]
async fn list_names_sorted_alphabetically() {
let reg = CapabilityRegistry::new();
for name in ["c", "a", "b"] {
reg.register(Arc::new(StubCapability::builtin(name)) as Arc<dyn Capability>)
.await
.unwrap();
}
let names = reg.list_names().await;
assert_eq!(names, vec!["a", "b", "c"]);
}
#[tokio::test]
async fn reload_plugin_removes_old_and_adds_new() {
let reg = CapabilityRegistry::new();
reg.register(
Arc::new(StubCapability::plugin("old-cap-1", "plugin-a")) as Arc<dyn Capability>
)
.await
.unwrap();
reg.register(
Arc::new(StubCapability::plugin("old-cap-2", "plugin-a")) as Arc<dyn Capability>
)
.await
.unwrap();
reg.register(
Arc::new(StubCapability::plugin("other-cap", "plugin-b")) as Arc<dyn Capability>
)
.await
.unwrap();
let new_caps: Vec<Arc<dyn Capability>> =
vec![Arc::new(StubCapability::plugin("new-cap", "plugin-a"))];
let errors = reg.reload_plugin("plugin-a", new_caps).await;
assert!(errors.is_empty());
let names = reg.list_names().await;
assert!(
!names.contains(&"old-cap-1".to_string()),
"old cap 1 should be gone"
);
assert!(
!names.contains(&"old-cap-2".to_string()),
"old cap 2 should be gone"
);
assert!(
names.contains(&"new-cap".to_string()),
"new cap should be present"
);
assert!(
names.contains(&"other-cap".to_string()),
"other plugin should be untouched"
);
}
#[tokio::test]
async fn reload_plugin_with_empty_list_removes_all_plugin_caps() {
let reg = CapabilityRegistry::new();
reg.register(Arc::new(StubCapability::plugin("cap-x", "plugin-c")) as Arc<dyn Capability>)
.await
.unwrap();
reg.register(Arc::new(StubCapability::builtin("builtin-cap")) as Arc<dyn Capability>)
.await
.unwrap();
let errors = reg.reload_plugin("plugin-c", vec![]).await;
assert!(errors.is_empty());
let names = reg.list_names().await;
assert!(!names.contains(&"cap-x".to_string()));
assert!(names.contains(&"builtin-cap".to_string()));
}
#[tokio::test]
async fn sync_populates_catalog() {
let mut reg = ToolRegistry::new();
reg.register(Box::new(EchoTool));
let reg = Arc::new(reg);
let caps = CapabilityRegistry::new();
caps.sync_from_tool_registry(Arc::clone(®))
.await
.unwrap();
assert!(!caps.is_empty().await);
let names = caps.list_names().await;
assert!(names.iter().any(|n| n == "echo"));
}
#[tokio::test]
async fn resync_tools_replaces_previous_catalog() {
let mut reg = ToolRegistry::new();
reg.register(Box::new(EchoTool));
let reg = Arc::new(reg);
let caps = CapabilityRegistry::new();
caps.sync_from_tool_registry(Arc::clone(®))
.await
.unwrap();
let initial_count = caps.list_names().await.len();
caps.resync_tools(Arc::clone(®)).await.unwrap();
assert_eq!(caps.list_names().await.len(), initial_count);
}
#[tokio::test]
async fn tool_registry_capability_missing_tool_returns_forbidden() {
let empty_reg = Arc::new(ToolRegistry::new());
let cap = ToolRegistryCapability {
registry: Arc::clone(&empty_reg),
name: "nonexistent".to_string(),
source: CapabilitySource::BuiltIn,
};
let risk = cap.risk_level();
assert_eq!(risk, roboticus_core::RiskLevel::Forbidden);
}
#[tokio::test]
async fn tool_registry_capability_missing_tool_description_empty() {
let empty_reg = Arc::new(ToolRegistry::new());
let cap = ToolRegistryCapability {
registry: Arc::clone(&empty_reg),
name: "ghost".to_string(),
source: CapabilitySource::BuiltIn,
};
assert_eq!(cap.description(), "");
}
}