1use crate::client::McpClient;
7use crate::config::{DEFAULT_WORKERS, SSEProxyConfig};
8use crate::error::{Error, Result};
9use crate::server::ServerId;
10use crate::sse_proxy::auth::Authentication;
11use crate::sse_proxy::events::EventManager;
12use crate::sse_proxy::handlers;
13use crate::sse_proxy::types::{ServerInfo, ServerInfoUpdate};
14
15use actix_cors::Cors;
16use actix_web::{
17 App, HttpServer, middleware,
18 web::{self, Data},
19};
20
21use std::collections::HashMap;
22use std::net::ToSocketAddrs;
23use std::sync::Arc;
24use std::sync::atomic::{AtomicBool, Ordering};
25use tokio::sync::{Mutex, mpsc};
26use tokio::task::JoinHandle;
27use tracing;
28
29type ServerIdRetriever = dyn Fn(&str) -> Result<ServerId> + Send + Sync;
31type ClientRetriever = dyn Fn(ServerId) -> Result<McpClient> + Send + Sync;
33type AllowedServersRetriever = dyn Fn() -> Option<Vec<String>> + Send + Sync;
35type ServerConfigKeysRetriever = dyn Fn() -> Vec<String> + Send + Sync;
37
38#[derive(Clone)]
44pub struct SSEProxyHandle {
45 server_tx: mpsc::Sender<ServerInfoUpdate>,
47 handle: Arc<Mutex<Option<JoinHandle<()>>>>,
49 config: SSEProxyConfig,
51 shutdown_flag: Arc<AtomicBool>,
53}
54
55impl SSEProxyHandle {
56 fn new(
58 server_tx: mpsc::Sender<ServerInfoUpdate>,
59 handle: JoinHandle<()>,
60 config: SSEProxyConfig,
61 shutdown_flag: Arc<AtomicBool>,
62 ) -> Self {
63 Self {
64 server_tx,
65 handle: Arc::new(Mutex::new(Some(handle))),
66 config,
67 shutdown_flag,
68 }
69 }
70
71 pub async fn update_server_info(
73 &self,
74 server_name: &str,
75 server_id: Option<ServerId>,
76 status: &str,
77 ) -> Result<()> {
78 let update = ServerInfoUpdate::UpdateServer {
79 name: server_name.to_string(),
80 id: server_id,
81 status: status.to_string(),
82 };
83
84 self.server_tx.send(update).await.map_err(|e| {
85 Error::Communication(format!("Failed to send server info update to proxy: {}", e))
86 })
87 }
88
89 pub async fn add_server_info(&self, server_name: &str, server_info: ServerInfo) -> Result<()> {
91 let update = ServerInfoUpdate::AddServer {
92 name: server_name.to_string(),
93 info: server_info,
94 };
95
96 self.server_tx.send(update).await.map_err(|e| {
97 Error::Communication(format!("Failed to send server info update to proxy: {}", e))
98 })
99 }
100
101 pub async fn shutdown(&self) -> Result<()> {
103 self.shutdown_flag.store(true, Ordering::SeqCst);
105
106 let _ = self.server_tx.send(ServerInfoUpdate::Shutdown).await;
108
109 let mut handle = self.handle.lock().await;
111 if let Some(h) = handle.take() {
112 match tokio::time::timeout(std::time::Duration::from_secs(5), h).await {
114 Ok(result) => {
115 if let Err(e) = result {
116 tracing::warn!("Error while joining proxy task: {}", e);
117 }
118 }
119 Err(_) => {
120 tracing::warn!("Timeout waiting for proxy task to finish");
121 }
122 }
123 }
124
125 Ok(())
126 }
127
128 pub fn config(&self) -> &SSEProxyConfig {
130 &self.config
131 }
132}
133
134#[derive(Clone)]
140pub struct SSEProxyRunnerAccess {
141 pub get_server_id: Arc<ServerIdRetriever>,
143 pub get_client: Arc<ClientRetriever>,
145 pub get_allowed_servers: Arc<AllowedServersRetriever>,
147 pub get_server_config_keys: Arc<ServerConfigKeysRetriever>,
149}
150
151pub struct SSEProxy {
157 config: SSEProxyConfig,
159 runner_access: SSEProxyRunnerAccess,
161 event_manager: Arc<EventManager>,
163 server_info: Arc<Mutex<HashMap<String, ServerInfo>>>,
165 server_rx: mpsc::Receiver<ServerInfoUpdate>,
167 shutdown_flag: Arc<AtomicBool>,
169}
170
171impl Clone for SSEProxy {
173 fn clone(&self) -> Self {
174 let (_, dummy_rx) = mpsc::channel::<ServerInfoUpdate>(1);
176
177 Self {
178 config: self.config.clone(),
179 runner_access: self.runner_access.clone(),
180 event_manager: self.event_manager.clone(),
181 server_info: self.server_info.clone(),
182 server_rx: dummy_rx, shutdown_flag: self.shutdown_flag.clone(),
184 }
185 }
186}
187
188impl SSEProxy {
189 fn new(
201 runner_access: SSEProxyRunnerAccess,
202 config: SSEProxyConfig,
203 server_rx: mpsc::Receiver<ServerInfoUpdate>,
204 ) -> Self {
205 let event_manager = Arc::new(EventManager::new(100)); let server_info = Arc::new(Mutex::new(HashMap::new()));
210
211 Self {
212 config,
213 runner_access,
214 event_manager,
215 server_info,
216 server_rx,
217 shutdown_flag: Arc::new(AtomicBool::new(false)),
218 }
219 }
220
221 pub async fn start_proxy(
235 runner_access: SSEProxyRunnerAccess,
236 config: SSEProxyConfig,
237 ) -> Result<SSEProxyHandle> {
238 let (server_tx, server_rx) = mpsc::channel(32);
240 let server_tx_clone = server_tx.clone();
241
242 let shutdown_flag = Arc::new(AtomicBool::new(false));
244 let shutdown_flag_clone = shutdown_flag.clone();
245
246 let mut proxy = Self::new(runner_access.clone(), config.clone(), server_rx);
248
249 let server_names = (runner_access.get_server_config_keys)();
252
253 {
255 let mut server_info = proxy.server_info.lock().await;
257
258 for name in &server_names {
260 if let Ok(server_id) = (runner_access.get_server_id)(name) {
262 let id_str = format!("{:?}", server_id);
264
265 let info = ServerInfo {
267 name: name.clone(),
268 id: id_str.clone(),
269 status: "Running".to_string(),
270 };
271
272 server_info.insert(name.clone(), info);
274
275 tracing::debug!(server = %name, id = %id_str, "Added server to initial cache");
276 }
277 }
278
279 tracing::info!(
280 num_servers = server_info.len(),
281 "Initialized server information cache with running servers"
282 );
283 } let addr_str = format!("{}:{}", proxy.config.address, proxy.config.port);
287 let addr = match addr_str.to_socket_addrs() {
288 Ok(mut addrs) => match addrs.next() {
289 Some(addr) => addr,
290 None => {
291 return Err(Error::Other(format!(
292 "Could not parse socket address: {}",
293 addr_str
294 )));
295 }
296 },
297 Err(e) => {
298 return Err(Error::Other(format!(
299 "Failed to parse socket address: {}",
300 e
301 )));
302 }
303 };
304
305 tracing::info!(address = %addr_str, "Starting SSE proxy server with Actix Web");
306
307 let event_manager = Data::new(proxy.event_manager.clone());
309 let config_arc = Arc::new(proxy.config.clone());
310
311 let runner_access_for_handlers = proxy.runner_access.clone();
313 let server_info_for_handlers = proxy.server_info.clone();
314 let event_mgr_for_handlers = proxy.event_manager.clone();
315 let shutdown_flag_for_handlers = proxy.shutdown_flag.clone();
316
317 let proxy_for_handlers = SSEProxy {
319 config: proxy.config.clone(),
320 runner_access: runner_access_for_handlers,
321 event_manager: event_mgr_for_handlers,
322 server_info: server_info_for_handlers,
323 server_rx: {
325 let (_, rx) = mpsc::channel::<ServerInfoUpdate>(1);
326 rx
327 },
328 shutdown_flag: shutdown_flag_for_handlers,
329 };
330
331 let proxy_data = Data::new(Arc::new(Mutex::new(proxy_for_handlers)));
332
333 let mut server_builder = HttpServer::new(move || {
335 let cors = Cors::default()
337 .allow_any_origin()
338 .allow_any_method()
339 .allow_any_header()
340 .max_age(3600);
341
342 let auth_middleware = Authentication::new(config_arc.clone());
344
345 App::new()
346 .wrap(middleware::Logger::default())
347 .wrap(cors)
348 .app_data(event_manager.clone()) .app_data(proxy_data.clone()) .app_data(Data::new(config_arc.clone())) .wrap(auth_middleware)
353 .route("/sse", web::get().to(handlers::sse_main_endpoint))
355 .route("/sse/messages", web::post().to(handlers::sse_messages))
356 });
357
358 let workers = proxy.config.workers.unwrap_or(DEFAULT_WORKERS);
360 tracing::info!(workers = workers, "Setting number of Actix Web workers");
361 server_builder = server_builder.workers(workers);
362
363 let server = server_builder
365 .bind(addr)
366 .map_err(|e| Error::Other(format!("Failed to bind server: {}", e)))?
367 .run();
368
369 let server_handle = server.handle();
371
372 let server_task = tokio::spawn(server);
375
376 let update_handle = tokio::spawn(async move {
378 if let Err(e) = proxy.process_updates(server_handle).await {
379 tracing::error!(error = %e, "SSE proxy update processor error");
380 }
381 });
382
383 tracing::info!("SSE proxy server started successfully");
384
385 let handle = tokio::spawn(async move {
387 let (server_result, update_result) = tokio::join!(server_task, update_handle);
389
390 if let Err(e) = server_result {
392 tracing::error!(error = %e, "Actix server task error");
393 }
394 if let Err(e) = update_result {
395 tracing::error!(error = %e, "Update processor task error");
396 }
397
398 tracing::info!("SSE proxy server shut down completely");
399 });
400
401 Ok(SSEProxyHandle::new(
403 server_tx_clone,
404 handle,
405 config,
406 shutdown_flag_clone,
407 ))
408 }
409
410 async fn process_updates(&mut self, server_handle: actix_web::dev::ServerHandle) -> Result<()> {
418 tracing::info!("SSE proxy update processor started");
419
420 while !self.shutdown_flag.load(Ordering::SeqCst) {
422 match tokio::time::timeout(
424 tokio::time::Duration::from_millis(100),
425 self.server_rx.recv(),
426 )
427 .await
428 {
429 Ok(Some(update)) => match update {
430 ServerInfoUpdate::UpdateServer { name, id, status } => {
431 let mut servers = self.server_info.lock().await;
433
434 if let Some(server_info) = servers.get_mut(&name) {
435 if let Some(server_id) = id {
436 server_info.id = format!("{:?}", server_id);
437 }
438 server_info.status = status.clone();
439
440 self.event_manager
442 .send_server_status(&name, &server_info.id, &status);
443
444 tracing::debug!(server = %name, status = %status, "Updated server status");
445 } else {
446 let server_info = ServerInfo {
448 name: name.clone(),
449 id: id.map_or_else(
450 || "unknown".to_string(),
451 |id| format!("{:?}", id),
452 ),
453 status: status.clone(),
454 };
455
456 servers.insert(name.clone(), server_info.clone());
457
458 self.event_manager
460 .send_server_status(&name, &server_info.id, &status);
461
462 tracing::debug!(server = %name, status = %status, "Added server to cache");
463 }
464 }
465 ServerInfoUpdate::AddServer { name, info } => {
466 let mut servers = self.server_info.lock().await;
468 servers.insert(name.clone(), info.clone());
469
470 self.event_manager
472 .send_server_status(&name, &info.id, &info.status);
473
474 tracing::debug!(server = %name, "Added server to cache");
475 }
476 ServerInfoUpdate::Shutdown => {
477 tracing::info!("Received shutdown message");
478 self.shutdown_flag.store(true, Ordering::SeqCst);
479 break;
480 }
481 },
482 Ok(None) => {
483 tracing::info!("Server information channel closed, shutting down proxy");
485 self.shutdown_flag.store(true, Ordering::SeqCst);
486 break;
487 }
488 Err(_) => {
489 }
491 }
492 }
493
494 tracing::info!("Stopping Actix Web server");
496 server_handle.stop(true).await;
497
498 tracing::info!("SSE proxy update processor shut down");
499 Ok(())
500 }
501
502 pub async fn process_tool_call(
515 &self,
516 server_name: &str,
517 tool_name: &str,
518 args: serde_json::Value,
519 request_id: &str,
520 ) -> Result<()> {
521 tracing::debug!(server = %server_name, tool = %tool_name, req_id = %request_id, "Processing tool call");
522
523 if let Some(allowed_servers) = (self.runner_access.get_allowed_servers)() {
525 if !allowed_servers.contains(&server_name.to_string()) {
526 tracing::warn!(server = %server_name, "Server not in allowed list");
527
528 self.event_manager.send_tool_error(
530 request_id,
531 "unknown", tool_name,
533 &format!("Server not in allowed list: {}", server_name),
534 );
535
536 return Err(Error::Unauthorized(
537 "Server not in allowed list".to_string(),
538 ));
539 }
540 }
541
542 let server_id = match (self.runner_access.get_server_id)(server_name) {
544 Ok(id) => id,
545 Err(e) => {
546 tracing::warn!(server = %server_name, error = %e, "Server not found");
547
548 self.event_manager.send_tool_error(
550 request_id,
551 "unknown", tool_name,
553 &format!("Server not found: {}", server_name),
554 );
555
556 return Err(e);
557 }
558 };
559 let server_id_str = format!("{:?}", server_id);
560
561 let client = match (self.runner_access.get_client)(server_id) {
563 Ok(c) => c,
564 Err(e) => {
565 tracing::error!(server_id = ?server_id, error = %e, "Failed to get client");
566
567 self.event_manager.send_tool_error(
569 request_id,
570 &server_id_str,
571 tool_name,
572 &format!("Failed to get client: {}", e),
573 );
574
575 return Err(e);
576 }
577 };
578
579 if let Err(e) = client.initialize().await {
581 tracing::error!(server_id = ?server_id, error = %e, "Failed to initialize client");
582
583 self.event_manager.send_tool_error(
585 request_id,
586 &server_id_str,
587 tool_name,
588 &format!("Failed to initialize client: {}", e),
589 );
590
591 return Err(e);
592 }
593
594 let result: Result<serde_json::Value> = client.call_tool(tool_name, &args).await;
596
597 match result {
598 Ok(response) => {
599 tracing::debug!(req_id = %request_id, "Tool call successful");
600
601 self.event_manager.send_tool_response(
604 request_id,
605 &server_id_str,
606 tool_name,
607 response,
608 );
609
610 Ok(())
611 }
612 Err(e) => {
613 tracing::error!(req_id = %request_id, error = %e, "Tool call failed");
614
615 self.event_manager.send_tool_error(
617 request_id,
618 &server_id_str,
619 tool_name,
620 &format!("Tool call failed: {}", e),
621 );
622
623 Err(e)
624 }
625 }
626 }
627
628 pub fn get_server_info(&self) -> &Arc<Mutex<HashMap<String, ServerInfo>>> {
630 &self.server_info
631 }
632
633 pub fn get_runner_access(&self) -> &SSEProxyRunnerAccess {
635 &self.runner_access
636 }
637
638 pub fn event_manager(&self) -> &Arc<EventManager> {
640 &self.event_manager
641 }
642
643 pub fn config(&self) -> &SSEProxyConfig {
645 &self.config
646 }
647}
648
649pub struct SSEProxySharedState {
654 runner_access: SSEProxyRunnerAccess,
656 event_manager: Arc<EventManager>,
658 server_info: Arc<Mutex<HashMap<String, ServerInfo>>>,
660}
661
662impl SSEProxySharedState {
663 pub fn runner_access(&self) -> &SSEProxyRunnerAccess {
664 &self.runner_access
665 }
666
667 pub fn event_manager(&self) -> &Arc<EventManager> {
668 &self.event_manager
669 }
670
671 pub fn server_info(&self) -> &Arc<Mutex<HashMap<String, ServerInfo>>> {
672 &self.server_info
673 }
674}