mcp_runner/sse_proxy/
proxy.rs

1//! SSE proxy implementation for MCP servers using Actix Web.
2//!
3//! This module provides the core implementation of the Actix Web-based SSE proxy,
4//! including the main proxy server, runner access functions, and proxy handle.
5
6use crate::client::McpClient;
7use crate::config::{DEFAULT_WORKERS, SSEProxyConfig};
8use crate::error::{Error, Result};
9use crate::server::ServerId;
10use crate::sse_proxy::auth::Authentication;
11use crate::sse_proxy::events::EventManager;
12use crate::sse_proxy::handlers;
13use crate::sse_proxy::types::{ServerInfo, ServerInfoUpdate};
14
15use actix_cors::Cors;
16use actix_web::{
17    App, HttpServer, middleware,
18    web::{self, Data},
19};
20
21use std::collections::HashMap;
22use std::net::ToSocketAddrs;
23use std::sync::Arc;
24use std::sync::atomic::{AtomicBool, Ordering};
25use tokio::sync::{Mutex, mpsc};
26use tokio::task::JoinHandle;
27use tracing;
28
29/// Type alias for server ID retrieval function
30type ServerIdRetriever = dyn Fn(&str) -> Result<ServerId> + Send + Sync;
31/// Type alias for client retrieval function
32type ClientRetriever = dyn Fn(ServerId) -> Result<McpClient> + Send + Sync;
33/// Type alias for allowed servers retrieval function
34type AllowedServersRetriever = dyn Fn() -> Option<Vec<String>> + Send + Sync;
35/// Type alias for server config keys retrieval function
36type ServerConfigKeysRetriever = dyn Fn() -> Vec<String> + Send + Sync;
37
38/// Handle for controlling the SSE proxy
39///
40/// This handle is stored by the McpRunner to communicate with the SSE proxy.
41/// It allows the runner to send updates to the proxy about server status changes
42/// and other events without needing to access the proxy directly.
43#[derive(Clone)]
44pub struct SSEProxyHandle {
45    /// Channel for sending server information updates to the proxy
46    server_tx: mpsc::Sender<ServerInfoUpdate>,
47    /// Proxy task handle
48    handle: Arc<Mutex<Option<JoinHandle<()>>>>,
49    /// Configuration for the proxy
50    config: SSEProxyConfig,
51    /// Shutdown flag
52    shutdown_flag: Arc<AtomicBool>,
53}
54
55impl SSEProxyHandle {
56    /// Create a new SSE proxy handle
57    fn new(
58        server_tx: mpsc::Sender<ServerInfoUpdate>,
59        handle: JoinHandle<()>,
60        config: SSEProxyConfig,
61        shutdown_flag: Arc<AtomicBool>,
62    ) -> Self {
63        Self {
64            server_tx,
65            handle: Arc::new(Mutex::new(Some(handle))),
66            config,
67            shutdown_flag,
68        }
69    }
70
71    /// Update server information in the proxy
72    pub async fn update_server_info(
73        &self,
74        server_name: &str,
75        server_id: Option<ServerId>,
76        status: &str,
77    ) -> Result<()> {
78        let update = ServerInfoUpdate::UpdateServer {
79            name: server_name.to_string(),
80            id: server_id,
81            status: status.to_string(),
82        };
83
84        self.server_tx.send(update).await.map_err(|e| {
85            Error::Communication(format!("Failed to send server info update to proxy: {}", e))
86        })
87    }
88
89    /// Add a new server to the proxy cache
90    pub async fn add_server_info(&self, server_name: &str, server_info: ServerInfo) -> Result<()> {
91        let update = ServerInfoUpdate::AddServer {
92            name: server_name.to_string(),
93            info: server_info,
94        };
95
96        self.server_tx.send(update).await.map_err(|e| {
97            Error::Communication(format!("Failed to send server info update to proxy: {}", e))
98        })
99    }
100
101    /// Shutdown the proxy
102    pub async fn shutdown(&self) -> Result<()> {
103        // Set the shutdown flag to signal the proxy to stop
104        self.shutdown_flag.store(true, Ordering::SeqCst);
105
106        // Send a shutdown message through the channel
107        let _ = self.server_tx.send(ServerInfoUpdate::Shutdown).await;
108
109        // Wait for the proxy task to finish
110        let mut handle = self.handle.lock().await;
111        if let Some(h) = handle.take() {
112            // Wait with a timeout
113            match tokio::time::timeout(std::time::Duration::from_secs(5), h).await {
114                Ok(result) => {
115                    if let Err(e) = result {
116                        tracing::warn!("Error while joining proxy task: {}", e);
117                    }
118                }
119                Err(_) => {
120                    tracing::warn!("Timeout waiting for proxy task to finish");
121                }
122            }
123        }
124
125        Ok(())
126    }
127
128    /// Get the proxy configuration
129    pub fn config(&self) -> &SSEProxyConfig {
130        &self.config
131    }
132}
133
134/// Access to McpRunner operations needed by the SSE proxy
135///
136/// This struct provides a controlled interface to the operations
137/// the SSE proxy needs from the McpRunner, rather than giving
138/// it direct access to the entire runner.
139#[derive(Clone)]
140pub struct SSEProxyRunnerAccess {
141    /// Function to get server ID by name
142    pub get_server_id: Arc<ServerIdRetriever>,
143    /// Function to get a client for a server
144    pub get_client: Arc<ClientRetriever>,
145    /// Function to get allowed servers if configured
146    pub get_allowed_servers: Arc<AllowedServersRetriever>,
147    /// Function to get server config keys
148    pub get_server_config_keys: Arc<ServerConfigKeysRetriever>,
149}
150
151/// SSE Proxy server for MCP servers
152///
153/// Provides an HTTP and SSE proxy that allows web clients to interact with MCP servers
154/// using Actix Web. The proxy supports authentication, server listing, tool calls,
155/// and resource retrieval.
156pub struct SSEProxy {
157    /// Configuration for the proxy
158    config: SSEProxyConfig,
159    /// Direct access to McpRunner for server operations
160    runner_access: SSEProxyRunnerAccess,
161    /// Event manager for broadcasting events
162    event_manager: Arc<EventManager>,
163    /// Server information cache
164    server_info: Arc<Mutex<HashMap<String, ServerInfo>>>,
165    /// Channel for receiving server updates from McpRunner
166    server_rx: mpsc::Receiver<ServerInfoUpdate>,
167    /// Shutdown flag
168    shutdown_flag: Arc<AtomicBool>,
169}
170
171// Custom Clone implementation for SSEProxy that creates a dummy receiver when cloning
172impl Clone for SSEProxy {
173    fn clone(&self) -> Self {
174        // Create a dummy channel just for the clone - the main instance will still use the real one
175        let (_, dummy_rx) = mpsc::channel::<ServerInfoUpdate>(1);
176
177        Self {
178            config: self.config.clone(),
179            runner_access: self.runner_access.clone(),
180            event_manager: self.event_manager.clone(),
181            server_info: self.server_info.clone(),
182            server_rx: dummy_rx, // Use dummy receiver for clones
183            shutdown_flag: self.shutdown_flag.clone(),
184        }
185    }
186}
187
188impl SSEProxy {
189    /// Create a new SSE proxy instance
190    ///
191    /// # Arguments
192    ///
193    /// * `runner_access` - Functions to access McpRunner operations
194    /// * `config` - Configuration for the SSE proxy
195    /// * `server_rx` - Channel for receiving server information updates
196    ///
197    /// # Returns
198    ///
199    /// A new `SSEProxy` instance
200    fn new(
201        runner_access: SSEProxyRunnerAccess,
202        config: SSEProxyConfig,
203        server_rx: mpsc::Receiver<ServerInfoUpdate>,
204    ) -> Self {
205        // Initialize event manager
206        let event_manager = Arc::new(EventManager::new(100)); // Buffer up to 100 messages
207
208        // Initialize server info cache
209        let server_info = Arc::new(Mutex::new(HashMap::new()));
210
211        Self {
212            config,
213            runner_access,
214            event_manager,
215            server_info,
216            server_rx,
217            shutdown_flag: Arc::new(AtomicBool::new(false)),
218        }
219    }
220
221    /// Start the SSE proxy server
222    ///
223    /// Creates a proxy handle and starts the server in a background task.
224    /// Returns a handle that can be used to control and communicate with the proxy.
225    ///
226    /// # Arguments
227    ///
228    /// * `runner_access` - Functions to access McpRunner operations
229    /// * `config` - Configuration for the SSE proxy
230    ///
231    /// # Returns
232    ///
233    /// A `Result` containing a `SSEProxyHandle` or an error
234    pub async fn start_proxy(
235        runner_access: SSEProxyRunnerAccess,
236        config: SSEProxyConfig,
237    ) -> Result<SSEProxyHandle> {
238        // Create channel for communication between McpRunner and proxy
239        let (server_tx, server_rx) = mpsc::channel(32);
240        let server_tx_clone = server_tx.clone();
241
242        // Create the shutdown flag
243        let shutdown_flag = Arc::new(AtomicBool::new(false));
244        let shutdown_flag_clone = shutdown_flag.clone();
245
246        // Create the proxy instance
247        let mut proxy = Self::new(runner_access.clone(), config.clone(), server_rx);
248
249        // Initialize server info with current server statuses
250        // Get all server names from runner_access
251        let server_names = (runner_access.get_server_config_keys)();
252
253        // Populate server info in a scoped block so the lock is released before moving proxy
254        {
255            // Lock the server info for updating
256            let mut server_info = proxy.server_info.lock().await;
257
258            // Add each server to the info cache
259            for name in &server_names {
260                // Try to get the server ID
261                if let Ok(server_id) = (runner_access.get_server_id)(name) {
262                    // Convert the ID to a string for storing
263                    let id_str = format!("{:?}", server_id);
264
265                    // Create a server info entry with "Running" status
266                    let info = ServerInfo {
267                        name: name.clone(),
268                        id: id_str.clone(),
269                        status: "Running".to_string(),
270                    };
271
272                    // Add to the cache
273                    server_info.insert(name.clone(), info);
274
275                    tracing::debug!(server = %name, id = %id_str, "Added server to initial cache");
276                }
277            }
278
279            tracing::info!(
280                num_servers = server_info.len(),
281                "Initialized server information cache with running servers"
282            );
283        } // Lock is released here when server_info goes out of scope
284
285        // Parse the socket address from the config
286        let addr_str = format!("{}:{}", proxy.config.address, proxy.config.port);
287        let addr = match addr_str.to_socket_addrs() {
288            Ok(mut addrs) => match addrs.next() {
289                Some(addr) => addr,
290                None => {
291                    return Err(Error::Other(format!(
292                        "Could not parse socket address: {}",
293                        addr_str
294                    )));
295                }
296            },
297            Err(e) => {
298                return Err(Error::Other(format!(
299                    "Failed to parse socket address: {}",
300                    e
301                )));
302            }
303        };
304
305        tracing::info!(address = %addr_str, "Starting SSE proxy server with Actix Web");
306
307        // Share the event manager and server info via Actix Data
308        let event_manager = Data::new(proxy.event_manager.clone());
309        let config_arc = Arc::new(proxy.config.clone());
310
311        // Create copies of the fields needed for the handler
312        let runner_access_for_handlers = proxy.runner_access.clone();
313        let server_info_for_handlers = proxy.server_info.clone();
314        let event_mgr_for_handlers = proxy.event_manager.clone();
315        let shutdown_flag_for_handlers = proxy.shutdown_flag.clone();
316
317        // Create a proxy data reference for handlers to use by creating a new SSEProxy instance
318        let proxy_for_handlers = SSEProxy {
319            config: proxy.config.clone(),
320            runner_access: runner_access_for_handlers,
321            event_manager: event_mgr_for_handlers,
322            server_info: server_info_for_handlers,
323            // Create a dummy receiver - the real one stays with self
324            server_rx: {
325                let (_, rx) = mpsc::channel::<ServerInfoUpdate>(1);
326                rx
327            },
328            shutdown_flag: shutdown_flag_for_handlers,
329        };
330
331        let proxy_data = Data::new(Arc::new(Mutex::new(proxy_for_handlers)));
332
333        // Create the HTTP server builder
334        let mut server_builder = HttpServer::new(move || {
335            // Configure CORS
336            let cors = Cors::default()
337                .allow_any_origin()
338                .allow_any_method()
339                .allow_any_header()
340                .max_age(3600);
341
342            // Configure authentication middleware if required
343            let auth_middleware = Authentication::new(config_arc.clone());
344
345            App::new()
346                .wrap(middleware::Logger::default())
347                .wrap(cors)
348                .app_data(event_manager.clone()) // For sse_events handler
349                .app_data(proxy_data.clone()) // Pass the SSEProxy directly
350                .app_data(Data::new(config_arc.clone())) // Pass config if needed by handlers
351                // Apply Authentication middleware unconditionally; its internal logic handles conditions
352                .wrap(auth_middleware)
353                // Define routes
354                .route("/sse", web::get().to(handlers::sse_main_endpoint))
355                .route("/sse/messages", web::post().to(handlers::sse_messages))
356        });
357
358        // Configure workers - use the config value if specified, otherwise default to 4
359        let workers = proxy.config.workers.unwrap_or(DEFAULT_WORKERS);
360        tracing::info!(workers = workers, "Setting number of Actix Web workers");
361        server_builder = server_builder.workers(workers);
362
363        // Bind to the address
364        let server = server_builder
365            .bind(addr)
366            .map_err(|e| Error::Other(format!("Failed to bind server: {}", e)))?
367            .run();
368
369        // Get the server handle for stopping later
370        let server_handle = server.handle();
371
372        // Start two tasks:
373        // 1. Run the Actix server
374        let server_task = tokio::spawn(server);
375
376        // 2. Run the update processing loop
377        let update_handle = tokio::spawn(async move {
378            if let Err(e) = proxy.process_updates(server_handle).await {
379                tracing::error!(error = %e, "SSE proxy update processor error");
380            }
381        });
382
383        tracing::info!("SSE proxy server started successfully");
384
385        // Create the handle that will be returned to the caller
386        let handle = tokio::spawn(async move {
387            // Wait for both tasks to complete
388            let (server_result, update_result) = tokio::join!(server_task, update_handle);
389
390            // Log any errors
391            if let Err(e) = server_result {
392                tracing::error!(error = %e, "Actix server task error");
393            }
394            if let Err(e) = update_result {
395                tracing::error!(error = %e, "Update processor task error");
396            }
397
398            tracing::info!("SSE proxy server shut down completely");
399        });
400
401        // Return a handle to control the proxy
402        Ok(SSEProxyHandle::new(
403            server_tx_clone,
404            handle,
405            config,
406            shutdown_flag_clone,
407        ))
408    }
409
410    /// Process server information updates
411    ///
412    /// This is the main loop that processes server information updates from the channel.
413    ///
414    /// # Returns
415    ///
416    /// A `Result<()>` indicating success or failure
417    async fn process_updates(&mut self, server_handle: actix_web::dev::ServerHandle) -> Result<()> {
418        tracing::info!("SSE proxy update processor started");
419
420        // Main loop to process server information updates
421        while !self.shutdown_flag.load(Ordering::SeqCst) {
422            // Check for server updates with timeout
423            match tokio::time::timeout(
424                tokio::time::Duration::from_millis(100),
425                self.server_rx.recv(),
426            )
427            .await
428            {
429                Ok(Some(update)) => match update {
430                    ServerInfoUpdate::UpdateServer { name, id, status } => {
431                        // Update server info in cache
432                        let mut servers = self.server_info.lock().await;
433
434                        if let Some(server_info) = servers.get_mut(&name) {
435                            if let Some(server_id) = id {
436                                server_info.id = format!("{:?}", server_id);
437                            }
438                            server_info.status = status.clone();
439
440                            // Send server status update event
441                            self.event_manager
442                                .send_server_status(&name, &server_info.id, &status);
443
444                            tracing::debug!(server = %name, status = %status, "Updated server status");
445                        } else {
446                            // Server not in cache yet, add it with default info
447                            let server_info = ServerInfo {
448                                name: name.clone(),
449                                id: id.map_or_else(
450                                    || "unknown".to_string(),
451                                    |id| format!("{:?}", id),
452                                ),
453                                status: status.clone(),
454                            };
455
456                            servers.insert(name.clone(), server_info.clone());
457
458                            // Send server status update event
459                            self.event_manager
460                                .send_server_status(&name, &server_info.id, &status);
461
462                            tracing::debug!(server = %name, status = %status, "Added server to cache");
463                        }
464                    }
465                    ServerInfoUpdate::AddServer { name, info } => {
466                        // Add server to cache
467                        let mut servers = self.server_info.lock().await;
468                        servers.insert(name.clone(), info.clone());
469
470                        // Send server status update event
471                        self.event_manager
472                            .send_server_status(&name, &info.id, &info.status);
473
474                        tracing::debug!(server = %name, "Added server to cache");
475                    }
476                    ServerInfoUpdate::Shutdown => {
477                        tracing::info!("Received shutdown message");
478                        self.shutdown_flag.store(true, Ordering::SeqCst);
479                        break;
480                    }
481                },
482                Ok(None) => {
483                    // Channel closed
484                    tracing::info!("Server information channel closed, shutting down proxy");
485                    self.shutdown_flag.store(true, Ordering::SeqCst);
486                    break;
487                }
488                Err(_) => {
489                    // Timeout - check shutdown flag and continue
490                }
491            }
492        }
493
494        // Stop the server
495        tracing::info!("Stopping Actix Web server");
496        server_handle.stop(true).await;
497
498        tracing::info!("SSE proxy update processor shut down");
499        Ok(())
500    }
501
502    /// Process a tool call request from a client
503    ///
504    /// # Arguments
505    ///
506    /// * `server_name` - Name of the server to call the tool on
507    /// * `tool_name` - Name of the tool to call
508    /// * `args` - Arguments to pass to the tool
509    /// * `request_id` - Request ID for correlation
510    ///
511    /// # Returns
512    ///
513    /// A `Result<()>` indicating success or failure
514    pub async fn process_tool_call(
515        &self,
516        server_name: &str,
517        tool_name: &str,
518        args: serde_json::Value,
519        request_id: &str,
520    ) -> Result<()> {
521        tracing::debug!(server = %server_name, tool = %tool_name, req_id = %request_id, "Processing tool call");
522
523        // Check if this server is allowed
524        if let Some(allowed_servers) = (self.runner_access.get_allowed_servers)() {
525            if !allowed_servers.contains(&server_name.to_string()) {
526                tracing::warn!(server = %server_name, "Server not in allowed list");
527
528                // Send error event
529                self.event_manager.send_tool_error(
530                    request_id,
531                    "unknown", // Server ID is unknown if name isn't allowed/found
532                    tool_name,
533                    &format!("Server not in allowed list: {}", server_name),
534                );
535
536                return Err(Error::Unauthorized(
537                    "Server not in allowed list".to_string(),
538                ));
539            }
540        }
541
542        // Get server ID
543        let server_id = match (self.runner_access.get_server_id)(server_name) {
544            Ok(id) => id,
545            Err(e) => {
546                tracing::warn!(server = %server_name, error = %e, "Server not found");
547
548                // Send error event
549                self.event_manager.send_tool_error(
550                    request_id,
551                    "unknown", // Server ID is unknown
552                    tool_name,
553                    &format!("Server not found: {}", server_name),
554                );
555
556                return Err(e);
557            }
558        };
559        let server_id_str = format!("{:?}", server_id);
560
561        // Get a client
562        let client = match (self.runner_access.get_client)(server_id) {
563            Ok(c) => c,
564            Err(e) => {
565                tracing::error!(server_id = ?server_id, error = %e, "Failed to get client");
566
567                // Send error event
568                self.event_manager.send_tool_error(
569                    request_id,
570                    &server_id_str,
571                    tool_name,
572                    &format!("Failed to get client: {}", e),
573                );
574
575                return Err(e);
576            }
577        };
578
579        // Initialize client
580        if let Err(e) = client.initialize().await {
581            tracing::error!(server_id = ?server_id, error = %e, "Failed to initialize client");
582
583            // Send error event
584            self.event_manager.send_tool_error(
585                request_id,
586                &server_id_str,
587                tool_name,
588                &format!("Failed to initialize client: {}", e),
589            );
590
591            return Err(e);
592        }
593
594        // Call the tool with explicit type annotation for serde_json::Value
595        let result: Result<serde_json::Value> = client.call_tool(tool_name, &args).await;
596
597        match result {
598            Ok(response) => {
599                tracing::debug!(req_id = %request_id, "Tool call successful");
600
601                // Send the raw response to the event manager
602                // The event manager will now format it properly as a JSON-RPC response
603                self.event_manager.send_tool_response(
604                    request_id,
605                    &server_id_str,
606                    tool_name,
607                    response,
608                );
609
610                Ok(())
611            }
612            Err(e) => {
613                tracing::error!(req_id = %request_id, error = %e, "Tool call failed");
614
615                // Send error event
616                self.event_manager.send_tool_error(
617                    request_id,
618                    &server_id_str,
619                    tool_name,
620                    &format!("Tool call failed: {}", e),
621                );
622
623                Err(e)
624            }
625        }
626    }
627
628    /// Get the server information cache
629    pub fn get_server_info(&self) -> &Arc<Mutex<HashMap<String, ServerInfo>>> {
630        &self.server_info
631    }
632
633    /// Get the runner access functions
634    pub fn get_runner_access(&self) -> &SSEProxyRunnerAccess {
635        &self.runner_access
636    }
637
638    /// Get the event manager
639    pub fn event_manager(&self) -> &Arc<EventManager> {
640        &self.event_manager
641    }
642
643    /// Get the configuration
644    pub fn config(&self) -> &SSEProxyConfig {
645        &self.config
646    }
647}
648
649/// Shared state for the SSE proxy
650///
651/// This struct provides shared state that can be used by Actix Web handlers.
652/// It's wrapped in an Arc<Mutex<_>> and passed to handlers via Data.
653pub struct SSEProxySharedState {
654    /// Runner access functions
655    runner_access: SSEProxyRunnerAccess,
656    /// Event manager
657    event_manager: Arc<EventManager>,
658    /// Server information cache
659    server_info: Arc<Mutex<HashMap<String, ServerInfo>>>,
660}
661
662impl SSEProxySharedState {
663    pub fn runner_access(&self) -> &SSEProxyRunnerAccess {
664        &self.runner_access
665    }
666
667    pub fn event_manager(&self) -> &Arc<EventManager> {
668        &self.event_manager
669    }
670
671    pub fn server_info(&self) -> &Arc<Mutex<HashMap<String, ServerInfo>>> {
672        &self.server_info
673    }
674}