use std::collections::HashMap;
use std::sync::Arc;
use rmcp::model::{CallToolRequestParams, RawContent};
use rmcp::transport::{
StreamableHttpClientTransport, TokioChildProcess,
streamable_http_client::StreamableHttpClientTransportConfig,
};
use rmcp::{Peer, RoleClient, ServiceExt};
use serde_json::Value;
use tokio::process::Command;
use tracing::{debug, info, warn};
use roboticus_core::config::{McpServerConfig, McpServerSpec};
#[derive(Debug, thiserror::Error)]
pub enum McpClientError {
#[error("transport error: {0}")]
Transport(String),
#[error("protocol error: {0}")]
Protocol(String),
#[error("server error: {0}")]
Server(String),
#[error("not connected")]
NotConnected,
#[error("connection failed: {0}")]
ConnectionFailed(String),
}
#[derive(Debug, Clone)]
pub struct DiscoveredTool {
pub name: String,
pub description: String,
pub input_schema: Value,
}
pub struct LiveMcpConnection {
name: String,
tools: Vec<DiscoveredTool>,
server_name: String,
server_version: String,
_handle: Box<dyn std::any::Any + Send + Sync>,
peer: Arc<Peer<RoleClient>>,
}
impl LiveMcpConnection {
fn finalize_connection<T>(
name: &str,
service: T,
peer: Arc<Peer<RoleClient>>,
) -> Result<Self, McpClientError>
where
T: Send + Sync + 'static,
{
let (server_name, server_version) = peer
.peer_info()
.map(|info| {
(
info.server_info.name.clone(),
info.server_info.version.clone(),
)
})
.unwrap_or_else(|| ("unknown".into(), "".into()));
Ok(Self {
name: name.to_string(),
tools: Vec::new(),
server_name,
server_version,
_handle: Box::new(service),
peer,
})
}
async fn discover_tools(mut self) -> Result<Self, McpClientError> {
let rmcp_tools = self
.peer
.list_all_tools()
.await
.map_err(|e| McpClientError::Protocol(e.to_string()))?;
self.tools = rmcp_tools
.into_iter()
.map(|t| DiscoveredTool {
name: t.name.to_string(),
description: t.description.clone().unwrap_or_default().to_string(),
input_schema: t.schema_as_json_value(),
})
.collect();
info!(
name = self.name,
server_name = self.server_name,
tool_count = self.tools.len(),
"MCP server connected"
);
Ok(self)
}
fn resolve_auth_header(config: &McpServerConfig) -> Result<Option<String>, McpClientError> {
match &config.auth_token_env {
Some(var) => std::env::var(var).map(Some).map_err(|e| {
McpClientError::ConnectionFailed(format!(
"failed to read auth token env var '{var}' for MCP server '{}': {e}",
config.name
))
}),
None => Ok(None),
}
}
pub async fn connect_stdio(
name: &str,
command: &str,
args: &[String],
env: &HashMap<String, String>,
) -> Result<Self, McpClientError> {
let mut cmd = Command::new(command);
cmd.args(args);
for (k, v) in env {
cmd.env(k, v);
}
let transport =
TokioChildProcess::new(cmd).map_err(|e| McpClientError::Transport(e.to_string()))?;
info!(name, command, "connecting to MCP server via STDIO");
let service = ()
.serve(transport)
.await
.map_err(|e| McpClientError::ConnectionFailed(e.to_string()))?;
let peer = Arc::new(service.peer().clone());
Self::finalize_connection(name, service, peer)?
.discover_tools()
.await
}
pub async fn connect_sse(config: &McpServerConfig, url: &str) -> Result<Self, McpClientError> {
let mut transport_config = StreamableHttpClientTransportConfig::with_uri(url.to_string());
if let Some(auth_header) = Self::resolve_auth_header(config)? {
transport_config = transport_config.auth_header(auth_header);
}
let transport = StreamableHttpClientTransport::from_config(transport_config);
info!(
name = config.name,
url, "connecting to MCP server via remote HTTP"
);
let service = ()
.serve(transport)
.await
.map_err(|e| McpClientError::ConnectionFailed(e.to_string()))?;
let peer = Arc::new(service.peer().clone());
Self::finalize_connection(&config.name, service, peer)?
.discover_tools()
.await
}
pub async fn connect(config: &McpServerConfig) -> Result<Self, McpClientError> {
match &config.spec {
McpServerSpec::Stdio { command, args, env } => {
Self::connect_stdio(&config.name, command, args, env).await
}
McpServerSpec::Sse { url } => Self::connect_sse(config, url).await,
}
}
pub fn name(&self) -> &str {
&self.name
}
pub fn tools(&self) -> &[DiscoveredTool] {
&self.tools
}
pub fn server_name(&self) -> &str {
&self.server_name
}
pub fn server_version(&self) -> &str {
&self.server_version
}
pub fn is_alive(&self) -> bool {
!self.peer.is_transport_closed()
}
pub async fn call_tool(
&self,
tool_name: &str,
arguments: Value,
) -> Result<Value, McpClientError> {
debug!(name = self.name, tool_name, "calling MCP tool");
let params = CallToolRequestParams {
meta: None,
name: tool_name.to_string().into(),
arguments: arguments.as_object().cloned(),
task: None,
};
let result = self
.peer
.call_tool(params)
.await
.map_err(|e| McpClientError::Server(e.to_string()))?;
let text_parts: Vec<String> = result
.content
.iter()
.filter_map(|c| {
if let RawContent::Text(t) = &c.raw {
Some(t.text.clone())
} else {
None
}
})
.collect();
Ok(serde_json::json!({
"content": text_parts.join("\n"),
"is_error": result.is_error.unwrap_or(false),
}))
}
pub async fn ping(&self) -> Result<(), McpClientError> {
if self.peer.is_transport_closed() {
Err(McpClientError::NotConnected)
} else {
Ok(())
}
}
}
impl std::fmt::Debug for LiveMcpConnection {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LiveMcpConnection")
.field("name", &self.name)
.field("server_name", &self.server_name)
.field("tool_count", &self.tools.len())
.field("alive", &self.is_alive())
.finish()
}
}
#[derive(Debug, Default)]
pub struct LiveMcpManager {
connections: HashMap<String, LiveMcpConnection>,
}
impl LiveMcpManager {
pub fn new() -> Self {
Self::default()
}
pub fn add(&mut self, conn: LiveMcpConnection) {
self.connections.insert(conn.name().to_string(), conn);
}
pub fn remove(&mut self, name: &str) -> Option<LiveMcpConnection> {
self.connections.remove(name)
}
pub fn get(&self, name: &str) -> Option<&LiveMcpConnection> {
self.connections.get(name)
}
pub fn list(&self) -> Vec<&LiveMcpConnection> {
self.connections.values().collect()
}
pub fn alive_count(&self) -> usize {
self.connections.values().filter(|c| c.is_alive()).count()
}
pub fn total_count(&self) -> usize {
self.connections.len()
}
pub fn all_tools(&self) -> Vec<(&str, &DiscoveredTool)> {
self.connections
.values()
.filter(|c| c.is_alive())
.flat_map(|c| c.tools().iter().map(move |t| (c.name(), t)))
.collect()
}
pub async fn connect_all(&mut self, configs: &[McpServerConfig]) {
for cfg in configs {
if !cfg.enabled {
debug!(name = cfg.name, "skipping disabled MCP server");
continue;
}
match LiveMcpConnection::connect(cfg).await {
Ok(conn) => self.add(conn),
Err(e) => warn!(name = cfg.name, error = %e, "failed to connect to MCP server"),
}
}
}
}
#[cfg(test)]
pub(crate) mod test_support {
use std::sync::Arc;
use rmcp::{
ServerHandler, ServiceExt,
handler::server::{router::tool::ToolRouter, wrapper::Parameters},
model::{ServerCapabilities, ServerInfo},
schemars, tool, tool_handler, tool_router,
};
use super::{LiveMcpConnection, McpClientError};
#[derive(Debug, Clone)]
struct TestInMemoryMcpServer {
tool_router: ToolRouter<Self>,
}
impl TestInMemoryMcpServer {
fn new() -> Self {
Self {
tool_router: Self::tool_router(),
}
}
}
#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
struct EchoRequest {
text: String,
}
#[tool_router]
impl TestInMemoryMcpServer {
#[tool(description = "Echo back the provided text")]
async fn echo(&self, params: Parameters<EchoRequest>) -> String {
params.0.text
}
}
#[tool_handler(router = self.tool_router)]
impl ServerHandler for TestInMemoryMcpServer {
fn get_info(&self) -> ServerInfo {
ServerInfo {
capabilities: ServerCapabilities::builder().enable_tools().build(),
..Default::default()
}
}
}
pub(crate) async fn echo_connection(
name: &str,
) -> Result<(LiveMcpConnection, tokio::task::JoinHandle<()>), McpClientError> {
let (server_transport, client_transport) = tokio::io::duplex(4096);
let server_handle = tokio::spawn(async move {
let server = TestInMemoryMcpServer::new()
.serve(server_transport)
.await
.expect("test MCP server should start");
server
.waiting()
.await
.expect("test MCP server should complete");
});
let service = ()
.serve(client_transport)
.await
.map_err(|e| McpClientError::ConnectionFailed(e.to_string()))?;
let peer = Arc::new(service.peer().clone());
let conn = LiveMcpConnection::finalize_connection(name, service, peer)?
.discover_tools()
.await?;
Ok((conn, server_handle))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn discovered_tool_fields() {
let tool = DiscoveredTool {
name: "test_tool".into(),
description: "A test tool".into(),
input_schema: serde_json::json!({"type": "object"}),
};
assert_eq!(tool.name, "test_tool");
assert_eq!(tool.description, "A test tool");
}
#[test]
fn mcp_client_error_display() {
let err = McpClientError::NotConnected;
assert_eq!(err.to_string(), "not connected");
let err = McpClientError::Transport("pipe broken".into());
assert!(err.to_string().contains("pipe broken"));
let err = McpClientError::ConnectionFailed("refused".into());
assert!(err.to_string().contains("refused"));
let err = McpClientError::Protocol("bad json".into());
assert!(err.to_string().contains("bad json"));
let err = McpClientError::Server("timeout".into());
assert!(err.to_string().contains("timeout"));
}
#[test]
fn live_mcp_manager_defaults() {
let mgr = LiveMcpManager::new();
assert_eq!(mgr.total_count(), 0);
assert_eq!(mgr.alive_count(), 0);
assert!(mgr.list().is_empty());
assert!(mgr.all_tools().is_empty());
}
#[tokio::test]
async fn connect_stdio_non_mcp_fails() {
let result =
LiveMcpConnection::connect_stdio("test-false", "false", &[], &HashMap::new()).await;
assert!(
result.is_err(),
"`false` doesn't speak MCP — expected an error, got: {:?}",
result
);
}
#[tokio::test]
async fn in_memory_connection_discovers_tools_and_calls_remote_server() {
let (conn, server_handle) = test_support::echo_connection("remote-test").await.unwrap();
assert!(conn.is_alive());
assert_eq!(conn.tools().len(), 1);
assert_eq!(conn.tools()[0].name, "echo");
let result = conn
.call_tool("echo", serde_json::json!({ "text": "hello over http" }))
.await
.unwrap();
assert_eq!(result["content"], "hello over http");
assert_eq!(result["is_error"], false);
server_handle.abort();
let _ = server_handle.await;
}
}