use futures::channel::oneshot;
use mcpkit_core::capability::{
is_version_supported, ClientCapabilities, ClientInfo, InitializeRequest, InitializeResult,
ServerCapabilities, ServerInfo, PROTOCOL_VERSION, SUPPORTED_PROTOCOL_VERSIONS,
};
use mcpkit_core::error::{
HandshakeDetails, JsonRpcError, McpError, TransportContext, TransportDetails,
TransportErrorKind,
};
use mcpkit_core::protocol::{Message, Notification, Request, RequestId, Response};
use mcpkit_core::protocol_version::ProtocolVersion;
use mcpkit_core::types::{
CallToolRequest, CallToolResult, CancelTaskRequest, CompleteRequest, CompleteResult,
CompletionArgument, CompletionRef, CreateMessageRequest, ElicitRequest, GetPromptRequest,
GetPromptResult, GetTaskRequest, ListPromptsResult, ListResourceTemplatesResult,
ListResourcesResult, ListTasksRequest, ListTasksResult, ListToolsResult, Prompt,
ReadResourceRequest, ReadResourceResult, Resource, ResourceContents, ResourceTemplate, Task,
TaskStatus, TaskSummary, Tool,
};
use mcpkit_transport::Transport;
use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use tracing::{debug, error, info, trace, warn};
use async_lock::RwLock;
#[cfg(feature = "tokio-runtime")]
use tokio::sync::mpsc;
use crate::handler::ClientHandler;
pub struct Client<T: Transport, H: ClientHandler = crate::handler::NoOpHandler> {
transport: Arc<T>,
server_info: ServerInfo,
server_caps: ServerCapabilities,
protocol_version: ProtocolVersion,
client_info: ClientInfo,
client_caps: ClientCapabilities,
next_id: AtomicU64,
pending: Arc<RwLock<HashMap<RequestId, oneshot::Sender<Response>>>>,
instructions: Option<String>,
handler: Arc<H>,
outgoing_tx: mpsc::Sender<Message>,
running: Arc<AtomicBool>,
_background_handle: Option<tokio::task::JoinHandle<()>>,
}
impl<T: Transport + 'static> Client<T, crate::handler::NoOpHandler> {
pub(crate) fn new(
transport: T,
init_result: InitializeResult,
client_info: ClientInfo,
client_caps: ClientCapabilities,
) -> Self {
Self::with_handler(
transport,
init_result,
client_info,
client_caps,
crate::handler::NoOpHandler,
)
}
}
impl<T: Transport + 'static, H: ClientHandler + 'static> Client<T, H> {
pub(crate) fn with_handler(
transport: T,
init_result: InitializeResult,
client_info: ClientInfo,
client_caps: ClientCapabilities,
handler: H,
) -> Self {
let transport = Arc::new(transport);
let pending = Arc::new(RwLock::new(HashMap::new()));
let handler = Arc::new(handler);
let running = Arc::new(AtomicBool::new(true));
let protocol_version =
if let Ok(v) = init_result.protocol_version.parse::<ProtocolVersion>() {
v
} else {
warn!(
server_version = %init_result.protocol_version,
fallback_version = %ProtocolVersion::LATEST,
"Server returned unknown protocol version, falling back to latest supported"
);
ProtocolVersion::LATEST
};
let (outgoing_tx, outgoing_rx) = mpsc::channel::<Message>(256);
let background_handle = Self::spawn_message_router(
Arc::clone(&transport),
Arc::clone(&pending),
Arc::clone(&handler),
Arc::clone(&running),
outgoing_rx,
);
let handler_clone = Arc::clone(&handler);
tokio::spawn(async move {
handler_clone.on_connected().await;
});
Self {
transport,
server_info: init_result.server_info,
server_caps: init_result.capabilities,
protocol_version,
client_info,
client_caps,
next_id: AtomicU64::new(1),
pending,
instructions: init_result.instructions,
handler,
outgoing_tx,
running,
_background_handle: Some(background_handle),
}
}
fn spawn_message_router(
transport: Arc<T>,
pending: Arc<RwLock<HashMap<RequestId, oneshot::Sender<Response>>>>,
handler: Arc<H>,
running: Arc<AtomicBool>,
mut outgoing_rx: mpsc::Receiver<Message>,
) -> tokio::task::JoinHandle<()> {
tokio::spawn(async move {
debug!("Starting client message router");
loop {
if !running.load(Ordering::SeqCst) {
debug!("Message router stopping (client closed)");
break;
}
tokio::select! {
Some(msg) = outgoing_rx.recv() => {
if let Err(e) = transport.send(msg).await {
error!(?e, "Failed to send message");
}
}
result = transport.recv() => {
match result {
Ok(Some(message)) => {
Self::handle_incoming_message(
message,
&pending,
&handler,
&transport,
).await;
}
Ok(None) => {
info!("Connection closed by server");
running.store(false, Ordering::SeqCst);
handler.on_disconnected().await;
break;
}
Err(e) => {
error!(?e, "Transport error in message router");
running.store(false, Ordering::SeqCst);
handler.on_disconnected().await;
break;
}
}
}
}
}
debug!("Message router stopped");
})
}
async fn handle_incoming_message(
message: Message,
pending: &Arc<RwLock<HashMap<RequestId, oneshot::Sender<Response>>>>,
handler: &Arc<H>,
transport: &Arc<T>,
) {
match message {
Message::Response(response) => {
Self::route_response(response, pending).await;
}
Message::Request(request) => {
Self::handle_server_request(request, handler, transport).await;
}
Message::Notification(notification) => {
Self::handle_notification(notification, handler).await;
}
}
}
async fn route_response(
response: Response,
pending: &Arc<RwLock<HashMap<RequestId, oneshot::Sender<Response>>>>,
) {
let sender = {
let mut pending_guard = pending.write().await;
pending_guard.remove(&response.id)
};
if let Some(sender) = sender {
trace!(?response.id, "Routing response to pending request");
if sender.send(response).is_err() {
warn!("Pending request receiver dropped");
}
} else {
warn!(?response.id, "Received response for unknown request");
}
}
async fn handle_server_request(request: Request, handler: &Arc<H>, transport: &Arc<T>) {
trace!(method = %request.method, "Handling server request");
let response = match request.method.as_ref() {
"sampling/createMessage" => Self::handle_sampling_request(&request, handler).await,
"elicitation/elicit" => Self::handle_elicitation_request(&request, handler).await,
"roots/list" => Self::handle_roots_request(&request, handler).await,
"ping" => {
Response::success(request.id.clone(), serde_json::json!({}))
}
_ => {
warn!(method = %request.method, "Unknown server request method");
Response::error(
request.id.clone(),
JsonRpcError::method_not_found(format!("Unknown method: {}", request.method)),
)
}
};
if let Err(e) = transport.send(Message::Response(response)).await {
error!(?e, "Failed to send response to server request");
}
}
async fn handle_sampling_request(request: &Request, handler: &Arc<H>) -> Response {
let params = match &request.params {
Some(p) => match serde_json::from_value::<CreateMessageRequest>(p.clone()) {
Ok(req) => req,
Err(e) => {
return Response::error(
request.id.clone(),
JsonRpcError::invalid_params(format!("Invalid params: {e}")),
);
}
},
None => {
return Response::error(
request.id.clone(),
JsonRpcError::invalid_params("Missing params for sampling/createMessage"),
);
}
};
match handler.create_message(params).await {
Ok(result) => match serde_json::to_value(result) {
Ok(value) => Response::success(request.id.clone(), value),
Err(e) => Response::error(
request.id.clone(),
JsonRpcError::internal_error(format!("Serialization error: {e}")),
),
},
Err(e) => Response::error(
request.id.clone(),
JsonRpcError::internal_error(e.to_string()),
),
}
}
async fn handle_elicitation_request(request: &Request, handler: &Arc<H>) -> Response {
let params = match &request.params {
Some(p) => match serde_json::from_value::<ElicitRequest>(p.clone()) {
Ok(req) => req,
Err(e) => {
return Response::error(
request.id.clone(),
JsonRpcError::invalid_params(format!("Invalid params: {e}")),
);
}
},
None => {
return Response::error(
request.id.clone(),
JsonRpcError::invalid_params("Missing params for elicitation/elicit"),
);
}
};
match handler.elicit(params).await {
Ok(result) => match serde_json::to_value(result) {
Ok(value) => Response::success(request.id.clone(), value),
Err(e) => Response::error(
request.id.clone(),
JsonRpcError::internal_error(format!("Serialization error: {e}")),
),
},
Err(e) => Response::error(
request.id.clone(),
JsonRpcError::internal_error(e.to_string()),
),
}
}
async fn handle_roots_request(request: &Request, handler: &Arc<H>) -> Response {
match handler.list_roots().await {
Ok(roots) => {
let roots_json: Vec<serde_json::Value> = roots
.into_iter()
.map(|r| {
serde_json::json!({
"uri": r.uri,
"name": r.name
})
})
.collect();
Response::success(
request.id.clone(),
serde_json::json!({ "roots": roots_json }),
)
}
Err(e) => Response::error(
request.id.clone(),
JsonRpcError::internal_error(e.to_string()),
),
}
}
async fn handle_notification(notification: Notification, handler: &Arc<H>) {
trace!(method = %notification.method, "Received server notification");
match notification.method.as_ref() {
"notifications/cancelled" => {
if let Some(params) = ¬ification.params {
if let Some(request_id) = params.get("requestId") {
debug!(?request_id, "Server cancelled request");
}
}
}
"notifications/progress" => {
if let Some(params) = notification.params {
if let (Some(task_id), Some(progress)) = (
params.get("progressToken").and_then(|v| v.as_str()),
params.get("progress"),
) {
if let Ok(progress) = serde_json::from_value::<
mcpkit_core::types::TaskProgress,
>(progress.clone())
{
debug!(task_id = %task_id, "Task progress update");
handler.on_task_progress(task_id.into(), progress).await;
}
}
}
}
"notifications/resources/updated" => {
if let Some(params) = notification.params {
if let Some(uri) = params.get("uri").and_then(|v| v.as_str()) {
debug!(uri = %uri, "Resource updated");
handler.on_resource_updated(uri.to_string()).await;
}
}
}
"notifications/resources/list_changed" => {
debug!("Resources list changed");
handler.on_resources_list_changed().await;
}
"notifications/tools/list_changed" => {
debug!("Tools list changed");
handler.on_tools_list_changed().await;
}
"notifications/prompts/list_changed" => {
debug!("Prompts list changed");
handler.on_prompts_list_changed().await;
}
_ => {
trace!(method = %notification.method, "Unhandled notification");
}
}
}
pub const fn server_info(&self) -> &ServerInfo {
&self.server_info
}
pub const fn server_capabilities(&self) -> &ServerCapabilities {
&self.server_caps
}
pub fn protocol_version(&self) -> ProtocolVersion {
self.protocol_version
}
pub const fn client_info(&self) -> &ClientInfo {
&self.client_info
}
pub const fn client_capabilities(&self) -> &ClientCapabilities {
&self.client_caps
}
pub fn instructions(&self) -> Option<&str> {
self.instructions.as_deref()
}
pub const fn has_tools(&self) -> bool {
self.server_caps.has_tools()
}
pub const fn has_resources(&self) -> bool {
self.server_caps.has_resources()
}
pub const fn has_prompts(&self) -> bool {
self.server_caps.has_prompts()
}
pub const fn has_tasks(&self) -> bool {
self.server_caps.has_tasks()
}
pub const fn has_completions(&self) -> bool {
self.server_caps.has_completions()
}
pub fn is_connected(&self) -> bool {
self.running.load(Ordering::SeqCst)
}
pub async fn list_tools(&self) -> Result<Vec<Tool>, McpError> {
self.ensure_capability("tools", self.has_tools())?;
let result: ListToolsResult = self.request("tools/list", None).await?;
Ok(result.tools)
}
pub async fn list_tools_paginated(
&self,
cursor: Option<&str>,
) -> Result<ListToolsResult, McpError> {
self.ensure_capability("tools", self.has_tools())?;
let params = cursor.map(|c| serde_json::json!({ "cursor": c }));
self.request("tools/list", params).await
}
pub async fn call_tool(
&self,
name: impl Into<String>,
arguments: serde_json::Value,
) -> Result<CallToolResult, McpError> {
self.ensure_capability("tools", self.has_tools())?;
let request = CallToolRequest {
name: name.into(),
arguments: Some(arguments),
};
self.request("tools/call", Some(serde_json::to_value(request)?))
.await
}
pub async fn list_resources(&self) -> Result<Vec<Resource>, McpError> {
self.ensure_capability("resources", self.has_resources())?;
let result: ListResourcesResult = self.request("resources/list", None).await?;
Ok(result.resources)
}
pub async fn list_resources_paginated(
&self,
cursor: Option<&str>,
) -> Result<ListResourcesResult, McpError> {
self.ensure_capability("resources", self.has_resources())?;
let params = cursor.map(|c| serde_json::json!({ "cursor": c }));
self.request("resources/list", params).await
}
pub async fn list_resource_templates(&self) -> Result<Vec<ResourceTemplate>, McpError> {
self.ensure_capability("resources", self.has_resources())?;
let result: ListResourceTemplatesResult =
self.request("resources/templates/list", None).await?;
Ok(result.resource_templates)
}
pub async fn read_resource(
&self,
uri: impl Into<String>,
) -> Result<Vec<ResourceContents>, McpError> {
self.ensure_capability("resources", self.has_resources())?;
let request = ReadResourceRequest { uri: uri.into() };
let result: ReadResourceResult = self
.request("resources/read", Some(serde_json::to_value(request)?))
.await?;
Ok(result.contents)
}
pub async fn list_prompts(&self) -> Result<Vec<Prompt>, McpError> {
self.ensure_capability("prompts", self.has_prompts())?;
let result: ListPromptsResult = self.request("prompts/list", None).await?;
Ok(result.prompts)
}
pub async fn list_prompts_paginated(
&self,
cursor: Option<&str>,
) -> Result<ListPromptsResult, McpError> {
self.ensure_capability("prompts", self.has_prompts())?;
let params = cursor.map(|c| serde_json::json!({ "cursor": c }));
self.request("prompts/list", params).await
}
pub async fn get_prompt(
&self,
name: impl Into<String>,
arguments: Option<serde_json::Map<String, serde_json::Value>>,
) -> Result<GetPromptResult, McpError> {
self.ensure_capability("prompts", self.has_prompts())?;
let request = GetPromptRequest {
name: name.into(),
arguments,
};
self.request("prompts/get", Some(serde_json::to_value(request)?))
.await
}
pub async fn list_tasks(&self) -> Result<Vec<TaskSummary>, McpError> {
self.ensure_capability("tasks", self.has_tasks())?;
let result: ListTasksResult = self.request("tasks/list", None).await?;
Ok(result.tasks)
}
pub async fn list_tasks_filtered(
&self,
status: Option<TaskStatus>,
cursor: Option<&str>,
) -> Result<ListTasksResult, McpError> {
self.ensure_capability("tasks", self.has_tasks())?;
let request = ListTasksRequest {
status,
cursor: cursor.map(String::from),
};
self.request("tasks/list", Some(serde_json::to_value(request)?))
.await
}
pub async fn get_task(&self, id: impl Into<String>) -> Result<Task, McpError> {
self.ensure_capability("tasks", self.has_tasks())?;
let request = GetTaskRequest {
id: id.into().into(),
};
self.request("tasks/get", Some(serde_json::to_value(request)?))
.await
}
pub async fn cancel_task(&self, id: impl Into<String>) -> Result<(), McpError> {
self.ensure_capability("tasks", self.has_tasks())?;
let request = CancelTaskRequest {
id: id.into().into(),
};
let _: serde_json::Value = self
.request("tasks/cancel", Some(serde_json::to_value(request)?))
.await?;
Ok(())
}
pub async fn complete_prompt_argument(
&self,
prompt_name: impl Into<String>,
argument_name: impl Into<String>,
current_value: impl Into<String>,
) -> Result<CompleteResult, McpError> {
self.ensure_capability("completions", self.has_completions())?;
let request = CompleteRequest {
ref_: CompletionRef::prompt(prompt_name),
argument: CompletionArgument {
name: argument_name.into(),
value: current_value.into(),
},
};
self.request("completion/complete", Some(serde_json::to_value(request)?))
.await
}
pub async fn complete_resource_argument(
&self,
resource_uri: impl Into<String>,
argument_name: impl Into<String>,
current_value: impl Into<String>,
) -> Result<CompleteResult, McpError> {
self.ensure_capability("completions", self.has_completions())?;
let request = CompleteRequest {
ref_: CompletionRef::resource(resource_uri),
argument: CompletionArgument {
name: argument_name.into(),
value: current_value.into(),
},
};
self.request("completion/complete", Some(serde_json::to_value(request)?))
.await
}
pub async fn subscribe_resource(&self, uri: impl Into<String>) -> Result<(), McpError> {
self.ensure_capability("resources", self.has_resources())?;
if !self.server_caps.has_resource_subscribe() {
return Err(McpError::CapabilityNotSupported {
capability: "resources.subscribe".to_string(),
available: self.available_capabilities().into_boxed_slice(),
});
}
let params = serde_json::json!({ "uri": uri.into() });
let _: serde_json::Value = self.request("resources/subscribe", Some(params)).await?;
Ok(())
}
pub async fn unsubscribe_resource(&self, uri: impl Into<String>) -> Result<(), McpError> {
self.ensure_capability("resources", self.has_resources())?;
if !self.server_caps.has_resource_subscribe() {
return Err(McpError::CapabilityNotSupported {
capability: "resources.subscribe".to_string(),
available: self.available_capabilities().into_boxed_slice(),
});
}
let params = serde_json::json!({ "uri": uri.into() });
let _: serde_json::Value = self.request("resources/unsubscribe", Some(params)).await?;
Ok(())
}
pub async fn ping(&self) -> Result<(), McpError> {
let _: serde_json::Value = self.request("ping", None).await?;
Ok(())
}
pub async fn close(self) -> Result<(), McpError> {
debug!("Closing client connection");
self.running.store(false, Ordering::SeqCst);
self.handler.on_disconnected().await;
self.transport.close().await.map_err(|e| {
McpError::Transport(Box::new(TransportDetails {
kind: TransportErrorKind::ConnectionClosed,
message: e.to_string(),
context: TransportContext::default(),
source: None,
}))
})
}
fn next_request_id(&self) -> RequestId {
RequestId::Number(self.next_id.fetch_add(1, Ordering::SeqCst))
}
async fn request<R: serde::de::DeserializeOwned>(
&self,
method: &str,
params: Option<serde_json::Value>,
) -> Result<R, McpError> {
if !self.is_connected() {
return Err(McpError::Transport(Box::new(TransportDetails {
kind: TransportErrorKind::ConnectionClosed,
message: "Client is not connected".to_string(),
context: TransportContext::default(),
source: None,
})));
}
let id = self.next_request_id();
let request = if let Some(params) = params {
Request::with_params(method.to_string(), id.clone(), params)
} else {
Request::new(method.to_string(), id.clone())
};
trace!(?id, method, "Sending request");
let (tx, rx) = oneshot::channel();
{
let mut pending = self.pending.write().await;
pending.insert(id.clone(), tx);
}
self.outgoing_tx
.send(Message::Request(request))
.await
.map_err(|_| {
McpError::Transport(Box::new(TransportDetails {
kind: TransportErrorKind::WriteFailed,
message: "Failed to send request (channel closed)".to_string(),
context: TransportContext::default(),
source: None,
}))
})?;
let response = rx.await.map_err(|_| {
McpError::Transport(Box::new(TransportDetails {
kind: TransportErrorKind::ConnectionClosed,
message: "Response channel closed (server may have disconnected)".to_string(),
context: TransportContext::default(),
source: None,
}))
})?;
if let Some(error) = response.error {
return Err(McpError::Internal {
message: error.message,
source: None,
});
}
let result = response.result.ok_or_else(|| McpError::Internal {
message: "Response contained neither result nor error".to_string(),
source: None,
})?;
serde_json::from_value(result).map_err(McpError::from)
}
fn ensure_capability(&self, name: &str, supported: bool) -> Result<(), McpError> {
if supported {
Ok(())
} else {
Err(McpError::CapabilityNotSupported {
capability: name.to_string(),
available: self.available_capabilities().into_boxed_slice(),
})
}
}
fn available_capabilities(&self) -> Vec<String> {
let mut caps = Vec::new();
if self.has_tools() {
caps.push("tools".to_string());
}
if self.has_resources() {
caps.push("resources".to_string());
}
if self.has_prompts() {
caps.push("prompts".to_string());
}
if self.has_tasks() {
caps.push("tasks".to_string());
}
if self.has_completions() {
caps.push("completions".to_string());
}
caps
}
}
pub(crate) async fn initialize<T: Transport>(
transport: &T,
client_info: &ClientInfo,
capabilities: &ClientCapabilities,
) -> Result<InitializeResult, McpError> {
debug!(
protocol_version = %PROTOCOL_VERSION,
supported_versions = ?SUPPORTED_PROTOCOL_VERSIONS,
"Initializing MCP connection"
);
let request = InitializeRequest::new(client_info.clone(), capabilities.clone());
let init_request = Request::with_params(
"initialize".to_string(),
RequestId::Number(0),
serde_json::to_value(&request)?,
);
transport
.send(Message::Request(init_request))
.await
.map_err(|e| {
McpError::Transport(Box::new(TransportDetails {
kind: TransportErrorKind::WriteFailed,
message: format!("Failed to send initialize: {e}"),
context: TransportContext::default(),
source: None,
}))
})?;
let response = loop {
match transport.recv().await {
Ok(Some(Message::Response(r))) if r.id == RequestId::Number(0) => break r,
Ok(Some(_)) => {} Ok(None) => {
return Err(McpError::HandshakeFailed(Box::new(HandshakeDetails {
message: "Connection closed during initialization".to_string(),
client_version: Some(PROTOCOL_VERSION.to_string()),
server_version: None,
source: None,
})));
}
Err(e) => {
return Err(McpError::HandshakeFailed(Box::new(HandshakeDetails {
message: format!("Transport error during initialization: {e}"),
client_version: Some(PROTOCOL_VERSION.to_string()),
server_version: None,
source: None,
})));
}
}
};
if let Some(error) = response.error {
return Err(McpError::HandshakeFailed(Box::new(HandshakeDetails {
message: error.message,
client_version: Some(PROTOCOL_VERSION.to_string()),
server_version: None,
source: None,
})));
}
let result: InitializeResult = response
.result
.map(serde_json::from_value)
.transpose()?
.ok_or_else(|| {
McpError::HandshakeFailed(Box::new(HandshakeDetails {
message: "Empty initialize result".to_string(),
client_version: Some(PROTOCOL_VERSION.to_string()),
server_version: None,
source: None,
}))
})?;
let server_version = &result.protocol_version;
if !is_version_supported(server_version) {
warn!(
server_version = %server_version,
supported = ?SUPPORTED_PROTOCOL_VERSIONS,
"Server returned unsupported protocol version"
);
return Err(McpError::HandshakeFailed(Box::new(HandshakeDetails {
message: format!(
"Unsupported protocol version: server returned '{server_version}', but client only supports {SUPPORTED_PROTOCOL_VERSIONS:?}"
),
client_version: Some(PROTOCOL_VERSION.to_string()),
server_version: Some(server_version.clone()),
source: None,
})));
}
debug!(
server = %result.server_info.name,
server_version = %result.server_info.version,
protocol_version = %result.protocol_version,
"Received initialize result with compatible protocol version"
);
let notification = Notification::new("notifications/initialized");
transport
.send(Message::Notification(notification))
.await
.map_err(|e| {
McpError::Transport(Box::new(TransportDetails {
kind: TransportErrorKind::WriteFailed,
message: format!("Failed to send initialized: {e}"),
context: TransportContext::default(),
source: None,
}))
})?;
debug!("MCP initialization complete");
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_request_id_generation() {
let next_id = AtomicU64::new(1);
assert_eq!(next_id.fetch_add(1, Ordering::SeqCst), 1);
assert_eq!(next_id.fetch_add(1, Ordering::SeqCst), 2);
}
}