Skip to main content

lash_plugin_mcp/
plugin.rs

1//! [`PluginFactory`] for MCP integration. Holds a shared connection pool
2//! across every session built from the same `LashCore`.
3
4use std::collections::BTreeMap;
5use std::sync::Arc;
6
7use async_trait::async_trait;
8use lash_core::plugin::{
9    PluginError, PluginFactory, PluginRegistrar, PluginSessionContext, SessionPlugin,
10};
11use lash_core::{ToolCall, ToolContract, ToolManifest, ToolProvider, ToolResult};
12
13use crate::config::McpServerConfig;
14use crate::error::McpError;
15use crate::pool::McpConnectionPool;
16
17/// Plugin factory for MCP. Add once to `LashCoreBuilder` via
18/// `.plugin(Arc::new(factory))`.
19pub struct McpPluginFactory {
20    pool: Arc<McpConnectionPool>,
21}
22
23impl McpPluginFactory {
24    /// Connect to every configured server eagerly and return a factory whose
25    /// pool is ready to use. The pool is `Arc`-shared across sessions; cloning
26    /// the factory and adding it to multiple `LashCore`s shares the same
27    /// connections.
28    pub async fn new(servers: BTreeMap<String, McpServerConfig>) -> Result<Self, McpError> {
29        let pool = McpConnectionPool::connect(servers).await?;
30        Ok(Self { pool })
31    }
32
33    /// Empty pool — useful when servers are added at runtime via
34    /// [`McpPluginFactory::attach_server`].
35    pub fn empty() -> Self {
36        Self {
37            pool: Arc::new(McpConnectionPool::empty()),
38        }
39    }
40
41    /// Direct access to the underlying pool, in case the embedder wants to
42    /// inspect or mutate it directly.
43    pub fn pool(&self) -> &Arc<McpConnectionPool> {
44        &self.pool
45    }
46
47    /// Attach a new server at runtime. The new tools become visible to any
48    /// session created after this call returns; existing sessions will see
49    /// the new tools after their next tool-surface refresh.
50    pub async fn attach_server(
51        &self,
52        server_name: String,
53        config: McpServerConfig,
54    ) -> Result<(), McpError> {
55        self.pool.attach(server_name, config).await
56    }
57
58    /// Detach a server at runtime.
59    pub async fn detach_server(&self, server_name: &str) -> Result<(), McpError> {
60        self.pool.detach(server_name).await
61    }
62}
63
64impl PluginFactory for McpPluginFactory {
65    fn id(&self) -> &'static str {
66        "mcp"
67    }
68
69    fn build(&self, _ctx: &PluginSessionContext) -> Result<Arc<dyn SessionPlugin>, PluginError> {
70        Ok(Arc::new(McpSessionPlugin {
71            pool: Arc::clone(&self.pool),
72        }))
73    }
74}
75
76struct McpSessionPlugin {
77    pool: Arc<McpConnectionPool>,
78}
79
80impl SessionPlugin for McpSessionPlugin {
81    fn id(&self) -> &'static str {
82        "mcp"
83    }
84
85    fn register(&self, reg: &mut PluginRegistrar) -> Result<(), PluginError> {
86        reg.tools().provider(Arc::new(McpToolProvider {
87            pool: Arc::clone(&self.pool),
88        }) as Arc<dyn ToolProvider>)
89    }
90}
91
92/// The `ToolProvider` actually registered with each session's tool surface.
93/// Delegates definitions and execution to the shared pool.
94pub struct McpToolProvider {
95    pool: Arc<McpConnectionPool>,
96}
97
98impl McpToolProvider {
99    pub fn new(pool: Arc<McpConnectionPool>) -> Self {
100        Self { pool }
101    }
102}
103
104#[async_trait]
105impl ToolProvider for McpToolProvider {
106    fn tool_manifests(&self) -> Vec<ToolManifest> {
107        self.pool
108            .advertised_tools_blocking()
109            .into_iter()
110            .map(|tool| tool.manifest())
111            .collect()
112    }
113
114    fn resolve_contract(&self, name: &str) -> Option<Arc<ToolContract>> {
115        self.pool
116            .advertised_tools_blocking()
117            .into_iter()
118            .find(|tool| tool.name() == name)
119            .map(|tool| Arc::new(tool.contract()))
120    }
121
122    async fn execute(&self, call: ToolCall<'_>) -> ToolResult {
123        self.pool
124            .call_tool(call.name, call.args, call.context)
125            .await
126    }
127}
128
129#[cfg(test)]
130mod tests {
131    use super::*;
132    use lash_core::ToolDefinition;
133    use serde_json::{Value, json};
134    use std::collections::BTreeMap;
135
136    /// Pure unit test ported from `crates/lash/src/mcp.rs`. Verifies that a
137    /// `ToolDefinition::raw` constructed from an MCP-advertised input schema
138    /// keeps the schema verbatim — this is the canonical input contract the
139    /// model sees, so any drift here is user-visible.
140    #[test]
141    fn mcp_definition_preserves_server_schema_as_canonical_input_contract() {
142        let schema = json!({
143            "type": "object",
144            "properties": {
145                "query": {
146                    "type": "string",
147                    "description": "Search query"
148                },
149                "filters": {
150                    "type": "array",
151                    "items": { "type": "string" }
152                },
153                "strict": {
154                    "type": ["boolean", "null"],
155                    "default": false
156                }
157            },
158            "required": ["query", "filters"]
159        });
160        let definition = ToolDefinition::raw(
161            "mcp:demo/search",
162            "mcp__demo__search",
163            "Search",
164            schema.clone(),
165            json!({}),
166        );
167        assert_eq!(definition.contract.input_schema, schema);
168        assert_eq!(definition.parameter_metadata().len(), 3);
169    }
170
171    /// Full stdio integration test: spin up a tiny `sh` mock that emits three
172    /// pre-canned JSON-RPC responses (initialize, tools/list, tools/call)
173    /// matching rmcp's request-id sequence (0, 1, 2), then verify the pool
174    /// imports the advertised tool with the right discovery metadata and
175    /// executes it end-to-end.
176    #[tokio::test]
177    async fn adapter_imports_and_executes_stdio_tools() {
178        let initialize = json!({
179            "jsonrpc": "2.0",
180            "id": 0,
181            "result": {
182                "protocolVersion": "2024-11-05",
183                "capabilities": { "tools": {} },
184                "serverInfo": { "name": "demo", "version": "1.0.0" }
185            }
186        });
187        let list = json!({
188            "jsonrpc": "2.0",
189            "id": 1,
190            "result": {
191                "tools": [{
192                    "name": "search-docs",
193                    "description": "Search docs",
194                    "inputSchema": {
195                        "type": "object",
196                        "properties": {
197                            "query": { "type": "string" }
198                        },
199                        "required": ["query"],
200                        "additionalProperties": false
201                    },
202                    "outputSchema": {
203                        "type": "object",
204                        "properties": {
205                            "matches": { "type": "array" }
206                        },
207                        "required": ["matches"]
208                    }
209                }]
210            }
211        });
212        let call = json!({
213            "jsonrpc": "2.0",
214            "id": 2,
215            "result": {
216                "structuredContent": {
217                    "matches": ["matched"]
218                },
219                "content": [{
220                    "type": "text",
221                    "text": "{\n  \"matches\": [\"matched\"]\n}"
222                }]
223            }
224        });
225
226        // Read each request line before emitting the matching response —
227        // rmcp drops responses that arrive before their request is in flight,
228        // so a "dump all responses upfront" mock races against the event
229        // loop and the third response never gets matched. Reading one line
230        // per request keeps the sequence deterministic.
231        // Lines:
232        //   1. initialize          → respond with RESP1
233        //   2. notifications/initialized (no response)
234        //   3. tools/list          → respond with RESP2
235        //   4. tools/call          → respond with RESP3
236        let script = "\
237            read -r _; printf '%s\\n' \"$RESP1\"; \
238            read -r _; \
239            read -r _; printf '%s\\n' \"$RESP2\"; \
240            read -r _; printf '%s\\n' \"$RESP3\"; \
241            cat >/dev/null"
242            .to_string();
243
244        let mut env = BTreeMap::new();
245        env.insert("RESP1".to_string(), initialize.to_string());
246        env.insert("RESP2".to_string(), list.to_string());
247        env.insert("RESP3".to_string(), call.to_string());
248
249        let mut servers = BTreeMap::new();
250        servers.insert(
251            "docs".to_string(),
252            McpServerConfig::Stdio {
253                command: "sh".to_string(),
254                args: vec!["-c".to_string(), script],
255                env,
256                cwd: None,
257                startup_timeout_ms: 10_000,
258                call_timeout_ms: 10_000,
259            },
260        );
261
262        let factory = McpPluginFactory::new(servers)
263            .await
264            .expect("factory connects to stdio mock");
265
266        let defs = factory.pool().advertised_tools().await;
267        assert_eq!(defs.len(), 1, "expected one imported tool, got {defs:?}");
268        assert_eq!(defs[0].name(), "mcp__docs__search_docs");
269        assert_eq!(
270            defs[0].manifest.agent_surface.module_path,
271            vec!["docs".to_string()]
272        );
273        assert_eq!(
274            defs[0].manifest.agent_surface.operation.as_deref(),
275            Some("search_docs")
276        );
277        assert_eq!(
278            defs[0].manifest.agent_surface.aliases,
279            vec!["search-docs".to_string()]
280        );
281        assert_eq!(
282            defs[0]
283                .contract
284                .input_schema
285                .get("properties")
286                .and_then(Value::as_object)
287                .and_then(|props| props.get("query"))
288                .and_then(|query| query.get("type"))
289                .cloned(),
290            Some(json!("string"))
291        );
292        assert_eq!(
293            defs[0].contract.output_schema,
294            json!({
295                "type": "object",
296                "properties": {
297                    "matches": { "type": "array" }
298                },
299                "required": ["matches"]
300            })
301        );
302
303        let result = factory
304            .pool()
305            .call_tool(
306                "mcp__docs__search_docs",
307                &json!({ "query": "lash" }),
308                &lash_core::testing::mock_tool_context(),
309            )
310            .await;
311        assert!(result.is_success(), "{result:?}");
312        assert_eq!(
313            result.value_for_projection(),
314            json!({ "matches": ["matched"] })
315        );
316
317        factory.pool().shutdown_all().await;
318    }
319}