Skip to main content

roboticus_agent/mcp/
client.rs

1//! MCP client — connects to external MCP servers via rmcp.
2//!
3//! Wraps the `rmcp` crate's client API, providing Roboticus-specific
4//! types and error handling. Supports both STDIO and remote streamable
5//! HTTP transports behind the shared MCP client abstraction.
6
7use std::collections::HashMap;
8use std::sync::Arc;
9
10use rmcp::model::{CallToolRequestParams, RawContent};
11use rmcp::transport::{
12    StreamableHttpClientTransport, TokioChildProcess,
13    streamable_http_client::StreamableHttpClientTransportConfig,
14};
15use rmcp::{Peer, RoleClient, ServiceExt};
16use serde_json::Value;
17use tokio::process::Command;
18use tracing::{debug, info, warn};
19
20use roboticus_core::config::{McpServerConfig, McpServerSpec};
21
22/// Errors from MCP client operations.
23#[derive(Debug, thiserror::Error)]
24pub enum McpClientError {
25    #[error("transport error: {0}")]
26    Transport(String),
27    #[error("protocol error: {0}")]
28    Protocol(String),
29    #[error("server error: {0}")]
30    Server(String),
31    #[error("not connected")]
32    NotConnected,
33    #[error("connection failed: {0}")]
34    ConnectionFailed(String),
35}
36
37/// Info about a tool discovered from an external MCP server.
38#[derive(Debug, Clone)]
39pub struct DiscoveredTool {
40    pub name: String,
41    pub description: String,
42    pub input_schema: Value,
43}
44
45/// A live connection to an external MCP server via rmcp.
46///
47/// The `_handle` field holds the `RunningService` (type-erased so we don't
48/// have to propagate the transport generic parameter). Keeping it alive keeps
49/// the child process running and the background I/O task running.
50/// The `peer` field is cloned out of the `RunningService` before we erase it,
51/// so all RPC calls go through the transport-independent `Peer<RoleClient>`.
52pub struct LiveMcpConnection {
53    name: String,
54    tools: Vec<DiscoveredTool>,
55    server_name: String,
56    server_version: String,
57    /// Keeps the RunningService (and its child process) alive.
58    _handle: Box<dyn std::any::Any + Send + Sync>,
59    /// Transport-independent handle for sending requests.
60    peer: Arc<Peer<RoleClient>>,
61}
62
63impl LiveMcpConnection {
64    fn finalize_connection<T>(
65        name: &str,
66        service: T,
67        peer: Arc<Peer<RoleClient>>,
68    ) -> Result<Self, McpClientError>
69    where
70        T: Send + Sync + 'static,
71    {
72        let (server_name, server_version) = peer
73            .peer_info()
74            .map(|info| {
75                (
76                    info.server_info.name.clone(),
77                    info.server_info.version.clone(),
78                )
79            })
80            .unwrap_or_else(|| ("unknown".into(), "".into()));
81
82        Ok(Self {
83            name: name.to_string(),
84            tools: Vec::new(),
85            server_name,
86            server_version,
87            _handle: Box::new(service),
88            peer,
89        })
90    }
91
92    async fn discover_tools(mut self) -> Result<Self, McpClientError> {
93        let rmcp_tools = self
94            .peer
95            .list_all_tools()
96            .await
97            .map_err(|e| McpClientError::Protocol(e.to_string()))?;
98
99        self.tools = rmcp_tools
100            .into_iter()
101            .map(|t| DiscoveredTool {
102                name: t.name.to_string(),
103                description: t.description.clone().unwrap_or_default().to_string(),
104                input_schema: t.schema_as_json_value(),
105            })
106            .collect();
107
108        info!(
109            name = self.name,
110            server_name = self.server_name,
111            tool_count = self.tools.len(),
112            "MCP server connected"
113        );
114
115        Ok(self)
116    }
117
118    fn resolve_auth_header(config: &McpServerConfig) -> Result<Option<String>, McpClientError> {
119        match &config.auth_token_env {
120            Some(var) => std::env::var(var).map(Some).map_err(|e| {
121                McpClientError::ConnectionFailed(format!(
122                    "failed to read auth token env var '{var}' for MCP server '{}': {e}",
123                    config.name
124                ))
125            }),
126            None => Ok(None),
127        }
128    }
129
130    /// Connect to an MCP server via STDIO transport.
131    ///
132    /// Spawns `command` with `args` and `env`, runs the MCP handshake,
133    /// and discovers all tools before returning.
134    pub async fn connect_stdio(
135        name: &str,
136        command: &str,
137        args: &[String],
138        env: &HashMap<String, String>,
139    ) -> Result<Self, McpClientError> {
140        let mut cmd = Command::new(command);
141        cmd.args(args);
142        for (k, v) in env {
143            cmd.env(k, v);
144        }
145
146        let transport =
147            TokioChildProcess::new(cmd).map_err(|e| McpClientError::Transport(e.to_string()))?;
148
149        info!(name, command, "connecting to MCP server via STDIO");
150
151        let service = ()
152            .serve(transport)
153            .await
154            .map_err(|e| McpClientError::ConnectionFailed(e.to_string()))?;
155        let peer = Arc::new(service.peer().clone());
156        Self::finalize_connection(name, service, peer)?
157            .discover_tools()
158            .await
159    }
160
161    /// Connect to an MCP server via remote streamable HTTP transport.
162    pub async fn connect_sse(config: &McpServerConfig, url: &str) -> Result<Self, McpClientError> {
163        let mut transport_config = StreamableHttpClientTransportConfig::with_uri(url.to_string());
164        if let Some(auth_header) = Self::resolve_auth_header(config)? {
165            transport_config = transport_config.auth_header(auth_header);
166        }
167        let transport = StreamableHttpClientTransport::from_config(transport_config);
168
169        info!(
170            name = config.name,
171            url, "connecting to MCP server via remote HTTP"
172        );
173
174        let service = ()
175            .serve(transport)
176            .await
177            .map_err(|e| McpClientError::ConnectionFailed(e.to_string()))?;
178        let peer = Arc::new(service.peer().clone());
179        Self::finalize_connection(&config.name, service, peer)?
180            .discover_tools()
181            .await
182    }
183
184    /// Connect from a [`McpServerConfig`].
185    pub async fn connect(config: &McpServerConfig) -> Result<Self, McpClientError> {
186        match &config.spec {
187            McpServerSpec::Stdio { command, args, env } => {
188                Self::connect_stdio(&config.name, command, args, env).await
189            }
190            McpServerSpec::Sse { url } => Self::connect_sse(config, url).await,
191        }
192    }
193
194    /// Name of this connection (matches the config entry name).
195    pub fn name(&self) -> &str {
196        &self.name
197    }
198
199    /// Tools discovered at connection time.
200    pub fn tools(&self) -> &[DiscoveredTool] {
201        &self.tools
202    }
203
204    /// Name reported by the remote server during handshake.
205    pub fn server_name(&self) -> &str {
206        &self.server_name
207    }
208
209    /// Version reported by the remote server during handshake.
210    pub fn server_version(&self) -> &str {
211        &self.server_version
212    }
213
214    /// Returns `true` if the underlying transport channel is still open.
215    pub fn is_alive(&self) -> bool {
216        !self.peer.is_transport_closed()
217    }
218
219    /// Call a tool on the remote server, returning the result as a JSON value.
220    ///
221    /// `arguments` should be a JSON object (`Value::Object`). If it is not an
222    /// object the call is made with no arguments.
223    pub async fn call_tool(
224        &self,
225        tool_name: &str,
226        arguments: Value,
227    ) -> Result<Value, McpClientError> {
228        debug!(name = self.name, tool_name, "calling MCP tool");
229
230        let params = CallToolRequestParams {
231            meta: None,
232            name: tool_name.to_string().into(),
233            arguments: arguments.as_object().cloned(),
234            task: None,
235        };
236
237        let result = self
238            .peer
239            .call_tool(params)
240            .await
241            .map_err(|e| McpClientError::Server(e.to_string()))?;
242
243        // Collect all text parts from the content list.
244        let text_parts: Vec<String> = result
245            .content
246            .iter()
247            .filter_map(|c| {
248                if let RawContent::Text(t) = &c.raw {
249                    Some(t.text.clone())
250                } else {
251                    None
252                }
253            })
254            .collect();
255
256        Ok(serde_json::json!({
257            "content": text_parts.join("\n"),
258            "is_error": result.is_error.unwrap_or(false),
259        }))
260    }
261
262    /// Send a ping to check liveness.
263    pub async fn ping(&self) -> Result<(), McpClientError> {
264        // Peer<RoleClient> doesn't expose a direct ping() — we check the channel.
265        if self.peer.is_transport_closed() {
266            Err(McpClientError::NotConnected)
267        } else {
268            Ok(())
269        }
270    }
271}
272
273impl std::fmt::Debug for LiveMcpConnection {
274    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
275        f.debug_struct("LiveMcpConnection")
276            .field("name", &self.name)
277            .field("server_name", &self.server_name)
278            .field("tool_count", &self.tools.len())
279            .field("alive", &self.is_alive())
280            .finish()
281    }
282}
283
284/// Manages a pool of live MCP client connections.
285///
286/// Thread-safe; wrap in `Arc<RwLock<LiveMcpManager>>` for shared use.
287#[derive(Debug, Default)]
288pub struct LiveMcpManager {
289    connections: HashMap<String, LiveMcpConnection>,
290}
291
292impl LiveMcpManager {
293    pub fn new() -> Self {
294        Self::default()
295    }
296
297    /// Add a successfully connected [`LiveMcpConnection`].
298    pub fn add(&mut self, conn: LiveMcpConnection) {
299        self.connections.insert(conn.name().to_string(), conn);
300    }
301
302    /// Remove a connection by name. Returns the connection if it existed.
303    pub fn remove(&mut self, name: &str) -> Option<LiveMcpConnection> {
304        self.connections.remove(name)
305    }
306
307    pub fn get(&self, name: &str) -> Option<&LiveMcpConnection> {
308        self.connections.get(name)
309    }
310
311    pub fn list(&self) -> Vec<&LiveMcpConnection> {
312        self.connections.values().collect()
313    }
314
315    pub fn alive_count(&self) -> usize {
316        self.connections.values().filter(|c| c.is_alive()).count()
317    }
318
319    pub fn total_count(&self) -> usize {
320        self.connections.len()
321    }
322
323    /// All tools across all alive connections.
324    pub fn all_tools(&self) -> Vec<(&str, &DiscoveredTool)> {
325        self.connections
326            .values()
327            .filter(|c| c.is_alive())
328            .flat_map(|c| c.tools().iter().map(move |t| (c.name(), t)))
329            .collect()
330    }
331
332    /// Connect to all enabled servers in a list of configs, logging warnings
333    /// for any that fail.
334    pub async fn connect_all(&mut self, configs: &[McpServerConfig]) {
335        for cfg in configs {
336            if !cfg.enabled {
337                debug!(name = cfg.name, "skipping disabled MCP server");
338                continue;
339            }
340            match LiveMcpConnection::connect(cfg).await {
341                Ok(conn) => self.add(conn),
342                Err(e) => warn!(name = cfg.name, error = %e, "failed to connect to MCP server"),
343            }
344        }
345    }
346}
347
348#[cfg(test)]
349pub(crate) mod test_support {
350    use std::sync::Arc;
351
352    use rmcp::{
353        ServerHandler, ServiceExt,
354        handler::server::{router::tool::ToolRouter, wrapper::Parameters},
355        model::{ServerCapabilities, ServerInfo},
356        schemars, tool, tool_handler, tool_router,
357    };
358
359    use super::{LiveMcpConnection, McpClientError};
360
361    #[derive(Debug, Clone)]
362    struct TestInMemoryMcpServer {
363        tool_router: ToolRouter<Self>,
364    }
365
366    impl TestInMemoryMcpServer {
367        fn new() -> Self {
368            Self {
369                tool_router: Self::tool_router(),
370            }
371        }
372    }
373
374    #[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
375    struct EchoRequest {
376        text: String,
377    }
378
379    #[tool_router]
380    impl TestInMemoryMcpServer {
381        #[tool(description = "Echo back the provided text")]
382        async fn echo(&self, params: Parameters<EchoRequest>) -> String {
383            params.0.text
384        }
385    }
386
387    #[tool_handler(router = self.tool_router)]
388    impl ServerHandler for TestInMemoryMcpServer {
389        fn get_info(&self) -> ServerInfo {
390            ServerInfo {
391                capabilities: ServerCapabilities::builder().enable_tools().build(),
392                ..Default::default()
393            }
394        }
395    }
396
397    pub(crate) async fn echo_connection(
398        name: &str,
399    ) -> Result<(LiveMcpConnection, tokio::task::JoinHandle<()>), McpClientError> {
400        let (server_transport, client_transport) = tokio::io::duplex(4096);
401        let server_handle = tokio::spawn(async move {
402            let server = TestInMemoryMcpServer::new()
403                .serve(server_transport)
404                .await
405                .expect("test MCP server should start");
406            server
407                .waiting()
408                .await
409                .expect("test MCP server should complete");
410        });
411        let service = ()
412            .serve(client_transport)
413            .await
414            .map_err(|e| McpClientError::ConnectionFailed(e.to_string()))?;
415        let peer = Arc::new(service.peer().clone());
416        let conn = LiveMcpConnection::finalize_connection(name, service, peer)?
417            .discover_tools()
418            .await?;
419        Ok((conn, server_handle))
420    }
421}
422
423#[cfg(test)]
424mod tests {
425    use super::*;
426
427    #[test]
428    fn discovered_tool_fields() {
429        let tool = DiscoveredTool {
430            name: "test_tool".into(),
431            description: "A test tool".into(),
432            input_schema: serde_json::json!({"type": "object"}),
433        };
434        assert_eq!(tool.name, "test_tool");
435        assert_eq!(tool.description, "A test tool");
436    }
437
438    #[test]
439    fn mcp_client_error_display() {
440        let err = McpClientError::NotConnected;
441        assert_eq!(err.to_string(), "not connected");
442
443        let err = McpClientError::Transport("pipe broken".into());
444        assert!(err.to_string().contains("pipe broken"));
445
446        let err = McpClientError::ConnectionFailed("refused".into());
447        assert!(err.to_string().contains("refused"));
448
449        let err = McpClientError::Protocol("bad json".into());
450        assert!(err.to_string().contains("bad json"));
451
452        let err = McpClientError::Server("timeout".into());
453        assert!(err.to_string().contains("timeout"));
454    }
455
456    #[test]
457    fn live_mcp_manager_defaults() {
458        let mgr = LiveMcpManager::new();
459        assert_eq!(mgr.total_count(), 0);
460        assert_eq!(mgr.alive_count(), 0);
461        assert!(mgr.list().is_empty());
462        assert!(mgr.all_tools().is_empty());
463    }
464
465    #[tokio::test]
466    async fn connect_stdio_non_mcp_fails() {
467        // Use `false` (exits immediately with code 1), so the MCP handshake
468        // fails because the transport closes before the server sends an
469        // initialize response.
470        let result =
471            LiveMcpConnection::connect_stdio("test-false", "false", &[], &HashMap::new()).await;
472
473        assert!(
474            result.is_err(),
475            "`false` doesn't speak MCP — expected an error, got: {:?}",
476            result
477        );
478    }
479
480    #[tokio::test]
481    async fn in_memory_connection_discovers_tools_and_calls_remote_server() {
482        let (conn, server_handle) = test_support::echo_connection("remote-test").await.unwrap();
483        assert!(conn.is_alive());
484        assert_eq!(conn.tools().len(), 1);
485        assert_eq!(conn.tools()[0].name, "echo");
486
487        let result = conn
488            .call_tool("echo", serde_json::json!({ "text": "hello over http" }))
489            .await
490            .unwrap();
491        assert_eq!(result["content"], "hello over http");
492        assert_eq!(result["is_error"], false);
493
494        server_handle.abort();
495        let _ = server_handle.await;
496    }
497}