mod builder;
mod handler;
mod state;
pub use builder::ClientBuilder;
use std::sync::Arc;
use tokio::sync::{mpsc, RwLock};
use async_trait::async_trait;
use dashmap::DashMap;
use log::{debug, error};
use uuid::Uuid;
use crate::error::Error;
use crate::protocol::{
Implementation, RequestId, ProgressToken,
JSONRPCMessage, JSONRPCRequest, JSONRPCNotification,
resources::{Resource, ResourceTemplate, TextResourceContents, BlobResourceContents},
prompts::{Prompt, PromptMessage},
tools::{Tool, CallToolResult},
roots::{Root},
completion::CompleteResult,
logging::LoggingLevel,
};
use crate::transport::Transport;
use self::state::{ClientState, PendingRequest};
use self::handler::ClientMessageHandler;
#[derive(Debug, Clone, Default)]
pub struct ClientCapabilities {
pub roots: bool,
pub roots_list_changed: bool,
pub sampling: bool,
pub experimental: Vec<String>,
}
#[derive(Debug, Clone)]
pub struct ClientOptions {
pub implementation: Implementation,
pub capabilities: ClientCapabilities,
pub auto_acknowledge_roots_changed: bool,
pub default_timeout_ms: u64,
}
impl Default for ClientOptions {
fn default() -> Self {
Self {
implementation: Implementation::new("mcpx-client", env!("CARGO_PKG_VERSION")),
capabilities: ClientCapabilities::default(),
auto_acknowledge_roots_changed: true,
default_timeout_ms: 30000, }
}
}
#[derive(Debug, Clone)]
pub enum ClientEvent {
Connected {
server_info: Implementation,
protocol_version: String,
capabilities: ServerCapabilities,
instructions: Option<String>,
},
Disconnected {
reason: String,
},
ResourcesChanged,
PromptsChanged,
ToolsChanged,
RootsChanged,
ResourceUpdated {
uri: String,
},
LogMessage {
level: LoggingLevel,
logger: Option<String>,
data: serde_json::Value,
},
Progress {
request_id: RequestId,
token: ProgressToken,
progress: f64,
total: Option<f64>,
message: Option<String>,
},
Error {
error: Error,
},
}
#[derive(Debug, Clone, Default)]
pub struct ServerCapabilities {
pub logging: bool,
pub completions: bool,
pub prompts: bool,
pub prompts_list_changed: bool,
pub resources: bool,
pub resources_list_changed: bool,
pub resources_subscribe: bool,
pub tools: bool,
pub tools_list_changed: bool,
pub experimental: Vec<String>,
}
pub struct Client {
id: String,
state: Arc<RwLock<ClientState>>,
pending_requests: Arc<DashMap<RequestId, PendingRequest>>,
event_sender: mpsc::Sender<ClientEvent>,
transport: Arc<Box<dyn Transport + Send + Sync>>,
options: ClientOptions,
server_capabilities: Arc<RwLock<Option<ServerCapabilities>>>,
handler: Arc<ClientMessageHandler>,
}
#[async_trait]
pub trait EventListener: Send + Sync {
async fn on_event(&self, event: ClientEvent);
}
impl Client {
pub fn new(
transport: Box<dyn Transport + Send + Sync>,
options: ClientOptions,
) -> (Self, mpsc::Receiver<ClientEvent>) {
let id = Uuid::new_v4().to_string();
let (event_sender, event_receiver) = mpsc::channel(100);
let state = Arc::new(RwLock::new(ClientState::new()));
let pending_requests = Arc::new(DashMap::new());
let server_capabilities = Arc::new(RwLock::new(None));
let handler = Arc::new(ClientMessageHandler::new(
state.clone(),
pending_requests.clone(),
event_sender.clone(),
server_capabilities.clone(),
options.clone(),
));
let client = Self {
id,
state,
pending_requests,
event_sender,
transport: Arc::new(transport),
options,
server_capabilities,
handler,
};
(client, event_receiver)
}
pub fn id(&self) -> &str {
&self.id
}
pub async fn connect(&self) -> Result<(), Error> {
self.transport.connect().await?;
let transport = self.transport.clone();
let handler = self.handler.clone();
let state = self.state.clone();
tokio::spawn(async move {
debug!("Starting message processing loop");
while let Some(msg) = transport.receive().await {
match msg {
Ok(message) => {
if let Err(e) = handler.handle_message(message).await {
error!("Error handling message: {}", e);
}
}
Err(e) => {
error!("Error receiving message: {}", e);
break;
}
}
}
debug!("Message processing loop ended");
let mut state = state.write().await;
state.set_disconnected();
});
self.initialize().await?;
Ok(())
}
pub async fn disconnect(&self) -> Result<(), Error> {
self.transport.disconnect().await?;
Ok(())
}
async fn initialize(&self) -> Result<(), Error> {
let mut state = self.state.write().await;
if state.is_initializing() || state.is_initialized() {
return Ok(());
}
state.set_initializing();
drop(state);
let request = JSONRPCRequest {
jsonrpc: "2.0".to_string(),
id: "init".into(),
method: "initialize".to_string(),
params: Some(serde_json::json!({
"protocolVersion": crate::protocol::LATEST_PROTOCOL_VERSION,
"clientInfo": self.options.implementation,
"capabilities": {
"sampling": {},
"roots": {
"listChanged": self.options.capabilities.roots_list_changed,
}
}
})),
};
let response = self.send_request(request).await?;
let result = match response {
JSONRPCMessage::Response(resp) => resp.result,
JSONRPCMessage::Error(err) => {
return Err(Error::ServerError(
err.error.code,
err.error.message,
err.error.data,
));
}
_ => return Err(Error::ProtocolError("Unexpected response type".to_string())),
};
let server_info: Implementation = serde_json::from_value(
result["serverInfo"].clone(),
)
.map_err(|e| Error::ParseError(e.to_string()))?;
let protocol_version = result["protocolVersion"]
.as_str()
.ok_or_else(|| Error::ParseError("Missing protocolVersion".to_string()))?
.to_string();
let instructions = result
.get("instructions")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let capabilities = self.parse_server_capabilities(&result["capabilities"]);
{
let mut server_caps = self.server_capabilities.write().await;
*server_caps = Some(capabilities.clone());
}
{
let mut state = self.state.write().await;
state.set_initialized();
}
let notification = JSONRPCNotification {
jsonrpc: "2.0".to_string(),
method: "notifications/initialized".to_string(),
params: None,
};
self.send_notification(notification).await?;
self.event_sender
.send(ClientEvent::Connected {
server_info,
protocol_version,
capabilities,
instructions,
})
.await
.map_err(|_| Error::InternalError("Failed to send event".to_string()))?;
Ok(())
}
fn parse_server_capabilities(&self, json: &serde_json::Value) -> ServerCapabilities {
let mut capabilities = ServerCapabilities::default();
if json.get("logging").is_some() {
capabilities.logging = true;
}
if json.get("completions").is_some() {
capabilities.completions = true;
}
if let Some(prompts) = json.get("prompts") {
capabilities.prompts = true;
if let Some(list_changed) = prompts.get("listChanged") {
capabilities.prompts_list_changed = list_changed.as_bool().unwrap_or(false);
}
}
if let Some(resources) = json.get("resources") {
capabilities.resources = true;
if let Some(list_changed) = resources.get("listChanged") {
capabilities.resources_list_changed = list_changed.as_bool().unwrap_or(false);
}
if let Some(subscribe) = resources.get("subscribe") {
capabilities.resources_subscribe = subscribe.as_bool().unwrap_or(false);
}
}
if let Some(tools) = json.get("tools") {
capabilities.tools = true;
if let Some(list_changed) = tools.get("listChanged") {
capabilities.tools_list_changed = list_changed.as_bool().unwrap_or(false);
}
}
if let Some(experimental) = json.get("experimental") {
if let Some(obj) = experimental.as_object() {
capabilities.experimental = obj.keys().map(|k| k.clone()).collect();
}
}
capabilities
}
pub async fn send_request(&self, request: JSONRPCRequest) -> Result<JSONRPCMessage, Error> {
let (sender, receiver) = tokio::sync::oneshot::channel();
let pending = PendingRequest {
sender,
method: request.method.clone(),
start_time: std::time::Instant::now(),
};
let request_id = request.id.clone();
self.pending_requests.insert(request_id.clone(), pending);
let message = JSONRPCMessage::Request(request);
self.transport.send(message).await?;
let timeout_ms = self.options.default_timeout_ms;
let response = if timeout_ms > 0 {
match tokio::time::timeout(
std::time::Duration::from_millis(timeout_ms),
receiver,
)
.await
{
Ok(r) => r.map_err(|_| Error::InternalError("Response channel closed".to_string()))?,
Err(_) => {
self.pending_requests.remove(&request_id);
return Err(Error::Timeout(format!(
"Request timed out after {} ms",
timeout_ms
)));
}
}
} else {
receiver.await.map_err(|_| Error::InternalError("Response channel closed".to_string()))?
};
Ok(response)
}
pub async fn send_notification(&self, notification: JSONRPCNotification) -> Result<(), Error> {
let message = JSONRPCMessage::Notification(notification);
self.transport.send(message).await
}
pub async fn list_resources(&self) -> Result<Vec<Resource>, Error> {
let caps = self.server_capabilities.read().await;
if let Some(caps) = &*caps {
if !caps.resources {
return Err(Error::UnsupportedFeature("Resources".to_string()));
}
} else {
return Err(Error::NotInitialized);
}
let request = JSONRPCRequest {
jsonrpc: "2.0".to_string(),
id: Uuid::new_v4().to_string().into(),
method: "resources/list".to_string(),
params: None,
};
let response = self.send_request(request).await?;
match response {
JSONRPCMessage::Response(resp) => {
let resources = resp.result["resources"]
.as_array()
.ok_or_else(|| Error::ParseError("Missing resources array".to_string()))?;
let mut result = Vec::with_capacity(resources.len());
for resource in resources {
let res: Resource = serde_json::from_value(resource.clone())
.map_err(|e| Error::ParseError(e.to_string()))?;
result.push(res);
}
Ok(result)
}
JSONRPCMessage::Error(err) => Err(Error::ServerError(
err.error.code,
err.error.message,
err.error.data,
)),
_ => Err(Error::ProtocolError("Unexpected response type".to_string())),
}
}
pub async fn list_resource_templates(&self) -> Result<Vec<ResourceTemplate>, Error> {
let caps = self.server_capabilities.read().await;
if let Some(caps) = &*caps {
if !caps.resources {
return Err(Error::UnsupportedFeature("Resources".to_string()));
}
} else {
return Err(Error::NotInitialized);
}
let request = JSONRPCRequest {
jsonrpc: "2.0".to_string(),
id: Uuid::new_v4().to_string().into(),
method: "resources/templates/list".to_string(),
params: None,
};
let response = self.send_request(request).await?;
match response {
JSONRPCMessage::Response(resp) => {
let templates = resp.result["resourceTemplates"]
.as_array()
.ok_or_else(|| Error::ParseError("Missing resourceTemplates array".to_string()))?;
let mut result = Vec::with_capacity(templates.len());
for template in templates {
let tmpl: ResourceTemplate = serde_json::from_value(template.clone())
.map_err(|e| Error::ParseError(e.to_string()))?;
result.push(tmpl);
}
Ok(result)
}
JSONRPCMessage::Error(err) => Err(Error::ServerError(
err.error.code,
err.error.message,
err.error.data,
)),
_ => Err(Error::ProtocolError("Unexpected response type".to_string())),
}
}
pub async fn read_resource(&self, uri: &str) -> Result<Vec<ResourceContent>, Error> {
let caps = self.server_capabilities.read().await;
if let Some(caps) = &*caps {
if !caps.resources {
return Err(Error::UnsupportedFeature("Resources".to_string()));
}
} else {
return Err(Error::NotInitialized);
}
let request = JSONRPCRequest {
jsonrpc: "2.0".to_string(),
id: Uuid::new_v4().to_string().into(),
method: "resources/read".to_string(),
params: Some(serde_json::json!({ "uri": uri })),
};
let response = self.send_request(request).await?;
match response {
JSONRPCMessage::Response(resp) => {
let contents = resp.result["contents"]
.as_array()
.ok_or_else(|| Error::ParseError("Missing contents array".to_string()))?;
let mut result = Vec::with_capacity(contents.len());
for content in contents {
if content.get("text").is_some() {
let text: TextResourceContents = serde_json::from_value(content.clone())
.map_err(|e| Error::ParseError(e.to_string()))?;
result.push(ResourceContent::Text(text));
} else if content.get("blob").is_some() {
let blob: BlobResourceContents = serde_json::from_value(content.clone())
.map_err(|e| Error::ParseError(e.to_string()))?;
result.push(ResourceContent::Blob(blob));
} else {
return Err(Error::ParseError("Unknown resource content type".to_string()));
}
}
Ok(result)
}
JSONRPCMessage::Error(err) => Err(Error::ServerError(
err.error.code,
err.error.message,
err.error.data,
)),
_ => Err(Error::ProtocolError("Unexpected response type".to_string())),
}
}
pub async fn subscribe_resource(&self, uri: &str) -> Result<(), Error> {
let caps = self.server_capabilities.read().await;
if let Some(caps) = &*caps {
if !caps.resources || !caps.resources_subscribe {
return Err(Error::UnsupportedFeature("Resource subscriptions".to_string()));
}
} else {
return Err(Error::NotInitialized);
}
let request = JSONRPCRequest {
jsonrpc: "2.0".to_string(),
id: Uuid::new_v4().to_string().into(),
method: "resources/subscribe".to_string(),
params: Some(serde_json::json!({ "uri": uri })),
};
let response = self.send_request(request).await?;
match response {
JSONRPCMessage::Response(_) => Ok(()),
JSONRPCMessage::Error(err) => Err(Error::ServerError(
err.error.code,
err.error.message,
err.error.data,
)),
_ => Err(Error::ProtocolError("Unexpected response type".to_string())),
}
}
pub async fn unsubscribe_resource(&self, uri: &str) -> Result<(), Error> {
let caps = self.server_capabilities.read().await;
if let Some(caps) = &*caps {
if !caps.resources || !caps.resources_subscribe {
return Err(Error::UnsupportedFeature("Resource subscriptions".to_string()));
}
} else {
return Err(Error::NotInitialized);
}
let request = JSONRPCRequest {
jsonrpc: "2.0".to_string(),
id: Uuid::new_v4().to_string().into(),
method: "resources/unsubscribe".to_string(),
params: Some(serde_json::json!({ "uri": uri })),
};
let response = self.send_request(request).await?;
match response {
JSONRPCMessage::Response(_) => Ok(()),
JSONRPCMessage::Error(err) => Err(Error::ServerError(
err.error.code,
err.error.message,
err.error.data,
)),
_ => Err(Error::ProtocolError("Unexpected response type".to_string())),
}
}
pub async fn list_prompts(&self) -> Result<Vec<Prompt>, Error> {
let caps = self.server_capabilities.read().await;
if let Some(caps) = &*caps {
if !caps.prompts {
return Err(Error::UnsupportedFeature("Prompts".to_string()));
}
} else {
return Err(Error::NotInitialized);
}
let request = JSONRPCRequest {
jsonrpc: "2.0".to_string(),
id: Uuid::new_v4().to_string().into(),
method: "prompts/list".to_string(),
params: None,
};
let response = self.send_request(request).await?;
match response {
JSONRPCMessage::Response(resp) => {
let prompts = resp.result["prompts"]
.as_array()
.ok_or_else(|| Error::ParseError("Missing prompts array".to_string()))?;
let mut result = Vec::with_capacity(prompts.len());
for prompt in prompts {
let p: Prompt = serde_json::from_value(prompt.clone())
.map_err(|e| Error::ParseError(e.to_string()))?;
result.push(p);
}
Ok(result)
}
JSONRPCMessage::Error(err) => Err(Error::ServerError(
err.error.code,
err.error.message,
err.error.data,
)),
_ => Err(Error::ProtocolError("Unexpected response type".to_string())),
}
}
pub async fn get_prompt(
&self,
name: &str,
arguments: Option<std::collections::HashMap<String, String>>,
) -> Result<Vec<PromptMessage>, Error> {
let caps = self.server_capabilities.read().await;
if let Some(caps) = &*caps {
if !caps.prompts {
return Err(Error::UnsupportedFeature("Prompts".to_string()));
}
} else {
return Err(Error::NotInitialized);
}
let mut params = serde_json::json!({ "name": name });
if let Some(args) = arguments {
params["arguments"] = serde_json::to_value(args)
.map_err(|e| Error::InternalError(e.to_string()))?;
}
let request = JSONRPCRequest {
jsonrpc: "2.0".to_string(),
id: Uuid::new_v4().to_string().into(),
method: "prompts/get".to_string(),
params: Some(params),
};
let response = self.send_request(request).await?;
match response {
JSONRPCMessage::Response(resp) => {
let messages = resp.result["messages"]
.as_array()
.ok_or_else(|| Error::ParseError("Missing messages array".to_string()))?;
let mut result = Vec::with_capacity(messages.len());
for message in messages {
let msg: PromptMessage = serde_json::from_value(message.clone())
.map_err(|e| Error::ParseError(e.to_string()))?;
result.push(msg);
}
Ok(result)
}
JSONRPCMessage::Error(err) => Err(Error::ServerError(
err.error.code,
err.error.message,
err.error.data,
)),
_ => Err(Error::ProtocolError("Unexpected response type".to_string())),
}
}
pub async fn list_tools(&self) -> Result<Vec<Tool>, Error> {
let caps = self.server_capabilities.read().await;
if let Some(caps) = &*caps {
if !caps.tools {
return Err(Error::UnsupportedFeature("Tools".to_string()));
}
} else {
return Err(Error::NotInitialized);
}
let request = JSONRPCRequest {
jsonrpc: "2.0".to_string(),
id: Uuid::new_v4().to_string().into(),
method: "tools/list".to_string(),
params: None,
};
let response = self.send_request(request).await?;
match response {
JSONRPCMessage::Response(resp) => {
let tools = resp.result["tools"]
.as_array()
.ok_or_else(|| Error::ParseError("Missing tools array".to_string()))?;
let mut result = Vec::with_capacity(tools.len());
for tool in tools {
let t: Tool = serde_json::from_value(tool.clone())
.map_err(|e| Error::ParseError(e.to_string()))?;
result.push(t);
}
Ok(result)
}
JSONRPCMessage::Error(err) => Err(Error::ServerError(
err.error.code,
err.error.message,
err.error.data,
)),
_ => Err(Error::ProtocolError("Unexpected response type".to_string())),
}
}
pub async fn call_tool(
&self,
name: &str,
arguments: Option<serde_json::Value>,
) -> Result<CallToolResult, Error> {
let caps = self.server_capabilities.read().await;
if let Some(caps) = &*caps {
if !caps.tools {
return Err(Error::UnsupportedFeature("Tools".to_string()));
}
} else {
return Err(Error::NotInitialized);
}
let mut params = serde_json::json!({ "name": name });
if let Some(args) = arguments {
params["arguments"] = args;
}
let request = JSONRPCRequest {
jsonrpc: "2.0".to_string(),
id: Uuid::new_v4().to_string().into(),
method: "tools/call".to_string(),
params: Some(params),
};
let response = self.send_request(request).await?;
match response {
JSONRPCMessage::Response(resp) => {
let result: CallToolResult = serde_json::from_value(resp.result)
.map_err(|e| Error::ParseError(e.to_string()))?;
Ok(result)
}
JSONRPCMessage::Error(err) => Err(Error::ServerError(
err.error.code,
err.error.message,
err.error.data,
)),
_ => Err(Error::ProtocolError("Unexpected response type".to_string())),
}
}
pub async fn set_logging_level(&self, level: LoggingLevel) -> Result<(), Error> {
let caps = self.server_capabilities.read().await;
if let Some(caps) = &*caps {
if !caps.logging {
return Err(Error::UnsupportedFeature("Logging".to_string()));
}
} else {
return Err(Error::NotInitialized);
}
let request = JSONRPCRequest {
jsonrpc: "2.0".to_string(),
id: Uuid::new_v4().to_string().into(),
method: "logging/setLevel".to_string(),
params: Some(serde_json::json!({ "level": level })),
};
let response = self.send_request(request).await?;
match response {
JSONRPCMessage::Response(_) => Ok(()),
JSONRPCMessage::Error(err) => Err(Error::ServerError(
err.error.code,
err.error.message,
err.error.data,
)),
_ => Err(Error::ProtocolError("Unexpected response type".to_string())),
}
}
pub async fn get_completions(
&self,
reference_type: CompletionReferenceType,
reference_name: &str,
argument_name: &str,
argument_value: &str,
) -> Result<CompleteResult, Error> {
let caps = self.server_capabilities.read().await;
if let Some(caps) = &*caps {
if !caps.completions {
return Err(Error::UnsupportedFeature("Completions".to_string()));
}
} else {
return Err(Error::NotInitialized);
}
let ref_obj = match reference_type {
CompletionReferenceType::Prompt => {
serde_json::json!({
"type": "ref/prompt",
"name": reference_name
})
}
CompletionReferenceType::Resource => {
serde_json::json!({
"type": "ref/resource",
"uri": reference_name
})
}
};
let request = JSONRPCRequest {
jsonrpc: "2.0".to_string(),
id: Uuid::new_v4().to_string().into(),
method: "completion/complete".to_string(),
params: Some(serde_json::json!({
"ref": ref_obj,
"argument": {
"name": argument_name,
"value": argument_value
}
})),
};
let response = self.send_request(request).await?;
match response {
JSONRPCMessage::Response(resp) => {
let result: CompleteResult = serde_json::from_value(resp.result)
.map_err(|e| Error::ParseError(e.to_string()))?;
Ok(result)
}
JSONRPCMessage::Error(err) => Err(Error::ServerError(
err.error.code,
err.error.message,
err.error.data,
)),
_ => Err(Error::ProtocolError("Unexpected response type".to_string())),
}
}
pub async fn list_roots(&self) -> Result<Vec<Root>, Error> {
let request = JSONRPCRequest {
jsonrpc: "2.0".to_string(),
id: Uuid::new_v4().to_string().into(),
method: "roots/list".to_string(),
params: None,
};
let response = self.send_request(request).await?;
match response {
JSONRPCMessage::Response(resp) => {
let roots = resp.result["roots"]
.as_array()
.ok_or_else(|| Error::ParseError("Missing roots array".to_string()))?;
let mut result = Vec::with_capacity(roots.len());
for root in roots {
let r: Root = serde_json::from_value(root.clone())
.map_err(|e| Error::ParseError(e.to_string()))?;
result.push(r);
}
Ok(result)
}
JSONRPCMessage::Error(err) => Err(Error::ServerError(
err.error.code,
err.error.message,
err.error.data,
)),
_ => Err(Error::ProtocolError("Unexpected response type".to_string())),
}
}
pub async fn notify_roots_changed(&self) -> Result<(), Error> {
if !self.options.capabilities.roots || !self.options.capabilities.roots_list_changed {
return Err(Error::UnsupportedFeature("Roots list changed notifications".to_string()));
}
let notification = JSONRPCNotification {
jsonrpc: "2.0".to_string(),
method: "notifications/roots/list_changed".to_string(),
params: None,
};
self.send_notification(notification).await
}
pub async fn ping(&self) -> Result<(), Error> {
let request = JSONRPCRequest {
jsonrpc: "2.0".to_string(),
id: Uuid::new_v4().to_string().into(),
method: "ping".to_string(),
params: None,
};
let response = self.send_request(request).await?;
match response {
JSONRPCMessage::Response(_) => Ok(()),
JSONRPCMessage::Error(err) => Err(Error::ServerError(
err.error.code,
err.error.message,
err.error.data,
)),
_ => Err(Error::ProtocolError("Unexpected response type".to_string())),
}
}
pub async fn cancel_request(
&self,
request_id: RequestId,
reason: Option<String>,
) -> Result<(), Error> {
let notification = JSONRPCNotification {
jsonrpc: "2.0".to_string(),
method: "notifications/cancelled".to_string(),
params: Some(serde_json::json!({
"requestId": request_id,
"reason": reason
})),
};
self.send_notification(notification).await
}
}
#[derive(Debug, Clone)]
pub enum ResourceContent {
Text(TextResourceContents),
Blob(BlobResourceContents),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CompletionReferenceType {
Prompt,
Resource,
}