use anyhow::{anyhow, Result};
use serde_json::{json, Value};
use std::collections::HashMap;
use tokio::sync::RwLock;
use super::transport::{create_transport, Transport, TransportConfig};
use super::types::*;
pub struct McpClient {
server_name: String,
transport: Box<dyn Transport>,
capabilities: RwLock<Option<ServerCapabilities>>,
server_info: RwLock<Option<Implementation>>,
tools_cache: RwLock<Vec<Tool>>,
request_id: RwLock<i64>,
initialized: RwLock<bool>,
}
impl McpClient {
pub async fn connect(
server_name: impl Into<String>,
config: TransportConfig,
) -> Result<Self> {
let server_name = server_name.into();
let transport = create_transport(&server_name, &config).await?;
let client = Self {
server_name,
transport,
capabilities: RwLock::new(None),
server_info: RwLock::new(None),
tools_cache: RwLock::new(Vec::new()),
request_id: RwLock::new(0),
initialized: RwLock::new(false),
};
client.initialize().await?;
Ok(client)
}
pub fn server_name(&self) -> &str {
&self.server_name
}
pub async fn is_initialized(&self) -> bool {
*self.initialized.read().await
}
pub async fn capabilities(&self) -> Option<ServerCapabilities> {
self.capabilities.read().await.clone()
}
pub async fn server_info(&self) -> Option<Implementation> {
self.server_info.read().await.clone()
}
async fn next_request_id(&self) -> RequestId {
let mut id = self.request_id.write().await;
*id += 1;
RequestId::Number(*id)
}
async fn send_request<T: serde::de::DeserializeOwned>(
&self,
method: &str,
params: Option<Value>,
) -> Result<T> {
let id = self.next_request_id().await;
let request = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: id.clone(),
method: method.to_string(),
params,
};
let message = serde_json::to_string(&request)?;
tracing::debug!("MCP request to '{}': {}", self.server_name, message);
self.transport.notify(&message).await?;
loop {
let response = self.transport.receive().await?;
tracing::debug!("MCP message from '{}': {}", self.server_name, response);
if let Ok(server_req) = serde_json::from_str::<JsonRpcRequest>(&response) {
self.handle_server_request(&server_req).await?;
continue;
}
if let Ok(success) = serde_json::from_str::<JsonRpcResponse>(&response) {
if success.id != id {
continue;
}
return serde_json::from_value(success.result)
.map_err(|e| anyhow!("Failed to parse result: {}", e));
}
if let Ok(error) = serde_json::from_str::<JsonRpcError>(&response) {
if error.id != id {
continue;
}
return Err(anyhow!(
"MCP error from '{}': [{}] {}",
self.server_name,
error.error.code,
error.error.message
));
}
if let Ok(notification) = serde_json::from_str::<JsonRpcNotification>(&response) {
tracing::debug!("MCP notification from '{}': {}", self.server_name, notification.method);
continue;
}
tracing::warn!("Unexpected MCP message format: {}", response);
}
}
async fn handle_server_request(&self, request: &JsonRpcRequest) -> Result<()> {
tracing::debug!("MCP server request '{}': {}", self.server_name, request.method);
match request.method.as_str() {
"roots/list" => {
let response = JsonRpcResponse {
jsonrpc: "2.0".to_string(),
id: request.id.clone(),
result: json!({ "roots": [] }),
};
let message = serde_json::to_string(&response)?;
self.transport.notify(&message).await?;
}
"ping" => {
let response = JsonRpcResponse {
jsonrpc: "2.0".to_string(),
id: request.id.clone(),
result: json!({}),
};
let message = serde_json::to_string(&response)?;
self.transport.notify(&message).await?;
}
_ => {
tracing::warn!("Unhandled MCP server request: {}", request.method);
let error_response = JsonRpcError {
jsonrpc: "2.0".to_string(),
id: request.id.clone(),
error: JsonRpcErrorDetail {
code: -32601,
message: "Method not found".to_string(),
data: None,
},
};
let message = serde_json::to_string(&error_response)?;
self.transport.notify(&message).await?;
}
}
Ok(())
}
async fn send_notification(&self, method: &str, params: Option<Value>) -> Result<()> {
let notification = JsonRpcNotification {
jsonrpc: "2.0".to_string(),
method: method.to_string(),
params,
};
let message = serde_json::to_string(¬ification)?;
self.transport.notify(&message).await?;
Ok(())
}
async fn initialize(&self) -> Result<()> {
tracing::info!("Initializing MCP server '{}'", self.server_name);
let params = InitializeParams {
capabilities: ClientCapabilities {
roots: Some(RootsCapability {
list_changed: Some(false),
}),
..Default::default()
},
client_info: Implementation::default(),
protocol_version: Some("2024-11-05".to_string()),
};
let result: InitializeResult = self.send_request(
"initialize",
Some(serde_json::to_value(params)?),
).await?;
let server_name = result.server_info.name.clone();
let server_version = result.server_info.version.clone();
*self.capabilities.write().await = Some(result.capabilities);
*self.server_info.write().await = Some(result.server_info);
tracing::info!(
"MCP server '{}' initialized: {} v{}",
self.server_name,
server_name,
server_version
);
self.send_notification("notifications/initialized", None).await?;
*self.initialized.write().await = true;
Ok(())
}
pub async fn list_tools(&self) -> Result<Vec<Tool>> {
if !self.is_initialized().await {
return Err(anyhow!("MCP client not initialized"));
}
let result: ListToolsResult = self.send_request("tools/list", None).await?;
*self.tools_cache.write().await = result.tools.clone();
Ok(result.tools)
}
pub async fn cached_tools(&self) -> Vec<Tool> {
self.tools_cache.read().await.clone()
}
pub async fn call_tool(
&self,
name: &str,
arguments: Option<Value>,
) -> Result<CallToolResult> {
if !self.is_initialized().await {
return Err(anyhow!("MCP client not initialized"));
}
let params = CallToolParams {
name: name.to_string(),
arguments,
};
self.send_request("tools/call", Some(serde_json::to_value(params)?)).await
}
pub async fn supports_tools(&self) -> bool {
self.capabilities.read().await
.as_ref()
.map(|c| c.tools.is_some())
.unwrap_or(false)
}
pub async fn list_resources(&self) -> Result<Vec<Resource>> {
if !self.is_initialized().await {
return Err(anyhow!("MCP client not initialized"));
}
let result: ListResourcesResult = self.send_request("resources/list", None).await?;
Ok(result.resources)
}
pub async fn read_resource(&self, uri: &str) -> Result<Value> {
if !self.is_initialized().await {
return Err(anyhow!("MCP client not initialized"));
}
self.send_request("resources/read", Some(json!({ "uri": uri }))).await
}
pub async fn supports_resources(&self) -> bool {
self.capabilities.read().await
.as_ref()
.map(|c| c.resources.is_some())
.unwrap_or(false)
}
pub async fn list_prompts(&self) -> Result<Vec<Prompt>> {
if !self.is_initialized().await {
return Err(anyhow!("MCP client not initialized"));
}
let result: ListPromptsResult = self.send_request("prompts/list", None).await?;
Ok(result.prompts)
}
pub async fn get_prompt(&self, name: &str, arguments: Option<HashMap<String, String>>) -> Result<Value> {
if !self.is_initialized().await {
return Err(anyhow!("MCP client not initialized"));
}
let mut params = json!({ "name": name });
if let Some(args) = arguments {
params["arguments"] = serde_json::to_value(args)?;
}
self.send_request("prompts/get", Some(params)).await
}
pub async fn supports_prompts(&self) -> bool {
self.capabilities.read().await
.as_ref()
.map(|c| c.prompts.is_some())
.unwrap_or(false)
}
pub async fn set_logging_level(&self, level: LogLevel) -> Result<()> {
if !self.is_initialized().await {
return Err(anyhow!("MCP client not initialized"));
}
let params = SetLoggingLevelParams { level };
self.send_request("logging/setLevel", Some(serde_json::to_value(params)?)).await
}
pub async fn shutdown(&self) -> Result<()> {
tracing::info!("Shutting down MCP server '{}'", self.server_name);
self.transport.close().await
}
}
pub struct McpClientBuilder {
server_name: String,
config: TransportConfig,
}
impl McpClientBuilder {
pub fn new(name: impl Into<String>) -> Self {
Self {
server_name: name.into(),
config: TransportConfig::stdio("", vec![]),
}
}
pub fn stdio(mut self, command: impl Into<String>, args: Vec<String>) -> Self {
self.config = TransportConfig::stdio(command, args);
self
}
pub fn sse(mut self, url: impl Into<String>) -> Self {
self.config = TransportConfig::sse(url);
self
}
pub async fn connect(self) -> Result<McpClient> {
McpClient::connect(self.server_name, self.config).await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_client_builder() {
let builder = McpClientBuilder::new("test")
.stdio("npx", vec!["-y".into(), "@playwright/mcp".into()]);
assert_eq!(builder.server_name, "test");
}
}