1use serde::{Deserialize, Serialize};
10use serde_json::Value;
11use std::collections::HashMap;
12use std::io::{BufRead, BufReader, Write as IoWrite};
13use std::process::{Child, Command, Stdio};
14use std::sync::atomic::{AtomicU64, Ordering};
15use thiserror::Error;
16use tokio::sync::Mutex;
17
18use crate::core::auth_generator::{self, AuthCache, GenContext};
19use crate::core::keyring::Keyring;
20use crate::core::manifest::Provider;
21
22#[derive(Error, Debug)]
27#[allow(dead_code)]
28pub enum McpError {
29 #[error("MCP transport error: {0}")]
30 Transport(String),
31 #[error("MCP protocol error (code {code}): {message}")]
32 Protocol { code: i64, message: String },
33 #[error("MCP server did not return tools capability")]
34 NoToolsCapability,
35 #[error("IO error: {0}")]
36 Io(#[from] std::io::Error),
37 #[error("JSON error: {0}")]
38 Json(#[from] serde_json::Error),
39 #[error("HTTP error: {0}")]
40 Http(#[from] reqwest::Error),
41 #[error("MCP initialization failed: {0}")]
42 InitFailed(String),
43 #[error("SSE parse error: {0}")]
44 SseParse(String),
45 #[error("MCP server process exited unexpectedly")]
46 ProcessExited,
47 #[error("Missing MCP configuration: {0}")]
48 Config(String),
49}
50
51#[derive(Debug, Serialize)]
56struct JsonRpcRequest {
57 jsonrpc: &'static str,
58 id: u64,
59 method: String,
60 #[serde(skip_serializing_if = "Option::is_none")]
61 params: Option<Value>,
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct McpToolDef {
75 pub name: String,
76 #[serde(default)]
77 pub description: Option<String>,
78 #[serde(default, rename = "inputSchema")]
79 pub input_schema: Option<Value>,
80}
81
82#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct McpContent {
85 #[serde(rename = "type")]
86 pub content_type: String,
87 #[serde(default)]
88 pub text: Option<String>,
89 #[serde(default)]
90 pub data: Option<String>,
91 #[serde(default, rename = "mimeType")]
92 pub mime_type: Option<String>,
93}
94
95#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct McpToolResult {
98 pub content: Vec<McpContent>,
99 #[serde(default, rename = "isError")]
100 pub is_error: bool,
101}
102
103enum Transport {
109 Stdio(StdioTransport),
110 Http(HttpTransport),
111}
112
113struct StdioTransport {
115 child: Child,
116 stdin: Option<std::process::ChildStdin>,
118 reader: BufReader<std::process::ChildStdout>,
120}
121
122impl Drop for StdioTransport {
123 fn drop(&mut self) {
124 let _ = self.child.kill();
127 let _ = self.child.wait();
128 }
129}
130
131struct HttpTransport {
133 client: reqwest::Client,
134 url: String,
135 session_id: Option<String>,
137 auth_header_name: String,
139 auth_header: Option<String>,
141 extra_headers: HashMap<String, String>,
143}
144
145pub struct McpClient {
151 transport: Mutex<Transport>,
152 next_id: AtomicU64,
153 cached_tools: Mutex<Option<Vec<McpToolDef>>>,
155 provider_name: String,
157}
158
159impl McpClient {
160 pub async fn connect(provider: &Provider, keyring: &Keyring) -> Result<Self, McpError> {
165 Self::connect_with_gen(provider, keyring, None, None, None).await
166 }
167
168 pub async fn connect_with_gen(
176 provider: &Provider,
177 keyring: &Keyring,
178 gen_ctx: Option<&GenContext>,
179 auth_cache: Option<&AuthCache>,
180 override_mcp_url: Option<&str>,
181 ) -> Result<Self, McpError> {
182 let transport = match provider.mcp_transport_type() {
183 "stdio" => {
184 let command = provider.mcp_command.as_deref().ok_or_else(|| {
185 McpError::Config("mcp_command required for stdio transport".into())
186 })?;
187
188 let mut env_map: HashMap<String, String> = HashMap::new();
190 if let Ok(path) = std::env::var("PATH") {
192 env_map.insert("PATH".to_string(), path);
193 }
194 if let Ok(home) = std::env::var("HOME") {
195 env_map.insert("HOME".to_string(), home);
196 }
197 for (k, v) in &provider.mcp_env {
199 let resolved = resolve_env_value(v, keyring);
200 env_map.insert(k.clone(), resolved);
201 }
202
203 if let Some(gen) = &provider.auth_generator {
205 let default_ctx = GenContext::default();
206 let ctx = gen_ctx.unwrap_or(&default_ctx);
207 let default_cache = AuthCache::new();
208 let cache = auth_cache.unwrap_or(&default_cache);
209 match auth_generator::generate(provider, gen, ctx, keyring, cache).await {
210 Ok(cred) => {
211 env_map.insert("ATI_AUTH_TOKEN".to_string(), cred.value);
212 for (k, v) in &cred.extra_env {
213 env_map.insert(k.clone(), v.clone());
214 }
215 }
216 Err(e) => {
217 return Err(McpError::Config(format!("auth_generator failed: {e}")));
218 }
219 }
220 }
221
222 let mut child = Command::new(command)
223 .args(&provider.mcp_args)
224 .stdin(Stdio::piped())
225 .stdout(Stdio::piped())
226 .stderr(Stdio::piped())
227 .env_clear()
228 .envs(&env_map)
229 .spawn()
230 .map_err(|e| {
231 McpError::Transport(format!("Failed to spawn MCP server '{command}': {e}"))
232 })?;
233
234 let stdin = child
235 .stdin
236 .take()
237 .ok_or_else(|| McpError::Transport("No stdin".into()))?;
238 let stdout = child
239 .stdout
240 .take()
241 .ok_or_else(|| McpError::Transport("No stdout".into()))?;
242 let reader = BufReader::new(stdout);
243
244 Transport::Stdio(StdioTransport {
245 child,
246 stdin: Some(stdin),
247 reader,
248 })
249 }
250 "http" => {
251 let url = override_mcp_url
256 .or(provider.mcp_url.as_deref())
257 .ok_or_else(|| {
258 McpError::Config("mcp_url required for HTTP transport".into())
259 })?;
260
261 let auth_header = if let Some(gen) = &provider.auth_generator {
263 let default_ctx = GenContext::default();
264 let ctx = gen_ctx.unwrap_or(&default_ctx);
265 let default_cache = AuthCache::new();
266 let cache = auth_cache.unwrap_or(&default_cache);
267 match auth_generator::generate(provider, gen, ctx, keyring, cache).await {
268 Ok(cred) => match &provider.auth_type {
269 super::manifest::AuthType::Bearer => {
270 Some(format!("Bearer {}", cred.value))
271 }
272 super::manifest::AuthType::Header => {
273 if let Some(prefix) = &provider.auth_value_prefix {
274 Some(format!("{prefix}{}", cred.value))
275 } else {
276 Some(cred.value)
277 }
278 }
279 _ => Some(cred.value),
280 },
281 Err(e) => {
282 return Err(McpError::Config(format!("auth_generator failed: {e}")));
283 }
284 }
285 } else {
286 build_auth_header(provider, keyring)
287 };
288
289 let client = reqwest::Client::builder()
290 .timeout(std::time::Duration::from_secs(300))
291 .build()?;
292
293 let resolved_url = resolve_env_value(url, keyring);
295
296 let auth_header_name = provider
297 .auth_header_name
298 .clone()
299 .unwrap_or_else(|| "Authorization".to_string());
300
301 Transport::Http(HttpTransport {
302 client,
303 url: resolved_url,
304 session_id: None,
305 auth_header_name,
306 auth_header,
307 extra_headers: provider.extra_headers.clone(),
308 })
309 }
310 other => {
311 return Err(McpError::Config(format!(
312 "Unknown MCP transport: '{other}' (expected 'stdio' or 'http')"
313 )));
314 }
315 };
316
317 let client = McpClient {
318 transport: Mutex::new(transport),
319 next_id: AtomicU64::new(1),
320 cached_tools: Mutex::new(None),
321 provider_name: provider.name.clone(),
322 };
323
324 client.initialize().await?;
326
327 Ok(client)
328 }
329
330 async fn initialize(&self) -> Result<(), McpError> {
332 let params = serde_json::json!({
333 "protocolVersion": "2025-03-26",
334 "capabilities": {},
335 "clientInfo": {
336 "name": "ati",
337 "version": env!("CARGO_PKG_VERSION")
338 }
339 });
340
341 let response = self.send_request("initialize", Some(params)).await?;
342
343 let capabilities = response.get("capabilities").unwrap_or(&Value::Null);
345 if capabilities.get("tools").is_none() {
346 return Err(McpError::NoToolsCapability);
347 }
348
349 self.send_notification("notifications/initialized", None)
353 .await?;
354
355 Ok(())
356 }
357
358 pub async fn list_tools(&self) -> Result<Vec<McpToolDef>, McpError> {
360 {
362 let cache = self.cached_tools.lock().await;
363 if let Some(tools) = cache.as_ref() {
364 return Ok(tools.clone());
365 }
366 }
367
368 let mut all_tools = Vec::new();
369 let mut cursor: Option<String> = None;
370 const MAX_PAGES: usize = 100;
371 const MAX_TOOLS: usize = 10_000;
372
373 for _page in 0..MAX_PAGES {
374 let params = cursor.as_ref().map(|c| serde_json::json!({"cursor": c}));
375 let result = self.send_request("tools/list", params).await?;
376
377 if let Some(tools_val) = result.get("tools") {
378 let tools: Vec<McpToolDef> = serde_json::from_value(tools_val.clone())?;
379 all_tools.extend(tools);
380 }
381
382 if all_tools.len() > MAX_TOOLS {
384 tracing::warn!(max = MAX_TOOLS, "MCP tool count exceeds limit, truncating");
385 all_tools.truncate(MAX_TOOLS);
386 break;
387 }
388
389 match result.get("nextCursor").and_then(|v| v.as_str()) {
391 Some(next) => cursor = Some(next.to_string()),
392 None => break,
393 }
394 }
395
396 {
398 let mut cache = self.cached_tools.lock().await;
399 *cache = Some(all_tools.clone());
400 }
401
402 Ok(all_tools)
403 }
404
405 pub async fn call_tool(
407 &self,
408 name: &str,
409 arguments: HashMap<String, Value>,
410 ) -> Result<McpToolResult, McpError> {
411 let params = serde_json::json!({
412 "name": name,
413 "arguments": arguments,
414 });
415
416 let result = self.send_request("tools/call", Some(params)).await?;
417 let tool_result: McpToolResult = serde_json::from_value(result)?;
418 Ok(tool_result)
419 }
420
421 pub async fn disconnect(&self) {
423 let mut transport = self.transport.lock().await;
424 match &mut *transport {
425 Transport::Stdio(stdio) => {
426 let _ = stdio.stdin.take();
429 let _ = stdio.child.kill();
431 let _ = stdio.child.wait();
432 }
433 Transport::Http(http) => {
434 if let Some(session_id) = &http.session_id {
436 let mut req = http.client.delete(&http.url);
437 req = req.header("Mcp-Session-Id", session_id.as_str());
438 let _ = req.send().await;
439 }
440 }
441 }
442 }
443
444 pub async fn invalidate_cache(&self) {
446 let mut cache = self.cached_tools.lock().await;
447 *cache = None;
448 }
449
450 async fn send_request(&self, method: &str, params: Option<Value>) -> Result<Value, McpError> {
455 let id = self.next_id.fetch_add(1, Ordering::SeqCst);
456 let request = JsonRpcRequest {
457 jsonrpc: "2.0",
458 id,
459 method: method.to_string(),
460 params,
461 };
462
463 let mut transport = self.transport.lock().await;
464 match &mut *transport {
465 Transport::Stdio(stdio) => send_stdio_request(stdio, &request).await,
466 Transport::Http(http) => send_http_request(http, &request, &self.provider_name).await,
467 }
468 }
469
470 async fn send_notification(&self, method: &str, params: Option<Value>) -> Result<(), McpError> {
471 let mut notification = serde_json::json!({
472 "jsonrpc": "2.0",
473 "method": method,
474 });
475 if let Some(p) = params {
476 notification["params"] = p;
477 }
478
479 let mut transport = self.transport.lock().await;
480 match &mut *transport {
481 Transport::Stdio(stdio) => {
482 let stdin = stdio
483 .stdin
484 .as_mut()
485 .ok_or_else(|| McpError::Transport("stdin closed".into()))?;
486 let msg = serde_json::to_string(¬ification)?;
487 stdin.write_all(msg.as_bytes())?;
488 stdin.write_all(b"\n")?;
489 stdin.flush()?;
490 Ok(())
491 }
492 Transport::Http(http) => {
493 let mut req = http
494 .client
495 .post(&http.url)
496 .header("Content-Type", "application/json")
497 .header("Accept", "application/json, text/event-stream")
498 .json(¬ification);
499
500 if let Some(session_id) = &http.session_id {
501 req = req.header("Mcp-Session-Id", session_id.as_str());
502 }
503 if let Some(auth) = &http.auth_header {
504 req = req.header(http.auth_header_name.as_str(), auth.as_str());
505 }
506 for (name, value) in &http.extra_headers {
507 req = req.header(name.as_str(), value.as_str());
508 }
509
510 let resp = req.send().await?;
511 if !resp.status().is_success() {
513 let status = resp.status().as_u16();
514 let body = resp.text().await.unwrap_or_default();
515 return Err(McpError::Transport(format!("HTTP {status}: {body}")));
516 }
517 Ok(())
518 }
519 }
520 }
521}
522
523async fn send_stdio_request(
530 stdio: &mut StdioTransport,
531 request: &JsonRpcRequest,
532) -> Result<Value, McpError> {
533 let stdin = stdio
534 .stdin
535 .as_mut()
536 .ok_or_else(|| McpError::Transport("stdin closed".into()))?;
537
538 let msg = serde_json::to_string(request)?;
540 stdin.write_all(msg.as_bytes())?;
541 stdin.write_all(b"\n")?;
542 stdin.flush()?;
543
544 let request_id = request.id;
545
546 loop {
549 let mut line = String::new();
550 let bytes_read = stdio.reader.read_line(&mut line)?;
551 if bytes_read == 0 {
552 return Err(McpError::ProcessExited);
553 }
554
555 let line = line.trim();
556 if line.is_empty() {
557 continue;
558 }
559
560 let parsed: Value = serde_json::from_str(line)?;
562
563 if let Some(id) = parsed.get("id") {
565 let id_matches = match id {
566 Value::Number(n) => n.as_u64() == Some(request_id),
567 _ => false,
568 };
569
570 if id_matches {
571 if let Some(err) = parsed.get("error") {
573 let code = err.get("code").and_then(|c| c.as_i64()).unwrap_or(-1);
574 let message = err
575 .get("message")
576 .and_then(|m| m.as_str())
577 .unwrap_or("Unknown error");
578 return Err(McpError::Protocol {
579 code,
580 message: message.to_string(),
581 });
582 }
583
584 return parsed
585 .get("result")
586 .cloned()
587 .ok_or_else(|| McpError::Protocol {
588 code: -1,
589 message: "Response missing 'result' field".into(),
590 });
591 }
592 }
593
594 }
596}
597
598async fn send_http_request(
610 http: &mut HttpTransport,
611 request: &JsonRpcRequest,
612 provider_name: &str,
613) -> Result<Value, McpError> {
614 let mut req = http
615 .client
616 .post(&http.url)
617 .header("Content-Type", "application/json")
618 .header("Accept", "application/json, text/event-stream")
619 .json(request);
620
621 if let Some(session_id) = &http.session_id {
623 req = req.header("Mcp-Session-Id", session_id.as_str());
624 }
625
626 if let Some(auth) = &http.auth_header {
628 req = req.header(http.auth_header_name.as_str(), auth.as_str());
629 }
630
631 for (name, value) in &http.extra_headers {
633 req = req.header(name.as_str(), value.as_str());
634 }
635
636 let response = req
637 .send()
638 .await
639 .map_err(|e| McpError::Transport(format!("[{provider_name}] HTTP request failed: {e}")))?;
640
641 if let Some(session_val) = response.headers().get("mcp-session-id") {
643 if let Ok(sid) = session_val.to_str() {
644 http.session_id = Some(sid.to_string());
645 }
646 }
647
648 let status = response.status();
649 if !status.is_success() {
650 let body = response.text().await.unwrap_or_default();
651 return Err(McpError::Transport(format!(
652 "[{provider_name}] HTTP {}: {body}",
653 status.as_u16()
654 )));
655 }
656
657 let content_type = response
659 .headers()
660 .get("content-type")
661 .and_then(|v| v.to_str().ok())
662 .unwrap_or("")
663 .to_lowercase();
664
665 if content_type.contains("text/event-stream") {
666 parse_sse_response(response, request.id).await
668 } else {
669 let body: Value = response.json().await?;
671 extract_jsonrpc_result(&body, request.id)
672 }
673}
674
675const MAX_SSE_BODY_SIZE: usize = 50 * 1024 * 1024;
687
688async fn parse_sse_response(
689 response: reqwest::Response,
690 request_id: u64,
691) -> Result<Value, McpError> {
692 let bytes = response
694 .bytes()
695 .await
696 .map_err(|e| McpError::SseParse(format!("Failed to read SSE stream: {e}")))?;
697 if bytes.len() > MAX_SSE_BODY_SIZE {
698 return Err(McpError::SseParse(format!(
699 "SSE response body exceeds maximum size ({} bytes > {} bytes)",
700 bytes.len(),
701 MAX_SSE_BODY_SIZE,
702 )));
703 }
704 let full_body = String::from_utf8_lossy(&bytes).into_owned();
705
706 let mut current_data = String::new();
708
709 for line in full_body.lines() {
710 if line.starts_with("data:") {
711 let data = line.strip_prefix("data:").unwrap().trim();
712 if !data.is_empty() {
713 current_data.push_str(data);
714 }
715 } else if line.is_empty() && !current_data.is_empty() {
716 match process_sse_data(¤t_data, request_id) {
718 SseParseResult::OurResponse(result) => return result,
719 SseParseResult::NotOurMessage => {}
720 SseParseResult::ParseError(e) => {
721 tracing::warn!(error = %e, "failed to parse SSE data");
722 }
723 }
724 current_data.clear();
725 }
726 }
728
729 if !current_data.is_empty() {
731 if let SseParseResult::OurResponse(result) = process_sse_data(¤t_data, request_id) {
732 return result;
733 }
734 }
735
736 Err(McpError::SseParse(
737 "SSE stream ended without receiving a response for our request".into(),
738 ))
739}
740
741#[derive(Debug)]
742enum SseParseResult {
743 OurResponse(Result<Value, McpError>),
744 NotOurMessage,
745 ParseError(String),
746}
747
748fn process_sse_data(data: &str, request_id: u64) -> SseParseResult {
749 let parsed: Value = match serde_json::from_str(data) {
750 Ok(v) => v,
751 Err(e) => return SseParseResult::ParseError(e.to_string()),
752 };
753
754 let messages = if parsed.is_array() {
756 parsed.as_array().unwrap().clone()
757 } else {
758 vec![parsed]
759 };
760
761 for msg in messages {
762 if let Some(id) = msg.get("id") {
764 let id_matches = match id {
765 Value::Number(n) => n.as_u64() == Some(request_id),
766 _ => false,
767 };
768 if id_matches {
769 return SseParseResult::OurResponse(extract_jsonrpc_result(&msg, request_id));
770 }
771 }
772 }
774
775 SseParseResult::NotOurMessage
776}
777
778fn extract_jsonrpc_result(msg: &Value, _request_id: u64) -> Result<Value, McpError> {
780 if let Some(err) = msg.get("error") {
781 let code = err.get("code").and_then(|c| c.as_i64()).unwrap_or(-1);
782 let message = err
783 .get("message")
784 .and_then(|m| m.as_str())
785 .unwrap_or("Unknown error");
786 return Err(McpError::Protocol {
787 code,
788 message: message.to_string(),
789 });
790 }
791
792 msg.get("result")
793 .cloned()
794 .ok_or_else(|| McpError::Protocol {
795 code: -1,
796 message: "Response missing 'result' field".into(),
797 })
798}
799
800fn resolve_env_value(value: &str, keyring: &Keyring) -> String {
807 let mut result = value.to_string();
808 while let Some(start) = result.find("${") {
810 let rest = &result[start + 2..];
811 if let Some(end) = rest.find('}') {
812 let key_name = &rest[..end];
813 let replacement = keyring.get(key_name).unwrap_or("");
814 if replacement.is_empty() && keyring.get(key_name).is_none() {
815 break;
817 }
818 result = format!("{}{}{}", &result[..start], replacement, &rest[end + 1..]);
819 } else {
820 break; }
822 }
823 result
824}
825
826fn build_auth_header(provider: &Provider, keyring: &Keyring) -> Option<String> {
828 let key_name = provider.auth_key_name.as_deref()?;
829 let key_value = keyring.get(key_name)?;
830
831 match &provider.auth_type {
832 super::manifest::AuthType::Bearer => Some(format!("Bearer {key_value}")),
833 super::manifest::AuthType::Header => {
834 if let Some(prefix) = &provider.auth_value_prefix {
836 Some(format!("{prefix}{key_value}"))
837 } else {
838 Some(key_value.to_string())
839 }
840 }
841 super::manifest::AuthType::Basic => {
842 let encoded = base64::Engine::encode(
843 &base64::engine::general_purpose::STANDARD,
844 format!("{key_value}:"),
845 );
846 Some(format!("Basic {encoded}"))
847 }
848 _ => None,
849 }
850}
851
852pub async fn execute(
863 provider: &Provider,
864 tool_name: &str,
865 args: &HashMap<String, Value>,
866 keyring: &Keyring,
867) -> Result<Value, McpError> {
868 execute_with_gen(provider, tool_name, args, keyring, None, None, None).await
869}
870
871pub async fn execute_with_gen(
878 provider: &Provider,
879 tool_name: &str,
880 args: &HashMap<String, Value>,
881 keyring: &Keyring,
882 gen_ctx: Option<&GenContext>,
883 auth_cache: Option<&AuthCache>,
884 override_mcp_url: Option<&str>,
885) -> Result<Value, McpError> {
886 let client =
887 McpClient::connect_with_gen(provider, keyring, gen_ctx, auth_cache, override_mcp_url)
888 .await?;
889
890 let mcp_tool_name = tool_name
892 .strip_prefix(&format!(
893 "{}{}",
894 provider.name,
895 crate::core::manifest::TOOL_SEP_STR
896 ))
897 .unwrap_or(tool_name);
898
899 let result = client.call_tool(mcp_tool_name, args.clone()).await?;
900
901 let value = mcp_result_to_value(&result);
903
904 client.disconnect().await;
906
907 Ok(value)
908}
909
910fn mcp_result_to_value(result: &McpToolResult) -> Value {
912 if result.content.len() == 1 {
913 let item = &result.content[0];
915 match item.content_type.as_str() {
916 "text" => {
917 if let Some(text) = &item.text {
918 serde_json::from_str(text).unwrap_or_else(|_| Value::String(text.clone()))
920 } else {
921 Value::Null
922 }
923 }
924 "image" | "audio" => {
925 serde_json::json!({
926 "type": item.content_type,
927 "data": item.data,
928 "mimeType": item.mime_type,
929 })
930 }
931 _ => serde_json::to_value(item).unwrap_or(Value::Null),
932 }
933 } else {
934 let items: Vec<Value> = result
936 .content
937 .iter()
938 .map(|c| serde_json::to_value(c).unwrap_or(Value::Null))
939 .collect();
940
941 serde_json::json!({
942 "content": items,
943 "isError": result.is_error,
944 })
945 }
946}
947
948pub async fn discover_all_mcp_tools(
957 registry: &mut crate::core::manifest::ManifestRegistry,
958 keyring: &Keyring,
959) -> usize {
960 use futures::stream::{self, StreamExt};
961
962 let providers: Vec<_> = registry
963 .list_mcp_providers()
964 .into_iter()
965 .map(|p| (p.name.clone(), p.clone()))
966 .collect();
967
968 if providers.is_empty() {
969 return 0;
970 }
971
972 let results: Vec<_> = stream::iter(&providers)
974 .map(|(name, provider)| async move {
975 let result = tokio::time::timeout(
976 std::time::Duration::from_secs(30),
977 discover_one_provider(name, provider, keyring),
978 )
979 .await;
980
981 match result {
982 Ok(Ok(tools)) => Some((name.clone(), tools)),
983 Ok(Err(e)) => {
984 tracing::warn!(provider = %name, error = %e, "MCP tool discovery failed");
985 None
986 }
987 Err(_) => {
988 tracing::warn!(provider = %name, "MCP tool discovery timed out (30s)");
989 None
990 }
991 }
992 })
993 .buffer_unordered(10)
994 .collect()
995 .await;
996
997 let mut total = 0;
999 for (name, tool_defs) in results.into_iter().flatten() {
1000 let count = tool_defs.len();
1001 registry.register_mcp_tools(&name, tool_defs);
1002 tracing::info!(provider = %name, tools = count, "discovered MCP tools");
1003 total += count;
1004 }
1005 total
1006}
1007
1008async fn discover_one_provider(
1010 _name: &str,
1011 provider: &Provider,
1012 keyring: &Keyring,
1013) -> Result<Vec<crate::core::manifest::McpToolDef>, McpError> {
1014 let client = McpClient::connect(provider, keyring).await?;
1015 let tools = client.list_tools().await;
1016 client.disconnect().await;
1017
1018 let tools = tools?;
1019 Ok(tools
1020 .into_iter()
1021 .map(|t| crate::core::manifest::McpToolDef {
1022 name: t.name,
1023 description: t.description,
1024 input_schema: t.input_schema,
1025 })
1026 .collect())
1027}
1028
1029#[cfg(test)]
1034mod tests {
1035 use super::*;
1036
1037 #[test]
1038 fn test_resolve_env_value_keyring() {
1039 let keyring = Keyring::empty();
1040 assert_eq!(
1042 resolve_env_value("${missing_key}", &keyring),
1043 "${missing_key}"
1044 );
1045 assert_eq!(resolve_env_value("plain_value", &keyring), "plain_value");
1047 }
1048
1049 #[test]
1050 fn test_resolve_env_value_inline() {
1051 let dir = tempfile::TempDir::new().unwrap();
1053 let path = dir.path().join("creds");
1054 std::fs::write(&path, r#"{"my_key":"SECRET123"}"#).unwrap();
1055 let keyring = Keyring::load_credentials(&path).unwrap();
1056
1057 assert_eq!(resolve_env_value("${my_key}", &keyring), "SECRET123");
1059 assert_eq!(
1061 resolve_env_value("https://example.com/${my_key}/path", &keyring),
1062 "https://example.com/SECRET123/path"
1063 );
1064 assert_eq!(
1066 resolve_env_value("${my_key}--${my_key}", &keyring),
1067 "SECRET123--SECRET123"
1068 );
1069 assert_eq!(
1071 resolve_env_value("https://example.com/${unknown}/path", &keyring),
1072 "https://example.com/${unknown}/path"
1073 );
1074 assert_eq!(
1076 resolve_env_value("no_placeholder", &keyring),
1077 "no_placeholder"
1078 );
1079 }
1080
1081 #[test]
1082 fn test_mcp_result_to_value_single_text() {
1083 let result = McpToolResult {
1084 content: vec![McpContent {
1085 content_type: "text".into(),
1086 text: Some("hello world".into()),
1087 data: None,
1088 mime_type: None,
1089 }],
1090 is_error: false,
1091 };
1092 assert_eq!(
1093 mcp_result_to_value(&result),
1094 Value::String("hello world".into())
1095 );
1096 }
1097
1098 #[test]
1099 fn test_mcp_result_to_value_json_text() {
1100 let result = McpToolResult {
1101 content: vec![McpContent {
1102 content_type: "text".into(),
1103 text: Some(r#"{"key":"value"}"#.into()),
1104 data: None,
1105 mime_type: None,
1106 }],
1107 is_error: false,
1108 };
1109 let val = mcp_result_to_value(&result);
1110 assert_eq!(val, serde_json::json!({"key": "value"}));
1111 }
1112
1113 #[test]
1114 fn test_extract_jsonrpc_result_success() {
1115 let msg = serde_json::json!({
1116 "jsonrpc": "2.0",
1117 "id": 1,
1118 "result": {"tools": []}
1119 });
1120 let result = extract_jsonrpc_result(&msg, 1).unwrap();
1121 assert_eq!(result, serde_json::json!({"tools": []}));
1122 }
1123
1124 #[test]
1125 fn test_extract_jsonrpc_result_error() {
1126 let msg = serde_json::json!({
1127 "jsonrpc": "2.0",
1128 "id": 1,
1129 "error": {"code": -32602, "message": "Invalid params"}
1130 });
1131 let err = extract_jsonrpc_result(&msg, 1).unwrap_err();
1132 assert!(matches!(err, McpError::Protocol { code: -32602, .. }));
1133 }
1134
1135 #[test]
1136 fn test_process_sse_data_matching_response() {
1137 let data = r#"{"jsonrpc":"2.0","id":5,"result":{"tools":[]}}"#;
1138 match process_sse_data(data, 5) {
1139 SseParseResult::OurResponse(Ok(val)) => {
1140 assert_eq!(val, serde_json::json!({"tools": []}));
1141 }
1142 _ => panic!("Expected OurResponse"),
1143 }
1144 }
1145
1146 #[test]
1147 fn test_process_sse_data_notification() {
1148 let data = r#"{"jsonrpc":"2.0","method":"notifications/tools/list_changed"}"#;
1150 match process_sse_data(data, 5) {
1151 SseParseResult::NotOurMessage => {}
1152 _ => panic!("Expected NotOurMessage"),
1153 }
1154 }
1155
1156 #[test]
1157 fn test_process_sse_data_batch() {
1158 let data = r#"[
1159 {"jsonrpc":"2.0","method":"notifications/progress","params":{}},
1160 {"jsonrpc":"2.0","id":3,"result":{"content":[],"isError":false}}
1161 ]"#;
1162 match process_sse_data(data, 3) {
1163 SseParseResult::OurResponse(Ok(val)) => {
1164 assert!(val.get("content").is_some());
1165 }
1166 _ => panic!("Expected OurResponse from batch"),
1167 }
1168 }
1169
1170 #[test]
1171 fn test_process_sse_data_invalid_json() {
1172 let data = "not valid json {{{}";
1173 match process_sse_data(data, 1) {
1174 SseParseResult::ParseError(_) => {}
1175 other => panic!("Expected ParseError, got: {other:?}"),
1176 }
1177 }
1178
1179 #[test]
1180 fn test_process_sse_data_wrong_id() {
1181 let data = r#"{"jsonrpc":"2.0","id":99,"result":{"data":"wrong"}}"#;
1182 match process_sse_data(data, 1) {
1183 SseParseResult::NotOurMessage => {}
1184 _ => panic!("Expected NotOurMessage for wrong ID"),
1185 }
1186 }
1187
1188 #[test]
1189 fn test_process_sse_data_empty_batch() {
1190 let data = "[]";
1191 match process_sse_data(data, 1) {
1192 SseParseResult::NotOurMessage => {}
1193 _ => panic!("Expected NotOurMessage for empty batch"),
1194 }
1195 }
1196
1197 #[test]
1198 fn test_extract_jsonrpc_result_missing_result() {
1199 let msg = serde_json::json!({
1200 "jsonrpc": "2.0",
1201 "id": 1
1202 });
1203 let err = extract_jsonrpc_result(&msg, 1).unwrap_err();
1204 assert!(matches!(err, McpError::Protocol { code: -1, .. }));
1205 }
1206
1207 #[test]
1208 fn test_extract_jsonrpc_error_defaults() {
1209 let msg = serde_json::json!({
1211 "jsonrpc": "2.0",
1212 "id": 1,
1213 "error": {}
1214 });
1215 let err = extract_jsonrpc_result(&msg, 1).unwrap_err();
1216 match err {
1217 McpError::Protocol { code, message } => {
1218 assert_eq!(code, -1);
1219 assert_eq!(message, "Unknown error");
1220 }
1221 _ => panic!("Expected Protocol error"),
1222 }
1223 }
1224
1225 #[test]
1226 fn test_mcp_result_to_value_error() {
1227 let result = McpToolResult {
1228 content: vec![McpContent {
1229 content_type: "text".into(),
1230 text: Some("Something went wrong".into()),
1231 data: None,
1232 mime_type: None,
1233 }],
1234 is_error: true,
1235 };
1236 let val = mcp_result_to_value(&result);
1237 assert_eq!(val, Value::String("Something went wrong".into()));
1238 }
1239
1240 #[test]
1241 fn test_mcp_result_to_value_multiple_content() {
1242 let result = McpToolResult {
1243 content: vec![
1244 McpContent {
1245 content_type: "text".into(),
1246 text: Some("Part 1".into()),
1247 data: None,
1248 mime_type: None,
1249 },
1250 McpContent {
1251 content_type: "text".into(),
1252 text: Some("Part 2".into()),
1253 data: None,
1254 mime_type: None,
1255 },
1256 ],
1257 is_error: false,
1258 };
1259 let val = mcp_result_to_value(&result);
1260 let content_arr = val["content"].as_array().unwrap();
1262 assert_eq!(content_arr.len(), 2);
1263 assert_eq!(val["isError"], false);
1264 }
1265
1266 #[test]
1267 fn test_mcp_result_to_value_empty_content() {
1268 let result = McpToolResult {
1269 content: vec![],
1270 is_error: false,
1271 };
1272 let val = mcp_result_to_value(&result);
1273 assert_eq!(val["content"].as_array().unwrap().len(), 0);
1275 assert_eq!(val["isError"], false);
1276 }
1277
1278 #[test]
1279 fn test_resolve_env_value_unclosed_brace() {
1280 let keyring = Keyring::empty();
1281 assert_eq!(resolve_env_value("${unclosed", &keyring), "${unclosed");
1282 }
1283}