collet 0.1.0

Relentless agentic coding orchestrator with zero-drop agent loops
Documentation
use std::sync::atomic::{AtomicU64, Ordering};

use anyhow::{Context, Result};
use serde_json::{Value, json};

use super::protocol::{
    JsonRpcNotification, JsonRpcRequest, MCP_PROTOCOL_VERSION, METHOD_INITIALIZE,
    METHOD_NOTIFICATIONS_INITIALIZED, METHOD_TOOLS_CALL, METHOD_TOOLS_LIST, McpServerInfo,
    McpToolDefinition, McpToolResult,
};
use super::transport::{HttpTransport, McpTransport, StdioTransport};

// ---------------------------------------------------------------------------
// McpClient
// ---------------------------------------------------------------------------

pub struct McpClient {
    transport: McpTransport,
    id_counter: AtomicU64,
    /// Cached tool definitions from the server.
    tools: Vec<McpToolDefinition>,
    /// Server info from the initialize handshake.
    server_info: Option<McpServerInfo>,
}

impl McpClient {
    // -- constructors -------------------------------------------------------

    pub fn connect_stdio(command: &str, args: &[&str]) -> Result<Self> {
        let transport = StdioTransport::new(command, args)?;
        Ok(Self {
            transport: McpTransport::Stdio(Box::new(transport)),
            id_counter: AtomicU64::new(1),
            tools: Vec::new(),
            server_info: None,
        })
    }

    pub fn connect_stdio_with_env(
        command: &str,
        args: &[&str],
        env: &std::collections::HashMap<String, String>,
    ) -> Result<Self> {
        let transport = StdioTransport::with_env(command, args, env)?;
        Ok(Self {
            transport: McpTransport::Stdio(Box::new(transport)),
            id_counter: AtomicU64::new(1),
            tools: Vec::new(),
            server_info: None,
        })
    }

    pub fn connect_http(url: &str) -> Result<Self> {
        Self::connect_http_with_headers(url, &std::collections::HashMap::new())
    }

    pub fn connect_http_with_headers(
        url: &str,
        headers: &std::collections::HashMap<String, String>,
    ) -> Result<Self> {
        let transport = if headers.is_empty() {
            HttpTransport::new(url)
        } else {
            HttpTransport::with_headers(url, headers)
        };
        Ok(Self {
            transport: McpTransport::Http(transport),
            id_counter: AtomicU64::new(1),
            tools: Vec::new(),
            server_info: None,
        })
    }

    /// Return the child process PID (stdio transport only).
    pub fn pid(&self) -> Option<u32> {
        self.transport.pid()
    }

    /// Return the target URL for HTTP transports (None for stdio).
    pub fn target_url(&self) -> Option<String> {
        self.transport.target_url()
    }

    // -- helpers ------------------------------------------------------------

    fn next_id(&self) -> u64 {
        self.id_counter.fetch_add(1, Ordering::Relaxed)
    }

    // -- MCP protocol methods -----------------------------------------------

    /// Send the `initialize` handshake, then the mandatory
    /// `notifications/initialized` notification, and return server info.
    ///
    /// Per the MCP lifecycle spec, the client MUST send
    /// `notifications/initialized` after receiving the `initialize` result and
    /// before issuing any other request. Skipping it causes some servers
    /// (alcove, context7, ...) to reject `tools/list` and `tools/call`
    /// outright.
    pub async fn initialize(&mut self) -> Result<McpServerInfo> {
        let req = JsonRpcRequest::new(
            self.next_id(),
            METHOD_INITIALIZE,
            Some(json!({
                "protocolVersion": MCP_PROTOCOL_VERSION,
                "capabilities": {
                    "roots": { "listChanged": false },
                    "sampling": {}
                },
                "clientInfo": {
                    "name": "collet",
                    "version": env!("CARGO_PKG_VERSION")
                }
            })),
        );

        let resp = self
            .transport
            .send(&req)
            .await
            .context("MCP initialize failed")?;

        if let Some(err) = resp.error {
            anyhow::bail!("MCP initialize error {}: {}", err.code, err.message);
        }

        // Parse initialize result manually — `serverInfo.{name, version}` is
        // nested per spec, not a flat field at the root.
        let result = resp.result.unwrap_or_default();
        let server_info_obj = result.get("serverInfo");
        let info = McpServerInfo {
            name: server_info_obj
                .and_then(|s| s.get("name"))
                .and_then(|v| v.as_str())
                .unwrap_or_default()
                .to_string(),
            version: server_info_obj
                .and_then(|s| s.get("version"))
                .and_then(|v| v.as_str())
                .unwrap_or_default()
                .to_string(),
            protocol_version: result
                .get("protocolVersion")
                .and_then(|v| v.as_str())
                .unwrap_or_default()
                .to_string(),
            capabilities: result.get("capabilities").cloned().unwrap_or(Value::Null),
            instructions: result
                .get("instructions")
                .and_then(|v| v.as_str())
                .map(|s| s.to_string()),
        };

        // Send the mandatory `notifications/initialized`. Failure here is
        // logged but not fatal — some servers tolerate its absence and we'd
        // rather try `tools/list` than abort the whole connection.
        let note = JsonRpcNotification::new(METHOD_NOTIFICATIONS_INITIALIZED, None);
        if let Err(e) = self.transport.send_notification(&note).await {
            tracing::warn!(error = %e, "MCP notifications/initialized send failed");
        }

        self.server_info = Some(info.clone());
        Ok(info)
    }

