collet 0.1.1

Relentless agentic coding orchestrator with zero-drop agent loops
Documentation
use std::collections::HashMap;
use std::sync::Arc;

use serde_json::Value;
use tokio::sync::Mutex;
use tracing::{info, warn};

use super::client::McpClient;
use super::config::{self, ResolvedMcpServer};
use super::protocol::McpServerInfo;

/// When total MCP tool count exceeds this threshold, switch to deferred
/// (lazy) loading — tool schemas are omitted from the LLM `tools` array
/// and discovered on-demand via the `tool_search` tool.
pub const MCP_EAGER_THRESHOLD: usize = 10;

/// Maps MCP tool name → server name for dispatch routing.
#[derive(Debug, Clone)]
struct ToolRoute {
    server_name: String,
    /// Original tool name on the MCP server (without prefix).
    original_name: String,
}

/// Per-server metadata captured during connection.
#[derive(Debug, Clone)]
pub struct ServerMeta {
    /// Human-readable description (from `McpServerInfo.name`).
    pub description: String,
    /// Server-level instructions describing purpose and tool usage.
    pub instructions: Option<String>,
    /// Number of tools exposed by this server.
    pub tool_count: usize,
    /// Prefixed tool names (`mcp__{server}__{tool}`).
    pub tool_names: Vec<String>,
}

/// Manages connections to all configured MCP servers.
///
/// Each `McpClient` is behind `Arc<Mutex>` because stdio transport
/// requires `&mut self` for send (stdin write + stdout read).
pub struct McpManager {
    clients: HashMap<String, Arc<Mutex<McpClient>>>,
    /// tool_name (prefixed `mcp__{server}__{tool}`) → route info.
    routes: HashMap<String, ToolRoute>,
    /// Cached tool definitions in OpenAI function-calling format.
    tool_defs: Vec<Value>,
    /// Per-server metadata for overview generation.
    server_meta: HashMap<String, ServerMeta>,
}

impl McpManager {
    /// Create an empty manager with no servers (used by subagents).
    pub fn empty() -> Self {
        Self {
            clients: HashMap::new(),
            routes: HashMap::new(),
            tool_defs: Vec::new(),
            server_meta: HashMap::new(),
        }
    }

    /// Connect to all enabled MCP servers discovered from config.
    ///
    /// Connection failures are logged but do not prevent other servers
    /// from initializing. Returns the manager even if all servers fail.
    pub async fn connect_all(working_dir: &str) -> Self {
        let servers = config::load_mcp_servers(working_dir).unwrap_or_default();

        // Connect all servers in parallel for faster startup.
        let handles: Vec<_> = servers
            .into_iter()
            .map(|server| {
                tokio::spawn(async move {
                    let result = connect_server(&server).await;
                    (
                        server.name.clone(),
                        server.source.clone(),
                        server.description.clone(),
                        result,
                    )
                })
            })
            .collect();

        let mut clients = HashMap::new();
        let mut routes = HashMap::new();
        let mut tool_defs = Vec::new();
        let mut server_meta = HashMap::new();

        for handle in handles {
            let (server_name, server_source, server_description, result) = match handle.await {
                Ok(r) => r,
                Err(e) => {
                    warn!(error = %e, "MCP server task panicked");
                    continue;
                }
            };

            match result {
                Ok((client, tools, mcp_info)) => {
                    info!(
                        server = %server_name,
                        tool_count = tools.len(),
                        source = ?server_source,
                        description = ?server_description,
                        "MCP server connected"
                    );

                    let mut tool_names = Vec::new();

                    // Register tool definitions with prefixed names.
                    for def in &tools {
                        let original_name =
                            def["function"]["name"].as_str().unwrap_or("").to_string();
                        let prefixed = format!("mcp__{}__{}", server_name, original_name);

                        routes.insert(
                            prefixed.clone(),
                            ToolRoute {
                                server_name: server_name.clone(),
                                original_name: original_name.clone(),
                            },
                        );

                        tool_names.push(prefixed.clone());

                        // Clone the definition with the prefixed name.
                        let mut patched = def.clone();
                        if let Some(func) = patched.get_mut("function") {
                            func["name"] = Value::String(prefixed);
                        }
                        tool_defs.push(patched);
                    }

                    server_meta.insert(
                        server_name.clone(),
                        ServerMeta {
                            description: mcp_info.name.clone(),
                            instructions: mcp_info.instructions.clone(),
                            tool_count: tools.len(),
                            tool_names,
                        },
                    );

                    clients.insert(server_name, Arc::new(Mutex::new(client)));
                }
                Err(e) => {
                    warn!(server = %server_name, error = %e, "Failed to connect MCP server");
                }
            }
        }

        Self {
            clients,
            routes,
            tool_defs,
            server_meta,
        }
    }

