1use std::collections::HashMap;
8use std::sync::Arc;
9
10use rmcp::model::{CallToolRequestParams, RawContent};
11use rmcp::transport::{
12 StreamableHttpClientTransport, TokioChildProcess,
13 streamable_http_client::StreamableHttpClientTransportConfig,
14};
15use rmcp::{Peer, RoleClient, ServiceExt};
16use serde_json::Value;
17use tokio::process::Command;
18use tracing::{debug, info, warn};
19
20use roboticus_core::config::{McpServerConfig, McpServerSpec};
21
22#[derive(Debug, thiserror::Error)]
24pub enum McpClientError {
25 #[error("transport error: {0}")]
26 Transport(String),
27 #[error("protocol error: {0}")]
28 Protocol(String),
29 #[error("server error: {0}")]
30 Server(String),
31 #[error("not connected")]
32 NotConnected,
33 #[error("connection failed: {0}")]
34 ConnectionFailed(String),
35}
36
37#[derive(Debug, Clone)]
39pub struct DiscoveredTool {
40 pub name: String,
41 pub description: String,
42 pub input_schema: Value,
43}
44
45pub struct LiveMcpConnection {
53 name: String,
54 tools: Vec<DiscoveredTool>,
55 server_name: String,
56 server_version: String,
57 _handle: Box<dyn std::any::Any + Send + Sync>,
59 peer: Arc<Peer<RoleClient>>,
61}
62
63impl LiveMcpConnection {
64 fn finalize_connection<T>(
65 name: &str,
66 service: T,
67 peer: Arc<Peer<RoleClient>>,
68 ) -> Result<Self, McpClientError>
69 where
70 T: Send + Sync + 'static,
71 {
72 let (server_name, server_version) = peer
73 .peer_info()
74 .map(|info| {
75 (
76 info.server_info.name.clone(),
77 info.server_info.version.clone(),
78 )
79 })
80 .unwrap_or_else(|| ("unknown".into(), "".into()));
81
82 Ok(Self {
83 name: name.to_string(),
84 tools: Vec::new(),
85 server_name,
86 server_version,
87 _handle: Box::new(service),
88 peer,
89 })
90 }
91
92 async fn discover_tools(mut self) -> Result<Self, McpClientError> {
93 let rmcp_tools = self
94 .peer
95 .list_all_tools()
96 .await
97 .map_err(|e| McpClientError::Protocol(e.to_string()))?;
98
99 self.tools = rmcp_tools
100 .into_iter()
101 .map(|t| DiscoveredTool {
102 name: t.name.to_string(),
103 description: t.description.clone().unwrap_or_default().to_string(),
104 input_schema: t.schema_as_json_value(),
105 })
106 .collect();
107
108 info!(
109 name = self.name,
110 server_name = self.server_name,
111 tool_count = self.tools.len(),
112 "MCP server connected"
113 );
114
115 Ok(self)
116 }
117
118 fn resolve_auth_header(config: &McpServerConfig) -> Result<Option<String>, McpClientError> {
119 match &config.auth_token_env {
120 Some(var) => std::env::var(var).map(Some).map_err(|e| {
121 McpClientError::ConnectionFailed(format!(
122 "failed to read auth token env var '{var}' for MCP server '{}': {e}",
123 config.name
124 ))
125 }),
126 None => Ok(None),
127 }
128 }
129
130 pub async fn connect_stdio(
135 name: &str,
136 command: &str,
137 args: &[String],
138 env: &HashMap<String, String>,
139 ) -> Result<Self, McpClientError> {
140 let mut cmd = Command::new(command);
141 cmd.args(args);
142 for (k, v) in env {
143 cmd.env(k, v);
144 }
145
146 let transport =
147 TokioChildProcess::new(cmd).map_err(|e| McpClientError::Transport(e.to_string()))?;
148
149 info!(name, command, "connecting to MCP server via STDIO");
150
151 let service = ()
152 .serve(transport)
153 .await
154 .map_err(|e| McpClientError::ConnectionFailed(e.to_string()))?;
155 let peer = Arc::new(service.peer().clone());
156 Self::finalize_connection(name, service, peer)?
157 .discover_tools()
158 .await
159 }
160
161 pub async fn connect_sse(config: &McpServerConfig, url: &str) -> Result<Self, McpClientError> {
163 let mut transport_config = StreamableHttpClientTransportConfig::with_uri(url.to_string());
164 if let Some(auth_header) = Self::resolve_auth_header(config)? {
165 transport_config = transport_config.auth_header(auth_header);
166 }
167 let transport = StreamableHttpClientTransport::from_config(transport_config);
168
169 info!(
170 name = config.name,
171 url, "connecting to MCP server via remote HTTP"
172 );
173
174 let service = ()
175 .serve(transport)
176 .await
177 .map_err(|e| McpClientError::ConnectionFailed(e.to_string()))?;
178 let peer = Arc::new(service.peer().clone());
179 Self::finalize_connection(&config.name, service, peer)?
180 .discover_tools()
181 .await
182 }
183
184 pub async fn connect(config: &McpServerConfig) -> Result<Self, McpClientError> {
186 match &config.spec {
187 McpServerSpec::Stdio { command, args, env } => {
188 Self::connect_stdio(&config.name, command, args, env).await
189 }
190 McpServerSpec::Sse { url } => Self::connect_sse(config, url).await,
191 }
192 }
193
194 pub fn name(&self) -> &str {
196 &self.name
197 }
198
199 pub fn tools(&self) -> &[DiscoveredTool] {
201 &self.tools
202 }
203
204 pub fn server_name(&self) -> &str {
206 &self.server_name
207 }
208
209 pub fn server_version(&self) -> &str {
211 &self.server_version
212 }
213
214 pub fn is_alive(&self) -> bool {
216 !self.peer.is_transport_closed()
217 }
218
219 pub async fn call_tool(
224 &self,
225 tool_name: &str,
226 arguments: Value,
227 ) -> Result<Value, McpClientError> {
228 debug!(name = self.name, tool_name, "calling MCP tool");
229
230 let params = CallToolRequestParams {
231 meta: None,
232 name: tool_name.to_string().into(),
233 arguments: arguments.as_object().cloned(),
234 task: None,
235 };
236
237 let result = self
238 .peer
239 .call_tool(params)
240 .await
241 .map_err(|e| McpClientError::Server(e.to_string()))?;
242
243 let text_parts: Vec<String> = result
245 .content
246 .iter()
247 .filter_map(|c| {
248 if let RawContent::Text(t) = &c.raw {
249 Some(t.text.clone())
250 } else {
251 None
252 }
253 })
254 .collect();
255
256 Ok(serde_json::json!({
257 "content": text_parts.join("\n"),
258 "is_error": result.is_error.unwrap_or(false),
259 }))
260 }
261
262 pub async fn ping(&self) -> Result<(), McpClientError> {
264 if self.peer.is_transport_closed() {
266 Err(McpClientError::NotConnected)
267 } else {
268 Ok(())
269 }
270 }
271}
272
273impl std::fmt::Debug for LiveMcpConnection {
274 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
275 f.debug_struct("LiveMcpConnection")
276 .field("name", &self.name)
277 .field("server_name", &self.server_name)
278 .field("tool_count", &self.tools.len())
279 .field("alive", &self.is_alive())
280 .finish()
281 }
282}
283
284#[derive(Debug, Default)]
288pub struct LiveMcpManager {
289 connections: HashMap<String, LiveMcpConnection>,
290}
291
292impl LiveMcpManager {
293 pub fn new() -> Self {
294 Self::default()
295 }
296
297 pub fn add(&mut self, conn: LiveMcpConnection) {
299 self.connections.insert(conn.name().to_string(), conn);
300 }
301
302 pub fn remove(&mut self, name: &str) -> Option<LiveMcpConnection> {
304 self.connections.remove(name)
305 }
306
307 pub fn get(&self, name: &str) -> Option<&LiveMcpConnection> {
308 self.connections.get(name)
309 }
310
311 pub fn list(&self) -> Vec<&LiveMcpConnection> {
312 self.connections.values().collect()
313 }
314
315 pub fn alive_count(&self) -> usize {
316 self.connections.values().filter(|c| c.is_alive()).count()
317 }
318
319 pub fn total_count(&self) -> usize {
320 self.connections.len()
321 }
322
323 pub fn all_tools(&self) -> Vec<(&str, &DiscoveredTool)> {
325 self.connections
326 .values()
327 .filter(|c| c.is_alive())
328 .flat_map(|c| c.tools().iter().map(move |t| (c.name(), t)))
329 .collect()
330 }
331
332 pub async fn connect_all(&mut self, configs: &[McpServerConfig]) {
335 for cfg in configs {
336 if !cfg.enabled {
337 debug!(name = cfg.name, "skipping disabled MCP server");
338 continue;
339 }
340 match LiveMcpConnection::connect(cfg).await {
341 Ok(conn) => self.add(conn),
342 Err(e) => warn!(name = cfg.name, error = %e, "failed to connect to MCP server"),
343 }
344 }
345 }
346}
347
348#[cfg(test)]
349pub(crate) mod test_support {
350 use std::sync::Arc;
351
352 use rmcp::{
353 ServerHandler, ServiceExt,
354 handler::server::{router::tool::ToolRouter, wrapper::Parameters},
355 model::{ServerCapabilities, ServerInfo},
356 schemars, tool, tool_handler, tool_router,
357 };
358
359 use super::{LiveMcpConnection, McpClientError};
360
361 #[derive(Debug, Clone)]
362 struct TestInMemoryMcpServer {
363 tool_router: ToolRouter<Self>,
364 }
365
366 impl TestInMemoryMcpServer {
367 fn new() -> Self {
368 Self {
369 tool_router: Self::tool_router(),
370 }
371 }
372 }
373
374 #[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
375 struct EchoRequest {
376 text: String,
377 }
378
379 #[tool_router]
380 impl TestInMemoryMcpServer {
381 #[tool(description = "Echo back the provided text")]
382 async fn echo(&self, params: Parameters<EchoRequest>) -> String {
383 params.0.text
384 }
385 }
386
387 #[tool_handler(router = self.tool_router)]
388 impl ServerHandler for TestInMemoryMcpServer {
389 fn get_info(&self) -> ServerInfo {
390 ServerInfo {
391 capabilities: ServerCapabilities::builder().enable_tools().build(),
392 ..Default::default()
393 }
394 }
395 }
396
397 pub(crate) async fn echo_connection(
398 name: &str,
399 ) -> Result<(LiveMcpConnection, tokio::task::JoinHandle<()>), McpClientError> {
400 let (server_transport, client_transport) = tokio::io::duplex(4096);
401 let server_handle = tokio::spawn(async move {
402 let server = TestInMemoryMcpServer::new()
403 .serve(server_transport)
404 .await
405 .expect("test MCP server should start");
406 server
407 .waiting()
408 .await
409 .expect("test MCP server should complete");
410 });
411 let service = ()
412 .serve(client_transport)
413 .await
414 .map_err(|e| McpClientError::ConnectionFailed(e.to_string()))?;
415 let peer = Arc::new(service.peer().clone());
416 let conn = LiveMcpConnection::finalize_connection(name, service, peer)?
417 .discover_tools()
418 .await?;
419 Ok((conn, server_handle))
420 }
421}
422
423#[cfg(test)]
424mod tests {
425 use super::*;
426
427 #[test]
428 fn discovered_tool_fields() {
429 let tool = DiscoveredTool {
430 name: "test_tool".into(),
431 description: "A test tool".into(),
432 input_schema: serde_json::json!({"type": "object"}),
433 };
434 assert_eq!(tool.name, "test_tool");
435 assert_eq!(tool.description, "A test tool");
436 }
437
438 #[test]
439 fn mcp_client_error_display() {
440 let err = McpClientError::NotConnected;
441 assert_eq!(err.to_string(), "not connected");
442
443 let err = McpClientError::Transport("pipe broken".into());
444 assert!(err.to_string().contains("pipe broken"));
445
446 let err = McpClientError::ConnectionFailed("refused".into());
447 assert!(err.to_string().contains("refused"));
448
449 let err = McpClientError::Protocol("bad json".into());
450 assert!(err.to_string().contains("bad json"));
451
452 let err = McpClientError::Server("timeout".into());
453 assert!(err.to_string().contains("timeout"));
454 }
455
456 #[test]
457 fn live_mcp_manager_defaults() {
458 let mgr = LiveMcpManager::new();
459 assert_eq!(mgr.total_count(), 0);
460 assert_eq!(mgr.alive_count(), 0);
461 assert!(mgr.list().is_empty());
462 assert!(mgr.all_tools().is_empty());
463 }
464
465 #[tokio::test]
466 async fn connect_stdio_non_mcp_fails() {
467 let result =
471 LiveMcpConnection::connect_stdio("test-false", "false", &[], &HashMap::new()).await;
472
473 assert!(
474 result.is_err(),
475 "`false` doesn't speak MCP — expected an error, got: {:?}",
476 result
477 );
478 }
479
480 #[tokio::test]
481 async fn in_memory_connection_discovers_tools_and_calls_remote_server() {
482 let (conn, server_handle) = test_support::echo_connection("remote-test").await.unwrap();
483 assert!(conn.is_alive());
484 assert_eq!(conn.tools().len(), 1);
485 assert_eq!(conn.tools()[0].name, "echo");
486
487 let result = conn
488 .call_tool("echo", serde_json::json!({ "text": "hello over http" }))
489 .await
490 .unwrap();
491 assert_eq!(result["content"], "hello over http");
492 assert_eq!(result["is_error"], false);
493
494 server_handle.abort();
495 let _ = server_handle.await;
496 }
497}