1use std::collections::{BTreeMap, VecDeque};
2use std::fmt;
3use std::process::Stdio;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::time::Duration;
7
8use agentkit_capabilities::{
9 CapabilityContext, CapabilityError, CapabilityName, CapabilityProvider, Invocable,
10 InvocableOutput, InvocableRequest, InvocableResult, InvocableSpec, PromptContents,
11 PromptDescriptor, PromptId, PromptProvider, ResourceContents, ResourceDescriptor, ResourceId,
12 ResourceProvider,
13};
14use agentkit_core::{
15 DataRef, Item, ItemKind, MetadataMap, Part, TextPart, ToolOutput, ToolResultPart,
16};
17use agentkit_tools_core::{
18 AuthOperation, AuthRequest, AuthResolution, Tool, ToolAnnotations, ToolContext, ToolError,
19 ToolName, ToolRegistry, ToolRequest, ToolResult, ToolSpec,
20};
21use async_trait::async_trait;
22use futures_util::TryStreamExt;
23use reqwest::{Client, StatusCode, Url};
24use serde::{Deserialize, Serialize};
25use serde_json::{Value, json};
26use thiserror::Error;
27use tokio::io::{AsyncBufRead, AsyncBufReadExt, AsyncWriteExt, BufReader};
28use tokio::process::{Child, ChildStdin, ChildStdout, Command};
29use tokio::sync::{Mutex, mpsc, oneshot};
30use tokio::task::JoinHandle;
31use tokio::time::sleep;
32use tokio_util::io::StreamReader;
33
34const MCP_LATEST_PROTOCOL_VERSION: &str = "2025-11-25";
35const MCP_SUPPORTED_PROTOCOL_VERSIONS: &[&str] =
36 &["2025-11-25", "2025-06-18", "2025-03-26", "2024-11-05"];
37
38#[derive(Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
53pub struct McpServerId(pub String);
54
55impl McpServerId {
56 pub fn new(value: impl Into<String>) -> Self {
58 Self(value.into())
59 }
60}
61
62impl fmt::Display for McpServerId {
63 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
64 self.0.fmt(f)
65 }
66}
67
68#[derive(Clone, Debug, PartialEq, Eq)]
86pub struct StdioTransportConfig {
87 pub command: String,
89 pub args: Vec<String>,
91 pub env: Vec<(String, String)>,
93 pub cwd: Option<std::path::PathBuf>,
95}
96
97impl StdioTransportConfig {
98 pub fn new(command: impl Into<String>) -> Self {
100 Self {
101 command: command.into(),
102 args: Vec::new(),
103 env: Vec::new(),
104 cwd: None,
105 }
106 }
107
108 pub fn with_arg(mut self, arg: impl Into<String>) -> Self {
110 self.args.push(arg.into());
111 self
112 }
113
114 pub fn with_env(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
116 self.env.push((key.into(), value.into()));
117 self
118 }
119
120 pub fn with_cwd(mut self, cwd: impl Into<std::path::PathBuf>) -> Self {
122 self.cwd = Some(cwd.into());
123 self
124 }
125}
126
127#[derive(Clone, Debug, PartialEq, Eq)]
142pub struct SseTransportConfig {
143 pub url: String,
145 pub headers: Vec<(String, String)>,
147}
148
149impl SseTransportConfig {
150 pub fn new(url: impl Into<String>) -> Self {
152 Self {
153 url: url.into(),
154 headers: Vec::new(),
155 }
156 }
157
158 pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
160 self.headers.push((key.into(), value.into()));
161 self
162 }
163}
164
165#[derive(Clone, Debug, PartialEq, Eq)]
180pub struct StreamableHttpTransportConfig {
181 pub url: String,
183 pub headers: Vec<(String, String)>,
185}
186
187impl StreamableHttpTransportConfig {
188 pub fn new(url: impl Into<String>) -> Self {
190 Self {
191 url: url.into(),
192 headers: Vec::new(),
193 }
194 }
195
196 pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
198 self.headers.push((key.into(), value.into()));
199 self
200 }
201}
202
203#[derive(Clone)]
210pub enum McpTransportBinding {
211 Stdio(StdioTransportConfig),
213 StreamableHttp(StreamableHttpTransportConfig),
215 Sse(SseTransportConfig),
217 Custom(Arc<dyn McpTransportFactory>),
219}
220
221#[derive(Clone)]
242pub struct McpServerConfig {
243 pub id: McpServerId,
245 pub transport: McpTransportBinding,
247 pub metadata: MetadataMap,
249}
250
251impl McpServerConfig {
252 pub fn new(id: impl Into<String>, transport: McpTransportBinding) -> Self {
259 Self {
260 id: McpServerId::new(id),
261 transport,
262 metadata: MetadataMap::new(),
263 }
264 }
265
266 pub fn stdio(id: impl Into<String>, command: impl Into<String>) -> Self {
268 Self::new(
269 id,
270 McpTransportBinding::Stdio(StdioTransportConfig::new(command)),
271 )
272 }
273
274 pub fn sse(id: impl Into<String>, url: impl Into<String>) -> Self {
276 Self::new(id, McpTransportBinding::Sse(SseTransportConfig::new(url)))
277 }
278
279 pub fn streamable_http(id: impl Into<String>, url: impl Into<String>) -> Self {
281 Self::new(
282 id,
283 McpTransportBinding::StreamableHttp(StreamableHttpTransportConfig::new(url)),
284 )
285 }
286
287 pub fn with_metadata(mut self, metadata: MetadataMap) -> Self {
289 self.metadata = metadata;
290 self
291 }
292}
293
294#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
299pub struct McpFrame {
300 pub value: Value,
302}
303
304#[async_trait]
315pub trait McpTransportFactory: Send + Sync {
316 async fn connect(&self) -> Result<Box<dyn McpTransport>, McpError>;
318}
319
320#[async_trait]
329pub trait McpTransport: Send + Sync {
330 async fn send(&mut self, message: McpFrame) -> Result<(), McpError>;
332 async fn recv(&mut self) -> Result<Option<McpFrame>, McpError>;
334 async fn close(&mut self) -> Result<(), McpError>;
336}
337
338pub struct StdioTransportFactory {
343 config: StdioTransportConfig,
344}
345
346impl StdioTransportFactory {
347 pub fn new(config: StdioTransportConfig) -> Self {
349 Self { config }
350 }
351}
352
353#[async_trait]
354impl McpTransportFactory for StdioTransportFactory {
355 async fn connect(&self) -> Result<Box<dyn McpTransport>, McpError> {
356 let mut command = Command::new(&self.config.command);
357 command.args(&self.config.args);
358 command.stdin(Stdio::piped());
359 command.stdout(Stdio::piped());
360 command.stderr(Stdio::inherit());
361
362 if let Some(cwd) = &self.config.cwd {
363 command.current_dir(cwd);
364 }
365
366 for (key, value) in &self.config.env {
367 command.env(key, value);
368 }
369
370 let mut child = command.spawn().map_err(McpError::Io)?;
371 let stdin = child
372 .stdin
373 .take()
374 .ok_or_else(|| McpError::Transport("failed to capture MCP stdin".into()))?;
375 let stdout = child
376 .stdout
377 .take()
378 .ok_or_else(|| McpError::Transport("failed to capture MCP stdout".into()))?;
379
380 Ok(Box::new(StdioTransport {
381 child,
382 stdin,
383 stdout: BufReader::new(stdout),
384 }))
385 }
386}
387
388pub struct SseTransportFactory {
393 config: SseTransportConfig,
394}
395
396impl SseTransportFactory {
397 pub fn new(config: SseTransportConfig) -> Self {
399 Self { config }
400 }
401}
402
403pub struct StreamableHttpTransportFactory {
408 config: StreamableHttpTransportConfig,
409}
410
411impl StreamableHttpTransportFactory {
412 pub fn new(config: StreamableHttpTransportConfig) -> Self {
414 Self { config }
415 }
416}
417
418#[async_trait]
419impl McpTransportFactory for SseTransportFactory {
420 async fn connect(&self) -> Result<Box<dyn McpTransport>, McpError> {
421 let client = Client::builder()
422 .user_agent(concat!("agentkit-mcp/", env!("CARGO_PKG_VERSION")))
423 .build()
424 .map_err(McpError::Http)?;
425
426 let mut request = client
427 .get(&self.config.url)
428 .header("Accept", "text/event-stream")
429 .header("Cache-Control", "no-cache");
430
431 for (key, value) in &self.config.headers {
432 request = request.header(key, value);
433 }
434
435 let response = request.send().await.map_err(McpError::Http)?;
436 let status = response.status();
437 if !status.is_success() {
438 let body = response
439 .text()
440 .await
441 .unwrap_or_else(|_| "<unreadable response body>".into());
442 return Err(McpError::Transport(format!(
443 "SSE connection failed with status {status}: {body}"
444 )));
445 }
446
447 let response_url = response.url().clone();
448 let stream = response.bytes_stream().map_err(std::io::Error::other);
449 let reader = BufReader::new(StreamReader::new(stream));
450 let (frame_tx, frame_rx) = mpsc::unbounded_channel();
451 let (endpoint_tx, endpoint_rx) = oneshot::channel();
452 let read_task = tokio::spawn(read_sse_stream(reader, response_url, frame_tx, endpoint_tx));
453
454 let endpoint_url = endpoint_rx
455 .await
456 .map_err(|_| McpError::Transport("SSE stream closed before endpoint event".into()))??;
457
458 Ok(Box::new(SseTransport {
459 client,
460 endpoint_url,
461 headers: self.config.headers.clone(),
462 frame_rx,
463 read_task,
464 }))
465 }
466}
467
468#[async_trait]
469impl McpTransportFactory for StreamableHttpTransportFactory {
470 async fn connect(&self) -> Result<Box<dyn McpTransport>, McpError> {
471 let client = Client::builder()
472 .user_agent(concat!("agentkit-mcp/", env!("CARGO_PKG_VERSION")))
473 .build()
474 .map_err(McpError::Http)?;
475
476 let endpoint_url = Url::parse(&self.config.url)
477 .map_err(|error| McpError::Transport(format!("invalid MCP endpoint URL: {error}")))?;
478
479 Ok(Box::new(StreamableHttpTransport {
480 client,
481 endpoint_url,
482 headers: self.config.headers.clone(),
483 protocol_version: None,
484 session_id: None,
485 pending_frames: VecDeque::new(),
486 }))
487 }
488}
489
490struct StdioTransport {
491 child: Child,
492 stdin: ChildStdin,
493 stdout: BufReader<ChildStdout>,
494}
495
496struct SseTransport {
497 client: Client,
498 endpoint_url: Url,
499 headers: Vec<(String, String)>,
500 frame_rx: mpsc::UnboundedReceiver<Result<McpFrame, McpError>>,
501 read_task: JoinHandle<()>,
502}
503
504struct StreamableHttpTransport {
505 client: Client,
506 endpoint_url: Url,
507 headers: Vec<(String, String)>,
508 protocol_version: Option<String>,
509 session_id: Option<String>,
510 pending_frames: VecDeque<McpFrame>,
511}
512
513#[async_trait]
514impl McpTransport for StdioTransport {
515 async fn send(&mut self, message: McpFrame) -> Result<(), McpError> {
516 let mut encoded = serde_json::to_vec(&message.value).map_err(McpError::Serialize)?;
517 encoded.push(b'\n');
518 self.stdin.write_all(&encoded).await.map_err(McpError::Io)?;
519 self.stdin.flush().await.map_err(McpError::Io)?;
520 Ok(())
521 }
522
523 async fn recv(&mut self) -> Result<Option<McpFrame>, McpError> {
524 let mut line = String::new();
525 let read = self
526 .stdout
527 .read_line(&mut line)
528 .await
529 .map_err(McpError::Io)?;
530 if read == 0 {
531 return Ok(None);
532 }
533
534 let value = serde_json::from_str(line.trim()).map_err(McpError::Serialize)?;
535 Ok(Some(McpFrame { value }))
536 }
537
538 async fn close(&mut self) -> Result<(), McpError> {
539 let _ = self.stdin.shutdown().await;
540 let _ = self.child.kill().await;
541 Ok(())
542 }
543}
544
545#[async_trait]
546impl McpTransport for SseTransport {
547 async fn send(&mut self, message: McpFrame) -> Result<(), McpError> {
548 let mut request = self
549 .client
550 .post(self.endpoint_url.clone())
551 .header("Content-Type", "application/json");
552
553 for (key, value) in &self.headers {
554 request = request.header(key, value);
555 }
556
557 let response = request
558 .json(&message.value)
559 .send()
560 .await
561 .map_err(McpError::Http)?;
562 let status = response.status();
563 if !status.is_success() {
564 let body = response
565 .text()
566 .await
567 .unwrap_or_else(|_| "<unreadable response body>".into());
568 return Err(McpError::Transport(format!(
569 "SSE POST failed with status {status}: {body}"
570 )));
571 }
572
573 Ok(())
574 }
575
576 async fn recv(&mut self) -> Result<Option<McpFrame>, McpError> {
577 match self.frame_rx.recv().await {
578 Some(Ok(frame)) => Ok(Some(frame)),
579 Some(Err(error)) => Err(error),
580 None => Ok(None),
581 }
582 }
583
584 async fn close(&mut self) -> Result<(), McpError> {
585 self.read_task.abort();
586 Ok(())
587 }
588}
589
590#[async_trait]
591impl McpTransport for StreamableHttpTransport {
592 async fn send(&mut self, message: McpFrame) -> Result<(), McpError> {
593 let is_request = is_jsonrpc_request(&message.value);
594 let request_id = message.value.get("id").cloned();
595 let is_initialize =
596 message.value.get("method").and_then(Value::as_str) == Some("initialize");
597
598 let mut request = self
599 .client
600 .post(self.endpoint_url.clone())
601 .header("Content-Type", "application/json")
602 .header("Accept", "application/json, text/event-stream");
603
604 request = apply_streamable_http_headers(
605 request,
606 &self.headers,
607 self.protocol_version.as_deref(),
608 self.session_id.as_deref(),
609 );
610
611 let response = request
612 .json(&message.value)
613 .send()
614 .await
615 .map_err(McpError::Http)?;
616
617 if is_initialize {
618 self.capture_session_id(response.headers());
619 }
620
621 let status = response.status();
622 if !status.is_success() {
623 return Err(
624 streamable_http_status_error("Streamable HTTP POST", status, response).await,
625 );
626 }
627
628 if !is_request {
629 return Ok(());
630 }
631
632 let content_type = response
633 .headers()
634 .get(reqwest::header::CONTENT_TYPE)
635 .and_then(|value| value.to_str().ok())
636 .unwrap_or_default()
637 .to_string();
638
639 if content_type.starts_with("application/json") {
640 let value = response.json::<Value>().await.map_err(McpError::Http)?;
641 self.maybe_update_protocol_version(&message.value, &value)?;
642 self.pending_frames.push_back(McpFrame { value });
643 return Ok(());
644 }
645
646 if !content_type.starts_with("text/event-stream") {
647 let body = response
648 .text()
649 .await
650 .unwrap_or_else(|_| "<unreadable response body>".into());
651 return Err(McpError::Transport(format!(
652 "unexpected Streamable HTTP response content type {content_type:?}: {body}"
653 )));
654 }
655
656 let request_id = request_id.ok_or_else(|| {
657 McpError::Protocol("JSON-RPC request over Streamable HTTP is missing an id".into())
658 })?;
659 self.collect_streamable_http_response(response, &message.value, &request_id)
660 .await
661 }
662
663 async fn recv(&mut self) -> Result<Option<McpFrame>, McpError> {
664 Ok(self.pending_frames.pop_front())
665 }
666
667 async fn close(&mut self) -> Result<(), McpError> {
668 let Some(session_id) = self.session_id.clone() else {
669 return Ok(());
670 };
671
672 let mut request = self.client.delete(self.endpoint_url.clone());
673 request = apply_streamable_http_headers(
674 request,
675 &self.headers,
676 self.protocol_version.as_deref(),
677 Some(session_id.as_str()),
678 );
679
680 let response = request.send().await.map_err(McpError::Http)?;
681 if response.status().is_success()
682 || response.status() == StatusCode::METHOD_NOT_ALLOWED
683 || response.status() == StatusCode::NOT_FOUND
684 {
685 self.session_id = None;
686 return Ok(());
687 }
688
689 Err(
690 streamable_http_status_error("Streamable HTTP DELETE", response.status(), response)
691 .await,
692 )
693 }
694}
695
696impl StreamableHttpTransport {
697 async fn collect_streamable_http_response(
698 &mut self,
699 response: reqwest::Response,
700 request_message: &Value,
701 request_id: &Value,
702 ) -> Result<(), McpError> {
703 let mut retry_delay = Duration::from_millis(0);
704 let mut last_event_id = None;
705 let mut saw_response = false;
706
707 saw_response |= self
708 .read_streamable_http_events(
709 response,
710 request_message,
711 request_id,
712 &mut last_event_id,
713 &mut retry_delay,
714 )
715 .await?;
716
717 while !saw_response && last_event_id.is_some() {
718 if !retry_delay.is_zero() {
719 sleep(retry_delay).await;
720 }
721
722 let response = self
723 .resume_streamable_http_stream(last_event_id.as_deref().unwrap())
724 .await?;
725 saw_response |= self
726 .read_streamable_http_events(
727 response,
728 request_message,
729 request_id,
730 &mut last_event_id,
731 &mut retry_delay,
732 )
733 .await?;
734 }
735
736 Ok(())
737 }
738
739 async fn read_streamable_http_events(
740 &mut self,
741 response: reqwest::Response,
742 request_message: &Value,
743 request_id: &Value,
744 last_event_id: &mut Option<String>,
745 retry_delay: &mut Duration,
746 ) -> Result<bool, McpError> {
747 let stream = response.bytes_stream().map_err(std::io::Error::other);
748 let mut reader = BufReader::new(StreamReader::new(stream));
749 let mut saw_response = false;
750
751 while let Some(event) = read_next_sse_event(&mut reader).await? {
752 if let Some(id) = event.id.clone() {
753 *last_event_id = Some(id);
754 }
755 if let Some(retry_ms) = event.retry_ms {
756 *retry_delay = Duration::from_millis(retry_ms);
757 }
758
759 let Some(frame) = streamable_http_event_to_frame(event)? else {
760 continue;
761 };
762
763 self.maybe_update_protocol_version(request_message, &frame.value)?;
764 if frame.value.get("id") == Some(request_id) {
765 saw_response = true;
766 }
767 self.pending_frames.push_back(frame);
768 }
769
770 Ok(saw_response)
771 }
772
773 async fn resume_streamable_http_stream(
774 &self,
775 last_event_id: &str,
776 ) -> Result<reqwest::Response, McpError> {
777 let mut request = self
778 .client
779 .get(self.endpoint_url.clone())
780 .header("Accept", "text/event-stream")
781 .header("Cache-Control", "no-cache")
782 .header("Last-Event-ID", last_event_id);
783
784 request = apply_streamable_http_headers(
785 request,
786 &self.headers,
787 self.protocol_version.as_deref(),
788 self.session_id.as_deref(),
789 );
790
791 let response = request.send().await.map_err(McpError::Http)?;
792 let status = response.status();
793 if !status.is_success() {
794 return Err(
795 streamable_http_status_error("Streamable HTTP GET", status, response).await,
796 );
797 }
798
799 let content_type = response
800 .headers()
801 .get(reqwest::header::CONTENT_TYPE)
802 .and_then(|value| value.to_str().ok())
803 .unwrap_or_default();
804 if !content_type.starts_with("text/event-stream") {
805 let content_type = content_type.to_string();
806 let body = response
807 .text()
808 .await
809 .unwrap_or_else(|_| "<unreadable response body>".into());
810 return Err(McpError::Transport(format!(
811 "Streamable HTTP GET expected text/event-stream, got {content_type:?}: {body}"
812 )));
813 }
814
815 Ok(response)
816 }
817
818 fn maybe_update_protocol_version(
819 &mut self,
820 request_message: &Value,
821 response_value: &Value,
822 ) -> Result<(), McpError> {
823 if request_message.get("method").and_then(Value::as_str) != Some("initialize") {
824 return Ok(());
825 }
826
827 let protocol_version = response_value
828 .get("result")
829 .and_then(|result| result.get("protocolVersion"))
830 .and_then(Value::as_str);
831
832 if let Some(protocol_version) = protocol_version {
833 self.protocol_version = Some(protocol_version.to_string());
834 }
835
836 Ok(())
837 }
838
839 fn capture_session_id(&mut self, headers: &reqwest::header::HeaderMap) {
840 self.session_id = headers
841 .get("MCP-Session-Id")
842 .and_then(|value| value.to_str().ok())
843 .map(|value| value.to_string());
844 }
845}
846
847#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
853pub struct McpToolDescriptor {
854 pub name: String,
856 pub description: Option<String>,
858 pub input_schema: Value,
860 pub metadata: MetadataMap,
862}
863
864#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
869pub struct McpResourceDescriptor {
870 pub id: String,
872 pub name: String,
874 pub description: Option<String>,
876 pub mime_type: Option<String>,
878 pub metadata: MetadataMap,
880}
881
882#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
887pub struct McpPromptDescriptor {
888 pub id: String,
890 pub name: String,
892 pub description: Option<String>,
894 pub input_schema: Value,
896 pub metadata: MetadataMap,
898}
899
900#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
906pub struct McpDiscoverySnapshot {
907 pub server_id: McpServerId,
909 pub tools: Vec<McpToolDescriptor>,
911 pub resources: Vec<McpResourceDescriptor>,
913 pub prompts: Vec<McpPromptDescriptor>,
915 pub metadata: MetadataMap,
917}
918
919pub struct McpConnection {
949 server_id: McpServerId,
950 transport: Mutex<Box<dyn McpTransport>>,
951 auth: Mutex<Option<MetadataMap>>,
952 next_id: AtomicU64,
953}
954
955#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
960pub enum McpOperationResult {
961 Connected(McpDiscoverySnapshot),
963 Tool(Value),
965 Resource(ResourceContents),
967 Prompt(PromptContents),
969}
970
971impl McpConnection {
972 pub async fn connect(config: &McpServerConfig) -> Result<Self, McpError> {
980 Self::connect_with_auth(config, None).await
981 }
982
983 async fn connect_with_auth(
984 config: &McpServerConfig,
985 auth: Option<&MetadataMap>,
986 ) -> Result<Self, McpError> {
987 let factory: Arc<dyn McpTransportFactory> = match &config.transport {
988 McpTransportBinding::Stdio(binding) => {
989 Arc::new(StdioTransportFactory::new(binding.clone()))
990 }
991 McpTransportBinding::StreamableHttp(binding) => {
992 Arc::new(StreamableHttpTransportFactory::new(binding.clone()))
993 }
994 McpTransportBinding::Sse(binding) => {
995 Arc::new(SseTransportFactory::new(binding.clone()))
996 }
997 McpTransportBinding::Custom(factory) => factory.clone(),
998 };
999
1000 let mut transport = factory.connect().await?;
1001 let mut params = serde_json::Map::new();
1002 params.insert(
1003 "protocolVersion".into(),
1004 Value::String(MCP_LATEST_PROTOCOL_VERSION.into()),
1005 );
1006 params.insert("capabilities".into(), json!({}));
1007 params.insert(
1008 "clientInfo".into(),
1009 json!({
1010 "name": "agentkit-mcp",
1011 "version": env!("CARGO_PKG_VERSION")
1012 }),
1013 );
1014 if let Some(auth) = auth {
1015 params.insert("auth".into(), metadata_to_value(auth));
1016 }
1017 let init_params = Value::Object(params.clone());
1018 transport
1019 .send(McpFrame {
1020 value: json!({
1021 "jsonrpc": "2.0",
1022 "id": 0,
1023 "method": "initialize",
1024 "params": init_params.clone()
1025 }),
1026 })
1027 .await?;
1028 let init_response = transport.recv().await?.ok_or_else(|| {
1029 McpError::Transport("transport closed during MCP initialization".into())
1030 })?;
1031 if let Some(error) = init_response.value.get("error") {
1032 if let Some(auth_request) =
1033 parse_auth_request(&config.id, "initialize", &init_params, error)
1034 {
1035 return Err(McpError::AuthRequired(Box::new(auth_request)));
1036 }
1037 return Err(McpError::Invocation(error.to_string()));
1038 }
1039 let negotiated_protocol_version = init_response
1040 .value
1041 .get("result")
1042 .and_then(|result| result.get("protocolVersion"))
1043 .and_then(Value::as_str)
1044 .ok_or_else(|| {
1045 McpError::Protocol("initialize response missing result.protocolVersion".into())
1046 })?;
1047 if !MCP_SUPPORTED_PROTOCOL_VERSIONS.contains(&negotiated_protocol_version) {
1048 return Err(McpError::Protocol(format!(
1049 "unsupported MCP protocol version negotiated during initialize: {negotiated_protocol_version}"
1050 )));
1051 }
1052 transport
1053 .send(McpFrame {
1054 value: json!({
1055 "jsonrpc": "2.0",
1056 "method": "notifications/initialized",
1057 "params": {}
1058 }),
1059 })
1060 .await?;
1061
1062 Ok(Self {
1063 server_id: config.id.clone(),
1064 transport: Mutex::new(transport),
1065 auth: Mutex::new(auth.cloned()),
1066 next_id: AtomicU64::new(1),
1067 })
1068 }
1069
1070 pub fn server_id(&self) -> &McpServerId {
1072 &self.server_id
1073 }
1074
1075 pub async fn close(&self) -> Result<(), McpError> {
1081 let mut transport = self.transport.lock().await;
1082 transport.close().await
1083 }
1084
1085 pub async fn resolve_auth(&self, resolution: AuthResolution) -> Result<(), McpError> {
1095 let mut auth = self.auth.lock().await;
1096 match resolution {
1097 AuthResolution::Provided { credentials, .. } => {
1098 *auth = Some(credentials);
1099 }
1100 AuthResolution::Cancelled { .. } => {
1101 *auth = None;
1102 }
1103 }
1104 Ok(())
1105 }
1106
1107 pub async fn discover(&self) -> Result<McpDiscoverySnapshot, McpError> {
1115 Ok(McpDiscoverySnapshot {
1116 server_id: self.server_id.clone(),
1117 tools: self.list_tools().await?,
1118 resources: self.list_resources().await?,
1119 prompts: self.list_prompts().await?,
1120 metadata: MetadataMap::new(),
1121 })
1122 }
1123
1124 pub async fn list_tools(&self) -> Result<Vec<McpToolDescriptor>, McpError> {
1130 let result = self.request("tools/list", json!({})).await?;
1131 result
1132 .get("tools")
1133 .and_then(Value::as_array)
1134 .cloned()
1135 .unwrap_or_default()
1136 .into_iter()
1137 .map(parse_tool_descriptor)
1138 .collect()
1139 }
1140
1141 pub async fn list_resources(&self) -> Result<Vec<McpResourceDescriptor>, McpError> {
1147 let result = self.request("resources/list", json!({})).await?;
1148 result
1149 .get("resources")
1150 .and_then(Value::as_array)
1151 .cloned()
1152 .unwrap_or_default()
1153 .into_iter()
1154 .map(parse_resource_descriptor)
1155 .collect()
1156 }
1157
1158 pub async fn list_prompts(&self) -> Result<Vec<McpPromptDescriptor>, McpError> {
1164 let result = self.request("prompts/list", json!({})).await?;
1165 result
1166 .get("prompts")
1167 .and_then(Value::as_array)
1168 .cloned()
1169 .unwrap_or_default()
1170 .into_iter()
1171 .map(parse_prompt_descriptor)
1172 .collect()
1173 }
1174
1175 pub async fn call_tool(&self, name: &str, arguments: Value) -> Result<Value, McpError> {
1187 self.request(
1188 "tools/call",
1189 json!({
1190 "name": name,
1191 "arguments": arguments,
1192 }),
1193 )
1194 .await
1195 }
1196
1197 pub async fn read_resource(&self, uri: &str) -> Result<ResourceContents, McpError> {
1207 let result = self
1208 .request(
1209 "resources/read",
1210 json!({
1211 "uri": uri,
1212 }),
1213 )
1214 .await?;
1215 let content = result
1216 .get("contents")
1217 .and_then(Value::as_array)
1218 .and_then(|values| values.first())
1219 .cloned()
1220 .ok_or_else(|| McpError::Protocol("resources/read returned no contents".into()))?;
1221
1222 let data = if let Some(text) = content.get("text").and_then(Value::as_str) {
1223 DataRef::InlineText(text.into())
1224 } else if let Some(found_uri) = content.get("uri").and_then(Value::as_str) {
1225 DataRef::Uri(found_uri.into())
1226 } else {
1227 return Err(McpError::Protocol(
1228 "unsupported resource content shape".into(),
1229 ));
1230 };
1231
1232 Ok(ResourceContents {
1233 data,
1234 metadata: MetadataMap::new(),
1235 })
1236 }
1237
1238 pub async fn get_prompt(
1249 &self,
1250 name: &str,
1251 arguments: Value,
1252 ) -> Result<PromptContents, McpError> {
1253 let result = self
1254 .request(
1255 "prompts/get",
1256 json!({
1257 "name": name,
1258 "arguments": arguments,
1259 }),
1260 )
1261 .await?;
1262 let items = result
1263 .get("messages")
1264 .and_then(Value::as_array)
1265 .cloned()
1266 .unwrap_or_default()
1267 .into_iter()
1268 .map(parse_prompt_message)
1269 .collect::<Result<Vec<_>, _>>()?;
1270
1271 Ok(PromptContents {
1272 items,
1273 metadata: MetadataMap::new(),
1274 })
1275 }
1276
1277 async fn request(&self, method: &str, params: Value) -> Result<Value, McpError> {
1278 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
1279 let params = self.enrich_params(params.clone()).await;
1280 let mut transport = self.transport.lock().await;
1281 transport
1282 .send(McpFrame {
1283 value: json!({
1284 "jsonrpc": "2.0",
1285 "id": id,
1286 "method": method,
1287 "params": params,
1288 }),
1289 })
1290 .await?;
1291
1292 loop {
1293 let Some(frame) = transport.recv().await? else {
1294 return Err(McpError::Transport(
1295 "transport closed while waiting for MCP response".into(),
1296 ));
1297 };
1298
1299 if frame.value.get("id").and_then(Value::as_u64) != Some(id) {
1300 continue;
1301 }
1302
1303 if let Some(error) = frame.value.get("error") {
1304 if let Some(auth_request) =
1305 parse_auth_request(&self.server_id, method, ¶ms, error)
1306 {
1307 return Err(McpError::AuthRequired(Box::new(auth_request)));
1308 }
1309 return Err(McpError::Invocation(error.to_string()));
1310 }
1311
1312 return frame
1313 .value
1314 .get("result")
1315 .cloned()
1316 .ok_or_else(|| McpError::Protocol("MCP response missing result".into()));
1317 }
1318 }
1319
1320 async fn enrich_params(&self, params: Value) -> Value {
1321 let auth = self.auth.lock().await;
1322 let Some(auth) = auth.as_ref() else {
1323 return params;
1324 };
1325
1326 match params {
1327 Value::Object(mut object) => {
1328 object
1329 .entry("auth")
1330 .or_insert_with(|| metadata_to_value(auth));
1331 Value::Object(object)
1332 }
1333 other => other,
1334 }
1335 }
1336
1337 pub async fn replay_auth_operation(
1347 &self,
1348 operation: &AuthOperation,
1349 ) -> Result<McpOperationResult, McpError> {
1350 match operation {
1351 AuthOperation::McpToolCall {
1352 server_id,
1353 tool_name,
1354 input,
1355 ..
1356 } => {
1357 self.ensure_server_match(server_id)?;
1358 self.call_tool(tool_name, input.clone())
1359 .await
1360 .map(McpOperationResult::Tool)
1361 }
1362 AuthOperation::McpResourceRead {
1363 server_id,
1364 resource_id,
1365 ..
1366 } => {
1367 self.ensure_server_match(server_id)?;
1368 self.read_resource(resource_id)
1369 .await
1370 .map(McpOperationResult::Resource)
1371 }
1372 AuthOperation::McpPromptGet {
1373 server_id,
1374 prompt_id,
1375 args,
1376 ..
1377 } => {
1378 self.ensure_server_match(server_id)?;
1379 self.get_prompt(prompt_id, args.clone())
1380 .await
1381 .map(McpOperationResult::Prompt)
1382 }
1383 AuthOperation::ToolCall {
1384 tool_name,
1385 input,
1386 metadata,
1387 ..
1388 } => {
1389 if let Some(server_id) = metadata.get("server_id").and_then(Value::as_str) {
1390 self.ensure_server_match(server_id)?;
1391 }
1392 let tool_name = normalize_mcp_tool_name(self.server_id(), tool_name);
1393 self.call_tool(&tool_name, input.clone())
1394 .await
1395 .map(McpOperationResult::Tool)
1396 }
1397 AuthOperation::McpConnect { .. } => Err(McpError::AuthResolution(
1398 "connect operations must be replayed through the server manager".into(),
1399 )),
1400 AuthOperation::Custom { kind, .. } => Err(McpError::AuthResolution(format!(
1401 "unsupported auth operation for replay: {kind}"
1402 ))),
1403 }
1404 }
1405
1406 fn ensure_server_match(&self, server_id: &str) -> Result<(), McpError> {
1407 if self.server_id.0 == server_id {
1408 Ok(())
1409 } else {
1410 Err(McpError::AuthResolution(format!(
1411 "auth operation targets server {server_id}, but connection is for {}",
1412 self.server_id
1413 )))
1414 }
1415 }
1416}
1417
1418pub struct McpInvocable {
1425 connection: Arc<McpConnection>,
1426 descriptor: McpToolDescriptor,
1427 spec: InvocableSpec,
1428}
1429
1430impl McpInvocable {
1431 pub fn new(connection: Arc<McpConnection>, descriptor: McpToolDescriptor) -> Self {
1438 let spec = InvocableSpec {
1439 name: CapabilityName::new(format!(
1440 "mcp_{}_{}",
1441 connection.server_id(),
1442 descriptor.name
1443 )),
1444 description: descriptor
1445 .description
1446 .clone()
1447 .unwrap_or_else(|| descriptor.name.clone()),
1448 input_schema: descriptor.input_schema.clone(),
1449 metadata: descriptor.metadata.clone(),
1450 };
1451
1452 Self {
1453 connection,
1454 descriptor,
1455 spec,
1456 }
1457 }
1458}
1459
1460#[async_trait]
1461impl Invocable for McpInvocable {
1462 fn spec(&self) -> &InvocableSpec {
1463 &self.spec
1464 }
1465
1466 async fn invoke(
1467 &self,
1468 request: InvocableRequest,
1469 _ctx: &mut CapabilityContext<'_>,
1470 ) -> Result<InvocableResult, CapabilityError> {
1471 let result = self
1472 .connection
1473 .call_tool(&self.descriptor.name, request.input)
1474 .await
1475 .map_err(|error| match error {
1476 McpError::AuthRequired(request) => {
1477 CapabilityError::Unavailable(format!("auth required: {:?}", request))
1478 }
1479 other => CapabilityError::ExecutionFailed(other.to_string()),
1480 })?;
1481
1482 Ok(InvocableResult {
1483 output: value_to_invocable_output(result),
1484 metadata: MetadataMap::new(),
1485 })
1486 }
1487}
1488
1489pub struct McpResourceHandle {
1494 connection: Arc<McpConnection>,
1495 descriptor: ResourceDescriptor,
1496}
1497
1498#[async_trait]
1499impl ResourceProvider for McpResourceHandle {
1500 async fn list_resources(&self) -> Result<Vec<ResourceDescriptor>, CapabilityError> {
1501 Ok(vec![self.descriptor.clone()])
1502 }
1503
1504 async fn read_resource(
1505 &self,
1506 id: &ResourceId,
1507 _ctx: &mut CapabilityContext<'_>,
1508 ) -> Result<ResourceContents, CapabilityError> {
1509 self.connection
1510 .read_resource(&id.0)
1511 .await
1512 .map_err(|error| match error {
1513 McpError::AuthRequired(request) => {
1514 CapabilityError::Unavailable(format!("auth required: {:?}", request))
1515 }
1516 other => CapabilityError::ExecutionFailed(other.to_string()),
1517 })
1518 }
1519}
1520
1521pub struct McpPromptHandle {
1526 connection: Arc<McpConnection>,
1527 descriptor: PromptDescriptor,
1528}
1529
1530#[async_trait]
1531impl PromptProvider for McpPromptHandle {
1532 async fn list_prompts(&self) -> Result<Vec<PromptDescriptor>, CapabilityError> {
1533 Ok(vec![self.descriptor.clone()])
1534 }
1535
1536 async fn get_prompt(
1537 &self,
1538 id: &PromptId,
1539 args: Value,
1540 _ctx: &mut CapabilityContext<'_>,
1541 ) -> Result<PromptContents, CapabilityError> {
1542 self.connection
1543 .get_prompt(&id.0, args)
1544 .await
1545 .map_err(|error| match error {
1546 McpError::AuthRequired(request) => {
1547 CapabilityError::Unavailable(format!("auth required: {:?}", request))
1548 }
1549 other => CapabilityError::ExecutionFailed(other.to_string()),
1550 })
1551 }
1552}
1553
1554pub struct McpCapabilityProvider {
1581 invocables: Vec<Arc<dyn Invocable>>,
1582 resources: Vec<Arc<dyn ResourceProvider>>,
1583 prompts: Vec<Arc<dyn PromptProvider>>,
1584}
1585
1586impl McpCapabilityProvider {
1587 pub fn from_snapshot(connection: Arc<McpConnection>, snapshot: &McpDiscoverySnapshot) -> Self {
1593 let invocables = snapshot
1594 .tools
1595 .iter()
1596 .cloned()
1597 .map(|descriptor| {
1598 Arc::new(McpInvocable::new(connection.clone(), descriptor)) as Arc<dyn Invocable>
1599 })
1600 .collect();
1601
1602 let resources = snapshot
1603 .resources
1604 .iter()
1605 .cloned()
1606 .map(|descriptor| {
1607 Arc::new(McpResourceHandle {
1608 connection: connection.clone(),
1609 descriptor: ResourceDescriptor {
1610 id: ResourceId::new(descriptor.id),
1611 name: descriptor.name,
1612 description: descriptor.description,
1613 mime_type: descriptor.mime_type,
1614 metadata: descriptor.metadata,
1615 },
1616 }) as Arc<dyn ResourceProvider>
1617 })
1618 .collect();
1619
1620 let prompts = snapshot
1621 .prompts
1622 .iter()
1623 .cloned()
1624 .map(|descriptor| {
1625 Arc::new(McpPromptHandle {
1626 connection: connection.clone(),
1627 descriptor: PromptDescriptor {
1628 id: PromptId::new(descriptor.id),
1629 name: descriptor.name,
1630 description: descriptor.description,
1631 input_schema: descriptor.input_schema,
1632 metadata: descriptor.metadata,
1633 },
1634 }) as Arc<dyn PromptProvider>
1635 })
1636 .collect();
1637
1638 Self {
1639 invocables,
1640 resources,
1641 prompts,
1642 }
1643 }
1644
1645 pub fn merge<I>(providers: I) -> Self
1650 where
1651 I: IntoIterator<Item = Self>,
1652 {
1653 let mut invocables = Vec::new();
1654 let mut resources = Vec::new();
1655 let mut prompts = Vec::new();
1656
1657 for provider in providers {
1658 invocables.extend(provider.invocables);
1659 resources.extend(provider.resources);
1660 prompts.extend(provider.prompts);
1661 }
1662
1663 Self {
1664 invocables,
1665 resources,
1666 prompts,
1667 }
1668 }
1669
1670 pub async fn connect(
1679 config: &McpServerConfig,
1680 ) -> Result<(Arc<McpConnection>, Self, McpDiscoverySnapshot), McpError> {
1681 let connection = Arc::new(McpConnection::connect(config).await?);
1682 let snapshot = connection.discover().await?;
1683 let provider = Self::from_snapshot(connection.clone(), &snapshot);
1684
1685 Ok((connection, provider, snapshot))
1686 }
1687}
1688
1689impl CapabilityProvider for McpCapabilityProvider {
1690 fn invocables(&self) -> Vec<Arc<dyn Invocable>> {
1691 self.invocables.clone()
1692 }
1693
1694 fn resources(&self) -> Vec<Arc<dyn ResourceProvider>> {
1695 self.resources.clone()
1696 }
1697
1698 fn prompts(&self) -> Vec<Arc<dyn PromptProvider>> {
1699 self.prompts.clone()
1700 }
1701}
1702
1703#[derive(Clone)]
1709pub struct McpServerHandle {
1710 config: McpServerConfig,
1711 connection: Arc<McpConnection>,
1712 snapshot: McpDiscoverySnapshot,
1713}
1714
1715impl McpServerHandle {
1716 pub fn config(&self) -> &McpServerConfig {
1718 &self.config
1719 }
1720
1721 pub fn server_id(&self) -> &McpServerId {
1723 self.connection.server_id()
1724 }
1725
1726 pub fn connection(&self) -> Arc<McpConnection> {
1728 self.connection.clone()
1729 }
1730
1731 pub fn snapshot(&self) -> &McpDiscoverySnapshot {
1733 &self.snapshot
1734 }
1735
1736 pub fn tool_registry(&self) -> ToolRegistry {
1739 self.snapshot
1740 .tools
1741 .iter()
1742 .cloned()
1743 .fold(ToolRegistry::new(), |registry, descriptor| {
1744 registry.with(McpToolAdapter::new(
1745 self.server_id(),
1746 self.connection.clone(),
1747 descriptor,
1748 ))
1749 })
1750 }
1751
1752 pub fn capability_provider(&self) -> McpCapabilityProvider {
1754 McpCapabilityProvider::from_snapshot(self.connection.clone(), &self.snapshot)
1755 }
1756}
1757
1758#[derive(Default)]
1799pub struct McpServerManager {
1800 configs: BTreeMap<McpServerId, McpServerConfig>,
1801 connections: BTreeMap<McpServerId, McpServerHandle>,
1802 auth: BTreeMap<McpServerId, MetadataMap>,
1803}
1804
1805impl McpServerManager {
1806 pub fn new() -> Self {
1808 Self::default()
1809 }
1810
1811 pub fn with_server(mut self, config: McpServerConfig) -> Self {
1816 self.register_server(config);
1817 self
1818 }
1819
1820 pub fn register_server(&mut self, config: McpServerConfig) -> &mut Self {
1825 self.configs.insert(config.id.clone(), config);
1826 self
1827 }
1828
1829 pub fn connected_server(&self, server_id: &McpServerId) -> Option<&McpServerHandle> {
1831 self.connections.get(server_id)
1832 }
1833
1834 pub fn connected_servers(&self) -> Vec<&McpServerHandle> {
1836 self.connections.values().collect()
1837 }
1838
1839 pub async fn connect_server(
1848 &mut self,
1849 server_id: &McpServerId,
1850 ) -> Result<McpServerHandle, McpError> {
1851 let config = self
1852 .configs
1853 .get(server_id)
1854 .cloned()
1855 .ok_or_else(|| McpError::UnknownServer(server_id.to_string()))?;
1856 let connection =
1857 Arc::new(McpConnection::connect_with_auth(&config, self.auth.get(server_id)).await?);
1858 let snapshot = connection.discover().await?;
1859 let handle = McpServerHandle {
1860 config,
1861 connection,
1862 snapshot,
1863 };
1864 self.connections.insert(server_id.clone(), handle.clone());
1865 Ok(handle)
1866 }
1867
1868 pub async fn connect_all(&mut self) -> Result<Vec<McpServerHandle>, McpError> {
1878 let server_ids = self.configs.keys().cloned().collect::<Vec<_>>();
1879 let mut handles = Vec::with_capacity(server_ids.len());
1880
1881 for server_id in server_ids {
1882 handles.push(self.connect_server(&server_id).await?);
1883 }
1884
1885 Ok(handles)
1886 }
1887
1888 pub async fn refresh_server(
1898 &mut self,
1899 server_id: &McpServerId,
1900 ) -> Result<McpDiscoverySnapshot, McpError> {
1901 let handle = self
1902 .connections
1903 .get_mut(server_id)
1904 .ok_or_else(|| McpError::UnknownServer(server_id.to_string()))?;
1905 let snapshot = handle.connection.discover().await?;
1906 handle.snapshot = snapshot.clone();
1907 Ok(snapshot)
1908 }
1909
1910 pub async fn disconnect_server(&mut self, server_id: &McpServerId) -> Result<(), McpError> {
1919 let Some(handle) = self.connections.remove(server_id) else {
1920 return Err(McpError::UnknownServer(server_id.to_string()));
1921 };
1922 handle.connection.close().await
1923 }
1924
1925 pub async fn resolve_auth(&mut self, resolution: AuthResolution) -> Result<(), McpError> {
1933 let server_id = resolution
1934 .request()
1935 .server_id()
1936 .ok_or_else(|| McpError::AuthResolution("auth resolution missing server id".into()))?;
1937 let server_id = McpServerId::new(server_id);
1938 match &resolution {
1939 AuthResolution::Provided { credentials, .. } => {
1940 self.auth.insert(server_id.clone(), credentials.clone());
1941 }
1942 AuthResolution::Cancelled { .. } => {
1943 self.auth.remove(&server_id);
1944 }
1945 }
1946
1947 if let Some(handle) = self.connections.get(&server_id) {
1948 handle.connection.resolve_auth(resolution).await?;
1949 return Ok(());
1950 }
1951
1952 if self.configs.contains_key(&server_id) {
1953 Ok(())
1954 } else {
1955 Err(McpError::UnknownServer(server_id.to_string()))
1956 }
1957 }
1958
1959 pub async fn resolve_auth_and_resume(
1969 &mut self,
1970 resolution: AuthResolution,
1971 ) -> Result<McpOperationResult, McpError> {
1972 let request = resolution.request().clone();
1973 self.resolve_auth(resolution).await?;
1974 self.replay_auth_request(&request).await
1975 }
1976
1977 pub async fn replay_auth_request(
1987 &mut self,
1988 request: &AuthRequest,
1989 ) -> Result<McpOperationResult, McpError> {
1990 match &request.operation {
1991 AuthOperation::McpConnect { server_id, .. } => {
1992 let server_id = McpServerId::new(server_id);
1993 let handle = self.connect_server(&server_id).await?;
1994 Ok(McpOperationResult::Connected(handle.snapshot.clone()))
1995 }
1996 AuthOperation::McpToolCall { server_id, .. }
1997 | AuthOperation::McpResourceRead { server_id, .. }
1998 | AuthOperation::McpPromptGet { server_id, .. } => {
1999 let connection = self.connection_for_auth_server(server_id).await?;
2000 connection.replay_auth_operation(&request.operation).await
2001 }
2002 AuthOperation::ToolCall { metadata, .. } => {
2003 let server_id = metadata
2004 .get("server_id")
2005 .and_then(Value::as_str)
2006 .ok_or_else(|| {
2007 McpError::AuthResolution(
2008 "tool-call auth replay requires metadata.server_id".into(),
2009 )
2010 })?;
2011 let connection = self.connection_for_auth_server(server_id).await?;
2012 connection.replay_auth_operation(&request.operation).await
2013 }
2014 AuthOperation::Custom { kind, .. } => Err(McpError::AuthResolution(format!(
2015 "unsupported auth operation for replay: {kind}"
2016 ))),
2017 }
2018 }
2019
2020 async fn connection_for_auth_server(
2021 &mut self,
2022 server_id: &str,
2023 ) -> Result<Arc<McpConnection>, McpError> {
2024 let server_id = McpServerId::new(server_id);
2025 if !self.connections.contains_key(&server_id) {
2026 self.connect_server(&server_id).await?;
2027 }
2028 self.connections
2029 .get(&server_id)
2030 .map(McpServerHandle::connection)
2031 .ok_or_else(|| McpError::UnknownServer(server_id.to_string()))
2032 }
2033
2034 pub fn tool_registry(&self) -> ToolRegistry {
2039 self.connections
2040 .values()
2041 .fold(ToolRegistry::new(), |mut registry, handle| {
2042 for tool in handle.snapshot.tools.iter().cloned() {
2043 registry.register(McpToolAdapter::new(
2044 handle.server_id(),
2045 handle.connection.clone(),
2046 tool,
2047 ));
2048 }
2049 registry
2050 })
2051 }
2052
2053 pub fn capability_provider(&self) -> McpCapabilityProvider {
2056 McpCapabilityProvider::merge(
2057 self.connections
2058 .values()
2059 .map(McpServerHandle::capability_provider),
2060 )
2061 }
2062}
2063
2064pub struct McpToolAdapter {
2080 descriptor: McpToolDescriptor,
2081 connection: Arc<McpConnection>,
2082 spec: ToolSpec,
2083}
2084
2085impl McpToolAdapter {
2086 pub fn new(
2094 server_id: &McpServerId,
2095 connection: Arc<McpConnection>,
2096 descriptor: McpToolDescriptor,
2097 ) -> Self {
2098 let spec = ToolSpec {
2099 name: ToolName::new(format!("mcp_{}_{}", server_id, descriptor.name)),
2100 description: descriptor
2101 .description
2102 .clone()
2103 .unwrap_or_else(|| descriptor.name.clone()),
2104 input_schema: descriptor.input_schema.clone(),
2105 annotations: ToolAnnotations::default(),
2106 metadata: descriptor.metadata.clone(),
2107 };
2108
2109 Self {
2110 descriptor,
2111 connection,
2112 spec,
2113 }
2114 }
2115}
2116
2117#[async_trait]
2118impl Tool for McpToolAdapter {
2119 fn spec(&self) -> &ToolSpec {
2120 &self.spec
2121 }
2122
2123 async fn invoke(
2124 &self,
2125 request: ToolRequest,
2126 _ctx: &mut ToolContext<'_>,
2127 ) -> Result<ToolResult, ToolError> {
2128 let result = self
2129 .connection
2130 .call_tool(&self.descriptor.name, request.input)
2131 .await
2132 .map_err(|error| match error {
2133 McpError::AuthRequired(request) => ToolError::AuthRequired(request),
2134 other => ToolError::ExecutionFailed(other.to_string()),
2135 })?;
2136
2137 Ok(ToolResult {
2138 result: ToolResultPart {
2139 call_id: request.call_id,
2140 output: invocable_output_to_tool_output(value_to_invocable_output(result)),
2141 is_error: false,
2142 metadata: MetadataMap::new(),
2143 },
2144 duration: None,
2145 metadata: MetadataMap::new(),
2146 })
2147 }
2148}
2149
2150fn parse_tool_descriptor(value: Value) -> Result<McpToolDescriptor, McpError> {
2151 Ok(McpToolDescriptor {
2152 name: required_string(&value, "name")?,
2153 description: value
2154 .get("description")
2155 .and_then(Value::as_str)
2156 .map(str::to_owned),
2157 input_schema: value
2158 .get("inputSchema")
2159 .cloned()
2160 .unwrap_or_else(|| json!({ "type": "object" })),
2161 metadata: MetadataMap::new(),
2162 })
2163}
2164
2165fn parse_resource_descriptor(value: Value) -> Result<McpResourceDescriptor, McpError> {
2166 Ok(McpResourceDescriptor {
2167 id: required_string(&value, "uri")?,
2168 name: value
2169 .get("name")
2170 .and_then(Value::as_str)
2171 .map(str::to_owned)
2172 .unwrap_or_else(|| {
2173 value
2174 .get("uri")
2175 .and_then(Value::as_str)
2176 .unwrap_or_default()
2177 .to_string()
2178 }),
2179 description: value
2180 .get("description")
2181 .and_then(Value::as_str)
2182 .map(str::to_owned),
2183 mime_type: value
2184 .get("mimeType")
2185 .and_then(Value::as_str)
2186 .map(str::to_owned),
2187 metadata: MetadataMap::new(),
2188 })
2189}
2190
2191fn parse_prompt_descriptor(value: Value) -> Result<McpPromptDescriptor, McpError> {
2192 let name = required_string(&value, "name")?;
2193 let properties = value
2194 .get("arguments")
2195 .and_then(Value::as_array)
2196 .cloned()
2197 .unwrap_or_default()
2198 .into_iter()
2199 .filter_map(|arg| {
2200 let name = arg.get("name")?.as_str()?.to_string();
2201 Some((name, json!({ "type": "string" })))
2202 })
2203 .collect::<serde_json::Map<String, Value>>();
2204
2205 Ok(McpPromptDescriptor {
2206 id: name.clone(),
2207 name,
2208 description: value
2209 .get("description")
2210 .and_then(Value::as_str)
2211 .map(str::to_owned),
2212 input_schema: json!({
2213 "type": "object",
2214 "properties": properties,
2215 }),
2216 metadata: MetadataMap::new(),
2217 })
2218}
2219
2220fn parse_prompt_message(value: Value) -> Result<Item, McpError> {
2221 let role = value.get("role").and_then(Value::as_str).unwrap_or("user");
2222 let kind = match role {
2223 "assistant" => ItemKind::Assistant,
2224 "system" => ItemKind::System,
2225 _ => ItemKind::User,
2226 };
2227
2228 let content = value.get("content").cloned().unwrap_or(Value::Null);
2229 let text = if let Some(text) = content.get("text").and_then(Value::as_str) {
2230 text.to_string()
2231 } else if let Some(text) = content.as_str() {
2232 text.to_string()
2233 } else {
2234 content.to_string()
2235 };
2236
2237 Ok(Item {
2238 id: None,
2239 kind,
2240 parts: vec![Part::Text(TextPart {
2241 text,
2242 metadata: MetadataMap::new(),
2243 })],
2244 metadata: MetadataMap::new(),
2245 })
2246}
2247
2248fn required_string(value: &Value, field: &str) -> Result<String, McpError> {
2249 value
2250 .get(field)
2251 .and_then(Value::as_str)
2252 .map(str::to_owned)
2253 .ok_or_else(|| McpError::Protocol(format!("missing string field {field}")))
2254}
2255
2256fn value_to_invocable_output(value: Value) -> InvocableOutput {
2257 if let Some(content) = value.get("content").and_then(Value::as_array) {
2258 let text = content
2259 .iter()
2260 .filter_map(|item| item.get("text").and_then(Value::as_str))
2261 .collect::<Vec<_>>()
2262 .join("\n");
2263 if !text.is_empty() {
2264 return InvocableOutput::Text(text);
2265 }
2266 }
2267
2268 if let Some(text) = value.as_str() {
2269 InvocableOutput::Text(text.to_string())
2270 } else {
2271 InvocableOutput::Structured(value)
2272 }
2273}
2274
2275fn invocable_output_to_tool_output(output: InvocableOutput) -> ToolOutput {
2276 match output {
2277 InvocableOutput::Text(text) => ToolOutput::Text(text),
2278 InvocableOutput::Structured(value) => ToolOutput::Structured(value),
2279 InvocableOutput::Items(items) => {
2280 ToolOutput::Parts(items.into_iter().flat_map(|item| item.parts).collect())
2281 }
2282 InvocableOutput::Data(data) => ToolOutput::Structured(json!({ "data": data })),
2283 }
2284}
2285
2286fn metadata_to_value(metadata: &MetadataMap) -> Value {
2287 Value::Object(
2288 metadata
2289 .iter()
2290 .map(|(key, value)| (key.clone(), value.clone()))
2291 .collect(),
2292 )
2293}
2294
2295fn parse_auth_request(
2296 server_id: &McpServerId,
2297 method: &str,
2298 params: &Value,
2299 error: &Value,
2300) -> Option<AuthRequest> {
2301 let code = error.get("code").and_then(Value::as_i64);
2302 let message = error.get("message").and_then(Value::as_str);
2303 let data = error.get("data");
2304
2305 let auth_marker = matches!(code, Some(401 | -32001))
2306 || data
2307 .and_then(|data| data.get("auth_required"))
2308 .and_then(Value::as_bool)
2309 == Some(true)
2310 || data.and_then(|data| data.get("auth")).is_some();
2311
2312 if !auth_marker {
2313 return None;
2314 }
2315
2316 let mut challenge = MetadataMap::new();
2317 challenge.insert("server_id".into(), Value::String(server_id.to_string()));
2318 challenge.insert("method".into(), Value::String(method.into()));
2319
2320 if let Some(code) = code {
2321 challenge.insert("code".into(), Value::Number(code.into()));
2322 }
2323 if let Some(message) = message {
2324 challenge.insert("message".into(), Value::String(message.into()));
2325 }
2326 if let Some(data) = data {
2327 challenge.insert("data".into(), data.clone());
2328 }
2329
2330 Some(AuthRequest {
2331 task_id: None,
2332 id: format!("mcp:{}:{}", server_id, method),
2333 provider: format!("mcp.{}", server_id),
2334 operation: auth_operation_for_method(server_id, method, params),
2335 challenge,
2336 })
2337}
2338
2339fn auth_operation_for_method(
2340 server_id: &McpServerId,
2341 method: &str,
2342 params: &Value,
2343) -> AuthOperation {
2344 match method {
2345 "initialize" => AuthOperation::McpConnect {
2346 server_id: server_id.to_string(),
2347 metadata: MetadataMap::new(),
2348 },
2349 "tools/call" => AuthOperation::McpToolCall {
2350 server_id: server_id.to_string(),
2351 tool_name: params
2352 .get("name")
2353 .and_then(Value::as_str)
2354 .unwrap_or_default()
2355 .to_string(),
2356 input: params
2357 .get("arguments")
2358 .cloned()
2359 .unwrap_or_else(|| json!({})),
2360 metadata: MetadataMap::new(),
2361 },
2362 "resources/read" => AuthOperation::McpResourceRead {
2363 server_id: server_id.to_string(),
2364 resource_id: params
2365 .get("uri")
2366 .and_then(Value::as_str)
2367 .unwrap_or_default()
2368 .to_string(),
2369 metadata: MetadataMap::new(),
2370 },
2371 "prompts/get" => AuthOperation::McpPromptGet {
2372 server_id: server_id.to_string(),
2373 prompt_id: params
2374 .get("name")
2375 .and_then(Value::as_str)
2376 .unwrap_or_default()
2377 .to_string(),
2378 args: params
2379 .get("arguments")
2380 .cloned()
2381 .unwrap_or_else(|| json!({})),
2382 metadata: MetadataMap::new(),
2383 },
2384 other => AuthOperation::Custom {
2385 kind: format!("mcp.{other}"),
2386 payload: params.clone(),
2387 metadata: {
2388 let mut metadata = MetadataMap::new();
2389 metadata.insert("server_id".into(), Value::String(server_id.to_string()));
2390 metadata
2391 },
2392 },
2393 }
2394}
2395
2396fn normalize_mcp_tool_name(server_id: &McpServerId, tool_name: &str) -> String {
2397 let prefix = format!("mcp_{server_id}_");
2398 tool_name
2399 .strip_prefix(&prefix)
2400 .unwrap_or(tool_name)
2401 .to_string()
2402}
2403
2404async fn read_sse_stream<R>(
2405 mut reader: R,
2406 response_url: Url,
2407 frame_tx: mpsc::UnboundedSender<Result<McpFrame, McpError>>,
2408 endpoint_tx: oneshot::Sender<Result<Url, McpError>>,
2409) where
2410 R: AsyncBufRead + Unpin,
2411{
2412 let mut endpoint_tx = Some(endpoint_tx);
2413 loop {
2414 match read_next_sse_event(&mut reader).await {
2415 Ok(Some(event)) => {
2416 if let Some(endpoint) = legacy_sse_event_to_endpoint(&response_url, &event) {
2417 if let Some(tx) = endpoint_tx.take() {
2418 let _ = tx.send(endpoint);
2419 }
2420 continue;
2421 }
2422
2423 if let Some(frame) = legacy_sse_event_to_frame(event) {
2424 let _ = frame_tx.send(frame);
2425 }
2426 }
2427 Ok(None) => break,
2428 Err(error) => {
2429 if let Some(tx) = endpoint_tx.take() {
2430 let _ = tx.send(Err(error));
2431 } else {
2432 let _ = frame_tx.send(Err(error));
2433 }
2434 return;
2435 }
2436 }
2437 }
2438
2439 if let Some(tx) = endpoint_tx.take() {
2440 let _ = tx.send(Err(McpError::Transport(
2441 "SSE stream ended before endpoint event".into(),
2442 )));
2443 }
2444}
2445
2446fn resolve_sse_endpoint(response_url: &Url, endpoint: &str) -> Result<Url, McpError> {
2447 response_url
2448 .join(endpoint.trim())
2449 .map_err(|error| McpError::Transport(format!("invalid SSE endpoint URL: {error}")))
2450}
2451
2452#[derive(Debug)]
2453struct SseEvent {
2454 event_name: Option<String>,
2455 data: String,
2456 id: Option<String>,
2457 retry_ms: Option<u64>,
2458}
2459
2460async fn read_next_sse_event<R>(reader: &mut R) -> Result<Option<SseEvent>, McpError>
2461where
2462 R: AsyncBufRead + Unpin,
2463{
2464 let mut event_name = None;
2465 let mut data_lines = Vec::new();
2466 let mut id = None;
2467 let mut retry_ms = None;
2468
2469 loop {
2470 let mut line = String::new();
2471 let read = reader.read_line(&mut line).await.map_err(McpError::Io)?;
2472 if read == 0 {
2473 if event_name.is_none() && data_lines.is_empty() && id.is_none() && retry_ms.is_none() {
2474 return Ok(None);
2475 }
2476 return Ok(Some(SseEvent {
2477 event_name,
2478 data: data_lines.join("\n"),
2479 id,
2480 retry_ms,
2481 }));
2482 }
2483
2484 let line = line.trim_end_matches(['\r', '\n']);
2485 if line.is_empty() {
2486 if event_name.is_none() && data_lines.is_empty() && id.is_none() && retry_ms.is_none() {
2487 continue;
2488 }
2489 return Ok(Some(SseEvent {
2490 event_name,
2491 data: data_lines.join("\n"),
2492 id,
2493 retry_ms,
2494 }));
2495 }
2496
2497 if line.starts_with(':') {
2498 continue;
2499 }
2500
2501 if let Some(rest) = line.strip_prefix("event:") {
2502 event_name = Some(rest.trim_start().to_string());
2503 continue;
2504 }
2505 if let Some(rest) = line.strip_prefix("data:") {
2506 data_lines.push(rest.trim_start().to_string());
2507 continue;
2508 }
2509 if let Some(rest) = line.strip_prefix("id:") {
2510 id = Some(rest.trim_start().to_string());
2511 continue;
2512 }
2513 if let Some(rest) = line.strip_prefix("retry:") {
2514 retry_ms = rest.trim_start().parse().ok();
2515 }
2516 }
2517}
2518
2519fn legacy_sse_event_to_endpoint(
2520 response_url: &Url,
2521 event: &SseEvent,
2522) -> Option<Result<Url, McpError>> {
2523 if event.event_name.as_deref() != Some("endpoint") {
2524 return None;
2525 }
2526 if event.data.is_empty() {
2527 return Some(Err(McpError::Transport(
2528 "legacy SSE endpoint event is missing data".into(),
2529 )));
2530 }
2531 Some(resolve_sse_endpoint(response_url, &event.data))
2532}
2533
2534fn legacy_sse_event_to_frame(event: SseEvent) -> Option<Result<McpFrame, McpError>> {
2535 let event_name = event.event_name.unwrap_or_else(|| "message".into());
2536 if event_name != "message" || event.data.is_empty() {
2537 return None;
2538 }
2539
2540 Some(
2541 serde_json::from_str(&event.data)
2542 .map_err(McpError::Serialize)
2543 .map(|value| McpFrame { value }),
2544 )
2545}
2546
2547fn streamable_http_event_to_frame(event: SseEvent) -> Result<Option<McpFrame>, McpError> {
2548 let event_name = event.event_name.unwrap_or_else(|| "message".into());
2549 if event_name != "message" || event.data.is_empty() {
2550 return Ok(None);
2551 }
2552
2553 let value = serde_json::from_str(&event.data).map_err(McpError::Serialize)?;
2554 Ok(Some(McpFrame { value }))
2555}
2556
2557fn is_jsonrpc_request(value: &Value) -> bool {
2558 value.get("method").is_some() && value.get("id").is_some()
2559}
2560
2561fn apply_streamable_http_headers(
2562 mut request: reqwest::RequestBuilder,
2563 headers: &[(String, String)],
2564 protocol_version: Option<&str>,
2565 session_id: Option<&str>,
2566) -> reqwest::RequestBuilder {
2567 for (key, value) in headers {
2568 request = request.header(key, value);
2569 }
2570
2571 if let Some(protocol_version) = protocol_version {
2572 request = request.header("MCP-Protocol-Version", protocol_version);
2573 }
2574 if let Some(session_id) = session_id {
2575 request = request.header("MCP-Session-Id", session_id);
2576 }
2577
2578 request
2579}
2580
2581async fn streamable_http_status_error(
2582 operation: &str,
2583 status: StatusCode,
2584 response: reqwest::Response,
2585) -> McpError {
2586 let body = response
2587 .text()
2588 .await
2589 .unwrap_or_else(|_| "<unreadable response body>".into());
2590 McpError::Transport(format!("{operation} failed with status {status}: {body}"))
2591}
2592
2593#[derive(Debug, Error)]
2595pub enum McpError {
2596 #[error("io error: {0}")]
2598 Io(#[from] std::io::Error),
2599 #[error("http error: {0}")]
2601 Http(#[from] reqwest::Error),
2602 #[error("serialization error: {0}")]
2604 Serialize(#[from] serde_json::Error),
2605 #[error("transport error: {0}")]
2607 Transport(String),
2608 #[error("protocol error: {0}")]
2610 Protocol(String),
2611 #[error("MCP auth required: {0:?}")]
2614 AuthRequired(Box<AuthRequest>),
2615 #[error("auth resolution error: {0}")]
2617 AuthResolution(String),
2618 #[error("invocation error: {0}")]
2620 Invocation(String),
2621 #[error("unknown MCP server: {0}")]
2623 UnknownServer(String),
2624}
2625
2626#[cfg(test)]
2627mod tests {
2628 use std::collections::VecDeque;
2629 use std::sync::{Arc as StdArc, Mutex as StdMutex};
2630
2631 use super::*;
2632 use agentkit_tools_core::{PermissionChecker, PermissionDecision, PermissionRequest};
2633 use tokio::io::{AsyncReadExt, AsyncWriteExt};
2634 use tokio::net::TcpListener;
2635
2636 struct AllowAll;
2637
2638 impl PermissionChecker for AllowAll {
2639 fn evaluate(&self, _request: &dyn PermissionRequest) -> PermissionDecision {
2640 PermissionDecision::Allow
2641 }
2642 }
2643
2644 struct FakeTransport {
2645 recv: VecDeque<Value>,
2646 }
2647
2648 #[async_trait]
2649 impl McpTransport for FakeTransport {
2650 async fn send(&mut self, _message: McpFrame) -> Result<(), McpError> {
2651 Ok(())
2652 }
2653
2654 async fn recv(&mut self) -> Result<Option<McpFrame>, McpError> {
2655 Ok(self.recv.pop_front().map(|value| McpFrame { value }))
2656 }
2657
2658 async fn close(&mut self) -> Result<(), McpError> {
2659 Ok(())
2660 }
2661 }
2662
2663 fn fake_connection(responses: Vec<Value>) -> McpConnection {
2664 McpConnection {
2665 server_id: McpServerId::new("fake"),
2666 transport: Mutex::new(Box::new(FakeTransport {
2667 recv: responses.into(),
2668 })),
2669 auth: Mutex::new(None),
2670 next_id: AtomicU64::new(1),
2671 }
2672 }
2673
2674 #[derive(Clone)]
2675 struct FakeTransportFactory {
2676 responses: StdArc<StdMutex<VecDeque<Vec<Value>>>>,
2677 }
2678
2679 impl FakeTransportFactory {
2680 fn new(sequences: Vec<Vec<Value>>) -> Self {
2681 Self {
2682 responses: StdArc::new(StdMutex::new(sequences.into())),
2683 }
2684 }
2685 }
2686
2687 #[async_trait]
2688 impl McpTransportFactory for FakeTransportFactory {
2689 async fn connect(&self) -> Result<Box<dyn McpTransport>, McpError> {
2690 let responses =
2691 self.responses.lock().unwrap().pop_front().ok_or_else(|| {
2692 McpError::Transport("no fake transport responses left".into())
2693 })?;
2694 Ok(Box::new(FakeTransport {
2695 recv: responses.into(),
2696 }))
2697 }
2698 }
2699
2700 #[tokio::test]
2701 async fn discovery_parses_snapshot() {
2702 let connection = fake_connection(vec![
2703 json!({ "jsonrpc": "2.0", "id": 1, "result": { "tools": [{ "name": "echo", "description": "Echo", "inputSchema": {"type": "object"} }] } }),
2704 json!({ "jsonrpc": "2.0", "id": 2, "result": { "resources": [{ "uri": "file:///tmp/example.txt", "name": "example.txt", "mimeType": "text/plain" }] } }),
2705 json!({ "jsonrpc": "2.0", "id": 3, "result": { "prompts": [{ "name": "summarize", "description": "Summarize", "arguments": [{ "name": "path" }] }] } }),
2706 ]);
2707
2708 let snapshot = connection.discover().await.unwrap();
2709 assert_eq!(snapshot.tools[0].name, "echo");
2710 assert_eq!(snapshot.resources[0].id, "file:///tmp/example.txt");
2711 assert_eq!(snapshot.prompts[0].id, "summarize");
2712 }
2713
2714 #[tokio::test]
2715 async fn tool_adapter_returns_text_output() {
2716 let connection = Arc::new(fake_connection(vec![json!({
2717 "jsonrpc": "2.0",
2718 "id": 1,
2719 "result": { "content": [{ "type": "text", "text": "pong" }] }
2720 })]));
2721 let server_id = connection.server_id().clone();
2722 let adapter = McpToolAdapter::new(
2723 &server_id,
2724 connection,
2725 McpToolDescriptor {
2726 name: "echo".into(),
2727 description: Some("Echo".into()),
2728 input_schema: json!({ "type": "object" }),
2729 metadata: MetadataMap::new(),
2730 },
2731 );
2732 let metadata = MetadataMap::new();
2733 let mut ctx = ToolContext {
2734 capability: CapabilityContext {
2735 session_id: None,
2736 turn_id: None,
2737 metadata: &metadata,
2738 },
2739 permissions: &AllowAll,
2740 resources: &(),
2741 cancellation: None,
2742 };
2743
2744 let result = adapter
2745 .invoke(
2746 ToolRequest {
2747 call_id: "call-1".into(),
2748 tool_name: ToolName::new("mcp_fake_echo"),
2749 input: json!({}),
2750 session_id: "session-1".into(),
2751 turn_id: "turn-1".into(),
2752 metadata: MetadataMap::new(),
2753 },
2754 &mut ctx,
2755 )
2756 .await
2757 .unwrap();
2758
2759 assert_eq!(result.result.output, ToolOutput::Text("pong".into()));
2760 }
2761
2762 #[tokio::test]
2763 async fn request_surfaces_auth_required_errors() {
2764 let connection = fake_connection(vec![json!({
2765 "jsonrpc": "2.0",
2766 "id": 1,
2767 "error": {
2768 "code": -32001,
2769 "message": "authentication required",
2770 "data": {
2771 "auth_required": true,
2772 "scope": "secrets.read"
2773 }
2774 }
2775 })]);
2776
2777 let error = connection.call_tool("echo", json!({})).await.unwrap_err();
2778 match error {
2779 McpError::AuthRequired(request) => {
2780 assert_eq!(request.provider, "mcp.fake");
2781 assert_eq!(
2782 request.challenge.get("method"),
2783 Some(&Value::String("tools/call".into()))
2784 );
2785 assert!(matches!(
2786 request.operation,
2787 AuthOperation::McpToolCall { ref tool_name, .. } if tool_name == "echo"
2788 ));
2789 }
2790 other => panic!("unexpected error: {other:?}"),
2791 }
2792 }
2793
2794 #[tokio::test]
2795 async fn tool_adapter_maps_auth_required_into_tool_error() {
2796 let connection = Arc::new(fake_connection(vec![json!({
2797 "jsonrpc": "2.0",
2798 "id": 1,
2799 "error": {
2800 "code": -32001,
2801 "message": "authentication required",
2802 "data": { "auth_required": true }
2803 }
2804 })]));
2805 let server_id = connection.server_id().clone();
2806 let adapter = McpToolAdapter::new(
2807 &server_id,
2808 connection,
2809 McpToolDescriptor {
2810 name: "echo".into(),
2811 description: Some("Echo".into()),
2812 input_schema: json!({ "type": "object" }),
2813 metadata: MetadataMap::new(),
2814 },
2815 );
2816 let metadata = MetadataMap::new();
2817 let mut ctx = ToolContext {
2818 capability: CapabilityContext {
2819 session_id: None,
2820 turn_id: None,
2821 metadata: &metadata,
2822 },
2823 permissions: &AllowAll,
2824 resources: &(),
2825 cancellation: None,
2826 };
2827
2828 let error = adapter
2829 .invoke(
2830 ToolRequest {
2831 call_id: "call-1".into(),
2832 tool_name: ToolName::new("mcp_fake_echo"),
2833 input: json!({}),
2834 session_id: "session-1".into(),
2835 turn_id: "turn-1".into(),
2836 metadata: MetadataMap::new(),
2837 },
2838 &mut ctx,
2839 )
2840 .await
2841 .unwrap_err();
2842
2843 match error {
2844 ToolError::AuthRequired(request) => {
2845 assert_eq!(request.provider, "mcp.fake");
2846 }
2847 other => panic!("unexpected error: {other:?}"),
2848 }
2849 }
2850
2851 struct RecordingTransport {
2852 recv: VecDeque<Value>,
2853 sent: StdArc<StdMutex<Vec<Value>>>,
2854 }
2855
2856 #[async_trait]
2857 impl McpTransport for RecordingTransport {
2858 async fn send(&mut self, message: McpFrame) -> Result<(), McpError> {
2859 self.sent.lock().unwrap().push(message.value);
2860 Ok(())
2861 }
2862
2863 async fn recv(&mut self) -> Result<Option<McpFrame>, McpError> {
2864 Ok(self.recv.pop_front().map(|value| McpFrame { value }))
2865 }
2866
2867 async fn close(&mut self) -> Result<(), McpError> {
2868 Ok(())
2869 }
2870 }
2871
2872 #[derive(Clone)]
2873 struct RecordingTransportFactory {
2874 responses: StdArc<StdMutex<VecDeque<Vec<Value>>>>,
2875 sent: StdArc<StdMutex<Vec<Value>>>,
2876 }
2877
2878 impl RecordingTransportFactory {
2879 fn new(sequences: Vec<Vec<Value>>) -> Self {
2880 Self {
2881 responses: StdArc::new(StdMutex::new(sequences.into())),
2882 sent: StdArc::new(StdMutex::new(Vec::new())),
2883 }
2884 }
2885
2886 fn sent(&self) -> Vec<Value> {
2887 self.sent.lock().unwrap().clone()
2888 }
2889 }
2890
2891 #[async_trait]
2892 impl McpTransportFactory for RecordingTransportFactory {
2893 async fn connect(&self) -> Result<Box<dyn McpTransport>, McpError> {
2894 let responses = self.responses.lock().unwrap().pop_front().ok_or_else(|| {
2895 McpError::Transport("no recording transport responses left".into())
2896 })?;
2897 Ok(Box::new(RecordingTransport {
2898 recv: responses.into(),
2899 sent: self.sent.clone(),
2900 }))
2901 }
2902 }
2903
2904 #[tokio::test]
2905 async fn connection_includes_resolved_auth_in_future_requests() {
2906 let factory = RecordingTransportFactory::new(vec![vec![
2907 json!({ "jsonrpc": "2.0", "id": 0, "result": { "protocolVersion": "2025-11-25", "capabilities": {}, "serverInfo": { "name": "recording", "version": "1.0.0" } } }),
2908 json!({ "jsonrpc": "2.0", "id": 1, "result": { "content": [{ "type": "text", "text": "ok" }] } }),
2909 ]]);
2910 let config = McpServerConfig::new(
2911 "recording",
2912 McpTransportBinding::Custom(Arc::new(factory.clone())),
2913 );
2914 let connection = McpConnection::connect(&config).await.unwrap();
2915 let mut auth = MetadataMap::new();
2916 auth.insert("token".into(), json!("secret-token"));
2917 let request = AuthRequest {
2918 task_id: None,
2919 id: "auth-recording-tool".into(),
2920 provider: "mcp.recording".into(),
2921 operation: AuthOperation::McpToolCall {
2922 server_id: "recording".into(),
2923 tool_name: "echo".into(),
2924 input: json!({}),
2925 metadata: MetadataMap::new(),
2926 },
2927 challenge: MetadataMap::new(),
2928 };
2929 connection
2930 .resolve_auth(agentkit_tools_core::AuthResolution::Provided {
2931 request,
2932 credentials: auth,
2933 })
2934 .await
2935 .unwrap();
2936
2937 let _ = connection.call_tool("echo", json!({})).await.unwrap();
2938 let sent = factory.sent();
2939 assert!(
2940 sent.iter().any(|value| {
2941 value
2942 .get("params")
2943 .and_then(|params| params.get("auth"))
2944 .and_then(|auth| auth.get("token"))
2945 == Some(&json!("secret-token"))
2946 }),
2947 "expected an MCP request to include the resolved auth payload, saw {:?}",
2948 sent
2949 );
2950 }
2951
2952 #[tokio::test]
2953 async fn manager_reuses_stored_auth_on_connect() {
2954 let factory = RecordingTransportFactory::new(vec![vec![
2955 json!({ "jsonrpc": "2.0", "id": 0, "result": { "protocolVersion": "2025-11-25", "capabilities": {}, "serverInfo": { "name": "recording", "version": "1.0.0" } } }),
2956 json!({ "jsonrpc": "2.0", "id": 1, "result": { "tools": [] } }),
2957 json!({ "jsonrpc": "2.0", "id": 2, "result": { "resources": [] } }),
2958 json!({ "jsonrpc": "2.0", "id": 3, "result": { "prompts": [] } }),
2959 ]]);
2960 let server_id = McpServerId::new("recording");
2961 let mut manager = McpServerManager::new().with_server(McpServerConfig::new(
2962 server_id.to_string(),
2963 McpTransportBinding::Custom(Arc::new(factory.clone())),
2964 ));
2965 let mut auth = MetadataMap::new();
2966 auth.insert("token".into(), json!("seed-token"));
2967 let request = AuthRequest {
2968 task_id: None,
2969 id: "auth-recording-connect".into(),
2970 provider: "mcp.recording".into(),
2971 operation: AuthOperation::McpConnect {
2972 server_id: server_id.to_string(),
2973 metadata: MetadataMap::new(),
2974 },
2975 challenge: MetadataMap::new(),
2976 };
2977 manager
2978 .resolve_auth(agentkit_tools_core::AuthResolution::Provided {
2979 request,
2980 credentials: auth,
2981 })
2982 .await
2983 .unwrap();
2984
2985 manager.connect_server(&server_id).await.unwrap();
2986 let sent = factory.sent();
2987 assert!(
2988 sent.iter().any(|value| {
2989 value.get("method").and_then(Value::as_str) == Some("initialize")
2990 && value
2991 .get("params")
2992 .and_then(|params| params.get("auth"))
2993 .and_then(|auth| auth.get("token"))
2994 == Some(&json!("seed-token"))
2995 }),
2996 "expected initialize to include stored auth, saw {:?}",
2997 sent
2998 );
2999 }
3000
3001 #[tokio::test]
3002 async fn manager_resolves_auth_and_replays_resource_read() {
3003 let factory = RecordingTransportFactory::new(vec![vec![
3004 json!({ "jsonrpc": "2.0", "id": 0, "result": { "protocolVersion": "2025-11-25", "capabilities": {}, "serverInfo": { "name": "recording", "version": "1.0.0" } } }),
3005 json!({ "jsonrpc": "2.0", "id": 1, "result": { "tools": [] } }),
3006 json!({ "jsonrpc": "2.0", "id": 2, "result": { "resources": [] } }),
3007 json!({ "jsonrpc": "2.0", "id": 3, "result": { "prompts": [] } }),
3008 json!({
3009 "jsonrpc": "2.0",
3010 "id": 4,
3011 "result": {
3012 "contents": [
3013 {
3014 "uri": "file:///tmp/secret.txt",
3015 "text": "secret from resource"
3016 }
3017 ]
3018 }
3019 }),
3020 ]]);
3021 let server_id = McpServerId::new("recording");
3022 let mut manager = McpServerManager::new().with_server(McpServerConfig::new(
3023 server_id.to_string(),
3024 McpTransportBinding::Custom(Arc::new(factory.clone())),
3025 ));
3026 let mut auth = MetadataMap::new();
3027 auth.insert("token".into(), json!("resource-token"));
3028 let request = AuthRequest {
3029 task_id: None,
3030 id: "auth-recording-resource".into(),
3031 provider: "mcp.recording".into(),
3032 operation: AuthOperation::McpResourceRead {
3033 server_id: server_id.to_string(),
3034 resource_id: "file:///tmp/secret.txt".into(),
3035 metadata: MetadataMap::new(),
3036 },
3037 challenge: MetadataMap::new(),
3038 };
3039
3040 let result = manager
3041 .resolve_auth_and_resume(agentkit_tools_core::AuthResolution::Provided {
3042 request,
3043 credentials: auth,
3044 })
3045 .await
3046 .unwrap();
3047
3048 match result {
3049 McpOperationResult::Resource(contents) => {
3050 assert_eq!(
3051 contents.data,
3052 DataRef::InlineText("secret from resource".into())
3053 );
3054 }
3055 other => panic!("unexpected replay result: {other:?}"),
3056 }
3057
3058 let sent = factory.sent();
3059 assert!(
3060 sent.iter().any(|value| {
3061 value.get("method").and_then(Value::as_str) == Some("resources/read")
3062 && value
3063 .get("params")
3064 .and_then(|params| params.get("auth"))
3065 .and_then(|auth| auth.get("token"))
3066 == Some(&json!("resource-token"))
3067 }),
3068 "expected resources/read to include resolved auth, saw {:?}",
3069 sent
3070 );
3071 }
3072
3073 #[tokio::test]
3074 async fn manager_resolves_auth_and_replays_connect() {
3075 let factory = RecordingTransportFactory::new(vec![vec![
3076 json!({ "jsonrpc": "2.0", "id": 0, "result": { "protocolVersion": "2025-11-25", "capabilities": {}, "serverInfo": { "name": "recording", "version": "1.0.0" } } }),
3077 json!({ "jsonrpc": "2.0", "id": 1, "result": { "tools": [] } }),
3078 json!({ "jsonrpc": "2.0", "id": 2, "result": { "resources": [] } }),
3079 json!({ "jsonrpc": "2.0", "id": 3, "result": { "prompts": [] } }),
3080 ]]);
3081 let server_id = McpServerId::new("recording");
3082 let mut manager = McpServerManager::new().with_server(McpServerConfig::new(
3083 server_id.to_string(),
3084 McpTransportBinding::Custom(Arc::new(factory.clone())),
3085 ));
3086 let mut auth = MetadataMap::new();
3087 auth.insert("token".into(), json!("connect-token"));
3088 let request = AuthRequest {
3089 task_id: None,
3090 id: "auth-recording-connect-replay".into(),
3091 provider: "mcp.recording".into(),
3092 operation: AuthOperation::McpConnect {
3093 server_id: server_id.to_string(),
3094 metadata: MetadataMap::new(),
3095 },
3096 challenge: MetadataMap::new(),
3097 };
3098
3099 let result = manager
3100 .resolve_auth_and_resume(agentkit_tools_core::AuthResolution::Provided {
3101 request,
3102 credentials: auth,
3103 })
3104 .await
3105 .unwrap();
3106
3107 match result {
3108 McpOperationResult::Connected(snapshot) => {
3109 assert_eq!(snapshot.server_id, server_id);
3110 }
3111 other => panic!("unexpected replay result: {other:?}"),
3112 }
3113 }
3114
3115 #[tokio::test]
3116 async fn sse_transport_posts_messages_and_receives_frames() {
3117 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3118 let address = listener.local_addr().unwrap();
3119 let requests = StdArc::new(StdMutex::new(Vec::new()));
3120 let captured = requests.clone();
3121
3122 let server = tokio::spawn(async move {
3123 for _ in 0..2 {
3124 let (mut socket, _) = listener.accept().await.unwrap();
3125 let mut buffer = vec![0_u8; 4096];
3126 let read = socket.read(&mut buffer).await.unwrap();
3127 let request = String::from_utf8_lossy(&buffer[..read]).to_string();
3128
3129 if request.starts_with("GET /sse ") {
3130 let body = concat!(
3131 "event: endpoint\n",
3132 "data: /messages\n\n",
3133 "event: message\n",
3134 "data: {\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"tools\":[]}}\n\n"
3135 );
3136 let response = format!(
3137 "HTTP/1.1 200 OK\r\ncontent-type: text/event-stream\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{}",
3138 body.len(),
3139 body
3140 );
3141 socket.write_all(response.as_bytes()).await.unwrap();
3142 } else {
3143 captured.lock().unwrap().push(request);
3144 socket
3145 .write_all(
3146 b"HTTP/1.1 202 Accepted\r\ncontent-length: 0\r\nconnection: close\r\n\r\n",
3147 )
3148 .await
3149 .unwrap();
3150 }
3151 }
3152 });
3153
3154 let factory =
3155 SseTransportFactory::new(SseTransportConfig::new(format!("http://{address}/sse")));
3156 let mut transport = factory.connect().await.unwrap();
3157 transport
3158 .send(McpFrame {
3159 value: json!({
3160 "jsonrpc": "2.0",
3161 "id": 1,
3162 "method": "tools/list",
3163 "params": {}
3164 }),
3165 })
3166 .await
3167 .unwrap();
3168 let frame = transport.recv().await.unwrap().unwrap();
3169 transport.close().await.unwrap();
3170 server.await.unwrap();
3171
3172 assert_eq!(frame.value["result"]["tools"], json!([]));
3173 let requests = requests.lock().unwrap();
3174 assert_eq!(requests.len(), 1);
3175 assert!(requests[0].starts_with("POST /messages "));
3176 assert!(requests[0].contains("\"method\":\"tools/list\""));
3177 }
3178
3179 #[tokio::test]
3180 async fn streamable_http_connection_tracks_session_and_protocol_headers() {
3181 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3182 let address = listener.local_addr().unwrap();
3183 let requests = StdArc::new(StdMutex::new(Vec::new()));
3184 let captured = requests.clone();
3185
3186 let server = tokio::spawn(async move {
3187 for _ in 0..4 {
3188 let (mut socket, _) = listener.accept().await.unwrap();
3189 let mut buffer = vec![0_u8; 8192];
3190 let read = socket.read(&mut buffer).await.unwrap();
3191 let request = String::from_utf8_lossy(&buffer[..read]).to_string();
3192 captured.lock().unwrap().push(request.clone());
3193
3194 let response = if request.contains("\"method\":\"initialize\"") {
3195 let body = "{\"jsonrpc\":\"2.0\",\"id\":0,\"result\":{\"protocolVersion\":\"2025-11-25\",\"capabilities\":{},\"serverInfo\":{\"name\":\"remote\",\"version\":\"1.0.0\"}}}";
3196 format!(
3197 "HTTP/1.1 200 OK\r\ncontent-type: application/json\r\nMCP-Session-Id: session-123\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{}",
3198 body.len(),
3199 body
3200 )
3201 } else if request.contains("\"method\":\"notifications/initialized\"") {
3202 "HTTP/1.1 202 Accepted\r\ncontent-length: 0\r\nconnection: close\r\n\r\n"
3203 .to_string()
3204 } else if request.starts_with("DELETE /mcp ") {
3205 "HTTP/1.1 204 No Content\r\ncontent-length: 0\r\nconnection: close\r\n\r\n"
3206 .to_string()
3207 } else {
3208 let body = "{\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"tools\":[]}}";
3209 format!(
3210 "HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{}",
3211 body.len(),
3212 body
3213 )
3214 };
3215
3216 socket.write_all(response.as_bytes()).await.unwrap();
3217 }
3218 });
3219
3220 let config = McpServerConfig::new(
3221 "remote",
3222 McpTransportBinding::StreamableHttp(StreamableHttpTransportConfig::new(format!(
3223 "http://{address}/mcp"
3224 ))),
3225 );
3226 let connection = McpConnection::connect(&config).await.unwrap();
3227 let _ = connection.list_tools().await.unwrap();
3228 connection.close().await.unwrap();
3229 server.await.unwrap();
3230
3231 let requests = requests.lock().unwrap();
3232 assert_eq!(requests.len(), 4);
3233 let normalized = requests
3234 .iter()
3235 .map(|request| request.to_ascii_lowercase())
3236 .collect::<Vec<_>>();
3237 assert!(requests[0].starts_with("POST /mcp "));
3238 assert!(!requests[0].contains("MCP-Session-Id:"));
3239 assert!(normalized[1].contains("mcp-session-id: session-123"));
3240 assert!(normalized[1].contains("mcp-protocol-version: 2025-11-25"));
3241 assert!(normalized[2].contains("mcp-session-id: session-123"));
3242 assert!(normalized[2].contains("mcp-protocol-version: 2025-11-25"));
3243 assert!(requests[3].starts_with("DELETE /mcp "));
3244 assert!(normalized[3].contains("mcp-session-id: session-123"));
3245 }
3246
3247 #[tokio::test]
3248 async fn streamable_http_transport_resumes_sse_streams_until_response_arrives() {
3249 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3250 let address = listener.local_addr().unwrap();
3251 let requests = StdArc::new(StdMutex::new(Vec::new()));
3252 let captured = requests.clone();
3253
3254 let server = tokio::spawn(async move {
3255 for _ in 0..2 {
3256 let (mut socket, _) = listener.accept().await.unwrap();
3257 let mut buffer = vec![0_u8; 8192];
3258 let read = socket.read(&mut buffer).await.unwrap();
3259 let request = String::from_utf8_lossy(&buffer[..read]).to_string();
3260 captured.lock().unwrap().push(request.clone());
3261
3262 let response = if request.starts_with("POST /mcp ") {
3263 let body = concat!(
3264 "id: evt-1\n",
3265 "event: message\n",
3266 "data: {\"jsonrpc\":\"2.0\",\"method\":\"notifications/message\",\"params\":{\"phase\":\"stream-start\"}}\n\n"
3267 );
3268 format!(
3269 "HTTP/1.1 200 OK\r\ncontent-type: text/event-stream\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{}",
3270 body.len(),
3271 body
3272 )
3273 } else {
3274 let body = concat!(
3275 "id: evt-2\n",
3276 "event: message\n",
3277 "data: {\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"tools\":[]}}\n\n"
3278 );
3279 format!(
3280 "HTTP/1.1 200 OK\r\ncontent-type: text/event-stream\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{}",
3281 body.len(),
3282 body
3283 )
3284 };
3285
3286 socket.write_all(response.as_bytes()).await.unwrap();
3287 }
3288 });
3289
3290 let factory = StreamableHttpTransportFactory::new(StreamableHttpTransportConfig::new(
3291 format!("http://{address}/mcp"),
3292 ));
3293 let mut transport = factory.connect().await.unwrap();
3294 transport
3295 .send(McpFrame {
3296 value: json!({
3297 "jsonrpc": "2.0",
3298 "id": 1,
3299 "method": "tools/list",
3300 "params": {}
3301 }),
3302 })
3303 .await
3304 .unwrap();
3305
3306 let first = transport.recv().await.unwrap().unwrap();
3307 let second = transport.recv().await.unwrap().unwrap();
3308 transport.close().await.unwrap();
3309 server.await.unwrap();
3310
3311 assert_eq!(
3312 first.value["method"],
3313 Value::String("notifications/message".into())
3314 );
3315 assert_eq!(second.value["result"]["tools"], json!([]));
3316
3317 let requests = requests.lock().unwrap();
3318 assert_eq!(requests.len(), 2);
3319 assert!(requests[0].starts_with("POST /mcp "));
3320 assert!(requests[1].starts_with("GET /mcp "));
3321 assert!(
3322 requests[1].contains("last-event-id: evt-1")
3323 || requests[1].contains("Last-Event-ID: evt-1")
3324 );
3325 }
3326
3327 #[tokio::test]
3328 async fn server_manager_connects_refreshes_and_aggregates_tools() {
3329 let alpha = McpServerConfig::new(
3330 "alpha",
3331 McpTransportBinding::Custom(Arc::new(FakeTransportFactory::new(vec![vec![
3332 json!({ "jsonrpc": "2.0", "id": 0, "result": { "protocolVersion": "2025-11-25", "capabilities": {}, "serverInfo": { "name": "alpha", "version": "1.0.0" } } }),
3333 json!({ "jsonrpc": "2.0", "id": 1, "result": { "tools": [{ "name": "echo", "description": "Echo", "inputSchema": {"type": "object"} }] } }),
3334 json!({ "jsonrpc": "2.0", "id": 2, "result": { "resources": [] } }),
3335 json!({ "jsonrpc": "2.0", "id": 3, "result": { "prompts": [] } }),
3336 json!({ "jsonrpc": "2.0", "id": 4, "result": { "tools": [{ "name": "echo_v2", "description": "Echo 2", "inputSchema": {"type": "object"} }] } }),
3337 json!({ "jsonrpc": "2.0", "id": 5, "result": { "resources": [] } }),
3338 json!({ "jsonrpc": "2.0", "id": 6, "result": { "prompts": [] } }),
3339 ]]))),
3340 );
3341 let beta = McpServerConfig::new(
3342 "beta",
3343 McpTransportBinding::Custom(Arc::new(FakeTransportFactory::new(vec![vec![
3344 json!({ "jsonrpc": "2.0", "id": 0, "result": { "protocolVersion": "2025-11-25", "capabilities": {}, "serverInfo": { "name": "beta", "version": "1.0.0" } } }),
3345 json!({ "jsonrpc": "2.0", "id": 1, "result": { "tools": [{ "name": "search", "description": "Search", "inputSchema": {"type": "object"} }] } }),
3346 json!({ "jsonrpc": "2.0", "id": 2, "result": { "resources": [] } }),
3347 json!({ "jsonrpc": "2.0", "id": 3, "result": { "prompts": [] } }),
3348 ]]))),
3349 );
3350
3351 let mut manager = McpServerManager::new().with_server(alpha).with_server(beta);
3352
3353 let handles = manager.connect_all().await.unwrap();
3354 assert_eq!(handles.len(), 2);
3355 assert_eq!(
3356 manager
3357 .tool_registry()
3358 .specs()
3359 .into_iter()
3360 .map(|spec| spec.name.0)
3361 .collect::<Vec<_>>(),
3362 vec!["mcp_alpha_echo".to_string(), "mcp_beta_search".to_string()]
3363 );
3364
3365 let refreshed = manager
3366 .refresh_server(&McpServerId::new("alpha"))
3367 .await
3368 .unwrap();
3369 assert_eq!(refreshed.tools[0].name, "echo_v2");
3370 assert_eq!(
3371 manager
3372 .connected_server(&McpServerId::new("alpha"))
3373 .unwrap()
3374 .snapshot()
3375 .tools[0]
3376 .name,
3377 "echo_v2"
3378 );
3379
3380 let capabilities = manager.capability_provider();
3381 assert_eq!(capabilities.invocables().len(), 2);
3382
3383 manager
3384 .disconnect_server(&McpServerId::new("alpha"))
3385 .await
3386 .unwrap();
3387 assert!(
3388 manager
3389 .connected_server(&McpServerId::new("alpha"))
3390 .is_none()
3391 );
3392 }
3393}