1use std::collections::HashMap;
2use std::path::{Path, PathBuf};
3use std::sync::Arc;
4use std::time::{SystemTime, UNIX_EPOCH};
5
6use reqwest::header::{HeaderMap, HeaderName, HeaderValue, ACCEPT, CONTENT_TYPE};
7use serde::{Deserialize, Serialize};
8use serde_json::{json, Value};
9use sha2::{Digest, Sha256};
10use tandem_types::ToolResult;
11use tokio::process::{Child, Command};
12use tokio::sync::{Mutex, RwLock};
13
14const MCP_PROTOCOL_VERSION: &str = "2025-11-25";
15const MCP_CLIENT_NAME: &str = "tandem";
16const MCP_CLIENT_VERSION: &str = env!("CARGO_PKG_VERSION");
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct McpToolCacheEntry {
20 pub tool_name: String,
21 pub description: String,
22 #[serde(default)]
23 pub input_schema: Value,
24 pub fetched_at_ms: u64,
25 pub schema_hash: String,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct McpServer {
30 pub name: String,
31 pub transport: String,
32 #[serde(default = "default_enabled")]
33 pub enabled: bool,
34 pub connected: bool,
35 #[serde(skip_serializing_if = "Option::is_none")]
36 pub pid: Option<u32>,
37 #[serde(skip_serializing_if = "Option::is_none")]
38 pub last_error: Option<String>,
39 #[serde(default)]
40 pub headers: HashMap<String, String>,
41 #[serde(default)]
42 pub tool_cache: Vec<McpToolCacheEntry>,
43 #[serde(default, skip_serializing_if = "Option::is_none")]
44 pub tools_fetched_at_ms: Option<u64>,
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct McpRemoteTool {
49 pub server_name: String,
50 pub tool_name: String,
51 pub namespaced_name: String,
52 pub description: String,
53 #[serde(default)]
54 pub input_schema: Value,
55 pub fetched_at_ms: u64,
56 pub schema_hash: String,
57}
58
59#[derive(Clone)]
60pub struct McpRegistry {
61 servers: Arc<RwLock<HashMap<String, McpServer>>>,
62 processes: Arc<Mutex<HashMap<String, Child>>>,
63 state_file: Arc<PathBuf>,
64}
65
66impl McpRegistry {
67 pub fn new() -> Self {
68 Self::new_with_state_file(resolve_state_file())
69 }
70
71 pub fn new_with_state_file(state_file: PathBuf) -> Self {
72 let loaded = load_state(&state_file)
73 .into_iter()
74 .map(|(k, mut v)| {
75 v.connected = false;
76 v.pid = None;
77 if v.name.trim().is_empty() {
78 v.name = k.clone();
79 }
80 if v.headers.is_empty() {
81 v.headers = HashMap::new();
82 }
83 (k, v)
84 })
85 .collect::<HashMap<_, _>>();
86 Self {
87 servers: Arc::new(RwLock::new(loaded)),
88 processes: Arc::new(Mutex::new(HashMap::new())),
89 state_file: Arc::new(state_file),
90 }
91 }
92
93 pub async fn list(&self) -> HashMap<String, McpServer> {
94 self.servers.read().await.clone()
95 }
96
97 pub async fn add(&self, name: String, transport: String) {
98 self.add_or_update(name, transport, HashMap::new(), true).await;
99 }
100
101 pub async fn add_or_update(
102 &self,
103 name: String,
104 transport: String,
105 headers: HashMap<String, String>,
106 enabled: bool,
107 ) {
108 let mut servers = self.servers.write().await;
109 let existing = servers.get(&name).cloned();
110 let existing_tool_cache = existing
111 .as_ref()
112 .map(|row| row.tool_cache.clone())
113 .unwrap_or_default();
114 let existing_fetched_at = existing.as_ref().and_then(|row| row.tools_fetched_at_ms);
115 let server = McpServer {
116 name: name.clone(),
117 transport,
118 enabled,
119 connected: false,
120 pid: None,
121 last_error: None,
122 headers,
123 tool_cache: existing_tool_cache,
124 tools_fetched_at_ms: existing_fetched_at,
125 };
126 servers.insert(name, server);
127 drop(servers);
128 self.persist_state().await;
129 }
130
131 pub async fn set_enabled(&self, name: &str, enabled: bool) -> bool {
132 let mut servers = self.servers.write().await;
133 let Some(server) = servers.get_mut(name) else {
134 return false;
135 };
136 server.enabled = enabled;
137 if !enabled {
138 server.connected = false;
139 server.pid = None;
140 }
141 drop(servers);
142 if !enabled {
143 if let Some(mut child) = self.processes.lock().await.remove(name) {
144 let _ = child.kill().await;
145 let _ = child.wait().await;
146 }
147 }
148 self.persist_state().await;
149 true
150 }
151
152 pub async fn connect(&self, name: &str) -> bool {
153 let server = {
154 let servers = self.servers.read().await;
155 let Some(server) = servers.get(name) else {
156 return false;
157 };
158 server.clone()
159 };
160
161 if !server.enabled {
162 let mut servers = self.servers.write().await;
163 if let Some(entry) = servers.get_mut(name) {
164 entry.connected = false;
165 entry.pid = None;
166 entry.last_error = Some("MCP server is disabled".to_string());
167 }
168 drop(servers);
169 self.persist_state().await;
170 return false;
171 }
172
173 if let Some(command_text) = parse_stdio_transport(&server.transport) {
174 return self.connect_stdio(name, command_text).await;
175 }
176
177 if parse_remote_endpoint(&server.transport).is_some() {
178 return self.refresh(name).await.is_ok();
179 }
180
181 let mut servers = self.servers.write().await;
182 if let Some(entry) = servers.get_mut(name) {
183 entry.connected = true;
184 entry.pid = None;
185 entry.last_error = None;
186 }
187 drop(servers);
188 self.persist_state().await;
189 true
190 }
191
192 pub async fn refresh(&self, name: &str) -> Result<Vec<McpRemoteTool>, String> {
193 let server = {
194 let servers = self.servers.read().await;
195 let Some(server) = servers.get(name) else {
196 return Err("MCP server not found".to_string());
197 };
198 server.clone()
199 };
200
201 if !server.enabled {
202 return Err("MCP server is disabled".to_string());
203 }
204
205 let endpoint = parse_remote_endpoint(&server.transport)
206 .ok_or_else(|| "MCP refresh currently supports HTTP/S transports only".to_string())?;
207
208 let tools = match self.discover_remote_tools(&endpoint, &server.headers).await {
209 Ok(tools) => tools,
210 Err(err) => {
211 let mut servers = self.servers.write().await;
212 if let Some(entry) = servers.get_mut(name) {
213 entry.connected = false;
214 entry.pid = None;
215 entry.last_error = Some(err.clone());
216 }
217 drop(servers);
218 self.persist_state().await;
219 return Err(err);
220 }
221 };
222
223 let now = now_ms();
224 let cache = tools
225 .iter()
226 .map(|tool| McpToolCacheEntry {
227 tool_name: tool.tool_name.clone(),
228 description: tool.description.clone(),
229 input_schema: tool.input_schema.clone(),
230 fetched_at_ms: now,
231 schema_hash: schema_hash(&tool.input_schema),
232 })
233 .collect::<Vec<_>>();
234
235 let mut servers = self.servers.write().await;
236 if let Some(entry) = servers.get_mut(name) {
237 entry.connected = true;
238 entry.pid = None;
239 entry.last_error = None;
240 entry.tool_cache = cache;
241 entry.tools_fetched_at_ms = Some(now);
242 }
243 drop(servers);
244 self.persist_state().await;
245 Ok(self.server_tools(name).await)
246 }
247
248 pub async fn disconnect(&self, name: &str) -> bool {
249 if let Some(mut child) = self.processes.lock().await.remove(name) {
250 let _ = child.kill().await;
251 let _ = child.wait().await;
252 }
253 let mut servers = self.servers.write().await;
254 if let Some(server) = servers.get_mut(name) {
255 server.connected = false;
256 server.pid = None;
257 drop(servers);
258 self.persist_state().await;
259 return true;
260 }
261 false
262 }
263
264 pub async fn list_tools(&self) -> Vec<McpRemoteTool> {
265 let mut out = self
266 .servers
267 .read()
268 .await
269 .values()
270 .filter(|server| server.enabled && server.connected)
271 .flat_map(server_tool_rows)
272 .collect::<Vec<_>>();
273 out.sort_by(|a, b| a.namespaced_name.cmp(&b.namespaced_name));
274 out
275 }
276
277 pub async fn server_tools(&self, name: &str) -> Vec<McpRemoteTool> {
278 let Some(server) = self.servers.read().await.get(name).cloned() else {
279 return Vec::new();
280 };
281 let mut rows = server_tool_rows(&server);
282 rows.sort_by(|a, b| a.namespaced_name.cmp(&b.namespaced_name));
283 rows
284 }
285
286 pub async fn call_tool(
287 &self,
288 server_name: &str,
289 tool_name: &str,
290 args: Value,
291 ) -> Result<ToolResult, String> {
292 let server = {
293 let servers = self.servers.read().await;
294 let Some(server) = servers.get(server_name) else {
295 return Err(format!("MCP server '{server_name}' not found"));
296 };
297 server.clone()
298 };
299
300 if !server.enabled {
301 return Err(format!("MCP server '{server_name}' is disabled"));
302 }
303 if !server.connected {
304 return Err(format!("MCP server '{server_name}' is not connected"));
305 }
306
307 let endpoint = parse_remote_endpoint(&server.transport)
308 .ok_or_else(|| "MCP tools/call currently supports HTTP/S transports only".to_string())?;
309
310 let request = json!({
311 "jsonrpc": "2.0",
312 "id": format!("call-{}-{}", server_name, now_ms()),
313 "method": "tools/call",
314 "params": {
315 "name": tool_name,
316 "arguments": args
317 }
318 });
319 let response = post_json_rpc(&endpoint, &server.headers, request).await?;
320
321 if let Some(err) = response.get("error") {
322 let message = err
323 .get("message")
324 .and_then(|v| v.as_str())
325 .unwrap_or("MCP tools/call failed");
326 return Err(message.to_string());
327 }
328
329 let result = response.get("result").cloned().unwrap_or(Value::Null);
330 let output = result
331 .get("content")
332 .map(render_mcp_content)
333 .or_else(|| result.get("output").map(|v| v.to_string()))
334 .unwrap_or_else(|| result.to_string());
335
336 Ok(ToolResult {
337 output,
338 metadata: json!({
339 "server": server_name,
340 "tool": tool_name,
341 "result": result
342 }),
343 })
344 }
345
346 async fn connect_stdio(&self, name: &str, command_text: &str) -> bool {
347 match spawn_stdio_process(command_text).await {
348 Ok(child) => {
349 let pid = child.id();
350 self.processes.lock().await.insert(name.to_string(), child);
351 let mut servers = self.servers.write().await;
352 if let Some(server) = servers.get_mut(name) {
353 server.connected = true;
354 server.pid = pid;
355 server.last_error = None;
356 }
357 drop(servers);
358 self.persist_state().await;
359 true
360 }
361 Err(err) => {
362 let mut servers = self.servers.write().await;
363 if let Some(server) = servers.get_mut(name) {
364 server.connected = false;
365 server.pid = None;
366 server.last_error = Some(err);
367 }
368 drop(servers);
369 self.persist_state().await;
370 false
371 }
372 }
373 }
374
375 async fn discover_remote_tools(
376 &self,
377 endpoint: &str,
378 headers: &HashMap<String, String>,
379 ) -> Result<Vec<McpRemoteTool>, String> {
380 let initialize = json!({
381 "jsonrpc": "2.0",
382 "id": "initialize-1",
383 "method": "initialize",
384 "params": {
385 "protocolVersion": MCP_PROTOCOL_VERSION,
386 "capabilities": {},
387 "clientInfo": {
388 "name": MCP_CLIENT_NAME,
389 "version": MCP_CLIENT_VERSION,
390 }
391 }
392 });
393 let init_response = post_json_rpc(endpoint, headers, initialize).await?;
394 if let Some(err) = init_response.get("error") {
395 let message = err
396 .get("message")
397 .and_then(|v| v.as_str())
398 .unwrap_or("MCP initialize failed");
399 return Err(message.to_string());
400 }
401
402 let tools_list = json!({
403 "jsonrpc": "2.0",
404 "id": "tools-list-1",
405 "method": "tools/list",
406 "params": {}
407 });
408 let tools_response = post_json_rpc(endpoint, headers, tools_list).await?;
409 if let Some(err) = tools_response.get("error") {
410 let message = err
411 .get("message")
412 .and_then(|v| v.as_str())
413 .unwrap_or("MCP tools/list failed");
414 return Err(message.to_string());
415 }
416
417 let tools = tools_response
418 .get("result")
419 .and_then(|v| v.get("tools"))
420 .and_then(|v| v.as_array())
421 .ok_or_else(|| "MCP tools/list result missing tools array".to_string())?;
422
423 let now = now_ms();
424 let mut out = Vec::new();
425 for row in tools {
426 let Some(tool_name) = row.get("name").and_then(|v| v.as_str()) else {
427 continue;
428 };
429 let description = row
430 .get("description")
431 .and_then(|v| v.as_str())
432 .unwrap_or("")
433 .to_string();
434 let input_schema = row
435 .get("inputSchema")
436 .or_else(|| row.get("input_schema"))
437 .cloned()
438 .unwrap_or_else(|| json!({"type":"object"}));
439 out.push(McpRemoteTool {
440 server_name: String::new(),
441 tool_name: tool_name.to_string(),
442 namespaced_name: String::new(),
443 description,
444 input_schema,
445 fetched_at_ms: now,
446 schema_hash: String::new(),
447 });
448 }
449
450 Ok(out)
451 }
452
453 async fn persist_state(&self) {
454 let snapshot = self.servers.read().await.clone();
455 if let Some(parent) = self.state_file.parent() {
456 let _ = tokio::fs::create_dir_all(parent).await;
457 }
458 if let Ok(payload) = serde_json::to_string_pretty(&snapshot) {
459 let _ = tokio::fs::write(self.state_file.as_path(), payload).await;
460 }
461 }
462}
463
464impl Default for McpRegistry {
465 fn default() -> Self {
466 Self::new()
467 }
468}
469
470fn default_enabled() -> bool {
471 true
472}
473
474fn resolve_state_file() -> PathBuf {
475 if let Ok(path) = std::env::var("TANDEM_MCP_REGISTRY") {
476 return PathBuf::from(path);
477 }
478 PathBuf::from(".tandem").join("mcp_servers.json")
479}
480
481fn load_state(path: &Path) -> HashMap<String, McpServer> {
482 let Ok(raw) = std::fs::read_to_string(path) else {
483 return HashMap::new();
484 };
485 serde_json::from_str::<HashMap<String, McpServer>>(&raw).unwrap_or_default()
486}
487
488fn parse_stdio_transport(transport: &str) -> Option<&str> {
489 transport.strip_prefix("stdio:").map(str::trim)
490}
491
492fn parse_remote_endpoint(transport: &str) -> Option<String> {
493 let trimmed = transport.trim();
494 if trimmed.starts_with("http://") || trimmed.starts_with("https://") {
495 return Some(trimmed.to_string());
496 }
497 for prefix in ["http:", "https:"] {
498 if let Some(rest) = trimmed.strip_prefix(prefix) {
499 let endpoint = rest.trim();
500 if endpoint.starts_with("http://") || endpoint.starts_with("https://") {
501 return Some(endpoint.to_string());
502 }
503 }
504 }
505 None
506}
507
508fn server_tool_rows(server: &McpServer) -> Vec<McpRemoteTool> {
509 let server_slug = sanitize_namespace_segment(&server.name);
510 server
511 .tool_cache
512 .iter()
513 .map(|tool| {
514 let tool_slug = sanitize_namespace_segment(&tool.tool_name);
515 McpRemoteTool {
516 server_name: server.name.clone(),
517 tool_name: tool.tool_name.clone(),
518 namespaced_name: format!("mcp.{server_slug}.{tool_slug}"),
519 description: tool.description.clone(),
520 input_schema: tool.input_schema.clone(),
521 fetched_at_ms: tool.fetched_at_ms,
522 schema_hash: tool.schema_hash.clone(),
523 }
524 })
525 .collect()
526}
527
528fn sanitize_namespace_segment(raw: &str) -> String {
529 let mut out = String::new();
530 let mut previous_underscore = false;
531 for ch in raw.trim().chars() {
532 if ch.is_ascii_alphanumeric() {
533 out.push(ch.to_ascii_lowercase());
534 previous_underscore = false;
535 } else if !previous_underscore {
536 out.push('_');
537 previous_underscore = true;
538 }
539 }
540 let cleaned = out.trim_matches('_');
541 if cleaned.is_empty() {
542 "tool".to_string()
543 } else {
544 cleaned.to_string()
545 }
546}
547
548fn schema_hash(schema: &Value) -> String {
549 let payload = serde_json::to_vec(schema).unwrap_or_default();
550 let mut hasher = Sha256::new();
551 hasher.update(payload);
552 format!("{:x}", hasher.finalize())
553}
554
555fn now_ms() -> u64 {
556 SystemTime::now()
557 .duration_since(UNIX_EPOCH)
558 .map(|d| d.as_millis() as u64)
559 .unwrap_or(0)
560}
561
562fn build_headers(headers: &HashMap<String, String>) -> Result<HeaderMap, String> {
563 let mut map = HeaderMap::new();
564 map.insert(
565 ACCEPT,
566 HeaderValue::from_static("application/json, text/event-stream"),
567 );
568 map.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
569 for (key, value) in headers {
570 let name = HeaderName::from_bytes(key.trim().as_bytes())
571 .map_err(|e| format!("Invalid header name '{key}': {e}"))?;
572 let header = HeaderValue::from_str(value.trim())
573 .map_err(|e| format!("Invalid header value for '{key}': {e}"))?;
574 map.insert(name, header);
575 }
576 Ok(map)
577}
578
579async fn post_json_rpc(
580 endpoint: &str,
581 headers: &HashMap<String, String>,
582 request: Value,
583) -> Result<Value, String> {
584 let client = reqwest::Client::builder()
585 .timeout(std::time::Duration::from_secs(12))
586 .build()
587 .map_err(|e| format!("Failed to build HTTP client: {e}"))?;
588 let response = client
589 .post(endpoint)
590 .headers(build_headers(headers)?)
591 .json(&request)
592 .send()
593 .await
594 .map_err(|e| format!("MCP request failed: {e}"))?;
595 let status = response.status();
596 let payload = response
597 .text()
598 .await
599 .map_err(|e| format!("Failed to read MCP response: {e}"))?;
600 if !status.is_success() {
601 return Err(format!(
602 "MCP endpoint returned HTTP {}: {}",
603 status.as_u16(),
604 payload.chars().take(400).collect::<String>()
605 ));
606 }
607 serde_json::from_str::<Value>(&payload)
608 .map_err(|e| format!("Invalid MCP JSON response: {e}"))
609}
610
611fn render_mcp_content(value: &Value) -> String {
612 let Some(items) = value.as_array() else {
613 return value.to_string();
614 };
615 let mut chunks = Vec::new();
616 for item in items {
617 if let Some(text) = item.get("text").and_then(|v| v.as_str()) {
618 chunks.push(text.to_string());
619 continue;
620 }
621 chunks.push(item.to_string());
622 }
623 if chunks.is_empty() {
624 value.to_string()
625 } else {
626 chunks.join("\n")
627 }
628}
629
630async fn spawn_stdio_process(command_text: &str) -> Result<Child, String> {
631 if command_text.is_empty() {
632 return Err("Missing stdio command".to_string());
633 }
634 #[cfg(windows)]
635 let mut command = {
636 let mut cmd = Command::new("powershell");
637 cmd.args(["-NoProfile", "-Command", command_text]);
638 cmd
639 };
640 #[cfg(not(windows))]
641 let mut command = {
642 let mut cmd = Command::new("sh");
643 cmd.args(["-lc", command_text]);
644 cmd
645 };
646 command
647 .stdin(std::process::Stdio::null())
648 .stdout(std::process::Stdio::null())
649 .stderr(std::process::Stdio::null());
650 command.spawn().map_err(|e| e.to_string())
651}
652
653#[cfg(test)]
654mod tests {
655 use super::*;
656 use uuid::Uuid;
657
658 #[tokio::test]
659 async fn add_connect_disconnect_non_stdio_server() {
660 let file = std::env::temp_dir().join(format!("mcp-test-{}.json", Uuid::new_v4()));
661 let registry = McpRegistry::new_with_state_file(file);
662 registry
663 .add("example".to_string(), "sse:https://example.com".to_string())
664 .await;
665 assert!(registry.connect("example").await);
666 let listed = registry.list().await;
667 assert!(listed.get("example").map(|s| s.connected).unwrap_or(false));
668 assert!(registry.disconnect("example").await);
669 }
670
671 #[test]
672 fn parse_remote_endpoint_supports_http_prefixes() {
673 assert_eq!(
674 parse_remote_endpoint("https://mcp.example.com/mcp"),
675 Some("https://mcp.example.com/mcp".to_string())
676 );
677 assert_eq!(
678 parse_remote_endpoint("http:https://mcp.example.com/mcp"),
679 Some("https://mcp.example.com/mcp".to_string())
680 );
681 }
682}