use anyhow::{anyhow, Result};
use async_trait::async_trait;
use serde_json::Value;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{Mutex, RwLock};
use super::client::McpClient;
use super::config::{McpConfig, McpServerConfig};
use super::proxy::McpToolWrapper;
use crate::tools::{ToolDefinition, Tool};
use crate::approval::RiskLevel;
pub struct LazyMcpTool {
definition: ToolDefinition,
server_name: String,
server_config: McpServerConfig,
actual_tool: Arc<Mutex<Option<Arc<McpToolWrapper>>>>,
client: Arc<Mutex<Option<Arc<McpClient>>>>,
}
impl LazyMcpTool {
pub fn new(
server_name: String,
server_config: McpServerConfig,
tool_name: String,
tool_description: String,
tool_input_schema: Value,
) -> Self {
let definition = ToolDefinition {
name: tool_name,
description: tool_description,
parameters: tool_input_schema,
is_priority: false,
};
Self {
definition,
server_name,
server_config,
actual_tool: Arc::new(Mutex::new(None)),
client: Arc::new(Mutex::new(None)),
}
}
async fn ensure_started(&self) -> Result<Arc<McpToolWrapper>> {
{
let tool_lock = self.actual_tool.lock().await;
if let Some(tool) = tool_lock.as_ref() {
return Ok(tool.clone());
}
}
tracing::info!("Starting lazy MCP server '{}' for tool '{}'",
self.server_name, self.definition.name);
let transport_config = self.server_config.to_transport_config()?;
let client = Arc::new(McpClient::connect(&self.server_name, transport_config).await?);
let mcp_tools = client.list_tools().await?;
let tool_name_without_prefix = self.definition.name.clone();
let mcp_tool = mcp_tools.into_iter()
.find(|t| {
let name = t.name.clone();
name == tool_name_without_prefix ||
name == format!("{}_{}", self.server_name, tool_name_without_prefix)
})
.ok_or_else(|| anyhow!(
"Tool '{}' not found in MCP server '{}'",
self.definition.name, self.server_name
))?;
let wrapper = Arc::new(McpToolWrapper::new(
client.clone(),
mcp_tool,
self.server_name.clone()
));
{
let mut tool_lock = self.actual_tool.lock().await;
*tool_lock = Some(wrapper.clone());
}
{
let mut client_lock = self.client.lock().await;
*client_lock = Some(client);
}
tracing::info!("Lazy MCP server '{}' started successfully", self.server_name);
Ok(wrapper)
}
pub async fn shutdown(&self) -> Result<()> {
let client_lock = self.client.lock().await;
if let Some(client) = client_lock.as_ref() {
client.shutdown().await?;
tracing::info!("Lazy MCP server '{}' stopped", self.server_name);
}
Ok(())
}
}
#[async_trait]
impl Tool for LazyMcpTool {
fn definition(&self) -> ToolDefinition {
self.definition.clone()
}
async fn execute(&self, params: Value) -> Result<String> {
let tool = self.ensure_started().await?;
tool.execute(params).await
}
fn risk_level(&self) -> RiskLevel {
RiskLevel::Mutating
}
}
pub struct LazyMcpLoader {
config: McpConfig,
#[allow(dead_code)]
tools: Vec<LazyMcpTool>,
}
impl LazyMcpLoader {
pub fn from_config(config: McpConfig) -> Self {
let tools = Vec::new();
Self {
config,
tools,
}
}
pub async fn discover_tools(&self) -> Result<Vec<LazyMcpTool>> {
Ok(Vec::new())
}
pub fn config(&self) -> &McpConfig {
&self.config
}
}
pub struct McpToolPlaceholder {
server_name: String,
server_config: McpServerConfig,
tools: Arc<RwLock<Option<Vec<Arc<McpToolWrapper>>>>>,
client: Arc<RwLock<Option<Arc<McpClient>>>>,
}
impl McpToolPlaceholder {
pub fn new(server_name: String, server_config: McpServerConfig) -> Self {
Self {
server_name,
server_config,
tools: Arc::new(RwLock::new(None)),
client: Arc::new(RwLock::new(None)),
}
}
pub async fn start(&self) -> Result<Vec<Arc<McpToolWrapper>>> {
{
let tools_lock = self.tools.read().await;
if let Some(tools) = tools_lock.as_ref() {
return Ok(tools.clone());
}
}
tracing::info!("Starting MCP server '{}' on demand", self.server_name);
let transport_config = self.server_config.to_transport_config()?;
let client = Arc::new(McpClient::connect(&self.server_name, transport_config).await?);
if !client.supports_tools().await {
return Err(anyhow!("MCP server '{}' does not support tools", self.server_name));
}
let mcp_tools = client.list_tools().await?;
let tools: Vec<Arc<McpToolWrapper>> = mcp_tools.into_iter()
.map(|tool| Arc::new(McpToolWrapper::new(
client.clone(),
tool,
self.server_name.clone()
)))
.collect();
{
let mut tools_lock = self.tools.write().await;
*tools_lock = Some(tools.clone());
}
{
let mut client_lock = self.client.write().await;
*client_lock = Some(client);
}
tracing::info!("MCP server '{}' started with {} tools", self.server_name, tools.len());
Ok(tools)
}
pub async fn shutdown(&self) -> Result<()> {
let client_lock = self.client.read().await;
if let Some(client) = client_lock.as_ref() {
client.shutdown().await?;
}
Ok(())
}
pub async fn is_started(&self) -> bool {
self.tools.read().await.is_some()
}
pub fn server_name(&self) -> &str {
&self.server_name
}
}
pub struct McpToolRegistry {
placeholders: HashMap<String, Arc<McpToolPlaceholder>>,
}
impl McpToolRegistry {
pub fn from_config(config: &McpConfig) -> Self {
let placeholders = config.servers
.iter()
.filter(|(_, cfg)| cfg.enabled)
.map(|(name, cfg)| {
(name.clone(), Arc::new(McpToolPlaceholder::new(
name.clone(),
cfg.clone()
)))
})
.collect();
Self { placeholders }
}
pub fn new() -> Self {
Self {
placeholders: HashMap::new(),
}
}
pub fn add_server(&mut self, name: String, config: McpServerConfig) {
self.placeholders.insert(
name.clone(),
Arc::new(McpToolPlaceholder::new(name, config))
);
}
pub async fn remove_server(&mut self, name: &str) -> Result<()> {
if let Some(placeholder) = self.placeholders.remove(name) {
placeholder.shutdown().await?;
}
Ok(())
}
pub fn get_server(&self, name: &str) -> Option<Arc<McpToolPlaceholder>> {
self.placeholders.get(name).cloned()
}
pub fn server_names(&self) -> Vec<&String> {
self.placeholders.keys().collect()
}
pub async fn started_servers(&self) -> Vec<String> {
let mut started = Vec::new();
for (name, placeholder) in &self.placeholders {
if placeholder.is_started().await {
started.push(name.clone());
}
}
started
}
pub async fn server_status(&self) -> HashMap<String, ServerStatus> {
let mut status = HashMap::new();
for (name, placeholder) in &self.placeholders {
let is_started = placeholder.is_started().await;
status.insert(name.clone(), ServerStatus {
name: name.clone(),
is_started,
tool_count: if is_started {
placeholder.tools.read().await.as_ref().map(|t| t.len()).unwrap_or(0)
} else {
0
},
});
}
status
}
pub async fn start_all(&self) -> Result<HashMap<String, Vec<Arc<McpToolWrapper>>>> {
let mut results = HashMap::new();
for (name, placeholder) in &self.placeholders {
let tools = placeholder.start().await?;
results.insert(name.clone(), tools);
}
Ok(results)
}
pub async fn shutdown_all(&self) {
for placeholder in self.placeholders.values() {
if let Err(e) = placeholder.shutdown().await {
tracing::error!("Failed to shutdown MCP server '{}': {}",
placeholder.server_name(), e);
}
}
}
pub fn add_from_cli_arg(&mut self, arg: &str) -> Result<()> {
let (name, command, args) = parse_cli_mcp_arg(arg)?;
let config = McpServerConfig {
command: Some(command),
args,
url: None,
enabled: true,
..Default::default()
};
self.add_server(name, config);
Ok(())
}
pub fn add_from_cli_args(&mut self, args: &[String]) -> Result<()> {
for arg in args {
self.add_from_cli_arg(arg)?;
}
Ok(())
}
}
impl Default for McpToolRegistry {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct ServerStatus {
pub name: String,
pub is_started: bool,
pub tool_count: usize,
}
fn parse_cli_mcp_arg(arg: &str) -> Result<(String, String, Vec<String>)> {
let arg = arg.trim();
if arg.is_empty() {
return Err(anyhow!("Empty MCP server argument"));
}
if let Some((name, rest)) = arg.split_once(':') {
let name = name.trim().to_string();
let rest = rest.trim();
let parts = shell_words::split(rest)
.map_err(|e| anyhow!("Failed to parse command: {}", e))?;
if parts.is_empty() {
return Err(anyhow!("Missing command for MCP server '{}'", name));
}
let command = parts[0].clone();
let args = parts[1..].to_vec();
return Ok((name, command, args));
}
let parts = shell_words::split(arg)
.map_err(|e| anyhow!("Failed to parse command: {}", e))?;
if parts.is_empty() {
return Err(anyhow!("Empty MCP server argument"));
}
let command = parts[0].clone();
let name = command.clone(); let args = parts[1..].to_vec();
Ok((name, command, args))
}