1use crate::client::McpClient;
7use crate::config::{DEFAULT_WORKERS, SSEProxyConfig};
8use crate::error::{Error, Result};
9use crate::server::ServerId;
10use crate::sse_proxy::auth::Authentication;
11use crate::sse_proxy::events::EventManager;
12use crate::sse_proxy::handlers;
13use crate::sse_proxy::types::{ServerInfo, ServerInfoUpdate};
14
15use actix_cors::Cors;
16use actix_web::{
17 App, HttpServer, middleware,
18 web::{self, Data},
19};
20
21use std::collections::HashMap;
22use std::net::ToSocketAddrs;
23use std::sync::Arc;
24use std::sync::atomic::{AtomicBool, Ordering};
25use tokio::sync::{Mutex, mpsc};
26use tokio::task::JoinHandle;
27use tracing;
28
29type ServerIdRetriever = dyn Fn(&str) -> Result<ServerId> + Send + Sync;
31type ClientRetriever = dyn Fn(ServerId) -> Result<McpClient> + Send + Sync;
33type AllowedServersRetriever = dyn Fn() -> Option<Vec<String>> + Send + Sync;
35type ServerConfigKeysRetriever = dyn Fn() -> Vec<String> + Send + Sync;
37
38#[derive(Clone)]
44pub struct SSEProxyHandle {
45 server_tx: mpsc::Sender<ServerInfoUpdate>,
47 handle: Arc<Mutex<Option<JoinHandle<()>>>>,
49 config: SSEProxyConfig,
51 shutdown_flag: Arc<AtomicBool>,
53}
54
55impl SSEProxyHandle {
56 fn new(
58 server_tx: mpsc::Sender<ServerInfoUpdate>,
59 handle: JoinHandle<()>,
60 config: SSEProxyConfig,
61 shutdown_flag: Arc<AtomicBool>,
62 ) -> Self {
63 Self {
64 server_tx,
65 handle: Arc::new(Mutex::new(Some(handle))),
66 config,
67 shutdown_flag,
68 }
69 }
70
71 pub async fn update_server_info(
73 &self,
74 server_name: &str,
75 server_id: Option<ServerId>,
76 status: &str,
77 ) -> Result<()> {
78 let update = ServerInfoUpdate::UpdateServer {
79 name: server_name.to_string(),
80 id: server_id,
81 status: status.to_string(),
82 };
83
84 self.server_tx.send(update).await.map_err(|e| {
85 Error::Communication(format!("Failed to send server info update to proxy: {}", e))
86 })
87 }
88
89 pub async fn add_server_info(&self, server_name: &str, server_info: ServerInfo) -> Result<()> {
91 let update = ServerInfoUpdate::AddServer {
92 name: server_name.to_string(),
93 info: server_info,
94 };
95
96 self.server_tx.send(update).await.map_err(|e| {
97 Error::Communication(format!("Failed to send server info update to proxy: {}", e))
98 })
99 }
100
101 pub async fn shutdown(&self) -> Result<()> {
103 self.shutdown_flag.store(true, Ordering::SeqCst);
105
106 let _ = self.server_tx.send(ServerInfoUpdate::Shutdown).await;
108
109 let mut handle = self.handle.lock().await;
111 if let Some(h) = handle.take() {
112 match tokio::time::timeout(std::time::Duration::from_secs(5), h).await {
114 Ok(result) => {
115 if let Err(e) = result {
116 tracing::warn!("Error while joining proxy task: {}", e);
117 }
118 }
119 Err(_) => {
120 tracing::warn!("Timeout waiting for proxy task to finish");
121 }
122 }
123 }
124
125 Ok(())
126 }
127
128 pub fn config(&self) -> &SSEProxyConfig {
130 &self.config
131 }
132}
133
134#[derive(Clone)]
140pub struct SSEProxyRunnerAccess {
141 pub get_server_id: Arc<ServerIdRetriever>,
143 pub get_client: Arc<ClientRetriever>,
145 pub get_allowed_servers: Arc<AllowedServersRetriever>,
147 pub get_server_config_keys: Arc<ServerConfigKeysRetriever>,
149}
150
151pub struct SSEProxy {
157 config: SSEProxyConfig,
159 runner_access: SSEProxyRunnerAccess,
161 event_manager: Arc<EventManager>,
163 server_info: Arc<Mutex<HashMap<String, ServerInfo>>>,
165 server_rx: mpsc::Receiver<ServerInfoUpdate>,
167 shutdown_flag: Arc<AtomicBool>,
169}
170
171impl Clone for SSEProxy {
173 fn clone(&self) -> Self {
174 let (_, dummy_rx) = mpsc::channel::<ServerInfoUpdate>(1);
176
177 Self {
178 config: self.config.clone(),
179 runner_access: self.runner_access.clone(),
180 event_manager: self.event_manager.clone(),
181 server_info: self.server_info.clone(),
182 server_rx: dummy_rx, shutdown_flag: self.shutdown_flag.clone(),
184 }
185 }
186}
187
188impl SSEProxy {
189 fn new(
201 runner_access: SSEProxyRunnerAccess,
202 config: SSEProxyConfig,
203 server_rx: mpsc::Receiver<ServerInfoUpdate>,
204 ) -> Self {
205 let event_manager = Arc::new(EventManager::new(100)); let server_info = Arc::new(Mutex::new(HashMap::new()));
210
211 Self {
212 config,
213 runner_access,
214 event_manager,
215 server_info,
216 server_rx,
217 shutdown_flag: Arc::new(AtomicBool::new(false)),
218 }
219 }
220
221 pub async fn start_proxy(
235 runner_access: SSEProxyRunnerAccess,
236 config: SSEProxyConfig,
237 ) -> Result<SSEProxyHandle> {
238 let (server_tx, server_rx) = mpsc::channel(32);
240 let server_tx_clone = server_tx.clone();
241
242 let shutdown_flag = Arc::new(AtomicBool::new(false));
244 let shutdown_flag_clone = shutdown_flag.clone();
245
246 let mut proxy = Self::new(runner_access.clone(), config.clone(), server_rx);
248
249 let server_names = (runner_access.get_server_config_keys)();
252
253 {
255 let mut server_info = proxy.server_info.lock().await;
257
258 for name in &server_names {
260 if let Ok(server_id) = (runner_access.get_server_id)(name) {
262 let id_str = format!("{:?}", server_id);
264
265 let info = ServerInfo {
267 name: name.clone(),
268 id: id_str.clone(),
269 status: "Running".to_string(),
270 };
271
272 server_info.insert(name.clone(), info);
274
275 tracing::debug!(server = %name, id = %id_str, "Added server to initial cache");
276 }
277 }
278
279 tracing::info!(
280 num_servers = server_info.len(),
281 "Initialized server information cache with running servers"
282 );
283 } let addr_str = format!("{}:{}", proxy.config.address, proxy.config.port);
287 let addr = match addr_str.to_socket_addrs() {
288 Ok(mut addrs) => match addrs.next() {
289 Some(addr) => addr,
290 None => {
291 return Err(Error::Other(format!(
292 "Could not parse socket address: {}",
293 addr_str
294 )));
295 }
296 },
297 Err(e) => {
298 return Err(Error::Other(format!(
299 "Failed to parse socket address: {}",
300 e
301 )));
302 }
303 };
304
305 tracing::info!(address = %addr_str, "Starting SSE proxy server with Actix Web");
306
307 let event_manager = Data::new(proxy.event_manager.clone());
309 let config_arc = Arc::new(proxy.config.clone());
310
311 let runner_access_for_handlers = proxy.runner_access.clone();
313 let server_info_for_handlers = proxy.server_info.clone();
314 let event_mgr_for_handlers = proxy.event_manager.clone();
315 let shutdown_flag_for_handlers = proxy.shutdown_flag.clone();
316
317 let proxy_for_handlers = SSEProxy {
319 config: proxy.config.clone(),
320 runner_access: runner_access_for_handlers,
321 event_manager: event_mgr_for_handlers,
322 server_info: server_info_for_handlers,
323 server_rx: {
325 let (_, rx) = mpsc::channel::<ServerInfoUpdate>(1);
326 rx
327 },
328 shutdown_flag: shutdown_flag_for_handlers,
329 };
330
331 let proxy_data = Data::new(Arc::new(Mutex::new(proxy_for_handlers)));
332
333 let mut server_builder = HttpServer::new(move || {
335 let cors = Cors::default()
337 .allow_any_origin()
338 .allow_any_method()
339 .allow_any_header()
340 .max_age(3600);
341
342 let auth_middleware = Authentication::new(config_arc.clone());
344
345 App::new()
346 .wrap(middleware::Logger::default())
347 .wrap(cors)
348 .app_data(event_manager.clone()) .app_data(proxy_data.clone()) .app_data(Data::new(config_arc.clone())) .wrap(auth_middleware)
353 .route("/events", web::get().to(handlers::sse_events))
355 .route("/initialize", web::post().to(handlers::initialize))
356 .route("/tool", web::post().to(handlers::tool_call))
357 .route("/jsonrpc", web::post().to(handlers::tool_call_jsonrpc))
358 .route("/servers", web::get().to(handlers::list_servers))
359 .route(
360 "/servers/{server}/tools",
361 web::get().to(handlers::list_server_tools),
362 )
363 .route(
364 "/servers/{server}/resources",
365 web::get().to(handlers::list_server_resources),
366 )
367 .route(
368 "/servers/{server}/resources/{resource}",
369 web::get().to(handlers::get_server_resource),
370 )
371 });
372
373 let workers = proxy.config.workers.unwrap_or(DEFAULT_WORKERS);
375 tracing::info!(workers = workers, "Setting number of Actix Web workers");
376 server_builder = server_builder.workers(workers);
377
378 let server = server_builder
380 .bind(addr)
381 .map_err(|e| Error::Other(format!("Failed to bind server: {}", e)))?
382 .run();
383
384 let server_handle = server.handle();
386
387 let server_task = tokio::spawn(server);
390
391 let update_handle = tokio::spawn(async move {
393 if let Err(e) = proxy.process_updates(server_handle).await {
394 tracing::error!(error = %e, "SSE proxy update processor error");
395 }
396 });
397
398 tracing::info!("SSE proxy server started successfully");
399
400 let handle = tokio::spawn(async move {
402 let (server_result, update_result) = tokio::join!(server_task, update_handle);
404
405 if let Err(e) = server_result {
407 tracing::error!(error = %e, "Actix server task error");
408 }
409 if let Err(e) = update_result {
410 tracing::error!(error = %e, "Update processor task error");
411 }
412
413 tracing::info!("SSE proxy server shut down completely");
414 });
415
416 Ok(SSEProxyHandle::new(
418 server_tx_clone,
419 handle,
420 config,
421 shutdown_flag_clone,
422 ))
423 }
424
425 async fn process_updates(&mut self, server_handle: actix_web::dev::ServerHandle) -> Result<()> {
433 tracing::info!("SSE proxy update processor started");
434
435 while !self.shutdown_flag.load(Ordering::SeqCst) {
437 match tokio::time::timeout(
439 tokio::time::Duration::from_millis(100),
440 self.server_rx.recv(),
441 )
442 .await
443 {
444 Ok(Some(update)) => match update {
445 ServerInfoUpdate::UpdateServer { name, id, status } => {
446 let mut servers = self.server_info.lock().await;
448
449 if let Some(server_info) = servers.get_mut(&name) {
450 if let Some(server_id) = id {
451 server_info.id = format!("{:?}", server_id);
452 }
453 server_info.status = status.clone();
454
455 self.event_manager
457 .send_server_status(&name, &server_info.id, &status);
458
459 tracing::debug!(server = %name, status = %status, "Updated server status");
460 } else {
461 let server_info = ServerInfo {
463 name: name.clone(),
464 id: id.map_or_else(
465 || "unknown".to_string(),
466 |id| format!("{:?}", id),
467 ),
468 status: status.clone(),
469 };
470
471 servers.insert(name.clone(), server_info.clone());
472
473 self.event_manager
475 .send_server_status(&name, &server_info.id, &status);
476
477 tracing::debug!(server = %name, status = %status, "Added server to cache");
478 }
479 }
480 ServerInfoUpdate::AddServer { name, info } => {
481 let mut servers = self.server_info.lock().await;
483 servers.insert(name.clone(), info.clone());
484
485 self.event_manager
487 .send_server_status(&name, &info.id, &info.status);
488
489 tracing::debug!(server = %name, "Added server to cache");
490 }
491 ServerInfoUpdate::Shutdown => {
492 tracing::info!("Received shutdown message");
493 self.shutdown_flag.store(true, Ordering::SeqCst);
494 break;
495 }
496 },
497 Ok(None) => {
498 tracing::info!("Server information channel closed, shutting down proxy");
500 self.shutdown_flag.store(true, Ordering::SeqCst);
501 break;
502 }
503 Err(_) => {
504 }
506 }
507 }
508
509 tracing::info!("Stopping Actix Web server");
511 server_handle.stop(true).await;
512
513 tracing::info!("SSE proxy update processor shut down");
514 Ok(())
515 }
516
517 pub async fn process_tool_call(
530 &self,
531 server_name: &str,
532 tool_name: &str,
533 args: serde_json::Value,
534 request_id: &str,
535 ) -> Result<()> {
536 tracing::debug!(server = %server_name, tool = %tool_name, req_id = %request_id, "Processing tool call");
537
538 if let Some(allowed_servers) = (self.runner_access.get_allowed_servers)() {
540 if !allowed_servers.contains(&server_name.to_string()) {
541 tracing::warn!(server = %server_name, "Server not in allowed list");
542
543 self.event_manager.send_tool_error(
545 request_id,
546 "unknown", tool_name,
548 &format!("Server not in allowed list: {}", server_name),
549 );
550
551 return Err(Error::Unauthorized(
552 "Server not in allowed list".to_string(),
553 ));
554 }
555 }
556
557 let server_id = match (self.runner_access.get_server_id)(server_name) {
559 Ok(id) => id,
560 Err(e) => {
561 tracing::warn!(server = %server_name, error = %e, "Server not found");
562
563 self.event_manager.send_tool_error(
565 request_id,
566 "unknown", tool_name,
568 &format!("Server not found: {}", server_name),
569 );
570
571 return Err(e);
572 }
573 };
574 let server_id_str = format!("{:?}", server_id);
575
576 let client = match (self.runner_access.get_client)(server_id) {
578 Ok(c) => c,
579 Err(e) => {
580 tracing::error!(server_id = ?server_id, error = %e, "Failed to get client");
581
582 self.event_manager.send_tool_error(
584 request_id,
585 &server_id_str,
586 tool_name,
587 &format!("Failed to get client: {}", e),
588 );
589
590 return Err(e);
591 }
592 };
593
594 if let Err(e) = client.initialize().await {
596 tracing::error!(server_id = ?server_id, error = %e, "Failed to initialize client");
597
598 self.event_manager.send_tool_error(
600 request_id,
601 &server_id_str,
602 tool_name,
603 &format!("Failed to initialize client: {}", e),
604 );
605
606 return Err(e);
607 }
608
609 let result = client.call_tool(tool_name, &args).await;
611
612 match result {
613 Ok(response) => {
614 tracing::debug!(req_id = %request_id, "Tool call successful");
615
616 self.event_manager.send_tool_response(
618 request_id,
619 &server_id_str,
620 tool_name,
621 response,
622 );
623
624 Ok(())
625 }
626 Err(e) => {
627 tracing::error!(req_id = %request_id, error = %e, "Tool call failed");
628
629 self.event_manager.send_tool_error(
631 request_id,
632 &server_id_str,
633 tool_name,
634 &format!("Tool call failed: {}", e),
635 );
636
637 Err(e)
638 }
639 }
640 }
641
642 pub fn get_server_info(&self) -> &Arc<Mutex<HashMap<String, ServerInfo>>> {
644 &self.server_info
645 }
646
647 pub fn get_runner_access(&self) -> &SSEProxyRunnerAccess {
649 &self.runner_access
650 }
651
652 pub fn event_manager(&self) -> &Arc<EventManager> {
654 &self.event_manager
655 }
656
657 pub fn config(&self) -> &SSEProxyConfig {
659 &self.config
660 }
661}
662
663pub struct SSEProxySharedState {
668 runner_access: SSEProxyRunnerAccess,
670 event_manager: Arc<EventManager>,
672 server_info: Arc<Mutex<HashMap<String, ServerInfo>>>,
674}
675
676impl SSEProxySharedState {
677 pub fn runner_access(&self) -> &SSEProxyRunnerAccess {
678 &self.runner_access
679 }
680
681 pub fn event_manager(&self) -> &Arc<EventManager> {
682 &self.event_manager
683 }
684
685 pub fn server_info(&self) -> &Arc<Mutex<HashMap<String, ServerInfo>>> {
686 &self.server_info
687 }
688}