mockforge_ws/
handlers.rs

1//! # Programmable WebSocket Handlers
2//!
3//! This module provides a flexible handler API for scripting WebSocket event flows.
4//! Unlike static replay, handlers allow you to write custom logic for responding to
5//! WebSocket events, manage rooms, and route messages dynamically.
6//!
7//! ## Features
8//!
9//! - **Connection Lifecycle**: `on_connect` and `on_disconnect` hooks
10//! - **Pattern Matching**: Route messages with regex or JSONPath patterns
11//! - **Room Management**: Broadcast messages to groups of connections
12//! - **Passthrough**: Selectively forward messages to upstream servers
13//! - **Hot Reload**: Automatically reload handlers when code changes (via `MOCKFORGE_WS_HOTRELOAD`)
14//!
15//! ## Quick Start
16//!
17//! ```rust,no_run
18//! use mockforge_ws::handlers::{WsHandler, WsContext, WsMessage, HandlerResult};
19//! use async_trait::async_trait;
20//!
21//! struct EchoHandler;
22//!
23//! #[async_trait]
24//! impl WsHandler for EchoHandler {
25//!     async fn on_connect(&self, ctx: &mut WsContext) -> HandlerResult<()> {
26//!         ctx.send_text("Welcome to the echo server!").await?;
27//!         Ok(())
28//!     }
29//!
30//!     async fn on_message(&self, ctx: &mut WsContext, msg: WsMessage) -> HandlerResult<()> {
31//!         if let WsMessage::Text(text) = msg {
32//!             ctx.send_text(&format!("echo: {}", text)).await?;
33//!         }
34//!         Ok(())
35//!     }
36//! }
37//! ```
38//!
39//! ## Message Pattern Matching
40//!
41//! ```rust,no_run
42//! use mockforge_ws::handlers::{WsHandler, WsContext, WsMessage, HandlerResult, MessagePattern};
43//! use async_trait::async_trait;
44//!
45//! struct ChatHandler;
46//!
47//! #[async_trait]
48//! impl WsHandler for ChatHandler {
49//!     async fn on_message(&self, ctx: &mut WsContext, msg: WsMessage) -> HandlerResult<()> {
50//!         if let WsMessage::Text(text) = msg {
51//!             // Use pattern matching to route messages
52//!             if let Ok(pattern) = MessagePattern::regex(r"^/join (.+)$") {
53//!                 if pattern.matches(&text) {
54//!                     // Extract room name and join
55//!                     if let Some(room) = text.strip_prefix("/join ") {
56//!                         ctx.join_room(room).await?;
57//!                         ctx.send_text(&format!("Joined room: {}", room)).await?;
58//!                     }
59//!                 }
60//!             }
61//!             // Handle JSON chat messages
62//!             let jsonpath_pattern = MessagePattern::jsonpath("$.type");
63//!             if jsonpath_pattern.matches(&text) {
64//!                 ctx.broadcast_to_room("general", &text).await?;
65//!             }
66//!         }
67//!         Ok(())
68//!     }
69//! }
70//! ```
71
72use async_trait::async_trait;
73use axum::extract::ws::Message;
74use regex::Regex;
75use serde_json::Value;
76use std::collections::{HashMap, HashSet};
77use std::sync::Arc;
78use tokio::sync::{broadcast, RwLock};
79
80/// Result type for handler operations
81pub type HandlerResult<T> = Result<T, HandlerError>;
82
83/// Error type for handler operations
84#[derive(Debug, thiserror::Error)]
85pub enum HandlerError {
86    #[error("Failed to send message: {0}")]
87    SendError(String),
88
89    #[error("Failed to parse JSON: {0}")]
90    JsonError(#[from] serde_json::Error),
91
92    #[error("Pattern matching error: {0}")]
93    PatternError(String),
94
95    #[error("Room operation failed: {0}")]
96    RoomError(String),
97
98    #[error("Connection error: {0}")]
99    ConnectionError(String),
100
101    #[error("Handler error: {0}")]
102    Generic(String),
103}
104
105/// WebSocket message wrapper
106#[derive(Debug, Clone)]
107pub enum WsMessage {
108    Text(String),
109    Binary(Vec<u8>),
110    Ping(Vec<u8>),
111    Pong(Vec<u8>),
112    Close,
113}
114
115impl From<Message> for WsMessage {
116    fn from(msg: Message) -> Self {
117        match msg {
118            Message::Text(text) => WsMessage::Text(text.to_string()),
119            Message::Binary(data) => WsMessage::Binary(data.to_vec()),
120            Message::Ping(data) => WsMessage::Ping(data.to_vec()),
121            Message::Pong(data) => WsMessage::Pong(data.to_vec()),
122            Message::Close(_) => WsMessage::Close,
123        }
124    }
125}
126
127impl From<WsMessage> for Message {
128    fn from(msg: WsMessage) -> Self {
129        match msg {
130            WsMessage::Text(text) => Message::Text(text.into()),
131            WsMessage::Binary(data) => Message::Binary(data.into()),
132            WsMessage::Ping(data) => Message::Ping(data.into()),
133            WsMessage::Pong(data) => Message::Pong(data.into()),
134            WsMessage::Close => Message::Close(None),
135        }
136    }
137}
138
139/// Pattern for matching WebSocket messages
140#[derive(Debug, Clone)]
141pub enum MessagePattern {
142    /// Match using regular expression
143    Regex(Regex),
144    /// Match using JSONPath query
145    JsonPath(String),
146    /// Match exact text
147    Exact(String),
148    /// Always matches
149    Any,
150}
151
152impl MessagePattern {
153    /// Create a regex pattern
154    pub fn regex(pattern: &str) -> HandlerResult<Self> {
155        Ok(MessagePattern::Regex(
156            Regex::new(pattern).map_err(|e| HandlerError::PatternError(e.to_string()))?,
157        ))
158    }
159
160    /// Create a JSONPath pattern
161    pub fn jsonpath(query: &str) -> Self {
162        MessagePattern::JsonPath(query.to_string())
163    }
164
165    /// Create an exact match pattern
166    pub fn exact(text: &str) -> Self {
167        MessagePattern::Exact(text.to_string())
168    }
169
170    /// Create a pattern that matches everything
171    pub fn any() -> Self {
172        MessagePattern::Any
173    }
174
175    /// Check if the pattern matches the message
176    pub fn matches(&self, text: &str) -> bool {
177        match self {
178            MessagePattern::Regex(re) => re.is_match(text),
179            MessagePattern::JsonPath(query) => {
180                // Try to parse as JSON and check if path exists
181                if let Ok(json) = serde_json::from_str::<Value>(text) {
182                    // Use jsonpath crate's Selector
183                    if let Ok(selector) = jsonpath::Selector::new(query) {
184                        let results: Vec<_> = selector.find(&json).collect();
185                        !results.is_empty()
186                    } else {
187                        false
188                    }
189                } else {
190                    false
191                }
192            }
193            MessagePattern::Exact(expected) => text == expected,
194            MessagePattern::Any => true,
195        }
196    }
197
198    /// Check if pattern matches and extract value using JSONPath
199    pub fn extract(&self, text: &str, query: &str) -> Option<Value> {
200        if let Ok(json) = serde_json::from_str::<Value>(text) {
201            if let Ok(selector) = jsonpath::Selector::new(query) {
202                let results: Vec<_> = selector.find(&json).collect();
203                results.first().cloned().cloned()
204            } else {
205                None
206            }
207        } else {
208            None
209        }
210    }
211}
212
213/// Connection ID type
214pub type ConnectionId = String;
215
216/// Room manager for broadcasting messages to groups of connections
217#[derive(Clone)]
218pub struct RoomManager {
219    rooms: Arc<RwLock<HashMap<String, HashSet<ConnectionId>>>>,
220    connections: Arc<RwLock<HashMap<ConnectionId, HashSet<String>>>>,
221    broadcasters: Arc<RwLock<HashMap<String, broadcast::Sender<String>>>>,
222}
223
224impl RoomManager {
225    /// Create a new room manager
226    pub fn new() -> Self {
227        Self {
228            rooms: Arc::new(RwLock::new(HashMap::new())),
229            connections: Arc::new(RwLock::new(HashMap::new())),
230            broadcasters: Arc::new(RwLock::new(HashMap::new())),
231        }
232    }
233
234    /// Join a room
235    pub async fn join(&self, conn_id: &str, room: &str) -> HandlerResult<()> {
236        let mut rooms = self.rooms.write().await;
237        let mut connections = self.connections.write().await;
238
239        rooms
240            .entry(room.to_string())
241            .or_insert_with(HashSet::new)
242            .insert(conn_id.to_string());
243
244        connections
245            .entry(conn_id.to_string())
246            .or_insert_with(HashSet::new)
247            .insert(room.to_string());
248
249        Ok(())
250    }
251
252    /// Leave a room
253    pub async fn leave(&self, conn_id: &str, room: &str) -> HandlerResult<()> {
254        let mut rooms = self.rooms.write().await;
255        let mut connections = self.connections.write().await;
256
257        if let Some(room_members) = rooms.get_mut(room) {
258            room_members.remove(conn_id);
259            if room_members.is_empty() {
260                rooms.remove(room);
261            }
262        }
263
264        if let Some(conn_rooms) = connections.get_mut(conn_id) {
265            conn_rooms.remove(room);
266            if conn_rooms.is_empty() {
267                connections.remove(conn_id);
268            }
269        }
270
271        Ok(())
272    }
273
274    /// Leave all rooms for a connection
275    pub async fn leave_all(&self, conn_id: &str) -> HandlerResult<()> {
276        let mut connections = self.connections.write().await;
277        if let Some(conn_rooms) = connections.remove(conn_id) {
278            let mut rooms = self.rooms.write().await;
279            for room in conn_rooms {
280                if let Some(room_members) = rooms.get_mut(&room) {
281                    room_members.remove(conn_id);
282                    if room_members.is_empty() {
283                        rooms.remove(&room);
284                    }
285                }
286            }
287        }
288        Ok(())
289    }
290
291    /// Get all connections in a room
292    pub async fn get_room_members(&self, room: &str) -> Vec<ConnectionId> {
293        let rooms = self.rooms.read().await;
294        rooms
295            .get(room)
296            .map(|members| members.iter().cloned().collect())
297            .unwrap_or_default()
298    }
299
300    /// Get all rooms for a connection
301    pub async fn get_connection_rooms(&self, conn_id: &str) -> Vec<String> {
302        let connections = self.connections.read().await;
303        connections
304            .get(conn_id)
305            .map(|rooms| rooms.iter().cloned().collect())
306            .unwrap_or_default()
307    }
308
309    /// Get broadcast sender for a room (creates if doesn't exist)
310    pub async fn get_broadcaster(&self, room: &str) -> broadcast::Sender<String> {
311        let mut broadcasters = self.broadcasters.write().await;
312        broadcasters
313            .entry(room.to_string())
314            .or_insert_with(|| {
315                let (tx, _) = broadcast::channel(1024);
316                tx
317            })
318            .clone()
319    }
320}
321
322impl Default for RoomManager {
323    fn default() -> Self {
324        Self::new()
325    }
326}
327
328/// Context provided to handlers for each connection
329pub struct WsContext {
330    /// Unique connection ID
331    pub connection_id: ConnectionId,
332    /// WebSocket path
333    pub path: String,
334    /// Room manager for broadcasting
335    room_manager: RoomManager,
336    /// Sender for outgoing messages
337    message_tx: tokio::sync::mpsc::UnboundedSender<Message>,
338    /// Metadata storage
339    metadata: Arc<RwLock<HashMap<String, Value>>>,
340}
341
342impl WsContext {
343    /// Create a new WebSocket context
344    pub fn new(
345        connection_id: ConnectionId,
346        path: String,
347        room_manager: RoomManager,
348        message_tx: tokio::sync::mpsc::UnboundedSender<Message>,
349    ) -> Self {
350        Self {
351            connection_id,
352            path,
353            room_manager,
354            message_tx,
355            metadata: Arc::new(RwLock::new(HashMap::new())),
356        }
357    }
358
359    /// Send a text message
360    pub async fn send_text(&self, text: &str) -> HandlerResult<()> {
361        self.message_tx
362            .send(Message::Text(text.to_string().into()))
363            .map_err(|e| HandlerError::SendError(e.to_string()))
364    }
365
366    /// Send a binary message
367    pub async fn send_binary(&self, data: Vec<u8>) -> HandlerResult<()> {
368        self.message_tx
369            .send(Message::Binary(data.into()))
370            .map_err(|e| HandlerError::SendError(e.to_string()))
371    }
372
373    /// Send a JSON message
374    pub async fn send_json(&self, value: &Value) -> HandlerResult<()> {
375        let text = serde_json::to_string(value)?;
376        self.send_text(&text).await
377    }
378
379    /// Join a room
380    pub async fn join_room(&self, room: &str) -> HandlerResult<()> {
381        self.room_manager.join(&self.connection_id, room).await
382    }
383
384    /// Leave a room
385    pub async fn leave_room(&self, room: &str) -> HandlerResult<()> {
386        self.room_manager.leave(&self.connection_id, room).await
387    }
388
389    /// Broadcast text to all members in a room
390    pub async fn broadcast_to_room(&self, room: &str, text: &str) -> HandlerResult<()> {
391        let broadcaster = self.room_manager.get_broadcaster(room).await;
392        broadcaster
393            .send(text.to_string())
394            .map_err(|e| HandlerError::RoomError(e.to_string()))?;
395        Ok(())
396    }
397
398    /// Get all rooms this connection is in
399    pub async fn get_rooms(&self) -> Vec<String> {
400        self.room_manager.get_connection_rooms(&self.connection_id).await
401    }
402
403    /// Set metadata value
404    pub async fn set_metadata(&self, key: &str, value: Value) {
405        let mut metadata = self.metadata.write().await;
406        metadata.insert(key.to_string(), value);
407    }
408
409    /// Get metadata value
410    pub async fn get_metadata(&self, key: &str) -> Option<Value> {
411        let metadata = self.metadata.read().await;
412        metadata.get(key).cloned()
413    }
414}
415
416/// Trait for WebSocket message handlers
417#[async_trait]
418pub trait WsHandler: Send + Sync {
419    /// Called when a new WebSocket connection is established
420    async fn on_connect(&self, _ctx: &mut WsContext) -> HandlerResult<()> {
421        Ok(())
422    }
423
424    /// Called when a message is received
425    async fn on_message(&self, ctx: &mut WsContext, msg: WsMessage) -> HandlerResult<()>;
426
427    /// Called when the connection is closed
428    async fn on_disconnect(&self, _ctx: &mut WsContext) -> HandlerResult<()> {
429        Ok(())
430    }
431
432    /// Check if this handler should handle the given path
433    fn handles_path(&self, _path: &str) -> bool {
434        true // Default: handle all paths
435    }
436}
437
438/// Pattern-based message router
439pub struct MessageRouter {
440    routes: Vec<(MessagePattern, Box<dyn Fn(String) -> Option<String> + Send + Sync>)>,
441}
442
443impl MessageRouter {
444    /// Create a new message router
445    pub fn new() -> Self {
446        Self { routes: Vec::new() }
447    }
448
449    /// Add a route with a pattern and handler function
450    pub fn on<F>(&mut self, pattern: MessagePattern, handler: F) -> &mut Self
451    where
452        F: Fn(String) -> Option<String> + Send + Sync + 'static,
453    {
454        self.routes.push((pattern, Box::new(handler)));
455        self
456    }
457
458    /// Route a message through the registered handlers
459    pub fn route(&self, text: &str) -> Option<String> {
460        for (pattern, handler) in &self.routes {
461            if pattern.matches(text) {
462                if let Some(response) = handler(text.to_string()) {
463                    return Some(response);
464                }
465            }
466        }
467        None
468    }
469}
470
471impl Default for MessageRouter {
472    fn default() -> Self {
473        Self::new()
474    }
475}
476
477/// Handler registry for managing multiple handlers
478pub struct HandlerRegistry {
479    handlers: Vec<Arc<dyn WsHandler>>,
480    hot_reload_enabled: bool,
481}
482
483impl HandlerRegistry {
484    /// Create a new handler registry
485    pub fn new() -> Self {
486        Self {
487            handlers: Vec::new(),
488            hot_reload_enabled: std::env::var("MOCKFORGE_WS_HOTRELOAD")
489                .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
490                .unwrap_or(false),
491        }
492    }
493
494    /// Create a registry with hot-reload enabled
495    pub fn with_hot_reload() -> Self {
496        Self {
497            handlers: Vec::new(),
498            hot_reload_enabled: true,
499        }
500    }
501
502    /// Check if hot-reload is enabled
503    pub fn is_hot_reload_enabled(&self) -> bool {
504        self.hot_reload_enabled
505    }
506
507    /// Register a handler
508    pub fn register<H: WsHandler + 'static>(&mut self, handler: H) -> &mut Self {
509        self.handlers.push(Arc::new(handler));
510        self
511    }
512
513    /// Get handlers for a specific path
514    pub fn get_handlers(&self, path: &str) -> Vec<Arc<dyn WsHandler>> {
515        self.handlers.iter().filter(|h| h.handles_path(path)).cloned().collect()
516    }
517
518    /// Check if any handler handles the given path
519    pub fn has_handler_for(&self, path: &str) -> bool {
520        self.handlers.iter().any(|h| h.handles_path(path))
521    }
522
523    /// Clear all handlers (useful for hot-reload)
524    pub fn clear(&mut self) {
525        self.handlers.clear();
526    }
527
528    /// Get the number of registered handlers
529    pub fn len(&self) -> usize {
530        self.handlers.len()
531    }
532
533    /// Check if the registry is empty
534    pub fn is_empty(&self) -> bool {
535        self.handlers.is_empty()
536    }
537}
538
539impl Default for HandlerRegistry {
540    fn default() -> Self {
541        Self::new()
542    }
543}
544
545/// Passthrough handler configuration for forwarding messages to upstream servers
546#[derive(Clone)]
547pub struct PassthroughConfig {
548    /// Pattern to match paths for passthrough
549    pub pattern: MessagePattern,
550    /// Upstream URL to forward to
551    pub upstream_url: String,
552}
553
554impl PassthroughConfig {
555    /// Create a new passthrough configuration
556    pub fn new(pattern: MessagePattern, upstream_url: String) -> Self {
557        Self {
558            pattern,
559            upstream_url,
560        }
561    }
562
563    /// Create a passthrough for all messages matching a regex
564    pub fn regex(regex: &str, upstream_url: String) -> HandlerResult<Self> {
565        Ok(Self {
566            pattern: MessagePattern::regex(regex)?,
567            upstream_url,
568        })
569    }
570}
571
572/// Passthrough handler that forwards messages to an upstream server
573pub struct PassthroughHandler {
574    config: PassthroughConfig,
575}
576
577impl PassthroughHandler {
578    /// Create a new passthrough handler
579    pub fn new(config: PassthroughConfig) -> Self {
580        Self { config }
581    }
582
583    /// Check if a message should be passed through
584    pub fn should_passthrough(&self, text: &str) -> bool {
585        self.config.pattern.matches(text)
586    }
587
588    /// Get the upstream URL
589    pub fn upstream_url(&self) -> &str {
590        &self.config.upstream_url
591    }
592}
593
594#[async_trait]
595impl WsHandler for PassthroughHandler {
596    async fn on_message(&self, ctx: &mut WsContext, msg: WsMessage) -> HandlerResult<()> {
597        if let WsMessage::Text(text) = &msg {
598            if self.should_passthrough(text) {
599                // In a real implementation, this would forward to upstream
600                // For now, we'll just log and echo back
601                ctx.send_text(&format!("PASSTHROUGH({}): {}", self.config.upstream_url, text))
602                    .await?;
603                return Ok(());
604            }
605        }
606        Ok(())
607    }
608}
609
610#[cfg(test)]
611mod tests {
612    use super::*;
613
614    #[test]
615    fn test_message_pattern_regex() {
616        let pattern = MessagePattern::regex(r"^hello").unwrap();
617        assert!(pattern.matches("hello world"));
618        assert!(!pattern.matches("goodbye world"));
619    }
620
621    #[test]
622    fn test_message_pattern_exact() {
623        let pattern = MessagePattern::exact("hello");
624        assert!(pattern.matches("hello"));
625        assert!(!pattern.matches("hello world"));
626    }
627
628    #[test]
629    fn test_message_pattern_jsonpath() {
630        let pattern = MessagePattern::jsonpath("$.type");
631        assert!(pattern.matches(r#"{"type": "message"}"#));
632        assert!(!pattern.matches(r#"{"name": "test"}"#));
633    }
634
635    #[tokio::test]
636    async fn test_room_manager() {
637        let manager = RoomManager::new();
638
639        // Join rooms
640        manager.join("conn1", "room1").await.unwrap();
641        manager.join("conn1", "room2").await.unwrap();
642        manager.join("conn2", "room1").await.unwrap();
643
644        // Check room members
645        let room1_members = manager.get_room_members("room1").await;
646        assert_eq!(room1_members.len(), 2);
647        assert!(room1_members.contains(&"conn1".to_string()));
648        assert!(room1_members.contains(&"conn2".to_string()));
649
650        // Check connection rooms
651        let conn1_rooms = manager.get_connection_rooms("conn1").await;
652        assert_eq!(conn1_rooms.len(), 2);
653        assert!(conn1_rooms.contains(&"room1".to_string()));
654        assert!(conn1_rooms.contains(&"room2".to_string()));
655
656        // Leave room
657        manager.leave("conn1", "room1").await.unwrap();
658        let room1_members = manager.get_room_members("room1").await;
659        assert_eq!(room1_members.len(), 1);
660        assert!(room1_members.contains(&"conn2".to_string()));
661
662        // Leave all rooms
663        manager.leave_all("conn1").await.unwrap();
664        let conn1_rooms = manager.get_connection_rooms("conn1").await;
665        assert_eq!(conn1_rooms.len(), 0);
666    }
667
668    #[test]
669    fn test_message_router() {
670        let mut router = MessageRouter::new();
671
672        router
673            .on(MessagePattern::exact("ping"), |_| Some("pong".to_string()))
674            .on(MessagePattern::regex(r"^hello").unwrap(), |_| Some("hi there!".to_string()));
675
676        assert_eq!(router.route("ping"), Some("pong".to_string()));
677        assert_eq!(router.route("hello world"), Some("hi there!".to_string()));
678        assert_eq!(router.route("goodbye"), None);
679    }
680}