Skip to main content

mockforge_ws/
handlers.rs

1//! # Programmable WebSocket Handlers
2//!
3//! This module provides a flexible handler API for scripting WebSocket event flows.
4//! Unlike static replay, handlers allow you to write custom logic for responding to
5//! WebSocket events, manage rooms, and route messages dynamically.
6//!
7//! ## Features
8//!
9//! - **Connection Lifecycle**: `on_connect` and `on_disconnect` hooks
10//! - **Pattern Matching**: Route messages with regex or JSONPath patterns
11//! - **Room Management**: Broadcast messages to groups of connections
12//! - **Passthrough**: Selectively forward messages to upstream servers
13//! - **Hot Reload**: Automatically reload handlers when code changes (via `MOCKFORGE_WS_HOTRELOAD`)
14//!
15//! ## Quick Start
16//!
17//! ```rust,no_run
18//! use mockforge_ws::handlers::{WsHandler, WsContext, WsMessage, HandlerResult};
19//! use async_trait::async_trait;
20//!
21//! struct EchoHandler;
22//!
23//! #[async_trait]
24//! impl WsHandler for EchoHandler {
25//!     async fn on_connect(&self, ctx: &mut WsContext) -> HandlerResult<()> {
26//!         ctx.send_text("Welcome to the echo server!").await?;
27//!         Ok(())
28//!     }
29//!
30//!     async fn on_message(&self, ctx: &mut WsContext, msg: WsMessage) -> HandlerResult<()> {
31//!         if let WsMessage::Text(text) = msg {
32//!             ctx.send_text(&format!("echo: {}", text)).await?;
33//!         }
34//!         Ok(())
35//!     }
36//! }
37//! ```
38//!
39//! ## Message Pattern Matching
40//!
41//! ```rust,no_run
42//! use mockforge_ws::handlers::{WsHandler, WsContext, WsMessage, HandlerResult, MessagePattern};
43//! use async_trait::async_trait;
44//!
45//! struct ChatHandler;
46//!
47//! #[async_trait]
48//! impl WsHandler for ChatHandler {
49//!     async fn on_message(&self, ctx: &mut WsContext, msg: WsMessage) -> HandlerResult<()> {
50//!         if let WsMessage::Text(text) = msg {
51//!             // Use pattern matching to route messages
52//!             if let Ok(pattern) = MessagePattern::regex(r"^/join (.+)$") {
53//!                 if pattern.matches(&text) {
54//!                     // Extract room name and join
55//!                     if let Some(room) = text.strip_prefix("/join ") {
56//!                         ctx.join_room(room).await?;
57//!                         ctx.send_text(&format!("Joined room: {}", room)).await?;
58//!                     }
59//!                 }
60//!             }
61//!             // Handle JSON chat messages
62//!             let jsonpath_pattern = MessagePattern::jsonpath("$.type");
63//!             if jsonpath_pattern.matches(&text) {
64//!                 ctx.broadcast_to_room("general", &text).await?;
65//!             }
66//!         }
67//!         Ok(())
68//!     }
69//! }
70//! ```
71
72use async_trait::async_trait;
73use axum::extract::ws::Message;
74use futures_util::{SinkExt, StreamExt};
75use regex::Regex;
76use serde_json::Value;
77use std::collections::{HashMap, HashSet};
78use std::sync::Arc;
79use tokio::sync::{broadcast, Mutex, RwLock};
80
81/// Result type for handler operations
82pub type HandlerResult<T> = Result<T, HandlerError>;
83
84/// Error type for handler operations
85#[derive(Debug, thiserror::Error)]
86pub enum HandlerError {
87    /// Failed to send WebSocket message
88    #[error("Failed to send message: {0}")]
89    SendError(String),
90
91    /// JSON parsing/serialization error
92    #[error("Failed to parse JSON: {0}")]
93    JsonError(#[from] serde_json::Error),
94
95    /// Pattern matching failure (e.g., route pattern)
96    #[error("Pattern matching error: {0}")]
97    PatternError(String),
98
99    /// Room/group operation failure
100    #[error("Room operation failed: {0}")]
101    RoomError(String),
102
103    /// WebSocket connection error
104    #[error("Connection error: {0}")]
105    ConnectionError(String),
106
107    /// Generic handler error
108    #[error("Handler error: {0}")]
109    Generic(String),
110}
111
112/// WebSocket message wrapper for different message types
113#[derive(Debug, Clone)]
114pub enum WsMessage {
115    /// Text message (UTF-8 string)
116    Text(String),
117    /// Binary message (raw bytes)
118    Binary(Vec<u8>),
119    /// Ping frame (connection keepalive)
120    Ping(Vec<u8>),
121    /// Pong frame (response to ping)
122    Pong(Vec<u8>),
123    /// Close frame (connection termination)
124    Close,
125}
126
127impl From<Message> for WsMessage {
128    fn from(msg: Message) -> Self {
129        match msg {
130            Message::Text(text) => WsMessage::Text(text.to_string()),
131            Message::Binary(data) => WsMessage::Binary(data.to_vec()),
132            Message::Ping(data) => WsMessage::Ping(data.to_vec()),
133            Message::Pong(data) => WsMessage::Pong(data.to_vec()),
134            Message::Close(_) => WsMessage::Close,
135        }
136    }
137}
138
139impl From<WsMessage> for Message {
140    fn from(msg: WsMessage) -> Self {
141        match msg {
142            WsMessage::Text(text) => Message::Text(text.into()),
143            WsMessage::Binary(data) => Message::Binary(data.into()),
144            WsMessage::Ping(data) => Message::Ping(data.into()),
145            WsMessage::Pong(data) => Message::Pong(data.into()),
146            WsMessage::Close => Message::Close(None),
147        }
148    }
149}
150
151/// Pattern for matching WebSocket messages
152#[derive(Debug, Clone)]
153pub enum MessagePattern {
154    /// Match using regular expression
155    Regex(Regex),
156    /// Match using JSONPath query
157    JsonPath(String),
158    /// Match exact text
159    Exact(String),
160    /// Always matches
161    Any,
162}
163
164impl MessagePattern {
165    /// Create a regex pattern
166    pub fn regex(pattern: &str) -> HandlerResult<Self> {
167        Ok(MessagePattern::Regex(
168            Regex::new(pattern).map_err(|e| HandlerError::PatternError(e.to_string()))?,
169        ))
170    }
171
172    /// Create a JSONPath pattern
173    pub fn jsonpath(query: &str) -> Self {
174        MessagePattern::JsonPath(query.to_string())
175    }
176
177    /// Create an exact match pattern
178    pub fn exact(text: &str) -> Self {
179        MessagePattern::Exact(text.to_string())
180    }
181
182    /// Create a pattern that matches everything
183    pub fn any() -> Self {
184        MessagePattern::Any
185    }
186
187    /// Check if the pattern matches the message
188    pub fn matches(&self, text: &str) -> bool {
189        match self {
190            MessagePattern::Regex(re) => re.is_match(text),
191            MessagePattern::JsonPath(query) => {
192                // Try to parse as JSON and check if path exists
193                if let Ok(json) = serde_json::from_str::<Value>(text) {
194                    // Use jsonpath crate's Selector
195                    if let Ok(selector) = jsonpath::Selector::new(query) {
196                        let results: Vec<_> = selector.find(&json).collect();
197                        !results.is_empty()
198                    } else {
199                        false
200                    }
201                } else {
202                    false
203                }
204            }
205            MessagePattern::Exact(expected) => text == expected,
206            MessagePattern::Any => true,
207        }
208    }
209
210    /// Check if pattern matches and extract value using JSONPath
211    pub fn extract(&self, text: &str, query: &str) -> Option<Value> {
212        if let Ok(json) = serde_json::from_str::<Value>(text) {
213            if let Ok(selector) = jsonpath::Selector::new(query) {
214                let results: Vec<_> = selector.find(&json).collect();
215                results.first().cloned().cloned()
216            } else {
217                None
218            }
219        } else {
220            None
221        }
222    }
223}
224
225/// Connection ID type
226pub type ConnectionId = String;
227
228/// Room manager for broadcasting messages to groups of connections
229#[derive(Clone)]
230pub struct RoomManager {
231    rooms: Arc<RwLock<HashMap<String, HashSet<ConnectionId>>>>,
232    connections: Arc<RwLock<HashMap<ConnectionId, HashSet<String>>>>,
233    broadcasters: Arc<RwLock<HashMap<String, broadcast::Sender<String>>>>,
234}
235
236impl RoomManager {
237    /// Create a new room manager
238    pub fn new() -> Self {
239        Self {
240            rooms: Arc::new(RwLock::new(HashMap::new())),
241            connections: Arc::new(RwLock::new(HashMap::new())),
242            broadcasters: Arc::new(RwLock::new(HashMap::new())),
243        }
244    }
245
246    /// Join a room
247    pub async fn join(&self, conn_id: &str, room: &str) -> HandlerResult<()> {
248        let mut rooms = self.rooms.write().await;
249        let mut connections = self.connections.write().await;
250
251        rooms
252            .entry(room.to_string())
253            .or_insert_with(HashSet::new)
254            .insert(conn_id.to_string());
255
256        connections
257            .entry(conn_id.to_string())
258            .or_insert_with(HashSet::new)
259            .insert(room.to_string());
260
261        Ok(())
262    }
263
264    /// Leave a room
265    pub async fn leave(&self, conn_id: &str, room: &str) -> HandlerResult<()> {
266        let mut rooms = self.rooms.write().await;
267        let mut connections = self.connections.write().await;
268
269        if let Some(room_members) = rooms.get_mut(room) {
270            room_members.remove(conn_id);
271            if room_members.is_empty() {
272                rooms.remove(room);
273            }
274        }
275
276        if let Some(conn_rooms) = connections.get_mut(conn_id) {
277            conn_rooms.remove(room);
278            if conn_rooms.is_empty() {
279                connections.remove(conn_id);
280            }
281        }
282
283        Ok(())
284    }
285
286    /// Leave all rooms for a connection
287    pub async fn leave_all(&self, conn_id: &str) -> HandlerResult<()> {
288        let mut connections = self.connections.write().await;
289        if let Some(conn_rooms) = connections.remove(conn_id) {
290            let mut rooms = self.rooms.write().await;
291            for room in conn_rooms {
292                if let Some(room_members) = rooms.get_mut(&room) {
293                    room_members.remove(conn_id);
294                    if room_members.is_empty() {
295                        rooms.remove(&room);
296                    }
297                }
298            }
299        }
300        Ok(())
301    }
302
303    /// Get all connections in a room
304    pub async fn get_room_members(&self, room: &str) -> Vec<ConnectionId> {
305        let rooms = self.rooms.read().await;
306        rooms
307            .get(room)
308            .map(|members| members.iter().cloned().collect())
309            .unwrap_or_default()
310    }
311
312    /// Get all rooms for a connection
313    pub async fn get_connection_rooms(&self, conn_id: &str) -> Vec<String> {
314        let connections = self.connections.read().await;
315        connections
316            .get(conn_id)
317            .map(|rooms| rooms.iter().cloned().collect())
318            .unwrap_or_default()
319    }
320
321    /// Get broadcast sender for a room (creates if doesn't exist)
322    pub async fn get_broadcaster(&self, room: &str) -> broadcast::Sender<String> {
323        let mut broadcasters = self.broadcasters.write().await;
324        broadcasters
325            .entry(room.to_string())
326            .or_insert_with(|| {
327                let (tx, _) = broadcast::channel(1024);
328                tx
329            })
330            .clone()
331    }
332}
333
334impl Default for RoomManager {
335    fn default() -> Self {
336        Self::new()
337    }
338}
339
340/// Context provided to handlers for each connection
341pub struct WsContext {
342    /// Unique connection ID
343    pub connection_id: ConnectionId,
344    /// WebSocket path
345    pub path: String,
346    /// Room manager for broadcasting
347    room_manager: RoomManager,
348    /// Sender for outgoing messages
349    message_tx: tokio::sync::mpsc::UnboundedSender<Message>,
350    /// Metadata storage
351    metadata: Arc<RwLock<HashMap<String, Value>>>,
352}
353
354impl WsContext {
355    /// Create a new WebSocket context
356    pub fn new(
357        connection_id: ConnectionId,
358        path: String,
359        room_manager: RoomManager,
360        message_tx: tokio::sync::mpsc::UnboundedSender<Message>,
361    ) -> Self {
362        Self {
363            connection_id,
364            path,
365            room_manager,
366            message_tx,
367            metadata: Arc::new(RwLock::new(HashMap::new())),
368        }
369    }
370
371    /// Send a text message
372    pub async fn send_text(&self, text: &str) -> HandlerResult<()> {
373        self.message_tx
374            .send(Message::Text(text.to_string().into()))
375            .map_err(|e| HandlerError::SendError(e.to_string()))
376    }
377
378    /// Send a binary message
379    pub async fn send_binary(&self, data: Vec<u8>) -> HandlerResult<()> {
380        self.message_tx
381            .send(Message::Binary(data.into()))
382            .map_err(|e| HandlerError::SendError(e.to_string()))
383    }
384
385    /// Send a JSON message
386    pub async fn send_json(&self, value: &Value) -> HandlerResult<()> {
387        let text = serde_json::to_string(value)?;
388        self.send_text(&text).await
389    }
390
391    /// Join a room
392    pub async fn join_room(&self, room: &str) -> HandlerResult<()> {
393        self.room_manager.join(&self.connection_id, room).await
394    }
395
396    /// Leave a room
397    pub async fn leave_room(&self, room: &str) -> HandlerResult<()> {
398        self.room_manager.leave(&self.connection_id, room).await
399    }
400
401    /// Broadcast text to all members in a room
402    pub async fn broadcast_to_room(&self, room: &str, text: &str) -> HandlerResult<()> {
403        let broadcaster = self.room_manager.get_broadcaster(room).await;
404        broadcaster
405            .send(text.to_string())
406            .map_err(|e| HandlerError::RoomError(e.to_string()))?;
407        Ok(())
408    }
409
410    /// Get all rooms this connection is in
411    pub async fn get_rooms(&self) -> Vec<String> {
412        self.room_manager.get_connection_rooms(&self.connection_id).await
413    }
414
415    /// Set metadata value
416    pub async fn set_metadata(&self, key: &str, value: Value) {
417        let mut metadata = self.metadata.write().await;
418        metadata.insert(key.to_string(), value);
419    }
420
421    /// Get metadata value
422    pub async fn get_metadata(&self, key: &str) -> Option<Value> {
423        let metadata = self.metadata.read().await;
424        metadata.get(key).cloned()
425    }
426}
427
428/// Trait for WebSocket message handlers
429#[async_trait]
430pub trait WsHandler: Send + Sync {
431    /// Called when a new WebSocket connection is established
432    async fn on_connect(&self, _ctx: &mut WsContext) -> HandlerResult<()> {
433        Ok(())
434    }
435
436    /// Called when a message is received
437    async fn on_message(&self, ctx: &mut WsContext, msg: WsMessage) -> HandlerResult<()>;
438
439    /// Called when the connection is closed
440    async fn on_disconnect(&self, _ctx: &mut WsContext) -> HandlerResult<()> {
441        Ok(())
442    }
443
444    /// Check if this handler should handle the given path
445    fn handles_path(&self, _path: &str) -> bool {
446        true // Default: handle all paths
447    }
448}
449
450/// Pattern-based message router
451pub struct MessageRouter {
452    routes: Vec<(MessagePattern, Box<dyn Fn(String) -> Option<String> + Send + Sync>)>,
453}
454
455impl MessageRouter {
456    /// Create a new message router
457    pub fn new() -> Self {
458        Self { routes: Vec::new() }
459    }
460
461    /// Add a route with a pattern and handler function
462    pub fn on<F>(&mut self, pattern: MessagePattern, handler: F) -> &mut Self
463    where
464        F: Fn(String) -> Option<String> + Send + Sync + 'static,
465    {
466        self.routes.push((pattern, Box::new(handler)));
467        self
468    }
469
470    /// Route a message through the registered handlers
471    pub fn route(&self, text: &str) -> Option<String> {
472        for (pattern, handler) in &self.routes {
473            if pattern.matches(text) {
474                if let Some(response) = handler(text.to_string()) {
475                    return Some(response);
476                }
477            }
478        }
479        None
480    }
481}
482
483impl Default for MessageRouter {
484    fn default() -> Self {
485        Self::new()
486    }
487}
488
489/// Handler registry for managing multiple handlers
490pub struct HandlerRegistry {
491    handlers: Vec<Arc<dyn WsHandler>>,
492    hot_reload_enabled: bool,
493}
494
495impl HandlerRegistry {
496    /// Create a new handler registry
497    pub fn new() -> Self {
498        Self {
499            handlers: Vec::new(),
500            hot_reload_enabled: std::env::var("MOCKFORGE_WS_HOTRELOAD")
501                .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
502                .unwrap_or(false),
503        }
504    }
505
506    /// Create a registry with hot-reload enabled
507    pub fn with_hot_reload() -> Self {
508        Self {
509            handlers: Vec::new(),
510            hot_reload_enabled: true,
511        }
512    }
513
514    /// Check if hot-reload is enabled
515    pub fn is_hot_reload_enabled(&self) -> bool {
516        self.hot_reload_enabled
517    }
518
519    /// Register a handler
520    pub fn register<H: WsHandler + 'static>(&mut self, handler: H) -> &mut Self {
521        self.handlers.push(Arc::new(handler));
522        self
523    }
524
525    /// Get handlers for a specific path
526    pub fn get_handlers(&self, path: &str) -> Vec<Arc<dyn WsHandler>> {
527        self.handlers.iter().filter(|h| h.handles_path(path)).cloned().collect()
528    }
529
530    /// Check if any handler handles the given path
531    pub fn has_handler_for(&self, path: &str) -> bool {
532        self.handlers.iter().any(|h| h.handles_path(path))
533    }
534
535    /// Clear all handlers (useful for hot-reload)
536    pub fn clear(&mut self) {
537        self.handlers.clear();
538    }
539
540    /// Get the number of registered handlers
541    pub fn len(&self) -> usize {
542        self.handlers.len()
543    }
544
545    /// Check if the registry is empty
546    pub fn is_empty(&self) -> bool {
547        self.handlers.is_empty()
548    }
549}
550
551impl Default for HandlerRegistry {
552    fn default() -> Self {
553        Self::new()
554    }
555}
556
557/// Passthrough handler configuration for forwarding messages to upstream servers
558#[derive(Clone)]
559pub struct PassthroughConfig {
560    /// Pattern to match paths for passthrough
561    pub pattern: MessagePattern,
562    /// Upstream URL to forward to
563    pub upstream_url: String,
564}
565
566impl PassthroughConfig {
567    /// Create a new passthrough configuration
568    pub fn new(pattern: MessagePattern, upstream_url: String) -> Self {
569        Self {
570            pattern,
571            upstream_url,
572        }
573    }
574
575    /// Create a passthrough for all messages matching a regex
576    pub fn regex(regex: &str, upstream_url: String) -> HandlerResult<Self> {
577        Ok(Self {
578            pattern: MessagePattern::regex(regex)?,
579            upstream_url,
580        })
581    }
582}
583
584/// Passthrough handler that forwards messages to an upstream WebSocket server
585pub struct PassthroughHandler {
586    config: PassthroughConfig,
587    /// Upstream WebSocket write half, established lazily on first message
588    upstream_tx: Mutex<Option<UpstreamSender>>,
589}
590
591type UpstreamSender = futures_util::stream::SplitSink<
592    tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
593    tokio_tungstenite::tungstenite::Message,
594>;
595
596impl PassthroughHandler {
597    /// Create a new passthrough handler
598    pub fn new(config: PassthroughConfig) -> Self {
599        Self {
600            config,
601            upstream_tx: Mutex::new(None),
602        }
603    }
604
605    /// Check if a message should be passed through
606    pub fn should_passthrough(&self, text: &str) -> bool {
607        self.config.pattern.matches(text)
608    }
609
610    /// Get the upstream URL
611    pub fn upstream_url(&self) -> &str {
612        &self.config.upstream_url
613    }
614
615    /// Connect to the upstream WebSocket server and spawn a reader task that
616    /// relays messages back to the client.
617    async fn ensure_connected(
618        &self,
619        client_tx: &tokio::sync::mpsc::UnboundedSender<Message>,
620    ) -> HandlerResult<()> {
621        let mut guard = self.upstream_tx.lock().await;
622        if guard.is_some() {
623            return Ok(());
624        }
625
626        let url = &self.config.upstream_url;
627        tracing::info!(upstream = %url, "Connecting to upstream WebSocket server");
628
629        let (ws_stream, _response) = tokio_tungstenite::connect_async(url)
630            .await
631            .map_err(|e| HandlerError::ConnectionError(format!("Upstream connect failed: {e}")))?;
632
633        let (write, mut read) = ws_stream.split();
634        *guard = Some(write);
635
636        // Spawn a task that forwards upstream → client
637        let client_tx = client_tx.clone();
638        tokio::spawn(async move {
639            while let Some(Ok(msg)) = read.next().await {
640                let axum_msg = match msg {
641                    tokio_tungstenite::tungstenite::Message::Text(t) => {
642                        Message::Text(t.to_string().into())
643                    }
644                    tokio_tungstenite::tungstenite::Message::Binary(b) => {
645                        Message::Binary(b.to_vec().into())
646                    }
647                    tokio_tungstenite::tungstenite::Message::Ping(p) => {
648                        Message::Ping(p.to_vec().into())
649                    }
650                    tokio_tungstenite::tungstenite::Message::Pong(p) => {
651                        Message::Pong(p.to_vec().into())
652                    }
653                    tokio_tungstenite::tungstenite::Message::Close(_) => {
654                        break;
655                    }
656                    tokio_tungstenite::tungstenite::Message::Frame(_) => continue,
657                };
658                if client_tx.send(axum_msg).is_err() {
659                    break;
660                }
661            }
662            tracing::debug!("Upstream reader task finished");
663        });
664
665        Ok(())
666    }
667}
668
669#[async_trait]
670impl WsHandler for PassthroughHandler {
671    async fn on_connect(&self, ctx: &mut WsContext) -> HandlerResult<()> {
672        self.ensure_connected(&ctx.message_tx).await
673    }
674
675    async fn on_message(&self, ctx: &mut WsContext, msg: WsMessage) -> HandlerResult<()> {
676        match &msg {
677            WsMessage::Text(text) if self.should_passthrough(text) => {
678                self.ensure_connected(&ctx.message_tx).await?;
679                let mut guard = self.upstream_tx.lock().await;
680                if let Some(ref mut writer) = *guard {
681                    writer
682                        .send(tokio_tungstenite::tungstenite::Message::Text(text.clone().into()))
683                        .await
684                        .map_err(|e| {
685                            HandlerError::SendError(format!("Upstream send failed: {e}"))
686                        })?;
687                }
688            }
689            WsMessage::Binary(data) => {
690                self.ensure_connected(&ctx.message_tx).await?;
691                let mut guard = self.upstream_tx.lock().await;
692                if let Some(ref mut writer) = *guard {
693                    writer
694                        .send(tokio_tungstenite::tungstenite::Message::Binary(data.clone().into()))
695                        .await
696                        .map_err(|e| {
697                            HandlerError::SendError(format!("Upstream send failed: {e}"))
698                        })?;
699                }
700            }
701            _ => {}
702        }
703        Ok(())
704    }
705
706    async fn on_disconnect(&self, _ctx: &mut WsContext) -> HandlerResult<()> {
707        let mut guard = self.upstream_tx.lock().await;
708        if let Some(mut writer) = guard.take() {
709            let _ = writer.send(tokio_tungstenite::tungstenite::Message::Close(None)).await;
710        }
711        Ok(())
712    }
713}
714
715#[cfg(test)]
716mod tests {
717    use super::*;
718
719    // ==================== WsMessage Tests ====================
720
721    #[test]
722    fn test_ws_message_text_from_axum() {
723        let axum_msg = Message::Text("hello".to_string().into());
724        let ws_msg: WsMessage = axum_msg.into();
725        match ws_msg {
726            WsMessage::Text(text) => assert_eq!(text, "hello"),
727            _ => panic!("Expected Text message"),
728        }
729    }
730
731    #[test]
732    fn test_ws_message_binary_from_axum() {
733        let data = vec![1, 2, 3, 4];
734        let axum_msg = Message::Binary(data.clone().into());
735        let ws_msg: WsMessage = axum_msg.into();
736        match ws_msg {
737            WsMessage::Binary(bytes) => assert_eq!(bytes, data),
738            _ => panic!("Expected Binary message"),
739        }
740    }
741
742    #[test]
743    fn test_ws_message_ping_from_axum() {
744        let data = vec![1, 2];
745        let axum_msg = Message::Ping(data.clone().into());
746        let ws_msg: WsMessage = axum_msg.into();
747        match ws_msg {
748            WsMessage::Ping(bytes) => assert_eq!(bytes, data),
749            _ => panic!("Expected Ping message"),
750        }
751    }
752
753    #[test]
754    fn test_ws_message_pong_from_axum() {
755        let data = vec![3, 4];
756        let axum_msg = Message::Pong(data.clone().into());
757        let ws_msg: WsMessage = axum_msg.into();
758        match ws_msg {
759            WsMessage::Pong(bytes) => assert_eq!(bytes, data),
760            _ => panic!("Expected Pong message"),
761        }
762    }
763
764    #[test]
765    fn test_ws_message_close_from_axum() {
766        let axum_msg = Message::Close(None);
767        let ws_msg: WsMessage = axum_msg.into();
768        assert!(matches!(ws_msg, WsMessage::Close));
769    }
770
771    #[test]
772    fn test_ws_message_text_to_axum() {
773        let ws_msg = WsMessage::Text("hello".to_string());
774        let axum_msg: Message = ws_msg.into();
775        assert!(matches!(axum_msg, Message::Text(_)));
776    }
777
778    #[test]
779    fn test_ws_message_binary_to_axum() {
780        let ws_msg = WsMessage::Binary(vec![1, 2, 3]);
781        let axum_msg: Message = ws_msg.into();
782        assert!(matches!(axum_msg, Message::Binary(_)));
783    }
784
785    #[test]
786    fn test_ws_message_close_to_axum() {
787        let ws_msg = WsMessage::Close;
788        let axum_msg: Message = ws_msg.into();
789        assert!(matches!(axum_msg, Message::Close(_)));
790    }
791
792    // ==================== MessagePattern Tests ====================
793
794    #[test]
795    fn test_message_pattern_regex() {
796        let pattern = MessagePattern::regex(r"^hello").unwrap();
797        assert!(pattern.matches("hello world"));
798        assert!(!pattern.matches("goodbye world"));
799    }
800
801    #[test]
802    fn test_message_pattern_regex_invalid() {
803        let result = MessagePattern::regex(r"[invalid");
804        assert!(result.is_err());
805    }
806
807    #[test]
808    fn test_message_pattern_exact() {
809        let pattern = MessagePattern::exact("hello");
810        assert!(pattern.matches("hello"));
811        assert!(!pattern.matches("hello world"));
812    }
813
814    #[test]
815    fn test_message_pattern_jsonpath() {
816        let pattern = MessagePattern::jsonpath("$.type");
817        assert!(pattern.matches(r#"{"type": "message"}"#));
818        assert!(!pattern.matches(r#"{"name": "test"}"#));
819    }
820
821    #[test]
822    fn test_message_pattern_jsonpath_nested() {
823        let pattern = MessagePattern::jsonpath("$.user.name");
824        assert!(pattern.matches(r#"{"user": {"name": "John"}}"#));
825        assert!(!pattern.matches(r#"{"user": {"email": "john@example.com"}}"#));
826    }
827
828    #[test]
829    fn test_message_pattern_jsonpath_invalid_json() {
830        let pattern = MessagePattern::jsonpath("$.type");
831        assert!(!pattern.matches("not json"));
832    }
833
834    #[test]
835    fn test_message_pattern_any() {
836        let pattern = MessagePattern::any();
837        assert!(pattern.matches("anything"));
838        assert!(pattern.matches(""));
839        assert!(pattern.matches(r#"{"json": true}"#));
840    }
841
842    #[test]
843    fn test_message_pattern_extract() {
844        let pattern = MessagePattern::jsonpath("$.type");
845        let result = pattern.extract(r#"{"type": "greeting", "data": "hello"}"#, "$.type");
846        assert_eq!(result, Some(serde_json::json!("greeting")));
847    }
848
849    #[test]
850    fn test_message_pattern_extract_nested() {
851        let pattern = MessagePattern::any();
852        let result = pattern.extract(r#"{"user": {"id": 123}}"#, "$.user.id");
853        assert_eq!(result, Some(serde_json::json!(123)));
854    }
855
856    #[test]
857    fn test_message_pattern_extract_not_found() {
858        let pattern = MessagePattern::any();
859        let result = pattern.extract(r#"{"type": "message"}"#, "$.nonexistent");
860        assert!(result.is_none());
861    }
862
863    #[test]
864    fn test_message_pattern_extract_invalid_json() {
865        let pattern = MessagePattern::any();
866        let result = pattern.extract("not json", "$.type");
867        assert!(result.is_none());
868    }
869
870    // ==================== RoomManager Tests ====================
871
872    #[tokio::test]
873    async fn test_room_manager() {
874        let manager = RoomManager::new();
875
876        // Join rooms
877        manager.join("conn1", "room1").await.unwrap();
878        manager.join("conn1", "room2").await.unwrap();
879        manager.join("conn2", "room1").await.unwrap();
880
881        // Check room members
882        let room1_members = manager.get_room_members("room1").await;
883        assert_eq!(room1_members.len(), 2);
884        assert!(room1_members.contains(&"conn1".to_string()));
885        assert!(room1_members.contains(&"conn2".to_string()));
886
887        // Check connection rooms
888        let conn1_rooms = manager.get_connection_rooms("conn1").await;
889        assert_eq!(conn1_rooms.len(), 2);
890        assert!(conn1_rooms.contains(&"room1".to_string()));
891        assert!(conn1_rooms.contains(&"room2".to_string()));
892
893        // Leave room
894        manager.leave("conn1", "room1").await.unwrap();
895        let room1_members = manager.get_room_members("room1").await;
896        assert_eq!(room1_members.len(), 1);
897        assert!(room1_members.contains(&"conn2".to_string()));
898
899        // Leave all rooms
900        manager.leave_all("conn1").await.unwrap();
901        let conn1_rooms = manager.get_connection_rooms("conn1").await;
902        assert_eq!(conn1_rooms.len(), 0);
903    }
904
905    #[tokio::test]
906    async fn test_room_manager_default() {
907        let manager = RoomManager::default();
908        // Should work the same as new()
909        manager.join("conn1", "room1").await.unwrap();
910        let members = manager.get_room_members("room1").await;
911        assert_eq!(members.len(), 1);
912    }
913
914    #[tokio::test]
915    async fn test_room_manager_empty_room() {
916        let manager = RoomManager::new();
917        let members = manager.get_room_members("nonexistent").await;
918        assert!(members.is_empty());
919    }
920
921    #[tokio::test]
922    async fn test_room_manager_empty_connection() {
923        let manager = RoomManager::new();
924        let rooms = manager.get_connection_rooms("nonexistent").await;
925        assert!(rooms.is_empty());
926    }
927
928    #[tokio::test]
929    async fn test_room_manager_leave_nonexistent() {
930        let manager = RoomManager::new();
931        // Should not error when leaving a room we're not in
932        let result = manager.leave("conn1", "room1").await;
933        assert!(result.is_ok());
934    }
935
936    #[tokio::test]
937    async fn test_room_manager_broadcaster() {
938        let manager = RoomManager::new();
939        manager.join("conn1", "room1").await.unwrap();
940
941        let broadcaster = manager.get_broadcaster("room1").await;
942        let mut receiver = broadcaster.subscribe();
943
944        // Send a message
945        broadcaster.send("hello".to_string()).unwrap();
946
947        // Receive it
948        let msg = receiver.recv().await.unwrap();
949        assert_eq!(msg, "hello");
950    }
951
952    #[tokio::test]
953    async fn test_room_manager_room_cleanup_on_last_leave() {
954        let manager = RoomManager::new();
955        manager.join("conn1", "room1").await.unwrap();
956        manager.leave("conn1", "room1").await.unwrap();
957
958        // Room should be cleaned up
959        let members = manager.get_room_members("room1").await;
960        assert!(members.is_empty());
961    }
962
963    // ==================== MessageRouter Tests ====================
964
965    #[test]
966    fn test_message_router() {
967        let mut router = MessageRouter::new();
968
969        router
970            .on(MessagePattern::exact("ping"), |_| Some("pong".to_string()))
971            .on(MessagePattern::regex(r"^hello").unwrap(), |_| Some("hi there!".to_string()));
972
973        assert_eq!(router.route("ping"), Some("pong".to_string()));
974        assert_eq!(router.route("hello world"), Some("hi there!".to_string()));
975        assert_eq!(router.route("goodbye"), None);
976    }
977
978    #[test]
979    fn test_message_router_default() {
980        let router = MessageRouter::default();
981        // Empty router returns None for all messages
982        assert_eq!(router.route("anything"), None);
983    }
984
985    #[test]
986    fn test_message_router_first_match_wins() {
987        let mut router = MessageRouter::new();
988        router
989            .on(MessagePattern::any(), |_| Some("first".to_string()))
990            .on(MessagePattern::any(), |_| Some("second".to_string()));
991
992        assert_eq!(router.route("test"), Some("first".to_string()));
993    }
994
995    #[test]
996    fn test_message_router_handler_returns_none() {
997        let mut router = MessageRouter::new();
998        router
999            .on(MessagePattern::exact("skip"), |_| None)
1000            .on(MessagePattern::any(), |_| Some("fallback".to_string()));
1001
1002        // First pattern matches but returns None, so it continues to next
1003        assert_eq!(router.route("skip"), Some("fallback".to_string()));
1004    }
1005
1006    // ==================== HandlerRegistry Tests ====================
1007
1008    struct TestHandler;
1009
1010    #[async_trait]
1011    impl WsHandler for TestHandler {
1012        async fn on_message(&self, _ctx: &mut WsContext, _msg: WsMessage) -> HandlerResult<()> {
1013            Ok(())
1014        }
1015    }
1016
1017    struct PathSpecificHandler {
1018        path: String,
1019    }
1020
1021    #[async_trait]
1022    impl WsHandler for PathSpecificHandler {
1023        async fn on_message(&self, _ctx: &mut WsContext, _msg: WsMessage) -> HandlerResult<()> {
1024            Ok(())
1025        }
1026
1027        fn handles_path(&self, path: &str) -> bool {
1028            path == self.path
1029        }
1030    }
1031
1032    #[test]
1033    fn test_handler_registry_new() {
1034        let registry = HandlerRegistry::new();
1035        assert!(registry.is_empty());
1036        assert_eq!(registry.len(), 0);
1037    }
1038
1039    #[test]
1040    fn test_handler_registry_default() {
1041        let registry = HandlerRegistry::default();
1042        assert!(registry.is_empty());
1043    }
1044
1045    #[test]
1046    fn test_handler_registry_register() {
1047        let mut registry = HandlerRegistry::new();
1048        registry.register(TestHandler);
1049        assert!(!registry.is_empty());
1050        assert_eq!(registry.len(), 1);
1051    }
1052
1053    #[test]
1054    fn test_handler_registry_get_handlers() {
1055        let mut registry = HandlerRegistry::new();
1056        registry.register(TestHandler);
1057
1058        let handlers = registry.get_handlers("/any/path");
1059        assert_eq!(handlers.len(), 1);
1060    }
1061
1062    #[test]
1063    fn test_handler_registry_path_filtering() {
1064        let mut registry = HandlerRegistry::new();
1065        registry.register(PathSpecificHandler {
1066            path: "/ws/chat".to_string(),
1067        });
1068        registry.register(PathSpecificHandler {
1069            path: "/ws/events".to_string(),
1070        });
1071
1072        let chat_handlers = registry.get_handlers("/ws/chat");
1073        assert_eq!(chat_handlers.len(), 1);
1074
1075        let events_handlers = registry.get_handlers("/ws/events");
1076        assert_eq!(events_handlers.len(), 1);
1077
1078        let other_handlers = registry.get_handlers("/ws/other");
1079        assert!(other_handlers.is_empty());
1080    }
1081
1082    #[test]
1083    fn test_handler_registry_has_handler_for() {
1084        let mut registry = HandlerRegistry::new();
1085        registry.register(PathSpecificHandler {
1086            path: "/ws/chat".to_string(),
1087        });
1088
1089        assert!(registry.has_handler_for("/ws/chat"));
1090        assert!(!registry.has_handler_for("/ws/other"));
1091    }
1092
1093    #[test]
1094    fn test_handler_registry_clear() {
1095        let mut registry = HandlerRegistry::new();
1096        registry.register(TestHandler);
1097        registry.register(TestHandler);
1098        assert_eq!(registry.len(), 2);
1099
1100        registry.clear();
1101        assert!(registry.is_empty());
1102    }
1103
1104    #[test]
1105    fn test_handler_registry_with_hot_reload() {
1106        let registry = HandlerRegistry::with_hot_reload();
1107        assert!(registry.is_hot_reload_enabled());
1108    }
1109
1110    // ==================== PassthroughConfig Tests ====================
1111
1112    #[test]
1113    fn test_passthrough_config_new() {
1114        let config =
1115            PassthroughConfig::new(MessagePattern::any(), "ws://upstream:8080".to_string());
1116        assert_eq!(config.upstream_url, "ws://upstream:8080");
1117    }
1118
1119    #[test]
1120    fn test_passthrough_config_regex() {
1121        let config =
1122            PassthroughConfig::regex(r"^forward", "ws://upstream:8080".to_string()).unwrap();
1123        assert!(config.pattern.matches("forward this"));
1124        assert!(!config.pattern.matches("don't forward"));
1125    }
1126
1127    #[test]
1128    fn test_passthrough_config_regex_invalid() {
1129        let result = PassthroughConfig::regex(r"[invalid", "ws://upstream:8080".to_string());
1130        assert!(result.is_err());
1131    }
1132
1133    // ==================== PassthroughHandler Tests ====================
1134
1135    #[test]
1136    fn test_passthrough_handler_should_passthrough() {
1137        let config =
1138            PassthroughConfig::regex(r"^proxy:", "ws://upstream:8080".to_string()).unwrap();
1139        let handler = PassthroughHandler::new(config);
1140
1141        assert!(handler.should_passthrough("proxy:hello"));
1142        assert!(!handler.should_passthrough("hello"));
1143    }
1144
1145    #[test]
1146    fn test_passthrough_handler_upstream_url() {
1147        let config =
1148            PassthroughConfig::new(MessagePattern::any(), "ws://upstream:8080".to_string());
1149        let handler = PassthroughHandler::new(config);
1150
1151        assert_eq!(handler.upstream_url(), "ws://upstream:8080");
1152    }
1153
1154    // ==================== HandlerError Tests ====================
1155
1156    #[test]
1157    fn test_handler_error_send_error() {
1158        let err = HandlerError::SendError("connection closed".to_string());
1159        assert!(err.to_string().contains("send message"));
1160        assert!(err.to_string().contains("connection closed"));
1161    }
1162
1163    #[test]
1164    fn test_handler_error_json_error() {
1165        let json_err = serde_json::from_str::<serde_json::Value>("invalid").unwrap_err();
1166        let err = HandlerError::JsonError(json_err);
1167        assert!(err.to_string().contains("JSON"));
1168    }
1169
1170    #[test]
1171    fn test_handler_error_pattern_error() {
1172        let err = HandlerError::PatternError("invalid regex".to_string());
1173        assert!(err.to_string().contains("Pattern"));
1174    }
1175
1176    #[test]
1177    fn test_handler_error_room_error() {
1178        let err = HandlerError::RoomError("room full".to_string());
1179        assert!(err.to_string().contains("Room"));
1180    }
1181
1182    #[test]
1183    fn test_handler_error_connection_error() {
1184        let err = HandlerError::ConnectionError("timeout".to_string());
1185        assert!(err.to_string().contains("Connection"));
1186    }
1187
1188    #[test]
1189    fn test_handler_error_generic() {
1190        let err = HandlerError::Generic("something went wrong".to_string());
1191        assert!(err.to_string().contains("something went wrong"));
1192    }
1193
1194    // ==================== WsContext Tests ====================
1195
1196    #[tokio::test]
1197    async fn test_ws_context_metadata() {
1198        let (tx, _rx) = tokio::sync::mpsc::unbounded_channel();
1199        let ctx = WsContext::new("conn123".to_string(), "/ws".to_string(), RoomManager::new(), tx);
1200
1201        // Set and get metadata
1202        ctx.set_metadata("user", serde_json::json!({"id": 1})).await;
1203        let value = ctx.get_metadata("user").await;
1204        assert_eq!(value, Some(serde_json::json!({"id": 1})));
1205
1206        // Get nonexistent key
1207        let missing = ctx.get_metadata("nonexistent").await;
1208        assert!(missing.is_none());
1209    }
1210
1211    #[tokio::test]
1212    async fn test_ws_context_send_text() {
1213        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
1214        let ctx = WsContext::new("conn123".to_string(), "/ws".to_string(), RoomManager::new(), tx);
1215
1216        ctx.send_text("hello").await.unwrap();
1217
1218        let msg = rx.recv().await.unwrap();
1219        assert!(matches!(msg, Message::Text(_)));
1220    }
1221
1222    #[tokio::test]
1223    async fn test_ws_context_send_binary() {
1224        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
1225        let ctx = WsContext::new("conn123".to_string(), "/ws".to_string(), RoomManager::new(), tx);
1226
1227        ctx.send_binary(vec![1, 2, 3]).await.unwrap();
1228
1229        let msg = rx.recv().await.unwrap();
1230        assert!(matches!(msg, Message::Binary(_)));
1231    }
1232
1233    #[tokio::test]
1234    async fn test_ws_context_send_json() {
1235        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
1236        let ctx = WsContext::new("conn123".to_string(), "/ws".to_string(), RoomManager::new(), tx);
1237
1238        ctx.send_json(&serde_json::json!({"type": "test"})).await.unwrap();
1239
1240        let msg = rx.recv().await.unwrap();
1241        assert!(matches!(msg, Message::Text(_)));
1242    }
1243
1244    #[tokio::test]
1245    async fn test_ws_context_rooms() {
1246        let (tx, _rx) = tokio::sync::mpsc::unbounded_channel();
1247        let ctx = WsContext::new("conn123".to_string(), "/ws".to_string(), RoomManager::new(), tx);
1248
1249        // Join rooms
1250        ctx.join_room("chat").await.unwrap();
1251        ctx.join_room("notifications").await.unwrap();
1252
1253        let rooms = ctx.get_rooms().await;
1254        assert_eq!(rooms.len(), 2);
1255
1256        // Leave room
1257        ctx.leave_room("chat").await.unwrap();
1258        let rooms = ctx.get_rooms().await;
1259        assert_eq!(rooms.len(), 1);
1260    }
1261}