use std::collections::HashMap;
use std::sync::Arc;
use futures_util::future::join_all;
use serde_json::Value;
use synwire_core::mcp::traits::{McpServerStatus, McpTransport};
use tokio::sync::RwLock;
use crate::callbacks::McpCallbacks;
use crate::error::McpAdapterError;
use crate::session::McpClientSession;
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum Connection {
Stdio {
command: String,
args: Vec<String>,
env: HashMap<String, String>,
},
Sse {
url: String,
auth_token: Option<String>,
timeout_secs: Option<u64>,
},
StreamableHttp {
url: String,
auth_token: Option<String>,
timeout_secs: Option<u64>,
},
WebSocket {
url: String,
auth_token: Option<String>,
},
}
impl Connection {
pub fn into_transport(
self,
name: &str,
) -> Result<Box<dyn McpTransport>, synwire_core::agents::error::AgentError> {
match self {
Self::Stdio { command, args, env } => Ok(Box::new(
synwire_agent::mcp::StdioMcpTransport::new(name, command, args, env),
)),
Self::Sse {
url,
auth_token,
timeout_secs,
}
| Self::StreamableHttp {
url,
auth_token,
timeout_secs,
} => Ok(Box::new(synwire_agent::mcp::HttpMcpTransport::try_new(
name,
url,
auth_token,
timeout_secs,
)?)),
Self::WebSocket { url, auth_token } => Ok(Box::new(
crate::transport::WebSocketMcpTransport::new(name, url, auth_token),
)),
}
}
}
struct ServerEntry {
session: McpClientSession,
tool_name_prefix: Option<String>,
}
impl std::fmt::Debug for ServerEntry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ServerEntry")
.field("session", &self.session)
.field("tool_name_prefix", &self.tool_name_prefix)
.finish()
}
}
#[derive(Debug, Default)]
pub struct MultiServerMcpClientConfig {
pub servers: HashMap<String, Connection>,
pub global_tool_prefix: Option<String>,
pub server_prefixes: HashMap<String, String>,
}
impl MultiServerMcpClientConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_server(mut self, name: impl Into<String>, connection: Connection) -> Self {
let _ = self.servers.insert(name.into(), connection);
self
}
#[must_use]
pub fn with_server_prefix(
mut self,
server_name: impl Into<String>,
prefix: impl Into<String>,
) -> Self {
let _ = self
.server_prefixes
.insert(server_name.into(), prefix.into());
self
}
#[must_use]
pub fn with_global_prefix(mut self, prefix: impl Into<String>) -> Self {
self.global_tool_prefix = Some(prefix.into());
self
}
}
pub struct MultiServerMcpClient {
servers: Arc<RwLock<HashMap<String, ServerEntry>>>,
callbacks: Arc<McpCallbacks>,
}
impl std::fmt::Debug for MultiServerMcpClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MultiServerMcpClient")
.field("callbacks", &self.callbacks)
.finish_non_exhaustive()
}
}
impl MultiServerMcpClient {
pub async fn connect(
config: MultiServerMcpClientConfig,
callbacks: McpCallbacks,
) -> Result<Self, McpAdapterError> {
let callbacks = Arc::new(callbacks);
let MultiServerMcpClientConfig {
servers,
server_prefixes,
global_tool_prefix,
} = config;
let connect_futures: Vec<_> = servers
.into_iter()
.map(|(name, conn)| {
let prefix = server_prefixes
.get(&name)
.cloned()
.or_else(|| global_tool_prefix.clone());
let transport_result = conn.into_transport(&name);
async move {
let transport: Arc<dyn McpTransport> = match transport_result {
Ok(t) => Arc::from(t),
Err(e) => {
tracing::error!(server = %name, error = %e, "Failed to build transport");
return None;
}
};
match McpClientSession::connect(name.clone(), transport).await {
Ok(mut session) => {
if let Err(e) = session.populate_tool_cache().await {
tracing::warn!(
server = %name,
error = %e,
"Failed to populate tool cache"
);
}
Some((
name,
ServerEntry {
session,
tool_name_prefix: prefix,
},
))
}
Err(e) => {
tracing::error!(
server = %name,
error = %e,
"Failed to connect to MCP server"
);
None
}
}
}
})
.collect();
let results = join_all(connect_futures).await;
let servers: HashMap<String, ServerEntry> = results.into_iter().flatten().collect();
tracing::info!(connected = servers.len(), "MultiServerMcpClient connected");
Ok(Self {
servers: Arc::new(RwLock::new(servers)),
callbacks,
})
}
pub async fn get_tool_descriptors(&self) -> Vec<AggregatedToolDescriptor> {
let servers = self.servers.read().await;
let mut tools = Vec::new();
for (server_name, entry) in servers.iter() {
for descriptor in entry.session.cached_tools() {
let exposed_name = entry.tool_name_prefix.as_ref().map_or_else(
|| descriptor.name.clone(),
|prefix| format!("{prefix}/{}", descriptor.name),
);
tools.push(AggregatedToolDescriptor {
exposed_name,
server_name: server_name.clone(),
original_name: descriptor.name.clone(),
description: descriptor.description.clone(),
input_schema: descriptor.input_schema.clone(),
});
}
}
drop(servers);
tools
}
#[allow(clippy::significant_drop_tightening)]
pub async fn health(&self) -> Vec<McpServerStatus> {
let servers = self.servers.read().await;
let status_futures: Vec<_> = servers
.values()
.map(|entry| entry.session.status())
.collect();
join_all(status_futures).await
}
pub async fn call_tool(
&self,
exposed_tool_name: &str,
arguments: Value,
) -> Result<Value, McpAdapterError> {
let (server_name, original_name, transport) = {
let servers = self.servers.read().await;
let routing = servers.iter().find_map(|(server_name, entry)| {
for descriptor in entry.session.cached_tools() {
let exposed = entry.tool_name_prefix.as_ref().map_or_else(
|| descriptor.name.clone(),
|prefix| format!("{prefix}/{}", descriptor.name),
);
if exposed == exposed_tool_name {
return Some((server_name.clone(), descriptor.name.clone()));
}
}
None
});
let (server_name, original_name) =
routing.ok_or_else(|| McpAdapterError::ToolNotFound {
name: exposed_tool_name.to_owned(),
})?;
let transport = servers
.get(&server_name)
.ok_or_else(|| McpAdapterError::ServerNotFound {
name: server_name.clone(),
})?
.session
.transport()
.clone();
drop(servers);
(server_name, original_name, transport)
};
transport
.call_tool(&original_name, arguments)
.await
.map_err(|e| McpAdapterError::Transport {
message: format!("Tool '{original_name}' on server '{server_name}' failed: {e}"),
})
}
#[must_use]
pub fn callbacks(&self) -> &McpCallbacks {
&self.callbacks
}
}
#[derive(Debug, Clone)]
pub struct AggregatedToolDescriptor {
pub exposed_name: String,
pub server_name: String,
pub original_name: String,
pub description: String,
pub input_schema: Value,
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use crate::pagination::PaginationCursor;
#[test]
fn connection_enum_variants_exist() {
let _stdio = Connection::Stdio {
command: "mcp-server".into(),
args: vec![],
env: HashMap::new(),
};
let _ws = Connection::WebSocket {
url: "ws://localhost:3000".into(),
auth_token: None,
};
let _sse = Connection::Sse {
url: "http://localhost:3000/sse".into(),
auth_token: None,
timeout_secs: None,
};
let _http = Connection::StreamableHttp {
url: "http://localhost:3000".into(),
auth_token: None,
timeout_secs: None,
};
}
#[test]
fn config_builder() {
let config = MultiServerMcpClientConfig::new()
.with_server(
"s1",
Connection::WebSocket {
url: "ws://localhost:3000".into(),
auth_token: None,
},
)
.with_server_prefix("s1", "srv1")
.with_global_prefix("global");
assert!(config.servers.contains_key("s1"));
assert_eq!(config.server_prefixes.get("s1"), Some(&"srv1".to_owned()));
assert_eq!(config.global_tool_prefix, Some("global".to_owned()));
}
#[test]
fn pagination_used_in_client_context() {
let mut cursor = PaginationCursor::new();
assert!(cursor.advance(Some("token1".into())));
assert!(!cursor.advance(None));
}
}