    /// All tool definitions (for eager mode).
    pub fn tool_definitions(&self) -> &[Value] {
        &self.tool_defs
    }

    /// Tool definitions for the LLM `tools` array, respecting the eager threshold.
    ///
    /// If total MCP tool count ≤ `MCP_EAGER_THRESHOLD`, returns all definitions.
    /// Otherwise returns an empty Vec (deferred mode — tools discovered via `tool_search`).
    pub fn eager_tool_definitions(&self) -> Vec<Value> {
        if self.is_deferred_mode() {
            Vec::new()
        } else {
            self.tool_defs.clone()
        }
    }

    /// Whether deferred (lazy) loading is active.
    pub fn is_deferred_mode(&self) -> bool {
        self.total_tool_count() > MCP_EAGER_THRESHOLD
    }

    /// Total number of MCP tools across all servers.
    pub fn total_tool_count(&self) -> usize {
        self.tool_defs.len()
    }

    /// Returns true if the given tool name is an MCP tool.
    pub fn is_mcp_tool(&self, name: &str) -> bool {
        self.routes.contains_key(name)
    }

    /// Search tools by keyword query. Matches against prefixed name and description.
    /// Returns full tool schemas for matching tools.
    pub fn search_tools(&self, query: &str) -> Vec<Value> {
        let query_lower = query.to_lowercase();
        let keywords: Vec<&str> = query_lower.split_whitespace().collect();

        self.tool_defs
            .iter()
            .filter(|def| {
                let name = def["function"]["name"]
                    .as_str()
                    .unwrap_or("")
                    .to_lowercase();
                let desc = def["function"]["description"]
                    .as_str()
                    .unwrap_or("")
                    .to_lowercase();
                let haystack = format!("{} {}", name, desc);
                keywords.iter().all(|kw| haystack.contains(kw))
            })
            .cloned()
            .collect()
    }

    /// Generate an overview of all connected MCP servers for the system prompt.
    pub fn server_overview(&self) -> String {
        if self.server_meta.is_empty() {
            return String::new();
        }

        let mut out = String::from("## MCP Servers\n\n");
        for (name, meta) in &self.server_meta {
            out.push_str(&format!("### {name} ({} tools)\n", meta.tool_count,));
            if !meta.description.is_empty() {
                out.push_str(&format!("{}\n", meta.description));
            }
            if let Some(ref instructions) = meta.instructions {
                out.push_str(&format!("\n{instructions}\n"));
            }
            out.push('\n');
        }

        // In deferred mode, list all tool names so the LLM knows what's available.
        if self.is_deferred_mode() {
            out.push_str("<available-deferred-tools>\n");
            for (server_name, meta) in &self.server_meta {
                for tool_name in &meta.tool_names {
                    // Include description if available
                    let desc = self.tool_defs.iter().find_map(|def| {
                        let n = def["function"]["name"].as_str()?;
                        if n == tool_name {
                            def["function"]["description"].as_str()
                        } else {
                            None
                        }
                    });
                    if let Some(d) = desc {
                        out.push_str(&format!("- {tool_name}: {d}\n"));
                    } else {
                        out.push_str(&format!("- {tool_name} (server: {server_name})\n"));
                    }
                }
            }
            out.push_str("</available-deferred-tools>\n\n");
            out.push_str(
                "Use the `tool_search` tool to fetch full schemas for deferred tools before calling them.\n",
            );
        }

        out
    }

