use anyhow::{anyhow, Result};
use async_trait::async_trait;
use serde_json::Value;
use std::sync::Arc;
use tokio::sync::RwLock;
use crate::approval::RiskLevel;
use crate::tools::{Tool, ToolDefinition};
use super::client::McpClient;
use super::types::{CallToolResult, Content, Tool as McpTool};
pub struct McpToolWrapper {
client: Arc<McpClient>,
tool_def: McpTool,
server_name: String,
cached_definition: ToolDefinition,
}
impl McpToolWrapper {
pub fn new(client: Arc<McpClient>, tool_def: McpTool, server_name: String) -> Self {
let name = format!("{}_{}", server_name, tool_def.name);
let description = tool_def.description.clone()
.unwrap_or_else(|| format!("MCP tool: {}", tool_def.name));
let cached_definition = ToolDefinition {
name: name.clone(),
description,
parameters: tool_def.input_schema.clone(),
is_priority: false,
};
Self {
client,
tool_def,
server_name,
cached_definition,
}
}
pub fn original_name(&self) -> &str {
&self.tool_def.name
}
pub fn server_name(&self) -> &str {
&self.server_name
}
fn parse_result(&self, result: CallToolResult) -> String {
if result.content.is_empty() {
return String::new();
}
let mut output = String::new();
for content in result.content {
match content {
Content::Text { text } => {
output.push_str(&text);
output.push('\n');
}
Content::Image { data, mime_type } => {
output.push_str(&format!("[Image: {} ({} bytes)]\n", mime_type, data.len()));
}
Content::Resource { resource } => {
if let Some(text) = resource.text {
output.push_str(&text);
output.push('\n');
} else if let Some(blob) = resource.blob {
output.push_str(&format!("[Resource: {} ({} bytes)]\n", resource.uri, blob.len()));
} else {
output.push_str(&format!("[Resource: {}]\n", resource.uri));
}
}
}
}
output.trim_end().to_string()
}
}
#[async_trait]
impl Tool for McpToolWrapper {
fn definition(&self) -> ToolDefinition {
self.cached_definition.clone()
}
async fn execute(&self, params: Value) -> Result<String> {
tracing::debug!(
"Executing MCP tool '{}' from server '{}'",
self.tool_def.name,
self.server_name
);
let result = self.client.call_tool(&self.tool_def.name, Some(params)).await
.map_err(|e| anyhow!("MCP tool '{}' failed: {}", self.cached_definition.name, e))?;
if result.is_error.unwrap_or(false) {
let error_msg = self.parse_result(result);
return Err(anyhow!("MCP tool error: {}", error_msg));
}
Ok(self.parse_result(result))
}
fn risk_level(&self) -> RiskLevel {
let name = &self.tool_def.name;
if name.contains("read") || name.contains("list") || name.contains("get") {
RiskLevel::Safe
}
else if name.contains("browser") || name.contains("navigate") || name.contains("click") {
RiskLevel::Mutating
}
else if name.contains("write") || name.contains("delete") || name.contains("create") {
RiskLevel::Dangerous
}
else {
RiskLevel::Mutating
}
}
}
pub struct McpToolManager {
clients: RwLock<Vec<Arc<McpClient>>>,
}
impl McpToolManager {
pub fn new() -> Self {
Self {
clients: RwLock::new(Vec::new()),
}
}
pub async fn connect_server(
&self,
server_name: impl Into<String>,
config: super::transport::TransportConfig,
) -> Result<Vec<Box<dyn Tool>>> {
let server_name = server_name.into();
let client = Arc::new(McpClient::connect(&server_name, config).await?);
if !client.supports_tools().await {
tracing::warn!("MCP server '{}' does not support tools", server_name);
return Ok(Vec::new());
}
let mcp_tools = client.list_tools().await?;
tracing::info!(
"MCP server '{}' provided {} tools",
server_name,
mcp_tools.len()
);
let tools: Vec<Box<dyn Tool>> = mcp_tools
.into_iter()
.map(|tool| Box::new(McpToolWrapper::new(client.clone(), tool, server_name.clone())) as Box<dyn Tool>)
.collect();
self.clients.write().await.push(client);
Ok(tools)
}
pub async fn server_count(&self) -> usize {
self.clients.read().await.len()
}
pub async fn server_names(&self) -> Vec<String> {
self.clients.read().await.iter()
.map(|c| c.server_name().to_string())
.collect()
}
pub async fn shutdown(&self) {
let clients = self.clients.read().await;
for client in clients.iter() {
if let Err(e) = client.shutdown().await {
tracing::error!("Failed to shutdown MCP server '{}': {}", client.server_name(), e);
}
}
}
}
impl Default for McpToolManager {
fn default() -> Self {
Self::new()
}
}
pub async fn connect_mcp_server(
server_name: impl Into<String>,
config: super::transport::TransportConfig,
) -> Result<Vec<Box<dyn Tool>>> {
let server_name = server_name.into();
let client = McpClient::connect(&server_name, config).await?;
if !client.supports_tools().await {
client.shutdown().await?;
return Ok(Vec::new());
}
let mcp_tools = client.list_tools().await?;
let client = Arc::new(client);
let tools: Vec<Box<dyn Tool>> = mcp_tools
.into_iter()
.map(|tool| Box::new(McpToolWrapper::new(client.clone(), tool, server_name.clone())) as Box<dyn Tool>)
.collect();
Ok(tools)
}
pub async fn connect_mcp_servers_from_config(
mcp_config: &std::collections::HashMap<String, super::config::McpServerConfig>,
) -> Result<(Vec<Box<dyn Tool>>, McpToolManager)> {
let manager = McpToolManager::new();
let mut all_tools = Vec::new();
for (name, config) in mcp_config.iter() {
if !config.enabled {
tracing::debug!("MCP server '{}' is disabled, skipping", name);
continue;
}
let transport_config = config.to_transport_config()
.map_err(|e| anyhow!("Failed to create transport config for '{}': {}", name, e))?;
tracing::info!("Connecting to MCP server '{}'...", name);
let tools = manager.connect_server(name, transport_config).await?;
if !tools.is_empty() {
tracing::info!("MCP server '{}' provided {} tools", name, tools.len());
all_tools.extend(tools);
}
}
Ok((all_tools, manager))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mcp_tool_wrapper_definition() {
fn get_risk_level(name: &str) -> RiskLevel {
if name.contains("read") || name.contains("list") || name.contains("get") {
RiskLevel::Safe
} else if name.contains("browser") || name.contains("navigate") || name.contains("click") {
RiskLevel::Mutating
} else if name.contains("write") || name.contains("delete") || name.contains("create") {
RiskLevel::Dangerous
} else {
RiskLevel::Mutating
}
}
assert_eq!(get_risk_level("read_file"), RiskLevel::Safe);
assert_eq!(get_risk_level("browser_navigate"), RiskLevel::Mutating);
assert_eq!(get_risk_level("write_file"), RiskLevel::Dangerous);
}
}