Skip to main content

ati/core/
mcp_client.rs

1/// MCP client — connects to MCP servers via stdio or Streamable HTTP transport.
2///
3/// Implements the MCP protocol (2025-03-26 revision):
4/// - JSON-RPC 2.0 message framing
5/// - stdio transport: newline-delimited JSON over stdin/stdout
6/// - Streamable HTTP transport: POST with Accept: application/json, text/event-stream
7///   Server may respond with JSON or SSE stream. Supports Mcp-Session-Id for sessions.
8/// - Lifecycle: initialize → tools/list → tools/call → shutdown
9use serde::{Deserialize, Serialize};
10use serde_json::Value;
11use std::collections::HashMap;
12use std::io::{BufRead, BufReader, Write as IoWrite};
13use std::process::{Child, Command, Stdio};
14use std::sync::atomic::{AtomicU64, Ordering};
15use thiserror::Error;
16use tokio::sync::Mutex;
17
18use crate::core::auth_generator::{self, AuthCache, GenContext};
19use crate::core::keyring::Keyring;
20use crate::core::manifest::Provider;
21
22// ---------------------------------------------------------------------------
23// Errors
24// ---------------------------------------------------------------------------
25
26#[derive(Error, Debug)]
27#[allow(dead_code)]
28pub enum McpError {
29    #[error("MCP transport error: {0}")]
30    Transport(String),
31    #[error("MCP protocol error (code {code}): {message}")]
32    Protocol { code: i64, message: String },
33    #[error("MCP server did not return tools capability")]
34    NoToolsCapability,
35    #[error("IO error: {0}")]
36    Io(#[from] std::io::Error),
37    #[error("JSON error: {0}")]
38    Json(#[from] serde_json::Error),
39    #[error("HTTP error: {0}")]
40    Http(#[from] reqwest::Error),
41    #[error("MCP initialization failed: {0}")]
42    InitFailed(String),
43    #[error("SSE parse error: {0}")]
44    SseParse(String),
45    #[error("MCP server process exited unexpectedly")]
46    ProcessExited,
47    #[error("Missing MCP configuration: {0}")]
48    Config(String),
49}
50
51// ---------------------------------------------------------------------------
52// JSON-RPC types
53// ---------------------------------------------------------------------------
54
55#[derive(Debug, Serialize)]
56struct JsonRpcRequest {
57    jsonrpc: &'static str,
58    id: u64,
59    method: String,
60    #[serde(skip_serializing_if = "Option::is_none")]
61    params: Option<Value>,
62}
63
64// Note: We parse JSON-RPC responses manually via serde_json::Value
65// rather than typed deserialization, since responses can be interleaved
66// with notifications and batches in SSE streams.
67
68// ---------------------------------------------------------------------------
69// MCP protocol types
70// ---------------------------------------------------------------------------
71
72/// Tool definition from MCP tools/list response.
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct McpToolDef {
75    pub name: String,
76    #[serde(default)]
77    pub description: Option<String>,
78    #[serde(default, rename = "inputSchema")]
79    pub input_schema: Option<Value>,
80}
81
82/// Content item from MCP tools/call response.
83#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct McpContent {
85    #[serde(rename = "type")]
86    pub content_type: String,
87    #[serde(default)]
88    pub text: Option<String>,
89    #[serde(default)]
90    pub data: Option<String>,
91    #[serde(default, rename = "mimeType")]
92    pub mime_type: Option<String>,
93}
94
95/// Result from tools/call.
96#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct McpToolResult {
98    pub content: Vec<McpContent>,
99    #[serde(default, rename = "isError")]
100    pub is_error: bool,
101}
102
103// ---------------------------------------------------------------------------
104// Transport abstraction
105// ---------------------------------------------------------------------------
106
107/// Internal transport enum — stdio or HTTP.
108enum Transport {
109    Stdio(StdioTransport),
110    Http(HttpTransport),
111}
112
113/// Stdio transport: subprocess with stdin/stdout.
114struct StdioTransport {
115    child: Child,
116    /// We write JSON-RPC to the child's stdin. Option so we can take() on disconnect.
117    stdin: Option<std::process::ChildStdin>,
118    /// We read newline-delimited JSON-RPC from stdout.
119    reader: BufReader<std::process::ChildStdout>,
120}
121
122impl Drop for StdioTransport {
123    fn drop(&mut self) {
124        // Best-effort cleanup: kill the subprocess if it wasn't explicitly disconnected.
125        // Prevents orphan zombie processes when a future is cancelled (e.g., on timeout).
126        let _ = self.child.kill();
127        let _ = self.child.wait();
128    }
129}
130
131/// Streamable HTTP transport: POST to MCP endpoint.
132struct HttpTransport {
133    client: reqwest::Client,
134    url: String,
135    /// Session ID from Mcp-Session-Id header (set after initialize).
136    session_id: Option<String>,
137    /// Auth header name (default: "Authorization"). Custom for APIs using e.g. "x-api-key".
138    auth_header_name: String,
139    /// Auth header value (e.g., "Bearer <token>") injected on every request.
140    auth_header: Option<String>,
141    /// Extra headers from provider config.
142    extra_headers: HashMap<String, String>,
143}
144
145// ---------------------------------------------------------------------------
146// McpClient
147// ---------------------------------------------------------------------------
148
149/// MCP client that connects to a single MCP server.
150pub struct McpClient {
151    transport: Mutex<Transport>,
152    next_id: AtomicU64,
153    /// Cached tools from tools/list.
154    cached_tools: Mutex<Option<Vec<McpToolDef>>>,
155    /// Provider name (for logging).
156    provider_name: String,
157}
158
159impl McpClient {
160    /// Connect to an MCP server based on the provider's configuration.
161    ///
162    /// For stdio: spawns the subprocess with env vars resolved from keyring.
163    /// For HTTP: creates an HTTP client with auth headers.
164    pub async fn connect(provider: &Provider, keyring: &Keyring) -> Result<Self, McpError> {
165        Self::connect_with_gen(provider, keyring, None, None, None).await
166    }
167
168    /// Connect to an MCP server, optionally using a dynamic auth generator.
169    ///
170    /// `override_mcp_url`: when `Some`, overrides the provider's static
171    /// `mcp_url` for this connection. Used by the proxy to honour a
172    /// sandbox-supplied `X-Ati-Upstream-Url` after validating it against the
173    /// per-provider allowlist (issue #124). Ignored on stdio transport
174    /// (which has no URL — guarded at manifest load time).
175    pub async fn connect_with_gen(
176        provider: &Provider,
177        keyring: &Keyring,
178        gen_ctx: Option<&GenContext>,
179        auth_cache: Option<&AuthCache>,
180        override_mcp_url: Option<&str>,
181    ) -> Result<Self, McpError> {
182        let transport = match provider.mcp_transport_type() {
183            "stdio" => {
184                let command = provider.mcp_command.as_deref().ok_or_else(|| {
185                    McpError::Config("mcp_command required for stdio transport".into())
186                })?;
187
188                // Resolve env vars: "${key_name}" → keyring value
189                let mut env_map: HashMap<String, String> = HashMap::new();
190                // Selectively pass through essential env vars (don't leak secrets)
191                if let Ok(path) = std::env::var("PATH") {
192                    env_map.insert("PATH".to_string(), path);
193                }
194                if let Ok(home) = std::env::var("HOME") {
195                    env_map.insert("HOME".to_string(), home);
196                }
197                // Add provider-specific env vars (resolved from keyring)
198                for (k, v) in &provider.mcp_env {
199                    let resolved = resolve_env_value(v, keyring);
200                    env_map.insert(k.clone(), resolved);
201                }
202
203                // If auth_generator is configured, run it and inject into env
204                if let Some(gen) = &provider.auth_generator {
205                    let default_ctx = GenContext::default();
206                    let ctx = gen_ctx.unwrap_or(&default_ctx);
207                    let default_cache = AuthCache::new();
208                    let cache = auth_cache.unwrap_or(&default_cache);
209                    match auth_generator::generate(provider, gen, ctx, keyring, cache).await {
210                        Ok(cred) => {
211                            env_map.insert("ATI_AUTH_TOKEN".to_string(), cred.value);
212                            for (k, v) in &cred.extra_env {
213                                env_map.insert(k.clone(), v.clone());
214                            }
215                        }
216                        Err(e) => {
217                            return Err(McpError::Config(format!("auth_generator failed: {e}")));
218                        }
219                    }
220                }
221
222                let mut child = Command::new(command)
223                    .args(&provider.mcp_args)
224                    .stdin(Stdio::piped())
225                    .stdout(Stdio::piped())
226                    .stderr(Stdio::piped())
227                    .env_clear()
228                    .envs(&env_map)
229                    .spawn()
230                    .map_err(|e| {
231                        McpError::Transport(format!("Failed to spawn MCP server '{command}': {e}"))
232                    })?;
233
234                let stdin = child
235                    .stdin
236                    .take()
237                    .ok_or_else(|| McpError::Transport("No stdin".into()))?;
238                let stdout = child
239                    .stdout
240                    .take()
241                    .ok_or_else(|| McpError::Transport("No stdout".into()))?;
242                let reader = BufReader::new(stdout);
243
244                Transport::Stdio(StdioTransport {
245                    child,
246                    stdin: Some(stdin),
247                    reader,
248                })
249            }
250            "http" => {
251                // Pre-validated sandbox-supplied URL wins over the static
252                // manifest field. The proxy has already glob-matched the
253                // override against the operator's allowlist before calling
254                // us; we trust the override here (issue #124).
255                let url = override_mcp_url
256                    .or(provider.mcp_url.as_deref())
257                    .ok_or_else(|| {
258                        McpError::Config("mcp_url required for HTTP transport".into())
259                    })?;
260
261                // Build auth header: generator takes priority over static keyring
262                let auth_header = if let Some(gen) = &provider.auth_generator {
263                    let default_ctx = GenContext::default();
264                    let ctx = gen_ctx.unwrap_or(&default_ctx);
265                    let default_cache = AuthCache::new();
266                    let cache = auth_cache.unwrap_or(&default_cache);
267                    match auth_generator::generate(provider, gen, ctx, keyring, cache).await {
268                        Ok(cred) => match &provider.auth_type {
269                            super::manifest::AuthType::Bearer => {
270                                Some(format!("Bearer {}", cred.value))
271                            }
272                            super::manifest::AuthType::Header => {
273                                if let Some(prefix) = &provider.auth_value_prefix {
274                                    Some(format!("{prefix}{}", cred.value))
275                                } else {
276                                    Some(cred.value)
277                                }
278                            }
279                            _ => Some(cred.value),
280                        },
281                        Err(e) => {
282                            return Err(McpError::Config(format!("auth_generator failed: {e}")));
283                        }
284                    }
285                } else {
286                    build_auth_header(provider, keyring)
287                };
288
289                let client = reqwest::Client::builder()
290                    .timeout(std::time::Duration::from_secs(300))
291                    .build()?;
292
293                // Resolve ${key_name} placeholders in the URL from keyring
294                let resolved_url = resolve_env_value(url, keyring);
295
296                let auth_header_name = provider
297                    .auth_header_name
298                    .clone()
299                    .unwrap_or_else(|| "Authorization".to_string());
300
301                Transport::Http(HttpTransport {
302                    client,
303                    url: resolved_url,
304                    session_id: None,
305                    auth_header_name,
306                    auth_header,
307                    extra_headers: provider.extra_headers.clone(),
308                })
309            }
310            other => {
311                return Err(McpError::Config(format!(
312                    "Unknown MCP transport: '{other}' (expected 'stdio' or 'http')"
313                )));
314            }
315        };
316
317        let client = McpClient {
318            transport: Mutex::new(transport),
319            next_id: AtomicU64::new(1),
320            cached_tools: Mutex::new(None),
321            provider_name: provider.name.clone(),
322        };
323
324        // Perform MCP initialize handshake
325        client.initialize().await?;
326
327        Ok(client)
328    }
329
330    /// Perform the MCP initialize handshake.
331    async fn initialize(&self) -> Result<(), McpError> {
332        let params = serde_json::json!({
333            "protocolVersion": "2025-03-26",
334            "capabilities": {},
335            "clientInfo": {
336                "name": "ati",
337                "version": env!("CARGO_PKG_VERSION")
338            }
339        });
340
341        let response = self.send_request("initialize", Some(params)).await?;
342
343        // Verify server has tools capability
344        let capabilities = response.get("capabilities").unwrap_or(&Value::Null);
345        if capabilities.get("tools").is_none() {
346            return Err(McpError::NoToolsCapability);
347        }
348
349        // Extract session ID from HTTP transport response (handled inside send_request)
350
351        // Send initialized notification
352        self.send_notification("notifications/initialized", None)
353            .await?;
354
355        Ok(())
356    }
357
358    /// Discover tools via tools/list. Results are cached.
359    pub async fn list_tools(&self) -> Result<Vec<McpToolDef>, McpError> {
360        // Return cached if available
361        {
362            let cache = self.cached_tools.lock().await;
363            if let Some(tools) = cache.as_ref() {
364                return Ok(tools.clone());
365            }
366        }
367
368        let mut all_tools = Vec::new();
369        let mut cursor: Option<String> = None;
370        const MAX_PAGES: usize = 100;
371        const MAX_TOOLS: usize = 10_000;
372
373        for _page in 0..MAX_PAGES {
374            let params = cursor.as_ref().map(|c| serde_json::json!({"cursor": c}));
375            let result = self.send_request("tools/list", params).await?;
376
377            if let Some(tools_val) = result.get("tools") {
378                let tools: Vec<McpToolDef> = serde_json::from_value(tools_val.clone())?;
379                all_tools.extend(tools);
380            }
381
382            // Safety: cap total tools to prevent memory exhaustion
383            if all_tools.len() > MAX_TOOLS {
384                tracing::warn!(max = MAX_TOOLS, "MCP tool count exceeds limit, truncating");
385                all_tools.truncate(MAX_TOOLS);
386                break;
387            }
388
389            // Check for pagination
390            match result.get("nextCursor").and_then(|v| v.as_str()) {
391                Some(next) => cursor = Some(next.to_string()),
392                None => break,
393            }
394        }
395
396        // Cache the result
397        {
398            let mut cache = self.cached_tools.lock().await;
399            *cache = Some(all_tools.clone());
400        }
401
402        Ok(all_tools)
403    }
404
405    /// Execute a tool via tools/call.
406    pub async fn call_tool(
407        &self,
408        name: &str,
409        arguments: HashMap<String, Value>,
410    ) -> Result<McpToolResult, McpError> {
411        let params = serde_json::json!({
412            "name": name,
413            "arguments": arguments,
414        });
415
416        let result = self.send_request("tools/call", Some(params)).await?;
417        let tool_result: McpToolResult = serde_json::from_value(result)?;
418        Ok(tool_result)
419    }
420
421    /// Disconnect from the MCP server.
422    pub async fn disconnect(&self) {
423        let mut transport = self.transport.lock().await;
424        match &mut *transport {
425            Transport::Stdio(stdio) => {
426                // Take ownership of stdin to drop it (signals EOF to child).
427                // After this, stdin is None and the child should exit.
428                let _ = stdio.stdin.take();
429                // Try graceful shutdown, then kill
430                let _ = stdio.child.kill();
431                let _ = stdio.child.wait();
432            }
433            Transport::Http(http) => {
434                // Send HTTP DELETE to terminate session if we have a session ID
435                if let Some(session_id) = &http.session_id {
436                    let mut req = http.client.delete(&http.url);
437                    req = req.header("Mcp-Session-Id", session_id.as_str());
438                    let _ = req.send().await;
439                }
440            }
441        }
442    }
443
444    /// Invalidate cached tools (e.g., after tools/list_changed notification).
445    pub async fn invalidate_cache(&self) {
446        let mut cache = self.cached_tools.lock().await;
447        *cache = None;
448    }
449
450    // -----------------------------------------------------------------------
451    // Internal: send JSON-RPC request and receive response
452    // -----------------------------------------------------------------------
453
454    async fn send_request(&self, method: &str, params: Option<Value>) -> Result<Value, McpError> {
455        let id = self.next_id.fetch_add(1, Ordering::SeqCst);
456        let request = JsonRpcRequest {
457            jsonrpc: "2.0",
458            id,
459            method: method.to_string(),
460            params,
461        };
462
463        let mut transport = self.transport.lock().await;
464        match &mut *transport {
465            Transport::Stdio(stdio) => send_stdio_request(stdio, &request).await,
466            Transport::Http(http) => send_http_request(http, &request, &self.provider_name).await,
467        }
468    }
469
470    async fn send_notification(&self, method: &str, params: Option<Value>) -> Result<(), McpError> {
471        let mut notification = serde_json::json!({
472            "jsonrpc": "2.0",
473            "method": method,
474        });
475        if let Some(p) = params {
476            notification["params"] = p;
477        }
478
479        let mut transport = self.transport.lock().await;
480        match &mut *transport {
481            Transport::Stdio(stdio) => {
482                let stdin = stdio
483                    .stdin
484                    .as_mut()
485                    .ok_or_else(|| McpError::Transport("stdin closed".into()))?;
486                let msg = serde_json::to_string(&notification)?;
487                stdin.write_all(msg.as_bytes())?;
488                stdin.write_all(b"\n")?;
489                stdin.flush()?;
490                Ok(())
491            }
492            Transport::Http(http) => {
493                let mut req = http
494                    .client
495                    .post(&http.url)
496                    .header("Content-Type", "application/json")
497                    .header("Accept", "application/json, text/event-stream")
498                    .json(&notification);
499
500                if let Some(session_id) = &http.session_id {
501                    req = req.header("Mcp-Session-Id", session_id.as_str());
502                }
503                if let Some(auth) = &http.auth_header {
504                    req = req.header(http.auth_header_name.as_str(), auth.as_str());
505                }
506                for (name, value) in &http.extra_headers {
507                    req = req.header(name.as_str(), value.as_str());
508                }
509
510                let resp = req.send().await?;
511                // Notifications should get 202 Accepted
512                if !resp.status().is_success() {
513                    let status = resp.status().as_u16();
514                    let body = resp.text().await.unwrap_or_default();
515                    return Err(McpError::Transport(format!("HTTP {status}: {body}")));
516                }
517                Ok(())
518            }
519        }
520    }
521}
522
523// ---------------------------------------------------------------------------
524// Stdio transport I/O
525// ---------------------------------------------------------------------------
526
527/// Send a JSON-RPC request over stdio and read the response.
528/// Messages are newline-delimited JSON (no embedded newlines).
529async fn send_stdio_request(
530    stdio: &mut StdioTransport,
531    request: &JsonRpcRequest,
532) -> Result<Value, McpError> {
533    let stdin = stdio
534        .stdin
535        .as_mut()
536        .ok_or_else(|| McpError::Transport("stdin closed".into()))?;
537
538    // Serialize and send (newline-delimited)
539    let msg = serde_json::to_string(request)?;
540    stdin.write_all(msg.as_bytes())?;
541    stdin.write_all(b"\n")?;
542    stdin.flush()?;
543
544    let request_id = request.id;
545
546    // Read lines until we get a response matching our request ID.
547    // We may receive notifications interleaved — skip them.
548    loop {
549        let mut line = String::new();
550        let bytes_read = stdio.reader.read_line(&mut line)?;
551        if bytes_read == 0 {
552            return Err(McpError::ProcessExited);
553        }
554
555        let line = line.trim();
556        if line.is_empty() {
557            continue;
558        }
559
560        // Try to parse as JSON-RPC response
561        let parsed: Value = serde_json::from_str(line)?;
562
563        // Check if it's a response (has "id" field matching ours)
564        if let Some(id) = parsed.get("id") {
565            let id_matches = match id {
566                Value::Number(n) => n.as_u64() == Some(request_id),
567                _ => false,
568            };
569
570            if id_matches {
571                // It's our response
572                if let Some(err) = parsed.get("error") {
573                    let code = err.get("code").and_then(|c| c.as_i64()).unwrap_or(-1);
574                    let message = err
575                        .get("message")
576                        .and_then(|m| m.as_str())
577                        .unwrap_or("Unknown error");
578                    return Err(McpError::Protocol {
579                        code,
580                        message: message.to_string(),
581                    });
582                }
583
584                return parsed
585                    .get("result")
586                    .cloned()
587                    .ok_or_else(|| McpError::Protocol {
588                        code: -1,
589                        message: "Response missing 'result' field".into(),
590                    });
591            }
592        }
593
594        // Not our response — it's a notification or someone else's response; skip it.
595    }
596}
597
598// ---------------------------------------------------------------------------
599// HTTP Streamable transport I/O
600// ---------------------------------------------------------------------------
601
602/// Send a JSON-RPC request over Streamable HTTP.
603///
604/// Per MCP spec (2025-03-26):
605/// - POST with Accept: application/json, text/event-stream
606/// - Server may respond with Content-Type: application/json (single response)
607///   or Content-Type: text/event-stream (SSE stream with one or more messages)
608/// - Must handle Mcp-Session-Id header for session management
609async fn send_http_request(
610    http: &mut HttpTransport,
611    request: &JsonRpcRequest,
612    provider_name: &str,
613) -> Result<Value, McpError> {
614    let mut req = http
615        .client
616        .post(&http.url)
617        .header("Content-Type", "application/json")
618        .header("Accept", "application/json, text/event-stream")
619        .json(request);
620
621    // Attach session ID if we have one
622    if let Some(session_id) = &http.session_id {
623        req = req.header("Mcp-Session-Id", session_id.as_str());
624    }
625
626    // Inject auth (using custom header name if configured, e.g. "x-api-key")
627    if let Some(auth) = &http.auth_header {
628        req = req.header(http.auth_header_name.as_str(), auth.as_str());
629    }
630
631    // Inject extra headers from provider config
632    for (name, value) in &http.extra_headers {
633        req = req.header(name.as_str(), value.as_str());
634    }
635
636    let response = req
637        .send()
638        .await
639        .map_err(|e| McpError::Transport(format!("[{provider_name}] HTTP request failed: {e}")))?;
640
641    // Capture session ID from response header (usually set during initialize)
642    if let Some(session_val) = response.headers().get("mcp-session-id") {
643        if let Ok(sid) = session_val.to_str() {
644            http.session_id = Some(sid.to_string());
645        }
646    }
647
648    let status = response.status();
649    if !status.is_success() {
650        let body = response.text().await.unwrap_or_default();
651        return Err(McpError::Transport(format!(
652            "[{provider_name}] HTTP {}: {body}",
653            status.as_u16()
654        )));
655    }
656
657    // Determine response type from Content-Type header
658    let content_type = response
659        .headers()
660        .get("content-type")
661        .and_then(|v| v.to_str().ok())
662        .unwrap_or("")
663        .to_lowercase();
664
665    if content_type.contains("text/event-stream") {
666        // SSE stream — parse events to extract our JSON-RPC response
667        parse_sse_response(response, request.id).await
668    } else {
669        // Plain JSON response
670        let body: Value = response.json().await?;
671        extract_jsonrpc_result(&body, request.id)
672    }
673}
674
675/// Parse an SSE stream from an HTTP response, collecting JSON-RPC messages
676/// until we find the response matching our request ID.
677///
678/// SSE format (per HTML spec):
679///   event: message\n
680///   data: {"jsonrpc":"2.0","id":1,"result":{...}}\n
681///   \n
682///
683/// Each `data:` line contains a JSON-RPC message. The `event:` field is optional.
684/// We may receive notifications and server requests before getting our response.
685/// Maximum SSE response body size (50 MB) to prevent OOM from malicious servers.
686const MAX_SSE_BODY_SIZE: usize = 50 * 1024 * 1024;
687
688async fn parse_sse_response(
689    response: reqwest::Response,
690    request_id: u64,
691) -> Result<Value, McpError> {
692    // Enforce size limit on SSE stream body
693    let bytes = response
694        .bytes()
695        .await
696        .map_err(|e| McpError::SseParse(format!("Failed to read SSE stream: {e}")))?;
697    if bytes.len() > MAX_SSE_BODY_SIZE {
698        return Err(McpError::SseParse(format!(
699            "SSE response body exceeds maximum size ({} bytes > {} bytes)",
700            bytes.len(),
701            MAX_SSE_BODY_SIZE,
702        )));
703    }
704    let full_body = String::from_utf8_lossy(&bytes).into_owned();
705
706    // Parse SSE events
707    let mut current_data = String::new();
708
709    for line in full_body.lines() {
710        if line.starts_with("data:") {
711            let data = line.strip_prefix("data:").unwrap().trim();
712            if !data.is_empty() {
713                current_data.push_str(data);
714            }
715        } else if line.is_empty() && !current_data.is_empty() {
716            // End of event — process the accumulated data
717            match process_sse_data(&current_data, request_id) {
718                SseParseResult::OurResponse(result) => return result,
719                SseParseResult::NotOurMessage => {}
720                SseParseResult::ParseError(e) => {
721                    tracing::warn!(error = %e, "failed to parse SSE data");
722                }
723            }
724            current_data.clear();
725        }
726        // Lines starting with "event:", "id:", "retry:", or ":" are SSE metadata — skip
727    }
728
729    // Handle any remaining data that wasn't terminated by a blank line
730    if !current_data.is_empty() {
731        if let SseParseResult::OurResponse(result) = process_sse_data(&current_data, request_id) {
732            return result;
733        }
734    }
735
736    Err(McpError::SseParse(
737        "SSE stream ended without receiving a response for our request".into(),
738    ))
739}
740
741#[derive(Debug)]
742enum SseParseResult {
743    OurResponse(Result<Value, McpError>),
744    NotOurMessage,
745    ParseError(String),
746}
747
748fn process_sse_data(data: &str, request_id: u64) -> SseParseResult {
749    let parsed: Value = match serde_json::from_str(data) {
750        Ok(v) => v,
751        Err(e) => return SseParseResult::ParseError(e.to_string()),
752    };
753
754    // Could be a single message or a batch (array)
755    let messages = if parsed.is_array() {
756        parsed.as_array().unwrap().clone()
757    } else {
758        vec![parsed]
759    };
760
761    for msg in messages {
762        // Check if it's a response matching our ID
763        if let Some(id) = msg.get("id") {
764            let id_matches = match id {
765                Value::Number(n) => n.as_u64() == Some(request_id),
766                _ => false,
767            };
768            if id_matches {
769                return SseParseResult::OurResponse(extract_jsonrpc_result(&msg, request_id));
770            }
771        }
772        // Otherwise it's a notification or request from server — skip
773    }
774
775    SseParseResult::NotOurMessage
776}
777
778/// Extract the result (or error) from a JSON-RPC response message.
779fn extract_jsonrpc_result(msg: &Value, _request_id: u64) -> Result<Value, McpError> {
780    if let Some(err) = msg.get("error") {
781        let code = err.get("code").and_then(|c| c.as_i64()).unwrap_or(-1);
782        let message = err
783            .get("message")
784            .and_then(|m| m.as_str())
785            .unwrap_or("Unknown error");
786        return Err(McpError::Protocol {
787            code,
788            message: message.to_string(),
789        });
790    }
791
792    msg.get("result")
793        .cloned()
794        .ok_or_else(|| McpError::Protocol {
795            code: -1,
796            message: "Response missing 'result' field".into(),
797        })
798}
799
800// ---------------------------------------------------------------------------
801// Helpers
802// ---------------------------------------------------------------------------
803
804/// Resolve "${key_name}" placeholders in values from the keyring.
805/// Supports both whole-string (`${key}`) and inline (`prefix/${key}/suffix`) patterns.
806fn resolve_env_value(value: &str, keyring: &Keyring) -> String {
807    let mut result = value.to_string();
808    // Find all ${...} patterns and replace them
809    while let Some(start) = result.find("${") {
810        let rest = &result[start + 2..];
811        if let Some(end) = rest.find('}') {
812            let key_name = &rest[..end];
813            let replacement = keyring.get(key_name).unwrap_or("");
814            if replacement.is_empty() && keyring.get(key_name).is_none() {
815                // Key not found — leave the placeholder as-is to avoid breaking the string
816                break;
817            }
818            result = format!("{}{}{}", &result[..start], replacement, &rest[end + 1..]);
819        } else {
820            break; // No closing brace — stop
821        }
822    }
823    result
824}
825
826/// Build an Authorization header value from the provider's auth config.
827fn build_auth_header(provider: &Provider, keyring: &Keyring) -> Option<String> {
828    let key_name = provider.auth_key_name.as_deref()?;
829    let key_value = keyring.get(key_name)?;
830
831    match &provider.auth_type {
832        super::manifest::AuthType::Bearer => Some(format!("Bearer {key_value}")),
833        super::manifest::AuthType::Header => {
834            // For header auth with a custom prefix
835            if let Some(prefix) = &provider.auth_value_prefix {
836                Some(format!("{prefix}{key_value}"))
837            } else {
838                Some(key_value.to_string())
839            }
840        }
841        super::manifest::AuthType::Basic => {
842            let encoded = base64::Engine::encode(
843                &base64::engine::general_purpose::STANDARD,
844                format!("{key_value}:"),
845            );
846            Some(format!("Basic {encoded}"))
847        }
848        _ => None,
849    }
850}
851
852// ---------------------------------------------------------------------------
853// High-level execute function for the call dispatch
854// ---------------------------------------------------------------------------
855
856/// Execute an MCP tool call — high-level entry point for cli/call.rs dispatch.
857///
858/// 1. Connects to the MCP server (or reuses connection via cache — future optimization)
859/// 2. Strips the provider prefix from the tool name (e.g., "github:read_file" → "read_file")
860/// 3. Calls tools/call with the raw MCP tool name
861/// 4. Returns the result as a serde_json::Value
862pub async fn execute(
863    provider: &Provider,
864    tool_name: &str,
865    args: &HashMap<String, Value>,
866    keyring: &Keyring,
867) -> Result<Value, McpError> {
868    execute_with_gen(provider, tool_name, args, keyring, None, None, None).await
869}
870
871/// Execute an MCP tool call with optional dynamic auth generator.
872///
873/// `override_mcp_url`: when `Some`, overrides `provider.mcp_url` for this
874/// call. The proxy passes a sandbox-supplied URL here after glob-matching
875/// it against the operator's per-provider allowlist (issue #124). Ignored
876/// on stdio transport (guarded at manifest load time).
877pub async fn execute_with_gen(
878    provider: &Provider,
879    tool_name: &str,
880    args: &HashMap<String, Value>,
881    keyring: &Keyring,
882    gen_ctx: Option<&GenContext>,
883    auth_cache: Option<&AuthCache>,
884    override_mcp_url: Option<&str>,
885) -> Result<Value, McpError> {
886    let client =
887        McpClient::connect_with_gen(provider, keyring, gen_ctx, auth_cache, override_mcp_url)
888            .await?;
889
890    // Strip provider prefix: "github:read_file" → "read_file"
891    let mcp_tool_name = tool_name
892        .strip_prefix(&format!(
893            "{}{}",
894            provider.name,
895            crate::core::manifest::TOOL_SEP_STR
896        ))
897        .unwrap_or(tool_name);
898
899    let result = client.call_tool(mcp_tool_name, args.clone()).await?;
900
901    // Convert MCP tool result to a single Value for ATI's output system
902    let value = mcp_result_to_value(&result);
903
904    // Clean up
905    client.disconnect().await;
906
907    Ok(value)
908}
909
910/// Convert an McpToolResult to a serde_json::Value.
911fn mcp_result_to_value(result: &McpToolResult) -> Value {
912    if result.content.len() == 1 {
913        // Single content item — unwrap for cleaner output
914        let item = &result.content[0];
915        match item.content_type.as_str() {
916            "text" => {
917                if let Some(text) = &item.text {
918                    // Try to parse as JSON (many MCP tools return JSON as text)
919                    serde_json::from_str(text).unwrap_or_else(|_| Value::String(text.clone()))
920                } else {
921                    Value::Null
922                }
923            }
924            "image" | "audio" => {
925                serde_json::json!({
926                    "type": item.content_type,
927                    "data": item.data,
928                    "mimeType": item.mime_type,
929                })
930            }
931            _ => serde_json::to_value(item).unwrap_or(Value::Null),
932        }
933    } else {
934        // Multiple content items — return as array
935        let items: Vec<Value> = result
936            .content
937            .iter()
938            .map(|c| serde_json::to_value(c).unwrap_or(Value::Null))
939            .collect();
940
941        serde_json::json!({
942            "content": items,
943            "isError": result.is_error,
944        })
945    }
946}
947
948// ---------------------------------------------------------------------------
949// Shared MCP tool discovery (used by both CLI and proxy)
950// ---------------------------------------------------------------------------
951
952/// Discover tools from all MCP providers concurrently and register them.
953///
954/// Each provider has a 30-second timeout. Failures are logged and skipped.
955/// Returns the number of tools discovered.
956pub async fn discover_all_mcp_tools(
957    registry: &mut crate::core::manifest::ManifestRegistry,
958    keyring: &Keyring,
959) -> usize {
960    use futures::stream::{self, StreamExt};
961
962    let providers: Vec<_> = registry
963        .list_mcp_providers()
964        .into_iter()
965        .map(|p| (p.name.clone(), p.clone()))
966        .collect();
967
968    if providers.is_empty() {
969        return 0;
970    }
971
972    // Discover concurrently (up to 10 at a time), with per-provider timeout
973    let results: Vec<_> = stream::iter(&providers)
974        .map(|(name, provider)| async move {
975            let result = tokio::time::timeout(
976                std::time::Duration::from_secs(30),
977                discover_one_provider(name, provider, keyring),
978            )
979            .await;
980
981            match result {
982                Ok(Ok(tools)) => Some((name.clone(), tools)),
983                Ok(Err(e)) => {
984                    tracing::warn!(provider = %name, error = %e, "MCP tool discovery failed");
985                    None
986                }
987                Err(_) => {
988                    tracing::warn!(provider = %name, "MCP tool discovery timed out (30s)");
989                    None
990                }
991            }
992        })
993        .buffer_unordered(10)
994        .collect()
995        .await;
996
997    // Register discovered tools (sequential — fast, just index inserts)
998    let mut total = 0;
999    for (name, tool_defs) in results.into_iter().flatten() {
1000        let count = tool_defs.len();
1001        registry.register_mcp_tools(&name, tool_defs);
1002        tracing::info!(provider = %name, tools = count, "discovered MCP tools");
1003        total += count;
1004    }
1005    total
1006}
1007
1008/// Discover tools from a single MCP provider.
1009async fn discover_one_provider(
1010    _name: &str,
1011    provider: &Provider,
1012    keyring: &Keyring,
1013) -> Result<Vec<crate::core::manifest::McpToolDef>, McpError> {
1014    let client = McpClient::connect(provider, keyring).await?;
1015    let tools = client.list_tools().await;
1016    client.disconnect().await;
1017
1018    let tools = tools?;
1019    Ok(tools
1020        .into_iter()
1021        .map(|t| crate::core::manifest::McpToolDef {
1022            name: t.name,
1023            description: t.description,
1024            input_schema: t.input_schema,
1025        })
1026        .collect())
1027}
1028
1029// ---------------------------------------------------------------------------
1030// Tests
1031// ---------------------------------------------------------------------------
1032
1033#[cfg(test)]
1034mod tests {
1035    use super::*;
1036
1037    #[test]
1038    fn test_resolve_env_value_keyring() {
1039        let keyring = Keyring::empty();
1040        // No key in keyring — should return the raw value
1041        assert_eq!(
1042            resolve_env_value("${missing_key}", &keyring),
1043            "${missing_key}"
1044        );
1045        // Plain value — no resolution
1046        assert_eq!(resolve_env_value("plain_value", &keyring), "plain_value");
1047    }
1048
1049    #[test]
1050    fn test_resolve_env_value_inline() {
1051        // Build a keyring with a test key via load_credentials
1052        let dir = tempfile::TempDir::new().unwrap();
1053        let path = dir.path().join("creds");
1054        std::fs::write(&path, r#"{"my_key":"SECRET123"}"#).unwrap();
1055        let keyring = Keyring::load_credentials(&path).unwrap();
1056
1057        // Whole-string
1058        assert_eq!(resolve_env_value("${my_key}", &keyring), "SECRET123");
1059        // Inline
1060        assert_eq!(
1061            resolve_env_value("https://example.com/${my_key}/path", &keyring),
1062            "https://example.com/SECRET123/path"
1063        );
1064        // Multiple placeholders
1065        assert_eq!(
1066            resolve_env_value("${my_key}--${my_key}", &keyring),
1067            "SECRET123--SECRET123"
1068        );
1069        // Missing key stays as-is
1070        assert_eq!(
1071            resolve_env_value("https://example.com/${unknown}/path", &keyring),
1072            "https://example.com/${unknown}/path"
1073        );
1074        // No placeholder
1075        assert_eq!(
1076            resolve_env_value("no_placeholder", &keyring),
1077            "no_placeholder"
1078        );
1079    }
1080
1081    #[test]
1082    fn test_mcp_result_to_value_single_text() {
1083        let result = McpToolResult {
1084            content: vec![McpContent {
1085                content_type: "text".into(),
1086                text: Some("hello world".into()),
1087                data: None,
1088                mime_type: None,
1089            }],
1090            is_error: false,
1091        };
1092        assert_eq!(
1093            mcp_result_to_value(&result),
1094            Value::String("hello world".into())
1095        );
1096    }
1097
1098    #[test]
1099    fn test_mcp_result_to_value_json_text() {
1100        let result = McpToolResult {
1101            content: vec![McpContent {
1102                content_type: "text".into(),
1103                text: Some(r#"{"key":"value"}"#.into()),
1104                data: None,
1105                mime_type: None,
1106            }],
1107            is_error: false,
1108        };
1109        let val = mcp_result_to_value(&result);
1110        assert_eq!(val, serde_json::json!({"key": "value"}));
1111    }
1112
1113    #[test]
1114    fn test_extract_jsonrpc_result_success() {
1115        let msg = serde_json::json!({
1116            "jsonrpc": "2.0",
1117            "id": 1,
1118            "result": {"tools": []}
1119        });
1120        let result = extract_jsonrpc_result(&msg, 1).unwrap();
1121        assert_eq!(result, serde_json::json!({"tools": []}));
1122    }
1123
1124    #[test]
1125    fn test_extract_jsonrpc_result_error() {
1126        let msg = serde_json::json!({
1127            "jsonrpc": "2.0",
1128            "id": 1,
1129            "error": {"code": -32602, "message": "Invalid params"}
1130        });
1131        let err = extract_jsonrpc_result(&msg, 1).unwrap_err();
1132        assert!(matches!(err, McpError::Protocol { code: -32602, .. }));
1133    }
1134
1135    #[test]
1136    fn test_process_sse_data_matching_response() {
1137        let data = r#"{"jsonrpc":"2.0","id":5,"result":{"tools":[]}}"#;
1138        match process_sse_data(data, 5) {
1139            SseParseResult::OurResponse(Ok(val)) => {
1140                assert_eq!(val, serde_json::json!({"tools": []}));
1141            }
1142            _ => panic!("Expected OurResponse"),
1143        }
1144    }
1145
1146    #[test]
1147    fn test_process_sse_data_notification() {
1148        // Notifications don't have "id" — should be skipped
1149        let data = r#"{"jsonrpc":"2.0","method":"notifications/tools/list_changed"}"#;
1150        match process_sse_data(data, 5) {
1151            SseParseResult::NotOurMessage => {}
1152            _ => panic!("Expected NotOurMessage"),
1153        }
1154    }
1155
1156    #[test]
1157    fn test_process_sse_data_batch() {
1158        let data = r#"[
1159            {"jsonrpc":"2.0","method":"notifications/progress","params":{}},
1160            {"jsonrpc":"2.0","id":3,"result":{"content":[],"isError":false}}
1161        ]"#;
1162        match process_sse_data(data, 3) {
1163            SseParseResult::OurResponse(Ok(val)) => {
1164                assert!(val.get("content").is_some());
1165            }
1166            _ => panic!("Expected OurResponse from batch"),
1167        }
1168    }
1169
1170    #[test]
1171    fn test_process_sse_data_invalid_json() {
1172        let data = "not valid json {{{}";
1173        match process_sse_data(data, 1) {
1174            SseParseResult::ParseError(_) => {}
1175            other => panic!("Expected ParseError, got: {other:?}"),
1176        }
1177    }
1178
1179    #[test]
1180    fn test_process_sse_data_wrong_id() {
1181        let data = r#"{"jsonrpc":"2.0","id":99,"result":{"data":"wrong"}}"#;
1182        match process_sse_data(data, 1) {
1183            SseParseResult::NotOurMessage => {}
1184            _ => panic!("Expected NotOurMessage for wrong ID"),
1185        }
1186    }
1187
1188    #[test]
1189    fn test_process_sse_data_empty_batch() {
1190        let data = "[]";
1191        match process_sse_data(data, 1) {
1192            SseParseResult::NotOurMessage => {}
1193            _ => panic!("Expected NotOurMessage for empty batch"),
1194        }
1195    }
1196
1197    #[test]
1198    fn test_extract_jsonrpc_result_missing_result() {
1199        let msg = serde_json::json!({
1200            "jsonrpc": "2.0",
1201            "id": 1
1202        });
1203        let err = extract_jsonrpc_result(&msg, 1).unwrap_err();
1204        assert!(matches!(err, McpError::Protocol { code: -1, .. }));
1205    }
1206
1207    #[test]
1208    fn test_extract_jsonrpc_error_defaults() {
1209        // Error with missing code and message fields
1210        let msg = serde_json::json!({
1211            "jsonrpc": "2.0",
1212            "id": 1,
1213            "error": {}
1214        });
1215        let err = extract_jsonrpc_result(&msg, 1).unwrap_err();
1216        match err {
1217            McpError::Protocol { code, message } => {
1218                assert_eq!(code, -1);
1219                assert_eq!(message, "Unknown error");
1220            }
1221            _ => panic!("Expected Protocol error"),
1222        }
1223    }
1224
1225    #[test]
1226    fn test_mcp_result_to_value_error() {
1227        let result = McpToolResult {
1228            content: vec![McpContent {
1229                content_type: "text".into(),
1230                text: Some("Something went wrong".into()),
1231                data: None,
1232                mime_type: None,
1233            }],
1234            is_error: true,
1235        };
1236        let val = mcp_result_to_value(&result);
1237        assert_eq!(val, Value::String("Something went wrong".into()));
1238    }
1239
1240    #[test]
1241    fn test_mcp_result_to_value_multiple_content() {
1242        let result = McpToolResult {
1243            content: vec![
1244                McpContent {
1245                    content_type: "text".into(),
1246                    text: Some("Part 1".into()),
1247                    data: None,
1248                    mime_type: None,
1249                },
1250                McpContent {
1251                    content_type: "text".into(),
1252                    text: Some("Part 2".into()),
1253                    data: None,
1254                    mime_type: None,
1255                },
1256            ],
1257            is_error: false,
1258        };
1259        let val = mcp_result_to_value(&result);
1260        // Multiple items → {"content": [...], "isError": false}
1261        let content_arr = val["content"].as_array().unwrap();
1262        assert_eq!(content_arr.len(), 2);
1263        assert_eq!(val["isError"], false);
1264    }
1265
1266    #[test]
1267    fn test_mcp_result_to_value_empty_content() {
1268        let result = McpToolResult {
1269            content: vec![],
1270            is_error: false,
1271        };
1272        let val = mcp_result_to_value(&result);
1273        // Empty content → {"content": [], "isError": false}
1274        assert_eq!(val["content"].as_array().unwrap().len(), 0);
1275        assert_eq!(val["isError"], false);
1276    }
1277
1278    #[test]
1279    fn test_resolve_env_value_unclosed_brace() {
1280        let keyring = Keyring::empty();
1281        assert_eq!(resolve_env_value("${unclosed", &keyring), "${unclosed");
1282    }
1283}