mcp_runner/sse_proxy/
proxy.rs

1//! SSE proxy implementation for MCP servers using Actix Web.
2//!
3//! This module provides the core implementation of the Actix Web-based SSE proxy,
4//! including the main proxy server, runner access functions, and proxy handle.
5
6use crate::client::McpClient;
7use crate::config::{DEFAULT_WORKERS, SSEProxyConfig};
8use crate::error::{Error, Result};
9use crate::server::ServerId;
10use crate::sse_proxy::auth::Authentication;
11use crate::sse_proxy::events::EventManager;
12use crate::sse_proxy::handlers;
13use crate::sse_proxy::types::{ServerInfo, ServerInfoUpdate};
14
15use actix_cors::Cors;
16use actix_web::{
17    App, HttpServer, middleware,
18    web::{self, Data},
19};
20
21use std::collections::HashMap;
22use std::net::ToSocketAddrs;
23use std::sync::Arc;
24use std::sync::atomic::{AtomicBool, Ordering};
25use tokio::sync::{Mutex, mpsc};
26use tokio::task::JoinHandle;
27use tracing;
28
29/// Type alias for server ID retrieval function
30type ServerIdRetriever = dyn Fn(&str) -> Result<ServerId> + Send + Sync;
31/// Type alias for client retrieval function
32type ClientRetriever = dyn Fn(ServerId) -> Result<McpClient> + Send + Sync;
33/// Type alias for allowed servers retrieval function
34type AllowedServersRetriever = dyn Fn() -> Option<Vec<String>> + Send + Sync;
35/// Type alias for server config keys retrieval function
36type ServerConfigKeysRetriever = dyn Fn() -> Vec<String> + Send + Sync;
37
38/// Handle for controlling the SSE proxy
39///
40/// This handle is stored by the McpRunner to communicate with the SSE proxy.
41/// It allows the runner to send updates to the proxy about server status changes
42/// and other events without needing to access the proxy directly.
43#[derive(Clone)]
44pub struct SSEProxyHandle {
45    /// Channel for sending server information updates to the proxy
46    server_tx: mpsc::Sender<ServerInfoUpdate>,
47    /// Proxy task handle
48    handle: Arc<Mutex<Option<JoinHandle<()>>>>,
49    /// Configuration for the proxy
50    config: SSEProxyConfig,
51    /// Shutdown flag
52    shutdown_flag: Arc<AtomicBool>,
53}
54
55impl SSEProxyHandle {
56    /// Create a new SSE proxy handle
57    fn new(
58        server_tx: mpsc::Sender<ServerInfoUpdate>,
59        handle: JoinHandle<()>,
60        config: SSEProxyConfig,
61        shutdown_flag: Arc<AtomicBool>,
62    ) -> Self {
63        Self {
64            server_tx,
65            handle: Arc::new(Mutex::new(Some(handle))),
66            config,
67            shutdown_flag,
68        }
69    }
70
71    /// Update server information in the proxy
72    pub async fn update_server_info(
73        &self,
74        server_name: &str,
75        server_id: Option<ServerId>,
76        status: &str,
77    ) -> Result<()> {
78        let update = ServerInfoUpdate::UpdateServer {
79            name: server_name.to_string(),
80            id: server_id,
81            status: status.to_string(),
82        };
83
84        self.server_tx.send(update).await.map_err(|e| {
85            Error::Communication(format!("Failed to send server info update to proxy: {}", e))
86        })
87    }
88
89    /// Add a new server to the proxy cache
90    pub async fn add_server_info(&self, server_name: &str, server_info: ServerInfo) -> Result<()> {
91        let update = ServerInfoUpdate::AddServer {
92            name: server_name.to_string(),
93            info: server_info,
94        };
95
96        self.server_tx.send(update).await.map_err(|e| {
97            Error::Communication(format!("Failed to send server info update to proxy: {}", e))
98        })
99    }
100
101    /// Shutdown the proxy
102    pub async fn shutdown(&self) -> Result<()> {
103        // Set the shutdown flag to signal the proxy to stop
104        self.shutdown_flag.store(true, Ordering::SeqCst);
105
106        // Send a shutdown message through the channel
107        let _ = self.server_tx.send(ServerInfoUpdate::Shutdown).await;
108
109        // Wait for the proxy task to finish
110        let mut handle = self.handle.lock().await;
111        if let Some(h) = handle.take() {
112            // Wait with a timeout
113            match tokio::time::timeout(std::time::Duration::from_secs(5), h).await {
114                Ok(result) => {
115                    if let Err(e) = result {
116                        tracing::warn!("Error while joining proxy task: {}", e);
117                    }
118                }
119                Err(_) => {
120                    tracing::warn!("Timeout waiting for proxy task to finish");
121                }
122            }
123        }
124
125        Ok(())
126    }
127
128    /// Get the proxy configuration
129    pub fn config(&self) -> &SSEProxyConfig {
130        &self.config
131    }
132}
133
134/// Access to McpRunner operations needed by the SSE proxy
135///
136/// This struct provides a controlled interface to the operations
137/// the SSE proxy needs from the McpRunner, rather than giving
138/// it direct access to the entire runner.
139#[derive(Clone)]
140pub struct SSEProxyRunnerAccess {
141    /// Function to get server ID by name
142    pub get_server_id: Arc<ServerIdRetriever>,
143    /// Function to get a client for a server
144    pub get_client: Arc<ClientRetriever>,
145    /// Function to get allowed servers if configured
146    pub get_allowed_servers: Arc<AllowedServersRetriever>,
147    /// Function to get server config keys
148    pub get_server_config_keys: Arc<ServerConfigKeysRetriever>,
149}
150
151/// SSE Proxy server for MCP servers
152///
153/// Provides an HTTP and SSE proxy that allows web clients to interact with MCP servers
154/// using Actix Web. The proxy supports authentication, server listing, tool calls,
155/// and resource retrieval.
156pub struct SSEProxy {
157    /// Configuration for the proxy
158    config: SSEProxyConfig,
159    /// Direct access to McpRunner for server operations
160    runner_access: SSEProxyRunnerAccess,
161    /// Event manager for broadcasting events
162    event_manager: Arc<EventManager>,
163    /// Server information cache
164    server_info: Arc<Mutex<HashMap<String, ServerInfo>>>,
165    /// Channel for receiving server updates from McpRunner
166    server_rx: mpsc::Receiver<ServerInfoUpdate>,
167    /// Shutdown flag
168    shutdown_flag: Arc<AtomicBool>,
169}
170
171// Custom Clone implementation for SSEProxy that creates a dummy receiver when cloning
172impl Clone for SSEProxy {
173    fn clone(&self) -> Self {
174        // Create a dummy channel just for the clone - the main instance will still use the real one
175        let (_, dummy_rx) = mpsc::channel::<ServerInfoUpdate>(1);
176
177        Self {
178            config: self.config.clone(),
179            runner_access: self.runner_access.clone(),
180            event_manager: self.event_manager.clone(),
181            server_info: self.server_info.clone(),
182            server_rx: dummy_rx, // Use dummy receiver for clones
183            shutdown_flag: self.shutdown_flag.clone(),
184        }
185    }
186}
187
188impl SSEProxy {
189    /// Create a new SSE proxy instance
190    ///
191    /// # Arguments
192    ///
193    /// * `runner_access` - Functions to access McpRunner operations
194    /// * `config` - Configuration for the SSE proxy
195    /// * `server_rx` - Channel for receiving server information updates
196    ///
197    /// # Returns
198    ///
199    /// A new `SSEProxy` instance
200    fn new(
201        runner_access: SSEProxyRunnerAccess,
202        config: SSEProxyConfig,
203        server_rx: mpsc::Receiver<ServerInfoUpdate>,
204    ) -> Self {
205        // Initialize event manager
206        let event_manager = Arc::new(EventManager::new(100)); // Buffer up to 100 messages
207
208        // Initialize server info cache
209        let server_info = Arc::new(Mutex::new(HashMap::new()));
210
211        Self {
212            config,
213            runner_access,
214            event_manager,
215            server_info,
216            server_rx,
217            shutdown_flag: Arc::new(AtomicBool::new(false)),
218        }
219    }
220
221    /// Start the SSE proxy server
222    ///
223    /// Creates a proxy handle and starts the server in a background task.
224    /// Returns a handle that can be used to control and communicate with the proxy.
225    ///
226    /// # Arguments
227    ///
228    /// * `runner_access` - Functions to access McpRunner operations
229    /// * `config` - Configuration for the SSE proxy
230    ///
231    /// # Returns
232    ///
233    /// A `Result` containing a `SSEProxyHandle` or an error
234    pub async fn start_proxy(
235        runner_access: SSEProxyRunnerAccess,
236        config: SSEProxyConfig,
237    ) -> Result<SSEProxyHandle> {
238        // Create channel for communication between McpRunner and proxy
239        let (server_tx, server_rx) = mpsc::channel(32);
240        let server_tx_clone = server_tx.clone();
241
242        // Create the shutdown flag
243        let shutdown_flag = Arc::new(AtomicBool::new(false));
244        let shutdown_flag_clone = shutdown_flag.clone();
245
246        // Create the proxy instance
247        let mut proxy = Self::new(runner_access.clone(), config.clone(), server_rx);
248
249        // Initialize server info with current server statuses
250        // Get all server names from runner_access
251        let server_names = (runner_access.get_server_config_keys)();
252
253        // Populate server info in a scoped block so the lock is released before moving proxy
254        {
255            // Lock the server info for updating
256            let mut server_info = proxy.server_info.lock().await;
257
258            // Add each server to the info cache
259            for name in &server_names {
260                // Try to get the server ID
261                if let Ok(server_id) = (runner_access.get_server_id)(name) {
262                    // Convert the ID to a string for storing
263                    let id_str = format!("{:?}", server_id);
264
265                    // Create a server info entry with "Running" status
266                    let info = ServerInfo {
267                        name: name.clone(),
268                        id: id_str.clone(),
269                        status: "Running".to_string(),
270                    };
271
272                    // Add to the cache
273                    server_info.insert(name.clone(), info);
274
275                    tracing::debug!(server = %name, id = %id_str, "Added server to initial cache");
276                }
277            }
278
279            tracing::info!(
280                num_servers = server_info.len(),
281                "Initialized server information cache with running servers"
282            );
283        } // Lock is released here when server_info goes out of scope
284
285        // Parse the socket address from the config
286        let addr_str = format!("{}:{}", proxy.config.address, proxy.config.port);
287        let addr = match addr_str.to_socket_addrs() {
288            Ok(mut addrs) => match addrs.next() {
289                Some(addr) => addr,
290                None => {
291                    return Err(Error::Other(format!(
292                        "Could not parse socket address: {}",
293                        addr_str
294                    )));
295                }
296            },
297            Err(e) => {
298                return Err(Error::Other(format!(
299                    "Failed to parse socket address: {}",
300                    e
301                )));
302            }
303        };
304
305        tracing::info!(address = %addr_str, "Starting SSE proxy server with Actix Web");
306
307        // Share the event manager and server info via Actix Data
308        let event_manager = Data::new(proxy.event_manager.clone());
309        let config_arc = Arc::new(proxy.config.clone());
310
311        // Create copies of the fields needed for the handler
312        let runner_access_for_handlers = proxy.runner_access.clone();
313        let server_info_for_handlers = proxy.server_info.clone();
314        let event_mgr_for_handlers = proxy.event_manager.clone();
315        let shutdown_flag_for_handlers = proxy.shutdown_flag.clone();
316
317        // Create a proxy data reference for handlers to use by creating a new SSEProxy instance
318        let proxy_for_handlers = SSEProxy {
319            config: proxy.config.clone(),
320            runner_access: runner_access_for_handlers,
321            event_manager: event_mgr_for_handlers,
322            server_info: server_info_for_handlers,
323            // Create a dummy receiver - the real one stays with self
324            server_rx: {
325                let (_, rx) = mpsc::channel::<ServerInfoUpdate>(1);
326                rx
327            },
328            shutdown_flag: shutdown_flag_for_handlers,
329        };
330
331        let proxy_data = Data::new(Arc::new(Mutex::new(proxy_for_handlers)));
332
333        // Create the HTTP server builder
334        let mut server_builder = HttpServer::new(move || {
335            // Configure CORS
336            let cors = Cors::default()
337                .allow_any_origin()
338                .allow_any_method()
339                .allow_any_header()
340                .max_age(3600);
341
342            // Configure authentication middleware if required
343            let auth_middleware = Authentication::new(config_arc.clone());
344
345            App::new()
346                .wrap(middleware::Logger::default())
347                .wrap(cors)
348                .app_data(event_manager.clone()) // For sse_events handler
349                .app_data(proxy_data.clone()) // Pass the SSEProxy directly
350                .app_data(Data::new(config_arc.clone())) // Pass config if needed by handlers
351                // Apply Authentication middleware unconditionally; its internal logic handles conditions
352                .wrap(auth_middleware)
353                // Define routes
354                .route("/events", web::get().to(handlers::sse_events))
355                .route("/initialize", web::post().to(handlers::initialize))
356                .route("/tool", web::post().to(handlers::tool_call))
357                .route("/jsonrpc", web::post().to(handlers::tool_call_jsonrpc))
358                .route("/servers", web::get().to(handlers::list_servers))
359                .route(
360                    "/servers/{server}/tools",
361                    web::get().to(handlers::list_server_tools),
362                )
363                .route(
364                    "/servers/{server}/resources",
365                    web::get().to(handlers::list_server_resources),
366                )
367                .route(
368                    "/servers/{server}/resources/{resource}",
369                    web::get().to(handlers::get_server_resource),
370                )
371        });
372
373        // Configure workers - use the config value if specified, otherwise default to 4
374        let workers = proxy.config.workers.unwrap_or(DEFAULT_WORKERS);
375        tracing::info!(workers = workers, "Setting number of Actix Web workers");
376        server_builder = server_builder.workers(workers);
377
378        // Bind to the address
379        let server = server_builder
380            .bind(addr)
381            .map_err(|e| Error::Other(format!("Failed to bind server: {}", e)))?
382            .run();
383
384        // Get the server handle for stopping later
385        let server_handle = server.handle();
386
387        // Start two tasks:
388        // 1. Run the Actix server
389        let server_task = tokio::spawn(server);
390
391        // 2. Run the update processing loop
392        let update_handle = tokio::spawn(async move {
393            if let Err(e) = proxy.process_updates(server_handle).await {
394                tracing::error!(error = %e, "SSE proxy update processor error");
395            }
396        });
397
398        tracing::info!("SSE proxy server started successfully");
399
400        // Create the handle that will be returned to the caller
401        let handle = tokio::spawn(async move {
402            // Wait for both tasks to complete
403            let (server_result, update_result) = tokio::join!(server_task, update_handle);
404
405            // Log any errors
406            if let Err(e) = server_result {
407                tracing::error!(error = %e, "Actix server task error");
408            }
409            if let Err(e) = update_result {
410                tracing::error!(error = %e, "Update processor task error");
411            }
412
413            tracing::info!("SSE proxy server shut down completely");
414        });
415
416        // Return a handle to control the proxy
417        Ok(SSEProxyHandle::new(
418            server_tx_clone,
419            handle,
420            config,
421            shutdown_flag_clone,
422        ))
423    }
424
425    /// Process server information updates
426    ///
427    /// This is the main loop that processes server information updates from the channel.
428    ///
429    /// # Returns
430    ///
431    /// A `Result<()>` indicating success or failure
432    async fn process_updates(&mut self, server_handle: actix_web::dev::ServerHandle) -> Result<()> {
433        tracing::info!("SSE proxy update processor started");
434
435        // Main loop to process server information updates
436        while !self.shutdown_flag.load(Ordering::SeqCst) {
437            // Check for server updates with timeout
438            match tokio::time::timeout(
439                tokio::time::Duration::from_millis(100),
440                self.server_rx.recv(),
441            )
442            .await
443            {
444                Ok(Some(update)) => match update {
445                    ServerInfoUpdate::UpdateServer { name, id, status } => {
446                        // Update server info in cache
447                        let mut servers = self.server_info.lock().await;
448
449                        if let Some(server_info) = servers.get_mut(&name) {
450                            if let Some(server_id) = id {
451                                server_info.id = format!("{:?}", server_id);
452                            }
453                            server_info.status = status.clone();
454
455                            // Send server status update event
456                            self.event_manager
457                                .send_server_status(&name, &server_info.id, &status);
458
459                            tracing::debug!(server = %name, status = %status, "Updated server status");
460                        } else {
461                            // Server not in cache yet, add it with default info
462                            let server_info = ServerInfo {
463                                name: name.clone(),
464                                id: id.map_or_else(
465                                    || "unknown".to_string(),
466                                    |id| format!("{:?}", id),
467                                ),
468                                status: status.clone(),
469                            };
470
471                            servers.insert(name.clone(), server_info.clone());
472
473                            // Send server status update event
474                            self.event_manager
475                                .send_server_status(&name, &server_info.id, &status);
476
477                            tracing::debug!(server = %name, status = %status, "Added server to cache");
478                        }
479                    }
480                    ServerInfoUpdate::AddServer { name, info } => {
481                        // Add server to cache
482                        let mut servers = self.server_info.lock().await;
483                        servers.insert(name.clone(), info.clone());
484
485                        // Send server status update event
486                        self.event_manager
487                            .send_server_status(&name, &info.id, &info.status);
488
489                        tracing::debug!(server = %name, "Added server to cache");
490                    }
491                    ServerInfoUpdate::Shutdown => {
492                        tracing::info!("Received shutdown message");
493                        self.shutdown_flag.store(true, Ordering::SeqCst);
494                        break;
495                    }
496                },
497                Ok(None) => {
498                    // Channel closed
499                    tracing::info!("Server information channel closed, shutting down proxy");
500                    self.shutdown_flag.store(true, Ordering::SeqCst);
501                    break;
502                }
503                Err(_) => {
504                    // Timeout - check shutdown flag and continue
505                }
506            }
507        }
508
509        // Stop the server
510        tracing::info!("Stopping Actix Web server");
511        server_handle.stop(true).await;
512
513        tracing::info!("SSE proxy update processor shut down");
514        Ok(())
515    }
516
517    /// Process a tool call request from a client
518    ///
519    /// # Arguments
520    ///
521    /// * `server_name` - Name of the server to call the tool on
522    /// * `tool_name` - Name of the tool to call
523    /// * `args` - Arguments to pass to the tool
524    /// * `request_id` - Request ID for correlation
525    ///
526    /// # Returns
527    ///
528    /// A `Result<()>` indicating success or failure
529    pub async fn process_tool_call(
530        &self,
531        server_name: &str,
532        tool_name: &str,
533        args: serde_json::Value,
534        request_id: &str,
535    ) -> Result<()> {
536        tracing::debug!(server = %server_name, tool = %tool_name, req_id = %request_id, "Processing tool call");
537
538        // Check if this server is allowed
539        if let Some(allowed_servers) = (self.runner_access.get_allowed_servers)() {
540            if !allowed_servers.contains(&server_name.to_string()) {
541                tracing::warn!(server = %server_name, "Server not in allowed list");
542
543                // Send error event
544                self.event_manager.send_tool_error(
545                    request_id,
546                    "unknown", // Server ID is unknown if name isn't allowed/found
547                    tool_name,
548                    &format!("Server not in allowed list: {}", server_name),
549                );
550
551                return Err(Error::Unauthorized(
552                    "Server not in allowed list".to_string(),
553                ));
554            }
555        }
556
557        // Get server ID
558        let server_id = match (self.runner_access.get_server_id)(server_name) {
559            Ok(id) => id,
560            Err(e) => {
561                tracing::warn!(server = %server_name, error = %e, "Server not found");
562
563                // Send error event
564                self.event_manager.send_tool_error(
565                    request_id,
566                    "unknown", // Server ID is unknown
567                    tool_name,
568                    &format!("Server not found: {}", server_name),
569                );
570
571                return Err(e);
572            }
573        };
574        let server_id_str = format!("{:?}", server_id);
575
576        // Get a client
577        let client = match (self.runner_access.get_client)(server_id) {
578            Ok(c) => c,
579            Err(e) => {
580                tracing::error!(server_id = ?server_id, error = %e, "Failed to get client");
581
582                // Send error event
583                self.event_manager.send_tool_error(
584                    request_id,
585                    &server_id_str,
586                    tool_name,
587                    &format!("Failed to get client: {}", e),
588                );
589
590                return Err(e);
591            }
592        };
593
594        // Initialize client
595        if let Err(e) = client.initialize().await {
596            tracing::error!(server_id = ?server_id, error = %e, "Failed to initialize client");
597
598            // Send error event
599            self.event_manager.send_tool_error(
600                request_id,
601                &server_id_str,
602                tool_name,
603                &format!("Failed to initialize client: {}", e),
604            );
605
606            return Err(e);
607        }
608
609        // Call the tool
610        let result = client.call_tool(tool_name, &args).await;
611
612        match result {
613            Ok(response) => {
614                tracing::debug!(req_id = %request_id, "Tool call successful");
615
616                // Send response event
617                self.event_manager.send_tool_response(
618                    request_id,
619                    &server_id_str,
620                    tool_name,
621                    response,
622                );
623
624                Ok(())
625            }
626            Err(e) => {
627                tracing::error!(req_id = %request_id, error = %e, "Tool call failed");
628
629                // Send error event
630                self.event_manager.send_tool_error(
631                    request_id,
632                    &server_id_str,
633                    tool_name,
634                    &format!("Tool call failed: {}", e),
635                );
636
637                Err(e)
638            }
639        }
640    }
641
642    /// Get the server information cache
643    pub fn get_server_info(&self) -> &Arc<Mutex<HashMap<String, ServerInfo>>> {
644        &self.server_info
645    }
646
647    /// Get the runner access functions
648    pub fn get_runner_access(&self) -> &SSEProxyRunnerAccess {
649        &self.runner_access
650    }
651
652    /// Get the event manager
653    pub fn event_manager(&self) -> &Arc<EventManager> {
654        &self.event_manager
655    }
656
657    /// Get the configuration
658    pub fn config(&self) -> &SSEProxyConfig {
659        &self.config
660    }
661}
662
663/// Shared state for the SSE proxy
664///
665/// This struct provides shared state that can be used by Actix Web handlers.
666/// It's wrapped in an Arc<Mutex<_>> and passed to handlers via Data.
667pub struct SSEProxySharedState {
668    /// Runner access functions
669    runner_access: SSEProxyRunnerAccess,
670    /// Event manager
671    event_manager: Arc<EventManager>,
672    /// Server information cache
673    server_info: Arc<Mutex<HashMap<String, ServerInfo>>>,
674}
675
676impl SSEProxySharedState {
677    pub fn runner_access(&self) -> &SSEProxyRunnerAccess {
678        &self.runner_access
679    }
680
681    pub fn event_manager(&self) -> &Arc<EventManager> {
682        &self.event_manager
683    }
684
685    pub fn server_info(&self) -> &Arc<Mutex<HashMap<String, ServerInfo>>> {
686        &self.server_info
687    }
688}