1use std::collections::{BTreeMap, VecDeque};
2use std::fmt;
3use std::process::Stdio;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::time::Duration;
7
8use agentkit_capabilities::{
9 CapabilityContext, CapabilityError, CapabilityName, CapabilityProvider, Invocable,
10 InvocableOutput, InvocableRequest, InvocableResult, InvocableSpec, PromptContents,
11 PromptDescriptor, PromptId, PromptProvider, ResourceContents, ResourceDescriptor, ResourceId,
12 ResourceProvider,
13};
14use agentkit_core::{
15 DataRef, Item, ItemKind, MetadataMap, Part, TextPart, ToolOutput, ToolResultPart,
16};
17use agentkit_http::{
18 HeaderMap, Http, HttpError, HttpRequestBuilder, HttpResponse, StatusCode, header as http_header,
19};
20use agentkit_tools_core::{
21 AuthOperation, AuthRequest, AuthResolution, Tool, ToolAnnotations, ToolContext, ToolError,
22 ToolName, ToolRegistry, ToolRequest, ToolResult, ToolSpec,
23};
24use async_trait::async_trait;
25use futures_util::TryStreamExt;
26use serde::{Deserialize, Serialize};
27use serde_json::{Value, json};
28use thiserror::Error;
29use tokio::io::{AsyncBufRead, AsyncBufReadExt, AsyncWriteExt, BufReader};
30use tokio::process::{Child, ChildStdin, ChildStdout, Command};
31use tokio::sync::{Mutex, mpsc, oneshot};
32use tokio::task::JoinHandle;
33use tokio::time::sleep;
34use tokio_util::io::StreamReader;
35use url::Url;
36
37const MCP_LATEST_PROTOCOL_VERSION: &str = "2025-11-25";
38const MCP_SUPPORTED_PROTOCOL_VERSIONS: &[&str] =
39 &["2025-11-25", "2025-06-18", "2025-03-26", "2024-11-05"];
40
41#[derive(Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
56pub struct McpServerId(pub String);
57
58impl McpServerId {
59 pub fn new(value: impl Into<String>) -> Self {
61 Self(value.into())
62 }
63}
64
65impl fmt::Display for McpServerId {
66 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
67 self.0.fmt(f)
68 }
69}
70
71#[derive(Clone, Debug, PartialEq, Eq)]
89pub struct StdioTransportConfig {
90 pub command: String,
92 pub args: Vec<String>,
94 pub env: Vec<(String, String)>,
96 pub cwd: Option<std::path::PathBuf>,
98}
99
100impl StdioTransportConfig {
101 pub fn new(command: impl Into<String>) -> Self {
103 Self {
104 command: command.into(),
105 args: Vec::new(),
106 env: Vec::new(),
107 cwd: None,
108 }
109 }
110
111 pub fn with_arg(mut self, arg: impl Into<String>) -> Self {
113 self.args.push(arg.into());
114 self
115 }
116
117 pub fn with_env(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
119 self.env.push((key.into(), value.into()));
120 self
121 }
122
123 pub fn with_cwd(mut self, cwd: impl Into<std::path::PathBuf>) -> Self {
125 self.cwd = Some(cwd.into());
126 self
127 }
128}
129
130#[derive(Clone, Debug)]
141pub struct SseTransportConfig {
142 pub url: String,
144 pub client: Option<Http>,
147}
148
149impl SseTransportConfig {
150 pub fn new(url: impl Into<String>) -> Self {
152 Self {
153 url: url.into(),
154 client: None,
155 }
156 }
157
158 pub fn with_client(mut self, client: Http) -> Self {
162 self.client = Some(client);
163 self
164 }
165}
166
167#[derive(Clone, Debug)]
176pub struct StreamableHttpTransportConfig {
177 pub url: String,
179 pub client: Option<Http>,
182}
183
184impl StreamableHttpTransportConfig {
185 pub fn new(url: impl Into<String>) -> Self {
187 Self {
188 url: url.into(),
189 client: None,
190 }
191 }
192
193 pub fn with_client(mut self, client: Http) -> Self {
196 self.client = Some(client);
197 self
198 }
199}
200
201#[derive(Clone)]
208pub enum McpTransportBinding {
209 Stdio(StdioTransportConfig),
211 StreamableHttp(StreamableHttpTransportConfig),
213 Sse(SseTransportConfig),
215 Custom(Arc<dyn McpTransportFactory>),
217}
218
219#[derive(Clone)]
240pub struct McpServerConfig {
241 pub id: McpServerId,
243 pub transport: McpTransportBinding,
245 pub metadata: MetadataMap,
247}
248
249impl McpServerConfig {
250 pub fn new(id: impl Into<String>, transport: McpTransportBinding) -> Self {
257 Self {
258 id: McpServerId::new(id),
259 transport,
260 metadata: MetadataMap::new(),
261 }
262 }
263
264 pub fn stdio(id: impl Into<String>, command: impl Into<String>) -> Self {
266 Self::new(
267 id,
268 McpTransportBinding::Stdio(StdioTransportConfig::new(command)),
269 )
270 }
271
272 pub fn sse(id: impl Into<String>, url: impl Into<String>) -> Self {
274 Self::new(id, McpTransportBinding::Sse(SseTransportConfig::new(url)))
275 }
276
277 pub fn streamable_http(id: impl Into<String>, url: impl Into<String>) -> Self {
279 Self::new(
280 id,
281 McpTransportBinding::StreamableHttp(StreamableHttpTransportConfig::new(url)),
282 )
283 }
284
285 pub fn with_metadata(mut self, metadata: MetadataMap) -> Self {
287 self.metadata = metadata;
288 self
289 }
290}
291
292#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
297pub struct McpFrame {
298 pub value: Value,
300}
301
302#[async_trait]
313pub trait McpTransportFactory: Send + Sync {
314 async fn connect(&self) -> Result<Box<dyn McpTransport>, McpError>;
316}
317
318#[async_trait]
327pub trait McpTransport: Send + Sync {
328 async fn send(&mut self, message: McpFrame) -> Result<(), McpError>;
330 async fn recv(&mut self) -> Result<Option<McpFrame>, McpError>;
332 async fn close(&mut self) -> Result<(), McpError>;
334}
335
336pub struct StdioTransportFactory {
341 config: StdioTransportConfig,
342}
343
344impl StdioTransportFactory {
345 pub fn new(config: StdioTransportConfig) -> Self {
347 Self { config }
348 }
349}
350
351#[async_trait]
352impl McpTransportFactory for StdioTransportFactory {
353 async fn connect(&self) -> Result<Box<dyn McpTransport>, McpError> {
354 let mut command = Command::new(&self.config.command);
355 command.args(&self.config.args);
356 command.stdin(Stdio::piped());
357 command.stdout(Stdio::piped());
358 command.stderr(Stdio::inherit());
359
360 if let Some(cwd) = &self.config.cwd {
361 command.current_dir(cwd);
362 }
363
364 for (key, value) in &self.config.env {
365 command.env(key, value);
366 }
367
368 let mut child = command.spawn().map_err(McpError::Io)?;
369 let stdin = child
370 .stdin
371 .take()
372 .ok_or_else(|| McpError::Transport("failed to capture MCP stdin".into()))?;
373 let stdout = child
374 .stdout
375 .take()
376 .ok_or_else(|| McpError::Transport("failed to capture MCP stdout".into()))?;
377
378 Ok(Box::new(StdioTransport {
379 child,
380 stdin,
381 stdout: BufReader::new(stdout),
382 }))
383 }
384}
385
386pub struct SseTransportFactory {
391 config: SseTransportConfig,
392}
393
394impl SseTransportFactory {
395 pub fn new(config: SseTransportConfig) -> Self {
397 Self { config }
398 }
399}
400
401pub struct StreamableHttpTransportFactory {
406 config: StreamableHttpTransportConfig,
407}
408
409impl StreamableHttpTransportFactory {
410 pub fn new(config: StreamableHttpTransportConfig) -> Self {
412 Self { config }
413 }
414}
415
416#[async_trait]
417impl McpTransportFactory for SseTransportFactory {
418 async fn connect(&self) -> Result<Box<dyn McpTransport>, McpError> {
419 let client = resolve_http_client(self.config.client.as_ref())?;
420
421 let response = client
422 .get(self.config.url.as_str())
423 .header("Accept", "text/event-stream")
424 .header("Cache-Control", "no-cache")
425 .send()
426 .await?;
427
428 let status = response.status();
429 if !status.is_success() {
430 let body = response
431 .text()
432 .await
433 .unwrap_or_else(|_| "<unreadable response body>".into());
434 return Err(McpError::Transport(format!(
435 "SSE connection failed with status {status}: {body}"
436 )));
437 }
438
439 let response_url = Url::parse(response.url())
440 .map_err(|error| McpError::Transport(format!("invalid SSE response URL: {error}")))?;
441 let stream = response.bytes_stream().map_err(std::io::Error::other);
442 let reader = BufReader::new(StreamReader::new(stream));
443 let (frame_tx, frame_rx) = mpsc::unbounded_channel();
444 let (endpoint_tx, endpoint_rx) = oneshot::channel();
445 let read_task = tokio::spawn(read_sse_stream(reader, response_url, frame_tx, endpoint_tx));
446
447 let endpoint_url = endpoint_rx
448 .await
449 .map_err(|_| McpError::Transport("SSE stream closed before endpoint event".into()))??;
450
451 Ok(Box::new(SseTransport {
452 client,
453 endpoint_url,
454 frame_rx,
455 read_task,
456 }))
457 }
458}
459
460#[async_trait]
461impl McpTransportFactory for StreamableHttpTransportFactory {
462 async fn connect(&self) -> Result<Box<dyn McpTransport>, McpError> {
463 let client = resolve_http_client(self.config.client.as_ref())?;
464
465 let endpoint_url = Url::parse(&self.config.url)
466 .map_err(|error| McpError::Transport(format!("invalid MCP endpoint URL: {error}")))?;
467
468 Ok(Box::new(StreamableHttpTransport {
469 client,
470 endpoint_url,
471 protocol_version: None,
472 session_id: None,
473 pending_frames: VecDeque::new(),
474 }))
475 }
476}
477
478fn resolve_http_client(configured: Option<&Http>) -> Result<Http, McpError> {
479 if let Some(client) = configured {
480 return Ok(client.clone());
481 }
482 #[cfg(feature = "reqwest-client")]
483 {
484 reqwest::Client::builder()
485 .user_agent(concat!("agentkit-mcp/", env!("CARGO_PKG_VERSION")))
486 .build()
487 .map(Http::new)
488 .map_err(|error| McpError::Http(HttpError::request(error)))
489 }
490 #[cfg(not(feature = "reqwest-client"))]
491 {
492 Err(McpError::Transport(
493 "no HTTP client configured; enable the `reqwest-client` feature or supply one via `with_client`".into(),
494 ))
495 }
496}
497
498struct StdioTransport {
499 child: Child,
500 stdin: ChildStdin,
501 stdout: BufReader<ChildStdout>,
502}
503
504struct SseTransport {
505 client: Http,
506 endpoint_url: Url,
507 frame_rx: mpsc::UnboundedReceiver<Result<McpFrame, McpError>>,
508 read_task: JoinHandle<()>,
509}
510
511struct StreamableHttpTransport {
512 client: Http,
513 endpoint_url: Url,
514 protocol_version: Option<String>,
515 session_id: Option<String>,
516 pending_frames: VecDeque<McpFrame>,
517}
518
519#[async_trait]
520impl McpTransport for StdioTransport {
521 async fn send(&mut self, message: McpFrame) -> Result<(), McpError> {
522 let mut encoded = serde_json::to_vec(&message.value).map_err(McpError::Serialize)?;
523 encoded.push(b'\n');
524 self.stdin.write_all(&encoded).await.map_err(McpError::Io)?;
525 self.stdin.flush().await.map_err(McpError::Io)?;
526 Ok(())
527 }
528
529 async fn recv(&mut self) -> Result<Option<McpFrame>, McpError> {
530 let mut line = String::new();
531 let read = self
532 .stdout
533 .read_line(&mut line)
534 .await
535 .map_err(McpError::Io)?;
536 if read == 0 {
537 return Ok(None);
538 }
539
540 let value = serde_json::from_str(line.trim()).map_err(McpError::Serialize)?;
541 Ok(Some(McpFrame { value }))
542 }
543
544 async fn close(&mut self) -> Result<(), McpError> {
545 let _ = self.stdin.shutdown().await;
546 let _ = self.child.kill().await;
547 Ok(())
548 }
549}
550
551#[async_trait]
552impl McpTransport for SseTransport {
553 async fn send(&mut self, message: McpFrame) -> Result<(), McpError> {
554 let response = self
555 .client
556 .post(self.endpoint_url.as_str())
557 .header("Content-Type", "application/json")
558 .json(&message.value)
559 .send()
560 .await?;
561 let status = response.status();
562 if !status.is_success() {
563 let body = response
564 .text()
565 .await
566 .unwrap_or_else(|_| "<unreadable response body>".into());
567 return Err(McpError::Transport(format!(
568 "SSE POST failed with status {status}: {body}"
569 )));
570 }
571
572 Ok(())
573 }
574
575 async fn recv(&mut self) -> Result<Option<McpFrame>, McpError> {
576 match self.frame_rx.recv().await {
577 Some(Ok(frame)) => Ok(Some(frame)),
578 Some(Err(error)) => Err(error),
579 None => Ok(None),
580 }
581 }
582
583 async fn close(&mut self) -> Result<(), McpError> {
584 self.read_task.abort();
585 Ok(())
586 }
587}
588
589#[async_trait]
590impl McpTransport for StreamableHttpTransport {
591 async fn send(&mut self, message: McpFrame) -> Result<(), McpError> {
592 let is_request = is_jsonrpc_request(&message.value);
593 let request_id = message.value.get("id").cloned();
594 let is_initialize =
595 message.value.get("method").and_then(Value::as_str) == Some("initialize");
596
597 let mut request = self
598 .client
599 .post(self.endpoint_url.as_str())
600 .header("Content-Type", "application/json")
601 .header("Accept", "application/json, text/event-stream");
602
603 request = apply_streamable_http_headers(
604 request,
605 self.protocol_version.as_deref(),
606 self.session_id.as_deref(),
607 );
608
609 let response = request.json(&message.value).send().await?;
610
611 if is_initialize {
612 self.capture_session_id(response.headers());
613 }
614
615 let status = response.status();
616 if !status.is_success() {
617 return Err(
618 streamable_http_status_error("Streamable HTTP POST", status, response).await,
619 );
620 }
621
622 if !is_request {
623 return Ok(());
624 }
625
626 let content_type = response
627 .headers()
628 .get(http_header::CONTENT_TYPE)
629 .and_then(|value| value.to_str().ok())
630 .unwrap_or_default()
631 .to_string();
632
633 if content_type.starts_with("application/json") {
634 let value: Value = response.json().await?;
635 self.maybe_update_protocol_version(&message.value, &value)?;
636 self.pending_frames.push_back(McpFrame { value });
637 return Ok(());
638 }
639
640 if !content_type.starts_with("text/event-stream") {
641 let body = response
642 .text()
643 .await
644 .unwrap_or_else(|_| "<unreadable response body>".into());
645 return Err(McpError::Transport(format!(
646 "unexpected Streamable HTTP response content type {content_type:?}: {body}"
647 )));
648 }
649
650 let request_id = request_id.ok_or_else(|| {
651 McpError::Protocol("JSON-RPC request over Streamable HTTP is missing an id".into())
652 })?;
653 self.collect_streamable_http_response(response, &message.value, &request_id)
654 .await
655 }
656
657 async fn recv(&mut self) -> Result<Option<McpFrame>, McpError> {
658 Ok(self.pending_frames.pop_front())
659 }
660
661 async fn close(&mut self) -> Result<(), McpError> {
662 let Some(session_id) = self.session_id.clone() else {
663 return Ok(());
664 };
665
666 let mut request = self.client.delete(self.endpoint_url.as_str());
667 request = apply_streamable_http_headers(
668 request,
669 self.protocol_version.as_deref(),
670 Some(session_id.as_str()),
671 );
672
673 let response = request.send().await?;
674 if response.status().is_success()
675 || response.status() == StatusCode::METHOD_NOT_ALLOWED
676 || response.status() == StatusCode::NOT_FOUND
677 {
678 self.session_id = None;
679 return Ok(());
680 }
681
682 Err(
683 streamable_http_status_error("Streamable HTTP DELETE", response.status(), response)
684 .await,
685 )
686 }
687}
688
689impl StreamableHttpTransport {
690 async fn collect_streamable_http_response(
691 &mut self,
692 response: HttpResponse,
693 request_message: &Value,
694 request_id: &Value,
695 ) -> Result<(), McpError> {
696 let mut retry_delay = Duration::from_millis(0);
697 let mut last_event_id = None;
698 let mut saw_response = false;
699
700 saw_response |= self
701 .read_streamable_http_events(
702 response,
703 request_message,
704 request_id,
705 &mut last_event_id,
706 &mut retry_delay,
707 )
708 .await?;
709
710 while !saw_response && last_event_id.is_some() {
711 if !retry_delay.is_zero() {
712 sleep(retry_delay).await;
713 }
714
715 let response = self
716 .resume_streamable_http_stream(last_event_id.as_deref().unwrap())
717 .await?;
718 saw_response |= self
719 .read_streamable_http_events(
720 response,
721 request_message,
722 request_id,
723 &mut last_event_id,
724 &mut retry_delay,
725 )
726 .await?;
727 }
728
729 Ok(())
730 }
731
732 async fn read_streamable_http_events(
733 &mut self,
734 response: HttpResponse,
735 request_message: &Value,
736 request_id: &Value,
737 last_event_id: &mut Option<String>,
738 retry_delay: &mut Duration,
739 ) -> Result<bool, McpError> {
740 let stream = response.bytes_stream().map_err(std::io::Error::other);
741 let mut reader = BufReader::new(StreamReader::new(stream));
742 let mut saw_response = false;
743
744 while let Some(event) = read_next_sse_event(&mut reader).await? {
745 if let Some(id) = event.id.clone() {
746 *last_event_id = Some(id);
747 }
748 if let Some(retry_ms) = event.retry_ms {
749 *retry_delay = Duration::from_millis(retry_ms);
750 }
751
752 let Some(frame) = streamable_http_event_to_frame(event)? else {
753 continue;
754 };
755
756 self.maybe_update_protocol_version(request_message, &frame.value)?;
757 if frame.value.get("id") == Some(request_id) {
758 saw_response = true;
759 }
760 self.pending_frames.push_back(frame);
761 }
762
763 Ok(saw_response)
764 }
765
766 async fn resume_streamable_http_stream(
767 &self,
768 last_event_id: &str,
769 ) -> Result<HttpResponse, McpError> {
770 let mut request = self
771 .client
772 .get(self.endpoint_url.as_str())
773 .header("Accept", "text/event-stream")
774 .header("Cache-Control", "no-cache")
775 .header("Last-Event-ID", last_event_id);
776
777 request = apply_streamable_http_headers(
778 request,
779 self.protocol_version.as_deref(),
780 self.session_id.as_deref(),
781 );
782
783 let response = request.send().await?;
784 let status = response.status();
785 if !status.is_success() {
786 return Err(
787 streamable_http_status_error("Streamable HTTP GET", status, response).await,
788 );
789 }
790
791 let content_type = response
792 .headers()
793 .get(http_header::CONTENT_TYPE)
794 .and_then(|value| value.to_str().ok())
795 .unwrap_or_default();
796 if !content_type.starts_with("text/event-stream") {
797 let content_type = content_type.to_string();
798 let body = response
799 .text()
800 .await
801 .unwrap_or_else(|_| "<unreadable response body>".into());
802 return Err(McpError::Transport(format!(
803 "Streamable HTTP GET expected text/event-stream, got {content_type:?}: {body}"
804 )));
805 }
806
807 Ok(response)
808 }
809
810 fn maybe_update_protocol_version(
811 &mut self,
812 request_message: &Value,
813 response_value: &Value,
814 ) -> Result<(), McpError> {
815 if request_message.get("method").and_then(Value::as_str) != Some("initialize") {
816 return Ok(());
817 }
818
819 let protocol_version = response_value
820 .get("result")
821 .and_then(|result| result.get("protocolVersion"))
822 .and_then(Value::as_str);
823
824 if let Some(protocol_version) = protocol_version {
825 self.protocol_version = Some(protocol_version.to_string());
826 }
827
828 Ok(())
829 }
830
831 fn capture_session_id(&mut self, headers: &HeaderMap) {
832 self.session_id = headers
833 .get("MCP-Session-Id")
834 .and_then(|value| value.to_str().ok())
835 .map(|value| value.to_string());
836 }
837}
838
839#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
845pub struct McpToolDescriptor {
846 pub name: String,
848 pub description: Option<String>,
850 pub input_schema: Value,
852 pub metadata: MetadataMap,
854}
855
856#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
861pub struct McpResourceDescriptor {
862 pub id: String,
864 pub name: String,
866 pub description: Option<String>,
868 pub mime_type: Option<String>,
870 pub metadata: MetadataMap,
872}
873
874#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
879pub struct McpPromptDescriptor {
880 pub id: String,
882 pub name: String,
884 pub description: Option<String>,
886 pub input_schema: Value,
888 pub metadata: MetadataMap,
890}
891
892#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
898pub struct McpDiscoverySnapshot {
899 pub server_id: McpServerId,
901 pub tools: Vec<McpToolDescriptor>,
903 pub resources: Vec<McpResourceDescriptor>,
905 pub prompts: Vec<McpPromptDescriptor>,
907 pub metadata: MetadataMap,
909}
910
911pub struct McpConnection {
941 server_id: McpServerId,
942 transport: Mutex<Box<dyn McpTransport>>,
943 auth: Mutex<Option<MetadataMap>>,
944 next_id: AtomicU64,
945}
946
947#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
952pub enum McpOperationResult {
953 Connected(McpDiscoverySnapshot),
955 Tool(Value),
957 Resource(ResourceContents),
959 Prompt(PromptContents),
961}
962
963impl McpConnection {
964 pub async fn connect(config: &McpServerConfig) -> Result<Self, McpError> {
972 Self::connect_with_auth(config, None).await
973 }
974
975 async fn connect_with_auth(
976 config: &McpServerConfig,
977 auth: Option<&MetadataMap>,
978 ) -> Result<Self, McpError> {
979 let factory: Arc<dyn McpTransportFactory> = match &config.transport {
980 McpTransportBinding::Stdio(binding) => {
981 Arc::new(StdioTransportFactory::new(binding.clone()))
982 }
983 McpTransportBinding::StreamableHttp(binding) => {
984 Arc::new(StreamableHttpTransportFactory::new(binding.clone()))
985 }
986 McpTransportBinding::Sse(binding) => {
987 Arc::new(SseTransportFactory::new(binding.clone()))
988 }
989 McpTransportBinding::Custom(factory) => factory.clone(),
990 };
991
992 let mut transport = factory.connect().await?;
993 let mut params = serde_json::Map::new();
994 params.insert(
995 "protocolVersion".into(),
996 Value::String(MCP_LATEST_PROTOCOL_VERSION.into()),
997 );
998 params.insert("capabilities".into(), json!({}));
999 params.insert(
1000 "clientInfo".into(),
1001 json!({
1002 "name": "agentkit-mcp",
1003 "version": env!("CARGO_PKG_VERSION")
1004 }),
1005 );
1006 if let Some(auth) = auth {
1007 params.insert("auth".into(), metadata_to_value(auth));
1008 }
1009 let init_params = Value::Object(params.clone());
1010 transport
1011 .send(McpFrame {
1012 value: json!({
1013 "jsonrpc": "2.0",
1014 "id": 0,
1015 "method": "initialize",
1016 "params": init_params.clone()
1017 }),
1018 })
1019 .await?;
1020 let init_response = transport.recv().await?.ok_or_else(|| {
1021 McpError::Transport("transport closed during MCP initialization".into())
1022 })?;
1023 if let Some(error) = init_response.value.get("error") {
1024 if let Some(auth_request) =
1025 parse_auth_request(&config.id, "initialize", &init_params, error)
1026 {
1027 return Err(McpError::AuthRequired(Box::new(auth_request)));
1028 }
1029 return Err(McpError::Invocation(error.to_string()));
1030 }
1031 let negotiated_protocol_version = init_response
1032 .value
1033 .get("result")
1034 .and_then(|result| result.get("protocolVersion"))
1035 .and_then(Value::as_str)
1036 .ok_or_else(|| {
1037 McpError::Protocol("initialize response missing result.protocolVersion".into())
1038 })?;
1039 if !MCP_SUPPORTED_PROTOCOL_VERSIONS.contains(&negotiated_protocol_version) {
1040 return Err(McpError::Protocol(format!(
1041 "unsupported MCP protocol version negotiated during initialize: {negotiated_protocol_version}"
1042 )));
1043 }
1044 transport
1045 .send(McpFrame {
1046 value: json!({
1047 "jsonrpc": "2.0",
1048 "method": "notifications/initialized",
1049 "params": {}
1050 }),
1051 })
1052 .await?;
1053
1054 Ok(Self {
1055 server_id: config.id.clone(),
1056 transport: Mutex::new(transport),
1057 auth: Mutex::new(auth.cloned()),
1058 next_id: AtomicU64::new(1),
1059 })
1060 }
1061
1062 pub fn server_id(&self) -> &McpServerId {
1064 &self.server_id
1065 }
1066
1067 pub async fn close(&self) -> Result<(), McpError> {
1073 let mut transport = self.transport.lock().await;
1074 transport.close().await
1075 }
1076
1077 pub async fn resolve_auth(&self, resolution: AuthResolution) -> Result<(), McpError> {
1087 let mut auth = self.auth.lock().await;
1088 match resolution {
1089 AuthResolution::Provided { credentials, .. } => {
1090 *auth = Some(credentials);
1091 }
1092 AuthResolution::Cancelled { .. } => {
1093 *auth = None;
1094 }
1095 }
1096 Ok(())
1097 }
1098
1099 pub async fn discover(&self) -> Result<McpDiscoverySnapshot, McpError> {
1107 Ok(McpDiscoverySnapshot {
1108 server_id: self.server_id.clone(),
1109 tools: self.list_tools().await?,
1110 resources: self.list_resources().await?,
1111 prompts: self.list_prompts().await?,
1112 metadata: MetadataMap::new(),
1113 })
1114 }
1115
1116 pub async fn list_tools(&self) -> Result<Vec<McpToolDescriptor>, McpError> {
1122 let result = self.request("tools/list", json!({})).await?;
1123 result
1124 .get("tools")
1125 .and_then(Value::as_array)
1126 .cloned()
1127 .unwrap_or_default()
1128 .into_iter()
1129 .map(parse_tool_descriptor)
1130 .collect()
1131 }
1132
1133 pub async fn list_resources(&self) -> Result<Vec<McpResourceDescriptor>, McpError> {
1139 let result = self.request("resources/list", json!({})).await?;
1140 result
1141 .get("resources")
1142 .and_then(Value::as_array)
1143 .cloned()
1144 .unwrap_or_default()
1145 .into_iter()
1146 .map(parse_resource_descriptor)
1147 .collect()
1148 }
1149
1150 pub async fn list_prompts(&self) -> Result<Vec<McpPromptDescriptor>, McpError> {
1156 let result = self.request("prompts/list", json!({})).await?;
1157 result
1158 .get("prompts")
1159 .and_then(Value::as_array)
1160 .cloned()
1161 .unwrap_or_default()
1162 .into_iter()
1163 .map(parse_prompt_descriptor)
1164 .collect()
1165 }
1166
1167 pub async fn call_tool(&self, name: &str, arguments: Value) -> Result<Value, McpError> {
1179 self.request(
1180 "tools/call",
1181 json!({
1182 "name": name,
1183 "arguments": arguments,
1184 }),
1185 )
1186 .await
1187 }
1188
1189 pub async fn read_resource(&self, uri: &str) -> Result<ResourceContents, McpError> {
1199 let result = self
1200 .request(
1201 "resources/read",
1202 json!({
1203 "uri": uri,
1204 }),
1205 )
1206 .await?;
1207 let content = result
1208 .get("contents")
1209 .and_then(Value::as_array)
1210 .and_then(|values| values.first())
1211 .cloned()
1212 .ok_or_else(|| McpError::Protocol("resources/read returned no contents".into()))?;
1213
1214 let data = if let Some(text) = content.get("text").and_then(Value::as_str) {
1215 DataRef::InlineText(text.into())
1216 } else if let Some(found_uri) = content.get("uri").and_then(Value::as_str) {
1217 DataRef::Uri(found_uri.into())
1218 } else {
1219 return Err(McpError::Protocol(
1220 "unsupported resource content shape".into(),
1221 ));
1222 };
1223
1224 Ok(ResourceContents {
1225 data,
1226 metadata: MetadataMap::new(),
1227 })
1228 }
1229
1230 pub async fn get_prompt(
1241 &self,
1242 name: &str,
1243 arguments: Value,
1244 ) -> Result<PromptContents, McpError> {
1245 let result = self
1246 .request(
1247 "prompts/get",
1248 json!({
1249 "name": name,
1250 "arguments": arguments,
1251 }),
1252 )
1253 .await?;
1254 let items = result
1255 .get("messages")
1256 .and_then(Value::as_array)
1257 .cloned()
1258 .unwrap_or_default()
1259 .into_iter()
1260 .map(parse_prompt_message)
1261 .collect::<Result<Vec<_>, _>>()?;
1262
1263 Ok(PromptContents {
1264 items,
1265 metadata: MetadataMap::new(),
1266 })
1267 }
1268
1269 async fn request(&self, method: &str, params: Value) -> Result<Value, McpError> {
1270 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
1271 let params = self.enrich_params(params.clone()).await;
1272 let mut transport = self.transport.lock().await;
1273 transport
1274 .send(McpFrame {
1275 value: json!({
1276 "jsonrpc": "2.0",
1277 "id": id,
1278 "method": method,
1279 "params": params,
1280 }),
1281 })
1282 .await?;
1283
1284 loop {
1285 let Some(frame) = transport.recv().await? else {
1286 return Err(McpError::Transport(
1287 "transport closed while waiting for MCP response".into(),
1288 ));
1289 };
1290
1291 if frame.value.get("id").and_then(Value::as_u64) != Some(id) {
1292 continue;
1293 }
1294
1295 if let Some(error) = frame.value.get("error") {
1296 if let Some(auth_request) =
1297 parse_auth_request(&self.server_id, method, ¶ms, error)
1298 {
1299 return Err(McpError::AuthRequired(Box::new(auth_request)));
1300 }
1301 return Err(McpError::Invocation(error.to_string()));
1302 }
1303
1304 return frame
1305 .value
1306 .get("result")
1307 .cloned()
1308 .ok_or_else(|| McpError::Protocol("MCP response missing result".into()));
1309 }
1310 }
1311
1312 async fn enrich_params(&self, params: Value) -> Value {
1313 let auth = self.auth.lock().await;
1314 let Some(auth) = auth.as_ref() else {
1315 return params;
1316 };
1317
1318 match params {
1319 Value::Object(mut object) => {
1320 object
1321 .entry("auth")
1322 .or_insert_with(|| metadata_to_value(auth));
1323 Value::Object(object)
1324 }
1325 other => other,
1326 }
1327 }
1328
1329 pub async fn replay_auth_operation(
1339 &self,
1340 operation: &AuthOperation,
1341 ) -> Result<McpOperationResult, McpError> {
1342 match operation {
1343 AuthOperation::McpToolCall {
1344 server_id,
1345 tool_name,
1346 input,
1347 ..
1348 } => {
1349 self.ensure_server_match(server_id)?;
1350 self.call_tool(tool_name, input.clone())
1351 .await
1352 .map(McpOperationResult::Tool)
1353 }
1354 AuthOperation::McpResourceRead {
1355 server_id,
1356 resource_id,
1357 ..
1358 } => {
1359 self.ensure_server_match(server_id)?;
1360 self.read_resource(resource_id)
1361 .await
1362 .map(McpOperationResult::Resource)
1363 }
1364 AuthOperation::McpPromptGet {
1365 server_id,
1366 prompt_id,
1367 args,
1368 ..
1369 } => {
1370 self.ensure_server_match(server_id)?;
1371 self.get_prompt(prompt_id, args.clone())
1372 .await
1373 .map(McpOperationResult::Prompt)
1374 }
1375 AuthOperation::ToolCall {
1376 tool_name,
1377 input,
1378 metadata,
1379 ..
1380 } => {
1381 if let Some(server_id) = metadata.get("server_id").and_then(Value::as_str) {
1382 self.ensure_server_match(server_id)?;
1383 }
1384 let tool_name = normalize_mcp_tool_name(self.server_id(), tool_name);
1385 self.call_tool(&tool_name, input.clone())
1386 .await
1387 .map(McpOperationResult::Tool)
1388 }
1389 AuthOperation::McpConnect { .. } => Err(McpError::AuthResolution(
1390 "connect operations must be replayed through the server manager".into(),
1391 )),
1392 AuthOperation::Custom { kind, .. } => Err(McpError::AuthResolution(format!(
1393 "unsupported auth operation for replay: {kind}"
1394 ))),
1395 }
1396 }
1397
1398 fn ensure_server_match(&self, server_id: &str) -> Result<(), McpError> {
1399 if self.server_id.0 == server_id {
1400 Ok(())
1401 } else {
1402 Err(McpError::AuthResolution(format!(
1403 "auth operation targets server {server_id}, but connection is for {}",
1404 self.server_id
1405 )))
1406 }
1407 }
1408}
1409
1410pub struct McpInvocable {
1417 connection: Arc<McpConnection>,
1418 descriptor: McpToolDescriptor,
1419 spec: InvocableSpec,
1420}
1421
1422impl McpInvocable {
1423 pub fn new(connection: Arc<McpConnection>, descriptor: McpToolDescriptor) -> Self {
1430 let spec = InvocableSpec {
1431 name: CapabilityName::new(format!(
1432 "mcp_{}_{}",
1433 connection.server_id(),
1434 descriptor.name
1435 )),
1436 description: descriptor
1437 .description
1438 .clone()
1439 .unwrap_or_else(|| descriptor.name.clone()),
1440 input_schema: descriptor.input_schema.clone(),
1441 metadata: descriptor.metadata.clone(),
1442 };
1443
1444 Self {
1445 connection,
1446 descriptor,
1447 spec,
1448 }
1449 }
1450}
1451
1452#[async_trait]
1453impl Invocable for McpInvocable {
1454 fn spec(&self) -> &InvocableSpec {
1455 &self.spec
1456 }
1457
1458 async fn invoke(
1459 &self,
1460 request: InvocableRequest,
1461 _ctx: &mut CapabilityContext<'_>,
1462 ) -> Result<InvocableResult, CapabilityError> {
1463 let result = self
1464 .connection
1465 .call_tool(&self.descriptor.name, request.input)
1466 .await
1467 .map_err(|error| match error {
1468 McpError::AuthRequired(request) => {
1469 CapabilityError::Unavailable(format!("auth required: {:?}", request))
1470 }
1471 other => CapabilityError::ExecutionFailed(other.to_string()),
1472 })?;
1473
1474 Ok(InvocableResult {
1475 output: value_to_invocable_output(result),
1476 metadata: MetadataMap::new(),
1477 })
1478 }
1479}
1480
1481pub struct McpResourceHandle {
1486 connection: Arc<McpConnection>,
1487 descriptor: ResourceDescriptor,
1488}
1489
1490#[async_trait]
1491impl ResourceProvider for McpResourceHandle {
1492 async fn list_resources(&self) -> Result<Vec<ResourceDescriptor>, CapabilityError> {
1493 Ok(vec![self.descriptor.clone()])
1494 }
1495
1496 async fn read_resource(
1497 &self,
1498 id: &ResourceId,
1499 _ctx: &mut CapabilityContext<'_>,
1500 ) -> Result<ResourceContents, CapabilityError> {
1501 self.connection
1502 .read_resource(&id.0)
1503 .await
1504 .map_err(|error| match error {
1505 McpError::AuthRequired(request) => {
1506 CapabilityError::Unavailable(format!("auth required: {:?}", request))
1507 }
1508 other => CapabilityError::ExecutionFailed(other.to_string()),
1509 })
1510 }
1511}
1512
1513pub struct McpPromptHandle {
1518 connection: Arc<McpConnection>,
1519 descriptor: PromptDescriptor,
1520}
1521
1522#[async_trait]
1523impl PromptProvider for McpPromptHandle {
1524 async fn list_prompts(&self) -> Result<Vec<PromptDescriptor>, CapabilityError> {
1525 Ok(vec![self.descriptor.clone()])
1526 }
1527
1528 async fn get_prompt(
1529 &self,
1530 id: &PromptId,
1531 args: Value,
1532 _ctx: &mut CapabilityContext<'_>,
1533 ) -> Result<PromptContents, CapabilityError> {
1534 self.connection
1535 .get_prompt(&id.0, args)
1536 .await
1537 .map_err(|error| match error {
1538 McpError::AuthRequired(request) => {
1539 CapabilityError::Unavailable(format!("auth required: {:?}", request))
1540 }
1541 other => CapabilityError::ExecutionFailed(other.to_string()),
1542 })
1543 }
1544}
1545
1546pub struct McpCapabilityProvider {
1573 invocables: Vec<Arc<dyn Invocable>>,
1574 resources: Vec<Arc<dyn ResourceProvider>>,
1575 prompts: Vec<Arc<dyn PromptProvider>>,
1576}
1577
1578impl McpCapabilityProvider {
1579 pub fn from_snapshot(connection: Arc<McpConnection>, snapshot: &McpDiscoverySnapshot) -> Self {
1585 let invocables = snapshot
1586 .tools
1587 .iter()
1588 .cloned()
1589 .map(|descriptor| {
1590 Arc::new(McpInvocable::new(connection.clone(), descriptor)) as Arc<dyn Invocable>
1591 })
1592 .collect();
1593
1594 let resources = snapshot
1595 .resources
1596 .iter()
1597 .cloned()
1598 .map(|descriptor| {
1599 Arc::new(McpResourceHandle {
1600 connection: connection.clone(),
1601 descriptor: ResourceDescriptor {
1602 id: ResourceId::new(descriptor.id),
1603 name: descriptor.name,
1604 description: descriptor.description,
1605 mime_type: descriptor.mime_type,
1606 metadata: descriptor.metadata,
1607 },
1608 }) as Arc<dyn ResourceProvider>
1609 })
1610 .collect();
1611
1612 let prompts = snapshot
1613 .prompts
1614 .iter()
1615 .cloned()
1616 .map(|descriptor| {
1617 Arc::new(McpPromptHandle {
1618 connection: connection.clone(),
1619 descriptor: PromptDescriptor {
1620 id: PromptId::new(descriptor.id),
1621 name: descriptor.name,
1622 description: descriptor.description,
1623 input_schema: descriptor.input_schema,
1624 metadata: descriptor.metadata,
1625 },
1626 }) as Arc<dyn PromptProvider>
1627 })
1628 .collect();
1629
1630 Self {
1631 invocables,
1632 resources,
1633 prompts,
1634 }
1635 }
1636
1637 pub fn merge<I>(providers: I) -> Self
1642 where
1643 I: IntoIterator<Item = Self>,
1644 {
1645 let mut invocables = Vec::new();
1646 let mut resources = Vec::new();
1647 let mut prompts = Vec::new();
1648
1649 for provider in providers {
1650 invocables.extend(provider.invocables);
1651 resources.extend(provider.resources);
1652 prompts.extend(provider.prompts);
1653 }
1654
1655 Self {
1656 invocables,
1657 resources,
1658 prompts,
1659 }
1660 }
1661
1662 pub async fn connect(
1671 config: &McpServerConfig,
1672 ) -> Result<(Arc<McpConnection>, Self, McpDiscoverySnapshot), McpError> {
1673 let connection = Arc::new(McpConnection::connect(config).await?);
1674 let snapshot = connection.discover().await?;
1675 let provider = Self::from_snapshot(connection.clone(), &snapshot);
1676
1677 Ok((connection, provider, snapshot))
1678 }
1679}
1680
1681impl CapabilityProvider for McpCapabilityProvider {
1682 fn invocables(&self) -> Vec<Arc<dyn Invocable>> {
1683 self.invocables.clone()
1684 }
1685
1686 fn resources(&self) -> Vec<Arc<dyn ResourceProvider>> {
1687 self.resources.clone()
1688 }
1689
1690 fn prompts(&self) -> Vec<Arc<dyn PromptProvider>> {
1691 self.prompts.clone()
1692 }
1693}
1694
1695#[derive(Clone)]
1701pub struct McpServerHandle {
1702 config: McpServerConfig,
1703 connection: Arc<McpConnection>,
1704 snapshot: McpDiscoverySnapshot,
1705}
1706
1707impl McpServerHandle {
1708 pub fn config(&self) -> &McpServerConfig {
1710 &self.config
1711 }
1712
1713 pub fn server_id(&self) -> &McpServerId {
1715 self.connection.server_id()
1716 }
1717
1718 pub fn connection(&self) -> Arc<McpConnection> {
1720 self.connection.clone()
1721 }
1722
1723 pub fn snapshot(&self) -> &McpDiscoverySnapshot {
1725 &self.snapshot
1726 }
1727
1728 pub fn tool_registry(&self) -> ToolRegistry {
1731 self.snapshot
1732 .tools
1733 .iter()
1734 .cloned()
1735 .fold(ToolRegistry::new(), |registry, descriptor| {
1736 registry.with(McpToolAdapter::new(
1737 self.server_id(),
1738 self.connection.clone(),
1739 descriptor,
1740 ))
1741 })
1742 }
1743
1744 pub fn capability_provider(&self) -> McpCapabilityProvider {
1746 McpCapabilityProvider::from_snapshot(self.connection.clone(), &self.snapshot)
1747 }
1748}
1749
1750#[derive(Default)]
1791pub struct McpServerManager {
1792 configs: BTreeMap<McpServerId, McpServerConfig>,
1793 connections: BTreeMap<McpServerId, McpServerHandle>,
1794 auth: BTreeMap<McpServerId, MetadataMap>,
1795}
1796
1797impl McpServerManager {
1798 pub fn new() -> Self {
1800 Self::default()
1801 }
1802
1803 pub fn with_server(mut self, config: McpServerConfig) -> Self {
1808 self.register_server(config);
1809 self
1810 }
1811
1812 pub fn register_server(&mut self, config: McpServerConfig) -> &mut Self {
1817 self.configs.insert(config.id.clone(), config);
1818 self
1819 }
1820
1821 pub fn connected_server(&self, server_id: &McpServerId) -> Option<&McpServerHandle> {
1823 self.connections.get(server_id)
1824 }
1825
1826 pub fn connected_servers(&self) -> Vec<&McpServerHandle> {
1828 self.connections.values().collect()
1829 }
1830
1831 pub async fn connect_server(
1840 &mut self,
1841 server_id: &McpServerId,
1842 ) -> Result<McpServerHandle, McpError> {
1843 let config = self
1844 .configs
1845 .get(server_id)
1846 .cloned()
1847 .ok_or_else(|| McpError::UnknownServer(server_id.to_string()))?;
1848 let connection =
1849 Arc::new(McpConnection::connect_with_auth(&config, self.auth.get(server_id)).await?);
1850 let snapshot = connection.discover().await?;
1851 let handle = McpServerHandle {
1852 config,
1853 connection,
1854 snapshot,
1855 };
1856 self.connections.insert(server_id.clone(), handle.clone());
1857 Ok(handle)
1858 }
1859
1860 pub async fn connect_all(&mut self) -> Result<Vec<McpServerHandle>, McpError> {
1870 let server_ids = self.configs.keys().cloned().collect::<Vec<_>>();
1871 let mut handles = Vec::with_capacity(server_ids.len());
1872
1873 for server_id in server_ids {
1874 handles.push(self.connect_server(&server_id).await?);
1875 }
1876
1877 Ok(handles)
1878 }
1879
1880 pub async fn refresh_server(
1890 &mut self,
1891 server_id: &McpServerId,
1892 ) -> Result<McpDiscoverySnapshot, McpError> {
1893 let handle = self
1894 .connections
1895 .get_mut(server_id)
1896 .ok_or_else(|| McpError::UnknownServer(server_id.to_string()))?;
1897 let snapshot = handle.connection.discover().await?;
1898 handle.snapshot = snapshot.clone();
1899 Ok(snapshot)
1900 }
1901
1902 pub async fn disconnect_server(&mut self, server_id: &McpServerId) -> Result<(), McpError> {
1911 let Some(handle) = self.connections.remove(server_id) else {
1912 return Err(McpError::UnknownServer(server_id.to_string()));
1913 };
1914 handle.connection.close().await
1915 }
1916
1917 pub async fn resolve_auth(&mut self, resolution: AuthResolution) -> Result<(), McpError> {
1925 let server_id = resolution
1926 .request()
1927 .server_id()
1928 .ok_or_else(|| McpError::AuthResolution("auth resolution missing server id".into()))?;
1929 let server_id = McpServerId::new(server_id);
1930 match &resolution {
1931 AuthResolution::Provided { credentials, .. } => {
1932 self.auth.insert(server_id.clone(), credentials.clone());
1933 }
1934 AuthResolution::Cancelled { .. } => {
1935 self.auth.remove(&server_id);
1936 }
1937 }
1938
1939 if let Some(handle) = self.connections.get(&server_id) {
1940 handle.connection.resolve_auth(resolution).await?;
1941 return Ok(());
1942 }
1943
1944 if self.configs.contains_key(&server_id) {
1945 Ok(())
1946 } else {
1947 Err(McpError::UnknownServer(server_id.to_string()))
1948 }
1949 }
1950
1951 pub async fn resolve_auth_and_resume(
1961 &mut self,
1962 resolution: AuthResolution,
1963 ) -> Result<McpOperationResult, McpError> {
1964 let request = resolution.request().clone();
1965 self.resolve_auth(resolution).await?;
1966 self.replay_auth_request(&request).await
1967 }
1968
1969 pub async fn replay_auth_request(
1979 &mut self,
1980 request: &AuthRequest,
1981 ) -> Result<McpOperationResult, McpError> {
1982 match &request.operation {
1983 AuthOperation::McpConnect { server_id, .. } => {
1984 let server_id = McpServerId::new(server_id);
1985 let handle = self.connect_server(&server_id).await?;
1986 Ok(McpOperationResult::Connected(handle.snapshot.clone()))
1987 }
1988 AuthOperation::McpToolCall { server_id, .. }
1989 | AuthOperation::McpResourceRead { server_id, .. }
1990 | AuthOperation::McpPromptGet { server_id, .. } => {
1991 let connection = self.connection_for_auth_server(server_id).await?;
1992 connection.replay_auth_operation(&request.operation).await
1993 }
1994 AuthOperation::ToolCall { metadata, .. } => {
1995 let server_id = metadata
1996 .get("server_id")
1997 .and_then(Value::as_str)
1998 .ok_or_else(|| {
1999 McpError::AuthResolution(
2000 "tool-call auth replay requires metadata.server_id".into(),
2001 )
2002 })?;
2003 let connection = self.connection_for_auth_server(server_id).await?;
2004 connection.replay_auth_operation(&request.operation).await
2005 }
2006 AuthOperation::Custom { kind, .. } => Err(McpError::AuthResolution(format!(
2007 "unsupported auth operation for replay: {kind}"
2008 ))),
2009 }
2010 }
2011
2012 async fn connection_for_auth_server(
2013 &mut self,
2014 server_id: &str,
2015 ) -> Result<Arc<McpConnection>, McpError> {
2016 let server_id = McpServerId::new(server_id);
2017 if !self.connections.contains_key(&server_id) {
2018 self.connect_server(&server_id).await?;
2019 }
2020 self.connections
2021 .get(&server_id)
2022 .map(McpServerHandle::connection)
2023 .ok_or_else(|| McpError::UnknownServer(server_id.to_string()))
2024 }
2025
2026 pub fn tool_registry(&self) -> ToolRegistry {
2031 self.connections
2032 .values()
2033 .fold(ToolRegistry::new(), |mut registry, handle| {
2034 for tool in handle.snapshot.tools.iter().cloned() {
2035 registry.register(McpToolAdapter::new(
2036 handle.server_id(),
2037 handle.connection.clone(),
2038 tool,
2039 ));
2040 }
2041 registry
2042 })
2043 }
2044
2045 pub fn capability_provider(&self) -> McpCapabilityProvider {
2048 McpCapabilityProvider::merge(
2049 self.connections
2050 .values()
2051 .map(McpServerHandle::capability_provider),
2052 )
2053 }
2054}
2055
2056pub struct McpToolAdapter {
2072 descriptor: McpToolDescriptor,
2073 connection: Arc<McpConnection>,
2074 spec: ToolSpec,
2075}
2076
2077impl McpToolAdapter {
2078 pub fn new(
2086 server_id: &McpServerId,
2087 connection: Arc<McpConnection>,
2088 descriptor: McpToolDescriptor,
2089 ) -> Self {
2090 let spec = ToolSpec {
2091 name: ToolName::new(format!("mcp_{}_{}", server_id, descriptor.name)),
2092 description: descriptor
2093 .description
2094 .clone()
2095 .unwrap_or_else(|| descriptor.name.clone()),
2096 input_schema: descriptor.input_schema.clone(),
2097 annotations: ToolAnnotations::default(),
2098 metadata: descriptor.metadata.clone(),
2099 };
2100
2101 Self {
2102 descriptor,
2103 connection,
2104 spec,
2105 }
2106 }
2107}
2108
2109#[async_trait]
2110impl Tool for McpToolAdapter {
2111 fn spec(&self) -> &ToolSpec {
2112 &self.spec
2113 }
2114
2115 async fn invoke(
2116 &self,
2117 request: ToolRequest,
2118 _ctx: &mut ToolContext<'_>,
2119 ) -> Result<ToolResult, ToolError> {
2120 let result = self
2121 .connection
2122 .call_tool(&self.descriptor.name, request.input)
2123 .await
2124 .map_err(|error| match error {
2125 McpError::AuthRequired(request) => ToolError::AuthRequired(request),
2126 other => ToolError::ExecutionFailed(other.to_string()),
2127 })?;
2128
2129 Ok(ToolResult {
2130 result: ToolResultPart {
2131 call_id: request.call_id,
2132 output: invocable_output_to_tool_output(value_to_invocable_output(result)),
2133 is_error: false,
2134 metadata: MetadataMap::new(),
2135 },
2136 duration: None,
2137 metadata: MetadataMap::new(),
2138 })
2139 }
2140}
2141
2142fn parse_tool_descriptor(value: Value) -> Result<McpToolDescriptor, McpError> {
2143 Ok(McpToolDescriptor {
2144 name: required_string(&value, "name")?,
2145 description: value
2146 .get("description")
2147 .and_then(Value::as_str)
2148 .map(str::to_owned),
2149 input_schema: value
2150 .get("inputSchema")
2151 .cloned()
2152 .unwrap_or_else(|| json!({ "type": "object" })),
2153 metadata: MetadataMap::new(),
2154 })
2155}
2156
2157fn parse_resource_descriptor(value: Value) -> Result<McpResourceDescriptor, McpError> {
2158 Ok(McpResourceDescriptor {
2159 id: required_string(&value, "uri")?,
2160 name: value
2161 .get("name")
2162 .and_then(Value::as_str)
2163 .map(str::to_owned)
2164 .unwrap_or_else(|| {
2165 value
2166 .get("uri")
2167 .and_then(Value::as_str)
2168 .unwrap_or_default()
2169 .to_string()
2170 }),
2171 description: value
2172 .get("description")
2173 .and_then(Value::as_str)
2174 .map(str::to_owned),
2175 mime_type: value
2176 .get("mimeType")
2177 .and_then(Value::as_str)
2178 .map(str::to_owned),
2179 metadata: MetadataMap::new(),
2180 })
2181}
2182
2183fn parse_prompt_descriptor(value: Value) -> Result<McpPromptDescriptor, McpError> {
2184 let name = required_string(&value, "name")?;
2185 let properties = value
2186 .get("arguments")
2187 .and_then(Value::as_array)
2188 .cloned()
2189 .unwrap_or_default()
2190 .into_iter()
2191 .filter_map(|arg| {
2192 let name = arg.get("name")?.as_str()?.to_string();
2193 Some((name, json!({ "type": "string" })))
2194 })
2195 .collect::<serde_json::Map<String, Value>>();
2196
2197 Ok(McpPromptDescriptor {
2198 id: name.clone(),
2199 name,
2200 description: value
2201 .get("description")
2202 .and_then(Value::as_str)
2203 .map(str::to_owned),
2204 input_schema: json!({
2205 "type": "object",
2206 "properties": properties,
2207 }),
2208 metadata: MetadataMap::new(),
2209 })
2210}
2211
2212fn parse_prompt_message(value: Value) -> Result<Item, McpError> {
2213 let role = value.get("role").and_then(Value::as_str).unwrap_or("user");
2214 let kind = match role {
2215 "assistant" => ItemKind::Assistant,
2216 "system" => ItemKind::System,
2217 _ => ItemKind::User,
2218 };
2219
2220 let content = value.get("content").cloned().unwrap_or(Value::Null);
2221 let text = if let Some(text) = content.get("text").and_then(Value::as_str) {
2222 text.to_string()
2223 } else if let Some(text) = content.as_str() {
2224 text.to_string()
2225 } else {
2226 content.to_string()
2227 };
2228
2229 Ok(Item {
2230 id: None,
2231 kind,
2232 parts: vec![Part::Text(TextPart {
2233 text,
2234 metadata: MetadataMap::new(),
2235 })],
2236 metadata: MetadataMap::new(),
2237 })
2238}
2239
2240fn required_string(value: &Value, field: &str) -> Result<String, McpError> {
2241 value
2242 .get(field)
2243 .and_then(Value::as_str)
2244 .map(str::to_owned)
2245 .ok_or_else(|| McpError::Protocol(format!("missing string field {field}")))
2246}
2247
2248fn value_to_invocable_output(value: Value) -> InvocableOutput {
2249 if let Some(content) = value.get("content").and_then(Value::as_array) {
2250 let text = content
2251 .iter()
2252 .filter_map(|item| item.get("text").and_then(Value::as_str))
2253 .collect::<Vec<_>>()
2254 .join("\n");
2255 if !text.is_empty() {
2256 return InvocableOutput::Text(text);
2257 }
2258 }
2259
2260 if let Some(text) = value.as_str() {
2261 InvocableOutput::Text(text.to_string())
2262 } else {
2263 InvocableOutput::Structured(value)
2264 }
2265}
2266
2267fn invocable_output_to_tool_output(output: InvocableOutput) -> ToolOutput {
2268 match output {
2269 InvocableOutput::Text(text) => ToolOutput::Text(text),
2270 InvocableOutput::Structured(value) => ToolOutput::Structured(value),
2271 InvocableOutput::Items(items) => {
2272 ToolOutput::Parts(items.into_iter().flat_map(|item| item.parts).collect())
2273 }
2274 InvocableOutput::Data(data) => ToolOutput::Structured(json!({ "data": data })),
2275 }
2276}
2277
2278fn metadata_to_value(metadata: &MetadataMap) -> Value {
2279 Value::Object(
2280 metadata
2281 .iter()
2282 .map(|(key, value)| (key.clone(), value.clone()))
2283 .collect(),
2284 )
2285}
2286
2287fn parse_auth_request(
2288 server_id: &McpServerId,
2289 method: &str,
2290 params: &Value,
2291 error: &Value,
2292) -> Option<AuthRequest> {
2293 let code = error.get("code").and_then(Value::as_i64);
2294 let message = error.get("message").and_then(Value::as_str);
2295 let data = error.get("data");
2296
2297 let auth_marker = matches!(code, Some(401 | -32001))
2298 || data
2299 .and_then(|data| data.get("auth_required"))
2300 .and_then(Value::as_bool)
2301 == Some(true)
2302 || data.and_then(|data| data.get("auth")).is_some();
2303
2304 if !auth_marker {
2305 return None;
2306 }
2307
2308 let mut challenge = MetadataMap::new();
2309 challenge.insert("server_id".into(), Value::String(server_id.to_string()));
2310 challenge.insert("method".into(), Value::String(method.into()));
2311
2312 if let Some(code) = code {
2313 challenge.insert("code".into(), Value::Number(code.into()));
2314 }
2315 if let Some(message) = message {
2316 challenge.insert("message".into(), Value::String(message.into()));
2317 }
2318 if let Some(data) = data {
2319 challenge.insert("data".into(), data.clone());
2320 }
2321
2322 Some(AuthRequest {
2323 task_id: None,
2324 id: format!("mcp:{}:{}", server_id, method),
2325 provider: format!("mcp.{}", server_id),
2326 operation: auth_operation_for_method(server_id, method, params),
2327 challenge,
2328 })
2329}
2330
2331fn auth_operation_for_method(
2332 server_id: &McpServerId,
2333 method: &str,
2334 params: &Value,
2335) -> AuthOperation {
2336 match method {
2337 "initialize" => AuthOperation::McpConnect {
2338 server_id: server_id.to_string(),
2339 metadata: MetadataMap::new(),
2340 },
2341 "tools/call" => AuthOperation::McpToolCall {
2342 server_id: server_id.to_string(),
2343 tool_name: params
2344 .get("name")
2345 .and_then(Value::as_str)
2346 .unwrap_or_default()
2347 .to_string(),
2348 input: params
2349 .get("arguments")
2350 .cloned()
2351 .unwrap_or_else(|| json!({})),
2352 metadata: MetadataMap::new(),
2353 },
2354 "resources/read" => AuthOperation::McpResourceRead {
2355 server_id: server_id.to_string(),
2356 resource_id: params
2357 .get("uri")
2358 .and_then(Value::as_str)
2359 .unwrap_or_default()
2360 .to_string(),
2361 metadata: MetadataMap::new(),
2362 },
2363 "prompts/get" => AuthOperation::McpPromptGet {
2364 server_id: server_id.to_string(),
2365 prompt_id: params
2366 .get("name")
2367 .and_then(Value::as_str)
2368 .unwrap_or_default()
2369 .to_string(),
2370 args: params
2371 .get("arguments")
2372 .cloned()
2373 .unwrap_or_else(|| json!({})),
2374 metadata: MetadataMap::new(),
2375 },
2376 other => AuthOperation::Custom {
2377 kind: format!("mcp.{other}"),
2378 payload: params.clone(),
2379 metadata: {
2380 let mut metadata = MetadataMap::new();
2381 metadata.insert("server_id".into(), Value::String(server_id.to_string()));
2382 metadata
2383 },
2384 },
2385 }
2386}
2387
2388fn normalize_mcp_tool_name(server_id: &McpServerId, tool_name: &str) -> String {
2389 let prefix = format!("mcp_{server_id}_");
2390 tool_name
2391 .strip_prefix(&prefix)
2392 .unwrap_or(tool_name)
2393 .to_string()
2394}
2395
2396async fn read_sse_stream<R>(
2397 mut reader: R,
2398 response_url: Url,
2399 frame_tx: mpsc::UnboundedSender<Result<McpFrame, McpError>>,
2400 endpoint_tx: oneshot::Sender<Result<Url, McpError>>,
2401) where
2402 R: AsyncBufRead + Unpin,
2403{
2404 let mut endpoint_tx = Some(endpoint_tx);
2405 loop {
2406 match read_next_sse_event(&mut reader).await {
2407 Ok(Some(event)) => {
2408 if let Some(endpoint) = legacy_sse_event_to_endpoint(&response_url, &event) {
2409 if let Some(tx) = endpoint_tx.take() {
2410 let _ = tx.send(endpoint);
2411 }
2412 continue;
2413 }
2414
2415 if let Some(frame) = legacy_sse_event_to_frame(event) {
2416 let _ = frame_tx.send(frame);
2417 }
2418 }
2419 Ok(None) => break,
2420 Err(error) => {
2421 if let Some(tx) = endpoint_tx.take() {
2422 let _ = tx.send(Err(error));
2423 } else {
2424 let _ = frame_tx.send(Err(error));
2425 }
2426 return;
2427 }
2428 }
2429 }
2430
2431 if let Some(tx) = endpoint_tx.take() {
2432 let _ = tx.send(Err(McpError::Transport(
2433 "SSE stream ended before endpoint event".into(),
2434 )));
2435 }
2436}
2437
2438fn resolve_sse_endpoint(response_url: &Url, endpoint: &str) -> Result<Url, McpError> {
2439 response_url
2440 .join(endpoint.trim())
2441 .map_err(|error| McpError::Transport(format!("invalid SSE endpoint URL: {error}")))
2442}
2443
2444#[derive(Debug)]
2445struct SseEvent {
2446 event_name: Option<String>,
2447 data: String,
2448 id: Option<String>,
2449 retry_ms: Option<u64>,
2450}
2451
2452async fn read_next_sse_event<R>(reader: &mut R) -> Result<Option<SseEvent>, McpError>
2453where
2454 R: AsyncBufRead + Unpin,
2455{
2456 let mut event_name = None;
2457 let mut data_lines = Vec::new();
2458 let mut id = None;
2459 let mut retry_ms = None;
2460
2461 loop {
2462 let mut line = String::new();
2463 let read = reader.read_line(&mut line).await.map_err(McpError::Io)?;
2464 if read == 0 {
2465 if event_name.is_none() && data_lines.is_empty() && id.is_none() && retry_ms.is_none() {
2466 return Ok(None);
2467 }
2468 return Ok(Some(SseEvent {
2469 event_name,
2470 data: data_lines.join("\n"),
2471 id,
2472 retry_ms,
2473 }));
2474 }
2475
2476 let line = line.trim_end_matches(['\r', '\n']);
2477 if line.is_empty() {
2478 if event_name.is_none() && data_lines.is_empty() && id.is_none() && retry_ms.is_none() {
2479 continue;
2480 }
2481 return Ok(Some(SseEvent {
2482 event_name,
2483 data: data_lines.join("\n"),
2484 id,
2485 retry_ms,
2486 }));
2487 }
2488
2489 if line.starts_with(':') {
2490 continue;
2491 }
2492
2493 if let Some(rest) = line.strip_prefix("event:") {
2494 event_name = Some(rest.trim_start().to_string());
2495 continue;
2496 }
2497 if let Some(rest) = line.strip_prefix("data:") {
2498 data_lines.push(rest.trim_start().to_string());
2499 continue;
2500 }
2501 if let Some(rest) = line.strip_prefix("id:") {
2502 id = Some(rest.trim_start().to_string());
2503 continue;
2504 }
2505 if let Some(rest) = line.strip_prefix("retry:") {
2506 retry_ms = rest.trim_start().parse().ok();
2507 }
2508 }
2509}
2510
2511fn legacy_sse_event_to_endpoint(
2512 response_url: &Url,
2513 event: &SseEvent,
2514) -> Option<Result<Url, McpError>> {
2515 if event.event_name.as_deref() != Some("endpoint") {
2516 return None;
2517 }
2518 if event.data.is_empty() {
2519 return Some(Err(McpError::Transport(
2520 "legacy SSE endpoint event is missing data".into(),
2521 )));
2522 }
2523 Some(resolve_sse_endpoint(response_url, &event.data))
2524}
2525
2526fn legacy_sse_event_to_frame(event: SseEvent) -> Option<Result<McpFrame, McpError>> {
2527 let event_name = event.event_name.unwrap_or_else(|| "message".into());
2528 if event_name != "message" || event.data.is_empty() {
2529 return None;
2530 }
2531
2532 Some(
2533 serde_json::from_str(&event.data)
2534 .map_err(McpError::Serialize)
2535 .map(|value| McpFrame { value }),
2536 )
2537}
2538
2539fn streamable_http_event_to_frame(event: SseEvent) -> Result<Option<McpFrame>, McpError> {
2540 let event_name = event.event_name.unwrap_or_else(|| "message".into());
2541 if event_name != "message" || event.data.is_empty() {
2542 return Ok(None);
2543 }
2544
2545 let value = serde_json::from_str(&event.data).map_err(McpError::Serialize)?;
2546 Ok(Some(McpFrame { value }))
2547}
2548
2549fn is_jsonrpc_request(value: &Value) -> bool {
2550 value.get("method").is_some() && value.get("id").is_some()
2551}
2552
2553fn apply_streamable_http_headers(
2554 mut request: HttpRequestBuilder,
2555 protocol_version: Option<&str>,
2556 session_id: Option<&str>,
2557) -> HttpRequestBuilder {
2558 if let Some(protocol_version) = protocol_version {
2559 request = request.header("MCP-Protocol-Version", protocol_version);
2560 }
2561 if let Some(session_id) = session_id {
2562 request = request.header("MCP-Session-Id", session_id);
2563 }
2564
2565 request
2566}
2567
2568async fn streamable_http_status_error(
2569 operation: &str,
2570 status: StatusCode,
2571 response: HttpResponse,
2572) -> McpError {
2573 let body = response
2574 .text()
2575 .await
2576 .unwrap_or_else(|_| "<unreadable response body>".into());
2577 McpError::Transport(format!("{operation} failed with status {status}: {body}"))
2578}
2579
2580#[derive(Debug, Error)]
2582pub enum McpError {
2583 #[error("io error: {0}")]
2585 Io(#[from] std::io::Error),
2586 #[error("http error: {0}")]
2588 Http(#[from] HttpError),
2589 #[error("serialization error: {0}")]
2591 Serialize(#[from] serde_json::Error),
2592 #[error("transport error: {0}")]
2594 Transport(String),
2595 #[error("protocol error: {0}")]
2597 Protocol(String),
2598 #[error("MCP auth required: {0:?}")]
2601 AuthRequired(Box<AuthRequest>),
2602 #[error("auth resolution error: {0}")]
2604 AuthResolution(String),
2605 #[error("invocation error: {0}")]
2607 Invocation(String),
2608 #[error("unknown MCP server: {0}")]
2610 UnknownServer(String),
2611}
2612
2613#[cfg(test)]
2614mod tests {
2615 use std::collections::VecDeque;
2616 use std::sync::{Arc as StdArc, Mutex as StdMutex};
2617
2618 use super::*;
2619 use agentkit_tools_core::{PermissionChecker, PermissionDecision, PermissionRequest};
2620 use tokio::io::{AsyncReadExt, AsyncWriteExt};
2621 use tokio::net::TcpListener;
2622
2623 struct AllowAll;
2624
2625 impl PermissionChecker for AllowAll {
2626 fn evaluate(&self, _request: &dyn PermissionRequest) -> PermissionDecision {
2627 PermissionDecision::Allow
2628 }
2629 }
2630
2631 struct FakeTransport {
2632 recv: VecDeque<Value>,
2633 }
2634
2635 #[async_trait]
2636 impl McpTransport for FakeTransport {
2637 async fn send(&mut self, _message: McpFrame) -> Result<(), McpError> {
2638 Ok(())
2639 }
2640
2641 async fn recv(&mut self) -> Result<Option<McpFrame>, McpError> {
2642 Ok(self.recv.pop_front().map(|value| McpFrame { value }))
2643 }
2644
2645 async fn close(&mut self) -> Result<(), McpError> {
2646 Ok(())
2647 }
2648 }
2649
2650 fn fake_connection(responses: Vec<Value>) -> McpConnection {
2651 McpConnection {
2652 server_id: McpServerId::new("fake"),
2653 transport: Mutex::new(Box::new(FakeTransport {
2654 recv: responses.into(),
2655 })),
2656 auth: Mutex::new(None),
2657 next_id: AtomicU64::new(1),
2658 }
2659 }
2660
2661 #[derive(Clone)]
2662 struct FakeTransportFactory {
2663 responses: StdArc<StdMutex<VecDeque<Vec<Value>>>>,
2664 }
2665
2666 impl FakeTransportFactory {
2667 fn new(sequences: Vec<Vec<Value>>) -> Self {
2668 Self {
2669 responses: StdArc::new(StdMutex::new(sequences.into())),
2670 }
2671 }
2672 }
2673
2674 #[async_trait]
2675 impl McpTransportFactory for FakeTransportFactory {
2676 async fn connect(&self) -> Result<Box<dyn McpTransport>, McpError> {
2677 let responses =
2678 self.responses.lock().unwrap().pop_front().ok_or_else(|| {
2679 McpError::Transport("no fake transport responses left".into())
2680 })?;
2681 Ok(Box::new(FakeTransport {
2682 recv: responses.into(),
2683 }))
2684 }
2685 }
2686
2687 #[tokio::test]
2688 async fn discovery_parses_snapshot() {
2689 let connection = fake_connection(vec![
2690 json!({ "jsonrpc": "2.0", "id": 1, "result": { "tools": [{ "name": "echo", "description": "Echo", "inputSchema": {"type": "object"} }] } }),
2691 json!({ "jsonrpc": "2.0", "id": 2, "result": { "resources": [{ "uri": "file:///tmp/example.txt", "name": "example.txt", "mimeType": "text/plain" }] } }),
2692 json!({ "jsonrpc": "2.0", "id": 3, "result": { "prompts": [{ "name": "summarize", "description": "Summarize", "arguments": [{ "name": "path" }] }] } }),
2693 ]);
2694
2695 let snapshot = connection.discover().await.unwrap();
2696 assert_eq!(snapshot.tools[0].name, "echo");
2697 assert_eq!(snapshot.resources[0].id, "file:///tmp/example.txt");
2698 assert_eq!(snapshot.prompts[0].id, "summarize");
2699 }
2700
2701 #[tokio::test]
2702 async fn tool_adapter_returns_text_output() {
2703 let connection = Arc::new(fake_connection(vec![json!({
2704 "jsonrpc": "2.0",
2705 "id": 1,
2706 "result": { "content": [{ "type": "text", "text": "pong" }] }
2707 })]));
2708 let server_id = connection.server_id().clone();
2709 let adapter = McpToolAdapter::new(
2710 &server_id,
2711 connection,
2712 McpToolDescriptor {
2713 name: "echo".into(),
2714 description: Some("Echo".into()),
2715 input_schema: json!({ "type": "object" }),
2716 metadata: MetadataMap::new(),
2717 },
2718 );
2719 let metadata = MetadataMap::new();
2720 let mut ctx = ToolContext {
2721 capability: CapabilityContext {
2722 session_id: None,
2723 turn_id: None,
2724 metadata: &metadata,
2725 },
2726 permissions: &AllowAll,
2727 resources: &(),
2728 cancellation: None,
2729 };
2730
2731 let result = adapter
2732 .invoke(
2733 ToolRequest {
2734 call_id: "call-1".into(),
2735 tool_name: ToolName::new("mcp_fake_echo"),
2736 input: json!({}),
2737 session_id: "session-1".into(),
2738 turn_id: "turn-1".into(),
2739 metadata: MetadataMap::new(),
2740 },
2741 &mut ctx,
2742 )
2743 .await
2744 .unwrap();
2745
2746 assert_eq!(result.result.output, ToolOutput::Text("pong".into()));
2747 }
2748
2749 #[tokio::test]
2750 async fn request_surfaces_auth_required_errors() {
2751 let connection = fake_connection(vec![json!({
2752 "jsonrpc": "2.0",
2753 "id": 1,
2754 "error": {
2755 "code": -32001,
2756 "message": "authentication required",
2757 "data": {
2758 "auth_required": true,
2759 "scope": "secrets.read"
2760 }
2761 }
2762 })]);
2763
2764 let error = connection.call_tool("echo", json!({})).await.unwrap_err();
2765 match error {
2766 McpError::AuthRequired(request) => {
2767 assert_eq!(request.provider, "mcp.fake");
2768 assert_eq!(
2769 request.challenge.get("method"),
2770 Some(&Value::String("tools/call".into()))
2771 );
2772 assert!(matches!(
2773 request.operation,
2774 AuthOperation::McpToolCall { ref tool_name, .. } if tool_name == "echo"
2775 ));
2776 }
2777 other => panic!("unexpected error: {other:?}"),
2778 }
2779 }
2780
2781 #[tokio::test]
2782 async fn tool_adapter_maps_auth_required_into_tool_error() {
2783 let connection = Arc::new(fake_connection(vec![json!({
2784 "jsonrpc": "2.0",
2785 "id": 1,
2786 "error": {
2787 "code": -32001,
2788 "message": "authentication required",
2789 "data": { "auth_required": true }
2790 }
2791 })]));
2792 let server_id = connection.server_id().clone();
2793 let adapter = McpToolAdapter::new(
2794 &server_id,
2795 connection,
2796 McpToolDescriptor {
2797 name: "echo".into(),
2798 description: Some("Echo".into()),
2799 input_schema: json!({ "type": "object" }),
2800 metadata: MetadataMap::new(),
2801 },
2802 );
2803 let metadata = MetadataMap::new();
2804 let mut ctx = ToolContext {
2805 capability: CapabilityContext {
2806 session_id: None,
2807 turn_id: None,
2808 metadata: &metadata,
2809 },
2810 permissions: &AllowAll,
2811 resources: &(),
2812 cancellation: None,
2813 };
2814
2815 let error = adapter
2816 .invoke(
2817 ToolRequest {
2818 call_id: "call-1".into(),
2819 tool_name: ToolName::new("mcp_fake_echo"),
2820 input: json!({}),
2821 session_id: "session-1".into(),
2822 turn_id: "turn-1".into(),
2823 metadata: MetadataMap::new(),
2824 },
2825 &mut ctx,
2826 )
2827 .await
2828 .unwrap_err();
2829
2830 match error {
2831 ToolError::AuthRequired(request) => {
2832 assert_eq!(request.provider, "mcp.fake");
2833 }
2834 other => panic!("unexpected error: {other:?}"),
2835 }
2836 }
2837
2838 struct RecordingTransport {
2839 recv: VecDeque<Value>,
2840 sent: StdArc<StdMutex<Vec<Value>>>,
2841 }
2842
2843 #[async_trait]
2844 impl McpTransport for RecordingTransport {
2845 async fn send(&mut self, message: McpFrame) -> Result<(), McpError> {
2846 self.sent.lock().unwrap().push(message.value);
2847 Ok(())
2848 }
2849
2850 async fn recv(&mut self) -> Result<Option<McpFrame>, McpError> {
2851 Ok(self.recv.pop_front().map(|value| McpFrame { value }))
2852 }
2853
2854 async fn close(&mut self) -> Result<(), McpError> {
2855 Ok(())
2856 }
2857 }
2858
2859 #[derive(Clone)]
2860 struct RecordingTransportFactory {
2861 responses: StdArc<StdMutex<VecDeque<Vec<Value>>>>,
2862 sent: StdArc<StdMutex<Vec<Value>>>,
2863 }
2864
2865 impl RecordingTransportFactory {
2866 fn new(sequences: Vec<Vec<Value>>) -> Self {
2867 Self {
2868 responses: StdArc::new(StdMutex::new(sequences.into())),
2869 sent: StdArc::new(StdMutex::new(Vec::new())),
2870 }
2871 }
2872
2873 fn sent(&self) -> Vec<Value> {
2874 self.sent.lock().unwrap().clone()
2875 }
2876 }
2877
2878 #[async_trait]
2879 impl McpTransportFactory for RecordingTransportFactory {
2880 async fn connect(&self) -> Result<Box<dyn McpTransport>, McpError> {
2881 let responses = self.responses.lock().unwrap().pop_front().ok_or_else(|| {
2882 McpError::Transport("no recording transport responses left".into())
2883 })?;
2884 Ok(Box::new(RecordingTransport {
2885 recv: responses.into(),
2886 sent: self.sent.clone(),
2887 }))
2888 }
2889 }
2890
2891 #[tokio::test]
2892 async fn connection_includes_resolved_auth_in_future_requests() {
2893 let factory = RecordingTransportFactory::new(vec![vec![
2894 json!({ "jsonrpc": "2.0", "id": 0, "result": { "protocolVersion": "2025-11-25", "capabilities": {}, "serverInfo": { "name": "recording", "version": "1.0.0" } } }),
2895 json!({ "jsonrpc": "2.0", "id": 1, "result": { "content": [{ "type": "text", "text": "ok" }] } }),
2896 ]]);
2897 let config = McpServerConfig::new(
2898 "recording",
2899 McpTransportBinding::Custom(Arc::new(factory.clone())),
2900 );
2901 let connection = McpConnection::connect(&config).await.unwrap();
2902 let mut auth = MetadataMap::new();
2903 auth.insert("token".into(), json!("secret-token"));
2904 let request = AuthRequest {
2905 task_id: None,
2906 id: "auth-recording-tool".into(),
2907 provider: "mcp.recording".into(),
2908 operation: AuthOperation::McpToolCall {
2909 server_id: "recording".into(),
2910 tool_name: "echo".into(),
2911 input: json!({}),
2912 metadata: MetadataMap::new(),
2913 },
2914 challenge: MetadataMap::new(),
2915 };
2916 connection
2917 .resolve_auth(agentkit_tools_core::AuthResolution::Provided {
2918 request,
2919 credentials: auth,
2920 })
2921 .await
2922 .unwrap();
2923
2924 let _ = connection.call_tool("echo", json!({})).await.unwrap();
2925 let sent = factory.sent();
2926 assert!(
2927 sent.iter().any(|value| {
2928 value
2929 .get("params")
2930 .and_then(|params| params.get("auth"))
2931 .and_then(|auth| auth.get("token"))
2932 == Some(&json!("secret-token"))
2933 }),
2934 "expected an MCP request to include the resolved auth payload, saw {:?}",
2935 sent
2936 );
2937 }
2938
2939 #[tokio::test]
2940 async fn manager_reuses_stored_auth_on_connect() {
2941 let factory = RecordingTransportFactory::new(vec![vec![
2942 json!({ "jsonrpc": "2.0", "id": 0, "result": { "protocolVersion": "2025-11-25", "capabilities": {}, "serverInfo": { "name": "recording", "version": "1.0.0" } } }),
2943 json!({ "jsonrpc": "2.0", "id": 1, "result": { "tools": [] } }),
2944 json!({ "jsonrpc": "2.0", "id": 2, "result": { "resources": [] } }),
2945 json!({ "jsonrpc": "2.0", "id": 3, "result": { "prompts": [] } }),
2946 ]]);
2947 let server_id = McpServerId::new("recording");
2948 let mut manager = McpServerManager::new().with_server(McpServerConfig::new(
2949 server_id.to_string(),
2950 McpTransportBinding::Custom(Arc::new(factory.clone())),
2951 ));
2952 let mut auth = MetadataMap::new();
2953 auth.insert("token".into(), json!("seed-token"));
2954 let request = AuthRequest {
2955 task_id: None,
2956 id: "auth-recording-connect".into(),
2957 provider: "mcp.recording".into(),
2958 operation: AuthOperation::McpConnect {
2959 server_id: server_id.to_string(),
2960 metadata: MetadataMap::new(),
2961 },
2962 challenge: MetadataMap::new(),
2963 };
2964 manager
2965 .resolve_auth(agentkit_tools_core::AuthResolution::Provided {
2966 request,
2967 credentials: auth,
2968 })
2969 .await
2970 .unwrap();
2971
2972 manager.connect_server(&server_id).await.unwrap();
2973 let sent = factory.sent();
2974 assert!(
2975 sent.iter().any(|value| {
2976 value.get("method").and_then(Value::as_str) == Some("initialize")
2977 && value
2978 .get("params")
2979 .and_then(|params| params.get("auth"))
2980 .and_then(|auth| auth.get("token"))
2981 == Some(&json!("seed-token"))
2982 }),
2983 "expected initialize to include stored auth, saw {:?}",
2984 sent
2985 );
2986 }
2987
2988 #[tokio::test]
2989 async fn manager_resolves_auth_and_replays_resource_read() {
2990 let factory = RecordingTransportFactory::new(vec![vec![
2991 json!({ "jsonrpc": "2.0", "id": 0, "result": { "protocolVersion": "2025-11-25", "capabilities": {}, "serverInfo": { "name": "recording", "version": "1.0.0" } } }),
2992 json!({ "jsonrpc": "2.0", "id": 1, "result": { "tools": [] } }),
2993 json!({ "jsonrpc": "2.0", "id": 2, "result": { "resources": [] } }),
2994 json!({ "jsonrpc": "2.0", "id": 3, "result": { "prompts": [] } }),
2995 json!({
2996 "jsonrpc": "2.0",
2997 "id": 4,
2998 "result": {
2999 "contents": [
3000 {
3001 "uri": "file:///tmp/secret.txt",
3002 "text": "secret from resource"
3003 }
3004 ]
3005 }
3006 }),
3007 ]]);
3008 let server_id = McpServerId::new("recording");
3009 let mut manager = McpServerManager::new().with_server(McpServerConfig::new(
3010 server_id.to_string(),
3011 McpTransportBinding::Custom(Arc::new(factory.clone())),
3012 ));
3013 let mut auth = MetadataMap::new();
3014 auth.insert("token".into(), json!("resource-token"));
3015 let request = AuthRequest {
3016 task_id: None,
3017 id: "auth-recording-resource".into(),
3018 provider: "mcp.recording".into(),
3019 operation: AuthOperation::McpResourceRead {
3020 server_id: server_id.to_string(),
3021 resource_id: "file:///tmp/secret.txt".into(),
3022 metadata: MetadataMap::new(),
3023 },
3024 challenge: MetadataMap::new(),
3025 };
3026
3027 let result = manager
3028 .resolve_auth_and_resume(agentkit_tools_core::AuthResolution::Provided {
3029 request,
3030 credentials: auth,
3031 })
3032 .await
3033 .unwrap();
3034
3035 match result {
3036 McpOperationResult::Resource(contents) => {
3037 assert_eq!(
3038 contents.data,
3039 DataRef::InlineText("secret from resource".into())
3040 );
3041 }
3042 other => panic!("unexpected replay result: {other:?}"),
3043 }
3044
3045 let sent = factory.sent();
3046 assert!(
3047 sent.iter().any(|value| {
3048 value.get("method").and_then(Value::as_str) == Some("resources/read")
3049 && value
3050 .get("params")
3051 .and_then(|params| params.get("auth"))
3052 .and_then(|auth| auth.get("token"))
3053 == Some(&json!("resource-token"))
3054 }),
3055 "expected resources/read to include resolved auth, saw {:?}",
3056 sent
3057 );
3058 }
3059
3060 #[tokio::test]
3061 async fn manager_resolves_auth_and_replays_connect() {
3062 let factory = RecordingTransportFactory::new(vec![vec![
3063 json!({ "jsonrpc": "2.0", "id": 0, "result": { "protocolVersion": "2025-11-25", "capabilities": {}, "serverInfo": { "name": "recording", "version": "1.0.0" } } }),
3064 json!({ "jsonrpc": "2.0", "id": 1, "result": { "tools": [] } }),
3065 json!({ "jsonrpc": "2.0", "id": 2, "result": { "resources": [] } }),
3066 json!({ "jsonrpc": "2.0", "id": 3, "result": { "prompts": [] } }),
3067 ]]);
3068 let server_id = McpServerId::new("recording");
3069 let mut manager = McpServerManager::new().with_server(McpServerConfig::new(
3070 server_id.to_string(),
3071 McpTransportBinding::Custom(Arc::new(factory.clone())),
3072 ));
3073 let mut auth = MetadataMap::new();
3074 auth.insert("token".into(), json!("connect-token"));
3075 let request = AuthRequest {
3076 task_id: None,
3077 id: "auth-recording-connect-replay".into(),
3078 provider: "mcp.recording".into(),
3079 operation: AuthOperation::McpConnect {
3080 server_id: server_id.to_string(),
3081 metadata: MetadataMap::new(),
3082 },
3083 challenge: MetadataMap::new(),
3084 };
3085
3086 let result = manager
3087 .resolve_auth_and_resume(agentkit_tools_core::AuthResolution::Provided {
3088 request,
3089 credentials: auth,
3090 })
3091 .await
3092 .unwrap();
3093
3094 match result {
3095 McpOperationResult::Connected(snapshot) => {
3096 assert_eq!(snapshot.server_id, server_id);
3097 }
3098 other => panic!("unexpected replay result: {other:?}"),
3099 }
3100 }
3101
3102 #[tokio::test]
3103 async fn sse_transport_posts_messages_and_receives_frames() {
3104 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3105 let address = listener.local_addr().unwrap();
3106 let requests = StdArc::new(StdMutex::new(Vec::new()));
3107 let captured = requests.clone();
3108
3109 let server = tokio::spawn(async move {
3110 for _ in 0..2 {
3111 let (mut socket, _) = listener.accept().await.unwrap();
3112 let mut buffer = vec![0_u8; 4096];
3113 let read = socket.read(&mut buffer).await.unwrap();
3114 let request = String::from_utf8_lossy(&buffer[..read]).to_string();
3115
3116 if request.starts_with("GET /sse ") {
3117 let body = concat!(
3118 "event: endpoint\n",
3119 "data: /messages\n\n",
3120 "event: message\n",
3121 "data: {\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"tools\":[]}}\n\n"
3122 );
3123 let response = format!(
3124 "HTTP/1.1 200 OK\r\ncontent-type: text/event-stream\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{}",
3125 body.len(),
3126 body
3127 );
3128 socket.write_all(response.as_bytes()).await.unwrap();
3129 } else {
3130 captured.lock().unwrap().push(request);
3131 socket
3132 .write_all(
3133 b"HTTP/1.1 202 Accepted\r\ncontent-length: 0\r\nconnection: close\r\n\r\n",
3134 )
3135 .await
3136 .unwrap();
3137 }
3138 }
3139 });
3140
3141 let factory =
3142 SseTransportFactory::new(SseTransportConfig::new(format!("http://{address}/sse")));
3143 let mut transport = factory.connect().await.unwrap();
3144 transport
3145 .send(McpFrame {
3146 value: json!({
3147 "jsonrpc": "2.0",
3148 "id": 1,
3149 "method": "tools/list",
3150 "params": {}
3151 }),
3152 })
3153 .await
3154 .unwrap();
3155 let frame = transport.recv().await.unwrap().unwrap();
3156 transport.close().await.unwrap();
3157 server.await.unwrap();
3158
3159 assert_eq!(frame.value["result"]["tools"], json!([]));
3160 let requests = requests.lock().unwrap();
3161 assert_eq!(requests.len(), 1);
3162 assert!(requests[0].starts_with("POST /messages "));
3163 assert!(requests[0].contains("\"method\":\"tools/list\""));
3164 }
3165
3166 #[tokio::test]
3167 async fn streamable_http_connection_tracks_session_and_protocol_headers() {
3168 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3169 let address = listener.local_addr().unwrap();
3170 let requests = StdArc::new(StdMutex::new(Vec::new()));
3171 let captured = requests.clone();
3172
3173 let server = tokio::spawn(async move {
3174 for _ in 0..4 {
3175 let (mut socket, _) = listener.accept().await.unwrap();
3176 let mut buffer = vec![0_u8; 8192];
3177 let read = socket.read(&mut buffer).await.unwrap();
3178 let request = String::from_utf8_lossy(&buffer[..read]).to_string();
3179 captured.lock().unwrap().push(request.clone());
3180
3181 let response = if request.contains("\"method\":\"initialize\"") {
3182 let body = "{\"jsonrpc\":\"2.0\",\"id\":0,\"result\":{\"protocolVersion\":\"2025-11-25\",\"capabilities\":{},\"serverInfo\":{\"name\":\"remote\",\"version\":\"1.0.0\"}}}";
3183 format!(
3184 "HTTP/1.1 200 OK\r\ncontent-type: application/json\r\nMCP-Session-Id: session-123\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{}",
3185 body.len(),
3186 body
3187 )
3188 } else if request.contains("\"method\":\"notifications/initialized\"") {
3189 "HTTP/1.1 202 Accepted\r\ncontent-length: 0\r\nconnection: close\r\n\r\n"
3190 .to_string()
3191 } else if request.starts_with("DELETE /mcp ") {
3192 "HTTP/1.1 204 No Content\r\ncontent-length: 0\r\nconnection: close\r\n\r\n"
3193 .to_string()
3194 } else {
3195 let body = "{\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"tools\":[]}}";
3196 format!(
3197 "HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{}",
3198 body.len(),
3199 body
3200 )
3201 };
3202
3203 socket.write_all(response.as_bytes()).await.unwrap();
3204 }
3205 });
3206
3207 let config = McpServerConfig::new(
3208 "remote",
3209 McpTransportBinding::StreamableHttp(StreamableHttpTransportConfig::new(format!(
3210 "http://{address}/mcp"
3211 ))),
3212 );
3213 let connection = McpConnection::connect(&config).await.unwrap();
3214 let _ = connection.list_tools().await.unwrap();
3215 connection.close().await.unwrap();
3216 server.await.unwrap();
3217
3218 let requests = requests.lock().unwrap();
3219 assert_eq!(requests.len(), 4);
3220 let normalized = requests
3221 .iter()
3222 .map(|request| request.to_ascii_lowercase())
3223 .collect::<Vec<_>>();
3224 assert!(requests[0].starts_with("POST /mcp "));
3225 assert!(!requests[0].contains("MCP-Session-Id:"));
3226 assert!(normalized[1].contains("mcp-session-id: session-123"));
3227 assert!(normalized[1].contains("mcp-protocol-version: 2025-11-25"));
3228 assert!(normalized[2].contains("mcp-session-id: session-123"));
3229 assert!(normalized[2].contains("mcp-protocol-version: 2025-11-25"));
3230 assert!(requests[3].starts_with("DELETE /mcp "));
3231 assert!(normalized[3].contains("mcp-session-id: session-123"));
3232 }
3233
3234 #[tokio::test]
3235 async fn streamable_http_transport_resumes_sse_streams_until_response_arrives() {
3236 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3237 let address = listener.local_addr().unwrap();
3238 let requests = StdArc::new(StdMutex::new(Vec::new()));
3239 let captured = requests.clone();
3240
3241 let server = tokio::spawn(async move {
3242 for _ in 0..2 {
3243 let (mut socket, _) = listener.accept().await.unwrap();
3244 let mut buffer = vec![0_u8; 8192];
3245 let read = socket.read(&mut buffer).await.unwrap();
3246 let request = String::from_utf8_lossy(&buffer[..read]).to_string();
3247 captured.lock().unwrap().push(request.clone());
3248
3249 let response = if request.starts_with("POST /mcp ") {
3250 let body = concat!(
3251 "id: evt-1\n",
3252 "event: message\n",
3253 "data: {\"jsonrpc\":\"2.0\",\"method\":\"notifications/message\",\"params\":{\"phase\":\"stream-start\"}}\n\n"
3254 );
3255 format!(
3256 "HTTP/1.1 200 OK\r\ncontent-type: text/event-stream\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{}",
3257 body.len(),
3258 body
3259 )
3260 } else {
3261 let body = concat!(
3262 "id: evt-2\n",
3263 "event: message\n",
3264 "data: {\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"tools\":[]}}\n\n"
3265 );
3266 format!(
3267 "HTTP/1.1 200 OK\r\ncontent-type: text/event-stream\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{}",
3268 body.len(),
3269 body
3270 )
3271 };
3272
3273 socket.write_all(response.as_bytes()).await.unwrap();
3274 }
3275 });
3276
3277 let factory = StreamableHttpTransportFactory::new(StreamableHttpTransportConfig::new(
3278 format!("http://{address}/mcp"),
3279 ));
3280 let mut transport = factory.connect().await.unwrap();
3281 transport
3282 .send(McpFrame {
3283 value: json!({
3284 "jsonrpc": "2.0",
3285 "id": 1,
3286 "method": "tools/list",
3287 "params": {}
3288 }),
3289 })
3290 .await
3291 .unwrap();
3292
3293 let first = transport.recv().await.unwrap().unwrap();
3294 let second = transport.recv().await.unwrap().unwrap();
3295 transport.close().await.unwrap();
3296 server.await.unwrap();
3297
3298 assert_eq!(
3299 first.value["method"],
3300 Value::String("notifications/message".into())
3301 );
3302 assert_eq!(second.value["result"]["tools"], json!([]));
3303
3304 let requests = requests.lock().unwrap();
3305 assert_eq!(requests.len(), 2);
3306 assert!(requests[0].starts_with("POST /mcp "));
3307 assert!(requests[1].starts_with("GET /mcp "));
3308 assert!(
3309 requests[1].contains("last-event-id: evt-1")
3310 || requests[1].contains("Last-Event-ID: evt-1")
3311 );
3312 }
3313
3314 #[tokio::test]
3315 async fn server_manager_connects_refreshes_and_aggregates_tools() {
3316 let alpha = McpServerConfig::new(
3317 "alpha",
3318 McpTransportBinding::Custom(Arc::new(FakeTransportFactory::new(vec![vec![
3319 json!({ "jsonrpc": "2.0", "id": 0, "result": { "protocolVersion": "2025-11-25", "capabilities": {}, "serverInfo": { "name": "alpha", "version": "1.0.0" } } }),
3320 json!({ "jsonrpc": "2.0", "id": 1, "result": { "tools": [{ "name": "echo", "description": "Echo", "inputSchema": {"type": "object"} }] } }),
3321 json!({ "jsonrpc": "2.0", "id": 2, "result": { "resources": [] } }),
3322 json!({ "jsonrpc": "2.0", "id": 3, "result": { "prompts": [] } }),
3323 json!({ "jsonrpc": "2.0", "id": 4, "result": { "tools": [{ "name": "echo_v2", "description": "Echo 2", "inputSchema": {"type": "object"} }] } }),
3324 json!({ "jsonrpc": "2.0", "id": 5, "result": { "resources": [] } }),
3325 json!({ "jsonrpc": "2.0", "id": 6, "result": { "prompts": [] } }),
3326 ]]))),
3327 );
3328 let beta = McpServerConfig::new(
3329 "beta",
3330 McpTransportBinding::Custom(Arc::new(FakeTransportFactory::new(vec![vec![
3331 json!({ "jsonrpc": "2.0", "id": 0, "result": { "protocolVersion": "2025-11-25", "capabilities": {}, "serverInfo": { "name": "beta", "version": "1.0.0" } } }),
3332 json!({ "jsonrpc": "2.0", "id": 1, "result": { "tools": [{ "name": "search", "description": "Search", "inputSchema": {"type": "object"} }] } }),
3333 json!({ "jsonrpc": "2.0", "id": 2, "result": { "resources": [] } }),
3334 json!({ "jsonrpc": "2.0", "id": 3, "result": { "prompts": [] } }),
3335 ]]))),
3336 );
3337
3338 let mut manager = McpServerManager::new().with_server(alpha).with_server(beta);
3339
3340 let handles = manager.connect_all().await.unwrap();
3341 assert_eq!(handles.len(), 2);
3342 assert_eq!(
3343 manager
3344 .tool_registry()
3345 .specs()
3346 .into_iter()
3347 .map(|spec| spec.name.0)
3348 .collect::<Vec<_>>(),
3349 vec!["mcp_alpha_echo".to_string(), "mcp_beta_search".to_string()]
3350 );
3351
3352 let refreshed = manager
3353 .refresh_server(&McpServerId::new("alpha"))
3354 .await
3355 .unwrap();
3356 assert_eq!(refreshed.tools[0].name, "echo_v2");
3357 assert_eq!(
3358 manager
3359 .connected_server(&McpServerId::new("alpha"))
3360 .unwrap()
3361 .snapshot()
3362 .tools[0]
3363 .name,
3364 "echo_v2"
3365 );
3366
3367 let capabilities = manager.capability_provider();
3368 assert_eq!(capabilities.invocables().len(), 2);
3369
3370 manager
3371 .disconnect_server(&McpServerId::new("alpha"))
3372 .await
3373 .unwrap();
3374 assert!(
3375 manager
3376 .connected_server(&McpServerId::new("alpha"))
3377 .is_none()
3378 );
3379 }
3380}