mod context;
mod conversions;
use std::collections::HashMap;
use async_trait::async_trait;
use tracing::{info, warn};
use systemprompt_database::DbPool;
use systemprompt_traits::{
ToolCallRequest, ToolCallResult, ToolContext, ToolDefinition, ToolProvider, ToolProviderError,
ToolProviderResult,
};
use crate::services::client::{
McpClient, rewrite_url_for_internal_use, validate_connection, validate_connection_by_url,
};
use crate::services::registry::RegistryManager;
use context::{create_request_context, load_agent_servers};
use conversions::{to_tool_definition, to_tool_result};
#[derive(Debug, Clone)]
pub struct McpToolProvider {
db_pool: DbPool,
}
impl McpToolProvider {
pub const fn new(db_pool: DbPool) -> Self {
Self { db_pool }
}
pub const fn db_pool(&self) -> &DbPool {
&self.db_pool
}
}
#[async_trait]
impl ToolProvider for McpToolProvider {
async fn list_tools(
&self,
agent_name: &str,
context: &ToolContext,
) -> ToolProviderResult<Vec<ToolDefinition>> {
let assigned_servers = load_agent_servers(agent_name).map_err(|e| {
ToolProviderError::ConfigurationError(format!("Failed to load agent config: {e}"))
})?;
info!(
agent = agent_name,
servers = %assigned_servers.join(", "),
"Listing tools for agent from MCP servers"
);
let request_ctx = create_request_context(context)?;
let mut all_tools = Vec::new();
for server_name in &assigned_servers {
match McpClient::list_tools(server_name, &request_ctx).await {
Ok(tools) => {
info!(
server = server_name,
tool_count = tools.len(),
"Loaded tools from MCP server"
);
for tool in tools {
all_tools.push(to_tool_definition(&tool));
}
},
Err(e) => {
warn!(
server = server_name,
error = %e,
"Failed to list tools from MCP server"
);
},
}
}
info!(
agent = agent_name,
total_tools = all_tools.len(),
"Total tools loaded for agent"
);
Ok(all_tools)
}
async fn call_tool(
&self,
request: &ToolCallRequest,
service_id: &str,
context: &ToolContext,
) -> ToolProviderResult<ToolCallResult> {
let request_ctx = create_request_context(context)?;
info!(
tool = &request.name,
service = service_id,
"Executing tool via MCP"
);
let result = McpClient::call_tool(
service_id,
request.name.clone(),
Some(request.arguments.clone()),
&request_ctx,
&self.db_pool,
)
.await
.map_err(|e| ToolProviderError::ExecutionFailed(e.to_string()))?;
Ok(to_tool_result(&result))
}
async fn refresh_connections(&self, agent_name: &str) -> ToolProviderResult<()> {
let assigned_servers = load_agent_servers(agent_name).map_err(|e| {
ToolProviderError::ConfigurationError(format!("Failed to load agent config: {e}"))
})?;
info!(
agent = agent_name,
servers = %assigned_servers.join(", "),
"Refreshing MCP connections for agent"
);
RegistryManager::validate().map_err(|e| {
ToolProviderError::Internal(format!("Failed to validate registry: {e}"))
})?;
let api_server_url = systemprompt_models::Config::get()
.map_err(|e| ToolProviderError::Internal(format!("Failed to get configuration: {e}")))?
.api_server_url
.clone();
for server_name in assigned_servers {
validate_server_connection(&server_name, &api_server_url).await;
}
Ok(())
}
async fn health_check(&self) -> ToolProviderResult<HashMap<String, bool>> {
let mut health_status = HashMap::new();
let config_api_server_url = systemprompt_models::Config::get()
.map_err(|e| ToolProviderError::Internal(format!("Failed to get configuration: {e}")))?
.api_server_url
.clone();
if let Ok(servers) = RegistryManager::get_enabled_servers() {
for server in servers {
let is_healthy =
check_server_health(&server.name, server.port, &config_api_server_url).await;
health_status.insert(server.name, is_healthy);
}
}
Ok(health_status)
}
}
async fn validate_server_connection(server_name: &str, api_server_url: &str) {
if let Ok(Some(server_config)) = RegistryManager::find_server(server_name) {
let host = &server_config.host;
let port = server_config.port;
let result = if port == 0 {
let url = server_config.endpoint(api_server_url);
let url = rewrite_url_for_internal_use(&url);
validate_connection_by_url(server_name, &url).await
} else {
validate_connection(server_name, host, port).await
};
match result {
Ok(result) if result.success => {
info!(server = server_name, "MCP server connection validated");
},
Ok(result) => {
warn!(
server = server_name,
error = result.error_message.as_deref().unwrap_or("[no error]"),
"MCP server connection validation failed"
);
},
Err(e) => {
warn!(
server = server_name,
error = %e,
"Failed to validate MCP server connection"
);
},
}
}
}
async fn check_server_health(server_name: &str, server_port: u16, api_server_url: &str) -> bool {
let url = format!("{}/api/v1/mcp/{}/mcp", api_server_url, server_name);
let Ok(parsed_url) = url::Url::parse(&url) else {
return false;
};
let host = parsed_url.host_str().unwrap_or("127.0.0.1");
let actual_port = if server_port > 0 {
server_port
} else {
parsed_url.port().unwrap_or(80)
};
validate_connection(server_name, host, actual_port)
.await
.is_ok_and(|r| r.success)
}