1use std::collections::BTreeMap;
5use std::sync::Arc;
6
7use async_trait::async_trait;
8use lash_core::plugin::{
9 PluginError, PluginFactory, PluginRegistrar, PluginSessionContext, SessionPlugin,
10};
11use lash_core::{ToolCall, ToolContract, ToolManifest, ToolProvider, ToolResult};
12
13use crate::config::McpServerConfig;
14use crate::error::McpError;
15use crate::pool::McpConnectionPool;
16
17pub struct McpPluginFactory {
20 pool: Arc<McpConnectionPool>,
21}
22
23impl McpPluginFactory {
24 pub async fn new(servers: BTreeMap<String, McpServerConfig>) -> Result<Self, McpError> {
29 let pool = McpConnectionPool::connect(servers).await?;
30 Ok(Self { pool })
31 }
32
33 pub fn empty() -> Self {
36 Self {
37 pool: Arc::new(McpConnectionPool::empty()),
38 }
39 }
40
41 pub fn pool(&self) -> &Arc<McpConnectionPool> {
44 &self.pool
45 }
46
47 pub async fn attach_server(
51 &self,
52 server_name: String,
53 config: McpServerConfig,
54 ) -> Result<(), McpError> {
55 self.pool.attach(server_name, config).await
56 }
57
58 pub async fn detach_server(&self, server_name: &str) -> Result<(), McpError> {
60 self.pool.detach(server_name).await
61 }
62}
63
64impl PluginFactory for McpPluginFactory {
65 fn id(&self) -> &'static str {
66 "mcp"
67 }
68
69 fn build(&self, _ctx: &PluginSessionContext) -> Result<Arc<dyn SessionPlugin>, PluginError> {
70 Ok(Arc::new(McpSessionPlugin {
71 pool: Arc::clone(&self.pool),
72 }))
73 }
74}
75
76struct McpSessionPlugin {
77 pool: Arc<McpConnectionPool>,
78}
79
80impl SessionPlugin for McpSessionPlugin {
81 fn id(&self) -> &'static str {
82 "mcp"
83 }
84
85 fn register(&self, reg: &mut PluginRegistrar) -> Result<(), PluginError> {
86 reg.tools().provider(Arc::new(McpToolProvider {
87 pool: Arc::clone(&self.pool),
88 }) as Arc<dyn ToolProvider>)
89 }
90}
91
92pub struct McpToolProvider {
95 pool: Arc<McpConnectionPool>,
96}
97
98impl McpToolProvider {
99 pub fn new(pool: Arc<McpConnectionPool>) -> Self {
100 Self { pool }
101 }
102}
103
104#[async_trait]
105impl ToolProvider for McpToolProvider {
106 fn tool_manifests(&self) -> Vec<ToolManifest> {
107 self.pool
108 .advertised_tools_blocking()
109 .into_iter()
110 .map(|tool| tool.manifest())
111 .collect()
112 }
113
114 fn resolve_contract(&self, name: &str) -> Option<Arc<ToolContract>> {
115 self.pool
116 .advertised_tools_blocking()
117 .into_iter()
118 .find(|tool| tool.name() == name)
119 .map(|tool| Arc::new(tool.contract()))
120 }
121
122 async fn execute(&self, call: ToolCall<'_>) -> ToolResult {
123 self.pool
124 .call_tool(call.name, call.args, call.context)
125 .await
126 }
127}
128
129#[cfg(test)]
130mod tests {
131 use super::*;
132 use lash_core::ToolDefinition;
133 use serde_json::{Value, json};
134 use std::collections::BTreeMap;
135
136 #[test]
141 fn mcp_definition_preserves_server_schema_as_canonical_input_contract() {
142 let schema = json!({
143 "type": "object",
144 "properties": {
145 "query": {
146 "type": "string",
147 "description": "Search query"
148 },
149 "filters": {
150 "type": "array",
151 "items": { "type": "string" }
152 },
153 "strict": {
154 "type": ["boolean", "null"],
155 "default": false
156 }
157 },
158 "required": ["query", "filters"]
159 });
160 let definition = ToolDefinition::raw(
161 "mcp:demo/search",
162 "mcp__demo__search",
163 "Search",
164 schema.clone(),
165 json!({}),
166 );
167 assert_eq!(definition.contract.input_schema, schema);
168 assert_eq!(definition.parameter_metadata().len(), 3);
169 }
170
171 #[tokio::test]
177 async fn adapter_imports_and_executes_stdio_tools() {
178 let initialize = json!({
179 "jsonrpc": "2.0",
180 "id": 0,
181 "result": {
182 "protocolVersion": "2024-11-05",
183 "capabilities": { "tools": {} },
184 "serverInfo": { "name": "demo", "version": "1.0.0" }
185 }
186 });
187 let list = json!({
188 "jsonrpc": "2.0",
189 "id": 1,
190 "result": {
191 "tools": [{
192 "name": "search-docs",
193 "description": "Search docs",
194 "inputSchema": {
195 "type": "object",
196 "properties": {
197 "query": { "type": "string" }
198 },
199 "required": ["query"],
200 "additionalProperties": false
201 },
202 "outputSchema": {
203 "type": "object",
204 "properties": {
205 "matches": { "type": "array" }
206 },
207 "required": ["matches"]
208 }
209 }]
210 }
211 });
212 let call = json!({
213 "jsonrpc": "2.0",
214 "id": 2,
215 "result": {
216 "structuredContent": {
217 "matches": ["matched"]
218 },
219 "content": [{
220 "type": "text",
221 "text": "{\n \"matches\": [\"matched\"]\n}"
222 }]
223 }
224 });
225
226 let script = "\
237 read -r _; printf '%s\\n' \"$RESP1\"; \
238 read -r _; \
239 read -r _; printf '%s\\n' \"$RESP2\"; \
240 read -r _; printf '%s\\n' \"$RESP3\"; \
241 cat >/dev/null"
242 .to_string();
243
244 let mut env = BTreeMap::new();
245 env.insert("RESP1".to_string(), initialize.to_string());
246 env.insert("RESP2".to_string(), list.to_string());
247 env.insert("RESP3".to_string(), call.to_string());
248
249 let mut servers = BTreeMap::new();
250 servers.insert(
251 "docs".to_string(),
252 McpServerConfig::Stdio {
253 command: "sh".to_string(),
254 args: vec!["-c".to_string(), script],
255 env,
256 cwd: None,
257 startup_timeout_ms: 10_000,
258 call_timeout_ms: 10_000,
259 },
260 );
261
262 let factory = McpPluginFactory::new(servers)
263 .await
264 .expect("factory connects to stdio mock");
265
266 let defs = factory.pool().advertised_tools().await;
267 assert_eq!(defs.len(), 1, "expected one imported tool, got {defs:?}");
268 assert_eq!(defs[0].name(), "mcp__docs__search_docs");
269 assert_eq!(
270 defs[0].manifest.agent_surface.module_path,
271 vec!["docs".to_string()]
272 );
273 assert_eq!(
274 defs[0].manifest.agent_surface.operation.as_deref(),
275 Some("search_docs")
276 );
277 assert_eq!(
278 defs[0].manifest.agent_surface.aliases,
279 vec!["search-docs".to_string()]
280 );
281 assert_eq!(
282 defs[0]
283 .contract
284 .input_schema
285 .get("properties")
286 .and_then(Value::as_object)
287 .and_then(|props| props.get("query"))
288 .and_then(|query| query.get("type"))
289 .cloned(),
290 Some(json!("string"))
291 );
292 assert_eq!(
293 defs[0].contract.output_schema,
294 json!({
295 "type": "object",
296 "properties": {
297 "matches": { "type": "array" }
298 },
299 "required": ["matches"]
300 })
301 );
302
303 let result = factory
304 .pool()
305 .call_tool(
306 "mcp__docs__search_docs",
307 &json!({ "query": "lash" }),
308 &lash_core::testing::mock_tool_context(),
309 )
310 .await;
311 assert!(result.is_success(), "{result:?}");
312 assert_eq!(
313 result.value_for_projection(),
314 json!({ "matches": ["matched"] })
315 );
316
317 factory.pool().shutdown_all().await;
318 }
319}