mcpkit_transport/websocket/
server.rs

1//! WebSocket transport server implementation.
2//!
3//! This module provides server-side WebSocket transport for MCP.
4//!
5//! # Connection Handling
6//!
7//! The listener accepts connections and makes them available through
8//! the [`WebSocketListener::accept`] method. Use this in a loop to
9//! handle incoming connections:
10//!
11//! ```ignore
12//! let listener = WebSocketListener::new("0.0.0.0:8080").start().await?;
13//!
14//! while let Ok(transport) = listener.accept().await {
15//!     tokio::spawn(async move {
16//!         // Handle the connection
17//!         while let Some(msg) = transport.recv().await? {
18//!             // Process messages
19//!         }
20//!     });
21//! }
22//! ```
23
24use std::sync::atomic::{AtomicBool, Ordering};
25
26#[cfg(feature = "websocket")]
27use std::sync::Arc;
28#[cfg(feature = "websocket")]
29use std::sync::atomic::AtomicU64;
30
31use crate::error::TransportError;
32
33/// Server-side configuration for WebSocket listeners.
34#[derive(Debug, Clone, Default)]
35pub struct WebSocketServerConfig {
36    /// Allowed origins for DNS rebinding protection.
37    /// If empty, origin validation is disabled.
38    pub allowed_origins: Vec<String>,
39    /// Maximum message size in bytes.
40    pub max_message_size: usize,
41}
42
43impl WebSocketServerConfig {
44    /// Create a new server configuration.
45    #[must_use]
46    pub const fn new() -> Self {
47        Self {
48            allowed_origins: Vec::new(),
49            max_message_size: 16 * 1024 * 1024, // 16 MB
50        }
51    }
52
53    /// Add an allowed origin for DNS rebinding protection.
54    #[must_use]
55    pub fn with_allowed_origin(mut self, origin: impl Into<String>) -> Self {
56        self.allowed_origins.push(origin.into());
57        self
58    }
59
60    /// Set multiple allowed origins at once.
61    #[must_use]
62    pub fn with_allowed_origins(
63        mut self,
64        origins: impl IntoIterator<Item = impl Into<String>>,
65    ) -> Self {
66        self.allowed_origins
67            .extend(origins.into_iter().map(Into::into));
68        self
69    }
70
71    /// Set maximum message size.
72    #[must_use]
73    pub const fn with_max_message_size(mut self, size: usize) -> Self {
74        self.max_message_size = size;
75        self
76    }
77
78    /// Check if an origin is allowed.
79    #[must_use]
80    pub fn is_origin_allowed(&self, origin: &str) -> bool {
81        self.allowed_origins.is_empty() || self.allowed_origins.iter().any(|o| o == origin)
82    }
83}
84
85/// WebSocket listener for server-side connections.
86///
87/// This listener accepts incoming WebSocket connections and provides them
88/// through the [`accept`](Self::accept) method. It properly tracks active
89/// connections and task handles for graceful shutdown.
90///
91/// # Example
92///
93/// ```ignore
94/// use mcpkit_transport::websocket::WebSocketListener;
95///
96/// let listener = WebSocketListener::new("0.0.0.0:8080");
97/// listener.start().await?;
98///
99/// while let Ok(transport) = listener.accept().await {
100///     tokio::spawn(async move {
101///         // Handle the connection
102///     });
103/// }
104/// ```
105#[cfg(feature = "websocket")]
106pub struct WebSocketListener {
107    bind_addr: String,
108    config: WebSocketServerConfig,
109    running: AtomicBool,
110    /// Channel for delivering accepted connections to callers.
111    connection_tx: tokio::sync::mpsc::Sender<AcceptedConnection>,
112    /// Channel for receiving accepted connections.
113    connection_rx: crate::runtime::AsyncMutex<tokio::sync::mpsc::Receiver<AcceptedConnection>>,
114    /// Active connection count for metrics and shutdown coordination (shared with guards).
115    active_connections: Arc<AtomicU64>,
116    /// Shutdown signal sender.
117    shutdown_tx: crate::runtime::AsyncMutex<Option<tokio::sync::broadcast::Sender<()>>>,
118}
119
120// SAFETY: WebSocketListener is RefUnwindSafe because:
121// - All fields are either inherently panic-safe or wrapped in Arc/AtomicBool
122// - The AsyncMutex fields only contain types that can safely be dropped after a panic
123// - This maintains backwards compatibility with v0.2.5
124#[cfg(feature = "websocket")]
125impl std::panic::RefUnwindSafe for WebSocketListener {}
126
127/// An accepted WebSocket connection with metadata.
128#[cfg(feature = "websocket")]
129pub struct AcceptedConnection {
130    /// The WebSocket stream.
131    pub stream: tokio_tungstenite::WebSocketStream<tokio::net::TcpStream>,
132    /// Remote peer address.
133    pub peer_addr: std::net::SocketAddr,
134    /// Connection ID for tracking.
135    pub connection_id: u64,
136}
137
138#[cfg(feature = "websocket")]
139impl WebSocketListener {
140    /// Create a new WebSocket listener.
141    #[must_use]
142    pub fn new(bind_addr: impl Into<String>) -> Self {
143        // Buffer up to 32 pending connections
144        let (tx, rx) = tokio::sync::mpsc::channel(32);
145        Self {
146            bind_addr: bind_addr.into(),
147            config: WebSocketServerConfig::new(),
148            running: AtomicBool::new(false),
149            connection_tx: tx,
150            connection_rx: crate::runtime::AsyncMutex::new(rx),
151            active_connections: Arc::new(AtomicU64::new(0)),
152            shutdown_tx: crate::runtime::AsyncMutex::new(None),
153        }
154    }
155
156    /// Create a new WebSocket listener with configuration.
157    #[must_use]
158    pub fn with_config(bind_addr: impl Into<String>, config: WebSocketServerConfig) -> Self {
159        let (tx, rx) = tokio::sync::mpsc::channel(32);
160        Self {
161            bind_addr: bind_addr.into(),
162            config,
163            running: AtomicBool::new(false),
164            connection_tx: tx,
165            connection_rx: crate::runtime::AsyncMutex::new(rx),
166            active_connections: Arc::new(AtomicU64::new(0)),
167            shutdown_tx: crate::runtime::AsyncMutex::new(None),
168        }
169    }
170
171    /// Get the server configuration.
172    #[must_use]
173    pub const fn config(&self) -> &WebSocketServerConfig {
174        &self.config
175    }
176
177    /// Get the number of active connections.
178    #[must_use]
179    pub fn active_connections(&self) -> u64 {
180        self.active_connections.load(Ordering::Relaxed)
181    }
182
183    /// Accept the next incoming connection.
184    ///
185    /// This method returns the next accepted WebSocket connection, or an error
186    /// if the listener has been stopped.
187    ///
188    /// # Example
189    ///
190    /// ```ignore
191    /// while let Ok(conn) = listener.accept().await {
192    ///     let transport = WebSocketTransport::from_stream(conn.stream);
193    ///     // Handle the transport...
194    /// }
195    /// ```
196    pub async fn accept(&self) -> Result<AcceptedConnection, TransportError> {
197        let mut rx = self.connection_rx.lock().await;
198        rx.recv().await.ok_or_else(|| TransportError::Connection {
199            message: "Listener stopped".to_string(),
200        })
201    }
202
203    /// Start listening for connections.
204    ///
205    /// This spawns a background task that accepts connections and makes them
206    /// available through [`accept`](Self::accept). Call [`stop`](Self::stop)
207    /// to shut down the listener.
208    pub async fn start(&self) -> Result<(), TransportError> {
209        use tokio::net::TcpListener;
210
211        let listener =
212            TcpListener::bind(&self.bind_addr)
213                .await
214                .map_err(|e| TransportError::Connection {
215                    message: format!("Failed to bind WebSocket listener: {e}"),
216                })?;
217
218        self.running.store(true, Ordering::Release);
219        tracing::info!(addr = %self.bind_addr, "WebSocket listener started");
220
221        let (shutdown_tx, _) = tokio::sync::broadcast::channel::<()>(1);
222        *self.shutdown_tx.lock().await = Some(shutdown_tx.clone());
223
224        let connection_id = Arc::new(AtomicU64::new(0));
225
226        while self.running.load(Ordering::Acquire) {
227            let mut shutdown_rx = shutdown_tx.subscribe();
228
229            tokio::select! {
230                accept_result = listener.accept() => {
231                    match accept_result {
232                        Ok((stream, addr)) => {
233                            tracing::debug!(peer = %addr, "Accepting WebSocket connection");
234
235                            let allowed_origins = self.config.allowed_origins.clone();
236                            let tx = self.connection_tx.clone();
237                            let conn_id = connection_id.fetch_add(1, Ordering::Relaxed);
238                            let active_conns_counter = Arc::clone(&self.active_connections);
239
240                            // Increment active connection count
241                            self.active_connections.fetch_add(1, Ordering::Relaxed);
242
243                            // Create guard that decrements on drop
244                            let guard = ActiveConnectionGuard {
245                                counter: active_conns_counter,
246                            };
247
248                            // Spawn task to handle WebSocket upgrade
249                            tokio::spawn(async move {
250                                let _guard = guard;
251
252                                // Use the callback-based accept for origin validation
253                                let callback = |request: &tokio_tungstenite::tungstenite::handshake::server::Request,
254                                               response: tokio_tungstenite::tungstenite::handshake::server::Response| {
255                                    // Extract origin header
256                                    if !allowed_origins.is_empty() {
257                                        if let Some(origin) = request.headers().get("origin") {
258                                            let origin_str = origin.to_str().unwrap_or("");
259                                            if !allowed_origins.iter().any(|o| o == origin_str) {
260                                                tracing::warn!(
261                                                    peer = %addr,
262                                                    origin = %origin_str,
263                                                    "Rejecting WebSocket connection from disallowed origin"
264                                                );
265                                                return Err(tokio_tungstenite::tungstenite::handshake::server::Response::builder()
266                                                    .status(403)
267                                                    .body(Some("Origin not allowed".to_string()))
268                                                    .expect("failed to build HTTP 403 response"));
269                                            }
270                                        } else {
271                                            // No origin header - reject if origins are configured
272                                            tracing::warn!(
273                                                peer = %addr,
274                                                "Rejecting WebSocket connection with missing Origin header"
275                                            );
276                                            return Err(tokio_tungstenite::tungstenite::handshake::server::Response::builder()
277                                                .status(403)
278                                                .body(Some("Origin header required".to_string()))
279                                                .expect("failed to build HTTP 403 response"));
280                                        }
281                                    }
282                                    Ok(response)
283                                };
284
285                                match tokio_tungstenite::accept_hdr_async(stream, callback).await {
286                                    Ok(ws_stream) => {
287                                        tracing::info!(
288                                            peer = %addr,
289                                            connection_id = conn_id,
290                                            "WebSocket connection established"
291                                        );
292
293                                        // Send the accepted connection to the channel
294                                        let connection = AcceptedConnection {
295                                            stream: ws_stream,
296                                            peer_addr: addr,
297                                            connection_id: conn_id,
298                                        };
299
300                                        if tx.send(connection).await.is_err() {
301                                            tracing::warn!(
302                                                connection_id = conn_id,
303                                                "Connection channel closed, dropping connection"
304                                            );
305                                        }
306                                    }
307                                    Err(e) => {
308                                        tracing::error!(
309                                            peer = %addr,
310                                            error = %e,
311                                            "WebSocket handshake failed"
312                                        );
313                                    }
314                                }
315                            });
316                        }
317                        Err(e) => {
318                            if self.running.load(Ordering::Acquire) {
319                                tracing::error!(error = %e, "Error accepting connection");
320                            }
321                        }
322                    }
323                }
324                _ = shutdown_rx.recv() => {
325                    tracing::info!("WebSocket listener shutting down");
326                    break;
327                }
328            }
329        }
330
331        self.running.store(false, Ordering::Release);
332        Ok(())
333    }
334
335    /// Stop the listener gracefully.
336    ///
337    /// This signals the listener to stop accepting new connections. Existing
338    /// connections remain active until explicitly closed.
339    pub async fn stop(&self) {
340        self.running.store(false, Ordering::Release);
341        if let Some(tx) = self.shutdown_tx.lock().await.take() {
342            let _ = tx.send(());
343        }
344        tracing::info!(
345            active_connections = self.active_connections(),
346            "WebSocket listener stopped"
347        );
348    }
349
350    /// Check if the listener is running.
351    #[must_use]
352    pub fn is_running(&self) -> bool {
353        self.running.load(Ordering::Acquire)
354    }
355
356    /// Get the bind address.
357    #[must_use]
358    pub fn bind_addr(&self) -> &str {
359        &self.bind_addr
360    }
361}
362
363/// Guard that decrements active connection count on drop.
364///
365/// Uses `Arc<AtomicU64>` for safe shared ownership across tasks.
366#[cfg(feature = "websocket")]
367struct ActiveConnectionGuard {
368    counter: Arc<AtomicU64>,
369}
370
371#[cfg(feature = "websocket")]
372impl Drop for ActiveConnectionGuard {
373    fn drop(&mut self) {
374        self.counter.fetch_sub(1, Ordering::Relaxed);
375    }
376}
377
378/// Stub listener when websocket feature is disabled.
379#[cfg(not(feature = "websocket"))]
380pub struct WebSocketListener {
381    bind_addr: String,
382    config: WebSocketServerConfig,
383    running: AtomicBool,
384}
385
386/// Stub for `AcceptedConnection` when websocket feature is disabled.
387#[cfg(not(feature = "websocket"))]
388pub struct AcceptedConnection {
389    _private: (),
390}
391
392#[cfg(not(feature = "websocket"))]
393impl WebSocketListener {
394    /// Create a new WebSocket listener.
395    #[must_use]
396    pub fn new(bind_addr: impl Into<String>) -> Self {
397        Self {
398            bind_addr: bind_addr.into(),
399            config: WebSocketServerConfig::new(),
400            running: AtomicBool::new(false),
401        }
402    }
403
404    /// Create a new WebSocket listener with configuration.
405    #[must_use]
406    pub fn with_config(bind_addr: impl Into<String>, config: WebSocketServerConfig) -> Self {
407        Self {
408            bind_addr: bind_addr.into(),
409            config,
410            running: AtomicBool::new(false),
411        }
412    }
413
414    /// Get the server configuration.
415    #[must_use]
416    pub const fn config(&self) -> &WebSocketServerConfig {
417        &self.config
418    }
419
420    /// Get the number of active connections (always 0 when feature disabled).
421    #[must_use]
422    pub fn active_connections(&self) -> u64 {
423        0
424    }
425
426    /// Accept a connection (stub - always returns error).
427    pub async fn accept(&self) -> Result<AcceptedConnection, TransportError> {
428        Err(TransportError::Connection {
429            message: "WebSocket transport requires the 'websocket' feature".to_string(),
430        })
431    }
432
433    /// Start listening (stub).
434    pub async fn start(&self) -> Result<(), TransportError> {
435        Err(TransportError::Connection {
436            message: "WebSocket transport requires the 'websocket' feature".to_string(),
437        })
438    }
439
440    /// Stop the listener.
441    pub async fn stop(&self) {
442        self.running.store(false, Ordering::Release);
443    }
444
445    /// Check if the listener is running.
446    #[must_use]
447    pub fn is_running(&self) -> bool {
448        self.running.load(Ordering::Acquire)
449    }
450
451    /// Get the bind address.
452    #[must_use]
453    pub fn bind_addr(&self) -> &str {
454        &self.bind_addr
455    }
456}
457
458#[cfg(test)]
459mod tests {
460    use super::*;
461
462    #[test]
463    fn test_listener_creation() {
464        let listener = WebSocketListener::new("0.0.0.0:8080");
465        assert_eq!(listener.bind_addr(), "0.0.0.0:8080");
466        assert!(!listener.is_running());
467    }
468}