1use std::collections::HashMap;
2use std::path::{Path, PathBuf};
3use std::sync::Arc;
4use std::time::{SystemTime, UNIX_EPOCH};
5
6use reqwest::header::{HeaderMap, HeaderName, HeaderValue, ACCEPT, CONTENT_TYPE};
7use serde::{Deserialize, Serialize};
8use serde_json::{json, Value};
9use sha2::{Digest, Sha256};
10use tandem_types::ToolResult;
11use tokio::process::{Child, Command};
12use tokio::sync::{Mutex, RwLock};
13
14const MCP_PROTOCOL_VERSION: &str = "2025-11-25";
15const MCP_CLIENT_NAME: &str = "tandem";
16const MCP_CLIENT_VERSION: &str = env!("CARGO_PKG_VERSION");
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct McpToolCacheEntry {
20 pub tool_name: String,
21 pub description: String,
22 #[serde(default)]
23 pub input_schema: Value,
24 pub fetched_at_ms: u64,
25 pub schema_hash: String,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct McpServer {
30 pub name: String,
31 pub transport: String,
32 #[serde(default = "default_enabled")]
33 pub enabled: bool,
34 pub connected: bool,
35 #[serde(skip_serializing_if = "Option::is_none")]
36 pub pid: Option<u32>,
37 #[serde(skip_serializing_if = "Option::is_none")]
38 pub last_error: Option<String>,
39 #[serde(default, skip_serializing_if = "Option::is_none")]
40 pub last_auth_challenge: Option<McpAuthChallenge>,
41 #[serde(default, skip_serializing_if = "Option::is_none")]
42 pub mcp_session_id: Option<String>,
43 #[serde(default)]
44 pub headers: HashMap<String, String>,
45 #[serde(default)]
46 pub tool_cache: Vec<McpToolCacheEntry>,
47 #[serde(default, skip_serializing_if = "Option::is_none")]
48 pub tools_fetched_at_ms: Option<u64>,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct McpAuthChallenge {
53 pub challenge_id: String,
54 pub tool_name: String,
55 pub authorization_url: String,
56 pub message: String,
57 pub requested_at_ms: u64,
58 pub status: String,
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct McpRemoteTool {
63 pub server_name: String,
64 pub tool_name: String,
65 pub namespaced_name: String,
66 pub description: String,
67 #[serde(default)]
68 pub input_schema: Value,
69 pub fetched_at_ms: u64,
70 pub schema_hash: String,
71}
72
73#[derive(Clone)]
74pub struct McpRegistry {
75 servers: Arc<RwLock<HashMap<String, McpServer>>>,
76 processes: Arc<Mutex<HashMap<String, Child>>>,
77 state_file: Arc<PathBuf>,
78}
79
80impl McpRegistry {
81 pub fn new() -> Self {
82 Self::new_with_state_file(resolve_state_file())
83 }
84
85 pub fn new_with_state_file(state_file: PathBuf) -> Self {
86 let loaded = load_state(&state_file)
87 .into_iter()
88 .map(|(k, mut v)| {
89 v.connected = false;
90 v.pid = None;
91 if v.name.trim().is_empty() {
92 v.name = k.clone();
93 }
94 if v.headers.is_empty() {
95 v.headers = HashMap::new();
96 }
97 (k, v)
98 })
99 .collect::<HashMap<_, _>>();
100 Self {
101 servers: Arc::new(RwLock::new(loaded)),
102 processes: Arc::new(Mutex::new(HashMap::new())),
103 state_file: Arc::new(state_file),
104 }
105 }
106
107 pub async fn list(&self) -> HashMap<String, McpServer> {
108 self.servers.read().await.clone()
109 }
110
111 pub async fn add(&self, name: String, transport: String) {
112 self.add_or_update(name, transport, HashMap::new(), true)
113 .await;
114 }
115
116 pub async fn add_or_update(
117 &self,
118 name: String,
119 transport: String,
120 headers: HashMap<String, String>,
121 enabled: bool,
122 ) {
123 let mut servers = self.servers.write().await;
124 let existing = servers.get(&name).cloned();
125 let preserve_cache = existing
126 .as_ref()
127 .is_some_and(|row| row.transport == transport && row.headers == headers);
128 let existing_tool_cache = if preserve_cache {
129 existing
130 .as_ref()
131 .map(|row| row.tool_cache.clone())
132 .unwrap_or_default()
133 } else {
134 Vec::new()
135 };
136 let existing_fetched_at = if preserve_cache {
137 existing.as_ref().and_then(|row| row.tools_fetched_at_ms)
138 } else {
139 None
140 };
141 let server = McpServer {
142 name: name.clone(),
143 transport,
144 enabled,
145 connected: false,
146 pid: None,
147 last_error: None,
148 last_auth_challenge: None,
149 mcp_session_id: None,
150 headers,
151 tool_cache: existing_tool_cache,
152 tools_fetched_at_ms: existing_fetched_at,
153 };
154 servers.insert(name, server);
155 drop(servers);
156 self.persist_state().await;
157 }
158
159 pub async fn set_enabled(&self, name: &str, enabled: bool) -> bool {
160 let mut servers = self.servers.write().await;
161 let Some(server) = servers.get_mut(name) else {
162 return false;
163 };
164 server.enabled = enabled;
165 if !enabled {
166 server.connected = false;
167 server.pid = None;
168 server.last_auth_challenge = None;
169 server.mcp_session_id = None;
170 }
171 drop(servers);
172 if !enabled {
173 if let Some(mut child) = self.processes.lock().await.remove(name) {
174 let _ = child.kill().await;
175 let _ = child.wait().await;
176 }
177 }
178 self.persist_state().await;
179 true
180 }
181
182 pub async fn remove(&self, name: &str) -> bool {
183 let removed = {
184 let mut servers = self.servers.write().await;
185 servers.remove(name).is_some()
186 };
187 if !removed {
188 return false;
189 }
190
191 if let Some(mut child) = self.processes.lock().await.remove(name) {
192 let _ = child.kill().await;
193 let _ = child.wait().await;
194 }
195 self.persist_state().await;
196 true
197 }
198
199 pub async fn connect(&self, name: &str) -> bool {
200 let server = {
201 let servers = self.servers.read().await;
202 let Some(server) = servers.get(name) else {
203 return false;
204 };
205 server.clone()
206 };
207
208 if !server.enabled {
209 let mut servers = self.servers.write().await;
210 if let Some(entry) = servers.get_mut(name) {
211 entry.connected = false;
212 entry.pid = None;
213 entry.last_error = Some("MCP server is disabled".to_string());
214 entry.last_auth_challenge = None;
215 entry.mcp_session_id = None;
216 }
217 drop(servers);
218 self.persist_state().await;
219 return false;
220 }
221
222 if let Some(command_text) = parse_stdio_transport(&server.transport) {
223 return self.connect_stdio(name, command_text).await;
224 }
225
226 if parse_remote_endpoint(&server.transport).is_some() {
227 return self.refresh(name).await.is_ok();
228 }
229
230 let mut servers = self.servers.write().await;
231 if let Some(entry) = servers.get_mut(name) {
232 entry.connected = true;
233 entry.pid = None;
234 entry.last_error = None;
235 entry.last_auth_challenge = None;
236 entry.mcp_session_id = None;
237 }
238 drop(servers);
239 self.persist_state().await;
240 true
241 }
242
243 pub async fn refresh(&self, name: &str) -> Result<Vec<McpRemoteTool>, String> {
244 let server = {
245 let servers = self.servers.read().await;
246 let Some(server) = servers.get(name) else {
247 return Err("MCP server not found".to_string());
248 };
249 server.clone()
250 };
251
252 if !server.enabled {
253 return Err("MCP server is disabled".to_string());
254 }
255
256 let endpoint = parse_remote_endpoint(&server.transport)
257 .ok_or_else(|| "MCP refresh currently supports HTTP/S transports only".to_string())?;
258
259 let (tools, session_id) = match self.discover_remote_tools(&endpoint, &server.headers).await
260 {
261 Ok(result) => result,
262 Err(err) => {
263 let mut servers = self.servers.write().await;
264 if let Some(entry) = servers.get_mut(name) {
265 entry.connected = false;
266 entry.pid = None;
267 entry.last_error = Some(err.clone());
268 entry.last_auth_challenge = None;
269 entry.mcp_session_id = None;
270 entry.tool_cache.clear();
271 entry.tools_fetched_at_ms = None;
272 }
273 drop(servers);
274 self.persist_state().await;
275 return Err(err);
276 }
277 };
278
279 let now = now_ms();
280 let cache = tools
281 .iter()
282 .map(|tool| McpToolCacheEntry {
283 tool_name: tool.tool_name.clone(),
284 description: tool.description.clone(),
285 input_schema: tool.input_schema.clone(),
286 fetched_at_ms: now,
287 schema_hash: schema_hash(&tool.input_schema),
288 })
289 .collect::<Vec<_>>();
290
291 let mut servers = self.servers.write().await;
292 if let Some(entry) = servers.get_mut(name) {
293 entry.connected = true;
294 entry.pid = None;
295 entry.last_error = None;
296 entry.last_auth_challenge = None;
297 entry.mcp_session_id = session_id;
298 entry.tool_cache = cache;
299 entry.tools_fetched_at_ms = Some(now);
300 }
301 drop(servers);
302 self.persist_state().await;
303 Ok(self.server_tools(name).await)
304 }
305
306 pub async fn disconnect(&self, name: &str) -> bool {
307 if let Some(mut child) = self.processes.lock().await.remove(name) {
308 let _ = child.kill().await;
309 let _ = child.wait().await;
310 }
311 let mut servers = self.servers.write().await;
312 if let Some(server) = servers.get_mut(name) {
313 server.connected = false;
314 server.pid = None;
315 server.last_auth_challenge = None;
316 server.mcp_session_id = None;
317 drop(servers);
318 self.persist_state().await;
319 return true;
320 }
321 false
322 }
323
324 pub async fn list_tools(&self) -> Vec<McpRemoteTool> {
325 let mut out = self
326 .servers
327 .read()
328 .await
329 .values()
330 .filter(|server| server.enabled && server.connected)
331 .flat_map(server_tool_rows)
332 .collect::<Vec<_>>();
333 out.sort_by(|a, b| a.namespaced_name.cmp(&b.namespaced_name));
334 out
335 }
336
337 pub async fn server_tools(&self, name: &str) -> Vec<McpRemoteTool> {
338 let Some(server) = self.servers.read().await.get(name).cloned() else {
339 return Vec::new();
340 };
341 let mut rows = server_tool_rows(&server);
342 rows.sort_by(|a, b| a.namespaced_name.cmp(&b.namespaced_name));
343 rows
344 }
345
346 pub async fn call_tool(
347 &self,
348 server_name: &str,
349 tool_name: &str,
350 args: Value,
351 ) -> Result<ToolResult, String> {
352 let server = {
353 let servers = self.servers.read().await;
354 let Some(server) = servers.get(server_name) else {
355 return Err(format!("MCP server '{server_name}' not found"));
356 };
357 server.clone()
358 };
359
360 if !server.enabled {
361 return Err(format!("MCP server '{server_name}' is disabled"));
362 }
363 if !server.connected {
364 return Err(format!("MCP server '{server_name}' is not connected"));
365 }
366
367 let endpoint = parse_remote_endpoint(&server.transport).ok_or_else(|| {
368 "MCP tools/call currently supports HTTP/S transports only".to_string()
369 })?;
370 let normalized_args = normalize_mcp_tool_args(&server, tool_name, args);
371
372 let request = json!({
373 "jsonrpc": "2.0",
374 "id": format!("call-{}-{}", server_name, now_ms()),
375 "method": "tools/call",
376 "params": {
377 "name": tool_name,
378 "arguments": normalized_args
379 }
380 });
381 let (response, session_id) = post_json_rpc_with_session(
382 &endpoint,
383 &server.headers,
384 request,
385 server.mcp_session_id.as_deref(),
386 )
387 .await?;
388 if session_id.is_some() {
389 let mut servers = self.servers.write().await;
390 if let Some(row) = servers.get_mut(server_name) {
391 row.mcp_session_id = session_id;
392 }
393 drop(servers);
394 self.persist_state().await;
395 }
396
397 if let Some(err) = response.get("error") {
398 if let Some(challenge) = extract_auth_challenge(err, tool_name) {
399 let output = format!(
400 "{}\n\nAuthorize here: {}",
401 challenge.message, challenge.authorization_url
402 );
403 {
404 let mut servers = self.servers.write().await;
405 if let Some(row) = servers.get_mut(server_name) {
406 row.last_auth_challenge = Some(challenge.clone());
407 row.last_error = None;
408 }
409 }
410 self.persist_state().await;
411 return Ok(ToolResult {
412 output,
413 metadata: json!({
414 "server": server_name,
415 "tool": tool_name,
416 "result": Value::Null,
417 "mcpAuth": {
418 "required": true,
419 "challengeId": challenge.challenge_id,
420 "tool": challenge.tool_name,
421 "authorizationUrl": challenge.authorization_url,
422 "message": challenge.message,
423 "status": challenge.status
424 }
425 }),
426 });
427 }
428 let message = err
429 .get("message")
430 .and_then(|v| v.as_str())
431 .unwrap_or("MCP tools/call failed");
432 return Err(message.to_string());
433 }
434
435 let result = response.get("result").cloned().unwrap_or(Value::Null);
436 let auth_challenge = extract_auth_challenge(&result, tool_name);
437 let output = if let Some(challenge) = auth_challenge.as_ref() {
438 format!(
439 "{}\n\nAuthorize here: {}",
440 challenge.message, challenge.authorization_url
441 )
442 } else {
443 result
444 .get("content")
445 .map(render_mcp_content)
446 .or_else(|| result.get("output").map(|v| v.to_string()))
447 .unwrap_or_else(|| result.to_string())
448 };
449
450 {
451 let mut servers = self.servers.write().await;
452 if let Some(row) = servers.get_mut(server_name) {
453 row.last_auth_challenge = auth_challenge.clone();
454 }
455 }
456 self.persist_state().await;
457
458 let auth_metadata = auth_challenge.as_ref().map(|challenge| {
459 json!({
460 "required": true,
461 "challengeId": challenge.challenge_id,
462 "tool": challenge.tool_name,
463 "authorizationUrl": challenge.authorization_url,
464 "message": challenge.message,
465 "status": challenge.status
466 })
467 });
468
469 Ok(ToolResult {
470 output,
471 metadata: json!({
472 "server": server_name,
473 "tool": tool_name,
474 "result": result,
475 "mcpAuth": auth_metadata
476 }),
477 })
478 }
479
480 async fn connect_stdio(&self, name: &str, command_text: &str) -> bool {
481 match spawn_stdio_process(command_text).await {
482 Ok(child) => {
483 let pid = child.id();
484 self.processes.lock().await.insert(name.to_string(), child);
485 let mut servers = self.servers.write().await;
486 if let Some(server) = servers.get_mut(name) {
487 server.connected = true;
488 server.pid = pid;
489 server.last_error = None;
490 }
491 drop(servers);
492 self.persist_state().await;
493 true
494 }
495 Err(err) => {
496 let mut servers = self.servers.write().await;
497 if let Some(server) = servers.get_mut(name) {
498 server.connected = false;
499 server.pid = None;
500 server.last_error = Some(err);
501 }
502 drop(servers);
503 self.persist_state().await;
504 false
505 }
506 }
507 }
508
509 async fn discover_remote_tools(
510 &self,
511 endpoint: &str,
512 headers: &HashMap<String, String>,
513 ) -> Result<(Vec<McpRemoteTool>, Option<String>), String> {
514 let initialize = json!({
515 "jsonrpc": "2.0",
516 "id": "initialize-1",
517 "method": "initialize",
518 "params": {
519 "protocolVersion": MCP_PROTOCOL_VERSION,
520 "capabilities": {},
521 "clientInfo": {
522 "name": MCP_CLIENT_NAME,
523 "version": MCP_CLIENT_VERSION,
524 }
525 }
526 });
527 let (init_response, mut session_id) =
528 post_json_rpc_with_session(endpoint, headers, initialize, None).await?;
529 if let Some(err) = init_response.get("error") {
530 let message = err
531 .get("message")
532 .and_then(|v| v.as_str())
533 .unwrap_or("MCP initialize failed");
534 return Err(message.to_string());
535 }
536
537 let tools_list = json!({
538 "jsonrpc": "2.0",
539 "id": "tools-list-1",
540 "method": "tools/list",
541 "params": {}
542 });
543 let (tools_response, next_session_id) =
544 post_json_rpc_with_session(endpoint, headers, tools_list, session_id.as_deref())
545 .await?;
546 if next_session_id.is_some() {
547 session_id = next_session_id;
548 }
549 if let Some(err) = tools_response.get("error") {
550 let message = err
551 .get("message")
552 .and_then(|v| v.as_str())
553 .unwrap_or("MCP tools/list failed");
554 return Err(message.to_string());
555 }
556
557 let tools = tools_response
558 .get("result")
559 .and_then(|v| v.get("tools"))
560 .and_then(|v| v.as_array())
561 .ok_or_else(|| "MCP tools/list result missing tools array".to_string())?;
562
563 let now = now_ms();
564 let mut out = Vec::new();
565 for row in tools {
566 let Some(tool_name) = row.get("name").and_then(|v| v.as_str()) else {
567 continue;
568 };
569 let description = row
570 .get("description")
571 .and_then(|v| v.as_str())
572 .unwrap_or("")
573 .to_string();
574 let mut input_schema = row
575 .get("inputSchema")
576 .or_else(|| row.get("input_schema"))
577 .cloned()
578 .unwrap_or_else(|| json!({"type":"object"}));
579 normalize_tool_input_schema(&mut input_schema);
580 out.push(McpRemoteTool {
581 server_name: String::new(),
582 tool_name: tool_name.to_string(),
583 namespaced_name: String::new(),
584 description,
585 input_schema,
586 fetched_at_ms: now,
587 schema_hash: String::new(),
588 });
589 }
590
591 Ok((out, session_id))
592 }
593
594 async fn persist_state(&self) {
595 let snapshot = self.servers.read().await.clone();
596 if let Some(parent) = self.state_file.parent() {
597 let _ = tokio::fs::create_dir_all(parent).await;
598 }
599 if let Ok(payload) = serde_json::to_string_pretty(&snapshot) {
600 let _ = tokio::fs::write(self.state_file.as_path(), payload).await;
601 }
602 }
603}
604
605impl Default for McpRegistry {
606 fn default() -> Self {
607 Self::new()
608 }
609}
610
611fn default_enabled() -> bool {
612 true
613}
614
615fn resolve_state_file() -> PathBuf {
616 if let Ok(path) = std::env::var("TANDEM_MCP_REGISTRY") {
617 return PathBuf::from(path);
618 }
619 if let Ok(state_dir) = std::env::var("TANDEM_STATE_DIR") {
620 let trimmed = state_dir.trim();
621 if !trimmed.is_empty() {
622 return PathBuf::from(trimmed).join("mcp_servers.json");
623 }
624 }
625 if let Some(data_dir) = dirs::data_dir() {
626 return data_dir
627 .join("tandem")
628 .join("data")
629 .join("mcp_servers.json");
630 }
631 dirs::home_dir()
632 .map(|home| home.join(".tandem").join("data").join("mcp_servers.json"))
633 .unwrap_or_else(|| PathBuf::from("mcp_servers.json"))
634}
635
636fn load_state(path: &Path) -> HashMap<String, McpServer> {
637 let Ok(raw) = std::fs::read_to_string(path) else {
638 return HashMap::new();
639 };
640 serde_json::from_str::<HashMap<String, McpServer>>(&raw).unwrap_or_default()
641}
642
643fn parse_stdio_transport(transport: &str) -> Option<&str> {
644 transport.strip_prefix("stdio:").map(str::trim)
645}
646
647fn parse_remote_endpoint(transport: &str) -> Option<String> {
648 let trimmed = transport.trim();
649 if trimmed.starts_with("http://") || trimmed.starts_with("https://") {
650 return Some(trimmed.to_string());
651 }
652 for prefix in ["http:", "https:"] {
653 if let Some(rest) = trimmed.strip_prefix(prefix) {
654 let endpoint = rest.trim();
655 if endpoint.starts_with("http://") || endpoint.starts_with("https://") {
656 return Some(endpoint.to_string());
657 }
658 }
659 }
660 None
661}
662
663fn server_tool_rows(server: &McpServer) -> Vec<McpRemoteTool> {
664 let server_slug = sanitize_namespace_segment(&server.name);
665 server
666 .tool_cache
667 .iter()
668 .map(|tool| {
669 let tool_slug = sanitize_namespace_segment(&tool.tool_name);
670 McpRemoteTool {
671 server_name: server.name.clone(),
672 tool_name: tool.tool_name.clone(),
673 namespaced_name: format!("mcp.{server_slug}.{tool_slug}"),
674 description: tool.description.clone(),
675 input_schema: tool.input_schema.clone(),
676 fetched_at_ms: tool.fetched_at_ms,
677 schema_hash: tool.schema_hash.clone(),
678 }
679 })
680 .collect()
681}
682
683fn sanitize_namespace_segment(raw: &str) -> String {
684 let mut out = String::new();
685 let mut previous_underscore = false;
686 for ch in raw.trim().chars() {
687 if ch.is_ascii_alphanumeric() {
688 out.push(ch.to_ascii_lowercase());
689 previous_underscore = false;
690 } else if !previous_underscore {
691 out.push('_');
692 previous_underscore = true;
693 }
694 }
695 let cleaned = out.trim_matches('_');
696 if cleaned.is_empty() {
697 "tool".to_string()
698 } else {
699 cleaned.to_string()
700 }
701}
702
703fn schema_hash(schema: &Value) -> String {
704 let payload = serde_json::to_vec(schema).unwrap_or_default();
705 let mut hasher = Sha256::new();
706 hasher.update(payload);
707 format!("{:x}", hasher.finalize())
708}
709
710fn extract_auth_challenge(result: &Value, tool_name: &str) -> Option<McpAuthChallenge> {
711 let authorization_url = find_string_by_any_key(
712 result,
713 &["authorization_url", "authorizationUrl", "auth_url"],
714 )?;
715 let message = find_string_by_any_key(result, &["llm_instructions", "message", "text"])
716 .unwrap_or_else(|| "This tool requires authorization before it can run.".to_string());
717 let challenge_id = stable_id_seed(&format!("{tool_name}:{authorization_url}"));
718 Some(McpAuthChallenge {
719 challenge_id,
720 tool_name: tool_name.to_string(),
721 authorization_url,
722 message,
723 requested_at_ms: now_ms(),
724 status: "pending".to_string(),
725 })
726}
727
728fn find_string_by_any_key(value: &Value, keys: &[&str]) -> Option<String> {
729 match value {
730 Value::Object(map) => {
731 for key in keys {
732 if let Some(s) = map.get(*key).and_then(|v| v.as_str()) {
733 let trimmed = s.trim();
734 if !trimmed.is_empty() {
735 return Some(trimmed.to_string());
736 }
737 }
738 }
739 for child in map.values() {
740 if let Some(found) = find_string_by_any_key(child, keys) {
741 return Some(found);
742 }
743 }
744 None
745 }
746 Value::Array(items) => items
747 .iter()
748 .find_map(|item| find_string_by_any_key(item, keys)),
749 _ => None,
750 }
751}
752
753fn stable_id_seed(seed: &str) -> String {
754 let mut hasher = Sha256::new();
755 hasher.update(seed.as_bytes());
756 let encoded = format!("{:x}", hasher.finalize());
757 encoded.chars().take(16).collect()
758}
759
760fn normalize_tool_input_schema(schema: &mut Value) {
761 normalize_schema_node(schema);
762}
763
764fn normalize_schema_node(node: &mut Value) {
765 let Some(obj) = node.as_object_mut() else {
766 return;
767 };
768
769 if let Some(enum_values) = obj.get("enum").and_then(|v| v.as_array()) {
773 let all_strings = enum_values.iter().all(|v| v.is_string());
774 let string_like_type = schema_type_allows_string_enum(obj.get("type"));
775 if !all_strings || !string_like_type {
776 obj.remove("enum");
777 }
778 }
779
780 if let Some(properties) = obj.get_mut("properties").and_then(|v| v.as_object_mut()) {
781 for value in properties.values_mut() {
782 normalize_schema_node(value);
783 }
784 }
785
786 if let Some(items) = obj.get_mut("items") {
787 normalize_schema_node(items);
788 }
789
790 for key in ["anyOf", "oneOf", "allOf"] {
791 if let Some(array) = obj.get_mut(key).and_then(|v| v.as_array_mut()) {
792 for child in array.iter_mut() {
793 normalize_schema_node(child);
794 }
795 }
796 }
797
798 if let Some(additional) = obj.get_mut("additionalProperties") {
799 normalize_schema_node(additional);
800 }
801}
802
803fn schema_type_allows_string_enum(schema_type: Option<&Value>) -> bool {
804 let Some(schema_type) = schema_type else {
805 return true;
807 };
808
809 if let Some(kind) = schema_type.as_str() {
810 return kind == "string";
811 }
812
813 if let Some(kinds) = schema_type.as_array() {
814 let mut saw_string = false;
815 for kind in kinds {
816 let Some(kind) = kind.as_str() else {
817 return false;
818 };
819 if kind == "string" {
820 saw_string = true;
821 continue;
822 }
823 if kind != "null" {
824 return false;
825 }
826 }
827 return saw_string;
828 }
829
830 false
831}
832
833fn now_ms() -> u64 {
834 SystemTime::now()
835 .duration_since(UNIX_EPOCH)
836 .map(|d| d.as_millis() as u64)
837 .unwrap_or(0)
838}
839
840fn build_headers(headers: &HashMap<String, String>) -> Result<HeaderMap, String> {
841 let mut map = HeaderMap::new();
842 map.insert(
843 ACCEPT,
844 HeaderValue::from_static("application/json, text/event-stream"),
845 );
846 map.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
847 for (key, value) in headers {
848 let name = HeaderName::from_bytes(key.trim().as_bytes())
849 .map_err(|e| format!("Invalid header name '{key}': {e}"))?;
850 let header = HeaderValue::from_str(value.trim())
851 .map_err(|e| format!("Invalid header value for '{key}': {e}"))?;
852 map.insert(name, header);
853 }
854 Ok(map)
855}
856
857async fn post_json_rpc_with_session(
858 endpoint: &str,
859 headers: &HashMap<String, String>,
860 request: Value,
861 session_id: Option<&str>,
862) -> Result<(Value, Option<String>), String> {
863 let client = reqwest::Client::builder()
864 .timeout(std::time::Duration::from_secs(12))
865 .build()
866 .map_err(|e| format!("Failed to build HTTP client: {e}"))?;
867 let mut req = client.post(endpoint).headers(build_headers(headers)?);
868 if let Some(id) = session_id {
869 let trimmed = id.trim();
870 if !trimmed.is_empty() {
871 req = req.header("Mcp-Session-Id", trimmed);
872 }
873 }
874 let response = req
875 .json(&request)
876 .send()
877 .await
878 .map_err(|e| format!("MCP request failed: {e}"))?;
879 let response_session_id = response
880 .headers()
881 .get("mcp-session-id")
882 .and_then(|v| v.to_str().ok())
883 .map(|v| v.trim().to_string())
884 .filter(|v| !v.is_empty());
885 let status = response.status();
886 let payload = response
887 .text()
888 .await
889 .map_err(|e| format!("Failed to read MCP response: {e}"))?;
890 if !status.is_success() {
891 return Err(format!(
892 "MCP endpoint returned HTTP {}: {}",
893 status.as_u16(),
894 payload.chars().take(400).collect::<String>()
895 ));
896 }
897 let value = serde_json::from_str::<Value>(&payload)
898 .map_err(|e| format!("Invalid MCP JSON response: {e}"))?;
899 Ok((value, response_session_id))
900}
901
902fn render_mcp_content(value: &Value) -> String {
903 let Some(items) = value.as_array() else {
904 return value.to_string();
905 };
906 let mut chunks = Vec::new();
907 for item in items {
908 if let Some(text) = item.get("text").and_then(|v| v.as_str()) {
909 chunks.push(text.to_string());
910 continue;
911 }
912 chunks.push(item.to_string());
913 }
914 if chunks.is_empty() {
915 value.to_string()
916 } else {
917 chunks.join("\n")
918 }
919}
920
921fn normalize_mcp_tool_args(server: &McpServer, tool_name: &str, raw_args: Value) -> Value {
922 let Some(schema) = server
923 .tool_cache
924 .iter()
925 .find(|row| row.tool_name.eq_ignore_ascii_case(tool_name))
926 .map(|row| &row.input_schema)
927 else {
928 return raw_args;
929 };
930
931 let mut args_obj = match raw_args {
932 Value::Object(obj) => obj,
933 other => return other,
934 };
935
936 let properties = schema
937 .get("properties")
938 .and_then(|v| v.as_object())
939 .cloned()
940 .unwrap_or_default();
941 if properties.is_empty() {
942 return Value::Object(args_obj);
943 }
944
945 let mut normalized_existing: HashMap<String, String> = HashMap::new();
947 for key in args_obj.keys() {
948 normalized_existing.insert(normalize_arg_key(key), key.clone());
949 }
950
951 let canonical_keys = properties.keys().cloned().collect::<Vec<_>>();
953 for canonical in &canonical_keys {
954 if args_obj.contains_key(canonical) {
955 continue;
956 }
957 if let Some(existing_key) = normalized_existing.get(&normalize_arg_key(canonical)) {
958 if let Some(value) = args_obj.get(existing_key).cloned() {
959 args_obj.insert(canonical.clone(), value);
960 }
961 }
962 }
963
964 let required = schema
966 .get("required")
967 .and_then(|v| v.as_array())
968 .map(|arr| {
969 arr.iter()
970 .filter_map(|v| v.as_str().map(str::to_string))
971 .collect::<Vec<_>>()
972 })
973 .unwrap_or_default();
974
975 for required_key in required {
976 if args_obj.contains_key(&required_key) {
977 continue;
978 }
979 if let Some(alias_value) = find_required_alias_value(&required_key, &args_obj) {
980 args_obj.insert(required_key, alias_value);
981 }
982 }
983
984 Value::Object(args_obj)
985}
986
987fn find_required_alias_value(
988 required_key: &str,
989 args_obj: &serde_json::Map<String, Value>,
990) -> Option<Value> {
991 let mut alias_candidates = vec![
992 required_key.to_string(),
993 required_key.to_ascii_lowercase(),
994 required_key.replace('_', ""),
995 ];
996
997 if required_key.contains("title") {
999 alias_candidates.extend([
1000 "name".to_string(),
1001 "title".to_string(),
1002 "task_name".to_string(),
1003 "taskname".to_string(),
1004 ]);
1005 }
1006
1007 if let Some(base) = required_key.strip_suffix("_id") {
1009 alias_candidates.extend([base.to_string(), format!("{base}id"), format!("{base}_id")]);
1010 }
1011
1012 let mut by_normalized: HashMap<String, &Value> = HashMap::new();
1013 for (key, value) in args_obj {
1014 by_normalized.insert(normalize_arg_key(key), value);
1015 }
1016
1017 alias_candidates
1018 .into_iter()
1019 .find_map(|candidate| by_normalized.get(&normalize_arg_key(&candidate)).cloned())
1020 .cloned()
1021}
1022
1023fn normalize_arg_key(key: &str) -> String {
1024 key.chars()
1025 .filter(|ch| ch.is_ascii_alphanumeric())
1026 .map(|ch| ch.to_ascii_lowercase())
1027 .collect()
1028}
1029
1030async fn spawn_stdio_process(command_text: &str) -> Result<Child, String> {
1031 if command_text.is_empty() {
1032 return Err("Missing stdio command".to_string());
1033 }
1034 #[cfg(windows)]
1035 let mut command = {
1036 let mut cmd = Command::new("powershell");
1037 cmd.args(["-NoProfile", "-Command", command_text]);
1038 cmd
1039 };
1040 #[cfg(not(windows))]
1041 let mut command = {
1042 let mut cmd = Command::new("sh");
1043 cmd.args(["-lc", command_text]);
1044 cmd
1045 };
1046 command
1047 .stdin(std::process::Stdio::null())
1048 .stdout(std::process::Stdio::null())
1049 .stderr(std::process::Stdio::null());
1050 command.spawn().map_err(|e| e.to_string())
1051}
1052
1053#[cfg(test)]
1054mod tests {
1055 use super::*;
1056 use uuid::Uuid;
1057
1058 #[tokio::test]
1059 async fn add_connect_disconnect_non_stdio_server() {
1060 let file = std::env::temp_dir().join(format!("mcp-test-{}.json", Uuid::new_v4()));
1061 let registry = McpRegistry::new_with_state_file(file);
1062 registry
1063 .add("example".to_string(), "sse:https://example.com".to_string())
1064 .await;
1065 assert!(registry.connect("example").await);
1066 let listed = registry.list().await;
1067 assert!(listed.get("example").map(|s| s.connected).unwrap_or(false));
1068 assert!(registry.disconnect("example").await);
1069 }
1070
1071 #[test]
1072 fn parse_remote_endpoint_supports_http_prefixes() {
1073 assert_eq!(
1074 parse_remote_endpoint("https://mcp.example.com/mcp"),
1075 Some("https://mcp.example.com/mcp".to_string())
1076 );
1077 assert_eq!(
1078 parse_remote_endpoint("http:https://mcp.example.com/mcp"),
1079 Some("https://mcp.example.com/mcp".to_string())
1080 );
1081 }
1082
1083 #[test]
1084 fn normalize_schema_removes_non_string_enums_recursively() {
1085 let mut schema = json!({
1086 "type": "object",
1087 "properties": {
1088 "good": { "type": "string", "enum": ["a", "b"] },
1089 "good_nullable": { "type": ["string", "null"], "enum": ["asc", "desc"] },
1090 "bad_object": { "type": "object", "enum": ["asc", "desc"] },
1091 "bad_array": { "type": "array", "enum": ["asc", "desc"] },
1092 "bad_number": { "type": "number", "enum": [1, 2] },
1093 "bad_mixed": { "enum": ["ok", 1] },
1094 "nested": {
1095 "type": "object",
1096 "properties": {
1097 "child": { "enum": [true, false] }
1098 }
1099 }
1100 }
1101 });
1102
1103 normalize_tool_input_schema(&mut schema);
1104
1105 assert!(
1106 schema["properties"]["good"]["enum"].is_array(),
1107 "string enums should be preserved"
1108 );
1109 assert!(
1110 schema["properties"]["good_nullable"]["enum"].is_array(),
1111 "string|null enums should be preserved"
1112 );
1113 assert!(
1114 schema["properties"]["bad_object"]["enum"].is_null(),
1115 "object enums should be dropped"
1116 );
1117 assert!(
1118 schema["properties"]["bad_array"]["enum"].is_null(),
1119 "array enums should be dropped"
1120 );
1121 assert!(
1122 schema["properties"]["bad_number"]["enum"].is_null(),
1123 "non-string enums should be dropped"
1124 );
1125 assert!(
1126 schema["properties"]["bad_mixed"]["enum"].is_null(),
1127 "mixed enums should be dropped"
1128 );
1129 assert!(
1130 schema["properties"]["nested"]["properties"]["child"]["enum"].is_null(),
1131 "recursive non-string enums should be dropped"
1132 );
1133 }
1134
1135 #[test]
1136 fn extract_auth_challenge_from_result_payload() {
1137 let payload = json!({
1138 "content": [
1139 {
1140 "type": "text",
1141 "llm_instructions": "Authorize Gmail access first.",
1142 "authorization_url": "https://example.com/oauth/start"
1143 }
1144 ]
1145 });
1146 let challenge = extract_auth_challenge(&payload, "gmail_whoami")
1147 .expect("auth challenge should be detected");
1148 assert_eq!(challenge.tool_name, "gmail_whoami");
1149 assert_eq!(
1150 challenge.authorization_url,
1151 "https://example.com/oauth/start"
1152 );
1153 assert_eq!(challenge.status, "pending");
1154 }
1155
1156 #[test]
1157 fn extract_auth_challenge_returns_none_without_url() {
1158 let payload = json!({
1159 "content": [
1160 {"type":"text","text":"No authorization needed"}
1161 ]
1162 });
1163 assert!(extract_auth_challenge(&payload, "gmail_whoami").is_none());
1164 }
1165
1166 #[test]
1167 fn normalize_mcp_tool_args_maps_clickup_aliases() {
1168 let server = McpServer {
1169 name: "arcade".to_string(),
1170 transport: "https://example.com/mcp".to_string(),
1171 enabled: true,
1172 connected: true,
1173 pid: None,
1174 last_error: None,
1175 last_auth_challenge: None,
1176 mcp_session_id: None,
1177 headers: HashMap::new(),
1178 tool_cache: vec![McpToolCacheEntry {
1179 tool_name: "Clickup_CreateTask".to_string(),
1180 description: "Create task".to_string(),
1181 input_schema: json!({
1182 "type":"object",
1183 "properties":{
1184 "list_id":{"type":"string"},
1185 "task_title":{"type":"string"}
1186 },
1187 "required":["list_id","task_title"]
1188 }),
1189 fetched_at_ms: 0,
1190 schema_hash: "x".to_string(),
1191 }],
1192 tools_fetched_at_ms: None,
1193 };
1194
1195 let normalized = normalize_mcp_tool_args(
1196 &server,
1197 "Clickup_CreateTask",
1198 json!({
1199 "listId": "123",
1200 "name": "Prep fish"
1201 }),
1202 );
1203 assert_eq!(
1204 normalized.get("list_id").and_then(|v| v.as_str()),
1205 Some("123")
1206 );
1207 assert_eq!(
1208 normalized.get("task_title").and_then(|v| v.as_str()),
1209 Some("Prep fish")
1210 );
1211 }
1212
1213 #[test]
1214 fn normalize_arg_key_ignores_case_and_separators() {
1215 assert_eq!(normalize_arg_key("task_title"), "tasktitle");
1216 assert_eq!(normalize_arg_key("taskTitle"), "tasktitle");
1217 assert_eq!(normalize_arg_key("task-title"), "tasktitle");
1218 }
1219}