car-engine 0.15.0

Core runtime engine for Common Agent Runtime
Documentation
//! MCP (Model Context Protocol) server integration.
//!
//! Discovers tools from MCP servers via stdin/stdout JSON-RPC and registers
//! them into the canonical tool registry. MCP tools participate in the same
//! capability/permission/policy flow as all other tools.

use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::process::{Child, Command};
use tokio::sync::Mutex;

/// Configuration for an MCP server.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpServerConfig {
    /// Display name for this server.
    pub name: String,
    /// Command to launch the server.
    pub command: String,
    /// Arguments for the command.
    #[serde(default)]
    pub args: Vec<String>,
    /// Environment variables.
    #[serde(default)]
    pub env: HashMap<String, String>,
    /// Working directory.
    pub cwd: Option<String>,
}

/// A running MCP server connection.
pub struct McpServer {
    config: McpServerConfig,
    child: Child,
    stdin: tokio::io::BufWriter<tokio::process::ChildStdin>,
    stdout: BufReader<tokio::process::ChildStdout>,
    next_id: u64,
}

/// An MCP tool discovered from a server.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpToolInfo {
    pub name: String,
    pub description: Option<String>,
    #[serde(rename = "inputSchema")]
    pub input_schema: Option<Value>,
}

/// MCP JSON-RPC request.
#[derive(Debug, Serialize)]
struct McpRequest {
    jsonrpc: &'static str,
    method: String,
    #[serde(skip_serializing_if = "Option::is_none")]
    params: Option<Value>,
    id: u64,
}

/// MCP JSON-RPC response.
#[derive(Debug, Deserialize)]
struct McpResponse {
    result: Option<Value>,
    error: Option<McpError>,
    #[allow(dead_code)]
    id: Option<u64>,
}

#[derive(Debug, Deserialize)]
struct McpError {
    #[allow(dead_code)]
    code: Option<i64>,
    message: String,
}

impl McpServer {
    /// Start an MCP server and initialize the connection.
    pub async fn start(config: McpServerConfig) -> Result<Self, String> {
        let mut cmd = Command::new(&config.command);
        cmd.args(&config.args)
            .stdin(std::process::Stdio::piped())
            .stdout(std::process::Stdio::piped())
            .stderr(std::process::Stdio::piped());

        if let Some(ref cwd) = config.cwd {
            cmd.current_dir(cwd);
        }
        for (k, v) in &config.env {
            cmd.env(k, v);
        }

        let mut child = cmd
            .spawn()
            .map_err(|e| format!("failed to start MCP server '{}': {}", config.name, e))?;

        let stdin = child
            .stdin
            .take()
            .ok_or_else(|| "MCP server has no stdin".to_string())?;
        let stdout = child
            .stdout
            .take()
            .ok_or_else(|| "MCP server has no stdout".to_string())?;

        let mut server = Self {
            config,
            child,
            stdin: tokio::io::BufWriter::new(stdin),
            stdout: BufReader::new(stdout),
            next_id: 1,
        };

        // Send initialize
        server
            .send_request(
                "initialize",
                Some(serde_json::json!({
                    "protocolVersion": "2024-11-05",
                    "capabilities": {},
                    "clientInfo": {
                        "name": "car-runtime",
                        "version": env!("CARGO_PKG_VERSION")
                    }
                })),
            )
            .await?;

        // Send initialized notification (no id, per MCP spec)
        let notification = serde_json::json!({
            "jsonrpc": "2.0",
            "method": "notifications/initialized"
        });
        let msg =
            serde_json::to_string(&notification).map_err(|e| format!("serialize error: {e}"))?;
        server
            .stdin
            .write_all(msg.as_bytes())
            .await
            .map_err(|e| format!("write error: {e}"))?;
        server
            .stdin
            .write_all(b"\n")
            .await
            .map_err(|e| format!("write error: {e}"))?;
        server
            .stdin
            .flush()
            .await
            .map_err(|e| format!("flush error: {e}"))?;

        Ok(server)
    }

    async fn send_request(&mut self, method: &str, params: Option<Value>) -> Result<Value, String> {
        let id = self.next_id;
        self.next_id += 1;

        let req = McpRequest {
            jsonrpc: "2.0",
            method: method.to_string(),
            params,
            id,
        };

        let msg = serde_json::to_string(&req).map_err(|e| format!("serialize error: {e}"))?;

        self.stdin
            .write_all(msg.as_bytes())
            .await
            .map_err(|e| format!("write to MCP server: {e}"))?;
        self.stdin
            .write_all(b"\n")
            .await
            .map_err(|e| format!("write newline: {e}"))?;
        self.stdin
            .flush()
            .await
            .map_err(|e| format!("flush: {e}"))?;

        // Read response line
        let mut line = String::new();
        self.stdout
            .read_line(&mut line)
            .await
            .map_err(|e| format!("read from MCP server: {e}"))?;

        let resp: McpResponse = serde_json::from_str(&line)
            .map_err(|e| format!("invalid MCP response: {e} (raw: {})", line.trim()))?;

        if let Some(err) = resp.error {
            return Err(format!("MCP error: {}", err.message));
        }

        resp.result
            .ok_or_else(|| "MCP server returned no result".to_string())
    }

    /// Discover tools from this MCP server.
    pub async fn list_tools(&mut self) -> Result<Vec<McpToolInfo>, String> {
        let result = self.send_request("tools/list", None).await?;
        let tools = result
            .get("tools")
            .and_then(|t| t.as_array())
            .cloned()
            .unwrap_or_default();

        tools
            .into_iter()
            .map(|t| serde_json::from_value(t).map_err(|e| format!("invalid tool definition: {e}")))
            .collect()
    }

