1use serde::{Deserialize, Serialize};
10use serde_json::Value;
11use std::collections::HashMap;
12use std::io::{BufRead, BufReader, Write as IoWrite};
13use std::process::{Child, Command, Stdio};
14use std::sync::atomic::{AtomicU64, Ordering};
15use thiserror::Error;
16use tokio::sync::Mutex;
17
18use crate::core::auth_generator::{self, AuthCache, GenContext};
19use crate::core::keyring::Keyring;
20use crate::core::manifest::Provider;
21
22#[derive(Error, Debug)]
27#[allow(dead_code)]
28pub enum McpError {
29 #[error("MCP transport error: {0}")]
30 Transport(String),
31 #[error("MCP protocol error (code {code}): {message}")]
32 Protocol { code: i64, message: String },
33 #[error("MCP server did not return tools capability")]
34 NoToolsCapability,
35 #[error("IO error: {0}")]
36 Io(#[from] std::io::Error),
37 #[error("JSON error: {0}")]
38 Json(#[from] serde_json::Error),
39 #[error("HTTP error: {0}")]
40 Http(#[from] reqwest::Error),
41 #[error("MCP initialization failed: {0}")]
42 InitFailed(String),
43 #[error("SSE parse error: {0}")]
44 SseParse(String),
45 #[error("MCP server process exited unexpectedly")]
46 ProcessExited,
47 #[error("Missing MCP configuration: {0}")]
48 Config(String),
49}
50
51#[derive(Debug, Serialize)]
56struct JsonRpcRequest {
57 jsonrpc: &'static str,
58 id: u64,
59 method: String,
60 #[serde(skip_serializing_if = "Option::is_none")]
61 params: Option<Value>,
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct McpToolDef {
75 pub name: String,
76 #[serde(default)]
77 pub description: Option<String>,
78 #[serde(default, rename = "inputSchema")]
79 pub input_schema: Option<Value>,
80}
81
82#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct McpContent {
85 #[serde(rename = "type")]
86 pub content_type: String,
87 #[serde(default)]
88 pub text: Option<String>,
89 #[serde(default)]
90 pub data: Option<String>,
91 #[serde(default, rename = "mimeType")]
92 pub mime_type: Option<String>,
93}
94
95#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct McpToolResult {
98 pub content: Vec<McpContent>,
99 #[serde(default, rename = "isError")]
100 pub is_error: bool,
101}
102
103enum Transport {
109 Stdio(StdioTransport),
110 Http(HttpTransport),
111}
112
113struct StdioTransport {
115 child: Child,
116 stdin: Option<std::process::ChildStdin>,
118 reader: BufReader<std::process::ChildStdout>,
120}
121
122impl Drop for StdioTransport {
123 fn drop(&mut self) {
124 let _ = self.child.kill();
127 let _ = self.child.wait();
128 }
129}
130
131struct HttpTransport {
133 client: reqwest::Client,
134 url: String,
135 session_id: Option<String>,
137 auth_header_name: String,
139 auth_header: Option<String>,
141 extra_headers: HashMap<String, String>,
143}
144
145pub struct McpClient {
151 transport: Mutex<Transport>,
152 next_id: AtomicU64,
153 cached_tools: Mutex<Option<Vec<McpToolDef>>>,
155 provider_name: String,
157}
158
159impl McpClient {
160 pub async fn connect(provider: &Provider, keyring: &Keyring) -> Result<Self, McpError> {
165 Self::connect_with_gen(provider, keyring, None, None).await
166 }
167
168 pub async fn connect_with_gen(
170 provider: &Provider,
171 keyring: &Keyring,
172 gen_ctx: Option<&GenContext>,
173 auth_cache: Option<&AuthCache>,
174 ) -> Result<Self, McpError> {
175 let transport = match provider.mcp_transport_type() {
176 "stdio" => {
177 let command = provider.mcp_command.as_deref().ok_or_else(|| {
178 McpError::Config("mcp_command required for stdio transport".into())
179 })?;
180
181 let mut env_map: HashMap<String, String> = HashMap::new();
183 if let Ok(path) = std::env::var("PATH") {
185 env_map.insert("PATH".to_string(), path);
186 }
187 if let Ok(home) = std::env::var("HOME") {
188 env_map.insert("HOME".to_string(), home);
189 }
190 for (k, v) in &provider.mcp_env {
192 let resolved = resolve_env_value(v, keyring);
193 env_map.insert(k.clone(), resolved);
194 }
195
196 if let Some(gen) = &provider.auth_generator {
198 let default_ctx = GenContext::default();
199 let ctx = gen_ctx.unwrap_or(&default_ctx);
200 let default_cache = AuthCache::new();
201 let cache = auth_cache.unwrap_or(&default_cache);
202 match auth_generator::generate(provider, gen, ctx, keyring, cache).await {
203 Ok(cred) => {
204 env_map.insert("ATI_AUTH_TOKEN".to_string(), cred.value);
205 for (k, v) in &cred.extra_env {
206 env_map.insert(k.clone(), v.clone());
207 }
208 }
209 Err(e) => {
210 return Err(McpError::Config(format!("auth_generator failed: {e}")));
211 }
212 }
213 }
214
215 let mut child = Command::new(command)
216 .args(&provider.mcp_args)
217 .stdin(Stdio::piped())
218 .stdout(Stdio::piped())
219 .stderr(Stdio::piped())
220 .env_clear()
221 .envs(&env_map)
222 .spawn()
223 .map_err(|e| {
224 McpError::Transport(format!("Failed to spawn MCP server '{command}': {e}"))
225 })?;
226
227 let stdin = child
228 .stdin
229 .take()
230 .ok_or_else(|| McpError::Transport("No stdin".into()))?;
231 let stdout = child
232 .stdout
233 .take()
234 .ok_or_else(|| McpError::Transport("No stdout".into()))?;
235 let reader = BufReader::new(stdout);
236
237 Transport::Stdio(StdioTransport {
238 child,
239 stdin: Some(stdin),
240 reader,
241 })
242 }
243 "http" => {
244 let url = provider.mcp_url.as_deref().ok_or_else(|| {
245 McpError::Config("mcp_url required for HTTP transport".into())
246 })?;
247
248 let auth_header = if let Some(gen) = &provider.auth_generator {
250 let default_ctx = GenContext::default();
251 let ctx = gen_ctx.unwrap_or(&default_ctx);
252 let default_cache = AuthCache::new();
253 let cache = auth_cache.unwrap_or(&default_cache);
254 match auth_generator::generate(provider, gen, ctx, keyring, cache).await {
255 Ok(cred) => match &provider.auth_type {
256 super::manifest::AuthType::Bearer => {
257 Some(format!("Bearer {}", cred.value))
258 }
259 super::manifest::AuthType::Header => {
260 if let Some(prefix) = &provider.auth_value_prefix {
261 Some(format!("{prefix}{}", cred.value))
262 } else {
263 Some(cred.value)
264 }
265 }
266 _ => Some(cred.value),
267 },
268 Err(e) => {
269 return Err(McpError::Config(format!("auth_generator failed: {e}")));
270 }
271 }
272 } else {
273 build_auth_header(provider, keyring)
274 };
275
276 let client = reqwest::Client::builder()
277 .timeout(std::time::Duration::from_secs(300))
278 .build()?;
279
280 let resolved_url = resolve_env_value(url, keyring);
282
283 let auth_header_name = provider
284 .auth_header_name
285 .clone()
286 .unwrap_or_else(|| "Authorization".to_string());
287
288 Transport::Http(HttpTransport {
289 client,
290 url: resolved_url,
291 session_id: None,
292 auth_header_name,
293 auth_header,
294 extra_headers: provider.extra_headers.clone(),
295 })
296 }
297 other => {
298 return Err(McpError::Config(format!(
299 "Unknown MCP transport: '{other}' (expected 'stdio' or 'http')"
300 )));
301 }
302 };
303
304 let client = McpClient {
305 transport: Mutex::new(transport),
306 next_id: AtomicU64::new(1),
307 cached_tools: Mutex::new(None),
308 provider_name: provider.name.clone(),
309 };
310
311 client.initialize().await?;
313
314 Ok(client)
315 }
316
317 async fn initialize(&self) -> Result<(), McpError> {
319 let params = serde_json::json!({
320 "protocolVersion": "2025-03-26",
321 "capabilities": {},
322 "clientInfo": {
323 "name": "ati",
324 "version": env!("CARGO_PKG_VERSION")
325 }
326 });
327
328 let response = self.send_request("initialize", Some(params)).await?;
329
330 let capabilities = response.get("capabilities").unwrap_or(&Value::Null);
332 if capabilities.get("tools").is_none() {
333 return Err(McpError::NoToolsCapability);
334 }
335
336 self.send_notification("notifications/initialized", None)
340 .await?;
341
342 Ok(())
343 }
344
345 pub async fn list_tools(&self) -> Result<Vec<McpToolDef>, McpError> {
347 {
349 let cache = self.cached_tools.lock().await;
350 if let Some(tools) = cache.as_ref() {
351 return Ok(tools.clone());
352 }
353 }
354
355 let mut all_tools = Vec::new();
356 let mut cursor: Option<String> = None;
357 const MAX_PAGES: usize = 100;
358 const MAX_TOOLS: usize = 10_000;
359
360 for _page in 0..MAX_PAGES {
361 let params = cursor.as_ref().map(|c| serde_json::json!({"cursor": c}));
362 let result = self.send_request("tools/list", params).await?;
363
364 if let Some(tools_val) = result.get("tools") {
365 let tools: Vec<McpToolDef> = serde_json::from_value(tools_val.clone())?;
366 all_tools.extend(tools);
367 }
368
369 if all_tools.len() > MAX_TOOLS {
371 tracing::warn!(max = MAX_TOOLS, "MCP tool count exceeds limit, truncating");
372 all_tools.truncate(MAX_TOOLS);
373 break;
374 }
375
376 match result.get("nextCursor").and_then(|v| v.as_str()) {
378 Some(next) => cursor = Some(next.to_string()),
379 None => break,
380 }
381 }
382
383 {
385 let mut cache = self.cached_tools.lock().await;
386 *cache = Some(all_tools.clone());
387 }
388
389 Ok(all_tools)
390 }
391
392 pub async fn call_tool(
394 &self,
395 name: &str,
396 arguments: HashMap<String, Value>,
397 ) -> Result<McpToolResult, McpError> {
398 let params = serde_json::json!({
399 "name": name,
400 "arguments": arguments,
401 });
402
403 let result = self.send_request("tools/call", Some(params)).await?;
404 let tool_result: McpToolResult = serde_json::from_value(result)?;
405 Ok(tool_result)
406 }
407
408 pub async fn disconnect(&self) {
410 let mut transport = self.transport.lock().await;
411 match &mut *transport {
412 Transport::Stdio(stdio) => {
413 let _ = stdio.stdin.take();
416 let _ = stdio.child.kill();
418 let _ = stdio.child.wait();
419 }
420 Transport::Http(http) => {
421 if let Some(session_id) = &http.session_id {
423 let mut req = http.client.delete(&http.url);
424 req = req.header("Mcp-Session-Id", session_id.as_str());
425 let _ = req.send().await;
426 }
427 }
428 }
429 }
430
431 pub async fn invalidate_cache(&self) {
433 let mut cache = self.cached_tools.lock().await;
434 *cache = None;
435 }
436
437 async fn send_request(&self, method: &str, params: Option<Value>) -> Result<Value, McpError> {
442 let id = self.next_id.fetch_add(1, Ordering::SeqCst);
443 let request = JsonRpcRequest {
444 jsonrpc: "2.0",
445 id,
446 method: method.to_string(),
447 params,
448 };
449
450 let mut transport = self.transport.lock().await;
451 match &mut *transport {
452 Transport::Stdio(stdio) => send_stdio_request(stdio, &request).await,
453 Transport::Http(http) => send_http_request(http, &request, &self.provider_name).await,
454 }
455 }
456
457 async fn send_notification(&self, method: &str, params: Option<Value>) -> Result<(), McpError> {
458 let mut notification = serde_json::json!({
459 "jsonrpc": "2.0",
460 "method": method,
461 });
462 if let Some(p) = params {
463 notification["params"] = p;
464 }
465
466 let mut transport = self.transport.lock().await;
467 match &mut *transport {
468 Transport::Stdio(stdio) => {
469 let stdin = stdio
470 .stdin
471 .as_mut()
472 .ok_or_else(|| McpError::Transport("stdin closed".into()))?;
473 let msg = serde_json::to_string(¬ification)?;
474 stdin.write_all(msg.as_bytes())?;
475 stdin.write_all(b"\n")?;
476 stdin.flush()?;
477 Ok(())
478 }
479 Transport::Http(http) => {
480 let mut req = http
481 .client
482 .post(&http.url)
483 .header("Content-Type", "application/json")
484 .header("Accept", "application/json, text/event-stream")
485 .json(¬ification);
486
487 if let Some(session_id) = &http.session_id {
488 req = req.header("Mcp-Session-Id", session_id.as_str());
489 }
490 if let Some(auth) = &http.auth_header {
491 req = req.header(http.auth_header_name.as_str(), auth.as_str());
492 }
493 for (name, value) in &http.extra_headers {
494 req = req.header(name.as_str(), value.as_str());
495 }
496
497 let resp = req.send().await?;
498 if !resp.status().is_success() {
500 let status = resp.status().as_u16();
501 let body = resp.text().await.unwrap_or_default();
502 return Err(McpError::Transport(format!("HTTP {status}: {body}")));
503 }
504 Ok(())
505 }
506 }
507 }
508}
509
510async fn send_stdio_request(
517 stdio: &mut StdioTransport,
518 request: &JsonRpcRequest,
519) -> Result<Value, McpError> {
520 let stdin = stdio
521 .stdin
522 .as_mut()
523 .ok_or_else(|| McpError::Transport("stdin closed".into()))?;
524
525 let msg = serde_json::to_string(request)?;
527 stdin.write_all(msg.as_bytes())?;
528 stdin.write_all(b"\n")?;
529 stdin.flush()?;
530
531 let request_id = request.id;
532
533 loop {
536 let mut line = String::new();
537 let bytes_read = stdio.reader.read_line(&mut line)?;
538 if bytes_read == 0 {
539 return Err(McpError::ProcessExited);
540 }
541
542 let line = line.trim();
543 if line.is_empty() {
544 continue;
545 }
546
547 let parsed: Value = serde_json::from_str(line)?;
549
550 if let Some(id) = parsed.get("id") {
552 let id_matches = match id {
553 Value::Number(n) => n.as_u64() == Some(request_id),
554 _ => false,
555 };
556
557 if id_matches {
558 if let Some(err) = parsed.get("error") {
560 let code = err.get("code").and_then(|c| c.as_i64()).unwrap_or(-1);
561 let message = err
562 .get("message")
563 .and_then(|m| m.as_str())
564 .unwrap_or("Unknown error");
565 return Err(McpError::Protocol {
566 code,
567 message: message.to_string(),
568 });
569 }
570
571 return parsed
572 .get("result")
573 .cloned()
574 .ok_or_else(|| McpError::Protocol {
575 code: -1,
576 message: "Response missing 'result' field".into(),
577 });
578 }
579 }
580
581 }
583}
584
585async fn send_http_request(
597 http: &mut HttpTransport,
598 request: &JsonRpcRequest,
599 provider_name: &str,
600) -> Result<Value, McpError> {
601 let mut req = http
602 .client
603 .post(&http.url)
604 .header("Content-Type", "application/json")
605 .header("Accept", "application/json, text/event-stream")
606 .json(request);
607
608 if let Some(session_id) = &http.session_id {
610 req = req.header("Mcp-Session-Id", session_id.as_str());
611 }
612
613 if let Some(auth) = &http.auth_header {
615 req = req.header(http.auth_header_name.as_str(), auth.as_str());
616 }
617
618 for (name, value) in &http.extra_headers {
620 req = req.header(name.as_str(), value.as_str());
621 }
622
623 let response = req
624 .send()
625 .await
626 .map_err(|e| McpError::Transport(format!("[{provider_name}] HTTP request failed: {e}")))?;
627
628 if let Some(session_val) = response.headers().get("mcp-session-id") {
630 if let Ok(sid) = session_val.to_str() {
631 http.session_id = Some(sid.to_string());
632 }
633 }
634
635 let status = response.status();
636 if !status.is_success() {
637 let body = response.text().await.unwrap_or_default();
638 return Err(McpError::Transport(format!(
639 "[{provider_name}] HTTP {}: {body}",
640 status.as_u16()
641 )));
642 }
643
644 let content_type = response
646 .headers()
647 .get("content-type")
648 .and_then(|v| v.to_str().ok())
649 .unwrap_or("")
650 .to_lowercase();
651
652 if content_type.contains("text/event-stream") {
653 parse_sse_response(response, request.id).await
655 } else {
656 let body: Value = response.json().await?;
658 extract_jsonrpc_result(&body, request.id)
659 }
660}
661
662const MAX_SSE_BODY_SIZE: usize = 50 * 1024 * 1024;
674
675async fn parse_sse_response(
676 response: reqwest::Response,
677 request_id: u64,
678) -> Result<Value, McpError> {
679 let bytes = response
681 .bytes()
682 .await
683 .map_err(|e| McpError::SseParse(format!("Failed to read SSE stream: {e}")))?;
684 if bytes.len() > MAX_SSE_BODY_SIZE {
685 return Err(McpError::SseParse(format!(
686 "SSE response body exceeds maximum size ({} bytes > {} bytes)",
687 bytes.len(),
688 MAX_SSE_BODY_SIZE,
689 )));
690 }
691 let full_body = String::from_utf8_lossy(&bytes).into_owned();
692
693 let mut current_data = String::new();
695
696 for line in full_body.lines() {
697 if line.starts_with("data:") {
698 let data = line.strip_prefix("data:").unwrap().trim();
699 if !data.is_empty() {
700 current_data.push_str(data);
701 }
702 } else if line.is_empty() && !current_data.is_empty() {
703 match process_sse_data(¤t_data, request_id) {
705 SseParseResult::OurResponse(result) => return result,
706 SseParseResult::NotOurMessage => {}
707 SseParseResult::ParseError(e) => {
708 tracing::warn!(error = %e, "failed to parse SSE data");
709 }
710 }
711 current_data.clear();
712 }
713 }
715
716 if !current_data.is_empty() {
718 if let SseParseResult::OurResponse(result) = process_sse_data(¤t_data, request_id) {
719 return result;
720 }
721 }
722
723 Err(McpError::SseParse(
724 "SSE stream ended without receiving a response for our request".into(),
725 ))
726}
727
728#[derive(Debug)]
729enum SseParseResult {
730 OurResponse(Result<Value, McpError>),
731 NotOurMessage,
732 ParseError(String),
733}
734
735fn process_sse_data(data: &str, request_id: u64) -> SseParseResult {
736 let parsed: Value = match serde_json::from_str(data) {
737 Ok(v) => v,
738 Err(e) => return SseParseResult::ParseError(e.to_string()),
739 };
740
741 let messages = if parsed.is_array() {
743 parsed.as_array().unwrap().clone()
744 } else {
745 vec![parsed]
746 };
747
748 for msg in messages {
749 if let Some(id) = msg.get("id") {
751 let id_matches = match id {
752 Value::Number(n) => n.as_u64() == Some(request_id),
753 _ => false,
754 };
755 if id_matches {
756 return SseParseResult::OurResponse(extract_jsonrpc_result(&msg, request_id));
757 }
758 }
759 }
761
762 SseParseResult::NotOurMessage
763}
764
765fn extract_jsonrpc_result(msg: &Value, _request_id: u64) -> Result<Value, McpError> {
767 if let Some(err) = msg.get("error") {
768 let code = err.get("code").and_then(|c| c.as_i64()).unwrap_or(-1);
769 let message = err
770 .get("message")
771 .and_then(|m| m.as_str())
772 .unwrap_or("Unknown error");
773 return Err(McpError::Protocol {
774 code,
775 message: message.to_string(),
776 });
777 }
778
779 msg.get("result")
780 .cloned()
781 .ok_or_else(|| McpError::Protocol {
782 code: -1,
783 message: "Response missing 'result' field".into(),
784 })
785}
786
787fn resolve_env_value(value: &str, keyring: &Keyring) -> String {
794 let mut result = value.to_string();
795 while let Some(start) = result.find("${") {
797 let rest = &result[start + 2..];
798 if let Some(end) = rest.find('}') {
799 let key_name = &rest[..end];
800 let replacement = keyring.get(key_name).unwrap_or("");
801 if replacement.is_empty() && keyring.get(key_name).is_none() {
802 break;
804 }
805 result = format!("{}{}{}", &result[..start], replacement, &rest[end + 1..]);
806 } else {
807 break; }
809 }
810 result
811}
812
813fn build_auth_header(provider: &Provider, keyring: &Keyring) -> Option<String> {
815 let key_name = provider.auth_key_name.as_deref()?;
816 let key_value = keyring.get(key_name)?;
817
818 match &provider.auth_type {
819 super::manifest::AuthType::Bearer => Some(format!("Bearer {key_value}")),
820 super::manifest::AuthType::Header => {
821 if let Some(prefix) = &provider.auth_value_prefix {
823 Some(format!("{prefix}{key_value}"))
824 } else {
825 Some(key_value.to_string())
826 }
827 }
828 super::manifest::AuthType::Basic => {
829 let encoded = base64::Engine::encode(
830 &base64::engine::general_purpose::STANDARD,
831 format!("{key_value}:"),
832 );
833 Some(format!("Basic {encoded}"))
834 }
835 _ => None,
836 }
837}
838
839pub async fn execute(
850 provider: &Provider,
851 tool_name: &str,
852 args: &HashMap<String, Value>,
853 keyring: &Keyring,
854) -> Result<Value, McpError> {
855 execute_with_gen(provider, tool_name, args, keyring, None, None).await
856}
857
858pub async fn execute_with_gen(
860 provider: &Provider,
861 tool_name: &str,
862 args: &HashMap<String, Value>,
863 keyring: &Keyring,
864 gen_ctx: Option<&GenContext>,
865 auth_cache: Option<&AuthCache>,
866) -> Result<Value, McpError> {
867 let client = McpClient::connect_with_gen(provider, keyring, gen_ctx, auth_cache).await?;
868
869 let mcp_tool_name = tool_name
871 .strip_prefix(&format!(
872 "{}{}",
873 provider.name,
874 crate::core::manifest::TOOL_SEP_STR
875 ))
876 .unwrap_or(tool_name);
877
878 let result = client.call_tool(mcp_tool_name, args.clone()).await?;
879
880 let value = mcp_result_to_value(&result);
882
883 client.disconnect().await;
885
886 Ok(value)
887}
888
889fn mcp_result_to_value(result: &McpToolResult) -> Value {
891 if result.content.len() == 1 {
892 let item = &result.content[0];
894 match item.content_type.as_str() {
895 "text" => {
896 if let Some(text) = &item.text {
897 serde_json::from_str(text).unwrap_or_else(|_| Value::String(text.clone()))
899 } else {
900 Value::Null
901 }
902 }
903 "image" | "audio" => {
904 serde_json::json!({
905 "type": item.content_type,
906 "data": item.data,
907 "mimeType": item.mime_type,
908 })
909 }
910 _ => serde_json::to_value(item).unwrap_or(Value::Null),
911 }
912 } else {
913 let items: Vec<Value> = result
915 .content
916 .iter()
917 .map(|c| serde_json::to_value(c).unwrap_or(Value::Null))
918 .collect();
919
920 serde_json::json!({
921 "content": items,
922 "isError": result.is_error,
923 })
924 }
925}
926
927pub async fn discover_all_mcp_tools(
936 registry: &mut crate::core::manifest::ManifestRegistry,
937 keyring: &Keyring,
938) -> usize {
939 use futures::stream::{self, StreamExt};
940
941 let providers: Vec<_> = registry
942 .list_mcp_providers()
943 .into_iter()
944 .map(|p| (p.name.clone(), p.clone()))
945 .collect();
946
947 if providers.is_empty() {
948 return 0;
949 }
950
951 let results: Vec<_> = stream::iter(&providers)
953 .map(|(name, provider)| async move {
954 let result = tokio::time::timeout(
955 std::time::Duration::from_secs(30),
956 discover_one_provider(name, provider, keyring),
957 )
958 .await;
959
960 match result {
961 Ok(Ok(tools)) => Some((name.clone(), tools)),
962 Ok(Err(e)) => {
963 tracing::warn!(provider = %name, error = %e, "MCP tool discovery failed");
964 None
965 }
966 Err(_) => {
967 tracing::warn!(provider = %name, "MCP tool discovery timed out (30s)");
968 None
969 }
970 }
971 })
972 .buffer_unordered(10)
973 .collect()
974 .await;
975
976 let mut total = 0;
978 for (name, tool_defs) in results.into_iter().flatten() {
979 let count = tool_defs.len();
980 registry.register_mcp_tools(&name, tool_defs);
981 tracing::info!(provider = %name, tools = count, "discovered MCP tools");
982 total += count;
983 }
984 total
985}
986
987async fn discover_one_provider(
989 _name: &str,
990 provider: &Provider,
991 keyring: &Keyring,
992) -> Result<Vec<crate::core::manifest::McpToolDef>, McpError> {
993 let client = McpClient::connect(provider, keyring).await?;
994 let tools = client.list_tools().await;
995 client.disconnect().await;
996
997 let tools = tools?;
998 Ok(tools
999 .into_iter()
1000 .map(|t| crate::core::manifest::McpToolDef {
1001 name: t.name,
1002 description: t.description,
1003 input_schema: t.input_schema,
1004 })
1005 .collect())
1006}
1007
1008#[cfg(test)]
1013mod tests {
1014 use super::*;
1015
1016 #[test]
1017 fn test_resolve_env_value_keyring() {
1018 let keyring = Keyring::empty();
1019 assert_eq!(
1021 resolve_env_value("${missing_key}", &keyring),
1022 "${missing_key}"
1023 );
1024 assert_eq!(resolve_env_value("plain_value", &keyring), "plain_value");
1026 }
1027
1028 #[test]
1029 fn test_resolve_env_value_inline() {
1030 let dir = tempfile::TempDir::new().unwrap();
1032 let path = dir.path().join("creds");
1033 std::fs::write(&path, r#"{"my_key":"SECRET123"}"#).unwrap();
1034 let keyring = Keyring::load_credentials(&path).unwrap();
1035
1036 assert_eq!(resolve_env_value("${my_key}", &keyring), "SECRET123");
1038 assert_eq!(
1040 resolve_env_value("https://example.com/${my_key}/path", &keyring),
1041 "https://example.com/SECRET123/path"
1042 );
1043 assert_eq!(
1045 resolve_env_value("${my_key}--${my_key}", &keyring),
1046 "SECRET123--SECRET123"
1047 );
1048 assert_eq!(
1050 resolve_env_value("https://example.com/${unknown}/path", &keyring),
1051 "https://example.com/${unknown}/path"
1052 );
1053 assert_eq!(
1055 resolve_env_value("no_placeholder", &keyring),
1056 "no_placeholder"
1057 );
1058 }
1059
1060 #[test]
1061 fn test_mcp_result_to_value_single_text() {
1062 let result = McpToolResult {
1063 content: vec![McpContent {
1064 content_type: "text".into(),
1065 text: Some("hello world".into()),
1066 data: None,
1067 mime_type: None,
1068 }],
1069 is_error: false,
1070 };
1071 assert_eq!(
1072 mcp_result_to_value(&result),
1073 Value::String("hello world".into())
1074 );
1075 }
1076
1077 #[test]
1078 fn test_mcp_result_to_value_json_text() {
1079 let result = McpToolResult {
1080 content: vec![McpContent {
1081 content_type: "text".into(),
1082 text: Some(r#"{"key":"value"}"#.into()),
1083 data: None,
1084 mime_type: None,
1085 }],
1086 is_error: false,
1087 };
1088 let val = mcp_result_to_value(&result);
1089 assert_eq!(val, serde_json::json!({"key": "value"}));
1090 }
1091
1092 #[test]
1093 fn test_extract_jsonrpc_result_success() {
1094 let msg = serde_json::json!({
1095 "jsonrpc": "2.0",
1096 "id": 1,
1097 "result": {"tools": []}
1098 });
1099 let result = extract_jsonrpc_result(&msg, 1).unwrap();
1100 assert_eq!(result, serde_json::json!({"tools": []}));
1101 }
1102
1103 #[test]
1104 fn test_extract_jsonrpc_result_error() {
1105 let msg = serde_json::json!({
1106 "jsonrpc": "2.0",
1107 "id": 1,
1108 "error": {"code": -32602, "message": "Invalid params"}
1109 });
1110 let err = extract_jsonrpc_result(&msg, 1).unwrap_err();
1111 assert!(matches!(err, McpError::Protocol { code: -32602, .. }));
1112 }
1113
1114 #[test]
1115 fn test_process_sse_data_matching_response() {
1116 let data = r#"{"jsonrpc":"2.0","id":5,"result":{"tools":[]}}"#;
1117 match process_sse_data(data, 5) {
1118 SseParseResult::OurResponse(Ok(val)) => {
1119 assert_eq!(val, serde_json::json!({"tools": []}));
1120 }
1121 _ => panic!("Expected OurResponse"),
1122 }
1123 }
1124
1125 #[test]
1126 fn test_process_sse_data_notification() {
1127 let data = r#"{"jsonrpc":"2.0","method":"notifications/tools/list_changed"}"#;
1129 match process_sse_data(data, 5) {
1130 SseParseResult::NotOurMessage => {}
1131 _ => panic!("Expected NotOurMessage"),
1132 }
1133 }
1134
1135 #[test]
1136 fn test_process_sse_data_batch() {
1137 let data = r#"[
1138 {"jsonrpc":"2.0","method":"notifications/progress","params":{}},
1139 {"jsonrpc":"2.0","id":3,"result":{"content":[],"isError":false}}
1140 ]"#;
1141 match process_sse_data(data, 3) {
1142 SseParseResult::OurResponse(Ok(val)) => {
1143 assert!(val.get("content").is_some());
1144 }
1145 _ => panic!("Expected OurResponse from batch"),
1146 }
1147 }
1148
1149 #[test]
1150 fn test_process_sse_data_invalid_json() {
1151 let data = "not valid json {{{}";
1152 match process_sse_data(data, 1) {
1153 SseParseResult::ParseError(_) => {}
1154 other => panic!("Expected ParseError, got: {other:?}"),
1155 }
1156 }
1157
1158 #[test]
1159 fn test_process_sse_data_wrong_id() {
1160 let data = r#"{"jsonrpc":"2.0","id":99,"result":{"data":"wrong"}}"#;
1161 match process_sse_data(data, 1) {
1162 SseParseResult::NotOurMessage => {}
1163 _ => panic!("Expected NotOurMessage for wrong ID"),
1164 }
1165 }
1166
1167 #[test]
1168 fn test_process_sse_data_empty_batch() {
1169 let data = "[]";
1170 match process_sse_data(data, 1) {
1171 SseParseResult::NotOurMessage => {}
1172 _ => panic!("Expected NotOurMessage for empty batch"),
1173 }
1174 }
1175
1176 #[test]
1177 fn test_extract_jsonrpc_result_missing_result() {
1178 let msg = serde_json::json!({
1179 "jsonrpc": "2.0",
1180 "id": 1
1181 });
1182 let err = extract_jsonrpc_result(&msg, 1).unwrap_err();
1183 assert!(matches!(err, McpError::Protocol { code: -1, .. }));
1184 }
1185
1186 #[test]
1187 fn test_extract_jsonrpc_error_defaults() {
1188 let msg = serde_json::json!({
1190 "jsonrpc": "2.0",
1191 "id": 1,
1192 "error": {}
1193 });
1194 let err = extract_jsonrpc_result(&msg, 1).unwrap_err();
1195 match err {
1196 McpError::Protocol { code, message } => {
1197 assert_eq!(code, -1);
1198 assert_eq!(message, "Unknown error");
1199 }
1200 _ => panic!("Expected Protocol error"),
1201 }
1202 }
1203
1204 #[test]
1205 fn test_mcp_result_to_value_error() {
1206 let result = McpToolResult {
1207 content: vec![McpContent {
1208 content_type: "text".into(),
1209 text: Some("Something went wrong".into()),
1210 data: None,
1211 mime_type: None,
1212 }],
1213 is_error: true,
1214 };
1215 let val = mcp_result_to_value(&result);
1216 assert_eq!(val, Value::String("Something went wrong".into()));
1217 }
1218
1219 #[test]
1220 fn test_mcp_result_to_value_multiple_content() {
1221 let result = McpToolResult {
1222 content: vec![
1223 McpContent {
1224 content_type: "text".into(),
1225 text: Some("Part 1".into()),
1226 data: None,
1227 mime_type: None,
1228 },
1229 McpContent {
1230 content_type: "text".into(),
1231 text: Some("Part 2".into()),
1232 data: None,
1233 mime_type: None,
1234 },
1235 ],
1236 is_error: false,
1237 };
1238 let val = mcp_result_to_value(&result);
1239 let content_arr = val["content"].as_array().unwrap();
1241 assert_eq!(content_arr.len(), 2);
1242 assert_eq!(val["isError"], false);
1243 }
1244
1245 #[test]
1246 fn test_mcp_result_to_value_empty_content() {
1247 let result = McpToolResult {
1248 content: vec![],
1249 is_error: false,
1250 };
1251 let val = mcp_result_to_value(&result);
1252 assert_eq!(val["content"].as_array().unwrap().len(), 0);
1254 assert_eq!(val["isError"], false);
1255 }
1256
1257 #[test]
1258 fn test_resolve_env_value_unclosed_brace() {
1259 let keyring = Keyring::empty();
1260 assert_eq!(resolve_env_value("${unclosed", &keyring), "${unclosed");
1261 }
1262}