mockforge_ws/
handlers.rs

1//! # Programmable WebSocket Handlers
2//!
3//! This module provides a flexible handler API for scripting WebSocket event flows.
4//! Unlike static replay, handlers allow you to write custom logic for responding to
5//! WebSocket events, manage rooms, and route messages dynamically.
6//!
7//! ## Features
8//!
9//! - **Connection Lifecycle**: `on_connect` and `on_disconnect` hooks
10//! - **Pattern Matching**: Route messages with regex or JSONPath patterns
11//! - **Room Management**: Broadcast messages to groups of connections
12//! - **Passthrough**: Selectively forward messages to upstream servers
13//! - **Hot Reload**: Automatically reload handlers when code changes (via `MOCKFORGE_WS_HOTRELOAD`)
14//!
15//! ## Quick Start
16//!
17//! ```rust,no_run
18//! use mockforge_ws::handlers::{WsHandler, WsContext, WsMessage, HandlerResult};
19//! use async_trait::async_trait;
20//!
21//! struct EchoHandler;
22//!
23//! #[async_trait]
24//! impl WsHandler for EchoHandler {
25//!     async fn on_connect(&self, ctx: &mut WsContext) -> HandlerResult<()> {
26//!         ctx.send_text("Welcome to the echo server!").await?;
27//!         Ok(())
28//!     }
29//!
30//!     async fn on_message(&self, ctx: &mut WsContext, msg: WsMessage) -> HandlerResult<()> {
31//!         if let WsMessage::Text(text) = msg {
32//!             ctx.send_text(&format!("echo: {}", text)).await?;
33//!         }
34//!         Ok(())
35//!     }
36//! }
37//! ```
38//!
39//! ## Message Pattern Matching
40//!
41//! ```rust,no_run
42//! use mockforge_ws::handlers::{WsHandler, WsContext, WsMessage, HandlerResult, MessagePattern};
43//! use async_trait::async_trait;
44//!
45//! struct ChatHandler;
46//!
47//! #[async_trait]
48//! impl WsHandler for ChatHandler {
49//!     async fn on_message(&self, ctx: &mut WsContext, msg: WsMessage) -> HandlerResult<()> {
50//!         if let WsMessage::Text(text) = msg {
51//!             // Use pattern matching to route messages
52//!             if let Ok(pattern) = MessagePattern::regex(r"^/join (.+)$") {
53//!                 if pattern.matches(&text) {
54//!                     // Extract room name and join
55//!                     if let Some(room) = text.strip_prefix("/join ") {
56//!                         ctx.join_room(room).await?;
57//!                         ctx.send_text(&format!("Joined room: {}", room)).await?;
58//!                     }
59//!                 }
60//!             }
61//!             // Handle JSON chat messages
62//!             let jsonpath_pattern = MessagePattern::jsonpath("$.type");
63//!             if jsonpath_pattern.matches(&text) {
64//!                 ctx.broadcast_to_room("general", &text).await?;
65//!             }
66//!         }
67//!         Ok(())
68//!     }
69//! }
70//! ```
71
72use async_trait::async_trait;
73use axum::extract::ws::Message;
74use regex::Regex;
75use serde_json::Value;
76use std::collections::{HashMap, HashSet};
77use std::sync::Arc;
78use tokio::sync::{broadcast, RwLock};
79
80/// Result type for handler operations
81pub type HandlerResult<T> = Result<T, HandlerError>;
82
83/// Error type for handler operations
84#[derive(Debug, thiserror::Error)]
85pub enum HandlerError {
86    /// Failed to send WebSocket message
87    #[error("Failed to send message: {0}")]
88    SendError(String),
89
90    /// JSON parsing/serialization error
91    #[error("Failed to parse JSON: {0}")]
92    JsonError(#[from] serde_json::Error),
93
94    /// Pattern matching failure (e.g., route pattern)
95    #[error("Pattern matching error: {0}")]
96    PatternError(String),
97
98    /// Room/group operation failure
99    #[error("Room operation failed: {0}")]
100    RoomError(String),
101
102    /// WebSocket connection error
103    #[error("Connection error: {0}")]
104    ConnectionError(String),
105
106    /// Generic handler error
107    #[error("Handler error: {0}")]
108    Generic(String),
109}
110
111/// WebSocket message wrapper for different message types
112#[derive(Debug, Clone)]
113pub enum WsMessage {
114    /// Text message (UTF-8 string)
115    Text(String),
116    /// Binary message (raw bytes)
117    Binary(Vec<u8>),
118    /// Ping frame (connection keepalive)
119    Ping(Vec<u8>),
120    /// Pong frame (response to ping)
121    Pong(Vec<u8>),
122    /// Close frame (connection termination)
123    Close,
124}
125
126impl From<Message> for WsMessage {
127    fn from(msg: Message) -> Self {
128        match msg {
129            Message::Text(text) => WsMessage::Text(text.to_string()),
130            Message::Binary(data) => WsMessage::Binary(data.to_vec()),
131            Message::Ping(data) => WsMessage::Ping(data.to_vec()),
132            Message::Pong(data) => WsMessage::Pong(data.to_vec()),
133            Message::Close(_) => WsMessage::Close,
134        }
135    }
136}
137
138impl From<WsMessage> for Message {
139    fn from(msg: WsMessage) -> Self {
140        match msg {
141            WsMessage::Text(text) => Message::Text(text.into()),
142            WsMessage::Binary(data) => Message::Binary(data.into()),
143            WsMessage::Ping(data) => Message::Ping(data.into()),
144            WsMessage::Pong(data) => Message::Pong(data.into()),
145            WsMessage::Close => Message::Close(None),
146        }
147    }
148}
149
150/// Pattern for matching WebSocket messages
151#[derive(Debug, Clone)]
152pub enum MessagePattern {
153    /// Match using regular expression
154    Regex(Regex),
155    /// Match using JSONPath query
156    JsonPath(String),
157    /// Match exact text
158    Exact(String),
159    /// Always matches
160    Any,
161}
162
163impl MessagePattern {
164    /// Create a regex pattern
165    pub fn regex(pattern: &str) -> HandlerResult<Self> {
166        Ok(MessagePattern::Regex(
167            Regex::new(pattern).map_err(|e| HandlerError::PatternError(e.to_string()))?,
168        ))
169    }
170
171    /// Create a JSONPath pattern
172    pub fn jsonpath(query: &str) -> Self {
173        MessagePattern::JsonPath(query.to_string())
174    }
175
176    /// Create an exact match pattern
177    pub fn exact(text: &str) -> Self {
178        MessagePattern::Exact(text.to_string())
179    }
180
181    /// Create a pattern that matches everything
182    pub fn any() -> Self {
183        MessagePattern::Any
184    }
185
186    /// Check if the pattern matches the message
187    pub fn matches(&self, text: &str) -> bool {
188        match self {
189            MessagePattern::Regex(re) => re.is_match(text),
190            MessagePattern::JsonPath(query) => {
191                // Try to parse as JSON and check if path exists
192                if let Ok(json) = serde_json::from_str::<Value>(text) {
193                    // Use jsonpath crate's Selector
194                    if let Ok(selector) = jsonpath::Selector::new(query) {
195                        let results: Vec<_> = selector.find(&json).collect();
196                        !results.is_empty()
197                    } else {
198                        false
199                    }
200                } else {
201                    false
202                }
203            }
204            MessagePattern::Exact(expected) => text == expected,
205            MessagePattern::Any => true,
206        }
207    }
208
209    /// Check if pattern matches and extract value using JSONPath
210    pub fn extract(&self, text: &str, query: &str) -> Option<Value> {
211        if let Ok(json) = serde_json::from_str::<Value>(text) {
212            if let Ok(selector) = jsonpath::Selector::new(query) {
213                let results: Vec<_> = selector.find(&json).collect();
214                results.first().cloned().cloned()
215            } else {
216                None
217            }
218        } else {
219            None
220        }
221    }
222}
223
224/// Connection ID type
225pub type ConnectionId = String;
226
227/// Room manager for broadcasting messages to groups of connections
228#[derive(Clone)]
229pub struct RoomManager {
230    rooms: Arc<RwLock<HashMap<String, HashSet<ConnectionId>>>>,
231    connections: Arc<RwLock<HashMap<ConnectionId, HashSet<String>>>>,
232    broadcasters: Arc<RwLock<HashMap<String, broadcast::Sender<String>>>>,
233}
234
235impl RoomManager {
236    /// Create a new room manager
237    pub fn new() -> Self {
238        Self {
239            rooms: Arc::new(RwLock::new(HashMap::new())),
240            connections: Arc::new(RwLock::new(HashMap::new())),
241            broadcasters: Arc::new(RwLock::new(HashMap::new())),
242        }
243    }
244
245    /// Join a room
246    pub async fn join(&self, conn_id: &str, room: &str) -> HandlerResult<()> {
247        let mut rooms = self.rooms.write().await;
248        let mut connections = self.connections.write().await;
249
250        rooms
251            .entry(room.to_string())
252            .or_insert_with(HashSet::new)
253            .insert(conn_id.to_string());
254
255        connections
256            .entry(conn_id.to_string())
257            .or_insert_with(HashSet::new)
258            .insert(room.to_string());
259
260        Ok(())
261    }
262
263    /// Leave a room
264    pub async fn leave(&self, conn_id: &str, room: &str) -> HandlerResult<()> {
265        let mut rooms = self.rooms.write().await;
266        let mut connections = self.connections.write().await;
267
268        if let Some(room_members) = rooms.get_mut(room) {
269            room_members.remove(conn_id);
270            if room_members.is_empty() {
271                rooms.remove(room);
272            }
273        }
274
275        if let Some(conn_rooms) = connections.get_mut(conn_id) {
276            conn_rooms.remove(room);
277            if conn_rooms.is_empty() {
278                connections.remove(conn_id);
279            }
280        }
281
282        Ok(())
283    }
284
285    /// Leave all rooms for a connection
286    pub async fn leave_all(&self, conn_id: &str) -> HandlerResult<()> {
287        let mut connections = self.connections.write().await;
288        if let Some(conn_rooms) = connections.remove(conn_id) {
289            let mut rooms = self.rooms.write().await;
290            for room in conn_rooms {
291                if let Some(room_members) = rooms.get_mut(&room) {
292                    room_members.remove(conn_id);
293                    if room_members.is_empty() {
294                        rooms.remove(&room);
295                    }
296                }
297            }
298        }
299        Ok(())
300    }
301
302    /// Get all connections in a room
303    pub async fn get_room_members(&self, room: &str) -> Vec<ConnectionId> {
304        let rooms = self.rooms.read().await;
305        rooms
306            .get(room)
307            .map(|members| members.iter().cloned().collect())
308            .unwrap_or_default()
309    }
310
311    /// Get all rooms for a connection
312    pub async fn get_connection_rooms(&self, conn_id: &str) -> Vec<String> {
313        let connections = self.connections.read().await;
314        connections
315            .get(conn_id)
316            .map(|rooms| rooms.iter().cloned().collect())
317            .unwrap_or_default()
318    }
319
320    /// Get broadcast sender for a room (creates if doesn't exist)
321    pub async fn get_broadcaster(&self, room: &str) -> broadcast::Sender<String> {
322        let mut broadcasters = self.broadcasters.write().await;
323        broadcasters
324            .entry(room.to_string())
325            .or_insert_with(|| {
326                let (tx, _) = broadcast::channel(1024);
327                tx
328            })
329            .clone()
330    }
331}
332
333impl Default for RoomManager {
334    fn default() -> Self {
335        Self::new()
336    }
337}
338
339/// Context provided to handlers for each connection
340pub struct WsContext {
341    /// Unique connection ID
342    pub connection_id: ConnectionId,
343    /// WebSocket path
344    pub path: String,
345    /// Room manager for broadcasting
346    room_manager: RoomManager,
347    /// Sender for outgoing messages
348    message_tx: tokio::sync::mpsc::UnboundedSender<Message>,
349    /// Metadata storage
350    metadata: Arc<RwLock<HashMap<String, Value>>>,
351}
352
353impl WsContext {
354    /// Create a new WebSocket context
355    pub fn new(
356        connection_id: ConnectionId,
357        path: String,
358        room_manager: RoomManager,
359        message_tx: tokio::sync::mpsc::UnboundedSender<Message>,
360    ) -> Self {
361        Self {
362            connection_id,
363            path,
364            room_manager,
365            message_tx,
366            metadata: Arc::new(RwLock::new(HashMap::new())),
367        }
368    }
369
370    /// Send a text message
371    pub async fn send_text(&self, text: &str) -> HandlerResult<()> {
372        self.message_tx
373            .send(Message::Text(text.to_string().into()))
374            .map_err(|e| HandlerError::SendError(e.to_string()))
375    }
376
377    /// Send a binary message
378    pub async fn send_binary(&self, data: Vec<u8>) -> HandlerResult<()> {
379        self.message_tx
380            .send(Message::Binary(data.into()))
381            .map_err(|e| HandlerError::SendError(e.to_string()))
382    }
383
384    /// Send a JSON message
385    pub async fn send_json(&self, value: &Value) -> HandlerResult<()> {
386        let text = serde_json::to_string(value)?;
387        self.send_text(&text).await
388    }
389
390    /// Join a room
391    pub async fn join_room(&self, room: &str) -> HandlerResult<()> {
392        self.room_manager.join(&self.connection_id, room).await
393    }
394
395    /// Leave a room
396    pub async fn leave_room(&self, room: &str) -> HandlerResult<()> {
397        self.room_manager.leave(&self.connection_id, room).await
398    }
399
400    /// Broadcast text to all members in a room
401    pub async fn broadcast_to_room(&self, room: &str, text: &str) -> HandlerResult<()> {
402        let broadcaster = self.room_manager.get_broadcaster(room).await;
403        broadcaster
404            .send(text.to_string())
405            .map_err(|e| HandlerError::RoomError(e.to_string()))?;
406        Ok(())
407    }
408
409    /// Get all rooms this connection is in
410    pub async fn get_rooms(&self) -> Vec<String> {
411        self.room_manager.get_connection_rooms(&self.connection_id).await
412    }
413
414    /// Set metadata value
415    pub async fn set_metadata(&self, key: &str, value: Value) {
416        let mut metadata = self.metadata.write().await;
417        metadata.insert(key.to_string(), value);
418    }
419
420    /// Get metadata value
421    pub async fn get_metadata(&self, key: &str) -> Option<Value> {
422        let metadata = self.metadata.read().await;
423        metadata.get(key).cloned()
424    }
425}
426
427/// Trait for WebSocket message handlers
428#[async_trait]
429pub trait WsHandler: Send + Sync {
430    /// Called when a new WebSocket connection is established
431    async fn on_connect(&self, _ctx: &mut WsContext) -> HandlerResult<()> {
432        Ok(())
433    }
434
435    /// Called when a message is received
436    async fn on_message(&self, ctx: &mut WsContext, msg: WsMessage) -> HandlerResult<()>;
437
438    /// Called when the connection is closed
439    async fn on_disconnect(&self, _ctx: &mut WsContext) -> HandlerResult<()> {
440        Ok(())
441    }
442
443    /// Check if this handler should handle the given path
444    fn handles_path(&self, _path: &str) -> bool {
445        true // Default: handle all paths
446    }
447}
448
449/// Pattern-based message router
450pub struct MessageRouter {
451    routes: Vec<(MessagePattern, Box<dyn Fn(String) -> Option<String> + Send + Sync>)>,
452}
453
454impl MessageRouter {
455    /// Create a new message router
456    pub fn new() -> Self {
457        Self { routes: Vec::new() }
458    }
459
460    /// Add a route with a pattern and handler function
461    pub fn on<F>(&mut self, pattern: MessagePattern, handler: F) -> &mut Self
462    where
463        F: Fn(String) -> Option<String> + Send + Sync + 'static,
464    {
465        self.routes.push((pattern, Box::new(handler)));
466        self
467    }
468
469    /// Route a message through the registered handlers
470    pub fn route(&self, text: &str) -> Option<String> {
471        for (pattern, handler) in &self.routes {
472            if pattern.matches(text) {
473                if let Some(response) = handler(text.to_string()) {
474                    return Some(response);
475                }
476            }
477        }
478        None
479    }
480}
481
482impl Default for MessageRouter {
483    fn default() -> Self {
484        Self::new()
485    }
486}
487
488/// Handler registry for managing multiple handlers
489pub struct HandlerRegistry {
490    handlers: Vec<Arc<dyn WsHandler>>,
491    hot_reload_enabled: bool,
492}
493
494impl HandlerRegistry {
495    /// Create a new handler registry
496    pub fn new() -> Self {
497        Self {
498            handlers: Vec::new(),
499            hot_reload_enabled: std::env::var("MOCKFORGE_WS_HOTRELOAD")
500                .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
501                .unwrap_or(false),
502        }
503    }
504
505    /// Create a registry with hot-reload enabled
506    pub fn with_hot_reload() -> Self {
507        Self {
508            handlers: Vec::new(),
509            hot_reload_enabled: true,
510        }
511    }
512
513    /// Check if hot-reload is enabled
514    pub fn is_hot_reload_enabled(&self) -> bool {
515        self.hot_reload_enabled
516    }
517
518    /// Register a handler
519    pub fn register<H: WsHandler + 'static>(&mut self, handler: H) -> &mut Self {
520        self.handlers.push(Arc::new(handler));
521        self
522    }
523
524    /// Get handlers for a specific path
525    pub fn get_handlers(&self, path: &str) -> Vec<Arc<dyn WsHandler>> {
526        self.handlers.iter().filter(|h| h.handles_path(path)).cloned().collect()
527    }
528
529    /// Check if any handler handles the given path
530    pub fn has_handler_for(&self, path: &str) -> bool {
531        self.handlers.iter().any(|h| h.handles_path(path))
532    }
533
534    /// Clear all handlers (useful for hot-reload)
535    pub fn clear(&mut self) {
536        self.handlers.clear();
537    }
538
539    /// Get the number of registered handlers
540    pub fn len(&self) -> usize {
541        self.handlers.len()
542    }
543
544    /// Check if the registry is empty
545    pub fn is_empty(&self) -> bool {
546        self.handlers.is_empty()
547    }
548}
549
550impl Default for HandlerRegistry {
551    fn default() -> Self {
552        Self::new()
553    }
554}
555
556/// Passthrough handler configuration for forwarding messages to upstream servers
557#[derive(Clone)]
558pub struct PassthroughConfig {
559    /// Pattern to match paths for passthrough
560    pub pattern: MessagePattern,
561    /// Upstream URL to forward to
562    pub upstream_url: String,
563}
564
565impl PassthroughConfig {
566    /// Create a new passthrough configuration
567    pub fn new(pattern: MessagePattern, upstream_url: String) -> Self {
568        Self {
569            pattern,
570            upstream_url,
571        }
572    }
573
574    /// Create a passthrough for all messages matching a regex
575    pub fn regex(regex: &str, upstream_url: String) -> HandlerResult<Self> {
576        Ok(Self {
577            pattern: MessagePattern::regex(regex)?,
578            upstream_url,
579        })
580    }
581}
582
583/// Passthrough handler that forwards messages to an upstream server
584pub struct PassthroughHandler {
585    config: PassthroughConfig,
586}
587
588impl PassthroughHandler {
589    /// Create a new passthrough handler
590    pub fn new(config: PassthroughConfig) -> Self {
591        Self { config }
592    }
593
594    /// Check if a message should be passed through
595    pub fn should_passthrough(&self, text: &str) -> bool {
596        self.config.pattern.matches(text)
597    }
598
599    /// Get the upstream URL
600    pub fn upstream_url(&self) -> &str {
601        &self.config.upstream_url
602    }
603}
604
605#[async_trait]
606impl WsHandler for PassthroughHandler {
607    async fn on_message(&self, ctx: &mut WsContext, msg: WsMessage) -> HandlerResult<()> {
608        if let WsMessage::Text(text) = &msg {
609            if self.should_passthrough(text) {
610                // In a real implementation, this would forward to upstream
611                // For now, we'll just log and echo back
612                ctx.send_text(&format!("PASSTHROUGH({}): {}", self.config.upstream_url, text))
613                    .await?;
614                return Ok(());
615            }
616        }
617        Ok(())
618    }
619}
620
621#[cfg(test)]
622mod tests {
623    use super::*;
624
625    #[test]
626    fn test_message_pattern_regex() {
627        let pattern = MessagePattern::regex(r"^hello").unwrap();
628        assert!(pattern.matches("hello world"));
629        assert!(!pattern.matches("goodbye world"));
630    }
631
632    #[test]
633    fn test_message_pattern_exact() {
634        let pattern = MessagePattern::exact("hello");
635        assert!(pattern.matches("hello"));
636        assert!(!pattern.matches("hello world"));
637    }
638
639    #[test]
640    fn test_message_pattern_jsonpath() {
641        let pattern = MessagePattern::jsonpath("$.type");
642        assert!(pattern.matches(r#"{"type": "message"}"#));
643        assert!(!pattern.matches(r#"{"name": "test"}"#));
644    }
645
646    #[tokio::test]
647    async fn test_room_manager() {
648        let manager = RoomManager::new();
649
650        // Join rooms
651        manager.join("conn1", "room1").await.unwrap();
652        manager.join("conn1", "room2").await.unwrap();
653        manager.join("conn2", "room1").await.unwrap();
654
655        // Check room members
656        let room1_members = manager.get_room_members("room1").await;
657        assert_eq!(room1_members.len(), 2);
658        assert!(room1_members.contains(&"conn1".to_string()));
659        assert!(room1_members.contains(&"conn2".to_string()));
660
661        // Check connection rooms
662        let conn1_rooms = manager.get_connection_rooms("conn1").await;
663        assert_eq!(conn1_rooms.len(), 2);
664        assert!(conn1_rooms.contains(&"room1".to_string()));
665        assert!(conn1_rooms.contains(&"room2".to_string()));
666
667        // Leave room
668        manager.leave("conn1", "room1").await.unwrap();
669        let room1_members = manager.get_room_members("room1").await;
670        assert_eq!(room1_members.len(), 1);
671        assert!(room1_members.contains(&"conn2".to_string()));
672
673        // Leave all rooms
674        manager.leave_all("conn1").await.unwrap();
675        let conn1_rooms = manager.get_connection_rooms("conn1").await;
676        assert_eq!(conn1_rooms.len(), 0);
677    }
678
679    #[test]
680    fn test_message_router() {
681        let mut router = MessageRouter::new();
682
683        router
684            .on(MessagePattern::exact("ping"), |_| Some("pong".to_string()))
685            .on(MessagePattern::regex(r"^hello").unwrap(), |_| Some("hi there!".to_string()));
686
687        assert_eq!(router.route("ping"), Some("pong".to_string()));
688        assert_eq!(router.route("hello world"), Some("hi there!".to_string()));
689        assert_eq!(router.route("goodbye"), None);
690    }
691}