Skip to main content

mockforge_ws/
handlers.rs

1//! # Programmable WebSocket Handlers
2//!
3//! This module provides a flexible handler API for scripting WebSocket event flows.
4//! Unlike static replay, handlers allow you to write custom logic for responding to
5//! WebSocket events, manage rooms, and route messages dynamically.
6//!
7//! ## Features
8//!
9//! - **Connection Lifecycle**: `on_connect` and `on_disconnect` hooks
10//! - **Pattern Matching**: Route messages with regex or JSONPath patterns
11//! - **Room Management**: Broadcast messages to groups of connections
12//! - **Passthrough**: Selectively forward messages to upstream servers
13//! - **Hot Reload**: Automatically reload handlers when code changes (via `MOCKFORGE_WS_HOTRELOAD`)
14//!
15//! ## Quick Start
16//!
17//! ```rust,no_run
18//! use mockforge_ws::handlers::{WsHandler, WsContext, WsMessage, HandlerResult};
19//! use async_trait::async_trait;
20//!
21//! struct EchoHandler;
22//!
23//! #[async_trait]
24//! impl WsHandler for EchoHandler {
25//!     async fn on_connect(&self, ctx: &mut WsContext) -> HandlerResult<()> {
26//!         ctx.send_text("Welcome to the echo server!").await?;
27//!         Ok(())
28//!     }
29//!
30//!     async fn on_message(&self, ctx: &mut WsContext, msg: WsMessage) -> HandlerResult<()> {
31//!         if let WsMessage::Text(text) = msg {
32//!             ctx.send_text(&format!("echo: {}", text)).await?;
33//!         }
34//!         Ok(())
35//!     }
36//! }
37//! ```
38//!
39//! ## Message Pattern Matching
40//!
41//! ```rust,no_run
42//! use mockforge_ws::handlers::{WsHandler, WsContext, WsMessage, HandlerResult, MessagePattern};
43//! use async_trait::async_trait;
44//!
45//! struct ChatHandler;
46//!
47//! #[async_trait]
48//! impl WsHandler for ChatHandler {
49//!     async fn on_message(&self, ctx: &mut WsContext, msg: WsMessage) -> HandlerResult<()> {
50//!         if let WsMessage::Text(text) = msg {
51//!             // Use pattern matching to route messages
52//!             if let Ok(pattern) = MessagePattern::regex(r"^/join (.+)$") {
53//!                 if pattern.matches(&text) {
54//!                     // Extract room name and join
55//!                     if let Some(room) = text.strip_prefix("/join ") {
56//!                         ctx.join_room(room).await?;
57//!                         ctx.send_text(&format!("Joined room: {}", room)).await?;
58//!                     }
59//!                 }
60//!             }
61//!             // Handle JSON chat messages
62//!             let jsonpath_pattern = MessagePattern::jsonpath("$.type");
63//!             if jsonpath_pattern.matches(&text) {
64//!                 ctx.broadcast_to_room("general", &text).await?;
65//!             }
66//!         }
67//!         Ok(())
68//!     }
69//! }
70//! ```
71
72use async_trait::async_trait;
73use axum::extract::ws::Message;
74use futures_util::{SinkExt, StreamExt};
75use regex::Regex;
76use serde_json::Value;
77use std::collections::{HashMap, HashSet};
78use std::sync::Arc;
79use tokio::sync::{broadcast, Mutex, RwLock};
80
81/// Result type for handler operations
82pub type HandlerResult<T> = Result<T, HandlerError>;
83
84/// Error type for handler operations
85#[derive(Debug, thiserror::Error)]
86pub enum HandlerError {
87    /// Failed to send WebSocket message
88    #[error("Failed to send message: {0}")]
89    SendError(String),
90
91    /// JSON parsing/serialization error
92    #[error("Failed to parse JSON: {0}")]
93    JsonError(#[from] serde_json::Error),
94
95    /// Pattern matching failure (e.g., route pattern)
96    #[error("Pattern matching error: {0}")]
97    PatternError(String),
98
99    /// Room/group operation failure
100    #[error("Room operation failed: {0}")]
101    RoomError(String),
102
103    /// WebSocket connection error
104    #[error("Connection error: {0}")]
105    ConnectionError(String),
106
107    /// Generic handler error
108    #[error("Handler error: {0}")]
109    Generic(String),
110}
111
112/// WebSocket message wrapper for different message types
113#[derive(Debug, Clone)]
114pub enum WsMessage {
115    /// Text message (UTF-8 string)
116    Text(String),
117    /// Binary message (raw bytes)
118    Binary(Vec<u8>),
119    /// Ping frame (connection keepalive)
120    Ping(Vec<u8>),
121    /// Pong frame (response to ping)
122    Pong(Vec<u8>),
123    /// Close frame (connection termination)
124    Close,
125}
126
127impl From<Message> for WsMessage {
128    fn from(msg: Message) -> Self {
129        match msg {
130            Message::Text(text) => WsMessage::Text(text.to_string()),
131            Message::Binary(data) => WsMessage::Binary(data.to_vec()),
132            Message::Ping(data) => WsMessage::Ping(data.to_vec()),
133            Message::Pong(data) => WsMessage::Pong(data.to_vec()),
134            Message::Close(_) => WsMessage::Close,
135        }
136    }
137}
138
139impl From<WsMessage> for Message {
140    fn from(msg: WsMessage) -> Self {
141        match msg {
142            WsMessage::Text(text) => Message::Text(text.into()),
143            WsMessage::Binary(data) => Message::Binary(data.into()),
144            WsMessage::Ping(data) => Message::Ping(data.into()),
145            WsMessage::Pong(data) => Message::Pong(data.into()),
146            WsMessage::Close => Message::Close(None),
147        }
148    }
149}
150
151/// Pattern for matching WebSocket messages
152#[derive(Debug, Clone)]
153pub enum MessagePattern {
154    /// Match using regular expression
155    Regex(Regex),
156    /// Match using JSONPath query
157    JsonPath(String),
158    /// Match exact text
159    Exact(String),
160    /// Always matches
161    Any,
162}
163
164impl MessagePattern {
165    /// Create a regex pattern
166    pub fn regex(pattern: &str) -> HandlerResult<Self> {
167        Ok(MessagePattern::Regex(
168            Regex::new(pattern).map_err(|e| HandlerError::PatternError(e.to_string()))?,
169        ))
170    }
171
172    /// Create a JSONPath pattern
173    pub fn jsonpath(query: &str) -> Self {
174        MessagePattern::JsonPath(query.to_string())
175    }
176
177    /// Create an exact match pattern
178    pub fn exact(text: &str) -> Self {
179        MessagePattern::Exact(text.to_string())
180    }
181
182    /// Create a pattern that matches everything
183    pub fn any() -> Self {
184        MessagePattern::Any
185    }
186
187    /// Check if the pattern matches the message
188    pub fn matches(&self, text: &str) -> bool {
189        match self {
190            MessagePattern::Regex(re) => re.is_match(text),
191            MessagePattern::JsonPath(query) => {
192                // Try to parse as JSON and check if path exists
193                if let Ok(json) = serde_json::from_str::<Value>(text) {
194                    // Use jsonpath crate's Selector
195                    if let Ok(selector) = jsonpath::Selector::new(query) {
196                        let results: Vec<_> = selector.find(&json).collect();
197                        !results.is_empty()
198                    } else {
199                        false
200                    }
201                } else {
202                    false
203                }
204            }
205            MessagePattern::Exact(expected) => text == expected,
206            MessagePattern::Any => true,
207        }
208    }
209
210    /// Check if pattern matches and extract value using JSONPath
211    pub fn extract(&self, text: &str, query: &str) -> Option<Value> {
212        if let Ok(json) = serde_json::from_str::<Value>(text) {
213            if let Ok(selector) = jsonpath::Selector::new(query) {
214                let results: Vec<_> = selector.find(&json).collect();
215                results.first().cloned().cloned()
216            } else {
217                None
218            }
219        } else {
220            None
221        }
222    }
223}
224
225/// Connection ID type
226pub type ConnectionId = String;
227
228/// Room manager for broadcasting messages to groups of connections
229#[derive(Clone)]
230pub struct RoomManager {
231    rooms: Arc<RwLock<HashMap<String, HashSet<ConnectionId>>>>,
232    connections: Arc<RwLock<HashMap<ConnectionId, HashSet<String>>>>,
233    broadcasters: Arc<RwLock<HashMap<String, broadcast::Sender<String>>>>,
234}
235
236impl RoomManager {
237    /// Create a new room manager
238    pub fn new() -> Self {
239        Self {
240            rooms: Arc::new(RwLock::new(HashMap::new())),
241            connections: Arc::new(RwLock::new(HashMap::new())),
242            broadcasters: Arc::new(RwLock::new(HashMap::new())),
243        }
244    }
245
246    /// Join a room
247    pub async fn join(&self, conn_id: &str, room: &str) -> HandlerResult<()> {
248        let mut rooms = self.rooms.write().await;
249        let mut connections = self.connections.write().await;
250
251        rooms
252            .entry(room.to_string())
253            .or_insert_with(HashSet::new)
254            .insert(conn_id.to_string());
255
256        connections
257            .entry(conn_id.to_string())
258            .or_insert_with(HashSet::new)
259            .insert(room.to_string());
260
261        Ok(())
262    }
263
264    /// Leave a room
265    pub async fn leave(&self, conn_id: &str, room: &str) -> HandlerResult<()> {
266        let mut rooms = self.rooms.write().await;
267        let mut connections = self.connections.write().await;
268
269        if let Some(room_members) = rooms.get_mut(room) {
270            room_members.remove(conn_id);
271            if room_members.is_empty() {
272                rooms.remove(room);
273            }
274        }
275
276        if let Some(conn_rooms) = connections.get_mut(conn_id) {
277            conn_rooms.remove(room);
278            if conn_rooms.is_empty() {
279                connections.remove(conn_id);
280            }
281        }
282
283        Ok(())
284    }
285
286    /// Leave all rooms for a connection
287    pub async fn leave_all(&self, conn_id: &str) -> HandlerResult<()> {
288        let mut connections = self.connections.write().await;
289        if let Some(conn_rooms) = connections.remove(conn_id) {
290            let mut rooms = self.rooms.write().await;
291            for room in conn_rooms {
292                if let Some(room_members) = rooms.get_mut(&room) {
293                    room_members.remove(conn_id);
294                    if room_members.is_empty() {
295                        rooms.remove(&room);
296                    }
297                }
298            }
299        }
300        Ok(())
301    }
302
303    /// Get all connections in a room
304    pub async fn get_room_members(&self, room: &str) -> Vec<ConnectionId> {
305        let rooms = self.rooms.read().await;
306        rooms
307            .get(room)
308            .map(|members| members.iter().cloned().collect())
309            .unwrap_or_default()
310    }
311
312    /// Get all rooms for a connection
313    pub async fn get_connection_rooms(&self, conn_id: &str) -> Vec<String> {
314        let connections = self.connections.read().await;
315        connections
316            .get(conn_id)
317            .map(|rooms| rooms.iter().cloned().collect())
318            .unwrap_or_default()
319    }
320
321    /// Get broadcast sender for a room (creates if doesn't exist)
322    pub async fn get_broadcaster(&self, room: &str) -> broadcast::Sender<String> {
323        let mut broadcasters = self.broadcasters.write().await;
324        broadcasters
325            .entry(room.to_string())
326            .or_insert_with(|| {
327                let (tx, _) = broadcast::channel(1024);
328                tx
329            })
330            .clone()
331    }
332}
333
334impl Default for RoomManager {
335    fn default() -> Self {
336        Self::new()
337    }
338}
339
340/// Context provided to handlers for each connection
341pub struct WsContext {
342    /// Unique connection ID
343    pub connection_id: ConnectionId,
344    /// WebSocket path
345    pub path: String,
346    /// Room manager for broadcasting
347    room_manager: RoomManager,
348    /// Sender for outgoing messages
349    message_tx: tokio::sync::mpsc::UnboundedSender<Message>,
350    /// Metadata storage
351    metadata: Arc<RwLock<HashMap<String, Value>>>,
352}
353
354impl WsContext {
355    /// Create a new WebSocket context
356    pub fn new(
357        connection_id: ConnectionId,
358        path: String,
359        room_manager: RoomManager,
360        message_tx: tokio::sync::mpsc::UnboundedSender<Message>,
361    ) -> Self {
362        Self {
363            connection_id,
364            path,
365            room_manager,
366            message_tx,
367            metadata: Arc::new(RwLock::new(HashMap::new())),
368        }
369    }
370
371    /// Send a text message
372    pub async fn send_text(&self, text: &str) -> HandlerResult<()> {
373        self.message_tx
374            .send(Message::Text(text.to_string().into()))
375            .map_err(|e| HandlerError::SendError(e.to_string()))
376    }
377
378    /// Send a binary message
379    pub async fn send_binary(&self, data: Vec<u8>) -> HandlerResult<()> {
380        self.message_tx
381            .send(Message::Binary(data.into()))
382            .map_err(|e| HandlerError::SendError(e.to_string()))
383    }
384
385    /// Send a JSON message
386    pub async fn send_json(&self, value: &Value) -> HandlerResult<()> {
387        let text = serde_json::to_string(value)?;
388        self.send_text(&text).await
389    }
390
391    /// Join a room
392    pub async fn join_room(&self, room: &str) -> HandlerResult<()> {
393        self.room_manager.join(&self.connection_id, room).await
394    }
395
396    /// Leave a room
397    pub async fn leave_room(&self, room: &str) -> HandlerResult<()> {
398        self.room_manager.leave(&self.connection_id, room).await
399    }
400
401    /// Broadcast text to all members in a room
402    pub async fn broadcast_to_room(&self, room: &str, text: &str) -> HandlerResult<()> {
403        let broadcaster = self.room_manager.get_broadcaster(room).await;
404        broadcaster
405            .send(text.to_string())
406            .map_err(|e| HandlerError::RoomError(e.to_string()))?;
407        Ok(())
408    }
409
410    /// Get all rooms this connection is in
411    pub async fn get_rooms(&self) -> Vec<String> {
412        self.room_manager.get_connection_rooms(&self.connection_id).await
413    }
414
415    /// Set metadata value
416    pub async fn set_metadata(&self, key: &str, value: Value) {
417        let mut metadata = self.metadata.write().await;
418        metadata.insert(key.to_string(), value);
419    }
420
421    /// Get metadata value
422    pub async fn get_metadata(&self, key: &str) -> Option<Value> {
423        let metadata = self.metadata.read().await;
424        metadata.get(key).cloned()
425    }
426}
427
428/// Trait for WebSocket message handlers
429#[async_trait]
430pub trait WsHandler: Send + Sync {
431    /// Called when a new WebSocket connection is established
432    async fn on_connect(&self, _ctx: &mut WsContext) -> HandlerResult<()> {
433        Ok(())
434    }
435
436    /// Called when a message is received
437    async fn on_message(&self, ctx: &mut WsContext, msg: WsMessage) -> HandlerResult<()>;
438
439    /// Called when the connection is closed
440    async fn on_disconnect(&self, _ctx: &mut WsContext) -> HandlerResult<()> {
441        Ok(())
442    }
443
444    /// Check if this handler should handle the given path
445    fn handles_path(&self, _path: &str) -> bool {
446        true // Default: handle all paths
447    }
448}
449
450/// Handler function type for message routing
451type MessageHandler = Box<dyn Fn(String) -> Option<String> + Send + Sync>;
452
453/// Pattern-based message router
454pub struct MessageRouter {
455    routes: Vec<(MessagePattern, MessageHandler)>,
456}
457
458impl MessageRouter {
459    /// Create a new message router
460    pub fn new() -> Self {
461        Self { routes: Vec::new() }
462    }
463
464    /// Add a route with a pattern and handler function
465    pub fn on<F>(&mut self, pattern: MessagePattern, handler: F) -> &mut Self
466    where
467        F: Fn(String) -> Option<String> + Send + Sync + 'static,
468    {
469        self.routes.push((pattern, Box::new(handler)));
470        self
471    }
472
473    /// Route a message through the registered handlers
474    pub fn route(&self, text: &str) -> Option<String> {
475        for (pattern, handler) in &self.routes {
476            if pattern.matches(text) {
477                if let Some(response) = handler(text.to_string()) {
478                    return Some(response);
479                }
480            }
481        }
482        None
483    }
484}
485
486impl Default for MessageRouter {
487    fn default() -> Self {
488        Self::new()
489    }
490}
491
492/// Handler registry for managing multiple handlers
493pub struct HandlerRegistry {
494    handlers: Vec<Arc<dyn WsHandler>>,
495    hot_reload_enabled: bool,
496}
497
498impl HandlerRegistry {
499    /// Create a new handler registry
500    pub fn new() -> Self {
501        Self {
502            handlers: Vec::new(),
503            hot_reload_enabled: std::env::var("MOCKFORGE_WS_HOTRELOAD")
504                .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
505                .unwrap_or(false),
506        }
507    }
508
509    /// Create a registry with hot-reload enabled
510    pub fn with_hot_reload() -> Self {
511        Self {
512            handlers: Vec::new(),
513            hot_reload_enabled: true,
514        }
515    }
516
517    /// Check if hot-reload is enabled
518    pub fn is_hot_reload_enabled(&self) -> bool {
519        self.hot_reload_enabled
520    }
521
522    /// Register a handler
523    pub fn register<H: WsHandler + 'static>(&mut self, handler: H) -> &mut Self {
524        self.handlers.push(Arc::new(handler));
525        self
526    }
527
528    /// Get handlers for a specific path
529    pub fn get_handlers(&self, path: &str) -> Vec<Arc<dyn WsHandler>> {
530        self.handlers.iter().filter(|h| h.handles_path(path)).cloned().collect()
531    }
532
533    /// Check if any handler handles the given path
534    pub fn has_handler_for(&self, path: &str) -> bool {
535        self.handlers.iter().any(|h| h.handles_path(path))
536    }
537
538    /// Clear all handlers (useful for hot-reload)
539    pub fn clear(&mut self) {
540        self.handlers.clear();
541    }
542
543    /// Get the number of registered handlers
544    pub fn len(&self) -> usize {
545        self.handlers.len()
546    }
547
548    /// Check if the registry is empty
549    pub fn is_empty(&self) -> bool {
550        self.handlers.is_empty()
551    }
552}
553
554impl Default for HandlerRegistry {
555    fn default() -> Self {
556        Self::new()
557    }
558}
559
560/// Passthrough handler configuration for forwarding messages to upstream servers
561#[derive(Clone)]
562pub struct PassthroughConfig {
563    /// Pattern to match paths for passthrough
564    pub pattern: MessagePattern,
565    /// Upstream URL to forward to
566    pub upstream_url: String,
567}
568
569impl PassthroughConfig {
570    /// Create a new passthrough configuration
571    pub fn new(pattern: MessagePattern, upstream_url: String) -> Self {
572        Self {
573            pattern,
574            upstream_url,
575        }
576    }
577
578    /// Create a passthrough for all messages matching a regex
579    pub fn regex(regex: &str, upstream_url: String) -> HandlerResult<Self> {
580        Ok(Self {
581            pattern: MessagePattern::regex(regex)?,
582            upstream_url,
583        })
584    }
585}
586
587/// Passthrough handler that forwards messages to an upstream WebSocket server
588pub struct PassthroughHandler {
589    config: PassthroughConfig,
590    /// Upstream WebSocket write half, established lazily on first message
591    upstream_tx: Mutex<Option<UpstreamSender>>,
592}
593
594type UpstreamSender = futures_util::stream::SplitSink<
595    tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
596    tokio_tungstenite::tungstenite::Message,
597>;
598
599impl PassthroughHandler {
600    /// Create a new passthrough handler
601    pub fn new(config: PassthroughConfig) -> Self {
602        Self {
603            config,
604            upstream_tx: Mutex::new(None),
605        }
606    }
607
608    /// Check if a message should be passed through
609    pub fn should_passthrough(&self, text: &str) -> bool {
610        self.config.pattern.matches(text)
611    }
612
613    /// Get the upstream URL
614    pub fn upstream_url(&self) -> &str {
615        &self.config.upstream_url
616    }
617
618    /// Connect to the upstream WebSocket server and spawn a reader task that
619    /// relays messages back to the client.
620    async fn ensure_connected(
621        &self,
622        client_tx: &tokio::sync::mpsc::UnboundedSender<Message>,
623    ) -> HandlerResult<()> {
624        let mut guard = self.upstream_tx.lock().await;
625        if guard.is_some() {
626            return Ok(());
627        }
628
629        let url = &self.config.upstream_url;
630        tracing::info!(upstream = %url, "Connecting to upstream WebSocket server");
631
632        let (ws_stream, _response) = tokio_tungstenite::connect_async(url)
633            .await
634            .map_err(|e| HandlerError::ConnectionError(format!("Upstream connect failed: {e}")))?;
635
636        let (write, mut read) = ws_stream.split();
637        *guard = Some(write);
638
639        // Spawn a task that forwards upstream → client
640        let client_tx = client_tx.clone();
641        tokio::spawn(async move {
642            while let Some(Ok(msg)) = read.next().await {
643                let axum_msg = match msg {
644                    tokio_tungstenite::tungstenite::Message::Text(t) => {
645                        Message::Text(t.to_string().into())
646                    }
647                    tokio_tungstenite::tungstenite::Message::Binary(b) => {
648                        Message::Binary(b.to_vec().into())
649                    }
650                    tokio_tungstenite::tungstenite::Message::Ping(p) => {
651                        Message::Ping(p.to_vec().into())
652                    }
653                    tokio_tungstenite::tungstenite::Message::Pong(p) => {
654                        Message::Pong(p.to_vec().into())
655                    }
656                    tokio_tungstenite::tungstenite::Message::Close(_) => {
657                        break;
658                    }
659                    tokio_tungstenite::tungstenite::Message::Frame(_) => continue,
660                };
661                if client_tx.send(axum_msg).is_err() {
662                    break;
663                }
664            }
665            tracing::debug!("Upstream reader task finished");
666        });
667
668        Ok(())
669    }
670}
671
672#[async_trait]
673impl WsHandler for PassthroughHandler {
674    async fn on_connect(&self, ctx: &mut WsContext) -> HandlerResult<()> {
675        self.ensure_connected(&ctx.message_tx).await
676    }
677
678    async fn on_message(&self, ctx: &mut WsContext, msg: WsMessage) -> HandlerResult<()> {
679        match &msg {
680            WsMessage::Text(text) if self.should_passthrough(text) => {
681                self.ensure_connected(&ctx.message_tx).await?;
682                let mut guard = self.upstream_tx.lock().await;
683                if let Some(ref mut writer) = *guard {
684                    writer
685                        .send(tokio_tungstenite::tungstenite::Message::Text(text.clone().into()))
686                        .await
687                        .map_err(|e| {
688                            HandlerError::SendError(format!("Upstream send failed: {e}"))
689                        })?;
690                }
691            }
692            WsMessage::Binary(data) => {
693                self.ensure_connected(&ctx.message_tx).await?;
694                let mut guard = self.upstream_tx.lock().await;
695                if let Some(ref mut writer) = *guard {
696                    writer
697                        .send(tokio_tungstenite::tungstenite::Message::Binary(data.clone().into()))
698                        .await
699                        .map_err(|e| {
700                            HandlerError::SendError(format!("Upstream send failed: {e}"))
701                        })?;
702                }
703            }
704            _ => {}
705        }
706        Ok(())
707    }
708
709    async fn on_disconnect(&self, _ctx: &mut WsContext) -> HandlerResult<()> {
710        let mut guard = self.upstream_tx.lock().await;
711        if let Some(mut writer) = guard.take() {
712            let _ = writer.send(tokio_tungstenite::tungstenite::Message::Close(None)).await;
713        }
714        Ok(())
715    }
716}
717
718#[cfg(test)]
719mod tests {
720    use super::*;
721
722    // ==================== WsMessage Tests ====================
723
724    #[test]
725    fn test_ws_message_text_from_axum() {
726        let axum_msg = Message::Text("hello".to_string().into());
727        let ws_msg: WsMessage = axum_msg.into();
728        match ws_msg {
729            WsMessage::Text(text) => assert_eq!(text, "hello"),
730            _ => panic!("Expected Text message"),
731        }
732    }
733
734    #[test]
735    fn test_ws_message_binary_from_axum() {
736        let data = vec![1, 2, 3, 4];
737        let axum_msg = Message::Binary(data.clone().into());
738        let ws_msg: WsMessage = axum_msg.into();
739        match ws_msg {
740            WsMessage::Binary(bytes) => assert_eq!(bytes, data),
741            _ => panic!("Expected Binary message"),
742        }
743    }
744
745    #[test]
746    fn test_ws_message_ping_from_axum() {
747        let data = vec![1, 2];
748        let axum_msg = Message::Ping(data.clone().into());
749        let ws_msg: WsMessage = axum_msg.into();
750        match ws_msg {
751            WsMessage::Ping(bytes) => assert_eq!(bytes, data),
752            _ => panic!("Expected Ping message"),
753        }
754    }
755
756    #[test]
757    fn test_ws_message_pong_from_axum() {
758        let data = vec![3, 4];
759        let axum_msg = Message::Pong(data.clone().into());
760        let ws_msg: WsMessage = axum_msg.into();
761        match ws_msg {
762            WsMessage::Pong(bytes) => assert_eq!(bytes, data),
763            _ => panic!("Expected Pong message"),
764        }
765    }
766
767    #[test]
768    fn test_ws_message_close_from_axum() {
769        let axum_msg = Message::Close(None);
770        let ws_msg: WsMessage = axum_msg.into();
771        assert!(matches!(ws_msg, WsMessage::Close));
772    }
773
774    #[test]
775    fn test_ws_message_text_to_axum() {
776        let ws_msg = WsMessage::Text("hello".to_string());
777        let axum_msg: Message = ws_msg.into();
778        assert!(matches!(axum_msg, Message::Text(_)));
779    }
780
781    #[test]
782    fn test_ws_message_binary_to_axum() {
783        let ws_msg = WsMessage::Binary(vec![1, 2, 3]);
784        let axum_msg: Message = ws_msg.into();
785        assert!(matches!(axum_msg, Message::Binary(_)));
786    }
787
788    #[test]
789    fn test_ws_message_close_to_axum() {
790        let ws_msg = WsMessage::Close;
791        let axum_msg: Message = ws_msg.into();
792        assert!(matches!(axum_msg, Message::Close(_)));
793    }
794
795    // ==================== MessagePattern Tests ====================
796
797    #[test]
798    fn test_message_pattern_regex() {
799        let pattern = MessagePattern::regex(r"^hello").unwrap();
800        assert!(pattern.matches("hello world"));
801        assert!(!pattern.matches("goodbye world"));
802    }
803
804    #[test]
805    fn test_message_pattern_regex_invalid() {
806        let result = MessagePattern::regex(r"[invalid");
807        assert!(result.is_err());
808    }
809
810    #[test]
811    fn test_message_pattern_exact() {
812        let pattern = MessagePattern::exact("hello");
813        assert!(pattern.matches("hello"));
814        assert!(!pattern.matches("hello world"));
815    }
816
817    #[test]
818    fn test_message_pattern_jsonpath() {
819        let pattern = MessagePattern::jsonpath("$.type");
820        assert!(pattern.matches(r#"{"type": "message"}"#));
821        assert!(!pattern.matches(r#"{"name": "test"}"#));
822    }
823
824    #[test]
825    fn test_message_pattern_jsonpath_nested() {
826        let pattern = MessagePattern::jsonpath("$.user.name");
827        assert!(pattern.matches(r#"{"user": {"name": "John"}}"#));
828        assert!(!pattern.matches(r#"{"user": {"email": "john@example.com"}}"#));
829    }
830
831    #[test]
832    fn test_message_pattern_jsonpath_invalid_json() {
833        let pattern = MessagePattern::jsonpath("$.type");
834        assert!(!pattern.matches("not json"));
835    }
836
837    #[test]
838    fn test_message_pattern_any() {
839        let pattern = MessagePattern::any();
840        assert!(pattern.matches("anything"));
841        assert!(pattern.matches(""));
842        assert!(pattern.matches(r#"{"json": true}"#));
843    }
844
845    #[test]
846    fn test_message_pattern_extract() {
847        let pattern = MessagePattern::jsonpath("$.type");
848        let result = pattern.extract(r#"{"type": "greeting", "data": "hello"}"#, "$.type");
849        assert_eq!(result, Some(serde_json::json!("greeting")));
850    }
851
852    #[test]
853    fn test_message_pattern_extract_nested() {
854        let pattern = MessagePattern::any();
855        let result = pattern.extract(r#"{"user": {"id": 123}}"#, "$.user.id");
856        assert_eq!(result, Some(serde_json::json!(123)));
857    }
858
859    #[test]
860    fn test_message_pattern_extract_not_found() {
861        let pattern = MessagePattern::any();
862        let result = pattern.extract(r#"{"type": "message"}"#, "$.nonexistent");
863        assert!(result.is_none());
864    }
865
866    #[test]
867    fn test_message_pattern_extract_invalid_json() {
868        let pattern = MessagePattern::any();
869        let result = pattern.extract("not json", "$.type");
870        assert!(result.is_none());
871    }
872
873    // ==================== RoomManager Tests ====================
874
875    #[tokio::test]
876    async fn test_room_manager() {
877        let manager = RoomManager::new();
878
879        // Join rooms
880        manager.join("conn1", "room1").await.unwrap();
881        manager.join("conn1", "room2").await.unwrap();
882        manager.join("conn2", "room1").await.unwrap();
883
884        // Check room members
885        let room1_members = manager.get_room_members("room1").await;
886        assert_eq!(room1_members.len(), 2);
887        assert!(room1_members.contains(&"conn1".to_string()));
888        assert!(room1_members.contains(&"conn2".to_string()));
889
890        // Check connection rooms
891        let conn1_rooms = manager.get_connection_rooms("conn1").await;
892        assert_eq!(conn1_rooms.len(), 2);
893        assert!(conn1_rooms.contains(&"room1".to_string()));
894        assert!(conn1_rooms.contains(&"room2".to_string()));
895
896        // Leave room
897        manager.leave("conn1", "room1").await.unwrap();
898        let room1_members = manager.get_room_members("room1").await;
899        assert_eq!(room1_members.len(), 1);
900        assert!(room1_members.contains(&"conn2".to_string()));
901
902        // Leave all rooms
903        manager.leave_all("conn1").await.unwrap();
904        let conn1_rooms = manager.get_connection_rooms("conn1").await;
905        assert_eq!(conn1_rooms.len(), 0);
906    }
907
908    #[tokio::test]
909    async fn test_room_manager_default() {
910        let manager = RoomManager::default();
911        // Should work the same as new()
912        manager.join("conn1", "room1").await.unwrap();
913        let members = manager.get_room_members("room1").await;
914        assert_eq!(members.len(), 1);
915    }
916
917    #[tokio::test]
918    async fn test_room_manager_empty_room() {
919        let manager = RoomManager::new();
920        let members = manager.get_room_members("nonexistent").await;
921        assert!(members.is_empty());
922    }
923
924    #[tokio::test]
925    async fn test_room_manager_empty_connection() {
926        let manager = RoomManager::new();
927        let rooms = manager.get_connection_rooms("nonexistent").await;
928        assert!(rooms.is_empty());
929    }
930
931    #[tokio::test]
932    async fn test_room_manager_leave_nonexistent() {
933        let manager = RoomManager::new();
934        // Should not error when leaving a room we're not in
935        let result = manager.leave("conn1", "room1").await;
936        assert!(result.is_ok());
937    }
938
939    #[tokio::test]
940    async fn test_room_manager_broadcaster() {
941        let manager = RoomManager::new();
942        manager.join("conn1", "room1").await.unwrap();
943
944        let broadcaster = manager.get_broadcaster("room1").await;
945        let mut receiver = broadcaster.subscribe();
946
947        // Send a message
948        broadcaster.send("hello".to_string()).unwrap();
949
950        // Receive it
951        let msg = receiver.recv().await.unwrap();
952        assert_eq!(msg, "hello");
953    }
954
955    #[tokio::test]
956    async fn test_room_manager_room_cleanup_on_last_leave() {
957        let manager = RoomManager::new();
958        manager.join("conn1", "room1").await.unwrap();
959        manager.leave("conn1", "room1").await.unwrap();
960
961        // Room should be cleaned up
962        let members = manager.get_room_members("room1").await;
963        assert!(members.is_empty());
964    }
965
966    // ==================== MessageRouter Tests ====================
967
968    #[test]
969    fn test_message_router() {
970        let mut router = MessageRouter::new();
971
972        router
973            .on(MessagePattern::exact("ping"), |_| Some("pong".to_string()))
974            .on(MessagePattern::regex(r"^hello").unwrap(), |_| Some("hi there!".to_string()));
975
976        assert_eq!(router.route("ping"), Some("pong".to_string()));
977        assert_eq!(router.route("hello world"), Some("hi there!".to_string()));
978        assert_eq!(router.route("goodbye"), None);
979    }
980
981    #[test]
982    fn test_message_router_default() {
983        let router = MessageRouter::default();
984        // Empty router returns None for all messages
985        assert_eq!(router.route("anything"), None);
986    }
987
988    #[test]
989    fn test_message_router_first_match_wins() {
990        let mut router = MessageRouter::new();
991        router
992            .on(MessagePattern::any(), |_| Some("first".to_string()))
993            .on(MessagePattern::any(), |_| Some("second".to_string()));
994
995        assert_eq!(router.route("test"), Some("first".to_string()));
996    }
997
998    #[test]
999    fn test_message_router_handler_returns_none() {
1000        let mut router = MessageRouter::new();
1001        router
1002            .on(MessagePattern::exact("skip"), |_| None)
1003            .on(MessagePattern::any(), |_| Some("fallback".to_string()));
1004
1005        // First pattern matches but returns None, so it continues to next
1006        assert_eq!(router.route("skip"), Some("fallback".to_string()));
1007    }
1008
1009    // ==================== HandlerRegistry Tests ====================
1010
1011    struct TestHandler;
1012
1013    #[async_trait]
1014    impl WsHandler for TestHandler {
1015        async fn on_message(&self, _ctx: &mut WsContext, _msg: WsMessage) -> HandlerResult<()> {
1016            Ok(())
1017        }
1018    }
1019
1020    struct PathSpecificHandler {
1021        path: String,
1022    }
1023
1024    #[async_trait]
1025    impl WsHandler for PathSpecificHandler {
1026        async fn on_message(&self, _ctx: &mut WsContext, _msg: WsMessage) -> HandlerResult<()> {
1027            Ok(())
1028        }
1029
1030        fn handles_path(&self, path: &str) -> bool {
1031            path == self.path
1032        }
1033    }
1034
1035    #[test]
1036    fn test_handler_registry_new() {
1037        let registry = HandlerRegistry::new();
1038        assert!(registry.is_empty());
1039        assert_eq!(registry.len(), 0);
1040    }
1041
1042    #[test]
1043    fn test_handler_registry_default() {
1044        let registry = HandlerRegistry::default();
1045        assert!(registry.is_empty());
1046    }
1047
1048    #[test]
1049    fn test_handler_registry_register() {
1050        let mut registry = HandlerRegistry::new();
1051        registry.register(TestHandler);
1052        assert!(!registry.is_empty());
1053        assert_eq!(registry.len(), 1);
1054    }
1055
1056    #[test]
1057    fn test_handler_registry_get_handlers() {
1058        let mut registry = HandlerRegistry::new();
1059        registry.register(TestHandler);
1060
1061        let handlers = registry.get_handlers("/any/path");
1062        assert_eq!(handlers.len(), 1);
1063    }
1064
1065    #[test]
1066    fn test_handler_registry_path_filtering() {
1067        let mut registry = HandlerRegistry::new();
1068        registry.register(PathSpecificHandler {
1069            path: "/ws/chat".to_string(),
1070        });
1071        registry.register(PathSpecificHandler {
1072            path: "/ws/events".to_string(),
1073        });
1074
1075        let chat_handlers = registry.get_handlers("/ws/chat");
1076        assert_eq!(chat_handlers.len(), 1);
1077
1078        let events_handlers = registry.get_handlers("/ws/events");
1079        assert_eq!(events_handlers.len(), 1);
1080
1081        let other_handlers = registry.get_handlers("/ws/other");
1082        assert!(other_handlers.is_empty());
1083    }
1084
1085    #[test]
1086    fn test_handler_registry_has_handler_for() {
1087        let mut registry = HandlerRegistry::new();
1088        registry.register(PathSpecificHandler {
1089            path: "/ws/chat".to_string(),
1090        });
1091
1092        assert!(registry.has_handler_for("/ws/chat"));
1093        assert!(!registry.has_handler_for("/ws/other"));
1094    }
1095
1096    #[test]
1097    fn test_handler_registry_clear() {
1098        let mut registry = HandlerRegistry::new();
1099        registry.register(TestHandler);
1100        registry.register(TestHandler);
1101        assert_eq!(registry.len(), 2);
1102
1103        registry.clear();
1104        assert!(registry.is_empty());
1105    }
1106
1107    #[test]
1108    fn test_handler_registry_with_hot_reload() {
1109        let registry = HandlerRegistry::with_hot_reload();
1110        assert!(registry.is_hot_reload_enabled());
1111    }
1112
1113    // ==================== PassthroughConfig Tests ====================
1114
1115    #[test]
1116    fn test_passthrough_config_new() {
1117        let config =
1118            PassthroughConfig::new(MessagePattern::any(), "ws://upstream:8080".to_string());
1119        assert_eq!(config.upstream_url, "ws://upstream:8080");
1120    }
1121
1122    #[test]
1123    fn test_passthrough_config_regex() {
1124        let config =
1125            PassthroughConfig::regex(r"^forward", "ws://upstream:8080".to_string()).unwrap();
1126        assert!(config.pattern.matches("forward this"));
1127        assert!(!config.pattern.matches("don't forward"));
1128    }
1129
1130    #[test]
1131    fn test_passthrough_config_regex_invalid() {
1132        let result = PassthroughConfig::regex(r"[invalid", "ws://upstream:8080".to_string());
1133        assert!(result.is_err());
1134    }
1135
1136    // ==================== PassthroughHandler Tests ====================
1137
1138    #[test]
1139    fn test_passthrough_handler_should_passthrough() {
1140        let config =
1141            PassthroughConfig::regex(r"^proxy:", "ws://upstream:8080".to_string()).unwrap();
1142        let handler = PassthroughHandler::new(config);
1143
1144        assert!(handler.should_passthrough("proxy:hello"));
1145        assert!(!handler.should_passthrough("hello"));
1146    }
1147
1148    #[test]
1149    fn test_passthrough_handler_upstream_url() {
1150        let config =
1151            PassthroughConfig::new(MessagePattern::any(), "ws://upstream:8080".to_string());
1152        let handler = PassthroughHandler::new(config);
1153
1154        assert_eq!(handler.upstream_url(), "ws://upstream:8080");
1155    }
1156
1157    // ==================== HandlerError Tests ====================
1158
1159    #[test]
1160    fn test_handler_error_send_error() {
1161        let err = HandlerError::SendError("connection closed".to_string());
1162        assert!(err.to_string().contains("send message"));
1163        assert!(err.to_string().contains("connection closed"));
1164    }
1165
1166    #[test]
1167    fn test_handler_error_json_error() {
1168        let json_err = serde_json::from_str::<serde_json::Value>("invalid").unwrap_err();
1169        let err = HandlerError::JsonError(json_err);
1170        assert!(err.to_string().contains("JSON"));
1171    }
1172
1173    #[test]
1174    fn test_handler_error_pattern_error() {
1175        let err = HandlerError::PatternError("invalid regex".to_string());
1176        assert!(err.to_string().contains("Pattern"));
1177    }
1178
1179    #[test]
1180    fn test_handler_error_room_error() {
1181        let err = HandlerError::RoomError("room full".to_string());
1182        assert!(err.to_string().contains("Room"));
1183    }
1184
1185    #[test]
1186    fn test_handler_error_connection_error() {
1187        let err = HandlerError::ConnectionError("timeout".to_string());
1188        assert!(err.to_string().contains("Connection"));
1189    }
1190
1191    #[test]
1192    fn test_handler_error_generic() {
1193        let err = HandlerError::Generic("something went wrong".to_string());
1194        assert!(err.to_string().contains("something went wrong"));
1195    }
1196
1197    // ==================== WsContext Tests ====================
1198
1199    #[tokio::test]
1200    async fn test_ws_context_metadata() {
1201        let (tx, _rx) = tokio::sync::mpsc::unbounded_channel();
1202        let ctx = WsContext::new("conn123".to_string(), "/ws".to_string(), RoomManager::new(), tx);
1203
1204        // Set and get metadata
1205        ctx.set_metadata("user", serde_json::json!({"id": 1})).await;
1206        let value = ctx.get_metadata("user").await;
1207        assert_eq!(value, Some(serde_json::json!({"id": 1})));
1208
1209        // Get nonexistent key
1210        let missing = ctx.get_metadata("nonexistent").await;
1211        assert!(missing.is_none());
1212    }
1213
1214    #[tokio::test]
1215    async fn test_ws_context_send_text() {
1216        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
1217        let ctx = WsContext::new("conn123".to_string(), "/ws".to_string(), RoomManager::new(), tx);
1218
1219        ctx.send_text("hello").await.unwrap();
1220
1221        let msg = rx.recv().await.unwrap();
1222        assert!(matches!(msg, Message::Text(_)));
1223    }
1224
1225    #[tokio::test]
1226    async fn test_ws_context_send_binary() {
1227        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
1228        let ctx = WsContext::new("conn123".to_string(), "/ws".to_string(), RoomManager::new(), tx);
1229
1230        ctx.send_binary(vec![1, 2, 3]).await.unwrap();
1231
1232        let msg = rx.recv().await.unwrap();
1233        assert!(matches!(msg, Message::Binary(_)));
1234    }
1235
1236    #[tokio::test]
1237    async fn test_ws_context_send_json() {
1238        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
1239        let ctx = WsContext::new("conn123".to_string(), "/ws".to_string(), RoomManager::new(), tx);
1240
1241        ctx.send_json(&serde_json::json!({"type": "test"})).await.unwrap();
1242
1243        let msg = rx.recv().await.unwrap();
1244        assert!(matches!(msg, Message::Text(_)));
1245    }
1246
1247    #[tokio::test]
1248    async fn test_ws_context_rooms() {
1249        let (tx, _rx) = tokio::sync::mpsc::unbounded_channel();
1250        let ctx = WsContext::new("conn123".to_string(), "/ws".to_string(), RoomManager::new(), tx);
1251
1252        // Join rooms
1253        ctx.join_room("chat").await.unwrap();
1254        ctx.join_room("notifications").await.unwrap();
1255
1256        let rooms = ctx.get_rooms().await;
1257        assert_eq!(rooms.len(), 2);
1258
1259        // Leave room
1260        ctx.leave_room("chat").await.unwrap();
1261        let rooms = ctx.get_rooms().await;
1262        assert_eq!(rooms.len(), 1);
1263    }
1264}