use std::collections::{BTreeMap, BTreeSet, HashMap};
use std::fmt;
use std::sync::{Arc, RwLock};
use std::time::Duration;
use agentkit_capabilities::{
CapabilityContext, CapabilityError, CapabilityProvider, Invocable, PromptContents,
PromptDescriptor, PromptId, PromptProvider, ResourceContents, ResourceDescriptor, ResourceId,
ResourceProvider,
};
use agentkit_core::{
DataRef, Item, ItemKind, MediaPart, MetadataMap, Modality, Part, TextPart, ToolOutput,
ToolResultPart,
};
use agentkit_tools_core::{
AllowAllPermissions, CatalogReader, CatalogWriter, PermissionChecker, Tool, ToolAnnotations,
ToolCapabilityProvider, ToolContext, ToolError, ToolName, ToolRegistry, ToolRequest,
ToolResult, ToolSpec, dynamic_catalog,
};
use async_trait::async_trait;
use futures_util::future::{join_all, try_join_all};
use futures_util::stream::BoxStream;
use http::{HeaderName, HeaderValue};
use rmcp::ServiceExt;
use rmcp::handler::client::ClientHandler;
use rmcp::model as rmcp_model;
use rmcp::service::{ClientInitializeError, Peer, RoleClient, RunningService, ServiceError};
use rmcp::transport::streamable_http_client::{
AuthRequiredError, InsufficientScopeError, StreamableHttpClient as RmcpStreamableHttpClient,
StreamableHttpClientTransportConfig as RmcpStreamableHttpClientTransportConfig,
StreamableHttpError, StreamableHttpPostResponse,
};
use rmcp::transport::{
ConfigureCommandExt, DynamicTransportError, StreamableHttpClientTransport, TokioChildProcess,
};
use serde::{Deserialize, Serialize};
use serde_json::{Value, json};
use sse_stream::{Error as SseError, Sse};
use thiserror::Error;
use tokio::sync::{Mutex, broadcast, mpsc};
pub use rmcp::model::{
Annotations as McpAnnotations, AudioContent, CallToolResult,
CancelledNotificationParam as McpCancelledNotificationParam,
ClientCapabilities as McpClientCapabilities, Content,
CreateElicitationRequestParams as McpCreateElicitationRequestParams,
CreateElicitationResult as McpCreateElicitationResult,
CreateMessageRequestParams as McpCreateMessageRequestParams,
CreateMessageResult as McpCreateMessageResult, ElicitationAction as McpElicitationAction,
ElicitationCapability as McpElicitationCapability, EmbeddedResource,
FormElicitationCapability as McpFormElicitationCapability, GetPromptResult, ImageContent,
Implementation as McpImplementation, ListRootsResult as McpListRootsResult,
LoggingLevel as McpLoggingLevel,
LoggingMessageNotificationParam as McpLoggingMessageNotificationParam,
ProgressNotificationParam as McpProgressNotificationParam, Prompt as McpPrompt, PromptArgument,
PromptMessage, PromptMessageContent, PromptMessageRole, RawAudioContent, RawContent,
RawEmbeddedResource, RawImageContent, RawResource as McpRawResource, RawTextContent,
ReadResourceResult, Resource as McpResource, ResourceContents as McpResourceContents,
ResourceUpdatedNotificationParam as McpResourceUpdatedNotificationParam, Root as McpRoot,
RootsCapabilities as McpRootsCapabilities, SamplingCapability as McpSamplingCapability,
SamplingMessage as McpSamplingMessage, SetLevelRequestParams as McpSetLevelRequestParams,
TextContent, Tool as McpTool, ToolAnnotations as McpToolAnnotations,
UrlElicitationCapability as McpUrlElicitationCapability,
};
pub use rmcp::model::ClientJsonRpcMessage;
pub use rmcp::transport::streamable_http_client::{
StreamableHttpError as McpStreamableHttpError,
StreamableHttpPostResponse as McpStreamableHttpPostResponse,
};
pub use sse_stream::{Error as McpSseError, Sse as McpSse};
pub type McpToolDescriptor = McpTool;
pub type McpResourceDescriptor = McpResource;
pub type McpPromptDescriptor = McpPrompt;
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct AuthRequest {
pub id: String,
pub provider: String,
pub operation: AuthOperation,
pub challenge: MetadataMap,
}
impl AuthRequest {
pub fn server_id(&self) -> Option<&str> {
self.operation.server_id()
}
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum AuthOperation {
McpConnect {
server_id: String,
metadata: MetadataMap,
},
McpToolCall {
server_id: String,
tool_name: String,
input: Value,
metadata: MetadataMap,
},
McpResourceRead {
server_id: String,
resource_id: String,
metadata: MetadataMap,
},
McpPromptGet {
server_id: String,
prompt_id: String,
args: Value,
metadata: MetadataMap,
},
McpOther {
server_id: String,
method: String,
params: Value,
metadata: MetadataMap,
},
}
impl AuthOperation {
pub fn server_id(&self) -> Option<&str> {
match self {
Self::McpConnect { server_id, .. }
| Self::McpToolCall { server_id, .. }
| Self::McpResourceRead { server_id, .. }
| Self::McpPromptGet { server_id, .. }
| Self::McpOther { server_id, .. } => Some(server_id.as_str()),
}
}
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum AuthResolution {
Provided {
request: AuthRequest,
credentials: MetadataMap,
},
Cancelled { request: AuthRequest },
}
impl AuthResolution {
pub fn provided(request: AuthRequest, credentials: MetadataMap) -> Self {
Self::Provided {
request,
credentials,
}
}
pub fn cancelled(request: AuthRequest) -> Self {
Self::Cancelled { request }
}
pub fn request(&self) -> &AuthRequest {
match self {
Self::Provided { request, .. } | Self::Cancelled { request } => request,
}
}
}
#[async_trait]
pub trait McpAuthResponder: Send + Sync + 'static {
async fn resolve(&self, request: AuthRequest) -> Result<AuthResolution, McpError>;
}
#[derive(Debug, Clone, thiserror::Error)]
pub enum McpInvocationError {
#[error("url elicitation required: {message}")]
UrlElicitation {
message: String,
data: Option<UrlElicitationData>,
raw_data: Option<serde_json::Value>,
},
#[error("invalid request: {message}")]
InvalidRequest {
message: String,
data: Option<serde_json::Value>,
},
#[error("method not found: {message}")]
MethodNotFound {
message: String,
data: Option<serde_json::Value>,
},
#[error("invalid params: {message}")]
InvalidParams {
message: String,
data: Option<serde_json::Value>,
},
#[error("internal error: {message}")]
InternalError {
message: String,
data: Option<serde_json::Value>,
},
#[error("parse error: {message}")]
ParseError {
message: String,
data: Option<serde_json::Value>,
},
#[error("resource not found: {message}")]
ResourceNotFound {
message: String,
data: Option<serde_json::Value>,
},
#[error("mcp error code {code}: {message}")]
Other {
code: i32,
message: String,
data: Option<serde_json::Value>,
},
}
#[derive(Debug, Clone, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct UrlElicitationData {
pub url: String,
pub elicitation_id: String,
#[serde(default)]
pub message: Option<String>,
}
impl McpInvocationError {
pub fn from_error_data(err: rmcp::model::ErrorData) -> Self {
let rmcp::model::ErrorData {
code,
message,
data,
} = err;
let message = message.into_owned();
match code {
rmcp::model::ErrorCode::URL_ELICITATION_REQUIRED => {
let typed = data.as_ref().and_then(|value| {
serde_json::from_value::<UrlElicitationData>(value.clone()).ok()
});
Self::UrlElicitation {
message,
data: typed,
raw_data: data,
}
}
rmcp::model::ErrorCode::INVALID_REQUEST => Self::InvalidRequest { message, data },
rmcp::model::ErrorCode::METHOD_NOT_FOUND => Self::MethodNotFound { message, data },
rmcp::model::ErrorCode::INVALID_PARAMS => Self::InvalidParams { message, data },
rmcp::model::ErrorCode::INTERNAL_ERROR => Self::InternalError { message, data },
rmcp::model::ErrorCode::PARSE_ERROR => Self::ParseError { message, data },
rmcp::model::ErrorCode::RESOURCE_NOT_FOUND => Self::ResourceNotFound { message, data },
other => Self::Other {
code: other.0,
message,
data,
},
}
}
pub fn code(&self) -> i32 {
match self {
Self::UrlElicitation { .. } => rmcp::model::ErrorCode::URL_ELICITATION_REQUIRED.0,
Self::InvalidRequest { .. } => rmcp::model::ErrorCode::INVALID_REQUEST.0,
Self::MethodNotFound { .. } => rmcp::model::ErrorCode::METHOD_NOT_FOUND.0,
Self::InvalidParams { .. } => rmcp::model::ErrorCode::INVALID_PARAMS.0,
Self::InternalError { .. } => rmcp::model::ErrorCode::INTERNAL_ERROR.0,
Self::ParseError { .. } => rmcp::model::ErrorCode::PARSE_ERROR.0,
Self::ResourceNotFound { .. } => rmcp::model::ErrorCode::RESOURCE_NOT_FOUND.0,
Self::Other { code, .. } => *code,
}
}
}
#[async_trait]
pub trait McpErrorResponder: Send + Sync + 'static {
async fn handle(
&self,
error: &McpInvocationError,
ctx: McpErrorContext<'_>,
) -> ErrorResponderOutcome;
}
pub struct McpErrorContext<'a> {
pub server_id: &'a McpServerId,
pub method: &'a McpMethod,
pub input: Option<&'a serde_json::Value>,
}
pub enum ErrorResponderOutcome {
SynthesizeResult(CallToolResult),
PassThrough,
}
#[derive(Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
pub struct McpServerId(pub String);
impl McpServerId {
pub fn new(value: impl Into<String>) -> Self {
Self(value.into())
}
}
impl fmt::Display for McpServerId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct StdioTransportConfig {
pub command: String,
pub args: Vec<String>,
pub env: Vec<(String, String)>,
pub cwd: Option<std::path::PathBuf>,
}
impl StdioTransportConfig {
pub fn new(command: impl Into<String>) -> Self {
Self {
command: command.into(),
args: Vec::new(),
env: Vec::new(),
cwd: None,
}
}
pub fn with_arg(mut self, arg: impl Into<String>) -> Self {
self.args.push(arg.into());
self
}
pub fn with_env(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.env.push((key.into(), value.into()));
self
}
pub fn with_cwd(mut self, cwd: impl Into<std::path::PathBuf>) -> Self {
self.cwd = Some(cwd.into());
self
}
}
#[derive(Clone, Default)]
pub struct StreamableHttpTransportConfig {
pub url: String,
pub bearer_token: Option<String>,
pub headers: Vec<(HeaderName, HeaderValue)>,
pub http_client: Option<Arc<dyn McpHttpClient>>,
}
impl fmt::Debug for StreamableHttpTransportConfig {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("StreamableHttpTransportConfig")
.field("url", &self.url)
.field(
"bearer_token",
&self.bearer_token.as_deref().map(|_| "<redacted>"),
)
.field("headers", &self.headers)
.field(
"http_client",
&self.http_client.as_ref().map(|_| "<custom>"),
)
.finish()
}
}
impl StreamableHttpTransportConfig {
pub fn new(url: impl Into<String>) -> Self {
Self {
url: url.into(),
bearer_token: None,
headers: Vec::new(),
http_client: None,
}
}
pub fn with_bearer_token(mut self, token: impl Into<String>) -> Self {
self.bearer_token = Some(token.into());
self
}
pub fn with_http_client(mut self, client: Arc<dyn McpHttpClient>) -> Self {
self.http_client = Some(client);
self
}
pub fn with_header<N, V>(mut self, name: N, value: V) -> Result<Self, McpError>
where
N: TryInto<HeaderName>,
N::Error: fmt::Display,
V: TryInto<HeaderValue>,
V::Error: fmt::Display,
{
let name = name
.try_into()
.map_err(|error| McpError::Transport(format!("invalid HTTP header name: {error}")))?;
let value = value
.try_into()
.map_err(|error| McpError::Transport(format!("invalid HTTP header value: {error}")))?;
self.headers.push((name, value));
Ok(self)
}
}
pub type McpSseStream = BoxStream<'static, Result<Sse, SseError>>;
#[async_trait]
pub trait McpHttpClient: Send + Sync + 'static {
async fn post_message(
&self,
uri: Arc<str>,
message: ClientJsonRpcMessage,
session_id: Option<Arc<str>>,
auth_header: Option<String>,
custom_headers: HashMap<HeaderName, HeaderValue>,
) -> Result<StreamableHttpPostResponse, StreamableHttpError<reqwest::Error>>;
async fn delete_session(
&self,
uri: Arc<str>,
session_id: Arc<str>,
auth_header: Option<String>,
custom_headers: HashMap<HeaderName, HeaderValue>,
) -> Result<(), StreamableHttpError<reqwest::Error>>;
async fn get_stream(
&self,
uri: Arc<str>,
session_id: Arc<str>,
last_event_id: Option<String>,
auth_header: Option<String>,
custom_headers: HashMap<HeaderName, HeaderValue>,
) -> Result<McpSseStream, StreamableHttpError<reqwest::Error>>;
}
#[derive(Clone)]
struct DynHttpClient(Arc<dyn McpHttpClient>);
impl RmcpStreamableHttpClient for DynHttpClient {
type Error = reqwest::Error;
async fn post_message(
&self,
uri: Arc<str>,
message: ClientJsonRpcMessage,
session_id: Option<Arc<str>>,
auth_header: Option<String>,
custom_headers: HashMap<HeaderName, HeaderValue>,
) -> Result<StreamableHttpPostResponse, StreamableHttpError<reqwest::Error>> {
self.0
.post_message(uri, message, session_id, auth_header, custom_headers)
.await
}
async fn delete_session(
&self,
uri: Arc<str>,
session_id: Arc<str>,
auth_header: Option<String>,
custom_headers: HashMap<HeaderName, HeaderValue>,
) -> Result<(), StreamableHttpError<reqwest::Error>> {
self.0
.delete_session(uri, session_id, auth_header, custom_headers)
.await
}
async fn get_stream(
&self,
uri: Arc<str>,
session_id: Arc<str>,
last_event_id: Option<String>,
auth_header: Option<String>,
custom_headers: HashMap<HeaderName, HeaderValue>,
) -> Result<McpSseStream, StreamableHttpError<reqwest::Error>> {
self.0
.get_stream(uri, session_id, last_event_id, auth_header, custom_headers)
.await
}
}
#[derive(Clone, Debug)]
pub enum McpTransportBinding {
Stdio(StdioTransportConfig),
StreamableHttp(StreamableHttpTransportConfig),
}
#[derive(Clone, Debug)]
pub struct McpServerConfig {
pub id: McpServerId,
pub transport: McpTransportBinding,
pub metadata: MetadataMap,
}
impl McpServerConfig {
pub fn new(id: impl Into<String>, transport: McpTransportBinding) -> Self {
Self {
id: McpServerId::new(id),
transport,
metadata: MetadataMap::new(),
}
}
pub fn stdio(id: impl Into<String>, command: impl Into<String>) -> Self {
Self::new(
id,
McpTransportBinding::Stdio(StdioTransportConfig::new(command)),
)
}
pub fn streamable_http(id: impl Into<String>, url: impl Into<String>) -> Self {
Self::new(
id,
McpTransportBinding::StreamableHttp(StreamableHttpTransportConfig::new(url)),
)
}
pub fn with_metadata(mut self, metadata: MetadataMap) -> Self {
self.metadata = metadata;
self
}
}
type CustomNamespace = Arc<dyn Fn(&McpServerId, &str) -> String + Send + Sync>;
#[derive(Clone, Default)]
pub enum McpToolNamespace {
#[default]
Default,
None,
Custom(CustomNamespace),
}
impl fmt::Debug for McpToolNamespace {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Default => f.write_str("McpToolNamespace::Default"),
Self::None => f.write_str("McpToolNamespace::None"),
Self::Custom(_) => f.write_str("McpToolNamespace::Custom(<fn>)"),
}
}
}
impl McpToolNamespace {
pub fn custom(f: impl Fn(&McpServerId, &str) -> String + Send + Sync + 'static) -> Self {
Self::Custom(Arc::new(f))
}
pub fn apply(&self, server_id: &McpServerId, tool_name: &str) -> String {
match self {
Self::Default => format!("mcp_{server_id}_{tool_name}"),
Self::None => tool_name.to_string(),
Self::Custom(f) => f(server_id, tool_name),
}
}
pub fn unapply(&self, server_id: &McpServerId, agentkit_name: &str) -> Option<String> {
match self {
Self::Default => agentkit_name
.strip_prefix(&format!("mcp_{server_id}_"))
.map(str::to_string),
Self::None => Some(agentkit_name.to_string()),
Self::Custom(_) => None,
}
}
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct McpDiscoverySnapshot {
pub server_id: McpServerId,
pub tools: Vec<McpTool>,
pub resources: Vec<McpResource>,
pub prompts: Vec<McpPrompt>,
pub metadata: MetadataMap,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum McpCatalogEvent {
ServerConnected { server_id: McpServerId },
ServerDisconnected { server_id: McpServerId },
ToolsChanged {
server_id: McpServerId,
added: Vec<String>,
removed: Vec<String>,
changed: Vec<String>,
},
ResourcesChanged {
server_id: McpServerId,
added: Vec<String>,
removed: Vec<String>,
changed: Vec<String>,
},
PromptsChanged {
server_id: McpServerId,
added: Vec<String>,
removed: Vec<String>,
changed: Vec<String>,
},
AuthChanged { server_id: McpServerId },
RefreshFailed {
server_id: McpServerId,
message: String,
},
}
#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct McpServerCapabilities {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tools: Option<ToolsCapability>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub resources: Option<ResourcesCapability>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub prompts: Option<PromptsCapability>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub logging: Option<LoggingCapability>,
}
impl McpServerCapabilities {
pub fn all() -> Self {
Self {
tools: Some(ToolsCapability::default()),
resources: Some(ResourcesCapability::default()),
prompts: Some(PromptsCapability::default()),
logging: Some(LoggingCapability::default()),
}
}
}
#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ToolsCapability {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub list_changed: Option<bool>,
}
#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ResourcesCapability {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub subscribe: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub list_changed: Option<bool>,
}
#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PromptsCapability {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub list_changed: Option<bool>,
}
#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct LoggingCapability {}
#[allow(clippy::enum_variant_names)]
#[derive(Clone, Debug)]
pub enum McpServerNotification {
ToolsChanged,
ResourcesChanged,
PromptsChanged,
}
#[derive(Clone, Debug)]
pub enum McpServerEvent {
Progress(McpProgressNotificationParam),
Logging(McpLoggingMessageNotificationParam),
ResourceUpdated(McpResourceUpdatedNotificationParam),
ToolListChanged,
ResourceListChanged,
PromptListChanged,
Cancelled(McpCancelledNotificationParam),
}
#[async_trait]
pub trait McpSamplingResponder: Send + Sync + 'static {
async fn create_message(
&self,
params: McpCreateMessageRequestParams,
) -> Result<McpCreateMessageResult, McpError>;
}
#[async_trait]
pub trait McpElicitationResponder: Send + Sync + 'static {
async fn create_elicitation(
&self,
params: McpCreateElicitationRequestParams,
) -> Result<McpCreateElicitationResult, McpError>;
}
#[async_trait]
pub trait McpRootsProvider: Send + Sync + 'static {
async fn list_roots(&self) -> Result<Vec<McpRoot>, McpError>;
}
const DEFAULT_EVENTS_CAPACITY: usize = 128;
pub struct McpClientChannels {
pub notifications: mpsc::UnboundedReceiver<McpServerNotification>,
pub events: broadcast::Sender<McpServerEvent>,
}
#[derive(Clone)]
pub struct McpClientHandler {
info: rmcp_model::ClientInfo,
notifications: mpsc::UnboundedSender<McpServerNotification>,
events: broadcast::Sender<McpServerEvent>,
sampling: Option<Arc<dyn McpSamplingResponder>>,
elicitation: Option<Arc<dyn McpElicitationResponder>>,
roots: Option<Arc<dyn McpRootsProvider>>,
}
impl ClientHandler for McpClientHandler {
fn create_message(
&self,
params: rmcp_model::CreateMessageRequestParams,
_context: rmcp::service::RequestContext<RoleClient>,
) -> impl Future<Output = Result<rmcp_model::CreateMessageResult, rmcp_model::ErrorData>>
+ rmcp::service::MaybeSendFuture
+ '_ {
let responder = self.sampling.clone();
async move {
match responder {
Some(responder) => responder.create_message(params).await.map_err(Into::into),
None => Err(rmcp_model::ErrorData::method_not_found::<
rmcp_model::CreateMessageRequestMethod,
>()),
}
}
}
fn list_roots(
&self,
_context: rmcp::service::RequestContext<RoleClient>,
) -> impl Future<Output = Result<rmcp_model::ListRootsResult, rmcp_model::ErrorData>>
+ rmcp::service::MaybeSendFuture
+ '_ {
let provider = self.roots.clone();
async move {
match provider {
Some(provider) => provider
.list_roots()
.await
.map(McpListRootsResult::new)
.map_err(Into::into),
None => Ok(McpListRootsResult::default()),
}
}
}
fn create_elicitation(
&self,
params: rmcp_model::CreateElicitationRequestParams,
_context: rmcp::service::RequestContext<RoleClient>,
) -> impl Future<Output = Result<rmcp_model::CreateElicitationResult, rmcp_model::ErrorData>>
+ rmcp::service::MaybeSendFuture
+ '_ {
let responder = self.elicitation.clone();
async move {
match responder {
Some(responder) => responder
.create_elicitation(params)
.await
.map_err(Into::into),
None => Ok(McpCreateElicitationResult::new(
McpElicitationAction::Decline,
)),
}
}
}
fn on_progress(
&self,
params: rmcp_model::ProgressNotificationParam,
_context: rmcp::service::NotificationContext<RoleClient>,
) -> impl Future<Output = ()> + rmcp::service::MaybeSendFuture + '_ {
let _ = self.events.send(McpServerEvent::Progress(params));
std::future::ready(())
}
fn on_logging_message(
&self,
params: rmcp_model::LoggingMessageNotificationParam,
_context: rmcp::service::NotificationContext<RoleClient>,
) -> impl Future<Output = ()> + rmcp::service::MaybeSendFuture + '_ {
let _ = self.events.send(McpServerEvent::Logging(params));
std::future::ready(())
}
fn on_resource_updated(
&self,
params: rmcp_model::ResourceUpdatedNotificationParam,
_context: rmcp::service::NotificationContext<RoleClient>,
) -> impl Future<Output = ()> + rmcp::service::MaybeSendFuture + '_ {
let _ = self.events.send(McpServerEvent::ResourceUpdated(params));
std::future::ready(())
}
fn on_cancelled(
&self,
params: rmcp_model::CancelledNotificationParam,
_context: rmcp::service::NotificationContext<RoleClient>,
) -> impl Future<Output = ()> + rmcp::service::MaybeSendFuture + '_ {
let _ = self.events.send(McpServerEvent::Cancelled(params));
std::future::ready(())
}
fn on_tool_list_changed(
&self,
_context: rmcp::service::NotificationContext<RoleClient>,
) -> impl Future<Output = ()> + rmcp::service::MaybeSendFuture + '_ {
let _ = self.notifications.send(McpServerNotification::ToolsChanged);
let _ = self.events.send(McpServerEvent::ToolListChanged);
std::future::ready(())
}
fn on_resource_list_changed(
&self,
_context: rmcp::service::NotificationContext<RoleClient>,
) -> impl Future<Output = ()> + rmcp::service::MaybeSendFuture + '_ {
let _ = self
.notifications
.send(McpServerNotification::ResourcesChanged);
let _ = self.events.send(McpServerEvent::ResourceListChanged);
std::future::ready(())
}
fn on_prompt_list_changed(
&self,
_context: rmcp::service::NotificationContext<RoleClient>,
) -> impl Future<Output = ()> + rmcp::service::MaybeSendFuture + '_ {
let _ = self
.notifications
.send(McpServerNotification::PromptsChanged);
let _ = self.events.send(McpServerEvent::PromptListChanged);
std::future::ready(())
}
fn get_info(&self) -> rmcp_model::ClientInfo {
self.info.clone()
}
}
impl From<McpError> for rmcp_model::ErrorData {
fn from(error: McpError) -> Self {
rmcp_model::ErrorData::internal_error(error.to_string(), None)
}
}
type RmcpClientService = RunningService<RoleClient, McpClientHandler>;
#[derive(Clone, Default)]
pub struct McpHandlerConfig {
pub sampling: Option<Arc<dyn McpSamplingResponder>>,
pub elicitation: Option<Arc<dyn McpElicitationResponder>>,
pub roots: Option<Arc<dyn McpRootsProvider>>,
pub auth: Option<Arc<dyn McpAuthResponder>>,
pub error_responder: Option<Arc<dyn McpErrorResponder>>,
pub events_capacity: Option<usize>,
}
impl McpHandlerConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_sampling_responder(mut self, responder: Arc<dyn McpSamplingResponder>) -> Self {
self.sampling = Some(responder);
self
}
pub fn with_elicitation_responder(
mut self,
responder: Arc<dyn McpElicitationResponder>,
) -> Self {
self.elicitation = Some(responder);
self
}
pub fn with_roots_provider(mut self, provider: Arc<dyn McpRootsProvider>) -> Self {
self.roots = Some(provider);
self
}
pub fn with_auth_responder(mut self, responder: Arc<dyn McpAuthResponder>) -> Self {
self.auth = Some(responder);
self
}
pub fn with_error_responder(mut self, responder: Arc<dyn McpErrorResponder>) -> Self {
self.error_responder = Some(responder);
self
}
pub fn with_events_capacity(mut self, capacity: usize) -> Self {
self.events_capacity = Some(capacity);
self
}
pub fn build(&self) -> (McpClientHandler, McpClientChannels) {
self.build_inner(None)
}
pub fn build_with(
&self,
events: broadcast::Sender<McpServerEvent>,
) -> (McpClientHandler, McpClientChannels) {
self.build_inner(Some(events))
}
fn build_inner(
&self,
events: Option<broadcast::Sender<McpServerEvent>>,
) -> (McpClientHandler, McpClientChannels) {
let (notifications_tx, notifications_rx) = mpsc::unbounded_channel();
let events_tx = events.unwrap_or_else(|| {
let capacity = self.events_capacity.unwrap_or(DEFAULT_EVENTS_CAPACITY);
let (tx, _) = broadcast::channel(capacity);
tx
});
let mut capabilities = rmcp_model::ClientCapabilities::default();
if self.sampling.is_some() {
capabilities.sampling = Some(McpSamplingCapability::default());
}
if self.elicitation.is_some() {
capabilities.elicitation = Some(McpElicitationCapability {
form: Some(McpFormElicitationCapability::default()),
url: None,
});
}
if self.roots.is_some() {
capabilities.roots = Some(McpRootsCapabilities::default());
}
let handler = McpClientHandler {
info: rmcp_model::ClientInfo::new(
capabilities,
rmcp_model::Implementation::new("agentkit-mcp", env!("CARGO_PKG_VERSION"))
.with_title("agentkit MCP client"),
)
.with_protocol_version(rmcp_model::ProtocolVersion::LATEST),
notifications: notifications_tx,
events: events_tx.clone(),
sampling: self.sampling.clone(),
elicitation: self.elicitation.clone(),
roots: self.roots.clone(),
};
(
handler,
McpClientChannels {
notifications: notifications_rx,
events: events_tx,
},
)
}
}
pub struct McpConnection {
server_id: McpServerId,
config: Option<McpServerConfig>,
inner: Mutex<RmcpClientService>,
peer: RwLock<Peer<RoleClient>>,
auth: Mutex<Option<MetadataMap>>,
notifications: Mutex<mpsc::UnboundedReceiver<McpServerNotification>>,
events: broadcast::Sender<McpServerEvent>,
handler_config: McpHandlerConfig,
capabilities: McpServerCapabilities,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub enum McpOperationResult {
Connected(McpDiscoverySnapshot),
Tool(CallToolResult),
Resource(ReadResourceResult),
Prompt(GetPromptResult),
}
impl McpConnection {
pub async fn connect(config: &McpServerConfig) -> Result<Self, McpError> {
Self::connect_with_auth(config, None, McpHandlerConfig::default()).await
}
pub async fn connect_with_handler(
config: &McpServerConfig,
handler_config: McpHandlerConfig,
) -> Result<Self, McpError> {
Self::connect_with_auth(config, None, handler_config).await
}
async fn connect_with_auth(
config: &McpServerConfig,
auth: Option<&MetadataMap>,
handler_config: McpHandlerConfig,
) -> Result<Self, McpError> {
let (handler, channels) = handler_config.build();
let McpClientChannels {
notifications: notification_rx,
events: events_tx,
} = channels;
let (service, capabilities) = match &config.transport {
McpTransportBinding::Stdio(binding) => {
connect_rmcp_stdio(config, binding, handler).await?
}
McpTransportBinding::StreamableHttp(binding) => {
connect_rmcp_streamable_http(config, binding, auth, handler).await?
}
};
let peer = service.peer().clone();
Ok(Self {
server_id: config.id.clone(),
config: Some(config.clone()),
inner: Mutex::new(service),
peer: RwLock::new(peer),
auth: Mutex::new(auth.cloned()),
notifications: Mutex::new(notification_rx),
events: events_tx,
handler_config,
capabilities,
})
}
pub fn from_running_service(
server_id: impl Into<McpServerId>,
service: RmcpClientService,
notifications: mpsc::UnboundedReceiver<McpServerNotification>,
) -> Self {
let (events_tx, _) = broadcast::channel(DEFAULT_EVENTS_CAPACITY);
Self::from_running_service_with_events(server_id, service, notifications, events_tx)
}
pub fn from_running_service_with_events(
server_id: impl Into<McpServerId>,
service: RmcpClientService,
notifications: mpsc::UnboundedReceiver<McpServerNotification>,
events: broadcast::Sender<McpServerEvent>,
) -> Self {
Self::from_running_service_with_events_and_handler_config(
server_id,
service,
notifications,
events,
McpHandlerConfig::default(),
)
}
pub fn from_running_service_with_events_and_handler_config(
server_id: impl Into<McpServerId>,
service: RmcpClientService,
notifications: mpsc::UnboundedReceiver<McpServerNotification>,
events: broadcast::Sender<McpServerEvent>,
handler_config: McpHandlerConfig,
) -> Self {
let capabilities = service
.peer_info()
.map(|info| rmcp_server_capabilities_to_agentkit(&info.capabilities))
.unwrap_or_default();
let peer = service.peer().clone();
Self {
server_id: server_id.into(),
config: None,
inner: Mutex::new(service),
peer: RwLock::new(peer),
auth: Mutex::new(None),
notifications: Mutex::new(notifications),
events,
handler_config,
capabilities,
}
}
async fn reconnect_inner(&self, auth: Option<&MetadataMap>) -> Result<(), McpError> {
let Some(config) = self.config.clone() else {
return Ok(());
};
let (handler, channels) = self.handler_config.build_with(self.events.clone());
let McpClientChannels {
notifications: notification_rx,
..
} = channels;
let (service, _capabilities) = match &config.transport {
McpTransportBinding::Stdio(binding) => {
connect_rmcp_stdio(&config, binding, handler).await?
}
McpTransportBinding::StreamableHttp(binding) => {
connect_rmcp_streamable_http(&config, binding, auth, handler).await?
}
};
let new_peer = service.peer().clone();
*self.notifications.lock().await = notification_rx;
*self.inner.lock().await = service;
*self.peer.write().expect("MCP peer lock poisoned") = new_peer;
Ok(())
}
fn peer(&self) -> Peer<RoleClient> {
self.peer.read().expect("MCP peer lock poisoned").clone()
}
pub fn server_id(&self) -> &McpServerId {
&self.server_id
}
pub fn capabilities(&self) -> &McpServerCapabilities {
&self.capabilities
}
pub fn handler_config(&self) -> &McpHandlerConfig {
&self.handler_config
}
pub fn subscribe_events(&self) -> broadcast::Receiver<McpServerEvent> {
self.events.subscribe()
}
pub async fn subscribe_resource(&self, uri: impl Into<String>) -> Result<(), McpError> {
let uri = uri.into();
self.peer()
.subscribe(rmcp_model::SubscribeRequestParams::new(uri.clone()))
.await
.map_err(|error| {
rmcp_operation_error(
&self.server_id,
McpMethod::ResourcesSubscribe { uri },
error,
)
})
}
pub async fn unsubscribe_resource(&self, uri: impl Into<String>) -> Result<(), McpError> {
let uri = uri.into();
self.peer()
.unsubscribe(rmcp_model::UnsubscribeRequestParams::new(uri.clone()))
.await
.map_err(|error| {
rmcp_operation_error(
&self.server_id,
McpMethod::ResourcesUnsubscribe { uri },
error,
)
})
}
pub async fn set_logging_level(&self, level: McpLoggingLevel) -> Result<(), McpError> {
self.peer()
.set_level(rmcp_model::SetLevelRequestParams::new(level))
.await
.map_err(|error| {
rmcp_operation_error(
&self.server_id,
McpMethod::LoggingSetLevel {
level: format!("{level:?}"),
},
error,
)
})
}
pub async fn notify_cancelled(
&self,
params: McpCancelledNotificationParam,
) -> Result<(), McpError> {
self.peer()
.notify_cancelled(params)
.await
.map_err(rmcp_service_error)
}
pub async fn notify_roots_list_changed(&self) -> Result<(), McpError> {
self.peer()
.notify_roots_list_changed()
.await
.map_err(rmcp_service_error)
}
pub async fn close(&self) -> Result<(), McpError> {
let mut inner = self.inner.lock().await;
inner
.close()
.await
.map(|_| ())
.map_err(|error| McpError::Transport(format!("rmcp service close failed: {error}")))
}
pub async fn resolve_auth(&self, resolution: AuthResolution) -> Result<(), McpError> {
let mut auth_slot = self.auth.lock().await;
match resolution {
AuthResolution::Provided { credentials, .. } => {
*auth_slot = Some(credentials);
}
AuthResolution::Cancelled { .. } => {
*auth_slot = None;
}
}
let snapshot = auth_slot.clone();
drop(auth_slot);
if self.config.is_some() {
self.reconnect_inner(snapshot.as_ref()).await?;
}
Ok(())
}
pub async fn discover(&self) -> Result<McpDiscoverySnapshot, McpError> {
let tools = async {
match self.capabilities.tools {
Some(_) => self.list_tools().await,
None => Ok(Vec::new()),
}
};
let resources = async {
match self.capabilities.resources {
Some(_) => self.list_resources().await,
None => Ok(Vec::new()),
}
};
let prompts = async {
match self.capabilities.prompts {
Some(_) => self.list_prompts().await,
None => Ok(Vec::new()),
}
};
let (tools, resources, prompts) = tokio::try_join!(tools, resources, prompts)?;
Ok(McpDiscoverySnapshot {
server_id: self.server_id.clone(),
tools,
resources,
prompts,
metadata: MetadataMap::new(),
})
}
async fn drain_notifications(&self) -> Vec<McpServerNotification> {
let mut notifications = self.notifications.lock().await;
let mut drained = Vec::new();
while let Ok(notification) = notifications.try_recv() {
drained.push(notification);
}
drained
}
pub async fn list_tools(&self) -> Result<Vec<McpTool>, McpError> {
self.peer()
.list_all_tools()
.await
.map_err(rmcp_service_error)
}
pub async fn list_resources(&self) -> Result<Vec<McpResource>, McpError> {
self.peer()
.list_all_resources()
.await
.map_err(rmcp_service_error)
}
pub async fn list_prompts(&self) -> Result<Vec<McpPrompt>, McpError> {
self.peer()
.list_all_prompts()
.await
.map_err(rmcp_service_error)
}
pub async fn call_tool(
&self,
name: &str,
arguments: Value,
) -> Result<CallToolResult, McpError> {
let arguments_for_auth = arguments.clone();
let mut params = rmcp_model::CallToolRequestParams::new(name.to_string());
if !arguments.is_null() {
params =
params.with_arguments(value_to_json_object(arguments, "tools/call arguments")?);
}
let name_owned = name.to_string();
let span = tracing::info_span!(
"mcp.call_tool",
"otel.name" = %format!("mcp.call_tool {name}"),
"mcp.server.id" = %self.server_id,
"mcp.tool.name" = %name,
"error.type" = tracing::field::Empty,
);
use tracing::Instrument;
let result = self.peer().call_tool(params).instrument(span.clone()).await;
match result {
Ok(result) => {
if result.is_error == Some(true) {
span.record("error.type", "tool_error");
}
Ok(result)
}
Err(error) => {
span.record("error.type", "mcp_error");
Err(rmcp_operation_error(
&self.server_id,
McpMethod::ToolsCall {
name: name_owned,
arguments: arguments_for_auth,
},
error,
))
}
}
}
pub async fn read_resource(&self, uri: &str) -> Result<ReadResourceResult, McpError> {
let uri_owned = uri.to_string();
self.peer()
.read_resource(rmcp_model::ReadResourceRequestParams::new(uri))
.await
.map_err(|error| {
rmcp_operation_error(
&self.server_id,
McpMethod::ResourcesRead { uri: uri_owned },
error,
)
})
}
pub async fn get_prompt(
&self,
name: &str,
arguments: Value,
) -> Result<GetPromptResult, McpError> {
let arguments_for_auth = arguments.clone();
let name_owned = name.to_string();
let mut params = rmcp_model::GetPromptRequestParams::new(name);
if !arguments.is_null() {
params =
params.with_arguments(value_to_json_object(arguments, "prompts/get arguments")?);
}
self.peer().get_prompt(params).await.map_err(|error| {
rmcp_operation_error(
&self.server_id,
McpMethod::PromptsGet {
name: name_owned,
arguments: arguments_for_auth,
},
error,
)
})
}
}
async fn connect_rmcp_stdio(
config: &McpServerConfig,
binding: &StdioTransportConfig,
handler: McpClientHandler,
) -> Result<(RmcpClientService, McpServerCapabilities), McpError> {
let transport = TokioChildProcess::new(
tokio::process::Command::new(&binding.command).configure(|command| {
command.args(&binding.args);
if let Some(cwd) = &binding.cwd {
command.current_dir(cwd);
}
for (key, value) in &binding.env {
command.env(key, value);
}
}),
)
.map_err(McpError::Io)?;
let service = handler
.serve(transport)
.await
.map_err(|error| rmcp_initialize_error(config, error))?;
let capabilities = service
.peer_info()
.map(|info| rmcp_server_capabilities_to_agentkit(&info.capabilities))
.unwrap_or_default();
Ok((service, capabilities))
}
async fn connect_rmcp_streamable_http(
config: &McpServerConfig,
binding: &StreamableHttpTransportConfig,
auth: Option<&MetadataMap>,
handler: McpClientHandler,
) -> Result<(RmcpClientService, McpServerCapabilities), McpError> {
let auth_header = auth
.and_then(bearer_token_from_metadata)
.or_else(|| binding.bearer_token.clone());
let mut rmcp_config = RmcpStreamableHttpClientTransportConfig::with_uri(binding.url.clone());
if let Some(auth_header) = auth_header {
rmcp_config = rmcp_config.auth_header(auth_header);
}
rmcp_config = rmcp_config.custom_headers(binding.headers.iter().cloned().collect());
let result = match binding.http_client.as_ref() {
Some(client) => {
let transport = StreamableHttpClientTransport::with_client(
DynHttpClient(client.clone()),
rmcp_config,
);
handler.serve(transport).await
}
None => {
let transport = StreamableHttpClientTransport::from_config(rmcp_config);
handler.serve(transport).await
}
};
let service = result.map_err(|error| rmcp_initialize_error(config, error))?;
let capabilities = service
.peer_info()
.map(|info| rmcp_server_capabilities_to_agentkit(&info.capabilities))
.unwrap_or_default();
Ok((service, capabilities))
}
pub struct McpResourceHandle {
connection: Arc<McpConnection>,
descriptor: ResourceDescriptor,
}
#[async_trait]
impl ResourceProvider for McpResourceHandle {
async fn list_resources(&self) -> Result<Vec<ResourceDescriptor>, CapabilityError> {
Ok(vec![self.descriptor.clone()])
}
async fn read_resource(
&self,
id: &ResourceId,
_ctx: &mut CapabilityContext<'_>,
) -> Result<ResourceContents, CapabilityError> {
let result = self
.connection
.read_resource(&id.0)
.await
.map_err(|error| match error {
McpError::AuthRequired(request) => {
CapabilityError::Unavailable(format!("auth required: {:?}", request))
}
other => CapabilityError::ExecutionFailed(other.to_string()),
})?;
read_resource_result_to_capabilities(result)
.map_err(|error| CapabilityError::ExecutionFailed(error.to_string()))
}
}
pub struct McpPromptHandle {
connection: Arc<McpConnection>,
descriptor: PromptDescriptor,
}
#[async_trait]
impl PromptProvider for McpPromptHandle {
async fn list_prompts(&self) -> Result<Vec<PromptDescriptor>, CapabilityError> {
Ok(vec![self.descriptor.clone()])
}
async fn get_prompt(
&self,
id: &PromptId,
args: Value,
_ctx: &mut CapabilityContext<'_>,
) -> Result<PromptContents, CapabilityError> {
let result =
self.connection
.get_prompt(&id.0, args)
.await
.map_err(|error| match error {
McpError::AuthRequired(request) => {
CapabilityError::Unavailable(format!("auth required: {:?}", request))
}
other => CapabilityError::ExecutionFailed(other.to_string()),
})?;
Ok(get_prompt_result_to_capabilities(result))
}
}
pub struct McpCapabilityProvider {
invocables: Vec<Arc<dyn Invocable>>,
resources: Vec<Arc<dyn ResourceProvider>>,
prompts: Vec<Arc<dyn PromptProvider>>,
}
impl McpCapabilityProvider {
pub fn from_snapshot(connection: Arc<McpConnection>, snapshot: &McpDiscoverySnapshot) -> Self {
Self::from_snapshot_with_namespace(connection, snapshot, &McpToolNamespace::Default)
}
pub fn from_snapshot_with_namespace(
connection: Arc<McpConnection>,
snapshot: &McpDiscoverySnapshot,
namespace: &McpToolNamespace,
) -> Self {
let server_id = connection.server_id().clone();
let registry =
snapshot
.tools
.iter()
.cloned()
.fold(ToolRegistry::new(), |registry, tool| {
registry.with(McpToolAdapter::with_namespace(
&server_id,
connection.clone(),
tool,
namespace,
))
});
let permissions: Arc<dyn PermissionChecker> = Arc::new(AllowAllPermissions);
let resources_arc: Arc<dyn agentkit_tools_core::ToolResources> = Arc::new(());
let invocables =
ToolCapabilityProvider::from_registry(®istry, permissions, resources_arc)
.invocables();
let resources = snapshot
.resources
.iter()
.cloned()
.map(|resource| {
Arc::new(McpResourceHandle {
connection: connection.clone(),
descriptor: resource_descriptor_from_rmcp(resource),
}) as Arc<dyn ResourceProvider>
})
.collect();
let prompts = snapshot
.prompts
.iter()
.cloned()
.map(|prompt| {
Arc::new(McpPromptHandle {
connection: connection.clone(),
descriptor: prompt_descriptor_from_rmcp(prompt),
}) as Arc<dyn PromptProvider>
})
.collect();
Self {
invocables,
resources,
prompts,
}
}
pub fn merge<I>(providers: I) -> Self
where
I: IntoIterator<Item = Self>,
{
let mut invocables = Vec::new();
let mut resources = Vec::new();
let mut prompts = Vec::new();
for provider in providers {
invocables.extend(provider.invocables);
resources.extend(provider.resources);
prompts.extend(provider.prompts);
}
Self {
invocables,
resources,
prompts,
}
}
pub async fn connect(
config: &McpServerConfig,
) -> Result<(Arc<McpConnection>, Self, McpDiscoverySnapshot), McpError> {
let connection = Arc::new(McpConnection::connect(config).await?);
let snapshot = connection.discover().await?;
let provider = Self::from_snapshot(connection.clone(), &snapshot);
Ok((connection, provider, snapshot))
}
}
impl CapabilityProvider for McpCapabilityProvider {
fn invocables(&self) -> Vec<Arc<dyn Invocable>> {
self.invocables.clone()
}
fn resources(&self) -> Vec<Arc<dyn ResourceProvider>> {
self.resources.clone()
}
fn prompts(&self) -> Vec<Arc<dyn PromptProvider>> {
self.prompts.clone()
}
}
#[derive(Clone)]
pub struct McpServerHandle {
config: McpServerConfig,
connection: Arc<McpConnection>,
snapshot: McpDiscoverySnapshot,
namespace: McpToolNamespace,
}
impl McpServerHandle {
pub fn config(&self) -> &McpServerConfig {
&self.config
}
pub fn server_id(&self) -> &McpServerId {
self.connection.server_id()
}
pub fn connection(&self) -> Arc<McpConnection> {
self.connection.clone()
}
pub fn snapshot(&self) -> &McpDiscoverySnapshot {
&self.snapshot
}
pub fn namespace(&self) -> &McpToolNamespace {
&self.namespace
}
pub fn tool_registry(&self) -> ToolRegistry {
self.snapshot
.tools
.iter()
.cloned()
.fold(ToolRegistry::new(), |registry, tool| {
registry.with(McpToolAdapter::with_namespace(
self.server_id(),
self.connection.clone(),
tool,
&self.namespace,
))
})
}
pub fn capability_provider(&self) -> McpCapabilityProvider {
McpCapabilityProvider::from_snapshot_with_namespace(
self.connection.clone(),
&self.snapshot,
&self.namespace,
)
}
}
#[derive(Debug)]
pub struct McpServerConnectionError {
pub server_id: McpServerId,
pub error: McpError,
}
#[must_use = "inspect `failed` before ignoring the settled MCP connection result"]
pub struct McpConnectAllSettled {
pub connected: Vec<McpServerHandle>,
pub failed: Vec<McpServerConnectionError>,
}
impl McpConnectAllSettled {
pub fn all_connected(&self) -> bool {
self.failed.is_empty()
}
pub fn has_failures(&self) -> bool {
!self.failed.is_empty()
}
pub fn connected(&self) -> &[McpServerHandle] {
&self.connected
}
pub fn failed(&self) -> &[McpServerConnectionError] {
&self.failed
}
pub fn into_parts(self) -> (Vec<McpServerHandle>, Vec<McpServerConnectionError>) {
(self.connected, self.failed)
}
}
impl fmt::Debug for McpConnectAllSettled {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let connected = self
.connected
.iter()
.map(|handle| handle.server_id())
.collect::<Vec<_>>();
f.debug_struct("McpConnectAllSettled")
.field("connected", &connected)
.field("failed", &self.failed)
.finish()
}
}
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub struct McpServerOptions {
pub connect_timeout: Option<Duration>,
}
impl McpServerOptions {
pub fn new() -> Self {
Self::default()
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.connect_timeout = Some(timeout);
self
}
}
pub struct McpServerManager {
configs: BTreeMap<McpServerId, McpServerConfig>,
options: BTreeMap<McpServerId, McpServerOptions>,
connections: BTreeMap<McpServerId, McpServerHandle>,
auth: BTreeMap<McpServerId, MetadataMap>,
catalog_tx: broadcast::Sender<McpCatalogEvent>,
namespace: McpToolNamespace,
handler_config: McpHandlerConfig,
catalog_writer: CatalogWriter,
server_tools: BTreeMap<McpServerId, BTreeSet<ToolName>>,
}
impl Default for McpServerManager {
fn default() -> Self {
let (catalog_tx, _) = broadcast::channel(128);
let (catalog_writer, _) = dynamic_catalog("mcp");
Self {
configs: BTreeMap::new(),
options: BTreeMap::new(),
connections: BTreeMap::new(),
auth: BTreeMap::new(),
catalog_tx,
namespace: McpToolNamespace::Default,
handler_config: McpHandlerConfig::default(),
catalog_writer,
server_tools: BTreeMap::new(),
}
}
}
impl McpServerManager {
pub fn new() -> Self {
Self::default()
}
pub fn with_namespace(mut self, namespace: McpToolNamespace) -> Self {
self.namespace = namespace;
self
}
pub fn set_namespace(&mut self, namespace: McpToolNamespace) -> &mut Self {
self.namespace = namespace;
self
}
pub fn namespace(&self) -> &McpToolNamespace {
&self.namespace
}
pub fn with_handler_config(mut self, handler_config: McpHandlerConfig) -> Self {
self.handler_config = handler_config;
self
}
pub fn set_handler_config(&mut self, handler_config: McpHandlerConfig) -> &mut Self {
self.handler_config = handler_config;
self
}
pub fn handler_config(&self) -> &McpHandlerConfig {
&self.handler_config
}
pub fn with_server(mut self, config: McpServerConfig) -> Self {
self.register_server(config);
self
}
pub fn with_server_options(
mut self,
config: McpServerConfig,
options: McpServerOptions,
) -> Self {
self.register_server_with_options(config, options);
self
}
pub fn register_server(&mut self, config: McpServerConfig) -> &mut Self {
let id = config.id.clone();
self.configs.insert(id.clone(), config);
self.options.entry(id).or_default();
self
}
pub fn register_server_with_options(
&mut self,
config: McpServerConfig,
options: McpServerOptions,
) -> &mut Self {
let id = config.id.clone();
self.configs.insert(id.clone(), config);
self.options.insert(id, options);
self
}
pub fn connected_server(&self, server_id: &McpServerId) -> Option<&McpServerHandle> {
self.connections.get(server_id)
}
pub fn connected_servers(&self) -> Vec<&McpServerHandle> {
self.connections.values().collect()
}
pub fn subscribe_catalog_events(&self) -> broadcast::Receiver<McpCatalogEvent> {
self.catalog_tx.subscribe()
}
fn emit_catalog_event(&self, event: McpCatalogEvent) {
let _ = self.catalog_tx.send(event);
}
async fn discover_with_options(
connection: &McpConnection,
options: &McpServerOptions,
) -> Result<McpDiscoverySnapshot, McpError> {
match options.connect_timeout {
Some(timeout) => tokio::time::timeout(timeout, connection.discover())
.await
.map_err(|_| McpError::Timeout {
operation: "discover",
duration: timeout,
})?,
None => connection.discover().await,
}
}
async fn connect_and_discover(
config: &McpServerConfig,
auth: Option<&MetadataMap>,
handler_config: McpHandlerConfig,
options: &McpServerOptions,
) -> Result<(Arc<McpConnection>, McpDiscoverySnapshot), McpError> {
let connect = async {
let connection =
Arc::new(McpConnection::connect_with_auth(config, auth, handler_config).await?);
let snapshot = connection.discover().await?;
Ok((connection, snapshot))
};
match options.connect_timeout {
Some(timeout) => {
tokio::time::timeout(timeout, connect)
.await
.map_err(|_| McpError::Timeout {
operation: "connect",
duration: timeout,
})?
}
None => connect.await,
}
}
pub async fn connect_server(
&mut self,
server_id: &McpServerId,
) -> Result<McpServerHandle, McpError> {
let config = self
.configs
.get(server_id)
.cloned()
.ok_or_else(|| McpError::UnknownServer(server_id.to_string()))?;
let options = self.options.get(server_id).cloned().unwrap_or_default();
let (connection, snapshot) = Self::connect_and_discover(
&config,
self.auth.get(server_id),
self.handler_config.clone(),
&options,
)
.await?;
let handle = McpServerHandle {
config,
connection,
snapshot,
namespace: self.namespace.clone(),
};
self.connections.insert(server_id.clone(), handle.clone());
self.register_server_tools(server_id, &handle.snapshot);
self.emit_catalog_event(McpCatalogEvent::ServerConnected {
server_id: server_id.clone(),
});
Ok(handle)
}
pub async fn connect_all(&mut self) -> Result<Vec<McpServerHandle>, McpError> {
let plans: Vec<(
McpServerId,
McpServerConfig,
McpServerOptions,
Option<MetadataMap>,
)> = self
.configs
.iter()
.map(|(id, cfg)| {
(
id.clone(),
cfg.clone(),
self.options.get(id).cloned().unwrap_or_default(),
self.auth.get(id).cloned(),
)
})
.collect();
let handler_config = self.handler_config.clone();
let namespace = self.namespace.clone();
let futures = plans.into_iter().map(|(server_id, config, options, auth)| {
let handler_config = handler_config.clone();
let namespace = namespace.clone();
async move {
let (connection, snapshot) =
Self::connect_and_discover(&config, auth.as_ref(), handler_config, &options)
.await?;
Ok::<(McpServerId, McpServerHandle), McpError>((
server_id,
McpServerHandle {
config,
connection,
snapshot,
namespace,
},
))
}
});
let results = try_join_all(futures).await?;
let mut handles = Vec::with_capacity(results.len());
let mut connected: Vec<(McpServerId, McpDiscoverySnapshot)> =
Vec::with_capacity(results.len());
for (server_id, handle) in results {
connected.push((server_id.clone(), handle.snapshot.clone()));
self.connections.insert(server_id, handle.clone());
handles.push(handle);
}
for (server_id, snapshot) in &connected {
self.register_server_tools(server_id, snapshot);
}
for (server_id, _) in connected {
self.emit_catalog_event(McpCatalogEvent::ServerConnected { server_id });
}
Ok(handles)
}
pub async fn connect_all_settled(&mut self) -> McpConnectAllSettled {
let plans: Vec<(
McpServerId,
McpServerConfig,
McpServerOptions,
Option<MetadataMap>,
)> = self
.configs
.iter()
.map(|(id, cfg)| {
(
id.clone(),
cfg.clone(),
self.options.get(id).cloned().unwrap_or_default(),
self.auth.get(id).cloned(),
)
})
.collect();
let handler_config = self.handler_config.clone();
let namespace = self.namespace.clone();
let futures = plans.into_iter().map(|(server_id, config, options, auth)| {
let handler_config = handler_config.clone();
let namespace = namespace.clone();
async move {
let result = async {
let (connection, snapshot) = Self::connect_and_discover(
&config,
auth.as_ref(),
handler_config,
&options,
)
.await?;
Ok::<McpServerHandle, McpError>(McpServerHandle {
config,
connection,
snapshot,
namespace,
})
}
.await;
(server_id, result)
}
});
let results = join_all(futures).await;
let mut connected = Vec::new();
let mut failures = Vec::new();
let mut connected_snapshots = Vec::new();
for (server_id, result) in results {
match result {
Ok(handle) => {
connected_snapshots.push((server_id.clone(), handle.snapshot.clone()));
self.connections.insert(server_id, handle.clone());
connected.push(handle);
}
Err(error) => {
failures.push(McpServerConnectionError { server_id, error });
}
}
}
for (server_id, snapshot) in &connected_snapshots {
self.register_server_tools(server_id, snapshot);
}
for (server_id, _) in connected_snapshots {
self.emit_catalog_event(McpCatalogEvent::ServerConnected { server_id });
}
McpConnectAllSettled {
connected,
failed: failures,
}
}
pub async fn refresh_server(
&mut self,
server_id: &McpServerId,
) -> Result<McpDiscoverySnapshot, McpError> {
let handle = self
.connections
.get_mut(server_id)
.ok_or_else(|| McpError::UnknownServer(server_id.to_string()))?;
let options = self.options.get(server_id).cloned().unwrap_or_default();
let previous = handle.snapshot.clone();
let snapshot = match Self::discover_with_options(&handle.connection, &options).await {
Ok(snapshot) => snapshot,
Err(error) => {
self.emit_catalog_event(McpCatalogEvent::RefreshFailed {
server_id: server_id.clone(),
message: error.to_string(),
});
return Err(error);
}
};
handle.snapshot = snapshot.clone();
let events = diff_discovery_snapshots(server_id, &previous, &snapshot);
if !events.is_empty() {
self.apply_catalog_events(server_id, &snapshot, &events);
for event in events {
self.emit_catalog_event(event);
}
}
Ok(snapshot)
}
pub async fn refresh_changed_catalogs(&mut self) -> Result<Vec<McpCatalogEvent>, McpError> {
let server_ids = self.connections.keys().cloned().collect::<Vec<_>>();
let mut emitted = Vec::new();
for server_id in server_ids {
let Some(connection) = self
.connections
.get(&server_id)
.map(McpServerHandle::connection)
else {
continue;
};
let notifications = connection.drain_notifications().await;
if notifications.is_empty() {
continue;
}
let handle = self
.connections
.get_mut(&server_id)
.ok_or_else(|| McpError::UnknownServer(server_id.to_string()))?;
let options = self.options.get(&server_id).cloned().unwrap_or_default();
let previous = handle.snapshot.clone();
let snapshot = match Self::discover_with_options(&handle.connection, &options).await {
Ok(snapshot) => snapshot,
Err(error) => {
let event = McpCatalogEvent::RefreshFailed {
server_id: server_id.clone(),
message: error.to_string(),
};
self.emit_catalog_event(event.clone());
emitted.push(event);
return Err(error);
}
};
handle.snapshot = snapshot.clone();
let events = diff_discovery_snapshots(&server_id, &previous, &snapshot);
if !events.is_empty() {
self.apply_catalog_events(&server_id, &snapshot, &events);
for event in events {
self.emit_catalog_event(event.clone());
emitted.push(event);
}
}
}
Ok(emitted)
}
pub async fn disconnect_server(&mut self, server_id: &McpServerId) -> Result<(), McpError> {
let Some(handle) = self.connections.remove(server_id) else {
return Err(McpError::UnknownServer(server_id.to_string()));
};
handle.connection.close().await?;
self.unregister_server_tools(server_id);
self.emit_catalog_event(McpCatalogEvent::ServerDisconnected {
server_id: server_id.clone(),
});
Ok(())
}
pub async fn resolve_auth(&mut self, resolution: AuthResolution) -> Result<(), McpError> {
let server_id = resolution
.request()
.server_id()
.ok_or_else(|| McpError::AuthResolution("auth resolution missing server id".into()))?;
let server_id = McpServerId::new(server_id);
match &resolution {
AuthResolution::Provided { credentials, .. } => {
self.auth.insert(server_id.clone(), credentials.clone());
}
AuthResolution::Cancelled { .. } => {
self.auth.remove(&server_id);
}
}
if let Some(handle) = self.connections.get(&server_id) {
handle.connection.resolve_auth(resolution).await?;
} else if !self.configs.contains_key(&server_id) {
return Err(McpError::UnknownServer(server_id.to_string()));
}
self.emit_catalog_event(McpCatalogEvent::AuthChanged { server_id });
Ok(())
}
pub fn tool_registry(&self) -> ToolRegistry {
self.connections
.values()
.fold(ToolRegistry::new(), |mut registry, handle| {
for tool in handle.snapshot.tools.iter().cloned() {
registry.register(McpToolAdapter::with_namespace(
handle.server_id(),
handle.connection.clone(),
tool,
&self.namespace,
));
}
registry
})
}
pub fn source(&self) -> CatalogReader {
self.catalog_writer.reader()
}
fn apply_catalog_events(
&mut self,
server_id: &McpServerId,
snapshot: &McpDiscoverySnapshot,
events: &[McpCatalogEvent],
) {
for event in events {
if let McpCatalogEvent::ToolsChanged {
added,
removed,
changed,
..
} = event
{
self.apply_server_tool_diff(server_id, snapshot, added, removed, changed);
}
}
}
fn register_server_tools(&mut self, server_id: &McpServerId, snapshot: &McpDiscoverySnapshot) {
let connection = match self.connections.get(server_id) {
Some(handle) => handle.connection.clone(),
None => return,
};
let previous = self.server_tools.remove(server_id).unwrap_or_default();
let mut names = BTreeSet::new();
for tool in &snapshot.tools {
let adapter = McpToolAdapter::with_namespace(
server_id,
connection.clone(),
tool.clone(),
&self.namespace,
);
names.insert(adapter.spec().name.clone());
self.catalog_writer.upsert(Arc::new(adapter));
}
for stale in previous.difference(&names) {
self.catalog_writer.remove(stale);
}
self.server_tools.insert(server_id.clone(), names);
}
fn unregister_server_tools(&mut self, server_id: &McpServerId) {
let Some(names) = self.server_tools.remove(server_id) else {
return;
};
for name in names {
self.catalog_writer.remove(&name);
}
}
fn apply_server_tool_diff(
&mut self,
server_id: &McpServerId,
snapshot: &McpDiscoverySnapshot,
added: &[String],
removed: &[String],
changed: &[String],
) {
let connection = match self.connections.get(server_id) {
Some(handle) => handle.connection.clone(),
None => return,
};
let names = self.server_tools.entry(server_id.clone()).or_default();
for raw_name in removed {
let agentkit_name = ToolName::new(self.namespace.apply(server_id, raw_name));
if names.remove(&agentkit_name) {
self.catalog_writer.remove(&agentkit_name);
}
}
let upsert_one = |raw_name: &str| -> Option<(ToolName, McpToolAdapter)> {
let tool = snapshot
.tools
.iter()
.find(|tool| tool.name.as_ref() == raw_name)?
.clone();
let adapter = McpToolAdapter::with_namespace(
server_id,
connection.clone(),
tool,
&self.namespace,
);
Some((adapter.spec().name.clone(), adapter))
};
for raw_name in added.iter().chain(changed.iter()) {
if let Some((agentkit_name, adapter)) = upsert_one(raw_name) {
names.insert(agentkit_name);
self.catalog_writer.upsert(Arc::new(adapter));
}
}
}
pub fn capability_provider(&self) -> McpCapabilityProvider {
McpCapabilityProvider::merge(
self.connections
.values()
.map(McpServerHandle::capability_provider),
)
}
}
fn diff_discovery_snapshots(
server_id: &McpServerId,
previous: &McpDiscoverySnapshot,
current: &McpDiscoverySnapshot,
) -> Vec<McpCatalogEvent> {
let mut events = Vec::new();
let (added, removed, changed) = diff_named_items(
previous.tools.iter().map(|item| (item.name.as_ref(), item)),
current.tools.iter().map(|item| (item.name.as_ref(), item)),
);
if !added.is_empty() || !removed.is_empty() || !changed.is_empty() {
events.push(McpCatalogEvent::ToolsChanged {
server_id: server_id.clone(),
added,
removed,
changed,
});
}
let (added, removed, changed) = diff_named_items(
previous
.resources
.iter()
.map(|item| (item.uri.as_str(), item)),
current
.resources
.iter()
.map(|item| (item.uri.as_str(), item)),
);
if !added.is_empty() || !removed.is_empty() || !changed.is_empty() {
events.push(McpCatalogEvent::ResourcesChanged {
server_id: server_id.clone(),
added,
removed,
changed,
});
}
let (added, removed, changed) = diff_named_items(
previous
.prompts
.iter()
.map(|item| (item.name.as_str(), item)),
current
.prompts
.iter()
.map(|item| (item.name.as_str(), item)),
);
if !added.is_empty() || !removed.is_empty() || !changed.is_empty() {
events.push(McpCatalogEvent::PromptsChanged {
server_id: server_id.clone(),
added,
removed,
changed,
});
}
events
}
fn diff_named_items<'a, T>(
previous: impl IntoIterator<Item = (&'a str, &'a T)>,
current: impl IntoIterator<Item = (&'a str, &'a T)>,
) -> (Vec<String>, Vec<String>, Vec<String>)
where
T: PartialEq + 'a,
{
let mut prev: Vec<(&str, &T)> = previous.into_iter().collect();
let mut curr: Vec<(&str, &T)> = current.into_iter().collect();
prev.sort_unstable_by_key(|(name, _)| *name);
curr.sort_unstable_by_key(|(name, _)| *name);
let mut added = Vec::new();
let mut removed = Vec::new();
let mut changed = Vec::new();
let (mut i, mut j) = (0, 0);
while i < prev.len() && j < curr.len() {
match prev[i].0.cmp(curr[j].0) {
std::cmp::Ordering::Less => {
removed.push(prev[i].0.to_string());
i += 1;
}
std::cmp::Ordering::Greater => {
added.push(curr[j].0.to_string());
j += 1;
}
std::cmp::Ordering::Equal => {
if prev[i].1 != curr[j].1 {
changed.push(curr[j].0.to_string());
}
i += 1;
j += 1;
}
}
}
while i < prev.len() {
removed.push(prev[i].0.to_string());
i += 1;
}
while j < curr.len() {
added.push(curr[j].0.to_string());
j += 1;
}
(added, removed, changed)
}
pub struct McpToolAdapter {
tool_name: String,
connection: Arc<McpConnection>,
spec: ToolSpec,
}
impl McpToolAdapter {
pub fn new(server_id: &McpServerId, connection: Arc<McpConnection>, tool: McpTool) -> Self {
Self::with_namespace(server_id, connection, tool, &McpToolNamespace::Default)
}
pub fn with_namespace(
server_id: &McpServerId,
connection: Arc<McpConnection>,
tool: McpTool,
namespace: &McpToolNamespace,
) -> Self {
let spec = tool_spec_from_tool(server_id, &tool, namespace);
Self {
tool_name: tool.name.into_owned(),
connection,
spec,
}
}
async fn handle_invocation_error(
&self,
err: McpInvocationError,
input: &Value,
) -> Result<CallToolResult, ToolError> {
let Some(responder) = self.connection.handler_config().error_responder.clone() else {
return Err(ToolError::ExecutionFailed(err.to_string()));
};
let method = McpMethod::ToolsCall {
name: self.tool_name.clone(),
arguments: input.clone(),
};
let ctx = McpErrorContext {
server_id: self.connection.server_id(),
method: &method,
input: Some(input),
};
match responder.handle(&err, ctx).await {
ErrorResponderOutcome::SynthesizeResult(result) => Ok(result),
ErrorResponderOutcome::PassThrough => Err(ToolError::ExecutionFailed(err.to_string())),
}
}
}
#[async_trait]
impl Tool for McpToolAdapter {
fn spec(&self) -> &ToolSpec {
&self.spec
}
async fn invoke(
&self,
request: ToolRequest,
_ctx: &mut ToolContext<'_>,
) -> Result<ToolResult, ToolError> {
let input = request.input;
let result = match self
.connection
.call_tool(&self.tool_name, input.clone())
.await
{
Ok(result) => result,
Err(McpError::AuthRequired(auth_request)) => {
let responder = self
.connection
.handler_config()
.auth
.clone()
.ok_or_else(|| {
ToolError::ExecutionFailed(
"MCP server requires auth but no McpAuthResponder is registered".into(),
)
})?;
let resolution = responder.resolve(*auth_request).await.map_err(|error| {
ToolError::ExecutionFailed(format!("auth responder failed: {error}"))
})?;
match &resolution {
AuthResolution::Provided { .. } => {
self.connection
.resolve_auth(resolution.clone())
.await
.map_err(|error| {
ToolError::ExecutionFailed(format!(
"applying auth resolution failed: {error}"
))
})?;
}
AuthResolution::Cancelled { .. } => {
return Err(ToolError::ExecutionFailed(
"user cancelled MCP auth flow".into(),
));
}
}
match self
.connection
.call_tool(&self.tool_name, input.clone())
.await
{
Ok(result) => result,
Err(McpError::AuthRequired(req)) => {
return Err(ToolError::ExecutionFailed(format!(
"MCP auth challenge unresolved after retry: {}",
req.id
)));
}
Err(McpError::Invocation(err)) => {
self.handle_invocation_error(err, &input).await?
}
Err(other) => return Err(ToolError::ExecutionFailed(other.to_string())),
}
}
Err(McpError::Invocation(err)) => self.handle_invocation_error(err, &input).await?,
Err(other) => return Err(ToolError::ExecutionFailed(other.to_string())),
};
let is_error = result.is_error.unwrap_or(false);
Ok(ToolResult {
result: ToolResultPart {
call_id: request.call_id,
output: call_tool_result_to_tool_output(result),
is_error,
metadata: MetadataMap::new(),
},
duration: None,
metadata: MetadataMap::new(),
})
}
}
fn rmcp_server_capabilities_to_agentkit(
capabilities: &rmcp_model::ServerCapabilities,
) -> McpServerCapabilities {
McpServerCapabilities {
tools: capabilities.tools.as_ref().map(|tools| ToolsCapability {
list_changed: tools.list_changed,
}),
resources: capabilities
.resources
.as_ref()
.map(|resources| ResourcesCapability {
subscribe: resources.subscribe,
list_changed: resources.list_changed,
}),
prompts: capabilities
.prompts
.as_ref()
.map(|prompts| PromptsCapability {
list_changed: prompts.list_changed,
}),
logging: capabilities.logging.as_ref().map(|_| LoggingCapability {}),
}
}
fn tool_spec_from_tool(
server_id: &McpServerId,
tool: &McpTool,
namespace: &McpToolNamespace,
) -> ToolSpec {
ToolSpec {
name: ToolName::new(namespace.apply(server_id, &tool.name)),
description: tool
.description
.as_ref()
.map(|d| d.to_string())
.unwrap_or_else(|| tool.name.to_string()),
input_schema: Value::Object((*tool.input_schema).clone()),
output_schema: tool
.output_schema
.as_ref()
.map(|schema| Value::Object((**schema).clone())),
annotations: tool_annotations_from_rmcp(tool.annotations.as_ref()),
metadata: MetadataMap::new(),
}
}
fn tool_annotations_from_rmcp(annotations: Option<&McpToolAnnotations>) -> ToolAnnotations {
let Some(annotations) = annotations else {
return ToolAnnotations::default();
};
ToolAnnotations {
read_only_hint: annotations.read_only_hint.unwrap_or(false),
destructive_hint: annotations.destructive_hint.unwrap_or(false),
idempotent_hint: annotations.idempotent_hint.unwrap_or(false),
needs_approval_hint: false,
supports_streaming_hint: false,
}
}
fn resource_descriptor_from_rmcp(resource: McpResource) -> ResourceDescriptor {
let raw = resource.raw;
ResourceDescriptor {
id: ResourceId::new(raw.uri),
name: raw.name,
description: raw.description,
mime_type: raw.mime_type,
metadata: MetadataMap::new(),
}
}
fn prompt_descriptor_from_rmcp(prompt: McpPrompt) -> PromptDescriptor {
let arguments = prompt.arguments.unwrap_or_default();
let mut required = Vec::new();
let properties = arguments
.into_iter()
.map(|argument| {
let mut schema = serde_json::Map::new();
schema.insert("type".into(), Value::String("string".into()));
if let Some(description) = argument.description {
schema.insert("description".into(), Value::String(description));
}
if argument.required.unwrap_or(false) {
required.push(Value::String(argument.name.clone()));
}
(argument.name, Value::Object(schema))
})
.collect::<serde_json::Map<String, Value>>();
let mut input_schema = serde_json::Map::new();
input_schema.insert("type".into(), Value::String("object".into()));
input_schema.insert("properties".into(), Value::Object(properties));
if !required.is_empty() {
input_schema.insert("required".into(), Value::Array(required));
}
PromptDescriptor {
id: PromptId::new(prompt.name.clone()),
name: prompt.name,
description: prompt.description,
input_schema: Value::Object(input_schema),
metadata: MetadataMap::new(),
}
}
fn read_resource_result_to_capabilities(
result: ReadResourceResult,
) -> Result<ResourceContents, McpError> {
let content = result
.contents
.into_iter()
.next()
.ok_or_else(|| McpError::Protocol("resources/read returned no contents".into()))?;
Ok(resource_contents_to_capabilities(content))
}
fn resource_contents_to_capabilities(content: McpResourceContents) -> ResourceContents {
let mut metadata = MetadataMap::new();
let data = match content {
McpResourceContents::TextResourceContents {
text, mime_type, ..
} => {
if let Some(mime) = mime_type {
metadata.insert("mime_type".into(), Value::String(mime));
}
DataRef::InlineText(text)
}
McpResourceContents::BlobResourceContents {
blob,
mime_type,
uri,
..
} => {
if let Some(mime) = mime_type {
metadata.insert("mime_type".into(), Value::String(mime));
}
metadata.insert("uri".into(), Value::String(uri));
DataRef::InlineText(blob)
}
};
ResourceContents { data, metadata }
}
fn get_prompt_result_to_capabilities(result: GetPromptResult) -> PromptContents {
let items = result
.messages
.into_iter()
.map(prompt_message_to_item)
.collect();
let mut metadata = MetadataMap::new();
if let Some(description) = result.description {
metadata.insert("description".into(), Value::String(description));
}
PromptContents { items, metadata }
}
fn prompt_message_to_item(message: PromptMessage) -> Item {
let kind = match message.role {
PromptMessageRole::Assistant => ItemKind::Assistant,
PromptMessageRole::User => ItemKind::User,
};
Item {
id: None,
kind,
parts: vec![prompt_message_content_to_part(message.content)],
metadata: MetadataMap::new(),
usage: None,
finish_reason: None,
created_at: None,
}
}
fn prompt_message_content_to_part(content: PromptMessageContent) -> Part {
match content {
PromptMessageContent::Text { text } => Part::Text(TextPart::new(text)),
PromptMessageContent::Image { image } => Part::Media(MediaPart::new(
Modality::Image,
image.mime_type.clone(),
DataRef::InlineText(image.data.clone()),
)),
PromptMessageContent::Resource { resource } => {
let agentkit_resource = resource_contents_to_capabilities(resource.resource.clone());
agentkit_part_from_resource(agentkit_resource)
}
PromptMessageContent::ResourceLink { link } => Part::Text(TextPart::new(link.uri.clone())),
}
}
fn agentkit_part_from_resource(resource: ResourceContents) -> Part {
let mime = resource
.metadata
.get("mime_type")
.and_then(Value::as_str)
.unwrap_or("text/plain")
.to_string();
Part::Media(MediaPart::new(Modality::Binary, mime, resource.data))
}
fn call_tool_result_to_tool_output(result: CallToolResult) -> ToolOutput {
if let Some(structured) = result.structured_content {
return ToolOutput::Structured(structured);
}
let parts = call_tool_content_to_parts(result.content);
if parts.iter().all(|part| matches!(part, Part::Text(_))) {
let text = parts
.iter()
.filter_map(|part| match part {
Part::Text(text) => Some(text.text.clone()),
_ => None,
})
.collect::<Vec<_>>()
.join("\n");
ToolOutput::Text(text)
} else {
ToolOutput::Parts(parts)
}
}
fn call_tool_content_to_parts(contents: Vec<Content>) -> Vec<Part> {
contents.into_iter().map(content_to_part).collect()
}
fn content_to_part(content: Content) -> Part {
match content.raw {
RawContent::Text(text) => Part::Text(TextPart::new(text.text)),
RawContent::Image(image) => Part::Media(MediaPart::new(
Modality::Image,
image.mime_type,
DataRef::InlineText(image.data),
)),
RawContent::Audio(audio) => Part::Media(MediaPart::new(
Modality::Audio,
audio.mime_type,
DataRef::InlineText(audio.data),
)),
RawContent::Resource(embedded) => {
agentkit_part_from_resource(resource_contents_to_capabilities(embedded.resource))
}
RawContent::ResourceLink(link) => Part::Text(TextPart::new(link.uri)),
}
}
fn value_to_json_object(value: Value, context: &str) -> Result<rmcp_model::JsonObject, McpError> {
match value {
Value::Object(object) => Ok(object),
Value::Null => Ok(serde_json::Map::new()),
other => Err(McpError::Protocol(format!(
"{context} must be a JSON object, got {other}"
))),
}
}
fn bearer_token_from_metadata(metadata: &MetadataMap) -> Option<String> {
["bearer_token", "access_token", "token", "api_key"]
.into_iter()
.find_map(|key| metadata.get(key).and_then(Value::as_str).map(str::to_owned))
}
fn rmcp_initialize_error(config: &McpServerConfig, error: ClientInitializeError) -> McpError {
if let Some(signal) = match &error {
ClientInitializeError::TransportError { error: dyn_err, .. } => {
transport_auth_signal(dyn_err)
}
_ => None,
} {
return McpError::AuthRequired(Box::new(auth_request_from_signal(
&config.id,
McpMethod::Initialize,
signal,
&error.to_string(),
)));
}
McpError::Transport(error.to_string())
}
fn rmcp_service_error(error: ServiceError) -> McpError {
service_error_to_mcp_error(error)
}
fn rmcp_operation_error(
server_id: &McpServerId,
method: McpMethod,
error: ServiceError,
) -> McpError {
if let Some(signal) = service_auth_signal(&error) {
return McpError::AuthRequired(Box::new(auth_request_from_signal(
server_id,
method,
signal,
&error.to_string(),
)));
}
service_error_to_mcp_error(error)
}
fn service_error_to_mcp_error(error: ServiceError) -> McpError {
match error {
ServiceError::McpError(data) => {
McpError::Invocation(McpInvocationError::from_error_data(data))
}
other => McpError::Transport(other.to_string()),
}
}
#[derive(Debug)]
enum AuthSignal {
Required {
www_authenticate: Option<String>,
},
InsufficientScope {
www_authenticate: Option<String>,
required_scope: Option<String>,
},
}
fn service_auth_signal(error: &ServiceError) -> Option<AuthSignal> {
match error {
ServiceError::TransportSend(dyn_err) => transport_auth_signal(dyn_err),
_ => None,
}
}
fn transport_auth_signal(error: &DynamicTransportError) -> Option<AuthSignal> {
let inner = error
.error
.downcast_ref::<StreamableHttpError<reqwest::Error>>()?;
match inner {
StreamableHttpError::AuthRequired(AuthRequiredError {
www_authenticate_header,
..
}) => Some(AuthSignal::Required {
www_authenticate: Some(www_authenticate_header.clone()),
}),
StreamableHttpError::InsufficientScope(InsufficientScopeError {
www_authenticate_header,
required_scope,
..
}) => Some(AuthSignal::InsufficientScope {
www_authenticate: Some(www_authenticate_header.clone()),
required_scope: required_scope.clone(),
}),
_ => None,
}
}
fn auth_request_from_signal(
server_id: &McpServerId,
method: McpMethod,
signal: AuthSignal,
message: &str,
) -> AuthRequest {
let method_name = method.method_name();
let mut challenge = MetadataMap::new();
challenge.insert("server_id".into(), Value::String(server_id.to_string()));
challenge.insert("method".into(), Value::String(method_name.into()));
challenge.insert("message".into(), Value::String(message.into()));
challenge.insert("flow_kind".into(), Value::String("http_bearer".into()));
match signal {
AuthSignal::Required { www_authenticate } => {
if let Some(header) = www_authenticate {
challenge.insert("www_authenticate".into(), Value::String(header));
}
}
AuthSignal::InsufficientScope {
www_authenticate,
required_scope,
} => {
challenge.insert("insufficient_scope".into(), Value::Bool(true));
if let Some(header) = www_authenticate {
challenge.insert("www_authenticate".into(), Value::String(header));
}
if let Some(scope) = required_scope {
challenge.insert("required_scope".into(), Value::String(scope));
}
}
}
AuthRequest {
id: format!("mcp:{}:{}", server_id, method_name),
provider: format!("mcp.{}", server_id),
operation: method.into_auth_operation(server_id),
challenge,
}
}
#[derive(Debug, Clone)]
pub enum McpMethod {
Initialize,
ToolsCall {
name: String,
arguments: Value,
},
ResourcesRead {
uri: String,
},
ResourcesSubscribe {
uri: String,
},
ResourcesUnsubscribe {
uri: String,
},
PromptsGet {
name: String,
arguments: Value,
},
LoggingSetLevel {
level: String,
},
}
impl McpMethod {
pub fn method_name(&self) -> &'static str {
match self {
Self::Initialize => "initialize",
Self::ToolsCall { .. } => "tools/call",
Self::ResourcesRead { .. } => "resources/read",
Self::ResourcesSubscribe { .. } => "resources/subscribe",
Self::ResourcesUnsubscribe { .. } => "resources/unsubscribe",
Self::PromptsGet { .. } => "prompts/get",
Self::LoggingSetLevel { .. } => "logging/setLevel",
}
}
fn into_auth_operation(self, server_id: &McpServerId) -> AuthOperation {
let server = server_id.to_string();
match self {
Self::Initialize => AuthOperation::McpConnect {
server_id: server,
metadata: MetadataMap::new(),
},
Self::ToolsCall { name, arguments } => AuthOperation::McpToolCall {
server_id: server,
tool_name: name,
input: arguments,
metadata: MetadataMap::new(),
},
Self::ResourcesRead { uri } => AuthOperation::McpResourceRead {
server_id: server,
resource_id: uri,
metadata: MetadataMap::new(),
},
Self::PromptsGet { name, arguments } => AuthOperation::McpPromptGet {
server_id: server,
prompt_id: name,
args: arguments,
metadata: MetadataMap::new(),
},
other @ (Self::ResourcesSubscribe { .. }
| Self::ResourcesUnsubscribe { .. }
| Self::LoggingSetLevel { .. }) => {
let method = other.method_name().to_string();
AuthOperation::McpOther {
server_id: server,
method,
params: other.into_params_json(),
metadata: MetadataMap::new(),
}
}
}
}
fn into_params_json(self) -> Value {
match self {
Self::Initialize => json!({}),
Self::ToolsCall { name, arguments } => json!({ "name": name, "arguments": arguments }),
Self::ResourcesRead { uri } => json!({ "uri": uri }),
Self::ResourcesSubscribe { uri } => json!({ "uri": uri }),
Self::ResourcesUnsubscribe { uri } => json!({ "uri": uri }),
Self::PromptsGet { name, arguments } => {
json!({ "name": name, "arguments": arguments })
}
Self::LoggingSetLevel { level } => json!({ "level": level }),
}
}
}
#[derive(Debug, Error)]
pub enum McpError {
#[error("io error: {0}")]
Io(#[from] std::io::Error),
#[error("serialization error: {0}")]
Serialize(#[from] serde_json::Error),
#[error("transport error: {0}")]
Transport(String),
#[error("{operation} timed out after {duration:?}")]
Timeout {
operation: &'static str,
duration: Duration,
},
#[error("protocol error: {0}")]
Protocol(String),
#[error("MCP auth required: {0:?}")]
AuthRequired(Box<AuthRequest>),
#[error("auth resolution error: {0}")]
AuthResolution(String),
#[error("invocation error: {0}")]
Invocation(McpInvocationError),
#[error("unknown MCP server: {0}")]
UnknownServer(String),
}
impl From<&str> for McpServerId {
fn from(value: &str) -> Self {
Self::new(value)
}
}
impl From<String> for McpServerId {
fn from(value: String) -> Self {
Self::new(value)
}
}