1use std::collections::HashMap;
8use std::fmt;
9use std::sync::Arc;
10
11use async_trait::async_trait;
12use serde::{Deserialize, Serialize};
13use serde_json::Value;
14use tokio::sync::RwLock;
15
16use roboticus_core::RiskLevel;
17use roboticus_core::config::McpTransport;
18
19use crate::tools::{ToolContext, ToolError, ToolRegistry, ToolResult};
20
21#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
23pub enum CapabilitySource {
24 BuiltIn,
25 Plugin(String),
26 Mcp {
27 server: String,
28 transport: McpTransport,
29 },
30}
31
32impl fmt::Display for CapabilitySource {
33 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34 match self {
35 Self::BuiltIn => write!(f, "built-in"),
36 Self::Plugin(p) => write!(f, "plugin:{p}"),
37 Self::Mcp { server, transport } => {
38 let t = match transport {
39 McpTransport::Stdio => "stdio",
40 McpTransport::Sse => "sse",
41 McpTransport::Http => "http",
42 McpTransport::WebSocket => "ws",
43 };
44 write!(f, "mcp:{server}({t})")
45 }
46 }
47 }
48}
49
50#[async_trait]
52pub trait Capability: Send + Sync {
53 fn name(&self) -> &str;
54 fn description(&self) -> &str;
55 fn risk_level(&self) -> RiskLevel;
56 fn parameters_schema(&self) -> Value;
57 fn source(&self) -> CapabilitySource;
58
59 fn paired_skill(&self) -> Option<&str> {
61 None
62 }
63
64 async fn execute(&self, params: Value, ctx: &ToolContext) -> Result<ToolResult, ToolError>;
65}
66
67#[derive(Debug, Clone, Serialize)]
69pub struct CapabilitySummary {
70 pub name: String,
71 pub description: String,
72 pub source: CapabilitySource,
73 pub paired_skill: Option<String>,
74 pub risk_level: RiskLevel,
75 pub parameters_schema: Value,
76}
77
78#[derive(Debug)]
79pub enum RegistrationError {
80 NameConflict {
81 name: String,
82 existing_source: CapabilitySource,
83 },
84 InvalidMetadata(String),
85}
86
87impl fmt::Display for RegistrationError {
88 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89 match self {
90 Self::NameConflict {
91 name,
92 existing_source,
93 } => write!(
94 f,
95 "capability name conflict: '{name}' already registered ({existing_source})"
96 ),
97 Self::InvalidMetadata(m) => write!(f, "invalid capability metadata: {m}"),
98 }
99 }
100}
101
102impl std::error::Error for RegistrationError {}
103
104pub struct CapabilityRegistry {
106 capabilities: RwLock<HashMap<String, Arc<dyn Capability>>>,
107}
108
109impl Default for CapabilityRegistry {
110 fn default() -> Self {
111 Self::new()
112 }
113}
114
115impl CapabilityRegistry {
116 pub fn new() -> Self {
117 Self {
118 capabilities: RwLock::new(HashMap::new()),
119 }
120 }
121
122 pub async fn is_empty(&self) -> bool {
123 self.capabilities.read().await.is_empty()
124 }
125
126 pub async fn register(&self, cap: Arc<dyn Capability>) -> Result<(), RegistrationError> {
127 let name = cap.name().to_string();
128 if name.is_empty() {
129 return Err(RegistrationError::InvalidMetadata(
130 "capability name is empty".into(),
131 ));
132 }
133 if cap.description().is_empty() {
134 return Err(RegistrationError::InvalidMetadata(
135 "capability description is empty".into(),
136 ));
137 }
138
139 let has_separator = name.contains("::");
140 let is_mcp = matches!(cap.source(), CapabilitySource::Mcp { .. });
141 if is_mcp && !has_separator {
142 return Err(RegistrationError::InvalidMetadata(format!(
143 "MCP capability '{name}' must use '::' separator (e.g., 'server::tool_name')"
144 )));
145 }
146 if !is_mcp && has_separator {
147 return Err(RegistrationError::InvalidMetadata(format!(
148 "non-MCP capability '{name}' must not use '::' separator (reserved for MCP)"
149 )));
150 }
151
152 let mut caps = self.capabilities.write().await;
153 if let Some(existing) = caps.get(&name)
154 && existing.source() != cap.source()
155 {
156 return Err(RegistrationError::NameConflict {
157 name,
158 existing_source: existing.source(),
159 });
160 }
161 caps.insert(name, cap);
162 Ok(())
163 }
164
165 pub async fn register_all(
166 &self,
167 capabilities: Vec<Arc<dyn Capability>>,
168 ) -> Vec<(String, RegistrationError)> {
169 let mut errors = Vec::new();
170 for cap in capabilities {
171 let name = cap.name().to_string();
172 if let Err(e) = self.register(cap).await {
173 errors.push((name, e));
174 }
175 }
176 errors
177 }
178
179 pub async fn get(&self, name: &str) -> Option<Arc<dyn Capability>> {
180 self.capabilities.read().await.get(name).cloned()
181 }
182
183 pub async fn catalog(&self) -> Vec<CapabilitySummary> {
184 let mut out: Vec<_> = self
185 .capabilities
186 .read()
187 .await
188 .values()
189 .map(|c| CapabilitySummary {
190 name: c.name().to_string(),
191 description: c.description().to_string(),
192 source: c.source(),
193 paired_skill: c.paired_skill().map(String::from),
194 risk_level: c.risk_level(),
195 parameters_schema: c.parameters_schema(),
196 })
197 .collect();
198 out.sort_by(|a, b| a.name.cmp(&b.name));
199 out
200 }
201
202 pub async fn list_names(&self) -> Vec<String> {
203 let mut names: Vec<_> = self.capabilities.read().await.keys().cloned().collect();
204 names.sort();
205 names
206 }
207
208 pub async fn reload_plugin(
213 &self,
214 plugin_name: &str,
215 new_capabilities: Vec<Arc<dyn Capability>>,
216 ) -> Vec<(String, RegistrationError)> {
217 let target = CapabilitySource::Plugin(plugin_name.to_string());
218
219 let mut errors: Vec<(String, RegistrationError)> = Vec::new();
222 let mut valid: Vec<Arc<dyn Capability>> = Vec::new();
223 for cap in new_capabilities {
224 let name = cap.name().to_string();
225 if name.is_empty() {
226 errors.push((
227 name,
228 RegistrationError::InvalidMetadata("capability name is empty".into()),
229 ));
230 continue;
231 }
232 if cap.description().is_empty() {
233 errors.push((
234 name,
235 RegistrationError::InvalidMetadata("capability description is empty".into()),
236 ));
237 continue;
238 }
239 if cap.name().contains("::") {
240 errors.push((
241 name,
242 RegistrationError::InvalidMetadata(format!(
243 "non-MCP capability '{}' must not use '::' separator (reserved for MCP)",
244 cap.name()
245 )),
246 ));
247 continue;
248 }
249 valid.push(cap);
250 }
251
252 let mut caps = self.capabilities.write().await;
255 caps.retain(|_, c| c.source() != target);
256 for cap in valid {
257 let name = cap.name().to_string();
258 if let Some(existing) = caps.get(&name)
260 && existing.source() != cap.source()
261 {
262 errors.push((
263 name,
264 RegistrationError::NameConflict {
265 name: cap.name().to_string(),
266 existing_source: existing.source(),
267 },
268 ));
269 continue;
270 }
271 caps.insert(name, cap);
272 }
273 drop(caps);
274
275 errors
276 }
277
278 pub async fn reload_mcp_server(
283 &self,
284 server_name: &str,
285 new_capabilities: Vec<Arc<dyn Capability>>,
286 ) -> Result<(), RegistrationError> {
287 for cap in &new_capabilities {
290 if cap.name().is_empty() {
291 return Err(RegistrationError::InvalidMetadata(
292 "capability name is empty".into(),
293 ));
294 }
295 if cap.description().is_empty() {
296 return Err(RegistrationError::InvalidMetadata(
297 "capability description is empty".into(),
298 ));
299 }
300 if !cap.name().contains("::") {
301 return Err(RegistrationError::InvalidMetadata(format!(
302 "MCP capability '{}' must use '::' separator",
303 cap.name()
304 )));
305 }
306 }
307
308 let mut caps = self.capabilities.write().await;
309
310 caps.retain(|_, existing| {
312 !matches!(existing.source(), CapabilitySource::Mcp { server, .. } if server == server_name)
313 });
314
315 for cap in new_capabilities {
317 let name = cap.name().to_string();
318 caps.insert(name, cap);
319 }
320
321 Ok(())
322 }
323
324 pub async fn sync_from_tool_registry(&self, registry: Arc<ToolRegistry>) -> Result<(), String> {
326 let mut caps = self.capabilities.write().await;
327 caps.clear();
328 drop(caps);
329
330 let mut tools: Vec<_> = registry.list();
331 tools.sort_by_key(|t| t.name());
332 let mut errors = Vec::new();
333 for tool in tools {
334 let name = tool.name().to_string();
335 let source = match tool.plugin_owner() {
336 Some(p) => CapabilitySource::Plugin(p.to_string()),
337 None => CapabilitySource::BuiltIn,
338 };
339 let cap = Arc::new(ToolRegistryCapability {
340 registry: Arc::clone(®istry),
341 name,
342 source,
343 });
344 if let Err(e) = self.register(cap).await {
345 errors.push(e.to_string());
346 }
347 }
348 if errors.is_empty() {
349 Ok(())
350 } else {
351 Err(format!(
352 "capability sync partially failed ({} error(s)): {}",
353 errors.len(),
354 errors.join("; ")
355 ))
356 }
357 }
358
359 pub async fn resync_tools(&self, registry: Arc<ToolRegistry>) -> Result<(), String> {
361 self.sync_from_tool_registry(registry).await
362 }
363}
364
365pub struct ToolRegistryCapability {
367 registry: Arc<ToolRegistry>,
368 name: String,
369 source: CapabilitySource,
370}
371
372#[async_trait]
373impl Capability for ToolRegistryCapability {
374 fn name(&self) -> &str {
375 &self.name
376 }
377
378 fn description(&self) -> &str {
379 self.registry
380 .get(&self.name)
381 .map(|t| t.description())
382 .unwrap_or("")
383 }
384
385 fn risk_level(&self) -> RiskLevel {
386 self.registry
387 .get(&self.name)
388 .map(|t| t.risk_level())
389 .unwrap_or(RiskLevel::Forbidden)
390 }
391
392 fn parameters_schema(&self) -> Value {
393 self.registry
394 .get(&self.name)
395 .map(|t| t.parameters_schema())
396 .unwrap_or_else(|| serde_json::json!({"type": "object"}))
397 }
398
399 fn source(&self) -> CapabilitySource {
400 self.source.clone()
401 }
402
403 fn paired_skill(&self) -> Option<&str> {
404 self.registry.get(&self.name).and_then(|t| t.paired_skill())
405 }
406
407 async fn execute(&self, params: Value, ctx: &ToolContext) -> Result<ToolResult, ToolError> {
408 let tool = self.registry.get(&self.name).ok_or_else(|| ToolError {
409 message: format!("tool '{}' not found in ToolRegistry", self.name),
410 })?;
411 tool.execute(params, ctx).await
412 }
413}
414
415#[cfg(test)]
416mod tests {
417 use super::*;
418 use crate::tools::ToolRegistry;
419
420 #[tokio::test]
421 async fn sync_populates_catalog() {
422 use crate::tools::EchoTool;
423
424 let mut reg = ToolRegistry::new();
425 reg.register(Box::new(EchoTool));
426 let reg = Arc::new(reg);
427 let caps = CapabilityRegistry::new();
428 caps.sync_from_tool_registry(Arc::clone(®))
429 .await
430 .unwrap();
431 assert!(!caps.is_empty().await);
432 let names = caps.list_names().await;
433 assert!(names.iter().any(|n| n == "echo"));
434 }
435
436 #[test]
439 fn mcp_source_display_stdio() {
440 let source = CapabilitySource::Mcp {
441 server: "github".into(),
442 transport: McpTransport::Stdio,
443 };
444 assert_eq!(source.to_string(), "mcp:github(stdio)");
445 }
446
447 #[test]
448 fn mcp_source_display_sse() {
449 let source = CapabilitySource::Mcp {
450 server: "linear".into(),
451 transport: McpTransport::Sse,
452 };
453 assert_eq!(source.to_string(), "mcp:linear(sse)");
454 }
455
456 #[test]
457 fn mcp_source_display_http() {
458 let source = CapabilitySource::Mcp {
459 server: "sentry".into(),
460 transport: McpTransport::Http,
461 };
462 assert_eq!(source.to_string(), "mcp:sentry(http)");
463 }
464
465 #[test]
466 fn mcp_source_display_websocket() {
467 let source = CapabilitySource::Mcp {
468 server: "relay".into(),
469 transport: McpTransport::WebSocket,
470 };
471 assert_eq!(source.to_string(), "mcp:relay(ws)");
472 }
473
474 struct StubCap {
478 name: String,
479 source: CapabilitySource,
480 }
481
482 #[async_trait::async_trait]
483 impl Capability for StubCap {
484 fn name(&self) -> &str {
485 &self.name
486 }
487 fn description(&self) -> &str {
488 "stub"
489 }
490 fn risk_level(&self) -> roboticus_core::RiskLevel {
491 roboticus_core::RiskLevel::Safe
492 }
493 fn parameters_schema(&self) -> serde_json::Value {
494 serde_json::json!({"type": "object"})
495 }
496 fn source(&self) -> CapabilitySource {
497 self.source.clone()
498 }
499 async fn execute(
500 &self,
501 _params: serde_json::Value,
502 _ctx: &crate::tools::ToolContext,
503 ) -> Result<crate::tools::ToolResult, crate::tools::ToolError> {
504 Ok(crate::tools::ToolResult {
505 output: "stub".into(),
506 metadata: None,
507 })
508 }
509 }
510
511 #[tokio::test]
512 async fn register_rejects_builtin_with_separator() {
513 let reg = CapabilityRegistry::new();
514 let cap = Arc::new(StubCap {
515 name: "ns::tool".into(),
516 source: CapabilitySource::BuiltIn,
517 });
518 let err = reg.register(cap).await.unwrap_err();
519 assert!(
520 matches!(err, RegistrationError::InvalidMetadata(_)),
521 "expected InvalidMetadata, got: {err}"
522 );
523 assert!(err.to_string().contains("reserved for MCP"));
524 }
525
526 #[tokio::test]
527 async fn register_rejects_plugin_with_separator() {
528 let reg = CapabilityRegistry::new();
529 let cap = Arc::new(StubCap {
530 name: "ns::tool".into(),
531 source: CapabilitySource::Plugin("myplugin".into()),
532 });
533 let err = reg.register(cap).await.unwrap_err();
534 assert!(
535 matches!(err, RegistrationError::InvalidMetadata(_)),
536 "expected InvalidMetadata, got: {err}"
537 );
538 assert!(err.to_string().contains("reserved for MCP"));
539 }
540
541 #[tokio::test]
542 async fn register_rejects_mcp_without_separator() {
543 let reg = CapabilityRegistry::new();
544 let cap = Arc::new(StubCap {
545 name: "tool_name".into(),
546 source: CapabilitySource::Mcp {
547 server: "github".into(),
548 transport: McpTransport::Stdio,
549 },
550 });
551 let err = reg.register(cap).await.unwrap_err();
552 assert!(
553 matches!(err, RegistrationError::InvalidMetadata(_)),
554 "expected InvalidMetadata, got: {err}"
555 );
556 assert!(err.to_string().contains("must use '::' separator"));
557 }
558
559 #[tokio::test]
560 async fn register_allows_mcp_with_separator() {
561 let reg = CapabilityRegistry::new();
562 let cap = Arc::new(StubCap {
563 name: "github::create_issue".into(),
564 source: CapabilitySource::Mcp {
565 server: "github".into(),
566 transport: McpTransport::Stdio,
567 },
568 });
569 reg.register(cap).await.unwrap();
570 assert!(reg.get("github::create_issue").await.is_some());
571 }
572
573 #[tokio::test]
574 async fn register_allows_builtin_without_separator() {
575 let reg = CapabilityRegistry::new();
576 let cap = Arc::new(StubCap {
577 name: "bash".into(),
578 source: CapabilitySource::BuiltIn,
579 });
580 reg.register(cap).await.unwrap();
581 assert!(reg.get("bash").await.is_some());
582 }
583
584 fn make_mcp_cap(server: &str, tool: &str) -> Arc<StubCap> {
587 Arc::new(StubCap {
588 name: format!("{server}::{tool}"),
589 source: CapabilitySource::Mcp {
590 server: server.into(),
591 transport: McpTransport::Stdio,
592 },
593 })
594 }
595
596 #[tokio::test]
597 async fn atomic_reload_swaps_all_at_once() {
598 let registry = CapabilityRegistry::new();
599
600 let old_cap = make_mcp_cap("myserver", "old_tool");
602 registry.register(old_cap).await.unwrap();
603 assert!(registry.get("myserver::old_tool").await.is_some());
604
605 let new_cap = make_mcp_cap("myserver", "new_tool");
607 registry
608 .reload_mcp_server("myserver", vec![new_cap])
609 .await
610 .unwrap();
611
612 let summaries = registry.catalog().await;
614 assert!(
615 summaries.iter().any(|s| s.name == "myserver::new_tool"),
616 "new tool should be in the catalog"
617 );
618 assert!(
619 !summaries.iter().any(|s| s.name == "myserver::old_tool"),
620 "old tool should have been removed"
621 );
622 }
623
624 #[tokio::test]
625 async fn atomic_reload_rejects_cap_without_separator() {
626 let registry = CapabilityRegistry::new();
627 let bad_cap = Arc::new(StubCap {
628 name: "notnamespaced".into(),
629 source: CapabilitySource::Mcp {
630 server: "myserver".into(),
631 transport: McpTransport::Stdio,
632 },
633 });
634 let err = registry
635 .reload_mcp_server("myserver", vec![bad_cap])
636 .await
637 .unwrap_err();
638 assert!(
639 matches!(err, RegistrationError::InvalidMetadata(_)),
640 "expected InvalidMetadata, got: {err}"
641 );
642 assert!(err.to_string().contains("must use '::' separator"));
643 }
644
645 #[tokio::test]
646 async fn atomic_reload_only_removes_matching_server() {
647 let registry = CapabilityRegistry::new();
648
649 let cap_a = make_mcp_cap("server_a", "tool1");
651 let cap_b = make_mcp_cap("server_b", "tool2");
652 registry.register(cap_a).await.unwrap();
653 registry.register(cap_b).await.unwrap();
654
655 let new_cap = make_mcp_cap("server_a", "tool_new");
657 registry
658 .reload_mcp_server("server_a", vec![new_cap])
659 .await
660 .unwrap();
661
662 assert!(
664 registry.get("server_b::tool2").await.is_some(),
665 "server_b tools should not be touched"
666 );
667 assert!(
668 registry.get("server_a::tool_new").await.is_some(),
669 "new server_a tool should be present"
670 );
671 assert!(
672 registry.get("server_a::tool1").await.is_none(),
673 "old server_a tool should be gone"
674 );
675 }
676
677 #[tokio::test]
680 async fn reload_plugin_holds_lock_atomically() {
681 let registry = CapabilityRegistry::new();
682
683 let old_cap = Arc::new(StubCap {
685 name: "old_action".into(),
686 source: CapabilitySource::Plugin("myplugin".into()),
687 });
688 registry.register(old_cap).await.unwrap();
689
690 let new_cap = Arc::new(StubCap {
692 name: "new_action".into(),
693 source: CapabilitySource::Plugin("myplugin".into()),
694 });
695 let errors = registry.reload_plugin("myplugin", vec![new_cap]).await;
696 assert!(errors.is_empty(), "unexpected errors: {errors:?}");
697
698 let names = registry.list_names().await;
699 assert!(
700 names.contains(&"new_action".to_string()),
701 "new tool should be registered"
702 );
703 assert!(
704 !names.contains(&"old_action".to_string()),
705 "old tool should be removed"
706 );
707 }
708}