1use std::sync::Arc;
7
8use async_trait::async_trait;
9use serde_json::Value;
10use tokio::sync::RwLock;
11
12use roboticus_core::RiskLevel;
13use roboticus_core::config::McpTransport;
14
15use crate::capability::{Capability, CapabilitySource};
16use crate::tools::{ToolContext, ToolError, ToolResult};
17
18use super::client::{DiscoveredTool, LiveMcpConnection};
19
20pub struct McpCapability {
26 prefixed_name: String,
28 server_name: String,
29 tool_name: String,
30 description: String,
31 input_schema: Value,
32 transport: McpTransport,
33 risk_level: RiskLevel,
34 connection: Arc<RwLock<LiveMcpConnection>>,
36}
37
38impl McpCapability {
39 pub fn new(
41 server_name: &str,
42 tool: &DiscoveredTool,
43 transport: McpTransport,
44 connection: Arc<RwLock<LiveMcpConnection>>,
45 ) -> Self {
46 Self {
47 prefixed_name: format!("{server_name}::{}", tool.name),
48 server_name: server_name.to_string(),
49 tool_name: tool.name.clone(),
50 description: tool.description.clone(),
51 input_schema: tool.input_schema.clone(),
52 transport,
53 risk_level: RiskLevel::Caution,
54 connection,
55 }
56 }
57
58 pub fn with_risk_level(mut self, level: RiskLevel) -> Self {
60 self.risk_level = level;
61 self
62 }
63}
64
65#[async_trait]
66impl Capability for McpCapability {
67 fn name(&self) -> &str {
68 &self.prefixed_name
69 }
70
71 fn description(&self) -> &str {
72 &self.description
73 }
74
75 fn risk_level(&self) -> RiskLevel {
76 self.risk_level
77 }
78
79 fn parameters_schema(&self) -> Value {
80 self.input_schema.clone()
81 }
82
83 fn source(&self) -> CapabilitySource {
84 CapabilitySource::Mcp {
85 server: self.server_name.clone(),
86 transport: self.transport.clone(),
87 }
88 }
89
90 async fn execute(&self, params: Value, _ctx: &ToolContext) -> Result<ToolResult, ToolError> {
91 let conn = self.connection.read().await;
92 if !conn.is_alive() {
93 return Err(ToolError {
94 message: format!("MCP server '{}' is not connected", self.server_name),
95 });
96 }
97
98 let result = conn
99 .call_tool(&self.tool_name, params)
100 .await
101 .map_err(|e| ToolError {
102 message: format!("MCP tool '{}' call failed: {e}", self.prefixed_name),
103 })?;
104
105 let is_error = result
107 .get("is_error")
108 .and_then(|v| v.as_bool())
109 .unwrap_or(false);
110 let content = result
111 .get("content")
112 .and_then(|v| v.as_str())
113 .unwrap_or("")
114 .to_string();
115
116 if is_error {
117 Err(ToolError {
118 message: format!("MCP tool error: {content}"),
119 })
120 } else {
121 Ok(ToolResult {
122 output: content,
123 metadata: Some(serde_json::json!({
124 "mcp_server": self.server_name,
125 "mcp_tool": self.tool_name,
126 })),
127 })
128 }
129 }
130}
131
132pub fn bridge_tools(
136 server_name: &str,
137 tools: &[DiscoveredTool],
138 transport: McpTransport,
139 connection: Arc<RwLock<LiveMcpConnection>>,
140) -> Vec<McpCapability> {
141 tools
142 .iter()
143 .map(|tool| {
144 McpCapability::new(
145 server_name,
146 tool,
147 transport.clone(),
148 Arc::clone(&connection),
149 )
150 })
151 .collect()
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157 use std::sync::Arc;
158 use tokio::sync::RwLock;
159
160 use crate::mcp::client::test_support;
161 use crate::tools::{ToolContext, ToolSandboxSnapshot};
162 use roboticus_core::InputAuthority;
163
164 fn make_tool(name: &str, desc: &str) -> DiscoveredTool {
165 DiscoveredTool {
166 name: name.into(),
167 description: desc.into(),
168 input_schema: serde_json::json!({"type": "object"}),
169 }
170 }
171
172 fn test_ctx() -> ToolContext {
173 ToolContext {
174 session_id: "test-session".into(),
175 agent_id: "test-agent".into(),
176 agent_name: "test-agent".into(),
177 authority: InputAuthority::Creator,
178 workspace_root: std::env::current_dir().unwrap(),
179 tool_allowed_paths: vec![],
180 channel: None,
181 db: None,
182 sandbox: ToolSandboxSnapshot::default(),
183 }
184 }
185
186 #[test]
190 fn bridge_tools_produces_correct_names() {
191 let tools = [
194 make_tool("create_issue", "Create a GitHub issue"),
195 make_tool("list_repos", "List repositories"),
196 ];
197
198 let name = format!("github::{}", tools[0].name);
201 assert_eq!(name, "github::create_issue");
202
203 let name2 = format!("github::{}", tools[1].name);
204 assert_eq!(name2, "github::list_repos");
205 }
206
207 #[test]
208 fn prefixed_name_uses_double_colon() {
209 let name = format!("{}::{}", "linear", "create_ticket");
210 assert!(name.contains("::"));
211 assert_eq!(name, "linear::create_ticket");
212 }
213
214 #[tokio::test]
215 async fn bridge_tools_builds_capabilities_with_expected_metadata() {
216 let (conn, server_handle) = test_support::echo_connection("remote-test").await.unwrap();
217 let conn = Arc::new(RwLock::new(conn));
218 let caps = {
219 let read = conn.read().await;
220 bridge_tools(
221 "remote-test",
222 read.tools(),
223 McpTransport::Sse,
224 Arc::clone(&conn),
225 )
226 };
227
228 assert_eq!(caps.len(), 1);
229 let cap = &caps[0];
230 assert_eq!(cap.name(), "remote-test::echo");
231 assert_eq!(cap.description(), "Echo back the provided text");
232 assert_eq!(cap.parameters_schema()["type"], "object");
233 match cap.source() {
234 CapabilitySource::Mcp { server, transport } => {
235 assert_eq!(server, "remote-test");
236 assert!(matches!(transport, McpTransport::Sse));
237 }
238 other => panic!("expected MCP source, got {other:?}"),
239 }
240
241 server_handle.abort();
242 let _ = server_handle.await;
243 }
244
245 #[tokio::test]
246 async fn mcp_capability_executes_remote_tool_and_returns_metadata() {
247 let (conn, server_handle) = test_support::echo_connection("remote-test").await.unwrap();
248 let conn = Arc::new(RwLock::new(conn));
249 let tool = {
250 let read = conn.read().await;
251 read.tools()[0].clone()
252 };
253 let cap = McpCapability::new("remote-test", &tool, McpTransport::Sse, Arc::clone(&conn))
254 .with_risk_level(RiskLevel::Dangerous);
255
256 let result = cap
257 .execute(serde_json::json!({ "text": "hello bridge" }), &test_ctx())
258 .await
259 .unwrap();
260 assert_eq!(cap.risk_level(), RiskLevel::Dangerous);
261 assert_eq!(result.output, "hello bridge");
262 assert_eq!(
263 result.metadata.as_ref().unwrap()["mcp_server"],
264 "remote-test"
265 );
266 assert_eq!(result.metadata.as_ref().unwrap()["mcp_tool"], "echo");
267
268 server_handle.abort();
269 let _ = server_handle.await;
270 }
271}