1use std::collections::HashMap;
6use std::sync::Arc;
7#[cfg(not(target_has_atomic = "64"))]
8use std::sync::atomic::AtomicU32;
9#[cfg(target_has_atomic = "64")]
10use std::sync::atomic::AtomicU64;
11use std::sync::atomic::Ordering;
12
13use anyhow::{Context, Result, anyhow, bail};
14use serde_json::json;
15use tokio::sync::Mutex;
16use tokio::time::{Duration, timeout};
17
18use crate::config::schema::McpServerConfig;
19use crate::tools::mcp_protocol::{
20 JsonRpcRequest, MCP_PROTOCOL_VERSION, McpToolDef, McpToolsListResult,
21};
22use crate::tools::mcp_transport::{McpTransportConn, create_transport};
23
24const RECV_TIMEOUT_SECS: u64 = 60;
28
29const DEFAULT_TOOL_TIMEOUT_SECS: u64 = 180;
31
32const MAX_TOOL_TIMEOUT_SECS: u64 = 600;
34
35struct McpServerInner {
38 config: McpServerConfig,
39 transport: Box<dyn McpTransportConn>,
40 #[cfg(target_has_atomic = "64")]
41 next_id: AtomicU64,
42 #[cfg(not(target_has_atomic = "64"))]
43 next_id: AtomicU32,
44 tools: Vec<McpToolDef>,
45}
46
47#[derive(Clone)]
51pub struct McpServer {
52 inner: Arc<Mutex<McpServerInner>>,
53}
54
55impl McpServer {
56 pub async fn connect(config: McpServerConfig) -> Result<Self> {
58 let mut transport = create_transport(&config).with_context(|| {
60 format!(
61 "failed to create transport for MCP server `{}`",
62 config.name
63 )
64 })?;
65
66 let id = 1u64;
68 let init_req = JsonRpcRequest::new(
69 id,
70 "initialize",
71 json!({
72 "protocolVersion": MCP_PROTOCOL_VERSION,
73 "capabilities": {},
74 "clientInfo": {
75 "name": "construct",
76 "version": env!("CARGO_PKG_VERSION")
77 }
78 }),
79 );
80
81 let init_resp = timeout(
82 Duration::from_secs(RECV_TIMEOUT_SECS),
83 transport.send_and_recv(&init_req),
84 )
85 .await
86 .with_context(|| {
87 format!(
88 "MCP server `{}` timed out after {}s waiting for initialize response",
89 config.name, RECV_TIMEOUT_SECS
90 )
91 })??;
92
93 if init_resp.error.is_some() {
94 bail!(
95 "MCP server `{}` rejected initialize: {:?}",
96 config.name,
97 init_resp.error
98 );
99 }
100
101 let notif = JsonRpcRequest::notification("notifications/initialized", json!({}));
104 let _ = transport.send_and_recv(¬if).await;
106
107 let id = 2u64;
109 let list_req = JsonRpcRequest::new(id, "tools/list", json!({}));
110
111 let list_resp = timeout(
112 Duration::from_secs(RECV_TIMEOUT_SECS),
113 transport.send_and_recv(&list_req),
114 )
115 .await
116 .with_context(|| {
117 format!(
118 "MCP server `{}` timed out after {}s waiting for tools/list response",
119 config.name, RECV_TIMEOUT_SECS
120 )
121 })??;
122
123 let result = list_resp
124 .result
125 .ok_or_else(|| anyhow!("tools/list returned no result from `{}`", config.name))?;
126 let tool_list: McpToolsListResult = serde_json::from_value(result)
127 .with_context(|| format!("failed to parse tools/list from `{}`", config.name))?;
128
129 let tool_count = tool_list.tools.len();
130
131 let inner = McpServerInner {
132 config,
133 transport,
134 #[cfg(target_has_atomic = "64")]
135 next_id: AtomicU64::new(3), #[cfg(not(target_has_atomic = "64"))]
137 next_id: AtomicU32::new(3), tools: tool_list.tools,
139 };
140
141 tracing::info!(
142 "MCP server `{}` connected — {} tool(s) available",
143 inner.config.name,
144 tool_count
145 );
146
147 Ok(Self {
148 inner: Arc::new(Mutex::new(inner)),
149 })
150 }
151
152 pub async fn tools(&self) -> Vec<McpToolDef> {
154 self.inner.lock().await.tools.clone()
155 }
156
157 pub async fn name(&self) -> String {
159 self.inner.lock().await.config.name.clone()
160 }
161
162 pub async fn call_tool(
164 &self,
165 tool_name: &str,
166 arguments: serde_json::Value,
167 ) -> Result<serde_json::Value> {
168 let mut inner = self.inner.lock().await;
169 let id = inner.next_id.fetch_add(1, Ordering::Relaxed) as u64;
170 let req = JsonRpcRequest::new(
171 id,
172 "tools/call",
173 json!({ "name": tool_name, "arguments": arguments }),
174 );
175
176 let tool_timeout = inner
179 .config
180 .tool_timeout_secs
181 .unwrap_or(DEFAULT_TOOL_TIMEOUT_SECS)
182 .min(MAX_TOOL_TIMEOUT_SECS);
183
184 let resp = timeout(
185 Duration::from_secs(tool_timeout),
186 inner.transport.send_and_recv(&req),
187 )
188 .await
189 .map_err(|_| {
190 anyhow!(
191 "MCP server `{}` timed out after {}s during tool call `{tool_name}`",
192 inner.config.name,
193 tool_timeout
194 )
195 })?
196 .with_context(|| {
197 format!(
198 "MCP server `{}` error during tool call `{tool_name}`",
199 inner.config.name
200 )
201 })?;
202
203 if let Some(err) = resp.error {
204 bail!("MCP tool `{tool_name}` error {}: {}", err.code, err.message);
205 }
206 Ok(resp.result.unwrap_or(serde_json::Value::Null))
207 }
208}
209
210pub struct McpRegistry {
214 servers: Vec<McpServer>,
215 tool_index: HashMap<String, (usize, String)>,
217}
218
219impl McpRegistry {
220 pub async fn connect_all(configs: &[McpServerConfig]) -> Result<Self> {
222 let mut servers = Vec::new();
223 let mut tool_index = HashMap::new();
224
225 for config in configs {
226 match McpServer::connect(config.clone()).await {
227 Ok(server) => {
228 let server_idx = servers.len();
229 let tools = server.tools().await;
231 for tool in &tools {
232 let prefixed = format!("{}__{}", config.name, tool.name);
234 tool_index.insert(prefixed, (server_idx, tool.name.clone()));
235 }
236 servers.push(server);
237 }
238 Err(e) => {
240 tracing::error!("Failed to connect to MCP server `{}`: {:#}", config.name, e);
241 }
242 }
243 }
244
245 Ok(Self {
246 servers,
247 tool_index,
248 })
249 }
250
251 pub fn tool_names(&self) -> Vec<String> {
253 self.tool_index.keys().cloned().collect()
254 }
255
256 pub async fn get_tool_def(&self, prefixed_name: &str) -> Option<McpToolDef> {
258 let (server_idx, original_name) = self.tool_index.get(prefixed_name)?;
259 let inner = self.servers[*server_idx].inner.lock().await;
260 inner
261 .tools
262 .iter()
263 .find(|t| &t.name == original_name)
264 .cloned()
265 }
266
267 pub async fn call_tool(
269 &self,
270 prefixed_name: &str,
271 arguments: serde_json::Value,
272 ) -> Result<String> {
273 let (server_idx, original_name) = self
274 .tool_index
275 .get(prefixed_name)
276 .ok_or_else(|| anyhow!("unknown MCP tool `{prefixed_name}`"))?;
277 let result = self.servers[*server_idx]
278 .call_tool(original_name, arguments)
279 .await?;
280 serde_json::to_string_pretty(&result)
281 .with_context(|| format!("failed to serialize result of MCP tool `{prefixed_name}`"))
282 }
283
284 pub fn is_empty(&self) -> bool {
285 self.servers.is_empty()
286 }
287
288 pub fn server_count(&self) -> usize {
289 self.servers.len()
290 }
291
292 pub fn tool_count(&self) -> usize {
293 self.tool_index.len()
294 }
295}
296
297#[cfg(test)]
298mod tests {
299 use super::*;
300 use crate::config::schema::McpTransport;
301
302 #[test]
303 fn tool_name_prefix_format() {
304 let prefixed = format!("{}__{}", "filesystem", "read_file");
305 assert_eq!(prefixed, "filesystem__read_file");
306 }
307
308 #[tokio::test]
309 async fn connect_nonexistent_command_fails_cleanly() {
310 let config = McpServerConfig {
312 name: "nonexistent".to_string(),
313 command: "/usr/bin/this_binary_does_not_exist_construct_test".to_string(),
314 args: vec![],
315 env: std::collections::HashMap::default(),
316 tool_timeout_secs: None,
317 transport: McpTransport::Stdio,
318 url: None,
319 headers: std::collections::HashMap::default(),
320 };
321 let result = McpServer::connect(config).await;
322 assert!(result.is_err());
323 let msg = result.err().unwrap().to_string();
324 assert!(msg.contains("failed to create transport"), "got: {msg}");
325 }
326
327 #[tokio::test]
328 async fn connect_all_nonfatal_on_single_failure() {
329 let configs = vec![McpServerConfig {
331 name: "bad".to_string(),
332 command: "/usr/bin/does_not_exist_zc_test".to_string(),
333 args: vec![],
334 env: std::collections::HashMap::default(),
335 tool_timeout_secs: None,
336 transport: McpTransport::Stdio,
337 url: None,
338 headers: std::collections::HashMap::default(),
339 }];
340 let registry = McpRegistry::connect_all(&configs)
341 .await
342 .expect("connect_all should not fail");
343 assert!(registry.is_empty());
344 assert_eq!(registry.tool_count(), 0);
345 }
346
347 #[test]
348 fn http_transport_requires_url() {
349 let config = McpServerConfig {
350 name: "test".into(),
351 transport: McpTransport::Http,
352 ..Default::default()
353 };
354 let result = create_transport(&config);
355 assert!(result.is_err());
356 }
357
358 #[test]
359 fn sse_transport_requires_url() {
360 let config = McpServerConfig {
361 name: "test".into(),
362 transport: McpTransport::Sse,
363 ..Default::default()
364 };
365 let result = create_transport(&config);
366 assert!(result.is_err());
367 }
368
369 #[tokio::test]
372 async fn empty_registry_is_empty() {
373 let registry = McpRegistry::connect_all(&[])
374 .await
375 .expect("connect_all on empty slice should succeed");
376 assert!(registry.is_empty());
377 assert_eq!(registry.server_count(), 0);
378 assert_eq!(registry.tool_count(), 0);
379 }
380
381 #[tokio::test]
382 async fn empty_registry_tool_names_is_empty() {
383 let registry = McpRegistry::connect_all(&[])
384 .await
385 .expect("connect_all should succeed");
386 assert!(registry.tool_names().is_empty());
387 }
388
389 #[tokio::test]
390 async fn empty_registry_get_tool_def_returns_none() {
391 let registry = McpRegistry::connect_all(&[])
392 .await
393 .expect("connect_all should succeed");
394 let result = registry.get_tool_def("nonexistent__tool").await;
395 assert!(result.is_none());
396 }
397
398 #[tokio::test]
399 async fn empty_registry_call_tool_unknown_name_returns_error() {
400 let registry = McpRegistry::connect_all(&[])
401 .await
402 .expect("connect_all should succeed");
403 let err = registry
404 .call_tool("nonexistent__tool", serde_json::json!({}))
405 .await
406 .expect_err("should fail for unknown tool");
407 assert!(err.to_string().contains("unknown MCP tool"), "got: {err}");
408 }
409
410 #[tokio::test]
411 async fn connect_all_empty_gives_zero_servers() {
412 let registry = McpRegistry::connect_all(&[])
413 .await
414 .expect("connect_all should succeed");
415 assert_eq!(registry.server_count(), 0);
417 assert_eq!(registry.tool_count(), 0);
418 assert!(registry.is_empty());
419 }
420}