use std::borrow::Cow;
#[cfg(feature = "elicitation")]
use std::collections::HashSet;
use thiserror::Error;
#[cfg(feature = "elicitation")]
use url::Url;
use super::*;
#[cfg(feature = "elicitation")]
use crate::model::{
CreateElicitationRequest, CreateElicitationRequestParams, CreateElicitationResult,
ElicitationAction, ElicitationCompletionNotification, ElicitationResponseNotificationParam,
};
use crate::{
model::{
CancelledNotification, CancelledNotificationParam, ClientInfo, ClientJsonRpcMessage,
ClientNotification, ClientRequest, ClientResult, CreateMessageRequest,
CreateMessageRequestParams, CreateMessageResult, ErrorData, ListRootsRequest,
ListRootsResult, LoggingMessageNotification, LoggingMessageNotificationParam,
ProgressNotification, ProgressNotificationParam, PromptListChangedNotification,
ProtocolVersion, ResourceListChangedNotification, ResourceUpdatedNotification,
ResourceUpdatedNotificationParam, ServerInfo, ServerNotification, ServerRequest,
ServerResult, ToolListChangedNotification,
},
transport::DynamicTransportError,
};
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct RoleServer;
impl ServiceRole for RoleServer {
type Req = ServerRequest;
type Resp = ServerResult;
type Not = ServerNotification;
type PeerReq = ClientRequest;
type PeerResp = ClientResult;
type PeerNot = ClientNotification;
type Info = ServerInfo;
type PeerInfo = ClientInfo;
type InitializeError = ServerInitializeError;
const IS_CLIENT: bool = false;
}
#[derive(Error, Debug)]
pub enum ServerInitializeError {
#[error("expect initialized request, but received: {0:?}")]
ExpectedInitializeRequest(Option<ClientJsonRpcMessage>),
#[error("expect initialized notification, but received: {0:?}")]
ExpectedInitializedNotification(Option<ClientJsonRpcMessage>),
#[error("connection closed: {0}")]
ConnectionClosed(String),
#[error("unexpected initialize result: {0:?}")]
UnexpectedInitializeResponse(ServerResult),
#[error("initialize failed: {0}")]
InitializeFailed(ErrorData),
#[error("unsupported protocol version: {0}")]
UnsupportedProtocolVersion(ProtocolVersion),
#[error("Send message error {error}, when {context}")]
TransportError {
error: DynamicTransportError,
context: Cow<'static, str>,
},
#[error("Cancelled")]
Cancelled,
}
impl ServerInitializeError {
pub fn transport<T: Transport<RoleServer> + 'static>(
error: T::Error,
context: impl Into<Cow<'static, str>>,
) -> Self {
Self::TransportError {
error: DynamicTransportError::new::<T, _>(error),
context: context.into(),
}
}
}
pub type ClientSink = Peer<RoleServer>;
impl<S: Service<RoleServer>> ServiceExt<RoleServer> for S {
fn serve_with_ct<T, E, A>(
self,
transport: T,
ct: CancellationToken,
) -> impl Future<Output = Result<RunningService<RoleServer, Self>, ServerInitializeError>> + Send
where
T: IntoTransport<RoleServer, E, A>,
E: std::error::Error + Send + Sync + 'static,
Self: Sized,
{
serve_server_with_ct(self, transport, ct)
}
}
pub async fn serve_server<S, T, E, A>(
service: S,
transport: T,
) -> Result<RunningService<RoleServer, S>, ServerInitializeError>
where
S: Service<RoleServer>,
T: IntoTransport<RoleServer, E, A>,
E: std::error::Error + Send + Sync + 'static,
{
serve_server_with_ct(service, transport, CancellationToken::new()).await
}
async fn expect_next_message<T>(
transport: &mut T,
context: &str,
) -> Result<ClientJsonRpcMessage, ServerInitializeError>
where
T: Transport<RoleServer>,
{
transport
.receive()
.await
.ok_or_else(|| ServerInitializeError::ConnectionClosed(context.to_string()))
}
async fn expect_request<T>(
transport: &mut T,
context: &str,
) -> Result<(ClientRequest, RequestId), ServerInitializeError>
where
T: Transport<RoleServer>,
{
let msg = expect_next_message(transport, context).await?;
let msg_clone = msg.clone();
msg.into_request()
.ok_or(ServerInitializeError::ExpectedInitializeRequest(Some(
msg_clone,
)))
}
async fn expect_notification<T>(
transport: &mut T,
context: &str,
) -> Result<ClientNotification, ServerInitializeError>
where
T: Transport<RoleServer>,
{
let msg = expect_next_message(transport, context).await?;
let msg_clone = msg.clone();
msg.into_notification()
.ok_or(ServerInitializeError::ExpectedInitializedNotification(
Some(msg_clone),
))
}
pub async fn serve_server_with_ct<S, T, E, A>(
service: S,
transport: T,
ct: CancellationToken,
) -> Result<RunningService<RoleServer, S>, ServerInitializeError>
where
S: Service<RoleServer>,
T: IntoTransport<RoleServer, E, A>,
E: std::error::Error + Send + Sync + 'static,
{
tokio::select! {
result = serve_server_with_ct_inner(service, transport.into_transport(), ct.clone()) => { result }
_ = ct.cancelled() => {
Err(ServerInitializeError::Cancelled)
}
}
}
async fn serve_server_with_ct_inner<S, T>(
service: S,
transport: T,
ct: CancellationToken,
) -> Result<RunningService<RoleServer, S>, ServerInitializeError>
where
S: Service<RoleServer>,
T: Transport<RoleServer> + 'static,
{
let mut transport = transport.into_transport();
let id_provider = <Arc<AtomicU32RequestIdProvider>>::default();
let (request, id) = expect_request(&mut transport, "initialized request").await?;
let ClientRequest::InitializeRequest(peer_info) = &request else {
return Err(ServerInitializeError::ExpectedInitializeRequest(Some(
ClientJsonRpcMessage::request(request, id),
)));
};
let (peer, peer_rx) = Peer::new(id_provider, Some(peer_info.params.clone()));
let context = RequestContext {
ct: ct.child_token(),
id: id.clone(),
meta: request.get_meta().clone(),
extensions: request.extensions().clone(),
peer: peer.clone(),
};
let init_response = service.handle_request(request.clone(), context).await;
let mut init_response = match init_response {
Ok(ServerResult::InitializeResult(init_response)) => init_response,
Ok(result) => {
return Err(ServerInitializeError::UnexpectedInitializeResponse(result));
}
Err(e) => {
transport
.send(ServerJsonRpcMessage::error(e.clone(), id))
.await
.map_err(|error| {
ServerInitializeError::transport::<T>(error, "sending error response")
})?;
return Err(ServerInitializeError::InitializeFailed(e));
}
};
let peer_protocol_version = peer_info.params.protocol_version.clone();
let protocol_version = match peer_protocol_version
.partial_cmp(&init_response.protocol_version)
.ok_or(ServerInitializeError::UnsupportedProtocolVersion(
peer_protocol_version,
))? {
std::cmp::Ordering::Less => peer_info.params.protocol_version.clone(),
_ => init_response.protocol_version,
};
init_response.protocol_version = protocol_version;
transport
.send(ServerJsonRpcMessage::response(
ServerResult::InitializeResult(init_response),
id,
))
.await
.map_err(|error| {
ServerInitializeError::transport::<T>(error, "sending initialize response")
})?;
let notification = expect_notification(&mut transport, "initialize notification").await?;
let ClientNotification::InitializedNotification(_) = notification else {
return Err(ServerInitializeError::ExpectedInitializedNotification(
Some(ClientJsonRpcMessage::notification(notification)),
));
};
let context = NotificationContext {
meta: notification.get_meta().clone(),
extensions: notification.extensions().clone(),
peer: peer.clone(),
};
let _ = service.handle_notification(notification, context).await;
Ok(serve_inner(service, transport, peer, peer_rx, ct))
}
macro_rules! method {
(peer_req $method:ident $Req:ident() => $Resp: ident ) => {
pub async fn $method(&self) -> Result<$Resp, ServiceError> {
let result = self
.send_request(ServerRequest::$Req($Req {
method: Default::default(),
extensions: Default::default(),
}))
.await?;
match result {
ClientResult::$Resp(result) => Ok(result),
_ => Err(ServiceError::UnexpectedResponse),
}
}
};
(peer_req $method:ident $Req:ident($Param: ident) => $Resp: ident ) => {
pub async fn $method(&self, params: $Param) -> Result<$Resp, ServiceError> {
let result = self
.send_request(ServerRequest::$Req($Req {
method: Default::default(),
params,
extensions: Default::default(),
}))
.await?;
match result {
ClientResult::$Resp(result) => Ok(result),
_ => Err(ServiceError::UnexpectedResponse),
}
}
};
(peer_req $method:ident $Req:ident($Param: ident)) => {
pub fn $method(
&self,
params: $Param,
) -> impl Future<Output = Result<(), ServiceError>> + Send + '_ {
async move {
let result = self
.send_request(ServerRequest::$Req($Req {
method: Default::default(),
params,
}))
.await?;
match result {
ClientResult::EmptyResult(_) => Ok(()),
_ => Err(ServiceError::UnexpectedResponse),
}
}
}
};
(peer_not $method:ident $Not:ident($Param: ident)) => {
pub async fn $method(&self, params: $Param) -> Result<(), ServiceError> {
self.send_notification(ServerNotification::$Not($Not {
method: Default::default(),
params,
extensions: Default::default(),
}))
.await?;
Ok(())
}
};
(peer_not $method:ident $Not:ident) => {
pub async fn $method(&self) -> Result<(), ServiceError> {
self.send_notification(ServerNotification::$Not($Not {
method: Default::default(),
extensions: Default::default(),
}))
.await?;
Ok(())
}
};
(peer_req_with_timeout $method_with_timeout:ident $Req:ident() => $Resp: ident) => {
pub async fn $method_with_timeout(
&self,
timeout: Option<std::time::Duration>,
) -> Result<$Resp, ServiceError> {
let request = ServerRequest::$Req($Req {
method: Default::default(),
extensions: Default::default(),
});
let options = crate::service::PeerRequestOptions {
timeout,
meta: None,
};
let result = self
.send_request_with_option(request, options)
.await?
.await_response()
.await?;
match result {
ClientResult::$Resp(result) => Ok(result),
_ => Err(ServiceError::UnexpectedResponse),
}
}
};
(peer_req_with_timeout $method_with_timeout:ident $Req:ident($Param: ident) => $Resp: ident) => {
pub async fn $method_with_timeout(
&self,
params: $Param,
timeout: Option<std::time::Duration>,
) -> Result<$Resp, ServiceError> {
let request = ServerRequest::$Req($Req {
method: Default::default(),
params,
extensions: Default::default(),
});
let options = crate::service::PeerRequestOptions {
timeout,
meta: None,
};
let result = self
.send_request_with_option(request, options)
.await?
.await_response()
.await?;
match result {
ClientResult::$Resp(result) => Ok(result),
_ => Err(ServiceError::UnexpectedResponse),
}
}
};
}
impl Peer<RoleServer> {
pub fn supports_sampling_tools(&self) -> bool {
if let Some(client_info) = self.peer_info() {
client_info
.capabilities
.sampling
.as_ref()
.and_then(|s| s.tools.as_ref())
.is_some()
} else {
false
}
}
pub async fn create_message(
&self,
params: CreateMessageRequestParams,
) -> Result<CreateMessageResult, ServiceError> {
if (params.tools.is_some() || params.tool_choice.is_some())
&& !self.supports_sampling_tools()
{
return Err(ServiceError::McpError(ErrorData::invalid_params(
"tools or toolChoice provided but client does not support sampling tools capability",
None,
)));
}
params
.validate()
.map_err(|e| ServiceError::McpError(ErrorData::invalid_params(e, None)))?;
let result = self
.send_request(ServerRequest::CreateMessageRequest(CreateMessageRequest {
method: Default::default(),
params,
extensions: Default::default(),
}))
.await?;
match result {
ClientResult::CreateMessageResult(result) => Ok(*result),
_ => Err(ServiceError::UnexpectedResponse),
}
}
method!(peer_req list_roots ListRootsRequest() => ListRootsResult);
#[cfg(feature = "elicitation")]
method!(peer_req create_elicitation CreateElicitationRequest(CreateElicitationRequestParams) => CreateElicitationResult);
#[cfg(feature = "elicitation")]
method!(peer_req_with_timeout create_elicitation_with_timeout CreateElicitationRequest(CreateElicitationRequestParams) => CreateElicitationResult);
#[cfg(feature = "elicitation")]
method!(peer_not notify_url_elicitation_completed ElicitationCompletionNotification(ElicitationResponseNotificationParam));
method!(peer_not notify_cancelled CancelledNotification(CancelledNotificationParam));
method!(peer_not notify_progress ProgressNotification(ProgressNotificationParam));
method!(peer_not notify_logging_message LoggingMessageNotification(LoggingMessageNotificationParam));
method!(peer_not notify_resource_updated ResourceUpdatedNotification(ResourceUpdatedNotificationParam));
method!(peer_not notify_resource_list_changed ResourceListChangedNotification);
method!(peer_not notify_tool_list_changed ToolListChangedNotification);
method!(peer_not notify_prompt_list_changed PromptListChangedNotification);
}
#[cfg(feature = "elicitation")]
#[derive(Error, Debug)]
pub enum ElicitationError {
#[error("Service error: {0}")]
Service(#[from] ServiceError),
#[error("User explicitly declined the request")]
UserDeclined,
#[error("User cancelled/dismissed the request")]
UserCancelled,
#[error("Failed to parse response data: {error}\nReceived data: {data}")]
ParseError {
error: serde_json::Error,
data: serde_json::Value,
},
#[error("No response content provided")]
NoContent,
#[error("Client does not support elicitation - capability not declared during initialization")]
CapabilityNotSupported,
}
#[cfg(feature = "elicitation")]
pub trait ElicitationSafe: schemars::JsonSchema {}
#[cfg(feature = "elicitation")]
#[macro_export]
macro_rules! elicit_safe {
($($t:ty),* $(,)?) => {
$(
impl $crate::service::ElicitationSafe for $t {}
)*
};
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ElicitationMode {
Form,
Url,
}
#[cfg(feature = "elicitation")]
impl Peer<RoleServer> {
pub fn supported_elicitation_modes(&self) -> HashSet<ElicitationMode> {
if let Some(client_info) = self.peer_info() {
if let Some(elicit_capability) = &client_info.capabilities.elicitation {
let mut modes = HashSet::new();
if elicit_capability.form.is_none() && elicit_capability.url.is_none() {
modes.insert(ElicitationMode::Form);
} else {
if elicit_capability.form.is_some() {
modes.insert(ElicitationMode::Form);
}
if elicit_capability.url.is_some() {
modes.insert(ElicitationMode::Url);
}
}
modes
} else {
HashSet::new()
}
} else {
HashSet::new()
}
}
#[cfg(all(feature = "schemars", feature = "elicitation"))]
pub async fn elicit<T>(&self, message: impl Into<String>) -> Result<Option<T>, ElicitationError>
where
T: ElicitationSafe + for<'de> serde::Deserialize<'de>,
{
self.elicit_with_timeout(message, None).await
}
#[cfg(all(feature = "schemars", feature = "elicitation"))]
pub async fn elicit_with_timeout<T>(
&self,
message: impl Into<String>,
timeout: Option<std::time::Duration>,
) -> Result<Option<T>, ElicitationError>
where
T: ElicitationSafe + for<'de> serde::Deserialize<'de>,
{
if !self
.supported_elicitation_modes()
.contains(&ElicitationMode::Form)
{
return Err(ElicitationError::CapabilityNotSupported);
}
let schema = crate::model::ElicitationSchema::from_type::<T>().map_err(|e| {
ElicitationError::Service(ServiceError::McpError(crate::ErrorData::invalid_params(
format!(
"Invalid schema for type {}: {}",
std::any::type_name::<T>(),
e
),
None,
)))
})?;
let response = self
.create_elicitation_with_timeout(
CreateElicitationRequestParams::FormElicitationParams {
meta: None,
message: message.into(),
requested_schema: schema,
},
timeout,
)
.await?;
match response.action {
crate::model::ElicitationAction::Accept => {
if let Some(value) = response.content {
match serde_json::from_value::<T>(value.clone()) {
Ok(parsed) => Ok(Some(parsed)),
Err(error) => Err(ElicitationError::ParseError { error, data: value }),
}
} else {
Err(ElicitationError::NoContent)
}
}
crate::model::ElicitationAction::Decline => Err(ElicitationError::UserDeclined),
crate::model::ElicitationAction::Cancel => Err(ElicitationError::UserCancelled),
}
}
#[cfg(feature = "elicitation")]
pub async fn elicit_url(
&self,
message: impl Into<String>,
url: impl Into<Url>,
elicitation_id: impl Into<String>,
) -> Result<ElicitationAction, ElicitationError> {
self.elicit_url_with_timeout(message, url, elicitation_id, None)
.await
}
#[cfg(feature = "elicitation")]
pub async fn elicit_url_with_timeout(
&self,
message: impl Into<String>,
url: impl Into<Url>,
elicitation_id: impl Into<String>,
timeout: Option<std::time::Duration>,
) -> Result<ElicitationAction, ElicitationError> {
if !self
.supported_elicitation_modes()
.contains(&ElicitationMode::Url)
{
return Err(ElicitationError::CapabilityNotSupported);
}
let action = self
.create_elicitation_with_timeout(
CreateElicitationRequestParams::UrlElicitationParams {
meta: None,
message: message.into(),
url: url.into().to_string(),
elicitation_id: elicitation_id.into(),
},
timeout,
)
.await?
.action;
Ok(action)
}
}