use crate::error::{Result, SofosError};
use crate::mcp::client::McpClient;
use crate::mcp::config::{SafeModeAccess, load_mcp_config};
use crate::mcp::protocol::{CallToolResult, McpTool, ToolContent};
use colored::Colorize;
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::sync::Mutex;
pub const MCP_NAME_SEPARATOR: &str = "___";
#[derive(Debug, Clone)]
pub struct ToolResult {
pub text: String,
pub images: Vec<ImageData>,
}
#[derive(Debug, Clone)]
pub enum ImageData {
Base64 { mime_type: String, data: String },
Url { url: String },
}
impl ImageData {
pub fn outbound_size(&self) -> usize {
match self {
ImageData::Base64 { data, .. } => data.len(),
ImageData::Url { .. } => 0,
}
}
}
pub struct McpManager {
clients: Arc<Mutex<HashMap<String, Arc<McpClient>>>>,
tools_by_server: Arc<HashMap<String, Vec<McpTool>>>,
tool_to_server: Arc<HashMap<String, String>>,
safe_mode_by_server: Arc<HashMap<String, SafeModeAccess>>,
}
fn validate_mcp_name(kind: &str, name: &str) -> Result<()> {
if name.is_empty() {
return Err(SofosError::McpError(format!("MCP {} name is empty", kind)));
}
if name.contains(MCP_NAME_SEPARATOR) {
return Err(SofosError::McpError(format!(
"MCP {} name '{}' contains the reserved separator '{}'",
kind, name, MCP_NAME_SEPARATOR
)));
}
if kind == "server"
&& !name
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-')
{
return Err(SofosError::McpError(format!(
"MCP server name '{}' must contain only ASCII letters, digits, '_' or '-'",
name
)));
}
Ok(())
}
pub fn prefixed_tool_name(server: &str, tool: &str) -> String {
format!("{}{}{}", server, MCP_NAME_SEPARATOR, tool)
}
impl McpManager {
pub async fn new(workspace: PathBuf) -> Result<(Self, String)> {
let server_configs = load_mcp_config(&workspace);
let mut clients: HashMap<String, Arc<McpClient>> = HashMap::new();
let mut tools_by_server: HashMap<String, Vec<McpTool>> = HashMap::new();
let mut tool_to_server: HashMap<String, String> = HashMap::new();
let mut safe_mode_by_server: HashMap<String, SafeModeAccess> = HashMap::new();
let mut bullets = String::new();
for (server_name, config) in server_configs {
if let Err(e) = validate_mcp_name("server", &server_name) {
tracing::warn!(server = %server_name, error = %e, "skipping MCP server");
continue;
}
let server_safe_mode = config.safe_mode;
match McpClient::connect(server_name.clone(), config).await {
Ok(client) => match client.list_tools().await {
Ok(tools) => {
let mut accepted: Vec<McpTool> = Vec::with_capacity(tools.len());
for tool in tools {
if let Err(e) = validate_mcp_name("tool", &tool.name) {
tracing::warn!(
server = %server_name,
tool = %tool.name,
error = %e,
"skipping MCP tool with reserved separator in its name"
);
continue;
}
let prefixed = prefixed_tool_name(&server_name, &tool.name);
if let Some(existing) = tool_to_server.get(&prefixed) {
tracing::warn!(
new_server = %server_name,
existing_server = %existing,
tool = %tool.name,
"MCP tool name collision; keeping the first registration"
);
continue;
}
tool_to_server.insert(prefixed, server_name.clone());
accepted.push(tool);
}
let tool_count = accepted.len();
tools_by_server.insert(server_name.clone(), accepted);
clients.insert(server_name.clone(), Arc::new(client));
safe_mode_by_server.insert(server_name.clone(), server_safe_mode);
bullets.push_str(&format!(
" {} {} ({} tools)\n",
"✓".bright_green(),
server_name.bright_cyan(),
tool_count
));
}
Err(e) => {
tracing::warn!(
server = %server_name,
error = %e,
"failed to list tools from MCP server"
);
}
},
Err(e) => {
tracing::warn!(
server = %server_name,
error = %e,
"failed to connect to MCP server"
);
}
}
}
let manager = Self {
clients: Arc::new(Mutex::new(clients)),
tools_by_server: Arc::new(tools_by_server),
tool_to_server: Arc::new(tool_to_server),
safe_mode_by_server: Arc::new(safe_mode_by_server),
};
let init_block = if bullets.is_empty() {
String::new()
} else {
format!("{}\n{}", "MCP servers:".bright_green(), bullets)
};
Ok((manager, init_block))
}
pub fn server_names_for_safe_mode(&self, included: bool) -> Vec<String> {
let mut names: Vec<String> = self
.safe_mode_by_server
.iter()
.filter(|(_, access)| access.is_available_in_safe_mode() == included)
.map(|(name, _)| name.clone())
.collect();
names.sort();
names
}
pub fn is_server_available_in_safe_mode(&self, server: &str) -> bool {
self.safe_mode_by_server
.get(server)
.copied()
.unwrap_or_default()
.is_available_in_safe_mode()
}
pub async fn get_all_tools(&self) -> Result<Vec<crate::api::Tool>> {
Ok(self.collect_tools(false))
}
pub async fn get_safe_mode_tools(&self) -> Result<Vec<crate::api::Tool>> {
Ok(self.collect_tools(true))
}
fn collect_tools(&self, safe_mode: bool) -> Vec<crate::api::Tool> {
let mut all_tools = Vec::new();
for (server_name, tools) in self.tools_by_server.iter() {
if safe_mode && !self.is_server_available_in_safe_mode(server_name) {
continue;
}
for mcp_tool in tools {
let prefixed_name = prefixed_tool_name(server_name, &mcp_tool.name);
all_tools.push(crate::api::Tool::Regular {
name: prefixed_name,
description: format!("[MCP: {}] {}", server_name, mcp_tool.description),
input_schema: mcp_tool.input_schema.clone(),
cache_control: None,
});
}
}
all_tools
}
pub async fn execute_tool(
&self,
tool_name: &str,
input: &serde_json::Value,
) -> Result<ToolResult> {
let server_name = self
.tool_to_server
.get(tool_name)
.ok_or_else(|| SofosError::McpError(format!("Unknown MCP tool: {}", tool_name)))?;
let client = {
let clients = self.clients.lock().await;
clients.get(server_name).cloned().ok_or_else(|| {
SofosError::McpError(format!("MCP server '{}' not connected", server_name))
})?
};
let original_tool_name = tool_name
.strip_prefix(&format!("{}{}", server_name, MCP_NAME_SEPARATOR))
.unwrap_or(tool_name);
let result = client
.call_tool(original_tool_name, Some(input.clone()))
.await?;
Ok(format_tool_result(result))
}
pub fn is_mcp_tool(&self, tool_name: &str) -> bool {
self.tool_to_server.contains_key(tool_name)
}
pub fn server_for_tool(&self, tool_name: &str) -> Option<&str> {
self.tool_to_server.get(tool_name).map(String::as_str)
}
}
fn format_tool_result(result: CallToolResult) -> ToolResult {
let mut text_output = String::new();
let mut images = Vec::new();
if result.is_error == Some(true) {
text_output.push_str("Error from MCP tool:\n");
}
for content in result.content {
match content {
ToolContent::Text { text } => {
text_output.push_str(&text);
text_output.push('\n');
}
ToolContent::Image { data, mime_type } => {
let size_kb = crate::tools::utils::base64_approx_decoded_kb(data.len());
text_output.push_str(&format!("[Image: {} ({} KB)]\n", mime_type, size_kb));
images.push(ImageData::Base64 { mime_type, data });
}
ToolContent::Resource {
uri,
mime_type,
text,
} => {
text_output.push_str(&format!("[Resource: {}]\n", uri));
if let Some(mime) = mime_type {
text_output.push_str(&format!("MIME type: {}\n", mime));
}
if let Some(content) = text {
text_output.push_str(&content);
text_output.push('\n');
}
}
}
}
ToolResult {
text: text_output,
images,
}
}
impl Clone for McpManager {
fn clone(&self) -> Self {
Self {
clients: Arc::clone(&self.clients),
tools_by_server: Arc::clone(&self.tools_by_server),
tool_to_server: Arc::clone(&self.tool_to_server),
safe_mode_by_server: Arc::clone(&self.safe_mode_by_server),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn prefixed_name_uses_triple_underscore_separator() {
assert_eq!(prefixed_tool_name("docs", "read"), "docs___read");
assert_eq!(
prefixed_tool_name("github", "create_issue"),
"github___create_issue"
);
}
#[test]
fn validate_rejects_reserved_separator_in_names() {
assert!(validate_mcp_name("server", "good").is_ok());
assert!(validate_mcp_name("server", "with___sep").is_err());
assert!(validate_mcp_name("tool", "").is_err());
}
#[test]
fn collisions_from_underscores_no_longer_overlap() {
let a = prefixed_tool_name("a", "b_c");
let b = prefixed_tool_name("a_b", "c");
assert_ne!(a, b);
assert_eq!(a, "a___b_c");
assert_eq!(b, "a_b___c");
}
}