pub mod client;
pub mod config;
pub mod error;
pub mod server;
pub mod sse_proxy;
pub mod transport;
pub use client::McpClient;
pub use config::Config;
pub use error::{Error, Result};
pub use server::{ServerId, ServerProcess, ServerStatus};
pub use sse_proxy::SSEProxyHandle;
use std::collections::HashMap;
use std::path::Path;
use std::sync::Arc;
use transport::StdioTransport;
use sse_proxy::types::ServerInfo;
use sse_proxy::{SSEProxy, SSEProxyRunnerAccess};
pub struct McpRunner {
config: Config,
servers: HashMap<ServerId, ServerProcess>,
server_names: HashMap<String, ServerId>,
sse_proxy_handle: Option<SSEProxyHandle>,
clients: HashMap<ServerId, Option<McpClient>>,
}
impl McpRunner {
#[tracing::instrument(skip(path), fields(config_path = ?path.as_ref()))]
pub fn from_config_file(path: impl AsRef<Path>) -> Result<Self> {
tracing::info!("Loading configuration from file");
let config = Config::from_file(path)?;
Ok(Self::new(config))
}
#[tracing::instrument(skip(config))]
pub fn from_config_str(config: &str) -> Result<Self> {
tracing::info!("Loading configuration from string");
let config = Config::parse_from_str(config)?;
Ok(Self::new(config))
}
#[tracing::instrument(skip(config), fields(num_servers = config.mcp_servers.len()))]
pub fn new(config: Config) -> Self {
tracing::info!("Creating new McpRunner");
Self {
config,
servers: HashMap::new(),
server_names: HashMap::new(),
sse_proxy_handle: None,
clients: HashMap::new(),
}
}
#[tracing::instrument(skip(self), fields(server_name = %name))]
pub async fn start_server(&mut self, name: &str) -> Result<ServerId> {
if let Some(id) = self.server_names.get(name) {
tracing::debug!(server_id = %id, "Server already running");
return Ok(*id);
}
tracing::info!("Attempting to start server");
let config = self
.config
.mcp_servers
.get(name)
.ok_or_else(|| {
tracing::error!("Configuration not found for server");
Error::ServerNotFound(name.to_string())
})?
.clone();
let mut server = ServerProcess::new(name.to_string(), config);
let id = server.id();
tracing::debug!(server_id = %id, "Created ServerProcess instance");
server.start().await.map_err(|e| {
tracing::error!(error = %e, "Failed to start server process");
e
})?;
tracing::debug!(server_id = %id, "Storing running server process");
self.servers.insert(id, server);
self.server_names.insert(name.to_string(), id);
if let Some(proxy) = &self.sse_proxy_handle {
let status = format!("{:?}", ServerStatus::Running);
if let Err(e) = proxy.update_server_info(name, Some(id), &status).await {
tracing::warn!(
error = %e,
server = %name,
"Failed to update server info in SSE proxy"
);
let server_info = ServerInfo {
name: name.to_string(),
id: format!("{:?}", id),
status: status.clone(),
};
if let Err(e) = proxy.add_server_info(name, server_info.clone()).await {
tracing::warn!(
error = %e,
server = %name,
"Failed to add server to SSE proxy cache"
);
} else {
tracing::debug!(server = %name, "Added new server to SSE proxy cache");
}
} else {
tracing::debug!(server = %name, "Updated SSE proxy with new server information");
}
}
tracing::info!(server_id = %id, "Server started successfully");
Ok(id)
}
#[tracing::instrument(skip(self))]
pub async fn start_all_servers(&mut self) -> Result<Vec<ServerId>> {
tracing::info!("Starting all configured servers");
let server_names: Vec<String> = self
.config
.mcp_servers
.keys()
.map(|k| k.to_string())
.collect();
tracing::debug!(servers_to_start = ?server_names);
let mut ids = Vec::new();
let mut errors = Vec::new();
for name in server_names {
match self.start_server(&name).await {
Ok(id) => ids.push(id),
Err(e) => {
tracing::error!(server_name = %name, error = %e, "Failed to start server");
errors.push((name, e));
}
}
}
if !errors.is_empty() {
tracing::warn!(
num_failed = errors.len(),
"Some servers failed to start: {:?}",
errors
.iter()
.map(|(name, _): &(String, Error)| name.as_str())
.collect::<Vec<_>>()
);
if errors.len() == 1 {
return Err(errors.remove(0).1);
} else {
let error_msg = errors
.iter()
.map(|(name, e)| format!("{}: {}", name, e))
.collect::<Vec<_>>()
.join("; ");
return Err(Error::Other(format!(
"Multiple servers failed to start: {}",
error_msg
)));
}
}
tracing::info!(num_started = ids.len(), "Finished starting all servers");
Ok(ids)
}
#[tracing::instrument(skip(self))]
pub async fn start_all_with_proxy(&mut self) -> (Result<Vec<ServerId>>, bool) {
let server_result = self.start_all_servers().await;
let proxy_started = if server_result.is_ok() && self.is_sse_proxy_configured() {
match self.start_sse_proxy().await {
Ok(_) => {
tracing::info!("SSE proxy started automatically");
true
}
Err(e) => {
tracing::warn!(error = %e, "Failed to start SSE proxy");
false
}
}
} else {
if self.is_sse_proxy_configured() {
tracing::warn!("Not starting SSE proxy because servers failed to start");
}
false
};
(server_result, proxy_started)
}
#[tracing::instrument(skip(self), fields(server_id = %id))]
pub async fn stop_server(&mut self, id: ServerId) -> Result<()> {
tracing::info!("Attempting to stop server");
if let Some(mut server) = self.servers.remove(&id) {
let name = server.name().to_string();
tracing::debug!(server_name = %name, "Found server process to stop");
self.server_names.remove(&name);
server.stop().await.map_err(|e| {
tracing::error!(error = %e, "Failed to stop server process");
e
})?;
if let Some(proxy) = &self.sse_proxy_handle {
if let Err(e) = proxy.update_server_info(&name, None, "Stopped").await {
tracing::warn!(
error = %e,
server = %name,
"Failed to update SSE proxy with server stopped status"
);
} else {
tracing::debug!(server = %name, "Updated SSE proxy with server stopped status");
}
}
tracing::info!("Server stopped successfully");
Ok(())
} else {
tracing::warn!("Attempted to stop a server that was not found or not running");
Err(Error::ServerNotFound(format!("{:?}", id)))
}
}
#[tracing::instrument(skip(self))]
pub async fn stop_all_servers(&mut self) -> Result<()> {
tracing::info!("Stopping all servers and proxy if running");
let server_ids: Vec<ServerId> = self.servers.keys().copied().collect();
if let Some(proxy_handle) = self.sse_proxy_handle.take() {
tracing::info!("Stopping SSE proxy");
if let Err(e) = proxy_handle.shutdown().await {
tracing::warn!(error = %e, "Error shutting down SSE proxy");
}
tracing::info!("SSE proxy stopped");
}
let mut errors = Vec::new();
for id in server_ids {
match self.stop_server(id).await {
Ok(_) => {}
Err(e) => {
tracing::error!(server_id = ?id, error = %e, "Failed to stop server");
errors.push((id, e));
}
}
}
if errors.is_empty() {
tracing::info!("All servers stopped successfully");
Ok(())
} else {
tracing::warn!(error_count = errors.len(), "Some servers failed to stop");
if errors.len() == 1 {
return Err(errors.remove(0).1);
} else {
let error_msg = errors
.iter()
.map(|(id, e)| format!("{:?}: {}", id, e))
.collect::<Vec<_>>()
.join("; ");
return Err(Error::Other(format!(
"Multiple servers failed to stop: {}",
error_msg
)));
}
}
}
#[tracing::instrument(skip(self), fields(server_id = %id))]
pub fn server_status(&self, id: ServerId) -> Result<ServerStatus> {
tracing::debug!("Getting server status");
self.servers
.get(&id)
.map(|server| {
let status = server.status();
tracing::trace!(status = ?status);
status
})
.ok_or_else(|| {
tracing::warn!("Status requested for unknown server");
Error::ServerNotFound(format!("{:?}", id))
})
}
#[tracing::instrument(skip(self), fields(server_name = %name))]
pub fn get_server_id(&self, name: &str) -> Result<ServerId> {
tracing::debug!("Getting server ID by name");
self.server_names.get(name).copied().ok_or_else(|| {
tracing::warn!("Server ID requested for unknown server name");
Error::ServerNotFound(name.to_string())
})
}
#[tracing::instrument(skip(self), fields(server_id = %id))]
pub fn get_client(&mut self, id: ServerId) -> Result<McpClient> {
tracing::info!("Getting client for server");
if let Some(Some(_client)) = self.clients.get(&id) {
tracing::debug!("Client already exists in cache");
return Err(Error::ClientAlreadyCached);
}
if let Some(None) = self.clients.get(&id) {
tracing::warn!("Previously failed to create client for this server");
return Err(Error::ServerNotFound(format!(
"{:?} (client creation previously failed)",
id
)));
}
let server = self.servers.get_mut(&id).ok_or_else(|| {
tracing::error!("Client requested for unknown or stopped server");
Error::ServerNotFound(format!("{:?}", id))
})?;
let server_name = server.name().to_string();
tracing::debug!(server_name = %server_name, "Found server process");
tracing::debug!("Taking stdin/stdout from server process");
let stdin = match server.take_stdin() {
Ok(stdin) => stdin,
Err(e) => {
tracing::error!(error = %e, "Failed to take stdin from server");
self.clients.insert(id, None);
return Err(e);
}
};
let stdout = match server.take_stdout() {
Ok(stdout) => stdout,
Err(e) => {
tracing::error!(error = %e, "Failed to take stdout from server");
self.clients.insert(id, None);
return Err(e);
}
};
tracing::debug!("Creating StdioTransport and McpClient");
let transport = StdioTransport::new(server_name.clone(), stdin, stdout);
let client = McpClient::new(server_name, transport);
self.clients.insert(id, Some(client.clone()));
tracing::info!("Client created successfully");
Ok(client)
}
#[tracing::instrument(skip(self))]
pub async fn start_sse_proxy(&mut self) -> Result<()> {
if let Some(proxy_config) = &self.config.sse_proxy {
tracing::info!("Initializing SSE proxy server");
let runner_access = SSEProxyRunnerAccess {
get_server_id: Arc::new({
let self_clone = self.clone(); move |name: &str| self_clone.get_server_id(name)
}),
get_client: Arc::new({
let self_clone = self.clone(); move |id: ServerId| {
let servers = &self_clone.servers;
if let Some(server) = servers.get(&id) {
let server_name = server.name().to_string();
match McpClient::connect(&server_name, &self_clone.config) {
Ok(client) => Ok(client),
Err(e) => {
tracing::error!(error = %e, server_id = ?id, "Failed to create client for SSE proxy");
Err(e)
}
}
} else {
Err(Error::ServerNotFound(format!("{:?}", id)))
}
}
}),
get_allowed_servers: Arc::new({
let config = self.config.clone(); move || {
config
.sse_proxy
.as_ref()
.and_then(|proxy_config| proxy_config.allowed_servers.clone())
}
}),
get_server_config_keys: Arc::new({
let config = self.config.clone(); move || {
config.mcp_servers.keys().cloned().collect()
}
}),
};
let proxy_config_owned = proxy_config.clone();
let proxy_handle = SSEProxy::start_proxy(runner_access, proxy_config_owned).await?;
self.sse_proxy_handle = Some(proxy_handle);
tracing::info!(
"SSE proxy server started on {}:{}",
proxy_config.address,
proxy_config.port
);
Ok(())
} else {
tracing::warn!("SSE proxy not configured, skipping start");
Err(Error::Other(
"SSE proxy not configured in config".to_string(),
))
}
}
#[tracing::instrument(skip(self))]
pub fn is_sse_proxy_configured(&self) -> bool {
self.config.sse_proxy.is_some()
}
#[tracing::instrument(skip(self))]
pub fn get_sse_proxy_config(&self) -> Result<&config::SSEProxyConfig> {
tracing::debug!("Getting SSE proxy configuration");
self.config.sse_proxy.as_ref().ok_or_else(|| {
tracing::warn!("SSE proxy configuration requested but not configured");
Error::Other("SSE proxy not configured".to_string())
})
}
#[tracing::instrument(skip(self))]
pub fn get_sse_proxy_handle(&self) -> Result<&SSEProxyHandle> {
tracing::debug!("Getting SSE proxy handle");
self.sse_proxy_handle.as_ref().ok_or_else(|| {
tracing::warn!("SSE proxy handle requested but no proxy is running");
Error::Other("SSE proxy not running".to_string())
})
}
#[tracing::instrument(skip(self))]
pub fn get_all_server_statuses(&self) -> HashMap<String, ServerStatus> {
tracing::debug!("Getting status for all running servers");
let mut statuses = HashMap::new();
for (server_name, server_id) in &self.server_names {
if let Some(server) = self.servers.get(server_id) {
let status = server.status();
statuses.insert(server_name.clone(), status);
tracing::trace!(server = %server_name, status = ?status);
}
}
tracing::debug!(num_servers = statuses.len(), "Collected server statuses");
statuses
}
#[tracing::instrument(skip(self), fields(server_name = %name))]
pub async fn get_server_tools(&mut self, name: &str) -> Result<Vec<client::Tool>> {
tracing::info!("Getting tools for server '{}'", name);
let server_id = self.get_server_id(name)?;
let client_from_cache = if let Some(Some(_client)) = self.clients.get(&server_id) {
tracing::debug!("Using cached client");
true
} else {
false
};
let result: Result<Vec<client::Tool>> = if client_from_cache {
let client = self.clients.get(&server_id).unwrap().as_ref().unwrap();
client.initialize().await.map_err(|e| {
tracing::error!(error = %e, "Failed to initialize client");
e
})?;
client.list_tools().await.map_err(|e| {
tracing::error!(error = %e, "Failed to list tools for server");
e
})
} else {
match self.get_client(server_id) {
Ok(client) => {
client.initialize().await.map_err(|e| {
tracing::error!(error = %e, "Failed to initialize client");
e
})?;
client.list_tools().await.map_err(|e| {
tracing::error!(error = %e, "Failed to list tools for server");
e
})
}
Err(e) => {
tracing::error!(error = %e, "Failed to get client");
Err(e)
}
}
};
match &result {
Ok(tools) => {
let tools_len = tools.len();
tracing::info!(server = %name, num_tools = tools_len, "Successfully retrieved tools");
}
Err(e) => {
tracing::error!(server = %name, error = %e, "Failed to get tools");
}
}
result
}
#[tracing::instrument(skip(self))]
pub async fn get_all_server_tools(&mut self) -> HashMap<String, Result<Vec<client::Tool>>> {
tracing::debug!("Getting tools for all running servers");
let mut all_tools = HashMap::new();
let server_names: Vec<String> = self.server_names.keys().cloned().collect();
for name in server_names {
tracing::debug!(server = %name, "Getting tools");
let result = tokio::time::timeout(
std::time::Duration::from_secs(15),
self.get_server_tools(&name),
)
.await;
let final_result = match result {
Ok(inner_result) => inner_result,
Err(_) => {
tracing::warn!(server = %name, "Timed out getting tools");
Err(Error::Timeout(format!(
"Tool listing for server '{}' timed out",
name
)))
}
};
all_tools.insert(name, final_result);
}
tracing::debug!(
num_servers = all_tools.len(),
"Collected tools for all servers"
);
all_tools
}
#[tracing::instrument(skip(self))]
fn get_server_info_snapshot(&self) -> HashMap<String, ServerInfo> {
tracing::debug!("Creating server information snapshot for SSE proxy");
let mut server_info = HashMap::new();
for (name, id) in &self.server_names {
if let Some(server) = self.servers.get(id) {
let status = server.status();
server_info.insert(
name.clone(),
ServerInfo {
name: name.clone(),
id: format!("{:?}", id),
status: format!("{:?}", status),
},
);
tracing::trace!(server = %name, status = ?status, "Added server to snapshot");
}
}
server_info
}
}
impl Clone for McpRunner {
fn clone(&self) -> Self {
Self {
config: self.config.clone(),
servers: self.servers.clone(),
server_names: self.server_names.clone(),
sse_proxy_handle: self.sse_proxy_handle.clone(),
clients: HashMap::new(), }
}
}