1use std::collections::{BTreeMap, BTreeSet, HashMap};
11use std::sync::Arc;
12use std::sync::RwLock;
13
14use base64::Engine;
15use http::{HeaderName, HeaderValue};
16use rmcp::model::{CallToolRequestParams, ClientInfo, Content, Implementation, RawContent};
17use rmcp::service::{RoleClient, RunningService, ServiceExt};
18use rmcp::transport::child_process::TokioChildProcess;
19use rmcp::transport::streamable_http_client::{
20 StreamableHttpClientTransport, StreamableHttpClientTransportConfig,
21};
22use serde_json::{Value, json};
23use tokio::process::Command;
24use tokio::time::timeout;
25
26use lash_core::{
27 AttachmentCreateMeta, MediaType, ToolCallOutput, ToolContext, ToolDefinition, ToolFailure,
28 ToolFailureClass, ToolFailureSource, ToolResult, ToolRetryDisposition, ToolScheduling,
29 ToolValue,
30};
31
32use crate::config::McpServerConfig;
33use crate::error::McpError;
34use crate::naming;
35
36pub struct McpConnectionPool {
39 entries: RwLock<BTreeMap<String, Arc<McpEntry>>>,
40}
41
42struct McpEntry {
43 server_name: String,
44 config: McpServerConfig,
45 service: tokio::sync::Mutex<Option<RunningService<RoleClient, ClientInfo>>>,
49 imported_tools: BTreeMap<String, ImportedTool>,
53}
54
55#[derive(Clone)]
56struct ImportedTool {
57 original_name: String,
60 definition: ToolDefinition,
61}
62
63impl McpConnectionPool {
64 pub fn empty() -> Self {
66 Self {
67 entries: RwLock::new(BTreeMap::new()),
68 }
69 }
70
71 pub async fn connect(
74 servers: BTreeMap<String, McpServerConfig>,
75 ) -> Result<Arc<Self>, McpError> {
76 let pool = Arc::new(Self::empty());
77 for (name, config) in servers {
78 pool.attach(name, config).await?;
79 }
80 Ok(pool)
81 }
82
83 pub async fn attach(
86 self: &Arc<Self>,
87 server_name: String,
88 config: McpServerConfig,
89 ) -> Result<(), McpError> {
90 config.validate(&server_name)?;
91 let entry = McpEntry::connect(server_name.clone(), config).await?;
92 let mut entries = self
93 .entries
94 .write()
95 .expect("MCP pool entries lock poisoned");
96 entries.insert(server_name, Arc::new(entry));
97 Ok(())
98 }
99
100 pub async fn detach(self: &Arc<Self>, server_name: &str) -> Result<(), McpError> {
102 let removed = {
103 let mut entries = self
104 .entries
105 .write()
106 .expect("MCP pool entries lock poisoned");
107 entries.remove(server_name)
108 };
109 if let Some(entry) = removed {
110 entry.shutdown().await;
111 }
112 Ok(())
113 }
114
115 pub fn advertised_tools_blocking(&self) -> Vec<ToolDefinition> {
118 let guard = self.entries.read().expect("MCP pool entries lock poisoned");
119 guard
120 .values()
121 .flat_map(|entry| {
122 entry
123 .imported_tools
124 .values()
125 .map(|tool| tool.definition.clone())
126 })
127 .collect()
128 }
129
130 pub async fn advertised_tools(&self) -> Vec<ToolDefinition> {
132 let guard = self.entries.read().expect("MCP pool entries lock poisoned");
133 guard
134 .values()
135 .flat_map(|entry| {
136 entry
137 .imported_tools
138 .values()
139 .map(|tool| tool.definition.clone())
140 })
141 .collect()
142 }
143
144 pub async fn call_tool(
147 &self,
148 prefixed_name: &str,
149 args: &Value,
150 context: &ToolContext<'_>,
151 ) -> ToolResult {
152 let (entry, original_name) = match self.lookup(prefixed_name).await {
153 Some(found) => found,
154 None => {
155 return ToolResult::err_fmt(format!("Unknown MCP tool: {prefixed_name}"));
156 }
157 };
158
159 let call_timeout = entry.config.call_timeout();
160 let server_name = entry.server_name.clone();
161 let arguments = match args {
162 Value::Object(map) => Some(map.clone()),
163 Value::Null => None,
164 other => {
165 return ToolResult::err_fmt(format!(
166 "MCP tool `{prefixed_name}` expected an object argument, got {}",
167 other
168 ));
169 }
170 };
171
172 let peer = {
180 let service_guard = entry.service.lock().await;
181 match service_guard.as_ref() {
182 Some(service) => service.peer().clone(),
183 None => {
184 return ToolResult::err_fmt(McpError::Protocol(format!(
185 "MCP server `{server_name}` is not connected"
186 )));
187 }
188 }
189 };
190
191 let response = timeout(call_timeout, async {
192 let mut params = CallToolRequestParams::default();
193 params.name = original_name.clone().into();
194 params.arguments = arguments;
195 peer.call_tool(params)
196 .await
197 .map_err(|err| McpError::Protocol(err.to_string()))
198 })
199 .await;
200
201 match response {
202 Ok(Ok(result)) => tool_result_from_rmcp(result, context),
203 Ok(Err(err)) => ToolResult::err_fmt(err),
204 Err(_) => ToolResult::err_fmt(McpError::CallTimeout {
205 server: server_name,
206 timeout_ms: call_timeout.as_millis() as u64,
207 }),
208 }
209 }
210
211 async fn lookup(&self, prefixed_name: &str) -> Option<(Arc<McpEntry>, String)> {
212 let guard = self.entries.read().expect("MCP pool entries lock poisoned");
213 for entry in guard.values() {
214 if let Some(tool) = entry.imported_tools.get(prefixed_name) {
215 return Some((Arc::clone(entry), tool.original_name.clone()));
216 }
217 }
218 None
219 }
220
221 pub async fn shutdown_all(&self) {
224 let entries: Vec<Arc<McpEntry>> = {
225 let mut guard = self
226 .entries
227 .write()
228 .expect("MCP pool entries lock poisoned");
229 std::mem::take(&mut *guard).into_values().collect()
230 };
231 for entry in entries {
232 entry.shutdown().await;
233 }
234 }
235}
236
237impl McpEntry {
238 async fn connect(server_name: String, config: McpServerConfig) -> Result<Self, McpError> {
239 let service = timeout(
240 config.startup_timeout(),
241 connect_service(&server_name, &config),
242 )
243 .await
244 .map_err(|_| McpError::StartupTimeout {
245 server: server_name.clone(),
246 timeout_ms: config.startup_timeout().as_millis() as u64,
247 })??;
248
249 let discovery_timeout = config.startup_timeout();
254 let tools = timeout(discovery_timeout, service.peer().list_all_tools())
255 .await
256 .map_err(|_| McpError::StartupTimeout {
257 server: server_name.clone(),
258 timeout_ms: discovery_timeout.as_millis() as u64,
259 })?
260 .map_err(|err| McpError::Protocol(format!("list_tools failed: {err}")))?;
261 let imported_tools = import_tools(&server_name, tools);
262
263 Ok(Self {
264 server_name,
265 config,
266 service: tokio::sync::Mutex::new(Some(service)),
267 imported_tools,
268 })
269 }
270
271 async fn shutdown(&self) {
272 let mut guard = self.service.lock().await;
273 if let Some(service) = guard.take() {
274 let _ = service.cancel().await;
278 }
279 }
280}
281
282async fn connect_service(
283 server_name: &str,
284 config: &McpServerConfig,
285) -> Result<RunningService<RoleClient, ClientInfo>, McpError> {
286 let mut implementation = Implementation::default();
287 implementation.name = "lash".to_string();
288 implementation.version = lash_core::VERSION.to_string();
289 let mut client_info = ClientInfo::default();
290 client_info.client_info = implementation;
291
292 match config {
293 McpServerConfig::Stdio {
294 command,
295 args,
296 env,
297 cwd,
298 ..
299 } => {
300 let mut cmd = Command::new(command);
301 cmd.args(args);
302 if let Some(cwd) = cwd {
303 cmd.current_dir(cwd);
304 }
305 for (key, value) in env {
306 cmd.env(key, value);
307 }
308 let transport = TokioChildProcess::new(cmd).map_err(|err| {
309 McpError::Protocol(format!(
310 "failed to spawn `{command}` for `{server_name}`: {err}"
311 ))
312 })?;
313 client_info.serve(transport).await.map_err(|err| {
314 McpError::Protocol(format!("MCP handshake with `{server_name}`: {err}"))
315 })
316 }
317 McpServerConfig::StreamableHttp { url, headers, .. } => {
318 let custom_headers = build_http_headers(server_name, headers)?;
319 let config = StreamableHttpClientTransportConfig::with_uri(url.as_str())
320 .custom_headers(custom_headers);
321 let transport = StreamableHttpClientTransport::from_config(config);
322 client_info.serve(transport).await.map_err(|err| {
323 McpError::Protocol(format!("MCP handshake with `{server_name}`: {err}"))
324 })
325 }
326 McpServerConfig::Sse { .. } => Err(McpError::Config(format!(
333 "MCP server `{server_name}` uses the legacy `sse` transport, which is not supported \
334 by this build. Use the `streamable_http` transport instead (it speaks the current \
335 MCP HTTP transport and handles SSE responses)."
336 ))),
337 }
338}
339
340fn build_http_headers(
345 server_name: &str,
346 headers: &BTreeMap<String, String>,
347) -> Result<HashMap<HeaderName, HeaderValue>, McpError> {
348 let mut out = HashMap::with_capacity(headers.len());
349 for (name, value) in headers {
350 let header_name = HeaderName::try_from(name.as_str()).map_err(|err| {
351 McpError::Config(format!(
352 "MCP server `{server_name}` has invalid HTTP header name `{name}`: {err}"
353 ))
354 })?;
355 let header_value = HeaderValue::try_from(value.as_str()).map_err(|err| {
356 McpError::Config(format!(
357 "MCP server `{server_name}` has invalid value for HTTP header `{name}`: {err}"
358 ))
359 })?;
360 out.insert(header_name, header_value);
361 }
362 Ok(out)
363}
364
365fn import_tools(
366 server_name: &str,
367 tools: Vec<rmcp::model::Tool>,
368) -> BTreeMap<String, ImportedTool> {
369 let mut used_names = BTreeSet::new();
370 let mut imported = BTreeMap::new();
371 for tool in tools {
372 let original_name = tool.name.to_string();
373 let description = tool
374 .description
375 .as_deref()
376 .map(str::trim)
377 .unwrap_or_default();
378 let input_schema = Value::Object((*tool.input_schema).clone());
379 let output_schema = tool
380 .output_schema
381 .as_ref()
382 .map(|s| Value::Object((**s).clone()))
383 .unwrap_or_else(|| json!({}));
384 let (prefixed, agent_surface) =
385 naming::build_prefixed_name(server_name, &original_name, &mut used_names);
386
387 let description = if description.is_empty() {
388 format!("MCP tool from server `{server_name}`")
389 } else {
390 format!("[MCP {server_name}] {description}")
391 };
392
393 imported.insert(
394 prefixed.clone(),
395 ImportedTool {
396 original_name,
397 definition: ToolDefinition::raw(
398 format!("mcp:{server_name}/{prefixed}"),
399 prefixed,
400 description,
401 input_schema,
402 output_schema,
403 )
404 .with_agent_surface(agent_surface)
405 .with_scheduling(ToolScheduling::Parallel),
406 },
407 );
408 }
409 imported
410}
411
412fn tool_result_from_rmcp(
413 result: rmcp::model::CallToolResult,
414 context: &ToolContext<'_>,
415) -> ToolResult {
416 let is_error = result.is_error.unwrap_or(false);
417
418 let mut text_parts = Vec::new();
419 let mut content_items: Vec<ToolValue> = Vec::new();
420 let mut has_attachments = false;
421
422 for Content { raw, .. } in result.content {
423 match raw {
424 RawContent::Text(text) => {
425 text_parts.push(text.text.clone());
426 content_items.push(ToolValue::String(text.text));
427 }
428 RawContent::Image(image) => {
429 let data = match base64::engine::general_purpose::STANDARD.decode(image.data) {
430 Ok(bytes) => bytes,
431 Err(err) => {
432 return ToolResult::err_fmt(McpError::Decode(err));
433 }
434 };
435 let Some(media_type) = MediaType::from_mime(&image.mime_type) else {
436 return ToolResult::err_fmt(format!(
437 "Unsupported MCP image MIME type: {}",
438 image.mime_type
439 ));
440 };
441 let reference = match context.attachments().put(
442 data,
443 AttachmentCreateMeta::new(media_type, None, None, Some("MCP image".into())),
444 ) {
445 Ok(reference) => reference,
446 Err(err) => {
447 return ToolResult::err_fmt(format!(
448 "Failed to store MCP image attachment: {err}"
449 ));
450 }
451 };
452 has_attachments = true;
453 content_items.push(ToolValue::Attachment(reference));
454 }
455 other => {
456 if let Ok(value) = serde_json::to_value(&other) {
457 content_items.push(ToolValue::from(value));
458 }
459 }
460 }
461 }
462
463 let value = if let Some(structured) = result.structured_content {
464 if !has_attachments {
465 ToolValue::from(structured)
466 } else {
467 ToolValue::Object(
468 [
469 ("structured".to_string(), ToolValue::from(structured)),
470 ("content".to_string(), ToolValue::Array(content_items)),
471 ]
472 .into_iter()
473 .collect(),
474 )
475 }
476 } else if content_items.is_empty() {
477 ToolValue::Null
478 } else if content_items.len() == 1 {
479 content_items.into_iter().next().unwrap_or(ToolValue::Null)
480 } else {
481 ToolValue::Array(content_items)
482 };
483 if is_error {
484 ToolResult::from_output(ToolCallOutput::failure(ToolFailure {
485 class: ToolFailureClass::Execution,
486 code: "mcp_tool_error".into(),
487 message: if text_parts.is_empty() {
488 "MCP tool returned an error".into()
489 } else {
490 text_parts.join("\n\n")
491 },
492 source: ToolFailureSource::Tool,
493 retry: ToolRetryDisposition::Never,
494 raw: Some(value),
495 }))
496 } else {
497 ToolResult::from_output(ToolCallOutput::success(value))
498 }
499}
500
501impl Drop for McpConnectionPool {
502 fn drop(&mut self) {
503 }
509}
510
511#[cfg(test)]
512mod tests {
513 use super::*;
514
515 #[test]
521 fn build_http_headers_carries_configured_headers() {
522 let mut headers = BTreeMap::new();
523 headers.insert(
524 "Authorization".to_string(),
525 "Bearer secret-token".to_string(),
526 );
527 headers.insert("X-Tenant".to_string(), "acme".to_string());
528
529 let built = build_http_headers("api", &headers).expect("valid headers convert");
530
531 assert_eq!(
532 built
533 .get(&HeaderName::from_static("authorization"))
534 .map(|v| v.to_str().unwrap()),
535 Some("Bearer secret-token"),
536 "configured Authorization header must be carried through to the transport"
537 );
538 assert_eq!(
539 built
540 .get(&HeaderName::from_static("x-tenant"))
541 .map(|v| v.to_str().unwrap()),
542 Some("acme")
543 );
544 assert_eq!(built.len(), 2);
545 }
546
547 #[test]
548 fn build_http_headers_empty_map_is_empty() {
549 let built = build_http_headers("api", &BTreeMap::new()).expect("empty converts");
550 assert!(built.is_empty());
551 }
552
553 #[test]
554 fn build_http_headers_rejects_malformed_name() {
555 let mut headers = BTreeMap::new();
556 headers.insert("Bad Header Name".to_string(), "x".to_string());
557 let err = build_http_headers("api", &headers).expect_err("malformed name rejected");
558 assert!(
559 matches!(err, McpError::Config(_)),
560 "expected a config error, got {err:?}"
561 );
562 }
563
564 #[test]
565 fn build_http_headers_rejects_malformed_value() {
566 let mut headers = BTreeMap::new();
567 headers.insert("X-Bad".to_string(), "line1\nline2".to_string());
569 let err = build_http_headers("api", &headers).expect_err("malformed value rejected");
570 assert!(
571 matches!(err, McpError::Config(_)),
572 "expected a config error, got {err:?}"
573 );
574 }
575
576 #[tokio::test]
580 async fn sse_transport_reports_clear_unsupported_error() {
581 let err = connect_service("legacy", &McpServerConfig::sse("http://localhost:9/sse"))
582 .await
583 .expect_err("sse transport must error, not connect");
584 match err {
585 McpError::Config(msg) => {
586 assert!(
587 msg.contains("streamable_http"),
588 "error should point operators at the supported transport: {msg}"
589 );
590 }
591 other => panic!("expected a config error for sse, got {other:?}"),
592 }
593 }
594
595 #[tokio::test]
599 async fn discovery_hang_surfaces_startup_timeout() {
600 let initialize = json!({
601 "jsonrpc": "2.0",
602 "id": 0,
603 "result": {
604 "protocolVersion": "2024-11-05",
605 "capabilities": { "tools": {} },
606 "serverInfo": { "name": "demo", "version": "1.0.0" }
607 }
608 });
609
610 let script = "\
614 read -r _; printf '%s\\n' \"$RESP1\"; \
615 read -r _; \
616 read -r _; \
617 cat >/dev/null"
618 .to_string();
619
620 let mut env = BTreeMap::new();
621 env.insert("RESP1".to_string(), initialize.to_string());
622
623 let config = McpServerConfig::Stdio {
624 command: "sh".to_string(),
625 args: vec!["-c".to_string(), script],
626 env,
627 cwd: None,
628 startup_timeout_ms: 750,
629 call_timeout_ms: 10_000,
630 };
631
632 match McpEntry::connect("hangs".to_string(), config).await {
633 Err(McpError::StartupTimeout { .. }) => {}
634 Err(other) => panic!("expected StartupTimeout from a hung tools/list, got {other:?}"),
635 Ok(_) => panic!("a hung tools/list must not connect"),
636 }
637 }
638
639 #[tokio::test]
646 async fn concurrent_calls_are_not_serialized_by_the_service_mutex() {
647 let initialize = json!({
648 "jsonrpc": "2.0",
649 "id": 0,
650 "result": {
651 "protocolVersion": "2024-11-05",
652 "capabilities": { "tools": {} },
653 "serverInfo": { "name": "demo", "version": "1.0.0" }
654 }
655 });
656 let list = json!({
657 "jsonrpc": "2.0",
658 "id": 1,
659 "result": {
660 "tools": [{
661 "name": "ping",
662 "description": "Ping",
663 "inputSchema": { "type": "object", "properties": {} }
664 }]
665 }
666 });
667 let call2 = json!({ "jsonrpc": "2.0", "id": 2, "result": { "content": [{ "type": "text", "text": "pong" }] } });
671 let call3 = json!({ "jsonrpc": "2.0", "id": 3, "result": { "content": [{ "type": "text", "text": "pong" }] } });
672
673 let script = "\
674 read -r _; printf '%s\\n' \"$RESP1\"; \
675 read -r _; \
676 read -r _; printf '%s\\n' \"$RESP2\"; \
677 read -r _; \
678 read -r _; \
679 printf '%s\\n' \"$RESP3\"; \
680 printf '%s\\n' \"$RESP4\"; \
681 cat >/dev/null"
682 .to_string();
683
684 let mut env = BTreeMap::new();
685 env.insert("RESP1".to_string(), initialize.to_string());
686 env.insert("RESP2".to_string(), list.to_string());
687 env.insert("RESP3".to_string(), call2.to_string());
688 env.insert("RESP4".to_string(), call3.to_string());
689
690 let mut servers = BTreeMap::new();
691 servers.insert(
692 "svc".to_string(),
693 McpServerConfig::Stdio {
694 command: "sh".to_string(),
695 args: vec!["-c".to_string(), script],
696 env,
697 cwd: None,
698 startup_timeout_ms: 10_000,
699 call_timeout_ms: 5_000,
700 },
701 );
702
703 let pool = McpConnectionPool::connect(servers)
704 .await
705 .expect("connects to concurrency mock");
706
707 let ctx = lash_core::testing::mock_tool_context();
708 let args = json!({});
709 let (a, b) = tokio::join!(
710 pool.call_tool("mcp__svc__ping", &args, &ctx),
711 pool.call_tool("mcp__svc__ping", &args, &ctx),
712 );
713 assert!(a.is_success(), "first concurrent call failed: {a:?}");
714 assert!(b.is_success(), "second concurrent call failed: {b:?}");
715
716 pool.shutdown_all().await;
717 }
718}