mcpkit_transport/websocket/
server.rs1use std::sync::atomic::{AtomicBool, Ordering};
25
26#[cfg(feature = "websocket")]
27use std::sync::Arc;
28#[cfg(feature = "websocket")]
29use std::sync::atomic::AtomicU64;
30
31use crate::error::TransportError;
32
33#[derive(Debug, Clone, Default)]
35pub struct WebSocketServerConfig {
36 pub allowed_origins: Vec<String>,
39 pub max_message_size: usize,
41}
42
43impl WebSocketServerConfig {
44 #[must_use]
46 pub const fn new() -> Self {
47 Self {
48 allowed_origins: Vec::new(),
49 max_message_size: 16 * 1024 * 1024, }
51 }
52
53 #[must_use]
55 pub fn with_allowed_origin(mut self, origin: impl Into<String>) -> Self {
56 self.allowed_origins.push(origin.into());
57 self
58 }
59
60 #[must_use]
62 pub fn with_allowed_origins(
63 mut self,
64 origins: impl IntoIterator<Item = impl Into<String>>,
65 ) -> Self {
66 self.allowed_origins
67 .extend(origins.into_iter().map(Into::into));
68 self
69 }
70
71 #[must_use]
73 pub const fn with_max_message_size(mut self, size: usize) -> Self {
74 self.max_message_size = size;
75 self
76 }
77
78 #[must_use]
80 pub fn is_origin_allowed(&self, origin: &str) -> bool {
81 self.allowed_origins.is_empty() || self.allowed_origins.iter().any(|o| o == origin)
82 }
83}
84
85#[cfg(feature = "websocket")]
106pub struct WebSocketListener {
107 bind_addr: String,
108 config: WebSocketServerConfig,
109 running: AtomicBool,
110 connection_tx: tokio::sync::mpsc::Sender<AcceptedConnection>,
112 connection_rx: crate::runtime::AsyncMutex<tokio::sync::mpsc::Receiver<AcceptedConnection>>,
114 active_connections: Arc<AtomicU64>,
116 shutdown_tx: crate::runtime::AsyncMutex<Option<tokio::sync::broadcast::Sender<()>>>,
118}
119
120#[cfg(feature = "websocket")]
125impl std::panic::RefUnwindSafe for WebSocketListener {}
126
127#[cfg(feature = "websocket")]
129pub struct AcceptedConnection {
130 pub stream: tokio_tungstenite::WebSocketStream<tokio::net::TcpStream>,
132 pub peer_addr: std::net::SocketAddr,
134 pub connection_id: u64,
136}
137
138#[cfg(feature = "websocket")]
139impl WebSocketListener {
140 #[must_use]
142 pub fn new(bind_addr: impl Into<String>) -> Self {
143 let (tx, rx) = tokio::sync::mpsc::channel(32);
145 Self {
146 bind_addr: bind_addr.into(),
147 config: WebSocketServerConfig::new(),
148 running: AtomicBool::new(false),
149 connection_tx: tx,
150 connection_rx: crate::runtime::AsyncMutex::new(rx),
151 active_connections: Arc::new(AtomicU64::new(0)),
152 shutdown_tx: crate::runtime::AsyncMutex::new(None),
153 }
154 }
155
156 #[must_use]
158 pub fn with_config(bind_addr: impl Into<String>, config: WebSocketServerConfig) -> Self {
159 let (tx, rx) = tokio::sync::mpsc::channel(32);
160 Self {
161 bind_addr: bind_addr.into(),
162 config,
163 running: AtomicBool::new(false),
164 connection_tx: tx,
165 connection_rx: crate::runtime::AsyncMutex::new(rx),
166 active_connections: Arc::new(AtomicU64::new(0)),
167 shutdown_tx: crate::runtime::AsyncMutex::new(None),
168 }
169 }
170
171 #[must_use]
173 pub const fn config(&self) -> &WebSocketServerConfig {
174 &self.config
175 }
176
177 #[must_use]
179 pub fn active_connections(&self) -> u64 {
180 self.active_connections.load(Ordering::Relaxed)
181 }
182
183 pub async fn accept(&self) -> Result<AcceptedConnection, TransportError> {
197 let mut rx = self.connection_rx.lock().await;
198 rx.recv().await.ok_or_else(|| TransportError::Connection {
199 message: "Listener stopped".to_string(),
200 })
201 }
202
203 pub async fn start(&self) -> Result<(), TransportError> {
209 use tokio::net::TcpListener;
210
211 let listener =
212 TcpListener::bind(&self.bind_addr)
213 .await
214 .map_err(|e| TransportError::Connection {
215 message: format!("Failed to bind WebSocket listener: {e}"),
216 })?;
217
218 self.running.store(true, Ordering::Release);
219 tracing::info!(addr = %self.bind_addr, "WebSocket listener started");
220
221 let (shutdown_tx, _) = tokio::sync::broadcast::channel::<()>(1);
222 *self.shutdown_tx.lock().await = Some(shutdown_tx.clone());
223
224 let connection_id = Arc::new(AtomicU64::new(0));
225
226 while self.running.load(Ordering::Acquire) {
227 let mut shutdown_rx = shutdown_tx.subscribe();
228
229 tokio::select! {
230 accept_result = listener.accept() => {
231 match accept_result {
232 Ok((stream, addr)) => {
233 tracing::debug!(peer = %addr, "Accepting WebSocket connection");
234
235 let allowed_origins = self.config.allowed_origins.clone();
236 let tx = self.connection_tx.clone();
237 let conn_id = connection_id.fetch_add(1, Ordering::Relaxed);
238 let active_conns_counter = Arc::clone(&self.active_connections);
239
240 self.active_connections.fetch_add(1, Ordering::Relaxed);
242
243 let guard = ActiveConnectionGuard {
245 counter: active_conns_counter,
246 };
247
248 tokio::spawn(async move {
250 let _guard = guard;
251
252 let callback = |request: &tokio_tungstenite::tungstenite::handshake::server::Request,
254 response: tokio_tungstenite::tungstenite::handshake::server::Response| {
255 if !allowed_origins.is_empty() {
257 if let Some(origin) = request.headers().get("origin") {
258 let origin_str = origin.to_str().unwrap_or("");
259 if !allowed_origins.iter().any(|o| o == origin_str) {
260 tracing::warn!(
261 peer = %addr,
262 origin = %origin_str,
263 "Rejecting WebSocket connection from disallowed origin"
264 );
265 return Err(tokio_tungstenite::tungstenite::handshake::server::Response::builder()
266 .status(403)
267 .body(Some("Origin not allowed".to_string()))
268 .expect("failed to build HTTP 403 response"));
269 }
270 } else {
271 tracing::warn!(
273 peer = %addr,
274 "Rejecting WebSocket connection with missing Origin header"
275 );
276 return Err(tokio_tungstenite::tungstenite::handshake::server::Response::builder()
277 .status(403)
278 .body(Some("Origin header required".to_string()))
279 .expect("failed to build HTTP 403 response"));
280 }
281 }
282 Ok(response)
283 };
284
285 match tokio_tungstenite::accept_hdr_async(stream, callback).await {
286 Ok(ws_stream) => {
287 tracing::info!(
288 peer = %addr,
289 connection_id = conn_id,
290 "WebSocket connection established"
291 );
292
293 let connection = AcceptedConnection {
295 stream: ws_stream,
296 peer_addr: addr,
297 connection_id: conn_id,
298 };
299
300 if tx.send(connection).await.is_err() {
301 tracing::warn!(
302 connection_id = conn_id,
303 "Connection channel closed, dropping connection"
304 );
305 }
306 }
307 Err(e) => {
308 tracing::error!(
309 peer = %addr,
310 error = %e,
311 "WebSocket handshake failed"
312 );
313 }
314 }
315 });
316 }
317 Err(e) => {
318 if self.running.load(Ordering::Acquire) {
319 tracing::error!(error = %e, "Error accepting connection");
320 }
321 }
322 }
323 }
324 _ = shutdown_rx.recv() => {
325 tracing::info!("WebSocket listener shutting down");
326 break;
327 }
328 }
329 }
330
331 self.running.store(false, Ordering::Release);
332 Ok(())
333 }
334
335 pub async fn stop(&self) {
340 self.running.store(false, Ordering::Release);
341 if let Some(tx) = self.shutdown_tx.lock().await.take() {
342 let _ = tx.send(());
343 }
344 tracing::info!(
345 active_connections = self.active_connections(),
346 "WebSocket listener stopped"
347 );
348 }
349
350 #[must_use]
352 pub fn is_running(&self) -> bool {
353 self.running.load(Ordering::Acquire)
354 }
355
356 #[must_use]
358 pub fn bind_addr(&self) -> &str {
359 &self.bind_addr
360 }
361}
362
363#[cfg(feature = "websocket")]
367struct ActiveConnectionGuard {
368 counter: Arc<AtomicU64>,
369}
370
371#[cfg(feature = "websocket")]
372impl Drop for ActiveConnectionGuard {
373 fn drop(&mut self) {
374 self.counter.fetch_sub(1, Ordering::Relaxed);
375 }
376}
377
378#[cfg(not(feature = "websocket"))]
380pub struct WebSocketListener {
381 bind_addr: String,
382 config: WebSocketServerConfig,
383 running: AtomicBool,
384}
385
386#[cfg(not(feature = "websocket"))]
388pub struct AcceptedConnection {
389 _private: (),
390}
391
392#[cfg(not(feature = "websocket"))]
393impl WebSocketListener {
394 #[must_use]
396 pub fn new(bind_addr: impl Into<String>) -> Self {
397 Self {
398 bind_addr: bind_addr.into(),
399 config: WebSocketServerConfig::new(),
400 running: AtomicBool::new(false),
401 }
402 }
403
404 #[must_use]
406 pub fn with_config(bind_addr: impl Into<String>, config: WebSocketServerConfig) -> Self {
407 Self {
408 bind_addr: bind_addr.into(),
409 config,
410 running: AtomicBool::new(false),
411 }
412 }
413
414 #[must_use]
416 pub const fn config(&self) -> &WebSocketServerConfig {
417 &self.config
418 }
419
420 #[must_use]
422 pub fn active_connections(&self) -> u64 {
423 0
424 }
425
426 pub async fn accept(&self) -> Result<AcceptedConnection, TransportError> {
428 Err(TransportError::Connection {
429 message: "WebSocket transport requires the 'websocket' feature".to_string(),
430 })
431 }
432
433 pub async fn start(&self) -> Result<(), TransportError> {
435 Err(TransportError::Connection {
436 message: "WebSocket transport requires the 'websocket' feature".to_string(),
437 })
438 }
439
440 pub async fn stop(&self) {
442 self.running.store(false, Ordering::Release);
443 }
444
445 #[must_use]
447 pub fn is_running(&self) -> bool {
448 self.running.load(Ordering::Acquire)
449 }
450
451 #[must_use]
453 pub fn bind_addr(&self) -> &str {
454 &self.bind_addr
455 }
456}
457
458#[cfg(test)]
459mod tests {
460 use super::*;
461
462 #[test]
463 fn test_listener_creation() {
464 let listener = WebSocketListener::new("0.0.0.0:8080");
465 assert_eq!(listener.bind_addr(), "0.0.0.0:8080");
466 assert!(!listener.is_running());
467 }
468}