    /// Call a tool on this MCP server.
    pub async fn call_tool(&mut self, name: &str, arguments: Value) -> Result<Value, String> {
        let result = self
            .send_request(
                "tools/call",
                Some(serde_json::json!({
                    "name": name,
                    "arguments": arguments,
                })),
            )
            .await?;

        // Extract text content from MCP response format
        if let Some(content) = result.get("content").and_then(|c| c.as_array()) {
            let texts: Vec<&str> = content
                .iter()
                .filter_map(|block| {
                    if block.get("type").and_then(|t| t.as_str()) == Some("text") {
                        block.get("text").and_then(|t| t.as_str())
                    } else {
                        None
                    }
                })
                .collect();
            if !texts.is_empty() {
                return Ok(Value::String(texts.join("\n")));
            }
        }

        Ok(result)
    }

    /// Shut down the MCP server gracefully.
    pub async fn shutdown(mut self) {
        let _ = self.stdin.shutdown().await;
        let _ = self.child.kill().await;
        let _ = self.child.wait().await;
    }

    /// Get the server name.
    pub fn name(&self) -> &str {
        &self.config.name
    }
}

/// MCP tool executor -- routes tool calls to the appropriate MCP server.
pub struct McpToolExecutor {
    servers: Arc<Mutex<HashMap<String, Arc<Mutex<McpServer>>>>>,
    /// Maps tool_name -> server_name for routing.
    tool_routes: Arc<Mutex<HashMap<String, String>>>,
    /// Optional fallback for non-MCP tools.
    fallback: Option<Arc<dyn super::ToolExecutor>>,
}

impl McpToolExecutor {
    pub fn new() -> Self {
        Self {
            servers: Arc::new(Mutex::new(HashMap::new())),
            tool_routes: Arc::new(Mutex::new(HashMap::new())),
            fallback: None,
        }
    }

    pub fn with_fallback(mut self, fallback: Arc<dyn super::ToolExecutor>) -> Self {
        self.fallback = Some(fallback);
        self
    }

    /// Add an MCP server and discover its tools.
    /// Returns the list of discovered tool names (canonical form: `mcp_{server}_{tool}`).
    pub async fn add_server(&self, mut server: McpServer) -> Result<Vec<String>, String> {
        let server_name = server.config.name.clone();
        let tools = server.list_tools().await?;

        let tool_names: Vec<String> = tools
            .iter()
            .map(|t| format!("mcp_{}_{}", server_name, t.name))
            .collect();

        // Register tool routes
        {
            let mut routes = self.tool_routes.lock().await;
            for (info, canonical_name) in tools.iter().zip(tool_names.iter()) {
                routes.insert(canonical_name.clone(), server_name.clone());
                // Also register the bare name for convenience
                routes.insert(info.name.clone(), server_name.clone());
            }
        }

        // Store server
        self.servers
            .lock()
            .await
            .insert(server_name, Arc::new(Mutex::new(server)));

        Ok(tool_names)
    }

    /// Get tool schemas from all connected MCP servers.
    pub async fn tool_schemas(&self) -> Vec<(String, car_ir::ToolSchema)> {
        let mut schemas = Vec::new();
        let servers = self.servers.lock().await;
        for (server_name, server) in servers.iter() {
            let mut srv = server.lock().await;
            if let Ok(tools) = srv.list_tools().await {
                for tool in tools {
                    let canonical_name = format!("mcp_{}_{}", server_name, tool.name);
                    schemas.push((
                        server_name.clone(),
                        car_ir::ToolSchema {
                            name: canonical_name,
                            description: tool.description.unwrap_or_default(),
                            parameters: tool
                                .input_schema
                                .unwrap_or(serde_json::json!({"type": "object"})),
                            returns: None,
                            idempotent: false,
                            cache_ttl_secs: None,
                            rate_limit: None,
                        },
                    ));
                }
            }
        }
        schemas
    }

    /// Shut down all MCP servers.
    pub async fn shutdown_all(&self) {
        let mut servers = self.servers.lock().await;
        // Dropping the Arc<Mutex<McpServer>> will drop the Child, killing the process.
        servers.drain();
    }
}

impl Default for McpToolExecutor {
    fn default() -> Self {
        Self::new()
    }
}

#[async_trait::async_trait]
impl super::ToolExecutor for McpToolExecutor {
    async fn execute(&self, tool: &str, params: &Value) -> Result<Value, String> {
        self.execute_with_action(tool, params, "").await
    }

    async fn execute_with_action(
        &self,
        tool: &str,
        params: &Value,
        action_id: &str,
    ) -> Result<Value, String> {
        // Find which server handles this tool
        let server_name = {
            let routes = self.tool_routes.lock().await;
            routes.get(tool).cloned()
        };

        if let Some(server_name) = server_name {
            let servers = self.servers.lock().await;
            if let Some(server) = servers.get(&server_name) {
                let mut srv = server.lock().await;
                // Strip the mcp_{server}_ prefix to get the bare tool name
                let bare_name = tool
                    .strip_prefix(&format!("mcp_{}_", server_name))
                    .unwrap_or(tool);
                return srv.call_tool(bare_name, params.clone()).await;
            }
        }

        // Fallback
        if let Some(ref fallback) = self.fallback {
            return fallback.execute_with_action(tool, params, action_id).await;
        }

        Err(format!("unknown MCP tool: '{}'", tool))
    }
}