use crate::client::McpClient;
use crate::config::{DEFAULT_WORKERS, SSEProxyConfig};
use crate::error::{Error, Result};
use crate::server::ServerId;
use crate::sse_proxy::auth::Authentication;
use crate::sse_proxy::events::EventManager;
use crate::sse_proxy::handlers;
use crate::sse_proxy::types::{ServerInfo, ServerInfoUpdate};
use actix_cors::Cors;
use actix_web::{
App, HttpServer, middleware,
web::{self, Data},
};
use std::collections::HashMap;
use std::net::ToSocketAddrs;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use tokio::sync::{Mutex, mpsc};
use tokio::task::JoinHandle;
use tracing;
type ServerIdRetriever = dyn Fn(&str) -> Result<ServerId> + Send + Sync;
type ClientRetriever = dyn Fn(ServerId) -> Result<McpClient> + Send + Sync;
type AllowedServersRetriever = dyn Fn() -> Option<Vec<String>> + Send + Sync;
type ServerConfigKeysRetriever = dyn Fn() -> Vec<String> + Send + Sync;
#[derive(Clone)]
pub struct SSEProxyHandle {
server_tx: mpsc::Sender<ServerInfoUpdate>,
handle: Arc<Mutex<Option<JoinHandle<()>>>>,
config: SSEProxyConfig,
shutdown_flag: Arc<AtomicBool>,
}
impl SSEProxyHandle {
fn new(
server_tx: mpsc::Sender<ServerInfoUpdate>,
handle: JoinHandle<()>,
config: SSEProxyConfig,
shutdown_flag: Arc<AtomicBool>,
) -> Self {
Self {
server_tx,
handle: Arc::new(Mutex::new(Some(handle))),
config,
shutdown_flag,
}
}
pub async fn update_server_info(
&self,
server_name: &str,
server_id: Option<ServerId>,
status: &str,
) -> Result<()> {
let update = ServerInfoUpdate::UpdateServer {
name: server_name.to_string(),
id: server_id,
status: status.to_string(),
};
self.server_tx.send(update).await.map_err(|e| {
Error::Communication(format!("Failed to send server info update to proxy: {}", e))
})
}
pub async fn add_server_info(&self, server_name: &str, server_info: ServerInfo) -> Result<()> {
let update = ServerInfoUpdate::AddServer {
name: server_name.to_string(),
info: server_info,
};
self.server_tx.send(update).await.map_err(|e| {
Error::Communication(format!("Failed to send server info update to proxy: {}", e))
})
}
pub async fn shutdown(&self) -> Result<()> {
self.shutdown_flag.store(true, Ordering::SeqCst);
let _ = self.server_tx.send(ServerInfoUpdate::Shutdown).await;
let mut handle = self.handle.lock().await;
if let Some(h) = handle.take() {
match tokio::time::timeout(std::time::Duration::from_secs(5), h).await {
Ok(result) => {
if let Err(e) = result {
tracing::warn!("Error while joining proxy task: {}", e);
}
}
Err(_) => {
tracing::warn!("Timeout waiting for proxy task to finish");
}
}
}
Ok(())
}
pub fn config(&self) -> &SSEProxyConfig {
&self.config
}
}
#[derive(Clone)]
pub struct SSEProxyRunnerAccess {
pub get_server_id: Arc<ServerIdRetriever>,
pub get_client: Arc<ClientRetriever>,
pub get_allowed_servers: Arc<AllowedServersRetriever>,
pub get_server_config_keys: Arc<ServerConfigKeysRetriever>,
}
pub struct SSEProxy {
config: SSEProxyConfig,
runner_access: SSEProxyRunnerAccess,
event_manager: Arc<EventManager>,
server_info: Arc<Mutex<HashMap<String, ServerInfo>>>,
server_rx: mpsc::Receiver<ServerInfoUpdate>,
shutdown_flag: Arc<AtomicBool>,
}
impl Clone for SSEProxy {
fn clone(&self) -> Self {
let (_, dummy_rx) = mpsc::channel::<ServerInfoUpdate>(1);
Self {
config: self.config.clone(),
runner_access: self.runner_access.clone(),
event_manager: self.event_manager.clone(),
server_info: self.server_info.clone(),
server_rx: dummy_rx, shutdown_flag: self.shutdown_flag.clone(),
}
}
}
impl SSEProxy {
fn new(
runner_access: SSEProxyRunnerAccess,
config: SSEProxyConfig,
server_rx: mpsc::Receiver<ServerInfoUpdate>,
) -> Self {
let event_manager = Arc::new(EventManager::new(100));
let server_info = Arc::new(Mutex::new(HashMap::new()));
Self {
config,
runner_access,
event_manager,
server_info,
server_rx,
shutdown_flag: Arc::new(AtomicBool::new(false)),
}
}
pub async fn start_proxy(
runner_access: SSEProxyRunnerAccess,
config: SSEProxyConfig,
) -> Result<SSEProxyHandle> {
let (server_tx, server_rx) = mpsc::channel(32);
let server_tx_clone = server_tx.clone();
let shutdown_flag = Arc::new(AtomicBool::new(false));
let shutdown_flag_clone = shutdown_flag.clone();
let mut proxy = Self::new(runner_access.clone(), config.clone(), server_rx);
let server_names = (runner_access.get_server_config_keys)();
{
let mut server_info = proxy.server_info.lock().await;
for name in &server_names {
if let Ok(server_id) = (runner_access.get_server_id)(name) {
let id_str = format!("{:?}", server_id);
let info = ServerInfo {
name: name.clone(),
id: id_str.clone(),
status: "Running".to_string(),
};
server_info.insert(name.clone(), info);
tracing::debug!(server = %name, id = %id_str, "Added server to initial cache");
}
}
tracing::info!(
num_servers = server_info.len(),
"Initialized server information cache with running servers"
);
}
let addr_str = format!("{}:{}", proxy.config.address, proxy.config.port);
let addr = match addr_str.to_socket_addrs() {
Ok(mut addrs) => match addrs.next() {
Some(addr) => addr,
None => {
return Err(Error::Other(format!(
"Could not parse socket address: {}",
addr_str
)));
}
},
Err(e) => {
return Err(Error::Other(format!(
"Failed to parse socket address: {}",
e
)));
}
};
tracing::info!(address = %addr_str, "Starting SSE proxy server with Actix Web");
let event_manager = Data::new(proxy.event_manager.clone());
let config_arc = Arc::new(proxy.config.clone());
let runner_access_for_handlers = proxy.runner_access.clone();
let server_info_for_handlers = proxy.server_info.clone();
let event_mgr_for_handlers = proxy.event_manager.clone();
let shutdown_flag_for_handlers = proxy.shutdown_flag.clone();
let proxy_for_handlers = SSEProxy {
config: proxy.config.clone(),
runner_access: runner_access_for_handlers,
event_manager: event_mgr_for_handlers,
server_info: server_info_for_handlers,
server_rx: {
let (_, rx) = mpsc::channel::<ServerInfoUpdate>(1);
rx
},
shutdown_flag: shutdown_flag_for_handlers,
};
let proxy_data = Data::new(Arc::new(Mutex::new(proxy_for_handlers)));
let mut server_builder = HttpServer::new(move || {
let cors = Cors::default()
.allow_any_origin()
.allow_any_method()
.allow_any_header()
.max_age(3600);
let auth_middleware = Authentication::new(config_arc.clone());
App::new()
.wrap(middleware::Logger::default())
.wrap(cors)
.app_data(event_manager.clone()) .app_data(proxy_data.clone()) .app_data(Data::new(config_arc.clone())) .wrap(auth_middleware)
.route("/sse", web::get().to(handlers::sse_main_endpoint))
.route("/sse/messages", web::post().to(handlers::sse_messages))
});
let workers = proxy.config.workers.unwrap_or(DEFAULT_WORKERS);
tracing::info!(workers = workers, "Setting number of Actix Web workers");
server_builder = server_builder.workers(workers);
let server = server_builder
.bind(addr)
.map_err(|e| Error::Other(format!("Failed to bind server: {}", e)))?
.run();
let server_handle = server.handle();
let server_task = tokio::spawn(server);
let update_handle = tokio::spawn(async move {
if let Err(e) = proxy.process_updates(server_handle).await {
tracing::error!(error = %e, "SSE proxy update processor error");
}
});
tracing::info!("SSE proxy server started successfully");
let handle = tokio::spawn(async move {
let (server_result, update_result) = tokio::join!(server_task, update_handle);
if let Err(e) = server_result {
tracing::error!(error = %e, "Actix server task error");
}
if let Err(e) = update_result {
tracing::error!(error = %e, "Update processor task error");
}
tracing::info!("SSE proxy server shut down completely");
});
Ok(SSEProxyHandle::new(
server_tx_clone,
handle,
config,
shutdown_flag_clone,
))
}
async fn process_updates(&mut self, server_handle: actix_web::dev::ServerHandle) -> Result<()> {
tracing::info!("SSE proxy update processor started");
while !self.shutdown_flag.load(Ordering::SeqCst) {
match tokio::time::timeout(
tokio::time::Duration::from_millis(100),
self.server_rx.recv(),
)
.await
{
Ok(Some(update)) => match update {
ServerInfoUpdate::UpdateServer { name, id, status } => {
let mut servers = self.server_info.lock().await;
if let Some(server_info) = servers.get_mut(&name) {
if let Some(server_id) = id {
server_info.id = format!("{:?}", server_id);
}
server_info.status = status.clone();
self.event_manager
.send_server_status(&name, &server_info.id, &status);
tracing::debug!(server = %name, status = %status, "Updated server status");
} else {
let server_info = ServerInfo {
name: name.clone(),
id: id.map_or_else(
|| "unknown".to_string(),
|id| format!("{:?}", id),
),
status: status.clone(),
};
servers.insert(name.clone(), server_info.clone());
self.event_manager
.send_server_status(&name, &server_info.id, &status);
tracing::debug!(server = %name, status = %status, "Added server to cache");
}
}
ServerInfoUpdate::AddServer { name, info } => {
let mut servers = self.server_info.lock().await;
servers.insert(name.clone(), info.clone());
self.event_manager
.send_server_status(&name, &info.id, &info.status);
tracing::debug!(server = %name, "Added server to cache");
}
ServerInfoUpdate::Shutdown => {
tracing::info!("Received shutdown message");
self.shutdown_flag.store(true, Ordering::SeqCst);
break;
}
},
Ok(None) => {
tracing::info!("Server information channel closed, shutting down proxy");
self.shutdown_flag.store(true, Ordering::SeqCst);
break;
}
Err(_) => {
}
}
}
tracing::info!("Stopping Actix Web server");
server_handle.stop(true).await;
tracing::info!("SSE proxy update processor shut down");
Ok(())
}
pub async fn process_tool_call(
&self,
server_name: &str,
tool_name: &str,
args: serde_json::Value,
request_id: &str,
) -> Result<()> {
tracing::debug!(server = %server_name, tool = %tool_name, req_id = %request_id, "Processing tool call");
if let Some(allowed_servers) = (self.runner_access.get_allowed_servers)() {
if !allowed_servers.contains(&server_name.to_string()) {
tracing::warn!(server = %server_name, "Server not in allowed list");
self.event_manager.send_tool_error(
request_id,
"unknown", tool_name,
&format!("Server not in allowed list: {}", server_name),
);
return Err(Error::Unauthorized(
"Server not in allowed list".to_string(),
));
}
}
let server_id = match (self.runner_access.get_server_id)(server_name) {
Ok(id) => id,
Err(e) => {
tracing::warn!(server = %server_name, error = %e, "Server not found");
self.event_manager.send_tool_error(
request_id,
"unknown", tool_name,
&format!("Server not found: {}", server_name),
);
return Err(e);
}
};
let server_id_str = format!("{:?}", server_id);
let client = match (self.runner_access.get_client)(server_id) {
Ok(c) => c,
Err(e) => {
tracing::error!(server_id = ?server_id, error = %e, "Failed to get client");
self.event_manager.send_tool_error(
request_id,
&server_id_str,
tool_name,
&format!("Failed to get client: {}", e),
);
return Err(e);
}
};
if let Err(e) = client.initialize().await {
tracing::error!(server_id = ?server_id, error = %e, "Failed to initialize client");
self.event_manager.send_tool_error(
request_id,
&server_id_str,
tool_name,
&format!("Failed to initialize client: {}", e),
);
return Err(e);
}
let result: Result<serde_json::Value> = client.call_tool(tool_name, &args).await;
match result {
Ok(response) => {
tracing::debug!(req_id = %request_id, "Tool call successful");
self.event_manager.send_tool_response(
request_id,
&server_id_str,
tool_name,
response,
);
Ok(())
}
Err(e) => {
tracing::error!(req_id = %request_id, error = %e, "Tool call failed");
self.event_manager.send_tool_error(
request_id,
&server_id_str,
tool_name,
&format!("Tool call failed: {}", e),
);
Err(e)
}
}
}
pub fn get_server_info(&self) -> &Arc<Mutex<HashMap<String, ServerInfo>>> {
&self.server_info
}
pub fn get_runner_access(&self) -> &SSEProxyRunnerAccess {
&self.runner_access
}
pub fn event_manager(&self) -> &Arc<EventManager> {
&self.event_manager
}
pub fn config(&self) -> &SSEProxyConfig {
&self.config
}
}
pub struct SSEProxySharedState {
runner_access: SSEProxyRunnerAccess,
event_manager: Arc<EventManager>,
server_info: Arc<Mutex<HashMap<String, ServerInfo>>>,
}
impl SSEProxySharedState {
pub fn runner_access(&self) -> &SSEProxyRunnerAccess {
&self.runner_access
}
pub fn event_manager(&self) -> &Arc<EventManager> {
&self.event_manager
}
pub fn server_info(&self) -> &Arc<Mutex<HashMap<String, ServerInfo>>> {
&self.server_info
}
}