    /// List tools exposed by the MCP server.
    pub async fn list_tools(&mut self) -> Result<Vec<McpToolDefinition>> {
        let req = JsonRpcRequest::new(self.next_id(), METHOD_TOOLS_LIST, Some(json!({})));

        let resp = self
            .transport
            .send(&req)
            .await
            .context("MCP tools/list failed")?;

        if let Some(err) = resp.error {
            anyhow::bail!("MCP tools/list error {}: {}", err.code, err.message);
        }

        let result = resp.result.unwrap_or(json!({"tools": []}));
        let tools_value = result.get("tools").cloned().unwrap_or(Value::Array(vec![]));
        let tools: Vec<McpToolDefinition> =
            serde_json::from_value(tools_value).context("Failed to parse MCP tool list")?;

        self.tools = tools.clone();
        Ok(tools)
    }

    /// Call a tool on the MCP server and return concatenated text content.
    pub async fn call_tool(&mut self, name: &str, arguments: Value) -> Result<String> {
        let req = JsonRpcRequest::new(
            self.next_id(),
            METHOD_TOOLS_CALL,
            Some(json!({
                "name": name,
                "arguments": arguments,
            })),
        );

        let resp = self
            .transport
            .send(&req)
            .await
            .with_context(|| format!("MCP tools/call failed for tool '{name}'"))?;

        if let Some(err) = resp.error {
            anyhow::bail!(
                "MCP tools/call error for '{}': {} — {}",
                name,
                err.code,
                err.message
            );
        }

        let result = resp.result.unwrap_or_default();
        let tool_result: McpToolResult =
            serde_json::from_value(result).context("Failed to parse McpToolResult")?;

        // Concatenate all text content blocks.
        let text: String = tool_result
            .content
            .iter()
            .filter(|c| c.type_ == "text")
            .map(|c| c.text.as_str())
            .collect::<Vec<_>>()
            .join("\n");

        // MCP tools signal failure via `isError` on the result, not via the
        // JSON-RPC error envelope. Surface it as an error so callers don't
        // mistake "boom" responses for successful output.
        if tool_result.is_error {
            anyhow::bail!(
                "MCP tool '{name}' reported isError=true: {}",
                if text.is_empty() {
                    "<no content>"
                } else {
                    text.as_str()
                }
            );
        }

        Ok(text)
    }

    /// Convert cached MCP tool definitions into the collet tool definition
    /// format (the `serde_json::Value` shape used by `tools/registry.rs`).
    pub fn tool_definitions(&self) -> Vec<Value> {
        self.tools
            .iter()
            .map(|t| {
                json!({
                    "type": "function",
                    "function": {
                        "name": t.name,
                        "description": t.description,
                        "parameters": t.input_schema,
                    }
                })
            })
            .collect()
    }

    /// Return the server instructions from the initialize handshake, if any.
    pub fn server_instructions(&self) -> Option<&str> {
        self.server_info
            .as_ref()
            .and_then(|i| i.instructions.as_deref())
    }

    /// Return cached MCP tool definitions (raw).
    pub fn raw_tools(&self) -> &[McpToolDefinition] {
        &self.tools
    }

    /// Shut down the underlying transport (kills child process for stdio).
    pub async fn shutdown(&mut self) -> Result<()> {
        self.transport.shutdown().await.map_err(anyhow::Error::from)
    }
}

// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------

#[cfg(test)]
mod tests {
    use super::*;
    use crate::mcp::protocol::McpToolDefinition;
    use serde_json::json;

    /// Helper: build a client with some pre-populated tool definitions so we
    /// can test `tool_definitions` without needing a real server.
    fn client_with_tools(tools: Vec<McpToolDefinition>) -> McpClient {
        McpClient {
            transport: McpTransport::Http(HttpTransport::new("http://unused")),
            id_counter: AtomicU64::new(1),
            tools,
            server_info: None,
        }
    }

    #[test]
    fn test_tool_definitions_basic() {
        let client = client_with_tools(vec![McpToolDefinition {
            name: "read_file".to_string(),
            description: "Read a file".to_string(),
            input_schema: json!({
                "type": "object",
                "properties": {
                    "path": { "type": "string" }
                },
                "required": ["path"]
            }),
        }]);

        let defs = client.tool_definitions();
        assert_eq!(defs.len(), 1);

        let def = &defs[0];
        assert_eq!(def["type"], "function");
        assert_eq!(def["function"]["name"], "read_file");
        assert_eq!(def["function"]["description"], "Read a file");
        assert!(def["function"]["parameters"]["properties"]["path"].is_object());
    }

    #[test]
    fn test_tool_definitions_multiple() {
        let client = client_with_tools(vec![
            McpToolDefinition {
                name: "tool_a".to_string(),
                description: "First tool".to_string(),
                input_schema: json!({"type": "object"}),
            },
            McpToolDefinition {
                name: "tool_b".to_string(),
                description: "Second tool".to_string(),
                input_schema: json!({"type": "object", "properties": {}}),
            },
        ]);

        let defs = client.tool_definitions();
        assert_eq!(defs.len(), 2);
        assert_eq!(defs[0]["function"]["name"], "tool_a");
        assert_eq!(defs[1]["function"]["name"], "tool_b");
    }

    #[test]
    fn test_tool_definitions_empty() {
        let client = client_with_tools(vec![]);
        let defs = client.tool_definitions();
        assert!(defs.is_empty());
    }
}