    /// Access per-server metadata (for tool index building).
    pub fn server_meta(&self) -> &HashMap<String, ServerMeta> {
        &self.server_meta
    }

    /// Call an MCP tool by its prefixed name.
    pub async fn call_tool(
        &self,
        prefixed_name: &str,
        arguments: &str,
    ) -> crate::common::Result<String> {
        let route = self.routes.get(prefixed_name).ok_or_else(|| {
            crate::common::AgentError::InvalidArgument(format!(
                "Unknown MCP tool: {}",
                prefixed_name
            ))
        })?;

        let client_lock = self.clients.get(&route.server_name).ok_or_else(|| {
            crate::common::AgentError::Internal(format!(
                "MCP server '{}' not connected",
                route.server_name
            ))
        })?;

        let args: Value = if arguments.trim().is_empty() {
            Value::Object(Default::default())
        } else {
            match serde_json::from_str(arguments) {
                Ok(v) => v,
                Err(e) => {
                    warn!(
                        tool = %prefixed_name,
                        arguments = %arguments,
                        error = %e,
                        "Failed to parse MCP tool arguments — sending empty object"
                    );
                    Value::Object(Default::default())
                }
            }
        };

        let mut client = client_lock.lock().await;
        client
            .call_tool(&route.original_name, args)
            .await
            .map_err(|e| crate::common::AgentError::Internal(format!("MCP call failed: {e}")))
    }

    /// Shut down all connected MCP servers.
    pub async fn shutdown_all(&self) {
        for (name, client_lock) in &self.clients {
            let mut client = client_lock.lock().await;
            if let Err(e) = client.shutdown().await {
                warn!(server = %name, error = %e, "MCP server shutdown error");
            }
        }
    }

    /// Collect PIDs of all stdio MCP server child processes.
    pub async fn child_pids(&self) -> Vec<u32> {
        let mut pids = Vec::new();
        for client in self.clients.values() {
            if let Some(pid) = client.lock().await.pid() {
                pids.push(pid);
            }
        }
        pids
    }

    /// Number of connected servers.
    pub fn server_count(&self) -> usize {
        self.clients.len()
    }
}

/// Connect to a single MCP server: spawn/connect → initialize → list_tools.
/// Returns the client, tool definitions, and server info.
async fn connect_server(
    server: &ResolvedMcpServer,
) -> anyhow::Result<(McpClient, Vec<Value>, McpServerInfo)> {
    let mut client = if let Some(ref cmd) = server.command {
        let args: Vec<&str> = server.args.iter().map(|s| s.as_str()).collect();
        if server.env.is_empty() {
            McpClient::connect_stdio(cmd, &args)?
        } else {
            McpClient::connect_stdio_with_env(cmd, &args, &server.env)?
        }
    } else if let Some(ref url) = server.url {
        if server.headers.is_empty() {
            McpClient::connect_http(url)?
        } else {
            McpClient::connect_http_with_headers(url, &server.headers)?
        }
    } else {
        anyhow::bail!("MCP server '{}' has neither command nor url", server.name);
    };

    // Log the transport target for HTTP servers (derived via build_request).
    if let Some(url) = client.target_url() {
        tracing::debug!(server = %server.name, url = %url, "MCP HTTP transport target");
    }

    // Initialize with a timeout.
    let info = tokio::time::timeout(std::time::Duration::from_secs(30), client.initialize())
        .await
        .map_err(|_| anyhow::anyhow!("MCP initialize timed out for '{}'", server.name))??;

    // List tools with a timeout.
    tokio::time::timeout(std::time::Duration::from_secs(15), client.list_tools())
        .await
        .map_err(|_| anyhow::anyhow!("MCP list_tools timed out for '{}'", server.name))??;

    let defs = client.tool_definitions();
    let raw_count = client.raw_tools().len();
    if let Some(instructions) = client.server_instructions() {
        tracing::debug!(
            server = %server.name,
            raw_tools = raw_count,
            instructions_len = instructions.len(),
            "MCP server has system instructions"
        );
    }
    Ok((client, defs, info))
}