use std::collections::HashMap;
use std::fs;
use anyhow::{Context, Result};
use crate::network_policy::NetworkPolicyDecider;
use super::config::McpConfig;
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct McpReloadReport {
pub removed: Vec<String>,
pub updated: Vec<String>,
pub connected: Vec<String>,
pub connect_errors: Vec<(String, String)>,
}
use super::connection::McpConnection;
use super::types::{McpPrompt, McpResource, McpResourceTemplate, McpTool};
pub struct McpPool {
pub(super) connections: HashMap<String, McpConnection>,
config: McpConfig,
network_policy: Option<NetworkPolicyDecider>,
}
impl McpPool {
pub fn new(config: McpConfig) -> Self {
Self {
connections: HashMap::new(),
config,
network_policy: None,
}
}
pub fn from_config_path(path: &std::path::Path) -> Result<Self> {
let config = if path.exists() {
let contents = fs::read_to_string(path)
.with_context(|| format!("Failed to read MCP config: {}", path.display()))?;
serde_json::from_str(&contents)
.with_context(|| format!("Failed to parse MCP config: {}", path.display()))?
} else {
McpConfig::default()
};
Ok(Self::new(config))
}
pub fn with_network_policy(mut self, policy: NetworkPolicyDecider) -> Self {
self.network_policy = Some(policy);
self
}
pub async fn get_or_connect(&mut self, server_name: &str) -> Result<&mut McpConnection> {
let is_ready = self
.connections
.get(server_name)
.map(|conn| conn.is_ready())
.unwrap_or(false);
if is_ready {
return self
.connections
.get_mut(server_name)
.ok_or_else(|| anyhow::anyhow!("MCP connection disappeared for {server_name}"));
}
self.connections.remove(server_name);
let server_config = self
.config
.servers
.get(server_name)
.ok_or_else(|| anyhow::anyhow!("Failed to find MCP server: {server_name}"))?
.clone();
if !server_config.is_enabled() {
anyhow::bail!("Failed to connect MCP server '{server_name}': server is disabled");
}
let connection = McpConnection::connect_with_policy(
server_name.to_string(),
server_config,
&self.config.timeouts,
self.network_policy.as_ref(),
)
.await?;
self.connections.insert(server_name.to_string(), connection);
self.connections
.get_mut(server_name)
.ok_or_else(|| anyhow::anyhow!("Failed to store MCP connection for {server_name}"))
}
pub async fn connect_all(&mut self) -> Vec<(String, anyhow::Error)> {
let mut errors = Vec::new();
let names: Vec<String> = self
.config
.servers
.keys()
.filter(|n| self.config.servers[*n].is_enabled())
.cloned()
.collect();
for name in names {
if let Err(e) = self.get_or_connect(&name).await {
errors.push((name, e));
}
}
for (name, server_cfg) in &self.config.servers {
if server_cfg.required
&& server_cfg.is_enabled()
&& !self
.connections
.get(name)
.is_some_and(McpConnection::is_ready)
{
errors.push((
name.clone(),
anyhow::anyhow!("required MCP server failed to initialize"),
));
}
}
errors
}
pub fn all_tools(&self) -> Vec<(String, &McpTool)> {
let mut tools = Vec::new();
for (server, conn) in &self.connections {
for tool in conn.tools() {
if !conn.config().is_tool_enabled(&tool.name) {
continue;
}
tools.push((format!("mcp_{}_{}", server, tool.name), tool));
}
}
tools
}
pub fn all_resources(&self) -> Vec<(String, &McpResource)> {
let mut resources = Vec::new();
for (server, conn) in &self.connections {
for resource in conn.resources() {
let safe_name = resource.name.replace(' ', "_").to_lowercase();
resources.push((format!("mcp_{}_{}", server, safe_name), resource));
}
}
resources
}
#[allow(dead_code)] pub fn all_resource_templates(&self) -> Vec<(String, &McpResourceTemplate)> {
let mut templates = Vec::new();
for (server, conn) in &self.connections {
for template in conn.resource_templates() {
let safe_name = template.name.replace(' ', "_").to_lowercase();
templates.push((format!("mcp_{}_{}", server, safe_name), template));
}
}
templates
}
async fn list_resources(&mut self, server: Option<String>) -> Result<Vec<serde_json::Value>> {
if let Some(server_name) = server {
let conn = self.get_or_connect(&server_name).await?;
let resources = conn
.resources()
.iter()
.map(|resource| {
serde_json::json!({
"server": server_name.clone(),
"uri": resource.uri,
"name": resource.name,
"description": resource.description,
"mime_type": resource.mime_type,
})
})
.collect();
return Ok(resources);
}
let _ = self.connect_all().await;
let mut items = Vec::new();
for (server, conn) in &self.connections {
for resource in conn.resources() {
items.push(serde_json::json!({
"server": server,
"uri": resource.uri,
"name": resource.name,
"description": resource.description,
"mime_type": resource.mime_type,
}));
}
}
Ok(items)
}
async fn list_resource_templates(
&mut self,
server: Option<String>,
) -> Result<Vec<serde_json::Value>> {
if let Some(server_name) = server {
let conn = self.get_or_connect(&server_name).await?;
let templates = conn
.resource_templates()
.iter()
.map(|template| {
serde_json::json!({
"server": server_name.clone(),
"uri_template": template.uri_template,
"name": template.name,
"description": template.description,
"mime_type": template.mime_type,
})
})
.collect();
return Ok(templates);
}
let _ = self.connect_all().await;
let mut items = Vec::new();
for (server, conn) in &self.connections {
for template in conn.resource_templates() {
items.push(serde_json::json!({
"server": server,
"uri_template": template.uri_template,
"name": template.name,
"description": template.description,
"mime_type": template.mime_type,
}));
}
}
Ok(items)
}
pub fn all_prompts(&self) -> Vec<(String, &McpPrompt)> {
let mut prompts = Vec::new();
for (server, conn) in &self.connections {
for prompt in conn.prompts() {
prompts.push((format!("mcp_{}_{}", server, prompt.name), prompt));
}
}
prompts
}
pub async fn read_resource(
&mut self,
server_name: &str,
uri: &str,
) -> Result<serde_json::Value> {
let global_timeouts = self.config.timeouts;
let conn = self.get_or_connect(server_name).await?;
let timeout = conn.config().effective_read_timeout(&global_timeouts);
conn.read_resource(uri, timeout).await
}
pub async fn get_prompt(
&mut self,
server_name: &str,
prompt_name: &str,
arguments: serde_json::Value,
) -> Result<serde_json::Value> {
let global_timeouts = self.config.timeouts;
let conn = self.get_or_connect(server_name).await?;
let timeout = conn.config().effective_execute_timeout(&global_timeouts);
conn.get_prompt(prompt_name, arguments, timeout).await
}
pub(super) fn parse_prefixed_name<'a>(
&self,
prefixed_name: &'a str,
) -> Result<(&'a str, &'a str)> {
let rest = prefixed_name
.strip_prefix("mcp_")
.ok_or_else(|| anyhow::anyhow!("Invalid MCP tool name: {prefixed_name}"))?;
let mut servers: Vec<&str> = self.config.servers.keys().map(String::as_str).collect();
servers.sort_by_key(|name| std::cmp::Reverse(name.len()));
for server in servers {
if let Some(tool) = rest
.strip_prefix(server)
.and_then(|tail| tail.strip_prefix('_'))
&& !tool.is_empty()
{
return Ok((&rest[..server.len()], tool));
}
}
rest.split_once('_')
.filter(|(server, tool)| !server.is_empty() && !tool.is_empty())
.ok_or_else(|| anyhow::anyhow!("Invalid MCP tool name format: {prefixed_name}"))
}
pub fn to_api_tools(&self) -> Vec<crate::models::Tool> {
let mut api_tools = Vec::new();
for (name, tool) in self.all_tools() {
api_tools.push(crate::models::Tool {
tool_type: None,
name,
description: tool.description.clone().unwrap_or_default(),
input_schema: tool.input_schema.clone(),
allowed_callers: Some(vec!["direct".to_string()]),
defer_loading: Some(false),
input_examples: None,
strict: None,
cache_control: None,
});
}
if !self.config.servers.is_empty() {
api_tools.push(crate::models::Tool {
tool_type: None,
name: "list_mcp_resources".to_string(),
description: "List available MCP resources across servers (optionally filtered by server).".to_string(),
input_schema: serde_json::json!({
"type": "object",
"properties": {
"server": { "type": "string", "description": "Optional MCP server name to filter by" }
}
}),
allowed_callers: Some(vec!["direct".to_string()]),
defer_loading: Some(false),
input_examples: None,
strict: None,
cache_control: None,
});
api_tools.push(crate::models::Tool {
tool_type: None,
name: "list_mcp_resource_templates".to_string(),
description: "List available MCP resource templates across servers (optionally filtered by server).".to_string(),
input_schema: serde_json::json!({
"type": "object",
"properties": {
"server": { "type": "string", "description": "Optional MCP server name to filter by" }
}
}),
allowed_callers: Some(vec!["direct".to_string()]),
defer_loading: Some(false),
input_examples: None,
strict: None,
cache_control: None,
});
}
let resources = self.all_resources();
if !resources.is_empty() {
api_tools.push(crate::models::Tool {
tool_type: None,
name: "mcp_read_resource".to_string(),
description: "Read a resource from an MCP server using its URI".to_string(),
input_schema: serde_json::json!({
"type": "object",
"properties": {
"server": { "type": "string", "description": "The name of the MCP server" },
"uri": { "type": "string", "description": "The URI of the resource to read" }
},
"required": ["server", "uri"]
}),
allowed_callers: Some(vec!["direct".to_string()]),
defer_loading: Some(false),
input_examples: None,
strict: None,
cache_control: None,
});
api_tools.push(crate::models::Tool {
tool_type: None,
name: "read_mcp_resource".to_string(),
description: "Alias for mcp_read_resource.".to_string(),
input_schema: serde_json::json!({
"type": "object",
"properties": {
"server": { "type": "string", "description": "The name of the MCP server" },
"uri": { "type": "string", "description": "The URI of the resource to read" }
},
"required": ["server", "uri"]
}),
allowed_callers: Some(vec!["direct".to_string()]),
defer_loading: Some(false),
input_examples: None,
strict: None,
cache_control: None,
});
}
let prompts = self.all_prompts();
if !prompts.is_empty() {
api_tools.push(crate::models::Tool {
tool_type: None,
name: "mcp_get_prompt".to_string(),
description: "Get a prompt from an MCP server".to_string(),
input_schema: serde_json::json!({
"type": "object",
"properties": {
"server": { "type": "string", "description": "The name of the MCP server" },
"name": { "type": "string", "description": "The name of the prompt" },
"arguments": {
"type": "object",
"description": "Optional arguments for the prompt",
"additionalProperties": { "type": "string" }
}
},
"required": ["server", "name"]
}),
allowed_callers: Some(vec!["direct".to_string()]),
defer_loading: Some(false),
input_examples: None,
strict: None,
cache_control: None,
});
}
api_tools
}
pub async fn call_tool(
&mut self,
prefixed_name: &str,
arguments: serde_json::Value,
) -> Result<serde_json::Value> {
if prefixed_name == "list_mcp_resources" {
let server = arguments
.get("server")
.and_then(|v| v.as_str())
.map(str::to_string);
let resources = self.list_resources(server).await?;
return Ok(serde_json::json!({ "resources": resources }));
}
if prefixed_name == "list_mcp_resource_templates" {
let server = arguments
.get("server")
.and_then(|v| v.as_str())
.map(str::to_string);
let templates = self.list_resource_templates(server).await?;
return Ok(serde_json::json!({ "templates": templates }));
}
if prefixed_name == "mcp_read_resource" {
let server_name = arguments
.get("server")
.and_then(|v| v.as_str())
.context("Missing 'server' argument")?;
let uri = arguments
.get("uri")
.and_then(|v| v.as_str())
.context("Missing 'uri' argument")?;
return self.read_resource(server_name, uri).await;
}
if prefixed_name == "read_mcp_resource" {
let server_name = arguments
.get("server")
.and_then(|v| v.as_str())
.context("Missing 'server' argument")?;
let uri = arguments
.get("uri")
.and_then(|v| v.as_str())
.context("Missing 'uri' argument")?;
return self.read_resource(server_name, uri).await;
}
if prefixed_name == "mcp_get_prompt" {
let server_name = arguments
.get("server")
.and_then(|v| v.as_str())
.context("Missing 'server' argument")?;
let name = arguments
.get("name")
.and_then(|v| v.as_str())
.context("Missing 'name' argument")?;
let args = arguments
.get("arguments")
.cloned()
.unwrap_or(serde_json::json!({}));
return self.get_prompt(server_name, name, args).await;
}
let (server_name, tool_name) = self.parse_prefixed_name(prefixed_name)?;
let global_timeouts = self.config.timeouts;
let conn = self.get_or_connect(server_name).await?;
if !conn.config().is_tool_enabled(tool_name) {
anyhow::bail!("MCP tool '{tool_name}' is disabled for server '{server_name}'");
}
let timeout = conn.config().effective_execute_timeout(&global_timeouts);
let started = std::time::Instant::now();
let result = conn.call_tool(tool_name, arguments, timeout).await;
let duration_ms = started.elapsed().as_millis() as u64;
let (success, err_msg, result_bytes) = match &result {
Ok(value) => (
true,
None,
serde_json::to_string(value).map(|s| s.len()).unwrap_or(0),
),
Err(err) => (false, Some(err.to_string()), 0),
};
super::observability::record_mcp_call(
server_name,
format!("tools/call:{tool_name}"),
duration_ms,
success,
err_msg,
result_bytes,
);
result
}
#[allow(dead_code)] pub fn server_names(&self) -> Vec<&str> {
self.config
.servers
.keys()
.map(std::string::String::as_str)
.collect()
}
pub fn connected_servers(&self) -> Vec<&str> {
self.connections
.iter()
.filter(|(_, c)| c.is_ready())
.map(|(n, _)| n.as_str())
.collect()
}
#[allow(dead_code)] pub fn disconnect_all(&mut self) {
self.connections.clear();
}
pub async fn reload_from_path(&mut self, path: &std::path::Path) -> Result<McpReloadReport> {
let config = if path.exists() {
let contents = fs::read_to_string(path)
.with_context(|| format!("Failed to read MCP config: {}", path.display()))?;
serde_json::from_str(&contents)
.with_context(|| format!("Failed to parse MCP config: {}", path.display()))?
} else {
McpConfig::default()
};
Ok(self.reload_config(config, true).await)
}
pub async fn reload_config(
&mut self,
new_config: McpConfig,
reconnect: bool,
) -> McpReloadReport {
let old_config = std::mem::replace(&mut self.config, new_config);
let mut removed = Vec::new();
let mut updated = Vec::new();
let old_names: std::collections::HashSet<_> = old_config.servers.keys().collect();
let new_names: std::collections::HashSet<_> = self.config.servers.keys().collect();
for name in old_names.difference(&new_names) {
removed.push((*name).clone());
if let Some(mut conn) = self.connections.remove(*name) {
conn.transport.shutdown().await;
}
}
for name in old_names.intersection(&new_names) {
if old_config.servers[*name] != self.config.servers[*name] {
updated.push((*name).clone());
if let Some(mut conn) = self.connections.remove(*name) {
conn.transport.shutdown().await;
}
}
}
let disabled_or_missing: Vec<String> = self
.connections
.keys()
.filter(|name| {
self.config
.servers
.get(*name)
.is_none_or(|cfg| !cfg.is_enabled())
})
.cloned()
.collect();
for name in disabled_or_missing {
if let Some(mut conn) = self.connections.remove(&name) {
conn.transport.shutdown().await;
}
}
let mut connect_errors = Vec::new();
if reconnect {
connect_errors = self
.connect_all()
.await
.into_iter()
.map(|(name, err)| (name, err.to_string()))
.collect();
}
let connected = self
.connected_servers()
.into_iter()
.map(str::to_string)
.collect();
McpReloadReport {
removed,
updated,
connected,
connect_errors,
}
}
#[allow(dead_code)] pub async fn shutdown_all(&mut self) {
let names: Vec<String> = self.connections.keys().cloned().collect();
for name in names {
if let Some(conn) = self.connections.get_mut(&name) {
conn.transport.shutdown().await;
}
}
self.connections.clear();
}
#[allow(dead_code)] pub fn config(&self) -> &McpConfig {
&self.config
}
pub fn is_mcp_tool(name: &str) -> bool {
name.starts_with("mcp_")
|| matches!(
name,
"list_mcp_resources" | "list_mcp_resource_templates" | "read_mcp_resource"
)
}
}