Skip to main content

tandem_runtime/
mcp.rs

1use std::collections::HashMap;
2use std::path::{Path, PathBuf};
3use std::sync::Arc;
4use std::time::{SystemTime, UNIX_EPOCH};
5
6use reqwest::header::{HeaderMap, HeaderName, HeaderValue, ACCEPT, CONTENT_TYPE};
7use serde::{Deserialize, Serialize};
8use serde_json::{json, Value};
9use sha2::{Digest, Sha256};
10use tandem_types::ToolResult;
11use tokio::process::{Child, Command};
12use tokio::sync::{Mutex, RwLock};
13
14const MCP_PROTOCOL_VERSION: &str = "2025-11-25";
15const MCP_CLIENT_NAME: &str = "tandem";
16const MCP_CLIENT_VERSION: &str = env!("CARGO_PKG_VERSION");
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct McpToolCacheEntry {
20    pub tool_name: String,
21    pub description: String,
22    #[serde(default)]
23    pub input_schema: Value,
24    pub fetched_at_ms: u64,
25    pub schema_hash: String,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct McpServer {
30    pub name: String,
31    pub transport: String,
32    #[serde(default = "default_enabled")]
33    pub enabled: bool,
34    pub connected: bool,
35    #[serde(skip_serializing_if = "Option::is_none")]
36    pub pid: Option<u32>,
37    #[serde(skip_serializing_if = "Option::is_none")]
38    pub last_error: Option<String>,
39    #[serde(default)]
40    pub headers: HashMap<String, String>,
41    #[serde(default)]
42    pub tool_cache: Vec<McpToolCacheEntry>,
43    #[serde(default, skip_serializing_if = "Option::is_none")]
44    pub tools_fetched_at_ms: Option<u64>,
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct McpRemoteTool {
49    pub server_name: String,
50    pub tool_name: String,
51    pub namespaced_name: String,
52    pub description: String,
53    #[serde(default)]
54    pub input_schema: Value,
55    pub fetched_at_ms: u64,
56    pub schema_hash: String,
57}
58
59#[derive(Clone)]
60pub struct McpRegistry {
61    servers: Arc<RwLock<HashMap<String, McpServer>>>,
62    processes: Arc<Mutex<HashMap<String, Child>>>,
63    state_file: Arc<PathBuf>,
64}
65
66impl McpRegistry {
67    pub fn new() -> Self {
68        Self::new_with_state_file(resolve_state_file())
69    }
70
71    pub fn new_with_state_file(state_file: PathBuf) -> Self {
72        let loaded = load_state(&state_file)
73            .into_iter()
74            .map(|(k, mut v)| {
75                v.connected = false;
76                v.pid = None;
77                if v.name.trim().is_empty() {
78                    v.name = k.clone();
79                }
80                if v.headers.is_empty() {
81                    v.headers = HashMap::new();
82                }
83                (k, v)
84            })
85            .collect::<HashMap<_, _>>();
86        Self {
87            servers: Arc::new(RwLock::new(loaded)),
88            processes: Arc::new(Mutex::new(HashMap::new())),
89            state_file: Arc::new(state_file),
90        }
91    }
92
93    pub async fn list(&self) -> HashMap<String, McpServer> {
94        self.servers.read().await.clone()
95    }
96
97    pub async fn add(&self, name: String, transport: String) {
98        self.add_or_update(name, transport, HashMap::new(), true)
99            .await;
100    }
101
102    pub async fn add_or_update(
103        &self,
104        name: String,
105        transport: String,
106        headers: HashMap<String, String>,
107        enabled: bool,
108    ) {
109        let mut servers = self.servers.write().await;
110        let existing = servers.get(&name).cloned();
111        let existing_tool_cache = existing
112            .as_ref()
113            .map(|row| row.tool_cache.clone())
114            .unwrap_or_default();
115        let existing_fetched_at = existing.as_ref().and_then(|row| row.tools_fetched_at_ms);
116        let server = McpServer {
117            name: name.clone(),
118            transport,
119            enabled,
120            connected: false,
121            pid: None,
122            last_error: None,
123            headers,
124            tool_cache: existing_tool_cache,
125            tools_fetched_at_ms: existing_fetched_at,
126        };
127        servers.insert(name, server);
128        drop(servers);
129        self.persist_state().await;
130    }
131
132    pub async fn set_enabled(&self, name: &str, enabled: bool) -> bool {
133        let mut servers = self.servers.write().await;
134        let Some(server) = servers.get_mut(name) else {
135            return false;
136        };
137        server.enabled = enabled;
138        if !enabled {
139            server.connected = false;
140            server.pid = None;
141        }
142        drop(servers);
143        if !enabled {
144            if let Some(mut child) = self.processes.lock().await.remove(name) {
145                let _ = child.kill().await;
146                let _ = child.wait().await;
147            }
148        }
149        self.persist_state().await;
150        true
151    }
152
153    pub async fn connect(&self, name: &str) -> bool {
154        let server = {
155            let servers = self.servers.read().await;
156            let Some(server) = servers.get(name) else {
157                return false;
158            };
159            server.clone()
160        };
161
162        if !server.enabled {
163            let mut servers = self.servers.write().await;
164            if let Some(entry) = servers.get_mut(name) {
165                entry.connected = false;
166                entry.pid = None;
167                entry.last_error = Some("MCP server is disabled".to_string());
168            }
169            drop(servers);
170            self.persist_state().await;
171            return false;
172        }
173
174        if let Some(command_text) = parse_stdio_transport(&server.transport) {
175            return self.connect_stdio(name, command_text).await;
176        }
177
178        if parse_remote_endpoint(&server.transport).is_some() {
179            return self.refresh(name).await.is_ok();
180        }
181
182        let mut servers = self.servers.write().await;
183        if let Some(entry) = servers.get_mut(name) {
184            entry.connected = true;
185            entry.pid = None;
186            entry.last_error = None;
187        }
188        drop(servers);
189        self.persist_state().await;
190        true
191    }
192
193    pub async fn refresh(&self, name: &str) -> Result<Vec<McpRemoteTool>, String> {
194        let server = {
195            let servers = self.servers.read().await;
196            let Some(server) = servers.get(name) else {
197                return Err("MCP server not found".to_string());
198            };
199            server.clone()
200        };
201
202        if !server.enabled {
203            return Err("MCP server is disabled".to_string());
204        }
205
206        let endpoint = parse_remote_endpoint(&server.transport)
207            .ok_or_else(|| "MCP refresh currently supports HTTP/S transports only".to_string())?;
208
209        let tools = match self.discover_remote_tools(&endpoint, &server.headers).await {
210            Ok(tools) => tools,
211            Err(err) => {
212                let mut servers = self.servers.write().await;
213                if let Some(entry) = servers.get_mut(name) {
214                    entry.connected = false;
215                    entry.pid = None;
216                    entry.last_error = Some(err.clone());
217                }
218                drop(servers);
219                self.persist_state().await;
220                return Err(err);
221            }
222        };
223
224        let now = now_ms();
225        let cache = tools
226            .iter()
227            .map(|tool| McpToolCacheEntry {
228                tool_name: tool.tool_name.clone(),
229                description: tool.description.clone(),
230                input_schema: tool.input_schema.clone(),
231                fetched_at_ms: now,
232                schema_hash: schema_hash(&tool.input_schema),
233            })
234            .collect::<Vec<_>>();
235
236        let mut servers = self.servers.write().await;
237        if let Some(entry) = servers.get_mut(name) {
238            entry.connected = true;
239            entry.pid = None;
240            entry.last_error = None;
241            entry.tool_cache = cache;
242            entry.tools_fetched_at_ms = Some(now);
243        }
244        drop(servers);
245        self.persist_state().await;
246        Ok(self.server_tools(name).await)
247    }
248
249    pub async fn disconnect(&self, name: &str) -> bool {
250        if let Some(mut child) = self.processes.lock().await.remove(name) {
251            let _ = child.kill().await;
252            let _ = child.wait().await;
253        }
254        let mut servers = self.servers.write().await;
255        if let Some(server) = servers.get_mut(name) {
256            server.connected = false;
257            server.pid = None;
258            drop(servers);
259            self.persist_state().await;
260            return true;
261        }
262        false
263    }
264
265    pub async fn list_tools(&self) -> Vec<McpRemoteTool> {
266        let mut out = self
267            .servers
268            .read()
269            .await
270            .values()
271            .filter(|server| server.enabled && server.connected)
272            .flat_map(server_tool_rows)
273            .collect::<Vec<_>>();
274        out.sort_by(|a, b| a.namespaced_name.cmp(&b.namespaced_name));
275        out
276    }
277
278    pub async fn server_tools(&self, name: &str) -> Vec<McpRemoteTool> {
279        let Some(server) = self.servers.read().await.get(name).cloned() else {
280            return Vec::new();
281        };
282        let mut rows = server_tool_rows(&server);
283        rows.sort_by(|a, b| a.namespaced_name.cmp(&b.namespaced_name));
284        rows
285    }
286
287    pub async fn call_tool(
288        &self,
289        server_name: &str,
290        tool_name: &str,
291        args: Value,
292    ) -> Result<ToolResult, String> {
293        let server = {
294            let servers = self.servers.read().await;
295            let Some(server) = servers.get(server_name) else {
296                return Err(format!("MCP server '{server_name}' not found"));
297            };
298            server.clone()
299        };
300
301        if !server.enabled {
302            return Err(format!("MCP server '{server_name}' is disabled"));
303        }
304        if !server.connected {
305            return Err(format!("MCP server '{server_name}' is not connected"));
306        }
307
308        let endpoint = parse_remote_endpoint(&server.transport).ok_or_else(|| {
309            "MCP tools/call currently supports HTTP/S transports only".to_string()
310        })?;
311
312        let request = json!({
313            "jsonrpc": "2.0",
314            "id": format!("call-{}-{}", server_name, now_ms()),
315            "method": "tools/call",
316            "params": {
317                "name": tool_name,
318                "arguments": args
319            }
320        });
321        let response = post_json_rpc(&endpoint, &server.headers, request).await?;
322
323        if let Some(err) = response.get("error") {
324            let message = err
325                .get("message")
326                .and_then(|v| v.as_str())
327                .unwrap_or("MCP tools/call failed");
328            return Err(message.to_string());
329        }
330
331        let result = response.get("result").cloned().unwrap_or(Value::Null);
332        let output = result
333            .get("content")
334            .map(render_mcp_content)
335            .or_else(|| result.get("output").map(|v| v.to_string()))
336            .unwrap_or_else(|| result.to_string());
337
338        Ok(ToolResult {
339            output,
340            metadata: json!({
341                "server": server_name,
342                "tool": tool_name,
343                "result": result
344            }),
345        })
346    }
347
348    async fn connect_stdio(&self, name: &str, command_text: &str) -> bool {
349        match spawn_stdio_process(command_text).await {
350            Ok(child) => {
351                let pid = child.id();
352                self.processes.lock().await.insert(name.to_string(), child);
353                let mut servers = self.servers.write().await;
354                if let Some(server) = servers.get_mut(name) {
355                    server.connected = true;
356                    server.pid = pid;
357                    server.last_error = None;
358                }
359                drop(servers);
360                self.persist_state().await;
361                true
362            }
363            Err(err) => {
364                let mut servers = self.servers.write().await;
365                if let Some(server) = servers.get_mut(name) {
366                    server.connected = false;
367                    server.pid = None;
368                    server.last_error = Some(err);
369                }
370                drop(servers);
371                self.persist_state().await;
372                false
373            }
374        }
375    }
376
377    async fn discover_remote_tools(
378        &self,
379        endpoint: &str,
380        headers: &HashMap<String, String>,
381    ) -> Result<Vec<McpRemoteTool>, String> {
382        let initialize = json!({
383            "jsonrpc": "2.0",
384            "id": "initialize-1",
385            "method": "initialize",
386            "params": {
387                "protocolVersion": MCP_PROTOCOL_VERSION,
388                "capabilities": {},
389                "clientInfo": {
390                    "name": MCP_CLIENT_NAME,
391                    "version": MCP_CLIENT_VERSION,
392                }
393            }
394        });
395        let init_response = post_json_rpc(endpoint, headers, initialize).await?;
396        if let Some(err) = init_response.get("error") {
397            let message = err
398                .get("message")
399                .and_then(|v| v.as_str())
400                .unwrap_or("MCP initialize failed");
401            return Err(message.to_string());
402        }
403
404        let tools_list = json!({
405            "jsonrpc": "2.0",
406            "id": "tools-list-1",
407            "method": "tools/list",
408            "params": {}
409        });
410        let tools_response = post_json_rpc(endpoint, headers, tools_list).await?;
411        if let Some(err) = tools_response.get("error") {
412            let message = err
413                .get("message")
414                .and_then(|v| v.as_str())
415                .unwrap_or("MCP tools/list failed");
416            return Err(message.to_string());
417        }
418
419        let tools = tools_response
420            .get("result")
421            .and_then(|v| v.get("tools"))
422            .and_then(|v| v.as_array())
423            .ok_or_else(|| "MCP tools/list result missing tools array".to_string())?;
424
425        let now = now_ms();
426        let mut out = Vec::new();
427        for row in tools {
428            let Some(tool_name) = row.get("name").and_then(|v| v.as_str()) else {
429                continue;
430            };
431            let description = row
432                .get("description")
433                .and_then(|v| v.as_str())
434                .unwrap_or("")
435                .to_string();
436            let input_schema = row
437                .get("inputSchema")
438                .or_else(|| row.get("input_schema"))
439                .cloned()
440                .unwrap_or_else(|| json!({"type":"object"}));
441            out.push(McpRemoteTool {
442                server_name: String::new(),
443                tool_name: tool_name.to_string(),
444                namespaced_name: String::new(),
445                description,
446                input_schema,
447                fetched_at_ms: now,
448                schema_hash: String::new(),
449            });
450        }
451
452        Ok(out)
453    }
454
455    async fn persist_state(&self) {
456        let snapshot = self.servers.read().await.clone();
457        if let Some(parent) = self.state_file.parent() {
458            let _ = tokio::fs::create_dir_all(parent).await;
459        }
460        if let Ok(payload) = serde_json::to_string_pretty(&snapshot) {
461            let _ = tokio::fs::write(self.state_file.as_path(), payload).await;
462        }
463    }
464}
465
466impl Default for McpRegistry {
467    fn default() -> Self {
468        Self::new()
469    }
470}
471
472fn default_enabled() -> bool {
473    true
474}
475
476fn resolve_state_file() -> PathBuf {
477    if let Ok(path) = std::env::var("TANDEM_MCP_REGISTRY") {
478        return PathBuf::from(path);
479    }
480    if let Ok(state_dir) = std::env::var("TANDEM_STATE_DIR") {
481        let trimmed = state_dir.trim();
482        if !trimmed.is_empty() {
483            return PathBuf::from(trimmed).join("mcp_servers.json");
484        }
485    }
486    if let Some(data_dir) = dirs::data_dir() {
487        return data_dir
488            .join("tandem")
489            .join("data")
490            .join("mcp_servers.json");
491    }
492    dirs::home_dir()
493        .map(|home| home.join(".tandem").join("data").join("mcp_servers.json"))
494        .unwrap_or_else(|| PathBuf::from("mcp_servers.json"))
495}
496
497fn load_state(path: &Path) -> HashMap<String, McpServer> {
498    let Ok(raw) = std::fs::read_to_string(path) else {
499        return HashMap::new();
500    };
501    serde_json::from_str::<HashMap<String, McpServer>>(&raw).unwrap_or_default()
502}
503
504fn parse_stdio_transport(transport: &str) -> Option<&str> {
505    transport.strip_prefix("stdio:").map(str::trim)
506}
507
508fn parse_remote_endpoint(transport: &str) -> Option<String> {
509    let trimmed = transport.trim();
510    if trimmed.starts_with("http://") || trimmed.starts_with("https://") {
511        return Some(trimmed.to_string());
512    }
513    for prefix in ["http:", "https:"] {
514        if let Some(rest) = trimmed.strip_prefix(prefix) {
515            let endpoint = rest.trim();
516            if endpoint.starts_with("http://") || endpoint.starts_with("https://") {
517                return Some(endpoint.to_string());
518            }
519        }
520    }
521    None
522}
523
524fn server_tool_rows(server: &McpServer) -> Vec<McpRemoteTool> {
525    let server_slug = sanitize_namespace_segment(&server.name);
526    server
527        .tool_cache
528        .iter()
529        .map(|tool| {
530            let tool_slug = sanitize_namespace_segment(&tool.tool_name);
531            McpRemoteTool {
532                server_name: server.name.clone(),
533                tool_name: tool.tool_name.clone(),
534                namespaced_name: format!("mcp.{server_slug}.{tool_slug}"),
535                description: tool.description.clone(),
536                input_schema: tool.input_schema.clone(),
537                fetched_at_ms: tool.fetched_at_ms,
538                schema_hash: tool.schema_hash.clone(),
539            }
540        })
541        .collect()
542}
543
544fn sanitize_namespace_segment(raw: &str) -> String {
545    let mut out = String::new();
546    let mut previous_underscore = false;
547    for ch in raw.trim().chars() {
548        if ch.is_ascii_alphanumeric() {
549            out.push(ch.to_ascii_lowercase());
550            previous_underscore = false;
551        } else if !previous_underscore {
552            out.push('_');
553            previous_underscore = true;
554        }
555    }
556    let cleaned = out.trim_matches('_');
557    if cleaned.is_empty() {
558        "tool".to_string()
559    } else {
560        cleaned.to_string()
561    }
562}
563
564fn schema_hash(schema: &Value) -> String {
565    let payload = serde_json::to_vec(schema).unwrap_or_default();
566    let mut hasher = Sha256::new();
567    hasher.update(payload);
568    format!("{:x}", hasher.finalize())
569}
570
571fn now_ms() -> u64 {
572    SystemTime::now()
573        .duration_since(UNIX_EPOCH)
574        .map(|d| d.as_millis() as u64)
575        .unwrap_or(0)
576}
577
578fn build_headers(headers: &HashMap<String, String>) -> Result<HeaderMap, String> {
579    let mut map = HeaderMap::new();
580    map.insert(
581        ACCEPT,
582        HeaderValue::from_static("application/json, text/event-stream"),
583    );
584    map.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
585    for (key, value) in headers {
586        let name = HeaderName::from_bytes(key.trim().as_bytes())
587            .map_err(|e| format!("Invalid header name '{key}': {e}"))?;
588        let header = HeaderValue::from_str(value.trim())
589            .map_err(|e| format!("Invalid header value for '{key}': {e}"))?;
590        map.insert(name, header);
591    }
592    Ok(map)
593}
594
595async fn post_json_rpc(
596    endpoint: &str,
597    headers: &HashMap<String, String>,
598    request: Value,
599) -> Result<Value, String> {
600    let client = reqwest::Client::builder()
601        .timeout(std::time::Duration::from_secs(12))
602        .build()
603        .map_err(|e| format!("Failed to build HTTP client: {e}"))?;
604    let response = client
605        .post(endpoint)
606        .headers(build_headers(headers)?)
607        .json(&request)
608        .send()
609        .await
610        .map_err(|e| format!("MCP request failed: {e}"))?;
611    let status = response.status();
612    let payload = response
613        .text()
614        .await
615        .map_err(|e| format!("Failed to read MCP response: {e}"))?;
616    if !status.is_success() {
617        return Err(format!(
618            "MCP endpoint returned HTTP {}: {}",
619            status.as_u16(),
620            payload.chars().take(400).collect::<String>()
621        ));
622    }
623    serde_json::from_str::<Value>(&payload).map_err(|e| format!("Invalid MCP JSON response: {e}"))
624}
625
626fn render_mcp_content(value: &Value) -> String {
627    let Some(items) = value.as_array() else {
628        return value.to_string();
629    };
630    let mut chunks = Vec::new();
631    for item in items {
632        if let Some(text) = item.get("text").and_then(|v| v.as_str()) {
633            chunks.push(text.to_string());
634            continue;
635        }
636        chunks.push(item.to_string());
637    }
638    if chunks.is_empty() {
639        value.to_string()
640    } else {
641        chunks.join("\n")
642    }
643}
644
645async fn spawn_stdio_process(command_text: &str) -> Result<Child, String> {
646    if command_text.is_empty() {
647        return Err("Missing stdio command".to_string());
648    }
649    #[cfg(windows)]
650    let mut command = {
651        let mut cmd = Command::new("powershell");
652        cmd.args(["-NoProfile", "-Command", command_text]);
653        cmd
654    };
655    #[cfg(not(windows))]
656    let mut command = {
657        let mut cmd = Command::new("sh");
658        cmd.args(["-lc", command_text]);
659        cmd
660    };
661    command
662        .stdin(std::process::Stdio::null())
663        .stdout(std::process::Stdio::null())
664        .stderr(std::process::Stdio::null());
665    command.spawn().map_err(|e| e.to_string())
666}
667
668#[cfg(test)]
669mod tests {
670    use super::*;
671    use uuid::Uuid;
672
673    #[tokio::test]
674    async fn add_connect_disconnect_non_stdio_server() {
675        let file = std::env::temp_dir().join(format!("mcp-test-{}.json", Uuid::new_v4()));
676        let registry = McpRegistry::new_with_state_file(file);
677        registry
678            .add("example".to_string(), "sse:https://example.com".to_string())
679            .await;
680        assert!(registry.connect("example").await);
681        let listed = registry.list().await;
682        assert!(listed.get("example").map(|s| s.connected).unwrap_or(false));
683        assert!(registry.disconnect("example").await);
684    }
685
686    #[test]
687    fn parse_remote_endpoint_supports_http_prefixes() {
688        assert_eq!(
689            parse_remote_endpoint("https://mcp.example.com/mcp"),
690            Some("https://mcp.example.com/mcp".to_string())
691        );
692        assert_eq!(
693            parse_remote_endpoint("http:https://mcp.example.com/mcp"),
694            Some("https://mcp.example.com/mcp".to_string())
695        );
696    }
697}