1use std::collections::HashMap;
2use std::path::{Path, PathBuf};
3use std::sync::Arc;
4use std::time::{SystemTime, UNIX_EPOCH};
5
6use reqwest::header::{HeaderMap, HeaderName, HeaderValue, ACCEPT, CONTENT_TYPE};
7use serde::{Deserialize, Serialize};
8use serde_json::{json, Value};
9use sha2::{Digest, Sha256};
10use tandem_types::ToolResult;
11use tokio::process::{Child, Command};
12use tokio::sync::{Mutex, RwLock};
13
14const MCP_PROTOCOL_VERSION: &str = "2025-11-25";
15const MCP_CLIENT_NAME: &str = "tandem";
16const MCP_CLIENT_VERSION: &str = env!("CARGO_PKG_VERSION");
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct McpToolCacheEntry {
20 pub tool_name: String,
21 pub description: String,
22 #[serde(default)]
23 pub input_schema: Value,
24 pub fetched_at_ms: u64,
25 pub schema_hash: String,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct McpServer {
30 pub name: String,
31 pub transport: String,
32 #[serde(default = "default_enabled")]
33 pub enabled: bool,
34 pub connected: bool,
35 #[serde(skip_serializing_if = "Option::is_none")]
36 pub pid: Option<u32>,
37 #[serde(skip_serializing_if = "Option::is_none")]
38 pub last_error: Option<String>,
39 #[serde(default)]
40 pub headers: HashMap<String, String>,
41 #[serde(default)]
42 pub tool_cache: Vec<McpToolCacheEntry>,
43 #[serde(default, skip_serializing_if = "Option::is_none")]
44 pub tools_fetched_at_ms: Option<u64>,
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct McpRemoteTool {
49 pub server_name: String,
50 pub tool_name: String,
51 pub namespaced_name: String,
52 pub description: String,
53 #[serde(default)]
54 pub input_schema: Value,
55 pub fetched_at_ms: u64,
56 pub schema_hash: String,
57}
58
59#[derive(Clone)]
60pub struct McpRegistry {
61 servers: Arc<RwLock<HashMap<String, McpServer>>>,
62 processes: Arc<Mutex<HashMap<String, Child>>>,
63 state_file: Arc<PathBuf>,
64}
65
66impl McpRegistry {
67 pub fn new() -> Self {
68 Self::new_with_state_file(resolve_state_file())
69 }
70
71 pub fn new_with_state_file(state_file: PathBuf) -> Self {
72 let loaded = load_state(&state_file)
73 .into_iter()
74 .map(|(k, mut v)| {
75 v.connected = false;
76 v.pid = None;
77 if v.name.trim().is_empty() {
78 v.name = k.clone();
79 }
80 if v.headers.is_empty() {
81 v.headers = HashMap::new();
82 }
83 (k, v)
84 })
85 .collect::<HashMap<_, _>>();
86 Self {
87 servers: Arc::new(RwLock::new(loaded)),
88 processes: Arc::new(Mutex::new(HashMap::new())),
89 state_file: Arc::new(state_file),
90 }
91 }
92
93 pub async fn list(&self) -> HashMap<String, McpServer> {
94 self.servers.read().await.clone()
95 }
96
97 pub async fn add(&self, name: String, transport: String) {
98 self.add_or_update(name, transport, HashMap::new(), true)
99 .await;
100 }
101
102 pub async fn add_or_update(
103 &self,
104 name: String,
105 transport: String,
106 headers: HashMap<String, String>,
107 enabled: bool,
108 ) {
109 let mut servers = self.servers.write().await;
110 let existing = servers.get(&name).cloned();
111 let existing_tool_cache = existing
112 .as_ref()
113 .map(|row| row.tool_cache.clone())
114 .unwrap_or_default();
115 let existing_fetched_at = existing.as_ref().and_then(|row| row.tools_fetched_at_ms);
116 let server = McpServer {
117 name: name.clone(),
118 transport,
119 enabled,
120 connected: false,
121 pid: None,
122 last_error: None,
123 headers,
124 tool_cache: existing_tool_cache,
125 tools_fetched_at_ms: existing_fetched_at,
126 };
127 servers.insert(name, server);
128 drop(servers);
129 self.persist_state().await;
130 }
131
132 pub async fn set_enabled(&self, name: &str, enabled: bool) -> bool {
133 let mut servers = self.servers.write().await;
134 let Some(server) = servers.get_mut(name) else {
135 return false;
136 };
137 server.enabled = enabled;
138 if !enabled {
139 server.connected = false;
140 server.pid = None;
141 }
142 drop(servers);
143 if !enabled {
144 if let Some(mut child) = self.processes.lock().await.remove(name) {
145 let _ = child.kill().await;
146 let _ = child.wait().await;
147 }
148 }
149 self.persist_state().await;
150 true
151 }
152
153 pub async fn connect(&self, name: &str) -> bool {
154 let server = {
155 let servers = self.servers.read().await;
156 let Some(server) = servers.get(name) else {
157 return false;
158 };
159 server.clone()
160 };
161
162 if !server.enabled {
163 let mut servers = self.servers.write().await;
164 if let Some(entry) = servers.get_mut(name) {
165 entry.connected = false;
166 entry.pid = None;
167 entry.last_error = Some("MCP server is disabled".to_string());
168 }
169 drop(servers);
170 self.persist_state().await;
171 return false;
172 }
173
174 if let Some(command_text) = parse_stdio_transport(&server.transport) {
175 return self.connect_stdio(name, command_text).await;
176 }
177
178 if parse_remote_endpoint(&server.transport).is_some() {
179 return self.refresh(name).await.is_ok();
180 }
181
182 let mut servers = self.servers.write().await;
183 if let Some(entry) = servers.get_mut(name) {
184 entry.connected = true;
185 entry.pid = None;
186 entry.last_error = None;
187 }
188 drop(servers);
189 self.persist_state().await;
190 true
191 }
192
193 pub async fn refresh(&self, name: &str) -> Result<Vec<McpRemoteTool>, String> {
194 let server = {
195 let servers = self.servers.read().await;
196 let Some(server) = servers.get(name) else {
197 return Err("MCP server not found".to_string());
198 };
199 server.clone()
200 };
201
202 if !server.enabled {
203 return Err("MCP server is disabled".to_string());
204 }
205
206 let endpoint = parse_remote_endpoint(&server.transport)
207 .ok_or_else(|| "MCP refresh currently supports HTTP/S transports only".to_string())?;
208
209 let tools = match self.discover_remote_tools(&endpoint, &server.headers).await {
210 Ok(tools) => tools,
211 Err(err) => {
212 let mut servers = self.servers.write().await;
213 if let Some(entry) = servers.get_mut(name) {
214 entry.connected = false;
215 entry.pid = None;
216 entry.last_error = Some(err.clone());
217 }
218 drop(servers);
219 self.persist_state().await;
220 return Err(err);
221 }
222 };
223
224 let now = now_ms();
225 let cache = tools
226 .iter()
227 .map(|tool| McpToolCacheEntry {
228 tool_name: tool.tool_name.clone(),
229 description: tool.description.clone(),
230 input_schema: tool.input_schema.clone(),
231 fetched_at_ms: now,
232 schema_hash: schema_hash(&tool.input_schema),
233 })
234 .collect::<Vec<_>>();
235
236 let mut servers = self.servers.write().await;
237 if let Some(entry) = servers.get_mut(name) {
238 entry.connected = true;
239 entry.pid = None;
240 entry.last_error = None;
241 entry.tool_cache = cache;
242 entry.tools_fetched_at_ms = Some(now);
243 }
244 drop(servers);
245 self.persist_state().await;
246 Ok(self.server_tools(name).await)
247 }
248
249 pub async fn disconnect(&self, name: &str) -> bool {
250 if let Some(mut child) = self.processes.lock().await.remove(name) {
251 let _ = child.kill().await;
252 let _ = child.wait().await;
253 }
254 let mut servers = self.servers.write().await;
255 if let Some(server) = servers.get_mut(name) {
256 server.connected = false;
257 server.pid = None;
258 drop(servers);
259 self.persist_state().await;
260 return true;
261 }
262 false
263 }
264
265 pub async fn list_tools(&self) -> Vec<McpRemoteTool> {
266 let mut out = self
267 .servers
268 .read()
269 .await
270 .values()
271 .filter(|server| server.enabled && server.connected)
272 .flat_map(server_tool_rows)
273 .collect::<Vec<_>>();
274 out.sort_by(|a, b| a.namespaced_name.cmp(&b.namespaced_name));
275 out
276 }
277
278 pub async fn server_tools(&self, name: &str) -> Vec<McpRemoteTool> {
279 let Some(server) = self.servers.read().await.get(name).cloned() else {
280 return Vec::new();
281 };
282 let mut rows = server_tool_rows(&server);
283 rows.sort_by(|a, b| a.namespaced_name.cmp(&b.namespaced_name));
284 rows
285 }
286
287 pub async fn call_tool(
288 &self,
289 server_name: &str,
290 tool_name: &str,
291 args: Value,
292 ) -> Result<ToolResult, String> {
293 let server = {
294 let servers = self.servers.read().await;
295 let Some(server) = servers.get(server_name) else {
296 return Err(format!("MCP server '{server_name}' not found"));
297 };
298 server.clone()
299 };
300
301 if !server.enabled {
302 return Err(format!("MCP server '{server_name}' is disabled"));
303 }
304 if !server.connected {
305 return Err(format!("MCP server '{server_name}' is not connected"));
306 }
307
308 let endpoint = parse_remote_endpoint(&server.transport).ok_or_else(|| {
309 "MCP tools/call currently supports HTTP/S transports only".to_string()
310 })?;
311
312 let request = json!({
313 "jsonrpc": "2.0",
314 "id": format!("call-{}-{}", server_name, now_ms()),
315 "method": "tools/call",
316 "params": {
317 "name": tool_name,
318 "arguments": args
319 }
320 });
321 let response = post_json_rpc(&endpoint, &server.headers, request).await?;
322
323 if let Some(err) = response.get("error") {
324 let message = err
325 .get("message")
326 .and_then(|v| v.as_str())
327 .unwrap_or("MCP tools/call failed");
328 return Err(message.to_string());
329 }
330
331 let result = response.get("result").cloned().unwrap_or(Value::Null);
332 let output = result
333 .get("content")
334 .map(render_mcp_content)
335 .or_else(|| result.get("output").map(|v| v.to_string()))
336 .unwrap_or_else(|| result.to_string());
337
338 Ok(ToolResult {
339 output,
340 metadata: json!({
341 "server": server_name,
342 "tool": tool_name,
343 "result": result
344 }),
345 })
346 }
347
348 async fn connect_stdio(&self, name: &str, command_text: &str) -> bool {
349 match spawn_stdio_process(command_text).await {
350 Ok(child) => {
351 let pid = child.id();
352 self.processes.lock().await.insert(name.to_string(), child);
353 let mut servers = self.servers.write().await;
354 if let Some(server) = servers.get_mut(name) {
355 server.connected = true;
356 server.pid = pid;
357 server.last_error = None;
358 }
359 drop(servers);
360 self.persist_state().await;
361 true
362 }
363 Err(err) => {
364 let mut servers = self.servers.write().await;
365 if let Some(server) = servers.get_mut(name) {
366 server.connected = false;
367 server.pid = None;
368 server.last_error = Some(err);
369 }
370 drop(servers);
371 self.persist_state().await;
372 false
373 }
374 }
375 }
376
377 async fn discover_remote_tools(
378 &self,
379 endpoint: &str,
380 headers: &HashMap<String, String>,
381 ) -> Result<Vec<McpRemoteTool>, String> {
382 let initialize = json!({
383 "jsonrpc": "2.0",
384 "id": "initialize-1",
385 "method": "initialize",
386 "params": {
387 "protocolVersion": MCP_PROTOCOL_VERSION,
388 "capabilities": {},
389 "clientInfo": {
390 "name": MCP_CLIENT_NAME,
391 "version": MCP_CLIENT_VERSION,
392 }
393 }
394 });
395 let init_response = post_json_rpc(endpoint, headers, initialize).await?;
396 if let Some(err) = init_response.get("error") {
397 let message = err
398 .get("message")
399 .and_then(|v| v.as_str())
400 .unwrap_or("MCP initialize failed");
401 return Err(message.to_string());
402 }
403
404 let tools_list = json!({
405 "jsonrpc": "2.0",
406 "id": "tools-list-1",
407 "method": "tools/list",
408 "params": {}
409 });
410 let tools_response = post_json_rpc(endpoint, headers, tools_list).await?;
411 if let Some(err) = tools_response.get("error") {
412 let message = err
413 .get("message")
414 .and_then(|v| v.as_str())
415 .unwrap_or("MCP tools/list failed");
416 return Err(message.to_string());
417 }
418
419 let tools = tools_response
420 .get("result")
421 .and_then(|v| v.get("tools"))
422 .and_then(|v| v.as_array())
423 .ok_or_else(|| "MCP tools/list result missing tools array".to_string())?;
424
425 let now = now_ms();
426 let mut out = Vec::new();
427 for row in tools {
428 let Some(tool_name) = row.get("name").and_then(|v| v.as_str()) else {
429 continue;
430 };
431 let description = row
432 .get("description")
433 .and_then(|v| v.as_str())
434 .unwrap_or("")
435 .to_string();
436 let input_schema = row
437 .get("inputSchema")
438 .or_else(|| row.get("input_schema"))
439 .cloned()
440 .unwrap_or_else(|| json!({"type":"object"}));
441 out.push(McpRemoteTool {
442 server_name: String::new(),
443 tool_name: tool_name.to_string(),
444 namespaced_name: String::new(),
445 description,
446 input_schema,
447 fetched_at_ms: now,
448 schema_hash: String::new(),
449 });
450 }
451
452 Ok(out)
453 }
454
455 async fn persist_state(&self) {
456 let snapshot = self.servers.read().await.clone();
457 if let Some(parent) = self.state_file.parent() {
458 let _ = tokio::fs::create_dir_all(parent).await;
459 }
460 if let Ok(payload) = serde_json::to_string_pretty(&snapshot) {
461 let _ = tokio::fs::write(self.state_file.as_path(), payload).await;
462 }
463 }
464}
465
466impl Default for McpRegistry {
467 fn default() -> Self {
468 Self::new()
469 }
470}
471
472fn default_enabled() -> bool {
473 true
474}
475
476fn resolve_state_file() -> PathBuf {
477 if let Ok(path) = std::env::var("TANDEM_MCP_REGISTRY") {
478 return PathBuf::from(path);
479 }
480 if let Ok(state_dir) = std::env::var("TANDEM_STATE_DIR") {
481 let trimmed = state_dir.trim();
482 if !trimmed.is_empty() {
483 return PathBuf::from(trimmed).join("mcp_servers.json");
484 }
485 }
486 if let Some(data_dir) = dirs::data_dir() {
487 return data_dir
488 .join("tandem")
489 .join("data")
490 .join("mcp_servers.json");
491 }
492 dirs::home_dir()
493 .map(|home| home.join(".tandem").join("data").join("mcp_servers.json"))
494 .unwrap_or_else(|| PathBuf::from("mcp_servers.json"))
495}
496
497fn load_state(path: &Path) -> HashMap<String, McpServer> {
498 let Ok(raw) = std::fs::read_to_string(path) else {
499 return HashMap::new();
500 };
501 serde_json::from_str::<HashMap<String, McpServer>>(&raw).unwrap_or_default()
502}
503
504fn parse_stdio_transport(transport: &str) -> Option<&str> {
505 transport.strip_prefix("stdio:").map(str::trim)
506}
507
508fn parse_remote_endpoint(transport: &str) -> Option<String> {
509 let trimmed = transport.trim();
510 if trimmed.starts_with("http://") || trimmed.starts_with("https://") {
511 return Some(trimmed.to_string());
512 }
513 for prefix in ["http:", "https:"] {
514 if let Some(rest) = trimmed.strip_prefix(prefix) {
515 let endpoint = rest.trim();
516 if endpoint.starts_with("http://") || endpoint.starts_with("https://") {
517 return Some(endpoint.to_string());
518 }
519 }
520 }
521 None
522}
523
524fn server_tool_rows(server: &McpServer) -> Vec<McpRemoteTool> {
525 let server_slug = sanitize_namespace_segment(&server.name);
526 server
527 .tool_cache
528 .iter()
529 .map(|tool| {
530 let tool_slug = sanitize_namespace_segment(&tool.tool_name);
531 McpRemoteTool {
532 server_name: server.name.clone(),
533 tool_name: tool.tool_name.clone(),
534 namespaced_name: format!("mcp.{server_slug}.{tool_slug}"),
535 description: tool.description.clone(),
536 input_schema: tool.input_schema.clone(),
537 fetched_at_ms: tool.fetched_at_ms,
538 schema_hash: tool.schema_hash.clone(),
539 }
540 })
541 .collect()
542}
543
544fn sanitize_namespace_segment(raw: &str) -> String {
545 let mut out = String::new();
546 let mut previous_underscore = false;
547 for ch in raw.trim().chars() {
548 if ch.is_ascii_alphanumeric() {
549 out.push(ch.to_ascii_lowercase());
550 previous_underscore = false;
551 } else if !previous_underscore {
552 out.push('_');
553 previous_underscore = true;
554 }
555 }
556 let cleaned = out.trim_matches('_');
557 if cleaned.is_empty() {
558 "tool".to_string()
559 } else {
560 cleaned.to_string()
561 }
562}
563
564fn schema_hash(schema: &Value) -> String {
565 let payload = serde_json::to_vec(schema).unwrap_or_default();
566 let mut hasher = Sha256::new();
567 hasher.update(payload);
568 format!("{:x}", hasher.finalize())
569}
570
571fn now_ms() -> u64 {
572 SystemTime::now()
573 .duration_since(UNIX_EPOCH)
574 .map(|d| d.as_millis() as u64)
575 .unwrap_or(0)
576}
577
578fn build_headers(headers: &HashMap<String, String>) -> Result<HeaderMap, String> {
579 let mut map = HeaderMap::new();
580 map.insert(
581 ACCEPT,
582 HeaderValue::from_static("application/json, text/event-stream"),
583 );
584 map.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
585 for (key, value) in headers {
586 let name = HeaderName::from_bytes(key.trim().as_bytes())
587 .map_err(|e| format!("Invalid header name '{key}': {e}"))?;
588 let header = HeaderValue::from_str(value.trim())
589 .map_err(|e| format!("Invalid header value for '{key}': {e}"))?;
590 map.insert(name, header);
591 }
592 Ok(map)
593}
594
595async fn post_json_rpc(
596 endpoint: &str,
597 headers: &HashMap<String, String>,
598 request: Value,
599) -> Result<Value, String> {
600 let client = reqwest::Client::builder()
601 .timeout(std::time::Duration::from_secs(12))
602 .build()
603 .map_err(|e| format!("Failed to build HTTP client: {e}"))?;
604 let response = client
605 .post(endpoint)
606 .headers(build_headers(headers)?)
607 .json(&request)
608 .send()
609 .await
610 .map_err(|e| format!("MCP request failed: {e}"))?;
611 let status = response.status();
612 let payload = response
613 .text()
614 .await
615 .map_err(|e| format!("Failed to read MCP response: {e}"))?;
616 if !status.is_success() {
617 return Err(format!(
618 "MCP endpoint returned HTTP {}: {}",
619 status.as_u16(),
620 payload.chars().take(400).collect::<String>()
621 ));
622 }
623 serde_json::from_str::<Value>(&payload).map_err(|e| format!("Invalid MCP JSON response: {e}"))
624}
625
626fn render_mcp_content(value: &Value) -> String {
627 let Some(items) = value.as_array() else {
628 return value.to_string();
629 };
630 let mut chunks = Vec::new();
631 for item in items {
632 if let Some(text) = item.get("text").and_then(|v| v.as_str()) {
633 chunks.push(text.to_string());
634 continue;
635 }
636 chunks.push(item.to_string());
637 }
638 if chunks.is_empty() {
639 value.to_string()
640 } else {
641 chunks.join("\n")
642 }
643}
644
645async fn spawn_stdio_process(command_text: &str) -> Result<Child, String> {
646 if command_text.is_empty() {
647 return Err("Missing stdio command".to_string());
648 }
649 #[cfg(windows)]
650 let mut command = {
651 let mut cmd = Command::new("powershell");
652 cmd.args(["-NoProfile", "-Command", command_text]);
653 cmd
654 };
655 #[cfg(not(windows))]
656 let mut command = {
657 let mut cmd = Command::new("sh");
658 cmd.args(["-lc", command_text]);
659 cmd
660 };
661 command
662 .stdin(std::process::Stdio::null())
663 .stdout(std::process::Stdio::null())
664 .stderr(std::process::Stdio::null());
665 command.spawn().map_err(|e| e.to_string())
666}
667
668#[cfg(test)]
669mod tests {
670 use super::*;
671 use uuid::Uuid;
672
673 #[tokio::test]
674 async fn add_connect_disconnect_non_stdio_server() {
675 let file = std::env::temp_dir().join(format!("mcp-test-{}.json", Uuid::new_v4()));
676 let registry = McpRegistry::new_with_state_file(file);
677 registry
678 .add("example".to_string(), "sse:https://example.com".to_string())
679 .await;
680 assert!(registry.connect("example").await);
681 let listed = registry.list().await;
682 assert!(listed.get("example").map(|s| s.connected).unwrap_or(false));
683 assert!(registry.disconnect("example").await);
684 }
685
686 #[test]
687 fn parse_remote_endpoint_supports_http_prefixes() {
688 assert_eq!(
689 parse_remote_endpoint("https://mcp.example.com/mcp"),
690 Some("https://mcp.example.com/mcp".to_string())
691 );
692 assert_eq!(
693 parse_remote_endpoint("http:https://mcp.example.com/mcp"),
694 Some("https://mcp.example.com/mcp".to_string())
695 );
696 }
697}