use std::collections::BTreeMap;
use std::sync::Arc;
use async_trait::async_trait;
use lash_core::plugin::{
PluginError, PluginFactory, PluginRegistrar, PluginSessionContext, SessionPlugin,
};
use lash_core::{ToolCall, ToolContract, ToolManifest, ToolProvider, ToolResult};
use crate::config::McpServerConfig;
use crate::error::McpError;
use crate::pool::McpConnectionPool;
pub struct McpPluginFactory {
pool: Arc<McpConnectionPool>,
}
impl McpPluginFactory {
pub async fn new(servers: BTreeMap<String, McpServerConfig>) -> Result<Self, McpError> {
let pool = McpConnectionPool::connect(servers).await?;
Ok(Self { pool })
}
pub fn empty() -> Self {
Self {
pool: Arc::new(McpConnectionPool::empty()),
}
}
pub fn pool(&self) -> &Arc<McpConnectionPool> {
&self.pool
}
pub async fn attach_server(
&self,
server_name: String,
config: McpServerConfig,
) -> Result<(), McpError> {
self.pool.attach(server_name, config).await
}
pub async fn detach_server(&self, server_name: &str) -> Result<(), McpError> {
self.pool.detach(server_name).await
}
}
impl PluginFactory for McpPluginFactory {
fn id(&self) -> &'static str {
"mcp"
}
fn build(&self, _ctx: &PluginSessionContext) -> Result<Arc<dyn SessionPlugin>, PluginError> {
Ok(Arc::new(McpSessionPlugin {
pool: Arc::clone(&self.pool),
}))
}
}
struct McpSessionPlugin {
pool: Arc<McpConnectionPool>,
}
impl SessionPlugin for McpSessionPlugin {
fn id(&self) -> &'static str {
"mcp"
}
fn register(&self, reg: &mut PluginRegistrar) -> Result<(), PluginError> {
reg.tools().provider(Arc::new(McpToolProvider {
pool: Arc::clone(&self.pool),
}) as Arc<dyn ToolProvider>)
}
}
pub struct McpToolProvider {
pool: Arc<McpConnectionPool>,
}
impl McpToolProvider {
pub fn new(pool: Arc<McpConnectionPool>) -> Self {
Self { pool }
}
}
#[async_trait]
impl ToolProvider for McpToolProvider {
fn tool_manifests(&self) -> Vec<ToolManifest> {
self.pool
.advertised_tools_blocking()
.into_iter()
.map(|tool| tool.manifest())
.collect()
}
fn resolve_contract(&self, name: &str) -> Option<Arc<ToolContract>> {
self.pool
.advertised_tools_blocking()
.into_iter()
.find(|tool| tool.name() == name)
.map(|tool| Arc::new(tool.contract()))
}
async fn execute(&self, call: ToolCall<'_>) -> ToolResult {
self.pool
.call_tool(call.name, call.args, call.context)
.await
}
}
#[cfg(test)]
mod tests {
use super::*;
use lash_core::ToolDefinition;
use serde_json::{Value, json};
use std::collections::BTreeMap;
#[test]
fn mcp_definition_preserves_server_schema_as_canonical_input_contract() {
let schema = json!({
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Search query"
},
"filters": {
"type": "array",
"items": { "type": "string" }
},
"strict": {
"type": ["boolean", "null"],
"default": false
}
},
"required": ["query", "filters"]
});
let definition = ToolDefinition::raw(
"mcp:demo/search",
"mcp__demo__search",
"Search",
schema.clone(),
json!({}),
);
assert_eq!(definition.contract.input_schema, schema);
assert_eq!(definition.parameter_metadata().len(), 3);
}
#[tokio::test]
async fn adapter_imports_and_executes_stdio_tools() {
let initialize = json!({
"jsonrpc": "2.0",
"id": 0,
"result": {
"protocolVersion": "2024-11-05",
"capabilities": { "tools": {} },
"serverInfo": { "name": "demo", "version": "1.0.0" }
}
});
let list = json!({
"jsonrpc": "2.0",
"id": 1,
"result": {
"tools": [{
"name": "search-docs",
"description": "Search docs",
"inputSchema": {
"type": "object",
"properties": {
"query": { "type": "string" }
},
"required": ["query"],
"additionalProperties": false
},
"outputSchema": {
"type": "object",
"properties": {
"matches": { "type": "array" }
},
"required": ["matches"]
}
}]
}
});
let call = json!({
"jsonrpc": "2.0",
"id": 2,
"result": {
"structuredContent": {
"matches": ["matched"]
},
"content": [{
"type": "text",
"text": "{\n \"matches\": [\"matched\"]\n}"
}]
}
});
let script = "\
read -r _; printf '%s\\n' \"$RESP1\"; \
read -r _; \
read -r _; printf '%s\\n' \"$RESP2\"; \
read -r _; printf '%s\\n' \"$RESP3\"; \
cat >/dev/null"
.to_string();
let mut env = BTreeMap::new();
env.insert("RESP1".to_string(), initialize.to_string());
env.insert("RESP2".to_string(), list.to_string());
env.insert("RESP3".to_string(), call.to_string());
let mut servers = BTreeMap::new();
servers.insert(
"docs".to_string(),
McpServerConfig::Stdio {
command: "sh".to_string(),
args: vec!["-c".to_string(), script],
env,
cwd: None,
startup_timeout_ms: 10_000,
call_timeout_ms: 10_000,
},
);
let factory = McpPluginFactory::new(servers)
.await
.expect("factory connects to stdio mock");
let defs = factory.pool().advertised_tools().await;
assert_eq!(defs.len(), 1, "expected one imported tool, got {defs:?}");
assert_eq!(defs[0].name(), "mcp__docs__search_docs");
assert_eq!(
defs[0].manifest.agent_surface.module_path,
vec!["docs".to_string()]
);
assert_eq!(
defs[0].manifest.agent_surface.operation.as_deref(),
Some("search_docs")
);
assert_eq!(
defs[0].manifest.agent_surface.aliases,
vec!["search-docs".to_string()]
);
assert_eq!(
defs[0]
.contract
.input_schema
.get("properties")
.and_then(Value::as_object)
.and_then(|props| props.get("query"))
.and_then(|query| query.get("type"))
.cloned(),
Some(json!("string"))
);
assert_eq!(
defs[0].contract.output_schema,
json!({
"type": "object",
"properties": {
"matches": { "type": "array" }
},
"required": ["matches"]
})
);
let result = factory
.pool()
.call_tool(
"mcp__docs__search_docs",
&json!({ "query": "lash" }),
&lash_core::testing::mock_tool_context(),
)
.await;
assert!(result.is_success(), "{result:?}");
assert_eq!(
result.value_for_projection(),
json!({ "matches": ["matched"] })
);
factory.pool().shutdown_all().await;
}
}