1use std::collections::BTreeMap;
2use std::fmt;
3use std::process::Stdio;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicU64, Ordering};
6
7use agentkit_capabilities::{
8 CapabilityContext, CapabilityError, CapabilityName, CapabilityProvider, Invocable,
9 InvocableOutput, InvocableRequest, InvocableResult, InvocableSpec, PromptContents,
10 PromptDescriptor, PromptId, PromptProvider, ResourceContents, ResourceDescriptor, ResourceId,
11 ResourceProvider,
12};
13use agentkit_core::{
14 DataRef, Item, ItemKind, MetadataMap, Part, TextPart, ToolOutput, ToolResultPart,
15};
16use agentkit_tools_core::{
17 AuthOperation, AuthRequest, AuthResolution, Tool, ToolAnnotations, ToolContext, ToolError,
18 ToolName, ToolRegistry, ToolRequest, ToolResult, ToolSpec,
19};
20use async_trait::async_trait;
21use futures_util::TryStreamExt;
22use reqwest::{Client, Url};
23use serde::{Deserialize, Serialize};
24use serde_json::{Value, json};
25use thiserror::Error;
26use tokio::io::{AsyncBufRead, AsyncBufReadExt, AsyncWriteExt, BufReader};
27use tokio::process::{Child, ChildStdin, ChildStdout, Command};
28use tokio::sync::{Mutex, mpsc, oneshot};
29use tokio::task::JoinHandle;
30use tokio_util::io::StreamReader;
31
32#[derive(Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
47pub struct McpServerId(pub String);
48
49impl McpServerId {
50 pub fn new(value: impl Into<String>) -> Self {
52 Self(value.into())
53 }
54}
55
56impl fmt::Display for McpServerId {
57 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
58 self.0.fmt(f)
59 }
60}
61
62#[derive(Clone, Debug, PartialEq, Eq)]
80pub struct StdioTransportConfig {
81 pub command: String,
83 pub args: Vec<String>,
85 pub env: Vec<(String, String)>,
87 pub cwd: Option<std::path::PathBuf>,
89}
90
91impl StdioTransportConfig {
92 pub fn new(command: impl Into<String>) -> Self {
94 Self {
95 command: command.into(),
96 args: Vec::new(),
97 env: Vec::new(),
98 cwd: None,
99 }
100 }
101
102 pub fn with_arg(mut self, arg: impl Into<String>) -> Self {
104 self.args.push(arg.into());
105 self
106 }
107
108 pub fn with_env(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
110 self.env.push((key.into(), value.into()));
111 self
112 }
113
114 pub fn with_cwd(mut self, cwd: impl Into<std::path::PathBuf>) -> Self {
116 self.cwd = Some(cwd.into());
117 self
118 }
119}
120
121#[derive(Clone, Debug, PartialEq, Eq)]
136pub struct SseTransportConfig {
137 pub url: String,
139 pub headers: Vec<(String, String)>,
141}
142
143impl SseTransportConfig {
144 pub fn new(url: impl Into<String>) -> Self {
146 Self {
147 url: url.into(),
148 headers: Vec::new(),
149 }
150 }
151
152 pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
154 self.headers.push((key.into(), value.into()));
155 self
156 }
157}
158
159#[derive(Clone)]
166pub enum McpTransportBinding {
167 Stdio(StdioTransportConfig),
169 Sse(SseTransportConfig),
171 Custom(Arc<dyn McpTransportFactory>),
173}
174
175#[derive(Clone)]
196pub struct McpServerConfig {
197 pub id: McpServerId,
199 pub transport: McpTransportBinding,
201 pub metadata: MetadataMap,
203}
204
205impl McpServerConfig {
206 pub fn new(id: impl Into<String>, transport: McpTransportBinding) -> Self {
213 Self {
214 id: McpServerId::new(id),
215 transport,
216 metadata: MetadataMap::new(),
217 }
218 }
219
220 pub fn stdio(id: impl Into<String>, command: impl Into<String>) -> Self {
222 Self::new(
223 id,
224 McpTransportBinding::Stdio(StdioTransportConfig::new(command)),
225 )
226 }
227
228 pub fn sse(id: impl Into<String>, url: impl Into<String>) -> Self {
230 Self::new(id, McpTransportBinding::Sse(SseTransportConfig::new(url)))
231 }
232
233 pub fn with_metadata(mut self, metadata: MetadataMap) -> Self {
235 self.metadata = metadata;
236 self
237 }
238}
239
240#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
245pub struct McpFrame {
246 pub value: Value,
248}
249
250#[async_trait]
261pub trait McpTransportFactory: Send + Sync {
262 async fn connect(&self) -> Result<Box<dyn McpTransport>, McpError>;
264}
265
266#[async_trait]
275pub trait McpTransport: Send + Sync {
276 async fn send(&mut self, message: McpFrame) -> Result<(), McpError>;
278 async fn recv(&mut self) -> Result<Option<McpFrame>, McpError>;
280 async fn close(&mut self) -> Result<(), McpError>;
282}
283
284pub struct StdioTransportFactory {
289 config: StdioTransportConfig,
290}
291
292impl StdioTransportFactory {
293 pub fn new(config: StdioTransportConfig) -> Self {
295 Self { config }
296 }
297}
298
299#[async_trait]
300impl McpTransportFactory for StdioTransportFactory {
301 async fn connect(&self) -> Result<Box<dyn McpTransport>, McpError> {
302 let mut command = Command::new(&self.config.command);
303 command.args(&self.config.args);
304 command.stdin(Stdio::piped());
305 command.stdout(Stdio::piped());
306 command.stderr(Stdio::inherit());
307
308 if let Some(cwd) = &self.config.cwd {
309 command.current_dir(cwd);
310 }
311
312 for (key, value) in &self.config.env {
313 command.env(key, value);
314 }
315
316 let mut child = command.spawn().map_err(McpError::Io)?;
317 let stdin = child
318 .stdin
319 .take()
320 .ok_or_else(|| McpError::Transport("failed to capture MCP stdin".into()))?;
321 let stdout = child
322 .stdout
323 .take()
324 .ok_or_else(|| McpError::Transport("failed to capture MCP stdout".into()))?;
325
326 Ok(Box::new(StdioTransport {
327 child,
328 stdin,
329 stdout: BufReader::new(stdout),
330 }))
331 }
332}
333
334pub struct SseTransportFactory {
339 config: SseTransportConfig,
340}
341
342impl SseTransportFactory {
343 pub fn new(config: SseTransportConfig) -> Self {
345 Self { config }
346 }
347}
348
349#[async_trait]
350impl McpTransportFactory for SseTransportFactory {
351 async fn connect(&self) -> Result<Box<dyn McpTransport>, McpError> {
352 let client = Client::builder()
353 .user_agent("agentkit-mcp/0.1.0")
354 .build()
355 .map_err(McpError::Http)?;
356
357 let mut request = client
358 .get(&self.config.url)
359 .header("Accept", "text/event-stream")
360 .header("Cache-Control", "no-cache");
361
362 for (key, value) in &self.config.headers {
363 request = request.header(key, value);
364 }
365
366 let response = request.send().await.map_err(McpError::Http)?;
367 let status = response.status();
368 if !status.is_success() {
369 let body = response
370 .text()
371 .await
372 .unwrap_or_else(|_| "<unreadable response body>".into());
373 return Err(McpError::Transport(format!(
374 "SSE connection failed with status {status}: {body}"
375 )));
376 }
377
378 let response_url = response.url().clone();
379 let stream = response.bytes_stream().map_err(std::io::Error::other);
380 let reader = BufReader::new(StreamReader::new(stream));
381 let (frame_tx, frame_rx) = mpsc::unbounded_channel();
382 let (endpoint_tx, endpoint_rx) = oneshot::channel();
383 let read_task = tokio::spawn(read_sse_stream(reader, response_url, frame_tx, endpoint_tx));
384
385 let endpoint_url = endpoint_rx
386 .await
387 .map_err(|_| McpError::Transport("SSE stream closed before endpoint event".into()))??;
388
389 Ok(Box::new(SseTransport {
390 client,
391 endpoint_url,
392 headers: self.config.headers.clone(),
393 frame_rx,
394 read_task,
395 }))
396 }
397}
398
399struct StdioTransport {
400 child: Child,
401 stdin: ChildStdin,
402 stdout: BufReader<ChildStdout>,
403}
404
405struct SseTransport {
406 client: Client,
407 endpoint_url: Url,
408 headers: Vec<(String, String)>,
409 frame_rx: mpsc::UnboundedReceiver<Result<McpFrame, McpError>>,
410 read_task: JoinHandle<()>,
411}
412
413#[async_trait]
414impl McpTransport for StdioTransport {
415 async fn send(&mut self, message: McpFrame) -> Result<(), McpError> {
416 let mut encoded = serde_json::to_vec(&message.value).map_err(McpError::Serialize)?;
417 encoded.push(b'\n');
418 self.stdin.write_all(&encoded).await.map_err(McpError::Io)?;
419 self.stdin.flush().await.map_err(McpError::Io)?;
420 Ok(())
421 }
422
423 async fn recv(&mut self) -> Result<Option<McpFrame>, McpError> {
424 let mut line = String::new();
425 let read = self
426 .stdout
427 .read_line(&mut line)
428 .await
429 .map_err(McpError::Io)?;
430 if read == 0 {
431 return Ok(None);
432 }
433
434 let value = serde_json::from_str(line.trim()).map_err(McpError::Serialize)?;
435 Ok(Some(McpFrame { value }))
436 }
437
438 async fn close(&mut self) -> Result<(), McpError> {
439 let _ = self.stdin.shutdown().await;
440 let _ = self.child.kill().await;
441 Ok(())
442 }
443}
444
445#[async_trait]
446impl McpTransport for SseTransport {
447 async fn send(&mut self, message: McpFrame) -> Result<(), McpError> {
448 let mut request = self
449 .client
450 .post(self.endpoint_url.clone())
451 .header("Content-Type", "application/json");
452
453 for (key, value) in &self.headers {
454 request = request.header(key, value);
455 }
456
457 let response = request
458 .json(&message.value)
459 .send()
460 .await
461 .map_err(McpError::Http)?;
462 let status = response.status();
463 if !status.is_success() {
464 let body = response
465 .text()
466 .await
467 .unwrap_or_else(|_| "<unreadable response body>".into());
468 return Err(McpError::Transport(format!(
469 "SSE POST failed with status {status}: {body}"
470 )));
471 }
472
473 Ok(())
474 }
475
476 async fn recv(&mut self) -> Result<Option<McpFrame>, McpError> {
477 match self.frame_rx.recv().await {
478 Some(Ok(frame)) => Ok(Some(frame)),
479 Some(Err(error)) => Err(error),
480 None => Ok(None),
481 }
482 }
483
484 async fn close(&mut self) -> Result<(), McpError> {
485 self.read_task.abort();
486 Ok(())
487 }
488}
489
490#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
496pub struct McpToolDescriptor {
497 pub name: String,
499 pub description: Option<String>,
501 pub input_schema: Value,
503 pub metadata: MetadataMap,
505}
506
507#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
512pub struct McpResourceDescriptor {
513 pub id: String,
515 pub name: String,
517 pub description: Option<String>,
519 pub mime_type: Option<String>,
521 pub metadata: MetadataMap,
523}
524
525#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
530pub struct McpPromptDescriptor {
531 pub id: String,
533 pub name: String,
535 pub description: Option<String>,
537 pub input_schema: Value,
539 pub metadata: MetadataMap,
541}
542
543#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
549pub struct McpDiscoverySnapshot {
550 pub server_id: McpServerId,
552 pub tools: Vec<McpToolDescriptor>,
554 pub resources: Vec<McpResourceDescriptor>,
556 pub prompts: Vec<McpPromptDescriptor>,
558 pub metadata: MetadataMap,
560}
561
562pub struct McpConnection {
592 server_id: McpServerId,
593 transport: Mutex<Box<dyn McpTransport>>,
594 auth: Mutex<Option<MetadataMap>>,
595 next_id: AtomicU64,
596}
597
598#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
603pub enum McpOperationResult {
604 Connected(McpDiscoverySnapshot),
606 Tool(Value),
608 Resource(ResourceContents),
610 Prompt(PromptContents),
612}
613
614impl McpConnection {
615 pub async fn connect(config: &McpServerConfig) -> Result<Self, McpError> {
623 Self::connect_with_auth(config, None).await
624 }
625
626 async fn connect_with_auth(
627 config: &McpServerConfig,
628 auth: Option<&MetadataMap>,
629 ) -> Result<Self, McpError> {
630 let factory: Arc<dyn McpTransportFactory> = match &config.transport {
631 McpTransportBinding::Stdio(binding) => {
632 Arc::new(StdioTransportFactory::new(binding.clone()))
633 }
634 McpTransportBinding::Sse(binding) => {
635 Arc::new(SseTransportFactory::new(binding.clone()))
636 }
637 McpTransportBinding::Custom(factory) => factory.clone(),
638 };
639
640 let mut transport = factory.connect().await?;
641 let mut params = serde_json::Map::new();
642 params.insert("protocolVersion".into(), Value::String("2024-11-05".into()));
643 params.insert("capabilities".into(), json!({}));
644 params.insert(
645 "clientInfo".into(),
646 json!({
647 "name": "agentkit-mcp",
648 "version": env!("CARGO_PKG_VERSION")
649 }),
650 );
651 if let Some(auth) = auth {
652 params.insert("auth".into(), metadata_to_value(auth));
653 }
654 let init_params = Value::Object(params.clone());
655 transport
656 .send(McpFrame {
657 value: json!({
658 "jsonrpc": "2.0",
659 "id": 0,
660 "method": "initialize",
661 "params": init_params.clone()
662 }),
663 })
664 .await?;
665 let init_response = transport.recv().await?.ok_or_else(|| {
666 McpError::Transport("transport closed during MCP initialization".into())
667 })?;
668 if let Some(error) = init_response.value.get("error") {
669 if let Some(auth_request) =
670 parse_auth_request(&config.id, "initialize", &init_params, error)
671 {
672 return Err(McpError::AuthRequired(Box::new(auth_request)));
673 }
674 return Err(McpError::Invocation(error.to_string()));
675 }
676 transport
677 .send(McpFrame {
678 value: json!({
679 "jsonrpc": "2.0",
680 "method": "notifications/initialized",
681 "params": {}
682 }),
683 })
684 .await?;
685
686 Ok(Self {
687 server_id: config.id.clone(),
688 transport: Mutex::new(transport),
689 auth: Mutex::new(auth.cloned()),
690 next_id: AtomicU64::new(1),
691 })
692 }
693
694 pub fn server_id(&self) -> &McpServerId {
696 &self.server_id
697 }
698
699 pub async fn close(&self) -> Result<(), McpError> {
705 let mut transport = self.transport.lock().await;
706 transport.close().await
707 }
708
709 pub async fn resolve_auth(&self, resolution: AuthResolution) -> Result<(), McpError> {
719 let mut auth = self.auth.lock().await;
720 match resolution {
721 AuthResolution::Provided { credentials, .. } => {
722 *auth = Some(credentials);
723 }
724 AuthResolution::Cancelled { .. } => {
725 *auth = None;
726 }
727 }
728 Ok(())
729 }
730
731 pub async fn discover(&self) -> Result<McpDiscoverySnapshot, McpError> {
739 Ok(McpDiscoverySnapshot {
740 server_id: self.server_id.clone(),
741 tools: self.list_tools().await?,
742 resources: self.list_resources().await?,
743 prompts: self.list_prompts().await?,
744 metadata: MetadataMap::new(),
745 })
746 }
747
748 pub async fn list_tools(&self) -> Result<Vec<McpToolDescriptor>, McpError> {
754 let result = self.request("tools/list", json!({})).await?;
755 result
756 .get("tools")
757 .and_then(Value::as_array)
758 .cloned()
759 .unwrap_or_default()
760 .into_iter()
761 .map(parse_tool_descriptor)
762 .collect()
763 }
764
765 pub async fn list_resources(&self) -> Result<Vec<McpResourceDescriptor>, McpError> {
771 let result = self.request("resources/list", json!({})).await?;
772 result
773 .get("resources")
774 .and_then(Value::as_array)
775 .cloned()
776 .unwrap_or_default()
777 .into_iter()
778 .map(parse_resource_descriptor)
779 .collect()
780 }
781
782 pub async fn list_prompts(&self) -> Result<Vec<McpPromptDescriptor>, McpError> {
788 let result = self.request("prompts/list", json!({})).await?;
789 result
790 .get("prompts")
791 .and_then(Value::as_array)
792 .cloned()
793 .unwrap_or_default()
794 .into_iter()
795 .map(parse_prompt_descriptor)
796 .collect()
797 }
798
799 pub async fn call_tool(&self, name: &str, arguments: Value) -> Result<Value, McpError> {
811 self.request(
812 "tools/call",
813 json!({
814 "name": name,
815 "arguments": arguments,
816 }),
817 )
818 .await
819 }
820
821 pub async fn read_resource(&self, uri: &str) -> Result<ResourceContents, McpError> {
831 let result = self
832 .request(
833 "resources/read",
834 json!({
835 "uri": uri,
836 }),
837 )
838 .await?;
839 let content = result
840 .get("contents")
841 .and_then(Value::as_array)
842 .and_then(|values| values.first())
843 .cloned()
844 .ok_or_else(|| McpError::Protocol("resources/read returned no contents".into()))?;
845
846 let data = if let Some(text) = content.get("text").and_then(Value::as_str) {
847 DataRef::InlineText(text.into())
848 } else if let Some(found_uri) = content.get("uri").and_then(Value::as_str) {
849 DataRef::Uri(found_uri.into())
850 } else {
851 return Err(McpError::Protocol(
852 "unsupported resource content shape".into(),
853 ));
854 };
855
856 Ok(ResourceContents {
857 data,
858 metadata: MetadataMap::new(),
859 })
860 }
861
862 pub async fn get_prompt(
873 &self,
874 name: &str,
875 arguments: Value,
876 ) -> Result<PromptContents, McpError> {
877 let result = self
878 .request(
879 "prompts/get",
880 json!({
881 "name": name,
882 "arguments": arguments,
883 }),
884 )
885 .await?;
886 let items = result
887 .get("messages")
888 .and_then(Value::as_array)
889 .cloned()
890 .unwrap_or_default()
891 .into_iter()
892 .map(parse_prompt_message)
893 .collect::<Result<Vec<_>, _>>()?;
894
895 Ok(PromptContents {
896 items,
897 metadata: MetadataMap::new(),
898 })
899 }
900
901 async fn request(&self, method: &str, params: Value) -> Result<Value, McpError> {
902 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
903 let params = self.enrich_params(params.clone()).await;
904 let mut transport = self.transport.lock().await;
905 transport
906 .send(McpFrame {
907 value: json!({
908 "jsonrpc": "2.0",
909 "id": id,
910 "method": method,
911 "params": params,
912 }),
913 })
914 .await?;
915
916 loop {
917 let Some(frame) = transport.recv().await? else {
918 return Err(McpError::Transport(
919 "transport closed while waiting for MCP response".into(),
920 ));
921 };
922
923 if frame.value.get("id").and_then(Value::as_u64) != Some(id) {
924 continue;
925 }
926
927 if let Some(error) = frame.value.get("error") {
928 if let Some(auth_request) =
929 parse_auth_request(&self.server_id, method, ¶ms, error)
930 {
931 return Err(McpError::AuthRequired(Box::new(auth_request)));
932 }
933 return Err(McpError::Invocation(error.to_string()));
934 }
935
936 return frame
937 .value
938 .get("result")
939 .cloned()
940 .ok_or_else(|| McpError::Protocol("MCP response missing result".into()));
941 }
942 }
943
944 async fn enrich_params(&self, params: Value) -> Value {
945 let auth = self.auth.lock().await;
946 let Some(auth) = auth.as_ref() else {
947 return params;
948 };
949
950 match params {
951 Value::Object(mut object) => {
952 object
953 .entry("auth")
954 .or_insert_with(|| metadata_to_value(auth));
955 Value::Object(object)
956 }
957 other => other,
958 }
959 }
960
961 pub async fn replay_auth_operation(
971 &self,
972 operation: &AuthOperation,
973 ) -> Result<McpOperationResult, McpError> {
974 match operation {
975 AuthOperation::McpToolCall {
976 server_id,
977 tool_name,
978 input,
979 ..
980 } => {
981 self.ensure_server_match(server_id)?;
982 self.call_tool(tool_name, input.clone())
983 .await
984 .map(McpOperationResult::Tool)
985 }
986 AuthOperation::McpResourceRead {
987 server_id,
988 resource_id,
989 ..
990 } => {
991 self.ensure_server_match(server_id)?;
992 self.read_resource(resource_id)
993 .await
994 .map(McpOperationResult::Resource)
995 }
996 AuthOperation::McpPromptGet {
997 server_id,
998 prompt_id,
999 args,
1000 ..
1001 } => {
1002 self.ensure_server_match(server_id)?;
1003 self.get_prompt(prompt_id, args.clone())
1004 .await
1005 .map(McpOperationResult::Prompt)
1006 }
1007 AuthOperation::ToolCall {
1008 tool_name,
1009 input,
1010 metadata,
1011 ..
1012 } => {
1013 if let Some(server_id) = metadata.get("server_id").and_then(Value::as_str) {
1014 self.ensure_server_match(server_id)?;
1015 }
1016 let tool_name = normalize_mcp_tool_name(self.server_id(), tool_name);
1017 self.call_tool(&tool_name, input.clone())
1018 .await
1019 .map(McpOperationResult::Tool)
1020 }
1021 AuthOperation::McpConnect { .. } => Err(McpError::AuthResolution(
1022 "connect operations must be replayed through the server manager".into(),
1023 )),
1024 AuthOperation::Custom { kind, .. } => Err(McpError::AuthResolution(format!(
1025 "unsupported auth operation for replay: {kind}"
1026 ))),
1027 }
1028 }
1029
1030 fn ensure_server_match(&self, server_id: &str) -> Result<(), McpError> {
1031 if self.server_id.0 == server_id {
1032 Ok(())
1033 } else {
1034 Err(McpError::AuthResolution(format!(
1035 "auth operation targets server {server_id}, but connection is for {}",
1036 self.server_id
1037 )))
1038 }
1039 }
1040}
1041
1042pub struct McpInvocable {
1047 connection: Arc<McpConnection>,
1048 descriptor: McpToolDescriptor,
1049 spec: InvocableSpec,
1050}
1051
1052impl McpInvocable {
1053 pub fn new(connection: Arc<McpConnection>, descriptor: McpToolDescriptor) -> Self {
1060 let spec = InvocableSpec {
1061 name: CapabilityName::new(format!(
1062 "mcp.{}.{}",
1063 connection.server_id(),
1064 descriptor.name
1065 )),
1066 description: descriptor
1067 .description
1068 .clone()
1069 .unwrap_or_else(|| descriptor.name.clone()),
1070 input_schema: descriptor.input_schema.clone(),
1071 metadata: descriptor.metadata.clone(),
1072 };
1073
1074 Self {
1075 connection,
1076 descriptor,
1077 spec,
1078 }
1079 }
1080}
1081
1082#[async_trait]
1083impl Invocable for McpInvocable {
1084 fn spec(&self) -> &InvocableSpec {
1085 &self.spec
1086 }
1087
1088 async fn invoke(
1089 &self,
1090 request: InvocableRequest,
1091 _ctx: &mut CapabilityContext<'_>,
1092 ) -> Result<InvocableResult, CapabilityError> {
1093 let result = self
1094 .connection
1095 .call_tool(&self.descriptor.name, request.input)
1096 .await
1097 .map_err(|error| match error {
1098 McpError::AuthRequired(request) => {
1099 CapabilityError::Unavailable(format!("auth required: {:?}", request))
1100 }
1101 other => CapabilityError::ExecutionFailed(other.to_string()),
1102 })?;
1103
1104 Ok(InvocableResult {
1105 output: value_to_invocable_output(result),
1106 metadata: MetadataMap::new(),
1107 })
1108 }
1109}
1110
1111pub struct McpResourceHandle {
1116 connection: Arc<McpConnection>,
1117 descriptor: ResourceDescriptor,
1118}
1119
1120#[async_trait]
1121impl ResourceProvider for McpResourceHandle {
1122 async fn list_resources(&self) -> Result<Vec<ResourceDescriptor>, CapabilityError> {
1123 Ok(vec![self.descriptor.clone()])
1124 }
1125
1126 async fn read_resource(
1127 &self,
1128 id: &ResourceId,
1129 _ctx: &mut CapabilityContext<'_>,
1130 ) -> Result<ResourceContents, CapabilityError> {
1131 self.connection
1132 .read_resource(&id.0)
1133 .await
1134 .map_err(|error| match error {
1135 McpError::AuthRequired(request) => {
1136 CapabilityError::Unavailable(format!("auth required: {:?}", request))
1137 }
1138 other => CapabilityError::ExecutionFailed(other.to_string()),
1139 })
1140 }
1141}
1142
1143pub struct McpPromptHandle {
1148 connection: Arc<McpConnection>,
1149 descriptor: PromptDescriptor,
1150}
1151
1152#[async_trait]
1153impl PromptProvider for McpPromptHandle {
1154 async fn list_prompts(&self) -> Result<Vec<PromptDescriptor>, CapabilityError> {
1155 Ok(vec![self.descriptor.clone()])
1156 }
1157
1158 async fn get_prompt(
1159 &self,
1160 id: &PromptId,
1161 args: Value,
1162 _ctx: &mut CapabilityContext<'_>,
1163 ) -> Result<PromptContents, CapabilityError> {
1164 self.connection
1165 .get_prompt(&id.0, args)
1166 .await
1167 .map_err(|error| match error {
1168 McpError::AuthRequired(request) => {
1169 CapabilityError::Unavailable(format!("auth required: {:?}", request))
1170 }
1171 other => CapabilityError::ExecutionFailed(other.to_string()),
1172 })
1173 }
1174}
1175
1176pub struct McpCapabilityProvider {
1203 invocables: Vec<Arc<dyn Invocable>>,
1204 resources: Vec<Arc<dyn ResourceProvider>>,
1205 prompts: Vec<Arc<dyn PromptProvider>>,
1206}
1207
1208impl McpCapabilityProvider {
1209 pub fn from_snapshot(connection: Arc<McpConnection>, snapshot: &McpDiscoverySnapshot) -> Self {
1215 let invocables = snapshot
1216 .tools
1217 .iter()
1218 .cloned()
1219 .map(|descriptor| {
1220 Arc::new(McpInvocable::new(connection.clone(), descriptor)) as Arc<dyn Invocable>
1221 })
1222 .collect();
1223
1224 let resources = snapshot
1225 .resources
1226 .iter()
1227 .cloned()
1228 .map(|descriptor| {
1229 Arc::new(McpResourceHandle {
1230 connection: connection.clone(),
1231 descriptor: ResourceDescriptor {
1232 id: ResourceId::new(descriptor.id),
1233 name: descriptor.name,
1234 description: descriptor.description,
1235 mime_type: descriptor.mime_type,
1236 metadata: descriptor.metadata,
1237 },
1238 }) as Arc<dyn ResourceProvider>
1239 })
1240 .collect();
1241
1242 let prompts = snapshot
1243 .prompts
1244 .iter()
1245 .cloned()
1246 .map(|descriptor| {
1247 Arc::new(McpPromptHandle {
1248 connection: connection.clone(),
1249 descriptor: PromptDescriptor {
1250 id: PromptId::new(descriptor.id),
1251 name: descriptor.name,
1252 description: descriptor.description,
1253 input_schema: descriptor.input_schema,
1254 metadata: descriptor.metadata,
1255 },
1256 }) as Arc<dyn PromptProvider>
1257 })
1258 .collect();
1259
1260 Self {
1261 invocables,
1262 resources,
1263 prompts,
1264 }
1265 }
1266
1267 pub fn merge<I>(providers: I) -> Self
1272 where
1273 I: IntoIterator<Item = Self>,
1274 {
1275 let mut invocables = Vec::new();
1276 let mut resources = Vec::new();
1277 let mut prompts = Vec::new();
1278
1279 for provider in providers {
1280 invocables.extend(provider.invocables);
1281 resources.extend(provider.resources);
1282 prompts.extend(provider.prompts);
1283 }
1284
1285 Self {
1286 invocables,
1287 resources,
1288 prompts,
1289 }
1290 }
1291
1292 pub async fn connect(
1301 config: &McpServerConfig,
1302 ) -> Result<(Arc<McpConnection>, Self, McpDiscoverySnapshot), McpError> {
1303 let connection = Arc::new(McpConnection::connect(config).await?);
1304 let snapshot = connection.discover().await?;
1305 let provider = Self::from_snapshot(connection.clone(), &snapshot);
1306
1307 Ok((connection, provider, snapshot))
1308 }
1309}
1310
1311impl CapabilityProvider for McpCapabilityProvider {
1312 fn invocables(&self) -> Vec<Arc<dyn Invocable>> {
1313 self.invocables.clone()
1314 }
1315
1316 fn resources(&self) -> Vec<Arc<dyn ResourceProvider>> {
1317 self.resources.clone()
1318 }
1319
1320 fn prompts(&self) -> Vec<Arc<dyn PromptProvider>> {
1321 self.prompts.clone()
1322 }
1323}
1324
1325#[derive(Clone)]
1331pub struct McpServerHandle {
1332 config: McpServerConfig,
1333 connection: Arc<McpConnection>,
1334 snapshot: McpDiscoverySnapshot,
1335}
1336
1337impl McpServerHandle {
1338 pub fn config(&self) -> &McpServerConfig {
1340 &self.config
1341 }
1342
1343 pub fn server_id(&self) -> &McpServerId {
1345 self.connection.server_id()
1346 }
1347
1348 pub fn connection(&self) -> Arc<McpConnection> {
1350 self.connection.clone()
1351 }
1352
1353 pub fn snapshot(&self) -> &McpDiscoverySnapshot {
1355 &self.snapshot
1356 }
1357
1358 pub fn tool_registry(&self) -> ToolRegistry {
1361 self.snapshot
1362 .tools
1363 .iter()
1364 .cloned()
1365 .fold(ToolRegistry::new(), |registry, descriptor| {
1366 registry.with(McpToolAdapter::new(
1367 self.server_id(),
1368 self.connection.clone(),
1369 descriptor,
1370 ))
1371 })
1372 }
1373
1374 pub fn capability_provider(&self) -> McpCapabilityProvider {
1376 McpCapabilityProvider::from_snapshot(self.connection.clone(), &self.snapshot)
1377 }
1378}
1379
1380#[derive(Default)]
1421pub struct McpServerManager {
1422 configs: BTreeMap<McpServerId, McpServerConfig>,
1423 connections: BTreeMap<McpServerId, McpServerHandle>,
1424 auth: BTreeMap<McpServerId, MetadataMap>,
1425}
1426
1427impl McpServerManager {
1428 pub fn new() -> Self {
1430 Self::default()
1431 }
1432
1433 pub fn with_server(mut self, config: McpServerConfig) -> Self {
1438 self.register_server(config);
1439 self
1440 }
1441
1442 pub fn register_server(&mut self, config: McpServerConfig) -> &mut Self {
1447 self.configs.insert(config.id.clone(), config);
1448 self
1449 }
1450
1451 pub fn connected_server(&self, server_id: &McpServerId) -> Option<&McpServerHandle> {
1453 self.connections.get(server_id)
1454 }
1455
1456 pub fn connected_servers(&self) -> Vec<&McpServerHandle> {
1458 self.connections.values().collect()
1459 }
1460
1461 pub async fn connect_server(
1470 &mut self,
1471 server_id: &McpServerId,
1472 ) -> Result<McpServerHandle, McpError> {
1473 let config = self
1474 .configs
1475 .get(server_id)
1476 .cloned()
1477 .ok_or_else(|| McpError::UnknownServer(server_id.to_string()))?;
1478 let connection =
1479 Arc::new(McpConnection::connect_with_auth(&config, self.auth.get(server_id)).await?);
1480 let snapshot = connection.discover().await?;
1481 let handle = McpServerHandle {
1482 config,
1483 connection,
1484 snapshot,
1485 };
1486 self.connections.insert(server_id.clone(), handle.clone());
1487 Ok(handle)
1488 }
1489
1490 pub async fn connect_all(&mut self) -> Result<Vec<McpServerHandle>, McpError> {
1500 let server_ids = self.configs.keys().cloned().collect::<Vec<_>>();
1501 let mut handles = Vec::with_capacity(server_ids.len());
1502
1503 for server_id in server_ids {
1504 handles.push(self.connect_server(&server_id).await?);
1505 }
1506
1507 Ok(handles)
1508 }
1509
1510 pub async fn refresh_server(
1520 &mut self,
1521 server_id: &McpServerId,
1522 ) -> Result<McpDiscoverySnapshot, McpError> {
1523 let handle = self
1524 .connections
1525 .get_mut(server_id)
1526 .ok_or_else(|| McpError::UnknownServer(server_id.to_string()))?;
1527 let snapshot = handle.connection.discover().await?;
1528 handle.snapshot = snapshot.clone();
1529 Ok(snapshot)
1530 }
1531
1532 pub async fn disconnect_server(&mut self, server_id: &McpServerId) -> Result<(), McpError> {
1541 let Some(handle) = self.connections.remove(server_id) else {
1542 return Err(McpError::UnknownServer(server_id.to_string()));
1543 };
1544 handle.connection.close().await
1545 }
1546
1547 pub async fn resolve_auth(&mut self, resolution: AuthResolution) -> Result<(), McpError> {
1555 let server_id = resolution
1556 .request()
1557 .server_id()
1558 .ok_or_else(|| McpError::AuthResolution("auth resolution missing server id".into()))?;
1559 let server_id = McpServerId::new(server_id);
1560 match &resolution {
1561 AuthResolution::Provided { credentials, .. } => {
1562 self.auth.insert(server_id.clone(), credentials.clone());
1563 }
1564 AuthResolution::Cancelled { .. } => {
1565 self.auth.remove(&server_id);
1566 }
1567 }
1568
1569 if let Some(handle) = self.connections.get(&server_id) {
1570 handle.connection.resolve_auth(resolution).await?;
1571 return Ok(());
1572 }
1573
1574 if self.configs.contains_key(&server_id) {
1575 Ok(())
1576 } else {
1577 Err(McpError::UnknownServer(server_id.to_string()))
1578 }
1579 }
1580
1581 pub async fn resolve_auth_and_resume(
1591 &mut self,
1592 resolution: AuthResolution,
1593 ) -> Result<McpOperationResult, McpError> {
1594 let request = resolution.request().clone();
1595 self.resolve_auth(resolution).await?;
1596 self.replay_auth_request(&request).await
1597 }
1598
1599 pub async fn replay_auth_request(
1609 &mut self,
1610 request: &AuthRequest,
1611 ) -> Result<McpOperationResult, McpError> {
1612 match &request.operation {
1613 AuthOperation::McpConnect { server_id, .. } => {
1614 let server_id = McpServerId::new(server_id);
1615 let handle = self.connect_server(&server_id).await?;
1616 Ok(McpOperationResult::Connected(handle.snapshot.clone()))
1617 }
1618 AuthOperation::McpToolCall { server_id, .. }
1619 | AuthOperation::McpResourceRead { server_id, .. }
1620 | AuthOperation::McpPromptGet { server_id, .. } => {
1621 let connection = self.connection_for_auth_server(server_id).await?;
1622 connection.replay_auth_operation(&request.operation).await
1623 }
1624 AuthOperation::ToolCall { metadata, .. } => {
1625 let server_id = metadata
1626 .get("server_id")
1627 .and_then(Value::as_str)
1628 .ok_or_else(|| {
1629 McpError::AuthResolution(
1630 "tool-call auth replay requires metadata.server_id".into(),
1631 )
1632 })?;
1633 let connection = self.connection_for_auth_server(server_id).await?;
1634 connection.replay_auth_operation(&request.operation).await
1635 }
1636 AuthOperation::Custom { kind, .. } => Err(McpError::AuthResolution(format!(
1637 "unsupported auth operation for replay: {kind}"
1638 ))),
1639 }
1640 }
1641
1642 async fn connection_for_auth_server(
1643 &mut self,
1644 server_id: &str,
1645 ) -> Result<Arc<McpConnection>, McpError> {
1646 let server_id = McpServerId::new(server_id);
1647 if !self.connections.contains_key(&server_id) {
1648 self.connect_server(&server_id).await?;
1649 }
1650 self.connections
1651 .get(&server_id)
1652 .map(McpServerHandle::connection)
1653 .ok_or_else(|| McpError::UnknownServer(server_id.to_string()))
1654 }
1655
1656 pub fn tool_registry(&self) -> ToolRegistry {
1661 self.connections
1662 .values()
1663 .fold(ToolRegistry::new(), |mut registry, handle| {
1664 for tool in handle.snapshot.tools.iter().cloned() {
1665 registry.register(McpToolAdapter::new(
1666 handle.server_id(),
1667 handle.connection.clone(),
1668 tool,
1669 ));
1670 }
1671 registry
1672 })
1673 }
1674
1675 pub fn capability_provider(&self) -> McpCapabilityProvider {
1678 McpCapabilityProvider::merge(
1679 self.connections
1680 .values()
1681 .map(McpServerHandle::capability_provider),
1682 )
1683 }
1684}
1685
1686pub struct McpToolAdapter {
1702 descriptor: McpToolDescriptor,
1703 connection: Arc<McpConnection>,
1704 spec: ToolSpec,
1705}
1706
1707impl McpToolAdapter {
1708 pub fn new(
1716 server_id: &McpServerId,
1717 connection: Arc<McpConnection>,
1718 descriptor: McpToolDescriptor,
1719 ) -> Self {
1720 let spec = ToolSpec {
1721 name: ToolName::new(format!("mcp.{}.{}", server_id, descriptor.name)),
1722 description: descriptor
1723 .description
1724 .clone()
1725 .unwrap_or_else(|| descriptor.name.clone()),
1726 input_schema: descriptor.input_schema.clone(),
1727 annotations: ToolAnnotations::default(),
1728 metadata: descriptor.metadata.clone(),
1729 };
1730
1731 Self {
1732 descriptor,
1733 connection,
1734 spec,
1735 }
1736 }
1737}
1738
1739#[async_trait]
1740impl Tool for McpToolAdapter {
1741 fn spec(&self) -> &ToolSpec {
1742 &self.spec
1743 }
1744
1745 async fn invoke(
1746 &self,
1747 request: ToolRequest,
1748 _ctx: &mut ToolContext<'_>,
1749 ) -> Result<ToolResult, ToolError> {
1750 let result = self
1751 .connection
1752 .call_tool(&self.descriptor.name, request.input)
1753 .await
1754 .map_err(|error| match error {
1755 McpError::AuthRequired(request) => ToolError::AuthRequired(request),
1756 other => ToolError::ExecutionFailed(other.to_string()),
1757 })?;
1758
1759 Ok(ToolResult {
1760 result: ToolResultPart {
1761 call_id: request.call_id,
1762 output: invocable_output_to_tool_output(value_to_invocable_output(result)),
1763 is_error: false,
1764 metadata: MetadataMap::new(),
1765 },
1766 duration: None,
1767 metadata: MetadataMap::new(),
1768 })
1769 }
1770}
1771
1772fn parse_tool_descriptor(value: Value) -> Result<McpToolDescriptor, McpError> {
1773 Ok(McpToolDescriptor {
1774 name: required_string(&value, "name")?,
1775 description: value
1776 .get("description")
1777 .and_then(Value::as_str)
1778 .map(str::to_owned),
1779 input_schema: value
1780 .get("inputSchema")
1781 .cloned()
1782 .unwrap_or_else(|| json!({ "type": "object" })),
1783 metadata: MetadataMap::new(),
1784 })
1785}
1786
1787fn parse_resource_descriptor(value: Value) -> Result<McpResourceDescriptor, McpError> {
1788 Ok(McpResourceDescriptor {
1789 id: required_string(&value, "uri")?,
1790 name: value
1791 .get("name")
1792 .and_then(Value::as_str)
1793 .map(str::to_owned)
1794 .unwrap_or_else(|| {
1795 value
1796 .get("uri")
1797 .and_then(Value::as_str)
1798 .unwrap_or_default()
1799 .to_string()
1800 }),
1801 description: value
1802 .get("description")
1803 .and_then(Value::as_str)
1804 .map(str::to_owned),
1805 mime_type: value
1806 .get("mimeType")
1807 .and_then(Value::as_str)
1808 .map(str::to_owned),
1809 metadata: MetadataMap::new(),
1810 })
1811}
1812
1813fn parse_prompt_descriptor(value: Value) -> Result<McpPromptDescriptor, McpError> {
1814 let name = required_string(&value, "name")?;
1815 let properties = value
1816 .get("arguments")
1817 .and_then(Value::as_array)
1818 .cloned()
1819 .unwrap_or_default()
1820 .into_iter()
1821 .filter_map(|arg| {
1822 let name = arg.get("name")?.as_str()?.to_string();
1823 Some((name, json!({ "type": "string" })))
1824 })
1825 .collect::<serde_json::Map<String, Value>>();
1826
1827 Ok(McpPromptDescriptor {
1828 id: name.clone(),
1829 name,
1830 description: value
1831 .get("description")
1832 .and_then(Value::as_str)
1833 .map(str::to_owned),
1834 input_schema: json!({
1835 "type": "object",
1836 "properties": properties,
1837 }),
1838 metadata: MetadataMap::new(),
1839 })
1840}
1841
1842fn parse_prompt_message(value: Value) -> Result<Item, McpError> {
1843 let role = value.get("role").and_then(Value::as_str).unwrap_or("user");
1844 let kind = match role {
1845 "assistant" => ItemKind::Assistant,
1846 "system" => ItemKind::System,
1847 _ => ItemKind::User,
1848 };
1849
1850 let content = value.get("content").cloned().unwrap_or(Value::Null);
1851 let text = if let Some(text) = content.get("text").and_then(Value::as_str) {
1852 text.to_string()
1853 } else if let Some(text) = content.as_str() {
1854 text.to_string()
1855 } else {
1856 content.to_string()
1857 };
1858
1859 Ok(Item {
1860 id: None,
1861 kind,
1862 parts: vec![Part::Text(TextPart {
1863 text,
1864 metadata: MetadataMap::new(),
1865 })],
1866 metadata: MetadataMap::new(),
1867 })
1868}
1869
1870fn required_string(value: &Value, field: &str) -> Result<String, McpError> {
1871 value
1872 .get(field)
1873 .and_then(Value::as_str)
1874 .map(str::to_owned)
1875 .ok_or_else(|| McpError::Protocol(format!("missing string field {field}")))
1876}
1877
1878fn value_to_invocable_output(value: Value) -> InvocableOutput {
1879 if let Some(content) = value.get("content").and_then(Value::as_array) {
1880 let text = content
1881 .iter()
1882 .filter_map(|item| item.get("text").and_then(Value::as_str))
1883 .collect::<Vec<_>>()
1884 .join("\n");
1885 if !text.is_empty() {
1886 return InvocableOutput::Text(text);
1887 }
1888 }
1889
1890 if let Some(text) = value.as_str() {
1891 InvocableOutput::Text(text.to_string())
1892 } else {
1893 InvocableOutput::Structured(value)
1894 }
1895}
1896
1897fn invocable_output_to_tool_output(output: InvocableOutput) -> ToolOutput {
1898 match output {
1899 InvocableOutput::Text(text) => ToolOutput::Text(text),
1900 InvocableOutput::Structured(value) => ToolOutput::Structured(value),
1901 InvocableOutput::Items(items) => {
1902 ToolOutput::Parts(items.into_iter().flat_map(|item| item.parts).collect())
1903 }
1904 InvocableOutput::Data(data) => ToolOutput::Structured(json!({ "data": data })),
1905 }
1906}
1907
1908fn metadata_to_value(metadata: &MetadataMap) -> Value {
1909 Value::Object(
1910 metadata
1911 .iter()
1912 .map(|(key, value)| (key.clone(), value.clone()))
1913 .collect(),
1914 )
1915}
1916
1917fn parse_auth_request(
1918 server_id: &McpServerId,
1919 method: &str,
1920 params: &Value,
1921 error: &Value,
1922) -> Option<AuthRequest> {
1923 let code = error.get("code").and_then(Value::as_i64);
1924 let message = error.get("message").and_then(Value::as_str);
1925 let data = error.get("data");
1926
1927 let auth_marker = matches!(code, Some(401 | -32001))
1928 || data
1929 .and_then(|data| data.get("auth_required"))
1930 .and_then(Value::as_bool)
1931 == Some(true)
1932 || data.and_then(|data| data.get("auth")).is_some();
1933
1934 if !auth_marker {
1935 return None;
1936 }
1937
1938 let mut challenge = MetadataMap::new();
1939 challenge.insert("server_id".into(), Value::String(server_id.to_string()));
1940 challenge.insert("method".into(), Value::String(method.into()));
1941
1942 if let Some(code) = code {
1943 challenge.insert("code".into(), Value::Number(code.into()));
1944 }
1945 if let Some(message) = message {
1946 challenge.insert("message".into(), Value::String(message.into()));
1947 }
1948 if let Some(data) = data {
1949 challenge.insert("data".into(), data.clone());
1950 }
1951
1952 Some(AuthRequest {
1953 task_id: None,
1954 id: format!("mcp:{}:{}", server_id, method),
1955 provider: format!("mcp.{}", server_id),
1956 operation: auth_operation_for_method(server_id, method, params),
1957 challenge,
1958 })
1959}
1960
1961fn auth_operation_for_method(
1962 server_id: &McpServerId,
1963 method: &str,
1964 params: &Value,
1965) -> AuthOperation {
1966 match method {
1967 "initialize" => AuthOperation::McpConnect {
1968 server_id: server_id.to_string(),
1969 metadata: MetadataMap::new(),
1970 },
1971 "tools/call" => AuthOperation::McpToolCall {
1972 server_id: server_id.to_string(),
1973 tool_name: params
1974 .get("name")
1975 .and_then(Value::as_str)
1976 .unwrap_or_default()
1977 .to_string(),
1978 input: params
1979 .get("arguments")
1980 .cloned()
1981 .unwrap_or_else(|| json!({})),
1982 metadata: MetadataMap::new(),
1983 },
1984 "resources/read" => AuthOperation::McpResourceRead {
1985 server_id: server_id.to_string(),
1986 resource_id: params
1987 .get("uri")
1988 .and_then(Value::as_str)
1989 .unwrap_or_default()
1990 .to_string(),
1991 metadata: MetadataMap::new(),
1992 },
1993 "prompts/get" => AuthOperation::McpPromptGet {
1994 server_id: server_id.to_string(),
1995 prompt_id: params
1996 .get("name")
1997 .and_then(Value::as_str)
1998 .unwrap_or_default()
1999 .to_string(),
2000 args: params
2001 .get("arguments")
2002 .cloned()
2003 .unwrap_or_else(|| json!({})),
2004 metadata: MetadataMap::new(),
2005 },
2006 other => AuthOperation::Custom {
2007 kind: format!("mcp.{other}"),
2008 payload: params.clone(),
2009 metadata: {
2010 let mut metadata = MetadataMap::new();
2011 metadata.insert("server_id".into(), Value::String(server_id.to_string()));
2012 metadata
2013 },
2014 },
2015 }
2016}
2017
2018fn normalize_mcp_tool_name(server_id: &McpServerId, tool_name: &str) -> String {
2019 let prefix = format!("mcp.{server_id}.");
2020 tool_name
2021 .strip_prefix(&prefix)
2022 .unwrap_or(tool_name)
2023 .to_string()
2024}
2025
2026async fn read_sse_stream<R>(
2027 mut reader: R,
2028 response_url: Url,
2029 frame_tx: mpsc::UnboundedSender<Result<McpFrame, McpError>>,
2030 endpoint_tx: oneshot::Sender<Result<Url, McpError>>,
2031) where
2032 R: AsyncBufRead + Unpin,
2033{
2034 let mut endpoint_tx = Some(endpoint_tx);
2035 let mut event_name: Option<String> = None;
2036 let mut data_lines = Vec::new();
2037
2038 loop {
2039 let mut line = String::new();
2040 match reader.read_line(&mut line).await {
2041 Ok(0) => break,
2042 Ok(_) => {
2043 let line = line.trim_end_matches(['\r', '\n']);
2044 if line.is_empty() {
2045 dispatch_sse_event(
2046 &response_url,
2047 &mut endpoint_tx,
2048 &frame_tx,
2049 event_name.take(),
2050 std::mem::take(&mut data_lines),
2051 );
2052 continue;
2053 }
2054 if line.starts_with(':') {
2055 continue;
2056 }
2057 if let Some(rest) = line.strip_prefix("event:") {
2058 event_name = Some(rest.trim_start().to_string());
2059 continue;
2060 }
2061 if let Some(rest) = line.strip_prefix("data:") {
2062 data_lines.push(rest.trim_start().to_string());
2063 }
2064 }
2065 Err(error) => {
2066 let error = McpError::Io(error);
2067 if let Some(tx) = endpoint_tx.take() {
2068 let _ = tx.send(Err(error));
2069 } else {
2070 let _ = frame_tx.send(Err(error));
2071 }
2072 return;
2073 }
2074 }
2075 }
2076
2077 if event_name.is_some() || !data_lines.is_empty() {
2078 dispatch_sse_event(
2079 &response_url,
2080 &mut endpoint_tx,
2081 &frame_tx,
2082 event_name.take(),
2083 std::mem::take(&mut data_lines),
2084 );
2085 }
2086
2087 if let Some(tx) = endpoint_tx.take() {
2088 let _ = tx.send(Err(McpError::Transport(
2089 "SSE stream ended before endpoint event".into(),
2090 )));
2091 }
2092}
2093
2094fn dispatch_sse_event(
2095 response_url: &Url,
2096 endpoint_tx: &mut Option<oneshot::Sender<Result<Url, McpError>>>,
2097 frame_tx: &mpsc::UnboundedSender<Result<McpFrame, McpError>>,
2098 event_name: Option<String>,
2099 data_lines: Vec<String>,
2100) {
2101 if data_lines.is_empty() {
2102 return;
2103 }
2104
2105 let event_name = event_name.unwrap_or_else(|| "message".into());
2106 let data = data_lines.join("\n");
2107
2108 if event_name == "endpoint" {
2109 if let Some(tx) = endpoint_tx.take() {
2110 let _ = tx.send(resolve_sse_endpoint(response_url, &data));
2111 }
2112 return;
2113 }
2114
2115 if event_name != "message" {
2116 return;
2117 }
2118
2119 let value = serde_json::from_str(&data).map_err(McpError::Serialize);
2120 let _ = frame_tx.send(value.map(|value| McpFrame { value }));
2121}
2122
2123fn resolve_sse_endpoint(response_url: &Url, endpoint: &str) -> Result<Url, McpError> {
2124 response_url
2125 .join(endpoint.trim())
2126 .map_err(|error| McpError::Transport(format!("invalid SSE endpoint URL: {error}")))
2127}
2128
2129#[derive(Debug, Error)]
2131pub enum McpError {
2132 #[error("io error: {0}")]
2134 Io(#[from] std::io::Error),
2135 #[error("http error: {0}")]
2137 Http(#[from] reqwest::Error),
2138 #[error("serialization error: {0}")]
2140 Serialize(#[from] serde_json::Error),
2141 #[error("transport error: {0}")]
2143 Transport(String),
2144 #[error("protocol error: {0}")]
2146 Protocol(String),
2147 #[error("MCP auth required: {0:?}")]
2150 AuthRequired(Box<AuthRequest>),
2151 #[error("auth resolution error: {0}")]
2153 AuthResolution(String),
2154 #[error("invocation error: {0}")]
2156 Invocation(String),
2157 #[error("unknown MCP server: {0}")]
2159 UnknownServer(String),
2160}
2161
2162#[cfg(test)]
2163mod tests {
2164 use std::collections::VecDeque;
2165 use std::sync::{Arc as StdArc, Mutex as StdMutex};
2166
2167 use super::*;
2168 use agentkit_tools_core::{PermissionChecker, PermissionDecision, PermissionRequest};
2169 use tokio::io::{AsyncReadExt, AsyncWriteExt};
2170 use tokio::net::TcpListener;
2171
2172 struct AllowAll;
2173
2174 impl PermissionChecker for AllowAll {
2175 fn evaluate(&self, _request: &dyn PermissionRequest) -> PermissionDecision {
2176 PermissionDecision::Allow
2177 }
2178 }
2179
2180 struct FakeTransport {
2181 recv: VecDeque<Value>,
2182 }
2183
2184 #[async_trait]
2185 impl McpTransport for FakeTransport {
2186 async fn send(&mut self, _message: McpFrame) -> Result<(), McpError> {
2187 Ok(())
2188 }
2189
2190 async fn recv(&mut self) -> Result<Option<McpFrame>, McpError> {
2191 Ok(self.recv.pop_front().map(|value| McpFrame { value }))
2192 }
2193
2194 async fn close(&mut self) -> Result<(), McpError> {
2195 Ok(())
2196 }
2197 }
2198
2199 fn fake_connection(responses: Vec<Value>) -> McpConnection {
2200 McpConnection {
2201 server_id: McpServerId::new("fake"),
2202 transport: Mutex::new(Box::new(FakeTransport {
2203 recv: responses.into(),
2204 })),
2205 auth: Mutex::new(None),
2206 next_id: AtomicU64::new(1),
2207 }
2208 }
2209
2210 #[derive(Clone)]
2211 struct FakeTransportFactory {
2212 responses: StdArc<StdMutex<VecDeque<Vec<Value>>>>,
2213 }
2214
2215 impl FakeTransportFactory {
2216 fn new(sequences: Vec<Vec<Value>>) -> Self {
2217 Self {
2218 responses: StdArc::new(StdMutex::new(sequences.into())),
2219 }
2220 }
2221 }
2222
2223 #[async_trait]
2224 impl McpTransportFactory for FakeTransportFactory {
2225 async fn connect(&self) -> Result<Box<dyn McpTransport>, McpError> {
2226 let responses =
2227 self.responses.lock().unwrap().pop_front().ok_or_else(|| {
2228 McpError::Transport("no fake transport responses left".into())
2229 })?;
2230 Ok(Box::new(FakeTransport {
2231 recv: responses.into(),
2232 }))
2233 }
2234 }
2235
2236 #[tokio::test]
2237 async fn discovery_parses_snapshot() {
2238 let connection = fake_connection(vec![
2239 json!({ "jsonrpc": "2.0", "id": 1, "result": { "tools": [{ "name": "echo", "description": "Echo", "inputSchema": {"type": "object"} }] } }),
2240 json!({ "jsonrpc": "2.0", "id": 2, "result": { "resources": [{ "uri": "file:///tmp/example.txt", "name": "example.txt", "mimeType": "text/plain" }] } }),
2241 json!({ "jsonrpc": "2.0", "id": 3, "result": { "prompts": [{ "name": "summarize", "description": "Summarize", "arguments": [{ "name": "path" }] }] } }),
2242 ]);
2243
2244 let snapshot = connection.discover().await.unwrap();
2245 assert_eq!(snapshot.tools[0].name, "echo");
2246 assert_eq!(snapshot.resources[0].id, "file:///tmp/example.txt");
2247 assert_eq!(snapshot.prompts[0].id, "summarize");
2248 }
2249
2250 #[tokio::test]
2251 async fn tool_adapter_returns_text_output() {
2252 let connection = Arc::new(fake_connection(vec![json!({
2253 "jsonrpc": "2.0",
2254 "id": 1,
2255 "result": { "content": [{ "type": "text", "text": "pong" }] }
2256 })]));
2257 let server_id = connection.server_id().clone();
2258 let adapter = McpToolAdapter::new(
2259 &server_id,
2260 connection,
2261 McpToolDescriptor {
2262 name: "echo".into(),
2263 description: Some("Echo".into()),
2264 input_schema: json!({ "type": "object" }),
2265 metadata: MetadataMap::new(),
2266 },
2267 );
2268 let metadata = MetadataMap::new();
2269 let mut ctx = ToolContext {
2270 capability: CapabilityContext {
2271 session_id: None,
2272 turn_id: None,
2273 metadata: &metadata,
2274 },
2275 permissions: &AllowAll,
2276 resources: &(),
2277 cancellation: None,
2278 };
2279
2280 let result = adapter
2281 .invoke(
2282 ToolRequest {
2283 call_id: "call-1".into(),
2284 tool_name: ToolName::new("mcp.fake.echo"),
2285 input: json!({}),
2286 session_id: "session-1".into(),
2287 turn_id: "turn-1".into(),
2288 metadata: MetadataMap::new(),
2289 },
2290 &mut ctx,
2291 )
2292 .await
2293 .unwrap();
2294
2295 assert_eq!(result.result.output, ToolOutput::Text("pong".into()));
2296 }
2297
2298 #[tokio::test]
2299 async fn request_surfaces_auth_required_errors() {
2300 let connection = fake_connection(vec![json!({
2301 "jsonrpc": "2.0",
2302 "id": 1,
2303 "error": {
2304 "code": -32001,
2305 "message": "authentication required",
2306 "data": {
2307 "auth_required": true,
2308 "scope": "secrets.read"
2309 }
2310 }
2311 })]);
2312
2313 let error = connection.call_tool("echo", json!({})).await.unwrap_err();
2314 match error {
2315 McpError::AuthRequired(request) => {
2316 assert_eq!(request.provider, "mcp.fake");
2317 assert_eq!(
2318 request.challenge.get("method"),
2319 Some(&Value::String("tools/call".into()))
2320 );
2321 assert!(matches!(
2322 request.operation,
2323 AuthOperation::McpToolCall { ref tool_name, .. } if tool_name == "echo"
2324 ));
2325 }
2326 other => panic!("unexpected error: {other:?}"),
2327 }
2328 }
2329
2330 #[tokio::test]
2331 async fn tool_adapter_maps_auth_required_into_tool_error() {
2332 let connection = Arc::new(fake_connection(vec![json!({
2333 "jsonrpc": "2.0",
2334 "id": 1,
2335 "error": {
2336 "code": -32001,
2337 "message": "authentication required",
2338 "data": { "auth_required": true }
2339 }
2340 })]));
2341 let server_id = connection.server_id().clone();
2342 let adapter = McpToolAdapter::new(
2343 &server_id,
2344 connection,
2345 McpToolDescriptor {
2346 name: "echo".into(),
2347 description: Some("Echo".into()),
2348 input_schema: json!({ "type": "object" }),
2349 metadata: MetadataMap::new(),
2350 },
2351 );
2352 let metadata = MetadataMap::new();
2353 let mut ctx = ToolContext {
2354 capability: CapabilityContext {
2355 session_id: None,
2356 turn_id: None,
2357 metadata: &metadata,
2358 },
2359 permissions: &AllowAll,
2360 resources: &(),
2361 cancellation: None,
2362 };
2363
2364 let error = adapter
2365 .invoke(
2366 ToolRequest {
2367 call_id: "call-1".into(),
2368 tool_name: ToolName::new("mcp.fake.echo"),
2369 input: json!({}),
2370 session_id: "session-1".into(),
2371 turn_id: "turn-1".into(),
2372 metadata: MetadataMap::new(),
2373 },
2374 &mut ctx,
2375 )
2376 .await
2377 .unwrap_err();
2378
2379 match error {
2380 ToolError::AuthRequired(request) => {
2381 assert_eq!(request.provider, "mcp.fake");
2382 }
2383 other => panic!("unexpected error: {other:?}"),
2384 }
2385 }
2386
2387 struct RecordingTransport {
2388 recv: VecDeque<Value>,
2389 sent: StdArc<StdMutex<Vec<Value>>>,
2390 }
2391
2392 #[async_trait]
2393 impl McpTransport for RecordingTransport {
2394 async fn send(&mut self, message: McpFrame) -> Result<(), McpError> {
2395 self.sent.lock().unwrap().push(message.value);
2396 Ok(())
2397 }
2398
2399 async fn recv(&mut self) -> Result<Option<McpFrame>, McpError> {
2400 Ok(self.recv.pop_front().map(|value| McpFrame { value }))
2401 }
2402
2403 async fn close(&mut self) -> Result<(), McpError> {
2404 Ok(())
2405 }
2406 }
2407
2408 #[derive(Clone)]
2409 struct RecordingTransportFactory {
2410 responses: StdArc<StdMutex<VecDeque<Vec<Value>>>>,
2411 sent: StdArc<StdMutex<Vec<Value>>>,
2412 }
2413
2414 impl RecordingTransportFactory {
2415 fn new(sequences: Vec<Vec<Value>>) -> Self {
2416 Self {
2417 responses: StdArc::new(StdMutex::new(sequences.into())),
2418 sent: StdArc::new(StdMutex::new(Vec::new())),
2419 }
2420 }
2421
2422 fn sent(&self) -> Vec<Value> {
2423 self.sent.lock().unwrap().clone()
2424 }
2425 }
2426
2427 #[async_trait]
2428 impl McpTransportFactory for RecordingTransportFactory {
2429 async fn connect(&self) -> Result<Box<dyn McpTransport>, McpError> {
2430 let responses = self.responses.lock().unwrap().pop_front().ok_or_else(|| {
2431 McpError::Transport("no recording transport responses left".into())
2432 })?;
2433 Ok(Box::new(RecordingTransport {
2434 recv: responses.into(),
2435 sent: self.sent.clone(),
2436 }))
2437 }
2438 }
2439
2440 #[tokio::test]
2441 async fn connection_includes_resolved_auth_in_future_requests() {
2442 let factory = RecordingTransportFactory::new(vec![vec![
2443 json!({ "jsonrpc": "2.0", "id": 0, "result": { "capabilities": {} } }),
2444 json!({ "jsonrpc": "2.0", "id": 1, "result": { "content": [{ "type": "text", "text": "ok" }] } }),
2445 ]]);
2446 let config = McpServerConfig::new(
2447 "recording",
2448 McpTransportBinding::Custom(Arc::new(factory.clone())),
2449 );
2450 let connection = McpConnection::connect(&config).await.unwrap();
2451 let mut auth = MetadataMap::new();
2452 auth.insert("token".into(), json!("secret-token"));
2453 let request = AuthRequest {
2454 task_id: None,
2455 id: "auth-recording-tool".into(),
2456 provider: "mcp.recording".into(),
2457 operation: AuthOperation::McpToolCall {
2458 server_id: "recording".into(),
2459 tool_name: "echo".into(),
2460 input: json!({}),
2461 metadata: MetadataMap::new(),
2462 },
2463 challenge: MetadataMap::new(),
2464 };
2465 connection
2466 .resolve_auth(agentkit_tools_core::AuthResolution::Provided {
2467 request,
2468 credentials: auth,
2469 })
2470 .await
2471 .unwrap();
2472
2473 let _ = connection.call_tool("echo", json!({})).await.unwrap();
2474 let sent = factory.sent();
2475 assert!(
2476 sent.iter().any(|value| {
2477 value
2478 .get("params")
2479 .and_then(|params| params.get("auth"))
2480 .and_then(|auth| auth.get("token"))
2481 == Some(&json!("secret-token"))
2482 }),
2483 "expected an MCP request to include the resolved auth payload, saw {:?}",
2484 sent
2485 );
2486 }
2487
2488 #[tokio::test]
2489 async fn manager_reuses_stored_auth_on_connect() {
2490 let factory = RecordingTransportFactory::new(vec![vec![
2491 json!({ "jsonrpc": "2.0", "id": 0, "result": { "capabilities": {} } }),
2492 json!({ "jsonrpc": "2.0", "id": 1, "result": { "tools": [] } }),
2493 json!({ "jsonrpc": "2.0", "id": 2, "result": { "resources": [] } }),
2494 json!({ "jsonrpc": "2.0", "id": 3, "result": { "prompts": [] } }),
2495 ]]);
2496 let server_id = McpServerId::new("recording");
2497 let mut manager = McpServerManager::new().with_server(McpServerConfig::new(
2498 server_id.to_string(),
2499 McpTransportBinding::Custom(Arc::new(factory.clone())),
2500 ));
2501 let mut auth = MetadataMap::new();
2502 auth.insert("token".into(), json!("seed-token"));
2503 let request = AuthRequest {
2504 task_id: None,
2505 id: "auth-recording-connect".into(),
2506 provider: "mcp.recording".into(),
2507 operation: AuthOperation::McpConnect {
2508 server_id: server_id.to_string(),
2509 metadata: MetadataMap::new(),
2510 },
2511 challenge: MetadataMap::new(),
2512 };
2513 manager
2514 .resolve_auth(agentkit_tools_core::AuthResolution::Provided {
2515 request,
2516 credentials: auth,
2517 })
2518 .await
2519 .unwrap();
2520
2521 manager.connect_server(&server_id).await.unwrap();
2522 let sent = factory.sent();
2523 assert!(
2524 sent.iter().any(|value| {
2525 value.get("method").and_then(Value::as_str) == Some("initialize")
2526 && value
2527 .get("params")
2528 .and_then(|params| params.get("auth"))
2529 .and_then(|auth| auth.get("token"))
2530 == Some(&json!("seed-token"))
2531 }),
2532 "expected initialize to include stored auth, saw {:?}",
2533 sent
2534 );
2535 }
2536
2537 #[tokio::test]
2538 async fn manager_resolves_auth_and_replays_resource_read() {
2539 let factory = RecordingTransportFactory::new(vec![vec![
2540 json!({ "jsonrpc": "2.0", "id": 0, "result": { "capabilities": {} } }),
2541 json!({ "jsonrpc": "2.0", "id": 1, "result": { "tools": [] } }),
2542 json!({ "jsonrpc": "2.0", "id": 2, "result": { "resources": [] } }),
2543 json!({ "jsonrpc": "2.0", "id": 3, "result": { "prompts": [] } }),
2544 json!({
2545 "jsonrpc": "2.0",
2546 "id": 4,
2547 "result": {
2548 "contents": [
2549 {
2550 "uri": "file:///tmp/secret.txt",
2551 "text": "secret from resource"
2552 }
2553 ]
2554 }
2555 }),
2556 ]]);
2557 let server_id = McpServerId::new("recording");
2558 let mut manager = McpServerManager::new().with_server(McpServerConfig::new(
2559 server_id.to_string(),
2560 McpTransportBinding::Custom(Arc::new(factory.clone())),
2561 ));
2562 let mut auth = MetadataMap::new();
2563 auth.insert("token".into(), json!("resource-token"));
2564 let request = AuthRequest {
2565 task_id: None,
2566 id: "auth-recording-resource".into(),
2567 provider: "mcp.recording".into(),
2568 operation: AuthOperation::McpResourceRead {
2569 server_id: server_id.to_string(),
2570 resource_id: "file:///tmp/secret.txt".into(),
2571 metadata: MetadataMap::new(),
2572 },
2573 challenge: MetadataMap::new(),
2574 };
2575
2576 let result = manager
2577 .resolve_auth_and_resume(agentkit_tools_core::AuthResolution::Provided {
2578 request,
2579 credentials: auth,
2580 })
2581 .await
2582 .unwrap();
2583
2584 match result {
2585 McpOperationResult::Resource(contents) => {
2586 assert_eq!(
2587 contents.data,
2588 DataRef::InlineText("secret from resource".into())
2589 );
2590 }
2591 other => panic!("unexpected replay result: {other:?}"),
2592 }
2593
2594 let sent = factory.sent();
2595 assert!(
2596 sent.iter().any(|value| {
2597 value.get("method").and_then(Value::as_str) == Some("resources/read")
2598 && value
2599 .get("params")
2600 .and_then(|params| params.get("auth"))
2601 .and_then(|auth| auth.get("token"))
2602 == Some(&json!("resource-token"))
2603 }),
2604 "expected resources/read to include resolved auth, saw {:?}",
2605 sent
2606 );
2607 }
2608
2609 #[tokio::test]
2610 async fn manager_resolves_auth_and_replays_connect() {
2611 let factory = RecordingTransportFactory::new(vec![vec![
2612 json!({ "jsonrpc": "2.0", "id": 0, "result": { "capabilities": {} } }),
2613 json!({ "jsonrpc": "2.0", "id": 1, "result": { "tools": [] } }),
2614 json!({ "jsonrpc": "2.0", "id": 2, "result": { "resources": [] } }),
2615 json!({ "jsonrpc": "2.0", "id": 3, "result": { "prompts": [] } }),
2616 ]]);
2617 let server_id = McpServerId::new("recording");
2618 let mut manager = McpServerManager::new().with_server(McpServerConfig::new(
2619 server_id.to_string(),
2620 McpTransportBinding::Custom(Arc::new(factory.clone())),
2621 ));
2622 let mut auth = MetadataMap::new();
2623 auth.insert("token".into(), json!("connect-token"));
2624 let request = AuthRequest {
2625 task_id: None,
2626 id: "auth-recording-connect-replay".into(),
2627 provider: "mcp.recording".into(),
2628 operation: AuthOperation::McpConnect {
2629 server_id: server_id.to_string(),
2630 metadata: MetadataMap::new(),
2631 },
2632 challenge: MetadataMap::new(),
2633 };
2634
2635 let result = manager
2636 .resolve_auth_and_resume(agentkit_tools_core::AuthResolution::Provided {
2637 request,
2638 credentials: auth,
2639 })
2640 .await
2641 .unwrap();
2642
2643 match result {
2644 McpOperationResult::Connected(snapshot) => {
2645 assert_eq!(snapshot.server_id, server_id);
2646 }
2647 other => panic!("unexpected replay result: {other:?}"),
2648 }
2649 }
2650
2651 #[tokio::test]
2652 async fn sse_transport_posts_messages_and_receives_frames() {
2653 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2654 let address = listener.local_addr().unwrap();
2655 let requests = StdArc::new(StdMutex::new(Vec::new()));
2656 let captured = requests.clone();
2657
2658 let server = tokio::spawn(async move {
2659 for _ in 0..2 {
2660 let (mut socket, _) = listener.accept().await.unwrap();
2661 let mut buffer = vec![0_u8; 4096];
2662 let read = socket.read(&mut buffer).await.unwrap();
2663 let request = String::from_utf8_lossy(&buffer[..read]).to_string();
2664
2665 if request.starts_with("GET /sse ") {
2666 let body = concat!(
2667 "event: endpoint\n",
2668 "data: /messages\n\n",
2669 "event: message\n",
2670 "data: {\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"tools\":[]}}\n\n"
2671 );
2672 let response = format!(
2673 "HTTP/1.1 200 OK\r\ncontent-type: text/event-stream\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{}",
2674 body.len(),
2675 body
2676 );
2677 socket.write_all(response.as_bytes()).await.unwrap();
2678 } else {
2679 captured.lock().unwrap().push(request);
2680 socket
2681 .write_all(
2682 b"HTTP/1.1 202 Accepted\r\ncontent-length: 0\r\nconnection: close\r\n\r\n",
2683 )
2684 .await
2685 .unwrap();
2686 }
2687 }
2688 });
2689
2690 let factory =
2691 SseTransportFactory::new(SseTransportConfig::new(format!("http://{address}/sse")));
2692 let mut transport = factory.connect().await.unwrap();
2693 transport
2694 .send(McpFrame {
2695 value: json!({
2696 "jsonrpc": "2.0",
2697 "id": 1,
2698 "method": "tools/list",
2699 "params": {}
2700 }),
2701 })
2702 .await
2703 .unwrap();
2704 let frame = transport.recv().await.unwrap().unwrap();
2705 transport.close().await.unwrap();
2706 server.await.unwrap();
2707
2708 assert_eq!(frame.value["result"]["tools"], json!([]));
2709 let requests = requests.lock().unwrap();
2710 assert_eq!(requests.len(), 1);
2711 assert!(requests[0].starts_with("POST /messages "));
2712 assert!(requests[0].contains("\"method\":\"tools/list\""));
2713 }
2714
2715 #[tokio::test]
2716 async fn server_manager_connects_refreshes_and_aggregates_tools() {
2717 let alpha = McpServerConfig::new(
2718 "alpha",
2719 McpTransportBinding::Custom(Arc::new(FakeTransportFactory::new(vec![vec![
2720 json!({ "jsonrpc": "2.0", "id": 0, "result": { "capabilities": {} } }),
2721 json!({ "jsonrpc": "2.0", "id": 1, "result": { "tools": [{ "name": "echo", "description": "Echo", "inputSchema": {"type": "object"} }] } }),
2722 json!({ "jsonrpc": "2.0", "id": 2, "result": { "resources": [] } }),
2723 json!({ "jsonrpc": "2.0", "id": 3, "result": { "prompts": [] } }),
2724 json!({ "jsonrpc": "2.0", "id": 4, "result": { "tools": [{ "name": "echo_v2", "description": "Echo 2", "inputSchema": {"type": "object"} }] } }),
2725 json!({ "jsonrpc": "2.0", "id": 5, "result": { "resources": [] } }),
2726 json!({ "jsonrpc": "2.0", "id": 6, "result": { "prompts": [] } }),
2727 ]]))),
2728 );
2729 let beta = McpServerConfig::new(
2730 "beta",
2731 McpTransportBinding::Custom(Arc::new(FakeTransportFactory::new(vec![vec![
2732 json!({ "jsonrpc": "2.0", "id": 0, "result": { "capabilities": {} } }),
2733 json!({ "jsonrpc": "2.0", "id": 1, "result": { "tools": [{ "name": "search", "description": "Search", "inputSchema": {"type": "object"} }] } }),
2734 json!({ "jsonrpc": "2.0", "id": 2, "result": { "resources": [] } }),
2735 json!({ "jsonrpc": "2.0", "id": 3, "result": { "prompts": [] } }),
2736 ]]))),
2737 );
2738
2739 let mut manager = McpServerManager::new().with_server(alpha).with_server(beta);
2740
2741 let handles = manager.connect_all().await.unwrap();
2742 assert_eq!(handles.len(), 2);
2743 assert_eq!(
2744 manager
2745 .tool_registry()
2746 .specs()
2747 .into_iter()
2748 .map(|spec| spec.name.0)
2749 .collect::<Vec<_>>(),
2750 vec!["mcp.alpha.echo".to_string(), "mcp.beta.search".to_string()]
2751 );
2752
2753 let refreshed = manager
2754 .refresh_server(&McpServerId::new("alpha"))
2755 .await
2756 .unwrap();
2757 assert_eq!(refreshed.tools[0].name, "echo_v2");
2758 assert_eq!(
2759 manager
2760 .connected_server(&McpServerId::new("alpha"))
2761 .unwrap()
2762 .snapshot()
2763 .tools[0]
2764 .name,
2765 "echo_v2"
2766 );
2767
2768 let capabilities = manager.capability_provider();
2769 assert_eq!(capabilities.invocables().len(), 2);
2770
2771 manager
2772 .disconnect_server(&McpServerId::new("alpha"))
2773 .await
2774 .unwrap();
2775 assert!(
2776 manager
2777 .connected_server(&McpServerId::new("alpha"))
2778 .is_none()
2779 );
2780 }
2781}