use std::borrow::Cow;
use thiserror::Error;
use super::*;
use crate::{
model::{
ArgumentInfo, CallToolRequest, CallToolRequestParams, CallToolResult,
CancelledNotification, CancelledNotificationParam, ClientInfo, ClientJsonRpcMessage,
ClientNotification, ClientRequest, ClientResult, CompleteRequest, CompleteRequestParams,
CompleteResult, CompletionContext, CompletionInfo, ErrorData, GetPromptRequest,
GetPromptRequestParams, GetPromptResult, InitializeRequest, InitializedNotification,
JsonRpcResponse, ListPromptsRequest, ListPromptsResult, ListResourceTemplatesRequest,
ListResourceTemplatesResult, ListResourcesRequest, ListResourcesResult, ListToolsRequest,
ListToolsResult, PaginatedRequestParams, ProgressNotification, ProgressNotificationParam,
ReadResourceRequest, ReadResourceRequestParams, ReadResourceResult, Reference, RequestId,
RootsListChangedNotification, ServerInfo, ServerJsonRpcMessage, ServerNotification,
ServerRequest, ServerResult, SetLevelRequest, SetLevelRequestParams, SubscribeRequest,
SubscribeRequestParams, UnsubscribeRequest, UnsubscribeRequestParams,
},
transport::DynamicTransportError,
};
#[derive(Error, Debug)]
pub enum ClientInitializeError {
#[error("expect initialized response, but received: {0:?}")]
ExpectedInitResponse(Option<ServerJsonRpcMessage>),
#[error("expect initialized result, but received: {0:?}")]
ExpectedInitResult(Option<ServerResult>),
#[error("conflict initialized response id: expected {0}, got {1}")]
ConflictInitResponseId(RequestId, RequestId),
#[error("connection closed: {0}")]
ConnectionClosed(String),
#[error("Send message error {error}, when {context}")]
TransportError {
error: DynamicTransportError,
context: Cow<'static, str>,
},
#[error("JSON-RPC error: {0}")]
JsonRpcError(ErrorData),
#[error("Cancelled")]
Cancelled,
}
impl ClientInitializeError {
pub fn transport<T: Transport<RoleClient> + 'static>(
error: T::Error,
context: impl Into<Cow<'static, str>>,
) -> Self {
Self::TransportError {
error: DynamicTransportError::new::<T, _>(error),
context: context.into(),
}
}
}
async fn expect_next_message<T>(
transport: &mut T,
context: &str,
) -> Result<ServerJsonRpcMessage, ClientInitializeError>
where
T: Transport<RoleClient>,
{
transport
.receive()
.await
.ok_or_else(|| ClientInitializeError::ConnectionClosed(context.to_string()))
}
async fn expect_response<T, S>(
transport: &mut T,
context: &str,
service: &S,
peer: Peer<RoleClient>,
) -> Result<(ServerResult, RequestId), ClientInitializeError>
where
T: Transport<RoleClient>,
S: Service<RoleClient>,
{
loop {
let message = expect_next_message(transport, context).await?;
match message {
ServerJsonRpcMessage::Response(JsonRpcResponse { id, result, .. }) => {
break Ok((result, id));
}
ServerJsonRpcMessage::Error(error) => {
break Err(ClientInitializeError::JsonRpcError(error.error));
}
ServerJsonRpcMessage::Notification(mut notification) => {
let ServerNotification::LoggingMessageNotification(logging) =
&mut notification.notification
else {
tracing::warn!(?notification, "Received unexpected message");
continue;
};
let mut context = NotificationContext {
peer: peer.clone(),
meta: Meta::default(),
extensions: Extensions::default(),
};
if let Some(meta) = logging.extensions.get_mut::<Meta>() {
std::mem::swap(&mut context.meta, meta);
}
std::mem::swap(&mut context.extensions, &mut logging.extensions);
if let Err(error) = service
.handle_notification(notification.notification, context)
.await
{
tracing::warn!(?error, "Handle logging before handshake failed.");
}
}
ServerJsonRpcMessage::Request(ref request)
if matches!(request.request, ServerRequest::PingRequest(_)) =>
{
tracing::trace!("Received ping request. Ignored.")
}
_ => tracing::warn!(?message, "Received unexpected message"),
}
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct RoleClient;
impl ServiceRole for RoleClient {
type Req = ClientRequest;
type Resp = ClientResult;
type Not = ClientNotification;
type PeerReq = ServerRequest;
type PeerResp = ServerResult;
type PeerNot = ServerNotification;
type Info = ClientInfo;
type PeerInfo = ServerInfo;
type InitializeError = ClientInitializeError;
const IS_CLIENT: bool = true;
}
pub type ServerSink = Peer<RoleClient>;
impl<S: Service<RoleClient>> ServiceExt<RoleClient> for S {
fn serve_with_ct<T, E, A>(
self,
transport: T,
ct: CancellationToken,
) -> impl Future<Output = Result<RunningService<RoleClient, Self>, ClientInitializeError>> + Send
where
T: IntoTransport<RoleClient, E, A>,
E: std::error::Error + Send + Sync + 'static,
Self: Sized,
{
serve_client_with_ct(self, transport, ct)
}
}
pub async fn serve_client<S, T, E, A>(
service: S,
transport: T,
) -> Result<RunningService<RoleClient, S>, ClientInitializeError>
where
S: Service<RoleClient>,
T: IntoTransport<RoleClient, E, A>,
E: std::error::Error + Send + Sync + 'static,
{
serve_client_with_ct(service, transport, Default::default()).await
}
pub async fn serve_client_with_ct<S, T, E, A>(
service: S,
transport: T,
ct: CancellationToken,
) -> Result<RunningService<RoleClient, S>, ClientInitializeError>
where
S: Service<RoleClient>,
T: IntoTransport<RoleClient, E, A>,
E: std::error::Error + Send + Sync + 'static,
{
tokio::select! {
result = serve_client_with_ct_inner(service, transport.into_transport(), ct.clone()) => { result }
_ = ct.cancelled() => {
Err(ClientInitializeError::Cancelled)
}
}
}
async fn serve_client_with_ct_inner<S, T>(
service: S,
transport: T,
ct: CancellationToken,
) -> Result<RunningService<RoleClient, S>, ClientInitializeError>
where
S: Service<RoleClient>,
T: Transport<RoleClient> + 'static,
{
let mut transport = transport.into_transport();
let id_provider = <Arc<AtomicU32RequestIdProvider>>::default();
let id = id_provider.next_request_id();
let init_request = InitializeRequest {
method: Default::default(),
params: service.get_info(),
extensions: Default::default(),
};
transport
.send(ClientJsonRpcMessage::request(
ClientRequest::InitializeRequest(init_request),
id.clone(),
))
.await
.map_err(|error| ClientInitializeError::TransportError {
error: DynamicTransportError::new::<T, _>(error),
context: "send initialize request".into(),
})?;
let (peer, peer_rx) = Peer::new(id_provider, None);
let (response, response_id) = expect_response(
&mut transport,
"initialize response",
&service,
peer.clone(),
)
.await?;
if id != response_id {
return Err(ClientInitializeError::ConflictInitResponseId(
id,
response_id,
));
}
let ServerResult::InitializeResult(initialize_result) = response else {
return Err(ClientInitializeError::ExpectedInitResult(Some(response)));
};
peer.set_peer_info(initialize_result);
let notification = ClientJsonRpcMessage::notification(
ClientNotification::InitializedNotification(InitializedNotification {
method: Default::default(),
extensions: Default::default(),
}),
);
transport.send(notification).await.map_err(|error| {
ClientInitializeError::transport::<T>(error, "send initialized notification")
})?;
Ok(serve_inner(service, transport, peer, peer_rx, ct))
}
macro_rules! method {
(peer_req $method:ident $Req:ident() => $Resp: ident ) => {
pub async fn $method(&self) -> Result<$Resp, ServiceError> {
let result = self
.send_request(ClientRequest::$Req($Req {
method: Default::default(),
}))
.await?;
match result {
ServerResult::$Resp(result) => Ok(result),
_ => Err(ServiceError::UnexpectedResponse),
}
}
};
(peer_req $method:ident $Req:ident($Param: ident) => $Resp: ident ) => {
pub async fn $method(&self, params: $Param) -> Result<$Resp, ServiceError> {
let result = self
.send_request(ClientRequest::$Req($Req {
method: Default::default(),
params,
extensions: Default::default(),
}))
.await?;
match result {
ServerResult::$Resp(result) => Ok(result),
_ => Err(ServiceError::UnexpectedResponse),
}
}
};
(peer_req $method:ident $Req:ident($Param: ident)? => $Resp: ident ) => {
pub async fn $method(&self, params: Option<$Param>) -> Result<$Resp, ServiceError> {
let result = self
.send_request(ClientRequest::$Req($Req {
method: Default::default(),
params,
extensions: Default::default(),
}))
.await?;
match result {
ServerResult::$Resp(result) => Ok(result),
_ => Err(ServiceError::UnexpectedResponse),
}
}
};
(peer_req $method:ident $Req:ident($Param: ident)) => {
pub async fn $method(&self, params: $Param) -> Result<(), ServiceError> {
let result = self
.send_request(ClientRequest::$Req($Req {
method: Default::default(),
params,
extensions: Default::default(),
}))
.await?;
match result {
ServerResult::EmptyResult(_) => Ok(()),
_ => Err(ServiceError::UnexpectedResponse),
}
}
};
(peer_not $method:ident $Not:ident($Param: ident)) => {
pub async fn $method(&self, params: $Param) -> Result<(), ServiceError> {
self.send_notification(ClientNotification::$Not($Not {
method: Default::default(),
params,
extensions: Default::default(),
}))
.await?;
Ok(())
}
};
(peer_not $method:ident $Not:ident) => {
pub async fn $method(&self) -> Result<(), ServiceError> {
self.send_notification(ClientNotification::$Not($Not {
method: Default::default(),
extensions: Default::default(),
}))
.await?;
Ok(())
}
};
}
impl Peer<RoleClient> {
method!(peer_req complete CompleteRequest(CompleteRequestParams) => CompleteResult);
method!(peer_req set_level SetLevelRequest(SetLevelRequestParams));
method!(peer_req get_prompt GetPromptRequest(GetPromptRequestParams) => GetPromptResult);
method!(peer_req list_prompts ListPromptsRequest(PaginatedRequestParams)? => ListPromptsResult);
method!(peer_req list_resources ListResourcesRequest(PaginatedRequestParams)? => ListResourcesResult);
method!(peer_req list_resource_templates ListResourceTemplatesRequest(PaginatedRequestParams)? => ListResourceTemplatesResult);
method!(peer_req read_resource ReadResourceRequest(ReadResourceRequestParams) => ReadResourceResult);
method!(peer_req subscribe SubscribeRequest(SubscribeRequestParams) );
method!(peer_req unsubscribe UnsubscribeRequest(UnsubscribeRequestParams));
method!(peer_req call_tool CallToolRequest(CallToolRequestParams) => CallToolResult);
method!(peer_req list_tools ListToolsRequest(PaginatedRequestParams)? => ListToolsResult);
method!(peer_not notify_cancelled CancelledNotification(CancelledNotificationParam));
method!(peer_not notify_progress ProgressNotification(ProgressNotificationParam));
method!(peer_not notify_initialized InitializedNotification);
method!(peer_not notify_roots_list_changed RootsListChangedNotification);
}
impl Peer<RoleClient> {
pub async fn list_all_tools(&self) -> Result<Vec<crate::model::Tool>, ServiceError> {
let mut tools = Vec::new();
let mut cursor = None;
loop {
let result = self
.list_tools(Some(PaginatedRequestParams { meta: None, cursor }))
.await?;
tools.extend(result.tools);
cursor = result.next_cursor;
if cursor.is_none() {
break;
}
}
Ok(tools)
}
pub async fn list_all_prompts(&self) -> Result<Vec<crate::model::Prompt>, ServiceError> {
let mut prompts = Vec::new();
let mut cursor = None;
loop {
let result = self
.list_prompts(Some(PaginatedRequestParams { meta: None, cursor }))
.await?;
prompts.extend(result.prompts);
cursor = result.next_cursor;
if cursor.is_none() {
break;
}
}
Ok(prompts)
}
pub async fn list_all_resources(&self) -> Result<Vec<crate::model::Resource>, ServiceError> {
let mut resources = Vec::new();
let mut cursor = None;
loop {
let result = self
.list_resources(Some(PaginatedRequestParams { meta: None, cursor }))
.await?;
resources.extend(result.resources);
cursor = result.next_cursor;
if cursor.is_none() {
break;
}
}
Ok(resources)
}
pub async fn list_all_resource_templates(
&self,
) -> Result<Vec<crate::model::ResourceTemplate>, ServiceError> {
let mut resource_templates = Vec::new();
let mut cursor = None;
loop {
let result = self
.list_resource_templates(Some(PaginatedRequestParams { meta: None, cursor }))
.await?;
resource_templates.extend(result.resource_templates);
cursor = result.next_cursor;
if cursor.is_none() {
break;
}
}
Ok(resource_templates)
}
pub async fn complete_prompt_argument(
&self,
prompt_name: impl Into<String>,
argument_name: impl Into<String>,
current_value: impl Into<String>,
context: Option<CompletionContext>,
) -> Result<CompletionInfo, ServiceError> {
let request = CompleteRequestParams {
meta: None,
r#ref: Reference::for_prompt(prompt_name),
argument: ArgumentInfo {
name: argument_name.into(),
value: current_value.into(),
},
context,
};
let result = self.complete(request).await?;
Ok(result.completion)
}
pub async fn complete_resource_argument(
&self,
uri_template: impl Into<String>,
argument_name: impl Into<String>,
current_value: impl Into<String>,
context: Option<CompletionContext>,
) -> Result<CompletionInfo, ServiceError> {
let request = CompleteRequestParams {
meta: None,
r#ref: Reference::for_resource(uri_template),
argument: ArgumentInfo {
name: argument_name.into(),
value: current_value.into(),
},
context,
};
let result = self.complete(request).await?;
Ok(result.completion)
}
pub async fn complete_prompt_simple(
&self,
prompt_name: impl Into<String>,
argument_name: impl Into<String>,
current_value: impl Into<String>,
) -> Result<Vec<String>, ServiceError> {
let completion = self
.complete_prompt_argument(prompt_name, argument_name, current_value, None)
.await?;
Ok(completion.values)
}
pub async fn complete_resource_simple(
&self,
uri_template: impl Into<String>,
argument_name: impl Into<String>,
current_value: impl Into<String>,
) -> Result<Vec<String>, ServiceError> {
let completion = self
.complete_resource_argument(uri_template, argument_name, current_value, None)
.await?;
Ok(completion.values)
}
}