lash-plugin-mcp 0.1.0-alpha.42

MCP (Model Context Protocol) plugin for the lash agent runtime.
Documentation
//! [`PluginFactory`] for MCP integration. Holds a shared connection pool
//! across every session built from the same `LashCore`.

use std::collections::BTreeMap;
use std::sync::Arc;

use async_trait::async_trait;
use lash_core::plugin::{
    PluginError, PluginFactory, PluginRegistrar, PluginSessionContext, SessionPlugin,
};
use lash_core::{ToolCall, ToolContract, ToolManifest, ToolProvider, ToolResult};

use crate::config::McpServerConfig;
use crate::error::McpError;
use crate::pool::McpConnectionPool;

/// Plugin factory for MCP. Add once to `LashCoreBuilder` via
/// `.plugin(Arc::new(factory))`.
pub struct McpPluginFactory {
    pool: Arc<McpConnectionPool>,
}

impl McpPluginFactory {
    /// Connect to every configured server eagerly and return a factory whose
    /// pool is ready to use. The pool is `Arc`-shared across sessions; cloning
    /// the factory and adding it to multiple `LashCore`s shares the same
    /// connections.
    pub async fn new(servers: BTreeMap<String, McpServerConfig>) -> Result<Self, McpError> {
        let pool = McpConnectionPool::connect(servers).await?;
        Ok(Self { pool })
    }

    /// Empty pool — useful when servers are added at runtime via
    /// [`McpPluginFactory::attach_server`].
    pub fn empty() -> Self {
        Self {
            pool: Arc::new(McpConnectionPool::empty()),
        }
    }

    /// Direct access to the underlying pool, in case the embedder wants to
    /// inspect or mutate it directly.
    pub fn pool(&self) -> &Arc<McpConnectionPool> {
        &self.pool
    }

    /// Attach a new server at runtime. The new tools become visible to any
    /// session created after this call returns; existing sessions will see
    /// the new tools after their next tool-surface refresh.
    pub async fn attach_server(
        &self,
        server_name: String,
        config: McpServerConfig,
    ) -> Result<(), McpError> {
        self.pool.attach(server_name, config).await
    }

    /// Detach a server at runtime.
    pub async fn detach_server(&self, server_name: &str) -> Result<(), McpError> {
        self.pool.detach(server_name).await
    }
}

impl PluginFactory for McpPluginFactory {
    fn id(&self) -> &'static str {
        "mcp"
    }

    fn build(&self, _ctx: &PluginSessionContext) -> Result<Arc<dyn SessionPlugin>, PluginError> {
        Ok(Arc::new(McpSessionPlugin {
            pool: Arc::clone(&self.pool),
        }))
    }
}

struct McpSessionPlugin {
    pool: Arc<McpConnectionPool>,
}

impl SessionPlugin for McpSessionPlugin {
    fn id(&self) -> &'static str {
        "mcp"
    }

    fn register(&self, reg: &mut PluginRegistrar) -> Result<(), PluginError> {
        reg.tools().provider(Arc::new(McpToolProvider {
            pool: Arc::clone(&self.pool),
        }) as Arc<dyn ToolProvider>)
    }
}

/// The `ToolProvider` actually registered with each session's tool surface.
/// Delegates definitions and execution to the shared pool.
pub struct McpToolProvider {
    pool: Arc<McpConnectionPool>,
}

impl McpToolProvider {
    pub fn new(pool: Arc<McpConnectionPool>) -> Self {
        Self { pool }
    }
}

#[async_trait]
impl ToolProvider for McpToolProvider {
    fn tool_manifests(&self) -> Vec<ToolManifest> {
        self.pool
            .advertised_tools_blocking()
            .into_iter()
            .map(|tool| tool.manifest())
            .collect()
    }

    fn resolve_contract(&self, name: &str) -> Option<Arc<ToolContract>> {
        self.pool
            .advertised_tools_blocking()
            .into_iter()
            .find(|tool| tool.name() == name)
            .map(|tool| Arc::new(tool.contract()))
    }

