use std::sync::Arc;
use async_trait::async_trait;
use serde_json::Value;
use tokio::sync::RwLock;
use roboticus_core::RiskLevel;
use roboticus_core::config::McpTransport;
use crate::capability::{Capability, CapabilitySource};
use crate::tools::{ToolContext, ToolError, ToolResult};
use super::client::{DiscoveredTool, LiveMcpConnection};
pub struct McpCapability {
prefixed_name: String,
server_name: String,
tool_name: String,
description: String,
input_schema: Value,
transport: McpTransport,
risk_level: RiskLevel,
connection: Arc<RwLock<LiveMcpConnection>>,
}
impl McpCapability {
pub fn new(
server_name: &str,
tool: &DiscoveredTool,
transport: McpTransport,
connection: Arc<RwLock<LiveMcpConnection>>,
) -> Self {
Self {
prefixed_name: format!("{server_name}::{}", tool.name),
server_name: server_name.to_string(),
tool_name: tool.name.clone(),
description: tool.description.clone(),
input_schema: tool.input_schema.clone(),
transport,
risk_level: RiskLevel::Caution,
connection,
}
}
pub fn with_risk_level(mut self, level: RiskLevel) -> Self {
self.risk_level = level;
self
}
}
#[async_trait]
impl Capability for McpCapability {
fn name(&self) -> &str {
&self.prefixed_name
}
fn description(&self) -> &str {
&self.description
}
fn risk_level(&self) -> RiskLevel {
self.risk_level
}
fn parameters_schema(&self) -> Value {
self.input_schema.clone()
}
fn source(&self) -> CapabilitySource {
CapabilitySource::Mcp {
server: self.server_name.clone(),
transport: self.transport.clone(),
}
}
async fn execute(&self, params: Value, _ctx: &ToolContext) -> Result<ToolResult, ToolError> {
let conn = self.connection.read().await;
if !conn.is_alive() {
return Err(ToolError {
message: format!("MCP server '{}' is not connected", self.server_name),
});
}
let result = conn
.call_tool(&self.tool_name, params)
.await
.map_err(|e| ToolError {
message: format!("MCP tool '{}' call failed: {e}", self.prefixed_name),
})?;
let is_error = result
.get("is_error")
.and_then(|v| v.as_bool())
.unwrap_or(false);
let content = result
.get("content")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
if is_error {
Err(ToolError {
message: format!("MCP tool error: {content}"),
})
} else {
Ok(ToolResult {
output: content,
metadata: Some(serde_json::json!({
"mcp_server": self.server_name,
"mcp_tool": self.tool_name,
})),
})
}
}
}
pub fn bridge_tools(
server_name: &str,
tools: &[DiscoveredTool],
transport: McpTransport,
connection: Arc<RwLock<LiveMcpConnection>>,
) -> Vec<McpCapability> {
tools
.iter()
.map(|tool| {
McpCapability::new(
server_name,
tool,
transport.clone(),
Arc::clone(&connection),
)
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use tokio::sync::RwLock;
use crate::mcp::client::test_support;
use crate::tools::{ToolContext, ToolSandboxSnapshot};
use roboticus_core::InputAuthority;
fn make_tool(name: &str, desc: &str) -> DiscoveredTool {
DiscoveredTool {
name: name.into(),
description: desc.into(),
input_schema: serde_json::json!({"type": "object"}),
}
}
fn test_ctx() -> ToolContext {
ToolContext {
session_id: "test-session".into(),
agent_id: "test-agent".into(),
agent_name: "test-agent".into(),
authority: InputAuthority::Creator,
workspace_root: std::env::current_dir().unwrap(),
tool_allowed_paths: vec![],
channel: None,
db: None,
sandbox: ToolSandboxSnapshot::default(),
}
}
#[test]
fn bridge_tools_produces_correct_names() {
let tools = [
make_tool("create_issue", "Create a GitHub issue"),
make_tool("list_repos", "List repositories"),
];
let name = format!("github::{}", tools[0].name);
assert_eq!(name, "github::create_issue");
let name2 = format!("github::{}", tools[1].name);
assert_eq!(name2, "github::list_repos");
}
#[test]
fn prefixed_name_uses_double_colon() {
let name = format!("{}::{}", "linear", "create_ticket");
assert!(name.contains("::"));
assert_eq!(name, "linear::create_ticket");
}
#[tokio::test]
async fn bridge_tools_builds_capabilities_with_expected_metadata() {
let (conn, server_handle) = test_support::echo_connection("remote-test").await.unwrap();
let conn = Arc::new(RwLock::new(conn));
let caps = {
let read = conn.read().await;
bridge_tools(
"remote-test",
read.tools(),
McpTransport::Sse,
Arc::clone(&conn),
)
};
assert_eq!(caps.len(), 1);
let cap = &caps[0];
assert_eq!(cap.name(), "remote-test::echo");
assert_eq!(cap.description(), "Echo back the provided text");
assert_eq!(cap.parameters_schema()["type"], "object");
match cap.source() {
CapabilitySource::Mcp { server, transport } => {
assert_eq!(server, "remote-test");
assert!(matches!(transport, McpTransport::Sse));
}
other => panic!("expected MCP source, got {other:?}"),
}
server_handle.abort();
let _ = server_handle.await;
}
#[tokio::test]
async fn mcp_capability_executes_remote_tool_and_returns_metadata() {
let (conn, server_handle) = test_support::echo_connection("remote-test").await.unwrap();
let conn = Arc::new(RwLock::new(conn));
let tool = {
let read = conn.read().await;
read.tools()[0].clone()
};
let cap = McpCapability::new("remote-test", &tool, McpTransport::Sse, Arc::clone(&conn))
.with_risk_level(RiskLevel::Dangerous);
let result = cap
.execute(serde_json::json!({ "text": "hello bridge" }), &test_ctx())
.await
.unwrap();
assert_eq!(cap.risk_level(), RiskLevel::Dangerous);
assert_eq!(result.output, "hello bridge");
assert_eq!(
result.metadata.as_ref().unwrap()["mcp_server"],
"remote-test"
);
assert_eq!(result.metadata.as_ref().unwrap()["mcp_tool"], "echo");
server_handle.abort();
let _ = server_handle.await;
}
}