use crate::tools::{Function, Tool, ToolCallback, ToolCallbackWithTool, ToolType};
use crate::transport::{HttpTransport, McpTransport, ProcessTransport, WebSocketTransport};
use crate::types::McpToolResult;
use crate::{McpClientConfig, McpServerConfig, McpServerSource, McpToolInfo};
use anyhow::Result;
use rust_mcp_schema::Resource;
use serde_json::Value;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Semaphore;
use tracing::warn;
#[async_trait::async_trait]
pub trait McpServerConnection: Send + Sync {
fn server_id(&self) -> &str;
fn server_name(&self) -> &str;
async fn list_tools(&self) -> Result<Vec<McpToolInfo>>;
async fn call_tool(&self, name: &str, arguments: serde_json::Value) -> Result<String>;
async fn list_resources(&self) -> Result<Vec<Resource>>;
async fn read_resource(&self, uri: &str) -> Result<String>;
async fn ping(&self) -> Result<()>;
async fn close(&self) -> Result<()>;
}
fn initialize_params() -> Value {
serde_json::json!({
"protocolVersion": rust_mcp_schema::ProtocolVersion::latest().to_string(),
"capabilities": {
"tools": {}
},
"clientInfo": {
"name": "mistral.rs",
"version": env!("CARGO_PKG_VERSION"),
}
})
}
async fn initialize_transport(transport: &Arc<dyn McpTransport>) -> Result<()> {
transport
.send_request("initialize", initialize_params())
.await?;
transport.send_initialization_notification().await
}
async fn list_tools_via_transport(
transport: &Arc<dyn McpTransport>,
server_id: &str,
server_name: &str,
) -> Result<Vec<McpToolInfo>> {
let result = transport.send_request("tools/list", Value::Null).await?;
let tools = result
.get("tools")
.and_then(|t| t.as_array())
.ok_or_else(|| anyhow::anyhow!("Invalid tools response format"))?;
let mut tool_infos = Vec::new();
for tool in tools {
let name = tool
.get("name")
.and_then(|n| n.as_str())
.ok_or_else(|| anyhow::anyhow!("Tool missing name"))?
.to_string();
let description = tool
.get("description")
.and_then(|d| d.as_str())
.map(|s| s.to_string());
let input_schema = tool
.get("inputSchema")
.cloned()
.unwrap_or(Value::Object(serde_json::Map::new()));
tool_infos.push(McpToolInfo {
name,
description,
input_schema,
server_id: server_id.to_string(),
server_name: server_name.to_string(),
});
}
Ok(tool_infos)
}
async fn call_tool_via_transport(
transport: &Arc<dyn McpTransport>,
name: &str,
arguments: Value,
) -> Result<String> {
let params = serde_json::json!({
"name": name,
"arguments": arguments
});
let result = transport.send_request("tools/call", params).await?;
let tool_result: McpToolResult = serde_json::from_value(result)?;
if tool_result.is_error.unwrap_or(false) {
return Err(anyhow::anyhow!("Tool execution failed: {tool_result}"));
}
Ok(tool_result.to_string())
}
async fn list_resources_via_transport(transport: &Arc<dyn McpTransport>) -> Result<Vec<Resource>> {
let result = transport
.send_request("resources/list", Value::Null)
.await?;
let resources = result
.get("resources")
.and_then(|r| r.as_array())
.ok_or_else(|| anyhow::anyhow!("Invalid resources response format"))?;
let mut resource_list = Vec::new();
for resource in resources {
resource_list.push(serde_json::from_value(resource.clone())?);
}
Ok(resource_list)
}
async fn read_resource_via_transport(
transport: &Arc<dyn McpTransport>,
uri: &str,
) -> Result<String> {
let params = serde_json::json!({ "uri": uri });
let result = transport.send_request("resources/read", params).await?;
if let Some(contents) = result.get("contents").and_then(|c| c.as_array()) {
if let Some(first_content) = contents.first() {
if let Some(text) = first_content.get("text").and_then(|t| t.as_str()) {
return Ok(text.to_string());
}
}
}
Err(anyhow::anyhow!("No readable content found in resource"))
}
async fn ping_transport(transport: &Arc<dyn McpTransport>) -> Result<()> {
transport.send_request("ping", Value::Null).await?;
Ok(())
}
pub struct McpClient {
config: McpClientConfig,
servers: HashMap<String, Arc<dyn McpServerConnection>>,
tools: HashMap<String, McpToolInfo>,
tool_callbacks: HashMap<String, Arc<ToolCallback>>,
tool_callbacks_with_tools: HashMap<String, ToolCallbackWithTool>,
concurrency_semaphore: Arc<Semaphore>,
}
impl McpClient {
pub fn new(config: McpClientConfig) -> Self {
let max_concurrent = config.max_concurrent_calls.unwrap_or(10);
Self {
config,
servers: HashMap::new(),
tools: HashMap::new(),
tool_callbacks: HashMap::new(),
tool_callbacks_with_tools: HashMap::new(),
concurrency_semaphore: Arc::new(Semaphore::new(max_concurrent)),
}
}
pub async fn initialize(&mut self) -> Result<()> {
for server_config in &self.config.servers {
if server_config.enabled {
let connection = self.create_connection(server_config).await?;
self.servers.insert(server_config.id.clone(), connection);
}
}
if self.config.auto_register_tools {
self.discover_and_register_tools().await?;
}
Ok(())
}
pub fn get_tool_callbacks(&self) -> &HashMap<String, Arc<ToolCallback>> {
&self.tool_callbacks
}
pub fn get_tool_callbacks_with_tools(&self) -> &HashMap<String, ToolCallbackWithTool> {
&self.tool_callbacks_with_tools
}
pub fn get_tools(&self) -> &HashMap<String, McpToolInfo> {
&self.tools
}
pub fn servers(&self) -> &HashMap<String, Arc<dyn McpServerConnection>> {
&self.servers
}
pub fn server(&self, id: &str) -> Option<&Arc<dyn McpServerConnection>> {
self.servers.get(id)
}
pub fn config(&self) -> &McpClientConfig {
&self.config
}
async fn create_connection(
&self,
config: &McpServerConfig,
) -> Result<Arc<dyn McpServerConnection>> {
match &config.source {
McpServerSource::Http {
url,
timeout_secs,
headers,
} => {
let mut merged_headers = headers.clone().unwrap_or_default();
if let Some(token) = &config.bearer_token {
merged_headers.insert("Authorization".to_string(), format!("Bearer {token}"));
}
let connection = HttpMcpConnection::new(
config.id.clone(),
config.name.clone(),
url.clone(),
*timeout_secs,
Some(merged_headers),
)
.await?;
Ok(Arc::new(connection))
}
McpServerSource::Process {
command,
args,
work_dir,
env,
} => {
let connection = ProcessMcpConnection::new(
config.id.clone(),
config.name.clone(),
command.clone(),
args.clone(),
work_dir.clone(),
env.clone(),
)
.await?;
Ok(Arc::new(connection))
}
McpServerSource::WebSocket {
url,
timeout_secs,
headers,
} => {
let mut merged_headers = headers.clone().unwrap_or_default();
if let Some(token) = &config.bearer_token {
merged_headers.insert("Authorization".to_string(), format!("Bearer {token}"));
}
let connection = WebSocketMcpConnection::new(
config.id.clone(),
config.name.clone(),
url.clone(),
*timeout_secs,
Some(merged_headers),
)
.await?;
Ok(Arc::new(connection))
}
}
}
async fn discover_and_register_tools(&mut self) -> Result<()> {
for (server_id, connection) in &self.servers {
let tools = connection.list_tools().await?;
let server_config = self
.config
.servers
.iter()
.find(|s| &s.id == server_id)
.ok_or_else(|| anyhow::anyhow!("Server config not found for {}", server_id))?;
for tool in tools {
let tool_name = if let Some(prefix) = &server_config.tool_prefix {
format!("{}_{}", prefix, tool.name)
} else {
tool.name.clone()
};
let connection_clone = Arc::clone(connection);
let original_tool_name = tool.name.clone();
let semaphore_clone = Arc::clone(&self.concurrency_semaphore);
let timeout_duration =
Duration::from_secs(self.config.tool_timeout_secs.unwrap_or(30));
let callback: Arc<ToolCallback> = Arc::new(move |called_function| {
let connection = Arc::clone(&connection_clone);
let tool_name = original_tool_name.clone();
let semaphore = Arc::clone(&semaphore_clone);
let arguments: serde_json::Value =
serde_json::from_str(&called_function.arguments)?;
let rt = tokio::runtime::Handle::current();
std::thread::spawn(move || {
rt.block_on(async move {
let _permit = semaphore.acquire().await.map_err(|_| {
anyhow::anyhow!("Failed to acquire concurrency permit")
})?;
match tokio::time::timeout(
timeout_duration,
connection.call_tool(&tool_name, arguments),
)
.await
{
Ok(result) => result,
Err(_) => Err(anyhow::anyhow!(
"Tool call timed out after {} seconds",
timeout_duration.as_secs()
)),
}
})
})
.join()
.map_err(|_| anyhow::anyhow!("Tool call thread panicked"))?
});
let function_def = Function {
name: tool_name.clone(),
description: tool.description.clone(),
parameters: Self::convert_mcp_schema_to_parameters(&tool.input_schema),
};
let tool_def = Tool {
tp: ToolType::Function,
function: function_def,
};
self.tool_callbacks
.insert(tool_name.clone(), callback.clone());
self.tool_callbacks_with_tools.insert(
tool_name.clone(),
ToolCallbackWithTool {
callback,
tool: tool_def,
},
);
self.tools.insert(tool_name, tool);
}
}
Ok(())
}
fn convert_mcp_schema_to_parameters(
schema: &serde_json::Value,
) -> Option<HashMap<String, serde_json::Value>> {
match schema {
serde_json::Value::Object(obj) => {
let mut params = HashMap::new();
if let Some(properties) = obj.get("properties") {
if let serde_json::Value::Object(props) = properties {
for (key, value) in props {
params.insert(key.clone(), value.clone());
}
}
} else {
for (key, value) in obj {
params.insert(key.clone(), value.clone());
}
}
if params.is_empty() {
None
} else {
Some(params)
}
}
_ => {
None
}
}
}
fn remove_tools_for_server(&mut self, server_id: &str) {
let tools_to_remove: Vec<String> = self
.tools
.iter()
.filter(|(_, info)| info.server_id == server_id)
.map(|(name, _)| name.clone())
.collect();
for name in tools_to_remove {
self.tools.remove(&name);
self.tool_callbacks.remove(&name);
self.tool_callbacks_with_tools.remove(&name);
}
}
async fn register_tools_for_server(&mut self, server_id: &str) -> Result<()> {
let connection = self
.servers
.get(server_id)
.ok_or_else(|| anyhow::anyhow!("Server not connected: {}", server_id))?
.clone();
let server_config = self
.config
.servers
.iter()
.find(|s| s.id == server_id)
.ok_or_else(|| anyhow::anyhow!("Server config not found for {}", server_id))?
.clone();
let tools = connection.list_tools().await?;
for tool in tools {
let tool_name = if let Some(prefix) = &server_config.tool_prefix {
format!("{}_{}", prefix, tool.name)
} else {
tool.name.clone()
};
let connection_clone = Arc::clone(&connection);
let original_tool_name = tool.name.clone();
let semaphore_clone = Arc::clone(&self.concurrency_semaphore);
let timeout_duration = Duration::from_secs(self.config.tool_timeout_secs.unwrap_or(30));
let callback: Arc<ToolCallback> = Arc::new(move |called_function| {
let connection = Arc::clone(&connection_clone);
let tool_name = original_tool_name.clone();
let semaphore = Arc::clone(&semaphore_clone);
let arguments: serde_json::Value =
serde_json::from_str(&called_function.arguments)?;
let rt = tokio::runtime::Handle::current();
std::thread::spawn(move || {
rt.block_on(async move {
let _permit = semaphore
.acquire()
.await
.map_err(|_| anyhow::anyhow!("Failed to acquire concurrency permit"))?;
match tokio::time::timeout(
timeout_duration,
connection.call_tool(&tool_name, arguments),
)
.await
{
Ok(result) => result,
Err(_) => Err(anyhow::anyhow!(
"Tool call timed out after {} seconds",
timeout_duration.as_secs()
)),
}
})
})
.join()
.map_err(|_| anyhow::anyhow!("Tool call thread panicked"))?
});
let function_def = Function {
name: tool_name.clone(),
description: tool.description.clone(),
parameters: Self::convert_mcp_schema_to_parameters(&tool.input_schema),
};
let tool_def = Tool {
tp: ToolType::Function,
function: function_def,
};
self.tool_callbacks
.insert(tool_name.clone(), callback.clone());
self.tool_callbacks_with_tools.insert(
tool_name.clone(),
ToolCallbackWithTool {
callback,
tool: tool_def,
},
);
self.tools.insert(tool_name, tool);
}
Ok(())
}
pub async fn shutdown(&mut self) -> Result<()> {
for connection in self.servers.values() {
let _ = connection.close().await;
}
self.servers.clear();
self.tools.clear();
self.tool_callbacks.clear();
self.tool_callbacks_with_tools.clear();
Ok(())
}
pub async fn disconnect(&mut self, id: &str) -> Result<()> {
let connection = self
.servers
.remove(id)
.ok_or_else(|| anyhow::anyhow!("Server not connected: {}", id))?;
connection.close().await?;
self.remove_tools_for_server(id);
Ok(())
}
pub async fn reconnect(&mut self, id: &str) -> Result<()> {
let server_config = self
.config
.servers
.iter()
.find(|s| s.id == id)
.ok_or_else(|| anyhow::anyhow!("Server config not found: {}", id))?
.clone();
if let Some(connection) = self.servers.remove(id) {
let _ = connection.close().await;
}
self.remove_tools_for_server(id);
let connection = self.create_connection(&server_config).await?;
self.servers.insert(id.to_string(), connection);
if self.config.auto_register_tools {
self.register_tools_for_server(id).await?;
}
Ok(())
}
pub fn is_connected(&self, id: &str) -> bool {
self.servers.contains_key(id)
}
pub async fn add_server(&mut self, config: McpServerConfig) -> Result<()> {
let id = config.id.clone();
if self.servers.contains_key(&id) {
return Err(anyhow::anyhow!("Server already exists: {}", id));
}
let connection = self.create_connection(&config).await?;
self.servers.insert(id.clone(), connection);
self.config.servers.push(config);
if self.config.auto_register_tools {
self.register_tools_for_server(&id).await?;
}
Ok(())
}
pub async fn remove_server(&mut self, id: &str) -> Result<()> {
if let Some(connection) = self.servers.remove(id) {
let _ = connection.close().await;
}
self.remove_tools_for_server(id);
self.config.servers.retain(|s| s.id != id);
Ok(())
}
pub async fn refresh_tools(&mut self) -> Result<()> {
self.tools.clear();
self.tool_callbacks.clear();
self.tool_callbacks_with_tools.clear();
self.discover_and_register_tools().await
}
pub fn get_tool(&self, name: &str) -> Option<&McpToolInfo> {
self.tools.get(name)
}
pub fn has_tool(&self, name: &str) -> bool {
self.tools.contains_key(name)
}
pub async fn call_tool(&self, name: &str, arguments: serde_json::Value) -> Result<String> {
let tool_info = self
.tools
.get(name)
.ok_or_else(|| anyhow::anyhow!("Tool not found: {}", name))?;
let connection = self
.servers
.get(&tool_info.server_id)
.ok_or_else(|| anyhow::anyhow!("Server not connected: {}", tool_info.server_id))?;
let _permit = self
.concurrency_semaphore
.acquire()
.await
.map_err(|_| anyhow::anyhow!("Failed to acquire concurrency permit"))?;
let timeout_duration = Duration::from_secs(self.config.tool_timeout_secs.unwrap_or(30));
match tokio::time::timeout(
timeout_duration,
connection.call_tool(&tool_info.name, arguments),
)
.await
{
Ok(result) => result,
Err(_) => Err(anyhow::anyhow!(
"Tool call timed out after {} seconds",
timeout_duration.as_secs()
)),
}
}
pub fn tool_count(&self) -> usize {
self.tools.len()
}
pub fn server_count(&self) -> usize {
self.servers.len()
}
pub fn server_ids(&self) -> Vec<&str> {
self.servers.keys().map(|s| s.as_str()).collect()
}
pub async fn ping_all(&self) -> HashMap<String, Result<()>> {
let mut results = HashMap::new();
for (server_id, connection) in &self.servers {
let result = connection.ping().await;
results.insert(server_id.clone(), result);
}
results
}
pub async fn list_all_resources(&self) -> Result<Vec<(String, Resource)>> {
let mut all_resources = Vec::new();
for (server_id, connection) in &self.servers {
match connection.list_resources().await {
Ok(resources) => {
for resource in resources {
all_resources.push((server_id.clone(), resource));
}
}
Err(e) => {
warn!("Failed to list resources from server {}: {}", server_id, e);
}
}
}
Ok(all_resources)
}
}
impl Drop for McpClient {
fn drop(&mut self) {
if let Ok(handle) = tokio::runtime::Handle::try_current() {
let servers = std::mem::take(&mut self.servers);
handle.spawn(async move {
for (_, connection) in servers {
let _ = connection.close().await;
}
});
}
}
}
pub struct HttpMcpConnection {
server_id: String,
server_name: String,
transport: Arc<dyn McpTransport>,
}
impl HttpMcpConnection {
pub async fn new(
server_id: String,
server_name: String,
url: String,
timeout_secs: Option<u64>,
headers: Option<HashMap<String, String>>,
) -> Result<Self> {
let transport = HttpTransport::new(url, timeout_secs, headers)?;
let connection = Self {
server_id,
server_name,
transport: Arc::new(transport),
};
connection.initialize().await?;
Ok(connection)
}
async fn initialize(&self) -> Result<()> {
initialize_transport(&self.transport).await
}
}
#[async_trait::async_trait]
impl McpServerConnection for HttpMcpConnection {
fn server_id(&self) -> &str {
&self.server_id
}
fn server_name(&self) -> &str {
&self.server_name
}
async fn list_tools(&self) -> Result<Vec<McpToolInfo>> {
list_tools_via_transport(&self.transport, &self.server_id, &self.server_name).await
}
async fn call_tool(&self, name: &str, arguments: Value) -> Result<String> {
call_tool_via_transport(&self.transport, name, arguments).await
}
async fn list_resources(&self) -> Result<Vec<Resource>> {
list_resources_via_transport(&self.transport).await
}
async fn read_resource(&self, uri: &str) -> Result<String> {
read_resource_via_transport(&self.transport, uri).await
}
async fn ping(&self) -> Result<()> {
ping_transport(&self.transport).await
}
async fn close(&self) -> Result<()> {
self.transport.close().await
}
}
pub struct ProcessMcpConnection {
server_id: String,
server_name: String,
transport: Arc<dyn McpTransport>,
}
impl ProcessMcpConnection {
pub async fn new(
server_id: String,
server_name: String,
command: String,
args: Vec<String>,
work_dir: Option<String>,
env: Option<HashMap<String, String>>,
) -> Result<Self> {
let transport = ProcessTransport::new(command, args, work_dir, env).await?;
let connection = Self {
server_id,
server_name,
transport: Arc::new(transport),
};
connection.initialize().await?;
Ok(connection)
}
async fn initialize(&self) -> Result<()> {
initialize_transport(&self.transport).await
}
}
#[async_trait::async_trait]
impl McpServerConnection for ProcessMcpConnection {
fn server_id(&self) -> &str {
&self.server_id
}
fn server_name(&self) -> &str {
&self.server_name
}
async fn list_tools(&self) -> Result<Vec<McpToolInfo>> {
list_tools_via_transport(&self.transport, &self.server_id, &self.server_name).await
}
async fn call_tool(&self, name: &str, arguments: Value) -> Result<String> {
call_tool_via_transport(&self.transport, name, arguments).await
}
async fn list_resources(&self) -> Result<Vec<Resource>> {
list_resources_via_transport(&self.transport).await
}
async fn read_resource(&self, uri: &str) -> Result<String> {
read_resource_via_transport(&self.transport, uri).await
}
async fn ping(&self) -> Result<()> {
ping_transport(&self.transport).await
}
async fn close(&self) -> Result<()> {
self.transport.close().await
}
}
pub struct WebSocketMcpConnection {
server_id: String,
server_name: String,
transport: Arc<dyn McpTransport>,
}
impl WebSocketMcpConnection {
pub async fn new(
server_id: String,
server_name: String,
url: String,
timeout_secs: Option<u64>,
headers: Option<HashMap<String, String>>,
) -> Result<Self> {
let transport = WebSocketTransport::new(url, timeout_secs, headers).await?;
let connection = Self {
server_id,
server_name,
transport: Arc::new(transport),
};
connection.initialize().await?;
Ok(connection)
}
async fn initialize(&self) -> Result<()> {
initialize_transport(&self.transport).await
}
}
#[async_trait::async_trait]
impl McpServerConnection for WebSocketMcpConnection {
fn server_id(&self) -> &str {
&self.server_id
}
fn server_name(&self) -> &str {
&self.server_name
}
async fn list_tools(&self) -> Result<Vec<McpToolInfo>> {
list_tools_via_transport(&self.transport, &self.server_id, &self.server_name).await
}
async fn call_tool(&self, name: &str, arguments: Value) -> Result<String> {
call_tool_via_transport(&self.transport, name, arguments).await
}
async fn list_resources(&self) -> Result<Vec<Resource>> {
list_resources_via_transport(&self.transport).await
}
async fn read_resource(&self, uri: &str) -> Result<String> {
read_resource_via_transport(&self.transport, uri).await
}
async fn ping(&self) -> Result<()> {
ping_transport(&self.transport).await
}
async fn close(&self) -> Result<()> {
self.transport.close().await
}
}