use crate::config::{Config, McpServerConfig};
use crate::mcp::McpConnectionType;
use anyhow::Result;
use std::collections::HashMap;
use std::sync::{Arc, OnceLock, RwLock};
static TOOL_MAP: OnceLock<Arc<RwLock<ToolMapState>>> = OnceLock::new();
#[derive(Debug, Clone, Default)]
struct ToolMapState {
tool_to_server: HashMap<String, McpServerConfig>,
initialized: bool,
config_hash: u64,
}
pub async fn initialize_tool_map(config: &Config) -> Result<()> {
let config_hash = calculate_config_hash(config);
let tool_map_state = TOOL_MAP.get_or_init(|| Arc::new(RwLock::new(ToolMapState::default())));
{
let state = tool_map_state.read().unwrap();
if state.initialized && state.config_hash == config_hash {
crate::log_debug!("Tool map already initialized with current config");
return Ok(());
}
}
crate::log_debug!("Building tool-to-server map...");
let tool_to_server = build_tool_server_map_internal(config).await?;
{
let mut state = tool_map_state.write().unwrap();
state.tool_to_server = tool_to_server;
state.initialized = true;
state.config_hash = config_hash;
crate::log_debug!(
"Tool map initialized with {} tools",
state.tool_to_server.len()
);
}
Ok(())
}
pub fn get_server_for_tool(tool_name: &str) -> Option<McpServerConfig> {
let tool_map_state = TOOL_MAP.get()?;
let state = tool_map_state.read().unwrap();
if !state.initialized {
crate::log_debug!("Tool map not initialized, falling back to original logic");
return None;
}
state.tool_to_server.get(tool_name).cloned()
}
pub fn get_tool_server_name(tool_name: &str) -> Option<String> {
get_server_for_tool(tool_name).map(|server| server.name().to_string())
}
pub fn is_initialized() -> bool {
TOOL_MAP
.get()
.map(|state| state.read().unwrap().initialized)
.unwrap_or(false)
}
pub fn get_all_tool_names() -> Vec<String> {
let tool_map_state = match TOOL_MAP.get() {
Some(state) => state,
None => return Vec::new(),
};
let state = tool_map_state.read().unwrap();
if !state.initialized {
return Vec::new();
}
state.tool_to_server.keys().cloned().collect()
}
async fn build_tool_server_map_internal(
config: &Config,
) -> Result<HashMap<String, McpServerConfig>> {
let mut tool_map = HashMap::new();
let enabled_servers: Vec<McpServerConfig> = config.mcp.servers.to_vec();
for server in enabled_servers {
let server_functions = match server.connection_type() {
McpConnectionType::Builtin => {
match server.name() {
"developer" => {
crate::mcp::get_cached_internal_functions(
"developer",
server.tools(),
crate::mcp::dev::get_all_functions,
)
}
"filesystem" => crate::mcp::get_cached_internal_functions(
"filesystem",
server.tools(),
crate::mcp::fs::get_all_functions,
),
"agent" => {
let server_functions = crate::mcp::agent::get_all_functions(config);
crate::mcp::filter_tools_by_patterns(server_functions, server.tools())
}
"web" => {
crate::mcp::get_cached_internal_functions("web", server.tools(), || {
crate::mcp::web::get_all_functions()
})
}
_ => {
crate::log_debug!("Unknown builtin server: {}", server.name());
Vec::new()
}
}
}
McpConnectionType::Http | McpConnectionType::Stdin => {
match crate::mcp::server::get_server_functions_cached(&server).await {
Ok(functions) => {
crate::mcp::filter_tools_by_patterns(functions, server.tools())
}
Err(_) => Vec::new(), }
}
};
for function in server_functions {
tool_map
.entry(function.name)
.or_insert_with(|| server.clone());
}
}
Ok(tool_map)
}
fn calculate_config_hash(config: &Config) -> u64 {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
for server in &config.mcp.servers {
server.name().hash(&mut hasher);
server.connection_type().hash(&mut hasher);
server.tools().hash(&mut hasher);
}
hasher.finish()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tool_map_not_initialized() {
assert_eq!(get_server_for_tool("test_tool"), None);
assert_eq!(get_tool_server_name("test_tool"), None);
assert!(!is_initialized());
assert!(get_all_tool_names().is_empty());
}
}