1use std::collections::{BTreeMap, BTreeSet, HashMap};
11use std::sync::Arc;
12use std::sync::RwLock;
13
14use base64::Engine;
15use http::{HeaderName, HeaderValue};
16use rmcp::model::{CallToolRequestParams, ClientInfo, Content, Implementation, RawContent};
17use rmcp::service::{RoleClient, RunningService, ServiceExt};
18use rmcp::transport::child_process::TokioChildProcess;
19use rmcp::transport::streamable_http_client::{
20 StreamableHttpClientTransport, StreamableHttpClientTransportConfig,
21};
22use serde_json::{Value, json};
23use tokio::process::Command;
24use tokio::time::timeout;
25
26use lash_core::{
27 AttachmentCreateMeta, MediaType, ToolCallOutput, ToolContext, ToolDefinition, ToolFailure,
28 ToolFailureClass, ToolFailureSource, ToolResult, ToolRetryDisposition, ToolScheduling,
29 ToolValue,
30};
31
32use crate::config::McpServerConfig;
33use crate::error::McpError;
34use crate::naming;
35
36pub struct McpConnectionPool {
39 entries: RwLock<BTreeMap<String, Arc<McpEntry>>>,
40}
41
42struct McpEntry {
43 server_name: String,
44 config: McpServerConfig,
45 service: tokio::sync::Mutex<Option<RunningService<RoleClient, ClientInfo>>>,
49 imported_tools: BTreeMap<String, ImportedTool>,
53}
54
55#[derive(Clone)]
56struct ImportedTool {
57 original_name: String,
60 definition: ToolDefinition,
61}
62
63impl McpConnectionPool {
64 pub fn empty() -> Self {
66 Self {
67 entries: RwLock::new(BTreeMap::new()),
68 }
69 }
70
71 pub async fn connect(
74 servers: BTreeMap<String, McpServerConfig>,
75 ) -> Result<Arc<Self>, McpError> {
76 let pool = Arc::new(Self::empty());
77 for (name, config) in servers {
78 pool.attach(name, config).await?;
79 }
80 Ok(pool)
81 }
82
83 pub async fn attach(
86 self: &Arc<Self>,
87 server_name: String,
88 config: McpServerConfig,
89 ) -> Result<(), McpError> {
90 config.validate(&server_name)?;
91 let entry = McpEntry::connect(server_name.clone(), config).await?;
92 let mut entries = self
93 .entries
94 .write()
95 .expect("MCP pool entries lock poisoned");
96 entries.insert(server_name, Arc::new(entry));
97 Ok(())
98 }
99
100 pub async fn detach(self: &Arc<Self>, server_name: &str) -> Result<(), McpError> {
102 let removed = {
103 let mut entries = self
104 .entries
105 .write()
106 .expect("MCP pool entries lock poisoned");
107 entries.remove(server_name)
108 };
109 if let Some(entry) = removed {
110 entry.shutdown().await;
111 }
112 Ok(())
113 }
114
115 pub fn advertised_tools_blocking(&self) -> Vec<ToolDefinition> {
118 let guard = self.entries.read().expect("MCP pool entries lock poisoned");
119 guard
120 .values()
121 .flat_map(|entry| {
122 entry
123 .imported_tools
124 .values()
125 .map(|tool| tool.definition.clone())
126 })
127 .collect()
128 }
129
130 pub async fn advertised_tools(&self) -> Vec<ToolDefinition> {
132 let guard = self.entries.read().expect("MCP pool entries lock poisoned");
133 guard
134 .values()
135 .flat_map(|entry| {
136 entry
137 .imported_tools
138 .values()
139 .map(|tool| tool.definition.clone())
140 })
141 .collect()
142 }
143
144 pub async fn call_tool(
147 &self,
148 prefixed_name: &str,
149 args: &Value,
150 context: &ToolContext<'_>,
151 ) -> ToolResult {
152 let (entry, original_name) = match self.lookup(prefixed_name).await {
153 Some(found) => found,
154 None => {
155 return ToolResult::err_fmt(format!("Unknown MCP tool: {prefixed_name}"));
156 }
157 };
158
159 let call_timeout = entry.config.call_timeout();
160 let server_name = entry.server_name.clone();
161 let arguments = match args {
162 Value::Object(map) => Some(map.clone()),
163 Value::Null => None,
164 other => {
165 return ToolResult::err_fmt(format!(
166 "MCP tool `{prefixed_name}` expected an object argument, got {}",
167 other
168 ));
169 }
170 };
171
172 let peer = {
180 let service_guard = entry.service.lock().await;
181 match service_guard.as_ref() {
182 Some(service) => service.peer().clone(),
183 None => {
184 return ToolResult::err_fmt(McpError::Protocol(format!(
185 "MCP server `{server_name}` is not connected"
186 )));
187 }
188 }
189 };
190
191 let response = timeout(call_timeout, async {
192 let mut params = CallToolRequestParams::default();
193 params.name = original_name.clone().into();
194 params.arguments = arguments;
195 peer.call_tool(params)
196 .await
197 .map_err(|err| McpError::Protocol(err.to_string()))
198 })
199 .await;
200
201 match response {
202 Ok(Ok(result)) => tool_result_from_rmcp(result, context).await,
203 Ok(Err(err)) => ToolResult::err_fmt(err),
204 Err(_) => ToolResult::err_fmt(McpError::CallTimeout {
205 server: server_name,
206 timeout_ms: call_timeout.as_millis() as u64,
207 }),
208 }
209 }
210
211 async fn lookup(&self, prefixed_name: &str) -> Option<(Arc<McpEntry>, String)> {
212 let guard = self.entries.read().expect("MCP pool entries lock poisoned");
213 for entry in guard.values() {
214 if let Some(tool) = entry.imported_tools.get(prefixed_name) {
215 return Some((Arc::clone(entry), tool.original_name.clone()));
216 }
217 }
218 None
219 }
220
221 pub async fn shutdown_all(&self) {
224 let entries: Vec<Arc<McpEntry>> = {
225 let mut guard = self
226 .entries
227 .write()
228 .expect("MCP pool entries lock poisoned");
229 std::mem::take(&mut *guard).into_values().collect()
230 };
231 for entry in entries {
232 entry.shutdown().await;
233 }
234 }
235}
236
237impl McpEntry {
238 async fn connect(server_name: String, config: McpServerConfig) -> Result<Self, McpError> {
239 let service = timeout(
240 config.startup_timeout(),
241 connect_service(&server_name, &config),
242 )
243 .await
244 .map_err(|_| McpError::StartupTimeout {
245 server: server_name.clone(),
246 timeout_ms: config.startup_timeout().as_millis() as u64,
247 })??;
248
249 let discovery_timeout = config.startup_timeout();
254 let tools = timeout(discovery_timeout, service.peer().list_all_tools())
255 .await
256 .map_err(|_| McpError::StartupTimeout {
257 server: server_name.clone(),
258 timeout_ms: discovery_timeout.as_millis() as u64,
259 })?
260 .map_err(|err| McpError::Protocol(format!("list_tools failed: {err}")))?;
261 let imported_tools = import_tools(&server_name, tools);
262
263 Ok(Self {
264 server_name,
265 config,
266 service: tokio::sync::Mutex::new(Some(service)),
267 imported_tools,
268 })
269 }
270
271 async fn shutdown(&self) {
272 let mut guard = self.service.lock().await;
273 if let Some(service) = guard.take() {
274 let _ = service.cancel().await;
278 }
279 }
280}
281
282async fn connect_service(
283 server_name: &str,
284 config: &McpServerConfig,
285) -> Result<RunningService<RoleClient, ClientInfo>, McpError> {
286 let mut implementation = Implementation::default();
287 implementation.name = "lash".to_string();
288 implementation.version = lash_core::VERSION.to_string();
289 let mut client_info = ClientInfo::default();
290 client_info.client_info = implementation;
291
292 match config {
293 McpServerConfig::Stdio {
294 command,
295 args,
296 env,
297 cwd,
298 ..
299 } => {
300 let mut cmd = Command::new(command);
301 cmd.args(args);
302 if let Some(cwd) = cwd {
303 cmd.current_dir(cwd);
304 }
305 for (key, value) in env {
306 cmd.env(key, value);
307 }
308 let transport = TokioChildProcess::new(cmd).map_err(|err| {
309 McpError::Protocol(format!(
310 "failed to spawn `{command}` for `{server_name}`: {err}"
311 ))
312 })?;
313 client_info.serve(transport).await.map_err(|err| {
314 McpError::Protocol(format!("MCP handshake with `{server_name}`: {err}"))
315 })
316 }
317 McpServerConfig::StreamableHttp { url, headers, .. } => {
318 let custom_headers = build_http_headers(server_name, headers)?;
319 let config = StreamableHttpClientTransportConfig::with_uri(url.as_str())
320 .custom_headers(custom_headers);
321 let transport = StreamableHttpClientTransport::from_config(config);
322 client_info.serve(transport).await.map_err(|err| {
323 McpError::Protocol(format!("MCP handshake with `{server_name}`: {err}"))
324 })
325 }
326 McpServerConfig::Sse { .. } => Err(McpError::Config(format!(
333 "MCP server `{server_name}` uses the legacy `sse` transport, which is not supported \
334 by this build. Use the `streamable_http` transport instead (it speaks the current \
335 MCP HTTP transport and handles SSE responses)."
336 ))),
337 }
338}
339
340fn build_http_headers(
345 server_name: &str,
346 headers: &BTreeMap<String, String>,
347) -> Result<HashMap<HeaderName, HeaderValue>, McpError> {
348 let mut out = HashMap::with_capacity(headers.len());
349 for (name, value) in headers {
350 let header_name = HeaderName::try_from(name.as_str()).map_err(|err| {
351 McpError::Config(format!(
352 "MCP server `{server_name}` has invalid HTTP header name `{name}`: {err}"
353 ))
354 })?;
355 let header_value = HeaderValue::try_from(value.as_str()).map_err(|err| {
356 McpError::Config(format!(
357 "MCP server `{server_name}` has invalid value for HTTP header `{name}`: {err}"
358 ))
359 })?;
360 out.insert(header_name, header_value);
361 }
362 Ok(out)
363}
364
365fn import_tools(
366 server_name: &str,
367 tools: Vec<rmcp::model::Tool>,
368) -> BTreeMap<String, ImportedTool> {
369 let mut used_names = BTreeSet::new();
370 let mut imported = BTreeMap::new();
371 for tool in tools {
372 let original_name = tool.name.to_string();
373 let description = tool
374 .description
375 .as_deref()
376 .map(str::trim)
377 .unwrap_or_default();
378 let input_schema = Value::Object((*tool.input_schema).clone());
379 let output_schema = tool
380 .output_schema
381 .as_ref()
382 .map(|s| Value::Object((**s).clone()))
383 .unwrap_or_else(|| json!({}));
384 let (prefixed, agent_surface) =
385 naming::build_prefixed_name(server_name, &original_name, &mut used_names);
386
387 let description = if description.is_empty() {
388 format!("MCP tool from server `{server_name}`")
389 } else {
390 format!("[MCP {server_name}] {description}")
391 };
392
393 imported.insert(
394 prefixed.clone(),
395 ImportedTool {
396 original_name,
397 definition: ToolDefinition::raw(
398 format!("mcp:{server_name}/{prefixed}"),
399 prefixed,
400 description,
401 input_schema,
402 output_schema,
403 )
404 .with_agent_surface(agent_surface)
405 .with_scheduling(ToolScheduling::Parallel),
406 },
407 );
408 }
409 imported
410}
411
412async fn tool_result_from_rmcp(
413 result: rmcp::model::CallToolResult,
414 context: &ToolContext<'_>,
415) -> ToolResult {
416 let is_error = result.is_error.unwrap_or(false);
417
418 let mut text_parts = Vec::new();
419 let mut content_items: Vec<ToolValue> = Vec::new();
420 let mut has_attachments = false;
421
422 for Content { raw, .. } in result.content {
423 match raw {
424 RawContent::Text(text) => {
425 text_parts.push(text.text.clone());
426 content_items.push(ToolValue::String(text.text));
427 }
428 RawContent::Image(image) => {
429 let data = match base64::engine::general_purpose::STANDARD.decode(image.data) {
430 Ok(bytes) => bytes,
431 Err(err) => {
432 return ToolResult::err_fmt(McpError::Decode(err));
433 }
434 };
435 let Some(media_type) = MediaType::from_mime(&image.mime_type) else {
436 return ToolResult::err_fmt(format!(
437 "Unsupported MCP image MIME type: {}",
438 image.mime_type
439 ));
440 };
441 let reference = match context
442 .attachments()
443 .put(
444 data,
445 AttachmentCreateMeta::new(media_type, None, None, Some("MCP image".into())),
446 )
447 .await
448 {
449 Ok(reference) => reference,
450 Err(err) => {
451 return ToolResult::err_fmt(format!(
452 "Failed to store MCP image attachment: {err}"
453 ));
454 }
455 };
456 has_attachments = true;
457 content_items.push(ToolValue::Attachment(reference));
458 }
459 other => {
460 if let Ok(value) = serde_json::to_value(&other) {
461 content_items.push(ToolValue::from(value));
462 }
463 }
464 }
465 }
466
467 let value = if let Some(structured) = result.structured_content {
468 if !has_attachments {
469 ToolValue::from(structured)
470 } else {
471 ToolValue::Object(
472 [
473 ("structured".to_string(), ToolValue::from(structured)),
474 ("content".to_string(), ToolValue::Array(content_items)),
475 ]
476 .into_iter()
477 .collect(),
478 )
479 }
480 } else if content_items.is_empty() {
481 ToolValue::Null
482 } else if content_items.len() == 1 {
483 content_items.into_iter().next().unwrap_or(ToolValue::Null)
484 } else {
485 ToolValue::Array(content_items)
486 };
487 if is_error {
488 ToolResult::from_output(ToolCallOutput::failure(ToolFailure {
489 class: ToolFailureClass::Execution,
490 code: "mcp_tool_error".into(),
491 message: if text_parts.is_empty() {
492 "MCP tool returned an error".into()
493 } else {
494 text_parts.join("\n\n")
495 },
496 source: ToolFailureSource::Tool,
497 retry: ToolRetryDisposition::Never,
498 raw: Some(value),
499 }))
500 } else {
501 ToolResult::from_output(ToolCallOutput::success(value))
502 }
503}
504
505impl Drop for McpConnectionPool {
506 fn drop(&mut self) {
507 }
513}
514
515#[cfg(test)]
516mod tests {
517 use super::*;
518
519 #[test]
525 fn build_http_headers_carries_configured_headers() {
526 let mut headers = BTreeMap::new();
527 headers.insert(
528 "Authorization".to_string(),
529 "Bearer secret-token".to_string(),
530 );
531 headers.insert("X-Tenant".to_string(), "acme".to_string());
532
533 let built = build_http_headers("api", &headers).expect("valid headers convert");
534
535 assert_eq!(
536 built
537 .get(&HeaderName::from_static("authorization"))
538 .map(|v| v.to_str().unwrap()),
539 Some("Bearer secret-token"),
540 "configured Authorization header must be carried through to the transport"
541 );
542 assert_eq!(
543 built
544 .get(&HeaderName::from_static("x-tenant"))
545 .map(|v| v.to_str().unwrap()),
546 Some("acme")
547 );
548 assert_eq!(built.len(), 2);
549 }
550
551 #[test]
552 fn build_http_headers_empty_map_is_empty() {
553 let built = build_http_headers("api", &BTreeMap::new()).expect("empty converts");
554 assert!(built.is_empty());
555 }
556
557 #[test]
558 fn build_http_headers_rejects_malformed_name() {
559 let mut headers = BTreeMap::new();
560 headers.insert("Bad Header Name".to_string(), "x".to_string());
561 let err = build_http_headers("api", &headers).expect_err("malformed name rejected");
562 assert!(
563 matches!(err, McpError::Config(_)),
564 "expected a config error, got {err:?}"
565 );
566 }
567
568 #[test]
569 fn build_http_headers_rejects_malformed_value() {
570 let mut headers = BTreeMap::new();
571 headers.insert("X-Bad".to_string(), "line1\nline2".to_string());
573 let err = build_http_headers("api", &headers).expect_err("malformed value rejected");
574 assert!(
575 matches!(err, McpError::Config(_)),
576 "expected a config error, got {err:?}"
577 );
578 }
579
580 #[tokio::test]
584 async fn sse_transport_reports_clear_unsupported_error() {
585 let err = connect_service("legacy", &McpServerConfig::sse("http://localhost:9/sse"))
586 .await
587 .expect_err("sse transport must error, not connect");
588 match err {
589 McpError::Config(msg) => {
590 assert!(
591 msg.contains("streamable_http"),
592 "error should point operators at the supported transport: {msg}"
593 );
594 }
595 other => panic!("expected a config error for sse, got {other:?}"),
596 }
597 }
598
599 #[tokio::test]
603 async fn discovery_hang_surfaces_startup_timeout() {
604 let initialize = json!({
605 "jsonrpc": "2.0",
606 "id": 0,
607 "result": {
608 "protocolVersion": "2024-11-05",
609 "capabilities": { "tools": {} },
610 "serverInfo": { "name": "demo", "version": "1.0.0" }
611 }
612 });
613
614 let script = "\
618 read -r _; printf '%s\\n' \"$RESP1\"; \
619 read -r _; \
620 read -r _; \
621 cat >/dev/null"
622 .to_string();
623
624 let mut env = BTreeMap::new();
625 env.insert("RESP1".to_string(), initialize.to_string());
626
627 let config = McpServerConfig::Stdio {
628 command: "sh".to_string(),
629 args: vec!["-c".to_string(), script],
630 env,
631 cwd: None,
632 startup_timeout_ms: 750,
633 call_timeout_ms: 10_000,
634 };
635
636 match McpEntry::connect("hangs".to_string(), config).await {
637 Err(McpError::StartupTimeout { .. }) => {}
638 Err(other) => panic!("expected StartupTimeout from a hung tools/list, got {other:?}"),
639 Ok(_) => panic!("a hung tools/list must not connect"),
640 }
641 }
642
643 #[tokio::test]
650 async fn concurrent_calls_are_not_serialized_by_the_service_mutex() {
651 let initialize = json!({
652 "jsonrpc": "2.0",
653 "id": 0,
654 "result": {
655 "protocolVersion": "2024-11-05",
656 "capabilities": { "tools": {} },
657 "serverInfo": { "name": "demo", "version": "1.0.0" }
658 }
659 });
660 let list = json!({
661 "jsonrpc": "2.0",
662 "id": 1,
663 "result": {
664 "tools": [{
665 "name": "ping",
666 "description": "Ping",
667 "inputSchema": { "type": "object", "properties": {} }
668 }]
669 }
670 });
671 let call2 = json!({ "jsonrpc": "2.0", "id": 2, "result": { "content": [{ "type": "text", "text": "pong" }] } });
675 let call3 = json!({ "jsonrpc": "2.0", "id": 3, "result": { "content": [{ "type": "text", "text": "pong" }] } });
676
677 let script = "\
678 read -r _; printf '%s\\n' \"$RESP1\"; \
679 read -r _; \
680 read -r _; printf '%s\\n' \"$RESP2\"; \
681 read -r _; \
682 read -r _; \
683 printf '%s\\n' \"$RESP3\"; \
684 printf '%s\\n' \"$RESP4\"; \
685 cat >/dev/null"
686 .to_string();
687
688 let mut env = BTreeMap::new();
689 env.insert("RESP1".to_string(), initialize.to_string());
690 env.insert("RESP2".to_string(), list.to_string());
691 env.insert("RESP3".to_string(), call2.to_string());
692 env.insert("RESP4".to_string(), call3.to_string());
693
694 let mut servers = BTreeMap::new();
695 servers.insert(
696 "svc".to_string(),
697 McpServerConfig::Stdio {
698 command: "sh".to_string(),
699 args: vec!["-c".to_string(), script],
700 env,
701 cwd: None,
702 startup_timeout_ms: 10_000,
703 call_timeout_ms: 5_000,
704 },
705 );
706
707 let pool = McpConnectionPool::connect(servers)
708 .await
709 .expect("connects to concurrency mock");
710
711 let ctx = lash_core::testing::mock_tool_context();
712 let args = json!({});
713 let (a, b) = tokio::join!(
714 pool.call_tool("mcp__svc__ping", &args, &ctx),
715 pool.call_tool("mcp__svc__ping", &args, &ctx),
716 );
717 assert!(a.is_success(), "first concurrent call failed: {a:?}");
718 assert!(b.is_success(), "second concurrent call failed: {b:?}");
719
720 pool.shutdown_all().await;
721 }
722}