1use serde::{Deserialize, Serialize};
10use serde_json::Value;
11use std::collections::HashMap;
12use std::io::{BufRead, BufReader, Write as IoWrite};
13use std::process::{Child, Command, Stdio};
14use std::sync::atomic::{AtomicU64, Ordering};
15use thiserror::Error;
16use tokio::sync::Mutex;
17
18use crate::core::auth_generator::{self, AuthCache, GenContext};
19use crate::core::keyring::Keyring;
20use crate::core::manifest::Provider;
21
22#[derive(Error, Debug)]
27#[allow(dead_code)]
28pub enum McpError {
29 #[error("MCP transport error: {0}")]
30 Transport(String),
31 #[error("MCP protocol error (code {code}): {message}")]
32 Protocol { code: i64, message: String },
33 #[error("MCP server did not return tools capability")]
34 NoToolsCapability,
35 #[error("IO error: {0}")]
36 Io(#[from] std::io::Error),
37 #[error("JSON error: {0}")]
38 Json(#[from] serde_json::Error),
39 #[error("HTTP error: {0}")]
40 Http(#[from] reqwest::Error),
41 #[error("MCP initialization failed: {0}")]
42 InitFailed(String),
43 #[error("SSE parse error: {0}")]
44 SseParse(String),
45 #[error("MCP server process exited unexpectedly")]
46 ProcessExited,
47 #[error("Missing MCP configuration: {0}")]
48 Config(String),
49}
50
51#[derive(Debug, Serialize)]
56struct JsonRpcRequest {
57 jsonrpc: &'static str,
58 id: u64,
59 method: String,
60 #[serde(skip_serializing_if = "Option::is_none")]
61 params: Option<Value>,
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct McpToolDef {
75 pub name: String,
76 #[serde(default)]
77 pub description: Option<String>,
78 #[serde(default, rename = "inputSchema")]
79 pub input_schema: Option<Value>,
80}
81
82#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct McpContent {
85 #[serde(rename = "type")]
86 pub content_type: String,
87 #[serde(default)]
88 pub text: Option<String>,
89 #[serde(default)]
90 pub data: Option<String>,
91 #[serde(default, rename = "mimeType")]
92 pub mime_type: Option<String>,
93}
94
95#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct McpToolResult {
98 pub content: Vec<McpContent>,
99 #[serde(default, rename = "isError")]
100 pub is_error: bool,
101}
102
103enum Transport {
109 Stdio(StdioTransport),
110 Http(HttpTransport),
111}
112
113struct StdioTransport {
115 child: Child,
116 stdin: Option<std::process::ChildStdin>,
118 reader: BufReader<std::process::ChildStdout>,
120}
121
122struct HttpTransport {
124 client: reqwest::Client,
125 url: String,
126 session_id: Option<String>,
128 auth_header_name: String,
130 auth_header: Option<String>,
132 extra_headers: HashMap<String, String>,
134}
135
136pub struct McpClient {
142 transport: Mutex<Transport>,
143 next_id: AtomicU64,
144 cached_tools: Mutex<Option<Vec<McpToolDef>>>,
146 provider_name: String,
148}
149
150impl McpClient {
151 pub async fn connect(provider: &Provider, keyring: &Keyring) -> Result<Self, McpError> {
156 Self::connect_with_gen(provider, keyring, None, None).await
157 }
158
159 pub async fn connect_with_gen(
161 provider: &Provider,
162 keyring: &Keyring,
163 gen_ctx: Option<&GenContext>,
164 auth_cache: Option<&AuthCache>,
165 ) -> Result<Self, McpError> {
166 let transport = match provider.mcp_transport_type() {
167 "stdio" => {
168 let command = provider.mcp_command.as_deref().ok_or_else(|| {
169 McpError::Config("mcp_command required for stdio transport".into())
170 })?;
171
172 let mut env_map: HashMap<String, String> = HashMap::new();
174 if let Ok(path) = std::env::var("PATH") {
176 env_map.insert("PATH".to_string(), path);
177 }
178 if let Ok(home) = std::env::var("HOME") {
179 env_map.insert("HOME".to_string(), home);
180 }
181 for (k, v) in &provider.mcp_env {
183 let resolved = resolve_env_value(v, keyring);
184 env_map.insert(k.clone(), resolved);
185 }
186
187 if let Some(gen) = &provider.auth_generator {
189 let default_ctx = GenContext::default();
190 let ctx = gen_ctx.unwrap_or(&default_ctx);
191 let default_cache = AuthCache::new();
192 let cache = auth_cache.unwrap_or(&default_cache);
193 match auth_generator::generate(provider, gen, ctx, keyring, cache).await {
194 Ok(cred) => {
195 env_map.insert("ATI_AUTH_TOKEN".to_string(), cred.value);
196 for (k, v) in &cred.extra_env {
197 env_map.insert(k.clone(), v.clone());
198 }
199 }
200 Err(e) => {
201 return Err(McpError::Config(format!("auth_generator failed: {e}")));
202 }
203 }
204 }
205
206 let mut child = Command::new(command)
207 .args(&provider.mcp_args)
208 .stdin(Stdio::piped())
209 .stdout(Stdio::piped())
210 .stderr(Stdio::piped())
211 .env_clear()
212 .envs(&env_map)
213 .spawn()
214 .map_err(|e| {
215 McpError::Transport(format!("Failed to spawn MCP server '{command}': {e}"))
216 })?;
217
218 let stdin = child
219 .stdin
220 .take()
221 .ok_or_else(|| McpError::Transport("No stdin".into()))?;
222 let stdout = child
223 .stdout
224 .take()
225 .ok_or_else(|| McpError::Transport("No stdout".into()))?;
226 let reader = BufReader::new(stdout);
227
228 Transport::Stdio(StdioTransport {
229 child,
230 stdin: Some(stdin),
231 reader,
232 })
233 }
234 "http" => {
235 let url = provider.mcp_url.as_deref().ok_or_else(|| {
236 McpError::Config("mcp_url required for HTTP transport".into())
237 })?;
238
239 let auth_header = if let Some(gen) = &provider.auth_generator {
241 let default_ctx = GenContext::default();
242 let ctx = gen_ctx.unwrap_or(&default_ctx);
243 let default_cache = AuthCache::new();
244 let cache = auth_cache.unwrap_or(&default_cache);
245 match auth_generator::generate(provider, gen, ctx, keyring, cache).await {
246 Ok(cred) => match &provider.auth_type {
247 super::manifest::AuthType::Bearer => {
248 Some(format!("Bearer {}", cred.value))
249 }
250 super::manifest::AuthType::Header => {
251 if let Some(prefix) = &provider.auth_value_prefix {
252 Some(format!("{prefix}{}", cred.value))
253 } else {
254 Some(cred.value)
255 }
256 }
257 _ => Some(cred.value),
258 },
259 Err(e) => {
260 return Err(McpError::Config(format!("auth_generator failed: {e}")));
261 }
262 }
263 } else {
264 build_auth_header(provider, keyring)
265 };
266
267 let client = reqwest::Client::builder()
268 .timeout(std::time::Duration::from_secs(300))
269 .build()?;
270
271 let resolved_url = resolve_env_value(url, keyring);
273
274 let auth_header_name = provider
275 .auth_header_name
276 .clone()
277 .unwrap_or_else(|| "Authorization".to_string());
278
279 Transport::Http(HttpTransport {
280 client,
281 url: resolved_url,
282 session_id: None,
283 auth_header_name,
284 auth_header,
285 extra_headers: provider.extra_headers.clone(),
286 })
287 }
288 other => {
289 return Err(McpError::Config(format!(
290 "Unknown MCP transport: '{other}' (expected 'stdio' or 'http')"
291 )));
292 }
293 };
294
295 let client = McpClient {
296 transport: Mutex::new(transport),
297 next_id: AtomicU64::new(1),
298 cached_tools: Mutex::new(None),
299 provider_name: provider.name.clone(),
300 };
301
302 client.initialize().await?;
304
305 Ok(client)
306 }
307
308 async fn initialize(&self) -> Result<(), McpError> {
310 let params = serde_json::json!({
311 "protocolVersion": "2025-03-26",
312 "capabilities": {},
313 "clientInfo": {
314 "name": "ati",
315 "version": env!("CARGO_PKG_VERSION")
316 }
317 });
318
319 let response = self.send_request("initialize", Some(params)).await?;
320
321 let capabilities = response.get("capabilities").unwrap_or(&Value::Null);
323 if capabilities.get("tools").is_none() {
324 return Err(McpError::NoToolsCapability);
325 }
326
327 self.send_notification("notifications/initialized", None)
331 .await?;
332
333 Ok(())
334 }
335
336 pub async fn list_tools(&self) -> Result<Vec<McpToolDef>, McpError> {
338 {
340 let cache = self.cached_tools.lock().await;
341 if let Some(tools) = cache.as_ref() {
342 return Ok(tools.clone());
343 }
344 }
345
346 let mut all_tools = Vec::new();
347 let mut cursor: Option<String> = None;
348 const MAX_PAGES: usize = 100;
349 const MAX_TOOLS: usize = 10_000;
350
351 for _page in 0..MAX_PAGES {
352 let params = cursor.as_ref().map(|c| serde_json::json!({"cursor": c}));
353 let result = self.send_request("tools/list", params).await?;
354
355 if let Some(tools_val) = result.get("tools") {
356 let tools: Vec<McpToolDef> = serde_json::from_value(tools_val.clone())?;
357 all_tools.extend(tools);
358 }
359
360 if all_tools.len() > MAX_TOOLS {
362 tracing::warn!(max = MAX_TOOLS, "MCP tool count exceeds limit, truncating");
363 all_tools.truncate(MAX_TOOLS);
364 break;
365 }
366
367 match result.get("nextCursor").and_then(|v| v.as_str()) {
369 Some(next) => cursor = Some(next.to_string()),
370 None => break,
371 }
372 }
373
374 {
376 let mut cache = self.cached_tools.lock().await;
377 *cache = Some(all_tools.clone());
378 }
379
380 Ok(all_tools)
381 }
382
383 pub async fn call_tool(
385 &self,
386 name: &str,
387 arguments: HashMap<String, Value>,
388 ) -> Result<McpToolResult, McpError> {
389 let params = serde_json::json!({
390 "name": name,
391 "arguments": arguments,
392 });
393
394 let result = self.send_request("tools/call", Some(params)).await?;
395 let tool_result: McpToolResult = serde_json::from_value(result)?;
396 Ok(tool_result)
397 }
398
399 pub async fn disconnect(&self) {
401 let mut transport = self.transport.lock().await;
402 match &mut *transport {
403 Transport::Stdio(stdio) => {
404 let _ = stdio.stdin.take();
407 let _ = stdio.child.kill();
409 let _ = stdio.child.wait();
410 }
411 Transport::Http(http) => {
412 if let Some(session_id) = &http.session_id {
414 let mut req = http.client.delete(&http.url);
415 req = req.header("Mcp-Session-Id", session_id.as_str());
416 let _ = req.send().await;
417 }
418 }
419 }
420 }
421
422 pub async fn invalidate_cache(&self) {
424 let mut cache = self.cached_tools.lock().await;
425 *cache = None;
426 }
427
428 async fn send_request(&self, method: &str, params: Option<Value>) -> Result<Value, McpError> {
433 let id = self.next_id.fetch_add(1, Ordering::SeqCst);
434 let request = JsonRpcRequest {
435 jsonrpc: "2.0",
436 id,
437 method: method.to_string(),
438 params,
439 };
440
441 let mut transport = self.transport.lock().await;
442 match &mut *transport {
443 Transport::Stdio(stdio) => send_stdio_request(stdio, &request).await,
444 Transport::Http(http) => send_http_request(http, &request, &self.provider_name).await,
445 }
446 }
447
448 async fn send_notification(&self, method: &str, params: Option<Value>) -> Result<(), McpError> {
449 let mut notification = serde_json::json!({
450 "jsonrpc": "2.0",
451 "method": method,
452 });
453 if let Some(p) = params {
454 notification["params"] = p;
455 }
456
457 let mut transport = self.transport.lock().await;
458 match &mut *transport {
459 Transport::Stdio(stdio) => {
460 let stdin = stdio
461 .stdin
462 .as_mut()
463 .ok_or_else(|| McpError::Transport("stdin closed".into()))?;
464 let msg = serde_json::to_string(¬ification)?;
465 stdin.write_all(msg.as_bytes())?;
466 stdin.write_all(b"\n")?;
467 stdin.flush()?;
468 Ok(())
469 }
470 Transport::Http(http) => {
471 let mut req = http
472 .client
473 .post(&http.url)
474 .header("Content-Type", "application/json")
475 .header("Accept", "application/json, text/event-stream")
476 .json(¬ification);
477
478 if let Some(session_id) = &http.session_id {
479 req = req.header("Mcp-Session-Id", session_id.as_str());
480 }
481 if let Some(auth) = &http.auth_header {
482 req = req.header(http.auth_header_name.as_str(), auth.as_str());
483 }
484 for (name, value) in &http.extra_headers {
485 req = req.header(name.as_str(), value.as_str());
486 }
487
488 let resp = req.send().await?;
489 if !resp.status().is_success() {
491 let status = resp.status().as_u16();
492 let body = resp.text().await.unwrap_or_default();
493 return Err(McpError::Transport(format!("HTTP {status}: {body}")));
494 }
495 Ok(())
496 }
497 }
498 }
499}
500
501async fn send_stdio_request(
508 stdio: &mut StdioTransport,
509 request: &JsonRpcRequest,
510) -> Result<Value, McpError> {
511 let stdin = stdio
512 .stdin
513 .as_mut()
514 .ok_or_else(|| McpError::Transport("stdin closed".into()))?;
515
516 let msg = serde_json::to_string(request)?;
518 stdin.write_all(msg.as_bytes())?;
519 stdin.write_all(b"\n")?;
520 stdin.flush()?;
521
522 let request_id = request.id;
523
524 loop {
527 let mut line = String::new();
528 let bytes_read = stdio.reader.read_line(&mut line)?;
529 if bytes_read == 0 {
530 return Err(McpError::ProcessExited);
531 }
532
533 let line = line.trim();
534 if line.is_empty() {
535 continue;
536 }
537
538 let parsed: Value = serde_json::from_str(line)?;
540
541 if let Some(id) = parsed.get("id") {
543 let id_matches = match id {
544 Value::Number(n) => n.as_u64() == Some(request_id),
545 _ => false,
546 };
547
548 if id_matches {
549 if let Some(err) = parsed.get("error") {
551 let code = err.get("code").and_then(|c| c.as_i64()).unwrap_or(-1);
552 let message = err
553 .get("message")
554 .and_then(|m| m.as_str())
555 .unwrap_or("Unknown error");
556 return Err(McpError::Protocol {
557 code,
558 message: message.to_string(),
559 });
560 }
561
562 return parsed
563 .get("result")
564 .cloned()
565 .ok_or_else(|| McpError::Protocol {
566 code: -1,
567 message: "Response missing 'result' field".into(),
568 });
569 }
570 }
571
572 }
574}
575
576async fn send_http_request(
588 http: &mut HttpTransport,
589 request: &JsonRpcRequest,
590 provider_name: &str,
591) -> Result<Value, McpError> {
592 let mut req = http
593 .client
594 .post(&http.url)
595 .header("Content-Type", "application/json")
596 .header("Accept", "application/json, text/event-stream")
597 .json(request);
598
599 if let Some(session_id) = &http.session_id {
601 req = req.header("Mcp-Session-Id", session_id.as_str());
602 }
603
604 if let Some(auth) = &http.auth_header {
606 req = req.header(http.auth_header_name.as_str(), auth.as_str());
607 }
608
609 for (name, value) in &http.extra_headers {
611 req = req.header(name.as_str(), value.as_str());
612 }
613
614 let response = req
615 .send()
616 .await
617 .map_err(|e| McpError::Transport(format!("[{provider_name}] HTTP request failed: {e}")))?;
618
619 if let Some(session_val) = response.headers().get("mcp-session-id") {
621 if let Ok(sid) = session_val.to_str() {
622 http.session_id = Some(sid.to_string());
623 }
624 }
625
626 let status = response.status();
627 if !status.is_success() {
628 let body = response.text().await.unwrap_or_default();
629 return Err(McpError::Transport(format!(
630 "[{provider_name}] HTTP {}: {body}",
631 status.as_u16()
632 )));
633 }
634
635 let content_type = response
637 .headers()
638 .get("content-type")
639 .and_then(|v| v.to_str().ok())
640 .unwrap_or("")
641 .to_lowercase();
642
643 if content_type.contains("text/event-stream") {
644 parse_sse_response(response, request.id).await
646 } else {
647 let body: Value = response.json().await?;
649 extract_jsonrpc_result(&body, request.id)
650 }
651}
652
653const MAX_SSE_BODY_SIZE: usize = 50 * 1024 * 1024;
665
666async fn parse_sse_response(
667 response: reqwest::Response,
668 request_id: u64,
669) -> Result<Value, McpError> {
670 let bytes = response
672 .bytes()
673 .await
674 .map_err(|e| McpError::SseParse(format!("Failed to read SSE stream: {e}")))?;
675 if bytes.len() > MAX_SSE_BODY_SIZE {
676 return Err(McpError::SseParse(format!(
677 "SSE response body exceeds maximum size ({} bytes > {} bytes)",
678 bytes.len(),
679 MAX_SSE_BODY_SIZE,
680 )));
681 }
682 let full_body = String::from_utf8_lossy(&bytes).into_owned();
683
684 let mut current_data = String::new();
686
687 for line in full_body.lines() {
688 if line.starts_with("data:") {
689 let data = line.strip_prefix("data:").unwrap().trim();
690 if !data.is_empty() {
691 current_data.push_str(data);
692 }
693 } else if line.is_empty() && !current_data.is_empty() {
694 match process_sse_data(¤t_data, request_id) {
696 SseParseResult::OurResponse(result) => return result,
697 SseParseResult::NotOurMessage => {}
698 SseParseResult::ParseError(e) => {
699 tracing::warn!(error = %e, "failed to parse SSE data");
700 }
701 }
702 current_data.clear();
703 }
704 }
706
707 if !current_data.is_empty() {
709 if let SseParseResult::OurResponse(result) = process_sse_data(¤t_data, request_id) {
710 return result;
711 }
712 }
713
714 Err(McpError::SseParse(
715 "SSE stream ended without receiving a response for our request".into(),
716 ))
717}
718
719#[derive(Debug)]
720enum SseParseResult {
721 OurResponse(Result<Value, McpError>),
722 NotOurMessage,
723 ParseError(String),
724}
725
726fn process_sse_data(data: &str, request_id: u64) -> SseParseResult {
727 let parsed: Value = match serde_json::from_str(data) {
728 Ok(v) => v,
729 Err(e) => return SseParseResult::ParseError(e.to_string()),
730 };
731
732 let messages = if parsed.is_array() {
734 parsed.as_array().unwrap().clone()
735 } else {
736 vec![parsed]
737 };
738
739 for msg in messages {
740 if let Some(id) = msg.get("id") {
742 let id_matches = match id {
743 Value::Number(n) => n.as_u64() == Some(request_id),
744 _ => false,
745 };
746 if id_matches {
747 return SseParseResult::OurResponse(extract_jsonrpc_result(&msg, request_id));
748 }
749 }
750 }
752
753 SseParseResult::NotOurMessage
754}
755
756fn extract_jsonrpc_result(msg: &Value, _request_id: u64) -> Result<Value, McpError> {
758 if let Some(err) = msg.get("error") {
759 let code = err.get("code").and_then(|c| c.as_i64()).unwrap_or(-1);
760 let message = err
761 .get("message")
762 .and_then(|m| m.as_str())
763 .unwrap_or("Unknown error");
764 return Err(McpError::Protocol {
765 code,
766 message: message.to_string(),
767 });
768 }
769
770 msg.get("result")
771 .cloned()
772 .ok_or_else(|| McpError::Protocol {
773 code: -1,
774 message: "Response missing 'result' field".into(),
775 })
776}
777
778fn resolve_env_value(value: &str, keyring: &Keyring) -> String {
785 let mut result = value.to_string();
786 while let Some(start) = result.find("${") {
788 let rest = &result[start + 2..];
789 if let Some(end) = rest.find('}') {
790 let key_name = &rest[..end];
791 let replacement = keyring.get(key_name).unwrap_or("");
792 if replacement.is_empty() && keyring.get(key_name).is_none() {
793 break;
795 }
796 result = format!("{}{}{}", &result[..start], replacement, &rest[end + 1..]);
797 } else {
798 break; }
800 }
801 result
802}
803
804fn build_auth_header(provider: &Provider, keyring: &Keyring) -> Option<String> {
806 let key_name = provider.auth_key_name.as_deref()?;
807 let key_value = keyring.get(key_name)?;
808
809 match &provider.auth_type {
810 super::manifest::AuthType::Bearer => Some(format!("Bearer {key_value}")),
811 super::manifest::AuthType::Header => {
812 if let Some(prefix) = &provider.auth_value_prefix {
814 Some(format!("{prefix}{key_value}"))
815 } else {
816 Some(key_value.to_string())
817 }
818 }
819 super::manifest::AuthType::Basic => {
820 let encoded = base64::Engine::encode(
821 &base64::engine::general_purpose::STANDARD,
822 format!("{key_value}:"),
823 );
824 Some(format!("Basic {encoded}"))
825 }
826 _ => None,
827 }
828}
829
830pub async fn execute(
841 provider: &Provider,
842 tool_name: &str,
843 args: &HashMap<String, Value>,
844 keyring: &Keyring,
845) -> Result<Value, McpError> {
846 execute_with_gen(provider, tool_name, args, keyring, None, None).await
847}
848
849pub async fn execute_with_gen(
851 provider: &Provider,
852 tool_name: &str,
853 args: &HashMap<String, Value>,
854 keyring: &Keyring,
855 gen_ctx: Option<&GenContext>,
856 auth_cache: Option<&AuthCache>,
857) -> Result<Value, McpError> {
858 let client = McpClient::connect_with_gen(provider, keyring, gen_ctx, auth_cache).await?;
859
860 let mcp_tool_name = tool_name
862 .strip_prefix(&format!(
863 "{}{}",
864 provider.name,
865 crate::core::manifest::TOOL_SEP_STR
866 ))
867 .unwrap_or(tool_name);
868
869 let result = client.call_tool(mcp_tool_name, args.clone()).await?;
870
871 let value = mcp_result_to_value(&result);
873
874 client.disconnect().await;
876
877 Ok(value)
878}
879
880fn mcp_result_to_value(result: &McpToolResult) -> Value {
882 if result.content.len() == 1 {
883 let item = &result.content[0];
885 match item.content_type.as_str() {
886 "text" => {
887 if let Some(text) = &item.text {
888 serde_json::from_str(text).unwrap_or_else(|_| Value::String(text.clone()))
890 } else {
891 Value::Null
892 }
893 }
894 "image" | "audio" => {
895 serde_json::json!({
896 "type": item.content_type,
897 "data": item.data,
898 "mimeType": item.mime_type,
899 })
900 }
901 _ => serde_json::to_value(item).unwrap_or(Value::Null),
902 }
903 } else {
904 let items: Vec<Value> = result
906 .content
907 .iter()
908 .map(|c| serde_json::to_value(c).unwrap_or(Value::Null))
909 .collect();
910
911 serde_json::json!({
912 "content": items,
913 "isError": result.is_error,
914 })
915 }
916}
917
918#[cfg(test)]
923mod tests {
924 use super::*;
925
926 #[test]
927 fn test_resolve_env_value_keyring() {
928 let keyring = Keyring::empty();
929 assert_eq!(
931 resolve_env_value("${missing_key}", &keyring),
932 "${missing_key}"
933 );
934 assert_eq!(resolve_env_value("plain_value", &keyring), "plain_value");
936 }
937
938 #[test]
939 fn test_resolve_env_value_inline() {
940 let dir = tempfile::TempDir::new().unwrap();
942 let path = dir.path().join("creds");
943 std::fs::write(&path, r#"{"my_key":"SECRET123"}"#).unwrap();
944 let keyring = Keyring::load_credentials(&path).unwrap();
945
946 assert_eq!(resolve_env_value("${my_key}", &keyring), "SECRET123");
948 assert_eq!(
950 resolve_env_value("https://example.com/${my_key}/path", &keyring),
951 "https://example.com/SECRET123/path"
952 );
953 assert_eq!(
955 resolve_env_value("${my_key}--${my_key}", &keyring),
956 "SECRET123--SECRET123"
957 );
958 assert_eq!(
960 resolve_env_value("https://example.com/${unknown}/path", &keyring),
961 "https://example.com/${unknown}/path"
962 );
963 assert_eq!(
965 resolve_env_value("no_placeholder", &keyring),
966 "no_placeholder"
967 );
968 }
969
970 #[test]
971 fn test_mcp_result_to_value_single_text() {
972 let result = McpToolResult {
973 content: vec![McpContent {
974 content_type: "text".into(),
975 text: Some("hello world".into()),
976 data: None,
977 mime_type: None,
978 }],
979 is_error: false,
980 };
981 assert_eq!(
982 mcp_result_to_value(&result),
983 Value::String("hello world".into())
984 );
985 }
986
987 #[test]
988 fn test_mcp_result_to_value_json_text() {
989 let result = McpToolResult {
990 content: vec![McpContent {
991 content_type: "text".into(),
992 text: Some(r#"{"key":"value"}"#.into()),
993 data: None,
994 mime_type: None,
995 }],
996 is_error: false,
997 };
998 let val = mcp_result_to_value(&result);
999 assert_eq!(val, serde_json::json!({"key": "value"}));
1000 }
1001
1002 #[test]
1003 fn test_extract_jsonrpc_result_success() {
1004 let msg = serde_json::json!({
1005 "jsonrpc": "2.0",
1006 "id": 1,
1007 "result": {"tools": []}
1008 });
1009 let result = extract_jsonrpc_result(&msg, 1).unwrap();
1010 assert_eq!(result, serde_json::json!({"tools": []}));
1011 }
1012
1013 #[test]
1014 fn test_extract_jsonrpc_result_error() {
1015 let msg = serde_json::json!({
1016 "jsonrpc": "2.0",
1017 "id": 1,
1018 "error": {"code": -32602, "message": "Invalid params"}
1019 });
1020 let err = extract_jsonrpc_result(&msg, 1).unwrap_err();
1021 assert!(matches!(err, McpError::Protocol { code: -32602, .. }));
1022 }
1023
1024 #[test]
1025 fn test_process_sse_data_matching_response() {
1026 let data = r#"{"jsonrpc":"2.0","id":5,"result":{"tools":[]}}"#;
1027 match process_sse_data(data, 5) {
1028 SseParseResult::OurResponse(Ok(val)) => {
1029 assert_eq!(val, serde_json::json!({"tools": []}));
1030 }
1031 _ => panic!("Expected OurResponse"),
1032 }
1033 }
1034
1035 #[test]
1036 fn test_process_sse_data_notification() {
1037 let data = r#"{"jsonrpc":"2.0","method":"notifications/tools/list_changed"}"#;
1039 match process_sse_data(data, 5) {
1040 SseParseResult::NotOurMessage => {}
1041 _ => panic!("Expected NotOurMessage"),
1042 }
1043 }
1044
1045 #[test]
1046 fn test_process_sse_data_batch() {
1047 let data = r#"[
1048 {"jsonrpc":"2.0","method":"notifications/progress","params":{}},
1049 {"jsonrpc":"2.0","id":3,"result":{"content":[],"isError":false}}
1050 ]"#;
1051 match process_sse_data(data, 3) {
1052 SseParseResult::OurResponse(Ok(val)) => {
1053 assert!(val.get("content").is_some());
1054 }
1055 _ => panic!("Expected OurResponse from batch"),
1056 }
1057 }
1058
1059 #[test]
1060 fn test_process_sse_data_invalid_json() {
1061 let data = "not valid json {{{}";
1062 match process_sse_data(data, 1) {
1063 SseParseResult::ParseError(_) => {}
1064 other => panic!("Expected ParseError, got: {other:?}"),
1065 }
1066 }
1067
1068 #[test]
1069 fn test_process_sse_data_wrong_id() {
1070 let data = r#"{"jsonrpc":"2.0","id":99,"result":{"data":"wrong"}}"#;
1071 match process_sse_data(data, 1) {
1072 SseParseResult::NotOurMessage => {}
1073 _ => panic!("Expected NotOurMessage for wrong ID"),
1074 }
1075 }
1076
1077 #[test]
1078 fn test_process_sse_data_empty_batch() {
1079 let data = "[]";
1080 match process_sse_data(data, 1) {
1081 SseParseResult::NotOurMessage => {}
1082 _ => panic!("Expected NotOurMessage for empty batch"),
1083 }
1084 }
1085
1086 #[test]
1087 fn test_extract_jsonrpc_result_missing_result() {
1088 let msg = serde_json::json!({
1089 "jsonrpc": "2.0",
1090 "id": 1
1091 });
1092 let err = extract_jsonrpc_result(&msg, 1).unwrap_err();
1093 assert!(matches!(err, McpError::Protocol { code: -1, .. }));
1094 }
1095
1096 #[test]
1097 fn test_extract_jsonrpc_error_defaults() {
1098 let msg = serde_json::json!({
1100 "jsonrpc": "2.0",
1101 "id": 1,
1102 "error": {}
1103 });
1104 let err = extract_jsonrpc_result(&msg, 1).unwrap_err();
1105 match err {
1106 McpError::Protocol { code, message } => {
1107 assert_eq!(code, -1);
1108 assert_eq!(message, "Unknown error");
1109 }
1110 _ => panic!("Expected Protocol error"),
1111 }
1112 }
1113
1114 #[test]
1115 fn test_mcp_result_to_value_error() {
1116 let result = McpToolResult {
1117 content: vec![McpContent {
1118 content_type: "text".into(),
1119 text: Some("Something went wrong".into()),
1120 data: None,
1121 mime_type: None,
1122 }],
1123 is_error: true,
1124 };
1125 let val = mcp_result_to_value(&result);
1126 assert_eq!(val, Value::String("Something went wrong".into()));
1127 }
1128
1129 #[test]
1130 fn test_mcp_result_to_value_multiple_content() {
1131 let result = McpToolResult {
1132 content: vec![
1133 McpContent {
1134 content_type: "text".into(),
1135 text: Some("Part 1".into()),
1136 data: None,
1137 mime_type: None,
1138 },
1139 McpContent {
1140 content_type: "text".into(),
1141 text: Some("Part 2".into()),
1142 data: None,
1143 mime_type: None,
1144 },
1145 ],
1146 is_error: false,
1147 };
1148 let val = mcp_result_to_value(&result);
1149 let content_arr = val["content"].as_array().unwrap();
1151 assert_eq!(content_arr.len(), 2);
1152 assert_eq!(val["isError"], false);
1153 }
1154
1155 #[test]
1156 fn test_mcp_result_to_value_empty_content() {
1157 let result = McpToolResult {
1158 content: vec![],
1159 is_error: false,
1160 };
1161 let val = mcp_result_to_value(&result);
1162 assert_eq!(val["content"].as_array().unwrap().len(), 0);
1164 assert_eq!(val["isError"], false);
1165 }
1166
1167 #[test]
1168 fn test_resolve_env_value_unclosed_brace() {
1169 let keyring = Keyring::empty();
1170 assert_eq!(resolve_env_value("${unclosed", &keyring), "${unclosed");
1171 }
1172}