Skip to main content

construct/tools/
mcp_client.rs

1//! MCP (Model Context Protocol) client — connects to external tool servers.
2//!
3//! Supports multiple transports: stdio (spawn local process), HTTP, and SSE.
4
5use std::collections::HashMap;
6use std::sync::Arc;
7#[cfg(not(target_has_atomic = "64"))]
8use std::sync::atomic::AtomicU32;
9#[cfg(target_has_atomic = "64")]
10use std::sync::atomic::AtomicU64;
11use std::sync::atomic::Ordering;
12
13use anyhow::{Context, Result, anyhow, bail};
14use serde_json::json;
15use tokio::sync::Mutex;
16use tokio::time::{Duration, timeout};
17
18use crate::config::schema::McpServerConfig;
19use crate::tools::mcp_protocol::{
20    JsonRpcRequest, MCP_PROTOCOL_VERSION, McpToolDef, McpToolsListResult,
21};
22use crate::tools::mcp_transport::{McpTransportConn, create_transport};
23
24/// Timeout for receiving a response from an MCP server during init/list.
25/// Prevents a hung server from blocking the daemon indefinitely.
26/// Set to 60s to accommodate MCP servers with heavy startup (venv, network discovery).
27const RECV_TIMEOUT_SECS: u64 = 60;
28
29/// Default timeout for tool calls (seconds) when not configured per-server.
30const DEFAULT_TOOL_TIMEOUT_SECS: u64 = 180;
31
32/// Maximum allowed tool call timeout (seconds) — hard safety ceiling.
33const MAX_TOOL_TIMEOUT_SECS: u64 = 600;
34
35// ── Internal server state ──────────────────────────────────────────────────
36
37struct McpServerInner {
38    config: McpServerConfig,
39    transport: Box<dyn McpTransportConn>,
40    #[cfg(target_has_atomic = "64")]
41    next_id: AtomicU64,
42    #[cfg(not(target_has_atomic = "64"))]
43    next_id: AtomicU32,
44    tools: Vec<McpToolDef>,
45}
46
47// ── McpServer ──────────────────────────────────────────────────────────────
48
49/// A live connection to one MCP server (any transport).
50#[derive(Clone)]
51pub struct McpServer {
52    inner: Arc<Mutex<McpServerInner>>,
53}
54
55impl McpServer {
56    /// Connect to the server, perform the initialize handshake, and fetch the tool list.
57    pub async fn connect(config: McpServerConfig) -> Result<Self> {
58        // Create transport based on config
59        let mut transport = create_transport(&config).with_context(|| {
60            format!(
61                "failed to create transport for MCP server `{}`",
62                config.name
63            )
64        })?;
65
66        // Initialize handshake
67        let id = 1u64;
68        let init_req = JsonRpcRequest::new(
69            id,
70            "initialize",
71            json!({
72                "protocolVersion": MCP_PROTOCOL_VERSION,
73                "capabilities": {},
74                "clientInfo": {
75                    "name": "construct",
76                    "version": env!("CARGO_PKG_VERSION")
77                }
78            }),
79        );
80
81        let init_resp = timeout(
82            Duration::from_secs(RECV_TIMEOUT_SECS),
83            transport.send_and_recv(&init_req),
84        )
85        .await
86        .with_context(|| {
87            format!(
88                "MCP server `{}` timed out after {}s waiting for initialize response",
89                config.name, RECV_TIMEOUT_SECS
90            )
91        })??;
92
93        if init_resp.error.is_some() {
94            bail!(
95                "MCP server `{}` rejected initialize: {:?}",
96                config.name,
97                init_resp.error
98            );
99        }
100
101        // Notify server that client is initialized (no response expected for notifications)
102        // For notifications, we send but don't wait for response
103        let notif = JsonRpcRequest::notification("notifications/initialized", json!({}));
104        // Best effort - ignore errors for notifications
105        let _ = transport.send_and_recv(&notif).await;
106
107        // Fetch available tools
108        let id = 2u64;
109        let list_req = JsonRpcRequest::new(id, "tools/list", json!({}));
110
111        let list_resp = timeout(
112            Duration::from_secs(RECV_TIMEOUT_SECS),
113            transport.send_and_recv(&list_req),
114        )
115        .await
116        .with_context(|| {
117            format!(
118                "MCP server `{}` timed out after {}s waiting for tools/list response",
119                config.name, RECV_TIMEOUT_SECS
120            )
121        })??;
122
123        let result = list_resp
124            .result
125            .ok_or_else(|| anyhow!("tools/list returned no result from `{}`", config.name))?;
126        let tool_list: McpToolsListResult = serde_json::from_value(result)
127            .with_context(|| format!("failed to parse tools/list from `{}`", config.name))?;
128
129        let tool_count = tool_list.tools.len();
130
131        let inner = McpServerInner {
132            config,
133            transport,
134            #[cfg(target_has_atomic = "64")]
135            next_id: AtomicU64::new(3), // Start at 3 since we used 1 and 2
136            #[cfg(not(target_has_atomic = "64"))]
137            next_id: AtomicU32::new(3), // Start at 3 since we used 1 and 2
138            tools: tool_list.tools,
139        };
140
141        tracing::info!(
142            "MCP server `{}` connected — {} tool(s) available",
143            inner.config.name,
144            tool_count
145        );
146
147        Ok(Self {
148            inner: Arc::new(Mutex::new(inner)),
149        })
150    }
151
152    /// Tools advertised by this server.
153    pub async fn tools(&self) -> Vec<McpToolDef> {
154        self.inner.lock().await.tools.clone()
155    }
156
157    /// Server display name.
158    pub async fn name(&self) -> String {
159        self.inner.lock().await.config.name.clone()
160    }
161
162    /// Call a tool on this server. Returns the raw JSON result.
163    pub async fn call_tool(
164        &self,
165        tool_name: &str,
166        arguments: serde_json::Value,
167    ) -> Result<serde_json::Value> {
168        let mut inner = self.inner.lock().await;
169        let id = inner.next_id.fetch_add(1, Ordering::Relaxed) as u64;
170        let req = JsonRpcRequest::new(
171            id,
172            "tools/call",
173            json!({ "name": tool_name, "arguments": arguments }),
174        );
175
176        // Use per-server tool timeout if configured, otherwise default.
177        // Cap at MAX_TOOL_TIMEOUT_SECS for safety.
178        let tool_timeout = inner
179            .config
180            .tool_timeout_secs
181            .unwrap_or(DEFAULT_TOOL_TIMEOUT_SECS)
182            .min(MAX_TOOL_TIMEOUT_SECS);
183
184        let resp = timeout(
185            Duration::from_secs(tool_timeout),
186            inner.transport.send_and_recv(&req),
187        )
188        .await
189        .map_err(|_| {
190            anyhow!(
191                "MCP server `{}` timed out after {}s during tool call `{tool_name}`",
192                inner.config.name,
193                tool_timeout
194            )
195        })?
196        .with_context(|| {
197            format!(
198                "MCP server `{}` error during tool call `{tool_name}`",
199                inner.config.name
200            )
201        })?;
202
203        if let Some(err) = resp.error {
204            bail!("MCP tool `{tool_name}` error {}: {}", err.code, err.message);
205        }
206        Ok(resp.result.unwrap_or(serde_json::Value::Null))
207    }
208}
209
210// ── McpRegistry ───────────────────────────────────────────────────────────
211
212/// Registry of all connected MCP servers, with a flat tool index.
213pub struct McpRegistry {
214    servers: Vec<McpServer>,
215    /// prefixed_name → (server_index, original_tool_name)
216    tool_index: HashMap<String, (usize, String)>,
217}
218
219impl McpRegistry {
220    /// Connect to all configured servers. Non-fatal: failures are logged and skipped.
221    pub async fn connect_all(configs: &[McpServerConfig]) -> Result<Self> {
222        let mut servers = Vec::new();
223        let mut tool_index = HashMap::new();
224
225        for config in configs {
226            match McpServer::connect(config.clone()).await {
227                Ok(server) => {
228                    let server_idx = servers.len();
229                    // Collect tools while holding the lock once, then release
230                    let tools = server.tools().await;
231                    for tool in &tools {
232                        // Prefix prevents name collisions across servers
233                        let prefixed = format!("{}__{}", config.name, tool.name);
234                        tool_index.insert(prefixed, (server_idx, tool.name.clone()));
235                    }
236                    servers.push(server);
237                }
238                // Non-fatal — log and continue with remaining servers
239                Err(e) => {
240                    tracing::error!("Failed to connect to MCP server `{}`: {:#}", config.name, e);
241                }
242            }
243        }
244
245        Ok(Self {
246            servers,
247            tool_index,
248        })
249    }
250
251    /// All prefixed tool names across all connected servers.
252    pub fn tool_names(&self) -> Vec<String> {
253        self.tool_index.keys().cloned().collect()
254    }
255
256    /// Tool definition for a given prefixed name (cloned).
257    pub async fn get_tool_def(&self, prefixed_name: &str) -> Option<McpToolDef> {
258        let (server_idx, original_name) = self.tool_index.get(prefixed_name)?;
259        let inner = self.servers[*server_idx].inner.lock().await;
260        inner
261            .tools
262            .iter()
263            .find(|t| &t.name == original_name)
264            .cloned()
265    }
266
267    /// Execute a tool by prefixed name.
268    pub async fn call_tool(
269        &self,
270        prefixed_name: &str,
271        arguments: serde_json::Value,
272    ) -> Result<String> {
273        let (server_idx, original_name) = self
274            .tool_index
275            .get(prefixed_name)
276            .ok_or_else(|| anyhow!("unknown MCP tool `{prefixed_name}`"))?;
277        let result = self.servers[*server_idx]
278            .call_tool(original_name, arguments)
279            .await?;
280        serde_json::to_string_pretty(&result)
281            .with_context(|| format!("failed to serialize result of MCP tool `{prefixed_name}`"))
282    }
283
284    pub fn is_empty(&self) -> bool {
285        self.servers.is_empty()
286    }
287
288    pub fn server_count(&self) -> usize {
289        self.servers.len()
290    }
291
292    pub fn tool_count(&self) -> usize {
293        self.tool_index.len()
294    }
295}
296
297#[cfg(test)]
298mod tests {
299    use super::*;
300    use crate::config::schema::McpTransport;
301
302    #[test]
303    fn tool_name_prefix_format() {
304        let prefixed = format!("{}__{}", "filesystem", "read_file");
305        assert_eq!(prefixed, "filesystem__read_file");
306    }
307
308    #[tokio::test]
309    async fn connect_nonexistent_command_fails_cleanly() {
310        // A command that doesn't exist should fail at spawn, not panic.
311        let config = McpServerConfig {
312            name: "nonexistent".to_string(),
313            command: "/usr/bin/this_binary_does_not_exist_construct_test".to_string(),
314            args: vec![],
315            env: std::collections::HashMap::default(),
316            tool_timeout_secs: None,
317            transport: McpTransport::Stdio,
318            url: None,
319            headers: std::collections::HashMap::default(),
320        };
321        let result = McpServer::connect(config).await;
322        assert!(result.is_err());
323        let msg = result.err().unwrap().to_string();
324        assert!(msg.contains("failed to create transport"), "got: {msg}");
325    }
326
327    #[tokio::test]
328    async fn connect_all_nonfatal_on_single_failure() {
329        // If one server config is bad, connect_all should succeed (with 0 servers).
330        let configs = vec![McpServerConfig {
331            name: "bad".to_string(),
332            command: "/usr/bin/does_not_exist_zc_test".to_string(),
333            args: vec![],
334            env: std::collections::HashMap::default(),
335            tool_timeout_secs: None,
336            transport: McpTransport::Stdio,
337            url: None,
338            headers: std::collections::HashMap::default(),
339        }];
340        let registry = McpRegistry::connect_all(&configs)
341            .await
342            .expect("connect_all should not fail");
343        assert!(registry.is_empty());
344        assert_eq!(registry.tool_count(), 0);
345    }
346
347    #[test]
348    fn http_transport_requires_url() {
349        let config = McpServerConfig {
350            name: "test".into(),
351            transport: McpTransport::Http,
352            ..Default::default()
353        };
354        let result = create_transport(&config);
355        assert!(result.is_err());
356    }
357
358    #[test]
359    fn sse_transport_requires_url() {
360        let config = McpServerConfig {
361            name: "test".into(),
362            transport: McpTransport::Sse,
363            ..Default::default()
364        };
365        let result = create_transport(&config);
366        assert!(result.is_err());
367    }
368
369    // ── Empty registry (no servers) ────────────────────────────────────────
370
371    #[tokio::test]
372    async fn empty_registry_is_empty() {
373        let registry = McpRegistry::connect_all(&[])
374            .await
375            .expect("connect_all on empty slice should succeed");
376        assert!(registry.is_empty());
377        assert_eq!(registry.server_count(), 0);
378        assert_eq!(registry.tool_count(), 0);
379    }
380
381    #[tokio::test]
382    async fn empty_registry_tool_names_is_empty() {
383        let registry = McpRegistry::connect_all(&[])
384            .await
385            .expect("connect_all should succeed");
386        assert!(registry.tool_names().is_empty());
387    }
388
389    #[tokio::test]
390    async fn empty_registry_get_tool_def_returns_none() {
391        let registry = McpRegistry::connect_all(&[])
392            .await
393            .expect("connect_all should succeed");
394        let result = registry.get_tool_def("nonexistent__tool").await;
395        assert!(result.is_none());
396    }
397
398    #[tokio::test]
399    async fn empty_registry_call_tool_unknown_name_returns_error() {
400        let registry = McpRegistry::connect_all(&[])
401            .await
402            .expect("connect_all should succeed");
403        let err = registry
404            .call_tool("nonexistent__tool", serde_json::json!({}))
405            .await
406            .expect_err("should fail for unknown tool");
407        assert!(err.to_string().contains("unknown MCP tool"), "got: {err}");
408    }
409
410    #[tokio::test]
411    async fn connect_all_empty_gives_zero_servers() {
412        let registry = McpRegistry::connect_all(&[])
413            .await
414            .expect("connect_all should succeed");
415        // Verify all three count methods agree on zero.
416        assert_eq!(registry.server_count(), 0);
417        assert_eq!(registry.tool_count(), 0);
418        assert!(registry.is_empty());
419    }
420}