Skip to main content

batuta/agent/tool/
mcp_client.rs

1//! MCP Client Tool — wraps external MCP server tools.
2//!
3//! Each `McpClientTool` represents a single tool discovered from
4//! an external MCP server. The tool proxies execute calls through
5//! an `McpTransport` trait, which abstracts over stdio/SSE/HTTP.
6//!
7//! # Privacy Enforcement (Poka-Yoke)
8//!
9//! MCP servers are subject to `PrivacyTier` rules:
10//! - **Sovereign**: Only `stdio` transport allowed (local process)
11//! - **Private/Standard**: All transports allowed
12//!
13//! # References
14//!
15//! - arXiv:2505.02279 — MCP interoperability survey
16//! - arXiv:2503.23278 — MCP security analysis
17
18use std::time::Duration;
19
20use async_trait::async_trait;
21
22use super::{Tool, ToolResult};
23use crate::agent::capability::Capability;
24use crate::agent::driver::ToolDefinition;
25
26/// Transport abstraction for MCP server communication.
27///
28/// Separates the tool from the transport layer so that:
29/// - Tests use `MockMcpTransport`
30/// - Production uses `StdioMcpTransport` (Phase 2: `pmcp::Client` v2.3)
31/// - Future: SSE / WebSocket transports (both available in pmcp v2.3)
32#[async_trait]
33pub trait McpTransport: Send + Sync {
34    /// Call a tool on the MCP server.
35    async fn call_tool(&self, tool_name: &str, input: serde_json::Value) -> Result<String, String>;
36
37    /// Server name for capability matching.
38    fn server_name(&self) -> &str;
39}
40
41/// MCP client tool that proxies calls to an external MCP server.
42pub struct McpClientTool {
43    /// MCP server name (for capability matching).
44    server_name: String,
45    /// Tool name on the MCP server.
46    tool_name: String,
47    /// Tool description.
48    description: String,
49    /// JSON Schema for tool input.
50    input_schema: serde_json::Value,
51    /// Transport for calling the MCP server.
52    transport: Box<dyn McpTransport>,
53    /// Execution timeout.
54    timeout: Duration,
55}
56
57impl McpClientTool {
58    /// Create a new MCP client tool.
59    pub fn new(
60        server_name: impl Into<String>,
61        tool_name: impl Into<String>,
62        description: impl Into<String>,
63        input_schema: serde_json::Value,
64        transport: Box<dyn McpTransport>,
65    ) -> Self {
66        Self {
67            server_name: server_name.into(),
68            tool_name: tool_name.into(),
69            description: description.into(),
70            input_schema,
71            transport,
72            timeout: Duration::from_secs(60),
73        }
74    }
75
76    /// Set the execution timeout.
77    #[must_use]
78    pub fn with_timeout(mut self, timeout: Duration) -> Self {
79        self.timeout = timeout;
80        self
81    }
82
83    /// The prefixed tool name: `mcp_{server}_{tool}`.
84    fn prefixed_name(&self) -> String {
85        format!("mcp_{}_{}", self.server_name, self.tool_name)
86    }
87}
88
89#[async_trait]
90impl Tool for McpClientTool {
91    fn name(&self) -> &'static str {
92        // Leak the name to get 'static lifetime.
93        // This is safe because tool names live for the process.
94        Box::leak(self.prefixed_name().into_boxed_str())
95    }
96
97    fn definition(&self) -> ToolDefinition {
98        ToolDefinition {
99            name: self.prefixed_name(),
100            description: format!("[MCP:{}] {}", self.server_name, self.description),
101            input_schema: self.input_schema.clone(),
102        }
103    }
104
105    async fn execute(&self, input: serde_json::Value) -> ToolResult {
106        match self.transport.call_tool(&self.tool_name, input).await {
107            Ok(content) => ToolResult::success(content),
108            Err(e) => ToolResult::error(format!(
109                "MCP call to {}:{} failed: {}",
110                self.server_name, self.tool_name, e
111            )),
112        }
113    }
114
115    fn required_capability(&self) -> Capability {
116        Capability::Mcp { server: self.server_name.clone(), tool: self.tool_name.clone() }
117    }
118
119    fn timeout(&self) -> Duration {
120        self.timeout
121    }
122}
123
124/// Stdio MCP transport — launches a subprocess and communicates via stdin/stdout.
125///
126/// The subprocess is expected to speak JSON-RPC 2.0 with MCP tools/call messages.
127/// Each `call_tool` sends a request line and reads a response line.
128///
129/// # Privacy
130///
131/// This transport is allowed in Sovereign tier because the subprocess
132/// runs locally (no network egress).
133pub struct StdioMcpTransport {
134    server: String,
135    command: Vec<String>,
136}
137
138impl StdioMcpTransport {
139    /// Create a stdio transport for the given server.
140    ///
141    /// `command` is the full command line (e.g., `["node", "server.js"]`).
142    pub fn new(server: impl Into<String>, command: Vec<String>) -> Self {
143        Self { server: server.into(), command }
144    }
145}
146
147#[async_trait]
148impl McpTransport for StdioMcpTransport {
149    async fn call_tool(&self, tool_name: &str, input: serde_json::Value) -> Result<String, String> {
150        let request = serde_json::json!({
151            "jsonrpc": "2.0",
152            "id": 1,
153            "method": "tools/call",
154            "params": {
155                "name": tool_name,
156                "arguments": input,
157            }
158        });
159        let response = self.send_jsonrpc(&request).await?;
160        let result = response.get("result").ok_or("no result in response")?;
161        // MCP tools/call returns { content: [{ text: "..." }] }
162        if let Some(content) = result.get("content") {
163            if let Some(arr) = content.as_array() {
164                let texts: Vec<&str> =
165                    arr.iter().filter_map(|c| c.get("text").and_then(|t| t.as_str())).collect();
166                if !texts.is_empty() {
167                    return Ok(texts.join("\n"));
168                }
169            }
170        }
171        Ok(serde_json::to_string(result)
172            .unwrap_or_else(|e| format!(r#"{{"error": "serialize: {e}"}}"#)))
173    }
174
175    fn server_name(&self) -> &str {
176        &self.server
177    }
178}
179
180/// Discovered tool info from MCP `tools/list`.
181#[derive(Debug, Clone)]
182pub struct DiscoveredTool {
183    /// Tool name on the MCP server.
184    pub name: String,
185    /// Human-readable description.
186    pub description: String,
187    /// JSON Schema for input parameters.
188    pub input_schema: serde_json::Value,
189}
190
191impl StdioMcpTransport {
192    /// Discover available tools via MCP `tools/list`.
193    pub async fn discover_tools(&self) -> Result<Vec<DiscoveredTool>, String> {
194        let request = serde_json::json!({
195            "jsonrpc": "2.0",
196            "id": 1,
197            "method": "tools/list",
198            "params": {}
199        });
200        let response = self.send_jsonrpc(&request).await?;
201        let result = response.get("result").ok_or("no result in tools/list response")?;
202        let tools =
203            result.get("tools").and_then(|t| t.as_array()).ok_or("no tools array in response")?;
204        let mut discovered = Vec::new();
205        for tool in tools {
206            let name = tool.get("name").and_then(|n| n.as_str()).unwrap_or("").to_string();
207            let desc = tool.get("description").and_then(|d| d.as_str()).unwrap_or("").to_string();
208            let schema = tool.get("inputSchema").cloned().unwrap_or(serde_json::json!({}));
209            if !name.is_empty() {
210                discovered.push(DiscoveredTool { name, description: desc, input_schema: schema });
211            }
212        }
213        Ok(discovered)
214    }
215
216    /// Send a JSON-RPC request and return the parsed response.
217    async fn send_jsonrpc(&self, request: &serde_json::Value) -> Result<serde_json::Value, String> {
218        if self.command.is_empty() {
219            return Err("stdio transport: empty command".into());
220        }
221        let request_str =
222            serde_json::to_string(request).map_err(|e| format!("serialize request: {e}"))?;
223        let mut child = tokio::process::Command::new(&self.command[0])
224            .args(&self.command[1..])
225            .stdin(std::process::Stdio::piped())
226            .stdout(std::process::Stdio::piped())
227            .stderr(std::process::Stdio::piped())
228            .kill_on_drop(true)
229            .spawn()
230            .map_err(|e| format!("spawn {}: {e}", self.command[0]))?;
231        if let Some(mut stdin) = child.stdin.take() {
232            use tokio::io::AsyncWriteExt;
233            stdin
234                .write_all(request_str.as_bytes())
235                .await
236                .map_err(|e| format!("write stdin: {e}"))?;
237            stdin.write_all(b"\n").await.map_err(|e| format!("write newline: {e}"))?;
238            drop(stdin);
239        }
240        let result = child.wait_with_output().await.map_err(|e| format!("wait: {e}"))?;
241        if !result.status.success() {
242            let stderr = String::from_utf8_lossy(&result.stderr);
243            return Err(format!("process exited {}: {}", result.status, stderr.trim()));
244        }
245        let stdout = String::from_utf8_lossy(&result.stdout);
246        let response: serde_json::Value =
247            serde_json::from_str(stdout.trim()).map_err(|e| format!("parse response: {e}"))?;
248        if let Some(error) = response.get("error") {
249            let msg = error.get("message").and_then(|m| m.as_str()).unwrap_or("unknown error");
250            return Err(msg.to_string());
251        }
252        Ok(response)
253    }
254}
255
256/// Discover and register MCP tools from manifest config.
257///
258/// For each `mcp_server` in the manifest with `stdio` transport,
259/// launches the subprocess, calls `tools/list`, and wraps each
260/// discovered tool as an `McpClientTool`.
261#[cfg(feature = "agents-mcp")]
262pub async fn discover_mcp_tools(
263    manifest: &crate::agent::manifest::AgentManifest,
264) -> Vec<McpClientTool> {
265    use crate::agent::manifest::McpTransport;
266    use std::sync::Arc;
267
268    let mut tools = Vec::new();
269    for server in &manifest.mcp_servers {
270        if !matches!(server.transport, McpTransport::Stdio) {
271            continue;
272        }
273        let transport = Arc::new(StdioMcpTransport::new(&server.name, server.command.clone()));
274        let discovered = match transport.discover_tools().await {
275            Ok(d) => d,
276            Err(e) => {
277                tracing::warn!(
278                    server = %server.name,
279                    error = %e,
280                    "MCP tool discovery failed"
281                );
282                continue;
283            }
284        };
285        for tool_info in discovered {
286            let allowed = server.capabilities.iter().any(|c| c == "*" || c == &tool_info.name);
287            if !allowed {
288                tracing::debug!(
289                    server = %server.name,
290                    tool = %tool_info.name,
291                    "MCP tool not in capabilities, skipping"
292                );
293                continue;
294            }
295            tools.push(McpClientTool::new(
296                &server.name,
297                &tool_info.name,
298                &tool_info.description,
299                tool_info.input_schema,
300                Box::new(SharedTransport(Arc::clone(&transport))),
301            ));
302        }
303    }
304    tools
305}
306
307/// Wrapper to share an `Arc<StdioMcpTransport>` as `Box<dyn McpTransport>`.
308#[cfg(feature = "agents-mcp")]
309struct SharedTransport(std::sync::Arc<StdioMcpTransport>);
310
311#[cfg(feature = "agents-mcp")]
312#[async_trait]
313impl McpTransport for SharedTransport {
314    async fn call_tool(&self, tool_name: &str, input: serde_json::Value) -> Result<String, String> {
315        self.0.call_tool(tool_name, input).await
316    }
317    fn server_name(&self) -> &str {
318        self.0.server_name()
319    }
320}
321
322/// Mock MCP transport for testing.
323pub struct MockMcpTransport {
324    server: String,
325    responses: std::sync::Mutex<Vec<Result<String, String>>>,
326}
327
328impl MockMcpTransport {
329    /// Create a mock transport with pre-configured responses.
330    pub fn new(server: impl Into<String>, responses: Vec<Result<String, String>>) -> Self {
331        Self { server: server.into(), responses: std::sync::Mutex::new(responses) }
332    }
333}
334
335#[async_trait]
336impl McpTransport for MockMcpTransport {
337    async fn call_tool(
338        &self,
339        _tool_name: &str,
340        _input: serde_json::Value,
341    ) -> Result<String, String> {
342        let mut responses = self.responses.lock().expect("mock transport lock");
343        if responses.is_empty() {
344            Err("mock transport exhausted".into())
345        } else {
346            responses.remove(0)
347        }
348    }
349
350    fn server_name(&self) -> &str {
351        &self.server
352    }
353}
354
355#[cfg(test)]
356#[path = "mcp_client_tests.rs"]
357mod tests;