mockforge_ws/
handlers.rs

1//! # Programmable WebSocket Handlers
2//!
3//! This module provides a flexible handler API for scripting WebSocket event flows.
4//! Unlike static replay, handlers allow you to write custom logic for responding to
5//! WebSocket events, manage rooms, and route messages dynamically.
6//!
7//! ## Features
8//!
9//! - **Connection Lifecycle**: `on_connect` and `on_disconnect` hooks
10//! - **Pattern Matching**: Route messages with regex or JSONPath patterns
11//! - **Room Management**: Broadcast messages to groups of connections
12//! - **Passthrough**: Selectively forward messages to upstream servers
13//! - **Hot Reload**: Automatically reload handlers when code changes (via `MOCKFORGE_WS_HOTRELOAD`)
14//!
15//! ## Quick Start
16//!
17//! ```rust,no_run
18//! use mockforge_ws::handlers::{WsHandler, WsContext, WsMessage, HandlerResult};
19//! use async_trait::async_trait;
20//!
21//! struct EchoHandler;
22//!
23//! #[async_trait]
24//! impl WsHandler for EchoHandler {
25//!     async fn on_connect(&self, ctx: &mut WsContext) -> HandlerResult<()> {
26//!         ctx.send_text("Welcome to the echo server!").await?;
27//!         Ok(())
28//!     }
29//!
30//!     async fn on_message(&self, ctx: &mut WsContext, msg: WsMessage) -> HandlerResult<()> {
31//!         if let WsMessage::Text(text) = msg {
32//!             ctx.send_text(&format!("echo: {}", text)).await?;
33//!         }
34//!         Ok(())
35//!     }
36//! }
37//! ```
38//!
39//! ## Message Pattern Matching
40//!
41//! ```rust,no_run
42//! use mockforge_ws::handlers::{WsHandler, WsContext, WsMessage, HandlerResult, MessagePattern};
43//! use async_trait::async_trait;
44//!
45//! struct ChatHandler;
46//!
47//! #[async_trait]
48//! impl WsHandler for ChatHandler {
49//!     async fn on_message(&self, ctx: &mut WsContext, msg: WsMessage) -> HandlerResult<()> {
50//!         if let WsMessage::Text(text) = msg {
51//!             // Use pattern matching to route messages
52//!             if let Ok(pattern) = MessagePattern::regex(r"^/join (.+)$") {
53//!                 if pattern.matches(&text) {
54//!                     // Extract room name and join
55//!                     if let Some(room) = text.strip_prefix("/join ") {
56//!                         ctx.join_room(room).await?;
57//!                         ctx.send_text(&format!("Joined room: {}", room)).await?;
58//!                     }
59//!                 }
60//!             }
61//!             // Handle JSON chat messages
62//!             let jsonpath_pattern = MessagePattern::jsonpath("$.type");
63//!             if jsonpath_pattern.matches(&text) {
64//!                 ctx.broadcast_to_room("general", &text).await?;
65//!             }
66//!         }
67//!         Ok(())
68//!     }
69//! }
70//! ```
71
72use async_trait::async_trait;
73use axum::extract::ws::Message;
74use regex::Regex;
75use serde_json::Value;
76use std::collections::{HashMap, HashSet};
77use std::sync::Arc;
78use tokio::sync::{broadcast, RwLock};
79
80/// Result type for handler operations
81pub type HandlerResult<T> = Result<T, HandlerError>;
82
83/// Error type for handler operations
84#[derive(Debug, thiserror::Error)]
85pub enum HandlerError {
86    /// Failed to send WebSocket message
87    #[error("Failed to send message: {0}")]
88    SendError(String),
89
90    /// JSON parsing/serialization error
91    #[error("Failed to parse JSON: {0}")]
92    JsonError(#[from] serde_json::Error),
93
94    /// Pattern matching failure (e.g., route pattern)
95    #[error("Pattern matching error: {0}")]
96    PatternError(String),
97
98    /// Room/group operation failure
99    #[error("Room operation failed: {0}")]
100    RoomError(String),
101
102    /// WebSocket connection error
103    #[error("Connection error: {0}")]
104    ConnectionError(String),
105
106    /// Generic handler error
107    #[error("Handler error: {0}")]
108    Generic(String),
109}
110
111/// WebSocket message wrapper for different message types
112#[derive(Debug, Clone)]
113pub enum WsMessage {
114    /// Text message (UTF-8 string)
115    Text(String),
116    /// Binary message (raw bytes)
117    Binary(Vec<u8>),
118    /// Ping frame (connection keepalive)
119    Ping(Vec<u8>),
120    /// Pong frame (response to ping)
121    Pong(Vec<u8>),
122    /// Close frame (connection termination)
123    Close,
124}
125
126impl From<Message> for WsMessage {
127    fn from(msg: Message) -> Self {
128        match msg {
129            Message::Text(text) => WsMessage::Text(text.to_string()),
130            Message::Binary(data) => WsMessage::Binary(data.to_vec()),
131            Message::Ping(data) => WsMessage::Ping(data.to_vec()),
132            Message::Pong(data) => WsMessage::Pong(data.to_vec()),
133            Message::Close(_) => WsMessage::Close,
134        }
135    }
136}
137
138impl From<WsMessage> for Message {
139    fn from(msg: WsMessage) -> Self {
140        match msg {
141            WsMessage::Text(text) => Message::Text(text.into()),
142            WsMessage::Binary(data) => Message::Binary(data.into()),
143            WsMessage::Ping(data) => Message::Ping(data.into()),
144            WsMessage::Pong(data) => Message::Pong(data.into()),
145            WsMessage::Close => Message::Close(None),
146        }
147    }
148}
149
150/// Pattern for matching WebSocket messages
151#[derive(Debug, Clone)]
152pub enum MessagePattern {
153    /// Match using regular expression
154    Regex(Regex),
155    /// Match using JSONPath query
156    JsonPath(String),
157    /// Match exact text
158    Exact(String),
159    /// Always matches
160    Any,
161}
162
163impl MessagePattern {
164    /// Create a regex pattern
165    pub fn regex(pattern: &str) -> HandlerResult<Self> {
166        Ok(MessagePattern::Regex(
167            Regex::new(pattern).map_err(|e| HandlerError::PatternError(e.to_string()))?,
168        ))
169    }
170
171    /// Create a JSONPath pattern
172    pub fn jsonpath(query: &str) -> Self {
173        MessagePattern::JsonPath(query.to_string())
174    }
175
176    /// Create an exact match pattern
177    pub fn exact(text: &str) -> Self {
178        MessagePattern::Exact(text.to_string())
179    }
180
181    /// Create a pattern that matches everything
182    pub fn any() -> Self {
183        MessagePattern::Any
184    }
185
186    /// Check if the pattern matches the message
187    pub fn matches(&self, text: &str) -> bool {
188        match self {
189            MessagePattern::Regex(re) => re.is_match(text),
190            MessagePattern::JsonPath(query) => {
191                // Try to parse as JSON and check if path exists
192                if let Ok(json) = serde_json::from_str::<Value>(text) {
193                    // Use jsonpath crate's Selector
194                    if let Ok(selector) = jsonpath::Selector::new(query) {
195                        let results: Vec<_> = selector.find(&json).collect();
196                        !results.is_empty()
197                    } else {
198                        false
199                    }
200                } else {
201                    false
202                }
203            }
204            MessagePattern::Exact(expected) => text == expected,
205            MessagePattern::Any => true,
206        }
207    }
208
209    /// Check if pattern matches and extract value using JSONPath
210    pub fn extract(&self, text: &str, query: &str) -> Option<Value> {
211        if let Ok(json) = serde_json::from_str::<Value>(text) {
212            if let Ok(selector) = jsonpath::Selector::new(query) {
213                let results: Vec<_> = selector.find(&json).collect();
214                results.first().cloned().cloned()
215            } else {
216                None
217            }
218        } else {
219            None
220        }
221    }
222}
223
224/// Connection ID type
225pub type ConnectionId = String;
226
227/// Room manager for broadcasting messages to groups of connections
228#[derive(Clone)]
229pub struct RoomManager {
230    rooms: Arc<RwLock<HashMap<String, HashSet<ConnectionId>>>>,
231    connections: Arc<RwLock<HashMap<ConnectionId, HashSet<String>>>>,
232    broadcasters: Arc<RwLock<HashMap<String, broadcast::Sender<String>>>>,
233}
234
235impl RoomManager {
236    /// Create a new room manager
237    pub fn new() -> Self {
238        Self {
239            rooms: Arc::new(RwLock::new(HashMap::new())),
240            connections: Arc::new(RwLock::new(HashMap::new())),
241            broadcasters: Arc::new(RwLock::new(HashMap::new())),
242        }
243    }
244
245    /// Join a room
246    pub async fn join(&self, conn_id: &str, room: &str) -> HandlerResult<()> {
247        let mut rooms = self.rooms.write().await;
248        let mut connections = self.connections.write().await;
249
250        rooms
251            .entry(room.to_string())
252            .or_insert_with(HashSet::new)
253            .insert(conn_id.to_string());
254
255        connections
256            .entry(conn_id.to_string())
257            .or_insert_with(HashSet::new)
258            .insert(room.to_string());
259
260        Ok(())
261    }
262
263    /// Leave a room
264    pub async fn leave(&self, conn_id: &str, room: &str) -> HandlerResult<()> {
265        let mut rooms = self.rooms.write().await;
266        let mut connections = self.connections.write().await;
267
268        if let Some(room_members) = rooms.get_mut(room) {
269            room_members.remove(conn_id);
270            if room_members.is_empty() {
271                rooms.remove(room);
272            }
273        }
274
275        if let Some(conn_rooms) = connections.get_mut(conn_id) {
276            conn_rooms.remove(room);
277            if conn_rooms.is_empty() {
278                connections.remove(conn_id);
279            }
280        }
281
282        Ok(())
283    }
284
285    /// Leave all rooms for a connection
286    pub async fn leave_all(&self, conn_id: &str) -> HandlerResult<()> {
287        let mut connections = self.connections.write().await;
288        if let Some(conn_rooms) = connections.remove(conn_id) {
289            let mut rooms = self.rooms.write().await;
290            for room in conn_rooms {
291                if let Some(room_members) = rooms.get_mut(&room) {
292                    room_members.remove(conn_id);
293                    if room_members.is_empty() {
294                        rooms.remove(&room);
295                    }
296                }
297            }
298        }
299        Ok(())
300    }
301
302    /// Get all connections in a room
303    pub async fn get_room_members(&self, room: &str) -> Vec<ConnectionId> {
304        let rooms = self.rooms.read().await;
305        rooms
306            .get(room)
307            .map(|members| members.iter().cloned().collect())
308            .unwrap_or_default()
309    }
310
311    /// Get all rooms for a connection
312    pub async fn get_connection_rooms(&self, conn_id: &str) -> Vec<String> {
313        let connections = self.connections.read().await;
314        connections
315            .get(conn_id)
316            .map(|rooms| rooms.iter().cloned().collect())
317            .unwrap_or_default()
318    }
319
320    /// Get broadcast sender for a room (creates if doesn't exist)
321    pub async fn get_broadcaster(&self, room: &str) -> broadcast::Sender<String> {
322        let mut broadcasters = self.broadcasters.write().await;
323        broadcasters
324            .entry(room.to_string())
325            .or_insert_with(|| {
326                let (tx, _) = broadcast::channel(1024);
327                tx
328            })
329            .clone()
330    }
331}
332
333impl Default for RoomManager {
334    fn default() -> Self {
335        Self::new()
336    }
337}
338
339/// Context provided to handlers for each connection
340pub struct WsContext {
341    /// Unique connection ID
342    pub connection_id: ConnectionId,
343    /// WebSocket path
344    pub path: String,
345    /// Room manager for broadcasting
346    room_manager: RoomManager,
347    /// Sender for outgoing messages
348    message_tx: tokio::sync::mpsc::UnboundedSender<Message>,
349    /// Metadata storage
350    metadata: Arc<RwLock<HashMap<String, Value>>>,
351}
352
353impl WsContext {
354    /// Create a new WebSocket context
355    pub fn new(
356        connection_id: ConnectionId,
357        path: String,
358        room_manager: RoomManager,
359        message_tx: tokio::sync::mpsc::UnboundedSender<Message>,
360    ) -> Self {
361        Self {
362            connection_id,
363            path,
364            room_manager,
365            message_tx,
366            metadata: Arc::new(RwLock::new(HashMap::new())),
367        }
368    }
369
370    /// Send a text message
371    pub async fn send_text(&self, text: &str) -> HandlerResult<()> {
372        self.message_tx
373            .send(Message::Text(text.to_string().into()))
374            .map_err(|e| HandlerError::SendError(e.to_string()))
375    }
376
377    /// Send a binary message
378    pub async fn send_binary(&self, data: Vec<u8>) -> HandlerResult<()> {
379        self.message_tx
380            .send(Message::Binary(data.into()))
381            .map_err(|e| HandlerError::SendError(e.to_string()))
382    }
383
384    /// Send a JSON message
385    pub async fn send_json(&self, value: &Value) -> HandlerResult<()> {
386        let text = serde_json::to_string(value)?;
387        self.send_text(&text).await
388    }
389
390    /// Join a room
391    pub async fn join_room(&self, room: &str) -> HandlerResult<()> {
392        self.room_manager.join(&self.connection_id, room).await
393    }
394
395    /// Leave a room
396    pub async fn leave_room(&self, room: &str) -> HandlerResult<()> {
397        self.room_manager.leave(&self.connection_id, room).await
398    }
399
400    /// Broadcast text to all members in a room
401    pub async fn broadcast_to_room(&self, room: &str, text: &str) -> HandlerResult<()> {
402        let broadcaster = self.room_manager.get_broadcaster(room).await;
403        broadcaster
404            .send(text.to_string())
405            .map_err(|e| HandlerError::RoomError(e.to_string()))?;
406        Ok(())
407    }
408
409    /// Get all rooms this connection is in
410    pub async fn get_rooms(&self) -> Vec<String> {
411        self.room_manager.get_connection_rooms(&self.connection_id).await
412    }
413
414    /// Set metadata value
415    pub async fn set_metadata(&self, key: &str, value: Value) {
416        let mut metadata = self.metadata.write().await;
417        metadata.insert(key.to_string(), value);
418    }
419
420    /// Get metadata value
421    pub async fn get_metadata(&self, key: &str) -> Option<Value> {
422        let metadata = self.metadata.read().await;
423        metadata.get(key).cloned()
424    }
425}
426
427/// Trait for WebSocket message handlers
428#[async_trait]
429pub trait WsHandler: Send + Sync {
430    /// Called when a new WebSocket connection is established
431    async fn on_connect(&self, _ctx: &mut WsContext) -> HandlerResult<()> {
432        Ok(())
433    }
434
435    /// Called when a message is received
436    async fn on_message(&self, ctx: &mut WsContext, msg: WsMessage) -> HandlerResult<()>;
437
438    /// Called when the connection is closed
439    async fn on_disconnect(&self, _ctx: &mut WsContext) -> HandlerResult<()> {
440        Ok(())
441    }
442
443    /// Check if this handler should handle the given path
444    fn handles_path(&self, _path: &str) -> bool {
445        true // Default: handle all paths
446    }
447}
448
449/// Pattern-based message router
450pub struct MessageRouter {
451    routes: Vec<(MessagePattern, Box<dyn Fn(String) -> Option<String> + Send + Sync>)>,
452}
453
454impl MessageRouter {
455    /// Create a new message router
456    pub fn new() -> Self {
457        Self { routes: Vec::new() }
458    }
459
460    /// Add a route with a pattern and handler function
461    pub fn on<F>(&mut self, pattern: MessagePattern, handler: F) -> &mut Self
462    where
463        F: Fn(String) -> Option<String> + Send + Sync + 'static,
464    {
465        self.routes.push((pattern, Box::new(handler)));
466        self
467    }
468
469    /// Route a message through the registered handlers
470    pub fn route(&self, text: &str) -> Option<String> {
471        for (pattern, handler) in &self.routes {
472            if pattern.matches(text) {
473                if let Some(response) = handler(text.to_string()) {
474                    return Some(response);
475                }
476            }
477        }
478        None
479    }
480}
481
482impl Default for MessageRouter {
483    fn default() -> Self {
484        Self::new()
485    }
486}
487
488/// Handler registry for managing multiple handlers
489pub struct HandlerRegistry {
490    handlers: Vec<Arc<dyn WsHandler>>,
491    hot_reload_enabled: bool,
492}
493
494impl HandlerRegistry {
495    /// Create a new handler registry
496    pub fn new() -> Self {
497        Self {
498            handlers: Vec::new(),
499            hot_reload_enabled: std::env::var("MOCKFORGE_WS_HOTRELOAD")
500                .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
501                .unwrap_or(false),
502        }
503    }
504
505    /// Create a registry with hot-reload enabled
506    pub fn with_hot_reload() -> Self {
507        Self {
508            handlers: Vec::new(),
509            hot_reload_enabled: true,
510        }
511    }
512
513    /// Check if hot-reload is enabled
514    pub fn is_hot_reload_enabled(&self) -> bool {
515        self.hot_reload_enabled
516    }
517
518    /// Register a handler
519    pub fn register<H: WsHandler + 'static>(&mut self, handler: H) -> &mut Self {
520        self.handlers.push(Arc::new(handler));
521        self
522    }
523
524    /// Get handlers for a specific path
525    pub fn get_handlers(&self, path: &str) -> Vec<Arc<dyn WsHandler>> {
526        self.handlers.iter().filter(|h| h.handles_path(path)).cloned().collect()
527    }
528
529    /// Check if any handler handles the given path
530    pub fn has_handler_for(&self, path: &str) -> bool {
531        self.handlers.iter().any(|h| h.handles_path(path))
532    }
533
534    /// Clear all handlers (useful for hot-reload)
535    pub fn clear(&mut self) {
536        self.handlers.clear();
537    }
538
539    /// Get the number of registered handlers
540    pub fn len(&self) -> usize {
541        self.handlers.len()
542    }
543
544    /// Check if the registry is empty
545    pub fn is_empty(&self) -> bool {
546        self.handlers.is_empty()
547    }
548}
549
550impl Default for HandlerRegistry {
551    fn default() -> Self {
552        Self::new()
553    }
554}
555
556/// Passthrough handler configuration for forwarding messages to upstream servers
557#[derive(Clone)]
558pub struct PassthroughConfig {
559    /// Pattern to match paths for passthrough
560    pub pattern: MessagePattern,
561    /// Upstream URL to forward to
562    pub upstream_url: String,
563}
564
565impl PassthroughConfig {
566    /// Create a new passthrough configuration
567    pub fn new(pattern: MessagePattern, upstream_url: String) -> Self {
568        Self {
569            pattern,
570            upstream_url,
571        }
572    }
573
574    /// Create a passthrough for all messages matching a regex
575    pub fn regex(regex: &str, upstream_url: String) -> HandlerResult<Self> {
576        Ok(Self {
577            pattern: MessagePattern::regex(regex)?,
578            upstream_url,
579        })
580    }
581}
582
583/// Passthrough handler that forwards messages to an upstream server
584pub struct PassthroughHandler {
585    config: PassthroughConfig,
586}
587
588impl PassthroughHandler {
589    /// Create a new passthrough handler
590    pub fn new(config: PassthroughConfig) -> Self {
591        Self { config }
592    }
593
594    /// Check if a message should be passed through
595    pub fn should_passthrough(&self, text: &str) -> bool {
596        self.config.pattern.matches(text)
597    }
598
599    /// Get the upstream URL
600    pub fn upstream_url(&self) -> &str {
601        &self.config.upstream_url
602    }
603}
604
605#[async_trait]
606impl WsHandler for PassthroughHandler {
607    async fn on_message(&self, ctx: &mut WsContext, msg: WsMessage) -> HandlerResult<()> {
608        if let WsMessage::Text(text) = &msg {
609            if self.should_passthrough(text) {
610                // In a real implementation, this would forward to upstream
611                // For now, we'll just log and echo back
612                ctx.send_text(&format!("PASSTHROUGH({}): {}", self.config.upstream_url, text))
613                    .await?;
614                return Ok(());
615            }
616        }
617        Ok(())
618    }
619}
620
621#[cfg(test)]
622mod tests {
623    use super::*;
624
625    // ==================== WsMessage Tests ====================
626
627    #[test]
628    fn test_ws_message_text_from_axum() {
629        let axum_msg = Message::Text("hello".to_string().into());
630        let ws_msg: WsMessage = axum_msg.into();
631        match ws_msg {
632            WsMessage::Text(text) => assert_eq!(text, "hello"),
633            _ => panic!("Expected Text message"),
634        }
635    }
636
637    #[test]
638    fn test_ws_message_binary_from_axum() {
639        let data = vec![1, 2, 3, 4];
640        let axum_msg = Message::Binary(data.clone().into());
641        let ws_msg: WsMessage = axum_msg.into();
642        match ws_msg {
643            WsMessage::Binary(bytes) => assert_eq!(bytes, data),
644            _ => panic!("Expected Binary message"),
645        }
646    }
647
648    #[test]
649    fn test_ws_message_ping_from_axum() {
650        let data = vec![1, 2];
651        let axum_msg = Message::Ping(data.clone().into());
652        let ws_msg: WsMessage = axum_msg.into();
653        match ws_msg {
654            WsMessage::Ping(bytes) => assert_eq!(bytes, data),
655            _ => panic!("Expected Ping message"),
656        }
657    }
658
659    #[test]
660    fn test_ws_message_pong_from_axum() {
661        let data = vec![3, 4];
662        let axum_msg = Message::Pong(data.clone().into());
663        let ws_msg: WsMessage = axum_msg.into();
664        match ws_msg {
665            WsMessage::Pong(bytes) => assert_eq!(bytes, data),
666            _ => panic!("Expected Pong message"),
667        }
668    }
669
670    #[test]
671    fn test_ws_message_close_from_axum() {
672        let axum_msg = Message::Close(None);
673        let ws_msg: WsMessage = axum_msg.into();
674        assert!(matches!(ws_msg, WsMessage::Close));
675    }
676
677    #[test]
678    fn test_ws_message_text_to_axum() {
679        let ws_msg = WsMessage::Text("hello".to_string());
680        let axum_msg: Message = ws_msg.into();
681        assert!(matches!(axum_msg, Message::Text(_)));
682    }
683
684    #[test]
685    fn test_ws_message_binary_to_axum() {
686        let ws_msg = WsMessage::Binary(vec![1, 2, 3]);
687        let axum_msg: Message = ws_msg.into();
688        assert!(matches!(axum_msg, Message::Binary(_)));
689    }
690
691    #[test]
692    fn test_ws_message_close_to_axum() {
693        let ws_msg = WsMessage::Close;
694        let axum_msg: Message = ws_msg.into();
695        assert!(matches!(axum_msg, Message::Close(_)));
696    }
697
698    // ==================== MessagePattern Tests ====================
699
700    #[test]
701    fn test_message_pattern_regex() {
702        let pattern = MessagePattern::regex(r"^hello").unwrap();
703        assert!(pattern.matches("hello world"));
704        assert!(!pattern.matches("goodbye world"));
705    }
706
707    #[test]
708    fn test_message_pattern_regex_invalid() {
709        let result = MessagePattern::regex(r"[invalid");
710        assert!(result.is_err());
711    }
712
713    #[test]
714    fn test_message_pattern_exact() {
715        let pattern = MessagePattern::exact("hello");
716        assert!(pattern.matches("hello"));
717        assert!(!pattern.matches("hello world"));
718    }
719
720    #[test]
721    fn test_message_pattern_jsonpath() {
722        let pattern = MessagePattern::jsonpath("$.type");
723        assert!(pattern.matches(r#"{"type": "message"}"#));
724        assert!(!pattern.matches(r#"{"name": "test"}"#));
725    }
726
727    #[test]
728    fn test_message_pattern_jsonpath_nested() {
729        let pattern = MessagePattern::jsonpath("$.user.name");
730        assert!(pattern.matches(r#"{"user": {"name": "John"}}"#));
731        assert!(!pattern.matches(r#"{"user": {"email": "john@example.com"}}"#));
732    }
733
734    #[test]
735    fn test_message_pattern_jsonpath_invalid_json() {
736        let pattern = MessagePattern::jsonpath("$.type");
737        assert!(!pattern.matches("not json"));
738    }
739
740    #[test]
741    fn test_message_pattern_any() {
742        let pattern = MessagePattern::any();
743        assert!(pattern.matches("anything"));
744        assert!(pattern.matches(""));
745        assert!(pattern.matches(r#"{"json": true}"#));
746    }
747
748    #[test]
749    fn test_message_pattern_extract() {
750        let pattern = MessagePattern::jsonpath("$.type");
751        let result = pattern.extract(r#"{"type": "greeting", "data": "hello"}"#, "$.type");
752        assert_eq!(result, Some(serde_json::json!("greeting")));
753    }
754
755    #[test]
756    fn test_message_pattern_extract_nested() {
757        let pattern = MessagePattern::any();
758        let result = pattern.extract(r#"{"user": {"id": 123}}"#, "$.user.id");
759        assert_eq!(result, Some(serde_json::json!(123)));
760    }
761
762    #[test]
763    fn test_message_pattern_extract_not_found() {
764        let pattern = MessagePattern::any();
765        let result = pattern.extract(r#"{"type": "message"}"#, "$.nonexistent");
766        assert!(result.is_none());
767    }
768
769    #[test]
770    fn test_message_pattern_extract_invalid_json() {
771        let pattern = MessagePattern::any();
772        let result = pattern.extract("not json", "$.type");
773        assert!(result.is_none());
774    }
775
776    // ==================== RoomManager Tests ====================
777
778    #[tokio::test]
779    async fn test_room_manager() {
780        let manager = RoomManager::new();
781
782        // Join rooms
783        manager.join("conn1", "room1").await.unwrap();
784        manager.join("conn1", "room2").await.unwrap();
785        manager.join("conn2", "room1").await.unwrap();
786
787        // Check room members
788        let room1_members = manager.get_room_members("room1").await;
789        assert_eq!(room1_members.len(), 2);
790        assert!(room1_members.contains(&"conn1".to_string()));
791        assert!(room1_members.contains(&"conn2".to_string()));
792
793        // Check connection rooms
794        let conn1_rooms = manager.get_connection_rooms("conn1").await;
795        assert_eq!(conn1_rooms.len(), 2);
796        assert!(conn1_rooms.contains(&"room1".to_string()));
797        assert!(conn1_rooms.contains(&"room2".to_string()));
798
799        // Leave room
800        manager.leave("conn1", "room1").await.unwrap();
801        let room1_members = manager.get_room_members("room1").await;
802        assert_eq!(room1_members.len(), 1);
803        assert!(room1_members.contains(&"conn2".to_string()));
804
805        // Leave all rooms
806        manager.leave_all("conn1").await.unwrap();
807        let conn1_rooms = manager.get_connection_rooms("conn1").await;
808        assert_eq!(conn1_rooms.len(), 0);
809    }
810
811    #[tokio::test]
812    async fn test_room_manager_default() {
813        let manager = RoomManager::default();
814        // Should work the same as new()
815        manager.join("conn1", "room1").await.unwrap();
816        let members = manager.get_room_members("room1").await;
817        assert_eq!(members.len(), 1);
818    }
819
820    #[tokio::test]
821    async fn test_room_manager_empty_room() {
822        let manager = RoomManager::new();
823        let members = manager.get_room_members("nonexistent").await;
824        assert!(members.is_empty());
825    }
826
827    #[tokio::test]
828    async fn test_room_manager_empty_connection() {
829        let manager = RoomManager::new();
830        let rooms = manager.get_connection_rooms("nonexistent").await;
831        assert!(rooms.is_empty());
832    }
833
834    #[tokio::test]
835    async fn test_room_manager_leave_nonexistent() {
836        let manager = RoomManager::new();
837        // Should not error when leaving a room we're not in
838        let result = manager.leave("conn1", "room1").await;
839        assert!(result.is_ok());
840    }
841
842    #[tokio::test]
843    async fn test_room_manager_broadcaster() {
844        let manager = RoomManager::new();
845        manager.join("conn1", "room1").await.unwrap();
846
847        let broadcaster = manager.get_broadcaster("room1").await;
848        let mut receiver = broadcaster.subscribe();
849
850        // Send a message
851        broadcaster.send("hello".to_string()).unwrap();
852
853        // Receive it
854        let msg = receiver.recv().await.unwrap();
855        assert_eq!(msg, "hello");
856    }
857
858    #[tokio::test]
859    async fn test_room_manager_room_cleanup_on_last_leave() {
860        let manager = RoomManager::new();
861        manager.join("conn1", "room1").await.unwrap();
862        manager.leave("conn1", "room1").await.unwrap();
863
864        // Room should be cleaned up
865        let members = manager.get_room_members("room1").await;
866        assert!(members.is_empty());
867    }
868
869    // ==================== MessageRouter Tests ====================
870
871    #[test]
872    fn test_message_router() {
873        let mut router = MessageRouter::new();
874
875        router
876            .on(MessagePattern::exact("ping"), |_| Some("pong".to_string()))
877            .on(MessagePattern::regex(r"^hello").unwrap(), |_| Some("hi there!".to_string()));
878
879        assert_eq!(router.route("ping"), Some("pong".to_string()));
880        assert_eq!(router.route("hello world"), Some("hi there!".to_string()));
881        assert_eq!(router.route("goodbye"), None);
882    }
883
884    #[test]
885    fn test_message_router_default() {
886        let router = MessageRouter::default();
887        // Empty router returns None for all messages
888        assert_eq!(router.route("anything"), None);
889    }
890
891    #[test]
892    fn test_message_router_first_match_wins() {
893        let mut router = MessageRouter::new();
894        router
895            .on(MessagePattern::any(), |_| Some("first".to_string()))
896            .on(MessagePattern::any(), |_| Some("second".to_string()));
897
898        assert_eq!(router.route("test"), Some("first".to_string()));
899    }
900
901    #[test]
902    fn test_message_router_handler_returns_none() {
903        let mut router = MessageRouter::new();
904        router
905            .on(MessagePattern::exact("skip"), |_| None)
906            .on(MessagePattern::any(), |_| Some("fallback".to_string()));
907
908        // First pattern matches but returns None, so it continues to next
909        assert_eq!(router.route("skip"), Some("fallback".to_string()));
910    }
911
912    // ==================== HandlerRegistry Tests ====================
913
914    struct TestHandler;
915
916    #[async_trait]
917    impl WsHandler for TestHandler {
918        async fn on_message(&self, _ctx: &mut WsContext, _msg: WsMessage) -> HandlerResult<()> {
919            Ok(())
920        }
921    }
922
923    struct PathSpecificHandler {
924        path: String,
925    }
926
927    #[async_trait]
928    impl WsHandler for PathSpecificHandler {
929        async fn on_message(&self, _ctx: &mut WsContext, _msg: WsMessage) -> HandlerResult<()> {
930            Ok(())
931        }
932
933        fn handles_path(&self, path: &str) -> bool {
934            path == self.path
935        }
936    }
937
938    #[test]
939    fn test_handler_registry_new() {
940        let registry = HandlerRegistry::new();
941        assert!(registry.is_empty());
942        assert_eq!(registry.len(), 0);
943    }
944
945    #[test]
946    fn test_handler_registry_default() {
947        let registry = HandlerRegistry::default();
948        assert!(registry.is_empty());
949    }
950
951    #[test]
952    fn test_handler_registry_register() {
953        let mut registry = HandlerRegistry::new();
954        registry.register(TestHandler);
955        assert!(!registry.is_empty());
956        assert_eq!(registry.len(), 1);
957    }
958
959    #[test]
960    fn test_handler_registry_get_handlers() {
961        let mut registry = HandlerRegistry::new();
962        registry.register(TestHandler);
963
964        let handlers = registry.get_handlers("/any/path");
965        assert_eq!(handlers.len(), 1);
966    }
967
968    #[test]
969    fn test_handler_registry_path_filtering() {
970        let mut registry = HandlerRegistry::new();
971        registry.register(PathSpecificHandler {
972            path: "/ws/chat".to_string(),
973        });
974        registry.register(PathSpecificHandler {
975            path: "/ws/events".to_string(),
976        });
977
978        let chat_handlers = registry.get_handlers("/ws/chat");
979        assert_eq!(chat_handlers.len(), 1);
980
981        let events_handlers = registry.get_handlers("/ws/events");
982        assert_eq!(events_handlers.len(), 1);
983
984        let other_handlers = registry.get_handlers("/ws/other");
985        assert!(other_handlers.is_empty());
986    }
987
988    #[test]
989    fn test_handler_registry_has_handler_for() {
990        let mut registry = HandlerRegistry::new();
991        registry.register(PathSpecificHandler {
992            path: "/ws/chat".to_string(),
993        });
994
995        assert!(registry.has_handler_for("/ws/chat"));
996        assert!(!registry.has_handler_for("/ws/other"));
997    }
998
999    #[test]
1000    fn test_handler_registry_clear() {
1001        let mut registry = HandlerRegistry::new();
1002        registry.register(TestHandler);
1003        registry.register(TestHandler);
1004        assert_eq!(registry.len(), 2);
1005
1006        registry.clear();
1007        assert!(registry.is_empty());
1008    }
1009
1010    #[test]
1011    fn test_handler_registry_with_hot_reload() {
1012        let registry = HandlerRegistry::with_hot_reload();
1013        assert!(registry.is_hot_reload_enabled());
1014    }
1015
1016    // ==================== PassthroughConfig Tests ====================
1017
1018    #[test]
1019    fn test_passthrough_config_new() {
1020        let config =
1021            PassthroughConfig::new(MessagePattern::any(), "ws://upstream:8080".to_string());
1022        assert_eq!(config.upstream_url, "ws://upstream:8080");
1023    }
1024
1025    #[test]
1026    fn test_passthrough_config_regex() {
1027        let config =
1028            PassthroughConfig::regex(r"^forward", "ws://upstream:8080".to_string()).unwrap();
1029        assert!(config.pattern.matches("forward this"));
1030        assert!(!config.pattern.matches("don't forward"));
1031    }
1032
1033    #[test]
1034    fn test_passthrough_config_regex_invalid() {
1035        let result = PassthroughConfig::regex(r"[invalid", "ws://upstream:8080".to_string());
1036        assert!(result.is_err());
1037    }
1038
1039    // ==================== PassthroughHandler Tests ====================
1040
1041    #[test]
1042    fn test_passthrough_handler_should_passthrough() {
1043        let config =
1044            PassthroughConfig::regex(r"^proxy:", "ws://upstream:8080".to_string()).unwrap();
1045        let handler = PassthroughHandler::new(config);
1046
1047        assert!(handler.should_passthrough("proxy:hello"));
1048        assert!(!handler.should_passthrough("hello"));
1049    }
1050
1051    #[test]
1052    fn test_passthrough_handler_upstream_url() {
1053        let config =
1054            PassthroughConfig::new(MessagePattern::any(), "ws://upstream:8080".to_string());
1055        let handler = PassthroughHandler::new(config);
1056
1057        assert_eq!(handler.upstream_url(), "ws://upstream:8080");
1058    }
1059
1060    // ==================== HandlerError Tests ====================
1061
1062    #[test]
1063    fn test_handler_error_send_error() {
1064        let err = HandlerError::SendError("connection closed".to_string());
1065        assert!(err.to_string().contains("send message"));
1066        assert!(err.to_string().contains("connection closed"));
1067    }
1068
1069    #[test]
1070    fn test_handler_error_json_error() {
1071        let json_err = serde_json::from_str::<serde_json::Value>("invalid").unwrap_err();
1072        let err = HandlerError::JsonError(json_err);
1073        assert!(err.to_string().contains("JSON"));
1074    }
1075
1076    #[test]
1077    fn test_handler_error_pattern_error() {
1078        let err = HandlerError::PatternError("invalid regex".to_string());
1079        assert!(err.to_string().contains("Pattern"));
1080    }
1081
1082    #[test]
1083    fn test_handler_error_room_error() {
1084        let err = HandlerError::RoomError("room full".to_string());
1085        assert!(err.to_string().contains("Room"));
1086    }
1087
1088    #[test]
1089    fn test_handler_error_connection_error() {
1090        let err = HandlerError::ConnectionError("timeout".to_string());
1091        assert!(err.to_string().contains("Connection"));
1092    }
1093
1094    #[test]
1095    fn test_handler_error_generic() {
1096        let err = HandlerError::Generic("something went wrong".to_string());
1097        assert!(err.to_string().contains("something went wrong"));
1098    }
1099
1100    // ==================== WsContext Tests ====================
1101
1102    #[tokio::test]
1103    async fn test_ws_context_metadata() {
1104        let (tx, _rx) = tokio::sync::mpsc::unbounded_channel();
1105        let ctx = WsContext::new("conn123".to_string(), "/ws".to_string(), RoomManager::new(), tx);
1106
1107        // Set and get metadata
1108        ctx.set_metadata("user", serde_json::json!({"id": 1})).await;
1109        let value = ctx.get_metadata("user").await;
1110        assert_eq!(value, Some(serde_json::json!({"id": 1})));
1111
1112        // Get nonexistent key
1113        let missing = ctx.get_metadata("nonexistent").await;
1114        assert!(missing.is_none());
1115    }
1116
1117    #[tokio::test]
1118    async fn test_ws_context_send_text() {
1119        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
1120        let ctx = WsContext::new("conn123".to_string(), "/ws".to_string(), RoomManager::new(), tx);
1121
1122        ctx.send_text("hello").await.unwrap();
1123
1124        let msg = rx.recv().await.unwrap();
1125        assert!(matches!(msg, Message::Text(_)));
1126    }
1127
1128    #[tokio::test]
1129    async fn test_ws_context_send_binary() {
1130        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
1131        let ctx = WsContext::new("conn123".to_string(), "/ws".to_string(), RoomManager::new(), tx);
1132
1133        ctx.send_binary(vec![1, 2, 3]).await.unwrap();
1134
1135        let msg = rx.recv().await.unwrap();
1136        assert!(matches!(msg, Message::Binary(_)));
1137    }
1138
1139    #[tokio::test]
1140    async fn test_ws_context_send_json() {
1141        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
1142        let ctx = WsContext::new("conn123".to_string(), "/ws".to_string(), RoomManager::new(), tx);
1143
1144        ctx.send_json(&serde_json::json!({"type": "test"})).await.unwrap();
1145
1146        let msg = rx.recv().await.unwrap();
1147        assert!(matches!(msg, Message::Text(_)));
1148    }
1149
1150    #[tokio::test]
1151    async fn test_ws_context_rooms() {
1152        let (tx, _rx) = tokio::sync::mpsc::unbounded_channel();
1153        let ctx = WsContext::new("conn123".to_string(), "/ws".to_string(), RoomManager::new(), tx);
1154
1155        // Join rooms
1156        ctx.join_room("chat").await.unwrap();
1157        ctx.join_room("notifications").await.unwrap();
1158
1159        let rooms = ctx.get_rooms().await;
1160        assert_eq!(rooms.len(), 2);
1161
1162        // Leave room
1163        ctx.leave_room("chat").await.unwrap();
1164        let rooms = ctx.get_rooms().await;
1165        assert_eq!(rooms.len(), 1);
1166    }
1167}