cortexai-mcp 0.1.0

Model Context Protocol (MCP) support for Cortex: stdio, SSE, and server transports
Documentation
//! Tool Proxy Handler
//!
//! Exposes each tool in a `ToolRegistry` as an individual MCP tool.
//! When the MCP client calls `tools/list`, every registered Cortex tool
//! appears. When it calls `tools/call`, execution goes through the registry.
//!
//! The tool list is **dynamic**: tools added after server start are
//! discoverable on the next `tools/list` call.

#[cfg(feature = "engine")]
use std::sync::Arc;

#[cfg(feature = "engine")]
use cortexai_core::tool::{ExecutionContext, ToolRegistry};
#[cfg(feature = "engine")]
use cortexai_core::types::AgentId;

#[cfg(feature = "engine")]
use crate::error::McpError;
#[cfg(feature = "engine")]
use crate::protocol::{CallToolResult, McpTool, ToolContent};

/// Proxies every tool in a `ToolRegistry` as individual MCP tools.
///
/// The definitions are read from the registry at call time, so tools
/// registered after the handler is created will be visible.
#[cfg(feature = "engine")]
pub struct ToolProxyHandler {
    registry: Arc<ToolRegistry>,
}

#[cfg(feature = "engine")]
impl ToolProxyHandler {
    /// Create a new proxy handler wrapping the given registry.
    pub fn new(registry: Arc<ToolRegistry>) -> Self {
        Self { registry }
    }

    /// Return MCP tool definitions for every tool in the registry.
    pub fn definitions(&self) -> Vec<McpTool> {
        self.registry
            .list_schemas()
            .into_iter()
            .map(|schema| McpTool {
                name: schema.name.clone(),
                description: Some(schema.description.clone()),
                input_schema: schema.parameters.clone(),
            })
            .collect()
    }

    /// Execute a tool by name through the registry.
    pub async fn call(
        &self,
        name: &str,
        arguments: serde_json::Value,
    ) -> Result<CallToolResult, McpError> {
        let tool = self
            .registry
            .get(name)
            .ok_or_else(|| McpError::ToolNotFound(name.to_string()))?;

        let context = ExecutionContext::new(AgentId::new("mcp-proxy"));

        let result = tool
            .execute(&context, arguments)
            .await
            .map_err(|e| McpError::Internal(format!("Tool execution error: {}", e)))?;

        let text = serde_json::to_string_pretty(&result)
            .unwrap_or_else(|_| result.to_string());

        Ok(CallToolResult {
            content: vec![ToolContent::text(text)],
            is_error: false,
        })
    }
}

#[cfg(all(test, feature = "engine"))]
mod tests {
    use std::collections::HashMap;
    use std::sync::Arc;

    use async_trait::async_trait;
    use cortexai_core::errors::ToolError;
    use cortexai_core::tool::{ExecutionContext, Tool, ToolRegistry, ToolSchema};
    use serde_json::json;

    use super::*;

    struct AddTool;

    #[async_trait]
    impl Tool for AddTool {
        fn schema(&self) -> ToolSchema {
            ToolSchema {
                name: "add".to_string(),
                description: "Add two numbers".to_string(),
                parameters: json!({
                    "type": "object",
                    "properties": {
                        "a": {"type": "number"},
                        "b": {"type": "number"}
                    },
                    "required": ["a", "b"]
                }),
                dangerous: false,
                metadata: HashMap::new(),
                required_scopes: vec![],
            }
        }

        async fn execute(
            &self,
            _context: &ExecutionContext,
            arguments: serde_json::Value,
        ) -> Result<serde_json::Value, ToolError> {
            let a = arguments["a"].as_f64().unwrap_or(0.0);
            let b = arguments["b"].as_f64().unwrap_or(0.0);
            Ok(json!({"result": a + b}))
        }
    }

    struct UppercaseTool;

    #[async_trait]
    impl Tool for UppercaseTool {
        fn schema(&self) -> ToolSchema {
            ToolSchema {
                name: "uppercase".to_string(),
                description: "Convert text to uppercase".to_string(),
                parameters: json!({
                    "type": "object",
                    "properties": {
                        "text": {"type": "string"}
                    },
                    "required": ["text"]
                }),
                dangerous: false,
                metadata: HashMap::new(),
                required_scopes: vec![],
            }
        }

        async fn execute(
            &self,
            _context: &ExecutionContext,
            arguments: serde_json::Value,
        ) -> Result<serde_json::Value, ToolError> {
            let text = arguments["text"].as_str().unwrap_or("");
            Ok(json!({"result": text.to_uppercase()}))
        }
    }

    fn create_test_registry() -> ToolRegistry {
        let mut registry = ToolRegistry::new();
        registry.register(Arc::new(AddTool));
        registry.register(Arc::new(UppercaseTool));
        registry
    }

    #[tokio::test]
    async fn test_tool_proxy_lists_all_registry_tools() {
        let registry = Arc::new(create_test_registry());
        let handler = ToolProxyHandler::new(registry);

        let definitions = handler.definitions();
        assert_eq!(definitions.len(), 2);

        let names: Vec<&str> = definitions.iter().map(|d| d.name.as_str()).collect();
        assert!(names.contains(&"add"));
        assert!(names.contains(&"uppercase"));
    }

    #[tokio::test]
    async fn test_tool_proxy_executes_tool() {
        let registry = Arc::new(create_test_registry());
        let handler = ToolProxyHandler::new(registry);

        let result = handler
            .call("add", json!({"a": 3, "b": 4}))
            .await
            .unwrap();

        assert!(!result.is_error);
        let text = result.content[0].as_text().unwrap();
        assert!(text.contains("7"));
    }

    #[tokio::test]
    async fn test_tool_proxy_returns_error_for_unknown_tool() {
        let registry = Arc::new(create_test_registry());
        let handler = ToolProxyHandler::new(registry);

        let result = handler.call("nonexistent", json!({})).await;
        assert!(result.is_err());
    }
}