    async fn execute(&self, call: ToolCall<'_>) -> ToolResult {
        self.pool
            .call_tool(call.name, call.args, call.context)
            .await
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use lash_core::ToolDefinition;
    use serde_json::{Value, json};
    use std::collections::BTreeMap;

    /// Pure unit test ported from `crates/lash/src/mcp.rs`. Verifies that a
    /// `ToolDefinition::raw` constructed from an MCP-advertised input schema
    /// keeps the schema verbatim — this is the canonical input contract the
    /// model sees, so any drift here is user-visible.
    #[test]
    fn mcp_definition_preserves_server_schema_as_canonical_input_contract() {
        let schema = json!({
            "type": "object",
            "properties": {
                "query": {
                    "type": "string",
                    "description": "Search query"
                },
                "filters": {
                    "type": "array",
                    "items": { "type": "string" }
                },
                "strict": {
                    "type": ["boolean", "null"],
                    "default": false
                }
            },
            "required": ["query", "filters"]
        });
        let definition = ToolDefinition::raw(
            "mcp:demo/search",
            "mcp__demo__search",
            "Search",
            schema.clone(),
            json!({}),
        );
        assert_eq!(definition.contract.input_schema, schema);
        assert_eq!(definition.parameter_metadata().len(), 3);
    }

    /// Full stdio integration test: spin up a tiny `sh` mock that emits three
    /// pre-canned JSON-RPC responses (initialize, tools/list, tools/call)
    /// matching rmcp's request-id sequence (0, 1, 2), then verify the pool
    /// imports the advertised tool with the right discovery metadata and
    /// executes it end-to-end.
    #[tokio::test]
    async fn adapter_imports_and_executes_stdio_tools() {
        let initialize = json!({
            "jsonrpc": "2.0",
            "id": 0,
            "result": {
                "protocolVersion": "2024-11-05",
                "capabilities": { "tools": {} },
                "serverInfo": { "name": "demo", "version": "1.0.0" }
            }
        });
        let list = json!({
            "jsonrpc": "2.0",
            "id": 1,
            "result": {
                "tools": [{
                    "name": "search-docs",
                    "description": "Search docs",
                    "inputSchema": {
                        "type": "object",
                        "properties": {
                            "query": { "type": "string" }
                        },
                        "required": ["query"],
                        "additionalProperties": false
                    },
                    "outputSchema": {
                        "type": "object",
                        "properties": {
                            "matches": { "type": "array" }
                        },
                        "required": ["matches"]
                    }
                }]
            }
        });
        let call = json!({
            "jsonrpc": "2.0",
            "id": 2,
            "result": {
                "structuredContent": {
                    "matches": ["matched"]
                },
                "content": [{
                    "type": "text",
                    "text": "{\n  \"matches\": [\"matched\"]\n}"
                }]
            }
        });

        // Read each request line before emitting the matching response —
        // rmcp drops responses that arrive before their request is in flight,
        // so a "dump all responses upfront" mock races against the event
        // loop and the third response never gets matched. Reading one line
        // per request keeps the sequence deterministic.
        // Lines:
        //   1. initialize          → respond with RESP1
        //   2. notifications/initialized (no response)
        //   3. tools/list          → respond with RESP2
        //   4. tools/call          → respond with RESP3
        let script = "\
            read -r _; printf '%s\\n' \"$RESP1\"; \
            read -r _; \
            read -r _; printf '%s\\n' \"$RESP2\"; \
            read -r _; printf '%s\\n' \"$RESP3\"; \
            cat >/dev/null"
            .to_string();

        let mut env = BTreeMap::new();
        env.insert("RESP1".to_string(), initialize.to_string());
        env.insert("RESP2".to_string(), list.to_string());
        env.insert("RESP3".to_string(), call.to_string());

        let mut servers = BTreeMap::new();
        servers.insert(
            "docs".to_string(),
            McpServerConfig::Stdio {
                command: "sh".to_string(),
                args: vec!["-c".to_string(), script],
                env,
                cwd: None,
                startup_timeout_ms: 10_000,
                call_timeout_ms: 10_000,
            },
        );

        let factory = McpPluginFactory::new(servers)
            .await
            .expect("factory connects to stdio mock");

        let defs = factory.pool().advertised_tools().await;
        assert_eq!(defs.len(), 1, "expected one imported tool, got {defs:?}");
        assert_eq!(defs[0].name(), "mcp__docs__search_docs");
        assert_eq!(
            defs[0].manifest.agent_surface.module_path,
            vec!["docs".to_string()]
        );
        assert_eq!(
            defs[0].manifest.agent_surface.operation.as_deref(),
            Some("search_docs")
        );
        assert_eq!(
            defs[0].manifest.agent_surface.aliases,
            vec!["search-docs".to_string()]
        );
        assert_eq!(
            defs[0]
                .contract
                .input_schema
                .get("properties")
                .and_then(Value::as_object)
                .and_then(|props| props.get("query"))
                .and_then(|query| query.get("type"))
                .cloned(),
            Some(json!("string"))
        );
        assert_eq!(
            defs[0].contract.output_schema,
            json!({
                "type": "object",
                "properties": {
                    "matches": { "type": "array" }
                },
                "required": ["matches"]
            })
        );

        let result = factory
            .pool()
            .call_tool(
                "mcp__docs__search_docs",
                &json!({ "query": "lash" }),
                &lash_core::testing::mock_tool_context(),
            )
            .await;
        assert!(result.is_success(), "{result:?}");
        assert_eq!(
            result.value_for_projection(),
            json!({ "matches": ["matched"] })
        );

        factory.pool().shutdown_all().await;
    }
}