Skip to main content

ati/core/
mcp_client.rs

1/// MCP client — connects to MCP servers via stdio or Streamable HTTP transport.
2///
3/// Implements the MCP protocol (2025-03-26 revision):
4/// - JSON-RPC 2.0 message framing
5/// - stdio transport: newline-delimited JSON over stdin/stdout
6/// - Streamable HTTP transport: POST with Accept: application/json, text/event-stream
7///   Server may respond with JSON or SSE stream. Supports Mcp-Session-Id for sessions.
8/// - Lifecycle: initialize → tools/list → tools/call → shutdown
9use serde::{Deserialize, Serialize};
10use serde_json::Value;
11use std::collections::HashMap;
12use std::io::{BufRead, BufReader, Write as IoWrite};
13use std::process::{Child, Command, Stdio};
14use std::sync::atomic::{AtomicU64, Ordering};
15use thiserror::Error;
16use tokio::sync::Mutex;
17
18use crate::core::auth_generator::{self, AuthCache, GenContext};
19use crate::core::keyring::Keyring;
20use crate::core::manifest::Provider;
21
22// ---------------------------------------------------------------------------
23// Errors
24// ---------------------------------------------------------------------------
25
26#[derive(Error, Debug)]
27#[allow(dead_code)]
28pub enum McpError {
29    #[error("MCP transport error: {0}")]
30    Transport(String),
31    #[error("MCP protocol error (code {code}): {message}")]
32    Protocol { code: i64, message: String },
33    #[error("MCP server did not return tools capability")]
34    NoToolsCapability,
35    #[error("IO error: {0}")]
36    Io(#[from] std::io::Error),
37    #[error("JSON error: {0}")]
38    Json(#[from] serde_json::Error),
39    #[error("HTTP error: {0}")]
40    Http(#[from] reqwest::Error),
41    #[error("MCP initialization failed: {0}")]
42    InitFailed(String),
43    #[error("SSE parse error: {0}")]
44    SseParse(String),
45    #[error("MCP server process exited unexpectedly")]
46    ProcessExited,
47    #[error("Missing MCP configuration: {0}")]
48    Config(String),
49}
50
51// ---------------------------------------------------------------------------
52// JSON-RPC types
53// ---------------------------------------------------------------------------
54
55#[derive(Debug, Serialize)]
56struct JsonRpcRequest {
57    jsonrpc: &'static str,
58    id: u64,
59    method: String,
60    #[serde(skip_serializing_if = "Option::is_none")]
61    params: Option<Value>,
62}
63
64// Note: We parse JSON-RPC responses manually via serde_json::Value
65// rather than typed deserialization, since responses can be interleaved
66// with notifications and batches in SSE streams.
67
68// ---------------------------------------------------------------------------
69// MCP protocol types
70// ---------------------------------------------------------------------------
71
72/// Tool definition from MCP tools/list response.
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct McpToolDef {
75    pub name: String,
76    #[serde(default)]
77    pub description: Option<String>,
78    #[serde(default, rename = "inputSchema")]
79    pub input_schema: Option<Value>,
80}
81
82/// Content item from MCP tools/call response.
83#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct McpContent {
85    #[serde(rename = "type")]
86    pub content_type: String,
87    #[serde(default)]
88    pub text: Option<String>,
89    #[serde(default)]
90    pub data: Option<String>,
91    #[serde(default, rename = "mimeType")]
92    pub mime_type: Option<String>,
93}
94
95/// Result from tools/call.
96#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct McpToolResult {
98    pub content: Vec<McpContent>,
99    #[serde(default, rename = "isError")]
100    pub is_error: bool,
101}
102
103// ---------------------------------------------------------------------------
104// Transport abstraction
105// ---------------------------------------------------------------------------
106
107/// Internal transport enum — stdio or HTTP.
108enum Transport {
109    Stdio(StdioTransport),
110    Http(HttpTransport),
111}
112
113/// Stdio transport: subprocess with stdin/stdout.
114struct StdioTransport {
115    child: Child,
116    /// We write JSON-RPC to the child's stdin. Option so we can take() on disconnect.
117    stdin: Option<std::process::ChildStdin>,
118    /// We read newline-delimited JSON-RPC from stdout.
119    reader: BufReader<std::process::ChildStdout>,
120}
121
122impl Drop for StdioTransport {
123    fn drop(&mut self) {
124        // Best-effort cleanup: kill the subprocess if it wasn't explicitly disconnected.
125        // Prevents orphan zombie processes when a future is cancelled (e.g., on timeout).
126        let _ = self.child.kill();
127        let _ = self.child.wait();
128    }
129}
130
131/// Streamable HTTP transport: POST to MCP endpoint.
132struct HttpTransport {
133    client: reqwest::Client,
134    url: String,
135    /// Session ID from Mcp-Session-Id header (set after initialize).
136    session_id: Option<String>,
137    /// Auth header name (default: "Authorization"). Custom for APIs using e.g. "x-api-key".
138    auth_header_name: String,
139    /// Auth header value (e.g., "Bearer <token>") injected on every request.
140    auth_header: Option<String>,
141    /// Extra headers from provider config.
142    extra_headers: HashMap<String, String>,
143}
144
145// ---------------------------------------------------------------------------
146// McpClient
147// ---------------------------------------------------------------------------
148
149/// MCP client that connects to a single MCP server.
150pub struct McpClient {
151    transport: Mutex<Transport>,
152    next_id: AtomicU64,
153    /// Cached tools from tools/list.
154    cached_tools: Mutex<Option<Vec<McpToolDef>>>,
155    /// Provider name (for logging).
156    provider_name: String,
157}
158
159impl McpClient {
160    /// Connect to an MCP server based on the provider's configuration.
161    ///
162    /// For stdio: spawns the subprocess with env vars resolved from keyring.
163    /// For HTTP: creates an HTTP client with auth headers.
164    pub async fn connect(provider: &Provider, keyring: &Keyring) -> Result<Self, McpError> {
165        Self::connect_with_gen(provider, keyring, None, None).await
166    }
167
168    /// Connect to an MCP server, optionally using a dynamic auth generator.
169    pub async fn connect_with_gen(
170        provider: &Provider,
171        keyring: &Keyring,
172        gen_ctx: Option<&GenContext>,
173        auth_cache: Option<&AuthCache>,
174    ) -> Result<Self, McpError> {
175        let transport = match provider.mcp_transport_type() {
176            "stdio" => {
177                let command = provider.mcp_command.as_deref().ok_or_else(|| {
178                    McpError::Config("mcp_command required for stdio transport".into())
179                })?;
180
181                // Resolve env vars: "${key_name}" → keyring value
182                let mut env_map: HashMap<String, String> = HashMap::new();
183                // Selectively pass through essential env vars (don't leak secrets)
184                if let Ok(path) = std::env::var("PATH") {
185                    env_map.insert("PATH".to_string(), path);
186                }
187                if let Ok(home) = std::env::var("HOME") {
188                    env_map.insert("HOME".to_string(), home);
189                }
190                // Add provider-specific env vars (resolved from keyring)
191                for (k, v) in &provider.mcp_env {
192                    let resolved = resolve_env_value(v, keyring);
193                    env_map.insert(k.clone(), resolved);
194                }
195
196                // If auth_generator is configured, run it and inject into env
197                if let Some(gen) = &provider.auth_generator {
198                    let default_ctx = GenContext::default();
199                    let ctx = gen_ctx.unwrap_or(&default_ctx);
200                    let default_cache = AuthCache::new();
201                    let cache = auth_cache.unwrap_or(&default_cache);
202                    match auth_generator::generate(provider, gen, ctx, keyring, cache).await {
203                        Ok(cred) => {
204                            env_map.insert("ATI_AUTH_TOKEN".to_string(), cred.value);
205                            for (k, v) in &cred.extra_env {
206                                env_map.insert(k.clone(), v.clone());
207                            }
208                        }
209                        Err(e) => {
210                            return Err(McpError::Config(format!("auth_generator failed: {e}")));
211                        }
212                    }
213                }
214
215                let mut child = Command::new(command)
216                    .args(&provider.mcp_args)
217                    .stdin(Stdio::piped())
218                    .stdout(Stdio::piped())
219                    .stderr(Stdio::piped())
220                    .env_clear()
221                    .envs(&env_map)
222                    .spawn()
223                    .map_err(|e| {
224                        McpError::Transport(format!("Failed to spawn MCP server '{command}': {e}"))
225                    })?;
226
227                let stdin = child
228                    .stdin
229                    .take()
230                    .ok_or_else(|| McpError::Transport("No stdin".into()))?;
231                let stdout = child
232                    .stdout
233                    .take()
234                    .ok_or_else(|| McpError::Transport("No stdout".into()))?;
235                let reader = BufReader::new(stdout);
236
237                Transport::Stdio(StdioTransport {
238                    child,
239                    stdin: Some(stdin),
240                    reader,
241                })
242            }
243            "http" => {
244                let url = provider.mcp_url.as_deref().ok_or_else(|| {
245                    McpError::Config("mcp_url required for HTTP transport".into())
246                })?;
247
248                // Build auth header: generator takes priority over static keyring
249                let auth_header = if let Some(gen) = &provider.auth_generator {
250                    let default_ctx = GenContext::default();
251                    let ctx = gen_ctx.unwrap_or(&default_ctx);
252                    let default_cache = AuthCache::new();
253                    let cache = auth_cache.unwrap_or(&default_cache);
254                    match auth_generator::generate(provider, gen, ctx, keyring, cache).await {
255                        Ok(cred) => match &provider.auth_type {
256                            super::manifest::AuthType::Bearer => {
257                                Some(format!("Bearer {}", cred.value))
258                            }
259                            super::manifest::AuthType::Header => {
260                                if let Some(prefix) = &provider.auth_value_prefix {
261                                    Some(format!("{prefix}{}", cred.value))
262                                } else {
263                                    Some(cred.value)
264                                }
265                            }
266                            _ => Some(cred.value),
267                        },
268                        Err(e) => {
269                            return Err(McpError::Config(format!("auth_generator failed: {e}")));
270                        }
271                    }
272                } else {
273                    build_auth_header(provider, keyring)
274                };
275
276                let client = reqwest::Client::builder()
277                    .timeout(std::time::Duration::from_secs(300))
278                    .build()?;
279
280                // Resolve ${key_name} placeholders in the URL from keyring
281                let resolved_url = resolve_env_value(url, keyring);
282
283                let auth_header_name = provider
284                    .auth_header_name
285                    .clone()
286                    .unwrap_or_else(|| "Authorization".to_string());
287
288                Transport::Http(HttpTransport {
289                    client,
290                    url: resolved_url,
291                    session_id: None,
292                    auth_header_name,
293                    auth_header,
294                    extra_headers: provider.extra_headers.clone(),
295                })
296            }
297            other => {
298                return Err(McpError::Config(format!(
299                    "Unknown MCP transport: '{other}' (expected 'stdio' or 'http')"
300                )));
301            }
302        };
303
304        let client = McpClient {
305            transport: Mutex::new(transport),
306            next_id: AtomicU64::new(1),
307            cached_tools: Mutex::new(None),
308            provider_name: provider.name.clone(),
309        };
310
311        // Perform MCP initialize handshake
312        client.initialize().await?;
313
314        Ok(client)
315    }
316
317    /// Perform the MCP initialize handshake.
318    async fn initialize(&self) -> Result<(), McpError> {
319        let params = serde_json::json!({
320            "protocolVersion": "2025-03-26",
321            "capabilities": {},
322            "clientInfo": {
323                "name": "ati",
324                "version": env!("CARGO_PKG_VERSION")
325            }
326        });
327
328        let response = self.send_request("initialize", Some(params)).await?;
329
330        // Verify server has tools capability
331        let capabilities = response.get("capabilities").unwrap_or(&Value::Null);
332        if capabilities.get("tools").is_none() {
333            return Err(McpError::NoToolsCapability);
334        }
335
336        // Extract session ID from HTTP transport response (handled inside send_request)
337
338        // Send initialized notification
339        self.send_notification("notifications/initialized", None)
340            .await?;
341
342        Ok(())
343    }
344
345    /// Discover tools via tools/list. Results are cached.
346    pub async fn list_tools(&self) -> Result<Vec<McpToolDef>, McpError> {
347        // Return cached if available
348        {
349            let cache = self.cached_tools.lock().await;
350            if let Some(tools) = cache.as_ref() {
351                return Ok(tools.clone());
352            }
353        }
354
355        let mut all_tools = Vec::new();
356        let mut cursor: Option<String> = None;
357        const MAX_PAGES: usize = 100;
358        const MAX_TOOLS: usize = 10_000;
359
360        for _page in 0..MAX_PAGES {
361            let params = cursor.as_ref().map(|c| serde_json::json!({"cursor": c}));
362            let result = self.send_request("tools/list", params).await?;
363
364            if let Some(tools_val) = result.get("tools") {
365                let tools: Vec<McpToolDef> = serde_json::from_value(tools_val.clone())?;
366                all_tools.extend(tools);
367            }
368
369            // Safety: cap total tools to prevent memory exhaustion
370            if all_tools.len() > MAX_TOOLS {
371                tracing::warn!(max = MAX_TOOLS, "MCP tool count exceeds limit, truncating");
372                all_tools.truncate(MAX_TOOLS);
373                break;
374            }
375
376            // Check for pagination
377            match result.get("nextCursor").and_then(|v| v.as_str()) {
378                Some(next) => cursor = Some(next.to_string()),
379                None => break,
380            }
381        }
382
383        // Cache the result
384        {
385            let mut cache = self.cached_tools.lock().await;
386            *cache = Some(all_tools.clone());
387        }
388
389        Ok(all_tools)
390    }
391
392    /// Execute a tool via tools/call.
393    pub async fn call_tool(
394        &self,
395        name: &str,
396        arguments: HashMap<String, Value>,
397    ) -> Result<McpToolResult, McpError> {
398        let params = serde_json::json!({
399            "name": name,
400            "arguments": arguments,
401        });
402
403        let result = self.send_request("tools/call", Some(params)).await?;
404        let tool_result: McpToolResult = serde_json::from_value(result)?;
405        Ok(tool_result)
406    }
407
408    /// Disconnect from the MCP server.
409    pub async fn disconnect(&self) {
410        let mut transport = self.transport.lock().await;
411        match &mut *transport {
412            Transport::Stdio(stdio) => {
413                // Take ownership of stdin to drop it (signals EOF to child).
414                // After this, stdin is None and the child should exit.
415                let _ = stdio.stdin.take();
416                // Try graceful shutdown, then kill
417                let _ = stdio.child.kill();
418                let _ = stdio.child.wait();
419            }
420            Transport::Http(http) => {
421                // Send HTTP DELETE to terminate session if we have a session ID
422                if let Some(session_id) = &http.session_id {
423                    let mut req = http.client.delete(&http.url);
424                    req = req.header("Mcp-Session-Id", session_id.as_str());
425                    let _ = req.send().await;
426                }
427            }
428        }
429    }
430
431    /// Invalidate cached tools (e.g., after tools/list_changed notification).
432    pub async fn invalidate_cache(&self) {
433        let mut cache = self.cached_tools.lock().await;
434        *cache = None;
435    }
436
437    // -----------------------------------------------------------------------
438    // Internal: send JSON-RPC request and receive response
439    // -----------------------------------------------------------------------
440
441    async fn send_request(&self, method: &str, params: Option<Value>) -> Result<Value, McpError> {
442        let id = self.next_id.fetch_add(1, Ordering::SeqCst);
443        let request = JsonRpcRequest {
444            jsonrpc: "2.0",
445            id,
446            method: method.to_string(),
447            params,
448        };
449
450        let mut transport = self.transport.lock().await;
451        match &mut *transport {
452            Transport::Stdio(stdio) => send_stdio_request(stdio, &request).await,
453            Transport::Http(http) => send_http_request(http, &request, &self.provider_name).await,
454        }
455    }
456
457    async fn send_notification(&self, method: &str, params: Option<Value>) -> Result<(), McpError> {
458        let mut notification = serde_json::json!({
459            "jsonrpc": "2.0",
460            "method": method,
461        });
462        if let Some(p) = params {
463            notification["params"] = p;
464        }
465
466        let mut transport = self.transport.lock().await;
467        match &mut *transport {
468            Transport::Stdio(stdio) => {
469                let stdin = stdio
470                    .stdin
471                    .as_mut()
472                    .ok_or_else(|| McpError::Transport("stdin closed".into()))?;
473                let msg = serde_json::to_string(&notification)?;
474                stdin.write_all(msg.as_bytes())?;
475                stdin.write_all(b"\n")?;
476                stdin.flush()?;
477                Ok(())
478            }
479            Transport::Http(http) => {
480                let mut req = http
481                    .client
482                    .post(&http.url)
483                    .header("Content-Type", "application/json")
484                    .header("Accept", "application/json, text/event-stream")
485                    .json(&notification);
486
487                if let Some(session_id) = &http.session_id {
488                    req = req.header("Mcp-Session-Id", session_id.as_str());
489                }
490                if let Some(auth) = &http.auth_header {
491                    req = req.header(http.auth_header_name.as_str(), auth.as_str());
492                }
493                for (name, value) in &http.extra_headers {
494                    req = req.header(name.as_str(), value.as_str());
495                }
496
497                let resp = req.send().await?;
498                // Notifications should get 202 Accepted
499                if !resp.status().is_success() {
500                    let status = resp.status().as_u16();
501                    let body = resp.text().await.unwrap_or_default();
502                    return Err(McpError::Transport(format!("HTTP {status}: {body}")));
503                }
504                Ok(())
505            }
506        }
507    }
508}
509
510// ---------------------------------------------------------------------------
511// Stdio transport I/O
512// ---------------------------------------------------------------------------
513
514/// Send a JSON-RPC request over stdio and read the response.
515/// Messages are newline-delimited JSON (no embedded newlines).
516async fn send_stdio_request(
517    stdio: &mut StdioTransport,
518    request: &JsonRpcRequest,
519) -> Result<Value, McpError> {
520    let stdin = stdio
521        .stdin
522        .as_mut()
523        .ok_or_else(|| McpError::Transport("stdin closed".into()))?;
524
525    // Serialize and send (newline-delimited)
526    let msg = serde_json::to_string(request)?;
527    stdin.write_all(msg.as_bytes())?;
528    stdin.write_all(b"\n")?;
529    stdin.flush()?;
530
531    let request_id = request.id;
532
533    // Read lines until we get a response matching our request ID.
534    // We may receive notifications interleaved — skip them.
535    loop {
536        let mut line = String::new();
537        let bytes_read = stdio.reader.read_line(&mut line)?;
538        if bytes_read == 0 {
539            return Err(McpError::ProcessExited);
540        }
541
542        let line = line.trim();
543        if line.is_empty() {
544            continue;
545        }
546
547        // Try to parse as JSON-RPC response
548        let parsed: Value = serde_json::from_str(line)?;
549
550        // Check if it's a response (has "id" field matching ours)
551        if let Some(id) = parsed.get("id") {
552            let id_matches = match id {
553                Value::Number(n) => n.as_u64() == Some(request_id),
554                _ => false,
555            };
556
557            if id_matches {
558                // It's our response
559                if let Some(err) = parsed.get("error") {
560                    let code = err.get("code").and_then(|c| c.as_i64()).unwrap_or(-1);
561                    let message = err
562                        .get("message")
563                        .and_then(|m| m.as_str())
564                        .unwrap_or("Unknown error");
565                    return Err(McpError::Protocol {
566                        code,
567                        message: message.to_string(),
568                    });
569                }
570
571                return parsed
572                    .get("result")
573                    .cloned()
574                    .ok_or_else(|| McpError::Protocol {
575                        code: -1,
576                        message: "Response missing 'result' field".into(),
577                    });
578            }
579        }
580
581        // Not our response — it's a notification or someone else's response; skip it.
582    }
583}
584
585// ---------------------------------------------------------------------------
586// HTTP Streamable transport I/O
587// ---------------------------------------------------------------------------
588
589/// Send a JSON-RPC request over Streamable HTTP.
590///
591/// Per MCP spec (2025-03-26):
592/// - POST with Accept: application/json, text/event-stream
593/// - Server may respond with Content-Type: application/json (single response)
594///   or Content-Type: text/event-stream (SSE stream with one or more messages)
595/// - Must handle Mcp-Session-Id header for session management
596async fn send_http_request(
597    http: &mut HttpTransport,
598    request: &JsonRpcRequest,
599    provider_name: &str,
600) -> Result<Value, McpError> {
601    let mut req = http
602        .client
603        .post(&http.url)
604        .header("Content-Type", "application/json")
605        .header("Accept", "application/json, text/event-stream")
606        .json(request);
607
608    // Attach session ID if we have one
609    if let Some(session_id) = &http.session_id {
610        req = req.header("Mcp-Session-Id", session_id.as_str());
611    }
612
613    // Inject auth (using custom header name if configured, e.g. "x-api-key")
614    if let Some(auth) = &http.auth_header {
615        req = req.header(http.auth_header_name.as_str(), auth.as_str());
616    }
617
618    // Inject extra headers from provider config
619    for (name, value) in &http.extra_headers {
620        req = req.header(name.as_str(), value.as_str());
621    }
622
623    let response = req
624        .send()
625        .await
626        .map_err(|e| McpError::Transport(format!("[{provider_name}] HTTP request failed: {e}")))?;
627
628    // Capture session ID from response header (usually set during initialize)
629    if let Some(session_val) = response.headers().get("mcp-session-id") {
630        if let Ok(sid) = session_val.to_str() {
631            http.session_id = Some(sid.to_string());
632        }
633    }
634
635    let status = response.status();
636    if !status.is_success() {
637        let body = response.text().await.unwrap_or_default();
638        return Err(McpError::Transport(format!(
639            "[{provider_name}] HTTP {}: {body}",
640            status.as_u16()
641        )));
642    }
643
644    // Determine response type from Content-Type header
645    let content_type = response
646        .headers()
647        .get("content-type")
648        .and_then(|v| v.to_str().ok())
649        .unwrap_or("")
650        .to_lowercase();
651
652    if content_type.contains("text/event-stream") {
653        // SSE stream — parse events to extract our JSON-RPC response
654        parse_sse_response(response, request.id).await
655    } else {
656        // Plain JSON response
657        let body: Value = response.json().await?;
658        extract_jsonrpc_result(&body, request.id)
659    }
660}
661
662/// Parse an SSE stream from an HTTP response, collecting JSON-RPC messages
663/// until we find the response matching our request ID.
664///
665/// SSE format (per HTML spec):
666///   event: message\n
667///   data: {"jsonrpc":"2.0","id":1,"result":{...}}\n
668///   \n
669///
670/// Each `data:` line contains a JSON-RPC message. The `event:` field is optional.
671/// We may receive notifications and server requests before getting our response.
672/// Maximum SSE response body size (50 MB) to prevent OOM from malicious servers.
673const MAX_SSE_BODY_SIZE: usize = 50 * 1024 * 1024;
674
675async fn parse_sse_response(
676    response: reqwest::Response,
677    request_id: u64,
678) -> Result<Value, McpError> {
679    // Enforce size limit on SSE stream body
680    let bytes = response
681        .bytes()
682        .await
683        .map_err(|e| McpError::SseParse(format!("Failed to read SSE stream: {e}")))?;
684    if bytes.len() > MAX_SSE_BODY_SIZE {
685        return Err(McpError::SseParse(format!(
686            "SSE response body exceeds maximum size ({} bytes > {} bytes)",
687            bytes.len(),
688            MAX_SSE_BODY_SIZE,
689        )));
690    }
691    let full_body = String::from_utf8_lossy(&bytes).into_owned();
692
693    // Parse SSE events
694    let mut current_data = String::new();
695
696    for line in full_body.lines() {
697        if line.starts_with("data:") {
698            let data = line.strip_prefix("data:").unwrap().trim();
699            if !data.is_empty() {
700                current_data.push_str(data);
701            }
702        } else if line.is_empty() && !current_data.is_empty() {
703            // End of event — process the accumulated data
704            match process_sse_data(&current_data, request_id) {
705                SseParseResult::OurResponse(result) => return result,
706                SseParseResult::NotOurMessage => {}
707                SseParseResult::ParseError(e) => {
708                    tracing::warn!(error = %e, "failed to parse SSE data");
709                }
710            }
711            current_data.clear();
712        }
713        // Lines starting with "event:", "id:", "retry:", or ":" are SSE metadata — skip
714    }
715
716    // Handle any remaining data that wasn't terminated by a blank line
717    if !current_data.is_empty() {
718        if let SseParseResult::OurResponse(result) = process_sse_data(&current_data, request_id) {
719            return result;
720        }
721    }
722
723    Err(McpError::SseParse(
724        "SSE stream ended without receiving a response for our request".into(),
725    ))
726}
727
728#[derive(Debug)]
729enum SseParseResult {
730    OurResponse(Result<Value, McpError>),
731    NotOurMessage,
732    ParseError(String),
733}
734
735fn process_sse_data(data: &str, request_id: u64) -> SseParseResult {
736    let parsed: Value = match serde_json::from_str(data) {
737        Ok(v) => v,
738        Err(e) => return SseParseResult::ParseError(e.to_string()),
739    };
740
741    // Could be a single message or a batch (array)
742    let messages = if parsed.is_array() {
743        parsed.as_array().unwrap().clone()
744    } else {
745        vec![parsed]
746    };
747
748    for msg in messages {
749        // Check if it's a response matching our ID
750        if let Some(id) = msg.get("id") {
751            let id_matches = match id {
752                Value::Number(n) => n.as_u64() == Some(request_id),
753                _ => false,
754            };
755            if id_matches {
756                return SseParseResult::OurResponse(extract_jsonrpc_result(&msg, request_id));
757            }
758        }
759        // Otherwise it's a notification or request from server — skip
760    }
761
762    SseParseResult::NotOurMessage
763}
764
765/// Extract the result (or error) from a JSON-RPC response message.
766fn extract_jsonrpc_result(msg: &Value, _request_id: u64) -> Result<Value, McpError> {
767    if let Some(err) = msg.get("error") {
768        let code = err.get("code").and_then(|c| c.as_i64()).unwrap_or(-1);
769        let message = err
770            .get("message")
771            .and_then(|m| m.as_str())
772            .unwrap_or("Unknown error");
773        return Err(McpError::Protocol {
774            code,
775            message: message.to_string(),
776        });
777    }
778
779    msg.get("result")
780        .cloned()
781        .ok_or_else(|| McpError::Protocol {
782            code: -1,
783            message: "Response missing 'result' field".into(),
784        })
785}
786
787// ---------------------------------------------------------------------------
788// Helpers
789// ---------------------------------------------------------------------------
790
791/// Resolve "${key_name}" placeholders in values from the keyring.
792/// Supports both whole-string (`${key}`) and inline (`prefix/${key}/suffix`) patterns.
793fn resolve_env_value(value: &str, keyring: &Keyring) -> String {
794    let mut result = value.to_string();
795    // Find all ${...} patterns and replace them
796    while let Some(start) = result.find("${") {
797        let rest = &result[start + 2..];
798        if let Some(end) = rest.find('}') {
799            let key_name = &rest[..end];
800            let replacement = keyring.get(key_name).unwrap_or("");
801            if replacement.is_empty() && keyring.get(key_name).is_none() {
802                // Key not found — leave the placeholder as-is to avoid breaking the string
803                break;
804            }
805            result = format!("{}{}{}", &result[..start], replacement, &rest[end + 1..]);
806        } else {
807            break; // No closing brace — stop
808        }
809    }
810    result
811}
812
813/// Build an Authorization header value from the provider's auth config.
814fn build_auth_header(provider: &Provider, keyring: &Keyring) -> Option<String> {
815    let key_name = provider.auth_key_name.as_deref()?;
816    let key_value = keyring.get(key_name)?;
817
818    match &provider.auth_type {
819        super::manifest::AuthType::Bearer => Some(format!("Bearer {key_value}")),
820        super::manifest::AuthType::Header => {
821            // For header auth with a custom prefix
822            if let Some(prefix) = &provider.auth_value_prefix {
823                Some(format!("{prefix}{key_value}"))
824            } else {
825                Some(key_value.to_string())
826            }
827        }
828        super::manifest::AuthType::Basic => {
829            let encoded = base64::Engine::encode(
830                &base64::engine::general_purpose::STANDARD,
831                format!("{key_value}:"),
832            );
833            Some(format!("Basic {encoded}"))
834        }
835        _ => None,
836    }
837}
838
839// ---------------------------------------------------------------------------
840// High-level execute function for the call dispatch
841// ---------------------------------------------------------------------------
842
843/// Execute an MCP tool call — high-level entry point for cli/call.rs dispatch.
844///
845/// 1. Connects to the MCP server (or reuses connection via cache — future optimization)
846/// 2. Strips the provider prefix from the tool name (e.g., "github:read_file" → "read_file")
847/// 3. Calls tools/call with the raw MCP tool name
848/// 4. Returns the result as a serde_json::Value
849pub async fn execute(
850    provider: &Provider,
851    tool_name: &str,
852    args: &HashMap<String, Value>,
853    keyring: &Keyring,
854) -> Result<Value, McpError> {
855    execute_with_gen(provider, tool_name, args, keyring, None, None).await
856}
857
858/// Execute an MCP tool call with optional dynamic auth generator.
859pub async fn execute_with_gen(
860    provider: &Provider,
861    tool_name: &str,
862    args: &HashMap<String, Value>,
863    keyring: &Keyring,
864    gen_ctx: Option<&GenContext>,
865    auth_cache: Option<&AuthCache>,
866) -> Result<Value, McpError> {
867    let client = McpClient::connect_with_gen(provider, keyring, gen_ctx, auth_cache).await?;
868
869    // Strip provider prefix: "github:read_file" → "read_file"
870    let mcp_tool_name = tool_name
871        .strip_prefix(&format!(
872            "{}{}",
873            provider.name,
874            crate::core::manifest::TOOL_SEP_STR
875        ))
876        .unwrap_or(tool_name);
877
878    let result = client.call_tool(mcp_tool_name, args.clone()).await?;
879
880    // Convert MCP tool result to a single Value for ATI's output system
881    let value = mcp_result_to_value(&result);
882
883    // Clean up
884    client.disconnect().await;
885
886    Ok(value)
887}
888
889/// Convert an McpToolResult to a serde_json::Value.
890fn mcp_result_to_value(result: &McpToolResult) -> Value {
891    if result.content.len() == 1 {
892        // Single content item — unwrap for cleaner output
893        let item = &result.content[0];
894        match item.content_type.as_str() {
895            "text" => {
896                if let Some(text) = &item.text {
897                    // Try to parse as JSON (many MCP tools return JSON as text)
898                    serde_json::from_str(text).unwrap_or_else(|_| Value::String(text.clone()))
899                } else {
900                    Value::Null
901                }
902            }
903            "image" | "audio" => {
904                serde_json::json!({
905                    "type": item.content_type,
906                    "data": item.data,
907                    "mimeType": item.mime_type,
908                })
909            }
910            _ => serde_json::to_value(item).unwrap_or(Value::Null),
911        }
912    } else {
913        // Multiple content items — return as array
914        let items: Vec<Value> = result
915            .content
916            .iter()
917            .map(|c| serde_json::to_value(c).unwrap_or(Value::Null))
918            .collect();
919
920        serde_json::json!({
921            "content": items,
922            "isError": result.is_error,
923        })
924    }
925}
926
927// ---------------------------------------------------------------------------
928// Shared MCP tool discovery (used by both CLI and proxy)
929// ---------------------------------------------------------------------------
930
931/// Discover tools from all MCP providers concurrently and register them.
932///
933/// Each provider has a 30-second timeout. Failures are logged and skipped.
934/// Returns the number of tools discovered.
935pub async fn discover_all_mcp_tools(
936    registry: &mut crate::core::manifest::ManifestRegistry,
937    keyring: &Keyring,
938) -> usize {
939    use futures::stream::{self, StreamExt};
940
941    let providers: Vec<_> = registry
942        .list_mcp_providers()
943        .into_iter()
944        .map(|p| (p.name.clone(), p.clone()))
945        .collect();
946
947    if providers.is_empty() {
948        return 0;
949    }
950
951    // Discover concurrently (up to 10 at a time), with per-provider timeout
952    let results: Vec<_> = stream::iter(&providers)
953        .map(|(name, provider)| async move {
954            let result = tokio::time::timeout(
955                std::time::Duration::from_secs(30),
956                discover_one_provider(name, provider, keyring),
957            )
958            .await;
959
960            match result {
961                Ok(Ok(tools)) => Some((name.clone(), tools)),
962                Ok(Err(e)) => {
963                    tracing::warn!(provider = %name, error = %e, "MCP tool discovery failed");
964                    None
965                }
966                Err(_) => {
967                    tracing::warn!(provider = %name, "MCP tool discovery timed out (30s)");
968                    None
969                }
970            }
971        })
972        .buffer_unordered(10)
973        .collect()
974        .await;
975
976    // Register discovered tools (sequential — fast, just index inserts)
977    let mut total = 0;
978    for (name, tool_defs) in results.into_iter().flatten() {
979        let count = tool_defs.len();
980        registry.register_mcp_tools(&name, tool_defs);
981        tracing::info!(provider = %name, tools = count, "discovered MCP tools");
982        total += count;
983    }
984    total
985}
986
987/// Discover tools from a single MCP provider.
988async fn discover_one_provider(
989    _name: &str,
990    provider: &Provider,
991    keyring: &Keyring,
992) -> Result<Vec<crate::core::manifest::McpToolDef>, McpError> {
993    let client = McpClient::connect(provider, keyring).await?;
994    let tools = client.list_tools().await;
995    client.disconnect().await;
996
997    let tools = tools?;
998    Ok(tools
999        .into_iter()
1000        .map(|t| crate::core::manifest::McpToolDef {
1001            name: t.name,
1002            description: t.description,
1003            input_schema: t.input_schema,
1004        })
1005        .collect())
1006}
1007
1008// ---------------------------------------------------------------------------
1009// Tests
1010// ---------------------------------------------------------------------------
1011
1012#[cfg(test)]
1013mod tests {
1014    use super::*;
1015
1016    #[test]
1017    fn test_resolve_env_value_keyring() {
1018        let keyring = Keyring::empty();
1019        // No key in keyring — should return the raw value
1020        assert_eq!(
1021            resolve_env_value("${missing_key}", &keyring),
1022            "${missing_key}"
1023        );
1024        // Plain value — no resolution
1025        assert_eq!(resolve_env_value("plain_value", &keyring), "plain_value");
1026    }
1027
1028    #[test]
1029    fn test_resolve_env_value_inline() {
1030        // Build a keyring with a test key via load_credentials
1031        let dir = tempfile::TempDir::new().unwrap();
1032        let path = dir.path().join("creds");
1033        std::fs::write(&path, r#"{"my_key":"SECRET123"}"#).unwrap();
1034        let keyring = Keyring::load_credentials(&path).unwrap();
1035
1036        // Whole-string
1037        assert_eq!(resolve_env_value("${my_key}", &keyring), "SECRET123");
1038        // Inline
1039        assert_eq!(
1040            resolve_env_value("https://example.com/${my_key}/path", &keyring),
1041            "https://example.com/SECRET123/path"
1042        );
1043        // Multiple placeholders
1044        assert_eq!(
1045            resolve_env_value("${my_key}--${my_key}", &keyring),
1046            "SECRET123--SECRET123"
1047        );
1048        // Missing key stays as-is
1049        assert_eq!(
1050            resolve_env_value("https://example.com/${unknown}/path", &keyring),
1051            "https://example.com/${unknown}/path"
1052        );
1053        // No placeholder
1054        assert_eq!(
1055            resolve_env_value("no_placeholder", &keyring),
1056            "no_placeholder"
1057        );
1058    }
1059
1060    #[test]
1061    fn test_mcp_result_to_value_single_text() {
1062        let result = McpToolResult {
1063            content: vec![McpContent {
1064                content_type: "text".into(),
1065                text: Some("hello world".into()),
1066                data: None,
1067                mime_type: None,
1068            }],
1069            is_error: false,
1070        };
1071        assert_eq!(
1072            mcp_result_to_value(&result),
1073            Value::String("hello world".into())
1074        );
1075    }
1076
1077    #[test]
1078    fn test_mcp_result_to_value_json_text() {
1079        let result = McpToolResult {
1080            content: vec![McpContent {
1081                content_type: "text".into(),
1082                text: Some(r#"{"key":"value"}"#.into()),
1083                data: None,
1084                mime_type: None,
1085            }],
1086            is_error: false,
1087        };
1088        let val = mcp_result_to_value(&result);
1089        assert_eq!(val, serde_json::json!({"key": "value"}));
1090    }
1091
1092    #[test]
1093    fn test_extract_jsonrpc_result_success() {
1094        let msg = serde_json::json!({
1095            "jsonrpc": "2.0",
1096            "id": 1,
1097            "result": {"tools": []}
1098        });
1099        let result = extract_jsonrpc_result(&msg, 1).unwrap();
1100        assert_eq!(result, serde_json::json!({"tools": []}));
1101    }
1102
1103    #[test]
1104    fn test_extract_jsonrpc_result_error() {
1105        let msg = serde_json::json!({
1106            "jsonrpc": "2.0",
1107            "id": 1,
1108            "error": {"code": -32602, "message": "Invalid params"}
1109        });
1110        let err = extract_jsonrpc_result(&msg, 1).unwrap_err();
1111        assert!(matches!(err, McpError::Protocol { code: -32602, .. }));
1112    }
1113
1114    #[test]
1115    fn test_process_sse_data_matching_response() {
1116        let data = r#"{"jsonrpc":"2.0","id":5,"result":{"tools":[]}}"#;
1117        match process_sse_data(data, 5) {
1118            SseParseResult::OurResponse(Ok(val)) => {
1119                assert_eq!(val, serde_json::json!({"tools": []}));
1120            }
1121            _ => panic!("Expected OurResponse"),
1122        }
1123    }
1124
1125    #[test]
1126    fn test_process_sse_data_notification() {
1127        // Notifications don't have "id" — should be skipped
1128        let data = r#"{"jsonrpc":"2.0","method":"notifications/tools/list_changed"}"#;
1129        match process_sse_data(data, 5) {
1130            SseParseResult::NotOurMessage => {}
1131            _ => panic!("Expected NotOurMessage"),
1132        }
1133    }
1134
1135    #[test]
1136    fn test_process_sse_data_batch() {
1137        let data = r#"[
1138            {"jsonrpc":"2.0","method":"notifications/progress","params":{}},
1139            {"jsonrpc":"2.0","id":3,"result":{"content":[],"isError":false}}
1140        ]"#;
1141        match process_sse_data(data, 3) {
1142            SseParseResult::OurResponse(Ok(val)) => {
1143                assert!(val.get("content").is_some());
1144            }
1145            _ => panic!("Expected OurResponse from batch"),
1146        }
1147    }
1148
1149    #[test]
1150    fn test_process_sse_data_invalid_json() {
1151        let data = "not valid json {{{}";
1152        match process_sse_data(data, 1) {
1153            SseParseResult::ParseError(_) => {}
1154            other => panic!("Expected ParseError, got: {other:?}"),
1155        }
1156    }
1157
1158    #[test]
1159    fn test_process_sse_data_wrong_id() {
1160        let data = r#"{"jsonrpc":"2.0","id":99,"result":{"data":"wrong"}}"#;
1161        match process_sse_data(data, 1) {
1162            SseParseResult::NotOurMessage => {}
1163            _ => panic!("Expected NotOurMessage for wrong ID"),
1164        }
1165    }
1166
1167    #[test]
1168    fn test_process_sse_data_empty_batch() {
1169        let data = "[]";
1170        match process_sse_data(data, 1) {
1171            SseParseResult::NotOurMessage => {}
1172            _ => panic!("Expected NotOurMessage for empty batch"),
1173        }
1174    }
1175
1176    #[test]
1177    fn test_extract_jsonrpc_result_missing_result() {
1178        let msg = serde_json::json!({
1179            "jsonrpc": "2.0",
1180            "id": 1
1181        });
1182        let err = extract_jsonrpc_result(&msg, 1).unwrap_err();
1183        assert!(matches!(err, McpError::Protocol { code: -1, .. }));
1184    }
1185
1186    #[test]
1187    fn test_extract_jsonrpc_error_defaults() {
1188        // Error with missing code and message fields
1189        let msg = serde_json::json!({
1190            "jsonrpc": "2.0",
1191            "id": 1,
1192            "error": {}
1193        });
1194        let err = extract_jsonrpc_result(&msg, 1).unwrap_err();
1195        match err {
1196            McpError::Protocol { code, message } => {
1197                assert_eq!(code, -1);
1198                assert_eq!(message, "Unknown error");
1199            }
1200            _ => panic!("Expected Protocol error"),
1201        }
1202    }
1203
1204    #[test]
1205    fn test_mcp_result_to_value_error() {
1206        let result = McpToolResult {
1207            content: vec![McpContent {
1208                content_type: "text".into(),
1209                text: Some("Something went wrong".into()),
1210                data: None,
1211                mime_type: None,
1212            }],
1213            is_error: true,
1214        };
1215        let val = mcp_result_to_value(&result);
1216        assert_eq!(val, Value::String("Something went wrong".into()));
1217    }
1218
1219    #[test]
1220    fn test_mcp_result_to_value_multiple_content() {
1221        let result = McpToolResult {
1222            content: vec![
1223                McpContent {
1224                    content_type: "text".into(),
1225                    text: Some("Part 1".into()),
1226                    data: None,
1227                    mime_type: None,
1228                },
1229                McpContent {
1230                    content_type: "text".into(),
1231                    text: Some("Part 2".into()),
1232                    data: None,
1233                    mime_type: None,
1234                },
1235            ],
1236            is_error: false,
1237        };
1238        let val = mcp_result_to_value(&result);
1239        // Multiple items → {"content": [...], "isError": false}
1240        let content_arr = val["content"].as_array().unwrap();
1241        assert_eq!(content_arr.len(), 2);
1242        assert_eq!(val["isError"], false);
1243    }
1244
1245    #[test]
1246    fn test_mcp_result_to_value_empty_content() {
1247        let result = McpToolResult {
1248            content: vec![],
1249            is_error: false,
1250        };
1251        let val = mcp_result_to_value(&result);
1252        // Empty content → {"content": [], "isError": false}
1253        assert_eq!(val["content"].as_array().unwrap().len(), 0);
1254        assert_eq!(val["isError"], false);
1255    }
1256
1257    #[test]
1258    fn test_resolve_env_value_unclosed_brace() {
1259        let keyring = Keyring::empty();
1260        assert_eq!(resolve_env_value("${unclosed", &keyring), "${unclosed");
1261    }
1262}