1use std::collections::{BTreeMap, VecDeque};
2use std::fmt;
3use std::process::Stdio;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::time::Duration;
7
8use agentkit_capabilities::{
9 CapabilityContext, CapabilityError, CapabilityName, CapabilityProvider, Invocable,
10 InvocableOutput, InvocableRequest, InvocableResult, InvocableSpec, PromptContents,
11 PromptDescriptor, PromptId, PromptProvider, ResourceContents, ResourceDescriptor, ResourceId,
12 ResourceProvider,
13};
14use agentkit_core::{
15 DataRef, Item, ItemKind, MetadataMap, Part, TextPart, ToolOutput, ToolResultPart,
16};
17use agentkit_tools_core::{
18 AuthOperation, AuthRequest, AuthResolution, Tool, ToolAnnotations, ToolContext, ToolError,
19 ToolName, ToolRegistry, ToolRequest, ToolResult, ToolSpec,
20};
21use async_trait::async_trait;
22use futures_util::TryStreamExt;
23use reqwest::{Client, StatusCode, Url};
24use serde::{Deserialize, Serialize};
25use serde_json::{Value, json};
26use thiserror::Error;
27use tokio::io::{AsyncBufRead, AsyncBufReadExt, AsyncWriteExt, BufReader};
28use tokio::process::{Child, ChildStdin, ChildStdout, Command};
29use tokio::sync::{Mutex, mpsc, oneshot};
30use tokio::task::JoinHandle;
31use tokio::time::sleep;
32use tokio_util::io::StreamReader;
33
34const MCP_LATEST_PROTOCOL_VERSION: &str = "2025-11-25";
35const MCP_SUPPORTED_PROTOCOL_VERSIONS: &[&str] =
36 &["2025-11-25", "2025-06-18", "2025-03-26", "2024-11-05"];
37
38#[derive(Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
53pub struct McpServerId(pub String);
54
55impl McpServerId {
56 pub fn new(value: impl Into<String>) -> Self {
58 Self(value.into())
59 }
60}
61
62impl fmt::Display for McpServerId {
63 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
64 self.0.fmt(f)
65 }
66}
67
68#[derive(Clone, Debug, PartialEq, Eq)]
86pub struct StdioTransportConfig {
87 pub command: String,
89 pub args: Vec<String>,
91 pub env: Vec<(String, String)>,
93 pub cwd: Option<std::path::PathBuf>,
95}
96
97impl StdioTransportConfig {
98 pub fn new(command: impl Into<String>) -> Self {
100 Self {
101 command: command.into(),
102 args: Vec::new(),
103 env: Vec::new(),
104 cwd: None,
105 }
106 }
107
108 pub fn with_arg(mut self, arg: impl Into<String>) -> Self {
110 self.args.push(arg.into());
111 self
112 }
113
114 pub fn with_env(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
116 self.env.push((key.into(), value.into()));
117 self
118 }
119
120 pub fn with_cwd(mut self, cwd: impl Into<std::path::PathBuf>) -> Self {
122 self.cwd = Some(cwd.into());
123 self
124 }
125}
126
127#[derive(Clone, Debug, PartialEq, Eq)]
142pub struct SseTransportConfig {
143 pub url: String,
145 pub headers: Vec<(String, String)>,
147}
148
149impl SseTransportConfig {
150 pub fn new(url: impl Into<String>) -> Self {
152 Self {
153 url: url.into(),
154 headers: Vec::new(),
155 }
156 }
157
158 pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
160 self.headers.push((key.into(), value.into()));
161 self
162 }
163}
164
165#[derive(Clone, Debug, PartialEq, Eq)]
180pub struct StreamableHttpTransportConfig {
181 pub url: String,
183 pub headers: Vec<(String, String)>,
185}
186
187impl StreamableHttpTransportConfig {
188 pub fn new(url: impl Into<String>) -> Self {
190 Self {
191 url: url.into(),
192 headers: Vec::new(),
193 }
194 }
195
196 pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
198 self.headers.push((key.into(), value.into()));
199 self
200 }
201}
202
203#[derive(Clone)]
210pub enum McpTransportBinding {
211 Stdio(StdioTransportConfig),
213 StreamableHttp(StreamableHttpTransportConfig),
215 Sse(SseTransportConfig),
217 Custom(Arc<dyn McpTransportFactory>),
219}
220
221#[derive(Clone)]
242pub struct McpServerConfig {
243 pub id: McpServerId,
245 pub transport: McpTransportBinding,
247 pub metadata: MetadataMap,
249}
250
251impl McpServerConfig {
252 pub fn new(id: impl Into<String>, transport: McpTransportBinding) -> Self {
259 Self {
260 id: McpServerId::new(id),
261 transport,
262 metadata: MetadataMap::new(),
263 }
264 }
265
266 pub fn stdio(id: impl Into<String>, command: impl Into<String>) -> Self {
268 Self::new(
269 id,
270 McpTransportBinding::Stdio(StdioTransportConfig::new(command)),
271 )
272 }
273
274 pub fn sse(id: impl Into<String>, url: impl Into<String>) -> Self {
276 Self::new(id, McpTransportBinding::Sse(SseTransportConfig::new(url)))
277 }
278
279 pub fn streamable_http(id: impl Into<String>, url: impl Into<String>) -> Self {
281 Self::new(
282 id,
283 McpTransportBinding::StreamableHttp(StreamableHttpTransportConfig::new(url)),
284 )
285 }
286
287 pub fn with_metadata(mut self, metadata: MetadataMap) -> Self {
289 self.metadata = metadata;
290 self
291 }
292}
293
294#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
299pub struct McpFrame {
300 pub value: Value,
302}
303
304#[async_trait]
315pub trait McpTransportFactory: Send + Sync {
316 async fn connect(&self) -> Result<Box<dyn McpTransport>, McpError>;
318}
319
320#[async_trait]
329pub trait McpTransport: Send + Sync {
330 async fn send(&mut self, message: McpFrame) -> Result<(), McpError>;
332 async fn recv(&mut self) -> Result<Option<McpFrame>, McpError>;
334 async fn close(&mut self) -> Result<(), McpError>;
336}
337
338pub struct StdioTransportFactory {
343 config: StdioTransportConfig,
344}
345
346impl StdioTransportFactory {
347 pub fn new(config: StdioTransportConfig) -> Self {
349 Self { config }
350 }
351}
352
353#[async_trait]
354impl McpTransportFactory for StdioTransportFactory {
355 async fn connect(&self) -> Result<Box<dyn McpTransport>, McpError> {
356 let mut command = Command::new(&self.config.command);
357 command.args(&self.config.args);
358 command.stdin(Stdio::piped());
359 command.stdout(Stdio::piped());
360 command.stderr(Stdio::inherit());
361
362 if let Some(cwd) = &self.config.cwd {
363 command.current_dir(cwd);
364 }
365
366 for (key, value) in &self.config.env {
367 command.env(key, value);
368 }
369
370 let mut child = command.spawn().map_err(McpError::Io)?;
371 let stdin = child
372 .stdin
373 .take()
374 .ok_or_else(|| McpError::Transport("failed to capture MCP stdin".into()))?;
375 let stdout = child
376 .stdout
377 .take()
378 .ok_or_else(|| McpError::Transport("failed to capture MCP stdout".into()))?;
379
380 Ok(Box::new(StdioTransport {
381 child,
382 stdin,
383 stdout: BufReader::new(stdout),
384 }))
385 }
386}
387
388pub struct SseTransportFactory {
393 config: SseTransportConfig,
394}
395
396impl SseTransportFactory {
397 pub fn new(config: SseTransportConfig) -> Self {
399 Self { config }
400 }
401}
402
403pub struct StreamableHttpTransportFactory {
408 config: StreamableHttpTransportConfig,
409}
410
411impl StreamableHttpTransportFactory {
412 pub fn new(config: StreamableHttpTransportConfig) -> Self {
414 Self { config }
415 }
416}
417
418#[async_trait]
419impl McpTransportFactory for SseTransportFactory {
420 async fn connect(&self) -> Result<Box<dyn McpTransport>, McpError> {
421 let client = Client::builder()
422 .user_agent(concat!("agentkit-mcp/", env!("CARGO_PKG_VERSION")))
423 .build()
424 .map_err(McpError::Http)?;
425
426 let mut request = client
427 .get(&self.config.url)
428 .header("Accept", "text/event-stream")
429 .header("Cache-Control", "no-cache");
430
431 for (key, value) in &self.config.headers {
432 request = request.header(key, value);
433 }
434
435 let response = request.send().await.map_err(McpError::Http)?;
436 let status = response.status();
437 if !status.is_success() {
438 let body = response
439 .text()
440 .await
441 .unwrap_or_else(|_| "<unreadable response body>".into());
442 return Err(McpError::Transport(format!(
443 "SSE connection failed with status {status}: {body}"
444 )));
445 }
446
447 let response_url = response.url().clone();
448 let stream = response.bytes_stream().map_err(std::io::Error::other);
449 let reader = BufReader::new(StreamReader::new(stream));
450 let (frame_tx, frame_rx) = mpsc::unbounded_channel();
451 let (endpoint_tx, endpoint_rx) = oneshot::channel();
452 let read_task = tokio::spawn(read_sse_stream(reader, response_url, frame_tx, endpoint_tx));
453
454 let endpoint_url = endpoint_rx
455 .await
456 .map_err(|_| McpError::Transport("SSE stream closed before endpoint event".into()))??;
457
458 Ok(Box::new(SseTransport {
459 client,
460 endpoint_url,
461 headers: self.config.headers.clone(),
462 frame_rx,
463 read_task,
464 }))
465 }
466}
467
468#[async_trait]
469impl McpTransportFactory for StreamableHttpTransportFactory {
470 async fn connect(&self) -> Result<Box<dyn McpTransport>, McpError> {
471 let client = Client::builder()
472 .user_agent(concat!("agentkit-mcp/", env!("CARGO_PKG_VERSION")))
473 .build()
474 .map_err(McpError::Http)?;
475
476 let endpoint_url = Url::parse(&self.config.url)
477 .map_err(|error| McpError::Transport(format!("invalid MCP endpoint URL: {error}")))?;
478
479 Ok(Box::new(StreamableHttpTransport {
480 client,
481 endpoint_url,
482 headers: self.config.headers.clone(),
483 protocol_version: None,
484 session_id: None,
485 pending_frames: VecDeque::new(),
486 }))
487 }
488}
489
490struct StdioTransport {
491 child: Child,
492 stdin: ChildStdin,
493 stdout: BufReader<ChildStdout>,
494}
495
496struct SseTransport {
497 client: Client,
498 endpoint_url: Url,
499 headers: Vec<(String, String)>,
500 frame_rx: mpsc::UnboundedReceiver<Result<McpFrame, McpError>>,
501 read_task: JoinHandle<()>,
502}
503
504struct StreamableHttpTransport {
505 client: Client,
506 endpoint_url: Url,
507 headers: Vec<(String, String)>,
508 protocol_version: Option<String>,
509 session_id: Option<String>,
510 pending_frames: VecDeque<McpFrame>,
511}
512
513#[async_trait]
514impl McpTransport for StdioTransport {
515 async fn send(&mut self, message: McpFrame) -> Result<(), McpError> {
516 let mut encoded = serde_json::to_vec(&message.value).map_err(McpError::Serialize)?;
517 encoded.push(b'\n');
518 self.stdin.write_all(&encoded).await.map_err(McpError::Io)?;
519 self.stdin.flush().await.map_err(McpError::Io)?;
520 Ok(())
521 }
522
523 async fn recv(&mut self) -> Result<Option<McpFrame>, McpError> {
524 let mut line = String::new();
525 let read = self
526 .stdout
527 .read_line(&mut line)
528 .await
529 .map_err(McpError::Io)?;
530 if read == 0 {
531 return Ok(None);
532 }
533
534 let value = serde_json::from_str(line.trim()).map_err(McpError::Serialize)?;
535 Ok(Some(McpFrame { value }))
536 }
537
538 async fn close(&mut self) -> Result<(), McpError> {
539 let _ = self.stdin.shutdown().await;
540 let _ = self.child.kill().await;
541 Ok(())
542 }
543}
544
545#[async_trait]
546impl McpTransport for SseTransport {
547 async fn send(&mut self, message: McpFrame) -> Result<(), McpError> {
548 let mut request = self
549 .client
550 .post(self.endpoint_url.clone())
551 .header("Content-Type", "application/json");
552
553 for (key, value) in &self.headers {
554 request = request.header(key, value);
555 }
556
557 let response = request
558 .json(&message.value)
559 .send()
560 .await
561 .map_err(McpError::Http)?;
562 let status = response.status();
563 if !status.is_success() {
564 let body = response
565 .text()
566 .await
567 .unwrap_or_else(|_| "<unreadable response body>".into());
568 return Err(McpError::Transport(format!(
569 "SSE POST failed with status {status}: {body}"
570 )));
571 }
572
573 Ok(())
574 }
575
576 async fn recv(&mut self) -> Result<Option<McpFrame>, McpError> {
577 match self.frame_rx.recv().await {
578 Some(Ok(frame)) => Ok(Some(frame)),
579 Some(Err(error)) => Err(error),
580 None => Ok(None),
581 }
582 }
583
584 async fn close(&mut self) -> Result<(), McpError> {
585 self.read_task.abort();
586 Ok(())
587 }
588}
589
590#[async_trait]
591impl McpTransport for StreamableHttpTransport {
592 async fn send(&mut self, message: McpFrame) -> Result<(), McpError> {
593 let is_request = is_jsonrpc_request(&message.value);
594 let request_id = message.value.get("id").cloned();
595 let is_initialize =
596 message.value.get("method").and_then(Value::as_str) == Some("initialize");
597
598 let mut request = self
599 .client
600 .post(self.endpoint_url.clone())
601 .header("Content-Type", "application/json")
602 .header("Accept", "application/json, text/event-stream");
603
604 request = apply_streamable_http_headers(
605 request,
606 &self.headers,
607 self.protocol_version.as_deref(),
608 self.session_id.as_deref(),
609 );
610
611 let response = request
612 .json(&message.value)
613 .send()
614 .await
615 .map_err(McpError::Http)?;
616
617 if is_initialize {
618 self.capture_session_id(response.headers());
619 }
620
621 let status = response.status();
622 if !status.is_success() {
623 return Err(
624 streamable_http_status_error("Streamable HTTP POST", status, response).await,
625 );
626 }
627
628 if !is_request {
629 return Ok(());
630 }
631
632 let content_type = response
633 .headers()
634 .get(reqwest::header::CONTENT_TYPE)
635 .and_then(|value| value.to_str().ok())
636 .unwrap_or_default()
637 .to_string();
638
639 if content_type.starts_with("application/json") {
640 let value = response.json::<Value>().await.map_err(McpError::Http)?;
641 self.maybe_update_protocol_version(&message.value, &value)?;
642 self.pending_frames.push_back(McpFrame { value });
643 return Ok(());
644 }
645
646 if !content_type.starts_with("text/event-stream") {
647 let body = response
648 .text()
649 .await
650 .unwrap_or_else(|_| "<unreadable response body>".into());
651 return Err(McpError::Transport(format!(
652 "unexpected Streamable HTTP response content type {content_type:?}: {body}"
653 )));
654 }
655
656 let request_id = request_id.ok_or_else(|| {
657 McpError::Protocol("JSON-RPC request over Streamable HTTP is missing an id".into())
658 })?;
659 self.collect_streamable_http_response(response, &message.value, &request_id)
660 .await
661 }
662
663 async fn recv(&mut self) -> Result<Option<McpFrame>, McpError> {
664 Ok(self.pending_frames.pop_front())
665 }
666
667 async fn close(&mut self) -> Result<(), McpError> {
668 let Some(session_id) = self.session_id.clone() else {
669 return Ok(());
670 };
671
672 let mut request = self.client.delete(self.endpoint_url.clone());
673 request = apply_streamable_http_headers(
674 request,
675 &self.headers,
676 self.protocol_version.as_deref(),
677 Some(session_id.as_str()),
678 );
679
680 let response = request.send().await.map_err(McpError::Http)?;
681 if response.status().is_success()
682 || response.status() == StatusCode::METHOD_NOT_ALLOWED
683 || response.status() == StatusCode::NOT_FOUND
684 {
685 self.session_id = None;
686 return Ok(());
687 }
688
689 Err(
690 streamable_http_status_error("Streamable HTTP DELETE", response.status(), response)
691 .await,
692 )
693 }
694}
695
696impl StreamableHttpTransport {
697 async fn collect_streamable_http_response(
698 &mut self,
699 response: reqwest::Response,
700 request_message: &Value,
701 request_id: &Value,
702 ) -> Result<(), McpError> {
703 let mut retry_delay = Duration::from_millis(0);
704 let mut last_event_id = None;
705 let mut saw_response = false;
706
707 saw_response |= self
708 .read_streamable_http_events(
709 response,
710 request_message,
711 request_id,
712 &mut last_event_id,
713 &mut retry_delay,
714 )
715 .await?;
716
717 while !saw_response && last_event_id.is_some() {
718 if !retry_delay.is_zero() {
719 sleep(retry_delay).await;
720 }
721
722 let response = self
723 .resume_streamable_http_stream(last_event_id.as_deref().unwrap())
724 .await?;
725 saw_response |= self
726 .read_streamable_http_events(
727 response,
728 request_message,
729 request_id,
730 &mut last_event_id,
731 &mut retry_delay,
732 )
733 .await?;
734 }
735
736 Ok(())
737 }
738
739 async fn read_streamable_http_events(
740 &mut self,
741 response: reqwest::Response,
742 request_message: &Value,
743 request_id: &Value,
744 last_event_id: &mut Option<String>,
745 retry_delay: &mut Duration,
746 ) -> Result<bool, McpError> {
747 let stream = response.bytes_stream().map_err(std::io::Error::other);
748 let mut reader = BufReader::new(StreamReader::new(stream));
749 let mut saw_response = false;
750
751 while let Some(event) = read_next_sse_event(&mut reader).await? {
752 if let Some(id) = event.id.clone() {
753 *last_event_id = Some(id);
754 }
755 if let Some(retry_ms) = event.retry_ms {
756 *retry_delay = Duration::from_millis(retry_ms);
757 }
758
759 let Some(frame) = streamable_http_event_to_frame(event)? else {
760 continue;
761 };
762
763 self.maybe_update_protocol_version(request_message, &frame.value)?;
764 if frame.value.get("id") == Some(request_id) {
765 saw_response = true;
766 }
767 self.pending_frames.push_back(frame);
768 }
769
770 Ok(saw_response)
771 }
772
773 async fn resume_streamable_http_stream(
774 &self,
775 last_event_id: &str,
776 ) -> Result<reqwest::Response, McpError> {
777 let mut request = self
778 .client
779 .get(self.endpoint_url.clone())
780 .header("Accept", "text/event-stream")
781 .header("Cache-Control", "no-cache")
782 .header("Last-Event-ID", last_event_id);
783
784 request = apply_streamable_http_headers(
785 request,
786 &self.headers,
787 self.protocol_version.as_deref(),
788 self.session_id.as_deref(),
789 );
790
791 let response = request.send().await.map_err(McpError::Http)?;
792 let status = response.status();
793 if !status.is_success() {
794 return Err(
795 streamable_http_status_error("Streamable HTTP GET", status, response).await,
796 );
797 }
798
799 let content_type = response
800 .headers()
801 .get(reqwest::header::CONTENT_TYPE)
802 .and_then(|value| value.to_str().ok())
803 .unwrap_or_default();
804 if !content_type.starts_with("text/event-stream") {
805 let content_type = content_type.to_string();
806 let body = response
807 .text()
808 .await
809 .unwrap_or_else(|_| "<unreadable response body>".into());
810 return Err(McpError::Transport(format!(
811 "Streamable HTTP GET expected text/event-stream, got {content_type:?}: {body}"
812 )));
813 }
814
815 Ok(response)
816 }
817
818 fn maybe_update_protocol_version(
819 &mut self,
820 request_message: &Value,
821 response_value: &Value,
822 ) -> Result<(), McpError> {
823 if request_message.get("method").and_then(Value::as_str) != Some("initialize") {
824 return Ok(());
825 }
826
827 let protocol_version = response_value
828 .get("result")
829 .and_then(|result| result.get("protocolVersion"))
830 .and_then(Value::as_str);
831
832 if let Some(protocol_version) = protocol_version {
833 self.protocol_version = Some(protocol_version.to_string());
834 }
835
836 Ok(())
837 }
838
839 fn capture_session_id(&mut self, headers: &reqwest::header::HeaderMap) {
840 self.session_id = headers
841 .get("MCP-Session-Id")
842 .and_then(|value| value.to_str().ok())
843 .map(|value| value.to_string());
844 }
845}
846
847#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
853pub struct McpToolDescriptor {
854 pub name: String,
856 pub description: Option<String>,
858 pub input_schema: Value,
860 pub metadata: MetadataMap,
862}
863
864#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
869pub struct McpResourceDescriptor {
870 pub id: String,
872 pub name: String,
874 pub description: Option<String>,
876 pub mime_type: Option<String>,
878 pub metadata: MetadataMap,
880}
881
882#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
887pub struct McpPromptDescriptor {
888 pub id: String,
890 pub name: String,
892 pub description: Option<String>,
894 pub input_schema: Value,
896 pub metadata: MetadataMap,
898}
899
900#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
906pub struct McpDiscoverySnapshot {
907 pub server_id: McpServerId,
909 pub tools: Vec<McpToolDescriptor>,
911 pub resources: Vec<McpResourceDescriptor>,
913 pub prompts: Vec<McpPromptDescriptor>,
915 pub metadata: MetadataMap,
917}
918
919pub struct McpConnection {
949 server_id: McpServerId,
950 transport: Mutex<Box<dyn McpTransport>>,
951 auth: Mutex<Option<MetadataMap>>,
952 next_id: AtomicU64,
953}
954
955#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
960pub enum McpOperationResult {
961 Connected(McpDiscoverySnapshot),
963 Tool(Value),
965 Resource(ResourceContents),
967 Prompt(PromptContents),
969}
970
971impl McpConnection {
972 pub async fn connect(config: &McpServerConfig) -> Result<Self, McpError> {
980 Self::connect_with_auth(config, None).await
981 }
982
983 async fn connect_with_auth(
984 config: &McpServerConfig,
985 auth: Option<&MetadataMap>,
986 ) -> Result<Self, McpError> {
987 let factory: Arc<dyn McpTransportFactory> = match &config.transport {
988 McpTransportBinding::Stdio(binding) => {
989 Arc::new(StdioTransportFactory::new(binding.clone()))
990 }
991 McpTransportBinding::StreamableHttp(binding) => {
992 Arc::new(StreamableHttpTransportFactory::new(binding.clone()))
993 }
994 McpTransportBinding::Sse(binding) => {
995 Arc::new(SseTransportFactory::new(binding.clone()))
996 }
997 McpTransportBinding::Custom(factory) => factory.clone(),
998 };
999
1000 let mut transport = factory.connect().await?;
1001 let mut params = serde_json::Map::new();
1002 params.insert(
1003 "protocolVersion".into(),
1004 Value::String(MCP_LATEST_PROTOCOL_VERSION.into()),
1005 );
1006 params.insert("capabilities".into(), json!({}));
1007 params.insert(
1008 "clientInfo".into(),
1009 json!({
1010 "name": "agentkit-mcp",
1011 "version": env!("CARGO_PKG_VERSION")
1012 }),
1013 );
1014 if let Some(auth) = auth {
1015 params.insert("auth".into(), metadata_to_value(auth));
1016 }
1017 let init_params = Value::Object(params.clone());
1018 transport
1019 .send(McpFrame {
1020 value: json!({
1021 "jsonrpc": "2.0",
1022 "id": 0,
1023 "method": "initialize",
1024 "params": init_params.clone()
1025 }),
1026 })
1027 .await?;
1028 let init_response = transport.recv().await?.ok_or_else(|| {
1029 McpError::Transport("transport closed during MCP initialization".into())
1030 })?;
1031 if let Some(error) = init_response.value.get("error") {
1032 if let Some(auth_request) =
1033 parse_auth_request(&config.id, "initialize", &init_params, error)
1034 {
1035 return Err(McpError::AuthRequired(Box::new(auth_request)));
1036 }
1037 return Err(McpError::Invocation(error.to_string()));
1038 }
1039 let negotiated_protocol_version = init_response
1040 .value
1041 .get("result")
1042 .and_then(|result| result.get("protocolVersion"))
1043 .and_then(Value::as_str)
1044 .ok_or_else(|| {
1045 McpError::Protocol("initialize response missing result.protocolVersion".into())
1046 })?;
1047 if !MCP_SUPPORTED_PROTOCOL_VERSIONS.contains(&negotiated_protocol_version) {
1048 return Err(McpError::Protocol(format!(
1049 "unsupported MCP protocol version negotiated during initialize: {negotiated_protocol_version}"
1050 )));
1051 }
1052 transport
1053 .send(McpFrame {
1054 value: json!({
1055 "jsonrpc": "2.0",
1056 "method": "notifications/initialized",
1057 "params": {}
1058 }),
1059 })
1060 .await?;
1061
1062 Ok(Self {
1063 server_id: config.id.clone(),
1064 transport: Mutex::new(transport),
1065 auth: Mutex::new(auth.cloned()),
1066 next_id: AtomicU64::new(1),
1067 })
1068 }
1069
1070 pub fn server_id(&self) -> &McpServerId {
1072 &self.server_id
1073 }
1074
1075 pub async fn close(&self) -> Result<(), McpError> {
1081 let mut transport = self.transport.lock().await;
1082 transport.close().await
1083 }
1084
1085 pub async fn resolve_auth(&self, resolution: AuthResolution) -> Result<(), McpError> {
1095 let mut auth = self.auth.lock().await;
1096 match resolution {
1097 AuthResolution::Provided { credentials, .. } => {
1098 *auth = Some(credentials);
1099 }
1100 AuthResolution::Cancelled { .. } => {
1101 *auth = None;
1102 }
1103 }
1104 Ok(())
1105 }
1106
1107 pub async fn discover(&self) -> Result<McpDiscoverySnapshot, McpError> {
1115 Ok(McpDiscoverySnapshot {
1116 server_id: self.server_id.clone(),
1117 tools: self.list_tools().await?,
1118 resources: self.list_resources().await?,
1119 prompts: self.list_prompts().await?,
1120 metadata: MetadataMap::new(),
1121 })
1122 }
1123
1124 pub async fn list_tools(&self) -> Result<Vec<McpToolDescriptor>, McpError> {
1130 let result = self.request("tools/list", json!({})).await?;
1131 result
1132 .get("tools")
1133 .and_then(Value::as_array)
1134 .cloned()
1135 .unwrap_or_default()
1136 .into_iter()
1137 .map(parse_tool_descriptor)
1138 .collect()
1139 }
1140
1141 pub async fn list_resources(&self) -> Result<Vec<McpResourceDescriptor>, McpError> {
1147 let result = self.request("resources/list", json!({})).await?;
1148 result
1149 .get("resources")
1150 .and_then(Value::as_array)
1151 .cloned()
1152 .unwrap_or_default()
1153 .into_iter()
1154 .map(parse_resource_descriptor)
1155 .collect()
1156 }
1157
1158 pub async fn list_prompts(&self) -> Result<Vec<McpPromptDescriptor>, McpError> {
1164 let result = self.request("prompts/list", json!({})).await?;
1165 result
1166 .get("prompts")
1167 .and_then(Value::as_array)
1168 .cloned()
1169 .unwrap_or_default()
1170 .into_iter()
1171 .map(parse_prompt_descriptor)
1172 .collect()
1173 }
1174
1175 pub async fn call_tool(&self, name: &str, arguments: Value) -> Result<Value, McpError> {
1187 self.request(
1188 "tools/call",
1189 json!({
1190 "name": name,
1191 "arguments": arguments,
1192 }),
1193 )
1194 .await
1195 }
1196
1197 pub async fn read_resource(&self, uri: &str) -> Result<ResourceContents, McpError> {
1207 let result = self
1208 .request(
1209 "resources/read",
1210 json!({
1211 "uri": uri,
1212 }),
1213 )
1214 .await?;
1215 let content = result
1216 .get("contents")
1217 .and_then(Value::as_array)
1218 .and_then(|values| values.first())
1219 .cloned()
1220 .ok_or_else(|| McpError::Protocol("resources/read returned no contents".into()))?;
1221
1222 let data = if let Some(text) = content.get("text").and_then(Value::as_str) {
1223 DataRef::InlineText(text.into())
1224 } else if let Some(found_uri) = content.get("uri").and_then(Value::as_str) {
1225 DataRef::Uri(found_uri.into())
1226 } else {
1227 return Err(McpError::Protocol(
1228 "unsupported resource content shape".into(),
1229 ));
1230 };
1231
1232 Ok(ResourceContents {
1233 data,
1234 metadata: MetadataMap::new(),
1235 })
1236 }
1237
1238 pub async fn get_prompt(
1249 &self,
1250 name: &str,
1251 arguments: Value,
1252 ) -> Result<PromptContents, McpError> {
1253 let result = self
1254 .request(
1255 "prompts/get",
1256 json!({
1257 "name": name,
1258 "arguments": arguments,
1259 }),
1260 )
1261 .await?;
1262 let items = result
1263 .get("messages")
1264 .and_then(Value::as_array)
1265 .cloned()
1266 .unwrap_or_default()
1267 .into_iter()
1268 .map(parse_prompt_message)
1269 .collect::<Result<Vec<_>, _>>()?;
1270
1271 Ok(PromptContents {
1272 items,
1273 metadata: MetadataMap::new(),
1274 })
1275 }
1276
1277 async fn request(&self, method: &str, params: Value) -> Result<Value, McpError> {
1278 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
1279 let params = self.enrich_params(params.clone()).await;
1280 let mut transport = self.transport.lock().await;
1281 transport
1282 .send(McpFrame {
1283 value: json!({
1284 "jsonrpc": "2.0",
1285 "id": id,
1286 "method": method,
1287 "params": params,
1288 }),
1289 })
1290 .await?;
1291
1292 loop {
1293 let Some(frame) = transport.recv().await? else {
1294 return Err(McpError::Transport(
1295 "transport closed while waiting for MCP response".into(),
1296 ));
1297 };
1298
1299 if frame.value.get("id").and_then(Value::as_u64) != Some(id) {
1300 continue;
1301 }
1302
1303 if let Some(error) = frame.value.get("error") {
1304 if let Some(auth_request) =
1305 parse_auth_request(&self.server_id, method, ¶ms, error)
1306 {
1307 return Err(McpError::AuthRequired(Box::new(auth_request)));
1308 }
1309 return Err(McpError::Invocation(error.to_string()));
1310 }
1311
1312 return frame
1313 .value
1314 .get("result")
1315 .cloned()
1316 .ok_or_else(|| McpError::Protocol("MCP response missing result".into()));
1317 }
1318 }
1319
1320 async fn enrich_params(&self, params: Value) -> Value {
1321 let auth = self.auth.lock().await;
1322 let Some(auth) = auth.as_ref() else {
1323 return params;
1324 };
1325
1326 match params {
1327 Value::Object(mut object) => {
1328 object
1329 .entry("auth")
1330 .or_insert_with(|| metadata_to_value(auth));
1331 Value::Object(object)
1332 }
1333 other => other,
1334 }
1335 }
1336
1337 pub async fn replay_auth_operation(
1347 &self,
1348 operation: &AuthOperation,
1349 ) -> Result<McpOperationResult, McpError> {
1350 match operation {
1351 AuthOperation::McpToolCall {
1352 server_id,
1353 tool_name,
1354 input,
1355 ..
1356 } => {
1357 self.ensure_server_match(server_id)?;
1358 self.call_tool(tool_name, input.clone())
1359 .await
1360 .map(McpOperationResult::Tool)
1361 }
1362 AuthOperation::McpResourceRead {
1363 server_id,
1364 resource_id,
1365 ..
1366 } => {
1367 self.ensure_server_match(server_id)?;
1368 self.read_resource(resource_id)
1369 .await
1370 .map(McpOperationResult::Resource)
1371 }
1372 AuthOperation::McpPromptGet {
1373 server_id,
1374 prompt_id,
1375 args,
1376 ..
1377 } => {
1378 self.ensure_server_match(server_id)?;
1379 self.get_prompt(prompt_id, args.clone())
1380 .await
1381 .map(McpOperationResult::Prompt)
1382 }
1383 AuthOperation::ToolCall {
1384 tool_name,
1385 input,
1386 metadata,
1387 ..
1388 } => {
1389 if let Some(server_id) = metadata.get("server_id").and_then(Value::as_str) {
1390 self.ensure_server_match(server_id)?;
1391 }
1392 let tool_name = normalize_mcp_tool_name(self.server_id(), tool_name);
1393 self.call_tool(&tool_name, input.clone())
1394 .await
1395 .map(McpOperationResult::Tool)
1396 }
1397 AuthOperation::McpConnect { .. } => Err(McpError::AuthResolution(
1398 "connect operations must be replayed through the server manager".into(),
1399 )),
1400 AuthOperation::Custom { kind, .. } => Err(McpError::AuthResolution(format!(
1401 "unsupported auth operation for replay: {kind}"
1402 ))),
1403 }
1404 }
1405
1406 fn ensure_server_match(&self, server_id: &str) -> Result<(), McpError> {
1407 if self.server_id.0 == server_id {
1408 Ok(())
1409 } else {
1410 Err(McpError::AuthResolution(format!(
1411 "auth operation targets server {server_id}, but connection is for {}",
1412 self.server_id
1413 )))
1414 }
1415 }
1416}
1417
1418pub struct McpInvocable {
1423 connection: Arc<McpConnection>,
1424 descriptor: McpToolDescriptor,
1425 spec: InvocableSpec,
1426}
1427
1428impl McpInvocable {
1429 pub fn new(connection: Arc<McpConnection>, descriptor: McpToolDescriptor) -> Self {
1436 let spec = InvocableSpec {
1437 name: CapabilityName::new(format!(
1438 "mcp.{}.{}",
1439 connection.server_id(),
1440 descriptor.name
1441 )),
1442 description: descriptor
1443 .description
1444 .clone()
1445 .unwrap_or_else(|| descriptor.name.clone()),
1446 input_schema: descriptor.input_schema.clone(),
1447 metadata: descriptor.metadata.clone(),
1448 };
1449
1450 Self {
1451 connection,
1452 descriptor,
1453 spec,
1454 }
1455 }
1456}
1457
1458#[async_trait]
1459impl Invocable for McpInvocable {
1460 fn spec(&self) -> &InvocableSpec {
1461 &self.spec
1462 }
1463
1464 async fn invoke(
1465 &self,
1466 request: InvocableRequest,
1467 _ctx: &mut CapabilityContext<'_>,
1468 ) -> Result<InvocableResult, CapabilityError> {
1469 let result = self
1470 .connection
1471 .call_tool(&self.descriptor.name, request.input)
1472 .await
1473 .map_err(|error| match error {
1474 McpError::AuthRequired(request) => {
1475 CapabilityError::Unavailable(format!("auth required: {:?}", request))
1476 }
1477 other => CapabilityError::ExecutionFailed(other.to_string()),
1478 })?;
1479
1480 Ok(InvocableResult {
1481 output: value_to_invocable_output(result),
1482 metadata: MetadataMap::new(),
1483 })
1484 }
1485}
1486
1487pub struct McpResourceHandle {
1492 connection: Arc<McpConnection>,
1493 descriptor: ResourceDescriptor,
1494}
1495
1496#[async_trait]
1497impl ResourceProvider for McpResourceHandle {
1498 async fn list_resources(&self) -> Result<Vec<ResourceDescriptor>, CapabilityError> {
1499 Ok(vec![self.descriptor.clone()])
1500 }
1501
1502 async fn read_resource(
1503 &self,
1504 id: &ResourceId,
1505 _ctx: &mut CapabilityContext<'_>,
1506 ) -> Result<ResourceContents, CapabilityError> {
1507 self.connection
1508 .read_resource(&id.0)
1509 .await
1510 .map_err(|error| match error {
1511 McpError::AuthRequired(request) => {
1512 CapabilityError::Unavailable(format!("auth required: {:?}", request))
1513 }
1514 other => CapabilityError::ExecutionFailed(other.to_string()),
1515 })
1516 }
1517}
1518
1519pub struct McpPromptHandle {
1524 connection: Arc<McpConnection>,
1525 descriptor: PromptDescriptor,
1526}
1527
1528#[async_trait]
1529impl PromptProvider for McpPromptHandle {
1530 async fn list_prompts(&self) -> Result<Vec<PromptDescriptor>, CapabilityError> {
1531 Ok(vec![self.descriptor.clone()])
1532 }
1533
1534 async fn get_prompt(
1535 &self,
1536 id: &PromptId,
1537 args: Value,
1538 _ctx: &mut CapabilityContext<'_>,
1539 ) -> Result<PromptContents, CapabilityError> {
1540 self.connection
1541 .get_prompt(&id.0, args)
1542 .await
1543 .map_err(|error| match error {
1544 McpError::AuthRequired(request) => {
1545 CapabilityError::Unavailable(format!("auth required: {:?}", request))
1546 }
1547 other => CapabilityError::ExecutionFailed(other.to_string()),
1548 })
1549 }
1550}
1551
1552pub struct McpCapabilityProvider {
1579 invocables: Vec<Arc<dyn Invocable>>,
1580 resources: Vec<Arc<dyn ResourceProvider>>,
1581 prompts: Vec<Arc<dyn PromptProvider>>,
1582}
1583
1584impl McpCapabilityProvider {
1585 pub fn from_snapshot(connection: Arc<McpConnection>, snapshot: &McpDiscoverySnapshot) -> Self {
1591 let invocables = snapshot
1592 .tools
1593 .iter()
1594 .cloned()
1595 .map(|descriptor| {
1596 Arc::new(McpInvocable::new(connection.clone(), descriptor)) as Arc<dyn Invocable>
1597 })
1598 .collect();
1599
1600 let resources = snapshot
1601 .resources
1602 .iter()
1603 .cloned()
1604 .map(|descriptor| {
1605 Arc::new(McpResourceHandle {
1606 connection: connection.clone(),
1607 descriptor: ResourceDescriptor {
1608 id: ResourceId::new(descriptor.id),
1609 name: descriptor.name,
1610 description: descriptor.description,
1611 mime_type: descriptor.mime_type,
1612 metadata: descriptor.metadata,
1613 },
1614 }) as Arc<dyn ResourceProvider>
1615 })
1616 .collect();
1617
1618 let prompts = snapshot
1619 .prompts
1620 .iter()
1621 .cloned()
1622 .map(|descriptor| {
1623 Arc::new(McpPromptHandle {
1624 connection: connection.clone(),
1625 descriptor: PromptDescriptor {
1626 id: PromptId::new(descriptor.id),
1627 name: descriptor.name,
1628 description: descriptor.description,
1629 input_schema: descriptor.input_schema,
1630 metadata: descriptor.metadata,
1631 },
1632 }) as Arc<dyn PromptProvider>
1633 })
1634 .collect();
1635
1636 Self {
1637 invocables,
1638 resources,
1639 prompts,
1640 }
1641 }
1642
1643 pub fn merge<I>(providers: I) -> Self
1648 where
1649 I: IntoIterator<Item = Self>,
1650 {
1651 let mut invocables = Vec::new();
1652 let mut resources = Vec::new();
1653 let mut prompts = Vec::new();
1654
1655 for provider in providers {
1656 invocables.extend(provider.invocables);
1657 resources.extend(provider.resources);
1658 prompts.extend(provider.prompts);
1659 }
1660
1661 Self {
1662 invocables,
1663 resources,
1664 prompts,
1665 }
1666 }
1667
1668 pub async fn connect(
1677 config: &McpServerConfig,
1678 ) -> Result<(Arc<McpConnection>, Self, McpDiscoverySnapshot), McpError> {
1679 let connection = Arc::new(McpConnection::connect(config).await?);
1680 let snapshot = connection.discover().await?;
1681 let provider = Self::from_snapshot(connection.clone(), &snapshot);
1682
1683 Ok((connection, provider, snapshot))
1684 }
1685}
1686
1687impl CapabilityProvider for McpCapabilityProvider {
1688 fn invocables(&self) -> Vec<Arc<dyn Invocable>> {
1689 self.invocables.clone()
1690 }
1691
1692 fn resources(&self) -> Vec<Arc<dyn ResourceProvider>> {
1693 self.resources.clone()
1694 }
1695
1696 fn prompts(&self) -> Vec<Arc<dyn PromptProvider>> {
1697 self.prompts.clone()
1698 }
1699}
1700
1701#[derive(Clone)]
1707pub struct McpServerHandle {
1708 config: McpServerConfig,
1709 connection: Arc<McpConnection>,
1710 snapshot: McpDiscoverySnapshot,
1711}
1712
1713impl McpServerHandle {
1714 pub fn config(&self) -> &McpServerConfig {
1716 &self.config
1717 }
1718
1719 pub fn server_id(&self) -> &McpServerId {
1721 self.connection.server_id()
1722 }
1723
1724 pub fn connection(&self) -> Arc<McpConnection> {
1726 self.connection.clone()
1727 }
1728
1729 pub fn snapshot(&self) -> &McpDiscoverySnapshot {
1731 &self.snapshot
1732 }
1733
1734 pub fn tool_registry(&self) -> ToolRegistry {
1737 self.snapshot
1738 .tools
1739 .iter()
1740 .cloned()
1741 .fold(ToolRegistry::new(), |registry, descriptor| {
1742 registry.with(McpToolAdapter::new(
1743 self.server_id(),
1744 self.connection.clone(),
1745 descriptor,
1746 ))
1747 })
1748 }
1749
1750 pub fn capability_provider(&self) -> McpCapabilityProvider {
1752 McpCapabilityProvider::from_snapshot(self.connection.clone(), &self.snapshot)
1753 }
1754}
1755
1756#[derive(Default)]
1797pub struct McpServerManager {
1798 configs: BTreeMap<McpServerId, McpServerConfig>,
1799 connections: BTreeMap<McpServerId, McpServerHandle>,
1800 auth: BTreeMap<McpServerId, MetadataMap>,
1801}
1802
1803impl McpServerManager {
1804 pub fn new() -> Self {
1806 Self::default()
1807 }
1808
1809 pub fn with_server(mut self, config: McpServerConfig) -> Self {
1814 self.register_server(config);
1815 self
1816 }
1817
1818 pub fn register_server(&mut self, config: McpServerConfig) -> &mut Self {
1823 self.configs.insert(config.id.clone(), config);
1824 self
1825 }
1826
1827 pub fn connected_server(&self, server_id: &McpServerId) -> Option<&McpServerHandle> {
1829 self.connections.get(server_id)
1830 }
1831
1832 pub fn connected_servers(&self) -> Vec<&McpServerHandle> {
1834 self.connections.values().collect()
1835 }
1836
1837 pub async fn connect_server(
1846 &mut self,
1847 server_id: &McpServerId,
1848 ) -> Result<McpServerHandle, McpError> {
1849 let config = self
1850 .configs
1851 .get(server_id)
1852 .cloned()
1853 .ok_or_else(|| McpError::UnknownServer(server_id.to_string()))?;
1854 let connection =
1855 Arc::new(McpConnection::connect_with_auth(&config, self.auth.get(server_id)).await?);
1856 let snapshot = connection.discover().await?;
1857 let handle = McpServerHandle {
1858 config,
1859 connection,
1860 snapshot,
1861 };
1862 self.connections.insert(server_id.clone(), handle.clone());
1863 Ok(handle)
1864 }
1865
1866 pub async fn connect_all(&mut self) -> Result<Vec<McpServerHandle>, McpError> {
1876 let server_ids = self.configs.keys().cloned().collect::<Vec<_>>();
1877 let mut handles = Vec::with_capacity(server_ids.len());
1878
1879 for server_id in server_ids {
1880 handles.push(self.connect_server(&server_id).await?);
1881 }
1882
1883 Ok(handles)
1884 }
1885
1886 pub async fn refresh_server(
1896 &mut self,
1897 server_id: &McpServerId,
1898 ) -> Result<McpDiscoverySnapshot, McpError> {
1899 let handle = self
1900 .connections
1901 .get_mut(server_id)
1902 .ok_or_else(|| McpError::UnknownServer(server_id.to_string()))?;
1903 let snapshot = handle.connection.discover().await?;
1904 handle.snapshot = snapshot.clone();
1905 Ok(snapshot)
1906 }
1907
1908 pub async fn disconnect_server(&mut self, server_id: &McpServerId) -> Result<(), McpError> {
1917 let Some(handle) = self.connections.remove(server_id) else {
1918 return Err(McpError::UnknownServer(server_id.to_string()));
1919 };
1920 handle.connection.close().await
1921 }
1922
1923 pub async fn resolve_auth(&mut self, resolution: AuthResolution) -> Result<(), McpError> {
1931 let server_id = resolution
1932 .request()
1933 .server_id()
1934 .ok_or_else(|| McpError::AuthResolution("auth resolution missing server id".into()))?;
1935 let server_id = McpServerId::new(server_id);
1936 match &resolution {
1937 AuthResolution::Provided { credentials, .. } => {
1938 self.auth.insert(server_id.clone(), credentials.clone());
1939 }
1940 AuthResolution::Cancelled { .. } => {
1941 self.auth.remove(&server_id);
1942 }
1943 }
1944
1945 if let Some(handle) = self.connections.get(&server_id) {
1946 handle.connection.resolve_auth(resolution).await?;
1947 return Ok(());
1948 }
1949
1950 if self.configs.contains_key(&server_id) {
1951 Ok(())
1952 } else {
1953 Err(McpError::UnknownServer(server_id.to_string()))
1954 }
1955 }
1956
1957 pub async fn resolve_auth_and_resume(
1967 &mut self,
1968 resolution: AuthResolution,
1969 ) -> Result<McpOperationResult, McpError> {
1970 let request = resolution.request().clone();
1971 self.resolve_auth(resolution).await?;
1972 self.replay_auth_request(&request).await
1973 }
1974
1975 pub async fn replay_auth_request(
1985 &mut self,
1986 request: &AuthRequest,
1987 ) -> Result<McpOperationResult, McpError> {
1988 match &request.operation {
1989 AuthOperation::McpConnect { server_id, .. } => {
1990 let server_id = McpServerId::new(server_id);
1991 let handle = self.connect_server(&server_id).await?;
1992 Ok(McpOperationResult::Connected(handle.snapshot.clone()))
1993 }
1994 AuthOperation::McpToolCall { server_id, .. }
1995 | AuthOperation::McpResourceRead { server_id, .. }
1996 | AuthOperation::McpPromptGet { server_id, .. } => {
1997 let connection = self.connection_for_auth_server(server_id).await?;
1998 connection.replay_auth_operation(&request.operation).await
1999 }
2000 AuthOperation::ToolCall { metadata, .. } => {
2001 let server_id = metadata
2002 .get("server_id")
2003 .and_then(Value::as_str)
2004 .ok_or_else(|| {
2005 McpError::AuthResolution(
2006 "tool-call auth replay requires metadata.server_id".into(),
2007 )
2008 })?;
2009 let connection = self.connection_for_auth_server(server_id).await?;
2010 connection.replay_auth_operation(&request.operation).await
2011 }
2012 AuthOperation::Custom { kind, .. } => Err(McpError::AuthResolution(format!(
2013 "unsupported auth operation for replay: {kind}"
2014 ))),
2015 }
2016 }
2017
2018 async fn connection_for_auth_server(
2019 &mut self,
2020 server_id: &str,
2021 ) -> Result<Arc<McpConnection>, McpError> {
2022 let server_id = McpServerId::new(server_id);
2023 if !self.connections.contains_key(&server_id) {
2024 self.connect_server(&server_id).await?;
2025 }
2026 self.connections
2027 .get(&server_id)
2028 .map(McpServerHandle::connection)
2029 .ok_or_else(|| McpError::UnknownServer(server_id.to_string()))
2030 }
2031
2032 pub fn tool_registry(&self) -> ToolRegistry {
2037 self.connections
2038 .values()
2039 .fold(ToolRegistry::new(), |mut registry, handle| {
2040 for tool in handle.snapshot.tools.iter().cloned() {
2041 registry.register(McpToolAdapter::new(
2042 handle.server_id(),
2043 handle.connection.clone(),
2044 tool,
2045 ));
2046 }
2047 registry
2048 })
2049 }
2050
2051 pub fn capability_provider(&self) -> McpCapabilityProvider {
2054 McpCapabilityProvider::merge(
2055 self.connections
2056 .values()
2057 .map(McpServerHandle::capability_provider),
2058 )
2059 }
2060}
2061
2062pub struct McpToolAdapter {
2078 descriptor: McpToolDescriptor,
2079 connection: Arc<McpConnection>,
2080 spec: ToolSpec,
2081}
2082
2083impl McpToolAdapter {
2084 pub fn new(
2092 server_id: &McpServerId,
2093 connection: Arc<McpConnection>,
2094 descriptor: McpToolDescriptor,
2095 ) -> Self {
2096 let spec = ToolSpec {
2097 name: ToolName::new(format!("mcp.{}.{}", server_id, descriptor.name)),
2098 description: descriptor
2099 .description
2100 .clone()
2101 .unwrap_or_else(|| descriptor.name.clone()),
2102 input_schema: descriptor.input_schema.clone(),
2103 annotations: ToolAnnotations::default(),
2104 metadata: descriptor.metadata.clone(),
2105 };
2106
2107 Self {
2108 descriptor,
2109 connection,
2110 spec,
2111 }
2112 }
2113}
2114
2115#[async_trait]
2116impl Tool for McpToolAdapter {
2117 fn spec(&self) -> &ToolSpec {
2118 &self.spec
2119 }
2120
2121 async fn invoke(
2122 &self,
2123 request: ToolRequest,
2124 _ctx: &mut ToolContext<'_>,
2125 ) -> Result<ToolResult, ToolError> {
2126 let result = self
2127 .connection
2128 .call_tool(&self.descriptor.name, request.input)
2129 .await
2130 .map_err(|error| match error {
2131 McpError::AuthRequired(request) => ToolError::AuthRequired(request),
2132 other => ToolError::ExecutionFailed(other.to_string()),
2133 })?;
2134
2135 Ok(ToolResult {
2136 result: ToolResultPart {
2137 call_id: request.call_id,
2138 output: invocable_output_to_tool_output(value_to_invocable_output(result)),
2139 is_error: false,
2140 metadata: MetadataMap::new(),
2141 },
2142 duration: None,
2143 metadata: MetadataMap::new(),
2144 })
2145 }
2146}
2147
2148fn parse_tool_descriptor(value: Value) -> Result<McpToolDescriptor, McpError> {
2149 Ok(McpToolDescriptor {
2150 name: required_string(&value, "name")?,
2151 description: value
2152 .get("description")
2153 .and_then(Value::as_str)
2154 .map(str::to_owned),
2155 input_schema: value
2156 .get("inputSchema")
2157 .cloned()
2158 .unwrap_or_else(|| json!({ "type": "object" })),
2159 metadata: MetadataMap::new(),
2160 })
2161}
2162
2163fn parse_resource_descriptor(value: Value) -> Result<McpResourceDescriptor, McpError> {
2164 Ok(McpResourceDescriptor {
2165 id: required_string(&value, "uri")?,
2166 name: value
2167 .get("name")
2168 .and_then(Value::as_str)
2169 .map(str::to_owned)
2170 .unwrap_or_else(|| {
2171 value
2172 .get("uri")
2173 .and_then(Value::as_str)
2174 .unwrap_or_default()
2175 .to_string()
2176 }),
2177 description: value
2178 .get("description")
2179 .and_then(Value::as_str)
2180 .map(str::to_owned),
2181 mime_type: value
2182 .get("mimeType")
2183 .and_then(Value::as_str)
2184 .map(str::to_owned),
2185 metadata: MetadataMap::new(),
2186 })
2187}
2188
2189fn parse_prompt_descriptor(value: Value) -> Result<McpPromptDescriptor, McpError> {
2190 let name = required_string(&value, "name")?;
2191 let properties = value
2192 .get("arguments")
2193 .and_then(Value::as_array)
2194 .cloned()
2195 .unwrap_or_default()
2196 .into_iter()
2197 .filter_map(|arg| {
2198 let name = arg.get("name")?.as_str()?.to_string();
2199 Some((name, json!({ "type": "string" })))
2200 })
2201 .collect::<serde_json::Map<String, Value>>();
2202
2203 Ok(McpPromptDescriptor {
2204 id: name.clone(),
2205 name,
2206 description: value
2207 .get("description")
2208 .and_then(Value::as_str)
2209 .map(str::to_owned),
2210 input_schema: json!({
2211 "type": "object",
2212 "properties": properties,
2213 }),
2214 metadata: MetadataMap::new(),
2215 })
2216}
2217
2218fn parse_prompt_message(value: Value) -> Result<Item, McpError> {
2219 let role = value.get("role").and_then(Value::as_str).unwrap_or("user");
2220 let kind = match role {
2221 "assistant" => ItemKind::Assistant,
2222 "system" => ItemKind::System,
2223 _ => ItemKind::User,
2224 };
2225
2226 let content = value.get("content").cloned().unwrap_or(Value::Null);
2227 let text = if let Some(text) = content.get("text").and_then(Value::as_str) {
2228 text.to_string()
2229 } else if let Some(text) = content.as_str() {
2230 text.to_string()
2231 } else {
2232 content.to_string()
2233 };
2234
2235 Ok(Item {
2236 id: None,
2237 kind,
2238 parts: vec![Part::Text(TextPart {
2239 text,
2240 metadata: MetadataMap::new(),
2241 })],
2242 metadata: MetadataMap::new(),
2243 })
2244}
2245
2246fn required_string(value: &Value, field: &str) -> Result<String, McpError> {
2247 value
2248 .get(field)
2249 .and_then(Value::as_str)
2250 .map(str::to_owned)
2251 .ok_or_else(|| McpError::Protocol(format!("missing string field {field}")))
2252}
2253
2254fn value_to_invocable_output(value: Value) -> InvocableOutput {
2255 if let Some(content) = value.get("content").and_then(Value::as_array) {
2256 let text = content
2257 .iter()
2258 .filter_map(|item| item.get("text").and_then(Value::as_str))
2259 .collect::<Vec<_>>()
2260 .join("\n");
2261 if !text.is_empty() {
2262 return InvocableOutput::Text(text);
2263 }
2264 }
2265
2266 if let Some(text) = value.as_str() {
2267 InvocableOutput::Text(text.to_string())
2268 } else {
2269 InvocableOutput::Structured(value)
2270 }
2271}
2272
2273fn invocable_output_to_tool_output(output: InvocableOutput) -> ToolOutput {
2274 match output {
2275 InvocableOutput::Text(text) => ToolOutput::Text(text),
2276 InvocableOutput::Structured(value) => ToolOutput::Structured(value),
2277 InvocableOutput::Items(items) => {
2278 ToolOutput::Parts(items.into_iter().flat_map(|item| item.parts).collect())
2279 }
2280 InvocableOutput::Data(data) => ToolOutput::Structured(json!({ "data": data })),
2281 }
2282}
2283
2284fn metadata_to_value(metadata: &MetadataMap) -> Value {
2285 Value::Object(
2286 metadata
2287 .iter()
2288 .map(|(key, value)| (key.clone(), value.clone()))
2289 .collect(),
2290 )
2291}
2292
2293fn parse_auth_request(
2294 server_id: &McpServerId,
2295 method: &str,
2296 params: &Value,
2297 error: &Value,
2298) -> Option<AuthRequest> {
2299 let code = error.get("code").and_then(Value::as_i64);
2300 let message = error.get("message").and_then(Value::as_str);
2301 let data = error.get("data");
2302
2303 let auth_marker = matches!(code, Some(401 | -32001))
2304 || data
2305 .and_then(|data| data.get("auth_required"))
2306 .and_then(Value::as_bool)
2307 == Some(true)
2308 || data.and_then(|data| data.get("auth")).is_some();
2309
2310 if !auth_marker {
2311 return None;
2312 }
2313
2314 let mut challenge = MetadataMap::new();
2315 challenge.insert("server_id".into(), Value::String(server_id.to_string()));
2316 challenge.insert("method".into(), Value::String(method.into()));
2317
2318 if let Some(code) = code {
2319 challenge.insert("code".into(), Value::Number(code.into()));
2320 }
2321 if let Some(message) = message {
2322 challenge.insert("message".into(), Value::String(message.into()));
2323 }
2324 if let Some(data) = data {
2325 challenge.insert("data".into(), data.clone());
2326 }
2327
2328 Some(AuthRequest {
2329 task_id: None,
2330 id: format!("mcp:{}:{}", server_id, method),
2331 provider: format!("mcp.{}", server_id),
2332 operation: auth_operation_for_method(server_id, method, params),
2333 challenge,
2334 })
2335}
2336
2337fn auth_operation_for_method(
2338 server_id: &McpServerId,
2339 method: &str,
2340 params: &Value,
2341) -> AuthOperation {
2342 match method {
2343 "initialize" => AuthOperation::McpConnect {
2344 server_id: server_id.to_string(),
2345 metadata: MetadataMap::new(),
2346 },
2347 "tools/call" => AuthOperation::McpToolCall {
2348 server_id: server_id.to_string(),
2349 tool_name: params
2350 .get("name")
2351 .and_then(Value::as_str)
2352 .unwrap_or_default()
2353 .to_string(),
2354 input: params
2355 .get("arguments")
2356 .cloned()
2357 .unwrap_or_else(|| json!({})),
2358 metadata: MetadataMap::new(),
2359 },
2360 "resources/read" => AuthOperation::McpResourceRead {
2361 server_id: server_id.to_string(),
2362 resource_id: params
2363 .get("uri")
2364 .and_then(Value::as_str)
2365 .unwrap_or_default()
2366 .to_string(),
2367 metadata: MetadataMap::new(),
2368 },
2369 "prompts/get" => AuthOperation::McpPromptGet {
2370 server_id: server_id.to_string(),
2371 prompt_id: params
2372 .get("name")
2373 .and_then(Value::as_str)
2374 .unwrap_or_default()
2375 .to_string(),
2376 args: params
2377 .get("arguments")
2378 .cloned()
2379 .unwrap_or_else(|| json!({})),
2380 metadata: MetadataMap::new(),
2381 },
2382 other => AuthOperation::Custom {
2383 kind: format!("mcp.{other}"),
2384 payload: params.clone(),
2385 metadata: {
2386 let mut metadata = MetadataMap::new();
2387 metadata.insert("server_id".into(), Value::String(server_id.to_string()));
2388 metadata
2389 },
2390 },
2391 }
2392}
2393
2394fn normalize_mcp_tool_name(server_id: &McpServerId, tool_name: &str) -> String {
2395 let prefix = format!("mcp.{server_id}.");
2396 tool_name
2397 .strip_prefix(&prefix)
2398 .unwrap_or(tool_name)
2399 .to_string()
2400}
2401
2402async fn read_sse_stream<R>(
2403 mut reader: R,
2404 response_url: Url,
2405 frame_tx: mpsc::UnboundedSender<Result<McpFrame, McpError>>,
2406 endpoint_tx: oneshot::Sender<Result<Url, McpError>>,
2407) where
2408 R: AsyncBufRead + Unpin,
2409{
2410 let mut endpoint_tx = Some(endpoint_tx);
2411 loop {
2412 match read_next_sse_event(&mut reader).await {
2413 Ok(Some(event)) => {
2414 if let Some(endpoint) = legacy_sse_event_to_endpoint(&response_url, &event) {
2415 if let Some(tx) = endpoint_tx.take() {
2416 let _ = tx.send(endpoint);
2417 }
2418 continue;
2419 }
2420
2421 if let Some(frame) = legacy_sse_event_to_frame(event) {
2422 let _ = frame_tx.send(frame);
2423 }
2424 }
2425 Ok(None) => break,
2426 Err(error) => {
2427 if let Some(tx) = endpoint_tx.take() {
2428 let _ = tx.send(Err(error));
2429 } else {
2430 let _ = frame_tx.send(Err(error));
2431 }
2432 return;
2433 }
2434 }
2435 }
2436
2437 if let Some(tx) = endpoint_tx.take() {
2438 let _ = tx.send(Err(McpError::Transport(
2439 "SSE stream ended before endpoint event".into(),
2440 )));
2441 }
2442}
2443
2444fn resolve_sse_endpoint(response_url: &Url, endpoint: &str) -> Result<Url, McpError> {
2445 response_url
2446 .join(endpoint.trim())
2447 .map_err(|error| McpError::Transport(format!("invalid SSE endpoint URL: {error}")))
2448}
2449
2450#[derive(Debug)]
2451struct SseEvent {
2452 event_name: Option<String>,
2453 data: String,
2454 id: Option<String>,
2455 retry_ms: Option<u64>,
2456}
2457
2458async fn read_next_sse_event<R>(reader: &mut R) -> Result<Option<SseEvent>, McpError>
2459where
2460 R: AsyncBufRead + Unpin,
2461{
2462 let mut event_name = None;
2463 let mut data_lines = Vec::new();
2464 let mut id = None;
2465 let mut retry_ms = None;
2466
2467 loop {
2468 let mut line = String::new();
2469 let read = reader.read_line(&mut line).await.map_err(McpError::Io)?;
2470 if read == 0 {
2471 if event_name.is_none() && data_lines.is_empty() && id.is_none() && retry_ms.is_none() {
2472 return Ok(None);
2473 }
2474 return Ok(Some(SseEvent {
2475 event_name,
2476 data: data_lines.join("\n"),
2477 id,
2478 retry_ms,
2479 }));
2480 }
2481
2482 let line = line.trim_end_matches(['\r', '\n']);
2483 if line.is_empty() {
2484 if event_name.is_none() && data_lines.is_empty() && id.is_none() && retry_ms.is_none() {
2485 continue;
2486 }
2487 return Ok(Some(SseEvent {
2488 event_name,
2489 data: data_lines.join("\n"),
2490 id,
2491 retry_ms,
2492 }));
2493 }
2494
2495 if line.starts_with(':') {
2496 continue;
2497 }
2498
2499 if let Some(rest) = line.strip_prefix("event:") {
2500 event_name = Some(rest.trim_start().to_string());
2501 continue;
2502 }
2503 if let Some(rest) = line.strip_prefix("data:") {
2504 data_lines.push(rest.trim_start().to_string());
2505 continue;
2506 }
2507 if let Some(rest) = line.strip_prefix("id:") {
2508 id = Some(rest.trim_start().to_string());
2509 continue;
2510 }
2511 if let Some(rest) = line.strip_prefix("retry:") {
2512 retry_ms = rest.trim_start().parse().ok();
2513 }
2514 }
2515}
2516
2517fn legacy_sse_event_to_endpoint(
2518 response_url: &Url,
2519 event: &SseEvent,
2520) -> Option<Result<Url, McpError>> {
2521 if event.event_name.as_deref() != Some("endpoint") {
2522 return None;
2523 }
2524 if event.data.is_empty() {
2525 return Some(Err(McpError::Transport(
2526 "legacy SSE endpoint event is missing data".into(),
2527 )));
2528 }
2529 Some(resolve_sse_endpoint(response_url, &event.data))
2530}
2531
2532fn legacy_sse_event_to_frame(event: SseEvent) -> Option<Result<McpFrame, McpError>> {
2533 let event_name = event.event_name.unwrap_or_else(|| "message".into());
2534 if event_name != "message" || event.data.is_empty() {
2535 return None;
2536 }
2537
2538 Some(
2539 serde_json::from_str(&event.data)
2540 .map_err(McpError::Serialize)
2541 .map(|value| McpFrame { value }),
2542 )
2543}
2544
2545fn streamable_http_event_to_frame(event: SseEvent) -> Result<Option<McpFrame>, McpError> {
2546 let event_name = event.event_name.unwrap_or_else(|| "message".into());
2547 if event_name != "message" || event.data.is_empty() {
2548 return Ok(None);
2549 }
2550
2551 let value = serde_json::from_str(&event.data).map_err(McpError::Serialize)?;
2552 Ok(Some(McpFrame { value }))
2553}
2554
2555fn is_jsonrpc_request(value: &Value) -> bool {
2556 value.get("method").is_some() && value.get("id").is_some()
2557}
2558
2559fn apply_streamable_http_headers(
2560 mut request: reqwest::RequestBuilder,
2561 headers: &[(String, String)],
2562 protocol_version: Option<&str>,
2563 session_id: Option<&str>,
2564) -> reqwest::RequestBuilder {
2565 for (key, value) in headers {
2566 request = request.header(key, value);
2567 }
2568
2569 if let Some(protocol_version) = protocol_version {
2570 request = request.header("MCP-Protocol-Version", protocol_version);
2571 }
2572 if let Some(session_id) = session_id {
2573 request = request.header("MCP-Session-Id", session_id);
2574 }
2575
2576 request
2577}
2578
2579async fn streamable_http_status_error(
2580 operation: &str,
2581 status: StatusCode,
2582 response: reqwest::Response,
2583) -> McpError {
2584 let body = response
2585 .text()
2586 .await
2587 .unwrap_or_else(|_| "<unreadable response body>".into());
2588 McpError::Transport(format!("{operation} failed with status {status}: {body}"))
2589}
2590
2591#[derive(Debug, Error)]
2593pub enum McpError {
2594 #[error("io error: {0}")]
2596 Io(#[from] std::io::Error),
2597 #[error("http error: {0}")]
2599 Http(#[from] reqwest::Error),
2600 #[error("serialization error: {0}")]
2602 Serialize(#[from] serde_json::Error),
2603 #[error("transport error: {0}")]
2605 Transport(String),
2606 #[error("protocol error: {0}")]
2608 Protocol(String),
2609 #[error("MCP auth required: {0:?}")]
2612 AuthRequired(Box<AuthRequest>),
2613 #[error("auth resolution error: {0}")]
2615 AuthResolution(String),
2616 #[error("invocation error: {0}")]
2618 Invocation(String),
2619 #[error("unknown MCP server: {0}")]
2621 UnknownServer(String),
2622}
2623
2624#[cfg(test)]
2625mod tests {
2626 use std::collections::VecDeque;
2627 use std::sync::{Arc as StdArc, Mutex as StdMutex};
2628
2629 use super::*;
2630 use agentkit_tools_core::{PermissionChecker, PermissionDecision, PermissionRequest};
2631 use tokio::io::{AsyncReadExt, AsyncWriteExt};
2632 use tokio::net::TcpListener;
2633
2634 struct AllowAll;
2635
2636 impl PermissionChecker for AllowAll {
2637 fn evaluate(&self, _request: &dyn PermissionRequest) -> PermissionDecision {
2638 PermissionDecision::Allow
2639 }
2640 }
2641
2642 struct FakeTransport {
2643 recv: VecDeque<Value>,
2644 }
2645
2646 #[async_trait]
2647 impl McpTransport for FakeTransport {
2648 async fn send(&mut self, _message: McpFrame) -> Result<(), McpError> {
2649 Ok(())
2650 }
2651
2652 async fn recv(&mut self) -> Result<Option<McpFrame>, McpError> {
2653 Ok(self.recv.pop_front().map(|value| McpFrame { value }))
2654 }
2655
2656 async fn close(&mut self) -> Result<(), McpError> {
2657 Ok(())
2658 }
2659 }
2660
2661 fn fake_connection(responses: Vec<Value>) -> McpConnection {
2662 McpConnection {
2663 server_id: McpServerId::new("fake"),
2664 transport: Mutex::new(Box::new(FakeTransport {
2665 recv: responses.into(),
2666 })),
2667 auth: Mutex::new(None),
2668 next_id: AtomicU64::new(1),
2669 }
2670 }
2671
2672 #[derive(Clone)]
2673 struct FakeTransportFactory {
2674 responses: StdArc<StdMutex<VecDeque<Vec<Value>>>>,
2675 }
2676
2677 impl FakeTransportFactory {
2678 fn new(sequences: Vec<Vec<Value>>) -> Self {
2679 Self {
2680 responses: StdArc::new(StdMutex::new(sequences.into())),
2681 }
2682 }
2683 }
2684
2685 #[async_trait]
2686 impl McpTransportFactory for FakeTransportFactory {
2687 async fn connect(&self) -> Result<Box<dyn McpTransport>, McpError> {
2688 let responses =
2689 self.responses.lock().unwrap().pop_front().ok_or_else(|| {
2690 McpError::Transport("no fake transport responses left".into())
2691 })?;
2692 Ok(Box::new(FakeTransport {
2693 recv: responses.into(),
2694 }))
2695 }
2696 }
2697
2698 #[tokio::test]
2699 async fn discovery_parses_snapshot() {
2700 let connection = fake_connection(vec![
2701 json!({ "jsonrpc": "2.0", "id": 1, "result": { "tools": [{ "name": "echo", "description": "Echo", "inputSchema": {"type": "object"} }] } }),
2702 json!({ "jsonrpc": "2.0", "id": 2, "result": { "resources": [{ "uri": "file:///tmp/example.txt", "name": "example.txt", "mimeType": "text/plain" }] } }),
2703 json!({ "jsonrpc": "2.0", "id": 3, "result": { "prompts": [{ "name": "summarize", "description": "Summarize", "arguments": [{ "name": "path" }] }] } }),
2704 ]);
2705
2706 let snapshot = connection.discover().await.unwrap();
2707 assert_eq!(snapshot.tools[0].name, "echo");
2708 assert_eq!(snapshot.resources[0].id, "file:///tmp/example.txt");
2709 assert_eq!(snapshot.prompts[0].id, "summarize");
2710 }
2711
2712 #[tokio::test]
2713 async fn tool_adapter_returns_text_output() {
2714 let connection = Arc::new(fake_connection(vec![json!({
2715 "jsonrpc": "2.0",
2716 "id": 1,
2717 "result": { "content": [{ "type": "text", "text": "pong" }] }
2718 })]));
2719 let server_id = connection.server_id().clone();
2720 let adapter = McpToolAdapter::new(
2721 &server_id,
2722 connection,
2723 McpToolDescriptor {
2724 name: "echo".into(),
2725 description: Some("Echo".into()),
2726 input_schema: json!({ "type": "object" }),
2727 metadata: MetadataMap::new(),
2728 },
2729 );
2730 let metadata = MetadataMap::new();
2731 let mut ctx = ToolContext {
2732 capability: CapabilityContext {
2733 session_id: None,
2734 turn_id: None,
2735 metadata: &metadata,
2736 },
2737 permissions: &AllowAll,
2738 resources: &(),
2739 cancellation: None,
2740 };
2741
2742 let result = adapter
2743 .invoke(
2744 ToolRequest {
2745 call_id: "call-1".into(),
2746 tool_name: ToolName::new("mcp.fake.echo"),
2747 input: json!({}),
2748 session_id: "session-1".into(),
2749 turn_id: "turn-1".into(),
2750 metadata: MetadataMap::new(),
2751 },
2752 &mut ctx,
2753 )
2754 .await
2755 .unwrap();
2756
2757 assert_eq!(result.result.output, ToolOutput::Text("pong".into()));
2758 }
2759
2760 #[tokio::test]
2761 async fn request_surfaces_auth_required_errors() {
2762 let connection = fake_connection(vec![json!({
2763 "jsonrpc": "2.0",
2764 "id": 1,
2765 "error": {
2766 "code": -32001,
2767 "message": "authentication required",
2768 "data": {
2769 "auth_required": true,
2770 "scope": "secrets.read"
2771 }
2772 }
2773 })]);
2774
2775 let error = connection.call_tool("echo", json!({})).await.unwrap_err();
2776 match error {
2777 McpError::AuthRequired(request) => {
2778 assert_eq!(request.provider, "mcp.fake");
2779 assert_eq!(
2780 request.challenge.get("method"),
2781 Some(&Value::String("tools/call".into()))
2782 );
2783 assert!(matches!(
2784 request.operation,
2785 AuthOperation::McpToolCall { ref tool_name, .. } if tool_name == "echo"
2786 ));
2787 }
2788 other => panic!("unexpected error: {other:?}"),
2789 }
2790 }
2791
2792 #[tokio::test]
2793 async fn tool_adapter_maps_auth_required_into_tool_error() {
2794 let connection = Arc::new(fake_connection(vec![json!({
2795 "jsonrpc": "2.0",
2796 "id": 1,
2797 "error": {
2798 "code": -32001,
2799 "message": "authentication required",
2800 "data": { "auth_required": true }
2801 }
2802 })]));
2803 let server_id = connection.server_id().clone();
2804 let adapter = McpToolAdapter::new(
2805 &server_id,
2806 connection,
2807 McpToolDescriptor {
2808 name: "echo".into(),
2809 description: Some("Echo".into()),
2810 input_schema: json!({ "type": "object" }),
2811 metadata: MetadataMap::new(),
2812 },
2813 );
2814 let metadata = MetadataMap::new();
2815 let mut ctx = ToolContext {
2816 capability: CapabilityContext {
2817 session_id: None,
2818 turn_id: None,
2819 metadata: &metadata,
2820 },
2821 permissions: &AllowAll,
2822 resources: &(),
2823 cancellation: None,
2824 };
2825
2826 let error = adapter
2827 .invoke(
2828 ToolRequest {
2829 call_id: "call-1".into(),
2830 tool_name: ToolName::new("mcp.fake.echo"),
2831 input: json!({}),
2832 session_id: "session-1".into(),
2833 turn_id: "turn-1".into(),
2834 metadata: MetadataMap::new(),
2835 },
2836 &mut ctx,
2837 )
2838 .await
2839 .unwrap_err();
2840
2841 match error {
2842 ToolError::AuthRequired(request) => {
2843 assert_eq!(request.provider, "mcp.fake");
2844 }
2845 other => panic!("unexpected error: {other:?}"),
2846 }
2847 }
2848
2849 struct RecordingTransport {
2850 recv: VecDeque<Value>,
2851 sent: StdArc<StdMutex<Vec<Value>>>,
2852 }
2853
2854 #[async_trait]
2855 impl McpTransport for RecordingTransport {
2856 async fn send(&mut self, message: McpFrame) -> Result<(), McpError> {
2857 self.sent.lock().unwrap().push(message.value);
2858 Ok(())
2859 }
2860
2861 async fn recv(&mut self) -> Result<Option<McpFrame>, McpError> {
2862 Ok(self.recv.pop_front().map(|value| McpFrame { value }))
2863 }
2864
2865 async fn close(&mut self) -> Result<(), McpError> {
2866 Ok(())
2867 }
2868 }
2869
2870 #[derive(Clone)]
2871 struct RecordingTransportFactory {
2872 responses: StdArc<StdMutex<VecDeque<Vec<Value>>>>,
2873 sent: StdArc<StdMutex<Vec<Value>>>,
2874 }
2875
2876 impl RecordingTransportFactory {
2877 fn new(sequences: Vec<Vec<Value>>) -> Self {
2878 Self {
2879 responses: StdArc::new(StdMutex::new(sequences.into())),
2880 sent: StdArc::new(StdMutex::new(Vec::new())),
2881 }
2882 }
2883
2884 fn sent(&self) -> Vec<Value> {
2885 self.sent.lock().unwrap().clone()
2886 }
2887 }
2888
2889 #[async_trait]
2890 impl McpTransportFactory for RecordingTransportFactory {
2891 async fn connect(&self) -> Result<Box<dyn McpTransport>, McpError> {
2892 let responses = self.responses.lock().unwrap().pop_front().ok_or_else(|| {
2893 McpError::Transport("no recording transport responses left".into())
2894 })?;
2895 Ok(Box::new(RecordingTransport {
2896 recv: responses.into(),
2897 sent: self.sent.clone(),
2898 }))
2899 }
2900 }
2901
2902 #[tokio::test]
2903 async fn connection_includes_resolved_auth_in_future_requests() {
2904 let factory = RecordingTransportFactory::new(vec![vec![
2905 json!({ "jsonrpc": "2.0", "id": 0, "result": { "protocolVersion": "2025-11-25", "capabilities": {}, "serverInfo": { "name": "recording", "version": "1.0.0" } } }),
2906 json!({ "jsonrpc": "2.0", "id": 1, "result": { "content": [{ "type": "text", "text": "ok" }] } }),
2907 ]]);
2908 let config = McpServerConfig::new(
2909 "recording",
2910 McpTransportBinding::Custom(Arc::new(factory.clone())),
2911 );
2912 let connection = McpConnection::connect(&config).await.unwrap();
2913 let mut auth = MetadataMap::new();
2914 auth.insert("token".into(), json!("secret-token"));
2915 let request = AuthRequest {
2916 task_id: None,
2917 id: "auth-recording-tool".into(),
2918 provider: "mcp.recording".into(),
2919 operation: AuthOperation::McpToolCall {
2920 server_id: "recording".into(),
2921 tool_name: "echo".into(),
2922 input: json!({}),
2923 metadata: MetadataMap::new(),
2924 },
2925 challenge: MetadataMap::new(),
2926 };
2927 connection
2928 .resolve_auth(agentkit_tools_core::AuthResolution::Provided {
2929 request,
2930 credentials: auth,
2931 })
2932 .await
2933 .unwrap();
2934
2935 let _ = connection.call_tool("echo", json!({})).await.unwrap();
2936 let sent = factory.sent();
2937 assert!(
2938 sent.iter().any(|value| {
2939 value
2940 .get("params")
2941 .and_then(|params| params.get("auth"))
2942 .and_then(|auth| auth.get("token"))
2943 == Some(&json!("secret-token"))
2944 }),
2945 "expected an MCP request to include the resolved auth payload, saw {:?}",
2946 sent
2947 );
2948 }
2949
2950 #[tokio::test]
2951 async fn manager_reuses_stored_auth_on_connect() {
2952 let factory = RecordingTransportFactory::new(vec![vec![
2953 json!({ "jsonrpc": "2.0", "id": 0, "result": { "protocolVersion": "2025-11-25", "capabilities": {}, "serverInfo": { "name": "recording", "version": "1.0.0" } } }),
2954 json!({ "jsonrpc": "2.0", "id": 1, "result": { "tools": [] } }),
2955 json!({ "jsonrpc": "2.0", "id": 2, "result": { "resources": [] } }),
2956 json!({ "jsonrpc": "2.0", "id": 3, "result": { "prompts": [] } }),
2957 ]]);
2958 let server_id = McpServerId::new("recording");
2959 let mut manager = McpServerManager::new().with_server(McpServerConfig::new(
2960 server_id.to_string(),
2961 McpTransportBinding::Custom(Arc::new(factory.clone())),
2962 ));
2963 let mut auth = MetadataMap::new();
2964 auth.insert("token".into(), json!("seed-token"));
2965 let request = AuthRequest {
2966 task_id: None,
2967 id: "auth-recording-connect".into(),
2968 provider: "mcp.recording".into(),
2969 operation: AuthOperation::McpConnect {
2970 server_id: server_id.to_string(),
2971 metadata: MetadataMap::new(),
2972 },
2973 challenge: MetadataMap::new(),
2974 };
2975 manager
2976 .resolve_auth(agentkit_tools_core::AuthResolution::Provided {
2977 request,
2978 credentials: auth,
2979 })
2980 .await
2981 .unwrap();
2982
2983 manager.connect_server(&server_id).await.unwrap();
2984 let sent = factory.sent();
2985 assert!(
2986 sent.iter().any(|value| {
2987 value.get("method").and_then(Value::as_str) == Some("initialize")
2988 && value
2989 .get("params")
2990 .and_then(|params| params.get("auth"))
2991 .and_then(|auth| auth.get("token"))
2992 == Some(&json!("seed-token"))
2993 }),
2994 "expected initialize to include stored auth, saw {:?}",
2995 sent
2996 );
2997 }
2998
2999 #[tokio::test]
3000 async fn manager_resolves_auth_and_replays_resource_read() {
3001 let factory = RecordingTransportFactory::new(vec![vec![
3002 json!({ "jsonrpc": "2.0", "id": 0, "result": { "protocolVersion": "2025-11-25", "capabilities": {}, "serverInfo": { "name": "recording", "version": "1.0.0" } } }),
3003 json!({ "jsonrpc": "2.0", "id": 1, "result": { "tools": [] } }),
3004 json!({ "jsonrpc": "2.0", "id": 2, "result": { "resources": [] } }),
3005 json!({ "jsonrpc": "2.0", "id": 3, "result": { "prompts": [] } }),
3006 json!({
3007 "jsonrpc": "2.0",
3008 "id": 4,
3009 "result": {
3010 "contents": [
3011 {
3012 "uri": "file:///tmp/secret.txt",
3013 "text": "secret from resource"
3014 }
3015 ]
3016 }
3017 }),
3018 ]]);
3019 let server_id = McpServerId::new("recording");
3020 let mut manager = McpServerManager::new().with_server(McpServerConfig::new(
3021 server_id.to_string(),
3022 McpTransportBinding::Custom(Arc::new(factory.clone())),
3023 ));
3024 let mut auth = MetadataMap::new();
3025 auth.insert("token".into(), json!("resource-token"));
3026 let request = AuthRequest {
3027 task_id: None,
3028 id: "auth-recording-resource".into(),
3029 provider: "mcp.recording".into(),
3030 operation: AuthOperation::McpResourceRead {
3031 server_id: server_id.to_string(),
3032 resource_id: "file:///tmp/secret.txt".into(),
3033 metadata: MetadataMap::new(),
3034 },
3035 challenge: MetadataMap::new(),
3036 };
3037
3038 let result = manager
3039 .resolve_auth_and_resume(agentkit_tools_core::AuthResolution::Provided {
3040 request,
3041 credentials: auth,
3042 })
3043 .await
3044 .unwrap();
3045
3046 match result {
3047 McpOperationResult::Resource(contents) => {
3048 assert_eq!(
3049 contents.data,
3050 DataRef::InlineText("secret from resource".into())
3051 );
3052 }
3053 other => panic!("unexpected replay result: {other:?}"),
3054 }
3055
3056 let sent = factory.sent();
3057 assert!(
3058 sent.iter().any(|value| {
3059 value.get("method").and_then(Value::as_str) == Some("resources/read")
3060 && value
3061 .get("params")
3062 .and_then(|params| params.get("auth"))
3063 .and_then(|auth| auth.get("token"))
3064 == Some(&json!("resource-token"))
3065 }),
3066 "expected resources/read to include resolved auth, saw {:?}",
3067 sent
3068 );
3069 }
3070
3071 #[tokio::test]
3072 async fn manager_resolves_auth_and_replays_connect() {
3073 let factory = RecordingTransportFactory::new(vec![vec![
3074 json!({ "jsonrpc": "2.0", "id": 0, "result": { "protocolVersion": "2025-11-25", "capabilities": {}, "serverInfo": { "name": "recording", "version": "1.0.0" } } }),
3075 json!({ "jsonrpc": "2.0", "id": 1, "result": { "tools": [] } }),
3076 json!({ "jsonrpc": "2.0", "id": 2, "result": { "resources": [] } }),
3077 json!({ "jsonrpc": "2.0", "id": 3, "result": { "prompts": [] } }),
3078 ]]);
3079 let server_id = McpServerId::new("recording");
3080 let mut manager = McpServerManager::new().with_server(McpServerConfig::new(
3081 server_id.to_string(),
3082 McpTransportBinding::Custom(Arc::new(factory.clone())),
3083 ));
3084 let mut auth = MetadataMap::new();
3085 auth.insert("token".into(), json!("connect-token"));
3086 let request = AuthRequest {
3087 task_id: None,
3088 id: "auth-recording-connect-replay".into(),
3089 provider: "mcp.recording".into(),
3090 operation: AuthOperation::McpConnect {
3091 server_id: server_id.to_string(),
3092 metadata: MetadataMap::new(),
3093 },
3094 challenge: MetadataMap::new(),
3095 };
3096
3097 let result = manager
3098 .resolve_auth_and_resume(agentkit_tools_core::AuthResolution::Provided {
3099 request,
3100 credentials: auth,
3101 })
3102 .await
3103 .unwrap();
3104
3105 match result {
3106 McpOperationResult::Connected(snapshot) => {
3107 assert_eq!(snapshot.server_id, server_id);
3108 }
3109 other => panic!("unexpected replay result: {other:?}"),
3110 }
3111 }
3112
3113 #[tokio::test]
3114 async fn sse_transport_posts_messages_and_receives_frames() {
3115 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3116 let address = listener.local_addr().unwrap();
3117 let requests = StdArc::new(StdMutex::new(Vec::new()));
3118 let captured = requests.clone();
3119
3120 let server = tokio::spawn(async move {
3121 for _ in 0..2 {
3122 let (mut socket, _) = listener.accept().await.unwrap();
3123 let mut buffer = vec![0_u8; 4096];
3124 let read = socket.read(&mut buffer).await.unwrap();
3125 let request = String::from_utf8_lossy(&buffer[..read]).to_string();
3126
3127 if request.starts_with("GET /sse ") {
3128 let body = concat!(
3129 "event: endpoint\n",
3130 "data: /messages\n\n",
3131 "event: message\n",
3132 "data: {\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"tools\":[]}}\n\n"
3133 );
3134 let response = format!(
3135 "HTTP/1.1 200 OK\r\ncontent-type: text/event-stream\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{}",
3136 body.len(),
3137 body
3138 );
3139 socket.write_all(response.as_bytes()).await.unwrap();
3140 } else {
3141 captured.lock().unwrap().push(request);
3142 socket
3143 .write_all(
3144 b"HTTP/1.1 202 Accepted\r\ncontent-length: 0\r\nconnection: close\r\n\r\n",
3145 )
3146 .await
3147 .unwrap();
3148 }
3149 }
3150 });
3151
3152 let factory =
3153 SseTransportFactory::new(SseTransportConfig::new(format!("http://{address}/sse")));
3154 let mut transport = factory.connect().await.unwrap();
3155 transport
3156 .send(McpFrame {
3157 value: json!({
3158 "jsonrpc": "2.0",
3159 "id": 1,
3160 "method": "tools/list",
3161 "params": {}
3162 }),
3163 })
3164 .await
3165 .unwrap();
3166 let frame = transport.recv().await.unwrap().unwrap();
3167 transport.close().await.unwrap();
3168 server.await.unwrap();
3169
3170 assert_eq!(frame.value["result"]["tools"], json!([]));
3171 let requests = requests.lock().unwrap();
3172 assert_eq!(requests.len(), 1);
3173 assert!(requests[0].starts_with("POST /messages "));
3174 assert!(requests[0].contains("\"method\":\"tools/list\""));
3175 }
3176
3177 #[tokio::test]
3178 async fn streamable_http_connection_tracks_session_and_protocol_headers() {
3179 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3180 let address = listener.local_addr().unwrap();
3181 let requests = StdArc::new(StdMutex::new(Vec::new()));
3182 let captured = requests.clone();
3183
3184 let server = tokio::spawn(async move {
3185 for _ in 0..4 {
3186 let (mut socket, _) = listener.accept().await.unwrap();
3187 let mut buffer = vec![0_u8; 8192];
3188 let read = socket.read(&mut buffer).await.unwrap();
3189 let request = String::from_utf8_lossy(&buffer[..read]).to_string();
3190 captured.lock().unwrap().push(request.clone());
3191
3192 let response = if request.contains("\"method\":\"initialize\"") {
3193 let body = "{\"jsonrpc\":\"2.0\",\"id\":0,\"result\":{\"protocolVersion\":\"2025-11-25\",\"capabilities\":{},\"serverInfo\":{\"name\":\"remote\",\"version\":\"1.0.0\"}}}";
3194 format!(
3195 "HTTP/1.1 200 OK\r\ncontent-type: application/json\r\nMCP-Session-Id: session-123\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{}",
3196 body.len(),
3197 body
3198 )
3199 } else if request.contains("\"method\":\"notifications/initialized\"") {
3200 "HTTP/1.1 202 Accepted\r\ncontent-length: 0\r\nconnection: close\r\n\r\n"
3201 .to_string()
3202 } else if request.starts_with("DELETE /mcp ") {
3203 "HTTP/1.1 204 No Content\r\ncontent-length: 0\r\nconnection: close\r\n\r\n"
3204 .to_string()
3205 } else {
3206 let body = "{\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"tools\":[]}}";
3207 format!(
3208 "HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{}",
3209 body.len(),
3210 body
3211 )
3212 };
3213
3214 socket.write_all(response.as_bytes()).await.unwrap();
3215 }
3216 });
3217
3218 let config = McpServerConfig::new(
3219 "remote",
3220 McpTransportBinding::StreamableHttp(StreamableHttpTransportConfig::new(format!(
3221 "http://{address}/mcp"
3222 ))),
3223 );
3224 let connection = McpConnection::connect(&config).await.unwrap();
3225 let _ = connection.list_tools().await.unwrap();
3226 connection.close().await.unwrap();
3227 server.await.unwrap();
3228
3229 let requests = requests.lock().unwrap();
3230 assert_eq!(requests.len(), 4);
3231 let normalized = requests
3232 .iter()
3233 .map(|request| request.to_ascii_lowercase())
3234 .collect::<Vec<_>>();
3235 assert!(requests[0].starts_with("POST /mcp "));
3236 assert!(!requests[0].contains("MCP-Session-Id:"));
3237 assert!(normalized[1].contains("mcp-session-id: session-123"));
3238 assert!(normalized[1].contains("mcp-protocol-version: 2025-11-25"));
3239 assert!(normalized[2].contains("mcp-session-id: session-123"));
3240 assert!(normalized[2].contains("mcp-protocol-version: 2025-11-25"));
3241 assert!(requests[3].starts_with("DELETE /mcp "));
3242 assert!(normalized[3].contains("mcp-session-id: session-123"));
3243 }
3244
3245 #[tokio::test]
3246 async fn streamable_http_transport_resumes_sse_streams_until_response_arrives() {
3247 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3248 let address = listener.local_addr().unwrap();
3249 let requests = StdArc::new(StdMutex::new(Vec::new()));
3250 let captured = requests.clone();
3251
3252 let server = tokio::spawn(async move {
3253 for _ in 0..2 {
3254 let (mut socket, _) = listener.accept().await.unwrap();
3255 let mut buffer = vec![0_u8; 8192];
3256 let read = socket.read(&mut buffer).await.unwrap();
3257 let request = String::from_utf8_lossy(&buffer[..read]).to_string();
3258 captured.lock().unwrap().push(request.clone());
3259
3260 let response = if request.starts_with("POST /mcp ") {
3261 let body = concat!(
3262 "id: evt-1\n",
3263 "event: message\n",
3264 "data: {\"jsonrpc\":\"2.0\",\"method\":\"notifications/message\",\"params\":{\"phase\":\"stream-start\"}}\n\n"
3265 );
3266 format!(
3267 "HTTP/1.1 200 OK\r\ncontent-type: text/event-stream\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{}",
3268 body.len(),
3269 body
3270 )
3271 } else {
3272 let body = concat!(
3273 "id: evt-2\n",
3274 "event: message\n",
3275 "data: {\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"tools\":[]}}\n\n"
3276 );
3277 format!(
3278 "HTTP/1.1 200 OK\r\ncontent-type: text/event-stream\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{}",
3279 body.len(),
3280 body
3281 )
3282 };
3283
3284 socket.write_all(response.as_bytes()).await.unwrap();
3285 }
3286 });
3287
3288 let factory = StreamableHttpTransportFactory::new(StreamableHttpTransportConfig::new(
3289 format!("http://{address}/mcp"),
3290 ));
3291 let mut transport = factory.connect().await.unwrap();
3292 transport
3293 .send(McpFrame {
3294 value: json!({
3295 "jsonrpc": "2.0",
3296 "id": 1,
3297 "method": "tools/list",
3298 "params": {}
3299 }),
3300 })
3301 .await
3302 .unwrap();
3303
3304 let first = transport.recv().await.unwrap().unwrap();
3305 let second = transport.recv().await.unwrap().unwrap();
3306 transport.close().await.unwrap();
3307 server.await.unwrap();
3308
3309 assert_eq!(
3310 first.value["method"],
3311 Value::String("notifications/message".into())
3312 );
3313 assert_eq!(second.value["result"]["tools"], json!([]));
3314
3315 let requests = requests.lock().unwrap();
3316 assert_eq!(requests.len(), 2);
3317 assert!(requests[0].starts_with("POST /mcp "));
3318 assert!(requests[1].starts_with("GET /mcp "));
3319 assert!(
3320 requests[1].contains("last-event-id: evt-1")
3321 || requests[1].contains("Last-Event-ID: evt-1")
3322 );
3323 }
3324
3325 #[tokio::test]
3326 async fn server_manager_connects_refreshes_and_aggregates_tools() {
3327 let alpha = McpServerConfig::new(
3328 "alpha",
3329 McpTransportBinding::Custom(Arc::new(FakeTransportFactory::new(vec![vec![
3330 json!({ "jsonrpc": "2.0", "id": 0, "result": { "protocolVersion": "2025-11-25", "capabilities": {}, "serverInfo": { "name": "alpha", "version": "1.0.0" } } }),
3331 json!({ "jsonrpc": "2.0", "id": 1, "result": { "tools": [{ "name": "echo", "description": "Echo", "inputSchema": {"type": "object"} }] } }),
3332 json!({ "jsonrpc": "2.0", "id": 2, "result": { "resources": [] } }),
3333 json!({ "jsonrpc": "2.0", "id": 3, "result": { "prompts": [] } }),
3334 json!({ "jsonrpc": "2.0", "id": 4, "result": { "tools": [{ "name": "echo_v2", "description": "Echo 2", "inputSchema": {"type": "object"} }] } }),
3335 json!({ "jsonrpc": "2.0", "id": 5, "result": { "resources": [] } }),
3336 json!({ "jsonrpc": "2.0", "id": 6, "result": { "prompts": [] } }),
3337 ]]))),
3338 );
3339 let beta = McpServerConfig::new(
3340 "beta",
3341 McpTransportBinding::Custom(Arc::new(FakeTransportFactory::new(vec![vec![
3342 json!({ "jsonrpc": "2.0", "id": 0, "result": { "protocolVersion": "2025-11-25", "capabilities": {}, "serverInfo": { "name": "beta", "version": "1.0.0" } } }),
3343 json!({ "jsonrpc": "2.0", "id": 1, "result": { "tools": [{ "name": "search", "description": "Search", "inputSchema": {"type": "object"} }] } }),
3344 json!({ "jsonrpc": "2.0", "id": 2, "result": { "resources": [] } }),
3345 json!({ "jsonrpc": "2.0", "id": 3, "result": { "prompts": [] } }),
3346 ]]))),
3347 );
3348
3349 let mut manager = McpServerManager::new().with_server(alpha).with_server(beta);
3350
3351 let handles = manager.connect_all().await.unwrap();
3352 assert_eq!(handles.len(), 2);
3353 assert_eq!(
3354 manager
3355 .tool_registry()
3356 .specs()
3357 .into_iter()
3358 .map(|spec| spec.name.0)
3359 .collect::<Vec<_>>(),
3360 vec!["mcp.alpha.echo".to_string(), "mcp.beta.search".to_string()]
3361 );
3362
3363 let refreshed = manager
3364 .refresh_server(&McpServerId::new("alpha"))
3365 .await
3366 .unwrap();
3367 assert_eq!(refreshed.tools[0].name, "echo_v2");
3368 assert_eq!(
3369 manager
3370 .connected_server(&McpServerId::new("alpha"))
3371 .unwrap()
3372 .snapshot()
3373 .tools[0]
3374 .name,
3375 "echo_v2"
3376 );
3377
3378 let capabilities = manager.capability_provider();
3379 assert_eq!(capabilities.invocables().len(), 2);
3380
3381 manager
3382 .disconnect_server(&McpServerId::new("alpha"))
3383 .await
3384 .unwrap();
3385 assert!(
3386 manager
3387 .connected_server(&McpServerId::new("alpha"))
3388 .is_none()
3389 );
3390 }
3391}