Skip to main content

ati/core/
mcp_client.rs

1/// MCP client — connects to MCP servers via stdio or Streamable HTTP transport.
2///
3/// Implements the MCP protocol (2025-03-26 revision):
4/// - JSON-RPC 2.0 message framing
5/// - stdio transport: newline-delimited JSON over stdin/stdout
6/// - Streamable HTTP transport: POST with Accept: application/json, text/event-stream
7///   Server may respond with JSON or SSE stream. Supports Mcp-Session-Id for sessions.
8/// - Lifecycle: initialize → tools/list → tools/call → shutdown
9use serde::{Deserialize, Serialize};
10use serde_json::Value;
11use std::collections::HashMap;
12use std::io::{BufRead, BufReader, Write as IoWrite};
13use std::process::{Child, Command, Stdio};
14use std::sync::atomic::{AtomicU64, Ordering};
15use thiserror::Error;
16use tokio::sync::Mutex;
17
18use crate::core::auth_generator::{self, AuthCache, GenContext};
19use crate::core::keyring::Keyring;
20use crate::core::manifest::Provider;
21
22// ---------------------------------------------------------------------------
23// Errors
24// ---------------------------------------------------------------------------
25
26#[derive(Error, Debug)]
27#[allow(dead_code)]
28pub enum McpError {
29    #[error("MCP transport error: {0}")]
30    Transport(String),
31    #[error("MCP protocol error (code {code}): {message}")]
32    Protocol { code: i64, message: String },
33    #[error("MCP server did not return tools capability")]
34    NoToolsCapability,
35    #[error("IO error: {0}")]
36    Io(#[from] std::io::Error),
37    #[error("JSON error: {0}")]
38    Json(#[from] serde_json::Error),
39    #[error("HTTP error: {0}")]
40    Http(#[from] reqwest::Error),
41    #[error("MCP initialization failed: {0}")]
42    InitFailed(String),
43    #[error("SSE parse error: {0}")]
44    SseParse(String),
45    #[error("MCP server process exited unexpectedly")]
46    ProcessExited,
47    #[error("Missing MCP configuration: {0}")]
48    Config(String),
49}
50
51// ---------------------------------------------------------------------------
52// JSON-RPC types
53// ---------------------------------------------------------------------------
54
55#[derive(Debug, Serialize)]
56struct JsonRpcRequest {
57    jsonrpc: &'static str,
58    id: u64,
59    method: String,
60    #[serde(skip_serializing_if = "Option::is_none")]
61    params: Option<Value>,
62}
63
64// Note: We parse JSON-RPC responses manually via serde_json::Value
65// rather than typed deserialization, since responses can be interleaved
66// with notifications and batches in SSE streams.
67
68// ---------------------------------------------------------------------------
69// MCP protocol types
70// ---------------------------------------------------------------------------
71
72/// Tool definition from MCP tools/list response.
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct McpToolDef {
75    pub name: String,
76    #[serde(default)]
77    pub description: Option<String>,
78    #[serde(default, rename = "inputSchema")]
79    pub input_schema: Option<Value>,
80}
81
82/// Content item from MCP tools/call response.
83#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct McpContent {
85    #[serde(rename = "type")]
86    pub content_type: String,
87    #[serde(default)]
88    pub text: Option<String>,
89    #[serde(default)]
90    pub data: Option<String>,
91    #[serde(default, rename = "mimeType")]
92    pub mime_type: Option<String>,
93}
94
95/// Result from tools/call.
96#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct McpToolResult {
98    pub content: Vec<McpContent>,
99    #[serde(default, rename = "isError")]
100    pub is_error: bool,
101}
102
103// ---------------------------------------------------------------------------
104// Transport abstraction
105// ---------------------------------------------------------------------------
106
107/// Internal transport enum — stdio or HTTP.
108enum Transport {
109    Stdio(StdioTransport),
110    Http(HttpTransport),
111}
112
113/// Stdio transport: subprocess with stdin/stdout.
114struct StdioTransport {
115    child: Child,
116    /// We write JSON-RPC to the child's stdin. Option so we can take() on disconnect.
117    stdin: Option<std::process::ChildStdin>,
118    /// We read newline-delimited JSON-RPC from stdout.
119    reader: BufReader<std::process::ChildStdout>,
120}
121
122/// Streamable HTTP transport: POST to MCP endpoint.
123struct HttpTransport {
124    client: reqwest::Client,
125    url: String,
126    /// Session ID from Mcp-Session-Id header (set after initialize).
127    session_id: Option<String>,
128    /// Auth header name (default: "Authorization"). Custom for APIs using e.g. "x-api-key".
129    auth_header_name: String,
130    /// Auth header value (e.g., "Bearer <token>") injected on every request.
131    auth_header: Option<String>,
132    /// Extra headers from provider config.
133    extra_headers: HashMap<String, String>,
134}
135
136// ---------------------------------------------------------------------------
137// McpClient
138// ---------------------------------------------------------------------------
139
140/// MCP client that connects to a single MCP server.
141pub struct McpClient {
142    transport: Mutex<Transport>,
143    next_id: AtomicU64,
144    /// Cached tools from tools/list.
145    cached_tools: Mutex<Option<Vec<McpToolDef>>>,
146    /// Provider name (for logging).
147    provider_name: String,
148}
149
150impl McpClient {
151    /// Connect to an MCP server based on the provider's configuration.
152    ///
153    /// For stdio: spawns the subprocess with env vars resolved from keyring.
154    /// For HTTP: creates an HTTP client with auth headers.
155    pub async fn connect(provider: &Provider, keyring: &Keyring) -> Result<Self, McpError> {
156        Self::connect_with_gen(provider, keyring, None, None).await
157    }
158
159    /// Connect to an MCP server, optionally using a dynamic auth generator.
160    pub async fn connect_with_gen(
161        provider: &Provider,
162        keyring: &Keyring,
163        gen_ctx: Option<&GenContext>,
164        auth_cache: Option<&AuthCache>,
165    ) -> Result<Self, McpError> {
166        let transport = match provider.mcp_transport_type() {
167            "stdio" => {
168                let command = provider.mcp_command.as_deref().ok_or_else(|| {
169                    McpError::Config("mcp_command required for stdio transport".into())
170                })?;
171
172                // Resolve env vars: "${key_name}" → keyring value
173                let mut env_map: HashMap<String, String> = HashMap::new();
174                // Selectively pass through essential env vars (don't leak secrets)
175                if let Ok(path) = std::env::var("PATH") {
176                    env_map.insert("PATH".to_string(), path);
177                }
178                if let Ok(home) = std::env::var("HOME") {
179                    env_map.insert("HOME".to_string(), home);
180                }
181                // Add provider-specific env vars (resolved from keyring)
182                for (k, v) in &provider.mcp_env {
183                    let resolved = resolve_env_value(v, keyring);
184                    env_map.insert(k.clone(), resolved);
185                }
186
187                // If auth_generator is configured, run it and inject into env
188                if let Some(gen) = &provider.auth_generator {
189                    let default_ctx = GenContext::default();
190                    let ctx = gen_ctx.unwrap_or(&default_ctx);
191                    let default_cache = AuthCache::new();
192                    let cache = auth_cache.unwrap_or(&default_cache);
193                    match auth_generator::generate(provider, gen, ctx, keyring, cache).await {
194                        Ok(cred) => {
195                            env_map.insert("ATI_AUTH_TOKEN".to_string(), cred.value);
196                            for (k, v) in &cred.extra_env {
197                                env_map.insert(k.clone(), v.clone());
198                            }
199                        }
200                        Err(e) => {
201                            return Err(McpError::Config(format!("auth_generator failed: {e}")));
202                        }
203                    }
204                }
205
206                let mut child = Command::new(command)
207                    .args(&provider.mcp_args)
208                    .stdin(Stdio::piped())
209                    .stdout(Stdio::piped())
210                    .stderr(Stdio::piped())
211                    .env_clear()
212                    .envs(&env_map)
213                    .spawn()
214                    .map_err(|e| {
215                        McpError::Transport(format!("Failed to spawn MCP server '{command}': {e}"))
216                    })?;
217
218                let stdin = child
219                    .stdin
220                    .take()
221                    .ok_or_else(|| McpError::Transport("No stdin".into()))?;
222                let stdout = child
223                    .stdout
224                    .take()
225                    .ok_or_else(|| McpError::Transport("No stdout".into()))?;
226                let reader = BufReader::new(stdout);
227
228                Transport::Stdio(StdioTransport {
229                    child,
230                    stdin: Some(stdin),
231                    reader,
232                })
233            }
234            "http" => {
235                let url = provider.mcp_url.as_deref().ok_or_else(|| {
236                    McpError::Config("mcp_url required for HTTP transport".into())
237                })?;
238
239                // Build auth header: generator takes priority over static keyring
240                let auth_header = if let Some(gen) = &provider.auth_generator {
241                    let default_ctx = GenContext::default();
242                    let ctx = gen_ctx.unwrap_or(&default_ctx);
243                    let default_cache = AuthCache::new();
244                    let cache = auth_cache.unwrap_or(&default_cache);
245                    match auth_generator::generate(provider, gen, ctx, keyring, cache).await {
246                        Ok(cred) => match &provider.auth_type {
247                            super::manifest::AuthType::Bearer => {
248                                Some(format!("Bearer {}", cred.value))
249                            }
250                            super::manifest::AuthType::Header => {
251                                if let Some(prefix) = &provider.auth_value_prefix {
252                                    Some(format!("{prefix}{}", cred.value))
253                                } else {
254                                    Some(cred.value)
255                                }
256                            }
257                            _ => Some(cred.value),
258                        },
259                        Err(e) => {
260                            return Err(McpError::Config(format!("auth_generator failed: {e}")));
261                        }
262                    }
263                } else {
264                    build_auth_header(provider, keyring)
265                };
266
267                let client = reqwest::Client::builder()
268                    .timeout(std::time::Duration::from_secs(300))
269                    .build()?;
270
271                // Resolve ${key_name} placeholders in the URL from keyring
272                let resolved_url = resolve_env_value(url, keyring);
273
274                let auth_header_name = provider
275                    .auth_header_name
276                    .clone()
277                    .unwrap_or_else(|| "Authorization".to_string());
278
279                Transport::Http(HttpTransport {
280                    client,
281                    url: resolved_url,
282                    session_id: None,
283                    auth_header_name,
284                    auth_header,
285                    extra_headers: provider.extra_headers.clone(),
286                })
287            }
288            other => {
289                return Err(McpError::Config(format!(
290                    "Unknown MCP transport: '{other}' (expected 'stdio' or 'http')"
291                )));
292            }
293        };
294
295        let client = McpClient {
296            transport: Mutex::new(transport),
297            next_id: AtomicU64::new(1),
298            cached_tools: Mutex::new(None),
299            provider_name: provider.name.clone(),
300        };
301
302        // Perform MCP initialize handshake
303        client.initialize().await?;
304
305        Ok(client)
306    }
307
308    /// Perform the MCP initialize handshake.
309    async fn initialize(&self) -> Result<(), McpError> {
310        let params = serde_json::json!({
311            "protocolVersion": "2025-03-26",
312            "capabilities": {},
313            "clientInfo": {
314                "name": "ati",
315                "version": env!("CARGO_PKG_VERSION")
316            }
317        });
318
319        let response = self.send_request("initialize", Some(params)).await?;
320
321        // Verify server has tools capability
322        let capabilities = response.get("capabilities").unwrap_or(&Value::Null);
323        if capabilities.get("tools").is_none() {
324            return Err(McpError::NoToolsCapability);
325        }
326
327        // Extract session ID from HTTP transport response (handled inside send_request)
328
329        // Send initialized notification
330        self.send_notification("notifications/initialized", None)
331            .await?;
332
333        Ok(())
334    }
335
336    /// Discover tools via tools/list. Results are cached.
337    pub async fn list_tools(&self) -> Result<Vec<McpToolDef>, McpError> {
338        // Return cached if available
339        {
340            let cache = self.cached_tools.lock().await;
341            if let Some(tools) = cache.as_ref() {
342                return Ok(tools.clone());
343            }
344        }
345
346        let mut all_tools = Vec::new();
347        let mut cursor: Option<String> = None;
348        const MAX_PAGES: usize = 100;
349        const MAX_TOOLS: usize = 10_000;
350
351        for _page in 0..MAX_PAGES {
352            let params = cursor.as_ref().map(|c| serde_json::json!({"cursor": c}));
353            let result = self.send_request("tools/list", params).await?;
354
355            if let Some(tools_val) = result.get("tools") {
356                let tools: Vec<McpToolDef> = serde_json::from_value(tools_val.clone())?;
357                all_tools.extend(tools);
358            }
359
360            // Safety: cap total tools to prevent memory exhaustion
361            if all_tools.len() > MAX_TOOLS {
362                tracing::warn!(max = MAX_TOOLS, "MCP tool count exceeds limit, truncating");
363                all_tools.truncate(MAX_TOOLS);
364                break;
365            }
366
367            // Check for pagination
368            match result.get("nextCursor").and_then(|v| v.as_str()) {
369                Some(next) => cursor = Some(next.to_string()),
370                None => break,
371            }
372        }
373
374        // Cache the result
375        {
376            let mut cache = self.cached_tools.lock().await;
377            *cache = Some(all_tools.clone());
378        }
379
380        Ok(all_tools)
381    }
382
383    /// Execute a tool via tools/call.
384    pub async fn call_tool(
385        &self,
386        name: &str,
387        arguments: HashMap<String, Value>,
388    ) -> Result<McpToolResult, McpError> {
389        let params = serde_json::json!({
390            "name": name,
391            "arguments": arguments,
392        });
393
394        let result = self.send_request("tools/call", Some(params)).await?;
395        let tool_result: McpToolResult = serde_json::from_value(result)?;
396        Ok(tool_result)
397    }
398
399    /// Disconnect from the MCP server.
400    pub async fn disconnect(&self) {
401        let mut transport = self.transport.lock().await;
402        match &mut *transport {
403            Transport::Stdio(stdio) => {
404                // Take ownership of stdin to drop it (signals EOF to child).
405                // After this, stdin is None and the child should exit.
406                let _ = stdio.stdin.take();
407                // Try graceful shutdown, then kill
408                let _ = stdio.child.kill();
409                let _ = stdio.child.wait();
410            }
411            Transport::Http(http) => {
412                // Send HTTP DELETE to terminate session if we have a session ID
413                if let Some(session_id) = &http.session_id {
414                    let mut req = http.client.delete(&http.url);
415                    req = req.header("Mcp-Session-Id", session_id.as_str());
416                    let _ = req.send().await;
417                }
418            }
419        }
420    }
421
422    /// Invalidate cached tools (e.g., after tools/list_changed notification).
423    pub async fn invalidate_cache(&self) {
424        let mut cache = self.cached_tools.lock().await;
425        *cache = None;
426    }
427
428    // -----------------------------------------------------------------------
429    // Internal: send JSON-RPC request and receive response
430    // -----------------------------------------------------------------------
431
432    async fn send_request(&self, method: &str, params: Option<Value>) -> Result<Value, McpError> {
433        let id = self.next_id.fetch_add(1, Ordering::SeqCst);
434        let request = JsonRpcRequest {
435            jsonrpc: "2.0",
436            id,
437            method: method.to_string(),
438            params,
439        };
440
441        let mut transport = self.transport.lock().await;
442        match &mut *transport {
443            Transport::Stdio(stdio) => send_stdio_request(stdio, &request).await,
444            Transport::Http(http) => send_http_request(http, &request, &self.provider_name).await,
445        }
446    }
447
448    async fn send_notification(&self, method: &str, params: Option<Value>) -> Result<(), McpError> {
449        let mut notification = serde_json::json!({
450            "jsonrpc": "2.0",
451            "method": method,
452        });
453        if let Some(p) = params {
454            notification["params"] = p;
455        }
456
457        let mut transport = self.transport.lock().await;
458        match &mut *transport {
459            Transport::Stdio(stdio) => {
460                let stdin = stdio
461                    .stdin
462                    .as_mut()
463                    .ok_or_else(|| McpError::Transport("stdin closed".into()))?;
464                let msg = serde_json::to_string(&notification)?;
465                stdin.write_all(msg.as_bytes())?;
466                stdin.write_all(b"\n")?;
467                stdin.flush()?;
468                Ok(())
469            }
470            Transport::Http(http) => {
471                let mut req = http
472                    .client
473                    .post(&http.url)
474                    .header("Content-Type", "application/json")
475                    .header("Accept", "application/json, text/event-stream")
476                    .json(&notification);
477
478                if let Some(session_id) = &http.session_id {
479                    req = req.header("Mcp-Session-Id", session_id.as_str());
480                }
481                if let Some(auth) = &http.auth_header {
482                    req = req.header(http.auth_header_name.as_str(), auth.as_str());
483                }
484                for (name, value) in &http.extra_headers {
485                    req = req.header(name.as_str(), value.as_str());
486                }
487
488                let resp = req.send().await?;
489                // Notifications should get 202 Accepted
490                if !resp.status().is_success() {
491                    let status = resp.status().as_u16();
492                    let body = resp.text().await.unwrap_or_default();
493                    return Err(McpError::Transport(format!("HTTP {status}: {body}")));
494                }
495                Ok(())
496            }
497        }
498    }
499}
500
501// ---------------------------------------------------------------------------
502// Stdio transport I/O
503// ---------------------------------------------------------------------------
504
505/// Send a JSON-RPC request over stdio and read the response.
506/// Messages are newline-delimited JSON (no embedded newlines).
507async fn send_stdio_request(
508    stdio: &mut StdioTransport,
509    request: &JsonRpcRequest,
510) -> Result<Value, McpError> {
511    let stdin = stdio
512        .stdin
513        .as_mut()
514        .ok_or_else(|| McpError::Transport("stdin closed".into()))?;
515
516    // Serialize and send (newline-delimited)
517    let msg = serde_json::to_string(request)?;
518    stdin.write_all(msg.as_bytes())?;
519    stdin.write_all(b"\n")?;
520    stdin.flush()?;
521
522    let request_id = request.id;
523
524    // Read lines until we get a response matching our request ID.
525    // We may receive notifications interleaved — skip them.
526    loop {
527        let mut line = String::new();
528        let bytes_read = stdio.reader.read_line(&mut line)?;
529        if bytes_read == 0 {
530            return Err(McpError::ProcessExited);
531        }
532
533        let line = line.trim();
534        if line.is_empty() {
535            continue;
536        }
537
538        // Try to parse as JSON-RPC response
539        let parsed: Value = serde_json::from_str(line)?;
540
541        // Check if it's a response (has "id" field matching ours)
542        if let Some(id) = parsed.get("id") {
543            let id_matches = match id {
544                Value::Number(n) => n.as_u64() == Some(request_id),
545                _ => false,
546            };
547
548            if id_matches {
549                // It's our response
550                if let Some(err) = parsed.get("error") {
551                    let code = err.get("code").and_then(|c| c.as_i64()).unwrap_or(-1);
552                    let message = err
553                        .get("message")
554                        .and_then(|m| m.as_str())
555                        .unwrap_or("Unknown error");
556                    return Err(McpError::Protocol {
557                        code,
558                        message: message.to_string(),
559                    });
560                }
561
562                return parsed
563                    .get("result")
564                    .cloned()
565                    .ok_or_else(|| McpError::Protocol {
566                        code: -1,
567                        message: "Response missing 'result' field".into(),
568                    });
569            }
570        }
571
572        // Not our response — it's a notification or someone else's response; skip it.
573    }
574}
575
576// ---------------------------------------------------------------------------
577// HTTP Streamable transport I/O
578// ---------------------------------------------------------------------------
579
580/// Send a JSON-RPC request over Streamable HTTP.
581///
582/// Per MCP spec (2025-03-26):
583/// - POST with Accept: application/json, text/event-stream
584/// - Server may respond with Content-Type: application/json (single response)
585///   or Content-Type: text/event-stream (SSE stream with one or more messages)
586/// - Must handle Mcp-Session-Id header for session management
587async fn send_http_request(
588    http: &mut HttpTransport,
589    request: &JsonRpcRequest,
590    provider_name: &str,
591) -> Result<Value, McpError> {
592    let mut req = http
593        .client
594        .post(&http.url)
595        .header("Content-Type", "application/json")
596        .header("Accept", "application/json, text/event-stream")
597        .json(request);
598
599    // Attach session ID if we have one
600    if let Some(session_id) = &http.session_id {
601        req = req.header("Mcp-Session-Id", session_id.as_str());
602    }
603
604    // Inject auth (using custom header name if configured, e.g. "x-api-key")
605    if let Some(auth) = &http.auth_header {
606        req = req.header(http.auth_header_name.as_str(), auth.as_str());
607    }
608
609    // Inject extra headers from provider config
610    for (name, value) in &http.extra_headers {
611        req = req.header(name.as_str(), value.as_str());
612    }
613
614    let response = req
615        .send()
616        .await
617        .map_err(|e| McpError::Transport(format!("[{provider_name}] HTTP request failed: {e}")))?;
618
619    // Capture session ID from response header (usually set during initialize)
620    if let Some(session_val) = response.headers().get("mcp-session-id") {
621        if let Ok(sid) = session_val.to_str() {
622            http.session_id = Some(sid.to_string());
623        }
624    }
625
626    let status = response.status();
627    if !status.is_success() {
628        let body = response.text().await.unwrap_or_default();
629        return Err(McpError::Transport(format!(
630            "[{provider_name}] HTTP {}: {body}",
631            status.as_u16()
632        )));
633    }
634
635    // Determine response type from Content-Type header
636    let content_type = response
637        .headers()
638        .get("content-type")
639        .and_then(|v| v.to_str().ok())
640        .unwrap_or("")
641        .to_lowercase();
642
643    if content_type.contains("text/event-stream") {
644        // SSE stream — parse events to extract our JSON-RPC response
645        parse_sse_response(response, request.id).await
646    } else {
647        // Plain JSON response
648        let body: Value = response.json().await?;
649        extract_jsonrpc_result(&body, request.id)
650    }
651}
652
653/// Parse an SSE stream from an HTTP response, collecting JSON-RPC messages
654/// until we find the response matching our request ID.
655///
656/// SSE format (per HTML spec):
657///   event: message\n
658///   data: {"jsonrpc":"2.0","id":1,"result":{...}}\n
659///   \n
660///
661/// Each `data:` line contains a JSON-RPC message. The `event:` field is optional.
662/// We may receive notifications and server requests before getting our response.
663/// Maximum SSE response body size (50 MB) to prevent OOM from malicious servers.
664const MAX_SSE_BODY_SIZE: usize = 50 * 1024 * 1024;
665
666async fn parse_sse_response(
667    response: reqwest::Response,
668    request_id: u64,
669) -> Result<Value, McpError> {
670    // Enforce size limit on SSE stream body
671    let bytes = response
672        .bytes()
673        .await
674        .map_err(|e| McpError::SseParse(format!("Failed to read SSE stream: {e}")))?;
675    if bytes.len() > MAX_SSE_BODY_SIZE {
676        return Err(McpError::SseParse(format!(
677            "SSE response body exceeds maximum size ({} bytes > {} bytes)",
678            bytes.len(),
679            MAX_SSE_BODY_SIZE,
680        )));
681    }
682    let full_body = String::from_utf8_lossy(&bytes).into_owned();
683
684    // Parse SSE events
685    let mut current_data = String::new();
686
687    for line in full_body.lines() {
688        if line.starts_with("data:") {
689            let data = line.strip_prefix("data:").unwrap().trim();
690            if !data.is_empty() {
691                current_data.push_str(data);
692            }
693        } else if line.is_empty() && !current_data.is_empty() {
694            // End of event — process the accumulated data
695            match process_sse_data(&current_data, request_id) {
696                SseParseResult::OurResponse(result) => return result,
697                SseParseResult::NotOurMessage => {}
698                SseParseResult::ParseError(e) => {
699                    tracing::warn!(error = %e, "failed to parse SSE data");
700                }
701            }
702            current_data.clear();
703        }
704        // Lines starting with "event:", "id:", "retry:", or ":" are SSE metadata — skip
705    }
706
707    // Handle any remaining data that wasn't terminated by a blank line
708    if !current_data.is_empty() {
709        if let SseParseResult::OurResponse(result) = process_sse_data(&current_data, request_id) {
710            return result;
711        }
712    }
713
714    Err(McpError::SseParse(
715        "SSE stream ended without receiving a response for our request".into(),
716    ))
717}
718
719#[derive(Debug)]
720enum SseParseResult {
721    OurResponse(Result<Value, McpError>),
722    NotOurMessage,
723    ParseError(String),
724}
725
726fn process_sse_data(data: &str, request_id: u64) -> SseParseResult {
727    let parsed: Value = match serde_json::from_str(data) {
728        Ok(v) => v,
729        Err(e) => return SseParseResult::ParseError(e.to_string()),
730    };
731
732    // Could be a single message or a batch (array)
733    let messages = if parsed.is_array() {
734        parsed.as_array().unwrap().clone()
735    } else {
736        vec![parsed]
737    };
738
739    for msg in messages {
740        // Check if it's a response matching our ID
741        if let Some(id) = msg.get("id") {
742            let id_matches = match id {
743                Value::Number(n) => n.as_u64() == Some(request_id),
744                _ => false,
745            };
746            if id_matches {
747                return SseParseResult::OurResponse(extract_jsonrpc_result(&msg, request_id));
748            }
749        }
750        // Otherwise it's a notification or request from server — skip
751    }
752
753    SseParseResult::NotOurMessage
754}
755
756/// Extract the result (or error) from a JSON-RPC response message.
757fn extract_jsonrpc_result(msg: &Value, _request_id: u64) -> Result<Value, McpError> {
758    if let Some(err) = msg.get("error") {
759        let code = err.get("code").and_then(|c| c.as_i64()).unwrap_or(-1);
760        let message = err
761            .get("message")
762            .and_then(|m| m.as_str())
763            .unwrap_or("Unknown error");
764        return Err(McpError::Protocol {
765            code,
766            message: message.to_string(),
767        });
768    }
769
770    msg.get("result")
771        .cloned()
772        .ok_or_else(|| McpError::Protocol {
773            code: -1,
774            message: "Response missing 'result' field".into(),
775        })
776}
777
778// ---------------------------------------------------------------------------
779// Helpers
780// ---------------------------------------------------------------------------
781
782/// Resolve "${key_name}" placeholders in values from the keyring.
783/// Supports both whole-string (`${key}`) and inline (`prefix/${key}/suffix`) patterns.
784fn resolve_env_value(value: &str, keyring: &Keyring) -> String {
785    let mut result = value.to_string();
786    // Find all ${...} patterns and replace them
787    while let Some(start) = result.find("${") {
788        let rest = &result[start + 2..];
789        if let Some(end) = rest.find('}') {
790            let key_name = &rest[..end];
791            let replacement = keyring.get(key_name).unwrap_or("");
792            if replacement.is_empty() && keyring.get(key_name).is_none() {
793                // Key not found — leave the placeholder as-is to avoid breaking the string
794                break;
795            }
796            result = format!("{}{}{}", &result[..start], replacement, &rest[end + 1..]);
797        } else {
798            break; // No closing brace — stop
799        }
800    }
801    result
802}
803
804/// Build an Authorization header value from the provider's auth config.
805fn build_auth_header(provider: &Provider, keyring: &Keyring) -> Option<String> {
806    let key_name = provider.auth_key_name.as_deref()?;
807    let key_value = keyring.get(key_name)?;
808
809    match &provider.auth_type {
810        super::manifest::AuthType::Bearer => Some(format!("Bearer {key_value}")),
811        super::manifest::AuthType::Header => {
812            // For header auth with a custom prefix
813            if let Some(prefix) = &provider.auth_value_prefix {
814                Some(format!("{prefix}{key_value}"))
815            } else {
816                Some(key_value.to_string())
817            }
818        }
819        super::manifest::AuthType::Basic => {
820            let encoded = base64::Engine::encode(
821                &base64::engine::general_purpose::STANDARD,
822                format!("{key_value}:"),
823            );
824            Some(format!("Basic {encoded}"))
825        }
826        _ => None,
827    }
828}
829
830// ---------------------------------------------------------------------------
831// High-level execute function for the call dispatch
832// ---------------------------------------------------------------------------
833
834/// Execute an MCP tool call — high-level entry point for cli/call.rs dispatch.
835///
836/// 1. Connects to the MCP server (or reuses connection via cache — future optimization)
837/// 2. Strips the provider prefix from the tool name (e.g., "github:read_file" → "read_file")
838/// 3. Calls tools/call with the raw MCP tool name
839/// 4. Returns the result as a serde_json::Value
840pub async fn execute(
841    provider: &Provider,
842    tool_name: &str,
843    args: &HashMap<String, Value>,
844    keyring: &Keyring,
845) -> Result<Value, McpError> {
846    execute_with_gen(provider, tool_name, args, keyring, None, None).await
847}
848
849/// Execute an MCP tool call with optional dynamic auth generator.
850pub async fn execute_with_gen(
851    provider: &Provider,
852    tool_name: &str,
853    args: &HashMap<String, Value>,
854    keyring: &Keyring,
855    gen_ctx: Option<&GenContext>,
856    auth_cache: Option<&AuthCache>,
857) -> Result<Value, McpError> {
858    let client = McpClient::connect_with_gen(provider, keyring, gen_ctx, auth_cache).await?;
859
860    // Strip provider prefix: "github:read_file" → "read_file"
861    let mcp_tool_name = tool_name
862        .strip_prefix(&format!(
863            "{}{}",
864            provider.name,
865            crate::core::manifest::TOOL_SEP_STR
866        ))
867        .unwrap_or(tool_name);
868
869    let result = client.call_tool(mcp_tool_name, args.clone()).await?;
870
871    // Convert MCP tool result to a single Value for ATI's output system
872    let value = mcp_result_to_value(&result);
873
874    // Clean up
875    client.disconnect().await;
876
877    Ok(value)
878}
879
880/// Convert an McpToolResult to a serde_json::Value.
881fn mcp_result_to_value(result: &McpToolResult) -> Value {
882    if result.content.len() == 1 {
883        // Single content item — unwrap for cleaner output
884        let item = &result.content[0];
885        match item.content_type.as_str() {
886            "text" => {
887                if let Some(text) = &item.text {
888                    // Try to parse as JSON (many MCP tools return JSON as text)
889                    serde_json::from_str(text).unwrap_or_else(|_| Value::String(text.clone()))
890                } else {
891                    Value::Null
892                }
893            }
894            "image" | "audio" => {
895                serde_json::json!({
896                    "type": item.content_type,
897                    "data": item.data,
898                    "mimeType": item.mime_type,
899                })
900            }
901            _ => serde_json::to_value(item).unwrap_or(Value::Null),
902        }
903    } else {
904        // Multiple content items — return as array
905        let items: Vec<Value> = result
906            .content
907            .iter()
908            .map(|c| serde_json::to_value(c).unwrap_or(Value::Null))
909            .collect();
910
911        serde_json::json!({
912            "content": items,
913            "isError": result.is_error,
914        })
915    }
916}
917
918// ---------------------------------------------------------------------------
919// Tests
920// ---------------------------------------------------------------------------
921
922#[cfg(test)]
923mod tests {
924    use super::*;
925
926    #[test]
927    fn test_resolve_env_value_keyring() {
928        let keyring = Keyring::empty();
929        // No key in keyring — should return the raw value
930        assert_eq!(
931            resolve_env_value("${missing_key}", &keyring),
932            "${missing_key}"
933        );
934        // Plain value — no resolution
935        assert_eq!(resolve_env_value("plain_value", &keyring), "plain_value");
936    }
937
938    #[test]
939    fn test_resolve_env_value_inline() {
940        // Build a keyring with a test key via load_credentials
941        let dir = tempfile::TempDir::new().unwrap();
942        let path = dir.path().join("creds");
943        std::fs::write(&path, r#"{"my_key":"SECRET123"}"#).unwrap();
944        let keyring = Keyring::load_credentials(&path).unwrap();
945
946        // Whole-string
947        assert_eq!(resolve_env_value("${my_key}", &keyring), "SECRET123");
948        // Inline
949        assert_eq!(
950            resolve_env_value("https://example.com/${my_key}/path", &keyring),
951            "https://example.com/SECRET123/path"
952        );
953        // Multiple placeholders
954        assert_eq!(
955            resolve_env_value("${my_key}--${my_key}", &keyring),
956            "SECRET123--SECRET123"
957        );
958        // Missing key stays as-is
959        assert_eq!(
960            resolve_env_value("https://example.com/${unknown}/path", &keyring),
961            "https://example.com/${unknown}/path"
962        );
963        // No placeholder
964        assert_eq!(
965            resolve_env_value("no_placeholder", &keyring),
966            "no_placeholder"
967        );
968    }
969
970    #[test]
971    fn test_mcp_result_to_value_single_text() {
972        let result = McpToolResult {
973            content: vec![McpContent {
974                content_type: "text".into(),
975                text: Some("hello world".into()),
976                data: None,
977                mime_type: None,
978            }],
979            is_error: false,
980        };
981        assert_eq!(
982            mcp_result_to_value(&result),
983            Value::String("hello world".into())
984        );
985    }
986
987    #[test]
988    fn test_mcp_result_to_value_json_text() {
989        let result = McpToolResult {
990            content: vec![McpContent {
991                content_type: "text".into(),
992                text: Some(r#"{"key":"value"}"#.into()),
993                data: None,
994                mime_type: None,
995            }],
996            is_error: false,
997        };
998        let val = mcp_result_to_value(&result);
999        assert_eq!(val, serde_json::json!({"key": "value"}));
1000    }
1001
1002    #[test]
1003    fn test_extract_jsonrpc_result_success() {
1004        let msg = serde_json::json!({
1005            "jsonrpc": "2.0",
1006            "id": 1,
1007            "result": {"tools": []}
1008        });
1009        let result = extract_jsonrpc_result(&msg, 1).unwrap();
1010        assert_eq!(result, serde_json::json!({"tools": []}));
1011    }
1012
1013    #[test]
1014    fn test_extract_jsonrpc_result_error() {
1015        let msg = serde_json::json!({
1016            "jsonrpc": "2.0",
1017            "id": 1,
1018            "error": {"code": -32602, "message": "Invalid params"}
1019        });
1020        let err = extract_jsonrpc_result(&msg, 1).unwrap_err();
1021        assert!(matches!(err, McpError::Protocol { code: -32602, .. }));
1022    }
1023
1024    #[test]
1025    fn test_process_sse_data_matching_response() {
1026        let data = r#"{"jsonrpc":"2.0","id":5,"result":{"tools":[]}}"#;
1027        match process_sse_data(data, 5) {
1028            SseParseResult::OurResponse(Ok(val)) => {
1029                assert_eq!(val, serde_json::json!({"tools": []}));
1030            }
1031            _ => panic!("Expected OurResponse"),
1032        }
1033    }
1034
1035    #[test]
1036    fn test_process_sse_data_notification() {
1037        // Notifications don't have "id" — should be skipped
1038        let data = r#"{"jsonrpc":"2.0","method":"notifications/tools/list_changed"}"#;
1039        match process_sse_data(data, 5) {
1040            SseParseResult::NotOurMessage => {}
1041            _ => panic!("Expected NotOurMessage"),
1042        }
1043    }
1044
1045    #[test]
1046    fn test_process_sse_data_batch() {
1047        let data = r#"[
1048            {"jsonrpc":"2.0","method":"notifications/progress","params":{}},
1049            {"jsonrpc":"2.0","id":3,"result":{"content":[],"isError":false}}
1050        ]"#;
1051        match process_sse_data(data, 3) {
1052            SseParseResult::OurResponse(Ok(val)) => {
1053                assert!(val.get("content").is_some());
1054            }
1055            _ => panic!("Expected OurResponse from batch"),
1056        }
1057    }
1058
1059    #[test]
1060    fn test_process_sse_data_invalid_json() {
1061        let data = "not valid json {{{}";
1062        match process_sse_data(data, 1) {
1063            SseParseResult::ParseError(_) => {}
1064            other => panic!("Expected ParseError, got: {other:?}"),
1065        }
1066    }
1067
1068    #[test]
1069    fn test_process_sse_data_wrong_id() {
1070        let data = r#"{"jsonrpc":"2.0","id":99,"result":{"data":"wrong"}}"#;
1071        match process_sse_data(data, 1) {
1072            SseParseResult::NotOurMessage => {}
1073            _ => panic!("Expected NotOurMessage for wrong ID"),
1074        }
1075    }
1076
1077    #[test]
1078    fn test_process_sse_data_empty_batch() {
1079        let data = "[]";
1080        match process_sse_data(data, 1) {
1081            SseParseResult::NotOurMessage => {}
1082            _ => panic!("Expected NotOurMessage for empty batch"),
1083        }
1084    }
1085
1086    #[test]
1087    fn test_extract_jsonrpc_result_missing_result() {
1088        let msg = serde_json::json!({
1089            "jsonrpc": "2.0",
1090            "id": 1
1091        });
1092        let err = extract_jsonrpc_result(&msg, 1).unwrap_err();
1093        assert!(matches!(err, McpError::Protocol { code: -1, .. }));
1094    }
1095
1096    #[test]
1097    fn test_extract_jsonrpc_error_defaults() {
1098        // Error with missing code and message fields
1099        let msg = serde_json::json!({
1100            "jsonrpc": "2.0",
1101            "id": 1,
1102            "error": {}
1103        });
1104        let err = extract_jsonrpc_result(&msg, 1).unwrap_err();
1105        match err {
1106            McpError::Protocol { code, message } => {
1107                assert_eq!(code, -1);
1108                assert_eq!(message, "Unknown error");
1109            }
1110            _ => panic!("Expected Protocol error"),
1111        }
1112    }
1113
1114    #[test]
1115    fn test_mcp_result_to_value_error() {
1116        let result = McpToolResult {
1117            content: vec![McpContent {
1118                content_type: "text".into(),
1119                text: Some("Something went wrong".into()),
1120                data: None,
1121                mime_type: None,
1122            }],
1123            is_error: true,
1124        };
1125        let val = mcp_result_to_value(&result);
1126        assert_eq!(val, Value::String("Something went wrong".into()));
1127    }
1128
1129    #[test]
1130    fn test_mcp_result_to_value_multiple_content() {
1131        let result = McpToolResult {
1132            content: vec![
1133                McpContent {
1134                    content_type: "text".into(),
1135                    text: Some("Part 1".into()),
1136                    data: None,
1137                    mime_type: None,
1138                },
1139                McpContent {
1140                    content_type: "text".into(),
1141                    text: Some("Part 2".into()),
1142                    data: None,
1143                    mime_type: None,
1144                },
1145            ],
1146            is_error: false,
1147        };
1148        let val = mcp_result_to_value(&result);
1149        // Multiple items → {"content": [...], "isError": false}
1150        let content_arr = val["content"].as_array().unwrap();
1151        assert_eq!(content_arr.len(), 2);
1152        assert_eq!(val["isError"], false);
1153    }
1154
1155    #[test]
1156    fn test_mcp_result_to_value_empty_content() {
1157        let result = McpToolResult {
1158            content: vec![],
1159            is_error: false,
1160        };
1161        let val = mcp_result_to_value(&result);
1162        // Empty content → {"content": [], "isError": false}
1163        assert_eq!(val["content"].as_array().unwrap().len(), 0);
1164        assert_eq!(val["isError"], false);
1165    }
1166
1167    #[test]
1168    fn test_resolve_env_value_unclosed_brace() {
1169        let keyring = Keyring::empty();
1170        assert_eq!(resolve_env_value("${unclosed", &keyring), "${unclosed");
1171    }
1172}